mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Test CachePolicy::QueryAndThenReplace for num_parts==1. (#7620)
This commit is contained in:
committed by
GitHub
parent
0462538c5c
commit
87ea76b02a
@@ -136,7 +136,7 @@ BaseCachePolicy::QueryAndThenReplaceImpl(
|
||||
pointers_ptr[i] = &cache_key;
|
||||
}
|
||||
}));
|
||||
return {positions, indices, pointers, missing_keys};
|
||||
return {positions, indices, pointers, missing_keys.slice(0, found_cnt)};
|
||||
}
|
||||
|
||||
template <typename CachePolicy>
|
||||
|
||||
@@ -242,6 +242,41 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) {
|
||||
});
|
||||
}
|
||||
|
||||
std::tuple<
|
||||
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
||||
torch::Tensor>
|
||||
PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
|
||||
if (policies_.size() == 1) {
|
||||
std::lock_guard lock(mtx_);
|
||||
auto [positions, output_indices, pointers, missing_keys] =
|
||||
policies_[0]->QueryAndThenReplace(keys);
|
||||
auto found_and_missing_offsets = torch::empty(4, pointers.options());
|
||||
auto found_and_missing_offsets_ptr =
|
||||
found_and_missing_offsets.data_ptr<int64_t>();
|
||||
// Found offsets part.
|
||||
found_and_missing_offsets_ptr[0] = 0;
|
||||
found_and_missing_offsets_ptr[1] = keys.size(0) - missing_keys.size(0);
|
||||
// Missing offsets part.
|
||||
found_and_missing_offsets_ptr[2] = 0;
|
||||
found_and_missing_offsets_ptr[3] = missing_keys.size(0);
|
||||
auto found_offsets = found_and_missing_offsets.slice(0, 0, 2);
|
||||
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
|
||||
return {positions, output_indices, pointers,
|
||||
missing_keys, found_offsets, missing_offsets};
|
||||
};
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
|
||||
PartitionedCachePolicy::QueryAndThenReplaceAsync(torch::Tensor keys) {
|
||||
return async([=] {
|
||||
auto
|
||||
[positions, output_indices, pointers, missing_keys, found_offsets,
|
||||
missing_offsets] = QueryAndThenReplace(keys);
|
||||
return std::vector{positions, output_indices, pointers,
|
||||
missing_keys, found_offsets, missing_offsets};
|
||||
});
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
PartitionedCachePolicy::Replace(
|
||||
torch::Tensor keys, torch::optional<torch::Tensor> offsets) {
|
||||
|
||||
@@ -74,6 +74,29 @@ class PartitionedCachePolicy : public torch::CustomClassHolder {
|
||||
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>> QueryAsync(
|
||||
torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief The policy query and then replace function.
|
||||
* @param keys The keys to query the cache.
|
||||
*
|
||||
* @return (positions, indices, pointers, missing_keys, found_offsets,
|
||||
* missing_offsets), where positions has the locations of the keys which were
|
||||
* emplaced into the cache, pointers point to the emplaced CacheKey pointers
|
||||
* in the cache, missing_keys has the keys that were not found and just
|
||||
* inserted and indices is defined such that keys[indices[:keys.size(0) -
|
||||
* missing_keys.size(0)]] gives us the keys for the found keys and
|
||||
* keys[indices[keys.size(0) - missing_keys.size(0):]] is identical to
|
||||
* missing_keys. The found_offsets tensor holds the partition offsets for the
|
||||
* found pointers. The missing_offsets holds the partition offsets for the
|
||||
* missing_keys and missing pointers.
|
||||
*/
|
||||
std::tuple<
|
||||
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
|
||||
torch::Tensor>
|
||||
QueryAndThenReplace(torch::Tensor keys);
|
||||
|
||||
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
|
||||
QueryAndThenReplaceAsync(torch::Tensor keys);
|
||||
|
||||
/**
|
||||
* @brief The policy replace function.
|
||||
* @param keys The keys to query the cache.
|
||||
|
||||
@@ -107,6 +107,12 @@ TORCH_LIBRARY(graphbolt, m) {
|
||||
m.class_<storage::PartitionedCachePolicy>("PartitionedCachePolicy")
|
||||
.def("query", &storage::PartitionedCachePolicy::Query)
|
||||
.def("query_async", &storage::PartitionedCachePolicy::QueryAsync)
|
||||
.def(
|
||||
"query_and_then_replace",
|
||||
&storage::PartitionedCachePolicy::QueryAndThenReplace)
|
||||
.def(
|
||||
"query_and_then_replace_async",
|
||||
&storage::PartitionedCachePolicy::QueryAndThenReplaceAsync)
|
||||
.def("replace", &storage::PartitionedCachePolicy::Replace)
|
||||
.def("replace_async", &storage::PartitionedCachePolicy::ReplaceAsync)
|
||||
.def(
|
||||
|
||||
@@ -92,6 +92,50 @@ class CPUFeatureCache(object):
|
||||
missing_index = index[positions.size(0) :]
|
||||
return values, missing_index, missing_keys, missing_offsets
|
||||
|
||||
def query_and_then_replace(self, keys, reader_fn):
|
||||
"""Queries the cache. Then inserts the keys that are not found by
|
||||
reading them by calling `reader_fn(missing_keys)`, which are then
|
||||
inserted into the cache using the selected caching policy algorithm
|
||||
to remove the old entries if it is full.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
keys : Tensor
|
||||
The keys to query the cache with.
|
||||
reader_fn : reader_fn(keys: torch.Tensor) -> torch.Tensor
|
||||
A function that will take a missing keys tensor and will return
|
||||
their values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
A tensor containing values corresponding to the keys. Should equal
|
||||
`reader_fn(keys)`, computed in a faster way.
|
||||
"""
|
||||
self.total_queries += keys.shape[0]
|
||||
(
|
||||
positions,
|
||||
index,
|
||||
pointers,
|
||||
missing_keys,
|
||||
found_offsets,
|
||||
missing_offsets,
|
||||
) = self._policy.query_and_then_replace(keys)
|
||||
found_cnt = keys.size(0) - missing_keys.size(0)
|
||||
found_positions = positions[:found_cnt]
|
||||
values = self._cache.query(found_positions, index, keys.shape[0])
|
||||
found_pointers = pointers[:found_cnt]
|
||||
self._policy.reading_completed(found_pointers, found_offsets)
|
||||
self.total_miss += missing_keys.shape[0]
|
||||
missing_index = index[found_cnt:]
|
||||
missing_values = reader_fn(missing_keys)
|
||||
values[missing_index] = missing_values
|
||||
missing_positions = positions[found_cnt:]
|
||||
self._cache.replace(missing_positions, missing_values)
|
||||
missing_pointers = pointers[found_cnt:]
|
||||
self._policy.writing_completed(missing_pointers, missing_offsets)
|
||||
return values
|
||||
|
||||
def replace(self, keys, values, offsets=None):
|
||||
"""Inserts key-value pairs into the cache using the selected caching
|
||||
policy algorithm to remove old key-value pairs if it is full.
|
||||
|
||||
@@ -33,6 +33,10 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
cache = gb.impl.CPUFeatureCache(
|
||||
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
|
||||
)
|
||||
cache2 = gb.impl.CPUFeatureCache(
|
||||
(cache_size,) + a.shape[1:], a.dtype, policy, num_parts
|
||||
)
|
||||
reader_fn = lambda keys: a[keys]
|
||||
|
||||
keys = torch.tensor([0, 1])
|
||||
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
|
||||
@@ -48,6 +52,11 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
|
||||
pin_memory = F._default_context_str == "gpu"
|
||||
|
||||
keys = torch.arange(1, 33, pin_memory=pin_memory)
|
||||
@@ -65,15 +74,10 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
|
||||
if not offsets:
|
||||
missing_offsets = None
|
||||
assert torch.equal(missing_keys.flip([0]), torch.tensor([]))
|
||||
|
||||
missing_values = a[missing_keys]
|
||||
cache.replace(missing_keys, missing_values, missing_offsets)
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
|
||||
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
|
||||
if not offsets:
|
||||
@@ -85,6 +89,27 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
|
||||
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
|
||||
if not offsets:
|
||||
missing_offsets = None
|
||||
assert torch.equal(missing_keys.flip([0]), torch.tensor([]))
|
||||
|
||||
missing_values = a[missing_keys]
|
||||
cache.replace(missing_keys, missing_values, missing_offsets)
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
assert cache.miss_rate == cache2.miss_rate
|
||||
|
||||
raw_feature_cache = torch.ops.graphbolt.feature_cache(
|
||||
(cache_size,) + a.shape[1:], a.dtype, pin_memory
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user