[GraphBolt][CUDA][Temporal] Tests and example enablement. (#7678)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-09 08:52:02 -04:00
committed by GitHub
parent 6d55515dcd
commit 90c26be268
3 changed files with 59 additions and 59 deletions

View File

@@ -121,6 +121,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
shuffle=is_train,
)
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=args.device)
############################################################################
# [Input]:
# 'datapipe' is either 'ItemSampler' or 'UniformNegativeSampler' depending
@@ -250,7 +253,7 @@ def parse_args():
parser.add_argument(
"--mode",
default="cpu-cuda",
choices=["cpu-cpu", "cpu-cuda"],
choices=["cpu-cpu", "cpu-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)

View File

@@ -830,10 +830,6 @@ def test_in_subgraph_hetero():
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
@@ -848,6 +844,8 @@ def test_temporal_sample_neighbors_homo(
use_node_timestamp,
use_edge_timestamp,
):
if replace and F._default_context_str == "gpu":
pytest.skip("Sampling with replacement not yet implemented on the GPU.")
"""Original graph in COO:
1 0 1 0 1
1 0 1 1 0
@@ -867,7 +865,7 @@ def test_temporal_sample_neighbors_homo(
assert len(indptr) == total_num_nodes + 1
# Construct FusedCSCSamplingGraph.
graph = gb.fused_csc_sampling_graph(indptr, indices)
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([2])
@@ -878,15 +876,17 @@ def test_temporal_sample_neighbors_homo(
)
seed_list = [1, 3, 4]
seed_timestamp = torch.randint(0, 100, (len(seed_list),), dtype=torch.int64)
seed_timestamp = torch.randint(
0, 100, (len(seed_list),), dtype=torch.int64, device=F.ctx()
)
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()
)
graph.edge_attributes = {"timestamp": edge_timestamp}
@@ -936,7 +936,7 @@ def test_temporal_sample_neighbors_homo(
available_neighbors.append(neighbors)
return available_neighbors
nodes = torch.tensor(seed_list, dtype=indices_dtype)
nodes = torch.tensor(seed_list, dtype=indices_dtype, device=F.ctx())
subgraph = sampler(
nodes,
seed_timestamp,
@@ -947,6 +947,7 @@ def test_temporal_sample_neighbors_homo(
)
sampled_count = torch.diff(subgraph.sampled_csc.indptr).tolist()
available_neighbors = _get_available_neighbors()
assert len(available_neighbors) == len(sampled_count)
for i, count in enumerate(sampled_count):
if not replace:
expect_count = min(fanouts[0], len(available_neighbors[i]))
@@ -958,10 +959,6 @@ def test_temporal_sample_neighbors_homo(
assert set(neighbors.tolist()).issubset(set(available_neighbors[i]))
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Graph is CPU only at present.",
)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("replace", [False, True])
@@ -976,6 +973,8 @@ def test_temporal_sample_neighbors_hetero(
use_node_timestamp,
use_edge_timestamp,
):
if replace and F._default_context_str == "gpu":
pytest.skip("Sampling with replacement not yet implemented on the GPU.")
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
@@ -1006,7 +1005,7 @@ def test_temporal_sample_neighbors_hetero(
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
).to(F.ctx())
# Generate subgraph via sample neighbors.
fanouts = torch.LongTensor([-1, -1])
@@ -1017,8 +1016,8 @@ def test_temporal_sample_neighbors_hetero(
)
seeds = {
"n1": torch.tensor([0], dtype=indices_dtype),
"n2": torch.tensor([0], dtype=indices_dtype),
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
}
per_etype_destination_nodes = {
"n1:e1:n2": torch.tensor([1], dtype=indices_dtype),
@@ -1026,17 +1025,17 @@ def test_temporal_sample_neighbors_hetero(
}
seed_timestamp = {
"n1": torch.randint(0, 100, (1,), dtype=torch.int64),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64),
"n1": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),
"n2": torch.randint(0, 100, (1,), dtype=torch.int64, device=F.ctx()),
}
if use_node_timestamp:
node_timestamp = torch.randint(
0, 100, (total_num_nodes,), dtype=torch.int64
0, 100, (total_num_nodes,), dtype=torch.int64, device=F.ctx()
)
graph.node_attributes = {"timestamp": node_timestamp}
if use_edge_timestamp:
edge_timestamp = torch.randint(
0, 100, (total_num_edges,), dtype=torch.int64
0, 100, (total_num_edges,), dtype=torch.int64, device=F.ctx()
)
graph.edge_attributes = {"timestamp": edge_timestamp}

View File

@@ -14,14 +14,6 @@ import torch
from . import gb_test_utils
# Skip all tests on GPU when sampling with TemporalNeighborSampler.
def _check_sampler_type(sampler_type):
if F._default_context_str != "cpu" and _is_temporal(sampler_type):
pytest.skip(
"TemporalNeighborSampler sampling tests are only supported on CPU."
)
def _check_sampler_len(sampler, lenExp):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
@@ -199,7 +191,6 @@ def test_NeighborSampler_fanouts(labor):
],
)
def test_SubgraphSampler_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
@@ -231,7 +222,6 @@ def test_SubgraphSampler_Node(sampler_type):
],
)
def test_SubgraphSampler_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
@@ -268,7 +258,6 @@ def test_SubgraphSampler_Link(sampler_type):
],
)
def test_SubgraphSampler_Link_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
@@ -302,7 +291,6 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type):
],
)
def test_SubgraphSampler_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
@@ -339,7 +327,6 @@ def test_SubgraphSampler_HyperLink(sampler_type):
],
)
def test_SubgraphSampler_Node_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(3)
names = "seeds"
@@ -375,7 +362,6 @@ def test_SubgraphSampler_Node_Hetero(sampler_type):
],
)
def test_SubgraphSampler_Link_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
@@ -435,7 +421,6 @@ def test_SubgraphSampler_Link_Hetero(sampler_type):
],
)
def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
@@ -485,7 +470,6 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(sampler_type):
],
)
def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
@@ -535,7 +519,6 @@ def test_SubgraphSampler_Link_Hetero_Unknown_Etype(sampler_type):
],
)
def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
first_items = torch.LongTensor([[0, 0, 1, 1], [0, 2, 0, 1]]).T
first_names = "seeds"
@@ -586,7 +569,6 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
],
)
def test_SubgraphSampler_HyperLink_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.LongTensor([[2, 0, 1, 1, 2], [0, 1, 1, 0, 0]])
names = "seeds"
@@ -646,7 +628,6 @@ def test_SubgraphSampler_HyperLink_Hetero(sampler_type):
[False, True],
)
def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
_check_sampler_type(sampler_type)
if F._default_context_str == "gpu" and replace == True:
pytest.skip("Sampling with replacement not yet supported on GPU.")
num_nodes = 5
@@ -748,7 +729,6 @@ def test_SubgraphSampler_Random_Hetero_Graph(sampler_type, replace):
],
)
def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
@@ -758,10 +738,14 @@ def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
names = "seeds"
if _is_temporal(sampler_type):
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
"timestamp": torch.zeros(
graph.csc_indptr.numel() - 1, dtype=torch.int64
).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
"timestamp": torch.zeros(
graph.indices.numel(), dtype=torch.int64
).to(F.ctx())
}
items = (items, torch.randint(1, 10, (3,)))
names = (names, "timestamp")
@@ -822,16 +806,19 @@ def test_SubgraphSampler_without_deduplication_Homo_Node(sampler_type):
],
)
def test_SubgraphSampler_without_deduplication_Hetero_Node(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2)
names = "seeds"
if _is_temporal(sampler_type):
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
"timestamp": torch.zeros(
graph.csc_indptr.numel() - 1, dtype=torch.int64, device=F.ctx()
)
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
"timestamp": torch.zeros(
graph.indices.numel(), dtype=torch.int64, device=F.ctx()
)
}
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
@@ -1084,7 +1071,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Node(labor):
],
)
def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items_n1 = torch.tensor([0])
items_n2 = torch.tensor([1])
@@ -1160,7 +1146,6 @@ def test_SubgraphSampler_Hetero_multifanout_per_layer(sampler_type):
],
)
def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
@@ -1170,10 +1155,14 @@ def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
names = "seeds"
if _is_temporal(sampler_type):
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
"timestamp": torch.zeros(
graph.csc_indptr.numel() - 1, dtype=torch.int64
).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
"timestamp": torch.zeros(
graph.indices.numel(), dtype=torch.int64
).to(F.ctx())
}
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
@@ -1227,16 +1216,19 @@ def test_SubgraphSampler_without_deduplication_Homo_Link(sampler_type):
],
)
def test_SubgraphSampler_without_deduplication_Hetero_Link(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(2).view(1, 2)
names = "seeds"
if _is_temporal(sampler_type):
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
"timestamp": torch.zeros(
graph.csc_indptr.numel() - 1, dtype=torch.int64
).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
"timestamp": torch.zeros(
graph.indices.numel(), dtype=torch.int64
).to(F.ctx())
}
items = (items, torch.randint(1, 10, (1,)))
names = (names, "timestamp")
@@ -1542,7 +1534,6 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Link(labor):
],
)
def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
@@ -1551,10 +1542,14 @@ def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
names = "seeds"
if _is_temporal(sampler_type):
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
"timestamp": torch.zeros(
graph.csc_indptr.numel() - 1, dtype=torch.int64
).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
"timestamp": torch.zeros(
graph.indices.numel(), dtype=torch.int64
).to(F.ctx())
}
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
@@ -1608,16 +1603,19 @@ def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
],
)
def test_SubgraphSampler_without_deduplication_Hetero_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(3).view(1, 3)
names = "seeds"
if _is_temporal(sampler_type):
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
"timestamp": torch.zeros(
graph.csc_indptr.numel() - 1, dtype=torch.int64
).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
"timestamp": torch.zeros(
graph.indices.numel(), dtype=torch.int64
).to(F.ctx())
}
items = (items, torch.randint(1, 10, (1,)))
names = (names, "timestamp")