diff --git a/models/rfd3/configs/experiment/rfd3na_af.yaml b/models/rfd3/configs/experiment/rfd3na_af.yaml new file mode 100644 index 0000000..e9331b7 --- /dev/null +++ b/models/rfd3/configs/experiment/rfd3na_af.yaml @@ -0,0 +1,99 @@ +# @package _global_ +# Training configuration for RFD3 + +defaults: + - /debug/default + - override /model: rfd3_base + # - override /logger: wandb + - override /datasets: design_base_rfd3na + - _self_ + +name: rfd3na-fine-tune +tags: [print-model] +ckpt_path: /projects/ml/aa_design/models/rfd3_latest_foundry.ckpt + +model: + net: + token_initializer: + token_1d_features: + ref_motif_token_type: 3 + restype: 32 + is_dna_token: 1 + is_rna_token: 1 + is_protein_token: 1 + token_2d_features: + bp_partners: 3 # Unspecified, pair, loop + atom_1d_features: + ref_atom_name_chars: 256 + ref_element: 128 + ref_charge: 1 + ref_mask: 1 + ref_is_motif_atom_with_fixed_coord: 1 + ref_is_motif_atom_unindexed: 1 + has_zero_occupancy: 1 + ref_pos: 3 + + # Guided features + ref_atomwise_rasa: 3 + active_donor: 1 + active_acceptor: 1 + is_atom_level_hotspot: 1 + diffusion_module: + n_recycle: 2 + use_local_token_attention: True + diffusion_transformer: + n_local_tokens: 32 + n_keys: 128 + + inference_sampler: + num_timesteps: 100 + + +datasets: + diffusion_batch_size_train: 16 + crop_size: 256 + max_atoms_in_crop: 2560 # ~10x crop size. + global_transform_args: + meta_conditioning_probabilities: + # p_is_nucleic_ss_example: 0.25 + # p_nucleic_ss_show_partial_feats: 0.7 + # p_canonical_bp_filter: 0.2 + p_is_nucleic_ss_example: 1.0 + p_nucleic_ss_show_partial_feats: 0.0 + p_canonical_bp_filter: 0.0 + #calculate_NA_SS: 0.3 + + association_scheme: atom23 + #add_na_pair_features: true + train_conditions: + unconditional: + frequency: 2.0 + island: + frequency: 2.0 + sequence_design: + frequency: 0.5 + tipatom: + frequency: 5.0 + ppi: + frequency: 0.0 + train: + # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. + pdb: + probability: 0.0 + rna_monomer_distillation: + probability: 1.0 + monomer_distillation: + probability: 0.0 + + val: + pseudoknot: + dataset: + # eval_every_n: 10 + eval_every_n: 5 + +trainer: + #devices_per_node: 1 + #limit_train_batches: 10 + #limit_val_batches: 1 + validate_every_n_epochs: 5 + prevalidate: true diff --git a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py index cc7633e..345c26b 100644 --- a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py +++ b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py @@ -230,8 +230,24 @@ class NucleicSSSimilarityMetrics(Metric): # prediction can inherit it, yielding artificially perfect scores. # Optionally recompute bp_partners from the *predicted coordinates*. if self.annotate_predicted_fresh: - #pred_arr = _cleanup_virtual_atoms_and_assign_atom_name_elements(pred_arr, association_scheme = "atom23") - pred_arr = _readout_seq_from_struc(pred_arr, central_atom="C1'", threshold=0.5, association_scheme = "atom23") + + # Infer res name from geometry first + pred_arr = _readout_seq_from_struc( + pred_arr, + central_atom="C1'", + threshold=0.5, + association_scheme="atom23", + ) + # strip virtuals and set final atom names/elements + pred_arr = _cleanup_virtual_atoms_and_assign_atom_name_elements( + pred_arr, + association_scheme="atom23", + ) + # clear annotation to avoid potential info leak + if "bp_partners" in pred_arr.get_annotation_categories(): + pred_arr.del_annotation("bp_partners") + + # add nucleic-ss annotations annotate_na_ss( pred_arr, NA_only=self.annotation_NA_only, @@ -239,7 +255,7 @@ class NucleicSSSimilarityMetrics(Metric): overwrite=True, p_canonical_bp_filter=0.0, ) - import pdb; pdb.set_trace()#TODO + pred_categories = pred_arr.get_annotation_categories() if "bp_partners" not in pred_categories: continue @@ -304,7 +320,13 @@ class NucleicSSSimilarityMetrics(Metric): if n_valid == 0: return {} - + aaa = { + "pair_f1": float(np.mean(pair_f1_list)), + "loop_f1": float(np.mean(loop_f1_list)), + "weighted_f1": float(np.mean(weighted_f1_list)), + "n_valid_samples": int(n_valid), + } + print(aaa) return { "pair_f1": float(np.mean(pair_f1_list)), "loop_f1": float(np.mean(loop_f1_list)),