[GraphBolt][CUDA] Eliminate synchronization for overlap_graph_fetch. (#7709)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-16 17:42:14 -04:00
committed by GitHub
parent 396f5f1c5d
commit 25210816ff
5 changed files with 91 additions and 79 deletions

View File

@@ -207,5 +207,19 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
return std::make_tuple(output_indptr, results);
}
c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
IndexSelectCSCBatchedAsync(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size) {
return async(
[=] {
return IndexSelectCSCBatched(
indptr, indices_list, nodes, with_edge_ids, output_size);
},
utils::is_on_gpu(nodes));
}
} // namespace ops
} // namespace graphbolt

View File

@@ -92,6 +92,13 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size);
c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
IndexSelectCSCBatchedAsync(
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
torch::Tensor nodes, bool with_edge_ids,
torch::optional<int64_t> output_size);
} // namespace ops
} // namespace graphbolt

View File

@@ -184,6 +184,7 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("scatter_async", &ops::ScatterAsync);
m.def("index_select_csc", &ops::IndexSelectCSC);
m.def("index_select_csc_batched", &ops::IndexSelectCSCBatched);
m.def("index_select_csc_batched_async", &ops::IndexSelectCSCBatchedAsync);
m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create);
m.def("detect_io_uring", &io_uring::IsAvailable);
m.def("set_num_io_uring_threads", &io_uring::SetNumThreads);

View File

@@ -25,7 +25,6 @@ __all__ = [
"LayerNeighborSampler",
"SamplePerLayer",
"FetchInsubgraphData",
"ConcatHeteroSeeds",
"CombineCachedAndFetchedInSubgraph",
]
@@ -105,29 +104,6 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
return minibatch
@functional_datapipe("concat_hetero_seeds")
class ConcatHeteroSeeds(Mapper):
"""Concatenates the seeds into a single tensor in the hetero case."""
def __init__(self, datapipe, graph):
super().__init__(datapipe, self._concat)
self.graph = graph
def _concat(self, minibatch):
seeds = minibatch._seed_nodes
if isinstance(seeds, dict):
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seed_offsets = None
minibatch._seeds = seeds
minibatch._seed_offsets = seed_offsets
return minibatch
@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(MiniBatchTransformer):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
@@ -142,20 +118,46 @@ class FetchInsubgraphData(MiniBatchTransformer):
graph,
prob_name,
):
datapipe = datapipe.concat_hetero_seeds(graph)
datapipe = datapipe.transform(self._concat_hetero_seeds)
if graph._gpu_graph_cache is not None:
datapipe = datapipe.fetch_cached_insubgraph_data(
graph._gpu_graph_cache
)
datapipe = datapipe.transform(self._fetch_per_layer)
datapipe = datapipe.buffer().wait()
datapipe = datapipe.transform(self._fetch_per_layer_stage_1)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(self._fetch_per_layer_stage_2)
if graph._gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)
super().__init__(datapipe)
self.graph = graph
self.prob_name = prob_name
def _fetch_per_layer(self, minibatch):
def _concat_hetero_seeds(self, minibatch):
"""Concatenates the seeds into a single tensor in the hetero case."""
seeds = minibatch._seed_nodes
if isinstance(seeds, dict):
(
seeds,
seed_offsets,
) = self.graph._convert_to_homogeneous_nodes(seeds)
else:
seed_offsets = None
minibatch._seeds = seeds
minibatch._seed_offsets = seed_offsets
return minibatch
def _fetch_per_layer_stage_1(self, minibatch):
minibatch._async_handle_fetch = self._fetch_per_layer_async(minibatch)
next(minibatch._async_handle_fetch)
return minibatch
def _fetch_per_layer_stage_2(self, minibatch):
minibatch = next(minibatch._async_handle_fetch)
delattr(minibatch, "_async_handle_fetch")
return minibatch
def _fetch_per_layer_async(self, minibatch):
stream = torch.cuda.current_stream()
uva_stream = get_host_to_device_uva_stream()
uva_stream.wait_stream(stream)
@@ -167,11 +169,6 @@ class FetchInsubgraphData(MiniBatchTransformer):
seeds.record_stream(torch.cuda.current_stream())
def record_stream(tensor):
if tensor.is_cuda:
tensor.record_stream(stream)
return tensor
# Packs tensors for batch slicing.
tensors_to_be_sliced = [self.graph.indices]
@@ -190,51 +187,53 @@ class FetchInsubgraphData(MiniBatchTransformer):
has_probs_or_mask = True
# Slices the batched tensors.
(
indptr,
sliced_tensors,
) = torch.ops.graphbolt.index_select_csc_batched(
future = torch.ops.graphbolt.index_select_csc_batched_async(
self.graph.csc_indptr, tensors_to_be_sliced, seeds, True, None
)
for tensor in [indptr] + sliced_tensors:
record_stream(tensor)
# Unpacks the sliced tensors.
indices = sliced_tensors[0]
yield
# graphbolt::async has already recorded a CUDAEvent for us and
# called CUDAStreamWaitEvent for us on the current stream.
indptr, sliced_tensors = future.wait()
for tensor in [indptr] + sliced_tensors:
tensor.record_stream(stream)
# Unpacks the sliced tensors.
indices = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
type_per_edge = None
if has_type_per_edge:
type_per_edge = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
type_per_edge = None
if has_type_per_edge:
type_per_edge = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
probs_or_mask = None
if has_probs_or_mask:
probs_or_mask = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
edge_ids = sliced_tensors[0]
probs_or_mask = None
if has_probs_or_mask:
probs_or_mask = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
assert len(sliced_tensors) == 0
subgraph = fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=self.graph.node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)
edge_ids = sliced_tensors[0]
sliced_tensors = sliced_tensors[1:]
assert len(sliced_tensors) == 0
subgraph._indptr_node_type_offset_list = seed_offsets
minibatch._sliced_sampling_graph = subgraph
subgraph = fused_csc_sampling_graph(
indptr,
indices,
node_type_offset=self.graph.node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=self.graph.node_type_to_id,
edge_type_to_id=self.graph.edge_type_to_id,
)
if self.prob_name is not None and probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)
minibatch.wait = torch.cuda.current_stream().record_event().wait
subgraph._indptr_node_type_offset_list = seed_offsets
minibatch._sliced_sampling_graph = subgraph
return minibatch
yield minibatch
@functional_datapipe("sample_per_layer")

View File

@@ -132,23 +132,14 @@ def test_gpu_sampling_DataLoader(
dataloader, dataloader2 = dataloaders
bufferer_cnt = int(enable_feature_fetch and overlap_feature_fetch)
awaiter_cnt = 0
if overlap_graph_fetch:
bufferer_cnt += num_layers
awaiter_cnt += num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
if overlap_graph_fetch:
bufferer_cnt += 0 * num_layers
if num_gpu_cached_edges > 0:
bufferer_cnt += 2 * num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
datapipe = dataloader.dataset
datapipe_graph = traverse_dps(datapipe)
awaiters = find_dps(
datapipe_graph,
dgl.graphbolt.Waiter,
)
assert len(awaiters) == awaiter_cnt
bufferers = find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,