mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
"""
|
|
Unit tests for sampling code
|
|
"""
|
|
import os
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from foldingdiff import sampling
|
|
|
|
|
|
class TestSamplingReproducible(unittest.TestCase):
|
|
"""
|
|
Test that sampling is reproducible
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
self.mini_model = os.path.join(
|
|
os.path.dirname(__file__), "mini_model_for_testing", "results"
|
|
)
|
|
assert os.path.isdir(self.mini_model)
|
|
self.full_model = "wukevin/foldingdiff"
|
|
|
|
def test_repro_simple(self):
|
|
"""Test that sampling from the same random seed produces the same result"""
|
|
torch.manual_seed(1234)
|
|
samp_1 = sampling.sample_simple(
|
|
self.full_model, n=1, sweep_lengths=[50, 51]
|
|
).pop()
|
|
torch.manual_seed(1234)
|
|
samp_2 = sampling.sample_simple(
|
|
self.full_model, n=1, sweep_lengths=[50, 51]
|
|
).pop()
|
|
self.assertTrue(np.allclose(samp_1.values, samp_2.values))
|
|
|
|
def test_diff_seed_diff_result(self):
|
|
"""Test that having a different random state produces different results"""
|
|
torch.manual_seed(1234)
|
|
samp_1 = sampling.sample_simple(
|
|
self.full_model, n=1, sweep_lengths=[50, 51]
|
|
).pop()
|
|
samp_2 = sampling.sample_simple(
|
|
self.full_model, n=1, sweep_lengths=[50, 51]
|
|
).pop()
|
|
self.assertFalse(np.allclose(samp_1.values, samp_2.values))
|