diff --git a/models/rfd3/configs/datasets/design_base_rfd3na.yaml b/models/rfd3/configs/datasets/design_base_rfd3na.yaml new file mode 100644 index 0000000..7181617 --- /dev/null +++ b/models/rfd3/configs/datasets/design_base_rfd3na.yaml @@ -0,0 +1,103 @@ +# base training dataset for training AF3 design models (atom14 variants): +# protein subsampling only. + +defaults: + # Grab datasets + - train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface + - train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit + #- train/rfd3_monomer_distillation@train + - train/rna_monomer_distillation@train + + # Customized validation datasets + - val/unconditional@val.unconditional + - val/unconditional_deep@val.unconditional_deep + - val/indexed@val.indexed + - val/pseudoknot@val.pseudoknot + + # Customized train masks + - conditions/unconditional@global_transform_args.train_conditions.unconditional + - conditions/island@global_transform_args.train_conditions.island + - conditions/tipatom@global_transform_args.train_conditions.tipatom + - conditions/sequence_design@global_transform_args.train_conditions.sequence_design + - conditions/ppi@global_transform_args.train_conditions.ppi + + - _self_ + +# Create a dictionary used for transform arguments +pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline + +# Base config overrides: +diffusion_batch_size_train: 32 +diffusion_batch_size_inference: 8 +crop_size: 384 +n_recycles_train: 2 +n_recycles_validation: 1 +max_atoms_in_crop: 3840 # ~10x crop size. + +# Global transform arguments are necessary for arguments shared between training and inference +global_transform_args: + n_atoms_per_token: 14 + central_atom: CB + sigma_perturb: 2.0 + sigma_perturb_com: 1.0 + association_scheme: dense + center_option: diffuse # options are ["all", "motif", "diffuse"] + + # Reference conformer policy + generate_conformers: True + generate_conformers_for_non_protein_only: True + provide_reference_conformer_when_unmasked: True + ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details + provide_elements_for_unindexed_components: True + use_element_for_atom_names_of_atomized_tokens: True # TODO: correct name, implies unindexed do too + + # PPI Cropping + keep_full_binder_in_spatial_crop: False + max_binder_length: 170 + + # PPI Hotspots + max_ppi_hotspots_frac_to_provide: 0.2 + ppi_hotspot_max_distance: 4.5 + + # Secondary structure features + max_ss_frac_to_provide: 0.4 + min_ss_island_len: 1 + max_ss_island_len: 10 + + # Nucleic acid features + add_na_pair_features: false + + train_conditions: + unconditional: + frequency: 5.0 + sequence_design: + frequency: 2.0 + island: + frequency: 1.0 + tipatom: + frequency: 0.0 + ppi: + frequency: 0.0 + + # Used to create simple boolean flags for downstream conditioning + meta_conditioning_probabilities: + calculate_NA_SS: 1.0 + calculate_hbonds: 0.2 + calculate_rasa: 0.6 + + keep_protein_motif_rasa: 0.1 # Small to prevent noisy input to model + hbond_subsample: 0.5 + + # fully indexed training + unindex_leak_global_index: 0.10 + unindex_insert_random_break: 0.10 + unindex_remove_random_break: 0.10 + + # Probability of adding 1d secondary structure conditioning + add_1d_ss_features: 0.1 + featurize_plddt: 0.9 # Applied for monomer distillation only + add_global_is_non_loopy_feature: 0.99 + + # PPI + add_ppi_hotspots: 0.75 + full_binder_crop: 0.75 diff --git a/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml b/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml new file mode 100644 index 0000000..05ac847 --- /dev/null +++ b/models/rfd3/configs/datasets/train/rna_monomer_distillation.yaml @@ -0,0 +1,39 @@ +defaults: + - pdb/base_transform_args@rna_monomer_distillation + - _self_ + +rna_monomer_distillation: + dataset: + _target_: atomworks.ml.datasets.StructuralDatasetWrapper + save_failed_examples_to_dir: ${paths.data.failed_examples_dir} + + # cif parser arguments + cif_parser_args: + cache_dir: null + load_from_cache: False + save_to_cache: False + + # metadata parser + dataset_parser: + _target_: atomworks.ml.datasets.parsers.GenericDFParser + pn_unit_iid_colnames: null + + # metadata dataset + dataset: + _target_: atomworks.ml.datasets.PandasDataset + name: rna_monomer_distillation + id_column: example_id + data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet + columns_to_load: + - example_id + - path + - cluster_id + - seq_hash + - overall_plddt + - overall_pde + - overall_pae + + transform: + crop_contiguous_probability: 0.67 + crop_spatial_probability: 0.33 + diff --git a/models/rfd3/configs/datasets/val/pseudoknot.yaml b/models/rfd3/configs/datasets/val/pseudoknot.yaml index 27b801c..7cf5ce1 100644 --- a/models/rfd3/configs/datasets/val/pseudoknot.yaml +++ b/models/rfd3/configs/datasets/val/pseudoknot.yaml @@ -1,9 +1,9 @@ defaults: - - unconditional + - design_validation_base - _self_ dataset: name: pseudoknot eval_every_n: 1 - data: /home/afavor/git/RFD3/modelhub/projects/aa_design/tests/test_data/pseudoknot.json + data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json diff --git a/models/rfd3/configs/experiment/rfd3na-ss.yaml b/models/rfd3/configs/experiment/rfd3na-ss.yaml index 614f20f..6d2245c 100644 --- a/models/rfd3/configs/experiment/rfd3na-ss.yaml +++ b/models/rfd3/configs/experiment/rfd3na-ss.yaml @@ -5,6 +5,7 @@ defaults: - /debug/default - override /model: rfd3_base - override /logger: null + - override /datasets: design_base_rfd3na - _self_ name: rfd3na-SScond @@ -54,7 +55,7 @@ datasets: max_atoms_in_crop: 2560 # ~10x crop size. global_transform_args: association_scheme: atom23 - add_na_pair_features: true + #add_na_pair_features: true train_conditions: unconditional: frequency: 2.0 @@ -69,19 +70,19 @@ datasets: train: # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. pdb: - probability: 1.0 - # rna_monomer_distillation: - # probability: 1.0 + probability: 0.5 + rna_monomer_distillation: + probability: 0.5 - # val: - # pseudoknot: - # dataset: - # # eval_every_n: 10 - # eval_every_n: 2 + val: + pseudoknot: + dataset: + # eval_every_n: 10 + eval_every_n: 1 trainer: devices_per_node: 1 limit_train_batches: 10 limit_val_batches: 1 validate_every_n_epochs: 5 - prevalidate: false + prevalidate: true diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index b3117fc..17032e4 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -9,7 +9,7 @@ from os import PathLike from typing import Any, Dict, List, Optional, Union import numpy as np -from atomworks.constants import STANDARD_AA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.parser import parse_atom_array # from atomworks.ml.datasets.datasets import BaseDataset @@ -119,7 +119,9 @@ class DesignInputSpecification(BaseModel): validate_assignment=False, str_strip_whitespace=True, str_min_length=1, - extra="forbid", + #extra="forbid", #################################################### + extra="allow" + ## for now allowing extra for rfd3na-ss purposes, can decide later ## ) # fmt: off # ======================================================================== @@ -494,6 +496,12 @@ class DesignInputSpecification(BaseModel): aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like( is_bkbn, False, dtype=int ) + elif aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA) and self.redesign_motif_sidechains: + is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA) + aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int) + aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like( + is_bkbn, False, dtype=int + ) # ... Apply selections on top apply_selections(start, end) @@ -509,17 +517,18 @@ class DesignInputSpecification(BaseModel): atom_array_input_annotated = copy.deepcopy(self.atom_array_input) ########## reorder NA atoms ########### - is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"]) - is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"]) - dna_array = atom_array_input_annotated[is_dna] - rna_array = atom_array_input_annotated[is_rna] + if exists(atom_array_input_annotated): + is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"]) + is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"]) + dna_array = atom_array_input_annotated[is_dna] + rna_array = atom_array_input_annotated[is_rna] - atom_array_input_annotated[is_dna] = reorder_atoms_per_residue( - dna_array, backbone_atoms_DNA - ) - atom_array_input_annotated[is_rna] = reorder_atoms_per_residue( - rna_array, backbone_atoms_RNA - ) + atom_array_input_annotated[is_dna] = reorder_atoms_per_residue( + dna_array, backbone_atoms_DNA + ) + atom_array_input_annotated[is_rna] = reorder_atoms_per_residue( + rna_array, backbone_atoms_RNA + ) ####################################### atom_array = self._build_init(atom_array_input_annotated) @@ -928,11 +937,11 @@ def create_diffused_residues(n, additional_annotations=None, polymer_type="P"): elif polymer_type == "R": res_name = "A" bb_len = len(backbone_atoms_RNA) - bb_atom_names = strip_list(backbone_atoms_RNA) + bb_atom_names = backbone_atoms_RNA elif polymer_type == "D": res_name = "DA" bb_len = len(backbone_atoms_DNA) - bb_atom_names = strip_list(backbone_atoms_DNA) + bb_atom_names = backbone_atoms_DNA else: raise ValueError( f"invalid polymer type detected: {polymer_type}, check contig!" @@ -955,12 +964,13 @@ def create_diffused_residues(n, additional_annotations=None, polymer_type="P"): for idx in range(1, n + 1) ] array = struc.array(atoms) - array.set_annotation("element", np.array(bb_elements * n, dtype=" 0.75 + ca_dists_protein = torch.norm(xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1) + ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1) + + deviation_protein = torch.abs(ca_dists_protein - self.standard_ca_dist) # B, (I-1) + deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1) + is_chainbreak_protein = deviation_protein > 0.75 + is_chainbreak_na = deviation_na > 1 - o["max_ca_deviation"] = float(deviation.max(-1).values.mean()) - o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean()) - o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean()) + try: + o["max_ca_deviation_protein"] = float(deviation_protein.max(-1).values.mean()) + o["fraction_chainbreaks_protein"] = float(is_chainbreak_protein.float().mean(-1).mean()) + o["n_chainbreaks_protein"] = float(is_chainbreak_protein.float().sum(-1).mean()) + except: + print("No protein in this example, skipping protein chainbreak metrics") + + try: + o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean()) + o["fraction_chainbreaks_na"] = float(is_chainbreak_na.float().mean(-1).mean()) + o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean()) + except: + print("No NA in this example, skipping NA chainbreak metrics") return o - class PPIMetrics(Metric): """PPI-specific metrics""" diff --git a/models/rfd3/src/rfd3/transforms/conditioning_base.py b/models/rfd3/src/rfd3/transforms/conditioning_base.py index 25b51a7..4050041 100644 --- a/models/rfd3/src/rfd3/transforms/conditioning_base.py +++ b/models/rfd3/src/rfd3/transforms/conditioning_base.py @@ -243,6 +243,10 @@ class SampleConditioningType(Transform): ) self.meta_conditioning_probabilities = meta_conditioning_probabilities self.train_conditions = train_conditions + + for item in self.train_conditions: + self.train_conditions[item].association_scheme = association_scheme + self.sequence_encoding = sequence_encoding self.association_scheme = association_scheme @@ -261,12 +265,15 @@ class SampleConditioningType(Transform): assert "conditions" in data, "Conditioning dict not initialized" def forward(self, data): + #for item in self.train_conditions: + # print(self.train_conditions[item].is_valid_for_example(data)) + valid_conditions = [ cond for cond in self.train_conditions.values() - if cond.frequency > 0 and cond.is_valid_for_example(data) + if cond.is_valid_for_example(data) and cond.frequency > 0 ] - + if len(valid_conditions) == 0: raise InvalidSampledConditionException("No valid condition was found.") @@ -280,8 +287,6 @@ class SampleConditioningType(Transform): i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond) cond = valid_conditions[i_cond] - cond.association_scheme = self.association_scheme - data["sampled_condition"] = cond data["sampled_condition_name"] = cond.name data["sampled_condition_cls"] = cond.__class__ diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py index 36a8b43..8d8da60 100644 --- a/models/rfd3/src/rfd3/transforms/design_transforms.py +++ b/models/rfd3/src/rfd3/transforms/design_transforms.py @@ -71,8 +71,10 @@ class SubsampleToTypes(Transform): def __init__( self, allowed_types: list | str = ["is_protein"], + association_scheme: str = 'atom14' ): self.allowed_types = allowed_types + self.association_scheme = association_scheme if not self.allowed_types == "ALL": for k in allowed_types: if not k.startswith("is_"): @@ -104,7 +106,7 @@ class SubsampleToTypes(Transform): ) ) - if atom_array.is_protein.sum() == 0: + if self.association_scheme != 'atom23' and atom_array.is_protein.sum() == 0: raise ValueError( "No protein atoms found in the atom array. Example ID: {}".format( data.get("example_id", "unknown") diff --git a/models/rfd3/src/rfd3/transforms/na_geom.py b/models/rfd3/src/rfd3/transforms/na_geom.py index ebb076d..2cef782 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom.py +++ b/models/rfd3/src/rfd3/transforms/na_geom.py @@ -119,7 +119,7 @@ class CalculateNucleicAcidGeomFeats(Transform): def __init__( self, is_inference, - add_nucleic_ss_feats: bool = True, + meta_conditioning_probabilities, p_is_nucleic_ss_example: float = 0.3, p_show_partial_feats: float = 0.5, @@ -136,9 +136,18 @@ class CalculateNucleicAcidGeomFeats(Transform): ): # Critical, must always have to know how to handle self.is_inference = is_inference - + if not self.is_inference: + ## relevant in training + self.sampling_prob = meta_conditioning_probabilities['calculate_NA_SS'] + else: + ## irrelevant in inference + self.sampling_prob = 0 # For sampling whether we add nucleic-ss features (extra t2d) - self.add_nucleic_ss_feats = add_nucleic_ss_feats + + # relevant in training + self.add_nucleic_ss_feats = (self.sampling_prob > 0) + ###### + self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical self.p_is_nucleic_ss_example = p_is_nucleic_ss_example self.nucleic_ss_min_shown = nucleic_ss_min_shown @@ -189,7 +198,7 @@ class CalculateNucleicAcidGeomFeats(Transform): n_tokens = len(token_starts) print(" DO I NEED TO CHANGE TO TOKEN_ID???") # Handle the training case with ground truth and masking: - if not self.is_inference: + if not self.is_inference and (np.random.rand() < self.sampling_prob): # First, annotate as usual # atom_array = annotate_na_ss(atom_array, **kwargs) @@ -226,8 +235,9 @@ class CalculateNucleicAcidGeomFeats(Transform): - 3). Lists of paired indices """ - is_nucleic_ss_example=True - give_partial_feats=False + #is_nucleic_ss_example=True + #give_partial_feats=False + atom_array = annotate_na_ss_from_data_specification( data, overwrite=True, @@ -282,4 +292,4 @@ class CalculateNucleicAcidGeomFeats(Transform): n_islands_max=self.n_islands_max, max_length=max_length, ) - \ No newline at end of file + diff --git a/models/rfd3/src/rfd3/transforms/na_geom_utils.py b/models/rfd3/src/rfd3/transforms/na_geom_utils.py index 61f8059..29be56f 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom_utils.py +++ b/models/rfd3/src/rfd3/transforms/na_geom_utils.py @@ -1321,6 +1321,8 @@ def bp_partner_to_ss_matrix( *, feature_info: Optional[dict] = None, mol_info: Optional[NucMolInfo] = None, + NA_only: Optional[bool] = False, + planar_only: Optional[bool] = False, include_loops: bool = True, token_level_data: Optional[dict] = None, ) -> np.ndarray: diff --git a/models/rfd3/src/rfd3/transforms/pipelines.py b/models/rfd3/src/rfd3/transforms/pipelines.py index d4c3361..f7eb9a0 100644 --- a/models/rfd3/src/rfd3/transforms/pipelines.py +++ b/models/rfd3/src/rfd3/transforms/pipelines.py @@ -196,6 +196,7 @@ def get_crop_transform( max_binder_length: int, max_atoms_in_crop: int | None, allowed_types: List[str], + association_scheme: str, ): if ( crop_contiguous_probability > 0 @@ -215,7 +216,7 @@ def get_crop_transform( ), "Crop center cutoff distance must be greater than 0" pre_crop_transforms = [ - SubsampleToTypes(allowed_types=allowed_types), + SubsampleToTypes(allowed_types=allowed_types, association_scheme=association_scheme), ] cropping_transform = RandomRoute( @@ -360,8 +361,11 @@ def build_atom14_base_pipeline_( max_ss_frac_to_provide: float, min_ss_island_len: int, max_ss_island_len: int, - # Nucleic acid features - add_na_pair_features: bool, + + ## Nucleic acid features ##### + #add_na_pair_features: bool, + ## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself ##### + **_, # dump additional kwargs (e.g. msa stuff) ): """ @@ -411,6 +415,7 @@ def build_atom14_base_pipeline_( max_binder_length=max_binder_length, max_atoms_in_crop=max_atoms_in_crop, allowed_types=allowed_types, + association_scheme=association_scheme ) if zero_occ_on_exposure_after_cropping: @@ -443,14 +448,15 @@ def build_atom14_base_pipeline_( ) ) # Add nucleic acid geometry features - if add_na_pair_features: - transforms.append( - CalculateNucleicAcidGeomFeats( - is_inference, - NA_only=False, - planar_only=True, - ) + #if add_na_pair_features: + transforms.append( + CalculateNucleicAcidGeomFeats( + is_inference, + meta_conditioning_probabilities, + NA_only=False, + planar_only=True, ) + ) # Design Transforms transforms += [ @@ -618,7 +624,6 @@ def build_atom14_base_pipeline( Wrapper around pipeline construction to handle empty training args Sets default behaviour for inference to keep backward compatibility """ - if is_inference: # Provide explicit defaults for training-only args kwargs.setdefault("crop_size", 512) @@ -634,7 +639,8 @@ def build_atom14_base_pipeline( kwargs.setdefault("min_ss_island_len", 0) kwargs.setdefault("max_ss_island_len", 999) kwargs.setdefault("max_binder_length", 999) - kwargs.setdefault("add_na_pair_features", False) + # This should not be necessary. + #kwargs.setdefault("add_na_pair_features", False) kwargs.setdefault("b_factor_min", None) kwargs.setdefault("zero_occ_on_exposure_after_cropping", False) diff --git a/models/rfd3/src/rfd3/transforms/training_conditions.py b/models/rfd3/src/rfd3/transforms/training_conditions.py index dbb07b7..fae2768 100644 --- a/models/rfd3/src/rfd3/transforms/training_conditions.py +++ b/models/rfd3/src/rfd3/transforms/training_conditions.py @@ -96,13 +96,13 @@ class IslandCondition(TrainingCondition): is_rna = data["atom_array"].is_rna ### updating this to allow other polymers if self.association_scheme == "atom23": - if not np.any(is_protein | is_dna | is_rna): - return False + if np.any(is_protein | is_dna | is_rna): + return True else: - if not np.any(is_protein): - return False - - return True + if np.any(is_protein): + return True + + return False def sample_motif_tokens(self, atom_array): """ diff --git a/models/rfd3/src/rfd3/transforms/util_transforms.py b/models/rfd3/src/rfd3/transforms/util_transforms.py index 75d69dd..37f08f1 100644 --- a/models/rfd3/src/rfd3/transforms/util_transforms.py +++ b/models/rfd3/src/rfd3/transforms/util_transforms.py @@ -37,7 +37,7 @@ def assert_single_representative(token, central_atom="CB"): mask = get_af3_token_representative_masks(token, central_atom=central_atom) assert ( np.sum(mask) == 1 - ), f"No representative atom (CB) found. mask: {mask}\nToken: {token}" + ), f"No representative atom ({central_atom}) found. mask: {mask}\nToken: {token}" def assert_single_token(token): diff --git a/models/rfd3/src/rfd3/transforms/virtual_atoms.py b/models/rfd3/src/rfd3/transforms/virtual_atoms.py index 6ce1675..07db4e1 100644 --- a/models/rfd3/src/rfd3/transforms/virtual_atoms.py +++ b/models/rfd3/src/rfd3/transforms/virtual_atoms.py @@ -225,17 +225,21 @@ class PadTokensWithVirtualAtoms(Transform): # First, pad with virtual atoms if needed if self.association_scheme == "atom23" and atom_array[start].is_dna: n_atoms_per_token = 22 + central_atom = "C1'" elif self.association_scheme == "atom23" and atom_array[start].is_rna: n_atoms_per_token = 23 + central_atom = "C1'" else: n_atoms_per_token = self.n_atoms_per_token + central_atom = self.atom_to_pad_from + n_pad = n_atoms_per_token - len(token) if n_pad > 0: mask = get_af3_token_representative_masks( - token, central_atom=self.atom_to_pad_from + token, central_atom=central_atom ) - assert_single_representative(token) + assert_single_representative(token, central_atom=central_atom) # ... Create virtual atoms pad_atoms = token[mask].copy()