Test wrapped mean function

This commit is contained in:
Kevin Wu
2022-08-30 19:26:02 -07:00
parent 187b924386
commit 80bb4b8c9f
2 changed files with 64 additions and 0 deletions

View File

@@ -4,6 +4,8 @@ Some custom metrics
import functools
import multiprocessing
import logging
from cmath import rect, phase
from math import radians, degrees
import numpy as np
from scipy import stats
@@ -11,6 +13,8 @@ from scipy import stats
import torch
from torch.utils.data import Dataset
import utils
def kl_from_empirical(u: np.ndarray, v: np.ndarray, nbins: int = 100) -> float:
"""
@@ -75,3 +79,15 @@ def kl_from_dset(dset: Dataset, single_thread: bool = False) -> np.ndarray:
pool.close()
pool.join()
return np.array(kl_values)
def wrapped_mean(x: np.ndarray, min_val=-np.pi, max_val=np.pi) -> float:
"""
Wrap the mean function about the given range
"""
# https://rosettacode.org/wiki/Averages/Mean_angle
sin_x = np.sin(x)
cos_x = np.cos(x)
retval = np.arctan2(np.mean(sin_x), np.mean(cos_x))
return retval

View File

@@ -44,3 +44,51 @@ class TestKLFromEmpirical(unittest.TestCase):
kl = cm.kl_from_empirical(u, v)
self.assertEqual(np.inf, kl)
class TestWrappedMean(unittest.TestCase):
"""Test for the wrapped mean function"""
def setUp(self) -> None:
self.rng = np.random.default_rng(seed=6489)
self.rad2deg = lambda x: x * 180 / np.pi
self.deg2rad = lambda x: x * np.pi / 180
def test_simple(self):
"""Test a hand-engineered example"""
true_mean = 170
x = np.array([true_mean - 30, true_mean + 30])
x_rad = self.deg2rad(x)
m = cm.wrapped_mean(x_rad)
m_deg = self.rad2deg(m)
self.assertAlmostEqual(true_mean, m_deg, places=2)
def test_positive(self):
"""Simple test"""
x = self.rng.normal(loc=3.0, scale=0.25, size=100000)
m = cm.wrapped_mean(x)
self.assertAlmostEqual(m, 3.0, places=2)
def test_negative(self):
"""Test that wrapping a negative mean works"""
x = self.rng.normal(loc=-3.0, scale=0.25, size=100000)
m = cm.wrapped_mean(x)
self.assertAlmostEqual(m, -3.0, places=2)
def test_zero(self):
"""Test that a zero mean is still correctly handled"""
x = self.rng.normal(loc=0.0, scale=0.25, size=100000)
m = cm.wrapped_mean(x)
self.assertAlmostEqual(m, 0.0, places=2)
def test_positive_unwrapped(self):
"""Test positive values that don't actually require wrapping"""
x = self.rng.normal(loc=0.5, scale=0.25, size=100000)
m = cm.wrapped_mean(x)
self.assertAlmostEqual(m, 0.5, places=2)
def test_negative_unwrapped(self):
"""Test negative values don't actually require wrapping"""
x = self.rng.normal(loc=-0.5, scale=0.25, size=100000)
m = cm.wrapped_mean(x)
self.assertAlmostEqual(m, -0.5, places=2)