Files
foldingdiff/bin/train.py
2022-08-11 18:31:50 +00:00

439 lines
15 KiB
Python

"""
Training script
"""
import os, sys
from posixpath import abspath
import shutil
import json
import logging
from pathlib import Path
import multiprocessing
import argparse
import functools
from typing import *
import git
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch.utils.data import Dataset, Subset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import BertConfig
SRC_DIR = (Path(os.path.dirname(os.path.abspath(__file__))) / "../protdiff").resolve()
assert SRC_DIR.is_dir()
sys.path.append(str(SRC_DIR))
import datasets
import modelling
import losses
from beta_schedules import SCHEDULES
import plotting
import utils
# reproducibility
torch.manual_seed(6489)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
def plot_epoch_losses(loss_values, fname: str):
"""Plot the loss values and save to fname"""
fig, ax = plt.subplots(dpi=300)
ax.plot(np.arange(len(loss_values)), loss_values)
ax.set(xlabel="Epoch", ylabel="Loss", title="Loss over epochs")
fig.savefig(fname)
def get_train_valid_test_sets(
timesteps: int,
variance_schedule: SCHEDULES,
noise_prior: Literal["gaussian", "uniform"] = "gaussian",
shift_to_zero_twopi: bool = False,
var_scale: float = np.pi,
toy: Union[int, bool] = False,
exhaustive_t: bool = False,
syn_noiser: str = "",
single_dist_debug: bool = False,
single_angle_debug: int = -1, # Noise and return a single angle. -1 to disable, 1-3 for omega/theta/phi
single_time_debug: bool = False, # Noise and return a single time
) -> Tuple[Dataset, Dataset, Dataset]:
"""
Get the dataset objects to use for train/valid/test
Note, these need to be wrapped in data loaders later
"""
assert (
single_angle_debug != 0
), f"Invalid value for single_angle_debug: {single_angle_debug}"
clean_dsets = [
datasets.CathConsecutiveAnglesDataset(
split=s, shift_to_zero_twopi=shift_to_zero_twopi, toy=toy
)
for s in ["train", "validation", "test"]
]
if syn_noiser != "":
if syn_noiser == "halfhalf":
logging.warning("Using synthetic half-half noiser")
dset_noiser_class = datasets.SynNoisedByPositionDataset
else:
raise ValueError(f"Unknown synthetic noiser {syn_noiser}")
elif noise_prior == "gaussian":
if single_dist_debug:
logging.warning("Using single dist debug")
dset_noiser_class = datasets.SingleNoisedBondDistanceDataset
elif single_angle_debug > 0:
logging.warning("Using single angle noise!")
dset_noiser_class = functools.partial(
datasets.SingleNoisedAngleDataset, ft_idx=single_angle_debug
)
elif single_time_debug:
logging.warning("Using single angle and single time noise!")
dset_noiser_class = datasets.SingleNoisedAngleAndTimeDataset
else:
dset_noiser_class = datasets.NoisedAnglesDataset
elif noise_prior == "uniform":
dset_noiser_class = datasets.GaussianDistUniformAnglesNoisedDataset
else:
raise ValueError(f"Unrecognized noise prior: {noise_prior}")
logging.info(f"Using {dset_noiser_class} for noise")
noised_dsets = [
dset_noiser_class(
dset=ds,
dset_key="angles",
timesteps=timesteps,
exhaustive_t=(i != 0) and exhaustive_t,
beta_schedule=variance_schedule,
shift_to_zero_twopi=shift_to_zero_twopi,
variances=[1.0, var_scale, var_scale, var_scale],
)
for i, ds in enumerate(clean_dsets)
]
for dsname, ds in zip(["train", "val", "test"], noised_dsets):
logging.info(f"{dsname}: {ds}")
# Lot an example of the data
logging.debug(f"Example clean vals: {noised_dsets[0][0]['angles']}")
logging.debug(f"Example noised vals: {noised_dsets[0][0]['corrupted']}")
logging.debug(f"Example noise: {noised_dsets[0][0]['known_noise']}")
return tuple(noised_dsets)
def build_callbacks(early_stop_patience: Optional[int] = None, swa: bool = False):
"""
Build out the callbacks
"""
callbacks = [
pl.callbacks.ModelCheckpoint(
monitor="val_loss", save_top_k=1, save_weights_only=True,
),
pl.callbacks.LearningRateMonitor(logging_interval="epoch", log_momentum=True),
]
if early_stop_patience is not None and early_stop_patience > 0:
logging.info(f"Using early stopping with patience {early_stop_patience}")
callbacks.append(
pl.callbacks.early_stopping.EarlyStopping(
monitor="val_loss",
patience=early_stop_patience,
verbose=True,
mode="min",
)
)
if swa:
# Stochastic weight averaging
callbacks.append(pl.callbacks.StochasticWeightAveraging())
logging.info(f"Model callbacks: {callbacks}")
return callbacks
# For some arg defaults, see as reference:
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer.html
def train(
# Controls output
results_dir: str = "./results",
# Controls data loading and noising process
shift_angles_zero_twopi: bool = False,
noise_prior: Literal["gaussian", "uniform"] = "gaussian", # Uniform not tested
timesteps: int = 1000,
variance_schedule: SCHEDULES = "cosine", # cosine better on single angle toy test
variance_scale: float = np.pi,
# Related to model architecture
implementation: Literal[
"pytorch_encoder", "huggingface_encoder"
] = "pytorch_encoder",
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
num_hidden_layers: int = 6, # Default 12
hidden_size: int = 72, # Default 768
intermediate_size: int = 144, # Default 3072
num_heads: int = 8, # Default 12
position_embedding_type: Literal[
"absolute", "relative_key", "relative_key_query"
] = "relative_key_query",
dropout_p: float = 0.1, # Default 0.1, can disable for debugging
# Related to training strategy
gradient_clip: float = 1.0, # From BERT trainer
batch_size: int = 64,
lr: float = 5e-5, # Default lr for huggingface BERT trainer
loss: Literal["huber", "radian_l1", "radian_l1_smooth"] = "radian_l1_smooth",
l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0
l1_norm: float = 0.0,
circle_reg: float = 0.0,
min_epochs: int = 500,
max_epochs: int = 2000,
early_stop_patience: int = 10, # Set to 0 to disable early stopping
lr_scheduler: Optional[Literal["OneCycleLR"]] = None,
use_swa: bool = False, # Stochastic weight averaging can improve training genearlization
# Misc. and debugging
multithread: bool = True,
subset: Union[bool, int] = False, # Subset to n training examples
exhaustive_validation_t: bool = False, # Exhaustively enumerate t for validation/test
syn_noiser: str = "", # If specified, use a synthetic noiser
single_dist_debug: bool = False, # Debug on distance (no periodicity)
single_angle_debug: int = -1, # Noise and return a single angle, choose [1, 2, 3] or -1 to disable
single_timestep_debug: bool = False, # Noise and return a single timestep
):
"""Main training loop"""
# Record the args given to the function before we create more vars
# https://stackoverflow.com/questions/10724495/getting-all-arguments-and-values-passed-to-a-function
func_args = locals()
# Create results directory
results_folder = Path(results_dir)
if results_folder.exists():
logging.warning(f"Removing old results directory: {results_folder}")
shutil.rmtree(results_folder)
results_folder.mkdir(exist_ok=True)
with open(results_folder / "training_args.json", "w") as sink:
logging.info(f"Writing training args to {sink.name}")
json.dump(func_args, sink, indent=4)
for k, v in func_args.items():
logging.info(f"Training argument: {k}={v}")
# Record current Git version
repo = git.Repo(
path=os.path.dirname(os.path.abspath(__file__)), search_parent_directories=True
)
sha = repo.head.object.hexsha
with open(results_folder / "git_sha.txt", "w") as sink:
sink.write(sha + "\n")
# Get datasets and wrap them in dataloaders
dsets = get_train_valid_test_sets(
timesteps=timesteps,
variance_schedule=variance_schedule,
noise_prior=noise_prior,
shift_to_zero_twopi=shift_angles_zero_twopi,
var_scale=variance_scale,
toy=subset,
syn_noiser=syn_noiser,
exhaustive_t=exhaustive_validation_t,
single_dist_debug=single_dist_debug,
single_angle_debug=single_angle_debug,
single_time_debug=single_timestep_debug,
)
train_dataloader, valid_dataloader, test_dataloader = [
DataLoader(
dataset=ds,
batch_size=batch_size,
shuffle=False, # Shuffle only train loader
num_workers=multiprocessing.cpu_count() if multithread else 1,
)
for i, ds in enumerate(dsets)
]
# Create plots in output directories of distributions from different timesteps
plots_folder = results_folder / "plots"
os.makedirs(plots_folder, exist_ok=True)
# Skip this for debug runs
if (
single_angle_debug < 0
and not single_timestep_debug
and not single_dist_debug
and not syn_noiser
):
for t in np.linspace(0, timesteps, num=11, endpoint=True).astype(int):
t = min(t, timesteps - 1) # Ensure we don't exceed the number of timesteps
logging.info(f"Plotting distribution at time {t}")
plotting.plot_val_dists_at_t(
dsets[0],
t=t,
share_axes=False,
zero_center_angles=not shift_angles_zero_twopi,
fname=plots_folder / f"train_dists_at_t_{t}.pdf",
)
# https://jaketae.github.io/study/relative-positional-encoding/
# looking at the relative distance between things is more robust
loss_fn = loss
if single_angle_debug > 0 or single_timestep_debug or syn_noiser:
loss_fn = functools.partial(losses.radian_smooth_l1_loss, beta=0.1 * np.pi)
elif single_dist_debug:
loss_fn = F.smooth_l1_loss
model_n_inputs = (
1
if single_angle_debug > 0
or single_timestep_debug
or single_dist_debug
or syn_noiser
else 4
)
if implementation == "pytorch_encoder":
logging.info("Using PyTorch encoder implementation")
model = modelling.BertDenoiserEncoderModel(
n_inputs=model_n_inputs,
time_encoding=time_encoding,
num_layers=num_hidden_layers,
d_model=hidden_size,
intermediate_size=intermediate_size,
num_heads=num_heads,
dropout=dropout_p,
lr=lr,
loss=loss_fn,
l2=l2_norm,
l1=l1_norm,
circle_reg=circle_reg,
min_epochs=min_epochs,
steps_per_epoch=len(train_dataloader),
lr_scheduler=lr_scheduler,
write_preds_to_dir=results_folder / "valid_preds",
)
elif implementation == "huggingface_encoder":
logging.info("Using HuggingFace encoder implementation")
cfg = BertConfig(
num_attention_heads=num_heads,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
position_embedding_type=position_embedding_type,
hidden_dropout_prob=dropout_p,
attention_probs_dropout_prob=dropout_p,
use_cache=False,
)
model = modelling.BertForDiffusion(
cfg,
time_encoding=time_encoding,
n_inputs=model_n_inputs,
lr=lr,
loss=loss_fn,
l2=l2_norm,
l1=l1_norm,
circle_reg=circle_reg,
write_preds_to_dir=results_folder / "valid_preds",
)
cfg.save_pretrained(results_folder)
else:
raise ValueError(f"Unknown implementation: {implementation}")
callbacks = build_callbacks(early_stop_patience=early_stop_patience, swa=use_swa)
trainer = pl.Trainer(
default_root_dir=results_folder,
gradient_clip_val=gradient_clip,
min_epochs=min_epochs,
max_epochs=max_epochs,
check_val_every_n_epoch=1,
callbacks=callbacks,
logger=pl.loggers.CSVLogger(save_dir=results_folder / "logs"),
log_every_n_steps=min(50, len(train_dataloader)), # Log at least once per epoch
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices=1,
)
trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=valid_dataloader,
)
# Plot the losses
metrics_csv = os.path.join(
trainer.logger.save_dir, "lightning_logs/version_0/metrics.csv"
)
assert os.path.isfile(metrics_csv)
# Plot the losses
plotting.plot_losses(
metrics_csv, out_fname=plots_folder / "losses.pdf", simple=True
)
def build_parser() -> argparse.ArgumentParser:
"""
Build CLI parser
"""
parser = argparse.ArgumentParser(
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# https://stackoverflow.com/questions/4480075/argparse-optional-positional-arguments
parser.add_argument(
"config", nargs="?", default="", type=str, help="json of params"
)
parser.add_argument(
"-o",
"--outdir",
type=str,
default=os.path.join(os.getcwd(), "results"),
help="Directory to write model training outputs",
)
parser.add_argument(
"--implementation",
type=str,
choices=["pytorch_encoder", "huggingface_encoder"],
default="pytorch_encoder",
help="Which implementation to use",
)
parser.add_argument(
"--toy",
type=int,
default=None,
help="Use a toy dataset of n items rather than full dataset",
)
parser.add_argument("--debug_dist", action="store_true", help="Debug distances")
parser.add_argument(
"--debug_single_time",
action="store_true",
help="Debug single angle and timestep",
)
return parser
def main():
"""Run the training script based on params in the given json file"""
parser = build_parser()
args = parser.parse_args()
# Load in parameters and run training loop
config_args = {} # Empty dictionary as default
if args.config:
with open(args.config) as source:
config_args = json.load(source)
config_args = utils.update_dict_nonnull(
config_args,
{
"results_dir": args.outdir,
"implementation": args.implementation,
"subset": args.toy,
"single_dist_debug": args.debug_dist,
"single_timestep_debug": args.debug_single_time,
},
)
train(**config_args)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()