mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-07 15:34:23 +08:00
Test wrapped mean function
This commit is contained in:
@@ -4,6 +4,8 @@ Some custom metrics
|
||||
import functools
|
||||
import multiprocessing
|
||||
import logging
|
||||
from cmath import rect, phase
|
||||
from math import radians, degrees
|
||||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
@@ -11,6 +13,8 @@ from scipy import stats
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import utils
|
||||
|
||||
|
||||
def kl_from_empirical(u: np.ndarray, v: np.ndarray, nbins: int = 100) -> float:
|
||||
"""
|
||||
@@ -75,3 +79,15 @@ def kl_from_dset(dset: Dataset, single_thread: bool = False) -> np.ndarray:
|
||||
pool.close()
|
||||
pool.join()
|
||||
return np.array(kl_values)
|
||||
|
||||
|
||||
def wrapped_mean(x: np.ndarray, min_val=-np.pi, max_val=np.pi) -> float:
|
||||
"""
|
||||
Wrap the mean function about the given range
|
||||
"""
|
||||
# https://rosettacode.org/wiki/Averages/Mean_angle
|
||||
sin_x = np.sin(x)
|
||||
cos_x = np.cos(x)
|
||||
|
||||
retval = np.arctan2(np.mean(sin_x), np.mean(cos_x))
|
||||
return retval
|
||||
|
||||
@@ -44,3 +44,51 @@ class TestKLFromEmpirical(unittest.TestCase):
|
||||
|
||||
kl = cm.kl_from_empirical(u, v)
|
||||
self.assertEqual(np.inf, kl)
|
||||
|
||||
|
||||
class TestWrappedMean(unittest.TestCase):
|
||||
"""Test for the wrapped mean function"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.rng = np.random.default_rng(seed=6489)
|
||||
self.rad2deg = lambda x: x * 180 / np.pi
|
||||
self.deg2rad = lambda x: x * np.pi / 180
|
||||
|
||||
def test_simple(self):
|
||||
"""Test a hand-engineered example"""
|
||||
true_mean = 170
|
||||
x = np.array([true_mean - 30, true_mean + 30])
|
||||
x_rad = self.deg2rad(x)
|
||||
m = cm.wrapped_mean(x_rad)
|
||||
m_deg = self.rad2deg(m)
|
||||
self.assertAlmostEqual(true_mean, m_deg, places=2)
|
||||
|
||||
def test_positive(self):
|
||||
"""Simple test"""
|
||||
x = self.rng.normal(loc=3.0, scale=0.25, size=100000)
|
||||
m = cm.wrapped_mean(x)
|
||||
self.assertAlmostEqual(m, 3.0, places=2)
|
||||
|
||||
def test_negative(self):
|
||||
"""Test that wrapping a negative mean works"""
|
||||
x = self.rng.normal(loc=-3.0, scale=0.25, size=100000)
|
||||
m = cm.wrapped_mean(x)
|
||||
self.assertAlmostEqual(m, -3.0, places=2)
|
||||
|
||||
def test_zero(self):
|
||||
"""Test that a zero mean is still correctly handled"""
|
||||
x = self.rng.normal(loc=0.0, scale=0.25, size=100000)
|
||||
m = cm.wrapped_mean(x)
|
||||
self.assertAlmostEqual(m, 0.0, places=2)
|
||||
|
||||
def test_positive_unwrapped(self):
|
||||
"""Test positive values that don't actually require wrapping"""
|
||||
x = self.rng.normal(loc=0.5, scale=0.25, size=100000)
|
||||
m = cm.wrapped_mean(x)
|
||||
self.assertAlmostEqual(m, 0.5, places=2)
|
||||
|
||||
def test_negative_unwrapped(self):
|
||||
"""Test negative values don't actually require wrapping"""
|
||||
x = self.rng.normal(loc=-0.5, scale=0.25, size=100000)
|
||||
m = cm.wrapped_mean(x)
|
||||
self.assertAlmostEqual(m, -0.5, places=2)
|
||||
|
||||
Reference in New Issue
Block a user