mirror of
https://github.com/Saoge123/PocketFlow.git
synced 2026-06-04 12:44:22 +08:00
Delete pocket_flow/gdbp_model/.ipynb_checkpoints directory
This commit is contained in:
@@ -1,11 +0,0 @@
|
||||
from .atom_flow import AtomFlow
|
||||
from .bond_predictor import BondPredictor, BondFlow, BondFlowNew
|
||||
from .position_predictor import PositionPredictor
|
||||
from .focal_net import FrontierLayerVN
|
||||
from .encoder import ContextEncoder
|
||||
from .pos_filter import PositionEncoder, PositionFilter
|
||||
from .layers import *
|
||||
from .net_utils import *
|
||||
from .pocket_flow_with_edge import *
|
||||
from .pocket_flow_with_edge_new import *
|
||||
from .pocket_flow import *
|
||||
@@ -1,52 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from .layers import GBPerceptronVN, GBLinear, ST_GBP_Exp
|
||||
|
||||
|
||||
|
||||
class AtomFlow(nn.Module):
|
||||
def __init__(self, in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, num_lig_atom_type=10,
|
||||
num_flow_layers=6, bottleneck=1, use_conv1d=False) -> None:
|
||||
super(AtomFlow, self).__init__()
|
||||
'''self.msg_module = MessageModule(
|
||||
in_sca, in_vec, in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, cutoff=10.
|
||||
)'''
|
||||
self.net = nn.Sequential(
|
||||
GBPerceptronVN(
|
||||
in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
hidden_dim_sca, hidden_dim_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
)
|
||||
|
||||
self.flow_layers = nn.ModuleList()
|
||||
for _ in enumerate(range(num_flow_layers)):
|
||||
layer = ST_GBP_Exp(
|
||||
hidden_dim_sca, hidden_dim_vec, num_lig_atom_type, hidden_dim_vec, bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
self.flow_layers.append(layer)
|
||||
|
||||
def forward(self, z_atom, compose_features, focal_idx):
|
||||
sca_focal, vec_focal = compose_features[0][focal_idx], compose_features[1][focal_idx]
|
||||
sca_focal, vec_focal = self.net([sca_focal, vec_focal])
|
||||
for ix in range(len(self.flow_layers)):
|
||||
s, t = self.flow_layers[ix]([sca_focal, vec_focal])
|
||||
s = s.exp()
|
||||
z_atom = (z_atom + t) * s
|
||||
if ix == 0:
|
||||
atom_log_jacob = (torch.abs(s) + 1e-20).log()
|
||||
else:
|
||||
atom_log_jacob += (torch.abs(s) + 1e-20).log()
|
||||
return z_atom, atom_log_jacob
|
||||
|
||||
def reverse(self, atom_latent, compose_features, focal_idx):
|
||||
sca_focal, vec_focal = compose_features[0][focal_idx], compose_features[1][focal_idx]
|
||||
sca_focal, vec_focal = self.net([sca_focal, vec_focal])
|
||||
for ix in range(len(self.flow_layers)):
|
||||
s, t = self.flow_layers[ix]([sca_focal, vec_focal])
|
||||
atom_latent = (atom_latent / s.exp()) - t
|
||||
return atom_latent
|
||||
@@ -1,440 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import Module, Sequential
|
||||
from torch.nn import functional as F
|
||||
from torch_scatter import scatter_add
|
||||
#from .pos_filter import PositionEncoder
|
||||
from .layers import GBPerceptronVN, GBLinear, MessageModule, MessageAttention, AttentionEdges, ST_GBP_Exp, VNLeakyReLU
|
||||
from .net_utils import GaussianSmearing, EdgeExpansion
|
||||
import math
|
||||
|
||||
|
||||
GAUSSIAN_COEF = 1.0 / math.sqrt(2 * math.pi)
|
||||
|
||||
class BondPredictor(Module):
|
||||
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types,
|
||||
num_heads=4, cutoff=10.0, with_root=True, bottleneck=1):
|
||||
super(BondPredictor, self).__init__()
|
||||
self.with_root = with_root
|
||||
self.num_bond_types = num_bond_types
|
||||
self.message_module = MessageModule(
|
||||
in_sca, in_vec, edge_channels, edge_channels, num_filters[0], num_filters[1], cutoff=cutoff,
|
||||
bottleneck=bottleneck
|
||||
)
|
||||
self.nn_edge_ij = Sequential(
|
||||
GBPerceptronVN(edge_channels, edge_channels, num_filters[0], num_filters[1], bottleneck=bottleneck),
|
||||
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck)
|
||||
)
|
||||
|
||||
self.edge_feat = Sequential(
|
||||
GBPerceptronVN(num_filters[0] * 2 + in_sca, num_filters[1] * 2 + in_vec, num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck),
|
||||
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck)
|
||||
)
|
||||
self.edge_atten = AttentionEdges(
|
||||
num_filters, num_filters, num_heads, num_bond_types, bottleneck=bottleneck
|
||||
)
|
||||
self.edge_pred = GBLinear(num_filters[0], num_filters[1], num_bond_types + 1, 1, bottleneck=bottleneck)
|
||||
|
||||
if with_root:
|
||||
self.root_lin = GBLinear(in_sca, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck)
|
||||
self.root_vector_expansion = EdgeExpansion(edge_channels)
|
||||
|
||||
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
|
||||
self.distance_expansion_3A = GaussianSmearing(stop=3., num_gaussians=edge_channels)
|
||||
self.vector_expansion = EdgeExpansion(edge_channels) # Linear(in_features=1, out_features=edge_channels, bias=False)
|
||||
|
||||
def forward(self, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
|
||||
index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[], atom_type_emb=None):
|
||||
|
||||
vec_ij = pos_query[edge_index_q_cps_knn[0]] - cpx_pos[edge_index_q_cps_knn[1]]
|
||||
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (A, 1)
|
||||
edge_ij = self.distance_expansion(dist_ij), self.vector_expansion(vec_ij)
|
||||
|
||||
# node_attr_ctx_j = [node_attr_ctx_[edge_index_q_cps_knn[1]] for node_attr_ctx_ in node_attr_ctx] # (A, H)
|
||||
h = self.message_module(node_attr_compose, edge_ij, edge_index_q_cps_knn[1], dist_ij, annealing=True)
|
||||
# Aggregate messages
|
||||
y = [scatter_add(h[0], index=edge_index_q_cps_knn[0], dim=0, dim_size=pos_query.size(0)), # (N_query, F)
|
||||
scatter_add(h[1], index=edge_index_q_cps_knn[0], dim=0, dim_size=pos_query.size(0))]
|
||||
# add information of new atom
|
||||
if isinstance(atom_type_emb, torch.Tensor):
|
||||
root_vec_ij = self.root_vector_expansion(pos_query)
|
||||
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
|
||||
y = [y_root_sca+y[0], y_root_vec+y[1]]
|
||||
|
||||
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
|
||||
# print(edge_index_query.shape)
|
||||
idx_node_i = edge_index_query[0]
|
||||
node_mol_i = [
|
||||
y[0][idx_node_i],
|
||||
y[1][idx_node_i]
|
||||
]
|
||||
idx_node_j = edge_index_query[1]
|
||||
node_mol_j = [
|
||||
node_attr_compose[0][idx_node_j],
|
||||
node_attr_compose[1][idx_node_j]
|
||||
]
|
||||
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
|
||||
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
|
||||
|
||||
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
|
||||
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
|
||||
|
||||
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
|
||||
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
|
||||
edge_attr = self.edge_feat(edge_attr) # (E, N_edgetype)
|
||||
edge_attr = self.edge_atten(edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten, tri_edge_index, tri_edge_feat)
|
||||
edge_pred, _ = self.edge_pred(edge_attr)
|
||||
else:
|
||||
edge_pred = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
|
||||
return edge_pred
|
||||
|
||||
|
||||
class ST_AttEdge_Exp(torch.nn.Module):
|
||||
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types=3,
|
||||
num_heads=4, bottleneck=1, use_conv1d=False):
|
||||
super(ST_AttEdge_Exp, self).__init__()
|
||||
self.num_bond_types = num_bond_types
|
||||
|
||||
self.nn_edge_ij = Sequential(
|
||||
GBPerceptronVN(
|
||||
edge_channels, edge_channels, num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
)
|
||||
|
||||
self.edge_feat = Sequential(
|
||||
GBPerceptronVN(
|
||||
num_filters[0] * 2 + in_sca, num_filters[1] * 2 + in_vec, num_filters[0],
|
||||
num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
)
|
||||
self.edge_atten = AttentionEdges(num_filters, num_filters, num_heads, num_bond_types)
|
||||
self.edge_pred = GBLinear(
|
||||
num_filters[0], num_filters[1], (num_bond_types + 1) * 2, 1, bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
self.distance_expansion_3A = GaussianSmearing(stop=3., num_gaussians=edge_channels)
|
||||
self.vector_expansion = EdgeExpansion(edge_channels) # Linear(in_features=1, out_features=edge_channels, bias=False)
|
||||
|
||||
def forward(self, h_atom, pos_query, edge_index_query, cpx_pos, node_attr_compose, index_real_cps_edge_for_atten=[],
|
||||
tri_edge_index=[], tri_edge_feat=[]):
|
||||
|
||||
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
|
||||
# print(edge_index_query.shape)
|
||||
idx_node_i = edge_index_query[0]
|
||||
node_mol_i = [
|
||||
h_atom[0][idx_node_i],
|
||||
h_atom[1][idx_node_i]
|
||||
]
|
||||
idx_node_j = edge_index_query[1]
|
||||
node_mol_j = [
|
||||
node_attr_compose[0][idx_node_j],
|
||||
node_attr_compose[1][idx_node_j]
|
||||
]
|
||||
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
|
||||
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
|
||||
|
||||
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
|
||||
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
|
||||
|
||||
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
|
||||
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
|
||||
edge_attr = self.edge_feat(edge_attr) # (E, N_edgetype)
|
||||
edge_attr = self.edge_atten(edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten, tri_edge_index, tri_edge_feat)
|
||||
edge_pred, _ = self.edge_pred(edge_attr)
|
||||
s_edge, t_edge = edge_pred[:,:self.num_bond_types+1], edge_pred[:,self.num_bond_types+1:]
|
||||
else:
|
||||
s_edge = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
t_edge = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
return s_edge, t_edge
|
||||
|
||||
|
||||
|
||||
class BondFlow(torch.nn.Module):
|
||||
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types,
|
||||
num_heads=4, cutoff=10.0, with_root=True, num_st_layers=3,
|
||||
bottleneck=1, use_conv1d=False):
|
||||
super(BondFlow, self).__init__()
|
||||
self.with_root = with_root
|
||||
self.num_bond_types = num_bond_types
|
||||
self.num_st_layers = num_st_layers
|
||||
|
||||
self.pos_encoder = PositionEncoder(
|
||||
in_sca, in_vec, edge_channels, num_filters, cutoff=cutoff, bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
self.pos_filter = torch.nn.Sequential(
|
||||
GBPerceptronVN(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
num_filters[0], num_filters[1], 1, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
)
|
||||
if with_root:
|
||||
self.root_lin = GBLinear(
|
||||
in_sca, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.root_vector_expansion = EdgeExpansion(edge_channels)
|
||||
|
||||
self.flow_layers = torch.nn.ModuleList()
|
||||
for _ in range(num_st_layers):
|
||||
flow_layer = ST_AttEdge_Exp(
|
||||
in_sca, in_vec, edge_channels,
|
||||
num_filters, num_bond_types=num_bond_types, num_heads=num_heads,
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.flow_layers.append(flow_layer)
|
||||
|
||||
def forward(self, z_edge, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
|
||||
index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[], atom_type_emb=None,
|
||||
annealing=False):
|
||||
y = self.pos_encoder(
|
||||
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, atom_type_emb,
|
||||
annealing=annealing
|
||||
)
|
||||
'''if isinstance(atom_type_emb, torch.Tensor) and self.with_root:
|
||||
root_vec_ij = self.root_vector_expansion(pos_query)
|
||||
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
|
||||
y = [y_root_sca+y[0], y_root_vec+y[1]]'''
|
||||
for ix in range(self.num_st_layers):
|
||||
s, t = self.flow_layers[ix](
|
||||
y, pos_query, edge_index_query, cpx_pos, node_attr_compose,
|
||||
index_real_cps_edge_for_atten=index_real_cps_edge_for_atten,
|
||||
tri_edge_index=tri_edge_index, tri_edge_feat=tri_edge_feat
|
||||
)
|
||||
s = s.exp()
|
||||
z_edge = (z_edge + t) * s
|
||||
if ix == 0:
|
||||
edge_log_jacob = (torch.abs(s) + 1e-20).log()
|
||||
else:
|
||||
edge_log_jacob += (torch.abs(s) + 1e-20).log()
|
||||
return z_edge, edge_log_jacob
|
||||
|
||||
def reverse(self, edge_latent, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
|
||||
index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[], atom_type_emb=None,
|
||||
annealing=False):
|
||||
y = self.pos_encoder(
|
||||
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, atom_type_emb,
|
||||
annealing=annealing
|
||||
)
|
||||
'''if isinstance(atom_type_emb, torch.Tensor) and self.with_root:
|
||||
root_vec_ij = self.root_vector_expansion(pos_query)
|
||||
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
|
||||
y = [y_root_sca+y[0], y_root_vec+y[1]]'''
|
||||
for ix in range(self.num_st_layers):
|
||||
s, t = self.flow_layers[ix](
|
||||
y, pos_query, edge_index_query, cpx_pos, node_attr_compose,
|
||||
index_real_cps_edge_for_atten=index_real_cps_edge_for_atten,
|
||||
tri_edge_index=tri_edge_index, tri_edge_feat=tri_edge_feat
|
||||
)
|
||||
if s.size(0)==0 and t.size(0)==0:
|
||||
break
|
||||
else:
|
||||
edge_latent = (edge_latent / s.exp()) - t
|
||||
if s.size(0)==0 and t.size(0)==0:
|
||||
return torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
else:
|
||||
return edge_latent
|
||||
|
||||
def pos_classfier(self, pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, annealing=False):
|
||||
y = self.pos_encoder(
|
||||
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, annealing=annealing
|
||||
)
|
||||
pred = self.pos_filter(y)
|
||||
return pred
|
||||
|
||||
|
||||
class PositionEncoder(Module):
|
||||
def __init__(self, in_sca, in_vec, edge_channels, num_filters, bottleneck=1, cutoff=10.,
|
||||
num_heads=1, use_conv1d=False, with_root=True) -> None:
|
||||
super(PositionEncoder, self).__init__()
|
||||
self.message_module = MessageModule(
|
||||
in_sca, in_vec, edge_channels, edge_channels, num_filters[0], num_filters[1],
|
||||
bottleneck, cutoff, use_conv1d=use_conv1d
|
||||
)
|
||||
self.message_att = MessageAttention(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck,
|
||||
num_heads=num_heads, use_conv1d=use_conv1d
|
||||
)
|
||||
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
|
||||
self.vector_expansion = EdgeExpansion(edge_channels)
|
||||
self.root_lin = GBLinear(
|
||||
in_sca, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.root_vector_expansion = EdgeExpansion(edge_channels)
|
||||
|
||||
#self.act_sca = LeakyReLU()
|
||||
#self.act_vec = VNLeakyReLU(hidden_channels[1], share_nonlinearity=True) # 2023.1.13
|
||||
'''self.out_transform = GBLinear(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.layernorm_sca = nn.LayerNorm([num_filters[0]])
|
||||
self.layernorm_vec = nn.LayerNorm([num_filters[1], 3])''' # 2023.1.13 注释掉
|
||||
|
||||
def forward(self, pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, atom_type_emb,
|
||||
annealing=False):
|
||||
vec_ij = pos_query[edge_index_q_cps_knn[0]] - cpx_pos[edge_index_q_cps_knn[1]]
|
||||
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (A, 1)
|
||||
edge_ij = self.distance_expansion(dist_ij), self.vector_expansion(vec_ij)
|
||||
|
||||
#if isinstance(atom_type_emb, torch.Tensor) and self.with_root:
|
||||
root_vec_ij = self.root_vector_expansion(pos_query)
|
||||
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
|
||||
x = [y_root_sca, y_root_vec]
|
||||
|
||||
# node_attr_ctx_j = [node_attr_ctx_[edge_index_q_cps_knn[1]] for node_attr_ctx_ in node_attr_ctx] # (A, H)
|
||||
h_q = self.message_module(node_attr_compose, edge_ij, edge_index_q_cps_knn[1], dist_ij, annealing=annealing)
|
||||
y = self.message_att(x, h_q, edge_index_q_cps_knn[0])
|
||||
'''# non-linear
|
||||
out_sca = self.layernorm_sca(y[0])
|
||||
out_vec = self.layernorm_vec(y[1])
|
||||
y = self.out_transform((self.act_sca(out_sca), self.act_vec(out_vec)))''' # 2023.1.3
|
||||
return y
|
||||
|
||||
|
||||
class BondFlowNew(Module):
|
||||
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types=3, num_heads=4,
|
||||
cutoff=10.0, with_root=True, num_st_layers=6, bottleneck=1, use_conv1d=False):
|
||||
super(BondFlowNew, self).__init__()
|
||||
self.with_root = with_root
|
||||
self.num_bond_types = num_bond_types
|
||||
self.num_st_layers = num_st_layers
|
||||
## query encoder
|
||||
self.pos_encoder = PositionEncoder(
|
||||
in_sca, in_vec, edge_channels, num_filters, cutoff=cutoff, bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
## query filter
|
||||
'''self.pos_filter = torch.nn.Sequential(
|
||||
GBPerceptronVN(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
num_filters[0], num_filters[1], 1, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
)'''
|
||||
## edge pred
|
||||
self.distance_expansion_3A = GaussianSmearing(stop=3., num_gaussians=edge_channels)
|
||||
self.vector_expansion = EdgeExpansion(edge_channels)
|
||||
self.nn_edge_ij = Sequential(
|
||||
GBPerceptronVN(edge_channels, edge_channels, num_filters[0], num_filters[1]),
|
||||
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1])
|
||||
)
|
||||
self.edge_feat = Sequential(
|
||||
GBPerceptronVN(num_filters[0] * 2 + in_sca, num_filters[1] * 2 + in_vec, num_filters[0], num_filters[1]),
|
||||
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1])
|
||||
)
|
||||
self.edge_atten = AttentionEdges(num_filters, num_filters, num_heads, num_bond_types)
|
||||
## flow layer
|
||||
self.flow_layers = torch.nn.ModuleList()
|
||||
for _ in range(num_st_layers):
|
||||
flow_layer = ST_GBP_Exp(
|
||||
num_filters[0],
|
||||
num_filters[1],
|
||||
num_bond_types + 1,
|
||||
num_filters[1],
|
||||
bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
self.flow_layers.append(flow_layer)
|
||||
|
||||
def forward(self, z_edge, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
|
||||
atom_type_emb, index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[],
|
||||
annealing=False):
|
||||
y = self.pos_encoder(
|
||||
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose,
|
||||
atom_type_emb, annealing=annealing
|
||||
)
|
||||
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
|
||||
idx_node_i = edge_index_query[0]
|
||||
node_mol_i = [
|
||||
y[0][idx_node_i],
|
||||
y[1][idx_node_i]
|
||||
]
|
||||
idx_node_j = edge_index_query[1]
|
||||
node_mol_j = [
|
||||
node_attr_compose[0][idx_node_j],
|
||||
node_attr_compose[1][idx_node_j]
|
||||
]
|
||||
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
|
||||
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
|
||||
|
||||
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
|
||||
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
|
||||
|
||||
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
|
||||
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
|
||||
edge_attr = self.edge_feat(edge_attr)
|
||||
edge_attr = self.edge_atten(
|
||||
edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten,
|
||||
tri_edge_index, tri_edge_feat
|
||||
)
|
||||
#self.edge_atten()
|
||||
for ix in range(len(self.flow_layers)):
|
||||
s, t = self.flow_layers[ix](edge_attr)
|
||||
s = s.exp()
|
||||
z_edge = (z_edge + t) * s
|
||||
if ix == 0:
|
||||
edge_log_jacob = (torch.abs(s) + 1e-20).log()
|
||||
else:
|
||||
edge_log_jacob += (torch.abs(s) + 1e-20).log()
|
||||
return z_edge, edge_log_jacob
|
||||
else:
|
||||
z_edge = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
edge_log_jacob = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
return z_edge, edge_log_jacob
|
||||
|
||||
def reverse(self, edge_latent, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
|
||||
atom_type_emb, index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[],
|
||||
annealing=False):
|
||||
y = self.pos_encoder(
|
||||
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose,
|
||||
atom_type_emb, annealing=annealing
|
||||
)
|
||||
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
|
||||
idx_node_i = edge_index_query[0]
|
||||
node_mol_i = [
|
||||
y[0][idx_node_i],
|
||||
y[1][idx_node_i]
|
||||
]
|
||||
idx_node_j = edge_index_query[1]
|
||||
node_mol_j = [
|
||||
node_attr_compose[0][idx_node_j],
|
||||
node_attr_compose[1][idx_node_j]
|
||||
]
|
||||
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
|
||||
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
|
||||
|
||||
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
|
||||
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
|
||||
|
||||
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
|
||||
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
|
||||
edge_attr = self.edge_feat(edge_attr)
|
||||
edge_attr = self.edge_atten(
|
||||
edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten,
|
||||
tri_edge_index, tri_edge_feat
|
||||
)
|
||||
#self.edge_atten()
|
||||
for ix in range(len(self.flow_layers)):
|
||||
s, t = self.flow_layers[ix](edge_attr)
|
||||
edge_latent = (edge_latent / s.exp()) - t
|
||||
return edge_latent
|
||||
else:
|
||||
edge_latent = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
|
||||
return edge_latent
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
from torch.nn import Module, Sequential
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .layers import GBPerceptronVN, GBLinear
|
||||
|
||||
|
||||
class FrontierLayerVN(Module):
|
||||
def __init__(self, in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=1,
|
||||
use_conv1d=False):
|
||||
super(FrontierLayerVN, self).__init__()
|
||||
self.net = Sequential(
|
||||
GBPerceptronVN(
|
||||
in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
hidden_dim_sca, hidden_dim_vec, 1, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, h_att, idx_ligans):
|
||||
h_att_ligand = [h_att[0][idx_ligans], h_att[1][idx_ligans]]
|
||||
pred = self.net(h_att_ligand)
|
||||
pred = pred[0]
|
||||
return pred
|
||||
|
||||
@@ -1,477 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Module, Linear, LeakyReLU, ModuleList, LayerNorm
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn import global_mean_pool
|
||||
from torch_scatter import scatter_sum, scatter_softmax
|
||||
from math import pi as PI
|
||||
from .net_utils import GaussianSmearing, EdgeExpansion, Rescale
|
||||
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
class GBLinearConv1D(Module):
|
||||
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=True):
|
||||
super(GBLinearConv1D, self).__init__()
|
||||
assert in_vector % bottleneck == 0,\
|
||||
f"Input channel of vector ({in_vector}) must be divisible with bottleneck factor ({bottleneck})"
|
||||
if bottleneck > 1:
|
||||
self.hidden_dim = in_vector // bottleneck
|
||||
else:
|
||||
self.hidden_dim = max(in_vector, out_vector)
|
||||
self.out_vector = out_vector
|
||||
self.lin_vector = VNLinear(in_vector, self.hidden_dim, bias=False)
|
||||
self.lin_vector2 = VNLinear(self.hidden_dim, out_vector, bias=False)
|
||||
|
||||
self.use_conv1d = use_conv1d
|
||||
self.lin_scalar = nn.Conv1d(in_scalar + self.hidden_dim, out_scalar, 1, bias=False)
|
||||
self.scalar_to_vector_gates = nn.Conv1d(out_scalar, out_vector, 1)
|
||||
|
||||
def forward(self, features):
|
||||
feat_scalar, feat_vector = features
|
||||
feat_vector_inter = self.lin_vector(feat_vector) # (N_samples, dim_hid, 3)
|
||||
feat_vector_norm = torch.norm(feat_vector_inter, p=2, dim=-1) # (N_samples, dim_hid)
|
||||
feat_scalar_cat = torch.cat([feat_vector_norm, feat_scalar], dim=-1) # (N_samples, dim_hid+in_scalar)
|
||||
|
||||
out_scalar = self.lin_scalar(feat_scalar_cat.unsqueeze(-1)).squeeze(-1)
|
||||
gating = torch.sigmoid(self.scalar_to_vector_gates(out_scalar.unsqueeze(-1)))
|
||||
|
||||
out_vector = self.lin_vector2(feat_vector_inter)
|
||||
out_vector = gating * out_vector
|
||||
return out_scalar, out_vector
|
||||
|
||||
'''class GBLinear(Module):
|
||||
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=False):
|
||||
super(GBLinear, self).__init__()
|
||||
assert in_vector % bottleneck == 0,\
|
||||
f"Input channel of vector ({in_vector}) must be divisible with bottleneck factor ({bottleneck})"
|
||||
if bottleneck > 1:
|
||||
self.hidden_dim = in_vector // bottleneck
|
||||
else:
|
||||
self.hidden_dim = max(in_vector, out_vector)
|
||||
self.out_vector = out_vector
|
||||
self.lin_vector = VNLinear(in_vector, self.hidden_dim, bias=False)
|
||||
self.lin_vector2 = VNLinear(self.hidden_dim, out_vector, bias=False)
|
||||
|
||||
self.use_conv1d = use_conv1d
|
||||
self.scalar_to_vector_gates = Linear(out_scalar, out_vector)
|
||||
self.lin_scalar = Linear(in_scalar + self.hidden_dim, out_scalar, bias=False)
|
||||
|
||||
def forward(self, features):
|
||||
feat_scalar, feat_vector = features
|
||||
feat_vector_inter = self.lin_vector(feat_vector) # (N_samples, dim_hid, 3)
|
||||
feat_vector_norm = torch.norm(feat_vector_inter, p=2, dim=-1) # (N_samples, dim_hid)
|
||||
feat_scalar_cat = torch.cat([feat_vector_norm, feat_scalar], dim=-1) # (N_samples, dim_hid+in_scalar)
|
||||
|
||||
out_scalar = self.lin_scalar(feat_scalar_cat)
|
||||
gating = torch.sigmoid(self.scalar_to_vector_gates(out_scalar)).unsqueeze(-1)
|
||||
|
||||
out_vector = self.lin_vector2(feat_vector_inter)
|
||||
out_vector = gating * out_vector
|
||||
return out_scalar, out_vector'''
|
||||
|
||||
class GBLinear(Module):
|
||||
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=(1,1), use_conv1d=False):
|
||||
super(GBLinear, self).__init__()
|
||||
if isinstance(bottleneck, int):
|
||||
sca_bottleneck = bottleneck
|
||||
vec_bottleneck = bottleneck
|
||||
else:
|
||||
sca_bottleneck = bottleneck[0]
|
||||
vec_bottleneck = bottleneck[1]
|
||||
assert in_vector % vec_bottleneck == 0,\
|
||||
f"Input channel of vector ({in_vector}) must be divisible with bottleneck factor ({vec_bottleneck})"
|
||||
assert in_scalar % sca_bottleneck == 0,\
|
||||
f"Input channel of vector ({in_scalar}) must be divisible with bottleneck factor ({sca_bottleneck})"
|
||||
if sca_bottleneck > 1:
|
||||
self.sca_hidden_dim = in_scalar // sca_bottleneck
|
||||
else:
|
||||
self.sca_hidden_dim = max(in_vector, out_vector)
|
||||
|
||||
if vec_bottleneck > 1:
|
||||
self.hidden_dim = in_vector // vec_bottleneck
|
||||
else:
|
||||
self.hidden_dim = max(in_vector, out_vector)
|
||||
|
||||
self.out_vector = out_vector
|
||||
self.lin_vector = VNLinear(in_vector, self.hidden_dim, bias=False)
|
||||
self.lin_vector2 = VNLinear(self.hidden_dim, out_vector, bias=False)
|
||||
|
||||
self.use_conv1d = use_conv1d
|
||||
self.scalar_to_vector_gates = Linear(out_scalar, out_vector)
|
||||
self.lin_scalar_1 = Linear(in_scalar, self.sca_hidden_dim, bias=False)
|
||||
self.lin_scalar_2 = Linear(self.hidden_dim + self.sca_hidden_dim, out_scalar, bias=False)
|
||||
|
||||
def forward(self, features):
|
||||
feat_scalar, feat_vector = features
|
||||
feat_vector_inter = self.lin_vector(feat_vector) # (N_samples, dim_hid, 3)
|
||||
feat_vector_norm = torch.norm(feat_vector_inter, p=2, dim=-1) # (N_samples, dim_hid)
|
||||
z_sca = self.lin_scalar_1(feat_scalar)
|
||||
feat_scalar_cat = torch.cat([feat_vector_norm, z_sca], dim=-1) # (N_samples, dim_hid+in_scalar)
|
||||
|
||||
#z_sca = self.lin_scalar_1(feat_scalar_cat)
|
||||
out_scalar = self.lin_scalar_2(feat_scalar_cat)
|
||||
gating = torch.sigmoid(self.scalar_to_vector_gates(out_scalar)).unsqueeze(-1)
|
||||
|
||||
out_vector = self.lin_vector2(feat_vector_inter)
|
||||
out_vector = gating * out_vector
|
||||
return out_scalar, out_vector
|
||||
|
||||
|
||||
class GBPerceptronVN(Module):
|
||||
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=False):
|
||||
super(GBPerceptronVN, self).__init__()
|
||||
self.gb_linear = GBLinear(
|
||||
in_scalar, in_vector, out_scalar, out_vector, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.act_sca = LeakyReLU()
|
||||
self.act_vec = VNLeakyReLU(out_vector)
|
||||
|
||||
def forward(self, x):
|
||||
sca, vec = self.gb_linear(x)
|
||||
vec = self.act_vec(vec)
|
||||
sca = self.act_sca(sca)
|
||||
return sca, vec
|
||||
|
||||
|
||||
class VNLinear(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super(VNLinear, self).__init__()
|
||||
self.map_to_feat = nn.Linear(in_channels, out_channels, *args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: point features of shape [B, N_samples, N_feat, 3]
|
||||
'''
|
||||
x_out = self.map_to_feat(x.transpose(-2,-1)).transpose(-2,-1)
|
||||
return x_out
|
||||
|
||||
|
||||
class VNLeakyReLU(nn.Module):
|
||||
def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.01):
|
||||
super(VNLeakyReLU, self).__init__()
|
||||
if share_nonlinearity == True:
|
||||
self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
|
||||
else:
|
||||
self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
|
||||
self.negative_slope = negative_slope
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
x: point features of shape [B, N_samples, N_feat, 3]
|
||||
'''
|
||||
d = self.map_to_dir(x.transpose(-2,-1)).transpose(-2,-1) # (N_samples, N_feat, 3)
|
||||
dotprod = (x*d).sum(-1, keepdim=True) # sum over 3-value dimension
|
||||
mask = (dotprod >= 0).to(x.dtype)
|
||||
d_norm_sq = (d*d).sum(-1, keepdim=True) # sum over 3-value dimension
|
||||
x_out = (self.negative_slope * x +
|
||||
(1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d)))
|
||||
return x_out
|
||||
|
||||
|
||||
class ST_GBP_Exp(nn.Module):
|
||||
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=False):
|
||||
super(ST_GBP_Exp, self).__init__()
|
||||
self.in_scalar = in_scalar
|
||||
self.in_vector = in_vector
|
||||
self.out_scalar = out_scalar
|
||||
self.out_vector = out_vector
|
||||
|
||||
self.gb_linear1 = GBLinear(
|
||||
in_scalar, in_vector, in_scalar, in_vector, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.gb_linear2 = GBLinear(
|
||||
in_scalar, in_vector, out_scalar*2, out_vector, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.act_sca = nn.Tanh()
|
||||
self.act_vec = VNLeakyReLU(out_vector)
|
||||
self.rescale = Rescale()
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
:param x: (batch * repeat_num for node/edge, emb)
|
||||
:return: w and b for affine operation
|
||||
'''
|
||||
sca, vec = self.gb_linear1(x)
|
||||
sca = self.act_sca(sca)
|
||||
vec = self.act_vec(vec)
|
||||
sca, vec = self.gb_linear2([sca, vec])
|
||||
s = sca[:, :self.out_scalar]
|
||||
t = sca[:, self.out_scalar:]
|
||||
s = self.rescale(torch.tanh(s))
|
||||
return s, t
|
||||
|
||||
|
||||
class MessageAttention(Module):
|
||||
def __init__(self, in_sca, in_vec, out_sca, out_vec, bottleneck=1, num_heads=1, use_conv1d=False) -> None:
|
||||
super(MessageAttention, self).__init__()
|
||||
|
||||
assert (in_sca % num_heads == 0) and (in_vec % num_heads == 0)
|
||||
assert (out_sca % num_heads == 0) and (out_vec % num_heads == 0)
|
||||
|
||||
self.num_heads =num_heads
|
||||
self.lin_v = GBLinear(in_sca, in_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d)
|
||||
self.lin_k = GBLinear(in_sca, in_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d)
|
||||
|
||||
def forward(self, x, query, edge_index_i):
|
||||
N = x[0].size(0)
|
||||
N_msg = len(edge_index_i)
|
||||
msg = [
|
||||
query[0].view(N_msg, self.num_heads, -1),
|
||||
query[1].view(N_msg, self.num_heads, -1, 3)
|
||||
]
|
||||
k = self.lin_k(x)
|
||||
x_i = [
|
||||
k[0][edge_index_i].view(N_msg, self.num_heads, -1),
|
||||
k[1][edge_index_i].view(N_msg, self.num_heads, -1, 3)
|
||||
]
|
||||
#alpha_scale = [x_i[0].size(-1)**0.5, x_i[1].size(-2)**0.5]
|
||||
alpha = [
|
||||
(msg[0] * x_i[0]).sum(-1), #/alpha_scale[0] # (N', heads)
|
||||
(msg[1] * x_i[1]).sum(-1).sum(-1) #/alpha_scale[1] # (N', heads)
|
||||
]
|
||||
alpha = [
|
||||
scatter_softmax(alpha[0], edge_index_i, dim=0),
|
||||
scatter_softmax(alpha[1], edge_index_i, dim=0)
|
||||
]
|
||||
msg = [
|
||||
(alpha[0].unsqueeze(-1) * msg[0]).view(N_msg, -1),
|
||||
(alpha[1].unsqueeze(-1).unsqueeze(-1) * msg[1]).view(N_msg, -1, 3)
|
||||
]
|
||||
sca_msg = scatter_sum(msg[0], edge_index_i, dim=0, dim_size=N)
|
||||
vec_msg = scatter_sum(msg[1], edge_index_i, dim=0, dim_size=N)
|
||||
#return sca_msg, vec_msg
|
||||
root_sca, root_vec = self.lin_v(x)
|
||||
out_sca = sca_msg + root_sca
|
||||
out_vec = vec_msg + root_vec
|
||||
return out_sca, out_vec
|
||||
|
||||
|
||||
class MessageModule(nn.Module):
|
||||
def __init__(self, node_sca, node_vec, edge_sca, edge_vec, out_sca, out_vec,
|
||||
bottleneck=1, cutoff=10., use_conv1d=False):
|
||||
super(MessageModule, self).__init__()
|
||||
hid_sca, hid_vec = edge_sca, edge_vec
|
||||
self.cutoff = cutoff
|
||||
self.node_gblinear = GBLinear(
|
||||
node_sca, node_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.edge_gbp = GBPerceptronVN(
|
||||
edge_sca, edge_vec, hid_sca, hid_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
self.sca_linear = Linear(hid_sca, out_sca) # edge_sca for y_sca
|
||||
self.e2n_linear = Linear(hid_sca, out_vec)
|
||||
self.n2e_linear = Linear(out_sca, out_vec)
|
||||
self.edge_vnlinear = VNLinear(hid_vec, out_vec)
|
||||
|
||||
self.out_gblienar = GBLinear(
|
||||
out_sca, out_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
def forward(self, node_features, edge_features, edge_index_node, dist_ij=None, annealing=False):
|
||||
node_scalar, node_vector = self.node_gblinear(node_features)
|
||||
node_scalar, node_vector = node_scalar[edge_index_node], node_vector[edge_index_node]
|
||||
edge_scalar, edge_vector = self.edge_gbp(edge_features)
|
||||
|
||||
y_scalar = node_scalar * self.sca_linear(edge_scalar)
|
||||
y_node_vector = self.e2n_linear(edge_scalar).unsqueeze(-1) * node_vector
|
||||
y_edge_vector = self.n2e_linear(node_scalar).unsqueeze(-1) * self.edge_vnlinear(edge_vector)
|
||||
y_vector = y_node_vector + y_edge_vector
|
||||
|
||||
output = self.out_gblienar((y_scalar, y_vector))
|
||||
|
||||
if annealing:
|
||||
C = 0.5 * (torch.cos(dist_ij * PI / self.cutoff) + 1.0) # (A, 1)
|
||||
C = C * (dist_ij <= self.cutoff) * (dist_ij >= 0.0)
|
||||
output = [output[0] * C.view(-1, 1), output[1] * C.view(-1, 1, 1)] # (A, 1)
|
||||
return output
|
||||
|
||||
|
||||
class AttentionInteractionBlockVN(Module):
|
||||
|
||||
def __init__(self, hidden_channels, edge_channels, num_edge_types, bottleneck=1, num_heads=1,
|
||||
cutoff=10., use_conv1d=False):
|
||||
super(AttentionInteractionBlockVN, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
# edge features
|
||||
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels - num_edge_types)
|
||||
self.vector_expansion = EdgeExpansion(edge_channels) # Linear(in_features=1, out_features=edge_channels, bias=False)
|
||||
## compare encoder and classifier message passing
|
||||
|
||||
# edge weigths and linear for values
|
||||
self.message_module = MessageModule(hidden_channels[0], hidden_channels[1], edge_channels, edge_channels,
|
||||
hidden_channels[0], hidden_channels[1], bottleneck=bottleneck,
|
||||
cutoff=cutoff, use_conv1d=use_conv1d)
|
||||
self.msg_att = MessageAttention(hidden_channels[0], hidden_channels[1], hidden_channels[0], hidden_channels[1],
|
||||
bottleneck=bottleneck, num_heads=num_heads, use_conv1d=use_conv1d)
|
||||
|
||||
# centroid nodes and finall linear
|
||||
self.act_sca = LeakyReLU()
|
||||
self.act_vec = VNLeakyReLU(hidden_channels[1], share_nonlinearity=True)
|
||||
self.out_transform = GBLinear(
|
||||
hidden_channels[0], hidden_channels[1], hidden_channels[0], hidden_channels[1], use_conv1d=use_conv1d,
|
||||
bottleneck=bottleneck
|
||||
)
|
||||
self.layernorm_sca = LayerNorm([hidden_channels[0]])
|
||||
self.layernorm_vec = LayerNorm([hidden_channels[1], 3])
|
||||
|
||||
def forward(self, x, edge_index, edge_feature, edge_vector, edge_dist, annealing=False):
|
||||
"""
|
||||
Args:
|
||||
x: Node features: scalar features (N, feat), vector features(N, feat, 3)
|
||||
edge_index: (2, E).
|
||||
edge_attr: (E, H)
|
||||
"""
|
||||
scalar, vector = x
|
||||
N = scalar.size(0)
|
||||
row, col = edge_index # (E,) , (E,)
|
||||
|
||||
# Compute edge features
|
||||
#edge_dist = torch.norm(edge_vector, dim=-1, p=2)
|
||||
edge_sca_feat = torch.cat([self.distance_expansion(edge_dist), edge_feature], dim=-1)
|
||||
edge_vec_feat = self.vector_expansion(edge_vector)
|
||||
|
||||
msg_j_sca, msg_j_vec = self.message_module(
|
||||
x, (edge_sca_feat, edge_vec_feat), col, edge_dist, annealing=annealing
|
||||
)
|
||||
out_sca, out_vec = self.msg_att(x, (msg_j_sca, msg_j_vec), row)
|
||||
# non-linear
|
||||
out_sca = self.layernorm_sca(out_sca)
|
||||
out_vec = self.layernorm_vec(out_vec)
|
||||
out = self.out_transform((self.act_sca(out_sca), self.act_vec(out_vec)))
|
||||
return out
|
||||
|
||||
|
||||
class AttentionEdges(Module):
|
||||
def __init__(self, hidden_channels, key_channels, num_heads=1, num_bond_types=3, bottleneck=1,
|
||||
use_conv1d=False):
|
||||
super(AttentionEdges, self).__init__()
|
||||
|
||||
assert (hidden_channels[0] % num_heads == 0) and (hidden_channels[1] % num_heads == 0)
|
||||
assert (key_channels[0] % num_heads == 0) and (key_channels[1] % num_heads == 0)
|
||||
|
||||
self.hidden_channels = hidden_channels
|
||||
self.key_channels = key_channels
|
||||
self.num_heads = num_heads
|
||||
|
||||
# linear transformation for attention
|
||||
self.q_lin = GBLinear(
|
||||
hidden_channels[0], hidden_channels[1], key_channels[0], key_channels[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.k_lin = GBLinear(
|
||||
hidden_channels[0], hidden_channels[1], key_channels[0], key_channels[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.v_lin = GBLinear(
|
||||
hidden_channels[0], hidden_channels[1], hidden_channels[0], hidden_channels[1],
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
self.atten_bias_lin = AttentionBias(
|
||||
self.num_heads, hidden_channels, num_bond_types=num_bond_types,
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
self.layernorm_sca = LayerNorm([hidden_channels[0]])
|
||||
self.layernorm_vec = LayerNorm([hidden_channels[1], 3])
|
||||
|
||||
def forward(self, edge_attr, edge_index, pos_compose,
|
||||
index_real_cps_edge_for_atten, tri_edge_index, tri_edge_feat,):
|
||||
"""
|
||||
Args:
|
||||
x: edge features: scalar features (N, feat), vector features(N, feat, 3)
|
||||
edge_attr: (E, H)
|
||||
edge_index: (2, E). the row can be seen as batch_edge
|
||||
"""
|
||||
scalar, vector = edge_attr
|
||||
N = scalar.size(0)
|
||||
row, col = edge_index # (N,)
|
||||
|
||||
# Project to multiple key, query and value spaces
|
||||
h_queries = self.q_lin(edge_attr)
|
||||
h_queries = (h_queries[0].view(N, self.num_heads, -1), # (N, heads, K_per_head)
|
||||
h_queries[1].view(N, self.num_heads, -1, 3)) # (N, heads, K_per_head, 3)
|
||||
h_keys = self.k_lin(edge_attr)
|
||||
h_keys = (h_keys[0].view(N, self.num_heads, -1), # (N, heads, K_per_head)
|
||||
h_keys[1].view(N, self.num_heads, -1, 3)) # (N, heads, K_per_head, 3)
|
||||
h_values = self.v_lin(edge_attr)
|
||||
h_values = (h_values[0].view(N, self.num_heads, -1), # (N, heads, K_per_head)
|
||||
h_values[1].view(N, self.num_heads, -1, 3)) # (N, heads, K_per_head, 3)
|
||||
|
||||
# assert (index_edge_i_list == index_real_cps_edge_for_atten[0]).all()
|
||||
# assert (index_edge_j_list == index_real_cps_edge_for_atten[1]).all()
|
||||
index_edge_i_list, index_edge_j_list = index_real_cps_edge_for_atten
|
||||
|
||||
# # get nodes of triangle edges
|
||||
|
||||
atten_bias = self.atten_bias_lin(
|
||||
tri_edge_index,
|
||||
tri_edge_feat,
|
||||
pos_compose,
|
||||
)
|
||||
|
||||
|
||||
# query * key
|
||||
queries_i = [h_queries[0][index_edge_i_list], h_queries[1][index_edge_i_list]]
|
||||
keys_j = [h_keys[0][index_edge_j_list], h_keys[1][index_edge_j_list]]
|
||||
|
||||
qk_ij = [
|
||||
(queries_i[0] * keys_j[0]).sum(-1), # (N', heads)
|
||||
(queries_i[1] * keys_j[1]).sum(-1).sum(-1) # (N', heads)
|
||||
]
|
||||
|
||||
alpha = [
|
||||
atten_bias[0] + qk_ij[0],
|
||||
atten_bias[1] + qk_ij[1]
|
||||
]
|
||||
alpha = [
|
||||
scatter_softmax(alpha[0], index_edge_i_list, dim=0), # (N', heads)
|
||||
scatter_softmax(alpha[1], index_edge_i_list, dim=0) # (N', heads)
|
||||
]
|
||||
|
||||
values_j = [h_values[0][index_edge_j_list], h_values[1][index_edge_j_list]]
|
||||
num_attens = len(index_edge_j_list)
|
||||
output =[
|
||||
scatter_sum((alpha[0].unsqueeze(-1) * values_j[0]).view(num_attens, -1), index_edge_i_list, dim=0, dim_size=N), # (N, H, 3)
|
||||
scatter_sum((alpha[1].unsqueeze(-1).unsqueeze(-1) * values_j[1]).view(num_attens, -1, 3), index_edge_i_list, dim=0, dim_size=N) # (N, H, 3)
|
||||
]
|
||||
|
||||
# output
|
||||
output = [edge_attr[0] + output[0], edge_attr[1] + output[1]]
|
||||
output = [self.layernorm_sca(output[0]), self.layernorm_vec(output[1])]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AttentionBias(Module):
|
||||
def __init__(self, num_heads, hidden_channels, cutoff=10., num_bond_types=3,
|
||||
bottleneck=1, use_conv1d=False): #TODO: change the cutoff
|
||||
super(AttentionBias, self).__init__()
|
||||
num_edge_types = num_bond_types + 1
|
||||
self.num_bond_types = num_bond_types
|
||||
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=hidden_channels[0] - num_edge_types-1) # minus 1 for self edges (e.g. edge 0-0)
|
||||
self.vector_expansion = EdgeExpansion(hidden_channels[1]) # Linear(in_features=1, out_features=hidden_channels[1], bias=False)
|
||||
self.gblinear = GBLinear(
|
||||
hidden_channels[0], hidden_channels[1], num_heads, num_heads,
|
||||
bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
def forward(self, tri_edge_index, tri_edge_feat, pos_compose):
|
||||
node_a, node_b = tri_edge_index
|
||||
pos_a = pos_compose[node_a]
|
||||
pos_b = pos_compose[node_b]
|
||||
vector = pos_a - pos_b
|
||||
dist = torch.norm(vector, p=2, dim=-1)
|
||||
|
||||
dist_feat = self.distance_expansion(dist)
|
||||
sca_feat = torch.cat([
|
||||
dist_feat,
|
||||
tri_edge_feat,
|
||||
], dim=-1)
|
||||
vec_feat = self.vector_expansion(vector)
|
||||
output_sca, output_vec = self.gblinear([sca_feat, vec_feat])
|
||||
output_vec = (output_vec * output_vec).sum(-1)
|
||||
return output_sca, output_vec
|
||||
@@ -1,193 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.modules.loss import _WeightedLoss
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
|
||||
keys = ['edge_flow.flow_layers.5', 'atom_flow.flow_layers.5',
|
||||
'pos_predictor.mu_net', 'pos_predictor.logsigma_net', 'pos_predictor.pi_net',
|
||||
'focal_net.net']
|
||||
def reset_parameters(model, keys):
|
||||
for name, para in model.named_parameters():
|
||||
for k in keys:
|
||||
if k in name and 'bias' in name:
|
||||
torch.nn.init.constant_(para, 0.)
|
||||
elif k in name and 'layernorm' in name:
|
||||
torch.nn.init.constant_(para, 1.)
|
||||
elif k in name and 'rescale.weight' in name:
|
||||
torch.nn.init.constant_(para, 0.)
|
||||
elif k in name:
|
||||
torch.nn.init.kaiming_normal_(para)
|
||||
return model
|
||||
|
||||
|
||||
def flow_reverse(flow_layers, latent, feat):
|
||||
for i in reversed(range(len(flow_layers))):
|
||||
s_sca, t_sca, vec = flow_layers[i](feat)
|
||||
s_sca = s_sca.exp() # 为什么要计算指数?确保s中的每个元素都大于0,以保证可逆性
|
||||
latent = (latent / s_sca) - t_sca
|
||||
return latent, vec
|
||||
|
||||
|
||||
def flow_forward(flow_layers, x_z, feature):
|
||||
for i in range(len(flow_layers)):
|
||||
s_sca, t_sca, vec = flow_layers[i](feature)
|
||||
s_sca = s_sca.exp() # 为什么要计算指数?
|
||||
x_z = (x_z + t_sca) * s_sca
|
||||
if i == 0:
|
||||
x_log_jacob = (torch.abs(s_sca) + 1e-20).log()
|
||||
else:
|
||||
x_log_jacob += (torch.abs(s_sca) + 1e-20).log()
|
||||
return x_z, x_log_jacob, vec
|
||||
|
||||
|
||||
class GaussianSmearing(nn.Module):
|
||||
def __init__(self, start=0.0, stop=10.0, num_gaussians=50):
|
||||
super(GaussianSmearing, self).__init__()
|
||||
self.stop = stop
|
||||
offset = torch.linspace(start, stop, num_gaussians)
|
||||
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
|
||||
self.register_buffer('offset', offset)
|
||||
|
||||
def forward(self, dist):
|
||||
dist = dist.clamp_max(self.stop)
|
||||
dist = dist.view(-1, 1) - self.offset.view(1, -1)
|
||||
return torch.exp(self.coeff * torch.pow(dist, 2))
|
||||
|
||||
|
||||
class EdgeExpansion(nn.Module):
|
||||
def __init__(self, edge_channels):
|
||||
super(EdgeExpansion, self).__init__()
|
||||
self.nn = nn.Linear(in_features=1, out_features=edge_channels, bias=False)
|
||||
|
||||
def forward(self, edge_vector):
|
||||
edge_vector = edge_vector / (torch.norm(edge_vector, p=2, dim=1, keepdim=True)+1e-7)
|
||||
expansion = self.nn(edge_vector.unsqueeze(-1)).transpose(1, -1)
|
||||
return expansion
|
||||
|
||||
|
||||
class Scalarize(nn.Module):
|
||||
def __init__(self, sca_in_dim, vec_in_dim, hidden_dim, out_dim, act_fn=nn.Sigmoid()) -> None:
|
||||
super(Scalarize, self).__init__()
|
||||
self.sca_in_dim = sca_in_dim
|
||||
self.vec_in_dim = vec_in_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
self.lin_scalarize_1 = nn.Linear(sca_in_dim+vec_in_dim, hidden_dim)
|
||||
self.lin_scalarize_2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.act_fn = act_fn
|
||||
|
||||
def forward(self, x):
|
||||
sca, vec = x[0].view(-1, self.sca_in_dim), x[1]
|
||||
norm_vec = torch.norm(vec, p=2, dim=-1).view(-1, self.vec_in_dim)
|
||||
sca = torch.cat([sca, norm_vec], dim=1)
|
||||
sca = self.lin_scalarize_1(sca)
|
||||
sca = self.act_fn(sca)
|
||||
sca = self.lin_scalarize_2(sca)
|
||||
return sca
|
||||
|
||||
|
||||
class Rescale(nn.Module):
|
||||
def __init__(self):
|
||||
super(Rescale, self).__init__()
|
||||
self.weight = nn.Parameter(torch.zeros([1]))
|
||||
|
||||
def forward(self, x):
|
||||
if torch.isnan(torch.exp(self.weight)).any():
|
||||
print(self.weight)
|
||||
raise RuntimeError('Rescale factor has NaN entries')
|
||||
|
||||
x = torch.exp(self.weight) * x
|
||||
return x
|
||||
|
||||
'''class ST_DWLayer(nn.Module):
|
||||
def __init__(self, sca_in_features, sca_out_features,
|
||||
vec_in_features, vec_out_features, sca_act=nn.ReLU(),
|
||||
vec_act=nn.Sigmoid(), device=None, dtype=None) -> None:
|
||||
super(ST_DWLayer, self).__init__()
|
||||
self.sca_out_features = sca_out_features
|
||||
self.dw_layer = DWLayer(
|
||||
sca_in_features, sca_out_features*2,
|
||||
vec_in_features, vec_out_features, sca_act=sca_act,
|
||||
vec_act=vec_act, device=device, dtype=dtype
|
||||
)
|
||||
self.rescale = Rescale()
|
||||
|
||||
def forward(self, x):
|
||||
sca, vec = self.dw_layer(x)
|
||||
s_sca = self.rescale(sca[:, :self.sca_out_features])
|
||||
t_sca = sca[:, self.sca_out_features:]
|
||||
return s_sca, t_sca, vec'''
|
||||
|
||||
|
||||
|
||||
|
||||
class AtomEmbedding(nn.Module):
|
||||
def __init__(self, in_scalar, in_vector,
|
||||
out_scalar, out_vector, vector_normalizer=20.):
|
||||
super(AtomEmbedding, self).__init__()
|
||||
assert in_vector == 1
|
||||
self.in_scalar = in_scalar
|
||||
self.vector_normalizer = vector_normalizer
|
||||
self.emb_sca = nn.Linear(in_scalar, out_scalar)
|
||||
self.emb_vec = nn.Linear(in_vector, out_vector)
|
||||
|
||||
def forward(self, scalar_input, vector_input):
|
||||
if isinstance(self.vector_normalizer, float):
|
||||
vector_input = vector_input / self.vector_normalizer
|
||||
else:
|
||||
vector_input = vector_input / torch.norm(vector_input, p=2, dim=-1)
|
||||
assert vector_input.shape[1:] == (3, ), 'Not support. Only one vector can be input'
|
||||
sca_emb = self.emb_sca(scalar_input[:, :self.in_scalar]) # b, f -> b, f'
|
||||
vec_emb = vector_input.unsqueeze(-1) # b, 3 -> b, 3, 1
|
||||
vec_emb = self.emb_vec(vec_emb).transpose(1, -1) # b, 1, 3 -> b, f', 3
|
||||
return sca_emb, vec_emb
|
||||
|
||||
|
||||
def embed_compose(compose_feature, compose_pos, idx_ligand, idx_protein,
|
||||
ligand_atom_emb, protein_atom_emb, emb_dim):
|
||||
h_ligand = ligand_atom_emb(compose_feature[idx_ligand], compose_pos[idx_ligand])
|
||||
h_protein = protein_atom_emb(compose_feature[idx_protein], compose_pos[idx_protein])
|
||||
|
||||
h_sca = torch.zeros([len(compose_pos), emb_dim[0]],).to(h_ligand[0])
|
||||
h_vec = torch.zeros([len(compose_pos), emb_dim[1], 3],).to(h_ligand[1])
|
||||
h_sca[idx_ligand], h_sca[idx_protein] = h_ligand[0], h_protein[0]
|
||||
h_vec[idx_ligand], h_vec[idx_protein] = h_ligand[1], h_protein[1]
|
||||
return [h_sca, h_vec]
|
||||
|
||||
|
||||
class SmoothCrossEntropyLoss(_WeightedLoss):
|
||||
def __init__(self, weight=None, reduction='mean', smoothing=0.0):
|
||||
super().__init__(weight=weight, reduction=reduction)
|
||||
self.smoothing = smoothing
|
||||
self.weight = weight
|
||||
self.reduction = reduction
|
||||
|
||||
@staticmethod
|
||||
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
|
||||
assert 0 <= smoothing < 1
|
||||
with torch.no_grad():
|
||||
targets = torch.empty(size=(targets.size(0), n_classes),
|
||||
device=targets.device) \
|
||||
.fill_(smoothing /(n_classes-1)) \
|
||||
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
|
||||
return targets
|
||||
|
||||
def forward(self, inputs, targets):
|
||||
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
|
||||
self.smoothing)
|
||||
lsm = F.log_softmax(inputs, -1)
|
||||
|
||||
if self.weight is not None:
|
||||
lsm = lsm * self.weight.unsqueeze(0)
|
||||
|
||||
loss = -(targets * lsm).sum(-1)
|
||||
|
||||
if self.reduction == 'sum':
|
||||
loss = loss.sum()
|
||||
elif self.reduction == 'mean':
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
@@ -1,201 +0,0 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from .net_utils import AtomEmbedding, embed_compose
|
||||
from .encoder import ContextEncoder
|
||||
from .atom_flow import AtomFlow
|
||||
from .bond_predictor import BondFlowNew
|
||||
from .position_predictor import PositionPredictor
|
||||
from .pos_filter import PositionFilter
|
||||
from .focal_net import FrontierLayerVN
|
||||
from easydict import EasyDict
|
||||
from torch_scatter import scatter_add
|
||||
#import sys
|
||||
#sys.path.append("..")
|
||||
#from utils import get_tri_edges
|
||||
#from generate_utils import add_ligand_atom_to_data, data2mol
|
||||
#from rdkit import Chem
|
||||
|
||||
|
||||
encoder_cfg = EasyDict(
|
||||
{'edge_channels':8, 'num_interactions':6,
|
||||
'knn':32, 'cutoff':10.0}
|
||||
)
|
||||
focal_net_cfg = EasyDict(
|
||||
{'hidden_dim_sca':32, 'hidden_dim_vec':8}
|
||||
)
|
||||
atom_flow_cfg = EasyDict(
|
||||
{'hidden_dim_sca':32, 'hidden_dim_vec':8, 'num_flow_layers':6}
|
||||
)
|
||||
pos_predictor_cfg = EasyDict(
|
||||
{'num_filters':[64,64], 'n_component':3}
|
||||
)
|
||||
pos_filter_cfg = EasyDict(
|
||||
{'edge_channels':8, 'num_filters':[32,16]}
|
||||
)
|
||||
edge_flow_cfg = EasyDict(
|
||||
{'edge_channels':8, 'num_filters':[32,8], 'num_bond_types':3,
|
||||
'num_heads':2, 'cutoff':10.0, 'num_flow_layers':3}
|
||||
)
|
||||
config = EasyDict(
|
||||
{'deq_coeff':0.9, 'hidden_channels':32, 'hidden_channels_vec':8, 'bottleneck':8, 'use_conv1d':False,
|
||||
'encoder':encoder_cfg, 'atom_flow':atom_flow_cfg, 'pos_predictor':pos_predictor_cfg,
|
||||
'pos_filter':pos_filter_cfg, 'edge_flow':edge_flow_cfg, 'focal_net':focal_net_cfg}
|
||||
)
|
||||
|
||||
|
||||
class PocketFlowWithEdgeNew(nn.Module):
|
||||
def __init__(self, config) -> None:
|
||||
super(PocketFlowWithEdgeNew, self).__init__()
|
||||
self.config = config
|
||||
self.num_bond_types = config.num_bond_types
|
||||
self.msg_annealing = config.msg_annealing
|
||||
self.emb_dim = [config.hidden_channels, config.hidden_channels_vec]
|
||||
self.protein_atom_emb = AtomEmbedding(config.protein_atom_feature_dim, 1, *self.emb_dim)
|
||||
self.ligand_atom_emb = AtomEmbedding(config.ligand_atom_feature_dim, 1, *self.emb_dim)
|
||||
self.atom_type_embedding = nn.Embedding(config.num_atom_type, config.hidden_channels)
|
||||
|
||||
self.encoder = ContextEncoder(
|
||||
hidden_channels=self.emb_dim, edge_channels=config.encoder.edge_channels,
|
||||
num_edge_types=config.num_bond_types, num_interactions=config.encoder.num_interactions,
|
||||
k=config.encoder.knn, cutoff=config.encoder.cutoff, bottleneck=config.bottleneck,
|
||||
use_conv1d=config.use_conv1d, num_heads=config.encoder.num_heads
|
||||
)
|
||||
self.focal_net = FrontierLayerVN(
|
||||
self.emb_dim[0], self.emb_dim[1], config.focal_net.hidden_dim_sca,
|
||||
config.focal_net.hidden_dim_vec, bottleneck=config.bottleneck,
|
||||
use_conv1d=config.use_conv1d
|
||||
)
|
||||
self.atom_flow = AtomFlow(
|
||||
self.emb_dim[0], self.emb_dim[1], config.atom_flow.hidden_dim_sca,
|
||||
config.atom_flow.hidden_dim_vec, num_lig_atom_type=config.num_atom_type,
|
||||
num_flow_layers=config.atom_flow.num_flow_layers, bottleneck=config.bottleneck,
|
||||
use_conv1d=config.use_conv1d
|
||||
)
|
||||
self.pos_predictor = PositionPredictor(
|
||||
self.emb_dim[0], self.emb_dim[1], config.pos_predictor.num_filters,
|
||||
config.pos_predictor.n_component, bottleneck=config.bottleneck,
|
||||
use_conv1d=config.use_conv1d
|
||||
)
|
||||
'''self.pos_filter = PositionFilter(
|
||||
self.emb_dim[0], self.emb_dim[1], config.pos_filter.edge_channels,
|
||||
config.pos_filter.num_filters, bottleneck=config.bottleneck,
|
||||
use_conv1d=config.use_conv1d
|
||||
)'''
|
||||
self.edge_flow = BondFlowNew(
|
||||
self.emb_dim[0], self.emb_dim[1], config.edge_flow.edge_channels,
|
||||
config.edge_flow.num_filters, config.edge_flow.num_bond_types,
|
||||
num_heads=config.edge_flow.num_heads, cutoff=config.edge_flow.cutoff,
|
||||
num_st_layers=config.edge_flow.num_flow_layers, bottleneck=config.bottleneck,
|
||||
use_conv1d=config.use_conv1d
|
||||
)
|
||||
|
||||
def get_parameter_number(self):
|
||||
total_num = sum(p.numel() for p in self.parameters())
|
||||
trainable_num = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
return {'Total': total_num, 'Trainable': trainable_num}
|
||||
|
||||
def get_loss(self, data):
|
||||
h_cpx = embed_compose(data.cpx_feature.float(), data.cpx_pos, data.idx_ligand_ctx_in_cpx,
|
||||
data.idx_protein_in_cpx, self.ligand_atom_emb,
|
||||
self.protein_atom_emb, self.emb_dim)
|
||||
# encoding context
|
||||
h_cpx = self.encoder(
|
||||
node_attr = h_cpx,
|
||||
pos = data.cpx_pos,
|
||||
edge_index = data.cpx_edge_index,
|
||||
edge_feature = data.cpx_edge_feature,
|
||||
annealing=self.msg_annealing
|
||||
)
|
||||
# for focal loss
|
||||
focal_pred = self.focal_net(h_cpx, data.idx_ligand_ctx_in_cpx)
|
||||
focal_loss = F.binary_cross_entropy_with_logits(
|
||||
input=focal_pred, target=data.ligand_frontier.view(-1, 1).float()
|
||||
)
|
||||
# for focal loss in protein
|
||||
focal_pred_apo = self.focal_net(h_cpx, data.apo_protein_idx)
|
||||
surf_loss = F.binary_cross_entropy_with_logits(
|
||||
input=focal_pred_apo, target=data.candidate_focal_label_in_protein.view(-1, 1).float()
|
||||
)
|
||||
# for atom loss
|
||||
x_z = F.one_hot(data.atom_label, num_classes=self.config.num_atom_type).float() #[50,27]
|
||||
x_z += self.config.deq_coeff * torch.rand(x_z.size(), device=x_z.device) #[50,27]
|
||||
z_atom, atom_log_jacob = self.atom_flow(x_z, h_cpx, data.focal_idx_in_context)
|
||||
ll_atom = (1/2 * (z_atom ** 2) - atom_log_jacob).sum(-1)
|
||||
ll_atom = scatter_add(ll_atom, data.atom_label_batch, dim=0).mean()
|
||||
|
||||
# for position loss
|
||||
atom_type_emb = self.atom_type_embedding(data.atom_label)
|
||||
relative_mu, abs_mu, sigma, pi = self.pos_predictor(
|
||||
h_cpx, # +atom_type_emb[data.step_batch]
|
||||
data.focal_idx_in_context,
|
||||
data.cpx_pos,
|
||||
atom_type_emb=atom_type_emb,
|
||||
)
|
||||
#y_pos = torch.rand_like(data.y_pos) * 0.05 + data.y_pos
|
||||
loss_pos = -torch.log(
|
||||
self.pos_predictor.get_mdn_probability(abs_mu, sigma, pi, data.y_pos) + 1e-16
|
||||
).mean()#.clamp_max(10.) # 最大似然log
|
||||
#loss_pos = scatter_add(loss_pos, data.atom_label_batch, dim=0).mean()
|
||||
|
||||
# for edge loss
|
||||
z_edge = F.one_hot(data.edge_label, num_classes=4).float()
|
||||
z_edge += self.config.deq_coeff * torch.rand(z_edge.size(), device=z_edge.device)
|
||||
edge_index_query = torch.stack([data.edge_query_index_0, data.edge_query_index_1])
|
||||
pos_query_knn_edge_idx = torch.stack(
|
||||
[data.pos_query_knn_edge_idx_0, data.pos_query_knn_edge_idx_1]
|
||||
)
|
||||
z_edge, edge_log_jacob = self.edge_flow(
|
||||
z_edge=z_edge,
|
||||
pos_query=data.y_pos,
|
||||
edge_index_query=edge_index_query,
|
||||
cpx_pos=data.cpx_pos,
|
||||
node_attr_compose=h_cpx,
|
||||
edge_index_q_cps_knn=pos_query_knn_edge_idx,
|
||||
index_real_cps_edge_for_atten=data.index_real_cps_edge_for_atten,
|
||||
tri_edge_index=data.tri_edge_index,
|
||||
tri_edge_feat=data.tri_edge_feat,
|
||||
atom_type_emb=atom_type_emb,
|
||||
annealing=self.msg_annealing
|
||||
)
|
||||
ll_edge = (1/2 * (z_edge ** 2) - edge_log_jacob).sum(-1)
|
||||
ll_edge = scatter_add(ll_edge, data.edge_label_batch, dim=0).mean()
|
||||
|
||||
# pos filter
|
||||
'''pos_fake_knn_edge_idx = torch.stack([data.pos_fake_knn_edge_idx_0, data.pos_fake_knn_edge_idx_1])
|
||||
pos_fake_pred, _ = self.pos_filter(
|
||||
pos_query=data.pos_fake,
|
||||
edge_index_q_cps_knn=pos_fake_knn_edge_idx,
|
||||
cpx_pos=data.cpx_pos,
|
||||
node_attr_compose=h_cpx,
|
||||
annealing=annealing
|
||||
)
|
||||
loss_fake = F.binary_cross_entropy_with_logits(
|
||||
input=pos_fake_pred, target=torch.zeros_like(pos_fake_pred)
|
||||
)
|
||||
pos_real_knn_edge_idx = torch.stack([data.pos_real_knn_edge_idx_0, data.pos_real_knn_edge_idx_1])
|
||||
pos_real_pred, _ = self.pos_filter(
|
||||
pos_query=data.pos_real,
|
||||
edge_index_q_cps_knn=pos_real_knn_edge_idx,
|
||||
cpx_pos=data.cpx_pos,
|
||||
node_attr_compose=h_cpx,
|
||||
annealing=annealing
|
||||
)
|
||||
loss_real = F.binary_cross_entropy_with_logits(
|
||||
input=pos_real_pred, target=torch.ones_like(pos_real_pred)
|
||||
)'''
|
||||
# loss all
|
||||
loss = torch.nan_to_num(ll_atom)\
|
||||
+ torch.nan_to_num(loss_pos)\
|
||||
+ torch.nan_to_num(ll_edge)\
|
||||
+ torch.nan_to_num(focal_loss)\
|
||||
+ torch.nan_to_num(surf_loss)\
|
||||
#+ torch.nan_to_num(loss_fake)\
|
||||
#+ torch.nan_to_num(loss_real)\
|
||||
|
||||
out_dict = {
|
||||
'loss':loss, 'loss_atom':ll_atom, 'loss_edge':ll_edge, #'loss_fake':loss_fake, 'loss_real':loss_real,
|
||||
'loss_pos':loss_pos, 'focal_loss':focal_loss, 'surf_loss':torch.nan_to_num(surf_loss)
|
||||
}
|
||||
return out_dict
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import torch
|
||||
from torch.nn import Module, Sequential
|
||||
from torch.nn import functional as F
|
||||
from .layers import GBPerceptronVN, GBLinear
|
||||
import math
|
||||
|
||||
GAUSSIAN_COEF = 1.0 / math.sqrt(2 * math.pi)
|
||||
|
||||
class PositionPredictor(Module):
|
||||
def __init__(self, in_sca, in_vec, num_filters, n_component, bottleneck=1, use_conv1d=False):
|
||||
super(PositionPredictor, self).__init__()
|
||||
self.n_component = n_component
|
||||
self.gvp = Sequential(
|
||||
GBPerceptronVN(
|
||||
in_sca*2, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
),
|
||||
GBLinear(
|
||||
num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck,
|
||||
use_conv1d=use_conv1d
|
||||
)
|
||||
)
|
||||
self.mu_net = GBLinear(
|
||||
num_filters[0], num_filters[1], n_component, n_component, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.logsigma_net= GBLinear(
|
||||
num_filters[0], num_filters[1], n_component, n_component, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
self.pi_net = GBLinear(
|
||||
num_filters[0], num_filters[1], n_component, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
|
||||
)
|
||||
|
||||
def forward(self, h_compose, idx_focal, pos_compose, atom_type_emb=None):#, focal_step_batch=None):
|
||||
h_focal = [h[idx_focal] for h in h_compose]
|
||||
pos_focal = pos_compose[idx_focal]
|
||||
if isinstance(atom_type_emb, torch.Tensor):# and isinstance(focal_step_batch, torch.Tensor):
|
||||
h_focal[0] = torch.cat([h_focal[0], atom_type_emb], dim=1) # 可以直接相乘减少一些训练参数
|
||||
#h_focal = [h_focal[0] * atom_type_emb, h_focal[1]]
|
||||
|
||||
feat_focal = self.gvp(h_focal)
|
||||
relative_mu = self.mu_net(feat_focal)[1] # (N_focal, n_component, 3)
|
||||
logsigma = self.logsigma_net(feat_focal)[1] # (N_focal, n_component, 3)
|
||||
sigma = torch.exp(logsigma)
|
||||
pi = self.pi_net(feat_focal)[0] # (N_focal, n_component)
|
||||
pi = F.softmax(pi, dim=1)
|
||||
|
||||
abs_mu = relative_mu + pos_focal.unsqueeze(dim=1).expand_as(relative_mu)
|
||||
return relative_mu, abs_mu, sigma, pi
|
||||
|
||||
def get_mdn_probability(self, mu, sigma, pi, pos_target):
|
||||
prob_gauss = self._get_gaussian_probability(mu, sigma, pos_target)
|
||||
prob_mdn = pi * prob_gauss
|
||||
prob_mdn = torch.sum(prob_mdn, dim=1)
|
||||
return prob_mdn
|
||||
|
||||
def _get_gaussian_probability(self, mu, sigma, pos_target):
|
||||
"""
|
||||
mu - (N, n_component, 3)
|
||||
sigma - (N, n_component, 3)
|
||||
pos_target - (N, 3)
|
||||
"""
|
||||
target = pos_target.unsqueeze(1).expand_as(mu)
|
||||
errors = target - mu # 最大概率似然没有不变性,使用L2范数?
|
||||
sigma = sigma + 1e-16
|
||||
p = GAUSSIAN_COEF * torch.exp(-0.5 * (errors / sigma)**2) / sigma
|
||||
p = torch.prod(p, dim=2)
|
||||
return p # (N, n_component)
|
||||
|
||||
def sample_batch(self, mu, sigma, pi, num):
|
||||
"""sample from multiple mix gaussian
|
||||
mu - (N_batch, n_cat, 3)
|
||||
sigma - (N_batch, n_cat, 3)
|
||||
pi - (N_batch, n_cat)
|
||||
return
|
||||
(N_batch, num, 3)
|
||||
"""
|
||||
index_cats = torch.multinomial(pi, num, replacement=True) # (N_batch, num)
|
||||
# index_cats = index_cats.unsqueeze(-1)
|
||||
index_batch = torch.arange(len(mu)).unsqueeze(-1).expand(-1, num) # (N_batch, num)
|
||||
mu_sample = mu[index_batch, index_cats] # (N_batch, num, 3)
|
||||
sigma_sample = sigma[index_batch, index_cats]
|
||||
values = torch.normal(mu_sample, sigma_sample) # (N_batch, num, 3)
|
||||
return values
|
||||
|
||||
def get_maximum(self, mu, sigma, pi):
|
||||
"""sample from multiple mix gaussian
|
||||
mu - (N_batch, n_cat, 3)
|
||||
sigma - (N_batch, n_cat, 3)
|
||||
pi - (N_batch, n_cat)
|
||||
return
|
||||
(N_batch, n_cat, 3)
|
||||
"""
|
||||
return mu
|
||||
Reference in New Issue
Block a user