mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Eliminate synchronization from exclude edges. (#7757)
This commit is contained in:
committed by
GitHub
parent
03e83ac5a8
commit
d6cf415cbb
@@ -202,8 +202,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
|
||||
# the negative samples.
|
||||
############################################################################
|
||||
if is_train and args.exclude_edges:
|
||||
datapipe = datapipe.transform(
|
||||
partial(gb.exclude_seed_edges, include_reverse_edges=True)
|
||||
datapipe = datapipe.exclude_seed_edges(
|
||||
include_reverse_edges=True,
|
||||
asynchronous=args.storage_device != "cpu",
|
||||
)
|
||||
|
||||
############################################################################
|
||||
|
||||
@@ -163,8 +163,9 @@ def create_dataloader(
|
||||
asynchronous=args.graph_device != "cpu",
|
||||
)
|
||||
if job == "train" and args.exclude_edges:
|
||||
datapipe = datapipe.transform(
|
||||
partial(gb.exclude_seed_edges, include_reverse_edges=True)
|
||||
datapipe = datapipe.exclude_seed_edges(
|
||||
include_reverse_edges=True,
|
||||
asynchronous=args.graph_device != "cpu",
|
||||
)
|
||||
# Copy the data to the specified device.
|
||||
if args.feature_device != "cpu" and need_copy:
|
||||
|
||||
@@ -79,10 +79,22 @@ Sort(torch::Tensor input, int num_bits = 0);
|
||||
* @return
|
||||
* A boolean tensor of the same shape as elements that is True for elements
|
||||
* in test_elements and False otherwise.
|
||||
*
|
||||
*/
|
||||
torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements);
|
||||
|
||||
/**
|
||||
* @brief Returns the indexes of the nonzero elements in the given boolean mask
|
||||
* if logical_not is false. Otherwise, returns the indexes of the zero elements
|
||||
* instead.
|
||||
*
|
||||
* @param mask Input boolean mask.
|
||||
* @param logical_not Whether mask should be treated as ~mask.
|
||||
*
|
||||
* @return An int64_t tensor of the same shape as mask containing the indexes
|
||||
* of the selected elements.
|
||||
*/
|
||||
torch::Tensor Nonzero(torch::Tensor mask, bool logical_not);
|
||||
|
||||
/**
|
||||
* @brief Select columns for a sparse matrix in a CSC format according to nodes
|
||||
* tensor.
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#ifndef GRAPHBOLT_ISIN_H_
|
||||
#define GRAPHBOLT_ISIN_H_
|
||||
|
||||
#include <graphbolt/async.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace graphbolt {
|
||||
@@ -25,11 +26,27 @@ namespace sampling {
|
||||
* @return
|
||||
* A boolean tensor of the same shape as elements that is True for elements
|
||||
* in test_elements and False otherwise.
|
||||
*
|
||||
*/
|
||||
torch::Tensor IsIn(
|
||||
const torch::Tensor& elements, const torch::Tensor& test_elements);
|
||||
|
||||
/**
|
||||
* @brief Tests if each element of elements is not in test_elements. Returns an
|
||||
* int64_t tensor of the same shape as elements containing the indexes of the
|
||||
* elements not found in test_elements.
|
||||
*
|
||||
* @param elements Input elements
|
||||
* @param test_elements Values against which to test for each input element.
|
||||
*
|
||||
* @return An int64_t tensor of the same shape as elements containing indexes of
|
||||
* elements not found in test_elements.
|
||||
*/
|
||||
torch::Tensor IsNotInIndex(
|
||||
const torch::Tensor& elements, const torch::Tensor& test_elements);
|
||||
|
||||
c10::intrusive_ptr<Future<torch::Tensor>> IsNotInIndexAsync(
|
||||
const torch::Tensor& elements, const torch::Tensor& test_elements);
|
||||
|
||||
} // namespace sampling
|
||||
} // namespace graphbolt
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@
|
||||
#include <graphbolt/cuda_ops.h>
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "./common.h"
|
||||
|
||||
namespace graphbolt {
|
||||
@@ -42,5 +44,25 @@ torch::Tensor IsIn(torch::Tensor elements, torch::Tensor test_elements) {
|
||||
return result;
|
||||
}
|
||||
|
||||
torch::Tensor Nonzero(torch::Tensor mask, bool logical_not) {
|
||||
thrust::counting_iterator<int64_t> iota(0);
|
||||
auto result = torch::empty_like(mask, torch::kInt64);
|
||||
auto mask_ptr = mask.data_ptr<bool>();
|
||||
auto result_ptr = result.data_ptr<int64_t>();
|
||||
auto allocator = cuda::GetAllocator();
|
||||
auto num_copied = allocator.AllocateStorage<int64_t>(1);
|
||||
if (logical_not) {
|
||||
CUB_CALL(
|
||||
DeviceSelect::FlaggedIf, iota, mask_ptr, result_ptr, num_copied.get(),
|
||||
mask.numel(), thrust::logical_not<bool>{});
|
||||
} else {
|
||||
CUB_CALL(
|
||||
DeviceSelect::Flagged, iota, mask_ptr, result_ptr, num_copied.get(),
|
||||
mask.numel());
|
||||
}
|
||||
cuda::CopyScalar num_copied_cpu(num_copied.get());
|
||||
return result.slice(0, 0, static_cast<int64_t>(num_copied_cpu));
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace graphbolt
|
||||
|
||||
@@ -56,5 +56,22 @@ torch::Tensor IsIn(
|
||||
return IsInCPU(elements, test_elements);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor IsNotInIndex(
|
||||
const torch::Tensor& elements, const torch::Tensor& test_elements) {
|
||||
auto mask = IsIn(elements, test_elements);
|
||||
if (utils::is_on_gpu(mask)) {
|
||||
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
|
||||
c10::DeviceType::CUDA, "NonzeroOperation",
|
||||
{ return ops::Nonzero(mask, true); });
|
||||
}
|
||||
return torch::nonzero(torch::logical_not(mask)).squeeze(1);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Future<torch::Tensor>> IsNotInIndexAsync(
|
||||
const torch::Tensor& elements, const torch::Tensor& test_elements) {
|
||||
return async([=] { return IsNotInIndex(elements, test_elements); });
|
||||
}
|
||||
|
||||
} // namespace sampling
|
||||
} // namespace graphbolt
|
||||
|
||||
@@ -181,6 +181,8 @@ TORCH_LIBRARY(graphbolt, m) {
|
||||
m.def("unique_and_compact_batched", &UniqueAndCompactBatched);
|
||||
m.def("unique_and_compact_batched_async", &UniqueAndCompactBatchedAsync);
|
||||
m.def("isin", &IsIn);
|
||||
m.def("is_not_in_index", &IsNotInIndex);
|
||||
m.def("is_not_in_index_async", &IsNotInIndexAsync);
|
||||
m.def("index_select", &ops::IndexSelect);
|
||||
m.def("index_select_async", &ops::IndexSelectAsync);
|
||||
m.def("scatter_async", &ops::ScatterAsync);
|
||||
|
||||
@@ -1,10 +1,60 @@
|
||||
"""Utility functions for external use."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch.utils.data import functional_datapipe
|
||||
|
||||
from .minibatch import MiniBatch
|
||||
from .minibatch_transformer import MiniBatchTransformer
|
||||
|
||||
|
||||
@functional_datapipe("exclude_seed_edges")
|
||||
class SeedEdgesExcluder(MiniBatchTransformer):
|
||||
"""A mini-batch transformer used to manipulate mini-batch.
|
||||
|
||||
Functional name: :obj:`transform`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
datapipe : DataPipe
|
||||
The datapipe.
|
||||
include_reverse_edges : bool
|
||||
Whether reverse edges should be excluded as well. Default is False.
|
||||
reverse_etypes_mapping : Dict[str, str] = None
|
||||
The mapping from the original edge types to their reverse edge types.
|
||||
asynchronous: bool
|
||||
Boolean indicating whether edge exclusion stages should run on
|
||||
background threads to hide the latency of CPU GPU synchronization.
|
||||
Should be enabled only when sampling on the GPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
datapipe,
|
||||
include_reverse_edges: bool = False,
|
||||
reverse_etypes_mapping: Dict[str, str] = None,
|
||||
asynchronous=False,
|
||||
):
|
||||
exclude_seed_edges_fn = partial(
|
||||
exclude_seed_edges,
|
||||
include_reverse_edges=include_reverse_edges,
|
||||
reverse_etypes_mapping=reverse_etypes_mapping,
|
||||
async_op=asynchronous,
|
||||
)
|
||||
datapipe = datapipe.transform(exclude_seed_edges_fn)
|
||||
if asynchronous:
|
||||
datapipe = datapipe.buffer()
|
||||
datapipe = datapipe.transform(self._wait_for_sampled_subgraphs)
|
||||
super().__init__(datapipe)
|
||||
|
||||
@staticmethod
|
||||
def _wait_for_sampled_subgraphs(minibatch):
|
||||
minibatch.sampled_subgraphs = [
|
||||
subgraph.wait() for subgraph in minibatch.sampled_subgraphs
|
||||
]
|
||||
return minibatch
|
||||
|
||||
|
||||
def add_reverse_edges(
|
||||
@@ -79,6 +129,7 @@ def exclude_seed_edges(
|
||||
minibatch: MiniBatch,
|
||||
include_reverse_edges: bool = False,
|
||||
reverse_etypes_mapping: Dict[str, str] = None,
|
||||
async_op: bool = False,
|
||||
):
|
||||
"""
|
||||
Exclude seed edges with or without their reverse edges from the sampled
|
||||
@@ -88,8 +139,13 @@ def exclude_seed_edges(
|
||||
----------
|
||||
minibatch : MiniBatch
|
||||
The minibatch.
|
||||
include_reverse_edges : bool
|
||||
Whether reverse edges should be excluded as well. Default is False.
|
||||
reverse_etypes_mapping : Dict[str, str] = None
|
||||
The mapping from the original edge types to their reverse edge types.
|
||||
async_op: bool
|
||||
Boolean indicating whether the call is asynchronous. If so, the result
|
||||
can be obtained by calling wait on the modified sampled_subgraphs.
|
||||
"""
|
||||
edges_to_exclude = minibatch.seeds
|
||||
if include_reverse_edges:
|
||||
@@ -97,7 +153,7 @@ def exclude_seed_edges(
|
||||
edges_to_exclude, reverse_etypes_mapping
|
||||
)
|
||||
minibatch.sampled_subgraphs = [
|
||||
subgraph.exclude_edges(edges_to_exclude)
|
||||
subgraph.exclude_edges(edges_to_exclude, async_op=async_op)
|
||||
for subgraph in minibatch.sampled_subgraphs
|
||||
]
|
||||
return minibatch
|
||||
|
||||
@@ -20,6 +20,27 @@ from .internal_utils import recursive_apply
|
||||
__all__ = ["SampledSubgraph"]
|
||||
|
||||
|
||||
class _ExcludeEdgesWaiter:
|
||||
def __init__(self, sampled_subgraph, index):
|
||||
self.sampled_subgraph = sampled_subgraph
|
||||
self.index = index
|
||||
|
||||
def wait(self):
|
||||
"""Returns the stored value when invoked."""
|
||||
sampled_subgraph = self.sampled_subgraph
|
||||
index = self.index
|
||||
# Ensure there is no memory leak.
|
||||
self.sampled_subgraph = self.index = None
|
||||
|
||||
if isinstance(index, dict):
|
||||
for k in list(index.keys()):
|
||||
index[k] = index[k].wait()
|
||||
else:
|
||||
index = index.wait()
|
||||
|
||||
return type(sampled_subgraph)(*_slice_subgraph(sampled_subgraph, index))
|
||||
|
||||
|
||||
class PyGLayerData(NamedTuple):
|
||||
"""A named tuple class to represent homogenous inputs to a PyG model layer.
|
||||
The fields are x (input features), edge_index and size
|
||||
@@ -142,6 +163,7 @@ class SampledSubgraph:
|
||||
torch.Tensor,
|
||||
],
|
||||
assume_num_node_within_int32: bool = True,
|
||||
async_op: bool = False,
|
||||
):
|
||||
r"""Exclude edges from the sampled subgraph.
|
||||
|
||||
@@ -163,6 +185,9 @@ class SampledSubgraph:
|
||||
If True, assumes the value of node IDs in the provided `edges` fall
|
||||
within the int32 range, which can significantly enhance computation
|
||||
speed. Default: True
|
||||
async_op: bool
|
||||
Boolean indicating whether the call is asynchronous. If so, the
|
||||
result can be obtained by calling wait on the returned future.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -222,9 +247,8 @@ class SampledSubgraph:
|
||||
self.original_column_node_ids,
|
||||
)
|
||||
index = _exclude_homo_edges(
|
||||
reverse_edges, edges, assume_num_node_within_int32
|
||||
reverse_edges, edges, assume_num_node_within_int32, async_op
|
||||
)
|
||||
return calling_class(*_slice_subgraph(self, index))
|
||||
else:
|
||||
index = {}
|
||||
for etype, pair in self.sampled_csc.items():
|
||||
@@ -252,7 +276,11 @@ class SampledSubgraph:
|
||||
reverse_edges,
|
||||
edges[etype],
|
||||
assume_num_node_within_int32,
|
||||
async_op,
|
||||
)
|
||||
if async_op:
|
||||
return _ExcludeEdgesWaiter(self, index)
|
||||
else:
|
||||
return calling_class(*_slice_subgraph(self, index))
|
||||
|
||||
def to_pyg(
|
||||
@@ -367,6 +395,7 @@ def _exclude_homo_edges(
|
||||
edges: Tuple[torch.Tensor, torch.Tensor],
|
||||
edges_to_exclude: torch.Tensor,
|
||||
assume_num_node_within_int32: bool,
|
||||
async_op: bool,
|
||||
):
|
||||
"""Return the indices of edges to be included."""
|
||||
if assume_num_node_within_int32:
|
||||
@@ -381,8 +410,11 @@ def _exclude_homo_edges(
|
||||
raise NotImplementedError(
|
||||
"Values out of range int32 are not supported yet"
|
||||
)
|
||||
mask = ~isin(val, val_to_exclude)
|
||||
return torch.nonzero(mask, as_tuple=True)[0]
|
||||
if async_op:
|
||||
return torch.ops.graphbolt.is_not_in_index_async(val, val_to_exclude)
|
||||
else:
|
||||
mask = ~isin(val, val_to_exclude)
|
||||
return torch.nonzero(mask, as_tuple=True)[0]
|
||||
|
||||
|
||||
def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
|
||||
|
||||
@@ -72,7 +72,8 @@ def test_add_reverse_edges_hetero():
|
||||
F._default_context_str == "gpu",
|
||||
reason="Fails due to different result on the GPU.",
|
||||
)
|
||||
def test_exclude_seed_edges_homo_cpu():
|
||||
@pytest.mark.parametrize("use_datapipe", [False, True])
|
||||
def test_exclude_seed_edges_homo_cpu(use_datapipe):
|
||||
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
|
||||
graph = gb.from_dglgraph(graph, True).to(F.ctx())
|
||||
items = torch.LongTensor([[0, 3], [4, 4]])
|
||||
@@ -83,7 +84,10 @@ def test_exclude_seed_edges_homo_cpu():
|
||||
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
|
||||
sampler = gb.NeighborSampler
|
||||
datapipe = sampler(datapipe, graph, fanouts)
|
||||
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
|
||||
if use_datapipe:
|
||||
datapipe = datapipe.exclude_seed_edges()
|
||||
else:
|
||||
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
|
||||
original_row_node_ids = [
|
||||
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
|
||||
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
|
||||
@@ -121,7 +125,9 @@ def test_exclude_seed_edges_homo_cpu():
|
||||
F._default_context_str == "cpu",
|
||||
reason="Fails due to different result on the CPU.",
|
||||
)
|
||||
def test_exclude_seed_edges_gpu():
|
||||
@pytest.mark.parametrize("use_datapipe", [False, True])
|
||||
@pytest.mark.parametrize("async_op", [False, True])
|
||||
def test_exclude_seed_edges_gpu(use_datapipe, async_op):
|
||||
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
|
||||
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
|
||||
items = torch.LongTensor([[0, 3], [4, 4]])
|
||||
@@ -137,7 +143,12 @@ def test_exclude_seed_edges_gpu():
|
||||
fanouts,
|
||||
deduplicate=True,
|
||||
)
|
||||
datapipe = datapipe.transform(partial(gb.exclude_seed_edges))
|
||||
if use_datapipe:
|
||||
datapipe = datapipe.exclude_seed_edges(asynchronous=async_op)
|
||||
else:
|
||||
datapipe = datapipe.transform(
|
||||
partial(gb.exclude_seed_edges, async_op=async_op)
|
||||
)
|
||||
if torch.cuda.get_device_capability()[0] < 7:
|
||||
original_row_node_ids = [
|
||||
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
|
||||
@@ -174,6 +185,8 @@ def test_exclude_seed_edges_gpu():
|
||||
]
|
||||
for data in datapipe:
|
||||
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
|
||||
if async_op and not use_datapipe:
|
||||
sampled_subgraph = sampled_subgraph.wait()
|
||||
assert torch.equal(
|
||||
sampled_subgraph.original_row_node_ids,
|
||||
original_row_node_ids[step],
|
||||
|
||||
Reference in New Issue
Block a user