diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..c14843d --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,90 @@ +""" +Unit tests to test data loaders. These primarily check that the data loaders return values +with expected shapes and ranges. +""" + +import os, sys +import unittest + +import numpy as np + +SRC_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "protdiff") +assert os.path.isdir(SRC_DIR) +sys.path.append(SRC_DIR) +import datasets + + +class TestCathCanonical(unittest.TestCase): + """ + Tests for the cath canonical angles dataset (i.e., not the trRosetta ones) + """ + + def setUp(self) -> None: + # Setup the dataset + self.pad = 512 + self.dset = datasets.CathCanonicalAnglesDataset(pad=self.pad) + + def test_return_keys(self): + """Test that returned dictionary has expected keys""" + d = self.dset[0] + self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask"])) + + def test_num_feature(self): + """Test that we have the expected number of features""" + d = self.dset[0] + self.assertEqual(d["angles"].shape[1], 5) + + def test_shapes(self): + """Test that the returned tensors have expected shapes""" + d = self.dset[1] + self.assertEqual( + d["angles"].shape, (self.pad, len(self.dset.feature_names["angles"])) + ) + self.assertEqual(d["position_ids"].shape, (self.pad,)) + self.assertEqual(d["attn_mask"].shape, (self.pad,)) + + def test_angles(self): + """Test that angles do not fall outside of -pi and pi range""" + d = self.dset[2] + angular_idx = np.where(self.dset.feature_is_angular["angles"])[0] + self.assertTrue(np.all(d["angles"].numpy()[..., angular_idx] >= -np.pi)) + self.assertTrue(np.all(d["angles"].numpy()[..., angular_idx] <= np.pi)) + + +class TestCathCanonicalAnglesOnly(unittest.TestCase): + """ + Tests for the CATCH canonical angles only dataset (i.e. no distance returned) + """ + + def setUp(self) -> None: + self.pad = 512 + self.dset = datasets.CathCanonicalAnglesOnlyDataset(pad=self.pad) + + def test_return_keys(self): + """Test that returned dictionary has expected keys""" + d = self.dset[0] + self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask"])) + + def test_num_features(self): + """Test that we return the expected number of features and have correctly removed distance""" + d = self.dset[1] + self.assertEqual(d["angles"].shape[1], 4) + + def test_all_angular(self): + """Test that the dataset is all angular features and that this is properly registered""" + self.assertTrue(all(self.dset.feature_is_angular["angles"])) + + def test_shapes(self): + """Test that the returned tensors have expected shapes""" + d = self.dset[1] + self.assertEqual( + d["angles"].shape, (self.pad, len(self.dset.feature_names["angles"])) + ) + self.assertEqual(d["position_ids"].shape, (self.pad,)) + self.assertEqual(d["attn_mask"].shape, (self.pad,)) + + def test_angular_range(self): + """Test that the returned angles are all between -pi and pi""" + d = self.dset[5] + self.assertTrue(np.all(d["angles"].numpy() >= -np.pi)) + self.assertTrue(np.all(d["angles"].numpy() <= np.pi)) \ No newline at end of file