mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-06 13:24:29 +08:00
821 lines
24 KiB
Python
821 lines
24 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
# Copyright 2021 DeepMind Technologies Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from functools import reduce
|
|
import importlib
|
|
import math
|
|
import sys
|
|
from operator import mul
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Optional, Tuple, Sequence
|
|
|
|
from openfold.model.primitives import Linear, LayerNorm, ipa_point_weights_init_
|
|
from openfold.np.residue_constants import (
|
|
restype_rigid_group_default_frame,
|
|
restype_atom14_to_rigid_group,
|
|
restype_atom14_mask,
|
|
restype_atom14_rigid_group_positions,
|
|
)
|
|
from openfold.utils.feats import (
|
|
frames_and_literature_positions_to_atom14_pos,
|
|
torsion_angles_to_frames,
|
|
)
|
|
from openfold.utils.precision_utils import is_fp16_enabled
|
|
from openfold.utils.rigid_utils import Rotation, Rigid
|
|
from openfold.utils.tensor_utils import (
|
|
dict_multimap,
|
|
permute_final_dims,
|
|
flatten_final_dims,
|
|
)
|
|
|
|
# attn_core_inplace_cuda = importlib.import_module("attn_core_inplace_cuda")
|
|
|
|
|
|
class AngleResnetBlock(nn.Module):
|
|
def __init__(self, c_hidden):
|
|
"""
|
|
Args:
|
|
c_hidden:
|
|
Hidden channel dimension
|
|
"""
|
|
super(AngleResnetBlock, self).__init__()
|
|
|
|
self.c_hidden = c_hidden
|
|
|
|
self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
|
|
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
|
|
s_initial = a
|
|
|
|
a = self.relu(a)
|
|
a = self.linear_1(a)
|
|
a = self.relu(a)
|
|
a = self.linear_2(a)
|
|
|
|
return a + s_initial
|
|
|
|
|
|
class AngleResnet(nn.Module):
|
|
"""
|
|
Implements Algorithm 20, lines 11-14
|
|
"""
|
|
|
|
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
|
|
"""
|
|
Args:
|
|
c_in:
|
|
Input channel dimension
|
|
c_hidden:
|
|
Hidden channel dimension
|
|
no_blocks:
|
|
Number of resnet blocks
|
|
no_angles:
|
|
Number of torsion angles to generate
|
|
epsilon:
|
|
Small constant for normalization
|
|
"""
|
|
super(AngleResnet, self).__init__()
|
|
|
|
self.c_in = c_in
|
|
self.c_hidden = c_hidden
|
|
self.no_blocks = no_blocks
|
|
self.no_angles = no_angles
|
|
self.eps = epsilon
|
|
|
|
self.linear_in = Linear(self.c_in, self.c_hidden)
|
|
self.linear_initial = Linear(self.c_in, self.c_hidden)
|
|
|
|
self.layers = nn.ModuleList()
|
|
for _ in range(self.no_blocks):
|
|
layer = AngleResnetBlock(c_hidden=self.c_hidden)
|
|
self.layers.append(layer)
|
|
|
|
self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(
|
|
self, s: torch.Tensor, s_initial: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
s:
|
|
[*, C_hidden] single embedding
|
|
s_initial:
|
|
[*, C_hidden] single embedding as of the start of the
|
|
StructureModule
|
|
Returns:
|
|
[*, no_angles, 2] predicted angles
|
|
"""
|
|
# NOTE: The ReLU's applied to the inputs are absent from the supplement
|
|
# pseudocode but present in the source. For maximal compatibility with
|
|
# the pretrained weights, I'm going with the source.
|
|
|
|
# [*, C_hidden]
|
|
s_initial = self.relu(s_initial)
|
|
s_initial = self.linear_initial(s_initial)
|
|
s = self.relu(s)
|
|
s = self.linear_in(s)
|
|
s = s + s_initial
|
|
|
|
for l in self.layers:
|
|
s = l(s)
|
|
|
|
s = self.relu(s)
|
|
|
|
# [*, no_angles * 2]
|
|
s = self.linear_out(s)
|
|
|
|
# [*, no_angles, 2]
|
|
s = s.view(s.shape[:-1] + (-1, 2))
|
|
|
|
unnormalized_s = s
|
|
norm_denom = torch.sqrt(
|
|
torch.clamp(
|
|
torch.sum(s ** 2, dim=-1, keepdim=True),
|
|
min=self.eps,
|
|
)
|
|
)
|
|
s = s / norm_denom
|
|
|
|
return unnormalized_s, s
|
|
|
|
|
|
class InvariantPointAttention(nn.Module):
|
|
"""
|
|
Implements Algorithm 22.
|
|
"""
|
|
def __init__(
|
|
self,
|
|
c_s: int,
|
|
c_z: int,
|
|
c_hidden: int,
|
|
no_heads: int,
|
|
no_qk_points: int,
|
|
no_v_points: int,
|
|
inf: float = 1e5,
|
|
eps: float = 1e-8,
|
|
):
|
|
"""
|
|
Args:
|
|
c_s:
|
|
Single representation channel dimension
|
|
c_z:
|
|
Pair representation channel dimension
|
|
c_hidden:
|
|
Hidden channel dimension
|
|
no_heads:
|
|
Number of attention heads
|
|
no_qk_points:
|
|
Number of query/key points to generate
|
|
no_v_points:
|
|
Number of value points to generate
|
|
"""
|
|
super(InvariantPointAttention, self).__init__()
|
|
|
|
self.c_s = c_s
|
|
self.c_z = c_z
|
|
self.c_hidden = c_hidden
|
|
self.no_heads = no_heads
|
|
self.no_qk_points = no_qk_points
|
|
self.no_v_points = no_v_points
|
|
self.inf = inf
|
|
self.eps = eps
|
|
|
|
# These linear layers differ from their specifications in the
|
|
# supplement. There, they lack bias and use Glorot initialization.
|
|
# Here as in the official source, they have bias and use the default
|
|
# Lecun initialization.
|
|
hc = self.c_hidden * self.no_heads
|
|
self.linear_q = Linear(self.c_s, hc)
|
|
self.linear_kv = Linear(self.c_s, 2 * hc)
|
|
|
|
hpq = self.no_heads * self.no_qk_points * 3
|
|
self.linear_q_points = Linear(self.c_s, hpq)
|
|
|
|
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
|
|
self.linear_kv_points = Linear(self.c_s, hpkv)
|
|
|
|
hpv = self.no_heads * self.no_v_points * 3
|
|
|
|
self.linear_b = Linear(self.c_z, self.no_heads)
|
|
|
|
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
|
|
ipa_point_weights_init_(self.head_weights)
|
|
|
|
concat_out_dim = self.no_heads * (
|
|
self.c_z + self.c_hidden + self.no_v_points * 4
|
|
)
|
|
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
|
|
|
|
self.softmax = nn.Softmax(dim=-1)
|
|
self.softplus = nn.Softplus()
|
|
|
|
def forward(
|
|
self,
|
|
s: torch.Tensor,
|
|
z: Optional[torch.Tensor],
|
|
r: Rigid,
|
|
mask: torch.Tensor,
|
|
inplace_safe: bool = False,
|
|
_offload_inference: bool = False,
|
|
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
s:
|
|
[*, N_res, C_s] single representation
|
|
z:
|
|
[*, N_res, N_res, C_z] pair representation
|
|
r:
|
|
[*, N_res] transformation object
|
|
mask:
|
|
[*, N_res] mask
|
|
Returns:
|
|
[*, N_res, C_s] single representation update
|
|
"""
|
|
if(_offload_inference and inplace_safe):
|
|
z = _z_reference_list
|
|
else:
|
|
z = [z]
|
|
|
|
#######################################
|
|
# Generate scalar and point activations
|
|
#######################################
|
|
# [*, N_res, H * C_hidden]
|
|
q = self.linear_q(s)
|
|
kv = self.linear_kv(s)
|
|
|
|
# [*, N_res, H, C_hidden]
|
|
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
|
|
|
# [*, N_res, H, 2 * C_hidden]
|
|
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
|
|
|
|
# [*, N_res, H, C_hidden]
|
|
k, v = torch.split(kv, self.c_hidden, dim=-1)
|
|
|
|
# [*, N_res, H * P_q * 3]
|
|
q_pts = self.linear_q_points(s)
|
|
|
|
# This is kind of clunky, but it's how the original does it
|
|
# [*, N_res, H * P_q, 3]
|
|
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
|
|
q_pts = torch.stack(q_pts, dim=-1)
|
|
q_pts = r[..., None].apply(q_pts)
|
|
|
|
# [*, N_res, H, P_q, 3]
|
|
q_pts = q_pts.view(
|
|
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
|
|
)
|
|
|
|
# [*, N_res, H * (P_q + P_v) * 3]
|
|
kv_pts = self.linear_kv_points(s)
|
|
|
|
# [*, N_res, H * (P_q + P_v), 3]
|
|
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
|
|
kv_pts = torch.stack(kv_pts, dim=-1)
|
|
kv_pts = r[..., None].apply(kv_pts)
|
|
|
|
# [*, N_res, H, (P_q + P_v), 3]
|
|
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
|
|
|
|
# [*, N_res, H, P_q/P_v, 3]
|
|
k_pts, v_pts = torch.split(
|
|
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
|
|
)
|
|
|
|
##########################
|
|
# Compute attention scores
|
|
##########################
|
|
# [*, N_res, N_res, H]
|
|
b = self.linear_b(z[0])
|
|
|
|
if(_offload_inference):
|
|
assert(sys.getrefcount(z[0]) == 2)
|
|
z[0] = z[0].cpu()
|
|
|
|
# [*, H, N_res, N_res]
|
|
if(is_fp16_enabled()):
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
a = torch.matmul(
|
|
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
|
|
permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res]
|
|
)
|
|
else:
|
|
a = torch.matmul(
|
|
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
|
|
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
|
|
)
|
|
|
|
a *= math.sqrt(1.0 / (3 * self.c_hidden))
|
|
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
|
|
|
|
# [*, N_res, N_res, H, P_q, 3]
|
|
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
|
|
if(inplace_safe):
|
|
pt_att *= pt_att
|
|
else:
|
|
pt_att = pt_att ** 2
|
|
|
|
# [*, N_res, N_res, H, P_q]
|
|
pt_att = sum(torch.unbind(pt_att, dim=-1))
|
|
head_weights = self.softplus(self.head_weights).view(
|
|
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
|
|
)
|
|
head_weights = head_weights * math.sqrt(
|
|
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
|
|
)
|
|
if(inplace_safe):
|
|
pt_att *= head_weights
|
|
else:
|
|
pt_att = pt_att * head_weights
|
|
|
|
# [*, N_res, N_res, H]
|
|
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
|
|
# [*, N_res, N_res]
|
|
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
|
square_mask = self.inf * (square_mask - 1)
|
|
|
|
# [*, H, N_res, N_res]
|
|
pt_att = permute_final_dims(pt_att, (2, 0, 1))
|
|
|
|
if(inplace_safe):
|
|
a += pt_att
|
|
del pt_att
|
|
a += square_mask.unsqueeze(-3)
|
|
# in-place softmax
|
|
attn_core_inplace_cuda.forward_(
|
|
a,
|
|
reduce(mul, a.shape[:-1]),
|
|
a.shape[-1],
|
|
)
|
|
else:
|
|
a = a + pt_att
|
|
a = a + square_mask.unsqueeze(-3)
|
|
a = self.softmax(a)
|
|
|
|
################
|
|
# Compute output
|
|
################
|
|
# [*, N_res, H, C_hidden]
|
|
o = torch.matmul(
|
|
a, v.transpose(-2, -3).to(dtype=a.dtype)
|
|
).transpose(-2, -3)
|
|
|
|
# [*, N_res, H * C_hidden]
|
|
o = flatten_final_dims(o, 2)
|
|
|
|
# [*, H, 3, N_res, P_v]
|
|
if(inplace_safe):
|
|
v_pts = permute_final_dims(v_pts, (1, 3, 0, 2))
|
|
o_pt = [
|
|
torch.matmul(a, v.to(a.dtype))
|
|
for v in torch.unbind(v_pts, dim=-3)
|
|
]
|
|
o_pt = torch.stack(o_pt, dim=-3)
|
|
else:
|
|
o_pt = torch.sum(
|
|
(
|
|
a[..., None, :, :, None]
|
|
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
|
|
),
|
|
dim=-2,
|
|
)
|
|
|
|
# [*, N_res, H, P_v, 3]
|
|
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
|
|
o_pt = r[..., None, None].invert_apply(o_pt)
|
|
|
|
# [*, N_res, H * P_v]
|
|
o_pt_norm = flatten_final_dims(
|
|
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
|
|
)
|
|
|
|
# [*, N_res, H * P_v, 3]
|
|
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
|
|
|
|
if(_offload_inference):
|
|
z[0] = z[0].to(o_pt.device)
|
|
|
|
# [*, N_res, H, C_z]
|
|
o_pair = torch.matmul(a.transpose(-2, -3), z[0].to(dtype=a.dtype))
|
|
|
|
# [*, N_res, H * C_z]
|
|
o_pair = flatten_final_dims(o_pair, 2)
|
|
|
|
# [*, N_res, C_s]
|
|
s = self.linear_out(
|
|
torch.cat(
|
|
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
|
|
).to(dtype=z[0].dtype)
|
|
)
|
|
|
|
return s
|
|
|
|
|
|
class BackboneUpdate(nn.Module):
|
|
"""
|
|
Implements part of Algorithm 23.
|
|
"""
|
|
|
|
def __init__(self, c_s):
|
|
"""
|
|
Args:
|
|
c_s:
|
|
Single representation channel dimension
|
|
"""
|
|
super(BackboneUpdate, self).__init__()
|
|
|
|
self.c_s = c_s
|
|
|
|
self.linear = Linear(self.c_s, 6, init="final")
|
|
|
|
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
[*, N_res, C_s] single representation
|
|
Returns:
|
|
[*, N_res, 6] update vector
|
|
"""
|
|
# [*, 6]
|
|
update = self.linear(s)
|
|
|
|
return update
|
|
|
|
|
|
class StructureModuleTransitionLayer(nn.Module):
|
|
def __init__(self, c):
|
|
super(StructureModuleTransitionLayer, self).__init__()
|
|
|
|
self.c = c
|
|
|
|
self.linear_1 = Linear(self.c, self.c, init="relu")
|
|
self.linear_2 = Linear(self.c, self.c, init="relu")
|
|
self.linear_3 = Linear(self.c, self.c, init="final")
|
|
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, s):
|
|
s_initial = s
|
|
s = self.linear_1(s)
|
|
s = self.relu(s)
|
|
s = self.linear_2(s)
|
|
s = self.relu(s)
|
|
s = self.linear_3(s)
|
|
|
|
s = s + s_initial
|
|
|
|
return s
|
|
|
|
|
|
class StructureModuleTransition(nn.Module):
|
|
def __init__(self, c, num_layers, dropout_rate):
|
|
super(StructureModuleTransition, self).__init__()
|
|
|
|
self.c = c
|
|
self.num_layers = num_layers
|
|
self.dropout_rate = dropout_rate
|
|
|
|
self.layers = nn.ModuleList()
|
|
for _ in range(self.num_layers):
|
|
l = StructureModuleTransitionLayer(self.c)
|
|
self.layers.append(l)
|
|
|
|
self.dropout = nn.Dropout(self.dropout_rate)
|
|
self.layer_norm = LayerNorm(self.c)
|
|
|
|
def forward(self, s):
|
|
for l in self.layers:
|
|
s = l(s)
|
|
|
|
s = self.dropout(s)
|
|
s = self.layer_norm(s)
|
|
|
|
return s
|
|
|
|
|
|
class StructureModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
c_s,
|
|
c_z,
|
|
c_ipa,
|
|
c_resnet,
|
|
no_heads_ipa,
|
|
no_qk_points,
|
|
no_v_points,
|
|
dropout_rate,
|
|
no_blocks,
|
|
no_transition_layers,
|
|
no_resnet_blocks,
|
|
no_angles,
|
|
trans_scale_factor,
|
|
epsilon,
|
|
inf,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
c_s:
|
|
Single representation channel dimension
|
|
c_z:
|
|
Pair representation channel dimension
|
|
c_ipa:
|
|
IPA hidden channel dimension
|
|
c_resnet:
|
|
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
|
|
no_heads_ipa:
|
|
Number of IPA heads
|
|
no_qk_points:
|
|
Number of query/key points to generate during IPA
|
|
no_v_points:
|
|
Number of value points to generate during IPA
|
|
dropout_rate:
|
|
Dropout rate used throughout the layer
|
|
no_blocks:
|
|
Number of structure module blocks
|
|
no_transition_layers:
|
|
Number of layers in the single representation transition
|
|
(Alg. 23 lines 8-9)
|
|
no_resnet_blocks:
|
|
Number of blocks in the angle resnet
|
|
no_angles:
|
|
Number of angles to generate in the angle resnet
|
|
trans_scale_factor:
|
|
Scale of single representation transition hidden dimension
|
|
epsilon:
|
|
Small number used in angle resnet normalization
|
|
inf:
|
|
Large number used for attention masking
|
|
"""
|
|
super(StructureModule, self).__init__()
|
|
|
|
self.c_s = c_s
|
|
self.c_z = c_z
|
|
self.c_ipa = c_ipa
|
|
self.c_resnet = c_resnet
|
|
self.no_heads_ipa = no_heads_ipa
|
|
self.no_qk_points = no_qk_points
|
|
self.no_v_points = no_v_points
|
|
self.dropout_rate = dropout_rate
|
|
self.no_blocks = no_blocks
|
|
self.no_transition_layers = no_transition_layers
|
|
self.no_resnet_blocks = no_resnet_blocks
|
|
self.no_angles = no_angles
|
|
self.trans_scale_factor = trans_scale_factor
|
|
self.epsilon = epsilon
|
|
self.inf = inf
|
|
|
|
# Buffers to be lazily initialized later
|
|
# self.default_frames
|
|
# self.group_idx
|
|
# self.atom_mask
|
|
# self.lit_positions
|
|
|
|
self.layer_norm_s = LayerNorm(self.c_s)
|
|
self.layer_norm_z = LayerNorm(self.c_z)
|
|
|
|
self.linear_in = Linear(self.c_s, self.c_s)
|
|
|
|
self.ipa = InvariantPointAttention(
|
|
self.c_s,
|
|
self.c_z,
|
|
self.c_ipa,
|
|
self.no_heads_ipa,
|
|
self.no_qk_points,
|
|
self.no_v_points,
|
|
inf=self.inf,
|
|
eps=self.epsilon,
|
|
)
|
|
|
|
self.ipa_dropout = nn.Dropout(self.dropout_rate)
|
|
self.layer_norm_ipa = LayerNorm(self.c_s)
|
|
|
|
self.transition = StructureModuleTransition(
|
|
self.c_s,
|
|
self.no_transition_layers,
|
|
self.dropout_rate,
|
|
)
|
|
|
|
self.bb_update = BackboneUpdate(self.c_s)
|
|
|
|
self.angle_resnet = AngleResnet(
|
|
self.c_s,
|
|
self.c_resnet,
|
|
self.no_resnet_blocks,
|
|
self.no_angles,
|
|
self.epsilon,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
evoformer_output_dict,
|
|
aatype,
|
|
mask=None,
|
|
inplace_safe=False,
|
|
_offload_inference=False,
|
|
):
|
|
"""
|
|
Args:
|
|
evoformer_output_dict:
|
|
Dictionary containing:
|
|
"single":
|
|
[*, N_res, C_s] single representation
|
|
"pair":
|
|
[*, N_res, N_res, C_z] pair representation
|
|
aatype:
|
|
[*, N_res] amino acid indices
|
|
mask:
|
|
Optional [*, N_res] sequence mask
|
|
Returns:
|
|
A dictionary of outputs
|
|
"""
|
|
s = evoformer_output_dict["single"]
|
|
|
|
if mask is None:
|
|
# [*, N]
|
|
mask = s.new_ones(s.shape[:-1])
|
|
|
|
# [*, N, C_s]
|
|
s = self.layer_norm_s(s)
|
|
|
|
# [*, N, N, C_z]
|
|
z = self.layer_norm_z(evoformer_output_dict["pair"])
|
|
|
|
z_reference_list = None
|
|
if(_offload_inference):
|
|
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
|
|
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
|
|
z_reference_list = [z]
|
|
z = None
|
|
|
|
# [*, N, C_s]
|
|
s_initial = s
|
|
s = self.linear_in(s)
|
|
|
|
# [*, N]
|
|
rigids = Rigid.identity(
|
|
s.shape[:-1],
|
|
s.dtype,
|
|
s.device,
|
|
self.training,
|
|
fmt="quat",
|
|
)
|
|
outputs = []
|
|
for i in range(self.no_blocks):
|
|
# [*, N, C_s]
|
|
s = s + self.ipa(
|
|
s,
|
|
z,
|
|
rigids,
|
|
mask,
|
|
inplace_safe=inplace_safe,
|
|
_offload_inference=_offload_inference,
|
|
_z_reference_list=z_reference_list
|
|
)
|
|
s = self.ipa_dropout(s)
|
|
s = self.layer_norm_ipa(s)
|
|
s = self.transition(s)
|
|
|
|
# [*, N]
|
|
rigids = rigids.compose_q_update_vec(self.bb_update(s))
|
|
|
|
# To hew as closely as possible to AlphaFold, we convert our
|
|
# quaternion-based transformations to rotation-matrix ones
|
|
# here
|
|
backb_to_global = Rigid(
|
|
Rotation(
|
|
rot_mats=rigids.get_rots().get_rot_mats(),
|
|
quats=None
|
|
),
|
|
rigids.get_trans(),
|
|
)
|
|
|
|
backb_to_global = backb_to_global.scale_translation(
|
|
self.trans_scale_factor
|
|
)
|
|
|
|
# [*, N, 7, 2]
|
|
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
|
|
|
|
all_frames_to_global = self.torsion_angles_to_frames(
|
|
backb_to_global,
|
|
angles,
|
|
aatype,
|
|
)
|
|
|
|
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
|
|
all_frames_to_global,
|
|
aatype,
|
|
)
|
|
|
|
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
|
|
|
|
preds = {
|
|
"frames": scaled_rigids.to_tensor_7(),
|
|
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
|
|
"unnormalized_angles": unnormalized_angles,
|
|
"angles": angles,
|
|
"positions": pred_xyz,
|
|
"states": s,
|
|
}
|
|
|
|
outputs.append(preds)
|
|
|
|
rigids = rigids.stop_rot_gradient()
|
|
|
|
del z, z_reference_list
|
|
|
|
if(_offload_inference):
|
|
evoformer_output_dict["pair"] = (
|
|
evoformer_output_dict["pair"].to(s.device)
|
|
)
|
|
|
|
outputs = dict_multimap(torch.stack, outputs)
|
|
outputs["single"] = s
|
|
|
|
return outputs
|
|
|
|
def _init_residue_constants(self, float_dtype, device):
|
|
if not hasattr(self, "default_frames"):
|
|
self.register_buffer(
|
|
"default_frames",
|
|
torch.tensor(
|
|
restype_rigid_group_default_frame,
|
|
dtype=float_dtype,
|
|
device=device,
|
|
requires_grad=False,
|
|
),
|
|
persistent=False,
|
|
)
|
|
if not hasattr(self, "group_idx"):
|
|
self.register_buffer(
|
|
"group_idx",
|
|
torch.tensor(
|
|
restype_atom14_to_rigid_group,
|
|
device=device,
|
|
requires_grad=False,
|
|
),
|
|
persistent=False,
|
|
)
|
|
if not hasattr(self, "atom_mask"):
|
|
self.register_buffer(
|
|
"atom_mask",
|
|
torch.tensor(
|
|
restype_atom14_mask,
|
|
dtype=float_dtype,
|
|
device=device,
|
|
requires_grad=False,
|
|
),
|
|
persistent=False,
|
|
)
|
|
if not hasattr(self, "lit_positions"):
|
|
self.register_buffer(
|
|
"lit_positions",
|
|
torch.tensor(
|
|
restype_atom14_rigid_group_positions,
|
|
dtype=float_dtype,
|
|
device=device,
|
|
requires_grad=False,
|
|
),
|
|
persistent=False,
|
|
)
|
|
|
|
def torsion_angles_to_frames(self, r, alpha, f):
|
|
# Lazily initialize the residue constants on the correct device
|
|
self._init_residue_constants(alpha.dtype, alpha.device)
|
|
# Separated purely to make testing less annoying
|
|
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
|
|
|
|
def frames_and_literature_positions_to_atom14_pos(
|
|
self, r, f # [*, N, 8] # [*, N]
|
|
):
|
|
# Lazily initialize the residue constants on the correct device
|
|
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
|
|
return frames_and_literature_positions_to_atom14_pos(
|
|
r,
|
|
f,
|
|
self.default_frames,
|
|
self.group_idx,
|
|
self.atom_mask,
|
|
self.lit_positions,
|
|
)
|