mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-07 20:41:20 +08:00
215 lines
6.5 KiB
Python
215 lines
6.5 KiB
Python
import copy
|
|
from functools import partial
|
|
|
|
import dgl
|
|
import dgl.function as fn
|
|
import dgl.nn as dglnn
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, in_feats, out_feats, num_layers=2, hidden=128):
|
|
super(MLP, self).__init__()
|
|
self.layers = nn.ModuleList()
|
|
layer = nn.Linear(hidden, out_feats)
|
|
nn.init.normal_(layer.weight, std=0.1)
|
|
nn.init.zeros_(layer.bias)
|
|
self.layers.append(nn.Linear(in_feats, hidden))
|
|
if num_layers > 2:
|
|
for i in range(1, num_layers - 1):
|
|
layer = nn.Linear(hidden, hidden)
|
|
nn.init.normal_(layer.weight, std=0.1)
|
|
nn.init.zeros_(layer.bias)
|
|
self.layers.append(layer)
|
|
layer = nn.Linear(hidden, out_feats)
|
|
nn.init.normal_(layer.weight, std=0.1)
|
|
nn.init.zeros_(layer.bias)
|
|
self.layers.append(layer)
|
|
|
|
def forward(self, x):
|
|
for l in range(len(self.layers) - 1):
|
|
x = self.layers[l](x)
|
|
x = F.relu(x)
|
|
x = self.layers[-1](x)
|
|
return x
|
|
|
|
|
|
class PrepareLayer(nn.Module):
|
|
"""
|
|
Generate edge feature for the model input preparation:
|
|
as well as do the normalization work.
|
|
Parameters
|
|
==========
|
|
node_feats : int
|
|
Number of node features
|
|
|
|
stat : dict
|
|
dictionary which represent the statistics needed for normalization
|
|
"""
|
|
|
|
def __init__(self, node_feats, stat):
|
|
super(PrepareLayer, self).__init__()
|
|
self.node_feats = node_feats
|
|
# stat {'median':median,'max':max,'min':min}
|
|
self.stat = stat
|
|
|
|
def normalize_input(self, node_feature):
|
|
return (node_feature - self.stat["median"]) * (
|
|
2 / (self.stat["max"] - self.stat["min"])
|
|
)
|
|
|
|
def forward(self, g, node_feature):
|
|
with g.local_scope():
|
|
node_feature = self.normalize_input(node_feature)
|
|
g.ndata["feat"] = node_feature # Only dynamic feature
|
|
g.apply_edges(fn.u_sub_v("feat", "feat", "e"))
|
|
edge_feature = g.edata["e"]
|
|
return node_feature, edge_feature
|
|
|
|
|
|
class InteractionNet(nn.Module):
|
|
"""
|
|
Simple Interaction Network
|
|
One Layer interaction network for stellar multi-body problem simulation,
|
|
it has the ability to simulate number of body motion no more than 12
|
|
Parameters
|
|
==========
|
|
node_feats : int
|
|
Number of node features
|
|
|
|
stat : dict
|
|
Statistcics for Denormalization
|
|
"""
|
|
|
|
def __init__(self, node_feats, stat):
|
|
super(InteractionNet, self).__init__()
|
|
self.node_feats = node_feats
|
|
self.stat = stat
|
|
edge_fn = partial(MLP, num_layers=5, hidden=150)
|
|
node_fn = partial(MLP, num_layers=2, hidden=100)
|
|
|
|
self.in_layer = InteractionLayer(
|
|
node_feats - 3, # Use velocity only
|
|
node_feats,
|
|
out_node_feats=2,
|
|
out_edge_feats=50,
|
|
edge_fn=edge_fn,
|
|
node_fn=node_fn,
|
|
mode="n_n",
|
|
)
|
|
|
|
# Denormalize Velocity only
|
|
def denormalize_output(self, out):
|
|
return (
|
|
out * (self.stat["max"][3:5] - self.stat["min"][3:5]) / 2
|
|
+ self.stat["median"][3:5]
|
|
)
|
|
|
|
def forward(self, g, n_feat, e_feat, global_feats, relation_feats):
|
|
with g.local_scope():
|
|
out_n, out_e = self.in_layer(
|
|
g, n_feat, e_feat, global_feats, relation_feats
|
|
)
|
|
out_n = self.denormalize_output(out_n)
|
|
return out_n, out_e
|
|
|
|
|
|
class InteractionLayer(nn.Module):
|
|
"""
|
|
Implementation of single layer of interaction network
|
|
Parameters
|
|
==========
|
|
in_node_feats : int
|
|
Number of node features
|
|
|
|
in_edge_feats : int
|
|
Number of edge features
|
|
|
|
out_node_feats : int
|
|
Number of node feature after one interaction
|
|
|
|
out_edge_feats : int
|
|
Number of edge features after one interaction
|
|
|
|
global_feats : int
|
|
Number of global features used as input
|
|
|
|
relate_feats : int
|
|
Feature related to the relation between object themselves
|
|
|
|
edge_fn : torch.nn.Module
|
|
Function to update edge feature in message generation
|
|
|
|
node_fn : torch.nn.Module
|
|
Function to update node feature in message aggregation
|
|
|
|
mode : str
|
|
Type of message should the edge carry
|
|
nne : [src_feat,dst_feat,edge_feat] node feature concat edge feature.
|
|
n_n : [src_feat-edge_feat] node feature subtract from each other.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_node_feats,
|
|
in_edge_feats,
|
|
out_node_feats,
|
|
out_edge_feats,
|
|
global_feats=1,
|
|
relate_feats=1,
|
|
edge_fn=nn.Linear,
|
|
node_fn=nn.Linear,
|
|
mode="nne",
|
|
): # 'n_n'
|
|
super(InteractionLayer, self).__init__()
|
|
self.in_node_feats = in_node_feats
|
|
self.in_edge_feats = in_edge_feats
|
|
self.out_edge_feats = out_edge_feats
|
|
self.out_node_feats = out_node_feats
|
|
self.mode = mode
|
|
# MLP for message passing
|
|
input_shape = (
|
|
2 * self.in_node_feats + self.in_edge_feats
|
|
if mode == "nne"
|
|
else self.in_edge_feats + relate_feats
|
|
)
|
|
self.edge_fn = edge_fn(
|
|
input_shape, self.out_edge_feats
|
|
) # 50 in IN paper
|
|
|
|
self.node_fn = node_fn(
|
|
self.in_node_feats + self.out_edge_feats + global_feats,
|
|
self.out_node_feats,
|
|
)
|
|
|
|
# Should be done by apply edge
|
|
def update_edge_fn(self, edges):
|
|
x = torch.cat(
|
|
[edges.src["feat"], edges.dst["feat"], edges.data["feat"]], dim=1
|
|
)
|
|
ret = F.relu(self.edge_fn(x)) if self.mode == "nne" else self.edge_fn(x)
|
|
return {"e": ret}
|
|
|
|
# Assume agg comes from build in reduce
|
|
def update_node_fn(self, nodes):
|
|
x = torch.cat([nodes.data["feat"], nodes.data["agg"]], dim=1)
|
|
ret = F.relu(self.node_fn(x)) if self.mode == "nne" else self.node_fn(x)
|
|
return {"n": ret}
|
|
|
|
def forward(self, g, node_feats, edge_feats, global_feats, relation_feats):
|
|
# print(node_feats.shape,global_feats.shape)
|
|
g.ndata["feat"] = torch.cat([node_feats, global_feats], dim=1)
|
|
g.edata["feat"] = torch.cat([edge_feats, relation_feats], dim=1)
|
|
if self.mode == "nne":
|
|
g.apply_edges(self.update_edge_fn)
|
|
else:
|
|
g.edata["e"] = self.edge_fn(g.edata["feat"])
|
|
|
|
g.update_all(
|
|
fn.copy_e("e", "msg"), fn.sum("msg", "agg"), self.update_node_fn
|
|
)
|
|
return g.ndata["n"], g.edata["e"]
|