mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Add failing test case for loading without pytorch lightning
This commit is contained in:
@@ -197,5 +197,29 @@ class TestTransformerLoadingSaving(unittest.TestCase):
|
||||
self.assertNotEqual(p1.data.ne(p2.data).sum(), 0)
|
||||
|
||||
|
||||
class TestTransformerBaseLoadingSaving(unittest.TestCase):
|
||||
"""
|
||||
Test the loading and saving and re-loading of transformer models without
|
||||
pytorch lightning
|
||||
"""
|
||||
def setUp(self) -> None:
|
||||
self.orig_model_dir = os.path.join(
|
||||
os.path.dirname(__file__), "mini_model_for_testing", "results"
|
||||
)
|
||||
assert os.path.isdir(self.orig_model_dir)
|
||||
|
||||
|
||||
def test_saving_and_loading(self):
|
||||
"""Test that we can load, save, and reload model"""
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
orig_model = modelling.BertForDiffusionBase.from_dir(
|
||||
self.orig_model_dir, copy_to=tempdir
|
||||
)
|
||||
new_model = modelling.BertForDiffusionBase.from_dir(tempdir)
|
||||
|
||||
# https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351
|
||||
for p1, p2 in zip(orig_model.parameters(), new_model.parameters()):
|
||||
self.assertAlmostEqual(p1.data.ne(p2.data).sum(), 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user