Files
foldingdiff/tests/test_sampling.py
2022-11-08 14:41:39 -08:00

48 lines
1.4 KiB
Python

"""
Unit tests for sampling code
"""
import os
import unittest
import numpy as np
import torch
from foldingdiff import sampling
class TestSamplingReproducible(unittest.TestCase):
"""
Test that sampling is reproducible
"""
def setUp(self) -> None:
self.mini_model = os.path.join(
os.path.dirname(__file__), "mini_model_for_testing", "results"
)
assert os.path.isdir(self.mini_model)
self.full_model = "wukevin/foldingdiff"
def test_repro_simple(self):
"""Test that sampling from the same random seed produces the same result"""
torch.manual_seed(1234)
samp_1 = sampling.sample_simple(
self.full_model, n=1, sweep_lengths=[50, 51]
).pop()
torch.manual_seed(1234)
samp_2 = sampling.sample_simple(
self.full_model, n=1, sweep_lengths=[50, 51]
).pop()
self.assertTrue(np.allclose(samp_1.values, samp_2.values))
def test_diff_seed_diff_result(self):
"""Test that having a different random state produces different results"""
torch.manual_seed(1234)
samp_1 = sampling.sample_simple(
self.full_model, n=1, sweep_lengths=[50, 51]
).pop()
samp_2 = sampling.sample_simple(
self.full_model, n=1, sweep_lengths=[50, 51]
).pop()
self.assertFalse(np.allclose(samp_1.values, samp_2.values))