mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Lock-free CachePolicy::ReadWriteCompleted. (#7654)
This commit is contained in:
committed by
GitHub
parent
2babaf9d66
commit
7deeff6f4c
@@ -25,8 +25,8 @@
|
||||
#include <tsl/robin_map.h>
|
||||
#include <tsl/robin_set.h>
|
||||
|
||||
#include <cuda/std/atomic>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
|
||||
#include "./circular_queue.h"
|
||||
|
||||
@@ -34,14 +34,20 @@ namespace graphbolt {
|
||||
namespace storage {
|
||||
|
||||
struct CacheKey {
|
||||
auto getKey() const {
|
||||
return (static_cast<int64_t>(key_upper_) << 32) + key_lower_;
|
||||
}
|
||||
|
||||
CacheKey(int64_t key) : CacheKey(key, std::numeric_limits<int64_t>::min()) {}
|
||||
|
||||
CacheKey(int64_t key, int64_t position)
|
||||
: freq_(0),
|
||||
// EndUse<true>() should be called to reset the reference count.
|
||||
reference_count_(-1),
|
||||
key_(key),
|
||||
key_upper_(key >> 32),
|
||||
key_lower_(key),
|
||||
position_in_cache_(position) {
|
||||
TORCH_CHECK(key == getKey());
|
||||
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
|
||||
}
|
||||
|
||||
@@ -49,8 +55,6 @@ struct CacheKey {
|
||||
|
||||
auto getFreq() const { return freq_; }
|
||||
|
||||
auto getKey() const { return key_; }
|
||||
|
||||
auto getPos() const { return position_in_cache_; }
|
||||
|
||||
CacheKey& setPos(int64_t pos) {
|
||||
@@ -79,35 +83,58 @@ struct CacheKey {
|
||||
}
|
||||
|
||||
CacheKey& StartRead() {
|
||||
++reference_count_;
|
||||
::cuda::std::atomic_ref ref(reference_count_);
|
||||
// StartRead runs concurrently only with EndUse. EndUse does not need to see
|
||||
// this modification at all. So we can use the relaxed memory order.
|
||||
const auto old_val = ref.fetch_add(1, ::cuda::std::memory_order_relaxed);
|
||||
TORCH_CHECK(
|
||||
old_val < std::numeric_limits<int8_t>::max(),
|
||||
"There are too many in-flight read requests to the same cache entry!");
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
CacheKey& EndUse() {
|
||||
::cuda::std::atomic_ref ref(reference_count_);
|
||||
// The EndUse operation needs to synchronize with InUse and BeingWritten
|
||||
// operations. So we have an release-acquire ordering here.
|
||||
// https://en.cppreference.com/w/cpp/atomic/memory_order#Release-Acquire_ordering
|
||||
if constexpr (write) {
|
||||
++reference_count_;
|
||||
ref.fetch_add(1, ::cuda::std::memory_order_release);
|
||||
} else {
|
||||
--reference_count_;
|
||||
ref.fetch_add(-1, ::cuda::std::memory_order_release);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool InUse() const { return reference_count_; }
|
||||
bool InUse() const {
|
||||
::cuda::std::atomic_ref ref(reference_count_);
|
||||
// The operations after a call to this function need to happen after the
|
||||
// load operation. Hence the acquire order.
|
||||
return ref.load(::cuda::std::memory_order_acquire);
|
||||
}
|
||||
|
||||
bool BeingWritten() const { return reference_count_ < 0; }
|
||||
bool BeingWritten() const {
|
||||
::cuda::std::atomic_ref ref(reference_count_);
|
||||
// The operations after a call to this function need to happen after the
|
||||
// load operation. Hence the acquire order.
|
||||
return ref.load(::cuda::std::memory_order_acquire) < 0;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
|
||||
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
|
||||
<< key_ref.position_in_cache_ << ", " << key_ref.reference_count_
|
||||
<< ")";
|
||||
::cuda::std::atomic_ref ref(key_ref.reference_count_);
|
||||
return os << '(' << key_ref.getKey() << ", " << key_ref.freq_ << ", "
|
||||
<< key_ref.position_in_cache_ << ", " << ref.load() << ")";
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t freq_ : 3;
|
||||
int8_t freq_;
|
||||
// Negative values indicate writing while positive values indicate reading.
|
||||
int64_t reference_count_ : 13;
|
||||
int64_t key_ : 48;
|
||||
// Access only through an std::atomic_ref instance atomically.
|
||||
int8_t reference_count_;
|
||||
// Keys are restricted to be 48-bit unsigned integers.
|
||||
uint16_t key_upper_;
|
||||
uint32_t key_lower_;
|
||||
int64_t position_in_cache_;
|
||||
};
|
||||
|
||||
|
||||
@@ -451,7 +451,6 @@ template <bool write>
|
||||
void PartitionedCachePolicy::ReadingWritingCompletedImpl(
|
||||
torch::Tensor pointers, torch::Tensor offsets) {
|
||||
if (policies_.size() == 1) {
|
||||
std::lock_guard lock(mtx_);
|
||||
if constexpr (write)
|
||||
policies_[0]->WritingCompleted(pointers);
|
||||
else
|
||||
@@ -460,7 +459,6 @@ void PartitionedCachePolicy::ReadingWritingCompletedImpl(
|
||||
}
|
||||
auto offsets_ptr = offsets.data_ptr<int64_t>();
|
||||
namespace gb = graphbolt;
|
||||
std::lock_guard lock(mtx_);
|
||||
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
|
||||
if (begin == end) return;
|
||||
const auto tid = begin;
|
||||
|
||||
Reference in New Issue
Block a user