feat: atom23 porting

This commit is contained in:
Raktim Mitra
2026-01-18 15:41:00 -08:00
committed by Raktim Mitra
parent b5beff039e
commit 129efbf590
5 changed files with 15 additions and 4 deletions

View File

@@ -401,7 +401,6 @@ association_schemes["atom23"]["X"] = (
ATOM23_ATOM_NAMES_RNA = np.array(
[item.strip() for item in backbone_atomscheme_RNA]
+ [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
)
"""Atom23 atom names (e.g. CA, V1)"""
ATOM23_ATOM_ELEMENTS_RNA = np.array(

View File

@@ -287,6 +287,8 @@ class SampleConditioningType(Transform):
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
cond = valid_conditions[i_cond]
cond.association_scheme = self.association_scheme
data["sampled_condition"] = cond
data["sampled_condition_name"] = cond.name
data["sampled_condition_cls"] = cond.__class__
@@ -304,6 +306,8 @@ class SampleConditioningFlags(Transform):
"AssignTypes",
"SampleConditioningType",
] # We use is_protein in the PPI training condition
def __init__(self, association_scheme):
self.association_scheme = association_scheme
def __init__(self, association_scheme):
self.association_scheme = association_scheme

View File

@@ -757,6 +757,11 @@ class AddAdditional1dFeaturesToFeats(Transform):
"""
if "feats" not in data.keys():
data["feats"] = {}
if association_scheme == 'atom23':
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
if self.association_scheme == "atom23":
data["atom_array"].set_annotation(

View File

@@ -72,9 +72,11 @@ class IslandCondition(TrainingCondition):
p_fix_motif_coordinates,
p_fix_motif_sequence,
p_unindex_motif_tokens,
association_scheme = 'atom14',
):
self.name = name
self.frequency = frequency
self.association_scheme = association_scheme
# Token selection
self.island_sampling_kwargs = island_sampling_kwargs
@@ -89,6 +91,8 @@ class IslandCondition(TrainingCondition):
self.p_fix_motif_coordinates = p_fix_motif_coordinates
self.p_fix_motif_sequence = p_fix_motif_sequence
self.p_unindex_motif_tokens = p_unindex_motif_tokens
self.association_scheme = association_scheme
def is_valid_for_example(self, data) -> bool:
is_protein = data["atom_array"].is_protein
@@ -521,6 +525,7 @@ def sample_is_motif_atom_with_fixed_seq(
is_motif_atom_with_fixed_seq = (
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
)
return is_motif_atom_with_fixed_seq

View File

@@ -53,7 +53,6 @@ def map_to_association_scheme(
else:
return ATOM_NAMES[idxs]
def map_names_to_elements(
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
) -> np.ndarray:
@@ -180,7 +179,7 @@ class PadTokensWithVirtualAtoms(Transform):
token_ids = np.unique(atom_array.token_id)
assert len(token_ids) == len(
is_motif_atom_with_fixed_seq
), "Token ids and token level array have different lengths!"
), "Token ids and token level array have different lengths!"
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
if self.association_scheme == "atom23":
@@ -212,7 +211,6 @@ class PadTokensWithVirtualAtoms(Transform):
is_non_paddable_residue = is_residue & (
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
)
# Collect virtual atoms to insert (we will insert them all at once)
virtual_atoms_to_insert = []
insert_positions = []