diff --git a/models/rfd3/configs/datasets/train/pdb/base.yaml b/models/rfd3/configs/datasets/train/pdb/base.yaml index d3f4562..9c0a8c8 100644 --- a/models/rfd3/configs/datasets/train/pdb/base.yaml +++ b/models/rfd3/configs/datasets/train/pdb/base.yaml @@ -9,6 +9,6 @@ weights: beta: 0.5 alphas: a_prot: 3.0 # 3 for AF-3 - a_nuc: 0.0 # 3 for AF-3 + a_nuc: 3.0 # 3 for AF-3 a_ligand: 1.0 # 1 for AF-3 a_loi: 5.0 # 5 for AF-3 diff --git a/models/rfd3/configs/experiment/rfd3na.yaml b/models/rfd3/configs/experiment/rfd3na.yaml new file mode 100644 index 0000000..6d80720 --- /dev/null +++ b/models/rfd3/configs/experiment/rfd3na.yaml @@ -0,0 +1,82 @@ +# @package _global_ +# Training configuration for RFD3 + +defaults: + - /debug/default + - override /model: rfd3_base + #- override /datasets: all + - override /logger: csv + #- override /logger: wandb + - _self_ + +name: train-base +tags: [print-model] +ckpt_path: null + +model: + net: + token_initializer: + token_1d_features: + ref_motif_token_type: 3 + restype: 32 + is_dna_token: 1 + is_rna_token: 1 + is_protein_token: 1 + atom_1d_features: + ref_atom_name_chars: 256 + ref_element: 128 + ref_charge: 1 + ref_mask: 1 + ref_is_motif_atom_with_fixed_coord: 1 + ref_is_motif_atom_unindexed: 1 + has_zero_occupancy: 1 + ref_pos: 3 + + # Guided features + ref_atomwise_rasa: 3 + active_donor: 1 + active_acceptor: 1 + is_atom_level_hotspot: 1 + diffusion_module: + n_recycle: 2 + use_local_token_attention: True + diffusion_transformer: + n_local_tokens: 32 + n_keys: 128 + + inference_sampler: + num_timesteps: 100 + + +datasets: + diffusion_batch_size_train: 16 + crop_size: 256 + max_atoms_in_crop: 2560 # ~10x crop size. + global_transform_args: + association_scheme: atom23 + train_conditions: + unconditional: + frequency: 2.0 + island: + frequency: 2.0 + sequence_design: + frequency: 0.5 + tipatom: + frequency: 5.0 + ppi: + frequency: 0.0 + train: + # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. + #pdb: + #probability: 0.10 + #monomer_distillation: + #probability: 0.90 + pdb: + probability: 1.0 + +trainer: + devices_per_node: 1 + limit_train_batches: 10 + limit_val_batches: 1 + validate_every_n_epochs: 5 + prevalidate: false diff --git a/models/rfd3/src/rfd3/constants.py b/models/rfd3/src/rfd3/constants.py index a1a6e48..529ca58 100644 --- a/models/rfd3/src/rfd3/constants.py +++ b/models/rfd3/src/rfd3/constants.py @@ -242,42 +242,171 @@ SELECTION_NONPROTEIN = [ "POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID", ] -backbone_atomscheme_DNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'"]#, None] +backbone_atomscheme_DNA = [ + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " C1'", +] # , None] -backbone_atomscheme_RNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'"] +backbone_atomscheme_RNA = [ + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " O2'", + " C1'", +] DNA_atoms = { - 'DA': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '], - 'DC': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '], - 'DG': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '], - 'DT': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ']} - -RNA_atoms = { - 'A': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '], - 'C': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '], - 'G': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '], - 'U': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 '] + "DA": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " N6 ", + " N1 ", + " C2 ", + " N3 ", + " C4 ", + ], + "DC": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "], + "DG": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " O6 ", + " N1 ", + " C2 ", + " N2 ", + " N3 ", + " C4 ", + ], + "DT": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C7 ", " C6 "], } -association_schemes['atom23'] = {} +RNA_atoms = { + "A": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " N6 ", + " N1 ", + " C2 ", + " N3 ", + " C4 ", + ], + "C": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "], + "G": [ + " N9 ", + " C8 ", + " N7 ", + " C5 ", + " C6 ", + " O6 ", + " N1 ", + " C2 ", + " N2 ", + " N3 ", + " C4 ", + ], + "U": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C6 "], +} + +association_schemes["atom23"] = {} for item in DNA_atoms: - association_schemes['atom23'][item] = tuple(backbone_atomscheme_DNA + DNA_atoms[item]+ [None]*(22 - len(DNA_atoms[item] + backbone_atomscheme_DNA))) + association_schemes["atom23"][item] = tuple( + backbone_atomscheme_DNA + + DNA_atoms[item] + + [None] * (22 - len(DNA_atoms[item] + backbone_atomscheme_DNA)) + ) for item in RNA_atoms: - association_schemes['atom23'][item] = tuple(backbone_atomscheme_RNA + RNA_atoms[item]+ [None]*(23 - len(RNA_atoms[item] + backbone_atomscheme_RNA))) + association_schemes["atom23"][item] = tuple( + backbone_atomscheme_RNA + + RNA_atoms[item] + + [None] * (23 - len(RNA_atoms[item] + backbone_atomscheme_RNA)) + ) -for item in association_schemes['dense']: - association_schemes['atom23'][item] = association_schemes['dense'][item] +for item in association_schemes["dense"]: + association_schemes["atom23"][item] = association_schemes["dense"][item] -association_schemes['atom23']['DX'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None) #rna_mask -association_schemes['atom23']['X'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None)#rna mask +association_schemes["atom23"]["DX"] = ( + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " C1'", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, +) # rna_mask +association_schemes["atom23"]["X"] = ( + " P ", + " OP1", + " OP2", + " O5'", + " C5'", + " C4'", + " O4'", + " C3'", + " O3'", + " C2'", + " O2'", + " C1'", + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, +) # rna mask ATOM23_ATOM_NAMES_RNA = np.array( - [item.strip() for item in backbone_atomscheme_RNA] + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))] + [item.strip() for item in backbone_atomscheme_RNA] + + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))] ) """Atom23 atom names (e.g. CA, V1)""" ATOM23_ATOM_ELEMENTS_RNA = np.array( - ["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))] + ["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "O", "C"] + + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))] ) """Atom23 element names (e.g. C, VX)""" @@ -285,12 +414,14 @@ ATOM23_ATOM_NAME_TO_ELEMENT = { name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA) } ATOM23_ATOM_NAMES_DNA = np.array( - [item.strip() for item in backbone_atomscheme_DNA] + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))] + [item.strip() for item in backbone_atomscheme_DNA] + + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))] ) """Atom23 atom names (e.g. CA, V1)""" ATOM23_ATOM_ELEMENTS_DNA = np.array( - ["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))] + ["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "C"] + + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))] ) """Atom23 element names (e.g. C, VX)""" @@ -307,4 +438,3 @@ association_schemes_stripped = { backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA) backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA) - diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index d97b3be..b3117fc 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -30,9 +30,12 @@ from rfd3.constants import ( OPTIONAL_CONDITIONING_VALUES, REQUIRED_CONDITIONING_ANNOTATION_VALUES, REQUIRED_INFERENCE_ANNOTATIONS, + backbone_atoms_DNA, + backbone_atoms_RNA, ) from rfd3.inference.legacy_input_parsing import ( create_atom_array_from_design_specification_legacy, + reorder_atoms_per_residue, ) from rfd3.inference.parsing import InputSelection from rfd3.inference.symmetry.symmetry_utils import ( @@ -67,7 +70,6 @@ logging.basicConfig(level=logging.DEBUG) logger = RankedLogger(__name__, rank_zero_only=True) - ################################################################################# # Custom infer_ori functions ################################################################################# @@ -505,6 +507,21 @@ class DesignInputSpecification(BaseModel): def build(self, return_metadata=False): """Main build pipeline.""" atom_array_input_annotated = copy.deepcopy(self.atom_array_input) + + ########## reorder NA atoms ########### + is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"]) + is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"]) + dna_array = atom_array_input_annotated[is_dna] + rna_array = atom_array_input_annotated[is_rna] + + atom_array_input_annotated[is_dna] = reorder_atoms_per_residue( + dna_array, backbone_atoms_DNA + ) + atom_array_input_annotated[is_rna] = reorder_atoms_per_residue( + rna_array, backbone_atoms_RNA + ) + ####################################### + atom_array = self._build_init(atom_array_input_annotated) # Apply post-processing @@ -894,31 +911,52 @@ def validator_context(validator_name: str, data: dict = None): raise e -def create_diffused_residues(n, additional_annotations=None): +def create_diffused_residues(n, additional_annotations=None, polymer_type="P"): + from rfd3.constants import ( + ATOM23_ATOM_NAME_TO_ELEMENT, + backbone_atoms_DNA, + backbone_atoms_RNA, + ) + if n <= 0: raise ValueError(f"Negative/null residue count ({n}) not allowed.") + if polymer_type == "P": + res_name = "ALA" + bb_len = 5 + bb_atom_names = ["N", "CA", "C", "O", "CB"] + elif polymer_type == "R": + res_name = "A" + bb_len = len(backbone_atoms_RNA) + bb_atom_names = strip_list(backbone_atoms_RNA) + elif polymer_type == "D": + res_name = "DA" + bb_len = len(backbone_atoms_DNA) + bb_atom_names = strip_list(backbone_atoms_DNA) + else: + raise ValueError( + f"invalid polymer type detected: {polymer_type}, check contig!" + ) + + bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names] + atoms = [] [ atoms.extend( [ struc.Atom( np.array([0.0, 0.0, 0.0], dtype=np.float32), - res_name="ALA", + res_name=res_name, res_id=idx, ) - for _ in range(5) + for _ in range(bb_len) ] ) for idx in range(1, n + 1) ] array = struc.array(atoms) - array.set_annotation( - "element", np.array(["N", "C", "C", "O", "C"] * n, dtype=" AtomArray: + """ + Reorder atoms within each residue of an AtomArray. + Atoms in `desired_order` appear first (in that order), followed by all others + in original order. Faster version using get_residue_starts(). + + Parameters: + - atom_array: AtomArray to reorder. + - desired_order: List of atom names in the desired per-residue order. + + Returns: + - AtomArray with reordered atoms per residue. + """ + if len(atom_array) == 0: + return atom_array + res_starts = get_residue_starts(atom_array) + res_starts = np.append(res_starts, len(atom_array)) # add end index for slicing + reordered_chunks = [] + order_dict = {name: i for i, name in enumerate(desired_order)} + + for i in range(len(res_starts) - 1): + start, end = res_starts[i], res_starts[i + 1] + residue = atom_array[start:end] + + # Boolean masks for matching and non-matching atom names + in_order_mask = np.isin(residue.atom_name, desired_order) + not_in_order_mask = ~in_order_mask + + # Sort matching atoms by desired order + atoms_in_order = residue[in_order_mask] + sort_idx = np.argsort([order_dict[name] for name in atoms_in_order.atom_name]) + ordered_atoms = atoms_in_order[sort_idx] + + # Remaining atoms as-is + remaining_atoms = residue[not_in_order_mask] + + # Concatenate reordered residue + reordered_chunks.append(concatenate([ordered_atoms, remaining_atoms])) + return concatenate(reordered_chunks) diff --git a/models/rfd3/src/rfd3/trainer/rfd3.py b/models/rfd3/src/rfd3/trainer/rfd3.py index a7f72e6..9735978 100644 --- a/models/rfd3/src/rfd3/trainer/rfd3.py +++ b/models/rfd3/src/rfd3/trainer/rfd3.py @@ -428,9 +428,14 @@ class AADesignTrainer(FabricTrainer): # ... Delete virtual atoms and assign atom names and elements if self.cleanup_virtual_atoms: - atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements( - atom_array, association_scheme=self.association_scheme - ) + try: + atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements( + atom_array, association_scheme=self.association_scheme + ) + except Exception as e: + global_logger.warning( + f"Failed to cleanup virtual atoms from diffusion output: {e}" + ) # ... When cleaning up virtual atoms, we can also calculate native_array_metricsl metadata_dict[i]["metrics"] |= get_all_backbone_metrics( diff --git a/models/rfd3/src/rfd3/trainer/trainer_utils.py b/models/rfd3/src/rfd3/trainer/trainer_utils.py index 59f43a1..9ee7939 100644 --- a/models/rfd3/src/rfd3/trainer/trainer_utils.py +++ b/models/rfd3/src/rfd3/trainer/trainer_utils.py @@ -11,6 +11,8 @@ from biotite.structure import concatenate, infer_elements from jaxtyping import Float, Int from rfd3.constants import ( ATOM14_ATOM_NAMES, + ATOM23_ATOM_NAMES_DNA, + ATOM23_ATOM_NAMES_RNA, VIRTUAL_ATOM_ELEMENT_NAME, association_schemes, association_schemes_stripped, @@ -252,13 +254,19 @@ def _readout_seq_from_struc( continue # ... Find the index of virtual atom names in the standard atom14 names + ATOM_NAMES = ATOM14_ATOM_NAMES + if restype in STANDARD_DNA: + ATOM_NAMES = ATOM23_ATOM_NAMES_DNA + if restype in STANDARD_RNA: + ATOM_NAMES = ATOM23_ATOM_NAMES_RNA + atom_name_idx_in_atom14_scheme = np.array( [ - np.where(ATOM14_ATOM_NAMES == atom_name)[0][0] + np.where(ATOM_NAMES == atom_name)[0][0] for atom_name in cur_pred_res_atom_names ] ) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7] - atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool) + atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool) atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True # ... Find the matched restype by checking if all the non-None posititons and None positions match @@ -427,12 +435,30 @@ def process_unindexed_outputs( else: join_atom = None + if join_atom is None: + pass + else: + dist = float(dists[row_ind[join_atom], col_ind[join_atom]]) + + elif not np.any( + np.isin( + token.atom_name, [item.replace(" ", "") for item in backbone_atoms_RNA] + ) + ): + if np.sum(token.atomize) == 1: + join_atom = np.where(token.atomize)[0][0] + elif "C1'" in token.atom_name: + join_atom = np.where(token.atom_name == "C1'")[0][0] + else: + join_atom = None + if join_atom is None: global_logger.warning( - f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}." + "Skipping joint point rmsd, neither protein or NA backbone" ) else: dist = float(dists[row_ind[join_atom], col_ind[join_atom]]) + metadata["join_point_rmsd_by_token"][token_pdb_id] = dist metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}" diff --git a/models/rfd3/src/rfd3/transforms/conditioning_base.py b/models/rfd3/src/rfd3/transforms/conditioning_base.py index 3a551a5..769fcf4 100644 --- a/models/rfd3/src/rfd3/transforms/conditioning_base.py +++ b/models/rfd3/src/rfd3/transforms/conditioning_base.py @@ -281,7 +281,7 @@ class SampleConditioningType(Transform): cond = valid_conditions[i_cond] cond.association_scheme = self.association_scheme - + data["sampled_condition"] = cond data["sampled_condition_name"] = cond.name data["sampled_condition_cls"] = cond.__class__ @@ -299,6 +299,7 @@ class SampleConditioningFlags(Transform): "AssignTypes", "SampleConditioningType", ] # We use is_protein in the PPI training condition + def __init__(self, association_scheme): self.association_scheme = association_scheme @@ -375,13 +376,15 @@ class UnindexFlaggedTokens(Transform): token.res_id = token.res_id + max_resid token.is_C_terminus[:] = False token.is_N_terminus[:] = False - - if association_scheme is not 'atom23': + + if not self.association_scheme == "atom23": assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme" token = add_representative_atom(token, central_atom=self.central_atom) else: if token.is_protein.all(): - token = add_representative_atom(token, central_atom=self.central_atom) + token = add_representative_atom( + token, central_atom=self.central_atom + ) else: token = add_representative_atom(token, central_atom="C1'") diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py index 8225bcc..5cc84fa 100644 --- a/models/rfd3/src/rfd3/transforms/design_transforms.py +++ b/models/rfd3/src/rfd3/transforms/design_transforms.py @@ -697,11 +697,12 @@ class AddAdditional1dFeaturesToFeats(Transform): token_1d_features, atom_1d_features, autofill_zeros_if_not_present_in_atomarray=False, - association_scheme='atom14' + association_scheme="atom14", ): self.autofill = autofill_zeros_if_not_present_in_atomarray self.token_1d_features = token_1d_features self.atom_1d_features = atom_1d_features + self.association_scheme = association_scheme def check_input(self, data) -> None: check_contains_keys(data, ["atom_array"]) @@ -753,11 +754,13 @@ class AddAdditional1dFeaturesToFeats(Transform): """ if "feats" not in data.keys(): data["feats"] = {} - - if association_scheme == 'atom23': - data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein) - data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna) - data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna) + + if self.association_scheme == "atom23": + data["atom_array"].set_annotation( + "is_protein_token", data["atom_array"].is_protein + ) + data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna) + data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna) for feature_name, n_dims in self.token_1d_features.items(): data = self.generate_feature(feature_name, n_dims, data, "token") diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py index fff8685..0cd656b 100644 --- a/models/rfd3/src/rfd3/transforms/pipelines.py +++ b/models/rfd3/src/rfd3/transforms/pipelines.py @@ -383,7 +383,7 @@ def build_atom14_base_pipeline_( train_conditions=train_conditions, meta_conditioning_probabilities=meta_conditioning_probabilities, sequence_encoding=af3_sequence_encoding, - association_scheme=association_scheme + association_scheme=association_scheme, ), ), ] @@ -423,7 +423,9 @@ def build_atom14_base_pipeline_( # ... Add global token features (since number of tokens is fixed after cropping) transforms.append(AddGlobalTokenIdAnnotation()) # ... Create masks (NOTE: Modulates token count, and resets global token id if necessary) - transforms.append(TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme))) + transforms.append( + TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme)) + ) # Post-crop transforms transforms.append( @@ -443,7 +445,9 @@ def build_atom14_base_pipeline_( sharding_depth=1, ), # ... Fuse inference and training conditioning assignments - UnindexFlaggedTokens(central_atom=central_atom), + UnindexFlaggedTokens( + central_atom=central_atom, association_scheme=association_scheme + ), # ... Virtual atom padding (NOTE: Last transform which modulates atom count) PadTokensWithVirtualAtoms( n_atoms_per_token=n_atoms_per_token, @@ -519,7 +523,7 @@ def build_atom14_base_pipeline_( autofill_zeros_if_not_present_in_atomarray=True, token_1d_features=token_1d_features, atom_1d_features=atom_1d_features, - association_scheme=association_scheme + association_scheme=association_scheme, ), AddAF3TokenBondFeatures(), AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding), diff --git a/models/rfd3/src/rfd3/transforms/training_conditions.py b/models/rfd3/src/rfd3/transforms/training_conditions.py index 8f708e0..6354c1f 100644 --- a/models/rfd3/src/rfd3/transforms/training_conditions.py +++ b/models/rfd3/src/rfd3/transforms/training_conditions.py @@ -58,6 +58,8 @@ class IslandCondition(TrainingCondition): Select islands as motif and assign conditioning strategies. """ + association_scheme = "atom14" + def __init__( self, *, @@ -70,11 +72,9 @@ class IslandCondition(TrainingCondition): p_fix_motif_coordinates, p_fix_motif_sequence, p_unindex_motif_tokens, - association_scheme = 'atom14', ): self.name = name self.frequency = frequency - self.association_scheme = association_scheme # Token selection self.island_sampling_kwargs = island_sampling_kwargs @@ -89,15 +89,13 @@ class IslandCondition(TrainingCondition): self.p_fix_motif_coordinates = p_fix_motif_coordinates self.p_fix_motif_sequence = p_fix_motif_sequence self.p_unindex_motif_tokens = p_unindex_motif_tokens - - self.association_scheme = association_scheme def is_valid_for_example(self, data) -> bool: is_protein = data["atom_array"].is_protein is_dna = data["atom_array"].is_dna is_rna = data["atom_array"].is_rna ### updating this to allow other polymers - if self.association_scheme is not 'atom23': + if not self.association_scheme == "atom23": if not np.any(is_protein | is_dna | is_rna): return False else: @@ -113,8 +111,12 @@ class IslandCondition(TrainingCondition): token_level_array = atom_array[get_token_starts(atom_array)] # initialize motif tokens as all non-protein tokens - if self.association_scheme is 'atom23': - polymer_mask = (token_level_array.is_protein | token_level_array.is_dna | token_level_array.is_rna) + if self.association_scheme == "atom23": + polymer_mask = ( + token_level_array.is_protein + | token_level_array.is_dna + | token_level_array.is_rna + ) is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy() n_polymer_tokens = np.sum(polymer_mask) islands_mask = sample_island_tokens( @@ -123,13 +125,15 @@ class IslandCondition(TrainingCondition): ) is_motif_token[polymer_mask] = islands_mask else: - is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy() + is_motif_token = np.asarray( + ~token_level_array.is_protein, dtype=bool + ).copy() n_protein_tokens = np.sum(token_level_array.is_protein) - slands_mask = sample_island_tokens( - _protein_tokens, - *self.island_sampling_kwargs, - + islands_mask = sample_island_tokens( + n_protein_tokens, + **self.island_sampling_kwargs, + ) is_motif_token[token_level_array.is_protein] = islands_mask # TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this @@ -160,7 +164,7 @@ class IslandCondition(TrainingCondition): is_motif_atom = sample_motif_subgraphs( atom_array=atom_array, **self.subgraph_sampling_kwargs, - association_scheme=self.association_scheme + association_scheme=self.association_scheme, ) # We also only want resolved atoms to be motif @@ -182,7 +186,7 @@ class IslandCondition(TrainingCondition): p_fix_motif_sequence=self.p_fix_motif_sequence, p_fix_motif_coordinates=self.p_fix_motif_coordinates, p_unindex_motif_tokens=self.p_unindex_motif_tokens, - association_scheme=self.association_scheme + association_scheme=self.association_scheme, ) atom_array.set_annotation( @@ -202,7 +206,7 @@ class PPICondition(TrainingCondition): """Get condition indicating what is motif and what is to be diffused for protein-protein interaction training.""" name = "ppi" - association_scheme = 'atom14' + association_scheme = "atom14" def is_valid_for_example(self, data): # Extract relevant data @@ -301,7 +305,7 @@ class SubtypeCondition(TrainingCondition): """ name = "subtype" - association_scheme = 'atom14' + association_scheme = "atom14" def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False): self.frequency = frequency @@ -397,7 +401,7 @@ def sample_motif_subgraphs( hetatom_n_bond_expectation, residue_p_fix_all, hetatom_p_fix_all, - association_scheme = 'atom14' + association_scheme="atom14", ): """ Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif. @@ -431,10 +435,14 @@ def sample_motif_subgraphs( "p_fix_all": residue_p_fix_all, } - if association_scheme is 'atom23': - clause = atom_array_subset.is_protein.all() | atom_array_subset.is_dna.all() | atom_array_subset.is_rna.all() + if association_scheme == "atom23": + clause = ( + atom_array_subset.is_protein.all() + | atom_array_subset.is_dna.all() + | atom_array_subset.is_rna.all() + ) else: - clause = atom_array_subset.is_potein.all() + clause = atom_array_subset.is_protein.all() if not clause: args.update( @@ -465,12 +473,14 @@ def sample_conditioning_strategy( p_fix_motif_sequence, p_fix_motif_coordinates, p_unindex_motif_tokens, - association_scheme + association_scheme, ): atom_array.set_annotation( "is_motif_atom_with_fixed_seq", sample_is_motif_atom_with_fixed_seq( - atom_array, p_fix_motif_sequence=p_fix_motif_sequence, association_scheme=association_scheme + atom_array, + p_fix_motif_sequence=p_fix_motif_sequence, + association_scheme=association_scheme, ), ) @@ -491,7 +501,9 @@ def sample_conditioning_strategy( return atom_array -def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, association_scheme): +def sample_is_motif_atom_with_fixed_seq( + atom_array, p_fix_motif_sequence, association_scheme +): """ Samples what kind of conditioning to apply to motif tokens. @@ -504,10 +516,11 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, associ is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool) # By default reveal sequence for non-protein - - if association_scheme is not 'atom23': - is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein - + + if not association_scheme == "atom23": + is_motif_atom_with_fixed_seq = ( + is_motif_atom_with_fixed_seq | ~atom_array.is_protein + ) return is_motif_atom_with_fixed_seq @@ -526,7 +539,9 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates): return is_motif_atom_with_fixed_coord -def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_scheme='atom14'): +def sample_unindexed_atoms( + atom_array, p_unindex_motif_tokens, association_scheme="atom14" +): """ Samples which atoms in motif tokens should be flagged for unindexing. @@ -539,15 +554,15 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_schem is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool) # ensure non-residue atoms are not already flagged - if association_scheme == 'atom23': + if association_scheme == "atom23": is_motif_atom_unindexed = np.logical_and( - is_motif_atom_unindexed, (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna) - ) # is_residue refers to is_protein here + is_motif_atom_unindexed, + (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna), + ) # is_residue refers to is_protein here else: is_motif_atom_unindexed = np.logical_and( is_motif_atom_unindexed, atom_array.is_residue - ) - + ) return is_motif_atom_unindexed diff --git a/models/rfd3/src/rfd3/transforms/virtual_atoms.py b/models/rfd3/src/rfd3/transforms/virtual_atoms.py index fbc8c74..7d240a2 100644 --- a/models/rfd3/src/rfd3/transforms/virtual_atoms.py +++ b/models/rfd3/src/rfd3/transforms/virtual_atoms.py @@ -10,11 +10,10 @@ from atomworks.ml.transforms.base import ( ) from atomworks.ml.utils.token import get_token_starts from rfd3.constants import ( - ATOM23_ATOM_NAME_TO_ELEMENT, - ATOM14_ATOM_NAME_TO_ELEMENT, ATOM14_ATOM_NAMES, - ATOM23_ATOM_NAMES_RNA, + ATOM23_ATOM_NAME_TO_ELEMENT, ATOM23_ATOM_NAMES_DNA, + ATOM23_ATOM_NAMES_RNA, VIRTUAL_ATOM_ELEMENT_NAME, association_schemes, association_schemes_stripped, @@ -31,7 +30,9 @@ from rfd3.transforms.util_transforms import ( from foundry.common import exists -def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None): +def map_to_association_scheme( + atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None +): """ Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue) NB this function is a bit more general since it is used to handle tipatoms too. @@ -52,6 +53,7 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato else: return ATOM_NAMES[idxs] + def map_names_to_elements( atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME ) -> np.ndarray: @@ -61,7 +63,7 @@ def map_names_to_elements( then it returns the default value """ atom_names = [atom_names] if isinstance(atom_names, str) else atom_names - elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names] + elements = [ATOM23_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names] return np.array(elements) @@ -72,6 +74,9 @@ def generate_atom_mappings_(scheme="atom14"): symmetry_mapping = {} for aaa, atom_names in ccd_ordering_atomchar.items(): + if aaa not in scheme: + continue + mapping = list(range(len(atom_names))) scheme_names = scheme[aaa] @@ -126,10 +131,10 @@ def permute_symmetric_atom_names_( # With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries. ## fail safe, no symmetry confusion in NA bases ## - if (atom_names[0] == "P"): + if atom_names[0] == "P": return atom_names ################################################## - + if res_name in association_map: idx_to_swap = association_map[res_name] atom_names = atom_names[idx_to_swap] @@ -180,16 +185,17 @@ class PadTokensWithVirtualAtoms(Transform): token_ids = np.unique(atom_array.token_id) assert len(token_ids) == len( is_motif_atom_with_fixed_seq - ), "Token ids and token level array have different lengths!" + ), "Token ids and token level array have different lengths!" - # Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms: - if self.association_scheme == 'atom23': - is_residue = ( + # Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms: + if self.association_scheme == "atom23": + is_residue = ( token_level_array.is_protein & ~token_level_array.atomize ) | is_motif_token_unindexed - + is_residue_NA = ( - (token_level_array.is_dna | token_level_array.is_rna) & ~token_level_array.atomize + (token_level_array.is_dna | token_level_array.is_rna) + & ~token_level_array.atomize ) | is_motif_token_unindexed # Unindexed tokens are never padded, and so are treated as residues with fixed sequence. @@ -211,7 +217,6 @@ class PadTokensWithVirtualAtoms(Transform): is_non_paddable_residue = is_residue & ( is_motif_atom_with_fixed_seq | is_motif_token_unindexed ) - # Collect virtual atoms to insert (we will insert them all at once) virtual_atoms_to_insert = [] @@ -221,7 +226,7 @@ class PadTokensWithVirtualAtoms(Transform): for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])): if is_paddable[token_id]: token = atom_array[start:end] - + # First, pad with virtual atoms if needed if self.association_scheme == "atom23" and atom_array[start].is_dna: n_atoms_per_token = 22 @@ -230,7 +235,7 @@ class PadTokensWithVirtualAtoms(Transform): else: n_atoms_per_token = self.n_atoms_per_token n_pad = n_atoms_per_token - len(token) - + if n_pad > 0: mask = get_af3_token_representative_masks( token, central_atom=self.atom_to_pad_from @@ -297,10 +302,10 @@ class PadTokensWithVirtualAtoms(Transform): for token_id, (start, end) in enumerate( zip(starts_padded[:-1], starts_padded[1:]) - ): - if (atom_array_padded[start].is_dna): + ): + if atom_array_padded[start].is_dna: ATOM_NAMES = ATOM23_ATOM_NAMES_DNA - elif (atom_array_padded[start].is_rna): + elif atom_array_padded[start].is_rna: ATOM_NAMES = ATOM23_ATOM_NAMES_RNA else: ATOM_NAMES = ATOM14_ATOM_NAMES @@ -328,7 +333,10 @@ class PadTokensWithVirtualAtoms(Transform): ) atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names atom_names = map_to_association_scheme( - atom_names, res_name, scheme=self.association_scheme, ATOM_NAMES=ATOM_NAMES + atom_names, + res_name, + scheme=self.association_scheme, + ATOM_NAMES=ATOM_NAMES, ) atom_array_padded.atom_name[start:end] = atom_names else: diff --git a/src/foundry/utils/components.py b/src/foundry/utils/components.py index 75bc87f..0af9c7e 100644 --- a/src/foundry/utils/components.py +++ b/src/foundry/utils/components.py @@ -96,8 +96,21 @@ def get_design_pattern_with_constraints(contig, length=None): fixed_parts = [] pos_to_put_motif = [] + suff = [] # suffixes for diffused regions P(optional),R,D + for part in contig_parts: - if any(c.isalpha() for c in part): # Detect parts containing letters as fixed + ## updating to include DNA and RNA generation + if part[-1] in ["R", "D"]: ##Detect non-fixed RNA and DNA contig part + suff.append(part[-1]) + part = part[:-1] + if "-" in part: + start, end = map(int, part.split("-")) + else: + start = end = int(part) + variable_ranges.append([start, end]) + pos_to_put_motif.append(0) + + elif any(c.isalpha() for c in part): # Detect parts containing letters as fixed pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part) fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end]) pos_to_put_motif.append(1) @@ -110,6 +123,7 @@ def get_design_pattern_with_constraints(contig, length=None): start = end = int(part) variable_ranges.append([start, end]) pos_to_put_motif.append(0) + suff.append("P") # adjust the total length to solely for free residues num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts]) @@ -167,7 +181,7 @@ def get_design_pattern_with_constraints(contig, length=None): atoms_with_motif.append(f"{pn_unit_id}{index}") elif pos_to_put_motif[idx] == 0: free_atom = num_free_atoms.pop(0) - atoms_with_motif.append(free_atom) + atoms_with_motif.append(str(free_atom) + suff.pop(0)) elif pos_to_put_motif[idx] == 2: atoms_with_motif.append("/0")