refactor ss metric computatino section to be called during inference to write in the json

This commit is contained in:
Raktim Mitra
2026-02-26 12:25:41 -08:00
parent 60b18c281b
commit bfe513ab17
2 changed files with 98 additions and 58 deletions

View File

@@ -175,6 +175,88 @@ def _extract_loop_and_paired_token_ids(
return loop_token_ids, paired_token_ids
def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only = False):
gt_token_ids = _get_token_ids(gt_arr)
pred_token_ids = _get_token_ids(pred_arr)
if len(gt_token_ids) != len(pred_token_ids):
None
# Restrict to token_ids that are valid in both arrays.
gt_allowed = _get_candidate_token_ids(
gt_arr,
restrict_to_nucleic=restrict_to_nucleic,
compute_for_diffused_region_only=compute_for_diffused_region_only,
)
pred_allowed = _get_candidate_token_ids(
pred_arr,
restrict_to_nucleic=restrict_to_nucleic,
compute_for_diffused_region_only=compute_for_diffused_region_only,
)
allowed = gt_allowed & pred_allowed
if len(allowed) == 0:
return None
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
gt_arr, allowed_token_ids=allowed
)
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
pred_arr, allowed_token_ids=allowed
)
pair_tp = len(gt_pairs & pred_pairs)
pair_pred_n = len(pred_pairs)
pair_gt_n = len(gt_pairs)
loop_tp = len(gt_loop & pred_loop)
loop_pred_n = len(pred_loop)
loop_gt_n = len(gt_loop)
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
pair_weight = len(gt_paired_tokens)
loop_weight = len(gt_loop)
total_weight = pair_weight + loop_weight
if total_weight == 0:
weighted_f1 = 1.0
else:
weighted_f1 = float(
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
)
return pair_f1, loop_f1, weighted_f1
def get_NA_SS_F1(pred_array):
## save the original bop_partner annotation
gt_array = pred_array.copy()
## replace by annotating again
pred_array = annotate_na_ss(
pred_array,
NA_only=True,
planar_only=True,
overwrite=True,
p_canonical_bp_filter=0.0,
)
try:
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array)
except:
# fails when returns None because expects three returns
return {}
return {
"pair_f1": pair_f1,
"loop_f1": loop_f1,
"weighted_f1": weighted_f1,
}
class NucleicSSSimilarityMetrics(Metric):
"""Secondary-structure similarity for nucleic acids.
@@ -264,59 +346,13 @@ class NucleicSSSimilarityMetrics(Metric):
if "bp_partners" not in pred_categories:
continue
# Basic sanity check: token counts should match for aligned comparisons.
gt_token_ids = _get_token_ids(gt_arr)
pred_token_ids = _get_token_ids(pred_arr)
if len(gt_token_ids) != len(pred_token_ids):
# Basic sanity check: token counts should match for aligned comparisons
try:
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=self.restrict_to_nucleic, compute_for_diffused_region_only = self.compute_for_diffused_region_only)
except:
# fails when returns None because expects three returns
continue
# Restrict to token_ids that are valid in both arrays.
gt_allowed = _get_candidate_token_ids(
gt_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
pred_allowed = _get_candidate_token_ids(
pred_arr,
restrict_to_nucleic=self.restrict_to_nucleic,
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
)
allowed = gt_allowed & pred_allowed
if len(allowed) == 0:
continue
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
gt_arr, allowed_token_ids=allowed
)
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
pred_arr, allowed_token_ids=allowed
)
pair_tp = len(gt_pairs & pred_pairs)
pair_pred_n = len(pred_pairs)
pair_gt_n = len(gt_pairs)
loop_tp = len(gt_loop & pred_loop)
loop_pred_n = len(pred_loop)
loop_gt_n = len(gt_loop)
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
pair_weight = len(gt_paired_tokens)
loop_weight = len(gt_loop)
total_weight = pair_weight + loop_weight
if total_weight == 0:
weighted_f1 = 1.0
else:
weighted_f1 = float(
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
)
pair_f1_list.append(pair_f1)
loop_f1_list.append(loop_f1)
weighted_f1_list.append(weighted_f1)
@@ -324,13 +360,7 @@ class NucleicSSSimilarityMetrics(Metric):
if n_valid == 0:
return {}
aaa = {
"pair_f1": float(np.mean(pair_f1_list)),
"loop_f1": float(np.mean(loop_f1_list)),
"weighted_f1": float(np.mean(weighted_f1_list)),
"n_valid_samples": int(n_valid),
}
print(aaa)
return {
"pair_f1": float(np.mean(pair_f1_list)),
"loop_f1": float(np.mean(loop_f1_list)),

View File

@@ -8,6 +8,7 @@ from lightning_utilities import apply_to_collection
from omegaconf import DictConfig
from rfd3.metrics.design_metrics import get_all_backbone_metrics
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
from rfd3.metrics.nucleic_ss_metrics import get_NA_SS_F1
from rfd3.trainer.recycling import get_recycle_schedule
from rfd3.trainer.trainer_utils import (
_build_atom_array_stack,
@@ -449,6 +450,15 @@ class AADesignTrainer(FabricTrainer):
):
metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array)
if (
"bp_partners" in atom_array.get_annotation_categories()
):
if not np.all(atom_array.bp_partners == None):
try:
metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array)
except:
pass
if "partial_t" in f:
# Try calcualte a CA RMSD to input:
aa_in = example["atom_array"]