cleaned up NA-SS conditioning code before rebase

This commit is contained in:
afavor
2026-02-16 13:34:07 -08:00
committed by Raktim Mitra
parent 4478253fa9
commit d8b8d0c047
4 changed files with 927 additions and 969 deletions

View File

@@ -440,74 +440,73 @@ backbone_atoms_DNA = strip_list(backbone_atomscheme_DNA)
# Mapping from residue type to its backbone and sidechain atoms (for convenience)
ATOM_REGION_BY_RESI = {
'ALA': {'bb':('N','CA','C','O'),
'ALA': {'bb':('N','CA','C','O'),
'sc':('CB')},
'ARG': {'bb':('N','CA','C','O'),
'ARG': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','NE','CZ','NH1','NH2')},
'ASN': {'bb':('N','CA','C','O'),
'ASN': {'bb':('N','CA','C','O'),
'sc':('CB','CG','OD1','ND2')},
'ASP': {'bb':('N','CA','C','O'),
'ASP': {'bb':('N','CA','C','O'),
'sc':('CB','CG','OD1','OD2')},
'CYS': {'bb':('N','CA','C','O'),
'CYS': {'bb':('N','CA','C','O'),
'sc':('CB','SG')},
'GLN': {'bb':('N','CA','C','O'),
'GLN': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','OE1','NE2')},
'GLU': {'bb':('N','CA','C','O'),
'GLU': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','OE1','OE2')},
'GLY': {'bb':('N','CA','C','O'),
'GLY': {'bb':('N','CA','C','O'),
'sc':()},
'HIS': {'bb':('N','CA','C','O'),
'HIS': {'bb':('N','CA','C','O'),
'sc':('CB','CG','ND1','CD2','CE1','NE2')},
'ILE': {'bb':('N','CA','C','O'),
'ILE': {'bb':('N','CA','C','O'),
'sc':('CB','CG1','CG2','CD1')},
'LEU': {'bb':('N','CA','C','O'),
'LEU': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2')},
'LYS': {'bb':('N','CA','C','O'),
'LYS': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD','CE','NZ')},
'MET': {'bb':('N','CA','C','O'),
'MET': {'bb':('N','CA','C','O'),
'sc':('CB','CG','SD','CE')},
'PHE': {'bb':('N','CA','C','O'),
'PHE': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ')},
'PRO': {'bb':('N','CA','C','O'),
'PRO': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD')},
'SER': {'bb':('N','CA','C','O'),
'SER': {'bb':('N','CA','C','O'),
'sc':('CB','OG')},
'THR': {'bb':('N','CA','C','O'),
'THR': {'bb':('N','CA','C','O'),
'sc':('CB','OG1','CG2')},
'TRP': {'bb':('N','CA','C','O'),
'TRP': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE2','CE3','NE1','CZ2','CZ3','CH2')},
'TYR': {'bb':('N','CA','C','O'),
'TYR': {'bb':('N','CA','C','O'),
'sc':('CB','CG','CD1','CD2','CE1','CE2','CZ','OH')},
'VAL': {'bb':('N','CA','C','O'),
'VAL': {'bb':('N','CA','C','O'),
'sc':('CB','CG1','CG2')},
'UNK': {'bb':('N','CA','C','O'),
'UNK': {'bb':('N','CA','C','O'),
'sc':('CB')},
'MAS': {'bb':('N','CA','C','O'),
'MAS': {'bb':('N','CA','C','O'),
'sc':('CB')},
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'DA': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N6')},
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'DC': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'DG': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N9','C4','N3','C2','N1','C6','C5','N7','C8','N2','O6')},
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'DT': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':('N1','C2','O2','N3','C4','O4','C5','C7','C6')},
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'DX': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'"),
'sc':()},
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'A': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','N3','C4','C5','C6','N6','N7','C8','N9')},
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'C': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','O2','N3','C4','N4','C5','C6')},
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'G': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','N2','N3','C4','C5','C6','O6','N7','C8','N9')},
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'U': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':('N1','C2','O2','N3','C4','O4','C5','C6')},
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'X': {'bb':("O4'", "C1'", "C2'",'OP1','P','OP2', "O5'", "C5'", "C4'", "C3'", "O3'", "O2'"),
'sc':()},
'HIS_D': {'bb':('N','CA','C','O'),
'HIS_D': {'bb':('N','CA','C','O'),
'sc':('CB','CG','NE2','CD2','CE1','ND1')},
}
# Known planar sidechain atoms for each canonical residue type:
PLANAR_ATOMS_BY_RESI = {
'ALA': [],

View File

@@ -45,7 +45,7 @@ from rfd3.transforms.util_transforms import (
get_af3_token_representative_masks,
)
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
from rfd3.transforms.na_geom import get_bp_feats_from_atom_array
from rfd3.transforms.na_geom import na_ss_feats_from_annotation
from foundry.utils.ddp import RankedLogger # noqa
@@ -811,7 +811,7 @@ class AddAdditional2dFeaturesToFeats(Transform):
# Need to pre-define custom constructor functions
# to map from atomarray annotations to tensors.
self.constructor_functions = {
'bp_partners': get_bp_feats_from_atom_array,
'bp_partners': na_ss_feats_from_annotation,
}
def check_input(self, data) -> None:

View File

@@ -12,86 +12,80 @@ from rfd3.transforms.conditioning_utils import sample_island_tokens
from rfd3.transforms.na_geom_utils import (
annotate_na_ss,
annotate_na_ss_from_data_specification,
bp_partner_to_ss_matrix,
DEFAULT_NA_SS_FEATURE_INFO,
)
from atomworks.ml.utils.token import spread_token_wise, get_token_starts
def get_bp_feats_from_atom_array(
atom_array: AtomArray,
) -> np.ndarray:
"""Build NA-SS features from atom_array annotations, assuming 'bp_partners' is present.
This function reconstructs the SS matrix from the 'bp_partners' annotation on the atom_array,
then one-hot encodes it into a 3-class matrix (mask, pair, loop).
def na_ss_feats_from_annotation(atom_array: AtomArray,
token_starts= None,
n_tokens = None,
return_as_onehot = True,
) -> np.ndarray:
"""
# Fixed feature info (inferred from usage in other functions)
feature_info = {
'NA_SS_MASK': 0, # Unspecified
'NA_SS_PAIR': 1, # Paired
'NA_SS_LOOP': 2, # Loop / unpaired
'num_classes_nucleic_ss': 3,
}
Takes in atom array and constucts a base pair feature matrix from annotations,
according to to custom feature constuction + masking system.
This featurization utilizes info from BasePairEnum to assign int values
to paired, unpaired, and masked positions in the matrix.
# Check for required annotation
if "bp_partners" not in atom_array.get_annotation_categories():
raise ValueError("atom_array must have 'bp_partners' annotation for NA-SS feature building.")
Args:
* atom_array: AtomArray with bp_partners annotation at atom level
* token_starts (optional): indices of token starts in the atom array
* n_tokens (optional): number of tokens (length of token_starts)
* return_as_onehot (optional): if False, return integer-encoded
matrix instead of one-hot encoded matrix
# Reconstruct SS matrix from annotations
na_ss_matrix = np.asarray(
bp_partner_to_ss_matrix(
atom_array,
feature_info=feature_info,
NA_only=False, # Include all residues (logic from other utils)
planar_only=True, # Use planar interactions (common default)
include_loops=True, # Include loop states
),
dtype=np.int64,
)
returns:
* na_ss_matrix:
If ``return_as_onehot`` is True (default):
np.ndarray of shape (n_tokens, n_tokens, n_classes)
with one-hot encoded values according to BasePairEnum
# One-hot encode the matrix
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
eye = np.eye(int(feature_info['num_classes_nucleic_ss']), dtype=np.int64)
return eye[na_ss_matrix_int]
If ``return_as_onehot`` is False :
np.ndarray of shape (n_tokens, n_tokens)
with int values according to BasePairEnum
def _build_na_ss_features_from_annotations(
atom_array: AtomArray,
*,
feature_info: dict,
num_classes: int,
NA_only: bool,
planar_only: bool,
is_nucleic_ss_example: bool,
give_partial_feats: bool,
get_feature_mask_fn,
) -> np.ndarray:
"""Reconstruct SS matrix from annotations, optionally mask, then one-hot."""
na_ss_matrix = np.asarray(
bp_partner_to_ss_matrix(
atom_array,
feature_info=feature_info,
NA_only=NA_only,
planar_only=planar_only,
include_loops=True,
),
dtype=np.int64,
)
"""
# Get this info from atom_array, or avoid if given
if (token_starts is None) or (n_tokens is None):
token_starts = get_token_starts(atom_array)
n_tokens = len(token_starts)
n_tokens = int(na_ss_matrix.shape[0])
# Collect token inds for paired or loop positions:
pair_inds = []
loop_inds = []
token_bp_partners = atom_array.get_annotation("bp_partners")[token_starts] # get bp_partners at token level
assert len(token_bp_partners) == n_tokens, "Length of token_bp_partners should match n_tokens"
for i, j_list in enumerate(token_bp_partners):
if j_list is not None:
if len(j_list) > 0:
for j in j_list:
pair_inds.append((i, j))
else:
loop_inds.append(i)
if give_partial_feats:
is_shown = (
np.asarray(get_feature_mask_fn(n_tokens), dtype=bool)
if is_nucleic_ss_example
else np.zeros((n_tokens,), dtype=bool)
)
na_ss_matrix[~is_shown, :] = feature_info["NA_SS_MASK"]
na_ss_matrix[:, ~is_shown] = feature_info["NA_SS_MASK"]
# The standard system for constructing meaningful base pair features:
# 0). Initialize with values of UNSPECIFIED (0): int matrix of shape (n_tokens, n_tokens)
na_ss_matrix = np.full((n_tokens, n_tokens), DEFAULT_NA_SS_FEATURE_INFO["NA_SS_MASK"], dtype=np.int64)
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
eye = np.eye(int(num_classes), dtype=np.int64)
return eye[na_ss_matrix_int]
# 1). Fill in with values of PAIR (1) at positions that have bp_partners annotated as a non-empty list
for pair_i, pair_j in pair_inds:
na_ss_matrix[pair_i, pair_j] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"]
na_ss_matrix[pair_j, pair_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_PAIR"] # ensure symmetry
# 2). Fill in with values of LOOP (2) at positions that have bp_partners annotated as an empty list (explicitly unpaired)
# (we make full stripes across that position's row/col to indicate that NONE of those other positions are paired )
for loop_i in loop_inds:
na_ss_matrix[loop_i, :] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"]
na_ss_matrix[:, loop_i] = DEFAULT_NA_SS_FEATURE_INFO["NA_SS_LOOP"] # ensure symmetry
# Optional: convert NA-SS matrix to one-hot encoding according for model input:
if return_as_onehot:
na_ss_matrix = np.eye(len(DEFAULT_NA_SS_FEATURE_INFO), dtype=np.int64)[na_ss_matrix]
return na_ss_matrix
class CalculateNucleicAcidGeomFeats(Transform):
@@ -119,11 +113,12 @@ class CalculateNucleicAcidGeomFeats(Transform):
def __init__(
self,
is_inference,
meta_conditioning_probabilities,
add_nucleic_ss_feats: bool = True,
# Conditional sampling parameters:
p_is_nucleic_ss_example: float = 0.3,
p_show_partial_feats: float = 0.5,
nucleic_ss_min_shown: float = 0.0,
p_show_partial_feats: float = 0.7,
# Mask control paramerers:
nucleic_ss_min_shown: float = 0.2,
nucleic_ss_max_shown: float = 1.0,
n_islands_min: int = 1,
n_islands_max: int = 6,
@@ -136,39 +131,22 @@ class CalculateNucleicAcidGeomFeats(Transform):
):
# Critical, must always have to know how to handle
self.is_inference = is_inference
if not self.is_inference:
## relevant in training
self.sampling_prob = meta_conditioning_probabilities['calculate_NA_SS']
else:
## irrelevant in inference
self.sampling_prob = 0
# For sampling whether we add nucleic-ss features (extra t2d)
# relevant in training
self.add_nucleic_ss_feats = (self.sampling_prob > 0)
######
self.add_nucleic_ss_feats = add_nucleic_ss_feats
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
self.p_show_partial_feats = p_show_partial_feats
self.nucleic_ss_min_shown = nucleic_ss_min_shown
self.nucleic_ss_max_shown = nucleic_ss_max_shown
self.n_islands_min = n_islands_min
self.n_islands_max = n_islands_max
self.p_show_partial_feats = p_show_partial_feats
# Filters for what can be considered a planar contact interaction
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter
# Inds of annotation types in the nucleic-ss features (stack of 3 matrices):
self.feature_info = {
'NA_SS_MASK' : 0, # Unspecified, or sm, or protein:
'NA_SS_PAIR' : 1,
'NA_SS_LOOP' : 2,
'num_classes_nucleic_ss' : 3,
}
def check_input(self, data: dict[str, Any]) -> None:
@@ -179,10 +157,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
def _sample_training_flags(self) -> tuple[bool, bool]:
"""Sample booleans controlling whether/how features are shown in training."""
is_nucleic_ss_example = bool(
self.add_nucleic_ss_feats
and (np.random.rand() < self.p_is_nucleic_ss_example)
)
is_nucleic_ss_example = bool(np.random.rand() < self.p_is_nucleic_ss_example)
give_partial_feats = bool(
np.random.rand() < self.p_show_partial_feats
)
@@ -194,28 +169,41 @@ class CalculateNucleicAcidGeomFeats(Transform):
# Calculate n_tokens (assuming one token per residue for simplicity)
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = [int(t) for t in token_level_array.token_id]
n_tokens = len(token_starts)
#TODO print(" DO I NEED TO CHANGE TO TOKEN_ID???")
# Handle the training case with ground truth and masking:
if not self.is_inference and (np.random.rand() < self.sampling_prob):
# Defaults for feature visibility
is_nucleic_ss_example = True
give_partial_feats = False
token_mask_to_show = np.ones(n_tokens, dtype=bool)
# Handle the training case with ground truth and masking
if not self.is_inference:
# First, annotate as usual
# atom_array = annotate_na_ss(atom_array, **kwargs)
atom_array = annotate_na_ss(atom_array,
NA_only=self.NA_only,
planar_only=self.planar_only,
p_canonical_bp_filter=self.p_canonical_bp_filter,
)
# Sample mask on token level:
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
is_ss_shown = self._sample_where_to_show_ss(n_tokens,
is_nucleic_ss_example=is_nucleic_ss_example,
give_partial_feats=give_partial_feats) # Mask vec for tokens where ss shown
# Spread mask to atom level
is_ss_shown = spread_token_wise(atom_array, is_ss_shown)
# Generate symmetric partner annotations at the token level for masking purposes.
# choice for object-consistency: if already masked/undefined: be a list mapping to self-index.
partner_sym_map = {
i: atom_array.bp_partners[ts_i] if atom_array.bp_partners[ts_i] is not None else [i]
for i, ts_i in enumerate(token_starts)
}
# # Sample mask on token level:
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
token_mask_to_show = self._sample_where_to_show_ss(
n_tokens,
is_nucleic_ss_example=is_nucleic_ss_example,
give_partial_feats=give_partial_feats,
partner_sym_map=partner_sym_map,
) # Mask vec for tokens where ss shown
# Spread mask to atom level
is_ss_shown = spread_token_wise(atom_array, token_mask_to_show)
# Extract the base pair annotations
bp_partners_atom = atom_array.get_annotation("bp_partners")
@@ -233,11 +221,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
- 1). Single dot-bracket string
- 2). multiple dot bracket strings with chain/ind ranges specified
- 3). Lists of paired indices
"""
#is_nucleic_ss_example=True
#give_partial_feats=False
atom_array = annotate_na_ss_from_data_specification(
data,
overwrite=True,
@@ -247,28 +231,34 @@ class CalculateNucleicAcidGeomFeats(Transform):
if "feats" not in data:
data["feats"] = {}
# data["feats"].update(nucleic_features)
data.setdefault("log_dict", {})
log_dict = data["log_dict"]
data["log_dict"] = log_dict
data["atom_array"] = atom_array
return data
def _sample_where_to_show_ss(self, n_tokens: int,
is_nucleic_ss_example: bool = True,
give_partial_feats: bool = True,
partner_sym_map: dict[int, list[int]] = None,
) -> np.ndarray:
"""Sample token-level islands indicating which SS rows/cols to reveal."""
"""Sample token-level islands indicating which SS rows/cols to reveal.
This custom function allows for enforcing symmetry in the shown features according
to the partner_sym_map, which encodes which tokens are partners in the SS
matrix and thus should be masked/unmasked together to maintain consistency.
"""
# If NOT is_nucleic_ss_example, set is_shown to all False
if not is_nucleic_ss_example:
return np.zeros((n_tokens,), dtype=bool)
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
# If NOT give_partial_feats, set is_shown to all True
if not give_partial_feats:
return np.ones((n_tokens,), dtype=bool)
token_mask_to_show = np.ones((n_tokens,), dtype=bool)
else:
# Get numerical parameters for that govern the mask pattern
frac_shown = (
self.nucleic_ss_min_shown
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand()
@@ -276,15 +266,15 @@ class CalculateNucleicAcidGeomFeats(Transform):
frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
max_length = int(np.ceil(frac_shown * n_tokens))
if max_length <= 0:
return np.zeros((n_tokens,), dtype=bool)
token_mask_to_show = np.zeros((n_tokens,), dtype=bool)
island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)))
island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)))
island_len_min = min(island_len_min, n_tokens)
island_len_max = min(island_len_max, n_tokens)
island_len_max = max(island_len_max, island_len_min)
return sample_island_tokens(
# Sample the actual mask using the utility function:
token_mask_to_show = sample_island_tokens(
n_tokens,
island_len_min=island_len_min,
island_len_max=island_len_max,
@@ -292,4 +282,17 @@ class CalculateNucleicAcidGeomFeats(Transform):
n_islands_max=self.n_islands_max,
max_length=max_length,
)
# Handle symmetry by iterating through the partner_sym_map items and setting
# `partner_mask_to_show` at partner positions to match `token_mask_to_show`
# initialize as all shown so effect comes from hiding + logical AND condition
partner_mask_to_show = np.ones_like(token_mask_to_show)
for token_i, partner_ind_list in partner_sym_map.items():
for partner_ind in partner_ind_list:
partner_mask_to_show[partner_ind] = token_mask_to_show[token_i]
# Combine the original mask with the partner mask to ensure symmetry
token_mask_to_show = token_mask_to_show & partner_mask_to_show
return token_mask_to_show

File diff suppressed because it is too large Load Diff