diff --git a/graphbolt/src/cache_policy.cc b/graphbolt/src/cache_policy.cc index d313bfcb1c..a707b1a2f2 100644 --- a/graphbolt/src/cache_policy.cc +++ b/graphbolt/src/cache_policy.cc @@ -130,6 +130,8 @@ BaseCachePolicy::QueryAndThenReplaceImpl( // we do not have to check for the uniqueness of the positions. std::get<1>(position_set.insert(it->second->getPos())), "Can't insert all, larger cache capacity is needed."); + } else { + policy.MarkExistingWriting(it); } auto& cache_key = *it->second; positions_ptr[i] = cache_key.getPos(); diff --git a/graphbolt/src/cache_policy.h b/graphbolt/src/cache_policy.h index da24a83513..024ceef059 100644 --- a/graphbolt/src/cache_policy.h +++ b/graphbolt/src/cache_policy.h @@ -264,15 +264,13 @@ class S3FifoCachePolicy : public BaseCachePolicy { std::pair Emplace(int64_t key) { auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue()); - if (it->second != nullptr) { + if (it->second != getMapSentinelValue()) { auto& cache_key = *it->second; if (!cache_key.BeingWritten()) { // Not being written so we use StartUse() and return // true to indicate the key is ready to read. cache_key.Increment().StartUse(); return {it, true}; - } else { - cache_key.Increment().StartUse(); } } // First time insertion, return false to indicate not ready to read. @@ -296,6 +294,10 @@ class S3FifoCachePolicy : public BaseCachePolicy { it->second = queue.Push(CacheKey(key, pos)); } + void MarkExistingWriting(map_iterator it) { + it->second->Increment().StartUse(); + } + template void Unmark(CacheKey* cache_key_ptr) { cache_key_ptr->EndUse(); @@ -414,15 +416,13 @@ class SieveCachePolicy : public BaseCachePolicy { std::pair Emplace(int64_t key) { auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue()); - if (it->second != nullptr) { + if (it->second != getMapSentinelValue()) { auto& cache_key = *it->second; if (!cache_key.BeingWritten()) { // Not being written so we use StartUse() and return // true to indicate the key is ready to read. cache_key.SetFreq().StartUse(); return {it, true}; - } else { - cache_key.SetFreq().StartUse(); } } // First time insertion, return false to indicate not ready to read. @@ -444,6 +444,10 @@ class SieveCachePolicy : public BaseCachePolicy { it->second = &queue_.front(); } + void MarkExistingWriting(map_iterator it) { + it->second->SetFreq().StartUse(); + } + template void Unmark(CacheKey* cache_key_ptr) { cache_key_ptr->EndUse(); @@ -552,16 +556,14 @@ class LruCachePolicy : public BaseCachePolicy { std::pair Emplace(int64_t key) { auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue()); - if (it->second != queue_.end()) { - MoveToFront(it->second); + if (it->second != getMapSentinelValue()) { auto& cache_key = *it->second; if (!cache_key.BeingWritten()) { + MoveToFront(it->second); // Not being written so we use StartUse() and return // true to indicate the key is ready to read. cache_key.StartUse(); return {it, true}; - } else { - cache_key.StartUse(); } } // First time insertion, return false to indicate not ready to read. @@ -582,6 +584,11 @@ class LruCachePolicy : public BaseCachePolicy { it->second = queue_.begin(); } + void MarkExistingWriting(map_iterator it) { + MoveToFront(it->second); + it->second->StartUse(); + } + template void Unmark(CacheKey* cache_key_ptr) { cache_key_ptr->EndUse(); @@ -678,15 +685,13 @@ class ClockCachePolicy : public BaseCachePolicy { std::pair Emplace(int64_t key) { auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue()); - if (it->second != nullptr) { + if (it->second != getMapSentinelValue()) { auto& cache_key = *it->second; if (!cache_key.BeingWritten()) { // Not being written so we use StartUse() and return // true to indicate the key is ready to read. cache_key.SetFreq().StartUse(); return {it, true}; - } else { - cache_key.SetFreq().StartUse(); } } // First time insertion, return false to indicate not ready to read. @@ -706,6 +711,10 @@ class ClockCachePolicy : public BaseCachePolicy { it->second = queue_.Push(CacheKey(key, pos)); } + void MarkExistingWriting(map_iterator it) { + it->second->SetFreq().StartUse(); + } + template void Unmark(CacheKey* cache_key_ptr) { cache_key_ptr->EndUse(); diff --git a/graphbolt/src/partitioned_cache_policy.cc b/graphbolt/src/partitioned_cache_policy.cc index 6717faa350..c416181413 100644 --- a/graphbolt/src/partitioned_cache_policy.cc +++ b/graphbolt/src/partitioned_cache_policy.cc @@ -206,11 +206,10 @@ PartitionedCachePolicy::Query(torch::Tensor keys) { selected_positions_ptr, selected_positions_ptr + num_selected, positions.data_ptr() + begin, [off = tid * capacity_ / policies_.size()](auto x) { return x + off; }); - std::memcpy( - reinterpret_cast(found_pointers.data_ptr()) + - begin * found_pointers.element_size(), - std::get<3>(results[tid]).data_ptr(), - num_selected * found_pointers.element_size()); + auto selected_pointers_ptr = std::get<3>(results[tid]).data_ptr(); + std::copy( + selected_pointers_ptr, selected_pointers_ptr + num_selected, + found_pointers.data_ptr() + begin); begin = result_offsets[policies_.size() + tid]; end = result_offsets[policies_.size() + tid + 1]; missing_offsets[tid + 1] = end - result_offsets[policies_.size()]; @@ -263,7 +262,102 @@ PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) { auto missing_offsets = found_and_missing_offsets.slice(0, 2); return {positions, output_indices, pointers, missing_keys, found_offsets, missing_offsets}; - }; + } + torch::Tensor offsets, indices, permuted_keys; + std::tie(offsets, indices, permuted_keys) = Partition(keys); + auto offsets_ptr = offsets.data_ptr(); + auto indices_ptr = indices.data_ptr(); + std::vector< + std::tuple> + results(policies_.size()); + torch::Tensor result_offsets_tensor = + torch::empty(policies_.size() * 2 + 1, offsets.options()); + auto result_offsets = result_offsets_tensor.data_ptr(); + namespace gb = graphbolt; + { + std::lock_guard lock(mtx_); + gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { + if (begin == end) return; + TORCH_CHECK(end - begin == 1); + const auto tid = begin; + begin = offsets_ptr[tid]; + end = offsets_ptr[tid + 1]; + results[tid] = policies_.at(tid)->QueryAndThenReplace( + permuted_keys.slice(0, begin, end)); + const auto missing_cnt = std::get<3>(results[tid]).size(0); + result_offsets[tid] = end - begin - missing_cnt; + result_offsets[tid + policies_.size()] = missing_cnt; + }); + } + std::exclusive_scan( + result_offsets, result_offsets + result_offsets_tensor.size(0), + result_offsets, 0); + torch::Tensor positions = torch::empty( + keys.size(0), + std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys))); + torch::Tensor output_indices = torch::empty_like( + indices, indices.options().pinned_memory(utils::is_pinned(keys))); + torch::Tensor pointers = torch::empty( + keys.size(0), + std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys))); + torch::Tensor missing_keys = torch::empty( + result_offsets[2 * policies_.size()] - result_offsets[policies_.size()], + std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys))); + auto missing_offsets = + torch::empty(policies_.size() + 1, result_offsets_tensor.options()); + auto positions_ptr = positions.data_ptr(); + auto output_indices_ptr = output_indices.data_ptr(); + auto pointers_ptr = pointers.data_ptr(); + auto missing_offsets_ptr = missing_offsets.data_ptr(); + missing_offsets_ptr[0] = 0; + gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) { + if (begin == end) return; + const auto tid = begin; + auto out_index_ptr = indices_ptr + offsets_ptr[tid]; + begin = result_offsets[tid]; + end = result_offsets[tid + 1]; + const auto num_selected = end - begin; + auto indices_ptr = std::get<1>(results[tid]).data_ptr(); + for (int64_t i = 0; i < num_selected; i++) { + output_indices_ptr[begin + i] = out_index_ptr[indices_ptr[i]]; + } + auto selected_positions_ptr = std::get<0>(results[tid]).data_ptr(); + std::transform( + selected_positions_ptr, selected_positions_ptr + num_selected, + positions_ptr + begin, + [off = tid * capacity_ / policies_.size()](auto x) { return x + off; }); + auto selected_pointers_ptr = std::get<2>(results[tid]).data_ptr(); + std::copy( + selected_pointers_ptr, selected_pointers_ptr + num_selected, + pointers_ptr + begin); + begin = result_offsets[policies_.size() + tid]; + end = result_offsets[policies_.size() + tid + 1]; + missing_offsets[tid + 1] = end - result_offsets[policies_.size()]; + const auto num_missing = end - begin; + for (int64_t i = 0; i < num_missing; i++) { + output_indices_ptr[begin + i] = + out_index_ptr[indices_ptr[i + num_selected]]; + } + auto missing_positions_ptr = selected_positions_ptr + num_selected; + std::transform( + missing_positions_ptr, missing_positions_ptr + num_missing, + positions_ptr + begin, + [off = tid * capacity_ / policies_.size()](auto x) { return x + off; }); + auto missing_pointers_ptr = selected_pointers_ptr + num_selected; + std::copy( + missing_pointers_ptr, missing_pointers_ptr + num_missing, + pointers_ptr + begin); + std::memcpy( + reinterpret_cast(missing_keys.data_ptr()) + + (begin - result_offsets[policies_.size()]) * + missing_keys.element_size(), + std::get<3>(results[tid]).data_ptr(), + num_missing * missing_keys.element_size()); + }); + auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1); + return std::make_tuple( + positions, output_indices, pointers, missing_keys, found_offsets, + missing_offsets); } c10::intrusive_ptr>> diff --git a/python/dgl/graphbolt/impl/cpu_cached_feature.py b/python/dgl/graphbolt/impl/cpu_cached_feature.py index 7b572aec71..797e02156b 100644 --- a/python/dgl/graphbolt/impl/cpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/cpu_cached_feature.py @@ -75,16 +75,9 @@ class CPUCachedFeature(Feature): """ if ids is None: return self._fallback_feature.read() - ( - values, - missing_index, - missing_keys, - missing_offsets, - ) = self._feature.query(ids) - missing_values = self._fallback_feature.read(missing_keys) - values[missing_index] = missing_values - self._feature.replace(missing_keys, missing_values, missing_offsets) - return values + return self._feature.query_and_then_replace( + ids, self._fallback_feature.read + ) def read_async(self, ids: torch.Tensor): """Read the feature by index asynchronously. diff --git a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py index 5daa4b7af8..07cd749cba 100644 --- a/tests/python/pytorch/graphbolt/impl/test_feature_cache.py +++ b/tests/python/pytorch/graphbolt/impl/test_feature_cache.py @@ -51,11 +51,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - - if num_parts == 1: - assert torch.equal( - cache2.query_and_then_replace(keys, reader_fn), a[keys] - ) + assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) pin_memory = F._default_context_str == "gpu" @@ -73,11 +69,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - - if num_parts == 1: - assert torch.equal( - cache2.query_and_then_replace(keys, reader_fn), a[keys] - ) + assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) values, missing_index, missing_keys, missing_offsets = cache.query(keys) if not offsets: @@ -88,11 +80,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) - - if num_parts == 1: - assert torch.equal( - cache2.query_and_then_replace(keys, reader_fn), a[keys] - ) + assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) values, missing_index, missing_keys, missing_offsets = cache.query(keys) if not offsets: @@ -103,12 +91,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy): cache.replace(missing_keys, missing_values, missing_offsets) values[missing_index] = missing_values assert torch.equal(values, a[keys]) + assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys]) - if num_parts == 1: - assert torch.equal( - cache2.query_and_then_replace(keys, reader_fn), a[keys] - ) - assert cache.miss_rate == cache2.miss_rate + assert cache.miss_rate == cache2.miss_rate raw_feature_cache = torch.ops.graphbolt.feature_cache( (cache_size,) + a.shape[1:], a.dtype, pin_memory