[Fix][Readability] Improving the Capsule Network example. (#5985)

Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
This commit is contained in:
Andrei Ivanov
2023-07-26 22:15:48 -07:00
committed by GitHub
parent 416e2425b1
commit 017d9d4010

View File

@@ -68,15 +68,16 @@ def squash(s, dim=1):
def init_graph(in_nodes, out_nodes, f_size, device="cpu"):
g = dgl.DGLGraph()
g.set_n_initializer(dgl.frame.zero_initializer)
all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes)
src, dst = [], []
in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes))
# add edges use edge broadcasting
for u in in_indx:
g.add_edges(u, out_indx)
src += [u] * len(out_indx)
dst += out_indx
g = dgl.graph((src, dst)) # dgl.graph once;
g.set_n_initializer(dgl.frame.zero_initializer)
g = g.to(device)
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g