mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-07 15:34:23 +08:00
317 lines
12 KiB
Python
317 lines
12 KiB
Python
"""
|
|
Modelling
|
|
"""
|
|
import logging
|
|
import math
|
|
from typing import *
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from transformers import BertConfig
|
|
from transformers.models.bert.modeling_bert import (
|
|
BertPreTrainedModel,
|
|
BertEncoder,
|
|
BertPooler,
|
|
)
|
|
|
|
import losses
|
|
|
|
|
|
class SinusoidalPositionEmbeddings(nn.Module):
|
|
"""
|
|
Positional embeddings
|
|
"""
|
|
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.dim = dim
|
|
|
|
def forward(self, time) -> torch.Tensor:
|
|
device = time.device
|
|
half_dim = self.dim // 2
|
|
embeddings = math.log(10000) / (half_dim - 1)
|
|
# half_dim shape
|
|
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
|
|
# outer product (batch, 1) x (1, half_dim) -> (batch x half_dim)
|
|
embeddings = time[:, None] * embeddings[None, :]
|
|
# sin and cosine embeddings
|
|
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
|
|
return embeddings
|
|
|
|
|
|
class BertForDiffusion(BertPreTrainedModel, pl.LightningModule):
|
|
"""
|
|
BERT designed to be used with continuous inputs instead of tokens
|
|
|
|
Reference: https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/bert/modeling_bert.py#L870
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config,
|
|
lr: float = 1e-4,
|
|
loss: Literal["huber", "radian_l1"] = "huber",
|
|
l2: float = 0.0,
|
|
l1: float = 0.0,
|
|
add_pooling_layer: bool = False,
|
|
) -> None:
|
|
"""
|
|
dim should be the dimension of the inputs
|
|
"""
|
|
super().__init__(config)
|
|
self.config = config
|
|
|
|
# Store information about leraning rates and loss
|
|
self.learning_rate = lr
|
|
# loss functio is either a callable or a list of callables
|
|
self.loss_func = {
|
|
"huber": F.smooth_l1_loss,
|
|
"radian_l1": [
|
|
F.smooth_l1_loss,
|
|
losses.radian_l1_loss,
|
|
losses.radian_l1_loss,
|
|
losses.radian_l1_loss,
|
|
],
|
|
}[loss]
|
|
self.l1_lambda = l1
|
|
self.l2_lambda = l2
|
|
|
|
# Needed to project the low dimensional input to hidden dim
|
|
self.inputs_to_hidden_dim = nn.Linear(
|
|
in_features=4, out_features=config.hidden_size
|
|
)
|
|
self.encoder = BertEncoder(config)
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
|
# Set up the network to project token representation to our four outputs
|
|
self.token_decoder = nn.Linear(config.hidden_size, 4)
|
|
|
|
# Set up the time embedder
|
|
self.time_embed = SinusoidalPositionEmbeddings(config.hidden_size)
|
|
|
|
# Initialize weights and apply final processing
|
|
# self.post_init()
|
|
|
|
def get_input_embeddings(self) -> nn.Module:
|
|
raise NotImplementedError
|
|
|
|
def set_input_embeddings(self, value: nn.Module):
|
|
raise NotImplementedError()
|
|
|
|
def forward(
|
|
self,
|
|
inputs: torch.Tensor,
|
|
timestep: torch.Tensor, # Tensor of shape batch_length with time indices
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
head_mask: Optional[torch.Tensor] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
):
|
|
r"""
|
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
|
the model is configured as a decoder.
|
|
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
|
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
|
- 1 for tokens that are **not masked**,
|
|
- 0 for tokens that are **masked**.
|
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
|
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
use_cache (`bool`, *optional*):
|
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
`past_key_values`).
|
|
"""
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
return_dict = (
|
|
return_dict if return_dict is not None else self.config.use_return_dict
|
|
)
|
|
|
|
if self.config.is_decoder:
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
else:
|
|
use_cache = False
|
|
|
|
if inputs is not None and inputs_embeds is not None:
|
|
raise ValueError(
|
|
"You cannot specify both input_ids and inputs_embeds at the same time"
|
|
)
|
|
elif inputs is not None:
|
|
input_shape = inputs.size()
|
|
elif inputs_embeds is not None:
|
|
input_shape = inputs_embeds.size()[:-1]
|
|
else:
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
|
batch_size, seq_length, *_ = input_shape
|
|
logging.debug(f"Detected batch {batch_size} and seq length {seq_length}")
|
|
device = inputs.device if inputs is not None else inputs_embeds.device
|
|
|
|
# past_key_values_length
|
|
past_key_values_length = (
|
|
past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
)
|
|
|
|
if attention_mask is None:
|
|
attention_mask = torch.ones(
|
|
((batch_size, seq_length + past_key_values_length)), device=device
|
|
)
|
|
|
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
|
attention_mask, input_shape, device=device
|
|
)
|
|
|
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
|
if self.config.is_decoder and encoder_hidden_states is not None:
|
|
raise NotImplementedError
|
|
else:
|
|
encoder_extended_attention_mask = None
|
|
|
|
# Prepare head mask if needed
|
|
# 1.0 in head_mask indicate we keep the head
|
|
# attention_probs has shape bsz x n_heads x N x N
|
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
|
|
|
inputs_upscaled = self.inputs_to_hidden_dim(inputs) # Batch * seq_len * dim
|
|
# timestep is (batch, 1), squeeze to (batch,)
|
|
# embedding gets to (batch, embed_dim) -> unsqueee to (batch, 1, dim)
|
|
time_encoded = self.time_embed(timestep.squeeze(dim=-1)).unsqueeze(1)
|
|
inputs_with_time = inputs_upscaled + time_encoded
|
|
encoder_outputs = self.encoder(
|
|
inputs_with_time,
|
|
attention_mask=extended_attention_mask,
|
|
head_mask=head_mask,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
encoder_attention_mask=encoder_extended_attention_mask,
|
|
past_key_values=past_key_values,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
)
|
|
|
|
sequence_output = encoder_outputs[0]
|
|
pooled_output = (
|
|
self.pooler(sequence_output) if self.pooler is not None else None
|
|
)
|
|
|
|
per_token_decoded = self.token_decoder(sequence_output)
|
|
return per_token_decoded
|
|
|
|
def _get_loss_terms(self, batch):
|
|
"""
|
|
Returns the loss terms for the model.
|
|
"""
|
|
known_noise = batch["known_noise"]
|
|
predicted_noise = self.forward(
|
|
batch["corrupted"], batch["t"], attention_mask=batch["attn_mask"]
|
|
)
|
|
|
|
# Indexes into batch then indices along sequence length
|
|
unmask_idx = torch.where(batch["attn_mask"])
|
|
loss_terms = []
|
|
for i in range(known_noise.shape[-1]):
|
|
loss_fn = (
|
|
self.loss_func[i]
|
|
if isinstance(self.loss_func, list)
|
|
else self.loss_func
|
|
)
|
|
loss_terms.append(
|
|
loss_fn(
|
|
known_noise[:, :, i][unmask_idx],
|
|
predicted_noise[:, :, i][unmask_idx],
|
|
)
|
|
)
|
|
return loss_terms
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
"""
|
|
Training step
|
|
"""
|
|
loss_terms = self._get_loss_terms(batch)
|
|
avg_loss = torch.mean(torch.stack(loss_terms))
|
|
|
|
# L1 loss implementation
|
|
if self.l1_lambda > 0:
|
|
l1_penalty = sum(torch.linalg.norm(p, 1) for p in self.parameters())
|
|
avg_loss += self.l1_lambda * l1_penalty
|
|
|
|
return avg_loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
"""
|
|
Validation step
|
|
"""
|
|
with torch.no_grad():
|
|
loss_terms = self._get_loss_terms(batch)
|
|
|
|
# Log each of the loss terms
|
|
for val_name, val in zip(["bond_dist", "omega", "theta", "phi"], loss_terms):
|
|
self.log(f"val_loss_{val_name}", val)
|
|
|
|
avg_loss = torch.mean(torch.stack(loss_terms))
|
|
self.log("val_loss", avg_loss)
|
|
|
|
def configure_optimizers(self):
|
|
"""
|
|
Return optimizer
|
|
"""
|
|
return torch.optim.Adam(
|
|
self.parameters(), lr=self.learning_rate, weight_decay=self.l2_lambda
|
|
)
|
|
|
|
|
|
def main():
|
|
"""on the fly testing"""
|
|
import datasets
|
|
from torch.utils.data.dataloader import default_collate
|
|
|
|
clean_dset = datasets.CathConsecutiveAnglesDataset(toy=True)
|
|
noised_dset = datasets.NoisedAnglesDataset(clean_dset)
|
|
torch.utils.data.dataloader.default_collate
|
|
x = default_collate([noised_dset[i] for i in range(8)])
|
|
print(x["corrupted"].shape, x["corrupted"].dtype)
|
|
print(x["t"].shape)
|
|
|
|
# Create model
|
|
# device = torch.device("cuda")
|
|
model = BertForDiffusion(
|
|
BertConfig(hidden_size=144, position_embedding_type="relative_key_query")
|
|
)
|
|
# print(model)
|
|
y = model.forward(x["corrupted"], x["t"].squeeze())
|
|
print(y.shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
main()
|