mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
20 lines
600 B
Python
20 lines
600 B
Python
from typing import Dict
|
|
|
|
from rf2aa.metrics.predicted_error import PAE, PLDDT
|
|
|
|
|
|
class MetricManager:
|
|
def __init__(self, config) -> None:
|
|
self.config = config
|
|
self.metrics = {metric: metrics_factory[metric] for metric in config.metrics}
|
|
|
|
def __call__(self, rf_outputs, loss_calc_items) -> Dict:
|
|
metrics_dict = {}
|
|
for metric_name, metric in self.metrics:
|
|
metric_value = metric(rf_outputs, loss_calc_items)
|
|
metrics_dict[metric_name] = metric_value
|
|
return metrics_dict
|
|
|
|
|
|
metrics_factory = {"mean_pae": PAE(), "mean_plddt": PLDDT()}
|