mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
run time fixes; inference fix for readout seq from struct
This commit is contained in:
@@ -25,6 +25,7 @@ from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from foundry.common import exists
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
from atomworks.constants import STANDARD_DNA, STANDARD_RNA
|
||||
|
||||
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
||||
|
||||
@@ -221,11 +222,15 @@ def _readout_seq_from_struc(
|
||||
# There might be a better way to do this.
|
||||
CA_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CA"]
|
||||
CB_coord = cur_res_atom_array.coord[cur_res_atom_array.atom_name == "CB"]
|
||||
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
||||
|
||||
if cur_res_atom_array.is_dna[0] or cur_res_atom_array.is_rna[0]:
|
||||
cur_central_atom = "C1'"
|
||||
elif np.linalg.norm(CA_coord - CB_coord) < threshold:
|
||||
cur_central_atom = "CA"
|
||||
else:
|
||||
cur_central_atom = central_atom
|
||||
|
||||
|
||||
central_mask = cur_res_atom_array.atom_name == cur_central_atom
|
||||
|
||||
# ... Calculate the distance to the central atom
|
||||
@@ -258,8 +263,12 @@ def _readout_seq_from_struc(
|
||||
ATOM_NAMES = ATOM14_ATOM_NAMES
|
||||
if restype in STANDARD_DNA:
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_DNA
|
||||
if not cur_res_atom_array.is_dna[0]:
|
||||
continue
|
||||
if restype in STANDARD_RNA:
|
||||
ATOM_NAMES = ATOM23_ATOM_NAMES_RNA
|
||||
if not cur_res_atom_array.is_rna[0]:
|
||||
continue
|
||||
|
||||
atom_name_idx_in_atom14_scheme = np.array(
|
||||
[
|
||||
@@ -269,7 +278,6 @@ def _readout_seq_from_struc(
|
||||
) # five backbone atoms + some virtual atoms, returning e.g. [0, 1, 2, 3, 4, 11, 7]
|
||||
atom14_scheme_mask = np.zeros_like(ATOM_NAMES, dtype=bool)
|
||||
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
|
||||
|
||||
# ... Find the matched restype by checking if all the non-None posititons and None positions match
|
||||
# This is designed to keep virtual atoms and doesn't assign the atom names for now, which will be handled later.
|
||||
if all(x is not None for x in atom_names[atom14_scheme_mask]) and all(
|
||||
|
||||
@@ -870,7 +870,6 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
return data
|
||||
for feature_name, n_dims in self.token_2d_features.items():
|
||||
data = self.generate_token_feature(feature_name, n_dims, data)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -248,3 +248,5 @@ def subsample_one_hot_np(array, fraction):
|
||||
new_array[i, j] = 1
|
||||
|
||||
return new_array
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ from rfd3.constants import (
|
||||
ATOM_REGION_BY_RESI,
|
||||
PLANAR_ATOMS_BY_RESI,
|
||||
)
|
||||
|
||||
import tempfile
|
||||
|
||||
# Derived: True when the residue has any planar sidechain atoms
|
||||
HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()}
|
||||
@@ -132,110 +132,124 @@ def calculate_hb_counts(
|
||||
np.ndarray of shape ``(I, I, 3)`` (int32) where the last axis
|
||||
encodes: 0 = BB–BB, 1 = BB–SC, 2 = SC–SC H-bond counts.
|
||||
"""
|
||||
hbplus_exe = os.environ.get("HBPLUS_PATH")
|
||||
|
||||
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb"
|
||||
if hbplus_exe is None or hbplus_exe == "":
|
||||
raise ValueError(
|
||||
"HBPLUS_PATH environment variable not set. "
|
||||
"Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
pdb_filename = f"{dtstr}_{np.random.randint(10000)}.pdb"
|
||||
pdb_path = os.path.join(tmpdir, pdb_filename)
|
||||
atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
|
||||
|
||||
atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
|
||||
subprocess.call(
|
||||
[
|
||||
"/projects/ml/hbplus",
|
||||
"-h",
|
||||
str(cutoff_HA_dist),
|
||||
"-d",
|
||||
str(cutoff_DA_dist),
|
||||
pdb_path,
|
||||
pdb_path,
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
subprocess.call(
|
||||
[
|
||||
hbplus_exe,
|
||||
"-h",
|
||||
str(cutoff_HA_dist),
|
||||
"-d",
|
||||
str(cutoff_DA_dist),
|
||||
pdb_path,
|
||||
pdb_path,
|
||||
],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
cwd=tmpdir,
|
||||
)
|
||||
|
||||
num_resis_total = len(token_level_data["token_id_list"])
|
||||
|
||||
num_resis_total = len(token_level_data["token_id_list"])
|
||||
hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32)
|
||||
|
||||
hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32)
|
||||
hb2_path = pdb_path.replace("pdb", "hb2")
|
||||
if not os.path.exists(hb2_path):
|
||||
print("WARNING: HB2 file could not be found; skipping NA SS metric")
|
||||
return hbond_count
|
||||
with open(hb2_path, "r") as hb2_f:
|
||||
for i, line in enumerate(hb2_f):
|
||||
if i < 8:
|
||||
continue
|
||||
if len(line) < 28:
|
||||
continue
|
||||
|
||||
hb2_path = pdb_path.replace("pdb", "hb2")
|
||||
with open(hb2_path, "r") as hb2_f:
|
||||
for i, line in enumerate(hb2_f):
|
||||
if i < 8:
|
||||
continue
|
||||
if len(line) < 28:
|
||||
continue
|
||||
d_chain_iid = chain_map[line[0]]
|
||||
d_resi = int(line[1:5].strip())
|
||||
d_resn = line[6:9].strip()
|
||||
d_atom_name = line[9:13].strip()
|
||||
|
||||
d_chain_iid = chain_map[line[0]]
|
||||
d_resi = int(line[1:5].strip())
|
||||
d_resn = line[6:9].strip()
|
||||
d_atom_name = line[9:13].strip()
|
||||
# Initialize donor/acceptor sidechain/backbone flags:
|
||||
# then replace with True if valid for summation
|
||||
d_is_sc = False
|
||||
d_is_bb = False
|
||||
a_is_sc = False
|
||||
a_is_bb = False
|
||||
|
||||
# Initialize donor/acceptor sidechain/backbone flags:
|
||||
# then replace with True if valid for summation
|
||||
d_is_sc = False
|
||||
d_is_bb = False
|
||||
a_is_sc = False
|
||||
a_is_bb = False
|
||||
d_mask = (
|
||||
(atom_array.atom_name == d_atom_name)
|
||||
& (atom_array.res_name == d_resn)
|
||||
& (atom_array.res_id == d_resi)
|
||||
& (atom_array.chain_iid == d_chain_iid)
|
||||
)
|
||||
# d_atm = atom_array[d_mask]
|
||||
# d_idx = d_atm.token_id
|
||||
d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None)
|
||||
if d_idx is None:
|
||||
continue
|
||||
|
||||
d_mask = (
|
||||
(atom_array.atom_name == d_atom_name)
|
||||
& (atom_array.res_name == d_resn)
|
||||
& (atom_array.res_id == d_resi)
|
||||
& (atom_array.chain_iid == d_chain_iid)
|
||||
)
|
||||
# d_atm = atom_array[d_mask]
|
||||
# d_idx = d_atm.token_id
|
||||
d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None)
|
||||
if d_idx is None:
|
||||
continue
|
||||
# Handle standard polymer residues for donor atom:
|
||||
if d_resn in ATOM_REGION_BY_RESI.keys():
|
||||
d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc'])
|
||||
d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb'])
|
||||
else:
|
||||
# If non-polymer, define any ligand HBonding atom as backbone:
|
||||
if d_mask.sum() > 0:
|
||||
d_is_bb = atom_array[d_mask][0].is_ligand
|
||||
|
||||
# Handle standard polymer residues for donor atom:
|
||||
if d_resn in ATOM_REGION_BY_RESI.keys():
|
||||
d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc'])
|
||||
d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb'])
|
||||
else:
|
||||
# If non-polymer, define any ligand HBonding atom as backbone:
|
||||
if d_mask.sum() > 0:
|
||||
d_is_bb = atom_array[d_mask][0].is_ligand
|
||||
a_chain_iid = chain_map[line[14]]
|
||||
a_resi = int(line[15:19].strip())
|
||||
a_resn = line[20:23].strip()
|
||||
a_atom_name = line[23:27].strip()
|
||||
|
||||
a_chain_iid = chain_map[line[14]]
|
||||
a_resi = int(line[15:19].strip())
|
||||
a_resn = line[20:23].strip()
|
||||
a_atom_name = line[23:27].strip()
|
||||
a_mask = (
|
||||
(atom_array.atom_name == a_atom_name)
|
||||
& (atom_array.res_name == a_resn)
|
||||
& (atom_array.res_id == a_resi)
|
||||
& (atom_array.chain_iid == a_chain_iid)
|
||||
)
|
||||
a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None)
|
||||
if a_idx is None:
|
||||
continue
|
||||
|
||||
a_mask = (
|
||||
(atom_array.atom_name == a_atom_name)
|
||||
& (atom_array.res_name == a_resn)
|
||||
& (atom_array.res_id == a_resi)
|
||||
& (atom_array.chain_iid == a_chain_iid)
|
||||
)
|
||||
a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None)
|
||||
if a_idx is None:
|
||||
continue
|
||||
# Handle standard polymer residues for acceptor atom:
|
||||
if a_resn in ATOM_REGION_BY_RESI.keys():
|
||||
a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc'])
|
||||
a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb'])
|
||||
else:
|
||||
# If non-polymer, define any ligand HBonding atom as backbone:
|
||||
if a_mask.sum() > 0:
|
||||
a_is_bb = atom_array[a_mask][0].is_ligand
|
||||
|
||||
# Handle standard polymer residues for acceptor atom:
|
||||
if a_resn in ATOM_REGION_BY_RESI.keys():
|
||||
a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc'])
|
||||
a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb'])
|
||||
else:
|
||||
# If non-polymer, define any ligand HBonding atom as backbone:
|
||||
if a_mask.sum() > 0:
|
||||
a_is_bb = atom_array[a_mask][0].is_ligand
|
||||
# 0 -> both backbone (BB-BB)
|
||||
hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb)
|
||||
hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb)
|
||||
|
||||
# 0 -> both backbone (BB-BB)
|
||||
hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb)
|
||||
hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb)
|
||||
|
||||
# 1 -> one backbone, one sidechain (BB-SC)
|
||||
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb)
|
||||
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb)
|
||||
|
||||
# 2 -> both sidechain (SC-SC)
|
||||
hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc)
|
||||
hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc)
|
||||
|
||||
os.remove(pdb_path)
|
||||
os.remove(hb2_path)
|
||||
# 1 -> one backbone, one sidechain (BB-SC)
|
||||
hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb)
|
||||
hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb)
|
||||
|
||||
# 2 -> both sidechain (SC-SC)
|
||||
hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc)
|
||||
hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc)
|
||||
'''
|
||||
try:
|
||||
os.remove(pdb_path)
|
||||
os.remove(hb2_path)
|
||||
except:
|
||||
print("temp pdb/hb already removed or not created to begin with")
|
||||
'''
|
||||
return hbond_count
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user