[GraphBolt] CachePolicy::QueryAndThenReplace for num_parts>1. (#7621)

This commit is contained in:
Muhammed Fatih BALIN
2024-07-31 02:41:00 -04:00
committed by GitHub
parent eafb53013f
commit cfe0541802
5 changed files with 132 additions and 49 deletions

View File

@@ -130,6 +130,8 @@ BaseCachePolicy::QueryAndThenReplaceImpl(
// we do not have to check for the uniqueness of the positions.
std::get<1>(position_set.insert(it->second->getPos())),
"Can't insert all, larger cache capacity is needed.");
} else {
policy.MarkExistingWriting(it);
}
auto& cache_key = *it->second;
positions_ptr[i] = cache_key.getPos();

View File

@@ -264,15 +264,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.Increment().StartUse<false>();
return {it, true};
} else {
cache_key.Increment().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
@@ -296,6 +294,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
it->second = queue.Push(CacheKey(key, pos));
}
void MarkExistingWriting(map_iterator it) {
it->second->Increment().StartUse<true>();
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
@@ -414,15 +416,13 @@ class SieveCachePolicy : public BaseCachePolicy {
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
@@ -444,6 +444,10 @@ class SieveCachePolicy : public BaseCachePolicy {
it->second = &queue_.front();
}
void MarkExistingWriting(map_iterator it) {
it->second->SetFreq().StartUse<true>();
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
@@ -552,16 +556,14 @@ class LruCachePolicy : public BaseCachePolicy {
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != queue_.end()) {
MoveToFront(it->second);
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
MoveToFront(it->second);
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.StartUse<false>();
return {it, true};
} else {
cache_key.StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
@@ -582,6 +584,11 @@ class LruCachePolicy : public BaseCachePolicy {
it->second = queue_.begin();
}
void MarkExistingWriting(map_iterator it) {
MoveToFront(it->second);
it->second->StartUse<true>();
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();
@@ -678,15 +685,13 @@ class ClockCachePolicy : public BaseCachePolicy {
std::pair<map_iterator, bool> Emplace(int64_t key) {
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
if (it->second != nullptr) {
if (it->second != getMapSentinelValue()) {
auto& cache_key = *it->second;
if (!cache_key.BeingWritten()) {
// Not being written so we use StartUse<write=false>() and return
// true to indicate the key is ready to read.
cache_key.SetFreq().StartUse<false>();
return {it, true};
} else {
cache_key.SetFreq().StartUse<true>();
}
}
// First time insertion, return false to indicate not ready to read.
@@ -706,6 +711,10 @@ class ClockCachePolicy : public BaseCachePolicy {
it->second = queue_.Push(CacheKey(key, pos));
}
void MarkExistingWriting(map_iterator it) {
it->second->SetFreq().StartUse<true>();
}
template <bool write>
void Unmark(CacheKey* cache_key_ptr) {
cache_key_ptr->EndUse<write>();

View File

@@ -206,11 +206,10 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
selected_positions_ptr, selected_positions_ptr + num_selected,
positions.data_ptr<int64_t>() + begin,
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
std::memcpy(
reinterpret_cast<std::byte*>(found_pointers.data_ptr()) +
begin * found_pointers.element_size(),
std::get<3>(results[tid]).data_ptr(),
num_selected * found_pointers.element_size());
auto selected_pointers_ptr = std::get<3>(results[tid]).data_ptr<int64_t>();
std::copy(
selected_pointers_ptr, selected_pointers_ptr + num_selected,
found_pointers.data_ptr<int64_t>() + begin);
begin = result_offsets[policies_.size() + tid];
end = result_offsets[policies_.size() + tid + 1];
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
@@ -263,7 +262,102 @@ PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
return {positions, output_indices, pointers,
missing_keys, found_offsets, missing_offsets};
};
}
torch::Tensor offsets, indices, permuted_keys;
std::tie(offsets, indices, permuted_keys) = Partition(keys);
auto offsets_ptr = offsets.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
std::vector<
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
results(policies_.size());
torch::Tensor result_offsets_tensor =
torch::empty(policies_.size() * 2 + 1, offsets.options());
auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();
namespace gb = graphbolt;
{
std::lock_guard lock(mtx_);
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
TORCH_CHECK(end - begin == 1);
const auto tid = begin;
begin = offsets_ptr[tid];
end = offsets_ptr[tid + 1];
results[tid] = policies_.at(tid)->QueryAndThenReplace(
permuted_keys.slice(0, begin, end));
const auto missing_cnt = std::get<3>(results[tid]).size(0);
result_offsets[tid] = end - begin - missing_cnt;
result_offsets[tid + policies_.size()] = missing_cnt;
});
}
std::exclusive_scan(
result_offsets, result_offsets + result_offsets_tensor.size(0),
result_offsets, 0);
torch::Tensor positions = torch::empty(
keys.size(0),
std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor output_indices = torch::empty_like(
indices, indices.options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor pointers = torch::empty(
keys.size(0),
std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
torch::Tensor missing_keys = torch::empty(
result_offsets[2 * policies_.size()] - result_offsets[policies_.size()],
std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
auto missing_offsets =
torch::empty(policies_.size() + 1, result_offsets_tensor.options());
auto positions_ptr = positions.data_ptr<int64_t>();
auto output_indices_ptr = output_indices.data_ptr<int64_t>();
auto pointers_ptr = pointers.data_ptr<int64_t>();
auto missing_offsets_ptr = missing_offsets.data_ptr<int64_t>();
missing_offsets_ptr[0] = 0;
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
if (begin == end) return;
const auto tid = begin;
auto out_index_ptr = indices_ptr + offsets_ptr[tid];
begin = result_offsets[tid];
end = result_offsets[tid + 1];
const auto num_selected = end - begin;
auto indices_ptr = std::get<1>(results[tid]).data_ptr<int64_t>();
for (int64_t i = 0; i < num_selected; i++) {
output_indices_ptr[begin + i] = out_index_ptr[indices_ptr[i]];
}
auto selected_positions_ptr = std::get<0>(results[tid]).data_ptr<int64_t>();
std::transform(
selected_positions_ptr, selected_positions_ptr + num_selected,
positions_ptr + begin,
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
auto selected_pointers_ptr = std::get<2>(results[tid]).data_ptr<int64_t>();
std::copy(
selected_pointers_ptr, selected_pointers_ptr + num_selected,
pointers_ptr + begin);
begin = result_offsets[policies_.size() + tid];
end = result_offsets[policies_.size() + tid + 1];
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
const auto num_missing = end - begin;
for (int64_t i = 0; i < num_missing; i++) {
output_indices_ptr[begin + i] =
out_index_ptr[indices_ptr[i + num_selected]];
}
auto missing_positions_ptr = selected_positions_ptr + num_selected;
std::transform(
missing_positions_ptr, missing_positions_ptr + num_missing,
positions_ptr + begin,
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
auto missing_pointers_ptr = selected_pointers_ptr + num_selected;
std::copy(
missing_pointers_ptr, missing_pointers_ptr + num_missing,
pointers_ptr + begin);
std::memcpy(
reinterpret_cast<std::byte*>(missing_keys.data_ptr()) +
(begin - result_offsets[policies_.size()]) *
missing_keys.element_size(),
std::get<3>(results[tid]).data_ptr(),
num_missing * missing_keys.element_size());
});
auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);
return std::make_tuple(
positions, output_indices, pointers, missing_keys, found_offsets,
missing_offsets);
}
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>

View File

@@ -75,16 +75,9 @@ class CPUCachedFeature(Feature):
"""
if ids is None:
return self._fallback_feature.read()
(
values,
missing_index,
missing_keys,
missing_offsets,
) = self._feature.query(ids)
missing_values = self._fallback_feature.read(missing_keys)
values[missing_index] = missing_values
self._feature.replace(missing_keys, missing_values, missing_offsets)
return values
return self._feature.query_and_then_replace(
ids, self._fallback_feature.read
)
def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.

View File

@@ -51,11 +51,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
cache.replace(missing_keys, missing_values, missing_offsets)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
pin_memory = F._default_context_str == "gpu"
@@ -73,11 +69,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
cache.replace(missing_keys, missing_values, missing_offsets)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
if not offsets:
@@ -88,11 +80,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
cache.replace(missing_keys, missing_values, missing_offsets)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
if not offsets:
@@ -103,12 +91,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
cache.replace(missing_keys, missing_values, missing_offsets)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
assert cache.miss_rate == cache2.miss_rate
assert cache.miss_rate == cache2.miss_rate
raw_feature_cache = torch.ops.graphbolt.feature_cache(
(cache_size,) + a.shape[1:], a.dtype, pin_memory