mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] CachePolicy::QueryAndThenReplace for num_parts>1. (#7621)
This commit is contained in:
committed by
GitHub
parent
eafb53013f
commit
cfe0541802
@@ -130,6 +130,8 @@ BaseCachePolicy::QueryAndThenReplaceImpl(
|
||||
// we do not have to check for the uniqueness of the positions.
|
||||
std::get<1>(position_set.insert(it->second->getPos())),
|
||||
"Can't insert all, larger cache capacity is needed.");
|
||||
} else {
|
||||
policy.MarkExistingWriting(it);
|
||||
}
|
||||
auto& cache_key = *it->second;
|
||||
positions_ptr[i] = cache_key.getPos();
|
||||
|
||||
@@ -264,15 +264,13 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != nullptr) {
|
||||
if (it->second != getMapSentinelValue()) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.Increment().StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.Increment().StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
@@ -296,6 +294,10 @@ class S3FifoCachePolicy : public BaseCachePolicy {
|
||||
it->second = queue.Push(CacheKey(key, pos));
|
||||
}
|
||||
|
||||
void MarkExistingWriting(map_iterator it) {
|
||||
it->second->Increment().StartUse<true>();
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
@@ -414,15 +416,13 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != nullptr) {
|
||||
if (it->second != getMapSentinelValue()) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.SetFreq().StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.SetFreq().StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
@@ -444,6 +444,10 @@ class SieveCachePolicy : public BaseCachePolicy {
|
||||
it->second = &queue_.front();
|
||||
}
|
||||
|
||||
void MarkExistingWriting(map_iterator it) {
|
||||
it->second->SetFreq().StartUse<true>();
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
@@ -552,16 +556,14 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != queue_.end()) {
|
||||
MoveToFront(it->second);
|
||||
if (it->second != getMapSentinelValue()) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
MoveToFront(it->second);
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
@@ -582,6 +584,11 @@ class LruCachePolicy : public BaseCachePolicy {
|
||||
it->second = queue_.begin();
|
||||
}
|
||||
|
||||
void MarkExistingWriting(map_iterator it) {
|
||||
MoveToFront(it->second);
|
||||
it->second->StartUse<true>();
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
@@ -678,15 +685,13 @@ class ClockCachePolicy : public BaseCachePolicy {
|
||||
|
||||
std::pair<map_iterator, bool> Emplace(int64_t key) {
|
||||
auto [it, _] = key_to_cache_key_.emplace(key, getMapSentinelValue());
|
||||
if (it->second != nullptr) {
|
||||
if (it->second != getMapSentinelValue()) {
|
||||
auto& cache_key = *it->second;
|
||||
if (!cache_key.BeingWritten()) {
|
||||
// Not being written so we use StartUse<write=false>() and return
|
||||
// true to indicate the key is ready to read.
|
||||
cache_key.SetFreq().StartUse<false>();
|
||||
return {it, true};
|
||||
} else {
|
||||
cache_key.SetFreq().StartUse<true>();
|
||||
}
|
||||
}
|
||||
// First time insertion, return false to indicate not ready to read.
|
||||
@@ -706,6 +711,10 @@ class ClockCachePolicy : public BaseCachePolicy {
|
||||
it->second = queue_.Push(CacheKey(key, pos));
|
||||
}
|
||||
|
||||
void MarkExistingWriting(map_iterator it) {
|
||||
it->second->SetFreq().StartUse<true>();
|
||||
}
|
||||
|
||||
template <bool write>
|
||||
void Unmark(CacheKey* cache_key_ptr) {
|
||||
cache_key_ptr->EndUse<write>();
|
||||
|
||||
@@ -206,11 +206,10 @@ PartitionedCachePolicy::Query(torch::Tensor keys) {
|
||||
selected_positions_ptr, selected_positions_ptr + num_selected,
|
||||
positions.data_ptr<int64_t>() + begin,
|
||||
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
|
||||
std::memcpy(
|
||||
reinterpret_cast<std::byte*>(found_pointers.data_ptr()) +
|
||||
begin * found_pointers.element_size(),
|
||||
std::get<3>(results[tid]).data_ptr(),
|
||||
num_selected * found_pointers.element_size());
|
||||
auto selected_pointers_ptr = std::get<3>(results[tid]).data_ptr<int64_t>();
|
||||
std::copy(
|
||||
selected_pointers_ptr, selected_pointers_ptr + num_selected,
|
||||
found_pointers.data_ptr<int64_t>() + begin);
|
||||
begin = result_offsets[policies_.size() + tid];
|
||||
end = result_offsets[policies_.size() + tid + 1];
|
||||
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
|
||||
@@ -263,7 +262,102 @@ PartitionedCachePolicy::QueryAndThenReplace(torch::Tensor keys) {
|
||||
auto missing_offsets = found_and_missing_offsets.slice(0, 2);
|
||||
return {positions, output_indices, pointers,
|
||||
missing_keys, found_offsets, missing_offsets};
|
||||
};
|
||||
}
|
||||
torch::Tensor offsets, indices, permuted_keys;
|
||||
std::tie(offsets, indices, permuted_keys) = Partition(keys);
|
||||
auto offsets_ptr = offsets.data_ptr<int64_t>();
|
||||
auto indices_ptr = indices.data_ptr<int64_t>();
|
||||
std::vector<
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
results(policies_.size());
|
||||
torch::Tensor result_offsets_tensor =
|
||||
torch::empty(policies_.size() * 2 + 1, offsets.options());
|
||||
auto result_offsets = result_offsets_tensor.data_ptr<int64_t>();
|
||||
namespace gb = graphbolt;
|
||||
{
|
||||
std::lock_guard lock(mtx_);
|
||||
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
|
||||
if (begin == end) return;
|
||||
TORCH_CHECK(end - begin == 1);
|
||||
const auto tid = begin;
|
||||
begin = offsets_ptr[tid];
|
||||
end = offsets_ptr[tid + 1];
|
||||
results[tid] = policies_.at(tid)->QueryAndThenReplace(
|
||||
permuted_keys.slice(0, begin, end));
|
||||
const auto missing_cnt = std::get<3>(results[tid]).size(0);
|
||||
result_offsets[tid] = end - begin - missing_cnt;
|
||||
result_offsets[tid + policies_.size()] = missing_cnt;
|
||||
});
|
||||
}
|
||||
std::exclusive_scan(
|
||||
result_offsets, result_offsets + result_offsets_tensor.size(0),
|
||||
result_offsets, 0);
|
||||
torch::Tensor positions = torch::empty(
|
||||
keys.size(0),
|
||||
std::get<0>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
|
||||
torch::Tensor output_indices = torch::empty_like(
|
||||
indices, indices.options().pinned_memory(utils::is_pinned(keys)));
|
||||
torch::Tensor pointers = torch::empty(
|
||||
keys.size(0),
|
||||
std::get<2>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
|
||||
torch::Tensor missing_keys = torch::empty(
|
||||
result_offsets[2 * policies_.size()] - result_offsets[policies_.size()],
|
||||
std::get<3>(results[0]).options().pinned_memory(utils::is_pinned(keys)));
|
||||
auto missing_offsets =
|
||||
torch::empty(policies_.size() + 1, result_offsets_tensor.options());
|
||||
auto positions_ptr = positions.data_ptr<int64_t>();
|
||||
auto output_indices_ptr = output_indices.data_ptr<int64_t>();
|
||||
auto pointers_ptr = pointers.data_ptr<int64_t>();
|
||||
auto missing_offsets_ptr = missing_offsets.data_ptr<int64_t>();
|
||||
missing_offsets_ptr[0] = 0;
|
||||
gb::parallel_for(0, policies_.size(), 1, [&](int64_t begin, int64_t end) {
|
||||
if (begin == end) return;
|
||||
const auto tid = begin;
|
||||
auto out_index_ptr = indices_ptr + offsets_ptr[tid];
|
||||
begin = result_offsets[tid];
|
||||
end = result_offsets[tid + 1];
|
||||
const auto num_selected = end - begin;
|
||||
auto indices_ptr = std::get<1>(results[tid]).data_ptr<int64_t>();
|
||||
for (int64_t i = 0; i < num_selected; i++) {
|
||||
output_indices_ptr[begin + i] = out_index_ptr[indices_ptr[i]];
|
||||
}
|
||||
auto selected_positions_ptr = std::get<0>(results[tid]).data_ptr<int64_t>();
|
||||
std::transform(
|
||||
selected_positions_ptr, selected_positions_ptr + num_selected,
|
||||
positions_ptr + begin,
|
||||
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
|
||||
auto selected_pointers_ptr = std::get<2>(results[tid]).data_ptr<int64_t>();
|
||||
std::copy(
|
||||
selected_pointers_ptr, selected_pointers_ptr + num_selected,
|
||||
pointers_ptr + begin);
|
||||
begin = result_offsets[policies_.size() + tid];
|
||||
end = result_offsets[policies_.size() + tid + 1];
|
||||
missing_offsets[tid + 1] = end - result_offsets[policies_.size()];
|
||||
const auto num_missing = end - begin;
|
||||
for (int64_t i = 0; i < num_missing; i++) {
|
||||
output_indices_ptr[begin + i] =
|
||||
out_index_ptr[indices_ptr[i + num_selected]];
|
||||
}
|
||||
auto missing_positions_ptr = selected_positions_ptr + num_selected;
|
||||
std::transform(
|
||||
missing_positions_ptr, missing_positions_ptr + num_missing,
|
||||
positions_ptr + begin,
|
||||
[off = tid * capacity_ / policies_.size()](auto x) { return x + off; });
|
||||
auto missing_pointers_ptr = selected_pointers_ptr + num_selected;
|
||||
std::copy(
|
||||
missing_pointers_ptr, missing_pointers_ptr + num_missing,
|
||||
pointers_ptr + begin);
|
||||
std::memcpy(
|
||||
reinterpret_cast<std::byte*>(missing_keys.data_ptr()) +
|
||||
(begin - result_offsets[policies_.size()]) *
|
||||
missing_keys.element_size(),
|
||||
std::get<3>(results[tid]).data_ptr(),
|
||||
num_missing * missing_keys.element_size());
|
||||
});
|
||||
auto found_offsets = result_offsets_tensor.slice(0, 0, policies_.size() + 1);
|
||||
return std::make_tuple(
|
||||
positions, output_indices, pointers, missing_keys, found_offsets,
|
||||
missing_offsets);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Future<std::vector<torch::Tensor>>>
|
||||
|
||||
@@ -75,16 +75,9 @@ class CPUCachedFeature(Feature):
|
||||
"""
|
||||
if ids is None:
|
||||
return self._fallback_feature.read()
|
||||
(
|
||||
values,
|
||||
missing_index,
|
||||
missing_keys,
|
||||
missing_offsets,
|
||||
) = self._feature.query(ids)
|
||||
missing_values = self._fallback_feature.read(missing_keys)
|
||||
values[missing_index] = missing_values
|
||||
self._feature.replace(missing_keys, missing_values, missing_offsets)
|
||||
return values
|
||||
return self._feature.query_and_then_replace(
|
||||
ids, self._fallback_feature.read
|
||||
)
|
||||
|
||||
def read_async(self, ids: torch.Tensor):
|
||||
"""Read the feature by index asynchronously.
|
||||
|
||||
@@ -51,11 +51,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
cache.replace(missing_keys, missing_values, missing_offsets)
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
|
||||
|
||||
pin_memory = F._default_context_str == "gpu"
|
||||
|
||||
@@ -73,11 +69,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
cache.replace(missing_keys, missing_values, missing_offsets)
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
|
||||
|
||||
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
|
||||
if not offsets:
|
||||
@@ -88,11 +80,7 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
cache.replace(missing_keys, missing_values, missing_offsets)
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
|
||||
|
||||
values, missing_index, missing_keys, missing_offsets = cache.query(keys)
|
||||
if not offsets:
|
||||
@@ -103,12 +91,9 @@ def test_feature_cache(offsets, dtype, feature_size, num_parts, policy):
|
||||
cache.replace(missing_keys, missing_values, missing_offsets)
|
||||
values[missing_index] = missing_values
|
||||
assert torch.equal(values, a[keys])
|
||||
assert torch.equal(cache2.query_and_then_replace(keys, reader_fn), a[keys])
|
||||
|
||||
if num_parts == 1:
|
||||
assert torch.equal(
|
||||
cache2.query_and_then_replace(keys, reader_fn), a[keys]
|
||||
)
|
||||
assert cache.miss_rate == cache2.miss_rate
|
||||
assert cache.miss_rate == cache2.miss_rate
|
||||
|
||||
raw_feature_cache = torch.ops.graphbolt.feature_cache(
|
||||
(cache_size,) + a.shape[1:], a.dtype, pin_memory
|
||||
|
||||
Reference in New Issue
Block a user