Cleanup old code

This commit is contained in:
Kevin Wu
2022-09-01 23:10:49 +00:00
parent d18318b1cf
commit ffaffc5d1f
3 changed files with 0 additions and 532 deletions

View File

@@ -701,349 +701,6 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
return retval
class BertDenoiserEncoderModel(pl.LightningModule):
"""
Self implementation. Make sure that we know every bit of what goes into here so
there's no more issues
"""
loss_fn_dict = {
"huber": F.smooth_l1_loss,
"radian_l1": [
F.smooth_l1_loss,
losses.radian_l1_loss,
losses.radian_l1_loss,
losses.radian_l1_loss,
],
"radian_l1_smooth": [
F.smooth_l1_loss,
functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10),
functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10),
functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10),
],
}
def __init__(
self,
n_inputs: int = 4,
d_model: int = 256,
num_layers: int = 6,
intermediate_size: int = 512,
max_seq_len: int = 512,
num_heads: int = 8,
dropout: float = 0.1,
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
decoder: Literal["mlp", "linear"] = "linear",
loss: Union[
Callable, Literal["huber", "radian_l1", "radian_l1_smooth"]
] = "huber",
lr: float = 1e-4,
l2: float = 0.0,
l1: float = 0.0,
circle_reg: float = 0.0,
min_epochs: int = 500,
steps_per_epoch: int = 100, # Dummy value
lr_scheduler: Optional[Literal["OneCycleLR"]] = None,
write_preds_to_dir: Optional[str] = None,
) -> None:
super().__init__()
self.n_inputs = n_inputs
self.max_seq_len = max_seq_len
self.learning_rate = lr
self.l2_lambda = l2
self.l1_lambda = l1
self.circ_lambda = circle_reg
if self.circ_lambda > 0:
raise NotImplementedError
self.min_epochs = min_epochs
self.steps_per_epoch = steps_per_epoch
self.lr_scheduler = lr_scheduler
self.loss_func = self.loss_fn_dict[loss] if isinstance(loss, str) else loss
pl.utilities.rank_zero_info(f"Using loss: {self.loss_func}")
# Define the positional embedding. Called as self.pos_encoder(x) and
# returns the input + the positional embedding
self.pos_encoder = PositionalEncoding(
d_model, max_len=self.max_seq_len, dropout=dropout
)
# Define the time embedding
if time_encoding == "gaussian_fourier":
self.time_encoder = GaussianFourierProjection(d_model)
elif time_encoding == "sinusoidal":
self.time_encoder = SinusoidalPositionEmbeddings(d_model)
else:
raise ValueError(f"Unknown time encoding {time_encoding}")
pl.utilities.rank_zero_info(f"Time encoding: {self.time_encoder}")
self.num_layers = num_layers
self.d_model = d_model
self.dropout = dropout
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.src_proj = nn.Linear(n_inputs, d_model)
# Set up the network to project token representation to our four outputs
if decoder == "linear":
self.tgt_out = nn.Linear(d_model, n_inputs)
elif decoder == "mlp":
self.tgt_out = AnglesPredictor(d_model, n_inputs)
else:
raise ValueError(f"Unrecognized decoder: {decoder}")
# Define the transformer model itself
# https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer
self.transformer = self.get_transformer()
self._init_weights()
self.write_preds_to_dir = write_preds_to_dir
self.write_preds_counter = 0
if self.write_preds_to_dir:
os.makedirs(self.write_preds_to_dir, exist_ok=True)
def get_transformer(self) -> nn.Module:
"""
Return the transformer model. Allows for easy overriding of the
transformer aspect of the model for alternative architectures
"""
# https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html#torch.nn.TransformerEncoderLayer
enc_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=self.num_heads,
dim_feedforward=self.intermediate_size,
dropout=self.dropout,
activation="gelu",
layer_norm_eps=1e-5,
batch_first=False, # Must do a permute to get batch first
norm_first=True,
)
# https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html#torch.nn.TransformerEncoder
encoder = nn.TransformerEncoder(enc_layer, num_layers=self.num_layers)
return encoder
def _init_weights(self) -> None:
# Initialize transformer with xavier uniform
for p in self.transformer.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Upscale the features to (N, S, E) = (batch, seq_len, emb_size)
x_upscaled = self.src_proj(x)
# Add positional embeddings
src_with_pos = self.pos_encoder(x_upscaled)
# Add time embeddings
# time_embed shape (batch, n_features) --> (batch, 1, n_features)
time_embed = self.time_encoder(timestep.squeeze(1)).unsqueeze(1)
assert time_embed.shape == (x.shape[0], 1, self.d_model)
# src_with_pos shape (batch, seq_len, n_features)
src_with_pos_time = src_with_pos + time_embed
# Generate the src mask, follows https://pytorch.org/tutorials/beginner/translation_transformer.html
# True --> NOT allowed to attend
# False --> allowed to attend
# Generate a vector of False
src_mask = torch.zeros((self.max_seq_len, self.max_seq_len)).type(torch.bool)
# Feed through transformer
# shape (batch, seq_len, d_model) --> (seq_len, batch, d_model) --> (batch, seq_len, d_model)
decoded = self.transformer(
src_with_pos_time.permute(1, 0, 2),
mask=src_mask,
src_key_padding_mask=attn_mask,
).permute(1, 0, 2)
# Decode to targets
out = self.tgt_out(decoded)
assert out.shape == x.shape
return out
def ensure_mask_fmt(self, mask: torch.Tensor) -> torch.BoolTensor:
"""
Ensure that the mask is given in the correct format (i.e., a True
value indicates masked and a False indicates not masked). This is
required because HuggingFace transformers use the opposite where
a 1/True value indicates a position to be attended and 0/False
indicates a position that is masked
"""
assert torch.all(mask >= 0) and torch.all(mask <= 1)
first_item = mask.flatten()[0]
# if the first item is a 1.0 then we know that we have received
# huggingface standard where 1 = attended. Flip to be 0 = attended
if torch.isclose(first_item, torch.ones_like(first_item)):
flipped_mask = ~(mask.bool())
assert torch.all(
torch.sum(mask) == torch.numel(flipped_mask) - torch.sum(flipped_mask)
)
assert torch.all(flipped_mask[torch.where(mask)] == False)
return flipped_mask.bool()
return mask.bool()
def _get_loss_terms(self, batch, write_preds: Optional[str] = None) -> torch.Tensor:
"""
Gets the loss terms for the model
"""
known_noise = batch["known_noise"]
corrupted = batch["corrupted"]
# Make sure the attention mask is False for unmasked
attn_mask = self.ensure_mask_fmt(batch["attn_mask"])
assert (
attn_mask.dtype == torch.bool
), f"{attn_mask} is not boolean - {attn_mask.dtype}"
predicted_noise = self.forward(
corrupted, timestep=batch["t"], attn_mask=attn_mask
)
# Under pytorch convention, 0 = not masked
unmask_idx = torch.where(attn_mask == 0)
assert len(unmask_idx) == 2
loss_terms = []
for i in range(known_noise.shape[-1]):
loss_fn = (
self.loss_func[i]
if isinstance(self.loss_func, list)
else self.loss_func
)
logging.debug(f"Using loss function {loss_fn}")
l = loss_fn(
predicted_noise[unmask_idx[0], unmask_idx[1], i],
known_noise[unmask_idx[0], unmask_idx[1], i],
)
loss_terms.append(l)
if write_preds is not None:
with open(write_preds, "w") as f:
d_to_write = {
"known_noise": known_noise.cpu().numpy().tolist(),
"predicted_noise": predicted_noise.cpu().numpy().tolist(),
"attn_mask": attn_mask.cpu().numpy().tolist(),
"losses": [l.item() for l in loss_terms],
}
json.dump(d_to_write, f)
return torch.stack(loss_terms)
def training_step(self, batch, batch_idx):
"""
Training step for the model
"""
loss = self._get_loss_terms(batch)
avg_loss = torch.mean(loss)
# L1 regularization
if self.l1_lambda > 0:
l1_penalty = sum(torch.linalg.norm(p, 1) for p in self.parameters())
self.log("l1_penalty", l1_penalty, sync_dist=True, rank_zero_only=True)
avg_loss += self.l1_lambda * l1_penalty
for loss_name, loss_val in zip(["bond_dist", "omega", "theta", "phi"], loss):
self.log(
f"train_{loss_name}", loss_val, sync_dist=True, rank_zero_only=True
)
self.log("train_loss", avg_loss, sync_dist=True, rank_zero_only=True)
return avg_loss
def validation_step(self, batch, batch_idx):
with torch.no_grad():
loss_terms = self._get_loss_terms(
batch,
write_preds=os.path.join(
self.write_preds_to_dir, f"{self.write_preds_counter}_preds.json"
),
)
self.write_preds_counter += 1
avg_loss = torch.mean(loss_terms)
for loss_name, loss_val in zip(
["bond_dist", "omega", "theta", "phi"], loss_terms
):
self.log(f"val_{loss_name}", loss_val, sync_dist=True, rank_zero_only=True)
self.log("val_loss", avg_loss, sync_dist=True, rank_zero_only=True)
def configure_optimizers(self) -> Dict[str, Any]:
"""
References:
* https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html
* https://pytorch.org/docs/stable/optim.html
"""
optim = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=self.l2_lambda,
)
retval = {"optimizer": optim}
if self.lr_scheduler:
if self.lr_scheduler == "OneCycleLR":
retval["lr_scheduler"] = {
"scheduler": torch.optim.lr_scheduler.OneCycleLR(
optim,
max_lr=1e-2,
epochs=self.min_epochs,
steps_per_epoch=self.steps_per_epoch,
),
"monitor": "val_loss",
"frequency": 1,
"interval": "step",
}
else:
raise ValueError(f"Unknown lr scheduler {self.lr_scheduler}")
pl.utilities.rank_zero_info(f"Using optimizer {retval}")
return retval
class BertDenoiserSeq2SeqModel(BertDenoiserEncoderModel):
"""
Use a seq2seq model instead of a encoder only transformer
"""
def get_transformer(self) -> nn.Module:
nn.Transformer(
d_model=self.d_model,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="gelu",
norm_first=False,
batch_first=True,
)
raise NotImplementedError
def get_causal_tgt_mask(
self, tgt_seq_len: Optional[int] = None
) -> torch.BoolTensor:
"""
Get a causal mask for target sequence where each row allows only
the next token to be seen. This is important because otherwise
the decoder can simply pass through the known target sequence.
Example output:
# [F, T, T, T]
# [F, F, T, T]
# [F, F, F, T]
# [F, F, F, F]
"""
# Lower triangular matrix
if tgt_seq_len is None:
tgt_seq_len = self.max_seq_len
mask = ~torch.tril(torch.ones(tgt_seq_len, tgt_seq_len) == 1).bool()
# If a BoolTensor is provided, positions with True are not allowed to attend
# while False values will be unchanged (i.e., True => masked)
# If a FloatTensor is provided, it will be added to the attention weight
return mask
def main():
"""on the fly testing"""
import datasets

View File

@@ -1,7 +1,6 @@
"""
Code for sampling from diffusion models
"""
from cmath import isnan
import logging
from typing import *

View File

@@ -165,193 +165,5 @@ class TestHuggingFaceBertModel(unittest.TestCase):
)
class TestBertDenoiserEncoderModel(unittest.TestCase):
def setUp(self) -> None:
self.model = modelling.BertDenoiserEncoderModel(
d_model=32, num_heads=4, num_layers=6
)
self.model.eval()
self.bs = 32
rng = np.random.default_rng(6489)
torch.random.manual_seed(6489)
# Generate attention masks by huggingface convention
# These should be auto-converted to the pytorch convention
# Do not generate sequences in the last 5 so we never have attention there
# This helps create an easy test for attention masking
self.always_masked_residues = 5
lengths = [
rng.integers(100, self.model.max_seq_len - self.always_masked_residues)
for _ in range(self.bs)
]
self.attn_masks = torch.zeros((self.bs, self.model.max_seq_len))
for i, l in enumerate(lengths):
self.attn_masks[i][:l] = 1.0
assert sum(self.attn_masks[i]) == l
# Generate random timesteps
self.timesteps = torch.from_numpy(
rng.integers(0, 250, size=(self.bs, 1))
).long()
# Generate random inputs
self.inputs = torch.randn(
(self.bs, self.model.max_seq_len, self.model.n_inputs)
)
# Generate noise vectors that correspond to masked positions
unmask_positions = torch.where(self.attn_masks == 1.0)
self.noise_on_masked = torch.randn_like(self.inputs)
# zero out the positions that are not masked so we do NOT noise them
self.noise_on_masked[unmask_positions] = 0.0
for i, l in enumerate(lengths):
assert torch.all(
torch.isclose(self.noise_on_masked[i][:l], torch.tensor(0.0))
)
# Inputs with noise on masked positions
self.inputs_with_noise_on_mask = self.inputs + self.noise_on_masked
for i, l in enumerate(lengths):
# Check that the unmaske indices are unmodified
assert torch.all(
torch.isclose(self.inputs_with_noise_on_mask[i][:l], self.inputs[i][:l])
)
assert (
self.inputs.shape[0]
== self.timesteps.shape[0]
== self.attn_masks.shape[0]
== self.bs
)
assert self.inputs.shape[1] == self.attn_masks.shape[1]
def test_consistency(self):
"""
Test that given the same input the model gives the same output
"""
x = self.model(
x=self.inputs, timestep=self.timesteps, attn_mask=self.attn_masks
)
y = self.model(
x=self.inputs, timestep=self.timesteps, attn_mask=self.attn_masks
)
self.assertTrue(torch.allclose(x, y))
def test_batch_order_consistency(self):
"""
Test that the model is invariant to the order of inputs in a batch
"""
# Run the inputs through as a "baseline" set of values
x = self.inputs
with torch.no_grad():
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
# Shuffle the order of the inputs and run them through the model again, expect same output
# https://pytorch.org/docs/stable/generated/torch.randperm.html
idx = torch.randperm(x.shape[0])
assert idx.shape[0] == x.shape[0]
with torch.no_grad():
shuffled_out = self.model(
x=x[idx], timestep=self.timesteps[idx], attn_mask=self.attn_masks[idx]
)
# Shuffle the known outputs to match
out_reordered = out[idx]
self.assertTrue(
torch.allclose(out_reordered, shuffled_out, atol=ATOL, rtol=RTOL),
msg=f"Got different outputs: {out.flatten()[:5]} {shuffled_out.flatten()[:5]}",
)
def test_batch_order_consistency_reversed(self):
"""
Test that reversing the batch order of inputs does not change output
"""
x = self.inputs
with torch.no_grad():
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
with torch.no_grad():
rev_out = self.model(
x=torch.flip(x, dims=(0,)),
timestep=torch.flip(self.timesteps, dims=(0,)),
attn_mask=torch.flip(self.attn_masks, dims=(0,)),
)
self.assertEqual(self.bs, out.shape[0])
self.assertEqual(self.bs, rev_out.shape[0])
self.assertTrue(
torch.allclose(torch.flip(out, dims=(0,)), rev_out, atol=ATOL, rtol=RTOL),
msg=f"Mismatch on reversal: {out[-2]} != {rev_out[1]}",
)
def test_attn_mask_reformat(self):
"""
Test that the mask format is detected and converted to correct
PyTorch native format. Specifically, huggingface gives masked
positions as 0.0, pytorch expects masked positions as True
"""
converted_mask = self.model.ensure_mask_fmt(self.attn_masks)
# huggingface masked indices are indicated by 0.
orig_masked_indices = torch.where(self.attn_masks == 0.0)
# pytorch masked indicies are indicated by True
conv_masked_indices = torch.where(converted_mask)
for i, j in zip(orig_masked_indices, conv_masked_indices):
self.assertTrue(torch.all(i == j))
def test_noise_invariance_easy(self):
"""
Easy test for noise invariance that focuses on the last few
residues that should never be attended to
"""
x = self.inputs
with torch.no_grad():
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
noise = torch.randn_like(x)
noise[:, : -self.always_masked_residues] = 0.0
# Check that there is no noise in the leading residues
assert torch.all(noise[:, : -self.always_masked_residues] == 0.0)
noised_x = x + noise
with torch.no_grad():
noised_out = self.model(
x=noised_x, timestep=self.timesteps, attn_mask=self.attn_masks,
)
unmasked_idx = torch.where(self.attn_masks == 1.0)
self.assertTrue(
torch.allclose(out[unmasked_idx], noised_out[unmasked_idx]),
msg=f"Got different outputs: {out.flatten()[:5]} {noised_out.flatten()[:5]}",
)
def test_noise_invariance(self):
"""
Test that noising masked positions should not affect output
"""
# Run the inputs through as a "baseline" set of values
x = self.inputs
with torch.no_grad():
out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks)
# Noise the inputs and run them through the model again, expect same output
noised_x = x + self.noise_on_masked
unmasked_idx = torch.where(self.attn_masks == 1.0)
# Check that the x and noised x agree at unmasked indices
assert torch.allclose(x[unmasked_idx], noised_x[unmasked_idx])
with torch.no_grad():
noised_out = self.model(
x=noised_x, timestep=self.timesteps, attn_mask=self.attn_masks,
)
unmasked_idx = torch.where(self.attn_masks == 1.0)
self.assertTrue(
torch.allclose(
out[unmasked_idx], noised_out[unmasked_idx], atol=ATOL, rtol=RTOL
),
msg=f"Got different outputs: {out[unmasked_idx].flatten()[:5]} {noised_out[unmasked_idx].flatten()[:5]}",
)
if __name__ == "__main__":
unittest.main()