mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Use async for GPUGraphCache. (#7707)
This commit is contained in:
committed by
GitHub
parent
db574f5b0b
commit
9c874d0219
@@ -25,6 +25,7 @@
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuco/static_map.cuh>
|
||||
#include <cuda/std/atomic>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -168,6 +169,7 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
|
||||
seeds.device().index() == device_id_,
|
||||
"Seeds should be on the correct CUDA device.");
|
||||
TORCH_CHECK(seeds.sizes().size() == 1, "Keys should be a 1D tensor.");
|
||||
std::lock_guard lock(mtx_);
|
||||
auto allocator = cuda::GetAllocator();
|
||||
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
|
||||
const dim3 block(kIntBlockSize);
|
||||
@@ -237,6 +239,12 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
|
||||
}));
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<
|
||||
Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>
|
||||
GpuGraphCache::QueryAsync(torch::Tensor seeds) {
|
||||
return async([=] { return Query(seeds); });
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
|
||||
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
|
||||
@@ -250,6 +258,7 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
TORCH_CHECK(
|
||||
indptr.size(0) == num_nodes - num_hit + 1,
|
||||
"(indptr.size(0) == seeds.size(0) - num_hit + 1) failed.");
|
||||
std::lock_guard lock(mtx_);
|
||||
const int64_t num_buffers = num_nodes * num_tensors;
|
||||
auto allocator = cuda::GetAllocator();
|
||||
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
|
||||
@@ -490,5 +499,18 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
}));
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<
|
||||
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
|
||||
GpuGraphCache::ReplaceAsync(
|
||||
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
|
||||
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
|
||||
std::vector<torch::Tensor> edge_tensors) {
|
||||
return async([=] {
|
||||
return Replace(
|
||||
seeds, indices, positions, num_hit, num_threshold, indptr,
|
||||
edge_tensors);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace graphbolt
|
||||
|
||||
@@ -21,11 +21,11 @@
|
||||
#ifndef GRAPHBOLT_GPU_GRAPH_CACHE_H_
|
||||
#define GRAPHBOLT_GPU_GRAPH_CACHE_H_
|
||||
|
||||
#include <graphbolt/async.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <mutex>
|
||||
|
||||
namespace graphbolt {
|
||||
namespace cuda {
|
||||
@@ -69,6 +69,10 @@ class GpuGraphCache : public torch::CustomClassHolder {
|
||||
std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> Query(
|
||||
torch::Tensor seeds);
|
||||
|
||||
c10::intrusive_ptr<
|
||||
Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>
|
||||
QueryAsync(torch::Tensor seeds);
|
||||
|
||||
/**
|
||||
* @brief After the graph structure for the missing node ids are fetched, it
|
||||
* inserts the node ids which passes the threshold and returns the final
|
||||
@@ -96,6 +100,13 @@ class GpuGraphCache : public torch::CustomClassHolder {
|
||||
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
|
||||
std::vector<torch::Tensor> edge_tensors);
|
||||
|
||||
c10::intrusive_ptr<
|
||||
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
|
||||
ReplaceAsync(
|
||||
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
|
||||
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
|
||||
std::vector<torch::Tensor> edge_tensors);
|
||||
|
||||
static c10::intrusive_ptr<GpuGraphCache> Create(
|
||||
const int64_t num_edges, const int64_t threshold,
|
||||
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes);
|
||||
@@ -111,6 +122,7 @@ class GpuGraphCache : public torch::CustomClassHolder {
|
||||
torch::Tensor offset_; // The original graph's sliced_indptr tensor.
|
||||
std::vector<torch::Tensor> cached_edge_tensors_; // The cached graph
|
||||
// structure edge tensors.
|
||||
std::mutex mtx_; // Protects the data structure and makes it threadsafe.
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
@@ -58,6 +58,17 @@ TORCH_LIBRARY(graphbolt, m) {
|
||||
"wait",
|
||||
&Future<std::vector<
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
|
||||
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
|
||||
"GpuGraphCacheQueryFuture")
|
||||
.def(
|
||||
"wait",
|
||||
&Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>::
|
||||
Wait);
|
||||
m.class_<Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>(
|
||||
"GpuGraphCacheReplaceFuture")
|
||||
.def(
|
||||
"wait",
|
||||
&Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>::Wait);
|
||||
m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray")
|
||||
.def("index_select", &storage::OnDiskNpyArray::IndexSelect);
|
||||
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
|
||||
@@ -114,7 +125,9 @@ TORCH_LIBRARY(graphbolt, m) {
|
||||
m.def("gpu_cache", &cuda::GpuCache::Create);
|
||||
m.class_<cuda::GpuGraphCache>("GpuGraphCache")
|
||||
.def("query", &cuda::GpuGraphCache::Query)
|
||||
.def("replace", &cuda::GpuGraphCache::Replace);
|
||||
.def("query_async", &cuda::GpuGraphCache::QueryAsync)
|
||||
.def("replace", &cuda::GpuGraphCache::Replace)
|
||||
.def("replace_async", &cuda::GpuGraphCache::ReplaceAsync);
|
||||
m.def("gpu_graph_cache", &cuda::GpuGraphCache::Create);
|
||||
#endif
|
||||
m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);
|
||||
|
||||
@@ -68,6 +68,45 @@ class GPUGraphCache(object):
|
||||
|
||||
return keys[index[num_hit:]], replace_functional
|
||||
|
||||
def query_async(self, keys):
|
||||
"""Queries the GPU cache asynchronously.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
keys : Tensor
|
||||
The keys to query the GPU graph cache with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A generator object.
|
||||
The returned generator object returns the missing keys on the second
|
||||
invocation and expects the fetched indptr and edge tensors on the
|
||||
next invocation. The third and last invocation returns a future
|
||||
object and the return result can be accessed by calling `.wait()`
|
||||
on the returned future object. It is undefined behavior to call
|
||||
`.wait()` more than once.
|
||||
"""
|
||||
future = self._cache.query_async(keys)
|
||||
|
||||
yield
|
||||
|
||||
index, position, num_hit, num_threshold = future.wait()
|
||||
|
||||
self.total_queries += keys.shape[0]
|
||||
self.total_miss += keys.shape[0] - num_hit
|
||||
|
||||
missing_indptr, missing_edge_tensors = yield keys[index[num_hit:]]
|
||||
|
||||
yield self._cache.replace_async(
|
||||
keys,
|
||||
index,
|
||||
position,
|
||||
num_hit,
|
||||
num_threshold,
|
||||
missing_indptr,
|
||||
missing_edge_tensors,
|
||||
)
|
||||
|
||||
@property
|
||||
def miss_rate(self):
|
||||
"""Returns the cache miss rate since creation."""
|
||||
|
||||
@@ -32,18 +32,25 @@ __all__ = [
|
||||
|
||||
@functional_datapipe("fetch_cached_insubgraph_data")
|
||||
class FetchCachedInsubgraphData(Mapper):
|
||||
"""Queries the GPUGraphCache and returns the missing seeds and a lambda
|
||||
function that can be called with the fetched graph structure.
|
||||
"""Queries the GPUGraphCache and returns the missing seeds and a generator
|
||||
handle that can be called with the fetched graph structure.
|
||||
"""
|
||||
|
||||
def __init__(self, datapipe, gpu_graph_cache):
|
||||
super().__init__(datapipe, self._fetch_per_layer)
|
||||
datapipe = datapipe.transform(self._fetch_per_layer).buffer()
|
||||
super().__init__(datapipe, self._wait_query_future)
|
||||
self.cache = gpu_graph_cache
|
||||
|
||||
def _fetch_per_layer(self, minibatch):
|
||||
minibatch._seeds, minibatch._replace = self.cache.query(
|
||||
minibatch._seeds
|
||||
)
|
||||
minibatch._async_handle = self.cache.query_async(minibatch._seeds)
|
||||
# Start first stage
|
||||
next(minibatch._async_handle)
|
||||
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _wait_query_future(minibatch):
|
||||
minibatch._seeds = next(minibatch._async_handle)
|
||||
|
||||
return minibatch
|
||||
|
||||
@@ -55,7 +62,8 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
|
||||
"""
|
||||
|
||||
def __init__(self, datapipe, prob_name):
|
||||
super().__init__(datapipe, self._combine_per_layer)
|
||||
datapipe = datapipe.transform(self._combine_per_layer).buffer()
|
||||
super().__init__(datapipe, self._wait_replace_future)
|
||||
self.prob_name = prob_name
|
||||
|
||||
def _combine_per_layer(self, minibatch):
|
||||
@@ -69,16 +77,24 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
|
||||
edge_tensors.append(probs_or_mask)
|
||||
edge_tensors.append(subgraph.edge_attribute(ORIGINAL_EDGE_ID))
|
||||
|
||||
subgraph.csc_indptr, edge_tensors = minibatch._replace(
|
||||
subgraph.csc_indptr, edge_tensors
|
||||
minibatch._future = minibatch._async_handle.send(
|
||||
(subgraph.csc_indptr, edge_tensors)
|
||||
)
|
||||
delattr(minibatch, "_replace")
|
||||
delattr(minibatch, "_async_handle")
|
||||
|
||||
return minibatch
|
||||
|
||||
def _wait_replace_future(self, minibatch):
|
||||
subgraph = minibatch._sliced_sampling_graph
|
||||
subgraph.csc_indptr, edge_tensors = minibatch._future.wait()
|
||||
delattr(minibatch, "_future")
|
||||
|
||||
subgraph.indices = edge_tensors[0]
|
||||
edge_tensors = edge_tensors[1:]
|
||||
if subgraph.type_per_edge is not None:
|
||||
subgraph.type_per_edge = edge_tensors[0]
|
||||
edge_tensors = edge_tensors[1:]
|
||||
probs_or_mask = subgraph.edge_attribute(self.prob_name)
|
||||
if probs_or_mask is not None:
|
||||
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
|
||||
edge_tensors = edge_tensors[1:]
|
||||
@@ -113,7 +129,7 @@ class ConcatHeteroSeeds(Mapper):
|
||||
|
||||
|
||||
@functional_datapipe("fetch_insubgraph_data")
|
||||
class FetchInsubgraphData(Mapper):
|
||||
class FetchInsubgraphData(MiniBatchTransformer):
|
||||
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
|
||||
the provided sample_per_layer_obj has a valid prob_name, then it reads the
|
||||
probabilies of all the fetched edges. Furthermore, if type_per_array tensor
|
||||
@@ -131,9 +147,13 @@ class FetchInsubgraphData(Mapper):
|
||||
datapipe = datapipe.fetch_cached_insubgraph_data(
|
||||
graph._gpu_graph_cache
|
||||
)
|
||||
datapipe = datapipe.transform(self._fetch_per_layer)
|
||||
datapipe = datapipe.buffer().wait()
|
||||
if graph._gpu_graph_cache is not None:
|
||||
datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)
|
||||
super().__init__(datapipe)
|
||||
self.graph = graph
|
||||
self.prob_name = prob_name
|
||||
super().__init__(datapipe, self._fetch_per_layer)
|
||||
|
||||
def _fetch_per_layer(self, minibatch):
|
||||
stream = torch.cuda.current_stream()
|
||||
@@ -260,11 +280,6 @@ class SamplePerLayer(MiniBatchTransformer):
|
||||
self.returning_indices_is_optional = True
|
||||
elif overlap_fetch:
|
||||
datapipe = datapipe.fetch_insubgraph_data(graph, prob_name)
|
||||
datapipe = datapipe.buffer().wait()
|
||||
if graph._gpu_graph_cache is not None:
|
||||
datapipe = datapipe.combine_cached_and_fetched_insubgraph(
|
||||
prob_name
|
||||
)
|
||||
datapipe = datapipe.transform(
|
||||
self._sample_per_layer_from_fetched_subgraph
|
||||
)
|
||||
|
||||
@@ -138,6 +138,10 @@ def test_gpu_sampling_DataLoader(
|
||||
awaiter_cnt += num_layers
|
||||
if asynchronous:
|
||||
bufferer_cnt += 2 * num_layers
|
||||
if overlap_graph_fetch:
|
||||
bufferer_cnt += 0 * num_layers
|
||||
if num_gpu_cached_edges > 0:
|
||||
bufferer_cnt += 2 * num_layers
|
||||
datapipe = dataloader.dataset
|
||||
datapipe_graph = traverse_dps(datapipe)
|
||||
awaiters = find_dps(
|
||||
|
||||
Reference in New Issue
Block a user