mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA][Temporal] Tests and example enablement. (#7678)
This commit is contained in:
committed by
GitHub
parent
6d55515dcd
commit
90c26be268
@@ -121,6 +121,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
|
||||
shuffle=is_train,
|
||||
)
|
||||
|
||||
if args.storage_device != "cpu":
|
||||
datapipe = datapipe.copy_to(device=args.device)
|
||||
|
||||
############################################################################
|
||||
# [Input]:
|
||||
# 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending
|
||||
@@ -250,7 +253,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
default="cpu-cuda",
|
||||
choices=["cpu-cpu", "cpu-cuda"],
|
||||
choices=["cpu-cpu", "cpu-cuda", "cuda-cuda"],
|
||||
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
|
||||
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
|
||||
)
|
||||
|
||||
@@ -830,10 +830,6 @@ def test_in_subgraph_hetero():
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
F._default_context_str == "gpu",
|
||||
reason="Graph is CPU only at present.",
|
||||
)
|
||||
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
|
||||
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
|
||||
@pytest.mark.parametrize("replace", [False, True])
|
||||
@@ -848,6 +844,8 @@ def test_temporal_sample_neighbors_homo(
|
||||
use_node_timestamp,
|
||||
use_edge_timestamp,
|
||||
):
|
||||
if replace and F._default_context_str == "gpu":
|
||||
pytest.skip("Sampling with replacement not yet implemented on the GPU.")
|
||||
"""Original graph in COO:
|
||||
1 0 1 0 1
|
||||
1 0 1 1 0
|
||||
@@ -867,7 +865,7 @@ def test_temporal_sample_neighbors_homo(
|
||||
assert len(indptr) == total_num_nodes + 1
|
||||
|
||||
# Construct FusedCSCSamplingGraph.
|
||||
graph = gb.fused_csc_sampling_graph(indptr, indices)
|
||||
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
|
||||
|
||||
# Generate subgraph via sample neighbors.
|
||||
fanouts = torch.LongTensor([2])
|
||||
@@ -878,15 +876,17 @@ def test_temporal_sample_neighbors_homo(
|
||||
)
|
||||
|
||||
seed_list = [1, 3, 4]
|
||||
seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)
|
||||
seed_timestamp = torch.randint(
|
||||
0, 100, (len(seed_list),), dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
if use_node_timestamp:
|
||||
node_timestamp = torch.randint(
|
||||
0, 100, (total_num_nodes,), dtype=torch.int64
|
||||
0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
graph.node_attributes = {"timestamp": node_timestamp}
|
||||
if use_edge_timestamp:
|
||||
edge_timestamp = torch.randint(
|
||||
0, 100, (total_num_edges,), dtype=torch.int64
|
||||
0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
graph.edge_attributes = {"timestamp": edge_timestamp}
|
||||
|
||||
@@ -936,7 +936,7 @@ def test_temporal_sample_neighbors_homo(
|
||||
available_neighbors.append(neighbors)
|
||||
return available_neighbors
|
||||
|
||||
nodes = torch.tensor(seed_list, dtype=indices_dtype)
|
||||
nodes = torch.tensor(seed_list, dtype=indices_dtype, device=F.ctx())
|
||||
subgraph = sampler(
|
||||
nodes,
|
||||
seed_timestamp,
|
||||
@@ -947,6 +947,7 @@ def test_temporal_sample_neighbors_homo(
|
||||
)
|
||||
sampled_count = torch.diff(subgraph.sampled_csc.indptr).tolist()
|
||||
available_neighbors = _get_available_neighbors()
|
||||
assert len(available_neighbors) == len(sampled_count)
|
||||
for i, count in enumerate(sampled_count):
|
||||
if not replace:
|
||||
expect_count = min(fanouts[0], len(available_neighbors[i]))
|
||||
@@ -958,10 +959,6 @@ def test_temporal_sample_neighbors_homo(
|
||||
assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
F._default_context_str == "gpu",
|
||||
reason="Graph is CPU only at present.",
|
||||
)
|
||||
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
|
||||
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
|
||||
@pytest.mark.parametrize("replace", [False, True])
|
||||
@@ -976,6 +973,8 @@ def test_temporal_sample_neighbors_hetero(
|
||||
use_node_timestamp,
|
||||
use_edge_timestamp,
|
||||
):
|
||||
if replace and F._default_context_str == "gpu":
|
||||
pytest.skip("Sampling with replacement not yet implemented on the GPU.")
|
||||
"""Original graph in COO:
|
||||
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
|
||||
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
|
||||
@@ -1006,7 +1005,7 @@ def test_temporal_sample_neighbors_hetero(
|
||||
type_per_edge=type_per_edge,
|
||||
node_type_to_id=ntypes,
|
||||
edge_type_to_id=etypes,
|
||||
)
|
||||
).to(F.ctx())
|
||||
|
||||
# Generate subgraph via sample neighbors.
|
||||
fanouts = torch.LongTensor([-1, -1])
|
||||
@@ -1017,8 +1016,8 @@ def test_temporal_sample_neighbors_hetero(
|
||||
)
|
||||
|
||||
seeds = {
|
||||
"n1": torch.tensor([0], dtype=indices_dtype),
|
||||
"n2": torch.tensor([0], dtype=indices_dtype),
|
||||
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
|
||||
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
|
||||
}
|
||||
per_etype_destination_nodes = {
|
||||
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
|
||||
@@ -1026,17 +1025,17 @@ def test_temporal_sample_neighbors_hetero(
|
||||
}
|
||||
|
||||
seed_timestamp = {
|
||||
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
|
||||
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
|
||||
"n1": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),
|
||||
"n2": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),
|
||||
}
|
||||
if use_node_timestamp:
|
||||
node_timestamp = torch.randint(
|
||||
0, 100, (total_num_nodes,), dtype=torch.int64
|
||||
0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
graph.node_attributes = {"timestamp": node_timestamp}
|
||||
if use_edge_timestamp:
|
||||
edge_timestamp = torch.randint(
|
||||
0, 100, (total_num_edges,), dtype=torch.int64
|
||||
0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
graph.edge_attributes = {"timestamp": edge_timestamp}
|
||||
|
||||
|
||||
@@ -14,14 +14,6 @@ import torch
|
||||
from . import gb_test_utils
|
||||
|
||||
|
||||
# Skip all tests on GPU when sampling with TemporalNeighborSampler.
|
||||
def _check_sampler_type(sampler_type):
|
||||
if F._default_context_str != "cpu" and _is_temporal(sampler_type):
|
||||
pytest.skip(
|
||||
"TemporalNeighborSampler sampling tests are only supported on CPU."
|
||||
)
|
||||
|
||||
|
||||
def _check_sampler_len(sampler, lenExp):
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=UserWarning)
|
||||
@@ -199,7 +191,6 @@ def test_NeighborSampler_fanouts(labor):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Node(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
|
||||
F.ctx()
|
||||
)
|
||||
@@ -231,7 +222,6 @@ def test_SubgraphSampler_Node(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Link(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
|
||||
F.ctx()
|
||||
)
|
||||
@@ -268,7 +258,6 @@ def test_SubgraphSampler_Link(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Link_With_Negative(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
|
||||
F.ctx()
|
||||
)
|
||||
@@ -302,7 +291,6 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_HyperLink(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
|
||||
F.ctx()
|
||||
)
|
||||
@@ -339,7 +327,6 @@ def test_SubgraphSampler_HyperLink(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Node_Hetero(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
items = torch.arange(3)
|
||||
names = "seeds"
|
||||
@@ -375,7 +362,6 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Link_Hetero(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
|
||||
first_names = "seeds"
|
||||
@@ -435,7 +421,6 @@ def test_SubgraphSampler_Link_Hetero(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
|
||||
first_names = "seeds"
|
||||
@@ -485,7 +470,6 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
|
||||
first_names = "seeds"
|
||||
@@ -535,7 +519,6 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
|
||||
first_names = "seeds"
|
||||
@@ -586,7 +569,6 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_HyperLink_Hetero(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
items = torch.LongTensor([[2, 0, 1, 1, 2], [0, 1, 1, 0, 0]])
|
||||
names = "seeds"
|
||||
@@ -646,7 +628,6 @@ def test_SubgraphSampler_HyperLink_Hetero(sampler_type):
|
||||
[False, True],
|
||||
)
|
||||
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
|
||||
_check_sampler_type(sampler_type)
|
||||
if F._default_context_str == "gpu" and replace == True:
|
||||
pytest.skip("Sampling with replacement not yet supported on GPU.")
|
||||
num_nodes = 5
|
||||
@@ -748,7 +729,6 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = dgl.graph(
|
||||
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
|
||||
)
|
||||
@@ -758,10 +738,14 @@ def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
|
||||
names = "seeds"
|
||||
if _is_temporal(sampler_type):
|
||||
graph.node_attributes = {
|
||||
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.csc_indptr.numel() - 1, dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
graph.edge_attributes = {
|
||||
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.indices.numel(), dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
items = (items, torch.randint(1, 10, (3,)))
|
||||
names = (names, "timestamp")
|
||||
@@ -822,16 +806,19 @@ def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_without_deduplication_Hetero_Node(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
items = torch.arange(2)
|
||||
names = "seeds"
|
||||
if _is_temporal(sampler_type):
|
||||
graph.node_attributes = {
|
||||
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.csc_indptr.numel() - 1, dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
}
|
||||
graph.edge_attributes = {
|
||||
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.indices.numel(), dtype=torch.int64, device=F.ctx()
|
||||
)
|
||||
}
|
||||
items = (items, torch.randint(1, 10, (2,)))
|
||||
names = (names, "timestamp")
|
||||
@@ -1084,7 +1071,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Node(labor):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
items_n1 = torch.tensor([0])
|
||||
items_n2 = torch.tensor([1])
|
||||
@@ -1160,7 +1146,6 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = dgl.graph(
|
||||
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
|
||||
)
|
||||
@@ -1170,10 +1155,14 @@ def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
|
||||
names = "seeds"
|
||||
if _is_temporal(sampler_type):
|
||||
graph.node_attributes = {
|
||||
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.csc_indptr.numel() - 1, dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
graph.edge_attributes = {
|
||||
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.indices.numel(), dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
items = (items, torch.randint(1, 10, (2,)))
|
||||
names = (names, "timestamp")
|
||||
@@ -1227,16 +1216,19 @@ def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_without_deduplication_Hetero_Link(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
items = torch.arange(2).view(1, 2)
|
||||
names = "seeds"
|
||||
if _is_temporal(sampler_type):
|
||||
graph.node_attributes = {
|
||||
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.csc_indptr.numel() - 1, dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
graph.edge_attributes = {
|
||||
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.indices.numel(), dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
items = (items, torch.randint(1, 10, (1,)))
|
||||
names = (names, "timestamp")
|
||||
@@ -1542,7 +1534,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Link(labor):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = dgl.graph(
|
||||
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
|
||||
)
|
||||
@@ -1551,10 +1542,14 @@ def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
|
||||
names = "seeds"
|
||||
if _is_temporal(sampler_type):
|
||||
graph.node_attributes = {
|
||||
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.csc_indptr.numel() - 1, dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
graph.edge_attributes = {
|
||||
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.indices.numel(), dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
items = (items, torch.randint(1, 10, (2,)))
|
||||
names = (names, "timestamp")
|
||||
@@ -1608,16 +1603,19 @@ def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
|
||||
],
|
||||
)
|
||||
def test_SubgraphSampler_without_deduplication_Hetero_HyperLink(sampler_type):
|
||||
_check_sampler_type(sampler_type)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
items = torch.arange(3).view(1, 3)
|
||||
names = "seeds"
|
||||
if _is_temporal(sampler_type):
|
||||
graph.node_attributes = {
|
||||
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.csc_indptr.numel() - 1, dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
graph.edge_attributes = {
|
||||
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
|
||||
"timestamp": torch.zeros(
|
||||
graph.indices.numel(), dtype=torch.int64
|
||||
).to(F.ctx())
|
||||
}
|
||||
items = (items, torch.randint(1, 10, (1,)))
|
||||
names = (names, "timestamp")
|
||||
|
||||
Reference in New Issue
Block a user