From c2ec30274579695b3c92c888eca8c409ad91bba3 Mon Sep 17 00:00:00 2001 From: Raktim Mitra Date: Tue, 20 Jan 2026 14:44:19 -0800 Subject: [PATCH] feat: atom23 inference changes and training fixes --- models/rf3/src/rf3/loss/af3_losses.py | 10 +++++-- models/rfd3/src/rfd3/constants.py | 30 ++++++++++++++++++- .../rfd3/inference/legacy_input_parsing.py | 1 + .../src/rfd3/transforms/conditioning_base.py | 3 +- .../src/rfd3/transforms/design_transforms.py | 13 ++++---- models/rfd3/src/rfd3/transforms/pipelines.py | 2 +- .../rfd3/transforms/training_conditions.py | 10 +------ .../rfd3/src/rfd3/transforms/virtual_atoms.py | 7 +++-- 8 files changed, 52 insertions(+), 24 deletions(-) diff --git a/models/rf3/src/rf3/loss/af3_losses.py b/models/rf3/src/rf3/loss/af3_losses.py index 974db26..1bb0cd2 100644 --- a/models/rf3/src/rf3/loss/af3_losses.py +++ b/models/rf3/src/rf3/loss/af3_losses.py @@ -349,9 +349,13 @@ class SubunitSymmetryResolution(nn.Module): x_native = symm_input["coord_atom_lvl"].to(x_pred.device) mask_native = symm_input["mask_atom_lvl"].to(x_pred.device) - x_native_aln, x_native_mask = self._resolve_subunits( - mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred - ) + try: + x_native_aln, x_native_mask = self._resolve_subunits( + mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred + ) + except Exception: + # fd ... TO DO: DEBUG! + return loss_input loss_input["X_gt_L"] = x_native_aln loss_input["crd_mask_L"] = x_native_mask diff --git a/models/rfd3/src/rfd3/constants.py b/models/rfd3/src/rfd3/constants.py index 78d4892..4d76262 100644 --- a/models/rfd3/src/rfd3/constants.py +++ b/models/rfd3/src/rfd3/constants.py @@ -242,6 +242,35 @@ SELECTION_NONPROTEIN = [ "POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID", ] +backbone_atomscheme_DNA = [ + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " C1'", +] # , None] + +backbone_atomscheme_RNA = [ + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " O2'", + " C1'", +] + DNA_atoms = { "DA": [ " N9 ", @@ -410,7 +439,6 @@ association_schemes_stripped = { backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA) backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA) - # Mapping from residue type to its backbone and sidechain atoms (for convenience) ATOM_REGION_BY_RESI = { 'ALA': {'bb':('N','CA','C','O'), diff --git a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py index 401cd2a..5f77bdf 100644 --- a/models/rfd3/src/rfd3/inference/legacy_input_parsing.py +++ b/models/rfd3/src/rfd3/inference/legacy_input_parsing.py @@ -186,6 +186,7 @@ def fetch_motif_residue_( subarray.set_annotation( "is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int) ) + elif redesign_motif_sidechains and res_name in (STANDARD_DNA + STANDARD_RNA): is_backbone = np.isin(subarray.atom_name, backbone_atoms_RNA) subarray.set_annotation("is_motif_atom", is_backbone) diff --git a/models/rfd3/src/rfd3/transforms/conditioning_base.py b/models/rfd3/src/rfd3/transforms/conditioning_base.py index aec455c..25b51a7 100644 --- a/models/rfd3/src/rfd3/transforms/conditioning_base.py +++ b/models/rfd3/src/rfd3/transforms/conditioning_base.py @@ -281,7 +281,7 @@ class SampleConditioningType(Transform): cond = valid_conditions[i_cond] cond.association_scheme = self.association_scheme - + data["sampled_condition"] = cond data["sampled_condition_name"] = cond.name data["sampled_condition_cls"] = cond.__class__ @@ -299,6 +299,7 @@ class SampleConditioningFlags(Transform): "AssignTypes", "SampleConditioningType", ] # We use is_protein in the PPI training condition + def __init__(self, association_scheme): self.association_scheme = association_scheme diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py index e853322..ac17ec2 100644 --- a/models/rfd3/src/rfd3/transforms/design_transforms.py +++ b/models/rfd3/src/rfd3/transforms/design_transforms.py @@ -698,7 +698,6 @@ class AddAdditional1dFeaturesToFeats(Transform): atom_1d_features, autofill_zeros_if_not_present_in_atomarray=False, association_scheme="atom14", - association_scheme='atom14' ): self.autofill = autofill_zeros_if_not_present_in_atomarray self.token_1d_features = token_1d_features @@ -755,11 +754,13 @@ class AddAdditional1dFeaturesToFeats(Transform): """ if "feats" not in data.keys(): data["feats"] = {} - - if association_scheme == 'atom23': - data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein) - data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna) - data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna) + + if self.association_scheme == "atom23": + data["atom_array"].set_annotation( + "is_protein_token", data["atom_array"].is_protein + ) + data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna) + data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna) if self.association_scheme == "atom23": data["atom_array"].set_annotation( diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py index 378251f..0cd656b 100644 --- a/models/rfd3/src/rfd3/transforms/pipelines.py +++ b/models/rfd3/src/rfd3/transforms/pipelines.py @@ -523,7 +523,7 @@ def build_atom14_base_pipeline_( autofill_zeros_if_not_present_in_atomarray=True, token_1d_features=token_1d_features, atom_1d_features=atom_1d_features, - association_scheme=association_scheme + association_scheme=association_scheme, ), AddAF3TokenBondFeatures(), AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding), diff --git a/models/rfd3/src/rfd3/transforms/training_conditions.py b/models/rfd3/src/rfd3/transforms/training_conditions.py index bfd5257..dbb07b7 100644 --- a/models/rfd3/src/rfd3/transforms/training_conditions.py +++ b/models/rfd3/src/rfd3/transforms/training_conditions.py @@ -72,11 +72,9 @@ class IslandCondition(TrainingCondition): p_fix_motif_coordinates, p_fix_motif_sequence, p_unindex_motif_tokens, - association_scheme = 'atom14', ): self.name = name self.frequency = frequency - self.association_scheme = association_scheme # Token selection self.island_sampling_kwargs = island_sampling_kwargs @@ -91,8 +89,6 @@ class IslandCondition(TrainingCondition): self.p_fix_motif_coordinates = p_fix_motif_coordinates self.p_fix_motif_sequence = p_fix_motif_sequence self.p_unindex_motif_tokens = p_unindex_motif_tokens - - self.association_scheme = association_scheme def is_valid_for_example(self, data) -> bool: is_protein = data["atom_array"].is_protein @@ -138,7 +134,6 @@ class IslandCondition(TrainingCondition): n_protein_tokens, **self.island_sampling_kwargs, ) - is_motif_token[token_level_array.is_protein] = islands_mask # TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this @@ -310,7 +305,7 @@ class SubtypeCondition(TrainingCondition): """ name = "subtype" - association_scheme = 'atom14' + association_scheme = "atom14" def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False): self.frequency = frequency @@ -526,8 +521,6 @@ def sample_is_motif_atom_with_fixed_seq( is_motif_atom_with_fixed_seq = ( is_motif_atom_with_fixed_seq | ~atom_array.is_protein ) - - return is_motif_atom_with_fixed_seq @@ -571,7 +564,6 @@ def sample_unindexed_atoms( is_motif_atom_unindexed, atom_array.is_residue ) - return is_motif_atom_unindexed diff --git a/models/rfd3/src/rfd3/transforms/virtual_atoms.py b/models/rfd3/src/rfd3/transforms/virtual_atoms.py index a574412..6ce1675 100644 --- a/models/rfd3/src/rfd3/transforms/virtual_atoms.py +++ b/models/rfd3/src/rfd3/transforms/virtual_atoms.py @@ -53,6 +53,7 @@ def map_to_association_scheme( else: return ATOM_NAMES[idxs] + def map_names_to_elements( atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME ) -> np.ndarray: @@ -179,7 +180,7 @@ class PadTokensWithVirtualAtoms(Transform): token_ids = np.unique(atom_array.token_id) assert len(token_ids) == len( is_motif_atom_with_fixed_seq - ), "Token ids and token level array have different lengths!" + ), "Token ids and token level array have different lengths!" # Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms: if self.association_scheme == "atom23": @@ -220,7 +221,7 @@ class PadTokensWithVirtualAtoms(Transform): for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])): if is_paddable[token_id]: token = atom_array[start:end] - + # First, pad with virtual atoms if needed if self.association_scheme == "atom23" and atom_array[start].is_dna: n_atoms_per_token = 22 @@ -229,7 +230,7 @@ class PadTokensWithVirtualAtoms(Transform): else: n_atoms_per_token = self.n_atoms_per_token n_pad = n_atoms_per_token - len(token) - + if n_pad > 0: mask = get_af3_token_representative_masks( token, central_atom=self.atom_to_pad_from