mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
* merge * format * rename * sort * sort * update * update * update * Update tests/utils/checks.py Co-authored-by: Mufei Li <mufeili1996@gmail.com> --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal> Co-authored-by: Mufei Li <mufeili1996@gmail.com>
53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
import unittest
|
|
|
|
import backend as F
|
|
|
|
import dgl
|
|
from dgl.dataloading import (
|
|
as_edge_prediction_sampler,
|
|
negative_sampler,
|
|
NeighborSampler,
|
|
)
|
|
from utils import parametrize_idtype
|
|
|
|
|
|
def create_test_graph(idtype):
|
|
# test heterograph from the docstring, plus a user -- wishes -- game relation
|
|
# 3 users, 2 games, 2 developers
|
|
# metagraph:
|
|
# ('user', 'follows', 'user'),
|
|
# ('user', 'plays', 'game'),
|
|
# ('user', 'wishes', 'game'),
|
|
# ('developer', 'develops', 'game')])
|
|
|
|
g = dgl.heterograph(
|
|
{
|
|
("user", "follows", "user"): ([0, 1], [1, 2]),
|
|
("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
|
|
("user", "wishes", "game"): ([0, 2], [1, 0]),
|
|
("developer", "develops", "game"): ([0, 1], [0, 1]),
|
|
},
|
|
idtype=idtype,
|
|
device=F.ctx(),
|
|
)
|
|
assert g.idtype == idtype
|
|
assert g.device == F.ctx()
|
|
return g
|
|
|
|
|
|
@parametrize_idtype
|
|
def test_edge_prediction_sampler(idtype):
|
|
g = create_test_graph(idtype)
|
|
sampler = NeighborSampler([10, 10])
|
|
sampler = as_edge_prediction_sampler(
|
|
sampler, negative_sampler=negative_sampler.Uniform(1)
|
|
)
|
|
|
|
seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
|
|
# just a smoke test to make sure we don't fail internal assertions
|
|
result = sampler.sample(g, {"follows": seeds})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_edge_prediction_sampler()
|