mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 12:14:23 +08:00
123 lines
5.0 KiB
Python
123 lines
5.0 KiB
Python
|
|
import torch
|
|
from torch import nn
|
|
|
|
from models.node_feature_net import NodeFeatureNet
|
|
from models.edge_feature_net import EdgeFeatureNet
|
|
from models import ipa_pytorch
|
|
from data import utils as du
|
|
|
|
|
|
class FlowModel(nn.Module):
|
|
|
|
def __init__(self, model_conf):
|
|
super(FlowModel, self).__init__()
|
|
self._model_conf = model_conf
|
|
self._ipa_conf = model_conf.ipa
|
|
self.rigids_ang_to_nm = lambda x: x.apply_trans_fn(lambda x: x * du.ANG_TO_NM_SCALE)
|
|
self.rigids_nm_to_ang = lambda x: x.apply_trans_fn(lambda x: x * du.NM_TO_ANG_SCALE)
|
|
self.node_feature_net = NodeFeatureNet(model_conf.node_features)
|
|
self.edge_feature_net = EdgeFeatureNet(model_conf.edge_features)
|
|
|
|
# Attention trunk
|
|
self.trunk = nn.ModuleDict()
|
|
for b in range(self._ipa_conf.num_blocks):
|
|
self.trunk[f'ipa_{b}'] = ipa_pytorch.InvariantPointAttention(self._ipa_conf)
|
|
self.trunk[f'ipa_ln_{b}'] = nn.LayerNorm(self._ipa_conf.c_s)
|
|
tfmr_in = self._ipa_conf.c_s
|
|
tfmr_layer = torch.nn.TransformerEncoderLayer(
|
|
d_model=tfmr_in,
|
|
nhead=self._ipa_conf.seq_tfmr_num_heads,
|
|
dim_feedforward=tfmr_in,
|
|
batch_first=True,
|
|
dropout=0.0,
|
|
norm_first=False
|
|
)
|
|
self.trunk[f'seq_tfmr_{b}'] = torch.nn.TransformerEncoder(
|
|
tfmr_layer, self._ipa_conf.seq_tfmr_num_layers, enable_nested_tensor=False)
|
|
self.trunk[f'post_tfmr_{b}'] = ipa_pytorch.Linear(
|
|
tfmr_in, self._ipa_conf.c_s, init="final")
|
|
self.trunk[f'node_transition_{b}'] = ipa_pytorch.StructureModuleTransition(
|
|
c=self._ipa_conf.c_s)
|
|
self.trunk[f'bb_update_{b}'] = ipa_pytorch.BackboneUpdate(
|
|
self._ipa_conf.c_s, use_rot_updates=True)
|
|
|
|
if b < self._ipa_conf.num_blocks-1:
|
|
# No edge update on the last block.
|
|
edge_in = self._model_conf.edge_embed_size
|
|
self.trunk[f'edge_transition_{b}'] = ipa_pytorch.EdgeTransition(
|
|
node_embed_size=self._ipa_conf.c_s,
|
|
edge_embed_in=edge_in,
|
|
edge_embed_out=self._model_conf.edge_embed_size,
|
|
)
|
|
|
|
def forward(self, input_feats):
|
|
node_mask = input_feats['res_mask']
|
|
edge_mask = node_mask[:, None] * node_mask[:, :, None]
|
|
diffuse_mask = input_feats['diffuse_mask']
|
|
res_index = input_feats['res_idx']
|
|
so3_t = input_feats['so3_t']
|
|
r3_t = input_feats['r3_t']
|
|
trans_t = input_feats['trans_t']
|
|
rotquats_t = input_feats['rotquats_t']
|
|
|
|
# Initialize node and edge embeddings
|
|
init_node_embed = self.node_feature_net(
|
|
so3_t,
|
|
r3_t,
|
|
node_mask,
|
|
diffuse_mask,
|
|
res_index
|
|
)
|
|
if 'trans_sc' not in input_feats:
|
|
trans_sc = torch.zeros_like(trans_t)
|
|
else:
|
|
trans_sc = input_feats['trans_sc']
|
|
init_edge_embed = self.edge_feature_net(
|
|
init_node_embed,
|
|
trans_t,
|
|
trans_sc,
|
|
edge_mask,
|
|
diffuse_mask,
|
|
)
|
|
|
|
# Initial rigids
|
|
curr_rigids = du.create_rigid_quats(rotquats_t, trans_t)
|
|
|
|
# Main trunk
|
|
curr_rigids = self.rigids_ang_to_nm(curr_rigids)
|
|
init_node_embed = init_node_embed * node_mask[..., None]
|
|
node_embed = init_node_embed * node_mask[..., None]
|
|
edge_embed = init_edge_embed * edge_mask[..., None]
|
|
for b in range(self._ipa_conf.num_blocks):
|
|
ipa_embed = self.trunk[f'ipa_{b}'](
|
|
node_embed,
|
|
edge_embed,
|
|
curr_rigids,
|
|
node_mask)
|
|
ipa_embed *= node_mask[..., None]
|
|
node_embed = self.trunk[f'ipa_ln_{b}'](node_embed + ipa_embed)
|
|
seq_tfmr_out = self.trunk[f'seq_tfmr_{b}'](
|
|
node_embed, src_key_padding_mask=(1 - node_mask).to(torch.bool))
|
|
node_embed = node_embed + self.trunk[f'post_tfmr_{b}'](seq_tfmr_out)
|
|
node_embed = self.trunk[f'node_transition_{b}'](node_embed)
|
|
node_embed = node_embed * node_mask[..., None]
|
|
rigid_update = self.trunk[f'bb_update_{b}'](
|
|
node_embed * node_mask[..., None])
|
|
curr_rigids = curr_rigids.compose_q_update_vec(
|
|
rigid_update, (node_mask * diffuse_mask)[..., None])
|
|
if b < self._ipa_conf.num_blocks-1:
|
|
edge_embed = self.trunk[f'edge_transition_{b}'](
|
|
node_embed, edge_embed)
|
|
edge_embed *= edge_mask[..., None]
|
|
|
|
curr_rigids = self.rigids_nm_to_ang(curr_rigids)
|
|
pred_trans = curr_rigids.get_trans()
|
|
pred_rotmats = curr_rigids.get_rots().get_rot_mats()
|
|
pred_rotquats = curr_rigids.get_rots().get_quats()
|
|
return {
|
|
'pred_trans': pred_trans,
|
|
'pred_rotmats': pred_rotmats,
|
|
'pred_rotquats': pred_rotquats,
|
|
}
|