mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Fix overlap_graph_fetch edge_ids cached case. (#7595)
This commit is contained in:
committed by
GitHub
parent
bea83ed738
commit
db13f05f58
@@ -148,6 +148,7 @@ GpuGraphCache::GpuGraphCache(
|
||||
num_edges_ = 0;
|
||||
indptr_ =
|
||||
torch::zeros(initial_node_capacity + 1, options.dtype(indptr_dtype));
|
||||
offset_ = torch::empty(indptr_.size(0) - 1, indptr_.options());
|
||||
for (auto dtype : dtypes) {
|
||||
cached_edge_tensors_.push_back(
|
||||
torch::empty(num_edges, options.dtype(dtype)));
|
||||
@@ -240,7 +241,8 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
torch::Tensor seeds, torch::Tensor indices, torch::Tensor positions,
|
||||
int64_t num_hit, int64_t num_threshold, torch::Tensor indptr,
|
||||
std::vector<torch::Tensor> edge_tensors) {
|
||||
const auto num_tensors = edge_tensors.size();
|
||||
// The last element of edge_tensors has the edge ids.
|
||||
const auto num_tensors = edge_tensors.size() - 1;
|
||||
TORCH_CHECK(
|
||||
num_tensors == cached_edge_tensors_.size(),
|
||||
"Same number of tensors need to be passed!");
|
||||
@@ -301,14 +303,21 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
auto input = allocator.AllocateStorage<std::byte*>(num_buffers);
|
||||
auto input_size =
|
||||
allocator.AllocateStorage<size_t>(num_buffers + 1);
|
||||
auto edge_id_offsets = torch::empty(
|
||||
num_nodes, seeds.options().dtype(offset_.scalar_type()));
|
||||
const auto cache_missing_dtype_dev_ptr =
|
||||
cache_missing_dtype_dev.get();
|
||||
const auto indices_ptr = indices.data_ptr<indices_t>();
|
||||
const auto positions_ptr = positions.data_ptr<indices_t>();
|
||||
const auto input_ptr = input.get();
|
||||
const auto input_size_ptr = input_size.get();
|
||||
const auto edge_id_offsets_ptr =
|
||||
edge_id_offsets.data_ptr<indptr_t>();
|
||||
const auto cache_indptr = indptr_.data_ptr<indptr_t>();
|
||||
const auto missing_indptr = indptr.data_ptr<indptr_t>();
|
||||
const auto cache_offset = offset_.data_ptr<indptr_t>();
|
||||
const auto missing_edge_ids =
|
||||
edge_tensors.back().data_ptr<indptr_t>();
|
||||
CUB_CALL(DeviceFor::Bulk, num_buffers, [=] __device__(int64_t i) {
|
||||
const auto tensor_idx = i / num_nodes;
|
||||
const auto idx = i % num_nodes;
|
||||
@@ -322,11 +331,16 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
const auto offset_end = is_cached
|
||||
? cache_indptr[pos + 1]
|
||||
: missing_indptr[idx - num_hit + 1];
|
||||
const auto edge_id =
|
||||
is_cached ? cache_offset[pos] : missing_edge_ids[offset];
|
||||
const auto out_idx = tensor_idx * num_nodes + original_idx;
|
||||
|
||||
input_ptr[out_idx] =
|
||||
(is_cached ? cache_ptr : missing_ptr) + offset * size;
|
||||
input_size_ptr[out_idx] = size * (offset_end - offset);
|
||||
if (i < num_nodes) {
|
||||
edge_id_offsets_ptr[out_idx] = edge_id;
|
||||
}
|
||||
});
|
||||
auto output_indptr = torch::empty(
|
||||
num_nodes + 1, seeds.options().dtype(indptr_.scalar_type()));
|
||||
@@ -367,11 +381,15 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
indptr_.size(0) * kIntGrowthFactor, indptr_.options());
|
||||
new_indptr.slice(0, 0, indptr_.size(0)) = indptr_;
|
||||
indptr_ = new_indptr;
|
||||
auto new_offset =
|
||||
torch::empty(indptr_.size(0) - 1, offset_.options());
|
||||
new_offset.slice(0, 0, offset_.size(0)) = offset_;
|
||||
offset_ = new_offset;
|
||||
}
|
||||
torch::Tensor sindptr;
|
||||
bool enough_space;
|
||||
torch::optional<int64_t> cached_output_size;
|
||||
for (size_t i = 0; i < edge_tensors.size(); i++) {
|
||||
for (size_t i = 0; i < num_tensors; i++) {
|
||||
torch::Tensor sindices;
|
||||
std::tie(sindptr, sindices) = ops::IndexSelectCSCImpl(
|
||||
in_degree, sliced_indptr, edge_tensors[i], output_indices,
|
||||
@@ -388,12 +406,21 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
}
|
||||
if (enough_space) {
|
||||
auto num_edges = num_edges_;
|
||||
THRUST_CALL(
|
||||
transform, sindptr.data_ptr<indptr_t>() + 1,
|
||||
sindptr.data_ptr<indptr_t>() + sindptr.size(0),
|
||||
auto transform_input_it = thrust::make_zip_iterator(
|
||||
sindptr.data_ptr<indptr_t>() + 1,
|
||||
sliced_indptr.data_ptr<indptr_t>());
|
||||
auto transform_output_it = thrust::make_zip_iterator(
|
||||
indptr_.data_ptr<indptr_t>() + num_nodes_ + 1,
|
||||
[=] __host__ __device__(indptr_t x) {
|
||||
return x + num_edges;
|
||||
offset_.data_ptr<indptr_t>() + num_nodes_);
|
||||
THRUST_CALL(
|
||||
transform, transform_input_it,
|
||||
transform_input_it + sindptr.size(0) - 1,
|
||||
transform_output_it,
|
||||
[=] __host__ __device__(
|
||||
const thrust::tuple<indptr_t, indptr_t>& x) {
|
||||
return thrust::make_tuple(
|
||||
thrust::get<0>(x) + num_edges,
|
||||
missing_edge_ids[thrust::get<1>(x)]);
|
||||
});
|
||||
auto map = reinterpret_cast<map_t<indices_t>*>(map_);
|
||||
const dim3 block(kIntBlockSize);
|
||||
@@ -431,6 +458,10 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> GpuGraphCache::Replace(
|
||||
.view(edge_tensors[i].scalar_type())
|
||||
.slice(0, 0, static_cast<indptr_t>(output_size)));
|
||||
}
|
||||
// Append the edge ids as the last element of the output.
|
||||
output_edge_tensors.push_back(ops::IndptrEdgeIdsImpl(
|
||||
output_indptr, output_indptr.scalar_type(), edge_id_offsets,
|
||||
static_cast<int64_t>(static_cast<indptr_t>(output_size))));
|
||||
|
||||
{
|
||||
thrust::counting_iterator<int64_t> iota{0};
|
||||
|
||||
@@ -84,7 +84,9 @@ class GpuGraphCache : public torch::CustomClassHolder {
|
||||
* @param num_threshold The number of seeds among the missing node ids that
|
||||
* will be inserted into the cache.
|
||||
* @param indptr The indptr for the missing seeds fetched from remote.
|
||||
* @param edge_tensors The edge tensors for the missing seeds.
|
||||
* @param edge_tensors The edge tensors for the missing seeds. The last
|
||||
* element of edge_tensors is treated as the edge ids tensor with
|
||||
* indptr_dtype.
|
||||
*
|
||||
* @return (torch::Tensor, std::vector<torch::Tensor>) The final indptr and
|
||||
* edge_tensors, directly corresponding to the seeds tensor.
|
||||
@@ -106,6 +108,7 @@ class GpuGraphCache : public torch::CustomClassHolder {
|
||||
int64_t num_nodes_; // The number of cached nodes in the cache.
|
||||
int64_t num_edges_; // The number of cached edges in the cache.
|
||||
torch::Tensor indptr_; // The cached graph structure indptr tensor.
|
||||
torch::Tensor offset_; // The original graph's sliced_indptr tensor.
|
||||
std::vector<torch::Tensor> cached_edge_tensors_; // The cached graph
|
||||
// structure edge tensors.
|
||||
};
|
||||
|
||||
@@ -64,6 +64,7 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
|
||||
probs_or_mask = subgraph.edge_attribute(self.prob_name)
|
||||
if probs_or_mask is not None:
|
||||
edge_tensors.append(probs_or_mask)
|
||||
edge_tensors.append(subgraph.edge_attribute(ORIGINAL_EDGE_ID))
|
||||
|
||||
subgraph.csc_indptr, edge_tensors = minibatch._replace(
|
||||
subgraph.csc_indptr, edge_tensors
|
||||
@@ -78,11 +79,9 @@ class CombineCachedAndFetchedInSubgraph(Mapper):
|
||||
if probs_or_mask is not None:
|
||||
subgraph.add_edge_attribute(self.prob_name, edge_tensors[0])
|
||||
edge_tensors = edge_tensors[1:]
|
||||
subgraph.add_edge_attribute(ORIGINAL_EDGE_ID, edge_tensors[0])
|
||||
edge_tensors = edge_tensors[1:]
|
||||
assert len(edge_tensors) == 0
|
||||
# TODO @mfbalin: remove these lines after fixing cache for edge ids.
|
||||
edge_attributes = subgraph.edge_attributes
|
||||
edge_attributes.pop(ORIGINAL_EDGE_ID)
|
||||
subgraph.edge_attributes = edge_attributes
|
||||
|
||||
return minibatch
|
||||
|
||||
|
||||
@@ -55,24 +55,22 @@ def test_gpu_graph_cache(indptr_dtype, dtype, cache_size):
|
||||
torch.arange(2, dtype=indices_dtype, device=F.ctx()) + i * 2
|
||||
) % (indptr.size(0) - 1)
|
||||
missing_keys, replace = g.query(keys)
|
||||
missing_edge_tensors = []
|
||||
for e in edge_tensors:
|
||||
missing_indptr, missing_e = torch.ops.graphbolt.index_select_csc(
|
||||
indptr, e, missing_keys, None
|
||||
)
|
||||
missing_edge_tensors.append(missing_e)
|
||||
|
||||
(
|
||||
missing_indptr,
|
||||
missing_edge_tensors,
|
||||
) = torch.ops.graphbolt.index_select_csc_batched(
|
||||
indptr, edge_tensors, missing_keys, True, None
|
||||
)
|
||||
output_indptr, output_edge_tensors = replace(
|
||||
missing_indptr, missing_edge_tensors
|
||||
)
|
||||
|
||||
reference_edge_tensors = []
|
||||
for e in edge_tensors:
|
||||
(
|
||||
reference_indptr,
|
||||
reference_e,
|
||||
) = torch.ops.graphbolt.index_select_csc(indptr, e, keys, None)
|
||||
reference_edge_tensors.append(reference_e)
|
||||
(
|
||||
reference_indptr,
|
||||
reference_edge_tensors,
|
||||
) = torch.ops.graphbolt.index_select_csc_batched(
|
||||
indptr, edge_tensors, keys, True, None
|
||||
)
|
||||
|
||||
assert torch.equal(output_indptr, reference_indptr)
|
||||
assert len(output_edge_tensors) == len(reference_edge_tensors)
|
||||
|
||||
@@ -91,8 +91,7 @@ def test_NeighborSampler_GraphFetch(
|
||||
new_results = list(datapipe)
|
||||
assert len(expected_results) == len(new_results)
|
||||
for a, b in zip(expected_results, new_results):
|
||||
# TODO @mfbalin: Fix the edge id bug and enable this test.
|
||||
assert num_cached_edges != 0 or repr(a) == repr(b)
|
||||
assert repr(a) == repr(b)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_dependency", [False, True])
|
||||
|
||||
@@ -164,11 +164,7 @@ def test_gpu_sampling_DataLoader(
|
||||
assert "a" in minibatch.node_features
|
||||
assert "b" in minibatch.node_features
|
||||
assert "c" in minibatch.node_features
|
||||
# TODO @mfbalin: enable this if.
|
||||
if (
|
||||
num_gpu_cached_edges == 0
|
||||
and sampler_name == "LayerNeighborSampler"
|
||||
):
|
||||
if sampler_name == "LayerNeighborSampler":
|
||||
assert torch.equal(
|
||||
minibatch.node_features["a"], minibatch2.node_features["a"]
|
||||
)
|
||||
@@ -176,10 +172,6 @@ def test_gpu_sampling_DataLoader(
|
||||
assert "d" in minibatch.edge_features[layer_id]
|
||||
edge_feature = minibatch.edge_features[layer_id]["d"]
|
||||
edge_feature_ref = minibatch2.edge_features[layer_id]["d"]
|
||||
# TODO @mfbalin: enable this if.
|
||||
if (
|
||||
num_gpu_cached_edges == 0
|
||||
and sampler_name == "LayerNeighborSampler"
|
||||
):
|
||||
if sampler_name == "LayerNeighborSampler":
|
||||
assert torch.equal(edge_feature, edge_feature_ref)
|
||||
assert len(list(dataloader)) == N // B
|
||||
|
||||
Reference in New Issue
Block a user