From bfe513ab178307f5ba58c7a0e0ef486d13b34109 Mon Sep 17 00:00:00 2001 From: Raktim Mitra Date: Thu, 26 Feb 2026 12:25:41 -0800 Subject: [PATCH] refactor ss metric computatino section to be called during inference to write in the json --- .../src/rfd3/metrics/nucleic_ss_metrics.py | 146 +++++++++++------- models/rfd3/src/rfd3/trainer/rfd3.py | 10 ++ 2 files changed, 98 insertions(+), 58 deletions(-) diff --git a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py index f021639..55bd8cb 100644 --- a/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py +++ b/models/rfd3/src/rfd3/metrics/nucleic_ss_metrics.py @@ -175,6 +175,88 @@ def _extract_loop_and_paired_token_ids( return loop_token_ids, paired_token_ids +def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only = False): + gt_token_ids = _get_token_ids(gt_arr) + pred_token_ids = _get_token_ids(pred_arr) + if len(gt_token_ids) != len(pred_token_ids): + None + + # Restrict to token_ids that are valid in both arrays. + gt_allowed = _get_candidate_token_ids( + gt_arr, + restrict_to_nucleic=restrict_to_nucleic, + compute_for_diffused_region_only=compute_for_diffused_region_only, + ) + pred_allowed = _get_candidate_token_ids( + pred_arr, + restrict_to_nucleic=restrict_to_nucleic, + compute_for_diffused_region_only=compute_for_diffused_region_only, + ) + allowed = gt_allowed & pred_allowed + + if len(allowed) == 0: + return None + + gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed) + pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed) + + gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids( + gt_arr, allowed_token_ids=allowed + ) + pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids( + pred_arr, allowed_token_ids=allowed + ) + + pair_tp = len(gt_pairs & pred_pairs) + pair_pred_n = len(pred_pairs) + pair_gt_n = len(gt_pairs) + + loop_tp = len(gt_loop & pred_loop) + loop_pred_n = len(pred_loop) + loop_gt_n = len(gt_loop) + + pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n) + loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n) + + pair_weight = len(gt_paired_tokens) + loop_weight = len(gt_loop) + total_weight = pair_weight + loop_weight + if total_weight == 0: + weighted_f1 = 1.0 + else: + weighted_f1 = float( + (pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight + ) + + + return pair_f1, loop_f1, weighted_f1 + +def get_NA_SS_F1(pred_array): + + ## save the original bop_partner annotation + gt_array = pred_array.copy() + + ## replace by annotating again + pred_array = annotate_na_ss( + pred_array, + NA_only=True, + planar_only=True, + overwrite=True, + p_canonical_bp_filter=0.0, + ) + + try: + pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array) + except: + # fails when returns None because expects three returns + return {} + + return { + "pair_f1": pair_f1, + "loop_f1": loop_f1, + "weighted_f1": weighted_f1, + } + class NucleicSSSimilarityMetrics(Metric): """Secondary-structure similarity for nucleic acids. @@ -264,59 +346,13 @@ class NucleicSSSimilarityMetrics(Metric): if "bp_partners" not in pred_categories: continue - # Basic sanity check: token counts should match for aligned comparisons. - gt_token_ids = _get_token_ids(gt_arr) - pred_token_ids = _get_token_ids(pred_arr) - if len(gt_token_ids) != len(pred_token_ids): + # Basic sanity check: token counts should match for aligned comparisons + try: + pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=self.restrict_to_nucleic, compute_for_diffused_region_only = self.compute_for_diffused_region_only) + except: + # fails when returns None because expects three returns continue - # Restrict to token_ids that are valid in both arrays. - gt_allowed = _get_candidate_token_ids( - gt_arr, - restrict_to_nucleic=self.restrict_to_nucleic, - compute_for_diffused_region_only=self.compute_for_diffused_region_only, - ) - pred_allowed = _get_candidate_token_ids( - pred_arr, - restrict_to_nucleic=self.restrict_to_nucleic, - compute_for_diffused_region_only=self.compute_for_diffused_region_only, - ) - allowed = gt_allowed & pred_allowed - - if len(allowed) == 0: - continue - - gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed) - pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed) - - gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids( - gt_arr, allowed_token_ids=allowed - ) - pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids( - pred_arr, allowed_token_ids=allowed - ) - - pair_tp = len(gt_pairs & pred_pairs) - pair_pred_n = len(pred_pairs) - pair_gt_n = len(gt_pairs) - - loop_tp = len(gt_loop & pred_loop) - loop_pred_n = len(pred_loop) - loop_gt_n = len(gt_loop) - - pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n) - loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n) - - pair_weight = len(gt_paired_tokens) - loop_weight = len(gt_loop) - total_weight = pair_weight + loop_weight - if total_weight == 0: - weighted_f1 = 1.0 - else: - weighted_f1 = float( - (pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight - ) - pair_f1_list.append(pair_f1) loop_f1_list.append(loop_f1) weighted_f1_list.append(weighted_f1) @@ -324,13 +360,7 @@ class NucleicSSSimilarityMetrics(Metric): if n_valid == 0: return {} - aaa = { - "pair_f1": float(np.mean(pair_f1_list)), - "loop_f1": float(np.mean(loop_f1_list)), - "weighted_f1": float(np.mean(weighted_f1_list)), - "n_valid_samples": int(n_valid), - } - print(aaa) + return { "pair_f1": float(np.mean(pair_f1_list)), "loop_f1": float(np.mean(loop_f1_list)), diff --git a/models/rfd3/src/rfd3/trainer/rfd3.py b/models/rfd3/src/rfd3/trainer/rfd3.py index 9735978..6701f8d 100644 --- a/models/rfd3/src/rfd3/trainer/rfd3.py +++ b/models/rfd3/src/rfd3/trainer/rfd3.py @@ -8,6 +8,7 @@ from lightning_utilities import apply_to_collection from omegaconf import DictConfig from rfd3.metrics.design_metrics import get_all_backbone_metrics from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics +from rfd3.metrics.nucleic_ss_metrics import get_NA_SS_F1 from rfd3.trainer.recycling import get_recycle_schedule from rfd3.trainer.trainer_utils import ( _build_atom_array_stack, @@ -449,6 +450,15 @@ class AADesignTrainer(FabricTrainer): ): metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array) + if ( + "bp_partners" in atom_array.get_annotation_categories() + ): + if not np.all(atom_array.bp_partners == None): + try: + metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array) + except: + pass + if "partial_t" in f: # Try calcualte a CA RMSD to input: aa_in = example["atom_array"]