created new folder for all atom

This commit is contained in:
Rohith Krishna
2022-07-11 00:48:54 -07:00
commit 0f7d4f6939
29 changed files with 20568 additions and 0 deletions

3
RF2_allatom/.gitignore vendored Normal file
View File

@@ -0,0 +1,3 @@
valid_remapped
lig_test
dataset.pkl

View File

@@ -0,0 +1,473 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from opt_einsum import contract as einsum
from util_module import init_lecun_normal
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, r_ff, p_drop=0.1):
super(FeedForwardLayer, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, d_model*r_ff)
self.dropout = nn.Dropout(p_drop)
self.linear2 = nn.Linear(d_model*r_ff, d_model)
self.reset_parameter()
def reset_parameter(self):
# initialize linear layer right before ReLu: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear1.bias)
# initialize linear layer right before residual connection: zero initialize
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, src):
src = self.norm(src)
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
return src
class Attention(nn.Module):
# calculate multi-head attention
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
super(Attention, self).__init__()
self.h = n_head
self.dim = d_hidden
#
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
#
self.to_out = nn.Linear(n_head*d_hidden, d_out)
self.scaling = 1/math.sqrt(d_hidden)
#
# initialize all parameters properly
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, query, key, value):
B, Q = query.shape[:2]
B, K = key.shape[:2]
#
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
key = self.to_k(key).reshape(B, K, self.h, self.dim)
value = self.to_v(value).reshape(B, K, self.h, self.dim)
#
query = query * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bhqk,bkhd->bqhd', attn, value)
out = out.reshape(B, Q, self.h*self.dim)
#
out = self.to_out(out)
return out
# MSA Attention (row/column) from AlphaFold architecture
class SequenceWeight(nn.Module):
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
super(SequenceWeight, self).__init__()
self.h = n_head
self.dim = d_hidden
self.scale = 1.0 / math.sqrt(self.dim)
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
self.dropout = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_query.weight)
nn.init.xavier_uniform_(self.to_key.weight)
def forward(self, msa):
B, N, L = msa.shape[:3]
tar_seq = msa[:,0]
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
q = q * self.scale
attn = einsum('bqihd,bkihd->bkihq', q, k)
attn = F.softmax(attn, dim=1)
return self.dropout(attn)
class MSARowAttentionWithBias(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
super(MSARowAttentionWithBias, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
#
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa, pair): # TODO: make this as tied-attention
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
#
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
key = key * self.scaling
attn = einsum('bsqhd,bskhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
#
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
class MSAColAttention(nn.Module):
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
super(MSAColAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * self.scaling
attn = einsum('bqihd,bkihd->bihqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
class MSAColGlobalAttention(nn.Module):
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
super(MSAColGlobalAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
query = query.mean(dim=1) # (B, L, h, dim)
key = self.to_k(msa) # (B, N, L, dim)
value = self.to_v(msa) # (B, N, L, dim)
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
#
query = query * self.scaling
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
out = gate * out # (B, N, L, h*dim)
#
out = self.to_out(out)
return out
# TriangleAttention & TriangleMultiplication from AlphaFold architecture
class TriangleAttention(nn.Module):
def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True):
super(TriangleAttention, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.start_node=start_node
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, pair):
B, L = pair.shape[:2]
pair = self.norm(pair)
# input projection
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
# attention
query = query * self.scaling
if self.start_node:
attn = einsum('bijhd,bikhd->bijkh', query, key)
else:
attn = einsum('bijhd,bkjhd->bijkh', query, key)
attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh)
attn = F.softmax(attn, dim=-2)
if self.start_node:
out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1)
else:
out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1)
out = gate * out # gated attention
# output projection
out = self.to_out(out)
return out
class TriangleMultiplication(nn.Module):
def __init__(self, d_pair, d_hidden=128, outgoing=True):
super(TriangleMultiplication, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.left_proj = nn.Linear(d_pair, d_hidden)
self.right_proj = nn.Linear(d_pair, d_hidden)
self.left_gate = nn.Linear(d_pair, d_hidden)
self.right_gate = nn.Linear(d_pair, d_hidden)
#
self.gate = nn.Linear(d_pair, d_pair)
self.norm_out = nn.LayerNorm(d_hidden)
self.out_proj = nn.Linear(d_hidden, d_pair)
self.outgoing = outgoing
self.reset_parameter()
def reset_parameter(self):
# normal distribution for regular linear weights
self.left_proj = init_lecun_normal(self.left_proj)
self.right_proj = init_lecun_normal(self.right_proj)
# Set Bias of Linear layers to zeros
nn.init.zeros_(self.left_proj.bias)
nn.init.zeros_(self.right_proj.bias)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.left_gate.weight)
nn.init.ones_(self.left_gate.bias)
nn.init.zeros_(self.right_gate.weight)
nn.init.ones_(self.right_gate.bias)
nn.init.zeros_(self.gate.weight)
nn.init.ones_(self.gate.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, pair):
B, L = pair.shape[:2]
pair = self.norm(pair)
left = self.left_proj(pair) # (B, L, L, d_h)
left_gate = torch.sigmoid(self.left_gate(pair))
left = left_gate * left
right = self.right_proj(pair) # (B, L, L, d_h)
right_gate = torch.sigmoid(self.right_gate(pair))
right = right_gate * right
if self.outgoing:
out = einsum('bikd,bjkd->bijd', left, right/float(L))
else:
out = einsum('bkid,bkjd->bijd', left, right/float(L))
out = self.norm_out(out)
out = self.out_proj(out)
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
out = gate * out
return out
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
class BiasedAxialAttention(nn.Module):
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
super(BiasedAxialAttention, self).__init__()
#
self.is_row = is_row
self.norm_pair = nn.LayerNorm(d_pair)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_bias, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
# initialize all parameters properly
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, pair, bias):
# pair: (B, L, L, d_pair)
B, L = pair.shape[:2]
if self.is_row:
pair = pair.permute(0,2,1,3)
pair = self.norm_pair(pair)
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
bias = self.to_b(bias) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
query = query * self.scaling
key = key / math.sqrt(L) # normalize for tied attention
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
attn = attn + bias # apply bias
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1)
out = gate * out
out = self.to_out(out)
if self.is_row:
out = out.permute(0,2,1,3)
return out

View File

@@ -0,0 +1,75 @@
import torch
import torch.nn as nn
from chemical import NAATOKENS
class DistanceNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
super(DistanceNetwork, self).__init__()
#
self.proj_symm = nn.Linear(n_feat, 37*2)
self.proj_asymm = nn.Linear(n_feat, 37+19)
self.reset_parameter()
def reset_parameter(self):
# initialize linear layer for final logit prediction
nn.init.zeros_(self.proj_symm.weight)
nn.init.zeros_(self.proj_asymm.weight)
nn.init.zeros_(self.proj_symm.bias)
nn.init.zeros_(self.proj_asymm.bias)
def forward(self, x):
# input: pair info (B, L, L, C)
# predict theta, phi (non-symmetric)
logits_asymm = self.proj_asymm(x)
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
# predict dist, omega
logits_symm = self.proj_symm(x)
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2)
logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2)
return logits_dist, logits_omega, logits_theta, logits_phi
class MaskedTokenNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.1):
super(MaskedTokenNetwork, self).__init__()
#fd note this predicts probability for the mask token (which is never in ground truth)
# it should be ok though(?)
self.proj = nn.Linear(n_feat, NAATOKENS)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
B, N, L = x.shape[:3]
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
return logits
class LDDTNetwork(nn.Module):
def __init__(self, n_feat, n_bin_lddt=50):
super(LDDTNetwork, self).__init__()
self.proj = nn.Linear(n_feat, n_bin_lddt)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
logits = self.proj(x) # (B, L, 50)
return logits.permute(0,2,1)

279
RF2_allatom/Embeddings.py Normal file
View File

@@ -0,0 +1,279 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from util import *
from util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal
from Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
from Track_module import PairStr2Pair
from chemical import NAATOKENS,NTOTALDOFS, NBTYPES
# Module contains classes and functions to generate initial embeddings
class PositionalEncoding2D(nn.Module):
# Add relative positional encoding to pair features
def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
super(PositionalEncoding2D, self).__init__()
self.minpos = minpos
self.maxpos = maxpos
self.nbin = abs(minpos)+maxpos+1
self.emb = nn.Embedding(self.nbin, d_model)
def forward(self, x, idx):
bins = torch.arange(self.minpos, self.maxpos, device=x.device)
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
#
ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
emb = self.emb(ib) #(B, L, L, d_model)
x = x + emb # add relative positional encoding
return x
class MSA_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2,
minpos=-32, maxpos=32, p_drop=0.1):
super(MSA_emb, self).__init__()
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
self.emb_left = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_state = nn.Embedding(NAATOKENS, d_state)
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
self.emb_q = init_lecun_normal(self.emb_q)
self.emb_left = init_lecun_normal(self.emb_left)
self.emb_right = init_lecun_normal(self.emb_right)
self.emb_state = init_lecun_normal(self.emb_state)
nn.init.zeros_(self.emb.bias)
def forward(self, msa, seq, idx):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
# - idx: Residue index
# Outputs:
# - msa: Initial MSA embedding (B, N, L, d_msa)
# - pair: Initial Pair embedding (B, L, L, d_pair)
N = msa.shape[1] # number of sequenes in MSA
# msa embedding
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
#msa = self.drop(msa)
# pair embedding
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
pair = left + right # (B, L, L, d_pair)
pair = self.pos(pair, idx) # add relative position
# state embedding
state = self.emb_state(seq)
return msa, pair, state
class Extra_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_init=NAATOKENS+1+2, p_drop=0.1):
super(Extra_emb, self).__init__()
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence
#self.drop = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
def forward(self, msa, seq, idx):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
# - idx: Residue index
# Outputs:
# - msa: Initial MSA embedding (B, N, L, d_msa)
N = msa.shape[1] # number of sequenes in MSA
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
#return self.drop(msa)
return (msa)
class Bond_emb(nn.Module):
def __init__(self, d_pair=128, d_init=NBTYPES):
super(Bond_emb, self).__init__()
self.emb = nn.Linear(d_init, d_pair)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
def forward(self, bond_feats):
return self.emb(bond_feats.float())
# TODO: Update template embedding not to use triangles....
# Use input xyz_t with biased attention
class TemplatePairStack(nn.Module):
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, rbf_sigma=1.0, p_drop=0.25):
super(TemplatePairStack, self).__init__()
self.n_block = n_block
self.rbf_sigma = rbf_sigma
proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)]
self.block = nn.ModuleList(proc_s)
self.norm = nn.LayerNorm(d_templ)
def forward(self, templ, xyz_t, use_checkpoint=False):
B, T, L = templ.shape[:3]
templ = templ.reshape(B*T, L, L, -1)
xyz_t = xyz_t.reshape(B*T, L, -1, 3)
rbf_feat = rbf(torch.cdist(xyz_t[:,:,1], xyz_t[:,:,1]), self.rbf_sigma)
for i_block in range(self.n_block):
if use_checkpoint:
templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat)
else:
templ = self.block[i_block](templ, rbf_feat)
return self.norm(templ).reshape(B, T, L, L, -1)
class Templ_emb(nn.Module):
# Get template embedding
# Features are
# t2d:
# - 37 distogram bins + 6 orientations (43)
# - Mask (missing/unaligned) (1)
# t1d:
# - tiled AA sequence (20 standard aa + gap)
# - confidence (1)
#
def __init__(self, d_t1d=(NAATOKENS-1)+1, d_t2d=43+1, d_tor=3*NTOTALDOFS, d_pair=128, d_state=32,
n_block=2, d_templ=64,
n_head=4, d_hidden=16, p_drop=0.25):
super(Templ_emb, self).__init__()
# process 2D features
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
d_hidden=d_hidden, p_drop=p_drop)
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
# process torsion angles
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
self.proj_t1d = nn.Linear(d_templ, d_templ)
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
# d_hidden=d_hidden, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
nn.init.zeros_(self.emb_t1d.bias)
self.proj_t1d = init_lecun_normal(self.proj_t1d)
nn.init.zeros_(self.proj_t1d.bias)
def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False):
# Input
# - t1d: 1D template info (B, T, L, 30)
# - t2d: 2D template info (B, T, L, L, 44)
B, T, L, _ = t1d.shape
# Prepare 2D template features
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
#
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
templ = self.emb(templ) # Template templures (B, T, L, L, d_templ)
# process each template features
xyz_t = xyz_t.reshape(B*T, L, -1, 3)
templ = self.templ_stack(templ, xyz_t, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ)
# Prepare 1D template torsion angle features
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17)
# process each template features
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
# mixing query state features to template state features
state = state.reshape(B*L, 1, -1)
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
if use_checkpoint:
out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d)
out = out.reshape(B, L, -1)
else:
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
state = state.reshape(B, L, -1)
state = state + out
# mixing query pair features to template information (Template pointwise attention)
pair = pair.reshape(B*L*L, 1, -1)
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
if use_checkpoint:
out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ)
out = out.reshape(B, L, L, -1)
else:
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
#
pair = pair.reshape(B, L, L, -1)
pair = pair + out
return pair, state
class Recycling(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=32, rbf_sigma=1.0):
super(Recycling, self).__init__()
self.proj_dist = nn.Linear(36+d_state*2, d_pair)
self.norm_pair = nn.LayerNorm(d_pair)
self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa)
self.norm_msa = nn.LayerNorm(d_msa)
self.rbf_sigma = rbf_sigma
self.norm_state = nn.LayerNorm(d_state)
self.reset_parameter()
def reset_parameter(self):
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
self.proj_sctors = init_lecun_normal(self.proj_sctors)
nn.init.zeros_(self.proj_sctors.bias)
def forward(self, msa, pair, xyz, state, sctors):
B, L = pair.shape[:2]
state = self.norm_state(state)
left = state.unsqueeze(2).expand(-1,-1,L,-1)
right = state.unsqueeze(1).expand(-1,L,-1,-1)
Ca_or_P = xyz[:,:,1]
# recreate Cb given N,Ca,C
#N = xyz[:,:,0]
#C = xyz[:,:,2]
#Cb = generate_Cbeta(N,Ca,C)
#dist = rbf(torch.cdist(Cb, Cb), self.rbf_sigma)
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P), self.rbf_sigma)
dist = torch.cat((dist, left, right), dim=-1)
dist = self.proj_dist(dist)
pair = dist + self.norm_pair(pair)
sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS))
msa = sctors + self.norm_msa(msa)
return msa, pair, state

View File

@@ -0,0 +1,111 @@
import torch
import torch.nn as nn
from Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, Recycling
from Track_module import IterativeSimulator
from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, LDDTNetwork
from chemical import INIT_CRDS,NAATOKENS, NBTYPES
class RoseTTAFoldModule(nn.Module):
def __init__(
self, n_extra_block=4, n_main_block=8, n_ref_block=4, n_finetune_block=0,\
d_msa=256, d_msa_full=64, d_pair=128, d_templ=64,
n_head_msa=8, n_head_pair=4, n_head_templ=4,
d_hidden=32, d_hidden_templ=64,
rbf_sigma=1.0, p_drop=0.15,
SE3_param={}, SE3_ref_param={},
atom_type_index=None, aamask=None, ljlk_parameters=None, lj_correction_parameters=None,
cb_len=None, cb_ang=None, cb_tor=None,
num_bonds=None, lj_lin=0.6
):
super(RoseTTAFoldModule, self).__init__()
#
# Input Embeddings
d_state = SE3_param['l0_out_features']
self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop)
self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=NAATOKENS-1+4, p_drop=p_drop)
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=NBTYPES)
self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, n_head=n_head_templ,
d_hidden=d_hidden_templ, p_drop=0.25)
# Update inputs with outputs from previous round
self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state, rbf_sigma=rbf_sigma)
#
self.simulator = IterativeSimulator(
n_extra_block=n_extra_block,
n_main_block=n_main_block,
n_ref_block=n_ref_block,
n_finetune_block=n_finetune_block,
d_msa=d_msa,
d_msa_full=d_msa_full,
d_pair=d_pair,
d_hidden=d_hidden,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
SE3_param=SE3_param,
SE3_ref_param=SE3_ref_param,
rbf_sigma=rbf_sigma,
p_drop=p_drop,
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
aamask=aamask,
ljlk_parameters=ljlk_parameters,
lj_correction_parameters=lj_correction_parameters,
num_bonds=num_bonds,
cb_len=cb_len,
cb_ang=cb_ang,
cb_tor=cb_tor,
lj_lin=lj_lin
)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.lddt_pred = LDDTNetwork(d_state)
def forward(
self, msa_latent, msa_full, seq, seq_unmasked, xyz, sctors, idx, bond_feats,
t1d=None, t2d=None, xyz_t=None, alpha_t=None,
msa_prev=None, pair_prev=None, state_prev=None,
return_raw=False, return_full=False,
use_checkpoint=False
):
B, N, L = msa_latent.shape[:3]
# Get embeddings
msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx)
msa_full = self.full_emb(msa_full, seq, idx)
pair = pair + self.bond_emb(bond_feats)
#
# Do recycling
if msa_prev == None:
msa_prev = torch.zeros_like(msa_latent[:,0])
pair_prev = torch.zeros_like(pair)
state_prev = torch.zeros_like(state)
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
state = state + state_recycle
# add template embedding
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint)
# Predict coordinates from given inputs
msa, pair, xyz, alpha_s, xyz_allatom, state = self.simulator(
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx, use_checkpoint=use_checkpoint)
if return_raw:
# get last structure
xyz_last = xyz_allatom[-1].unsqueeze(0)
return msa[:,0], pair, xyz_last, state, alpha_s[-1]
# predict masked amino acids
logits_aa = self.aa_pred(msa)
# predict distogram & orientograms
logits = self.c6d_pred(pair)
# Predict LDDT
lddt = self.lddt_pred(state)
return logits, logits_aa, xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state

View File

@@ -0,0 +1,83 @@
import torch
import torch.nn as nn
#from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias
#from equivariant_attention.modules import GConvSE3, GNormSE3
#from equivariant_attention.fibers import Fiber
from util_module import init_lecun_normal_param
from se3_transformer.model import SE3Transformer
from se3_transformer.model.fiber import Fiber
class SE3TransformerWrapper(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=2,
num_edge_features=32):
super().__init__()
# Build the network
self.l1_in = l1_in_features
#
fiber_edge = Fiber({0: num_edge_features})
if l1_out_features > 0:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
self.se3 = SE3Transformer(num_layers=num_layers,
fiber_in=fiber_in,
fiber_hidden=fiber_hidden,
fiber_out = fiber_out,
num_heads=n_heads,
channels_div=div,
fiber_edge=fiber_edge,
use_layer_norm=True)
#use_layer_norm=False)
self.reset_parameter()
def reset_parameter(self):
# make sure linear layer before ReLu are initialized with kaiming_normal_
for n, p in self.se3.named_parameters():
if "bias" in n:
nn.init.zeros_(p)
elif len(p.shape) == 1:
continue
else:
if "radial_func" not in n:
p = init_lecun_normal_param(p)
else:
if "net.6" in n:
nn.init.zeros_(p)
else:
nn.init.kaiming_normal_(p, nonlinearity='relu')
# make last layers to be zero-initialized
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
if self.l1_in > 0:
node_features = {'0': type_0_features, '1': type_1_features}
else:
node_features = {'0': type_0_features}
edge_features = {'0': edge_features}
return self.se3(G, node_features, edge_features)

693
RF2_allatom/Track_module.py Normal file
View File

@@ -0,0 +1,693 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from util_module import *
from Attention_module import *
from SE3_network import SE3TransformerWrapper
from resnet import ResidualNetwork
from util import INIT_CRDS, is_atom
from loss import (
calc_BB_bond_geom_grads, calc_lj_grads, calc_hb_grads, calc_cart_bonded_grads, calc_ljallatom_grads,
calc_lj, calc_cart_bonded
)
from chemical import NTOTALDOFS
# Components for three-track blocks
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
# 2. Pair -> Pair update (biased attention. bias from structure)
# 3. MSA -> Pair update (extract coevolution signal)
# 4. Str -> Str update (node from MSA, edge from Pair)
# Update MSA with biased self-attention. bias from Pair & Str
class MSAPairStr2MSA(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16,
d_hidden=32, p_drop=0.15, use_global_attn=False):
super(MSAPairStr2MSA, self).__init__()
self.norm_pair = nn.LayerNorm(d_pair)
self.proj_pair = nn.Linear(d_pair+36, d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.proj_state = nn.Linear(d_state, d_msa)
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
n_head=n_head, d_hidden=d_hidden)
if use_global_attn:
self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
else:
self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
# Do proper initialization
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distrib
self.proj_pair = init_lecun_normal(self.proj_pair)
self.proj_state = init_lecun_normal(self.proj_state)
# initialize bias to zeros
nn.init.zeros_(self.proj_pair.bias)
nn.init.zeros_(self.proj_state.bias)
def forward(self, msa, pair, rbf_feat, state):
'''
Inputs:
- msa: MSA feature (B, N, L, d_msa)
- pair: Pair feature (B, L, L, d_pair)
- rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
- xyz: xyz coordinates (B, L, n_atom, 3)
- state: updated node features after SE(3)-Transformer layer (B, L, d_state)
Output:
- msa: Updated MSA feature (B, N, L, d_msa)
'''
B, N, L = msa.shape[:3]
# prepare input bias feature by combining pair & coordinate info
pair = self.norm_pair(pair)
pair = torch.cat((pair, rbf_feat), dim=-1)
pair = self.proj_pair(pair) # (B, L, L, d_pair)
#
# update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
state = self.norm_state(state)
state = self.proj_state(state).reshape(B, 1, L, -1)
msa = msa.index_add(1, torch.tensor([0,], device=state.device), state.float())
#
# Apply row/column attention to msa & transform
msa = msa + self.drop_row(self.row_attn(msa, pair))
msa = msa + self.col_attn(msa)
msa = msa + self.ff(msa)
return msa
class PairStr2Pair(nn.Module):
def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_rbf=36, p_drop=0.15):
super(PairStr2Pair, self).__init__()
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
self.row_attn = BiasedAxialAttention(d_pair, d_rbf, n_head, d_hidden, p_drop=p_drop, is_row=True)
self.col_attn = BiasedAxialAttention(d_pair, d_rbf, n_head, d_hidden, p_drop=p_drop, is_row=False)
self.ff = FeedForwardLayer(d_pair, 2)
def forward(self, pair, rbf_feat):
pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
pair = pair + self.ff(pair)
return pair
class MSA2Pair(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
super(MSA2Pair, self).__init__()
self.norm = nn.LayerNorm(d_msa)
self.proj_left = nn.Linear(d_msa, d_hidden)
self.proj_right = nn.Linear(d_msa, d_hidden)
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
#self.proj_down = nn.Linear(d_pair*2, d_pair)
#self.update = ResidualNetwork(1, d_pair, d_pair, d_pair, p_drop=p_drop)
self.reset_parameter()
def reset_parameter(self):
# normal initialization
self.proj_left = init_lecun_normal(self.proj_left)
self.proj_right = init_lecun_normal(self.proj_right)
self.proj_out = init_lecun_normal(self.proj_out)
nn.init.zeros_(self.proj_left.bias)
nn.init.zeros_(self.proj_right.bias)
nn.init.zeros_(self.proj_out.bias)
# Identity initialization for proj_down
#nn.init.eye_(self.proj_down.weight)
#nn.init.zeros_(self.proj_down.bias)
def forward(self, msa, pair):
B, N, L = msa.shape[:3]
msa = self.norm(msa)
left = self.proj_left(msa)
right = self.proj_right(msa)
right = right / float(N)
out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
out = self.proj_out(out)
#pair = torch.cat((pair, out), dim=-1) # (B, L, L, d_pair*2)
#pair = self.proj_down(pair)
#pair = self.update(pair.permute(0,3,1,2).contiguous())
#pair = pair.permute(0,2,3,1).contiguous()
pair = pair + out
return pair
class Str2Str(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=16,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
nextra_l0=0, nextra_l1=0,
rbf_sigma=1.0, p_drop=0.1
):
super(Str2Str, self).__init__()
# initial node & pair feature process
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_state = nn.LayerNorm(d_state)
self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features'])
self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features'])
SE3_param_temp = SE3_param.copy()
SE3_param_temp['l0_in_features'] += nextra_l0
SE3_param_temp['l1_in_features'] += nextra_l1
self.se3 = SE3TransformerWrapper(**SE3_param_temp)
self.rbf_sigma = rbf_sigma
self.sc_predictor = SCPred(
d_msa=d_msa,
d_state=SE3_param['l0_out_features'],
p_drop=p_drop)
#self.nextra_l0 = nextra_l0
#self.nextra_l1 = nextra_l1
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.embed_x = init_lecun_normal(self.embed_x)
self.embed_e1 = init_lecun_normal(self.embed_e1)
self.embed_e2 = init_lecun_normal(self.embed_e2)
# initialize bias to zeros
nn.init.zeros_(self.embed_x.bias)
nn.init.zeros_(self.embed_e1.bias)
nn.init.zeros_(self.embed_e2.bias)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, msa, pair, xyz, state, idx, rotation_mask, extra_l0=None, extra_l1=None, top_k=128, eps=1e-5):
# process msa & pair features
B, N, L = msa.shape[:3]
node = self.norm_msa(msa[:,0])
pair = self.norm_pair(pair)
state = self.norm_state(state)
node = torch.cat((node, state), dim=-1)
node = self.norm_node(self.embed_x(node))
pair = self.norm_edge1(self.embed_e1(pair))
neighbor = get_seqsep(idx)
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas), self.rbf_sigma)
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
pair = self.norm_edge2(self.embed_e2(pair))
# define graph
if top_k != 0:
G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k)
else:
G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx)
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
l1_feats = l1_feats.reshape(B*L, -1, 3)
if extra_l1 is not None:
l1_feats = torch.cat( (l1_feats,extra_l1), dim=1 )
if extra_l0 is not None:
node = torch.cat( (node,extra_l0), dim=2 )
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
state = shift['0'].reshape(B, L, -1) # (B, L, C)
offset = shift['1'].reshape(B, L, 2, 3)
T = offset[:,:,0,:] / 10.0
R = offset[:,:,1,:] / 100.0
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
v = xyz - xyz[:,:,1:2,:]
Rout = torch.zeros((B,L,3,3), device=xyz.device)
Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
Rout = torch.where(rotation_mask.reshape(B, L, 1,1), I, Rout)
xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
alpha = self.sc_predictor(msa[:,0], state)
return xyz, state, alpha
class Allatom2Allatom(nn.Module):
def __init__(
self,
SE3_param
):
super(Allatom2Allatom, self).__init__()
self.se3 = SE3TransformerWrapper(**SE3_param)
@torch.cuda.amp.autocast(enabled=False)
def forward(self, seq, xyz, aamask, num_bonds, state, grads, top_k=24, eps=1e-5):
# seq (B,L)
# xyz (B,L,27,3)
# aamask (22,27) [per-amino-acid]
# num_bonds (22,27,27) [per-amino-acid]
# state (N,B,L,K) [K channels]
# grads (N,B,L,27,3) [N terms]
B, L = xyz.shape[:2]
mask = aamask[seq]
G, edge = make_atom_graph( xyz, mask, num_bonds[seq], top_k, maxbonds=4 )
node = state[mask]
node_l1 = grads[:,mask].permute(1,0,2)
# apply SE(3) Transformer & update coordinates
shift = self.se3(G, node[...,None], node_l1, edge)
state[mask] = shift['0'][...,0]
xyz[mask] = xyz[mask] + shift['1'].squeeze(1) / 100.0
return xyz, state
class AllatomEmbed(nn.Module):
def __init__(
self,
d_state_in=64,
d_state_out=32,
p_mask=0.15
):
super(AllatomEmbed, self).__init__()
self.p_mask = p_mask
# initial node & pair feature process
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
self.norm_state = nn.LayerNorm(d_state_out)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.compress_embed = init_lecun_normal(self.compress_embed)
# initialize bias to zeros
nn.init.zeros_(self.compress_embed.bias)
def forward(self, state, seq, eltmap):
B,L = state.shape[:2]
mask = torch.rand(B,L) < self.p_mask
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
state[mask] = 0.0
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
state = self.compress_embed(
torch.cat( (state,elements), dim=-1 )
)
state = self.norm_state( state )
return state
# embed residue state + atomtype -> per-atom state
#
class AllatomEmbed(nn.Module):
def __init__(
self,
d_state_in=64,
d_state_out=32,
p_mask=0.15
):
super(AllatomEmbed, self).__init__()
self.p_mask = p_mask
# initial node & pair feature process
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
self.norm_state = nn.LayerNorm(d_state_out)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.compress_embed = init_lecun_normal(self.compress_embed)
# initialize bias to zeros
nn.init.zeros_(self.compress_embed.bias)
def forward(self, state, seq, eltmap):
B,L = state.shape[:2]
mask = torch.rand(B,L) < self.p_mask
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
state[mask] = 0.0
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
state = self.compress_embed(
torch.cat( (state,elements), dim=-1 )
)
state = self.norm_state( state )
return state
# embed per-atom state -> residue state
class ResidueEmbed(nn.Module):
def __init__(
self,
d_state_in=16,
d_state_out=64
):
super(ResidueEmbed, self).__init__()
self.compress_embed = nn.Linear(27*d_state_in, d_state_out)
self.norm_state = nn.LayerNorm(d_state_out)
self.reset_parameter()
def reset_parameter(self):
# initialize weights to normal distribution
self.compress_embed = init_lecun_normal(self.compress_embed)
# initialize bias to zeros
nn.init.zeros_(self.compress_embed.bias)
def forward(self, state):
B,L = state.shape[:2]
state = self.compress_embed( state.reshape(B,L,-1) )
state = self.norm_state( state )
return state
class SCPred(nn.Module):
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
super(SCPred, self).__init__()
self.norm_s0 = nn.LayerNorm(d_msa)
self.norm_si = nn.LayerNorm(d_state)
self.linear_s0 = nn.Linear(d_msa, d_hidden)
self.linear_si = nn.Linear(d_state, d_hidden)
# ResNet layers
self.linear_1 = nn.Linear(d_hidden, d_hidden)
self.linear_2 = nn.Linear(d_hidden, d_hidden)
self.linear_3 = nn.Linear(d_hidden, d_hidden)
self.linear_4 = nn.Linear(d_hidden, d_hidden)
# Final outputs
self.linear_out = nn.Linear(d_hidden, 2*NTOTALDOFS)
self.reset_parameter()
def reset_parameter(self):
# normal initialization
self.linear_s0 = init_lecun_normal(self.linear_s0)
self.linear_si = init_lecun_normal(self.linear_si)
self.linear_out = init_lecun_normal(self.linear_out)
nn.init.zeros_(self.linear_s0.bias)
nn.init.zeros_(self.linear_si.bias)
nn.init.zeros_(self.linear_out.bias)
# right before relu activation: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_1.bias)
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
nn.init.zeros_(self.linear_3.bias)
# right before residual connection: zero initialize
nn.init.zeros_(self.linear_2.weight)
nn.init.zeros_(self.linear_2.bias)
nn.init.zeros_(self.linear_4.weight)
nn.init.zeros_(self.linear_4.bias)
def forward(self, seq, state):
'''
Predict side-chain torsion angles along with backbone torsions
Inputs:
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
Outputs:
- si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2)
'''
B, L = seq.shape[:2]
seq = self.norm_s0(seq)
state = self.norm_si(state)
si = self.linear_s0(seq) + self.linear_si(state)
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
si = self.linear_out(F.relu_(si))
return si.view(B, L, NTOTALDOFS, 2)
class IterBlock(nn.Module):
def __init__(self, d_msa=256, d_pair=128,
n_head_msa=8, n_head_pair=4,
use_global_attn=False,
d_hidden=32, d_hidden_msa=None, rbf_sigma=1.0, p_drop=0.15,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}):
super(IterBlock, self).__init__()
if d_hidden_msa == None:
d_hidden_msa = d_hidden
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair,
n_head=n_head_msa,
d_state=SE3_param['l0_out_features'],
use_global_attn=use_global_attn,
d_hidden=d_hidden_msa, p_drop=p_drop)
self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
d_hidden=16, p_drop=p_drop) # fd - use only 16 channels
self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair,
d_hidden=d_hidden, p_drop=p_drop)
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_param,
rbf_sigma=rbf_sigma,
p_drop=p_drop)
self.rbf_sigma = rbf_sigma
def forward(self, msa, pair, xyz, state, idx, use_checkpoint=False, top_k=128, rotation_mask=None):
cas = xyz[:,:,1].contiguous()
rbf_feat = rbf(torch.cdist(cas, cas), self.rbf_sigma)
if use_checkpoint:
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
xyz, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=top_k),
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask)
else:
msa = self.msa2msa(msa, pair, rbf_feat, state)
pair = self.msa2pair(msa, pair)
pair = self.pair2pair(pair, rbf_feat)
xyz, state, alpha = self.str2str(msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask, top_k=top_k)
return msa, pair, xyz, state, alpha
class IterativeSimulator(nn.Module):
def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4, n_finetune_block=0,
d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
n_head_msa=8, n_head_pair=4,
SE3_param={}, SE3_ref_param={},
rbf_sigma=1.0, p_drop=0.15,
atom_type_index=None, aamask=None,
ljlk_parameters=None, lj_correction_parameters=None,
cb_len=None, cb_ang=None, cb_tor=None,
num_bonds=None, lj_lin=0.6
):
super(IterativeSimulator, self).__init__()
self.n_extra_block = n_extra_block
self.n_main_block = n_main_block
self.n_ref_block = n_ref_block
self.n_finetune_block = n_finetune_block
self.atom_type_index = atom_type_index
self.aamask = aamask
self.ljlk_parameters = ljlk_parameters
self.lj_correction_parameters = lj_correction_parameters
self.num_bonds = num_bonds
self.lj_lin = lj_lin
self.cb_len = cb_len
self.cb_ang = cb_ang
self.cb_tor = cb_tor
# Update with extra sequences
if n_extra_block > 0:
self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
d_hidden_msa=8,
d_hidden=d_hidden,
p_drop=p_drop,
rbf_sigma=rbf_sigma,
use_global_attn=True,
SE3_param=SE3_param)
for i in range(n_extra_block)])
# Update with seed sequences
if n_main_block > 0:
self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
d_hidden=d_hidden,
p_drop=p_drop,
rbf_sigma=rbf_sigma,
use_global_attn=False,
SE3_param=SE3_param)
for i in range(n_main_block)])
# Final SE(3) refinement
if n_ref_block > 0:
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
d_state=SE3_param['l0_out_features'],
SE3_param=SE3_ref_param,
rbf_sigma=rbf_sigma,
p_drop=p_drop,
# nextra_l0=2*NTOTALDOFS,
# nextra_l1=6
)
# Fine-tuning all-atom SE(3) refinement
if n_finetune_block > 0:
d_state=16
self.allatom_embed = AllatomEmbed(
d_state_in = SE3_param['l0_out_features'],
d_state_out = d_state,
p_mask = 0.15
)
self.finetune_refiner = Allatom2Allatom(
SE3_param = {
'num_layers':1,
'num_channels':16,
'num_degrees':2,
'l0_in_features':d_state,
'l0_out_features':d_state,
'l1_in_features':2,
'l1_out_features':1,
'num_edge_features':4,
'n_heads':4,
'div':2,
}
)
self.residue_embed = ResidueEmbed(
d_state_in = d_state,
d_state_out = SE3_param['l0_out_features']
)
# To get all-atom coordinates
self.compute_allatom_coords = ComputeAllAtomCoords()
def forward(self, seq_unmasked, msa, msa_full, pair, xyz, state, idx, use_checkpoint=False):
# input:
# msa: initial MSA embeddings (N, L, d_msa)
# pair: initial residue pair embeddings (L, L, d_pair)
rotation_mask = is_atom(seq_unmasked)
xyz_s = list()
alpha_s = list()
for i_m in range(self.n_extra_block):
msa_full, pair, xyz, state, alpha = self.extra_block[i_m](msa_full, pair,
xyz, state, idx,
use_checkpoint=use_checkpoint, top_k=0, rotation_mask=rotation_mask)
xyz_s.append(xyz)
alpha_s.append(alpha)
for i_m in range(self.n_main_block):
msa, pair, xyz, state, alpha = self.main_block[i_m](msa, pair,
xyz, state, idx,
use_checkpoint=use_checkpoint, top_k=0, rotation_mask=rotation_mask)
xyz_s.append(xyz)
alpha_s.append(alpha)
_, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here...
# now use unmasked seq (no cross-talk for msa prediction)
for i_m in range(self.n_ref_block):
# dbonddxyz, = calc_BB_bond_geom_grads(seq_unmasked[0], xyz.detach(), idx)
# dljdxyz, dljdalpha = calc_lj_grads(
# seq_unmasked, xyz.detach(), alpha.detach(),
# self.compute_allatom_coords,
# self.aamask,
# self.ljlk_parameters,
# self.lj_correction_parameters,
# self.num_bonds,
# lj_lin=self.lj_lin)
# extra_l1 = torch.cat((dbonddxyz[0].detach(),dljdxyz[0].detach()), dim=1)
# extra_l0 = dljdalpha.reshape(1,-1,2*NTOTALDOFS).detach()
extra_l0 =None
extra_l1= None
xyz, state, alpha = self.str_refiner(
msa, pair, xyz.detach(), state, idx, rotation_mask,
extra_l0, extra_l1, top_k=128)
xyz_s.append(xyz)
alpha_s.append(alpha)
_, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here...
xyzallatom_s = list()
xyzallatom_s.append(xyzallatom.clone())
if (self.n_finetune_block>0):
state = self.allatom_embed(state, seq_unmasked, self.atom_type_index)
for i_m in range(self.n_finetune_block):
# dbonddxyz, = calc_cart_bonded_grads(
# seq_unmasked, xyzallatom.detach(), idx,
# self.cb_len, self.cb_ang, self.cb_tor
# )
# dljdxyz, = calc_ljallatom_grads(
# seq_unmasked,
# xyzallatom.detach(),
# self.aamask,
# self.ljlk_parameters,
# self.lj_correction_parameters,
# self.num_bonds,
# lj_lin=self.lj_lin
# )
# extra_l1 = torch.stack((dbonddxyz.detach(), dljdxyz.detach()))
extra_l1 = None
xyzallatom, state = self.finetune_refiner(
seq_unmasked,
xyzallatom.detach().float(),
self.aamask,
self.num_bonds,
state,
extra_l1.float()
)
# cb_loss = calc_cart_bonded(
# seq_unmasked, xyzallatom.detach(), idx,
# self.cb_len, self.cb_ang, self.cb_tor
# )
# lj_loss = calc_lj(
# seq_unmasked[0],
# xyzallatom.detach(),
# self.aamask,
# self.ljlk_parameters,
# self.lj_correction_parameters,
# self.num_bonds,
# lj_lin=self.lj_lin
# )
xyzallatom_s.append(xyzallatom.clone())
state = self.residue_embed(state)
xyz = torch.stack(xyz_s, dim=0)
alpha_s = torch.stack(alpha_s, dim=0)
xyzallatom_s = torch.cat(xyzallatom_s, dim=0)
return msa, pair, xyz, alpha_s, xyzallatom_s, state

172
RF2_allatom/arguments.py Normal file
View File

@@ -0,0 +1,172 @@
import argparse
import data_loader
import os
TRUNK_PARAMS = ['n_extra_block', 'n_main_block', 'n_ref_block', 'n_finetune_block',\
'd_msa', 'd_msa_full', 'd_pair', 'd_templ',\
'n_head_msa', 'n_head_pair', 'n_head_templ', 'd_hidden', 'd_hidden_templ', 'p_drop', 'rbf_sigma']
SE3_PARAMS = ['num_layers', 'num_channels', 'num_degrees', 'n_heads', 'div',
'l0_in_features', 'l0_out_features', 'l1_in_features', 'l1_out_features', 'num_edge_features'
]
def get_args():
parser = argparse.ArgumentParser()
# training parameters
train_group = parser.add_argument_group("training parameters")
train_group.add_argument("-model_name", default=None,
help="model name for saving")
train_group.add_argument('-batch_size', type=int, default=1,
help="Batch size [1]")
train_group.add_argument('-lr', type=float, default=2.0e-4,
help="Learning rate [5.0e-4]")
train_group.add_argument('-num_epochs', type=int, default=300,
help="Number of epochs [300]")
train_group.add_argument("-step_lr", type=int, default=300,
help="Parameter for Step LR scheduler [300]")
train_group.add_argument("-port", type=int, default=12319,
help="PORT for ddp training, should be randomized [12319]")
train_group.add_argument("-accum", type=int, default=1,
help="Gradient accumulation when it's > 1 [1]")
train_group.add_argument("-eval", action='store_true', default=False,
help="Train structure only")
# data-loading parameters
data_group = parser.add_argument_group("data loading parameters")
data_group.add_argument('-maxseq', type=int, default=1024,
help="Maximum depth of subsampled MSA [1024]")
data_group.add_argument('-maxtoken', type=int, default=2**18,
help="Maximum depth of subsampled MSA [2**18]")
data_group.add_argument('-maxlat', type=int, default=128,
help="Maximum depth of subsampled MSA [128]")
data_group.add_argument("-crop", type=int, default=260,
help="Upper limit of crop size [260]")
data_group.add_argument("-rescut", type=float, default=4.5,
help="Resolution cutoff [4.5]")
data_group.add_argument("-slice", type=str, default="DISCONT",
help="How to make crops [CONT / DISCONT (default)]")
data_group.add_argument("-subsmp", type=str, default="UNI",
help="How to subsample MSAs [UNI (default) / LOG / CONST]")
data_group.add_argument('-mintplt', type=int, default=1,
help="Minimum number of templates to select [1]")
data_group.add_argument('-maxtplt', type=int, default=4,
help="maximum number of templates to select [4]")
data_group.add_argument('-seqid', type=float, default=150.0,
help="maximum sequence identity cutoff for template selection [150.0]")
data_group.add_argument('-maxcycle', type=int, default=4,
help="maximum number of recycle [4]")
# Trunk module properties
trunk_group = parser.add_argument_group("Trunk module parameters")
trunk_group.add_argument('-n_extra_block', type=int, default=4,
help="Number of iteration blocks for extra sequences [4]")
trunk_group.add_argument('-n_main_block', type=int, default=8,
help="Number of iteration blocks for main sequences [8]")
trunk_group.add_argument('-n_ref_block', type=int, default=4,
help="Number of refinement layers")
trunk_group.add_argument('-n_finetune_block', type=int, default=0,
help="Number of finetune layers" [0])
trunk_group.add_argument('-d_msa', type=int, default=256,
help="Number of MSA features [256]")
trunk_group.add_argument('-d_msa_full', type=int, default=64,
help="Number of MSA features [64]")
trunk_group.add_argument('-d_pair', type=int, default=128,
help="Number of pair features [128]")
trunk_group.add_argument('-d_templ', type=int, default=64,
help="Number of templ features [64]")
trunk_group.add_argument('-n_head_msa', type=int, default=8,
help="Number of attention heads for MSA2MSA [8]")
trunk_group.add_argument('-n_head_pair', type=int, default=4,
help="Number of attention heads for Pair2Pair [4]")
trunk_group.add_argument('-n_head_templ', type=int, default=4,
help="Number of attention heads for template [4]")
trunk_group.add_argument("-d_hidden", type=int, default=32,
help="Number of hidden features [32]")
trunk_group.add_argument("-d_hidden_templ", type=int, default=64,
help="Number of hidden features for templates [64]")
trunk_group.add_argument("-p_drop", type=float, default=0.15,
help="Dropout ratio [0.15]")
trunk_group.add_argument("-rbf_sigma", type=float, default=1.0,
help="Sigma scale factor for RBF [1.0]")
# Structure module properties
str_group = parser.add_argument_group("structure module parameters")
str_group.add_argument('-num_layers', type=int, default=1,
help="Number of equivariant layers in structure module block [1]")
str_group.add_argument('-num_channels', type=int, default=32,
help="Number of channels [32]")
str_group.add_argument('-num_degrees', type=int, default=2,
help="Number of degrees for SE(3) network [2]")
str_group.add_argument('-l0_in_features', type=int, default=64,
help="Number of type 0 input features [64]")
str_group.add_argument('-l0_out_features', type=int, default=64,
help="Number of type 0 output features [64]")
str_group.add_argument('-l1_in_features', type=int, default=3,
help="Number of type 1 input features [3]")
str_group.add_argument('-l1_out_features', type=int, default=2,
help="Number of type 1 output features [2]")
str_group.add_argument('-num_edge_features', type=int, default=64,
help="Number of edge features [64]")
str_group.add_argument('-n_heads', type=int, default=4,
help="Number of attention heads for SE3-Transformer [4]")
str_group.add_argument("-div", type=int, default=4,
help="Div parameter for SE3-Transformer [4]")
str_group.add_argument('-ref_num_layers', type=int, default=1,
help="Number of equivariant layers in structure module block [1]")
str_group.add_argument('-ref_num_channels', type=int, default=32,
help="Number of channels [32]")
# Loss function parameters
loss_group = parser.add_argument_group("loss parameters")
loss_group.add_argument('-w_dist', type=float, default=1.0,
help="Weight on distd in loss function [1.0]")
loss_group.add_argument('-w_str', type=float, default=10.0,
help="Weight on strd in loss function [10.0]")
loss_group.add_argument('-w_lddt', type=float, default=0.1,
help="Weight on predicted lddt loss [0.1]")
loss_group.add_argument('-w_aa', type=float, default=3.0,
help="Weight on MSA masked token prediction loss [3.0]")
loss_group.add_argument('-w_bond', type=float, default=0.0,
help="Weight on predicted bond loss [0.0]")
loss_group.add_argument('-w_dih', type=float, default=0.0,
help="Weight on pseudodihedral loss [0.0]")
loss_group.add_argument('-w_clash', type=float, default=0.0,
help="Weight on clash loss [0.0]")
loss_group.add_argument('-w_hb', type=float, default=0.0,
help="Weight on clash loss [0.0]")
loss_group.add_argument('-lj_lin', type=float, default=0.75,
help="linear inflection for lj [0.75]")
# parse arguments
args = parser.parse_args()
# Setup dataloader parameters:
loader_param = data_loader.set_data_loader_params(args)
# make dictionary for each parameters
trunk_param = {}
for param in TRUNK_PARAMS:
trunk_param[param] = getattr(args, param)
SE3_param = {}
for param in SE3_PARAMS:
if hasattr(args, param):
SE3_param[param] = getattr(args, param)
SE3_ref_param = SE3_param.copy()
for param in SE3_PARAMS:
if hasattr(args, 'ref_'+param):
SE3_ref_param[param] = getattr(args, 'ref_'+param)
#print (SE3_param)
#print (SE3_ref_param)
trunk_param['SE3_param'] = SE3_param
trunk_param['SE3_ref_param'] = SE3_ref_param
loss_param = {}
for param in ['w_dist', 'w_str', 'w_aa', 'w_lddt', 'w_bond', 'w_dih', 'w_clash', 'w_hb', 'lj_lin']:
loss_param[param] = getattr(args, param)
return args, trunk_param, loader_param, loss_param

9052
RF2_allatom/cartbonded.json Normal file

File diff suppressed because it is too large Load Diff

1076
RF2_allatom/chemical.py Normal file

File diff suppressed because it is too large Load Diff

78
RF2_allatom/coords6d.py Normal file
View File

@@ -0,0 +1,78 @@
import numpy as np
import scipy
import scipy.spatial
from util import generate_Cbeta
# calculate dihedral angles defined by 4 sets of points
def get_dihedrals(a, b, c, d):
b0 = -1.0*(b - a)
b1 = c - b
b2 = d - c
b1 /= np.linalg.norm(b1, axis=-1)[:,None]
v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1
w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1
x = np.sum(v*w, axis=-1)
y = np.sum(np.cross(b1, v)*w, axis=-1)
return np.arctan2(y, x)
# calculate planar angles defined by 3 sets of points
def get_angles(a, b, c):
v = a - b
v /= np.linalg.norm(v, axis=-1)[:,None]
w = c - b
w /= np.linalg.norm(w, axis=-1)[:,None]
x = np.sum(v*w, axis=1)
#return np.arccos(x)
return np.arccos(np.clip(x, -1.0, 1.0))
# get 6d coordinates from x,y,z coords of N,Ca,C atoms
def get_coords6d(xyz, dmax):
nres = xyz.shape[1]
# three anchor atoms
N = xyz[0]
Ca = xyz[1]
C = xyz[2]
# recreate Cb given N,Ca,C
Cb = generate_Cbeta(N,Ca,C)
# fast neighbors search to collect all
# Cb-Cb pairs within dmax
kdCb = scipy.spatial.cKDTree(Cb)
indices = kdCb.query_ball_tree(kdCb, dmax)
# indices of contacting residues
idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T
idx0 = idx[0]
idx1 = idx[1]
# Cb-Cb distance matrix
dist6d = np.full((nres, nres),999.9, dtype=np.float32)
dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1)
# matrix of Ca-Cb-Cb-Ca dihedrals
omega6d = np.zeros((nres, nres), dtype=np.float32)
omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1])
# matrix of polar coord theta
theta6d = np.zeros((nres, nres), dtype=np.float32)
theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1])
# matrix of polar coord phi
phi6d = np.zeros((nres, nres), dtype=np.float32)
phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1])
mask = np.zeros((nres, nres), dtype=np.float32)
mask[idx0, idx1] = 1.0
return dist6d, omega6d, theta6d, phi6d, mask

1954
RF2_allatom/data_loader.py Normal file

File diff suppressed because it is too large Load Diff

342
RF2_allatom/eval.py Normal file
View File

@@ -0,0 +1,342 @@
import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, parse_fasta, read_template_pdb
from RoseTTAFoldModel import RoseTTAFoldModule
import util
from collections import namedtuple
from ffindex import *
from data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo
from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_init_xyz
from util_module import ComputeAllAtomCoords
from chemical import NTOTAL, NTOTALDOFS, NAATOKENS
from memory import mem_report
MAX_CYCLE = 30
NREPLICATES = 5
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_extra_block" : 4,
"n_main_block" : 32,
"n_ref_block" : 4,
"d_msa" : 256 ,
"d_pair" : 128,
"d_templ" : 64,
"n_head_msa" : 8,
"n_head_pair" : 4,
"n_head_templ" : 4,
"d_hidden" : 32,
"d_hidden_templ" : 64,
"p_drop" : 0.15,
"lj_lin" : 0.75
}
SE3_param = {
"num_layers" : 1,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
SE3_ref_param = {
"num_layers" : 2,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
MODEL_PARAM['SE3_ref_param'] = SE3_ref_param
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
# compute expected value from binned lddt
def lddt_unbin(pred_lddt):
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
class Predictor():
def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"):
if model_dir == None:
self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__)))
else:
self.model_dir = model_dir
#
# define model name
self.model_name = model_name
self.device = device
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule(
**MODEL_PARAM,
aamask=util.allatom_mask.to(self.device),
ljlk_parameters=util.ljlk_parameters.to(self.device),
lj_correction_parameters=util.lj_correction_parameters.to(self.device),
num_bonds=util.num_bonds.to(self.device)
).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device)
def load_model(self, model_name, suffix='last'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
return True
def predict(self, fasta_fn, out_prefix, tmpl_fn=None, atab_fn=None, window=1e9, shift=50, n_latent=256, oligo=1):
msa_orig, ins_orig = parse_fasta(fasta_fn, rmsa_alphabet=True)
msa_orig = torch.tensor(msa_orig).long()
ins_orig = torch.tensor(ins_orig).long()
#sel = torch.arange(941) #, msa_orig.shape[1])
#msa_orig = msa_orig[:,sel]
#ins_orig = ins_orig[:,sel]
if (oligo>1):
msa_orig, ins_orig = merge_a3m_homo(msa_orig, ins_orig, oligo) # make unpaired alignments, for training, we always use two chains
N, L = msa_orig.shape
#
if tmpl_fn and os.path.exists(tmpl_fn):
xyz_t, t1d = read_template_pdb(L, tmpl_fn)
#xyz_t, t1d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=4)
else:
xyz_t = torch.full((1,L,3,3),np.nan).float()
t1d = torch.nn.functional.one_hot(torch.full((1, L), 20).long(), num_classes=NAATOKENS-1).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,L,1)).float()), -1)
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t)
same_chain = torch.ones((1,L,L), dtype=torch.bool, device=xyz_t.device)
xyz_t = get_init_xyz(msa_orig[0:1],xyz_t,same_chain) # initialize coordinates with first template
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
alpha, _, alpha_mask, _ = util.get_torsions(
xyz_t.reshape(-1,L,NTOTAL,3),
seq_tmp,
util.torsion_indices,
util.torsion_can_flip,
util.reference_angles
)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(1,-1,L,NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(1,-1,L,NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 3*NTOTALDOFS)
self.model.eval()
for i_trial in range(NREPLICATES):
if os.path.exists("%s_%02d_init.pdb"%(out_prefix, i_trial)):
continue
self.run_prediction(msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial), n_latent=n_latent)
torch.cuda.empty_cache()
def run_prediction(self, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz, alpha_t, out_prefix, n_latent=256):
start = time.time()
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
#
msa = msa_orig.to(self.device) # (N, L)
ins = ins_orig.long().to(self.device)
#if msa_orig.shape[0] > 4096:
# msa, ins = MSABlockDeletion(msa, ins)
#
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(
msa, ins, p_mask=0.0, params={'MAXLAT': 128, 'MAXSEQ': 1024, 'MAXCYCLE': MAX_CYCLE}, tocpu=True)
_, N, L = msa_seed.shape[:3]
B = 1
#
idx_pdb = torch.arange(L).long().view(1, L)
#
seq = seq.unsqueeze(0)
msa_seed = msa_seed.unsqueeze(0)
msa_extra = msa_extra.unsqueeze(0)
t1d = t1d.to(self.device)
t2d = t2d.to(self.device)
idx_pdb = idx_pdb.to(self.device)
xyz_t = xyz_t.to(self.device)
alpha_t = alpha_t.to(self.device)
xyz = xyz.to(self.device)
self.write_pdb(seq[0, -1], xyz[0], prefix="%s_templ"%(out_prefix))
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((1,L,NTOTALDOFS,2), device=seq.device)
xyz_prev=xyz
state_prev = None
best_lddt = torch.tensor([-1.0], device=seq.device)
best_xyz = None
best_logit = None
best_aa = None
for i_cycle in range(MAX_CYCLE):
msa_seed_i = msa_seed[:,i_cycle].to(self.device)
msa_extra_i = msa_extra[:,i_cycle].to(self.device)
with torch.cuda.amp.autocast(True):
logit_s, logit_aa_s, init_crds, alpha_prev, _, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model(
msa_seed_i,
msa_extra_i,
seq[:,i_cycle],
seq[:,i_cycle],
xyz_prev,
alpha_prev,
idx_pdb,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev
)
logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1)
xyz_prev = init_crds[-1]
alpha_prev = alpha_prev[-1]
pred_lddt = lddt_unbin(pred_lddt_binned)
print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean())
_, all_crds = self.compute_allatom_coords(seq[:,i_cycle], init_crds[-1], alpha_prev)
#self.write_pdb(seq[0, -1], all_crds[0], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle))
if pred_lddt.mean() < best_lddt.mean():
continue
best_xyz = all_crds.clone()
best_logit = logit_s
best_aa = logit_aa_s
best_lddt = pred_lddt.clone()
#print (pred_lddt)
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy()
prob_s.append(prob)
end = time.time()
for prob in prob_s:
prob += 1e-8
prob = prob / torch.sum(prob, dim=0)[None]
self.write_pdb(seq[0, -1], best_xyz[0], Bfacts=100*best_lddt[0], prefix="%s_init"%(out_prefix))
prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s]
np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16),\
lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16))
max_mem = torch.cuda.max_memory_allocated()/1e9
print ("max mem", max_mem)
print ("runtime", end-start)
def write_pdb(self, seq, atoms, Bfacts=None, prefix=None):
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 100)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
"A", i+1, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
else:
atms = util.aa2long[s]
for j,atm_j in enumerate(atms):
if (atm_j is not None):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
#DB="/home/robetta/rosetta_server_beta/external/databases/trRosetta/pdb100_2021Mar03/pdb100_2021Mar03"
DB = "/projects/ml/TrRosetta/pdb100_2020Mar11/pdb100_2020Mar11"
import argparse
parser = argparse.ArgumentParser(description="RoseTTAFold: Protein structure prediction with 3-track attentions on 1D, 2D, and 3D features")
parser.add_argument("fasta", help="fasta for structure prediction")
parser.add_argument("-oligo", type=int, default=1)
parser.add_argument("-prefix", type=str, default="pred")
parser.add_argument("-tmpl", default=None)
parser.add_argument("-model_name", default="BFF", required=False,
help="Prefix for model. The model under models/[model_name]_best.pt will be used. [BFF]")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
pred = Predictor(model_name=args.model_name)
pred.predict(args.fasta, args.prefix, args.tmpl, oligo=args.oligo)

395
RF2_allatom/eval_fb.py Normal file
View File

@@ -0,0 +1,395 @@
import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
#from parsers import parse_a3m, read_templates
from RoseTTAFoldModel import RoseTTAFoldModule
import util
from collections import namedtuple
#from ffindex import *
from data_loader import *
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d, get_init_xyz
from util_module import ComputeAllAtomCoords
from loss import *
MAX_CYCLE = 4
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_extra_block" : 4,
"n_main_block" : 32,
"n_ref_block" : 0,
"n_finetune_block" : 4,
"d_msa" : 256 ,
"d_pair" : 128,
"d_templ" : 64,
"n_head_msa" : 8,
"n_head_pair" : 4,
"n_head_templ" : 4,
"d_hidden" : 32,
"d_hidden_templ" : 64,
"p_drop" : 0.0,
"lj_lin" : 0.6
}
SE3_param = {
"num_layers" : 1,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
LOAD_PARAM = {'shuffle': False,
'num_workers': 4,
'pin_memory': True}
fb_dir = "/projects/ml/TrRosetta/fb_af"
base_dir = "/projects/ml/TrRosetta/PDB30-20FEB17"
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
# compute expected value from binned lddt
def lddt_unbin(pred_lddt):
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
class Predictor():
def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"):
if model_dir == None:
self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__)))
else:
self.model_dir = model_dir
#
# define model name
self.model_name = model_name
self.device = device
self.active_fn = nn.Softmax(dim=1)
self.aamask = util.allatom_mask.to(self.device)
self.atom_type_index = util.atom_type_index.to(self.device)
self.ljlk_parameters = util.ljlk_parameters.to(self.device)
self.lj_correction_parameters = util.lj_correction_parameters.to(self.device)
self.num_bonds = util.num_bonds.to(self.device)
self.cb_len = util.cb_length_t.to(self.device)
self.cb_ang = util.cb_angle_t.to(self.device)
self.cb_tor = util.cb_torsion_t.to(self.device)
# define model & load model
self.model = RoseTTAFoldModule(
**MODEL_PARAM,
aamask=self.aamask,
atom_type_index = self.atom_type_index,
ljlk_parameters = self.ljlk_parameters,
lj_correction_parameters = self.lj_correction_parameters,
num_bonds = self.num_bonds,
cb_len = self.cb_len,
cb_ang = self.cb_ang,
cb_tor = self.cb_tor
).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device)
self.ti_dev = util.torsion_indices.to(self.device)
self.ti_flip = util.torsion_can_flip.to(self.device)
self.ang_ref = util.reference_angles.to(self.device)
self.l2a = util.long2alt.to(self.device)
self.aamask = util.allatom_mask.to(self.device)
def load_model(self, model_name, suffix='last'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
return True
def run_prediction(self, seq, msa_seed, msa_extra, true_crds, res_mask, atom_mask, idx_pdb, xyz_t, t1d, alpha_t, tag):
self.model.eval()
with torch.no_grad():
# transfer inputs to device
B, _, N, L, _ = msa_seed.shape
idx_pdb = idx_pdb.to(self.device, non_blocking=True) # (B, L)
true_crds = true_crds.to(self.device, non_blocking=True) # (B, L, 27, 3)
res_mask = res_mask.to(self.device, non_blocking=True) # (B, L)
atom_mask = atom_mask.to(self.device, non_blocking=True) # (B, L, 27)
xyz_t = xyz_t.to(self.device, non_blocking=True)
t1d = t1d.to(self.device, non_blocking=True)
alpha_t = alpha_t.to(self.device, non_blocking=True)
seq = seq.to(self.device, non_blocking=True)
msa_seed = msa_seed.to(self.device, non_blocking=True)
msa_extra = msa_extra.to(self.device, non_blocking=True)
# processing labels & template features
c6d, _ = xyz_to_c6d(true_crds)
c6d = c6d_to_bins2(c6d)
t2d = xyz_to_t2d(xyz_t)
xyz_t = get_init_xyz(xyz_t)
xyz_prev = xyz_t[:,0]
# set number of recycles
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((1,L,10,2)).to(self.device, non_blocking=True) #fd we could get this from the template...
state_prev = None
best_lddt = torch.tensor([-1.0], device=seq.device)
best_xyz = None
best_logit = None
best_aa = None
for i_cycle in range(MAX_CYCLE):
with torch.cuda.amp.autocast(True):
logit_s, logit_aa_s, init_crds, alpha_prev, init_allatom, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model(
msa_seed[:,i_cycle],
msa_extra[:,i_cycle],
seq[:,i_cycle],
xyz_prev,
alpha_prev,
idx_pdb,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev
)
logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1)
#xyz_prev = init_crds[-1]
xyz_prev = init_allatom[-1].unsqueeze(0)
#msa_prev = msa_prev[:,0]
alpha_prev = alpha_prev[-1]
pred_lddt = lddt_unbin(pred_lddt_binned)
#print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean())
#_, all_crds = self.compute_allatom_coords(seq[:,i_cycle], xyz_prev, alpha_prev)
#self.write_pdb(seq[0, -1], all_crds[0], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle))
if pred_lddt.mean() < best_lddt.mean():
continue
best_xyz = init_allatom[-1].clone()
best_logit = logit_s
best_aa = logit_aa_s
best_lddt = pred_lddt.clone()
# lddt to native
seq = seq[:,0]
res_mask = res_mask[0]
true_tors, true_tors_alt, tors_mask, tors_planar = util.get_torsions(
true_crds, seq, self.ti_dev, self.ti_flip, self.ang_ref, mask_in=atom_mask)
# get alternative coordinates for ground-truth
true_alt = torch.zeros_like(true_crds)
true_alt.scatter_(2, self.l2a[seq,:,None].repeat(1,1,1,3), true_crds)
natRs_all, _n0 = self.compute_allatom_coords(seq, true_crds[...,:3,:], true_tors)
natRs_all_alt, _n1 = self.compute_allatom_coords(seq, true_alt[...,:3,:], true_tors_alt)
# - resolve symmetry
xs_mask = self.model.simulator.aamask[seq] # (B, L, 27)
xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss)
xs_mask *= atom_mask # mask missing atoms & residues as well
natRs_all_symm, nat_symm = resolve_symmetry(best_xyz, natRs_all[0], true_crds[0], natRs_all_alt[0], true_alt[0], xs_mask[0])
atom_mask_trim = atom_mask[0,res_mask]
true_lddt = calc_allatom_lddt(best_xyz.unsqueeze(0), nat_symm, idx_pdb, atom_mask)
ljE = calc_lj(
seq[0], init_allatom,
self.aamask,
self.ljlk_parameters,
self.lj_correction_parameters,
self.num_bonds
)
cbE = calc_cart_bonded(
seq, init_allatom, idx_pdb, self.cb_len, self.cb_ang, self.cb_tor)
print (tag[0],tag[1],true_lddt.mean().cpu().numpy(), ljE.cpu().numpy(), cbE.cpu().numpy())
self.write_pdb(seq[0], best_xyz, Bfacts=best_lddt[0], prefix="preds/%s_pred"%(tag[0]))
#if (true_lddt.mean()<0.9 or true_lddt.mean()>0.97):
# self.write_pdb(seq[0, -1], all_crds[0], Bfacts=best_lddt[0], prefix="%s_pred"%(tag[1]))
# self.write_pdb(seq[0, -1], true_crds[0,...,:14,:], Bfacts=best_lddt[0], prefix="%s_native"%(tag[1]))
#prob_s = list()
#for logit in logit_s:
# prob = self.active_fn(logit.float()) # distogram
# prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy()
# prob_s.append(prob)
#for prob in prob_s:
# prob += 1e-8
# prob = prob / torch.sum(prob, dim=0)[None]
#self.write_pdb(seq[0, -1], best_xyz[0], Bfacts=best_lddt[0], prefix="%s_init"%(out_prefix))
#prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s]
#np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
# omega=prob_s[1].astype(np.float16),\
# theta=prob_s[2].astype(np.float16),\
# phi=prob_s[3].astype(np.float16),\
# lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16))
def write_pdb(self, seq, atoms, Bfacts=None, prefix=None):
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 1)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
"A", i+1, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
else:
natoms = atoms.shape[1]
atms = util.aa2long[s]
# his prot hack
if (s==8 and torch.linalg.norm( atoms[i,9,:]-atoms[i,5,:] ) < 1.7):
atms = (
" N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1",
None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1",
" HD1", None, None, None, None, None, None) # his_d
for j,atm_j in enumerate(atms):
if (j<natoms and atm_j is not None): # and not torch.isnan(atomscpu[i,j,:]).any()):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
#DB="/home/robetta/rosetta_server_beta/external/databases/trRosetta/pdb100_2021Mar03/pdb100_2021Mar03"
DB = "/projects/ml/TrRosetta/pdb100_2020Mar11/pdb100_2020Mar11"
import argparse
parser = argparse.ArgumentParser(description="RoseTTAFold: Protein structure prediction with 3-track attentions on 1D, 2D, and 3D features")
parser.add_argument("-model_name", default="BFF", required=False,
help="Prefix for model. The model under models/[model_name]_best.pt will be used. [BFF]")
parser.add_argument("-i", default=1, required=False, type=int,
help="parallelize i of j")
parser.add_argument("-j", default=1, required=False, type=int,
help="parallelize i of j")
args = parser.parse_args()
return args
LOADER_PARAMS = {
"FB_LIST" : "%s/list_b1-3.csv"%fb_dir,
"FB_DIR" : fb_dir,
"PLDDTCUT": 70.0,
#"seqID" : 50.0,
"MAXLAT" : 256,
"MAXSEQ" : 2048,
"MAXCYCLE": 4,
"SCCUT" : 90.0
}
if __name__ == "__main__":
args = get_args()
pred = Predictor(model_name=args.model_name)
# compile facebook model sets
with open(LOADER_PARAMS['FB_LIST'], 'r') as f:
reader = csv.reader(f)
next(reader)
rows = [[r[0],r[2],int(r[3]),len(r[-1].strip())] for r in reader
if float(r[1]) > 85.0 and
len(r[-1].strip()) > 200]
fb = {}
for r in rows:
if r[2] in fb.keys():
fb[r[2]].append((r[:2], r[-1]))
else:
fb[r[2]] = [(r[:2], r[-1])]
for i, (id_i, key_i) in enumerate(fb.items()):
if (i%args.j != args.i%args.j):
continue
item = key_i[0][0]
a3m = get_msa(os.path.join(LOADER_PARAMS["FB_DIR"], "a3m", item[-1][:2], item[-1][2:], item[0]+".a3m.gz"), item[0])
pdb = get_pdb(os.path.join(LOADER_PARAMS["FB_DIR"], "pdb", item[-1][:2], item[-1][2:], item[0]+".pdb"),
os.path.join(LOADER_PARAMS["FB_DIR"], "pdb", item[-1][:2], item[-1][2:], item[0]+".plddt.npy"),
item[0], LOADER_PARAMS['PLDDTCUT'], LOADER_PARAMS['SCCUT'])
idx = pdb['idx']
l = a3m['msa'].shape[-1]
msa = a3m['msa'].long()
ins = a3m['ins'].long()
#if len(msa) > 5:
# msa, ins = MSABlockDeletion(msa, ins)
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, LOADER_PARAMS, p_mask=0.0)
# No templates
xyz_t = torch.full((1,l,27,3),np.nan).float()
alpha_t = torch.full((1,l,30),0.0).float()
f1d_t = torch.nn.functional.one_hot(torch.full((1, l), 20).long(), num_classes=21).float() # all gaps
f1d_t = torch.cat((f1d_t, torch.zeros((1,l,1)).float()), -1)
true_crds = torch.full((len(idx),27,3),np.nan).float()
true_crds[:,:,:] = pdb['xyz']
mask_atoms = torch.zeros((len(idx),27), dtype=torch.bool)
mask_atoms[:,:] = pdb['mask']
mask_res = mask_atoms.sum(dim=-1)>=3
idx_pdb = torch.arange(l).long()
pred.run_prediction(
seq[None,...], msa_seed[None,...], msa_extra[None,...],
true_crds[None,...], mask_res[None,...], mask_atoms[None,...],
idx_pdb[None,...], xyz_t[None,...], f1d_t[None,...], alpha_t[None,...],
item)

View File

@@ -0,0 +1,51 @@
import os
import numpy as np
def get_lddt(fn):
data = np.load(fn)
return data['lddt'].mean()
#lddt = list()
#with open(fn) as fp:
# for line in fp:
# if not line.startswith("ATOM"):
# continue
# if line[12:16].strip() == "CA":
# lddt.append(float(line[61:66]))
#return np.mean(lddt)
ans_home = "/projects/casp/CASP14/eval/official/answers/domainwise"
dom_s = [line.split() for line in open("/projects/casp/CASP14/eval/official/difficulty.domains")]
if not os.path.exists("model1"):
os.mkdir("model1")
print ("#DomID diff TM TS LDDT")
for domID, diff in dom_s:
tar = domID.split('-')[0]
if not os.path.exists("%s/%s_00_init.pdb"%(tar, tar)):
continue
TM_s = list()
TS_s = list()
lddt_s = list()
pdb_s = list()
for i_iter in range(5):
pdb_fn = "%s/%s_%02d_init.pdb"%(tar, tar, i_iter)
if not os.path.exists(pdb_fn):
continue
npz_fn = "%s/%s_%02d.npz"%(tar, tar, i_iter)
lddt = get_lddt(npz_fn)
lddt_s.append(lddt)
pdb_s.append(pdb_fn)
max_idx = np.argmax(lddt_s)
pdb_fn = pdb_s[max_idx]
lines = os.popen("TMscore %s %s/%s.pdb"%(pdb_fn, ans_home, domID)).readlines()
TM = 0.0
TS = 0.0
for line in lines:
if line.startswith("TM-score"):
TM = float(line.split()[2])
elif line.startswith("GDT-TS"):
TS = float(line.split()[1])
break
lddt = os.popen("lddt %s %s/%s.pdb | grep Glob"%(pdb_fn, ans_home, domID)).readlines()[-1].split()[-1]
print (domID, diff, TM, TS, lddt)
os.system("cp %s model1/%s_pred.pdb"%(pdb_fn, tar))

91
RF2_allatom/ffindex.py Normal file
View File

@@ -0,0 +1,91 @@
#!/usr/bin/env python
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
'''
Created on Apr 30, 2014
@author: meiermark
'''
import sys
import mmap
from collections import namedtuple
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
def read_index(ffindex_filename):
entries = []
fh = open(ffindex_filename)
for line in fh:
tokens = line.split("\t")
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
fh.close()
return entries
def read_data(ffdata_filename):
fh = open(ffdata_filename, "r+b")
data = mmap.mmap(fh.fileno(), 0)
fh.close()
return data
def get_entry_by_name(name, index):
#TODO: bsearch
for entry in index:
if(name == entry.name):
return entry
return None
def read_entry_lines(entry, data):
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
return lines
def read_entry_data(entry, data):
return data[entry.offset:entry.offset + entry.length - 1]
def write_entry(entries, data_fh, entry_name, offset, data):
data_fh.write(data[:-1])
data_fh.write(bytearray(1))
entry = FFindexEntry(entry_name, offset, len(data))
entries.append(entry)
return offset + len(data)
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
with open(file_name, "rb") as fh:
data = bytearray(fh.read())
return write_entry(entries, data_fh, entry_name, offset, data)
def finish_db(entries, ffindex_filename, data_fh):
data_fh.close()
write_entries_to_db(entries, ffindex_filename)
def write_entries_to_db(entries, ffindex_filename):
sorted(entries, key=lambda x: x.name)
index_fh = open(ffindex_filename, "w")
for entry in entries:
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
index_fh.close()
def write_entry_to_file(entry, data, file):
lines = read_lines(entry, data)
fh = open(file, "w")
for line in lines:
fh.write(line+"\n")
fh.close()

266
RF2_allatom/kinematics.py Normal file
View File

@@ -0,0 +1,266 @@
import numpy as np
import torch
from util import INIT_CRDS, INIT_NA_CRDS, generate_Cbeta, is_nucleic
from chemical import NTOTAL
PARAMS = {
"DMIN" : 2.0,
"DMAX" : 20.0,
"DBINS" : 36,
"ABINS" : 36,
}
# ============================================================
def get_pair_dist(a, b):
"""calculate pair distances between two sets of points
Parameters
----------
a,b : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of two sets of atoms
Returns
-------
dist : pytorch tensor of shape [batch,nres,nres]
stores paitwise distances between atoms in a and b
"""
dist = torch.cdist(a, b, p=2)
return dist
# ============================================================
def get_ang(a, b, c, eps=1e-6):
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
from Cartesian coordinates of three sets of atoms a,b,c
Parameters
----------
a,b,c : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of three sets of atoms
Returns
-------
ang : pytorch tensor of shape [batch,nres]
stores resulting planar angles
"""
v = a - b
w = c - b
vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps)
wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps)
vw = torch.sum(vn*wn, dim=-1)
return torch.acos(torch.clamp(vw,-0.999,0.999))
# ============================================================
def get_dih(a, b, c, d, eps=1e-6):
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
given Cartesian coordinates of four sets of atoms a,b,c,d
Parameters
----------
a,b,c,d : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of four sets of atoms
Returns
-------
dih : pytorch tensor of shape [batch,nres]
stores resulting dihedrals
"""
b0 = a - b
b1 = c - b
b2 = d - c
b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n
w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n
x = torch.sum(v*w, dim=-1)
y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1)
return torch.atan2(y+eps, x+eps)
# ============================================================
def xyz_to_c6d(xyz, params=PARAMS):
"""convert cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz : pytorch tensor of shape [batch,nres,3,3]
stores Cartesian coordinates of backbone N,Ca,C atoms
Returns
-------
c6d : pytorch tensor of shape [batch,nres,nres,4]
stores stacked dist,omega,theta,phi 2D maps
"""
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
Cb = generate_Cbeta(N,Ca,C)
# 6d coordinates order: (dist,omega,theta,phi)
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
dist = get_pair_dist(Cb,Cb)
dist[torch.isnan(dist)] = 999.9
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
# fix long-range distances
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
mask = torch.zeros((batch, nres,nres), dtype=xyz.dtype, device=xyz.device)
mask[b,i,j] = 1.0
return c6d, mask
def xyz_to_t2d(xyz_t, params=PARAMS):
"""convert template cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
stores Cartesian coordinates of template backbone N,Ca,C atoms
Returns
-------
t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
stores stacked dist,omega,theta,phi 2D maps
"""
B, T, L = xyz_t.shape[:3]
c6d, mask = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params)
c6d = c6d.view(B, T, L, L, 4)
mask = mask.view(B, T, L, L, 1)
#
# dist to one-hot encoded
dist = dist_to_onehot(c6d[...,0], params)
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
#
mask = torch.isnan(c6d[:,:,:,:,0]) # (B, T, L, L)
t2d = torch.cat((dist, orien, mask.unsqueeze(-1)), dim=-1)
t2d[torch.isnan(t2d)] = 0.0
return t2d
def xyz_to_bbtor(xyz, params=PARAMS):
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
next_N = torch.roll(N, -1, dims=1)
prev_C = torch.roll(C, 1, dims=1)
phi = get_dih(prev_C, N, Ca, C)
psi = get_dih(N, Ca, C, next_N)
#
phi[:,0] = 0.0
psi[:,-1] = 0.0
#
astep = 2.0*np.pi / params['ABINS']
phi_bin = torch.round((phi+np.pi-astep/2)/astep)
psi_bin = torch.round((psi+np.pi-astep/2)/astep)
return torch.stack([phi_bin, psi_bin], axis=-1).long()
# ============================================================
def dist_to_onehot(dist, params=PARAMS):
dist[torch.isnan(dist)] = 999.9
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=dist.dtype,device=dist.device)
db = torch.bucketize(dist.contiguous(),dbins).long()
dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS']+1).float()
return dist
# ============================================================
def dist_to_bins(dist,params=PARAMS):
"""bin 2d distance maps
"""
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
db = torch.round((dist-params['DMIN']-dstep/2)/dstep)
db[db<0] = 0
db[db>params['DBINS']] = params['DBINS']
return db.long()
# ============================================================
def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
"""bin 2d distance and orientation maps
"""
dstep = (params['DMAX'] - params['DMIN']) / params['DBINS']
astep = 2.0*np.pi / params['ABINS']
db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep)
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
pb = torch.round((c6d[...,3]-astep/2)/astep)
# put all d<dmin into one bin
db[db<0] = 0
# synchronize no-contact bins
db[db>params['DBINS']] = params['DBINS']
ob[db==params['DBINS']] = params['ABINS']
tb[db==params['DBINS']] = params['ABINS']
pb[db==params['DBINS']] = params['ABINS']//2
if negative:
db = torch.where(same_chain.bool(), db.long(), params['DBINS'])
ob = torch.where(same_chain.bool(), ob.long(), params['ABINS'])
tb = torch.where(same_chain.bool(), tb.long(), params['ABINS'])
pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2)
return torch.stack([db,ob,tb,pb],axis=-1).long()
def get_init_xyz(seq, xyz_t, same_chain):
# input: xyz_t (B, T, L, Natms, 3)
# ouput: xyz (B, T, L, Natms, 3)
B, T, L = xyz_t.shape[:3]
init = torch.full((B,T,L,NTOTAL,3), np.nan, device=xyz_t.device)
na_mask = is_nucleic(seq)
b_nmask,l_nmask = na_mask.nonzero(as_tuple=True)
b_pmask,l_pmask = (~na_mask).nonzero(as_tuple=True)
init[b_pmask,:,l_pmask] = INIT_CRDS[None,...].to(xyz_t.device)
init[b_nmask,:,l_nmask] = INIT_NA_CRDS[None,...].to(xyz_t.device)
if torch.isnan(xyz_t).all():
return init
mask = torch.isnan(xyz_t[:,:,:,:3]).any(dim=-1).any(dim=-1) # (B, T, L)
#
center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3)
xyz_t = xyz_t - center_CA.view(B,T,1,1,3)
#
idx_s = list()
for i_b in range(B):
for i_T in range(T):
if mask[i_b, i_T].all():
continue
exist_in_templ = torch.where(~mask[i_b, i_T])[0] # (L_sub)
is_same_chain_in_templ = same_chain[i_b][:,~mask[i_b,i_T]].bool() # (L, L_sub)
seqmap = (torch.arange(L, device=xyz_t.device)[:,None] - exist_in_templ[None,:]).abs() # (L, L_sub)
seqmap[~is_same_chain_in_templ] += 99999
seqmap = torch.argmin(seqmap, dim=-1) # (L)
idx = torch.gather(exist_in_templ, -1, seqmap) # (L)
offset_CA = torch.gather(xyz_t[i_b, i_T, :, 1, :], 0, idx.reshape(L,1).expand(-1,3))
init[i_b,i_T] += offset_CA.reshape(L,1,3)
#
xyz = torch.where(mask.view(B, T, L, 1, 1), init, xyz_t)
return xyz

812
RF2_allatom/loss.py Normal file
View File

@@ -0,0 +1,812 @@
import torch
import numpy as np
from util import (
rigid_from_3_points,
cb_lengths_CN,
cb_angles_CACN,
cb_angles_CNCA,
cb_torsions_CACNH,
cb_torsions_CANCO,
is_nucleic
)
from chemical import NFRAMES
from kinematics import get_dih, get_ang
from scoring import HbHybType
# Loss functions for the training
# 1. BB rmsd loss
# 2. distance loss (or 6D loss?)
# 3. bond geometry loss
# 4. predicted lddt loss
#fd use improved coordinate frame generation
def get_t(N, Ca, C, eps=1e-5):
I,B,L=N.shape[:3]
Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), eps=eps)
Rs = Rs.view(I,B,L,3,3)
Ts = Ts.view(I,B,L,3)
t = Ts.unsqueeze(-2) - Ts.unsqueeze(-3)
return torch.einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3) **fixed
def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra=10.0, d_clamp_inter=30.0, A=10.0, gamma=0.99, eps=1e-6):
'''
Calculate Backbone FAPE loss
Input:
- pred: predicted coordinates (I, B, L, n_atom, 3)
- true: true coordinates (B, L, n_atom, 3)
Output: str loss
'''
I = pred.shape[0]
true = true.unsqueeze(0)
t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2])
t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2])
difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps)
clamp = torch.zeros_like(difference)
clamp[:,same_chain==1] = d_clamp_intra
clamp[:,same_chain==0] = d_clamp_inter
difference = torch.clamp(difference, max=clamp)
loss = difference / A # (I, B, L, L)
# Get a mask information (ignore missing residue + inter-chain residues)
# for positive cases, mask = mask_2d
# for negative cases (non-interacting pairs) mask = mask_2d*same_chain
if negative:
mask = mask_2d * same_chain
else:
mask = mask_2d
# calculate masked loss (ignore missing regions when calculate loss)
loss = (mask[None]*loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I)
# weighting loss
w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device))
w_loss = torch.flip(w_loss, (0,))
w_loss = w_loss / w_loss.sum()
tot_loss = (w_loss * loss).sum()
return tot_loss, loss.detach()
#resolve rotationally equivalent sidechains
def resolve_symmetry(xs, Rsnat_all, xsnat, Rsnat_all_alt, xsnat_alt, atm_mask):
dists = torch.linalg.norm( xs[:,:,None,:] - xs[atm_mask,:][None,None,:,:], dim=-1)
dists_nat = torch.linalg.norm( xsnat[:,:,None,:] - xsnat[atm_mask,:][None,None,:,:], dim=-1)
dists_natalt = torch.linalg.norm( xsnat_alt[:,:,None,:] - xsnat_alt[atm_mask,:][None,None,:,:], dim=-1)
drms_nat = torch.sum(torch.abs(dists_nat-dists),dim=(-1,-2))
drms_natalt = torch.sum(torch.abs(dists_nat-dists_natalt), dim=(-1,-2))
Rsnat_symm = Rsnat_all
xs_symm = xsnat
toflip = drms_natalt<drms_nat
Rsnat_symm[toflip,...] = Rsnat_all_alt[toflip,...]
xs_symm[toflip,...] = xsnat_alt[toflip,...]
return Rsnat_symm, xs_symm
# resolve "equivalent" natives
def resolve_equiv_natives(xs, natstack, maskstack):
if (len(natstack.shape)==4):
return natstack, maskstack
if (natstack.shape[1]==1):
return natstack[:,0,...], maskstack[:,0,...]
dx = torch.norm( xs[:,None,:,None,1,:]-xs[:,None,None,:,1,:], dim=-1)
dnat = torch.norm( natstack[:,:,:,None,1,:]-natstack[:,:,None,:,1,:], dim=-1)
delta = torch.sum( torch.abs(dnat-dx), dim=(-2,-1))
return natstack[:,torch.argmin(delta),...], maskstack[:,torch.argmin(delta),...]
#torsion angle predictor loss
def torsionAngleLoss( alpha, alphanat, alphanat_alt, tors_mask, tors_planar, eps=1e-8 ):
I = alpha.shape[0]
lnat = torch.sqrt( torch.sum( torch.square(alpha), dim=-1 ) + eps )
anorm = alpha / (lnat[...,None])
l_tors_ij = torch.min(
torch.sum(torch.square( anorm - alphanat[None] ),dim=-1),
torch.sum(torch.square( anorm - alphanat_alt[None] ),dim=-1)
)
l_tors = torch.sum( l_tors_ij*tors_mask[None] ) / (torch.sum( tors_mask )*I + eps)
l_norm = torch.sum( torch.abs(lnat-1.0)*tors_mask[None] ) / (torch.sum( tors_mask )*I + eps)
l_planar = torch.sum( torch.abs( alpha[...,0] )*tors_planar[None] ) / (torch.sum( tors_planar )*I + eps)
return l_tors+0.02*l_norm+0.02*l_planar
def compute_FAPE(Rs, Ts, xs, Rsnat, Tsnat, xsnat, Z=10.0, dclamp=10.0, eps=1e-4):
xij = torch.einsum('rji,rsj->rsi', Rs, xs[None,...] - Ts[:,None,...])
xij_t = torch.einsum('rji,rsj->rsi', Rsnat, xsnat[None,...] - Tsnat[:,None,...])
#torch.norm(xij-xij_t,dim=-1)
diff = torch.sqrt( torch.sum( torch.square(xij-xij_t), dim=-1 ) + eps )
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean()
return loss
# from Ivan: FAPE generalized over atom sets & frames
def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, Z=10.0, dclamp=10.0, eps=1e-4):
# X (predicted) N x L x 27 x 3
# Y (native) 1 x L x 27 x 3
# atom_mask 1 x L x 27
# frames 1 x L x 6 x 3 x 2
# frame_mask 1 x L x 6
N, L, natoms, _ = X.shape
# flatten middle dims so can gather across residues
X_prime = X.reshape(N, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
Y_prime = Y.reshape(1, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
# reindex frames for flat X
frames_reindex = torch.zeros(frames.shape[:-1], device=frames.device)
for i in range(L):
frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1]
frames_reindex = frames_reindex.long()
frame_mask *= torch.all(
torch.gather(atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3),
axis=-1)
X_x = torch.gather(X_prime, 1, frames_reindex[...,0:1].repeat(N,1,1,3))
X_y = torch.gather(X_prime, 1, frames_reindex[...,1:2].repeat(N,1,1,3))
X_z = torch.gather(X_prime, 1, frames_reindex[...,2:3].repeat(N,1,1,3))
uX,tX = rigid_from_3_points(X_x, X_y, X_z)
Y_x = torch.gather(Y_prime, 1, frames_reindex[...,0:1].repeat(1,1,1,3))
Y_y = torch.gather(Y_prime, 1, frames_reindex[...,1:2].repeat(1,1,1,3))
Y_z = torch.gather(Y_prime, 1, frames_reindex[...,2:3].repeat(1,1,1,3))
uY,tY = rigid_from_3_points(Y_x, Y_y, Y_z)
xij = torch.einsum(
'brji,brsj->brsi',
uX[:,frame_mask[0]], X[:,atom_mask[0]][:,None,...] - X_y[:,frame_mask[0]][:,:,None,...]
)
xij_t = torch.einsum('rji,rsj->rsi', uY[frame_mask], Y[atom_mask][None,...] - Y_y[frame_mask][:,None,...])
diff = torch.sqrt( torch.sum( torch.square(xij-xij_t[None,...]), dim=-1 ) + eps )
loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean(dim=(1,2))
return loss
def angle(a, b, c, eps=1e-6):
'''
Calculate cos/sin angle between ab and cb
a,b,c have shape of (B, L, 3)
'''
B,L = a.shape[:2]
u1 = a-b
u2 = c-b
u1_norm = torch.norm(u1, dim=-1, keepdim=True) + eps
u2_norm = torch.norm(u2, dim=-1, keepdim=True) + eps
# normalize u1 & u2 --> make unit vector
u1 = u1 / u1_norm
u2 = u2 / u2_norm
u1 = u1.reshape(B*L, 3)
u2 = u2.reshape(B*L, 3)
# sin_theta = norm(a cross b)/(norm(a)*norm(b))
# cos_theta = norm(a dot b) / (norm(a)*norm(b))
sin_theta = torch.norm(torch.cross(u1, u2, dim=1), dim=1, keepdim=True).reshape(B, L, 1) # (B,L,1)
cos_theta = torch.matmul(u1[:,None,:], u2[:,:,None]).reshape(B, L, 1)
return torch.cat([cos_theta, sin_theta], axis=-1) # (B, L, 2)
def length(a, b):
return torch.norm(a-b, dim=-1)
def torsion(a,b,c,d, eps=1e-6):
#A function that takes in 4 atom coordinates:
# a - [B,L,3]
# b - [B,L,3]
# c - [B,L,3]
# d - [B,L,3]
# and returns cos and sin of the dihedral angle between those 4 points in order a, b, c, d
# output - [B,L,2]
u1 = b-a
u1 = u1 / (torch.norm(u1, dim=-1, keepdim=True) + eps)
u2 = c-b
u2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps)
u3 = d-c
u3 = u3 / (torch.norm(u3, dim=-1, keepdim=True) + eps)
#
t1 = torch.cross(u1, u2, dim=-1) #[B, L, 3]
t2 = torch.cross(u2, u3, dim=-1)
t1_norm = torch.norm(t1, dim=-1, keepdim=True)
t2_norm = torch.norm(t2, dim=-1, keepdim=True)
cos_angle = torch.matmul(t1[:,:,None,:], t2[:,:,:,None])[:,:,0]
sin_angle = torch.norm(u2, dim=-1,keepdim=True)*(torch.matmul(u1[:,:,None,:], t2[:,:,:,None])[:,:,0])
cos_sin = torch.cat([cos_angle, sin_angle], axis=-1)/(t1_norm*t2_norm+eps) #[B,L,2]
return cos_sin
# ideal N-C distance, ideal cos(CA-C-N angle), ideal cos(C-N-CA angle)
# for NA, we do not compute this as it is not computable from the stubs alone
def calc_BB_bond_geom(
seq, pred, idx, eps=1e-6,
ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255,
sig_len=0.02, sig_ang=0.05):
'''
Calculate backbone bond geometry (bond length and angle) and put loss on them
Input:
- pred: predicted coords (B, L, :, 3), 0; N / 1; CA / 2; C
- true: True coords (B, L, :, 3)
Output:
- bond length loss, bond angle loss
'''
def cosangle( A,B,C ):
AB = A-B
BC = C-B
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999)
B, L = pred.shape[:2]
bonded = (idx[:,1:] - idx[:,:-1])==1
is_prot = ~is_nucleic(seq)[:-1]
# bond length: C-N
blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(B,L-1) # (B, L-1)
CN_loss = torch.clamp( torch.abs(blen_CN_pred - ideal_NC) - sig_len, min=0.0 )
CN_loss = (bonded*is_prot*CN_loss).sum() / ((bonded*is_prot).sum() + eps)
blen_loss = CN_loss #fd squared loss
# bond angle: CA-C-N, C-N-CA
bang_CACN_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1)
bang_CNCA_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1)
CACN_loss = torch.clamp( torch.abs(bang_CACN_pred - ideal_CACN) - sig_ang, min=0.0 )
CACN_loss = (bonded*is_prot*CACN_loss).sum() / ((bonded*is_prot).sum() + eps)
CNCA_loss = torch.clamp( torch.abs(bang_CNCA_pred - ideal_CNCA) - sig_ang, min=0.0 )
CNCA_loss = (bonded*is_prot*CNCA_loss).sum() / ((bonded*is_prot).sum() + eps)
bang_loss = CACN_loss + CNCA_loss
return blen_loss+bang_loss
def calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6):
# pred: N x L x 27 x 3
# idx: 1 x L
# seq: 1 x L
def gen_ang( A,B,C ):
AB = A-B
BC = C-B
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
return torch.acos( torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999) )
# quadratic from [-1,1], linear elsewhere
def boundfunc(X):
Y = torch.abs(X)
Y[Y<1.0] = torch.square(Y[Y<1.0])
#Y = torch.square(X)
return Y
N,L = pred.shape[:2]
cb_loss = torch.zeros(N, device=pred.device)
## intra-res
cblens = len_param[seq]
len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3)
len_all = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3)
len_mask = cblens[...,0]!=cblens[...,1]
E_cb_len = (
len_mask[None,...] *
cblens[None,...,3] *
boundfunc( length(len_all[...,0,:],len_all[...,1,:]) - cblens[...,2] )
).sum(dim=(0,3)) / len_mask.sum()
# figure out which his are his_d
cblens[seq==8] = len_param[-1]
len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3)
len_all_a = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3)
len_mask_a = cblens[...,0]!=cblens[...,1]
E_cb_len_a = (
len_mask_a[None,...] *
cblens[None,...,3] *
boundfunc( length(len_all_a[...,0,:],len_all_a[...,1,:]) - cblens[...,2] )
).sum(dim=(0,3)) / len_mask.sum() # N,L
is_his_d = (seq==8)*(E_cb_len_a<E_cb_len)
cb_loss += torch.min(E_cb_len_a,E_cb_len).sum(dim=1)
cbangs = ang_param[seq].repeat(N,1,1,1)
cbangs[is_his_d] = ang_param[-1]
ang_idx = cbangs[...,:3].to(torch.long).reshape(N,L,-1,1).repeat(1,1,1,3)
ang_all = torch.gather(pred, 2, ang_idx).reshape(N,L,-1,3,3)
ang_mask = cbangs[...,0]!=cbangs[...,1]
E_cb_ang = (
ang_mask[None,...] *
cbangs[None,...,4] *
boundfunc( get_ang(ang_all[...,0,:],ang_all[...,1,:],ang_all[...,2,:]) - cbangs[None,...,3] )
).sum(dim=(0,2,3)) / ang_mask.sum()
cb_loss += E_cb_ang
cbtors = tor_param[seq].repeat(N,1,1,1)
cbtors[is_his_d] = tor_param[-1]
tor_idx = cbtors[...,:4].to(torch.long).reshape(N,L,-1,1).repeat(1,1,1,3)
tor_all = torch.gather(pred, 2, tor_idx).reshape(N,L,-1,4,3)
tor_mask = cbtors[...,0]!=cbtors[...,1]
offset = 2*np.pi/cbtors[None,...,6]
tor_deltas = (
get_dih(
tor_all[...,0,:],tor_all[...,1,:],tor_all[...,2,:],tor_all[...,3,:]
) - cbtors[None,...,4] + 0.5*offset
) % offset - 0.5*offset
dihs = get_dih(
tor_all[...,0,:],tor_all[...,1,:],tor_all[...,2,:],tor_all[...,3,:]
)
E_cb_tor = (
tor_mask[None,...] *
cbtors[None,...,5] *
boundfunc( tor_deltas )
).sum(dim=(0,2,3)) / tor_mask.sum()
cb_loss += E_cb_tor
# inter-res
# bond length: C-N
bonded = (idx[:,1:] - idx[:,:-1])==1
blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(N,L-1) # (B, L-1)
CN_loss = cb_lengths_CN[1] * boundfunc(blen_CN_pred - cb_lengths_CN[0])
cb_loss += (bonded*CN_loss).sum(dim=1) / (bonded.sum())
# bond angle: CA-C-N, C-N-CA
bang_CACN_pred = get_ang(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(N,L-1)
CACN_loss = cb_angles_CACN[1] * boundfunc(bang_CACN_pred - cb_angles_CACN[0])
cb_loss += (bonded*CACN_loss).sum(dim=1) / (bonded.sum())
bang_CNCA_pred = get_ang(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(N,L-1)
CNCA_loss = cb_angles_CNCA[1] * boundfunc(bang_CNCA_pred - cb_angles_CNCA[0])
cb_loss += (bonded*CNCA_loss).sum(dim=1) / (bonded.sum())
# improper torsions CA-C-N-H (CD-C-N-CA), CA-N-C-O
# planarity around N (H for non-pro, CD for pro)
atom4idx = torch.full_like(seq, 14)
atom4idx[seq==14] = 6 # set to CD for proline
atom4 = torch.gather( pred, 2, atom4idx[:,:,None,None].repeat(1,1,1,3) )
btor_CACNH_delta = (
get_dih(
pred[:,:-1,1], pred[:,:-1,2], pred[:,1:,0], atom4[:,1:,0]
) - cb_torsions_CACNH[0] + np.pi/2
) % np.pi - np.pi/2
CACNH_loss = cb_torsions_CACNH[1] * boundfunc( btor_CACNH_delta )
cb_loss += (bonded*CACNH_loss).sum(dim=1) / (bonded.sum())
# planarity around C
btor_CANCO_delta = (
get_dih(
pred[:,:-1,1], pred[:,1:,0], pred[:,:-1,2], pred[:,:-1,3]
) - cb_torsions_CANCO[0] + np.pi/2
) % np.pi - np.pi/2
CANCO_loss = cb_torsions_CANCO[1] * boundfunc( btor_CANCO_delta )
cb_loss += (bonded*CANCO_loss).sum(dim=1) / (bonded.sum())
return cb_loss
# AF2-like version of clash score
def calc_clash(xs, mask):
DISTCUT=2.0 # (d_lit - tau) from AF2 MS
L = xs.shape[0]
dij = torch.sqrt(
torch.sum( torch.square( xs[:,:,None,None,:]-xs[None,None,:,:,:] ), dim=-1 ) + 1e-8
)
allmask = mask[:,:,None,None]*mask[None,None,:,:]
allmask[torch.arange(L),:,torch.arange(L),:] = False # ignore res-self
allmask[torch.arange(1,L),0,torch.arange(L-1),2] = False # ignore N->C
allmask[torch.arange(L-1),2,torch.arange(1,L),0] = False # ignore N->C
clash = torch.sum( torch.clamp(DISTCUT-dij[allmask],0.0) ) / torch.sum(mask)
return clash
# Rosetta-like version of LJ (fa_atr+fa_rep)
# lj_lin is switch from linear to 12-6. Smaller values more sharply penalize clashes
def calc_lj(
seq, xs, aamask, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
def ljV(dist, sigma, epsilon, lj_lin, lj_maxrad):
N = dist.shape[0]
linpart = dist<lj_lin*sigma[None]
deff = dist.clone()
deff[linpart] = lj_lin*sigma.repeat(N,1)[linpart]
sd = sigma[None] / deff
sd2 = sd*sd
sd6 = sd2 * sd2 * sd2
sd12 = sd6 * sd6
ljE = epsilon * (sd12 - 2 * sd6)
ljE[linpart] += epsilon.repeat(N,1)[linpart] * (
-12 * sd12[linpart]/deff[linpart] + 12 * sd6[linpart]/deff[linpart]
) * (dist[linpart]-deff[linpart])
if (lj_maxrad>0):
sdmax = sigma / lj_maxrad
sd2 = sd*sd
sd6 = sd2 * sd2 * sd2
sd12 = sd6 * sd6
ljE = ljE - epsilon * (sd12 - 2 * sd6)
return ljE
N, L = xs.shape[:2]
# mask keeps running total of what to compute
mask = aamask[seq][...,None,None]*aamask[seq][None,None,...]
idxes1r = torch.tril_indices(L,L,-1)
mask[idxes1r[0],:,idxes1r[1],:] = False
idxes2r = torch.arange(L)
idxes2a = torch.tril_indices(27,27,0)
mask[idxes2r[:,None],idxes2a[0:1],idxes2r[:,None],idxes2a[1:2]] = False
# "countpair" can be enforced by making this a weight
mask[idxes2r,:,idxes2r,:] *= num_bonds[seq,:,:] >= 4 #intra-res
mask[idxes2r[:-1],:,idxes2r[1:],:] *= (
num_bonds[seq[:-1],:,2:3] + num_bonds[seq[1:],0:1,:] + 1 >= 4 #inter-res
)
si,ai,sj,aj = mask.nonzero(as_tuple=True)
ds = torch.sqrt( torch.sum ( torch.square( xs[:,si,ai]-xs[:,sj,aj] ), dim=-1 ) + eps )
# hbond correction
use_hb_dis = (
ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,1]
+ ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0] )
use_ohdon_dis = ( # OH are both donors & acceptors
ljcorr[seq[si],ai,0]*ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0]
+ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,0]*ljcorr[seq[sj],aj,1]
)
use_hb_hdis = (
ljcorr[seq[si],ai,2]*ljcorr[seq[sj],aj,1]
+ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,2]
)
# disulfide correction
potential_disulf = ljcorr[seq[si],ai,3]*ljcorr[seq[sj],aj,3]
ljrs = ljparams[seq[si],ai,0] + ljparams[seq[sj],aj,0]
ljrs[use_hb_dis] = lj_hb_dis
ljrs[use_ohdon_dis] = lj_OHdon_dis
ljrs[use_hb_hdis] = lj_hbond_hdis
ljss = torch.sqrt( ljparams[seq[si],ai,1] * ljparams[seq[sj],aj,1] + eps )
ljss [potential_disulf] = 0.0
ljval = ljV(ds,ljrs,ljss,lj_lin,lj_maxrad)
return (torch.sum( ljval, dim=-1 )/torch.sum(aamask[seq]))
def calc_hb(
seq, xs, aamask, hbtypes, hbbaseatoms, hbpolys,
hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357,
hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True
):
def evalpoly( ds, xrange, yrange, coeffs ):
v = coeffs[...,0]
for i in range(1,10):
v = v * ds + coeffs[...,i]
minmask = ds<xrange[...,0]
v[minmask] = yrange[minmask][...,0]
maxmask = ds>xrange[...,1]
v[maxmask] = yrange[maxmask][...,1]
return v
def cosangle( A,B,C ):
AB = A-B
BC = C-B
ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps)
BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps)
return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999)
hbts = hbtypes[seq]
hbba = hbbaseatoms[seq]
rh,ah = (hbts[...,0]>=0).nonzero(as_tuple=True)
ra,aa = (hbts[...,1]>=0).nonzero(as_tuple=True)
D_xs = xs[rh,hbba[rh,ah,0]][:,None,:]
H_xs = xs[rh,ah][:,None,:]
A_xs = xs[ra,aa][None,:,:]
B_xs = xs[ra,hbba[ra,aa,0]][None,:,:]
B0_xs = xs[ra,hbba[ra,aa,1]][None,:,:]
hyb = hbts[ra,aa,2]
polys = hbpolys[hbts[rh,ah,0][:,None],hbts[ra,aa,1][None,:]]
AH = torch.sqrt( torch.sum( torch.square( H_xs-A_xs), axis=-1) + eps )
AHD = torch.acos( cosangle( B_xs, A_xs, H_xs) )
Es = polys[...,0,0]*evalpoly(
AH,polys[...,0,1:3],polys[...,0,3:5],polys[...,0,5:])
Es += polys[...,1,0] * evalpoly(
AHD,polys[...,1,1:3],polys[...,1,3:5],polys[...,1,5:])
Bm = 0.5*(B0_xs[:,hyb==HbHybType.RING]+B_xs[:,hyb==HbHybType.RING])
cosBAH = cosangle( Bm, A_xs[:,hyb==HbHybType.RING], H_xs )
Es[:,hyb==HbHybType.RING] += polys[:,hyb==HbHybType.RING,2,0] * evalpoly(
cosBAH,
polys[:,hyb==HbHybType.RING,2,1:3],
polys[:,hyb==HbHybType.RING,2,3:5],
polys[:,hyb==HbHybType.RING,2,5:])
cosBAH1 = cosangle( B_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs )
cosBAH2 = cosangle( B0_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs )
Esp3_1 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly(
cosBAH1,
polys[:,hyb==HbHybType.SP3,2,1:3],
polys[:,hyb==HbHybType.SP3,2,3:5],
polys[:,hyb==HbHybType.SP3,2,5:])
Esp3_2 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly(
cosBAH2,
polys[:,hyb==HbHybType.SP3,2,1:3],
polys[:,hyb==HbHybType.SP3,2,3:5],
polys[:,hyb==HbHybType.SP3,2,5:])
Es[:,hyb==HbHybType.SP3] += torch.log(
torch.exp(Esp3_1 * hb_sp3_softmax_fade)
+ torch.exp(Esp3_2 * hb_sp3_softmax_fade)
) / hb_sp3_softmax_fade
cosBAH = cosangle( B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs )
Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * evalpoly(
cosBAH,
polys[:,hyb==HbHybType.SP2,2,1:3],
polys[:,hyb==HbHybType.SP2,2,3:5],
polys[:,hyb==HbHybType.SP2,2,5:])
BAH = torch.acos( cosBAH )
B0BAH = get_dih(B0_xs[:,hyb==HbHybType.SP2], B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs)
d,m,l = hb_sp2_BAH180_rise, hb_sp2_range_span, hb_sp2_outer_width
Echi = torch.full_like( B0BAH, m-0.5 )
mask1 = BAH>np.pi * 2.0 / 3.0
H = 0.5 * (torch.cos(2 * B0BAH) + 1)
F = d / 2 * torch.cos(3 * (np.pi - BAH[mask1])) + d / 2 - 0.5
Echi[mask1] = H[mask1] * F + (1 - H[mask1]) * d - 0.5
mask2 = BAH>np.pi * (2.0 / 3.0 - l)
mask2 *= ~mask1
outer_rise = torch.cos(np.pi - (np.pi * 2 / 3 - BAH[mask2]) / l)
F = m / 2 * outer_rise + m / 2 - 0.5
G = (m - d) / 2 * outer_rise + (m - d) / 2 + d - 0.5
Echi[mask2] = H[mask2] * F + (1 - H[mask2]) * d - 0.5
Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * Echi
tosquish = torch.logical_and(Es > -0.1,Es < 0.1)
Es[tosquish] = -0.025 + 0.5 * Es[tosquish] - 2.5 * torch.square(Es[tosquish])
Es[Es > 0.1] = 0.
if (normalize):
return (torch.sum( Es ) / torch.sum(aamask[seq]))
else:
return torch.sum( Es )
@torch.enable_grad()
def calc_BB_bond_geom_grads(seq, pred, idx, eps=1e-6, ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255, sig_len=0.02, sig_ang=0.05):
pred.requires_grad_(True)
Ebond = calc_BB_bond_geom(seq, pred, idx, eps, ideal_NC, ideal_CACN, ideal_CNCA, sig_len, sig_ang)
return torch.autograd.grad(Ebond, pred)
@torch.enable_grad()
def calc_cart_bonded_grads(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6):
pred.requires_grad_(True)
Ecb = calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps)
return torch.autograd.grad(Ecb, pred)
@torch.enable_grad()
def calc_ljallatom_grads(
seq, xyzaa,
aamask, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
xyzaa.requires_grad_(True)
Elj = calc_lj(
seq[0],
xyzaa[...,:3],
aamask,
ljparams,
ljcorr,
num_bonds,
lj_lin,
lj_hb_dis,
lj_OHdon_dis,
lj_hbond_hdis,
lj_maxrad,
eps
)
return torch.autograd.grad(Elj, (xyzaa,))
@torch.enable_grad()
def calc_lj_grads(
seq, xyz, alpha, toaa,
aamask, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
xyz.requires_grad_(True)
alpha.requires_grad_(True)
_, xyzaa = toaa(seq, xyz, alpha)
Elj = calc_lj(
seq[0],
xyzaa[...,:3],
aamask,
ljparams,
ljcorr,
num_bonds,
lj_lin,
lj_hb_dis,
lj_OHdon_dis,
lj_hbond_hdis,
lj_maxrad,
eps
)
return torch.autograd.grad(Elj, (xyz,alpha))
@torch.enable_grad()
def calc_hb_grads(
seq, xyz, alpha, toaa,
aamask, hbtypes, hbbaseatoms, hbpolys,
hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357,
hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True
):
xyz.requires_grad_(True)
alpha.requires_grad_(True)
_, xyzaa = toaa(seq, xyz, alpha)
Ehb = calc_hb(
seq,
xyzaa[0,...,:3],
aamask,
hbtypes,
hbbaseatoms,
hbpolys,
hb_sp2_range_span,
hb_sp2_BAH180_rise,
hb_sp2_outer_width,
hb_sp3_softmax_fade,
threshold_distance,
eps,
normalize)
return torch.autograd.grad(Ehb, xs)
def calc_pseudo_dih(pred, true, eps=1e-6):
'''
calculate pseudo CA dihedral angle and put loss on them
Input:
- predicted & true CA coordinates (I,B,L,3) / (B, L, 3)
Output:
- dihedral angle loss
'''
I, B, L = pred.shape[:3]
pred = pred.reshape(I*B, L, -1)
true_dih = torsion(true[:,:-3,:],true[:,1:-2,:],true[:,2:-1,:],true[:,3:,:]) # (B, L', 2)
pred_dih = torsion(pred[:,:-3,:],pred[:,1:-2,:],pred[:,2:-1,:],pred[:,3:,:]) # (I*B, L', 2)
pred_dih = pred_dih.reshape(I, B, -1, 2)
dih_loss = torch.square(pred_dih - true_dih).sum(dim=-1).mean()
dih_loss = torch.sqrt(dih_loss + eps)
return dih_loss
def calc_lddt(pred_ca, true_ca, mask_crds, mask_2d, same_chain, negative=False, interface=False, eps=1e-6):
# Input
# pred_ca: predicted CA coordinates (I, B, L, 3)
# true_ca: true CA coordinates (B, L, 3)
# pred_lddt: predicted lddt values (I-1, B, L)
I, B, L = pred_ca.shape[:3]
pred_dist = torch.cdist(pred_ca, pred_ca) # (I, B, L, L)
true_dist = torch.cdist(true_ca, true_ca).unsqueeze(0) # (1, B, L, L)
mask = torch.logical_and(true_dist > 0.0, true_dist < 15.0) # (1, B, L, L)
# update mask information
mask *= mask_2d[None]
if negative:
mask *= same_chain.bool()[None]
elif interface:
# ignore atoms between the same chain
mask *= ~same_chain.bool()[None]
mask_crds = mask_crds * (mask[0].sum(dim=-1) != 0)
delta = torch.abs(pred_dist-true_dist) # (I, B, L, L)
true_lddt = torch.zeros((I,B,L), device=pred_ca.device)
for distbin in [0.5, 1.0, 2.0, 4.0]:
true_lddt += 0.25*torch.sum((delta<=distbin)*mask, dim=-1) / (torch.sum(mask, dim=-1) + eps)
true_lddt = mask_crds*true_lddt
true_lddt = true_lddt.sum(dim=(1,2)) / (mask_crds.sum() + eps)
return true_lddt
#fd allatom lddt
def calc_allatom_lddt(P, Q, idx, atm_mask, eps=1e-6):
# P - N x L x 27 x 3
# Q - L x 27 x 3
N, L = P.shape[:2]
# distance matrix
Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27)
Pij = torch.sqrt( Pij.sum(dim=-1) + eps)
Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27)
Qij = torch.sqrt( Qij.sum(dim=-1) + eps)
# get valid pairs
pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A
# ignore missing atoms
pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float()
# ignore atoms within same residue
pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27)
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
lddt = torch.zeros( (N,L,27), device=P.device ) # (N, L, 27)
for distbin in (0.5,1.0,2.0,4.0):
lddt += 0.25 * torch.sum( (delta_PQ<=distbin)*pair_mask, dim=(2,4)
) / ( torch.sum( pair_mask, dim=(2,4) ) + 1e-8)
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
return lddt
def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain, negative=False, interface=False, eps=1e-6):
# P - N x L x 27 x 3
# Q - L x 27 x 3
# pred_lddt - 1 x nbucket x L
N, L, Natm = P.shape[:3]
# distance matrix
Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27)
Pij = torch.sqrt( Pij.sum(dim=-1) + eps)
Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27)
Qij = torch.sqrt( Qij.sum(dim=-1) + eps)
# get valid pairs
pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A
# ignore missing atoms
pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float()
# ignore atoms within same residue
pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27)
if negative:
# ignore atoms between different chains
pair_mask *= same_chain.bool()[:,:,:,None,None]
delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14)
lddt = torch.zeros( (N,L,Natm), device=P.device ) # (N, L, 27)
for distbin in (0.5,1.0,2.0,4.0):
lddt += 0.25 * torch.sum( (delta_PQ<=distbin)*pair_mask, dim=(2,4)
) / ( torch.sum( pair_mask, dim=(2,4) ) + eps)
final_lddt_by_res = torch.clamp(
(lddt[-1]*atm_mask[0]).sum(-1)
/ (atm_mask.sum(-1) + eps), min=0.0, max=1.0)
# calculate lddt prediction loss
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long()
lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')(
pred_lddt, true_lddt_label[-1])
res_mask = atm_mask.any(dim=-1)
lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps)
# method 1: average per-residue
#lddt = lddt.sum(dim=-1) / (atm_mask.sum(dim=-1)+1e-8) # L
#lddt = (res_mask*lddt).sum() / (res_mask.sum() + 1e-8)
# method 2: average per-atom
atm_mask = atm_mask * (pair_mask.sum(dim=(1,3)) != 0)
lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps)
return lddt_loss, lddt

57
RF2_allatom/memory.py Normal file
View File

@@ -0,0 +1,57 @@
import gc
import torch
## MEM utils ##
def mem_report():
'''Report the memory usage of the tensor.storage in pytorch
Both on CPUs and GPUs are reported'''
def _mem_report(tensors, mem_type):
'''Print the selected tensors of type
There are two major storage types in our major concern:
- GPU: tensors transferred to CUDA devices
- CPU: tensors remaining on the system memory (usually unimportant)
Args:
- tensors: the tensors of specified type
- mem_type: 'CPU' or 'GPU' in current implementation '''
print('Storage on %s' %(mem_type))
print('-'*LEN)
total_numel = 0
total_mem = 0
visited_data = []
for tensor in tensors:
if tensor.is_sparse:
continue
# a data_ptr indicates a memory block allocated
data_ptr = tensor.storage().data_ptr()
if data_ptr in visited_data:
continue
visited_data.append(data_ptr)
numel = tensor.storage().size()
total_numel += numel
element_size = tensor.storage().element_size()
mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte
total_mem += mem
element_type = type(tensor).__name__
size = tuple(tensor.size())
print('%s\t\t%s\t\t%.2f' % (
element_type,
size,
mem) )
print('-'*LEN)
print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) )
print('-'*LEN)
LEN = 65
print('='*LEN)
objects = gc.get_objects()
print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') )
tensors = [obj for obj in objects if torch.is_tensor(obj)]
cuda_tensors = [t for t in tensors if t.is_cuda]
host_tensors = [t for t in tensors if not t.is_cuda]
_mem_report(cuda_tensors, 'GPU')
_mem_report(host_tensors, 'CPU')
print('='*LEN)

439
RF2_allatom/parsers.py Normal file
View File

@@ -0,0 +1,439 @@
import numpy as np
import scipy
import scipy.spatial
import string
import os,re
from os.path import exists
import random
import util
import gzip
from ffindex import *
import torch
from chemical import NAATOKENS, aa2num, aa2long
from rdkit import Chem
to1letter = {
"ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C',
"GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I',
"LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P',
"SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V',
"DA":'a', "DC":'c', "DG":'g', "DT":'t',
"A":'b', "C":'d', "G":'h', "U":'u',
}
def read_template_pdb(L, pdb_fn, target_chain=None):
# get full sequence from given PDB
seq_full = list()
prev_chain=''
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(aa2num[aa] if aa in aa2num.keys() else 20)
seq_full = torch.tensor(seq_full).long()
xyz = torch.full((L, 36, 3), np.nan).float()
seq = torch.full((L,), 20).long()
conf = torch.zeros(L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = aa2num[aa] if aa in aa2num.keys() else 20
#
idx = resNo - 1
for i_atm, tgtatm in enumerate(aa2long[aa_idx]):
if tgtatm == atom:
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
break
seq[idx] = aa_idx
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
mask = mask.all(dim=-1)[:,None]
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
t1d = torch.cat((seq_1hot, conf), -1)
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
return xyz[None], t1d[None]
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
# convert letters into numbers
if rmsa_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# parse a fasta alignment IF it exists
# otherwise return single-sequence msa
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
if (exists(filename)):
return parse_fasta(filename, maxseq, rmsa_alphabet)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
seq[seq == alphabet[i]] = i
return (seq, np.zeros_like(seq))
# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
def parse_a3m(filename, unzip=True, maxseq=10000):
msa = []
ins = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
if (unzip):
fstream = gzip.open(filename,"rt")
else:
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# sequence length
L = len(msa[-1])
# remove insertion at the end
if (not unzip):
n_remove = 0
for c in reversed(line):
if c.islower():
n_remove += 1
else:
break
line = line[:-n_remove]
# 0 - match or gap; 1 - insertion
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
i = np.zeros((L))
if np.sum(a) > 0:
# positions of insertions
pos = np.where(a==1)[0]
# shift by occurrence
a = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(a, return_counts=True)
# append to the matrix of insetions
i[pos] = num
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename):
lines = open(filename,'r').readlines()
return parse_pdb_lines(lines)
#'''
def parse_pdb_lines(lines):
# indices of residues observed in the structure
idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), 27, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
idx = idx_s.index(resNo)
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_pdb_lines_w_seq(lines):
# indices of residues observed in the structure
#idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]
idx_s = [int(r[0]) for r in res]
seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), 27, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
idx = idx_s.index(resNo)
for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s), np.array(seq)
def parse_templates(item, params):
# init FFindexDB of templates
### and extract template IDs
### present in the DB
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
read_data(params['FFDB']+'_pdb.ffdata'))
#ffids = set([i.name for i in ffdb.index])
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(infile[:-4]+'hhr', "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(np.bool)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return xyz,mask,qmap,f0d,f1d,ids
def parse_templates_raw(ffdb, hhr_fn, atab_fn):
# process tabulated hhsearch output to get
# matched positions and positional scores
hits = []
for l in open(atab_fn, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines_w_seq(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
seq.append(data[-1][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
qmap = np.vstack(qmap).astype(np.long)
f1d = np.vstack(f1d).astype(np.float32)
seq = np.hstack(seq).astype(np.long)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(qmap), \
torch.from_numpy(f1d), torch.from_numpy(seq), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn)
npick = min(n_templ, len(ids))
if npick < 1: # no templates
xyz = torch.full((1,qlen,27,3),np.nan).float()
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
return xyz, t1d
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, 27, 3), np.nan).float()
f1d = torch.full((npick, qlen), 20).long()
f1d_val = torch.zeros((npick, qlen, 1)).float()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel]
f1d[i, pos] = seq[sel]
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
f1d = torch.nn.functional.one_hot(f1d, num_classes=21).float()
f1d = torch.cat((f1d, f1d_val), dim=-1)
return xyz, f1d
def parse_mol(filename):
"""parses a mol file"""
mol = Chem.MolFromMolFile(filename)
msa = torch.tensor([aa2num[a.GetSymbol()] for a in mol.GetAtoms()])
ins = torch.zeros_like(msa)
return mol, msa, ins
def get_ligand_xyz(mol):
xyz = torch.tensor(np.array([c.GetPositions() for c in mol.GetConformers()])).squeeze(0)
permuts = mol.GetSubstructMatches(mol, uniquify=False, maxMatches=256)
permuts = torch.tensor(permuts)
Y = xyz[permuts].reshape(-1,mol.GetNumAtoms(),3)
mask = torch.full(Y.shape[:-1], True)
return Y, mask

View File

@@ -0,0 +1,334 @@
import sys, os
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from parsers import parse_a3m, read_templates
from RoseTTAFoldModel import RoseTTAFoldModule
import util
from collections import namedtuple
from ffindex import *
from data_loader import MSAFeaturize, MSABlockDeletion
from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d, get_init_xyz
from util_module import ComputeAllAtomCoords
from memory import mem_report
MAX_CYCLE = 20
NREPLICATES = 5
NBIN = [37, 37, 37, 19]
MODEL_PARAM ={
"n_extra_block" : 4,
"n_main_block" : 32,
"n_ref_block" : 0,
"n_finetune_block" : 4,
"d_msa" : 256 ,
"d_pair" : 128,
"d_templ" : 64,
"n_head_msa" : 8,
"n_head_pair" : 4,
"n_head_templ" : 4,
"d_hidden" : 32,
"d_hidden_templ" : 64,
"p_drop" : 0.0,
"lj_lin" : 0.7
}
SE3_param = {
"num_layers" : 1,
"num_channels" : 32,
"num_degrees" : 2,
"l0_in_features": 64,
"l0_out_features": 64,
"l1_in_features": 3,
"l1_out_features": 2,
"num_edge_features": 64,
"div": 4,
"n_heads": 4
}
MODEL_PARAM['SE3_param'] = SE3_param
# params for the folding protocol
fold_params = {
"SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21,
"SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231,
"DCUT" : 19.5,
"ALPHA" : 1.57,
# TODO: add Cb to the motif
"NCAC" : np.array([[-0.676, -1.294, 0. ],
[ 0. , 0. , 0. ],
[ 1.5 , -0.174, 0. ]], dtype=np.float32),
"CLASH" : 2.0,
"PCUT" : 0.5,
"DSTEP" : 0.5,
"ASTEP" : np.deg2rad(10.0),
"XYZRAD" : 7.5,
"WANG" : 0.1,
"WCST" : 0.1
}
fold_params["SG"] = fold_params["SG9"]
# compute expected value from binned lddt
def lddt_unbin(pred_lddt):
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
class Predictor():
def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"):
if model_dir == None:
self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__)))
else:
self.model_dir = model_dir
#
# define model name
self.model_name = model_name
self.device = device
self.active_fn = nn.Softmax(dim=1)
# define model & load model
self.model = RoseTTAFoldModule(
**MODEL_PARAM,
aamask=util.allatom_mask.to(self.device),
atom_type_index=util.atom_type_index.to(self.device),
ljlk_parameters=util.ljlk_parameters.to(self.device),
lj_correction_parameters=util.lj_correction_parameters.to(self.device),
num_bonds=util.num_bonds.to(self.device),
cb_len = util.cb_length_t.to(self.device),
cb_ang = util.cb_angle_t.to(self.device),
cb_tor = util.cb_torsion_t.to(self.device),
).to(self.device)
could_load = self.load_model(self.model_name)
if not could_load:
print ("ERROR: failed to load model")
sys.exit()
self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device)
def load_model(self, model_name, suffix='last'):
chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix)
if not os.path.exists(chk_fn):
return False
checkpoint = torch.load(chk_fn, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
return True
def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=1e9, shift=50, n_latent=256):
msa_orig, ins_orig = parse_a3m(a3m_fn, unzip=False)
N, L = msa_orig.shape
#
if os.path.exists(hhr_fn):
xyz_t, t1d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=4)
else:
xyz_t = torch.full((1,L,3,3),np.nan).float()
t1d = torch.nn.functional.one_hot(torch.full((1, L), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,L,1)).float()), -1)
#
# template features
xyz_t = xyz_t.float().unsqueeze(0)
t1d = t1d.float().unsqueeze(0)
t2d = xyz_to_t2d(xyz_t)
xyz_t = get_init_xyz(xyz_t) # initialize coordinates with first template
seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L)
alpha, _, alpha_mask, _ = util.get_torsions(
xyz_t.reshape(-1,L,27,3),
seq_tmp,
util.torsion_indices,
util.torsion_can_flip,
util.reference_angles
)
alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0]))
alpha[torch.isnan(alpha)] = 0.0
alpha = alpha.reshape(1,-1,L,10,2)
alpha_mask = alpha_mask.reshape(1,-1,L,10,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30)
self.model.eval()
for i_trial in range(NREPLICATES):
if os.path.exists("%s_%02d_init.pdb"%(out_prefix, i_trial)):
continue
self.run_prediction(msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial), n_latent=n_latent)
torch.cuda.empty_cache()
def run_prediction(self, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz, alpha_t, out_prefix, n_latent=256):
start = time.time()
torch.cuda.reset_peak_memory_stats()
with torch.no_grad():
#
msa = torch.tensor(msa_orig).long().to(self.device) # (N, L)
ins = torch.tensor(ins_orig).long().to(self.device)
if msa_orig.shape[0] > 4096:
msa, ins = MSABlockDeletion(msa, ins)
print (msa_orig.shape, msa.shape)
#
seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, p_mask=0.0)
_, N, L = msa_seed.shape[:3]
B = 1
#
idx_pdb = torch.arange(L).long().view(1, L)
#
seq = seq.unsqueeze(0)
msa_seed = msa_seed.unsqueeze(0)
msa_extra = msa_extra.unsqueeze(0)
t1d = t1d.to(self.device)
t2d = t2d.to(self.device)
idx_pdb = idx_pdb.to(self.device)
xyz_t = xyz_t.to(self.device)
alpha_t = alpha_t.to(self.device)
xyz = xyz.to(self.device)
self.write_pdb(seq[0, -1], xyz[0], prefix="%s_templ"%(out_prefix))
msa_prev = None
pair_prev = None
alpha_prev = torch.zeros((1,L,10,2), device=seq.device)
xyz_prev=xyz
state_prev = None
best_lddt = torch.tensor([-1.0], device=seq.device)
best_xyz = None
best_logit = None
best_aa = None
for i_cycle in range(MAX_CYCLE):
with torch.cuda.amp.autocast(True):
logit_s, logit_aa_s, init_crds, alpha_prev, init_allatom, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model(
msa_seed[:,i_cycle],
msa_extra[:,i_cycle],
seq[:,i_cycle],
xyz_prev,
alpha_prev,
idx_pdb,
t1d=t1d,
t2d=t2d,
xyz_t=xyz_t,
alpha_t=alpha_t,
msa_prev=msa_prev,
pair_prev=pair_prev,
state_prev=state_prev
)
logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1)
#xyz_prev = init_crds[-1]
xyz_prev = init_allatom[-1].unsqueeze(0)
#msa_prev = msa_prev[:,0]
alpha_prev = alpha_prev[-1]
pred_lddt = lddt_unbin(pred_lddt_binned)
print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean())
#_, all_crds = self.compute_allatom_coords(seq[:,i_cycle], init_crds[-1], alpha_prev)
self.write_pdb(seq[0, -1], init_allatom[-1], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle))
if pred_lddt.mean() < best_lddt.mean():
continue
best_xyz = init_allatom[-1].clone()
best_logit = logit_s
best_aa = logit_aa_s
best_lddt = pred_lddt.clone()
prob_s = list()
for logit in logit_s:
prob = self.active_fn(logit.float()) # distogram
prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy()
prob_s.append(prob)
end = time.time()
for prob in prob_s:
prob += 1e-8
prob = prob / torch.sum(prob, dim=0)[None]
self.write_pdb(seq[0, -1], best_xyz, Bfacts=best_lddt[0], prefix="%s_init"%(out_prefix))
prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s]
np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \
omega=prob_s[1].astype(np.float16),\
theta=prob_s[2].astype(np.float16),\
phi=prob_s[3].astype(np.float16),\
lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16))
max_mem = torch.cuda.max_memory_allocated()/1e9
print ("max mem", max_mem)
print ("runtime", end-start)
def write_pdb(self, seq, atoms, Bfacts=None, prefix=None):
L = len(seq)
filename = "%s.pdb"%prefix
ctr = 1
with open(filename, 'wt') as f:
if Bfacts == None:
Bfacts = np.zeros(L)
else:
Bfacts = torch.clamp( Bfacts, 0, 1)
for i,s in enumerate(seq):
if (len(atoms.shape)==2):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, " CA ", util.num2aa[s],
"A", i+1, atoms[i,0], atoms[i,1], atoms[i,2],
1.0, Bfacts[i] ) )
ctr += 1
elif atoms.shape[1]==3:
for j,atm_j in enumerate((" N "," CA "," C ")):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
else:
atms = util.aa2long[s]
for j,atm_j in enumerate(atms):
if (atm_j is not None):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, util.num2aa[s],
"A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
def get_args():
#DB="/home/robetta/rosetta_server_beta/external/databases/trRosetta/pdb100_2021Mar03/pdb100_2021Mar03"
DB = "/projects/ml/TrRosetta/pdb100_2020Mar11/pdb100_2020Mar11"
import argparse
parser = argparse.ArgumentParser(description="RoseTTAFold: Protein structure prediction with 3-track attentions on 1D, 2D, and 3D features")
parser.add_argument("-db", default=DB, required=False,
help="HHsearch database [%s]"%DB)
parser.add_argument("-model_name", default="BFF", required=False,
help="Prefix for model. The model under models/[model_name]_best.pt will be used. [BFF]")
parser.add_argument("-i", default=1, type=int, required=False, help="Parallelize i of j [1]")
parser.add_argument("-j", default=1, type=int, required=False, help="Parallelize i of j [1]")
args = parser.parse_args()
return args
casp14_home = "/home/minkbaek/CASP14.bench"
if __name__ == "__main__":
args = get_args()
#tar_s = [line.strip() for line in open("%s/tar_s"%casp14_home)]
tar_s = [line.strip() for line in open("./tar_s")]
FFDB = args.db
FFindexDB = namedtuple("FFindexDB", "index, data")
ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'),
read_data(FFDB+'_pdb.ffdata'))
pred = Predictor(model_name=args.model_name)
for i_str,tar in enumerate(tar_s):
if (i_str % args.j == args.i % args.j):
print (tar)
if not os.path.exists(tar):
os.mkdir(tar)
out_prefix = "%s/%s"%(tar, tar)
a3m_fn = "%s/a3m.final/%s.a3m"%(casp14_home,tar)
hhr_fn = "%s/hhr.final/%s.hhr"%(casp14_home,tar)
atab_fn = "%s/hhr.final/%s.atab"%(casp14_home, tar)
if not os.path.exists("%s_04.npz"%out_prefix):
pred.predict(a3m_fn, out_prefix, hhr_fn, atab_fn)

72
RF2_allatom/resnet.py Normal file
View File

@@ -0,0 +1,72 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
# pre-activation bottleneck resblock
class ResBlock2D_bottleneck(nn.Module):
def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15):
super(ResBlock2D_bottleneck, self).__init__()
padding = self._get_same_padding(kernel, dilation)
n_b = n_c // 2 # bottleneck channel
layer_s = list()
# pre-activation
layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# project down to n_b
layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# convolution
layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False))
layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6))
layer_s.append(nn.ELU(inplace=True))
# dropout
layer_s.append(nn.Dropout(p_drop))
# project up
layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False))
# make final layer initialize with zeros
#nn.init.zeros_(layer_s[-1].weight)
self.layer = nn.Sequential(*layer_s)
self.reset_parameter()
def reset_parameter(self):
# zero-initialize final layer right before residual connection
nn.init.zeros_(self.layer[-1].weight)
def _get_same_padding(self, kernel, dilation):
return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2
def forward(self, x):
out = self.layer(x)
return x + out
class ResidualNetwork(nn.Module):
def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out,
dilation=[1,2,4,8], p_drop=0.15):
super(ResidualNetwork, self).__init__()
layer_s = list()
# project to n_feat_block
if n_feat_in != n_feat_block:
layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False))
# add resblocks
for i_block in range(n_block):
d = dilation[i_block%len(dilation)]
res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop)
layer_s.append(res_block)
if n_feat_out != n_feat_block:
# project to n_feat_out
layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1))
self.layer = nn.Sequential(*layer_s)
def forward(self, x):
return self.layer(x)

30
RF2_allatom/run.sh Executable file
View File

@@ -0,0 +1,30 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=1
python -u ./train_multi_EMA.py \
-model_name BFF20h \
-p_drop 0.0 \
-maxcycle 4 \
-n_extra_block 4 \
-n_main_block 32 \
-n_ref_block 4 \
-n_finetune_block 0 \
-ref_num_layers 2 \
-accum 4 \
-crop 256 \
-w_bond 0.0 \
-w_dih 0.0 \
-w_clash 0.0 \
-w_hb 0.0 \
-lj_lin 0.7 \
-w_dist 1.0 \
-w_str 10.0 \
-w_lddt 0.1 \
-w_aa 3.0 \
-subsmp UNI \
-num_epochs 400 \
-slice CONT \
-lr 0.001 \
-port 12345 \
-eval

180
RF2_allatom/scheduler.py Normal file
View File

@@ -0,0 +1,180 @@
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
#def get_cosine_with_hard_restarts_schedule_with_warmup(
# optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
#):
# """
# Create a schedule with a learning rate that decreases following the values of the cosine function between the
# initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
# linearly between 0 and the initial lr set in the optimizer.
#
# Args:
# optimizer (:class:`~torch.optim.Optimizer`):
# The optimizer for which to schedule the learning rate.
# num_warmup_steps (:obj:`int`):
# The number of steps for the warmup phase.
# num_training_steps (:obj:`int`):
# The total number of training steps.
# num_cycles (:obj:`int`, `optional`, defaults to 1):
# The number of hard restarts to use.
# last_epoch (:obj:`int`, `optional`, defaults to -1):
# The index of the last epoch when resuming training.
#
# Return:
# :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
# """
#
# def lr_lambda(current_step):
# if current_step < num_warmup_steps:
# return float(current_step) / float(max(1, num_warmup_steps))
# progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
# if progress >= 1.0:
# return 0.0
# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
#
# return LambdaLR(optimizer, lr_lambda, last_epoch)
#
class CosineAnnealingWarmupRestarts(_LRScheduler):
"""
optimizer (Optimizer): Wrapped optimizer.
first_cycle_steps (int): First cycle step size.
cycle_mult(float): Cycle steps magnification. Default: -1.
max_lr(float): First cycle's max learning rate. Default: 0.1.
min_lr(float): Min learning rate. Default: 0.001.
warmup_steps(int): Linear warmup step size. Default: 0.
gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(self,
optimizer : torch.optim.Optimizer,
first_cycle_steps : int,
cycle_mult : float = 1.,
max_lr : float = 0.1,
min_lr : float = 0.001,
warmup_steps : int = 0,
gamma : float = 1.,
last_epoch : int = -1
):
assert warmup_steps < first_cycle_steps
self.first_cycle_steps = first_cycle_steps # first cycle step size
self.cycle_mult = cycle_mult # cycle steps magnification
self.base_max_lr = max_lr # first max learning rate
self.max_lr = max_lr # max learning rate in the current cycle
self.min_lr = min_lr # min learning rate
self.warmup_steps = warmup_steps # warmup step size
self.gamma = gamma # decrease rate of max learning rate by cycle
self.cur_cycle_steps = first_cycle_steps # first cycle step size
self.cycle = 0 # cycle count
self.step_in_cycle = last_epoch # step size of the current cycle
super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
# set learning rate min_lr
self.init_lr()
def init_lr(self):
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
self.base_lrs.append(self.min_lr)
def get_lr(self):
if self.step_in_cycle == -1:
return self.base_lrs
elif self.step_in_cycle < self.warmup_steps:
return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs]
else:
return [base_lr + (self.max_lr - base_lr) \
* (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \
/ (self.cur_cycle_steps - self.warmup_steps))) / 2
for base_lr in self.base_lrs]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.step_in_cycle = self.step_in_cycle + 1
if self.step_in_cycle >= self.cur_cycle_steps:
self.cycle += 1
self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps
else:
if epoch >= self.first_cycle_steps:
if self.cycle_mult == 1.:
self.step_in_cycle = epoch % self.first_cycle_steps
self.cycle = epoch // self.first_cycle_steps
else:
n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult))
self.cycle = n
self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1))
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n)
else:
self.cur_cycle_steps = self.first_cycle_steps
self.step_in_cycle = epoch
self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_ratio=0.001, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
min_ratio, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
def get_stepwise_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_steps_decay, decay_rate, last_epoch=-1):
"""
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
Args:
optimizer (:class:`~torch.optim.Optimizer`):
The optimizer for which to schedule the learning rate.
num_warmup_steps (:obj:`int`):
The number of steps for the warmup phase.
num_training_steps (:obj:`int`):
The total number of training steps.
last_epoch (:obj:`int`, `optional`, defaults to -1):
The index of the last epoch when resuming training.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
num_fades = (current_step-num_warmup_steps)//num_steps_decay
return (decay_rate**num_fades)
return LambdaLR(optimizer, lr_lambda, last_epoch)

310
RF2_allatom/scoring.py Normal file
View File

@@ -0,0 +1,310 @@
import json
##
## lk and lk term
#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME)
type2ljlk = {
"CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000),
"COO":(1.916661,0.141799,-3.332648,3.5000,14.653000),
"CH0":(2.011760,0.062642,1.409284,3.5000,8.998000),
"CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000),
"CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000),
"CH3":(2.011760,0.062642,7.292929,3.5000,25.855000),
"aroC":(2.016441,0.068775,1.797950,3.5000,16.704000),
"Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100),
"Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700),
"NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200),
"NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000),
"Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000),
"Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000),
"Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100),
"OH":(1.542743,0.161947,-8.133520,3.5000,10.722000),
"OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000),
"ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000),
"OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600),
"S":(1.975967,0.455970,-1.707229,3.5000,17.640000),
"SH1":(1.975967,0.455970,3.291643,3.5000,23.240000),
"Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000),
"CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000),
"CObb":(1.916661,0.141799,3.104248,3.5000,13.221000),
"OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000),
"Phos":(2.1500,0.5850,-4.1000,3.5000,14.7000), # phil
"Oet2":(1.5500,0.1591,-5.8500,3.5000,10.8000),
"Oet3":(1.5500,0.1591,-6.7000,3.5000,10.8000),
"HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000),
"Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000),
"Haro":(1.374914,0.015909,0.0000,3.5000,0.0000),
"Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000),
"HS":(0.363887,0.050836,0.0000,3.5000,0.0000),
}
# cartbonded
with open('cartbonded.json', 'r') as j:
cartbonded_data_raw = json.loads(j.read())
# hbond donor/acceptors
class HbAtom:
NO = 0
DO = 1 # donor
AC = 2 # acceptor
DA = 3 # donor & acceptor
HP = 4 # polar H
type2hb = {
"CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO,
"CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO,
"Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO,
"Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA,
"ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO,
"Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC,
"HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP,
"HS":HbAtom.HP, # HP in rosetta(?)
"Phos":HbAtom.NO, "Oet2":HbAtom.AC, "Oet3":HbAtom.AC
}
##
## hbond term
## TO DO: ADD DNA
class HbDonType:
PBA = 0
IND = 1
IME = 2
GDE = 3
CXA = 4
AMO = 5
HXL = 6
AHX = 7
NTYPES = 8
class HbAccType:
PBA = 0
CXA = 1
CXL = 2
HXL = 3
AHX = 4
IME = 5
NTYPES = 6
class HbHybType:
SP2 = 0
SP3 = 1
RING = 2
NTYPES = 3
type2dontype = {
"Nbb": HbDonType.PBA,
"Ntrp": HbDonType.IND,
"NtrR": HbDonType.GDE,
"Narg": HbDonType.GDE,
"NH2O": HbDonType.CXA,
"Nlys": HbDonType.AMO,
"OH": HbDonType.HXL,
"OHY": HbDonType.AHX,
}
type2acctype = {
"OCbb": HbAccType.PBA,
"ONH2": HbAccType.CXA,
"OOC": HbAccType.CXL,
"OH": HbAccType.HXL,
"OHY": HbAccType.AHX,
"Nhis": HbAccType.IME,
}
type2hybtype = {
"OCbb": HbHybType.SP2,
"ONH2": HbHybType.SP2,
"OOC": HbHybType.SP2,
"OHY": HbHybType.SP3,
"OH": HbHybType.SP3,
"Nhis": HbHybType.RING,
}
dontype2wt = {
HbDonType.PBA: 1.45,
HbDonType.IND: 1.15,
HbDonType.IME: 1.42,
HbDonType.GDE: 1.11,
HbDonType.CXA: 1.29,
HbDonType.AMO: 1.17,
HbDonType.HXL: 0.99,
HbDonType.AHX: 1.00,
}
acctype2wt = {
HbAccType.PBA: 1.19,
HbAccType.CXA: 1.21,
HbAccType.CXL: 1.10,
HbAccType.HXL: 1.15,
HbAccType.AHX: 1.15,
HbAccType.IME: 1.17,
}
class HbPolyType:
ahdist_aASN_dARG = 0
ahdist_aASN_dASN = 1
ahdist_aASN_dGLY = 2
ahdist_aASN_dHIS = 3
ahdist_aASN_dLYS = 4
ahdist_aASN_dSER = 5
ahdist_aASN_dTRP = 6
ahdist_aASN_dTYR = 7
ahdist_aASP_dARG = 8
ahdist_aASP_dASN = 9
ahdist_aASP_dGLY = 10
ahdist_aASP_dHIS = 11
ahdist_aASP_dLYS = 12
ahdist_aASP_dSER = 13
ahdist_aASP_dTRP = 14
ahdist_aASP_dTYR = 15
ahdist_aGLY_dARG = 16
ahdist_aGLY_dASN = 17
ahdist_aGLY_dGLY = 18
ahdist_aGLY_dHIS = 19
ahdist_aGLY_dLYS = 20
ahdist_aGLY_dSER = 21
ahdist_aGLY_dTRP = 22
ahdist_aGLY_dTYR = 23
ahdist_aHIS_dARG = 24
ahdist_aHIS_dASN = 25
ahdist_aHIS_dGLY = 26
ahdist_aHIS_dHIS = 27
ahdist_aHIS_dLYS = 28
ahdist_aHIS_dSER = 29
ahdist_aHIS_dTRP = 30
ahdist_aHIS_dTYR = 31
ahdist_aSER_dARG = 32
ahdist_aSER_dASN = 33
ahdist_aSER_dGLY = 34
ahdist_aSER_dHIS = 35
ahdist_aSER_dLYS = 36
ahdist_aSER_dSER = 37
ahdist_aSER_dTRP = 38
ahdist_aSER_dTYR = 39
ahdist_aTYR_dARG = 40
ahdist_aTYR_dASN = 41
ahdist_aTYR_dGLY = 42
ahdist_aTYR_dHIS = 43
ahdist_aTYR_dLYS = 44
ahdist_aTYR_dSER = 45
ahdist_aTYR_dTRP = 46
ahdist_aTYR_dTYR = 47
cosBAH_off = 48
cosBAH_7 = 49
cosBAH_6i = 50
AHD_1h = 51
AHD_1i = 52
AHD_1j = 53
AHD_1k = 54
# map donor:acceptor pairs to polynomials
hbtypepair2poly = {
(HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
}
# polynomials are triplets, (x_min, x_max), (y[x<x_min],y[x>x_max]), (c_9,...,c_0)
hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886
HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)),
HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)),
HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)),
HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)),
HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)),
HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)),
HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)),
HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)),
HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)),
HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)),
HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)),
HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)),
HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)),
HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)),
HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)),
HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)),
HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)),
HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)),
HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)),
HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)),
HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)),
HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)),
HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)),
HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)),
HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)),
HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)),
HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)),
HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)),
HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)),
HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)),
HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)),
HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)),
HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)),
HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)),
HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)),
HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)),
HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)),
HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)),
HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)),
HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)),
HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)),
HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)),
HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)),
HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)),
HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)),
HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)),
HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)),
HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)),
HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)),
HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)),
HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)),
HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)),
HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)),
}

227
RF2_allatom/tests.py Normal file
View File

@@ -0,0 +1,227 @@
import unittest
import torch
from torch.utils import data
from chemical import NFRAMES
from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params
from kinematics import xyz_to_c6d, xyz_to_t2d
from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz
class LossTestCase(unittest.TestCase):
def setUp(self):
self.loader_param = set_data_loader_params({})
(
pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items,
sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl,
valid_na_neg, valid_rna, valid_sm_compl, homo
) = get_train_valid_set(self.loader_param)
pdb_IDs, pdb_weights, pdb_dict = pdb_items
na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items
rna_IDs, rna_weights, rna_dict = rna_items
sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items
self.homo = homo
valid_pdb_set = Dataset(
list(valid_pdb.keys()),
loader_pdb, valid_pdb,
self.loader_param, homo, p_homo_cut=-1.0
)
valid_na_compl_set = DatasetNAComplex(
list(valid_na_compl.keys()),
loader_na_complex, valid_na_compl,
self.loader_param, negative=False, native_NA_frac=1.0
)
valid_sm_compl_set = DatasetSMComplex(
list(sm_compl_dict.keys()),
loader_sm_compl, sm_compl_dict,
self.loader_param
)
self.valid_pdb_loader = data.DataLoader(valid_pdb_set)
self.valid_na_compl_loader = data.DataLoader(valid_na_compl_set)
self.valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set)
def test_compute_general_FAPE(self):
with self.subTest("test that FAPE loss is correctly calculated for proteins"):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_pdb_loader:
# first assert that same structure gives you 0 loss
frames, frame_mask = get_frames(
true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames
)
l_fape = compute_general_FAPE(
true_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
self.assertAlmostEqual(int(l_fape.numpy()),0)
fapes = []
for i in range(5):
perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1))
l_fape = compute_general_FAPE(
perturbed_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
fapes.append(l_fape)
for i in range(1,5):
self.assertLess(fapes[i-1], fapes[i])
break
#add noise and make sure increasing noise increases loss
with self.subTest("test that FAPE loss is correctly calculated for protein/NA complexes"):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_na_compl_loader:
# first assert that same structure gives you 0 loss
frames, frame_mask = get_frames(
true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames
)
l_fape = compute_general_FAPE(
true_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
self.assertAlmostEqual(int(l_fape.numpy()),0)
fapes = []
for i in range(5):
perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1))
l_fape = compute_general_FAPE(
perturbed_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
fapes.append(l_fape)
for i in range(1,5):
self.assertLess(fapes[i-1], fapes[i])
break
with self.subTest("test that FAPE loss is correctly calculated for protein/SM complexes"):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_sm_compl_loader:
# first assert that same structure gives you 0 loss
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
frames, frame_mask = get_frames(
true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames
)
l_fape = compute_general_FAPE(
true_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
self.assertAlmostEqual(int(l_fape.numpy()),0)
fapes = []
for i in range(5):
perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1))
l_fape = compute_general_FAPE(
perturbed_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
fapes.append(l_fape)
for i in range(1,5):
self.assertLess(fapes[i-1], fapes[i])
break
with self.subTest("test that protein backbone FAPE loss can be calculated with compute_general_FAPE"):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_pdb_loader:
frames, frame_mask = get_frames(
true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames
)
frame_mask[...,1:] = False
l_fape = compute_general_FAPE(
true_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
self.assertAlmostEqual(int(l_fape.numpy()),0)
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,0,0])))
mask_2d = res_mask[:,None,:] * res_mask[:,:,None]
fapes = []
for i in range(5):
perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1))
l_fape = compute_general_FAPE(
perturbed_crds,
true_crds,
atom_mask,
frames,
frame_mask
)
fapes.append(l_fape)
tot_str, str_loss = calc_str_loss(perturbed_crds.unsqueeze(0), true_crds, mask_2d, same_chain, negative=False,
A=10.0, d_clamp_intra=10.0, d_clamp_inter=10.0, gamma=1.0, eps=1e-4)
self.assertAlmostEqual(int(l_fape.numpy()), int(tot_str.numpy()))
for i in range(1,5):
self.assertLess(fapes[i-1], fapes[i])
break
def test_get_frames(self):
"""test that nodes in atom frames are relatively close to each other (because they should be bonded)"""
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_sm_compl_loader:
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
frames, frame_mask = get_frames(
true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames
)
N, L, natoms, _ = true_crds.shape
# flatten middle dims so can gather across residues
X_prime = true_crds.reshape(N, L*natoms, -1, 3).repeat(1,1,NFRAMES,1)
frames_reindex = torch.zeros(frames.shape[:-1])
for i in range(L):
frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1]
frames_reindex = frames_reindex.long()
frame_mask *= torch.all(
torch.gather(atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3),
axis=-1)
X_x = torch.gather(X_prime, 1, frames_reindex[...,0:1].repeat(N,1,1,3))
X_y = torch.gather(X_prime, 1, frames_reindex[...,1:2].repeat(N,1,1,3))
X_z = torch.gather(X_prime, 1, frames_reindex[...,2:3].repeat(N,1,1,3))
atoms = is_atom(msa[:, 0,0])
frame_distance1 = torch.cdist(X_x[atoms], X_y[atoms])
frame_distance2 = torch.cdist(X_y[atoms], X_z[atoms])
self.assertTrue(torch.all(frame_distance1[:,0,0] <2))
self.assertTrue(torch.all(frame_distance2[:,0,0] <2))
break
def test_xyz_to_c6d(self):
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_sm_compl_loader:
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
# atoms = is_atom(msa[:, 0,0])
# atom_crds = true_crds[atoms]
# atom_L, natoms, _ = atom_crds.shape
# frames_reindex = torch.zeros(atom_frames.shape[:-1])
# for i in range(atom_L):
# frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1]
# frames_reindex = frames_reindex.long()
# true_crds[atoms, :, :3] = atom_crds.reshape(atom_L*natoms, 3)[frames_reindex]
true_crds = xyz_to_frame_xyz(true_crds, msa[:, 0,0], atom_frames)
c6d, _ = xyz_to_c6d(true_crds)
xyz_t = xyz_t_to_frame_xyz(xyz_t, msa[:, 0,0].squeeze(0), atom_frames)
t2d = xyz_to_t2d(xyz_t)
break
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

883
RF2_allatom/util.py Normal file
View File

@@ -0,0 +1,883 @@
import sys
import numpy as np
import torch
import scipy.sparse
import networkx as nx
import rdkit
from rdkit import Chem
from chemical import *
from scoring import *
def th_ang_v(ab,bc,eps:float=1e-8):
def th_norm(x,eps:float=1e-8):
return x.square().sum(-1,keepdim=True).add(eps).sqrt()
def th_N(x,alpha:float=0):
return x/th_norm(x).add(alpha)
ab, bc = th_N(ab),th_N(bc)
cos_angle = torch.clamp( (ab*bc).sum(-1), -1, 1)
sin_angle = torch.sqrt(1-cos_angle.square() + eps)
dih = torch.stack((cos_angle,sin_angle),-1)
return dih
def th_dih_v(ab,bc,cd):
def th_cross(a,b):
a,b = torch.broadcast_tensors(a,b)
return torch.cross(a,b, dim=-1)
def th_norm(x,eps:float=1e-8):
return x.square().sum(-1,keepdim=True).add(eps).sqrt()
def th_N(x,alpha:float=0):
return x/th_norm(x).add(alpha)
ab, bc, cd = th_N(ab),th_N(bc),th_N(cd)
n1 = th_N( th_cross(ab,bc) )
n2 = th_N( th_cross(bc,cd) )
sin_angle = (th_cross(n1,bc)*n2).sum(-1)
cos_angle = (n1*n2).sum(-1)
dih = torch.stack((cos_angle,sin_angle),-1)
return dih
def th_dih(a,b,c,d):
return th_dih_v(a-b,b-c,c-d)
# build a frame from 3 points
#fd - more complicated version splits angle deviations between CA-N and CA-C (giving more accurate CB position)
#fd - makes no assumptions about input dims (other than last 1 is xyz)
def rigid_from_3_points(N, Ca, C, is_na=None, eps=1e-8):
dims = N.shape[:-1]
v1 = C-Ca
v2 = N-Ca
e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
u2 = v2-(torch.einsum('...li, ...li -> ...l', e1, v2)[...,None]*e1)
e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
e3 = torch.cross(e1, e2, dim=-1)
R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
cosref = torch.sum(e1*v2, dim=-1)
costgt = torch.full(dims, -0.3616, device=N.device)
if is_na is not None:
costgt[is_na] = -0.4929
cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
Rp = torch.eye(3, device=N.device).repeat(*dims,1,1)
Rp[...,0,0] = cosdel
Rp[...,0,1] = -sindel
Rp[...,1,0] = sindel
Rp[...,1,1] = cosdel
R = torch.einsum('...ij,...jk->...ik', R,Rp)
return R, Ca
# note: needs consistency with chemical.py
def is_nucleic(seq):
return (seq>=NPROTAAS) * (seq <= NNAPROTAAS)
def is_atom(seq):
return seq > NNAPROTAAS
def idealize_reference_frame(seq, xyz_in):
xyz = xyz_in.clone()
namask = is_nucleic(seq)
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], namask)
protmask = ~namask
Nideal = torch.tensor([-0.5272, 1.3593, 0.000], device=xyz_in.device)
Cideal = torch.tensor([1.5233, 0.000, 0.000], device=xyz_in.device)
OP1ideal = torch.tensor([-0.7319, 1.2920, 0.000], device=xyz_in.device)
OP2ideal = torch.tensor([1.4855, 0.000, 0.000], device=xyz_in.device)
pmask_bs,pmask_rs = protmask.nonzero(as_tuple=True)
nmask_bs,nmask_rs = namask.nonzero(as_tuple=True)
xyz[pmask_bs,pmask_rs,0,:] = torch.einsum('...ij,j->...i', Rs[pmask_bs,pmask_rs], Nideal) + Ts[pmask_bs,pmask_rs]
xyz[pmask_bs,pmask_rs,2,:] = torch.einsum('...ij,j->...i', Rs[pmask_bs,pmask_rs], Cideal) + Ts[pmask_bs,pmask_rs]
xyz[nmask_bs,nmask_rs,0,:] = torch.einsum('...ij,j->...i', Rs[nmask_bs,nmask_rs], OP1ideal) + Ts[nmask_bs,nmask_rs]
xyz[nmask_bs,nmask_rs,2,:] = torch.einsum('...ij,j->...i', Rs[nmask_bs,nmask_rs], OP2ideal) + Ts[nmask_bs,nmask_rs]
return xyz
# works for both dna and protein
# alphas in order:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# cb/cg bend: 7-9
# eps(p)/zeta(p): 10-11
# alpha/beta/gamma/delta: 12-15
# nu2/nu1/nu0: 16-18
# chi_1(na): 19
def get_tor_mask(seq, torsion_indices, mask_in=None):
B,L = seq.shape[:2]
dna_mask = is_nucleic(seq)
prot_mask = ~dna_mask
tors_mask = torsion_indices[seq,:,-1] > 0
if mask_in != None:
N = mask_in.shape[2]
ts = torsion_indices[seq]
bs = torch.arange(B, device=seq.device)[:,None,None,None]
rs = torch.arange(L, device=seq.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res
ts = torch.abs(ts)
tors_mask *= mask_in[bs,rs,ts].all(dim=-1)
return tors_mask
def get_torsions(xyz_in, seq, torsion_indices, torsion_can_flip, ref_angles, mask_in=None):
B,L = xyz_in.shape[:2]
tors_mask = get_tor_mask(seq, torsion_indices, mask_in)
# idealize given xyz coordinates before computing torsion angles
xyz = idealize_reference_frame(seq, xyz_in)
ts = torsion_indices[seq]
bs = torch.arange(B, device=xyz_in.device)[:,None,None,None]
xs = torch.arange(L, device=xyz_in.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res
ys = torch.abs(ts)
xyzs_bytor = xyz[bs,xs,ys,:]
torsions = torch.zeros( (B,L,NTOTALDOFS,2), device=xyz_in.device )
torsions[...,:7,:] = th_dih(
xyzs_bytor[...,:7,0,:],xyzs_bytor[...,:7,1,:],xyzs_bytor[...,:7,2,:],xyzs_bytor[...,:7,3,:]
)
torsions[:,:,2,:] = -1 * torsions[:,:,2,:] # shift psi by pi
torsions[...,10:,:] = th_dih(
xyzs_bytor[...,10:,0,:],xyzs_bytor[...,10:,1,:],xyzs_bytor[...,10:,2,:],xyzs_bytor[...,10:,3,:]
)
# angles (hardcoded)
# CB bend
NC = 0.5*( xyz[:,:,0,:3] + xyz[:,:,2,:3] )
CA = xyz[:,:,1,:3]
CB = xyz[:,:,4,:3]
t = th_ang_v(CB-CA,NC-CA)
t0 = ref_angles[seq][...,0,:]
torsions[:,:,7,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
# CB twist
NCCA = NC-CA
NCp = xyz[:,:,2,:3] - xyz[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCCA, dim=-1, keepdim=True)/ torch.sum(NCCA*NCCA, dim=-1, keepdim=True) * NCCA
t = th_ang_v(CB-CA,NCpp)
t0 = ref_angles[seq][...,1,:]
torsions[:,:,8,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
# CG bend
CG = xyz[:,:,5,:3]
t = th_ang_v(CG-CB,CA-CB)
t0 = ref_angles[seq][...,2,:]
torsions[:,:,9,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
mask0 = (torch.isnan(torsions[...,0])).nonzero()
mask1 = (torch.isnan(torsions[...,1])).nonzero()
torsions[mask0[:,0],mask0[:,1],mask0[:,2],0] = 1.0
torsions[mask1[:,0],mask1[:,1],mask1[:,2],1] = 0.0
# alt chis
torsions_alt = torsions.clone()
torsions_alt[torsion_can_flip[seq,:]] *= -1
# torsions to restrain to 0 or 180 degree
# (this should be specified in chemical?)
tors_planar = torch.zeros((B, L, NTOTALDOFS), dtype=torch.bool, device=xyz_in.device)
tors_planar[:,:,5] = seq == aa2num['TYR'] # TYR chi 3 should be planar
return torsions, torsions_alt, tors_mask, tors_planar
def xyz_to_frame_xyz(xyz, seq_unmasked, atom_frames):
"""
xyz (1, L, natoms, 3)
seq_unmasked (1, L)
atom_frames (1, L, 3, 2)
"""
atoms = is_atom(seq_unmasked)
if torch.all(~atoms):
return xyz
atom_crds = xyz[atoms]
atom_L, natoms, _ = atom_crds.shape
frames_reindex = torch.zeros(atom_frames.shape[:-1])
for i in range(atom_L):
frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1]
frames_reindex = frames_reindex.long()
xyz[atoms, :, :3] = atom_crds.reshape(atom_L*natoms, 3)[frames_reindex]
return xyz
def xyz_t_to_frame_xyz(xyz_t, seq_unmasked, atom_frames):
"""
xyz (1, T, L, natoms, 3)
seq_unmasked (L)
atom_frames (1, L, 3, 2)
"""
atoms = is_atom(seq_unmasked)
if torch.all(~atoms):
return xyz_t
atom_crds_t = xyz_t[:, :, atoms]
B, T, atom_L, natoms, _ = atom_crds_t.shape
frames_reindex = torch.zeros(atom_frames.shape[:-1])
for i in range(atom_L):
frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1]
frames_reindex = frames_reindex.long()
xyz_t[:, :, atoms, :3] = atom_crds_t.reshape(T, atom_L*natoms, 3)[:, frames_reindex.squeeze(0)]
return xyz_t
def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None):
B,L,natoms = xyz_in.shape[:3]
frames = frame_indices[seq]
atoms = seq > NNAPROTAAS
if torch.any(atoms):
# print(torch.sum(atoms))
# print(atom_frames.shape)
# print(atoms[0].nonzero().flatten().shape)
try:
frames[:,atoms[0].nonzero().flatten(), 0] = atom_frames
except Exception as e:
print(e)
print(torch.sum(atoms))
print(atom_frames.shape)
print(atoms[0].nonzero().flatten().shape)
frame_mask = ~torch.all(frames[...,0, :] == frames[...,1, :], axis=-1)
# frame_mask *= torch.all(
# torch.gather(xyz_mask,2,frames.reshape(B,L,-1)).reshape(B,L,-1,3),
# axis=-1)
return frames, frame_mask
def generate_Cbeta(N,Ca,C):
# recreate Cb given N,Ca,C
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
#Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
# fd: below matches sidechain generator (=Rosetta params)
Cb = -0.57910144*a + 0.5689693*b - 0.5441217*c + Ca
return Cb
def get_tips(xyz, seq):
B,L = xyz.shape[:2]
xyz_tips = torch.gather(xyz, 2, tip_indices.to(xyz.device)[seq][:,:,None,None].expand(-1,-1,-1,3)).reshape(B, L, 3)
if torch.isnan(xyz_tips).any(): # replace NaN tip atom with virtual Cb atom
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
xyz_tips = torch.where(torch.isnan(xyz_tips), Cb, xyz_tips)
return xyz_tips
# writepdb
def writepdb(filename, atoms, seq, idx_pdb=None, bfacts=None):
f = open(filename,"w")
ctr = 1
scpu = seq.cpu().squeeze(0)
atomscpu = atoms.cpu().squeeze(0)
if bfacts is None:
bfacts = torch.zeros(atomscpu.shape[0])
if idx_pdb is None:
idx_pdb = 1 + torch.arange(atomscpu.shape[0])
Bfacts = torch.clamp( bfacts.cpu(), 0, 1)
for i,s in enumerate(scpu):
natoms = atomscpu.shape[-2]
if (natoms!=NHEAVY and natoms!=NTOTAL):
print ('bad size!', natoms, NHEAVY, NTOTAL, atoms.shape)
assert(False)
atms = aa2long[s]
# his prot hack
if (s==8 and torch.linalg.norm( atomscpu[i,9,:]-atomscpu[i,5,:] ) < 1.7):
atms = (
" N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1",
None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1",
" HD1", None, None, None, None, None, None) # his_d
for j,atm_j in enumerate(atms):
if (j<natoms and atm_j is not None and not torch.isnan(atomscpu[i,j,:]).any()):
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
"ATOM", ctr, atm_j, num2aa[s],
"A", idx_pdb[i], atomscpu[i,j,0], atomscpu[i,j,1], atomscpu[i,j,2],
1.0, Bfacts[i] ) )
ctr += 1
######
######
# process ideal frames
def make_frame(X, Y):
Xn = X / torch.linalg.norm(X)
Y = Y - torch.dot(Y, Xn) * Xn
Yn = Y / torch.linalg.norm(Y)
Z = torch.cross(Xn,Yn)
Zn = Z / torch.linalg.norm(Z)
return torch.stack((Xn,Yn,Zn), dim=-1)
# resolve tip atom indices
tip_indices = torch.full((NAATOKENS,), 0)
for i in range(NAATOKENS):
if i > NNAPROTAAS-1:
# all atoms are at index 1 in the atom array
tip_indices[i] = 1
else:
tip_atm = aa2tip[i]
atm_long = aa2long[i]
tip_indices[i] = atm_long.index(tip_atm)
# resolve torsion indices
# a negative index indicates the previous residue
# order:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# cb/cg bend: 7-9
# eps(p)/zeta(p): 10-11
# alpha/beta/gamma/delta: 12-15
# nu2/nu1/nu0: 16-18
# chi_1(na): 19
torsion_indices = torch.full((NAATOKENS,NTOTALDOFS,4),0)
torsion_can_flip = torch.full((NAATOKENS,NTOTALDOFS),False,dtype=torch.bool)
for i in range(NPROTAAS):
i_l, i_a = aa2long[i], aa2longalt[i]
# protein omega/phi/psi
torsion_indices[i,0,:] = torch.tensor([-1,-2,0,1]) # omega
torsion_indices[i,1,:] = torch.tensor([-2,0,1,2]) # phi
torsion_indices[i,2,:] = torch.tensor([0,1,2,3]) # psi (+pi)
# protein chis
for j in range(4):
if torsions[i][j] is None:
continue
for k in range(4):
a = torsions[i][j][k]
torsion_indices[i,3+j,k] = i_l.index(a)
if (i_l.index(a) != i_a.index(a)):
torsion_can_flip[i,3+j] = True ##bb tors never flip
# CB/CG angles (only masking uses these indices)
torsion_indices[i,7,:] = torch.tensor([0,2,1,4]) # CB ang1
torsion_indices[i,8,:] = torch.tensor([0,2,1,4]) # CB ang2
torsion_indices[i,9,:] = torch.tensor([0,2,4,5]) # CG ang (arg 1 ignored)
# HIS is a special case for flip
torsion_can_flip[8,4]=False
for i in range(NPROTAAS,NNAPROTAAS):
# NA BB tors
torsion_indices[i,10,:] = torch.tensor([-5,-7,-8,1]) # epsilon_prev
torsion_indices[i,11,:] = torch.tensor([-7,-8,1,3]) # zeta_prev
torsion_indices[i,12,:] = torch.tensor([0,1,3,4]) # alpha (+2pi/3)
torsion_indices[i,13,:] = torch.tensor([1,3,4,5]) # beta
torsion_indices[i,14,:] = torch.tensor([3,4,5,7]) # gamma
torsion_indices[i,15,:] = torch.tensor([4,5,7,8]) # delta
if (i<NPROTAAS+5):
# is DNA
torsion_indices[i,16,:] = torch.tensor([4,5,6,10]) # nu2
torsion_indices[i,17,:] = torch.tensor([5,6,10,9]) # nu1
torsion_indices[i,18,:] = torch.tensor([6,10,9,7]) # nu0
else:
# is RNA (fd: my fault since I flipped C1'/C2' order for DNA and RNA)
torsion_indices[i,16,:] = torch.tensor([4,5,6,9]) # nu2
torsion_indices[i,17,:] = torch.tensor([5,6,9,10]) # nu1
torsion_indices[i,18,:] = torch.tensor([6,9,10,7]) # nu0
# NA chi
if torsions[i][0] is not None:
i_l = aa2long[i]
for k in range(4):
a = torsions[i][0][k]
torsion_indices[i,19,k] = i_l.index(a) # chi
# no NA torsion flips
# build the mapping from atoms in the full rep (Nx27) to the "alternate" rep
allatom_mask = torch.zeros((NAATOKENS,NTOTAL), dtype=torch.bool)
long2alt = torch.zeros((NAATOKENS,NTOTAL), dtype=torch.long)
for i in range(NNAPROTAAS):
i_l, i_lalt = aa2long[i], aa2longalt[i]
for j,a in enumerate(i_l):
if (a is None):
long2alt[i,j] = j
else:
long2alt[i,j] = i_lalt.index(a)
allatom_mask[i,j] = True
# bond graph traversal
num_bonds = torch.zeros((NAATOKENS,NTOTAL,NTOTAL), dtype=torch.long)
for i in range(NNAPROTAAS):
num_bonds_i = np.zeros((NTOTAL,NTOTAL))
for (bnamei,bnamej) in aabonds[i]:
bi,bj = aa2long[i].index(bnamei),aa2long[i].index(bnamej)
num_bonds_i[bi,bj] = 1
num_bonds_i = scipy.sparse.csgraph.shortest_path (num_bonds_i,directed=False)
num_bonds_i[num_bonds_i>=4] = 4
num_bonds[i,...] = torch.tensor(num_bonds_i)
# atom type indices
idx2aatype = []
for x in aa2type:
for y in x:
if y and y not in idx2aatype:
idx2aatype.append(y)
aatype2idx = {x:i for i,x in enumerate(idx2aatype)}
# element indices
idx2elt = []
for x in aa2elt:
for y in x:
if y and y not in idx2elt:
idx2elt.append(y)
elt2idx = {x:i for i,x in enumerate(idx2elt)}
# LJ/LK scoring parameters
atom_type_index = torch.zeros((NAATOKENS,NTOTAL), dtype=torch.long)
element_index = torch.zeros((NAATOKENS,NTOTAL), dtype=torch.long)
ljlk_parameters = torch.zeros((NAATOKENS,NTOTAL,5), dtype=torch.float)
lj_correction_parameters = torch.zeros((NAATOKENS,NTOTAL,4), dtype=bool) # donor/acceptor/hpol/disulf
for i in range(NNAPROTAAS):
for j,a in enumerate(aa2type[i]):
if (a is not None):
atom_type_index[i,j] = aatype2idx[a]
ljlk_parameters[i,j,:] = torch.tensor( type2ljlk[a] )
lj_correction_parameters[i,j,0] = (type2hb[a]==HbAtom.DO)+(type2hb[a]==HbAtom.DA)
lj_correction_parameters[i,j,1] = (type2hb[a]==HbAtom.AC)+(type2hb[a]==HbAtom.DA)
lj_correction_parameters[i,j,2] = (type2hb[a]==HbAtom.HP)
lj_correction_parameters[i,j,3] = (a=="SH1" or a=="HS")
for j,a in enumerate(aa2elt[i]):
if (a is not None):
element_index[i,j] = elt2idx[a]
# hbond scoring parameters
def donorHs(D,bonds,atoms):
dHs = []
for (i,j) in bonds:
if (i==D):
idx_j = atoms.index(j)
if (idx_j>=NHEAVY): # if atom j is a hydrogen
dHs.append(idx_j)
if (j==D):
idx_i = atoms.index(i)
if (idx_i>=NHEAVY): # if atom j is a hydrogen
dHs.append(idx_i)
assert (len(dHs)>0)
return dHs
def acceptorBB0(A,hyb,bonds,atoms):
if (hyb == HbHybType.SP2):
for (i,j) in bonds:
if (i==A):
B = atoms.index(j)
if (B<NHEAVY):
break
if (j==A):
B = atoms.index(i)
if (B<NHEAVY):
break
for (i,j) in bonds:
if (i==atoms[B]):
B0 = atoms.index(j)
if (B0<NHEAVY):
break
if (j==atoms[B]):
B0 = atoms.index(i)
if (B0<NHEAVY):
break
elif (hyb == HbHybType.SP3 or hyb == HbHybType.RING):
for (i,j) in bonds:
if (i==A):
B = atoms.index(j)
if (B<NHEAVY):
break
if (j==A):
B = atoms.index(i)
if (B<NHEAVY):
break
for (i,j) in bonds:
if (i==A and j!=atoms[B]):
B0 = atoms.index(j)
break
if (j==A and i!=atoms[B]):
B0 = atoms.index(i)
break
return B,B0
hbtypes = torch.full((NAATOKENS,NTOTAL,3),-1, dtype=torch.long) # (donortype, acceptortype, acchybtype)
hbbaseatoms = torch.full((NAATOKENS,NTOTAL,2),-1, dtype=torch.long) # (B,B0) for acc; (D,-1) for don
hbpolys = torch.zeros((HbDonType.NTYPES,HbAccType.NTYPES,3,15)) # weight,xmin,xmax,ymin,ymax,c9,...,c0
for i in range(NNAPROTAAS):
for j,a in enumerate(aa2type[i]):
if (a in type2dontype):
j_hs = donorHs(aa2long[i][j],aabonds[i],aa2long[i])
for j_h in j_hs:
hbtypes[i,j_h,0] = type2dontype[a]
hbbaseatoms[i,j_h,0] = j
if (a in type2acctype):
j_b, j_b0 = acceptorBB0(aa2long[i][j],type2hybtype[a],aabonds[i],aa2long[i])
hbtypes[i,j,1] = type2acctype[a]
hbtypes[i,j,2] = type2hybtype[a]
hbbaseatoms[i,j,0] = j_b
hbbaseatoms[i,j,1] = j_b0
for i in range(HbDonType.NTYPES):
for j in range(HbAccType.NTYPES):
weight = dontype2wt[i]*acctype2wt[j]
pdist,pbah,pahd = hbtypepair2poly[(i,j)]
xrange,yrange,coeffs = hbpolytype2coeffs[pdist]
hbpolys[i,j,0,0] = weight
hbpolys[i,j,0,1:3] = torch.tensor(xrange)
hbpolys[i,j,0,3:5] = torch.tensor(yrange)
hbpolys[i,j,0,5:] = torch.tensor(coeffs)
xrange,yrange,coeffs = hbpolytype2coeffs[pahd]
hbpolys[i,j,1,0] = weight
hbpolys[i,j,1,1:3] = torch.tensor(xrange)
hbpolys[i,j,1,3:5] = torch.tensor(yrange)
hbpolys[i,j,1,5:] = torch.tensor(coeffs)
xrange,yrange,coeffs = hbpolytype2coeffs[pbah]
hbpolys[i,j,2,0] = weight
hbpolys[i,j,2,1:3] = torch.tensor(xrange)
hbpolys[i,j,2,3:5] = torch.tensor(yrange)
hbpolys[i,j,2,5:] = torch.tensor(coeffs)
# cartbonded scoring parameters
# (0) inter-res
cb_lengths_CN = (1.32868, 369.445)
cb_angles_CACN = (2.02807,160)
cb_angles_CNCA = (2.12407,96.53)
cb_torsions_CACNH = (0.0,41.830) # also used for proline CACNCD
cb_torsions_CANCO = (0.0,38.668)
# note for the below, the extra amino acid corrsponds to cb params for HIS_D
# (1) intra-res lengths
cb_lengths = [[] for i in range(NAATOKENS+1)]
for cst in cartbonded_data_raw['lengths']:
res_idx = aa2num[ cst['res'] ]
cb_lengths[res_idx].append( (
aa2long[res_idx].index(cst['atm1']),
aa2long[res_idx].index(cst['atm2']),
cst['x0'],cst['K']
) )
ncst_per_res=max([len(i) for i in cb_lengths])
cb_length_t = torch.zeros(NAATOKENS+1,ncst_per_res,4)
for i in range(NNAPROTAAS+1):
src = i
if (num2aa[i]=='UNK' or num2aa[i]=='MAS'):
src=aa2num['ALA']
if (len(cb_lengths[src])>0):
cb_length_t[i,:len(cb_lengths[src]),:] = torch.tensor(cb_lengths[src])
# (2) intra-res angles
cb_angles = [[] for i in range(NAATOKENS+1)]
for cst in cartbonded_data_raw['angles']:
res_idx = aa2num[ cst['res'] ]
cb_angles[res_idx].append( (
aa2long[res_idx].index(cst['atm1']),
aa2long[res_idx].index(cst['atm2']),
aa2long[res_idx].index(cst['atm3']),
cst['x0'],cst['K']
) )
ncst_per_res=max([len(i) for i in cb_angles])
cb_angle_t = torch.zeros(NAATOKENS+1,ncst_per_res,5)
for i in range(NNAPROTAAS+1):
src = i
if (num2aa[i]=='UNK' or num2aa[i]=='MAS'):
src=aa2num['ALA']
if (len(cb_angles[src])>0):
cb_angle_t[i,:len(cb_angles[src]),:] = torch.tensor(cb_angles[src])
# (3) intra-res torsions
cb_torsions = [[] for i in range(NAATOKENS+1)]
for cst in cartbonded_data_raw['torsions']:
res_idx = aa2num[ cst['res'] ]
cb_torsions[res_idx].append( (
aa2long[res_idx].index(cst['atm1']),
aa2long[res_idx].index(cst['atm2']),
aa2long[res_idx].index(cst['atm3']),
aa2long[res_idx].index(cst['atm4']),
cst['x0'],cst['K'],cst['period']
) )
ncst_per_res=max([len(i) for i in cb_torsions])
cb_torsion_t = torch.zeros(NAATOKENS+1,ncst_per_res,7)
cb_torsion_t[...,6]=1.0 # periodicity
for i in range(NNAPROTAAS):
src = i
if (num2aa[i]=='UNK' or num2aa[i]=='MAS'):
src=aa2num['ALA']
if (len(cb_torsions[src])>0):
cb_torsion_t[i,:len(cb_torsions[src]),:] = torch.tensor(cb_torsions[src])
# kinematic parameters
base_indices = torch.full((NAATOKENS,NTOTAL),0, dtype=torch.long) # base frame that builds each atom
xyzs_in_base_frame = torch.ones((NAATOKENS,NTOTAL,4)) # coords of each atom in the base frame
RTs_by_torsion = torch.eye(4).repeat(NAATOKENS,NTOTALTORS,1,1) # torsion frames
reference_angles = torch.ones((NAATOKENS,NPROTANGS,2)) # reference values for bendable angles
## PROTEIN
for i in range(NPROTAAS):
i_l = aa2long[i]
for name, base, coords in ideal_coords[i]:
idx = i_l.index(name)
base_indices[i,idx] = base
xyzs_in_base_frame[i,idx,:3] = torch.tensor(coords)
# omega frame
RTs_by_torsion[i,0,:3,:3] = torch.eye(3)
RTs_by_torsion[i,0,:3,3] = torch.zeros(3)
# phi frame
RTs_by_torsion[i,1,:3,:3] = make_frame(
xyzs_in_base_frame[i,0,:3] - xyzs_in_base_frame[i,1,:3],
torch.tensor([1.,0.,0.])
)
RTs_by_torsion[i,1,:3,3] = xyzs_in_base_frame[i,0,:3]
# psi frame
RTs_by_torsion[i,2,:3,:3] = make_frame(
xyzs_in_base_frame[i,2,:3] - xyzs_in_base_frame[i,1,:3],
xyzs_in_base_frame[i,1,:3] - xyzs_in_base_frame[i,0,:3]
)
RTs_by_torsion[i,2,:3,3] = xyzs_in_base_frame[i,2,:3]
# chi1 frame
if torsions[i][0] is not None:
a0,a1,a2 = torsion_indices[i,3,0:3]
RTs_by_torsion[i,3,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3],
xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3],
)
RTs_by_torsion[i,3,:3,3] = xyzs_in_base_frame[i,a2,:3]
# chi2/3/4 frame
for j in range(1,4):
if torsions[i][j] is not None:
a2 = torsion_indices[i,3+j,2]
if ((i==18 and j==2) or (i==8 and j==2)): # TYR CZ-OH & HIS CE1-HE1 a special case
a0,a1 = torsion_indices[i,3+j,0:2]
RTs_by_torsion[i,3+j,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3],
xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3] )
else:
RTs_by_torsion[i,3+j,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3],
torch.tensor([-1.,0.,0.]), )
RTs_by_torsion[i,3+j,:3,3] = xyzs_in_base_frame[i,a2,:3]
# CB/CG angles
NCr = 0.5*(xyzs_in_base_frame[i,0,:3]+xyzs_in_base_frame[i,2,:3])
CAr = xyzs_in_base_frame[i,1,:3]
CBr = xyzs_in_base_frame[i,4,:3]
CGr = xyzs_in_base_frame[i,5,:3]
reference_angles[i,0,:]=th_ang_v(CBr-CAr,NCr-CAr)
NCp = xyzs_in_base_frame[i,2,:3]-xyzs_in_base_frame[i,0,:3]
NCpp = NCp - torch.dot(NCp,NCr)/ torch.dot(NCr,NCr) * NCr
reference_angles[i,1,:]=th_ang_v(CBr-CAr,NCpp)
reference_angles[i,2,:]=th_ang_v(CGr,torch.tensor([-1.,0.,0.]))
## NUCLEIC ACIDS
for i in range(NPROTAAS, NNAPROTAAS):
i_l = aa2long[i]
for name, base, coords in ideal_coords[i]:
idx = i_l.index(name)
base_indices[i,idx] = base
xyzs_in_base_frame[i,idx,:3] = torch.tensor(coords)
# epsilon(p)/zeta(p) - like omega in protein, not used to build atoms
# - keep as identity
RTs_by_torsion[i,NPROTTORS+0,:3,:3] = torch.eye(3)
RTs_by_torsion[i,NPROTTORS+0,:3,3] = torch.zeros(3)
RTs_by_torsion[i,NPROTTORS+1,:3,:3] = torch.eye(3)
RTs_by_torsion[i,NPROTTORS+1,:3,3] = torch.zeros(3)
# alpha
RTs_by_torsion[i,NPROTTORS+2,:3,:3] = make_frame(
xyzs_in_base_frame[i,3,:3] - xyzs_in_base_frame[i,1,:3], # P->O5'
xyzs_in_base_frame[i,0,:3] - xyzs_in_base_frame[i,1,:3] # P<-OP1
)
RTs_by_torsion[i,NPROTTORS+2,:3,3] = xyzs_in_base_frame[i,3,:3] # O5'
# beta
RTs_by_torsion[i,NPROTTORS+3,:3,:3] = make_frame(
xyzs_in_base_frame[i,4,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+3,:3,3] = xyzs_in_base_frame[i,4,:3] # C5'
# gamma
RTs_by_torsion[i,NPROTTORS+4,:3,:3] = make_frame(
xyzs_in_base_frame[i,5,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+4,:3,3] = xyzs_in_base_frame[i,5,:3] # C4'
# delta
RTs_by_torsion[i,NPROTTORS+5,:3,:3] = make_frame(
xyzs_in_base_frame[i,7,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+5,:3,3] = xyzs_in_base_frame[i,7,:3] # C3'
# nu2
RTs_by_torsion[i,NPROTTORS+6,:3,:3] = make_frame(
xyzs_in_base_frame[i,6,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+6,:3,3] = xyzs_in_base_frame[i,6,:3] # O4'
# nu1
if i<NPROTAAS+5:
# is DNA
C1idx,C2idx = 10,9
else:
# is RNA
C1idx,C2idx = 9,10
RTs_by_torsion[i,NPROTTORS+7,:3,:3] = make_frame(
xyzs_in_base_frame[i,C1idx,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+7,:3,3] = xyzs_in_base_frame[i,C1idx,:3] # C1'
# nu0
RTs_by_torsion[i,NPROTTORS+8,:3,:3] = make_frame(
xyzs_in_base_frame[i,C2idx,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+8,:3,3] = xyzs_in_base_frame[i,C2idx,:3] # C2'
# NA chi
if torsions[i][0] is not None:
a2 = torsion_indices[i,19,2]
RTs_by_torsion[i,NPROTTORS+9,:3,:3] = make_frame(
xyzs_in_base_frame[i,a2,:3] , torch.tensor([-1.,0.,0.])
)
RTs_by_torsion[i,NPROTTORS+9,:3,3] = xyzs_in_base_frame[i,a2,:3]
# general FAPE parameters
frame_indices = torch.full((NAATOKENS,NFRAMES,3,2),0, dtype=torch.long)
for i in range(NNAPROTAAS):
i_l = aa2long[i]
for j,x in enumerate(frames[i]):
if x is not None:
# frames are stored as (residue offset, atom position)
frame_indices[i,j,0] = torch.tensor((0, i_l.index(x[0])))
frame_indices[i,j,1] = torch.tensor((0, i_l.index(x[1])))
frame_indices[i,j,2] = torch.tensor((0, i_l.index(x[2])))
### Create atom frames for FAPE loss calculation ###
def get_nxgraph(mol : rdkit.Chem.rdchem.Mol) -> nx.Graph:
'''build NetworkX graph from rdkit's molecule'''
N = mol.GetNumAtoms()
# pairs of bonded atoms
bonds = [(b.GetBeginAtomIdx(),b.GetEndAtomIdx()) for b in mol.GetBonds()]
# connectivity graph
G = nx.Graph()
G.add_nodes_from(range(N))
G.add_edges_from(bonds)
return G
def find_all_paths_of_length_n(G : nx.Graph,
n : int) -> torch.Tensor:
'''find all paths of length N in a networkx graph
https://stackoverflow.com/questions/28095646/finding-all-paths-walks-of-given-length-in-a-networkx-graph'''
def findPaths(G,u,n):
if n==0:
return [[u]]
paths = [[u]+path for neighbor in G.neighbors(u) for path in findPaths(G,neighbor,n-1) if u not in path]
return paths
# all paths of length n
allpaths = [tuple(p) if p[0]<p[-1] else tuple(reversed(p))
for node in G for p in findPaths(G,node,n)]
# unique paths
allpaths = list(set(allpaths))
#return torch.tensor(allpaths)
return allpaths
def get_atom_frames(msa, mol, G):
"""choose a frame of 3 bonded atoms for each atom in the molecule, rule based system that chooses frame based on atom priorities"""
query_seq = msa
frames = find_all_paths_of_length_n(G, 2)
selected_frames = []
for n in range(mol.GetNumAtoms()):
frames_with_n = [frame for frame in frames if n == frame[1]]
# some chemical groups don't have two bonded heavy atoms; so choose a frame with an atom 2 bonds away
if not frames_with_n:
frames_with_n = [frame for frame in frames if n in frame]
# if the atom isn't in a 3 atom frame, it should be ignored in loss calc, set all the atoms to n
if not frames_with_n:
selected_frames.append([n,n,n])
continue
frame_priorities = []
for frame in frames_with_n:
# hacky but uses the "query_seq" to convert index of the atom into an "atom type" and converts that into a priority
indices = [index for index in frame if index!=n]
aas = [num2aa[int(query_seq[index].numpy())] for index in indices]
frame_priorities.append(sorted([atom2frame_priority[aa] for aa in aas]))
# np.argsort doesn't sort tuples correctly so just sort a list of indices using a key
sorted_indices = sorted(range(len(frame_priorities)), key=lambda i: frame_priorities[i])
# calculate residue offset for frame
frame = [(frame-n, 1) for frame in frames_with_n[sorted_indices[0]]]
selected_frames.append(frame)
assert msa.shape[0] == len(selected_frames)
return torch.Tensor(selected_frames).long()
### Generate bond features for small molecules ###
def get_bond_feats(mol, G):
"""creates 2d bond graph for small molecules"""
N = mol.GetNumAtoms()
bond_feats = torch.zeros((N, N)).long()
i,j = np.array(G.edges).T
bond_feats[i,j] = torch.tensor([rdkit2btype[int(b.GetBondType())] for b in mol.GetBonds()]).long()
bond_feats[j,i] = bond_feats[i,j]
return bond_feats
def get_protein_bond_feats(protein_L):
""" creates protein residue connectivity graphs """
bond_feats = torch.zeros((protein_L, protein_L))
residues = torch.arange(protein_L-1)
bond_feats[residues, residues+1] = 1
bond_feats[residues+1, residues] = 1
return bond_feats

439
RF2_allatom/util_module.py Normal file
View File

@@ -0,0 +1,439 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import copy
import dgl
from util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points, is_nucleic
def init_lecun_normal(module, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
return module
def init_lecun_normal_param(weight, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
return weight
# for gradient checkpointing
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
self.broadcast_dim=broadcast_dim
self.p_drop=p_drop
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if not self.broadcast_dim == None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x
def rbf(D, scale=1.0):
# Distance radial basis function
D_min, D_max, D_count = 0., 20., 36
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_sigma = scale * (D_max - D_min) / D_count #fd add factor (?)
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
def get_seqsep(idx):
'''
Input:
- idx: residue indices of given sequence (B,L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
Sergey found that having sign in seqsep features helps a little
'''
seqsep = idx[:,None,:] - idx[:,:,None]
sign = torch.sign(seqsep)
neigh = torch.abs(seqsep)
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
neigh = sign * neigh
return neigh.unsqueeze(-1)
def make_full_graph(xyz, pair, idx):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
b,i,j = torch.where(sep.abs() > 0)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# distance map from current CA coordinates
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*9999.9 # (B, L, L)
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*9999.9
if (topk_incl_local):
D = D + sep*eps
D[sep<nlocal] = 0.0
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
cond = topk_matrix > 0.0
else:
D = D + sep*eps
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
# put an edge if any of the 3 conditions are met:
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
# 2) top_k neighbors
cond = torch.logical_or(topk_matrix > 0.0, sep < nlocal)
b,i,j = torch.where(cond)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
B,L,A = xyz.shape[:3]
device = xyz.device
D = torch.norm(
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
)
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
D[~mask2d] = 9999.
D[D==0] = 9999.
# select top K neighbors for each atom
# keep indices as batch/res/atm indices
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
bi,ri,ai = mask.nonzero(as_tuple=True)
bi = bi[:,None].repeat(1,top_k).reshape(-1)
ri = ri[:,None].repeat(1,top_k).reshape(-1)
ai = ai[:,None].repeat(1,top_k).reshape(-1)
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
edge = torch.full(ri.shape, maxbonds, device=device)
resmask = ri==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
resmask = ri+1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
resmask = ri-1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
edge = edge.clamp(0,maxbonds-1)
edge = F.one_hot(edge)[...,None]
natm = torch.sum(mask)
index = torch.zeros_like(mask, dtype=torch.long, device=device)
index[mask] = torch.arange(natm, device=device)
src=index[bi,ri,ai]
tgt=index[bi,rj,aj]
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
return G, edge
# rotate about the x axis
def make_rotX(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,1,1] = angs[:,:,0]/NORM
RTs[:,:,1,2] = -angs[:,:,1]/NORM
RTs[:,:,2,1] = angs[:,:,1]/NORM
RTs[:,:,2,2] = angs[:,:,0]/NORM
return RTs
# rotate about the x axis
def make_rotZ(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,0,0] = angs[:,:,0]/NORM
RTs[:,:,0,1] = -angs[:,:,1]/NORM
RTs[:,:,1,0] = angs[:,:,1]/NORM
RTs[:,:,1,1] = angs[:,:,0]/NORM
return RTs
# rotate about an arbitrary axis
def make_rot_axis(angs, u, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
ct = angs[:,:,0]/NORM
st = angs[:,:,1]/NORM
u0 = u[:,:,0]
u1 = u[:,:,1]
u2 = u[:,:,2]
RTs[:,:,0,0] = ct+u0*u0*(1-ct)
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st
RTs[:,:,1,1] = ct+u1*u1*(1-ct)
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st
RTs[:,:,2,2] = ct+u2*u2*(1-ct)
return RTs
# compute allatom structure from backbone frames and torsions
#
# alphas:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# cb/cg bend: 7-9
# eps(p)/zeta(p): 10-11
# alpha/beta/gamma/delta: 12-15
# nu2/nu1/nu0: 16-18
# chi_1(na): 19
#
# RTs_in_base_frame:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# eps(p)/zeta(p): 7-8
# alpha/beta/gamma/delta: 9-12
# nu2/nu1/nu0: 13-15
# chi_1(na): 16
#
# RT frames (output):
# origin: 0
# omega/phi/psi: 1-3
# chi_1-4(prot): 4-7
# cb bend: 8
# alpha/beta/gamma/delta: 9-12
# nu2/nu1/nu0: 13-15
# chi_1(na): 16
#
class ComputeAllAtomCoords(nn.Module):
def __init__(self):
super(ComputeAllAtomCoords, self).__init__()
self.base_indices = nn.Parameter(base_indices, requires_grad=False)
self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False)
self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False)
def forward(self, seq, xyz, alphas):
B,L = xyz.shape[:2]
is_NA = is_nucleic(seq)
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], is_NA)
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
# bb
RTF0[:,:,:3,:3] = Rs
RTF0[:,:,:3,3] = Ts
# omega
RTF1 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
# phi
RTF2 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
# psi
RTF3 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
# CB bend
basexyzs = self.xyzs_in_base_frame[seq]
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
CAr = (basexyzs[:,:,1,:3])
CBr = (basexyzs[:,:,4,:3])
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
# CB twist
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
CBrotaxis2 = (CBr-CAr).cross(NCpp)
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
RTF8 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, CBrot1,CBrot2)
# chi1 + CG bend
RTF4 = torch.einsum(
'brij,brjk,brkl,brlm->brim',
RTF8,
self.RTs_in_base_frame[seq,3,:],
make_rotX(alphas[:,:,3,:]),
make_rotZ(alphas[:,:,9,:]))
# chi2
RTF5 = torch.einsum(
'brij,brjk,brkl->bril',
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
# chi3
RTF6 = torch.einsum(
'brij,brjk,brkl->bril',
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
# chi4
RTF7 = torch.einsum(
'brij,brjk,brkl->bril',
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
# ignore RTs_in_base_frame[seq,7:9,:] and alphas[:,:,10:12,:]
# NA alpha
RTF9 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
# NA beta
RTF10 = torch.einsum(
'brij,brjk,brkl->bril',
RTF9, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
# NA gamma
RTF11 = torch.einsum(
'brij,brjk,brkl->bril',
RTF10, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
# NA delta
RTF12 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
# NA nu2 - from gamma frame
RTF13 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
# NA nu1
RTF14 = torch.einsum(
'brij,brjk,brkl->bril',
RTF13, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
# NA nu0
RTF15 = torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
# NA chi - from nu1 frame
RTF16= torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
RTframes = torch.stack((
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8,
RTF9,RTF10,RTF11,RTF12,RTF13,RTF14,RTF15,RTF16
),dim=2)
xyzs = torch.einsum(
'brtij,brtj->brti',
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
)
return RTframes, xyzs[...,:3]