diff --git a/graphbolt/include/graphbolt/cuda_sampling_ops.h b/graphbolt/include/graphbolt/cuda_sampling_ops.h index 8dd001f4ea..3d22204e2e 100644 --- a/graphbolt/include/graphbolt/cuda_sampling_ops.h +++ b/graphbolt/include/graphbolt/cuda_sampling_ops.h @@ -54,6 +54,8 @@ namespace ops { * @param layer Boolean indicating whether neighbors should be sampled in a * layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of * sampled edges, see arXiv:2210.13339. + * @param returning_indices_is_optional Boolean indicating whether returning + * indices tensor is optional. * @param type_per_edge A tensor representing the type of each edge, if present. * @param probs_or_mask An optional tensor with (unnormalized) probabilities * corresponding to each neighboring edge of a node. It must be @@ -76,6 +78,7 @@ c10::intrusive_ptr SampleNeighbors( torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, + bool returning_indices_is_optional, torch::optional type_per_edge = torch::nullopt, torch::optional probs_or_mask = torch::nullopt, torch::optional node_type_offset = torch::nullopt, diff --git a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h index 420d8522a4..652b768896 100644 --- a/graphbolt/include/graphbolt/fused_csc_sampling_graph.h +++ b/graphbolt/include/graphbolt/fused_csc_sampling_graph.h @@ -339,6 +339,8 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { * @param layer Boolean indicating whether neighbors should be sampled in a * layer sampling fashion. Uses the LABOR-0 algorithm to increase overlap of * sampled edges, see arXiv:2210.13339. + * @param returning_indices_is_optional Boolean indicating whether returning + * indices tensor is optional. * @param probs_or_mask An optional edge attribute tensor for probablities * or masks. This attribute tensor should contain (unnormalized) * probabilities corresponding to each neighboring edge of a node. It must be @@ -355,6 +357,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, + bool returning_indices_is_optional, torch::optional probs_or_mask, torch::optional random_seed, double seed2_contribution) const; diff --git a/graphbolt/src/cuda/neighbor_sampler.cu b/graphbolt/src/cuda/neighbor_sampler.cu index cecb1dec11..76fe26c6d8 100644 --- a/graphbolt/src/cuda/neighbor_sampler.cu +++ b/graphbolt/src/cuda/neighbor_sampler.cu @@ -202,6 +202,7 @@ c10::intrusive_ptr SampleNeighbors( torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, + bool returning_indices_is_optional, torch::optional type_per_edge, torch::optional probs_or_mask, torch::optional node_type_offset, @@ -519,9 +520,7 @@ c10::intrusive_ptr SampleNeighbors( } } - // TODO @mfbalin: remove true from here once fetching indices later is - // setup. - if (true || layer || utils::is_on_gpu(indices)) { + if (!returning_indices_is_optional || utils::is_on_gpu(indices)) { output_indices = Gather(indices, picked_eids); } })); diff --git a/graphbolt/src/fused_csc_sampling_graph.cc b/graphbolt/src/fused_csc_sampling_graph.cc index 1a7914775f..4a5c197bae 100644 --- a/graphbolt/src/fused_csc_sampling_graph.cc +++ b/graphbolt/src/fused_csc_sampling_graph.cc @@ -804,6 +804,7 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( torch::optional seeds, torch::optional> seed_offsets, const std::vector& fanouts, bool replace, bool layer, + bool returning_indices_is_optional, torch::optional probs_or_mask, torch::optional random_seed, double seed2_contribution) const { @@ -828,9 +829,9 @@ c10::intrusive_ptr FusedCSCSamplingGraph::SampleNeighbors( c10::DeviceType::CUDA, "SampleNeighbors", { return ops::SampleNeighbors( indptr_, indices_, seeds, seed_offsets, fanouts, replace, layer, - type_per_edge_, probs_or_mask, node_type_offset_, - node_type_to_id_, edge_type_to_id_, random_seed, - seed2_contribution); + returning_indices_is_optional, type_per_edge_, probs_or_mask, + node_type_offset_, node_type_to_id_, edge_type_to_id_, + random_seed, seed2_contribution); }); } TORCH_CHECK(seeds.has_value(), "Nodes can not be None on the CPU."); diff --git a/python/dgl/graphbolt/dataloader.py b/python/dgl/graphbolt/dataloader.py index c332befcbc..4caa67ef9f 100644 --- a/python/dgl/graphbolt/dataloader.py +++ b/python/dgl/graphbolt/dataloader.py @@ -141,7 +141,7 @@ class DataLoader(torch_data.DataLoader): of the computations can run simultaneously with it. Setting it to a too high value will limit the amount of overlap while setting it too low may cause the PCI-e bandwidth to not get fully utilized. Manually tuned - default is 6144, meaning around 3-4 Streaming Multiprocessors. + default is 10240, meaning around 5-7 Streaming Multiprocessors. """ def __init__( @@ -152,7 +152,7 @@ class DataLoader(torch_data.DataLoader): overlap_graph_fetch=False, num_gpu_cached_edges=0, gpu_cache_threshold=1, - max_uva_threads=6144, + max_uva_threads=10240, ): # Multiprocessing requires two modifications to the datapipe: # @@ -215,15 +215,38 @@ class DataLoader(torch_data.DataLoader): gpu_graph_cache = construct_gpu_graph_cache( sampler, num_gpu_cached_edges, gpu_cache_threshold ) - datapipe_graph = replace_dp( - datapipe_graph, - sampler, - sampler.fetch_and_sample( - gpu_graph_cache, - get_host_to_device_uva_stream(), - 1, - ), - ) + if ( + sampler.sampler.__name__ == "sample_layer_neighbors" + or gpu_graph_cache is not None + ): + # This code path is not faster for sample_neighbors. + datapipe_graph = replace_dp( + datapipe_graph, + sampler, + sampler.fetch_and_sample( + gpu_graph_cache, + get_host_to_device_uva_stream(), + 1, + ), + ) + elif sampler.sampler.__name__ == "sample_neighbors": + # This code path is faster for sample_neighbors. + datapipe_graph = replace_dp( + datapipe_graph, + sampler, + sampler.datapipe.sample_per_layer( + sampler=sampler.sampler, + fanout=sampler.fanout, + replace=sampler.replace, + prob_name=sampler.prob_name, + returning_indices_is_optional=True, + ), + ) + else: + raise AssertionError( + "overlap_graph_fetch is supported only for " + "sample_neighbor and sample_layer_neighbor." + ) # (4) Cut datapipe at CopyTo and wrap with pinning and prefetching # before it. This enables enables non_blocking copies to the device. diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index aa2d504242..f311a28878 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -694,6 +694,7 @@ class FusedCSCSamplingGraph(SamplingGraph): fanouts: torch.Tensor, replace: bool = False, probs_name: Optional[str] = None, + returning_indices_is_optional: bool = False, ) -> SampledSubgraphImpl: """Sample neighboring edges of the given nodes and return the induced subgraph. @@ -733,6 +734,10 @@ class FusedCSCSamplingGraph(SamplingGraph): corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges. + returning_indices_is_optional: bool + Boolean indicating whether it is okay for the call to this function + to leave the indices tensor uninitialized. In this case, it is the + user's responsibility to gather it using the edge ids. Returns ------- @@ -776,6 +781,7 @@ class FusedCSCSamplingGraph(SamplingGraph): fanouts, replace=replace, probs_or_mask=probs_or_mask, + returning_indices_is_optional=returning_indices_is_optional, ) return self._convert_to_sampled_subgraph( C_sampled_subgraph, seed_offsets @@ -827,6 +833,7 @@ class FusedCSCSamplingGraph(SamplingGraph): fanouts: torch.Tensor, replace: bool = False, probs_or_mask: Optional[torch.Tensor] = None, + returning_indices_is_optional: bool = False, ) -> torch.ScriptObject: """Sample neighboring edges of the given nodes and return the induced subgraph. @@ -865,6 +872,10 @@ class FusedCSCSamplingGraph(SamplingGraph): corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges. + returning_indices_is_optional: bool + Boolean indicating whether it is okay for the call to this function + to leave the indices tensor uninitialized. In this case, it is the + user's responsibility to gather it using the edge ids. Returns ------- @@ -879,6 +890,7 @@ class FusedCSCSamplingGraph(SamplingGraph): fanouts.tolist(), replace, False, # is_labor + returning_indices_is_optional, probs_or_mask, None, # random_seed, labor parameter 0, # seed2_contribution, labor_parameter @@ -890,6 +902,7 @@ class FusedCSCSamplingGraph(SamplingGraph): fanouts: torch.Tensor, replace: bool = False, probs_name: Optional[str] = None, + returning_indices_is_optional: bool = False, random_seed: torch.Tensor = None, seed2_contribution: float = 0.0, ) -> SampledSubgraphImpl: @@ -933,6 +946,10 @@ class FusedCSCSamplingGraph(SamplingGraph): corresponding to each neighboring edge of a node. It must be a 1D floating-point or boolean tensor, with the number of elements equalling the total number of edges. + returning_indices_is_optional: bool + Boolean indicating whether it is okay for the call to this function + to leave the indices tensor uninitialized. In this case, it is the + user's responsibility to gather it using the edge ids. random_seed: torch.Tensor, optional An int64 tensor with one or two elements. @@ -1012,6 +1029,7 @@ class FusedCSCSamplingGraph(SamplingGraph): fanouts.tolist(), replace, True, # is_labor + returning_indices_is_optional, probs_or_mask, random_seed, seed2_contribution, diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index e95f432473..8c997eb684 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -6,7 +6,12 @@ import torch from torch.utils.data import functional_datapipe from torch.utils.data.datapipes.iter import Mapper -from ..base import ORIGINAL_EDGE_ID +from ..base import ( + etype_str_to_tuple, + get_host_to_device_uva_stream, + index_select, + ORIGINAL_EDGE_ID, +) from ..internal import compact_csc_format, unique_and_compact_csc_formats from ..minibatch_transformer import MiniBatchTransformer @@ -138,6 +143,8 @@ class FetchInsubgraphData(Mapper): delattr(minibatch, "_seeds") delattr(minibatch, "_seed_offsets") + seeds.record_stream(torch.cuda.current_stream()) + def record_stream(tensor): if stream is not None and tensor.is_cuda: tensor.record_stream(stream) @@ -251,12 +258,40 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer): class SamplePerLayer(MiniBatchTransformer): """Sample neighbor edges from a graph for a single layer.""" - def __init__(self, datapipe, sampler, fanout, replace, prob_name): - super().__init__(datapipe, self._sample_per_layer) + def __init__( + self, + datapipe, + sampler, + fanout, + replace, + prob_name, + returning_indices_is_optional=False, + ): + graph = sampler.__self__ + if returning_indices_is_optional and graph.indices.is_pinned(): + datapipe = datapipe.transform(self._sample_per_layer) + datapipe = ( + datapipe.transform(partial(self._fetch_indices, graph.indices)) + .buffer() + .wait() + ) + if graph.type_per_edge is not None: + # Hetero case. + datapipe = datapipe.transform( + partial( + self._subtract_hetero_indices_offset, + graph._node_type_offset_list, + graph.node_type_to_id, + ) + ) + super().__init__(datapipe) + else: + super().__init__(datapipe, self._sample_per_layer) self.sampler = sampler self.fanout = fanout self.replace = replace self.prob_name = prob_name + self.returning_indices_is_optional = returning_indices_is_optional def _sample_per_layer(self, minibatch): kwargs = { @@ -269,11 +304,61 @@ class SamplePerLayer(MiniBatchTransformer): self.fanout, self.replace, self.prob_name, + self.returning_indices_is_optional, **kwargs, ) minibatch.sampled_subgraphs.insert(0, subgraph) return minibatch + @staticmethod + def _fetch_indices(indices, minibatch): + stream = torch.cuda.current_stream() + host_to_device_stream = get_host_to_device_uva_stream() + host_to_device_stream.wait_stream(stream) + + def record_stream(tensor): + tensor.record_stream(stream) + return tensor + + with torch.cuda.stream(host_to_device_stream): + minibatch._indices_needs_offset_subtraction = False + subgraph = minibatch.sampled_subgraphs[0] + if isinstance(subgraph.sampled_csc, dict): + for etype, pair in subgraph.sampled_csc.items(): + if pair.indices is None: + edge_ids = subgraph._sampled_edge_ids[etype] + edge_ids.record_stream(torch.cuda.current_stream()) + pair.indices = record_stream( + index_select(indices, edge_ids) + ) + minibatch._indices_needs_offset_subtraction = True + elif subgraph.sampled_csc.indices is None: + subgraph._sampled_edge_ids.record_stream( + torch.cuda.current_stream() + ) + subgraph.sampled_csc.indices = record_stream( + index_select(indices, subgraph._sampled_edge_ids) + ) + minibatch._indices_needs_offset_subtraction = True + subgraph._sampled_edge_ids = None + minibatch.wait = torch.cuda.current_stream().record_event().wait + + return minibatch + + @staticmethod + def _subtract_hetero_indices_offset( + node_type_offset, node_type_to_id, minibatch + ): + if minibatch._indices_needs_offset_subtraction: + subgraph = minibatch.sampled_subgraphs[0] + for etype, pair in subgraph.sampled_csc.items(): + src_ntype = etype_str_to_tuple(etype)[0] + src_ntype_id = node_type_to_id[src_ntype] + pair.indices -= node_type_offset[src_ntype_id] + delattr(minibatch, "_indices_needs_offset_subtraction") + + return minibatch + @functional_datapipe("compact_per_layer") class CompactPerLayer(MiniBatchTransformer): diff --git a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py index 72097c69cb..6c5818bbf6 100644 --- a/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py +++ b/tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py @@ -44,8 +44,9 @@ def get_hetero_graph(): @pytest.mark.parametrize("prob_name", [None, "weight", "mask"]) @pytest.mark.parametrize("sorted", [False, True]) @pytest.mark.parametrize("num_cached_edges", [0, 10]) +@pytest.mark.parametrize("is_pinned", [False, True]) def test_NeighborSampler_GraphFetch( - hetero, prob_name, sorted, num_cached_edges + hetero, prob_name, sorted, num_cached_edges, is_pinned ): if sorted: items = torch.arange(3) @@ -53,7 +54,8 @@ def test_NeighborSampler_GraphFetch( items = torch.tensor([2, 0, 1]) names = "seeds" itemset = gb.ItemSet(items, names=names) - graph = get_hetero_graph().to(F.ctx()) + graph = get_hetero_graph() + graph = graph.pin_memory_() if is_pinned else graph.to(F.ctx()) if hetero: itemset = gb.HeteroItemSet({"n3": itemset}) else: @@ -65,7 +67,7 @@ def test_NeighborSampler_GraphFetch( partial(gb.NeighborSampler._prepare, graph.node_type_to_id) ) sample_per_layer = gb.SamplePerLayer( - datapipe, graph.sample_neighbors, fanout, False, prob_name + datapipe, graph.sample_neighbors, fanout, False, prob_name, False ) compact_per_layer = sample_per_layer.compact_per_layer(True) gb.seed(123) @@ -92,6 +94,21 @@ def test_NeighborSampler_GraphFetch( for a, b in zip(expected_results, new_results): assert repr(a) == repr(b) + def remove_input_nodes(minibatch): + minibatch.input_nodes = None + return minibatch + + datapipe = item_sampler.sample_neighbor( + graph, [fanout], False, prob_name=prob_name + ) + datapipe = datapipe.transform(remove_input_nodes) + dataloader = gb.DataLoader(datapipe, overlap_graph_fetch=True) + gb.seed(123) + new_results = list(dataloader) + assert len(expected_results) == len(new_results) + for a, b in zip(expected_results, new_results): + assert repr(a) == repr(b) + @pytest.mark.parametrize("layer_dependency", [False, True]) @pytest.mark.parametrize("overlap_graph_fetch", [False, True]) diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 79791771ed..ba22bfdda2 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -77,9 +77,8 @@ def test_gpu_sampling_DataLoader( B = 4 num_layers = 2 itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seeds") - graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True).to( - F.ctx() - ) + graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True) + graph = graph.pin_memory_() if overlap_graph_fetch else graph.to(F.ctx()) features = {} keys = [ ("node", None, "a"),