mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
feat: atom23 inference changes and training fixes
This commit is contained in:
committed by
Raktim Mitra
parent
94d9d635cd
commit
c2ec302745
@@ -349,9 +349,13 @@ class SubunitSymmetryResolution(nn.Module):
|
||||
x_native = symm_input["coord_atom_lvl"].to(x_pred.device)
|
||||
mask_native = symm_input["mask_atom_lvl"].to(x_pred.device)
|
||||
|
||||
x_native_aln, x_native_mask = self._resolve_subunits(
|
||||
mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred
|
||||
)
|
||||
try:
|
||||
x_native_aln, x_native_mask = self._resolve_subunits(
|
||||
mol_entities, mol_iid, crop_mask, x_native, mask_native, x_pred
|
||||
)
|
||||
except Exception:
|
||||
# fd ... TO DO: DEBUG!
|
||||
return loss_input
|
||||
|
||||
loss_input["X_gt_L"] = x_native_aln
|
||||
loss_input["crd_mask_L"] = x_native_mask
|
||||
|
||||
@@ -242,6 +242,35 @@ SELECTION_NONPROTEIN = [
|
||||
"POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID",
|
||||
]
|
||||
|
||||
backbone_atomscheme_DNA = [
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" C1'",
|
||||
] # , None]
|
||||
|
||||
backbone_atomscheme_RNA = [
|
||||
" P ",
|
||||
" OP1",
|
||||
" OP2",
|
||||
" O5'",
|
||||
" C5'",
|
||||
" C4'",
|
||||
" O4'",
|
||||
" C3'",
|
||||
" O3'",
|
||||
" C2'",
|
||||
" O2'",
|
||||
" C1'",
|
||||
]
|
||||
|
||||
DNA_atoms = {
|
||||
"DA": [
|
||||
" N9 ",
|
||||
@@ -410,7 +439,6 @@ association_schemes_stripped = {
|
||||
backbone_atoms_RNA = strip_list(backbone_atomscheme_RNA)
|
||||
backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
|
||||
|
||||
|
||||
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
|
||||
ATOM_REGION_BY_RESI = {
|
||||
'ALA': {'bb':('N','CA','C','O'),
|
||||
|
||||
@@ -186,6 +186,7 @@ def fetch_motif_residue_(
|
||||
subarray.set_annotation(
|
||||
"is_motif_atom_with_fixed_seq", np.zeros(subarray.shape[0], dtype=int)
|
||||
)
|
||||
|
||||
elif redesign_motif_sidechains and res_name in (STANDARD_DNA + STANDARD_RNA):
|
||||
is_backbone = np.isin(subarray.atom_name, backbone_atoms_RNA)
|
||||
subarray.set_annotation("is_motif_atom", is_backbone)
|
||||
|
||||
@@ -281,7 +281,7 @@ class SampleConditioningType(Transform):
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
@@ -299,6 +299,7 @@ class SampleConditioningFlags(Transform):
|
||||
"AssignTypes",
|
||||
"SampleConditioningType",
|
||||
] # We use is_protein in the PPI training condition
|
||||
|
||||
def __init__(self, association_scheme):
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
|
||||
@@ -698,7 +698,6 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
atom_1d_features,
|
||||
autofill_zeros_if_not_present_in_atomarray=False,
|
||||
association_scheme="atom14",
|
||||
association_scheme='atom14'
|
||||
):
|
||||
self.autofill = autofill_zeros_if_not_present_in_atomarray
|
||||
self.token_1d_features = token_1d_features
|
||||
@@ -755,11 +754,13 @@ class AddAdditional1dFeaturesToFeats(Transform):
|
||||
"""
|
||||
if "feats" not in data.keys():
|
||||
data["feats"] = {}
|
||||
|
||||
if association_scheme == 'atom23':
|
||||
data['atom_array'].set_annotation('is_protein_token', data['atom_array'].is_protein)
|
||||
data['atom_array'].set_annotation('is_dna_token', data['atom_array'].is_dna)
|
||||
data['atom_array'].set_annotation('is_rna_token', data['atom_array'].is_rna)
|
||||
|
||||
if self.association_scheme == "atom23":
|
||||
data["atom_array"].set_annotation(
|
||||
"is_protein_token", data["atom_array"].is_protein
|
||||
)
|
||||
data["atom_array"].set_annotation("is_dna_token", data["atom_array"].is_dna)
|
||||
data["atom_array"].set_annotation("is_rna_token", data["atom_array"].is_rna)
|
||||
|
||||
if self.association_scheme == "atom23":
|
||||
data["atom_array"].set_annotation(
|
||||
|
||||
@@ -523,7 +523,7 @@ def build_atom14_base_pipeline_(
|
||||
autofill_zeros_if_not_present_in_atomarray=True,
|
||||
token_1d_features=token_1d_features,
|
||||
atom_1d_features=atom_1d_features,
|
||||
association_scheme=association_scheme
|
||||
association_scheme=association_scheme,
|
||||
),
|
||||
AddAF3TokenBondFeatures(),
|
||||
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
|
||||
|
||||
@@ -72,11 +72,9 @@ class IslandCondition(TrainingCondition):
|
||||
p_fix_motif_coordinates,
|
||||
p_fix_motif_sequence,
|
||||
p_unindex_motif_tokens,
|
||||
association_scheme = 'atom14',
|
||||
):
|
||||
self.name = name
|
||||
self.frequency = frequency
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
# Token selection
|
||||
self.island_sampling_kwargs = island_sampling_kwargs
|
||||
@@ -91,8 +89,6 @@ class IslandCondition(TrainingCondition):
|
||||
self.p_fix_motif_coordinates = p_fix_motif_coordinates
|
||||
self.p_fix_motif_sequence = p_fix_motif_sequence
|
||||
self.p_unindex_motif_tokens = p_unindex_motif_tokens
|
||||
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
def is_valid_for_example(self, data) -> bool:
|
||||
is_protein = data["atom_array"].is_protein
|
||||
@@ -138,7 +134,6 @@ class IslandCondition(TrainingCondition):
|
||||
n_protein_tokens,
|
||||
**self.island_sampling_kwargs,
|
||||
)
|
||||
|
||||
is_motif_token[token_level_array.is_protein] = islands_mask
|
||||
|
||||
# TODO: Atoms with covalent bonds should be motif, needs FlagAndReassignCovalentModifications transform prior to this
|
||||
@@ -310,7 +305,7 @@ class SubtypeCondition(TrainingCondition):
|
||||
"""
|
||||
|
||||
name = "subtype"
|
||||
association_scheme = 'atom14'
|
||||
association_scheme = "atom14"
|
||||
|
||||
def __init__(self, frequency: float, subtype: list[str], fix_pos: bool = False):
|
||||
self.frequency = frequency
|
||||
@@ -526,8 +521,6 @@ def sample_is_motif_atom_with_fixed_seq(
|
||||
is_motif_atom_with_fixed_seq = (
|
||||
is_motif_atom_with_fixed_seq | ~atom_array.is_protein
|
||||
)
|
||||
|
||||
|
||||
|
||||
return is_motif_atom_with_fixed_seq
|
||||
|
||||
@@ -571,7 +564,6 @@ def sample_unindexed_atoms(
|
||||
is_motif_atom_unindexed, atom_array.is_residue
|
||||
)
|
||||
|
||||
|
||||
return is_motif_atom_unindexed
|
||||
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ def map_to_association_scheme(
|
||||
else:
|
||||
return ATOM_NAMES[idxs]
|
||||
|
||||
|
||||
def map_names_to_elements(
|
||||
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
||||
) -> np.ndarray:
|
||||
@@ -179,7 +180,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
token_ids = np.unique(atom_array.token_id)
|
||||
assert len(token_ids) == len(
|
||||
is_motif_atom_with_fixed_seq
|
||||
), "Token ids and token level array have different lengths!"
|
||||
), "Token ids and token level array have different lengths!"
|
||||
|
||||
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
||||
if self.association_scheme == "atom23":
|
||||
@@ -220,7 +221,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
||||
if is_paddable[token_id]:
|
||||
token = atom_array[start:end]
|
||||
|
||||
|
||||
# First, pad with virtual atoms if needed
|
||||
if self.association_scheme == "atom23" and atom_array[start].is_dna:
|
||||
n_atoms_per_token = 22
|
||||
@@ -229,7 +230,7 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
else:
|
||||
n_atoms_per_token = self.n_atoms_per_token
|
||||
n_pad = n_atoms_per_token - len(token)
|
||||
|
||||
|
||||
if n_pad > 0:
|
||||
mask = get_af3_token_representative_masks(
|
||||
token, central_atom=self.atom_to_pad_from
|
||||
|
||||
Reference in New Issue
Block a user