mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Fix missing random module; Implement logging to W&B with...
This commit is contained in:
committed by
Rohith Krishna
parent
666559cd2a
commit
bb06653ab5
2
.gitignore
vendored
2
.gitignore
vendored
@@ -14,3 +14,5 @@ __pycache__/
|
||||
unit_tests/
|
||||
ruff.toml
|
||||
*/scratch/
|
||||
*/wandb/
|
||||
rf2aa/dataset_20240318.pkl
|
||||
|
||||
@@ -133,6 +133,8 @@ loss_param:
|
||||
|
||||
log_params:
|
||||
log_every_n_examples: 1
|
||||
use_wandb: False
|
||||
wandb_project: 'rf2aa'
|
||||
|
||||
eval_params: null
|
||||
|
||||
|
||||
@@ -6,4 +6,4 @@ def seed_all(seed=0):
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
@@ -3,12 +3,15 @@ import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
|
||||
from functools import partial
|
||||
|
||||
import hydra
|
||||
import os
|
||||
import time
|
||||
import wandb
|
||||
import omegaconf
|
||||
from contextlib import nullcontext
|
||||
import datetime
|
||||
import certifi
|
||||
|
||||
from rf2aa.data.compose_dataset import compose_dataset, compose_single_item_dataset
|
||||
from rf2aa.data.dataloader_adaptor import prepare_input, get_loss_calc_items
|
||||
@@ -31,7 +34,8 @@ from rf2aa.set_seed import seed_all
|
||||
os.environ['OMP_NUM_THREADS'] = '4'
|
||||
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||||
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||||
|
||||
# Update environment variable with correct path (needed for W&B upload)
|
||||
os.environ['REQUESTS_CA_BUNDLE'] = certifi.where()
|
||||
## To reproduce errors
|
||||
|
||||
torch.set_num_threads(4)
|
||||
@@ -48,7 +52,7 @@ class Trainer:
|
||||
self.output_dir = "models/"
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
|
||||
|
||||
def construct_model(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -153,7 +157,6 @@ class Trainer:
|
||||
rank = int (os.environ["SLURM_PROCID"])
|
||||
print ("Launched from slurm", rank, world_size)
|
||||
self.train_model(rank, world_size)
|
||||
#mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
|
||||
|
||||
else:
|
||||
print ("Launched from interactive")
|
||||
@@ -174,7 +177,8 @@ class Trainer:
|
||||
return gpu
|
||||
|
||||
def cleanup(self):
|
||||
dist.destroy_process_group()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
def train_model(self, rank, world_size):
|
||||
""" runs model training on each gpu """
|
||||
@@ -185,51 +189,67 @@ class Trainer:
|
||||
init = partial(initialize_chemdata,self.config.chem_params)
|
||||
init()
|
||||
|
||||
train_loader, train_sampler, valid_loaders, valid_samplers = self.construct_dataset(
|
||||
init, rank, world_size
|
||||
# Define context manager for training run (either nullcontext or W&B)
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
context_manager = (
|
||||
wandb.init(
|
||||
project=self.config.log_params.wandb_project,
|
||||
config=omegaconf.OmegaConf.to_container(
|
||||
self.config, resolve=True, throw_on_missing=True
|
||||
),
|
||||
name = f"{self.config.experiment.name}_{timestamp}"
|
||||
)
|
||||
if self.config.log_params.use_wandb and rank == 0
|
||||
else nullcontext() # Does nothing
|
||||
)
|
||||
|
||||
self.train_loader = train_loader
|
||||
self.valid_loaders = valid_loaders
|
||||
# Without W&B, context manager does nothing
|
||||
with context_manager:
|
||||
train_loader, train_sampler, valid_loaders, valid_samplers = self.construct_dataset(
|
||||
init, rank, world_size
|
||||
)
|
||||
|
||||
# move global information to device
|
||||
self.move_constants_to_device(gpu)
|
||||
self.train_loader = train_loader
|
||||
self.valid_loaders = valid_loaders
|
||||
|
||||
self.construct_model(device=gpu)
|
||||
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
if rank == 0:
|
||||
print(f"Loading model with {count_parameters(self.model)} parameters")
|
||||
# move global information to device
|
||||
self.move_constants_to_device(gpu)
|
||||
|
||||
self.construct_optimizer()
|
||||
self.construct_scheduler()
|
||||
self.construct_scaler()
|
||||
if self.config.training_params.resume_train:
|
||||
self.load_checkpoint(gpu)
|
||||
self.load_model()
|
||||
try:
|
||||
self.load_optimizer()
|
||||
self.load_scheduler()
|
||||
self.load_scaler()
|
||||
except Exception as ex:
|
||||
print ('Error in loading optimizer parameters:',ex)
|
||||
print ('Continuing...')
|
||||
self.construct_model(device=gpu)
|
||||
self.model = DDP(self.model, device_ids=[gpu], find_unused_parameters=False, broadcast_buffers=False)
|
||||
if rank == 0:
|
||||
print(f"Loading model with {count_parameters(self.model)} parameters")
|
||||
|
||||
self.recycle_schedule = recycle_sampling["by_batch"](self.config.loader_params.maxcycle,
|
||||
self.config.experiment.n_epoch,
|
||||
self.config.dataset_params.n_train,
|
||||
world_size)
|
||||
#self.valid_epoch(-1, rank, world_size)
|
||||
for epoch in range(self.config.experiment.n_epoch):
|
||||
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
|
||||
self.train_epoch(epoch, rank, world_size)
|
||||
for _, valid_sampler in valid_samplers.items():
|
||||
valid_sampler.set_epoch(epoch)
|
||||
self.construct_optimizer()
|
||||
self.construct_scheduler()
|
||||
self.construct_scaler()
|
||||
if self.config.training_params.resume_train:
|
||||
self.load_checkpoint(gpu)
|
||||
self.load_model()
|
||||
try:
|
||||
self.load_optimizer()
|
||||
self.load_scheduler()
|
||||
self.load_scaler()
|
||||
except Exception as ex:
|
||||
print ('Error in loading optimizer parameters:',ex)
|
||||
print ('Continuing...')
|
||||
|
||||
if (
|
||||
self.config.dataset_params.validate_every_n_epochs > 0
|
||||
and epoch % self.config.dataset_params.validate_every_n_epochs==0
|
||||
):
|
||||
self.valid_epoch(epoch, rank, world_size)
|
||||
self.recycle_schedule = recycle_sampling["by_batch"](self.config.loader_params.maxcycle,
|
||||
self.config.experiment.n_epoch,
|
||||
self.config.dataset_params.n_train,
|
||||
world_size)
|
||||
|
||||
for epoch in range(self.config.experiment.n_epoch):
|
||||
train_sampler.set_epoch(epoch) #TODO: need to make sure each gpu gets a different example
|
||||
self.train_epoch(epoch, rank, world_size)
|
||||
for _, valid_sampler in valid_samplers.items():
|
||||
valid_sampler.set_epoch(epoch)
|
||||
|
||||
if (
|
||||
self.config.dataset_params.validate_every_n_epochs > 0
|
||||
and epoch % self.config.dataset_params.validate_every_n_epochs==0
|
||||
):
|
||||
self.valid_epoch(epoch, rank, world_size)
|
||||
|
||||
self.cleanup()
|
||||
|
||||
@@ -260,6 +280,10 @@ class Trainer:
|
||||
inputs, loss_dict, n_cycle,
|
||||
(train_idx+1)*world_size, len(self.train_loader)*world_size, train_time
|
||||
)
|
||||
|
||||
# If using W&B, log the intermediate losses (note: this is only done for rank = 0)
|
||||
if self.config.log_params.use_wandb:
|
||||
wandb.log(loss_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if rank == 0:
|
||||
@@ -296,6 +320,9 @@ class Trainer:
|
||||
|
||||
if rank==0:
|
||||
self.log_validation_losses(dataset_name, valid_loss_dict)
|
||||
# If using W&B, log the validation losses (note: this is only done for rank = 0)
|
||||
if self.config.log_params.use_wandb:
|
||||
wandb.log(valid_loss_dict)
|
||||
|
||||
def train_step(self, inputs, n_cycle):
|
||||
""" take an input from dataloader, run the model and compute a loss """
|
||||
@@ -419,7 +446,18 @@ class ComposedTrainer(Trainer):
|
||||
def main(config):
|
||||
seed_all()
|
||||
trainer = trainer_factory[config.experiment.trainer](config=config)
|
||||
trainer.launch_distributed_training()
|
||||
|
||||
# Wrap the training in a try-except block to ensure SLURM cleanup post-interrupt (otherwise, we'd need to change the SLURM id each run)
|
||||
try:
|
||||
trainer.launch_distributed_training()
|
||||
except KeyboardInterrupt:
|
||||
print("Training interrupted by user.")
|
||||
except Exception as e:
|
||||
print("Training interrupted by exception:", e)
|
||||
raise e
|
||||
finally:
|
||||
print("Cleaning up...")
|
||||
trainer.cleanup()
|
||||
|
||||
trainer_factory = {
|
||||
"legacy": LegacyTrainer,
|
||||
|
||||
Reference in New Issue
Block a user