diff --git a/protdiff/modelling.py b/protdiff/modelling.py index 6d06bf6..f8f667c 100644 --- a/protdiff/modelling.py +++ b/protdiff/modelling.py @@ -701,349 +701,6 @@ class BertForDiffusion(BertPreTrainedModel, pl.LightningModule): return retval -class BertDenoiserEncoderModel(pl.LightningModule): - """ - Self implementation. Make sure that we know every bit of what goes into here so - there's no more issues - """ - - loss_fn_dict = { - "huber": F.smooth_l1_loss, - "radian_l1": [ - F.smooth_l1_loss, - losses.radian_l1_loss, - losses.radian_l1_loss, - losses.radian_l1_loss, - ], - "radian_l1_smooth": [ - F.smooth_l1_loss, - functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10), - functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10), - functools.partial(losses.radian_smooth_l1_loss, beta=torch.pi / 10), - ], - } - - def __init__( - self, - n_inputs: int = 4, - d_model: int = 256, - num_layers: int = 6, - intermediate_size: int = 512, - max_seq_len: int = 512, - num_heads: int = 8, - dropout: float = 0.1, - time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier", - decoder: Literal["mlp", "linear"] = "linear", - loss: Union[ - Callable, Literal["huber", "radian_l1", "radian_l1_smooth"] - ] = "huber", - lr: float = 1e-4, - l2: float = 0.0, - l1: float = 0.0, - circle_reg: float = 0.0, - min_epochs: int = 500, - steps_per_epoch: int = 100, # Dummy value - lr_scheduler: Optional[Literal["OneCycleLR"]] = None, - write_preds_to_dir: Optional[str] = None, - ) -> None: - super().__init__() - self.n_inputs = n_inputs - self.max_seq_len = max_seq_len - self.learning_rate = lr - self.l2_lambda = l2 - self.l1_lambda = l1 - self.circ_lambda = circle_reg - if self.circ_lambda > 0: - raise NotImplementedError - self.min_epochs = min_epochs - self.steps_per_epoch = steps_per_epoch - self.lr_scheduler = lr_scheduler - - self.loss_func = self.loss_fn_dict[loss] if isinstance(loss, str) else loss - pl.utilities.rank_zero_info(f"Using loss: {self.loss_func}") - - # Define the positional embedding. Called as self.pos_encoder(x) and - # returns the input + the positional embedding - self.pos_encoder = PositionalEncoding( - d_model, max_len=self.max_seq_len, dropout=dropout - ) - - # Define the time embedding - if time_encoding == "gaussian_fourier": - self.time_encoder = GaussianFourierProjection(d_model) - elif time_encoding == "sinusoidal": - self.time_encoder = SinusoidalPositionEmbeddings(d_model) - else: - raise ValueError(f"Unknown time encoding {time_encoding}") - pl.utilities.rank_zero_info(f"Time encoding: {self.time_encoder}") - - self.num_layers = num_layers - self.d_model = d_model - self.dropout = dropout - self.intermediate_size = intermediate_size - self.num_heads = num_heads - self.src_proj = nn.Linear(n_inputs, d_model) - - # Set up the network to project token representation to our four outputs - if decoder == "linear": - self.tgt_out = nn.Linear(d_model, n_inputs) - elif decoder == "mlp": - self.tgt_out = AnglesPredictor(d_model, n_inputs) - else: - raise ValueError(f"Unrecognized decoder: {decoder}") - - # Define the transformer model itself - # https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer - self.transformer = self.get_transformer() - - self._init_weights() - - self.write_preds_to_dir = write_preds_to_dir - self.write_preds_counter = 0 - if self.write_preds_to_dir: - os.makedirs(self.write_preds_to_dir, exist_ok=True) - - def get_transformer(self) -> nn.Module: - """ - Return the transformer model. Allows for easy overriding of the - transformer aspect of the model for alternative architectures - """ - # https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html#torch.nn.TransformerEncoderLayer - enc_layer = nn.TransformerEncoderLayer( - d_model=self.d_model, - nhead=self.num_heads, - dim_feedforward=self.intermediate_size, - dropout=self.dropout, - activation="gelu", - layer_norm_eps=1e-5, - batch_first=False, # Must do a permute to get batch first - norm_first=True, - ) - # https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html#torch.nn.TransformerEncoder - encoder = nn.TransformerEncoder(enc_layer, num_layers=self.num_layers) - return encoder - - def _init_weights(self) -> None: - # Initialize transformer with xavier uniform - for p in self.transformer.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward( - self, - x: torch.Tensor, - timestep: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - # Upscale the features to (N, S, E) = (batch, seq_len, emb_size) - x_upscaled = self.src_proj(x) - - # Add positional embeddings - src_with_pos = self.pos_encoder(x_upscaled) - - # Add time embeddings - # time_embed shape (batch, n_features) --> (batch, 1, n_features) - time_embed = self.time_encoder(timestep.squeeze(1)).unsqueeze(1) - assert time_embed.shape == (x.shape[0], 1, self.d_model) - # src_with_pos shape (batch, seq_len, n_features) - src_with_pos_time = src_with_pos + time_embed - - # Generate the src mask, follows https://pytorch.org/tutorials/beginner/translation_transformer.html - # True --> NOT allowed to attend - # False --> allowed to attend - # Generate a vector of False - src_mask = torch.zeros((self.max_seq_len, self.max_seq_len)).type(torch.bool) - # Feed through transformer - # shape (batch, seq_len, d_model) --> (seq_len, batch, d_model) --> (batch, seq_len, d_model) - decoded = self.transformer( - src_with_pos_time.permute(1, 0, 2), - mask=src_mask, - src_key_padding_mask=attn_mask, - ).permute(1, 0, 2) - - # Decode to targets - out = self.tgt_out(decoded) - assert out.shape == x.shape - return out - - def ensure_mask_fmt(self, mask: torch.Tensor) -> torch.BoolTensor: - """ - Ensure that the mask is given in the correct format (i.e., a True - value indicates masked and a False indicates not masked). This is - required because HuggingFace transformers use the opposite where - a 1/True value indicates a position to be attended and 0/False - indicates a position that is masked - """ - assert torch.all(mask >= 0) and torch.all(mask <= 1) - first_item = mask.flatten()[0] - # if the first item is a 1.0 then we know that we have received - # huggingface standard where 1 = attended. Flip to be 0 = attended - if torch.isclose(first_item, torch.ones_like(first_item)): - flipped_mask = ~(mask.bool()) - assert torch.all( - torch.sum(mask) == torch.numel(flipped_mask) - torch.sum(flipped_mask) - ) - assert torch.all(flipped_mask[torch.where(mask)] == False) - return flipped_mask.bool() - return mask.bool() - - def _get_loss_terms(self, batch, write_preds: Optional[str] = None) -> torch.Tensor: - """ - Gets the loss terms for the model - """ - known_noise = batch["known_noise"] - corrupted = batch["corrupted"] - # Make sure the attention mask is False for unmasked - attn_mask = self.ensure_mask_fmt(batch["attn_mask"]) - assert ( - attn_mask.dtype == torch.bool - ), f"{attn_mask} is not boolean - {attn_mask.dtype}" - - predicted_noise = self.forward( - corrupted, timestep=batch["t"], attn_mask=attn_mask - ) - - # Under pytorch convention, 0 = not masked - unmask_idx = torch.where(attn_mask == 0) - assert len(unmask_idx) == 2 - loss_terms = [] - for i in range(known_noise.shape[-1]): - loss_fn = ( - self.loss_func[i] - if isinstance(self.loss_func, list) - else self.loss_func - ) - logging.debug(f"Using loss function {loss_fn}") - - l = loss_fn( - predicted_noise[unmask_idx[0], unmask_idx[1], i], - known_noise[unmask_idx[0], unmask_idx[1], i], - ) - loss_terms.append(l) - - if write_preds is not None: - with open(write_preds, "w") as f: - d_to_write = { - "known_noise": known_noise.cpu().numpy().tolist(), - "predicted_noise": predicted_noise.cpu().numpy().tolist(), - "attn_mask": attn_mask.cpu().numpy().tolist(), - "losses": [l.item() for l in loss_terms], - } - json.dump(d_to_write, f) - - return torch.stack(loss_terms) - - def training_step(self, batch, batch_idx): - """ - Training step for the model - """ - loss = self._get_loss_terms(batch) - avg_loss = torch.mean(loss) - - # L1 regularization - if self.l1_lambda > 0: - l1_penalty = sum(torch.linalg.norm(p, 1) for p in self.parameters()) - self.log("l1_penalty", l1_penalty, sync_dist=True, rank_zero_only=True) - avg_loss += self.l1_lambda * l1_penalty - - for loss_name, loss_val in zip(["bond_dist", "omega", "theta", "phi"], loss): - self.log( - f"train_{loss_name}", loss_val, sync_dist=True, rank_zero_only=True - ) - self.log("train_loss", avg_loss, sync_dist=True, rank_zero_only=True) - return avg_loss - - def validation_step(self, batch, batch_idx): - with torch.no_grad(): - loss_terms = self._get_loss_terms( - batch, - write_preds=os.path.join( - self.write_preds_to_dir, f"{self.write_preds_counter}_preds.json" - ), - ) - self.write_preds_counter += 1 - - avg_loss = torch.mean(loss_terms) - - for loss_name, loss_val in zip( - ["bond_dist", "omega", "theta", "phi"], loss_terms - ): - self.log(f"val_{loss_name}", loss_val, sync_dist=True, rank_zero_only=True) - self.log("val_loss", avg_loss, sync_dist=True, rank_zero_only=True) - - def configure_optimizers(self) -> Dict[str, Any]: - """ - References: - * https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html - * https://pytorch.org/docs/stable/optim.html - """ - optim = torch.optim.AdamW( - self.parameters(), - lr=self.learning_rate, - weight_decay=self.l2_lambda, - ) - retval = {"optimizer": optim} - if self.lr_scheduler: - if self.lr_scheduler == "OneCycleLR": - retval["lr_scheduler"] = { - "scheduler": torch.optim.lr_scheduler.OneCycleLR( - optim, - max_lr=1e-2, - epochs=self.min_epochs, - steps_per_epoch=self.steps_per_epoch, - ), - "monitor": "val_loss", - "frequency": 1, - "interval": "step", - } - else: - raise ValueError(f"Unknown lr scheduler {self.lr_scheduler}") - pl.utilities.rank_zero_info(f"Using optimizer {retval}") - return retval - - -class BertDenoiserSeq2SeqModel(BertDenoiserEncoderModel): - """ - Use a seq2seq model instead of a encoder only transformer - """ - - def get_transformer(self) -> nn.Module: - nn.Transformer( - d_model=self.d_model, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - dim_feedforward=2048, - dropout=0.1, - activation="gelu", - norm_first=False, - batch_first=True, - ) - raise NotImplementedError - - def get_causal_tgt_mask( - self, tgt_seq_len: Optional[int] = None - ) -> torch.BoolTensor: - """ - Get a causal mask for target sequence where each row allows only - the next token to be seen. This is important because otherwise - the decoder can simply pass through the known target sequence. - Example output: - # [F, T, T, T] - # [F, F, T, T] - # [F, F, F, T] - # [F, F, F, F] - """ - # Lower triangular matrix - if tgt_seq_len is None: - tgt_seq_len = self.max_seq_len - mask = ~torch.tril(torch.ones(tgt_seq_len, tgt_seq_len) == 1).bool() - # If a BoolTensor is provided, positions with True are not allowed to attend - # while False values will be unchanged (i.e., True => masked) - # If a FloatTensor is provided, it will be added to the attention weight - return mask - - def main(): """on the fly testing""" import datasets diff --git a/protdiff/sampling.py b/protdiff/sampling.py index 76f8a9e..77ba653 100644 --- a/protdiff/sampling.py +++ b/protdiff/sampling.py @@ -1,7 +1,6 @@ """ Code for sampling from diffusion models """ -from cmath import isnan import logging from typing import * diff --git a/tests/test_transformer_invariance.py b/tests/test_transformer_invariance.py index f34ca68..e68e373 100644 --- a/tests/test_transformer_invariance.py +++ b/tests/test_transformer_invariance.py @@ -165,193 +165,5 @@ class TestHuggingFaceBertModel(unittest.TestCase): ) -class TestBertDenoiserEncoderModel(unittest.TestCase): - def setUp(self) -> None: - self.model = modelling.BertDenoiserEncoderModel( - d_model=32, num_heads=4, num_layers=6 - ) - self.model.eval() - - self.bs = 32 - rng = np.random.default_rng(6489) - torch.random.manual_seed(6489) - - # Generate attention masks by huggingface convention - # These should be auto-converted to the pytorch convention - # Do not generate sequences in the last 5 so we never have attention there - # This helps create an easy test for attention masking - self.always_masked_residues = 5 - lengths = [ - rng.integers(100, self.model.max_seq_len - self.always_masked_residues) - for _ in range(self.bs) - ] - self.attn_masks = torch.zeros((self.bs, self.model.max_seq_len)) - for i, l in enumerate(lengths): - self.attn_masks[i][:l] = 1.0 - assert sum(self.attn_masks[i]) == l - - # Generate random timesteps - self.timesteps = torch.from_numpy( - rng.integers(0, 250, size=(self.bs, 1)) - ).long() - - # Generate random inputs - self.inputs = torch.randn( - (self.bs, self.model.max_seq_len, self.model.n_inputs) - ) - - # Generate noise vectors that correspond to masked positions - unmask_positions = torch.where(self.attn_masks == 1.0) - self.noise_on_masked = torch.randn_like(self.inputs) - # zero out the positions that are not masked so we do NOT noise them - self.noise_on_masked[unmask_positions] = 0.0 - for i, l in enumerate(lengths): - assert torch.all( - torch.isclose(self.noise_on_masked[i][:l], torch.tensor(0.0)) - ) - - # Inputs with noise on masked positions - self.inputs_with_noise_on_mask = self.inputs + self.noise_on_masked - for i, l in enumerate(lengths): - # Check that the unmaske indices are unmodified - assert torch.all( - torch.isclose(self.inputs_with_noise_on_mask[i][:l], self.inputs[i][:l]) - ) - - assert ( - self.inputs.shape[0] - == self.timesteps.shape[0] - == self.attn_masks.shape[0] - == self.bs - ) - assert self.inputs.shape[1] == self.attn_masks.shape[1] - - def test_consistency(self): - """ - Test that given the same input the model gives the same output - """ - x = self.model( - x=self.inputs, timestep=self.timesteps, attn_mask=self.attn_masks - ) - y = self.model( - x=self.inputs, timestep=self.timesteps, attn_mask=self.attn_masks - ) - self.assertTrue(torch.allclose(x, y)) - - def test_batch_order_consistency(self): - """ - Test that the model is invariant to the order of inputs in a batch - """ - # Run the inputs through as a "baseline" set of values - x = self.inputs - with torch.no_grad(): - out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks) - - # Shuffle the order of the inputs and run them through the model again, expect same output - # https://pytorch.org/docs/stable/generated/torch.randperm.html - idx = torch.randperm(x.shape[0]) - assert idx.shape[0] == x.shape[0] - with torch.no_grad(): - shuffled_out = self.model( - x=x[idx], timestep=self.timesteps[idx], attn_mask=self.attn_masks[idx] - ) - # Shuffle the known outputs to match - out_reordered = out[idx] - self.assertTrue( - torch.allclose(out_reordered, shuffled_out, atol=ATOL, rtol=RTOL), - msg=f"Got different outputs: {out.flatten()[:5]} {shuffled_out.flatten()[:5]}", - ) - - def test_batch_order_consistency_reversed(self): - """ - Test that reversing the batch order of inputs does not change output - """ - x = self.inputs - with torch.no_grad(): - out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks) - - with torch.no_grad(): - rev_out = self.model( - x=torch.flip(x, dims=(0,)), - timestep=torch.flip(self.timesteps, dims=(0,)), - attn_mask=torch.flip(self.attn_masks, dims=(0,)), - ) - - self.assertEqual(self.bs, out.shape[0]) - self.assertEqual(self.bs, rev_out.shape[0]) - self.assertTrue( - torch.allclose(torch.flip(out, dims=(0,)), rev_out, atol=ATOL, rtol=RTOL), - msg=f"Mismatch on reversal: {out[-2]} != {rev_out[1]}", - ) - - def test_attn_mask_reformat(self): - """ - Test that the mask format is detected and converted to correct - PyTorch native format. Specifically, huggingface gives masked - positions as 0.0, pytorch expects masked positions as True - """ - converted_mask = self.model.ensure_mask_fmt(self.attn_masks) - # huggingface masked indices are indicated by 0. - orig_masked_indices = torch.where(self.attn_masks == 0.0) - # pytorch masked indicies are indicated by True - conv_masked_indices = torch.where(converted_mask) - for i, j in zip(orig_masked_indices, conv_masked_indices): - self.assertTrue(torch.all(i == j)) - - def test_noise_invariance_easy(self): - """ - Easy test for noise invariance that focuses on the last few - residues that should never be attended to - """ - x = self.inputs - with torch.no_grad(): - out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks) - - noise = torch.randn_like(x) - noise[:, : -self.always_masked_residues] = 0.0 - # Check that there is no noise in the leading residues - assert torch.all(noise[:, : -self.always_masked_residues] == 0.0) - noised_x = x + noise - - with torch.no_grad(): - noised_out = self.model( - x=noised_x, timestep=self.timesteps, attn_mask=self.attn_masks, - ) - - unmasked_idx = torch.where(self.attn_masks == 1.0) - self.assertTrue( - torch.allclose(out[unmasked_idx], noised_out[unmasked_idx]), - msg=f"Got different outputs: {out.flatten()[:5]} {noised_out.flatten()[:5]}", - ) - - def test_noise_invariance(self): - """ - Test that noising masked positions should not affect output - """ - # Run the inputs through as a "baseline" set of values - x = self.inputs - with torch.no_grad(): - out = self.model(x=x, timestep=self.timesteps, attn_mask=self.attn_masks) - - # Noise the inputs and run them through the model again, expect same output - noised_x = x + self.noise_on_masked - unmasked_idx = torch.where(self.attn_masks == 1.0) - # Check that the x and noised x agree at unmasked indices - assert torch.allclose(x[unmasked_idx], noised_x[unmasked_idx]) - - with torch.no_grad(): - noised_out = self.model( - x=noised_x, timestep=self.timesteps, attn_mask=self.attn_masks, - ) - - unmasked_idx = torch.where(self.attn_masks == 1.0) - self.assertTrue( - torch.allclose( - out[unmasked_idx], noised_out[unmasked_idx], atol=ATOL, rtol=RTOL - ), - msg=f"Got different outputs: {out[unmasked_idx].flatten()[:5]} {noised_out[unmasked_idx].flatten()[:5]}", - ) - - if __name__ == "__main__": unittest.main()