extra filter for huge base pair centroid distances, and also cleaned inference specification code

This commit is contained in:
afavor
2026-02-25 21:03:30 -08:00
committed by Raktim Mitra
parent d04989f65f
commit ad3d95f351

View File

@@ -90,7 +90,7 @@ class NucMolInfo:
self.bp_val_cutoff = 0.5 # minimum basepairing score for binarizing basepairs when needed
self.base_geometry_limits = {}
self.base_geometry_limits['D_ij'] = 20.0
self.base_geometry_limits['D_ij'] = 16.0
self.base_geometry_limits['H_ij'] = 1.5
self.base_geometry_limits['P_ij'] = math.pi/5
self.base_geometry_limits['B_ij'] = math.pi/5
@@ -875,6 +875,7 @@ def _compute_pairwise_geometry(
# Rise (H_ij)
H_ij = np.sum(sc_D_ij_vec * Z_ij, axis=-1) # [I, I]
D_ij = np.linalg.norm(sc_D_ij_vec, axis=-1) # [I, I]
# Buckle (B_ij)
proj_Z_i_YZ = ( # [I, I, 3]
@@ -903,6 +904,7 @@ def _compute_pairwise_geometry(
"H_ij": H_ij,
"B_ij": B_ij,
"P_ij": P_ij,
"D_ij": D_ij,
"base_ori_ij": base_ori_ij,
"X_ij": X_ij,
"Y_ij": Y_ij,
@@ -933,6 +935,7 @@ def _compute_basepair_mask(
H_ij: np.ndarray,
B_ij: np.ndarray,
P_ij: np.ndarray,
D_ij: np.ndarray,
mol_info,
*,
bool_only: bool = False,
@@ -985,7 +988,10 @@ def _compute_basepair_mask(
(P_ij <= mol_info.base_geometry_limits["P_ij"])
| (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"])
)
bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter # [I, I]
D_ij_filter = (D_ij <= mol_info.base_geometry_limits["D_ij"])
bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter & D_ij_filter # [I, I]
if bool_only:
basepairs_bool_ij = ( # [I, I]
@@ -1074,6 +1080,7 @@ def compute_nucleic_ss(
"H_ij": np.zeros((0, 0), dtype=np.float32),
"B_ij": np.zeros((0, 0), dtype=np.float32),
"P_ij": np.zeros((0, 0), dtype=np.float32),
"D_ij": np.zeros((0, 0), dtype=np.float32),
"base_ori_ij": np.zeros((0, 0), dtype=np.float32),
"basepairs_bool_ij": np.zeros((0, 0), dtype=bool),
"basepairs_ij": np.zeros((0, 0), dtype=np.float32),
@@ -1139,6 +1146,7 @@ def compute_nucleic_ss(
pw_geom["H_ij"],
pw_geom["B_ij"],
pw_geom["P_ij"],
pw_geom["D_ij"],
mol_info,
bool_only=return_basepairs_only,
eps=eps,
@@ -1498,12 +1506,23 @@ def annotate_na_ss_from_specification(
bp_partners_ann = np.empty(len(atom_array), dtype=object)
bp_partners_ann[:] = None
# Build chain/res -> token index map for region/position specs
# Build chain/res -> token index map for region/position specs.
# Accept both chain_iid-like keys (e.g. "A_1") and plain chain IDs (e.g. "A")
# so CLI/json specs like "A1,B3" work reliably in inference.
chain_iid_list: list[str] = [str(atm.chain_iid) for atm in token_level_array]
chain_id_list: list[str] = [str(atm.chain_id) for atm in token_level_array]
resi_list: list[int] = [int(atm.res_id) for atm in token_level_array]
chain_res_to_tok: dict[tuple[str, int], int] = {
(c, r): i for i, (c, r) in enumerate(zip(chain_iid_list, resi_list))
}
chain_res_to_tok: dict[tuple[str, int], int] = {}
for i, (chain_iid, chain_id, res_id) in enumerate(
zip(chain_iid_list, chain_id_list, resi_list)
):
key_iid = (chain_iid, int(res_id))
key_chain = (chain_id, int(res_id))
chain_res_to_tok.setdefault(key_iid, int(i))
chain_res_to_tok.setdefault(key_chain, int(i))
# Also support the short alias from chain_iid (e.g. "A_1" -> "A")
short_chain = chain_iid.split("_", 1)[0]
chain_res_to_tok.setdefault((short_chain, int(res_id)), int(i))
def _parse_region(region_str: str) -> tuple[str, int, int] | None:
region_str = str(region_str).strip()