inference fixes

This commit is contained in:
Raktim Mitra
2026-02-09 11:37:38 -08:00
parent 335e580390
commit cff03801ed
3 changed files with 15 additions and 6 deletions

View File

@@ -12,9 +12,10 @@ from rfd3.metrics.metrics_utils import (
from foundry.common import exists
from foundry.metrics.metric import Metric
from rfd3.constants import backbone_atoms_RNA
STANDARD_CACA_DIST = 3.8
STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
def get_clash_metrics(
atom_array,
@@ -28,7 +29,13 @@ def get_clash_metrics(
)
def get_chainbreaks():
ca_atoms = atom_array[atom_array.atom_name == "CA"]
if "CA" in atom_array.atom_name:
ca_atoms = atom_array[atom_array.atom_name == "CA"]
cut_off = STANDARD_CACA_DIST
elif "P" in atom_array.atom_name:
ca_atoms = atom_array[atom_array.atom_name == "P"]
cut_off = STANDARD_P_P_DISTANCE
xyz = ca_atoms.coord
xyz = torch.from_numpy(xyz)
ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
@@ -45,7 +52,7 @@ def get_clash_metrics(
}
def get_interresidue_clashes(backbone_only=False):
protein_array = atom_array[atom_array.is_protein]
protein_array = atom_array[atom_array.is_protein | atom_array.is_dna | atom_array.is_rna]
resid = protein_array.res_id - protein_array.res_id.min()
xyz = protein_array.coord
dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
@@ -58,7 +65,7 @@ def get_clash_metrics(
if backbone_only:
# Block out non-backbone atoms
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"])
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA)
mask = backbone_mask[:, None] & backbone_mask[None, :]
dists[~mask] = 999

View File

@@ -823,9 +823,11 @@ class AddAdditional2dFeaturesToFeats(Transform):
# Don't do this if we already have the feature
if feature_name in data["feats"].keys():
return data
# For these, we need to use a constructor function mapping,
# since pair features may require custom logic/conventions.
## for old ckpt handling ##
if feature_name in self.constructor_functions.keys():
feature_array = self.constructor_functions[feature_name](data["atom_array"])
else:

View File

@@ -196,7 +196,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
token_level_array = atom_array[token_starts]
token_ids = [int(t) for t in token_level_array.token_id]
n_tokens = len(token_starts)
print(" DO I NEED TO CHANGE TO TOKEN_ID???")
#TODO print(" DO I NEED TO CHANGE TO TOKEN_ID???")
# Handle the training case with ground truth and masking:
if not self.is_inference and (np.random.rand() < self.sampling_prob):