Basic unit tests for data loaders

This commit is contained in:
Kevin Wu
2022-08-26 22:10:44 +00:00
parent 972a1dddc8
commit 5b78ca5472

90
tests/test_data.py Normal file
View File

@@ -0,0 +1,90 @@
"""
Unit tests to test data loaders. These primarily check that the data loaders return values
with expected shapes and ranges.
"""
import os, sys
import unittest
import numpy as np
SRC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "protdiff")
assert os.path.isdir(SRC_DIR)
sys.path.append(SRC_DIR)
import datasets
class TestCathCanonical(unittest.TestCase):
"""
Tests for the cath canonical angles dataset (i.e., not the trRosetta ones)
"""
def setUp(self) -> None:
# Setup the dataset
self.pad = 512
self.dset = datasets.CathCanonicalAnglesDataset(pad=self.pad)
def test_return_keys(self):
"""Test that returned dictionary has expected keys"""
d = self.dset[0]
self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask"]))
def test_num_feature(self):
"""Test that we have the expected number of features"""
d = self.dset[0]
self.assertEqual(d["angles"].shape[1], 5)
def test_shapes(self):
"""Test that the returned tensors have expected shapes"""
d = self.dset[1]
self.assertEqual(
d["angles"].shape, (self.pad, len(self.dset.feature_names["angles"]))
)
self.assertEqual(d["position_ids"].shape, (self.pad,))
self.assertEqual(d["attn_mask"].shape, (self.pad,))
def test_angles(self):
"""Test that angles do not fall outside of -pi and pi range"""
d = self.dset[2]
angular_idx = np.where(self.dset.feature_is_angular["angles"])[0]
self.assertTrue(np.all(d["angles"].numpy()[..., angular_idx] >= -np.pi))
self.assertTrue(np.all(d["angles"].numpy()[..., angular_idx] <= np.pi))
class TestCathCanonicalAnglesOnly(unittest.TestCase):
"""
Tests for the CATCH canonical angles only dataset (i.e. no distance returned)
"""
def setUp(self) -> None:
self.pad = 512
self.dset = datasets.CathCanonicalAnglesOnlyDataset(pad=self.pad)
def test_return_keys(self):
"""Test that returned dictionary has expected keys"""
d = self.dset[0]
self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask"]))
def test_num_features(self):
"""Test that we return the expected number of features and have correctly removed distance"""
d = self.dset[1]
self.assertEqual(d["angles"].shape[1], 4)
def test_all_angular(self):
"""Test that the dataset is all angular features and that this is properly registered"""
self.assertTrue(all(self.dset.feature_is_angular["angles"]))
def test_shapes(self):
"""Test that the returned tensors have expected shapes"""
d = self.dset[1]
self.assertEqual(
d["angles"].shape, (self.pad, len(self.dset.feature_names["angles"]))
)
self.assertEqual(d["position_ids"].shape, (self.pad,))
self.assertEqual(d["attn_mask"].shape, (self.pad,))
def test_angular_range(self):
"""Test that the returned angles are all between -pi and pi"""
d = self.dset[5]
self.assertTrue(np.all(d["angles"].numpy() >= -np.pi))
self.assertTrue(np.all(d["angles"].numpy() <= np.pi))