Fix missing random module; Implement logging to W&B with...

This commit is contained in:
Nate Corley
2024-03-26 22:16:05 +00:00
committed by Rohith Krishna
parent 666559cd2a
commit bb06653ab5
4 changed files with 88 additions and 46 deletions

2
.gitignore vendored
View File

@@ -14,3 +14,5 @@ __pycache__/
unit_tests/
ruff.toml
*/scratch/
*/wandb/
rf2aa/dataset_20240318.pkl

View File

@@ -133,6 +133,8 @@ loss_param:
log_params:
log_every_n_examples: 1
use_wandb: False
wandb_project: 'rf2aa'
eval_params: null

View File

@@ -6,4 +6,4 @@ def seed_all(seed=0):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
np.random.seed(seed)

View File

@@ -3,12 +3,15 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import numpy as np
from functools import partial
import hydra
import os
import time
import wandb
import omegaconf
from contextlib import nullcontext
import datetime
import certifi
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items
@@ -31,7 +34,8 @@ from rf2aa.set_seed import seed_all
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
# Update environment variable with correct path (needed for W&B upload)
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
## To reproduce errors
torch.set_num_threads(4)
@@ -48,7 +52,7 @@ class Trainer:
self.output_dir = "models/"
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
def construct_model(self):
raise NotImplementedError()
@@ -153,7 +157,6 @@ class Trainer:
rank = int (os.environ["SLURM_PROCID"])
print ("Launched from slurm", rank, world_size)
self.train_model(rank, world_size)
#mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
else:
print ("Launched from interactive")
@@ -174,7 +177,8 @@ class Trainer:
return gpu
def cleanup(self):
dist.destroy_process_group()
if dist.is_initialized():
dist.destroy_process_group()
def train_model(self, rank, world_size):
""" runs model training on each gpu """
@@ -185,51 +189,67 @@ class Trainer:
init = partial(initialize_chemdata,self.config.chem_params)
init()
train_loader, train_sampler, valid_loaders, valid_samplers = self.construct_dataset(
init, rank, world_size
# Define context manager for training run (either nullcontext or W&B)
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
context_manager = (
wandb.init(
project=self.config.log_params.wandb_project,
config=omegaconf.OmegaConf.to_container(
self.config, resolve=True, throw_on_missing=True
),
name = f"{self.config.experiment.name}_{timestamp}"
)
if self.config.log_params.use_wandb and rank == 0
else nullcontext() # Does nothing
)
self.train_loader = train_loader
self.valid_loaders = valid_loaders
# Without W&B, context manager does nothing
with context_manager:
train_loader, train_sampler, valid_loaders, valid_samplers = self.construct_dataset(
init, rank, world_size
)
# move global information to device
self.move_constants_to_device(gpu)
self.train_loader = train_loader
self.valid_loaders = valid_loaders
self.construct_model(device=gpu)
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
if rank == 0:
print(f"Loading model with {count_parameters(self.model)} parameters")
# move global information to device
self.move_constants_to_device(gpu)
self.construct_optimizer()
self.construct_scheduler()
self.construct_scaler()
if self.config.training_params.resume_train:
self.load_checkpoint(gpu)
self.load_model()
try:
self.load_optimizer()
self.load_scheduler()
self.load_scaler()
except Exception as ex:
print ('Error in loading optimizer parameters:',ex)
print ('Continuing...')
self.construct_model(device=gpu)
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
if rank == 0:
print(f"Loading model with {count_parameters(self.model)} parameters")
self.recycle_schedule = recycle_sampling["by_batch"](self.config.loader_params.maxcycle,
self.config.experiment.n_epoch,
self.config.dataset_params.n_train,
world_size)
#self.valid_epoch(-1, rank, world_size)
for epoch in range(self.config.experiment.n_epoch):
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
self.train_epoch(epoch, rank, world_size)
for _, valid_sampler in valid_samplers.items():
valid_sampler.set_epoch(epoch)
self.construct_optimizer()
self.construct_scheduler()
self.construct_scaler()
if self.config.training_params.resume_train:
self.load_checkpoint(gpu)
self.load_model()
try:
self.load_optimizer()
self.load_scheduler()
self.load_scaler()
except Exception as ex:
print ('Error in loading optimizer parameters:',ex)
print ('Continuing...')
if (
self.config.dataset_params.validate_every_n_epochs > 0
and epoch % self.config.dataset_params.validate_every_n_epochs==0
):
self.valid_epoch(epoch, rank, world_size)
self.recycle_schedule = recycle_sampling["by_batch"](self.config.loader_params.maxcycle,
self.config.experiment.n_epoch,
self.config.dataset_params.n_train,
world_size)
for epoch in range(self.config.experiment.n_epoch):
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
self.train_epoch(epoch, rank, world_size)
for _, valid_sampler in valid_samplers.items():
valid_sampler.set_epoch(epoch)
if (
self.config.dataset_params.validate_every_n_epochs > 0
and epoch % self.config.dataset_params.validate_every_n_epochs==0
):
self.valid_epoch(epoch, rank, world_size)
self.cleanup()
@@ -260,6 +280,10 @@ class Trainer:
inputs, loss_dict, n_cycle,
(train_idx+1)*world_size, len(self.train_loader)*world_size, train_time
)
# If using W&B, log the intermediate losses (note: this is only done for rank = 0)
if self.config.log_params.use_wandb:
wandb.log(loss_dict)
torch.cuda.empty_cache()
if rank == 0:
@@ -296,6 +320,9 @@ class Trainer:
if rank==0:
self.log_validation_losses(dataset_name, valid_loss_dict)
# If using W&B, log the validation losses (note: this is only done for rank = 0)
if self.config.log_params.use_wandb:
wandb.log(valid_loss_dict)
def train_step(self, inputs, n_cycle):
""" take an input from dataloader, run the model and compute a loss """
@@ -419,7 +446,18 @@ class ComposedTrainer(Trainer):
def main(config):
seed_all()
trainer = trainer_factory[config.experiment.trainer](config=config)
trainer.launch_distributed_training()
# Wrap the training in a try-except block to ensure SLURM cleanup post-interrupt (otherwise, we'd need to change the SLURM id each run)
try:
trainer.launch_distributed_training()
except KeyboardInterrupt:
print("Training interrupted by user.")
except Exception as e:
print("Training interrupted by exception:", e)
raise e
finally:
print("Cleaning up...")
trainer.cleanup()
trainer_factory = {
"legacy": LegacyTrainer,