mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 21:34:32 +08:00
584 lines
21 KiB
Python
584 lines
21 KiB
Python
"""
|
|
Training script.
|
|
|
|
Example usage: python ~/protdiff/bin/train.py ~/protdiff/config_jsons/full_run_canonical_angles_only_zero_centered_1000_timesteps_reduced_len.json
|
|
"""
|
|
|
|
import os, sys
|
|
import shutil
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
import multiprocessing
|
|
import argparse
|
|
import functools
|
|
from datetime import datetime
|
|
from typing import *
|
|
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset, Subset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
import torch.nn.functional as F
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.strategies.ddp import DDPStrategy
|
|
|
|
from transformers import BertConfig
|
|
|
|
from foldingdiff import datasets
|
|
from foldingdiff import modelling
|
|
from foldingdiff import losses
|
|
from foldingdiff import beta_schedules
|
|
from foldingdiff import plotting
|
|
from foldingdiff import utils
|
|
from foldingdiff import custom_metrics as cm
|
|
|
|
assert torch.cuda.is_available(), "Requires CUDA to train"
|
|
# reproducibility
|
|
torch.manual_seed(6489)
|
|
# torch.use_deterministic_algorithms(True)
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
# Define some typing literals
|
|
ANGLES_DEFINITIONS = Literal[
|
|
"canonical", "canonical-full-angles", "canonical-minimal-angles", "cart-coords"
|
|
]
|
|
|
|
|
|
@pl.utilities.rank_zero_only
|
|
def plot_timestep_distributions(
|
|
train_dset,
|
|
timesteps: int,
|
|
plots_folder: Path,
|
|
shift_angles_zero_twopi: bool = False,
|
|
n_intervals: int = 11,
|
|
) -> None:
|
|
"""
|
|
Plot the distributions across timesteps. This is parallelized across multiple cores
|
|
"""
|
|
ts = np.linspace(0, timesteps, num=n_intervals, endpoint=True).astype(int)
|
|
ts = np.minimum(ts, timesteps - 1).tolist()
|
|
logging.info(f"Plotting distributions at {ts} to {plots_folder}")
|
|
args = [
|
|
(
|
|
t,
|
|
train_dset,
|
|
True,
|
|
not shift_angles_zero_twopi,
|
|
plots_folder / f"train_dists_at_t_{t}.pdf",
|
|
)
|
|
for t in ts
|
|
]
|
|
|
|
# Parallelize the plotting
|
|
pool = multiprocessing.Pool(processes=min(multiprocessing.cpu_count(), len(ts)))
|
|
pool.starmap(plotting.plot_val_dists_at_t, args)
|
|
pool.close()
|
|
pool.join()
|
|
|
|
|
|
@pl.utilities.rank_zero_only
|
|
def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
|
|
"""
|
|
Plot the KL divergence over time
|
|
"""
|
|
# This works because the main body of this script should clean out the dir
|
|
# between runs
|
|
outname = plots_folder / "kl_divergence_timesteps.pdf"
|
|
if outname.is_file():
|
|
logging.info(f"KL divergence plot exists at {outname}; skipping...")
|
|
kl_at_timesteps = cm.kl_from_dset(train_dset) # Shape (n_timesteps, n_features)
|
|
n_timesteps, n_features = kl_at_timesteps.shape
|
|
fig, axes = plt.subplots(
|
|
dpi=300, figsize=(n_features * 3.05, 2.5), ncols=n_features, sharey=True
|
|
)
|
|
for i, (ft_name, ax) in enumerate(zip(train_dset.feature_names["angles"], axes)):
|
|
ax.plot(np.arange(n_timesteps), kl_at_timesteps[:, i], label=ft_name)
|
|
ax.axhline(0, color="grey", linestyle="--", alpha=0.5)
|
|
ax.set(title=ft_name)
|
|
if i == 0:
|
|
ax.set(ylabel="KL divergence")
|
|
ax.set(xlabel="Timestep")
|
|
fig.suptitle(
|
|
f"KL(empirical || Gaussian) over timesteps={train_dset.timesteps}", y=1.05
|
|
)
|
|
fig.savefig(outname, bbox_inches="tight")
|
|
|
|
|
|
def get_train_valid_test_sets(
|
|
dataset_key: str = "cath",
|
|
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
|
|
max_seq_len: int = 512,
|
|
min_seq_len: int = 0,
|
|
seq_trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
|
|
timesteps: int = 250,
|
|
variance_schedule: beta_schedules.SCHEDULES = "linear",
|
|
var_scale: float = np.pi,
|
|
toy: Union[int, bool] = False,
|
|
exhaustive_t: bool = False,
|
|
syn_noiser: str = "",
|
|
single_angle_debug: int = -1, # Noise and return a single angle. -1 to disable, 1-3 for omega/theta/phi
|
|
single_time_debug: bool = False, # Noise and return a single time
|
|
train_only: bool = False,
|
|
) -> Tuple[Dataset, Dataset, Dataset]:
|
|
"""
|
|
Get the dataset objects to use for train/valid/test
|
|
|
|
Note, these need to be wrapped in data loaders later
|
|
"""
|
|
assert (
|
|
single_angle_debug != 0
|
|
), f"Invalid value for single_angle_debug: {single_angle_debug}"
|
|
|
|
clean_dset_class = {
|
|
"canonical": datasets.CathCanonicalAnglesDataset,
|
|
"canonical-full-angles": datasets.CathCanonicalAnglesOnlyDataset,
|
|
"canonical-minimal-angles": datasets.CathCanonicalMinimalAnglesDataset,
|
|
"cart-coords": datasets.CathCanonicalCoordsDataset,
|
|
}[angles_definitions]
|
|
logging.info(f"Clean dataset class: {clean_dset_class}")
|
|
|
|
splits = ["train"] if train_only else ["train", "validation", "test"]
|
|
logging.info(f"Creating data splits: {splits}")
|
|
clean_dsets = [
|
|
clean_dset_class(
|
|
pdbs=dataset_key,
|
|
split=s,
|
|
pad=max_seq_len,
|
|
min_length=min_seq_len,
|
|
trim_strategy=seq_trim_strategy,
|
|
zero_center=False if angles_definitions == "cart-coords" else True,
|
|
toy=toy,
|
|
)
|
|
for s in splits
|
|
]
|
|
assert len(clean_dsets) == len(splits)
|
|
# Set the training set mean to the validation set mean
|
|
if len(clean_dsets) > 1 and clean_dsets[0].means is not None:
|
|
logging.info(f"Updating valid/test mean offset to {clean_dsets[0].means}")
|
|
for i in range(1, len(clean_dsets)):
|
|
clean_dsets[i].means = clean_dsets[0].means
|
|
|
|
if syn_noiser != "":
|
|
if syn_noiser == "halfhalf":
|
|
logging.warning("Using synthetic half-half noiser")
|
|
dset_noiser_class = datasets.SynNoisedByPositionDataset
|
|
else:
|
|
raise ValueError(f"Unknown synthetic noiser {syn_noiser}")
|
|
else:
|
|
if single_angle_debug > 0:
|
|
logging.warning("Using single angle noise!")
|
|
dset_noiser_class = functools.partial(
|
|
datasets.SingleNoisedAngleDataset, ft_idx=single_angle_debug
|
|
)
|
|
elif single_time_debug:
|
|
logging.warning("Using single angle and single time noise!")
|
|
dset_noiser_class = datasets.SingleNoisedAngleAndTimeDataset
|
|
else:
|
|
dset_noiser_class = datasets.NoisedAnglesDataset
|
|
|
|
logging.info(f"Using {dset_noiser_class} for noise")
|
|
noised_dsets = [
|
|
dset_noiser_class(
|
|
dset=ds,
|
|
dset_key="coords" if angles_definitions == "cart-coords" else "angles",
|
|
timesteps=timesteps,
|
|
exhaustive_t=(i != 0) and exhaustive_t,
|
|
beta_schedule=variance_schedule,
|
|
nonangular_variance=1.0,
|
|
angular_variance=var_scale,
|
|
)
|
|
for i, ds in enumerate(clean_dsets)
|
|
]
|
|
for dsname, ds in zip(splits, noised_dsets):
|
|
logging.info(f"{dsname}: {ds}")
|
|
|
|
# Pad with None values
|
|
if len(noised_dsets) < 3:
|
|
noised_dsets = noised_dsets + [None] * int(3 - len(noised_dsets))
|
|
assert len(noised_dsets) == 3
|
|
|
|
return tuple(noised_dsets)
|
|
|
|
|
|
def build_callbacks(
|
|
outdir: str, early_stop_patience: Optional[int] = None, swa: bool = False
|
|
):
|
|
"""
|
|
Build out the callbacks
|
|
"""
|
|
# Create the logging dir
|
|
os.makedirs(os.path.join(outdir, "logs/lightning_logs"), exist_ok=True)
|
|
os.makedirs(os.path.join(outdir, "models/best_by_valid"), exist_ok=True)
|
|
os.makedirs(os.path.join(outdir, "models/best_by_train"), exist_ok=True)
|
|
callbacks = [
|
|
pl.callbacks.ModelCheckpoint(
|
|
monitor="val_loss",
|
|
dirpath=os.path.join(outdir, "models/best_by_valid"),
|
|
save_top_k=5,
|
|
save_weights_only=True,
|
|
mode="min",
|
|
),
|
|
pl.callbacks.ModelCheckpoint(
|
|
monitor="train_loss",
|
|
dirpath=os.path.join(outdir, "models/best_by_train"),
|
|
save_top_k=5,
|
|
save_weights_only=True,
|
|
mode="min",
|
|
),
|
|
pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
|
|
]
|
|
if early_stop_patience is not None and early_stop_patience > 0:
|
|
logging.info(f"Using early stopping with patience {early_stop_patience}")
|
|
callbacks.append(
|
|
pl.callbacks.early_stopping.EarlyStopping(
|
|
monitor="val_loss",
|
|
patience=early_stop_patience,
|
|
verbose=True,
|
|
mode="min",
|
|
)
|
|
)
|
|
if swa:
|
|
# Stochastic weight averaging
|
|
callbacks.append(pl.callbacks.StochasticWeightAveraging())
|
|
logging.info(f"Model callbacks: {callbacks}")
|
|
return callbacks
|
|
|
|
|
|
# For some arg defaults, see as reference:
|
|
# https://huggingface.co/docs/transformers/main/en/main_classes/trainer.html
|
|
|
|
|
|
@pl.utilities.rank_zero_only
|
|
def record_args_and_metadata(func_args: Dict[str, Any], results_folder: Path):
|
|
# Create results directory
|
|
if results_folder.exists():
|
|
logging.warning(f"Removing old results directory: {results_folder}")
|
|
shutil.rmtree(results_folder)
|
|
results_folder.mkdir(exist_ok=True)
|
|
with open(results_folder / "training_args.json", "w") as sink:
|
|
logging.info(f"Writing training args to {sink.name}")
|
|
json.dump(func_args, sink, indent=4)
|
|
for k, v in func_args.items():
|
|
logging.info(f"Training argument: {k}={v}")
|
|
|
|
# Record current Git version
|
|
try:
|
|
import git
|
|
|
|
repo = git.Repo(
|
|
path=os.path.dirname(os.path.abspath(__file__)),
|
|
search_parent_directories=True,
|
|
)
|
|
sha = repo.head.object.hexsha
|
|
with open(results_folder / "git_sha.txt", "w") as sink:
|
|
sink.write(sha + "\n")
|
|
except git.exc.InvalidGitRepositoryError:
|
|
logging.warning("Could not determine Git repo status -- not a git repo")
|
|
except ModuleNotFoundError:
|
|
logging.warning(
|
|
f"Could not determine Git repo status -- GitPython is not installed"
|
|
)
|
|
|
|
|
|
def train(
|
|
# Controls output
|
|
results_dir: str = "./results",
|
|
# Controls data loading and noising process
|
|
dataset_key: str = "cath", # cath, alhpafold, or a directory containing pdb files
|
|
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
|
|
max_seq_len: int = 512,
|
|
min_seq_len: int = 0, # 0 means no filtering based on min sequence length
|
|
trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
|
|
timesteps: int = 250,
|
|
variance_schedule: beta_schedules.SCHEDULES = "linear", # cosine better on single angle toy test
|
|
variance_scale: float = 1.0,
|
|
# Related to model architecture
|
|
time_encoding: modelling.TIME_ENCODING = "gaussian_fourier",
|
|
num_hidden_layers: int = 12, # Default 12
|
|
hidden_size: int = 384, # Default 768
|
|
intermediate_size: int = 768, # Default 3072
|
|
num_heads: int = 12, # Default 12
|
|
position_embedding_type: Literal[
|
|
"absolute", "relative_key", "relative_key_query"
|
|
] = "absolute", # relative_key = https://arxiv.org/pdf/1803.02155.pdf | relative_key_query = https://arxiv.org/pdf/2009.13658.pdf
|
|
dropout_p: float = 0.1, # Default 0.1, can disable for debugging
|
|
decoder: modelling.DECODER_HEAD = "mlp",
|
|
# Related to training strategy
|
|
gradient_clip: float = 1.0, # From BERT trainer
|
|
batch_size: int = 64,
|
|
lr: float = 5e-5, # Default lr for huggingface BERT trainer
|
|
loss: modelling.LOSS_KEYS = "smooth_l1",
|
|
use_pdist_loss: Union[
|
|
float, Tuple[float, float]
|
|
] = 0.0, # Use the pairwise distances between CAs as an additional loss term, multiplied by this scalar
|
|
l2_norm: float = 0.0, # AdamW default has 0.01 L2 regularization, but BERT trainer uses 0.0
|
|
l1_norm: float = 0.0,
|
|
circle_reg: float = 0.0,
|
|
min_epochs: Optional[int] = None,
|
|
max_epochs: int = 10000,
|
|
early_stop_patience: int = 0, # Set to 0 to disable early stopping
|
|
lr_scheduler: modelling.LR_SCHEDULE = None,
|
|
use_swa: bool = False, # Stochastic weight averaging can improve training genearlization
|
|
# Misc. and debugging
|
|
multithread: bool = True,
|
|
subset: Union[bool, int] = False, # Subset to n training examples
|
|
exhaustive_validation_t: bool = False, # Exhaustively enumerate t for validation/test
|
|
syn_noiser: str = "", # If specified, use a synthetic noiser
|
|
single_angle_debug: int = -1, # Noise and return a single angle, choose [1, 2, 3] or -1 to disable
|
|
single_timestep_debug: bool = False, # Noise and return a single timestep
|
|
cpu_only: bool = False,
|
|
ngpu: int = -1, # -1 for all GPUs
|
|
write_valid_preds: bool = False, # Write validation predictions to disk at each epoch
|
|
dryrun: bool = False, # Disable some frills for a fast run to just train
|
|
):
|
|
"""Main training loop"""
|
|
# Record the args given to the function before we create more vars
|
|
# https://stackoverflow.com/questions/10724495/getting-all-arguments-and-values-passed-to-a-function
|
|
func_args = locals()
|
|
|
|
results_folder = Path(results_dir)
|
|
record_args_and_metadata(func_args, results_folder)
|
|
|
|
# Get datasets and wrap them in dataloaders
|
|
dsets = get_train_valid_test_sets(
|
|
dataset_key=dataset_key,
|
|
angles_definitions=angles_definitions,
|
|
max_seq_len=max_seq_len,
|
|
min_seq_len=min_seq_len,
|
|
seq_trim_strategy=trim_strategy,
|
|
timesteps=timesteps,
|
|
variance_schedule=variance_schedule,
|
|
var_scale=variance_scale,
|
|
toy=subset,
|
|
syn_noiser=syn_noiser,
|
|
exhaustive_t=exhaustive_validation_t,
|
|
single_angle_debug=single_angle_debug,
|
|
single_time_debug=single_timestep_debug,
|
|
)
|
|
# Record the masked means in the output directory
|
|
np.save(
|
|
results_folder / "training_mean_offset.npy",
|
|
dsets[0].dset.get_masked_means(),
|
|
fix_imports=False,
|
|
)
|
|
# Record the exact files used for training
|
|
for i, dset in enumerate(dsets):
|
|
dset_name = ["train", "valid", "test"][i]
|
|
with open(results_folder / f"{dset_name}_files.txt", "w") as f:
|
|
f.write("\n".join(dset.dset.filenames))
|
|
|
|
# Calculate effective batch size
|
|
# https://pytorch-lightning.readthedocs.io/en/1.4.0/advanced/multi_gpu.html#batch-size
|
|
# Under DDP, effective batch size is batch_size * num_gpus * num_nodes
|
|
effective_batch_size = batch_size
|
|
if torch.cuda.is_available():
|
|
effective_batch_size = int(batch_size / torch.cuda.device_count())
|
|
pl.utilities.rank_zero_info(
|
|
f"Given batch size: {batch_size} --> effective batch size with {torch.cuda.device_count()} GPUs: {effective_batch_size}"
|
|
)
|
|
|
|
train_dataloader, valid_dataloader, test_dataloader = [
|
|
DataLoader(
|
|
dataset=ds,
|
|
batch_size=effective_batch_size,
|
|
shuffle=i == 0, # Shuffle only train loader
|
|
num_workers=multiprocessing.cpu_count() if multithread else 1,
|
|
pin_memory=True,
|
|
)
|
|
for i, ds in enumerate(dsets)
|
|
]
|
|
|
|
# Create plots in output directories of distributions from different timesteps
|
|
plots_folder = results_folder / "plots"
|
|
os.makedirs(plots_folder, exist_ok=True)
|
|
# Skip this for debug runs
|
|
if (
|
|
single_angle_debug < 0
|
|
and not single_timestep_debug
|
|
and not syn_noiser
|
|
and not dryrun
|
|
):
|
|
plot_kl_divergence(dsets[0], plots_folder)
|
|
plot_timestep_distributions(
|
|
dsets[0],
|
|
timesteps=timesteps,
|
|
plots_folder=plots_folder,
|
|
)
|
|
|
|
# https://jaketae.github.io/study/relative-positional-encoding/
|
|
# looking at the relative distance between things is more robust
|
|
|
|
loss_fn = loss
|
|
if single_angle_debug > 0 or single_timestep_debug or syn_noiser:
|
|
loss_fn = functools.partial(losses.radian_smooth_l1_loss, beta=0.1 * np.pi)
|
|
logging.info(f"Using loss function: {loss_fn}")
|
|
|
|
# Shape of the input is (batch_size, timesteps, features)
|
|
sample_input = dsets[0][0]["corrupted"] # First item of the training dset
|
|
model_n_inputs = sample_input.shape[-1]
|
|
logging.info(f"Auto detected {model_n_inputs} inputs")
|
|
|
|
cfg = BertConfig(
|
|
max_position_embeddings=max_seq_len,
|
|
num_attention_heads=num_heads,
|
|
hidden_size=hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
num_hidden_layers=num_hidden_layers,
|
|
position_embedding_type=position_embedding_type,
|
|
hidden_dropout_prob=dropout_p,
|
|
attention_probs_dropout_prob=dropout_p,
|
|
use_cache=False,
|
|
)
|
|
# ft_is_angular from the clean datasets angularity definition
|
|
ft_key = "coords" if angles_definitions == "cart-coords" else "angles"
|
|
model = modelling.BertForDiffusion(
|
|
config=cfg,
|
|
time_encoding=time_encoding,
|
|
decoder=decoder,
|
|
ft_is_angular=dsets[0].dset.feature_is_angular[ft_key],
|
|
ft_names=dsets[0].dset.feature_names[ft_key],
|
|
lr=lr,
|
|
loss=loss_fn,
|
|
use_pairwise_dist_loss=use_pdist_loss
|
|
if isinstance(use_pdist_loss, float)
|
|
else [*use_pdist_loss, timesteps],
|
|
l2=l2_norm,
|
|
l1=l1_norm,
|
|
circle_reg=circle_reg,
|
|
epochs=max_epochs,
|
|
steps_per_epoch=len(train_dataloader),
|
|
lr_scheduler=lr_scheduler,
|
|
write_preds_to_dir=results_folder / "valid_preds"
|
|
if write_valid_preds
|
|
else None,
|
|
)
|
|
# https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
|
|
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
logging.info(f"Model has {num_params} trainable parameters")
|
|
|
|
cfg.save_pretrained(results_folder)
|
|
|
|
callbacks = build_callbacks(
|
|
outdir=results_folder, early_stop_patience=early_stop_patience, swa=use_swa
|
|
)
|
|
|
|
# Get accelerator and distributed strategy
|
|
accelerator, strategy = "cpu", None
|
|
if not cpu_only and torch.cuda.is_available():
|
|
accelerator = "cuda"
|
|
if torch.cuda.device_count() > 1:
|
|
# https://github.com/Lightning-AI/lightning/discussions/6761https://github.com/Lightning-AI/lightning/discussions/6761
|
|
strategy = DDPStrategy(find_unused_parameters=False)
|
|
|
|
logging.info(f"Using {accelerator} with strategy {strategy}")
|
|
trainer = pl.Trainer(
|
|
default_root_dir=results_folder,
|
|
gradient_clip_val=gradient_clip,
|
|
min_epochs=min_epochs,
|
|
max_epochs=max_epochs,
|
|
check_val_every_n_epoch=1,
|
|
callbacks=callbacks,
|
|
logger=pl.loggers.CSVLogger(save_dir=results_folder / "logs"),
|
|
log_every_n_steps=min(200, len(train_dataloader)), # Log >= once per epoch
|
|
accelerator=accelerator,
|
|
strategy=strategy,
|
|
gpus=ngpu,
|
|
enable_progress_bar=False,
|
|
move_metrics_to_cpu=False, # Saves memory
|
|
)
|
|
trainer.fit(
|
|
model=model,
|
|
train_dataloaders=train_dataloader,
|
|
val_dataloaders=valid_dataloader,
|
|
)
|
|
|
|
# Plot the losses
|
|
metrics_csv = os.path.join(
|
|
trainer.logger.save_dir, "lightning_logs/version_0/metrics.csv"
|
|
)
|
|
assert os.path.isfile(metrics_csv)
|
|
# Plot the losses
|
|
plotting.plot_losses(
|
|
metrics_csv, out_fname=plots_folder / "losses.pdf", simple=True
|
|
)
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
"""
|
|
Build CLI parser
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
usage=__doc__,
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
# https://stackoverflow.com/questions/4480075/argparse-optional-positional-arguments
|
|
parser.add_argument(
|
|
"config", nargs="?", default="", type=str, help="json of params"
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--outdir",
|
|
type=str,
|
|
default=os.path.join(os.getcwd(), "results"),
|
|
help="Directory to write model training outputs",
|
|
)
|
|
parser.add_argument(
|
|
"--toy",
|
|
type=int,
|
|
default=None,
|
|
help="Use a toy dataset of n items rather than full dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--debug_single_time",
|
|
action="store_true",
|
|
help="Debug single angle and timestep",
|
|
)
|
|
parser.add_argument("--cpu", action="store_true", help="Force use CPU")
|
|
parser.add_argument(
|
|
"--ngpu", type=int, default=-1, help="Number of GPUs to use (-1 for all)"
|
|
)
|
|
parser.add_argument("--dryrun", action="store_true", help="Dry run")
|
|
return parser
|
|
|
|
|
|
def main():
|
|
"""Run the training script based on params in the given json file"""
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
|
|
# Load in parameters and run training loop
|
|
config_args = {} # Empty dictionary as default
|
|
if args.config:
|
|
with open(args.config) as source:
|
|
config_args = json.load(source)
|
|
config_args = utils.update_dict_nonnull(
|
|
config_args,
|
|
{
|
|
"results_dir": args.outdir,
|
|
"subset": args.toy,
|
|
"single_timestep_debug": args.debug_single_time,
|
|
"cpu_only": args.cpu,
|
|
"ngpu": args.ngpu,
|
|
"dryrun": args.dryrun,
|
|
},
|
|
)
|
|
train(**config_args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
curr_time = datetime.now().strftime("%y%m%d_%H%M%S")
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
handlers=[
|
|
logging.FileHandler(f"training_{curr_time}.log"),
|
|
logging.StreamHandler(),
|
|
],
|
|
)
|
|
|
|
main()
|