[GraphBolt][CUDA] Fetch indices for NS later. (#7665)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-07 10:52:06 -04:00
committed by GitHub
parent 01d10e5820
commit b5ee45fd1a
9 changed files with 174 additions and 26 deletions

View File

@@ -54,6 +54,8 @@ namespace ops {
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param returning_indices_is_optional Boolean indicating whether returning
* indices tensor is optional.
* @param type_per_edge A tensor representing the type of each edge, if present.
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
@@ -76,6 +78,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool returning_indices_is_optional,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt,
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt,
torch::optional<torch::Tensor> node_type_offset = torch::nullopt,

View File

@@ -339,6 +339,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param layer Boolean indicating whether neighbors should be sampled in a
* layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of
* sampled edges, see arXiv:2210.13339.
* @param returning_indices_is_optional Boolean indicating whether returning
* indices tensor is optional.
* @param probs_or_mask An optional edge attribute tensor for probablities
* or masks. This attribute tensor should contain (unnormalized)
* probabilities corresponding to each neighboring edge of a node. It must be
@@ -355,6 +357,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool returning_indices_is_optional,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const;

View File

@@ -202,6 +202,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool returning_indices_is_optional,
torch::optional<torch::Tensor> type_per_edge,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> node_type_offset,
@@ -519,9 +520,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}
// TODO @mfbalin: remove true from here once fetching indices later is
// setup.
if (true || layer || utils::is_on_gpu(indices)) {
if (!returning_indices_is_optional || utils::is_on_gpu(indices)) {
output_indices = Gather(indices, picked_eids);
}
}));

View File

@@ -804,6 +804,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
torch::optional<torch::Tensor> seeds,
torch::optional<std::vector<int64_t>> seed_offsets,
const std::vector<int64_t>& fanouts, bool replace, bool layer,
bool returning_indices_is_optional,
torch::optional<torch::Tensor> probs_or_mask,
torch::optional<torch::Tensor> random_seed,
double seed2_contribution) const {
@@ -828,9 +829,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
c10::DeviceType::CUDA, "SampleNeighbors", {
return ops::SampleNeighbors(
indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer,
type_per_edge_, probs_or_mask, node_type_offset_,
node_type_to_id_, edge_type_to_id_, random_seed,
seed2_contribution);
returning_indices_is_optional, type_per_edge_, probs_or_mask,
node_type_offset_, node_type_to_id_, edge_type_to_id_,
random_seed, seed2_contribution);
});
}
TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU.");

View File

@@ -141,7 +141,7 @@ class DataLoader(torch_data.DataLoader):
of the computations can run simultaneously with it. Setting it to a too
high value will limit the amount of overlap while setting it too low may
cause the PCI-e bandwidth to not get fully utilized. Manually tuned
default is 6144, meaning around 3-4 Streaming Multiprocessors.
default is 10240, meaning around 5-7 Streaming Multiprocessors.
"""
def __init__(
@@ -152,7 +152,7 @@ class DataLoader(torch_data.DataLoader):
overlap_graph_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
max_uva_threads=6144,
max_uva_threads=10240,
):
# Multiprocessing requires two modifications to the datapipe:
#
@@ -215,15 +215,38 @@ class DataLoader(torch_data.DataLoader):
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
gpu_graph_cache,
get_host_to_device_uva_stream(),
1,
),
)
if (
sampler.sampler.__name__ == "sample_layer_neighbors"
or gpu_graph_cache is not None
):
# This code path is not faster for sample_neighbors.
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
gpu_graph_cache,
get_host_to_device_uva_stream(),
1,
),
)
elif sampler.sampler.__name__ == "sample_neighbors":
# This code path is faster for sample_neighbors.
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.datapipe.sample_per_layer(
sampler=sampler.sampler,
fanout=sampler.fanout,
replace=sampler.replace,
prob_name=sampler.prob_name,
returning_indices_is_optional=True,
),
)
else:
raise AssertionError(
"overlap_graph_fetch is supported only for "
"sample_neighbor and sample_layer_neighbor."
)
# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
# before it. This enables enables non_blocking copies to the device.

View File

@@ -694,6 +694,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
returning_indices_is_optional: bool = False,
) -> SampledSubgraphImpl:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
@@ -733,6 +734,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
returning_indices_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
Returns
-------
@@ -776,6 +781,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts,
replace=replace,
probs_or_mask=probs_or_mask,
returning_indices_is_optional=returning_indices_is_optional,
)
return self._convert_to_sampled_subgraph(
C_sampled_subgraph, seed_offsets
@@ -827,6 +833,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_or_mask: Optional[torch.Tensor] = None,
returning_indices_is_optional: bool = False,
) -> torch.ScriptObject:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
@@ -865,6 +872,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
returning_indices_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
Returns
-------
@@ -879,6 +890,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts.tolist(),
replace,
False, # is_labor
returning_indices_is_optional,
probs_or_mask,
None, # random_seed, labor parameter
0, # seed2_contribution, labor_parameter
@@ -890,6 +902,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts: torch.Tensor,
replace: bool = False,
probs_name: Optional[str] = None,
returning_indices_is_optional: bool = False,
random_seed: torch.Tensor = None,
seed2_contribution: float = 0.0,
) -> SampledSubgraphImpl:
@@ -933,6 +946,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
returning_indices_is_optional: bool
Boolean indicating whether it is okay for the call to this function
to leave the indices tensor uninitialized. In this case, it is the
user's responsibility to gather it using the edge ids.
random_seed: torch.Tensor, optional
An int64 tensor with one or two elements.
@@ -1012,6 +1029,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts.tolist(),
replace,
True, # is_labor
returning_indices_is_optional,
probs_or_mask,
random_seed,
seed2_contribution,

View File

@@ -6,7 +6,12 @@ import torch
from torch.utils.data import functional_datapipe
from torch.utils.data.datapipes.iter import Mapper
from ..base import ORIGINAL_EDGE_ID
from ..base import (
etype_str_to_tuple,
get_host_to_device_uva_stream,
index_select,
ORIGINAL_EDGE_ID,
)
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..minibatch_transformer import MiniBatchTransformer
@@ -138,6 +143,8 @@ class FetchInsubgraphData(Mapper):
delattr(minibatch, "_seeds")
delattr(minibatch, "_seed_offsets")
seeds.record_stream(torch.cuda.current_stream())
def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
@@ -251,12 +258,40 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
class SamplePerLayer(MiniBatchTransformer):
"""Sample neighbor edges from a graph for a single layer."""
def __init__(self, datapipe, sampler, fanout, replace, prob_name):
super().__init__(datapipe, self._sample_per_layer)
def __init__(
self,
datapipe,
sampler,
fanout,
replace,
prob_name,
returning_indices_is_optional=False,
):
graph = sampler.__self__
if returning_indices_is_optional and graph.indices.is_pinned():
datapipe = datapipe.transform(self._sample_per_layer)
datapipe = (
datapipe.transform(partial(self._fetch_indices, graph.indices))
.buffer()
.wait()
)
if graph.type_per_edge is not None:
# Hetero case.
datapipe = datapipe.transform(
partial(
self._subtract_hetero_indices_offset,
graph._node_type_offset_list,
graph.node_type_to_id,
)
)
super().__init__(datapipe)
else:
super().__init__(datapipe, self._sample_per_layer)
self.sampler = sampler
self.fanout = fanout
self.replace = replace
self.prob_name = prob_name
self.returning_indices_is_optional = returning_indices_is_optional
def _sample_per_layer(self, minibatch):
kwargs = {
@@ -269,11 +304,61 @@ class SamplePerLayer(MiniBatchTransformer):
self.fanout,
self.replace,
self.prob_name,
self.returning_indices_is_optional,
**kwargs,
)
minibatch.sampled_subgraphs.insert(0, subgraph)
return minibatch
@staticmethod
def _fetch_indices(indices, minibatch):
stream = torch.cuda.current_stream()
host_to_device_stream = get_host_to_device_uva_stream()
host_to_device_stream.wait_stream(stream)
def record_stream(tensor):
tensor.record_stream(stream)
return tensor
with torch.cuda.stream(host_to_device_stream):
minibatch._indices_needs_offset_subtraction = False
subgraph = minibatch.sampled_subgraphs[0]
if isinstance(subgraph.sampled_csc, dict):
for etype, pair in subgraph.sampled_csc.items():
if pair.indices is None:
edge_ids = subgraph._sampled_edge_ids[etype]
edge_ids.record_stream(torch.cuda.current_stream())
pair.indices = record_stream(
index_select(indices, edge_ids)
)
minibatch._indices_needs_offset_subtraction = True
elif subgraph.sampled_csc.indices is None:
subgraph._sampled_edge_ids.record_stream(
torch.cuda.current_stream()
)
subgraph.sampled_csc.indices = record_stream(
index_select(indices, subgraph._sampled_edge_ids)
)
minibatch._indices_needs_offset_subtraction = True
subgraph._sampled_edge_ids = None
minibatch.wait = torch.cuda.current_stream().record_event().wait
return minibatch
@staticmethod
def _subtract_hetero_indices_offset(
node_type_offset, node_type_to_id, minibatch
):
if minibatch._indices_needs_offset_subtraction:
subgraph = minibatch.sampled_subgraphs[0]
for etype, pair in subgraph.sampled_csc.items():
src_ntype = etype_str_to_tuple(etype)[0]
src_ntype_id = node_type_to_id[src_ntype]
pair.indices -= node_type_offset[src_ntype_id]
delattr(minibatch, "_indices_needs_offset_subtraction")
return minibatch
@functional_datapipe("compact_per_layer")
class CompactPerLayer(MiniBatchTransformer):

View File

@@ -44,8 +44,9 @@ def get_hetero_graph():
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
@pytest.mark.parametrize("sorted", [False, True])
@pytest.mark.parametrize("num_cached_edges", [0, 10])
@pytest.mark.parametrize("is_pinned", [False, True])
def test_NeighborSampler_GraphFetch(
hetero, prob_name, sorted, num_cached_edges
hetero, prob_name, sorted, num_cached_edges, is_pinned
):
if sorted:
items = torch.arange(3)
@@ -53,7 +54,8 @@ def test_NeighborSampler_GraphFetch(
items = torch.tensor([2, 0, 1])
names = "seeds"
itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx())
graph = get_hetero_graph()
graph = graph.pin_memory_() if is_pinned else graph.to(F.ctx())
if hetero:
itemset = gb.HeteroItemSet({"n3": itemset})
else:
@@ -65,7 +67,7 @@ def test_NeighborSampler_GraphFetch(
partial(gb.NeighborSampler._prepare, graph.node_type_to_id)
)
sample_per_layer = gb.SamplePerLayer(
datapipe, graph.sample_neighbors, fanout, False, prob_name
datapipe, graph.sample_neighbors, fanout, False, prob_name, False
)
compact_per_layer = sample_per_layer.compact_per_layer(True)
gb.seed(123)
@@ -92,6 +94,21 @@ def test_NeighborSampler_GraphFetch(
for a, b in zip(expected_results, new_results):
assert repr(a) == repr(b)
def remove_input_nodes(minibatch):
minibatch.input_nodes = None
return minibatch
datapipe = item_sampler.sample_neighbor(
graph, [fanout], False, prob_name=prob_name
)
datapipe = datapipe.transform(remove_input_nodes)
dataloader = gb.DataLoader(datapipe, overlap_graph_fetch=True)
gb.seed(123)
new_results = list(dataloader)
assert len(expected_results) == len(new_results)
for a, b in zip(expected_results, new_results):
assert repr(a) == repr(b)
@pytest.mark.parametrize("layer_dependency", [False, True])
@pytest.mark.parametrize("overlap_graph_fetch", [False, True])

View File

@@ -77,9 +77,8 @@ def test_gpu_sampling_DataLoader(
B = 4
num_layers = 2
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seeds")
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to(
F.ctx()
)
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
graph = graph.pin_memory_() if overlap_graph_fetch else graph.to(F.ctx())
features = {}
keys = [
("node", None, "a"),