[GraphBolt] Test CachePolicy::QueryAndThenReplace for num_parts==1. (#7620)

This commit is contained in:
Muhammed Fatih BALIN
2024-07-30 15:16:23 -04:00
committed by GitHub
parent 0462538c5c
commit 87ea76b02a
6 changed files with 143 additions and 10 deletions

View File

@@ -136,7 +136,7 @@ BaseCachePolicy::QueryAndThenReplaceImpl(
pointers_ptr[i] = &cache_key;
}
}));
return {positions, indices, pointers, missing_keys};
return {positions, indices, pointers, missing_keys.slice(0, found_cnt)};
}
template <typename CachePolicy>

View File

@@ -242,6 +242,41 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
});
}
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
if (policies_.size() == 1) {
std::lock_guard lock(mtx_);
auto [positions, output_indices, pointers, missing_keys] =
policies_[0]->QueryAndThenReplace(keys);
auto found_and_missing_offsets = torch::empty(4, pointers.options());
auto found_and_missing_offsets_ptr =
found_and_missing_offsets.data_ptr<int64_t>();
// Found offsets part.
found_and_missing_offsets_ptr[0] = 0;
found_and_missing_offsets_ptr[1] = keys.size(0) - missing_keys.size(0);
// Missing offsets part.
found_and_missing_offsets_ptr[2] = 0;
found_and_missing_offsets_ptr[3] = missing_keys.size(0);
auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
return {positions, output_indices, pointers,
missing_keys, found_offsets, missing_offsets};
};
}
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
PartitionedCachePolicy::QueryAndThenReplaceAsync(torch::Tensor keys) {
return async([=] {
auto
[positions, output_indices, pointers, missing_keys, found_offsets,
missing_offsets] = QueryAndThenReplace(keys);
return std::vector{positions, output_indices, pointers,
missing_keys, found_offsets, missing_offsets};
});
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
PartitionedCachePolicy::Replace(
torch::Tensor keys, torch::optional<torch::Tensor> offsets) {

View File

@@ -74,6 +74,29 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(
torch::Tensor keys);
/**
* @brief The policy query and then replace function.
* @param keys The keys to query the cache.
*
* @return (positions, indices, pointers, missing_keys, found_offsets,
* missing_offsets), where positions has the locations of the keys which were
* emplaced into the cache, pointers point to the emplaced CacheKey pointers
* in the cache, missing_keys has the keys that were not found and just
* inserted and indices is defined such that keys[indices[:keys.size(0) -
* missing_keys.size(0)]] gives us the keys for the found keys and
* keys[indices[keys.size(0) - missing_keys.size(0):]] is identical to
* missing_keys. The found_offsets tensor holds the partition offsets for the
* found pointers. The missing_offsets holds the partition offsets for the
* missing_keys and missing pointers.
*/
std::tuple<
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor>
QueryAndThenReplace(torch::Tensor keys);
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
QueryAndThenReplaceAsync(torch::Tensor keys);
/**
* @brief The policy replace function.
* @param keys The keys to query the cache.

View File

@@ -107,6 +107,12 @@ TORCH_LIBRARY(graphbolt, m) {
m.class_<storage::PartitionedCachePolicy>("PartitionedCachePolicy")
.def("query", &storage::PartitionedCachePolicy::Query)
.def("query_async", &storage::PartitionedCachePolicy::QueryAsync)
.def(
"query_and_then_replace",
&storage::PartitionedCachePolicy::QueryAndThenReplace)
.def(
"query_and_then_replace_async",
&storage::PartitionedCachePolicy::QueryAndThenReplaceAsync)
.def("replace", &storage::PartitionedCachePolicy::Replace)
.def("replace_async", &storage::PartitionedCachePolicy::ReplaceAsync)
.def(

View File

@@ -92,6 +92,50 @@ class CPUFeatureCache(object):
missing_index = index[positions.size(0) :]
return values, missing_index, missing_keys, missing_offsets
def query_and_then_replace(self, keys, reader_fn):
"""Queries the cache. Then inserts the keys that are not found by
reading them by calling `reader_fn(missing_keys)`, which are then
inserted into the cache using the selected caching policy algorithm
to remove the old entries if it is full.
Parameters
----------
keys : Tensor
The keys to query the cache with.
reader_fn : reader_fn(keys: torch.Tensor) -> torch.Tensor
A function that will take a missing keys tensor and will return
their values.
Returns
-------
Tensor
A tensor containing values corresponding to the keys. Should equal
`reader_fn(keys)`, computed in a faster way.
"""
self.total_queries += keys.shape[0]
(
positions,
index,
pointers,
missing_keys,
found_offsets,
missing_offsets,
) = self._policy.query_and_then_replace(keys)
found_cnt = keys.size(0) - missing_keys.size(0)
found_positions = positions[:found_cnt]
values = self._cache.query(found_positions, index, keys.shape[0])
found_pointers = pointers[:found_cnt]
self._policy.reading_completed(found_pointers, found_offsets)
self.total_miss += missing_keys.shape[0]
missing_index = index[found_cnt:]
missing_values = reader_fn(missing_keys)
values[missing_index] = missing_values
missing_positions = positions[found_cnt:]
self._cache.replace(missing_positions, missing_values)
missing_pointers = pointers[found_cnt:]
self._policy.writing_completed(missing_pointers, missing_offsets)
return values
def replace(self, keys, values, offsets=None):
"""Inserts key-value pairs into the cache using the selected caching
policy algorithm to remove old key-value pairs if it is full.

View File

@@ -33,6 +33,10 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
cache = gb.impl.CPUFeatureCache(
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)
cache2 = gb.impl.CPUFeatureCache(
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
)
reader_fn = lambda keys: a[keys]
keys = torch.tensor([0, 1])
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
@@ -48,6 +52,11 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
pin_memory = F._default_context_str == "gpu"
keys = torch.arange(1, 33, pin_memory=pin_memory)
@@ -65,15 +74,10 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
if not offsets:
missing_offsets = None
assert torch.equal(missing_keys.flip([0]), torch.tensor([]))
missing_values = a[missing_keys]
cache.replace(missing_keys, missing_values, missing_offsets)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
if not offsets:
@@ -85,6 +89,27 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
if not offsets:
missing_offsets = None
assert torch.equal(missing_keys.flip([0]), torch.tensor([]))
missing_values = a[missing_keys]
cache.replace(missing_keys, missing_values, missing_offsets)
values[missing_index] = missing_values
assert torch.equal(values, a[keys])
if num_parts == 1:
assert torch.equal(
cache2.query_and_then_replace(keys, reader_fn), a[keys]
)
assert cache.miss_rate == cache2.miss_rate
raw_feature_cache = torch.ops.graphbolt.feature_cache(
(cache_size,) + a.shape[1:], a.dtype, pin_memory
)