feat: atom23 porting

This commit is contained in:
Raktim Mitra
2026-01-18 15:41:00 -08:00
committed by Raktim Mitra
parent ece2498c5d
commit 4a7aaf8793
7 changed files with 264 additions and 55 deletions

View File

@@ -72,6 +72,37 @@ ccd_ordering_atomchar = {
'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
'DA': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ',
None),
'DC': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ',
None, None, None),
'DG': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '),
'DT': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ',
None, None),
'A' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ',
None),
'C' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ',
None, None, None),
'G' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '),
'U' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ',
None, None, None),
'DX': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #dna_mask
'X': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #rna mask
}
"""Canonical ordering of amino acid atom names in the CCD."""
@@ -210,3 +241,70 @@ SELECTION_NONPROTEIN = [
"MACROLIDE",
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
backbone_atomscheme_DNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'"]#, None]
backbone_atomscheme_RNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'"]
DNA_atoms = {
'DA': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
'DC': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
'DG': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
'DT': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ']}
RNA_atoms = {
'A': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
'C': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
'G': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
'U': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ']
}
association_schemes['atom23'] = {}
for item in DNA_atoms:
association_schemes['atom23'][item] = tuple(backbone_atomscheme_DNA + DNA_atoms[item]+ [None]*(22 - len(DNA_atoms[item] + backbone_atomscheme_DNA)))
for item in RNA_atoms:
association_schemes['atom23'][item] = tuple(backbone_atomscheme_RNA + RNA_atoms[item]+ [None]*(23 - len(RNA_atoms[item] + backbone_atomscheme_RNA)))
for item in association_schemes['dense']:
association_schemes['atom23'][item] = association_schemes['dense'][item]
association_schemes['atom23']['DX'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None) #rna_mask
association_schemes['atom23']['X'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None)#rna mask
ATOM23_ATOM_NAMES_RNA = np.array(
[item.strip() for item in backbone_atomscheme_RNA] + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
)
"""Atom23 atom names (e.g. CA, V1)"""
ATOM23_ATOM_ELEMENTS_RNA = np.array(
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
)
"""Atom23 element names (e.g. C, VX)"""
ATOM23_ATOM_NAME_TO_ELEMENT = {
name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA)
}
ATOM23_ATOM_NAMES_DNA = np.array(
[item.strip() for item in backbone_atomscheme_DNA] + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
)
"""Atom23 atom names (e.g. CA, V1)"""
ATOM23_ATOM_ELEMENTS_DNA = np.array(
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
)
"""Atom23 element names (e.g. C, VX)"""
"""Mapping from atom14 atom names (e.g. CA, V1) to their corresponding element names (e.g. C, VX)"""
## combining name to element mapping, should be fine
for item in ATOM14_ATOM_NAME_TO_ELEMENT:
ATOM23_ATOM_NAME_TO_ELEMENT[item] = ATOM14_ATOM_NAME_TO_ELEMENT[item]
association_schemes_stripped = {
name: {k: strip_list(v) for k, v in scheme.items()}
for name, scheme in association_schemes.items()
}
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)

View File

@@ -235,6 +235,7 @@ class SampleConditioningType(Transform):
train_conditions: dict,
meta_conditioning_probabilities: dict,
sequence_encoding,
association_scheme,
):
if exists(train_conditions):
train_conditions = hydra.utils.instantiate(
@@ -243,6 +244,7 @@ class SampleConditioningType(Transform):
self.meta_conditioning_probabilities = meta_conditioning_probabilities
self.train_conditions = train_conditions
self.sequence_encoding = sequence_encoding
self.association_scheme = association_scheme
def check_input(self, data: dict):
assert not data["is_inference"], "This transform is only used during training!"
@@ -278,6 +280,8 @@ class SampleConditioningType(Transform):
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -295,6 +299,8 @@ class SampleConditioningFlags(Transform):
"AssignTypes",
"SampleConditioningType",
] # We use is_protein in the PPI training condition
def __init__(self, association_scheme):
self.association_scheme = association_scheme
def check_input(self, data):
assert not data[
@@ -317,13 +323,14 @@ class UnindexFlaggedTokens(Transform):
Serves as the merge point between training / infernece conditioning pipelines
"""
def __init__(self, central_atom):
def __init__(self, central_atom, association_scheme):
"""
Args:
central_atom: The atom to use as the central atom for unindexed motifs.
"""
super().__init__()
self.central_atom = central_atom
self.association_scheme = association_scheme
def check_input(self, data: dict):
check_contains_keys(data, ["atom_array"])
@@ -368,8 +375,16 @@ class UnindexFlaggedTokens(Transform):
token.res_id = token.res_id + max_resid
token.is_C_terminus[:] = False
token.is_N_terminus[:] = False
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token}"
token = add_representative_atom(token, central_atom=self.central_atom)
if association_scheme is not 'atom23':
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme"
token = add_representative_atom(token, central_atom=self.central_atom)
else:
if token.is_protein.all():
token = add_representative_atom(token, central_atom=self.central_atom)
else:
token = add_representative_atom(token, central_atom="C1'")
unindexed_tokens.append(token)
# ... Remove original tokens e.g. during inference

View File

@@ -697,6 +697,7 @@ class AddAdditional1dFeaturesToFeats(Transform):
token_1d_features,
atom_1d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme='atom14'
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_1d_features = token_1d_features
@@ -752,6 +753,11 @@ class AddAdditional1dFeaturesToFeats(Transform):
"""
if "feats" not in data.keys():
data["feats"] = {}
if association_scheme == 'atom23':
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
for feature_name, n_dims in self.token_1d_features.items():
data = self.generate_feature(feature_name, n_dims, data, "token")

View File

@@ -383,6 +383,7 @@ def build_atom14_base_pipeline_(
train_conditions=train_conditions,
meta_conditioning_probabilities=meta_conditioning_probabilities,
sequence_encoding=af3_sequence_encoding,
association_scheme=association_scheme
),
),
]
@@ -422,7 +423,7 @@ def build_atom14_base_pipeline_(
# ... Add global token features (since number of tokens is fixed after cropping)
transforms.append(AddGlobalTokenIdAnnotation())
# ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
transforms.append(TrainingRoute(SampleConditioningFlags()))
transforms.append(TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme)))
# Post-crop transforms
transforms.append(
@@ -518,6 +519,7 @@ def build_atom14_base_pipeline_(
autofill_zeros_if_not_present_in_atomarray=True,
token_1d_features=token_1d_features,
atom_1d_features=atom_1d_features,
association_scheme=association_scheme
),
AddAF3TokenBondFeatures(),
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),

View File

@@ -70,9 +70,11 @@ class IslandCondition(TrainingCondition):
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
association_scheme = 'atom14',
):
self.name = name
self.frequency = frequency
self.association_scheme = association_scheme
# Token selection
self.island_sampling_kwargs = island_sampling_kwargs
@@ -87,11 +89,21 @@ class IslandCondition(TrainingCondition):
self.p_fix_motif_coordinates = p_fix_motif_coordinates
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
self.association_scheme = association_scheme
def is_valid_for_example(self, data) -> bool:
is_protein = data["atom_array"].is_protein
if not np.any(is_protein):
return False
is_dna = data["atom_array"].is_dna
is_rna = data["atom_array"].is_rna
### updating this to allow other polymers
if self.association_scheme is not 'atom23':
if not np.any(is_protein | is_dna | is_rna):
return False
else:
if not np.any(is_protein):
return False
return True
def sample_motif_tokens(self, atom_array):
@@ -101,13 +113,24 @@ class IslandCondition(TrainingCondition):
token_level_array = atom_array[get_token_starts(atom_array)]
# initialize motif tokens as all non-protein tokens
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
n_protein_tokens = np.sum(token_level_array.is_protein)
islands_mask = sample_island_tokens(
n_protein_tokens,
**self.island_sampling_kwargs,
)
is_motif_token[token_level_array.is_protein] = islands_mask
if self.association_scheme is 'atom23':
polymer_mask = (token_level_array.is_protein | token_level_array.is_dna | token_level_array.is_rna)
is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy()
n_polymer_tokens = np.sum(polymer_mask)
islands_mask = sample_island_tokens(
n_polymer_tokens,
**self.island_sampling_kwargs,
)
is_motif_token[polymer_mask] = islands_mask
else:
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
n_protein_tokens = np.sum(token_level_array.is_protein)
slands_mask = sample_island_tokens(
_protein_tokens,
*self.island_sampling_kwargs,
is_motif_token[token_level_array.is_protein] = islands_mask
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
# atom_with_coval_bond = token_level_array.covale # (n_atoms, )
@@ -137,6 +160,7 @@ class IslandCondition(TrainingCondition):
is_motif_atom = sample_motif_subgraphs(
atom_array=atom_array,
**self.subgraph_sampling_kwargs,
association_scheme=self.association_scheme
)
# We also only want resolved atoms to be motif
@@ -158,6 +182,7 @@ class IslandCondition(TrainingCondition):
p_fix_motif_sequence=self.p_fix_motif_sequence,
p_fix_motif_coordinates=self.p_fix_motif_coordinates,
p_unindex_motif_tokens=self.p_unindex_motif_tokens,
association_scheme=self.association_scheme
)
atom_array.set_annotation(
@@ -177,6 +202,7 @@ class PPICondition(TrainingCondition):
"""Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
name = "ppi"
association_scheme = 'atom14'
def is_valid_for_example(self, data):
# Extract relevant data
@@ -275,6 +301,7 @@ class SubtypeCondition(TrainingCondition):
"""
name = "subtype"
association_scheme = 'atom14'
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
self.frequency = frequency
@@ -370,6 +397,7 @@ def sample_motif_subgraphs(
hetatom_n_bond_expectation,
residue_p_fix_all,
hetatom_p_fix_all,
association_scheme = 'atom14'
):
"""
Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
@@ -402,7 +430,13 @@ def sample_motif_subgraphs(
"n_bond_expectation": residue_n_bond_expectation,
"p_fix_all": residue_p_fix_all,
}
if not atom_array_subset.is_protein.all():
if association_scheme is 'atom23':
clause = atom_array_subset.is_protein.all() | atom_array_subset.is_dna.all() | atom_array_subset.is_rna.all()
else:
clause = atom_array_subset.is_potein.all()
if not clause:
args.update(
{
"p_seed_furthest_from_o": 0.0,
@@ -431,11 +465,12 @@ def sample_conditioning_strategy(
p_fix_motif_sequence,
p_fix_motif_coordinates,
p_unindex_motif_tokens,
association_scheme
):
atom_array.set_annotation(
"is_motif_atom_with_fixed_seq",
sample_is_motif_atom_with_fixed_seq(
atom_array, p_fix_motif_sequence=p_fix_motif_sequence
atom_array, p_fix_motif_sequence=p_fix_motif_sequence, association_scheme=association_scheme
),
)
@@ -456,7 +491,7 @@ def sample_conditioning_strategy(
return atom_array
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, association_scheme):
"""
Samples what kind of conditioning to apply to motif tokens.
@@ -469,7 +504,11 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
# By default reveal sequence for non-protein
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
if association_scheme is not 'atom23':
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
return is_motif_atom_with_fixed_seq
@@ -487,7 +526,7 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
return is_motif_atom_with_fixed_coord
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_scheme='atom14'):
"""
Samples which atoms in motif tokens should be flagged for unindexing.
@@ -500,9 +539,15 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
# ensure non-residue atoms are not already flagged
is_motif_atom_unindexed = np.logical_and(
is_motif_atom_unindexed, atom_array.is_residue
)
if association_scheme == 'atom23':
is_motif_atom_unindexed = np.logical_and(
is_motif_atom_unindexed, (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna)
) # is_residue refers to is_protein here
else:
is_motif_atom_unindexed = np.logical_and(
is_motif_atom_unindexed, atom_array.is_residue
)
return is_motif_atom_unindexed

View File

@@ -252,13 +252,13 @@ def get_af3_token_representative_masks(
atom_array: AtomArray, central_atom: str = "CA"
) -> np.ndarray:
pyrimidine_representative_atom = is_pyrimidine(atom_array.res_name) & (
atom_array.atom_name == "C2"
atom_array.atom_name == "C1'"
)
purine_representative_atom = is_purine(atom_array.res_name) & (
atom_array.atom_name == "C4"
atom_array.atom_name == "C1'"
)
unknown_na_representative_atom = is_unknown_nucleotide(atom_array.res_name) & (
atom_array.atom_name == "C4"
atom_array.atom_name == "C1'"
)
glycine_representative_atom = is_glycine(atom_array.res_name) & (

View File

@@ -10,8 +10,11 @@ from atomworks.ml.transforms.base import (
)
from atomworks.ml.utils.token import get_token_starts
from rfd3.constants import (
ATOM23_ATOM_NAME_TO_ELEMENT,
ATOM14_ATOM_NAME_TO_ELEMENT,
ATOM14_ATOM_NAMES,
ATOM23_ATOM_NAMES_RNA,
ATOM23_ATOM_NAMES_DNA,
VIRTUAL_ATOM_ELEMENT_NAME,
association_schemes,
association_schemes_stripped,
@@ -28,7 +31,7 @@ from rfd3.transforms.util_transforms import (
from foundry.common import exists
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14"):
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None):
"""
Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
NB this function is a bit more general since it is used to handle tipatoms too.
@@ -37,17 +40,17 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
raise ValueError(
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
)
atom_names = (
[str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
)
atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
idxs = np.array(
[
association_schemes_stripped[scheme][res_name].index(name)
for name in atom_names
]
)
return ATOM14_ATOM_NAMES[idxs]
if ATOM_NAMES is None:
return ATOM14_ATOM_NAMES[idxs]
else:
return ATOM_NAMES[idxs]
def map_names_to_elements(
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
@@ -68,17 +71,17 @@ def generate_atom_mappings_(scheme="atom14"):
atom_mapping = {}
symmetry_mapping = {}
for aaa, atom14_names in ccd_ordering_atomchar.items():
mapping = list(range(14))
for aaa, atom_names in ccd_ordering_atomchar.items():
mapping = list(range(len(atom_names)))
scheme_names = scheme[aaa]
for ccd_index in range(len(atom14_names)):
atom14_name = atom14_names[ccd_index]
if atom14_name is not None:
for ccd_index in range(len(atom_names)):
atom_name = atom_names[ccd_index]
if atom_name is not None:
assert (
atom14_name in scheme_names
), f"{atom14_name} not in CCD ordering for {aaa}"
scheme_index = scheme_names.index(atom14_name)
atom_name in scheme_names
), f"{atom_name} not in CCD ordering for {aaa}"
scheme_index = scheme_names.index(atom_name)
scheme_index_in_cur_mapping = mapping.index(scheme_index)
mapping[ccd_index], mapping[scheme_index_in_cur_mapping] = (
mapping[scheme_index_in_cur_mapping],
@@ -121,6 +124,12 @@ def permute_symmetric_atom_names_(
) -> list:
# NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
## fail safe, no symmetry confusion in NA bases ##
if (atom_names[0] == "P"):
return atom_names
##################################################
if res_name in association_map:
idx_to_swap = association_map[res_name]
atom_names = atom_names[idx_to_swap]
@@ -171,20 +180,38 @@ class PadTokensWithVirtualAtoms(Transform):
token_ids = np.unique(atom_array.token_id)
assert len(token_ids) == len(
is_motif_atom_with_fixed_seq
), "Token ids and token level array have different lengths!"
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
is_residue = (
token_level_array.is_protein & ~token_level_array.atomize
) | is_motif_token_unindexed
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == 'atom23':
is_residue = (
token_level_array.is_protein & ~token_level_array.atomize
) | is_motif_token_unindexed
is_residue_NA = (
(token_level_array.is_dna | token_level_array.is_rna) & ~token_level_array.atomize
) | is_motif_token_unindexed
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
is_paddable = is_residue & ~(
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
is_non_paddable_residue = is_residue & (
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
is_paddable = (is_residue_NA | is_residue) & ~(
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
is_non_paddable_residue = (is_residue_NA | is_residue) & (
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
else:
is_residue = (
token_level_array.is_protein & ~token_level_array.atomize
) | is_motif_token_unindexed
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
is_paddable = is_residue & ~(
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
is_non_paddable_residue = is_residue & (
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
# Collect virtual atoms to insert (we will insert them all at once)
virtual_atoms_to_insert = []
@@ -194,8 +221,16 @@ class PadTokensWithVirtualAtoms(Transform):
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
if is_paddable[token_id]:
token = atom_array[start:end]
# First, pad with virtual atoms if needed
n_pad = self.n_atoms_per_token - len(token)
if self.association_scheme == "atom23" and atom_array[start].is_dna:
n_atoms_per_token = 22
elif self.association_scheme == "atom23" and atom_array[start].is_rna:
n_atoms_per_token = 23
else:
n_atoms_per_token = self.n_atoms_per_token
n_pad = n_atoms_per_token - len(token)
if n_pad > 0:
mask = get_af3_token_representative_masks(
token, central_atom=self.atom_to_pad_from
@@ -262,18 +297,26 @@ class PadTokensWithVirtualAtoms(Transform):
for token_id, (start, end) in enumerate(
zip(starts_padded[:-1], starts_padded[1:])
):
):
if (atom_array_padded[start].is_dna):
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
elif (atom_array_padded[start].is_rna):
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
else:
ATOM_NAMES = ATOM14_ATOM_NAMES
if is_paddable[token_id]:
# ... Permutation of atom names during training
if not data["is_inference"] and exists(self.association_scheme):
atom_names = permute_symmetric_atom_names_(
ATOM14_ATOM_NAMES,
ATOM_NAMES,
atom_array_padded.res_name[start],
association_map=self.association_map_,
symmetry_map=self.symmetry_map_,
)
else:
atom_names = ATOM14_ATOM_NAMES
atom_names = ATOM_NAMES
atom_array_padded.atom_name[start:end] = atom_names
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
@@ -285,7 +328,7 @@ class PadTokensWithVirtualAtoms(Transform):
)
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
atom_names = map_to_association_scheme(
atom_names, res_name, scheme=self.association_scheme
atom_names, res_name, scheme=self.association_scheme, ATOM_NAMES=ATOM_NAMES
)
atom_array_padded.atom_name[start:end] = atom_names
else: