[GraphBolt] CachePolicy::QueryAndThenReplace. (#7619)

This commit is contained in:
Muhammed Fatih BALIN
2024-07-30 13:02:42 -04:00
committed by GitHub
parent 9550f0ecab
commit 0462538c5c
2 changed files with 238 additions and 0 deletions

View File

@@ -73,6 +73,72 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
found_ptr_tensor.slice(0, 0, found_cnt)};
}
template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
BaseCachePolicy::QueryAndThenReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto indices = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto pointers = torch::empty_like(keys, keys.options().dtype(torch::kInt64));
auto missing_keys = torch::empty_like(
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
static_assert(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto pointers_ptr =
reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
auto iterators = std::unique_ptr<typename CachePolicy::map_iterator[]>(
new typename CachePolicy::map_iterator[keys.size(0)]);
// QueryImpl here.
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
const auto [it, can_read] = policy.Emplace(key);
if (can_read) {
auto& cache_key = *it->second;
positions_ptr[found_cnt] = cache_key.getPos();
pointers_ptr[found_cnt] = &cache_key;
indices_ptr[found_cnt++] = i;
} else {
indices_ptr[--missing_cnt] = i;
missing_keys_ptr[missing_cnt] = key;
iterators[missing_cnt] = it;
}
}
// ReplaceImpl here.
set_t<int64_t> position_set;
position_set.reserve(keys.size(0));
for (int64_t i = missing_cnt; i < missing_keys.size(0); i++) {
auto it = iterators[i];
if (it->second == policy.getMapSentinelValue()) {
policy.Insert(it);
// After Insert, it->second is not nullptr anymore.
TORCH_CHECK(
// If there are duplicate values and the key was just inserted,
// we do not have to check for the uniqueness of the positions.
std::get<1>(position_set.insert(it->second->getPos())),
"Can't insert all, larger cache capacity is needed.");
}
auto& cache_key = *it->second;
positions_ptr[i] = cache_key.getPos();
pointers_ptr[i] = &cache_key;
}
}));
return {positions, indices, pointers, missing_keys};
}
template <typename CachePolicy>
std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
@@ -140,6 +206,11 @@ S3FifoCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
S3FifoCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
@@ -165,6 +236,11 @@ SieveCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
SieveCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
@@ -189,6 +265,11 @@ LruCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
LruCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
@@ -213,6 +294,11 @@ ClockCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
ClockCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
return QueryAndThenReplaceImpl(*this, keys);
}
std::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);

View File

@@ -131,6 +131,21 @@ class BaseCachePolicy {
virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Query(torch::Tensor keys) = 0;
/**
* @brief The policy query function.
* @param keys The keys to query the cache.
*
* @return (positions, indices, pointers, missing_keys), where positions has
* the locations of the keys which were emplaced into the cache, pointers
* point to the emplaced CacheKey pointers in the cache, missing_keys has the
* keys that were not found and just inserted and indices is defined such that
* keys[indices[:keys.size(0) - missing_keys.size(0)]] gives us the keys for
* the found keys and keys[indices[keys.size(0) - missing_keys.size(0):]] is
* identical to missing_keys.
*/
virtual std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys) = 0;
/**
* @brief The policy replace function.
* @param keys The keys to query the cache.
@@ -165,6 +180,10 @@ class BaseCachePolicy {
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryImpl(CachePolicy& policy, torch::Tensor keys);
template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplaceImpl(CachePolicy& policy, torch::Tensor keys);
template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor> ReplaceImpl(
CachePolicy& policy, torch::Tensor keys);
@@ -180,6 +199,7 @@ class BaseCachePolicy {
**/
class S3FifoCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
/**
* @brief Constructor for the S3FifoCachePolicy class.
*
@@ -199,6 +219,12 @@ class S3FifoCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);
/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::Replace.
*/
@@ -234,6 +260,25 @@ class S3FifoCachePolicy : public BaseCachePolicy {
return std::nullopt;
}
auto getMapSentinelValue() const { return nullptr; }
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.Increment().StartUse<false>();
return {it, true};
} else {
cache_key.Increment().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
const auto in_ghost_queue = ghost_set_.erase(key);
@@ -243,6 +288,14 @@ class S3FifoCachePolicy : public BaseCachePolicy {
return {pos, cache_key_ptr};
}
void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
const auto in_ghost_queue = ghost_set_.erase(key);
auto& queue = in_ghost_queue ? main_queue_ : small_queue_;
it->second = queue.Push(CacheKey(key, pos));
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
@@ -306,6 +359,7 @@ class S3FifoCachePolicy : public BaseCachePolicy {
**/
class SieveCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
/**
* @brief Constructor for the SieveCachePolicy class.
*
@@ -323,6 +377,12 @@ class SieveCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);
/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::Replace.
*/
@@ -350,6 +410,25 @@ class SieveCachePolicy : public BaseCachePolicy {
return std::nullopt;
}
auto getMapSentinelValue() const { return nullptr; }
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
@@ -358,6 +437,13 @@ class SieveCachePolicy : public BaseCachePolicy {
return {pos, cache_key_ptr};
}
void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
it->second = &queue_.front();
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
@@ -398,6 +484,7 @@ class SieveCachePolicy : public BaseCachePolicy {
**/
class LruCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, std::list<CacheKey>::iterator>::iterator;
/**
* @brief Constructor for the LruCachePolicy class.
*
@@ -415,6 +502,12 @@ class LruCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);
/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::Replace.
*/
@@ -455,6 +548,26 @@ class LruCachePolicy : public BaseCachePolicy {
return std::nullopt;
}
auto getMapSentinelValue() { return queue_.end(); }
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != queue_.end()) {
MoveToFront(it->second);
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.StartUse<false>();
return {it, true};
} else {
cache_key.StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
@@ -462,6 +575,13 @@ class LruCachePolicy : public BaseCachePolicy {
return {pos, &queue_.front()};
}
void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
queue_.push_front(CacheKey(key, pos));
it->second = queue_.begin();
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
@@ -501,6 +621,7 @@ class LruCachePolicy : public BaseCachePolicy {
**/
class ClockCachePolicy : public BaseCachePolicy {
public:
using map_iterator = map_t<int64_t, CacheKey*>::iterator;
/**
* @brief Constructor for the ClockCachePolicy class.
*
@@ -520,6 +641,12 @@ class ClockCachePolicy : public BaseCachePolicy {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);
/**
* @brief See BaseCachePolicy::QueryAndThenReplace.
*/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);
/**
* @brief See BaseCachePolicy::Replace.
*/
@@ -547,6 +674,25 @@ class ClockCachePolicy : public BaseCachePolicy {
return std::nullopt;
}
auto getMapSentinelValue() const { return nullptr; }
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
return {it, false};
}
std::pair<int64_t, CacheKey*> Insert(int64_t key) {
const auto pos = Evict();
auto cache_key_ptr = queue_.Push(CacheKey(key, pos));
@@ -554,6 +700,12 @@ class ClockCachePolicy : public BaseCachePolicy {
return {pos, cache_key_ptr};
}
void Insert(map_iterator it) {
const auto key = it->first;
const auto pos = Evict();
it->second = queue_.Push(CacheKey(key, pos));
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();