diff --git a/models/rfd3/configs/datasets/design_base.yaml b/models/rfd3/configs/datasets/design_base.yaml index 02ee92a..c862900 100644 --- a/models/rfd3/configs/datasets/design_base.yaml +++ b/models/rfd3/configs/datasets/design_base.yaml @@ -98,3 +98,8 @@ global_transform_args: # PPI add_ppi_hotspots: 0.75 full_binder_crop: 0.75 + + # Nucleic SS: + p_is_nucleic_ss_example: 0.0 + p_nucleic_ss_show_partial_feats: 0.0 + \ No newline at end of file diff --git a/models/rfd3/configs/datasets/design_base_rfd3na.yaml b/models/rfd3/configs/datasets/design_base_rfd3na.yaml index 7181617..dbebee8 100644 --- a/models/rfd3/configs/datasets/design_base_rfd3na.yaml +++ b/models/rfd3/configs/datasets/design_base_rfd3na.yaml @@ -81,7 +81,9 @@ global_transform_args: # Used to create simple boolean flags for downstream conditioning meta_conditioning_probabilities: - calculate_NA_SS: 1.0 + # calculate_NA_SS: 1.0 + p_is_nucleic_ss_example: 0.1 + p_nucleic_ss_show_partial_feats: 0.7 calculate_hbonds: 0.2 calculate_rasa: 0.6 diff --git a/models/rfd3/configs/datasets/val/pseudoknot.yaml b/models/rfd3/configs/datasets/val/pseudoknot.yaml index 7cf5ce1..895e6c2 100644 --- a/models/rfd3/configs/datasets/val/pseudoknot.yaml +++ b/models/rfd3/configs/datasets/val/pseudoknot.yaml @@ -6,4 +6,5 @@ defaults: dataset: name: pseudoknot eval_every_n: 1 + # data: ${paths.data.design_benchmark_data_dir}/pseudoknot_debug.json data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json diff --git a/models/rfd3/configs/experiment/rfd3na-ss.yaml b/models/rfd3/configs/experiment/rfd3na-ss.yaml index 6d2245c..082ed81 100644 --- a/models/rfd3/configs/experiment/rfd3na-ss.yaml +++ b/models/rfd3/configs/experiment/rfd3na-ss.yaml @@ -2,6 +2,7 @@ # Training configuration for RFD3 defaults: + # - /debug/default - /debug/default - override /model: rfd3_base - override /logger: null @@ -54,6 +55,10 @@ datasets: crop_size: 256 max_atoms_in_crop: 2560 # ~10x crop size. global_transform_args: + meta_conditioning_probabilities: + p_is_nucleic_ss_example: 1.0 + p_nucleic_ss_show_partial_feats: 0.7 + p_canonical_bp_filter: 0.2 association_scheme: atom23 #add_na_pair_features: true train_conditions: @@ -70,19 +75,23 @@ datasets: train: # These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data. pdb: - probability: 0.5 + # probability: 0.5 + probability: 0.75 + # probability: 0.0 rna_monomer_distillation: - probability: 0.5 + # probability: 0.2 + probability: 0.25 + # probability: 1.0 val: pseudoknot: dataset: # eval_every_n: 10 - eval_every_n: 1 + eval_every_n: 5 trainer: - devices_per_node: 1 - limit_train_batches: 10 - limit_val_batches: 1 + #devices_per_node: 1 + #limit_train_batches: 10 + #limit_val_batches: 1 validate_every_n_epochs: 5 prevalidate: true diff --git a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py index 6a0c8d4..62d077f 100644 --- a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py +++ b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py @@ -16,6 +16,14 @@ logging.basicConfig(level=logging.INFO) global_logger = RankedLogger(__name__, rank_zero_only=False) +def _get_bp_partners_annotation(atom_array: AtomArray): + """Return bp-partners annotation.""" + categories = atom_array.get_annotation_categories() + if "bp_partners" in categories: + return atom_array.bp_partners + raise ValueError("atom_array missing bp_partners annotation") + + def _safe_f1_from_sizes(intersection_n: int, pred_n: int, gt_n: int) -> float: """Return F1 with sensible empty-set handling.""" if pred_n == 0 and gt_n == 0: @@ -76,19 +84,16 @@ def _extract_bp_pairs( *, allowed_token_ids: set[int], ) -> set[tuple[int, int]]: - """Extract unordered base-pair edges from bp_partner annotations. + """Extract unordered base-pair edges from bp-partner annotations. Pairs are represented as (min_token_id, max_token_id). """ - if "bp_partner" not in atom_array.get_annotation_categories(): - raise ValueError("atom_array missing bp_partner annotation") - token_starts = get_token_starts(atom_array) token_level_array = atom_array[token_starts] token_ids = np.asarray(token_level_array.token_id, dtype=int) token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())} - bp_partner_ann = atom_array.bp_partner + bp_partner_ann = _get_bp_partners_annotation(atom_array) pairs: set[tuple[int, int]] = set() for pos, start_idx in enumerate(token_starts.tolist()): @@ -126,15 +131,12 @@ def _extract_loop_and_paired_token_ids( allowed_token_ids: set[int], ) -> tuple[set[int], set[int]]: """Return (loop_token_ids, paired_token_ids) within the allowed token set.""" - if "bp_partner" not in atom_array.get_annotation_categories(): - raise ValueError("atom_array missing bp_partner annotation") - token_starts = get_token_starts(atom_array) token_level_array = atom_array[token_starts] token_ids = np.asarray(token_level_array.token_id, dtype=int) token_id_to_pos = {int(tid): i for i, tid in enumerate(token_ids.tolist())} - bp_partner_ann = atom_array.bp_partner + bp_partner_ann = _get_bp_partners_annotation(atom_array) loop_token_ids: set[int] = set() paired_token_ids: set[int] = set() @@ -176,9 +178,9 @@ class NucleicSSSimilarityMetrics(Metric): """Secondary-structure similarity for nucleic acids. Reports: - - `pair_f1`: F1 over the set of basepair edges implied by token-level `bp_partner`. - - `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partner == []`). - Unannotated tokens (`bp_partner is None`) are masked. + - `pair_f1`: F1 over basepair edges from token-level bp-partner annotation. + - `loop_f1`: F1 over explicitly-unpaired loop tokens (`bp_partners == []`). + Unannotated tokens (`bp_partners is None`) are masked. - `weighted_f1`: GT-weighted average of `pair_f1` and `loop_f1`, weighted by the prevalence of paired vs loop tokens in the GT. """ @@ -217,92 +219,85 @@ class NucleicSSSimilarityMetrics(Metric): n_valid = 0 for gt_arr, pred_arr in zip(ground_truth_atom_array_stack, predicted_atom_array_stack): - try: - if "bp_partner" not in gt_arr.get_annotation_categories(): - continue - - # Important: predicted AtomArrays are built from a template AtomArray. - # If that template already carries bp_partner (often GT-derived), the - # prediction can inherit it, yielding artificially perfect scores. - # Optionally recompute bp_partner from the *predicted coordinates*. - - if self.annotate_predicted_fresh: - annotate_na_ss( - pred_arr, - NA_only=self.annotation_NA_only, - planar_only=self.annotation_planar_only, - overwrite=True, - p_canonical_bp_filter=0.0, - ) - - if "bp_partner" not in pred_arr.get_annotation_categories(): - continue - - # Basic sanity check: token counts should match for aligned comparisons. - gt_token_ids = _get_token_ids(gt_arr) - pred_token_ids = _get_token_ids(pred_arr) - if len(gt_token_ids) != len(pred_token_ids): - continue - - # Restrict to token_ids that are valid in both arrays. - gt_allowed = _get_candidate_token_ids( - gt_arr, - restrict_to_nucleic=self.restrict_to_nucleic, - compute_for_diffused_region_only=self.compute_for_diffused_region_only, - ) - pred_allowed = _get_candidate_token_ids( - pred_arr, - restrict_to_nucleic=self.restrict_to_nucleic, - compute_for_diffused_region_only=self.compute_for_diffused_region_only, - ) - allowed = gt_allowed & pred_allowed - - if len(allowed) == 0: - continue - - gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed) - pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed) - - gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids( - gt_arr, allowed_token_ids=allowed - ) - pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids( - pred_arr, allowed_token_ids=allowed - ) - - pair_tp = len(gt_pairs & pred_pairs) - pair_pred_n = len(pred_pairs) - pair_gt_n = len(gt_pairs) - - loop_tp = len(gt_loop & pred_loop) - loop_pred_n = len(pred_loop) - loop_gt_n = len(gt_loop) - - pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n) - loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n) - - pair_weight = len(gt_paired_tokens) - loop_weight = len(gt_loop) - total_weight = pair_weight + loop_weight - if total_weight == 0: - weighted_f1 = 1.0 - else: - weighted_f1 = float( - (pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight - ) - - pair_f1_list.append(pair_f1) - loop_f1_list.append(loop_f1) - weighted_f1_list.append(weighted_f1) - n_valid += 1 - - except bdb.BdbQuit: - # Allow interactive debuggers (pdb) to cleanly abort without being swallowed. - raise - except Exception as e: - global_logger.error(f"Error computing nucleic-SS similarity: {e} | Skipping") + gt_categories = gt_arr.get_annotation_categories() + if "bp_partners" not in gt_categories: continue + # Important: predicted AtomArrays are built from a template AtomArray. + # If that template already carries bp_partners (often GT-derived), the + # prediction can inherit it, yielding artificially perfect scores. + # Optionally recompute bp_partners from the *predicted coordinates*. + if self.annotate_predicted_fresh: + annotate_na_ss( + pred_arr, + NA_only=self.annotation_NA_only, + planar_only=self.annotation_planar_only, + overwrite=True, + p_canonical_bp_filter=0.0, + ) + + pred_categories = pred_arr.get_annotation_categories() + if "bp_partners" not in pred_categories: + continue + + # Basic sanity check: token counts should match for aligned comparisons. + gt_token_ids = _get_token_ids(gt_arr) + pred_token_ids = _get_token_ids(pred_arr) + if len(gt_token_ids) != len(pred_token_ids): + continue + + # Restrict to token_ids that are valid in both arrays. + gt_allowed = _get_candidate_token_ids( + gt_arr, + restrict_to_nucleic=self.restrict_to_nucleic, + compute_for_diffused_region_only=self.compute_for_diffused_region_only, + ) + pred_allowed = _get_candidate_token_ids( + pred_arr, + restrict_to_nucleic=self.restrict_to_nucleic, + compute_for_diffused_region_only=self.compute_for_diffused_region_only, + ) + allowed = gt_allowed & pred_allowed + + if len(allowed) == 0: + continue + + gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed) + pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed) + + gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids( + gt_arr, allowed_token_ids=allowed + ) + pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids( + pred_arr, allowed_token_ids=allowed + ) + + pair_tp = len(gt_pairs & pred_pairs) + pair_pred_n = len(pred_pairs) + pair_gt_n = len(gt_pairs) + + loop_tp = len(gt_loop & pred_loop) + loop_pred_n = len(pred_loop) + loop_gt_n = len(gt_loop) + + pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n) + loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n) + + pair_weight = len(gt_paired_tokens) + loop_weight = len(gt_loop) + total_weight = pair_weight + loop_weight + if total_weight == 0: + weighted_f1 = 1.0 + else: + weighted_f1 = float( + (pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight + ) + + pair_f1_list.append(pair_f1) + loop_f1_list.append(loop_f1) + weighted_f1_list.append(weighted_f1) + n_valid += 1 + if n_valid == 0: return {} diff --git a/models/rfd3/src/rfd3/transforms/na_geom.py b/models/rfd3/src/rfd3/transforms/na_geom.py index a1c8565..f1f400a 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom.py +++ b/models/rfd3/src/rfd3/transforms/na_geom.py @@ -100,12 +100,12 @@ class CalculateNucleicAcidGeomFeats(Transform): Training: - Computes geometry/H-bond-based base pairs and writes them onto the AtomArray - via the ``bp_partner`` annotation (annotation-first), then reconstructs the + via the ``bp_partners`` annotation (annotation-first), then reconstructs the matrix (and optionally masks parts of it) before one-hot encoding. Inference: - Interprets user-provided secondary-structure specifications, writes the same - ``bp_partner`` annotation, then follows the same matrix + one-hot path. + ``bp_partners`` annotation, then follows the same matrix + one-hot path. Note: helical-parameter features are not implemented/used in this refactored path. """ @@ -113,16 +113,14 @@ class CalculateNucleicAcidGeomFeats(Transform): def __init__( self, is_inference, - add_nucleic_ss_feats: bool = True, - # Conditional sampling parameters: - p_is_nucleic_ss_example: float = 0.3, - p_show_partial_feats: float = 0.7, + # Conditional sampling parameters all stored in this dict: + meta_conditioning_probabilities: dict[str, float] = None, + # Mask control paramerers: nucleic_ss_min_shown: float = 0.2, nucleic_ss_max_shown: float = 1.0, n_islands_min: int = 1, n_islands_max: int = 6, - p_canonical_bp_filter: float = 0.0, # USE_RF2AA_NAMES: bool = False, NA_only: bool = False, @@ -131,21 +129,27 @@ class CalculateNucleicAcidGeomFeats(Transform): ): # Critical, must always have to know how to handle self.is_inference = is_inference + + self.meta_conditioning_probabilities = meta_conditioning_probabilities or {} + + # Control whether we show some nucleic SS or default to full 2D mask + self.p_is_nucleic_ss_example = self.meta_conditioning_probabilities.get("p_is_nucleic_ss_example", 0.0) - self.add_nucleic_ss_feats = add_nucleic_ss_feats - self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical - self.p_is_nucleic_ss_example = p_is_nucleic_ss_example - self.p_show_partial_feats = p_show_partial_feats + # Control whether we define full SS or just part of it (only applies if is NA SS example) + self.p_show_partial_feats = self.meta_conditioning_probabilities.get("p_nucleic_ss_show_partial_feats", 0.0) + + # Some frac of time default to only showing canonical base pairs + self.p_canonical_bp_filter = self.meta_conditioning_probabilities.get("p_canonical_bp_filter", 0.5) + + # mask patterning control to make things resemble design scenarios self.nucleic_ss_min_shown = nucleic_ss_min_shown self.nucleic_ss_max_shown = nucleic_ss_max_shown self.n_islands_min = n_islands_min self.n_islands_max = n_islands_max - # Filters for what can be considered a planar contact interaction self.NA_only = NA_only # only annotate base-like interactions for nucleic acid residues self.planar_only = planar_only # only consider planar atoms in sidechains for geometry calculations, - self.p_canonical_bp_filter = p_canonical_bp_filter # probability of enforcing canonical base pair filter @@ -168,13 +172,8 @@ class CalculateNucleicAcidGeomFeats(Transform): # Calculate n_tokens (assuming one token per residue for simplicity) token_starts = get_token_starts(atom_array) - token_level_array = atom_array[token_starts] n_tokens = len(token_starts) - - # Defaults for feature visibility - is_nucleic_ss_example = True - give_partial_feats = False - token_mask_to_show = np.ones(n_tokens, dtype=bool) + # token_level_array = atom_array[token_starts] # Handle the training case with ground truth and masking if not self.is_inference: diff --git a/models/rfd3/src/rfd3/transforms/na_geom_utils.py b/models/rfd3/src/rfd3/transforms/na_geom_utils.py index 8ab2e71..23292a8 100644 --- a/models/rfd3/src/rfd3/transforms/na_geom_utils.py +++ b/models/rfd3/src/rfd3/transforms/na_geom_utils.py @@ -1050,6 +1050,40 @@ def compute_nucleic_ss( hbond_count = np.asarray(hbond_count)[mask_1d, :][:, mask_1d] # [I, I, 3] seq_neighbors = np.asarray(token_level_data["seq_neighbors"], dtype=bool)[mask_1d, :][:, mask_1d] # [I, I] + # Nothing passed NA/planar filtering for this structure. + # Return empty outputs instead of failing downstream on np.stack([]). + if len(xyz_planar) == 0: + if return_basepairs_only: + return np.zeros((0, 0), dtype=bool) + + pair_params: dict[str, np.ndarray] = { + "H_ij": np.zeros((0, 0), dtype=np.float32), + "B_ij": np.zeros((0, 0), dtype=np.float32), + "P_ij": np.zeros((0, 0), dtype=np.float32), + "base_ori_ij": np.zeros((0, 0), dtype=np.float32), + "basepairs_bool_ij": np.zeros((0, 0), dtype=bool), + "basepairs_ij": np.zeros((0, 0), dtype=np.float32), + "hbond_summation": np.zeros((0, 0), dtype=np.float32), + } + + if return_opening_angle: + pair_params["O_ij"] = np.zeros((0, 0), dtype=np.float32) + + if return_pairwise_geometry: + pair_params["X_ij"] = np.zeros((0, 0), dtype=np.float32) + pair_params["Y_ij"] = np.zeros((0, 0), dtype=np.float32) + pair_params["Z_ij"] = np.zeros((0, 0), dtype=np.float32) + + nucleic_ss_data: dict = {"pair_params": pair_params} + if return_local_params: + nucleic_ss_data["local_params"] = { + "X_i": np.zeros((0, 3), dtype=np.float32), + "Y_i": np.zeros((0, 3), dtype=np.float32), + "Z_i": np.zeros((0, 3), dtype=np.float32), + } + + return nucleic_ss_data + # --- Precompute centroids and displacement vectors ----------- planar_centers = np.stack( # [I, 3] [np.nanmean(np.asarray(xyz_i, dtype=np.float32), axis=0) for xyz_i in xyz_planar], @@ -1441,19 +1475,6 @@ def annotate_na_ss_from_specification( token_ids: list[int] = [int(t) for t in list(token_level_array.token_id)] n_tokens = len(token_starts) - # Explicit loops are only meaningful for nucleic-acid tokens. - # Instantiate encoding locally to avoid retaining large arrays at module scope. - sequence_encoding = AF3SequenceEncoding() - is_rna_like = np.isin( - token_level_array.res_name, - sequence_encoding.all_res_names[sequence_encoding.is_rna_like], - ) - is_dna_like = np.isin( - token_level_array.res_name, - sequence_encoding.all_res_names[sequence_encoding.is_dna_like], - ) - is_na_token = np.asarray(is_rna_like | is_dna_like, dtype=bool) - # Prepare/overwrite annotation array if (not overwrite) and ("bp_partners" in atom_array.get_annotation_categories()): bp_partners_ann = atom_array.bp_partners @@ -1528,8 +1549,6 @@ def annotate_na_ss_from_specification( return if i == j: return - if (not bool(is_na_token[int(i)])) or (not bool(is_na_token[int(j)])): - return if i in loop_token_idxs or j in loop_token_idxs: return partners[i].add(j) @@ -1619,9 +1638,6 @@ def annotate_na_ss_from_specification( for i in list(loop_token_idxs): if not (0 <= i < n_tokens): continue - if not bool(is_na_token[int(i)]): - loop_token_idxs.discard(int(i)) - continue for j in list(partners[i]): partners[j].discard(i) partners[i].clear() @@ -1630,8 +1646,6 @@ def annotate_na_ss_from_specification( # Unspecified tokens remain unannotated (None) -> NA_SS_MASK. for i in range(n_tokens): atom_i = int(token_starts[i]) - if not bool(is_na_token[int(i)]): - continue if len(partners[i]) > 0: bp_partners_ann[atom_i] = [] for j in sorted(partners[i]):