mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Keep CUDAStream only if cuda is available. (#7701)
This commit is contained in:
committed by
GitHub
parent
0d68130f92
commit
1c0ff2c924
@@ -39,6 +39,7 @@
|
||||
#ifdef GRAPHBOLT_USE_CUDA
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/csrc/api/include/torch/cuda.h>
|
||||
#endif
|
||||
|
||||
namespace graphbolt {
|
||||
@@ -111,13 +112,23 @@ template <typename F>
|
||||
inline auto async(F&& function) {
|
||||
using T = decltype(function());
|
||||
#ifdef GRAPHBOLT_USE_CUDA
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
const auto is_cuda_available = torch::cuda::is_available();
|
||||
struct c10::StreamData3 stream_data;
|
||||
if (is_cuda_available) {
|
||||
stream_data = c10::cuda::getCurrentCUDAStream().pack3();
|
||||
}
|
||||
#endif
|
||||
auto fn = [=, func = std::move(function)] {
|
||||
#ifdef GRAPHBOLT_USE_CUDA
|
||||
// We make sure to use the same CUDA stream as the thread launching the
|
||||
// async operation.
|
||||
c10::cuda::CUDAStreamGuard guard(stream);
|
||||
if (is_cuda_available) {
|
||||
auto stream = c10::cuda::CUDAStream::unpack3(
|
||||
stream_data.stream_id, stream_data.device_index,
|
||||
stream_data.device_type);
|
||||
c10::cuda::CUDAStreamGuard guard(stream);
|
||||
return func();
|
||||
}
|
||||
#endif
|
||||
return func();
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user