diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index 6ba83941bb..352dedc067 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -3,6 +3,7 @@ from functools import partial import torch +import torch.distributed as thd from torch.utils.data import functional_datapipe from torch.utils.data.datapipes.iter import Mapper @@ -12,10 +13,14 @@ from ..base import ( index_select, ORIGINAL_EDGE_ID, ) -from ..internal import compact_csc_format, unique_and_compact_csc_formats +from ..internal import ( + compact_csc_format, + unique_and_compact, + unique_and_compact_csc_formats, +) from ..minibatch_transformer import MiniBatchTransformer -from ..subgraph_sampler import SubgraphSampler +from ..subgraph_sampler import all_to_all, revert_to_homo, SubgraphSampler from .fused_csc_sampling_graph import fused_csc_sampling_graph from .sampled_subgraph_impl import SampledSubgraphImpl @@ -455,12 +460,32 @@ class SamplePerLayer(MiniBatchTransformer): class CompactPerLayer(MiniBatchTransformer): """Compact the sampled edges for a single layer.""" - def __init__(self, datapipe, deduplicate, asynchronous=False): + def __init__( + self, datapipe, deduplicate, cooperative=False, asynchronous=False + ): self.deduplicate = deduplicate + self.cooperative = cooperative if asynchronous and deduplicate: datapipe = datapipe.transform(self._compact_per_layer_async) datapipe = datapipe.buffer() - super().__init__(datapipe, self._compact_per_layer_wait_future) + datapipe = datapipe.transform(self._compact_per_layer_wait_future) + if cooperative: + datapipe = datapipe.transform( + self._seeds_cooperative_exchange_1 + ) + datapipe = datapipe.buffer() + datapipe = datapipe.transform( + self._seeds_cooperative_exchange_2 + ) + datapipe = datapipe.buffer() + datapipe = datapipe.transform( + self._seeds_cooperative_exchange_3 + ) + datapipe = datapipe.buffer() + datapipe = datapipe.transform( + self._seeds_cooperative_exchange_4 + ) + super().__init__(datapipe) else: super().__init__(datapipe, self._compact_per_layer) @@ -498,19 +523,20 @@ class CompactPerLayer(MiniBatchTransformer): subgraph = minibatch.sampled_subgraphs[0] seeds = minibatch._seed_nodes assert self.deduplicate + rank = thd.get_rank() if self.cooperative else 0 + world_size = thd.get_world_size() if self.cooperative else 1 minibatch._future = unique_and_compact_csc_formats( - subgraph.sampled_csc, seeds, async_op=True + subgraph.sampled_csc, seeds, rank, world_size, async_op=True ) return minibatch - @staticmethod - def _compact_per_layer_wait_future(minibatch): + def _compact_per_layer_wait_future(self, minibatch): subgraph = minibatch.sampled_subgraphs[0] seeds = minibatch._seed_nodes ( original_row_node_ids, compacted_csc_format, - _, + seeds_offsets, ) = minibatch._future.wait() delattr(minibatch, "_future") subgraph = SampledSubgraphImpl( @@ -521,6 +547,87 @@ class CompactPerLayer(MiniBatchTransformer): ) minibatch._seed_nodes = original_row_node_ids minibatch.sampled_subgraphs[0] = subgraph + if self.cooperative: + subgraph._seeds_offsets = seeds_offsets + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_1(minibatch): + world_size = thd.get_world_size() + subgraph = minibatch.sampled_subgraphs[0] + seeds_offsets = subgraph._seeds_offsets + is_homogeneous = not isinstance(seeds_offsets, dict) + if is_homogeneous: + seeds_offsets = {"_N": seeds_offsets} + num_ntypes = len(seeds_offsets) + counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) + for i, offsets in enumerate(seeds_offsets.values()): + counts_sent[ + torch.arange(i, world_size * num_ntypes, num_ntypes) + ] = offsets.diff() + counts_received = torch.empty_like(counts_sent) + subgraph._counts_future = all_to_all( + counts_received.split(num_ntypes), + counts_sent.split(num_ntypes), + async_op=True, + ) + subgraph._counts_sent = counts_sent + subgraph._counts_received = counts_received + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_2(minibatch): + world_size = thd.get_world_size() + seeds = minibatch._seed_nodes + is_homogenous = not isinstance(seeds, dict) + if is_homogenous: + seeds = {"_N": seeds} + subgraph = minibatch.sampled_subgraphs[0] + subgraph._counts_future.wait() + delattr(subgraph, "_counts_future") + num_ntypes = len(seeds.keys()) + seeds_received = {} + counts_sent = {} + counts_received = {} + for i, (ntype, typed_seeds) in enumerate(seeds.items()): + idx = torch.arange(i, world_size * num_ntypes, num_ntypes) + typed_counts_sent = subgraph._counts_sent[idx].tolist() + typed_counts_received = subgraph._counts_received[idx].tolist() + typed_seeds_received = typed_seeds.new_empty( + sum(typed_counts_received) + ) + all_to_all( + typed_seeds_received.split(typed_counts_received), + typed_seeds.split(typed_counts_sent), + ) + seeds_received[ntype] = typed_seeds_received + subgraph._seeds_received = seeds_received + subgraph._counts_sent = revert_to_homo(counts_sent) + subgraph._counts_received = revert_to_homo(counts_received) + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_3(minibatch): + subgraph = minibatch.sampled_subgraphs[0] + nodes = { + ntype: [typed_seeds] + for ntype, typed_seeds in subgraph._seeds_received.items() + } + minibatch._unique_future = unique_and_compact( + nodes, 0, 1, async_op=True + ) + return minibatch + + @staticmethod + def _seeds_cooperative_exchange_4(minibatch): + unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait() + delattr(minibatch, "_unique_future") + inverse_seeds = { + ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items() + } + minibatch._seed_nodes = revert_to_homo(unique_seeds) + subgraph = minibatch.sampled_subgraphs[0] + subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds) return minibatch @@ -541,6 +648,7 @@ class NeighborSamplerImpl(SubgraphSampler): overlap_fetch, num_gpu_cached_edges, gpu_cache_threshold, + cooperative, asynchronous, layer_dependency=None, batch_dependency=None, @@ -561,6 +669,7 @@ class NeighborSamplerImpl(SubgraphSampler): deduplicate, sampler, overlap_fetch, + cooperative=cooperative, asynchronous=asynchronous, layer_dependency=layer_dependency, ) @@ -637,6 +746,7 @@ class NeighborSamplerImpl(SubgraphSampler): deduplicate, sampler, overlap_fetch, + cooperative, asynchronous, layer_dependency, ): @@ -653,7 +763,9 @@ class NeighborSamplerImpl(SubgraphSampler): datapipe = datapipe.sample_per_layer( sampler, fanout, replace, prob_name, overlap_fetch, asynchronous ) - datapipe = datapipe.compact_per_layer(deduplicate, asynchronous) + datapipe = datapipe.compact_per_layer( + deduplicate, cooperative, asynchronous + ) if is_labor and not layer_dependency: datapipe = datapipe.transform(self._increment_seed) if is_labor: @@ -775,6 +887,7 @@ class NeighborSampler(NeighborSamplerImpl): overlap_fetch=False, num_gpu_cached_edges=0, gpu_cache_threshold=1, + cooperative=False, asynchronous=False, ): super().__init__( @@ -788,6 +901,7 @@ class NeighborSampler(NeighborSamplerImpl): overlap_fetch, num_gpu_cached_edges, gpu_cache_threshold, + cooperative, asynchronous, ) @@ -937,6 +1051,7 @@ class LayerNeighborSampler(NeighborSamplerImpl): overlap_fetch=False, num_gpu_cached_edges=0, gpu_cache_threshold=1, + cooperative=False, asynchronous=False, ): super().__init__( @@ -950,6 +1065,7 @@ class LayerNeighborSampler(NeighborSamplerImpl): overlap_fetch, num_gpu_cached_edges, gpu_cache_threshold, + cooperative, asynchronous, layer_dependency, batch_dependency, diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 88fdd38087..556950982f 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -15,6 +15,8 @@ from .minibatch_transformer import MiniBatchTransformer __all__ = [ "SubgraphSampler", + "all_to_all", + "revert_to_homo", ] @@ -41,10 +43,48 @@ def all_to_all(outputs, inputs, group=None, async_op=False): `rank, ..., world_size - 1, 0, ..., rank - 1` and we make it `0, world_size - 1` before calling `thd.all_to_all`.""" shift_fn = partial(_shift, group=group) - return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op) + outputs = shift_fn(list(outputs)) + inputs = shift_fn(list(inputs)) + if outputs[0].is_cuda: + return thd.all_to_all(outputs, inputs, group, async_op) + # gloo backend will be used. + outputs_single = torch.cat(outputs) + output_split_sizes = [o.size(0) for o in outputs] + handle = thd.all_to_all_single( + outputs_single, + torch.cat(inputs), + output_split_sizes, + [i.size(0) for i in inputs], + group, + async_op, + ) + temp_outputs = outputs_single.split(output_split_sizes) + + class _Waiter: + def __init__(self, handle, outputs, temp_outputs): + self.handle = handle + self.outputs = outputs + self.temp_outputs = temp_outputs + + def wait(self): + """Returns the stored value when invoked.""" + handle = self.handle + outputs = self.outputs + temp_outputs = self.temp_outputs + # Ensure that there is no leak + self.handle = self.outputs = self.temp_outputs = None + + if handle is not None: + handle.wait() + for output, temp_output in zip(outputs, temp_outputs): + output.copy_(temp_output) + + post_processor = _Waiter(handle, outputs, temp_outputs) + return post_processor if async_op else post_processor.wait() -def _revert_to_homo(d: dict): +def revert_to_homo(d: dict): + """Utility function to convert a dictionary that stores homogenous data.""" is_homogenous = len(d) == 1 and "_N" in d return list(d.values())[0] if is_homogenous else d @@ -148,45 +188,31 @@ class SubgraphSampler(MiniBatchTransformer): def _seeds_cooperative_exchange_1(minibatch, group=None): rank = thd.get_rank(group) world_size = thd.get_world_size(group) - assert world_size > 1 seeds = minibatch._seed_nodes is_homogeneous = not isinstance(seeds, dict) if is_homogeneous: seeds = {"_N": seeds} if minibatch._seeds_offsets is None: seeds_list = list(seeds.values()) - ( - sorted_seeds_list, - index_list, - offsets_list, - ) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size) + result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size) assert minibatch.compacted_seeds is None sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {} num_ntypes = len(seeds.keys()) for i, ( seed_type, - typed_sorted_seeds, - typed_index, - typed_offsets, - ) in enumerate( - zip( - seeds.keys(), - sorted_seeds_list, - index_list, - offsets_list, - ) - ): + (typed_sorted_seeds, typed_index, typed_offsets), + ) in enumerate(zip(seeds.keys(), result)): sorted_seeds[seed_type] = typed_sorted_seeds sorted_compacted[seed_type] = typed_index - sorted_offsets[seed_type] = typed_offsets.tolist() + sorted_offsets[seed_type] = typed_offsets minibatch._seed_nodes = sorted_seeds - minibatch.compacted_seeds = sorted_compacted + minibatch.compacted_seeds = revert_to_homo(sorted_compacted) minibatch._seeds_offsets = sorted_offsets else: minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets} counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64) - for i, offsets in enumerate(minibatch._seeds_offsets[0].values()): + for i, offsets in enumerate(minibatch._seeds_offsets.values()): counts_sent[ torch.arange(i, world_size * num_ntypes, num_ntypes) ] = offsets.diff() @@ -208,7 +234,6 @@ class SubgraphSampler(MiniBatchTransformer): seeds = minibatch._seed_nodes minibatch._counts_future.wait() delattr(minibatch, "_counts_future") - counts_received = minibatch._counts_received num_ntypes = len(seeds.keys()) seeds_received = {} counts_sent = {} @@ -226,15 +251,19 @@ class SubgraphSampler(MiniBatchTransformer): group, ) seeds_received[ntype] = typed_seeds_received - minibatch._seed_nodes = _revert_to_homo(seeds_received) - minibatch._counts_sent = _revert_to_homo(counts_sent) - minibatch._counts_received = _revert_to_homo(counts_received) + minibatch._seed_nodes = seeds_received + minibatch._counts_sent = revert_to_homo(counts_sent) + minibatch._counts_received = revert_to_homo(counts_received) return minibatch @staticmethod def _seeds_cooperative_exchange_3(minibatch): + nodes = { + ntype: [typed_seeds] + for ntype, typed_seeds in minibatch._seed_nodes.items() + } minibatch._unique_future = unique_and_compact( - minibatch._seed_nodes, 0, 1, async_op=True + nodes, 0, 1, async_op=True ) return minibatch @@ -242,8 +271,11 @@ class SubgraphSampler(MiniBatchTransformer): def _seeds_cooperative_exchange_4(minibatch): unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait() delattr(minibatch, "_unique_future") - minibatch._seed_nodes = _revert_to_homo(unique_seeds) - minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds) + inverse_seeds = { + ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items() + } + minibatch._seed_nodes = revert_to_homo(unique_seeds) + minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds) return minibatch def _sample(self, minibatch): diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 666ab352d2..b02c820dd6 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -1,4 +1,6 @@ +import os import unittest +from sys import platform import backend as F @@ -6,6 +8,7 @@ import dgl import dgl.graphbolt import pytest import torch +import torch.distributed as thd from dgl.graphbolt.datapipes import find_dps, traverse_dps @@ -63,6 +66,7 @@ def test_DataLoader(overlap_feature_fetch): @pytest.mark.parametrize("enable_feature_fetch", [True, False]) @pytest.mark.parametrize("overlap_feature_fetch", [True, False]) @pytest.mark.parametrize("overlap_graph_fetch", [True, False]) +@pytest.mark.parametrize("cooperative", [True, False]) @pytest.mark.parametrize("asynchronous", [True, False]) @pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024]) @pytest.mark.parametrize("gpu_cache_threshold", [1, 3]) @@ -71,10 +75,23 @@ def test_gpu_sampling_DataLoader( enable_feature_fetch, overlap_feature_fetch, overlap_graph_fetch, + cooperative, asynchronous, num_gpu_cached_edges, gpu_cache_threshold, ): + if cooperative and not thd.is_initialized(): + # On Windows, the init method can only be file. + init_method = ( + f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}" + if platform == "win32" + else "tcp://127.0.0.1:12345" + ) + thd.init_process_group( + init_method=init_method, + world_size=1, + rank=0, + ) N = 40 B = 4 num_layers = 2 @@ -110,6 +127,7 @@ def test_gpu_sampling_DataLoader( "overlap_fetch": overlap_graph_fetch, "num_gpu_cached_edges": num_gpu_cached_edges, "gpu_cache_threshold": gpu_cache_threshold, + "cooperative": cooperative, "asynchronous": asynchronous, } if i != 0: @@ -118,7 +136,7 @@ def test_gpu_sampling_DataLoader( datapipe, graph, fanouts=[torch.LongTensor([2]) for _ in range(num_layers)], - **kwargs + **kwargs, ) if enable_feature_fetch: datapipe = dgl.graphbolt.FeatureFetcher( @@ -138,6 +156,11 @@ def test_gpu_sampling_DataLoader( bufferer_cnt += 2 * num_layers if asynchronous: bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1. + if cooperative: + bufferer_cnt += 3 * num_layers + if cooperative: + # _preprocess stage and each sampling layer. + bufferer_cnt += 3 datapipe_graph = traverse_dps(dataloader) bufferers = find_dps( datapipe_graph, @@ -171,3 +194,5 @@ def test_gpu_sampling_DataLoader( if sampler_name == "LayerNeighborSampler": assert torch.equal(edge_feature, edge_feature_ref) assert len(list(dataloader)) == N // B + if thd.is_initialized(): + thd.destroy_process_group()