mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Eliminate synchronization for overlap_graph_fetch. (#7709)
This commit is contained in:
committed by
GitHub
parent
396f5f1c5d
commit
25210816ff
@@ -207,5 +207,19 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
|
||||
return std::make_tuple(output_indptr, results);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<
|
||||
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
|
||||
IndexSelectCSCBatchedAsync(
|
||||
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
|
||||
torch::Tensor nodes, bool with_edge_ids,
|
||||
torch::optional<int64_t> output_size) {
|
||||
return async(
|
||||
[=] {
|
||||
return IndexSelectCSCBatched(
|
||||
indptr, indices_list, nodes, with_edge_ids, output_size);
|
||||
},
|
||||
utils::is_on_gpu(nodes));
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace graphbolt
|
||||
|
||||
@@ -92,6 +92,13 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
|
||||
torch::Tensor nodes, bool with_edge_ids,
|
||||
torch::optional<int64_t> output_size);
|
||||
|
||||
c10::intrusive_ptr<
|
||||
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
|
||||
IndexSelectCSCBatchedAsync(
|
||||
torch::Tensor indptr, std::vector<torch::Tensor> indices_list,
|
||||
torch::Tensor nodes, bool with_edge_ids,
|
||||
torch::optional<int64_t> output_size);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace graphbolt
|
||||
|
||||
|
||||
@@ -184,6 +184,7 @@ TORCH_LIBRARY(graphbolt, m) {
|
||||
m.def("scatter_async", &ops::ScatterAsync);
|
||||
m.def("index_select_csc", &ops::IndexSelectCSC);
|
||||
m.def("index_select_csc_batched", &ops::IndexSelectCSCBatched);
|
||||
m.def("index_select_csc_batched_async", &ops::IndexSelectCSCBatchedAsync);
|
||||
m.def("ondisk_npy_array", &storage::OnDiskNpyArray::Create);
|
||||
m.def("detect_io_uring", &io_uring::IsAvailable);
|
||||
m.def("set_num_io_uring_threads", &io_uring::SetNumThreads);
|
||||
|
||||
@@ -25,7 +25,6 @@ __all__ = [
|
||||
"LayerNeighborSampler",
|
||||
"SamplePerLayer",
|
||||
"FetchInsubgraphData",
|
||||
"ConcatHeteroSeeds",
|
||||
"CombineCachedAndFetchedInSubgraph",
|
||||
]
|
||||
|
||||
@@ -105,29 +104,6 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
|
||||
return minibatch
|
||||
|
||||
|
||||
@functional_datapipe("concat_hetero_seeds")
|
||||
class ConcatHeteroSeeds(Mapper):
|
||||
"""Concatenates the seeds into a single tensor in the hetero case."""
|
||||
|
||||
def __init__(self, datapipe, graph):
|
||||
super().__init__(datapipe, self._concat)
|
||||
self.graph = graph
|
||||
|
||||
def _concat(self, minibatch):
|
||||
seeds = minibatch._seed_nodes
|
||||
if isinstance(seeds, dict):
|
||||
(
|
||||
seeds,
|
||||
seed_offsets,
|
||||
) = self.graph._convert_to_homogeneous_nodes(seeds)
|
||||
else:
|
||||
seed_offsets = None
|
||||
minibatch._seeds = seeds
|
||||
minibatch._seed_offsets = seed_offsets
|
||||
|
||||
return minibatch
|
||||
|
||||
|
||||
@functional_datapipe("fetch_insubgraph_data")
|
||||
class FetchInsubgraphData(MiniBatchTransformer):
|
||||
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
|
||||
@@ -142,20 +118,46 @@ class FetchInsubgraphData(MiniBatchTransformer):
|
||||
graph,
|
||||
prob_name,
|
||||
):
|
||||
datapipe = datapipe.concat_hetero_seeds(graph)
|
||||
datapipe = datapipe.transform(self._concat_hetero_seeds)
|
||||
if graph._gpu_graph_cache is not None:
|
||||
datapipe = datapipe.fetch_cached_insubgraph_data(
|
||||
graph._gpu_graph_cache
|
||||
)
|
||||
datapipe = datapipe.transform(self._fetch_per_layer)
|
||||
datapipe = datapipe.buffer().wait()
|
||||
datapipe = datapipe.transform(self._fetch_per_layer_stage_1)
|
||||
datapipe = datapipe.buffer()
|
||||
datapipe = datapipe.transform(self._fetch_per_layer_stage_2)
|
||||
if graph._gpu_graph_cache is not None:
|
||||
datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)
|
||||
super().__init__(datapipe)
|
||||
self.graph = graph
|
||||
self.prob_name = prob_name
|
||||
|
||||
def _fetch_per_layer(self, minibatch):
|
||||
def _concat_hetero_seeds(self, minibatch):
|
||||
"""Concatenates the seeds into a single tensor in the hetero case."""
|
||||
seeds = minibatch._seed_nodes
|
||||
if isinstance(seeds, dict):
|
||||
(
|
||||
seeds,
|
||||
seed_offsets,
|
||||
) = self.graph._convert_to_homogeneous_nodes(seeds)
|
||||
else:
|
||||
seed_offsets = None
|
||||
minibatch._seeds = seeds
|
||||
minibatch._seed_offsets = seed_offsets
|
||||
|
||||
return minibatch
|
||||
|
||||
def _fetch_per_layer_stage_1(self, minibatch):
|
||||
minibatch._async_handle_fetch = self._fetch_per_layer_async(minibatch)
|
||||
next(minibatch._async_handle_fetch)
|
||||
return minibatch
|
||||
|
||||
def _fetch_per_layer_stage_2(self, minibatch):
|
||||
minibatch = next(minibatch._async_handle_fetch)
|
||||
delattr(minibatch, "_async_handle_fetch")
|
||||
return minibatch
|
||||
|
||||
def _fetch_per_layer_async(self, minibatch):
|
||||
stream = torch.cuda.current_stream()
|
||||
uva_stream = get_host_to_device_uva_stream()
|
||||
uva_stream.wait_stream(stream)
|
||||
@@ -167,11 +169,6 @@ class FetchInsubgraphData(MiniBatchTransformer):
|
||||
|
||||
seeds.record_stream(torch.cuda.current_stream())
|
||||
|
||||
def record_stream(tensor):
|
||||
if tensor.is_cuda:
|
||||
tensor.record_stream(stream)
|
||||
return tensor
|
||||
|
||||
# Packs tensors for batch slicing.
|
||||
tensors_to_be_sliced = [self.graph.indices]
|
||||
|
||||
@@ -190,51 +187,53 @@ class FetchInsubgraphData(MiniBatchTransformer):
|
||||
has_probs_or_mask = True
|
||||
|
||||
# Slices the batched tensors.
|
||||
(
|
||||
indptr,
|
||||
sliced_tensors,
|
||||
) = torch.ops.graphbolt.index_select_csc_batched(
|
||||
future = torch.ops.graphbolt.index_select_csc_batched_async(
|
||||
self.graph.csc_indptr, tensors_to_be_sliced, seeds, True, None
|
||||
)
|
||||
for tensor in [indptr] + sliced_tensors:
|
||||
record_stream(tensor)
|
||||
|
||||
# Unpacks the sliced tensors.
|
||||
indices = sliced_tensors[0]
|
||||
yield
|
||||
|
||||
# graphbolt::async has already recorded a CUDAEvent for us and
|
||||
# called CUDAStreamWaitEvent for us on the current stream.
|
||||
indptr, sliced_tensors = future.wait()
|
||||
|
||||
for tensor in [indptr] + sliced_tensors:
|
||||
tensor.record_stream(stream)
|
||||
|
||||
# Unpacks the sliced tensors.
|
||||
indices = sliced_tensors[0]
|
||||
sliced_tensors = sliced_tensors[1:]
|
||||
|
||||
type_per_edge = None
|
||||
if has_type_per_edge:
|
||||
type_per_edge = sliced_tensors[0]
|
||||
sliced_tensors = sliced_tensors[1:]
|
||||
|
||||
type_per_edge = None
|
||||
if has_type_per_edge:
|
||||
type_per_edge = sliced_tensors[0]
|
||||
sliced_tensors = sliced_tensors[1:]
|
||||
|
||||
probs_or_mask = None
|
||||
if has_probs_or_mask:
|
||||
probs_or_mask = sliced_tensors[0]
|
||||
sliced_tensors = sliced_tensors[1:]
|
||||
|
||||
edge_ids = sliced_tensors[0]
|
||||
probs_or_mask = None
|
||||
if has_probs_or_mask:
|
||||
probs_or_mask = sliced_tensors[0]
|
||||
sliced_tensors = sliced_tensors[1:]
|
||||
assert len(sliced_tensors) == 0
|
||||
|
||||
subgraph = fused_csc_sampling_graph(
|
||||
indptr,
|
||||
indices,
|
||||
node_type_offset=self.graph.node_type_offset,
|
||||
type_per_edge=type_per_edge,
|
||||
node_type_to_id=self.graph.node_type_to_id,
|
||||
edge_type_to_id=self.graph.edge_type_to_id,
|
||||
)
|
||||
if self.prob_name is not None and probs_or_mask is not None:
|
||||
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
|
||||
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)
|
||||
edge_ids = sliced_tensors[0]
|
||||
sliced_tensors = sliced_tensors[1:]
|
||||
assert len(sliced_tensors) == 0
|
||||
|
||||
subgraph._indptr_node_type_offset_list = seed_offsets
|
||||
minibatch._sliced_sampling_graph = subgraph
|
||||
subgraph = fused_csc_sampling_graph(
|
||||
indptr,
|
||||
indices,
|
||||
node_type_offset=self.graph.node_type_offset,
|
||||
type_per_edge=type_per_edge,
|
||||
node_type_to_id=self.graph.node_type_to_id,
|
||||
edge_type_to_id=self.graph.edge_type_to_id,
|
||||
)
|
||||
if self.prob_name is not None and probs_or_mask is not None:
|
||||
subgraph.add_edge_attribute(self.prob_name, probs_or_mask)
|
||||
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_ids)
|
||||
|
||||
minibatch.wait = torch.cuda.current_stream().record_event().wait
|
||||
subgraph._indptr_node_type_offset_list = seed_offsets
|
||||
minibatch._sliced_sampling_graph = subgraph
|
||||
|
||||
return minibatch
|
||||
yield minibatch
|
||||
|
||||
|
||||
@functional_datapipe("sample_per_layer")
|
||||
|
||||
@@ -132,23 +132,14 @@ def test_gpu_sampling_DataLoader(
|
||||
dataloader, dataloader2 = dataloaders
|
||||
|
||||
bufferer_cnt = int(enable_feature_fetch and overlap_feature_fetch)
|
||||
awaiter_cnt = 0
|
||||
if overlap_graph_fetch:
|
||||
bufferer_cnt += num_layers
|
||||
awaiter_cnt += num_layers
|
||||
if asynchronous:
|
||||
bufferer_cnt += 2 * num_layers
|
||||
if overlap_graph_fetch:
|
||||
bufferer_cnt += 0 * num_layers
|
||||
if num_gpu_cached_edges > 0:
|
||||
bufferer_cnt += 2 * num_layers
|
||||
if asynchronous:
|
||||
bufferer_cnt += 2 * num_layers
|
||||
datapipe = dataloader.dataset
|
||||
datapipe_graph = traverse_dps(datapipe)
|
||||
awaiters = find_dps(
|
||||
datapipe_graph,
|
||||
dgl.graphbolt.Waiter,
|
||||
)
|
||||
assert len(awaiters) == awaiter_cnt
|
||||
bufferers = find_dps(
|
||||
datapipe_graph,
|
||||
dgl.graphbolt.Bufferer,
|
||||
|
||||
Reference in New Issue
Block a user