diff --git a/protdiff/custom_metrics.py b/protdiff/custom_metrics.py index fb73523..6428d1a 100644 --- a/protdiff/custom_metrics.py +++ b/protdiff/custom_metrics.py @@ -4,6 +4,8 @@ Some custom metrics import functools import multiprocessing import logging +from cmath import rect, phase +from math import radians, degrees import numpy as np from scipy import stats @@ -11,6 +13,8 @@ from scipy import stats import torch from torch.utils.data import Dataset +import utils + def kl_from_empirical(u: np.ndarray, v: np.ndarray, nbins: int = 100) -> float: """ @@ -75,3 +79,15 @@ def kl_from_dset(dset: Dataset, single_thread: bool = False) -> np.ndarray: pool.close() pool.join() return np.array(kl_values) + + +def wrapped_mean(x: np.ndarray, min_val=-np.pi, max_val=np.pi) -> float: + """ + Wrap the mean function about the given range + """ + # https://rosettacode.org/wiki/Averages/Mean_angle + sin_x = np.sin(x) + cos_x = np.cos(x) + + retval = np.arctan2(np.mean(sin_x), np.mean(cos_x)) + return retval diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 052eedf..f53c23f 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -44,3 +44,51 @@ class TestKLFromEmpirical(unittest.TestCase): kl = cm.kl_from_empirical(u, v) self.assertEqual(np.inf, kl) + + +class TestWrappedMean(unittest.TestCase): + """Test for the wrapped mean function""" + + def setUp(self) -> None: + self.rng = np.random.default_rng(seed=6489) + self.rad2deg = lambda x: x * 180 / np.pi + self.deg2rad = lambda x: x * np.pi / 180 + + def test_simple(self): + """Test a hand-engineered example""" + true_mean = 170 + x = np.array([true_mean - 30, true_mean + 30]) + x_rad = self.deg2rad(x) + m = cm.wrapped_mean(x_rad) + m_deg = self.rad2deg(m) + self.assertAlmostEqual(true_mean, m_deg, places=2) + + def test_positive(self): + """Simple test""" + x = self.rng.normal(loc=3.0, scale=0.25, size=100000) + m = cm.wrapped_mean(x) + self.assertAlmostEqual(m, 3.0, places=2) + + def test_negative(self): + """Test that wrapping a negative mean works""" + x = self.rng.normal(loc=-3.0, scale=0.25, size=100000) + m = cm.wrapped_mean(x) + self.assertAlmostEqual(m, -3.0, places=2) + + def test_zero(self): + """Test that a zero mean is still correctly handled""" + x = self.rng.normal(loc=0.0, scale=0.25, size=100000) + m = cm.wrapped_mean(x) + self.assertAlmostEqual(m, 0.0, places=2) + + def test_positive_unwrapped(self): + """Test positive values that don't actually require wrapping""" + x = self.rng.normal(loc=0.5, scale=0.25, size=100000) + m = cm.wrapped_mean(x) + self.assertAlmostEqual(m, 0.5, places=2) + + def test_negative_unwrapped(self): + """Test negative values don't actually require wrapping""" + x = self.rng.normal(loc=-0.5, scale=0.25, size=100000) + m = cm.wrapped_mean(x) + self.assertAlmostEqual(m, -0.5, places=2)