cleanup; aptamer metric; monomer dist on; send trainininggit status

This commit is contained in:
Raktim Mitra
2026-02-17 12:26:13 -08:00
committed by Raktim Mitra
parent 91a0eb22ec
commit de06df8cbf
7 changed files with 145 additions and 114 deletions

View File

@@ -5,7 +5,7 @@ defaults:
# Grab datasets
- train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
- train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit
#- train/rfd3_monomer_distillation@train
- train/rfd3_monomer_distillation@train
- train/rna_monomer_distillation@train
# Customized validation datasets

View File

@@ -12,6 +12,7 @@ dataset:
# filters common across all PDB datasets
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
- 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]'
- "deposition_date < '2024-12-16'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
@@ -19,4 +20,4 @@ dataset:
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "is_inter_molecule"
- "is_inter_molecule"

View File

@@ -15,6 +15,7 @@ dataset:
# filters common across all PDB datasets
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
- 'pdb_id not in ["1drz", "2m8k", "2miy", "3q3z", "4oqu", "4plx", "4znp", "7kd1", "7kga", "7qr4"]'
- "deposition_date < '2024-12-16'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"

View File

@@ -1,98 +0,0 @@
# @package _global_
# Training configuration for RFD3
defaults:
#- /debug/default
- override /model: rfd3_base
- override /logger: wandb
- override /datasets: design_base_rfd3na
- _self_
name: rfd3na
tags: [print-model]
ckpt_path: null
model:
net:
token_initializer:
token_1d_features:
ref_motif_token_type: 3
restype: 32
is_dna_token: 1
is_rna_token: 1
is_protein_token: 1
token_2d_features:
bp_partners: 3 # Unspecified, pair, loop
atom_1d_features:
ref_atom_name_chars: 256
ref_element: 128
ref_charge: 1
ref_mask: 1
ref_is_motif_atom_with_fixed_coord: 1
ref_is_motif_atom_unindexed: 1
has_zero_occupancy: 1
ref_pos: 3
# Guided features
ref_atomwise_rasa: 3
active_donor: 1
active_acceptor: 1
is_atom_level_hotspot: 1
diffusion_module:
n_recycle: 2
use_local_token_attention: True
diffusion_transformer:
n_local_tokens: 32
n_keys: 128
inference_sampler:
num_timesteps: 100
datasets:
diffusion_batch_size_train: 16
crop_size: 256
max_atoms_in_crop: 2560 # ~10x crop size.
global_transform_args:
meta_conditioning_probabilities:
p_is_nucleic_ss_example: 1.0
p_nucleic_ss_show_partial_feats: 0.7
p_canonical_bp_filter: 0.2
calculate_NA_SS: 0.3
association_scheme: atom23
#add_na_pair_features: true
train_conditions:
unconditional:
frequency: 2.0
island:
frequency: 2.0
sequence_design:
frequency: 0.5
tipatom:
frequency: 5.0
ppi:
frequency: 0.0
train:
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
pdb:
# probability: 0.5
probability: 0.75
# probability: 0.0
rna_monomer_distillation:
# probability: 0.2
probability: 0.25
# probability: 1.0
val:
pseudoknot:
dataset:
# eval_every_n: 10
eval_every_n: 5
trainer:
#devices_per_node: 1
#limit_train_batches: 10
#limit_val_batches: 1
validate_every_n_epochs: 5
prevalidate: true

View File

@@ -2,14 +2,13 @@
# Training configuration for RFD3
defaults:
- /debug/default
#- /debug/default
- override /model: rfd3_base
#- override /datasets: all
- override /logger: csv
#- override /logger: wandb
- override /logger: wandb
- override /datasets: design_base_rfd3na
- _self_
name: train-base
name: rfd3na_scratch
tags: [print-model]
ckpt_path: null
@@ -22,6 +21,8 @@ model:
is_dna_token: 1
is_rna_token: 1
is_protein_token: 1
token_2d_features:
bp_partners: 3 # Unspecified, pair, loop
atom_1d_features:
ref_atom_name_chars: 256
ref_element: 128
@@ -53,7 +54,14 @@ datasets:
crop_size: 256
max_atoms_in_crop: 2560 # ~10x crop size.
global_transform_args:
meta_conditioning_probabilities:
p_is_nucleic_ss_example: 0.25
p_nucleic_ss_show_partial_feats: 0.7
p_canonical_bp_filter: 0.2
#calculate_NA_SS: 0.3
association_scheme: atom23
#add_na_pair_features: true
train_conditions:
unconditional:
frequency: 2.0
@@ -67,16 +75,22 @@ datasets:
frequency: 0.0
train:
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
#pdb:
#probability: 0.10
#monomer_distillation:
#probability: 0.90
pdb:
probability: 1.0
probability: 0.5
rna_monomer_distillation:
probability: 0.3
monomer_distillation:
probability: 0.2
val:
pseudoknot:
dataset:
# eval_every_n: 10
eval_every_n: 5
trainer:
devices_per_node: 1
limit_train_batches: 10
limit_val_batches: 1
#devices_per_node: 1
#limit_train_batches: 10
#limit_val_batches: 1
validate_every_n_epochs: 5
prevalidate: false
prevalidate: true

View File

@@ -29,3 +29,7 @@ nucleic_ss_similarity:
annotation_NA_only: False
annotation_planar_only: True
rna_aptamer_contacts:
_target_: rfd3.metrics.rna_aptamer_metrics.LigandContactMetrics
restrict_to_nucleic: True

View File

@@ -0,0 +1,109 @@
import logging
from foundry.metrics.metric import Metric
from foundry.utils.ddp import RankedLogger
logging.basicConfig(level=logging.INFO)
global_logger = RankedLogger(__name__, rank_zero_only=False)
import numpy as np
def calculate_ligand_contacts(
atom_array_stack,
cutoff_distance=4.0,
):
"""
Count number of atom contacts within cutoff of any ligand atom.
Parameters
----------
atom_array_stack : AtomArrayStack
Shape: (n_models, n_atoms)
cutoff_distance : float
Distance cutoff in Å
Returns
-------
total_contacts : int
mean_contacts_per_model : float
"""
cutoff_sq = cutoff_distance ** 2
contacts_per_model = []
n_models = len(atom_array_stack)
for i in range(n_models):
atoms = atom_array_stack[i]
coords = atoms.coord
hetero_mask = atoms.hetero.astype(bool)
# Skip if no ligand
if not np.any(hetero_mask):
contacts_per_model.append(0)
continue
ligand_coords = coords[hetero_mask]
non_ligand_coords = coords[~hetero_mask]
if len(non_ligand_coords) == 0:
contacts_per_model.append(0)
continue
# Pairwise squared distances
diff = non_ligand_coords[:, None, :] - ligand_coords[None, :, :]
dist_sq = np.sum(diff ** 2, axis=-1)
# Any ligand within cutoff
contact_mask = np.any(dist_sq < cutoff_sq, axis=1)
n_contacts = np.sum(contact_mask)
contacts_per_model.append(n_contacts)
contacts_per_model = np.array(contacts_per_model)
return int(np.sum(contacts_per_model)), float(np.mean(contacts_per_model)), float(np.mean(contacts_per_model))/hetero_mask.sum()
class LigandContactMetrics(Metric):
def __init__(
self,
*,
cutoff_distance: float = 4.0,
restrict_to_nucleic: bool = True,
):
super().__init__()
self.cutoff_distance = cutoff_distance
self.restrict_to_nucleic = restrict_to_nucleic
@property
def kwargs_to_compute_args(self):
return {
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
}
def compute(self, *, predicted_atom_array_stack):
if self.restrict_to_nucleic:
if (predicted_atom_array_stack[0].is_rna.sum() + predicted_atom_array_stack[0].is_dna.sum()== 0):
return {}
try:
total_contacts, mean_contacts, mean_contacts_per_atom = calculate_ligand_contacts(
atom_array_stack=predicted_atom_array_stack,
cutoff_distance=self.cutoff_distance,
)
except Exception as e:
global_logger.error(
f"Error calculating ligand contact metrics: {e} | Skipping"
)
return {}
return {
"mean_ligand_contacts_per_model": float(mean_contacts),
"mean_ligand_contacts_per_atom": float(mean_contacts_per_atom),
}