mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Get world_size=1 somewhat for cooperative sampling. (#7796)
This commit is contained in:
committed by
GitHub
parent
165e2507e7
commit
189b83c28c
@@ -3,6 +3,7 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as thd
|
||||
from torch.utils.data import functional_datapipe
|
||||
from torch.utils.data.datapipes.iter import Mapper
|
||||
|
||||
@@ -12,10 +13,14 @@ from ..base import (
|
||||
index_select,
|
||||
ORIGINAL_EDGE_ID,
|
||||
)
|
||||
from ..internal import compact_csc_format, unique_and_compact_csc_formats
|
||||
from ..internal import (
|
||||
compact_csc_format,
|
||||
unique_and_compact,
|
||||
unique_and_compact_csc_formats,
|
||||
)
|
||||
from ..minibatch_transformer import MiniBatchTransformer
|
||||
|
||||
from ..subgraph_sampler import SubgraphSampler
|
||||
from ..subgraph_sampler import all_to_all, revert_to_homo, SubgraphSampler
|
||||
from .fused_csc_sampling_graph import fused_csc_sampling_graph
|
||||
from .sampled_subgraph_impl import SampledSubgraphImpl
|
||||
|
||||
@@ -455,12 +460,32 @@ class SamplePerLayer(MiniBatchTransformer):
|
||||
class CompactPerLayer(MiniBatchTransformer):
|
||||
"""Compact the sampled edges for a single layer."""
|
||||
|
||||
def __init__(self, datapipe, deduplicate, asynchronous=False):
|
||||
def __init__(
|
||||
self, datapipe, deduplicate, cooperative=False, asynchronous=False
|
||||
):
|
||||
self.deduplicate = deduplicate
|
||||
self.cooperative = cooperative
|
||||
if asynchronous and deduplicate:
|
||||
datapipe = datapipe.transform(self._compact_per_layer_async)
|
||||
datapipe = datapipe.buffer()
|
||||
super().__init__(datapipe, self._compact_per_layer_wait_future)
|
||||
datapipe = datapipe.transform(self._compact_per_layer_wait_future)
|
||||
if cooperative:
|
||||
datapipe = datapipe.transform(
|
||||
self._seeds_cooperative_exchange_1
|
||||
)
|
||||
datapipe = datapipe.buffer()
|
||||
datapipe = datapipe.transform(
|
||||
self._seeds_cooperative_exchange_2
|
||||
)
|
||||
datapipe = datapipe.buffer()
|
||||
datapipe = datapipe.transform(
|
||||
self._seeds_cooperative_exchange_3
|
||||
)
|
||||
datapipe = datapipe.buffer()
|
||||
datapipe = datapipe.transform(
|
||||
self._seeds_cooperative_exchange_4
|
||||
)
|
||||
super().__init__(datapipe)
|
||||
else:
|
||||
super().__init__(datapipe, self._compact_per_layer)
|
||||
|
||||
@@ -498,19 +523,20 @@ class CompactPerLayer(MiniBatchTransformer):
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
seeds = minibatch._seed_nodes
|
||||
assert self.deduplicate
|
||||
rank = thd.get_rank() if self.cooperative else 0
|
||||
world_size = thd.get_world_size() if self.cooperative else 1
|
||||
minibatch._future = unique_and_compact_csc_formats(
|
||||
subgraph.sampled_csc, seeds, async_op=True
|
||||
subgraph.sampled_csc, seeds, rank, world_size, async_op=True
|
||||
)
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _compact_per_layer_wait_future(minibatch):
|
||||
def _compact_per_layer_wait_future(self, minibatch):
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
seeds = minibatch._seed_nodes
|
||||
(
|
||||
original_row_node_ids,
|
||||
compacted_csc_format,
|
||||
_,
|
||||
seeds_offsets,
|
||||
) = minibatch._future.wait()
|
||||
delattr(minibatch, "_future")
|
||||
subgraph = SampledSubgraphImpl(
|
||||
@@ -521,6 +547,87 @@ class CompactPerLayer(MiniBatchTransformer):
|
||||
)
|
||||
minibatch._seed_nodes = original_row_node_ids
|
||||
minibatch.sampled_subgraphs[0] = subgraph
|
||||
if self.cooperative:
|
||||
subgraph._seeds_offsets = seeds_offsets
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _seeds_cooperative_exchange_1(minibatch):
|
||||
world_size = thd.get_world_size()
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
seeds_offsets = subgraph._seeds_offsets
|
||||
is_homogeneous = not isinstance(seeds_offsets, dict)
|
||||
if is_homogeneous:
|
||||
seeds_offsets = {"_N": seeds_offsets}
|
||||
num_ntypes = len(seeds_offsets)
|
||||
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
|
||||
for i, offsets in enumerate(seeds_offsets.values()):
|
||||
counts_sent[
|
||||
torch.arange(i, world_size * num_ntypes, num_ntypes)
|
||||
] = offsets.diff()
|
||||
counts_received = torch.empty_like(counts_sent)
|
||||
subgraph._counts_future = all_to_all(
|
||||
counts_received.split(num_ntypes),
|
||||
counts_sent.split(num_ntypes),
|
||||
async_op=True,
|
||||
)
|
||||
subgraph._counts_sent = counts_sent
|
||||
subgraph._counts_received = counts_received
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _seeds_cooperative_exchange_2(minibatch):
|
||||
world_size = thd.get_world_size()
|
||||
seeds = minibatch._seed_nodes
|
||||
is_homogenous = not isinstance(seeds, dict)
|
||||
if is_homogenous:
|
||||
seeds = {"_N": seeds}
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
subgraph._counts_future.wait()
|
||||
delattr(subgraph, "_counts_future")
|
||||
num_ntypes = len(seeds.keys())
|
||||
seeds_received = {}
|
||||
counts_sent = {}
|
||||
counts_received = {}
|
||||
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
|
||||
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
|
||||
typed_counts_sent = subgraph._counts_sent[idx].tolist()
|
||||
typed_counts_received = subgraph._counts_received[idx].tolist()
|
||||
typed_seeds_received = typed_seeds.new_empty(
|
||||
sum(typed_counts_received)
|
||||
)
|
||||
all_to_all(
|
||||
typed_seeds_received.split(typed_counts_received),
|
||||
typed_seeds.split(typed_counts_sent),
|
||||
)
|
||||
seeds_received[ntype] = typed_seeds_received
|
||||
subgraph._seeds_received = seeds_received
|
||||
subgraph._counts_sent = revert_to_homo(counts_sent)
|
||||
subgraph._counts_received = revert_to_homo(counts_received)
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _seeds_cooperative_exchange_3(minibatch):
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
nodes = {
|
||||
ntype: [typed_seeds]
|
||||
for ntype, typed_seeds in subgraph._seeds_received.items()
|
||||
}
|
||||
minibatch._unique_future = unique_and_compact(
|
||||
nodes, 0, 1, async_op=True
|
||||
)
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _seeds_cooperative_exchange_4(minibatch):
|
||||
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
|
||||
delattr(minibatch, "_unique_future")
|
||||
inverse_seeds = {
|
||||
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
|
||||
}
|
||||
minibatch._seed_nodes = revert_to_homo(unique_seeds)
|
||||
subgraph = minibatch.sampled_subgraphs[0]
|
||||
subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)
|
||||
return minibatch
|
||||
|
||||
|
||||
@@ -541,6 +648,7 @@ class NeighborSamplerImpl(SubgraphSampler):
|
||||
overlap_fetch,
|
||||
num_gpu_cached_edges,
|
||||
gpu_cache_threshold,
|
||||
cooperative,
|
||||
asynchronous,
|
||||
layer_dependency=None,
|
||||
batch_dependency=None,
|
||||
@@ -561,6 +669,7 @@ class NeighborSamplerImpl(SubgraphSampler):
|
||||
deduplicate,
|
||||
sampler,
|
||||
overlap_fetch,
|
||||
cooperative=cooperative,
|
||||
asynchronous=asynchronous,
|
||||
layer_dependency=layer_dependency,
|
||||
)
|
||||
@@ -637,6 +746,7 @@ class NeighborSamplerImpl(SubgraphSampler):
|
||||
deduplicate,
|
||||
sampler,
|
||||
overlap_fetch,
|
||||
cooperative,
|
||||
asynchronous,
|
||||
layer_dependency,
|
||||
):
|
||||
@@ -653,7 +763,9 @@ class NeighborSamplerImpl(SubgraphSampler):
|
||||
datapipe = datapipe.sample_per_layer(
|
||||
sampler, fanout, replace, prob_name, overlap_fetch, asynchronous
|
||||
)
|
||||
datapipe = datapipe.compact_per_layer(deduplicate, asynchronous)
|
||||
datapipe = datapipe.compact_per_layer(
|
||||
deduplicate, cooperative, asynchronous
|
||||
)
|
||||
if is_labor and not layer_dependency:
|
||||
datapipe = datapipe.transform(self._increment_seed)
|
||||
if is_labor:
|
||||
@@ -775,6 +887,7 @@ class NeighborSampler(NeighborSamplerImpl):
|
||||
overlap_fetch=False,
|
||||
num_gpu_cached_edges=0,
|
||||
gpu_cache_threshold=1,
|
||||
cooperative=False,
|
||||
asynchronous=False,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -788,6 +901,7 @@ class NeighborSampler(NeighborSamplerImpl):
|
||||
overlap_fetch,
|
||||
num_gpu_cached_edges,
|
||||
gpu_cache_threshold,
|
||||
cooperative,
|
||||
asynchronous,
|
||||
)
|
||||
|
||||
@@ -937,6 +1051,7 @@ class LayerNeighborSampler(NeighborSamplerImpl):
|
||||
overlap_fetch=False,
|
||||
num_gpu_cached_edges=0,
|
||||
gpu_cache_threshold=1,
|
||||
cooperative=False,
|
||||
asynchronous=False,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -950,6 +1065,7 @@ class LayerNeighborSampler(NeighborSamplerImpl):
|
||||
overlap_fetch,
|
||||
num_gpu_cached_edges,
|
||||
gpu_cache_threshold,
|
||||
cooperative,
|
||||
asynchronous,
|
||||
layer_dependency,
|
||||
batch_dependency,
|
||||
|
||||
@@ -15,6 +15,8 @@ from .minibatch_transformer import MiniBatchTransformer
|
||||
|
||||
__all__ = [
|
||||
"SubgraphSampler",
|
||||
"all_to_all",
|
||||
"revert_to_homo",
|
||||
]
|
||||
|
||||
|
||||
@@ -41,10 +43,48 @@ def all_to_all(outputs, inputs, group=None, async_op=False):
|
||||
`rank, ..., world_size - 1, 0, ..., rank - 1` and we make it
|
||||
`0, world_size - 1` before calling `thd.all_to_all`."""
|
||||
shift_fn = partial(_shift, group=group)
|
||||
return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op)
|
||||
outputs = shift_fn(list(outputs))
|
||||
inputs = shift_fn(list(inputs))
|
||||
if outputs[0].is_cuda:
|
||||
return thd.all_to_all(outputs, inputs, group, async_op)
|
||||
# gloo backend will be used.
|
||||
outputs_single = torch.cat(outputs)
|
||||
output_split_sizes = [o.size(0) for o in outputs]
|
||||
handle = thd.all_to_all_single(
|
||||
outputs_single,
|
||||
torch.cat(inputs),
|
||||
output_split_sizes,
|
||||
[i.size(0) for i in inputs],
|
||||
group,
|
||||
async_op,
|
||||
)
|
||||
temp_outputs = outputs_single.split(output_split_sizes)
|
||||
|
||||
class _Waiter:
|
||||
def __init__(self, handle, outputs, temp_outputs):
|
||||
self.handle = handle
|
||||
self.outputs = outputs
|
||||
self.temp_outputs = temp_outputs
|
||||
|
||||
def wait(self):
|
||||
"""Returns the stored value when invoked."""
|
||||
handle = self.handle
|
||||
outputs = self.outputs
|
||||
temp_outputs = self.temp_outputs
|
||||
# Ensure that there is no leak
|
||||
self.handle = self.outputs = self.temp_outputs = None
|
||||
|
||||
if handle is not None:
|
||||
handle.wait()
|
||||
for output, temp_output in zip(outputs, temp_outputs):
|
||||
output.copy_(temp_output)
|
||||
|
||||
post_processor = _Waiter(handle, outputs, temp_outputs)
|
||||
return post_processor if async_op else post_processor.wait()
|
||||
|
||||
|
||||
def _revert_to_homo(d: dict):
|
||||
def revert_to_homo(d: dict):
|
||||
"""Utility function to convert a dictionary that stores homogenous data."""
|
||||
is_homogenous = len(d) == 1 and "_N" in d
|
||||
return list(d.values())[0] if is_homogenous else d
|
||||
|
||||
@@ -148,45 +188,31 @@ class SubgraphSampler(MiniBatchTransformer):
|
||||
def _seeds_cooperative_exchange_1(minibatch, group=None):
|
||||
rank = thd.get_rank(group)
|
||||
world_size = thd.get_world_size(group)
|
||||
assert world_size > 1
|
||||
seeds = minibatch._seed_nodes
|
||||
is_homogeneous = not isinstance(seeds, dict)
|
||||
if is_homogeneous:
|
||||
seeds = {"_N": seeds}
|
||||
if minibatch._seeds_offsets is None:
|
||||
seeds_list = list(seeds.values())
|
||||
(
|
||||
sorted_seeds_list,
|
||||
index_list,
|
||||
offsets_list,
|
||||
) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
|
||||
result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
|
||||
assert minibatch.compacted_seeds is None
|
||||
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
|
||||
num_ntypes = len(seeds.keys())
|
||||
for i, (
|
||||
seed_type,
|
||||
typed_sorted_seeds,
|
||||
typed_index,
|
||||
typed_offsets,
|
||||
) in enumerate(
|
||||
zip(
|
||||
seeds.keys(),
|
||||
sorted_seeds_list,
|
||||
index_list,
|
||||
offsets_list,
|
||||
)
|
||||
):
|
||||
(typed_sorted_seeds, typed_index, typed_offsets),
|
||||
) in enumerate(zip(seeds.keys(), result)):
|
||||
sorted_seeds[seed_type] = typed_sorted_seeds
|
||||
sorted_compacted[seed_type] = typed_index
|
||||
sorted_offsets[seed_type] = typed_offsets.tolist()
|
||||
sorted_offsets[seed_type] = typed_offsets
|
||||
|
||||
minibatch._seed_nodes = sorted_seeds
|
||||
minibatch.compacted_seeds = sorted_compacted
|
||||
minibatch.compacted_seeds = revert_to_homo(sorted_compacted)
|
||||
minibatch._seeds_offsets = sorted_offsets
|
||||
else:
|
||||
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
|
||||
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
|
||||
for i, offsets in enumerate(minibatch._seeds_offsets[0].values()):
|
||||
for i, offsets in enumerate(minibatch._seeds_offsets.values()):
|
||||
counts_sent[
|
||||
torch.arange(i, world_size * num_ntypes, num_ntypes)
|
||||
] = offsets.diff()
|
||||
@@ -208,7 +234,6 @@ class SubgraphSampler(MiniBatchTransformer):
|
||||
seeds = minibatch._seed_nodes
|
||||
minibatch._counts_future.wait()
|
||||
delattr(minibatch, "_counts_future")
|
||||
counts_received = minibatch._counts_received
|
||||
num_ntypes = len(seeds.keys())
|
||||
seeds_received = {}
|
||||
counts_sent = {}
|
||||
@@ -226,15 +251,19 @@ class SubgraphSampler(MiniBatchTransformer):
|
||||
group,
|
||||
)
|
||||
seeds_received[ntype] = typed_seeds_received
|
||||
minibatch._seed_nodes = _revert_to_homo(seeds_received)
|
||||
minibatch._counts_sent = _revert_to_homo(counts_sent)
|
||||
minibatch._counts_received = _revert_to_homo(counts_received)
|
||||
minibatch._seed_nodes = seeds_received
|
||||
minibatch._counts_sent = revert_to_homo(counts_sent)
|
||||
minibatch._counts_received = revert_to_homo(counts_received)
|
||||
return minibatch
|
||||
|
||||
@staticmethod
|
||||
def _seeds_cooperative_exchange_3(minibatch):
|
||||
nodes = {
|
||||
ntype: [typed_seeds]
|
||||
for ntype, typed_seeds in minibatch._seed_nodes.items()
|
||||
}
|
||||
minibatch._unique_future = unique_and_compact(
|
||||
minibatch._seed_nodes, 0, 1, async_op=True
|
||||
nodes, 0, 1, async_op=True
|
||||
)
|
||||
return minibatch
|
||||
|
||||
@@ -242,8 +271,11 @@ class SubgraphSampler(MiniBatchTransformer):
|
||||
def _seeds_cooperative_exchange_4(minibatch):
|
||||
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
|
||||
delattr(minibatch, "_unique_future")
|
||||
minibatch._seed_nodes = _revert_to_homo(unique_seeds)
|
||||
minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds)
|
||||
inverse_seeds = {
|
||||
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
|
||||
}
|
||||
minibatch._seed_nodes = revert_to_homo(unique_seeds)
|
||||
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
|
||||
return minibatch
|
||||
|
||||
def _sample(self, minibatch):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
import unittest
|
||||
from sys import platform
|
||||
|
||||
import backend as F
|
||||
|
||||
@@ -6,6 +8,7 @@ import dgl
|
||||
import dgl.graphbolt
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as thd
|
||||
|
||||
from dgl.graphbolt.datapipes import find_dps, traverse_dps
|
||||
|
||||
@@ -63,6 +66,7 @@ def test_DataLoader(overlap_feature_fetch):
|
||||
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
|
||||
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
|
||||
@pytest.mark.parametrize("overlap_graph_fetch", [True, False])
|
||||
@pytest.mark.parametrize("cooperative", [True, False])
|
||||
@pytest.mark.parametrize("asynchronous", [True, False])
|
||||
@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024])
|
||||
@pytest.mark.parametrize("gpu_cache_threshold", [1, 3])
|
||||
@@ -71,10 +75,23 @@ def test_gpu_sampling_DataLoader(
|
||||
enable_feature_fetch,
|
||||
overlap_feature_fetch,
|
||||
overlap_graph_fetch,
|
||||
cooperative,
|
||||
asynchronous,
|
||||
num_gpu_cached_edges,
|
||||
gpu_cache_threshold,
|
||||
):
|
||||
if cooperative and not thd.is_initialized():
|
||||
# On Windows, the init method can only be file.
|
||||
init_method = (
|
||||
f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}"
|
||||
if platform == "win32"
|
||||
else "tcp://127.0.0.1:12345"
|
||||
)
|
||||
thd.init_process_group(
|
||||
init_method=init_method,
|
||||
world_size=1,
|
||||
rank=0,
|
||||
)
|
||||
N = 40
|
||||
B = 4
|
||||
num_layers = 2
|
||||
@@ -110,6 +127,7 @@ def test_gpu_sampling_DataLoader(
|
||||
"overlap_fetch": overlap_graph_fetch,
|
||||
"num_gpu_cached_edges": num_gpu_cached_edges,
|
||||
"gpu_cache_threshold": gpu_cache_threshold,
|
||||
"cooperative": cooperative,
|
||||
"asynchronous": asynchronous,
|
||||
}
|
||||
if i != 0:
|
||||
@@ -118,7 +136,7 @@ def test_gpu_sampling_DataLoader(
|
||||
datapipe,
|
||||
graph,
|
||||
fanouts=[torch.LongTensor([2]) for _ in range(num_layers)],
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
if enable_feature_fetch:
|
||||
datapipe = dgl.graphbolt.FeatureFetcher(
|
||||
@@ -138,6 +156,11 @@ def test_gpu_sampling_DataLoader(
|
||||
bufferer_cnt += 2 * num_layers
|
||||
if asynchronous:
|
||||
bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1.
|
||||
if cooperative:
|
||||
bufferer_cnt += 3 * num_layers
|
||||
if cooperative:
|
||||
# _preprocess stage and each sampling layer.
|
||||
bufferer_cnt += 3
|
||||
datapipe_graph = traverse_dps(dataloader)
|
||||
bufferers = find_dps(
|
||||
datapipe_graph,
|
||||
@@ -171,3 +194,5 @@ def test_gpu_sampling_DataLoader(
|
||||
if sampler_name == "LayerNeighborSampler":
|
||||
assert torch.equal(edge_feature, edge_feature_ref)
|
||||
assert len(list(dataloader)) == N // B
|
||||
if thd.is_initialized():
|
||||
thd.destroy_process_group()
|
||||
|
||||
Reference in New Issue
Block a user