fixed behavior for conditional sampling of NA SS condition

This commit is contained in:
afavor
2026-02-16 18:15:42 -08:00
committed by Raktim Mitra
parent d8b8d0c047
commit fc8514bb04
7 changed files with 167 additions and 142 deletions

View File

@@ -98,3 +98,8 @@ global_transform_args:
# PPI
add_ppi_hotspots: 0.75
full_binder_crop: 0.75
# Nucleic SS:
p_is_nucleic_ss_example: 0.0
p_nucleic_ss_show_partial_feats: 0.0

View File

@@ -81,7 +81,9 @@ global_transform_args:
# Used to create simple boolean flags for downstream conditioning
meta_conditioning_probabilities:
calculate_NA_SS: 1.0
# calculate_NA_SS: 1.0
p_is_nucleic_ss_example: 0.1
p_nucleic_ss_show_partial_feats: 0.7
calculate_hbonds: 0.2
calculate_rasa: 0.6

View File

@@ -6,4 +6,5 @@ defaults:
dataset:
name: pseudoknot
eval_every_n: 1
# data: ${paths.data.design_benchmark_data_dir}/pseudoknot_debug.json
data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json

View File

@@ -2,6 +2,7 @@
# Training configuration for RFD3
defaults:
# - /debug/default
- /debug/default
- override /model: rfd3_base
- override /logger: null
@@ -54,6 +55,10 @@ datasets:
crop_size: 256
max_atoms_in_crop: 2560 # ~10x crop size.
global_transform_args:
meta_conditioning_probabilities:
p_is_nucleic_ss_example: 1.0
p_nucleic_ss_show_partial_feats: 0.7
p_canonical_bp_filter: 0.2
association_scheme: atom23
#add_na_pair_features: true
train_conditions:
@@ -70,19 +75,23 @@ datasets:
train:
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
pdb:
probability: 0.5
# probability: 0.5
probability: 0.75
# probability: 0.0
rna_monomer_distillation:
probability: 0.5
# probability: 0.2
probability: 0.25
# probability: 1.0
val:
pseudoknot:
dataset:
# eval_every_n: 10
eval_every_n: 1
eval_every_n: 5
trainer:
devices_per_node: 1
limit_train_batches: 10
limit_val_batches: 1
#devices_per_node: 1
#limit_train_batches: 10
#limit_val_batches: 1
validate_every_n_epochs: 5
prevalidate: true

View File

@@ -16,6 +16,14 @@ logging.basicConfig(level=logging.INFO)
global_logger = RankedLogger(__name__, rank_zero_only=False)
def _get_bp_partners_annotation(atom_array: AtomArray):
"""Return bp-partners annotation."""
categories = atom_array.get_annotation_categories()
if "bp_partners" in categories:
return atom_array.bp_partners
raise ValueError("atom_array missing bp_partners annotation")
def _safe_f1_from_sizes(intersection_n: int, pred_n: int, gt_n: int) -> float:
"""Return F1 with sensible empty-set handling."""
if pred_n == 0 and gt_n == 0:
@@ -76,19 +84,16 @@ def _extract_bp_pairs(
*,
allowed_token_ids: set[int],
) -> set[tuple[int, int]]:
"""Extract unordered base-pair edges from bp_partner annotations.
"""Extract unordered base-pair edges from bp-partner annotations.
Pairs are represented as (min_token_id, max_token_id).
"""
if "bp_partner" not in atom_array.get_annotation_categories():
raise ValueError("atom_array missing bp_partner annotation")
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = np.asarray(token_level_array.token_id, dtype=int)
token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
bp_partner_ann = atom_array.bp_partner
bp_partner_ann = _get_bp_partners_annotation(atom_array)
pairs: set[tuple[int, int]] = set()
for pos, start_idx in enumerate(token_starts.tolist()):
@@ -126,15 +131,12 @@ def _extract_loop_and_paired_token_ids(
allowed_token_ids: set[int],
) -> tuple[set[int], set[int]]:
"""Return (loop_token_ids, paired_token_ids) within the allowed token set."""
if "bp_partner" not in atom_array.get_annotation_categories():
raise ValueError("atom_array missing bp_partner annotation")
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
token_ids = np.asarray(token_level_array.token_id, dtype=int)
token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())}
bp_partner_ann = atom_array.bp_partner
bp_partner_ann = _get_bp_partners_annotation(atom_array)
loop_token_ids: set[int] = set()
paired_token_ids: set[int] = set()
@@ -176,9 +178,9 @@ class NucleicSSSimilarityMetrics(Metric):
"""Secondary-structure similarity for nucleic acids.
Reports:
- `pair_f1`: F1 over the set of basepair edges implied by token-level `bp_partner`.
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partner == []`).
Unannotated tokens (`bp_partner is None`) are masked.
- `pair_f1`: F1 over basepair edges from token-level bp-partner annotation.
- `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`).
Unannotated tokens (`bp_partners is None`) are masked.
- `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by
the prevalence of paired vs loop tokens in the GT.
"""
@@ -217,92 +219,85 @@ class NucleicSSSimilarityMetrics(Metric):
n_valid = 0
for gt_arr, pred_arr in zip(ground_truth_atom_array_stack, predicted_atom_array_stack):
try:
if "bp_partner" not in gt_arr.get_annotation_categories():
continue
# Important: predicted AtomArrays are built from a template AtomArray.
# If that template already carries bp_partner (often GT-derived), the
# prediction can inherit it, yielding artificially perfect scores.
# Optionally recompute bp_partner from the *predicted coordinates*.
if self.annotate_predicted_fresh:
annotate_na_ss(
pred_arr,
NA_only=self.annotation_NA_only,
planar_only=self.annotation_planar_only,
overwrite=True,
p_canonical_bp_filter=0.0,
)
if "bp_partner" not in pred_arr.get_annotation_categories():
continue
# Basic sanity check: token counts should match for aligned comparisons.
gt_token_ids = _get_token_ids(gt_arr)
pred_token_ids = _get_token_ids(pred_arr)
if len(gt_token_ids) != len(pred_token_ids):
continue
# Restrict to token_ids that are valid in both arrays.
gt_allowed = _get_candidate_token_ids(
gt_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
pred_allowed = _get_candidate_token_ids(
pred_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
allowed = gt_allowed & pred_allowed
if len(allowed) == 0:
continue
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
gt_arr, allowed_token_ids=allowed
)
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
pred_arr, allowed_token_ids=allowed
)
pair_tp = len(gt_pairs & pred_pairs)
pair_pred_n = len(pred_pairs)
pair_gt_n = len(gt_pairs)
loop_tp = len(gt_loop & pred_loop)
loop_pred_n = len(pred_loop)
loop_gt_n = len(gt_loop)
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
pair_weight = len(gt_paired_tokens)
loop_weight = len(gt_loop)
total_weight = pair_weight + loop_weight
if total_weight == 0:
weighted_f1 = 1.0
else:
weighted_f1 = float(
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
)
pair_f1_list.append(pair_f1)
loop_f1_list.append(loop_f1)
weighted_f1_list.append(weighted_f1)
n_valid += 1
except bdb.BdbQuit:
# Allow interactive debuggers (pdb) to cleanly abort without being swallowed.
raise
except Exception as e:
global_logger.error(f"Error computing nucleic-SS similarity: {e} | Skipping")
gt_categories = gt_arr.get_annotation_categories()
if "bp_partners" not in gt_categories:
continue
# Important: predicted AtomArrays are built from a template AtomArray.
# If that template already carries bp_partners (often GT-derived), the
# prediction can inherit it, yielding artificially perfect scores.
# Optionally recompute bp_partners from the *predicted coordinates*.
if self.annotate_predicted_fresh:
annotate_na_ss(
pred_arr,
NA_only=self.annotation_NA_only,
planar_only=self.annotation_planar_only,
overwrite=True,
p_canonical_bp_filter=0.0,
)
pred_categories = pred_arr.get_annotation_categories()
if "bp_partners" not in pred_categories:
continue
# Basic sanity check: token counts should match for aligned comparisons.
gt_token_ids = _get_token_ids(gt_arr)
pred_token_ids = _get_token_ids(pred_arr)
if len(gt_token_ids) != len(pred_token_ids):
continue
# Restrict to token_ids that are valid in both arrays.
gt_allowed = _get_candidate_token_ids(
gt_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
pred_allowed = _get_candidate_token_ids(
pred_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
allowed = gt_allowed & pred_allowed
if len(allowed) == 0:
continue
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
gt_arr, allowed_token_ids=allowed
)
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
pred_arr, allowed_token_ids=allowed
)
pair_tp = len(gt_pairs & pred_pairs)
pair_pred_n = len(pred_pairs)
pair_gt_n = len(gt_pairs)
loop_tp = len(gt_loop & pred_loop)
loop_pred_n = len(pred_loop)
loop_gt_n = len(gt_loop)
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
pair_weight = len(gt_paired_tokens)
loop_weight = len(gt_loop)
total_weight = pair_weight + loop_weight
if total_weight == 0:
weighted_f1 = 1.0
else:
weighted_f1 = float(
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
)
pair_f1_list.append(pair_f1)
loop_f1_list.append(loop_f1)
weighted_f1_list.append(weighted_f1)
n_valid += 1
if n_valid == 0:
return {}

View File

@@ -100,12 +100,12 @@ class CalculateNucleicAcidGeomFeats(Transform):
Training:
- Computes geometry/H-bond-based base pairs and writes them onto the AtomArray
via the ``bp_partner`` annotation (annotation-first), then reconstructs the
via the ``bp_partners`` annotation (annotation-first), then reconstructs the
matrix (and optionally masks parts of it) before one-hot encoding.
Inference:
- Interprets user-provided secondary-structure specifications, writes the same
``bp_partner`` annotation, then follows the same matrix + one-hot path.
``bp_partners`` annotation, then follows the same matrix + one-hot path.
Note: helical-parameter features are not implemented/used in this refactored path.
"""
@@ -113,16 +113,14 @@ class CalculateNucleicAcidGeomFeats(Transform):
def __init__(
self,
is_inference,
add_nucleic_ss_feats: bool = True,
# Conditional sampling parameters:
p_is_nucleic_ss_example: float = 0.3,
p_show_partial_feats: float = 0.7,
# Conditional sampling parameters all stored in this dict:
meta_conditioning_probabilities: dict[str, float] = None,
# Mask control paramerers:
nucleic_ss_min_shown: float = 0.2,
nucleic_ss_max_shown: float = 1.0,
n_islands_min: int = 1,
n_islands_max: int = 6,
p_canonical_bp_filter: float = 0.0,
# USE_RF2AA_NAMES: bool = False,
NA_only: bool = False,
@@ -131,21 +129,27 @@ class CalculateNucleicAcidGeomFeats(Transform):
):
# Critical, must always have to know how to handle
self.is_inference = is_inference
self.meta_conditioning_probabilities = meta_conditioning_probabilities or {}
# Control whether we show some nucleic SS or default to full 2D mask
self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get("p_is_nucleic_ss_example", 0.0)
self.add_nucleic_ss_feats = add_nucleic_ss_feats
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
self.p_show_partial_feats = p_show_partial_feats
# Control whether we define full SS or just part of it (only applies if is NA SS example)
self.p_show_partial_feats = self.meta_conditioning_probabilities.get("p_nucleic_ss_show_partial_feats", 0.0)
# Some frac of time default to only showing canonical base pairs
self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get("p_canonical_bp_filter", 0.5)
# mask patterning control to make things resemble design scenarios
self.nucleic_ss_min_shown = nucleic_ss_min_shown
self.nucleic_ss_max_shown = nucleic_ss_max_shown
self.n_islands_min = n_islands_min
self.n_islands_max = n_islands_max
# Filters for what can be considered a planar contact interaction
self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues
self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations,
self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter
@@ -168,13 +172,8 @@ class CalculateNucleicAcidGeomFeats(Transform):
# Calculate n_tokens (assuming one token per residue for simplicity)
token_starts = get_token_starts(atom_array)
token_level_array = atom_array[token_starts]
n_tokens = len(token_starts)
# Defaults for feature visibility
is_nucleic_ss_example = True
give_partial_feats = False
token_mask_to_show = np.ones(n_tokens, dtype=bool)
# token_level_array = atom_array[token_starts]
# Handle the training case with ground truth and masking
if not self.is_inference:

View File

@@ -1050,6 +1050,40 @@ def compute_nucleic_ss(
hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3]
seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[mask_1d, :][:, mask_1d] # [I, I]
# Nothing passed NA/planar filtering for this structure.
# Return empty outputs instead of failing downstream on np.stack([]).
if len(xyz_planar) == 0:
if return_basepairs_only:
return np.zeros((0, 0), dtype=bool)
pair_params: dict[str, np.ndarray] = {
"H_ij": np.zeros((0, 0), dtype=np.float32),
"B_ij": np.zeros((0, 0), dtype=np.float32),
"P_ij": np.zeros((0, 0), dtype=np.float32),
"base_ori_ij": np.zeros((0, 0), dtype=np.float32),
"basepairs_bool_ij": np.zeros((0, 0), dtype=bool),
"basepairs_ij": np.zeros((0, 0), dtype=np.float32),
"hbond_summation": np.zeros((0, 0), dtype=np.float32),
}
if return_opening_angle:
pair_params["O_ij"] = np.zeros((0, 0), dtype=np.float32)
if return_pairwise_geometry:
pair_params["X_ij"] = np.zeros((0, 0), dtype=np.float32)
pair_params["Y_ij"] = np.zeros((0, 0), dtype=np.float32)
pair_params["Z_ij"] = np.zeros((0, 0), dtype=np.float32)
nucleic_ss_data: dict = {"pair_params": pair_params}
if return_local_params:
nucleic_ss_data["local_params"] = {
"X_i": np.zeros((0, 3), dtype=np.float32),
"Y_i": np.zeros((0, 3), dtype=np.float32),
"Z_i": np.zeros((0, 3), dtype=np.float32),
}
return nucleic_ss_data
# --- Precompute centroids and displacement vectors -----------
planar_centers = np.stack( # [I, 3]
[np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) for xyz_i in xyz_planar],
@@ -1441,19 +1475,6 @@ def annotate_na_ss_from_specification(
token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)]
n_tokens = len(token_starts)
# Explicit loops are only meaningful for nucleic-acid tokens.
# Instantiate encoding locally to avoid retaining large arrays at module scope.
sequence_encoding = AF3SequenceEncoding()
is_rna_like = np.isin(
token_level_array.res_name,
sequence_encoding.all_res_names[sequence_encoding.is_rna_like],
)
is_dna_like = np.isin(
token_level_array.res_name,
sequence_encoding.all_res_names[sequence_encoding.is_dna_like],
)
is_na_token = np.asarray(is_rna_like | is_dna_like, dtype=bool)
# Prepare/overwrite annotation array
if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()):
bp_partners_ann = atom_array.bp_partners
@@ -1528,8 +1549,6 @@ def annotate_na_ss_from_specification(
return
if i == j:
return
if (not bool(is_na_token[int(i)])) or (not bool(is_na_token[int(j)])):
return
if i in loop_token_idxs or j in loop_token_idxs:
return
partners[i].add(j)
@@ -1619,9 +1638,6 @@ def annotate_na_ss_from_specification(
for i in list(loop_token_idxs):
if not (0 <= i < n_tokens):
continue
if not bool(is_na_token[int(i)]):
loop_token_idxs.discard(int(i))
continue
for j in list(partners[i]):
partners[j].discard(i)
partners[i].clear()
@@ -1630,8 +1646,6 @@ def annotate_na_ss_from_specification(
# Unspecified tokens remain unannotated (None) -> NA_SS_MASK.
for i in range(n_tokens):
atom_i = int(token_starts[i])
if not bool(is_na_token[int(i)]):
continue
if len(partners[i]) > 0:
bp_partners_ann[atom_i] = []
for j in sorted(partners[i]):