[GraphBolt] Implement proper IndexSelectCSC for CPU. (#7670)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-07 18:06:05 -04:00
committed by GitHub
parent b5ee45fd1a
commit ce29f5814b
2 changed files with 90 additions and 52 deletions

View File

@@ -19,6 +19,7 @@
#include <vector>
#include "./expand_indptr.h"
#include "./index_select.h"
#include "./macro.h"
#include "./random.h"
#include "./shared_memory_helper.h"
@@ -293,48 +294,21 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
return ops::InSubgraph(indptr_, indices_, nodes, type_per_edge_);
});
}
using namespace torch::indexing;
const int32_t kDefaultGrainSize = 100;
const auto num_seeds = nodes.size(0);
torch::Tensor indptr = torch::empty({num_seeds + 1}, indptr_.dtype());
std::vector<torch::Tensor> indices_arr(num_seeds);
std::vector<torch::Tensor> edge_ids_arr(num_seeds);
std::vector<torch::Tensor> type_per_edge_arr(num_seeds);
std::vector<torch::Tensor> tensors{indices_};
if (type_per_edge_.has_value()) {
tensors.push_back(*type_per_edge_);
}
AT_DISPATCH_INDEX_TYPES(
indptr_.scalar_type(), "InSubgraph::indptr", ([&] {
const auto indptr_data = indptr_.data_ptr<index_t>();
auto out_indptr_data = indptr.data_ptr<index_t>();
out_indptr_data[0] = 0;
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "InSubgraph::nodes", ([&] {
const auto nodes_data = nodes.data_ptr<index_t>();
torch::parallel_for(
0, num_seeds, kDefaultGrainSize,
[&](size_t start, size_t end) {
for (size_t i = start; i < end; ++i) {
const auto node_id = nodes_data[i];
const auto start_idx = indptr_data[node_id];
const auto end_idx = indptr_data[node_id + 1];
out_indptr_data[i + 1] = end_idx - start_idx;
indices_arr[i] = indices_.slice(0, start_idx, end_idx);
edge_ids_arr[i] = torch::arange(
start_idx, end_idx, indptr_.scalar_type());
if (type_per_edge_) {
type_per_edge_arr[i] =
type_per_edge_.value().slice(0, start_idx, end_idx);
}
}
});
}));
}));
auto [output_indptr, results] =
ops::IndexSelectCSCBatched(indptr_, tensors, nodes, true, torch::nullopt);
torch::optional<torch::Tensor> type_per_edge;
if (type_per_edge_.has_value()) {
type_per_edge = results.at(1);
}
return c10::make_intrusive<FusedSampledSubgraph>(
indptr.cumsum(0), torch::cat(indices_arr), torch::cat(edge_ids_arr),
nodes, torch::arange(0, NumNodes()),
type_per_edge_
? torch::optional<torch::Tensor>{torch::cat(type_per_edge_arr)}
: torch::nullopt);
output_indptr, results.at(0), results.back(), nodes,
torch::arange(0, NumNodes()), type_per_edge);
}
/**

View File

@@ -8,6 +8,9 @@
#include <graphbolt/cuda_ops.h>
#include <graphbolt/fused_csc_sampling_graph.h>
#include <cstring>
#include <numeric>
#include "./macro.h"
#include "./utils.h"
@@ -107,9 +110,9 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes, output_size); });
}
sampling::FusedCSCSamplingGraph g(indptr, indices);
const auto res = g.InSubgraph(nodes);
return std::make_tuple(res->indptr, res->indices.value());
auto [output_indptr, results] = IndexSelectCSCBatched(
indptr, std::vector{indices}, nodes, false, output_size);
return std::make_tuple(output_indptr, results.at(0));
}
std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
@@ -129,17 +132,78 @@ std::tuple<torch::Tensor, std::vector<torch::Tensor>> IndexSelectCSCBatched(
indptr, indices_list, nodes, with_edge_ids, output_size);
});
}
constexpr int kDefaultGrainSize = 128;
const auto num_nodes = nodes.size(0);
torch::Tensor output_indptr = torch::empty(
{num_nodes + 1}, nodes.options().dtype(indptr.scalar_type()));
std::vector<torch::Tensor> results;
torch::Tensor output_indptr;
torch::Tensor edge_ids;
for (auto& indices : indices_list) {
sampling::FusedCSCSamplingGraph g(indptr, indices);
const auto res = g.InSubgraph(nodes);
output_indptr = res->indptr;
results.push_back(res->indices.value());
edge_ids = res->original_edge_ids;
}
if (with_edge_ids) results.push_back(edge_ids);
torch::optional<torch::Tensor> edge_ids;
AT_DISPATCH_INDEX_TYPES(
indptr.scalar_type(), "IndexSelectCSCBatched::indptr", ([&] {
using indptr_t = index_t;
const auto indptr_data = indptr.data_ptr<indptr_t>();
auto out_indptr_data = output_indptr.data_ptr<indptr_t>();
out_indptr_data[0] = 0;
AT_DISPATCH_INDEX_TYPES(
nodes.scalar_type(), "IndexSelectCSCBatched::nodes", ([&] {
const auto nodes_data = nodes.data_ptr<index_t>();
torch::parallel_for(
0, num_nodes, kDefaultGrainSize,
[&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
const auto node_id = nodes_data[i];
const auto degree =
indptr_data[node_id + 1] - indptr_data[node_id];
out_indptr_data[i + 1] = degree;
}
});
output_indptr = output_indptr.cumsum(0, indptr.scalar_type());
out_indptr_data = output_indptr.data_ptr<indptr_t>();
TORCH_CHECK(
!output_size.has_value() ||
out_indptr_data[num_nodes] == *output_size,
"An incorrect output_size argument was provided.");
output_size = out_indptr_data[num_nodes];
for (const auto& indices : indices_list) {
results.push_back(torch::empty(
*output_size,
nodes.options().dtype(indices.scalar_type())));
}
if (with_edge_ids) {
edge_ids = torch::empty(
*output_size, nodes.options().dtype(indptr.scalar_type()));
}
torch::parallel_for(
0, num_nodes, kDefaultGrainSize,
[&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; i++) {
const auto output_offset = out_indptr_data[i];
const auto numel = out_indptr_data[i + 1] - output_offset;
const auto input_offset = indptr_data[nodes_data[i]];
for (size_t tensor_id = 0;
tensor_id < indices_list.size(); tensor_id++) {
auto output = reinterpret_cast<std::byte*>(
results[tensor_id].data_ptr());
const auto input = reinterpret_cast<std::byte*>(
indices_list[tensor_id].data_ptr());
const auto element_size =
indices_list[tensor_id].element_size();
std::memcpy(
output + output_offset * element_size,
input + input_offset * element_size,
element_size * numel);
}
if (edge_ids.has_value()) {
auto output = edge_ids->data_ptr<indptr_t>();
std::iota(
output + output_offset,
output + output_offset + numel, input_offset);
}
}
});
}));
}));
if (edge_ids) results.push_back(*edge_ids);
return std::make_tuple(output_indptr, results);
}