Files
foundry/RF2_allatom/train_multi_EMA.py
2022-08-03 16:44:44 -07:00

1607 lines
72 KiB
Python

import sys, os
import time
import numpy as np
from copy import deepcopy
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.utils import data
from functools import partial
from data_loader import (
get_train_valid_set, loader_pdb, loader_fb, loader_complex, loader_na_complex, loader_rna, loader_sm_compl,
Dataset, DatasetComplex, DatasetNAComplex, DatasetRNA, DatasetSMComplex, DistilledDataset, DistributedWeightedSampler
)
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, xyz_to_bbtor, get_init_xyz
from RoseTTAFoldModel import RoseTTAFoldModule
from loss import *
from util import *
from util_module import ComputeAllAtomCoords
from scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedule_with_warmup
# distributed data parallel
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
torch.autograd.set_detect_anomaly(True)
torch.manual_seed(5924)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
## To reproduce errors
#import random
np.random.seed(6636)
#random.seed(0)
USE_AMP = False
torch.set_num_threads(4)
N_PRINT_TRAIN = 16
#BATCH_SIZE = 1 * torch.cuda.device_count()
# num structs per epoch
# must be divisible by #GPUs
N_EXAMPLE_PER_EPOCH = 1208
LOAD_PARAM = {'shuffle': False,
'num_workers': 3,
'pin_memory': True}
def add_weight_decay(model, l2_coeff):
decay, no_decay = [], []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
#if len(param.shape) == 1 or name.endswith(".bias"):
if "norm" in name or name.endswith(".bias"):
no_decay.append(param)
else:
decay.append(param)
return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_coeff}]
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class EMA(nn.Module):
def __init__(self, model, decay):
super().__init__()
self.decay = decay
self.model = model
self.shadow = deepcopy(self.model)
for param in self.shadow.parameters():
param.detach_()
@torch.no_grad()
def update(self):
if not self.training:
print("EMA update should only be called during training", file=stderr, flush=True)
return
model_params = OrderedDict(self.model.named_parameters())
shadow_params = OrderedDict(self.shadow.named_parameters())
# check if both model contains the same set of keys
assert model_params.keys() == shadow_params.keys()
for name, param in model_params.items():
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
if param.requires_grad:
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
model_buffers = OrderedDict(self.model.named_buffers())
shadow_buffers = OrderedDict(self.shadow.named_buffers())
# check if both model contains the same set of keys
assert model_buffers.keys() == shadow_buffers.keys()
for name, buffer in model_buffers.items():
# buffers are copied
shadow_buffers[name].copy_(buffer)
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
else:
return self.shadow(*args, **kwargs)
class Trainer():
def __init__(self, model_name='BFF',
n_epoch=100, step_lr=100, lr=1.0e-4, l2_coeff=1.0e-2, port=None, interactive=False,
model_param={}, loader_param={}, loss_param={}, batch_size=1, accum_step=1, maxcycle=4, eval=False):
self.model_name = model_name #"BFF"
#self.model_name = "%s_%d_%d_%d_%d"%(model_name, model_param['n_module'],
# model_param['n_module_str'],
# model_param['d_msa'],
# model_param['d_pair'])
#
self.n_epoch = n_epoch
self.step_lr = step_lr
self.init_lr = lr
self.l2_coeff = l2_coeff
self.port = port
self.interactive = interactive
self.eval = eval
#
self.model_param = model_param
self.loader_param = loader_param
self.loss_param = loss_param
self.ACCUM_STEP = accum_step
self.batch_size = batch_size
# for all-atom str loss
self.ti_dev = torsion_indices
self.ti_flip = torsion_can_flip
self.ang_ref = reference_angles
self.fi_dev = frame_indices
self.l2a = long2alt
self.aamask = allatom_mask
self.num_bonds = num_bonds
self.atom_type_index = atom_type_index
self.ljlk_parameters = ljlk_parameters
self.lj_correction_parameters = lj_correction_parameters
self.hbtypes = hbtypes
self.hbbaseatoms = hbbaseatoms
self.hbpolys = hbpolys
self.cb_len = cb_length_t
self.cb_ang = cb_angle_t
self.cb_tor = cb_torsion_t
# module torsion -> allatom
self.compute_allatom_coords = ComputeAllAtomCoords()
# loss & final activation function
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
self.active_fn = nn.Softmax(dim=1)
self.maxcycle = maxcycle
self.pdb_counter=0
def calc_loss(self, logit_s, label_s,
logit_aa_s, label_aa_s, mask_aa_s,
pred, pred_tors, pred_allatom, true,
mask_crds, mask_BB, mask_2d, same_chain,
pred_lddt, idx, atom_frames=None, unclamp=False, negative=False, interface=False,
verbose=False, ctr=0,
w_dist=1.0, w_aa=1.0, w_str=1.0, w_lddt=1.0, w_bond=1.0, w_clash=0.0, w_hb=0.0, w_dih=0.0,
lj_lin=0.85, eps=1e-6
):
B, L = true.shape[:2]
seq = label_aa_s[:,0].clone()
assert (B==1) # fd - code assumes a batch size of 1
loss_s = list()
tot_loss = 0.0
# c6d loss
for i in range(4):
loss = self.loss_fn(logit_s[i], label_s[...,i]) # (B, L, L)
loss = (mask_2d*loss).sum() / (mask_2d.sum() + eps)
tot_loss += w_dist*loss
loss_s.append(loss[None].detach())
# masked token prediction loss
loss = self.loss_fn(logit_aa_s, label_aa_s.reshape(B, -1))
loss = loss * mask_aa_s.reshape(B, -1)
loss = loss.sum() / (mask_aa_s.sum() + 1e-8)
tot_loss += w_aa*loss
loss_s.append(loss[None].detach())
### GENERAL LAYERS
# Structural loss
dclamp = 300.0 if unclamp else 30.0
frames, frame_mask = get_frames(
pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
frame_mask_BB = frame_mask.clone()
frame_mask_BB[...,1:] =False
if negative: # inter-chain fapes should be ignored for negative cases
L1 = same_chain[0,0,:].sum()
mask_BBA = mask_BB.clone()
mask_BBA[0, L1:] = False
l_fape_A = compute_general_FAPE(
pred_allatom[:,mask_BBA[0],:,:3],
true[:,mask_BBA[0],:,:3],
mask_crds[:,mask_BBA[0]],
frames[:,mask_BBA[0]],
frame_mask_BB[:,mask_BBA[0]],
dclamp=dclamp
)
mask_BBB = mask_BB.clone()
mask_BBB[0,:L1] = False
l_fape_B = compute_general_FAPE(
pred_allatom[:,mask_BBB[0],:,:3],
true[:,mask_BBB[0],:,:3],
mask_crds[:,mask_BBB[0]],
frames[:,mask_BBB[0]],
frame_mask_BB[:,mask_BBB[0]],
dclamp=dclamp
)
fracA = float(L1)/len(same_chain[0,0])
tot_str = fracA*l_fape_A + (1.0-fracA)*l_fape_B
else:
tot_str = compute_general_FAPE(
pred_allatom[:,mask_BB[0],:,:3],
true[:,mask_BB[0],:,:3],
mask_crds[:,mask_BB[0]],
frames[:,mask_BB[0]],
frame_mask_BB[:,mask_BB[0]],
dclamp=dclamp
)
tot_loss += 0.5*w_str*tot_str[0]
loss_s.append(tot_str.detach())
# AllAtom loss
# get ground-truth torsion angles
true_tors, true_tors_alt, tors_mask, tors_planar = get_torsions(
true, seq, self.ti_dev, self.ti_flip, self.ang_ref, mask_in=mask_crds)
tors_mask *= mask_BB[...,None]
# get alternative coordinates for ground-truth
true_alt = torch.zeros_like(true)
true_alt.scatter_(2, self.l2a[seq,:,None].repeat(1,1,1,3), true)
print(true_alt)
natRs_all, _n0 = self.compute_allatom_coords(seq, true[...,:3,:], true_tors)
natRs_all_alt, _n1 = self.compute_allatom_coords(seq, true_alt[...,:3,:], true_tors_alt)
predTs = pred[-1,...]
predRs_all, pred_all = self.compute_allatom_coords(seq, predTs, pred_tors[-1])
# - resolve symmetry
xs_mask = self.aamask[seq] # (B, L, 27)
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
xs_mask *= mask_crds # mask missing atoms & residues as well
natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
# torsion angle loss
l_tors = torsionAngleLoss(
pred_tors,
true_tors,
true_tors_alt,
tors_mask,
tors_planar,
eps = 1e-10)
tot_loss += w_str*l_tors
loss_s.append(l_tors[None].detach())
### FINETUNING LAYERS
# lddts (CA)
ca_lddt = calc_lddt(pred[:,:,:,1].detach(), true[:,:,1], mask_BB, mask_2d, same_chain, negative=negative, interface=interface)
loss_s.append(ca_lddt.detach())
# lddts (allatom) + lddt loss
lddt_loss, allatom_lddt = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, negative=negative, interface=interface)
tot_loss += w_lddt*lddt_loss
loss_s.append(lddt_loss.detach()[None])
loss_s.append(allatom_lddt.detach())
# FAPE losses
# allatom fape and torsion angle loss
# frames, frame_mask = get_frames(
# pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames)
if negative: # inter-chain fapes should be ignored for negative cases
# L1 = same_chain[0,0,:].sum()
# mask_BBA = mask_BB.clone()
# mask_BBA[0, L1:] = False
l_fape_A = compute_general_FAPE(
pred_allatom[:,mask_BBA[0],:,:3],
nat_symm[None,mask_BBA[0],:,:3],
xs_mask[:,mask_BBA[0]],
frames[:,mask_BBA[0]],
frame_mask[:,mask_BBA[0]]
)
# mask_BBB = mask_BB.clone()
# mask_BBB[0,:L1] = False
l_fape_B = compute_general_FAPE(
pred_allatom[:,mask_BBB[0],:,:3],
nat_symm[None,mask_BBB[0],:,:3],
xs_mask[:,mask_BBB[0]],
frames[:,mask_BBB[0]],
frame_mask[:,mask_BBB[0]]
)
fracA = float(L1)/len(same_chain[0,0])
l_fape = fracA*l_fape_A + (1.0-fracA)*l_fape_B
else:
l_fape = compute_general_FAPE(
pred_allatom[:,mask_BB[0],:,:3],
nat_symm[None,mask_BB[0],:,:3],
xs_mask[:,mask_BB[0]],
frames[:,mask_BB[0]],
frame_mask[:,mask_BB[0]]
)
loss_s.append(l_fape.detach())
tot_loss += w_str*l_fape.mean()
# cart bonded (bond geometry)
bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx)
if w_bond > 0.0:
tot_loss += w_bond*bond_loss
loss_s.append( bond_loss[None].detach() )
if (pred_allatom.shape[0] > 1):
bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor)
if w_bond > 0.0:
tot_loss += w_bond*bond_loss.mean()
loss_s.append( bond_loss.detach() )
# clash [use all atoms not just those in native]
clash_loss = calc_lj(
seq[0], pred_allatom,
self.aamask, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds,
lj_lin=lj_lin
)
if w_clash > 0.0:
tot_loss += w_clash*clash_loss.mean()
loss_s.append( clash_loss.detach() )
L0 = same_chain[0,0,:].sum()
chain1 = torch.zeros_like(same_chain, dtype=bool)
chain1[:,:L0,:L0] = True
_, allatom_lddt_c1 = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain1, negative=True)
loss_s.append(allatom_lddt_c1.detach())
chain2 = torch.zeros_like(same_chain, dtype=bool)
chain2[:,L0:,L0:] = True
_, allatom_lddt_c2 = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True, bin_scaling=0.5)
loss_s.append(allatom_lddt_c2.detach())
_, allatom_lddt_inter = calc_allatom_lddt_loss(
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True)
loss_s.append(allatom_lddt_inter.detach())
# hbond [use all atoms not just those in native]
#hb_loss = calc_hb(
# seq[0], pred_all[0,...,:3],
# self.aamask, self.hbtypes, self.hbbaseatoms, self.hbpolys,
# normalize=(not verbose)
#)
#if w_hb > 0.0:
# tot_loss += w_hb*hb_loss
#loss_s.append(torch.stack((hb_loss, clash_loss, bond_loss)).detach())
if (verbose):
print (
ctr,
allatom_lddt.cpu().detach().numpy(),
l_fape.cpu().detach().numpy(),
mask_BB[0].sum()
)
writepdb("p_"+self.model_name+"_"+str(ctr)+".pdb", pred_all[-1,mask_BB[0]][:,:23], seq[mask_BB][:])
writepdb("n_"+str(ctr)+".pdb", true[mask_BB][:,:23], seq[mask_BB][:])
writepdb("nre_"+str(ctr)+".pdb", _n0[mask_BB], seq[mask_BB][:])
return tot_loss, torch.cat(loss_s, dim=0)
def calc_acc(self, prob, dist, idx_pdb, mask_2d, return_cnt=False):
B = idx_pdb.shape[0]
L = idx_pdb.shape[1] # (B, L)
seqsep = torch.abs(idx_pdb[:,:,None] - idx_pdb[:,None,:]) + 1
mask = seqsep > 24
mask = torch.triu(mask.float())
mask *= mask_2d
#
cnt_ref = dist < 20
cnt_ref = cnt_ref.float() * mask
#
cnt_pred = prob[:,:20,:,:].sum(dim=1) * mask
#
top_pred = torch.topk(cnt_pred.view(B,-1), L)
kth = top_pred.values.min(dim=-1).values
tmp_pred = list()
for i_batch in range(B):
tmp_pred.append(cnt_pred[i_batch] > kth[i_batch])
tmp_pred = torch.stack(tmp_pred, dim=0)
tmp_pred = tmp_pred.float()*mask
#
condition = torch.logical_and(tmp_pred==cnt_ref, cnt_ref==torch.ones_like(cnt_ref))
n_good = condition.float().sum()
n_total = (cnt_ref == torch.ones_like(cnt_ref)).float().sum() + 1e-9
n_total_pred = (tmp_pred == torch.ones_like(tmp_pred)).float().sum() + 1e-9
prec = n_good / n_total_pred
recall = n_good / n_total
F1 = 2.0*prec*recall / (prec+recall+1e-9)
if return_cnt:
return torch.stack([prec, recall, F1]), cnt_pred, cnt_ref
return torch.stack([prec, recall, F1])
def load_model(self, model, optimizer, scheduler, scaler, model_name, rank, suffix='last', resume_train=False):
chk_fn = "models/%s_%s.pt"%(model_name, suffix)
loaded_epoch = -1
best_valid_loss = 999999.9
if not os.path.exists(chk_fn):
print ('no model found', model_name)
return -1, best_valid_loss
print ('loading model', model_name)
map_location = {"cuda:%d"%0: "cuda:%d"%rank}
checkpoint = torch.load(chk_fn, map_location=map_location)
rename_model = False
new_chk = {}
msd_src = checkpoint['model_state_dict']
msd_tgt = model.module.model.state_dict()
for param in msd_tgt:
if param not in msd_src:
print ('missing',param)
rename_model=True
#break
elif (msd_tgt[param].shape == msd_src[param].shape):
new_chk[param] = msd_src[param]
else:
# fd hack for new encoding
if (msd_src[param].shape[0]==30 and msd_tgt[param].shape[0]==32 and 'compute_allatom_coords' not in param):
print ('Fixing',param)
new_chk[param] = torch.zeros_like(msd_tgt[param])
new_chk[param][:26] = msd_src[param][:26]
new_chk[param][27:31] = msd_src[param][26:30]
else:
#wrong size latent_emb.emb.weight torch.Size([256, 64]) torch.Size([256, 68])
#wrong size templ_emb.emb.weight torch.Size([64, 104]) torch.Size([64, 108])
#wrong size full_emb.emb.weight torch.Size([64, 33]) torch.Size([64, 35])
print (
'wrong size',param,
checkpoint['model_state_dict'][param].shape,
model.module.model.state_dict()[param].shape )
rename_model=True
#new_chk = checkpoint['model_state_dict']
model.module.model.load_state_dict(new_chk, strict=False)
model.module.shadow.load_state_dict(new_chk, strict=False)
if resume_train and (not rename_model):
print (' ... loading optimization params')
loaded_epoch = checkpoint['epoch']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])
if 'scheduler_state_dict' in checkpoint:
print (' ... loading scheduler params')
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
else:
scheduler.last_epoch = loaded_epoch + 1
#if 'best_loss' in checkpoint:
# best_valid_loss = checkpoint['best_loss']
return loaded_epoch, best_valid_loss
def checkpoint_fn(self, model_name, description):
if not os.path.exists("models"):
os.mkdir("models")
name = "%s_%s.pt"%(model_name, description)
return os.path.join("models", name)
# main entry function of training
# 1) make sure ddp env vars set
# 2) figure out if we launched using slurm or interactively
# - if slurm, assume 1 job launched per GPU
# - if interactive, launch one job for each GPU on node
def run_model_training(self, world_size):
if ('MASTER_ADDR' not in os.environ):
os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script
if ('MASTER_PORT' not in os.environ):
os.environ['MASTER_PORT'] = '%d'%self.port
if (not self.interactive and "SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ):
world_size = int(os.environ["SLURM_NTASKS"])
rank = int (os.environ["SLURM_PROCID"])
print ("Launched from slurm", rank, world_size)
self.train_model(rank, world_size)
else:
print ("Launched from interactive")
world_size = torch.cuda.device_count()
mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True)
def train_model(self, rank, world_size):
#print ("running ddp on rank %d, world_size %d"%(rank, world_size))
gpu = rank % torch.cuda.device_count()
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
torch.cuda.set_device("cuda:%d"%gpu)
#define dataset & data loader
(
pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items,
sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl,
valid_na_neg, valid_rna, valid_sm_compl, homo
) = get_train_valid_set(self.loader_param)
pdb_IDs, pdb_weights, pdb_dict = pdb_items
fb_IDs, fb_weights, fb_dict = fb_items
compl_IDs, compl_weights, compl_dict = compl_items
neg_IDs, neg_weights, neg_dict = neg_items
na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items
na_neg_IDs, na_neg_weights, na_neg_dict = na_neg_items
rna_IDs, rna_weights, rna_dict = rna_items
sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items
self.n_train = N_EXAMPLE_PER_EPOCH
self.n_valid_pdb = len(valid_pdb.keys())
#self.n_valid_pdb = (self.n_valid_pdb // world_size)*world_size
self.n_valid_homo = len(valid_homo.keys())
#self.n_valid_homo = (self.n_valid_homo // world_size)*world_size
self.n_valid_compl = len(valid_compl.keys())
#self.n_valid_compl = (self.n_valid_compl // world_size)*world_size
self.n_valid_neg = len(valid_neg.keys())
#self.n_valid_neg = (self.n_valid_neg // world_size)*world_size
self.n_valid_na_compl = len(valid_na_compl.keys())
#self.n_valid_na_compl = (self.n_valid_na_compl // world_size)*world_size
self.n_valid_na_neg = len(valid_na_neg.keys())
#self.n_valid_na_neg = (self.n_valid_na_neg // world_size)*world_size
self.n_valid_rna = len(valid_rna.keys())
#self.n_valid_rna = (self.n_valid_rna // world_size)*world_size
self.n_valid_rna = len(valid_rna.keys())
self.n_valid_sm_compl = len(valid_sm_compl.keys())
self.n_valid_pdb = 200
#self.n_valid_homo = 4
#self.n_valid_compl = 4
#self.n_valid_neg = 4
#self.n_valid_na_compl = 4
#self.n_valid_na_neg = 4
#self.n_valid_rna = 4
if (rank==0):
print ('Loaded (training)',
len(pdb_IDs),'monomers/homomers,',
len(fb_IDs),'distilled monomers,',
len(compl_IDs),'heteromers,',
len(neg_IDs),'negative heteromers,',
len(na_compl_IDs),'nucleic-acid complexes,',
len(na_neg_IDs),'negative nucleic-acid complexes,',
len(rna_IDs),'RNA structures, and',
len(sm_compl_IDs), 'small molecule complexes'
)
print ('Loaded (valid)',
len(valid_pdb.keys()),'monomers,',
len(valid_homo.keys()),'homomers,',
len(valid_compl.keys()),'heteromers,',
len(valid_neg.keys()),'negative heteromers,',
len(valid_na_compl.keys()),'nucleic-acid complexes,',
len(valid_na_neg.keys()),'negative nucleic-acid complexes,',
len(valid_rna),'RNA structures, and',
len(valid_sm_compl), 'small molecule complexes'
)
print ('Using',
self.n_valid_pdb,'monomers,',
self.n_valid_homo,'homomers,',
self.n_valid_compl,'heteromers,',
self.n_valid_neg,'negative heteromers',
self.n_valid_na_compl,'nucleic-acid complexes,',
self.n_valid_na_neg,'negative nucleic-acid complexes,',
self.n_valid_rna,'RNA structures, and',
self.n_valid_sm_compl, 'small molecule complexes'
)
train_set = DistilledDataset(
pdb_IDs, loader_pdb, pdb_dict,
compl_IDs, loader_complex, compl_dict,
neg_IDs, loader_complex, neg_dict,
na_compl_IDs, loader_na_complex, na_compl_dict,
na_neg_IDs, loader_na_complex, na_neg_dict,
fb_IDs, loader_fb, fb_dict,
rna_IDs, loader_rna, rna_dict,
sm_compl_IDs, loader_sm_compl, sm_compl_dict,
homo,
self.loader_param,
native_NA_frac=0.25
)
valid_pdb_set = Dataset(
list(valid_pdb.keys())[:self.n_valid_pdb],
loader_pdb, valid_pdb,
self.loader_param, homo, p_homo_cut=-1.0
)
# valid_homo_set = Dataset(
# list(valid_homo.keys())[:self.n_valid_homo],
# loader_pdb, valid_homo,
# self.loader_param, homo, p_homo_cut=2.0
# )
# valid_compl_set = DatasetComplex(
# list(valid_compl.keys())[:self.n_valid_compl],
# loader_complex, valid_compl,
# self.loader_param, negative=False
# )
# valid_neg_set = DatasetComplex(
# list(valid_neg.keys())[:self.n_valid_neg],
# loader_complex, valid_neg,
# self.loader_param, negative=True
# )
# valid_na_compl_set = DatasetNAComplex(
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
# loader_na_complex, valid_na_compl,
# self.loader_param, negative=False, native_NA_frac=1.0
# )
# valid_na_neg_set = DatasetNAComplex(
# list(valid_na_neg.keys())[:self.n_valid_na_neg],
# loader_na_complex, valid_na_neg,
# self.loader_param, negative=True, native_NA_frac=1.0
# )
# valid_na_from_scratch_compl_set = DatasetNAComplex(
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
# loader_na_complex, valid_na_compl,
# self.loader_param, negative=False, native_NA_frac=0.0
# )
# valid_na_from_scratch_neg_set = DatasetNAComplex(
# list(valid_na_neg.keys())[:self.n_valid_na_neg],
# loader_na_complex, valid_na_neg,
# self.loader_param, negative=True, native_NA_frac=0.0
# )
# valid_rna_set = DatasetRNA(
# list(valid_rna.keys())[:self.n_valid_rna],
# loader_rna, valid_rna,
# self.loader_param
# )
valid_sm_compl_set = DatasetSMComplex(
list(valid_sm_compl.keys())[:self.n_valid_sm_compl],
loader_sm_compl, valid_sm_compl,
self.loader_param
)
train_sampler = DistributedWeightedSampler(
train_set,
pdb_weights,
fb_weights,
compl_weights,
neg_weights,
na_compl_weights,
na_neg_weights,
rna_weights,
sm_compl_weights,
num_example_per_epoch=N_EXAMPLE_PER_EPOCH,
num_replicas=world_size,
rank=rank,
fraction_fb=0.0,
fraction_compl=0.0,
fraction_na_compl=0.0,
fraction_rna=0.0,
fraction_sm_compl=1.0,
replacement=True
)
valid_pdb_sampler = data.distributed.DistributedSampler(valid_pdb_set, num_replicas=world_size, rank=rank)
# valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
# valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
# valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
# valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
# valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
# valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
# valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank)
# valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank)
valid_sm_compl_sampler = data.distributed.DistributedSampler(valid_sm_compl_set, num_replicas=world_size, rank=rank)
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM)
valid_pdb_loader = data.DataLoader(valid_pdb_set, sampler=valid_pdb_sampler, **LOAD_PARAM)
# valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
# valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
# valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
# valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
# valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
# valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)
# valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM)
# valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM)
valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set, sampler=valid_sm_compl_sampler, **LOAD_PARAM)
# move some global data to cuda device
self.ti_dev = self.ti_dev.to(gpu)
self.ti_flip = self.ti_flip.to(gpu)
self.ang_ref = self.ang_ref.to(gpu)
self.fi_dev = self.fi_dev.to(gpu)
self.l2a = self.l2a.to(gpu)
self.aamask = self.aamask.to(gpu)
self.compute_allatom_coords = self.compute_allatom_coords.to(gpu)
self.num_bonds = self.num_bonds.to(gpu)
self.atom_type_index = self.atom_type_index.to(gpu)
self.ljlk_parameters = self.ljlk_parameters.to(gpu)
self.lj_correction_parameters = self.lj_correction_parameters.to(gpu)
self.hbtypes = self.hbtypes.to(gpu)
self.hbbaseatoms = self.hbbaseatoms.to(gpu)
self.hbpolys = self.hbpolys.to(gpu)
self.cb_len = self.cb_len.to(gpu)
self.cb_ang = self.cb_ang.to(gpu)
self.cb_tor = self.cb_tor.to(gpu)
# define model
model = EMA(RoseTTAFoldModule(
**self.model_param,
aamask=self.aamask,
atom_type_index=self.atom_type_index,
ljlk_parameters=self.ljlk_parameters,
lj_correction_parameters=self.lj_correction_parameters,
num_bonds=self.num_bonds,
cb_len = self.cb_len,
cb_ang = self.cb_ang,
cb_tor = self.cb_tor,
lj_lin=self.loss_param['lj_lin']
).to(gpu), 0.999)
#for n,p in model.named_parameters():
# if ("finetune_refiner" not in n and "residue_embed" not in n and "allatom_embed" not in n):
# p.requires_grad_(False)
ddp_model = DDP(model, device_ids=[gpu], find_unused_parameters=False)
if rank == 0:
print ("# of parameters:", count_parameters(ddp_model))
# define optimizer and scheduler
opt_params = add_weight_decay(ddp_model, self.l2_coeff)
optimizer = torch.optim.AdamW(opt_params, lr=self.init_lr)
#scheduler = get_stepwise_decay_schedule_with_warmup(optimizer, 1000, 5000, 0.95)
scheduler = get_stepwise_decay_schedule_with_warmup(optimizer, 0, 5000, 0.95)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)
# load model
loaded_epoch, best_valid_loss = self.load_model(ddp_model, optimizer, scheduler, scaler,
self.model_name, gpu, suffix="best", resume_train=True)
if (self.eval):
# run protein/NA prediction (TEMPLATED)
#_, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_compl_loader, valid_na_neg_loader,
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
# run protein/NA prediction (NON-TEMPLATED)
#_, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
# run RNA prediction
#_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, 0, verbose=True)
_, _, _ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, 0, verbose=True)
dist.destroy_process_group()
return
if loaded_epoch >= self.n_epoch:
DDP_cleanup()
return
#_, _, _ = self.valid_pdb_cycle(ddp_model, valid_homo_loader, rank, gpu, world_size, epoch, header="Homo")
#_, _, _ = self.valid_ppi_cycle(ddp_model, valid_compl_loader, valid_neg_loader, rank, gpu, world_size, epoch, report_interface=True)
#_, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_compl_loader, valid_na_neg_loader,
# rank, gpu, world_size, epoch, header="NA", report_interface=False)
#_, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
# rank, gpu, world_size, epoch, header="NAfs", report_interface=False)
#_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, epoch, header="RNA")
for epoch in range(loaded_epoch+1, self.n_epoch):
train_sampler.set_epoch(epoch)
valid_pdb_sampler.set_epoch(epoch)
#valid_homo_sampler.set_epoch(epoch)
#valid_compl_sampler.set_epoch(epoch)
#valid_neg_sampler.set_epoch(epoch)
train_tot, train_loss, train_acc = self.train_cycle(ddp_model, train_loader, optimizer, scheduler, scaler, rank, gpu, world_size, epoch)
valid_tot, valid_loss, valid_acc = self.valid_pdb_cycle(ddp_model, valid_pdb_loader, rank, gpu, world_size, epoch)
#_, _, _ = self.valid_pdb_cycle(ddp_model, valid_homo_loader, rank, gpu, world_size, epoch, header="Homo")
#_, _, _ = self.valid_ppi_cycle(ddp_model, valid_compl_loader, valid_neg_loader, rank, gpu, world_size, epoch, report_interface=True)
# _, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_compl_loader, valid_na_neg_loader,
# rank, gpu, world_size, epoch, header="NA", report_interface=False)
# _, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
# rank, gpu, world_size, epoch, header="NAfs", report_interface=False)
# _,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, epoch, header="RNA")
_,_,_ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, epoch, header="SM Compl")
if rank == 0: # save model
if valid_tot < best_valid_loss:
best_valid_loss = valid_tot
torch.save({'epoch': epoch,
#'model_state_dict': ddp_model.state_dict(),
'model_state_dict': ddp_model.module.shadow.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'best_loss': best_valid_loss,
'train_loss': train_loss,
'train_acc': train_acc,
'valid_loss': valid_loss,
'valid_acc': valid_acc},
self.checkpoint_fn(self.model_name, 'best'))
torch.save({'epoch': epoch,
#'model_state_dict': ddp_model.state_dict(),
'model_state_dict': ddp_model.module.shadow.state_dict(),
'final_state_dict': ddp_model.module.model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
'train_loss': train_loss,
'train_acc': train_acc,
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'best_loss': best_valid_loss},
self.checkpoint_fn(self.model_name, str(epoch)))
dist.destroy_process_group()
def train_cycle(self, ddp_model, train_loader, optimizer, scheduler, scaler, rank, gpu, world_size, epoch, verbose=False):
# Turn on training mode
ddp_model.train()
# clear gradients
optimizer.zero_grad()
start_time = time.time()
# For intermediate logs
local_tot = 0.0
local_loss = None
local_acc = None
train_tot = 0.0
train_loss = None
train_acc = None
counter = 0
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in train_loader:
# transfer inputs to device
B, _, N, L = msa.shape
idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L)
true_crds = true_crds.to(gpu, non_blocking=True) # (B, N?, L, Natms, 3)
atom_mask = atom_mask.to(gpu, non_blocking=True) # (B, L, Natms)
same_chain = same_chain.to(gpu, non_blocking=True) # (B, L, L)
xyz_t = xyz_t.to(gpu, non_blocking=True)
t1d = t1d.to(gpu, non_blocking=True)
seq = seq.to(gpu, non_blocking=True)
msa = msa.to(gpu, non_blocking=True)
msa_masked = msa_masked.to(gpu, non_blocking=True)
msa_full = msa_full.to(gpu, non_blocking=True)
mask_msa = mask_msa.to(gpu, non_blocking=True)
atom_frames = atom_frames.to(gpu, non_blocking=True)
bond_feats = bond_feats.to(gpu, non_blocking=True)
# processing template features
# get torsion angles from templates
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames)
t2d = xyz_to_t2d(xyz_t_frames)
alpha, _, alpha_mask, _ = get_torsions(
xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS)
# processing template coordinates
xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain)
xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3)
counter += 1
N_cycle = np.random.randint(1, self.maxcycle+1) # number of recycling
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True)
state_prev = None
with torch.no_grad():
for i_cycle in range(N_cycle-1):
with ddp_model.no_sync():
with torch.cuda.amp.autocast(enabled=USE_AMP):
msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
return_raw=True,
use_checkpoint=False
)
i_cycle = N_cycle-1
if counter%self.ACCUM_STEP != 0:
with ddp_model.no_sync():
with torch.cuda.amp.autocast(enabled=USE_AMP):
logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
use_checkpoint=True
)
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0])))
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames)
c6d, _ = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
prob = self.active_fn(logit_s[0]) # distogram
acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
ctrid = len(train_loader)*rank+counter
loss, loss_s = self.calc_loss(
logit_s, c6d,
logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle],
pred_crds, alphas, pred_allatom, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, atom_frames=atom_frames,
unclamp=unclamp, negative=negative,
verbose=verbose, ctr=ctrid, **self.loss_param
)
loss = loss / self.ACCUM_STEP
scaler.scale(loss).backward()
else:
with torch.cuda.amp.autocast(enabled=USE_AMP):
logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
use_checkpoint=True
)
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0])))
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames)
c6d, _ = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
prob = self.active_fn(logit_s[0]) # distogram
acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
ctrid = len(train_loader)*rank+counter
loss, loss_s = self.calc_loss(
logit_s, c6d,
logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle],
pred_crds, alphas, pred_allatom, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, atom_frames=atom_frames, unclamp=unclamp, negative=negative,
verbose=verbose, ctr=ctrid, **self.loss_param
)
loss = loss / self.ACCUM_STEP
scaler.scale(loss).backward()
# gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.2)
scaler.step(optimizer)
scale = scaler.get_scale()
scaler.update()
skip_lr_sched = (scale != scaler.get_scale())
optimizer.zero_grad()
if not skip_lr_sched:
scheduler.step()
ddp_model.module.update() # apply EMA
local_tot += loss.detach()*self.ACCUM_STEP
if local_loss == None:
local_loss = torch.zeros_like(loss_s.detach())
local_acc = torch.zeros_like(acc_s.detach())
local_loss += loss_s.detach()
local_acc += acc_s.detach()
train_tot += loss.detach()*self.ACCUM_STEP
if train_loss == None:
train_loss = torch.zeros_like(loss_s.detach())
train_acc = torch.zeros_like(acc_s.detach())
train_loss += loss_s.detach()
train_acc += acc_s.detach()
if counter % N_PRINT_TRAIN == 0:
if rank == 0:
max_mem = torch.cuda.max_memory_allocated()/1e9
train_time = time.time() - start_time
local_tot /= float(N_PRINT_TRAIN)
local_loss /= float(N_PRINT_TRAIN)
local_acc /= float(N_PRINT_TRAIN)
local_tot = local_tot.cpu().detach()
local_loss = local_loss.cpu().detach().numpy()
local_acc = local_acc.cpu().detach().numpy()
sys.stdout.write("Local: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f | Max mem %.4f\n"%(\
epoch, self.n_epoch, counter*self.batch_size*world_size, self.n_train, train_time, local_tot, \
" ".join(["%8.4f"%l for l in local_loss]),\
local_acc[0], local_acc[1], local_acc[2], max_mem))
sys.stdout.flush()
local_tot = 0.0
local_loss = None
local_acc = None
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
# write total train loss
train_tot /= float(counter * world_size)
train_loss /= float(counter * world_size)
train_acc /= float(counter * world_size)
dist.all_reduce(train_tot, op=dist.ReduceOp.SUM)
dist.all_reduce(train_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
train_tot = train_tot.cpu().detach()
train_loss = train_loss.cpu().detach().numpy()
train_acc = train_acc.cpu().detach().numpy()
if rank == 0:
train_time = time.time() - start_time
sys.stdout.write("Train: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f\n"%(\
epoch, self.n_epoch, self.n_train, self.n_train, train_time, train_tot, \
" ".join(["%8.4f"%l for l in train_loss]),\
train_acc[0], train_acc[1], train_acc[2]))
sys.stdout.flush()
return train_tot, train_loss, train_acc
def valid_pdb_cycle(self, ddp_model, valid_loader, rank, gpu, world_size, epoch, header='Monomer', verbose=False):
valid_tot = 0.0
valid_loss = None
valid_acc = None
counter = 0
start_time = time.time()
with torch.no_grad(): # no need to calculate gradient
ddp_model.eval() # change it to eval mode
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in valid_loader:
# transfer inputs to device
B, _, N, L = msa.shape
idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L)
true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3)
atom_mask = atom_mask.to(gpu, non_blocking=True) # (B, L, 27)
same_chain = same_chain.to(gpu, non_blocking=True)
xyz_t = xyz_t.to(gpu, non_blocking=True)
t1d = t1d.to(gpu, non_blocking=True)
seq = seq.to(gpu, non_blocking=True)
msa = msa.to(gpu, non_blocking=True)
msa_masked = msa_masked.to(gpu, non_blocking=True)
msa_full = msa_full.to(gpu, non_blocking=True)
mask_msa = mask_msa.to(gpu, non_blocking=True)
atom_frames = atom_frames.to(gpu, non_blocking=True)
bond_feats = bond_feats.to(gpu, non_blocking=True)
# res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) # ignore residues having missing BB atoms for loss calculation
# mask_2d = res_mask[:,None,:] * res_mask[:,:,None] # ignore pairs having missing residues
# processing template features
# get torsion angles from templates
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames)
t2d = xyz_to_t2d(xyz_t_frames)
alpha, _, alpha_mask, _ = get_torsions(xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS)
# processing template coordinates
xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain)
xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3)
# set number of recycles
N_cycle = self.maxcycle
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) #fd we could get this from the template...
state_prev = None
for i_cycle in range(N_cycle-1):
msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
return_raw=True,
use_checkpoint=False
)
#true_crds_i, atom_mask_i = resolve_equiv_natives(xyz_prev, true_crds, atom_mask)
#res_mask = ~(atom_mask_i[:,:,:3].sum(dim=-1) < 3.0)
#mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
i_cycle = N_cycle-1
logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
use_checkpoint=False
)
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0])))
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames)
c6d, _ = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
prob = self.active_fn(logit_s[0]) # distogram
acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d)
ctrid = len(valid_loader)*rank+counter
loss, loss_s = self.calc_loss(
logit_s, c6d,
logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle],
pred_crds, alphas, pred_allatom, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, atom_frames, unclamp=unclamp, negative=negative,
verbose=verbose, ctr=ctrid, **self.loss_param
)
valid_tot += loss.detach()
if valid_loss == None:
valid_loss = torch.zeros_like(loss_s.detach())
valid_acc = torch.zeros_like(acc_s.detach())
valid_loss += loss_s.detach()
valid_acc += acc_s.detach()
counter += 1
valid_tot /= float(counter*world_size)
valid_loss /= float(counter*world_size)
valid_acc /= float(counter*world_size)
dist.all_reduce(valid_tot, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_acc, op=dist.ReduceOp.SUM)
valid_tot = valid_tot.cpu().detach().numpy()
valid_loss = valid_loss.cpu().detach().numpy()
valid_acc = valid_acc.cpu().detach().numpy()
if rank == 0:
train_time = time.time() - start_time
sys.stdout.write("%s: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f\n"%(\
header, epoch, self.n_epoch, world_size*len(valid_loader), world_size*len(valid_loader), train_time, valid_tot, \
" ".join(["%8.4f"%l for l in valid_loss]),\
valid_acc[0], valid_acc[1], valid_acc[2]))
sys.stdout.flush()
return valid_tot, valid_loss, valid_acc
def valid_ppi_cycle(self, ddp_model, valid_pos_loader, valid_neg_loader, rank, gpu, world_size, epoch, header='Protein', report_interface=True, verbose=False):
valid_tot = 0.0
valid_loss = None
valid_acc = None
valid_inter = None
counter = 0
TP = 0
TN = 0
FP = 0
FN = 0
start_time = time.time()
with torch.no_grad(): # no need to calculate gradient
ddp_model.eval() # change it to eval mode
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in valid_pos_loader:
# transfer inputs to device
B, _, N, L = msa.shape
idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L)
true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3)
atom_mask = mask_crds.to(gpu, non_blocking=True) # (B, L, 27)
same_chain = same_chain.to(gpu, non_blocking=True)
xyz_t = xyz_t.to(gpu, non_blocking=True)
t1d = t1d.to(gpu, non_blocking=True)
xyz_prev = xyz_prev.to(gpu, non_blocking=True)
seq = seq.to(gpu, non_blocking=True)
msa = msa.to(gpu, non_blocking=True)
msa_masked = msa_masked.to(gpu, non_blocking=True)
msa_full = msa_full.to(gpu, non_blocking=True)
mask_msa = mask_msa.to(gpu, non_blocking=True)
atom_frames = atom_frames.to(gpu, non_blocking=True)
bond_feats = bond_feats.to(gpu, non_blocking=True)
# processing labels for distogram orientograms
# res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) # ignore residues having missing BB atoms for loss calculation
# mask_2d = res_mask[:,None,:] * res_mask[:,:,None] # ignore pairs having missing residues
# processing template features
# get torsion angles from templates
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames)
t2d = xyz_to_t2d(xyz_t_frames)
alpha, _, alpha_mask, _ = get_torsions(xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS)
# processing template coordinates
xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain)
xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3)
N_cycle = self.maxcycle # number of recycling
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) #fd we could get this from the template...
state_prev = None
for i_cycle in range(N_cycle-1):
msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
return_raw=True,
use_checkpoint=False
)
#true_crds_i, atom_mask_i = resolve_equiv_natives(xyz_prev, true_crds, atom_mask)
#res_mask = ~(atom_mask_i[:,:,:3].sum(dim=-1) < 3.0)
#mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
i_cycle = N_cycle-1
logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
use_checkpoint=False
)
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) and ~(is_atom(msa[:,i_cycle,0])))
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames)
c6d, _ = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
prob = self.active_fn(logit_s[0]) # distogram
acc_s, cnt_pred, cnt_ref = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d, return_cnt=True)
# inter-chain contact prob
cnt_pred = cnt_pred * (1-same_chain).float()
cnt_ref = cnt_ref * (1-same_chain).float()
max_prob = cnt_pred.max()
if max_prob > 0.5:
if (cnt_ref > 0).any():
TP += 1.0
else:
FP += 1.0
else:
if (cnt_ref > 0).any():
FN += 1.0
else:
TN += 1.0
inter_s = torch.tensor([TP, FP, TN, FN], device=prob.device).float()
ctrid = len(valid_pos_loader)*rank+counter
loss, loss_s = self.calc_loss(
logit_s, c6d,
logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle],
pred_crds, alphas, pred_allatom, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, atom_frames, unclamp=unclamp, negative=negative, interface=report_interface,
verbose=verbose, ctr=ctrid, **self.loss_param
)
valid_tot += loss.detach()
if valid_loss == None:
valid_loss = torch.zeros_like(loss_s.detach())
valid_acc = torch.zeros_like(acc_s.detach())
valid_inter = torch.zeros_like(inter_s.detach())
valid_loss += loss_s.detach()
valid_acc += acc_s.detach()
valid_inter += inter_s.detach()
counter += 1
valid_tot /= float(counter*world_size)
valid_loss /= float(counter*world_size)
valid_acc /= float(counter*world_size)
dist.all_reduce(valid_tot, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_acc, op=dist.ReduceOp.SUM)
valid_tot = valid_tot.cpu().detach().numpy()
valid_loss = valid_loss.cpu().detach().numpy()
valid_acc = valid_acc.cpu().detach().numpy()
if rank == 0:
train_time = time.time() - start_time
sys.stdout.write("%s-interface: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f\n"%(\
header, epoch, self.n_epoch, counter*world_size, counter*world_size, train_time, valid_tot, \
" ".join(["%8.4f"%l for l in valid_loss]),\
valid_acc[0], valid_acc[1], valid_acc[2]))
sys.stdout.flush()
valid_tot = 0.0
valid_loss = None
valid_acc = None
counter = 0
start_time = time.time()
with torch.no_grad(): # no need to calculate gradient
ddp_model.eval() # change it to eval mode
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in valid_neg_loader:
# transfer inputs to device
B, _, N, L = msa.shape
idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L)
true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3)
atom_mask = mask_crds.to(gpu, non_blocking=True) # (B, L, 27)
same_chain = same_chain.to(gpu, non_blocking=True)
xyz_t = xyz_t.to(gpu, non_blocking=True)
t1d = t1d.to(gpu, non_blocking=True)
xyz_prev = xyz_prev.to(gpu, non_blocking=True)
seq = seq.to(gpu, non_blocking=True)
msa = msa.to(gpu, non_blocking=True)
msa_masked = msa_masked.to(gpu, non_blocking=True)
msa_full = msa_full.to(gpu, non_blocking=True)
mask_msa = mask_msa.to(gpu, non_blocking=True)
atom_frames = atom_frames.to(gpu, non_blocking=True)
bond_feats = bond_feats.to(gpu, non_blocking=True)
# processing labels for distogram orientograms
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) # ignore residues having missing BB atoms for loss calculation
mask_2d = res_mask[:,None,:] * res_mask[:,:,None] # ignore pairs having missing residues
# processing template features
# get torsion angles from templates
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames)
t2d = xyz_to_t2d(xyz_t_frames)
alpha, _, alpha_mask, _ = get_torsions(xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS)
# processing template coordinates
xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain)
xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3)
N_cycle = self.maxcycle # number of recycling
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) #fd we could get this from the template...
state_prev = None
for i_cycle in range(N_cycle-1):
msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
return_raw=True,
use_checkpoint=False
)
i_cycle = N_cycle-1
logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model(
msa_masked[:,i_cycle],
msa_full[:,i_cycle],
seq[:,i_cycle],
msa[:,i_cycle,0], # unmasked seq
xyz_prev,
alpha_prev,
idx_pdb,
bond_feats,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev,
use_checkpoint=False
)
true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0])))
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames)
c6d, _ = xyz_to_c6d(true_crds_frame)
c6d = c6d_to_bins(c6d, same_chain, negative=negative)
prob = self.active_fn(logit_s[0]) # distogram
acc_s, cnt_pred, cnt_ref = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d, return_cnt=True)
# inter-chain contact prob
cnt_pred = cnt_pred * (1-same_chain).float()
cnt_ref = cnt_ref * (1-same_chain).float()
max_prob = cnt_pred.max()
if max_prob > 0.5:
if (cnt_ref > 0).any():
TP += 1.0
else:
FP += 1.0
else:
if (cnt_ref > 0).any():
FN += 1.0
else:
TN += 1.0
inter_s = torch.tensor([TP, FP, TN, FN], device=prob.device).float()
loss, loss_s = self.calc_loss(
logit_s, c6d,
logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle],
pred_crds, alphas, pred_allatom, true_crds,
atom_mask, res_mask, mask_2d, same_chain,
pred_lddts, idx_pdb, atom_frames, unclamp=unclamp, negative=negative,
verbose=verbose, ctr=ctrid, **self.loss_param
)
valid_tot += loss.detach()
if valid_loss == None:
valid_loss = torch.zeros_like(loss_s.detach())
valid_acc = torch.zeros_like(acc_s.detach())
valid_loss += loss_s.detach()
valid_acc += acc_s.detach()
valid_inter += inter_s.detach()
counter += 1
valid_tot /= float(counter*world_size)
valid_loss /= float(counter*world_size)
valid_acc /= float(counter*world_size)
dist.all_reduce(valid_tot, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_loss, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_acc, op=dist.ReduceOp.SUM)
dist.all_reduce(valid_inter, op=dist.ReduceOp.SUM)
valid_tot = valid_tot.cpu().detach().numpy()
valid_loss = valid_loss.cpu().detach().numpy()
valid_acc = valid_acc.cpu().detach().numpy()
valid_inter = valid_inter.cpu().detach().numpy()
if rank == 0:
TP, FP, TN, FN = valid_inter
prec = TP/(TP+FP+1e-4)
recall = TP/(TP+FN+1e-4)
F1 = 2*TP/(2*TP+FP+FN+1e-4)
train_time = time.time() - start_time
sys.stdout.write("%s-PPI: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f | %.4f %.4f %.4f\n"%(\
header, epoch, self.n_epoch, counter*world_size, counter*world_size, train_time, valid_tot, \
" ".join(["%8.4f"%l for l in valid_loss]),\
valid_acc[0], valid_acc[1], valid_acc[2],\
prec, recall, F1))
sys.stdout.flush()
return valid_tot, valid_loss, valid_acc
if __name__ == "__main__":
from arguments import get_args
args, model_param, loader_param, loss_param = get_args()
print (args)
mp.freeze_support()
train = Trainer(model_name=args.model_name,
n_epoch=args.num_epochs, step_lr=args.step_lr, lr=args.lr, l2_coeff=1.0e-2,
port=args.port, model_param=model_param, loader_param=loader_param,
loss_param=loss_param,
batch_size=args.batch_size,
accum_step=args.accum,
maxcycle=args.maxcycle,
eval=args.eval)
train.run_model_training(torch.cuda.device_count())