unindex fix final for real

This commit is contained in:
Raktim Mitra
2026-03-19 12:56:45 -07:00
committed by Raktim Mitra
parent 42680bdf1e
commit 5a73df5c2f
3 changed files with 10 additions and 8 deletions

View File

@@ -470,8 +470,11 @@ def process_unindexed_outputs(
)
else:
dist = float(dists[row_ind[join_atom], col_ind[join_atom]])
try:
metadata["join_point_rmsd_by_token"][token_pdb_id] = dist
except:
pass
metadata["diffused_index_map"][token_pdb_id] = f"{chain_id}{res_id}"

View File

@@ -434,7 +434,6 @@ class UnindexFlaggedTokens(Transform):
f"Failed to create uniquely recognised tokens after concatenation.\n"
f"Concatenated tokens: {get_token_count(atom_array_full)}, unindexed: {n_unindexed_tokens}"
)
return atom_array_full
def create_unindexed_masks(

View File

@@ -19,6 +19,8 @@ from rfd3.transforms.conditioning_utils import (
sample_subgraph_atoms,
)
from rfd3.constants import backbone_atoms_RNA
nx.from_numpy_matrix = nx.from_numpy_array
logger = logging.getLogger(__name__)
@@ -158,7 +160,9 @@ class IslandCondition(TrainingCondition):
is_motif_atom = np.asarray(atom_array.is_motif_token, dtype=bool).copy()
if random_condition(self.p_diffuse_motif_sidechains):
backbone_atoms = ["N", "C", "CA"]
backbone_atoms = backbone_atoms_RNA.copy()
backbone_atoms.remove("C1'")
backbone_atoms = ["N", "C", "CA"] + backbone_atoms #covers DNA also
if random_condition(self.p_include_oxygen_in_backbone_mask):
backbone_atoms.append("O")
is_motif_atom = is_motif_atom & np.isin(
@@ -173,7 +177,6 @@ class IslandCondition(TrainingCondition):
# We also only want resolved atoms to be motif
is_motif_atom = (is_motif_atom) & (atom_array.occupancy > 0.0)
return is_motif_atom
def sample(self, data):
@@ -202,7 +205,6 @@ class IslandCondition(TrainingCondition):
leak_global_index=data["conditions"]["unindex_leak_global_index"],
),
)
return atom_array
@@ -498,10 +500,9 @@ def sample_conditioning_strategy(
atom_array.set_annotation(
"is_motif_atom_unindexed",
sample_unindexed_atoms(
atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens
atom_array, p_unindex_motif_tokens=p_unindex_motif_tokens, association_scheme=association_scheme
),
)
return atom_array
@@ -557,7 +558,6 @@ def sample_unindexed_atoms(
is_motif_atom_unindexed = atom_array.is_motif_atom.copy()
else:
is_motif_atom_unindexed = np.zeros(atom_array.array_length(), dtype=bool)
# ensure non-residue atoms are not already flagged
if association_scheme == "atom23":
is_motif_atom_unindexed = np.logical_and(