Files
openfold/train_openfold.py
Lukas Jarosch 9660a43d10 Fix distributed seeding behavior
This adds workers=True to the Lightning seed_everything function which guarantees different random states across all processes in distributed training. Prior to that some processes on different GPUs with the same worker ID could share the same random state.

Note that this will break reproducibility between runs prior to and after this change.

Also removes the seed and supress_output modules that were not used anymore in OpenFold.
2024-05-09 17:18:42 +07:00

702 lines
25 KiB
Python

import argparse
import logging
import os
import sys
import json
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
import torch
import wandb
from deepspeed.utils import zero_to_fp32
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
)
from openfold.utils.import_weights import (
import_jax_weights_,
import_openfold_weights_
)
from openfold.utils.logger import PerformanceLoggingCallback
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.is_multimer = self.config.globals.is_multimer
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters()
def forward(self, batch):
return self.model(batch)
def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
prog_bar=(loss_name == 'loss'),
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True, sync_dist=False,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
self.log(
f"{phase}/{k}",
torch.mean(v),
prog_bar = (k == 'loss'),
on_step=False, on_epoch=True, logger=True, sync_dist=False,
)
def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None)
# Run the model
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log it
self._log(loss_breakdown, batch, outputs)
return loss
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None)
# Run the model
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
batch["use_clamped_fape"] = 0.
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss and other metrics
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, batch, outputs, train=False)
def on_validation_epoch_end(self):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def _compute_validation_metrics(self,
batch,
outputs,
superimposition_metrics=False
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
def configure_optimizers(self,
learning_rate: float = 1e-3,
eps: float = 1e-5,
) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"]
if(not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
_DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files
state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
return torch.load(state_file)
def main(args):
if(args.seed is not None):
seed_everything(args.seed, workers=True)
is_low_precision = args.precision in [
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
config = model_config(
args.config_preset,
train=True,
low_prec=is_low_precision,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
model_module = OpenFoldWrapper(config)
if args.resume_from_ckpt:
if args.resume_model_weights_only:
# Load the checkpoint
if os.path.isdir(args.resume_from_ckpt):
sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
# Process the state dict
if 'module' in sd:
sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=sd)
elif 'state_dict' in sd:
import_openfold_weights_(
model=model_module, state_dict=sd['state_dict'])
else:
# Loading from pre-trained model
sd = {'model.'+k: v for k, v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt):
sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
script_preset_(model_module)
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
else:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
callbacks = []
if(args.checkpoint_every_epoch):
mc = ModelCheckpoint(
every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val/lddt_ca",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
mode="max",
check_finite=True,
strict=True,
)
callbacks.append(es)
if(args.log_performance):
global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"),
global_batch_size=global_batch_size,
)
callbacks.append(perf)
if(args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
loggers = []
is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
if(args.wandb):
if args.mpi_plugin and is_rank_zero:
wandb_init_dict = dict(
name=args.experiment_name,
project=args.wandb_project,
id=args.wandb_id,
dir=args.output_dir,
resume="allow",
anonymous=None,
entity=args.wandb_entity
)
wandb.run = wandb.init(**wandb_init_dict)
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
)
if(args.wandb and is_rank_zero):
wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPStrategy(find_unused_parameters=False,
cluster_environment=cluster_environment)
else:
strategy = None
if(args.wandb and is_rank_zero):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
trainer_args.update({
'default_root_dir': args.output_dir,
'strategy': strategy,
'callbacks': callbacks,
'logger': loggers,
})
trainer = pl.Trainer(**trainer_args)
if (args.resume_model_weights_only):
ckpt_path = None
else:
ckpt_path = args.resume_from_ckpt
trainer.fit(
model_module,
datamodule=data_module,
ckpt_path=ckpt_path,
)
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str,
help="Directory containing training mmCIF files"
)
parser.add_argument(
"train_alignment_dir", type=str,
help="Directory containing precomputed training alignments"
)
parser.add_argument(
"template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"output_dir", type=str,
help='''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser.add_argument(
"max_template_date", type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="Path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument(
"--use_single_seq_mode", type=str, default=False,
help="Use single sequence embeddings instead of MSAs."
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
)
parser.add_argument(
"--distillation_alignment_dir", type=str, default=None,
help="Directory containing precomputed distillation alignments"
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
)
parser.add_argument(
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--val_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during validation"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
)
parser.add_argument(
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements."""
)
parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser.add_argument(
"--use_small_bfd", type=bool_type, default=False,
help="Whether to use a reduced version of the BFD database"
)
parser.add_argument(
"--seed", type=int, default=None,
help="Random seed"
)
parser.add_argument(
"--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_every_epoch", action="store_true", default=False,
help="""Whether to checkpoint at the end of every training epoch"""
)
parser.add_argument(
"--early_stopping", type=bool_type, default=False,
help="Whether to stop training when validation loss fails to decrease"
)
parser.add_argument(
"--min_delta", type=float, default=0,
help="""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser.add_argument(
"--patience", type=int, default=3,
help="Early stopping patience"
)
parser.add_argument(
"--resume_from_ckpt", type=str, default=None,
help="Path to a model checkpoint from which to restore training state"
)
parser.add_argument(
"--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument(
"--log_performance", type=bool_type, default=False,
help="Measure performance"
)
parser.add_argument(
"--wandb", action="store_true", default=False,
help="Whether to log metrics to Weights & Biases"
)
parser.add_argument(
"--experiment_name", type=str, default=None,
help="Name of the current experiment. Used for wandb logging"
)
parser.add_argument(
"--wandb_id", type=str, default=None,
help="ID of a previous run to be resumed"
)
parser.add_argument(
"--wandb_project", type=str, default=None,
help="Name of the wandb project to which this run will belong"
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
help="wandb username or team name to which runs are attributed"
)
parser.add_argument(
"--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model"
)
parser.add_argument(
"--train_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--log_lr", action="store_true", default=False,
help="Whether to log the actual learning rate"
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument(
"--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
)
parser.add_argument("--mpi_plugin", action="store_true", default=False,
help="Whether to use MPI for parallele processing")
trainer_group = parser.add_argument_group(
'Arguments to pass to PyTorch Lightning Trainer')
trainer_group.add_argument(
"--num_nodes", type=int, default=1,
)
trainer_group.add_argument(
"--precision", type=str, default='bf16',
help='Sets precision, lower precision improves runtime performance.',
)
trainer_group.add_argument(
"--max_epochs", type=int, default=1,
)
trainer_group.add_argument(
"--log_every_n_steps", type=int, default=25,
)
trainer_group.add_argument(
"--flush_logs_every_n_steps", type=int, default=5,
)
trainer_group.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args()
if(args.seed is None and
((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
main(args)