diff --git a/models/rfd3/src/rfd3/constants.py b/models/rfd3/src/rfd3/constants.py index bba1925..1e023ed 100644 --- a/models/rfd3/src/rfd3/constants.py +++ b/models/rfd3/src/rfd3/constants.py @@ -440,74 +440,73 @@ backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA) # Mapping from residue type to its backbone and sidechain atoms (for convenience) ATOM_REGION_BY_RESI = { - 'ALA': {'bb':('N','CA','C','O'), + 'ALA': {'bb':('N','CA','C','O'), 'sc':('CB')}, - 'ARG': {'bb':('N','CA','C','O'), + 'ARG': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD','NE','CZ','NH1','NH2')}, - 'ASN': {'bb':('N','CA','C','O'), + 'ASN': {'bb':('N','CA','C','O'), 'sc':('CB','CG','OD1','ND2')}, - 'ASP': {'bb':('N','CA','C','O'), + 'ASP': {'bb':('N','CA','C','O'), 'sc':('CB','CG','OD1','OD2')}, - 'CYS': {'bb':('N','CA','C','O'), + 'CYS': {'bb':('N','CA','C','O'), 'sc':('CB','SG')}, - 'GLN': {'bb':('N','CA','C','O'), + 'GLN': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD','OE1','NE2')}, - 'GLU': {'bb':('N','CA','C','O'), + 'GLU': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD','OE1','OE2')}, - 'GLY': {'bb':('N','CA','C','O'), + 'GLY': {'bb':('N','CA','C','O'), 'sc':()}, - 'HIS': {'bb':('N','CA','C','O'), + 'HIS': {'bb':('N','CA','C','O'), 'sc':('CB','CG','ND1','CD2','CE1','NE2')}, - 'ILE': {'bb':('N','CA','C','O'), + 'ILE': {'bb':('N','CA','C','O'), 'sc':('CB','CG1','CG2','CD1')}, - 'LEU': {'bb':('N','CA','C','O'), + 'LEU': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD1','CD2')}, - 'LYS': {'bb':('N','CA','C','O'), + 'LYS': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD','CE','NZ')}, - 'MET': {'bb':('N','CA','C','O'), + 'MET': {'bb':('N','CA','C','O'), 'sc':('CB','CG','SD','CE')}, - 'PHE': {'bb':('N','CA','C','O'), + 'PHE': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')}, - 'PRO': {'bb':('N','CA','C','O'), + 'PRO': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD')}, - 'SER': {'bb':('N','CA','C','O'), + 'SER': {'bb':('N','CA','C','O'), 'sc':('CB','OG')}, - 'THR': {'bb':('N','CA','C','O'), + 'THR': {'bb':('N','CA','C','O'), 'sc':('CB','OG1','CG2')}, - 'TRP': {'bb':('N','CA','C','O'), + 'TRP': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')}, - 'TYR': {'bb':('N','CA','C','O'), + 'TYR': {'bb':('N','CA','C','O'), 'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')}, - 'VAL': {'bb':('N','CA','C','O'), + 'VAL': {'bb':('N','CA','C','O'), 'sc':('CB','CG1','CG2')}, - 'UNK': {'bb':('N','CA','C','O'), + 'UNK': {'bb':('N','CA','C','O'), 'sc':('CB')}, - 'MAS': {'bb':('N','CA','C','O'), + 'MAS': {'bb':('N','CA','C','O'), 'sc':('CB')}, - 'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), + 'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), 'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')}, - 'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), + 'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), 'sc':('N1','C2','O2','N3','C4','N4','C5','C6')}, - 'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), + 'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), 'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')}, - 'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), + 'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), 'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')}, - 'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), + 'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), 'sc':()}, - 'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), + 'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), 'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')}, - 'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), + 'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), 'sc':('N1','C2','O2','N3','C4','N4','C5','C6')}, - 'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), + 'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), 'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')}, - 'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), + 'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), 'sc':('N1','C2','O2','N3','C4','O4','C5','C6')}, - 'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), + 'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), 'sc':()}, - 'HIS_D': {'bb':('N','CA','C','O'), + 'HIS_D': {'bb':('N','CA','C','O'), 'sc':('CB','CG','NE2','CD2','CE1','ND1')}, } - # Known planar sidechain atoms for each canonical residue type: PLANAR_ATOMS_BY_RESI = { 'ALA': [], diff --git a/models/rfd3/src/rfd3/transforms/design_transforms.py b/models/rfd3/src/rfd3/transforms/design_transforms.py index 1ee14ce..c1193a3 100644 --- a/models/rfd3/src/rfd3/transforms/design_transforms.py +++ b/models/rfd3/src/rfd3/transforms/design_transforms.py @@ -45,7 +45,7 @@ from rfd3.transforms.util_transforms import ( get_af3_token_representative_masks, ) from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms -from rfd3.transforms.na_geom import get_bp_feats_from_atom_array +from rfd3.transforms.na_geom import na_ss_feats_from_annotation from foundry.utils.ddp import RankedLogger # noqa @@ -811,7 +811,7 @@ class AddAdditional2dFeaturesToFeats(Transform): # Need to pre-define custom constructor functions # to map from atomarray annotations to tensors. self.constructor_functions = { - 'bp_partners': get_bp_feats_from_atom_array, + 'bp_partners': na_ss_feats_from_annotation, } def check_input(self, data) -> None: diff --git a/models/rfd3/src/rfd3/transforms/na_geom.py b/models/rfd3/src/rfd3/transforms/na_geom.py index 2a93150..a1c8565 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom.py +++ b/models/rfd3/src/rfd3/transforms/na_geom.py @@ -12,86 +12,80 @@ from rfd3.transforms.conditioning_utils import sample_island_tokens from rfd3.transforms.na_geom_utils import ( annotate_na_ss, annotate_na_ss_from_data_specification, - bp_partner_to_ss_matrix, + DEFAULT_NA_SS_FEATURE_INFO, ) from atomworks.ml.utils.token import spread_token_wise, get_token_starts -def get_bp_feats_from_atom_array( - atom_array: AtomArray, -) -> np.ndarray: - """Build NA-SS features from atom_array annotations, assuming 'bp_partners' is present. - - This function reconstructs the SS matrix from the 'bp_partners' annotation on the atom_array, - then one-hot encodes it into a 3-class matrix (mask, pair, loop). +def na_ss_feats_from_annotation(atom_array: AtomArray, + token_starts= None, + n_tokens = None, + return_as_onehot = True, + ) -> np.ndarray: """ - # Fixed feature info (inferred from usage in other functions) - feature_info = { - 'NA_SS_MASK': 0, # Unspecified - 'NA_SS_PAIR': 1, # Paired - 'NA_SS_LOOP': 2, # Loop / unpaired - 'num_classes_nucleic_ss': 3, - } + Takes in atom array and constucts a base pair feature matrix from annotations, + according to to custom feature constuction + masking system. + This featurization utilizes info from BasePairEnum to assign int values + to paired, unpaired, and masked positions in the matrix. - # Check for required annotation - if "bp_partners" not in atom_array.get_annotation_categories(): - raise ValueError("atom_array must have 'bp_partners' annotation for NA-SS feature building.") + Args: + * atom_array: AtomArray with bp_partners annotation at atom level + * token_starts (optional): indices of token starts in the atom array + * n_tokens (optional): number of tokens (length of token_starts) + * return_as_onehot (optional): if False, return integer-encoded + matrix instead of one-hot encoded matrix - # Reconstruct SS matrix from annotations - na_ss_matrix = np.asarray( - bp_partner_to_ss_matrix( - atom_array, - feature_info=feature_info, - NA_only=False, # Include all residues (logic from other utils) - planar_only=True, # Use planar interactions (common default) - include_loops=True, # Include loop states - ), - dtype=np.int64, - ) + returns: + * na_ss_matrix: + If ``return_as_onehot`` is True (default): + np.ndarray of shape (n_tokens, n_tokens, n_classes) + with one-hot encoded values according to BasePairEnum - # One-hot encode the matrix - na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64) - eye = np.eye(int(feature_info['num_classes_nucleic_ss']), dtype=np.int64) - return eye[na_ss_matrix_int] + If ``return_as_onehot`` is False : + np.ndarray of shape (n_tokens, n_tokens) + with int values according to BasePairEnum -def _build_na_ss_features_from_annotations( - atom_array: AtomArray, - *, - feature_info: dict, - num_classes: int, - NA_only: bool, - planar_only: bool, - is_nucleic_ss_example: bool, - give_partial_feats: bool, - get_feature_mask_fn, -) -> np.ndarray: - """Reconstruct SS matrix from annotations, optionally mask, then one-hot.""" - na_ss_matrix = np.asarray( - bp_partner_to_ss_matrix( - atom_array, - feature_info=feature_info, - NA_only=NA_only, - planar_only=planar_only, - include_loops=True, - ), - dtype=np.int64, - ) + """ + # Get this info from atom_array, or avoid if given + if (token_starts is None) or (n_tokens is None): + token_starts = get_token_starts(atom_array) + n_tokens = len(token_starts) + - n_tokens = int(na_ss_matrix.shape[0]) + # Collect token inds for paired or loop positions: + pair_inds = [] + loop_inds = [] + token_bp_partners = atom_array.get_annotation("bp_partners")[token_starts] # get bp_partners at token level + assert len(token_bp_partners) == n_tokens, "Length of token_bp_partners should match n_tokens" + for i, j_list in enumerate(token_bp_partners): + if j_list is not None: + if len(j_list) > 0: + for j in j_list: + pair_inds.append((i, j)) + else: + loop_inds.append(i) - if give_partial_feats: - is_shown = ( - np.asarray(get_feature_mask_fn(n_tokens), dtype=bool) - if is_nucleic_ss_example - else np.zeros((n_tokens,), dtype=bool) - ) - na_ss_matrix[~is_shown, :] = feature_info["NA_SS_MASK"] - na_ss_matrix[:, ~is_shown] = feature_info["NA_SS_MASK"] + # The standard system for constructing meaningful base pair features: + # 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens) + na_ss_matrix = np.full((n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64) - na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64) - eye = np.eye(int(num_classes), dtype=np.int64) - return eye[na_ss_matrix_int] + # 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list + for pair_i, pair_j in pair_inds: + na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] + na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] # ensure symmetry + + # 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired) + # (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired ) + for loop_i in loop_inds: + na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] + na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] # ensure symmetry + + # Optional: convert NA-SS matrix to one-hot encoding according for model input: + if return_as_onehot: + na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[na_ss_matrix] + + return na_ss_matrix class CalculateNucleicAcidGeomFeats(Transform): @@ -119,11 +113,12 @@ class CalculateNucleicAcidGeomFeats(Transform): def __init__( self, is_inference, - meta_conditioning_probabilities, - + add_nucleic_ss_feats: bool = True, + # Conditional sampling parameters: p_is_nucleic_ss_example: float = 0.3, - p_show_partial_feats: float = 0.5, - nucleic_ss_min_shown: float = 0.0, + p_show_partial_feats: float = 0.7, + # Mask control paramerers: + nucleic_ss_min_shown: float = 0.2, nucleic_ss_max_shown: float = 1.0, n_islands_min: int = 1, n_islands_max: int = 6, @@ -136,39 +131,22 @@ class CalculateNucleicAcidGeomFeats(Transform): ): # Critical, must always have to know how to handle self.is_inference = is_inference - if not self.is_inference: - ## relevant in training - self.sampling_prob = meta_conditioning_probabilities['calculate_NA_SS'] - else: - ## irrelevant in inference - self.sampling_prob = 0 - # For sampling whether we add nucleic-ss features (extra t2d) - - # relevant in training - self.add_nucleic_ss_feats = (self.sampling_prob > 0) - ###### + self.add_nucleic_ss_feats = add_nucleic_ss_feats self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical self.p_is_nucleic_ss_example = p_is_nucleic_ss_example + self.p_show_partial_feats = p_show_partial_feats self.nucleic_ss_min_shown = nucleic_ss_min_shown self.nucleic_ss_max_shown = nucleic_ss_max_shown self.n_islands_min = n_islands_min self.n_islands_max = n_islands_max - self.p_show_partial_feats = p_show_partial_feats # Filters for what can be considered a planar contact interaction self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations, self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter - # Inds of annotation types in the nucleic-ss features (stack of 3 matrices): - self.feature_info = { - 'NA_SS_MASK' : 0, # Unspecified, or sm, or protein: - 'NA_SS_PAIR' : 1, - 'NA_SS_LOOP' : 2, - 'num_classes_nucleic_ss' : 3, - } def check_input(self, data: dict[str, Any]) -> None: @@ -179,10 +157,7 @@ class CalculateNucleicAcidGeomFeats(Transform): def _sample_training_flags(self) -> tuple[bool, bool]: """Sample booleans controlling whether/how features are shown in training.""" - is_nucleic_ss_example = bool( - self.add_nucleic_ss_feats - and (np.random.rand() < self.p_is_nucleic_ss_example) - ) + is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example) give_partial_feats = bool( np.random.rand() < self.p_show_partial_feats ) @@ -194,28 +169,41 @@ class CalculateNucleicAcidGeomFeats(Transform): # Calculate n_tokens (assuming one token per residue for simplicity) token_starts = get_token_starts(atom_array) token_level_array = atom_array[token_starts] - token_ids = [int(t) for t in token_level_array.token_id] n_tokens = len(token_starts) - #TODO print(" DO I NEED TO CHANGE TO TOKEN_ID???") - # Handle the training case with ground truth and masking: - if not self.is_inference and (np.random.rand() < self.sampling_prob): + + # Defaults for feature visibility + is_nucleic_ss_example = True + give_partial_feats = False + token_mask_to_show = np.ones(n_tokens, dtype=bool) + + # Handle the training case with ground truth and masking + if not self.is_inference: # First, annotate as usual - # atom_array = annotate_na_ss(atom_array, **kwargs) atom_array = annotate_na_ss(atom_array, NA_only=self.NA_only, planar_only=self.planar_only, p_canonical_bp_filter=self.p_canonical_bp_filter, ) - - # Sample mask on token level: - is_nucleic_ss_example, give_partial_feats = self._sample_training_flags() - is_ss_shown = self._sample_where_to_show_ss(n_tokens, - is_nucleic_ss_example=is_nucleic_ss_example, - give_partial_feats=give_partial_feats) # Mask vec for tokens where ss shown - # Spread mask to atom level - is_ss_shown = spread_token_wise(atom_array, is_ss_shown) + # Generate symmetric partner annotations at the token level for masking purposes. + # choice for object-consistency: if already masked/undefined: be a list mapping to self-index. + partner_sym_map = { + i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i] + for i, ts_i in enumerate(token_starts) + } + + # # Sample mask on token level: + is_nucleic_ss_example, give_partial_feats = self._sample_training_flags() + token_mask_to_show = self._sample_where_to_show_ss( + n_tokens, + is_nucleic_ss_example=is_nucleic_ss_example, + give_partial_feats=give_partial_feats, + partner_sym_map=partner_sym_map, + ) # Mask vec for tokens where ss shown + + # Spread mask to atom level + is_ss_shown = spread_token_wise(atom_array, token_mask_to_show) # Extract the base pair annotations bp_partners_atom = atom_array.get_annotation("bp_partners") @@ -233,11 +221,7 @@ class CalculateNucleicAcidGeomFeats(Transform): - 1). Single dot-bracket string - 2). multiple dot bracket strings with chain/ind ranges specified - 3). Lists of paired indices - """ - #is_nucleic_ss_example=True - #give_partial_feats=False - atom_array = annotate_na_ss_from_data_specification( data, overwrite=True, @@ -247,28 +231,34 @@ class CalculateNucleicAcidGeomFeats(Transform): if "feats" not in data: data["feats"] = {} - # data["feats"].update(nucleic_features) data.setdefault("log_dict", {}) log_dict = data["log_dict"] data["log_dict"] = log_dict data["atom_array"] = atom_array - + return data def _sample_where_to_show_ss(self, n_tokens: int, is_nucleic_ss_example: bool = True, give_partial_feats: bool = True, + partner_sym_map: dict[int, list[int]] = None, ) -> np.ndarray: - """Sample token-level islands indicating which SS rows/cols to reveal.""" + """Sample token-level islands indicating which SS rows/cols to reveal. + This custom function allows for enforcing symmetry in the shown features according + to the partner_sym_map, which encodes which tokens are partners in the SS + matrix and thus should be masked/unmasked together to maintain consistency. + + """ # If NOT is_nucleic_ss_example, set is_shown to all False if not is_nucleic_ss_example: - return np.zeros((n_tokens,), dtype=bool) + token_mask_to_show = np.zeros((n_tokens,), dtype=bool) # If NOT give_partial_feats, set is_shown to all True if not give_partial_feats: - return np.ones((n_tokens,), dtype=bool) + token_mask_to_show = np.ones((n_tokens,), dtype=bool) else: + # Get numerical parameters for that govern the mask pattern frac_shown = ( self.nucleic_ss_min_shown + (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand() @@ -276,15 +266,15 @@ class CalculateNucleicAcidGeomFeats(Transform): frac_shown = float(np.clip(frac_shown, 0.0, 1.0)) max_length = int(np.ceil(frac_shown * n_tokens)) if max_length <= 0: - return np.zeros((n_tokens,), dtype=bool) - + token_mask_to_show = np.zeros((n_tokens,), dtype=bool) island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1))) island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1))) island_len_min = min(island_len_min, n_tokens) island_len_max = min(island_len_max, n_tokens) island_len_max = max(island_len_max, island_len_min) - - return sample_island_tokens( + + # Sample the actual mask using the utility function: + token_mask_to_show = sample_island_tokens( n_tokens, island_len_min=island_len_min, island_len_max=island_len_max, @@ -292,4 +282,17 @@ class CalculateNucleicAcidGeomFeats(Transform): n_islands_max=self.n_islands_max, max_length=max_length, ) + + # Handle symmetry by iterating through the partner_sym_map items and setting + # `partner_mask_to_show` at partner positions to match `token_mask_to_show` + # initialize as all shown so effect comes from hiding + logical AND condition + partner_mask_to_show = np.ones_like(token_mask_to_show) + for token_i, partner_ind_list in partner_sym_map.items(): + for partner_ind in partner_ind_list: + partner_mask_to_show[partner_ind] = token_mask_to_show[token_i] + + # Combine the original mask with the partner mask to ensure symmetry + token_mask_to_show = token_mask_to_show & partner_mask_to_show + + return token_mask_to_show diff --git a/models/rfd3/src/rfd3/transforms/na_geom_utils.py b/models/rfd3/src/rfd3/transforms/na_geom_utils.py index 29be56f..8ab2e71 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom_utils.py +++ b/models/rfd3/src/rfd3/transforms/na_geom_utils.py @@ -4,9 +4,11 @@ from datetime import datetime from typing import Dict, Optional import math import numpy as np +import biotite.structure as struc from biotite.structure import AtomArray from atomworks.constants import ( + STANDARD_AA, STANDARD_DNA, STANDARD_RNA, ) @@ -26,72 +28,63 @@ from rfd3.transforms.hbonds_hbplus import save_atomarray_to_pdb from atomworks.ml.encoding_definitions import AF3SequenceEncoding +from rfd3.constants import ( +ATOM_REGION_BY_RESI, +PLANAR_ATOMS_BY_RESI, +) + + +# Derived: True when the residue has any planar sidechain atoms +HAS_PLANAR_SC = {res: bool(atoms) for res, atoms in PLANAR_ATOMS_BY_RESI.items()} DEFAULT_NA_SS_FEATURE_INFO: dict[str, int] = { "NA_SS_MASK": 0, "NA_SS_PAIR": 1, "NA_SS_LOOP": 2, - "num_classes_nucleic_ss": 3, } +AA_PLANAR_ATOMS = sorted(set( + atom for res in STANDARD_AA if res in PLANAR_ATOMS_BY_RESI + for atom in PLANAR_ATOMS_BY_RESI[res] +)) -# Move to function scope to avoid module-level memory retention -def _get_sequence_encoding_data(): - """Get sequence encoding data on demand to avoid persistent module-level variables.""" - sequence_encoding = AF3SequenceEncoding() - return { - 'aa_like_res_names': sequence_encoding.all_res_names[sequence_encoding.is_aa_like], - 'rna_like_res_names': sequence_encoding.all_res_names[sequence_encoding.is_rna_like], - 'dna_like_res_names': sequence_encoding.all_res_names[sequence_encoding.is_dna_like], - 'sequence_encoding': sequence_encoding - } - +NA_PLANAR_ATOMS = sorted(set( + atom for res in (*STANDARD_RNA, *STANDARD_DNA) if res in PLANAR_ATOMS_BY_RESI + for atom in PLANAR_ATOMS_BY_RESI[res] +)) class NucMolInfo: + """Constants and parameters for nucleic-acid geometry and interaction scoring. + + All parameters are set to empirically validated defaults. No constructor + arguments are currently accepted. """ - Initializes constants and parameters relevant for computing nucleic acid geometry and interactions. - """ - def __init__(self, - cutoff_HA_dist = 2.5, - cutoff_DA_dist = 3.9, - ): - """ - Args: - kwargs: Optional keyword arguments for customization. - """ + + def __init__(self) -> None: - # Optional parameters with default values - # self.incl_protein = True - self.eps = 1e-8 - # self.clamp_pairwise_params = True - # self.use_eigennormals = kwargs.get('use_eigennormals', True) - # self.use_all_base_atoms_for_MBD = kwargs.get('use_all_base_atoms_for_MBD', False) - self.edges_to_compute = ['S'] # list base edges to compute, if we want to analyze WC/Hoog/etc - self.perp_base_edge = 'S' # edge orthogonal to x- and z-directions in base frames (which is generally the sugar edge) - - self.cutoff_HA_dist = cutoff_HA_dist - self.cutoff_DA_dist = cutoff_DA_dist - self.seq_cutoff = 2 - self.gap_length = 200 - - - - - # Hbond interaction type inds when counting: - self.BB_BB = 0 - self.BB_SC = 1 - self.SC_SC = 2 + # Hbond interaction-class indices of the `hbond_count`` array: + # `hbond_count`` array is (L, L, 3), where the last dimension + # encodes interaction type between tokens i & j + self.BB_BB = 0 # backbone-backbone hbond interactions + self.BB_SC = 1 # backbone-sidechain hbond interactions + self.SC_SC = 2 # sidechain-sidechain hbond interactions + # We sum over the last dimension of the hbond_count array, scaling + # count by the following weights to get the interaction score: self.bp_weight_BB_BB = 0.0 self.bp_weight_BB_SC = 0.5 self.bp_weight_SC_SC = 1.0 - self.bp_summation_weights = [self.bp_weight_BB_BB, self.bp_weight_BB_SC, self.bp_weight_SC_SC] + # Parameters fo sigmoid function that gives us a continuous step function for + # meeting basepair interaction criteria based on hbond counts alone (1st filter). + # Calibrated such that: + # >= 2 base-base H-bonds -> ~1.0 + # 1 base-base H-bond + 1 base-backbone H-bond -> ~0.5 self.min_hbonds_for_bp = 2.0 self.bp_hbond_coeff = 9.8 # determined heuristically self.bp_val_cutoff = 0.5 # minimum basepairing score for binarizing basepairs when needed @@ -102,91 +95,8 @@ class NucMolInfo: self.base_geometry_limits['P_ij'] = math.pi/5 self.base_geometry_limits['B_ij'] = math.pi/5 - # For interaction-edge classification (Watson-Crick, Hoogstein, Sugar, Base-other): - # self.edge_to_ind = {'W':0 , 'H':1 , 'S':2 ,'B':3} self.rep_atom_dict={"protein": "CA", "rna": "C1'", "dna": "C1'"} - self.has_planar_sc = { - 'ALA': False, - 'ARG': True, - 'ASN': True, - 'ASP': True, - 'CYS': False, - 'GLN': True, - 'GLU': True, - 'GLY': False, - 'HIS': True, - 'ILE': False, - 'LEU': False, - 'LYS': False, - 'MET': False, - 'PHE': True, - 'PRO': False, - 'SER': False, - 'THR': False, - 'TRP': True, - 'TYR': True, - 'VAL': False, - 'UNK': False, - 'MAS': False, - 'DA': True, - 'DC': True, - 'DG': True, - 'DT': True, - 'DX': False, - 'A': True, - 'C': True, - 'G': True, - 'U': True, - 'X': False, - 'HIS_D': True, - } - - - - # Make self.planar_atom_list_dict based on known planar atoms for each residue type: - self.planar_atom_list_dict = { - 'ALA': [], - 'ARG': ['NH1', 'NH2', 'CZ', 'NE', 'CD'], - 'ASN': ['OD1', 'ND2', 'CG', 'CB'], - 'ASP': ['OD1', 'OD2', 'CG', 'CB'], - 'CYS': [], - 'GLN': ['OE1', 'NE2', 'CD', 'CG'], - 'GLU': ['OE1', 'OE2', 'CD', 'CG'], - 'GLY': [], - 'HIS': ['ND1', 'CE1', 'NE2', 'CD2', 'CG', 'CB'], - 'ILE': [], - 'LEU': [], - 'LYS': [], - 'MET': [], - 'PHE': ['CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'], - 'PRO': [], - 'SER': [], - 'THR': [], - 'TRP': ['CH2', 'CZ3', 'CZ2', 'CE3', 'CE2', 'CD2', 'NE1', 'CD1', 'CG', 'CB'], - 'TYR': ['OH', 'CZ', 'CE1', 'CE2', 'CD1', 'CD2', 'CG', 'CB'], - 'VAL': [], - 'UNK': [], - 'MAS': [], - 'DA': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'], - 'DC': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'], - 'DG': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'], - 'DT': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1', 'C7'], - 'DX': [], - 'A': ['N6', 'C6', 'N1', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'], - 'C': ['N4', 'C4', 'N3', 'O2', 'C2', 'C5', 'C6', 'N1'], - 'G': ['O6', 'C6', 'N1', 'N2', 'C2', 'N3', 'C4', 'C5', 'N7', 'C8', 'N9'], - 'U': ['O4', 'O2', 'N3', 'C4', 'C2', 'C5', 'C6', 'N1'], - 'X': [], - 'HIS_D': ['ND1', 'CD2', 'CE1', 'NE2', 'CG', 'CB'], - } - - - # from pdb import set_trace; set_trace() - - self.nuc_resi_3letter = ["DA","DG","DC","DT","A","G","C","U"] - self.ring_atom_list = ["N1","C2","N3","C4","C6","C5"] - # go through self.vec_atom_dict and remove spaces from atom names (values in inner dicts), and remove spaces from keys + replace 'R' with '' in outer dict keys self.vec_atom_dict = { "DA": {"W_start":"N1", "W_stop":"N6", "H_start":"N7", "H_stop":"N6", "S_start":"C1'", "S_stop":"N3", "B_start":"C1'", "B_stop":"N9" }, @@ -201,96 +111,154 @@ class NucMolInfo: - self.atom_region_dict = { - 'ALA': {'bb':('N','CA','C','O'), - 'sc':('CB')}, - 'ARG': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD','NE','CZ','NH1','NH2')}, - 'ASN': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','OD1','ND2')}, - 'ASP': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','OD1','OD2')}, - 'CYS': {'bb':('N','CA','C','O'), - 'sc':('CB','SG')}, - 'GLN': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD','OE1','NE2')}, - 'GLU': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD','OE1','OE2')}, - 'GLY': {'bb':('N','CA','C','O'), - 'sc':()}, - 'HIS': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','ND1','CD2','CE1','NE2')}, - 'ILE': {'bb':('N','CA','C','O'), - 'sc':('CB','CG1','CG2','CD1')}, - 'LEU': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD1','CD2')}, - 'LYS': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD','CE','NZ')}, - 'MET': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','SD','CE')}, - 'PHE': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')}, - 'PRO': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD')}, - 'SER': {'bb':('N','CA','C','O'), - 'sc':('CB','OG')}, - 'THR': {'bb':('N','CA','C','O'), - 'sc':('CB','OG1','CG2')}, - 'TRP': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')}, - 'TYR': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')}, - 'VAL': {'bb':('N','CA','C','O'), - 'sc':('CB','CG1','CG2')}, - 'UNK': {'bb':('N','CA','C','O'), - 'sc':('CB')}, - 'MAS': {'bb':('N','CA','C','O'), - 'sc':('CB')}, - 'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), - 'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')}, - 'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), - 'sc':('N1','C2','O2','N3','C4','N4','C5','C6')}, - 'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), - 'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')}, - 'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), - 'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')}, - 'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"), - 'sc':()}, - 'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), - 'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')}, - 'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), - 'sc':('N1','C2','O2','N3','C4','N4','C5','C6')}, - 'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), - 'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')}, - 'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), - 'sc':('N1','C2','O2','N3','C4','O4','C5','C6')}, - 'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"), - 'sc':()}, - 'HIS_D': {'bb':('N','CA','C','O'), - 'sc':('CB','CG','NE2','CD2','CE1','ND1')}, - } +def calculate_hb_counts( + atom_array: AtomArray, + token_level_data: dict, + mol_info: NucMolInfo, + cutoff_HA_dist: float = 2.5, + cutoff_DA_dist: float = 3.9, + ): + """Count hydrogen bonds between residue pairs using HBPLUS. + + Args: + atom_array: Structure to analyse. + token_level_data: Token-level metadata dict (must contain + ``token_id_list`` and ``resi2index``). + mol_info: Molecular-info object for backbone/sidechain atom lookup. + cutoff_HA_dist: H–A distance cutoff (Å) passed to HBPLUS. + cutoff_DA_dist: D–A distance cutoff (Å) passed to HBPLUS. + + Returns: + np.ndarray of shape ``(I, I, 3)`` (int32) where the last axis + encodes: 0 = BB–BB, 1 = BB–SC, 2 = SC–SC H-bond counts. + """ + + dtstr = datetime.now().strftime("%Y%m%d%H%M%S") + pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb" + + atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path) + subprocess.call( + [ + "/projects/ml/hbplus", + "-h", + str(cutoff_HA_dist), + "-d", + str(cutoff_DA_dist), + pdb_path, + pdb_path, + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) - self.aa_planar_atoms = ['NH1', 'NH2', 'CZ', 'NE', 'OD1', 'ND2', - 'OD2', 'OE1', 'NE2', 'CD', 'OE2', 'ND1', - 'CD2', 'CE1', 'CD1', 'CE2', 'NE1', 'CZ2', - 'CZ3', 'CH2', 'CE3', 'OH', 'CG', 'CB',] + num_resis_total = len(token_level_data["token_id_list"]) + + hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32) + + hb2_path = pdb_path.replace("pdb", "hb2") + with open(hb2_path, "r") as hb2_f: + for i, line in enumerate(hb2_f): + if i < 8: + continue + if len(line) < 28: + continue + + d_chain_iid = chain_map[line[0]] + d_resi = int(line[1:5].strip()) + d_resn = line[6:9].strip() + d_atom_name = line[9:13].strip() + + # Initialize donor/acceptor sidechain/backbone flags: + # then replace with True if valid for summation + d_is_sc = False + d_is_bb = False + a_is_sc = False + a_is_bb = False + + d_mask = ( + (atom_array.atom_name == d_atom_name) + & (atom_array.res_name == d_resn) + & (atom_array.res_id == d_resi) + & (atom_array.chain_iid == d_chain_iid) + ) + # d_atm = atom_array[d_mask] + # d_idx = d_atm.token_id + d_idx = token_level_data["resi2index"].get(f"{d_chain_iid}__{d_resi}", None) + if d_idx is None: + continue + + # Handle standard polymer residues for donor atom: + if d_resn in ATOM_REGION_BY_RESI.keys(): + d_is_sc = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['sc']) + d_is_bb = (d_atom_name in ATOM_REGION_BY_RESI[d_resn]['bb']) + else: + # If non-polymer, define any ligand HBonding atom as backbone: + if d_mask.sum() > 0: + d_is_bb = atom_array[d_mask][0].is_ligand + + a_chain_iid = chain_map[line[14]] + a_resi = int(line[15:19].strip()) + a_resn = line[20:23].strip() + a_atom_name = line[23:27].strip() + + a_mask = ( + (atom_array.atom_name == a_atom_name) + & (atom_array.res_name == a_resn) + & (atom_array.res_id == a_resi) + & (atom_array.chain_iid == a_chain_iid) + ) + a_idx = token_level_data["resi2index"].get(f"{a_chain_iid}__{a_resi}", None) + if a_idx is None: + continue + + # Handle standard polymer residues for acceptor atom: + if a_resn in ATOM_REGION_BY_RESI.keys(): + a_is_sc = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['sc']) + a_is_bb = (a_atom_name in ATOM_REGION_BY_RESI[a_resn]['bb']) + else: + # If non-polymer, define any ligand HBonding atom as backbone: + if a_mask.sum() > 0: + a_is_bb = atom_array[a_mask][0].is_ligand + + # 0 -> both backbone (BB-BB) + hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb) + hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb) + + # 1 -> one backbone, one sidechain (BB-SC) + hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb) + hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb) + + # 2 -> both sidechain (SC-SC) + hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc) + hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc) + + os.remove(pdb_path) + os.remove(hb2_path) + + return hbond_count - self.na_planar_atoms = ['C4', 'N3', 'C2', 'C6', 'C5', 'N7', 'C8', - 'N6', 'O2', 'N4', 'N2', 'O6', 'O4', 'C7', - 'N9', 'N1'] def find_planar_positions( atom_array: AtomArray, mol_info: NucMolInfo, tol: float = 1e-2, ) -> Dict: - """ - Finds residues with planar sidechains based on four tip-most atoms, - but also checks for valid atoms to use for this type of calculation. + """Identify residues with planar sidechains via known atom lists or PCA plane-fitting. + + For canonical residues the planar atoms are looked up from ``mol_info``; + for non-canonical residues a plane is fitted to the four tip-most sidechain + atoms, and all atoms within *tol* of that plane are returned. + + Args: + atom_array: Structure to analyse. + mol_info: Molecular-info object supplying per-residue planar atom lists. + tol: Distance tolerance (Å) from the fitted plane for an atom to be + considered planar. Returns: - dict of planar atom lists + Dictionary ``{(chain_iid, res_id): [atom_name, ...]}`` mapping each + unique residue position to its list of planar sidechain atom names. """ unique_positions_list = [] for atm in atom_array: @@ -312,11 +280,11 @@ def find_planar_positions( res_atoms = atom_array[mask] # If possible, speed up by using known planar atoms for this residue type: - if res_name in mol_info.planar_atom_list_dict.keys(): + if res_name in PLANAR_ATOMS_BY_RESI.keys(): # Shared atoms between residue and known planar atoms for that residue type: planar_atom_list = list( set([atm.atom_name for atm in res_atoms]) & - set(mol_info.planar_atom_list_dict[res_name]) + set(PLANAR_ATOMS_BY_RESI[res_name]) ) planar_atom_list_dict[(chain_iid, res_id)] = planar_atom_list @@ -327,11 +295,11 @@ def find_planar_positions( for atm in res_atoms: # Can pre-filter protein planar atoms: - if atm.is_protein and (atm.atom_name in mol_info.aa_planar_atoms): + if atm.is_protein and (atm.atom_name in AA_PLANAR_ATOMS): candidate_planar_atm_names.append(atm.atom_name) candidate_planar_atm_coords.append(atm.coord) # Can pre-filter nucleic acid planar atoms: - elif (atm.is_rna or atm.is_dna) and (atm.atom_name in mol_info.na_planar_atoms): + elif (atm.is_rna or atm.is_dna) and (atm.atom_name in NA_PLANAR_ATOMS): candidate_planar_atm_names.append(atm.atom_name) candidate_planar_atm_coords.append(atm.coord) # Otherwise, consider all atoms for plane fitting: @@ -389,144 +357,25 @@ def find_planar_positions( return planar_atom_list_dict - - -def calculate_hb_counts( - atom_array: AtomArray, - token_level_data: dict, - mol_info: NucMolInfo, - cutoff_HA_dist: float = 2.5, - cutoff_DA_dist: float = 3.9, - ): - """ - Compute hbond counts between residues and return an (L, L, 3) - numpy array where the last dimension encodes: - 0 -> both backbone (BB-BB) - 1 -> one backbone, one sidechain (BB-SC) - 2 -> both sidechain (SC-SC) - """ - dtstr = datetime.now().strftime("%Y%m%d%H%M%S") - pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb" - - atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path) - subprocess.call( - [ - "/projects/ml/hbplus", - "-h", - str(cutoff_HA_dist), - "-d", - str(cutoff_DA_dist), - pdb_path, - pdb_path, - ], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - ) - - - num_resis_total = len(token_level_data["token_id_list"]) - - hbond_count = np.zeros((num_resis_total, num_resis_total, 3), dtype=np.int32) - - hb2_path = pdb_path.replace("pdb", "hb2") - with open(hb2_path, "r") as hb2_f: - for i, line in enumerate(hb2_f): - if i < 8: - continue - if len(line) < 28: - continue - - # Initialize donor/acceptor sidechain/backbone flags: - # then replace with True if valid for summation - d_is_sc = False - d_is_bb = False - a_is_sc = False - a_is_bb = False - - d_chain_iid = chain_map[line[0]] - d_resi = int(line[1:5].strip()) - d_resn = line[6:9].strip() - d_atom_name = line[9:13].strip() - - d_mask = ( - (atom_array.atom_name == d_atom_name) - & (atom_array.res_name == d_resn) - & (atom_array.res_id == d_resi) - & (atom_array.chain_iid == d_chain_iid) - ) - d_atm = atom_array[d_mask] - d_idx = d_atm.token_id - - # Handle standard polymer residues for donor atom: - if d_resn in mol_info.atom_region_dict.keys(): - d_is_sc = (d_atom_name in mol_info.atom_region_dict[d_resn]['sc']) - d_is_bb = (d_atom_name in mol_info.atom_region_dict[d_resn]['bb']) - else: - # If non-polymer, define any ligand HBonding atom as backbone: - if d_mask.sum() > 0: - d_is_bb = atom_array[d_mask][0].is_ligand - - a_chain_iid = chain_map[line[14]] - a_resi = int(line[15:19].strip()) - a_resn = line[20:23].strip() - a_atom_name = line[23:27].strip() - - a_mask = ( - (atom_array.atom_name == a_atom_name) - & (atom_array.res_name == a_resn) - & (atom_array.res_id == a_resi) - & (atom_array.chain_iid == a_chain_iid) - ) - a_atm = atom_array[a_mask] - a_idx = a_atm.token_id - - # Handle standard polymer residues for acceptor atom: - if a_resn in mol_info.atom_region_dict.keys(): - a_is_sc = (a_atom_name in mol_info.atom_region_dict[a_resn]['sc']) - a_is_bb = (a_atom_name in mol_info.atom_region_dict[a_resn]['bb']) - else: - # If non-polymer, define any ligand HBonding atom as backbone: - if a_mask.sum() > 0: - a_is_bb = atom_array[a_mask][0].is_ligand - - # 0 -> both backbone (BB-BB) - hbond_count[a_idx, d_idx, 0] += (a_is_bb * d_is_bb) - hbond_count[d_idx, a_idx, 0] += (d_is_bb * a_is_bb) - - # 1 -> one backbone, one sidechain (BB-SC) - hbond_count[a_idx, d_idx, 1] += (a_is_bb * d_is_sc) | (a_is_sc * d_is_bb) - hbond_count[d_idx, a_idx, 1] += (d_is_bb * a_is_sc) | (d_is_sc * a_is_bb) - - # 2 -> both sidechain (SC-SC) - hbond_count[a_idx, d_idx, 2] += (a_is_sc * d_is_sc) - hbond_count[d_idx, a_idx, 2] += (d_is_sc * a_is_sc) - - os.remove(pdb_path) - os.remove(hb2_path) - - return hbond_count - - - - def make_coord_list(atom_array: AtomArray, residue_list: list[str], chain_list: list[str], atom_list: list[str], ) -> list[list[str]]: - """ - Given an atom array, and lists of residues, chains, and atom names, - return a list of coordinates for the specified atoms in the specified residues and chains. - If the atom is not found, return [NaN, NaN, NaN] for that atom. - The the three input lists must be of the same length, and the output list will have the same length as well. - Args: - atom_array: BioTite atom_array object - residue_list: list of residue names to consider - chain_list: list of chain identifiers to consider - atom_list: list of atom names to extract coordinates for - Returns: - coord_list: list of lists of coordinates for the specified atoms + """Extract per-residue representative coordinates from an AtomArray. + All three input lists must have the same length. Missing atoms are + filled with ``[NaN, NaN, NaN]``. + + Args: + atom_array: Biotite AtomArray to query. + residue_list: Residue IDs (one per token). + chain_list: Chain identifiers (one per token). + atom_list: Atom names to extract (use ``"atomized"`` to take the + first atom of the residue). + + Returns: + List of ``[x, y, z]`` coordinate lists, same length as input. """ coord_list = [] for res_id, chain_id, atom_name in zip(residue_list, chain_list, atom_list): @@ -565,31 +414,49 @@ def get_token_level_metadata( *, NA_only: bool = False, planar_only: bool = True, + seq_cutoff = 2, + gap_length = 200 ) -> dict: - """Lightweight token-level metadata. + """Build lightweight token-level metadata (no coordinate geometry). - This intentionally avoids expensive coordinate-derived computations - (e.g., planar plane-fitting and geometry coordinate extraction). + Sufficient for SS reconstruction, loop labeling from ``bp_partners``, + and inference-time SS specification parsing. For geometry keys + (``xyz_planar``, ``frame_xyz``, ``M_i``), follow up with + :func:`add_token_level_geometry_data`. - It is sufficient for: - - SS reconstruction / loop labeling from ``bp_partners`` - - inference-time SS specification parsing + Args: + atom_array: Structure to analyse. + mol_info: Molecular-info constants. + NA_only: If True, restrict filter_mask to nucleic-acid tokens. + planar_only: If True, restrict filter_mask to tokens with planar + sidechains. + seq_cutoff: Sequence-distance threshold for the ``seq_neighbors`` + boolean mask. + gap_length: Artificial gap inserted between chains for relative + sequence position computation. - If you later need geometry keys (``xyz_planar``, ``frame_xyz``, ``M_i``), - call :func:`add_token_level_geometry_data`. + Returns: + Dict with keys: ``token_starts``, ``token_index``, ``is_na``, + ``is_planar``, ``chain_list``, ``chain_iid_list``, ``resi_list``, + ``resn_list``, ``token_id_list``, ``resi2index``, ``len_s``, + ``seq_neighbors``, ``na_inds``, ``na_tensor_inds``, + ``filter_mask``, ``rep_atom_list``, ``S_start_atom_list``, + ``S_stop_atom_list``, ``include_geometry`` (False). """ - token_starts = get_token_starts(atom_array) + # Use residue starts (not token starts) so atomized atoms within one residue + # map to a single NA-SS position. + token_starts = struc.get_residue_starts(atom_array) token_level_array = atom_array[token_starts] token_index = np.arange(len(token_starts)) # molecule type flags - seq_data = _get_sequence_encoding_data() - is_protein = np.isin(token_level_array.res_name, seq_data["aa_like_res_names"]) - is_rna = np.isin(token_level_array.res_name, seq_data["rna_like_res_names"]) - is_dna = np.isin(token_level_array.res_name, seq_data["dna_like_res_names"]) - del seq_data + # Instantiate encoding locally to avoid retaining large arrays at module scope. + sequence_encoding = AF3SequenceEncoding() + is_protein = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_aa_like]) + is_rna = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_rna_like]) + is_dna = np.isin(token_level_array.res_name, sequence_encoding.all_res_names[sequence_encoding.is_dna_like]) is_na_arr = (is_dna | is_rna).astype(bool) @@ -613,8 +480,8 @@ def get_token_level_metadata( res_name_list.append(atm.res_name) token_id_list.append(str(atm.token_id)) - if atm.is_polymer and (atm.res_name in mol_info.has_planar_sc.keys()): - sc_planarity_list.append(bool(mol_info.has_planar_sc[atm.res_name])) + if atm.is_polymer and (atm.res_name in HAS_PLANAR_SC.keys()): + sc_planarity_list.append(bool(HAS_PLANAR_SC[atm.res_name])) else: sc_planarity_list.append(False) @@ -658,16 +525,16 @@ def get_token_level_metadata( # relative sequence positions w/ chain gaps rel_pos_list: list[int] = [] current_chain = "" - chn_bias = -mol_info.gap_length + chn_bias = -gap_length for r, c in zip(resi_list, chain_iid_list): if c != current_chain: - chn_bias += mol_info.gap_length + chn_bias += gap_length current_chain = c rel_pos_list.append(int(r + chn_bias)) rel_pos = np.asarray(rel_pos_list, dtype=np.int64) seq_neighbors = ( - np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(mol_info.seq_cutoff) + np.abs(rel_pos[:, None] - rel_pos[None, :]) <= int(seq_cutoff) ) na_inds = np.nonzero(is_na_arr)[0].tolist() @@ -717,13 +584,24 @@ def add_token_level_geometry_data( NA_only: bool = False, planar_only: bool = True, ) -> dict: - """Augment a metadata-only token_level_data dict with geometry fields. + """Augment token-level metadata with coordinate-derived geometry fields. - Populates: - - xyz_planar, xyz_S_start, xyz_S_stop - - frame_xyz, M_i - - updates is_planar and filter_mask using coordinate-derived planarity - - sets include_geometry=True + Populates ``xyz_planar``, ``xyz_S_start``, ``xyz_S_stop``, + ``frame_xyz``, ``M_i`` and updates ``is_planar`` / ``filter_mask`` + using coordinate-derived planarity. Sets ``include_geometry=True``. + + No-ops if geometry was already computed. + + Args: + atom_array: Structure to extract coordinates from. + mol_info: Molecular-info constants. + token_level_data: Dict produced by :func:`get_token_level_metadata` + (modified in-place and returned). + NA_only: Restrict filter_mask to nucleic-acid tokens. + planar_only: Restrict filter_mask to tokens with planar sidechains. + + Returns: + The same ``token_level_data`` dict, augmented with geometry keys. """ if bool(token_level_data.get("include_geometry", False)): @@ -755,12 +633,12 @@ def add_token_level_geometry_data( S_start_atom_list: list[str | None] = token_level_data["S_start_atom_list"] S_stop_atom_list: list[str | None] = token_level_data["S_stop_atom_list"] - planar_atom_list_dict = find_planar_positions(atom_array, mol_info) + planar_atom_list_dict = find_planar_positions(atom_array, mol_info) # {(chain_iid, res_id): [atom_name, ...]} has_planar_sc: list[bool] = [] - xyz_planar: list[list[list[float]]] = [] - xyz_S_start: list[list[float]] = [] - xyz_S_stop: list[list[float]] = [] + xyz_planar: list[list[list[float]]] = [] # list[I] of [K_i, 3] (K_i varies per residue) + xyz_S_start: list[list[float]] = [] # list[I] of [3] + xyz_S_stop: list[list[float]] = [] # list[I] of [3] for c, r, S_start_atm, S_stop_atm in zip( chain_iid_list, @@ -806,21 +684,21 @@ def add_token_level_geometry_data( del atom_array_i # frame coordinates and backbone direction - frame_xyz = np.asarray( + frame_xyz = np.asarray( # [I, 3] representative-atom coordinates make_coord_list(atom_array, resi_list, chain_list, rep_atom_list), dtype=np.float32, ) - padded_centers = np.concatenate([frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0) - M_i = ( + padded_centers = np.concatenate([frame_xyz[:1], frame_xyz, frame_xyz[-1:]], axis=0) # [I+2, 3] + M_i = ( # [I, 3] smoothed backbone-direction vectors (padded_centers[1:-1] - padded_centers[:-2]) + (padded_centers[2:] - padded_centers[1:-1]) ) / 2.0 - is_planar_arr = np.asarray(has_planar_sc, dtype=bool) + is_planar_arr = np.asarray(has_planar_sc, dtype=bool) # [I] token_level_data["is_planar"] = is_planar_arr - is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) + is_na_arr = np.asarray(token_level_data["is_na"], dtype=bool) # [I] if NA_only and planar_only: filter_mask = is_na_arr & is_planar_arr elif NA_only and (not planar_only): @@ -829,7 +707,7 @@ def add_token_level_geometry_data( filter_mask = is_planar_arr.copy() else: filter_mask = np.ones_like(is_na_arr, dtype=bool) - token_level_data["filter_mask"] = filter_mask + token_level_data["filter_mask"] = filter_mask # [I] bool token_level_data.update( { @@ -846,7 +724,277 @@ def add_token_level_geometry_data( return token_level_data -def _compute_nucleic_ss_impl( +# --------------------------------------------------------------------------- +# Sub-calculations used by compute_nucleic_ss +# --------------------------------------------------------------------------- + + +def _compute_local_frames( + xyz_planar: list[np.ndarray], + planar_centers: np.ndarray, + M_i: np.ndarray, + *, + xyz_S_start: list | None = None, + xyz_S_stop: list | None = None, + compute_full_frame: bool = False, + eps: float = 1e-8, +) -> dict[str, np.ndarray]: + """Build per-residue local coordinate frames from planar sidechain atoms. + + The base-normal direction Z_i is always computed via PCA on the planar + atom cloud, corrected for backbone direction. When *compute_full_frame* + is True the sugar-edge vector is used to derive X_i and Y_i as well. + + Args: + xyz_planar: Per-residue planar-atom coordinates, list[I] of [K_i, 3]. + planar_centers: Sidechain planar-atom centroids, [I, 3]. + M_i: Backbone-direction vectors, [I, 3]. + xyz_S_start: Sugar-edge start coordinates, list[I] of [3]. + Required when *compute_full_frame* is True. + xyz_S_stop: Sugar-edge stop coordinates, list[I] of [3]. + Required when *compute_full_frame* is True. + compute_full_frame: If True, also compute X_i and Y_i. + eps: Small constant for numerical stability. + + Returns: + Dict with ``"Z_i"`` (always), and ``"X_i"``, ``"Y_i"`` when + *compute_full_frame* is True. Each array has shape ``[I, 3]``. + """ + n_tokens = len(xyz_planar) + + # Mean-centre the planar atoms per residue + centered_points = [ # list[I] of [K_i, 3] + np.asarray(xyz_i, dtype=np.float32) - cen_i + for xyz_i, cen_i in zip(xyz_planar, planar_centers) + ] + + # PCA → eigenvectors per residue + eigenvectors = np.full((n_tokens, 3, 3), np.nan, dtype=np.float32) # [I, 3, 3] + + for i, xyz_i in enumerate(centered_points): + xyz_i = xyz_i[~np.isnan(xyz_i).any(axis=1)] + if xyz_i.shape[0] >= 3: + cov_matrix = np.einsum("ij,ik->jk", xyz_i, xyz_i) / max( # [3, 3] + xyz_i.shape[0] - 1, 1 + ) + _, eigvecs = np.linalg.eigh(cov_matrix) # [3, 3] + eigenvectors[i] = eigvecs + + # Base-normal: smallest-eigenvalue direction, corrected for backbone dir + N_i = eigenvectors[:, :, 0] # [I, 3] + N_i = N_i / (np.linalg.norm(N_i, axis=1, keepdims=True) + eps) + + Z_i = N_i * np.sum(M_i * N_i, axis=-1, keepdims=True) # [I, 3] + Z_i = Z_i / (np.linalg.norm(Z_i, axis=-1, keepdims=True) + eps) + + result: dict[str, np.ndarray] = {"Z_i": Z_i} + + if compute_full_frame: + if xyz_S_start is None or xyz_S_stop is None: + raise ValueError("xyz_S_start and xyz_S_stop are required for full frame") + + X_s_i = ( # [I, 3] sugar-edge direction + np.asarray(xyz_S_stop, dtype=np.float32) + - np.asarray(xyz_S_start, dtype=np.float32) + ) + X_s_i = X_s_i / (np.linalg.norm(X_s_i, axis=-1, keepdims=True) + eps) + + X_i = np.cross(Z_i, X_s_i) # [I, 3] + X_i = X_i / (np.linalg.norm(X_i, axis=-1, keepdims=True) + eps) + result["X_i"] = X_i + + Y_i = np.cross(X_i, Z_i) # [I, 3] + Y_i = Y_i / (np.linalg.norm(Y_i, axis=-1, keepdims=True) + eps) + result["Y_i"] = Y_i + + return result + + +def _compute_pairwise_geometry( + Z_i: np.ndarray, + frame_D_ij_vec: np.ndarray, + sc_D_ij_vec: np.ndarray, + *, + X_i: np.ndarray | None = None, + clamp: bool = True, + compute_opening: bool = False, + eps: float = 1e-8, +) -> dict[str, np.ndarray]: + """Compute pairwise base-step geometry between all residue pairs. + + Derives the pairwise coordinate frame (X_ij, Y_ij, Z_ij) and the + base-pair geometry parameters: rise (H_ij), buckle (B_ij), propeller + (P_ij), and optionally opening angle (O_ij). + + Args: + Z_i: Per-residue base-normal vectors, [I, 3]. + frame_D_ij_vec: Pairwise backbone displacement vectors, [I, I, 3]. + sc_D_ij_vec: Pairwise sidechain-centroid displacement vectors, [I, I, 3]. + X_i: Per-residue local X-axis, [I, 3]. Required when + *compute_opening* is True. + clamp: Clamp cosines to [-1, 1] before ``arccos``. + compute_opening: If True, compute opening angle O_ij. + eps: Small constant for numerical stability. + + Returns: + Dict with keys ``"H_ij"`` [I, I], ``"B_ij"`` [I, I], + ``"P_ij"`` [I, I], ``"base_ori_ij"`` [I, I], + ``"X_ij"`` [I, I, 3], ``"Y_ij"`` [I, I, 3], + ``"Z_ij"`` [I, I, 3], and optionally ``"O_ij"`` [I, I]. + """ + # Orientation-selected pairwise Z-axis + Z_sum = Z_i[:, None, :] + Z_i[None, :, :] # [I, I, 3] + Z_diff = Z_i[:, None, :] - Z_i[None, :, :] # [I, I, 3] + Z_ij_oris = 0.5 * np.stack((Z_sum, Z_diff), axis=0) # [2, I, I, 3] + + base_ori_ij = ( # [I, I] 0=parallel, 1=antiparallel + np.linalg.norm(Z_ij_oris[1], axis=-1) > np.linalg.norm(Z_ij_oris[0], axis=-1) + ).astype(np.int64) + + Z_ij = np.where(base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]) # [I, I, 3] + Z_ij = Z_ij / (np.linalg.norm(Z_ij, axis=-1, keepdims=True) + eps) + + # Pairwise Y (inter-residue direction) and X axes + Y_ij = frame_D_ij_vec / (np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps) # [I, I, 3] + X_ij = np.cross(Z_ij, Y_ij) # [I, I, 3] + X_ij = X_ij / (np.linalg.norm(X_ij, axis=-1, keepdims=True) + eps) + + # Rise (H_ij) + H_ij = np.sum(sc_D_ij_vec * Z_ij, axis=-1) # [I, I] + + # Buckle (B_ij) + proj_Z_i_YZ = ( # [I, I, 3] + np.sum(Z_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij + + np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij + ) + proj_Z_i_YZ_norm = proj_Z_i_YZ / (np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps) + cos_buckle = np.sum(proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), axis=-1) # [I, I] + + # Propeller (P_ij) + proj_Z_i_ZX = ( # [I, I, 3] + np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij + + np.sum(Z_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij + ) + proj_Z_i_ZX_norm = proj_Z_i_ZX / (np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps) + cos_propeller = np.sum(proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), axis=-1) # [I, I] + + if clamp: + cos_buckle = np.clip(cos_buckle, -1.0, 1.0) + cos_propeller = np.clip(cos_propeller, -1.0, 1.0) + + B_ij = np.arccos(cos_buckle) # [I, I] + P_ij = np.arccos(cos_propeller) # [I, I] + + result: dict[str, np.ndarray] = { + "H_ij": H_ij, + "B_ij": B_ij, + "P_ij": P_ij, + "base_ori_ij": base_ori_ij, + "X_ij": X_ij, + "Y_ij": Y_ij, + "Z_ij": Z_ij, + } + + # Opening angle (O_ij) — purely diagnostic + if compute_opening: + if X_i is None: + raise ValueError("X_i is required to compute opening angle") + + proj_X_i_XY = ( # [I, I, 3] + np.sum(X_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij + + np.sum(X_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij + ) + proj_X_i_XY_norm = proj_X_i_XY / (np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps) + cos_opening = np.sum(proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), axis=-1) # [I, I] + if clamp: + cos_opening = np.clip(cos_opening, -1.0, 1.0) + result["O_ij"] = np.arccos(cos_opening) # [I, I] + + return result + + +def _compute_basepair_mask( + hbond_count: np.ndarray, + seq_neighbors: np.ndarray, + H_ij: np.ndarray, + B_ij: np.ndarray, + P_ij: np.ndarray, + mol_info, + *, + bool_only: bool = False, + eps: float = 1e-8, +) -> dict[str, np.ndarray] | np.ndarray: + """Identify base pairs by combining H-bond scores with geometry filters. + + Computes a sigmoid-based base-pair probability from weighted H-bond + counts and gates it with rise / buckle / propeller geometry limits. + + Args: + hbond_count: H-bond counts, [I, I, 3] (BB-BB / BB-SC / SC-SC). + seq_neighbors: Sequence-neighbor boolean mask, [I, I]. + H_ij: Rise displacement, [I, I]. + B_ij: Buckle angle (radians), [I, I]. + P_ij: Propeller angle (radians), [I, I]. + mol_info: Molecular-info object with ``bp_summation_weights``, + ``bp_hbond_coeff``, ``min_hbonds_for_bp``, ``bp_val_cutoff``, + and ``base_geometry_limits``. + bool_only: If True, return only the boolean mask array. + eps: Small constant for numerical stability. + + Returns: + If *bool_only*: ``np.ndarray`` of shape ``(I, I)`` (bool). + Otherwise: dict with ``"basepairs_bool_ij"`` [I, I] (bool), + ``"basepairs_ij"`` [I, I] (float), and + ``"hbond_summation"`` [I, I] (float). + """ + hbond_summation = np.tensordot( # [I, I] + hbond_count.astype(np.float32), + np.asarray(mol_info.bp_summation_weights, dtype=np.float32), + axes=([2], [0]), + ) + + logits = mol_info.bp_hbond_coeff * ( # [I, I] + hbond_summation - (mol_info.min_hbonds_for_bp - 1) + ) + bp_preds = (1.0 / (1.0 + np.exp(-logits))) + eps # [I, I] + + # Geometry filters + H_ij_filter = ( # [I, I] + (H_ij >= -mol_info.base_geometry_limits["H_ij"]) + & (H_ij <= mol_info.base_geometry_limits["H_ij"]) + ) + B_ij_filter = ( # [I, I] + (B_ij <= mol_info.base_geometry_limits["B_ij"]) + | (B_ij >= math.pi - mol_info.base_geometry_limits["B_ij"]) + ) + P_ij_filter = ( # [I, I] + (P_ij <= mol_info.base_geometry_limits["P_ij"]) + | (P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"]) + ) + bp_geom_filter = H_ij_filter & B_ij_filter & P_ij_filter # [I, I] + + if bool_only: + basepairs_bool_ij = ( # [I, I] + (~seq_neighbors) & bp_geom_filter + & (bp_preds >= float(mol_info.bp_val_cutoff)) + ) + return basepairs_bool_ij + + basepairs_ij = ( # [I, I] + (~seq_neighbors).astype(np.float32) + * bp_geom_filter.astype(np.float32) + * bp_preds.astype(np.float32) + ) + basepairs_bool_ij = basepairs_ij >= mol_info.bp_val_cutoff # [I, I] + + return { + "basepairs_bool_ij": basepairs_bool_ij, + "basepairs_ij": basepairs_ij, + "hbond_summation": hbond_summation, + } + + +def compute_nucleic_ss( mol_info, token_level_data, hbond_count, @@ -858,276 +1006,132 @@ def _compute_nucleic_ss_impl( return_opening_angle: bool = False, return_basepairs_only: bool = False, ): - """ - Compute nucleic secondary structure–related quantities and pairwise base params. + """Compute nucleic-acid pairwise base-pair geometry and filters. - Notes - ----- - This function is used in two modes: + Operates in two modes: - - Fast annotation mode (default): computes only what is needed to derive - ``basepairs_bool_ij`` and does *not* retain large intermediate pairwise - geometry arrays (X_ij/Y_ij/Z_ij/O_ij). - - Diagnostic mode: set ``return_pairwise_geometry=True`` (and optionally - ``return_local_params=True`` / ``return_opening_angle=True``) to also - return additional geometry arrays. + * **Fast annotation** (default / ``return_basepairs_only=True``): returns + only ``basepairs_bool_ij`` and frees intermediate arrays. + * **Diagnostic**: additionally returns local/pairwise geometry when + ``return_pairwise_geometry``, ``return_local_params``, or + ``return_opening_angle`` are set. + + Args: + mol_info: Molecular-info constants (geometry limits, H-bond weights). + token_level_data: Token-level dict with geometry (from + :func:`add_token_level_geometry_data`). + hbond_count: H-bond count array, shape ``(I_full, I_full, 3)``. + clamp_pairwise_params: Clamp cosines to [-1, 1] before ``arccos``. + eps: Small constant for numerical stability. + return_local_params: Return per-residue X/Y/Z local frames. + return_pairwise_geometry: Return pairwise X_ij/Y_ij/Z_ij arrays. + return_opening_angle: Return pairwise opening angle O_ij. + return_basepairs_only: Return only the boolean base-pair mask + (fastest path). + + Returns: + If ``return_basepairs_only``: ``np.ndarray`` of shape ``(I, I)`` + (bool) — the base-pair boolean mask. + + Otherwise: dict ``{"pair_params": {...}, "local_params": {...}}`` + containing the requested geometry arrays (all shape ``(I, I)`` or + ``(I, 3)``). """ - mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) - len_mask = int(mask_1d.sum()) - # len_full = len(mask_1d) - - # unpack 1D data from token_level_data and apply filters - M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] - frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[mask_1d] - - is_na = np.asarray(token_level_data["is_na"], dtype=bool)[mask_1d] - - xyz_S_start = [xyz_list_i for xyz_list_i, keep_i in zip(token_level_data["xyz_S_start"], mask_1d) if keep_i] - xyz_S_stop = [xyz_list_i for xyz_list_i, keep_i in zip(token_level_data["xyz_S_stop"], mask_1d) if keep_i] - xyz_planar = [xyz_list_i for xyz_list_i, keep_i in zip(token_level_data["xyz_planar"], mask_1d) if keep_i] - + mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) # [I_full] - # unpack 2D data from token_level_data and apply filters - hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] - seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[mask_1d, :][:, mask_1d] + # --- Unpack and filter token-level data ---------------------- + M_i = np.asarray(token_level_data["M_i"], dtype=np.float32)[mask_1d] # [I, 3] + frame_xyz = np.asarray(token_level_data["frame_xyz"], dtype=np.float32)[mask_1d] # [I, 3] + xyz_S_start = [v for v, k in zip(token_level_data["xyz_S_start"], mask_1d) if k] # list[I] of [3] + xyz_S_stop = [v for v, k in zip(token_level_data["xyz_S_stop"], mask_1d) if k] # list[I] of [3] + xyz_planar = [v for v, k in zip(token_level_data["xyz_planar"], mask_1d) if k] # list[I] of [K_i, 3] - # --- CALC 0: precompute displacement vectors / distances ---- - planar_centers = np.stack( - [ - np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) - for xyz_i in xyz_planar - ], + hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3] + seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[mask_1d, :][:, mask_1d] # [I, I] + + # --- Precompute centroids and displacement vectors ----------- + planar_centers = np.stack( # [I, 3] + [np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) for xyz_i in xyz_planar], axis=0, ).astype(np.float32) - - frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [L, L, 3] - sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [L, L, 3] - # D_ij = frame_D_ij_vec.norm(dim=-1) # [L, L] + frame_D_ij_vec = frame_xyz[None, :, :] - frame_xyz[:, None, :] # [I, I, 3] + sc_D_ij_vec = planar_centers[None, :, :] - planar_centers[:, None, :] # [I, I, 3] - - # --- CALC I: local base params (canonical frames) ------------ - centered_points = [ - np.asarray(xyz_i, dtype=np.float32) - cen_i - for xyz_i, cen_i in zip(xyz_planar, planar_centers) - ] - - # eigenvectors per residue: [L, 3, 3] (NaNs where invalid) - eigenvectors = np.full((len_mask, 3, 3), np.nan, dtype=np.float32) - - for i, xyz_i in enumerate(centered_points): - xyz_i = xyz_i[~np.isnan(xyz_i).any(axis=1)] - if xyz_i.shape[0] >= 3: - cov_matrix = np.einsum("ij,ik->jk", xyz_i, xyz_i) / max( - xyz_i.shape[0] - 1, 1 - ) - _, eigvecs = np.linalg.eigh(cov_matrix) - eigenvectors[i] = eigvecs - - - # base-normal (principal) direction N_i, then corrected Z_i - N_i = eigenvectors[:, :, 0] - N_i = N_i / (np.linalg.norm(N_i, axis=1, keepdims=True) + eps) - - Z_i = N_i * np.sum(M_i * N_i, axis=-1, keepdims=True) - Z_i = Z_i / (np.linalg.norm(Z_i, axis=-1, keepdims=True) + eps) - - # Only compute full local frames when requested. - # Basepair filters only need Z_i (via Z_ij) and do not require X_i/Y_i. - local_base_params = None - if return_local_params or return_opening_angle: - # Sugar-edge vectors X_s_i built from S_start/stop - X_s_i = ( - np.asarray(xyz_S_stop, dtype=np.float32) - - np.asarray(xyz_S_start, dtype=np.float32) - ) - X_s_i = X_s_i / (np.linalg.norm(X_s_i, axis=-1, keepdims=True) + eps) - - X_i = np.cross(Z_i, X_s_i) - X_i = X_i / (np.linalg.norm(X_i, axis=-1, keepdims=True) + eps) - - if return_local_params: - Y_i = np.cross(X_i, Z_i) - Y_i = Y_i / (np.linalg.norm(Y_i, axis=-1, keepdims=True) + eps) - local_base_params = {"X_i": X_i, "Y_i": Y_i, "Z_i": Z_i} - else: - # Opening needs X_i but not the local params dict. - local_base_params = None - - # --- CALC II: pairwise base parameters ----------------------- - - # stack mean Z-direction vectors for parallel (0) and antiparallel (1) - Z_sum = Z_i[:, None, :] + Z_i[None, :, :] - Z_diff = Z_i[:, None, :] - Z_i[None, :, :] - Z_ij_oris = 0.5 * np.stack((Z_sum, Z_diff), axis=0) # [2, L, L, 3] - - base_ori_ij = ( - np.linalg.norm(Z_ij_oris[1], axis=-1) > np.linalg.norm(Z_ij_oris[0], axis=-1) - ).astype(np.int64) # [L, L] - - Z_ij = np.where(base_ori_ij[..., None] == 0, Z_ij_oris[0], Z_ij_oris[1]) - Z_ij = Z_ij / (np.linalg.norm(Z_ij, axis=-1, keepdims=True) + eps) - - Y_ij = frame_D_ij_vec / (np.linalg.norm(frame_D_ij_vec, axis=-1, keepdims=True) + eps) - X_ij = np.cross(Z_ij, Y_ij) - X_ij = X_ij / (np.linalg.norm(X_ij, axis=-1, keepdims=True) + eps) - - # vertical displacement using sidechain centroids - H_ij = np.sum(sc_D_ij_vec * Z_ij, axis=-1) - # H_ij_vec = H_ij[..., None] * Z_ij - - # Opening (O_ij) is purely diagnostic; compute only if requested. - O_ij = None - if return_opening_angle: - if not (return_local_params or return_opening_angle): - raise RuntimeError("Internal error: opening angle requested without local frame") - - proj_X_i_XY = ( - np.sum(X_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij - + np.sum(X_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij - ) - proj_X_i_XY_norm = proj_X_i_XY / ( - np.linalg.norm(proj_X_i_XY, axis=-1, keepdims=True) + eps - ) - cos_opening = np.sum( - proj_X_i_XY_norm * proj_X_i_XY_norm.swapaxes(0, 1), - axis=-1, - ) - if clamp_pairwise_params: - cos_opening = np.clip(cos_opening, -1.0, 1.0) - O_ij = np.arccos(cos_opening) - - # Buckle (B_ij) - proj_Z_i_YZ = ( - np.sum(Z_i[:, None, :] * Y_ij, axis=-1, keepdims=True) * Y_ij - + np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij + # --- CALC I: per-residue local coordinate frames ------------- + need_full_frame = return_local_params or return_opening_angle + local_frames = _compute_local_frames( + xyz_planar, + planar_centers, + M_i, + xyz_S_start=xyz_S_start if need_full_frame else None, + xyz_S_stop=xyz_S_stop if need_full_frame else None, + compute_full_frame=need_full_frame, + eps=eps, ) - proj_Z_i_YZ_norm = proj_Z_i_YZ / ( - np.linalg.norm(proj_Z_i_YZ, axis=-1, keepdims=True) + eps - ) - cos_buckle = np.sum( - proj_Z_i_YZ_norm * (-proj_Z_i_YZ_norm.swapaxes(0, 1)), - axis=-1, + Z_i = local_frames["Z_i"] # [I, 3] + X_i = local_frames.get("X_i") # [I, 3] or None + + # --- CALC II: pairwise base-step geometry -------------------- + pw_geom = _compute_pairwise_geometry( + Z_i, + frame_D_ij_vec, + sc_D_ij_vec, + X_i=X_i, + clamp=clamp_pairwise_params, + compute_opening=return_opening_angle, + eps=eps, ) - # Propeller (P_ij) - proj_Z_i_ZX = ( - np.sum(Z_i[:, None, :] * Z_ij, axis=-1, keepdims=True) * Z_ij - + np.sum(Z_i[:, None, :] * X_ij, axis=-1, keepdims=True) * X_ij - ) - proj_Z_i_ZX_norm = proj_Z_i_ZX / ( - np.linalg.norm(proj_Z_i_ZX, axis=-1, keepdims=True) + eps - ) - cos_propeller = np.sum( - proj_Z_i_ZX_norm * (-proj_Z_i_ZX_norm.swapaxes(0, 1)), - axis=-1, + # --- CALC III: base-pair identification ---------------------- + bp_result = _compute_basepair_mask( + hbond_count, + seq_neighbors, + pw_geom["H_ij"], + pw_geom["B_ij"], + pw_geom["P_ij"], + mol_info, + bool_only=return_basepairs_only, + eps=eps, ) - if clamp_pairwise_params: - cos_buckle = np.clip(cos_buckle, -1.0, 1.0) - cos_propeller = np.clip(cos_propeller, -1.0, 1.0) - - B_ij = np.arccos(cos_buckle) - P_ij = np.arccos(cos_propeller) - - pair_params: dict | None if return_basepairs_only: - pair_params = None - else: - pair_params = { - "H_ij": H_ij, - "B_ij": B_ij, - "P_ij": P_ij, - "base_ori_ij": base_ori_ij, + return bp_result # np.ndarray [I, I] bool + + # --- Assemble output dict ------------------------------------ + assert isinstance(bp_result, dict) + + pair_params: dict[str, np.ndarray] = { + "H_ij": pw_geom["H_ij"], + "B_ij": pw_geom["B_ij"], + "P_ij": pw_geom["P_ij"], + "base_ori_ij": pw_geom["base_ori_ij"], + "basepairs_bool_ij": bp_result["basepairs_bool_ij"], + "basepairs_ij": bp_result["basepairs_ij"], + "hbond_summation": bp_result["hbond_summation"], + } + + if return_opening_angle and "O_ij" in pw_geom: + pair_params["O_ij"] = pw_geom["O_ij"] + + if return_pairwise_geometry: + pair_params["X_ij"] = pw_geom["X_ij"] + pair_params["Y_ij"] = pw_geom["Y_ij"] + pair_params["Z_ij"] = pw_geom["Z_ij"] + + nucleic_ss_data: dict = {"pair_params": pair_params} + if return_local_params and "Y_i" in local_frames: + nucleic_ss_data["local_params"] = { + "X_i": local_frames["X_i"], + "Y_i": local_frames["Y_i"], + "Z_i": local_frames["Z_i"], } - if return_opening_angle and O_ij is not None: - pair_params["O_ij"] = O_ij - - if return_pairwise_geometry: - pair_params["X_ij"] = X_ij - pair_params["Y_ij"] = Y_ij - pair_params["Z_ij"] = Z_ij - - # --- CALC III: basepair filters / probabilities -------------- - hbond_summation = np.tensordot( - hbond_count.astype(np.float32), - np.asarray(mol_info.bp_summation_weights, dtype=np.float32), - axes=([2], [0]), - ) # [L, L] - - logits = mol_info.bp_hbond_coeff * ( - hbond_summation - (mol_info.min_hbonds_for_bp - 1) - ) - bp_preds = (1.0 / (1.0 + np.exp(-logits))) + eps - - # Filter Height geometry - H_ij_filter = (H_ij >= -mol_info.base_geometry_limits["H_ij"]) & ( - H_ij <= mol_info.base_geometry_limits["H_ij"] - ) - # Filter Buckle geometry - B_ij_filter = (B_ij <= mol_info.base_geometry_limits["B_ij"]) | ( - B_ij >= math.pi - mol_info.base_geometry_limits["B_ij"] - ) - # Filter Propeller geometry - P_ij_filter = (P_ij <= mol_info.base_geometry_limits["P_ij"]) | ( - P_ij >= math.pi - mol_info.base_geometry_limits["P_ij"] - ) - - bp_geom_filter = (H_ij_filter & B_ij_filter & P_ij_filter) - - if return_basepairs_only: - # Avoid allocating basepairs_ij float matrix when only the boolean mask is needed. - basepairs_bool_ij = (~seq_neighbors) & bp_geom_filter & ( - bp_preds >= float(mol_info.bp_val_cutoff) - ) - basepairs_ij = None - else: - basepairs_ij = (~seq_neighbors).astype(np.float32) * ( - bp_geom_filter.astype(np.float32) * bp_preds.astype(np.float32) - ) - basepairs_bool_ij = basepairs_ij >= mol_info.bp_val_cutoff - - if return_basepairs_only: - # Cleanup intermediate tensors to free memory - del frame_D_ij_vec, sc_D_ij_vec - del hbond_summation, bp_preds - del H_ij_filter, B_ij_filter, P_ij_filter, bp_geom_filter - - # Explicitly drop the largest pairwise arrays. - del X_ij, Y_ij, Z_ij - if O_ij is not None: - del O_ij - if local_base_params is not None: - del local_base_params - return basepairs_bool_ij - - assert pair_params is not None - - pair_params["basepairs_bool_ij"] = basepairs_bool_ij - pair_params["hbond_summation"] = hbond_summation - pair_params["basepairs_ij"] = basepairs_ij - - nucleic_ss_data = {"pair_params": pair_params} - if return_local_params and local_base_params is not None: - nucleic_ss_data["local_params"] = local_base_params - - # Cleanup intermediate tensors to free memory - del frame_D_ij_vec, sc_D_ij_vec - del hbond_summation, bp_preds - del H_ij_filter, B_ij_filter, P_ij_filter, bp_geom_filter - - # If not returning, explicitly drop the largest pairwise arrays. - if not return_pairwise_geometry: - del X_ij, Y_ij, Z_ij - if not return_opening_angle and O_ij is not None: - del O_ij - if not return_local_params and local_base_params is not None: - del local_base_params - return nucleic_ss_data + + def annotate_na_ss( atom_array: AtomArray, *, @@ -1137,41 +1141,47 @@ def annotate_na_ss( mol_info: Optional[NucMolInfo] = None, overwrite: bool = True, token_level_data: Optional[dict] = None, + cutoff_HA_dist: float = 3.5, + cutoff_DA_dist: float = 3.5, ) -> AtomArray: - """Annotate base-pair partners directly onto the AtomArray. + """Compute base pairs and write a ``bp_partners`` annotation onto *atom_array*. - This computes nucleic-acid base pairing similarly to - :func:`get_gt_nucleic_geom_feats` but instead of returning an integer - secondary-structure matrix, it writes an AtomArray annotation - ``bp_partners``. + Uses H-bond counts and pairwise geometry filters to identify base pairs, + then stores the result as a per-atom annotation with the following + semantics: - The annotation is stored on the *full* ``atom_array`` (length N atoms), - but only nucleic-acid token-representative atoms (indices ``token_starts`` - from :func:`get_token_starts`) that are included in this call's - ``annotation_mask`` get a list value. + * ``[]`` — explicitly unpaired (loop) + * ``[token_id, ...]`` — paired partner token IDs + * ``None`` — unannotated / masked (non-NA or filtered-out tokens) - Semantics: - - ``[]`` (empty list): explicitly unpaired nucleic-acid loop - - ``[token_id, ...]``: paired nucleic-acid token(s) - - ``None``: unannotated/masked (non-NA tokens, or tokens filtered out) + Args: + atom_array: Structure to annotate (modified in-place). + NA_only: Restrict geometry filter to nucleic-acid tokens. + planar_only: Restrict geometry filter to tokens with planar + sidechains. + p_canonical_bp_filter: Probability of discarding non-canonical + base pairs (keeps only A–U, A–T, G–C). + mol_info: Molecular-info constants; created if ``None``. + overwrite: If False, merge with existing ``bp_partners``. + token_level_data: Pre-computed metadata dict; augmented with + geometry as needed. + cutoff_HA_dist: H–A distance cutoff (Å) for HBPLUS. + cutoff_DA_dist: D–A distance cutoff (Å) for HBPLUS. - Each list element is the partner token identifier (``token_id`` as int) - for the paired residue. This is sufficient to recover the partner's - token-representative atom via ``token_starts`` + token_id mapping. - - Notes - ----- - - ``token_level_data`` may be metadata-only; this function will augment it - with geometry as needed. - - If ``p_canonical_bp_filter > 0``, then with that probability we discard - any non-canonical NA basepairs (keeps only A-U, A-T, G-C). + Returns: + The same *atom_array* with the ``bp_partners`` annotation set. """ if mol_info is None: mol_info = NucMolInfo() - # Token representatives (0..L-1) and their corresponding atom indices (into atom_array) - token_starts = get_token_starts(atom_array) + # Residue representatives (0..L-1) and their corresponding atom indices. + # Keep this aligned with get_token_level_metadata(), which uses residue starts. + if token_level_data is not None and "token_starts" in token_level_data: + token_starts = np.asarray(token_level_data["token_starts"], dtype=int) + else: + token_starts = struc.get_residue_starts(atom_array) + residue_start_end = np.concatenate([token_starts, [atom_array.array_length()]]) token_level_array = atom_array[token_starts] # token_id is assigned token-wise and matches get_token_starts() segmentation. token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)] @@ -1204,11 +1214,11 @@ def annotate_na_ss( atom_array, token_level_data, mol_info, - cutoff_HA_dist=mol_info.cutoff_HA_dist, - cutoff_DA_dist=mol_info.cutoff_DA_dist, + cutoff_HA_dist=cutoff_HA_dist, + cutoff_DA_dist=cutoff_DA_dist, ) bp_bool = np.asarray( - _compute_nucleic_ss_impl( + compute_nucleic_ss( mol_info, token_level_data, hbond_count, @@ -1222,6 +1232,18 @@ def annotate_na_ss( dtype=bool, ) + # Apply optional filters + if NA_only: + bp_bool &= is_na_full[:, None] + bp_bool &= is_na_full[None, :] + if planar_only: + n_tokens = bp_bool.shape[0] + has_planar_sc = np.asarray( + token_level_data.get("has_planar_sc", np.ones(n_tokens, dtype=bool)), dtype=bool + ) + bp_bool &= has_planar_sc[:, None] + bp_bool &= has_planar_sc[None, :] + # Optional: filter to canonical Watson-Crick basepairs only. # Sampled probabilistically to allow mixed supervision during training. do_canonical_filter = bool(p_canonical_bp_filter and (np.random.rand() < float(p_canonical_bp_filter))) @@ -1255,16 +1277,18 @@ def annotate_na_ss( bp_bool = np.asarray(bp_bool, dtype=bool) bp_rows, bp_cols = np.nonzero(bp_bool) - # Prepare/overwrite annotation array + # Build residue-level annotation first, then spread to all atoms in each residue. if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()): - bp_partners_ann = atom_array.bp_partners - if len(bp_partners_ann) != len(atom_array): - raise ValueError( - "Existing bp_partners annotation has wrong length" - ) + existing_ann = atom_array.bp_partners + if len(existing_ann) != len(atom_array): + raise ValueError("Existing bp_partners annotation has wrong length") + residue_bp_partners = np.empty(len(token_starts), dtype=object) + residue_bp_partners[:] = None + for i, start in enumerate(token_starts.tolist()): + residue_bp_partners[i] = existing_ann[int(start)] else: - bp_partners_ann = np.empty(len(atom_array), dtype=object) - bp_partners_ann[:] = None + residue_bp_partners = np.empty(len(token_starts), dtype=object) + residue_bp_partners[:] = None # Explicit-loop semantics: # - Only nucleic-acid token-start atoms *within subset_idxs* get a list container. @@ -1273,9 +1297,8 @@ def annotate_na_ss( for full_i in subset_idxs.tolist(): if not bool(is_na_full[int(full_i)]): continue - atom_i = int(token_starts[int(full_i)]) - if bp_partners_ann[atom_i] is None: - bp_partners_ann[atom_i] = [] + if residue_bp_partners[int(full_i)] is None: + residue_bp_partners[int(full_i)] = [] # Populate partners using token_id ints # We only process each unordered pair once to avoid duplicates. @@ -1296,149 +1319,58 @@ def annotate_na_ss( if full_j < full_i: continue - atom_i = int(token_starts[full_i]) - atom_j = int(token_starts[full_j]) partner_i = int(token_ids[full_j]) partner_j = int(token_ids[full_i]) - if bp_partners_ann[atom_i] is None: - bp_partners_ann[atom_i] = [] - if bp_partners_ann[atom_j] is None: - bp_partners_ann[atom_j] = [] + if residue_bp_partners[full_i] is None: + residue_bp_partners[full_i] = [] + if residue_bp_partners[full_j] is None: + residue_bp_partners[full_j] = [] # Add if not present - if partner_i not in bp_partners_ann[atom_i]: - bp_partners_ann[atom_i].append(partner_i) - if partner_j not in bp_partners_ann[atom_j]: - bp_partners_ann[atom_j].append(partner_j) + if partner_i not in residue_bp_partners[full_i]: + residue_bp_partners[full_i].append(partner_i) + if partner_j not in residue_bp_partners[full_j]: + residue_bp_partners[full_j].append(partner_j) + + # Project residue-level annotations back to atom-level storage: + # - atomized residues: spread to all atoms in that residue + # - non-atomized residues: keep only on token-start representative atom + bp_partners_ann = np.empty(len(atom_array), dtype=object) + bp_partners_ann[:] = None + for i, start in enumerate(token_starts.tolist()): + stop = int(residue_start_end[i + 1]) + value = residue_bp_partners[i] + if value is None: + continue + # A residue is treated as atomized if any atom in the residue carries atomize=True. + if "atomize" in atom_array.get_annotation_categories(): + residue_is_atomized = bool(np.any(np.asarray(atom_array.atomize[int(start):stop], dtype=bool))) + else: + residue_is_atomized = False + if residue_is_atomized: + for atom_idx in range(int(start), stop): + bp_partners_ann[atom_idx] = list(value) + else: + bp_partners_ann[int(start)] = list(value) atom_array.set_annotation("bp_partners", bp_partners_ann) return atom_array -def bp_partner_to_ss_matrix( - atom_array: AtomArray, - *, - feature_info: Optional[dict] = None, - mol_info: Optional[NucMolInfo] = None, - NA_only: Optional[bool] = False, - planar_only: Optional[bool] = False, - include_loops: bool = True, - token_level_data: Optional[dict] = None, -) -> np.ndarray: - """Reconstruct an integer NA secondary-structure matrix from annotations. - - Requires that ``atom_array`` has a ``bp_partners`` annotation created by - :func:`annotate_na_ss`. - - Returns - ------- - ss_matrix : np.ndarray - Shape (L, L) with values from ``feature_info``. - - Loop semantics: - - Only nucleic-acid tokens can be loops. - - Only tokens with an explicit empty list ``bp_partners == []`` are loops. - Unannotated tokens (``bp_partners is None``) remain masked. - """ - - if mol_info is None: - mol_info = NucMolInfo() - - if feature_info is None: - feature_info = DEFAULT_NA_SS_FEATURE_INFO - - if "bp_partners" not in atom_array.get_annotation_categories(): - raise ValueError( - "atom_array is missing bp_partners annotation; run annotate_na_ss() first" - ) - - token_starts = get_token_starts(atom_array) - token_level_array = atom_array[token_starts] - token_ids_int: list[int] = [int(t) for t in list(token_level_array.token_id)] - token_id_to_index_int = {int(tid): i for i, tid in enumerate(token_ids_int)} - L = len(token_starts) - - ss_matrix = feature_info["NA_SS_MASK"] * np.ones((L, L), dtype=np.int64) - - if token_level_data is None: - token_level_data = get_token_level_metadata( - atom_array, - mol_info, - # NA_only=NA_only, - # planar_only=planar_only, - ) - - mask_1d = np.asarray(token_level_data["filter_mask"], dtype=bool) - subset_idxs = np.nonzero(mask_1d)[0] - subset_set = set(int(x) for x in subset_idxs.tolist()) - is_na = np.asarray(token_level_data["is_na"], dtype=bool) - subset_na_idxs = subset_idxs[np.asarray(is_na[subset_idxs], dtype=bool)] - subset_na_set = set(int(x) for x in subset_na_idxs.tolist()) - - # Fill base-pair edges (only within subset, and only NA-NA) - bp_partners_ann = atom_array.bp_partners - for i in subset_idxs.tolist(): - if not bool(is_na[int(i)]): - continue - atom_i = int(token_starts[int(i)]) - partners = bp_partners_ann[atom_i] - if partners is None: - continue - if not isinstance(partners, (list, tuple, np.ndarray)): - continue - for partner_token_id in partners: - # Support int, numpy scalar, and legacy stringified token_id. - try: - partner_tid_int = int(partner_token_id) - except Exception: - partner_tid_int = None - j = token_id_to_index_int.get(partner_tid_int) if partner_tid_int is not None else None - if j is None or j == i: - continue - if int(j) not in subset_set: - continue - if not bool(is_na[int(j)]): - continue - ss_matrix[i, j] = feature_info["NA_SS_PAIR"] - ss_matrix[j, i] = feature_info["NA_SS_PAIR"] - - if not include_loops: - return ss_matrix - - # Loop labeling is explicit and NA-only: - # - only nucleic tokens can be loops - # - only tokens with an explicit empty list annotation are loops - loop_idxs_list: list[int] = [] - for i in subset_idxs.tolist(): - if not bool(is_na[int(i)]): - continue - atom_i = int(token_starts[int(i)]) - partners = bp_partners_ann[atom_i] - if not isinstance(partners, (list, tuple, np.ndarray)): - continue - if len(partners) == 0: - loop_idxs_list.append(int(i)) - - loop_idxs = np.asarray(loop_idxs_list, dtype=np.int64) - if loop_idxs.size > 0: - ss_matrix[loop_idxs[:, None], subset_na_idxs[None, :]] = feature_info["NA_SS_LOOP"] - ss_matrix[subset_na_idxs[:, None], loop_idxs[None, :]] = feature_info["NA_SS_LOOP"] - - return ss_matrix - - def parse_dot_bracket(dot_bracket: str) -> tuple[list[tuple[int, int]], list[int]]: """Parse a dot-bracket string into base pairs and unpaired positions. - Supports (), [], {}, <>, and A..E / a..e bracket pairs. + Supports standard ``()``, ``[]``, ``{}``, ``<>`` and pseudoknot + brackets ``A``–``E`` / ``a``–``e``. - Returns - ------- - pairs : list of (i, j) - 0-based indices in the string for paired positions. - unpaired : list of int - 0-based indices that are '.' (unpaired). + Args: + dot_bracket: Dot-bracket notation string. + + Returns: + Tuple of ``(pairs, unpaired)`` where *pairs* is a list of 0-based + ``(i, j)`` index tuples and *unpaired* is a list of 0-based indices + corresponding to ``.`` characters. """ stack: dict[str, list[int]] = {} @@ -1480,19 +1412,27 @@ def annotate_na_ss_from_specification( *, overwrite: bool = True, ) -> AtomArray: - """Annotate ``bp_partners`` from an inference-time specification. + """Write ``bp_partners`` annotation from an inference-time specification. - This is the inference analogue of :func:`annotate_na_ss`, except instead - of computing base pairs from geometry/H-bonds, it interprets a user-provided - specification (dot-bracket strings and/or residue ranges/positions) and - writes the same ``bp_partners`` annotation on token-representative atoms. + Inference analogue of :func:`annotate_na_ss`: interprets user-provided + dot-bracket strings and/or residue ranges rather than computing base + pairs from geometry. - Supported spec keys (all optional): - - ``ss_dbn``: global dot-bracket string applied to the first L tokens. - - ``ss_dbn_dict``: mapping like {"A5-15": dbn_str, ...}. - - ``paired_region_list``: list of "A5-15,B1-11" entries. - - ``paired_position_list``: list of "A19,A61,A20" groups. - - ``loop_region_list``: list of "A5-10" regions forced unpaired. + Supported *specification* keys (all optional): + + * ``ss_dbn``: global dot-bracket string (applied to the first *L* tokens). + * ``ss_dbn_dict``: ``{"-": dbn_str, ...}``. + * ``paired_region_list``: ``["A5-15,B1-11", ...]``. + * ``paired_position_list``: ``["A19,A61,A20", ...]``. + * ``loop_region_list``: ``["A5-10", ...]`` (forced unpaired). + + Args: + atom_array: Structure to annotate (modified in-place). + specification: Specification dict as described above. + overwrite: If False, merge with existing ``bp_partners``. + + Returns: + The same *atom_array* with the ``bp_partners`` annotation set. """ spec = specification or {} @@ -1502,11 +1442,17 @@ def annotate_na_ss_from_specification( n_tokens = len(token_starts) # Explicit loops are only meaningful for nucleic-acid tokens. - seq_data = _get_sequence_encoding_data() - is_rna_like = np.isin(token_level_array.res_name, seq_data["rna_like_res_names"]) - is_dna_like = np.isin(token_level_array.res_name, seq_data["dna_like_res_names"]) + # Instantiate encoding locally to avoid retaining large arrays at module scope. + sequence_encoding = AF3SequenceEncoding() + is_rna_like = np.isin( + token_level_array.res_name, + sequence_encoding.all_res_names[sequence_encoding.is_rna_like], + ) + is_dna_like = np.isin( + token_level_array.res_name, + sequence_encoding.all_res_names[sequence_encoding.is_dna_like], + ) is_na_token = np.asarray(is_rna_like | is_dna_like, dtype=bool) - del seq_data # Prepare/overwrite annotation array if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()): @@ -1703,8 +1649,18 @@ def annotate_na_ss_from_data_specification( *, overwrite: bool = True, ) -> AtomArray: - """Convenience wrapper: annotate bp partners from ``data['specification']``.""" + """Annotate ``bp_partners`` from ``data["specification"]``. + + Convenience wrapper around :func:`annotate_na_ss_from_specification`. + + Args: + data: Pipeline data dict containing ``atom_array`` and optionally + ``specification``. + overwrite: If False, merge with existing ``bp_partners``. + + Returns: + The annotated AtomArray (also stored back in *data*). + """ atom_array = data["atom_array"] spec = data.get("specification", {}) or {} return annotate_na_ss_from_specification(atom_array, spec, overwrite=overwrite) -