Files
dgl/tests/python/common/dataloading/test_dataloader.py
Hongzhi (Steve), Chen 8a83027274 [Misc] Merge test utils. (#5440)
* merge

* format

* rename

* sort

* sort

* update

* update

* update

* Update tests/utils/checks.py

Co-authored-by: Mufei Li <mufeili1996@gmail.com>

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
Co-authored-by: Mufei Li <mufeili1996@gmail.com>
2023-03-20 15:32:45 +08:00

53 lines
1.4 KiB
Python

import unittest
import backend as F
import dgl
from dgl.dataloading import (
as_edge_prediction_sampler,
negative_sampler,
NeighborSampler,
)
from utils import parametrize_idtype
def create_test_graph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])
g = dgl.heterograph(
{
("user", "follows", "user"): ([0, 1], [1, 2]),
("user", "plays", "game"): ([0, 1, 2, 1], [0, 0, 1, 1]),
("user", "wishes", "game"): ([0, 2], [1, 0]),
("developer", "develops", "game"): ([0, 1], [0, 1]),
},
idtype=idtype,
device=F.ctx(),
)
assert g.idtype == idtype
assert g.device == F.ctx()
return g
@parametrize_idtype
def test_edge_prediction_sampler(idtype):
g = create_test_graph(idtype)
sampler = NeighborSampler([10, 10])
sampler = as_edge_prediction_sampler(
sampler, negative_sampler=negative_sampler.Uniform(1)
)
seeds = F.copy_to(F.arange(0, 2, dtype=idtype), ctx=F.ctx())
# just a smoke test to make sure we don't fail internal assertions
result = sampler.sample(g, {"follows": seeds})
if __name__ == "__main__":
test_edge_prediction_sampler()