mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
cleanup; aptamer metric; monomer dist on; send trainininggit status
This commit is contained in:
committed by
Raktim Mitra
parent
91a0eb22ec
commit
de06df8cbf
@@ -5,7 +5,7 @@ defaults:
|
||||
# Grab datasets
|
||||
- train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
|
||||
- train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit
|
||||
#- train/rfd3_monomer_distillation@train
|
||||
- train/rfd3_monomer_distillation@train
|
||||
- train/rna_monomer_distillation@train
|
||||
|
||||
# Customized validation datasets
|
||||
|
||||
@@ -12,6 +12,7 @@ dataset:
|
||||
# filters common across all PDB datasets
|
||||
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
|
||||
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
|
||||
- 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]'
|
||||
- "deposition_date < '2024-12-16'"
|
||||
- "resolution < 9.0"
|
||||
- "num_polymer_pn_units <= 300"
|
||||
@@ -19,4 +20,4 @@ dataset:
|
||||
# interface specific filters
|
||||
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
|
||||
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
|
||||
- "is_inter_molecule"
|
||||
- "is_inter_molecule"
|
||||
|
||||
@@ -15,6 +15,7 @@ dataset:
|
||||
# filters common across all PDB datasets
|
||||
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
|
||||
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
|
||||
- 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]'
|
||||
- "deposition_date < '2024-12-16'"
|
||||
- "resolution < 9.0"
|
||||
- "num_polymer_pn_units <= 300"
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
# @package _global_
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
#- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: wandb
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: rfd3na
|
||||
tags: [print-model]
|
||||
ckpt_path: null
|
||||
|
||||
model:
|
||||
net:
|
||||
token_initializer:
|
||||
token_1d_features:
|
||||
ref_motif_token_type: 3
|
||||
restype: 32
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
token_2d_features:
|
||||
bp_partners: 3 # Unspecified, pair, loop
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
ref_charge: 1
|
||||
ref_mask: 1
|
||||
ref_is_motif_atom_with_fixed_coord: 1
|
||||
ref_is_motif_atom_unindexed: 1
|
||||
has_zero_occupancy: 1
|
||||
ref_pos: 3
|
||||
|
||||
# Guided features
|
||||
ref_atomwise_rasa: 3
|
||||
active_donor: 1
|
||||
active_acceptor: 1
|
||||
is_atom_level_hotspot: 1
|
||||
diffusion_module:
|
||||
n_recycle: 2
|
||||
use_local_token_attention: True
|
||||
diffusion_transformer:
|
||||
n_local_tokens: 32
|
||||
n_keys: 128
|
||||
|
||||
inference_sampler:
|
||||
num_timesteps: 100
|
||||
|
||||
|
||||
datasets:
|
||||
diffusion_batch_size_train: 16
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
meta_conditioning_probabilities:
|
||||
p_is_nucleic_ss_example: 1.0
|
||||
p_nucleic_ss_show_partial_feats: 0.7
|
||||
p_canonical_bp_filter: 0.2
|
||||
calculate_NA_SS: 0.3
|
||||
|
||||
association_scheme: atom23
|
||||
#add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 2.0
|
||||
sequence_design:
|
||||
frequency: 0.5
|
||||
tipatom:
|
||||
frequency: 5.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
pdb:
|
||||
# probability: 0.5
|
||||
probability: 0.75
|
||||
# probability: 0.0
|
||||
rna_monomer_distillation:
|
||||
# probability: 0.2
|
||||
probability: 0.25
|
||||
# probability: 1.0
|
||||
|
||||
val:
|
||||
pseudoknot:
|
||||
dataset:
|
||||
# eval_every_n: 10
|
||||
eval_every_n: 5
|
||||
|
||||
trainer:
|
||||
#devices_per_node: 1
|
||||
#limit_train_batches: 10
|
||||
#limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: true
|
||||
@@ -2,14 +2,13 @@
|
||||
# Training configuration for RFD3
|
||||
|
||||
defaults:
|
||||
- /debug/default
|
||||
#- /debug/default
|
||||
- override /model: rfd3_base
|
||||
#- override /datasets: all
|
||||
- override /logger: csv
|
||||
#- override /logger: wandb
|
||||
- override /logger: wandb
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: train-base
|
||||
name: rfd3na_scratch
|
||||
tags: [print-model]
|
||||
ckpt_path: null
|
||||
|
||||
@@ -22,6 +21,8 @@ model:
|
||||
is_dna_token: 1
|
||||
is_rna_token: 1
|
||||
is_protein_token: 1
|
||||
token_2d_features:
|
||||
bp_partners: 3 # Unspecified, pair, loop
|
||||
atom_1d_features:
|
||||
ref_atom_name_chars: 256
|
||||
ref_element: 128
|
||||
@@ -53,7 +54,14 @@ datasets:
|
||||
crop_size: 256
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
meta_conditioning_probabilities:
|
||||
p_is_nucleic_ss_example: 0.25
|
||||
p_nucleic_ss_show_partial_feats: 0.7
|
||||
p_canonical_bp_filter: 0.2
|
||||
#calculate_NA_SS: 0.3
|
||||
|
||||
association_scheme: atom23
|
||||
#add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
@@ -67,16 +75,22 @@ datasets:
|
||||
frequency: 0.0
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
#pdb:
|
||||
#probability: 0.10
|
||||
#monomer_distillation:
|
||||
#probability: 0.90
|
||||
pdb:
|
||||
probability: 1.0
|
||||
probability: 0.5
|
||||
rna_monomer_distillation:
|
||||
probability: 0.3
|
||||
monomer_distillation:
|
||||
probability: 0.2
|
||||
|
||||
val:
|
||||
pseudoknot:
|
||||
dataset:
|
||||
# eval_every_n: 10
|
||||
eval_every_n: 5
|
||||
|
||||
trainer:
|
||||
devices_per_node: 1
|
||||
limit_train_batches: 10
|
||||
limit_val_batches: 1
|
||||
#devices_per_node: 1
|
||||
#limit_train_batches: 10
|
||||
#limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: false
|
||||
prevalidate: true
|
||||
|
||||
@@ -29,3 +29,7 @@ nucleic_ss_similarity:
|
||||
annotation_NA_only: False
|
||||
annotation_planar_only: True
|
||||
|
||||
rna_aptamer_contacts:
|
||||
_target_: rfd3.metrics.rna_aptamer_metrics.LigandContactMetrics
|
||||
restrict_to_nucleic: True
|
||||
|
||||
|
||||
109
models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py
Normal file
109
models/rfd3/src/rfd3/metrics/rna_aptamer_metrics.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import logging
|
||||
|
||||
|
||||
from foundry.metrics.metric import Metric
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def calculate_ligand_contacts(
|
||||
atom_array_stack,
|
||||
cutoff_distance=4.0,
|
||||
):
|
||||
"""
|
||||
Count number of atom contacts within cutoff of any ligand atom.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
atom_array_stack : AtomArrayStack
|
||||
Shape: (n_models, n_atoms)
|
||||
cutoff_distance : float
|
||||
Distance cutoff in Å
|
||||
|
||||
Returns
|
||||
-------
|
||||
total_contacts : int
|
||||
mean_contacts_per_model : float
|
||||
"""
|
||||
|
||||
cutoff_sq = cutoff_distance ** 2
|
||||
contacts_per_model = []
|
||||
|
||||
n_models = len(atom_array_stack)
|
||||
|
||||
for i in range(n_models):
|
||||
|
||||
atoms = atom_array_stack[i]
|
||||
|
||||
coords = atoms.coord
|
||||
hetero_mask = atoms.hetero.astype(bool)
|
||||
|
||||
# Skip if no ligand
|
||||
if not np.any(hetero_mask):
|
||||
contacts_per_model.append(0)
|
||||
continue
|
||||
|
||||
ligand_coords = coords[hetero_mask]
|
||||
non_ligand_coords = coords[~hetero_mask]
|
||||
|
||||
if len(non_ligand_coords) == 0:
|
||||
contacts_per_model.append(0)
|
||||
continue
|
||||
|
||||
# Pairwise squared distances
|
||||
diff = non_ligand_coords[:, None, :] - ligand_coords[None, :, :]
|
||||
dist_sq = np.sum(diff ** 2, axis=-1)
|
||||
|
||||
# Any ligand within cutoff
|
||||
contact_mask = np.any(dist_sq < cutoff_sq, axis=1)
|
||||
|
||||
n_contacts = np.sum(contact_mask)
|
||||
contacts_per_model.append(n_contacts)
|
||||
|
||||
contacts_per_model = np.array(contacts_per_model)
|
||||
|
||||
return int(np.sum(contacts_per_model)), float(np.mean(contacts_per_model)), float(np.mean(contacts_per_model))/hetero_mask.sum()
|
||||
|
||||
|
||||
class LigandContactMetrics(Metric):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cutoff_distance: float = 4.0,
|
||||
restrict_to_nucleic: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.cutoff_distance = cutoff_distance
|
||||
self.restrict_to_nucleic = restrict_to_nucleic
|
||||
|
||||
@property
|
||||
def kwargs_to_compute_args(self):
|
||||
return {
|
||||
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
||||
}
|
||||
|
||||
def compute(self, *, predicted_atom_array_stack):
|
||||
if self.restrict_to_nucleic:
|
||||
if (predicted_atom_array_stack[0].is_rna.sum() + predicted_atom_array_stack[0].is_dna.sum()== 0):
|
||||
return {}
|
||||
try:
|
||||
total_contacts, mean_contacts, mean_contacts_per_atom = calculate_ligand_contacts(
|
||||
atom_array_stack=predicted_atom_array_stack,
|
||||
cutoff_distance=self.cutoff_distance,
|
||||
)
|
||||
except Exception as e:
|
||||
global_logger.error(
|
||||
f"Error calculating ligand contact metrics: {e} | Skipping"
|
||||
)
|
||||
return {}
|
||||
|
||||
return {
|
||||
"mean_ligand_contacts_per_model": float(mean_contacts),
|
||||
"mean_ligand_contacts_per_atom": float(mean_contacts_per_atom),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user