mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
More tests
This commit is contained in:
@@ -8,7 +8,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from foldingdiff import datasets
|
||||
from foldingdiff import datasets, utils
|
||||
|
||||
|
||||
class TestCathCanonical(unittest.TestCase):
|
||||
@@ -25,7 +25,10 @@ class TestCathCanonical(unittest.TestCase):
|
||||
def test_return_keys(self):
|
||||
"""Test that returned dictionary has expected keys"""
|
||||
d = self.dset[0]
|
||||
self.assertEqual(set(d.keys()), set(["angles", "coords", "position_ids", "attn_mask", "lengths"]))
|
||||
self.assertEqual(
|
||||
set(d.keys()),
|
||||
set(["angles", "coords", "position_ids", "attn_mask", "lengths"]),
|
||||
)
|
||||
|
||||
def test_num_feature(self):
|
||||
"""Test that we have the expected number of features"""
|
||||
@@ -66,7 +69,10 @@ class TestCathCanonicalAnglesOnly(unittest.TestCase):
|
||||
def test_return_keys(self):
|
||||
"""Test that returned dictionary has expected keys"""
|
||||
d = self.dset[0]
|
||||
self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask", "coords", "lengths"]))
|
||||
self.assertEqual(
|
||||
set(d.keys()),
|
||||
set(["angles", "position_ids", "attn_mask", "coords", "lengths"]),
|
||||
)
|
||||
|
||||
def test_num_features(self):
|
||||
"""Test that we return the expected number of features and have correctly removed distance"""
|
||||
@@ -146,3 +152,16 @@ class TestNoisedDataset(unittest.TestCase):
|
||||
x = self.noised_dset[1]["angles"]
|
||||
y = self.noised_dset[1]["angles"]
|
||||
self.assertTrue(torch.allclose(x, y))
|
||||
|
||||
def test_angles_reconstructed(self):
|
||||
"""Test that subtracting noise from corrupted angles (with constant scaling) recovers original angles"""
|
||||
d = self.noised_dset[3]
|
||||
noised_angles = d["corrupted"]
|
||||
orig_angles = d["angles"]
|
||||
noise = d["known_noise"]
|
||||
|
||||
recovered = (noised_angles - d["sqrt_one_minus_alphas_cumprod_t"] * noise) / d[
|
||||
"sqrt_alphas_cumprod_t"
|
||||
]
|
||||
recovered = utils.modulo_with_wrapped_range(recovered, -np.pi, np.pi)
|
||||
self.assertTrue(torch.allclose(recovered, orig_angles, atol=1e-5))
|
||||
|
||||
Reference in New Issue
Block a user