mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
feat: atom23 porting
This commit is contained in:
committed by
Raktim Mitra
parent
ece2498c5d
commit
4a7aaf8793
@@ -72,6 +72,37 @@ ccd_ordering_atomchar = {
|
||||
'GLY': (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None), # gly
|
||||
'UNK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # unk
|
||||
'MSK': (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None), # mask
|
||||
'DA': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
|
||||
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ',
|
||||
None),
|
||||
|
||||
'DC': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
|
||||
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ',
|
||||
None, None, None),
|
||||
|
||||
'DG': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
|
||||
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '),
|
||||
|
||||
'DT': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'",
|
||||
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ',
|
||||
None, None),
|
||||
|
||||
'A' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
|
||||
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 ',
|
||||
None),
|
||||
|
||||
'C' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
|
||||
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 ',
|
||||
None, None, None),
|
||||
|
||||
'G' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
|
||||
' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '),
|
||||
|
||||
'U' : (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'",
|
||||
' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ',
|
||||
None, None, None),
|
||||
'DX': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #dna_mask
|
||||
'X': (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None), #rna mask
|
||||
}
|
||||
"""Canonical ordering of amino acid atom names in the CCD."""
|
||||
|
||||
@@ -210,3 +241,70 @@ SELECTION_NONPROTEIN = [
|
||||
"MACROLIDE",
|
||||
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
|
||||
]
|
||||
|
||||
backbone_atomscheme_DNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'"]#, None]
|
||||
|
||||
backbone_atomscheme_RNA = [' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'"]
|
||||
|
||||
DNA_atoms = {
|
||||
'DA': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
|
||||
'DC': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
|
||||
'DG': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
|
||||
'DT': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C7 ', ' C6 ']}
|
||||
|
||||
RNA_atoms = {
|
||||
'A': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' N6 ', ' N1 ', ' C2 ', ' N3 ', ' C4 '],
|
||||
'C': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' N4 ', ' C5 ', ' C6 '],
|
||||
'G': [' N9 ', ' C8 ', ' N7 ', ' C5 ', ' C6 ', ' O6 ', ' N1 ', ' C2 ', ' N2 ', ' N3 ', ' C4 '],
|
||||
'U': [' N1 ', ' C2 ', ' O2 ', ' N3 ', ' C4 ', ' O4 ', ' C5 ', ' C6 ']
|
||||
}
|
||||
|
||||
association_schemes['atom23'] = {}
|
||||
for item in DNA_atoms:
|
||||
association_schemes['atom23'][item] = tuple(backbone_atomscheme_DNA + DNA_atoms[item]+ [None]*(22 - len(DNA_atoms[item] + backbone_atomscheme_DNA)))
|
||||
for item in RNA_atoms:
|
||||
association_schemes['atom23'][item] = tuple(backbone_atomscheme_RNA + RNA_atoms[item]+ [None]*(23 - len(RNA_atoms[item] + backbone_atomscheme_RNA)))
|
||||
|
||||
for item in association_schemes['dense']:
|
||||
association_schemes['atom23'][item] = association_schemes['dense'][item]
|
||||
|
||||
association_schemes['atom23']['DX'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " C1'", None, None, None, None, None, None, None, None, None, None, None) #rna_mask
|
||||
association_schemes['atom23']['X'] = (' P ', ' OP1', ' OP2', " O5'", " C5'", " C4'", " O4'", " C3'", " O3'", " C2'", " O2'", " C1'", None, None, None, None, None, None, None, None, None, None, None)#rna mask
|
||||
|
||||
ATOM23_ATOM_NAMES_RNA = np.array(
|
||||
[item.strip() for item in backbone_atomscheme_RNA] + [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
)
|
||||
"""Atom23 atom names (e.g. CA, V1)"""
|
||||
|
||||
ATOM23_ATOM_ELEMENTS_RNA = np.array(
|
||||
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "O", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
)
|
||||
"""Atom23 element names (e.g. C, VX)"""
|
||||
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT = {
|
||||
name: elem for name, elem in zip(ATOM23_ATOM_NAMES_RNA, ATOM23_ATOM_ELEMENTS_RNA)
|
||||
}
|
||||
ATOM23_ATOM_NAMES_DNA = np.array(
|
||||
[item.strip() for item in backbone_atomscheme_DNA] + [f"V{i}" for i in range(22 - len(backbone_atomscheme_DNA))]
|
||||
)
|
||||
"""Atom23 atom names (e.g. CA, V1)"""
|
||||
|
||||
ATOM23_ATOM_ELEMENTS_DNA = np.array(
|
||||
["P", "O", "O", "O", "C", "C", "O", "C","O", "C", "C"] + [VIRTUAL_ATOM_ELEMENT_NAME for i in range(22 - len(backbone_atomscheme_DNA))]
|
||||
)
|
||||
"""Atom23 element names (e.g. C, VX)"""
|
||||
|
||||
|
||||
"""Mapping from atom14 atom names (e.g. CA, V1) to their corresponding element names (e.g. C, VX)"""
|
||||
## combining name to element mapping, should be fine
|
||||
for item in ATOM14_ATOM_NAME_TO_ELEMENT:
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT[item] = ATOM14_ATOM_NAME_TO_ELEMENT[item]
|
||||
|
||||
association_schemes_stripped = {
|
||||
name: {k: strip_list(v) for k, v in scheme.items()}
|
||||
for name, scheme in association_schemes.items()
|
||||
}
|
||||
|
||||
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
|
||||
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
|
||||
|
||||
|
||||
@@ -235,6 +235,7 @@ class SampleConditioningType(Transform):
|
||||
train_conditions: dict,
|
||||
meta_conditioning_probabilities: dict,
|
||||
sequence_encoding,
|
||||
association_scheme,
|
||||
):
|
||||
if exists(train_conditions):
|
||||
train_conditions = hydra.utils.instantiate(
|
||||
@@ -243,6 +244,7 @@ class SampleConditioningType(Transform):
|
||||
self.meta_conditioning_probabilities = meta_conditioning_probabilities
|
||||
self.train_conditions = train_conditions
|
||||
self.sequence_encoding = sequence_encoding
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def check_input(self, data: dict):
|
||||
assert not data["is_inference"], "This transform is only used during training!"
|
||||
@@ -278,6 +280,8 @@ class SampleConditioningType(Transform):
|
||||
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
@@ -295,6 +299,8 @@ class SampleConditioningFlags(Transform):
|
||||
"AssignTypes",
|
||||
"SampleConditioningType",
|
||||
] # We use is_protein in the PPI training condition
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def check_input(self, data):
|
||||
assert not data[
|
||||
@@ -317,13 +323,14 @@ class UnindexFlaggedTokens(Transform):
|
||||
Serves as the merge point between training / infernece conditioning pipelines
|
||||
"""
|
||||
|
||||
def __init__(self, central_atom):
|
||||
def __init__(self, central_atom, association_scheme):
|
||||
"""
|
||||
Args:
|
||||
central_atom: The atom to use as the central atom for unindexed motifs.
|
||||
"""
|
||||
super().__init__()
|
||||
self.central_atom = central_atom
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def check_input(self, data: dict):
|
||||
check_contains_keys(data, ["atom_array"])
|
||||
@@ -368,8 +375,16 @@ class UnindexFlaggedTokens(Transform):
|
||||
token.res_id = token.res_id + max_resid
|
||||
token.is_C_terminus[:] = False
|
||||
token.is_N_terminus[:] = False
|
||||
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token}"
|
||||
token = add_representative_atom(token, central_atom=self.central_atom)
|
||||
|
||||
if association_scheme is not 'atom23':
|
||||
assert token.is_protein.all(), f"Cannot unindex non-protein token: {token} unless using atom23 association scheme"
|
||||
token = add_representative_atom(token, central_atom=self.central_atom)
|
||||
else:
|
||||
if token.is_protein.all():
|
||||
token = add_representative_atom(token, central_atom=self.central_atom)
|
||||
else:
|
||||
token = add_representative_atom(token, central_atom="C1'")
|
||||
|
||||
unindexed_tokens.append(token)
|
||||
|
||||
# ... Remove original tokens e.g. during inference
|
||||
|
||||
@@ -697,6 +697,7 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
token_1d_features,
|
||||
atom_1d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme='atom14'
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_1d_features = token_1d_features
|
||||
@@ -752,6 +753,11 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
|
||||
if association_scheme == 'atom23':
|
||||
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
|
||||
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
|
||||
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
|
||||
|
||||
for feature_name, n_dims in self.token_1d_features.items():
|
||||
data = self.generate_feature(feature_name, n_dims, data, "token")
|
||||
|
||||
@@ -383,6 +383,7 @@ def build_atom14_base_pipeline_(
|
||||
train_conditions=train_conditions,
|
||||
meta_conditioning_probabilities=meta_conditioning_probabilities,
|
||||
sequence_encoding=af3_sequence_encoding,
|
||||
association_scheme=association_scheme
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -422,7 +423,7 @@ def build_atom14_base_pipeline_(
|
||||
# ... Add global token features (since number of tokens is fixed after cropping)
|
||||
transforms.append(AddGlobalTokenIdAnnotation())
|
||||
# ... Create masks (NOTE: Modulates token count, and resets global token id if necessary)
|
||||
transforms.append(TrainingRoute(SampleConditioningFlags()))
|
||||
transforms.append(TrainingRoute(SampleConditioningFlags(association_scheme=association_scheme)))
|
||||
|
||||
# Post-crop transforms
|
||||
transforms.append(
|
||||
@@ -518,6 +519,7 @@ def build_atom14_base_pipeline_(
|
||||
autofill_zeros_if_not_present_in_atomarray=True,
|
||||
token_1d_features=token_1d_features,
|
||||
atom_1d_features=atom_1d_features,
|
||||
association_scheme=association_scheme
|
||||
),
|
||||
AddAF3TokenBondFeatures(),
|
||||
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
|
||||
|
||||
@@ -70,9 +70,11 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_coordinates,
|
||||
p_fix_motif_sequence,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme = 'atom14',
|
||||
):
|
||||
self.name = name
|
||||
self.frequency = frequency
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Token selection
|
||||
self.island_sampling_kwargs = island_sampling_kwargs
|
||||
@@ -87,11 +89,21 @@ class IslandCondition(TrainingCondition):
|
||||
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
||||
self.p_fix_motif_sequence = p_fix_motif_sequence
|
||||
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
||||
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def is_valid_for_example(self, data) -> bool:
|
||||
is_protein = data["atom_array"].is_protein
|
||||
if not np.any(is_protein):
|
||||
return False
|
||||
is_dna = data["atom_array"].is_dna
|
||||
is_rna = data["atom_array"].is_rna
|
||||
### updating this to allow other polymers
|
||||
if self.association_scheme is not 'atom23':
|
||||
if not np.any(is_protein | is_dna | is_rna):
|
||||
return False
|
||||
else:
|
||||
if not np.any(is_protein):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def sample_motif_tokens(self, atom_array):
|
||||
@@ -101,13 +113,24 @@ class IslandCondition(TrainingCondition):
|
||||
token_level_array = atom_array[get_token_starts(atom_array)]
|
||||
|
||||
# initialize motif tokens as all non-protein tokens
|
||||
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
|
||||
n_protein_tokens = np.sum(token_level_array.is_protein)
|
||||
islands_mask = sample_island_tokens(
|
||||
n_protein_tokens,
|
||||
**self.island_sampling_kwargs,
|
||||
)
|
||||
is_motif_token[token_level_array.is_protein] = islands_mask
|
||||
if self.association_scheme is 'atom23':
|
||||
polymer_mask = (token_level_array.is_protein | token_level_array.is_dna | token_level_array.is_rna)
|
||||
is_motif_token = np.asarray(~polymer_mask, dtype=bool).copy()
|
||||
n_polymer_tokens = np.sum(polymer_mask)
|
||||
islands_mask = sample_island_tokens(
|
||||
n_polymer_tokens,
|
||||
**self.island_sampling_kwargs,
|
||||
)
|
||||
is_motif_token[polymer_mask] = islands_mask
|
||||
else:
|
||||
is_motif_token = np.asarray(~token_level_array.is_protein, dtype=bool).copy()
|
||||
n_protein_tokens = np.sum(token_level_array.is_protein)
|
||||
|
||||
slands_mask = sample_island_tokens(
|
||||
_protein_tokens,
|
||||
*self.island_sampling_kwargs,
|
||||
|
||||
is_motif_token[token_level_array.is_protein] = islands_mask
|
||||
|
||||
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
|
||||
# atom_with_coval_bond = token_level_array.covale # (n_atoms, )
|
||||
@@ -137,6 +160,7 @@ class IslandCondition(TrainingCondition):
|
||||
is_motif_atom = sample_motif_subgraphs(
|
||||
atom_array=atom_array,
|
||||
**self.subgraph_sampling_kwargs,
|
||||
association_scheme=self.association_scheme
|
||||
)
|
||||
|
||||
# We also only want resolved atoms to be motif
|
||||
@@ -158,6 +182,7 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_sequence=self.p_fix_motif_sequence,
|
||||
p_fix_motif_coordinates=self.p_fix_motif_coordinates,
|
||||
p_unindex_motif_tokens=self.p_unindex_motif_tokens,
|
||||
association_scheme=self.association_scheme
|
||||
)
|
||||
|
||||
atom_array.set_annotation(
|
||||
@@ -177,6 +202,7 @@ class PPICondition(TrainingCondition):
|
||||
"""Get condition indicating what is motif and what is to be diffused for protein-protein interaction training."""
|
||||
|
||||
name = "ppi"
|
||||
association_scheme = 'atom14'
|
||||
|
||||
def is_valid_for_example(self, data):
|
||||
# Extract relevant data
|
||||
@@ -275,6 +301,7 @@ class SubtypeCondition(TrainingCondition):
|
||||
"""
|
||||
|
||||
name = "subtype"
|
||||
association_scheme = 'atom14'
|
||||
|
||||
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
|
||||
self.frequency = frequency
|
||||
@@ -370,6 +397,7 @@ def sample_motif_subgraphs(
|
||||
hetatom_n_bond_expectation,
|
||||
residue_p_fix_all,
|
||||
hetatom_p_fix_all,
|
||||
association_scheme = 'atom14'
|
||||
):
|
||||
"""
|
||||
Returns a boolean mask over atoms, indicating which atoms are part of the sampled motif.
|
||||
@@ -402,7 +430,13 @@ def sample_motif_subgraphs(
|
||||
"n_bond_expectation": residue_n_bond_expectation,
|
||||
"p_fix_all": residue_p_fix_all,
|
||||
}
|
||||
if not atom_array_subset.is_protein.all():
|
||||
|
||||
if association_scheme is 'atom23':
|
||||
clause = atom_array_subset.is_protein.all() | atom_array_subset.is_dna.all() | atom_array_subset.is_rna.all()
|
||||
else:
|
||||
clause = atom_array_subset.is_potein.all()
|
||||
|
||||
if not clause:
|
||||
args.update(
|
||||
{
|
||||
"p_seed_furthest_from_o": 0.0,
|
||||
@@ -431,11 +465,12 @@ def sample_conditioning_strategy(
|
||||
p_fix_motif_sequence,
|
||||
p_fix_motif_coordinates,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme
|
||||
):
|
||||
atom_array.set_annotation(
|
||||
"is_motif_atom_with_fixed_seq",
|
||||
sample_is_motif_atom_with_fixed_seq(
|
||||
atom_array, p_fix_motif_sequence=p_fix_motif_sequence
|
||||
atom_array, p_fix_motif_sequence=p_fix_motif_sequence, association_scheme=association_scheme
|
||||
),
|
||||
)
|
||||
|
||||
@@ -456,7 +491,7 @@ def sample_conditioning_strategy(
|
||||
return atom_array
|
||||
|
||||
|
||||
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
|
||||
def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence, association_scheme):
|
||||
"""
|
||||
Samples what kind of conditioning to apply to motif tokens.
|
||||
|
||||
@@ -469,7 +504,11 @@ def sample_is_motif_atom_with_fixed_seq(atom_array, p_fix_motif_sequence):
|
||||
is_motif_atom_with_fixed_seq = np.zeros(atom_array.array_length(), dtype=bool)
|
||||
|
||||
# By default reveal sequence for non-protein
|
||||
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
|
||||
if association_scheme is not 'atom23':
|
||||
is_motif_atom_with_fixed_seq = is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
|
||||
|
||||
return is_motif_atom_with_fixed_seq
|
||||
|
||||
|
||||
@@ -487,7 +526,7 @@ def sample_fix_motif_coordinates(atom_array, p_fix_motif_coordinates):
|
||||
return is_motif_atom_with_fixed_coord
|
||||
|
||||
|
||||
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
|
||||
def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens, association_scheme='atom14'):
|
||||
"""
|
||||
Samples which atoms in motif tokens should be flagged for unindexing.
|
||||
|
||||
@@ -500,9 +539,15 @@ def sample_unindexed_atoms(atom_array, p_unindex_motif_tokens):
|
||||
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
|
||||
|
||||
# ensure non-residue atoms are not already flagged
|
||||
is_motif_atom_unindexed = np.logical_and(
|
||||
is_motif_atom_unindexed, atom_array.is_residue
|
||||
)
|
||||
if association_scheme == 'atom23':
|
||||
is_motif_atom_unindexed = np.logical_and(
|
||||
is_motif_atom_unindexed, (atom_array.is_residue | atom_array.is_dna | atom_array.is_rna)
|
||||
) # is_residue refers to is_protein here
|
||||
else:
|
||||
is_motif_atom_unindexed = np.logical_and(
|
||||
is_motif_atom_unindexed, atom_array.is_residue
|
||||
)
|
||||
|
||||
|
||||
return is_motif_atom_unindexed
|
||||
|
||||
|
||||
@@ -252,13 +252,13 @@ def get_af3_token_representative_masks(
|
||||
atom_array: AtomArray, central_atom: str = "CA"
|
||||
) -> np.ndarray:
|
||||
pyrimidine_representative_atom = is_pyrimidine(atom_array.res_name) & (
|
||||
atom_array.atom_name == "C2"
|
||||
atom_array.atom_name == "C1'"
|
||||
)
|
||||
purine_representative_atom = is_purine(atom_array.res_name) & (
|
||||
atom_array.atom_name == "C4"
|
||||
atom_array.atom_name == "C1'"
|
||||
)
|
||||
unknown_na_representative_atom = is_unknown_nucleotide(atom_array.res_name) & (
|
||||
atom_array.atom_name == "C4"
|
||||
atom_array.atom_name == "C1'"
|
||||
)
|
||||
|
||||
glycine_representative_atom = is_glycine(atom_array.res_name) & (
|
||||
|
||||
@@ -10,8 +10,11 @@ from atomworks.ml.transforms.base import (
|
||||
)
|
||||
from atomworks.ml.utils.token import get_token_starts
|
||||
from rfd3.constants import (
|
||||
ATOM23_ATOM_NAME_TO_ELEMENT,
|
||||
ATOM14_ATOM_NAME_TO_ELEMENT,
|
||||
ATOM14_ATOM_NAMES,
|
||||
ATOM23_ATOM_NAMES_RNA,
|
||||
ATOM23_ATOM_NAMES_DNA,
|
||||
VIRTUAL_ATOM_ELEMENT_NAME,
|
||||
association_schemes,
|
||||
association_schemes_stripped,
|
||||
@@ -28,7 +31,7 @@ from rfd3.transforms.util_transforms import (
|
||||
from foundry.common import exists
|
||||
|
||||
|
||||
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14"):
|
||||
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14", ATOM_NAMES=None):
|
||||
"""
|
||||
Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
|
||||
NB this function is a bit more general since it is used to handle tipatoms too.
|
||||
@@ -37,17 +40,17 @@ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="ato
|
||||
raise ValueError(
|
||||
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
|
||||
)
|
||||
atom_names = (
|
||||
[str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
|
||||
)
|
||||
atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
|
||||
idxs = np.array(
|
||||
[
|
||||
association_schemes_stripped[scheme][res_name].index(name)
|
||||
for name in atom_names
|
||||
]
|
||||
)
|
||||
return ATOM14_ATOM_NAMES[idxs]
|
||||
|
||||
if ATOM_NAMES is None:
|
||||
return ATOM14_ATOM_NAMES[idxs]
|
||||
else:
|
||||
return ATOM_NAMES[idxs]
|
||||
|
||||
def map_names_to_elements(
|
||||
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
||||
@@ -68,17 +71,17 @@ def generate_atom_mappings_(scheme="atom14"):
|
||||
atom_mapping = {}
|
||||
symmetry_mapping = {}
|
||||
|
||||
for aaa, atom14_names in ccd_ordering_atomchar.items():
|
||||
mapping = list(range(14))
|
||||
for aaa, atom_names in ccd_ordering_atomchar.items():
|
||||
mapping = list(range(len(atom_names)))
|
||||
scheme_names = scheme[aaa]
|
||||
|
||||
for ccd_index in range(len(atom14_names)):
|
||||
atom14_name = atom14_names[ccd_index]
|
||||
if atom14_name is not None:
|
||||
for ccd_index in range(len(atom_names)):
|
||||
atom_name = atom_names[ccd_index]
|
||||
if atom_name is not None:
|
||||
assert (
|
||||
atom14_name in scheme_names
|
||||
), f"{atom14_name} not in CCD ordering for {aaa}"
|
||||
scheme_index = scheme_names.index(atom14_name)
|
||||
atom_name in scheme_names
|
||||
), f"{atom_name} not in CCD ordering for {aaa}"
|
||||
scheme_index = scheme_names.index(atom_name)
|
||||
scheme_index_in_cur_mapping = mapping.index(scheme_index)
|
||||
mapping[ccd_index], mapping[scheme_index_in_cur_mapping] = (
|
||||
mapping[scheme_index_in_cur_mapping],
|
||||
@@ -121,6 +124,12 @@ def permute_symmetric_atom_names_(
|
||||
) -> list:
|
||||
# NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
|
||||
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
|
||||
|
||||
## fail safe, no symmetry confusion in NA bases ##
|
||||
if (atom_names[0] == "P"):
|
||||
return atom_names
|
||||
##################################################
|
||||
|
||||
if res_name in association_map:
|
||||
idx_to_swap = association_map[res_name]
|
||||
atom_names = atom_names[idx_to_swap]
|
||||
@@ -171,20 +180,38 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
token_ids = np.unique(atom_array.token_id)
|
||||
assert len(token_ids) == len(
|
||||
is_motif_atom_with_fixed_seq
|
||||
), "Token ids and token level array have different lengths!"
|
||||
), "Token ids and token level array have different lengths!"
|
||||
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
is_residue = (
|
||||
token_level_array.is_protein & ~token_level_array.atomize
|
||||
) | is_motif_token_unindexed
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == 'atom23':
|
||||
is_residue = (
|
||||
token_level_array.is_protein & ~token_level_array.atomize
|
||||
) | is_motif_token_unindexed
|
||||
|
||||
is_residue_NA = (
|
||||
(token_level_array.is_dna | token_level_array.is_rna) & ~token_level_array.atomize
|
||||
) | is_motif_token_unindexed
|
||||
|
||||
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
|
||||
is_paddable = is_residue & ~(
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
is_non_paddable_residue = is_residue & (
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
|
||||
is_paddable = (is_residue_NA | is_residue) & ~(
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
is_non_paddable_residue = (is_residue_NA | is_residue) & (
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
else:
|
||||
is_residue = (
|
||||
token_level_array.is_protein & ~token_level_array.atomize
|
||||
) | is_motif_token_unindexed
|
||||
|
||||
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
|
||||
is_paddable = is_residue & ~(
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
is_non_paddable_residue = is_residue & (
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
|
||||
|
||||
# Collect virtual atoms to insert (we will insert them all at once)
|
||||
virtual_atoms_to_insert = []
|
||||
@@ -194,8 +221,16 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
||||
if is_paddable[token_id]:
|
||||
token = atom_array[start:end]
|
||||
|
||||
# First, pad with virtual atoms if needed
|
||||
n_pad = self.n_atoms_per_token - len(token)
|
||||
if self.association_scheme == "atom23" and atom_array[start].is_dna:
|
||||
n_atoms_per_token = 22
|
||||
elif self.association_scheme == "atom23" and atom_array[start].is_rna:
|
||||
n_atoms_per_token = 23
|
||||
else:
|
||||
n_atoms_per_token = self.n_atoms_per_token
|
||||
n_pad = n_atoms_per_token - len(token)
|
||||
|
||||
if n_pad > 0:
|
||||
mask = get_af3_token_representative_masks(
|
||||
token, central_atom=self.atom_to_pad_from
|
||||
@@ -262,18 +297,26 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
|
||||
for token_id, (start, end) in enumerate(
|
||||
zip(starts_padded[:-1], starts_padded[1:])
|
||||
):
|
||||
):
|
||||
if (atom_array_padded[start].is_dna):
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
|
||||
elif (atom_array_padded[start].is_rna):
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
|
||||
else:
|
||||
ATOM_NAMES = ATOM14_ATOM_NAMES
|
||||
|
||||
if is_paddable[token_id]:
|
||||
# ... Permutation of atom names during training
|
||||
if not data["is_inference"] and exists(self.association_scheme):
|
||||
atom_names = permute_symmetric_atom_names_(
|
||||
ATOM14_ATOM_NAMES,
|
||||
ATOM_NAMES,
|
||||
atom_array_padded.res_name[start],
|
||||
association_map=self.association_map_,
|
||||
symmetry_map=self.symmetry_map_,
|
||||
)
|
||||
else:
|
||||
atom_names = ATOM14_ATOM_NAMES
|
||||
atom_names = ATOM_NAMES
|
||||
|
||||
atom_array_padded.atom_name[start:end] = atom_names
|
||||
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
|
||||
|
||||
@@ -285,7 +328,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
)
|
||||
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
|
||||
atom_names = map_to_association_scheme(
|
||||
atom_names, res_name, scheme=self.association_scheme
|
||||
atom_names, res_name, scheme=self.association_scheme, ATOM_NAMES=ATOM_NAMES
|
||||
)
|
||||
atom_array_padded.atom_name[start:end] = atom_names
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user