mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[GraphBolt] Refactor CachePolicy more. (#7649)
This commit is contained in:
committed by
GitHub
parent
683a25a8ec
commit
0b31bdd4e7
@@ -169,9 +169,8 @@ std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
|
||||
return {positions, pointers};
|
||||
}
|
||||
|
||||
template <bool write, typename CachePolicy>
|
||||
void BaseCachePolicy::ReadingWritingCompletedImpl(
|
||||
CachePolicy& policy, torch::Tensor pointers) {
|
||||
template <bool write>
|
||||
void BaseCachePolicy::ReadingWritingCompletedImpl(torch::Tensor pointers) {
|
||||
static_assert(
|
||||
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
|
||||
auto pointers_ptr =
|
||||
@@ -184,6 +183,14 @@ void BaseCachePolicy::ReadingWritingCompletedImpl(
|
||||
}
|
||||
}
|
||||
|
||||
void BaseCachePolicy::ReadingCompleted(torch::Tensor pointers) {
|
||||
ReadingWritingCompletedImpl<false>(pointers);
|
||||
}
|
||||
|
||||
void BaseCachePolicy::WritingCompleted(torch::Tensor pointers) {
|
||||
ReadingWritingCompletedImpl<true>(pointers);
|
||||
}
|
||||
|
||||
S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
|
||||
: BaseCachePolicy(capacity),
|
||||
ghost_queue_(capacity - capacity / 10),
|
||||
@@ -209,14 +216,6 @@ std::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(
|
||||
return ReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
void S3FifoCachePolicy::ReadingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<false>(*this, keys);
|
||||
}
|
||||
|
||||
void S3FifoCachePolicy::WritingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<true>(*this, keys);
|
||||
}
|
||||
|
||||
SieveCachePolicy::SieveCachePolicy(int64_t capacity)
|
||||
// Ensure that queue_ is constructed first before accessing its `.end()`.
|
||||
: BaseCachePolicy(capacity), queue_(), hand_(queue_.end()) {
|
||||
@@ -239,14 +238,6 @@ std::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(
|
||||
return ReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
void SieveCachePolicy::ReadingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<false>(*this, keys);
|
||||
}
|
||||
|
||||
void SieveCachePolicy::WritingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<true>(*this, keys);
|
||||
}
|
||||
|
||||
LruCachePolicy::LruCachePolicy(int64_t capacity) : BaseCachePolicy(capacity) {
|
||||
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
|
||||
key_to_cache_key_.reserve(kCapacityFactor * (capacity + 1));
|
||||
@@ -267,14 +258,6 @@ std::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(
|
||||
return ReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
void LruCachePolicy::ReadingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<false>(*this, keys);
|
||||
}
|
||||
|
||||
void LruCachePolicy::WritingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<true>(*this, keys);
|
||||
}
|
||||
|
||||
ClockCachePolicy::ClockCachePolicy(int64_t capacity)
|
||||
: BaseCachePolicy(capacity) {
|
||||
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
|
||||
@@ -296,13 +279,5 @@ std::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(
|
||||
return ReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
void ClockCachePolicy::ReadingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<false>(*this, keys);
|
||||
}
|
||||
|
||||
void ClockCachePolicy::WritingCompleted(torch::Tensor keys) {
|
||||
ReadingWritingCompletedImpl<true>(*this, keys);
|
||||
}
|
||||
|
||||
} // namespace storage
|
||||
} // namespace graphbolt
|
||||
|
||||
@@ -40,9 +40,8 @@ struct CacheKey {
|
||||
: freq_(0),
|
||||
key_(key),
|
||||
position_in_cache_(position),
|
||||
read_reference_count_(0),
|
||||
// EndUse<true>() should be called to reset the write_reference_count.
|
||||
write_reference_count_(1) {
|
||||
// EndUse<true>() should be called to reset the reference count.
|
||||
reference_count_(-1) {
|
||||
static_assert(sizeof(CacheKey) == 2 * sizeof(int64_t));
|
||||
}
|
||||
|
||||
@@ -80,23 +79,26 @@ struct CacheKey {
|
||||
}
|
||||
|
||||
CacheKey& StartRead() {
|
||||
TORCH_CHECK(read_reference_count_++ < std::numeric_limits<int8_t>::max());
|
||||
TORCH_CHECK(reference_count_ >= 0);
|
||||
TORCH_CHECK(reference_count_++ < std::numeric_limits<int16_t>::max());
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
CacheKey& EndUse() {
|
||||
if constexpr (write) {
|
||||
--write_reference_count_;
|
||||
TORCH_CHECK(reference_count_ == -1);
|
||||
++reference_count_;
|
||||
} else {
|
||||
--read_reference_count_;
|
||||
TORCH_CHECK(reference_count_ > 0);
|
||||
--reference_count_;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool InUse() const { return read_reference_count_ || write_reference_count_; }
|
||||
bool InUse() const { return reference_count_; }
|
||||
|
||||
bool BeingWritten() const { return write_reference_count_; }
|
||||
bool BeingWritten() const { return reference_count_ < 0; }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const CacheKey& key_ref) {
|
||||
return os << '(' << key_ref.key_ << ", " << key_ref.freq_ << ", "
|
||||
@@ -106,10 +108,9 @@ struct CacheKey {
|
||||
private:
|
||||
int64_t freq_ : 3;
|
||||
int64_t key_ : 61;
|
||||
int64_t position_in_cache_ : 40;
|
||||
int64_t read_reference_count_ : 8;
|
||||
// There could be a chain of writes so it is better to have larger bit count.
|
||||
int64_t write_reference_count_ : 16;
|
||||
int64_t position_in_cache_ : 48;
|
||||
// Negative values indicate writing while positive values indicate reading.
|
||||
int64_t reference_count_ : 16;
|
||||
};
|
||||
|
||||
class BaseCachePolicy {
|
||||
@@ -167,13 +168,13 @@ class BaseCachePolicy {
|
||||
* @brief A reader has finished reading these keys, so they can be evicted.
|
||||
* @param pointers The CacheKey pointers in the cache to unmark.
|
||||
*/
|
||||
virtual void ReadingCompleted(torch::Tensor pointers) = 0;
|
||||
static void ReadingCompleted(torch::Tensor pointers);
|
||||
|
||||
/**
|
||||
* @brief A writer has finished writing these keys, so they can be evicted.
|
||||
* @param pointers The CacheKey pointers in the cache to unmark.
|
||||
*/
|
||||
virtual void WritingCompleted(torch::Tensor pointers) = 0;
|
||||
static void WritingCompleted(torch::Tensor pointers);
|
||||
|
||||
protected:
|
||||
template <typename K, typename V>
|
||||
@@ -198,10 +199,6 @@ class BaseCachePolicy {
|
||||
static std::tuple<torch::Tensor, torch::Tensor> ReplaceImpl(
|
||||
CachePolicy& policy, torch::Tensor keys);
|
||||
|
||||
template <bool write, typename CachePolicy>
|
||||
static void ReadingWritingCompletedImpl(
|
||||
CachePolicy& policy, torch::Tensor pointers);
|
||||
|
||||
template <typename T>
|
||||
static void MoveToFront(
|
||||
std::list<T>& from, std::list<T>& to,
|
||||
@@ -219,6 +216,10 @@ class BaseCachePolicy {
|
||||
|
||||
int64_t capacity_;
|
||||
int64_t cache_usage_;
|
||||
|
||||
private:
|
||||
template <bool write>
|
||||
static void ReadingWritingCompletedImpl(torch::Tensor pointers);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -258,16 +259,6 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::ReadingCompleted.
|
||||
*/
|
||||
void ReadingCompleted(torch::Tensor pointers);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::WritingCompleted.
|
||||
*/
|
||||
void WritingCompleted(torch::Tensor pointers);
|
||||
|
||||
template <bool write>
|
||||
CacheKey* Read(int64_t key) {
|
||||
auto it = key_to_cache_key_.find(key);
|
||||
@@ -413,16 +404,6 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::ReadingCompleted.
|
||||
*/
|
||||
void ReadingCompleted(torch::Tensor pointers);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::WritingCompleted.
|
||||
*/
|
||||
void WritingCompleted(torch::Tensor pointers);
|
||||
|
||||
template <bool write>
|
||||
CacheKey* Read(int64_t key) {
|
||||
auto it = key_to_cache_key_.find(key);
|
||||
@@ -530,16 +511,6 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::ReadingCompleted.
|
||||
*/
|
||||
void ReadingCompleted(torch::Tensor pointers);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::WritingCompleted.
|
||||
*/
|
||||
void WritingCompleted(torch::Tensor pointers);
|
||||
|
||||
template <bool write>
|
||||
CacheKey* Read(int64_t key) {
|
||||
auto it = key_to_cache_key_.find(key);
|
||||
@@ -646,16 +617,6 @@ class ClockCachePolicy : public BaseCachePolicy {
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor> Replace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::ReadingCompleted.
|
||||
*/
|
||||
void ReadingCompleted(torch::Tensor pointers);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::WritingCompleted.
|
||||
*/
|
||||
void WritingCompleted(torch::Tensor pointers);
|
||||
|
||||
template <bool write>
|
||||
CacheKey* Read(int64_t key) {
|
||||
auto it = key_to_cache_key_.find(key);
|
||||
|
||||
@@ -72,7 +72,6 @@ void FeatureCache::Replace(torch::Tensor positions, torch::Tensor values) {
|
||||
auto values_ptr = reinterpret_cast<std::byte*>(values.data_ptr());
|
||||
const auto tensor_ptr = reinterpret_cast<std::byte*>(tensor_.data_ptr());
|
||||
const auto positions_ptr = positions.data_ptr<int64_t>();
|
||||
std::lock_guard lock(mtx_);
|
||||
graphbolt::parallel_for(
|
||||
0, positions.size(0), kIntGrainSize, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; i++) {
|
||||
|
||||
@@ -89,8 +89,6 @@ struct FeatureCache : public torch::CustomClassHolder {
|
||||
|
||||
private:
|
||||
torch::Tensor tensor_;
|
||||
// Protects writes only as reads are guaranteed to be safe.
|
||||
std::mutex mtx_;
|
||||
};
|
||||
|
||||
} // namespace storage
|
||||
|
||||
@@ -162,7 +162,7 @@ class CPUCachedFeature(Feature):
|
||||
missing_values_future = next(fallback_reader, None)
|
||||
yield # fallback feature stages.
|
||||
|
||||
values_from_cpu_copy_event.wait()
|
||||
values_from_cpu_copy_event.synchronize()
|
||||
reading_completed = policy.reading_completed_async(
|
||||
found_pointers, found_offsets
|
||||
)
|
||||
@@ -187,7 +187,6 @@ class CPUCachedFeature(Feature):
|
||||
|
||||
reading_completed.wait()
|
||||
replace_future.wait()
|
||||
missing_values_copy_event.wait()
|
||||
writing_completed = policy.writing_completed_async(
|
||||
missing_pointers, missing_offsets
|
||||
)
|
||||
@@ -219,7 +218,11 @@ class CPUCachedFeature(Feature):
|
||||
return values
|
||||
|
||||
yield _Waiter(
|
||||
[writing_completed],
|
||||
[
|
||||
writing_completed,
|
||||
values_from_cpu_copy_event,
|
||||
missing_values_copy_event,
|
||||
],
|
||||
values_from_cpu,
|
||||
missing_values,
|
||||
index,
|
||||
|
||||
Reference in New Issue
Block a user