From 87ea76b02ac5743871181e8d38bd18c2dc103115 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Tue, 30 Jul 2024 15:16:23 -0400 Subject: [PATCH] [GraphBolt] Test `CachePolicy::QueryAndThenReplace` for `num_parts==1`. (#7620) --- graphbolt/src/cache_policy.cc | 2 +- graphbolt/src/partitioned_cache_policy.cc | 35 +++++++++++++++ graphbolt/src/partitioned_cache_policy.h | 23 ++++++++++ graphbolt/src/python_binding.cc | 6 +++ python/dgl/graphbolt/impl/feature_cache.py | 44 +++++++++++++++++++ .../graphbolt/impl/test_feature_cache.py | 43 ++++++++++++++---- 6 files changed, 143 insertions(+), 10 deletions(-) diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index e34e55cb6e..d313bfcb1c 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -136,7 +136,7 @@ BaseCachePolicy::QueryAndThenReplaceImpl( pointers_ptr[i] = &cache_key; } })); - return {positions, indices, pointers, missing_keys}; + return {positions, indices, pointers, missing_keys.slice(0, found_cnt)}; } template diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index 1a2136c307..6717faa350 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -242,6 +242,41 @@ PartitionedCachePolicy::QueryAsync(torch::Tensor keys) { }); } +std::tuple< + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor> +PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) { + if (policies_.size() == 1) { + std::lock_guard lock(mtx_); + auto [positions, output_indices, pointers, missing_keys] = + policies_[0]->QueryAndThenReplace(keys); + auto found_and_missing_offsets = torch::empty(4, pointers.options()); + auto found_and_missing_offsets_ptr = + found_and_missing_offsets.data_ptr(); + // Found offsets part. + found_and_missing_offsets_ptr[0] = 0; + found_and_missing_offsets_ptr[1] = keys.size(0) - missing_keys.size(0); + // Missing offsets part. + found_and_missing_offsets_ptr[2] = 0; + found_and_missing_offsets_ptr[3] = missing_keys.size(0); + auto found_offsets = found_and_missing_offsets.slice(0, 0, 2); + auto missing_offsets = found_and_missing_offsets.slice(0, 2); + return {positions, output_indices, pointers, + missing_keys, found_offsets, missing_offsets}; + }; +} + +c10::intrusive_ptr>> +PartitionedCachePolicy::QueryAndThenReplaceAsync(torch::Tensor keys) { + return async([=] { + auto + [positions, output_indices, pointers, missing_keys, found_offsets, + missing_offsets] = QueryAndThenReplace(keys); + return std::vector{positions, output_indices, pointers, + missing_keys, found_offsets, missing_offsets}; + }); +} + std::tuple PartitionedCachePolicy::Replace( torch::Tensor keys, torch::optional offsets) { diff --git a/graphbolt/src/partitioned_cache_policy.h b/graphbolt/src/partitioned_cache_policy.h index efd0f724b3..3ee81dd9c5 100644 --- a/graphbolt/src/partitioned_cache_policy.h +++ b/graphbolt/src/partitioned_cache_policy.h @@ -74,6 +74,29 @@ class PartitionedCachePolicy : public torch::CustomClassHolder { c10::intrusive_ptr>> QueryAsync( torch::Tensor keys); + /** + * @brief The policy query and then replace function. + * @param keys The keys to query the cache. + * + * @return (positions, indices, pointers, missing_keys, found_offsets, + * missing_offsets), where positions has the locations of the keys which were + * emplaced into the cache, pointers point to the emplaced CacheKey pointers + * in the cache, missing_keys has the keys that were not found and just + * inserted and indices is defined such that keys[indices[:keys.size(0) - + * missing_keys.size(0)]] gives us the keys for the found keys and + * keys[indices[keys.size(0) - missing_keys.size(0):]] is identical to + * missing_keys. The found_offsets tensor holds the partition offsets for the + * found pointers. The missing_offsets holds the partition offsets for the + * missing_keys and missing pointers. + */ + std::tuple< + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor> + QueryAndThenReplace(torch::Tensor keys); + + c10::intrusive_ptr>> + QueryAndThenReplaceAsync(torch::Tensor keys); + /** * @brief The policy replace function. * @param keys The keys to query the cache. diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index c1125a4dfd..223b7b20b2 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -107,6 +107,12 @@ TORCH_LIBRARY(graphbolt, m) { m.class_("PartitionedCachePolicy") .def("query", &storage::PartitionedCachePolicy::Query) .def("query_async", &storage::PartitionedCachePolicy::QueryAsync) + .def( + "query_and_then_replace", + &storage::PartitionedCachePolicy::QueryAndThenReplace) + .def( + "query_and_then_replace_async", + &storage::PartitionedCachePolicy::QueryAndThenReplaceAsync) .def("replace", &storage::PartitionedCachePolicy::Replace) .def("replace_async", &storage::PartitionedCachePolicy::ReplaceAsync) .def( diff --git a/python/dgl/graphbolt/impl/feature_cache.py b/python/dgl/graphbolt/impl/feature_cache.py index b2a7ea7ca1..ba185c58d4 100644 --- a/python/dgl/graphbolt/impl/feature_cache.py +++ b/python/dgl/graphbolt/impl/feature_cache.py @@ -92,6 +92,50 @@ class CPUFeatureCache(object): missing_index = index[positions.size(0) :] return values, missing_index, missing_keys, missing_offsets + def query_and_then_replace(self, keys, reader_fn): + """Queries the cache. Then inserts the keys that are not found by + reading them by calling `reader_fn(missing_keys)`, which are then + inserted into the cache using the selected caching policy algorithm + to remove the old entries if it is full. + + Parameters + ---------- + keys : Tensor + The keys to query the cache with. + reader_fn : reader_fn(keys: torch.Tensor) -> torch.Tensor + A function that will take a missing keys tensor and will return + their values. + + Returns + ------- + Tensor + A tensor containing values corresponding to the keys. Should equal + `reader_fn(keys)`, computed in a faster way. + """ + self.total_queries += keys.shape[0] + ( + positions, + index, + pointers, + missing_keys, + found_offsets, + missing_offsets, + ) = self._policy.query_and_then_replace(keys) + found_cnt = keys.size(0) - missing_keys.size(0) + found_positions = positions[:found_cnt] + values = self._cache.query(found_positions, index, keys.shape[0]) + found_pointers = pointers[:found_cnt] + self._policy.reading_completed(found_pointers, found_offsets) + self.total_miss += missing_keys.shape[0] + missing_index = index[found_cnt:] + missing_values = reader_fn(missing_keys) + values[missing_index] = missing_values + missing_positions = positions[found_cnt:] + self._cache.replace(missing_positions, missing_values) + missing_pointers = pointers[found_cnt:] + self._policy.writing_completed(missing_pointers, missing_offsets) + return values + def replace(self, keys, values, offsets=None): """Inserts key-value pairs into the cache using the selected caching policy algorithm to remove old key-value pairs if it is full. diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index 274ddc4808..5daa4b7af8 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -33,6 +33,10 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache = gb.impl.CPUFeatureCache( (cache_size,) + a.shape[1:], a.dtype, policy, num_parts ) + cache2 = gb.impl.CPUFeatureCache( + (cache_size,) + a.shape[1:], a.dtype, policy, num_parts + ) + reader_fn = lambda keys: a[keys] keys = torch.tensor([0, 1]) values, missing_index, missing_keys, missing_offsets = cache.query(keys) @@ -48,6 +52,11 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): values[missing_index] = missing_values assert torch.equal(values, a[keys]) + if num_parts == 1: + assert torch.equal( + cache2.query_and_then_replace(keys, reader_fn), a[keys] + ) + pin_memory = F._default_context_str == "gpu" keys = torch.arange(1, 33, pin_memory=pin_memory) @@ -65,15 +74,10 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): values[missing_index] = missing_values assert torch.equal(values, a[keys]) - values, missing_index, missing_keys, missing_offsets = cache.query(keys) - if not offsets: - missing_offsets = None - assert torch.equal(missing_keys.flip([0]), torch.tensor([])) - - missing_values = a[missing_keys] - cache.replace(missing_keys, missing_values, missing_offsets) - values[missing_index] = missing_values - assert torch.equal(values, a[keys]) + if num_parts == 1: + assert torch.equal( + cache2.query_and_then_replace(keys, reader_fn), a[keys] + ) values, missing_index, missing_keys, missing_offsets = cache.query(keys) if not offsets: @@ -85,6 +89,27 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): values[missing_index] = missing_values assert torch.equal(values, a[keys]) + if num_parts == 1: + assert torch.equal( + cache2.query_and_then_replace(keys, reader_fn), a[keys] + ) + + values, missing_index, missing_keys, missing_offsets = cache.query(keys) + if not offsets: + missing_offsets = None + assert torch.equal(missing_keys.flip([0]), torch.tensor([])) + + missing_values = a[missing_keys] + cache.replace(missing_keys, missing_values, missing_offsets) + values[missing_index] = missing_values + assert torch.equal(values, a[keys]) + + if num_parts == 1: + assert torch.equal( + cache2.query_and_then_replace(keys, reader_fn), a[keys] + ) + assert cache.miss_rate == cache2.miss_rate + raw_feature_cache = torch.ops.graphbolt.feature_cache( (cache_size,) + a.shape[1:], a.dtype, pin_memory )