mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Return partition offsets to be utilized in all_to_all. (#7775)
This commit is contained in:
committed by
GitHub
parent
0734e33e0e
commit
f71427f33f
@@ -49,10 +49,12 @@ torch::Tensor RankAssignment(
|
||||
return part_ids;
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> RankSortImpl(
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, at::cuda::CUDAEvent>
|
||||
RankSortImpl(
|
||||
torch::Tensor nodes, torch::Tensor part_ids, torch::Tensor offsets_dev,
|
||||
const int64_t world_size) {
|
||||
const int num_bits = cuda::NumberOfBits(world_size);
|
||||
const auto num_batches = offsets_dev.numel() - 1;
|
||||
auto offsets_dev_ptr = offsets_dev.data_ptr<int64_t>();
|
||||
auto part_ids_sorted = torch::empty_like(part_ids);
|
||||
auto part_ids2 = part_ids.clone();
|
||||
@@ -60,27 +62,47 @@ std::pair<torch::Tensor, torch::Tensor> RankSortImpl(
|
||||
auto nodes_sorted = torch::empty_like(nodes);
|
||||
auto index = torch::arange(nodes.numel(), nodes.options());
|
||||
auto index_sorted = torch::empty_like(index);
|
||||
AT_DISPATCH_INDEX_TYPES(
|
||||
return AT_DISPATCH_INDEX_TYPES(
|
||||
nodes.scalar_type(), "RankSortImpl", ([&] {
|
||||
CUB_CALL(
|
||||
DeviceSegmentedRadixSort::SortPairs,
|
||||
part_ids.data_ptr<cuda::part_t>(),
|
||||
part_ids_sorted.data_ptr<cuda::part_t>(), nodes.data_ptr<index_t>(),
|
||||
nodes_sorted.data_ptr<index_t>(), nodes.numel(),
|
||||
offsets_dev.numel() - 1, offsets_dev_ptr, offsets_dev_ptr + 1, 0,
|
||||
num_bits);
|
||||
nodes_sorted.data_ptr<index_t>(), nodes.numel(), num_batches,
|
||||
offsets_dev_ptr, offsets_dev_ptr + 1, 0, num_bits);
|
||||
auto offsets = torch::empty(
|
||||
num_batches * world_size + 1, c10::TensorOptions()
|
||||
.dtype(offsets_dev.scalar_type())
|
||||
.pinned_memory(true));
|
||||
CUB_CALL(
|
||||
DeviceFor::Bulk, num_batches * world_size + 1,
|
||||
[=, part_ids = part_ids_sorted.data_ptr<cuda::part_t>(),
|
||||
offsets = offsets.data_ptr<int64_t>()] __device__(int64_t i) {
|
||||
const auto batch_id = i / world_size;
|
||||
const auto rank = i % world_size;
|
||||
const auto offset_begin = offsets_dev_ptr[batch_id];
|
||||
const auto offset_end =
|
||||
offsets_dev_ptr[::cuda::std::min(batch_id + 1, num_batches)];
|
||||
offsets[i] = cub::LowerBound(
|
||||
part_ids + offset_begin,
|
||||
offset_end - offset_begin, rank) +
|
||||
offset_begin;
|
||||
});
|
||||
at::cuda::CUDAEvent offsets_event;
|
||||
offsets_event.record();
|
||||
CUB_CALL(
|
||||
DeviceSegmentedRadixSort::SortPairs,
|
||||
part_ids2.data_ptr<cuda::part_t>(),
|
||||
part_ids2_sorted.data_ptr<cuda::part_t>(),
|
||||
index.data_ptr<index_t>(), index_sorted.data_ptr<index_t>(),
|
||||
nodes.numel(), offsets_dev.numel() - 1, offsets_dev_ptr,
|
||||
offsets_dev_ptr + 1, 0, num_bits);
|
||||
nodes.numel(), num_batches, offsets_dev_ptr, offsets_dev_ptr + 1, 0,
|
||||
num_bits);
|
||||
return std::make_tuple(
|
||||
nodes_sorted, index_sorted, offsets, std::move(offsets_event));
|
||||
}));
|
||||
return {nodes_sorted, index_sorted};
|
||||
}
|
||||
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor>> RankSort(
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
|
||||
std::vector<torch::Tensor>& nodes_list, const int64_t rank,
|
||||
const int64_t world_size) {
|
||||
const auto num_batches = nodes_list.size();
|
||||
@@ -100,13 +122,15 @@ std::vector<std::tuple<torch::Tensor, torch::Tensor>> RankSort(
|
||||
offsets_dev.data_ptr<int64_t>(), offsets_ptr,
|
||||
sizeof(int64_t) * offsets.numel(), cudaMemcpyHostToDevice,
|
||||
cuda::GetCurrentStream()));
|
||||
auto [nodes_sorted, index_sorted] =
|
||||
auto [nodes_sorted, index_sorted, rank_offsets, rank_offsets_event] =
|
||||
RankSortImpl(nodes, part_ids, offsets_dev, world_size);
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor>> results;
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> results;
|
||||
rank_offsets_event.synchronize();
|
||||
for (int64_t i = 0; i < num_batches; i++) {
|
||||
results.emplace_back(
|
||||
nodes_sorted.slice(0, offsets_ptr[i], offsets_ptr[i + 1]),
|
||||
index_sorted.slice(0, offsets_ptr[i], offsets_ptr[i + 1]));
|
||||
index_sorted.slice(0, offsets_ptr[i], offsets_ptr[i + 1]),
|
||||
rank_offsets.slice(0, i * world_size, (i + 1) * world_size + 1));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@@ -74,11 +74,15 @@ torch::Tensor RankAssignment(
|
||||
* @param offsets_dev Offsets to separate different node types.
|
||||
* @param world_size World size, the total number of cooperating GPUs.
|
||||
*
|
||||
* @return (sorted_nodes, original_positions), where the first
|
||||
* one includes sorted nodes, the second contains original positions of the
|
||||
* sorted nodes.
|
||||
* @return (sorted_nodes, original_positions, rank_offsets, rank_offsets_event),
|
||||
* where the first one includes sorted nodes, the second contains original
|
||||
* positions of the sorted nodes and the third contains the offsets of the
|
||||
* sorted_nodes indicating sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]]
|
||||
* contains nodes that belongs to the `i`th rank. Before accessing rank_offsets
|
||||
* on the CPU, `rank_offsets_event.synchronize()` is required.
|
||||
*/
|
||||
std::pair<torch::Tensor, torch::Tensor> RankSortImpl(
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, at::cuda::CUDAEvent>
|
||||
RankSortImpl(
|
||||
torch::Tensor nodes, torch::Tensor part_ids, torch::Tensor offsets_dev,
|
||||
int64_t world_size);
|
||||
|
||||
@@ -91,11 +95,13 @@ std::pair<torch::Tensor, torch::Tensor> RankSortImpl(
|
||||
* @param rank Rank of the current GPU.
|
||||
* @param world_size World size, the total number of cooperating GPUs.
|
||||
*
|
||||
* @return vector of (sorted_nodes, original_positions), where the first
|
||||
* one includes sorted nodes, the second contains original positions of the
|
||||
* sorted nodes.
|
||||
* @return vector of (sorted_nodes, original_positions, rank_offsets), where the
|
||||
* first one includes sorted nodes, the second contains original positions of
|
||||
* the sorted nodes and the third contains the offsets of the sorted_nodes
|
||||
* indicating sorted_nodes[rank_offsets[i]: rank_offsets[i + 1]] contains nodes
|
||||
* that belongs to the `i`th rank.
|
||||
*/
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor>> RankSort(
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>> RankSort(
|
||||
std::vector<torch::Tensor>& nodes_list, int64_t rank, int64_t world_size);
|
||||
|
||||
} // namespace cuda
|
||||
|
||||
@@ -28,7 +28,8 @@
|
||||
namespace graphbolt {
|
||||
namespace ops {
|
||||
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> >
|
||||
std::vector<
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
UniqueAndCompactBatchedHashMapBased(
|
||||
const std::vector<torch::Tensor>& src_ids,
|
||||
const std::vector<torch::Tensor>& dst_ids,
|
||||
|
||||
@@ -106,7 +106,8 @@ __global__ void _MapIdsBatched(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
std::vector<
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
UniqueAndCompactBatchedHashMapBased(
|
||||
const std::vector<torch::Tensor>& src_ids,
|
||||
const std::vector<torch::Tensor>& dst_ids,
|
||||
@@ -258,7 +259,6 @@ UniqueAndCompactBatchedHashMapBased(
|
||||
auto unique_ids_offsets = torch::empty(
|
||||
num_batches + 1,
|
||||
c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));
|
||||
auto unique_ids_offsets_ptr = unique_ids_offsets.data_ptr<int64_t>();
|
||||
{
|
||||
auto unique_ids_offsets_dev2 =
|
||||
torch::empty_like(unique_ids_offsets_dev);
|
||||
@@ -271,7 +271,7 @@ UniqueAndCompactBatchedHashMapBased(
|
||||
thrust::make_transform_output_iterator(
|
||||
thrust::make_zip_iterator(
|
||||
unique_ids_offsets_dev2.data_ptr<int64_t>(),
|
||||
unique_ids_offsets_ptr),
|
||||
unique_ids_offsets.data_ptr<int64_t>()),
|
||||
::cuda::proclaim_return_type<
|
||||
thrust::tuple<int64_t, int64_t>>(
|
||||
[=] __device__(const auto x) {
|
||||
@@ -283,11 +283,14 @@ UniqueAndCompactBatchedHashMapBased(
|
||||
unique_ids_offsets_dev.data_ptr<int64_t>();
|
||||
}
|
||||
at::cuda::CUDAEvent unique_ids_offsets_event;
|
||||
unique_ids_offsets_event.record();
|
||||
torch::optional<torch::Tensor> index;
|
||||
if (part_ids) {
|
||||
std::tie(unique_ids, index) = cuda::RankSortImpl(
|
||||
unique_ids, *part_ids, unique_ids_offsets_dev, world_size);
|
||||
std::tie(
|
||||
unique_ids, index, unique_ids_offsets, unique_ids_offsets_event) =
|
||||
cuda::RankSortImpl(
|
||||
unique_ids, *part_ids, unique_ids_offsets_dev, world_size);
|
||||
} else {
|
||||
unique_ids_offsets_event.record();
|
||||
}
|
||||
auto mapped_ids =
|
||||
torch::empty(offsets_ptr[3 * num_batches], unique_ids.options());
|
||||
@@ -297,18 +300,23 @@ UniqueAndCompactBatchedHashMapBased(
|
||||
pointers_dev_ptr, offsets_dev_ptr, unique_ids_offsets_dev_ptr,
|
||||
index ? index->data_ptr<index_t>() : nullptr, map.ref(cuco::find),
|
||||
mapped_ids.data_ptr<index_t>());
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
std::vector<std::tuple<
|
||||
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
results;
|
||||
unique_ids_offsets_event.synchronize();
|
||||
auto unique_ids_offsets_ptr = unique_ids_offsets.data_ptr<int64_t>();
|
||||
for (int64_t i = 0; i < num_batches; i++) {
|
||||
results.emplace_back(
|
||||
unique_ids.slice(
|
||||
0, unique_ids_offsets_ptr[i], unique_ids_offsets_ptr[i + 1]),
|
||||
0, unique_ids_offsets_ptr[i * world_size],
|
||||
unique_ids_offsets_ptr[(i + 1) * world_size]),
|
||||
mapped_ids.slice(
|
||||
0, offsets_ptr[2 * i + 1], offsets_ptr[2 * i + 2]),
|
||||
mapped_ids.slice(
|
||||
0, offsets_ptr[2 * num_batches + i],
|
||||
offsets_ptr[2 * num_batches + i + 1]));
|
||||
offsets_ptr[2 * num_batches + i + 1]),
|
||||
unique_ids_offsets.slice(
|
||||
0, i * world_size, (i + 1) * world_size + 1));
|
||||
}
|
||||
return results;
|
||||
}));
|
||||
|
||||
@@ -106,7 +106,7 @@ struct EdgeTypeSearch {
|
||||
const auto indptr_i = sub_indptr[homo_i];
|
||||
const auto degree = sub_indptr[homo_i + 1] - indptr_i;
|
||||
const etype_t etype = i % num_fanouts;
|
||||
auto offset = cuda::LowerBound(etypes + indptr_i, degree, etype);
|
||||
auto offset = cub::LowerBound(etypes + indptr_i, degree, etype);
|
||||
new_sub_indptr[i] = indptr_i + offset;
|
||||
new_sliced_indptr[i] = sliced_indptr[homo_i] + offset;
|
||||
if (i == num_rows - 1) new_sub_indptr[num_rows] = indptr_i + degree;
|
||||
|
||||
@@ -282,8 +282,15 @@ UniqueAndCompactBatched(
|
||||
// Utilizes a hash table based implementation, the mapped id of a vertex
|
||||
// will be monotonically increasing as the first occurrence index of it in
|
||||
// torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic.
|
||||
return UniqueAndCompactBatchedHashMapBased(
|
||||
auto results4 = UniqueAndCompactBatchedHashMapBased(
|
||||
src_ids, dst_ids, unique_dst_ids, rank, world_size);
|
||||
std::vector<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>>
|
||||
results3;
|
||||
// TODO @mfbalin: expose the `d` result in a later PR.
|
||||
for (const auto& [a, b, c, d] : results4) {
|
||||
results3.emplace_back(a, b, c);
|
||||
}
|
||||
return results3;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
world_size <= 1,
|
||||
|
||||
@@ -61,56 +61,6 @@ int NumberOfBits(const T& range) {
|
||||
return bits;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Given a sorted array and a value this function returns the index
|
||||
* of the first element which compares greater than or equal to value.
|
||||
*
|
||||
* This function assumes 0-based index
|
||||
* @param A: ascending sorted array
|
||||
* @param n: size of the A
|
||||
* @param x: value to search in A
|
||||
* @return index, i, of the first element st. A[i]>=x. If x>A[n-1] returns n.
|
||||
* if x<A[0] then it returns 0.
|
||||
*/
|
||||
template <typename indptr_t, typename indices_t>
|
||||
__device__ indices_t LowerBound(const indptr_t* A, indices_t n, indptr_t x) {
|
||||
indices_t l = 0, r = n;
|
||||
while (l < r) {
|
||||
const auto m = l + (r - l) / 2;
|
||||
if (x > A[m]) {
|
||||
l = m + 1;
|
||||
} else {
|
||||
r = m;
|
||||
}
|
||||
}
|
||||
return l;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Given a sorted array and a value this function returns the index
|
||||
* of the first element which compares greater than value.
|
||||
*
|
||||
* This function assumes 0-based index
|
||||
* @param A: ascending sorted array
|
||||
* @param n: size of the A
|
||||
* @param x: value to search in A
|
||||
* @return index, i, of the first element st. A[i]>x. If x>=A[n-1] returns n.
|
||||
* if x<A[0] then it returns 0.
|
||||
*/
|
||||
template <typename indptr_t, typename indices_t>
|
||||
__device__ indices_t UpperBound(const indptr_t* A, indices_t n, indptr_t x) {
|
||||
indices_t l = 0, r = n;
|
||||
while (l < r) {
|
||||
const auto m = l + (r - l) / 2;
|
||||
if (x >= A[m]) {
|
||||
l = m + 1;
|
||||
} else {
|
||||
r = m;
|
||||
}
|
||||
}
|
||||
return l;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace graphbolt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user