mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
time to fix ss metric line 242
This commit is contained in:
@@ -6,5 +6,5 @@ defaults:
|
||||
dataset:
|
||||
name: pseudoknot
|
||||
eval_every_n: 1
|
||||
# data: ${paths.data.design_benchmark_data_dir}/pseudoknot_debug.json
|
||||
data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json
|
||||
data: ${paths.data.design_benchmark_data_dir}/pseudoknot_debug.json
|
||||
#data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json
|
||||
|
||||
@@ -2,15 +2,18 @@
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
#- /debug/default
|
||||
- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: wandb
|
||||
#- override /logger: wandb
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: rfd3na_scratch
|
||||
name: rfd3na_scratch_clean_test
|
||||
tags: [print-model]
|
||||
ckpt_path: null
|
||||
|
||||
#ckpt_path: /net/scratch/raktim/training/logs/train/rfd3na-fine-tune/2026-02-17_15-21_JOB_3608285/ckpt/epoch-0590.ckpt
|
||||
#ckpt_path: /net/scratch/raktim/training/logs/train/rfd3na_scratch/2026-02-17_17-56_JOB_3620867/ckpt/epoch-0030.ckpt
|
||||
|
||||
|
||||
model:
|
||||
net:
|
||||
@@ -55,7 +58,7 @@ datasets:
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
meta_conditioning_probabilities:
|
||||
p_is_nucleic_ss_example: 0.25
|
||||
p_is_nucleic_ss_example: 1.0
|
||||
p_nucleic_ss_show_partial_feats: 0.7
|
||||
p_canonical_bp_filter: 0.2
|
||||
#calculate_NA_SS: 0.3
|
||||
|
||||
96
models/rfd3/configs/experiment/rfd3na_fine_tune.yaml
Normal file
96
models/rfd3/configs/experiment/rfd3na_fine_tune.yaml
Normal file
@@ -0,0 +1,96 @@
|
||||
# @package _global_
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
#- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: wandb
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: rfd3na-fine-tune
|
||||
tags: [print-model]
|
||||
ckpt_path: /projects/ml/aa_design/models/rfd3_latest_foundry.ckpt
|
||||
|
||||
model:
|
||||
net:
|
||||
token_initializer:
|
||||
token_1d_features:
|
||||
ref_motif_token_type: 3
|
||||
restype: 32
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
token_2d_features:
|
||||
bp_partners: 3 # Unspecified, pair, loop
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
ref_charge: 1
|
||||
ref_mask: 1
|
||||
ref_is_motif_atom_with_fixed_coord: 1
|
||||
ref_is_motif_atom_unindexed: 1
|
||||
has_zero_occupancy: 1
|
||||
ref_pos: 3
|
||||
|
||||
# Guided features
|
||||
ref_atomwise_rasa: 3
|
||||
active_donor: 1
|
||||
active_acceptor: 1
|
||||
is_atom_level_hotspot: 1
|
||||
diffusion_module:
|
||||
n_recycle: 2
|
||||
use_local_token_attention: True
|
||||
diffusion_transformer:
|
||||
n_local_tokens: 32
|
||||
n_keys: 128
|
||||
|
||||
inference_sampler:
|
||||
num_timesteps: 100
|
||||
|
||||
|
||||
datasets:
|
||||
diffusion_batch_size_train: 16
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
meta_conditioning_probabilities:
|
||||
p_is_nucleic_ss_example: 0.25
|
||||
p_nucleic_ss_show_partial_feats: 0.7
|
||||
p_canonical_bp_filter: 0.2
|
||||
#calculate_NA_SS: 0.3
|
||||
|
||||
association_scheme: atom23
|
||||
#add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 2.0
|
||||
sequence_design:
|
||||
frequency: 0.5
|
||||
tipatom:
|
||||
frequency: 5.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
pdb:
|
||||
probability: 0.5
|
||||
rna_monomer_distillation:
|
||||
probability: 0.3
|
||||
monomer_distillation:
|
||||
probability: 0.2
|
||||
|
||||
val:
|
||||
pseudoknot:
|
||||
dataset:
|
||||
# eval_every_n: 10
|
||||
eval_every_n: 5
|
||||
|
||||
trainer:
|
||||
#devices_per_node: 1
|
||||
#limit_train_batches: 10
|
||||
#limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: true
|
||||
96
models/rfd3/configs/experiment/rfd3na_no_distill.yaml
Normal file
96
models/rfd3/configs/experiment/rfd3na_no_distill.yaml
Normal file
@@ -0,0 +1,96 @@
|
||||
# @package _global_
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
#- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: wandb
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: rfd3na_no_distill
|
||||
tags: [print-model]
|
||||
ckpt_path: /net/scratch/raktim/training/logs/train/rfd3na_no_distill/2026-02-17_15-21_JOB_3608348/ckpt/epoch-0020.ckpt
|
||||
|
||||
model:
|
||||
net:
|
||||
token_initializer:
|
||||
token_1d_features:
|
||||
ref_motif_token_type: 3
|
||||
restype: 32
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
token_2d_features:
|
||||
bp_partners: 3 # Unspecified, pair, loop
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
ref_charge: 1
|
||||
ref_mask: 1
|
||||
ref_is_motif_atom_with_fixed_coord: 1
|
||||
ref_is_motif_atom_unindexed: 1
|
||||
has_zero_occupancy: 1
|
||||
ref_pos: 3
|
||||
|
||||
# Guided features
|
||||
ref_atomwise_rasa: 3
|
||||
active_donor: 1
|
||||
active_acceptor: 1
|
||||
is_atom_level_hotspot: 1
|
||||
diffusion_module:
|
||||
n_recycle: 2
|
||||
use_local_token_attention: True
|
||||
diffusion_transformer:
|
||||
n_local_tokens: 32
|
||||
n_keys: 128
|
||||
|
||||
inference_sampler:
|
||||
num_timesteps: 100
|
||||
|
||||
|
||||
datasets:
|
||||
diffusion_batch_size_train: 16
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
meta_conditioning_probabilities:
|
||||
p_is_nucleic_ss_example: 0.25
|
||||
p_nucleic_ss_show_partial_feats: 0.7
|
||||
p_canonical_bp_filter: 0.2
|
||||
#calculate_NA_SS: 0.3
|
||||
|
||||
association_scheme: atom23
|
||||
#add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 2.0
|
||||
sequence_design:
|
||||
frequency: 0.5
|
||||
tipatom:
|
||||
frequency: 5.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
pdb:
|
||||
probability: 0.6
|
||||
rna_monomer_distillation:
|
||||
probability: 0.4
|
||||
monomer_distillation:
|
||||
probability: 0.0
|
||||
|
||||
val:
|
||||
pseudoknot:
|
||||
dataset:
|
||||
# eval_every_n: 10
|
||||
eval_every_n: 5
|
||||
|
||||
trainer:
|
||||
#devices_per_node: 1
|
||||
#limit_train_batches: 10
|
||||
#limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: true
|
||||
90
models/rfd3/configs/experiment/rfd3na_rm.yaml
Normal file
90
models/rfd3/configs/experiment/rfd3na_rm.yaml
Normal file
@@ -0,0 +1,90 @@
|
||||
# @package _global_
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
#- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: wandb
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: rfd3na
|
||||
tags: [print-model]
|
||||
ckpt_path: null
|
||||
|
||||
model:
|
||||
net:
|
||||
token_initializer:
|
||||
token_1d_features:
|
||||
ref_motif_token_type: 3
|
||||
restype: 32
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
#token_2d_features:
|
||||
#bp_partners: 3 # Unspecified, pair, loop
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
ref_charge: 1
|
||||
ref_mask: 1
|
||||
ref_is_motif_atom_with_fixed_coord: 1
|
||||
ref_is_motif_atom_unindexed: 1
|
||||
has_zero_occupancy: 1
|
||||
ref_pos: 3
|
||||
|
||||
# Guided features
|
||||
ref_atomwise_rasa: 3
|
||||
active_donor: 1
|
||||
active_acceptor: 1
|
||||
is_atom_level_hotspot: 1
|
||||
diffusion_module:
|
||||
n_recycle: 2
|
||||
use_local_token_attention: True
|
||||
diffusion_transformer:
|
||||
n_local_tokens: 32
|
||||
n_keys: 128
|
||||
|
||||
inference_sampler:
|
||||
num_timesteps: 100
|
||||
|
||||
|
||||
datasets:
|
||||
diffusion_batch_size_train: 16
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
meta_conditioning_probabilities:
|
||||
calculate_NA_SS: 0.0
|
||||
association_scheme: atom23
|
||||
#add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 2.0
|
||||
sequence_design:
|
||||
frequency: 0.5
|
||||
tipatom:
|
||||
frequency: 5.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
pdb:
|
||||
probability: 0.5
|
||||
rna_monomer_distillation:
|
||||
probability: 0.5
|
||||
|
||||
val:
|
||||
pseudoknot:
|
||||
dataset:
|
||||
# eval_every_n: 10
|
||||
eval_every_n: 5
|
||||
|
||||
trainer:
|
||||
#devices_per_node: 1
|
||||
#limit_train_batches: 10
|
||||
limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: false
|
||||
@@ -12,6 +12,8 @@ from rfd3.transforms.na_geom_utils import annotate_na_ss
|
||||
from foundry.metrics.metric import Metric
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
|
||||
from rfd3.trainer.trainer_utils import _readout_seq_from_struc, _cleanup_virtual_atoms_and_assign_atom_name_elements
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
@@ -228,6 +230,8 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
# prediction can inherit it, yielding artificially perfect scores.
|
||||
# Optionally recompute bp_partners from the *predicted coordinates*.
|
||||
if self.annotate_predicted_fresh:
|
||||
#pred_arr = _cleanup_virtual_atoms_and_assign_atom_name_elements(pred_arr, association_scheme = "atom23")
|
||||
pred_arr = _readout_seq_from_struc(pred_arr, central_atom="C1'", threshold=0.5, association_scheme = "atom23")
|
||||
annotate_na_ss(
|
||||
pred_arr,
|
||||
NA_only=self.annotation_NA_only,
|
||||
@@ -235,7 +239,7 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
overwrite=True,
|
||||
p_canonical_bp_filter=0.0,
|
||||
)
|
||||
|
||||
import pdb; pdb.set_trace()#TODO
|
||||
pred_categories = pred_arr.get_annotation_categories()
|
||||
if "bp_partners" not in pred_categories:
|
||||
continue
|
||||
|
||||
@@ -152,7 +152,6 @@ def _cleanup_virtual_atoms_and_assign_atom_name_elements(
|
||||
is_seq_known = all(
|
||||
np.array(res_array.is_motif_atom_with_fixed_seq, dtype=bool)
|
||||
) or all(np.array(res_array.is_motif_atom_unindexed, dtype=bool))
|
||||
|
||||
# ... If sequence is known for the original atom array, just skip
|
||||
if is_seq_known:
|
||||
ret_mask += [True] * len(res_array)
|
||||
|
||||
Reference in New Issue
Block a user