Files
D-SCRIPT/dscript/commands/train.py
2022-08-18 12:27:49 -04:00

343 lines
9.2 KiB
Python

"""
Train a new model.
"""
from __future__ import annotations
import argparse
import datetime
import gzip as gz
import logging as lg
import os
import subprocess as sp
import sys
import h5py
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimizers
from pytorch_lightning import loggers as pl_loggers
from tqdm import tqdm
from typing import Callable, NamedTuple, Optional
from ..datamodules import PPIDataModule
# from ..models.contact import ContactCNN
# from ..models.embedding import FullyConnectedEmbed, IdentityEmbed
# from ..models.interaction import ModelInteraction
from ..models.lightning import LitInteraction
from ..utils import config_logger
class TrainArguments(NamedTuple):
cmd: str
device: int
train: str
test: str
embedding: str
no_augment: bool
input_dim: int
projection_dim: int
dropout: float
hidden_dim: int
kernel_width: int
no_w: bool
no_sigmoid: bool
do_pool: bool
pool_width: int
num_epochs: int
batch_size: int
weight_decay: float
lr: float
interaction_weight: float
run_tt: bool
glider_weight: float
glider_thresh: float
outfile: Optional[str]
save_prefix: Optional[str]
checkpoint: Optional[str]
func: Callable[[TrainArguments], None]
def add_args(parser):
"""
Create parser for command line utility.
:meta private:
"""
data_grp = parser.add_argument_group("Data")
proj_grp = parser.add_argument_group("Projection Module")
contact_grp = parser.add_argument_group("Contact Module")
inter_grp = parser.add_argument_group("Interaction Module")
train_grp = parser.add_argument_group("Training")
misc_grp = parser.add_argument_group("Output and Device")
# Data
data_grp.add_argument("--train", help="Training data", required=True)
data_grp.add_argument("--val", help="Validation data", required=True)
data_grp.add_argument("--test", help="Testing data")
data_grp.add_argument(
"--embedding",
required=True,
help="h5py path containing embedded sequences",
)
data_grp.add_argument(
"--no-augment",
action="store_true",
help="data is automatically augmented by adding (B A) for all pairs (A B). Set this flag to not augment data",
)
data_grp.add_argument(
"--preload", action="store_true", help="Preload embeddings into memory"
)
# data_grp.add_argument(
# "--val_split",
# default=0.1,
# help="Proportion of data to use for validation",
# )
# Embedding model
proj_grp.add_argument(
"--input-dim",
type=int,
default=6165,
help="dimension of input language model embedding (per amino acid) (default: 6165)",
)
proj_grp.add_argument(
"--projection-dim",
type=int,
default=100,
help="dimension of embedding projection layer (default: 100)",
)
proj_grp.add_argument(
"--dropout-p",
type=float,
default=0.5,
help="parameter p for embedding dropout layer (default: 0.5)",
)
# Contact model
contact_grp.add_argument(
"--hidden-dim",
type=int,
default=50,
help="number of hidden units for comparison layer in contact prediction (default: 50)",
)
contact_grp.add_argument(
"--kernel-width",
type=int,
default=7,
help="width of convolutional filter for contact prediction (default: 7)",
)
# Interaction Model
inter_grp.add_argument(
"--no-w",
action="store_true",
help="don't use weight matrix in interaction prediction model",
)
inter_grp.add_argument(
"--no-sigmoid",
action="store_true",
help="don't use sigmoid activation at end of interaction model",
)
inter_grp.add_argument(
"--do-pool",
action="store_true",
help="use max pool layer in interaction prediction model",
)
inter_grp.add_argument(
"--pool-width",
type=int,
default=9,
help="size of max-pool in interaction model (default: 9)",
)
# Training
train_grp.add_argument(
"--epoch-scale",
type=int,
default=1,
help="Report heldout performance every this many epochs (default: 1)",
)
train_grp.add_argument(
"--num-epochs",
type=int,
default=10,
help="Number of epochs (default: 10)",
)
train_grp.add_argument(
"--batch-size",
type=int,
default=25,
help="Minibatch size (default: 25)",
)
train_grp.add_argument(
"--weight-decay",
type=float,
default=0,
help="L2 regularization (default: 0)",
)
train_grp.add_argument(
"--lr",
type=float,
default=0.001,
help="Learning rate (default: 0.001)",
)
train_grp.add_argument(
"--lambda",
dest="lambda_",
type=float,
default=0.35,
help="weight on the similarity objective (default: 0.35)",
)
# Output
misc_grp.add_argument(
"-o", "--outfile", help="Output file path (default: stdout)"
)
misc_grp.add_argument(
"-v",
"--verbosity",
type=int,
default=2,
help="Verbosity level (default: 2 [info])",
)
misc_grp.add_argument(
"--debug",
action="store_true",
help="Run in debug mode",
)
misc_grp.add_argument(
"--save-prefix", help="Path prefix for saving models"
)
misc_grp.add_argument(
"-d", "--device", type=int, default=-1, help="Compute device to use"
)
misc_grp.add_argument(
"--checkpoint", help="Checkpoint model to start training from"
)
return parser
def main(args):
"""
Run training from arguments.
:meta private:
"""
args.wandb = True
if args.debug:
args.verbosity = 3
args.wandb = False
logg = config_logger(
args.outfile,
"%(asctime)s [%(levelname)s] %(message)s",
args.verbosity,
use_stdout=True,
)
# logg.info(f"Beginning experiment {conf.experiment_id}")
if args.debug:
logg.warning("RUNNING IN DEBUG MODE")
if args.test is None:
args.test = args.val
logg.info("Data:")
logg.info(f"\ttrain file: {args.train}")
logg.info(f"\tval file: {args.val}")
logg.info(f"\ttest file: {args.test}")
logg.info(f"\tdata_augmentation: {not args.no_augment}")
logg.info(f"\tbatch_size: {args.batch_size}")
datamod = PPIDataModule(
args.embedding,
args.train,
args.val,
args.test,
batch_size=args.batch_size,
preload=args.preload,
shuffle=True,
num_workers=0,
augment_train=(not args.no_augment),
)
logg.info("Preparing data")
datamod.prepare_data()
logg.info("Running DataModule set up")
datamod.setup()
logg.info("Configuring model")
logg.info("Initializing embedding model with:")
logg.info(f"\tprojection_dim: {args.projection_dim}")
logg.info(f"\tdropout_p: {args.dropout_p}")
logg.info("Initializing contact model with:")
logg.info(f"\thidden_dim: {args.hidden_dim}")
logg.info(f"\tkernel_width: {args.kernel_width}")
logg.info("Initializing interaction model with:")
logg.info(f"\tpool_width: {args.pool_width}")
logg.info(f"\tinteraction weight: {args.lambda_}")
logg.info(f"\tcontact map weight: {1 - args.lambda_}")
model = LitInteraction(
projection_dim=args.projection_dim,
dropout_p=args.dropout_p,
hidden_dim=args.hidden_dim,
kernel_width=args.kernel_width,
pool_width=args.pool_width,
lr=args.lr,
weight_decay=args.weight_decay,
lambda_similarity=args.lambda_,
save_prefix=args.save_prefix,
save_every=args.epoch_scale,
)
logger_list = [
# pl_loggers.TensorBoardLogger(".", name=conf.experiment_id, default_hp_metric=False),
pl_loggers.CSVLogger(args.save_prefix, name=args.outfile)
]
# if conf.wandb:
# logg.info(f"Logging to WandB {conf.experiment_id}")
# wandb_lg = pl.loggers.WandbLogger(conf.experiment_id,
# save_dir=conf.log_dir,
# project=conf.wandb_proj,
# )
# logger_list.append(wandb_lg)
logg.info(f"Saving checkpoints to '{args.save_prefix}'")
logg.info(
f"Training with Adam: lr={args.lr}, weight_decay={args.weight_decay}"
)
logg.info(f"\tnum_epochs: {args.num_epochs}")
logg.info(f"\tepoch_scale: {args.epoch_scale}")
num_gpus = 1 if torch.cuda.is_available else 0
trainer = pl.Trainer(
logger=logger_list,
max_epochs=args.num_epochs,
gpus=num_gpus,
)
trainer.fit(model, datamod)
trainer.test(model, datamod)
output = args.outfile
if output is None:
output = sys.stdout
else:
output = open(output, "w")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
add_args(parser)
main(parser.parse_args())