mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] CachePolicy::QueryAndThenReplace. (#7619)
This commit is contained in:
committed by
GitHub
parent
9550f0ecab
commit
0462538c5c
@@ -73,6 +73,72 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
|
||||
found_ptr_tensor.slice(0, 0, found_cnt)};
|
||||
}
|
||||
|
||||
template <typename CachePolicy>
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
BaseCachePolicy::QueryAndThenReplaceImpl(
|
||||
CachePolicy& policy, torch::Tensor keys) {
|
||||
auto positions = torch::empty_like(
|
||||
keys, keys.options()
|
||||
.dtype(torch::kInt64)
|
||||
.pinned_memory(utils::is_pinned(keys)));
|
||||
auto indices = torch::empty_like(
|
||||
keys, keys.options()
|
||||
.dtype(torch::kInt64)
|
||||
.pinned_memory(utils::is_pinned(keys)));
|
||||
auto pointers = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
|
||||
auto missing_keys = torch::empty_like(
|
||||
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
|
||||
int64_t found_cnt = 0;
|
||||
int64_t missing_cnt = keys.size(0);
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
|
||||
auto keys_ptr = keys.data_ptr<index_t>();
|
||||
auto positions_ptr = positions.data_ptr<int64_t>();
|
||||
auto indices_ptr = indices.data_ptr<int64_t>();
|
||||
static_assert(
|
||||
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
|
||||
auto pointers_ptr =
|
||||
reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());
|
||||
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
|
||||
auto iterators = std::unique_ptr<typename CachePolicy::map_iterator[]>(
|
||||
new typename CachePolicy::map_iterator[keys.size(0)]);
|
||||
// QueryImpl here.
|
||||
for (int64_t i = 0; i < keys.size(0); i++) {
|
||||
const auto key = keys_ptr[i];
|
||||
const auto [it, can_read] = policy.Emplace(key);
|
||||
if (can_read) {
|
||||
auto& cache_key = *it->second;
|
||||
positions_ptr[found_cnt] = cache_key.getPos();
|
||||
pointers_ptr[found_cnt] = &cache_key;
|
||||
indices_ptr[found_cnt++] = i;
|
||||
} else {
|
||||
indices_ptr[--missing_cnt] = i;
|
||||
missing_keys_ptr[missing_cnt] = key;
|
||||
iterators[missing_cnt] = it;
|
||||
}
|
||||
}
|
||||
// ReplaceImpl here.
|
||||
set_t<int64_t> position_set;
|
||||
position_set.reserve(keys.size(0));
|
||||
for (int64_t i = missing_cnt; i < missing_keys.size(0); i++) {
|
||||
auto it = iterators[i];
|
||||
if (it->second == policy.getMapSentinelValue()) {
|
||||
policy.Insert(it);
|
||||
// After Insert, it->second is not nullptr anymore.
|
||||
TORCH_CHECK(
|
||||
// If there are duplicate values and the key was just inserted,
|
||||
// we do not have to check for the uniqueness of the positions.
|
||||
std::get<1>(position_set.insert(it->second->getPos())),
|
||||
"Can't insert all, larger cache capacity is needed.");
|
||||
}
|
||||
auto& cache_key = *it->second;
|
||||
positions_ptr[i] = cache_key.getPos();
|
||||
pointers_ptr[i] = &cache_key;
|
||||
}
|
||||
}));
|
||||
return {positions, indices, pointers, missing_keys};
|
||||
}
|
||||
|
||||
template <typename CachePolicy>
|
||||
std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
|
||||
CachePolicy& policy, torch::Tensor keys) {
|
||||
@@ -140,6 +206,11 @@ S3FifoCachePolicy::Query(torch::Tensor keys) {
|
||||
return QueryImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
S3FifoCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
|
||||
return QueryAndThenReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(
|
||||
torch::Tensor keys) {
|
||||
return ReplaceImpl(*this, keys);
|
||||
@@ -165,6 +236,11 @@ SieveCachePolicy::Query(torch::Tensor keys) {
|
||||
return QueryImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
SieveCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
|
||||
return QueryAndThenReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(
|
||||
torch::Tensor keys) {
|
||||
return ReplaceImpl(*this, keys);
|
||||
@@ -189,6 +265,11 @@ LruCachePolicy::Query(torch::Tensor keys) {
|
||||
return QueryImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
LruCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
|
||||
return QueryAndThenReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(
|
||||
torch::Tensor keys) {
|
||||
return ReplaceImpl(*this, keys);
|
||||
@@ -213,6 +294,11 @@ ClockCachePolicy::Query(torch::Tensor keys) {
|
||||
return QueryImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
ClockCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
|
||||
return QueryAndThenReplaceImpl(*this, keys);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(
|
||||
torch::Tensor keys) {
|
||||
return ReplaceImpl(*this, keys);
|
||||
|
||||
@@ -131,6 +131,21 @@ class BaseCachePolicy {
|
||||
virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
Query(torch::Tensor keys) = 0;
|
||||
|
||||
/**
|
||||
* @brief The policy query function.
|
||||
* @param keys The keys to query the cache.
|
||||
*
|
||||
* @return (positions, indices, pointers, missing_keys), where positions has
|
||||
* the locations of the keys which were emplaced into the cache, pointers
|
||||
* point to the emplaced CacheKey pointers in the cache, missing_keys has the
|
||||
* keys that were not found and just inserted and indices is defined such that
|
||||
* keys[indices[:keys.size(0) - missing_keys.size(0)]] gives us the keys for
|
||||
* the found keys and keys[indices[keys.size(0) - missing_keys.size(0):]] is
|
||||
* identical to missing_keys.
|
||||
*/
|
||||
virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryAndThenReplace(torch::Tensor keys) = 0;
|
||||
|
||||
/**
|
||||
* @brief The policy replace function.
|
||||
* @param keys The keys to query the cache.
|
||||
@@ -165,6 +180,10 @@ class BaseCachePolicy {
|
||||
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryImpl(CachePolicy& policy, torch::Tensor keys);
|
||||
|
||||
template <typename CachePolicy>
|
||||
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryAndThenReplaceImpl(CachePolicy& policy, torch::Tensor keys);
|
||||
|
||||
template <typename CachePolicy>
|
||||
static std::tuple<torch::Tensor, torch::Tensor> ReplaceImpl(
|
||||
CachePolicy& policy, torch::Tensor keys);
|
||||
@@ -180,6 +199,7 @@ class BaseCachePolicy {
|
||||
**/
|
||||
class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
public:
|
||||
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
|
||||
/**
|
||||
* @brief Constructor for the S3FifoCachePolicy class.
|
||||
*
|
||||
@@ -199,6 +219,12 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
|
||||
torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::QueryAndThenReplace.
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryAndThenReplace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::Replace.
|
||||
*/
|
||||
@@ -234,6 +260,25 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto getMapSentinelValue() const { return nullptr; }
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != nullptr) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.Increment().StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.Increment().StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
return {it, false};
|
||||
}
|
||||
|
||||
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
|
||||
const auto pos = Evict();
|
||||
const auto in_ghost_queue = ghost_set_.erase(key);
|
||||
@@ -243,6 +288,14 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
return {pos, cache_key_ptr};
|
||||
}
|
||||
|
||||
void Insert(map_iterator it) {
|
||||
const auto key = it->first;
|
||||
const auto pos = Evict();
|
||||
const auto in_ghost_queue = ghost_set_.erase(key);
|
||||
auto& queue = in_ghost_queue ? main_queue_ : small_queue_;
|
||||
it->second = queue.Push(CacheKey(key, pos));
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
@@ -306,6 +359,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
**/
|
||||
class SieveCachePolicy : public BaseCachePolicy {
|
||||
public:
|
||||
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
|
||||
/**
|
||||
* @brief Constructor for the SieveCachePolicy class.
|
||||
*
|
||||
@@ -323,6 +377,12 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
|
||||
torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::QueryAndThenReplace.
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryAndThenReplace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::Replace.
|
||||
*/
|
||||
@@ -350,6 +410,25 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto getMapSentinelValue() const { return nullptr; }
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != nullptr) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.SetFreq().StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.SetFreq().StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
return {it, false};
|
||||
}
|
||||
|
||||
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
|
||||
const auto pos = Evict();
|
||||
queue_.push_front(CacheKey(key, pos));
|
||||
@@ -358,6 +437,13 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
return {pos, cache_key_ptr};
|
||||
}
|
||||
|
||||
void Insert(map_iterator it) {
|
||||
const auto key = it->first;
|
||||
const auto pos = Evict();
|
||||
queue_.push_front(CacheKey(key, pos));
|
||||
it->second = &queue_.front();
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
@@ -398,6 +484,7 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
**/
|
||||
class LruCachePolicy : public BaseCachePolicy {
|
||||
public:
|
||||
using map_iterator = map_t<int64_t, std::list<CacheKey>::iterator>::iterator;
|
||||
/**
|
||||
* @brief Constructor for the LruCachePolicy class.
|
||||
*
|
||||
@@ -415,6 +502,12 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
|
||||
torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::QueryAndThenReplace.
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryAndThenReplace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::Replace.
|
||||
*/
|
||||
@@ -455,6 +548,26 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto getMapSentinelValue() { return queue_.end(); }
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != queue_.end()) {
|
||||
MoveToFront(it->second);
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
return {it, false};
|
||||
}
|
||||
|
||||
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
|
||||
const auto pos = Evict();
|
||||
queue_.push_front(CacheKey(key, pos));
|
||||
@@ -462,6 +575,13 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
return {pos, &queue_.front()};
|
||||
}
|
||||
|
||||
void Insert(map_iterator it) {
|
||||
const auto key = it->first;
|
||||
const auto pos = Evict();
|
||||
queue_.push_front(CacheKey(key, pos));
|
||||
it->second = queue_.begin();
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
@@ -501,6 +621,7 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
**/
|
||||
class ClockCachePolicy : public BaseCachePolicy {
|
||||
public:
|
||||
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
|
||||
/**
|
||||
* @brief Constructor for the ClockCachePolicy class.
|
||||
*
|
||||
@@ -520,6 +641,12 @@ class ClockCachePolicy : public BaseCachePolicy {
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
|
||||
torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::QueryAndThenReplace.
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
QueryAndThenReplace(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief See BaseCachePolicy::Replace.
|
||||
*/
|
||||
@@ -547,6 +674,25 @@ class ClockCachePolicy : public BaseCachePolicy {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
auto getMapSentinelValue() const { return nullptr; }
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != nullptr) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.SetFreq().StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.SetFreq().StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
return {it, false};
|
||||
}
|
||||
|
||||
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
|
||||
const auto pos = Evict();
|
||||
auto cache_key_ptr = queue_.Push(CacheKey(key, pos));
|
||||
@@ -554,6 +700,12 @@ class ClockCachePolicy : public BaseCachePolicy {
|
||||
return {pos, cache_key_ptr};
|
||||
}
|
||||
|
||||
void Insert(map_iterator it) {
|
||||
const auto key = it->first;
|
||||
const auto pos = Evict();
|
||||
it->second = queue_.Push(CacheKey(key, pos));
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
|
||||
Reference in New Issue
Block a user