mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[Fix][Readability] Improving the Capsule Network example. (#5985)
Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
This commit is contained in:
@@ -68,15 +68,16 @@ def squash(s, dim=1):
|
||||
|
||||
|
||||
def init_graph(in_nodes, out_nodes, f_size, device="cpu"):
|
||||
g = dgl.DGLGraph()
|
||||
g.set_n_initializer(dgl.frame.zero_initializer)
|
||||
all_nodes = in_nodes + out_nodes
|
||||
g.add_nodes(all_nodes)
|
||||
src, dst = [], []
|
||||
in_indx = list(range(in_nodes))
|
||||
out_indx = list(range(in_nodes, in_nodes + out_nodes))
|
||||
# add edges use edge broadcasting
|
||||
for u in in_indx:
|
||||
g.add_edges(u, out_indx)
|
||||
src += [u] * len(out_indx)
|
||||
dst += out_indx
|
||||
|
||||
g = dgl.graph((src, dst)) # dgl.graph once;
|
||||
g.set_n_initializer(dgl.frame.zero_initializer)
|
||||
g = g.to(device)
|
||||
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
|
||||
return g
|
||||
|
||||
Reference in New Issue
Block a user