mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Implement proper IndexSelectCSC for CPU. (#7670)
This commit is contained in:
committed by
GitHub
parent
b5ee45fd1a
commit
ce29f5814b
@@ -19,6 +19,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "./expand_indptr.h"
|
||||
#include "./index_select.h"
|
||||
#include "./macro.h"
|
||||
#include "./random.h"
|
||||
#include "./shared_memory_helper.h"
|
||||
@@ -293,48 +294,21 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
|
||||
return ops::InSubgraph(indptr_, indices_, nodes, type_per_edge_);
|
||||
});
|
||||
}
|
||||
using namespace torch::indexing;
|
||||
const int32_t kDefaultGrainSize = 100;
|
||||
const auto num_seeds = nodes.size(0);
|
||||
torch::Tensor indptr = torch::empty({num_seeds + 1}, indptr_.dtype());
|
||||
std::vector<torch::Tensor> indices_arr(num_seeds);
|
||||
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
|
||||
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
|
||||
std::vector<torch::Tensor> tensors{indices_};
|
||||
if (type_per_edge_.has_value()) {
|
||||
tensors.push_back(*type_per_edge_);
|
||||
}
|
||||
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
indptr_.scalar_type(), "InSubgraph::indptr", ([&] {
|
||||
const auto indptr_data = indptr_.data_ptr<index_t>();
|
||||
auto out_indptr_data = indptr.data_ptr<index_t>();
|
||||
out_indptr_data[0] = 0;
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
nodes.scalar_type(), "InSubgraph::nodes", ([&] {
|
||||
const auto nodes_data = nodes.data_ptr<index_t>();
|
||||
torch::parallel_for(
|
||||
0, num_seeds, kDefaultGrainSize,
|
||||
[&](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
const auto node_id = nodes_data[i];
|
||||
const auto start_idx = indptr_data[node_id];
|
||||
const auto end_idx = indptr_data[node_id + 1];
|
||||
out_indptr_data[i + 1] = end_idx - start_idx;
|
||||
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
|
||||
edge_ids_arr[i] = torch::arange(
|
||||
start_idx, end_idx, indptr_.scalar_type());
|
||||
if (type_per_edge_) {
|
||||
type_per_edge_arr[i] =
|
||||
type_per_edge_.value().slice(0, start_idx, end_idx);
|
||||
}
|
||||
}
|
||||
});
|
||||
}));
|
||||
}));
|
||||
auto [output_indptr, results] =
|
||||
ops::IndexSelectCSCBatched(indptr_, tensors, nodes, true, torch::nullopt);
|
||||
torch::optional<torch::Tensor> type_per_edge;
|
||||
if (type_per_edge_.has_value()) {
|
||||
type_per_edge = results.at(1);
|
||||
}
|
||||
|
||||
return c10::make_intrusive<FusedSampledSubgraph>(
|
||||
indptr.cumsum(0), torch::cat(indices_arr), torch::cat(edge_ids_arr),
|
||||
nodes, torch::arange(0, NumNodes()),
|
||||
type_per_edge_
|
||||
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
|
||||
: torch::nullopt);
|
||||
output_indptr, results.at(0), results.back(), nodes,
|
||||
torch::arange(0, NumNodes()), type_per_edge);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
#include <graphbolt/cuda_ops.h>
|
||||
#include <graphbolt/fused_csc_sampling_graph.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
|
||||
#include "./macro.h"
|
||||
#include "./utils.h"
|
||||
|
||||
@@ -107,9 +110,9 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
|
||||
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
|
||||
{ return IndexSelectCSCImpl(indptr, indices, nodes, output_size); });
|
||||
}
|
||||
sampling::FusedCSCSamplingGraph g(indptr, indices);
|
||||
const auto res = g.InSubgraph(nodes);
|
||||
return std::make_tuple(res->indptr, res->indices.value());
|
||||
auto [output_indptr, results] = IndexSelectCSCBatched(
|
||||
indptr, std::vector{indices}, nodes, false, output_size);
|
||||
return std::make_tuple(output_indptr, results.at(0));
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
|
||||
@@ -129,17 +132,78 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
|
||||
indptr, indices_list, nodes, with_edge_ids, output_size);
|
||||
});
|
||||
}
|
||||
constexpr int kDefaultGrainSize = 128;
|
||||
const auto num_nodes = nodes.size(0);
|
||||
torch::Tensor output_indptr = torch::empty(
|
||||
{num_nodes + 1}, nodes.options().dtype(indptr.scalar_type()));
|
||||
std::vector<torch::Tensor> results;
|
||||
torch::Tensor output_indptr;
|
||||
torch::Tensor edge_ids;
|
||||
for (auto& indices : indices_list) {
|
||||
sampling::FusedCSCSamplingGraph g(indptr, indices);
|
||||
const auto res = g.InSubgraph(nodes);
|
||||
output_indptr = res->indptr;
|
||||
results.push_back(res->indices.value());
|
||||
edge_ids = res->original_edge_ids;
|
||||
}
|
||||
if (with_edge_ids) results.push_back(edge_ids);
|
||||
torch::optional<torch::Tensor> edge_ids;
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
indptr.scalar_type(), "IndexSelectCSCBatched::indptr", ([&] {
|
||||
using indptr_t = index_t;
|
||||
const auto indptr_data = indptr.data_ptr<indptr_t>();
|
||||
auto out_indptr_data = output_indptr.data_ptr<indptr_t>();
|
||||
out_indptr_data[0] = 0;
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
nodes.scalar_type(), "IndexSelectCSCBatched::nodes", ([&] {
|
||||
const auto nodes_data = nodes.data_ptr<index_t>();
|
||||
torch::parallel_for(
|
||||
0, num_nodes, kDefaultGrainSize,
|
||||
[&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; i++) {
|
||||
const auto node_id = nodes_data[i];
|
||||
const auto degree =
|
||||
indptr_data[node_id + 1] - indptr_data[node_id];
|
||||
out_indptr_data[i + 1] = degree;
|
||||
}
|
||||
});
|
||||
output_indptr = output_indptr.cumsum(0, indptr.scalar_type());
|
||||
out_indptr_data = output_indptr.data_ptr<indptr_t>();
|
||||
TORCH_CHECK(
|
||||
!output_size.has_value() ||
|
||||
out_indptr_data[num_nodes] == *output_size,
|
||||
"An incorrect output_size argument was provided.");
|
||||
output_size = out_indptr_data[num_nodes];
|
||||
for (const auto& indices : indices_list) {
|
||||
results.push_back(torch::empty(
|
||||
*output_size,
|
||||
nodes.options().dtype(indices.scalar_type())));
|
||||
}
|
||||
if (with_edge_ids) {
|
||||
edge_ids = torch::empty(
|
||||
*output_size, nodes.options().dtype(indptr.scalar_type()));
|
||||
}
|
||||
torch::parallel_for(
|
||||
0, num_nodes, kDefaultGrainSize,
|
||||
[&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; i++) {
|
||||
const auto output_offset = out_indptr_data[i];
|
||||
const auto numel = out_indptr_data[i + 1] - output_offset;
|
||||
const auto input_offset = indptr_data[nodes_data[i]];
|
||||
for (size_t tensor_id = 0;
|
||||
tensor_id < indices_list.size(); tensor_id++) {
|
||||
auto output = reinterpret_cast<std::byte*>(
|
||||
results[tensor_id].data_ptr());
|
||||
const auto input = reinterpret_cast<std::byte*>(
|
||||
indices_list[tensor_id].data_ptr());
|
||||
const auto element_size =
|
||||
indices_list[tensor_id].element_size();
|
||||
std::memcpy(
|
||||
output + output_offset * element_size,
|
||||
input + input_offset * element_size,
|
||||
element_size * numel);
|
||||
}
|
||||
if (edge_ids.has_value()) {
|
||||
auto output = edge_ids->data_ptr<indptr_t>();
|
||||
std::iota(
|
||||
output + output_offset,
|
||||
output + output_offset + numel, input_offset);
|
||||
}
|
||||
}
|
||||
});
|
||||
}));
|
||||
}));
|
||||
if (edge_ids) results.push_back(*edge_ids);
|
||||
return std::make_tuple(output_indptr, results);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user