mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
refactor ss metric computatino section to be called during inference to write in the json
This commit is contained in:
@@ -175,6 +175,88 @@ def _extract_loop_and_paired_token_ids(
|
||||
|
||||
return loop_token_ids, paired_token_ids
|
||||
|
||||
def compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=True, compute_for_diffused_region_only = False):
|
||||
gt_token_ids = _get_token_ids(gt_arr)
|
||||
pred_token_ids = _get_token_ids(pred_arr)
|
||||
if len(gt_token_ids) != len(pred_token_ids):
|
||||
None
|
||||
|
||||
# Restrict to token_ids that are valid in both arrays.
|
||||
gt_allowed = _get_candidate_token_ids(
|
||||
gt_arr,
|
||||
restrict_to_nucleic=restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=compute_for_diffused_region_only,
|
||||
)
|
||||
pred_allowed = _get_candidate_token_ids(
|
||||
pred_arr,
|
||||
restrict_to_nucleic=restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=compute_for_diffused_region_only,
|
||||
)
|
||||
allowed = gt_allowed & pred_allowed
|
||||
|
||||
if len(allowed) == 0:
|
||||
return None
|
||||
|
||||
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
|
||||
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
|
||||
|
||||
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
|
||||
gt_arr, allowed_token_ids=allowed
|
||||
)
|
||||
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
|
||||
pred_arr, allowed_token_ids=allowed
|
||||
)
|
||||
|
||||
pair_tp = len(gt_pairs & pred_pairs)
|
||||
pair_pred_n = len(pred_pairs)
|
||||
pair_gt_n = len(gt_pairs)
|
||||
|
||||
loop_tp = len(gt_loop & pred_loop)
|
||||
loop_pred_n = len(pred_loop)
|
||||
loop_gt_n = len(gt_loop)
|
||||
|
||||
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
|
||||
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
|
||||
|
||||
pair_weight = len(gt_paired_tokens)
|
||||
loop_weight = len(gt_loop)
|
||||
total_weight = pair_weight + loop_weight
|
||||
if total_weight == 0:
|
||||
weighted_f1 = 1.0
|
||||
else:
|
||||
weighted_f1 = float(
|
||||
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
|
||||
)
|
||||
|
||||
|
||||
return pair_f1, loop_f1, weighted_f1
|
||||
|
||||
def get_NA_SS_F1(pred_array):
|
||||
|
||||
## save the original bop_partner annotation
|
||||
gt_array = pred_array.copy()
|
||||
|
||||
## replace by annotating again
|
||||
pred_array = annotate_na_ss(
|
||||
pred_array,
|
||||
NA_only=True,
|
||||
planar_only=True,
|
||||
overwrite=True,
|
||||
p_canonical_bp_filter=0.0,
|
||||
)
|
||||
|
||||
try:
|
||||
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_array, pred_array)
|
||||
except:
|
||||
# fails when returns None because expects three returns
|
||||
return {}
|
||||
|
||||
return {
|
||||
"pair_f1": pair_f1,
|
||||
"loop_f1": loop_f1,
|
||||
"weighted_f1": weighted_f1,
|
||||
}
|
||||
|
||||
|
||||
class NucleicSSSimilarityMetrics(Metric):
|
||||
"""Secondary-structure similarity for nucleic acids.
|
||||
@@ -264,59 +346,13 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
if "bp_partners" not in pred_categories:
|
||||
continue
|
||||
|
||||
# Basic sanity check: token counts should match for aligned comparisons.
|
||||
gt_token_ids = _get_token_ids(gt_arr)
|
||||
pred_token_ids = _get_token_ids(pred_arr)
|
||||
if len(gt_token_ids) != len(pred_token_ids):
|
||||
# Basic sanity check: token counts should match for aligned comparisons
|
||||
try:
|
||||
pair_f1, loop_f1, weighted_f1 = compute_from_two_arr(gt_arr, pred_arr, restrict_to_nucleic=self.restrict_to_nucleic, compute_for_diffused_region_only = self.compute_for_diffused_region_only)
|
||||
except:
|
||||
# fails when returns None because expects three returns
|
||||
continue
|
||||
|
||||
# Restrict to token_ids that are valid in both arrays.
|
||||
gt_allowed = _get_candidate_token_ids(
|
||||
gt_arr,
|
||||
restrict_to_nucleic=self.restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
|
||||
)
|
||||
pred_allowed = _get_candidate_token_ids(
|
||||
pred_arr,
|
||||
restrict_to_nucleic=self.restrict_to_nucleic,
|
||||
compute_for_diffused_region_only=self.compute_for_diffused_region_only,
|
||||
)
|
||||
allowed = gt_allowed & pred_allowed
|
||||
|
||||
if len(allowed) == 0:
|
||||
continue
|
||||
|
||||
gt_pairs = _extract_bp_pairs(gt_arr, allowed_token_ids=allowed)
|
||||
pred_pairs = _extract_bp_pairs(pred_arr, allowed_token_ids=allowed)
|
||||
|
||||
gt_loop, gt_paired_tokens = _extract_loop_and_paired_token_ids(
|
||||
gt_arr, allowed_token_ids=allowed
|
||||
)
|
||||
pred_loop, _pred_paired_tokens = _extract_loop_and_paired_token_ids(
|
||||
pred_arr, allowed_token_ids=allowed
|
||||
)
|
||||
|
||||
pair_tp = len(gt_pairs & pred_pairs)
|
||||
pair_pred_n = len(pred_pairs)
|
||||
pair_gt_n = len(gt_pairs)
|
||||
|
||||
loop_tp = len(gt_loop & pred_loop)
|
||||
loop_pred_n = len(pred_loop)
|
||||
loop_gt_n = len(gt_loop)
|
||||
|
||||
pair_f1 = _safe_f1_from_sizes(pair_tp, pair_pred_n, pair_gt_n)
|
||||
loop_f1 = _safe_f1_from_sizes(loop_tp, loop_pred_n, loop_gt_n)
|
||||
|
||||
pair_weight = len(gt_paired_tokens)
|
||||
loop_weight = len(gt_loop)
|
||||
total_weight = pair_weight + loop_weight
|
||||
if total_weight == 0:
|
||||
weighted_f1 = 1.0
|
||||
else:
|
||||
weighted_f1 = float(
|
||||
(pair_weight * pair_f1 + loop_weight * loop_f1) / total_weight
|
||||
)
|
||||
|
||||
pair_f1_list.append(pair_f1)
|
||||
loop_f1_list.append(loop_f1)
|
||||
weighted_f1_list.append(weighted_f1)
|
||||
@@ -324,13 +360,7 @@ class NucleicSSSimilarityMetrics(Metric):
|
||||
|
||||
if n_valid == 0:
|
||||
return {}
|
||||
aaa = {
|
||||
"pair_f1": float(np.mean(pair_f1_list)),
|
||||
"loop_f1": float(np.mean(loop_f1_list)),
|
||||
"weighted_f1": float(np.mean(weighted_f1_list)),
|
||||
"n_valid_samples": int(n_valid),
|
||||
}
|
||||
print(aaa)
|
||||
|
||||
return {
|
||||
"pair_f1": float(np.mean(pair_f1_list)),
|
||||
"loop_f1": float(np.mean(loop_f1_list)),
|
||||
|
||||
@@ -8,6 +8,7 @@ from lightning_utilities import apply_to_collection
|
||||
from omegaconf import DictConfig
|
||||
from rfd3.metrics.design_metrics import get_all_backbone_metrics
|
||||
from rfd3.metrics.hbonds_hbplus_metrics import get_hbond_metrics
|
||||
from rfd3.metrics.nucleic_ss_metrics import get_NA_SS_F1
|
||||
from rfd3.trainer.recycling import get_recycle_schedule
|
||||
from rfd3.trainer.trainer_utils import (
|
||||
_build_atom_array_stack,
|
||||
@@ -449,6 +450,15 @@ class AADesignTrainer(FabricTrainer):
|
||||
):
|
||||
metadata_dict[i]["metrics"] |= get_hbond_metrics(atom_array)
|
||||
|
||||
if (
|
||||
"bp_partners" in atom_array.get_annotation_categories()
|
||||
):
|
||||
if not np.all(atom_array.bp_partners == None):
|
||||
try:
|
||||
metadata_dict[i]["metrics"] |= get_NA_SS_F1(atom_array)
|
||||
except:
|
||||
pass
|
||||
|
||||
if "partial_t" in f:
|
||||
# Try calcualte a CA RMSD to input:
|
||||
aa_in = example["atom_array"]
|
||||
|
||||
Reference in New Issue
Block a user