run time fixes; inference fix for readout seq from struct

This commit is contained in:
Raktim Mitra
2026-02-18 12:59:23 -08:00
parent de06df8cbf
commit bcaa6b79a8
4 changed files with 116 additions and 93 deletions

View File

@@ -25,6 +25,7 @@ from scipy.optimize import linear_sum_assignment
from foundry.common import exists
from foundry.utils.ddp import RankedLogger
from atomworks.constants import STANDARD_DNA, STANDARD_RNA
global_logger = RankedLogger(__name__, rank_zero_only=False)
@@ -221,11 +222,15 @@ def _readout_seq_from_struc(
# There might be a better way to do this.
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
if np.linalg.norm(CA_coord - CB_coord) < threshold:
if cur_res_atom_array.is_dna[0] or cur_res_atom_array.is_rna[0]:
cur_central_atom = "C1'"
elif np.linalg.norm(CA_coord - CB_coord) < threshold:
cur_central_atom = "CA"
else:
cur_central_atom = central_atom
central_mask = cur_res_atom_array.atom_name == cur_central_atom
# ... Calculate the distance to the central atom
@@ -258,8 +263,12 @@ def _readout_seq_from_struc(
ATOM_NAMES = ATOM14_ATOM_NAMES
if restype in STANDARD_DNA:
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
if not cur_res_atom_array.is_dna[0]:
continue
if restype in STANDARD_RNA:
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
if not cur_res_atom_array.is_rna[0]:
continue
atom_name_idx_in_atom14_scheme = np.array(
[
@@ -269,7 +278,6 @@ def _readout_seq_from_struc(
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool)
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
# ... Find the matched restype by checking if all the non-None posititons and None positions match
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(

View File

@@ -870,7 +870,6 @@ class AddAdditional2dFeaturesToFeats(Transform):
return data
for feature_name, n_dims in self.token_2d_features.items():
data = self.generate_token_feature(feature_name, n_dims, data)
return data

View File

@@ -248,3 +248,5 @@ def subsample_one_hot_np(array, fraction):
new_array[i, j] = 1
return new_array

View File

@@ -32,7 +32,7 @@ from rfd3.constants import (
ATOM_REGION_BY_RESI,
PLANAR_ATOMS_BY_RESI,
)
import tempfile
# Derived: True when the residue has any planar sidechain atoms
HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()}
@@ -132,110 +132,124 @@ def calculate_hb_counts(
np.ndarray of shape ``(I, I, 3)`` (int32) where the last axis
encodes: 0 = BBBB, 1 = BBSC, 2 = SCSC H-bond counts.
"""
hbplus_exe = os.environ.get("HBPLUS_PATH")
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb"
if hbplus_exe is None or hbplus_exe == "":
raise ValueError(
"HBPLUS_PATH environment variable not set. "
"Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
)
with tempfile.TemporaryDirectory() as tmpdir:
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb"
pdb_path = os.path.join(tmpdir, pdb_filename)
atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
subprocess.call(
[
"/projects/ml/hbplus",
"-h",
str(cutoff_HA_dist),
"-d",
str(cutoff_DA_dist),
pdb_path,
pdb_path,
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
subprocess.call(
[
hbplus_exe,
"-h",
str(cutoff_HA_dist),
"-d",
str(cutoff_DA_dist),
pdb_path,
pdb_path,
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
cwd=tmpdir,
)
num_resis_total = len(token_level_data["token_id_list"])
num_resis_total = len(token_level_data["token_id_list"])
hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32)
hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32)
hb2_path = pdb_path.replace("pdb", "hb2")
if not os.path.exists(hb2_path):
print("WARNING: HB2 file could not be found; skipping NA SS metric")
return hbond_count
with open(hb2_path, "r") as hb2_f:
for i, line in enumerate(hb2_f):
if i < 8:
continue
if len(line) < 28:
continue
hb2_path = pdb_path.replace("pdb", "hb2")
with open(hb2_path, "r") as hb2_f:
for i, line in enumerate(hb2_f):
if i < 8:
continue
if len(line) < 28:
continue
d_chain_iid = chain_map[line[0]]
d_resi = int(line[1:5].strip())
d_resn = line[6:9].strip()
d_atom_name = line[9:13].strip()
d_chain_iid = chain_map[line[0]]
d_resi = int(line[1:5].strip())
d_resn = line[6:9].strip()
d_atom_name = line[9:13].strip()
# Initialize donor/acceptor sidechain/backbone flags:
# then replace with True if valid for summation
d_is_sc = False
d_is_bb = False
a_is_sc = False
a_is_bb = False
# Initialize donor/acceptor sidechain/backbone flags:
# then replace with True if valid for summation
d_is_sc = False
d_is_bb = False
a_is_sc = False
a_is_bb = False
d_mask = (
(atom_array.atom_name == d_atom_name)
& (atom_array.res_name == d_resn)
& (atom_array.res_id == d_resi)
& (atom_array.chain_iid == d_chain_iid)
)
# d_atm = atom_array[d_mask]
# d_idx = d_atm.token_id
d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None)
if d_idx is None:
continue
d_mask = (
(atom_array.atom_name == d_atom_name)
& (atom_array.res_name == d_resn)
& (atom_array.res_id == d_resi)
& (atom_array.chain_iid == d_chain_iid)
)
# d_atm = atom_array[d_mask]
# d_idx = d_atm.token_id
d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None)
if d_idx is None:
continue
# Handle standard polymer residues for donor atom:
if d_resn in ATOM_REGION_BY_RESI.keys():
d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc'])
d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb'])
else:
# If non-polymer, define any ligand HBonding atom as backbone:
if d_mask.sum() > 0:
d_is_bb = atom_array[d_mask][0].is_ligand
# Handle standard polymer residues for donor atom:
if d_resn in ATOM_REGION_BY_RESI.keys():
d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc'])
d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb'])
else:
# If non-polymer, define any ligand HBonding atom as backbone:
if d_mask.sum() > 0:
d_is_bb = atom_array[d_mask][0].is_ligand
a_chain_iid = chain_map[line[14]]
a_resi = int(line[15:19].strip())
a_resn = line[20:23].strip()
a_atom_name = line[23:27].strip()
a_chain_iid = chain_map[line[14]]
a_resi = int(line[15:19].strip())
a_resn = line[20:23].strip()
a_atom_name = line[23:27].strip()
a_mask = (
(atom_array.atom_name == a_atom_name)
& (atom_array.res_name == a_resn)
& (atom_array.res_id == a_resi)
& (atom_array.chain_iid == a_chain_iid)
)
a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None)
if a_idx is None:
continue
a_mask = (
(atom_array.atom_name == a_atom_name)
& (atom_array.res_name == a_resn)
& (atom_array.res_id == a_resi)
& (atom_array.chain_iid == a_chain_iid)
)
a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None)
if a_idx is None:
continue
# Handle standard polymer residues for acceptor atom:
if a_resn in ATOM_REGION_BY_RESI.keys():
a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc'])
a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb'])
else:
# If non-polymer, define any ligand HBonding atom as backbone:
if a_mask.sum() > 0:
a_is_bb = atom_array[a_mask][0].is_ligand
# Handle standard polymer residues for acceptor atom:
if a_resn in ATOM_REGION_BY_RESI.keys():
a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc'])
a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb'])
else:
# If non-polymer, define any ligand HBonding atom as backbone:
if a_mask.sum() > 0:
a_is_bb = atom_array[a_mask][0].is_ligand
# 0 -> both backbone (BB-BB)
hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb)
hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb)
# 0 -> both backbone (BB-BB)
hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb)
hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb)
# 1 -> one backbone, one sidechain (BB-SC)
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb)
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb)
# 2 -> both sidechain (SC-SC)
hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc)
hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc)
os.remove(pdb_path)
os.remove(hb2_path)
# 1 -> one backbone, one sidechain (BB-SC)
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb)
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb)
# 2 -> both sidechain (SC-SC)
hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc)
hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc)
'''
try:
os.remove(pdb_path)
os.remove(hb2_path)
except:
print("temp pdb/hb already removed or not created to begin with")
'''
return hbond_count