Files
Delete/delete_train.py
HaotianZhang 7652568ecc update
2024-11-25 09:59:13 -08:00

222 lines
9.0 KiB
Python

import os
import shutil
import argparse
from tqdm.auto import tqdm
import torch
from torch.nn.utils import clip_grad_norm_
from torch_geometric.loader import DataLoader
from models.delete import Delete
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.train import *
from time import time
from utils.train import get_model_loss
from utils.datasets.pl import SurfLigandPairDataset
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='./configs/train_frag_moad.yml')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--logdir', type=str, default='./logs')
args = parser.parse_args()
config = load_config(args.config)
config_name = os.path.basename(args.config)[:os.path.basename(args.config).rfind('.')]
seed_all(config.train.seed)
log_dir = get_new_log_dir(args.logdir, prefix=config_name)
ckpt_dir = os.path.join(log_dir, 'checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)
logger = get_logger('train', log_dir)
logger.info(args)
logger.info(config)
shutil.copyfile(args.config, os.path.join(log_dir, os.path.basename(args.config)))
shutil.copytree('./models', os.path.join(log_dir, 'models'))
protein_featurizer = FeaturizeProteinAtom()
ligand_featurizer = FeaturizeLigandAtom()
masking = get_mask(config.train.transform.mask)
composer = AtomComposer(protein_featurizer.feature_dim, ligand_featurizer.feature_dim, config.model.encoder.knn)
edge_sampler = EdgeSample(config.train.transform.edgesampler)
cfg_ctr = config.train.transform.contrastive
contrastive_sampler = ContrastiveSample(cfg_ctr.num_real, cfg_ctr.num_fake, cfg_ctr.pos_real_std, cfg_ctr.pos_fake_std, config.model.field.knn)
transform = Compose([
RefineData(),
LigandCountNeighbors(),
protein_featurizer,
ligand_featurizer,
masking,
composer,
FocalBuilder(),
edge_sampler,
contrastive_sampler,
])
def get_dataset(config, *args, **kwargs):
name = config.name
root = config.path
if name == 'pl':
dataset = SurfLigandPairDataset(root, *args, **kwargs)
else:
raise NotImplementedError('Unknown dataset: %s' % name)
if 'split' in config:
split_by_name = torch.load(config.split)
split = {
k: [dataset.name2id[n] for n in names if n in dataset.name2id]
for k, names in split_by_name.items()
}
subsets = {k:Subset(dataset, indices=v) for k, v in split.items()}
return dataset, subsets
else:
return dataset
dataset, subsets = get_dataset(
config = config.dataset,
transform = transform,
)
train_set, val_set = subsets['train'], subsets['test']
follow_batch = []
collate_exclude_keys = ['ligand_nbh_list']
val_loader = DataLoader(val_set, config.train.batch_size, shuffle=False, follow_batch=follow_batch, exclude_keys = collate_exclude_keys,)
train_loader = DataLoader(train_set, config.train.batch_size, shuffle=False, exclude_keys = collate_exclude_keys)
model = Delete(
config.model,
num_classes = contrastive_sampler.num_elements,
num_bond_types = edge_sampler.num_bond_types,
protein_atom_feature_dim = protein_featurizer.feature_dim,
ligand_atom_feature_dim = ligand_featurizer.feature_dim,
).to(args.device)
print('Num of parameters is {0:.4}M'.format(np.sum([p.numel() for p in model.parameters()]) /100000 ))
optimizer = get_optimizer(config.train.optimizer, model)
scheduler = get_scheduler(config.train.scheduler, optimizer)
def update_losses(eval_loss, loss, loss_frontier, loss_pos, loss_cls, loss_edge, loss_real, loss_fake, loss_surf):
eval_loss['total'].append(loss)
eval_loss['frontier'].append(loss_frontier)
eval_loss['pos'].append(loss_pos)
eval_loss['cls'].append(loss_cls)
eval_loss['edge'].append(loss_edge)
eval_loss['real'].append(loss_real)
eval_loss['fake'].append(loss_fake)
eval_loss['surf'].append(loss_surf)
return eval_loss
def evaluate(epoch, verbose=1):
model.eval()
eval_start = time()
#eval_losses = {'total':[], 'frontier':[], 'pos':[], 'cls':[], 'edge':[], 'real':[], 'fake':[], 'surf':[] }
eval_losses = []
for batch in val_loader:
batch = batch.to(args.device)
loss, loss_frontier, loss_pos, loss_cls, loss_edge, loss_real, loss_fake, loss_surf = get_model_loss(model, batch, config )
eval_losses.append(loss.item())
average_loss = sum(eval_losses) / len(eval_losses)
if verbose:
logger.info('Evaluate Epoch %d | Average_Loss %.5f | Single Batch Loss %.6f | Loss(Fron) %.6f | Loss(Pos) %.6f | Loss(Cls) %.6f | Loss(Edge) %.6f | Loss(Real) %.6f | Loss(Fake) %.6f | Loss(Surf) %.6f ' % (
epoch, average_loss, loss.item(), loss_frontier.item(), loss_pos.item(), loss_cls.item(), loss_edge.item(), loss_real.item(), loss_fake.item(), loss_surf.item()
))
return average_loss
def load(config, model, optimizer=False, scheduler=False):
'''
Load model, optimizer, scheduler
'''
ckpt_name = config.train.ckpt_name
resume_epoch = int(config.train.start_epoch)
ckpt = torch.load(os.path.join(config.train.checkpoint_path,ckpt_name))
best_loss = float(ckpt['best_loss'])
model.load_state_dict(ckpt["model"])
if scheduler:
scheduler.load_state_dict(ckpt["scheduler"])
if optimizer:
optimizer.load_state_dict(ckpt["optimizer"])
if args.device == 'cuda':
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(args.device)
return model, best_loss, resume_epoch
def train(model, verbose=1, num_epoches=300):
train_start = time()
train_losses = []
val_losses = []
start_epoch = 0
best_loss = 1000
if config.train.resume_train:
ckpt_name = config.train.ckpt_name
model, best_loss, start_epoch = load(config, model, optimizer, scheduler)
logger.info('load pretrained model from '.format(ckpt_name))
logger.info('start training...')
for epoch in range(num_epoches):
model.train()
epoch_start = time()
batch_losses = []
batch_cnt = 0
for batch in train_loader:
batch_cnt+=1
batch = batch.to(args.device)
loss, loss_frontier, loss_pos, loss_cls, loss_edge, loss_real, loss_fake, loss_surf = get_model_loss(model, batch, config )
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_losses.append(loss.item())
if (epoch==0 and batch_cnt <= 10):
logger.info('Training Epoch %d | Step %d | Loss %.6f | Loss(Fron) %.6f | Loss(Pos) %.6f | Loss(Cls) %.6f | Loss(Edge) %.6f | Loss(Real) %.6f | Loss(Fake) %.6f | Loss(Surf) %.6f ' % (
epoch+start_epoch, batch_cnt, loss.item(), loss_frontier.item(), loss_pos.item(), loss_cls.item(), loss_edge.item(), loss_real.item(), loss_fake.item(), loss_surf.item()
))
average_loss = sum(batch_losses) / (len(batch_losses)+1)
train_losses.append(average_loss)
if verbose:
logger.info('Training Epoch %d | Average_Loss %.5f | Loss %.6f | Loss(Fron) %.6f | Loss(Pos) %.6f | Loss(Cls) %.6f | Loss(Edge) %.6f | Loss(Real) %.6f | Loss(Fake) %.6f | Loss(Surf) %.6f ' % (
epoch+start_epoch, average_loss , loss.item(), loss_frontier.item(), loss_pos.item(), loss_cls.item(), loss_edge.item(), loss_real.item(), loss_fake.item(), loss_surf.item()
))
average_eval_loss = evaluate(epoch+start_epoch, verbose=1)
val_losses.append(average_eval_loss)
if config.train.scheduler.type=="plateau":
scheduler.step(average_eval_loss)
else:
scheduler.step()
if val_losses[-1] < best_loss:
best_loss = val_losses[-1]
if config.train.save:
ckpt_path = os.path.join(ckpt_dir, 'val_%d.pt' % int(epoch+start_epoch))
torch.save({
'config': config,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': start_epoch + epoch,
'best_loss': best_loss
}, ckpt_path)
else:
if len(train_losses) > 20:
if (train_losses[-1]<train_losses[-2]):
if config.train.save:
ckpt_path = os.path.join(ckpt_dir, 'train_%d.pt' % int(epoch+start_epoch))
torch.save({
'config': config,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': start_epoch + epoch,
'best_loss': best_loss
}, ckpt_path)
torch.cuda.empty_cache()
train(model)