[GraphBolt][CUDA] Use async for GPUGraphCache. (#7707)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-15 20:39:58 -04:00
committed by GitHub
parent db574f5b0b
commit 9c874d0219
6 changed files with 125 additions and 20 deletions

View File

@@ -25,6 +25,7 @@
#include <cub/cub.cuh>
#include <cuco/static_map.cuh>
#include <cuda/std/atomic>
#include <limits>
#include <numeric>
#include <type_traits>
@@ -168,6 +169,7 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
seeds.device().index() == device_id_,
"Seeds should be on the correct CUDA device.");
TORCH_CHECK(seeds.sizes().size() == 1, "Keys should be a 1D tensor.");
std::lock_guard lock(mtx_);
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
const dim3 block(kIntBlockSize);
@@ -237,6 +239,12 @@ std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> GpuGraphCache::Query(
}));
}
c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>
GpuGraphCache::QueryAsync(torch::Tensor seeds) {
return async([=] { return Query(seeds); });
}
std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
@@ -250,6 +258,7 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
TORCH_CHECK(
indptr.size(0) == num_nodes - num_hit + 1,
"(indptr.size(0) == seeds.size(0) - num_hit + 1) failed.");
std::lock_guard lock(mtx_);
const int64_t num_buffers = num_nodes * num_tensors;
auto allocator = cuda::GetAllocator();
auto index_dtype = cached_edge_tensors_.at(0).scalar_type();
@@ -490,5 +499,18 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
}));
}
c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
GpuGraphCache::ReplaceAsync(
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
std::vector<torch::Tensor> edge_tensors) {
return async([=] {
return Replace(
seeds, indices, positions, num_hit, num_threshold, indptr,
edge_tensors);
});
}
} // namespace cuda
} // namespace graphbolt

View File

@@ -21,11 +21,11 @@
#ifndef GRAPHBOLT_GPU_GRAPH_CACHE_H_
#define GRAPHBOLT_GPU_GRAPH_CACHE_H_
#include <graphbolt/async.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <limits>
#include <type_traits>
#include <mutex>
namespace graphbolt {
namespace cuda {
@@ -69,6 +69,10 @@ class GpuGraphCache : public torch::CustomClassHolder {
std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t> Query(
torch::Tensor seeds);
c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>
QueryAsync(torch::Tensor seeds);
/**
* @brief After the graph structure for the missing node ids are fetched, it
* inserts the node ids which passes the threshold and returns the final
@@ -96,6 +100,13 @@ class GpuGraphCache : public torch::CustomClassHolder {
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
std::vector<torch::Tensor> edge_tensors);
c10::intrusive_ptr<
Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>
ReplaceAsync(
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
std::vector<torch::Tensor> edge_tensors);
static c10::intrusive_ptr<GpuGraphCache> Create(
const int64_t num_edges, const int64_t threshold,
torch::ScalarType indptr_dtype, std::vector<torch::ScalarType> dtypes);
@@ -111,6 +122,7 @@ class GpuGraphCache : public torch::CustomClassHolder {
torch::Tensor offset_; // The original graph's sliced_indptr tensor.
std::vector<torch::Tensor> cached_edge_tensors_; // The cached graph
// structure edge tensors.
std::mutex mtx_; // Protects the data structure and makes it threadsafe.
};
} // namespace cuda

View File

@@ -58,6 +58,17 @@ TORCH_LIBRARY(graphbolt, m) {
"wait",
&Future<std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>>::Wait);
m.class_<Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>>(
"GpuGraphCacheQueryFuture")
.def(
"wait",
&Future<std::tuple<torch::Tensor, torch::Tensor, int64_t, int64_t>>::
Wait);
m.class_<Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>>(
"GpuGraphCacheReplaceFuture")
.def(
"wait",
&Future<std::tuple<torch::Tensor, std::vector<torch::Tensor>>>::Wait);
m.class_<storage::OnDiskNpyArray>("OnDiskNpyArray")
.def("index_select", &storage::OnDiskNpyArray::IndexSelect);
m.class_<FusedCSCSamplingGraph>("FusedCSCSamplingGraph")
@@ -114,7 +125,9 @@ TORCH_LIBRARY(graphbolt, m) {
m.def("gpu_cache", &cuda::GpuCache::Create);
m.class_<cuda::GpuGraphCache>("GpuGraphCache")
.def("query", &cuda::GpuGraphCache::Query)
.def("replace", &cuda::GpuGraphCache::Replace);
.def("query_async", &cuda::GpuGraphCache::QueryAsync)
.def("replace", &cuda::GpuGraphCache::Replace)
.def("replace_async", &cuda::GpuGraphCache::ReplaceAsync);
m.def("gpu_graph_cache", &cuda::GpuGraphCache::Create);
#endif
m.def("fused_csc_sampling_graph", &FusedCSCSamplingGraph::Create);

View File

@@ -68,6 +68,45 @@ class GPUGraphCache(object):
return keys[index[num_hit:]], replace_functional
def query_async(self, keys):
"""Queries the GPU cache asynchronously.
Parameters
----------
keys : Tensor
The keys to query the GPU graph cache with.
Returns
-------
A generator object.
The returned generator object returns the missing keys on the second
invocation and expects the fetched indptr and edge tensors on the
next invocation. The third and last invocation returns a future
object and the return result can be accessed by calling `.wait()`
on the returned future object. It is undefined behavior to call
`.wait()` more than once.
"""
future = self._cache.query_async(keys)
yield
index, position, num_hit, num_threshold = future.wait()
self.total_queries += keys.shape[0]
self.total_miss += keys.shape[0] - num_hit
missing_indptr, missing_edge_tensors = yield keys[index[num_hit:]]
yield self._cache.replace_async(
keys,
index,
position,
num_hit,
num_threshold,
missing_indptr,
missing_edge_tensors,
)
@property
def miss_rate(self):
"""Returns the cache miss rate since creation."""

View File

@@ -32,18 +32,25 @@ __all__ = [
@functional_datapipe("fetch_cached_insubgraph_data")
class FetchCachedInsubgraphData(Mapper):
"""Queries the GPUGraphCache and returns the missing seeds and a lambda
function that can be called with the fetched graph structure.
"""Queries the GPUGraphCache and returns the missing seeds and a generator
handle that can be called with the fetched graph structure.
"""
def __init__(self, datapipe, gpu_graph_cache):
super().__init__(datapipe, self._fetch_per_layer)
datapipe = datapipe.transform(self._fetch_per_layer).buffer()
super().__init__(datapipe, self._wait_query_future)
self.cache = gpu_graph_cache
def _fetch_per_layer(self, minibatch):
minibatch._seeds, minibatch._replace = self.cache.query(
minibatch._seeds
)
minibatch._async_handle = self.cache.query_async(minibatch._seeds)
# Start first stage
next(minibatch._async_handle)
return minibatch
@staticmethod
def _wait_query_future(minibatch):
minibatch._seeds = next(minibatch._async_handle)
return minibatch
@@ -55,7 +62,8 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
"""
def __init__(self, datapipe, prob_name):
super().__init__(datapipe, self._combine_per_layer)
datapipe = datapipe.transform(self._combine_per_layer).buffer()
super().__init__(datapipe, self._wait_replace_future)
self.prob_name = prob_name
def _combine_per_layer(self, minibatch):
@@ -69,16 +77,24 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
edge_tensors.append(probs_or_mask)
edge_tensors.append(subgraph.edge_attribute(ORIGINAL_EDGE_ID))
subgraph.csc_indptr, edge_tensors = minibatch._replace(
subgraph.csc_indptr, edge_tensors
minibatch._future = minibatch._async_handle.send(
(subgraph.csc_indptr, edge_tensors)
)
delattr(minibatch, "_replace")
delattr(minibatch, "_async_handle")
return minibatch
def _wait_replace_future(self, minibatch):
subgraph = minibatch._sliced_sampling_graph
subgraph.csc_indptr, edge_tensors = minibatch._future.wait()
delattr(minibatch, "_future")
subgraph.indices = edge_tensors[0]
edge_tensors = edge_tensors[1:]
if subgraph.type_per_edge is not None:
subgraph.type_per_edge = edge_tensors[0]
edge_tensors = edge_tensors[1:]
probs_or_mask = subgraph.edge_attribute(self.prob_name)
if probs_or_mask is not None:
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
edge_tensors = edge_tensors[1:]
@@ -113,7 +129,7 @@ class ConcatHeteroSeeds(Mapper):
@functional_datapipe("fetch_insubgraph_data")
class FetchInsubgraphData(Mapper):
class FetchInsubgraphData(MiniBatchTransformer):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
the provided sample_per_layer_obj has a valid prob_name, then it reads the
probabilies of all the fetched edges. Furthermore, if type_per_array tensor
@@ -131,9 +147,13 @@ class FetchInsubgraphData(Mapper):
datapipe = datapipe.fetch_cached_insubgraph_data(
graph._gpu_graph_cache
)
datapipe = datapipe.transform(self._fetch_per_layer)
datapipe = datapipe.buffer().wait()
if graph._gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(prob_name)
super().__init__(datapipe)
self.graph = graph
self.prob_name = prob_name
super().__init__(datapipe, self._fetch_per_layer)
def _fetch_per_layer(self, minibatch):
stream = torch.cuda.current_stream()
@@ -260,11 +280,6 @@ class SamplePerLayer(MiniBatchTransformer):
self.returning_indices_is_optional = True
elif overlap_fetch:
datapipe = datapipe.fetch_insubgraph_data(graph, prob_name)
datapipe = datapipe.buffer().wait()
if graph._gpu_graph_cache is not None:
datapipe = datapipe.combine_cached_and_fetched_insubgraph(
prob_name
)
datapipe = datapipe.transform(
self._sample_per_layer_from_fetched_subgraph
)

View File

@@ -138,6 +138,10 @@ def test_gpu_sampling_DataLoader(
awaiter_cnt += num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
if overlap_graph_fetch:
bufferer_cnt += 0 * num_layers
if num_gpu_cached_edges > 0:
bufferer_cnt += 2 * num_layers
datapipe = dataloader.dataset
datapipe_graph = traverse_dps(datapipe)
awaiters = find_dps(