mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Fetch indices for NS later. (#7665)
This commit is contained in:
committed by
GitHub
parent
01d10e5820
commit
b5ee45fd1a
@@ -54,6 +54,8 @@ namespace ops {
|
||||
* @param layer Boolean indicating whether neighbors should be sampled in a
|
||||
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
|
||||
* sampled edges, see arXiv:2210.13339.
|
||||
* @param returning_indices_is_optional Boolean indicating whether returning
|
||||
* indices tensor is optional.
|
||||
* @param type_per_edge A tensor representing the type of each edge, if present.
|
||||
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
|
||||
* corresponding to each neighboring edge of a node. It must be
|
||||
@@ -76,6 +78,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
|
||||
torch::optional<torch::Tensor> seeds,
|
||||
torch::optional<std::vector<int64_t>> seed_offsets,
|
||||
const std::vector<int64_t>& fanouts, bool replace, bool layer,
|
||||
bool returning_indices_is_optional,
|
||||
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
|
||||
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
|
||||
torch::optional<torch::Tensor> node_type_offset = torch::nullopt,
|
||||
|
||||
@@ -339,6 +339,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
|
||||
* @param layer Boolean indicating whether neighbors should be sampled in a
|
||||
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
|
||||
* sampled edges, see arXiv:2210.13339.
|
||||
* @param returning_indices_is_optional Boolean indicating whether returning
|
||||
* indices tensor is optional.
|
||||
* @param probs_or_mask An optional edge attribute tensor for probablities
|
||||
* or masks. This attribute tensor should contain (unnormalized)
|
||||
* probabilities corresponding to each neighboring edge of a node. It must be
|
||||
@@ -355,6 +357,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
|
||||
torch::optional<torch::Tensor> seeds,
|
||||
torch::optional<std::vector<int64_t>> seed_offsets,
|
||||
const std::vector<int64_t>& fanouts, bool replace, bool layer,
|
||||
bool returning_indices_is_optional,
|
||||
torch::optional<torch::Tensor> probs_or_mask,
|
||||
torch::optional<torch::Tensor> random_seed,
|
||||
double seed2_contribution) const;
|
||||
|
||||
@@ -202,6 +202,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
|
||||
torch::optional<torch::Tensor> seeds,
|
||||
torch::optional<std::vector<int64_t>> seed_offsets,
|
||||
const std::vector<int64_t>& fanouts, bool replace, bool layer,
|
||||
bool returning_indices_is_optional,
|
||||
torch::optional<torch::Tensor> type_per_edge,
|
||||
torch::optional<torch::Tensor> probs_or_mask,
|
||||
torch::optional<torch::Tensor> node_type_offset,
|
||||
@@ -519,9 +520,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO @mfbalin: remove true from here once fetching indices later is
|
||||
// setup.
|
||||
if (true || layer || utils::is_on_gpu(indices)) {
|
||||
if (!returning_indices_is_optional || utils::is_on_gpu(indices)) {
|
||||
output_indices = Gather(indices, picked_eids);
|
||||
}
|
||||
}));
|
||||
|
||||
@@ -804,6 +804,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
|
||||
torch::optional<torch::Tensor> seeds,
|
||||
torch::optional<std::vector<int64_t>> seed_offsets,
|
||||
const std::vector<int64_t>& fanouts, bool replace, bool layer,
|
||||
bool returning_indices_is_optional,
|
||||
torch::optional<torch::Tensor> probs_or_mask,
|
||||
torch::optional<torch::Tensor> random_seed,
|
||||
double seed2_contribution) const {
|
||||
@@ -828,9 +829,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
|
||||
c10::DeviceType::CUDA, "SampleNeighbors", {
|
||||
return ops::SampleNeighbors(
|
||||
indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
|
||||
type_per_edge_, probs_or_mask, node_type_offset_,
|
||||
node_type_to_id_, edge_type_to_id_, random_seed,
|
||||
seed2_contribution);
|
||||
returning_indices_is_optional, type_per_edge_, probs_or_mask,
|
||||
node_type_offset_, node_type_to_id_, edge_type_to_id_,
|
||||
random_seed, seed2_contribution);
|
||||
});
|
||||
}
|
||||
TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU.");
|
||||
|
||||
@@ -141,7 +141,7 @@ class DataLoader(torch_data.DataLoader):
|
||||
of the computations can run simultaneously with it. Setting it to a too
|
||||
high value will limit the amount of overlap while setting it too low may
|
||||
cause the PCI-e bandwidth to not get fully utilized. Manually tuned
|
||||
default is 6144, meaning around 3-4 Streaming Multiprocessors.
|
||||
default is 10240, meaning around 5-7 Streaming Multiprocessors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -152,7 +152,7 @@ class DataLoader(torch_data.DataLoader):
|
||||
overlap_graph_fetch=False,
|
||||
num_gpu_cached_edges=0,
|
||||
gpu_cache_threshold=1,
|
||||
max_uva_threads=6144,
|
||||
max_uva_threads=10240,
|
||||
):
|
||||
# Multiprocessing requires two modifications to the datapipe:
|
||||
#
|
||||
@@ -215,15 +215,38 @@ class DataLoader(torch_data.DataLoader):
|
||||
gpu_graph_cache = construct_gpu_graph_cache(
|
||||
sampler, num_gpu_cached_edges, gpu_cache_threshold
|
||||
)
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
sampler,
|
||||
sampler.fetch_and_sample(
|
||||
gpu_graph_cache,
|
||||
get_host_to_device_uva_stream(),
|
||||
1,
|
||||
),
|
||||
)
|
||||
if (
|
||||
sampler.sampler.__name__ == "sample_layer_neighbors"
|
||||
or gpu_graph_cache is not None
|
||||
):
|
||||
# This code path is not faster for sample_neighbors.
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
sampler,
|
||||
sampler.fetch_and_sample(
|
||||
gpu_graph_cache,
|
||||
get_host_to_device_uva_stream(),
|
||||
1,
|
||||
),
|
||||
)
|
||||
elif sampler.sampler.__name__ == "sample_neighbors":
|
||||
# This code path is faster for sample_neighbors.
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
sampler,
|
||||
sampler.datapipe.sample_per_layer(
|
||||
sampler=sampler.sampler,
|
||||
fanout=sampler.fanout,
|
||||
replace=sampler.replace,
|
||||
prob_name=sampler.prob_name,
|
||||
returning_indices_is_optional=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"overlap_graph_fetch is supported only for "
|
||||
"sample_neighbor and sample_layer_neighbor."
|
||||
)
|
||||
|
||||
# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
|
||||
# before it. This enables enables non_blocking copies to the device.
|
||||
|
||||
@@ -694,6 +694,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
fanouts: torch.Tensor,
|
||||
replace: bool = False,
|
||||
probs_name: Optional[str] = None,
|
||||
returning_indices_is_optional: bool = False,
|
||||
) -> SampledSubgraphImpl:
|
||||
"""Sample neighboring edges of the given nodes and return the induced
|
||||
subgraph.
|
||||
@@ -733,6 +734,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
corresponding to each neighboring edge of a node. It must be a 1D
|
||||
floating-point or boolean tensor, with the number of elements
|
||||
equalling the total number of edges.
|
||||
returning_indices_is_optional: bool
|
||||
Boolean indicating whether it is okay for the call to this function
|
||||
to leave the indices tensor uninitialized. In this case, it is the
|
||||
user's responsibility to gather it using the edge ids.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -776,6 +781,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
fanouts,
|
||||
replace=replace,
|
||||
probs_or_mask=probs_or_mask,
|
||||
returning_indices_is_optional=returning_indices_is_optional,
|
||||
)
|
||||
return self._convert_to_sampled_subgraph(
|
||||
C_sampled_subgraph, seed_offsets
|
||||
@@ -827,6 +833,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
fanouts: torch.Tensor,
|
||||
replace: bool = False,
|
||||
probs_or_mask: Optional[torch.Tensor] = None,
|
||||
returning_indices_is_optional: bool = False,
|
||||
) -> torch.ScriptObject:
|
||||
"""Sample neighboring edges of the given nodes and return the induced
|
||||
subgraph.
|
||||
@@ -865,6 +872,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
corresponding to each neighboring edge of a node. It must be a 1D
|
||||
floating-point or boolean tensor, with the number of elements
|
||||
equalling the total number of edges.
|
||||
returning_indices_is_optional: bool
|
||||
Boolean indicating whether it is okay for the call to this function
|
||||
to leave the indices tensor uninitialized. In this case, it is the
|
||||
user's responsibility to gather it using the edge ids.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -879,6 +890,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
fanouts.tolist(),
|
||||
replace,
|
||||
False, # is_labor
|
||||
returning_indices_is_optional,
|
||||
probs_or_mask,
|
||||
None, # random_seed, labor parameter
|
||||
0, # seed2_contribution, labor_parameter
|
||||
@@ -890,6 +902,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
fanouts: torch.Tensor,
|
||||
replace: bool = False,
|
||||
probs_name: Optional[str] = None,
|
||||
returning_indices_is_optional: bool = False,
|
||||
random_seed: torch.Tensor = None,
|
||||
seed2_contribution: float = 0.0,
|
||||
) -> SampledSubgraphImpl:
|
||||
@@ -933,6 +946,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
corresponding to each neighboring edge of a node. It must be a 1D
|
||||
floating-point or boolean tensor, with the number of elements
|
||||
equalling the total number of edges.
|
||||
returning_indices_is_optional: bool
|
||||
Boolean indicating whether it is okay for the call to this function
|
||||
to leave the indices tensor uninitialized. In this case, it is the
|
||||
user's responsibility to gather it using the edge ids.
|
||||
random_seed: torch.Tensor, optional
|
||||
An int64 tensor with one or two elements.
|
||||
|
||||
@@ -1012,6 +1029,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
fanouts.tolist(),
|
||||
replace,
|
||||
True, # is_labor
|
||||
returning_indices_is_optional,
|
||||
probs_or_mask,
|
||||
random_seed,
|
||||
seed2_contribution,
|
||||
|
||||
@@ -6,7 +6,12 @@ import torch
|
||||
from torch.utils.data import functional_datapipe
|
||||
from torch.utils.data.datapipes.iter import Mapper
|
||||
|
||||
from ..base import ORIGINAL_EDGE_ID
|
||||
from ..base import (
|
||||
etype_str_to_tuple,
|
||||
get_host_to_device_uva_stream,
|
||||
index_select,
|
||||
ORIGINAL_EDGE_ID,
|
||||
)
|
||||
from ..internal import compact_csc_format, unique_and_compact_csc_formats
|
||||
from ..minibatch_transformer import MiniBatchTransformer
|
||||
|
||||
@@ -138,6 +143,8 @@ class FetchInsubgraphData(Mapper):
|
||||
delattr(minibatch, "_seeds")
|
||||
delattr(minibatch, "_seed_offsets")
|
||||
|
||||
seeds.record_stream(torch.cuda.current_stream())
|
||||
|
||||
def record_stream(tensor):
|
||||
if stream is not None and tensor.is_cuda:
|
||||
tensor.record_stream(stream)
|
||||
@@ -251,12 +258,40 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
|
||||
class SamplePerLayer(MiniBatchTransformer):
|
||||
"""Sample neighbor edges from a graph for a single layer."""
|
||||
|
||||
def __init__(self, datapipe, sampler, fanout, replace, prob_name):
|
||||
super().__init__(datapipe, self._sample_per_layer)
|
||||
def __init__(
|
||||
self,
|
||||
datapipe,
|
||||
sampler,
|
||||
fanout,
|
||||
replace,
|
||||
prob_name,
|
||||
returning_indices_is_optional=False,
|
||||
):
|
||||
graph = sampler.__self__
|
||||
if returning_indices_is_optional and graph.indices.is_pinned():
|
||||
datapipe = datapipe.transform(self._sample_per_layer)
|
||||
datapipe = (
|
||||
datapipe.transform(partial(self._fetch_indices, graph.indices))
|
||||
.buffer()
|
||||
.wait()
|
||||
)
|
||||
if graph.type_per_edge is not None:
|
||||
# Hetero case.
|
||||
datapipe = datapipe.transform(
|
||||
partial(
|
||||
self._subtract_hetero_indices_offset,
|
||||
graph._node_type_offset_list,
|
||||
graph.node_type_to_id,
|
||||
)
|
||||
)
|
||||
super().__init__(datapipe)
|
||||
else:
|
||||
super().__init__(datapipe, self._sample_per_layer)
|
||||
self.sampler = sampler
|
||||
self.fanout = fanout
|
||||
self.replace = replace
|
||||
self.prob_name = prob_name
|
||||
self.returning_indices_is_optional = returning_indices_is_optional
|
||||
|
||||
def _sample_per_layer(self, minibatch):
|
||||
kwargs = {
|
||||
@@ -269,11 +304,61 @@ class SamplePerLayer(MiniBatchTransformer):
|
||||
self.fanout,
|
||||
self.replace,
|
||||
self.prob_name,
|
||||
self.returning_indices_is_optional,
|
||||
**kwargs,
|
||||
)
|
||||
minibatch.sampled_subgraphs.insert(0, subgraph)
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _fetch_indices(indices, minibatch):
|
||||
stream = torch.cuda.current_stream()
|
||||
host_to_device_stream = get_host_to_device_uva_stream()
|
||||
host_to_device_stream.wait_stream(stream)
|
||||
|
||||
def record_stream(tensor):
|
||||
tensor.record_stream(stream)
|
||||
return tensor
|
||||
|
||||
with torch.cuda.stream(host_to_device_stream):
|
||||
minibatch._indices_needs_offset_subtraction = False
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
if isinstance(subgraph.sampled_csc, dict):
|
||||
for etype, pair in subgraph.sampled_csc.items():
|
||||
if pair.indices is None:
|
||||
edge_ids = subgraph._sampled_edge_ids[etype]
|
||||
edge_ids.record_stream(torch.cuda.current_stream())
|
||||
pair.indices = record_stream(
|
||||
index_select(indices, edge_ids)
|
||||
)
|
||||
minibatch._indices_needs_offset_subtraction = True
|
||||
elif subgraph.sampled_csc.indices is None:
|
||||
subgraph._sampled_edge_ids.record_stream(
|
||||
torch.cuda.current_stream()
|
||||
)
|
||||
subgraph.sampled_csc.indices = record_stream(
|
||||
index_select(indices, subgraph._sampled_edge_ids)
|
||||
)
|
||||
minibatch._indices_needs_offset_subtraction = True
|
||||
subgraph._sampled_edge_ids = None
|
||||
minibatch.wait = torch.cuda.current_stream().record_event().wait
|
||||
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _subtract_hetero_indices_offset(
|
||||
node_type_offset, node_type_to_id, minibatch
|
||||
):
|
||||
if minibatch._indices_needs_offset_subtraction:
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
for etype, pair in subgraph.sampled_csc.items():
|
||||
src_ntype = etype_str_to_tuple(etype)[0]
|
||||
src_ntype_id = node_type_to_id[src_ntype]
|
||||
pair.indices -= node_type_offset[src_ntype_id]
|
||||
delattr(minibatch, "_indices_needs_offset_subtraction")
|
||||
|
||||
return minibatch
|
||||
|
||||
|
||||
@functional_datapipe("compact_per_layer")
|
||||
class CompactPerLayer(MiniBatchTransformer):
|
||||
|
||||
@@ -44,8 +44,9 @@ def get_hetero_graph():
|
||||
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
|
||||
@pytest.mark.parametrize("sorted", [False, True])
|
||||
@pytest.mark.parametrize("num_cached_edges", [0, 10])
|
||||
@pytest.mark.parametrize("is_pinned", [False, True])
|
||||
def test_NeighborSampler_GraphFetch(
|
||||
hetero, prob_name, sorted, num_cached_edges
|
||||
hetero, prob_name, sorted, num_cached_edges, is_pinned
|
||||
):
|
||||
if sorted:
|
||||
items = torch.arange(3)
|
||||
@@ -53,7 +54,8 @@ def test_NeighborSampler_GraphFetch(
|
||||
items = torch.tensor([2, 0, 1])
|
||||
names = "seeds"
|
||||
itemset = gb.ItemSet(items, names=names)
|
||||
graph = get_hetero_graph().to(F.ctx())
|
||||
graph = get_hetero_graph()
|
||||
graph = graph.pin_memory_() if is_pinned else graph.to(F.ctx())
|
||||
if hetero:
|
||||
itemset = gb.HeteroItemSet({"n3": itemset})
|
||||
else:
|
||||
@@ -65,7 +67,7 @@ def test_NeighborSampler_GraphFetch(
|
||||
partial(gb.NeighborSampler._prepare, graph.node_type_to_id)
|
||||
)
|
||||
sample_per_layer = gb.SamplePerLayer(
|
||||
datapipe, graph.sample_neighbors, fanout, False, prob_name
|
||||
datapipe, graph.sample_neighbors, fanout, False, prob_name, False
|
||||
)
|
||||
compact_per_layer = sample_per_layer.compact_per_layer(True)
|
||||
gb.seed(123)
|
||||
@@ -92,6 +94,21 @@ def test_NeighborSampler_GraphFetch(
|
||||
for a, b in zip(expected_results, new_results):
|
||||
assert repr(a) == repr(b)
|
||||
|
||||
def remove_input_nodes(minibatch):
|
||||
minibatch.input_nodes = None
|
||||
return minibatch
|
||||
|
||||
datapipe = item_sampler.sample_neighbor(
|
||||
graph, [fanout], False, prob_name=prob_name
|
||||
)
|
||||
datapipe = datapipe.transform(remove_input_nodes)
|
||||
dataloader = gb.DataLoader(datapipe, overlap_graph_fetch=True)
|
||||
gb.seed(123)
|
||||
new_results = list(dataloader)
|
||||
assert len(expected_results) == len(new_results)
|
||||
for a, b in zip(expected_results, new_results):
|
||||
assert repr(a) == repr(b)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_dependency", [False, True])
|
||||
@pytest.mark.parametrize("overlap_graph_fetch", [False, True])
|
||||
|
||||
@@ -77,9 +77,8 @@ def test_gpu_sampling_DataLoader(
|
||||
B = 4
|
||||
num_layers = 2
|
||||
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seeds")
|
||||
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to(
|
||||
F.ctx()
|
||||
)
|
||||
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
|
||||
graph = graph.pin_memory_() if overlap_graph_fetch else graph.to(F.ctx())
|
||||
features = {}
|
||||
keys = [
|
||||
("node", None, "a"),
|
||||
|
||||
Reference in New Issue
Block a user