mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 20:24:22 +08:00
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
from models.utils import get_index_embedding, calc_distogram
|
|
|
|
class EdgeFeatureNet(nn.Module):
|
|
|
|
def __init__(self, module_cfg):
|
|
super(EdgeFeatureNet, self).__init__()
|
|
self._cfg = module_cfg
|
|
|
|
self.c_s = self._cfg.c_s
|
|
self.c_p = self._cfg.c_p
|
|
self.feat_dim = self._cfg.feat_dim
|
|
|
|
self.linear_s_p = nn.Linear(self.c_s, self.feat_dim)
|
|
self.linear_relpos = nn.Linear(self.feat_dim, self.feat_dim)
|
|
|
|
total_edge_feats = self.feat_dim * 3 + self._cfg.num_bins * 2
|
|
if self._cfg.embed_chain:
|
|
total_edge_feats += 1
|
|
if self._cfg.embed_diffuse_mask:
|
|
total_edge_feats += 2
|
|
self.edge_embedder = nn.Sequential(
|
|
nn.Linear(total_edge_feats, self.c_p),
|
|
nn.ReLU(),
|
|
nn.Linear(self.c_p, self.c_p),
|
|
nn.ReLU(),
|
|
nn.Linear(self.c_p, self.c_p),
|
|
nn.LayerNorm(self.c_p),
|
|
)
|
|
|
|
def embed_relpos(self, r):
|
|
# AlphaFold 2 Algorithm 4 & 5
|
|
# Based on OpenFold utils/tensor_utils.py
|
|
# Input: [b, n_res]
|
|
# [b, n_res, n_res]
|
|
d = r[:, :, None] - r[:, None, :]
|
|
pos_emb = get_index_embedding(d, self._cfg.feat_dim, max_len=2056)
|
|
return self.linear_relpos(pos_emb)
|
|
|
|
def _cross_concat(self, feats_1d, num_batch, num_res):
|
|
return torch.cat([
|
|
torch.tile(feats_1d[:, :, None, :], (1, 1, num_res, 1)),
|
|
torch.tile(feats_1d[:, None, :, :], (1, num_res, 1, 1)),
|
|
], dim=-1).float().reshape([num_batch, num_res, num_res, -1])
|
|
|
|
def forward(self, s, t, sc_t, p_mask, diffuse_mask):
|
|
# Input: [b, n_res, c_s]
|
|
num_batch, num_res, _ = s.shape
|
|
|
|
# [b, n_res, c_p]
|
|
p_i = self.linear_s_p(s)
|
|
cross_node_feats = self._cross_concat(p_i, num_batch, num_res)
|
|
|
|
# [b, n_res]
|
|
r = torch.arange(
|
|
num_res, device=s.device).unsqueeze(0).repeat(num_batch, 1)
|
|
relpos_feats = self.embed_relpos(r)
|
|
|
|
dist_feats = calc_distogram(
|
|
t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins)
|
|
sc_feats = calc_distogram(
|
|
sc_t, min_bin=1e-3, max_bin=20.0, num_bins=self._cfg.num_bins)
|
|
|
|
all_edge_feats = [cross_node_feats, relpos_feats, dist_feats, sc_feats]
|
|
if self._cfg.embed_diffuse_mask:
|
|
diff_feat = self._cross_concat(diffuse_mask[..., None], num_batch, num_res)
|
|
all_edge_feats.append(diff_feat)
|
|
edge_feats = self.edge_embedder(torch.concat(all_edge_feats, dim=-1))
|
|
edge_feats *= p_mask.unsqueeze(-1)
|
|
return edge_feats |