train and inference atom23 with custom datasets and na ss transform

This commit is contained in:
Raktim Mitra
2026-01-31 16:14:06 -08:00
committed by Raktim Mitra
parent 716c16a7ec
commit f0ab0fedae
15 changed files with 282 additions and 77 deletions

View File

@@ -0,0 +1,103 @@
# base training dataset for training AF3 design models (atom14 variants):
# protein subsampling only.
defaults:
# Grab datasets
- train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
- train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit
#- train/rfd3_monomer_distillation@train
- train/rna_monomer_distillation@train
# Customized validation datasets
- val/unconditional@val.unconditional
- val/unconditional_deep@val.unconditional_deep
- val/indexed@val.indexed
- val/pseudoknot@val.pseudoknot
# Customized train masks
- conditions/unconditional@global_transform_args.train_conditions.unconditional
- conditions/island@global_transform_args.train_conditions.island
- conditions/tipatom@global_transform_args.train_conditions.tipatom
- conditions/sequence_design@global_transform_args.train_conditions.sequence_design
- conditions/ppi@global_transform_args.train_conditions.ppi
- _self_
# Create a dictionary used for transform arguments
pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline
# Base config overrides:
diffusion_batch_size_train: 32
diffusion_batch_size_inference: 8
crop_size: 384
n_recycles_train: 2
n_recycles_validation: 1
max_atoms_in_crop: 3840 # ~10x crop size.
# Global transform arguments are necessary for arguments shared between training and inference
global_transform_args:
n_atoms_per_token: 14
central_atom: CB
sigma_perturb: 2.0
sigma_perturb_com: 1.0
association_scheme: dense
center_option: diffuse # options are ["all", "motif", "diffuse"]
# Reference conformer policy
generate_conformers: True
generate_conformers_for_non_protein_only: True
provide_reference_conformer_when_unmasked: True
ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details
provide_elements_for_unindexed_components: True
use_element_for_atom_names_of_atomized_tokens: True # TODO: correct name, implies unindexed do too
# PPI Cropping
keep_full_binder_in_spatial_crop: False
max_binder_length: 170
# PPI Hotspots
max_ppi_hotspots_frac_to_provide: 0.2
ppi_hotspot_max_distance: 4.5
# Secondary structure features
max_ss_frac_to_provide: 0.4
min_ss_island_len: 1
max_ss_island_len: 10
# Nucleic acid features
add_na_pair_features: false
train_conditions:
unconditional:
frequency: 5.0
sequence_design:
frequency: 2.0
island:
frequency: 1.0
tipatom:
frequency: 0.0
ppi:
frequency: 0.0
# Used to create simple boolean flags for downstream conditioning
meta_conditioning_probabilities:
calculate_NA_SS: 1.0
calculate_hbonds: 0.2
calculate_rasa: 0.6
keep_protein_motif_rasa: 0.1 # Small to prevent noisy input to model
hbond_subsample: 0.5
# fully indexed training
unindex_leak_global_index: 0.10
unindex_insert_random_break: 0.10
unindex_remove_random_break: 0.10
# Probability of adding 1d secondary structure conditioning
add_1d_ss_features: 0.1
featurize_plddt: 0.9 # Applied for monomer distillation only
add_global_is_non_loopy_feature: 0.99
# PPI
add_ppi_hotspots: 0.75
full_binder_crop: 0.75

View File

@@ -0,0 +1,39 @@
defaults:
- pdb/base_transform_args@rna_monomer_distillation
- _self_
rna_monomer_distillation:
dataset:
_target_: atomworks.ml.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
# cif parser arguments
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
# metadata parser
dataset_parser:
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
# metadata dataset
dataset:
_target_: atomworks.ml.datasets.PandasDataset
name: rna_monomer_distillation
id_column: example_id
data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet
columns_to_load:
- example_id
- path
- cluster_id
- seq_hash
- overall_plddt
- overall_pde
- overall_pae
transform:
crop_contiguous_probability: 0.67
crop_spatial_probability: 0.33

View File

@@ -1,9 +1,9 @@
defaults:
- unconditional
- design_validation_base
- _self_
dataset:
name: pseudoknot
eval_every_n: 1
data: /home/afavor/git/RFD3/modelhub/projects/aa_design/tests/test_data/pseudoknot.json
data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json

View File

@@ -5,6 +5,7 @@ defaults:
- /debug/default
- override /model: rfd3_base
- override /logger: null
- override /datasets: design_base_rfd3na
- _self_
name: rfd3na-SScond
@@ -54,7 +55,7 @@ datasets:
max_atoms_in_crop: 2560 # ~10x crop size.
global_transform_args:
association_scheme: atom23
add_na_pair_features: true
#add_na_pair_features: true
train_conditions:
unconditional:
frequency: 2.0
@@ -69,19 +70,19 @@ datasets:
train:
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
pdb:
probability: 1.0
# rna_monomer_distillation:
# probability: 1.0
probability: 0.5
rna_monomer_distillation:
probability: 0.5
# val:
# pseudoknot:
# dataset:
# # eval_every_n: 10
# eval_every_n: 2
val:
pseudoknot:
dataset:
# eval_every_n: 10
eval_every_n: 1
trainer:
devices_per_node: 1
limit_train_batches: 10
limit_val_batches: 1
validate_every_n_epochs: 5
prevalidate: false
prevalidate: true

View File

@@ -9,7 +9,7 @@ from os import PathLike
from typing import Any, Dict, List, Optional, Union
import numpy as np
from atomworks.constants import STANDARD_AA
from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.io.parser import parse_atom_array
# from atomworks.ml.datasets.datasets import BaseDataset
@@ -119,7 +119,9 @@ class DesignInputSpecification(BaseModel):
validate_assignment=False,
str_strip_whitespace=True,
str_min_length=1,
extra="forbid",
#extra="forbid", ####################################################
extra="allow"
## for now allowing extra for rfd3na-ss purposes, can decide later ##
)
# fmt: off
# ========================================================================
@@ -494,6 +496,12 @@ class DesignInputSpecification(BaseModel):
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
is_bkbn, False, dtype=int
)
elif aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA) and self.redesign_motif_sidechains:
is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA)
aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
is_bkbn, False, dtype=int
)
# ... Apply selections on top
apply_selections(start, end)
@@ -509,17 +517,18 @@ class DesignInputSpecification(BaseModel):
atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
########## reorder NA atoms ###########
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input_annotated[is_dna]
rna_array = atom_array_input_annotated[is_rna]
if exists(atom_array_input_annotated):
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input_annotated[is_dna]
rna_array = atom_array_input_annotated[is_rna]
atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
dna_array, backbone_atoms_DNA
)
atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
rna_array, backbone_atoms_RNA
)
atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
dna_array, backbone_atoms_DNA
)
atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
rna_array, backbone_atoms_RNA
)
#######################################
atom_array = self._build_init(atom_array_input_annotated)
@@ -928,11 +937,11 @@ def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
elif polymer_type == "R":
res_name = "A"
bb_len = len(backbone_atoms_RNA)
bb_atom_names = strip_list(backbone_atoms_RNA)
bb_atom_names = backbone_atoms_RNA
elif polymer_type == "D":
res_name = "DA"
bb_len = len(backbone_atoms_DNA)
bb_atom_names = strip_list(backbone_atoms_DNA)
bb_atom_names = backbone_atoms_DNA
else:
raise ValueError(
f"invalid polymer type detected: {polymer_type}, check contig!"
@@ -955,12 +964,13 @@ def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
for idx in range(1, n + 1)
]
array = struc.array(atoms)
array.set_annotation("element", np.array(bb_elements * n, dtype="<U2"))
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U2"))
array.set_annotation("element", np.array(bb_elements * n, dtype="<U3"))
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U3"))
array = set_default_conditioning_annotations(
array, motif=False, additional=additional_annotations
)
array = set_common_annotations(array)
return array

View File

@@ -536,13 +536,14 @@ def create_atom_array_from_design_specification_legacy(
optional_conditions = []
########## reorder NA atoms ###########
is_dna = np.isin(atom_array_input.res_name, ["DA", "DC", "DG", "DT"])
is_rna = np.isin(atom_array_input.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input[is_dna]
rna_array = atom_array_input[is_rna]
if exists(atom_array_input):
is_dna = np.isin(atom_array_input.res_name, ["DA", "DC", "DG", "DT"])
is_rna = np.isin(atom_array_input.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input[is_dna]
rna_array = atom_array_input[is_rna]
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
#######################################
if exists(atomwise_rasa):

View File

@@ -291,6 +291,7 @@ class BackboneMetrics(Metric):
3.0 # maximum closest-neighbour distance before considered a floating atom
)
self.standard_ca_dist = 3.8
self.standard_PP_dist = 6.4
self.compute_for_diffused_region_only = compute_for_diffused_region_only
@property
@@ -310,6 +311,8 @@ class BackboneMetrics(Metric):
) # N_atoms x N_atoms
is_protein = f["is_protein"][tok_idx].cpu().numpy() # n_atoms
is_rna = f["is_rna"][tok_idx].cpu().numpy()
is_dna = f["is_dna"][tok_idx].cpu().numpy()
mask = np.zeros_like(dists, dtype=bool)
mask = mask | (np.eye(dists.shape[-1], dtype=bool))[None]
@@ -362,23 +365,42 @@ class BackboneMetrics(Metric):
if self.compute_for_diffused_region_only:
is_ca = is_ca[diffused_region]
is_protein = is_protein[diffused_region]
idx_mask = is_ca & is_protein
is_dna = is_dna[diffused_region]
is_rna = is_rna[diffused_region]
protein_idx_mask = is_ca & (is_protein)
na_idx_mask = is_ca & (is_rna | is_dna)
if self.compute_for_diffused_region_only:
xyz = X_L.cpu()[:, diffused_region][:, idx_mask]
xyz_protein = X_L.cpu()[:, diffused_region][:, protein_idx_mask]
xyz_na = X_L.cpu()[:, diffused_region][:, na_idx_mask]
else:
xyz = X_L.cpu()[:, idx_mask]
xyz_protein = X_L.cpu()[:, protein_idx_mask]
xyz_na = X_L.cpu()[:, na_idx_mask]
ca_dists = torch.norm(xyz[:, 1:] - xyz[:, :-1], dim=-1)
deviation = torch.abs(ca_dists - self.standard_ca_dist) # B, (I-1)
is_chainbreak = deviation > 0.75
ca_dists_protein = torch.norm(xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1)
ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1)
deviation_protein = torch.abs(ca_dists_protein - self.standard_ca_dist) # B, (I-1)
deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1)
is_chainbreak_protein = deviation_protein > 0.75
is_chainbreak_na = deviation_na > 1
o["max_ca_deviation"] = float(deviation.max(-1).values.mean())
o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean())
o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean())
try:
o["max_ca_deviation_protein"] = float(deviation_protein.max(-1).values.mean())
o["fraction_chainbreaks_protein"] = float(is_chainbreak_protein.float().mean(-1).mean())
o["n_chainbreaks_protein"] = float(is_chainbreak_protein.float().sum(-1).mean())
except:
print("No protein in this example, skipping protein chainbreak metrics")
try:
o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean())
o["fraction_chainbreaks_na"] = float(is_chainbreak_na.float().mean(-1).mean())
o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean())
except:
print("No NA in this example, skipping NA chainbreak metrics")
return o
class PPIMetrics(Metric):
"""PPI-specific metrics"""

View File

@@ -243,6 +243,10 @@ class SampleConditioningType(Transform):
)
self.meta_conditioning_probabilities = meta_conditioning_probabilities
self.train_conditions = train_conditions
for item in self.train_conditions:
self.train_conditions[item].association_scheme = association_scheme
self.sequence_encoding = sequence_encoding
self.association_scheme = association_scheme
@@ -261,12 +265,15 @@ class SampleConditioningType(Transform):
assert "conditions" in data, "Conditioning dict not initialized"
def forward(self, data):
#for item in self.train_conditions:
# print(self.train_conditions[item].is_valid_for_example(data))
valid_conditions = [
cond
for cond in self.train_conditions.values()
if cond.frequency > 0 and cond.is_valid_for_example(data)
if cond.is_valid_for_example(data) and cond.frequency > 0
]
if len(valid_conditions) == 0:
raise InvalidSampledConditionException("No valid condition was found.")
@@ -280,8 +287,6 @@ class SampleConditioningType(Transform):
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__

View File

@@ -71,8 +71,10 @@ class SubsampleToTypes(Transform):
def __init__(
self,
allowed_types: list | str = ["is_protein"],
association_scheme: str = 'atom14'
):
self.allowed_types = allowed_types
self.association_scheme = association_scheme
if not self.allowed_types == "ALL":
for k in allowed_types:
if not k.startswith("is_"):
@@ -104,7 +106,7 @@ class SubsampleToTypes(Transform):
)
)
if atom_array.is_protein.sum() == 0:
if self.association_scheme != 'atom23' and atom_array.is_protein.sum() == 0:
raise ValueError(
"No protein atoms found in the atom array. Example ID: {}".format(
data.get("example_id", "unknown")

View File

@@ -119,7 +119,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
def __init__(
self,
is_inference,
add_nucleic_ss_feats: bool = True,
meta_conditioning_probabilities,
p_is_nucleic_ss_example: float = 0.3,
p_show_partial_feats: float = 0.5,
@@ -136,9 +136,18 @@ class CalculateNucleicAcidGeomFeats(Transform):
):
# Critical, must always have to know how to handle
self.is_inference = is_inference
if not self.is_inference:
## relevant in training
self.sampling_prob = meta_conditioning_probabilities['calculate_NA_SS']
else:
## irrelevant in inference
self.sampling_prob = 0
# For sampling whether we add nucleic-ss features (extra t2d)
self.add_nucleic_ss_feats = add_nucleic_ss_feats
# relevant in training
self.add_nucleic_ss_feats = (self.sampling_prob > 0)
######
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
self.nucleic_ss_min_shown = nucleic_ss_min_shown
@@ -189,7 +198,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
n_tokens = len(token_starts)
print(" DO I NEED TO CHANGE TO TOKEN_ID???")
# Handle the training case with ground truth and masking:
if not self.is_inference:
if not self.is_inference and (np.random.rand() < self.sampling_prob):
# First, annotate as usual
# atom_array = annotate_na_ss(atom_array, **kwargs)
@@ -226,8 +235,9 @@ class CalculateNucleicAcidGeomFeats(Transform):
- 3). Lists of paired indices
"""
is_nucleic_ss_example=True
give_partial_feats=False
#is_nucleic_ss_example=True
#give_partial_feats=False
atom_array = annotate_na_ss_from_data_specification(
data,
overwrite=True,
@@ -282,4 +292,4 @@ class CalculateNucleicAcidGeomFeats(Transform):
n_islands_max=self.n_islands_max,
max_length=max_length,
)

View File

@@ -1321,6 +1321,8 @@ def bp_partner_to_ss_matrix(
*,
feature_info: Optional[dict] = None,
mol_info: Optional[NucMolInfo] = None,
NA_only: Optional[bool] = False,
planar_only: Optional[bool] = False,
include_loops: bool = True,
token_level_data: Optional[dict] = None,
) -> np.ndarray:

View File

@@ -196,6 +196,7 @@ def get_crop_transform(
max_binder_length: int,
max_atoms_in_crop: int | None,
allowed_types: List[str],
association_scheme: str,
):
if (
crop_contiguous_probability > 0
@@ -215,7 +216,7 @@ def get_crop_transform(
), "Crop center cutoff distance must be greater than 0"
pre_crop_transforms = [
SubsampleToTypes(allowed_types=allowed_types),
SubsampleToTypes(allowed_types=allowed_types, association_scheme=association_scheme),
]
cropping_transform = RandomRoute(
@@ -360,8 +361,11 @@ def build_atom14_base_pipeline_(
max_ss_frac_to_provide: float,
min_ss_island_len: int,
max_ss_island_len: int,
# Nucleic acid features
add_na_pair_features: bool,
## Nucleic acid features #####
#add_na_pair_features: bool,
## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself #####
**_, # dump additional kwargs (e.g. msa stuff)
):
"""
@@ -411,6 +415,7 @@ def build_atom14_base_pipeline_(
max_binder_length=max_binder_length,
max_atoms_in_crop=max_atoms_in_crop,
allowed_types=allowed_types,
association_scheme=association_scheme
)
if zero_occ_on_exposure_after_cropping:
@@ -443,14 +448,15 @@ def build_atom14_base_pipeline_(
)
)
# Add nucleic acid geometry features
if add_na_pair_features:
transforms.append(
CalculateNucleicAcidGeomFeats(
is_inference,
NA_only=False,
planar_only=True,
)
#if add_na_pair_features:
transforms.append(
CalculateNucleicAcidGeomFeats(
is_inference,
meta_conditioning_probabilities,
NA_only=False,
planar_only=True,
)
)
# Design Transforms
transforms += [
@@ -618,7 +624,6 @@ def build_atom14_base_pipeline(
Wrapper around pipeline construction to handle empty training args
Sets default behaviour for inference to keep backward compatibility
"""
if is_inference:
# Provide explicit defaults for training-only args
kwargs.setdefault("crop_size", 512)
@@ -634,7 +639,8 @@ def build_atom14_base_pipeline(
kwargs.setdefault("min_ss_island_len", 0)
kwargs.setdefault("max_ss_island_len", 999)
kwargs.setdefault("max_binder_length", 999)
kwargs.setdefault("add_na_pair_features", False)
# This should not be necessary.
#kwargs.setdefault("add_na_pair_features", False)
kwargs.setdefault("b_factor_min", None)
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)

View File

@@ -96,13 +96,13 @@ class IslandCondition(TrainingCondition):
is_rna = data["atom_array"].is_rna
### updating this to allow other polymers
if self.association_scheme == "atom23":
if not np.any(is_protein | is_dna | is_rna):
return False
if np.any(is_protein | is_dna | is_rna):
return True
else:
if not np.any(is_protein):
return False
return True
if np.any(is_protein):
return True
return False
def sample_motif_tokens(self, atom_array):
"""

View File

@@ -37,7 +37,7 @@ def assert_single_representative(token, central_atom="CB"):
mask = get_af3_token_representative_masks(token, central_atom=central_atom)
assert (
np.sum(mask) == 1
), f"No representative atom (CB) found. mask: {mask}\nToken: {token}"
), f"No representative atom ({central_atom}) found. mask: {mask}\nToken: {token}"
def assert_single_token(token):

View File

@@ -225,17 +225,21 @@ class PadTokensWithVirtualAtoms(Transform):
# First, pad with virtual atoms if needed
if self.association_scheme == "atom23" and atom_array[start].is_dna:
n_atoms_per_token = 22
central_atom = "C1'"
elif self.association_scheme == "atom23" and atom_array[start].is_rna:
n_atoms_per_token = 23
central_atom = "C1'"
else:
n_atoms_per_token = self.n_atoms_per_token
central_atom = self.atom_to_pad_from
n_pad = n_atoms_per_token - len(token)
if n_pad > 0:
mask = get_af3_token_representative_masks(
token, central_atom=self.atom_to_pad_from
token, central_atom=central_atom
)
assert_single_representative(token)
assert_single_representative(token, central_atom=central_atom)
# ... Create virtual atoms
pad_atoms = token[mask].copy()