Files
dgl/tests/python/common/test_convert.py
2023-06-01 19:49:15 +08:00

123 lines
3.7 KiB
Python

import unittest
import backend as F
import dgl
from utils import parametrize_idtype
def get_nodes_by_ntype(nodes, ntype):
return dict((k, v) for k, v in nodes.items() if v["ntype"] == ntype)
def edge_attrs(edge):
# Edges in Networkx are in the format (src, dst, attrs)
return edge[2]
def get_edges_by_etype(edges, etype):
return [e for e in edges if edge_attrs(e)["etype"] == etype]
def check_attrs_for_nodes(nodes, attrs):
return all(v.keys() == attrs for v in nodes.values())
def check_attr_values_for_nodes(nodes, attr_name, values):
return F.allclose(
F.stack([v[attr_name] for v in nodes.values()], 0), values
)
def check_attrs_for_edges(edges, attrs):
return all(edge_attrs(e).keys() == attrs for e in edges)
def check_attr_values_for_edges(edges, attr_name, values):
return F.allclose(
F.stack([edge_attrs(e)[attr_name] for e in edges], 0), values
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="`to_networkx` does not support graphs on GPU",
)
@parametrize_idtype
def test_to_networkx(idtype):
# TODO: adapt and move code from the _test_nx_conversion function in
# tests/python/common/function/test_basics.py to here
# (pending resolution of https://github.com/dmlc/dgl/issues/5735).
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "follows", "topic"): ([1, 1], [1, 2]),
("user", "plays", "game"): ([0, 3], [3, 4]),
},
idtype=idtype,
device=F.ctx(),
)
n1 = F.randn((5, 3))
n2 = F.randn((4, 2))
e1 = F.randn((2, 3))
e2 = F.randn((2, 2))
g.nodes["game"].data["n"] = F.copy_to(n1, ctx=F.ctx())
g.nodes["user"].data["n"] = F.copy_to(n2, ctx=F.ctx())
g.edges[("user", "follows", "user")].data["e"] = F.copy_to(e1, ctx=F.ctx())
g.edges["plays"].data["e"] = F.copy_to(e2, ctx=F.ctx())
nxg = dgl.to_networkx(
g,
node_attrs=["n"],
edge_attrs=["e"],
)
# Test nodes
nxg_nodes = dict(nxg.nodes(data=True))
assert len(nxg_nodes) == g.num_nodes()
assert {v["ntype"] for v in nxg_nodes.values()} == set(g.ntypes)
nxg_nodes_by_ntype = {}
for ntype in g.ntypes:
nxg_nodes_by_ntype[ntype] = get_nodes_by_ntype(nxg_nodes, ntype)
assert g.num_nodes(ntype) == len(nxg_nodes_by_ntype[ntype])
assert check_attrs_for_nodes(nxg_nodes_by_ntype["game"], {"ntype", "n"})
assert check_attr_values_for_nodes(nxg_nodes_by_ntype["game"], "n", n1)
assert check_attrs_for_nodes(nxg_nodes_by_ntype["user"], {"ntype", "n"})
assert check_attr_values_for_nodes(nxg_nodes_by_ntype["user"], "n", n2)
# Nodes without node attributes
assert check_attrs_for_nodes(nxg_nodes_by_ntype["topic"], {"ntype"})
# Test edges
nxg_edges = list(nxg.edges(data=True))
assert len(nxg_edges) == g.num_edges()
assert {edge_attrs(e)["etype"] for e in nxg_edges} == set(
g.canonical_etypes
)
nxg_edges_by_etype = {}
for etype in g.canonical_etypes:
nxg_edges_by_etype[etype] = get_edges_by_etype(nxg_edges, etype)
assert g.num_edges(etype) == len(nxg_edges_by_etype[etype])
assert check_attrs_for_edges(
nxg_edges_by_etype[("user", "follows", "user")],
{"id", "etype", "e"},
)
assert check_attr_values_for_edges(
nxg_edges_by_etype[("user", "follows", "user")], "e", e1
)
assert check_attrs_for_edges(
nxg_edges_by_etype[("user", "plays", "game")], {"id", "etype", "e"}
)
assert check_attr_values_for_edges(
nxg_edges_by_etype[("user", "plays", "game")], "e", e2
)
# Edges without edge attributes
assert check_attrs_for_edges(
nxg_edges_by_etype[("user", "follows", "topic")], {"id", "etype"}
)