mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
feat: atom23 porting
This commit is contained in:
committed by
Raktim Mitra
parent
ebec466e4f
commit
94d9d635cd
@@ -242,35 +242,6 @@ SELECTION_NONPROTEIN = [
|
||||
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
|
||||
]
|
||||
|
||||
backbone_atomscheme_DNA = [
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" C1'",
|
||||
] # , None]
|
||||
|
||||
backbone_atomscheme_RNA = [
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" O2'",
|
||||
" C1'",
|
||||
]
|
||||
|
||||
DNA_atoms = {
|
||||
"DA": [
|
||||
" N9 ",
|
||||
@@ -438,3 +409,124 @@ association_schemes_stripped = {
|
||||
|
||||
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
|
||||
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
|
||||
|
||||
|
||||
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
|
||||
ATOM_REGION_BY_RESI = {
|
||||
'ALA': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'ARG': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','NE','CZ','NH1','NH2')},
|
||||
'ASN': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','OD1','ND2')},
|
||||
'ASP': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','OD1','OD2')},
|
||||
'CYS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','SG')},
|
||||
'GLN': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','OE1','NE2')},
|
||||
'GLU': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','OE1','OE2')},
|
||||
'GLY': {'bb':('N','CA','C','O'),
|
||||
'sc':()},
|
||||
'HIS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','ND1','CD2','CE1','NE2')},
|
||||
'ILE': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG1','CG2','CD1')},
|
||||
'LEU': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2')},
|
||||
'LYS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','CE','NZ')},
|
||||
'MET': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','SD','CE')},
|
||||
'PHE': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')},
|
||||
'PRO': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD')},
|
||||
'SER': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','OG')},
|
||||
'THR': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','OG1','CG2')},
|
||||
'TRP': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')},
|
||||
'TYR': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')},
|
||||
'VAL': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG1','CG2')},
|
||||
'UNK': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'MAS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')},
|
||||
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
|
||||
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')},
|
||||
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')},
|
||||
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':()},
|
||||
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')},
|
||||
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
|
||||
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')},
|
||||
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','O2','N3','C4','O4','C5','C6')},
|
||||
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':()},
|
||||
'HIS_D': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','NE2','CD2','CE1','ND1')},
|
||||
}
|
||||
|
||||
# Known planar sidechain atoms for each canonical residue type:
|
||||
PLANAR_ATOMS_BY_RESI = {
|
||||
'ALA': [],
|
||||
'ARG': ['NH1', 'NH2', 'CZ', 'NE', 'CD'],
|
||||
'ASN': ['OD1', 'ND2', 'CG', 'CB'],
|
||||
'ASP': ['OD1', 'OD2', 'CG', 'CB'],
|
||||
'CYS': [],
|
||||
'GLN': ['OE1', 'NE2', 'CD', 'CG'],
|
||||
'GLU': ['OE1', 'OE2', 'CD', 'CG'],
|
||||
'GLY': [],
|
||||
'HIS': ['ND1', 'CE1', 'NE2', 'CD2', 'CG', 'CB'],
|
||||
'ILE': [],
|
||||
'LEU': [],
|
||||
'LYS': [],
|
||||
'MET': [],
|
||||
'PHE': ['CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
|
||||
'PRO': [],
|
||||
'SER': [],
|
||||
'THR': [],
|
||||
'TRP': ['CH2', 'CZ3', 'CZ2', 'CE3', 'CE2', 'CD2', 'NE1', 'CD1', 'CG', 'CB'],
|
||||
'TYR': ['OH', 'CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'],
|
||||
'VAL': [],
|
||||
'UNK': [],
|
||||
'MAS': [],
|
||||
'DA': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'DC': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
|
||||
'DG': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'DT': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1', 'C7'],
|
||||
'DX': [],
|
||||
'A': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'C': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'],
|
||||
'G': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'],
|
||||
'U': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1'],
|
||||
'X': [],
|
||||
'HIS_D': ['ND1', 'CD2', 'CE1', 'NE2', 'CG', 'CB'],
|
||||
}
|
||||
|
||||
# fix C/U symmetry
|
||||
temp = list(association_schemes['atom23']['U'])
|
||||
temp[19], temp[20] = temp[20], temp[19]
|
||||
association_schemes['atom23']['U'] = tuple(temp)
|
||||
|
||||
association_schemes_stripped = {
|
||||
name: {k: strip_list(v) for k, v in scheme.items()}
|
||||
for name, scheme in association_schemes.items()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
@@ -281,7 +281,7 @@ class SampleConditioningType(Transform):
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
@@ -299,6 +299,8 @@ class SampleConditioningFlags(Transform):
|
||||
"AssignTypes",
|
||||
"SampleConditioningType",
|
||||
] # We use is_protein in the PPI training condition
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
@@ -698,6 +698,7 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
atom_1d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme="atom14",
|
||||
association_scheme='atom14'
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_1d_features = token_1d_features
|
||||
@@ -754,6 +755,11 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
|
||||
if association_scheme == 'atom23':
|
||||
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
|
||||
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
|
||||
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
|
||||
|
||||
if self.association_scheme == "atom23":
|
||||
data["atom_array"].set_annotation(
|
||||
|
||||
@@ -523,7 +523,7 @@ def build_atom14_base_pipeline_(
|
||||
autofill_zeros_if_not_present_in_atomarray=True,
|
||||
token_1d_features=token_1d_features,
|
||||
atom_1d_features=atom_1d_features,
|
||||
association_scheme=association_scheme,
|
||||
association_scheme=association_scheme
|
||||
),
|
||||
AddAF3TokenBondFeatures(),
|
||||
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
|
||||
|
||||
@@ -72,9 +72,11 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_coordinates,
|
||||
p_fix_motif_sequence,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme = 'atom14',
|
||||
):
|
||||
self.name = name
|
||||
self.frequency = frequency
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Token selection
|
||||
self.island_sampling_kwargs = island_sampling_kwargs
|
||||
@@ -89,13 +91,15 @@ class IslandCondition(TrainingCondition):
|
||||
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
||||
self.p_fix_motif_sequence = p_fix_motif_sequence
|
||||
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
||||
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def is_valid_for_example(self, data) -> bool:
|
||||
is_protein = data["atom_array"].is_protein
|
||||
is_dna = data["atom_array"].is_dna
|
||||
is_rna = data["atom_array"].is_rna
|
||||
### updating this to allow other polymers
|
||||
if not self.association_scheme == "atom23":
|
||||
if self.association_scheme == "atom23":
|
||||
if not np.any(is_protein | is_dna | is_rna):
|
||||
return False
|
||||
else:
|
||||
@@ -134,6 +138,7 @@ class IslandCondition(TrainingCondition):
|
||||
n_protein_tokens,
|
||||
**self.island_sampling_kwargs,
|
||||
)
|
||||
|
||||
is_motif_token[token_level_array.is_protein] = islands_mask
|
||||
|
||||
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
|
||||
@@ -305,7 +310,7 @@ class SubtypeCondition(TrainingCondition):
|
||||
"""
|
||||
|
||||
name = "subtype"
|
||||
association_scheme = "atom14"
|
||||
association_scheme = 'atom14'
|
||||
|
||||
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
|
||||
self.frequency = frequency
|
||||
@@ -521,6 +526,8 @@ def sample_is_motif_atom_with_fixed_seq(
|
||||
is_motif_atom_with_fixed_seq = (
|
||||
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
)
|
||||
|
||||
|
||||
|
||||
return is_motif_atom_with_fixed_seq
|
||||
|
||||
@@ -564,6 +571,7 @@ def sample_unindexed_atoms(
|
||||
is_motif_atom_unindexed, atom_array.is_residue
|
||||
)
|
||||
|
||||
|
||||
return is_motif_atom_unindexed
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ def map_to_association_scheme(
|
||||
else:
|
||||
return ATOM_NAMES[idxs]
|
||||
|
||||
|
||||
def map_names_to_elements(
|
||||
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
||||
) -> np.ndarray:
|
||||
@@ -130,11 +129,6 @@ def permute_symmetric_atom_names_(
|
||||
# NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
|
||||
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
|
||||
|
||||
## fail safe, no symmetry confusion in NA bases ##
|
||||
if atom_names[0] == "P":
|
||||
return atom_names
|
||||
##################################################
|
||||
|
||||
if res_name in association_map:
|
||||
idx_to_swap = association_map[res_name]
|
||||
atom_names = atom_names[idx_to_swap]
|
||||
@@ -185,7 +179,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
token_ids = np.unique(atom_array.token_id)
|
||||
assert len(token_ids) == len(
|
||||
is_motif_atom_with_fixed_seq
|
||||
), "Token ids and token level array have different lengths!"
|
||||
), "Token ids and token level array have different lengths!"
|
||||
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == "atom23":
|
||||
@@ -226,7 +220,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
||||
if is_paddable[token_id]:
|
||||
token = atom_array[start:end]
|
||||
|
||||
|
||||
# First, pad with virtual atoms if needed
|
||||
if self.association_scheme == "atom23" and atom_array[start].is_dna:
|
||||
n_atoms_per_token = 22
|
||||
@@ -235,7 +229,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
else:
|
||||
n_atoms_per_token = self.n_atoms_per_token
|
||||
n_pad = n_atoms_per_token - len(token)
|
||||
|
||||
|
||||
if n_pad > 0:
|
||||
mask = get_af3_token_representative_masks(
|
||||
token, central_atom=self.atom_to_pad_from
|
||||
|
||||
Reference in New Issue
Block a user