mirror of
https://github.com/AngxiaoYue/ReQFlow.git
synced 2026-06-04 20:24:22 +08:00
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
import torch
|
|
from torch import nn
|
|
from models.utils import get_index_embedding, get_time_embedding
|
|
|
|
|
|
class NodeFeatureNet(nn.Module):
|
|
|
|
def __init__(self, module_cfg):
|
|
super(NodeFeatureNet, self).__init__()
|
|
self._cfg = module_cfg
|
|
self.c_s = self._cfg.c_s
|
|
self.c_pos_emb = self._cfg.c_pos_emb
|
|
self.c_timestep_emb = self._cfg.c_timestep_emb
|
|
embed_size = self._cfg.c_pos_emb + self._cfg.c_timestep_emb * 2 + 1
|
|
if self._cfg.embed_chain:
|
|
embed_size += self._cfg.c_pos_emb
|
|
self.linear = nn.Linear(embed_size, self.c_s)
|
|
|
|
def embed_t(self, timesteps, mask):
|
|
timestep_emb = get_time_embedding(
|
|
timesteps[:, 0],
|
|
self.c_timestep_emb,
|
|
max_positions=2056
|
|
)[:, None, :].repeat(1, mask.shape[1], 1)
|
|
return timestep_emb * mask.unsqueeze(-1)
|
|
|
|
def forward(self, so3_t, r3_t, res_mask, diffuse_mask, pos):
|
|
# s: [b]
|
|
|
|
b, num_res, device = res_mask.shape[0], res_mask.shape[1], res_mask.device
|
|
|
|
# [b, n_res, c_pos_emb]
|
|
# pos = torch.arange(num_res, dtype=torch.float32).to(device)[None]
|
|
pos_emb = get_index_embedding(pos, self.c_pos_emb, max_len=2056)
|
|
pos_emb = pos_emb * res_mask.unsqueeze(-1)
|
|
|
|
# [b, n_res, c_timestep_emb]
|
|
input_feats = [
|
|
pos_emb,
|
|
diffuse_mask[..., None],
|
|
self.embed_t(so3_t, res_mask),
|
|
self.embed_t(r3_t, res_mask)
|
|
]
|
|
return self.linear(torch.cat(input_feats, dim=-1))
|