mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
committed by
GitHub
parent
87ea76b02a
commit
eafb53013f
@@ -63,7 +63,7 @@ struct AdjacentDifference {
|
||||
torch::Tensor ExpandIndptrImpl(
|
||||
torch::Tensor indptr, torch::ScalarType dtype,
|
||||
torch::optional<torch::Tensor> nodes, torch::optional<int64_t> output_size,
|
||||
const bool edge_ids) {
|
||||
const bool is_edge_ids_variant) {
|
||||
if (!output_size.has_value()) {
|
||||
output_size = AT_DISPATCH_INTEGRAL_TYPES(
|
||||
indptr.scalar_type(), "ExpandIndptrIndptr[-1]", ([&]() -> int64_t {
|
||||
@@ -102,7 +102,7 @@ torch::Tensor ExpandIndptrImpl(
|
||||
constexpr int64_t max_copy_at_once =
|
||||
std::numeric_limits<int32_t>::max();
|
||||
|
||||
if (edge_ids) {
|
||||
if (is_edge_ids_variant) {
|
||||
auto input_buffer = thrust::make_transform_iterator(
|
||||
iota, IotaIndex<indices_t, nodes_t>{nodes_ptr});
|
||||
for (int64_t i = 0; i < num_rows; i += max_copy_at_once) {
|
||||
|
||||
Reference in New Issue
Block a user