mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
feat: atom23 porting
This commit is contained in:
committed by
Raktim Mitra
parent
b5beff039e
commit
129efbf590
@@ -401,7 +401,6 @@ association_schemes["atom23"]["X"] = (
|
||||
ATOM23_ATOM_NAMES_RNA = np.array(
|
||||
[item.strip() for item in backbone_atomscheme_RNA]
|
||||
+ [f"V{i}" for i in range(23 - len(backbone_atomscheme_RNA))]
|
||||
)
|
||||
"""Atom23 atom names (e.g. CA, V1)"""
|
||||
|
||||
ATOM23_ATOM_ELEMENTS_RNA = np.array(
|
||||
|
||||
@@ -287,6 +287,8 @@ class SampleConditioningType(Transform):
|
||||
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
@@ -304,6 +306,8 @@ class SampleConditioningFlags(Transform):
|
||||
"AssignTypes",
|
||||
"SampleConditioningType",
|
||||
] # We use is_protein in the PPI training condition
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
@@ -757,6 +757,11 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
|
||||
if association_scheme == 'atom23':
|
||||
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
|
||||
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
|
||||
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
|
||||
|
||||
if self.association_scheme == "atom23":
|
||||
data["atom_array"].set_annotation(
|
||||
|
||||
@@ -72,9 +72,11 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_coordinates,
|
||||
p_fix_motif_sequence,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme = 'atom14',
|
||||
):
|
||||
self.name = name
|
||||
self.frequency = frequency
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Token selection
|
||||
self.island_sampling_kwargs = island_sampling_kwargs
|
||||
@@ -89,6 +91,8 @@ class IslandCondition(TrainingCondition):
|
||||
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
||||
self.p_fix_motif_sequence = p_fix_motif_sequence
|
||||
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
||||
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def is_valid_for_example(self, data) -> bool:
|
||||
is_protein = data["atom_array"].is_protein
|
||||
@@ -521,6 +525,7 @@ def sample_is_motif_atom_with_fixed_seq(
|
||||
is_motif_atom_with_fixed_seq = (
|
||||
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
)
|
||||
|
||||
|
||||
return is_motif_atom_with_fixed_seq
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ def map_to_association_scheme(
|
||||
else:
|
||||
return ATOM_NAMES[idxs]
|
||||
|
||||
|
||||
def map_names_to_elements(
|
||||
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
||||
) -> np.ndarray:
|
||||
@@ -180,7 +179,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
token_ids = np.unique(atom_array.token_id)
|
||||
assert len(token_ids) == len(
|
||||
is_motif_atom_with_fixed_seq
|
||||
), "Token ids and token level array have different lengths!"
|
||||
), "Token ids and token level array have different lengths!"
|
||||
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == "atom23":
|
||||
@@ -212,7 +211,6 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
is_non_paddable_residue = is_residue & (
|
||||
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
||||
)
|
||||
|
||||
# Collect virtual atoms to insert (we will insert them all at once)
|
||||
virtual_atoms_to_insert = []
|
||||
insert_positions = []
|
||||
|
||||
Reference in New Issue
Block a user