From 25210816ffbf18788f06b8aaa02fc390c398097e Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Fri, 16 Aug 2024 17:42:14 -0400 Subject: [PATCH] [GraphBolt][CUDA] Eliminate synchronization for `overlap_graph_fetch`. (#7709) --- graphbolt/src/index_select.cc | 14 ++ graphbolt/src/index_select.h | 7 + graphbolt/src/python_binding.cc | 1 + python/dgl/graphbolt/impl/neighbor_sampler.py | 135 +++++++++--------- .../pytorch/graphbolt/test_dataloader.py | 13 +- 5 files changed, 91 insertions(+), 79 deletions(-) diff --git a/graphbolt/src/index_select.cc b/graphbolt/src/index_select.cc index 114fd60191..8fdc6a4987 100644 --- a/graphbolt/src/index_select.cc +++ b/graphbolt/src/index_select.cc @@ -207,5 +207,19 @@ std::tuple> IndexSelectCSCBatched( return std::make_tuple(output_indptr, results); } +c10::intrusive_ptr< + Future>>> +IndexSelectCSCBatchedAsync( + torch::Tensor indptr, std::vector indices_list, + torch::Tensor nodes, bool with_edge_ids, + torch::optional output_size) { + return async( + [=] { + return IndexSelectCSCBatched( + indptr, indices_list, nodes, with_edge_ids, output_size); + }, + utils::is_on_gpu(nodes)); +} + } // namespace ops } // namespace graphbolt diff --git a/graphbolt/src/index_select.h b/graphbolt/src/index_select.h index f78ad98fe0..2522df6523 100644 --- a/graphbolt/src/index_select.h +++ b/graphbolt/src/index_select.h @@ -92,6 +92,13 @@ std::tuple> IndexSelectCSCBatched( torch::Tensor nodes, bool with_edge_ids, torch::optional output_size); +c10::intrusive_ptr< + Future>>> +IndexSelectCSCBatchedAsync( + torch::Tensor indptr, std::vector indices_list, + torch::Tensor nodes, bool with_edge_ids, + torch::optional output_size); + } // namespace ops } // namespace graphbolt diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index df295bd718..e8d54f9f9a 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -184,6 +184,7 @@ TORCH_LIBRARY(graphbolt, m) { m.def("scatter_async", &ops::ScatterAsync); m.def("index_select_csc", &ops::IndexSelectCSC); m.def("index_select_csc_batched", &ops::IndexSelectCSCBatched); + m.def("index_select_csc_batched_async", &ops::IndexSelectCSCBatchedAsync); m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create); m.def("detect_io_uring", &io_uring::IsAvailable); m.def("set_num_io_uring_threads", &io_uring::SetNumThreads); diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index f5bea47bea..fc834718ef 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -25,7 +25,6 @@ __all__ = [ "LayerNeighborSampler", "SamplePerLayer", "FetchInsubgraphData", - "ConcatHeteroSeeds", "CombineCachedAndFetchedInSubgraph", ] @@ -105,29 +104,6 @@ class CombineCachedAndFetchedInSubgraph(Mapper): return minibatch -@functional_datapipe("concat_hetero_seeds") -class ConcatHeteroSeeds(Mapper): - """Concatenates the seeds into a single tensor in the hetero case.""" - - def __init__(self, datapipe, graph): - super().__init__(datapipe, self._concat) - self.graph = graph - - def _concat(self, minibatch): - seeds = minibatch._seed_nodes - if isinstance(seeds, dict): - ( - seeds, - seed_offsets, - ) = self.graph._convert_to_homogeneous_nodes(seeds) - else: - seed_offsets = None - minibatch._seeds = seeds - minibatch._seed_offsets = seed_offsets - - return minibatch - - @functional_datapipe("fetch_insubgraph_data") class FetchInsubgraphData(MiniBatchTransformer): """Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If @@ -142,20 +118,46 @@ class FetchInsubgraphData(MiniBatchTransformer): graph, prob_name, ): - datapipe = datapipe.concat_hetero_seeds(graph) + datapipe = datapipe.transform(self._concat_hetero_seeds) if graph._gpu_graph_cache is not None: datapipe = datapipe.fetch_cached_insubgraph_data( graph._gpu_graph_cache ) - datapipe = datapipe.transform(self._fetch_per_layer) - datapipe = datapipe.buffer().wait() + datapipe = datapipe.transform(self._fetch_per_layer_stage_1) + datapipe = datapipe.buffer() + datapipe = datapipe.transform(self._fetch_per_layer_stage_2) if graph._gpu_graph_cache is not None: datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name) super().__init__(datapipe) self.graph = graph self.prob_name = prob_name - def _fetch_per_layer(self, minibatch): + def _concat_hetero_seeds(self, minibatch): + """Concatenates the seeds into a single tensor in the hetero case.""" + seeds = minibatch._seed_nodes + if isinstance(seeds, dict): + ( + seeds, + seed_offsets, + ) = self.graph._convert_to_homogeneous_nodes(seeds) + else: + seed_offsets = None + minibatch._seeds = seeds + minibatch._seed_offsets = seed_offsets + + return minibatch + + def _fetch_per_layer_stage_1(self, minibatch): + minibatch._async_handle_fetch = self._fetch_per_layer_async(minibatch) + next(minibatch._async_handle_fetch) + return minibatch + + def _fetch_per_layer_stage_2(self, minibatch): + minibatch = next(minibatch._async_handle_fetch) + delattr(minibatch, "_async_handle_fetch") + return minibatch + + def _fetch_per_layer_async(self, minibatch): stream = torch.cuda.current_stream() uva_stream = get_host_to_device_uva_stream() uva_stream.wait_stream(stream) @@ -167,11 +169,6 @@ class FetchInsubgraphData(MiniBatchTransformer): seeds.record_stream(torch.cuda.current_stream()) - def record_stream(tensor): - if tensor.is_cuda: - tensor.record_stream(stream) - return tensor - # Packs tensors for batch slicing. tensors_to_be_sliced = [self.graph.indices] @@ -190,51 +187,53 @@ class FetchInsubgraphData(MiniBatchTransformer): has_probs_or_mask = True # Slices the batched tensors. - ( - indptr, - sliced_tensors, - ) = torch.ops.graphbolt.index_select_csc_batched( + future = torch.ops.graphbolt.index_select_csc_batched_async( self.graph.csc_indptr, tensors_to_be_sliced, seeds, True, None ) - for tensor in [indptr] + sliced_tensors: - record_stream(tensor) - # Unpacks the sliced tensors. - indices = sliced_tensors[0] + yield + + # graphbolt::async has already recorded a CUDAEvent for us and + # called CUDAStreamWaitEvent for us on the current stream. + indptr, sliced_tensors = future.wait() + + for tensor in [indptr] + sliced_tensors: + tensor.record_stream(stream) + + # Unpacks the sliced tensors. + indices = sliced_tensors[0] + sliced_tensors = sliced_tensors[1:] + + type_per_edge = None + if has_type_per_edge: + type_per_edge = sliced_tensors[0] sliced_tensors = sliced_tensors[1:] - type_per_edge = None - if has_type_per_edge: - type_per_edge = sliced_tensors[0] - sliced_tensors = sliced_tensors[1:] - - probs_or_mask = None - if has_probs_or_mask: - probs_or_mask = sliced_tensors[0] - sliced_tensors = sliced_tensors[1:] - - edge_ids = sliced_tensors[0] + probs_or_mask = None + if has_probs_or_mask: + probs_or_mask = sliced_tensors[0] sliced_tensors = sliced_tensors[1:] - assert len(sliced_tensors) == 0 - subgraph = fused_csc_sampling_graph( - indptr, - indices, - node_type_offset=self.graph.node_type_offset, - type_per_edge=type_per_edge, - node_type_to_id=self.graph.node_type_to_id, - edge_type_to_id=self.graph.edge_type_to_id, - ) - if self.prob_name is not None and probs_or_mask is not None: - subgraph.add_edge_attribute(self.prob_name, probs_or_mask) - subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids) + edge_ids = sliced_tensors[0] + sliced_tensors = sliced_tensors[1:] + assert len(sliced_tensors) == 0 - subgraph._indptr_node_type_offset_list = seed_offsets - minibatch._sliced_sampling_graph = subgraph + subgraph = fused_csc_sampling_graph( + indptr, + indices, + node_type_offset=self.graph.node_type_offset, + type_per_edge=type_per_edge, + node_type_to_id=self.graph.node_type_to_id, + edge_type_to_id=self.graph.edge_type_to_id, + ) + if self.prob_name is not None and probs_or_mask is not None: + subgraph.add_edge_attribute(self.prob_name, probs_or_mask) + subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids) - minibatch.wait = torch.cuda.current_stream().record_event().wait + subgraph._indptr_node_type_offset_list = seed_offsets + minibatch._sliced_sampling_graph = subgraph - return minibatch + yield minibatch @functional_datapipe("sample_per_layer") diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 85e034b123..92c6700553 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -132,23 +132,14 @@ def test_gpu_sampling_DataLoader( dataloader, dataloader2 = dataloaders bufferer_cnt = int(enable_feature_fetch and overlap_feature_fetch) - awaiter_cnt = 0 if overlap_graph_fetch: bufferer_cnt += num_layers - awaiter_cnt += num_layers - if asynchronous: - bufferer_cnt += 2 * num_layers - if overlap_graph_fetch: - bufferer_cnt += 0 * num_layers if num_gpu_cached_edges > 0: bufferer_cnt += 2 * num_layers + if asynchronous: + bufferer_cnt += 2 * num_layers datapipe = dataloader.dataset datapipe_graph = traverse_dps(datapipe) - awaiters = find_dps( - datapipe_graph, - dgl.graphbolt.Waiter, - ) - assert len(awaiters) == awaiter_cnt bufferers = find_dps( datapipe_graph, dgl.graphbolt.Bufferer,