mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
inference fixes
This commit is contained in:
@@ -12,9 +12,10 @@ from rfd3.metrics.metrics_utils import (
|
||||
|
||||
from foundry.common import exists
|
||||
from foundry.metrics.metric import Metric
|
||||
from rfd3.constants import backbone_atoms_RNA
|
||||
|
||||
STANDARD_CACA_DIST = 3.8
|
||||
|
||||
STANDARD_P_P_DISTANCE = 6.4 ## average of B and A form 7 and 5.9
|
||||
|
||||
def get_clash_metrics(
|
||||
atom_array,
|
||||
@@ -28,7 +29,13 @@ def get_clash_metrics(
|
||||
)
|
||||
|
||||
def get_chainbreaks():
|
||||
ca_atoms = atom_array[atom_array.atom_name == "CA"]
|
||||
if "CA" in atom_array.atom_name:
|
||||
ca_atoms = atom_array[atom_array.atom_name == "CA"]
|
||||
cut_off = STANDARD_CACA_DIST
|
||||
elif "P" in atom_array.atom_name:
|
||||
ca_atoms = atom_array[atom_array.atom_name == "P"]
|
||||
cut_off = STANDARD_P_P_DISTANCE
|
||||
|
||||
xyz = ca_atoms.coord
|
||||
xyz = torch.from_numpy(xyz)
|
||||
ca_dists = torch.norm(xyz[1:] - xyz[:-1], dim=-1)
|
||||
@@ -45,7 +52,7 @@ def get_clash_metrics(
|
||||
}
|
||||
|
||||
def get_interresidue_clashes(backbone_only=False):
|
||||
protein_array = atom_array[atom_array.is_protein]
|
||||
protein_array = atom_array[atom_array.is_protein | atom_array.is_dna | atom_array.is_rna]
|
||||
resid = protein_array.res_id - protein_array.res_id.min()
|
||||
xyz = protein_array.coord
|
||||
dists = np.linalg.norm(xyz[:, None] - xyz[None], axis=-1) # N_atoms x N_atoms
|
||||
@@ -58,7 +65,7 @@ def get_clash_metrics(
|
||||
|
||||
if backbone_only:
|
||||
# Block out non-backbone atoms
|
||||
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"])
|
||||
backbone_mask = np.isin(protein_array.atom_name, ["N", "CA", "C"] + backbone_atoms_RNA)
|
||||
mask = backbone_mask[:, None] & backbone_mask[None, :]
|
||||
dists[~mask] = 999
|
||||
|
||||
|
||||
@@ -823,9 +823,11 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
# Don't do this if we already have the feature
|
||||
if feature_name in data["feats"].keys():
|
||||
return data
|
||||
|
||||
|
||||
# For these, we need to use a constructor function mapping,
|
||||
# since pair features may require custom logic/conventions.
|
||||
|
||||
## for old ckpt handling ##
|
||||
if feature_name in self.constructor_functions.keys():
|
||||
feature_array = self.constructor_functions[feature_name](data["atom_array"])
|
||||
else:
|
||||
|
||||
@@ -196,7 +196,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
token_level_array = atom_array[token_starts]
|
||||
token_ids = [int(t) for t in token_level_array.token_id]
|
||||
n_tokens = len(token_starts)
|
||||
print(" DO I NEED TO CHANGE TO TOKEN_ID???")
|
||||
#TODO print(" DO I NEED TO CHANGE TO TOKEN_ID???")
|
||||
# Handle the training case with ground truth and masking:
|
||||
if not self.is_inference and (np.random.rand() < self.sampling_prob):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user