mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-05 19:54:25 +08:00
386 lines
10 KiB
Python
386 lines
10 KiB
Python
import backend as F
|
|
|
|
import dgl
|
|
import dgl.function as fn
|
|
import numpy as np
|
|
import scipy.sparse as sp
|
|
from utils import parametrize_idtype
|
|
|
|
D = 5
|
|
|
|
|
|
def generate_graph(idtype):
|
|
g = dgl.graph([])
|
|
g = g.astype(idtype).to(F.ctx())
|
|
g.add_nodes(10)
|
|
# create a graph where 0 is the source and 9 is the sink
|
|
for i in range(1, 9):
|
|
g.add_edges(0, i)
|
|
g.add_edges(i, 9)
|
|
# add a back flow from 9 to 0
|
|
g.add_edges(9, 0)
|
|
g.ndata.update({"f1": F.randn((10,)), "f2": F.randn((10, D))})
|
|
weights = F.randn((17,))
|
|
g.edata.update({"e1": weights, "e2": F.unsqueeze(weights, 1)})
|
|
return g
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_v2v_update_all(idtype):
|
|
def _test(fld):
|
|
def message_func(edges):
|
|
return {"m": edges.src[fld]}
|
|
|
|
def message_func_edge(edges):
|
|
if len(edges.src[fld].shape) == 1:
|
|
return {"m": edges.src[fld] * edges.data["e1"]}
|
|
else:
|
|
return {"m": edges.src[fld] * edges.data["e2"]}
|
|
|
|
def reduce_func(nodes):
|
|
return {fld: F.sum(nodes.mailbox["m"], 1)}
|
|
|
|
def apply_func(nodes):
|
|
return {fld: 2 * nodes.data[fld]}
|
|
|
|
g = generate_graph(idtype)
|
|
# update all
|
|
v1 = g.ndata[fld]
|
|
g.update_all(
|
|
fn.copy_u(u=fld, out="m"), fn.sum(msg="m", out=fld), apply_func
|
|
)
|
|
v2 = g.ndata[fld]
|
|
g.ndata.update({fld: v1})
|
|
g.update_all(message_func, reduce_func, apply_func)
|
|
v3 = g.ndata[fld]
|
|
assert F.allclose(v2, v3)
|
|
# update all with edge weights
|
|
v1 = g.ndata[fld]
|
|
g.update_all(
|
|
fn.u_mul_e(fld, "e1", "m"), fn.sum(msg="m", out=fld), apply_func
|
|
)
|
|
v2 = g.ndata[fld]
|
|
g.ndata.update({fld: v1})
|
|
g.update_all(message_func_edge, reduce_func, apply_func)
|
|
v4 = g.ndata[fld]
|
|
assert F.allclose(v2, v4)
|
|
|
|
# test 1d node features
|
|
_test("f1")
|
|
# test 2d node features
|
|
_test("f2")
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_v2v_snr(idtype):
|
|
u = F.tensor([0, 0, 0, 3, 4, 9], idtype)
|
|
v = F.tensor([1, 2, 3, 9, 9, 0], idtype)
|
|
|
|
def _test(fld):
|
|
def message_func(edges):
|
|
return {"m": edges.src[fld]}
|
|
|
|
def message_func_edge(edges):
|
|
if len(edges.src[fld].shape) == 1:
|
|
return {"m": edges.src[fld] * edges.data["e1"]}
|
|
else:
|
|
return {"m": edges.src[fld] * edges.data["e2"]}
|
|
|
|
def reduce_func(nodes):
|
|
return {fld: F.sum(nodes.mailbox["m"], 1)}
|
|
|
|
def apply_func(nodes):
|
|
return {fld: 2 * nodes.data[fld]}
|
|
|
|
g = generate_graph(idtype)
|
|
# send and recv
|
|
v1 = g.ndata[fld]
|
|
g.send_and_recv(
|
|
(u, v),
|
|
fn.copy_u(u=fld, out="m"),
|
|
fn.sum(msg="m", out=fld),
|
|
apply_func,
|
|
)
|
|
v2 = g.ndata[fld]
|
|
g.ndata.update({fld: v1})
|
|
g.send_and_recv((u, v), message_func, reduce_func, apply_func)
|
|
v3 = g.ndata[fld]
|
|
assert F.allclose(v2, v3)
|
|
# send and recv with edge weights
|
|
v1 = g.ndata[fld]
|
|
g.send_and_recv(
|
|
(u, v),
|
|
fn.u_mul_e(fld, "e1", "m"),
|
|
fn.sum(msg="m", out=fld),
|
|
apply_func,
|
|
)
|
|
v2 = g.ndata[fld]
|
|
g.ndata.update({fld: v1})
|
|
g.send_and_recv((u, v), message_func_edge, reduce_func, apply_func)
|
|
v4 = g.ndata[fld]
|
|
assert F.allclose(v2, v4)
|
|
|
|
# test 1d node features
|
|
_test("f1")
|
|
# test 2d node features
|
|
_test("f2")
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_v2v_pull(idtype):
|
|
nodes = F.tensor([1, 2, 3, 9], idtype)
|
|
|
|
def _test(fld):
|
|
def message_func(edges):
|
|
return {"m": edges.src[fld]}
|
|
|
|
def message_func_edge(edges):
|
|
if len(edges.src[fld].shape) == 1:
|
|
return {"m": edges.src[fld] * edges.data["e1"]}
|
|
else:
|
|
return {"m": edges.src[fld] * edges.data["e2"]}
|
|
|
|
def reduce_func(nodes):
|
|
return {fld: F.sum(nodes.mailbox["m"], 1)}
|
|
|
|
def apply_func(nodes):
|
|
return {fld: 2 * nodes.data[fld]}
|
|
|
|
g = generate_graph(idtype)
|
|
# send and recv
|
|
v1 = g.ndata[fld]
|
|
g.pull(
|
|
nodes,
|
|
fn.copy_u(u=fld, out="m"),
|
|
fn.sum(msg="m", out=fld),
|
|
apply_func,
|
|
)
|
|
v2 = g.ndata[fld]
|
|
g.ndata[fld] = v1
|
|
g.pull(nodes, message_func, reduce_func, apply_func)
|
|
v3 = g.ndata[fld]
|
|
assert F.allclose(v2, v3)
|
|
# send and recv with edge weights
|
|
v1 = g.ndata[fld]
|
|
g.pull(
|
|
nodes,
|
|
fn.u_mul_e(fld, "e1", "m"),
|
|
fn.sum(msg="m", out=fld),
|
|
apply_func,
|
|
)
|
|
v2 = g.ndata[fld]
|
|
g.ndata[fld] = v1
|
|
g.pull(nodes, message_func_edge, reduce_func, apply_func)
|
|
v4 = g.ndata[fld]
|
|
assert F.allclose(v2, v4)
|
|
|
|
# test 1d node features
|
|
_test("f1")
|
|
# test 2d node features
|
|
_test("f2")
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_update_all_multi_fallback(idtype):
|
|
# create a graph with zero in degree nodes
|
|
g = dgl.graph([])
|
|
g = g.astype(idtype).to(F.ctx())
|
|
g.add_nodes(10)
|
|
for i in range(1, 9):
|
|
g.add_edges(0, i)
|
|
g.add_edges(i, 9)
|
|
g.ndata["h"] = F.randn((10, D))
|
|
g.edata["w1"] = F.randn((16,))
|
|
g.edata["w2"] = F.randn((16, D))
|
|
|
|
def _mfunc_hxw1(edges):
|
|
return {"m1": edges.src["h"] * F.unsqueeze(edges.data["w1"], 1)}
|
|
|
|
def _mfunc_hxw2(edges):
|
|
return {"m2": edges.src["h"] * edges.data["w2"]}
|
|
|
|
def _rfunc_m1(nodes):
|
|
return {"o1": F.sum(nodes.mailbox["m1"], 1)}
|
|
|
|
def _rfunc_m2(nodes):
|
|
return {"o2": F.sum(nodes.mailbox["m2"], 1)}
|
|
|
|
def _rfunc_m1max(nodes):
|
|
return {"o3": F.max(nodes.mailbox["m1"], 1)}
|
|
|
|
def _afunc(nodes):
|
|
ret = {}
|
|
for k, v in nodes.data.items():
|
|
if k.startswith("o"):
|
|
ret[k] = 2 * v
|
|
return ret
|
|
|
|
# compute ground truth
|
|
g.update_all(_mfunc_hxw1, _rfunc_m1, _afunc)
|
|
o1 = g.ndata.pop("o1")
|
|
g.update_all(_mfunc_hxw2, _rfunc_m2, _afunc)
|
|
o2 = g.ndata.pop("o2")
|
|
g.update_all(_mfunc_hxw1, _rfunc_m1max, _afunc)
|
|
o3 = g.ndata.pop("o3")
|
|
# v2v spmv
|
|
g.update_all(
|
|
fn.u_mul_e("h", "w1", "m1"), fn.sum(msg="m1", out="o1"), _afunc
|
|
)
|
|
assert F.allclose(o1, g.ndata.pop("o1"))
|
|
# v2v fallback to e2v
|
|
g.update_all(
|
|
fn.u_mul_e("h", "w2", "m2"), fn.sum(msg="m2", out="o2"), _afunc
|
|
)
|
|
assert F.allclose(o2, g.ndata.pop("o2"))
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_pull_multi_fallback(idtype):
|
|
# create a graph with zero in degree nodes
|
|
g = dgl.graph([])
|
|
g = g.astype(idtype).to(F.ctx())
|
|
g.add_nodes(10)
|
|
for i in range(1, 9):
|
|
g.add_edges(0, i)
|
|
g.add_edges(i, 9)
|
|
g.ndata["h"] = F.randn((10, D))
|
|
g.edata["w1"] = F.randn((16,))
|
|
g.edata["w2"] = F.randn((16, D))
|
|
|
|
def _mfunc_hxw1(edges):
|
|
return {"m1": edges.src["h"] * F.unsqueeze(edges.data["w1"], 1)}
|
|
|
|
def _mfunc_hxw2(edges):
|
|
return {"m2": edges.src["h"] * edges.data["w2"]}
|
|
|
|
def _rfunc_m1(nodes):
|
|
return {"o1": F.sum(nodes.mailbox["m1"], 1)}
|
|
|
|
def _rfunc_m2(nodes):
|
|
return {"o2": F.sum(nodes.mailbox["m2"], 1)}
|
|
|
|
def _rfunc_m1max(nodes):
|
|
return {"o3": F.max(nodes.mailbox["m1"], 1)}
|
|
|
|
def _afunc(nodes):
|
|
ret = {}
|
|
for k, v in nodes.data.items():
|
|
if k.startswith("o"):
|
|
ret[k] = 2 * v
|
|
return ret
|
|
|
|
# nodes to pull
|
|
def _pull_nodes(nodes):
|
|
# compute ground truth
|
|
g.pull(nodes, _mfunc_hxw1, _rfunc_m1, _afunc)
|
|
o1 = g.ndata.pop("o1")
|
|
g.pull(nodes, _mfunc_hxw2, _rfunc_m2, _afunc)
|
|
o2 = g.ndata.pop("o2")
|
|
g.pull(nodes, _mfunc_hxw1, _rfunc_m1max, _afunc)
|
|
o3 = g.ndata.pop("o3")
|
|
# v2v spmv
|
|
g.pull(
|
|
nodes,
|
|
fn.u_mul_e("h", "w1", "m1"),
|
|
fn.sum(msg="m1", out="o1"),
|
|
_afunc,
|
|
)
|
|
assert F.allclose(o1, g.ndata.pop("o1"))
|
|
# v2v fallback to e2v
|
|
g.pull(
|
|
nodes,
|
|
fn.u_mul_e("h", "w2", "m2"),
|
|
fn.sum(msg="m2", out="o2"),
|
|
_afunc,
|
|
)
|
|
assert F.allclose(o2, g.ndata.pop("o2"))
|
|
|
|
# test#1: non-0deg nodes
|
|
nodes = [1, 2, 9]
|
|
_pull_nodes(nodes)
|
|
# test#2: 0deg nodes + non-0deg nodes
|
|
nodes = [0, 1, 2, 9]
|
|
_pull_nodes(nodes)
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_spmv_3d_feat(idtype):
|
|
def src_mul_edge_udf(edges):
|
|
return {
|
|
"sum": edges.src["h"]
|
|
* F.unsqueeze(F.unsqueeze(edges.data["h"], 1), 1)
|
|
}
|
|
|
|
def sum_udf(nodes):
|
|
return {"h": F.sum(nodes.mailbox["sum"], 1)}
|
|
|
|
n = 100
|
|
p = 0.1
|
|
a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n))
|
|
g = dgl.from_scipy(a)
|
|
g = g.astype(idtype).to(F.ctx())
|
|
m = g.num_edges()
|
|
|
|
# test#1: v2v with adj data
|
|
h = F.randn((n, 5, 5))
|
|
e = F.randn((m,))
|
|
|
|
g.ndata["h"] = h
|
|
g.edata["h"] = e
|
|
g.update_all(
|
|
message_func=fn.u_mul_e("h", "h", "sum"), reduce_func=fn.sum("sum", "h")
|
|
) # 1
|
|
ans = g.ndata["h"]
|
|
|
|
g.ndata["h"] = h
|
|
g.edata["h"] = e
|
|
g.update_all(
|
|
message_func=src_mul_edge_udf, reduce_func=fn.sum("sum", "h")
|
|
) # 2
|
|
assert F.allclose(g.ndata["h"], ans)
|
|
|
|
g.ndata["h"] = h
|
|
g.edata["h"] = e
|
|
g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
|
|
assert F.allclose(g.ndata["h"], ans)
|
|
|
|
# test#2: e2v
|
|
def src_mul_edge_udf(edges):
|
|
return {"sum": edges.src["h"] * edges.data["h"]}
|
|
|
|
h = F.randn((n, 5, 5))
|
|
e = F.randn((m, 5, 5))
|
|
|
|
g.ndata["h"] = h
|
|
g.edata["h"] = e
|
|
g.update_all(
|
|
message_func=fn.u_mul_e("h", "h", "sum"), reduce_func=fn.sum("sum", "h")
|
|
) # 1
|
|
ans = g.ndata["h"]
|
|
|
|
g.ndata["h"] = h
|
|
g.edata["h"] = e
|
|
g.update_all(
|
|
message_func=src_mul_edge_udf, reduce_func=fn.sum("sum", "h")
|
|
) # 2
|
|
assert F.allclose(g.ndata["h"], ans)
|
|
|
|
g.ndata["h"] = h
|
|
g.edata["h"] = e
|
|
g.update_all(message_func=src_mul_edge_udf, reduce_func=sum_udf) # 3
|
|
assert F.allclose(g.ndata["h"], ans)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_v2v_update_all()
|
|
test_v2v_snr()
|
|
test_v2v_pull()
|
|
test_v2v_update_all_multi_fn()
|
|
test_v2v_snr_multi_fn()
|
|
test_e2v_update_all_multi_fn()
|
|
test_e2v_snr_multi_fn()
|
|
test_e2v_recv_multi_fn()
|
|
test_update_all_multi_fallback()
|
|
test_pull_multi_fallback()
|
|
test_spmv_3d_feat()
|