mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 21:34:32 +08:00
Cleanup old code
This commit is contained in:
@@ -701,349 +701,6 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
|
||||
return retval
|
||||
|
||||
|
||||
class BertDenoiserEncoderModel(pl.LightningModule):
|
||||
"""
|
||||
Self implementation. Make sure that we know every bit of what goes into here so
|
||||
there's no more issues
|
||||
"""
|
||||
|
||||
loss_fn_dict = {
|
||||
"huber": F.smooth_l1_loss,
|
||||
"radian_l1": [
|
||||
F.smooth_l1_loss,
|
||||
losses.radian_l1_loss,
|
||||
losses.radian_l1_loss,
|
||||
losses.radian_l1_loss,
|
||||
],
|
||||
"radian_l1_smooth": [
|
||||
F.smooth_l1_loss,
|
||||
functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10),
|
||||
functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10),
|
||||
functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10),
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_inputs: int = 4,
|
||||
d_model: int = 256,
|
||||
num_layers: int = 6,
|
||||
intermediate_size: int = 512,
|
||||
max_seq_len: int = 512,
|
||||
num_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
|
||||
decoder: Literal["mlp", "linear"] = "linear",
|
||||
loss: Union[
|
||||
Callable, Literal["huber", "radian_l1", "radian_l1_smooth"]
|
||||
] = "huber",
|
||||
lr: float = 1e-4,
|
||||
l2: float = 0.0,
|
||||
l1: float = 0.0,
|
||||
circle_reg: float = 0.0,
|
||||
min_epochs: int = 500,
|
||||
steps_per_epoch: int = 100, # Dummy value
|
||||
lr_scheduler: Optional[Literal["OneCycleLR"]] = None,
|
||||
write_preds_to_dir: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.n_inputs = n_inputs
|
||||
self.max_seq_len = max_seq_len
|
||||
self.learning_rate = lr
|
||||
self.l2_lambda = l2
|
||||
self.l1_lambda = l1
|
||||
self.circ_lambda = circle_reg
|
||||
if self.circ_lambda > 0:
|
||||
raise NotImplementedError
|
||||
self.min_epochs = min_epochs
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
self.loss_func = self.loss_fn_dict[loss] if isinstance(loss, str) else loss
|
||||
pl.utilities.rank_zero_info(f"Using loss: {self.loss_func}")
|
||||
|
||||
# Define the positional embedding. Called as self.pos_encoder(x) and
|
||||
# returns the input + the positional embedding
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
d_model, max_len=self.max_seq_len, dropout=dropout
|
||||
)
|
||||
|
||||
# Define the time embedding
|
||||
if time_encoding == "gaussian_fourier":
|
||||
self.time_encoder = GaussianFourierProjection(d_model)
|
||||
elif time_encoding == "sinusoidal":
|
||||
self.time_encoder = SinusoidalPositionEmbeddings(d_model)
|
||||
else:
|
||||
raise ValueError(f"Unknown time encoding {time_encoding}")
|
||||
pl.utilities.rank_zero_info(f"Time encoding: {self.time_encoder}")
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.src_proj = nn.Linear(n_inputs, d_model)
|
||||
|
||||
# Set up the network to project token representation to our four outputs
|
||||
if decoder == "linear":
|
||||
self.tgt_out = nn.Linear(d_model, n_inputs)
|
||||
elif decoder == "mlp":
|
||||
self.tgt_out = AnglesPredictor(d_model, n_inputs)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized decoder: {decoder}")
|
||||
|
||||
# Define the transformer model itself
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer
|
||||
self.transformer = self.get_transformer()
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self.write_preds_to_dir = write_preds_to_dir
|
||||
self.write_preds_counter = 0
|
||||
if self.write_preds_to_dir:
|
||||
os.makedirs(self.write_preds_to_dir, exist_ok=True)
|
||||
|
||||
def get_transformer(self) -> nn.Module:
|
||||
"""
|
||||
Return the transformer model. Allows for easy overriding of the
|
||||
transformer aspect of the model for alternative architectures
|
||||
"""
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html#torch.nn.TransformerEncoderLayer
|
||||
enc_layer = nn.TransformerEncoderLayer(
|
||||
d_model=self.d_model,
|
||||
nhead=self.num_heads,
|
||||
dim_feedforward=self.intermediate_size,
|
||||
dropout=self.dropout,
|
||||
activation="gelu",
|
||||
layer_norm_eps=1e-5,
|
||||
batch_first=False, # Must do a permute to get batch first
|
||||
norm_first=True,
|
||||
)
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html#torch.nn.TransformerEncoder
|
||||
encoder = nn.TransformerEncoder(enc_layer, num_layers=self.num_layers)
|
||||
return encoder
|
||||
|
||||
def _init_weights(self) -> None:
|
||||
# Initialize transformer with xavier uniform
|
||||
for p in self.transformer.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Upscale the features to (N, S, E) = (batch, seq_len, emb_size)
|
||||
x_upscaled = self.src_proj(x)
|
||||
|
||||
# Add positional embeddings
|
||||
src_with_pos = self.pos_encoder(x_upscaled)
|
||||
|
||||
# Add time embeddings
|
||||
# time_embed shape (batch, n_features) --> (batch, 1, n_features)
|
||||
time_embed = self.time_encoder(timestep.squeeze(1)).unsqueeze(1)
|
||||
assert time_embed.shape == (x.shape[0], 1, self.d_model)
|
||||
# src_with_pos shape (batch, seq_len, n_features)
|
||||
src_with_pos_time = src_with_pos + time_embed
|
||||
|
||||
# Generate the src mask, follows https://pytorch.org/tutorials/beginner/translation_transformer.html
|
||||
# True --> NOT allowed to attend
|
||||
# False --> allowed to attend
|
||||
# Generate a vector of False
|
||||
src_mask = torch.zeros((self.max_seq_len, self.max_seq_len)).type(torch.bool)
|
||||
# Feed through transformer
|
||||
# shape (batch, seq_len, d_model) --> (seq_len, batch, d_model) --> (batch, seq_len, d_model)
|
||||
decoded = self.transformer(
|
||||
src_with_pos_time.permute(1, 0, 2),
|
||||
mask=src_mask,
|
||||
src_key_padding_mask=attn_mask,
|
||||
).permute(1, 0, 2)
|
||||
|
||||
# Decode to targets
|
||||
out = self.tgt_out(decoded)
|
||||
assert out.shape == x.shape
|
||||
return out
|
||||
|
||||
def ensure_mask_fmt(self, mask: torch.Tensor) -> torch.BoolTensor:
|
||||
"""
|
||||
Ensure that the mask is given in the correct format (i.e., a True
|
||||
value indicates masked and a False indicates not masked). This is
|
||||
required because HuggingFace transformers use the opposite where
|
||||
a 1/True value indicates a position to be attended and 0/False
|
||||
indicates a position that is masked
|
||||
"""
|
||||
assert torch.all(mask >= 0) and torch.all(mask <= 1)
|
||||
first_item = mask.flatten()[0]
|
||||
# if the first item is a 1.0 then we know that we have received
|
||||
# huggingface standard where 1 = attended. Flip to be 0 = attended
|
||||
if torch.isclose(first_item, torch.ones_like(first_item)):
|
||||
flipped_mask = ~(mask.bool())
|
||||
assert torch.all(
|
||||
torch.sum(mask) == torch.numel(flipped_mask) - torch.sum(flipped_mask)
|
||||
)
|
||||
assert torch.all(flipped_mask[torch.where(mask)] == False)
|
||||
return flipped_mask.bool()
|
||||
return mask.bool()
|
||||
|
||||
def _get_loss_terms(self, batch, write_preds: Optional[str] = None) -> torch.Tensor:
|
||||
"""
|
||||
Gets the loss terms for the model
|
||||
"""
|
||||
known_noise = batch["known_noise"]
|
||||
corrupted = batch["corrupted"]
|
||||
# Make sure the attention mask is False for unmasked
|
||||
attn_mask = self.ensure_mask_fmt(batch["attn_mask"])
|
||||
assert (
|
||||
attn_mask.dtype == torch.bool
|
||||
), f"{attn_mask} is not boolean - {attn_mask.dtype}"
|
||||
|
||||
predicted_noise = self.forward(
|
||||
corrupted, timestep=batch["t"], attn_mask=attn_mask
|
||||
)
|
||||
|
||||
# Under pytorch convention, 0 = not masked
|
||||
unmask_idx = torch.where(attn_mask == 0)
|
||||
assert len(unmask_idx) == 2
|
||||
loss_terms = []
|
||||
for i in range(known_noise.shape[-1]):
|
||||
loss_fn = (
|
||||
self.loss_func[i]
|
||||
if isinstance(self.loss_func, list)
|
||||
else self.loss_func
|
||||
)
|
||||
logging.debug(f"Using loss function {loss_fn}")
|
||||
|
||||
l = loss_fn(
|
||||
predicted_noise[unmask_idx[0], unmask_idx[1], i],
|
||||
known_noise[unmask_idx[0], unmask_idx[1], i],
|
||||
)
|
||||
loss_terms.append(l)
|
||||
|
||||
if write_preds is not None:
|
||||
with open(write_preds, "w") as f:
|
||||
d_to_write = {
|
||||
"known_noise": known_noise.cpu().numpy().tolist(),
|
||||
"predicted_noise": predicted_noise.cpu().numpy().tolist(),
|
||||
"attn_mask": attn_mask.cpu().numpy().tolist(),
|
||||
"losses": [l.item() for l in loss_terms],
|
||||
}
|
||||
json.dump(d_to_write, f)
|
||||
|
||||
return torch.stack(loss_terms)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
Training step for the model
|
||||
"""
|
||||
loss = self._get_loss_terms(batch)
|
||||
avg_loss = torch.mean(loss)
|
||||
|
||||
# L1 regularization
|
||||
if self.l1_lambda > 0:
|
||||
l1_penalty = sum(torch.linalg.norm(p, 1) for p in self.parameters())
|
||||
self.log("l1_penalty", l1_penalty, sync_dist=True, rank_zero_only=True)
|
||||
avg_loss += self.l1_lambda * l1_penalty
|
||||
|
||||
for loss_name, loss_val in zip(["bond_dist", "omega", "theta", "phi"], loss):
|
||||
self.log(
|
||||
f"train_{loss_name}", loss_val, sync_dist=True, rank_zero_only=True
|
||||
)
|
||||
self.log("train_loss", avg_loss, sync_dist=True, rank_zero_only=True)
|
||||
return avg_loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
with torch.no_grad():
|
||||
loss_terms = self._get_loss_terms(
|
||||
batch,
|
||||
write_preds=os.path.join(
|
||||
self.write_preds_to_dir, f"{self.write_preds_counter}_preds.json"
|
||||
),
|
||||
)
|
||||
self.write_preds_counter += 1
|
||||
|
||||
avg_loss = torch.mean(loss_terms)
|
||||
|
||||
for loss_name, loss_val in zip(
|
||||
["bond_dist", "omega", "theta", "phi"], loss_terms
|
||||
):
|
||||
self.log(f"val_{loss_name}", loss_val, sync_dist=True, rank_zero_only=True)
|
||||
self.log("val_loss", avg_loss, sync_dist=True, rank_zero_only=True)
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""
|
||||
References:
|
||||
* https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html
|
||||
* https://pytorch.org/docs/stable/optim.html
|
||||
"""
|
||||
optim = torch.optim.AdamW(
|
||||
self.parameters(),
|
||||
lr=self.learning_rate,
|
||||
weight_decay=self.l2_lambda,
|
||||
)
|
||||
retval = {"optimizer": optim}
|
||||
if self.lr_scheduler:
|
||||
if self.lr_scheduler == "OneCycleLR":
|
||||
retval["lr_scheduler"] = {
|
||||
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
|
||||
optim,
|
||||
max_lr=1e-2,
|
||||
epochs=self.min_epochs,
|
||||
steps_per_epoch=self.steps_per_epoch,
|
||||
),
|
||||
"monitor": "val_loss",
|
||||
"frequency": 1,
|
||||
"interval": "step",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown lr scheduler {self.lr_scheduler}")
|
||||
pl.utilities.rank_zero_info(f"Using optimizer {retval}")
|
||||
return retval
|
||||
|
||||
|
||||
class BertDenoiserSeq2SeqModel(BertDenoiserEncoderModel):
|
||||
"""
|
||||
Use a seq2seq model instead of a encoder only transformer
|
||||
"""
|
||||
|
||||
def get_transformer(self) -> nn.Module:
|
||||
nn.Transformer(
|
||||
d_model=self.d_model,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="gelu",
|
||||
norm_first=False,
|
||||
batch_first=True,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def get_causal_tgt_mask(
|
||||
self, tgt_seq_len: Optional[int] = None
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Get a causal mask for target sequence where each row allows only
|
||||
the next token to be seen. This is important because otherwise
|
||||
the decoder can simply pass through the known target sequence.
|
||||
Example output:
|
||||
# [F, T, T, T]
|
||||
# [F, F, T, T]
|
||||
# [F, F, F, T]
|
||||
# [F, F, F, F]
|
||||
"""
|
||||
# Lower triangular matrix
|
||||
if tgt_seq_len is None:
|
||||
tgt_seq_len = self.max_seq_len
|
||||
mask = ~torch.tril(torch.ones(tgt_seq_len, tgt_seq_len) == 1).bool()
|
||||
# If a BoolTensor is provided, positions with True are not allowed to attend
|
||||
# while False values will be unchanged (i.e., True => masked)
|
||||
# If a FloatTensor is provided, it will be added to the attention weight
|
||||
return mask
|
||||
|
||||
|
||||
def main():
|
||||
"""on the fly testing"""
|
||||
import datasets
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
Code for sampling from diffusion models
|
||||
"""
|
||||
from cmath import isnan
|
||||
import logging
|
||||
from typing import *
|
||||
|
||||
|
||||
@@ -165,193 +165,5 @@ class TestHuggingFaceBertModel(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestBertDenoiserEncoderModel(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.model = modelling.BertDenoiserEncoderModel(
|
||||
d_model=32, num_heads=4, num_layers=6
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
self.bs = 32
|
||||
rng = np.random.default_rng(6489)
|
||||
torch.random.manual_seed(6489)
|
||||
|
||||
# Generate attention masks by huggingface convention
|
||||
# These should be auto-converted to the pytorch convention
|
||||
# Do not generate sequences in the last 5 so we never have attention there
|
||||
# This helps create an easy test for attention masking
|
||||
self.always_masked_residues = 5
|
||||
lengths = [
|
||||
rng.integers(100, self.model.max_seq_len - self.always_masked_residues)
|
||||
for _ in range(self.bs)
|
||||
]
|
||||
self.attn_masks = torch.zeros((self.bs, self.model.max_seq_len))
|
||||
for i, l in enumerate(lengths):
|
||||
self.attn_masks[i][:l] = 1.0
|
||||
assert sum(self.attn_masks[i]) == l
|
||||
|
||||
# Generate random timesteps
|
||||
self.timesteps = torch.from_numpy(
|
||||
rng.integers(0, 250, size=(self.bs, 1))
|
||||
).long()
|
||||
|
||||
# Generate random inputs
|
||||
self.inputs = torch.randn(
|
||||
(self.bs, self.model.max_seq_len, self.model.n_inputs)
|
||||
)
|
||||
|
||||
# Generate noise vectors that correspond to masked positions
|
||||
unmask_positions = torch.where(self.attn_masks == 1.0)
|
||||
self.noise_on_masked = torch.randn_like(self.inputs)
|
||||
# zero out the positions that are not masked so we do NOT noise them
|
||||
self.noise_on_masked[unmask_positions] = 0.0
|
||||
for i, l in enumerate(lengths):
|
||||
assert torch.all(
|
||||
torch.isclose(self.noise_on_masked[i][:l], torch.tensor(0.0))
|
||||
)
|
||||
|
||||
# Inputs with noise on masked positions
|
||||
self.inputs_with_noise_on_mask = self.inputs + self.noise_on_masked
|
||||
for i, l in enumerate(lengths):
|
||||
# Check that the unmaske indices are unmodified
|
||||
assert torch.all(
|
||||
torch.isclose(self.inputs_with_noise_on_mask[i][:l], self.inputs[i][:l])
|
||||
)
|
||||
|
||||
assert (
|
||||
self.inputs.shape[0]
|
||||
== self.timesteps.shape[0]
|
||||
== self.attn_masks.shape[0]
|
||||
== self.bs
|
||||
)
|
||||
assert self.inputs.shape[1] == self.attn_masks.shape[1]
|
||||
|
||||
def test_consistency(self):
|
||||
"""
|
||||
Test that given the same input the model gives the same output
|
||||
"""
|
||||
x = self.model(
|
||||
x=self.inputs, timestep=self.timesteps, attn_mask=self.attn_masks
|
||||
)
|
||||
y = self.model(
|
||||
x=self.inputs, timestep=self.timesteps, attn_mask=self.attn_masks
|
||||
)
|
||||
self.assertTrue(torch.allclose(x, y))
|
||||
|
||||
def test_batch_order_consistency(self):
|
||||
"""
|
||||
Test that the model is invariant to the order of inputs in a batch
|
||||
"""
|
||||
# Run the inputs through as a "baseline" set of values
|
||||
x = self.inputs
|
||||
with torch.no_grad():
|
||||
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
|
||||
|
||||
# Shuffle the order of the inputs and run them through the model again, expect same output
|
||||
# https://pytorch.org/docs/stable/generated/torch.randperm.html
|
||||
idx = torch.randperm(x.shape[0])
|
||||
assert idx.shape[0] == x.shape[0]
|
||||
with torch.no_grad():
|
||||
shuffled_out = self.model(
|
||||
x=x[idx], timestep=self.timesteps[idx], attn_mask=self.attn_masks[idx]
|
||||
)
|
||||
# Shuffle the known outputs to match
|
||||
out_reordered = out[idx]
|
||||
self.assertTrue(
|
||||
torch.allclose(out_reordered, shuffled_out, atol=ATOL, rtol=RTOL),
|
||||
msg=f"Got different outputs: {out.flatten()[:5]} {shuffled_out.flatten()[:5]}",
|
||||
)
|
||||
|
||||
def test_batch_order_consistency_reversed(self):
|
||||
"""
|
||||
Test that reversing the batch order of inputs does not change output
|
||||
"""
|
||||
x = self.inputs
|
||||
with torch.no_grad():
|
||||
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
|
||||
|
||||
with torch.no_grad():
|
||||
rev_out = self.model(
|
||||
x=torch.flip(x, dims=(0,)),
|
||||
timestep=torch.flip(self.timesteps, dims=(0,)),
|
||||
attn_mask=torch.flip(self.attn_masks, dims=(0,)),
|
||||
)
|
||||
|
||||
self.assertEqual(self.bs, out.shape[0])
|
||||
self.assertEqual(self.bs, rev_out.shape[0])
|
||||
self.assertTrue(
|
||||
torch.allclose(torch.flip(out, dims=(0,)), rev_out, atol=ATOL, rtol=RTOL),
|
||||
msg=f"Mismatch on reversal: {out[-2]} != {rev_out[1]}",
|
||||
)
|
||||
|
||||
def test_attn_mask_reformat(self):
|
||||
"""
|
||||
Test that the mask format is detected and converted to correct
|
||||
PyTorch native format. Specifically, huggingface gives masked
|
||||
positions as 0.0, pytorch expects masked positions as True
|
||||
"""
|
||||
converted_mask = self.model.ensure_mask_fmt(self.attn_masks)
|
||||
# huggingface masked indices are indicated by 0.
|
||||
orig_masked_indices = torch.where(self.attn_masks == 0.0)
|
||||
# pytorch masked indicies are indicated by True
|
||||
conv_masked_indices = torch.where(converted_mask)
|
||||
for i, j in zip(orig_masked_indices, conv_masked_indices):
|
||||
self.assertTrue(torch.all(i == j))
|
||||
|
||||
def test_noise_invariance_easy(self):
|
||||
"""
|
||||
Easy test for noise invariance that focuses on the last few
|
||||
residues that should never be attended to
|
||||
"""
|
||||
x = self.inputs
|
||||
with torch.no_grad():
|
||||
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
|
||||
|
||||
noise = torch.randn_like(x)
|
||||
noise[:, : -self.always_masked_residues] = 0.0
|
||||
# Check that there is no noise in the leading residues
|
||||
assert torch.all(noise[:, : -self.always_masked_residues] == 0.0)
|
||||
noised_x = x + noise
|
||||
|
||||
with torch.no_grad():
|
||||
noised_out = self.model(
|
||||
x=noised_x, timestep=self.timesteps, attn_mask=self.attn_masks,
|
||||
)
|
||||
|
||||
unmasked_idx = torch.where(self.attn_masks == 1.0)
|
||||
self.assertTrue(
|
||||
torch.allclose(out[unmasked_idx], noised_out[unmasked_idx]),
|
||||
msg=f"Got different outputs: {out.flatten()[:5]} {noised_out.flatten()[:5]}",
|
||||
)
|
||||
|
||||
def test_noise_invariance(self):
|
||||
"""
|
||||
Test that noising masked positions should not affect output
|
||||
"""
|
||||
# Run the inputs through as a "baseline" set of values
|
||||
x = self.inputs
|
||||
with torch.no_grad():
|
||||
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
|
||||
|
||||
# Noise the inputs and run them through the model again, expect same output
|
||||
noised_x = x + self.noise_on_masked
|
||||
unmasked_idx = torch.where(self.attn_masks == 1.0)
|
||||
# Check that the x and noised x agree at unmasked indices
|
||||
assert torch.allclose(x[unmasked_idx], noised_x[unmasked_idx])
|
||||
|
||||
with torch.no_grad():
|
||||
noised_out = self.model(
|
||||
x=noised_x, timestep=self.timesteps, attn_mask=self.attn_masks,
|
||||
)
|
||||
|
||||
unmasked_idx = torch.where(self.attn_masks == 1.0)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
out[unmasked_idx], noised_out[unmasked_idx], atol=ATOL, rtol=RTOL
|
||||
),
|
||||
msg=f"Got different outputs: {out[unmasked_idx].flatten()[:5]} {noised_out[unmasked_idx].flatten()[:5]}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user