Initial transformer model

This commit is contained in:
Kevin Wu
2022-07-06 01:25:36 +00:00
parent 8601dcb27e
commit 33dd1e2890

227
protdiff/modelling.py Normal file
View File

@@ -0,0 +1,227 @@
"""
Modelling
"""
import logging
import math
from typing import *
import torch
from torch import nn
from transformers import BertConfig
from transformers.models.bert.modeling_bert import (
BertPreTrainedModel,
BertEncoder,
BertPooler,
BaseModelOutputWithPoolingAndCrossAttentions,
)
class SinusoidalPositionEmbeddings(nn.Module):
"""
Positional embeddings
"""
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
class BertForDiffusion(BertPreTrainedModel):
"""
BERT designed to be used with continuous inputs instead of tokens
Reference: https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/bert/modeling_bert.py#L870
"""
def __init__(self, config, dim: int, add_pooling_layer: bool = False) -> None:
"""
dim should be the dimension of the inputs
"""
super().__init__(config)
self.config = config
# Needed to project the low dimensional input to hidden dim
self.inputs_to_hidden_dim = nn.Linear(
in_features=4, out_features=config.hidden_size
)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
# Set up the network to project token representation to our four outputs
self.token_decoder = nn.Linear(config.hidden_size, 4)
# Set up the time embedder
self.dim = dim
self.time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, self.time_dim),
nn.GELU(),
nn.Linear(self.time_dim, self.time_dim),
)
# Initialize weights and apply final processing
# self.post_init()
def get_input_embeddings(self) -> nn.Module:
raise NotImplementedError
def set_input_embeddings(self, value: nn.Module):
raise NotImplementedError()
def forward(
self,
inputs: torch.Tensor,
timestep: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if inputs is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif inputs is not None:
input_shape = inputs.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length, *_ = input_shape
logging.debug(f"Detected batch {batch_size} and seq length {seq_length}")
device = inputs.device if inputs is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
if attention_mask is None:
attention_mask = torch.ones(
((batch_size, seq_length + past_key_values_length)), device=device
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
attention_mask, input_shape, device=device
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
raise NotImplementedError
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
inputs_upscaled = self.inputs_to_hidden_dim(inputs)
encoder_outputs = self.encoder(
inputs_upscaled,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
time_encoded = self.time_mlp(timestep)
sequence_output = encoder_outputs[0]
pooled_output = (
self.pooler(sequence_output) if self.pooler is not None else None
)
per_token_decoded = self.token_decoder(sequence_output)
return per_token_decoded
def main():
"""on the fly testing"""
import datasets
from torch.utils.data.dataloader import default_collate
clean_dset = datasets.CathConsecutiveAnglesDataset(toy=True)
noised_dset = datasets.NoisedAnglesDataset(clean_dset)
torch.utils.data.dataloader.default_collate
x = default_collate([noised_dset[i] for i in range(8)])
print(x["corrupted"].shape, x["corrupted"].dtype)
print(x["t"].shape)
# Create model
# device = torch.device("cuda")
model = BertForDiffusion(BertConfig(hidden_size=144), dim=512)
# print(model)
y = model.forward(x["corrupted"], x["t"].squeeze())
print(y.shape)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
main()