mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
extra filter for huge base pair centroid distances, and also cleaned inference specification code
This commit is contained in:
@@ -90,7 +90,7 @@ class NucMolInfo:
|
||||
self.bp_val_cutoff = 0.5 # minimum basepairing score for binarizing basepairs when needed
|
||||
|
||||
self.base_geometry_limits = {}
|
||||
self.base_geometry_limits['D_ij'] = 20.0
|
||||
self.base_geometry_limits['D_ij'] = 16.0
|
||||
self.base_geometry_limits['H_ij'] = 1.5
|
||||
self.base_geometry_limits['P_ij'] = math.pi/5
|
||||
self.base_geometry_limits['B_ij'] = math.pi/5
|
||||
@@ -875,6 +875,7 @@ def _compute_pairwise_geometry(
|
||||
|
||||
# Rise (H_ij)
|
||||
H_ij = np.sum(sc_D_ij_vec * Z_ij, axis=-1) # [I, I]
|
||||
D_ij = np.linalg.norm(sc_D_ij_vec, axis=-1) # [I, I]
|
||||
|
||||
# Buckle (B_ij)
|
||||
proj_Z_i_YZ = ( # [I, I, 3]
|
||||
@@ -903,6 +904,7 @@ def _compute_pairwise_geometry(
|
||||
"H_ij": H_ij,
|
||||
"B_ij": B_ij,
|
||||
"P_ij": P_ij,
|
||||
"D_ij": D_ij,
|
||||
"base_ori_ij": base_ori_ij,
|
||||
"X_ij": X_ij,
|
||||
"Y_ij": Y_ij,
|
||||
@@ -933,6 +935,7 @@ def _compute_basepair_mask(
|
||||
H_ij: np.ndarray,
|
||||
B_ij: np.ndarray,
|
||||
P_ij: np.ndarray,
|
||||
D_ij: np.ndarray,
|
||||
mol_info,
|
||||
*,
|
||||
bool_only: bool = False,
|
||||
@@ -985,7 +988,10 @@ def _compute_basepair_mask(
|
||||
(P_ij <= mol_info.base_geometry_limits["P_ij"])
|
||||
| (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"])
|
||||
)
|
||||
bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter # [I, I]
|
||||
|
||||
D_ij_filter = (D_ij <= mol_info.base_geometry_limits["D_ij"])
|
||||
|
||||
bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter & D_ij_filter # [I, I]
|
||||
|
||||
if bool_only:
|
||||
basepairs_bool_ij = ( # [I, I]
|
||||
@@ -1074,6 +1080,7 @@ def compute_nucleic_ss(
|
||||
"H_ij": np.zeros((0, 0), dtype=np.float32),
|
||||
"B_ij": np.zeros((0, 0), dtype=np.float32),
|
||||
"P_ij": np.zeros((0, 0), dtype=np.float32),
|
||||
"D_ij": np.zeros((0, 0), dtype=np.float32),
|
||||
"base_ori_ij": np.zeros((0, 0), dtype=np.float32),
|
||||
"basepairs_bool_ij": np.zeros((0, 0), dtype=bool),
|
||||
"basepairs_ij": np.zeros((0, 0), dtype=np.float32),
|
||||
@@ -1139,6 +1146,7 @@ def compute_nucleic_ss(
|
||||
pw_geom["H_ij"],
|
||||
pw_geom["B_ij"],
|
||||
pw_geom["P_ij"],
|
||||
pw_geom["D_ij"],
|
||||
mol_info,
|
||||
bool_only=return_basepairs_only,
|
||||
eps=eps,
|
||||
@@ -1498,12 +1506,23 @@ def annotate_na_ss_from_specification(
|
||||
bp_partners_ann = np.empty(len(atom_array), dtype=object)
|
||||
bp_partners_ann[:] = None
|
||||
|
||||
# Build chain/res -> token index map for region/position specs
|
||||
# Build chain/res -> token index map for region/position specs.
|
||||
# Accept both chain_iid-like keys (e.g. "A_1") and plain chain IDs (e.g. "A")
|
||||
# so CLI/json specs like "A1,B3" work reliably in inference.
|
||||
chain_iid_list: list[str] = [str(atm.chain_iid) for atm in token_level_array]
|
||||
chain_id_list: list[str] = [str(atm.chain_id) for atm in token_level_array]
|
||||
resi_list: list[int] = [int(atm.res_id) for atm in token_level_array]
|
||||
chain_res_to_tok: dict[tuple[str, int], int] = {
|
||||
(c, r): i for i, (c, r) in enumerate(zip(chain_iid_list, resi_list))
|
||||
}
|
||||
chain_res_to_tok: dict[tuple[str, int], int] = {}
|
||||
for i, (chain_iid, chain_id, res_id) in enumerate(
|
||||
zip(chain_iid_list, chain_id_list, resi_list)
|
||||
):
|
||||
key_iid = (chain_iid, int(res_id))
|
||||
key_chain = (chain_id, int(res_id))
|
||||
chain_res_to_tok.setdefault(key_iid, int(i))
|
||||
chain_res_to_tok.setdefault(key_chain, int(i))
|
||||
# Also support the short alias from chain_iid (e.g. "A_1" -> "A")
|
||||
short_chain = chain_iid.split("_", 1)[0]
|
||||
chain_res_to_tok.setdefault((short_chain, int(res_id)), int(i))
|
||||
|
||||
def _parse_region(region_str: str) -> tuple[str, int, int] | None:
|
||||
region_str = str(region_str).strip()
|
||||
|
||||
Reference in New Issue
Block a user