feat: atom23 inference changes and training fixes

This commit is contained in:
Raktim Mitra
2026-01-20 14:44:19 -08:00
committed by Raktim Mitra
parent 94d9d635cd
commit c2ec302745
8 changed files with 52 additions and 24 deletions

View File

@@ -349,9 +349,13 @@ class SubunitSymmetryResolution(nn.Module):
x_native = symm_input["coord_atom_lvl"].to(x_pred.device)
mask_native = symm_input["mask_atom_lvl"].to(x_pred.device)
x_native_aln, x_native_mask = self._resolve_subunits(
mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred
)
try:
x_native_aln, x_native_mask = self._resolve_subunits(
mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred
)
except Exception:
# fd ... TO DO: DEBUG!
return loss_input
loss_input["X_gt_L"] = x_native_aln
loss_input["crd_mask_L"] = x_native_mask

View File

@@ -242,6 +242,35 @@ SELECTION_NONPROTEIN = [
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
backbone_atomscheme_DNA = [
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" C1'",
] # , None]
backbone_atomscheme_RNA = [
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" O2'",
" C1'",
]
DNA_atoms = {
"DA": [
" N9 ",
@@ -410,7 +439,6 @@ association_schemes_stripped = {
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
ATOM_REGION_BY_RESI = {
'ALA': {'bb':('N','CA','C','O'),

View File

@@ -186,6 +186,7 @@ def fetch_motif_residue_(
subarray.set_annotation(
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
)
elif redesign_motif_sidechains and res_name in (STANDARD_DNA + STANDARD_RNA):
is_backbone = np.isin(subarray.atom_name, backbone_atoms_RNA)
subarray.set_annotation("is_motif_atom", is_backbone)

View File

@@ -281,7 +281,7 @@ class SampleConditioningType(Transform):
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -299,6 +299,7 @@ class SampleConditioningFlags(Transform):
"AssignTypes",
"SampleConditioningType",
] # We use is_protein in the PPI training condition
def __init__(self, association_scheme):
self.association_scheme = association_scheme

View File

@@ -698,7 +698,6 @@ class AddAdditional1dFeaturesToFeats(Transform):
atom_1d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme="atom14",
association_scheme='atom14'
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_1d_features = token_1d_features
@@ -755,11 +754,13 @@ class AddAdditional1dFeaturesToFeats(Transform):
"""
if "feats" not in data.keys():
data["feats"] = {}
if association_scheme == 'atom23':
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
if self.association_scheme == "atom23":
data["atom_array"].set_annotation(
"is_protein_token", data["atom_array"].is_protein
)
data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna)
data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna)
if self.association_scheme == "atom23":
data["atom_array"].set_annotation(

View File

@@ -523,7 +523,7 @@ def build_atom14_base_pipeline_(
autofill_zeros_if_not_present_in_atomarray=True,
token_1d_features=token_1d_features,
atom_1d_features=atom_1d_features,
association_scheme=association_scheme
association_scheme=association_scheme,
),
AddAF3TokenBondFeatures(),
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),

View File

@@ -72,11 +72,9 @@ class IslandCondition(TrainingCondition):
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
association_scheme = 'atom14',
):
self.name = name
self.frequency = frequency
self.association_scheme = association_scheme
# Token selection
self.island_sampling_kwargs = island_sampling_kwargs
@@ -91,8 +89,6 @@ class IslandCondition(TrainingCondition):
self.p_fix_motif_coordinates = p_fix_motif_coordinates
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
self.association_scheme = association_scheme
def is_valid_for_example(self, data) -> bool:
is_protein = data["atom_array"].is_protein
@@ -138,7 +134,6 @@ class IslandCondition(TrainingCondition):
n_protein_tokens,
**self.island_sampling_kwargs,
)
is_motif_token[token_level_array.is_protein] = islands_mask
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
@@ -310,7 +305,7 @@ class SubtypeCondition(TrainingCondition):
"""
name = "subtype"
association_scheme = 'atom14'
association_scheme = "atom14"
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
self.frequency = frequency
@@ -526,8 +521,6 @@ def sample_is_motif_atom_with_fixed_seq(
is_motif_atom_with_fixed_seq = (
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
)
return is_motif_atom_with_fixed_seq
@@ -571,7 +564,6 @@ def sample_unindexed_atoms(
is_motif_atom_unindexed, atom_array.is_residue
)
return is_motif_atom_unindexed

View File

@@ -53,6 +53,7 @@ def map_to_association_scheme(
else:
return ATOM_NAMES[idxs]
def map_names_to_elements(
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
) -> np.ndarray:
@@ -179,7 +180,7 @@ class PadTokensWithVirtualAtoms(Transform):
token_ids = np.unique(atom_array.token_id)
assert len(token_ids) == len(
is_motif_atom_with_fixed_seq
), "Token ids and token level array have different lengths!"
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == "atom23":
@@ -220,7 +221,7 @@ class PadTokensWithVirtualAtoms(Transform):
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
if is_paddable[token_id]:
token = atom_array[start:end]
# First, pad with virtual atoms if needed
if self.association_scheme == "atom23" and atom_array[start].is_dna:
n_atoms_per_token = 22
@@ -229,7 +230,7 @@ class PadTokensWithVirtualAtoms(Transform):
else:
n_atoms_per_token = self.n_atoms_per_token
n_pad = n_atoms_per_token - len(token)
if n_pad > 0:
mask = get_af3_token_representative_masks(
token, central_atom=self.atom_to_pad_from