Delete pocket_flow/gdbp_model/.ipynb_checkpoints directory

This commit is contained in:
Saoge123
2023-06-16 18:06:00 +08:00
committed by GitHub
parent ab21175d39
commit 0035092502
8 changed files with 0 additions and 1493 deletions

View File

@@ -1,11 +0,0 @@
from .atom_flow import AtomFlow
from .bond_predictor import BondPredictor, BondFlow, BondFlowNew
from .position_predictor import PositionPredictor
from .focal_net import FrontierLayerVN
from .encoder import ContextEncoder
from .pos_filter import PositionEncoder, PositionFilter
from .layers import *
from .net_utils import *
from .pocket_flow_with_edge import *
from .pocket_flow_with_edge_new import *
from .pocket_flow import *

View File

@@ -1,52 +0,0 @@
import torch
from torch import nn
from torch.nn import functional as F
from .layers import GBPerceptronVN, GBLinear, ST_GBP_Exp
class AtomFlow(nn.Module):
def __init__(self, in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, num_lig_atom_type=10,
num_flow_layers=6, bottleneck=1, use_conv1d=False) -> None:
super(AtomFlow, self).__init__()
'''self.msg_module = MessageModule(
in_sca, in_vec, in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, cutoff=10.
)'''
self.net = nn.Sequential(
GBPerceptronVN(
in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
),
GBLinear(
hidden_dim_sca, hidden_dim_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=bottleneck,
use_conv1d=use_conv1d
)
)
self.flow_layers = nn.ModuleList()
for _ in enumerate(range(num_flow_layers)):
layer = ST_GBP_Exp(
hidden_dim_sca, hidden_dim_vec, num_lig_atom_type, hidden_dim_vec, bottleneck=bottleneck,
use_conv1d=use_conv1d
)
self.flow_layers.append(layer)
def forward(self, z_atom, compose_features, focal_idx):
sca_focal, vec_focal = compose_features[0][focal_idx], compose_features[1][focal_idx]
sca_focal, vec_focal = self.net([sca_focal, vec_focal])
for ix in range(len(self.flow_layers)):
s, t = self.flow_layers[ix]([sca_focal, vec_focal])
s = s.exp()
z_atom = (z_atom + t) * s
if ix == 0:
atom_log_jacob = (torch.abs(s) + 1e-20).log()
else:
atom_log_jacob += (torch.abs(s) + 1e-20).log()
return z_atom, atom_log_jacob
def reverse(self, atom_latent, compose_features, focal_idx):
sca_focal, vec_focal = compose_features[0][focal_idx], compose_features[1][focal_idx]
sca_focal, vec_focal = self.net([sca_focal, vec_focal])
for ix in range(len(self.flow_layers)):
s, t = self.flow_layers[ix]([sca_focal, vec_focal])
atom_latent = (atom_latent / s.exp()) - t
return atom_latent

View File

@@ -1,440 +0,0 @@
import torch
from torch import nn
from torch.nn import Module, Sequential
from torch.nn import functional as F
from torch_scatter import scatter_add
#from .pos_filter import PositionEncoder
from .layers import GBPerceptronVN, GBLinear, MessageModule, MessageAttention, AttentionEdges, ST_GBP_Exp, VNLeakyReLU
from .net_utils import GaussianSmearing, EdgeExpansion
import math
GAUSSIAN_COEF = 1.0 / math.sqrt(2 * math.pi)
class BondPredictor(Module):
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types,
num_heads=4, cutoff=10.0, with_root=True, bottleneck=1):
super(BondPredictor, self).__init__()
self.with_root = with_root
self.num_bond_types = num_bond_types
self.message_module = MessageModule(
in_sca, in_vec, edge_channels, edge_channels, num_filters[0], num_filters[1], cutoff=cutoff,
bottleneck=bottleneck
)
self.nn_edge_ij = Sequential(
GBPerceptronVN(edge_channels, edge_channels, num_filters[0], num_filters[1], bottleneck=bottleneck),
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck)
)
self.edge_feat = Sequential(
GBPerceptronVN(num_filters[0] * 2 + in_sca, num_filters[1] * 2 + in_vec, num_filters[0], num_filters[1],
bottleneck=bottleneck),
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck)
)
self.edge_atten = AttentionEdges(
num_filters, num_filters, num_heads, num_bond_types, bottleneck=bottleneck
)
self.edge_pred = GBLinear(num_filters[0], num_filters[1], num_bond_types + 1, 1, bottleneck=bottleneck)
if with_root:
self.root_lin = GBLinear(in_sca, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck)
self.root_vector_expansion = EdgeExpansion(edge_channels)
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
self.distance_expansion_3A = GaussianSmearing(stop=3., num_gaussians=edge_channels)
self.vector_expansion = EdgeExpansion(edge_channels) # Linear(in_features=1, out_features=edge_channels, bias=False)
def forward(self, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[], atom_type_emb=None):
vec_ij = pos_query[edge_index_q_cps_knn[0]] - cpx_pos[edge_index_q_cps_knn[1]]
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (A, 1)
edge_ij = self.distance_expansion(dist_ij), self.vector_expansion(vec_ij)
# node_attr_ctx_j = [node_attr_ctx_[edge_index_q_cps_knn[1]] for node_attr_ctx_ in node_attr_ctx] # (A, H)
h = self.message_module(node_attr_compose, edge_ij, edge_index_q_cps_knn[1], dist_ij, annealing=True)
# Aggregate messages
y = [scatter_add(h[0], index=edge_index_q_cps_knn[0], dim=0, dim_size=pos_query.size(0)), # (N_query, F)
scatter_add(h[1], index=edge_index_q_cps_knn[0], dim=0, dim_size=pos_query.size(0))]
# add information of new atom
if isinstance(atom_type_emb, torch.Tensor):
root_vec_ij = self.root_vector_expansion(pos_query)
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
y = [y_root_sca+y[0], y_root_vec+y[1]]
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
# print(edge_index_query.shape)
idx_node_i = edge_index_query[0]
node_mol_i = [
y[0][idx_node_i],
y[1][idx_node_i]
]
idx_node_j = edge_index_query[1]
node_mol_j = [
node_attr_compose[0][idx_node_j],
node_attr_compose[1][idx_node_j]
]
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
edge_attr = self.edge_feat(edge_attr) # (E, N_edgetype)
edge_attr = self.edge_atten(edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten, tri_edge_index, tri_edge_feat)
edge_pred, _ = self.edge_pred(edge_attr)
else:
edge_pred = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
return edge_pred
class ST_AttEdge_Exp(torch.nn.Module):
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types=3,
num_heads=4, bottleneck=1, use_conv1d=False):
super(ST_AttEdge_Exp, self).__init__()
self.num_bond_types = num_bond_types
self.nn_edge_ij = Sequential(
GBPerceptronVN(
edge_channels, edge_channels, num_filters[0], num_filters[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
),
GBLinear(
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
)
)
self.edge_feat = Sequential(
GBPerceptronVN(
num_filters[0] * 2 + in_sca, num_filters[1] * 2 + in_vec, num_filters[0],
num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
),
GBLinear(
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
)
)
self.edge_atten = AttentionEdges(num_filters, num_filters, num_heads, num_bond_types)
self.edge_pred = GBLinear(
num_filters[0], num_filters[1], (num_bond_types + 1) * 2, 1, bottleneck=bottleneck,
use_conv1d=use_conv1d
)
self.distance_expansion_3A = GaussianSmearing(stop=3., num_gaussians=edge_channels)
self.vector_expansion = EdgeExpansion(edge_channels) # Linear(in_features=1, out_features=edge_channels, bias=False)
def forward(self, h_atom, pos_query, edge_index_query, cpx_pos, node_attr_compose, index_real_cps_edge_for_atten=[],
tri_edge_index=[], tri_edge_feat=[]):
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
# print(edge_index_query.shape)
idx_node_i = edge_index_query[0]
node_mol_i = [
h_atom[0][idx_node_i],
h_atom[1][idx_node_i]
]
idx_node_j = edge_index_query[1]
node_mol_j = [
node_attr_compose[0][idx_node_j],
node_attr_compose[1][idx_node_j]
]
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
edge_attr = self.edge_feat(edge_attr) # (E, N_edgetype)
edge_attr = self.edge_atten(edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten, tri_edge_index, tri_edge_feat)
edge_pred, _ = self.edge_pred(edge_attr)
s_edge, t_edge = edge_pred[:,:self.num_bond_types+1], edge_pred[:,self.num_bond_types+1:]
else:
s_edge = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
t_edge = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
return s_edge, t_edge
class BondFlow(torch.nn.Module):
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types,
num_heads=4, cutoff=10.0, with_root=True, num_st_layers=3,
bottleneck=1, use_conv1d=False):
super(BondFlow, self).__init__()
self.with_root = with_root
self.num_bond_types = num_bond_types
self.num_st_layers = num_st_layers
self.pos_encoder = PositionEncoder(
in_sca, in_vec, edge_channels, num_filters, cutoff=cutoff, bottleneck=bottleneck,
use_conv1d=use_conv1d
)
self.pos_filter = torch.nn.Sequential(
GBPerceptronVN(
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
),
GBLinear(
num_filters[0], num_filters[1], 1, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
)
)
if with_root:
self.root_lin = GBLinear(
in_sca, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.root_vector_expansion = EdgeExpansion(edge_channels)
self.flow_layers = torch.nn.ModuleList()
for _ in range(num_st_layers):
flow_layer = ST_AttEdge_Exp(
in_sca, in_vec, edge_channels,
num_filters, num_bond_types=num_bond_types, num_heads=num_heads,
bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.flow_layers.append(flow_layer)
def forward(self, z_edge, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[], atom_type_emb=None,
annealing=False):
y = self.pos_encoder(
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, atom_type_emb,
annealing=annealing
)
'''if isinstance(atom_type_emb, torch.Tensor) and self.with_root:
root_vec_ij = self.root_vector_expansion(pos_query)
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
y = [y_root_sca+y[0], y_root_vec+y[1]]'''
for ix in range(self.num_st_layers):
s, t = self.flow_layers[ix](
y, pos_query, edge_index_query, cpx_pos, node_attr_compose,
index_real_cps_edge_for_atten=index_real_cps_edge_for_atten,
tri_edge_index=tri_edge_index, tri_edge_feat=tri_edge_feat
)
s = s.exp()
z_edge = (z_edge + t) * s
if ix == 0:
edge_log_jacob = (torch.abs(s) + 1e-20).log()
else:
edge_log_jacob += (torch.abs(s) + 1e-20).log()
return z_edge, edge_log_jacob
def reverse(self, edge_latent, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[], atom_type_emb=None,
annealing=False):
y = self.pos_encoder(
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, atom_type_emb,
annealing=annealing
)
'''if isinstance(atom_type_emb, torch.Tensor) and self.with_root:
root_vec_ij = self.root_vector_expansion(pos_query)
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
y = [y_root_sca+y[0], y_root_vec+y[1]]'''
for ix in range(self.num_st_layers):
s, t = self.flow_layers[ix](
y, pos_query, edge_index_query, cpx_pos, node_attr_compose,
index_real_cps_edge_for_atten=index_real_cps_edge_for_atten,
tri_edge_index=tri_edge_index, tri_edge_feat=tri_edge_feat
)
if s.size(0)==0 and t.size(0)==0:
break
else:
edge_latent = (edge_latent / s.exp()) - t
if s.size(0)==0 and t.size(0)==0:
return torch.empty([0, self.num_bond_types+1], device=pos_query.device)
else:
return edge_latent
def pos_classfier(self, pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, annealing=False):
y = self.pos_encoder(
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, annealing=annealing
)
pred = self.pos_filter(y)
return pred
class PositionEncoder(Module):
def __init__(self, in_sca, in_vec, edge_channels, num_filters, bottleneck=1, cutoff=10.,
num_heads=1, use_conv1d=False, with_root=True) -> None:
super(PositionEncoder, self).__init__()
self.message_module = MessageModule(
in_sca, in_vec, edge_channels, edge_channels, num_filters[0], num_filters[1],
bottleneck, cutoff, use_conv1d=use_conv1d
)
self.message_att = MessageAttention(
num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck,
num_heads=num_heads, use_conv1d=use_conv1d
)
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels)
self.vector_expansion = EdgeExpansion(edge_channels)
self.root_lin = GBLinear(
in_sca, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.root_vector_expansion = EdgeExpansion(edge_channels)
#self.act_sca = LeakyReLU()
#self.act_vec = VNLeakyReLU(hidden_channels[1], share_nonlinearity=True) # 2023.1.13
'''self.out_transform = GBLinear(
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.layernorm_sca = nn.LayerNorm([num_filters[0]])
self.layernorm_vec = nn.LayerNorm([num_filters[1], 3])''' # 2023.1.13 注释掉
def forward(self, pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose, atom_type_emb,
annealing=False):
vec_ij = pos_query[edge_index_q_cps_knn[0]] - cpx_pos[edge_index_q_cps_knn[1]]
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (A, 1)
edge_ij = self.distance_expansion(dist_ij), self.vector_expansion(vec_ij)
#if isinstance(atom_type_emb, torch.Tensor) and self.with_root:
root_vec_ij = self.root_vector_expansion(pos_query)
y_root_sca, y_root_vec = self.root_lin([atom_type_emb, root_vec_ij])
x = [y_root_sca, y_root_vec]
# node_attr_ctx_j = [node_attr_ctx_[edge_index_q_cps_knn[1]] for node_attr_ctx_ in node_attr_ctx] # (A, H)
h_q = self.message_module(node_attr_compose, edge_ij, edge_index_q_cps_knn[1], dist_ij, annealing=annealing)
y = self.message_att(x, h_q, edge_index_q_cps_knn[0])
'''# non-linear
out_sca = self.layernorm_sca(y[0])
out_vec = self.layernorm_vec(y[1])
y = self.out_transform((self.act_sca(out_sca), self.act_vec(out_vec)))''' # 2023.1.3
return y
class BondFlowNew(Module):
def __init__(self, in_sca, in_vec, edge_channels, num_filters, num_bond_types=3, num_heads=4,
cutoff=10.0, with_root=True, num_st_layers=6, bottleneck=1, use_conv1d=False):
super(BondFlowNew, self).__init__()
self.with_root = with_root
self.num_bond_types = num_bond_types
self.num_st_layers = num_st_layers
## query encoder
self.pos_encoder = PositionEncoder(
in_sca, in_vec, edge_channels, num_filters, cutoff=cutoff, bottleneck=bottleneck,
use_conv1d=use_conv1d
)
## query filter
'''self.pos_filter = torch.nn.Sequential(
GBPerceptronVN(
num_filters[0], num_filters[1], num_filters[0], num_filters[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
),
GBLinear(
num_filters[0], num_filters[1], 1, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
)
)'''
## edge pred
self.distance_expansion_3A = GaussianSmearing(stop=3., num_gaussians=edge_channels)
self.vector_expansion = EdgeExpansion(edge_channels)
self.nn_edge_ij = Sequential(
GBPerceptronVN(edge_channels, edge_channels, num_filters[0], num_filters[1]),
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1])
)
self.edge_feat = Sequential(
GBPerceptronVN(num_filters[0] * 2 + in_sca, num_filters[1] * 2 + in_vec, num_filters[0], num_filters[1]),
GBLinear(num_filters[0], num_filters[1], num_filters[0], num_filters[1])
)
self.edge_atten = AttentionEdges(num_filters, num_filters, num_heads, num_bond_types)
## flow layer
self.flow_layers = torch.nn.ModuleList()
for _ in range(num_st_layers):
flow_layer = ST_GBP_Exp(
num_filters[0],
num_filters[1],
num_bond_types + 1,
num_filters[1],
bottleneck=bottleneck,
use_conv1d=use_conv1d
)
self.flow_layers.append(flow_layer)
def forward(self, z_edge, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
atom_type_emb, index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[],
annealing=False):
y = self.pos_encoder(
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose,
atom_type_emb, annealing=annealing
)
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
idx_node_i = edge_index_query[0]
node_mol_i = [
y[0][idx_node_i],
y[1][idx_node_i]
]
idx_node_j = edge_index_query[1]
node_mol_j = [
node_attr_compose[0][idx_node_j],
node_attr_compose[1][idx_node_j]
]
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
edge_attr = self.edge_feat(edge_attr)
edge_attr = self.edge_atten(
edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten,
tri_edge_index, tri_edge_feat
)
#self.edge_atten()
for ix in range(len(self.flow_layers)):
s, t = self.flow_layers[ix](edge_attr)
s = s.exp()
z_edge = (z_edge + t) * s
if ix == 0:
edge_log_jacob = (torch.abs(s) + 1e-20).log()
else:
edge_log_jacob += (torch.abs(s) + 1e-20).log()
return z_edge, edge_log_jacob
else:
z_edge = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
edge_log_jacob = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
return z_edge, edge_log_jacob
def reverse(self, edge_latent, pos_query, edge_index_query, cpx_pos, node_attr_compose, edge_index_q_cps_knn,
atom_type_emb, index_real_cps_edge_for_atten=[], tri_edge_index=[], tri_edge_feat=[],
annealing=False):
y = self.pos_encoder(
pos_query, edge_index_q_cps_knn, cpx_pos, node_attr_compose,
atom_type_emb, annealing=annealing
)
if (len(edge_index_query) != 0) and (edge_index_query.size(1) > 0):
idx_node_i = edge_index_query[0]
node_mol_i = [
y[0][idx_node_i],
y[1][idx_node_i]
]
idx_node_j = edge_index_query[1]
node_mol_j = [
node_attr_compose[0][idx_node_j],
node_attr_compose[1][idx_node_j]
]
vec_ij = pos_query[idx_node_i] - cpx_pos[idx_node_j]
dist_ij = torch.norm(vec_ij, p=2, dim=-1).view(-1, 1) # (E, 1)
edge_ij = self.distance_expansion_3A(dist_ij), self.vector_expansion(vec_ij)
edge_feat = self.nn_edge_ij(edge_ij) # (E, F)
edge_attr = (torch.cat([node_mol_i[0], node_mol_j[0], edge_feat[0]], dim=-1), # (E, F)
torch.cat([node_mol_i[1], node_mol_j[1], edge_feat[1]], dim=1))
edge_attr = self.edge_feat(edge_attr)
edge_attr = self.edge_atten(
edge_attr, edge_index_query, cpx_pos, index_real_cps_edge_for_atten,
tri_edge_index, tri_edge_feat
)
#self.edge_atten()
for ix in range(len(self.flow_layers)):
s, t = self.flow_layers[ix](edge_attr)
edge_latent = (edge_latent / s.exp()) - t
return edge_latent
else:
edge_latent = torch.empty([0, self.num_bond_types+1], device=pos_query.device)
return edge_latent

View File

@@ -1,27 +0,0 @@
import torch
from torch.nn import Module, Sequential
from torch.nn import functional as F
from .layers import GBPerceptronVN, GBLinear
class FrontierLayerVN(Module):
def __init__(self, in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=1,
use_conv1d=False):
super(FrontierLayerVN, self).__init__()
self.net = Sequential(
GBPerceptronVN(
in_sca, in_vec, hidden_dim_sca, hidden_dim_vec, bottleneck=bottleneck,
use_conv1d=use_conv1d
),
GBLinear(
hidden_dim_sca, hidden_dim_vec, 1, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
)
)
def forward(self, h_att, idx_ligans):
h_att_ligand = [h_att[0][idx_ligans], h_att[1][idx_ligans]]
pred = self.net(h_att_ligand)
pred = pred[0]
return pred

View File

@@ -1,477 +0,0 @@
import torch
import torch.nn.functional as F
from torch.nn import Module, Linear, LeakyReLU, ModuleList, LayerNorm
import numpy as np
import torch.nn as nn
from torch_geometric.nn import global_mean_pool
from torch_scatter import scatter_sum, scatter_softmax
from math import pi as PI
from .net_utils import GaussianSmearing, EdgeExpansion, Rescale
EPS = 1e-6
class GBLinearConv1D(Module):
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=True):
super(GBLinearConv1D, self).__init__()
assert in_vector % bottleneck == 0,\
f"Input channel of vector ({in_vector}) must be divisible with bottleneck factor ({bottleneck})"
if bottleneck > 1:
self.hidden_dim = in_vector // bottleneck
else:
self.hidden_dim = max(in_vector, out_vector)
self.out_vector = out_vector
self.lin_vector = VNLinear(in_vector, self.hidden_dim, bias=False)
self.lin_vector2 = VNLinear(self.hidden_dim, out_vector, bias=False)
self.use_conv1d = use_conv1d
self.lin_scalar = nn.Conv1d(in_scalar + self.hidden_dim, out_scalar, 1, bias=False)
self.scalar_to_vector_gates = nn.Conv1d(out_scalar, out_vector, 1)
def forward(self, features):
feat_scalar, feat_vector = features
feat_vector_inter = self.lin_vector(feat_vector) # (N_samples, dim_hid, 3)
feat_vector_norm = torch.norm(feat_vector_inter, p=2, dim=-1) # (N_samples, dim_hid)
feat_scalar_cat = torch.cat([feat_vector_norm, feat_scalar], dim=-1) # (N_samples, dim_hid+in_scalar)
out_scalar = self.lin_scalar(feat_scalar_cat.unsqueeze(-1)).squeeze(-1)
gating = torch.sigmoid(self.scalar_to_vector_gates(out_scalar.unsqueeze(-1)))
out_vector = self.lin_vector2(feat_vector_inter)
out_vector = gating * out_vector
return out_scalar, out_vector
'''class GBLinear(Module):
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=False):
super(GBLinear, self).__init__()
assert in_vector % bottleneck == 0,\
f"Input channel of vector ({in_vector}) must be divisible with bottleneck factor ({bottleneck})"
if bottleneck > 1:
self.hidden_dim = in_vector // bottleneck
else:
self.hidden_dim = max(in_vector, out_vector)
self.out_vector = out_vector
self.lin_vector = VNLinear(in_vector, self.hidden_dim, bias=False)
self.lin_vector2 = VNLinear(self.hidden_dim, out_vector, bias=False)
self.use_conv1d = use_conv1d
self.scalar_to_vector_gates = Linear(out_scalar, out_vector)
self.lin_scalar = Linear(in_scalar + self.hidden_dim, out_scalar, bias=False)
def forward(self, features):
feat_scalar, feat_vector = features
feat_vector_inter = self.lin_vector(feat_vector) # (N_samples, dim_hid, 3)
feat_vector_norm = torch.norm(feat_vector_inter, p=2, dim=-1) # (N_samples, dim_hid)
feat_scalar_cat = torch.cat([feat_vector_norm, feat_scalar], dim=-1) # (N_samples, dim_hid+in_scalar)
out_scalar = self.lin_scalar(feat_scalar_cat)
gating = torch.sigmoid(self.scalar_to_vector_gates(out_scalar)).unsqueeze(-1)
out_vector = self.lin_vector2(feat_vector_inter)
out_vector = gating * out_vector
return out_scalar, out_vector'''
class GBLinear(Module):
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=(1,1), use_conv1d=False):
super(GBLinear, self).__init__()
if isinstance(bottleneck, int):
sca_bottleneck = bottleneck
vec_bottleneck = bottleneck
else:
sca_bottleneck = bottleneck[0]
vec_bottleneck = bottleneck[1]
assert in_vector % vec_bottleneck == 0,\
f"Input channel of vector ({in_vector}) must be divisible with bottleneck factor ({vec_bottleneck})"
assert in_scalar % sca_bottleneck == 0,\
f"Input channel of vector ({in_scalar}) must be divisible with bottleneck factor ({sca_bottleneck})"
if sca_bottleneck > 1:
self.sca_hidden_dim = in_scalar // sca_bottleneck
else:
self.sca_hidden_dim = max(in_vector, out_vector)
if vec_bottleneck > 1:
self.hidden_dim = in_vector // vec_bottleneck
else:
self.hidden_dim = max(in_vector, out_vector)
self.out_vector = out_vector
self.lin_vector = VNLinear(in_vector, self.hidden_dim, bias=False)
self.lin_vector2 = VNLinear(self.hidden_dim, out_vector, bias=False)
self.use_conv1d = use_conv1d
self.scalar_to_vector_gates = Linear(out_scalar, out_vector)
self.lin_scalar_1 = Linear(in_scalar, self.sca_hidden_dim, bias=False)
self.lin_scalar_2 = Linear(self.hidden_dim + self.sca_hidden_dim, out_scalar, bias=False)
def forward(self, features):
feat_scalar, feat_vector = features
feat_vector_inter = self.lin_vector(feat_vector) # (N_samples, dim_hid, 3)
feat_vector_norm = torch.norm(feat_vector_inter, p=2, dim=-1) # (N_samples, dim_hid)
z_sca = self.lin_scalar_1(feat_scalar)
feat_scalar_cat = torch.cat([feat_vector_norm, z_sca], dim=-1) # (N_samples, dim_hid+in_scalar)
#z_sca = self.lin_scalar_1(feat_scalar_cat)
out_scalar = self.lin_scalar_2(feat_scalar_cat)
gating = torch.sigmoid(self.scalar_to_vector_gates(out_scalar)).unsqueeze(-1)
out_vector = self.lin_vector2(feat_vector_inter)
out_vector = gating * out_vector
return out_scalar, out_vector
class GBPerceptronVN(Module):
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=False):
super(GBPerceptronVN, self).__init__()
self.gb_linear = GBLinear(
in_scalar, in_vector, out_scalar, out_vector, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.act_sca = LeakyReLU()
self.act_vec = VNLeakyReLU(out_vector)
def forward(self, x):
sca, vec = self.gb_linear(x)
vec = self.act_vec(vec)
sca = self.act_sca(sca)
return sca, vec
class VNLinear(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super(VNLinear, self).__init__()
self.map_to_feat = nn.Linear(in_channels, out_channels, *args, **kwargs)
def forward(self, x):
'''
x: point features of shape [B, N_samples, N_feat, 3]
'''
x_out = self.map_to_feat(x.transpose(-2,-1)).transpose(-2,-1)
return x_out
class VNLeakyReLU(nn.Module):
def __init__(self, in_channels, share_nonlinearity=False, negative_slope=0.01):
super(VNLeakyReLU, self).__init__()
if share_nonlinearity == True:
self.map_to_dir = nn.Linear(in_channels, 1, bias=False)
else:
self.map_to_dir = nn.Linear(in_channels, in_channels, bias=False)
self.negative_slope = negative_slope
def forward(self, x):
'''
x: point features of shape [B, N_samples, N_feat, 3]
'''
d = self.map_to_dir(x.transpose(-2,-1)).transpose(-2,-1) # (N_samples, N_feat, 3)
dotprod = (x*d).sum(-1, keepdim=True) # sum over 3-value dimension
mask = (dotprod >= 0).to(x.dtype)
d_norm_sq = (d*d).sum(-1, keepdim=True) # sum over 3-value dimension
x_out = (self.negative_slope * x +
(1-self.negative_slope) * (mask*x + (1-mask)*(x-(dotprod/(d_norm_sq+EPS))*d)))
return x_out
class ST_GBP_Exp(nn.Module):
def __init__(self, in_scalar, in_vector, out_scalar, out_vector, bottleneck=1, use_conv1d=False):
super(ST_GBP_Exp, self).__init__()
self.in_scalar = in_scalar
self.in_vector = in_vector
self.out_scalar = out_scalar
self.out_vector = out_vector
self.gb_linear1 = GBLinear(
in_scalar, in_vector, in_scalar, in_vector, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.gb_linear2 = GBLinear(
in_scalar, in_vector, out_scalar*2, out_vector, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.act_sca = nn.Tanh()
self.act_vec = VNLeakyReLU(out_vector)
self.rescale = Rescale()
def forward(self, x):
'''
:param x: (batch * repeat_num for node/edge, emb)
:return: w and b for affine operation
'''
sca, vec = self.gb_linear1(x)
sca = self.act_sca(sca)
vec = self.act_vec(vec)
sca, vec = self.gb_linear2([sca, vec])
s = sca[:, :self.out_scalar]
t = sca[:, self.out_scalar:]
s = self.rescale(torch.tanh(s))
return s, t
class MessageAttention(Module):
def __init__(self, in_sca, in_vec, out_sca, out_vec, bottleneck=1, num_heads=1, use_conv1d=False) -> None:
super(MessageAttention, self).__init__()
assert (in_sca % num_heads == 0) and (in_vec % num_heads == 0)
assert (out_sca % num_heads == 0) and (out_vec % num_heads == 0)
self.num_heads =num_heads
self.lin_v = GBLinear(in_sca, in_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d)
self.lin_k = GBLinear(in_sca, in_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d)
def forward(self, x, query, edge_index_i):
N = x[0].size(0)
N_msg = len(edge_index_i)
msg = [
query[0].view(N_msg, self.num_heads, -1),
query[1].view(N_msg, self.num_heads, -1, 3)
]
k = self.lin_k(x)
x_i = [
k[0][edge_index_i].view(N_msg, self.num_heads, -1),
k[1][edge_index_i].view(N_msg, self.num_heads, -1, 3)
]
#alpha_scale = [x_i[0].size(-1)**0.5, x_i[1].size(-2)**0.5]
alpha = [
(msg[0] * x_i[0]).sum(-1), #/alpha_scale[0] # (N', heads)
(msg[1] * x_i[1]).sum(-1).sum(-1) #/alpha_scale[1] # (N', heads)
]
alpha = [
scatter_softmax(alpha[0], edge_index_i, dim=0),
scatter_softmax(alpha[1], edge_index_i, dim=0)
]
msg = [
(alpha[0].unsqueeze(-1) * msg[0]).view(N_msg, -1),
(alpha[1].unsqueeze(-1).unsqueeze(-1) * msg[1]).view(N_msg, -1, 3)
]
sca_msg = scatter_sum(msg[0], edge_index_i, dim=0, dim_size=N)
vec_msg = scatter_sum(msg[1], edge_index_i, dim=0, dim_size=N)
#return sca_msg, vec_msg
root_sca, root_vec = self.lin_v(x)
out_sca = sca_msg + root_sca
out_vec = vec_msg + root_vec
return out_sca, out_vec
class MessageModule(nn.Module):
def __init__(self, node_sca, node_vec, edge_sca, edge_vec, out_sca, out_vec,
bottleneck=1, cutoff=10., use_conv1d=False):
super(MessageModule, self).__init__()
hid_sca, hid_vec = edge_sca, edge_vec
self.cutoff = cutoff
self.node_gblinear = GBLinear(
node_sca, node_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.edge_gbp = GBPerceptronVN(
edge_sca, edge_vec, hid_sca, hid_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.sca_linear = Linear(hid_sca, out_sca) # edge_sca for y_sca
self.e2n_linear = Linear(hid_sca, out_vec)
self.n2e_linear = Linear(out_sca, out_vec)
self.edge_vnlinear = VNLinear(hid_vec, out_vec)
self.out_gblienar = GBLinear(
out_sca, out_vec, out_sca, out_vec, bottleneck=bottleneck, use_conv1d=use_conv1d
)
def forward(self, node_features, edge_features, edge_index_node, dist_ij=None, annealing=False):
node_scalar, node_vector = self.node_gblinear(node_features)
node_scalar, node_vector = node_scalar[edge_index_node], node_vector[edge_index_node]
edge_scalar, edge_vector = self.edge_gbp(edge_features)
y_scalar = node_scalar * self.sca_linear(edge_scalar)
y_node_vector = self.e2n_linear(edge_scalar).unsqueeze(-1) * node_vector
y_edge_vector = self.n2e_linear(node_scalar).unsqueeze(-1) * self.edge_vnlinear(edge_vector)
y_vector = y_node_vector + y_edge_vector
output = self.out_gblienar((y_scalar, y_vector))
if annealing:
C = 0.5 * (torch.cos(dist_ij * PI / self.cutoff) + 1.0) # (A, 1)
C = C * (dist_ij <= self.cutoff) * (dist_ij >= 0.0)
output = [output[0] * C.view(-1, 1), output[1] * C.view(-1, 1, 1)] # (A, 1)
return output
class AttentionInteractionBlockVN(Module):
def __init__(self, hidden_channels, edge_channels, num_edge_types, bottleneck=1, num_heads=1,
cutoff=10., use_conv1d=False):
super(AttentionInteractionBlockVN, self).__init__()
self.num_heads = num_heads
# edge features
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=edge_channels - num_edge_types)
self.vector_expansion = EdgeExpansion(edge_channels) # Linear(in_features=1, out_features=edge_channels, bias=False)
## compare encoder and classifier message passing
# edge weigths and linear for values
self.message_module = MessageModule(hidden_channels[0], hidden_channels[1], edge_channels, edge_channels,
hidden_channels[0], hidden_channels[1], bottleneck=bottleneck,
cutoff=cutoff, use_conv1d=use_conv1d)
self.msg_att = MessageAttention(hidden_channels[0], hidden_channels[1], hidden_channels[0], hidden_channels[1],
bottleneck=bottleneck, num_heads=num_heads, use_conv1d=use_conv1d)
# centroid nodes and finall linear
self.act_sca = LeakyReLU()
self.act_vec = VNLeakyReLU(hidden_channels[1], share_nonlinearity=True)
self.out_transform = GBLinear(
hidden_channels[0], hidden_channels[1], hidden_channels[0], hidden_channels[1], use_conv1d=use_conv1d,
bottleneck=bottleneck
)
self.layernorm_sca = LayerNorm([hidden_channels[0]])
self.layernorm_vec = LayerNorm([hidden_channels[1], 3])
def forward(self, x, edge_index, edge_feature, edge_vector, edge_dist, annealing=False):
"""
Args:
x: Node features: scalar features (N, feat), vector features(N, feat, 3)
edge_index: (2, E).
edge_attr: (E, H)
"""
scalar, vector = x
N = scalar.size(0)
row, col = edge_index # (E,) , (E,)
# Compute edge features
#edge_dist = torch.norm(edge_vector, dim=-1, p=2)
edge_sca_feat = torch.cat([self.distance_expansion(edge_dist), edge_feature], dim=-1)
edge_vec_feat = self.vector_expansion(edge_vector)
msg_j_sca, msg_j_vec = self.message_module(
x, (edge_sca_feat, edge_vec_feat), col, edge_dist, annealing=annealing
)
out_sca, out_vec = self.msg_att(x, (msg_j_sca, msg_j_vec), row)
# non-linear
out_sca = self.layernorm_sca(out_sca)
out_vec = self.layernorm_vec(out_vec)
out = self.out_transform((self.act_sca(out_sca), self.act_vec(out_vec)))
return out
class AttentionEdges(Module):
def __init__(self, hidden_channels, key_channels, num_heads=1, num_bond_types=3, bottleneck=1,
use_conv1d=False):
super(AttentionEdges, self).__init__()
assert (hidden_channels[0] % num_heads == 0) and (hidden_channels[1] % num_heads == 0)
assert (key_channels[0] % num_heads == 0) and (key_channels[1] % num_heads == 0)
self.hidden_channels = hidden_channels
self.key_channels = key_channels
self.num_heads = num_heads
# linear transformation for attention
self.q_lin = GBLinear(
hidden_channels[0], hidden_channels[1], key_channels[0], key_channels[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.k_lin = GBLinear(
hidden_channels[0], hidden_channels[1], key_channels[0], key_channels[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.v_lin = GBLinear(
hidden_channels[0], hidden_channels[1], hidden_channels[0], hidden_channels[1],
bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.atten_bias_lin = AttentionBias(
self.num_heads, hidden_channels, num_bond_types=num_bond_types,
bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.layernorm_sca = LayerNorm([hidden_channels[0]])
self.layernorm_vec = LayerNorm([hidden_channels[1], 3])
def forward(self, edge_attr, edge_index, pos_compose,
index_real_cps_edge_for_atten, tri_edge_index, tri_edge_feat,):
"""
Args:
x: edge features: scalar features (N, feat), vector features(N, feat, 3)
edge_attr: (E, H)
edge_index: (2, E). the row can be seen as batch_edge
"""
scalar, vector = edge_attr
N = scalar.size(0)
row, col = edge_index # (N,)
# Project to multiple key, query and value spaces
h_queries = self.q_lin(edge_attr)
h_queries = (h_queries[0].view(N, self.num_heads, -1), # (N, heads, K_per_head)
h_queries[1].view(N, self.num_heads, -1, 3)) # (N, heads, K_per_head, 3)
h_keys = self.k_lin(edge_attr)
h_keys = (h_keys[0].view(N, self.num_heads, -1), # (N, heads, K_per_head)
h_keys[1].view(N, self.num_heads, -1, 3)) # (N, heads, K_per_head, 3)
h_values = self.v_lin(edge_attr)
h_values = (h_values[0].view(N, self.num_heads, -1), # (N, heads, K_per_head)
h_values[1].view(N, self.num_heads, -1, 3)) # (N, heads, K_per_head, 3)
# assert (index_edge_i_list == index_real_cps_edge_for_atten[0]).all()
# assert (index_edge_j_list == index_real_cps_edge_for_atten[1]).all()
index_edge_i_list, index_edge_j_list = index_real_cps_edge_for_atten
# # get nodes of triangle edges
atten_bias = self.atten_bias_lin(
tri_edge_index,
tri_edge_feat,
pos_compose,
)
# query * key
queries_i = [h_queries[0][index_edge_i_list], h_queries[1][index_edge_i_list]]
keys_j = [h_keys[0][index_edge_j_list], h_keys[1][index_edge_j_list]]
qk_ij = [
(queries_i[0] * keys_j[0]).sum(-1), # (N', heads)
(queries_i[1] * keys_j[1]).sum(-1).sum(-1) # (N', heads)
]
alpha = [
atten_bias[0] + qk_ij[0],
atten_bias[1] + qk_ij[1]
]
alpha = [
scatter_softmax(alpha[0], index_edge_i_list, dim=0), # (N', heads)
scatter_softmax(alpha[1], index_edge_i_list, dim=0) # (N', heads)
]
values_j = [h_values[0][index_edge_j_list], h_values[1][index_edge_j_list]]
num_attens = len(index_edge_j_list)
output =[
scatter_sum((alpha[0].unsqueeze(-1) * values_j[0]).view(num_attens, -1), index_edge_i_list, dim=0, dim_size=N), # (N, H, 3)
scatter_sum((alpha[1].unsqueeze(-1).unsqueeze(-1) * values_j[1]).view(num_attens, -1, 3), index_edge_i_list, dim=0, dim_size=N) # (N, H, 3)
]
# output
output = [edge_attr[0] + output[0], edge_attr[1] + output[1]]
output = [self.layernorm_sca(output[0]), self.layernorm_vec(output[1])]
return output
class AttentionBias(Module):
def __init__(self, num_heads, hidden_channels, cutoff=10., num_bond_types=3,
bottleneck=1, use_conv1d=False): #TODO: change the cutoff
super(AttentionBias, self).__init__()
num_edge_types = num_bond_types + 1
self.num_bond_types = num_bond_types
self.distance_expansion = GaussianSmearing(stop=cutoff, num_gaussians=hidden_channels[0] - num_edge_types-1) # minus 1 for self edges (e.g. edge 0-0)
self.vector_expansion = EdgeExpansion(hidden_channels[1]) # Linear(in_features=1, out_features=hidden_channels[1], bias=False)
self.gblinear = GBLinear(
hidden_channels[0], hidden_channels[1], num_heads, num_heads,
bottleneck=bottleneck, use_conv1d=use_conv1d
)
def forward(self, tri_edge_index, tri_edge_feat, pos_compose):
node_a, node_b = tri_edge_index
pos_a = pos_compose[node_a]
pos_b = pos_compose[node_b]
vector = pos_a - pos_b
dist = torch.norm(vector, p=2, dim=-1)
dist_feat = self.distance_expansion(dist)
sca_feat = torch.cat([
dist_feat,
tri_edge_feat,
], dim=-1)
vec_feat = self.vector_expansion(vector)
output_sca, output_vec = self.gblinear([sca_feat, vec_feat])
output_vec = (output_vec * output_vec).sum(-1)
return output_sca, output_vec

View File

@@ -1,193 +0,0 @@
import torch
from torch import nn
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
keys = ['edge_flow.flow_layers.5', 'atom_flow.flow_layers.5',
'pos_predictor.mu_net', 'pos_predictor.logsigma_net', 'pos_predictor.pi_net',
'focal_net.net']
def reset_parameters(model, keys):
for name, para in model.named_parameters():
for k in keys:
if k in name and 'bias' in name:
torch.nn.init.constant_(para, 0.)
elif k in name and 'layernorm' in name:
torch.nn.init.constant_(para, 1.)
elif k in name and 'rescale.weight' in name:
torch.nn.init.constant_(para, 0.)
elif k in name:
torch.nn.init.kaiming_normal_(para)
return model
def flow_reverse(flow_layers, latent, feat):
for i in reversed(range(len(flow_layers))):
s_sca, t_sca, vec = flow_layers[i](feat)
s_sca = s_sca.exp() # 为什么要计算指数确保s中的每个元素都大于0以保证可逆性
latent = (latent / s_sca) - t_sca
return latent, vec
def flow_forward(flow_layers, x_z, feature):
for i in range(len(flow_layers)):
s_sca, t_sca, vec = flow_layers[i](feature)
s_sca = s_sca.exp() # 为什么要计算指数?
x_z = (x_z + t_sca) * s_sca
if i == 0:
x_log_jacob = (torch.abs(s_sca) + 1e-20).log()
else:
x_log_jacob += (torch.abs(s_sca) + 1e-20).log()
return x_z, x_log_jacob, vec
class GaussianSmearing(nn.Module):
def __init__(self, start=0.0, stop=10.0, num_gaussians=50):
super(GaussianSmearing, self).__init__()
self.stop = stop
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
self.register_buffer('offset', offset)
def forward(self, dist):
dist = dist.clamp_max(self.stop)
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))
class EdgeExpansion(nn.Module):
def __init__(self, edge_channels):
super(EdgeExpansion, self).__init__()
self.nn = nn.Linear(in_features=1, out_features=edge_channels, bias=False)
def forward(self, edge_vector):
edge_vector = edge_vector / (torch.norm(edge_vector, p=2, dim=1, keepdim=True)+1e-7)
expansion = self.nn(edge_vector.unsqueeze(-1)).transpose(1, -1)
return expansion
class Scalarize(nn.Module):
def __init__(self, sca_in_dim, vec_in_dim, hidden_dim, out_dim, act_fn=nn.Sigmoid()) -> None:
super(Scalarize, self).__init__()
self.sca_in_dim = sca_in_dim
self.vec_in_dim = vec_in_dim
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.lin_scalarize_1 = nn.Linear(sca_in_dim+vec_in_dim, hidden_dim)
self.lin_scalarize_2 = nn.Linear(hidden_dim, out_dim)
self.act_fn = act_fn
def forward(self, x):
sca, vec = x[0].view(-1, self.sca_in_dim), x[1]
norm_vec = torch.norm(vec, p=2, dim=-1).view(-1, self.vec_in_dim)
sca = torch.cat([sca, norm_vec], dim=1)
sca = self.lin_scalarize_1(sca)
sca = self.act_fn(sca)
sca = self.lin_scalarize_2(sca)
return sca
class Rescale(nn.Module):
def __init__(self):
super(Rescale, self).__init__()
self.weight = nn.Parameter(torch.zeros([1]))
def forward(self, x):
if torch.isnan(torch.exp(self.weight)).any():
print(self.weight)
raise RuntimeError('Rescale factor has NaN entries')
x = torch.exp(self.weight) * x
return x
'''class ST_DWLayer(nn.Module):
def __init__(self, sca_in_features, sca_out_features,
vec_in_features, vec_out_features, sca_act=nn.ReLU(),
vec_act=nn.Sigmoid(), device=None, dtype=None) -> None:
super(ST_DWLayer, self).__init__()
self.sca_out_features = sca_out_features
self.dw_layer = DWLayer(
sca_in_features, sca_out_features*2,
vec_in_features, vec_out_features, sca_act=sca_act,
vec_act=vec_act, device=device, dtype=dtype
)
self.rescale = Rescale()
def forward(self, x):
sca, vec = self.dw_layer(x)
s_sca = self.rescale(sca[:, :self.sca_out_features])
t_sca = sca[:, self.sca_out_features:]
return s_sca, t_sca, vec'''
class AtomEmbedding(nn.Module):
def __init__(self, in_scalar, in_vector,
out_scalar, out_vector, vector_normalizer=20.):
super(AtomEmbedding, self).__init__()
assert in_vector == 1
self.in_scalar = in_scalar
self.vector_normalizer = vector_normalizer
self.emb_sca = nn.Linear(in_scalar, out_scalar)
self.emb_vec = nn.Linear(in_vector, out_vector)
def forward(self, scalar_input, vector_input):
if isinstance(self.vector_normalizer, float):
vector_input = vector_input / self.vector_normalizer
else:
vector_input = vector_input / torch.norm(vector_input, p=2, dim=-1)
assert vector_input.shape[1:] == (3, ), 'Not support. Only one vector can be input'
sca_emb = self.emb_sca(scalar_input[:, :self.in_scalar]) # b, f -> b, f'
vec_emb = vector_input.unsqueeze(-1) # b, 3 -> b, 3, 1
vec_emb = self.emb_vec(vec_emb).transpose(1, -1) # b, 1, 3 -> b, f', 3
return sca_emb, vec_emb
def embed_compose(compose_feature, compose_pos, idx_ligand, idx_protein,
ligand_atom_emb, protein_atom_emb, emb_dim):
h_ligand = ligand_atom_emb(compose_feature[idx_ligand], compose_pos[idx_ligand])
h_protein = protein_atom_emb(compose_feature[idx_protein], compose_pos[idx_protein])
h_sca = torch.zeros([len(compose_pos), emb_dim[0]],).to(h_ligand[0])
h_vec = torch.zeros([len(compose_pos), emb_dim[1], 3],).to(h_ligand[1])
h_sca[idx_ligand], h_sca[idx_protein] = h_ligand[0], h_protein[0]
h_vec[idx_ligand], h_vec[idx_protein] = h_ligand[1], h_protein[1]
return [h_sca, h_vec]
class SmoothCrossEntropyLoss(_WeightedLoss):
def __init__(self, weight=None, reduction='mean', smoothing=0.0):
super().__init__(weight=weight, reduction=reduction)
self.smoothing = smoothing
self.weight = weight
self.reduction = reduction
@staticmethod
def _smooth_one_hot(targets:torch.Tensor, n_classes:int, smoothing=0.0):
assert 0 <= smoothing < 1
with torch.no_grad():
targets = torch.empty(size=(targets.size(0), n_classes),
device=targets.device) \
.fill_(smoothing /(n_classes-1)) \
.scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
return targets
def forward(self, inputs, targets):
targets = SmoothCrossEntropyLoss._smooth_one_hot(targets, inputs.size(-1),
self.smoothing)
lsm = F.log_softmax(inputs, -1)
if self.weight is not None:
lsm = lsm * self.weight.unsqueeze(0)
loss = -(targets * lsm).sum(-1)
if self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'mean':
loss = loss.mean()
return loss

View File

@@ -1,201 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
from .net_utils import AtomEmbedding, embed_compose
from .encoder import ContextEncoder
from .atom_flow import AtomFlow
from .bond_predictor import BondFlowNew
from .position_predictor import PositionPredictor
from .pos_filter import PositionFilter
from .focal_net import FrontierLayerVN
from easydict import EasyDict
from torch_scatter import scatter_add
#import sys
#sys.path.append("..")
#from utils import get_tri_edges
#from generate_utils import add_ligand_atom_to_data, data2mol
#from rdkit import Chem
encoder_cfg = EasyDict(
{'edge_channels':8, 'num_interactions':6,
'knn':32, 'cutoff':10.0}
)
focal_net_cfg = EasyDict(
{'hidden_dim_sca':32, 'hidden_dim_vec':8}
)
atom_flow_cfg = EasyDict(
{'hidden_dim_sca':32, 'hidden_dim_vec':8, 'num_flow_layers':6}
)
pos_predictor_cfg = EasyDict(
{'num_filters':[64,64], 'n_component':3}
)
pos_filter_cfg = EasyDict(
{'edge_channels':8, 'num_filters':[32,16]}
)
edge_flow_cfg = EasyDict(
{'edge_channels':8, 'num_filters':[32,8], 'num_bond_types':3,
'num_heads':2, 'cutoff':10.0, 'num_flow_layers':3}
)
config = EasyDict(
{'deq_coeff':0.9, 'hidden_channels':32, 'hidden_channels_vec':8, 'bottleneck':8, 'use_conv1d':False,
'encoder':encoder_cfg, 'atom_flow':atom_flow_cfg, 'pos_predictor':pos_predictor_cfg,
'pos_filter':pos_filter_cfg, 'edge_flow':edge_flow_cfg, 'focal_net':focal_net_cfg}
)
class PocketFlowWithEdgeNew(nn.Module):
def __init__(self, config) -> None:
super(PocketFlowWithEdgeNew, self).__init__()
self.config = config
self.num_bond_types = config.num_bond_types
self.msg_annealing = config.msg_annealing
self.emb_dim = [config.hidden_channels, config.hidden_channels_vec]
self.protein_atom_emb = AtomEmbedding(config.protein_atom_feature_dim, 1, *self.emb_dim)
self.ligand_atom_emb = AtomEmbedding(config.ligand_atom_feature_dim, 1, *self.emb_dim)
self.atom_type_embedding = nn.Embedding(config.num_atom_type, config.hidden_channels)
self.encoder = ContextEncoder(
hidden_channels=self.emb_dim, edge_channels=config.encoder.edge_channels,
num_edge_types=config.num_bond_types, num_interactions=config.encoder.num_interactions,
k=config.encoder.knn, cutoff=config.encoder.cutoff, bottleneck=config.bottleneck,
use_conv1d=config.use_conv1d, num_heads=config.encoder.num_heads
)
self.focal_net = FrontierLayerVN(
self.emb_dim[0], self.emb_dim[1], config.focal_net.hidden_dim_sca,
config.focal_net.hidden_dim_vec, bottleneck=config.bottleneck,
use_conv1d=config.use_conv1d
)
self.atom_flow = AtomFlow(
self.emb_dim[0], self.emb_dim[1], config.atom_flow.hidden_dim_sca,
config.atom_flow.hidden_dim_vec, num_lig_atom_type=config.num_atom_type,
num_flow_layers=config.atom_flow.num_flow_layers, bottleneck=config.bottleneck,
use_conv1d=config.use_conv1d
)
self.pos_predictor = PositionPredictor(
self.emb_dim[0], self.emb_dim[1], config.pos_predictor.num_filters,
config.pos_predictor.n_component, bottleneck=config.bottleneck,
use_conv1d=config.use_conv1d
)
'''self.pos_filter = PositionFilter(
self.emb_dim[0], self.emb_dim[1], config.pos_filter.edge_channels,
config.pos_filter.num_filters, bottleneck=config.bottleneck,
use_conv1d=config.use_conv1d
)'''
self.edge_flow = BondFlowNew(
self.emb_dim[0], self.emb_dim[1], config.edge_flow.edge_channels,
config.edge_flow.num_filters, config.edge_flow.num_bond_types,
num_heads=config.edge_flow.num_heads, cutoff=config.edge_flow.cutoff,
num_st_layers=config.edge_flow.num_flow_layers, bottleneck=config.bottleneck,
use_conv1d=config.use_conv1d
)
def get_parameter_number(self):
total_num = sum(p.numel() for p in self.parameters())
trainable_num = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
def get_loss(self, data):
h_cpx = embed_compose(data.cpx_feature.float(), data.cpx_pos, data.idx_ligand_ctx_in_cpx,
data.idx_protein_in_cpx, self.ligand_atom_emb,
self.protein_atom_emb, self.emb_dim)
# encoding context
h_cpx = self.encoder(
node_attr = h_cpx,
pos = data.cpx_pos,
edge_index = data.cpx_edge_index,
edge_feature = data.cpx_edge_feature,
annealing=self.msg_annealing
)
# for focal loss
focal_pred = self.focal_net(h_cpx, data.idx_ligand_ctx_in_cpx)
focal_loss = F.binary_cross_entropy_with_logits(
input=focal_pred, target=data.ligand_frontier.view(-1, 1).float()
)
# for focal loss in protein
focal_pred_apo = self.focal_net(h_cpx, data.apo_protein_idx)
surf_loss = F.binary_cross_entropy_with_logits(
input=focal_pred_apo, target=data.candidate_focal_label_in_protein.view(-1, 1).float()
)
# for atom loss
x_z = F.one_hot(data.atom_label, num_classes=self.config.num_atom_type).float() #[50,27]
x_z += self.config.deq_coeff * torch.rand(x_z.size(), device=x_z.device) #[50,27]
z_atom, atom_log_jacob = self.atom_flow(x_z, h_cpx, data.focal_idx_in_context)
ll_atom = (1/2 * (z_atom ** 2) - atom_log_jacob).sum(-1)
ll_atom = scatter_add(ll_atom, data.atom_label_batch, dim=0).mean()
# for position loss
atom_type_emb = self.atom_type_embedding(data.atom_label)
relative_mu, abs_mu, sigma, pi = self.pos_predictor(
h_cpx, # +atom_type_emb[data.step_batch]
data.focal_idx_in_context,
data.cpx_pos,
atom_type_emb=atom_type_emb,
)
#y_pos = torch.rand_like(data.y_pos) * 0.05 + data.y_pos
loss_pos = -torch.log(
self.pos_predictor.get_mdn_probability(abs_mu, sigma, pi, data.y_pos) + 1e-16
).mean()#.clamp_max(10.) # 最大似然log
#loss_pos = scatter_add(loss_pos, data.atom_label_batch, dim=0).mean()
# for edge loss
z_edge = F.one_hot(data.edge_label, num_classes=4).float()
z_edge += self.config.deq_coeff * torch.rand(z_edge.size(), device=z_edge.device)
edge_index_query = torch.stack([data.edge_query_index_0, data.edge_query_index_1])
pos_query_knn_edge_idx = torch.stack(
[data.pos_query_knn_edge_idx_0, data.pos_query_knn_edge_idx_1]
)
z_edge, edge_log_jacob = self.edge_flow(
z_edge=z_edge,
pos_query=data.y_pos,
edge_index_query=edge_index_query,
cpx_pos=data.cpx_pos,
node_attr_compose=h_cpx,
edge_index_q_cps_knn=pos_query_knn_edge_idx,
index_real_cps_edge_for_atten=data.index_real_cps_edge_for_atten,
tri_edge_index=data.tri_edge_index,
tri_edge_feat=data.tri_edge_feat,
atom_type_emb=atom_type_emb,
annealing=self.msg_annealing
)
ll_edge = (1/2 * (z_edge ** 2) - edge_log_jacob).sum(-1)
ll_edge = scatter_add(ll_edge, data.edge_label_batch, dim=0).mean()
# pos filter
'''pos_fake_knn_edge_idx = torch.stack([data.pos_fake_knn_edge_idx_0, data.pos_fake_knn_edge_idx_1])
pos_fake_pred, _ = self.pos_filter(
pos_query=data.pos_fake,
edge_index_q_cps_knn=pos_fake_knn_edge_idx,
cpx_pos=data.cpx_pos,
node_attr_compose=h_cpx,
annealing=annealing
)
loss_fake = F.binary_cross_entropy_with_logits(
input=pos_fake_pred, target=torch.zeros_like(pos_fake_pred)
)
pos_real_knn_edge_idx = torch.stack([data.pos_real_knn_edge_idx_0, data.pos_real_knn_edge_idx_1])
pos_real_pred, _ = self.pos_filter(
pos_query=data.pos_real,
edge_index_q_cps_knn=pos_real_knn_edge_idx,
cpx_pos=data.cpx_pos,
node_attr_compose=h_cpx,
annealing=annealing
)
loss_real = F.binary_cross_entropy_with_logits(
input=pos_real_pred, target=torch.ones_like(pos_real_pred)
)'''
# loss all
loss = torch.nan_to_num(ll_atom)\
+ torch.nan_to_num(loss_pos)\
+ torch.nan_to_num(ll_edge)\
+ torch.nan_to_num(focal_loss)\
+ torch.nan_to_num(surf_loss)\
#+ torch.nan_to_num(loss_fake)\
#+ torch.nan_to_num(loss_real)\
out_dict = {
'loss':loss, 'loss_atom':ll_atom, 'loss_edge':ll_edge, #'loss_fake':loss_fake, 'loss_real':loss_real,
'loss_pos':loss_pos, 'focal_loss':focal_loss, 'surf_loss':torch.nan_to_num(surf_loss)
}
return out_dict

View File

@@ -1,92 +0,0 @@
import torch
from torch.nn import Module, Sequential
from torch.nn import functional as F
from .layers import GBPerceptronVN, GBLinear
import math
GAUSSIAN_COEF = 1.0 / math.sqrt(2 * math.pi)
class PositionPredictor(Module):
def __init__(self, in_sca, in_vec, num_filters, n_component, bottleneck=1, use_conv1d=False):
super(PositionPredictor, self).__init__()
self.n_component = n_component
self.gvp = Sequential(
GBPerceptronVN(
in_sca*2, in_vec, num_filters[0], num_filters[1], bottleneck=bottleneck, use_conv1d=use_conv1d
),
GBLinear(
num_filters[0], num_filters[1], num_filters[0], num_filters[1], bottleneck=bottleneck,
use_conv1d=use_conv1d
)
)
self.mu_net = GBLinear(
num_filters[0], num_filters[1], n_component, n_component, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.logsigma_net= GBLinear(
num_filters[0], num_filters[1], n_component, n_component, bottleneck=bottleneck, use_conv1d=use_conv1d
)
self.pi_net = GBLinear(
num_filters[0], num_filters[1], n_component, 1, bottleneck=bottleneck, use_conv1d=use_conv1d
)
def forward(self, h_compose, idx_focal, pos_compose, atom_type_emb=None):#, focal_step_batch=None):
h_focal = [h[idx_focal] for h in h_compose]
pos_focal = pos_compose[idx_focal]
if isinstance(atom_type_emb, torch.Tensor):# and isinstance(focal_step_batch, torch.Tensor):
h_focal[0] = torch.cat([h_focal[0], atom_type_emb], dim=1) # 可以直接相乘减少一些训练参数
#h_focal = [h_focal[0] * atom_type_emb, h_focal[1]]
feat_focal = self.gvp(h_focal)
relative_mu = self.mu_net(feat_focal)[1] # (N_focal, n_component, 3)
logsigma = self.logsigma_net(feat_focal)[1] # (N_focal, n_component, 3)
sigma = torch.exp(logsigma)
pi = self.pi_net(feat_focal)[0] # (N_focal, n_component)
pi = F.softmax(pi, dim=1)
abs_mu = relative_mu + pos_focal.unsqueeze(dim=1).expand_as(relative_mu)
return relative_mu, abs_mu, sigma, pi
def get_mdn_probability(self, mu, sigma, pi, pos_target):
prob_gauss = self._get_gaussian_probability(mu, sigma, pos_target)
prob_mdn = pi * prob_gauss
prob_mdn = torch.sum(prob_mdn, dim=1)
return prob_mdn
def _get_gaussian_probability(self, mu, sigma, pos_target):
"""
mu - (N, n_component, 3)
sigma - (N, n_component, 3)
pos_target - (N, 3)
"""
target = pos_target.unsqueeze(1).expand_as(mu)
errors = target - mu # 最大概率似然没有不变性使用L2范数
sigma = sigma + 1e-16
p = GAUSSIAN_COEF * torch.exp(-0.5 * (errors / sigma)**2) / sigma
p = torch.prod(p, dim=2)
return p # (N, n_component)
def sample_batch(self, mu, sigma, pi, num):
"""sample from multiple mix gaussian
mu - (N_batch, n_cat, 3)
sigma - (N_batch, n_cat, 3)
pi - (N_batch, n_cat)
return
(N_batch, num, 3)
"""
index_cats = torch.multinomial(pi, num, replacement=True) # (N_batch, num)
# index_cats = index_cats.unsqueeze(-1)
index_batch = torch.arange(len(mu)).unsqueeze(-1).expand(-1, num) # (N_batch, num)
mu_sample = mu[index_batch, index_cats] # (N_batch, num, 3)
sigma_sample = sigma[index_batch, index_cats]
values = torch.normal(mu_sample, sigma_sample) # (N_batch, num, 3)
return values
def get_maximum(self, mu, sigma, pi):
"""sample from multiple mix gaussian
mu - (N_batch, n_cat, 3)
sigma - (N_batch, n_cat, 3)
pi - (N_batch, n_cat)
return
(N_batch, n_cat, 3)
"""
return mu