feat: atom23 porting

This commit is contained in:
Raktim Mitra
2026-01-18 15:41:00 -08:00
committed by Raktim Mitra
parent ebec466e4f
commit 94d9d635cd
6 changed files with 144 additions and 42 deletions

View File

@@ -242,35 +242,6 @@ SELECTION_NONPROTEIN = [
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
]
backbone_atomscheme_DNA = [
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" C1'",
] # , None]
backbone_atomscheme_RNA = [
" P ",
" OP1",
" OP2",
" O5'",
" C5'",
" C4'",
" O4'",
" C3'",
" O3'",
" C2'",
" O2'",
" C1'",
]
DNA_atoms = {
"DA": [
" N9 ",
@@ -438,3 +409,124 @@ association_schemes_stripped = {
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
ATOM_REGION_BY_RESI = {
'ALA': {'bb':('N','CA','C','O'),
'sc':('CB')},
'ARG': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','NE','CZ','NH1','NH2')},
'ASN': {'bb':('N','CA','C','O'),
'sc':('CB','CG','OD1','ND2')},
'ASP': {'bb':('N','CA','C','O'),
'sc':('CB','CG','OD1','OD2')},
'CYS': {'bb':('N','CA','C','O'),
'sc':('CB','SG')},
'GLN': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','OE1','NE2')},
'GLU': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','OE1','OE2')},
'GLY': {'bb':('N','CA','C','O'),
'sc':()},
'HIS': {'bb':('N','CA','C','O'),
'sc':('CB','CG','ND1','CD2','CE1','NE2')},
'ILE': {'bb':('N','CA','C','O'),
'sc':('CB','CG1','CG2','CD1')},
'LEU': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2')},
'LYS': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','CE','NZ')},
'MET': {'bb':('N','CA','C','O'),
'sc':('CB','CG','SD','CE')},
'PHE': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')},
'PRO': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD')},
'SER': {'bb':('N','CA','C','O'),
'sc':('CB','OG')},
'THR': {'bb':('N','CA','C','O'),
'sc':('CB','OG1','CG2')},
'TRP': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')},
'TYR': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')},
'VAL': {'bb':('N','CA','C','O'),
'sc':('CB','CG1','CG2')},
'UNK': {'bb':('N','CA','C','O'),
'sc':('CB')},
'MAS': {'bb':('N','CA','C','O'),
'sc':('CB')},
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')},
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')},
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')},
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':()},
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')},
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')},
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','O2','N3','C4','O4','C5','C6')},
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':()},
'HIS_D': {'bb':('N','CA','C','O'),
'sc':('CB','CG','NE2','CD2','CE1','ND1')},
}
# Known planar sidechain atoms for each canonical residue type:
PLANAR_ATOMS_BY_RESI = {
'ALA': [],
'ARG': ['NH1', 'NH2', 'CZ', 'NE', 'CD'],
'ASN': ['OD1', 'ND2', 'CG', 'CB'],
'ASP': ['OD1', 'OD2', 'CG', 'CB'],
'CYS': [],
'GLN': ['OE1', 'NE2', 'CD', 'CG'],
'GLU': ['OE1', 'OE2', 'CD', 'CG'],
'GLY': [],
'HIS': ['ND1', 'CE1', 'NE2', 'CD2', 'CG', 'CB'],
'ILE': [],
'LEU': [],
'LYS': [],
'MET': [],
'PHE': ['CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
'PRO': [],
'SER': [],
'THR': [],
'TRP': ['CH2', 'CZ3', 'CZ2', 'CE3', 'CE2', 'CD2', 'NE1', 'CD1', 'CG', 'CB'],
'TYR': ['OH', 'CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
'VAL': [],
'UNK': [],
'MAS': [],
'DA': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'DC': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
'DG': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'DT': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1', 'C7'],
'DX': [],
'A': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'C': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
'G': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
'U': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1'],
'X': [],
'HIS_D': ['ND1', 'CD2', 'CE1', 'NE2', 'CG', 'CB'],
}
# fix C/U symmetry
temp = list(association_schemes['atom23']['U'])
temp[19], temp[20] = temp[20], temp[19]
association_schemes['atom23']['U'] = tuple(temp)
association_schemes_stripped = {
name: {k: strip_list(v) for k, v in scheme.items()}
for name, scheme in association_schemes.items()
}
if __name__ == "__main__":
import pdb; pdb.set_trace()

View File

@@ -281,7 +281,7 @@ class SampleConditioningType(Transform):
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -299,6 +299,8 @@ class SampleConditioningFlags(Transform):
"AssignTypes",
"SampleConditioningType",
] # We use is_protein in the PPI training condition
def __init__(self, association_scheme):
self.association_scheme = association_scheme
def __init__(self, association_scheme):
self.association_scheme = association_scheme

View File

@@ -698,6 +698,7 @@ class AddAdditional1dFeaturesToFeats(Transform):
atom_1d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme="atom14",
association_scheme='atom14'
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_1d_features = token_1d_features
@@ -754,6 +755,11 @@ class AddAdditional1dFeaturesToFeats(Transform):
"""
if "feats" not in data.keys():
data["feats"] = {}
if association_scheme == 'atom23':
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
if self.association_scheme == "atom23":
data["atom_array"].set_annotation(

View File

@@ -523,7 +523,7 @@ def build_atom14_base_pipeline_(
autofill_zeros_if_not_present_in_atomarray=True,
token_1d_features=token_1d_features,
atom_1d_features=atom_1d_features,
association_scheme=association_scheme,
association_scheme=association_scheme
),
AddAF3TokenBondFeatures(),
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),

View File

@@ -72,9 +72,11 @@ class IslandCondition(TrainingCondition):
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
association_scheme = 'atom14',
):
self.name = name
self.frequency = frequency
self.association_scheme = association_scheme
# Token selection
self.island_sampling_kwargs = island_sampling_kwargs
@@ -89,13 +91,15 @@ class IslandCondition(TrainingCondition):
self.p_fix_motif_coordinates = p_fix_motif_coordinates
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
self.association_scheme = association_scheme
def is_valid_for_example(self, data) -> bool:
is_protein = data["atom_array"].is_protein
is_dna = data["atom_array"].is_dna
is_rna = data["atom_array"].is_rna
### updating this to allow other polymers
if not self.association_scheme == "atom23":
if self.association_scheme == "atom23":
if not np.any(is_protein | is_dna | is_rna):
return False
else:
@@ -134,6 +138,7 @@ class IslandCondition(TrainingCondition):
n_protein_tokens,
**self.island_sampling_kwargs,
)
is_motif_token[token_level_array.is_protein] = islands_mask
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
@@ -305,7 +310,7 @@ class SubtypeCondition(TrainingCondition):
"""
name = "subtype"
association_scheme = "atom14"
association_scheme = 'atom14'
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
self.frequency = frequency
@@ -521,6 +526,8 @@ def sample_is_motif_atom_with_fixed_seq(
is_motif_atom_with_fixed_seq = (
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
)
return is_motif_atom_with_fixed_seq
@@ -564,6 +571,7 @@ def sample_unindexed_atoms(
is_motif_atom_unindexed, atom_array.is_residue
)
return is_motif_atom_unindexed

View File

@@ -53,7 +53,6 @@ def map_to_association_scheme(
else:
return ATOM_NAMES[idxs]
def map_names_to_elements(
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
) -> np.ndarray:
@@ -130,11 +129,6 @@ def permute_symmetric_atom_names_(
# NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
## fail safe, no symmetry confusion in NA bases ##
if atom_names[0] == "P":
return atom_names
##################################################
if res_name in association_map:
idx_to_swap = association_map[res_name]
atom_names = atom_names[idx_to_swap]
@@ -185,7 +179,7 @@ class PadTokensWithVirtualAtoms(Transform):
token_ids = np.unique(atom_array.token_id)
assert len(token_ids) == len(
is_motif_atom_with_fixed_seq
), "Token ids and token level array have different lengths!"
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == "atom23":
@@ -226,7 +220,7 @@ class PadTokensWithVirtualAtoms(Transform):
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
if is_paddable[token_id]:
token = atom_array[start:end]
# First, pad with virtual atoms if needed
if self.association_scheme == "atom23" and atom_array[start].is_dna:
n_atoms_per_token = 22
@@ -235,7 +229,7 @@ class PadTokensWithVirtualAtoms(Transform):
else:
n_atoms_per_token = self.n_atoms_per_token
n_pad = n_atoms_per_token - len(token)
if n_pad > 0:
mask = get_af3_token_representative_masks(
token, central_atom=self.atom_to_pad_from