mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 23:14:22 +08:00
192 lines
5.7 KiB
Python
192 lines
5.7 KiB
Python
import logging as lg
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchmetrics
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
|
|
from .contact import ContactCNN
|
|
from .embedding import FullyConnectedEmbed
|
|
from .interaction import ModelInteraction
|
|
|
|
logg = lg.getLogger("D-SCRIPT")
|
|
|
|
|
|
class LitInteraction(pl.LightningModule):
|
|
def __init__(
|
|
self,
|
|
projection_dim: int = 100,
|
|
dropout_p: float = 0.5,
|
|
projection_activation=nn.ReLU,
|
|
hidden_dim: int = 50,
|
|
kernel_width: int = 7,
|
|
contact_activation=nn.Sigmoid,
|
|
pool_width: int = 9,
|
|
lr: float = 1e-3,
|
|
weight_decay: float = 0,
|
|
lambda_similarity: float = 0.35,
|
|
save_prefix: str = "ModelCheckpoint",
|
|
save_every: int = 1,
|
|
):
|
|
super(LitInteraction, self).__init__()
|
|
|
|
self._embedding = FullyConnectedEmbed(
|
|
nin=6165,
|
|
nout=projection_dim,
|
|
dropout=dropout_p,
|
|
activation=projection_activation(),
|
|
)
|
|
self._contact = ContactCNN(
|
|
embed_dim=projection_dim,
|
|
hidden_dim=hidden_dim,
|
|
width=kernel_width,
|
|
activation=contact_activation(),
|
|
)
|
|
self._model = ModelInteraction(
|
|
self._embedding,
|
|
self._contact,
|
|
pool_size=pool_width,
|
|
)
|
|
self.criterion = nn.BCELoss()
|
|
|
|
self.train_acc = torchmetrics.Accuracy()
|
|
self.train_aupr = torchmetrics.AveragePrecision()
|
|
|
|
self.val_acc = torchmetrics.Accuracy()
|
|
self.val_aupr = torchmetrics.AveragePrecision()
|
|
|
|
self.test_acc = torchmetrics.Accuracy()
|
|
self.test_aupr = torchmetrics.AveragePrecision()
|
|
self.test_auroc = torchmetrics.AUROC()
|
|
self.test_spec = torchmetrics.Specificity()
|
|
self.test_sens = torchmetrics.Recall()
|
|
|
|
self.save_hyperparameters()
|
|
self.hparams.patience = 3
|
|
|
|
def forward(self, z0, z1):
|
|
return self._model(z0, z1)
|
|
|
|
def on_train_start(self):
|
|
self.logger.log_hyperparams(params=self.hparams)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
preds = []
|
|
cmaps = []
|
|
labels = []
|
|
for (x0, x1, y) in zip(*batch):
|
|
cmap, y_hat = self._model.map_predict(x0, x1)
|
|
cmaps.append(torch.mean(cmap))
|
|
preds.append(torch.mean(y_hat))
|
|
labels.append(y)
|
|
|
|
preds = torch.stack(preds, 0).float()
|
|
cmaps = torch.stack(cmaps, 0)
|
|
labels = torch.stack(labels, 0)
|
|
|
|
bce_loss = self.criterion(preds, labels.float())
|
|
cmap_loss = torch.mean(cmaps)
|
|
loss = (self.hparams.lambda_similarity * bce_loss) + (
|
|
(1 - self.hparams.lambda_similarity) * cmap_loss
|
|
)
|
|
|
|
self.train_acc(preds.detach(), labels)
|
|
self.train_aupr(preds.detach(), labels)
|
|
train_metrics = {
|
|
"train/bce_loss": bce_loss,
|
|
"train/cmap_loss": cmap_loss,
|
|
"train/loss": loss,
|
|
"train/acc": self.train_acc,
|
|
"train/aupr": self.train_aupr,
|
|
}
|
|
self.log_dict(train_metrics, on_step=True, logger=True)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
preds = []
|
|
labels = []
|
|
for (x0, x1, y) in zip(*batch):
|
|
y_hat = self._model(x0, x1)
|
|
preds.append(y_hat)
|
|
labels.append(y)
|
|
preds = torch.stack(preds, 0).float()
|
|
labels = torch.stack(labels, 0)
|
|
|
|
bce_loss = self.criterion(preds, labels.float())
|
|
|
|
self.val_acc(preds.detach(), labels)
|
|
self.val_aupr(preds.detach(), labels)
|
|
val_metrics = {
|
|
"val/bce_loss": bce_loss,
|
|
"val/loss": bce_loss,
|
|
"val/acc": self.val_acc,
|
|
"val/aupr": self.val_aupr,
|
|
}
|
|
self.log_dict(
|
|
val_metrics,
|
|
on_step=False,
|
|
on_epoch=True,
|
|
prog_bar=True,
|
|
logger=True,
|
|
)
|
|
return bce_loss
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
preds = []
|
|
labels = []
|
|
for (x0, x1, y) in zip(*batch):
|
|
y_hat = self._model(x0, x1)
|
|
preds.append(y_hat)
|
|
labels.append(y)
|
|
preds = torch.stack(preds, 0).float()
|
|
labels = torch.stack(labels, 0)
|
|
|
|
bce_loss = self.criterion(preds, labels.float())
|
|
|
|
self.test_acc(preds, labels)
|
|
self.test_aupr(preds, labels)
|
|
self.test_auroc.update(preds, labels)
|
|
self.test_spec.update(preds, labels)
|
|
self.test_sens.update(preds, labels)
|
|
test_metrics = {
|
|
"test/bce_loss": bce_loss,
|
|
"test/loss": bce_loss,
|
|
"test/acc": self.test_acc,
|
|
"test/aupr": self.test_aupr,
|
|
"test/auroc": self.test_auroc,
|
|
"test/spec": self.test_spec,
|
|
"test/sens": self.test_sens,
|
|
}
|
|
|
|
self.log_dict(
|
|
test_metrics,
|
|
on_step=False,
|
|
on_epoch=True,
|
|
prog_bar=True,
|
|
logger=True,
|
|
)
|
|
return bce_loss
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.Adam(
|
|
self.parameters(),
|
|
lr=self.hparams.lr,
|
|
weight_decay=self.hparams.weight_decay,
|
|
)
|
|
|
|
return {"optimizer": optimizer}
|
|
|
|
def configure_callbacks(self):
|
|
early_stop = EarlyStopping(
|
|
monitor="val/aupr", mode="max", patience=self.hparams.patience
|
|
)
|
|
checkpoint = ModelCheckpoint(
|
|
monitor="val/aupr",
|
|
dirpath=self.hparams.save_prefix,
|
|
save_top_k=self.hparams.save_every,
|
|
filename="{epoch}-{val/aupr:.2f}",
|
|
)
|
|
return [early_stop, checkpoint]
|