From 0f7d4f6939696fbf5e4cb890d794f5be3574cc71 Mon Sep 17 00:00:00 2001 From: Rohith Krishna Date: Mon, 11 Jul 2022 00:48:54 -0700 Subject: [PATCH] created new folder for all atom --- RF2_allatom/.gitignore | 3 + RF2_allatom/Attention_module.py | 473 ++ RF2_allatom/AuxiliaryPredictor.py | 75 + RF2_allatom/Embeddings.py | 279 + RF2_allatom/RoseTTAFoldModel.py | 111 + RF2_allatom/SE3_network.py | 83 + RF2_allatom/Track_module.py | 693 +++ RF2_allatom/arguments.py | 172 + RF2_allatom/cartbonded.json | 9052 +++++++++++++++++++++++++++++ RF2_allatom/chemical.py | 1076 ++++ RF2_allatom/coords6d.py | 78 + RF2_allatom/data_loader.py | 1954 +++++++ RF2_allatom/eval.py | 342 ++ RF2_allatom/eval_fb.py | 395 ++ RF2_allatom/eval_model1.py | 51 + RF2_allatom/ffindex.py | 91 + RF2_allatom/kinematics.py | 266 + RF2_allatom/loss.py | 812 +++ RF2_allatom/memory.py | 57 + RF2_allatom/parsers.py | 439 ++ RF2_allatom/predict_casp14.py | 334 ++ RF2_allatom/resnet.py | 72 + RF2_allatom/run.sh | 30 + RF2_allatom/scheduler.py | 180 + RF2_allatom/scoring.py | 310 + RF2_allatom/tests.py | 227 + RF2_allatom/train_multi_EMA.py | 1591 +++++ RF2_allatom/util.py | 883 +++ RF2_allatom/util_module.py | 439 ++ 29 files changed, 20568 insertions(+) create mode 100644 RF2_allatom/.gitignore create mode 100644 RF2_allatom/Attention_module.py create mode 100644 RF2_allatom/AuxiliaryPredictor.py create mode 100644 RF2_allatom/Embeddings.py create mode 100644 RF2_allatom/RoseTTAFoldModel.py create mode 100644 RF2_allatom/SE3_network.py create mode 100644 RF2_allatom/Track_module.py create mode 100644 RF2_allatom/arguments.py create mode 100644 RF2_allatom/cartbonded.json create mode 100644 RF2_allatom/chemical.py create mode 100644 RF2_allatom/coords6d.py create mode 100644 RF2_allatom/data_loader.py create mode 100644 RF2_allatom/eval.py create mode 100644 RF2_allatom/eval_fb.py create mode 100644 RF2_allatom/eval_model1.py create mode 100644 RF2_allatom/ffindex.py create mode 100644 RF2_allatom/kinematics.py create mode 100644 RF2_allatom/loss.py create mode 100644 RF2_allatom/memory.py create mode 100644 RF2_allatom/parsers.py create mode 100644 RF2_allatom/predict_casp14.py create mode 100644 RF2_allatom/resnet.py create mode 100755 RF2_allatom/run.sh create mode 100644 RF2_allatom/scheduler.py create mode 100644 RF2_allatom/scoring.py create mode 100644 RF2_allatom/tests.py create mode 100644 RF2_allatom/train_multi_EMA.py create mode 100644 RF2_allatom/util.py create mode 100644 RF2_allatom/util_module.py diff --git a/RF2_allatom/.gitignore b/RF2_allatom/.gitignore new file mode 100644 index 0000000..64d7c7a --- /dev/null +++ b/RF2_allatom/.gitignore @@ -0,0 +1,3 @@ +valid_remapped +lig_test +dataset.pkl diff --git a/RF2_allatom/Attention_module.py b/RF2_allatom/Attention_module.py new file mode 100644 index 0000000..790ba2d --- /dev/null +++ b/RF2_allatom/Attention_module.py @@ -0,0 +1,473 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from opt_einsum import contract as einsum +from util_module import init_lecun_normal +class FeedForwardLayer(nn.Module): + def __init__(self, d_model, r_ff, p_drop=0.1): + super(FeedForwardLayer, self).__init__() + self.norm = nn.LayerNorm(d_model) + self.linear1 = nn.Linear(d_model, d_model*r_ff) + self.dropout = nn.Dropout(p_drop) + self.linear2 = nn.Linear(d_model*r_ff, d_model) + + self.reset_parameter() + + def reset_parameter(self): + # initialize linear layer right before ReLu: He initializer (kaiming normal) + nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu') + nn.init.zeros_(self.linear1.bias) + + # initialize linear layer right before residual connection: zero initialize + nn.init.zeros_(self.linear2.weight) + nn.init.zeros_(self.linear2.bias) + + def forward(self, src): + src = self.norm(src) + src = self.linear2(self.dropout(F.relu_(self.linear1(src)))) + return src + +class Attention(nn.Module): + # calculate multi-head attention + def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1): + super(Attention, self).__init__() + self.h = n_head + self.dim = d_hidden + # + self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False) + # + self.to_out = nn.Linear(n_head*d_hidden, d_out) + self.scaling = 1/math.sqrt(d_hidden) + # + # initialize all parameters properly + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, query, key, value): + B, Q = query.shape[:2] + B, K = key.shape[:2] + # + query = self.to_q(query).reshape(B, Q, self.h, self.dim) + key = self.to_k(key).reshape(B, K, self.h, self.dim) + value = self.to_v(value).reshape(B, K, self.h, self.dim) + # + query = query * self.scaling + attn = einsum('bqhd,bkhd->bhqk', query, key) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bhqk,bkhd->bqhd', attn, value) + out = out.reshape(B, Q, self.h*self.dim) + # + out = self.to_out(out) + + return out + +# MSA Attention (row/column) from AlphaFold architecture +class SequenceWeight(nn.Module): + def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1): + super(SequenceWeight, self).__init__() + self.h = n_head + self.dim = d_hidden + self.scale = 1.0 / math.sqrt(self.dim) + + self.to_query = nn.Linear(d_msa, n_head*d_hidden) + self.to_key = nn.Linear(d_msa, n_head*d_hidden) + self.dropout = nn.Dropout(p_drop) + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_query.weight) + nn.init.xavier_uniform_(self.to_key.weight) + + def forward(self, msa): + B, N, L = msa.shape[:3] + + tar_seq = msa[:,0] + + q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim) + k = self.to_key(msa).view(B, N, L, self.h, self.dim) + + q = q * self.scale + attn = einsum('bqihd,bkihd->bkihq', q, k) + attn = F.softmax(attn, dim=1) + return self.dropout(attn) + +class MSARowAttentionWithBias(nn.Module): + def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32): + super(MSARowAttentionWithBias, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + self.norm_pair = nn.LayerNorm(d_pair) + # + self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1) + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_pair, n_head, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # bias: normal distribution + self.to_b = init_lecun_normal(self.to_b) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, msa, pair): # TODO: make this as tied-attention + B, N, L = msa.shape[:3] + # + msa = self.norm_msa(msa) + pair = self.norm_pair(pair) + # + seq_weight = self.seq_weight(msa) # (B, N, L, h, 1) + query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) + key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) + value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) + bias = self.to_b(pair) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(msa)) + # + query = query * seq_weight.expand(-1, -1, -1, -1, self.dim) + key = key * self.scaling + attn = einsum('bsqhd,bskhd->bqkh', query, key) + attn = attn + bias + attn = F.softmax(attn, dim=-2) + # + out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +class MSAColAttention(nn.Module): + def __init__(self, d_msa=256, n_head=8, d_hidden=32): + super(MSAColAttention, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + # + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, msa): + B, N, L = msa.shape[:3] + # + msa = self.norm_msa(msa) + # + query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) + key = self.to_k(msa).reshape(B, N, L, self.h, self.dim) + value = self.to_v(msa).reshape(B, N, L, self.h, self.dim) + gate = torch.sigmoid(self.to_g(msa)) + # + query = query * self.scaling + attn = einsum('bqihd,bkihd->bihqk', query, key) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1) + out = gate * out + # + out = self.to_out(out) + return out + +class MSAColGlobalAttention(nn.Module): + def __init__(self, d_msa=64, n_head=8, d_hidden=8): + super(MSAColGlobalAttention, self).__init__() + self.norm_msa = nn.LayerNorm(d_msa) + # + self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_msa, d_hidden, bias=False) + self.to_v = nn.Linear(d_msa, d_hidden, bias=False) + self.to_g = nn.Linear(d_msa, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_msa) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, msa): + B, N, L = msa.shape[:3] + # + msa = self.norm_msa(msa) + # + query = self.to_q(msa).reshape(B, N, L, self.h, self.dim) + query = query.mean(dim=1) # (B, L, h, dim) + key = self.to_k(msa) # (B, N, L, dim) + value = self.to_v(msa) # (B, N, L, dim) + gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim) + # + query = query * self.scaling + attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N) + attn = F.softmax(attn, dim=-1) + # + out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim) + out = gate * out # (B, N, L, h*dim) + # + out = self.to_out(out) + return out + +# TriangleAttention & TriangleMultiplication from AlphaFold architecture +class TriangleAttention(nn.Module): + def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True): + super(TriangleAttention, self).__init__() + self.norm = nn.LayerNorm(d_pair) + self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False) + + self.to_b = nn.Linear(d_pair, n_head, bias=False) + self.to_g = nn.Linear(d_pair, n_head*d_hidden) + + self.to_out = nn.Linear(n_head*d_hidden, d_pair) + + self.scaling = 1/math.sqrt(d_hidden) + + self.h = n_head + self.dim = d_hidden + self.start_node=start_node + + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # bias: normal distribution + self.to_b = init_lecun_normal(self.to_b) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, pair): + B, L = pair.shape[:2] + + pair = self.norm(pair) + + # input projection + query = self.to_q(pair).reshape(B, L, L, self.h, -1) + key = self.to_k(pair).reshape(B, L, L, self.h, -1) + value = self.to_v(pair).reshape(B, L, L, self.h, -1) + bias = self.to_b(pair) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim) + + # attention + query = query * self.scaling + if self.start_node: + attn = einsum('bijhd,bikhd->bijkh', query, key) + else: + attn = einsum('bijhd,bkjhd->bijkh', query, key) + attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh) + attn = F.softmax(attn, dim=-2) + if self.start_node: + out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1) + else: + out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1) + out = gate * out # gated attention + + # output projection + out = self.to_out(out) + return out + +class TriangleMultiplication(nn.Module): + def __init__(self, d_pair, d_hidden=128, outgoing=True): + super(TriangleMultiplication, self).__init__() + self.norm = nn.LayerNorm(d_pair) + self.left_proj = nn.Linear(d_pair, d_hidden) + self.right_proj = nn.Linear(d_pair, d_hidden) + self.left_gate = nn.Linear(d_pair, d_hidden) + self.right_gate = nn.Linear(d_pair, d_hidden) + # + self.gate = nn.Linear(d_pair, d_pair) + self.norm_out = nn.LayerNorm(d_hidden) + self.out_proj = nn.Linear(d_hidden, d_pair) + + self.outgoing = outgoing + + self.reset_parameter() + + def reset_parameter(self): + # normal distribution for regular linear weights + self.left_proj = init_lecun_normal(self.left_proj) + self.right_proj = init_lecun_normal(self.right_proj) + + # Set Bias of Linear layers to zeros + nn.init.zeros_(self.left_proj.bias) + nn.init.zeros_(self.right_proj.bias) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.left_gate.weight) + nn.init.ones_(self.left_gate.bias) + + nn.init.zeros_(self.right_gate.weight) + nn.init.ones_(self.right_gate.bias) + + nn.init.zeros_(self.gate.weight) + nn.init.ones_(self.gate.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.out_proj.weight) + nn.init.zeros_(self.out_proj.bias) + + def forward(self, pair): + B, L = pair.shape[:2] + pair = self.norm(pair) + + left = self.left_proj(pair) # (B, L, L, d_h) + left_gate = torch.sigmoid(self.left_gate(pair)) + left = left_gate * left + + right = self.right_proj(pair) # (B, L, L, d_h) + right_gate = torch.sigmoid(self.right_gate(pair)) + right = right_gate * right + + if self.outgoing: + out = einsum('bikd,bjkd->bijd', left, right/float(L)) + else: + out = einsum('bkid,bkjd->bijd', left, right/float(L)) + out = self.norm_out(out) + out = self.out_proj(out) + + gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair) + out = gate * out + return out + +# Instead of triangle attention, use Tied axail attention with bias from coordinates..? +class BiasedAxialAttention(nn.Module): + def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True): + super(BiasedAxialAttention, self).__init__() + # + self.is_row = is_row + self.norm_pair = nn.LayerNorm(d_pair) + + self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False) + self.to_b = nn.Linear(d_bias, n_head, bias=False) + self.to_g = nn.Linear(d_pair, n_head*d_hidden) + self.to_out = nn.Linear(n_head*d_hidden, d_pair) + + self.scaling = 1/math.sqrt(d_hidden) + self.h = n_head + self.dim = d_hidden + + # initialize all parameters properly + self.reset_parameter() + + def reset_parameter(self): + # query/key/value projection: Glorot uniform / Xavier uniform + nn.init.xavier_uniform_(self.to_q.weight) + nn.init.xavier_uniform_(self.to_k.weight) + nn.init.xavier_uniform_(self.to_v.weight) + + # bias: normal distribution + self.to_b = init_lecun_normal(self.to_b) + + # gating: zero weights, one biases (mostly open gate at the begining) + nn.init.zeros_(self.to_g.weight) + nn.init.ones_(self.to_g.bias) + + # to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining + nn.init.zeros_(self.to_out.weight) + nn.init.zeros_(self.to_out.bias) + + def forward(self, pair, bias): + # pair: (B, L, L, d_pair) + B, L = pair.shape[:2] + + if self.is_row: + pair = pair.permute(0,2,1,3) + + pair = self.norm_pair(pair) + + query = self.to_q(pair).reshape(B, L, L, self.h, self.dim) + key = self.to_k(pair).reshape(B, L, L, self.h, self.dim) + value = self.to_v(pair).reshape(B, L, L, self.h, self.dim) + bias = self.to_b(bias) # (B, L, L, h) + gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim) + + query = query * self.scaling + key = key / math.sqrt(L) # normalize for tied attention + attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention + attn = attn + bias # apply bias + attn = F.softmax(attn, dim=-2) # (B, L, L, h) + + out = einsum('bijh,bkjhd->bikhd', attn, value).reshape(B, L, L, -1) + out = gate * out + + out = self.to_out(out) + if self.is_row: + out = out.permute(0,2,1,3) + return out + diff --git a/RF2_allatom/AuxiliaryPredictor.py b/RF2_allatom/AuxiliaryPredictor.py new file mode 100644 index 0000000..2145686 --- /dev/null +++ b/RF2_allatom/AuxiliaryPredictor.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn + +from chemical import NAATOKENS + +class DistanceNetwork(nn.Module): + def __init__(self, n_feat, p_drop=0.1): + super(DistanceNetwork, self).__init__() + # + self.proj_symm = nn.Linear(n_feat, 37*2) + self.proj_asymm = nn.Linear(n_feat, 37+19) + + self.reset_parameter() + + def reset_parameter(self): + # initialize linear layer for final logit prediction + nn.init.zeros_(self.proj_symm.weight) + nn.init.zeros_(self.proj_asymm.weight) + nn.init.zeros_(self.proj_symm.bias) + nn.init.zeros_(self.proj_asymm.bias) + + def forward(self, x): + # input: pair info (B, L, L, C) + + # predict theta, phi (non-symmetric) + logits_asymm = self.proj_asymm(x) + logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2) + logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2) + + # predict dist, omega + logits_symm = self.proj_symm(x) + logits_symm = logits_symm + logits_symm.permute(0,2,1,3) + logits_dist = logits_symm[:,:,:,:37].permute(0,3,1,2) + logits_omega = logits_symm[:,:,:,37:].permute(0,3,1,2) + + return logits_dist, logits_omega, logits_theta, logits_phi + +class MaskedTokenNetwork(nn.Module): + def __init__(self, n_feat, p_drop=0.1): + super(MaskedTokenNetwork, self).__init__() + + #fd note this predicts probability for the mask token (which is never in ground truth) + # it should be ok though(?) + self.proj = nn.Linear(n_feat, NAATOKENS) + + self.reset_parameter() + + def reset_parameter(self): + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward(self, x): + B, N, L = x.shape[:3] + logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L) + + return logits + +class LDDTNetwork(nn.Module): + def __init__(self, n_feat, n_bin_lddt=50): + super(LDDTNetwork, self).__init__() + self.proj = nn.Linear(n_feat, n_bin_lddt) + + self.reset_parameter() + + def reset_parameter(self): + nn.init.zeros_(self.proj.weight) + nn.init.zeros_(self.proj.bias) + + def forward(self, x): + logits = self.proj(x) # (B, L, 50) + + return logits.permute(0,2,1) + + + diff --git a/RF2_allatom/Embeddings.py b/RF2_allatom/Embeddings.py new file mode 100644 index 0000000..0c3cb82 --- /dev/null +++ b/RF2_allatom/Embeddings.py @@ -0,0 +1,279 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from opt_einsum import contract as einsum +import torch.utils.checkpoint as checkpoint +from util import * +from util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal +from Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer +from Track_module import PairStr2Pair +from chemical import NAATOKENS,NTOTALDOFS, NBTYPES + +# Module contains classes and functions to generate initial embeddings + +class PositionalEncoding2D(nn.Module): + # Add relative positional encoding to pair features + def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1): + super(PositionalEncoding2D, self).__init__() + self.minpos = minpos + self.maxpos = maxpos + self.nbin = abs(minpos)+maxpos+1 + self.emb = nn.Embedding(self.nbin, d_model) + + def forward(self, x, idx): + bins = torch.arange(self.minpos, self.maxpos, device=x.device) + seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L) + # + ib = torch.bucketize(seqsep, bins).long() # (B, L, L) + emb = self.emb(ib) #(B, L, L, d_model) + x = x + emb # add relative positional encoding + return x + +class MSA_emb(nn.Module): + # Get initial seed MSA embedding + def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=2*NAATOKENS+2+2, + minpos=-32, maxpos=32, p_drop=0.1): + super(MSA_emb, self).__init__() + self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA + self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding + self.emb_left = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding + self.emb_right = nn.Embedding(NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding + self.emb_state = nn.Embedding(NAATOKENS, d_state) + self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos, p_drop=p_drop) + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + self.emb_q = init_lecun_normal(self.emb_q) + self.emb_left = init_lecun_normal(self.emb_left) + self.emb_right = init_lecun_normal(self.emb_right) + self.emb_state = init_lecun_normal(self.emb_state) + + nn.init.zeros_(self.emb.bias) + + def forward(self, msa, seq, idx): + # Inputs: + # - msa: Input MSA (B, N, L, d_init) + # - seq: Input Sequence (B, L) + # - idx: Residue index + # Outputs: + # - msa: Initial MSA embedding (B, N, L, d_msa) + # - pair: Initial Pair embedding (B, L, L, d_pair) + + N = msa.shape[1] # number of sequenes in MSA + + # msa embedding + msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding + tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding + msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA + #msa = self.drop(msa) + + # pair embedding + left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair) + right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair) + pair = left + right # (B, L, L, d_pair) + pair = self.pos(pair, idx) # add relative position + + # state embedding + state = self.emb_state(seq) + + return msa, pair, state + +class Extra_emb(nn.Module): + # Get initial seed MSA embedding + def __init__(self, d_msa=256, d_init=NAATOKENS+1+2, p_drop=0.1): + super(Extra_emb, self).__init__() + self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA + self.emb_q = nn.Embedding(NAATOKENS, d_msa) # embedding for query sequence + #self.drop = nn.Dropout(p_drop) + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + nn.init.zeros_(self.emb.bias) + + def forward(self, msa, seq, idx): + # Inputs: + # - msa: Input MSA (B, N, L, d_init) + # - seq: Input Sequence (B, L) + # - idx: Residue index + # Outputs: + # - msa: Initial MSA embedding (B, N, L, d_msa) + N = msa.shape[1] # number of sequenes in MSA + msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding + seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding + msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA + #return self.drop(msa) + return (msa) + +class Bond_emb(nn.Module): + def __init__(self, d_pair=128, d_init=NBTYPES): + super(Bond_emb, self).__init__() + self.emb = nn.Linear(d_init, d_pair) + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + nn.init.zeros_(self.emb.bias) + + def forward(self, bond_feats): + return self.emb(bond_feats.float()) + +# TODO: Update template embedding not to use triangles.... +# Use input xyz_t with biased attention +class TemplatePairStack(nn.Module): + def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, rbf_sigma=1.0, p_drop=0.25): + super(TemplatePairStack, self).__init__() + self.n_block = n_block + self.rbf_sigma = rbf_sigma + proc_s = [PairStr2Pair(d_pair=d_templ, n_head=n_head, d_hidden=d_hidden, p_drop=p_drop) for i in range(n_block)] + self.block = nn.ModuleList(proc_s) + self.norm = nn.LayerNorm(d_templ) + + def forward(self, templ, xyz_t, use_checkpoint=False): + B, T, L = templ.shape[:3] + templ = templ.reshape(B*T, L, L, -1) + xyz_t = xyz_t.reshape(B*T, L, -1, 3) + rbf_feat = rbf(torch.cdist(xyz_t[:,:,1], xyz_t[:,:,1]), self.rbf_sigma) + + for i_block in range(self.n_block): + if use_checkpoint: + templ = checkpoint.checkpoint(create_custom_forward(self.block[i_block]), templ, rbf_feat) + else: + templ = self.block[i_block](templ, rbf_feat) + return self.norm(templ).reshape(B, T, L, L, -1) + + +class Templ_emb(nn.Module): + # Get template embedding + # Features are + # t2d: + # - 37 distogram bins + 6 orientations (43) + # - Mask (missing/unaligned) (1) + # t1d: + # - tiled AA sequence (20 standard aa + gap) + # - confidence (1) + # + def __init__(self, d_t1d=(NAATOKENS-1)+1, d_t2d=43+1, d_tor=3*NTOTALDOFS, d_pair=128, d_state=32, + n_block=2, d_templ=64, + n_head=4, d_hidden=16, p_drop=0.25): + super(Templ_emb, self).__init__() + # process 2D features + self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ) + self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head, + d_hidden=d_hidden, p_drop=p_drop) + + self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop) + + # process torsion angles + self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ) + self.proj_t1d = nn.Linear(d_templ, d_templ) + #self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head, + # d_hidden=d_hidden, p_drop=p_drop) + self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop) + + self.reset_parameter() + + def reset_parameter(self): + self.emb = init_lecun_normal(self.emb) + nn.init.zeros_(self.emb.bias) + + nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu') + nn.init.zeros_(self.emb_t1d.bias) + + self.proj_t1d = init_lecun_normal(self.proj_t1d) + nn.init.zeros_(self.proj_t1d.bias) + + def forward(self, t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=False): + # Input + # - t1d: 1D template info (B, T, L, 30) + # - t2d: 2D template info (B, T, L, L, 44) + B, T, L, _ = t1d.shape + + # Prepare 2D template features + left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1) + right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1) + # + templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88) + templ = self.emb(templ) # Template templures (B, T, L, L, d_templ) + # process each template features + xyz_t = xyz_t.reshape(B*T, L, -1, 3) + templ = self.templ_stack(templ, xyz_t, use_checkpoint=use_checkpoint) # (B, T, L,L, d_templ) + + # Prepare 1D template torsion angle features + t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17) + # process each template features + t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d))) + + # mixing query state features to template state features + state = state.reshape(B*L, 1, -1) + t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1) + if use_checkpoint: + out = checkpoint.checkpoint(create_custom_forward(self.attn_tor), state, t1d, t1d) + out = out.reshape(B, L, -1) + else: + out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1) + state = state.reshape(B, L, -1) + state = state + out + + # mixing query pair features to template information (Template pointwise attention) + pair = pair.reshape(B*L*L, 1, -1) + templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1) + if use_checkpoint: + out = checkpoint.checkpoint(create_custom_forward(self.attn), pair, templ, templ) + out = out.reshape(B, L, L, -1) + else: + out = self.attn(pair, templ, templ).reshape(B, L, L, -1) + # + pair = pair.reshape(B, L, L, -1) + pair = pair + out + + return pair, state + + +class Recycling(nn.Module): + def __init__(self, d_msa=256, d_pair=128, d_state=32, rbf_sigma=1.0): + super(Recycling, self).__init__() + self.proj_dist = nn.Linear(36+d_state*2, d_pair) + self.norm_pair = nn.LayerNorm(d_pair) + self.proj_sctors = nn.Linear(2*NTOTALDOFS, d_msa) + self.norm_msa = nn.LayerNorm(d_msa) + self.rbf_sigma = rbf_sigma + self.norm_state = nn.LayerNorm(d_state) + + self.reset_parameter() + + def reset_parameter(self): + self.proj_dist = init_lecun_normal(self.proj_dist) + nn.init.zeros_(self.proj_dist.bias) + self.proj_sctors = init_lecun_normal(self.proj_sctors) + nn.init.zeros_(self.proj_sctors.bias) + + def forward(self, msa, pair, xyz, state, sctors): + B, L = pair.shape[:2] + state = self.norm_state(state) + + left = state.unsqueeze(2).expand(-1,-1,L,-1) + right = state.unsqueeze(1).expand(-1,L,-1,-1) + + Ca_or_P = xyz[:,:,1] + + # recreate Cb given N,Ca,C + #N = xyz[:,:,0] + #C = xyz[:,:,2] + #Cb = generate_Cbeta(N,Ca,C) + #dist = rbf(torch.cdist(Cb, Cb), self.rbf_sigma) + + dist = rbf(torch.cdist(Ca_or_P, Ca_or_P), self.rbf_sigma) + dist = torch.cat((dist, left, right), dim=-1) + dist = self.proj_dist(dist) + pair = dist + self.norm_pair(pair) + + sctors = self.proj_sctors(sctors.reshape(B,-1,2*NTOTALDOFS)) + msa = sctors + self.norm_msa(msa) + + return msa, pair, state + diff --git a/RF2_allatom/RoseTTAFoldModel.py b/RF2_allatom/RoseTTAFoldModel.py new file mode 100644 index 0000000..afab679 --- /dev/null +++ b/RF2_allatom/RoseTTAFoldModel.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +from Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, Recycling +from Track_module import IterativeSimulator +from AuxiliaryPredictor import DistanceNetwork, MaskedTokenNetwork, LDDTNetwork + +from chemical import INIT_CRDS,NAATOKENS, NBTYPES + +class RoseTTAFoldModule(nn.Module): + def __init__( + self, n_extra_block=4, n_main_block=8, n_ref_block=4, n_finetune_block=0,\ + d_msa=256, d_msa_full=64, d_pair=128, d_templ=64, + n_head_msa=8, n_head_pair=4, n_head_templ=4, + d_hidden=32, d_hidden_templ=64, + rbf_sigma=1.0, p_drop=0.15, + SE3_param={}, SE3_ref_param={}, + atom_type_index=None, aamask=None, ljlk_parameters=None, lj_correction_parameters=None, + cb_len=None, cb_ang=None, cb_tor=None, + num_bonds=None, lj_lin=0.6 + ): + super(RoseTTAFoldModule, self).__init__() + # + # Input Embeddings + d_state = SE3_param['l0_out_features'] + self.latent_emb = MSA_emb(d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop) + self.full_emb = Extra_emb(d_msa=d_msa_full, d_init=NAATOKENS-1+4, p_drop=p_drop) + self.bond_emb = Bond_emb(d_pair=d_pair, d_init=NBTYPES) + self.templ_emb = Templ_emb(d_pair=d_pair, d_templ=d_templ, d_state=d_state, n_head=n_head_templ, + d_hidden=d_hidden_templ, p_drop=0.25) + + # Update inputs with outputs from previous round + self.recycle = Recycling(d_msa=d_msa, d_pair=d_pair, d_state=d_state, rbf_sigma=rbf_sigma) + # + self.simulator = IterativeSimulator( + n_extra_block=n_extra_block, + n_main_block=n_main_block, + n_ref_block=n_ref_block, + n_finetune_block=n_finetune_block, + d_msa=d_msa, + d_msa_full=d_msa_full, + d_pair=d_pair, + d_hidden=d_hidden, + n_head_msa=n_head_msa, + n_head_pair=n_head_pair, + SE3_param=SE3_param, + SE3_ref_param=SE3_ref_param, + rbf_sigma=rbf_sigma, + p_drop=p_drop, + atom_type_index=atom_type_index, # change if encoding elements instead of atomtype + aamask=aamask, + ljlk_parameters=ljlk_parameters, + lj_correction_parameters=lj_correction_parameters, + num_bonds=num_bonds, + cb_len=cb_len, + cb_ang=cb_ang, + cb_tor=cb_tor, + lj_lin=lj_lin + ) + + ## + self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop) + self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop) + self.lddt_pred = LDDTNetwork(d_state) + + def forward( + self, msa_latent, msa_full, seq, seq_unmasked, xyz, sctors, idx, bond_feats, + t1d=None, t2d=None, xyz_t=None, alpha_t=None, + msa_prev=None, pair_prev=None, state_prev=None, + return_raw=False, return_full=False, + use_checkpoint=False + ): + B, N, L = msa_latent.shape[:3] + # Get embeddings + msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx) + msa_full = self.full_emb(msa_full, seq, idx) + pair = pair + self.bond_emb(bond_feats) + # + # Do recycling + if msa_prev == None: + msa_prev = torch.zeros_like(msa_latent[:,0]) + pair_prev = torch.zeros_like(pair) + state_prev = torch.zeros_like(state) + + msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors) + msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1) + pair = pair + pair_recycle + state = state + state_recycle + + # add template embedding + pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, pair, state, use_checkpoint=use_checkpoint) + + # Predict coordinates from given inputs + msa, pair, xyz, alpha_s, xyz_allatom, state = self.simulator( + seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx, use_checkpoint=use_checkpoint) + + if return_raw: + # get last structure + xyz_last = xyz_allatom[-1].unsqueeze(0) + return msa[:,0], pair, xyz_last, state, alpha_s[-1] + + # predict masked amino acids + logits_aa = self.aa_pred(msa) + + # predict distogram & orientograms + logits = self.c6d_pred(pair) + + # Predict LDDT + lddt = self.lddt_pred(state) + + return logits, logits_aa, xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state + diff --git a/RF2_allatom/SE3_network.py b/RF2_allatom/SE3_network.py new file mode 100644 index 0000000..a1f01d1 --- /dev/null +++ b/RF2_allatom/SE3_network.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn + +#from equivariant_attention.modules import get_basis_and_r, GSE3Res, GNormBias +#from equivariant_attention.modules import GConvSE3, GNormSE3 +#from equivariant_attention.fibers import Fiber + +from util_module import init_lecun_normal_param +from se3_transformer.model import SE3Transformer +from se3_transformer.model.fiber import Fiber + +class SE3TransformerWrapper(nn.Module): + """SE(3) equivariant GCN with attention""" + def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4, + l0_in_features=32, l0_out_features=32, + l1_in_features=3, l1_out_features=2, + num_edge_features=32): + super().__init__() + # Build the network + self.l1_in = l1_in_features + # + fiber_edge = Fiber({0: num_edge_features}) + if l1_out_features > 0: + if l1_in_features > 0: + fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) + else: + fiber_in = Fiber({0: l0_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features, 1: l1_out_features}) + else: + if l1_in_features > 0: + fiber_in = Fiber({0: l0_in_features, 1: l1_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features}) + else: + fiber_in = Fiber({0: l0_in_features}) + fiber_hidden = Fiber.create(num_degrees, num_channels) + fiber_out = Fiber({0: l0_out_features}) + + self.se3 = SE3Transformer(num_layers=num_layers, + fiber_in=fiber_in, + fiber_hidden=fiber_hidden, + fiber_out = fiber_out, + num_heads=n_heads, + channels_div=div, + fiber_edge=fiber_edge, + use_layer_norm=True) + #use_layer_norm=False) + + self.reset_parameter() + + def reset_parameter(self): + + # make sure linear layer before ReLu are initialized with kaiming_normal_ + for n, p in self.se3.named_parameters(): + if "bias" in n: + nn.init.zeros_(p) + elif len(p.shape) == 1: + continue + else: + if "radial_func" not in n: + p = init_lecun_normal_param(p) + else: + if "net.6" in n: + nn.init.zeros_(p) + else: + nn.init.kaiming_normal_(p, nonlinearity='relu') + + # make last layers to be zero-initialized + #self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0']) + #self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1']) + nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0']) + nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1']) + + def forward(self, G, type_0_features, type_1_features=None, edge_features=None): + if self.l1_in > 0: + node_features = {'0': type_0_features, '1': type_1_features} + else: + node_features = {'0': type_0_features} + edge_features = {'0': edge_features} + return self.se3(G, node_features, edge_features) diff --git a/RF2_allatom/Track_module.py b/RF2_allatom/Track_module.py new file mode 100644 index 0000000..04cb743 --- /dev/null +++ b/RF2_allatom/Track_module.py @@ -0,0 +1,693 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from opt_einsum import contract as einsum +import torch.utils.checkpoint as checkpoint +from util_module import * +from Attention_module import * +from SE3_network import SE3TransformerWrapper +from resnet import ResidualNetwork +from util import INIT_CRDS, is_atom +from loss import ( + calc_BB_bond_geom_grads, calc_lj_grads, calc_hb_grads, calc_cart_bonded_grads, calc_ljallatom_grads, + calc_lj, calc_cart_bonded +) +from chemical import NTOTALDOFS + +# Components for three-track blocks +# 1. MSA -> MSA update (biased attention. bias from pair & structure) +# 2. Pair -> Pair update (biased attention. bias from structure) +# 3. MSA -> Pair update (extract coevolution signal) +# 4. Str -> Str update (node from MSA, edge from Pair) + +# Update MSA with biased self-attention. bias from Pair & Str +class MSAPairStr2MSA(nn.Module): + def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, + d_hidden=32, p_drop=0.15, use_global_attn=False): + super(MSAPairStr2MSA, self).__init__() + self.norm_pair = nn.LayerNorm(d_pair) + self.proj_pair = nn.Linear(d_pair+36, d_pair) + self.norm_state = nn.LayerNorm(d_state) + self.proj_state = nn.Linear(d_state, d_msa) + self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) + self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair, + n_head=n_head, d_hidden=d_hidden) + if use_global_attn: + self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden) + else: + self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden) + self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop) + + # Do proper initialization + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distrib + self.proj_pair = init_lecun_normal(self.proj_pair) + self.proj_state = init_lecun_normal(self.proj_state) + + # initialize bias to zeros + nn.init.zeros_(self.proj_pair.bias) + nn.init.zeros_(self.proj_state.bias) + + def forward(self, msa, pair, rbf_feat, state): + ''' + Inputs: + - msa: MSA feature (B, N, L, d_msa) + - pair: Pair feature (B, L, L, d_pair) + - rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36) + - xyz: xyz coordinates (B, L, n_atom, 3) + - state: updated node features after SE(3)-Transformer layer (B, L, d_state) + Output: + - msa: Updated MSA feature (B, N, L, d_msa) + ''' + B, N, L = msa.shape[:3] + + # prepare input bias feature by combining pair & coordinate info + pair = self.norm_pair(pair) + pair = torch.cat((pair, rbf_feat), dim=-1) + pair = self.proj_pair(pair) # (B, L, L, d_pair) + # + # update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3 + state = self.norm_state(state) + state = self.proj_state(state).reshape(B, 1, L, -1) + msa = msa.index_add(1, torch.tensor([0,], device=state.device), state.float()) + # + # Apply row/column attention to msa & transform + msa = msa + self.drop_row(self.row_attn(msa, pair)) + msa = msa + self.col_attn(msa) + msa = msa + self.ff(msa) + + return msa + +class PairStr2Pair(nn.Module): + def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_rbf=36, p_drop=0.15): + super(PairStr2Pair, self).__init__() + + self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop) + self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop) + + self.row_attn = BiasedAxialAttention(d_pair, d_rbf, n_head, d_hidden, p_drop=p_drop, is_row=True) + self.col_attn = BiasedAxialAttention(d_pair, d_rbf, n_head, d_hidden, p_drop=p_drop, is_row=False) + + self.ff = FeedForwardLayer(d_pair, 2) + + def forward(self, pair, rbf_feat): + pair = pair + self.drop_row(self.row_attn(pair, rbf_feat)) + pair = pair + self.drop_col(self.col_attn(pair, rbf_feat)) + pair = pair + self.ff(pair) + return pair + +class MSA2Pair(nn.Module): + def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15): + super(MSA2Pair, self).__init__() + self.norm = nn.LayerNorm(d_msa) + self.proj_left = nn.Linear(d_msa, d_hidden) + self.proj_right = nn.Linear(d_msa, d_hidden) + self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair) + + #self.proj_down = nn.Linear(d_pair*2, d_pair) + #self.update = ResidualNetwork(1, d_pair, d_pair, d_pair, p_drop=p_drop) + + self.reset_parameter() + + def reset_parameter(self): + # normal initialization + self.proj_left = init_lecun_normal(self.proj_left) + self.proj_right = init_lecun_normal(self.proj_right) + self.proj_out = init_lecun_normal(self.proj_out) + nn.init.zeros_(self.proj_left.bias) + nn.init.zeros_(self.proj_right.bias) + nn.init.zeros_(self.proj_out.bias) + + # Identity initialization for proj_down + #nn.init.eye_(self.proj_down.weight) + #nn.init.zeros_(self.proj_down.bias) + + def forward(self, msa, pair): + B, N, L = msa.shape[:3] + msa = self.norm(msa) + left = self.proj_left(msa) + right = self.proj_right(msa) + right = right / float(N) + out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1) + out = self.proj_out(out) + + #pair = torch.cat((pair, out), dim=-1) # (B, L, L, d_pair*2) + #pair = self.proj_down(pair) + #pair = self.update(pair.permute(0,3,1,2).contiguous()) + #pair = pair.permute(0,2,3,1).contiguous() + pair = pair + out + + return pair + +class Str2Str(nn.Module): + def __init__(self, d_msa=256, d_pair=128, d_state=16, + SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, + nextra_l0=0, nextra_l1=0, + rbf_sigma=1.0, p_drop=0.1 + ): + super(Str2Str, self).__init__() + + # initial node & pair feature process + self.norm_msa = nn.LayerNorm(d_msa) + self.norm_pair = nn.LayerNorm(d_pair) + self.norm_state = nn.LayerNorm(d_state) + + self.embed_x = nn.Linear(d_msa+d_state, SE3_param['l0_in_features']) + self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features']) + self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features']) + + self.norm_node = nn.LayerNorm(SE3_param['l0_in_features']) + self.norm_edge1 = nn.LayerNorm(SE3_param['num_edge_features']) + self.norm_edge2 = nn.LayerNorm(SE3_param['num_edge_features']) + + SE3_param_temp = SE3_param.copy() + SE3_param_temp['l0_in_features'] += nextra_l0 + SE3_param_temp['l1_in_features'] += nextra_l1 + + self.se3 = SE3TransformerWrapper(**SE3_param_temp) + self.rbf_sigma = rbf_sigma + + self.sc_predictor = SCPred( + d_msa=d_msa, + d_state=SE3_param['l0_out_features'], + p_drop=p_drop) + + #self.nextra_l0 = nextra_l0 + #self.nextra_l1 = nextra_l1 + + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distribution + self.embed_x = init_lecun_normal(self.embed_x) + self.embed_e1 = init_lecun_normal(self.embed_e1) + self.embed_e2 = init_lecun_normal(self.embed_e2) + + # initialize bias to zeros + nn.init.zeros_(self.embed_x.bias) + nn.init.zeros_(self.embed_e1.bias) + nn.init.zeros_(self.embed_e2.bias) + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, msa, pair, xyz, state, idx, rotation_mask, extra_l0=None, extra_l1=None, top_k=128, eps=1e-5): + # process msa & pair features + B, N, L = msa.shape[:3] + node = self.norm_msa(msa[:,0]) + pair = self.norm_pair(pair) + state = self.norm_state(state) + + node = torch.cat((node, state), dim=-1) + node = self.norm_node(self.embed_x(node)) + pair = self.norm_edge1(self.embed_e1(pair)) + + neighbor = get_seqsep(idx) + cas = xyz[:,:,1].contiguous() + rbf_feat = rbf(torch.cdist(cas, cas), self.rbf_sigma) + pair = torch.cat((pair, rbf_feat, neighbor), dim=-1) + pair = self.norm_edge2(self.embed_e2(pair)) + + # define graph + if top_k != 0: + G, edge_feats = make_topk_graph(xyz[:,:,1,:], pair, idx, top_k=top_k) + else: + G, edge_feats = make_full_graph(xyz[:,:,1,:], pair, idx) + l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2) + l1_feats = l1_feats.reshape(B*L, -1, 3) + if extra_l1 is not None: + l1_feats = torch.cat( (l1_feats,extra_l1), dim=1 ) + if extra_l0 is not None: + node = torch.cat( (node,extra_l0), dim=2 ) + + # apply SE(3) Transformer & update coordinates + shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats) + + state = shift['0'].reshape(B, L, -1) # (B, L, C) + + offset = shift['1'].reshape(B, L, 2, 3) + T = offset[:,:,0,:] / 10.0 + R = offset[:,:,1,:] / 100.0 + + Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) ) + qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm + + v = xyz - xyz[:,:,1:2,:] + Rout = torch.zeros((B,L,3,3), device=xyz.device) + Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD + Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD + Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC + Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD + Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD + Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB + Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC + Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB + Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD + I = torch.eye(3, device=Rout.device).expand(B,L,3,3) + Rout = torch.where(rotation_mask.reshape(B, L, 1,1), I, Rout) + xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:] + + alpha = self.sc_predictor(msa[:,0], state) + + return xyz, state, alpha + + +class Allatom2Allatom(nn.Module): + def __init__( + self, + SE3_param + ): + super(Allatom2Allatom, self).__init__() + + self.se3 = SE3TransformerWrapper(**SE3_param) + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, seq, xyz, aamask, num_bonds, state, grads, top_k=24, eps=1e-5): + # seq (B,L) + # xyz (B,L,27,3) + # aamask (22,27) [per-amino-acid] + # num_bonds (22,27,27) [per-amino-acid] + # state (N,B,L,K) [K channels] + # grads (N,B,L,27,3) [N terms] + + B, L = xyz.shape[:2] + + mask = aamask[seq] + G, edge = make_atom_graph( xyz, mask, num_bonds[seq], top_k, maxbonds=4 ) + node = state[mask] + node_l1 = grads[:,mask].permute(1,0,2) + + # apply SE(3) Transformer & update coordinates + shift = self.se3(G, node[...,None], node_l1, edge) + + state[mask] = shift['0'][...,0] + xyz[mask] = xyz[mask] + shift['1'].squeeze(1) / 100.0 + + return xyz, state + +class AllatomEmbed(nn.Module): + def __init__( + self, + d_state_in=64, + d_state_out=32, + p_mask=0.15 + ): + super(AllatomEmbed, self).__init__() + + self.p_mask = p_mask + + # initial node & pair feature process + self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element + self.norm_state = nn.LayerNorm(d_state_out) + + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distribution + self.compress_embed = init_lecun_normal(self.compress_embed) + # initialize bias to zeros + nn.init.zeros_(self.compress_embed.bias) + + def forward(self, state, seq, eltmap): + B,L = state.shape[:2] + mask = torch.rand(B,L) < self.p_mask + state = state.reshape(B,L,1,-1).repeat(1,1,27,1) + state[mask] = 0.0 + elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element + state = self.compress_embed( + torch.cat( (state,elements), dim=-1 ) + ) + state = self.norm_state( state ) + + return state + +# embed residue state + atomtype -> per-atom state +# +class AllatomEmbed(nn.Module): + def __init__( + self, + d_state_in=64, + d_state_out=32, + p_mask=0.15 + ): + super(AllatomEmbed, self).__init__() + + self.p_mask = p_mask + + # initial node & pair feature process + self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element + self.norm_state = nn.LayerNorm(d_state_out) + + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distribution + self.compress_embed = init_lecun_normal(self.compress_embed) + # initialize bias to zeros + nn.init.zeros_(self.compress_embed.bias) + + def forward(self, state, seq, eltmap): + B,L = state.shape[:2] + mask = torch.rand(B,L) < self.p_mask + state = state.reshape(B,L,1,-1).repeat(1,1,27,1) + state[mask] = 0.0 + elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element + state = self.compress_embed( + torch.cat( (state,elements), dim=-1 ) + ) + state = self.norm_state( state ) + + return state + +# embed per-atom state -> residue state +class ResidueEmbed(nn.Module): + def __init__( + self, + d_state_in=16, + d_state_out=64 + ): + super(ResidueEmbed, self).__init__() + + self.compress_embed = nn.Linear(27*d_state_in, d_state_out) + self.norm_state = nn.LayerNorm(d_state_out) + + self.reset_parameter() + + def reset_parameter(self): + # initialize weights to normal distribution + self.compress_embed = init_lecun_normal(self.compress_embed) + # initialize bias to zeros + nn.init.zeros_(self.compress_embed.bias) + + def forward(self, state): + B,L = state.shape[:2] + + state = self.compress_embed( state.reshape(B,L,-1) ) + state = self.norm_state( state ) + + return state + +class SCPred(nn.Module): + def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15): + super(SCPred, self).__init__() + self.norm_s0 = nn.LayerNorm(d_msa) + self.norm_si = nn.LayerNorm(d_state) + self.linear_s0 = nn.Linear(d_msa, d_hidden) + self.linear_si = nn.Linear(d_state, d_hidden) + + # ResNet layers + self.linear_1 = nn.Linear(d_hidden, d_hidden) + self.linear_2 = nn.Linear(d_hidden, d_hidden) + self.linear_3 = nn.Linear(d_hidden, d_hidden) + self.linear_4 = nn.Linear(d_hidden, d_hidden) + + # Final outputs + self.linear_out = nn.Linear(d_hidden, 2*NTOTALDOFS) + + self.reset_parameter() + + def reset_parameter(self): + # normal initialization + self.linear_s0 = init_lecun_normal(self.linear_s0) + self.linear_si = init_lecun_normal(self.linear_si) + self.linear_out = init_lecun_normal(self.linear_out) + nn.init.zeros_(self.linear_s0.bias) + nn.init.zeros_(self.linear_si.bias) + nn.init.zeros_(self.linear_out.bias) + + # right before relu activation: He initializer (kaiming normal) + nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu') + nn.init.zeros_(self.linear_1.bias) + nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu') + nn.init.zeros_(self.linear_3.bias) + + # right before residual connection: zero initialize + nn.init.zeros_(self.linear_2.weight) + nn.init.zeros_(self.linear_2.bias) + nn.init.zeros_(self.linear_4.weight) + nn.init.zeros_(self.linear_4.bias) + + def forward(self, seq, state): + ''' + Predict side-chain torsion angles along with backbone torsions + Inputs: + - seq: hidden embeddings corresponding to query sequence (B, L, d_msa) + - state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state) + Outputs: + - si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2) + ''' + B, L = seq.shape[:2] + seq = self.norm_s0(seq) + state = self.norm_si(state) + si = self.linear_s0(seq) + self.linear_si(state) + + si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si)))) + si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si)))) + + si = self.linear_out(F.relu_(si)) + return si.view(B, L, NTOTALDOFS, 2) + +class IterBlock(nn.Module): + def __init__(self, d_msa=256, d_pair=128, + n_head_msa=8, n_head_pair=4, + use_global_attn=False, + d_hidden=32, d_hidden_msa=None, rbf_sigma=1.0, p_drop=0.15, + SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}): + super(IterBlock, self).__init__() + if d_hidden_msa == None: + d_hidden_msa = d_hidden + + self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, + n_head=n_head_msa, + d_state=SE3_param['l0_out_features'], + use_global_attn=use_global_attn, + d_hidden=d_hidden_msa, p_drop=p_drop) + self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair, + d_hidden=16, p_drop=p_drop) # fd - use only 16 channels + self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, + d_hidden=d_hidden, p_drop=p_drop) + self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, + d_state=SE3_param['l0_out_features'], + SE3_param=SE3_param, + rbf_sigma=rbf_sigma, + p_drop=p_drop) + self.rbf_sigma = rbf_sigma + + def forward(self, msa, pair, xyz, state, idx, use_checkpoint=False, top_k=128, rotation_mask=None): + cas = xyz[:,:,1].contiguous() + rbf_feat = rbf(torch.cdist(cas, cas), self.rbf_sigma) + if use_checkpoint: + msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state) + pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair) + pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat) + + xyz, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=top_k), + msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask) + + else: + msa = self.msa2msa(msa, pair, rbf_feat, state) + pair = self.msa2pair(msa, pair) + pair = self.pair2pair(pair, rbf_feat) + + xyz, state, alpha = self.str2str(msa.float(), pair.float(), xyz.detach().float(), state.float(), idx, rotation_mask, top_k=top_k) + + return msa, pair, xyz, state, alpha + +class IterativeSimulator(nn.Module): + def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4, n_finetune_block=0, + d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32, + n_head_msa=8, n_head_pair=4, + SE3_param={}, SE3_ref_param={}, + rbf_sigma=1.0, p_drop=0.15, + atom_type_index=None, aamask=None, + ljlk_parameters=None, lj_correction_parameters=None, + cb_len=None, cb_ang=None, cb_tor=None, + num_bonds=None, lj_lin=0.6 + ): + super(IterativeSimulator, self).__init__() + self.n_extra_block = n_extra_block + self.n_main_block = n_main_block + self.n_ref_block = n_ref_block + self.n_finetune_block = n_finetune_block + + self.atom_type_index = atom_type_index + self.aamask = aamask + self.ljlk_parameters = ljlk_parameters + self.lj_correction_parameters = lj_correction_parameters + self.num_bonds = num_bonds + self.lj_lin = lj_lin + self.cb_len = cb_len + self.cb_ang = cb_ang + self.cb_tor = cb_tor + + # Update with extra sequences + if n_extra_block > 0: + self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair, + n_head_msa=n_head_msa, + n_head_pair=n_head_pair, + d_hidden_msa=8, + d_hidden=d_hidden, + p_drop=p_drop, + rbf_sigma=rbf_sigma, + use_global_attn=True, + SE3_param=SE3_param) + for i in range(n_extra_block)]) + + # Update with seed sequences + if n_main_block > 0: + self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair, + n_head_msa=n_head_msa, + n_head_pair=n_head_pair, + d_hidden=d_hidden, + p_drop=p_drop, + rbf_sigma=rbf_sigma, + use_global_attn=False, + SE3_param=SE3_param) + for i in range(n_main_block)]) + + # Final SE(3) refinement + if n_ref_block > 0: + self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair, + d_state=SE3_param['l0_out_features'], + SE3_param=SE3_ref_param, + rbf_sigma=rbf_sigma, + p_drop=p_drop, + # nextra_l0=2*NTOTALDOFS, + # nextra_l1=6 + ) + + # Fine-tuning all-atom SE(3) refinement + if n_finetune_block > 0: + d_state=16 + self.allatom_embed = AllatomEmbed( + d_state_in = SE3_param['l0_out_features'], + d_state_out = d_state, + p_mask = 0.15 + ) + self.finetune_refiner = Allatom2Allatom( + SE3_param = { + 'num_layers':1, + 'num_channels':16, + 'num_degrees':2, + 'l0_in_features':d_state, + 'l0_out_features':d_state, + 'l1_in_features':2, + 'l1_out_features':1, + 'num_edge_features':4, + 'n_heads':4, + 'div':2, + } + ) + self.residue_embed = ResidueEmbed( + d_state_in = d_state, + d_state_out = SE3_param['l0_out_features'] + ) + + # To get all-atom coordinates + self.compute_allatom_coords = ComputeAllAtomCoords() + + + def forward(self, seq_unmasked, msa, msa_full, pair, xyz, state, idx, use_checkpoint=False): + # input: + # msa: initial MSA embeddings (N, L, d_msa) + # pair: initial residue pair embeddings (L, L, d_pair) + rotation_mask = is_atom(seq_unmasked) + xyz_s = list() + alpha_s = list() + for i_m in range(self.n_extra_block): + msa_full, pair, xyz, state, alpha = self.extra_block[i_m](msa_full, pair, + xyz, state, idx, + use_checkpoint=use_checkpoint, top_k=0, rotation_mask=rotation_mask) + xyz_s.append(xyz) + alpha_s.append(alpha) + + for i_m in range(self.n_main_block): + msa, pair, xyz, state, alpha = self.main_block[i_m](msa, pair, + xyz, state, idx, + use_checkpoint=use_checkpoint, top_k=0, rotation_mask=rotation_mask) + xyz_s.append(xyz) + alpha_s.append(alpha) + + _, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here... + + # now use unmasked seq (no cross-talk for msa prediction) + for i_m in range(self.n_ref_block): + # dbonddxyz, = calc_BB_bond_geom_grads(seq_unmasked[0], xyz.detach(), idx) + # dljdxyz, dljdalpha = calc_lj_grads( + # seq_unmasked, xyz.detach(), alpha.detach(), + # self.compute_allatom_coords, + # self.aamask, + # self.ljlk_parameters, + # self.lj_correction_parameters, + # self.num_bonds, + # lj_lin=self.lj_lin) + + # extra_l1 = torch.cat((dbonddxyz[0].detach(),dljdxyz[0].detach()), dim=1) + # extra_l0 = dljdalpha.reshape(1,-1,2*NTOTALDOFS).detach() + extra_l0 =None + extra_l1= None + + xyz, state, alpha = self.str_refiner( + msa, pair, xyz.detach(), state, idx, rotation_mask, + extra_l0, extra_l1, top_k=128) + + xyz_s.append(xyz) + alpha_s.append(alpha) + + _, xyzallatom = self.compute_allatom_coords(seq_unmasked, xyz, alpha) # think about detach here... + xyzallatom_s = list() + xyzallatom_s.append(xyzallatom.clone()) + + if (self.n_finetune_block>0): + state = self.allatom_embed(state, seq_unmasked, self.atom_type_index) + + for i_m in range(self.n_finetune_block): + # dbonddxyz, = calc_cart_bonded_grads( + # seq_unmasked, xyzallatom.detach(), idx, + # self.cb_len, self.cb_ang, self.cb_tor + # ) + # dljdxyz, = calc_ljallatom_grads( + # seq_unmasked, + # xyzallatom.detach(), + # self.aamask, + # self.ljlk_parameters, + # self.lj_correction_parameters, + # self.num_bonds, + # lj_lin=self.lj_lin + # ) + + # extra_l1 = torch.stack((dbonddxyz.detach(), dljdxyz.detach())) + extra_l1 = None + + xyzallatom, state = self.finetune_refiner( + seq_unmasked, + xyzallatom.detach().float(), + self.aamask, + self.num_bonds, + state, + extra_l1.float() + ) + + # cb_loss = calc_cart_bonded( + # seq_unmasked, xyzallatom.detach(), idx, + # self.cb_len, self.cb_ang, self.cb_tor + # ) + # lj_loss = calc_lj( + # seq_unmasked[0], + # xyzallatom.detach(), + # self.aamask, + # self.ljlk_parameters, + # self.lj_correction_parameters, + # self.num_bonds, + # lj_lin=self.lj_lin + # ) + + xyzallatom_s.append(xyzallatom.clone()) + + state = self.residue_embed(state) + + xyz = torch.stack(xyz_s, dim=0) + alpha_s = torch.stack(alpha_s, dim=0) + xyzallatom_s = torch.cat(xyzallatom_s, dim=0) + + return msa, pair, xyz, alpha_s, xyzallatom_s, state diff --git a/RF2_allatom/arguments.py b/RF2_allatom/arguments.py new file mode 100644 index 0000000..520edc2 --- /dev/null +++ b/RF2_allatom/arguments.py @@ -0,0 +1,172 @@ +import argparse +import data_loader +import os + +TRUNK_PARAMS = ['n_extra_block', 'n_main_block', 'n_ref_block', 'n_finetune_block',\ + 'd_msa', 'd_msa_full', 'd_pair', 'd_templ',\ + 'n_head_msa', 'n_head_pair', 'n_head_templ', 'd_hidden', 'd_hidden_templ', 'p_drop', 'rbf_sigma'] + +SE3_PARAMS = ['num_layers', 'num_channels', 'num_degrees', 'n_heads', 'div', + 'l0_in_features', 'l0_out_features', 'l1_in_features', 'l1_out_features', 'num_edge_features' + ] + +def get_args(): + parser = argparse.ArgumentParser() + + # training parameters + train_group = parser.add_argument_group("training parameters") + train_group.add_argument("-model_name", default=None, + help="model name for saving") + train_group.add_argument('-batch_size', type=int, default=1, + help="Batch size [1]") + train_group.add_argument('-lr', type=float, default=2.0e-4, + help="Learning rate [5.0e-4]") + train_group.add_argument('-num_epochs', type=int, default=300, + help="Number of epochs [300]") + train_group.add_argument("-step_lr", type=int, default=300, + help="Parameter for Step LR scheduler [300]") + train_group.add_argument("-port", type=int, default=12319, + help="PORT for ddp training, should be randomized [12319]") + train_group.add_argument("-accum", type=int, default=1, + help="Gradient accumulation when it's > 1 [1]") + train_group.add_argument("-eval", action='store_true', default=False, + help="Train structure only") + + # data-loading parameters + data_group = parser.add_argument_group("data loading parameters") + data_group.add_argument('-maxseq', type=int, default=1024, + help="Maximum depth of subsampled MSA [1024]") + data_group.add_argument('-maxtoken', type=int, default=2**18, + help="Maximum depth of subsampled MSA [2**18]") + data_group.add_argument('-maxlat', type=int, default=128, + help="Maximum depth of subsampled MSA [128]") + data_group.add_argument("-crop", type=int, default=260, + help="Upper limit of crop size [260]") + data_group.add_argument("-rescut", type=float, default=4.5, + help="Resolution cutoff [4.5]") + data_group.add_argument("-slice", type=str, default="DISCONT", + help="How to make crops [CONT / DISCONT (default)]") + data_group.add_argument("-subsmp", type=str, default="UNI", + help="How to subsample MSAs [UNI (default) / LOG / CONST]") + data_group.add_argument('-mintplt', type=int, default=1, + help="Minimum number of templates to select [1]") + data_group.add_argument('-maxtplt', type=int, default=4, + help="maximum number of templates to select [4]") + data_group.add_argument('-seqid', type=float, default=150.0, + help="maximum sequence identity cutoff for template selection [150.0]") + data_group.add_argument('-maxcycle', type=int, default=4, + help="maximum number of recycle [4]") + + # Trunk module properties + trunk_group = parser.add_argument_group("Trunk module parameters") + trunk_group.add_argument('-n_extra_block', type=int, default=4, + help="Number of iteration blocks for extra sequences [4]") + trunk_group.add_argument('-n_main_block', type=int, default=8, + help="Number of iteration blocks for main sequences [8]") + trunk_group.add_argument('-n_ref_block', type=int, default=4, + help="Number of refinement layers") + trunk_group.add_argument('-n_finetune_block', type=int, default=0, + help="Number of finetune layers" [0]) + trunk_group.add_argument('-d_msa', type=int, default=256, + help="Number of MSA features [256]") + trunk_group.add_argument('-d_msa_full', type=int, default=64, + help="Number of MSA features [64]") + trunk_group.add_argument('-d_pair', type=int, default=128, + help="Number of pair features [128]") + trunk_group.add_argument('-d_templ', type=int, default=64, + help="Number of templ features [64]") + trunk_group.add_argument('-n_head_msa', type=int, default=8, + help="Number of attention heads for MSA2MSA [8]") + trunk_group.add_argument('-n_head_pair', type=int, default=4, + help="Number of attention heads for Pair2Pair [4]") + trunk_group.add_argument('-n_head_templ', type=int, default=4, + help="Number of attention heads for template [4]") + trunk_group.add_argument("-d_hidden", type=int, default=32, + help="Number of hidden features [32]") + trunk_group.add_argument("-d_hidden_templ", type=int, default=64, + help="Number of hidden features for templates [64]") + trunk_group.add_argument("-p_drop", type=float, default=0.15, + help="Dropout ratio [0.15]") + trunk_group.add_argument("-rbf_sigma", type=float, default=1.0, + help="Sigma scale factor for RBF [1.0]") + + # Structure module properties + str_group = parser.add_argument_group("structure module parameters") + str_group.add_argument('-num_layers', type=int, default=1, + help="Number of equivariant layers in structure module block [1]") + str_group.add_argument('-num_channels', type=int, default=32, + help="Number of channels [32]") + str_group.add_argument('-num_degrees', type=int, default=2, + help="Number of degrees for SE(3) network [2]") + str_group.add_argument('-l0_in_features', type=int, default=64, + help="Number of type 0 input features [64]") + str_group.add_argument('-l0_out_features', type=int, default=64, + help="Number of type 0 output features [64]") + str_group.add_argument('-l1_in_features', type=int, default=3, + help="Number of type 1 input features [3]") + str_group.add_argument('-l1_out_features', type=int, default=2, + help="Number of type 1 output features [2]") + str_group.add_argument('-num_edge_features', type=int, default=64, + help="Number of edge features [64]") + str_group.add_argument('-n_heads', type=int, default=4, + help="Number of attention heads for SE3-Transformer [4]") + str_group.add_argument("-div", type=int, default=4, + help="Div parameter for SE3-Transformer [4]") + str_group.add_argument('-ref_num_layers', type=int, default=1, + help="Number of equivariant layers in structure module block [1]") + str_group.add_argument('-ref_num_channels', type=int, default=32, + help="Number of channels [32]") + + + # Loss function parameters + loss_group = parser.add_argument_group("loss parameters") + loss_group.add_argument('-w_dist', type=float, default=1.0, + help="Weight on distd in loss function [1.0]") + loss_group.add_argument('-w_str', type=float, default=10.0, + help="Weight on strd in loss function [10.0]") + loss_group.add_argument('-w_lddt', type=float, default=0.1, + help="Weight on predicted lddt loss [0.1]") + loss_group.add_argument('-w_aa', type=float, default=3.0, + help="Weight on MSA masked token prediction loss [3.0]") + loss_group.add_argument('-w_bond', type=float, default=0.0, + help="Weight on predicted bond loss [0.0]") + loss_group.add_argument('-w_dih', type=float, default=0.0, + help="Weight on pseudodihedral loss [0.0]") + loss_group.add_argument('-w_clash', type=float, default=0.0, + help="Weight on clash loss [0.0]") + loss_group.add_argument('-w_hb', type=float, default=0.0, + help="Weight on clash loss [0.0]") + loss_group.add_argument('-lj_lin', type=float, default=0.75, + help="linear inflection for lj [0.75]") + + # parse arguments + args = parser.parse_args() + + # Setup dataloader parameters: + loader_param = data_loader.set_data_loader_params(args) + + # make dictionary for each parameters + trunk_param = {} + for param in TRUNK_PARAMS: + trunk_param[param] = getattr(args, param) + SE3_param = {} + for param in SE3_PARAMS: + if hasattr(args, param): + SE3_param[param] = getattr(args, param) + + SE3_ref_param = SE3_param.copy() + + for param in SE3_PARAMS: + if hasattr(args, 'ref_'+param): + SE3_ref_param[param] = getattr(args, 'ref_'+param) + + #print (SE3_param) + #print (SE3_ref_param) + trunk_param['SE3_param'] = SE3_param + trunk_param['SE3_ref_param'] = SE3_ref_param + + loss_param = {} + for param in ['w_dist', 'w_str', 'w_aa', 'w_lddt', 'w_bond', 'w_dih', 'w_clash', 'w_hb', 'lj_lin']: + loss_param[param] = getattr(args, param) + + return args, trunk_param, loader_param, loss_param diff --git a/RF2_allatom/cartbonded.json b/RF2_allatom/cartbonded.json new file mode 100644 index 0000000..f755e12 --- /dev/null +++ b/RF2_allatom/cartbonded.json @@ -0,0 +1,9052 @@ +{ + "lengths": [ + { + "res": "CYS", + "atm1": " SG ", + "atm2": " SG ", + "x0": 0, + "K": 0 + }, + { + "res": "ALA", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23101, + "K": 563.084 + }, + { + "res": "ALA", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "ALA", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.52174, + "K": 260.414 + }, + { + "res": "ALA", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09008, + "K": 306.735 + }, + { + "res": "ALA", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09004, + "K": 122.8108 + }, + { + "res": "ALA", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09007, + "K": 122.8108 + }, + { + "res": "ALA", + "atm1": " CB ", + "atm2": "3HB ", + "x0": 1.0888, + "K": 122.8108 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "ARG", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "ARG", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "ARG", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.52157, + "K": 260.414 + }, + { + "res": "ARG", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09042, + "K": 306.735 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.08903, + "K": 117.8526 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09007, + "K": 117.8526 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.52044, + "K": 84.8615 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": "1HD ", + "x0": 1.08956, + "K": 117.8526 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": "2HD ", + "x0": 1.08936, + "K": 117.8526 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " NE ", + "x0": 1.45407, + "K": 99.5454 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": "1HG ", + "x0": 1.09024, + "K": 117.8526 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": "2HG ", + "x0": 1.09062, + "K": 117.8526 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": " CD ", + "x0": 1.48537, + "K": 84.8615 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NH1", + "x0": 1.31462, + "K": 176.5882 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NH2", + "x0": 1.32163, + "K": 176.5882 + }, + { + "res": "ARG", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "ARG", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "x0": 1.34729, + "K": 176.5882 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " HE ", + "x0": 1.01136, + "K": 173.537 + }, + { + "res": "ARG", + "atm1": " NH1", + "atm2": "1HH1", + "x0": 1.01026, + "K": 173.537 + }, + { + "res": "ARG", + "atm1": " NH1", + "atm2": "2HH1", + "x0": 1.00968, + "K": 173.537 + }, + { + "res": "ARG", + "atm1": " NH2", + "atm2": "1HH2", + "x0": 1.01091, + "K": 173.537 + }, + { + "res": "ARG", + "atm1": " NH2", + "atm2": "2HH2", + "x0": 1.00912, + "K": 173.537 + }, + { + "res": "ASN", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "ASN", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "ASN", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.51768, + "K": 260.414 + }, + { + "res": "ASN", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09116, + "K": 306.735 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.0908, + "K": 117.8526 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09009, + "K": 117.8526 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.50351, + "K": 76.28 + }, + { + "res": "ASN", + "atm1": " CG ", + "atm2": " ND2", + "x0": 1.30864, + "K": 164.002 + }, + { + "res": "ASN", + "atm1": " CG ", + "atm2": " OD1", + "x0": 1.2364, + "K": 247.91 + }, + { + "res": "ASN", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "ASN", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "ASN", + "atm1": " ND2", + "atm2": "1HD2", + "x0": 1.00047, + "K": 183.072 + }, + { + "res": "ASN", + "atm1": " ND2", + "atm2": "2HD2", + "x0": 0.999495, + "K": 183.072 + }, + { + "res": "ASP", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "ASP", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "ASP", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53065, + "K": 260.414 + }, + { + "res": "ASP", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09038, + "K": 306.735 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09024, + "K": 117.8526 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09039, + "K": 117.8526 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.52279, + "K": 76.28 + }, + { + "res": "ASP", + "atm1": " CG ", + "atm2": " OD1", + "x0": 1.20825, + "K": 200.235 + }, + { + "res": "ASP", + "atm1": " CG ", + "atm2": " OD2", + "x0": 1.20776, + "K": 200.235 + }, + { + "res": "ASP", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "ASP", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "CYS", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.52886, + "K": 260.414 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09006, + "K": 306.735 + }, + { + "res": "CYS", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09025, + "K": 117.8526 + }, + { + "res": "CYS", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08982, + "K": 117.8526 + }, + { + "res": "CYS", + "atm1": " CB ", + "atm2": " SG ", + "x0": 1.8088, + "K": 75.5172 + }, + { + "res": "CYS", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "CYS", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "CYS", + "atm1": " SG ", + "atm2": " HG ", + "x0": 1.32937, + "K": 104.885 + }, + { + "res": "GLN", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "GLN", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "GLN", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53107, + "K": 260.414 + }, + { + "res": "GLN", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08987, + "K": 306.735 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09, + "K": 117.8526 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09, + "K": 117.8526 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.51911, + "K": 84.8615 + }, + { + "res": "GLN", + "atm1": " CD ", + "atm2": " NE2", + "x0": 1.32811, + "K": 164.002 + }, + { + "res": "GLN", + "atm1": " CD ", + "atm2": " OE1", + "x0": 1.23416, + "K": 247.91 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": "1HG ", + "x0": 1.08971, + "K": 117.8526 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": "2HG ", + "x0": 1.09, + "K": 117.8526 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CD ", + "x0": 1.51688, + "K": 76.28 + }, + { + "res": "GLN", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "GLN", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "GLN", + "atm1": " NE2", + "atm2": "1HE2", + "x0": 1.00096, + "K": 183.072 + }, + { + "res": "GLN", + "atm1": " NE2", + "atm2": "2HE2", + "x0": 1.00008, + "K": 183.072 + }, + { + "res": "GLU", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "GLU", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "GLU", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53032, + "K": 260.414 + }, + { + "res": "GLU", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09051, + "K": 306.735 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09015, + "K": 117.8526 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09006, + "K": 117.8526 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.52211, + "K": 84.8615 + }, + { + "res": "GLU", + "atm1": " CD ", + "atm2": " OE1", + "x0": 1.20758, + "K": 200.235 + }, + { + "res": "GLU", + "atm1": " CD ", + "atm2": " OE2", + "x0": 1.20854, + "K": 200.235 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": "1HG ", + "x0": 1.08969, + "K": 117.8526 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": "2HG ", + "x0": 1.08911, + "K": 117.8526 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": " CD ", + "x0": 1.50336, + "K": 76.28 + }, + { + "res": "GLU", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "GLU", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "GLY", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "GLY", + "atm1": " CA ", + "atm2": "1HA ", + "x0": 1.09017, + "K": 330 + }, + { + "res": "GLY", + "atm1": " CA ", + "atm2": "2HA ", + "x0": 1.08935, + "K": 330 + }, + { + "res": "GLY", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "GLY", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "GLY", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "HIS", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "HIS", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "HIS", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53212, + "K": 260.414 + }, + { + "res": "HIS", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08962, + "K": 306.735 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09058, + "K": 117.8526 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08966, + "K": 117.8526 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.49716, + "K": 87.56944 + }, + { + "res": "HIS", + "atm1": " CD2", + "atm2": "2HD ", + "x0": 1.09034, + "K": 139.211 + }, + { + "res": "HIS", + "atm1": " CD2", + "atm2": " NE2", + "x0": 1.37321, + "K": 152.56 + }, + { + "res": "HIS", + "atm1": " CE1", + "atm2": "1HE ", + "x0": 1.08979, + "K": 129.676 + }, + { + "res": "HIS", + "atm1": " CE1", + "atm2": " NE2", + "x0": 1.32038, + "K": 152.56 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CD2", + "x0": 1.35365, + "K": 156.374 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " ND1", + "x0": 1.37916, + "K": 152.56 + }, + { + "res": "HIS", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "HIS", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CE1", + "x0": 1.32193, + "K": 152.56 + }, + { + "res": "HIS", + "atm1": " NE2", + "atm2": "2HE ", + "x0": 1.01008, + "K": 177.7324 + }, + { + "res": "HIS_D", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "HIS_D", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "HIS_D", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53212, + "K": 260.414 + }, + { + "res": "HIS_D", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08962, + "K": 306.735 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09058, + "K": 117.8526 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08966, + "K": 117.8526 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.49716, + "K": 87.56944 + }, + { + "res": "HIS_D", + "atm1": " CD2", + "atm2": "2HD ", + "x0": 1.09034, + "K": 139.211 + }, + { + "res": "HIS_D", + "atm1": " CD2", + "atm2": " NE2", + "x0": 1.37321, + "K": 152.56 + }, + { + "res": "HIS_D", + "atm1": " CE1", + "atm2": "1HE ", + "x0": 1.08979, + "K": 129.676 + }, + { + "res": "HIS_D", + "atm1": " CE1", + "atm2": " NE2", + "x0": 1.32038, + "K": 152.56 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " CD2", + "x0": 1.35365, + "K": 156.374 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " ND1", + "x0": 1.37916, + "K": 152.56 + }, + { + "res": "HIS_D", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "HIS_D", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CE1", + "x0": 1.32193, + "K": 152.56 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": "1HD ", + "x0": 1.00024, + "K": 177.7324 + }, + { + "res": "ILE", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53963, + "K": 260.414 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08933, + "K": 306.735 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "x0": 1.5309, + "K": 84.8615 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG2", + "x0": 1.52091, + "K": 84.8615 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " HB ", + "x0": 1.08955, + "K": 117.8526 + }, + { + "res": "ILE", + "atm1": " CD1", + "atm2": "1HD1", + "x0": 1.09029, + "K": 122.8108 + }, + { + "res": "ILE", + "atm1": " CD1", + "atm2": "2HD1", + "x0": 1.09058, + "K": 122.8108 + }, + { + "res": "ILE", + "atm1": " CD1", + "atm2": "3HD1", + "x0": 1.08906, + "K": 122.8108 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": "1HG1", + "x0": 1.08947, + "K": 117.8526 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": "2HG1", + "x0": 1.09034, + "K": 117.8526 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CD1", + "x0": 1.51168, + "K": 84.8615 + }, + { + "res": "ILE", + "atm1": " CG2", + "atm2": "1HG2", + "x0": 1.08976, + "K": 122.8108 + }, + { + "res": "ILE", + "atm1": " CG2", + "atm2": "2HG2", + "x0": 1.08915, + "K": 122.8108 + }, + { + "res": "ILE", + "atm1": " CG2", + "atm2": "3HG2", + "x0": 1.09015, + "K": 122.8108 + }, + { + "res": "ILE", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "ILE", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "LEU", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "LEU", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "LEU", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53385, + "K": 260.414 + }, + { + "res": "LEU", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08944, + "K": 306.735 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.0888, + "K": 117.8526 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08992, + "K": 117.8526 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.53403, + "K": 84.8615 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": "1HD1", + "x0": 1.0901, + "K": 122.8108 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": "2HD1", + "x0": 1.09008, + "K": 122.8108 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": "3HD1", + "x0": 1.08937, + "K": 122.8108 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": "1HD2", + "x0": 1.09047, + "K": 122.8108 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": "2HD2", + "x0": 1.09021, + "K": 122.8108 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": "3HD2", + "x0": 1.0901, + "K": 122.8108 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD1", + "x0": 1.52267, + "K": 84.8615 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD2", + "x0": 1.52143, + "K": 84.8615 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " HG ", + "x0": 1.09033, + "K": 117.8526 + }, + { + "res": "LEU", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "LEU", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "LYS", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23101, + "K": 563.084 + }, + { + "res": "LYS", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "LYS", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.52951, + "K": 260.414 + }, + { + "res": "LYS", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09053, + "K": 306.735 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09033, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09021, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.52293, + "K": 84.8615 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": "1HD ", + "x0": 1.09086, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": "2HD ", + "x0": 1.08996, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "x0": 1.52158, + "K": 84.8615 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": "1HE ", + "x0": 1.08899, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": "2HE ", + "x0": 1.08993, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": " NZ ", + "x0": 1.48811, + "K": 76.28 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": "1HG ", + "x0": 1.09065, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": "2HG ", + "x0": 1.08954, + "K": 117.8526 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": " CD ", + "x0": 1.52135, + "K": 84.8615 + }, + { + "res": "LYS", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "LYS", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "LYS", + "atm1": " NZ ", + "atm2": "1HZ ", + "x0": 1.01022, + "K": 153.7042 + }, + { + "res": "LYS", + "atm1": " NZ ", + "atm2": "2HZ ", + "x0": 1.01051, + "K": 153.7042 + }, + { + "res": "LYS", + "atm1": " NZ ", + "atm2": "3HZ ", + "x0": 1.00991, + "K": 153.7042 + }, + { + "res": "MET", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "MET", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "MET", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.52739, + "K": 260.414 + }, + { + "res": "MET", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08982, + "K": 306.735 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.08903, + "K": 117.8526 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.09075, + "K": 117.8526 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.5222, + "K": 84.8615 + }, + { + "res": "MET", + "atm1": " CE ", + "atm2": "1HE ", + "x0": 1.09027, + "K": 122.8108 + }, + { + "res": "MET", + "atm1": " CE ", + "atm2": "2HE ", + "x0": 1.09085, + "K": 122.8108 + }, + { + "res": "MET", + "atm1": " CE ", + "atm2": "3HE ", + "x0": 1.09057, + "K": 122.8108 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": "1HG ", + "x0": 1.08947, + "K": 117.8526 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": "2HG ", + "x0": 1.08981, + "K": 117.8526 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " SD ", + "x0": 1.80384, + "K": 75.5172 + }, + { + "res": "MET", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "MET", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "MET", + "atm1": " SD ", + "atm2": " CE ", + "x0": 1.79039, + "K": 91.536 + }, + { + "res": "PHE", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "PHE", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "PHE", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.5298, + "K": 260.414 + }, + { + "res": "PHE", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.0909, + "K": 306.735 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.08922, + "K": 117.8526 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08919, + "K": 117.8526 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.50223, + "K": 87.722 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CE1", + "x0": 1.38218, + "K": 116.327 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": "1HD ", + "x0": 1.09033, + "K": 129.676 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CE2", + "x0": 1.38128, + "K": 116.327 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": "2HD ", + "x0": 1.0906, + "K": 129.676 + }, + { + "res": "PHE", + "atm1": " CE1", + "atm2": " CZ ", + "x0": 1.37858, + "K": 116.327 + }, + { + "res": "PHE", + "atm1": " CE1", + "atm2": "1HE ", + "x0": 1.08972, + "K": 129.676 + }, + { + "res": "PHE", + "atm1": " CE2", + "atm2": " CZ ", + "x0": 1.38049, + "K": 116.327 + }, + { + "res": "PHE", + "atm1": " CE2", + "atm2": "2HE ", + "x0": 1.08983, + "K": 129.676 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD1", + "x0": 1.38696, + "K": 116.327 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD2", + "x0": 1.3866, + "K": 116.327 + }, + { + "res": "PHE", + "atm1": " CZ ", + "atm2": " HZ ", + "x0": 1.08908, + "K": 129.676 + }, + { + "res": "PHE", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "PHE", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "PRO", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23101, + "K": 563.084 + }, + { + "res": "PRO", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "PRO", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.532, + "K": 260.414 + }, + { + "res": "PRO", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.1, + "K": 306.735 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.1, + "K": 117.8526 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.1, + "K": 117.8526 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.4906, + "K": 84.8615 + }, + { + "res": "PRO", + "atm1": " CD ", + "atm2": "1HD ", + "x0": 1.1, + "K": 117.8526 + }, + { + "res": "PRO", + "atm1": " CD ", + "atm2": "2HD ", + "x0": 1.1, + "K": 117.8526 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": "1HG ", + "x0": 1.1, + "K": 117.8526 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": "2HG ", + "x0": 1.1, + "K": 117.8526 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": " CD ", + "x0": 1.5055, + "K": 84.8615 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CD ", + "x0": 1.473, + "K": 122.048 + }, + { + "res": "SER", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.51626, + "K": 260.414 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09092, + "K": 306.735 + }, + { + "res": "SER", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.08969, + "K": 117.8526 + }, + { + "res": "SER", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08918, + "K": 117.8526 + }, + { + "res": "SER", + "atm1": " CB ", + "atm2": " OG ", + "x0": 1.40119, + "K": 163.2392 + }, + { + "res": "SER", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "SER", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "SER", + "atm1": " OG ", + "atm2": " HG ", + "x0": 0.960175, + "K": 207.863 + }, + { + "res": "THR", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53992, + "K": 260.414 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09026, + "K": 306.735 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " CG2", + "x0": 1.52099, + "K": 84.8615 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " HB ", + "x0": 1.08982, + "K": 117.8526 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " OG1", + "x0": 1.43355, + "K": 163.2392 + }, + { + "res": "THR", + "atm1": " CG2", + "atm2": "1HG2", + "x0": 1.08983, + "K": 122.8108 + }, + { + "res": "THR", + "atm1": " CG2", + "atm2": "2HG2", + "x0": 1.08986, + "K": 122.8108 + }, + { + "res": "THR", + "atm1": " CG2", + "atm2": "3HG2", + "x0": 1.08924, + "K": 122.8108 + }, + { + "res": "THR", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "THR", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "THR", + "atm1": " OG1", + "atm2": " HG1", + "x0": 0.960297, + "K": 207.863 + }, + { + "res": "TRP", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23101, + "K": 563.084 + }, + { + "res": "TRP", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "TRP", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.52982, + "K": 260.414 + }, + { + "res": "TRP", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.08988, + "K": 306.735 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09017, + "K": 117.8526 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08979, + "K": 117.8526 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.49875, + "K": 87.722 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": "1HD ", + "x0": 1.08852, + "K": 129.676 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " NE1", + "x0": 1.37294, + "K": 102.978 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE2", + "x0": 1.39734, + "K": 137.304 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE3", + "x0": 1.40038, + "K": 116.327 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CZ2", + "x0": 1.38595, + "K": 116.327 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CZ3", + "x0": 1.38985, + "K": 116.327 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " HE3", + "x0": 1.08954, + "K": 129.676 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD1", + "x0": 1.36272, + "K": 133.49 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "x0": 1.44821, + "K": 133.49 + }, + { + "res": "TRP", + "atm1": " CH2", + "atm2": " HH2", + "x0": 1.09029, + "K": 129.676 + }, + { + "res": "TRP", + "atm1": " CZ2", + "atm2": " CH2", + "x0": 1.39502, + "K": 116.327 + }, + { + "res": "TRP", + "atm1": " CZ2", + "atm2": " HZ2", + "x0": 1.09024, + "K": 129.676 + }, + { + "res": "TRP", + "atm1": " CZ3", + "atm2": " CH2", + "x0": 1.37211, + "K": 116.327 + }, + { + "res": "TRP", + "atm1": " CZ3", + "atm2": " HZ3", + "x0": 1.09029, + "K": 129.676 + }, + { + "res": "TRP", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "TRP", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "TRP", + "atm1": " NE1", + "atm2": " CE2", + "x0": 1.37213, + "K": 102.978 + }, + { + "res": "TRP", + "atm1": " NE1", + "atm2": "1HE ", + "x0": 1.00989, + "K": 177.351 + }, + { + "res": "TYR", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23101, + "K": 563.084 + }, + { + "res": "TYR", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "TYR", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.53035, + "K": 260.414 + }, + { + "res": "TYR", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09032, + "K": 306.735 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": "1HB ", + "x0": 1.09097, + "K": 117.8526 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": "2HB ", + "x0": 1.08957, + "K": 117.8526 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "x0": 1.51266, + "K": 87.722 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CE1", + "x0": 1.38155, + "K": 116.327 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": "1HD ", + "x0": 1.09023, + "K": 129.676 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CE2", + "x0": 1.38136, + "K": 116.327 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": "2HD ", + "x0": 1.09002, + "K": 129.676 + }, + { + "res": "TYR", + "atm1": " CE1", + "atm2": " CZ ", + "x0": 1.39041, + "K": 116.327 + }, + { + "res": "TYR", + "atm1": " CE1", + "atm2": "1HE ", + "x0": 1.08966, + "K": 129.676 + }, + { + "res": "TYR", + "atm1": " CE2", + "atm2": " CZ ", + "x0": 1.37999, + "K": 116.327 + }, + { + "res": "TYR", + "atm1": " CE2", + "atm2": "2HE ", + "x0": 1.09029, + "K": 129.676 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD1", + "x0": 1.38719, + "K": 116.327 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD2", + "x0": 1.38689, + "K": 116.327 + }, + { + "res": "TYR", + "atm1": " CZ ", + "atm2": " OH ", + "x0": 1.37596, + "K": 127.50202 + }, + { + "res": "TYR", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "TYR", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + }, + { + "res": "TYR", + "atm1": " OH ", + "atm2": " HH ", + "x0": 0.960239, + "K": 207.863 + }, + { + "res": "VAL", + "atm1": " C ", + "atm2": " O ", + "x0": 1.23102, + "K": 563.084 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " C ", + "x0": 1.52326, + "K": 301.675 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "x0": 1.54025, + "K": 260.414 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " HA ", + "x0": 1.09065, + "K": 306.735 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG1", + "x0": 1.52142, + "K": 84.8615 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG2", + "x0": 1.52106, + "K": 84.8615 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " HB ", + "x0": 1.09011, + "K": 117.8526 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": "1HG1", + "x0": 1.09035, + "K": 122.8108 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": "2HG1", + "x0": 1.08982, + "K": 122.8108 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": "3HG1", + "x0": 1.08981, + "K": 122.8108 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": "1HG2", + "x0": 1.08961, + "K": 122.8108 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": "2HG2", + "x0": 1.09018, + "K": 122.8108 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": "3HG2", + "x0": 1.09017, + "K": 122.8108 + }, + { + "res": "VAL", + "atm1": " N ", + "atm2": " CA ", + "x0": 1.458, + "K": 361.504 + }, + { + "res": "VAL", + "atm1": " N ", + "atm2": " H ", + "x0": 1.01, + "K": 458.304 + } + ], + "angles": [ + { + "res": "CYS", + "atm1": " CB ", + "atm2": " SG ", + "atm3": " SG ", + "x0": 0, + "K": 0 + }, + { + "res": "ALA", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "ALA", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "3HB ", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "ALA", + "atm1": "2HB ", + "atm2": " CB ", + "atm3": "3HB ", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "ALA", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92575, + "K": 103.9584 + }, + { + "res": "ALA", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88223, + "K": 85.58 + }, + { + "res": "ALA", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ALA", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ALA", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "3HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ALA", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.90057, + "K": 59.906 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.9264, + "K": 144.466 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8867, + "K": 82.1568 + }, + { + "res": "ARG", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "ARG", + "atm1": "1HD ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "ARG", + "atm1": "1HG ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "ARG", + "atm1": "1HH1", + "atm2": " NH1", + "atm3": "2HH1", + "x0": 2.0944, + "K": 42.79 + }, + { + "res": "ARG", + "atm1": "1HH2", + "atm2": " NH2", + "atm3": "2HH2", + "x0": 2.0944, + "K": 42.79 + }, + { + "res": "ARG", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.90853, + "K": 103.9584 + }, + { + "res": "ARG", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89435, + "K": 85.58 + }, + { + "res": "ARG", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ARG", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ARG", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 2.00495, + "K": 121.74144 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89517, + "K": 59.906 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD ", + "x0": 1.94953, + "K": 108.09921 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.89072, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.89072, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " NE ", + "atm3": " CZ ", + "x0": 2.17468, + "K": 115.41698 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " NE ", + "atm3": " HE ", + "x0": 2.05425, + "K": 69.14864 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.87822, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87822, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": " CD ", + "atm3": "1HD ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "ARG", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " NE ", + "x0": 1.95302, + "K": 125.42102 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NE ", + "atm3": " HE ", + "x0": 2.05425, + "K": 83.8684 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NH1", + "atm3": "1HH1", + "x0": 2.0944, + "K": 83.8684 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NH1", + "atm3": "2HH1", + "x0": 2.0944, + "K": 83.8684 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NH2", + "atm3": "1HH2", + "x0": 2.0944, + "K": 83.8684 + }, + { + "res": "ARG", + "atm1": " CZ ", + "atm2": " NH2", + "atm3": "2HH2", + "x0": 2.0944, + "K": 83.8684 + }, + { + "res": "ARG", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93033, + "K": 144.466 + }, + { + "res": "ARG", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CD ", + "atm3": "1HD ", + "x0": 1.88894, + "K": 88.1474 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.88894, + "K": 88.1474 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "atm3": " NH1", + "x0": 2.0944, + "K": 96.3352 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "atm3": " NH2", + "x0": 2.0944, + "K": 96.3352 + }, + { + "res": "ARG", + "atm1": " NH1", + "atm2": " CZ ", + "atm3": " NH2", + "x0": 2.0944, + "K": 96.3352 + }, + { + "res": "ASN", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "ASN", + "atm1": "1HD2", + "atm2": " ND2", + "atm3": "2HD2", + "x0": 2.0944, + "K": 39.3668 + }, + { + "res": "ASN", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.9304, + "K": 103.9584 + }, + { + "res": "ASN", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.87155, + "K": 85.58 + }, + { + "res": "ASN", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ASN", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ASN", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.96524, + "K": 108.4928 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89517, + "K": 59.906 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND2", + "x0": 2.03331, + "K": 92.63 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " OD1", + "x0": 2.10836, + "K": 27.789 + }, + { + "res": "ASN", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.8992, + "K": 56.4828 + }, + { + "res": "ASN", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.8992, + "K": 56.4828 + }, + { + "res": "ASN", + "atm1": " CG ", + "atm2": " ND2", + "atm3": "1HD2", + "x0": 2.0944, + "K": 85.58 + }, + { + "res": "ASN", + "atm1": " CG ", + "atm2": " ND2", + "atm3": "2HD2", + "x0": 2.0944, + "K": 85.58 + }, + { + "res": "ASN", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93033, + "K": 144.466 + }, + { + "res": "ASN", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "ASN", + "atm1": " OD1", + "atm2": " CG ", + "atm3": " ND2", + "x0": 2.14152, + "K": 138.945 + }, + { + "res": "ASP", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "ASP", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91479, + "K": 103.9584 + }, + { + "res": "ASP", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88867, + "K": 85.58 + }, + { + "res": "ASP", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "ASP", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "ASP", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.97048, + "K": 108.4928 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89586, + "K": 59.906 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " OD1", + "x0": 2.06637, + "K": 74.104 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " OD2", + "x0": 2.06591, + "K": 74.104 + }, + { + "res": "ASP", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.89646, + "K": 56.4828 + }, + { + "res": "ASP", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.89646, + "K": 56.4828 + }, + { + "res": "ASP", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92892, + "K": 144.466 + }, + { + "res": "ASP", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "ASP", + "atm1": " OD1", + "atm2": " CG ", + "atm3": " OD2", + "x0": 2.1509, + "K": 185.26 + }, + { + "res": "CYS", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "CYS", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.90716, + "K": 103.9584 + }, + { + "res": "CYS", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89577, + "K": 85.58 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " SG ", + "x0": 1.99142, + "K": 121.0112 + }, + { + "res": "CYS", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89517, + "K": 59.906 + }, + { + "res": "CYS", + "atm1": " CB ", + "atm2": " SG ", + "atm3": " HG ", + "x0": 1.67531, + "K": 66.41008 + }, + { + "res": "CYS", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93033, + "K": 144.466 + }, + { + "res": "CYS", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "CYS", + "atm1": " SG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88541, + "K": 78.90476 + }, + { + "res": "CYS", + "atm1": " SG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88541, + "K": 78.90476 + }, + { + "res": "GLN", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "GLN", + "atm1": "1HE2", + "atm2": " NE2", + "atm3": "2HE2", + "x0": 2.09341, + "K": 39.3668 + }, + { + "res": "GLN", + "atm1": "1HG ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.9205, + "K": 60.7618 + }, + { + "res": "GLN", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91915, + "K": 103.9584 + }, + { + "res": "GLN", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8844, + "K": 85.58 + }, + { + "res": "GLN", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "GLN", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "GLN", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.99298, + "K": 121.74144 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89608, + "K": 59.906 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.90066, + "K": 45.3574 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.90066, + "K": 45.3574 + }, + { + "res": "GLN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD ", + "x0": 1.96227, + "K": 96.3352 + }, + { + "res": "GLN", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.88985, + "K": 56.4828 + }, + { + "res": "GLN", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.88985, + "K": 56.4828 + }, + { + "res": "GLN", + "atm1": " CD ", + "atm2": " NE2", + "atm3": "1HE2", + "x0": 2.09481, + "K": 85.58 + }, + { + "res": "GLN", + "atm1": " CD ", + "atm2": " NE2", + "atm3": "2HE2", + "x0": 2.09497, + "K": 85.58 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88458, + "K": 45.3574 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88458, + "K": 45.3574 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " NE2", + "x0": 2.03179, + "K": 92.63 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " OE1", + "x0": 2.11099, + "K": 27.789 + }, + { + "res": "GLN", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92846, + "K": 144.466 + }, + { + "res": "GLN", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "GLN", + "atm1": " OE1", + "atm2": " CD ", + "atm3": " NE2", + "x0": 2.1404, + "K": 138.945 + }, + { + "res": "GLU", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "GLU", + "atm1": "1HG ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.88428, + "K": 60.7618 + }, + { + "res": "GLU", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91729, + "K": 103.9584 + }, + { + "res": "GLU", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88725, + "K": 85.58 + }, + { + "res": "GLU", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "GLU", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "GLU", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.9966, + "K": 121.74144 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89686, + "K": 59.906 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.90066, + "K": 45.3574 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.90066, + "K": 45.3574 + }, + { + "res": "GLU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD ", + "x0": 1.97048, + "K": 96.3352 + }, + { + "res": "GLU", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.89951, + "K": 56.4828 + }, + { + "res": "GLU", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.90638, + "K": 56.4828 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88266, + "K": 45.3574 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88266, + "K": 45.3574 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " OE1", + "x0": 2.06728, + "K": 74.104 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " OE2", + "x0": 2.06581, + "K": 74.104 + }, + { + "res": "GLU", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92684, + "K": 144.466 + }, + { + "res": "GLU", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "GLU", + "atm1": " OE1", + "atm2": " CD ", + "atm3": " OE2", + "x0": 2.15009, + "K": 185.26 + }, + { + "res": "GLY", + "atm1": "1HA ", + "atm2": " CA ", + "atm3": "2HA ", + "x0": 1.86998, + "K": 72 + }, + { + "res": "GLY", + "atm1": " C ", + "atm2": " CA ", + "atm3": "1HA ", + "x0": 1.91471, + "K": 100 + }, + { + "res": "GLY", + "atm1": " C ", + "atm2": " CA ", + "atm3": "2HA ", + "x0": 1.91471, + "K": 100 + }, + { + "res": "GLY", + "atm1": " N ", + "atm2": " CA ", + "atm3": "1HA ", + "x0": 1.91114, + "K": 96 + }, + { + "res": "GLY", + "atm1": " N ", + "atm2": " CA ", + "atm3": "2HA ", + "x0": 1.91114, + "K": 96 + }, + { + "res": "HIS", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "HIS", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91454, + "K": 103.9584 + }, + { + "res": "HIS", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.87829, + "K": 85.58 + }, + { + "res": "HIS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "HIS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "HIS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.98427, + "K": 121.74144 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88557, + "K": 78.39128 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.29001, + "K": 84.84908 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND1", + "x0": 2.1403, + "K": 240.838 + }, + { + "res": "HIS", + "atm1": " CD2", + "atm2": " NE2", + "atm3": " CE1", + "x0": 1.9025, + "K": 55.578 + }, + { + "res": "HIS", + "atm1": " CD2", + "atm2": " NE2", + "atm3": "2HE ", + "x0": 2.18506, + "K": 42.79 + }, + { + "res": "HIS", + "atm1": " CE1", + "atm2": " NE2", + "atm3": "2HE ", + "x0": 2.19562, + "K": 51.348 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88919, + "K": 57.218788 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88919, + "K": 57.218788 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.20671, + "K": 42.79 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " NE2", + "x0": 1.87038, + "K": 240.838 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " ND1", + "atm3": " CE1", + "x0": 1.90739, + "K": 240.838 + }, + { + "res": "HIS", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93189, + "K": 144.466 + }, + { + "res": "HIS", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.91114, + "K": 82.1568 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.19562, + "K": 42.79 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CE1", + "atm3": " NE2", + "x0": 1.89164, + "K": 240.838 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CG ", + "atm3": " CD2", + "x0": 1.85288, + "K": 240.838 + }, + { + "res": "HIS", + "atm1": " NE2", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.2061, + "K": 42.79 + }, + { + "res": "HIS", + "atm1": " NE2", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.19593, + "K": 42.79 + }, + { + "res": "HIS_D", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "HIS_D", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91454, + "K": 103.9584 + }, + { + "res": "HIS_D", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8783, + "K": 85.58 + }, + { + "res": "HIS_D", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "HIS_D", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "HIS_D", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.98427, + "K": 121.74144 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88557, + "K": 78.39128 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.29002, + "K": 84.84908 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND1", + "x0": 2.14029, + "K": 240.838 + }, + { + "res": "HIS_D", + "atm1": " CD2", + "atm2": " NE2", + "atm3": " CE1", + "x0": 1.90251, + "K": 55.578 + }, + { + "res": "HIS_D", + "atm1": " CE1", + "atm2": " ND1", + "atm3": "1HD ", + "x0": 2.16847, + "K": 42.79 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88919, + "K": 57.218788 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88919, + "K": 57.218788 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.31752, + "K": 42.79 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " NE2", + "x0": 1.87037, + "K": 240.838 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " ND1", + "atm3": " CE1", + "x0": 1.90738, + "K": 240.838 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " ND1", + "atm3": "1HD ", + "x0": 2.20336, + "K": 42.79 + }, + { + "res": "HIS_D", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93189, + "K": 144.466 + }, + { + "res": "HIS_D", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.91114, + "K": 82.1568 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.10501, + "K": 42.79 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CE1", + "atm3": " NE2", + "x0": 1.89163, + "K": 240.838 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CG ", + "atm3": " CD2", + "x0": 1.85288, + "K": 240.838 + }, + { + "res": "HIS_D", + "atm1": " NE2", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.09529, + "K": 42.79 + }, + { + "res": "HIS_D", + "atm1": " NE2", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.28654, + "K": 42.79 + }, + { + "res": "ILE", + "atm1": "1HD1", + "atm2": " CD1", + "atm3": "2HD1", + "x0": 1.90998, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": "1HD1", + "atm2": " CD1", + "atm3": "3HD1", + "x0": 1.90964, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": "1HG1", + "atm2": " CG1", + "atm3": "2HG1", + "x0": 1.8743, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": "1HG2", + "atm2": " CG2", + "atm3": "2HG2", + "x0": 1.91024, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": "1HG2", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.91002, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": "2HD1", + "atm2": " CD1", + "atm3": "3HD1", + "x0": 1.91077, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": "2HG2", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "ILE", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91116, + "K": 103.9584 + }, + { + "res": "ILE", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88867, + "K": 85.58 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG1", + "x0": 1.92703, + "K": 111.30944 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "x0": 1.92804, + "K": 111.30944 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.91114, + "K": 69 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89264, + "K": 59.906 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "atm3": "1HG1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "atm3": "2HG1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "atm3": " CD1", + "x0": 1.98662, + "K": 108.09921 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "1HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "2HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "ILE", + "atm1": " CD1", + "atm2": " CG1", + "atm3": "1HG1", + "x0": 1.8895, + "K": 59.22136 + }, + { + "res": "ILE", + "atm1": " CD1", + "atm2": " CG1", + "atm3": "2HG1", + "x0": 1.88775, + "K": 59.22136 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "x0": 1.9306, + "K": 98.83621 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.8641, + "K": 59.0502 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CD1", + "atm3": "1HD1", + "x0": 1.91114, + "K": 59.22136 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CD1", + "atm3": "2HD1", + "x0": 1.91114, + "K": 59.22136 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CD1", + "atm3": "3HD1", + "x0": 1.91114, + "K": 59.22136 + }, + { + "res": "ILE", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.90163, + "K": 59.0502 + }, + { + "res": "ILE", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93557, + "K": 144.466 + }, + { + "res": "ILE", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "LEU", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.89065, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": "1HD1", + "atm2": " CD1", + "atm3": "2HD1", + "x0": 1.90928, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": "1HD1", + "atm2": " CD1", + "atm3": "3HD1", + "x0": 1.91003, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": "1HD2", + "atm2": " CD2", + "atm3": "2HD2", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": "1HD2", + "atm2": " CD2", + "atm3": "3HD2", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": "2HD1", + "atm2": " CD1", + "atm3": "3HD1", + "x0": 1.91109, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": "2HD2", + "atm2": " CD2", + "atm3": "3HD2", + "x0": 1.91013, + "K": 60.7618 + }, + { + "res": "LEU", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91492, + "K": 103.9584 + }, + { + "res": "LEU", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89151, + "K": 85.58 + }, + { + "res": "LEU", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.89543, + "K": 66.86 + }, + { + "res": "LEU", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.89543, + "K": 66.86 + }, + { + "res": "LEU", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 2.01935, + "K": 121.74144 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89842, + "K": 59.906 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "x0": 1.91114, + "K": 98.83621 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "x0": 1.91114, + "K": 98.83621 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " HG ", + "x0": 1.91114, + "K": 59.0502 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "x0": 1.93845, + "K": 98.83621 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " HG ", + "x0": 1.88153, + "K": 59.0502 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " HG ", + "x0": 1.91013, + "K": 59.0502 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.87977, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87977, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD1", + "atm3": "1HD1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD1", + "atm3": "2HD1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD1", + "atm3": "3HD1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "1HD2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "2HD2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "3HD2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "LEU", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92361, + "K": 144.466 + }, + { + "res": "LEU", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "LYS", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "LYS", + "atm1": "1HD ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.89356, + "K": 60.7618 + }, + { + "res": "LYS", + "atm1": "1HE ", + "atm2": " CE ", + "atm3": "2HE ", + "x0": 1.88918, + "K": 60.7618 + }, + { + "res": "LYS", + "atm1": "1HG ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.89349, + "K": 60.7618 + }, + { + "res": "LYS", + "atm1": "1HZ ", + "atm2": " NZ ", + "atm3": "2HZ ", + "x0": 1.91021, + "K": 75.3104 + }, + { + "res": "LYS", + "atm1": "1HZ ", + "atm2": " NZ ", + "atm3": "3HZ ", + "x0": 1.91002, + "K": 75.3104 + }, + { + "res": "LYS", + "atm1": "2HZ ", + "atm2": " NZ ", + "atm3": "3HZ ", + "x0": 1.91016, + "K": 75.3104 + }, + { + "res": "LYS", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92181, + "K": 103.9584 + }, + { + "res": "LYS", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.88155, + "K": 85.58 + }, + { + "res": "LYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "LYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "LYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.99661, + "K": 121.74144 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89601, + "K": 59.906 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD ", + "x0": 1.94255, + "K": 108.09921 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "atm3": "1HE ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "atm3": "2HE ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "atm3": " NZ ", + "x0": 1.95413, + "K": 125.42102 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.90229, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.90254, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": " CD ", + "atm3": "1HD ", + "x0": 1.90247, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.90109, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": " NZ ", + "atm3": "1HZ ", + "x0": 1.91114, + "K": 51.348 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": " NZ ", + "atm3": "2HZ ", + "x0": 1.91114, + "K": 51.348 + }, + { + "res": "LYS", + "atm1": " CE ", + "atm2": " NZ ", + "atm3": "3HZ ", + "x0": 1.91114, + "K": 51.348 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88265, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88265, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": " CD ", + "atm3": "1HD ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "LYS", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " CE ", + "x0": 1.94373, + "K": 108.09921 + }, + { + "res": "LYS", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92861, + "K": 144.466 + }, + { + "res": "LYS", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "LYS", + "atm1": " NZ ", + "atm2": " CE ", + "atm3": "1HE ", + "x0": 1.89906, + "K": 77.022 + }, + { + "res": "LYS", + "atm1": " NZ ", + "atm2": " CE ", + "atm3": "2HE ", + "x0": 1.89802, + "K": 77.022 + }, + { + "res": "MET", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "MET", + "atm1": "1HE ", + "atm2": " CE ", + "atm3": "2HE ", + "x0": 1.91081, + "K": 60.7618 + }, + { + "res": "MET", + "atm1": "1HE ", + "atm2": " CE ", + "atm3": "3HE ", + "x0": 1.90982, + "K": 60.7618 + }, + { + "res": "MET", + "atm1": "1HG ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.88852, + "K": 60.7618 + }, + { + "res": "MET", + "atm1": "2HE ", + "atm2": " CE ", + "atm3": "3HE ", + "x0": 1.90976, + "K": 60.7618 + }, + { + "res": "MET", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92732, + "K": 103.9584 + }, + { + "res": "MET", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8787, + "K": 85.58 + }, + { + "res": "MET", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "MET", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "MET", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.99733, + "K": 121.74144 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89845, + "K": 59.906 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "MET", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " SD ", + "x0": 1.96649, + "K": 107.4508 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88227, + "K": 45.3574 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88227, + "K": 45.3574 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " SD ", + "atm3": " CE ", + "x0": 1.76091, + "K": 62.9884 + }, + { + "res": "MET", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92354, + "K": 144.466 + }, + { + "res": "MET", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "MET", + "atm1": " SD ", + "atm2": " CE ", + "atm3": "1HE ", + "x0": 1.91114, + "K": 78.90476 + }, + { + "res": "MET", + "atm1": " SD ", + "atm2": " CE ", + "atm3": "2HE ", + "x0": 1.91114, + "K": 78.90476 + }, + { + "res": "MET", + "atm1": " SD ", + "atm2": " CE ", + "atm3": "3HE ", + "x0": 1.91114, + "K": 78.90476 + }, + { + "res": "MET", + "atm1": " SD ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.89281, + "K": 78.90476 + }, + { + "res": "MET", + "atm1": " SD ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.89208, + "K": 78.90476 + }, + { + "res": "PHE", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "PHE", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91412, + "K": 103.9584 + }, + { + "res": "PHE", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89009, + "K": 85.58 + }, + { + "res": "PHE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "PHE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "PHE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.98604, + "K": 108.07552 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89648, + "K": 59.906 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "x0": 2.10636, + "K": 84.84908 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.107, + "K": 84.84908 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CE1", + "atm3": " CZ ", + "x0": 2.09493, + "K": 74.104 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.09383, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.06983, + "K": 74.104 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ ", + "x0": 2.09386, + "K": 74.104 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CE2", + "atm3": "2HE ", + "x0": 2.0956, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CE1", + "atm2": " CD1", + "atm3": "1HD ", + "x0": 2.09495, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CE1", + "atm2": " CZ ", + "atm3": " CE2", + "x0": 2.09214, + "K": 74.104 + }, + { + "res": "PHE", + "atm1": " CE1", + "atm2": " CZ ", + "atm3": " HZ ", + "x0": 2.09582, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CE2", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.09398, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CE2", + "atm2": " CZ ", + "atm3": " HZ ", + "x0": 2.09522, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88826, + "K": 84.38188 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88826, + "K": 84.38188 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " CE1", + "x0": 2.10722, + "K": 74.104 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD1", + "atm3": "1HD ", + "x0": 2.08102, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "x0": 2.10839, + "K": 74.104 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.08081, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CZ ", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.09443, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " CZ ", + "atm2": " CE2", + "atm3": "2HE ", + "x0": 2.09372, + "K": 51.348 + }, + { + "res": "PHE", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92762, + "K": 144.466 + }, + { + "res": "PHE", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "PRO", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.86724, + "K": 60.7618 + }, + { + "res": "PRO", + "atm1": "1HD ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.93141, + "K": 60.7618 + }, + { + "res": "PRO", + "atm1": "1HG ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.8638, + "K": 60.7618 + }, + { + "res": "PRO", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.94845, + "K": 103.9584 + }, + { + "res": "PRO", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.93002, + "K": 85.58 + }, + { + "res": "PRO", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91986, + "K": 66.86 + }, + { + "res": "PRO", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91986, + "K": 66.86 + }, + { + "res": "PRO", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.81863, + "K": 146.048 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.93694, + "K": 59.906 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.92335, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.92335, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD ", + "x0": 1.82212, + "K": 129.682 + }, + { + "res": "PRO", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "1HG ", + "x0": 1.96602, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CD ", + "atm2": " CG ", + "atm3": "2HG ", + "x0": 1.96602, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.96927, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.96927, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": " CD ", + "atm3": "1HD ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " CG ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.91114, + "K": 45.3574 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.79769, + "K": 144.466 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.9059, + "K": 82.1568 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CD ", + "atm3": "1HD ", + "x0": 1.9024, + "K": 60.7618 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CD ", + "atm3": "2HD ", + "x0": 1.9024, + "K": 60.7618 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " CD ", + "atm3": " CG ", + "x0": 1.8012, + "K": 125.184 + }, + { + "res": "SER", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 71 + }, + { + "res": "SER", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91703, + "K": 103.9584 + }, + { + "res": "SER", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89009, + "K": 85.58 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " OG ", + "x0": 1.93732, + "K": 157.94048 + }, + { + "res": "SER", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89908, + "K": 59.906 + }, + { + "res": "SER", + "atm1": " CB ", + "atm2": " OG ", + "atm3": " HG ", + "x0": 1.85005, + "K": 115 + }, + { + "res": "SER", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92225, + "K": 144.466 + }, + { + "res": "SER", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "SER", + "atm1": " OG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91373, + "K": 91.8 + }, + { + "res": "SER", + "atm1": " OG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91373, + "K": 91.8 + }, + { + "res": "THR", + "atm1": "1HG2", + "atm2": " CG2", + "atm3": "2HG2", + "x0": 1.91059, + "K": 60.7618 + }, + { + "res": "THR", + "atm1": "1HG2", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.90984, + "K": 60.7618 + }, + { + "res": "THR", + "atm1": "2HG2", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.90996, + "K": 60.7618 + }, + { + "res": "THR", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91334, + "K": 103.9584 + }, + { + "res": "THR", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8844, + "K": 85.58 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "x0": 1.92913, + "K": 111.30944 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.91114, + "K": 69 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " OG1", + "x0": 1.91255, + "K": 157.94048 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89094, + "K": 59.906 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "1HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "2HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "THR", + "atm1": " CB ", + "atm2": " OG1", + "atm3": " HG1", + "x0": 1.90986, + "K": 98.417 + }, + { + "res": "THR", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.90521, + "K": 59.0502 + }, + { + "res": "THR", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.93906, + "K": 144.466 + }, + { + "res": "THR", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "THR", + "atm1": " OG1", + "atm2": " CB ", + "atm3": " CG2", + "x0": 1.90802, + "K": 140.24182 + }, + { + "res": "THR", + "atm1": " OG1", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.89749, + "K": 78.56244 + }, + { + "res": "TRP", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "TRP", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.91317, + "K": 103.9584 + }, + { + "res": "TRP", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89151, + "K": 85.58 + }, + { + "res": "TRP", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "TRP", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "TRP", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.98154, + "K": 121.74144 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89686, + "K": 59.906 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "x0": 2.21197, + "K": 84.84908 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.21981, + "K": 84.84908 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "x0": 1.8514, + "K": 222.312 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " NE1", + "atm3": " CE2", + "x0": 1.90066, + "K": 203.786 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " NE1", + "atm3": "1HE ", + "x0": 2.19273, + "K": 47.9248 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ2", + "x0": 2.12856, + "K": 111.156 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " NE1", + "x0": 1.88329, + "K": 188.276 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE3", + "atm3": " CZ3", + "x0": 2.07225, + "K": 111.156 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE3", + "atm3": " HE3", + "x0": 2.10432, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CD2", + "atm3": " CE3", + "x0": 2.08238, + "K": 111.156 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CZ2", + "atm3": " CH2", + "x0": 2.05076, + "K": 111.156 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CZ2", + "atm3": " HZ2", + "x0": 2.13105, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " NE1", + "atm3": "1HE ", + "x0": 2.18979, + "K": 47.9248 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CZ3", + "atm3": " CH2", + "x0": 2.11185, + "K": 74.104 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CZ3", + "atm3": " HZ3", + "x0": 2.06422, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.89063, + "K": 57.218788 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.89063, + "K": 57.218788 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD1", + "atm3": "1HD ", + "x0": 2.17992, + "K": 54.7712 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " NE1", + "x0": 1.92259, + "K": 222.312 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "x0": 1.86683, + "K": 203.786 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE3", + "x0": 2.33398, + "K": 296.416 + }, + { + "res": "TRP", + "atm1": " CH2", + "atm2": " CZ2", + "atm3": " HZ2", + "x0": 2.10137, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CH2", + "atm2": " CZ3", + "atm3": " HZ3", + "x0": 2.10711, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CZ2", + "atm2": " CH2", + "atm3": " CZ3", + "x0": 2.12058, + "K": 74.104 + }, + { + "res": "TRP", + "atm1": " CZ2", + "atm2": " CH2", + "atm3": " HH2", + "x0": 2.08408, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CZ3", + "atm2": " CE3", + "atm3": " HE3", + "x0": 2.10661, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " CZ3", + "atm2": " CH2", + "atm3": " HH2", + "x0": 2.07853, + "K": 51.348 + }, + { + "res": "TRP", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92684, + "K": 144.466 + }, + { + "res": "TRP", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "TRP", + "atm1": " NE1", + "atm2": " CD1", + "atm3": "1HD ", + "x0": 2.18068, + "K": 54.7712 + }, + { + "res": "TRP", + "atm1": " NE1", + "atm2": " CE2", + "atm3": " CZ2", + "x0": 2.27134, + "K": 296.416 + }, + { + "res": "TYR", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.87575, + "K": 60.7618 + }, + { + "res": "TYR", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92479, + "K": 103.9584 + }, + { + "res": "TYR", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8787, + "K": 85.58 + }, + { + "res": "TYR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "TYR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.91114, + "K": 66.86 + }, + { + "res": "TYR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG ", + "x0": 1.98618, + "K": 108.07552 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89623, + "K": 59.906 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "x0": 2.0944, + "K": 84.84908 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.0944, + "K": 84.84908 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CE1", + "atm3": " CZ ", + "x0": 2.08979, + "K": 74.104 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.0944, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "x0": 2.0944, + "K": 74.104 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ ", + "x0": 2.0944, + "K": 74.104 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CE2", + "atm3": "2HE ", + "x0": 2.0944, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CE1", + "atm2": " CD1", + "atm3": "1HD ", + "x0": 2.0944, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CE1", + "atm2": " CZ ", + "atm3": " CE2", + "x0": 2.099, + "K": 74.104 + }, + { + "res": "TYR", + "atm1": " CE1", + "atm2": " CZ ", + "atm3": " OH ", + "x0": 2.08979, + "K": 83.73752 + }, + { + "res": "TYR", + "atm1": " CE2", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.09439, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CE2", + "atm2": " CZ ", + "atm3": " OH ", + "x0": 2.0944, + "K": 83.73752 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "1HB ", + "x0": 1.88818, + "K": 84.38188 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CB ", + "atm3": "2HB ", + "x0": 1.88818, + "K": 84.38188 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " CE1", + "x0": 2.0944, + "K": 74.104 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD1", + "atm3": "1HD ", + "x0": 2.0944, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "x0": 2.0944, + "K": 74.104 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD2", + "atm3": "2HD ", + "x0": 2.0944, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CZ ", + "atm2": " CE1", + "atm3": "1HE ", + "x0": 2.099, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CZ ", + "atm2": " CE2", + "atm3": "2HE ", + "x0": 2.09439, + "K": 51.348 + }, + { + "res": "TYR", + "atm1": " CZ ", + "atm2": " OH ", + "atm3": " HH ", + "x0": 1.90939, + "K": 111.254 + }, + { + "res": "TYR", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.92815, + "K": 144.466 + }, + { + "res": "TYR", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + }, + { + "res": "VAL", + "atm1": "1HG1", + "atm2": " CG1", + "atm3": "2HG1", + "x0": 1.90965, + "K": 60.7618 + }, + { + "res": "VAL", + "atm1": "1HG1", + "atm2": " CG1", + "atm3": "3HG1", + "x0": 1.90936, + "K": 60.7618 + }, + { + "res": "VAL", + "atm1": "1HG2", + "atm2": " CG2", + "atm3": "2HG2", + "x0": 1.91055, + "K": 60.7618 + }, + { + "res": "VAL", + "atm1": "1HG2", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.90981, + "K": 60.7618 + }, + { + "res": "VAL", + "atm1": "2HG1", + "atm2": " CG1", + "atm3": "3HG1", + "x0": 1.91138, + "K": 60.7618 + }, + { + "res": "VAL", + "atm1": "2HG2", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.91003, + "K": 60.7618 + }, + { + "res": "VAL", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.90861, + "K": 103.9584 + }, + { + "res": "VAL", + "atm1": " C ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89719, + "K": 85.58 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG1", + "x0": 1.92842, + "K": 111.30944 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "x0": 1.91812, + "K": 111.30944 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.90167, + "K": 69 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.8977, + "K": 59.906 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG1", + "atm3": "1HG1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG1", + "atm3": "2HG1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG1", + "atm3": "3HG1", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "1HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "2HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "VAL", + "atm1": " CB ", + "atm2": " CG2", + "atm3": "3HG2", + "x0": 1.91114, + "K": 57.218788 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "x0": 1.93332, + "K": 98.83621 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.87003, + "K": 59.0502 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " HB ", + "x0": 1.91141, + "K": 59.0502 + }, + { + "res": "VAL", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "x0": 1.9251, + "K": 144.466 + }, + { + "res": "VAL", + "atm1": " N ", + "atm2": " CA ", + "atm3": " HA ", + "x0": 1.89368, + "K": 82.1568 + } + ], + "torsions": [ + { + "res": "ALA", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "1HB ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "2HB ", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " C ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "3HB ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " HA ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "1HB ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " HA ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "2HB ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " HA ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "3HB ", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "1HB ", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "2HB ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " CA ", + "atm3": " CB ", + "atm4": "3HB ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " NE ", + "atm3": " CZ ", + "atm4": " NH1", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "ARG", + "atm1": " CD ", + "atm2": " NE ", + "atm3": " CZ ", + "atm4": " NH2", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "ARG", + "atm1": " HE ", + "atm2": " NE ", + "atm3": " CZ ", + "atm4": " NH1", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " HE ", + "atm2": " NE ", + "atm3": " CZ ", + "atm4": " NH2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "atm3": " NH1", + "atm4": "1HH1", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "atm3": " NH1", + "atm4": "2HH1", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "atm3": " NH2", + "atm4": "1HH2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NE ", + "atm2": " CZ ", + "atm3": " NH2", + "atm4": "2HH2", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NH1", + "atm2": " CZ ", + "atm3": " NH2", + "atm4": "1HH2", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NH1", + "atm2": " CZ ", + "atm3": " NH2", + "atm4": "2HH2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NH2", + "atm2": " CZ ", + "atm3": " NH1", + "atm4": "1HH1", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ARG", + "atm1": " NH2", + "atm2": " CZ ", + "atm3": " NH1", + "atm4": "2HH1", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND2", + "atm4": "1HD2", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND2", + "atm4": "2HD2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ASN", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " OD1", + "atm4": " ND2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "ASN", + "atm1": " OD1", + "atm2": " CG ", + "atm3": " ND2", + "atm4": "1HD2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "ASN", + "atm1": " OD1", + "atm2": " CG ", + "atm3": " ND2", + "atm4": "2HD2", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "ASP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " OD1", + "atm4": " OD2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "CYS", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": " SG ", + "atm4": " HG ", + "x0": 1.047197551, + "K": 1, + "period": 3 + }, + { + "res": "CYS", + "atm1": "2HB ", + "atm2": " CB ", + "atm3": " SG ", + "atm4": " HG ", + "x0": 1.047197551, + "K": 1, + "period": 3 + }, + { + "res": "CYS", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " SG ", + "atm4": " HG ", + "x0": 1.047197551, + "K": 1, + "period": 3 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " NE2", + "atm4": "1HE2", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " NE2", + "atm4": "2HE2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "GLN", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " OE1", + "atm4": " NE2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "GLN", + "atm1": " OE1", + "atm2": " CD ", + "atm3": " NE2", + "atm4": "1HE2", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "GLN", + "atm1": " OE1", + "atm2": " CD ", + "atm3": " NE2", + "atm4": "2HE2", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "GLU", + "atm1": " CG ", + "atm2": " CD ", + "atm3": " OE1", + "atm4": " OE2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " NE2", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND1", + "atm4": " CE1", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " ND1", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " NE2", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " NE2", + "atm4": "2HE ", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " ND1", + "atm3": " CE1", + "atm4": "1HE ", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": " CG ", + "atm2": " ND1", + "atm3": " CE1", + "atm4": " NE2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " NE2", + "atm4": " CE1", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " NE2", + "atm4": "2HE ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " NE2", + "atm4": " CD2", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " NE2", + "atm4": "2HE ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CE1", + "atm3": " NE2", + "atm4": " CD2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CE1", + "atm3": " NE2", + "atm4": "2HE ", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS", + "atm1": " ND1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " NE2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " NE2", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " ND1", + "atm4": " CE1", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " ND1", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " NE2", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " ND1", + "atm3": " CE1", + "atm4": "1HE ", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " CG ", + "atm2": " ND1", + "atm3": " CE1", + "atm4": " NE2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " NE2", + "atm4": " CE1", + "x0": 3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " NE2", + "atm4": " CD2", + "x0": 3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CE1", + "atm3": " NE2", + "atm4": " CD2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": -3.141592654, + "K": 34.552, + "period": 2 + }, + { + "res": "HIS_D", + "atm1": " ND1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " NE2", + "x0": 0, + "K": 43.352, + "period": 2 + }, + { + "res": "ILE", + "atm1": "1HG1", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "1HD1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": "1HG1", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "2HD1", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": "1HG1", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "3HD1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": "2HG1", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "1HD1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": "2HG1", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "2HD1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": "2HG1", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "3HD1", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "1HD1", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "2HD1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CB ", + "atm2": " CG1", + "atm3": " CD1", + "atm4": "3HD1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ILE", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD1", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "2HD1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "3HD1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "1HD2", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "3HD2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "1HD2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "3HD2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "2HD1", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "3HD1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " HG ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " HG ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "2HD1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " HG ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "3HD1", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " HG ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "1HD2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " HG ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LEU", + "atm1": " HG ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "3HD2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": "1HE ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "1HZ ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": "1HE ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "2HZ ", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": "1HE ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "3HZ ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": "2HE ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "1HZ ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": "2HE ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "2HZ ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": "2HE ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "3HZ ", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "1HZ ", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "2HZ ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "LYS", + "atm1": " CD ", + "atm2": " CE ", + "atm3": " NZ ", + "atm4": "3HZ ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " SD ", + "atm3": " CE ", + "atm4": "1HE ", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " SD ", + "atm3": " CE ", + "atm4": "2HE ", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "MET", + "atm1": " CG ", + "atm2": " SD ", + "atm3": " CE ", + "atm4": "3HE ", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": " CE1", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD ", + "x0": 0, + "K": 29.8460176, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE2", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": 0, + "K": 29.8460176, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " CE2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " HZ ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " HZ ", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD ", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": " CZ ", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": "1HE ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " CZ ", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "PHE", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": "2HE ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "1HD ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": " CZ ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "1HD ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": "1HE ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " CZ ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": "2HE ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " CE2", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " HZ ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "2HE ", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " CE1", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "PHE", + "atm1": "2HE ", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " HZ ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "SER", + "atm1": "1HB ", + "atm2": " CB ", + "atm3": " OG ", + "atm4": " HG ", + "x0": 1.047197551, + "K": 1.2086, + "period": 3 + }, + { + "res": "SER", + "atm1": "2HB ", + "atm2": " CB ", + "atm3": " OG ", + "atm4": " HG ", + "x0": 1.047197551, + "K": 1.2086, + "period": 3 + }, + { + "res": "SER", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " OG ", + "atm4": " HG ", + "x0": 1.047197551, + "K": 1.2086, + "period": 3 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " OG1", + "atm4": " HG1", + "x0": 1.047197551, + "K": 1.2086, + "period": 3 + }, + { + "res": "THR", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " OG1", + "atm4": " HG1", + "x0": 1.047197551, + "K": 1.2086, + "period": 3 + }, + { + "res": "THR", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " OG1", + "atm4": " HG1", + "x0": 1.047197551, + "K": 1.2086, + "period": 3 + }, + { + "res": "THR", + "atm1": " OG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " OG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "THR", + "atm1": " OG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": " NE1", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE2", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE3", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE3", + "x0": 3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " NE1", + "atm3": " CE2", + "atm4": " CD2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD1", + "atm2": " NE1", + "atm3": " CE2", + "atm4": " CZ2", + "x0": 3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ2", + "atm4": " CH2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ2", + "atm4": " HZ2", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE3", + "atm3": " CZ3", + "atm4": " CH2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CE3", + "atm3": " CZ3", + "atm4": " HZ3", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD ", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": " NE1", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CD2", + "atm3": " CE3", + "atm4": " CZ3", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CD2", + "atm3": " CE3", + "atm4": " HE3", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CZ2", + "atm3": " CH2", + "atm4": " CZ3", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE2", + "atm2": " CZ2", + "atm3": " CH2", + "atm4": " HH2", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " CZ2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " NE1", + "x0": 3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CZ3", + "atm3": " CH2", + "atm4": " CZ2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CE3", + "atm2": " CZ3", + "atm3": " CH2", + "atm4": " HH2", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " NE1", + "atm4": " CE2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " NE1", + "atm4": "1HE ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " CZ2", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " NE1", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE3", + "atm4": " CZ3", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE3", + "atm4": " HE3", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": "1HD ", + "atm2": " CD1", + "atm3": " NE1", + "atm4": " CE2", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": "1HD ", + "atm2": " CD1", + "atm3": " NE1", + "atm4": "1HE ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": "1HE ", + "atm2": " NE1", + "atm3": " CE2", + "atm4": " CD2", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": "1HE ", + "atm2": " NE1", + "atm3": " CE2", + "atm4": " CZ2", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " HE3", + "atm2": " CE3", + "atm3": " CZ3", + "atm4": " CH2", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " HE3", + "atm2": " CE3", + "atm3": " CZ3", + "atm4": " HZ3", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " HZ2", + "atm2": " CZ2", + "atm3": " CH2", + "atm4": " CZ3", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " HZ2", + "atm2": " CZ2", + "atm3": " CH2", + "atm4": " HH2", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " HZ3", + "atm2": " CZ3", + "atm3": " CH2", + "atm4": " CZ2", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " HZ3", + "atm2": " CZ3", + "atm3": " CH2", + "atm4": " HH2", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TRP", + "atm1": " NE1", + "atm2": " CE2", + "atm3": " CZ2", + "atm4": " CH2", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TRP", + "atm1": " NE1", + "atm2": " CE2", + "atm3": " CZ2", + "atm4": " HZ2", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": " CE1", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD ", + "x0": 0, + "K": 29.8460176, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE2", + "x0": -3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CB ", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": 0, + "K": 29.8460176, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " CE2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " OH ", + "x0": -3.141592654, + "K": 43.352, + "period": 2 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": " CE2", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD1", + "atm2": " CG ", + "atm3": " CD2", + "atm4": "2HD ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " OH ", + "x0": 3.141592654, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": " CE1", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CD2", + "atm2": " CG ", + "atm3": " CD1", + "atm4": "1HD ", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CE1", + "atm2": " CZ ", + "atm3": " OH ", + "atm4": " HH ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "TYR", + "atm1": " CE2", + "atm2": " CZ ", + "atm3": " OH ", + "atm4": " HH ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": " CZ ", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": "1HE ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " CZ ", + "x0": 0, + "K": 43.352, + "period": 1 + }, + { + "res": "TYR", + "atm1": " CG ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": "2HE ", + "x0": -3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "1HD ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": " CZ ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "1HD ", + "atm2": " CD1", + "atm3": " CE1", + "atm4": "1HE ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": " CZ ", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "2HD ", + "atm2": " CD2", + "atm3": " CE2", + "atm4": "2HE ", + "x0": 0, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " CE2", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "1HE ", + "atm2": " CE1", + "atm3": " CZ ", + "atm4": " OH ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "TYR", + "atm1": "2HE ", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " CE1", + "x0": 3.141592654, + "K": 34.552, + "period": 1 + }, + { + "res": "TYR", + "atm1": "2HE ", + "atm2": " CE2", + "atm3": " CZ ", + "atm4": " OH ", + "x0": 0, + "K": 34.552, + "period": 2 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "1HG1", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "2HG1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "3HG1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CA ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CG1", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "1HG1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "2HG1", + "x0": 3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " CG2", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "3HG1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "1HG1", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "2HG1", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG1", + "atm4": "3HG1", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "1HG2", + "x0": 1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "2HG2", + "x0": -3.141592654, + "K": 24.336, + "period": 3 + }, + { + "res": "VAL", + "atm1": " HB ", + "atm2": " CB ", + "atm3": " CG2", + "atm4": "3HG2", + "x0": -1.047197551, + "K": 24.336, + "period": 3 + }, + { + "res": "ALA", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1445, + "K": 45.8703, + "period": 1 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1402, + "K": 47.5692, + "period": 1 + }, + { + "res": "ALA", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.06684, + "K": 39.904, + "period": 1 + }, + { + "res": "ARG", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1462, + "K": 27.7487, + "period": 1 + }, + { + "res": "ARG", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1411, + "K": 32.8454, + "period": 1 + }, + { + "res": "ARG", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.08116, + "K": 39.904, + "period": 1 + }, + { + "res": "ASN", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1343, + "K": 23.2183, + "period": 1 + }, + { + "res": "ASN", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1497, + "K": 24.3509, + "period": 1 + }, + { + "res": "ASN", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.06748, + "K": 39.904, + "period": 1 + }, + { + "res": "ASP", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1413, + "K": 23.2183, + "period": 1 + }, + { + "res": "ASP", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1454, + "K": 23.7846, + "period": 1 + }, + { + "res": "ASP", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07852, + "K": 39.904, + "period": 1 + }, + { + "res": "CYS", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1558, + "K": 24.9172, + "period": 1 + }, + { + "res": "CYS", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1332, + "K": 29.4476, + "period": 1 + }, + { + "res": "CYS", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.0829, + "K": 39.904, + "period": 1 + }, + { + "res": "GLN", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1488, + "K": 27.7487, + "period": 1 + }, + { + "res": "GLN", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1398, + "K": 31.7128, + "period": 1 + }, + { + "res": "GLN", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07464, + "K": 39.904, + "period": 1 + }, + { + "res": "GLU", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1466, + "K": 28.8813, + "period": 1 + }, + { + "res": "GLU", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1398, + "K": 31.7128, + "period": 1 + }, + { + "res": "GLU", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07654, + "K": 39.904, + "period": 1 + }, + { + "res": "HIS", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1507, + "K": 23.7846, + "period": 1 + }, + { + "res": "HIS", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1339, + "K": 25.4835, + "period": 1 + }, + { + "res": "HIS", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.09335, + "K": 39.904, + "period": 1 + }, + { + "res": "HIS_D", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1507, + "K": 23.7846, + "period": 1 + }, + { + "res": "HIS_D", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1339, + "K": 25.4835, + "period": 1 + }, + { + "res": "HIS_D", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.09335, + "K": 39.904, + "period": 1 + }, + { + "res": "ILE", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1456, + "K": 23.7846, + "period": 1 + }, + { + "res": "ILE", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1473, + "K": 27.7487, + "period": 1 + }, + { + "res": "ILE", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.0785, + "K": 39.904, + "period": 1 + }, + { + "res": "LEU", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1565, + "K": 28.8813, + "period": 1 + }, + { + "res": "LEU", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1322, + "K": 32.2791, + "period": 1 + }, + { + "res": "LEU", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07936, + "K": 39.904, + "period": 1 + }, + { + "res": "LYS", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1479, + "K": 28.315, + "period": 1 + }, + { + "res": "LYS", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.14, + "K": 32.2791, + "period": 1 + }, + { + "res": "LYS", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07474, + "K": 39.904, + "period": 1 + }, + { + "res": "MET", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1497, + "K": 22.0857, + "period": 1 + }, + { + "res": "MET", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1398, + "K": 26.0498, + "period": 1 + }, + { + "res": "MET", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07171, + "K": 39.904, + "period": 1 + }, + { + "res": "PHE", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1502, + "K": 22.652, + "period": 1 + }, + { + "res": "PHE", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1378, + "K": 26.6161, + "period": 1 + }, + { + "res": "PHE", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07838, + "K": 39.904, + "period": 1 + }, + { + "res": "PRO", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.2002, + "K": 41.3399, + "period": 1 + }, + { + "res": "PRO", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.0058, + "K": 48.1355, + "period": 1 + }, + { + "res": "SER", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.153, + "K": 23.7846, + "period": 1 + }, + { + "res": "SER", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1381, + "K": 25.4835, + "period": 1 + }, + { + "res": "SER", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07875, + "K": 39.904, + "period": 1 + }, + { + "res": "THR", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1555, + "K": 26.0498, + "period": 1 + }, + { + "res": "THR", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1439, + "K": 30.0139, + "period": 1 + }, + { + "res": "THR", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.0753, + "K": 39.904, + "period": 1 + }, + { + "res": "TRP", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.153, + "K": 23.7846, + "period": 1 + }, + { + "res": "TRP", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1358, + "K": 26.0498, + "period": 1 + }, + { + "res": "TRP", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.0801, + "K": 39.904, + "period": 1 + }, + { + "res": "TYR", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1522, + "K": 24.9172, + "period": 1 + }, + { + "res": "TYR", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1358, + "K": 28.8813, + "period": 1 + }, + { + "res": "TYR", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.07132, + "K": 39.904, + "period": 1 + }, + { + "res": "VAL", + "atm1": " C ", + "atm2": " N ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 4.1464, + "K": 26.6161, + "period": 1 + }, + { + "res": "VAL", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " CB ", + "x0": 2.1466, + "K": 29.4476, + "period": 1 + }, + { + "res": "VAL", + "atm1": " N ", + "atm2": " C ", + "atm3": " CA ", + "atm4": " HA ", + "x0": -2.08278, + "K": 39.904, + "period": 1 + } + ] +} \ No newline at end of file diff --git a/RF2_allatom/chemical.py b/RF2_allatom/chemical.py new file mode 100644 index 0000000..bbd4af5 --- /dev/null +++ b/RF2_allatom/chemical.py @@ -0,0 +1,1076 @@ +import torch +import numpy as np +import rdkit + +num2aa=[ + 'ALA','ARG','ASN','ASP','CYS', + 'GLN','GLU','GLY','HIS','ILE', + 'LEU','LYS','MET','PHE','PRO', + 'SER','THR','TRP','TYR','VAL', + 'UNK','MAS', + ' DA',' DC',' DG',' DU', ' DX', + ' RA',' RC',' RG',' RU', ' RX', + 'HIS_D', # only used for cart_bonded + 'Al', 'As', 'Au', 'B', + 'Be', 'Br', 'C', 'Ca', 'Cl', + 'Co', 'Cr', 'Cu', 'F', 'Fe', + 'Hg', 'I', 'Ir', 'Li', 'Mg', + 'Mn', 'Mo', 'N', 'Ni', 'O', + 'Os', 'P', 'Pb', 'Pd', 'Pr', + 'Pt', 'Re', 'Rh', 'Ru', 'S', + 'Sb', 'Se', 'Si', 'Sn', 'Tb', + 'Te', 'U', 'W', 'Y', 'Zn', + 'ATM' +] + +aa2num= {x:i for i,x in enumerate(num2aa)} + +NAATOKENS = 20+2+10+1+45 # 20 AAs, UNK, MASK, 8 NAs,HIS_D, 45 atoms +MASKINDEX = 21 # protein mask + +NHEAVY = 23 +NTOTAL = 36 +NNAPROTAAS = 32 +NPROTAAS = 22 # include UNK/MAS + +# internal coords +NPROTTORS = 7 +NPROTANGS = 3 +NNATORS = 10 +NTOTALTORS = NPROTTORS+NNATORS +NTOTALDOFS = NTOTALTORS+NPROTANGS + +#bond types +num2btype = [0,1,2,3,4] # UNK, SINGLE, DOUBLE, TRIPLE, AROMATIC + +# reindexes the rdkit btypes +rdkit2btype = { + int(rdkit.Chem.rdchem.BondType.UNSPECIFIED): 0, + int(rdkit.Chem.rdchem.BondType.SINGLE): 1, + int(rdkit.Chem.rdchem.BondType.DOUBLE): 2, + int(rdkit.Chem.rdchem.BondType.TRIPLE): 3, + int(rdkit.Chem.rdchem.BondType.AROMATIC): 4 +} +NBTYPES = len(num2btype) +# full sc atom representation +aa2long=[ + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), #0 ala + (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD "," HE ","1HH1","2HH1","1HH2","2HH2"), #1 arg + (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD2","2HD2", None, None, None, None, None, None, None), #2 asn + (" N "," CA "," C "," O "," CB "," CG "," OD1"," OD2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None, None), #3 asp + (" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ", None, None, None, None, None, None, None, None), #4 cys + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE2","2HE2", None, None, None, None, None), #5 gln + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," OE2", None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ", None, None, None, None, None, None, None), #6 glu + (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H ","1HA ","2HA ", None, None, None, None, None, None, None, None, None, None), #7 gly + (" N "," CA "," C "," O "," CB "," CG "," ND1"," CD2"," CE1"," NE2", None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","2HD ","1HE ","2HE ", None, None, None, None, None, None), #8 his + (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA "," HB ","1HG2","2HG2","3HG2","1HG1","2HG1","1HD1","2HD1","3HD1", None, None), #9 ile + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ","1HD1","2HD1","3HD1","1HD2","2HD2","3HD2", None, None), #10 leu + (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ","1HE ","2HE ","1HZ ","2HZ ","3HZ "), #11 lys + (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE ","2HE ","3HE ", None, None, None, None), #12 met + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ", None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD ","2HD ","1HE ","2HE "," HZ ", None, None, None, None), #13 phe + (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ", None, None, None, None, None, None), #14 pro + (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HG "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None), #15 ser + (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HG1"," HA "," HB ","1HG2","2HG2","3HG2", None, None, None, None, None, None), #16 thr + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD ","1HE "," HZ2"," HH2"," HZ3"," HE3", None, None, None), #17 trp + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ", None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD ","1HE ","2HE ","2HD "," HH ", None, None, None, None), #18 tyr + (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA "," HB ","1HG1","2HG1","3HG1","1HG2","2HG2","3HG2", None, None, None, None), #19 val + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), #20 unk + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), #21 mask + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N9 "," C4 "," N3 "," C2 "," N1 "," C6 "," C5 "," N7 "," C8 "," N6 ", None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H2 "," H61"," H62"," H8 ", None, None), #22 DA + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N1 "," C2 "," O2 "," N3 "," C4 "," N4 "," C5 "," C6 ", None, None, None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H42"," H41"," H5 "," H6 ", None, None), #23 DC + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N9 "," C4 "," N3 "," C2 "," N1 "," C6 "," C5 "," N7 "," C8 "," N2 "," O6 ", None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H1 "," H22"," H21"," H8 ", None, None), #24 DG + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N1 "," C2 "," O2 "," N3 "," C4 "," O4 "," C5 "," C7 "," C6 ", None, None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H3 "," H71"," H72"," H73"," H6 ", None), #25 DT + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'", None, None, None, None, None, None, None, None, None, None, None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'", None, None, None, None, None, None), #26 DX (unk DNA) + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," N3 "," C4 "," C5 "," C6 "," N6 "," N7 "," C8 "," N9 ", None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H2 "," H61"," H62"," H8 ", None, None), #27 A + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," O2 "," N3 "," C4 "," N4 "," C5 "," C6 ", None, None, None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H42"," H41"," H5 "," H6 ", None, None), #28 C + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," N2 "," N3 "," C4 "," C5 "," C6 "," O6 "," N7 "," C8 "," N9 "," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H1 "," H22"," H21"," H8 ", None, None), #29 G + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," O2 "," N3 "," C4 "," O4 "," C5 "," C6 ", None, None, None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H3 "," H5 "," H6 ", None, None, None), #30 U + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'", None, None, None, None, None, None, None, None, None, None, None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'", None, None, None, None, None, None), #31 RX (unk RNA) + (" N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1", None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","2HD ","1HE ","1HD ", None, None, None, None, None, None), #-1 his_d +] + + +# build the "alternate" sc mapping +aa2longalt=[ + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # ala + (" N "," CA "," C "," O "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2", None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD "," HE ","1HH1","2HH1","1HH2","2HH2"), # arg + (" N "," CA "," C "," O "," CB "," CG "," OD1"," ND2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD2","2HD2", None, None, None, None, None, None, None), # asn + (" N "," CA "," C "," O "," CB "," CG "," OD2"," OD1", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None, None), # asp + (" N "," CA "," C "," O "," CB "," SG ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ", None, None, None, None, None, None, None, None), # cys + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE1"," NE2", None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE2","2HE2", None, None, None, None, None), # gln + (" N "," CA "," C "," O "," CB "," CG "," CD "," OE2"," OE1", None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ", None, None, None, None, None, None, None), # glu + (" N "," CA "," C "," O ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H ","1HA ","2HA ", None, None, None, None, None, None, None, None, None, None), # gly + (" N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1", None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","2HD ","1HE ","2HE ", None, None, None, None, None, None), # his + (" N "," CA "," C "," O "," CB "," CG1"," CG2"," CD1", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA "," HB ","1HG2","2HG2","3HG2","1HG1","2HG1","1HD1","2HD1","3HD1", None, None), # ile + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB "," HG ","1HD1","2HD1","3HD1","1HD2","2HD2","3HD2", None, None), # leu + (" N "," CA "," C "," O "," CB "," CG "," CD "," CE "," NZ ", None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ","1HE ","2HE ","1HZ ","2HZ ","3HZ "), # lys + (" N "," CA "," C "," O "," CB "," CG "," SD "," CE ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HG ","2HG ","1HE ","2HE ","3HE ", None, None, None, None), # met + (" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ ", None, None, None, None, None, None, None, None, None, None, None, None," H ","2HD ","2HE "," HZ ","1HE ","1HD "," HA ","1HB ","2HB ", None, None, None, None), # phe + (" N "," CA "," C "," O "," CB "," CG "," CD ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ", None, None, None, None, None, None), # pro + (" N "," CA "," C "," O "," CB "," OG ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HG "," HA ","1HB ","2HB ", None, None, None, None, None, None, None, None), # ser + (" N "," CA "," C "," O "," CB "," OG1"," CG2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HG1"," HA "," HB ","1HG2","2HG2","3HG2", None, None, None, None, None, None), # thr + (" N "," CA "," C "," O "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2", None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","1HD ","1HE "," HZ2"," HH2"," HZ3"," HE3", None, None, None), # trp + (" N "," CA "," C "," O "," CB "," CG "," CD2"," CD1"," CE2"," CE1"," CZ "," OH ", None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","2HD ","2HE ","1HE ","1HD "," HH ", None, None, None, None), # tyr + (" N "," CA "," C "," O "," CB "," CG1"," CG2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA "," HB ","1HG1","2HG1","3HG1","1HG2","2HG2","3HG2", None, None, None, None), # val + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # unk + (" N "," CA "," C "," O "," CB ", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None," H "," HA ","1HB ","2HB ","3HB ", None, None, None, None, None, None, None, None), # mask + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N9 "," C4 "," N3 "," C2 "," N1 "," C6 "," C5 "," N7 "," C8 "," N6 ", None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H2 "," H61"," H62"," H8 ", None, None), # DA + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N1 "," C2 "," O2 "," N3 "," C4 "," N4 "," C5 "," C6 ", None, None, None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H42"," H41"," H5 "," H6 ", None, None), # DC + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N9 "," C4 "," N3 "," C2 "," N1 "," C6 "," C5 "," N7 "," C8 "," N2 "," O6 ", None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H1 "," H22"," H21"," H8 ", None, None), # DG + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'"," N1 "," C2 "," O2 "," N3 "," C4 "," O4 "," C5 "," C7 "," C6 ", None, None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'"," H3 "," H71"," H72"," H73"," H6 ", None), # DT + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C2'"," C1'", None, None, None, None, None, None, None, None, None, None, None, None,"H5''"," H5'"," H4'"," H3'","H2''"," H2'"," H1'", None, None, None, None, None, None), # DX (unk DNA) + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," N3 "," C4 "," C5 "," C6 "," N6 "," N7 "," C8 "," N9 ", None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H2 "," H61"," H62"," H8 ", None, None), # A + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," O2 "," N3 "," C4 "," N4 "," C5 "," C6 ", None, None, None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H42"," H41"," H5 "," H6 ", None, None), # C + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," N2 "," N3 "," C4 "," C5 "," C6 "," O6 "," N7 "," C8 "," N9 "," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H1 "," H22"," H21"," H8 ", None, None), # G + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'"," N1 "," C2 "," O2 "," N3 "," C4 "," O4 "," C5 "," C6 ", None, None, None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'"," H3 "," H5 "," H6 ", None, None, None), # U + (" OP1"," P "," OP2"," O5'"," C5'"," C4'"," O4'"," C3'"," O3'"," C1'"," C2'"," O2'", None, None, None, None, None, None, None, None, None, None, None," H5'","H5''"," H4'"," H3'"," H2'","HO2'"," H1'", None, None, None, None, None, None), # RX (unk RNA) +] + +aabonds=[ + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # ala + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," NE "),(" CD ","1HD "),(" CD ","2HD "),(" NE "," CZ "),(" NE "," HE "),(" CZ "," NH1"),(" CZ "," NH2"),(" NH1","1HH1"),(" NH1","2HH1"),(" NH2","1HH2"),(" NH2","2HH2")) , # arg + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," OD1"),(" CG "," ND2"),(" ND2","1HD2"),(" ND2","2HD2")) , # asn + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," OD1"),(" CG "," OD2")) , # asp + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," SG "),(" CB ","1HB "),(" CB ","2HB "),(" SG "," HG ")) , # cys + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," OE1"),(" CD "," NE2"),(" NE2","1HE2"),(" NE2","2HE2")) , # gln + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," OE1"),(" CD "," OE2")) , # glu + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA ","1HA "),(" CA ","2HA "),(" C "," O ")) , # gly + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," ND1"),(" CG "," CD2"),(" ND1"," CE1"),(" CD2"," NE2"),(" CD2","2HD "),(" CE1"," NE2"),(" CE1","1HE "),(" NE2","2HE ")) , # his + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG1"),(" CB "," CG2"),(" CB "," HB "),(" CG1"," CD1"),(" CG1","1HG1"),(" CG1","2HG1"),(" CG2","1HG2"),(" CG2","2HG2"),(" CG2","3HG2"),(" CD1","1HD1"),(" CD1","2HD1"),(" CD1","3HD1")) , # ile + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CG "," HG "),(" CD1","1HD1"),(" CD1","2HD1"),(" CD1","3HD1"),(" CD2","1HD2"),(" CD2","2HD2"),(" CD2","3HD2")) , # leu + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD "," CE "),(" CD ","1HD "),(" CD ","2HD "),(" CE "," NZ "),(" CE ","1HE "),(" CE ","2HE "),(" NZ ","1HZ "),(" NZ ","2HZ "),(" NZ ","3HZ ")) , # lys + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," SD "),(" CG ","1HG "),(" CG ","2HG "),(" SD "," CE "),(" CE ","1HE "),(" CE ","2HE "),(" CE ","3HE ")) , # met + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CD1"," CE1"),(" CD1","1HD "),(" CD2"," CE2"),(" CD2","2HD "),(" CE1"," CZ "),(" CE1","1HE "),(" CE2"," CZ "),(" CE2","2HE "),(" CZ "," HZ ")) , # phe + ((" N "," CA "),(" N "," CD "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD "),(" CG ","1HG "),(" CG ","2HG "),(" CD ","1HD "),(" CD ","2HD ")) , # pro + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," OG "),(" CB ","1HB "),(" CB ","2HB "),(" OG "," HG ")) , # ser + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," OG1"),(" CB "," CG2"),(" CB "," HB "),(" OG1"," HG1"),(" CG2","1HG2"),(" CG2","2HG2"),(" CG2","3HG2")) , # thr + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CD1"," NE1"),(" CD1","1HD "),(" CD2"," CE2"),(" CD2"," CE3"),(" NE1"," CE2"),(" NE1","1HE "),(" CE2"," CZ2"),(" CE3"," CZ3"),(" CE3"," HE3"),(" CZ2"," CH2"),(" CZ2"," HZ2"),(" CZ3"," CH2"),(" CZ3"," HZ3"),(" CH2"," HH2")) , # trp + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG "),(" CB ","1HB "),(" CB ","2HB "),(" CG "," CD1"),(" CG "," CD2"),(" CD1"," CE1"),(" CD1","1HD "),(" CD2"," CE2"),(" CD2","2HD "),(" CE1"," CZ "),(" CE1","1HE "),(" CE2"," CZ "),(" CE2","2HE "),(" CZ "," OH "),(" OH "," HH ")) , # tyr + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB "," CG1"),(" CB "," CG2"),(" CB "," HB "),(" CG1","1HG1"),(" CG1","2HG1"),(" CG1","3HG1"),(" CG2","1HG2"),(" CG2","2HG2"),(" CG2","3HG2")), # val + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # unk + ((" N "," CA "),(" N "," H "),(" CA "," C "),(" CA "," CB "),(" CA "," HA "),(" C "," O "),(" CB ","1HB "),(" CB ","2HB "),(" CB ","3HB ")) , # mask + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'","H5''"),(" C5'"," H5'"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'","H2''"),(" C2'"," H2'"),(" C1'"," N9 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" C2 "," N3 "),(" C2 "," H2 "),(" N3 "," C4 "),(" C4 "," C5 "),(" C4 "," N9 "),(" C5 "," C6 "),(" C5 "," N7 "),(" C6 "," N6 "),(" N6 "," H61"),(" N6 "," H62"),(" N7 "," C8 "),(" C8 "," N9 "),(" C8 "," H8 ")) , # DA + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'","H5''"),(" C5'"," H5'"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'","H2''"),(" C2'"," H2'"),(" C1'"," N1 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" C2 "," O2 "),(" C2 "," N3 "),(" N3 "," C4 "),(" C4 "," N4 "),(" C4 "," C5 "),(" N4 "," H42"),(" N4 "," H41"),(" C5 "," C6 "),(" C5 "," H5 "),(" C6 "," H6 ")), # DC + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'","H5''"),(" C5'"," H5'"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'","H2''"),(" C2'"," H2'"),(" C1'"," N9 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" N1 "," H1 "),(" C2 "," N2 "),(" C2 "," N3 "),(" N2 "," H22"),(" N2 "," H21"),(" N3 "," C4 "),(" C4 "," C5 "),(" C4 "," N9 "),(" C5 "," C6 "),(" C5 "," N7 "),(" C6 "," O6 "),(" N7 "," C8 "),(" C8 "," N9 "),(" C8 "," H8 ")), # DG + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'","H5''"),(" C5'"," H5'"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'","H2''"),(" C2'"," H2'"),(" C1'"," N1 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" C2 "," O2 "),(" C2 "," N3 "),(" N3 "," C4 "),(" N3 "," H3 "),(" C4 "," O4 "),(" C4 "," C5 "),(" C5 "," C7 "),(" C5 "," C6 "),(" C7 "," H71"),(" C7 "," H72"),(" C7 "," H73"),(" C6 "," H6 ")), # DT + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'","H5''"),(" C5'"," H5'"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'","H2''"),(" C2'"," H2'"),(" C1'"," H1'")) , # DX + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'"," H5'"),(" C5'","H5''"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'"," O2'"),(" C2'"," H2'"),(" O2'","HO2'"),(" C1'"," N9 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" C2 "," N3 "),(" C2 "," H2 "),(" N3 "," C4 "),(" C4 "," C5 "),(" C4 "," N9 "),(" C5 "," C6 "),(" C5 "," N7 "),(" C6 "," N6 "),(" N6 "," H61"),(" N6 "," H62"),(" N7 "," C8 "),(" C8 "," N9 "),(" C8 "," H8 ")), # A + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'"," H5'"),(" C5'","H5''"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'"," O2'"),(" C2'"," H2'"),(" O2'","HO2'"),(" C1'"," N1 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" C2 "," O2 "),(" C2 "," N3 "),(" N3 "," C4 "),(" C4 "," N4 "),(" C4 "," C5 "),(" N4 "," H42"),(" N4 "," H41"),(" C5 "," C6 "),(" C5 "," H5 "),(" C6 "," H6 ")), # C + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'"," H5'"),(" C5'","H5''"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'"," O2'"),(" C2'"," H2'"),(" O2'","HO2'"),(" C1'"," N9 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" N1 "," H1 "),(" C2 "," N2 "),(" C2 "," N3 "),(" N2 "," H22"),(" N2 "," H21"),(" N3 "," C4 "),(" C4 "," C5 "),(" C4 "," N9 "),(" C5 "," C6 "),(" C5 "," N7 "),(" C6 "," O6 "),(" N7 "," C8 "),(" C8 "," N9 "),(" C8 "," H8 ")), # G + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'"," H5'"),(" C5'","H5''"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'"," O2'"),(" C2'"," H2'"),(" O2'","HO2'"),(" C1'"," N1 "),(" C1'"," H1'"),(" N1 "," C2 "),(" N1 "," C6 "),(" C2 "," O2 "),(" C2 "," N3 "),(" N3 "," C4 "),(" N3 "," H3 "),(" C4 "," O4 "),(" C4 "," C5 "),(" C5 "," C6 "),(" C5 "," H5 "),(" C6 "," H6 ")), # U + ((" P "," OP2"),(" P "," OP1"),(" P "," O5'"),(" O5'"," C5'"),(" C5'"," C4'"),(" C5'"," H5'"),(" C5'","H5''"),(" C4'"," O4'"),(" C4'"," C3'"),(" C4'"," H4'"),(" O4'"," C1'"),(" C3'"," O3'"),(" C3'"," C2'"),(" C3'"," H3'"),(" C2'"," C1'"),(" C2'"," O2'"),(" C2'"," H2'"),(" O2'","HO2'"),(" C1'"," H1'")), # RX +] + +aa2type = [ + ("Nbb", "CAbb","CObb","OCbb","CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # ala + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "CH2", "NtrR","aroC","Narg","Narg", None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol","Hpol","Hpol","Hpol"), # arg + ("Nbb", "CAbb","CObb","OCbb","CH2", "CNH2","ONH2","NH2O", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hpol","Hpol", None, None, None, None, None, None, None), # asn + ("Nbb", "CAbb","CObb","OCbb","CH2", "COO", "OOC", "OOC", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None, None), # asp + ("Nbb", "CAbb","CObb","OCbb","CH2", "SH1", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","HS", None, None, None, None, None, None, None, None), # cys + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "CNH2","ONH2","NH2O", None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol", None, None, None, None, None), # gln + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "COO", "OOC", "OOC", None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None), # glu + ("Nbb", "CAbb","CObb","OCbb", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo", None, None, None, None, None, None, None, None, None, None), # gly + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "Nhis","aroC","aroC","Ntrp", None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hpol","Hapo","Hapo", None, None, None, None, None, None), # his + ("Nbb", "CAbb","CObb","OCbb","CH1", "CH2", "CH3", "CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None), # ile + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH1", "CH3", "CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None), # leu + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "CH2", "CH2", "Nlys", None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol","Hpol"), # lys + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH2", "S", "CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None), # met + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "aroC","aroC","aroC","aroC","aroC", None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Haro","Haro","Haro","Haro","Haro", None, None, None, None), # phe + ("Npro","CAbb","CObb","OCbb","CH2", "CH2", "CH2", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None), # pro + ("Nbb", "CAbb","CObb","OCbb","CH2", "OH", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hpol","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # ser + ("Nbb", "CAbb","CObb","OCbb","CH1", "OH", "CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hpol","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None), # thr + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "aroC","CH0", "Ntrp","CH0", "aroC","aroC","aroC","aroC", None, None, None, None, None, None, None, None, None,"HNbb","Haro","Hapo","Hapo","Hapo","Hpol","Haro","Haro","Haro","Haro", None, None, None), # trp + ("Nbb", "CAbb","CObb","OCbb","CH2", "CH0", "aroC","aroC","aroC","aroC","CH0", "OHY", None, None, None, None, None, None, None, None, None, None, None,"HNbb","Haro","Haro","Haro","Haro","Hapo","Hapo","Hapo","Hpol", None, None, None, None), # tyr + ("Nbb", "CAbb","CObb","OCbb","CH1", "CH3", "CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None), # val + ("Nbb", "CAbb","CObb","OCbb","CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # unk + ("Nbb", "CAbb","CObb","OCbb","CH3", None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None,"HNbb","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None, None, None), # mask + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH2", "CH1", "Npro","aroC","Nhis","aroC","Nhis","aroC","aroC","Nhis","aroC","NH2O", None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Haro","Hpol","Hpol","Haro", None, None), # DA + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH2", "CH1", "Npro","CObb","OCbb","Nhis","aroC","NH2O","aroC","aroC", None, None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol","Haro","Haro", None, None), # DC + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH2", "CH1", "Npro","aroC","Nhis","aroC","Ntrp","CObb","aroC","Nhis","aroC","NH2O","OCbb", None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hpol","Hpol","Haro", None, None), # DG + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH2", "CH1", "Npro","CObb","OCbb","Ntrp","CObb","OCbb","aroC","CH3", "aroC", None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hapo","Hapo","Hapo","Haro", None), # DT + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH2", "CH1", None, None, None, None, None, None, None, None, None, None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hapo","Hapo", None, None, None, None, None, None), # DX (unk DNA) + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH1", "CH2", "OH", "Nhis","aroC","Nhis","aroC","aroC","aroC","NH2O","Nhis","aroC","Npro", None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hapo","Haro","Hpol","Hpol","Haro", None, None), # A + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH1", "CH2", "OH", "Npro","CObb","OCbb","Nhis","aroC","NH2O","aroC","aroC", None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hapo","Hpol","Hpol","Haro","Haro", None, None), # C + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH1", "CH2", "OH", "Ntrp","aroC","NH2O","Nhis","aroC","aroC","CObb","OCbb","Nhis","aroC","Npro","Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hapo","Hpol","Hpol","Hpol","Haro", None, None), # G + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH1", "CH2", "OH", "Npro","CObb","OCbb","Ntrp","CObb","OCbb","aroC","aroC", None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hapo","Hpol","Hapo","Haro", None, None, None), # U + ("OOC","Phos", "OOC", "Oet2","CH2", "CH1", "Oet3","CH1", "Oet2","CH1", "CH2", "OH", None, None, None, None, None, None, None, None, None, None, None,"Hapo","Hapo","Hapo","Hapo","Hapo","Hpol","Hapo", None, None, None, None, None, None), # RX (unk RNA) +] + +aa2elt = [ + ("N","C","C","O","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H",None,None,None,None,None,None,None,None),#ala + ("N","C","C","O","C","C","C","N","C","N","N",None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H","H","H","H","H"),#arg + ("N","C","C","O","C","C","O","N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H",None,None,None,None,None,None,None),#asn + ("N","C","C","O","C","C","O","O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H",None,None,None,None,None,None,None,None,None),#asp + ("N","C","C","O","C","S",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H",None,None,None,None,None,None,None,None),#cys + ("N","C","C","O","C","C","C","O","N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H",None,None,None,None,None),#gln + ("N","C","C","O","C","C","C","O","O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H",None,None,None,None,None,None,None),#glu + ("N","C","C","O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H",None,None,None,None,None,None,None,None,None,None),#gly + ("N","C","C","O","C","C","N","C","C","N",None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),#his + ("N","C","C","O","C","C","C","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#ile + ("N","C","C","O","C","C","C","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#leu + ("N","C","C","O","C","C","C","C","N",None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H","H","H","H","H"),#lys + ("N","C","C","O","C","C","S","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H",None,None,None,None),#met + ("N","C","C","O","C","C","C","C","C","C","C",None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H",None,None,None,None),#phe + ("N","C","C","O","C","C","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),#pro + ("N","C","C","O","C","O",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H",None,None,None,None,None,None,None,None),#ser + ("N","C","C","O","C","O","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),#thr + ("N","C","C","O","C","C","C","C","N","C","C","C","C","C",None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H","H",None,None,None),#trp + ("N","C","C","O","C","C","C","C","C","C","C","O",None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H",None,None,None,None),#tyr + ("N","C","C","O","C","C","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H","H",None,None,None,None),#val + ("N","C","C","O","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H",None,None,None,None,None,None,None,None),#unk + ("N","C","C","O","C",None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H",None,None,None,None,None,None,None,None),#mask + ("O","P","O","O","C","C","O","C","O","C","C","N","C","N","C","N","C","C","N","C","N",None,None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#DA + ("O","P","O","O","C","C","O","C","O","C","C","N","C","O","N","C","N","C","C",None,None,None,None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#DC + ("O","P","O","O","C","C","O","C","O","C","C","N","C","N","C","N","C","C","N","C","N","O",None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#DG + ("O","P","O","O","C","C","O","C","O","C","C","N","C","O","N","C","O","C","C","C",None,None,None,"H","H","H","H","H","H","H","H","H","H","H","H",None),#DT + ("O","P","O","O","C","C","O","C","O","C","C",None,None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H",None,None,None,None,None,None),#DX + ("O","P","O","O","C","C","O","C","O","C","C","O","N","C","N","C","C","C","N","N","C","N",None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#A + ("O","P","O","O","C","C","O","C","O","C","C","O","N","C","O","N","C","N","C","C",None,None,None,"H","H","H","H","H","H","H","H","H","H","H",None,None),#C + ("O","P","O","O","C","C","O","C","O","C","C","O","N","C","N","N","C","C","C","O","N","C","N","H","H","H","H","H","H","H","H","H","H","H",None,None),#G + ("O","P","O","O","C","C","O","C","O","C","C","O","N","C","O","N","C","O","C","C",None,None,None,"H","H","H","H","H","H","H","H","H","H",None,None,None),#U + ("O","P","O","O","C","C","O","C","O","C","C","O",None,None,None,None,None,None,None,None,None,None,None,"H","H","H","H","H","H","H","H",None,None,None,None,None),#RX +] + + +# tip atom +aa2tip = [ + " CB ", # ala + " CZ ", # arg + " ND2", # asn + " CG ", # asp + " SG ", # cys + " NE2", # gln + " CD ", # glu + " CA ", # gly + " NE2", # his + " CD1", # ile + " CG ", # leu + " NZ ", # lys + " SD ", # met + " CZ ", # phe + " CG ", # pro + " OG ", # ser + " OG1", # thr + " CH2", # trp + " OH ", # tyr + " CB ", # val + " CB ", # unknown (gap etc) + " CB ", # masked + " N1 ", # DA + " N3 ", # DC + " N1 ", # DG + " N3 ", # DT + " C1'", # DX + " N1 ", # A + " N3 ", # C + " N1 ", # G + " N3 ", # U + " C1'", # RX + ] + +# ideal N, CA, C initial coordinates (protein) +init_N = torch.tensor([-0.5272, 1.3593, 0.000]).float() +init_CA = torch.zeros_like(init_N) +init_C = torch.tensor([1.5233, 0.000, 0.000]).float() +INIT_CRDS = torch.full((NTOTAL, 3), np.nan) +INIT_CRDS[:3] = torch.stack((init_N, init_CA, init_C), dim=0) # (3,3) + +# ideal OP1,P,OP2 initial coordinates (nucleic acid) +init_OP1 = torch.tensor([-0.7319, 1.2920, 0.000]).float() +init_P = torch.zeros_like(init_OP1) +init_OP2 = torch.tensor([1.5233, 0.000, 0.000]).float() +INIT_NA_CRDS = torch.full((NTOTAL, 3), np.nan) +INIT_NA_CRDS[:3] = torch.stack((init_OP1, init_P, init_OP2), dim=0) # (3,3) + +# non-backbone torsions +# (bb torsions are hard-coded) +torsions=[ + [ None, None, None, None ], # ala + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," NE "], [" CG "," CD "," NE "," CZ "] ], # arg + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," OD1"], None, None ], # asn + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," OD1"], None, None ], # asp + [ [" N "," CA "," CB "," SG "], [" CA "," CB "," SG "," HG "], None, None ], # cys + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," OE1"], None ], # gln + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," OE1"], None ], # glu + [ None, None, None, None ], # gly + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," ND1"], [" CD2"," CE1","1HE "," NE2"], None ], # his (protonation handled as a pseudo-torsion) + [ [" N "," CA "," CB "," CG1"], [" CA "," CB "," CG1"," CD1"], None, None ], # ile + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], None, None ], # leu + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD "," CE "], [" CG "," CD "," CE "," NZ "] ], # lys + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," SD "], [" CB "," CG "," SD "," CE "], None ], # met + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], None, None ], # phe + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD "], [" CB "," CG "," CD ","1HD "], None ], # pro + [ [" N "," CA "," CB "," OG "], [" CA "," CB "," OG "," HG "], None, None ], # ser + [ [" N "," CA "," CB "," OG1"], [" CA "," CB "," OG1"," HG1"], None, None ], # thr + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], None, None ], # trp + [ [" N "," CA "," CB "," CG "], [" CA "," CB "," CG "," CD1"], [" CE1"," CZ "," OH "," HH "], None ], # tyr + [ [" N "," CA "," CB "," CG1"], None, None, None ], # val + [ None, None, None, None ], # unk + [ None, None, None, None ], # mask + [ [" O4'"," C1'"," N9 "," C4 "], None, None, None ],#DA + [ [" O4'"," C1'"," N1 "," C2 "], None, None, None ],#DC + [ [" O4'"," C1'"," N9 "," C4 "], None, None, None ],#DG + [ [" O4'"," C1'"," N1 "," C2 "], None, None, None ],#DT + [ None, None, None, None ], # DX + [ [" O4'"," C1'"," N9 "," C4 "], None, None, None ],#A + [ [" O4'"," C1'"," N1 "," C2 "], None, None, None ],#C + [ [" O4'"," C1'"," N9 "," C4 "], None, None, None ],#G + [ [" O4'"," C1'"," N1 "," C2 "], None, None, None ],#U + [ None, None, None, None ], # RX +] + +# frames for generic FAPE +frames=[ + [ [" N "," CA "," C "],[" CA "," C "," O "] ], # ala + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "], [" CB "," CG "," CD "], [" CG "," CD "," NE "] ], # arg + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # asn + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # asp + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "] ], # cys + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "], [" CB "," CG "," CD "] ], # gln + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "], [" CB "," CG "," CD "] ], # glu + [ [" N "," CA "," C "],[" CA "," C "," O "] ], # gly + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # his + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG1"] ], # ile + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # leu + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "], [" CB "," CG "," CD "], [" CG "," CD "," CE "] ], # lys + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "], [" CB "," CG "," SD "] ], # met + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # phe + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "], [" CB "," CG "," CD "]], # pro + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," OG "] ], # ser + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," OG1"] ], # thr + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # trp + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "], [" CA "," CB "," CG "] ], # tyr + [ [" N "," CA "," C "],[" CA "," C "," O "],[" N "," CA "," CB "] ], # val + [ [" N "," CA "," C "],[" CA "," C "," O "] ], # unk + [ [" N "," CA "," C "],[" CA "," C "," O "] ], # mask + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N9 "], [" C4'"," C3'"," O3'"] ], #DA + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N1 "], [" C4'"," C3'"," O3'"] ], #DC + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N9 "], [" C4'"," C3'"," O3'"] ], #DG + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N1 "], [" C4'"," C3'"," O3'"] ], #DT + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C4'"," C3'"," O3'"] ], #DX + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N9 "], [" C4'"," C3'"," O3'"] ], #A + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N1 "], [" C4'"," C3'"," O3'"] ], #C + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N9 "], [" C4'"," C3'"," O3'"] ], #G + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C2'"," C1'"," N1 "], [" C4'"," C3'"," O3'"] ], #U + [ [" OP1"," P "," O5'"], [" P "," O5'"," C5'"], [" O5'"," C5'"," C4'"], [" C5'"," C4'"," C3'"], [" C5'"," C4'"," O4'"], [" C4'"," O4'"," C1'"], [" C4'"," C3'"," O3'"] ], #RX +] +NFRAMES = max([len(f) for f in frames]) + + + +#fd Rosetta ideal coords +#fd - uses same "frame-building" as AF2 +# FRAMES: +# base = 0 +# omega/phi/psi = 1-3 (omega unused) +# chi_1-4(prot) = 4-7 +# CB_bend = 8 +# NA alpha/beta/gamma/delta = 9-12 (NA epsilon/zeta no frame) +# NA nu2/nu1/nu0 = 13-15 +# chi_1(NA) = 16 +ideal_coords = [ + [ # 0 ala + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3341, -0.4928, 0.9132)], + [' CB ', 8, (-0.5289,-0.7734,-1.1991)], + ['1HB ', 8, (-0.1265, -1.7863, -1.1851)], + ['2HB ', 8, (-1.6173, -0.8147, -1.1541)], + ['3HB ', 8, (-0.2229, -0.2744, -2.1172)], + ], + [ # 1 arg + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3467, -0.5055, 0.9018)], + [' CB ', 8, (-0.5042,-0.7698,-1.2118)], + ['1HB ', 4, ( 0.3635, -0.5318, 0.8781)], + ['2HB ', 4, ( 0.3639, -0.5323, -0.8789)], + [' CG ', 4, (0.6396,1.3794, 0.000)], + ['1HG ', 5, (0.3639, -0.5139, 0.8900)], + ['2HG ', 5, (0.3641, -0.5140, -0.8903)], + [' CD ', 5, (0.5492,1.3801, 0.000)], + ['1HD ', 6, (0.3637, -0.5135, 0.8895)], + ['2HD ', 6, (0.3636, -0.5134, -0.8893)], + [' NE ', 6, (0.5423,1.3491, 0.000)], + [' NH1', 7, (0.2012,2.2965, 0.000)], + [' NH2', 7, (2.0824,1.0030, 0.000)], + [' CZ ', 7, (0.7650,1.1090, 0.000)], + [' HE ', 7, (0.4701,-0.8955, 0.000)], + ['1HH1', 7, (-0.8059,2.3776, 0.000)], + ['1HH2', 7, (2.5160,0.0898, 0.000)], + ['2HH1', 7, (0.7745,3.1277, 0.000)], + ['2HH2', 7, (2.6554,1.8336, 0.000)], + ], + [ # 2 asn + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3233, -0.4967, 0.9162)], + [' CB ', 8, (-0.5341,-0.7799,-1.1874)], + ['1HB ', 4, ( 0.3641, -0.5327, 0.8795)], + ['2HB ', 4, ( 0.3639, -0.5323, -0.8789)], + [' CG ', 4, (0.5778,1.3881, 0.000)], + [' ND2', 5, (0.5839,-1.1711, 0.000)], + [' OD1', 5, (0.6331,1.0620, 0.000)], + ['1HD2', 5, (1.5825, -1.2322, 0.000)], + ['2HD2', 5, (0.0323, -2.0046, 0.000)], + ], + [ # 3 asp + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3233, -0.4967, 0.9162)], + [' CB ', 8, (-0.5162,-0.7757,-1.2144)], + ['1HB ', 4, ( 0.3639, -0.5324, 0.8791)], + ['2HB ', 4, ( 0.3640, -0.5325, -0.8792)], + [' CG ', 4, (0.5926,1.4028, 0.000)], + [' OD1', 5, (0.5746,1.0629, 0.000)], + [' OD2', 5, (0.5738,-1.0627, 0.000)], + ], + [ # 4 cys + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3481, -0.5059, 0.9006)], + [' CB ', 8, (-0.5046,-0.7727,-1.2189)], + ['1HB ', 4, ( 0.3639, -0.5324, 0.8791)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8787)], + [' SG ', 4, (0.7386,1.6511, 0.000)], + [' HG ', 5, (0.1387,1.3221, 0.000)], + ], + [ # 5 gln + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3363, -0.5013, 0.9074)], + [' CB ', 8, (-0.5226,-0.7776,-1.2109)], + ['1HB ', 4, ( 0.3638, -0.5323, 0.8789)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8788)], + [' CG ', 4, (0.6225,1.3857, 0.000)], + ['1HG ', 5, ( 0.3531, -0.5156, 0.8931)], + ['2HG ', 5, ( 0.3531, -0.5156, -0.8931)], + [' CD ', 5, (0.5788,1.4021, 0.000)], + [' NE2', 6, (0.5908,-1.1895, 0.000)], + [' OE1', 6, (0.6347,1.0584, 0.000)], + ['1HE2', 6, (1.5825, -1.2525, 0.000)], + ['2HE2', 6, (0.0380, -2.0229, 0.000)], + ], + [ # 6 glu + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3363, -0.5013, 0.9074)], + [' CB ', 8, (-0.5197,-0.7737,-1.2137)], + ['1HB ', 4, ( 0.3638, -0.5323, 0.8789)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8788)], + [' CG ', 4, (0.6287,1.3862, 0.000)], + ['1HG ', 5, ( 0.3531, -0.5156, 0.8931)], + ['2HG ', 5, ( 0.3531, -0.5156, -0.8931)], + [' CD ', 5, (0.5850,1.3849, 0.000)], + [' OE1', 6, (0.5752,1.0618, 0.000)], + [' OE2', 6, (0.5741,-1.0635, 0.000)], + ], + [ # 7 gly + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + ['1HA ', 0, ( -0.3676, -0.5329, 0.8771)], + ['2HA ', 0, ( -0.3674, -0.5325, -0.8765)], + ], + [ # 8 his + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3299, -0.5180, 0.9001)], + [' CB ', 8, (-0.5163,-0.7809,-1.2129)], + ['1HB ', 4, ( 0.3640, -0.5325, 0.8793)], + ['2HB ', 4, ( 0.3637, -0.5321, -0.8786)], + [' CG ', 4, (0.6016,1.3710, 0.000)], + [' CD2', 5, (0.8918,-1.0184, 0.000)], + [' CE1', 5, (2.0299,0.8564, 0.000)], + ['1HE ', 5, (2.8542, 1.5693, 0.000)], + ['2HD ', 5, ( 0.6584, -2.0835, 0.000) ], + [' ND1', 6, (-1.8631, -1.0722, 0.000)], + [' NE2', 6, (-1.8625, 1.0707, 0.000)], + ['2HE ', 6, (-1.5439, 2.0292, 0.000)], + ], + [ # 9 ile + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3405, -0.5028, 0.9044)], + [' CB ', 8, (-0.5140,-0.7885,-1.2184)], + [' HB ', 4, (0.3637, -0.4714, 0.9125)], + [' CG1', 4, (0.5339,1.4348,0.000)], + [' CG2', 4, (0.5319,-0.7693,-1.1994)], + ['1HG2', 4, (1.6215, -0.7588, -1.1842)], + ['2HG2', 4, (0.1785, -1.7986, -1.1569)], + ['3HG2', 4, (0.1773, -0.3016, -2.1180)], + [' CD1', 5, (0.6106,1.3829, 0.000)], + ['1HG1', 5, (0.3637, -0.5338, 0.8774)], + ['2HG1', 5, (0.3640, -0.5322, -0.8793)], + ['1HD1', 5, (1.6978, 1.3006, 0.000)], + ['2HD1', 5, (0.2873, 1.9236, -0.8902)], + ['3HD1', 5, (0.2888, 1.9224, 0.8896)], + ], + [ # 10 leu + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.525, -0.000, -0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3435, -0.5040, 0.9027)], + [' CB ', 8, (-0.5175,-0.7692,-1.2220)], + ['1HB ', 4, ( 0.3473, -0.5346, 0.8827)], + ['2HB ', 4, ( 0.3476, -0.5351, -0.8836)], + [' CG ', 4, (0.6652,1.3823, 0.000)], + [' CD1', 5, (0.5083,1.4353, 0.000)], + [' CD2', 5, (0.5079,-0.7600,1.2163)], + [' HG ', 5, (0.3640, -0.4825, -0.9075)], + ['1HD1', 5, (1.5984, 1.4353, 0.000)], + ['2HD1', 5, (0.1462, 1.9496, -0.8903)], + ['3HD1', 5, (0.1459, 1.9494, 0.8895)], + ['1HD2', 5, (1.5983, -0.7606, 1.2158)], + ['2HD2', 5, (0.1456, -0.2774, 2.1243)], + ['3HD2', 5, (0.1444, -1.7871, 1.1815)], + ], + [ # 11 lys + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3335, -0.5005, 0.9097)], + ['1HB ', 4, ( 0.3640, -0.5324, 0.8791)], + ['2HB ', 4, ( 0.3639, -0.5324, -0.8790)], + [' CB ', 8, (-0.5259,-0.7785,-1.2069)], + ['1HG ', 5, (0.3641, -0.5229, 0.8852)], + ['2HG ', 5, (0.3637, -0.5227, -0.8841)], + [' CG ', 4, (0.6291,1.3869, 0.000)], + [' CD ', 5, (0.5526,1.4174, 0.000)], + ['1HD ', 6, (0.3641, -0.5239, 0.8848)], + ['2HD ', 6, (0.3638, -0.5219, -0.8850)], + [' CE ', 6, (0.5544,1.4170, 0.000)], + [' NZ ', 7, (0.5566,1.3801, 0.000)], + ['1HE ', 7, (0.4199, -0.4638, 0.9482)], + ['2HE ', 7, (0.4202, -0.4631, -0.8172)], + ['1HZ ', 7, (1.6223, 1.3980, 0.0658)], + ['2HZ ', 7, (0.2970, 1.9326, -0.7584)], + ['3HZ ', 7, (0.2981, 1.9319, 0.8909)], + ], + [ # 12 met + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3303, -0.4990, 0.9108)], + ['1HB ', 4, ( 0.3635, -0.5318, 0.8781)], + ['2HB ', 4, ( 0.3641, -0.5326, -0.8795)], + [' CB ', 8, (-0.5331,-0.7727,-1.2048)], + ['1HG ', 5, (0.3637, -0.5256, 0.8823)], + ['2HG ', 5, (0.3638, -0.5249, -0.8831)], + [' CG ', 4, (0.6298,1.3858,0.000)], + [' SD ', 5, (0.6953,1.6645,0.000)], + [' CE ', 6, (0.3383,1.7581,0.000)], + ['1HE ', 6, (1.7054, 2.0532, -0.0063)], + ['2HE ', 6, (0.1906, 2.3099, -0.9072)], + ['3HE ', 6, (0.1917, 2.3792, 0.8720)], + ], + [ # 13 phe + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3303, -0.4990, 0.9108)], + ['1HB ', 4, ( 0.3635, -0.5318, 0.8781)], + ['2HB ', 4, ( 0.3641, -0.5326, -0.8795)], + [' CB ', 8, (-0.5150,-0.7729,-1.2156)], + [' CG ', 4, (0.6060,1.3746, 0.000)], + [' CD1', 5, (0.7078,1.1928, 0.000)], + [' CD2', 5, (0.7084,-1.1920, 0.000)], + [' CE1', 5, (2.0900,1.1940, 0.000)], + [' CE2', 5, (2.0897,-1.1939, 0.000)], + [' CZ ', 5, (2.7809, 0.000, 0.000)], + ['1HD ', 5, (0.1613, 2.1362, 0.000)], + ['2HD ', 5, (0.1621, -2.1360, 0.000)], + ['1HE ', 5, (2.6335, 2.1384, 0.000)], + ['2HE ', 5, (2.6344, -2.1378, 0.000)], + [' HZ ', 5, (3.8700, 0.000, 0.000)], + ], + [ # 14 pro + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' HA ', 0, (-0.3868, -0.5380, 0.8781)], + ['1HB ', 4, ( 0.3762, -0.5355, 0.8842)], + ['2HB ', 4, ( 0.3762, -0.5355, -0.8842)], + [' CB ', 8, (-0.5649,-0.5888,-1.2966)], + [' CG ', 4, (0.3657,1.4451,0.0000)], + [' CD ', 5, (0.3744,1.4582, 0.0)], + ['1HG ', 5, (0.3798, -0.5348, 0.8830)], + ['2HG ', 5, (0.3798, -0.5348, -0.8830)], + ['1HD ', 6, (0.3798, -0.5348, 0.8830)], + ['2HD ', 6, (0.3798, -0.5348, -0.8830)], + ], + [ # 15 ser + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3425, -0.5041, 0.9048)], + ['1HB ', 4, ( 0.3637, -0.5321, 0.8786)], + ['2HB ', 4, ( 0.3636, -0.5319, -0.8782)], + [' CB ', 8, (-0.5146,-0.7595,-1.2073)], + [' OG ', 4, (0.5021,1.3081, 0.000)], + [' HG ', 5, (0.2647, 0.9230, 0.000)], + ], + [ # 16 thr + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3364, -0.5015, 0.9078)], + [' HB ', 4, ( 0.3638, -0.5006, 0.8971)], + ['1HG2', 4, ( 1.6231, -0.7142, -1.2097)], + ['2HG2', 4, ( 0.1792, -1.7546, -1.2237)], + ['3HG2', 4, ( 0.1808, -0.2222, -2.1269)], + [' CB ', 8, (-0.5172,-0.7952,-1.2130)], + [' CG2', 4, (0.5334,-0.7239,-1.2267)], + [' OG1', 4, (0.4804,1.3506,0.000)], + [' HG1', 5, (0.3194, 0.9056, 0.000)], + ], + [ # 17 trp + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3436, -0.5042, 0.9031)], + ['1HB ', 4, ( 0.3639, -0.5323, 0.8790)], + ['2HB ', 4, ( 0.3638, -0.5322, -0.8787)], + [' CB ', 8, (-0.5136,-0.7712,-1.2173)], + [' CG ', 4, (0.5984,1.3741, 0.000)], + [' CD1', 5, (0.8151,1.0921, 0.000)], + [' CD2', 5, (0.8753,-1.1538, 0.000)], + [' CE2', 5, (2.1865,-0.6707, 0.000)], + [' CE3', 5, (0.6541,-2.5366, 0.000)], + [' NE1', 5, (2.1309,0.7003, 0.000)], + [' CH2', 5, (3.0315,-2.8930, 0.000)], + [' CZ2', 5, (3.2813,-1.5205, 0.000)], + [' CZ3', 5, (1.7521,-3.3888, 0.000)], + ['1HD ', 5, (0.4722, 2.1252, 0.000)], + ['1HE ', 5, ( 2.9291, 1.3191, 0.000)], + [' HE3', 5, (-0.3597, -2.9356, 0.000)], + [' HZ2', 5, (4.3053, -1.1462, 0.000)], + [' HZ3', 5, ( 1.5712, -4.4640, 0.000)], + [' HH2', 5, ( 3.8700, -3.5898, 0.000)], + ], + [ # 18 tyr + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3305, -0.4992, 0.9112)], + ['1HB ', 4, ( 0.3642, -0.5327, 0.8797)], + ['2HB ', 4, ( 0.3637, -0.5321, -0.8785)], + [' CB ', 8, (-0.5305,-0.7799,-1.2051)], + [' CG ', 4, (0.6104,1.3840, 0.000)], + [' CD1', 5, (0.6936,1.2013, 0.000)], + [' CD2', 5, (0.6934,-1.2011, 0.000)], + [' CE1', 5, (2.0751,1.2013, 0.000)], + [' CE2', 5, (2.0748,-1.2011, 0.000)], + [' OH ', 5, (4.1408, 0.000, 0.000)], + [' CZ ', 5, (2.7648, 0.000, 0.000)], + ['1HD ', 5, (0.1485, 2.1455, 0.000)], + ['2HD ', 5, (0.1484, -2.1451, 0.000)], + ['1HE ', 5, (2.6200, 2.1450, 0.000)], + ['2HE ', 5, (2.6199, -2.1453, 0.000)], + [' HH ', 6, (0.3190, 0.9057, 0.000)], + ], + [ # 19 val + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3497, -0.5068, 0.9002)], + [' CB ', 8, (-0.5105,-0.7712,-1.2317)], + [' CG1', 4, (0.5326,1.4252, 0.000)], + [' CG2', 4, (0.5177,-0.7693,1.2057)], + [' HB ', 4, (0.3541, -0.4754, -0.9148)], + ['1HG1', 4, (1.6228, 1.4063, 0.000)], + ['2HG1', 4, (0.1790, 1.9457, -0.8898)], + ['3HG1', 4, (0.1798, 1.9453, 0.8903)], + ['1HG2', 4, (1.6073, -0.7659, 1.1989)], + ['2HG2', 4, (0.1586, -0.2971, 2.1203)], + ['3HG2', 4, (0.1582, -1.7976, 1.1631)], + ], + [ # 20 unk + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3341, -0.4928, 0.9132)], + [' CB ', 8, (-0.5289,-0.7734,-1.1991)], + ['1HB ', 8, (-0.1265, -1.7863, -1.1851)], + ['2HB ', 8, (-1.6173, -0.8147, -1.1541)], + ['3HB ', 8, (-0.2229, -0.2744, -2.1172)], + ], + [ # 21 mask + [' N ', 0, (-0.5272, 1.3593, 0.000)], + [' CA ', 0, (0.000, 0.000, 0.000)], + [' C ', 0, (1.5233, 0.000, 0.000)], + [' O ', 3, (0.6303, 1.0574, 0.000)], + [' H ', 2, (0.4920,-0.8821, 0.0000)], + [' HA ', 0, (-0.3341, -0.4928, 0.9132)], + [' CB ', 8, (-0.5289,-0.7734,-1.1991)], + ['1HB ', 8, (-0.1265, -1.7863, -1.1851)], + ['2HB ', 8, (-1.6173, -0.8147, -1.1541)], + ['3HB ', 8, (-0.2229, -0.2744, -2.1172)], + ], + [ # 22 DA + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7411, 1.2354, 0.000)], + [" C4'",10, (0.5207, 1.4178, 0.000)], + [" H5'",10, (0.3748, -0.5360, -0.8839)], + ["H5''",10, (0.3744, -0.5368, 0.8842)], + [" C3'",11, ( 0.6388, 1.3889, 0.000)], + [" H4'",11, ( 0.2823, -0.5105, 0.9326)], + [" O4'",11, (0.4804, -0.6610, -1.1947)], + [" C1'",13, (0.4913, 1.3316, 0.0000)], + [" H1'",14, (0.4561, -0.4898, 0.8726)], + [" N9 ",14, (0.4467, -0.7474, -1.1746)], + [" C2'",14, (0.4167, 1.4603, 0.0000)], + [" H2'",15, (0.4107, -0.5097, -0.8844)], + ["H2''",15, (0.4106, -0.5096, 0.8840)], + [" O3'",12, ( 0.4966, 1.3432, 0.000)], + [" H3'",12, (0.4359, -0.4915, -0.8827)], + [" C4 ",16, (0.8119, 1.1084, 0.0000)], + [" N3 ",16, (0.4328, 2.3976, 0.0000)], + [" C2 ",16, (1.4957, 3.1983, 0.0000)], + [" N1 ",16, (2.7960, 2.8816, 0.0000)], + [" C6 ",16, (3.1433, 1.5760, 0.0000)], + [" C5 ",16, (2.1084, 0.6255, 0.0000)], + [" N7 ",16, (2.1145, -0.7627, 0.0000)], + [" C8 ",16, (0.8438, -1.0825, 0.0000)], + [" N6 ",16, (4.4402, 1.2598, 0.0000)], + [" H2 ",16, (1.2740, 4.2755, 0.0000)], + [" H8 ",16, (0.4867, -2.1227, 0.0000)], + [" H61",16, (5.1313, 1.9828, 0.0000)], + [" H62",16, (4.7211, 0.3001, 0.0000)], + ], + [ # 23 DC + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7411, 1.2354, 0.000)], + [" C4'",10, (0.5207, 1.4178, 0.000)], + [" H5'",10, (0.3748, -0.5360, -0.8839)], + ["H5''",10, (0.3744, -0.5368, 0.8842)], + [" C3'",11, ( 0.6388, 1.3889, 0.000)], + [" H4'",11, ( 0.2823, -0.5105, 0.9326)], + [" O4'",11, (0.4804, -0.6610, -1.1947)], + [" C1'",13, (0.4913, 1.3316, 0.0000)], + [" H1'",14, (0.4561, -0.4898, 0.8726)], + [" N1 ",14, (0.4467, -0.7474, -1.1746)], + [" C2'",14, (0.4167, 1.4603, 0.0000)], + [" H2'",15, (0.4107, -0.5097, -0.8844)], + ["H2''",15, (0.4106, -0.5096, 0.8840)], + [" O3'",12, ( 0.4966, 1.3432, 0.000)], + [" H3'",12, (0.4359, -0.4915, -0.8827)], + [" C2 ",16, (0.6758, 1.2249, 0.0000)], + [" O2 ",16, (0.0158, 2.2756, 0.0000)], + [" N3 ",16, (2.0283, 1.2334, 0.0000)], + [" C4 ",16, (2.7022, 0.0815, 0.0000)], + [" N4 ",16, (4.0356, 0.1372, 0.0000)], + [" C5 ",16, (2.0394, -1.1794, 0.0000)], + [" C6 ",16, (0.7007, -1.1745, 0.0000)], + [" H42",16, (4.5715, -0.7074, 0.0000)], + [" H41",16, (4.4992, 1.0229, 0.0000)], + [" H5 ",16, (2.6061, -2.1225, 0.0000)], + [" H6 ",16, (0.1563, -2.1302, 0.0000)], + ], + [ # 24 DG + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7411, 1.2354, 0.000)], + [" C4'",10, (0.5207, 1.4178, 0.000)], + [" H5'",10, (0.3748, -0.5360, -0.8839)], + ["H5''",10, (0.3744, -0.5368, 0.8842)], + [" C3'",11, ( 0.6388, 1.3889, 0.000)], + [" H4'",11, ( 0.2823, -0.5105, 0.9326)], + [" O4'",11, (0.4804, -0.6610, -1.1947)], + [" C1'",13, (0.4913, 1.3316, 0.0000)], + [" H1'",14, (0.4561, -0.4898, 0.8726)], + [" N9 ",14, (0.4467, -0.7474, -1.1746)], + [" C2'",14, (0.4167, 1.4603, 0.0000)], + [" H2'",15, (0.4107, -0.5097, -0.8844)], + ["H2''",15, (0.4106, -0.5096, 0.8840)], + [" O3'",12, ( 0.4966, 1.3432, 0.000)], + [" H3'",12, (0.4359, -0.4915, -0.8827)], + [" C4 ",16, (0.8171, 1.1043, 0.0000)], + [" N3 ",16, (0.4110, 2.3918, 0.0000)], + [" C2 ",16, (1.4330, 3.2319, 0.0000)], + [" N1 ",16, (2.7493, 2.8397, 0.0000)], + [" C6 ",16, (3.1894, 1.5195, 0.0000)], + [" C5 ",16, (2.1029, 0.6070, 0.0000)], + [" N7 ",16, (2.0942, -0.7800, 0.0000)], + [" C8 ",16, (0.8285, -1.0956, 0.0000)], + [" N2 ",16, (1.2085, 4.5537, 0.0000)], + [" O6 ",16, (4.4017, 1.2743, 0.0000)], + [" H1 ",16, (3.4453, 3.5579, 0.0000)], + [" H8 ",16, (0.4623, -2.1330, 0.0000)], + [" H22",16, (0.2708, 4.9015, 0.0000)], + [" H21",16, (1.9785, 5.1920, 0.0000)], + ], + [ # 25 DT + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7411, 1.2354, 0.000)], + [" C4'",10, (0.5207, 1.4178, 0.000)], + [" H5'",10, (0.3748, -0.5360, -0.8839)], + ["H5''",10, (0.3744, -0.5368, 0.8842)], + [" C3'",11, ( 0.6388, 1.3889, 0.000)], + [" H4'",11, ( 0.2823, -0.5105, 0.9326)], + [" O4'",11, (0.4804, -0.6610, -1.1947)], + [" C1'",13, (0.4913, 1.3316, 0.0000)], + [" H1'",14, (0.4561, -0.4898, 0.8726)], + [" N1 ",14, (0.4467, -0.7474, -1.1746)], + [" C2'",14, (0.4167, 1.4603, 0.0000)], + [" H2'",15, (0.4107, -0.5097, -0.8844)], + ["H2''",15, (0.4106, -0.5096, 0.8840)], + [" O3'",12, ( 0.4966, 1.3432, 0.000)], + [" H3'",12, (0.4359, -0.4915, -0.8827)], + [" C2 ",16, (0.6495, 1.2140, 0.0000)], + [" O2 ",16, (0.0636, 2.2854, 0.0000)], + [" N3 ",16, (2.0191, 1.1297, 0.0000)], + [" C4 ",16, (2.7859, -0.0198, 0.0000)], + [" O4 ",16, (4.0113, 0.0622, 0.0000)], + [" C5 ",16, (2.0397, -1.2580, 0.0000)], + [" C7 ",16, (2.7845, -2.5550, 0.0000)], + [" C6 ",16, (0.7021, -1.1863, 0.0000)], + [" H3 ",16, (2.5175, 1.9968, 0.0000)], + [" H71",16, (2.0680, -3.3898, 0.0000)], + [" H72",16, (3.4147, -2.6153, -0.9071)], + [" H73",16, (3.4193, -2.6153, 0.8885)], + [" H6 ",16, (0.1317, -2.1273, 0.0000)], + ], + [ # 26 DX + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7411, 1.2354, 0.000)], + [" C4'",10, (0.5207, 1.4178, 0.000)], + [" H5'",10, (0.3748, -0.5360, -0.8839)], + ["H5''",10, (0.3744, -0.5368, 0.8842)], + [" C3'",11, ( 0.6388, 1.3889, 0.000)], + [" H4'",11, ( 0.2823, -0.5105, 0.9326)], + [" O4'",11, (0.4804, -0.6610, -1.1947)], + [" C1'",13, (0.4913, 1.3316, 0.0000)], + [" H1'",14, (0.4561, -0.4898, 0.8726)], + [" C2'",14, (0.4167, 1.4603, 0.0000)], + [" H2'",15, (0.4107, -0.5097, -0.8844)], + ["H2''",15, (0.4106, -0.5096, 0.8840)], + [" O3'",12, ( 0.4966, 1.3432, 0.000)], + [" H3'",12, (0.4359, -0.4915, -0.8827)], + ], + [ # 27 A + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7289, 1.2185, 0.000)], + [" C4'",10, (0.5541, 1.4027, 0.000)], + [" H5'",10, (0.3201, -0.4698, -0.7986)], + ["H5''",10, (0.3206, -0.4706, 0.7970)], + [" C3'",11, ( 0.6673, 1.3669, 0.000)], + [" H4'",11, ( 0.3173, -0.5074, 0.7763)], + [" O4'",11, ( 0.4914, -0.6338, -1.2098)], + [" C1'",13, (0.4828, 1.3277, -0.0000)], + [" H1'",14, (0.3265, -0.4460, 0.8101)], + [" N9 ",14, (0.4722, -0.7339, -1.1894)], + [" C2'",14, (0.4641, 1.4573, 0.0000)], + [" H2'",15, (0.3582, -0.4393, -0.7998)], + [" O2'",15, (0.4613, -0.6189, 1.1921)], + ["HO2'",15, (0.2499, -1.5749, 1.1568)], + [" O3'",12, ( 0.5548, 1.3039, 0.000)], + [" H3'",12, ( 0.3215, -0.4857, -0.7888)], + [" N1 ",16, (2.7963, 2.8824, 0.0000)], + [" C2 ",16, (1.4955, 3.2007, 0.0000)], + [" N3 ",16, (0.4333, 2.3980, 0.0000)], + [" C4 ",16, (0.8127, 1.1078, 0.0000)], + [" C5 ",16, (2.1082, 0.6254, 0.0000)], + [" C6 ",16, (3.1432, 1.5774, 0.0000)], + [" N6 ",16, (4.4400, 1.2609, 0.0000)], + [" N7 ",16, (2.1146, -0.7630, 0.0000)], + [" C8 ",16, (0.8442, -1.0830, 0.0000)], + [" H2 ",16, (1.2972, 4.1608, 0.0000)], + [" H61",16, (5.1172, 1.9697, 0.0000)], + [" H62",16, (4.7154, 0.3206, 0.0000)], + [" H8 ",16, (0.5258, -2.0104, 0.0000)], + ], + [ # 28 C + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7289, 1.2185, 0.000)], + [" C4'",10, (0.5541, 1.4027, 0.000)], + [" H5'",10, (0.3201, -0.4698, -0.7986)], + ["H5''",10, (0.3206, -0.4706, 0.7970)], + [" C3'",11, ( 0.6673, 1.3669, 0.000)], + [" H4'",11, ( 0.3173, -0.5074, 0.7763)], + [" O4'",11, ( 0.4914, -0.6338, -1.2098)], + [" C1'",13, (0.4828, 1.3277, -0.0000)], + [" H1'",14, (0.3265, -0.4460, 0.8101)], + [" N1 ",14, (0.4722, -0.7339, -1.1894)], + [" C2'",14, (0.4641, 1.4573, 0.0000)], + [" H2'",15, (0.3582, -0.4393, -0.7998)], + [" O2'",15, (0.4613, -0.6189, 1.1921)], + ["HO2'",15, (0.2499, -1.5749, 1.1568)], + [" O3'",12, ( 0.5548, 1.3039, 0.000)], + [" H3'",12, ( 0.3215, -0.4857, -0.7888)], + [" C2 ",16, (0.6650, 1.2325, 0.0000)], + [" O2 ",16, (-0.0001, 2.2799, 0.0000)], + [" N3 ",16, (2.0175, 1.2603, 0.0000)], + [" C4 ",16, (2.7090, 0.1210, 0.0000)], + [" N4 ",16, (4.0423, 0.1969, 0.0000)], + [" C5 ",16, (2.0635, -1.1476, 0.0000)], + [" C6 ",16, (0.7250, -1.1627, 0.0000)], + [" H42",16, (4.5791, -0.6226, 0.0000)], + [" H41",16, (4.4833, 1.0723, 0.0000)], + [" H5 ",16, (2.5806, -1.9803, 0.0000)], + [" H6 ",16, (0.2622, -2.0258, 0.0000)], + ], + [ # 29 G + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7289, 1.2185, 0.000)], + [" C4'",10, (0.5541, 1.4027, 0.000)], + [" H5'",10, (0.3201, -0.4698, -0.7986)], + ["H5''",10, (0.3206, -0.4706, 0.7970)], + [" C3'",11, ( 0.6673, 1.3669, 0.000)], + [" H4'",11, ( 0.3173, -0.5074, 0.7763)], + [" O4'",11, ( 0.4914, -0.6338, -1.2098)], + [" C1'",13, (0.4828, 1.3277, -0.0000)], + [" H1'",14, (0.3265, -0.4460, 0.8101)], + [" N9 ",14, (0.4722, -0.7339, -1.1894)], + [" C2'",14, (0.4641, 1.4573, 0.0000)], + [" H2'",15, (0.3582, -0.4393, -0.7998)], + [" O2'",15, (0.4613, -0.6189, 1.1921)], + ["HO2'",15, (0.2499, -1.5749, 1.1568)], + [" O3'",12, ( 0.5548, 1.3039, 0.000)], + [" H3'",12, ( 0.3215, -0.4857, -0.7888)], + [" N1 ",16, (2.7458, 2.8461, 0.0000)], + [" C2 ",16, (1.4286, 3.2360, 0.0000)], + [" N2 ",16, (1.1989, 4.5575, 0.0000)], + [" N3 ",16, (0.4087, 2.3932, 0.0000)], + [" C4 ",16, (0.8167, 1.1068, 0.0000)], + [" C5 ",16, (2.1036, 0.6115, 0.0000)], + [" C6 ",16, (3.1883, 1.5266, 0.0000)], + [" O6 ",16, (4.4006, 1.2842, 0.0000)], + [" N7 ",16, (2.0980, -0.7759, 0.0000)], + [" C8 ",16, (0.8317, -1.0936, 0.0000)], + [" H1 ",16, (3.4279, 3.5496, 0.0000)], + [" H22",16, (0.2781, 4.8947, 0.0000)], + [" H21",16, (1.9487, 5.1879, 0.0000)], + [" H8 ",16, (0.5085, -2.0185, 0.0000)], + ], + [ # 30 U + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7289, 1.2185, 0.000)], + [" C4'",10, (0.5541, 1.4027, 0.000)], + [" H5'",10, (0.3201, -0.4698, -0.7986)], + ["H5''",10, (0.3206, -0.4706, 0.7970)], + [" C3'",11, ( 0.6673, 1.3669, 0.000)], + [" H4'",11, ( 0.3173, -0.5074, 0.7763)], + [" O4'",11, ( 0.4914, -0.6338, -1.2098)], + [" C1'",13, (0.4828, 1.3277, -0.0000)], + [" H1'",14, (0.3265, -0.4460, 0.8101)], + [" N1 ",14, (0.4722, -0.7339, -1.1894)], + [" C2'",14, (0.4641, 1.4573, 0.0000)], + [" H2'",15, (0.3582, -0.4393, -0.7998)], + [" O2'",15, (0.4613, -0.6189, 1.1921)], + ["HO2'",15, (0.2499, -1.5749, 1.1568)], + [" O3'",12, ( 0.5548, 1.3039, 0.000)], + [" H3'",12, ( 0.3215, -0.4857, -0.7888)], + [" C2 ",16, (0.6307, 1.2305, 0.0000)], + [" O2 ",16, (0.0260, 2.2886, 0.0000)], + [" N3 ",16, (2.0031, 1.1816, 0.0000)], + [" C4 ",16, (2.7953, 0.0532, 0.0000)], + [" O4 ",16, (4.0212, 0.1751, 0.0000)], + [" C5 ",16, (2.0746, -1.1833, 0.0000)], + [" C6 ",16, (0.7378, -1.1648, 0.0000)], + [" H3 ",16, (2.4701, 2.0428, 0.0000)], + [" H5 ",16, (2.5579, -2.0361, 0.0000)], + [" H6 ",16, (0.2681, -2.0239, 0.0000)], + ], + [ # 31 RX + [" OP1", 0, (-0.7319, 1.2920, 0.000)], + [" P ", 0, (0.000, 0.000, 0.000)], + [" OP2", 0, (1.4855, 0.000, 0.000)], + [" O5'", 0, (-0.4948, -0.8559, 1.2489)], + [" C5'", 9, (0.7289, 1.2185, 0.000)], + [" C4'",10, (0.5541, 1.4027, 0.000)], + [" H5'",10, (0.3201, -0.4698, -0.7986)], + ["H5''",10, (0.3206, -0.4706, 0.7970)], + [" C3'",11, ( 0.6673, 1.3669, 0.000)], + [" H4'",11, ( 0.3173, -0.5074, 0.7763)], + [" O4'",11, ( 0.4914, -0.6338, -1.2098)], + [" C1'",13, (0.4828, 1.3277, -0.0000)], + [" H1'",14, (0.3265, -0.4460, 0.8101)], + [" C2'",14, (0.4641, 1.4573, 0.0000)], + [" H2'",15, (0.3582, -0.4393, -0.7998)], + [" O2'",15, (0.4613, -0.6189, 1.1921)], + ["HO2'",15, (0.2499, -1.5749, 1.1568)], + [" O3'",12, ( 0.5548, 1.3039, 0.000)], + [" H3'",12, ( 0.3215, -0.4857, -0.7888)], + ], +] + +frame_priority2atom = ["F", "Cl", "Br", "I", "O", "S", "Se", "Te", "N", "P", "As", "Sb", "C", "Si", "Sn", "Pb", "B", "Al", + "Zn", "Hg", "Cu", "Au", "Ni", "Pd", "Pt", "Co", "Rh", "Ir", "Pr", "Fe", "Ru", "Os", "Mn", "Re", "Cr", "Mo", "W", "U", "Tb", "Y", "Be", "Mg", "Ca", "Li", "ATM"] +atom2frame_priority = {x:i for i,x in enumerate(frame_priority2atom)} \ No newline at end of file diff --git a/RF2_allatom/coords6d.py b/RF2_allatom/coords6d.py new file mode 100644 index 0000000..d6f95bb --- /dev/null +++ b/RF2_allatom/coords6d.py @@ -0,0 +1,78 @@ +import numpy as np +import scipy +import scipy.spatial +from util import generate_Cbeta + +# calculate dihedral angles defined by 4 sets of points +def get_dihedrals(a, b, c, d): + + b0 = -1.0*(b - a) + b1 = c - b + b2 = d - c + + b1 /= np.linalg.norm(b1, axis=-1)[:,None] + + v = b0 - np.sum(b0*b1, axis=-1)[:,None]*b1 + w = b2 - np.sum(b2*b1, axis=-1)[:,None]*b1 + + x = np.sum(v*w, axis=-1) + y = np.sum(np.cross(b1, v)*w, axis=-1) + + return np.arctan2(y, x) + +# calculate planar angles defined by 3 sets of points +def get_angles(a, b, c): + + v = a - b + v /= np.linalg.norm(v, axis=-1)[:,None] + + w = c - b + w /= np.linalg.norm(w, axis=-1)[:,None] + + x = np.sum(v*w, axis=1) + + #return np.arccos(x) + return np.arccos(np.clip(x, -1.0, 1.0)) + +# get 6d coordinates from x,y,z coords of N,Ca,C atoms +def get_coords6d(xyz, dmax): + + nres = xyz.shape[1] + + # three anchor atoms + N = xyz[0] + Ca = xyz[1] + C = xyz[2] + + # recreate Cb given N,Ca,C + Cb = generate_Cbeta(N,Ca,C) + + # fast neighbors search to collect all + # Cb-Cb pairs within dmax + kdCb = scipy.spatial.cKDTree(Cb) + indices = kdCb.query_ball_tree(kdCb, dmax) + + # indices of contacting residues + idx = np.array([[i,j] for i in range(len(indices)) for j in indices[i] if i != j]).T + idx0 = idx[0] + idx1 = idx[1] + + # Cb-Cb distance matrix + dist6d = np.full((nres, nres),999.9, dtype=np.float32) + dist6d[idx0,idx1] = np.linalg.norm(Cb[idx1]-Cb[idx0], axis=-1) + + # matrix of Ca-Cb-Cb-Ca dihedrals + omega6d = np.zeros((nres, nres), dtype=np.float32) + omega6d[idx0,idx1] = get_dihedrals(Ca[idx0], Cb[idx0], Cb[idx1], Ca[idx1]) + + # matrix of polar coord theta + theta6d = np.zeros((nres, nres), dtype=np.float32) + theta6d[idx0,idx1] = get_dihedrals(N[idx0], Ca[idx0], Cb[idx0], Cb[idx1]) + + # matrix of polar coord phi + phi6d = np.zeros((nres, nres), dtype=np.float32) + phi6d[idx0,idx1] = get_angles(Ca[idx0], Cb[idx0], Cb[idx1]) + + mask = np.zeros((nres, nres), dtype=np.float32) + mask[idx0, idx1] = 1.0 + return dist6d, omega6d, theta6d, phi6d, mask diff --git a/RF2_allatom/data_loader.py b/RF2_allatom/data_loader.py new file mode 100644 index 0000000..119af0b --- /dev/null +++ b/RF2_allatom/data_loader.py @@ -0,0 +1,1954 @@ +import torch +from torch.utils import data +import os +import csv +from dateutil import parser +import numpy as np +from parsers import parse_a3m, parse_pdb, parse_fasta_if_exists, parse_mol, get_ligand_xyz +from chemical import INIT_CRDS, INIT_NA_CRDS, NAATOKENS, MASKINDEX, NTOTAL, NBTYPES +from util import get_nxgraph, get_atom_frames, get_bond_feats, get_protein_bond_feats +import pickle +import random +import ast + +from scipy.sparse.csgraph import shortest_path + +base_dir = "/projects/ml/TrRosetta/PDB-2021AUG02" +compl_dir = "/projects/ml/RoseTTAComplex" +#na_dir = "/projects/ml/nucleic" +na_dir = "/home/dimaio/TrRosetta/nucleic" +fb_dir = "/projects/ml/TrRosetta/fb_af" +mol_dir = "/projects/ml/ligand_datasets/mmcif_parse_wlig" +if not os.path.exists(base_dir): + # training on blue + base_dir = "/gscratch2/PDB-2021AUG02" + compl_dir = "/gscratch2/RoseTTAComplex" + na_dir = "/gscratch2/nucleic" + fb_dir = "/gscratch2/fb_af1" + +def set_data_loader_params(args): + PARAMS = { + "COMPL_LIST" : "%s/list.hetero.csv"%compl_dir, + "HOMO_LIST" : "%s/list.homo.csv"%compl_dir, + "NEGATIVE_LIST" : "%s/list.negative.csv"%compl_dir, + "RNA_LIST" : "%s/list.rnaonly.csv"%na_dir, + "NA_COMPL_LIST" : "%s/list.nucleic.csv"%na_dir, + "NEG_NA_COMPL_LIST": "%s/list.na_negatives.csv"%na_dir, + "SM_LIST" : "%s/list_v02_ligonly_notest.csv"%base_dir, + "PDB_LIST" : "%s/list_v02.csv"%base_dir, # on digs + #"PDB_LIST" : "/gscratch2/list_2021AUG02.csv", # on blue + "FB_LIST" : "%s/list_b1-3.csv"%fb_dir, + "VAL_PDB" : "./valid_remapped", + "VAL_RNA" : "%s/rna_valid.csv"%na_dir, + "VAL_COMPL" : "%s/val_lists/xaa"%compl_dir, + "VAL_NEG" : "%s/val_lists/xaa.neg"%compl_dir, + "TEST_SM" : "./lig_test", + "DATAPKL" : "./dataset.pkl", # cache for faster loading + "PDB_DIR" : base_dir, + "FB_DIR" : fb_dir, + "COMPL_DIR" : compl_dir, + "NA_DIR" : na_dir, + "MOL_DIR" : mol_dir, + "MINTPLT" : 0, + "MAXTPLT" : 5, + "MINSEQ" : 1, + "MAXSEQ" : 1024, + "MAXLAT" : 128, + "CROP" : 256, + "DATCUT" : "2020-Apr-30", + "RESCUT" : 4.5, + "BLOCKCUT" : 5, + "PLDDTCUT" : 70.0, + "SCCUT" : 90.0, + "ROWS" : 1, + "SEQID" : 95.0, + "MAXCYCLE" : 4 + } + for param in PARAMS: + if hasattr(args, param.lower()): + PARAMS[param] = getattr(args, param.lower()) + return PARAMS + +def MSABlockDeletion(msa, ins, nb=5): + ''' + Input: MSA having shape (N, L) + output: new MSA with block deletion + ''' + N, L = msa.shape + block_size = max(int(N*0.3), 1) + block_start = np.random.randint(low=1, high=N, size=nb) # (nb) + to_delete = block_start[:,None] + np.arange(block_size)[None,:] + to_delete = np.unique(np.clip(to_delete, 1, N-1)) + # + mask = np.ones(N, np.bool) + mask[to_delete] = 0 + + return msa[mask], ins[mask] + +def cluster_sum(data, assignment, N_seq, N_res): + csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float()) + return csum + +def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[], tocpu=False): + ''' + Input: full MSA information (after Block deletion if necessary) & full insertion information + Output: seed MSA features & extra sequences + + Seed MSA features: + - aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask) + - profile of clustered sequences (22) + - insertion statistics (2) + - N-term or C-term? (2) + extra sequence features: + - aatype of extra sequence (22) + - insertion info (1) + - N-term or C-term? (2) + ''' + N, L = msa.shape + + term_info = torch.zeros((L,2), device=msa.device).float() + if len(L_s) < 1: + term_info[0,0] = 1.0 # flag for N-term + term_info[-1,1] = 1.0 # flag for C-term + else: + start = 0 + for L_chain in L_s: + term_info[start, 0] = 1.0 # flag for N-term + term_info[start+L_chain-1,1] = 1.0 # flag for C-term + start += L_chain + + # raw MSA profile + raw_profile = torch.nn.functional.one_hot(msa, num_classes=NAATOKENS) + raw_profile = raw_profile.float().mean(dim=0) + + # Select Nclust sequence randomly (seed MSA or latent MSA) + Nclust = (min(N, params['MAXLAT'])-1) // nmer + Nclust = Nclust*nmer + 1 + + if N > Nclust*2: + Nextra = N - Nclust + else: + Nextra = N + Nextra = min(Nextra, params['MAXSEQ']) // nmer + Nextra = max(1, Nextra * nmer) + # + b_seq = list() + b_msa_clust = list() + b_msa_seed = list() + b_msa_extra = list() + b_mask_pos = list() + for i_cycle in range(params['MAXCYCLE']): + sample_mono = torch.randperm((N-1)//nmer, device=msa.device) + sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)] + sample = torch.stack(sample, dim=-1) + sample = sample.reshape(-1) + msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0) + ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0) + + # 15% random masking + # - 10%: aa replaced with a uniformly sampled random amino acid + # - 10%: aa replaced with an amino acid sampled from the MSA profile + # - 10%: not replaced + # - 70%: replaced with a special token ("mask") + random_aa = torch.tensor([[0.05]*20 + [0.0]*(NAATOKENS-20)], device=msa.device) + same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=NAATOKENS) + probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa + #probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7) + probs[...,MASKINDEX]=0.7 + + sampler = torch.distributions.categorical.Categorical(probs=probs) + mask_sample = sampler.sample() + + mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask + mask_pos[msa_clust>MASKINDEX]=False # no masking on NAs + + msa_masked = torch.where(mask_pos, mask_sample, msa_clust) + b_seq.append(msa_masked[0].clone()) + + ## get extra sequenes + if N > Nclust*2: # there are enough extra sequences + msa_extra = msa[1:,:][sample[Nclust-1:]] + ins_extra = ins[1:,:][sample[Nclust-1:]] + extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device) + elif N - Nclust < 1: + msa_extra = msa_masked.clone() + ins_extra = ins_clust.clone() + extra_mask = mask_pos.clone() + else: + msa_add = msa[1:,:][sample[Nclust-1:]] + ins_add = ins[1:,:][sample[Nclust-1:]] + mask_add = torch.full(msa_add.shape, False, device=msa_add.device) + msa_extra = torch.cat((msa_masked, msa_add), dim=0) + ins_extra = torch.cat((ins_clust, ins_add), dim=0) + extra_mask = torch.cat((mask_pos, mask_add), dim=0) + N_extra = msa_extra.shape[0] + + # clustering (assign remaining sequences to their closest cluster by Hamming distance + msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=NAATOKENS) + msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=NAATOKENS) + count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps + count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float() + agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T) + assignment = torch.argmax(agreement, dim=-1) + + # seed MSA features + # 1. one_hot encoded aatype: msa_clust_onehot + # 2. cluster profile + count_extra = ~extra_mask + count_clust = ~mask_pos + msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L) + msa_clust_profile += count_clust[:,:,None]*msa_clust_profile + count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L) + count_profile += count_clust + count_profile += eps + msa_clust_profile /= count_profile[:,:,None] + # 3. insertion statistics + msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L) + msa_clust_del += count_clust*ins_clust + msa_clust_del /= count_profile + ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1) + msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1) + ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1) + # + msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1) + + # extra MSA features + ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1) + msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1) + + if (tocpu): + b_msa_clust.append(msa_clust.cpu()) + b_msa_seed.append(msa_seed.cpu()) + b_msa_extra.append(msa_extra.cpu()) + b_mask_pos.append(mask_pos.cpu()) + else: + b_msa_clust.append(msa_clust) + b_msa_seed.append(msa_seed) + b_msa_extra.append(msa_extra) + b_mask_pos.append(mask_pos) + + b_seq = torch.stack(b_seq) + b_msa_clust = torch.stack(b_msa_clust) + b_msa_seed = torch.stack(b_msa_seed) + b_msa_extra = torch.stack(b_msa_extra) + b_mask_pos = torch.stack(b_mask_pos) + + return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos + +def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, pick_top=True): + seqID_cut = params['SEQID'] + + ntplt = len(tplt['ids']) + if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ + xyz = torch.full((1, qlen, NTOTAL, 3), np.nan).float() + t1d = torch.nn.functional.one_hot( + torch.full((1, qlen), 20).long(), num_classes=NAATOKENS-1).float() # all gaps (no mask token) + conf = torch.zeros((1, qlen, 1)).float() + t1d = torch.cat((t1d, conf), -1) + return xyz, t1d + + # ignore templates having too high seqID + if seqID_cut <= 100.0: + sel = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0] + tplt['ids'] = np.array(tplt['ids'])[sel] + tplt['qmap'] = tplt['qmap'][:,sel] + tplt['xyz'] = tplt['xyz'][:, sel] + tplt['seq'] = tplt['seq'][:, sel] + tplt['f1d'] = tplt['f1d'][:, sel] + + # check again if there are templates having seqID < cutoff + ntplt = len(tplt['ids']) + npick = min(npick, ntplt) + if npick<1: # no templates + xyz = torch.full((1,qlen,NTOTAL,3),np.nan).float() + t1d = torch.nn.functional.one_hot( + torch.full((1, qlen), 20).long(), num_classes=NAATOKENS-1).float() # all gaps (no mask token) + conf = torch.zeros((1, qlen, 1)).float() + t1d = torch.cat((t1d, conf), -1) + return xyz, t1d + + if not pick_top: # select randomly among all possible templates + sample = torch.randperm(ntplt)[:npick] + else: # only consider top 50 templates + sample = torch.randperm(min(50,ntplt))[:npick] + + xyz = torch.full((npick,qlen,NTOTAL,3),np.nan).float() + mask = torch.full((npick,qlen,NTOTAL),False) + t1d = torch.full((npick, qlen), 20).long() # all gaps + t1d_val = torch.zeros((npick, qlen)).float() + + for i,nt in enumerate(sample): + ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates + sel = torch.where(tplt['qmap'][0,:,1]==nt)[0] + pos = tplt['qmap'][0,sel,0] + offset + xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel] + mask[i,pos,:ntmplatoms] = tplt['mask'][0,sel] + # 1-D features: alignment confidence + t1d[i,pos] = tplt['seq'][0,sel] + t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence + + t1d = torch.nn.functional.one_hot(t1d, num_classes=NAATOKENS-1).float() # (no mask token) + t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1) + + xyz = torch.where(mask[...,None], xyz.float(),torch.full((npick,qlen,NTOTAL,3),np.nan).float()) + + return xyz, t1d + + +def get_train_valid_set(params, OFFSET=1000000): + if (not os.path.exists(params['DATAPKL'])): + # read validation IDs for PDB set + val_pdb_ids = set([int(l) for l in open(params['VAL_PDB']).readlines()]) + val_compl_ids = set([int(l) for l in open(params['VAL_COMPL']).readlines()]) + val_neg_ids = set([int(l)+OFFSET for l in open(params['VAL_NEG']).readlines()]) + val_rna_pdb_ids = set([l.rstrip() for l in open(params['VAL_RNA']).readlines()]) + test_sm_ids = set([int(l) for l in open(params['TEST_SM']).readlines()]) + + # read & clean RNA list + with open(params['RNA_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + rows = [[r[0],[int(clid) for clid in r[3].split(':')], [int(plen) for plen in r[4].split(':')]] for r in reader + if float(r[2]) <= params['RESCUT'] and + parser.parse(r[1]) <= parser.parse(params['DATCUT'])] + + # compile training and validation sets + train_rna = {} + valid_rna = {} + for i,r in enumerate(rows): + if any([x in val_rna_pdb_ids for x in r[0].split(":")]): + valid_rna[i] = [(r[0], r[-1])] + else: + train_rna[i] = [(r[0], r[-1])] + + with open(params["SM_LIST"], 'r') as f: + reader = csv.reader(f) + next(reader) + rows = [[r[0],r[3],int(r[4]), int(r[6]), ast.literal_eval(r[-2].strip())] for r in reader + if float(r[2])<=params['RESCUT'] and + parser.parse(r[1])<=parser.parse(params['DATCUT'])] + + train_sm_compl = {} + valid_sm_compl = {} + for r in rows: + if r[2] in val_pdb_ids: + if r[2] in valid_sm_compl.keys(): + valid_sm_compl[r[2]].append((r[:2], r[3], r[-1])) + else: + valid_sm_compl[r[2]] = [(r[:2], r[3], r[-1])] + else: + if r[2] in train_sm_compl.keys(): + train_sm_compl[r[2]].append((r[:2], r[3], r[-1])) + else: + train_sm_compl[r[2]] = [(r[:2], r[3], r[-1])] + + # read homo-oligomer list + homo = {} + with open(params['HOMO_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + # read pdbA, pdbB, bioA, opA, bioB, opB + rows = [[r[0], r[1], int(r[2]), int(r[3]), int(r[4]), int(r[5])] for r in reader] + for r in rows: + if r[0] in homo.keys(): + homo[r[0]].append(r[1:]) + else: + homo[r[0]] = [r[1:]] + + # read & clean list.csv + with open(params['PDB_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + rows = [[r[0],r[3],int(r[4]), int(r[-1].strip())] for r in reader + if float(r[2])<=params['RESCUT'] and + parser.parse(r[1])<=parser.parse(params['DATCUT'])] + + # compile training and validation sets + val_hash = list() + train_pdb = {} + valid_pdb = {} + valid_homo = {} + for r in rows: + if r[2] in val_pdb_ids or r[2] in test_sm_ids: + val_hash.append(r[1]) + if r[2] in valid_pdb.keys(): + valid_pdb[r[2]].append((r[:2], r[-1])) + else: + valid_pdb[r[2]] = [(r[:2], r[-1])] + # + if r[0] in homo: + if r[2] in valid_homo.keys(): + valid_homo[r[2]].append((r[:2], r[-1])) + else: + valid_homo[r[2]] = [(r[:2], r[-1])] + else: + if r[2] in train_pdb.keys(): + train_pdb[r[2]].append((r[:2], r[-1])) + else: + train_pdb[r[2]] = [(r[:2], r[-1])] + + # compile facebook model sets + with open(params['FB_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + rows = [[r[0],r[2],int(r[3]),len(r[-1].strip())] for r in reader + if float(r[1]) > 80.0 and + len(r[-1].strip()) > 200] + fb = {} + for r in rows: + if r[2] in fb.keys(): + fb[r[2]].append((r[:2], r[-1])) + else: + fb[r[2]] = [(r[:2], r[-1])] + + # compile complex sets + with open(params['COMPL_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + # read complex_pdb, pMSA_hash, complex_cluster, length, taxID, assembly (bioA,opA,bioB,opB) + rows = [[r[0], r[3], int(r[4]), [int(plen) for plen in r[5].split(':')], r[6] , [int(r[7]), int(r[8]), int(r[9]), int(r[10])]] for r in reader + if float(r[2]) <= params['RESCUT'] and + parser.parse(r[1]) <= parser.parse(params['DATCUT'])] + + train_compl = {} + valid_compl = {} + for r in rows: + if r[2] in val_compl_ids: + if r[2] in valid_compl.keys(): + valid_compl[r[2]].append((r[:2], r[-3], r[-2], r[-1])) # ((pdb, hash), length, taxID, assembly, negative?) + else: + valid_compl[r[2]] = [(r[:2], r[-3], r[-2], r[-1])] + else: + # if subunits are included in PDB validation set, exclude them from training + hashA, hashB = r[1].split('_') + if hashA in val_hash: + continue + if hashB in val_hash: + continue + if r[2] in train_compl.keys(): + train_compl[r[2]].append((r[:2], r[-3], r[-2], r[-1])) + else: + train_compl[r[2]] = [(r[:2], r[-3], r[-2], r[-1])] + + # compile negative examples + # remove pairs if any of the subunits are included in validation set + with open(params['NEGATIVE_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + # read complex_pdb, pMSA_hash, complex_cluster, length, taxonomy + rows = [[r[0],r[3],OFFSET+int(r[4]),[int(plen) for plen in r[5].split(':')],r[6]] for r in reader + if float(r[2])<=params['RESCUT'] and + parser.parse(r[1])<=parser.parse(params['DATCUT'])] + + train_neg = {} + valid_neg = {} + for r in rows: + if r[2] in val_neg_ids: + if r[2] in valid_neg.keys(): + valid_neg[r[2]].append((r[:2], r[-2], r[-1], [])) + else: + valid_neg[r[2]] = [(r[:2], r[-2], r[-1], [])] + else: + hashA, hashB = r[1].split('_') + if hashA in val_hash: + continue + if hashB in val_hash: + continue + if r[2] in train_neg.keys(): + train_neg[r[2]].append((r[:2], r[-2], r[-1], [])) + else: + train_neg[r[2]] = [(r[:2], r[-2], r[-1], [])] + + # compile NA complex sets + # use PDB validation set as validation set + with open(params['NA_COMPL_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + # read complex_pdb, pMSA_hash, complex_cluster, length + rows = [[r[0], r[3], int(r[4]), [int(plen) for plen in r[5].split(':')]] for r in reader + if float(r[2]) <= params['RESCUT'] and + parser.parse(r[1]) <= parser.parse(params['DATCUT'])] + + train_na_compl = {} + valid_na_compl = {} + for r in rows: + if r[2] in val_compl_ids: + if r[2] in valid_na_compl.keys(): + valid_na_compl[r[2]].append((r[:2], r[-1])) # ((pdb, hash), length) + else: + valid_na_compl[r[2]] = [(r[:2], r[-1])] + else: + if r[2] in train_na_compl.keys(): + train_na_compl[r[2]].append((r[:2], r[-1])) + else: + train_na_compl[r[2]] = [(r[:2], r[-1])] + + # compile negative examples + # remove pairs if any of the subunits are included in validation set + with open(params['NEG_NA_COMPL_LIST'], 'r') as f: + reader = csv.reader(f) + next(reader) + # read complex_pdb, pMSA_hash, complex_cluster, length, taxonomy + rows = [[r[0],r[3],OFFSET+int(r[4]),[int(plen) for plen in r[5].split(':')]] for r in reader + if float(r[2])<=params['RESCUT'] and + parser.parse(r[1])<=parser.parse(params['DATCUT'])] + + train_na_neg = {} + valid_na_neg = {} + for r in rows: + if r[2] in val_neg_ids: + if r[2] in valid_na_neg.keys(): + valid_na_neg[r[2]].append((r[:2], r[-1])) + else: + valid_na_neg[r[2]] = [(r[:2], r[-1])] + else: + if r[2] in train_na_neg.keys(): + train_na_neg[r[2]].append((r[:2], r[-1])) + else: + train_na_neg[r[2]] = [(r[:2], r[-1])] + + # Get average chain length in each cluster and calculate weights + pdb_IDs = list(train_pdb.keys()) + fb_IDs = list(fb.keys()) + compl_IDs = list(train_compl.keys()) + neg_IDs = list(train_neg.keys()) + na_compl_IDs = list(train_na_compl.keys()) + na_neg_IDs = list(train_na_neg.keys()) + rna_IDs = list(train_rna.keys()) + sm_compl_IDs = list(train_sm_compl.keys()) + + # + pdb_weights = np.array([train_pdb[key][0][1] for key in pdb_IDs]) + pdb_weights = (1/512.)*np.clip(pdb_weights, 256, 512) + fb_weights = np.array([fb[key][0][1] for key in fb_IDs]) + fb_weights = (1/512.)*np.clip(fb_weights, 256, 512) + compl_weights = np.array([sum(train_compl[key][0][1]) for key in compl_IDs]) + compl_weights = (1/512.)*np.clip(compl_weights, 256, 512) + neg_weights = np.array([sum(train_neg[key][0][1]) for key in neg_IDs]) + neg_weights = (1/512.)*np.clip(neg_weights, 256, 512) + na_compl_weights = np.array([sum(train_na_compl[key][0][1]) for key in na_compl_IDs]) + na_compl_weights = (1/512.)*np.clip(na_compl_weights, 256, 512) + na_neg_weights = np.array([sum(train_na_neg[key][0][1]) for key in na_neg_IDs]) + na_neg_weights = (1/512.)*np.clip(na_neg_weights, 256, 512) + rna_weights = np.ones(len(rna_IDs)) # no weighing + sm_compl_weights = np.array([train_sm_compl[key][0][1] for key in sm_compl_IDs]) + sm_compl_weights = (1/512.)*np.clip(sm_compl_weights, 256, 512) + + # save + obj = ( + pdb_IDs, pdb_weights, train_pdb, + fb_IDs, fb_weights, fb, + compl_IDs, compl_weights, train_compl, + neg_IDs, neg_weights, train_neg, + na_compl_IDs, na_compl_weights, train_na_compl, + na_neg_IDs, na_neg_weights, train_na_neg, + rna_IDs, rna_weights, train_rna, + sm_compl_IDs, sm_compl_weights, train_sm_compl, + valid_pdb, valid_homo, + valid_compl, valid_neg, + valid_na_compl, valid_na_neg, + valid_rna, valid_sm_compl, + homo + ) + with open(params["DATAPKL"], "wb") as f: + print ('Writing',params["DATAPKL"],'...') + pickle.dump(obj, f) + print ('...done') + else: + with open(params["DATAPKL"], "rb") as f: + print ('Loading',params["DATAPKL"],'...') + ( + pdb_IDs, pdb_weights, train_pdb, + fb_IDs, fb_weights, fb, + compl_IDs, compl_weights, train_compl, + neg_IDs, neg_weights, train_neg, + na_compl_IDs, na_compl_weights, train_na_compl, + na_neg_IDs, na_neg_weights, train_na_neg, + rna_IDs, rna_weights, train_rna, + sm_compl_IDs, sm_compl_weights, train_sm_compl, + valid_pdb, valid_homo, + valid_compl, valid_neg, + valid_na_compl, valid_na_neg, + valid_rna, valid_sm_compl, + homo + ) = pickle.load(f) + print ('...done') + + return ( + (pdb_IDs, torch.tensor(pdb_weights).float(), train_pdb), \ + (fb_IDs, torch.tensor(fb_weights).float(), fb), \ + (compl_IDs, torch.tensor(compl_weights).float(), train_compl), \ + (neg_IDs, torch.tensor(neg_weights).float(), train_neg),\ + (na_compl_IDs, torch.tensor(na_compl_weights).float(), train_na_compl),\ + (na_neg_IDs, torch.tensor(na_neg_weights).float(), train_na_neg),\ + (rna_IDs, torch.tensor(rna_weights).float(), train_rna),\ + (sm_compl_IDs, torch.tensor(sm_compl_weights).float(), train_sm_compl), + valid_pdb, valid_homo, + valid_compl, valid_neg, + valid_na_compl, valid_na_neg, + valid_rna, valid_sm_compl, + homo + ) + + +# slice long chains +def get_crop(l, mask, device, params, unclamp=False): + + sel = torch.arange(l,device=device) + if l <= params['CROP']: + return sel + + size = params['CROP'] + + mask = ~(mask[:,:3].sum(dim=-1) < 3.0) + exists = mask.nonzero()[0] + res_idx = exists[torch.randperm(len(exists))[0]].item() + + lower_bound = max(0, res_idx-size+1) + upper_bound = min(l-size, res_idx+1) + start = np.random.randint(lower_bound, upper_bound) + return sel[start:start+size] + +# devide crop between multiple (2+) chains +# >20 res / chain +def rand_crops(ls, maxlen, minlen=20): + base = [min(minlen,l) for l in ls ] + nremain = [max(0,l-minlen) for l in ls ] + + # this must be inefficient... + pool = [] + for i in range(len(ls)): + pool.extend([i]*nremain[i]) + pool = random.sample(pool,maxlen-sum(base)) + chosen = [base[i] + sum(p==i for p in pool) for i in range(len(ls))] + return torch.tensor(chosen) + + +def get_complex_crop(len_s, mask, device, params): + tot_len = sum(len_s) + sel = torch.arange(tot_len, device=device) + + crops = rand_crops(len_s, params['CROP']) + + offset = 0 + sel_s = list() + for k in range(len(len_s)): + mask_chain = ~(mask[offset:offset+len_s[k],:3].sum(dim=-1) < 3.0) + exists = mask_chain.nonzero()[0] + res_idx = exists[torch.randperm(len(exists))[0]].item() + lower_bound = max(0, res_idx - crops[k] + 1) + upper_bound = min(len_s[k]-crops[k], res_idx) + 1 + start = np.random.randint(lower_bound, upper_bound) + offset + sel_s.append(sel[start:start+crops[k]]) + offset += len_s[k] + return torch.cat(sel_s) + +def get_spatial_crop(xyz, mask, sel, len_s, params, cutoff=10.0, eps=1e-6): + device = xyz.device + + # get interface residues + # interface defined as chain 1 versus all other chains + cond = torch.cdist(xyz[:len_s[0],1], xyz[len_s[0]:,1]) < cutoff + cond = torch.logical_and(cond, mask[:len_s[0],None,1]*mask[None,len_s[0]:,1]) + i,j = torch.where(cond) + ifaces = torch.cat([i,j+len_s[0]]) + if len(ifaces) < 1: + print ("ERROR: no iface residue????") + return get_complex_crop(len_s, mask, device, params) + cnt_idx = ifaces[np.random.randint(len(ifaces))] + + dist = torch.cdist(xyz[:,1], xyz[cnt_idx,1][None]).reshape(-1) + torch.arange(len(xyz), device=xyz.device)*eps + cond = mask[:,1]*mask[cnt_idx,1] + dist[~cond] = 999999.9 + _, idx = torch.topk(dist, params['CROP'], largest=False) + + sel, _ = torch.sort(sel[idx]) + return sel + + +# this is a bit of a mess... +def get_na_crop(seq, xyz, mask, sel, len_s, params, negative=False, incl_protein=True, cutoff=12.0, bp_cutoff=4.0, eps=1e-6): + device = xyz.device + + # get base pairing NA bases + repatom = torch.zeros(sum(len_s), dtype=torch.long, device=xyz.device) + repatom[seq==22] = 15 # DA - N1 + repatom[seq==23] = 14 # DC - N3 + repatom[seq==24] = 15 # DG - N1 + repatom[seq==25] = 14 # DT - N3 + repatom[seq==27] = 12 # A - N1 + repatom[seq==28] = 15 # C - N3 + repatom[seq==29] = 12 # G - N1 + repatom[seq==30] = 15 # U - N3 + + if not incl_protein: + if len(len_s)==2: + # 2 RNA chains + xyz_na1_rep = torch.gather(xyz[:len_s[0]], 1, repatom[:len_s[0],None,None].repeat(1,1,3)).squeeze(1) + xyz_na2_rep = torch.gather(xyz[len_s[0]:], 1, repatom[len_s[0]:,None,None].repeat(1,1,3)).squeeze(1) + cond = torch.cdist(xyz_na1_rep, xyz_na2_rep) < bp_cutoff + + mask_na1_rep = torch.gather(mask[:len_s[0]], 1, repatom[:len_s[0],None]).squeeze(1) + mask_na2_rep = torch.gather(mask[len_s[0]:], 1, repatom[len_s[0]:,None]).squeeze(1) + cond = torch.logical_and(cond, mask_na1_rep[:,None]*mask_na2_rep[None,:]) + else: + # 1 RNA chains + xyz_na_rep = torch.gather(xyz, 1, repatom[:,None,None].repeat(1,1,3)).squeeze(1) + cond = torch.cdist(xyz_na_rep, xyz_na_rep) < bp_cutoff + mask_na_rep = torch.gather(mask, 1, repatom[:,None]).squeeze(1) + cond = torch.logical_and(cond, mask_na_rep[:,None]*mask_na_rep[None,:]) + + if (torch.sum(cond)==0): + i= np.random.randint(len_s[0]-1) + while (not mask[i,1] or not mask[i+1,1]): + i = np.random.randint(len_s[0]) + cond[i,i+1] = True + + else: + if len(len_s)==3: + xyz_na1_rep = torch.gather(xyz[len_s[0]:(len_s[0]+len_s[1])], 1, repatom[len_s[0]:(len_s[0]+len_s[1]),None,None].repeat(1,1,3)).squeeze(1) + xyz_na2_rep = torch.gather(xyz[(len_s[0]+len_s[1]):], 1, repatom[(len_s[0]+len_s[1]):,None,None].repeat(1,1,3)).squeeze(1) + cond_bp = torch.cdist(xyz_na1_rep, xyz_na2_rep) < bp_cutoff + + mask_na1_rep = torch.gather(mask[len_s[0]:(len_s[0]+len_s[1])], 1, repatom[len_s[0]:(len_s[0]+len_s[1]),None]).squeeze(1) + mask_na2_rep = torch.gather(mask[(len_s[0]+len_s[1]):], 1, repatom[(len_s[0]+len_s[1]):,None]).squeeze(1) + cond_bp = torch.logical_and(cond_bp, mask_na1_rep[:,None]*mask_na2_rep[None,:]) + + if (not negative): + # get interface residues + # interface defined as chain 1 versus all other chains + xyz_na_rep = torch.gather(xyz[len_s[0]:], 1, repatom[len_s[0]:,None,None].repeat(1,1,3)).squeeze(1) + cond = torch.cdist(xyz[:len_s[0],1], xyz_na_rep) < cutoff + mask_na_rep = torch.gather(mask[len_s[0]:], 1, repatom[len_s[0]:,None]).squeeze(1) + cond = torch.logical_and( + cond, + mask[:len_s[0],None,1] * mask_na_rep[None,:] + ) + + if (negative or torch.sum(cond)==0): + # pick a random pair of residues + cond = torch.zeros( (len_s[0], sum(len_s[1:])), dtype=torch.bool ) + i,j = np.random.randint(len_s[0]), np.random.randint(sum(len_s[1:])) + while (not mask[i,1]): + i = np.random.randint(len_s[0]) + while (not mask[len_s[0]+j,1]): + j = np.random.randint(sum(len_s[1:])) + cond[i,j] = True + + # a) build a graph of costs: + # cost (i,j in same chain) = abs(i-j) + # cost (i,j in different chains) = { 0 if i,j are an interface + # = { 999 if i,j are NOT an interface + if len(len_s)==3: + int_1_2 = np.full((len_s[0],len_s[1]),999) + int_1_3 = np.full((len_s[0],len_s[2]),999) + int_2_3 = np.full((len_s[1],len_s[2]),999) + int_1_2[cond[:,:len_s[1]]]=1 + int_1_3[cond[:,len_s[1]:]]=1 + int_2_3[cond_bp] = 0 + inter = np.block([ + [np.abs(np.arange(len_s[0])[:,None]-np.arange(len_s[0])[None,:]),int_1_2,int_1_3], + [int_1_2.T,np.abs(np.arange(len_s[1])[:,None]-np.arange(len_s[1])[None,:]),int_2_3], + [int_1_3.T,int_2_3.T,np.abs(np.arange(len_s[2])[:,None]-np.arange(len_s[2])[None,:])] + ]) + elif len(len_s)==2: + int_1_2 = np.full((len_s[0],len_s[1]),999) + int_1_2[cond]=1 + inter = np.block([ + [np.abs(np.arange(len_s[0])[:,None]-np.arange(len_s[0])[None,:]),int_1_2], + [int_1_2.T,np.abs(np.arange(len_s[1])[:,None]-np.arange(len_s[1])[None,:])] + ]) + else: + inter = np.abs(np.arange(len_s[0])[:,None]-np.arange(len_s[0])[None,:]) + inter[cond] = 1 + + # b) pick a random interface residue + intface,_ = torch.where(cond) + startres = intface[np.random.randint(len(intface))] + + # c) traverse graph starting from chosen residue + d_res = shortest_path(inter,directed=False,indices=startres) + _, idx = torch.topk(torch.from_numpy(d_res).to(device=device), params['CROP'], largest=False) + + sel, _ = torch.sort(sel[idx]) + + return sel + + +# merge msa & insertion statistics of two proteins having different taxID +def merge_a3m_hetero(a3mA, a3mB, L_s): + # merge msa + query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L) + msa = [query] + if a3mA['msa'].shape[0] > 1: + extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps + msa.append(extra_A) + if a3mB['msa'].shape[0] > 1: + extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20) + msa.append(extra_B) + msa = torch.cat(msa, dim=0) + + # merge ins + query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L) + ins = [query] + if a3mA['ins'].shape[0] > 1: + extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps + ins.append(extra_A) + if a3mB['ins'].shape[0] > 1: + extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0) + ins.append(extra_B) + ins = torch.cat(ins, dim=0) + return {'msa': msa, 'ins': ins} + +# merge msa & insertion statistics of units in homo-oligomers +def merge_a3m_homo(msa_orig, ins_orig, nmer): + N, L = msa_orig.shape[:2] + msa = torch.full((1+(N-1)*nmer, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device) + ins = torch.full((1+(N-1)*nmer, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device) + start=0 + start2 = 1 + for i_c in range(nmer): + msa[0, start:start+L] = msa_orig[0] + msa[start2:start2+(N-1), start:start+L] = msa_orig[1:] + ins[0, start:start+L] = ins_orig[0] + ins[start2:start2+(N-1), start:start+L] = ins_orig[1:] + start += L + start2 += (N-1) + return msa, ins + +# Generate input features for single-chain +def featurize_single_chain(msa, ins, tplt, pdb, params, unclamp=False, pick_top=True): + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params) + + # get template features + ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']+1) + xyz_t,f1d_t = TemplFeaturize(tplt, msa.shape[1], params, npick=ntempl, offset=0, pick_top=pick_top) + + # get ground-truth structures + idx = torch.arange(len(pdb['xyz'])) + xyz = torch.full((len(idx),NTOTAL,3),np.nan).float() + xyz[:,:14,:] = pdb['xyz'] + mask = torch.full((len(idx), NTOTAL), False) + mask[:,:14] = pdb['mask'] + + # Residue cropping + crop_idx = get_crop(len(idx), mask, msa_seed_orig.device, params, unclamp=unclamp) + seq = seq[:,crop_idx] + msa_seed_orig = msa_seed_orig[:,:,crop_idx] + msa_seed = msa_seed[:,:,crop_idx] + msa_extra = msa_extra[:,:,crop_idx] + mask_msa = mask_msa[:,:,crop_idx] + xyz_t = xyz_t[:,crop_idx] + f1d_t = f1d_t[:,crop_idx] + xyz = xyz[crop_idx] + mask = mask[crop_idx] + idx = idx[crop_idx] + + # get initial coordinates + xyz_prev = xyz_t[0] + chain_idx = torch.ones((len(crop_idx), len(crop_idx))).long() + bond_feats = get_protein_bond_feats(len(crop_idx)).long() + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation + init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1) + xyz = torch.where(mask[...,None], xyz, init).contiguous() + xyz = torch.nan_to_num(xyz) + + #print ("loader_single", mask.shape, xyz_t.shape, f1d_t.shape, xyz_prev.shape) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa, \ + xyz.float(), mask, idx.long(),\ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, unclamp, False, torch.zeros(seq.shape), bond_feats + +# Generate input features for homo-oligomers +def featurize_homo(msa_orig, ins_orig, tplt, pdbA, pdbid, interfaces, params, pick_top=True): + L = msa_orig.shape[1] + + msa, ins = merge_a3m_homo(msa_orig, ins_orig, 2) # make unpaired alignments, for training, we always use two chains + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params, nmer=2, L_s=[L,L]) + + # get template features + ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']//2+1) + xyz_t_single, f1d_t_single = TemplFeaturize(tplt, L, params, npick=ntempl, offset=0, pick_top=pick_top) + ntempl = max(1, ntempl) + # duplicate + xyz_t = torch.full((2*ntempl, L*2, NTOTAL, 3), np.nan).float() + f1d_t = torch.full((2*ntempl, L*2), 20).long() + f1d_t = torch.cat((torch.nn.functional.one_hot(f1d_t, num_classes=NAATOKENS-1).float(), torch.zeros((2*ntempl, L*2, 1)).float()), dim=-1) + xyz_t[:ntempl,:L] = xyz_t_single + xyz_t[ntempl:,L:] = xyz_t_single + f1d_t[:ntempl,:L] = f1d_t_single + f1d_t[ntempl:,L:] = f1d_t_single + + # get initial coordinates + xyz_prev = torch.cat((xyz_t_single[0], xyz_t_single[0]), dim=0) + + # get ground-truth structures + # load metadata + PREFIX = "%s/torch/pdb/%s/%s"%(params['PDB_DIR'],pdbid[1:3],pdbid) + meta = torch.load(PREFIX+".pt") + + npairs = len(interfaces) + xyz = torch.full((npairs, 2*L, NTOTAL, 3), np.nan).float() + mask = torch.full((npairs, 2*L, NTOTAL), False) + for i_int,interface in enumerate(interfaces): + pdbB = torch.load(params['PDB_DIR']+'/torch/pdb/'+interface[0][1:3]+'/'+interface[0]+'.pt') + xformA = meta['asmb_xform%d'%interface[1]][interface[2]] + xformB = meta['asmb_xform%d'%interface[3]][interface[4]] + xyzA = torch.einsum('ij,raj->rai', xformA[:3,:3], pdbA['xyz']) + xformA[:3,3][None,None,:] + xyzB = torch.einsum('ij,raj->rai', xformB[:3,:3], pdbB['xyz']) + xformB[:3,3][None,None,:] + xyz[i_int,:,:14] = torch.cat((xyzA, xyzB), dim=0) + mask[i_int,:,:14] = torch.cat((pdbA['mask'], pdbB['mask']), dim=0) + + idx = torch.arange(L*2) + idx[L:] += 200 # to let network know about chain breaks + + # indicator for which residues are in same chain + chain_idx = torch.zeros((2*L, 2*L)).long() + chain_idx[:L, :L] = 1 + chain_idx[L:, L:] = 1 + bond_feats = torch.zeros((2*L, 2*L)).long() + bond_feats[:L, :L] = get_protein_bond_feats(L) + bond_feats[L:, L:] = get_protein_bond_feats(L) + + # Residue cropping + if 2*L > params['CROP']: + # crop so there are contacts in AT LEAST ONE of the interfaces + spatial_crop_tgt = np.random.randint(0, npairs) + crop_idx = get_spatial_crop( + xyz[spatial_crop_tgt], mask[spatial_crop_tgt], torch.arange(L*2), [L,L], params) + seq = seq[:,crop_idx] + msa_seed_orig = msa_seed_orig[:,:,crop_idx] + msa_seed = msa_seed[:,:,crop_idx] + msa_extra = msa_extra[:,:,crop_idx] + mask_msa = mask_msa[:,:,crop_idx] + xyz_t = xyz_t[:,crop_idx] + f1d_t = f1d_t[:,crop_idx] + xyz = xyz[:,crop_idx] + mask = mask[:,crop_idx] + idx = idx[crop_idx] + chain_idx = chain_idx[crop_idx][:,crop_idx] + bond_feats = bond_feats[crop_idx][:,crop_idx] + xyz_prev = xyz_prev[crop_idx] + + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation + init = INIT_CRDS.reshape(1, 1, NTOTAL, 3).repeat(npairs, xyz.shape[1], 1, 1) + + xyz = torch.where(mask[...,None], xyz, init).contiguous() + xyz = torch.nan_to_num(xyz) + + #print ("loader_homo", mask.shape, xyz_t.shape, f1d_t.shape, xyz_prev.shape) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa, \ + xyz.float(), mask, idx.long(),\ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, False, False, torch.zeros(seq.shape), bond_feats + + +def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut): + xyz, mask, res_idx = parse_pdb(pdbfilename) + plddt = np.load(plddtfilename) + + # update mask info with plddt (ignore sidechains if plddt < 90.0) + mask_lddt = np.full_like(mask, False) + mask_lddt[plddt > sccut] = True + mask_lddt[:,:5] = True + mask = np.logical_and(mask, mask_lddt) + mask = np.logical_and(mask, (plddt > lddtcut)[:,None]) + + return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item} + +def get_msa(a3mfilename, item, unzip=True): + msa,ins = parse_a3m(a3mfilename, unzip=unzip) + return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'label':item} + +# Load PDB examples +def loader_pdb(item, params, homo, unclamp=False, pick_top=True, p_homo_cut=0.5): + # load MSA, PDB, template info + pdb = torch.load(params['PDB_DIR']+'/torch/pdb/'+item[0][1:3]+'/'+item[0]+'.pt') + a3m = get_msa(params['PDB_DIR'] + '/a3m/' + item[1][:3] + '/' + item[1] + '.a3m.gz', item[1]) + tplt = torch.load(params['PDB_DIR']+'/torch/hhr/'+item[1][:3]+'/'+item[1]+'.pt') + + # get msa features + msa = a3m['msa'].long() + ins = a3m['ins'].long() + if len(msa) > params['BLOCKCUT']: + msa, ins = MSABlockDeletion(msa, ins) + + if item[0] in homo: # Target is homo-oligomer + p_homo = np.random.rand() + if p_homo < p_homo_cut: # model as homo-oligomer with p_homo_cut prob + pdbid = item[0].split('_')[0] + # choose one from all possible dimer copies of original homomers + #sel_idx = np.random.randint(0, len(homo[item[0]])) + #homo_item = homo[item[0]][sel_idx] + interfaces = homo[item[0]] + feats = featurize_homo(msa, ins, tplt, pdb, pdbid, interfaces, params, pick_top=pick_top) + return feats + else: + return featurize_single_chain(msa, ins, tplt, pdb, params, unclamp=unclamp, pick_top=pick_top) + else: + return featurize_single_chain(msa, ins, tplt, pdb, params, unclamp=unclamp, pick_top=pick_top) + +def loader_fb(item, params, unclamp=False): + + # loads sequence/structure/plddt information + a3m = get_msa(os.path.join(params["FB_DIR"], "a3m", item[-1][:2], item[-1][2:], item[0]+".a3m.gz"), item[0]) + pdb = get_pdb(os.path.join(params["FB_DIR"], "pdb", item[-1][:2], item[-1][2:], item[0]+".pdb"), + os.path.join(params["FB_DIR"], "pdb", item[-1][:2], item[-1][2:], item[0]+".plddt.npy"), + item[0], params['PLDDTCUT'], params['SCCUT']) + + # get msa features + msa = a3m['msa'].long() + ins = a3m['ins'].long() + l_orig = msa.shape[1] + if len(msa) > params['BLOCKCUT']: + msa, ins = MSABlockDeletion(msa, ins) + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params) + + # get template features -- None + xyz_t = torch.full((1,l_orig,NTOTAL,3),np.nan).float() + f1d_t = torch.nn.functional.one_hot(torch.full((1, l_orig), 20).long(), num_classes=NAATOKENS-1).float() # all gaps + conf = torch.zeros((1,l_orig,1)).float() # zero confidence + f1d_t = torch.cat((f1d_t, conf), -1) + + idx = pdb['idx'] + xyz = torch.full((len(idx),NTOTAL,3),np.nan).float() + xyz[:,:27,:] = pdb['xyz'] + mask = torch.full((len(idx),NTOTAL), False) + mask[:,:27] = pdb['mask'] + + # Residue cropping + crop_idx = get_crop(len(idx), mask, msa_seed_orig.device, params, unclamp=unclamp) + seq = seq[:,crop_idx] + msa_seed_orig = msa_seed_orig[:,:,crop_idx] + msa_seed = msa_seed[:,:,crop_idx] + msa_extra = msa_extra[:,:,crop_idx] + mask_msa = mask_msa[:,:,crop_idx] + xyz_t = xyz_t[:,crop_idx] + f1d_t = f1d_t[:,crop_idx] + xyz = xyz[crop_idx] + mask = mask[crop_idx] + idx = idx[crop_idx] + + # initial structure + xyz_prev = xyz_t[0] + chain_idx = torch.ones((len(crop_idx), len(crop_idx))).long() + bond_feats = get_protein_bond_feats(len(crop_idx)).long() + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + + #print ("loader_fb", mask.shape, xyz_t.shape, f1d_t.shape, xyz_prev.shape) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa, \ + xyz.float(), mask, idx.long(),\ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, unclamp, False, torch.zeros(seq.shape), bond_feats + + +def loader_complex(item, L_s, taxID, assem, params, negative=False, pick_top=True): + pdb_pair = item[0] + pMSA_hash = item[1] + + msaA_id, msaB_id = pMSA_hash.split('_') + if len(set(taxID.split(':'))) == 1: # two proteins have same taxID -- use paired MSA + # read pMSA + if negative: + pMSA_fn = params['COMPL_DIR'] + '/pMSA.negative/' + msaA_id[:3] + '/' + msaB_id[:3] + '/' + pMSA_hash + '.a3m' + else: + pMSA_fn = params['COMPL_DIR'] + '/pMSA/' + msaA_id[:3] + '/' + msaB_id[:3] + '/' + pMSA_hash + '.a3m' + a3m = get_msa(pMSA_fn, pMSA_hash, unzip=False) + else: + # read MSA for each subunit & merge them + a3mA_fn = params['PDB_DIR'] + '/a3m/' + msaA_id[:3] + '/' + msaA_id + '.a3m.gz' + a3mB_fn = params['PDB_DIR'] + '/a3m/' + msaB_id[:3] + '/' + msaB_id + '.a3m.gz' + a3mA = get_msa(a3mA_fn, msaA_id) + a3mB = get_msa(a3mB_fn, msaB_id) + a3m = merge_a3m_hetero(a3mA, a3mB, L_s) + + # get MSA features + msa = a3m['msa'].long() + if negative: # Qian's paired MSA for true-pairs have no insertions... (ignore insertion to avoid any weird bias..) + ins = torch.zeros_like(msa) + else: + ins = a3m['ins'].long() + if len(msa) > params['BLOCKCUT']: + msa, ins = MSABlockDeletion(msa, ins) + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params, L_s=L_s) + + # read template info + tpltA_fn = params['PDB_DIR'] + '/torch/hhr/' + msaA_id[:3] + '/' + msaA_id + '.pt' + tpltB_fn = params['PDB_DIR'] + '/torch/hhr/' + msaB_id[:3] + '/' + msaB_id + '.pt' + tpltA = torch.load(tpltA_fn) + tpltB = torch.load(tpltB_fn) + ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']//2+1) + xyz_t_A, f1d_t_A = TemplFeaturize(tpltA, sum(L_s), params, offset=0, npick=ntempl, pick_top=pick_top) + xyz_t_B, f1d_t_B = TemplFeaturize(tpltB, sum(L_s), params, offset=L_s[0], npick=ntempl, pick_top=pick_top) + xyz_t = torch.cat((xyz_t_A, xyz_t_B), dim=0) + f1d_t = torch.cat((f1d_t_A, f1d_t_B), dim=0) + + # get initial coordinates + xyz_prev = torch.cat((xyz_t_A[0][:L_s[0]], xyz_t_B[0][L_s[0]:]), dim=0) + + # read PDB + pdbA_id, pdbB_id = pdb_pair.split(':') + pdbA = torch.load(params['PDB_DIR']+'/torch/pdb/'+pdbA_id[1:3]+'/'+pdbA_id+'.pt') + pdbB = torch.load(params['PDB_DIR']+'/torch/pdb/'+pdbB_id[1:3]+'/'+pdbB_id+'.pt') + + if len(assem) > 0: + # read metadata + pdbid = pdbA_id.split('_')[0] + meta = torch.load(params['PDB_DIR']+'/torch/pdb/'+pdbid[1:3]+'/'+pdbid+'.pt') + + # get transform + xformA = meta['asmb_xform%d'%assem[0]][assem[1]] + xformB = meta['asmb_xform%d'%assem[2]][assem[3]] + + # apply transform + xyzA = torch.einsum('ij,raj->rai', xformA[:3,:3], pdbA['xyz']) + xformA[:3,3][None,None,:] + xyzB = torch.einsum('ij,raj->rai', xformB[:3,:3], pdbB['xyz']) + xformB[:3,3][None,None,:] + xyz = torch.full((sum(L_s), NTOTAL, 3), np.nan).float() + xyz[:,:14] = torch.cat((xyzA, xyzB), dim=0) + mask = torch.full((sum(L_s), NTOTAL), False) + mask[:,:14] = torch.cat((pdbA['mask'], pdbB['mask']), dim=0) + else: + xyz = torch.full((sum(L_s), NTOTAL, 3), np.nan).float() + xyz[:,:14] = torch.cat((pdbA['xyz'], pdbB['xyz']), dim=0) + mask = torch.full((sum(L_s), NTOTAL), False) + mask[:,:14] = torch.cat((pdbA['mask'], pdbB['mask']), dim=0) + idx = torch.arange(sum(L_s)) + idx[L_s[0]:] += 200 + + chain_idx = torch.zeros((sum(L_s), sum(L_s))).long() + chain_idx[:L_s[0], :L_s[0]] = 1 + chain_idx[L_s[0]:, L_s[0]:] = 1 + bond_feats = torch.zeros((sum(L_s), sum(L_s))).long() + bond_feats[:L_s[0], :L_s[0]] = get_protein_bond_feats(L_s[0]) + bond_feats[L_s[0]:, L_s[0]:] = get_protein_bond_feats(sum(L_s[1:])) + + # Do cropping + if sum(L_s) > params['CROP']: + if negative: + sel = get_complex_crop(L_s, mask, seq.device, params) + else: + sel = get_spatial_crop(xyz, mask, torch.arange(sum(L_s)), L_s, params) + # + seq = seq[:,sel] + msa_seed_orig = msa_seed_orig[:,:,sel] + msa_seed = msa_seed[:,:,sel] + msa_extra = msa_extra[:,:,sel] + mask_msa = mask_msa[:,:,sel] + xyz = xyz[sel] + mask = mask[sel] + xyz_t = xyz_t[:,sel] + f1d_t = f1d_t[:,sel] + xyz_prev = xyz_prev[sel] + # + idx = idx[sel] + chain_idx = chain_idx[sel][:,sel] + bond_feats = bond_feats[sel][:,sel] + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + + # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation + init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1) + xyz = torch.where(mask[...,None], xyz, init).contiguous() + xyz = torch.nan_to_num(xyz) + + #print ("loader_compl", mask.shape, xyz_t.shape, f1d_t.shape, xyz_prev.shape, negative) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\ + xyz.float(), mask, idx.long(), \ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, False, negative, torch.zeros(seq.shape), bond_feats + +def loader_na_complex(item, Ls, params, native_NA_frac=0.25, negative=False, pick_top=True): + pdb_set = item[0] + msa_id = item[1] + + # read MSA for protein + a3mA = get_msa(params['PDB_DIR'] + '/a3m/' + msa_id[:3] + '/' + msa_id + '.a3m.gz', msa_id) + + # read PDBs + pdb_ids = pdb_set.split(':') + pdbA = torch.load(params['PDB_DIR']+'/torch/pdb/'+pdb_ids[0][1:3]+'/'+pdb_ids[0]+'.pt') + pdbB = torch.load(params['NA_DIR']+'/torch/'+pdb_ids[1][1:3]+'/'+pdb_ids[1]+'.pt') + pdbC = None + if (len(pdb_ids)==3): + pdbC = torch.load(params['NA_DIR']+'/torch/'+pdb_ids[2][1:3]+'/'+pdb_ids[2]+'.pt') + + # msa for NA is sequence only + #alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask + #if (len(pdb_ids)==2): + # a3mB = np.array([list(pdbB['seq'])], dtype='|S1').view(np.uint8) + #else: + # a3mB = np.array([list(pdbB['seq']+pdbC['seq'])], dtype='|S1').view(np.uint8) # separate entries? + #for i in range(alphabet.shape[0]): + # a3mB[a3mB == alphabet[i]] = i + #a3mB = { + # 'msa':torch.from_numpy(a3mB), + # 'ins':torch.zeros(a3mB.shape, dtype=torch.uint8), + #} + msaB,insB = parse_fasta_if_exists(pdbB['seq'], params['NA_DIR']+'/torch/'+pdb_ids[1][1:3]+'/'+pdb_ids[1]+'.afa', rmsa_alphabet=True) + a3mB = {'msa':torch.from_numpy(msaB), 'ins':torch.from_numpy(insB)} + if (len(pdb_ids)==3): + msaC,insC = parse_fasta_if_exists(pdbC['seq'], params['NA_DIR']+'/torch/'+pdb_ids[2][1:3]+'/'+pdb_ids[2]+'.afa', rmsa_alphabet=True) + a3mC = {'msa':torch.from_numpy(msaC), 'ins':torch.from_numpy(insC)} + a3mB = merge_a3m_hetero(a3mB, a3mC, Ls[1:]) + a3m = merge_a3m_hetero(a3mA, a3mB, [Ls[0],sum(Ls[1:])]) + + # note: the block below is due to differences in the way RNA and DNA structures are processed + # to support NMR, RNA structs return multiple states + # For protein/NA complexes get rid of the 'NMODEL' dimension (if present) + # NOTE there are a very small number of protein/NA NMR models: + # - ideally these should return the ensemble, but that requires reprocessing of PDBs + if (len(pdbB['xyz'].shape) > 3): + pdbB['xyz'] = pdbB['xyz'][0,...] + pdbB['mask'] = pdbB['mask'][0,...] + if (pdbC is not None and len(pdbC['xyz'].shape) > 3): + pdbC['xyz'] = pdbC['xyz'][0,...] + pdbC['mask'] = pdbC['mask'][0,...] + + # read template info + tpltA = torch.load(params['PDB_DIR'] + '/torch/hhr/' + msa_id[:3] + '/' + msa_id + '.pt') + ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']-1) + xyz_t, f1d_t = TemplFeaturize(tpltA, sum(Ls), params, offset=0, npick=ntempl, pick_top=pick_top) + + xyz_prev = xyz_t[0] + + if (np.random.rand()<=native_NA_frac): + natNA_templ = pdbB['xyz'] + if pdbC is not None: + natNA_templ = torch.cat((pdbB['xyz'], pdbC['xyz']), dim=0) + + # construct template from NA + xyz_t_B = torch.full((1,sum(Ls),NTOTAL,3),np.nan).float() + xyz_t_B[:,Ls[0]:sum(Ls),:23] = natNA_templ + seq_t_B = torch.cat( (torch.full((1, Ls[0]), 20).long(), a3mB['msa'][0:1]), dim=1) + seq_t_B[seq_t_B>21] -= 1 # remove mask token + f1d_t_B = torch.nn.functional.one_hot(seq_t_B, num_classes=NAATOKENS-1).float() + conf_B = torch.cat( ( + torch.zeros((1,Ls[0],1)), + torch.full((1,sum(Ls[1:]),1),1.0), + ),dim=1).float() + f1d_t_B = torch.cat((f1d_t_B, conf_B), -1) + + xyz_t = torch.cat((xyz_t,xyz_t_B),dim=0) + f1d_t = torch.cat((f1d_t,f1d_t_B),dim=0) + + xyz_prev = xyz_t_B[0] # initialize NA only + + # get MSA features + msa = a3m['msa'].long() + ins = a3m['ins'].long() + if len(msa) > params['BLOCKCUT']: + msa, ins = MSABlockDeletion(msa, ins) + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params, L_s=Ls) + + xyz = torch.full((sum(Ls), NTOTAL, 3), np.nan).float() + mask = torch.full((sum(Ls), NTOTAL), False) + + if (len(pdb_ids)==3): + xyz[:Ls[0],:14] = pdbA['xyz'] + xyz[Ls[0]:,:23] = torch.cat((pdbB['xyz'], pdbC['xyz']), dim=0) + mask[:Ls[0],:14] = pdbA['mask'] + mask[Ls[0]:,:23] = torch.cat((pdbB['mask'], pdbC['mask']), dim=0) + else: + xyz[:Ls[0],:14] = pdbA['xyz'] + xyz[Ls[0]:,:23] = pdbB['xyz'] + mask[:Ls[0],:14] = pdbA['mask'] + mask[Ls[0]:,:23] = pdbB['mask'] + + idx = torch.arange(sum(Ls)) + idx[Ls[0]:] += 200 + if (len(pdb_ids)==3): + idx[Ls[1]:] += 200 + + chain_idx = torch.zeros((sum(Ls), sum(Ls))).long() + chain_idx[:Ls[0], :Ls[0]] = 1 + chain_idx[Ls[0]:, Ls[0]:] = 1 # fd - "negatives" still predict DNA double helix + bond_feats = torch.zeros((sum(Ls), sum(Ls))).long() + bond_feats[:L_s[0], :L_s[0]] = get_protein_bond_feats(L_s[0]) + bond_feats[L_s[0]:, L_s[0]:] = get_protein_bond_feats(sum(L_s[1:])) + + init = torch.cat(( + INIT_CRDS.reshape(1, NTOTAL, 3).repeat(Ls[0], 1, 1), + INIT_NA_CRDS.reshape(1, NTOTAL, 3).repeat(sum(Ls[1:]), 1, 1) + ), dim=0) + + # Do cropping + #print (item) + if sum(Ls) > params['CROP']: + sel = get_na_crop(seq[0], xyz, mask, torch.arange(sum(Ls)), Ls, params, negative) + + seq = seq[:,sel] + msa_seed_orig = msa_seed_orig[:,:,sel] + msa_seed = msa_seed[:,:,sel] + msa_extra = msa_extra[:,:,sel] + mask_msa = mask_msa[:,:,sel] + xyz = xyz[sel] + mask = mask[sel] + xyz_t = xyz_t[:,sel] + f1d_t = f1d_t[:,sel] + xyz_prev = xyz_prev[sel] + # + idx = idx[sel] + chain_idx = chain_idx[sel][:,sel] + bond_feats = bond_feats[sel][:,sel] + init = init[sel] + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation + xyz = torch.where(mask[...,None], xyz, init).contiguous() + xyz = torch.nan_to_num(xyz) + + #print ("loader_na_complex", mask.shape, xyz_t.shape, f1d_t.shape, xyz_prev.shape) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\ + xyz.float(), mask, idx.long(), \ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, False, negative, torch.zeros(seq.shape), bond_feats + +def loader_rna(pdb_set, Ls, params): + # read PDBs + pdb_ids = pdb_set.split(':') + pdbA = torch.load(params['NA_DIR']+'/torch/'+pdb_ids[0][1:3]+'/'+pdb_ids[0]+'.pt') + pdbB = None + if (len(pdb_ids)==2): + pdbB = torch.load(params['NA_DIR']+'/torch/'+pdb_ids[1][1:3]+'/'+pdb_ids[1]+'.pt') + + # msa for NA is sequence only + msaA,insA = parse_fasta_if_exists(pdbA['seq'], params['NA_DIR']+'/torch/'+pdb_ids[0][1:3]+'/'+pdb_ids[0]+'.afa', rmsa_alphabet=True) + a3m = {'msa':torch.from_numpy(msaA), 'ins':torch.from_numpy(insA)} + if (len(pdb_ids)==2): + msaB,insB = parse_fasta_if_exists(pdbB['seq'], params['NA_DIR']+'/torch/'+pdb_ids[1][1:3]+'/'+pdb_ids[1]+'.afa', rmsa_alphabet=True) + a3mB = {'msa':torch.from_numpy(msaB), 'ins':torch.from_numpy(insB)} + a3m = merge_a3m_hetero(a3m, a3mB, Ls) + + # get template features -- None + L = sum(Ls) + xyz_t = torch.full((1,L,NTOTAL,3),np.nan).float() + f1d_t = torch.nn.functional.one_hot(torch.full((1, L), 20).long(), num_classes=NAATOKENS-1).float() # all gaps + conf = torch.zeros((1,L,1)).float() # zero confidence + f1d_t = torch.cat((f1d_t, conf), -1) + + xyz_prev = xyz_t[0] + + NMDLS = pdbA['xyz'].shape[0] + + # get MSA features + msa = a3m['msa'].long() + ins = a3m['ins'].long() + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params, L_s=Ls) + + xyz = torch.full((NMDLS, L, NTOTAL, 3), np.nan).float() + mask = torch.full((NMDLS, L, NTOTAL), False) + if (len(pdb_ids)==2): + xyz[:,:,:23] = torch.cat((pdbA['xyz'], pdbB['xyz']), dim=1) + mask[:,:,:23] = torch.cat((pdbA['mask'], pdbB['mask']), dim=1) + else: + xyz[:,:,:23] = pdbA['xyz'] + mask[:,:,:23] = pdbA['mask'] + + idx = torch.arange(L) + if (len(pdb_ids)==2): + idx[Ls[0]:] += 200 + + chain_idx = torch.ones(L,L).long() + bond_feats = get_protein_bond_feats(L) + init = INIT_NA_CRDS.reshape(1, NTOTAL, 3).repeat(L, 1, 1) + + # Do cropping + #print (item) + if sum(Ls) > params['CROP']: + cropref = np.random.randint(xyz.shape[0]) + sel = get_na_crop(seq[0], xyz[cropref], mask[cropref], torch.arange(L), Ls, params, incl_protein=False) + + seq = seq[:,sel] + msa_seed_orig = msa_seed_orig[:,:,sel] + msa_seed = msa_seed[:,:,sel] + msa_extra = msa_extra[:,:,sel] + mask_msa = mask_msa[:,:,sel] + xyz = xyz[:,sel] + mask = mask[:,sel] + xyz_t = xyz_t[:,sel] + f1d_t = f1d_t[:,sel] + xyz_prev = xyz_prev[sel] + # + idx = idx[sel] + chain_idx = chain_idx[sel][:,sel] + bond_feats = bond_feats[sel][:, sel] + init = init[sel] + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation + xyz = torch.where(mask[...,None], xyz, init).contiguous() + xyz = torch.nan_to_num(xyz) + + #print ("loader_rna", mask.shape, xyz_t.shape, f1d_t.shape, xyz_prev.shape) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\ + xyz.float(), mask, idx.long(), \ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, False, False, torch.zeros(seq.shape), bond_feats + +def loader_sm_compl(item, sm_chains, params, pick_top=True): + """Load protein/SM complex with mixed residue and atom tokens. Also, compute frames for atom FAPE loss calc""" + # Load protein information + pdbA = torch.load(params['PDB_DIR']+'/torch/pdb/'+item[0][1:3]+'/'+item[0]+'.pt') + a3mA = get_msa(params['PDB_DIR'] + '/a3m/'+item[1][:3] + '/'+ item[1] + '.a3m.gz', item[1]) + tpltA = torch.load(params['PDB_DIR']+'/torch/hhr/'+item[1][:3]+'/'+item[1]+'.pt') + + # get msa features + msa_prot = a3mA['msa'].long() + ins_prot = a3mA['ins'].long() + + if len(msa_prot) > params['BLOCKCUT']: + msa_prot, ins_prot = MSABlockDeletion(msa_prot, ins_prot) + a3m_prot = {"msa": msa_prot, "ins": ins_prot} + xyz_prot, mask_prot = pdbA["xyz"], pdbA["mask"] + protein_L, nprotatoms, _ = xyz_prot.shape + # Load small molecule + + mol, msa_sm, ins_sm = parse_mol(params["MOL_DIR"]+"/mol2/"+item[0][1:3]+"/"+item[0][:-1]+random.choice(sm_chains)+".mol2") + a3m_sm = {"msa": msa_sm.unsqueeze(0), "ins": ins_sm.unsqueeze(0)} + G = get_nxgraph(mol) + frames = get_atom_frames(msa_sm, mol, G) + xyz_sm, mask_sm = get_ligand_xyz(mol) + + N_symmetry, sm_L, _ = xyz_sm.shape + # Generate ground truth structure: account for ligand symmetry + xyz = torch.full((N_symmetry, protein_L+sm_L, NTOTAL, 3), np.nan).float() + mask = torch.full(xyz.shape[:-1], False).bool() + xyz[:, :protein_L, :nprotatoms, :] = xyz_prot.expand(N_symmetry, protein_L, nprotatoms, 3) + xyz[:, protein_L:, 1, :] = xyz_sm + mask[:, :protein_L, :nprotatoms] = mask_prot.expand(N_symmetry, protein_L, nprotatoms) + mask[:, protein_L:, 1] = mask_sm + + Ls = [xyz_prot.shape[0], xyz_sm.shape[1]] + a3m = merge_a3m_hetero(a3m_prot, a3m_sm, Ls) + msa = a3m['msa'].long() + ins = a3m['ins'].long() + + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, params) + + idx = torch.arange(sum(Ls)) + idx[Ls[0]:] += 200 + + chain_idx = torch.zeros((sum(Ls), sum(Ls))).long() + chain_idx[:Ls[0], :Ls[0]] = 1 + chain_idx[Ls[0]:, Ls[0]:] = 1 + bond_feats = torch.zeros((sum(Ls), sum(Ls))).long() + bond_feats[:Ls[0], :Ls[0]] = get_protein_bond_feats(Ls[0]) + bond_feats[Ls[0]:, Ls[0]:] = get_bond_feats(mol, G) + + ntempl = np.random.randint(params['MINTPLT'], params['MAXTPLT']-1) + xyz_t, f1d_t = TemplFeaturize(tpltA, sum(Ls), params, offset=0, npick=ntempl, pick_top=pick_top) + # give template of native backbone if none exists + + #generate initial coordinates + xyz_prev = xyz_t[0] + + if sum(Ls) > params["CROP"]: + sel = crop_small_molecule(xyz_prot, xyz_sm[0], Ls, params) + + seq = seq[:,sel] + msa_seed_orig = msa_seed_orig[:,:,sel] + msa_seed = msa_seed[:,:,sel] + msa_extra = msa_extra[:,:,sel] + mask_msa = mask_msa[:,:,sel] + xyz = xyz[:,sel] + mask = mask[:,sel] + xyz_t = xyz_t[:,sel] + f1d_t = f1d_t[:,sel] + xyz_prev = xyz_prev[sel] # need to initialize ligand atoms + # + idx = idx[sel] + chain_idx = chain_idx[sel][:,sel] + bond_feats = bond_feats[sel][:, sel] + bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=NBTYPES) + # replace missing with blackholes & conovert NaN to zeros to avoid any NaN problems during loss calculation + # init = INIT_CRDS.reshape(1, NTOTAL, 3).repeat(len(xyz), 1, 1) + # xyz = torch.where(mask[...,None], xyz, init).contiguous() + # xyz = torch.nan_to_num(xyz) + + return seq.long(), msa_seed_orig.long(), msa_seed.float(), msa_extra.float(), mask_msa,\ + xyz.float(), mask, idx.long(), \ + xyz_t.float(), f1d_t.float(), xyz_prev.float(), \ + chain_idx, False, False, frames, bond_feats + +def crop_small_molecule(prot_xyz, lig_xyz,Ls, params): + """choose residues with calphas close to the ligand center of mass""" + ligand_com = torch.nanmean(lig_xyz, dim=[0,1]).expand(1,3) + dist = torch.cdist(prot_xyz[:,1].double(), ligand_com).flatten() + _, idx = torch.topk(dist, params["CROP"]-len(lig_xyz), largest=False) + sel, _ = torch.sort(idx) + # select the whole ligand + lig_sel = torch.arange(lig_xyz.shape[0])+Ls[0] + return torch.cat((sel, lig_sel)) + + +class Dataset(data.Dataset): + def __init__(self, IDs, loader, item_dict, params, homo, unclamp_cut=0.9, pick_top=True, p_homo_cut=-1.0): + self.IDs = IDs + self.item_dict = item_dict + self.loader = loader + self.params = params + self.homo = homo + self.pick_top = pick_top + self.unclamp_cut = unclamp_cut + self.p_homo_cut = p_homo_cut + + def __len__(self): + return len(self.IDs) + + def __getitem__(self, index): + ID = self.IDs[index] + sel_idx = np.random.randint(0, len(self.item_dict[ID])) + p_unclamp = np.random.rand() + if p_unclamp > self.unclamp_cut: + out = self.loader(self.item_dict[ID][sel_idx][0], self.params, self.homo, + unclamp=True, + pick_top=self.pick_top, + p_homo_cut=self.p_homo_cut) + else: + out = self.loader(self.item_dict[ID][sel_idx][0], self.params, self.homo, + pick_top=self.pick_top, + p_homo_cut=self.p_homo_cut) + return out + +class DatasetComplex(data.Dataset): + def __init__(self, IDs, loader, item_dict, params, pick_top=True, negative=False): + self.IDs = IDs + self.item_dict = item_dict + self.loader = loader + self.params = params + self.pick_top = pick_top + self.negative = negative + + def __len__(self): + return len(self.IDs) + + def __getitem__(self, index): + ID = self.IDs[index] + sel_idx = np.random.randint(0, len(self.item_dict[ID])) + out = self.loader(self.item_dict[ID][sel_idx][0], + self.item_dict[ID][sel_idx][1], + self.item_dict[ID][sel_idx][2], + self.item_dict[ID][sel_idx][3], + self.params, + pick_top = self.pick_top, + negative = self.negative) + return out + +class DatasetNAComplex(data.Dataset): + def __init__(self, IDs, loader, item_dict, params, pick_top=True, negative=False, native_NA_frac=0.0): + self.IDs = IDs + self.item_dict = item_dict + self.loader = loader + self.params = params + self.pick_top = pick_top + self.negative = negative + self.native_NA_frac = native_NA_frac + + def __len__(self): + return len(self.IDs) + + def __getitem__(self, index): + ID = self.IDs[index] + sel_idx = np.random.randint(0, len(self.item_dict[ID])) + out = self.loader( + self.item_dict[ID][sel_idx][0], + self.item_dict[ID][sel_idx][1], + self.params, + pick_top = self.pick_top, + negative = self.negative, + native_NA_frac = self.native_NA_frac + ) + return out + +class DatasetRNA(data.Dataset): + def __init__(self, IDs, loader, item_dict, params): + self.IDs = IDs + self.item_dict = item_dict + self.loader = loader + self.params = params + + def __len__(self): + return len(self.IDs) + + def __getitem__(self, index): + ID = self.IDs[index] + sel_idx = np.random.randint(0, len(self.item_dict[ID])) + out = self.loader( + self.item_dict[ID][sel_idx][0], + self.item_dict[ID][sel_idx][1], + self.params + ) + return out + + +class DatasetSMComplex(data.Dataset): + def __init__(self, IDs, loader, item_dict, params): + self.IDs = IDs + self.item_dict = item_dict + self.loader = loader + self.params = params + + def __len__(self): + return len(self.IDs) + + def __getitem__(self, index): + ID = self.IDs[index] + sel_idx = np.random.randint(0, len(self.item_dict[ID])) + out = self.loader( + self.item_dict[ID][sel_idx][0], + self.item_dict[ID][sel_idx][2], + self.params + ) + return out + + +class DistilledDataset(data.Dataset): + def __init__( + self, + pdb_IDs, pdb_loader, pdb_dict, + compl_IDs, compl_loader, compl_dict, + neg_IDs, neg_loader, neg_dict, + na_compl_IDs, na_compl_loader, na_compl_dict, + na_neg_IDs, na_neg_loader, na_neg_dict, + fb_IDs, fb_loader, fb_dict, + rna_IDs, rna_loader, rna_dict, + sm_compl_IDs, sm_compl_loader, sm_compl_dict, + homo, + params, + native_NA_frac=0.25, + unclamp_cut=0.9 + ): + # + self.pdb_IDs = pdb_IDs + self.pdb_dict = pdb_dict + self.pdb_loader = pdb_loader + self.compl_IDs = compl_IDs + self.compl_loader = compl_loader + self.compl_dict = compl_dict + self.neg_IDs = neg_IDs + self.neg_loader = neg_loader + self.neg_dict = neg_dict + self.na_compl_IDs = na_compl_IDs + self.na_compl_loader = na_compl_loader + self.na_compl_dict = na_compl_dict + self.na_neg_IDs = na_neg_IDs + self.na_neg_loader = na_neg_loader + self.na_neg_dict = na_neg_dict + self.fb_IDs = fb_IDs + self.fb_dict = fb_dict + self.fb_loader = fb_loader + self.rna_IDs = rna_IDs + self.rna_dict = rna_dict + self.rna_loader = rna_loader + self.sm_compl_IDs = sm_compl_IDs + self.sm_compl_loader = sm_compl_loader + self.sm_compl_dict = sm_compl_dict + self.homo = homo + self.params = params + self.unclamp_cut = unclamp_cut + self.native_NA_frac = native_NA_frac + + self.compl_inds = np.arange(len(self.compl_IDs)) + self.neg_inds = np.arange(len(self.neg_IDs)) + self.na_compl_inds = np.arange(len(self.na_compl_IDs)) + self.na_neg_inds = np.arange(len(self.na_neg_IDs)) + self.fb_inds = np.arange(len(self.fb_IDs)) + self.pdb_inds = np.arange(len(self.pdb_IDs)) + self.rna_inds = np.arange(len(self.rna_IDs)) + self.sm_compl_inds = np.arange(len(self.sm_compl_IDs)) + + def __len__(self): + return ( + len(self.fb_inds) + + len(self.pdb_inds) + + len(self.compl_inds) + + len(self.neg_inds) + + len(self.na_compl_inds) + + len(self.na_neg_inds) + + len(self.rna_inds) + + len(self.sm_compl_inds) + ) + + # order: + # 0 - nfb-1 = FB + # nfb - nfb+npdb-1 = PDB + # "+npdb - "+ncmpl-1 = COMPLEX + # "+ncmpl - "+nneg-1 = COMPLEX NEGATIVES + # "+nneg - "+nna_cmpl-1 = NA COMPLEX + # "+nna_cmpl - "+nrna-1 = NA COMPLEX NEGATIVES + # "+nrna-1 - "nsm_compl-1 = RNA + # nsm_compl -1 - = SM COMPLEX + def __getitem__(self, index): + p_unclamp = np.random.rand() + + if index < len(self.fb_inds): + ID = self.fb_IDs[index] + sel_idx = np.random.randint(0, len(self.fb_dict[ID])) + out = self.fb_loader(self.fb_dict[ID][sel_idx][0], self.params, unclamp=(p_unclamp > self.unclamp_cut)) + + offset = len(self.fb_inds) + if index >= offset and index < offset + len(self.pdb_inds): + ID = self.pdb_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.pdb_dict[ID])) + out = self.pdb_loader(self.pdb_dict[ID][sel_idx][0], self.params, self.homo, unclamp=(p_unclamp > self.unclamp_cut)) + + offset += len(self.pdb_inds) + if index >= offset and index < offset + len(self.compl_inds): + ID = self.compl_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.compl_dict[ID])) + out = self.compl_loader( + self.compl_dict[ID][sel_idx][0], + self.compl_dict[ID][sel_idx][1], + self.compl_dict[ID][sel_idx][2], + self.compl_dict[ID][sel_idx][3], + self.params, + negative=False + ) + + offset += len(self.compl_inds) + if index >= offset and index < offset + len(self.neg_inds): + ID = self.neg_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.neg_dict[ID])) + out = self.neg_loader( + self.neg_dict[ID][sel_idx][0], + self.neg_dict[ID][sel_idx][1], + self.neg_dict[ID][sel_idx][2], + self.neg_dict[ID][sel_idx][3], + self.params, + negative=True + ) + + offset += len(self.neg_inds) + if index >= offset and index < offset + len(self.na_compl_inds): + ID = self.na_compl_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.na_compl_dict[ID])) + out = self.na_compl_loader( + self.na_compl_dict[ID][sel_idx][0], + self.na_compl_dict[ID][sel_idx][1], + self.params, + negative=False, + native_NA_frac=self.native_NA_frac + ) + + offset += len(self.na_compl_inds) + if index >= offset and index < offset + len(self.na_neg_inds): + ID = self.na_neg_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.na_neg_dict[ID])) + out = self.na_neg_loader( + self.na_neg_dict[ID][sel_idx][0], + self.na_neg_dict[ID][sel_idx][1], + self.params, + negative=True, + native_NA_frac=self.native_NA_frac + ) + + offset += len(self.na_neg_inds) + if index >= offset and index < offset + len(self.rna_inds): + ID = self.rna_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.rna_dict[ID])) + out = self.rna_loader( + self.rna_dict[ID][sel_idx][0], + self.rna_dict[ID][sel_idx][1], + self.params + ) + offset += len(self.rna_inds) + if index >= offset: + ID = self.sm_compl_IDs[index-offset] + sel_idx = np.random.randint(0, len(self.sm_compl_dict[ID])) + out = self.sm_compl_loader( + self.sm_compl_dict[ID][sel_idx][0], + self.sm_compl_dict[ID][sel_idx][2], + self.params + ) + return out + +class DistributedWeightedSampler(data.Sampler): + def __init__( + self, + dataset, + pdb_weights, + fb_weights, + compl_weights, + neg_weights, + na_compl_weights, + neg_na_compl_weights, + rna_weights, + sm_compl_weights, + num_example_per_epoch=25600, + fraction_fb=0.16, + fraction_compl=0.16, # half neg, half pos + fraction_na_compl=0.16, # half neg, half pos + fraction_rna=0.16, + fraction_sm_compl=0.16, + num_replicas=None, + rank=None, + replacement=False + ): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + + assert num_example_per_epoch % num_replicas == 0 + assert (fraction_fb+fraction_compl+fraction_na_compl+fraction_rna+fraction_sm_compl<= 1.0) + + self.dataset = dataset + self.num_replicas = num_replicas + self.num_fb_per_epoch = int(round(num_example_per_epoch*fraction_fb)) + self.num_compl_per_epoch = int(round(0.5*num_example_per_epoch*fraction_compl)) + self.num_neg_per_epoch = self.num_compl_per_epoch + self.num_na_compl_per_epoch = int(round(0.5*num_example_per_epoch*fraction_na_compl)) + self.num_neg_na_compl_per_epoch = self.num_na_compl_per_epoch + self.num_rna_per_epoch = int(round(num_example_per_epoch*fraction_rna)) + self.num_sm_compl_per_epoch = int(round(num_example_per_epoch*fraction_sm_compl)) + + self.num_pdb_per_epoch = num_example_per_epoch - ( + self.num_fb_per_epoch + + self.num_compl_per_epoch + + self.num_neg_per_epoch + + self.num_na_compl_per_epoch + + self.num_neg_na_compl_per_epoch + + self.num_rna_per_epoch + + self.num_sm_compl_per_epoch + ) + + if (rank==0): + print ( + "Per epoch:", + self.num_pdb_per_epoch,"pdb,", + self.num_fb_per_epoch,"fb,", + self.num_compl_per_epoch,"compl,", + self.num_neg_per_epoch,"neg,", + self.num_na_compl_per_epoch,"NA compl,", + self.num_neg_na_compl_per_epoch,"NA neg,", + self.num_rna_per_epoch,"RNA,", + self.num_sm_compl_per_epoch, "SM Compl." + ) + + + self.total_size = num_example_per_epoch + self.num_samples = self.total_size // self.num_replicas + self.rank = rank + self.epoch = 0 + self.replacement = replacement + + self.pdb_weights = pdb_weights + self.fb_weights = fb_weights + + self.compl_weights = compl_weights + self.neg_weights = neg_weights + + self.na_compl_weights = na_compl_weights + self.neg_na_compl_weights = neg_na_compl_weights + + self.rna_weights = rna_weights + self.sm_compl_weights = sm_compl_weights + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + # get indices (fb + pdb models) + indices = torch.arange(len(self.dataset)) + + # weighted subsampling + # order: + # 0 - nfb-1 = FB + # nfb - nfb+npdb-1 = PDB + # "+npdb - "+ncmpl-1 = COMPLEX + # "+ncmpl - "+nneg-1 = COMPLEX NEGATIVES + # "+nneg - "+nna_cmpl-1 = NA COMPLEX + # "+nna_cmpl - "+nrna-1 = NA COMPLEX NEGATIVES + # "+nrna-1 - = RNA + sel_indices = torch.tensor((),dtype=int) + if (self.num_fb_per_epoch>0): + fb_sampled = torch.multinomial(self.fb_weights, self.num_fb_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[fb_sampled])) + + if (self.num_pdb_per_epoch>0): + offset = len(self.dataset.fb_IDs) + pdb_sampled = torch.multinomial(self.pdb_weights, self.num_pdb_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[pdb_sampled + offset])) + + if (self.num_compl_per_epoch>0): + offset = len(self.dataset.fb_IDs) + len(self.dataset.pdb_IDs) + compl_sampled = torch.multinomial(self.compl_weights, self.num_compl_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[compl_sampled + offset])) + + if (self.num_neg_per_epoch>0): + offset = len(self.dataset.fb_IDs) + len(self.dataset.pdb_IDs) + len(self.dataset.compl_IDs) + neg_sampled = torch.multinomial(self.neg_weights, self.num_neg_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[neg_sampled + offset])) + + if (self.num_na_compl_per_epoch>0): + offset = ( + len(self.dataset.fb_IDs) + + len(self.dataset.pdb_IDs) + + len(self.dataset.compl_IDs) + + len(self.dataset.neg_IDs) + ) + na_compl_sampled = torch.multinomial(self.na_compl_weights, self.num_na_compl_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[na_compl_sampled + offset])) + + if (self.num_neg_na_compl_per_epoch>0): + offset = ( + len(self.dataset.fb_IDs) + + len(self.dataset.pdb_IDs) + + len(self.dataset.compl_IDs) + + len(self.dataset.neg_IDs) + + len(self.dataset.na_compl_IDs) + ) + neg_na_sampled = torch.multinomial(self.neg_na_compl_weights, self.num_neg_na_compl_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[neg_na_sampled + offset])) + + if (self.num_rna_per_epoch>0): + offset = ( + len(self.dataset.fb_IDs) + + len(self.dataset.pdb_IDs) + + len(self.dataset.compl_IDs) + + len(self.dataset.neg_IDs) + + len(self.dataset.na_compl_IDs) + + len(self.dataset.na_neg_IDs) + ) + rna_sampled = torch.multinomial(self.rna_weights, self.num_rna_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[rna_sampled + offset])) + + if (self.num_sm_compl_per_epoch>0): + offset = ( + len(self.dataset.fb_IDs) + + len(self.dataset.pdb_IDs) + + len(self.dataset.compl_IDs) + + len(self.dataset.neg_IDs) + + len(self.dataset.na_compl_IDs) + + len(self.dataset.na_neg_IDs) + + len(self.dataset.rna_IDs) + ) + sm_compl_sampled = torch.multinomial(self.sm_compl_weights, self.num_sm_compl_per_epoch, self.replacement, generator=g) + sel_indices = torch.cat((sel_indices, indices[sm_compl_sampled + offset])) + + + # shuffle indices + indices = sel_indices[torch.randperm(len(sel_indices), generator=g)] + + # per each gpu + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices.tolist()) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + diff --git a/RF2_allatom/eval.py b/RF2_allatom/eval.py new file mode 100644 index 0000000..77541e8 --- /dev/null +++ b/RF2_allatom/eval.py @@ -0,0 +1,342 @@ +import sys, os +import time +import numpy as np +import torch +import torch.nn as nn +from torch.utils import data +from parsers import parse_a3m, parse_fasta, read_template_pdb +from RoseTTAFoldModel import RoseTTAFoldModule +import util +from collections import namedtuple +from ffindex import * +from data_loader import MSAFeaturize, MSABlockDeletion, merge_a3m_homo +from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, get_init_xyz +from util_module import ComputeAllAtomCoords +from chemical import NTOTAL, NTOTALDOFS, NAATOKENS + +from memory import mem_report + +MAX_CYCLE = 30 +NREPLICATES = 5 +NBIN = [37, 37, 37, 19] + +MODEL_PARAM ={ + "n_extra_block" : 4, + "n_main_block" : 32, + "n_ref_block" : 4, + "d_msa" : 256 , + "d_pair" : 128, + "d_templ" : 64, + "n_head_msa" : 8, + "n_head_pair" : 4, + "n_head_templ" : 4, + "d_hidden" : 32, + "d_hidden_templ" : 64, + "p_drop" : 0.15, + "lj_lin" : 0.75 + } + +SE3_param = { + "num_layers" : 1, + "num_channels" : 32, + "num_degrees" : 2, + "l0_in_features": 64, + "l0_out_features": 64, + "l1_in_features": 3, + "l1_out_features": 2, + "num_edge_features": 64, + "div": 4, + "n_heads": 4 + } +SE3_ref_param = { + "num_layers" : 2, + "num_channels" : 32, + "num_degrees" : 2, + "l0_in_features": 64, + "l0_out_features": 64, + "l1_in_features": 3, + "l1_out_features": 2, + "num_edge_features": 64, + "div": 4, + "n_heads": 4 + } +MODEL_PARAM['SE3_param'] = SE3_param +MODEL_PARAM['SE3_ref_param'] = SE3_ref_param + + +# params for the folding protocol +fold_params = { + "SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21, + "SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231, + "DCUT" : 19.5, + "ALPHA" : 1.57, + + # TODO: add Cb to the motif + "NCAC" : np.array([[-0.676, -1.294, 0. ], + [ 0. , 0. , 0. ], + [ 1.5 , -0.174, 0. ]], dtype=np.float32), + "CLASH" : 2.0, + "PCUT" : 0.5, + "DSTEP" : 0.5, + "ASTEP" : np.deg2rad(10.0), + "XYZRAD" : 7.5, + "WANG" : 0.1, + "WCST" : 0.1 +} + +fold_params["SG"] = fold_params["SG9"] + +# compute expected value from binned lddt +def lddt_unbin(pred_lddt): + nbin = pred_lddt.shape[1] + bin_step = 1.0 / nbin + lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device) + + pred_lddt = nn.Softmax(dim=1)(pred_lddt) + return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) + +class Predictor(): + def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"): + if model_dir == None: + self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__))) + else: + self.model_dir = model_dir + # + # define model name + self.model_name = model_name + self.device = device + self.active_fn = nn.Softmax(dim=1) + + # define model & load model + self.model = RoseTTAFoldModule( + **MODEL_PARAM, + aamask=util.allatom_mask.to(self.device), + ljlk_parameters=util.ljlk_parameters.to(self.device), + lj_correction_parameters=util.lj_correction_parameters.to(self.device), + num_bonds=util.num_bonds.to(self.device) + ).to(self.device) + + could_load = self.load_model(self.model_name) + if not could_load: + print ("ERROR: failed to load model") + sys.exit() + + self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device) + + def load_model(self, model_name, suffix='last'): + chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix) + if not os.path.exists(chk_fn): + return False + checkpoint = torch.load(chk_fn, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + return True + + def predict(self, fasta_fn, out_prefix, tmpl_fn=None, atab_fn=None, window=1e9, shift=50, n_latent=256, oligo=1): + msa_orig, ins_orig = parse_fasta(fasta_fn, rmsa_alphabet=True) + msa_orig = torch.tensor(msa_orig).long() + ins_orig = torch.tensor(ins_orig).long() + + #sel = torch.arange(941) #, msa_orig.shape[1]) + #msa_orig = msa_orig[:,sel] + #ins_orig = ins_orig[:,sel] + if (oligo>1): + msa_orig, ins_orig = merge_a3m_homo(msa_orig, ins_orig, oligo) # make unpaired alignments, for training, we always use two chains + + N, L = msa_orig.shape + # + if tmpl_fn and os.path.exists(tmpl_fn): + xyz_t, t1d = read_template_pdb(L, tmpl_fn) + #xyz_t, t1d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=4) + else: + xyz_t = torch.full((1,L,3,3),np.nan).float() + t1d = torch.nn.functional.one_hot(torch.full((1, L), 20).long(), num_classes=NAATOKENS-1).float() # all gaps + t1d = torch.cat((t1d, torch.zeros((1,L,1)).float()), -1) + # + # template features + xyz_t = xyz_t.float().unsqueeze(0) + t1d = t1d.float().unsqueeze(0) + t2d = xyz_to_t2d(xyz_t) + + same_chain = torch.ones((1,L,L), dtype=torch.bool, device=xyz_t.device) + xyz_t = get_init_xyz(msa_orig[0:1],xyz_t,same_chain) # initialize coordinates with first template + + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + + alpha, _, alpha_mask, _ = util.get_torsions( + xyz_t.reshape(-1,L,NTOTAL,3), + seq_tmp, + util.torsion_indices, + util.torsion_can_flip, + util.reference_angles + ) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(1,-1,L,NTOTALDOFS,2) + alpha_mask = alpha_mask.reshape(1,-1,L,NTOTALDOFS,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 3*NTOTALDOFS) + + self.model.eval() + for i_trial in range(NREPLICATES): + if os.path.exists("%s_%02d_init.pdb"%(out_prefix, i_trial)): + continue + self.run_prediction(msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial), n_latent=n_latent) + torch.cuda.empty_cache() + + def run_prediction(self, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz, alpha_t, out_prefix, n_latent=256): + start = time.time() + torch.cuda.reset_peak_memory_stats() + with torch.no_grad(): + # + msa = msa_orig.to(self.device) # (N, L) + ins = ins_orig.long().to(self.device) + #if msa_orig.shape[0] > 4096: + # msa, ins = MSABlockDeletion(msa, ins) + # + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize( + msa, ins, p_mask=0.0, params={'MAXLAT': 128, 'MAXSEQ': 1024, 'MAXCYCLE': MAX_CYCLE}, tocpu=True) + _, N, L = msa_seed.shape[:3] + B = 1 + # + idx_pdb = torch.arange(L).long().view(1, L) + # + seq = seq.unsqueeze(0) + msa_seed = msa_seed.unsqueeze(0) + msa_extra = msa_extra.unsqueeze(0) + + t1d = t1d.to(self.device) + t2d = t2d.to(self.device) + idx_pdb = idx_pdb.to(self.device) + xyz_t = xyz_t.to(self.device) + alpha_t = alpha_t.to(self.device) + xyz = xyz.to(self.device) + + self.write_pdb(seq[0, -1], xyz[0], prefix="%s_templ"%(out_prefix)) + + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((1,L,NTOTALDOFS,2), device=seq.device) + xyz_prev=xyz + state_prev = None + + best_lddt = torch.tensor([-1.0], device=seq.device) + best_xyz = None + best_logit = None + best_aa = None + for i_cycle in range(MAX_CYCLE): + msa_seed_i = msa_seed[:,i_cycle].to(self.device) + msa_extra_i = msa_extra[:,i_cycle].to(self.device) + with torch.cuda.amp.autocast(True): + logit_s, logit_aa_s, init_crds, alpha_prev, _, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model( + msa_seed_i, + msa_extra_i, + seq[:,i_cycle], + seq[:,i_cycle], + xyz_prev, + alpha_prev, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev + ) + + logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1) + + xyz_prev = init_crds[-1] + alpha_prev = alpha_prev[-1] + pred_lddt = lddt_unbin(pred_lddt_binned) + + print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean()) + _, all_crds = self.compute_allatom_coords(seq[:,i_cycle], init_crds[-1], alpha_prev) + #self.write_pdb(seq[0, -1], all_crds[0], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle)) + + if pred_lddt.mean() < best_lddt.mean(): + continue + best_xyz = all_crds.clone() + best_logit = logit_s + best_aa = logit_aa_s + best_lddt = pred_lddt.clone() + #print (pred_lddt) + + prob_s = list() + for logit in logit_s: + prob = self.active_fn(logit.float()) # distogram + prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy() + prob_s.append(prob) + + end = time.time() + + for prob in prob_s: + prob += 1e-8 + prob = prob / torch.sum(prob, dim=0)[None] + self.write_pdb(seq[0, -1], best_xyz[0], Bfacts=100*best_lddt[0], prefix="%s_init"%(out_prefix)) + prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s] + np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \ + omega=prob_s[1].astype(np.float16),\ + theta=prob_s[2].astype(np.float16),\ + phi=prob_s[3].astype(np.float16),\ + lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16)) + + max_mem = torch.cuda.max_memory_allocated()/1e9 + print ("max mem", max_mem) + print ("runtime", end-start) + + + def write_pdb(self, seq, atoms, Bfacts=None, prefix=None): + L = len(seq) + filename = "%s.pdb"%prefix + ctr = 1 + with open(filename, 'wt') as f: + if Bfacts == None: + Bfacts = np.zeros(L) + else: + Bfacts = torch.clamp( Bfacts, 0, 100) + + for i,s in enumerate(seq): + if (len(atoms.shape)==2): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, " CA ", util.num2aa[s], + "A", i+1, atoms[i,0], atoms[i,1], atoms[i,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + + elif atoms.shape[1]==3: + for j,atm_j in enumerate((" N "," CA "," C ")): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, atm_j, util.num2aa[s], + "A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + else: + atms = util.aa2long[s] + for j,atm_j in enumerate(atms): + if (atm_j is not None): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, atm_j, util.num2aa[s], + "A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + +def get_args(): + #DB="/home/robetta/rosetta_server_beta/external/databases/trRosetta/pdb100_2021Mar03/pdb100_2021Mar03" + DB = "/projects/ml/TrRosetta/pdb100_2020Mar11/pdb100_2020Mar11" + import argparse + parser = argparse.ArgumentParser(description="RoseTTAFold: Protein structure prediction with 3-track attentions on 1D, 2D, and 3D features") + parser.add_argument("fasta", help="fasta for structure prediction") + parser.add_argument("-oligo", type=int, default=1) + parser.add_argument("-prefix", type=str, default="pred") + parser.add_argument("-tmpl", default=None) + parser.add_argument("-model_name", default="BFF", required=False, + help="Prefix for model. The model under models/[model_name]_best.pt will be used. [BFF]") + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = get_args() + + pred = Predictor(model_name=args.model_name) + + pred.predict(args.fasta, args.prefix, args.tmpl, oligo=args.oligo) diff --git a/RF2_allatom/eval_fb.py b/RF2_allatom/eval_fb.py new file mode 100644 index 0000000..6470caf --- /dev/null +++ b/RF2_allatom/eval_fb.py @@ -0,0 +1,395 @@ +import sys, os +import time +import numpy as np +import torch +import torch.nn as nn +from torch.utils import data +#from parsers import parse_a3m, read_templates +from RoseTTAFoldModel import RoseTTAFoldModule +import util +from collections import namedtuple +#from ffindex import * +from data_loader import * +from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d, get_init_xyz +from util_module import ComputeAllAtomCoords +from loss import * + +MAX_CYCLE = 4 +NBIN = [37, 37, 37, 19] + +MODEL_PARAM ={ + "n_extra_block" : 4, + "n_main_block" : 32, + "n_ref_block" : 0, + "n_finetune_block" : 4, + "d_msa" : 256 , + "d_pair" : 128, + "d_templ" : 64, + "n_head_msa" : 8, + "n_head_pair" : 4, + "n_head_templ" : 4, + "d_hidden" : 32, + "d_hidden_templ" : 64, + "p_drop" : 0.0, + "lj_lin" : 0.6 +} + +SE3_param = { + "num_layers" : 1, + "num_channels" : 32, + "num_degrees" : 2, + "l0_in_features": 64, + "l0_out_features": 64, + "l1_in_features": 3, + "l1_out_features": 2, + "num_edge_features": 64, + "div": 4, + "n_heads": 4 +} +MODEL_PARAM['SE3_param'] = SE3_param + +LOAD_PARAM = {'shuffle': False, + 'num_workers': 4, + 'pin_memory': True} + +fb_dir = "/projects/ml/TrRosetta/fb_af" +base_dir = "/projects/ml/TrRosetta/PDB30-20FEB17" + +# params for the folding protocol +fold_params = { + "SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21, + "SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231, + "DCUT" : 19.5, + "ALPHA" : 1.57, + + # TODO: add Cb to the motif + "NCAC" : np.array([[-0.676, -1.294, 0. ], + [ 0. , 0. , 0. ], + [ 1.5 , -0.174, 0. ]], dtype=np.float32), + "CLASH" : 2.0, + "PCUT" : 0.5, + "DSTEP" : 0.5, + "ASTEP" : np.deg2rad(10.0), + "XYZRAD" : 7.5, + "WANG" : 0.1, + "WCST" : 0.1 +} + +fold_params["SG"] = fold_params["SG9"] + +# compute expected value from binned lddt +def lddt_unbin(pred_lddt): + nbin = pred_lddt.shape[1] + bin_step = 1.0 / nbin + lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device) + + pred_lddt = nn.Softmax(dim=1)(pred_lddt) + return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) + +class Predictor(): + def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"): + if model_dir == None: + self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__))) + else: + self.model_dir = model_dir + # + # define model name + self.model_name = model_name + self.device = device + self.active_fn = nn.Softmax(dim=1) + + self.aamask = util.allatom_mask.to(self.device) + self.atom_type_index = util.atom_type_index.to(self.device) + self.ljlk_parameters = util.ljlk_parameters.to(self.device) + self.lj_correction_parameters = util.lj_correction_parameters.to(self.device) + self.num_bonds = util.num_bonds.to(self.device) + self.cb_len = util.cb_length_t.to(self.device) + self.cb_ang = util.cb_angle_t.to(self.device) + self.cb_tor = util.cb_torsion_t.to(self.device) + + # define model & load model + self.model = RoseTTAFoldModule( + **MODEL_PARAM, + aamask=self.aamask, + atom_type_index = self.atom_type_index, + ljlk_parameters = self.ljlk_parameters, + lj_correction_parameters = self.lj_correction_parameters, + num_bonds = self.num_bonds, + cb_len = self.cb_len, + cb_ang = self.cb_ang, + cb_tor = self.cb_tor + ).to(self.device) + + could_load = self.load_model(self.model_name) + if not could_load: + print ("ERROR: failed to load model") + sys.exit() + + self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device) + + self.ti_dev = util.torsion_indices.to(self.device) + self.ti_flip = util.torsion_can_flip.to(self.device) + self.ang_ref = util.reference_angles.to(self.device) + self.l2a = util.long2alt.to(self.device) + self.aamask = util.allatom_mask.to(self.device) + + def load_model(self, model_name, suffix='last'): + chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix) + if not os.path.exists(chk_fn): + return False + checkpoint = torch.load(chk_fn, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + return True + + def run_prediction(self, seq, msa_seed, msa_extra, true_crds, res_mask, atom_mask, idx_pdb, xyz_t, t1d, alpha_t, tag): + self.model.eval() + with torch.no_grad(): + # transfer inputs to device + B, _, N, L, _ = msa_seed.shape + + idx_pdb = idx_pdb.to(self.device, non_blocking=True) # (B, L) + true_crds = true_crds.to(self.device, non_blocking=True) # (B, L, 27, 3) + res_mask = res_mask.to(self.device, non_blocking=True) # (B, L) + atom_mask = atom_mask.to(self.device, non_blocking=True) # (B, L, 27) + + xyz_t = xyz_t.to(self.device, non_blocking=True) + t1d = t1d.to(self.device, non_blocking=True) + alpha_t = alpha_t.to(self.device, non_blocking=True) + + seq = seq.to(self.device, non_blocking=True) + msa_seed = msa_seed.to(self.device, non_blocking=True) + msa_extra = msa_extra.to(self.device, non_blocking=True) + + # processing labels & template features + c6d, _ = xyz_to_c6d(true_crds) + c6d = c6d_to_bins2(c6d) + t2d = xyz_to_t2d(xyz_t) + xyz_t = get_init_xyz(xyz_t) + xyz_prev = xyz_t[:,0] + + # set number of recycles + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((1,L,10,2)).to(self.device, non_blocking=True) #fd we could get this from the template... + state_prev = None + + best_lddt = torch.tensor([-1.0], device=seq.device) + best_xyz = None + best_logit = None + best_aa = None + for i_cycle in range(MAX_CYCLE): + with torch.cuda.amp.autocast(True): + logit_s, logit_aa_s, init_crds, alpha_prev, init_allatom, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model( + msa_seed[:,i_cycle], + msa_extra[:,i_cycle], + seq[:,i_cycle], + xyz_prev, + alpha_prev, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev + ) + + logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1) + + #xyz_prev = init_crds[-1] + xyz_prev = init_allatom[-1].unsqueeze(0) + #msa_prev = msa_prev[:,0] + alpha_prev = alpha_prev[-1] + pred_lddt = lddt_unbin(pred_lddt_binned) + + #print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean()) + #_, all_crds = self.compute_allatom_coords(seq[:,i_cycle], xyz_prev, alpha_prev) + #self.write_pdb(seq[0, -1], all_crds[0], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle)) + + if pred_lddt.mean() < best_lddt.mean(): + continue + best_xyz = init_allatom[-1].clone() + best_logit = logit_s + best_aa = logit_aa_s + best_lddt = pred_lddt.clone() + + # lddt to native + seq = seq[:,0] + res_mask = res_mask[0] + true_tors, true_tors_alt, tors_mask, tors_planar = util.get_torsions( + true_crds, seq, self.ti_dev, self.ti_flip, self.ang_ref, mask_in=atom_mask) + + # get alternative coordinates for ground-truth + true_alt = torch.zeros_like(true_crds) + true_alt.scatter_(2, self.l2a[seq,:,None].repeat(1,1,1,3), true_crds) + natRs_all, _n0 = self.compute_allatom_coords(seq, true_crds[...,:3,:], true_tors) + natRs_all_alt, _n1 = self.compute_allatom_coords(seq, true_alt[...,:3,:], true_tors_alt) + + # - resolve symmetry + xs_mask = self.model.simulator.aamask[seq] # (B, L, 27) + xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss) + xs_mask *= atom_mask # mask missing atoms & residues as well + natRs_all_symm, nat_symm = resolve_symmetry(best_xyz, natRs_all[0], true_crds[0], natRs_all_alt[0], true_alt[0], xs_mask[0]) + + atom_mask_trim = atom_mask[0,res_mask] + true_lddt = calc_allatom_lddt(best_xyz.unsqueeze(0), nat_symm, idx_pdb, atom_mask) + + ljE = calc_lj( + seq[0], init_allatom, + self.aamask, + self.ljlk_parameters, + self.lj_correction_parameters, + self.num_bonds + ) + cbE = calc_cart_bonded( + seq, init_allatom, idx_pdb, self.cb_len, self.cb_ang, self.cb_tor) + + print (tag[0],tag[1],true_lddt.mean().cpu().numpy(), ljE.cpu().numpy(), cbE.cpu().numpy()) + self.write_pdb(seq[0], best_xyz, Bfacts=best_lddt[0], prefix="preds/%s_pred"%(tag[0])) + + #if (true_lddt.mean()<0.9 or true_lddt.mean()>0.97): + # self.write_pdb(seq[0, -1], all_crds[0], Bfacts=best_lddt[0], prefix="%s_pred"%(tag[1])) + # self.write_pdb(seq[0, -1], true_crds[0,...,:14,:], Bfacts=best_lddt[0], prefix="%s_native"%(tag[1])) + + #prob_s = list() + #for logit in logit_s: + # prob = self.active_fn(logit.float()) # distogram + # prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy() + # prob_s.append(prob) + #for prob in prob_s: + # prob += 1e-8 + # prob = prob / torch.sum(prob, dim=0)[None] + #self.write_pdb(seq[0, -1], best_xyz[0], Bfacts=best_lddt[0], prefix="%s_init"%(out_prefix)) + #prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s] + #np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \ + # omega=prob_s[1].astype(np.float16),\ + # theta=prob_s[2].astype(np.float16),\ + # phi=prob_s[3].astype(np.float16),\ + # lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16)) + + + def write_pdb(self, seq, atoms, Bfacts=None, prefix=None): + L = len(seq) + filename = "%s.pdb"%prefix + ctr = 1 + with open(filename, 'wt') as f: + if Bfacts == None: + Bfacts = np.zeros(L) + else: + Bfacts = torch.clamp( Bfacts, 0, 1) + + for i,s in enumerate(seq): + if (len(atoms.shape)==2): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, " CA ", util.num2aa[s], + "A", i+1, atoms[i,0], atoms[i,1], atoms[i,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + + elif atoms.shape[1]==3: + for j,atm_j in enumerate((" N "," CA "," C ")): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, atm_j, util.num2aa[s], + "A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + else: + natoms = atoms.shape[1] + atms = util.aa2long[s] + # his prot hack + if (s==8 and torch.linalg.norm( atoms[i,9,:]-atoms[i,5,:] ) < 1.7): + atms = ( + " N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1", + None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1", + " HD1", None, None, None, None, None, None) # his_d + + for j,atm_j in enumerate(atms): + if (j 85.0 and + len(r[-1].strip()) > 200] + fb = {} + for r in rows: + if r[2] in fb.keys(): + fb[r[2]].append((r[:2], r[-1])) + else: + fb[r[2]] = [(r[:2], r[-1])] + + for i, (id_i, key_i) in enumerate(fb.items()): + if (i%args.j != args.i%args.j): + continue + + item = key_i[0][0] + a3m = get_msa(os.path.join(LOADER_PARAMS["FB_DIR"], "a3m", item[-1][:2], item[-1][2:], item[0]+".a3m.gz"), item[0]) + pdb = get_pdb(os.path.join(LOADER_PARAMS["FB_DIR"], "pdb", item[-1][:2], item[-1][2:], item[0]+".pdb"), + os.path.join(LOADER_PARAMS["FB_DIR"], "pdb", item[-1][:2], item[-1][2:], item[0]+".plddt.npy"), + item[0], LOADER_PARAMS['PLDDTCUT'], LOADER_PARAMS['SCCUT']) + idx = pdb['idx'] + l = a3m['msa'].shape[-1] + + msa = a3m['msa'].long() + ins = a3m['ins'].long() + #if len(msa) > 5: + # msa, ins = MSABlockDeletion(msa, ins) + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, LOADER_PARAMS, p_mask=0.0) + + # No templates + xyz_t = torch.full((1,l,27,3),np.nan).float() + alpha_t = torch.full((1,l,30),0.0).float() + f1d_t = torch.nn.functional.one_hot(torch.full((1, l), 20).long(), num_classes=21).float() # all gaps + f1d_t = torch.cat((f1d_t, torch.zeros((1,l,1)).float()), -1) + + true_crds = torch.full((len(idx),27,3),np.nan).float() + true_crds[:,:,:] = pdb['xyz'] + mask_atoms = torch.zeros((len(idx),27), dtype=torch.bool) + mask_atoms[:,:] = pdb['mask'] + mask_res = mask_atoms.sum(dim=-1)>=3 + + idx_pdb = torch.arange(l).long() + + pred.run_prediction( + seq[None,...], msa_seed[None,...], msa_extra[None,...], + true_crds[None,...], mask_res[None,...], mask_atoms[None,...], + idx_pdb[None,...], xyz_t[None,...], f1d_t[None,...], alpha_t[None,...], + item) + diff --git a/RF2_allatom/eval_model1.py b/RF2_allatom/eval_model1.py new file mode 100644 index 0000000..95f696b --- /dev/null +++ b/RF2_allatom/eval_model1.py @@ -0,0 +1,51 @@ +import os +import numpy as np + +def get_lddt(fn): + data = np.load(fn) + return data['lddt'].mean() + #lddt = list() + #with open(fn) as fp: + # for line in fp: + # if not line.startswith("ATOM"): + # continue + # if line[12:16].strip() == "CA": + # lddt.append(float(line[61:66])) + #return np.mean(lddt) + +ans_home = "/projects/casp/CASP14/eval/official/answers/domainwise" +dom_s = [line.split() for line in open("/projects/casp/CASP14/eval/official/difficulty.domains")] +if not os.path.exists("model1"): + os.mkdir("model1") + +print ("#DomID diff TM TS LDDT") +for domID, diff in dom_s: + tar = domID.split('-')[0] + if not os.path.exists("%s/%s_00_init.pdb"%(tar, tar)): + continue + TM_s = list() + TS_s = list() + lddt_s = list() + pdb_s = list() + for i_iter in range(5): + pdb_fn = "%s/%s_%02d_init.pdb"%(tar, tar, i_iter) + if not os.path.exists(pdb_fn): + continue + npz_fn = "%s/%s_%02d.npz"%(tar, tar, i_iter) + lddt = get_lddt(npz_fn) + lddt_s.append(lddt) + pdb_s.append(pdb_fn) + max_idx = np.argmax(lddt_s) + pdb_fn = pdb_s[max_idx] + lines = os.popen("TMscore %s %s/%s.pdb"%(pdb_fn, ans_home, domID)).readlines() + TM = 0.0 + TS = 0.0 + for line in lines: + if line.startswith("TM-score"): + TM = float(line.split()[2]) + elif line.startswith("GDT-TS"): + TS = float(line.split()[1]) + break + lddt = os.popen("lddt %s %s/%s.pdb | grep Glob"%(pdb_fn, ans_home, domID)).readlines()[-1].split()[-1] + print (domID, diff, TM, TS, lddt) + os.system("cp %s model1/%s_pred.pdb"%(pdb_fn, tar)) diff --git a/RF2_allatom/ffindex.py b/RF2_allatom/ffindex.py new file mode 100644 index 0000000..3484eee --- /dev/null +++ b/RF2_allatom/ffindex.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py + +''' +Created on Apr 30, 2014 + +@author: meiermark +''' + + +import sys +import mmap +from collections import namedtuple + +FFindexEntry = namedtuple("FFindexEntry", "name, offset, length") + + +def read_index(ffindex_filename): + entries = [] + + fh = open(ffindex_filename) + for line in fh: + tokens = line.split("\t") + entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2]))) + fh.close() + + return entries + + +def read_data(ffdata_filename): + fh = open(ffdata_filename, "r+b") + data = mmap.mmap(fh.fileno(), 0) + fh.close() + return data + + +def get_entry_by_name(name, index): + #TODO: bsearch + for entry in index: + if(name == entry.name): + return entry + return None + + +def read_entry_lines(entry, data): + lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n") + return lines + + +def read_entry_data(entry, data): + return data[entry.offset:entry.offset + entry.length - 1] + + +def write_entry(entries, data_fh, entry_name, offset, data): + data_fh.write(data[:-1]) + data_fh.write(bytearray(1)) + + entry = FFindexEntry(entry_name, offset, len(data)) + entries.append(entry) + + return offset + len(data) + + +def write_entry_with_file(entries, data_fh, entry_name, offset, file_name): + with open(file_name, "rb") as fh: + data = bytearray(fh.read()) + return write_entry(entries, data_fh, entry_name, offset, data) + + +def finish_db(entries, ffindex_filename, data_fh): + data_fh.close() + write_entries_to_db(entries, ffindex_filename) + + +def write_entries_to_db(entries, ffindex_filename): + sorted(entries, key=lambda x: x.name) + index_fh = open(ffindex_filename, "w") + + for entry in entries: + index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length)) + + index_fh.close() + + +def write_entry_to_file(entry, data, file): + lines = read_lines(entry, data) + + fh = open(file, "w") + for line in lines: + fh.write(line+"\n") + fh.close() diff --git a/RF2_allatom/kinematics.py b/RF2_allatom/kinematics.py new file mode 100644 index 0000000..13d92cc --- /dev/null +++ b/RF2_allatom/kinematics.py @@ -0,0 +1,266 @@ +import numpy as np +import torch +from util import INIT_CRDS, INIT_NA_CRDS, generate_Cbeta, is_nucleic +from chemical import NTOTAL + +PARAMS = { + "DMIN" : 2.0, + "DMAX" : 20.0, + "DBINS" : 36, + "ABINS" : 36, +} + +# ============================================================ +def get_pair_dist(a, b): + """calculate pair distances between two sets of points + + Parameters + ---------- + a,b : pytorch tensors of shape [batch,nres,3] + store Cartesian coordinates of two sets of atoms + Returns + ------- + dist : pytorch tensor of shape [batch,nres,nres] + stores paitwise distances between atoms in a and b + """ + + dist = torch.cdist(a, b, p=2) + return dist + +# ============================================================ +def get_ang(a, b, c, eps=1e-6): + """calculate planar angles for all consecutive triples (a[i],b[i],c[i]) + from Cartesian coordinates of three sets of atoms a,b,c + + Parameters + ---------- + a,b,c : pytorch tensors of shape [batch,nres,3] + store Cartesian coordinates of three sets of atoms + Returns + ------- + ang : pytorch tensor of shape [batch,nres] + stores resulting planar angles + """ + v = a - b + w = c - b + vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps) + wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps) + vw = torch.sum(vn*wn, dim=-1) + + return torch.acos(torch.clamp(vw,-0.999,0.999)) + +# ============================================================ +def get_dih(a, b, c, d, eps=1e-6): + """calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i]) + given Cartesian coordinates of four sets of atoms a,b,c,d + + Parameters + ---------- + a,b,c,d : pytorch tensors of shape [batch,nres,3] + store Cartesian coordinates of four sets of atoms + Returns + ------- + dih : pytorch tensor of shape [batch,nres] + stores resulting dihedrals + """ + b0 = a - b + b1 = c - b + b2 = d - c + + b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps) + + v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n + w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n + + x = torch.sum(v*w, dim=-1) + y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1) + + return torch.atan2(y+eps, x+eps) + + +# ============================================================ +def xyz_to_c6d(xyz, params=PARAMS): + """convert cartesian coordinates into 2d distance + and orientation maps + + Parameters + ---------- + xyz : pytorch tensor of shape [batch,nres,3,3] + stores Cartesian coordinates of backbone N,Ca,C atoms + Returns + ------- + c6d : pytorch tensor of shape [batch,nres,nres,4] + stores stacked dist,omega,theta,phi 2D maps + """ + + batch = xyz.shape[0] + nres = xyz.shape[1] + + # three anchor atoms + N = xyz[:,:,0] + Ca = xyz[:,:,1] + C = xyz[:,:,2] + + # recreate Cb given N,Ca,C + Cb = generate_Cbeta(N,Ca,C) + + # 6d coordinates order: (dist,omega,theta,phi) + c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device) + + dist = get_pair_dist(Cb,Cb) + dist[torch.isnan(dist)] = 999.9 + c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...] + b,i,j = torch.where(c6d[...,0]=params['DMAX']] = 999.9 + + mask = torch.zeros((batch, nres,nres), dtype=xyz.dtype, device=xyz.device) + mask[b,i,j] = 1.0 + return c6d, mask + +def xyz_to_t2d(xyz_t, params=PARAMS): + """convert template cartesian coordinates into 2d distance + and orientation maps + + Parameters + ---------- + xyz_t : pytorch tensor of shape [batch,templ,nres,3,3] + stores Cartesian coordinates of template backbone N,Ca,C atoms + + Returns + ------- + t2d : pytorch tensor of shape [batch,nres,nres,37+6+3] + stores stacked dist,omega,theta,phi 2D maps + """ + B, T, L = xyz_t.shape[:3] + c6d, mask = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params) + c6d = c6d.view(B, T, L, L, 4) + mask = mask.view(B, T, L, L, 1) + # + # dist to one-hot encoded + dist = dist_to_onehot(c6d[...,0], params) + orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6) + # + mask = torch.isnan(c6d[:,:,:,:,0]) # (B, T, L, L) + t2d = torch.cat((dist, orien, mask.unsqueeze(-1)), dim=-1) + t2d[torch.isnan(t2d)] = 0.0 + return t2d + +def xyz_to_bbtor(xyz, params=PARAMS): + batch = xyz.shape[0] + nres = xyz.shape[1] + + # three anchor atoms + N = xyz[:,:,0] + Ca = xyz[:,:,1] + C = xyz[:,:,2] + + # recreate Cb given N,Ca,C + next_N = torch.roll(N, -1, dims=1) + prev_C = torch.roll(C, 1, dims=1) + phi = get_dih(prev_C, N, Ca, C) + psi = get_dih(N, Ca, C, next_N) + # + phi[:,0] = 0.0 + psi[:,-1] = 0.0 + # + astep = 2.0*np.pi / params['ABINS'] + phi_bin = torch.round((phi+np.pi-astep/2)/astep) + psi_bin = torch.round((psi+np.pi-astep/2)/astep) + return torch.stack([phi_bin, psi_bin], axis=-1).long() + +# ============================================================ +def dist_to_onehot(dist, params=PARAMS): + dist[torch.isnan(dist)] = 999.9 + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + dbins = torch.linspace(params['DMIN']+dstep, params['DMAX'], params['DBINS'],dtype=dist.dtype,device=dist.device) + db = torch.bucketize(dist.contiguous(),dbins).long() + dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS']+1).float() + return dist + +# ============================================================ +def dist_to_bins(dist,params=PARAMS): + """bin 2d distance maps + """ + + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + db = torch.round((dist-params['DMIN']-dstep/2)/dstep) + + db[db<0] = 0 + db[db>params['DBINS']] = params['DBINS'] + + return db.long() + + +# ============================================================ +def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS): + """bin 2d distance and orientation maps + """ + + dstep = (params['DMAX'] - params['DMIN']) / params['DBINS'] + astep = 2.0*np.pi / params['ABINS'] + + db = torch.round((c6d[...,0]-params['DMIN']-dstep/2)/dstep) + ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep) + tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep) + pb = torch.round((c6d[...,3]-astep/2)/astep) + + # put all dparams['DBINS']] = params['DBINS'] + ob[db==params['DBINS']] = params['ABINS'] + tb[db==params['DBINS']] = params['ABINS'] + pb[db==params['DBINS']] = params['ABINS']//2 + + if negative: + db = torch.where(same_chain.bool(), db.long(), params['DBINS']) + ob = torch.where(same_chain.bool(), ob.long(), params['ABINS']) + tb = torch.where(same_chain.bool(), tb.long(), params['ABINS']) + pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2) + + return torch.stack([db,ob,tb,pb],axis=-1).long() + +def get_init_xyz(seq, xyz_t, same_chain): + # input: xyz_t (B, T, L, Natms, 3) + # ouput: xyz (B, T, L, Natms, 3) + B, T, L = xyz_t.shape[:3] + + init = torch.full((B,T,L,NTOTAL,3), np.nan, device=xyz_t.device) + na_mask = is_nucleic(seq) + b_nmask,l_nmask = na_mask.nonzero(as_tuple=True) + b_pmask,l_pmask = (~na_mask).nonzero(as_tuple=True) + + init[b_pmask,:,l_pmask] = INIT_CRDS[None,...].to(xyz_t.device) + init[b_nmask,:,l_nmask] = INIT_NA_CRDS[None,...].to(xyz_t.device) + + if torch.isnan(xyz_t).all(): + return init + + mask = torch.isnan(xyz_t[:,:,:,:3]).any(dim=-1).any(dim=-1) # (B, T, L) + # + center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3) + xyz_t = xyz_t - center_CA.view(B,T,1,1,3) + # + idx_s = list() + for i_b in range(B): + for i_T in range(T): + if mask[i_b, i_T].all(): + continue + exist_in_templ = torch.where(~mask[i_b, i_T])[0] # (L_sub) + is_same_chain_in_templ = same_chain[i_b][:,~mask[i_b,i_T]].bool() # (L, L_sub) + seqmap = (torch.arange(L, device=xyz_t.device)[:,None] - exist_in_templ[None,:]).abs() # (L, L_sub) + seqmap[~is_same_chain_in_templ] += 99999 + seqmap = torch.argmin(seqmap, dim=-1) # (L) + idx = torch.gather(exist_in_templ, -1, seqmap) # (L) + offset_CA = torch.gather(xyz_t[i_b, i_T, :, 1, :], 0, idx.reshape(L,1).expand(-1,3)) + init[i_b,i_T] += offset_CA.reshape(L,1,3) + # + xyz = torch.where(mask.view(B, T, L, 1, 1), init, xyz_t) + return xyz diff --git a/RF2_allatom/loss.py b/RF2_allatom/loss.py new file mode 100644 index 0000000..3553f82 --- /dev/null +++ b/RF2_allatom/loss.py @@ -0,0 +1,812 @@ +import torch +import numpy as np + +from util import ( + rigid_from_3_points, + cb_lengths_CN, + cb_angles_CACN, + cb_angles_CNCA, + cb_torsions_CACNH, + cb_torsions_CANCO, + is_nucleic +) +from chemical import NFRAMES + +from kinematics import get_dih, get_ang +from scoring import HbHybType + +# Loss functions for the training +# 1. BB rmsd loss +# 2. distance loss (or 6D loss?) +# 3. bond geometry loss +# 4. predicted lddt loss + +#fd use improved coordinate frame generation +def get_t(N, Ca, C, eps=1e-5): + I,B,L=N.shape[:3] + Rs,Ts = rigid_from_3_points(N.view(I*B,L,3), Ca.view(I*B,L,3), C.view(I*B,L,3), eps=eps) + Rs = Rs.view(I,B,L,3,3) + Ts = Ts.view(I,B,L,3) + t = Ts.unsqueeze(-2) - Ts.unsqueeze(-3) + return torch.einsum('iblkj, iblmk -> iblmj', Rs, t) # (I,B,L,L,3) **fixed + +def calc_str_loss(pred, true, mask_2d, same_chain, negative=False, d_clamp_intra=10.0, d_clamp_inter=30.0, A=10.0, gamma=0.99, eps=1e-6): + ''' + Calculate Backbone FAPE loss + Input: + - pred: predicted coordinates (I, B, L, n_atom, 3) + - true: true coordinates (B, L, n_atom, 3) + Output: str loss + ''' + I = pred.shape[0] + true = true.unsqueeze(0) + t_tilde_ij = get_t(true[:,:,:,0], true[:,:,:,1], true[:,:,:,2]) + t_ij = get_t(pred[:,:,:,0], pred[:,:,:,1], pred[:,:,:,2]) + + difference = torch.sqrt(torch.square(t_tilde_ij-t_ij).sum(dim=-1) + eps) + clamp = torch.zeros_like(difference) + clamp[:,same_chain==1] = d_clamp_intra + clamp[:,same_chain==0] = d_clamp_inter + difference = torch.clamp(difference, max=clamp) + loss = difference / A # (I, B, L, L) + + # Get a mask information (ignore missing residue + inter-chain residues) + # for positive cases, mask = mask_2d + # for negative cases (non-interacting pairs) mask = mask_2d*same_chain + if negative: + mask = mask_2d * same_chain + else: + mask = mask_2d + # calculate masked loss (ignore missing regions when calculate loss) + loss = (mask[None]*loss).sum(dim=(1,2,3)) / (mask.sum()+eps) # (I) + + # weighting loss + w_loss = torch.pow(torch.full((I,), gamma, device=pred.device), torch.arange(I, device=pred.device)) + w_loss = torch.flip(w_loss, (0,)) + w_loss = w_loss / w_loss.sum() + + tot_loss = (w_loss * loss).sum() + return tot_loss, loss.detach() + +#resolve rotationally equivalent sidechains +def resolve_symmetry(xs, Rsnat_all, xsnat, Rsnat_all_alt, xsnat_alt, atm_mask): + dists = torch.linalg.norm( xs[:,:,None,:] - xs[atm_mask,:][None,None,:,:], dim=-1) + dists_nat = torch.linalg.norm( xsnat[:,:,None,:] - xsnat[atm_mask,:][None,None,:,:], dim=-1) + dists_natalt = torch.linalg.norm( xsnat_alt[:,:,None,:] - xsnat_alt[atm_mask,:][None,None,:,:], dim=-1) + + drms_nat = torch.sum(torch.abs(dists_nat-dists),dim=(-1,-2)) + drms_natalt = torch.sum(torch.abs(dists_nat-dists_natalt), dim=(-1,-2)) + + Rsnat_symm = Rsnat_all + xs_symm = xsnat + + toflip = drms_nataltrsi', Rs, xs[None,...] - Ts[:,None,...]) + xij_t = torch.einsum('rji,rsj->rsi', Rsnat, xsnat[None,...] - Tsnat[:,None,...]) + + #torch.norm(xij-xij_t,dim=-1) + diff = torch.sqrt( torch.sum( torch.square(xij-xij_t), dim=-1 ) + eps ) + + loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean() + + return loss + +# from Ivan: FAPE generalized over atom sets & frames +def compute_general_FAPE(X, Y, atom_mask, frames, frame_mask, Z=10.0, dclamp=10.0, eps=1e-4): + # X (predicted) N x L x 27 x 3 + # Y (native) 1 x L x 27 x 3 + # atom_mask 1 x L x 27 + # frames 1 x L x 6 x 3 x 2 + # frame_mask 1 x L x 6 + + N, L, natoms, _ = X.shape + # flatten middle dims so can gather across residues + X_prime = X.reshape(N, L*natoms, -1, 3).repeat(1,1,NFRAMES,1) + Y_prime = Y.reshape(1, L*natoms, -1, 3).repeat(1,1,NFRAMES,1) + + # reindex frames for flat X + frames_reindex = torch.zeros(frames.shape[:-1], device=frames.device) + for i in range(L): + frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1] + frames_reindex = frames_reindex.long() + + frame_mask *= torch.all( + torch.gather(atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3), + axis=-1) + + X_x = torch.gather(X_prime, 1, frames_reindex[...,0:1].repeat(N,1,1,3)) + X_y = torch.gather(X_prime, 1, frames_reindex[...,1:2].repeat(N,1,1,3)) + X_z = torch.gather(X_prime, 1, frames_reindex[...,2:3].repeat(N,1,1,3)) + uX,tX = rigid_from_3_points(X_x, X_y, X_z) + + Y_x = torch.gather(Y_prime, 1, frames_reindex[...,0:1].repeat(1,1,1,3)) + Y_y = torch.gather(Y_prime, 1, frames_reindex[...,1:2].repeat(1,1,1,3)) + Y_z = torch.gather(Y_prime, 1, frames_reindex[...,2:3].repeat(1,1,1,3)) + uY,tY = rigid_from_3_points(Y_x, Y_y, Y_z) + + xij = torch.einsum( + 'brji,brsj->brsi', + uX[:,frame_mask[0]], X[:,atom_mask[0]][:,None,...] - X_y[:,frame_mask[0]][:,:,None,...] + ) + + xij_t = torch.einsum('rji,rsj->rsi', uY[frame_mask], Y[atom_mask][None,...] - Y_y[frame_mask][:,None,...]) + + diff = torch.sqrt( torch.sum( torch.square(xij-xij_t[None,...]), dim=-1 ) + eps ) + + loss = (1.0/Z) * (torch.clamp(diff, max=dclamp)).mean(dim=(1,2)) + return loss + + +def angle(a, b, c, eps=1e-6): + ''' + Calculate cos/sin angle between ab and cb + a,b,c have shape of (B, L, 3) + ''' + B,L = a.shape[:2] + + u1 = a-b + u2 = c-b + + u1_norm = torch.norm(u1, dim=-1, keepdim=True) + eps + u2_norm = torch.norm(u2, dim=-1, keepdim=True) + eps + + # normalize u1 & u2 --> make unit vector + u1 = u1 / u1_norm + u2 = u2 / u2_norm + u1 = u1.reshape(B*L, 3) + u2 = u2.reshape(B*L, 3) + + # sin_theta = norm(a cross b)/(norm(a)*norm(b)) + # cos_theta = norm(a dot b) / (norm(a)*norm(b)) + sin_theta = torch.norm(torch.cross(u1, u2, dim=1), dim=1, keepdim=True).reshape(B, L, 1) # (B,L,1) + cos_theta = torch.matmul(u1[:,None,:], u2[:,:,None]).reshape(B, L, 1) + + return torch.cat([cos_theta, sin_theta], axis=-1) # (B, L, 2) + +def length(a, b): + return torch.norm(a-b, dim=-1) + +def torsion(a,b,c,d, eps=1e-6): + #A function that takes in 4 atom coordinates: + # a - [B,L,3] + # b - [B,L,3] + # c - [B,L,3] + # d - [B,L,3] + # and returns cos and sin of the dihedral angle between those 4 points in order a, b, c, d + # output - [B,L,2] + u1 = b-a + u1 = u1 / (torch.norm(u1, dim=-1, keepdim=True) + eps) + u2 = c-b + u2 = u2 / (torch.norm(u2, dim=-1, keepdim=True) + eps) + u3 = d-c + u3 = u3 / (torch.norm(u3, dim=-1, keepdim=True) + eps) + # + t1 = torch.cross(u1, u2, dim=-1) #[B, L, 3] + t2 = torch.cross(u2, u3, dim=-1) + t1_norm = torch.norm(t1, dim=-1, keepdim=True) + t2_norm = torch.norm(t2, dim=-1, keepdim=True) + + cos_angle = torch.matmul(t1[:,:,None,:], t2[:,:,:,None])[:,:,0] + sin_angle = torch.norm(u2, dim=-1,keepdim=True)*(torch.matmul(u1[:,:,None,:], t2[:,:,:,None])[:,:,0]) + + cos_sin = torch.cat([cos_angle, sin_angle], axis=-1)/(t1_norm*t2_norm+eps) #[B,L,2] + return cos_sin + +# ideal N-C distance, ideal cos(CA-C-N angle), ideal cos(C-N-CA angle) +# for NA, we do not compute this as it is not computable from the stubs alone +def calc_BB_bond_geom( + seq, pred, idx, eps=1e-6, + ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255, + sig_len=0.02, sig_ang=0.05): + ''' + Calculate backbone bond geometry (bond length and angle) and put loss on them + Input: + - pred: predicted coords (B, L, :, 3), 0; N / 1; CA / 2; C + - true: True coords (B, L, :, 3) + Output: + - bond length loss, bond angle loss + ''' + def cosangle( A,B,C ): + AB = A-B + BC = C-B + ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps) + BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps) + return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999) + + B, L = pred.shape[:2] + + bonded = (idx[:,1:] - idx[:,:-1])==1 + is_prot = ~is_nucleic(seq)[:-1] + + # bond length: C-N + blen_CN_pred = length(pred[:,:-1,2], pred[:,1:,0]).reshape(B,L-1) # (B, L-1) + CN_loss = torch.clamp( torch.abs(blen_CN_pred - ideal_NC) - sig_len, min=0.0 ) + CN_loss = (bonded*is_prot*CN_loss).sum() / ((bonded*is_prot).sum() + eps) + blen_loss = CN_loss #fd squared loss + + # bond angle: CA-C-N, C-N-CA + bang_CACN_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1) + bang_CNCA_pred = cosangle(pred[:,:-1,2], pred[:,1:,0], pred[:,1:,1]).reshape(B,L-1) + CACN_loss = torch.clamp( torch.abs(bang_CACN_pred - ideal_CACN) - sig_ang, min=0.0 ) + CACN_loss = (bonded*is_prot*CACN_loss).sum() / ((bonded*is_prot).sum() + eps) + CNCA_loss = torch.clamp( torch.abs(bang_CNCA_pred - ideal_CNCA) - sig_ang, min=0.0 ) + CNCA_loss = (bonded*is_prot*CNCA_loss).sum() / ((bonded*is_prot).sum() + eps) + bang_loss = CACN_loss + CNCA_loss + + return blen_loss+bang_loss + + +def calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6): + # pred: N x L x 27 x 3 + # idx: 1 x L + # seq: 1 x L + def gen_ang( A,B,C ): + AB = A-B + BC = C-B + ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps) + BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps) + return torch.acos( torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999) ) + + # quadratic from [-1,1], linear elsewhere + def boundfunc(X): + Y = torch.abs(X) + Y[Y<1.0] = torch.square(Y[Y<1.0]) + #Y = torch.square(X) + return Y + + N,L = pred.shape[:2] + cb_loss = torch.zeros(N, device=pred.device) + + ## intra-res + cblens = len_param[seq] + len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3) + len_all = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3) + len_mask = cblens[...,0]!=cblens[...,1] + E_cb_len = ( + len_mask[None,...] * + cblens[None,...,3] * + boundfunc( length(len_all[...,0,:],len_all[...,1,:]) - cblens[...,2] ) + ).sum(dim=(0,3)) / len_mask.sum() + + # figure out which his are his_d + cblens[seq==8] = len_param[-1] + len_idx = cblens[...,:2].to(torch.long).reshape(1,L,-1,1).repeat(N,1,1,3) + len_all_a = torch.gather(pred, 2, len_idx).reshape(N,L,-1,2,3) + len_mask_a = cblens[...,0]!=cblens[...,1] + E_cb_len_a = ( + len_mask_a[None,...] * + cblens[None,...,3] * + boundfunc( length(len_all_a[...,0,:],len_all_a[...,1,:]) - cblens[...,2] ) + ).sum(dim=(0,3)) / len_mask.sum() # N,L + is_his_d = (seq==8)*(E_cb_len_aC + allmask[torch.arange(L-1),2,torch.arange(1,L),0] = False # ignore N->C + + clash = torch.sum( torch.clamp(DISTCUT-dij[allmask],0.0) ) / torch.sum(mask) + return clash + +# Rosetta-like version of LJ (fa_atr+fa_rep) +# lj_lin is switch from linear to 12-6. Smaller values more sharply penalize clashes +def calc_lj( + seq, xs, aamask, ljparams, ljcorr, num_bonds, + lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75, + lj_maxrad=-1.0, eps=1e-8 +): + def ljV(dist, sigma, epsilon, lj_lin, lj_maxrad): + N = dist.shape[0] + linpart = dist0): + sdmax = sigma / lj_maxrad + sd2 = sd*sd + sd6 = sd2 * sd2 * sd2 + sd12 = sd6 * sd6 + ljE = ljE - epsilon * (sd12 - 2 * sd6) + return ljE + + N, L = xs.shape[:2] + + # mask keeps running total of what to compute + mask = aamask[seq][...,None,None]*aamask[seq][None,None,...] + idxes1r = torch.tril_indices(L,L,-1) + mask[idxes1r[0],:,idxes1r[1],:] = False + idxes2r = torch.arange(L) + idxes2a = torch.tril_indices(27,27,0) + mask[idxes2r[:,None],idxes2a[0:1],idxes2r[:,None],idxes2a[1:2]] = False + + # "countpair" can be enforced by making this a weight + mask[idxes2r,:,idxes2r,:] *= num_bonds[seq,:,:] >= 4 #intra-res + mask[idxes2r[:-1],:,idxes2r[1:],:] *= ( + num_bonds[seq[:-1],:,2:3] + num_bonds[seq[1:],0:1,:] + 1 >= 4 #inter-res + ) + si,ai,sj,aj = mask.nonzero(as_tuple=True) + + ds = torch.sqrt( torch.sum ( torch.square( xs[:,si,ai]-xs[:,sj,aj] ), dim=-1 ) + eps ) + + # hbond correction + use_hb_dis = ( + ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,1] + + ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0] ) + use_ohdon_dis = ( # OH are both donors & acceptors + ljcorr[seq[si],ai,0]*ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,0] + +ljcorr[seq[si],ai,0]*ljcorr[seq[sj],aj,0]*ljcorr[seq[sj],aj,1] + ) + use_hb_hdis = ( + ljcorr[seq[si],ai,2]*ljcorr[seq[sj],aj,1] + +ljcorr[seq[si],ai,1]*ljcorr[seq[sj],aj,2] + ) + + # disulfide correction + potential_disulf = ljcorr[seq[si],ai,3]*ljcorr[seq[sj],aj,3] + + ljrs = ljparams[seq[si],ai,0] + ljparams[seq[sj],aj,0] + ljrs[use_hb_dis] = lj_hb_dis + ljrs[use_ohdon_dis] = lj_OHdon_dis + ljrs[use_hb_hdis] = lj_hbond_hdis + + ljss = torch.sqrt( ljparams[seq[si],ai,1] * ljparams[seq[sj],aj,1] + eps ) + ljss [potential_disulf] = 0.0 + + ljval = ljV(ds,ljrs,ljss,lj_lin,lj_maxrad) + + return (torch.sum( ljval, dim=-1 )/torch.sum(aamask[seq])) + + +def calc_hb( + seq, xs, aamask, hbtypes, hbbaseatoms, hbpolys, + hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357, + hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True +): + def evalpoly( ds, xrange, yrange, coeffs ): + v = coeffs[...,0] + for i in range(1,10): + v = v * ds + coeffs[...,i] + minmask = dsxrange[...,1] + v[maxmask] = yrange[maxmask][...,1] + return v + + def cosangle( A,B,C ): + AB = A-B + BC = C-B + ABn = torch.sqrt( torch.sum(torch.square(AB),dim=-1) + eps) + BCn = torch.sqrt( torch.sum(torch.square(BC),dim=-1) + eps) + return torch.clamp(torch.sum(AB*BC,dim=-1)/(ABn*BCn), -0.999,0.999) + + hbts = hbtypes[seq] + hbba = hbbaseatoms[seq] + + rh,ah = (hbts[...,0]>=0).nonzero(as_tuple=True) + ra,aa = (hbts[...,1]>=0).nonzero(as_tuple=True) + D_xs = xs[rh,hbba[rh,ah,0]][:,None,:] + H_xs = xs[rh,ah][:,None,:] + A_xs = xs[ra,aa][None,:,:] + B_xs = xs[ra,hbba[ra,aa,0]][None,:,:] + B0_xs = xs[ra,hbba[ra,aa,1]][None,:,:] + hyb = hbts[ra,aa,2] + polys = hbpolys[hbts[rh,ah,0][:,None],hbts[ra,aa,1][None,:]] + + AH = torch.sqrt( torch.sum( torch.square( H_xs-A_xs), axis=-1) + eps ) + AHD = torch.acos( cosangle( B_xs, A_xs, H_xs) ) + + Es = polys[...,0,0]*evalpoly( + AH,polys[...,0,1:3],polys[...,0,3:5],polys[...,0,5:]) + Es += polys[...,1,0] * evalpoly( + AHD,polys[...,1,1:3],polys[...,1,3:5],polys[...,1,5:]) + + Bm = 0.5*(B0_xs[:,hyb==HbHybType.RING]+B_xs[:,hyb==HbHybType.RING]) + cosBAH = cosangle( Bm, A_xs[:,hyb==HbHybType.RING], H_xs ) + Es[:,hyb==HbHybType.RING] += polys[:,hyb==HbHybType.RING,2,0] * evalpoly( + cosBAH, + polys[:,hyb==HbHybType.RING,2,1:3], + polys[:,hyb==HbHybType.RING,2,3:5], + polys[:,hyb==HbHybType.RING,2,5:]) + + cosBAH1 = cosangle( B_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs ) + cosBAH2 = cosangle( B0_xs[:,hyb==HbHybType.SP3], A_xs[:,hyb==HbHybType.SP3], H_xs ) + Esp3_1 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly( + cosBAH1, + polys[:,hyb==HbHybType.SP3,2,1:3], + polys[:,hyb==HbHybType.SP3,2,3:5], + polys[:,hyb==HbHybType.SP3,2,5:]) + Esp3_2 = polys[:,hyb==HbHybType.SP3,2,0] * evalpoly( + cosBAH2, + polys[:,hyb==HbHybType.SP3,2,1:3], + polys[:,hyb==HbHybType.SP3,2,3:5], + polys[:,hyb==HbHybType.SP3,2,5:]) + Es[:,hyb==HbHybType.SP3] += torch.log( + torch.exp(Esp3_1 * hb_sp3_softmax_fade) + + torch.exp(Esp3_2 * hb_sp3_softmax_fade) + ) / hb_sp3_softmax_fade + + cosBAH = cosangle( B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs ) + Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * evalpoly( + cosBAH, + polys[:,hyb==HbHybType.SP2,2,1:3], + polys[:,hyb==HbHybType.SP2,2,3:5], + polys[:,hyb==HbHybType.SP2,2,5:]) + + BAH = torch.acos( cosBAH ) + B0BAH = get_dih(B0_xs[:,hyb==HbHybType.SP2], B_xs[:,hyb==HbHybType.SP2], A_xs[:,hyb==HbHybType.SP2], H_xs) + + d,m,l = hb_sp2_BAH180_rise, hb_sp2_range_span, hb_sp2_outer_width + Echi = torch.full_like( B0BAH, m-0.5 ) + + mask1 = BAH>np.pi * 2.0 / 3.0 + H = 0.5 * (torch.cos(2 * B0BAH) + 1) + F = d / 2 * torch.cos(3 * (np.pi - BAH[mask1])) + d / 2 - 0.5 + Echi[mask1] = H[mask1] * F + (1 - H[mask1]) * d - 0.5 + + mask2 = BAH>np.pi * (2.0 / 3.0 - l) + mask2 *= ~mask1 + outer_rise = torch.cos(np.pi - (np.pi * 2 / 3 - BAH[mask2]) / l) + F = m / 2 * outer_rise + m / 2 - 0.5 + G = (m - d) / 2 * outer_rise + (m - d) / 2 + d - 0.5 + Echi[mask2] = H[mask2] * F + (1 - H[mask2]) * d - 0.5 + + Es[:,hyb==HbHybType.SP2] += polys[:,hyb==HbHybType.SP2,2,0] * Echi + + tosquish = torch.logical_and(Es > -0.1,Es < 0.1) + Es[tosquish] = -0.025 + 0.5 * Es[tosquish] - 2.5 * torch.square(Es[tosquish]) + Es[Es > 0.1] = 0. + if (normalize): + return (torch.sum( Es ) / torch.sum(aamask[seq])) + else: + return torch.sum( Es ) + +@torch.enable_grad() +def calc_BB_bond_geom_grads(seq, pred, idx, eps=1e-6, ideal_NC=1.329, ideal_CACN=-0.4415, ideal_CNCA=-0.5255, sig_len=0.02, sig_ang=0.05): + pred.requires_grad_(True) + Ebond = calc_BB_bond_geom(seq, pred, idx, eps, ideal_NC, ideal_CACN, ideal_CNCA, sig_len, sig_ang) + return torch.autograd.grad(Ebond, pred) + +@torch.enable_grad() +def calc_cart_bonded_grads(seq, pred, idx, len_param, ang_param, tor_param, eps=1e-6): + pred.requires_grad_(True) + Ecb = calc_cart_bonded(seq, pred, idx, len_param, ang_param, tor_param, eps) + return torch.autograd.grad(Ecb, pred) + +@torch.enable_grad() +def calc_ljallatom_grads( + seq, xyzaa, + aamask, ljparams, ljcorr, num_bonds, + lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75, + lj_maxrad=-1.0, eps=1e-8 +): + xyzaa.requires_grad_(True) + Elj = calc_lj( + seq[0], + xyzaa[...,:3], + aamask, + ljparams, + ljcorr, + num_bonds, + lj_lin, + lj_hb_dis, + lj_OHdon_dis, + lj_hbond_hdis, + lj_maxrad, + eps + ) + return torch.autograd.grad(Elj, (xyzaa,)) + +@torch.enable_grad() +def calc_lj_grads( + seq, xyz, alpha, toaa, + aamask, ljparams, ljcorr, num_bonds, + lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75, + lj_maxrad=-1.0, eps=1e-8 +): + xyz.requires_grad_(True) + alpha.requires_grad_(True) + _, xyzaa = toaa(seq, xyz, alpha) + Elj = calc_lj( + seq[0], + xyzaa[...,:3], + aamask, + ljparams, + ljcorr, + num_bonds, + lj_lin, + lj_hb_dis, + lj_OHdon_dis, + lj_hbond_hdis, + lj_maxrad, + eps + ) + return torch.autograd.grad(Elj, (xyz,alpha)) + +@torch.enable_grad() +def calc_hb_grads( + seq, xyz, alpha, toaa, + aamask, hbtypes, hbbaseatoms, hbpolys, + hb_sp2_range_span=1.6, hb_sp2_BAH180_rise=0.75, hb_sp2_outer_width=0.357, + hb_sp3_softmax_fade=2.5, threshold_distance=6.0, eps=1e-8, normalize=True +): + xyz.requires_grad_(True) + alpha.requires_grad_(True) + _, xyzaa = toaa(seq, xyz, alpha) + Ehb = calc_hb( + seq, + xyzaa[0,...,:3], + aamask, + hbtypes, + hbbaseatoms, + hbpolys, + hb_sp2_range_span, + hb_sp2_BAH180_rise, + hb_sp2_outer_width, + hb_sp3_softmax_fade, + threshold_distance, + eps, + normalize) + return torch.autograd.grad(Ehb, xs) + +def calc_pseudo_dih(pred, true, eps=1e-6): + ''' + calculate pseudo CA dihedral angle and put loss on them + Input: + - predicted & true CA coordinates (I,B,L,3) / (B, L, 3) + Output: + - dihedral angle loss + ''' + I, B, L = pred.shape[:3] + pred = pred.reshape(I*B, L, -1) + true_dih = torsion(true[:,:-3,:],true[:,1:-2,:],true[:,2:-1,:],true[:,3:,:]) # (B, L', 2) + pred_dih = torsion(pred[:,:-3,:],pred[:,1:-2,:],pred[:,2:-1,:],pred[:,3:,:]) # (I*B, L', 2) + pred_dih = pred_dih.reshape(I, B, -1, 2) + dih_loss = torch.square(pred_dih - true_dih).sum(dim=-1).mean() + dih_loss = torch.sqrt(dih_loss + eps) + return dih_loss + +def calc_lddt(pred_ca, true_ca, mask_crds, mask_2d, same_chain, negative=False, interface=False, eps=1e-6): + # Input + # pred_ca: predicted CA coordinates (I, B, L, 3) + # true_ca: true CA coordinates (B, L, 3) + # pred_lddt: predicted lddt values (I-1, B, L) + + I, B, L = pred_ca.shape[:3] + + pred_dist = torch.cdist(pred_ca, pred_ca) # (I, B, L, L) + true_dist = torch.cdist(true_ca, true_ca).unsqueeze(0) # (1, B, L, L) + + mask = torch.logical_and(true_dist > 0.0, true_dist < 15.0) # (1, B, L, L) + # update mask information + mask *= mask_2d[None] + if negative: + mask *= same_chain.bool()[None] + elif interface: + # ignore atoms between the same chain + mask *= ~same_chain.bool()[None] + + mask_crds = mask_crds * (mask[0].sum(dim=-1) != 0) + + delta = torch.abs(pred_dist-true_dist) # (I, B, L, L) + + true_lddt = torch.zeros((I,B,L), device=pred_ca.device) + for distbin in [0.5, 1.0, 2.0, 4.0]: + true_lddt += 0.25*torch.sum((delta<=distbin)*mask, dim=-1) / (torch.sum(mask, dim=-1) + eps) + + true_lddt = mask_crds*true_lddt + true_lddt = true_lddt.sum(dim=(1,2)) / (mask_crds.sum() + eps) + return true_lddt + + +#fd allatom lddt +def calc_allatom_lddt(P, Q, idx, atm_mask, eps=1e-6): + # P - N x L x 27 x 3 + # Q - L x 27 x 3 + N, L = P.shape[:2] + + # distance matrix + Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27) + Pij = torch.sqrt( Pij.sum(dim=-1) + eps) + Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27) + Qij = torch.sqrt( Qij.sum(dim=-1) + eps) + + # get valid pairs + pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A + # ignore missing atoms + pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float() + + # ignore atoms within same residue + pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27) + + delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14) + + lddt = torch.zeros( (N,L,27), device=P.device ) # (N, L, 27) + for distbin in (0.5,1.0,2.0,4.0): + lddt += 0.25 * torch.sum( (delta_PQ<=distbin)*pair_mask, dim=(2,4) + ) / ( torch.sum( pair_mask, dim=(2,4) ) + 1e-8) + + lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps) + return lddt + + +def calc_allatom_lddt_loss(P, Q, pred_lddt, idx, atm_mask, mask_2d, same_chain, negative=False, interface=False, eps=1e-6): + # P - N x L x 27 x 3 + # Q - L x 27 x 3 + # pred_lddt - 1 x nbucket x L + N, L, Natm = P.shape[:3] + + # distance matrix + Pij = torch.square(P[:,:,None,:,None,:]-P[:,None,:,None,:,:]) # (N, L, L, 27, 27) + Pij = torch.sqrt( Pij.sum(dim=-1) + eps) + Qij = torch.square(Q[None,:,None,:,None,:]-Q[None,None,:,None,:,:]) # (1, L, L, 27, 27) + Qij = torch.sqrt( Qij.sum(dim=-1) + eps) + + # get valid pairs + pair_mask = torch.logical_and(Qij>0,Qij<15).float() # only consider atom pairs within 15A + # ignore missing atoms + pair_mask *= (atm_mask[:,:,None,:,None] * atm_mask[:,None,:,None,:]).float() + + # ignore atoms within same residue + pair_mask *= (idx[:,:,None,None,None] != idx[:,None,:,None,None]).float() # (1, L, L, 27, 27) + if negative: + # ignore atoms between different chains + pair_mask *= same_chain.bool()[:,:,:,None,None] + + delta_PQ = torch.abs(Pij-Qij+eps) # (N, L, L, 14, 14) + + lddt = torch.zeros( (N,L,Natm), device=P.device ) # (N, L, 27) + for distbin in (0.5,1.0,2.0,4.0): + lddt += 0.25 * torch.sum( (delta_PQ<=distbin)*pair_mask, dim=(2,4) + ) / ( torch.sum( pair_mask, dim=(2,4) ) + eps) + + final_lddt_by_res = torch.clamp( + (lddt[-1]*atm_mask[0]).sum(-1) + / (atm_mask.sum(-1) + eps), min=0.0, max=1.0) + + # calculate lddt prediction loss + nbin = pred_lddt.shape[1] + bin_step = 1.0 / nbin + lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device) + true_lddt_label = torch.bucketize(final_lddt_by_res[None,...], lddt_bins).long() + lddt_loss = torch.nn.CrossEntropyLoss(reduction='none')( + pred_lddt, true_lddt_label[-1]) + + res_mask = atm_mask.any(dim=-1) + lddt_loss = (lddt_loss * res_mask).sum() / (res_mask.sum() + eps) + + # method 1: average per-residue + #lddt = lddt.sum(dim=-1) / (atm_mask.sum(dim=-1)+1e-8) # L + #lddt = (res_mask*lddt).sum() / (res_mask.sum() + 1e-8) + + # method 2: average per-atom + atm_mask = atm_mask * (pair_mask.sum(dim=(1,3)) != 0) + lddt = (lddt * atm_mask).sum(dim=(1,2)) / (atm_mask.sum() + eps) + + return lddt_loss, lddt diff --git a/RF2_allatom/memory.py b/RF2_allatom/memory.py new file mode 100644 index 0000000..38bfdaf --- /dev/null +++ b/RF2_allatom/memory.py @@ -0,0 +1,57 @@ +import gc + +import torch + +## MEM utils ## +def mem_report(): + '''Report the memory usage of the tensor.storage in pytorch + Both on CPUs and GPUs are reported''' + + def _mem_report(tensors, mem_type): + '''Print the selected tensors of type + There are two major storage types in our major concern: + - GPU: tensors transferred to CUDA devices + - CPU: tensors remaining on the system memory (usually unimportant) + Args: + - tensors: the tensors of specified type + - mem_type: 'CPU' or 'GPU' in current implementation ''' + print('Storage on %s' %(mem_type)) + print('-'*LEN) + total_numel = 0 + total_mem = 0 + visited_data = [] + for tensor in tensors: + if tensor.is_sparse: + continue + # a data_ptr indicates a memory block allocated + data_ptr = tensor.storage().data_ptr() + if data_ptr in visited_data: + continue + visited_data.append(data_ptr) + + numel = tensor.storage().size() + total_numel += numel + element_size = tensor.storage().element_size() + mem = numel*element_size /1024/1024 # 32bit=4Byte, MByte + total_mem += mem + element_type = type(tensor).__name__ + size = tuple(tensor.size()) + + print('%s\t\t%s\t\t%.2f' % ( + element_type, + size, + mem) ) + print('-'*LEN) + print('Total Tensors: %d \tUsed Memory Space: %.2f MBytes' % (total_numel, total_mem) ) + print('-'*LEN) + + LEN = 65 + print('='*LEN) + objects = gc.get_objects() + print('%s\t%s\t\t\t%s' %('Element type', 'Size', 'Used MEM(MBytes)') ) + tensors = [obj for obj in objects if torch.is_tensor(obj)] + cuda_tensors = [t for t in tensors if t.is_cuda] + host_tensors = [t for t in tensors if not t.is_cuda] + _mem_report(cuda_tensors, 'GPU') + _mem_report(host_tensors, 'CPU') + print('='*LEN) \ No newline at end of file diff --git a/RF2_allatom/parsers.py b/RF2_allatom/parsers.py new file mode 100644 index 0000000..e5bb282 --- /dev/null +++ b/RF2_allatom/parsers.py @@ -0,0 +1,439 @@ +import numpy as np +import scipy +import scipy.spatial +import string +import os,re +from os.path import exists +import random +import util +import gzip +from ffindex import * +import torch +from chemical import NAATOKENS, aa2num, aa2long +from rdkit import Chem + +to1letter = { + "ALA":'A', "ARG":'R', "ASN":'N', "ASP":'D', "CYS":'C', + "GLN":'Q', "GLU":'E', "GLY":'G', "HIS":'H', "ILE":'I', + "LEU":'L', "LYS":'K', "MET":'M', "PHE":'F', "PRO":'P', + "SER":'S', "THR":'T', "TRP":'W', "TYR":'Y', "VAL":'V', + "DA":'a', "DC":'c', "DG":'g', "DT":'t', + "A":'b', "C":'d', "G":'h', "U":'u', +} + +def read_template_pdb(L, pdb_fn, target_chain=None): + # get full sequence from given PDB + seq_full = list() + prev_chain='' + with open(pdb_fn) as fp: + for line in fp: + if line[:4] != "ATOM": + continue + if line[12:16].strip() != "CA": + continue + if line[21] != prev_chain: + if len(seq_full) > 0: + L_s.append(len(seq_full)-offset) + offset = len(seq_full) + prev_chain = line[21] + aa = line[17:20] + seq_full.append(aa2num[aa] if aa in aa2num.keys() else 20) + + seq_full = torch.tensor(seq_full).long() + + xyz = torch.full((L, 36, 3), np.nan).float() + seq = torch.full((L,), 20).long() + conf = torch.zeros(L,1).float() + + with open(pdb_fn) as fp: + for line in fp: + if line[:4] != "ATOM": + continue + resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20] + aa_idx = aa2num[aa] if aa in aa2num.keys() else 20 + # + idx = resNo - 1 + for i_atm, tgtatm in enumerate(aa2long[aa_idx]): + if tgtatm == atom: + xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])]) + break + seq[idx] = aa_idx + + mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3) + mask = mask.all(dim=-1)[:,None] + conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float() + seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float() + t1d = torch.cat((seq_1hot, conf), -1) + + #return seq_full[None], ins[None], L_s, xyz[None], t1d[None] + return xyz[None], t1d[None] + +def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False): + msa = [] + ins = [] + + fstream = open(filename,"r") + + for line in fstream: + # skip labels + if line[0] == '>': + continue + + # remove right whitespaces + line = line.rstrip() + + if len(line) == 0: + continue + + # remove lowercase letters and append to MSA + msa.append(line) + + # sequence length + L = len(msa[-1]) + + i = np.zeros((L)) + ins.append(i) + + # convert letters into numbers + if rmsa_alphabet: + alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8) + else: + alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) + msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8) + for i in range(alphabet.shape[0]): + msa[msa == alphabet[i]] = i + + ins = np.array(ins, dtype=np.uint8) + + return msa,ins + +# parse a fasta alignment IF it exists +# otherwise return single-sequence msa +def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False): + if (exists(filename)): + return parse_fasta(filename, maxseq, rmsa_alphabet) + else: + alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask + seq = np.array([list(seq)], dtype='|S1').view(np.uint8) + for i in range(alphabet.shape[0]): + seq[seq == alphabet[i]] = i + + return (seq, np.zeros_like(seq)) + +# read A3M and convert letters into +# integers in the 0..20 range, +# also keep track of insertions +def parse_a3m(filename, unzip=True, maxseq=10000): + msa = [] + ins = [] + + table = str.maketrans(dict.fromkeys(string.ascii_lowercase)) + + # read file line by line + if (unzip): + fstream = gzip.open(filename,"rt") + else: + fstream = open(filename,"r") + + for line in fstream: + + # skip labels + if line[0] == '>': + continue + + # remove right whitespaces + line = line.rstrip() + + if len(line) == 0: + continue + + # remove lowercase letters and append to MSA + msa.append(line.translate(table)) + + # sequence length + L = len(msa[-1]) + + # remove insertion at the end + if (not unzip): + n_remove = 0 + for c in reversed(line): + if c.islower(): + n_remove += 1 + else: + break + line = line[:-n_remove] + + # 0 - match or gap; 1 - insertion + a = np.array([0 if c.isupper() or c=='-' else 1 for c in line]) + i = np.zeros((L)) + + if np.sum(a) > 0: + # positions of insertions + pos = np.where(a==1)[0] + + # shift by occurrence + a = pos - np.arange(pos.shape[0]) + + # position of insertions in cleaned sequence + # and their length + pos,num = np.unique(a, return_counts=True) + + # append to the matrix of insetions + i[pos] = num + + ins.append(i) + + if (len(msa) >= maxseq): + break + + # convert letters into numbers + alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8) + msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8) + for i in range(alphabet.shape[0]): + msa[msa == alphabet[i]] = i + + # treat all unknown characters as gaps + msa[msa > 20] = 20 + + ins = np.array(ins, dtype=np.uint8) + + return msa,ins + + +# read and extract xyz coords of N,Ca,C atoms +# from a PDB file +def parse_pdb(filename): + lines = open(filename,'r').readlines() + return parse_pdb_lines(lines) + +#''' +def parse_pdb_lines(lines): + + # indices of residues observed in the structure + idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] + + # 4 BB + up to 10 SC atoms + xyz = np.full((len(idx_s), 27, 3), np.nan, dtype=np.float32) + for l in lines: + if l[:4] != "ATOM": + continue + resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20] + idx = idx_s.index(resNo) + for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]): + if tgtatm == atom: + xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] + break + + # save atom mask + mask = np.logical_not(np.isnan(xyz[...,0])) + xyz[np.isnan(xyz[...,0])] = 0.0 + + return xyz,mask,np.array(idx_s) + +def parse_pdb_lines_w_seq(lines): + + # indices of residues observed in the structure + #idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] + res = [(l[22:26],l[17:20]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"] + idx_s = [int(r[0]) for r in res] + seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res] + + # 4 BB + up to 10 SC atoms + xyz = np.full((len(idx_s), 27, 3), np.nan, dtype=np.float32) + for l in lines: + if l[:4] != "ATOM": + continue + resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20] + idx = idx_s.index(resNo) + for i_atm, tgtatm in enumerate(aa2long[aa2num[aa]]): + if tgtatm == atom: + xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])] + break + + # save atom mask + mask = np.logical_not(np.isnan(xyz[...,0])) + xyz[np.isnan(xyz[...,0])] = 0.0 + + return xyz,mask,np.array(idx_s), np.array(seq) + + +def parse_templates(item, params): + + # init FFindexDB of templates + ### and extract template IDs + ### present in the DB + ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'), + read_data(params['FFDB']+'_pdb.ffdata')) + #ffids = set([i.name for i in ffdb.index]) + + # process tabulated hhsearch output to get + # matched positions and positional scores + infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab' + hits = [] + for l in open(infile, "r").readlines(): + if l[0]=='>': + key = l[1:].split()[0] + hits.append([key,[],[]]) + elif "score" in l or "dssp" in l: + continue + else: + hi = l.split()[:5]+[0.0,0.0,0.0] + hits[-1][1].append([int(hi[0]),int(hi[1])]) + hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])]) + + # get per-hit statistics from an .hhr file + # (!!! assume that .hhr and .atab have the same hits !!!) + # [Probab, E-value, Score, Aligned_cols, + # Identities, Similarity, Sum_probs, Template_Neff] + lines = open(infile[:-4]+'hhr', "r").readlines() + pos = [i+1 for i,l in enumerate(lines) if l[0]=='>'] + for i,posi in enumerate(pos): + hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]]) + + # parse templates from FFDB + for hi in hits: + #if hi[0] not in ffids: + # continue + entry = get_entry_by_name(hi[0], ffdb.index) + if entry == None: + continue + data = read_entry_lines(entry, ffdb.data) + hi += list(parse_pdb_lines(data)) + + # process hits + counter = 0 + xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[] + for data in hits: + if len(data)<7: + continue + + qi,ti = np.array(data[1]).T + _,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True) + ncol = sel1.shape[0] + if ncol < 10: + continue + + ids.append(data[0]) + f0d.append(data[3]) + f1d.append(np.array(data[2])[sel1]) + xyz.append(data[4][sel2]) + mask.append(data[5][sel2]) + qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1)) + counter += 1 + + xyz = np.vstack(xyz).astype(np.float32) + mask = np.vstack(mask).astype(np.bool) + qmap = np.vstack(qmap).astype(np.long) + f0d = np.vstack(f0d).astype(np.float32) + f1d = np.vstack(f1d).astype(np.float32) + ids = ids + + return xyz,mask,qmap,f0d,f1d,ids + +def parse_templates_raw(ffdb, hhr_fn, atab_fn): + # process tabulated hhsearch output to get + # matched positions and positional scores + hits = [] + for l in open(atab_fn, "r").readlines(): + if l[0]=='>': + key = l[1:].split()[0] + hits.append([key,[],[]]) + elif "score" in l or "dssp" in l: + continue + else: + hi = l.split()[:5]+[0.0,0.0,0.0] + hits[-1][1].append([int(hi[0]),int(hi[1])]) + hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])]) + + # get per-hit statistics from an .hhr file + # (!!! assume that .hhr and .atab have the same hits !!!) + # [Probab, E-value, Score, Aligned_cols, + # Identities, Similarity, Sum_probs, Template_Neff] + lines = open(hhr_fn, "r").readlines() + pos = [i+1 for i,l in enumerate(lines) if l[0]=='>'] + for i,posi in enumerate(pos): + hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]]) + + # parse templates from FFDB + for hi in hits: + #if hi[0] not in ffids: + # continue + entry = get_entry_by_name(hi[0], ffdb.index) + if entry == None: + continue + data = read_entry_lines(entry, ffdb.data) + hi += list(parse_pdb_lines_w_seq(data)) + + # process hits + counter = 0 + xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[] + for data in hits: + if len(data)<7: + continue + + qi,ti = np.array(data[1]).T + _,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True) + ncol = sel1.shape[0] + if ncol < 10: + continue + + ids.append(data[0]) + f0d.append(data[3]) + f1d.append(np.array(data[2])[sel1]) + xyz.append(data[4][sel2]) + mask.append(data[5][sel2]) + seq.append(data[-1][sel2]) + qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1)) + counter += 1 + + xyz = np.vstack(xyz).astype(np.float32) + qmap = np.vstack(qmap).astype(np.long) + f1d = np.vstack(f1d).astype(np.float32) + seq = np.hstack(seq).astype(np.long) + ids = ids + + return torch.from_numpy(xyz), torch.from_numpy(qmap), \ + torch.from_numpy(f1d), torch.from_numpy(seq), ids + +def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10): + xyz_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn) + npick = min(n_templ, len(ids)) + if npick < 1: # no templates + xyz = torch.full((1,qlen,27,3),np.nan).float() + t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps + t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1) + return xyz, t1d + + sample = torch.arange(npick) + # + xyz = torch.full((npick, qlen, 27, 3), np.nan).float() + f1d = torch.full((npick, qlen), 20).long() + f1d_val = torch.zeros((npick, qlen, 1)).float() + # + for i, nt in enumerate(sample): + sel = torch.where(qmap[:,1] == nt)[0] + pos = qmap[sel, 0] + xyz[i, pos] = xyz_t[sel] + f1d[i, pos] = seq[sel] + f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1) + + f1d = torch.nn.functional.one_hot(f1d, num_classes=21).float() + f1d = torch.cat((f1d, f1d_val), dim=-1) + + return xyz, f1d + +def parse_mol(filename): + """parses a mol file""" + mol = Chem.MolFromMolFile(filename) + msa = torch.tensor([aa2num[a.GetSymbol()] for a in mol.GetAtoms()]) + ins = torch.zeros_like(msa) + + return mol, msa, ins + +def get_ligand_xyz(mol): + xyz = torch.tensor(np.array([c.GetPositions() for c in mol.GetConformers()])).squeeze(0) + permuts = mol.GetSubstructMatches(mol, uniquify=False, maxMatches=256) + permuts = torch.tensor(permuts) + Y = xyz[permuts].reshape(-1,mol.GetNumAtoms(),3) + mask = torch.full(Y.shape[:-1], True) + return Y, mask diff --git a/RF2_allatom/predict_casp14.py b/RF2_allatom/predict_casp14.py new file mode 100644 index 0000000..55f9bd4 --- /dev/null +++ b/RF2_allatom/predict_casp14.py @@ -0,0 +1,334 @@ +import sys, os +import time +import numpy as np +import torch +import torch.nn as nn +from torch.utils import data +from parsers import parse_a3m, read_templates +from RoseTTAFoldModel import RoseTTAFoldModule +import util +from collections import namedtuple +from ffindex import * +from data_loader import MSAFeaturize, MSABlockDeletion +from kinematics import xyz_to_c6d, c6d_to_bins2, xyz_to_t2d, get_init_xyz +from util_module import ComputeAllAtomCoords + +from memory import mem_report + +MAX_CYCLE = 20 +NREPLICATES = 5 +NBIN = [37, 37, 37, 19] + +MODEL_PARAM ={ + "n_extra_block" : 4, + "n_main_block" : 32, + "n_ref_block" : 0, + "n_finetune_block" : 4, + "d_msa" : 256 , + "d_pair" : 128, + "d_templ" : 64, + "n_head_msa" : 8, + "n_head_pair" : 4, + "n_head_templ" : 4, + "d_hidden" : 32, + "d_hidden_templ" : 64, + "p_drop" : 0.0, + "lj_lin" : 0.7 +} + +SE3_param = { + "num_layers" : 1, + "num_channels" : 32, + "num_degrees" : 2, + "l0_in_features": 64, + "l0_out_features": 64, + "l1_in_features": 3, + "l1_out_features": 2, + "num_edge_features": 64, + "div": 4, + "n_heads": 4 +} +MODEL_PARAM['SE3_param'] = SE3_param + + +# params for the folding protocol +fold_params = { + "SG7" : np.array([[[-2,3,6,7,6,3,-2]]])/21, + "SG9" : np.array([[[-21,14,39,54,59,54,39,14,-21]]])/231, + "DCUT" : 19.5, + "ALPHA" : 1.57, + + # TODO: add Cb to the motif + "NCAC" : np.array([[-0.676, -1.294, 0. ], + [ 0. , 0. , 0. ], + [ 1.5 , -0.174, 0. ]], dtype=np.float32), + "CLASH" : 2.0, + "PCUT" : 0.5, + "DSTEP" : 0.5, + "ASTEP" : np.deg2rad(10.0), + "XYZRAD" : 7.5, + "WANG" : 0.1, + "WCST" : 0.1 +} + +fold_params["SG"] = fold_params["SG9"] + +# compute expected value from binned lddt +def lddt_unbin(pred_lddt): + nbin = pred_lddt.shape[1] + bin_step = 1.0 / nbin + lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device) + + pred_lddt = nn.Softmax(dim=1)(pred_lddt) + return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1) + +class Predictor(): + def __init__(self, model_name="BFF", model_dir=None, device="cuda:0"): + if model_dir == None: + self.model_dir = "%s/models"%(os.path.dirname(os.path.abspath(__file__))) + else: + self.model_dir = model_dir + # + # define model name + self.model_name = model_name + self.device = device + self.active_fn = nn.Softmax(dim=1) + + # define model & load model + self.model = RoseTTAFoldModule( + **MODEL_PARAM, + aamask=util.allatom_mask.to(self.device), + atom_type_index=util.atom_type_index.to(self.device), + ljlk_parameters=util.ljlk_parameters.to(self.device), + lj_correction_parameters=util.lj_correction_parameters.to(self.device), + num_bonds=util.num_bonds.to(self.device), + cb_len = util.cb_length_t.to(self.device), + cb_ang = util.cb_angle_t.to(self.device), + cb_tor = util.cb_torsion_t.to(self.device), + ).to(self.device) + + could_load = self.load_model(self.model_name) + if not could_load: + print ("ERROR: failed to load model") + sys.exit() + + self.compute_allatom_coords = ComputeAllAtomCoords().to(self.device) + + def load_model(self, model_name, suffix='last'): + chk_fn = "%s/%s_%s.pt"%(self.model_dir, model_name, suffix) + if not os.path.exists(chk_fn): + return False + checkpoint = torch.load(chk_fn, map_location=self.device) + self.model.load_state_dict(checkpoint['model_state_dict']) + return True + + def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=1e9, shift=50, n_latent=256): + msa_orig, ins_orig = parse_a3m(a3m_fn, unzip=False) + N, L = msa_orig.shape + # + if os.path.exists(hhr_fn): + xyz_t, t1d = read_templates(L, ffdb, hhr_fn, atab_fn, n_templ=4) + else: + xyz_t = torch.full((1,L,3,3),np.nan).float() + t1d = torch.nn.functional.one_hot(torch.full((1, L), 20).long(), num_classes=21).float() # all gaps + t1d = torch.cat((t1d, torch.zeros((1,L,1)).float()), -1) + # + # template features + xyz_t = xyz_t.float().unsqueeze(0) + t1d = t1d.float().unsqueeze(0) + t2d = xyz_to_t2d(xyz_t) + xyz_t = get_init_xyz(xyz_t) # initialize coordinates with first template + + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + + alpha, _, alpha_mask, _ = util.get_torsions( + xyz_t.reshape(-1,L,27,3), + seq_tmp, + util.torsion_indices, + util.torsion_can_flip, + util.reference_angles + ) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(1,-1,L,10,2) + alpha_mask = alpha_mask.reshape(1,-1,L,10,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(1, -1, L, 30) + + self.model.eval() + for i_trial in range(NREPLICATES): + if os.path.exists("%s_%02d_init.pdb"%(out_prefix, i_trial)): + continue + self.run_prediction(msa_orig, ins_orig, t1d, t2d, xyz_t, xyz_t[:,0], alpha_t, "%s_%02d"%(out_prefix, i_trial), n_latent=n_latent) + torch.cuda.empty_cache() + + def run_prediction(self, msa_orig, ins_orig, t1d, t2d, xyz_t, xyz, alpha_t, out_prefix, n_latent=256): + start = time.time() + torch.cuda.reset_peak_memory_stats() + with torch.no_grad(): + # + msa = torch.tensor(msa_orig).long().to(self.device) # (N, L) + ins = torch.tensor(ins_orig).long().to(self.device) + if msa_orig.shape[0] > 4096: + msa, ins = MSABlockDeletion(msa, ins) + print (msa_orig.shape, msa.shape) + # + seq, msa_seed_orig, msa_seed, msa_extra, mask_msa = MSAFeaturize(msa, ins, p_mask=0.0) + _, N, L = msa_seed.shape[:3] + B = 1 + # + idx_pdb = torch.arange(L).long().view(1, L) + # + seq = seq.unsqueeze(0) + msa_seed = msa_seed.unsqueeze(0) + msa_extra = msa_extra.unsqueeze(0) + t1d = t1d.to(self.device) + t2d = t2d.to(self.device) + idx_pdb = idx_pdb.to(self.device) + xyz_t = xyz_t.to(self.device) + alpha_t = alpha_t.to(self.device) + xyz = xyz.to(self.device) + + self.write_pdb(seq[0, -1], xyz[0], prefix="%s_templ"%(out_prefix)) + + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((1,L,10,2), device=seq.device) + xyz_prev=xyz + state_prev = None + + best_lddt = torch.tensor([-1.0], device=seq.device) + best_xyz = None + best_logit = None + best_aa = None + for i_cycle in range(MAX_CYCLE): + with torch.cuda.amp.autocast(True): + logit_s, logit_aa_s, init_crds, alpha_prev, init_allatom, pred_lddt_binned, msa_prev, pair_prev, state_prev = self.model( + msa_seed[:,i_cycle], + msa_extra[:,i_cycle], + seq[:,i_cycle], + xyz_prev, + alpha_prev, + idx_pdb, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev + ) + + logit_aa_s = logit_aa_s.reshape(B,-1,N,L)[:,:,0].permute(0,2,1) + + #xyz_prev = init_crds[-1] + xyz_prev = init_allatom[-1].unsqueeze(0) + #msa_prev = msa_prev[:,0] + alpha_prev = alpha_prev[-1] + pred_lddt = lddt_unbin(pred_lddt_binned) + + print ("RECYCLE", i_cycle, pred_lddt.mean(), best_lddt.mean()) + #_, all_crds = self.compute_allatom_coords(seq[:,i_cycle], init_crds[-1], alpha_prev) + self.write_pdb(seq[0, -1], init_allatom[-1], Bfacts=pred_lddt[0], prefix="%s_cycle_%02d"%(out_prefix, i_cycle)) + + if pred_lddt.mean() < best_lddt.mean(): + continue + best_xyz = init_allatom[-1].clone() + best_logit = logit_s + best_aa = logit_aa_s + best_lddt = pred_lddt.clone() + + prob_s = list() + for logit in logit_s: + prob = self.active_fn(logit.float()) # distogram + prob = prob.reshape(-1, L, L) #.permute(1,2,0).cpu().numpy() + prob_s.append(prob) + + end = time.time() + + for prob in prob_s: + prob += 1e-8 + prob = prob / torch.sum(prob, dim=0)[None] + self.write_pdb(seq[0, -1], best_xyz, Bfacts=best_lddt[0], prefix="%s_init"%(out_prefix)) + prob_s = [prob.permute(1,2,0).detach().cpu().numpy().astype(np.float16) for prob in prob_s] + np.savez_compressed("%s.npz"%(out_prefix), dist=prob_s[0].astype(np.float16), \ + omega=prob_s[1].astype(np.float16),\ + theta=prob_s[2].astype(np.float16),\ + phi=prob_s[3].astype(np.float16),\ + lddt=best_lddt[0].detach().cpu().numpy().astype(np.float16)) + + max_mem = torch.cuda.max_memory_allocated()/1e9 + print ("max mem", max_mem) + print ("runtime", end-start) + + + def write_pdb(self, seq, atoms, Bfacts=None, prefix=None): + L = len(seq) + filename = "%s.pdb"%prefix + ctr = 1 + with open(filename, 'wt') as f: + if Bfacts == None: + Bfacts = np.zeros(L) + else: + Bfacts = torch.clamp( Bfacts, 0, 1) + + for i,s in enumerate(seq): + if (len(atoms.shape)==2): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, " CA ", util.num2aa[s], + "A", i+1, atoms[i,0], atoms[i,1], atoms[i,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + + elif atoms.shape[1]==3: + for j,atm_j in enumerate((" N "," CA "," C ")): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, atm_j, util.num2aa[s], + "A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + else: + atms = util.aa2long[s] + for j,atm_j in enumerate(atms): + if (atm_j is not None): + f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%( + "ATOM", ctr, atm_j, util.num2aa[s], + "A", i+1, atoms[i,j,0], atoms[i,j,1], atoms[i,j,2], + 1.0, Bfacts[i] ) ) + ctr += 1 + +def get_args(): + #DB="/home/robetta/rosetta_server_beta/external/databases/trRosetta/pdb100_2021Mar03/pdb100_2021Mar03" + DB = "/projects/ml/TrRosetta/pdb100_2020Mar11/pdb100_2020Mar11" + import argparse + parser = argparse.ArgumentParser(description="RoseTTAFold: Protein structure prediction with 3-track attentions on 1D, 2D, and 3D features") + parser.add_argument("-db", default=DB, required=False, + help="HHsearch database [%s]"%DB) + parser.add_argument("-model_name", default="BFF", required=False, + help="Prefix for model. The model under models/[model_name]_best.pt will be used. [BFF]") + parser.add_argument("-i", default=1, type=int, required=False, help="Parallelize i of j [1]") + parser.add_argument("-j", default=1, type=int, required=False, help="Parallelize i of j [1]") + args = parser.parse_args() + return args + +casp14_home = "/home/minkbaek/CASP14.bench" +if __name__ == "__main__": + args = get_args() + #tar_s = [line.strip() for line in open("%s/tar_s"%casp14_home)] + tar_s = [line.strip() for line in open("./tar_s")] + FFDB = args.db + FFindexDB = namedtuple("FFindexDB", "index, data") + ffdb = FFindexDB(read_index(FFDB+'_pdb.ffindex'), + read_data(FFDB+'_pdb.ffdata')) + pred = Predictor(model_name=args.model_name) + + for i_str,tar in enumerate(tar_s): + if (i_str % args.j == args.i % args.j): + print (tar) + if not os.path.exists(tar): + os.mkdir(tar) + out_prefix = "%s/%s"%(tar, tar) + a3m_fn = "%s/a3m.final/%s.a3m"%(casp14_home,tar) + hhr_fn = "%s/hhr.final/%s.hhr"%(casp14_home,tar) + atab_fn = "%s/hhr.final/%s.atab"%(casp14_home, tar) + if not os.path.exists("%s_04.npz"%out_prefix): + pred.predict(a3m_fn, out_prefix, hhr_fn, atab_fn) diff --git a/RF2_allatom/resnet.py b/RF2_allatom/resnet.py new file mode 100644 index 0000000..f9134d3 --- /dev/null +++ b/RF2_allatom/resnet.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +# pre-activation bottleneck resblock +class ResBlock2D_bottleneck(nn.Module): + def __init__(self, n_c, kernel=3, dilation=1, p_drop=0.15): + super(ResBlock2D_bottleneck, self).__init__() + padding = self._get_same_padding(kernel, dilation) + + n_b = n_c // 2 # bottleneck channel + + layer_s = list() + # pre-activation + layer_s.append(nn.InstanceNorm2d(n_c, affine=True, eps=1e-6)) + layer_s.append(nn.ELU(inplace=True)) + # project down to n_b + layer_s.append(nn.Conv2d(n_c, n_b, 1, bias=False)) + layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6)) + layer_s.append(nn.ELU(inplace=True)) + # convolution + layer_s.append(nn.Conv2d(n_b, n_b, kernel, dilation=dilation, padding=padding, bias=False)) + layer_s.append(nn.InstanceNorm2d(n_b, affine=True, eps=1e-6)) + layer_s.append(nn.ELU(inplace=True)) + # dropout + layer_s.append(nn.Dropout(p_drop)) + # project up + layer_s.append(nn.Conv2d(n_b, n_c, 1, bias=False)) + + # make final layer initialize with zeros + #nn.init.zeros_(layer_s[-1].weight) + + self.layer = nn.Sequential(*layer_s) + + self.reset_parameter() + + def reset_parameter(self): + # zero-initialize final layer right before residual connection + nn.init.zeros_(self.layer[-1].weight) + + def _get_same_padding(self, kernel, dilation): + return (kernel + (kernel - 1) * (dilation - 1) - 1) // 2 + + def forward(self, x): + out = self.layer(x) + return x + out + +class ResidualNetwork(nn.Module): + def __init__(self, n_block, n_feat_in, n_feat_block, n_feat_out, + dilation=[1,2,4,8], p_drop=0.15): + super(ResidualNetwork, self).__init__() + + + layer_s = list() + # project to n_feat_block + if n_feat_in != n_feat_block: + layer_s.append(nn.Conv2d(n_feat_in, n_feat_block, 1, bias=False)) + + # add resblocks + for i_block in range(n_block): + d = dilation[i_block%len(dilation)] + res_block = ResBlock2D_bottleneck(n_feat_block, kernel=3, dilation=d, p_drop=p_drop) + layer_s.append(res_block) + + if n_feat_out != n_feat_block: + # project to n_feat_out + layer_s.append(nn.Conv2d(n_feat_block, n_feat_out, 1)) + + self.layer = nn.Sequential(*layer_s) + + def forward(self, x): + return self.layer(x) diff --git a/RF2_allatom/run.sh b/RF2_allatom/run.sh new file mode 100755 index 0000000..f560aa2 --- /dev/null +++ b/RF2_allatom/run.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=1 + +python -u ./train_multi_EMA.py \ + -model_name BFF20h \ + -p_drop 0.0 \ + -maxcycle 4 \ + -n_extra_block 4 \ + -n_main_block 32 \ + -n_ref_block 4 \ + -n_finetune_block 0 \ + -ref_num_layers 2 \ + -accum 4 \ + -crop 256 \ + -w_bond 0.0 \ + -w_dih 0.0 \ + -w_clash 0.0 \ + -w_hb 0.0 \ + -lj_lin 0.7 \ + -w_dist 1.0 \ + -w_str 10.0 \ + -w_lddt 0.1 \ + -w_aa 3.0 \ + -subsmp UNI \ + -num_epochs 400 \ + -slice CONT \ + -lr 0.001 \ + -port 12345 \ + -eval diff --git a/RF2_allatom/scheduler.py b/RF2_allatom/scheduler.py new file mode 100644 index 0000000..8b8150c --- /dev/null +++ b/RF2_allatom/scheduler.py @@ -0,0 +1,180 @@ +import math +import torch +from torch.optim.lr_scheduler import _LRScheduler, LambdaLR + +#def get_cosine_with_hard_restarts_schedule_with_warmup( +# optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 +#): +# """ +# Create a schedule with a learning rate that decreases following the values of the cosine function between the +# initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases +# linearly between 0 and the initial lr set in the optimizer. +# +# Args: +# optimizer (:class:`~torch.optim.Optimizer`): +# The optimizer for which to schedule the learning rate. +# num_warmup_steps (:obj:`int`): +# The number of steps for the warmup phase. +# num_training_steps (:obj:`int`): +# The total number of training steps. +# num_cycles (:obj:`int`, `optional`, defaults to 1): +# The number of hard restarts to use. +# last_epoch (:obj:`int`, `optional`, defaults to -1): +# The index of the last epoch when resuming training. +# +# Return: +# :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. +# """ +# +# def lr_lambda(current_step): +# if current_step < num_warmup_steps: +# return float(current_step) / float(max(1, num_warmup_steps)) +# progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) +# if progress >= 1.0: +# return 0.0 +# return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) +# +# return LambdaLR(optimizer, lr_lambda, last_epoch) +# + +class CosineAnnealingWarmupRestarts(_LRScheduler): + """ + optimizer (Optimizer): Wrapped optimizer. + first_cycle_steps (int): First cycle step size. + cycle_mult(float): Cycle steps magnification. Default: -1. + max_lr(float): First cycle's max learning rate. Default: 0.1. + min_lr(float): Min learning rate. Default: 0.001. + warmup_steps(int): Linear warmup step size. Default: 0. + gamma(float): Decrease rate of max learning rate by cycle. Default: 1. + last_epoch (int): The index of last epoch. Default: -1. + """ + + def __init__(self, + optimizer : torch.optim.Optimizer, + first_cycle_steps : int, + cycle_mult : float = 1., + max_lr : float = 0.1, + min_lr : float = 0.001, + warmup_steps : int = 0, + gamma : float = 1., + last_epoch : int = -1 + ): + assert warmup_steps < first_cycle_steps + + self.first_cycle_steps = first_cycle_steps # first cycle step size + self.cycle_mult = cycle_mult # cycle steps magnification + self.base_max_lr = max_lr # first max learning rate + self.max_lr = max_lr # max learning rate in the current cycle + self.min_lr = min_lr # min learning rate + self.warmup_steps = warmup_steps # warmup step size + self.gamma = gamma # decrease rate of max learning rate by cycle + + self.cur_cycle_steps = first_cycle_steps # first cycle step size + self.cycle = 0 # cycle count + self.step_in_cycle = last_epoch # step size of the current cycle + + super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) + + # set learning rate min_lr + self.init_lr() + + def init_lr(self): + self.base_lrs = [] + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.min_lr + self.base_lrs.append(self.min_lr) + + def get_lr(self): + if self.step_in_cycle == -1: + return self.base_lrs + elif self.step_in_cycle < self.warmup_steps: + return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] + else: + return [base_lr + (self.max_lr - base_lr) \ + * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ + / (self.cur_cycle_steps - self.warmup_steps))) / 2 + for base_lr in self.base_lrs] + + def step(self, epoch=None): + if epoch is None: + epoch = self.last_epoch + 1 + self.step_in_cycle = self.step_in_cycle + 1 + if self.step_in_cycle >= self.cur_cycle_steps: + self.cycle += 1 + self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps + self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps + else: + if epoch >= self.first_cycle_steps: + if self.cycle_mult == 1.: + self.step_in_cycle = epoch % self.first_cycle_steps + self.cycle = epoch // self.first_cycle_steps + else: + n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) + self.cycle = n + self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) + self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) + else: + self.cur_cycle_steps = self.first_cycle_steps + self.step_in_cycle = epoch + + self.max_lr = self.base_max_lr * (self.gamma**self.cycle) + self.last_epoch = math.floor(epoch) + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_ratio=0.001, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + min_ratio, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) + ) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + +def get_stepwise_decay_schedule_with_warmup(optimizer, num_warmup_steps, num_steps_decay, decay_rate, last_epoch=-1): + """ + Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after + a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. + + Args: + optimizer (:class:`~torch.optim.Optimizer`): + The optimizer for which to schedule the learning rate. + num_warmup_steps (:obj:`int`): + The number of steps for the warmup phase. + num_training_steps (:obj:`int`): + The total number of training steps. + last_epoch (:obj:`int`, `optional`, defaults to -1): + The index of the last epoch when resuming training. + + Return: + :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + def lr_lambda(current_step: int): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + + num_fades = (current_step-num_warmup_steps)//num_steps_decay + return (decay_rate**num_fades) + + return LambdaLR(optimizer, lr_lambda, last_epoch) diff --git a/RF2_allatom/scoring.py b/RF2_allatom/scoring.py new file mode 100644 index 0000000..541a26a --- /dev/null +++ b/RF2_allatom/scoring.py @@ -0,0 +1,310 @@ +import json + +## +## lk and lk term +#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME) +type2ljlk = { + "CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000), + "COO":(1.916661,0.141799,-3.332648,3.5000,14.653000), + "CH0":(2.011760,0.062642,1.409284,3.5000,8.998000), + "CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000), + "CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000), + "CH3":(2.011760,0.062642,7.292929,3.5000,25.855000), + "aroC":(2.016441,0.068775,1.797950,3.5000,16.704000), + "Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100), + "Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700), + "NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200), + "NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000), + "Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000), + "Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000), + "Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100), + "OH":(1.542743,0.161947,-8.133520,3.5000,10.722000), + "OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000), + "ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000), + "OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600), + "S":(1.975967,0.455970,-1.707229,3.5000,17.640000), + "SH1":(1.975967,0.455970,3.291643,3.5000,23.240000), + "Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000), + "CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000), + "CObb":(1.916661,0.141799,3.104248,3.5000,13.221000), + "OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000), + "Phos":(2.1500,0.5850,-4.1000,3.5000,14.7000), # phil + "Oet2":(1.5500,0.1591,-5.8500,3.5000,10.8000), + "Oet3":(1.5500,0.1591,-6.7000,3.5000,10.8000), + "HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000), + "Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000), + "Haro":(1.374914,0.015909,0.0000,3.5000,0.0000), + "Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000), + "HS":(0.363887,0.050836,0.0000,3.5000,0.0000), +} + +# cartbonded +with open('cartbonded.json', 'r') as j: + cartbonded_data_raw = json.loads(j.read()) + +# hbond donor/acceptors +class HbAtom: + NO = 0 + DO = 1 # donor + AC = 2 # acceptor + DA = 3 # donor & acceptor + HP = 4 # polar H + +type2hb = { + "CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO, + "CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO, + "Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO, + "Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA, + "ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO, + "Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC, + "HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP, + "HS":HbAtom.HP, # HP in rosetta(?) + "Phos":HbAtom.NO, "Oet2":HbAtom.AC, "Oet3":HbAtom.AC +} + +## +## hbond term +## TO DO: ADD DNA +class HbDonType: + PBA = 0 + IND = 1 + IME = 2 + GDE = 3 + CXA = 4 + AMO = 5 + HXL = 6 + AHX = 7 + NTYPES = 8 + +class HbAccType: + PBA = 0 + CXA = 1 + CXL = 2 + HXL = 3 + AHX = 4 + IME = 5 + NTYPES = 6 + +class HbHybType: + SP2 = 0 + SP3 = 1 + RING = 2 + NTYPES = 3 + +type2dontype = { + "Nbb": HbDonType.PBA, + "Ntrp": HbDonType.IND, + "NtrR": HbDonType.GDE, + "Narg": HbDonType.GDE, + "NH2O": HbDonType.CXA, + "Nlys": HbDonType.AMO, + "OH": HbDonType.HXL, + "OHY": HbDonType.AHX, +} + +type2acctype = { + "OCbb": HbAccType.PBA, + "ONH2": HbAccType.CXA, + "OOC": HbAccType.CXL, + "OH": HbAccType.HXL, + "OHY": HbAccType.AHX, + "Nhis": HbAccType.IME, +} + +type2hybtype = { + "OCbb": HbHybType.SP2, + "ONH2": HbHybType.SP2, + "OOC": HbHybType.SP2, + "OHY": HbHybType.SP3, + "OH": HbHybType.SP3, + "Nhis": HbHybType.RING, +} + +dontype2wt = { + HbDonType.PBA: 1.45, + HbDonType.IND: 1.15, + HbDonType.IME: 1.42, + HbDonType.GDE: 1.11, + HbDonType.CXA: 1.29, + HbDonType.AMO: 1.17, + HbDonType.HXL: 0.99, + HbDonType.AHX: 1.00, +} + +acctype2wt = { + HbAccType.PBA: 1.19, + HbAccType.CXA: 1.21, + HbAccType.CXL: 1.10, + HbAccType.HXL: 1.15, + HbAccType.AHX: 1.15, + HbAccType.IME: 1.17, +} + +class HbPolyType: + ahdist_aASN_dARG = 0 + ahdist_aASN_dASN = 1 + ahdist_aASN_dGLY = 2 + ahdist_aASN_dHIS = 3 + ahdist_aASN_dLYS = 4 + ahdist_aASN_dSER = 5 + ahdist_aASN_dTRP = 6 + ahdist_aASN_dTYR = 7 + ahdist_aASP_dARG = 8 + ahdist_aASP_dASN = 9 + ahdist_aASP_dGLY = 10 + ahdist_aASP_dHIS = 11 + ahdist_aASP_dLYS = 12 + ahdist_aASP_dSER = 13 + ahdist_aASP_dTRP = 14 + ahdist_aASP_dTYR = 15 + ahdist_aGLY_dARG = 16 + ahdist_aGLY_dASN = 17 + ahdist_aGLY_dGLY = 18 + ahdist_aGLY_dHIS = 19 + ahdist_aGLY_dLYS = 20 + ahdist_aGLY_dSER = 21 + ahdist_aGLY_dTRP = 22 + ahdist_aGLY_dTYR = 23 + ahdist_aHIS_dARG = 24 + ahdist_aHIS_dASN = 25 + ahdist_aHIS_dGLY = 26 + ahdist_aHIS_dHIS = 27 + ahdist_aHIS_dLYS = 28 + ahdist_aHIS_dSER = 29 + ahdist_aHIS_dTRP = 30 + ahdist_aHIS_dTYR = 31 + ahdist_aSER_dARG = 32 + ahdist_aSER_dASN = 33 + ahdist_aSER_dGLY = 34 + ahdist_aSER_dHIS = 35 + ahdist_aSER_dLYS = 36 + ahdist_aSER_dSER = 37 + ahdist_aSER_dTRP = 38 + ahdist_aSER_dTYR = 39 + ahdist_aTYR_dARG = 40 + ahdist_aTYR_dASN = 41 + ahdist_aTYR_dGLY = 42 + ahdist_aTYR_dHIS = 43 + ahdist_aTYR_dLYS = 44 + ahdist_aTYR_dSER = 45 + ahdist_aTYR_dTRP = 46 + ahdist_aTYR_dTYR = 47 + cosBAH_off = 48 + cosBAH_7 = 49 + cosBAH_6i = 50 + AHD_1h = 51 + AHD_1i = 52 + AHD_1j = 53 + AHD_1k = 54 + +# map donor:acceptor pairs to polynomials +hbtypepair2poly = { + (HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h), + (HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j), + (HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h), + (HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h), + (HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k), + (HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h), + (HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h), + (HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h), + (HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i), + (HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h), + (HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h), + (HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), + (HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i), +} + + +# polynomials are triplets, (x_min, x_max), (y[xx_max]), (c_9,...,c_0) +hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886 + HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)), + HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)), + HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)), + HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)), + HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)), + HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)), + HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)), + HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)), + HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)), + HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)), + HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)), + HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)), + HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)), + HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)), + HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)), + HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)), + HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)), + HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)), + HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)), + HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)), + HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)), + HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)), + HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)), + HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)), + HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)), + HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)), + HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)), + HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)), + HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)), + HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)), + HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)), + HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)), + HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)), + HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)), + HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)), + HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)), + HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)), + HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)), + HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)), + HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)), + HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)), + HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)), + HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)), + HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)), + HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)), + HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)), + HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)), + HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)), + HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)), + HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)), + HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)), + HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)), + HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)), + HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)), + HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)), +} \ No newline at end of file diff --git a/RF2_allatom/tests.py b/RF2_allatom/tests.py new file mode 100644 index 0000000..25d8a3d --- /dev/null +++ b/RF2_allatom/tests.py @@ -0,0 +1,227 @@ +import unittest +import torch +from torch.utils import data + +from chemical import NFRAMES +from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetRNA, DatasetSMComplex, loader_pdb, loader_na_complex, loader_rna, loader_sm_compl,set_data_loader_params +from kinematics import xyz_to_c6d, xyz_to_t2d +from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss +from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz + +class LossTestCase(unittest.TestCase): + + def setUp(self): + self.loader_param = set_data_loader_params({}) + ( + pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items, + sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl, + valid_na_neg, valid_rna, valid_sm_compl, homo + ) = get_train_valid_set(self.loader_param) + + pdb_IDs, pdb_weights, pdb_dict = pdb_items + na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items + rna_IDs, rna_weights, rna_dict = rna_items + sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items + self.homo = homo + + valid_pdb_set = Dataset( + list(valid_pdb.keys()), + loader_pdb, valid_pdb, + self.loader_param, homo, p_homo_cut=-1.0 + ) + valid_na_compl_set = DatasetNAComplex( + list(valid_na_compl.keys()), + loader_na_complex, valid_na_compl, + self.loader_param, negative=False, native_NA_frac=1.0 + ) + valid_sm_compl_set = DatasetSMComplex( + list(sm_compl_dict.keys()), + loader_sm_compl, sm_compl_dict, + self.loader_param + ) + self.valid_pdb_loader = data.DataLoader(valid_pdb_set) + self.valid_na_compl_loader = data.DataLoader(valid_na_compl_set) + self.valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set) + + def test_compute_general_FAPE(self): + with self.subTest("test that FAPE loss is correctly calculated for proteins"): + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_pdb_loader: + # first assert that same structure gives you 0 loss + + frames, frame_mask = get_frames( + true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames + ) + + l_fape = compute_general_FAPE( + true_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + self.assertAlmostEqual(int(l_fape.numpy()),0) + fapes = [] + for i in range(5): + perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1)) + l_fape = compute_general_FAPE( + perturbed_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + fapes.append(l_fape) + for i in range(1,5): + self.assertLess(fapes[i-1], fapes[i]) + break + #add noise and make sure increasing noise increases loss + with self.subTest("test that FAPE loss is correctly calculated for protein/NA complexes"): + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_na_compl_loader: + # first assert that same structure gives you 0 loss + + frames, frame_mask = get_frames( + true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames + ) + + l_fape = compute_general_FAPE( + true_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + self.assertAlmostEqual(int(l_fape.numpy()),0) + fapes = [] + for i in range(5): + perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1)) + l_fape = compute_general_FAPE( + perturbed_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + fapes.append(l_fape) + for i in range(1,5): + self.assertLess(fapes[i-1], fapes[i]) + break + with self.subTest("test that FAPE loss is correctly calculated for protein/SM complexes"): + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_sm_compl_loader: + # first assert that same structure gives you 0 loss + true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask) + frames, frame_mask = get_frames( + true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames + ) + l_fape = compute_general_FAPE( + true_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + self.assertAlmostEqual(int(l_fape.numpy()),0) + + fapes = [] + for i in range(5): + perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1)) + l_fape = compute_general_FAPE( + perturbed_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + fapes.append(l_fape) + for i in range(1,5): + self.assertLess(fapes[i-1], fapes[i]) + break + with self.subTest("test that protein backbone FAPE loss can be calculated with compute_general_FAPE"): + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_pdb_loader: + frames, frame_mask = get_frames( + true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames + ) + frame_mask[...,1:] = False + + l_fape = compute_general_FAPE( + true_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + self.assertAlmostEqual(int(l_fape.numpy()),0) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,0,0]))) + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + fapes = [] + for i in range(5): + perturbed_crds = true_crds+(torch.rand(true_crds.shape)*(i+1)) + l_fape = compute_general_FAPE( + perturbed_crds, + true_crds, + atom_mask, + frames, + frame_mask + ) + + fapes.append(l_fape) + + tot_str, str_loss = calc_str_loss(perturbed_crds.unsqueeze(0), true_crds, mask_2d, same_chain, negative=False, + A=10.0, d_clamp_intra=10.0, d_clamp_inter=10.0, gamma=1.0, eps=1e-4) + self.assertAlmostEqual(int(l_fape.numpy()), int(tot_str.numpy())) + for i in range(1,5): + self.assertLess(fapes[i-1], fapes[i]) + break + + def test_get_frames(self): + """test that nodes in atom frames are relatively close to each other (because they should be bonded)""" + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_sm_compl_loader: + true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask) + frames, frame_mask = get_frames( + true_crds, atom_mask, msa[:, 0, 0], frame_indices, atom_frames + ) + N, L, natoms, _ = true_crds.shape + # flatten middle dims so can gather across residues + X_prime = true_crds.reshape(N, L*natoms, -1, 3).repeat(1,1,NFRAMES,1) + + frames_reindex = torch.zeros(frames.shape[:-1]) + for i in range(L): + frames_reindex[:, i, :, :] = (i+frames[..., i, :, :, 0])*natoms + frames[..., i, :, :, 1] + frames_reindex = frames_reindex.long() + + frame_mask *= torch.all( + torch.gather(atom_mask.reshape(1, L*natoms),1,frames_reindex.reshape(1,L*NFRAMES*3)).reshape(1,L,-1,3), + axis=-1) + + X_x = torch.gather(X_prime, 1, frames_reindex[...,0:1].repeat(N,1,1,3)) + X_y = torch.gather(X_prime, 1, frames_reindex[...,1:2].repeat(N,1,1,3)) + X_z = torch.gather(X_prime, 1, frames_reindex[...,2:3].repeat(N,1,1,3)) + atoms = is_atom(msa[:, 0,0]) + + frame_distance1 = torch.cdist(X_x[atoms], X_y[atoms]) + frame_distance2 = torch.cdist(X_y[atoms], X_z[atoms]) + self.assertTrue(torch.all(frame_distance1[:,0,0] <2)) + self.assertTrue(torch.all(frame_distance2[:,0,0] <2)) + break + + def test_xyz_to_c6d(self): + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in self.valid_sm_compl_loader: + true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask) + # atoms = is_atom(msa[:, 0,0]) + # atom_crds = true_crds[atoms] + # atom_L, natoms, _ = atom_crds.shape + # frames_reindex = torch.zeros(atom_frames.shape[:-1]) + # for i in range(atom_L): + # frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1] + # frames_reindex = frames_reindex.long() + # true_crds[atoms, :, :3] = atom_crds.reshape(atom_L*natoms, 3)[frames_reindex] + true_crds = xyz_to_frame_xyz(true_crds, msa[:, 0,0], atom_frames) + c6d, _ = xyz_to_c6d(true_crds) + + xyz_t = xyz_t_to_frame_xyz(xyz_t, msa[:, 0,0].squeeze(0), atom_frames) + t2d = xyz_to_t2d(xyz_t) + break + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/RF2_allatom/train_multi_EMA.py b/RF2_allatom/train_multi_EMA.py new file mode 100644 index 0000000..1bf3dc3 --- /dev/null +++ b/RF2_allatom/train_multi_EMA.py @@ -0,0 +1,1591 @@ +import sys, os +import time +import numpy as np +from copy import deepcopy +from collections import OrderedDict +import torch +import torch.nn as nn +from torch.utils import data +from functools import partial +from data_loader import ( + get_train_valid_set, loader_pdb, loader_fb, loader_complex, loader_na_complex, loader_rna, loader_sm_compl, + Dataset, DatasetComplex, DatasetNAComplex, DatasetRNA, DatasetSMComplex, DistilledDataset, DistributedWeightedSampler +) +from kinematics import xyz_to_c6d, c6d_to_bins, xyz_to_t2d, xyz_to_bbtor, get_init_xyz +from RoseTTAFoldModel import RoseTTAFoldModule +from loss import * +from util import * +from util_module import ComputeAllAtomCoords +from scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedule_with_warmup + +# distributed data parallel +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +#torch.autograd.set_detect_anomaly(True) +torch.manual_seed(5924) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +## To reproduce errors +#import random +np.random.seed(6636) +#random.seed(0) + +USE_AMP = False +torch.set_num_threads(4) + +N_PRINT_TRAIN = 16 +#BATCH_SIZE = 1 * torch.cuda.device_count() + +# num structs per epoch +# must be divisible by #GPUs +N_EXAMPLE_PER_EPOCH = 25600 + +LOAD_PARAM = {'shuffle': False, + 'num_workers': 3, + 'pin_memory': True} + +def add_weight_decay(model, l2_coeff): + decay, no_decay = [], [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + #if len(param.shape) == 1 or name.endswith(".bias"): + if "norm" in name or name.endswith(".bias"): + no_decay.append(param) + else: + decay.append(param) + return [{'params': no_decay, 'weight_decay': 0.0}, {'params': decay, 'weight_decay': l2_coeff}] + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + +class EMA(nn.Module): + def __init__(self, model, decay): + super().__init__() + self.decay = decay + + self.model = model + self.shadow = deepcopy(self.model) + + for param in self.shadow.parameters(): + param.detach_() + + @torch.no_grad() + def update(self): + if not self.training: + print("EMA update should only be called during training", file=stderr, flush=True) + return + + model_params = OrderedDict(self.model.named_parameters()) + shadow_params = OrderedDict(self.shadow.named_parameters()) + + # check if both model contains the same set of keys + assert model_params.keys() == shadow_params.keys() + + for name, param in model_params.items(): + # see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + # shadow_variable -= (1 - decay) * (shadow_variable - variable) + if param.requires_grad: + shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param)) + + model_buffers = OrderedDict(self.model.named_buffers()) + shadow_buffers = OrderedDict(self.shadow.named_buffers()) + + # check if both model contains the same set of keys + assert model_buffers.keys() == shadow_buffers.keys() + + for name, buffer in model_buffers.items(): + # buffers are copied + shadow_buffers[name].copy_(buffer) + + def forward(self, *args, **kwargs): + if self.training: + return self.model(*args, **kwargs) + else: + return self.shadow(*args, **kwargs) + +class Trainer(): + def __init__(self, model_name='BFF', + n_epoch=100, step_lr=100, lr=1.0e-4, l2_coeff=1.0e-2, port=None, interactive=False, + model_param={}, loader_param={}, loss_param={}, batch_size=1, accum_step=1, maxcycle=4, eval=False): + self.model_name = model_name #"BFF" + #self.model_name = "%s_%d_%d_%d_%d"%(model_name, model_param['n_module'], + # model_param['n_module_str'], + # model_param['d_msa'], + # model_param['d_pair']) + # + self.n_epoch = n_epoch + self.step_lr = step_lr + self.init_lr = lr + self.l2_coeff = l2_coeff + self.port = port + self.interactive = interactive + self.eval = eval + # + self.model_param = model_param + self.loader_param = loader_param + self.loss_param = loss_param + self.ACCUM_STEP = accum_step + self.batch_size = batch_size + + # for all-atom str loss + self.ti_dev = torsion_indices + self.ti_flip = torsion_can_flip + self.ang_ref = reference_angles + self.fi_dev = frame_indices + self.l2a = long2alt + self.aamask = allatom_mask + self.num_bonds = num_bonds + self.atom_type_index = atom_type_index + self.ljlk_parameters = ljlk_parameters + self.lj_correction_parameters = lj_correction_parameters + self.hbtypes = hbtypes + self.hbbaseatoms = hbbaseatoms + self.hbpolys = hbpolys + self.cb_len = cb_length_t + self.cb_ang = cb_angle_t + self.cb_tor = cb_torsion_t + + # module torsion -> allatom + self.compute_allatom_coords = ComputeAllAtomCoords() + + # loss & final activation function + self.loss_fn = nn.CrossEntropyLoss(reduction='none') + self.active_fn = nn.Softmax(dim=1) + + self.maxcycle = maxcycle + + self.pdb_counter=0 + + def calc_loss(self, logit_s, label_s, + logit_aa_s, label_aa_s, mask_aa_s, + pred, pred_tors, pred_allatom, true, + mask_crds, mask_BB, mask_2d, same_chain, + pred_lddt, idx, atom_frames=None, unclamp=False, negative=False, interface=False, + verbose=False, ctr=0, + w_dist=1.0, w_aa=1.0, w_str=1.0, w_lddt=1.0, w_bond=1.0, w_clash=0.0, w_hb=0.0, w_dih=0.0, + lj_lin=0.85, eps=1e-6 + ): + B, L = true.shape[:2] + seq = label_aa_s[:,0].clone() + + assert (B==1) # fd - code assumes a batch size of 1 + + loss_s = list() + tot_loss = 0.0 + + # c6d loss + for i in range(4): + loss = self.loss_fn(logit_s[i], label_s[...,i]) # (B, L, L) + loss = (mask_2d*loss).sum() / (mask_2d.sum() + eps) + tot_loss += w_dist*loss + loss_s.append(loss[None].detach()) + + # masked token prediction loss + loss = self.loss_fn(logit_aa_s, label_aa_s.reshape(B, -1)) + loss = loss * mask_aa_s.reshape(B, -1) + loss = loss.sum() / (mask_aa_s.sum() + 1e-8) + tot_loss += w_aa*loss + loss_s.append(loss[None].detach()) + + ### GENERAL LAYERS + # Structural loss + dclamp = 300.0 if unclamp else 30.0 + frames, frame_mask = get_frames( + pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames) + frame_mask_BB = frame_mask.clone() + frame_mask_BB[...,1:] =False + if negative: # inter-chain fapes should be ignored for negative cases + L1 = same_chain[0,0,:].sum() + mask_BBA = mask_BB.clone() + mask_BBA[0, L1:] = False + l_fape_A = compute_general_FAPE( + pred_allatom[:,mask_BBA[0],:,:3], + true[:,mask_BBA[0],:,:3], + mask_crds[:,mask_BBA[0]], + frames[:,mask_BBA[0]], + frame_mask_BB[:,mask_BBA[0]], + dclamp=dclamp + ) + mask_BBB = mask_BB.clone() + mask_BBB[0,:L1] = False + l_fape_B = compute_general_FAPE( + pred_allatom[:,mask_BBB[0],:,:3], + true[:,mask_BBB[0],:,:3], + mask_crds[:,mask_BBB[0]], + frames[:,mask_BBB[0]], + frame_mask_BB[:,mask_BBB[0]], + dclamp=dclamp + ) + fracA = float(L1)/len(same_chain[0,0]) + tot_str = fracA*l_fape_A + (1.0-fracA)*l_fape_B + + else: + tot_str = compute_general_FAPE( + pred_allatom[:,mask_BB[0],:,:3], + true[:,mask_BB[0],:,:3], + mask_crds[:,mask_BB[0]], + frames[:,mask_BB[0]], + frame_mask_BB[:,mask_BB[0]], + dclamp=dclamp + ) + tot_loss += 0.5*w_str*tot_str[0] + loss_s.append(tot_str.detach()) + + # AllAtom loss + # get ground-truth torsion angles + true_tors, true_tors_alt, tors_mask, tors_planar = get_torsions( + true, seq, self.ti_dev, self.ti_flip, self.ang_ref, mask_in=mask_crds) + tors_mask *= mask_BB[...,None] + + # get alternative coordinates for ground-truth + true_alt = torch.zeros_like(true) + true_alt.scatter_(2, self.l2a[seq,:,None].repeat(1,1,1,3), true) + natRs_all, _n0 = self.compute_allatom_coords(seq, true[...,:3,:], true_tors) + natRs_all_alt, _n1 = self.compute_allatom_coords(seq, true_alt[...,:3,:], true_tors_alt) + predTs = pred[-1,...] + predRs_all, pred_all = self.compute_allatom_coords(seq, predTs, pred_tors[-1]) + + # - resolve symmetry + xs_mask = self.aamask[seq] # (B, L, 27) + xs_mask[0,:,14:]=False # (ignore hydrogens except lj loss) + xs_mask *= mask_crds # mask missing atoms & residues as well + natRs_all_symm, nat_symm = resolve_symmetry(pred_allatom[-1], natRs_all[0], true[0], natRs_all_alt[0], true_alt[0], xs_mask[0]) + + # torsion angle loss + l_tors = torsionAngleLoss( + pred_tors, + true_tors, + true_tors_alt, + tors_mask, + tors_planar, + eps = 1e-10) + tot_loss += w_str*l_tors + loss_s.append(l_tors[None].detach()) + + ### FINETUNING LAYERS + # lddts (CA) + ca_lddt = calc_lddt(pred[:,:,:,1].detach(), true[:,:,1], mask_BB, mask_2d, same_chain, negative=negative, interface=interface) + loss_s.append(ca_lddt.detach()) + + # lddts (allatom) + lddt loss + lddt_loss, allatom_lddt = calc_allatom_lddt_loss( + pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, negative=negative, interface=interface) + tot_loss += w_lddt*lddt_loss + loss_s.append(lddt_loss.detach()[None]) + loss_s.append(allatom_lddt.detach()) + + # FAPE losses + # allatom fape and torsion angle loss + # frames, frame_mask = get_frames( + # pred_allatom[-1,None,...], mask_crds, seq, self.fi_dev, atom_frames) + + if negative: # inter-chain fapes should be ignored for negative cases + # L1 = same_chain[0,0,:].sum() + # mask_BBA = mask_BB.clone() + # mask_BBA[0, L1:] = False + l_fape_A = compute_general_FAPE( + pred_allatom[:,mask_BBA[0],:,:3], + nat_symm[None,mask_BBA[0],:,:3], + xs_mask[:,mask_BBA[0]], + frames[:,mask_BBA[0]], + frame_mask[:,mask_BBA[0]] + ) + # mask_BBB = mask_BB.clone() + # mask_BBB[0,:L1] = False + l_fape_B = compute_general_FAPE( + pred_allatom[:,mask_BBB[0],:,:3], + nat_symm[None,mask_BBB[0],:,:3], + xs_mask[:,mask_BBB[0]], + frames[:,mask_BBB[0]], + frame_mask[:,mask_BBB[0]] + ) + fracA = float(L1)/len(same_chain[0,0]) + l_fape = fracA*l_fape_A + (1.0-fracA)*l_fape_B + + else: + l_fape = compute_general_FAPE( + pred_allatom[:,mask_BB[0],:,:3], + nat_symm[None,mask_BB[0],:,:3], + xs_mask[:,mask_BB[0]], + frames[:,mask_BB[0]], + frame_mask[:,mask_BB[0]] + ) + + loss_s.append(l_fape.detach()) + tot_loss += w_str*l_fape.mean() + + # cart bonded (bond geometry) + bond_loss = calc_BB_bond_geom(seq[0], pred_allatom[0:1], idx) + if w_bond > 0.0: + tot_loss += w_bond*bond_loss + loss_s.append( bond_loss[None].detach() ) + + if (pred_allatom.shape[0] > 1): + bond_loss = calc_cart_bonded(seq, pred_allatom[1:], idx, self.cb_len, self.cb_ang, self.cb_tor) + if w_bond > 0.0: + tot_loss += w_bond*bond_loss.mean() + loss_s.append( bond_loss.detach() ) + + # clash [use all atoms not just those in native] + clash_loss = calc_lj( + seq[0], pred_allatom, + self.aamask, self.ljlk_parameters, self.lj_correction_parameters, self.num_bonds, + lj_lin=lj_lin + ) + if w_clash > 0.0: + tot_loss += w_clash*clash_loss.mean() + loss_s.append( clash_loss.detach() ) + + # hbond [use all atoms not just those in native] + #hb_loss = calc_hb( + # seq[0], pred_all[0,...,:3], + # self.aamask, self.hbtypes, self.hbbaseatoms, self.hbpolys, + # normalize=(not verbose) + #) + #if w_hb > 0.0: + # tot_loss += w_hb*hb_loss + #loss_s.append(torch.stack((hb_loss, clash_loss, bond_loss)).detach()) + + if (verbose): + print ( + ctr, + allatom_lddt.cpu().detach().numpy(), + l_fape.cpu().detach().numpy(), + mask_BB[0].sum() + ) + writepdb("p_"+self.model_name+"_"+str(ctr)+".pdb", pred_all[-1,mask_BB[0]][:,:23], seq[mask_BB][:]) + writepdb("n_"+str(ctr)+".pdb", true[mask_BB][:,:23], seq[mask_BB][:]) + writepdb("nre_"+str(ctr)+".pdb", _n0[mask_BB], seq[mask_BB][:]) + return tot_loss, torch.cat(loss_s, dim=0) + + + def calc_acc(self, prob, dist, idx_pdb, mask_2d, return_cnt=False): + B = idx_pdb.shape[0] + L = idx_pdb.shape[1] # (B, L) + seqsep = torch.abs(idx_pdb[:,:,None] - idx_pdb[:,None,:]) + 1 + mask = seqsep > 24 + mask = torch.triu(mask.float()) + mask *= mask_2d + # + cnt_ref = dist < 20 + cnt_ref = cnt_ref.float() * mask + # + cnt_pred = prob[:,:20,:,:].sum(dim=1) * mask + # + top_pred = torch.topk(cnt_pred.view(B,-1), L) + kth = top_pred.values.min(dim=-1).values + tmp_pred = list() + for i_batch in range(B): + tmp_pred.append(cnt_pred[i_batch] > kth[i_batch]) + tmp_pred = torch.stack(tmp_pred, dim=0) + tmp_pred = tmp_pred.float()*mask + # + condition = torch.logical_and(tmp_pred==cnt_ref, cnt_ref==torch.ones_like(cnt_ref)) + n_good = condition.float().sum() + n_total = (cnt_ref == torch.ones_like(cnt_ref)).float().sum() + 1e-9 + n_total_pred = (tmp_pred == torch.ones_like(tmp_pred)).float().sum() + 1e-9 + prec = n_good / n_total_pred + recall = n_good / n_total + F1 = 2.0*prec*recall / (prec+recall+1e-9) + if return_cnt: + return torch.stack([prec, recall, F1]), cnt_pred, cnt_ref + + return torch.stack([prec, recall, F1]) + + def load_model(self, model, optimizer, scheduler, scaler, model_name, rank, suffix='last', resume_train=False): + chk_fn = "models/%s_%s.pt"%(model_name, suffix) + loaded_epoch = -1 + best_valid_loss = 999999.9 + if not os.path.exists(chk_fn): + print ('no model found', model_name) + return -1, best_valid_loss + print ('loading model', model_name) + map_location = {"cuda:%d"%0: "cuda:%d"%rank} + checkpoint = torch.load(chk_fn, map_location=map_location) + rename_model = False + new_chk = {} + msd_src = checkpoint['model_state_dict'] + msd_tgt = model.module.model.state_dict() + for param in msd_tgt: + + if param not in msd_src: + print ('missing',param) + rename_model=True + #break + elif (msd_tgt[param].shape == msd_src[param].shape): + new_chk[param] = msd_src[param] + else: + # fd hack for new encoding + if (msd_src[param].shape[0]==30 and msd_tgt[param].shape[0]==32 and 'compute_allatom_coords' not in param): + print ('Fixing',param) + new_chk[param] = torch.zeros_like(msd_tgt[param]) + new_chk[param][:26] = msd_src[param][:26] + new_chk[param][27:31] = msd_src[param][26:30] + + else: + #wrong size latent_emb.emb.weight torch.Size([256, 64]) torch.Size([256, 68]) + #wrong size templ_emb.emb.weight torch.Size([64, 104]) torch.Size([64, 108]) + #wrong size full_emb.emb.weight torch.Size([64, 33]) torch.Size([64, 35]) + + print ( + 'wrong size',param, + checkpoint['model_state_dict'][param].shape, + model.module.model.state_dict()[param].shape ) + rename_model=True + + #new_chk = checkpoint['model_state_dict'] + model.module.model.load_state_dict(new_chk, strict=False) + model.module.shadow.load_state_dict(new_chk, strict=False) + if resume_train and (not rename_model): + print (' ... loading optimization params') + loaded_epoch = checkpoint['epoch'] + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + scaler.load_state_dict(checkpoint['scaler_state_dict']) + if 'scheduler_state_dict' in checkpoint: + print (' ... loading scheduler params') + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + else: + scheduler.last_epoch = loaded_epoch + 1 + #if 'best_loss' in checkpoint: + # best_valid_loss = checkpoint['best_loss'] + return loaded_epoch, best_valid_loss + + def checkpoint_fn(self, model_name, description): + if not os.path.exists("models"): + os.mkdir("models") + name = "%s_%s.pt"%(model_name, description) + return os.path.join("models", name) + + # main entry function of training + # 1) make sure ddp env vars set + # 2) figure out if we launched using slurm or interactively + # - if slurm, assume 1 job launched per GPU + # - if interactive, launch one job for each GPU on node + def run_model_training(self, world_size): + if ('MASTER_ADDR' not in os.environ): + os.environ['MASTER_ADDR'] = '127.0.0.1' # multinode requires this set in submit script + if ('MASTER_PORT' not in os.environ): + os.environ['MASTER_PORT'] = '%d'%self.port + + if (not self.interactive and "SLURM_NTASKS" in os.environ and "SLURM_PROCID" in os.environ): + world_size = int(os.environ["SLURM_NTASKS"]) + rank = int (os.environ["SLURM_PROCID"]) + print ("Launched from slurm", rank, world_size) + self.train_model(rank, world_size) + else: + print ("Launched from interactive") + world_size = torch.cuda.device_count() + mp.spawn(self.train_model, args=(world_size,), nprocs=world_size, join=True) + + def train_model(self, rank, world_size): + #print ("running ddp on rank %d, world_size %d"%(rank, world_size)) + gpu = rank % torch.cuda.device_count() + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + torch.cuda.set_device("cuda:%d"%gpu) + + #define dataset & data loader + ( + pdb_items, fb_items, compl_items, neg_items, na_compl_items, na_neg_items, rna_items, + sm_compl_items, valid_pdb, valid_homo, valid_compl, valid_neg, valid_na_compl, + valid_na_neg, valid_rna, valid_sm_compl, homo + ) = get_train_valid_set(self.loader_param) + + pdb_IDs, pdb_weights, pdb_dict = pdb_items + fb_IDs, fb_weights, fb_dict = fb_items + compl_IDs, compl_weights, compl_dict = compl_items + neg_IDs, neg_weights, neg_dict = neg_items + na_compl_IDs, na_compl_weights, na_compl_dict = na_compl_items + na_neg_IDs, na_neg_weights, na_neg_dict = na_neg_items + rna_IDs, rna_weights, rna_dict = rna_items + sm_compl_IDs, sm_compl_weights, sm_compl_dict = sm_compl_items + + self.n_train = N_EXAMPLE_PER_EPOCH + self.n_valid_pdb = len(valid_pdb.keys()) + #self.n_valid_pdb = (self.n_valid_pdb // world_size)*world_size + self.n_valid_homo = len(valid_homo.keys()) + #self.n_valid_homo = (self.n_valid_homo // world_size)*world_size + self.n_valid_compl = len(valid_compl.keys()) + #self.n_valid_compl = (self.n_valid_compl // world_size)*world_size + self.n_valid_neg = len(valid_neg.keys()) + #self.n_valid_neg = (self.n_valid_neg // world_size)*world_size + self.n_valid_na_compl = len(valid_na_compl.keys()) + #self.n_valid_na_compl = (self.n_valid_na_compl // world_size)*world_size + self.n_valid_na_neg = len(valid_na_neg.keys()) + #self.n_valid_na_neg = (self.n_valid_na_neg // world_size)*world_size + self.n_valid_rna = len(valid_rna.keys()) + #self.n_valid_rna = (self.n_valid_rna // world_size)*world_size + self.n_valid_rna = len(valid_rna.keys()) + self.n_valid_sm_compl = len(valid_sm_compl.keys()) + + #self.n_valid_pdb = 4 + #self.n_valid_homo = 4 + #self.n_valid_compl = 4 + #self.n_valid_neg = 4 + #self.n_valid_na_compl = 4 + #self.n_valid_na_neg = 4 + #self.n_valid_rna = 4 + + if (rank==0): + print ('Loaded (training)', + len(pdb_IDs),'monomers/homomers,', + len(fb_IDs),'distilled monomers,', + len(compl_IDs),'heteromers,', + len(neg_IDs),'negative heteromers,', + len(na_compl_IDs),'nucleic-acid complexes,', + len(na_neg_IDs),'negative nucleic-acid complexes,', + len(rna_IDs),'RNA structures, and', + len(sm_compl_IDs), 'small molecule complexes' + ) + print ('Loaded (valid)', + len(valid_pdb.keys()),'monomers,', + len(valid_homo.keys()),'homomers,', + len(valid_compl.keys()),'heteromers,', + len(valid_neg.keys()),'negative heteromers,', + len(valid_na_compl.keys()),'nucleic-acid complexes,', + len(valid_na_neg.keys()),'negative nucleic-acid complexes,', + len(valid_rna),'RNA structures, and', + len(valid_sm_compl), 'small molecule complexes' + ) + print ('Using', + self.n_valid_pdb,'monomers,', + self.n_valid_homo,'homomers,', + self.n_valid_compl,'heteromers,', + self.n_valid_neg,'negative heteromers', + self.n_valid_na_compl,'nucleic-acid complexes,', + self.n_valid_na_neg,'negative nucleic-acid complexes,', + self.n_valid_rna,'RNA structures, and', + self.n_valid_sm_compl, 'small molecule complexes' + ) + + train_set = DistilledDataset( + pdb_IDs, loader_pdb, pdb_dict, + compl_IDs, loader_complex, compl_dict, + neg_IDs, loader_complex, neg_dict, + na_compl_IDs, loader_na_complex, na_compl_dict, + na_neg_IDs, loader_na_complex, na_neg_dict, + fb_IDs, loader_fb, fb_dict, + rna_IDs, loader_rna, rna_dict, + sm_compl_IDs, loader_sm_compl, sm_compl_dict, + homo, + self.loader_param, + native_NA_frac=0.25 + ) + + valid_pdb_set = Dataset( + list(valid_pdb.keys())[:self.n_valid_pdb], + loader_pdb, valid_pdb, + self.loader_param, homo, p_homo_cut=-1.0 + ) + valid_homo_set = Dataset( + list(valid_homo.keys())[:self.n_valid_homo], + loader_pdb, valid_homo, + self.loader_param, homo, p_homo_cut=2.0 + ) + valid_compl_set = DatasetComplex( + list(valid_compl.keys())[:self.n_valid_compl], + loader_complex, valid_compl, + self.loader_param, negative=False + ) + valid_neg_set = DatasetComplex( + list(valid_neg.keys())[:self.n_valid_neg], + loader_complex, valid_neg, + self.loader_param, negative=True + ) + valid_na_compl_set = DatasetNAComplex( + list(valid_na_compl.keys())[:self.n_valid_na_compl], + loader_na_complex, valid_na_compl, + self.loader_param, negative=False, native_NA_frac=1.0 + ) + valid_na_neg_set = DatasetNAComplex( + list(valid_na_neg.keys())[:self.n_valid_na_neg], + loader_na_complex, valid_na_neg, + self.loader_param, negative=True, native_NA_frac=1.0 + ) + valid_na_from_scratch_compl_set = DatasetNAComplex( + list(valid_na_compl.keys())[:self.n_valid_na_compl], + loader_na_complex, valid_na_compl, + self.loader_param, negative=False, native_NA_frac=0.0 + ) + valid_na_from_scratch_neg_set = DatasetNAComplex( + list(valid_na_neg.keys())[:self.n_valid_na_neg], + loader_na_complex, valid_na_neg, + self.loader_param, negative=True, native_NA_frac=0.0 + ) + valid_rna_set = DatasetRNA( + list(valid_rna.keys())[:self.n_valid_rna], + loader_rna, valid_rna, + self.loader_param + ) + valid_sm_compl_set = DatasetSMComplex( + list(valid_sm_compl.keys())[:self.n_valid_sm_compl], + loader_sm_compl, valid_sm_compl, + self.loader_param + ) + + train_sampler = DistributedWeightedSampler( + train_set, + pdb_weights, + fb_weights, + compl_weights, + neg_weights, + na_compl_weights, + na_neg_weights, + rna_weights, + sm_compl_weights, + num_example_per_epoch=N_EXAMPLE_PER_EPOCH, + num_replicas=world_size, + rank=rank, + fraction_fb=0.0, + fraction_compl=0.0, + fraction_na_compl=0.0, + fraction_rna=0.0, + fraction_sm_compl=1.0, + replacement=True + ) + + valid_pdb_sampler = data.distributed.DistributedSampler(valid_pdb_set, num_replicas=world_size, rank=rank) + valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank) + valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank) + valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank) + valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank) + valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank) + valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank) + valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank) + valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank) + valid_sm_compl_sampler = data.distributed.DistributedSampler(valid_sm_compl_set, num_replicas=world_size, rank=rank) + + train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM) + valid_pdb_loader = data.DataLoader(valid_pdb_set, sampler=valid_pdb_sampler, **LOAD_PARAM) + valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM) + valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM) + valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM) + valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM) + valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM) + valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM) + valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM) + valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM) + valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set, sampler=valid_sm_compl_sampler, **LOAD_PARAM) + + # move some global data to cuda device + self.ti_dev = self.ti_dev.to(gpu) + self.ti_flip = self.ti_flip.to(gpu) + self.ang_ref = self.ang_ref.to(gpu) + self.fi_dev = self.fi_dev.to(gpu) + self.l2a = self.l2a.to(gpu) + self.aamask = self.aamask.to(gpu) + self.compute_allatom_coords = self.compute_allatom_coords.to(gpu) + self.num_bonds = self.num_bonds.to(gpu) + self.atom_type_index = self.atom_type_index.to(gpu) + self.ljlk_parameters = self.ljlk_parameters.to(gpu) + self.lj_correction_parameters = self.lj_correction_parameters.to(gpu) + self.hbtypes = self.hbtypes.to(gpu) + self.hbbaseatoms = self.hbbaseatoms.to(gpu) + self.hbpolys = self.hbpolys.to(gpu) + self.cb_len = self.cb_len.to(gpu) + self.cb_ang = self.cb_ang.to(gpu) + self.cb_tor = self.cb_tor.to(gpu) + + # define model + model = EMA(RoseTTAFoldModule( + **self.model_param, + aamask=self.aamask, + atom_type_index=self.atom_type_index, + ljlk_parameters=self.ljlk_parameters, + lj_correction_parameters=self.lj_correction_parameters, + num_bonds=self.num_bonds, + cb_len = self.cb_len, + cb_ang = self.cb_ang, + cb_tor = self.cb_tor, + lj_lin=self.loss_param['lj_lin'] + ).to(gpu), 0.999) + + #for n,p in model.named_parameters(): + # if ("finetune_refiner" not in n and "residue_embed" not in n and "allatom_embed" not in n): + # p.requires_grad_(False) + + ddp_model = DDP(model, device_ids=[gpu], find_unused_parameters=False) + if rank == 0: + print ("# of parameters:", count_parameters(ddp_model)) + + # define optimizer and scheduler + opt_params = add_weight_decay(ddp_model, self.l2_coeff) + optimizer = torch.optim.AdamW(opt_params, lr=self.init_lr) + #scheduler = get_stepwise_decay_schedule_with_warmup(optimizer, 1000, 5000, 0.95) + scheduler = get_stepwise_decay_schedule_with_warmup(optimizer, 0, 5000, 0.95) + scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP) + + # load model + loaded_epoch, best_valid_loss = self.load_model(ddp_model, optimizer, scheduler, scaler, + self.model_name, gpu, resume_train=True) + + if (self.eval): + # run protein/NA prediction (TEMPLATED) + #_, _, _ = self.valid_ppi_cycle( + # ddp_model, valid_na_compl_loader, valid_na_neg_loader, + # rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True) + + # run protein/NA prediction (NON-TEMPLATED) + _, _, _ = self.valid_ppi_cycle( + ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader, + rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True) + + # run RNA prediction + #_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, 0, verbose=True) + + dist.destroy_process_group() + return + + if loaded_epoch >= self.n_epoch: + DDP_cleanup() + return + + #_, _, _ = self.valid_pdb_cycle(ddp_model, valid_homo_loader, rank, gpu, world_size, epoch, header="Homo") + #_, _, _ = self.valid_ppi_cycle(ddp_model, valid_compl_loader, valid_neg_loader, rank, gpu, world_size, epoch, report_interface=True) + #_, _, _ = self.valid_ppi_cycle( + # ddp_model, valid_na_compl_loader, valid_na_neg_loader, + # rank, gpu, world_size, epoch, header="NA", report_interface=False) + #_, _, _ = self.valid_ppi_cycle( + # ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader, + # rank, gpu, world_size, epoch, header="NAfs", report_interface=False) + #_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, epoch, header="RNA") + + for epoch in range(loaded_epoch+1, self.n_epoch): + train_sampler.set_epoch(epoch) + valid_pdb_sampler.set_epoch(epoch) + valid_homo_sampler.set_epoch(epoch) + valid_compl_sampler.set_epoch(epoch) + valid_neg_sampler.set_epoch(epoch) + + train_tot, train_loss, train_acc = self.train_cycle(ddp_model, train_loader, optimizer, scheduler, scaler, rank, gpu, world_size, epoch) + + valid_tot, valid_loss, valid_acc = self.valid_pdb_cycle(ddp_model, valid_pdb_loader, rank, gpu, world_size, epoch) + #_, _, _ = self.valid_pdb_cycle(ddp_model, valid_homo_loader, rank, gpu, world_size, epoch, header="Homo") + #_, _, _ = self.valid_ppi_cycle(ddp_model, valid_compl_loader, valid_neg_loader, rank, gpu, world_size, epoch, report_interface=True) + _, _, _ = self.valid_ppi_cycle( + ddp_model, valid_na_compl_loader, valid_na_neg_loader, + rank, gpu, world_size, epoch, header="NA", report_interface=False) + _, _, _ = self.valid_ppi_cycle( + ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader, + rank, gpu, world_size, epoch, header="NAfs", report_interface=False) + _,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, epoch, header="RNA") + _,_,_ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, epoch, header="SM Compl") + + if rank == 0: # save model + if valid_tot < best_valid_loss: + best_valid_loss = valid_tot + torch.save({'epoch': epoch, + #'model_state_dict': ddp_model.state_dict(), + 'model_state_dict': ddp_model.module.shadow.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'scaler_state_dict': scaler.state_dict(), + 'best_loss': best_valid_loss, + 'train_loss': train_loss, + 'train_acc': train_acc, + 'valid_loss': valid_loss, + 'valid_acc': valid_acc}, + self.checkpoint_fn(self.model_name, 'best')) + + + torch.save({'epoch': epoch, + #'model_state_dict': ddp_model.state_dict(), + 'model_state_dict': ddp_model.module.shadow.state_dict(), + 'final_state_dict': ddp_model.module.model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'scaler_state_dict': scaler.state_dict(), + 'train_loss': train_loss, + 'train_acc': train_acc, + 'valid_loss': valid_loss, + 'valid_acc': valid_acc, + 'best_loss': best_valid_loss}, + self.checkpoint_fn(self.model_name, 'last')) + + dist.destroy_process_group() + + def train_cycle(self, ddp_model, train_loader, optimizer, scheduler, scaler, rank, gpu, world_size, epoch, verbose=False): + # Turn on training mode + ddp_model.train() + + # clear gradients + optimizer.zero_grad() + + start_time = time.time() + + # For intermediate logs + local_tot = 0.0 + local_loss = None + local_acc = None + train_tot = 0.0 + train_loss = None + train_acc = None + + counter = 0 + + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in train_loader: + + # transfer inputs to device + B, _, N, L = msa.shape + + idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L) + true_crds = true_crds.to(gpu, non_blocking=True) # (B, N?, L, Natms, 3) + atom_mask = atom_mask.to(gpu, non_blocking=True) # (B, L, Natms) + same_chain = same_chain.to(gpu, non_blocking=True) # (B, L, L) + + xyz_t = xyz_t.to(gpu, non_blocking=True) + t1d = t1d.to(gpu, non_blocking=True) + + seq = seq.to(gpu, non_blocking=True) + msa = msa.to(gpu, non_blocking=True) + msa_masked = msa_masked.to(gpu, non_blocking=True) + msa_full = msa_full.to(gpu, non_blocking=True) + mask_msa = mask_msa.to(gpu, non_blocking=True) + atom_frames = atom_frames.to(gpu, non_blocking=True) + bond_feats = bond_feats.to(gpu, non_blocking=True) + + # processing template features + # get torsion angles from templates + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames) + t2d = xyz_to_t2d(xyz_t_frames) + + alpha, _, alpha_mask, _ = get_torsions( + xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2) + alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS) + + # processing template coordinates + xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain) + xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3) + + counter += 1 + + N_cycle = np.random.randint(1, self.maxcycle+1) # number of recycling + + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) + state_prev = None + + with torch.no_grad(): + for i_cycle in range(N_cycle-1): + with ddp_model.no_sync(): + with torch.cuda.amp.autocast(enabled=USE_AMP): + msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + return_raw=True, + use_checkpoint=False + ) + + i_cycle = N_cycle-1 + + if counter%self.ACCUM_STEP != 0: + with ddp_model.no_sync(): + with torch.cuda.amp.autocast(enabled=USE_AMP): + logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + use_checkpoint=True + ) + + true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames) + c6d, _ = xyz_to_c6d(true_crds_frame) + c6d = c6d_to_bins(c6d, same_chain, negative=negative) + + prob = self.active_fn(logit_s[0]) # distogram + acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d) + + ctrid = len(train_loader)*rank+counter + loss, loss_s = self.calc_loss( + logit_s, c6d, + logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle], + pred_crds, alphas, pred_allatom, true_crds, + atom_mask, res_mask, mask_2d, same_chain, + pred_lddts, idx_pdb, atom_frames=atom_frames, + unclamp=unclamp, negative=negative, + verbose=verbose, ctr=ctrid, **self.loss_param + ) + loss = loss / self.ACCUM_STEP + scaler.scale(loss).backward() + else: + with torch.cuda.amp.autocast(enabled=USE_AMP): + logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + use_checkpoint=True + ) + + true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames) + c6d, _ = xyz_to_c6d(true_crds_frame) + c6d = c6d_to_bins(c6d, same_chain, negative=negative) + + prob = self.active_fn(logit_s[0]) # distogram + acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d) + + ctrid = len(train_loader)*rank+counter + loss, loss_s = self.calc_loss( + logit_s, c6d, + logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle], + pred_crds, alphas, pred_allatom, true_crds, + atom_mask, res_mask, mask_2d, same_chain, + pred_lddts, idx_pdb, atom_frames=atom_frames, unclamp=unclamp, negative=negative, + verbose=verbose, ctr=ctrid, **self.loss_param + ) + loss = loss / self.ACCUM_STEP + scaler.scale(loss).backward() + # gradient clipping + scaler.unscale_(optimizer) + + torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 0.2) + + scaler.step(optimizer) + scale = scaler.get_scale() + scaler.update() + skip_lr_sched = (scale != scaler.get_scale()) + optimizer.zero_grad() + if not skip_lr_sched: + scheduler.step() + ddp_model.module.update() # apply EMA + + local_tot += loss.detach()*self.ACCUM_STEP + if local_loss == None: + local_loss = torch.zeros_like(loss_s.detach()) + local_acc = torch.zeros_like(acc_s.detach()) + local_loss += loss_s.detach() + local_acc += acc_s.detach() + + train_tot += loss.detach()*self.ACCUM_STEP + if train_loss == None: + train_loss = torch.zeros_like(loss_s.detach()) + train_acc = torch.zeros_like(acc_s.detach()) + train_loss += loss_s.detach() + train_acc += acc_s.detach() + + + if counter % N_PRINT_TRAIN == 0: + if rank == 0: + max_mem = torch.cuda.max_memory_allocated()/1e9 + train_time = time.time() - start_time + local_tot /= float(N_PRINT_TRAIN) + local_loss /= float(N_PRINT_TRAIN) + local_acc /= float(N_PRINT_TRAIN) + + local_tot = local_tot.cpu().detach() + local_loss = local_loss.cpu().detach().numpy() + local_acc = local_acc.cpu().detach().numpy() + + sys.stdout.write("Local: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f | Max mem %.4f\n"%(\ + epoch, self.n_epoch, counter*self.batch_size*world_size, self.n_train, train_time, local_tot, \ + " ".join(["%8.4f"%l for l in local_loss]),\ + local_acc[0], local_acc[1], local_acc[2], max_mem)) + sys.stdout.flush() + local_tot = 0.0 + local_loss = None + local_acc = None + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + # write total train loss + train_tot /= float(counter * world_size) + train_loss /= float(counter * world_size) + train_acc /= float(counter * world_size) + + dist.all_reduce(train_tot, op=dist.ReduceOp.SUM) + dist.all_reduce(train_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(train_acc, op=dist.ReduceOp.SUM) + train_tot = train_tot.cpu().detach() + train_loss = train_loss.cpu().detach().numpy() + train_acc = train_acc.cpu().detach().numpy() + if rank == 0: + + train_time = time.time() - start_time + sys.stdout.write("Train: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f\n"%(\ + epoch, self.n_epoch, self.n_train, self.n_train, train_time, train_tot, \ + " ".join(["%8.4f"%l for l in train_loss]),\ + train_acc[0], train_acc[1], train_acc[2])) + sys.stdout.flush() + + return train_tot, train_loss, train_acc + + def valid_pdb_cycle(self, ddp_model, valid_loader, rank, gpu, world_size, epoch, header='Monomer', verbose=False): + valid_tot = 0.0 + valid_loss = None + valid_acc = None + counter = 0 + + start_time = time.time() + + with torch.no_grad(): # no need to calculate gradient + ddp_model.eval() # change it to eval mode + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in valid_loader: + # transfer inputs to device + B, _, N, L = msa.shape + + idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L) + true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3) + atom_mask = atom_mask.to(gpu, non_blocking=True) # (B, L, 27) + same_chain = same_chain.to(gpu, non_blocking=True) + + xyz_t = xyz_t.to(gpu, non_blocking=True) + t1d = t1d.to(gpu, non_blocking=True) + + seq = seq.to(gpu, non_blocking=True) + msa = msa.to(gpu, non_blocking=True) + msa_masked = msa_masked.to(gpu, non_blocking=True) + msa_full = msa_full.to(gpu, non_blocking=True) + mask_msa = mask_msa.to(gpu, non_blocking=True) + atom_frames = atom_frames.to(gpu, non_blocking=True) + bond_feats = bond_feats.to(gpu, non_blocking=True) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) # ignore residues having missing BB atoms for loss calculation + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] # ignore pairs having missing residues + + # processing template features + # get torsion angles from templates + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames) + t2d = xyz_to_t2d(xyz_t_frames) + + alpha, _, alpha_mask, _ = get_torsions(xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2) + alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS) + # processing template coordinates + xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain) + xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3) + + # set number of recycles + N_cycle = self.maxcycle + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) #fd we could get this from the template... + state_prev = None + + for i_cycle in range(N_cycle-1): + msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + return_raw=True, + use_checkpoint=False + ) + + #true_crds_i, atom_mask_i = resolve_equiv_natives(xyz_prev, true_crds, atom_mask) + + #res_mask = ~(atom_mask_i[:,:,:3].sum(dim=-1) < 3.0) + #mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + i_cycle = N_cycle-1 + logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + use_checkpoint=False + ) + + true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames) + c6d, _ = xyz_to_c6d(true_crds_frame) + c6d = c6d_to_bins(c6d, same_chain, negative=negative) + + prob = self.active_fn(logit_s[0]) # distogram + acc_s = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d) + + ctrid = len(valid_loader)*rank+counter + loss, loss_s = self.calc_loss( + logit_s, c6d, + logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle], + pred_crds, alphas, pred_allatom, true_crds, + atom_mask, res_mask, mask_2d, same_chain, + pred_lddts, idx_pdb, atom_frames, unclamp=unclamp, negative=negative, + verbose=verbose, ctr=ctrid, **self.loss_param + ) + + valid_tot += loss.detach() + if valid_loss == None: + valid_loss = torch.zeros_like(loss_s.detach()) + valid_acc = torch.zeros_like(acc_s.detach()) + valid_loss += loss_s.detach() + valid_acc += acc_s.detach() + counter += 1 + + valid_tot /= float(counter*world_size) + valid_loss /= float(counter*world_size) + valid_acc /= float(counter*world_size) + + dist.all_reduce(valid_tot, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_acc, op=dist.ReduceOp.SUM) + + valid_tot = valid_tot.cpu().detach().numpy() + valid_loss = valid_loss.cpu().detach().numpy() + valid_acc = valid_acc.cpu().detach().numpy() + + if rank == 0: + train_time = time.time() - start_time + sys.stdout.write("%s: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f\n"%(\ + header, epoch, self.n_epoch, world_size*len(valid_loader), world_size*len(valid_loader), train_time, valid_tot, \ + " ".join(["%8.4f"%l for l in valid_loss]),\ + valid_acc[0], valid_acc[1], valid_acc[2])) + sys.stdout.flush() + return valid_tot, valid_loss, valid_acc + + def valid_ppi_cycle(self, ddp_model, valid_pos_loader, valid_neg_loader, rank, gpu, world_size, epoch, header='Protein', report_interface=True, verbose=False): + valid_tot = 0.0 + valid_loss = None + valid_acc = None + valid_inter = None + counter = 0 + + TP = 0 + TN = 0 + FP = 0 + FN = 0 + + start_time = time.time() + + with torch.no_grad(): # no need to calculate gradient + ddp_model.eval() # change it to eval mode + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in valid_pos_loader: + # transfer inputs to device + B, _, N, L = msa.shape + + idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L) + true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3) + atom_mask = mask_crds.to(gpu, non_blocking=True) # (B, L, 27) + same_chain = same_chain.to(gpu, non_blocking=True) + + xyz_t = xyz_t.to(gpu, non_blocking=True) + t1d = t1d.to(gpu, non_blocking=True) + + xyz_prev = xyz_prev.to(gpu, non_blocking=True) + + seq = seq.to(gpu, non_blocking=True) + msa = msa.to(gpu, non_blocking=True) + msa_masked = msa_masked.to(gpu, non_blocking=True) + msa_full = msa_full.to(gpu, non_blocking=True) + mask_msa = mask_msa.to(gpu, non_blocking=True) + atom_frames = atom_frames.to(gpu, non_blocking=True) + bond_feats = bond_feats.to(gpu, non_blocking=True) + + # processing labels for distogram orientograms + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) # ignore residues having missing BB atoms for loss calculation + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] # ignore pairs having missing residues + + # processing template features + # get torsion angles from templates + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames) + t2d = xyz_to_t2d(xyz_t_frames) + + alpha, _, alpha_mask, _ = get_torsions(xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2) + alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS) + # processing template coordinates + xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain) + xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3) + + N_cycle = self.maxcycle # number of recycling + + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) #fd we could get this from the template... + state_prev = None + + for i_cycle in range(N_cycle-1): + msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + return_raw=True, + use_checkpoint=False + ) + + #true_crds_i, atom_mask_i = resolve_equiv_natives(xyz_prev, true_crds, atom_mask) + + #res_mask = ~(atom_mask_i[:,:,:3].sum(dim=-1) < 3.0) + #mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + + i_cycle = N_cycle-1 + logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + use_checkpoint=False + ) + + true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) and ~(is_atom(msa[:,i_cycle,0]))) + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames) + c6d, _ = xyz_to_c6d(true_crds_frame) + c6d = c6d_to_bins(c6d, same_chain, negative=negative) + + prob = self.active_fn(logit_s[0]) # distogram + acc_s, cnt_pred, cnt_ref = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d, return_cnt=True) + + # inter-chain contact prob + cnt_pred = cnt_pred * (1-same_chain).float() + cnt_ref = cnt_ref * (1-same_chain).float() + max_prob = cnt_pred.max() + if max_prob > 0.5: + if (cnt_ref > 0).any(): + TP += 1.0 + else: + FP += 1.0 + else: + if (cnt_ref > 0).any(): + FN += 1.0 + else: + TN += 1.0 + inter_s = torch.tensor([TP, FP, TN, FN], device=prob.device).float() + + ctrid = len(valid_pos_loader)*rank+counter + loss, loss_s = self.calc_loss( + logit_s, c6d, + logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle], + pred_crds, alphas, pred_allatom, true_crds, + atom_mask, res_mask, mask_2d, same_chain, + pred_lddts, idx_pdb, atom_frames, unclamp=unclamp, negative=negative, interface=report_interface, + verbose=verbose, ctr=ctrid, **self.loss_param + ) + + valid_tot += loss.detach() + if valid_loss == None: + valid_loss = torch.zeros_like(loss_s.detach()) + valid_acc = torch.zeros_like(acc_s.detach()) + valid_inter = torch.zeros_like(inter_s.detach()) + valid_loss += loss_s.detach() + valid_acc += acc_s.detach() + valid_inter += inter_s.detach() + counter += 1 + + + valid_tot /= float(counter*world_size) + valid_loss /= float(counter*world_size) + valid_acc /= float(counter*world_size) + + dist.all_reduce(valid_tot, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_acc, op=dist.ReduceOp.SUM) + + valid_tot = valid_tot.cpu().detach().numpy() + valid_loss = valid_loss.cpu().detach().numpy() + valid_acc = valid_acc.cpu().detach().numpy() + + if rank == 0: + + train_time = time.time() - start_time + sys.stdout.write("%s-interface: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f\n"%(\ + header, epoch, self.n_epoch, counter*world_size, counter*world_size, train_time, valid_tot, \ + " ".join(["%8.4f"%l for l in valid_loss]),\ + valid_acc[0], valid_acc[1], valid_acc[2])) + sys.stdout.flush() + + valid_tot = 0.0 + valid_loss = None + valid_acc = None + counter = 0 + + start_time = time.time() + + with torch.no_grad(): # no need to calculate gradient + ddp_model.eval() # change it to eval mode + for seq, msa, msa_masked, msa_full, mask_msa, true_crds, mask_crds, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames in valid_neg_loader: + # transfer inputs to device + B, _, N, L = msa.shape + + idx_pdb = idx_pdb.to(gpu, non_blocking=True) # (B, L) + true_crds = true_crds.to(gpu, non_blocking=True) # (B, L, 27, 3) + atom_mask = mask_crds.to(gpu, non_blocking=True) # (B, L, 27) + same_chain = same_chain.to(gpu, non_blocking=True) + + xyz_t = xyz_t.to(gpu, non_blocking=True) + t1d = t1d.to(gpu, non_blocking=True) + + xyz_prev = xyz_prev.to(gpu, non_blocking=True) + + seq = seq.to(gpu, non_blocking=True) + msa = msa.to(gpu, non_blocking=True) + msa_masked = msa_masked.to(gpu, non_blocking=True) + msa_full = msa_full.to(gpu, non_blocking=True) + mask_msa = mask_msa.to(gpu, non_blocking=True) + atom_frames = atom_frames.to(gpu, non_blocking=True) + bond_feats = bond_feats.to(gpu, non_blocking=True) + + # processing labels for distogram orientograms + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) # ignore residues having missing BB atoms for loss calculation + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] # ignore pairs having missing residues + + # processing template features + # get torsion angles from templates + seq_tmp = t1d[...,:-1].argmax(dim=-1).reshape(-1,L) + xyz_t_frames = xyz_t_to_frame_xyz(xyz_t, seq_tmp, atom_frames) + t2d = xyz_to_t2d(xyz_t_frames) + + alpha, _, alpha_mask, _ = get_torsions(xyz_t.reshape(-1,L,NTOTAL,3), seq_tmp, self.ti_dev, self.ti_flip, self.ang_ref) + alpha_mask = torch.logical_and(alpha_mask, ~torch.isnan(alpha[...,0])) + alpha[torch.isnan(alpha)] = 0.0 + alpha = alpha.reshape(B,-1,L,NTOTALDOFS,2) + alpha_mask = alpha_mask.reshape(B,-1,L,NTOTALDOFS,1) + alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*NTOTALDOFS) + # processing template coordinates + xyz_t = get_init_xyz(seq[:,0],xyz_t,same_chain) + xyz_prev = get_init_xyz(seq[:,0],xyz_prev[:,None],same_chain).reshape(B, L, NTOTAL, 3) + + N_cycle = self.maxcycle # number of recycling + + msa_prev = None + pair_prev = None + alpha_prev = torch.zeros((B,L,NTOTALDOFS,2)).to(gpu, non_blocking=True) #fd we could get this from the template... + state_prev = None + for i_cycle in range(N_cycle-1): + msa_prev, pair_prev, xyz_prev, state_prev, alpha = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + return_raw=True, + use_checkpoint=False + ) + + i_cycle = N_cycle-1 + logit_s, logit_aa_s, pred_crds, alphas, pred_allatom, pred_lddts, _, _, _ = ddp_model( + msa_masked[:,i_cycle], + msa_full[:,i_cycle], + seq[:,i_cycle], + msa[:,i_cycle,0], # unmasked seq + xyz_prev, + alpha_prev, + idx_pdb, + bond_feats, + t1d=t1d, + t2d=t2d, + xyz_t=xyz_t, + alpha_t=alpha_t, + msa_prev=msa_prev, + pair_prev=pair_prev, + state_prev=state_prev, + use_checkpoint=False + ) + + true_crds, atom_mask = resolve_equiv_natives(pred_crds[-1], true_crds, atom_mask) + + res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,i_cycle,0]))) + mask_2d = res_mask[:,None,:] * res_mask[:,:,None] + + true_crds_frame = xyz_to_frame_xyz(true_crds, msa[:, i_cycle, 0],atom_frames) + c6d, _ = xyz_to_c6d(true_crds_frame) + c6d = c6d_to_bins(c6d, same_chain, negative=negative) + + prob = self.active_fn(logit_s[0]) # distogram + acc_s, cnt_pred, cnt_ref = self.calc_acc(prob, c6d[...,0], idx_pdb, mask_2d, return_cnt=True) + + # inter-chain contact prob + cnt_pred = cnt_pred * (1-same_chain).float() + cnt_ref = cnt_ref * (1-same_chain).float() + max_prob = cnt_pred.max() + if max_prob > 0.5: + if (cnt_ref > 0).any(): + TP += 1.0 + else: + FP += 1.0 + else: + if (cnt_ref > 0).any(): + FN += 1.0 + else: + TN += 1.0 + inter_s = torch.tensor([TP, FP, TN, FN], device=prob.device).float() + + loss, loss_s = self.calc_loss( + logit_s, c6d, + logit_aa_s, msa[:, i_cycle], mask_msa[:,i_cycle], + pred_crds, alphas, pred_allatom, true_crds, + atom_mask, res_mask, mask_2d, same_chain, + pred_lddts, idx_pdb, atom_frames, unclamp=unclamp, negative=negative, + verbose=verbose, ctr=ctrid, **self.loss_param + ) + + valid_tot += loss.detach() + if valid_loss == None: + valid_loss = torch.zeros_like(loss_s.detach()) + valid_acc = torch.zeros_like(acc_s.detach()) + valid_loss += loss_s.detach() + valid_acc += acc_s.detach() + valid_inter += inter_s.detach() + counter += 1 + + + + valid_tot /= float(counter*world_size) + valid_loss /= float(counter*world_size) + valid_acc /= float(counter*world_size) + + dist.all_reduce(valid_tot, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_acc, op=dist.ReduceOp.SUM) + dist.all_reduce(valid_inter, op=dist.ReduceOp.SUM) + + valid_tot = valid_tot.cpu().detach().numpy() + valid_loss = valid_loss.cpu().detach().numpy() + valid_acc = valid_acc.cpu().detach().numpy() + valid_inter = valid_inter.cpu().detach().numpy() + + if rank == 0: + TP, FP, TN, FN = valid_inter + prec = TP/(TP+FP+1e-4) + recall = TP/(TP+FN+1e-4) + F1 = 2*TP/(2*TP+FP+FN+1e-4) + + train_time = time.time() - start_time + sys.stdout.write("%s-PPI: [%04d/%04d] Batch: [%05d/%05d] Time: %16.1f | total_loss: %8.4f | %s | %.4f %.4f %.4f | %.4f %.4f %.4f\n"%(\ + header, epoch, self.n_epoch, counter*world_size, counter*world_size, train_time, valid_tot, \ + " ".join(["%8.4f"%l for l in valid_loss]),\ + valid_acc[0], valid_acc[1], valid_acc[2],\ + prec, recall, F1)) + sys.stdout.flush() + return valid_tot, valid_loss, valid_acc + +if __name__ == "__main__": + from arguments import get_args + args, model_param, loader_param, loss_param = get_args() + + print (args) + + mp.freeze_support() + train = Trainer(model_name=args.model_name, + n_epoch=args.num_epochs, step_lr=args.step_lr, lr=args.lr, l2_coeff=1.0e-2, + port=args.port, model_param=model_param, loader_param=loader_param, + loss_param=loss_param, + batch_size=args.batch_size, + accum_step=args.accum, + maxcycle=args.maxcycle, + eval=args.eval) + train.run_model_training(torch.cuda.device_count()) diff --git a/RF2_allatom/util.py b/RF2_allatom/util.py new file mode 100644 index 0000000..4b1781d --- /dev/null +++ b/RF2_allatom/util.py @@ -0,0 +1,883 @@ +import sys + +import numpy as np +import torch + +import scipy.sparse +import networkx as nx +import rdkit +from rdkit import Chem + +from chemical import * +from scoring import * + + +def th_ang_v(ab,bc,eps:float=1e-8): + def th_norm(x,eps:float=1e-8): + return x.square().sum(-1,keepdim=True).add(eps).sqrt() + def th_N(x,alpha:float=0): + return x/th_norm(x).add(alpha) + ab, bc = th_N(ab),th_N(bc) + cos_angle = torch.clamp( (ab*bc).sum(-1), -1, 1) + sin_angle = torch.sqrt(1-cos_angle.square() + eps) + dih = torch.stack((cos_angle,sin_angle),-1) + return dih + +def th_dih_v(ab,bc,cd): + def th_cross(a,b): + a,b = torch.broadcast_tensors(a,b) + return torch.cross(a,b, dim=-1) + def th_norm(x,eps:float=1e-8): + return x.square().sum(-1,keepdim=True).add(eps).sqrt() + def th_N(x,alpha:float=0): + return x/th_norm(x).add(alpha) + + ab, bc, cd = th_N(ab),th_N(bc),th_N(cd) + n1 = th_N( th_cross(ab,bc) ) + n2 = th_N( th_cross(bc,cd) ) + sin_angle = (th_cross(n1,bc)*n2).sum(-1) + cos_angle = (n1*n2).sum(-1) + dih = torch.stack((cos_angle,sin_angle),-1) + return dih + +def th_dih(a,b,c,d): + return th_dih_v(a-b,b-c,c-d) + +# build a frame from 3 points +#fd - more complicated version splits angle deviations between CA-N and CA-C (giving more accurate CB position) +#fd - makes no assumptions about input dims (other than last 1 is xyz) +def rigid_from_3_points(N, Ca, C, is_na=None, eps=1e-8): + dims = N.shape[:-1] + + v1 = C-Ca + v2 = N-Ca + e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps) + u2 = v2-(torch.einsum('...li, ...li -> ...l', e1, v2)[...,None]*e1) + e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps) + e3 = torch.cross(e1, e2, dim=-1) + R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix + + v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps) + cosref = torch.sum(e1*v2, dim=-1) + costgt = torch.full(dims, -0.3616, device=N.device) + if is_na is not None: + costgt[is_na] = -0.4929 + + cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 ) + cosdel = torch.sqrt(0.5*(1+cos2del)+eps) + sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps) + Rp = torch.eye(3, device=N.device).repeat(*dims,1,1) + Rp[...,0,0] = cosdel + Rp[...,0,1] = -sindel + Rp[...,1,0] = sindel + Rp[...,1,1] = cosdel + + R = torch.einsum('...ij,...jk->...ik', R,Rp) + + return R, Ca + +# note: needs consistency with chemical.py +def is_nucleic(seq): + return (seq>=NPROTAAS) * (seq <= NNAPROTAAS) + +def is_atom(seq): + return seq > NNAPROTAAS + +def idealize_reference_frame(seq, xyz_in): + xyz = xyz_in.clone() + + namask = is_nucleic(seq) + Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], namask) + + protmask = ~namask + + Nideal = torch.tensor([-0.5272, 1.3593, 0.000], device=xyz_in.device) + Cideal = torch.tensor([1.5233, 0.000, 0.000], device=xyz_in.device) + + OP1ideal = torch.tensor([-0.7319, 1.2920, 0.000], device=xyz_in.device) + OP2ideal = torch.tensor([1.4855, 0.000, 0.000], device=xyz_in.device) + + pmask_bs,pmask_rs = protmask.nonzero(as_tuple=True) + nmask_bs,nmask_rs = namask.nonzero(as_tuple=True) + xyz[pmask_bs,pmask_rs,0,:] = torch.einsum('...ij,j->...i', Rs[pmask_bs,pmask_rs], Nideal) + Ts[pmask_bs,pmask_rs] + xyz[pmask_bs,pmask_rs,2,:] = torch.einsum('...ij,j->...i', Rs[pmask_bs,pmask_rs], Cideal) + Ts[pmask_bs,pmask_rs] + xyz[nmask_bs,nmask_rs,0,:] = torch.einsum('...ij,j->...i', Rs[nmask_bs,nmask_rs], OP1ideal) + Ts[nmask_bs,nmask_rs] + xyz[nmask_bs,nmask_rs,2,:] = torch.einsum('...ij,j->...i', Rs[nmask_bs,nmask_rs], OP2ideal) + Ts[nmask_bs,nmask_rs] + + return xyz + +# works for both dna and protein +# alphas in order: +# omega/phi/psi: 0-2 +# chi_1-4(prot): 3-6 +# cb/cg bend: 7-9 +# eps(p)/zeta(p): 10-11 +# alpha/beta/gamma/delta: 12-15 +# nu2/nu1/nu0: 16-18 +# chi_1(na): 19 +def get_tor_mask(seq, torsion_indices, mask_in=None): + B,L = seq.shape[:2] + dna_mask = is_nucleic(seq) + prot_mask = ~dna_mask + + tors_mask = torsion_indices[seq,:,-1] > 0 + + if mask_in != None: + N = mask_in.shape[2] + ts = torsion_indices[seq] + bs = torch.arange(B, device=seq.device)[:,None,None,None] + rs = torch.arange(L, device=seq.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res + ts = torch.abs(ts) + tors_mask *= mask_in[bs,rs,ts].all(dim=-1) + + return tors_mask + + +def get_torsions(xyz_in, seq, torsion_indices, torsion_can_flip, ref_angles, mask_in=None): + B,L = xyz_in.shape[:2] + + tors_mask = get_tor_mask(seq, torsion_indices, mask_in) + # idealize given xyz coordinates before computing torsion angles + xyz = idealize_reference_frame(seq, xyz_in) + + ts = torsion_indices[seq] + bs = torch.arange(B, device=xyz_in.device)[:,None,None,None] + xs = torch.arange(L, device=xyz_in.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res + ys = torch.abs(ts) + xyzs_bytor = xyz[bs,xs,ys,:] + + torsions = torch.zeros( (B,L,NTOTALDOFS,2), device=xyz_in.device ) + torsions[...,:7,:] = th_dih( + xyzs_bytor[...,:7,0,:],xyzs_bytor[...,:7,1,:],xyzs_bytor[...,:7,2,:],xyzs_bytor[...,:7,3,:] + ) + torsions[:,:,2,:] = -1 * torsions[:,:,2,:] # shift psi by pi + torsions[...,10:,:] = th_dih( + xyzs_bytor[...,10:,0,:],xyzs_bytor[...,10:,1,:],xyzs_bytor[...,10:,2,:],xyzs_bytor[...,10:,3,:] + ) + + # angles (hardcoded) + # CB bend + NC = 0.5*( xyz[:,:,0,:3] + xyz[:,:,2,:3] ) + CA = xyz[:,:,1,:3] + CB = xyz[:,:,4,:3] + t = th_ang_v(CB-CA,NC-CA) + t0 = ref_angles[seq][...,0,:] + torsions[:,:,7,:] = torch.stack( + (torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]), + dim=-1 ) + + # CB twist + NCCA = NC-CA + NCp = xyz[:,:,2,:3] - xyz[:,:,0,:3] + NCpp = NCp - torch.sum(NCp*NCCA, dim=-1, keepdim=True)/ torch.sum(NCCA*NCCA, dim=-1, keepdim=True) * NCCA + t = th_ang_v(CB-CA,NCpp) + t0 = ref_angles[seq][...,1,:] + torsions[:,:,8,:] = torch.stack( + (torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]), + dim=-1 ) + + # CG bend + CG = xyz[:,:,5,:3] + t = th_ang_v(CG-CB,CA-CB) + t0 = ref_angles[seq][...,2,:] + torsions[:,:,9,:] = torch.stack( + (torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]), + dim=-1 ) + + mask0 = (torch.isnan(torsions[...,0])).nonzero() + mask1 = (torch.isnan(torsions[...,1])).nonzero() + torsions[mask0[:,0],mask0[:,1],mask0[:,2],0] = 1.0 + torsions[mask1[:,0],mask1[:,1],mask1[:,2],1] = 0.0 + + # alt chis + torsions_alt = torsions.clone() + torsions_alt[torsion_can_flip[seq,:]] *= -1 + + # torsions to restrain to 0 or 180 degree + # (this should be specified in chemical?) + tors_planar = torch.zeros((B, L, NTOTALDOFS), dtype=torch.bool, device=xyz_in.device) + tors_planar[:,:,5] = seq == aa2num['TYR'] # TYR chi 3 should be planar + + return torsions, torsions_alt, tors_mask, tors_planar + +def xyz_to_frame_xyz(xyz, seq_unmasked, atom_frames): + """ + xyz (1, L, natoms, 3) + seq_unmasked (1, L) + atom_frames (1, L, 3, 2) + """ + atoms = is_atom(seq_unmasked) + if torch.all(~atoms): + return xyz + atom_crds = xyz[atoms] + atom_L, natoms, _ = atom_crds.shape + frames_reindex = torch.zeros(atom_frames.shape[:-1]) + for i in range(atom_L): + frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1] + frames_reindex = frames_reindex.long() + xyz[atoms, :, :3] = atom_crds.reshape(atom_L*natoms, 3)[frames_reindex] + return xyz + +def xyz_t_to_frame_xyz(xyz_t, seq_unmasked, atom_frames): + """ + xyz (1, T, L, natoms, 3) + seq_unmasked (L) + atom_frames (1, L, 3, 2) + """ + atoms = is_atom(seq_unmasked) + if torch.all(~atoms): + return xyz_t + atom_crds_t = xyz_t[:, :, atoms] + + B, T, atom_L, natoms, _ = atom_crds_t.shape + frames_reindex = torch.zeros(atom_frames.shape[:-1]) + for i in range(atom_L): + frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1] + frames_reindex = frames_reindex.long() + xyz_t[:, :, atoms, :3] = atom_crds_t.reshape(T, atom_L*natoms, 3)[:, frames_reindex.squeeze(0)] + return xyz_t + +def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None): + B,L,natoms = xyz_in.shape[:3] + frames = frame_indices[seq] + + atoms = seq > NNAPROTAAS + if torch.any(atoms): + # print(torch.sum(atoms)) + # print(atom_frames.shape) + # print(atoms[0].nonzero().flatten().shape) + try: + frames[:,atoms[0].nonzero().flatten(), 0] = atom_frames + except Exception as e: + print(e) + print(torch.sum(atoms)) + print(atom_frames.shape) + print(atoms[0].nonzero().flatten().shape) + + frame_mask = ~torch.all(frames[...,0, :] == frames[...,1, :], axis=-1) + + # frame_mask *= torch.all( + # torch.gather(xyz_mask,2,frames.reshape(B,L,-1)).reshape(B,L,-1,3), + # axis=-1) + + return frames, frame_mask + +def generate_Cbeta(N,Ca,C): + # recreate Cb given N,Ca,C + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + #Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + # fd: below matches sidechain generator (=Rosetta params) + Cb = -0.57910144*a + 0.5689693*b - 0.5441217*c + Ca + + return Cb + +def get_tips(xyz, seq): + B,L = xyz.shape[:2] + + xyz_tips = torch.gather(xyz, 2, tip_indices.to(xyz.device)[seq][:,:,None,None].expand(-1,-1,-1,3)).reshape(B, L, 3) + if torch.isnan(xyz_tips).any(): # replace NaN tip atom with virtual Cb atom + # three anchor atoms + N = xyz[:,:,0] + Ca = xyz[:,:,1] + C = xyz[:,:,2] + + # recreate Cb given N,Ca,C + b = Ca - N + c = C - Ca + a = torch.cross(b, c, dim=-1) + Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca + + xyz_tips = torch.where(torch.isnan(xyz_tips), Cb, xyz_tips) + return xyz_tips + + +# writepdb +def writepdb(filename, atoms, seq, idx_pdb=None, bfacts=None): + f = open(filename,"w") + ctr = 1 + scpu = seq.cpu().squeeze(0) + atomscpu = atoms.cpu().squeeze(0) + + if bfacts is None: + bfacts = torch.zeros(atomscpu.shape[0]) + if idx_pdb is None: + idx_pdb = 1 + torch.arange(atomscpu.shape[0]) + + Bfacts = torch.clamp( bfacts.cpu(), 0, 1) + for i,s in enumerate(scpu): + natoms = atomscpu.shape[-2] + if (natoms!=NHEAVY and natoms!=NTOTAL): + print ('bad size!', natoms, NHEAVY, NTOTAL, atoms.shape) + assert(False) + + atms = aa2long[s] + + # his prot hack + if (s==8 and torch.linalg.norm( atomscpu[i,9,:]-atomscpu[i,5,:] ) < 1.7): + atms = ( + " N "," CA "," C "," O "," CB "," CG "," NE2"," CD2"," CE1"," ND1", + None, None, None, None," H "," HA ","1HB ","2HB "," HD2"," HE1", + " HD1", None, None, None, None, None, None) # his_d + + for j,atm_j in enumerate(atms): + if (j NNAPROTAAS-1: + # all atoms are at index 1 in the atom array + tip_indices[i] = 1 + else: + tip_atm = aa2tip[i] + atm_long = aa2long[i] + tip_indices[i] = atm_long.index(tip_atm) + + +# resolve torsion indices +# a negative index indicates the previous residue +# order: +# omega/phi/psi: 0-2 +# chi_1-4(prot): 3-6 +# cb/cg bend: 7-9 +# eps(p)/zeta(p): 10-11 +# alpha/beta/gamma/delta: 12-15 +# nu2/nu1/nu0: 16-18 +# chi_1(na): 19 +torsion_indices = torch.full((NAATOKENS,NTOTALDOFS,4),0) +torsion_can_flip = torch.full((NAATOKENS,NTOTALDOFS),False,dtype=torch.bool) +for i in range(NPROTAAS): + i_l, i_a = aa2long[i], aa2longalt[i] + + # protein omega/phi/psi + torsion_indices[i,0,:] = torch.tensor([-1,-2,0,1]) # omega + torsion_indices[i,1,:] = torch.tensor([-2,0,1,2]) # phi + torsion_indices[i,2,:] = torch.tensor([0,1,2,3]) # psi (+pi) + + # protein chis + for j in range(4): + if torsions[i][j] is None: + continue + for k in range(4): + a = torsions[i][j][k] + torsion_indices[i,3+j,k] = i_l.index(a) + if (i_l.index(a) != i_a.index(a)): + torsion_can_flip[i,3+j] = True ##bb tors never flip + + # CB/CG angles (only masking uses these indices) + torsion_indices[i,7,:] = torch.tensor([0,2,1,4]) # CB ang1 + torsion_indices[i,8,:] = torch.tensor([0,2,1,4]) # CB ang2 + torsion_indices[i,9,:] = torch.tensor([0,2,4,5]) # CG ang (arg 1 ignored) + +# HIS is a special case for flip +torsion_can_flip[8,4]=False + +for i in range(NPROTAAS,NNAPROTAAS): + # NA BB tors + torsion_indices[i,10,:] = torch.tensor([-5,-7,-8,1]) # epsilon_prev + torsion_indices[i,11,:] = torch.tensor([-7,-8,1,3]) # zeta_prev + torsion_indices[i,12,:] = torch.tensor([0,1,3,4]) # alpha (+2pi/3) + torsion_indices[i,13,:] = torch.tensor([1,3,4,5]) # beta + torsion_indices[i,14,:] = torch.tensor([3,4,5,7]) # gamma + torsion_indices[i,15,:] = torch.tensor([4,5,7,8]) # delta + + if (i=4] = 4 + num_bonds[i,...] = torch.tensor(num_bonds_i) + + +# atom type indices +idx2aatype = [] +for x in aa2type: + for y in x: + if y and y not in idx2aatype: + idx2aatype.append(y) +aatype2idx = {x:i for i,x in enumerate(idx2aatype)} + +# element indices +idx2elt = [] +for x in aa2elt: + for y in x: + if y and y not in idx2elt: + idx2elt.append(y) +elt2idx = {x:i for i,x in enumerate(idx2elt)} + +# LJ/LK scoring parameters +atom_type_index = torch.zeros((NAATOKENS,NTOTAL), dtype=torch.long) +element_index = torch.zeros((NAATOKENS,NTOTAL), dtype=torch.long) + +ljlk_parameters = torch.zeros((NAATOKENS,NTOTAL,5), dtype=torch.float) +lj_correction_parameters = torch.zeros((NAATOKENS,NTOTAL,4), dtype=bool) # donor/acceptor/hpol/disulf +for i in range(NNAPROTAAS): + for j,a in enumerate(aa2type[i]): + if (a is not None): + atom_type_index[i,j] = aatype2idx[a] + ljlk_parameters[i,j,:] = torch.tensor( type2ljlk[a] ) + lj_correction_parameters[i,j,0] = (type2hb[a]==HbAtom.DO)+(type2hb[a]==HbAtom.DA) + lj_correction_parameters[i,j,1] = (type2hb[a]==HbAtom.AC)+(type2hb[a]==HbAtom.DA) + lj_correction_parameters[i,j,2] = (type2hb[a]==HbAtom.HP) + lj_correction_parameters[i,j,3] = (a=="SH1" or a=="HS") + for j,a in enumerate(aa2elt[i]): + if (a is not None): + element_index[i,j] = elt2idx[a] + +# hbond scoring parameters +def donorHs(D,bonds,atoms): + dHs = [] + for (i,j) in bonds: + if (i==D): + idx_j = atoms.index(j) + if (idx_j>=NHEAVY): # if atom j is a hydrogen + dHs.append(idx_j) + if (j==D): + idx_i = atoms.index(i) + if (idx_i>=NHEAVY): # if atom j is a hydrogen + dHs.append(idx_i) + assert (len(dHs)>0) + return dHs + +def acceptorBB0(A,hyb,bonds,atoms): + if (hyb == HbHybType.SP2): + for (i,j) in bonds: + if (i==A): + B = atoms.index(j) + if (B0): + cb_length_t[i,:len(cb_lengths[src]),:] = torch.tensor(cb_lengths[src]) + +# (2) intra-res angles +cb_angles = [[] for i in range(NAATOKENS+1)] +for cst in cartbonded_data_raw['angles']: + res_idx = aa2num[ cst['res'] ] + cb_angles[res_idx].append( ( + aa2long[res_idx].index(cst['atm1']), + aa2long[res_idx].index(cst['atm2']), + aa2long[res_idx].index(cst['atm3']), + cst['x0'],cst['K'] + ) ) +ncst_per_res=max([len(i) for i in cb_angles]) +cb_angle_t = torch.zeros(NAATOKENS+1,ncst_per_res,5) +for i in range(NNAPROTAAS+1): + src = i + if (num2aa[i]=='UNK' or num2aa[i]=='MAS'): + src=aa2num['ALA'] + + if (len(cb_angles[src])>0): + cb_angle_t[i,:len(cb_angles[src]),:] = torch.tensor(cb_angles[src]) + +# (3) intra-res torsions +cb_torsions = [[] for i in range(NAATOKENS+1)] +for cst in cartbonded_data_raw['torsions']: + res_idx = aa2num[ cst['res'] ] + cb_torsions[res_idx].append( ( + aa2long[res_idx].index(cst['atm1']), + aa2long[res_idx].index(cst['atm2']), + aa2long[res_idx].index(cst['atm3']), + aa2long[res_idx].index(cst['atm4']), + cst['x0'],cst['K'],cst['period'] + ) ) +ncst_per_res=max([len(i) for i in cb_torsions]) +cb_torsion_t = torch.zeros(NAATOKENS+1,ncst_per_res,7) +cb_torsion_t[...,6]=1.0 # periodicity +for i in range(NNAPROTAAS): + src = i + if (num2aa[i]=='UNK' or num2aa[i]=='MAS'): + src=aa2num['ALA'] + + if (len(cb_torsions[src])>0): + cb_torsion_t[i,:len(cb_torsions[src]),:] = torch.tensor(cb_torsions[src]) + +# kinematic parameters +base_indices = torch.full((NAATOKENS,NTOTAL),0, dtype=torch.long) # base frame that builds each atom +xyzs_in_base_frame = torch.ones((NAATOKENS,NTOTAL,4)) # coords of each atom in the base frame +RTs_by_torsion = torch.eye(4).repeat(NAATOKENS,NTOTALTORS,1,1) # torsion frames +reference_angles = torch.ones((NAATOKENS,NPROTANGS,2)) # reference values for bendable angles + +## PROTEIN +for i in range(NPROTAAS): + i_l = aa2long[i] + for name, base, coords in ideal_coords[i]: + idx = i_l.index(name) + base_indices[i,idx] = base + xyzs_in_base_frame[i,idx,:3] = torch.tensor(coords) + + # omega frame + RTs_by_torsion[i,0,:3,:3] = torch.eye(3) + RTs_by_torsion[i,0,:3,3] = torch.zeros(3) + + # phi frame + RTs_by_torsion[i,1,:3,:3] = make_frame( + xyzs_in_base_frame[i,0,:3] - xyzs_in_base_frame[i,1,:3], + torch.tensor([1.,0.,0.]) + ) + RTs_by_torsion[i,1,:3,3] = xyzs_in_base_frame[i,0,:3] + + # psi frame + RTs_by_torsion[i,2,:3,:3] = make_frame( + xyzs_in_base_frame[i,2,:3] - xyzs_in_base_frame[i,1,:3], + xyzs_in_base_frame[i,1,:3] - xyzs_in_base_frame[i,0,:3] + ) + RTs_by_torsion[i,2,:3,3] = xyzs_in_base_frame[i,2,:3] + + # chi1 frame + if torsions[i][0] is not None: + a0,a1,a2 = torsion_indices[i,3,0:3] + RTs_by_torsion[i,3,:3,:3] = make_frame( + xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3], + xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3], + ) + RTs_by_torsion[i,3,:3,3] = xyzs_in_base_frame[i,a2,:3] + + # chi2/3/4 frame + for j in range(1,4): + if torsions[i][j] is not None: + a2 = torsion_indices[i,3+j,2] + if ((i==18 and j==2) or (i==8 and j==2)): # TYR CZ-OH & HIS CE1-HE1 a special case + a0,a1 = torsion_indices[i,3+j,0:2] + RTs_by_torsion[i,3+j,:3,:3] = make_frame( + xyzs_in_base_frame[i,a2,:3]-xyzs_in_base_frame[i,a1,:3], + xyzs_in_base_frame[i,a0,:3]-xyzs_in_base_frame[i,a1,:3] ) + else: + RTs_by_torsion[i,3+j,:3,:3] = make_frame( + xyzs_in_base_frame[i,a2,:3], + torch.tensor([-1.,0.,0.]), ) + RTs_by_torsion[i,3+j,:3,3] = xyzs_in_base_frame[i,a2,:3] + + # CB/CG angles + NCr = 0.5*(xyzs_in_base_frame[i,0,:3]+xyzs_in_base_frame[i,2,:3]) + CAr = xyzs_in_base_frame[i,1,:3] + CBr = xyzs_in_base_frame[i,4,:3] + CGr = xyzs_in_base_frame[i,5,:3] + reference_angles[i,0,:]=th_ang_v(CBr-CAr,NCr-CAr) + NCp = xyzs_in_base_frame[i,2,:3]-xyzs_in_base_frame[i,0,:3] + NCpp = NCp - torch.dot(NCp,NCr)/ torch.dot(NCr,NCr) * NCr + reference_angles[i,1,:]=th_ang_v(CBr-CAr,NCpp) + reference_angles[i,2,:]=th_ang_v(CGr,torch.tensor([-1.,0.,0.])) + +## NUCLEIC ACIDS +for i in range(NPROTAAS, NNAPROTAAS): + i_l = aa2long[i] + + for name, base, coords in ideal_coords[i]: + idx = i_l.index(name) + base_indices[i,idx] = base + xyzs_in_base_frame[i,idx,:3] = torch.tensor(coords) + + # epsilon(p)/zeta(p) - like omega in protein, not used to build atoms + # - keep as identity + RTs_by_torsion[i,NPROTTORS+0,:3,:3] = torch.eye(3) + RTs_by_torsion[i,NPROTTORS+0,:3,3] = torch.zeros(3) + RTs_by_torsion[i,NPROTTORS+1,:3,:3] = torch.eye(3) + RTs_by_torsion[i,NPROTTORS+1,:3,3] = torch.zeros(3) + + # alpha + RTs_by_torsion[i,NPROTTORS+2,:3,:3] = make_frame( + xyzs_in_base_frame[i,3,:3] - xyzs_in_base_frame[i,1,:3], # P->O5' + xyzs_in_base_frame[i,0,:3] - xyzs_in_base_frame[i,1,:3] # P<-OP1 + ) + RTs_by_torsion[i,NPROTTORS+2,:3,3] = xyzs_in_base_frame[i,3,:3] # O5' + + # beta + RTs_by_torsion[i,NPROTTORS+3,:3,:3] = make_frame( + xyzs_in_base_frame[i,4,:3] , torch.tensor([-1.,0.,0.]) + ) + RTs_by_torsion[i,NPROTTORS+3,:3,3] = xyzs_in_base_frame[i,4,:3] # C5' + + # gamma + RTs_by_torsion[i,NPROTTORS+4,:3,:3] = make_frame( + xyzs_in_base_frame[i,5,:3] , torch.tensor([-1.,0.,0.]) + ) + RTs_by_torsion[i,NPROTTORS+4,:3,3] = xyzs_in_base_frame[i,5,:3] # C4' + + # delta + RTs_by_torsion[i,NPROTTORS+5,:3,:3] = make_frame( + xyzs_in_base_frame[i,7,:3] , torch.tensor([-1.,0.,0.]) + ) + RTs_by_torsion[i,NPROTTORS+5,:3,3] = xyzs_in_base_frame[i,7,:3] # C3' + + # nu2 + RTs_by_torsion[i,NPROTTORS+6,:3,:3] = make_frame( + xyzs_in_base_frame[i,6,:3] , torch.tensor([-1.,0.,0.]) + ) + RTs_by_torsion[i,NPROTTORS+6,:3,3] = xyzs_in_base_frame[i,6,:3] # O4' + + # nu1 + if i nx.Graph: + '''build NetworkX graph from rdkit's molecule''' + + N = mol.GetNumAtoms() + + # pairs of bonded atoms + bonds = [(b.GetBeginAtomIdx(),b.GetEndAtomIdx()) for b in mol.GetBonds()] + + # connectivity graph + G = nx.Graph() + G.add_nodes_from(range(N)) + G.add_edges_from(bonds) + + return G + +def find_all_paths_of_length_n(G : nx.Graph, + n : int) -> torch.Tensor: + '''find all paths of length N in a networkx graph + https://stackoverflow.com/questions/28095646/finding-all-paths-walks-of-given-length-in-a-networkx-graph''' + + def findPaths(G,u,n): + if n==0: + return [[u]] + paths = [[u]+path for neighbor in G.neighbors(u) for path in findPaths(G,neighbor,n-1) if u not in path] + return paths + + # all paths of length n + allpaths = [tuple(p) if p[0] 1] = 0.0 # if bonded -- 1.0 / else 0.0 + neigh = sign * neigh + return neigh.unsqueeze(-1) + +def make_full_graph(xyz, pair, idx): + ''' + Input: + - xyz: current backbone cooordinates (B, L, 3, 3) + - pair: pair features from Trunk (B, L, L, E) + - idx: residue index from ground truth pdb + Output: + - G: defined graph + ''' + + B, L = xyz.shape[:2] + device = xyz.device + + # seq sep + sep = idx[:,None,:] - idx[:,:,None] + b,i,j = torch.where(sep.abs() > 0) + + src = b*L+i + tgt = b*L+j + G = dgl.graph((src, tgt), num_nodes=B*L).to(device) + G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function + + return G, pair[b,i,j][...,None] + +def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6): + ''' + Input: + - xyz: current backbone cooordinates (B, L, 3, 3) + - pair: pair features from Trunk (B, L, L, E) + - idx: residue index from ground truth pdb + Output: + - G: defined graph + ''' + + B, L = xyz.shape[:2] + device = xyz.device + + # distance map from current CA coordinates + D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*9999.9 # (B, L, L) + + # seq sep + sep = idx[:,None,:] - idx[:,:,None] + sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*9999.9 + + if (topk_incl_local): + D = D + sep*eps + D[sep 0.0 + + else: + + D = D + sep*eps + + # get top_k neighbors + D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k) + topk_matrix = torch.zeros((B, L, L), device=device) + topk_matrix.scatter_(2, E_idx, 1.0) + + # put an edge if any of the 3 conditions are met: + # 1) |i-j| <= kmin (connect sequentially adjacent residues) + # 2) top_k neighbors + cond = torch.logical_or(topk_matrix > 0.0, sep < nlocal) + b,i,j = torch.where(cond) + + src = b*L+i + tgt = b*L+j + G = dgl.graph((src, tgt), num_nodes=B*L).to(device) + G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function + + return G, pair[b,i,j][...,None] + +def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ): + B,L,A = xyz.shape[:3] + device = xyz.device + + D = torch.norm( + xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1 + ) + mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:] + D[~mask2d] = 9999. + D[D==0] = 9999. + + # select top K neighbors for each atom + # keep indices as batch/res/atm indices + D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k) + Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A + bi,ri,ai = mask.nonzero(as_tuple=True) + bi = bi[:,None].repeat(1,top_k).reshape(-1) + ri = ri[:,None].repeat(1,top_k).reshape(-1) + ai = ai[:,None].repeat(1,top_k).reshape(-1) + rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1) + + # on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom + edge = torch.full(ri.shape, maxbonds, device=device) + resmask = ri==rj + edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1 + resmask = ri+1==rj + edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]] + resmask = ri-1==rj + edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]] + edge = edge.clamp(0,maxbonds-1) + edge = F.one_hot(edge)[...,None] + + natm = torch.sum(mask) + index = torch.zeros_like(mask, dtype=torch.long, device=device) + index[mask] = torch.arange(natm, device=device) + src=index[bi,ri,ai] + tgt=index[bi,rj,aj] + + G = dgl.graph((src, tgt), num_nodes=natm).to(device) + G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function + + return G, edge + + +# rotate about the x axis +def make_rotX(angs, eps=1e-6): + B,L = angs.shape[:2] + NORM = torch.linalg.norm(angs, dim=-1) + eps + + RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) + + RTs[:,:,1,1] = angs[:,:,0]/NORM + RTs[:,:,1,2] = -angs[:,:,1]/NORM + RTs[:,:,2,1] = angs[:,:,1]/NORM + RTs[:,:,2,2] = angs[:,:,0]/NORM + return RTs + +# rotate about the x axis +def make_rotZ(angs, eps=1e-6): + B,L = angs.shape[:2] + NORM = torch.linalg.norm(angs, dim=-1) + eps + + RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) + + RTs[:,:,0,0] = angs[:,:,0]/NORM + RTs[:,:,0,1] = -angs[:,:,1]/NORM + RTs[:,:,1,0] = angs[:,:,1]/NORM + RTs[:,:,1,1] = angs[:,:,0]/NORM + return RTs + +# rotate about an arbitrary axis +def make_rot_axis(angs, u, eps=1e-6): + B,L = angs.shape[:2] + NORM = torch.linalg.norm(angs, dim=-1) + eps + + RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1) + + ct = angs[:,:,0]/NORM + st = angs[:,:,1]/NORM + u0 = u[:,:,0] + u1 = u[:,:,1] + u2 = u[:,:,2] + + RTs[:,:,0,0] = ct+u0*u0*(1-ct) + RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st + RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st + RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st + RTs[:,:,1,1] = ct+u1*u1*(1-ct) + RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st + RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st + RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st + RTs[:,:,2,2] = ct+u2*u2*(1-ct) + return RTs + + +# compute allatom structure from backbone frames and torsions +# +# alphas: +# omega/phi/psi: 0-2 +# chi_1-4(prot): 3-6 +# cb/cg bend: 7-9 +# eps(p)/zeta(p): 10-11 +# alpha/beta/gamma/delta: 12-15 +# nu2/nu1/nu0: 16-18 +# chi_1(na): 19 +# +# RTs_in_base_frame: +# omega/phi/psi: 0-2 +# chi_1-4(prot): 3-6 +# eps(p)/zeta(p): 7-8 +# alpha/beta/gamma/delta: 9-12 +# nu2/nu1/nu0: 13-15 +# chi_1(na): 16 +# +# RT frames (output): +# origin: 0 +# omega/phi/psi: 1-3 +# chi_1-4(prot): 4-7 +# cb bend: 8 +# alpha/beta/gamma/delta: 9-12 +# nu2/nu1/nu0: 13-15 +# chi_1(na): 16 +# +class ComputeAllAtomCoords(nn.Module): + def __init__(self): + super(ComputeAllAtomCoords, self).__init__() + + self.base_indices = nn.Parameter(base_indices, requires_grad=False) + self.RTs_in_base_frame = nn.Parameter(RTs_by_torsion, requires_grad=False) + self.xyzs_in_base_frame = nn.Parameter(xyzs_in_base_frame, requires_grad=False) + + def forward(self, seq, xyz, alphas): + B,L = xyz.shape[:2] + + is_NA = is_nucleic(seq) + Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], is_NA) + + RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device) + + # bb + RTF0[:,:,:3,:3] = Rs + RTF0[:,:,:3,3] = Ts + + # omega + RTF1 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:])) + + # phi + RTF2 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:])) + + # psi + RTF3 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:])) + + # CB bend + basexyzs = self.xyzs_in_base_frame[seq] + NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3]) + CAr = (basexyzs[:,:,1,:3]) + CBr = (basexyzs[:,:,4,:3]) + CBrotaxis1 = (CBr-CAr).cross(NCr-CAr) + CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8 + + # CB twist + NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3] + NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr + CBrotaxis2 = (CBr-CAr).cross(NCpp) + CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8 + + CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 ) + CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 ) + + RTF8 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, CBrot1,CBrot2) + + # chi1 + CG bend + RTF4 = torch.einsum( + 'brij,brjk,brkl,brlm->brim', + RTF8, + self.RTs_in_base_frame[seq,3,:], + make_rotX(alphas[:,:,3,:]), + make_rotZ(alphas[:,:,9,:])) + + # chi2 + RTF5 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:])) + + # chi3 + RTF6 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:])) + + # chi4 + RTF7 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:])) + + # ignore RTs_in_base_frame[seq,7:9,:] and alphas[:,:,10:12,:] + + # NA alpha + RTF9 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF0, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:])) + + # NA beta + RTF10 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF9, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:])) + + # NA gamma + RTF11 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF10, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:])) + + # NA delta + RTF12 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF11, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:])) + + # NA nu2 - from gamma frame + RTF13 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF11, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:])) + + # NA nu1 + RTF14 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF13, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:])) + + # NA nu0 + RTF15 = torch.einsum( + 'brij,brjk,brkl->bril', + RTF14, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:])) + + # NA chi - from nu1 frame + RTF16= torch.einsum( + 'brij,brjk,brkl->bril', + RTF14, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:])) + + RTframes = torch.stack(( + RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8, + RTF9,RTF10,RTF11,RTF12,RTF13,RTF14,RTF15,RTF16 + ),dim=2) + + xyzs = torch.einsum( + 'brtij,brtj->brti', + RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs + ) + + return RTframes, xyzs[...,:3]