From 3f5408fe05128e077c059014f21d3f420aef0750 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 4 Oct 2022 19:46:27 -0700 Subject: [PATCH] Add failing test case for loading without pytorch lightning --- tests/test_transformer.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 5a992fd..16621b3 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -197,5 +197,29 @@ class TestTransformerLoadingSaving(unittest.TestCase): self.assertNotEqual(p1.data.ne(p2.data).sum(), 0) +class TestTransformerBaseLoadingSaving(unittest.TestCase): + """ + Test the loading and saving and re-loading of transformer models without + pytorch lightning + """ + def setUp(self) -> None: + self.orig_model_dir = os.path.join( + os.path.dirname(__file__), "mini_model_for_testing", "results" + ) + assert os.path.isdir(self.orig_model_dir) + + + def test_saving_and_loading(self): + """Test that we can load, save, and reload model""" + with tempfile.TemporaryDirectory() as tempdir: + orig_model = modelling.BertForDiffusionBase.from_dir( + self.orig_model_dir, copy_to=tempdir + ) + new_model = modelling.BertForDiffusionBase.from_dir(tempdir) + + # https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351 + for p1, p2 in zip(orig_model.parameters(), new_model.parameters()): + self.assertAlmostEqual(p1.data.ne(p2.data).sum(), 0) + if __name__ == "__main__": unittest.main()