feat: atom23 inference changes and training fixes

This commit is contained in:
Raktim Mitra
2026-01-20 14:44:19 -08:00
committed by Raktim Mitra
parent 4a7aaf8793
commit ebec466e4f
13 changed files with 527 additions and 113 deletions

View File

@@ -9,6 +9,6 @@ weights:
beta: 0.5
alphas:
a_prot: 3.0 # 3 for AF-3
a_nuc: 0.0 # 3 for AF-3
a_nuc: 3.0 # 3 for AF-3
a_ligand: 1.0 # 1 for AF-3
a_loi: 5.0 # 5 for AF-3

View File

@@ -0,0 +1,82 @@
# @package _global_
# Training configuration for RFD3
defaults:
- /debug/default
- override /model: rfd3_base
#- override /datasets: all
- override /logger: csv
#- override /logger: wandb
- _self_
name: train-base
tags: [print-model]
ckpt_path: null
model:
net:
token_initializer:
token_1d_features:
ref_motif_token_type: 3
restype: 32
is_dna_token: 1
is_rna_token: 1
is_protein_token: 1
atom_1d_features:
ref_atom_name_chars: 256
ref_element: 128
ref_charge: 1
ref_mask: 1
ref_is_motif_atom_with_fixed_coord: 1
ref_is_motif_atom_unindexed: 1
has_zero_occupancy: 1
ref_pos: 3
# Guided features
ref_atomwise_rasa: 3
active_donor: 1
active_acceptor: 1
is_atom_level_hotspot: 1
diffusion_module:
n_recycle: 2
use_local_token_attention: True
diffusion_transformer:
n_local_tokens: 32
n_keys: 128
inference_sampler:
num_timesteps: 100
datasets:
diffusion_batch_size_train: 16
crop_size: 256
max_atoms_in_crop: 2560 # ~10x crop size.
global_transform_args:
association_scheme: atom23
train_conditions:
unconditional:
frequency: 2.0
island:
frequency: 2.0
sequence_design:
frequency: 0.5
tipatom:
frequency: 5.0
ppi:
frequency: 0.0
train:
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
#pdb:
#probability: 0.10
#monomer_distillation:
#probability: 0.90
pdb:
probability: 1.0
trainer:
devices_per_node: 1
limit_train_batches: 10
limit_val_batches: 1
validate_every_n_epochs: 5
prevalidate: false

View File

@@ -242,42 +242,171 @@ SELECTION_NONPROTEIN = [
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
backbone_atomscheme_DNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'"]#, None]
backbone_atomscheme_DNA = [
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" C1'",
] # , None]
backbone_atomscheme_RNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'"]
backbone_atomscheme_RNA = [
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" O2'",
" C1'",
]
DNA_atoms = {
'DA': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
'DC': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
'DG': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
'DT': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ']}
RNA_atoms = {
'A': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
'C': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
'G': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
'U': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ']
"DA": [
" N9 ",
" C8 ",
" N7 ",
" C5 ",
" C6 ",
" N6 ",
" N1 ",
" C2 ",
" N3 ",
" C4 ",
],
"DC": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "],
"DG": [
" N9 ",
" C8 ",
" N7 ",
" C5 ",
" C6 ",
" O6 ",
" N1 ",
" C2 ",
" N2 ",
" N3 ",
" C4 ",
],
"DT": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C7 ", " C6 "],
}
association_schemes['atom23'] = {}
RNA_atoms = {
"A": [
" N9 ",
" C8 ",
" N7 ",
" C5 ",
" C6 ",
" N6 ",
" N1 ",
" C2 ",
" N3 ",
" C4 ",
],
"C": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " N4 ", " C5 ", " C6 "],
"G": [
" N9 ",
" C8 ",
" N7 ",
" C5 ",
" C6 ",
" O6 ",
" N1 ",
" C2 ",
" N2 ",
" N3 ",
" C4 ",
],
"U": [" N1 ", " C2 ", " O2 ", " N3 ", " C4 ", " O4 ", " C5 ", " C6 "],
}
association_schemes["atom23"] = {}
for item in DNA_atoms:
association_schemes['atom23'][item] = tuple(backbone_atomscheme_DNA + DNA_atoms[item]+ [None]*(22 - len(DNA_atoms[item] + backbone_atomscheme_DNA)))
association_schemes["atom23"][item] = tuple(
backbone_atomscheme_DNA
+ DNA_atoms[item]
+ [None] * (22 - len(DNA_atoms[item] + backbone_atomscheme_DNA))
)
for item in RNA_atoms:
association_schemes['atom23'][item] = tuple(backbone_atomscheme_RNA + RNA_atoms[item]+ [None]*(23 - len(RNA_atoms[item] + backbone_atomscheme_RNA)))
association_schemes["atom23"][item] = tuple(
backbone_atomscheme_RNA
+ RNA_atoms[item]
+ [None] * (23 - len(RNA_atoms[item] + backbone_atomscheme_RNA))
)
for item in association_schemes['dense']:
association_schemes['atom23'][item] = association_schemes['dense'][item]
for item in association_schemes["dense"]:
association_schemes["atom23"][item] = association_schemes["dense"][item]
association_schemes['atom23']['DX'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None) #rna_mask
association_schemes['atom23']['X'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None)#rna mask
association_schemes["atom23"]["DX"] = (
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" C1'",
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
) # rna_mask
association_schemes["atom23"]["X"] = (
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" O2'",
" C1'",
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
) # rna mask
ATOM23_ATOM_NAMES_RNA = np.array(
[item.strip() for item in backbone_atomscheme_RNA] + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
[item.strip() for item in backbone_atomscheme_RNA]
+ [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
)
"""Atom23 atom names (e.g. CA, V1)"""
ATOM23_ATOM_ELEMENTS_RNA = np.array(
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "O", "C"]
+ [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
)
"""Atom23 element names (e.g. C, VX)"""
@@ -285,12 +414,14 @@ ATOM23_ATOM_NAME_TO_ELEMENT = {
name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA)
}
ATOM23_ATOM_NAMES_DNA = np.array(
[item.strip() for item in backbone_atomscheme_DNA] + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
[item.strip() for item in backbone_atomscheme_DNA]
+ [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
)
"""Atom23 atom names (e.g. CA, V1)"""
ATOM23_ATOM_ELEMENTS_DNA = np.array(
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
["P", "O", "O", "O", "C", "C", "O", "C", "O", "C", "C"]
+ [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
)
"""Atom23 element names (e.g. C, VX)"""
@@ -307,4 +438,3 @@ association_schemes_stripped = {
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)

View File

@@ -30,9 +30,12 @@ from rfd3.constants import (
OPTIONAL_CONDITIONING_VALUES,
REQUIRED_CONDITIONING_ANNOTATION_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
backbone_atoms_DNA,
backbone_atoms_RNA,
)
from rfd3.inference.legacy_input_parsing import (
create_atom_array_from_design_specification_legacy,
reorder_atoms_per_residue,
)
from rfd3.inference.parsing import InputSelection
from rfd3.inference.symmetry.symmetry_utils import (
@@ -67,7 +70,6 @@ logging.basicConfig(level=logging.DEBUG)
logger = RankedLogger(__name__, rank_zero_only=True)
#################################################################################
# Custom infer_ori functions
#################################################################################
@@ -505,6 +507,21 @@ class DesignInputSpecification(BaseModel):
def build(self, return_metadata=False):
"""Main build pipeline."""
atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
########## reorder NA atoms ###########
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input_annotated[is_dna]
rna_array = atom_array_input_annotated[is_rna]
atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
dna_array, backbone_atoms_DNA
)
atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
rna_array, backbone_atoms_RNA
)
#######################################
atom_array = self._build_init(atom_array_input_annotated)
# Apply post-processing
@@ -894,31 +911,52 @@ def validator_context(validator_name: str, data: dict = None):
raise e
def create_diffused_residues(n, additional_annotations=None):
def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
from rfd3.constants import (
ATOM23_ATOM_NAME_TO_ELEMENT,
backbone_atoms_DNA,
backbone_atoms_RNA,
)
if n <= 0:
raise ValueError(f"Negative/null residue count ({n}) not allowed.")
if polymer_type == "P":
res_name = "ALA"
bb_len = 5
bb_atom_names = ["N", "CA", "C", "O", "CB"]
elif polymer_type == "R":
res_name = "A"
bb_len = len(backbone_atoms_RNA)
bb_atom_names = strip_list(backbone_atoms_RNA)
elif polymer_type == "D":
res_name = "DA"
bb_len = len(backbone_atoms_DNA)
bb_atom_names = strip_list(backbone_atoms_DNA)
else:
raise ValueError(
f"invalid polymer type detected: {polymer_type}, check contig!"
)
bb_elements = [ATOM23_ATOM_NAME_TO_ELEMENT[item] for item in bb_atom_names]
atoms = []
[
atoms.extend(
[
struc.Atom(
np.array([0.0, 0.0, 0.0], dtype=np.float32),
res_name="ALA",
res_name=res_name,
res_id=idx,
)
for _ in range(5)
for _ in range(bb_len)
]
)
for idx in range(1, n + 1)
]
array = struc.array(atoms)
array.set_annotation(
"element", np.array(["N", "C", "C", "O", "C"] * n, dtype="<U2")
)
array.set_annotation(
"atom_name", np.array(["N", "CA", "C", "O", "CB"] * n, dtype="<U2")
)
array.set_annotation("element", np.array(bb_elements * n, dtype="<U2"))
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U2"))
array = set_default_conditioning_annotations(
array, motif=False, additional=additional_annotations
)
@@ -1065,14 +1103,20 @@ def accumulate_components(
np.zeros(token.shape[0], dtype=int),
)
else:
n = int(component)
if component[-1] in ["P", "R", "D"]: # if polymer type specified
polymer_type = component[-1] # can be 'P'rotein, 'R'NA, 'D'NA
n = int(component[:-1])
else:
polymer_type = "P"
n = int(components)
# ... Skip if none or unindexed
if n == 0 or unindexed_components_started:
res_id += n
continue
# ... Create diffused residues
token = create_diffused_residues(n, all_annots)
token = create_diffused_residues(n, all_annots, polymer_type)
# ... Set index of insertion
token = set_indices(

View File

@@ -5,16 +5,19 @@ from os import PathLike
import biotite.structure as struc
import numpy as np
from atomworks.constants import STANDARD_AA
from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.io.utils.io_utils import to_cif_file
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
)
from biotite.structure import AtomArray, concatenate, get_residue_starts
from rfd3.constants import (
INFERENCE_ANNOTATIONS,
OPTIONAL_CONDITIONING_VALUES,
REQUIRED_INFERENCE_ANNOTATIONS,
backbone_atoms_DNA,
backbone_atoms_RNA,
)
from rfd3.inference.symmetry.symmetry_utils import (
center_symmetric_src_atom_array,
@@ -183,6 +186,27 @@ def fetch_motif_residue_(
subarray.set_annotation(
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
)
elif redesign_motif_sidechains and res_name in (STANDARD_DNA + STANDARD_RNA):
is_backbone = np.isin(subarray.atom_name, backbone_atoms_RNA)
subarray.set_annotation("is_motif_atom", is_backbone)
subarray.set_annotation(
"is_motif_atom", is_backbone & (subarray.atom_name != "C1'")
)
subarray.set_annotation(
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
)
if res_name in STANDARD_DNA:
subarray.res_name = np.full_like(
subarray.res_name, "DA", dtype=subarray.res_name.dtype
)
else:
subarray.res_name = np.full_like(
subarray.res_name, "A", dtype=subarray.res_name.dtype
)
subarray = subarray[is_backbone]
elif to_index or to_unindex:
# If the residue is in the contig or unindexed components,
# we set all atoms in the residue to be motif atoms
@@ -195,6 +219,8 @@ def fetch_motif_residue_(
f"{src_chain}{src_resid} is not found in fixed_atoms, contig or unindex contig."
"Please check your input and contig specification."
)
if unfix_all or f"{src_chain}{src_resid}" in unfix_residues:
subarray.set_annotation(
"is_motif_atom_with_fixed_coord", np.zeros(subarray.shape[0], dtype=int)
@@ -507,6 +533,17 @@ def create_atom_array_from_design_specification_legacy(
fixed_atoms = {}
optional_conditions = []
########## reorder NA atoms ###########
is_dna = np.isin(atom_array_input.res_name, ["DA", "DC", "DG", "DT"])
is_rna = np.isin(atom_array_input.res_name, ["A", "C", "G", "U"])
dna_array = atom_array_input[is_dna]
rna_array = atom_array_input[is_rna]
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
#######################################
if exists(atomwise_rasa):
set_atom_level_argument(atom_array_input, atomwise_rasa, "rasa_bin")
optional_conditions.append("rasa_bin")
@@ -731,3 +768,46 @@ def create_atom_array_from_design_specification_legacy(
to_cif_file(atom_array, out_path, extra_fields=INFERENCE_ANNOTATIONS)
return atom_array
def reorder_atoms_per_residue(
atom_array: AtomArray, desired_order: list[str]
) -> AtomArray:
"""
Reorder atoms within each residue of an AtomArray.
Atoms in `desired_order` appear first (in that order), followed by all others
in original order. Faster version using get_residue_starts().
Parameters:
- atom_array: AtomArray to reorder.
- desired_order: List of atom names in the desired per-residue order.
Returns:
- AtomArray with reordered atoms per residue.
"""
if len(atom_array) == 0:
return atom_array
res_starts = get_residue_starts(atom_array)
res_starts = np.append(res_starts, len(atom_array)) # add end index for slicing
reordered_chunks = []
order_dict = {name: i for i, name in enumerate(desired_order)}
for i in range(len(res_starts) - 1):
start, end = res_starts[i], res_starts[i + 1]
residue = atom_array[start:end]
# Boolean masks for matching and non-matching atom names
in_order_mask = np.isin(residue.atom_name, desired_order)
not_in_order_mask = ~in_order_mask
# Sort matching atoms by desired order
atoms_in_order = residue[in_order_mask]
sort_idx = np.argsort([order_dict[name] for name in atoms_in_order.atom_name])
ordered_atoms = atoms_in_order[sort_idx]
# Remaining atoms as-is
remaining_atoms = residue[not_in_order_mask]
# Concatenate reordered residue
reordered_chunks.append(concatenate([ordered_atoms, remaining_atoms]))
return concatenate(reordered_chunks)

View File

@@ -428,9 +428,14 @@ class AADesignTrainer(FabricTrainer):
# ... Delete virtual atoms and assign atom names and elements
if self.cleanup_virtual_atoms:
atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
atom_array, association_scheme=self.association_scheme
)
try:
atom_array = _cleanup_virtual_atoms_and_assign_atom_name_elements(
atom_array, association_scheme=self.association_scheme
)
except Exception as e:
global_logger.warning(
f"Failed to cleanup virtual atoms from diffusion output: {e}"
)
# ... When cleaning up virtual atoms, we can also calculate native_array_metricsl
metadata_dict[i]["metrics"] |= get_all_backbone_metrics(

View File

@@ -11,6 +11,8 @@ from biotite.structure import concatenate, infer_elements
from jaxtyping import Float, Int
from rfd3.constants import (
ATOM14_ATOM_NAMES,
ATOM23_ATOM_NAMES_DNA,
ATOM23_ATOM_NAMES_RNA,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
@@ -252,13 +254,19 @@ def _readout_seq_from_struc(
continue
# ... Find the index of virtual atom names in the standard atom14 names
ATOM_NAMES = ATOM14_ATOM_NAMES
if restype in STANDARD_DNA:
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
if restype in STANDARD_RNA:
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
atom_name_idx_in_atom14_scheme = np.array(
[
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
np.where(ATOM_NAMES == atom_name)[0][0]
for atom_name in cur_pred_res_atom_names
]
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool)
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
# ... Find the matched restype by checking if all the non-None posititons and None positions match
@@ -427,12 +435,30 @@ def process_unindexed_outputs(
else:
join_atom = None
if join_atom is None:
pass
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
elif not np.any(
np.isin(
token.atom_name, [item.replace(" ", "") for item in backbone_atoms_RNA]
)
):
if np.sum(token.atomize) == 1:
join_atom = np.where(token.atomize)[0][0]
elif "C1'" in token.atom_name:
join_atom = np.where(token.atom_name == "C1'")[0][0]
else:
join_atom = None
if join_atom is None:
global_logger.warning(
f"Token {token_pdb_id} does not contain backbone atoms or CB, skipping join point distance calculation {token}."
"Skipping joint point rmsd, neither protein or NA backbone"
)
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"

View File

@@ -281,7 +281,7 @@ class SampleConditioningType(Transform):
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -299,6 +299,7 @@ class SampleConditioningFlags(Transform):
"AssignTypes",
"SampleConditioningType",
] # We use is_protein in the PPI training condition
def __init__(self, association_scheme):
self.association_scheme = association_scheme
@@ -375,13 +376,15 @@ class UnindexFlaggedTokens(Transform):
token.res_id = token.res_id + max_resid
token.is_C_terminus[:] = False
token.is_N_terminus[:] = False
if association_scheme is not 'atom23':
if not self.association_scheme == "atom23":
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme"
token = add_representative_atom(token, central_atom=self.central_atom)
else:
if token.is_protein.all():
token = add_representative_atom(token, central_atom=self.central_atom)
token = add_representative_atom(
token, central_atom=self.central_atom
)
else:
token = add_representative_atom(token, central_atom="C1'")

View File

@@ -697,11 +697,12 @@ class AddAdditional1dFeaturesToFeats(Transform):
token_1d_features,
atom_1d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme='atom14'
association_scheme="atom14",
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_1d_features = token_1d_features
self.atom_1d_features = atom_1d_features
self.association_scheme = association_scheme
def check_input(self, data) -> None:
check_contains_keys(data, ["atom_array"])
@@ -753,11 +754,13 @@ class AddAdditional1dFeaturesToFeats(Transform):
"""
if "feats" not in data.keys():
data["feats"] = {}
if association_scheme == 'atom23':
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
if self.association_scheme == "atom23":
data["atom_array"].set_annotation(
"is_protein_token", data["atom_array"].is_protein
)
data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna)
data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna)
for feature_name, n_dims in self.token_1d_features.items():
data = self.generate_feature(feature_name, n_dims, data, "token")

View File

@@ -383,7 +383,7 @@ def build_atom14_base_pipeline_(
train_conditions=train_conditions,
meta_conditioning_probabilities=meta_conditioning_probabilities,
sequence_encoding=af3_sequence_encoding,
association_scheme=association_scheme
association_scheme=association_scheme,
),
),
]
@@ -423,7 +423,9 @@ def build_atom14_base_pipeline_(
# ... Add global token features (since number of tokens is fixed after cropping)
transforms.append(AddGlobalTokenIdAnnotation())
# ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
transforms.append(TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme)))
transforms.append(
TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme))
)
# Post-crop transforms
transforms.append(
@@ -443,7 +445,9 @@ def build_atom14_base_pipeline_(
sharding_depth=1,
),
# ... Fuse inference and training conditioning assignments
UnindexFlaggedTokens(central_atom=central_atom),
UnindexFlaggedTokens(
central_atom=central_atom, association_scheme=association_scheme
),
# ... Virtual atom padding (NOTE: Last transform which modulates atom count)
PadTokensWithVirtualAtoms(
n_atoms_per_token=n_atoms_per_token,
@@ -519,7 +523,7 @@ def build_atom14_base_pipeline_(
autofill_zeros_if_not_present_in_atomarray=True,
token_1d_features=token_1d_features,
atom_1d_features=atom_1d_features,
association_scheme=association_scheme
association_scheme=association_scheme,
),
AddAF3TokenBondFeatures(),
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),

View File

@@ -58,6 +58,8 @@ class IslandCondition(TrainingCondition):
Select islands as motif and assign conditioning strategies.
"""
association_scheme = "atom14"
def __init__(
self,
*,
@@ -70,11 +72,9 @@ class IslandCondition(TrainingCondition):
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
association_scheme = 'atom14',
):
self.name = name
self.frequency = frequency
self.association_scheme = association_scheme
# Token selection
self.island_sampling_kwargs = island_sampling_kwargs
@@ -89,15 +89,13 @@ class IslandCondition(TrainingCondition):
self.p_fix_motif_coordinates = p_fix_motif_coordinates
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
self.association_scheme = association_scheme
def is_valid_for_example(self, data) -> bool:
is_protein = data["atom_array"].is_protein
is_dna = data["atom_array"].is_dna
is_rna = data["atom_array"].is_rna
### updating this to allow other polymers
if self.association_scheme is not 'atom23':
if not self.association_scheme == "atom23":
if not np.any(is_protein | is_dna | is_rna):
return False
else:
@@ -113,8 +111,12 @@ class IslandCondition(TrainingCondition):
token_level_array = atom_array[get_token_starts(atom_array)]
# initialize motif tokens as all non-protein tokens
if self.association_scheme is 'atom23':
polymer_mask = (token_level_array.is_protein | token_level_array.is_dna | token_level_array.is_rna)
if self.association_scheme == "atom23":
polymer_mask = (
token_level_array.is_protein
| token_level_array.is_dna
| token_level_array.is_rna
)
is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy()
n_polymer_tokens = np.sum(polymer_mask)
islands_mask = sample_island_tokens(
@@ -123,13 +125,15 @@ class IslandCondition(TrainingCondition):
)
is_motif_token[polymer_mask] = islands_mask
else:
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
is_motif_token = np.asarray(
~token_level_array.is_protein, dtype=bool
).copy()
n_protein_tokens = np.sum(token_level_array.is_protein)
slands_mask = sample_island_tokens(
_protein_tokens,
*self.island_sampling_kwargs,
islands_mask = sample_island_tokens(
n_protein_tokens,
**self.island_sampling_kwargs,
)
is_motif_token[token_level_array.is_protein] = islands_mask
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
@@ -160,7 +164,7 @@ class IslandCondition(TrainingCondition):
is_motif_atom = sample_motif_subgraphs(
atom_array=atom_array,
**self.subgraph_sampling_kwargs,
association_scheme=self.association_scheme
association_scheme=self.association_scheme,
)
# We also only want resolved atoms to be motif
@@ -182,7 +186,7 @@ class IslandCondition(TrainingCondition):
p_fix_motif_sequence=self.p_fix_motif_sequence,
p_fix_motif_coordinates=self.p_fix_motif_coordinates,
p_unindex_motif_tokens=self.p_unindex_motif_tokens,
association_scheme=self.association_scheme
association_scheme=self.association_scheme,
)
atom_array.set_annotation(
@@ -202,7 +206,7 @@ class PPICondition(TrainingCondition):
"""Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
name = "ppi"
association_scheme = 'atom14'
association_scheme = "atom14"
def is_valid_for_example(self, data):
# Extract relevant data
@@ -301,7 +305,7 @@ class SubtypeCondition(TrainingCondition):
"""
name = "subtype"
association_scheme = 'atom14'
association_scheme = "atom14"
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
self.frequency = frequency
@@ -397,7 +401,7 @@ def sample_motif_subgraphs(
hetatom_n_bond_expectation,
residue_p_fix_all,
hetatom_p_fix_all,
association_scheme = 'atom14'
association_scheme="atom14",
):
"""
Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
@@ -431,10 +435,14 @@ def sample_motif_subgraphs(
"p_fix_all": residue_p_fix_all,
}
if association_scheme is 'atom23':
clause = atom_array_subset.is_protein.all() | atom_array_subset.is_dna.all() | atom_array_subset.is_rna.all()
if association_scheme == "atom23":
clause = (
atom_array_subset.is_protein.all()
| atom_array_subset.is_dna.all()
| atom_array_subset.is_rna.all()
)
else:
clause = atom_array_subset.is_potein.all()
clause = atom_array_subset.is_protein.all()
if not clause:
args.update(
@@ -465,12 +473,14 @@ def sample_conditioning_strategy(
p_fix_motif_sequence,
p_fix_motif_coordinates,
p_unindex_motif_tokens,
association_scheme
association_scheme,
):
atom_array.set_annotation(
"is_motif_atom_with_fixed_seq",
sample_is_motif_atom_with_fixed_seq(
atom_array, p_fix_motif_sequence=p_fix_motif_sequence, association_scheme=association_scheme
atom_array,
p_fix_motif_sequence=p_fix_motif_sequence,
association_scheme=association_scheme,
),
)
@@ -491,7 +501,9 @@ def sample_conditioning_strategy(
return atom_array
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, association_scheme):
def sample_is_motif_atom_with_fixed_seq(
atom_array, p_fix_motif_sequence, association_scheme
):
"""
Samples what kind of conditioning to apply to motif tokens.
@@ -504,10 +516,11 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, associ
is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
# By default reveal sequence for non-protein
if association_scheme is not 'atom23':
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
if not association_scheme == "atom23":
is_motif_atom_with_fixed_seq = (
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
)
return is_motif_atom_with_fixed_seq
@@ -526,7 +539,9 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
return is_motif_atom_with_fixed_coord
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_scheme='atom14'):
def sample_unindexed_atoms(
atom_array, p_unindex_motif_tokens, association_scheme="atom14"
):
"""
Samples which atoms in motif tokens should be flagged for unindexing.
@@ -539,15 +554,15 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_schem
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
# ensure non-residue atoms are not already flagged
if association_scheme == 'atom23':
if association_scheme == "atom23":
is_motif_atom_unindexed = np.logical_and(
is_motif_atom_unindexed, (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna)
) # is_residue refers to is_protein here
is_motif_atom_unindexed,
(atom_array.is_residue | atom_array.is_dna | atom_array.is_rna),
) # is_residue refers to is_protein here
else:
is_motif_atom_unindexed = np.logical_and(
is_motif_atom_unindexed, atom_array.is_residue
)
)
return is_motif_atom_unindexed

View File

@@ -10,11 +10,10 @@ from atomworks.ml.transforms.base import (
)
from atomworks.ml.utils.token import get_token_starts
from rfd3.constants import (
ATOM23_ATOM_NAME_TO_ELEMENT,
ATOM14_ATOM_NAME_TO_ELEMENT,
ATOM14_ATOM_NAMES,
ATOM23_ATOM_NAMES_RNA,
ATOM23_ATOM_NAME_TO_ELEMENT,
ATOM23_ATOM_NAMES_DNA,
ATOM23_ATOM_NAMES_RNA,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
@@ -31,7 +30,9 @@ from rfd3.transforms.util_transforms import (
from foundry.common import exists
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None):
def map_to_association_scheme(
atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None
):
"""
Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
NB this function is a bit more general since it is used to handle tipatoms too.
@@ -52,6 +53,7 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
else:
return ATOM_NAMES[idxs]
def map_names_to_elements(
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
) -> np.ndarray:
@@ -61,7 +63,7 @@ def map_names_to_elements(
then it returns the default value
"""
atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
elements = [ATOM23_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
return np.array(elements)
@@ -72,6 +74,9 @@ def generate_atom_mappings_(scheme="atom14"):
symmetry_mapping = {}
for aaa, atom_names in ccd_ordering_atomchar.items():
if aaa not in scheme:
continue
mapping = list(range(len(atom_names)))
scheme_names = scheme[aaa]
@@ -126,10 +131,10 @@ def permute_symmetric_atom_names_(
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
## fail safe, no symmetry confusion in NA bases ##
if (atom_names[0] == "P"):
if atom_names[0] == "P":
return atom_names
##################################################
if res_name in association_map:
idx_to_swap = association_map[res_name]
atom_names = atom_names[idx_to_swap]
@@ -180,16 +185,17 @@ class PadTokensWithVirtualAtoms(Transform):
token_ids = np.unique(atom_array.token_id)
assert len(token_ids) == len(
is_motif_atom_with_fixed_seq
), "Token ids and token level array have different lengths!"
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == 'atom23':
is_residue = (
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == "atom23":
is_residue = (
token_level_array.is_protein & ~token_level_array.atomize
) | is_motif_token_unindexed
is_residue_NA = (
(token_level_array.is_dna | token_level_array.is_rna) & ~token_level_array.atomize
(token_level_array.is_dna | token_level_array.is_rna)
& ~token_level_array.atomize
) | is_motif_token_unindexed
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
@@ -211,7 +217,6 @@ class PadTokensWithVirtualAtoms(Transform):
is_non_paddable_residue = is_residue & (
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
# Collect virtual atoms to insert (we will insert them all at once)
virtual_atoms_to_insert = []
@@ -221,7 +226,7 @@ class PadTokensWithVirtualAtoms(Transform):
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
if is_paddable[token_id]:
token = atom_array[start:end]
# First, pad with virtual atoms if needed
if self.association_scheme == "atom23" and atom_array[start].is_dna:
n_atoms_per_token = 22
@@ -230,7 +235,7 @@ class PadTokensWithVirtualAtoms(Transform):
else:
n_atoms_per_token = self.n_atoms_per_token
n_pad = n_atoms_per_token - len(token)
if n_pad > 0:
mask = get_af3_token_representative_masks(
token, central_atom=self.atom_to_pad_from
@@ -297,10 +302,10 @@ class PadTokensWithVirtualAtoms(Transform):
for token_id, (start, end) in enumerate(
zip(starts_padded[:-1], starts_padded[1:])
):
if (atom_array_padded[start].is_dna):
):
if atom_array_padded[start].is_dna:
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
elif (atom_array_padded[start].is_rna):
elif atom_array_padded[start].is_rna:
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
else:
ATOM_NAMES = ATOM14_ATOM_NAMES
@@ -328,7 +333,10 @@ class PadTokensWithVirtualAtoms(Transform):
)
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
atom_names = map_to_association_scheme(
atom_names, res_name, scheme=self.association_scheme, ATOM_NAMES=ATOM_NAMES
atom_names,
res_name,
scheme=self.association_scheme,
ATOM_NAMES=ATOM_NAMES,
)
atom_array_padded.atom_name[start:end] = atom_names
else:

View File

@@ -96,8 +96,21 @@ def get_design_pattern_with_constraints(contig, length=None):
fixed_parts = []
pos_to_put_motif = []
suff = [] # suffixes for diffused regions P(optional),R,D
for part in contig_parts:
if any(c.isalpha() for c in part): # Detect parts containing letters as fixed
## updating to include DNA and RNA generation
if part[-1] in ["R", "D"]: ##Detect non-fixed RNA and DNA contig part
suff.append(part[-1])
part = part[:-1]
if "-" in part:
start, end = map(int, part.split("-"))
else:
start = end = int(part)
variable_ranges.append([start, end])
pos_to_put_motif.append(0)
elif any(c.isalpha() for c in part): # Detect parts containing letters as fixed
pn_unit_id, pn_unit_start, pn_unit_end = extract_pn_unit_info(part)
fixed_parts.append([pn_unit_id, pn_unit_start, pn_unit_end])
pos_to_put_motif.append(1)
@@ -110,6 +123,7 @@ def get_design_pattern_with_constraints(contig, length=None):
start = end = int(part)
variable_ranges.append([start, end])
pos_to_put_motif.append(0)
suff.append("P")
# adjust the total length to solely for free residues
num_motif_residues = sum([i[2] - i[1] + 1 for i in fixed_parts])
@@ -167,7 +181,7 @@ def get_design_pattern_with_constraints(contig, length=None):
atoms_with_motif.append(f"{pn_unit_id}{index}")
elif pos_to_put_motif[idx] == 0:
free_atom = num_free_atoms.pop(0)
atoms_with_motif.append(free_atom)
atoms_with_motif.append(str(free_atom) + suff.pop(0))
elif pos_to_put_motif[idx] == 2:
atoms_with_motif.append("/0")