diff --git a/tests/test_data.py b/tests/test_data.py index 48c3932..90c1646 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -25,7 +25,7 @@ class TestCathCanonical(unittest.TestCase): def test_return_keys(self): """Test that returned dictionary has expected keys""" d = self.dset[0] - self.assertEqual(set(d.keys()), set(["angles", "coords", "position_ids", "attn_mask"])) + self.assertEqual(set(d.keys()), set(["angles", "coords", "position_ids", "attn_mask", "lengths"])) def test_num_feature(self): """Test that we have the expected number of features""" @@ -66,7 +66,7 @@ class TestCathCanonicalAnglesOnly(unittest.TestCase): def test_return_keys(self): """Test that returned dictionary has expected keys""" d = self.dset[0] - self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask"])) + self.assertEqual(set(d.keys()), set(["angles", "position_ids", "attn_mask", "coords", "lengths"])) def test_num_features(self): """Test that we return the expected number of features and have correctly removed distance"""