[GraphBolt] Refactor CachePolicy more. (#7649)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-03 23:12:34 -04:00
committed by GitHub
parent 683a25a8ec
commit 0b31bdd4e7
5 changed files with 35 additions and 99 deletions

View File

@@ -169,9 +169,8 @@ std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
return {positions, pointers};
}
template <bool write, typename CachePolicy>
void BaseCachePolicy::ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor pointers) {
template <bool write>
void BaseCachePolicy::ReadingWritingCompletedImpl(torch::Tensor pointers) {
static_assert(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto pointers_ptr =
@@ -184,6 +183,14 @@ void BaseCachePolicy::ReadingWritingCompletedImpl(
}
}
void BaseCachePolicy::ReadingCompleted(torch::Tensor pointers) {
ReadingWritingCompletedImpl<false>(pointers);
}
void BaseCachePolicy::WritingCompleted(torch::Tensor pointers) {
ReadingWritingCompletedImpl<true>(pointers);
}
S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
: BaseCachePolicy(capacity),
ghost_queue_(capacity - capacity / 10),
@@ -209,14 +216,6 @@ std::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(
return ReplaceImpl(*this, keys);
}
void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<false>(*this, keys);
}
void S3FifoCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}
SieveCachePolicy::SieveCachePolicy(int64_t capacity)
// Ensure that queue_ is constructed first before accessing its `.end()`.
: BaseCachePolicy(capacity), queue_(), hand_(queue_.end()) {
@@ -239,14 +238,6 @@ std::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(
return ReplaceImpl(*this, keys);
}
void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<false>(*this, keys);
}
void SieveCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}
LruCachePolicy::LruCachePolicy(int64_t capacity) : BaseCachePolicy(capacity) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1));
@@ -267,14 +258,6 @@ std::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(
return ReplaceImpl(*this, keys);
}
void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<false>(*this, keys);
}
void LruCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}
ClockCachePolicy::ClockCachePolicy(int64_t capacity)
: BaseCachePolicy(capacity) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
@@ -296,13 +279,5 @@ std::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(
return ReplaceImpl(*this, keys);
}
void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<false>(*this, keys);
}
void ClockCachePolicy::WritingCompleted(torch::Tensor keys) {
ReadingWritingCompletedImpl<true>(*this, keys);
}
} // namespace storage
} // namespace graphbolt

View File

@@ -40,9 +40,8 @@ struct CacheKey {
: freq_(0),
key_(key),
position_in_cache_(position),
read_reference_count_(0),
// EndUse<true>() should be called to reset the write_reference_count.
write_reference_count_(1) {
// EndUse<true>() should be called to reset the reference count.
reference_count_(-1) {
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
}
@@ -80,23 +79,26 @@ struct CacheKey {
}
CacheKey& StartRead() {
TORCH_CHECK(read_reference_count_++ < std::numeric_limits<int8_t>::max());
TORCH_CHECK(reference_count_ >= 0);
TORCH_CHECK(reference_count_++ < std::numeric_limits<int16_t>::max());
return *this;
}
template <bool write>
CacheKey& EndUse() {
if constexpr (write) {
--write_reference_count_;
TORCH_CHECK(reference_count_ == -1);
++reference_count_;
} else {
--read_reference_count_;
TORCH_CHECK(reference_count_ > 0);
--reference_count_;
}
return *this;
}
bool InUse() const { return read_reference_count_ || write_reference_count_; }
bool InUse() const { return reference_count_; }
bool BeingWritten() const { return write_reference_count_; }
bool BeingWritten() const { return reference_count_ < 0; }
friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
@@ -106,10 +108,9 @@ struct CacheKey {
private:
int64_t freq_ : 3;
int64_t key_ : 61;
int64_t position_in_cache_ : 40;
int64_t read_reference_count_ : 8;
// There could be a chain of writes so it is better to have larger bit count.
int64_t write_reference_count_ : 16;
int64_t position_in_cache_ : 48;
// Negative values indicate writing while positive values indicate reading.
int64_t reference_count_ : 16;
};
class BaseCachePolicy {
@@ -167,13 +168,13 @@ class BaseCachePolicy {
* @brief A reader has finished reading these keys, so they can be evicted.
* @param pointers The CacheKey pointers in the cache to unmark.
*/
virtual void ReadingCompleted(torch::Tensor pointers) = 0;
static void ReadingCompleted(torch::Tensor pointers);
/**
* @brief A writer has finished writing these keys, so they can be evicted.
* @param pointers The CacheKey pointers in the cache to unmark.
*/
virtual void WritingCompleted(torch::Tensor pointers) = 0;
static void WritingCompleted(torch::Tensor pointers);
protected:
template <typename K, typename V>
@@ -198,10 +199,6 @@ class BaseCachePolicy {
static std::tuple<torch::Tensor, torch::Tensor> ReplaceImpl(
CachePolicy& policy, torch::Tensor keys);
template <bool write, typename CachePolicy>
static void ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor pointers);
template <typename T>
static void MoveToFront(
std::list<T>& from, std::list<T>& to,
@@ -219,6 +216,10 @@ class BaseCachePolicy {
int64_t capacity_;
int64_t cache_usage_;
private:
template <bool write>
static void ReadingWritingCompletedImpl(torch::Tensor pointers);
};
/**
@@ -258,16 +259,6 @@ class S3FifoCachePolicy : public BaseCachePolicy {
*/
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::ReadingCompleted.
*/
void ReadingCompleted(torch::Tensor pointers);
/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor pointers);
template <bool write>
CacheKey* Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
@@ -413,16 +404,6 @@ class SieveCachePolicy : public BaseCachePolicy {
*/
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::ReadingCompleted.
*/
void ReadingCompleted(torch::Tensor pointers);
/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor pointers);
template <bool write>
CacheKey* Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
@@ -530,16 +511,6 @@ class LruCachePolicy : public BaseCachePolicy {
*/
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::ReadingCompleted.
*/
void ReadingCompleted(torch::Tensor pointers);
/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor pointers);
template <bool write>
CacheKey* Read(int64_t key) {
auto it = key_to_cache_key_.find(key);
@@ -646,16 +617,6 @@ class ClockCachePolicy : public BaseCachePolicy {
*/
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::ReadingCompleted.
*/
void ReadingCompleted(torch::Tensor pointers);
/**
* @brief See BaseCachePolicy::WritingCompleted.
*/
void WritingCompleted(torch::Tensor pointers);
template <bool write>
CacheKey* Read(int64_t key) {
auto it = key_to_cache_key_.find(key);

View File

@@ -72,7 +72,6 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());
const auto positions_ptr = positions.data_ptr<int64_t>();
std::lock_guard lock(mtx_);
graphbolt::parallel_for(
0, positions.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {

View File

@@ -89,8 +89,6 @@ struct FeatureCache : public torch::CustomClassHolder {
private:
torch::Tensor tensor_;
// Protects writes only as reads are guaranteed to be safe.
std::mutex mtx_;
};
} // namespace storage

View File

@@ -162,7 +162,7 @@ class CPUCachedFeature(Feature):
missing_values_future = next(fallback_reader, None)
yield # fallback feature stages.
values_from_cpu_copy_event.wait()
values_from_cpu_copy_event.synchronize()
reading_completed = policy.reading_completed_async(
found_pointers, found_offsets
)
@@ -187,7 +187,6 @@ class CPUCachedFeature(Feature):
reading_completed.wait()
replace_future.wait()
missing_values_copy_event.wait()
writing_completed = policy.writing_completed_async(
missing_pointers, missing_offsets
)
@@ -219,7 +218,11 @@ class CPUCachedFeature(Feature):
return values
yield _Waiter(
[writing_completed],
[
writing_completed,
values_from_cpu_copy_event,
missing_values_copy_event,
],
values_from_cpu,
missing_values,
index,