diff --git a/protdiff/modelling.py b/protdiff/modelling.py new file mode 100644 index 0000000..29b1fa8 --- /dev/null +++ b/protdiff/modelling.py @@ -0,0 +1,227 @@ +""" +Modelling +""" +import logging +import math +from typing import * + +import torch +from torch import nn + +from transformers import BertConfig +from transformers.models.bert.modeling_bert import ( + BertPreTrainedModel, + BertEncoder, + BertPooler, + BaseModelOutputWithPoolingAndCrossAttentions, +) + + +class SinusoidalPositionEmbeddings(nn.Module): + """ + Positional embeddings + """ + + def __init__(self, dim: int) -> None: + super().__init__() + self.dim = dim + + def forward(self, time): + device = time.device + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) + return embeddings + + +class BertForDiffusion(BertPreTrainedModel): + """ + BERT designed to be used with continuous inputs instead of tokens + + Reference: https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/bert/modeling_bert.py#L870 + """ + + def __init__(self, config, dim: int, add_pooling_layer: bool = False) -> None: + """ + dim should be the dimension of the inputs + """ + super().__init__(config) + self.config = config + + # Needed to project the low dimensional input to hidden dim + self.inputs_to_hidden_dim = nn.Linear( + in_features=4, out_features=config.hidden_size + ) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + # Set up the network to project token representation to our four outputs + self.token_decoder = nn.Linear(config.hidden_size, 4) + + # Set up the time embedder + self.dim = dim + self.time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPositionEmbeddings(dim), + nn.Linear(dim, self.time_dim), + nn.GELU(), + nn.Linear(self.time_dim, self.time_dim), + ) + + # Initialize weights and apply final processing + # self.post_init() + + def get_input_embeddings(self) -> nn.Module: + raise NotImplementedError + + def set_input_embeddings(self, value: nn.Module): + raise NotImplementedError() + + def forward( + self, + inputs: torch.Tensor, + timestep: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if inputs is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif inputs is not None: + input_shape = inputs.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length, *_ = input_shape + logging.debug(f"Detected batch {batch_size} and seq length {seq_length}") + device = inputs.device if inputs is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, input_shape, device=device + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + raise NotImplementedError + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + inputs_upscaled = self.inputs_to_hidden_dim(inputs) + encoder_outputs = self.encoder( + inputs_upscaled, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + time_encoded = self.time_mlp(timestep) + + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + per_token_decoded = self.token_decoder(sequence_output) + return per_token_decoded + + +def main(): + """on the fly testing""" + import datasets + from torch.utils.data.dataloader import default_collate + + clean_dset = datasets.CathConsecutiveAnglesDataset(toy=True) + noised_dset = datasets.NoisedAnglesDataset(clean_dset) + torch.utils.data.dataloader.default_collate + x = default_collate([noised_dset[i] for i in range(8)]) + print(x["corrupted"].shape, x["corrupted"].dtype) + print(x["t"].shape) + + # Create model + # device = torch.device("cuda") + model = BertForDiffusion(BertConfig(hidden_size=144), dim=512) + # print(model) + y = model.forward(x["corrupted"], x["t"].squeeze()) + print(y.shape) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + main()