diff --git a/models/rfd3/configs/datasets/val/pseudoknot.yaml b/models/rfd3/configs/datasets/val/pseudoknot.yaml index 895e6c2..d3fa6f7 100644 --- a/models/rfd3/configs/datasets/val/pseudoknot.yaml +++ b/models/rfd3/configs/datasets/val/pseudoknot.yaml @@ -6,5 +6,5 @@ defaults: dataset: name: pseudoknot eval_every_n: 1 - # data: ${paths.data.design_benchmark_data_dir}/pseudoknot_debug.json - data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json + data: ${paths.data.design_benchmark_data_dir}/pseudoknot_debug.json + #data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json diff --git a/models/rfd3/configs/experiment/rfd3na.yaml b/models/rfd3/configs/experiment/rfd3na.yaml index 32e4382..e1194be 100644 --- a/models/rfd3/configs/experiment/rfd3na.yaml +++ b/models/rfd3/configs/experiment/rfd3na.yaml @@ -2,15 +2,18 @@ # Training configuration for RFD3 defaults: - #- /debug/default + - /debug/default - override /model: rfd3_base - - override /logger: wandb + #- override /logger: wandb - override /datasets: design_base_rfd3na - _self_ -name: rfd3na_scratch +name: rfd3na_scratch_clean_test tags: [print-model] -ckpt_path: null + +#ckpt_path: /net/scratch/raktim/training/logs/train/rfd3na-fine-tune/2026-02-17_15-21_JOB_3608285/ckpt/epoch-0590.ckpt +#ckpt_path: /net/scratch/raktim/training/logs/train/rfd3na_scratch/2026-02-17_17-56_JOB_3620867/ckpt/epoch-0030.ckpt + model: net: @@ -55,7 +58,7 @@ datasets: max_atoms_in_crop: 2560 # ~10x crop size. global_transform_args: meta_conditioning_probabilities: - p_is_nucleic_ss_example: 0.25 + p_is_nucleic_ss_example: 1.0 p_nucleic_ss_show_partial_feats: 0.7 p_canonical_bp_filter: 0.2 #calculate_NA_SS: 0.3 diff --git a/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml b/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml new file mode 100644 index 0000000..e731063 --- /dev/null +++ b/models/rfd3/configs/experiment/rfd3na_fine_tune.yaml @@ -0,0 +1,96 @@ +# @package _global_ +# Training configuration for RFD3 + +defaults: + #- /debug/default + - override /model: rfd3_base + - override /logger: wandb + - override /datasets: design_base_rfd3na + - _self_ + +name: rfd3na-fine-tune +tags: [print-model] +ckpt_path: /projects/ml/aa_design/models/rfd3_latest_foundry.ckpt + +model: + net: + token_initializer: + token_1d_features: + ref_motif_token_type: 3 + restype: 32 + is_dna_token: 1 + is_rna_token: 1 + is_protein_token: 1 + token_2d_features: + bp_partners: 3 # Unspecified, pair, loop + atom_1d_features: + ref_atom_name_chars: 256 + ref_element: 128 + ref_charge: 1 + ref_mask: 1 + ref_is_motif_atom_with_fixed_coord: 1 + ref_is_motif_atom_unindexed: 1 + has_zero_occupancy: 1 + ref_pos: 3 + + # Guided features + ref_atomwise_rasa: 3 + active_donor: 1 + active_acceptor: 1 + is_atom_level_hotspot: 1 + diffusion_module: + n_recycle: 2 + use_local_token_attention: True + diffusion_transformer: + n_local_tokens: 32 + n_keys: 128 + + inference_sampler: + num_timesteps: 100 + + +datasets: + diffusion_batch_size_train: 16 + crop_size: 256 + max_atoms_in_crop: 2560 # ~10x crop size. + global_transform_args: + meta_conditioning_probabilities: + p_is_nucleic_ss_example: 0.25 + p_nucleic_ss_show_partial_feats: 0.7 + p_canonical_bp_filter: 0.2 + #calculate_NA_SS: 0.3 + + association_scheme: atom23 + #add_na_pair_features: true + train_conditions: + unconditional: + frequency: 2.0 + island: + frequency: 2.0 + sequence_design: + frequency: 0.5 + tipatom: + frequency: 5.0 + ppi: + frequency: 0.0 + train: + # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. + pdb: + probability: 0.5 + rna_monomer_distillation: + probability: 0.3 + monomer_distillation: + probability: 0.2 + + val: + pseudoknot: + dataset: + # eval_every_n: 10 + eval_every_n: 5 + +trainer: + #devices_per_node: 1 + #limit_train_batches: 10 + #limit_val_batches: 1 + validate_every_n_epochs: 5 + prevalidate: true diff --git a/models/rfd3/configs/experiment/rfd3na_no_distill.yaml b/models/rfd3/configs/experiment/rfd3na_no_distill.yaml new file mode 100644 index 0000000..e3642cc --- /dev/null +++ b/models/rfd3/configs/experiment/rfd3na_no_distill.yaml @@ -0,0 +1,96 @@ +# @package _global_ +# Training configuration for RFD3 + +defaults: + #- /debug/default + - override /model: rfd3_base + - override /logger: wandb + - override /datasets: design_base_rfd3na + - _self_ + +name: rfd3na_no_distill +tags: [print-model] +ckpt_path: /net/scratch/raktim/training/logs/train/rfd3na_no_distill/2026-02-17_15-21_JOB_3608348/ckpt/epoch-0020.ckpt + +model: + net: + token_initializer: + token_1d_features: + ref_motif_token_type: 3 + restype: 32 + is_dna_token: 1 + is_rna_token: 1 + is_protein_token: 1 + token_2d_features: + bp_partners: 3 # Unspecified, pair, loop + atom_1d_features: + ref_atom_name_chars: 256 + ref_element: 128 + ref_charge: 1 + ref_mask: 1 + ref_is_motif_atom_with_fixed_coord: 1 + ref_is_motif_atom_unindexed: 1 + has_zero_occupancy: 1 + ref_pos: 3 + + # Guided features + ref_atomwise_rasa: 3 + active_donor: 1 + active_acceptor: 1 + is_atom_level_hotspot: 1 + diffusion_module: + n_recycle: 2 + use_local_token_attention: True + diffusion_transformer: + n_local_tokens: 32 + n_keys: 128 + + inference_sampler: + num_timesteps: 100 + + +datasets: + diffusion_batch_size_train: 16 + crop_size: 256 + max_atoms_in_crop: 2560 # ~10x crop size. + global_transform_args: + meta_conditioning_probabilities: + p_is_nucleic_ss_example: 0.25 + p_nucleic_ss_show_partial_feats: 0.7 + p_canonical_bp_filter: 0.2 + #calculate_NA_SS: 0.3 + + association_scheme: atom23 + #add_na_pair_features: true + train_conditions: + unconditional: + frequency: 2.0 + island: + frequency: 2.0 + sequence_design: + frequency: 0.5 + tipatom: + frequency: 5.0 + ppi: + frequency: 0.0 + train: + # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. + pdb: + probability: 0.6 + rna_monomer_distillation: + probability: 0.4 + monomer_distillation: + probability: 0.0 + + val: + pseudoknot: + dataset: + # eval_every_n: 10 + eval_every_n: 5 + +trainer: + #devices_per_node: 1 + #limit_train_batches: 10 + #limit_val_batches: 1 + validate_every_n_epochs: 5 + prevalidate: true diff --git a/models/rfd3/configs/experiment/rfd3na_rm.yaml b/models/rfd3/configs/experiment/rfd3na_rm.yaml new file mode 100644 index 0000000..1596cb1 --- /dev/null +++ b/models/rfd3/configs/experiment/rfd3na_rm.yaml @@ -0,0 +1,90 @@ +# @package _global_ +# Training configuration for RFD3 + +defaults: + #- /debug/default + - override /model: rfd3_base + - override /logger: wandb + - override /datasets: design_base_rfd3na + - _self_ + +name: rfd3na +tags: [print-model] +ckpt_path: null + +model: + net: + token_initializer: + token_1d_features: + ref_motif_token_type: 3 + restype: 32 + is_dna_token: 1 + is_rna_token: 1 + is_protein_token: 1 + #token_2d_features: + #bp_partners: 3 # Unspecified, pair, loop + atom_1d_features: + ref_atom_name_chars: 256 + ref_element: 128 + ref_charge: 1 + ref_mask: 1 + ref_is_motif_atom_with_fixed_coord: 1 + ref_is_motif_atom_unindexed: 1 + has_zero_occupancy: 1 + ref_pos: 3 + + # Guided features + ref_atomwise_rasa: 3 + active_donor: 1 + active_acceptor: 1 + is_atom_level_hotspot: 1 + diffusion_module: + n_recycle: 2 + use_local_token_attention: True + diffusion_transformer: + n_local_tokens: 32 + n_keys: 128 + + inference_sampler: + num_timesteps: 100 + + +datasets: + diffusion_batch_size_train: 16 + crop_size: 256 + max_atoms_in_crop: 2560 # ~10x crop size. + global_transform_args: + meta_conditioning_probabilities: + calculate_NA_SS: 0.0 + association_scheme: atom23 + #add_na_pair_features: true + train_conditions: + unconditional: + frequency: 2.0 + island: + frequency: 2.0 + sequence_design: + frequency: 0.5 + tipatom: + frequency: 5.0 + ppi: + frequency: 0.0 + train: + # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. + pdb: + probability: 0.5 + rna_monomer_distillation: + probability: 0.5 + + val: + pseudoknot: + dataset: + # eval_every_n: 10 + eval_every_n: 5 + +trainer: + #devices_per_node: 1 + #limit_train_batches: 10 + limit_val_batches: 1 + validate_every_n_epochs: 5 + prevalidate: false diff --git a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py index 62d077f..cc7633e 100644 --- a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py +++ b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py @@ -12,6 +12,8 @@ from rfd3.transforms.na_geom_utils import annotate_na_ss from foundry.metrics.metric import Metric from foundry.utils.ddp import RankedLogger +from rfd3.trainer.trainer_utils import _readout_seq_from_struc, _cleanup_virtual_atoms_and_assign_atom_name_elements + logging.basicConfig(level=logging.INFO) global_logger = RankedLogger(__name__, rank_zero_only=False) @@ -228,6 +230,8 @@ class NucleicSSSimilarityMetrics(Metric): # prediction can inherit it, yielding artificially perfect scores. # Optionally recompute bp_partners from the *predicted coordinates*. if self.annotate_predicted_fresh: + #pred_arr = _cleanup_virtual_atoms_and_assign_atom_name_elements(pred_arr, association_scheme = "atom23") + pred_arr = _readout_seq_from_struc(pred_arr, central_atom="C1'", threshold=0.5, association_scheme = "atom23") annotate_na_ss( pred_arr, NA_only=self.annotation_NA_only, @@ -235,7 +239,7 @@ class NucleicSSSimilarityMetrics(Metric): overwrite=True, p_canonical_bp_filter=0.0, ) - + import pdb; pdb.set_trace()#TODO pred_categories = pred_arr.get_annotation_categories() if "bp_partners" not in pred_categories: continue diff --git a/models/rfd3/src/rfd3/trainer/trainer_utils.py b/models/rfd3/src/rfd3/trainer/trainer_utils.py index 652665a..436d80e 100644 --- a/models/rfd3/src/rfd3/trainer/trainer_utils.py +++ b/models/rfd3/src/rfd3/trainer/trainer_utils.py @@ -152,7 +152,6 @@ def _cleanup_virtual_atoms_and_assign_atom_name_elements( is_seq_known = all( np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool) ) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool)) - # ... If sequence is known for the original atom array, just skip if is_seq_known: ret_mask += [True] * len(res_array)