mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
cleaned up NA-SS conditioning code before rebase
This commit is contained in:
@@ -440,74 +440,73 @@ backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
|
||||
|
||||
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
|
||||
ATOM_REGION_BY_RESI = {
|
||||
'ALA': {'bb':('N','CA','C','O'),
|
||||
'ALA': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'ARG': {'bb':('N','CA','C','O'),
|
||||
'ARG': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','NE','CZ','NH1','NH2')},
|
||||
'ASN': {'bb':('N','CA','C','O'),
|
||||
'ASN': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','OD1','ND2')},
|
||||
'ASP': {'bb':('N','CA','C','O'),
|
||||
'ASP': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','OD1','OD2')},
|
||||
'CYS': {'bb':('N','CA','C','O'),
|
||||
'CYS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','SG')},
|
||||
'GLN': {'bb':('N','CA','C','O'),
|
||||
'GLN': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','OE1','NE2')},
|
||||
'GLU': {'bb':('N','CA','C','O'),
|
||||
'GLU': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','OE1','OE2')},
|
||||
'GLY': {'bb':('N','CA','C','O'),
|
||||
'GLY': {'bb':('N','CA','C','O'),
|
||||
'sc':()},
|
||||
'HIS': {'bb':('N','CA','C','O'),
|
||||
'HIS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','ND1','CD2','CE1','NE2')},
|
||||
'ILE': {'bb':('N','CA','C','O'),
|
||||
'ILE': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG1','CG2','CD1')},
|
||||
'LEU': {'bb':('N','CA','C','O'),
|
||||
'LEU': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2')},
|
||||
'LYS': {'bb':('N','CA','C','O'),
|
||||
'LYS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD','CE','NZ')},
|
||||
'MET': {'bb':('N','CA','C','O'),
|
||||
'MET': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','SD','CE')},
|
||||
'PHE': {'bb':('N','CA','C','O'),
|
||||
'PHE': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')},
|
||||
'PRO': {'bb':('N','CA','C','O'),
|
||||
'PRO': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD')},
|
||||
'SER': {'bb':('N','CA','C','O'),
|
||||
'SER': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','OG')},
|
||||
'THR': {'bb':('N','CA','C','O'),
|
||||
'THR': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','OG1','CG2')},
|
||||
'TRP': {'bb':('N','CA','C','O'),
|
||||
'TRP': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')},
|
||||
'TYR': {'bb':('N','CA','C','O'),
|
||||
'TYR': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')},
|
||||
'VAL': {'bb':('N','CA','C','O'),
|
||||
'VAL': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG1','CG2')},
|
||||
'UNK': {'bb':('N','CA','C','O'),
|
||||
'UNK': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'MAS': {'bb':('N','CA','C','O'),
|
||||
'MAS': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB')},
|
||||
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')},
|
||||
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
|
||||
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')},
|
||||
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')},
|
||||
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
|
||||
'sc':()},
|
||||
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')},
|
||||
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
|
||||
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')},
|
||||
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':('N1','C2','O2','N3','C4','O4','C5','C6')},
|
||||
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
|
||||
'sc':()},
|
||||
'HIS_D': {'bb':('N','CA','C','O'),
|
||||
'HIS_D': {'bb':('N','CA','C','O'),
|
||||
'sc':('CB','CG','NE2','CD2','CE1','ND1')},
|
||||
}
|
||||
|
||||
# Known planar sidechain atoms for each canonical residue type:
|
||||
PLANAR_ATOMS_BY_RESI = {
|
||||
'ALA': [],
|
||||
|
||||
@@ -45,7 +45,7 @@ from rfd3.transforms.util_transforms import (
|
||||
get_af3_token_representative_masks,
|
||||
)
|
||||
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
|
||||
from rfd3.transforms.na_geom import get_bp_feats_from_atom_array
|
||||
from rfd3.transforms.na_geom import na_ss_feats_from_annotation
|
||||
|
||||
from foundry.utils.ddp import RankedLogger # noqa
|
||||
|
||||
@@ -811,7 +811,7 @@ class AddAdditional2dFeaturesToFeats(Transform):
|
||||
# Need to pre-define custom constructor functions
|
||||
# to map from atomarray annotations to tensors.
|
||||
self.constructor_functions = {
|
||||
'bp_partners': get_bp_feats_from_atom_array,
|
||||
'bp_partners': na_ss_feats_from_annotation,
|
||||
}
|
||||
|
||||
def check_input(self, data) -> None:
|
||||
|
||||
@@ -12,86 +12,80 @@ from rfd3.transforms.conditioning_utils import sample_island_tokens
|
||||
from rfd3.transforms.na_geom_utils import (
|
||||
annotate_na_ss,
|
||||
annotate_na_ss_from_data_specification,
|
||||
bp_partner_to_ss_matrix,
|
||||
DEFAULT_NA_SS_FEATURE_INFO,
|
||||
)
|
||||
|
||||
from atomworks.ml.utils.token import spread_token_wise, get_token_starts
|
||||
|
||||
def get_bp_feats_from_atom_array(
|
||||
atom_array: AtomArray,
|
||||
) -> np.ndarray:
|
||||
"""Build NA-SS features from atom_array annotations, assuming 'bp_partners' is present.
|
||||
|
||||
This function reconstructs the SS matrix from the 'bp_partners' annotation on the atom_array,
|
||||
then one-hot encodes it into a 3-class matrix (mask, pair, loop).
|
||||
def na_ss_feats_from_annotation(atom_array: AtomArray,
|
||||
token_starts= None,
|
||||
n_tokens = None,
|
||||
return_as_onehot = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
# Fixed feature info (inferred from usage in other functions)
|
||||
feature_info = {
|
||||
'NA_SS_MASK': 0, # Unspecified
|
||||
'NA_SS_PAIR': 1, # Paired
|
||||
'NA_SS_LOOP': 2, # Loop / unpaired
|
||||
'num_classes_nucleic_ss': 3,
|
||||
}
|
||||
Takes in atom array and constucts a base pair feature matrix from annotations,
|
||||
according to to custom feature constuction + masking system.
|
||||
This featurization utilizes info from BasePairEnum to assign int values
|
||||
to paired, unpaired, and masked positions in the matrix.
|
||||
|
||||
# Check for required annotation
|
||||
if "bp_partners" not in atom_array.get_annotation_categories():
|
||||
raise ValueError("atom_array must have 'bp_partners' annotation for NA-SS feature building.")
|
||||
Args:
|
||||
* atom_array: AtomArray with bp_partners annotation at atom level
|
||||
* token_starts (optional): indices of token starts in the atom array
|
||||
* n_tokens (optional): number of tokens (length of token_starts)
|
||||
* return_as_onehot (optional): if False, return integer-encoded
|
||||
matrix instead of one-hot encoded matrix
|
||||
|
||||
# Reconstruct SS matrix from annotations
|
||||
na_ss_matrix = np.asarray(
|
||||
bp_partner_to_ss_matrix(
|
||||
atom_array,
|
||||
feature_info=feature_info,
|
||||
NA_only=False, # Include all residues (logic from other utils)
|
||||
planar_only=True, # Use planar interactions (common default)
|
||||
include_loops=True, # Include loop states
|
||||
),
|
||||
dtype=np.int64,
|
||||
)
|
||||
returns:
|
||||
* na_ss_matrix:
|
||||
If ``return_as_onehot`` is True (default):
|
||||
np.ndarray of shape (n_tokens, n_tokens, n_classes)
|
||||
with one-hot encoded values according to BasePairEnum
|
||||
|
||||
# One-hot encode the matrix
|
||||
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
|
||||
eye = np.eye(int(feature_info['num_classes_nucleic_ss']), dtype=np.int64)
|
||||
return eye[na_ss_matrix_int]
|
||||
If ``return_as_onehot`` is False :
|
||||
np.ndarray of shape (n_tokens, n_tokens)
|
||||
with int values according to BasePairEnum
|
||||
|
||||
|
||||
def _build_na_ss_features_from_annotations(
|
||||
atom_array: AtomArray,
|
||||
*,
|
||||
feature_info: dict,
|
||||
num_classes: int,
|
||||
NA_only: bool,
|
||||
planar_only: bool,
|
||||
is_nucleic_ss_example: bool,
|
||||
give_partial_feats: bool,
|
||||
get_feature_mask_fn,
|
||||
) -> np.ndarray:
|
||||
"""Reconstruct SS matrix from annotations, optionally mask, then one-hot."""
|
||||
na_ss_matrix = np.asarray(
|
||||
bp_partner_to_ss_matrix(
|
||||
atom_array,
|
||||
feature_info=feature_info,
|
||||
NA_only=NA_only,
|
||||
planar_only=planar_only,
|
||||
include_loops=True,
|
||||
),
|
||||
dtype=np.int64,
|
||||
)
|
||||
"""
|
||||
# Get this info from atom_array, or avoid if given
|
||||
if (token_starts is None) or (n_tokens is None):
|
||||
token_starts = get_token_starts(atom_array)
|
||||
n_tokens = len(token_starts)
|
||||
|
||||
|
||||
n_tokens = int(na_ss_matrix.shape[0])
|
||||
# Collect token inds for paired or loop positions:
|
||||
pair_inds = []
|
||||
loop_inds = []
|
||||
token_bp_partners = atom_array.get_annotation("bp_partners")[token_starts] # get bp_partners at token level
|
||||
assert len(token_bp_partners) == n_tokens, "Length of token_bp_partners should match n_tokens"
|
||||
for i, j_list in enumerate(token_bp_partners):
|
||||
if j_list is not None:
|
||||
if len(j_list) > 0:
|
||||
for j in j_list:
|
||||
pair_inds.append((i, j))
|
||||
else:
|
||||
loop_inds.append(i)
|
||||
|
||||
if give_partial_feats:
|
||||
is_shown = (
|
||||
np.asarray(get_feature_mask_fn(n_tokens), dtype=bool)
|
||||
if is_nucleic_ss_example
|
||||
else np.zeros((n_tokens,), dtype=bool)
|
||||
)
|
||||
na_ss_matrix[~is_shown, :] = feature_info["NA_SS_MASK"]
|
||||
na_ss_matrix[:, ~is_shown] = feature_info["NA_SS_MASK"]
|
||||
# The standard system for constructing meaningful base pair features:
|
||||
# 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens)
|
||||
na_ss_matrix = np.full((n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64)
|
||||
|
||||
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
|
||||
eye = np.eye(int(num_classes), dtype=np.int64)
|
||||
return eye[na_ss_matrix_int]
|
||||
# 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list
|
||||
for pair_i, pair_j in pair_inds:
|
||||
na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"]
|
||||
na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] # ensure symmetry
|
||||
|
||||
# 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired)
|
||||
# (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired )
|
||||
for loop_i in loop_inds:
|
||||
na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"]
|
||||
na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] # ensure symmetry
|
||||
|
||||
# Optional: convert NA-SS matrix to one-hot encoding according for model input:
|
||||
if return_as_onehot:
|
||||
na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[na_ss_matrix]
|
||||
|
||||
return na_ss_matrix
|
||||
|
||||
|
||||
class CalculateNucleicAcidGeomFeats(Transform):
|
||||
@@ -119,11 +113,12 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
def __init__(
|
||||
self,
|
||||
is_inference,
|
||||
meta_conditioning_probabilities,
|
||||
|
||||
add_nucleic_ss_feats: bool = True,
|
||||
# Conditional sampling parameters:
|
||||
p_is_nucleic_ss_example: float = 0.3,
|
||||
p_show_partial_feats: float = 0.5,
|
||||
nucleic_ss_min_shown: float = 0.0,
|
||||
p_show_partial_feats: float = 0.7,
|
||||
# Mask control paramerers:
|
||||
nucleic_ss_min_shown: float = 0.2,
|
||||
nucleic_ss_max_shown: float = 1.0,
|
||||
n_islands_min: int = 1,
|
||||
n_islands_max: int = 6,
|
||||
@@ -136,39 +131,22 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
):
|
||||
# Critical, must always have to know how to handle
|
||||
self.is_inference = is_inference
|
||||
if not self.is_inference:
|
||||
## relevant in training
|
||||
self.sampling_prob = meta_conditioning_probabilities['calculate_NA_SS']
|
||||
else:
|
||||
## irrelevant in inference
|
||||
self.sampling_prob = 0
|
||||
# For sampling whether we add nucleic-ss features (extra t2d)
|
||||
|
||||
# relevant in training
|
||||
self.add_nucleic_ss_feats = (self.sampling_prob > 0)
|
||||
######
|
||||
|
||||
self.add_nucleic_ss_feats = add_nucleic_ss_feats
|
||||
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
|
||||
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
|
||||
self.p_show_partial_feats = p_show_partial_feats
|
||||
self.nucleic_ss_min_shown = nucleic_ss_min_shown
|
||||
self.nucleic_ss_max_shown = nucleic_ss_max_shown
|
||||
self.n_islands_min = n_islands_min
|
||||
self.n_islands_max = n_islands_max
|
||||
|
||||
self.p_show_partial_feats = p_show_partial_feats
|
||||
|
||||
# Filters for what can be considered a planar contact interaction
|
||||
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
|
||||
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
|
||||
self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter
|
||||
|
||||
# Inds of annotation types in the nucleic-ss features (stack of 3 matrices):
|
||||
self.feature_info = {
|
||||
'NA_SS_MASK' : 0, # Unspecified, or sm, or protein:
|
||||
'NA_SS_PAIR' : 1,
|
||||
'NA_SS_LOOP' : 2,
|
||||
'num_classes_nucleic_ss' : 3,
|
||||
}
|
||||
|
||||
|
||||
def check_input(self, data: dict[str, Any]) -> None:
|
||||
@@ -179,10 +157,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
|
||||
def _sample_training_flags(self) -> tuple[bool, bool]:
|
||||
"""Sample booleans controlling whether/how features are shown in training."""
|
||||
is_nucleic_ss_example = bool(
|
||||
self.add_nucleic_ss_feats
|
||||
and (np.random.rand() < self.p_is_nucleic_ss_example)
|
||||
)
|
||||
is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example)
|
||||
give_partial_feats = bool(
|
||||
np.random.rand() < self.p_show_partial_feats
|
||||
)
|
||||
@@ -194,28 +169,41 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
# Calculate n_tokens (assuming one token per residue for simplicity)
|
||||
token_starts = get_token_starts(atom_array)
|
||||
token_level_array = atom_array[token_starts]
|
||||
token_ids = [int(t) for t in token_level_array.token_id]
|
||||
n_tokens = len(token_starts)
|
||||
#TODO print(" DO I NEED TO CHANGE TO TOKEN_ID???")
|
||||
# Handle the training case with ground truth and masking:
|
||||
if not self.is_inference and (np.random.rand() < self.sampling_prob):
|
||||
|
||||
# Defaults for feature visibility
|
||||
is_nucleic_ss_example = True
|
||||
give_partial_feats = False
|
||||
token_mask_to_show = np.ones(n_tokens, dtype=bool)
|
||||
|
||||
# Handle the training case with ground truth and masking
|
||||
if not self.is_inference:
|
||||
|
||||
# First, annotate as usual
|
||||
# atom_array = annotate_na_ss(atom_array, **kwargs)
|
||||
atom_array = annotate_na_ss(atom_array,
|
||||
NA_only=self.NA_only,
|
||||
planar_only=self.planar_only,
|
||||
p_canonical_bp_filter=self.p_canonical_bp_filter,
|
||||
)
|
||||
|
||||
# Sample mask on token level:
|
||||
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
|
||||
is_ss_shown = self._sample_where_to_show_ss(n_tokens,
|
||||
is_nucleic_ss_example=is_nucleic_ss_example,
|
||||
give_partial_feats=give_partial_feats) # Mask vec for tokens where ss shown
|
||||
# Spread mask to atom level
|
||||
is_ss_shown = spread_token_wise(atom_array, is_ss_shown)
|
||||
|
||||
# Generate symmetric partner annotations at the token level for masking purposes.
|
||||
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
|
||||
partner_sym_map = {
|
||||
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
|
||||
for i, ts_i in enumerate(token_starts)
|
||||
}
|
||||
|
||||
# # Sample mask on token level:
|
||||
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
|
||||
token_mask_to_show = self._sample_where_to_show_ss(
|
||||
n_tokens,
|
||||
is_nucleic_ss_example=is_nucleic_ss_example,
|
||||
give_partial_feats=give_partial_feats,
|
||||
partner_sym_map=partner_sym_map,
|
||||
) # Mask vec for tokens where ss shown
|
||||
|
||||
# Spread mask to atom level
|
||||
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
|
||||
|
||||
# Extract the base pair annotations
|
||||
bp_partners_atom = atom_array.get_annotation("bp_partners")
|
||||
@@ -233,11 +221,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
- 1). Single dot-bracket string
|
||||
- 2). multiple dot bracket strings with chain/ind ranges specified
|
||||
- 3). Lists of paired indices
|
||||
|
||||
"""
|
||||
#is_nucleic_ss_example=True
|
||||
#give_partial_feats=False
|
||||
|
||||
atom_array = annotate_na_ss_from_data_specification(
|
||||
data,
|
||||
overwrite=True,
|
||||
@@ -247,28 +231,34 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
if "feats" not in data:
|
||||
data["feats"] = {}
|
||||
|
||||
# data["feats"].update(nucleic_features)
|
||||
data.setdefault("log_dict", {})
|
||||
log_dict = data["log_dict"]
|
||||
data["log_dict"] = log_dict
|
||||
data["atom_array"] = atom_array
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _sample_where_to_show_ss(self, n_tokens: int,
|
||||
is_nucleic_ss_example: bool = True,
|
||||
give_partial_feats: bool = True,
|
||||
partner_sym_map: dict[int, list[int]] = None,
|
||||
) -> np.ndarray:
|
||||
"""Sample token-level islands indicating which SS rows/cols to reveal."""
|
||||
"""Sample token-level islands indicating which SS rows/cols to reveal.
|
||||
This custom function allows for enforcing symmetry in the shown features according
|
||||
to the partner_sym_map, which encodes which tokens are partners in the SS
|
||||
matrix and thus should be masked/unmasked together to maintain consistency.
|
||||
|
||||
"""
|
||||
# If NOT is_nucleic_ss_example, set is_shown to all False
|
||||
if not is_nucleic_ss_example:
|
||||
return np.zeros((n_tokens,), dtype=bool)
|
||||
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
|
||||
|
||||
# If NOT give_partial_feats, set is_shown to all True
|
||||
if not give_partial_feats:
|
||||
return np.ones((n_tokens,), dtype=bool)
|
||||
token_mask_to_show = np.ones((n_tokens,), dtype=bool)
|
||||
else:
|
||||
# Get numerical parameters for that govern the mask pattern
|
||||
frac_shown = (
|
||||
self.nucleic_ss_min_shown
|
||||
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand()
|
||||
@@ -276,15 +266,15 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
|
||||
max_length = int(np.ceil(frac_shown * n_tokens))
|
||||
if max_length <= 0:
|
||||
return np.zeros((n_tokens,), dtype=bool)
|
||||
|
||||
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
|
||||
island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)))
|
||||
island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)))
|
||||
island_len_min = min(island_len_min, n_tokens)
|
||||
island_len_max = min(island_len_max, n_tokens)
|
||||
island_len_max = max(island_len_max, island_len_min)
|
||||
|
||||
return sample_island_tokens(
|
||||
|
||||
# Sample the actual mask using the utility function:
|
||||
token_mask_to_show = sample_island_tokens(
|
||||
n_tokens,
|
||||
island_len_min=island_len_min,
|
||||
island_len_max=island_len_max,
|
||||
@@ -292,4 +282,17 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
n_islands_max=self.n_islands_max,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
# Handle symmetry by iterating through the partner_sym_map items and setting
|
||||
# `partner_mask_to_show` at partner positions to match `token_mask_to_show`
|
||||
# initialize as all shown so effect comes from hiding + logical AND condition
|
||||
partner_mask_to_show = np.ones_like(token_mask_to_show)
|
||||
for token_i, partner_ind_list in partner_sym_map.items():
|
||||
for partner_ind in partner_ind_list:
|
||||
partner_mask_to_show[partner_ind] = token_mask_to_show[token_i]
|
||||
|
||||
# Combine the original mask with the partner mask to ensure symmetry
|
||||
token_mask_to_show = token_mask_to_show & partner_mask_to_show
|
||||
|
||||
return token_mask_to_show
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user