Add nucleic SS metrics + conditioning + pseudoknot dataset config

This commit is contained in:
afavor
2026-01-30 13:32:46 -08:00
committed by Raktim Mitra
parent 3952c976a3
commit 716c16a7ec
14 changed files with 2567 additions and 2 deletions

View File

@@ -62,6 +62,9 @@ global_transform_args:
min_ss_island_len: 1
max_ss_island_len: 10
# Nucleic acid features
add_na_pair_features: false
train_conditions:
unconditional:
frequency: 5.0

View File

@@ -43,6 +43,9 @@ dataset:
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
# Nucleic acid features
add_na_pair_features: ${datasets.global_transform_args.add_na_pair_features}
# Cropping
crop_size: ${datasets.crop_size}
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
@@ -56,4 +59,5 @@ dataset:
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}
token_2d_features: ${model.net.token_initializer.token_2d_features}

View File

@@ -37,4 +37,5 @@ dataset:
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}
token_2d_features: ${model.net.token_initializer.token_2d_features}

View File

@@ -0,0 +1,9 @@
defaults:
- unconditional
- _self_
dataset:
name: pseudoknot
eval_every_n: 1
data: /home/afavor/git/RFD3/modelhub/projects/aa_design/tests/test_data/pseudoknot.json

View File

@@ -0,0 +1,87 @@
# @package _global_
# Training configuration for RFD3
defaults:
- /debug/default
- override /model: rfd3_base
- override /logger: null
- _self_
name: rfd3na-SScond
tags: [print-model]
ckpt_path: null
model:
net:
token_initializer:
token_1d_features:
ref_motif_token_type: 3
restype: 32
is_dna_token: 1
is_rna_token: 1
is_protein_token: 1
token_2d_features:
bp_partners: 3 # Unspecified, pair, loop
atom_1d_features:
ref_atom_name_chars: 256
ref_element: 128
ref_charge: 1
ref_mask: 1
ref_is_motif_atom_with_fixed_coord: 1
ref_is_motif_atom_unindexed: 1
has_zero_occupancy: 1
ref_pos: 3
# Guided features
ref_atomwise_rasa: 3
active_donor: 1
active_acceptor: 1
is_atom_level_hotspot: 1
diffusion_module:
n_recycle: 2
use_local_token_attention: True
diffusion_transformer:
n_local_tokens: 32
n_keys: 128
inference_sampler:
num_timesteps: 100
datasets:
diffusion_batch_size_train: 16
crop_size: 256
max_atoms_in_crop: 2560 # ~10x crop size.
global_transform_args:
association_scheme: atom23
add_na_pair_features: true
train_conditions:
unconditional:
frequency: 2.0
island:
frequency: 2.0
sequence_design:
frequency: 0.5
tipatom:
frequency: 5.0
ppi:
frequency: 0.0
train:
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
pdb:
probability: 1.0
# rna_monomer_distillation:
# probability: 1.0
# val:
# pseudoknot:
# dataset:
# # eval_every_n: 10
# eval_every_n: 2
trainer:
devices_per_node: 1
limit_train_batches: 10
limit_val_batches: 1
validate_every_n_epochs: 5
prevalidate: false

View File

@@ -25,6 +25,9 @@ token_initializer: # formerly known as the trunk
ref_plddt: 1
is_non_loopy: 1
# Optional 2D token feature definitions (empty by default)
token_2d_features: {}
downcast: ${model.net.diffusion_module.downcast}
atom_1d_features:
ref_atom_name_chars: 256

View File

@@ -20,3 +20,12 @@ hbond_metrics:
_target_: rfd3.metrics.hbonds_hbplus_metrics.HbondMetrics
cutoff_HA_dist: 3
cutoff_DA_distance: 3.5
nucleic_ss_similarity:
_target_: rfd3.metrics.nucleic_ss_metrics.NucleicSSSimilarityMetrics
restrict_to_nucleic: True
compute_for_diffused_region_only: False
annotate_predicted_fresh: True
annotation_NA_only: False
annotation_planar_only: True

View File

@@ -0,0 +1,314 @@
import logging
import bdb
import numpy as np
from biotite.structure import AtomArray
from atomworks.ml.utils.token import (
get_token_starts,
)
from rfd3.transforms.na_geom_utils import annotate_na_ss
from foundry.metrics.metric import Metric
from foundry.utils.ddp import RankedLogger
logging.basicConfig(level=logging.INFO)
global_logger = RankedLogger(__name__, rank_zero_only=False)
def _safe_f1_from_sizes(intersection_n: int, pred_n: int, gt_n: int) -> float:
"""Return F1 with sensible empty-set handling."""
if pred_n == 0 and gt_n == 0:
return 1.0
precision = float(intersection_n / pred_n) if pred_n > 0 else 0.0
recall = float(intersection_n / gt_n) if gt_n > 0 else 0.0
if precision + recall == 0.0:
return 0.0
return float(2.0 * precision * recall / (precision + recall))
def _get_token_ids(atom_array: AtomArray) -> np.ndarray:
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
return np.asarray(token_level_array.token_id, dtype=int)
def _get_candidate_token_ids(
atom_array: AtomArray,
*,
restrict_to_nucleic: bool,
compute_for_diffused_region_only: bool,
) -> set[int]:
"""Return a set of token_ids to include for scoring."""
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = np.asarray(token_level_array.token_id, dtype=int)
token_mask = np.ones(len(token_ids), dtype=bool)
if restrict_to_nucleic:
is_rna = (
np.asarray(getattr(token_level_array, "is_rna"), dtype=bool)
if hasattr(token_level_array, "is_rna")
else np.zeros(len(token_ids), dtype=bool)
)
is_dna = (
np.asarray(getattr(token_level_array, "is_dna"), dtype=bool)
if hasattr(token_level_array, "is_dna")
else np.zeros(len(token_ids), dtype=bool)
)
token_mask &= (is_rna | is_dna) if (is_rna.any() or is_dna.any()) else token_mask
if compute_for_diffused_region_only:
if hasattr(token_level_array, "is_motif_atom"):
token_mask &= ~np.asarray(token_level_array.is_motif_atom, dtype=bool)
elif hasattr(token_level_array, "is_motif_token"):
token_mask &= ~np.asarray(token_level_array.is_motif_token, dtype=bool)
return set(int(t) for t in token_ids[token_mask].tolist())
def _extract_bp_pairs(
atom_array: AtomArray,
*,
allowed_token_ids: set[int],
) -> set[tuple[int, int]]:
"""Extract unordered base-pair edges from bp_partner annotations.
Pairs are represented as (min_token_id, max_token_id).
"""
if "bp_partner" not in atom_array.get_annotation_categories():
raise ValueError("atom_array missing bp_partner annotation")
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = np.asarray(token_level_array.token_id, dtype=int)
token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
bp_partner_ann = atom_array.bp_partner
pairs: set[tuple[int, int]] = set()
for pos, start_idx in enumerate(token_starts.tolist()):
i_tid = int(token_ids[pos])
if i_tid not in allowed_token_ids:
continue
partners = bp_partner_ann[int(start_idx)]
if partners is None:
continue
if not isinstance(partners, (list, tuple, np.ndarray)):
continue
for partner_token_id in partners:
try:
j_tid = int(partner_token_id)
except Exception:
continue
if j_tid == i_tid or j_tid not in allowed_token_ids:
continue
if j_tid not in token_id_to_pos:
continue
a, b = (i_tid, j_tid) if i_tid < j_tid else (j_tid, i_tid)
pairs.add((a, b))
return pairs
def _extract_loop_and_paired_token_ids(
atom_array: AtomArray,
*,
allowed_token_ids: set[int],
) -> tuple[set[int], set[int]]:
"""Return (loop_token_ids, paired_token_ids) within the allowed token set."""
if "bp_partner" not in atom_array.get_annotation_categories():
raise ValueError("atom_array missing bp_partner annotation")
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = np.asarray(token_level_array.token_id, dtype=int)
token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
bp_partner_ann = atom_array.bp_partner
loop_token_ids: set[int] = set()
paired_token_ids: set[int] = set()
for pos, start_idx in enumerate(token_starts.tolist()):
i_tid = int(token_ids[pos])
if i_tid not in allowed_token_ids:
continue
partners = bp_partner_ann[int(start_idx)]
# New semantics:
# - None => unannotated/masked (NOT a loop)
# - [] => explicitly unpaired loop
if partners is None:
continue
if not isinstance(partners, (list, tuple, np.ndarray)):
continue
if len(partners) == 0:
loop_token_ids.add(i_tid)
continue
for partner_token_id in partners:
try:
j_tid = int(partner_token_id)
except Exception:
continue
if j_tid == i_tid or j_tid not in allowed_token_ids:
continue
if j_tid not in token_id_to_pos:
continue
paired_token_ids.add(i_tid)
paired_token_ids.add(j_tid)
return loop_token_ids, paired_token_ids
class NucleicSSSimilarityMetrics(Metric):
"""Secondary-structure similarity for nucleic acids.
Reports:
- `pair_f1`: F1 over the set of basepair edges implied by token-level `bp_partner`.
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partner == []`).
Unannotated tokens (`bp_partner is None`) are masked.
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
the prevalence of paired vs loop tokens in the GT.
"""
def __init__(
self,
*,
restrict_to_nucleic: bool = True,
compute_for_diffused_region_only: bool = False,
annotate_predicted_fresh: bool = False,
annotation_NA_only: bool = False,
annotation_planar_only: bool = True,
):
super().__init__()
self.restrict_to_nucleic = restrict_to_nucleic
self.compute_for_diffused_region_only = compute_for_diffused_region_only
self.annotate_predicted_fresh = annotate_predicted_fresh
self.annotation_NA_only = annotation_NA_only
self.annotation_planar_only = annotation_planar_only
@property
def kwargs_to_compute_args(self):
return {
"ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
}
def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
if ground_truth_atom_array_stack is None or predicted_atom_array_stack is None:
return {}
pair_f1_list: list[float] = []
loop_f1_list: list[float] = []
weighted_f1_list: list[float] = []
n_valid = 0
for gt_arr, pred_arr in zip(ground_truth_atom_array_stack, predicted_atom_array_stack):
try:
if "bp_partner" not in gt_arr.get_annotation_categories():
continue
# Important: predicted AtomArrays are built from a template AtomArray.
# If that template already carries bp_partner (often GT-derived), the
# prediction can inherit it, yielding artificially perfect scores.
# Optionally recompute bp_partner from the *predicted coordinates*.
if self.annotate_predicted_fresh:
annotate_na_ss(
pred_arr,
NA_only=self.annotation_NA_only,
planar_only=self.annotation_planar_only,
overwrite=True,
p_canonical_bp_filter=0.0,
)
if "bp_partner" not in pred_arr.get_annotation_categories():
continue
# Basic sanity check: token counts should match for aligned comparisons.
gt_token_ids = _get_token_ids(gt_arr)
pred_token_ids = _get_token_ids(pred_arr)
if len(gt_token_ids) != len(pred_token_ids):
continue
# Restrict to token_ids that are valid in both arrays.
gt_allowed = _get_candidate_token_ids(
gt_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
pred_allowed = _get_candidate_token_ids(
pred_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
allowed = gt_allowed & pred_allowed
if len(allowed) == 0:
continue
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
gt_arr, allowed_token_ids=allowed
)
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
pred_arr, allowed_token_ids=allowed
)
pair_tp = len(gt_pairs & pred_pairs)
pair_pred_n = len(pred_pairs)
pair_gt_n = len(gt_pairs)
loop_tp = len(gt_loop & pred_loop)
loop_pred_n = len(pred_loop)
loop_gt_n = len(gt_loop)
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
pair_weight = len(gt_paired_tokens)
loop_weight = len(gt_loop)
total_weight = pair_weight + loop_weight
if total_weight == 0:
weighted_f1 = 1.0
else:
weighted_f1 = float(
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
)
pair_f1_list.append(pair_f1)
loop_f1_list.append(loop_f1)
weighted_f1_list.append(weighted_f1)
n_valid += 1
except bdb.BdbQuit:
# Allow interactive debuggers (pdb) to cleanly abort without being swallowed.
raise
except Exception as e:
global_logger.error(f"Error computing nucleic-SS similarity: {e} | Skipping")
continue
if n_valid == 0:
return {}
return {
"pair_f1": float(np.mean(pair_f1_list)),
"loop_f1": float(np.mean(loop_f1_list)),
"weighted_f1": float(np.mean(weighted_f1_list)),
"n_valid_samples": int(n_valid),
}

View File

@@ -143,6 +143,38 @@ class OneDFeatureEmbedder(nn.Module):
)
)
class TwoDFeatureEmbedder(nn.Module):
"""
Embeds 2D features into a single vector.
Args:
features (dict): Dictionary of feature names and their number of channels.
output_channels (int): Output dimension of the projected embedding.
"""
def __init__(self, features, output_channels):
super().__init__()
self.features = {k: v for k, v in features.items() if exists(v)}
total_embedding_input_features = sum(self.features.values())
self.embedders = nn.ModuleDict(
{
feature: EmbeddingLayer(
n_channels, total_embedding_input_features, output_channels
)
for feature, n_channels in self.features.items()
}
)
def collapse2D(self, x, L):
return x.reshape((L, L, x.numel() // (L * L)))
def forward(self, f, collapse_length):
return sum(
tuple(
self.embedders[feature](self.collapse2D(f[feature].float(), collapse_length))
for feature, n_channels in self.features.items()
if exists(n_channels)
)
)
class SinusoidalDistEmbed(nn.Module):
"""

View File

@@ -11,6 +11,7 @@ from rfd3.model.layers.blocks import (
Downcast,
LocalAtomTransformer,
OneDFeatureEmbedder,
TwoDFeatureEmbedder,
PositionPairDistEmbedder,
RelativePositionEncodingWithIndexRemoval,
SinusoidalDistEmbed,
@@ -49,6 +50,7 @@ class TokenInitializer(nn.Module):
pairformer_block,
downcast,
token_1d_features,
token_2d_features,
atom_1d_features,
atom_transformer,
use_chunked_pll=False, # New parameter for memory optimization
@@ -62,6 +64,7 @@ class TokenInitializer(nn.Module):
self.atom_1d_embedder_1 = OneDFeatureEmbedder(atom_1d_features, c_s)
self.atom_1d_embedder_2 = OneDFeatureEmbedder(atom_1d_features, c_atom)
self.token_1d_embedder = OneDFeatureEmbedder(token_1d_features, c_s)
self.token_2d_embedder = TwoDFeatureEmbedder(token_2d_features, c_z)
self.downcast_atom = Downcast(c_atom=c_s, c_token=c_s, c_s=None, **downcast)
self.transition_post_token = Transition(c=c_s, n=2)
@@ -202,6 +205,8 @@ class TokenInitializer(nn.Module):
Z_init_II = Z_init_II + self.ref_pos_embedder_tok(
f["ref_pos"][f["is_ca"]], valid_mask
)
# Add extra token pair features
Z_init_II = Z_init_II + self.token_2d_embedder(f, I)
# Run a small transformer to provide position encodings to single.
for block in self.transformer_stack:

View File

@@ -45,6 +45,7 @@ from rfd3.transforms.util_transforms import (
get_af3_token_representative_masks,
)
from rfd3.transforms.virtual_atoms import PadTokensWithVirtualAtoms
from rfd3.transforms.na_geom import get_bp_feats_from_atom_array
from foundry.utils.ddp import RankedLogger # noqa
@@ -776,6 +777,90 @@ class AddAdditional1dFeaturesToFeats(Transform):
data = self.generate_feature(feature_name, n_dims, data, "atom")
return data
class AddAdditional2dFeaturesToFeats(Transform):
"""
Adds any net.token_initializer.token_2d_features and net.diffusion_module.diffusion_atom_encoder.atom_2d_features present in the atomarray but not in data['feats'] to data['feats']
Args:
- autofill_zeros_if_not_present_in_atomarray: self explanatory
- token_2d_features: List of single-item dictionaries, corresponding to feature_name: n_feature_dims. Should be hydra interpolated from
net.token_initializer.token_2d_features
"""
incompatible_previous_transforms = ["AddAdditional2dFeaturesToFeats"]
def __init__(
self,
token_2d_features,
autofill_zeros_if_not_present_in_atomarray=False,
association_scheme="atom14",
):
self.autofill = autofill_zeros_if_not_present_in_atomarray
self.token_2d_features = token_2d_features
self.association_scheme = association_scheme
# Need to pre-define custom constructor functions
# to map from atomarray annotations to tensors.
self.constructor_functions = {
'bp_partners': get_bp_feats_from_atom_array,
}
def check_input(self, data) -> None:
check_contains_keys(data, ["atom_array"])
check_is_instance(data, "atom_array", AtomArray)
def generate_token_feature(self, feature_name, n_dims, data):
# Don't do this if we already have the feature
if feature_name in data["feats"].keys():
return data
# For these, we need to use a constructor function mapping,
# since pair features may require custom logic/conventions.
if feature_name in self.constructor_functions.keys():
feature_array = self.constructor_functions[feature_name](data["atom_array"])
else:
raise ValueError(
f"No constructor function found for 2d feature `{feature_name}`"
)
# We can fix shape issues here:
if len(feature_array.shape) == 2 and n_dims == 1:
feature_array = feature_array.unsqueeze(1)
# ensure that feature_array is a 3d array with third dim == n_dims:
if len(feature_array.shape) != 3:
raise ValueError(
f"token 2d_feature `{feature_name}` must be a 3d array, got {len(feature_array.shape)}d."
)
if feature_array.shape[2] != n_dims:
raise ValueError(
f"token 2d_feature `{feature_name}` dimensions in atomarray ({feature_array.shape[-1]}) does not match dimension declared in config, ({n_dims})"
)
# Ensure correct shape in first two dims (I,I,...)
if feature_array.shape[0] != feature_array.shape[1]:
raise ValueError(
f"token 2d_feature `{feature_name}` first two dimensions must be equal (square matrix), got {feature_array.shape[0]} and {feature_array.shape[1]}"
)
data["feats"][feature_name] = feature_array
return data
def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Checks if the 2d_features are present in data['feats']. If not present, adds them from the atomarray.
If annotation is not present in atomarray, either autofills the feature with 0s or throws an error
"""
if "feats" not in data.keys():
data["feats"] = {}
# Only apply for features that the model is expecting:
for feature_name, n_dims in self.token_2d_features.items():
data = self.generate_token_feature(feature_name, n_dims, data)
return data
class FeaturizepLDDT(Transform):

View File

@@ -0,0 +1,285 @@
from typing import Any
import numpy as np
from functools import partial
from biotite.structure import AtomArray
from atomworks.ml.transforms._checks import (
check_atom_array_annotation,
check_contains_keys,
check_is_instance,
)
from atomworks.ml.transforms.base import Transform
from rfd3.transforms.conditioning_utils import sample_island_tokens
from rfd3.transforms.na_geom_utils import (
annotate_na_ss,
annotate_na_ss_from_data_specification,
bp_partner_to_ss_matrix,
)
from atomworks.ml.utils.token import spread_token_wise, get_token_starts
def get_bp_feats_from_atom_array(
atom_array: AtomArray,
) -> np.ndarray:
"""Build NA-SS features from atom_array annotations, assuming 'bp_partners' is present.
This function reconstructs the SS matrix from the 'bp_partners' annotation on the atom_array,
then one-hot encodes it into a 3-class matrix (mask, pair, loop).
"""
# Fixed feature info (inferred from usage in other functions)
feature_info = {
'NA_SS_MASK': 0, # Unspecified
'NA_SS_PAIR': 1, # Paired
'NA_SS_LOOP': 2, # Loop / unpaired
'num_classes_nucleic_ss': 3,
}
# Check for required annotation
if "bp_partners" not in atom_array.get_annotation_categories():
raise ValueError("atom_array must have 'bp_partners' annotation for NA-SS feature building.")
# Reconstruct SS matrix from annotations
na_ss_matrix = np.asarray(
bp_partner_to_ss_matrix(
atom_array,
feature_info=feature_info,
NA_only=False, # Include all residues (logic from other utils)
planar_only=True, # Use planar interactions (common default)
include_loops=True, # Include loop states
),
dtype=np.int64,
)
# One-hot encode the matrix
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
eye = np.eye(int(feature_info['num_classes_nucleic_ss']), dtype=np.int64)
return eye[na_ss_matrix_int]
def _build_na_ss_features_from_annotations(
atom_array: AtomArray,
*,
feature_info: dict,
num_classes: int,
NA_only: bool,
planar_only: bool,
is_nucleic_ss_example: bool,
give_partial_feats: bool,
get_feature_mask_fn,
) -> np.ndarray:
"""Reconstruct SS matrix from annotations, optionally mask, then one-hot."""
na_ss_matrix = np.asarray(
bp_partner_to_ss_matrix(
atom_array,
feature_info=feature_info,
NA_only=NA_only,
planar_only=planar_only,
include_loops=True,
),
dtype=np.int64,
)
n_tokens = int(na_ss_matrix.shape[0])
if give_partial_feats:
is_shown = (
np.asarray(get_feature_mask_fn(n_tokens), dtype=bool)
if is_nucleic_ss_example
else np.zeros((n_tokens,), dtype=bool)
)
na_ss_matrix[~is_shown, :] = feature_info["NA_SS_MASK"]
na_ss_matrix[:, ~is_shown] = feature_info["NA_SS_MASK"]
na_ss_matrix_int = np.asarray(na_ss_matrix, dtype=np.int64)
eye = np.eye(int(num_classes), dtype=np.int64)
return eye[na_ss_matrix_int]
class CalculateNucleicAcidGeomFeats(Transform):
"""
Transform for constructing nucleic-acid conditioning features.
This transform currently produces only nucleic-acid secondary-structure (NA-SS)
features as a 2D token-token matrix with 3 bins:
* 0: mask / unspecified
* 1: paired
* 2: loop / explicitly unpaired
Training:
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
via the ``bp_partner`` annotation (annotation-first), then reconstructs the
matrix (and optionally masks parts of it) before one-hot encoding.
Inference:
- Interprets user-provided secondary-structure specifications, writes the same
``bp_partner`` annotation, then follows the same matrix + one-hot path.
Note: helical-parameter features are not implemented/used in this refactored path.
"""
def __init__(
self,
is_inference,
add_nucleic_ss_feats: bool = True,
p_is_nucleic_ss_example: float = 0.3,
p_show_partial_feats: float = 0.5,
nucleic_ss_min_shown: float = 0.0,
nucleic_ss_max_shown: float = 1.0,
n_islands_min: int = 1,
n_islands_max: int = 6,
p_canonical_bp_filter: float = 0.0,
# USE_RF2AA_NAMES: bool = False,
NA_only: bool = False,
planar_only : bool = True,
):
# Critical, must always have to know how to handle
self.is_inference = is_inference
# For sampling whether we add nucleic-ss features (extra t2d)
self.add_nucleic_ss_feats = add_nucleic_ss_feats
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
self.nucleic_ss_min_shown = nucleic_ss_min_shown
self.nucleic_ss_max_shown = nucleic_ss_max_shown
self.n_islands_min = n_islands_min
self.n_islands_max = n_islands_max
self.p_show_partial_feats = p_show_partial_feats
# Filters for what can be considered a planar contact interaction
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter
# Inds of annotation types in the nucleic-ss features (stack of 3 matrices):
self.feature_info = {
'NA_SS_MASK' : 0, # Unspecified, or sm, or protein:
'NA_SS_PAIR' : 1,
'NA_SS_LOOP' : 2,
'num_classes_nucleic_ss' : 3,
}
def check_input(self, data: dict[str, Any]) -> None:
check_contains_keys(data, ["atom_array"])
check_is_instance(data, "atom_array", AtomArray)
check_atom_array_annotation(data, ["res_name"])
# maybe do later: check_atom_array_has_hydrogen(data)
def _sample_training_flags(self) -> tuple[bool, bool]:
"""Sample booleans controlling whether/how features are shown in training."""
is_nucleic_ss_example = bool(
self.add_nucleic_ss_feats
and (np.random.rand() < self.p_is_nucleic_ss_example)
)
give_partial_feats = bool(
np.random.rand() < self.p_show_partial_feats
)
return is_nucleic_ss_example, give_partial_feats
def forward(self, data: dict) -> dict:
atom_array = data["atom_array"]
# Calculate n_tokens (assuming one token per residue for simplicity)
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = [int(t) for t in token_level_array.token_id]
n_tokens = len(token_starts)
print(" DO I NEED TO CHANGE TO TOKEN_ID???")
# Handle the training case with ground truth and masking:
if not self.is_inference:
# First, annotate as usual
# atom_array = annotate_na_ss(atom_array, **kwargs)
atom_array = annotate_na_ss(atom_array,
NA_only=self.NA_only,
planar_only=self.planar_only,
p_canonical_bp_filter=self.p_canonical_bp_filter,
)
# Sample mask on token level:
is_nucleic_ss_example, give_partial_feats = self._sample_training_flags()
is_ss_shown = self._sample_where_to_show_ss(n_tokens,
is_nucleic_ss_example=is_nucleic_ss_example,
give_partial_feats=give_partial_feats) # Mask vec for tokens where ss shown
# Spread mask to atom level
is_ss_shown = spread_token_wise(atom_array, is_ss_shown)
# Extract the base pair annotations
bp_partners_atom = atom_array.get_annotation("bp_partners")
# Remove unshown positions from bp_partners annotation
bp_partners_atom[~is_ss_shown] = None
# Reset the annotation with newly hidden positions
atom_array.set_annotation("bp_partners", bp_partners_atom)
# Inference case: create from commandline args
else:
"""
Different cases handled:
- 1). Single dot-bracket string
- 2). multiple dot bracket strings with chain/ind ranges specified
- 3). Lists of paired indices
"""
is_nucleic_ss_example=True
give_partial_feats=False
atom_array = annotate_na_ss_from_data_specification(
data,
overwrite=True,
)
# Check feats existence and update:
if "feats" not in data:
data["feats"] = {}
# data["feats"].update(nucleic_features)
data.setdefault("log_dict", {})
log_dict = data["log_dict"]
data["log_dict"] = log_dict
data["atom_array"] = atom_array
return data
def _sample_where_to_show_ss(self, n_tokens: int,
is_nucleic_ss_example: bool = True,
give_partial_feats: bool = True,
) -> np.ndarray:
"""Sample token-level islands indicating which SS rows/cols to reveal."""
# If NOT is_nucleic_ss_example, set is_shown to all False
if not is_nucleic_ss_example:
return np.zeros((n_tokens,), dtype=bool)
# If NOT give_partial_feats, set is_shown to all True
if not give_partial_feats:
return np.ones((n_tokens,), dtype=bool)
else:
frac_shown = (
self.nucleic_ss_min_shown
+ (self.nucleic_ss_max_shown - self.nucleic_ss_min_shown) * np.random.rand()
)
frac_shown = float(np.clip(frac_shown, 0.0, 1.0))
max_length = int(np.ceil(frac_shown * n_tokens))
if max_length <= 0:
return np.zeros((n_tokens,), dtype=bool)
island_len_min = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_max), 1)))
island_len_max = max(1, int(frac_shown * n_tokens // max(int(self.n_islands_min), 1)))
island_len_min = min(island_len_min, n_tokens)
island_len_max = min(island_len_max, n_tokens)
island_len_max = max(island_len_max, island_len_min)
return sample_island_tokens(
n_tokens,
island_len_min=island_len_min,
island_len_max=island_len_max,
n_islands_min=self.n_islands_min,
n_islands_max=self.n_islands_max,
max_length=max_length,
)

File diff suppressed because it is too large Load Diff

View File

@@ -72,6 +72,7 @@ from rfd3.transforms.conditioning_base import (
)
from rfd3.transforms.design_transforms import (
AddAdditional1dFeaturesToFeats,
AddAdditional2dFeaturesToFeats,
AddGroundTruthSequence,
AddIsXFeats,
AssignTypes,
@@ -84,6 +85,7 @@ from rfd3.transforms.design_transforms import (
)
from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
from rfd3.transforms.na_geom import CalculateNucleicAcidGeomFeats
from rfd3.transforms.ppi_transforms import (
Add1DSSFeature,
AddGlobalIsNonLoopyFeature,
@@ -350,6 +352,7 @@ def build_atom14_base_pipeline_(
center_option: str,
atom_1d_features: dict | None,
token_1d_features: dict | None,
token_2d_features: dict | None,
# PPI features
max_ppi_hotspots_frac_to_provide: float,
ppi_hotspot_max_distance: float,
@@ -357,6 +360,8 @@ def build_atom14_base_pipeline_(
max_ss_frac_to_provide: float,
min_ss_island_len: int,
max_ss_island_len: int,
# Nucleic acid features
add_na_pair_features: bool,
**_, # dump additional kwargs (e.g. msa stuff)
):
"""
@@ -437,6 +442,15 @@ def build_atom14_base_pipeline_(
),
)
)
# Add nucleic acid geometry features
if add_na_pair_features:
transforms.append(
CalculateNucleicAcidGeomFeats(
is_inference,
NA_only=False,
planar_only=True,
)
)
# Design Transforms
transforms += [
@@ -525,6 +539,11 @@ def build_atom14_base_pipeline_(
atom_1d_features=atom_1d_features,
association_scheme=association_scheme,
),
AddAdditional2dFeaturesToFeats(
autofill_zeros_if_not_present_in_atomarray=True,
token_2d_features=token_2d_features,
association_scheme=association_scheme,
),
AddAF3TokenBondFeatures(),
AddGroundTruthSequence(sequence_encoding=af3_sequence_encoding),
ConditionalRoute(
@@ -615,6 +634,7 @@ def build_atom14_base_pipeline(
kwargs.setdefault("min_ss_island_len", 0)
kwargs.setdefault("max_ss_island_len", 999)
kwargs.setdefault("max_binder_length", 999)
kwargs.setdefault("add_na_pair_features", False)
kwargs.setdefault("b_factor_min", None)
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)