From e8022e9494b556c90e563fdd21e0f19da5da9ffd Mon Sep 17 00:00:00 2001 From: Muhammed Fatih BALIN Date: Sun, 8 Sep 2024 20:37:52 -0400 Subject: [PATCH] [GraphBolt][CUDA] Expose `UniqueAndCompact` offsets. (#7789) --- graphbolt/include/graphbolt/cuda_ops.h | 11 ++- .../include/graphbolt/unique_and_compact.h | 15 ++- graphbolt/src/cuda/unique_and_compact_impl.cu | 31 ++++--- graphbolt/src/python_binding.cc | 9 +- graphbolt/src/unique_and_compact.cc | 19 ++-- .../dgl/graphbolt/impl/in_subgraph_sampler.py | 1 + python/dgl/graphbolt/impl/neighbor_sampler.py | 7 +- python/dgl/graphbolt/internal/sample_utils.py | 93 +++++++++++++++---- python/dgl/graphbolt/subgraph_sampler.py | 4 +- .../graphbolt/internal/test_sample_utils.py | 8 +- 10 files changed, 144 insertions(+), 54 deletions(-) diff --git a/graphbolt/include/graphbolt/cuda_ops.h b/graphbolt/include/graphbolt/cuda_ops.h index 07feaeb6b5..857f634246 100644 --- a/graphbolt/include/graphbolt/cuda_ops.h +++ b/graphbolt/include/graphbolt/cuda_ops.h @@ -288,7 +288,7 @@ torch::Tensor IndptrEdgeIdsImpl( * @param rank The rank of the current GPU. * @param world_size The total # GPUs, world size. * - * @return + * @return (unique_ids, compacted_src_ids, compacted_dst_ids, unique_offsets) * - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after * removing duplicates. The indices in this tensor precisely match the compacted * IDs of the corresponding elements. @@ -296,6 +296,9 @@ torch::Tensor IndptrEdgeIdsImpl( * mapped to compacted IDs. * - The tensor corresponding to the 'dst_ids' tensor, where the entries are * mapped to compacted IDs. + * - The tensor corresponding to the offsets into the unique_ids tensor. Has + * size `world_size + 1` and unique_ids[offsets[i]: offsets[i + 1]] belongs to + * the rank `(rank + i) % world_size`. * * @example * torch::Tensor src_ids = src @@ -306,7 +309,8 @@ torch::Tensor IndptrEdgeIdsImpl( * torch::Tensor compacted_src_ids = std::get<1>(result); * torch::Tensor compacted_dst_ids = std::get<2>(result); */ -std::tuple UniqueAndCompact( +std::tuple +UniqueAndCompact( const torch::Tensor src_ids, const torch::Tensor dst_ids, const torch::Tensor unique_dst_ids, const int64_t rank, const int64_t world_size); @@ -316,7 +320,8 @@ std::tuple UniqueAndCompact( * value is equal to the passing the ith elements of the input arguments to * UniqueAndCompact. */ -std::vector> +std::vector< + std::tuple> UniqueAndCompactBatched( const std::vector& src_ids, const std::vector& dst_ids, diff --git a/graphbolt/include/graphbolt/unique_and_compact.h b/graphbolt/include/graphbolt/unique_and_compact.h index db61c2b6f9..6a7a5cb3b3 100644 --- a/graphbolt/include/graphbolt/unique_and_compact.h +++ b/graphbolt/include/graphbolt/unique_and_compact.h @@ -38,7 +38,7 @@ namespace sampling { * @param rank The rank of the current GPU. * @param world_size The total # GPUs, world size. * - * @return + * @return (unique_ids, compacted_src_ids, compacted_dst_ids, unique_offsets) * - A tensor representing all unique elements in 'src_ids' and 'dst_ids' after * removing duplicates. The indices in this tensor precisely match the compacted * IDs of the corresponding elements. @@ -46,6 +46,9 @@ namespace sampling { * mapped to compacted IDs. * - The tensor corresponding to the 'dst_ids' tensor, where the entries are * mapped to compacted IDs. + * - The tensor corresponding to the offsets into the unique_ids tensor. Has + * size `world_size + 1` and unique_ids[offsets[i]: offsets[i + 1]] belongs to + * the rank `(rank + i) % world_size`. * * @example * torch::Tensor src_ids = src @@ -56,20 +59,22 @@ namespace sampling { * torch::Tensor compacted_src_ids = std::get<1>(result); * torch::Tensor compacted_dst_ids = std::get<2>(result); */ -std::tuple UniqueAndCompact( +std::tuple +UniqueAndCompact( const torch::Tensor& src_ids, const torch::Tensor& dst_ids, const torch::Tensor unique_dst_ids, const int64_t rank, const int64_t world_size); -std::vector> +std::vector< + std::tuple> UniqueAndCompactBatched( const std::vector& src_ids, const std::vector& dst_ids, const std::vector unique_dst_ids, const int64_t rank, const int64_t world_size); -c10::intrusive_ptr>>> +c10::intrusive_ptr>>> UniqueAndCompactBatchedAsync( const std::vector& src_ids, const std::vector& dst_ids, diff --git a/graphbolt/src/cuda/unique_and_compact_impl.cu b/graphbolt/src/cuda/unique_and_compact_impl.cu index a630b78e14..ffe0501de7 100644 --- a/graphbolt/src/cuda/unique_and_compact_impl.cu +++ b/graphbolt/src/cuda/unique_and_compact_impl.cu @@ -272,7 +272,8 @@ UniqueAndCompactBatchedSortBased( })); } -std::vector> +std::vector< + std::tuple> UniqueAndCompactBatched( const std::vector& src_ids, const std::vector& dst_ids, @@ -282,15 +283,8 @@ UniqueAndCompactBatched( // Utilizes a hash table based implementation, the mapped id of a vertex // will be monotonically increasing as the first occurrence index of it in // torch.cat([unique_dst_ids, src_ids]). Thus, it is deterministic. - auto results4 = UniqueAndCompactBatchedHashMapBased( + return UniqueAndCompactBatchedHashMapBased( src_ids, dst_ids, unique_dst_ids, rank, world_size); - std::vector> - results3; - // TODO @mfbalin: expose the `d` result in a later PR. - for (const auto& [a, b, c, d] : results4) { - results3.emplace_back(a, b, c); - } - return results3; } TORCH_CHECK( world_size <= 1, @@ -299,10 +293,25 @@ UniqueAndCompactBatched( // Utilizes a sort based algorithm, the mapped id of a vertex part of the // src_ids but not part of the unique_dst_ids will be monotonically increasing // as the actual vertex id increases. Thus, it is deterministic. - return UniqueAndCompactBatchedSortBased(src_ids, dst_ids, unique_dst_ids); + auto results3 = + UniqueAndCompactBatchedSortBased(src_ids, dst_ids, unique_dst_ids); + std::vector< + std::tuple> + results4; + auto offsets = torch::zeros( + 2 * results3.size(), + c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true)); + for (const auto& [a, b, c] : results3) { + auto d = offsets.slice(0, 0, 2); + d.data_ptr()[1] = a.size(0); + results4.emplace_back(a, b, c, d); + offsets = offsets.slice(0, 2); + } + return results4; } -std::tuple UniqueAndCompact( +std::tuple +UniqueAndCompact( const torch::Tensor src_ids, const torch::Tensor dst_ids, const torch::Tensor unique_dst_ids, const int64_t rank, const int64_t world_size) { diff --git a/graphbolt/src/python_binding.cc b/graphbolt/src/python_binding.cc index 20c6d59be5..35ab345c56 100644 --- a/graphbolt/src/python_binding.cc +++ b/graphbolt/src/python_binding.cc @@ -51,13 +51,14 @@ TORCH_LIBRARY(graphbolt, m) { m.class_>>( "FusedSampledSubgraphFuture") .def("wait", &Future>::Wait); - m.class_>>>( + m.class_>>>( "UniqueAndCompactBatchedFuture") .def( "wait", - &Future>>::Wait); + &Future>>:: + Wait); m.class_>>( "GpuGraphCacheQueryFuture") .def( diff --git a/graphbolt/src/unique_and_compact.cc b/graphbolt/src/unique_and_compact.cc index 03fb8f514f..bbed379f79 100644 --- a/graphbolt/src/unique_and_compact.cc +++ b/graphbolt/src/unique_and_compact.cc @@ -14,7 +14,8 @@ namespace graphbolt { namespace sampling { -std::tuple UniqueAndCompact( +std::tuple +UniqueAndCompact( const torch::Tensor& src_ids, const torch::Tensor& dst_ids, const torch::Tensor unique_dst_ids, const int64_t rank, const int64_t world_size) { @@ -31,16 +32,20 @@ std::tuple UniqueAndCompact( "Cooperative Minibatching (arXiv:2310.12403) is supported only on GPUs."); auto num_dst = unique_dst_ids.size(0); torch::Tensor ids = torch::cat({unique_dst_ids, src_ids}); - return AT_DISPATCH_INDEX_TYPES( + auto [unique_ids, compacted_src, compacted_dst] = AT_DISPATCH_INDEX_TYPES( ids.scalar_type(), "unique_and_compact", ([&] { ConcurrentIdHashMap id_map(ids, num_dst); return std::make_tuple( id_map.GetUniqueIds(), id_map.MapIds(src_ids), id_map.MapIds(dst_ids)); })); + auto offsets = torch::zeros(2, c10::TensorOptions().dtype(torch::kInt64)); + offsets.data_ptr()[1] = unique_ids.size(0); + return {unique_ids, compacted_src, compacted_dst, offsets}; } -std::vector> +std::vector< + std::tuple> UniqueAndCompactBatched( const std::vector& src_ids, const std::vector& dst_ids, @@ -64,7 +69,9 @@ UniqueAndCompactBatched( src_ids, dst_ids, unique_dst_ids, rank, world_size); }); } - std::vector> results; + std::vector< + std::tuple> + results; results.reserve(src_ids.size()); for (std::size_t i = 0; i < src_ids.size(); i++) { results.emplace_back(UniqueAndCompact( @@ -73,8 +80,8 @@ UniqueAndCompactBatched( return results; } -c10::intrusive_ptr>>> +c10::intrusive_ptr>>> UniqueAndCompactBatchedAsync( const std::vector& src_ids, const std::vector& dst_ids, diff --git a/python/dgl/graphbolt/impl/in_subgraph_sampler.py b/python/dgl/graphbolt/impl/in_subgraph_sampler.py index de98f24ba0..cc2d515de0 100644 --- a/python/dgl/graphbolt/impl/in_subgraph_sampler.py +++ b/python/dgl/graphbolt/impl/in_subgraph_sampler.py @@ -74,6 +74,7 @@ class InSubgraphSampler(SubgraphSampler): ( original_row_node_ids, compacted_csc_formats, + _, ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds) subgraph = SampledSubgraphImpl( sampled_csc=compacted_csc_formats, diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 4229edb6be..e14a586fc5 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -471,6 +471,7 @@ class CompactPerLayer(MiniBatchTransformer): ( original_row_node_ids, compacted_csc_format, + _, ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds) subgraph = SampledSubgraphImpl( sampled_csc=compacted_csc_format, @@ -506,7 +507,11 @@ class CompactPerLayer(MiniBatchTransformer): def _compact_per_layer_wait_future(minibatch): subgraph = minibatch.sampled_subgraphs[0] seeds = minibatch._seed_nodes - original_row_node_ids, compacted_csc_format = minibatch._future.wait() + ( + original_row_node_ids, + compacted_csc_format, + _, + ) = minibatch._future.wait() delattr(minibatch, "_future") subgraph = SampledSubgraphImpl( sampled_csc=compacted_csc_format, diff --git a/python/dgl/graphbolt/internal/sample_utils.py b/python/dgl/graphbolt/internal/sample_utils.py index d840d5bf4a..013f102e35 100644 --- a/python/dgl/graphbolt/internal/sample_utils.py +++ b/python/dgl/graphbolt/internal/sample_utils.py @@ -13,9 +13,26 @@ def unique_and_compact( List[torch.Tensor], Dict[str, List[torch.Tensor]], ], + rank: int = 0, + world_size: int = 1, ): """ - Compact a list of nodes tensor. + Compact a list of nodes tensor. The `rank` and `world_size` parameters are + relevant when using Cooperative Minibatching, which was initially proposed + in `Deep Graph Library PR#4337`__ and + was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__ + Cooperation between the GPUs eliminates duplicate work performed across the + GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when + performing GNN minibatching. + + When `world_size` is greater than 1, then the given ids are partitioned + between the available ranks. The ids corresponding to the given rank are + guaranteed to come before the ids of other ranks. To do this, the + partitioned ids are rotated backwards by the given rank so that the ids are + ordered as: `[rank, rank + 1, world_size, 0, ..., rank - 1]`. This is + supported only for Volta and later generation NVIDIA GPUs. Parameters ---------- @@ -27,15 +44,22 @@ def unique_and_compact( - If `nodes` is a list of dictionary: The keys should be node type and the values should be corresponding nodes, the unique and compact will be done per type, usually it is used for heterogeneous graph. + rank : int + The rank of the current process. + world_size : int + The number of processes. Returns ------- - Tuple[unique_nodes, compacted_node_list] + Tuple[unique_nodes, compacted_node_list, unique_nodes_offsets] The Unique nodes (per type) of all nodes in the input. And the compacted nodes list, where IDs inside are replaced with compacted node IDs. "Compacted node list" indicates that the node IDs in the input node list are replaced with mapped node IDs, where each type of node is mapped to a contiguous space of IDs ranging from 0 to N. + The unique nodes offsets tensor partitions the unique_nodes tensor. Has + size `world_size + 1` and unique_nodes[offsets[i]: offsets[i + 1]] + belongs to the rank `(rank + i) % world_size`. """ is_heterogeneous = isinstance(nodes, dict) @@ -43,19 +67,21 @@ def unique_and_compact( nums = [node.size(0) for node in nodes] nodes = torch.cat(nodes) empty_tensor = nodes.new_empty(0) - unique, compacted, _ = torch.ops.graphbolt.unique_and_compact( - nodes, empty_tensor, empty_tensor, 0, 1 + unique, compacted, _, offsets = torch.ops.graphbolt.unique_and_compact( + nodes, empty_tensor, empty_tensor, rank, world_size ) compacted = compacted.split(nums) - return unique, list(compacted) + return unique, list(compacted), offsets if is_heterogeneous: - unique, compacted = {}, {} + unique, compacted, offsets = {}, {}, {} for ntype, nodes_of_type in nodes.items(): - unique[ntype], compacted[ntype] = unique_and_compact_per_type( - nodes_of_type - ) - return unique, compacted + ( + unique[ntype], + compacted[ntype], + offsets[ntype], + ) = unique_and_compact_per_type(nodes_of_type) + return unique, compacted, offsets else: return unique_and_compact_per_type(nodes) @@ -124,10 +150,28 @@ def unique_and_compact_csc_formats( torch.Tensor, Dict[str, torch.Tensor], ], + rank: int = 0, + world_size: int = 1, async_op: bool = False, ): """ - Compact csc formats and return unique nodes (per type). + Compact csc formats and return unique nodes (per type). The `rank` and + `world_size` parameters are relevant when using Cooperative Minibatching, + which was initially proposed in + `Deep Graph Library PR#4337`__ + and was later first fully described in + `Cooperative Minibatching in Graph Neural Networks + `__ + Cooperation between the GPUs eliminates duplicate work performed across the + GPUs due to the overlapping sampled k-hop neighborhoods of seed nodes when + performing GNN minibatching. + + When `world_size` is greater than 1, then the given ids are partitioned + between the available ranks. The ids corresponding to the given rank are + guaranteed to come before the ids of other ranks. To do this, the + partitioned ids are rotated backwards by the given rank so that the ids are + ordered as: `[rank, rank + 1, world_size, 0, ..., rank - 1]`. This is + supported only for Volta and later generation NVIDIA GPUs. Parameters ---------- @@ -145,18 +189,25 @@ def unique_and_compact_csc_formats( - If `unique_dst_nodes` is a tensor: It means the graph is homogeneous. - If `csc_formats` is a dictionary: The keys are node type and the values are corresponding nodes. And IDs inside are heterogeneous ids. + rank : int + The rank of the current process. + world_size : int + The number of processes. async_op: bool Boolean indicating whether the call is asynchronous. If so, the result can be obtained by calling wait on the returned future. Returns ------- - Tuple[csc_formats, unique_nodes] + Tuple[unique_nodes, csc_formats, unique_nodes_offsets] The compacted csc formats, where node IDs are replaced with mapped node IDs, and the unique nodes (per type). "Compacted csc formats" indicates that the node IDs in the input node pairs are replaced with mapped node IDs, where each type of node is - mapped to a contiguous space of IDs ranging from 0 to N. + mapped to a contiguous space of IDs ranging from 0 to N. The unique + nodes offsets tensor partitions the unique_nodes tensor. Has size + `world_size + 1` and unique_nodes[offsets[i]: offsets[i + 1]] belongs to + the rank `(rank + i) % world_size`. Examples -------- @@ -169,7 +220,7 @@ def unique_and_compact_csc_formats( >>> csc_formats = { ... "n1:e1:n2": gb.CSCFormatBase(indptr=torch.tensor([0, 2, 3]),indices=N1), ... "n2:e2:n1": gb.CSCFormatBase(indptr=torch.tensor([0, 1, 3]),indices=N2)} - >>> unique_nodes, compacted_csc_formats = gb.unique_and_compact_csc_formats( + >>> unique_nodes, compacted_csc_formats, _ = gb.unique_and_compact_csc_formats( ... csc_formats, unique_dst ... ) >>> print(unique_nodes) @@ -213,12 +264,12 @@ def unique_and_compact_csc_formats( dst_list = [torch.tensor([], dtype=dtype, device=device)] * len( unique_dst_list ) - unique_fn = ( + uniq_fn = ( torch.ops.graphbolt.unique_and_compact_batched_async if async_op else torch.ops.graphbolt.unique_and_compact_batched ) - results = unique_fn(indice_list, dst_list, unique_dst_list, 0, 1) + results = uniq_fn(indice_list, dst_list, unique_dst_list, rank, world_size) class _Waiter: def __init__(self, future, csc_formats): @@ -234,8 +285,14 @@ def unique_and_compact_csc_formats( unique_nodes = {} compacted_indices = {} + offsets = {} for i, ntype in enumerate(ntypes): - unique_nodes[ntype], compacted_indices[ntype], _ = results[i] + ( + unique_nodes[ntype], + compacted_indices[ntype], + _, + offsets[ntype], + ) = results[i] compacted_csc_formats = {} # Map back with the same order. @@ -256,7 +313,7 @@ def unique_and_compact_csc_formats( compacted_csc_formats = list(compacted_csc_formats.values())[0] unique_nodes = list(unique_nodes.values())[0] - return unique_nodes, compacted_csc_formats + return unique_nodes, compacted_csc_formats, offsets post_processer = _Waiter(results, csc_formats) if async_op: diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 3e56bd75ff..97086d3153 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -163,7 +163,7 @@ class SubgraphSampler(MiniBatchTransformer): compacted, ) = compact_temporal_nodes(nodes, nodes_timestamp) else: - unique_seeds, compacted = unique_and_compact(nodes) + unique_seeds, compacted, _ = unique_and_compact(nodes) nodes_timestamp = None compacted_seeds = {} # Map back in same order as collect. @@ -212,7 +212,7 @@ class SubgraphSampler(MiniBatchTransformer): compacted, ) = compact_temporal_nodes(nodes, nodes_timestamp) else: - unique_seeds, compacted = unique_and_compact(nodes) + unique_seeds, compacted, _ = unique_and_compact(nodes) nodes_timestamp = None # Map back in same order as collect. compacted_seeds = compacted[0].view(seeds.shape) diff --git a/tests/python/pytorch/graphbolt/internal/test_sample_utils.py b/tests/python/pytorch/graphbolt/internal/test_sample_utils.py index bb27cb8c18..aadf4081f2 100644 --- a/tests/python/pytorch/graphbolt/internal/test_sample_utils.py +++ b/tests/python/pytorch/graphbolt/internal/test_sample_utils.py @@ -50,7 +50,7 @@ def test_unique_and_compact_hetero(): ], } - unique, compacted = gb.unique_and_compact(nodes_dict) + unique, compacted, _ = gb.unique_and_compact(nodes_dict) for ntype, nodes in unique.items(): expected_nodes = expected_unique[ntype] assert torch.equal(nodes, expected_nodes) @@ -84,7 +84,7 @@ def test_unique_and_compact_homo(): torch.tensor([7, 8, 9, 0, 5], device=F.ctx()), ] - unique, compacted = gb.unique_and_compact(nodes_list) + unique, compacted, _ = gb.unique_and_compact(nodes_list) assert torch.equal(unique, expected_unique_N) assert isinstance(compacted, list) @@ -133,7 +133,7 @@ def test_unique_and_compact_csc_formats_hetero(): ), } - unique_nodes, compacted_csc_formats = gb.unique_and_compact_csc_formats( + unique_nodes, compacted_csc_formats, _ = gb.unique_and_compact_csc_formats( csc_formats, dst_nodes ) @@ -159,7 +159,7 @@ def test_unique_and_compact_csc_formats_homo(): expected_indptr = indptr expected_indices = torch.tensor([3, 1, 0, 5, 2, 3, 2, 0, 5, 5, 4]) - unique_nodes, compacted_csc_formats = gb.unique_and_compact_csc_formats( + unique_nodes, compacted_csc_formats, _ = gb.unique_and_compact_csc_formats( csc_formats, seeds )