Files
foundry/RF2_allatom/eval.py
2022-07-11 00:48:54 -07:00

343 lines
13 KiB
Python

import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, parse_fasta, read_template_pdb
from RoseTTAFoldModel import RoseTTAFoldModule
import util
from collections import namedtuple
from ffindex import *
from data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_init_xyz
from util_module import ComputeAllAtomCoords
from chemical import NTOTAL, NTOTALDOFS, NAATOKENS
from memory import mem_report
MAX_CYCLE = 30
NREPLICATES = 5
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_extra_block" : 4,
"n_main_block" : 32,
"n_ref_block" : 4,
"d_msa" : 256 ,
"d_pair" : 128,
"d_templ" : 64,
"n_head_msa" : 8,
"n_head_pair" : 4,
"n_head_templ" : 4,
"d_hidden" : 32,
"d_hidden_templ" : 64,
"p_drop" : 0.15,
"lj_lin" : 0.75
}
SE3_param = {
"num_layers" : 1,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
SE3_ref_param = {
"num_layers" : 2,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['SE3_ref_param'] = SE3_ref_param
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
# compute expected value from binned lddt
def lddt_unbin(pred_lddt):
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
class Predictor():
def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"):
if model_dir == None:
self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__)))
else:
self.model_dir = model_dir
#
# define model name
self.model_name = model_name
self.device = device
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule(
**MODEL_PARAM,
aamask=util.allatom_mask.to(self.device),
ljlk_parameters=util.ljlk_parameters.to(self.device),
lj_correction_parameters=util.lj_correction_parameters.to(self.device),
num_bonds=util.num_bonds.to(self.device)
).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device)
def load_model(self, model_name, suffix='last'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
return True
def predict(self, fasta_fn, out_prefix, tmpl_fn=None, atab_fn=None, window=1e9, shift=50, n_latent=256, oligo=1):
msa_orig, ins_orig = parse_fasta(fasta_fn, rmsa_alphabet=True)
msa_orig = torch.tensor(msa_orig).long()
ins_orig = torch.tensor(ins_orig).long()
#sel = torch.arange(941) #, msa_orig.shape[1])
#msa_orig = msa_orig[:,sel]
#ins_orig = ins_orig[:,sel]
if (oligo>1):
msa_orig, ins_orig = merge_a3m_homo(msa_orig, ins_orig, oligo) # make unpaired alignments, for training, we always use two chains
N, L = msa_orig.shape
#
if tmpl_fn and os.path.exists(tmpl_fn):
xyz_t, t1d = read_template_pdb(L, tmpl_fn)
#xyz_t, t1d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=4)
else:
xyz_t = torch.full((1,L,3,3),np.nan).float()
t1d = torch.nn.functional.one_hot(torch.full((1, L), 20).long(), num_classes=NAATOKENS-1).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,L,1)).float()), -1)
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t)
same_chain = torch.ones((1,L,L), dtype=torch.bool, device=xyz_t.device)
xyz_t = get_init_xyz(msa_orig[0:1],xyz_t,same_chain) # initialize coordinates with first template
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
alpha, _, alpha_mask, _ = util.get_torsions(
xyz_t.reshape(-1,L,NTOTAL,3),
seq_tmp,
util.torsion_indices,
util.torsion_can_flip,
util.reference_angles
)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(1,-1,L,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(1,-1,L,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 3*NTOTALDOFS)
self.model.eval()
for i_trial in range(NREPLICATES):
if os.path.exists("%s_%02d_init.pdb"%(out_prefix, i_trial)):
continue
self.run_prediction(msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial), n_latent=n_latent)
torch.cuda.empty_cache()
def run_prediction(self, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz, alpha_t, out_prefix, n_latent=256):
start = time.time()
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
#
msa = msa_orig.to(self.device) # (N, L)
ins = ins_orig.long().to(self.device)
#if msa_orig.shape[0] > 4096:
# msa, ins = MSABlockDeletion(msa, ins)
#
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(
msa, ins, p_mask=0.0, params={'MAXLAT': 128, 'MAXSEQ': 1024, 'MAXCYCLE': MAX_CYCLE}, tocpu=True)
_, N, L = msa_seed.shape[:3]
B = 1
#
idx_pdb = torch.arange(L).long().view(1, L)
#
seq = seq.unsqueeze(0)
msa_seed = msa_seed.unsqueeze(0)
msa_extra = msa_extra.unsqueeze(0)
t1d = t1d.to(self.device)
t2d = t2d.to(self.device)
idx_pdb = idx_pdb.to(self.device)
xyz_t = xyz_t.to(self.device)
alpha_t = alpha_t.to(self.device)
xyz = xyz.to(self.device)
self.write_pdb(seq[0, -1], xyz[0], prefix="%s_templ"%(out_prefix))
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((1,L,NTOTALDOFS,2), device=seq.device)
xyz_prev=xyz
state_prev = None
best_lddt = torch.tensor([-1.0], device=seq.device)
best_xyz = None
best_logit = None
best_aa = None
for i_cycle in range(MAX_CYCLE):
msa_seed_i = msa_seed[:,i_cycle].to(self.device)
msa_extra_i = msa_extra[:,i_cycle].to(self.device)
with torch.cuda.amp.autocast(True):
logit_s, logit_aa_s, init_crds, alpha_prev, _, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model(
msa_seed_i,
msa_extra_i,
seq[:,i_cycle],
seq[:,i_cycle],
xyz_prev,
alpha_prev,
idx_pdb,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev
)
logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1)
xyz_prev = init_crds[-1]
alpha_prev = alpha_prev[-1]
pred_lddt = lddt_unbin(pred_lddt_binned)
print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean())
_, all_crds = self.compute_allatom_coords(seq[:,i_cycle], init_crds[-1], alpha_prev)
#self.write_pdb(seq[0, -1], all_crds[0], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle))
if pred_lddt.mean() < best_lddt.mean():
continue
best_xyz = all_crds.clone()
best_logit = logit_s
best_aa = logit_aa_s
best_lddt = pred_lddt.clone()
#print (pred_lddt)
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy()
prob_s.append(prob)
end = time.time()
for prob in prob_s:
prob += 1e-8
prob = prob / torch.sum(prob, dim=0)[None]
self.write_pdb(seq[0, -1], best_xyz[0], Bfacts=100*best_lddt[0], prefix="%s_init"%(out_prefix))
prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s]
np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16),\
lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16))
max_mem = torch.cuda.max_memory_allocated()/1e9
print ("max mem", max_mem)
print ("runtime", end-start)
def write_pdb(self, seq, atoms, Bfacts=None, prefix=None):
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 100)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
"A", i+1, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
else:
atms = util.aa2long[s]
for j,atm_j in enumerate(atms):
if (atm_j is not None):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
#DB="/home/robetta/rosetta_server_beta/external/databases/trRosetta/pdb100_2021Mar03/pdb100_2021Mar03"
DB = "/projects/ml/TrRosetta/pdb100_2020Mar11/pdb100_2020Mar11"
import argparse
parser = argparse.ArgumentParser(description="RoseTTAFold: Protein structure prediction with 3-track attentions on 1D, 2D, and 3D features")
parser.add_argument("fasta", help="fasta for structure prediction")
parser.add_argument("-oligo", type=int, default=1)
parser.add_argument("-prefix", type=str, default="pred")
parser.add_argument("-tmpl", default=None)
parser.add_argument("-model_name", default="BFF", required=False,
help="Prefix for model. The model under models/[model_name]_best.pt will be used. [BFF]")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
pred = Predictor(model_name=args.model_name)
pred.predict(args.fasta, args.prefix, args.tmpl, oligo=args.oligo)