mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
241 lines
8.5 KiB
Python
241 lines
8.5 KiB
Python
import os
|
|
import unittest
|
|
import tempfile
|
|
|
|
import numpy as np
|
|
import torch
|
|
from transformers import BertConfig
|
|
|
|
from foldingdiff import modelling
|
|
|
|
ATOL, RTOL = 1e-6, 1e-3
|
|
|
|
|
|
class TestHuggingFaceBertModel(unittest.TestCase):
|
|
"""
|
|
Test the BERT model for diffusion built on HuggingFace library
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
self.max_seq_len = 512
|
|
self.cfg = BertConfig(max_position_embeddings=self.max_seq_len, use_cache=False)
|
|
self.model = modelling.BertForDiffusion(config=self.cfg)
|
|
self.model.eval()
|
|
|
|
self.bs = 32
|
|
rng = np.random.default_rng(6489)
|
|
torch.random.manual_seed(6489)
|
|
|
|
# Generate attention masks by huggingface convention
|
|
# These should be auto-converted to the pytorch convention
|
|
# Do not generate sequences in the last 5 so we never have attention there
|
|
# This helps create an easy test for attention masking
|
|
self.always_masked_residues = 5
|
|
lengths = [
|
|
rng.integers(100, self.max_seq_len - self.always_masked_residues)
|
|
for _ in range(self.bs)
|
|
]
|
|
self.attn_masks = torch.zeros((self.bs, self.max_seq_len))
|
|
for i, l in enumerate(lengths):
|
|
self.attn_masks[i][:l] = 1.0
|
|
|
|
# Generate random timesteps
|
|
self.timesteps = torch.from_numpy(
|
|
rng.integers(0, 250, size=(self.bs, 1))
|
|
).long()
|
|
|
|
# Generate random inputs
|
|
self.inputs = torch.randn((self.bs, self.max_seq_len, self.model.n_inputs))
|
|
|
|
# Generate noise vectors that correspond to masked positions
|
|
unmask_positions = torch.where(self.attn_masks == 1.0)
|
|
self.noise_on_masked = torch.randn_like(self.inputs)
|
|
# zero out the positions that are not masked so we do NOT noise them
|
|
self.noise_on_masked[unmask_positions] = 0.0
|
|
for i, l in enumerate(lengths):
|
|
assert torch.all(
|
|
torch.isclose(self.noise_on_masked[i][:l], torch.tensor(0.0))
|
|
)
|
|
|
|
# Inputs with noise on masked positions
|
|
self.inputs_with_noise_on_mask = self.inputs + self.noise_on_masked
|
|
for i, l in enumerate(lengths):
|
|
# Check that the unmaske indices are unmodified
|
|
assert torch.all(
|
|
torch.isclose(self.inputs_with_noise_on_mask[i][:l], self.inputs[i][:l])
|
|
)
|
|
|
|
# set position ids
|
|
self.position_ids = (
|
|
torch.arange(0, self.max_seq_len)
|
|
.unsqueeze(0)
|
|
.expand(self.bs, self.max_seq_len)
|
|
)
|
|
|
|
assert (
|
|
self.inputs.shape[0]
|
|
== self.timesteps.shape[0]
|
|
== self.attn_masks.shape[0]
|
|
== self.bs
|
|
)
|
|
assert self.inputs.shape[1] == self.attn_masks.shape[1]
|
|
|
|
def test_consistency(self):
|
|
"""
|
|
Test that given the same input the model gives the same output
|
|
"""
|
|
x = self.model(
|
|
inputs=self.inputs,
|
|
timestep=self.timesteps,
|
|
attention_mask=self.attn_masks,
|
|
position_ids=self.position_ids,
|
|
)
|
|
y = self.model(
|
|
inputs=self.inputs,
|
|
timestep=self.timesteps,
|
|
attention_mask=self.attn_masks,
|
|
position_ids=self.position_ids,
|
|
)
|
|
self.assertTrue(torch.allclose(x, y))
|
|
|
|
def test_noise_invariance_easy(self):
|
|
"""
|
|
Easy test for noise invariance that focuses on the last few
|
|
residues that should never be attended to
|
|
"""
|
|
x = self.inputs
|
|
with torch.no_grad():
|
|
out = self.model(
|
|
inputs=x,
|
|
timestep=self.timesteps,
|
|
attention_mask=self.attn_masks,
|
|
position_ids=self.position_ids,
|
|
)
|
|
|
|
noise = torch.randn_like(x)
|
|
noise[:, : -self.always_masked_residues] = 0.0
|
|
# Check that there is no noise in the leading residues
|
|
assert torch.all(noise[:, : -self.always_masked_residues] == 0.0)
|
|
noised_x = x + noise
|
|
|
|
with torch.no_grad():
|
|
noised_out = self.model(
|
|
inputs=noised_x,
|
|
timestep=self.timesteps,
|
|
attention_mask=self.attn_masks,
|
|
position_ids=self.position_ids,
|
|
)
|
|
unmasked_idx = torch.where(self.attn_masks == 1.0)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
out[unmasked_idx], noised_out[unmasked_idx], rtol=RTOL, atol=ATOL
|
|
),
|
|
msg=f"Got different outputs: {out.flatten()[:5]} {noised_out.flatten()[:5]}",
|
|
)
|
|
|
|
def test_batch_order_agnostic(self):
|
|
"""
|
|
Reversing the batch order of inputs does not change output values
|
|
"""
|
|
x = self.inputs
|
|
with torch.no_grad():
|
|
out = self.model(
|
|
inputs=x,
|
|
timestep=self.timesteps,
|
|
attention_mask=self.attn_masks,
|
|
position_ids=self.position_ids,
|
|
)
|
|
|
|
with torch.no_grad():
|
|
rev_out = self.model(
|
|
inputs=torch.flip(x, dims=(0,)),
|
|
timestep=torch.flip(self.timesteps, dims=(0,)),
|
|
attention_mask=torch.flip(self.attn_masks, dims=(0,)),
|
|
position_ids=torch.flip(self.position_ids, dims=(0,)),
|
|
)
|
|
|
|
self.assertEqual(self.bs, out.shape[0])
|
|
self.assertEqual(self.bs, rev_out.shape[0])
|
|
self.assertTrue(
|
|
torch.allclose(torch.flip(out, dims=(0,)), rev_out, atol=ATOL, rtol=RTOL),
|
|
msg=f"Mismatch on reversal: {out[-2]} != {rev_out[1]}",
|
|
)
|
|
|
|
|
|
class TestTransformerLoadingSaving(unittest.TestCase):
|
|
"""
|
|
Test the loading, saving, and then re-loading of transformer models.
|
|
"""
|
|
|
|
def setUp(self) -> None:
|
|
self.orig_model_dir = os.path.join(
|
|
os.path.dirname(__file__), "mini_model_for_testing", "results"
|
|
)
|
|
assert os.path.isdir(self.orig_model_dir)
|
|
|
|
def test_saving_and_loading(self):
|
|
"""Test that we can load, save, and reload model"""
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
orig_model = modelling.BertForDiffusion.from_dir(
|
|
self.orig_model_dir, copy_to=tempdir
|
|
)
|
|
new_model = modelling.BertForDiffusion.from_dir(tempdir)
|
|
|
|
# https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351
|
|
for p1, p2 in zip(orig_model.parameters(), new_model.parameters()):
|
|
self.assertAlmostEqual(p1.data.ne(p2.data).sum(), 0)
|
|
|
|
def test_loading_without_model_checkpoint(self):
|
|
"""Test that loading ignores weights correctly"""
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
orig_model = modelling.BertForDiffusion.from_dir(
|
|
self.orig_model_dir, copy_to=tempdir
|
|
)
|
|
new_model = modelling.BertForDiffusion.from_dir(tempdir, load_weights=False)
|
|
|
|
for p1, p2 in zip(orig_model.parameters(), new_model.parameters()):
|
|
self.assertNotEqual(p1.data.ne(p2.data).sum(), 0)
|
|
|
|
|
|
class TestTransformerBaseLoadingSaving(unittest.TestCase):
|
|
"""
|
|
Test the loading and saving and re-loading of transformer models without
|
|
pytorch lightning
|
|
"""
|
|
def setUp(self) -> None:
|
|
self.orig_model_dir = os.path.join(
|
|
os.path.dirname(__file__), "mini_model_for_testing", "results"
|
|
)
|
|
assert os.path.isdir(self.orig_model_dir)
|
|
|
|
|
|
def test_saving_and_loading(self):
|
|
"""Test that we can load, save, and reload model"""
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
orig_model = modelling.BertForDiffusionBase.from_dir(
|
|
self.orig_model_dir, copy_to=tempdir
|
|
)
|
|
new_model = modelling.BertForDiffusionBase.from_dir(tempdir)
|
|
|
|
# https://discuss.pytorch.org/t/check-if-models-have-same-weights/4351
|
|
for p1, p2 in zip(orig_model.parameters(), new_model.parameters()):
|
|
self.assertAlmostEqual(p1.data.ne(p2.data).sum(), 0)
|
|
|
|
def test_against_pl(self):
|
|
"""
|
|
Test that loading with or without pl lightning produces the same results
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
nn_model = modelling.BertForDiffusionBase.from_dir(
|
|
self.orig_model_dir, copy_to=tempdir
|
|
)
|
|
pl_model = modelling.BertForDiffusion.from_dir(
|
|
self.orig_model_dir, copy_to=tempdir
|
|
)
|
|
for p1, p2 in zip(nn_model.parameters(), pl_model.parameters()):
|
|
self.assertAlmostEqual(p1.data.ne(p2.data).sum(), 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|