More tests

This commit is contained in:
Kevin Wu
2022-10-31 15:30:35 -07:00
parent fabfb694cc
commit d4646309e5

View File

@@ -8,7 +8,7 @@ import unittest
import numpy as np
import torch
from foldingdiff import datasets
from foldingdiff import datasets, utils
class TestCathCanonical(unittest.TestCase):
@@ -25,7 +25,10 @@ class TestCathCanonical(unittest.TestCase):
def test_return_keys(self):
"""Test that returned dictionary has expected keys"""
d = self.dset[0]
self.assertEqual(set(d.keys()), set(["angles", "coords", "position_ids", "attn_mask", "lengths"]))
self.assertEqual(
set(d.keys()),
set(["angles", "coords", "position_ids", "attn_mask", "lengths"]),
)
def test_num_feature(self):
"""Test that we have the expected number of features"""
@@ -66,7 +69,10 @@ class TestCathCanonicalAnglesOnly(unittest.TestCase):
def test_return_keys(self):
"""Test that returned dictionary has expected keys"""
d = self.dset[0]
self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask", "coords", "lengths"]))
self.assertEqual(
set(d.keys()),
set(["angles", "position_ids", "attn_mask", "coords", "lengths"]),
)
def test_num_features(self):
"""Test that we return the expected number of features and have correctly removed distance"""
@@ -146,3 +152,16 @@ class TestNoisedDataset(unittest.TestCase):
x = self.noised_dset[1]["angles"]
y = self.noised_dset[1]["angles"]
self.assertTrue(torch.allclose(x, y))
def test_angles_reconstructed(self):
"""Test that subtracting noise from corrupted angles (with constant scaling) recovers original angles"""
d = self.noised_dset[3]
noised_angles = d["corrupted"]
orig_angles = d["angles"]
noise = d["known_noise"]
recovered = (noised_angles - d["sqrt_one_minus_alphas_cumprod_t"] * noise) / d[
"sqrt_alphas_cumprod_t"
]
recovered = utils.modulo_with_wrapped_range(recovered, -np.pi, np.pi)
self.assertTrue(torch.allclose(recovered, orig_angles, atol=1e-5))