Add failing test case for loading without pytorch lightning

This commit is contained in:
Kevin Wu
2022-10-04 19:46:27 -07:00
parent 6ac88f7a44
commit 3f5408fe05

View File

@@ -197,5 +197,29 @@ class TestTransformerLoadingSaving(unittest.TestCase):
self.assertNotEqual(p1.data.ne(p2.data).sum(), 0)
class TestTransformerBaseLoadingSaving(unittest.TestCase):
"""
Test the loading and saving and re-loading of transformer models without
pytorch lightning
"""
def setUp(self) -> None:
self.orig_model_dir = os.path.join(
os.path.dirname(__file__), "mini_model_for_testing", "results"
)
assert os.path.isdir(self.orig_model_dir)
def test_saving_and_loading(self):
"""Test that we can load, save, and reload model"""
with tempfile.TemporaryDirectory() as tempdir:
orig_model = modelling.BertForDiffusionBase.from_dir(
self.orig_model_dir, copy_to=tempdir
)
new_model = modelling.BertForDiffusionBase.from_dir(tempdir)
# https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351
for p1, p2 in zip(orig_model.parameters(), new_model.parameters()):
self.assertAlmostEqual(p1.data.ne(p2.data).sum(), 0)
if __name__ == "__main__":
unittest.main()