[GraphBolt] Lock-free CachePolicy::ReadWriteCompleted. (#7654)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-04 19:55:18 -04:00
committed by GitHub
parent 2babaf9d66
commit 7deeff6f4c
2 changed files with 42 additions and 17 deletions

View File

@@ -25,8 +25,8 @@
#include <tsl/robin_map.h>
#include <tsl/robin_set.h>
#include <cuda/std/atomic>
#include <limits>
#include <mutex>
#include "./circular_queue.h"
@@ -34,14 +34,20 @@ namespace graphbolt {
namespace storage {
struct CacheKey {
auto getKey() const {
return (static_cast<int64_t>(key_upper_) << 32) + key_lower_;
}
CacheKey(int64_t key) : CacheKey(key, std::numeric_limits<int64_t>::min()) {}
CacheKey(int64_t key, int64_t position)
: freq_(0),
// EndUse<true>() should be called to reset the reference count.
reference_count_(-1),
key_(key),
key_upper_(key >> 32),
key_lower_(key),
position_in_cache_(position) {
TORCH_CHECK(key == getKey());
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
}
@@ -49,8 +55,6 @@ struct CacheKey {
auto getFreq() const { return freq_; }
auto getKey() const { return key_; }
auto getPos() const { return position_in_cache_; }
CacheKey& setPos(int64_t pos) {
@@ -79,35 +83,58 @@ struct CacheKey {
}
CacheKey& StartRead() {
++reference_count_;
::cuda::std::atomic_ref ref(reference_count_);
// StartRead runs concurrently only with EndUse. EndUse does not need to see
// this modification at all. So we can use the relaxed memory order.
const auto old_val = ref.fetch_add(1, ::cuda::std::memory_order_relaxed);
TORCH_CHECK(
old_val < std::numeric_limits<int8_t>::max(),
"There are too many in-flight read requests to the same cache entry!");
return *this;
}
template <bool write>
CacheKey& EndUse() {
::cuda::std::atomic_ref ref(reference_count_);
// The EndUse operation needs to synchronize with InUse and BeingWritten
// operations. So we have an release-acquire ordering here.
// https://en.cppreference.com/w/cpp/atomic/memory_order#Release-Acquire_ordering
if constexpr (write) {
++reference_count_;
ref.fetch_add(1, ::cuda::std::memory_order_release);
} else {
--reference_count_;
ref.fetch_add(-1, ::cuda::std::memory_order_release);
}
return *this;
}
bool InUse() const { return reference_count_; }
bool InUse() const {
::cuda::std::atomic_ref ref(reference_count_);
// The operations after a call to this function need to happen after the
// load operation. Hence the acquire order.
return ref.load(::cuda::std::memory_order_acquire);
}
bool BeingWritten() const { return reference_count_ < 0; }
bool BeingWritten() const {
::cuda::std::atomic_ref ref(reference_count_);
// The operations after a call to this function need to happen after the
// load operation. Hence the acquire order.
return ref.load(::cuda::std::memory_order_acquire) < 0;
}
friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
<< key_ref.position_in_cache_ << ", " << key_ref.reference_count_
<< ")";
::cuda::std::atomic_ref ref(key_ref.reference_count_);
return os << '(' << key_ref.getKey() << ", " << key_ref.freq_ << ", "
<< key_ref.position_in_cache_ << ", " << ref.load() << ")";
}
private:
int64_t freq_ : 3;
int8_t freq_;
// Negative values indicate writing while positive values indicate reading.
int64_t reference_count_ : 13;
int64_t key_ : 48;
// Access only through an std::atomic_ref instance atomically.
int8_t reference_count_;
// Keys are restricted to be 48-bit unsigned integers.
uint16_t key_upper_;
uint32_t key_lower_;
int64_t position_in_cache_;
};

View File

@@ -451,7 +451,6 @@ template <bool write>
void PartitionedCachePolicy::ReadingWritingCompletedImpl(
torch::Tensor pointers, torch::Tensor offsets) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
if constexpr (write)
policies_[0]->WritingCompleted(pointers);
else
@@ -460,7 +459,6 @@ void PartitionedCachePolicy::ReadingWritingCompletedImpl(
}
auto offsets_ptr = offsets.data_ptr<int64_t>();
namespace gb = graphbolt;
std::lock_guard lock(mtx_);
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;