mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-06 06:44:23 +08:00
439 lines
15 KiB
Python
439 lines
15 KiB
Python
"""
|
|
Training script
|
|
"""
|
|
|
|
import os, sys
|
|
from posixpath import abspath
|
|
import shutil
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
import multiprocessing
|
|
import argparse
|
|
import functools
|
|
from typing import *
|
|
|
|
import git
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset, Subset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
import torch.nn.functional as F
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
from transformers import BertConfig
|
|
|
|
SRC_DIR = (Path(os.path.dirname(os.path.abspath(__file__))) / "../protdiff").resolve()
|
|
assert SRC_DIR.is_dir()
|
|
sys.path.append(str(SRC_DIR))
|
|
|
|
import datasets
|
|
import modelling
|
|
import losses
|
|
from beta_schedules import SCHEDULES
|
|
import plotting
|
|
import utils
|
|
|
|
|
|
# reproducibility
|
|
torch.manual_seed(6489)
|
|
# torch.use_deterministic_algorithms(True)
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def plot_epoch_losses(loss_values, fname: str):
|
|
"""Plot the loss values and save to fname"""
|
|
fig, ax = plt.subplots(dpi=300)
|
|
ax.plot(np.arange(len(loss_values)), loss_values)
|
|
ax.set(xlabel="Epoch", ylabel="Loss", title="Loss over epochs")
|
|
fig.savefig(fname)
|
|
|
|
|
|
def get_train_valid_test_sets(
|
|
timesteps: int,
|
|
variance_schedule: SCHEDULES,
|
|
noise_prior: Literal["gaussian", "uniform"] = "gaussian",
|
|
shift_to_zero_twopi: bool = False,
|
|
var_scale: float = np.pi,
|
|
toy: Union[int, bool] = False,
|
|
exhaustive_t: bool = False,
|
|
syn_noiser: str = "",
|
|
single_dist_debug: bool = False,
|
|
single_angle_debug: int = -1, # Noise and return a single angle. -1 to disable, 1-3 for omega/theta/phi
|
|
single_time_debug: bool = False, # Noise and return a single time
|
|
) -> Tuple[Dataset, Dataset, Dataset]:
|
|
"""
|
|
Get the dataset objects to use for train/valid/test
|
|
|
|
Note, these need to be wrapped in data loaders later
|
|
"""
|
|
assert (
|
|
single_angle_debug != 0
|
|
), f"Invalid value for single_angle_debug: {single_angle_debug}"
|
|
clean_dsets = [
|
|
datasets.CathConsecutiveAnglesDataset(
|
|
split=s, shift_to_zero_twopi=shift_to_zero_twopi, toy=toy
|
|
)
|
|
for s in ["train", "validation", "test"]
|
|
]
|
|
|
|
if syn_noiser != "":
|
|
if syn_noiser == "halfhalf":
|
|
logging.warning("Using synthetic half-half noiser")
|
|
dset_noiser_class = datasets.SynNoisedByPositionDataset
|
|
else:
|
|
raise ValueError(f"Unknown synthetic noiser {syn_noiser}")
|
|
elif noise_prior == "gaussian":
|
|
if single_dist_debug:
|
|
logging.warning("Using single dist debug")
|
|
dset_noiser_class = datasets.SingleNoisedBondDistanceDataset
|
|
elif single_angle_debug > 0:
|
|
logging.warning("Using single angle noise!")
|
|
dset_noiser_class = functools.partial(
|
|
datasets.SingleNoisedAngleDataset, ft_idx=single_angle_debug
|
|
)
|
|
elif single_time_debug:
|
|
logging.warning("Using single angle and single time noise!")
|
|
dset_noiser_class = datasets.SingleNoisedAngleAndTimeDataset
|
|
else:
|
|
dset_noiser_class = datasets.NoisedAnglesDataset
|
|
elif noise_prior == "uniform":
|
|
dset_noiser_class = datasets.GaussianDistUniformAnglesNoisedDataset
|
|
else:
|
|
raise ValueError(f"Unrecognized noise prior: {noise_prior}")
|
|
|
|
logging.info(f"Using {dset_noiser_class} for noise")
|
|
noised_dsets = [
|
|
dset_noiser_class(
|
|
dset=ds,
|
|
dset_key="angles",
|
|
timesteps=timesteps,
|
|
exhaustive_t=(i != 0) and exhaustive_t,
|
|
beta_schedule=variance_schedule,
|
|
shift_to_zero_twopi=shift_to_zero_twopi,
|
|
variances=[1.0, var_scale, var_scale, var_scale],
|
|
)
|
|
for i, ds in enumerate(clean_dsets)
|
|
]
|
|
for dsname, ds in zip(["train", "val", "test"], noised_dsets):
|
|
logging.info(f"{dsname}: {ds}")
|
|
|
|
# Lot an example of the data
|
|
logging.debug(f"Example clean vals: {noised_dsets[0][0]['angles']}")
|
|
logging.debug(f"Example noised vals: {noised_dsets[0][0]['corrupted']}")
|
|
logging.debug(f"Example noise: {noised_dsets[0][0]['known_noise']}")
|
|
|
|
return tuple(noised_dsets)
|
|
|
|
|
|
def build_callbacks(early_stop_patience: Optional[int] = None, swa: bool = False):
|
|
"""
|
|
Build out the callbacks
|
|
"""
|
|
callbacks = [
|
|
pl.callbacks.ModelCheckpoint(
|
|
monitor="val_loss", save_top_k=1, save_weights_only=True,
|
|
),
|
|
pl.callbacks.LearningRateMonitor(logging_interval="epoch", log_momentum=True),
|
|
]
|
|
if early_stop_patience is not None and early_stop_patience > 0:
|
|
logging.info(f"Using early stopping with patience {early_stop_patience}")
|
|
callbacks.append(
|
|
pl.callbacks.early_stopping.EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=early_stop_patience,
|
|
verbose=True,
|
|
mode="min",
|
|
)
|
|
)
|
|
if swa:
|
|
# Stochastic weight averaging
|
|
callbacks.append(pl.callbacks.StochasticWeightAveraging())
|
|
logging.info(f"Model callbacks: {callbacks}")
|
|
return callbacks
|
|
|
|
|
|
# For some arg defaults, see as reference:
|
|
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer.html
|
|
|
|
|
|
def train(
|
|
# Controls output
|
|
results_dir: str = "./results",
|
|
# Controls data loading and noising process
|
|
shift_angles_zero_twopi: bool = False,
|
|
noise_prior: Literal["gaussian", "uniform"] = "gaussian", # Uniform not tested
|
|
timesteps: int = 1000,
|
|
variance_schedule: SCHEDULES = "cosine", # cosine better on single angle toy test
|
|
variance_scale: float = np.pi,
|
|
# Related to model architecture
|
|
implementation: Literal[
|
|
"pytorch_encoder", "huggingface_encoder"
|
|
] = "pytorch_encoder",
|
|
time_encoding: Literal["gaussian_fourier", "sinusoidal"] = "gaussian_fourier",
|
|
num_hidden_layers: int = 6, # Default 12
|
|
hidden_size: int = 72, # Default 768
|
|
intermediate_size: int = 144, # Default 3072
|
|
num_heads: int = 8, # Default 12
|
|
position_embedding_type: Literal[
|
|
"absolute", "relative_key", "relative_key_query"
|
|
] = "relative_key_query",
|
|
dropout_p: float = 0.1, # Default 0.1, can disable for debugging
|
|
# Related to training strategy
|
|
gradient_clip: float = 1.0, # From BERT trainer
|
|
batch_size: int = 64,
|
|
lr: float = 5e-5, # Default lr for huggingface BERT trainer
|
|
loss: Literal["huber", "radian_l1", "radian_l1_smooth"] = "radian_l1_smooth",
|
|
l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0
|
|
l1_norm: float = 0.0,
|
|
circle_reg: float = 0.0,
|
|
min_epochs: int = 500,
|
|
max_epochs: int = 2000,
|
|
early_stop_patience: int = 10, # Set to 0 to disable early stopping
|
|
lr_scheduler: Optional[Literal["OneCycleLR"]] = None,
|
|
use_swa: bool = False, # Stochastic weight averaging can improve training genearlization
|
|
# Misc. and debugging
|
|
multithread: bool = True,
|
|
subset: Union[bool, int] = False, # Subset to n training examples
|
|
exhaustive_validation_t: bool = False, # Exhaustively enumerate t for validation/test
|
|
syn_noiser: str = "", # If specified, use a synthetic noiser
|
|
single_dist_debug: bool = False, # Debug on distance (no periodicity)
|
|
single_angle_debug: int = -1, # Noise and return a single angle, choose [1, 2, 3] or -1 to disable
|
|
single_timestep_debug: bool = False, # Noise and return a single timestep
|
|
):
|
|
"""Main training loop"""
|
|
# Record the args given to the function before we create more vars
|
|
# https://stackoverflow.com/questions/10724495/getting-all-arguments-and-values-passed-to-a-function
|
|
func_args = locals()
|
|
|
|
# Create results directory
|
|
results_folder = Path(results_dir)
|
|
if results_folder.exists():
|
|
logging.warning(f"Removing old results directory: {results_folder}")
|
|
shutil.rmtree(results_folder)
|
|
results_folder.mkdir(exist_ok=True)
|
|
with open(results_folder / "training_args.json", "w") as sink:
|
|
logging.info(f"Writing training args to {sink.name}")
|
|
json.dump(func_args, sink, indent=4)
|
|
for k, v in func_args.items():
|
|
logging.info(f"Training argument: {k}={v}")
|
|
|
|
# Record current Git version
|
|
repo = git.Repo(
|
|
path=os.path.dirname(os.path.abspath(__file__)), search_parent_directories=True
|
|
)
|
|
sha = repo.head.object.hexsha
|
|
with open(results_folder / "git_sha.txt", "w") as sink:
|
|
sink.write(sha + "\n")
|
|
|
|
# Get datasets and wrap them in dataloaders
|
|
dsets = get_train_valid_test_sets(
|
|
timesteps=timesteps,
|
|
variance_schedule=variance_schedule,
|
|
noise_prior=noise_prior,
|
|
shift_to_zero_twopi=shift_angles_zero_twopi,
|
|
var_scale=variance_scale,
|
|
toy=subset,
|
|
syn_noiser=syn_noiser,
|
|
exhaustive_t=exhaustive_validation_t,
|
|
single_dist_debug=single_dist_debug,
|
|
single_angle_debug=single_angle_debug,
|
|
single_time_debug=single_timestep_debug,
|
|
)
|
|
train_dataloader, valid_dataloader, test_dataloader = [
|
|
DataLoader(
|
|
dataset=ds,
|
|
batch_size=batch_size,
|
|
shuffle=False, # Shuffle only train loader
|
|
num_workers=multiprocessing.cpu_count() if multithread else 1,
|
|
)
|
|
for i, ds in enumerate(dsets)
|
|
]
|
|
|
|
# Create plots in output directories of distributions from different timesteps
|
|
plots_folder = results_folder / "plots"
|
|
os.makedirs(plots_folder, exist_ok=True)
|
|
# Skip this for debug runs
|
|
if (
|
|
single_angle_debug < 0
|
|
and not single_timestep_debug
|
|
and not single_dist_debug
|
|
and not syn_noiser
|
|
):
|
|
for t in np.linspace(0, timesteps, num=11, endpoint=True).astype(int):
|
|
t = min(t, timesteps - 1) # Ensure we don't exceed the number of timesteps
|
|
logging.info(f"Plotting distribution at time {t}")
|
|
plotting.plot_val_dists_at_t(
|
|
dsets[0],
|
|
t=t,
|
|
share_axes=False,
|
|
zero_center_angles=not shift_angles_zero_twopi,
|
|
fname=plots_folder / f"train_dists_at_t_{t}.pdf",
|
|
)
|
|
|
|
# https://jaketae.github.io/study/relative-positional-encoding/
|
|
# looking at the relative distance between things is more robust
|
|
|
|
loss_fn = loss
|
|
if single_angle_debug > 0 or single_timestep_debug or syn_noiser:
|
|
loss_fn = functools.partial(losses.radian_smooth_l1_loss, beta=0.1 * np.pi)
|
|
elif single_dist_debug:
|
|
loss_fn = F.smooth_l1_loss
|
|
|
|
model_n_inputs = (
|
|
1
|
|
if single_angle_debug > 0
|
|
or single_timestep_debug
|
|
or single_dist_debug
|
|
or syn_noiser
|
|
else 4
|
|
)
|
|
|
|
if implementation == "pytorch_encoder":
|
|
logging.info("Using PyTorch encoder implementation")
|
|
model = modelling.BertDenoiserEncoderModel(
|
|
n_inputs=model_n_inputs,
|
|
time_encoding=time_encoding,
|
|
num_layers=num_hidden_layers,
|
|
d_model=hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
num_heads=num_heads,
|
|
dropout=dropout_p,
|
|
lr=lr,
|
|
loss=loss_fn,
|
|
l2=l2_norm,
|
|
l1=l1_norm,
|
|
circle_reg=circle_reg,
|
|
min_epochs=min_epochs,
|
|
steps_per_epoch=len(train_dataloader),
|
|
lr_scheduler=lr_scheduler,
|
|
write_preds_to_dir=results_folder / "valid_preds",
|
|
)
|
|
elif implementation == "huggingface_encoder":
|
|
logging.info("Using HuggingFace encoder implementation")
|
|
cfg = BertConfig(
|
|
num_attention_heads=num_heads,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
position_embedding_type=position_embedding_type,
|
|
hidden_dropout_prob=dropout_p,
|
|
attention_probs_dropout_prob=dropout_p,
|
|
use_cache=False,
|
|
)
|
|
model = modelling.BertForDiffusion(
|
|
cfg,
|
|
time_encoding=time_encoding,
|
|
n_inputs=model_n_inputs,
|
|
lr=lr,
|
|
loss=loss_fn,
|
|
l2=l2_norm,
|
|
l1=l1_norm,
|
|
circle_reg=circle_reg,
|
|
write_preds_to_dir=results_folder / "valid_preds",
|
|
)
|
|
cfg.save_pretrained(results_folder)
|
|
else:
|
|
raise ValueError(f"Unknown implementation: {implementation}")
|
|
|
|
callbacks = build_callbacks(early_stop_patience=early_stop_patience, swa=use_swa)
|
|
trainer = pl.Trainer(
|
|
default_root_dir=results_folder,
|
|
gradient_clip_val=gradient_clip,
|
|
min_epochs=min_epochs,
|
|
max_epochs=max_epochs,
|
|
check_val_every_n_epoch=1,
|
|
callbacks=callbacks,
|
|
logger=pl.loggers.CSVLogger(save_dir=results_folder / "logs"),
|
|
log_every_n_steps=min(50, len(train_dataloader)), # Log at least once per epoch
|
|
accelerator="gpu" if torch.cuda.is_available() else "cpu",
|
|
devices=1,
|
|
)
|
|
trainer.fit(
|
|
model=model,
|
|
train_dataloaders=train_dataloader,
|
|
val_dataloaders=valid_dataloader,
|
|
)
|
|
|
|
# Plot the losses
|
|
metrics_csv = os.path.join(
|
|
trainer.logger.save_dir, "lightning_logs/version_0/metrics.csv"
|
|
)
|
|
assert os.path.isfile(metrics_csv)
|
|
# Plot the losses
|
|
plotting.plot_losses(
|
|
metrics_csv, out_fname=plots_folder / "losses.pdf", simple=True
|
|
)
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
"""
|
|
Build CLI parser
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
# https://stackoverflow.com/questions/4480075/argparse-optional-positional-arguments
|
|
parser.add_argument(
|
|
"config", nargs="?", default="", type=str, help="json of params"
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--outdir",
|
|
type=str,
|
|
default=os.path.join(os.getcwd(), "results"),
|
|
help="Directory to write model training outputs",
|
|
)
|
|
parser.add_argument(
|
|
"--implementation",
|
|
type=str,
|
|
choices=["pytorch_encoder", "huggingface_encoder"],
|
|
default="pytorch_encoder",
|
|
help="Which implementation to use",
|
|
)
|
|
parser.add_argument(
|
|
"--toy",
|
|
type=int,
|
|
default=None,
|
|
help="Use a toy dataset of n items rather than full dataset",
|
|
)
|
|
parser.add_argument("--debug_dist", action="store_true", help="Debug distances")
|
|
parser.add_argument(
|
|
"--debug_single_time",
|
|
action="store_true",
|
|
help="Debug single angle and timestep",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main():
|
|
"""Run the training script based on params in the given json file"""
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
|
|
# Load in parameters and run training loop
|
|
config_args = {} # Empty dictionary as default
|
|
if args.config:
|
|
with open(args.config) as source:
|
|
config_args = json.load(source)
|
|
config_args = utils.update_dict_nonnull(
|
|
config_args,
|
|
{
|
|
"results_dir": args.outdir,
|
|
"implementation": args.implementation,
|
|
"subset": args.toy,
|
|
"single_dist_debug": args.debug_dist,
|
|
"single_timestep_debug": args.debug_single_time,
|
|
},
|
|
)
|
|
train(**config_args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main()
|