Files
D-SCRIPT/dscript/models/lightning.py

192 lines
5.7 KiB
Python

import logging as lg
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from .contact import ContactCNN
from .embedding import FullyConnectedEmbed
from .interaction import ModelInteraction
logg = lg.getLogger("D-SCRIPT")
class LitInteraction(pl.LightningModule):
def __init__(
self,
projection_dim: int = 100,
dropout_p: float = 0.5,
projection_activation=nn.ReLU,
hidden_dim: int = 50,
kernel_width: int = 7,
contact_activation=nn.Sigmoid,
pool_width: int = 9,
lr: float = 1e-3,
weight_decay: float = 0,
lambda_similarity: float = 0.35,
save_prefix: str = "ModelCheckpoint",
save_every: int = 1,
):
super(LitInteraction, self).__init__()
self._embedding = FullyConnectedEmbed(
nin=6165,
nout=projection_dim,
dropout=dropout_p,
activation=projection_activation(),
)
self._contact = ContactCNN(
embed_dim=projection_dim,
hidden_dim=hidden_dim,
width=kernel_width,
activation=contact_activation(),
)
self._model = ModelInteraction(
self._embedding,
self._contact,
pool_size=pool_width,
)
self.criterion = nn.BCELoss()
self.train_acc = torchmetrics.Accuracy()
self.train_aupr = torchmetrics.AveragePrecision()
self.val_acc = torchmetrics.Accuracy()
self.val_aupr = torchmetrics.AveragePrecision()
self.test_acc = torchmetrics.Accuracy()
self.test_aupr = torchmetrics.AveragePrecision()
self.test_auroc = torchmetrics.AUROC()
self.test_spec = torchmetrics.Specificity()
self.test_sens = torchmetrics.Recall()
self.save_hyperparameters()
self.hparams.patience = 3
def forward(self, z0, z1):
return self._model(z0, z1)
def on_train_start(self):
self.logger.log_hyperparams(params=self.hparams)
def training_step(self, batch, batch_idx):
preds = []
cmaps = []
labels = []
for (x0, x1, y) in zip(*batch):
cmap, y_hat = self._model.map_predict(x0, x1)
cmaps.append(torch.mean(cmap))
preds.append(torch.mean(y_hat))
labels.append(y)
preds = torch.stack(preds, 0).float()
cmaps = torch.stack(cmaps, 0)
labels = torch.stack(labels, 0)
bce_loss = self.criterion(preds, labels.float())
cmap_loss = torch.mean(cmaps)
loss = (self.hparams.lambda_similarity * bce_loss) + (
(1 - self.hparams.lambda_similarity) * cmap_loss
)
self.train_acc(preds.detach(), labels)
self.train_aupr(preds.detach(), labels)
train_metrics = {
"train/bce_loss": bce_loss,
"train/cmap_loss": cmap_loss,
"train/loss": loss,
"train/acc": self.train_acc,
"train/aupr": self.train_aupr,
}
self.log_dict(train_metrics, on_step=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
preds = []
labels = []
for (x0, x1, y) in zip(*batch):
y_hat = self._model(x0, x1)
preds.append(y_hat)
labels.append(y)
preds = torch.stack(preds, 0).float()
labels = torch.stack(labels, 0)
bce_loss = self.criterion(preds, labels.float())
self.val_acc(preds.detach(), labels)
self.val_aupr(preds.detach(), labels)
val_metrics = {
"val/bce_loss": bce_loss,
"val/loss": bce_loss,
"val/acc": self.val_acc,
"val/aupr": self.val_aupr,
}
self.log_dict(
val_metrics,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return bce_loss
def test_step(self, batch, batch_idx):
preds = []
labels = []
for (x0, x1, y) in zip(*batch):
y_hat = self._model(x0, x1)
preds.append(y_hat)
labels.append(y)
preds = torch.stack(preds, 0).float()
labels = torch.stack(labels, 0)
bce_loss = self.criterion(preds, labels.float())
self.test_acc(preds, labels)
self.test_aupr(preds, labels)
self.test_auroc.update(preds, labels)
self.test_spec.update(preds, labels)
self.test_sens.update(preds, labels)
test_metrics = {
"test/bce_loss": bce_loss,
"test/loss": bce_loss,
"test/acc": self.test_acc,
"test/aupr": self.test_aupr,
"test/auroc": self.test_auroc,
"test/spec": self.test_spec,
"test/sens": self.test_sens,
}
self.log_dict(
test_metrics,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return bce_loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay,
)
return {"optimizer": optimizer}
def configure_callbacks(self):
early_stop = EarlyStopping(
monitor="val/aupr", mode="max", patience=self.hparams.patience
)
checkpoint = ModelCheckpoint(
monitor="val/aupr",
dirpath=self.hparams.save_prefix,
save_top_k=self.hparams.save_every,
filename="{epoch}-{val/aupr:.2f}",
)
return [early_stop, checkpoint]