[GraphBolt][CUDA] Get world_size=1 somewhat for cooperative sampling. (#7796)

This commit is contained in:
Muhammed Fatih BALIN
2024-09-12 17:12:31 -04:00
committed by GitHub
parent 165e2507e7
commit 189b83c28c
3 changed files with 212 additions and 39 deletions

View File

@@ -3,6 +3,7 @@
from functools import partial
import torch
import torch.distributed as thd
from torch.utils.data import functional_datapipe
from torch.utils.data.datapipes.iter import Mapper
@@ -12,10 +13,14 @@ from ..base import (
index_select,
ORIGINAL_EDGE_ID,
)
from ..internal import compact_csc_format, unique_and_compact_csc_formats
from ..internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
)
from ..minibatch_transformer import MiniBatchTransformer
from ..subgraph_sampler import SubgraphSampler
from ..subgraph_sampler import all_to_all, revert_to_homo, SubgraphSampler
from .fused_csc_sampling_graph import fused_csc_sampling_graph
from .sampled_subgraph_impl import SampledSubgraphImpl
@@ -455,12 +460,32 @@ class SamplePerLayer(MiniBatchTransformer):
class CompactPerLayer(MiniBatchTransformer):
"""Compact the sampled edges for a single layer."""
def __init__(self, datapipe, deduplicate, asynchronous=False):
def __init__(
self, datapipe, deduplicate, cooperative=False, asynchronous=False
):
self.deduplicate = deduplicate
self.cooperative = cooperative
if asynchronous and deduplicate:
datapipe = datapipe.transform(self._compact_per_layer_async)
datapipe = datapipe.buffer()
super().__init__(datapipe, self._compact_per_layer_wait_future)
datapipe = datapipe.transform(self._compact_per_layer_wait_future)
if cooperative:
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_1
)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_2
)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_3
)
datapipe = datapipe.buffer()
datapipe = datapipe.transform(
self._seeds_cooperative_exchange_4
)
super().__init__(datapipe)
else:
super().__init__(datapipe, self._compact_per_layer)
@@ -498,19 +523,20 @@ class CompactPerLayer(MiniBatchTransformer):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
assert self.deduplicate
rank = thd.get_rank() if self.cooperative else 0
world_size = thd.get_world_size() if self.cooperative else 1
minibatch._future = unique_and_compact_csc_formats(
subgraph.sampled_csc, seeds, async_op=True
subgraph.sampled_csc, seeds, rank, world_size, async_op=True
)
return minibatch
@staticmethod
def _compact_per_layer_wait_future(minibatch):
def _compact_per_layer_wait_future(self, minibatch):
subgraph = minibatch.sampled_subgraphs[0]
seeds = minibatch._seed_nodes
(
original_row_node_ids,
compacted_csc_format,
_,
seeds_offsets,
) = minibatch._future.wait()
delattr(minibatch, "_future")
subgraph = SampledSubgraphImpl(
@@ -521,6 +547,87 @@ class CompactPerLayer(MiniBatchTransformer):
)
minibatch._seed_nodes = original_row_node_ids
minibatch.sampled_subgraphs[0] = subgraph
if self.cooperative:
subgraph._seeds_offsets = seeds_offsets
return minibatch
@staticmethod
def _seeds_cooperative_exchange_1(minibatch):
world_size = thd.get_world_size()
subgraph = minibatch.sampled_subgraphs[0]
seeds_offsets = subgraph._seeds_offsets
is_homogeneous = not isinstance(seeds_offsets, dict)
if is_homogeneous:
seeds_offsets = {"_N": seeds_offsets}
num_ntypes = len(seeds_offsets)
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(seeds_offsets.values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
counts_received = torch.empty_like(counts_sent)
subgraph._counts_future = all_to_all(
counts_received.split(num_ntypes),
counts_sent.split(num_ntypes),
async_op=True,
)
subgraph._counts_sent = counts_sent
subgraph._counts_received = counts_received
return minibatch
@staticmethod
def _seeds_cooperative_exchange_2(minibatch):
world_size = thd.get_world_size()
seeds = minibatch._seed_nodes
is_homogenous = not isinstance(seeds, dict)
if is_homogenous:
seeds = {"_N": seeds}
subgraph = minibatch.sampled_subgraphs[0]
subgraph._counts_future.wait()
delattr(subgraph, "_counts_future")
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
counts_received = {}
for i, (ntype, typed_seeds) in enumerate(seeds.items()):
idx = torch.arange(i, world_size * num_ntypes, num_ntypes)
typed_counts_sent = subgraph._counts_sent[idx].tolist()
typed_counts_received = subgraph._counts_received[idx].tolist()
typed_seeds_received = typed_seeds.new_empty(
sum(typed_counts_received)
)
all_to_all(
typed_seeds_received.split(typed_counts_received),
typed_seeds.split(typed_counts_sent),
)
seeds_received[ntype] = typed_seeds_received
subgraph._seeds_received = seeds_received
subgraph._counts_sent = revert_to_homo(counts_sent)
subgraph._counts_received = revert_to_homo(counts_received)
return minibatch
@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
subgraph = minibatch.sampled_subgraphs[0]
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in subgraph._seeds_received.items()
}
minibatch._unique_future = unique_and_compact(
nodes, 0, 1, async_op=True
)
return minibatch
@staticmethod
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
inverse_seeds = {
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
subgraph = minibatch.sampled_subgraphs[0]
subgraph._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch
@@ -541,6 +648,7 @@ class NeighborSamplerImpl(SubgraphSampler):
overlap_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
cooperative,
asynchronous,
layer_dependency=None,
batch_dependency=None,
@@ -561,6 +669,7 @@ class NeighborSamplerImpl(SubgraphSampler):
deduplicate,
sampler,
overlap_fetch,
cooperative=cooperative,
asynchronous=asynchronous,
layer_dependency=layer_dependency,
)
@@ -637,6 +746,7 @@ class NeighborSamplerImpl(SubgraphSampler):
deduplicate,
sampler,
overlap_fetch,
cooperative,
asynchronous,
layer_dependency,
):
@@ -653,7 +763,9 @@ class NeighborSamplerImpl(SubgraphSampler):
datapipe = datapipe.sample_per_layer(
sampler, fanout, replace, prob_name, overlap_fetch, asynchronous
)
datapipe = datapipe.compact_per_layer(deduplicate, asynchronous)
datapipe = datapipe.compact_per_layer(
deduplicate, cooperative, asynchronous
)
if is_labor and not layer_dependency:
datapipe = datapipe.transform(self._increment_seed)
if is_labor:
@@ -775,6 +887,7 @@ class NeighborSampler(NeighborSamplerImpl):
overlap_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
cooperative=False,
asynchronous=False,
):
super().__init__(
@@ -788,6 +901,7 @@ class NeighborSampler(NeighborSamplerImpl):
overlap_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
cooperative,
asynchronous,
)
@@ -937,6 +1051,7 @@ class LayerNeighborSampler(NeighborSamplerImpl):
overlap_fetch=False,
num_gpu_cached_edges=0,
gpu_cache_threshold=1,
cooperative=False,
asynchronous=False,
):
super().__init__(
@@ -950,6 +1065,7 @@ class LayerNeighborSampler(NeighborSamplerImpl):
overlap_fetch,
num_gpu_cached_edges,
gpu_cache_threshold,
cooperative,
asynchronous,
layer_dependency,
batch_dependency,

View File

@@ -15,6 +15,8 @@ from .minibatch_transformer import MiniBatchTransformer
__all__ = [
"SubgraphSampler",
"all_to_all",
"revert_to_homo",
]
@@ -41,10 +43,48 @@ def all_to_all(outputs, inputs, group=None, async_op=False):
`rank, ..., world_size - 1, 0, ..., rank - 1` and we make it
`0, world_size - 1` before calling `thd.all_to_all`."""
shift_fn = partial(_shift, group=group)
return thd.all_to_all(shift_fn(outputs), shift_fn(inputs), group, async_op)
outputs = shift_fn(list(outputs))
inputs = shift_fn(list(inputs))
if outputs[0].is_cuda:
return thd.all_to_all(outputs, inputs, group, async_op)
# gloo backend will be used.
outputs_single = torch.cat(outputs)
output_split_sizes = [o.size(0) for o in outputs]
handle = thd.all_to_all_single(
outputs_single,
torch.cat(inputs),
output_split_sizes,
[i.size(0) for i in inputs],
group,
async_op,
)
temp_outputs = outputs_single.split(output_split_sizes)
class _Waiter:
def __init__(self, handle, outputs, temp_outputs):
self.handle = handle
self.outputs = outputs
self.temp_outputs = temp_outputs
def wait(self):
"""Returns the stored value when invoked."""
handle = self.handle
outputs = self.outputs
temp_outputs = self.temp_outputs
# Ensure that there is no leak
self.handle = self.outputs = self.temp_outputs = None
if handle is not None:
handle.wait()
for output, temp_output in zip(outputs, temp_outputs):
output.copy_(temp_output)
post_processor = _Waiter(handle, outputs, temp_outputs)
return post_processor if async_op else post_processor.wait()
def _revert_to_homo(d: dict):
def revert_to_homo(d: dict):
"""Utility function to convert a dictionary that stores homogenous data."""
is_homogenous = len(d) == 1 and "_N" in d
return list(d.values())[0] if is_homogenous else d
@@ -148,45 +188,31 @@ class SubgraphSampler(MiniBatchTransformer):
def _seeds_cooperative_exchange_1(minibatch, group=None):
rank = thd.get_rank(group)
world_size = thd.get_world_size(group)
assert world_size > 1
seeds = minibatch._seed_nodes
is_homogeneous = not isinstance(seeds, dict)
if is_homogeneous:
seeds = {"_N": seeds}
if minibatch._seeds_offsets is None:
seeds_list = list(seeds.values())
(
sorted_seeds_list,
index_list,
offsets_list,
) = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
result = torch.ops.graphbolt.rank_sort(seeds_list, rank, world_size)
assert minibatch.compacted_seeds is None
sorted_seeds, sorted_compacted, sorted_offsets = {}, {}, {}
num_ntypes = len(seeds.keys())
for i, (
seed_type,
typed_sorted_seeds,
typed_index,
typed_offsets,
) in enumerate(
zip(
seeds.keys(),
sorted_seeds_list,
index_list,
offsets_list,
)
):
(typed_sorted_seeds, typed_index, typed_offsets),
) in enumerate(zip(seeds.keys(), result)):
sorted_seeds[seed_type] = typed_sorted_seeds
sorted_compacted[seed_type] = typed_index
sorted_offsets[seed_type] = typed_offsets.tolist()
sorted_offsets[seed_type] = typed_offsets
minibatch._seed_nodes = sorted_seeds
minibatch.compacted_seeds = sorted_compacted
minibatch.compacted_seeds = revert_to_homo(sorted_compacted)
minibatch._seeds_offsets = sorted_offsets
else:
minibatch._seeds_offsets = {"_N": minibatch._seeds_offsets}
counts_sent = torch.empty(world_size * num_ntypes, dtype=torch.int64)
for i, offsets in enumerate(minibatch._seeds_offsets[0].values()):
for i, offsets in enumerate(minibatch._seeds_offsets.values()):
counts_sent[
torch.arange(i, world_size * num_ntypes, num_ntypes)
] = offsets.diff()
@@ -208,7 +234,6 @@ class SubgraphSampler(MiniBatchTransformer):
seeds = minibatch._seed_nodes
minibatch._counts_future.wait()
delattr(minibatch, "_counts_future")
counts_received = minibatch._counts_received
num_ntypes = len(seeds.keys())
seeds_received = {}
counts_sent = {}
@@ -226,15 +251,19 @@ class SubgraphSampler(MiniBatchTransformer):
group,
)
seeds_received[ntype] = typed_seeds_received
minibatch._seed_nodes = _revert_to_homo(seeds_received)
minibatch._counts_sent = _revert_to_homo(counts_sent)
minibatch._counts_received = _revert_to_homo(counts_received)
minibatch._seed_nodes = seeds_received
minibatch._counts_sent = revert_to_homo(counts_sent)
minibatch._counts_received = revert_to_homo(counts_received)
return minibatch
@staticmethod
def _seeds_cooperative_exchange_3(minibatch):
nodes = {
ntype: [typed_seeds]
for ntype, typed_seeds in minibatch._seed_nodes.items()
}
minibatch._unique_future = unique_and_compact(
minibatch._seed_nodes, 0, 1, async_op=True
nodes, 0, 1, async_op=True
)
return minibatch
@@ -242,8 +271,11 @@ class SubgraphSampler(MiniBatchTransformer):
def _seeds_cooperative_exchange_4(minibatch):
unique_seeds, inverse_seeds, _ = minibatch._unique_future.wait()
delattr(minibatch, "_unique_future")
minibatch._seed_nodes = _revert_to_homo(unique_seeds)
minibatch._seed_inverse_ids = _revert_to_homo(inverse_seeds)
inverse_seeds = {
ntype: typed_inv[0] for ntype, typed_inv in inverse_seeds.items()
}
minibatch._seed_nodes = revert_to_homo(unique_seeds)
minibatch._seed_inverse_ids = revert_to_homo(inverse_seeds)
return minibatch
def _sample(self, minibatch):

View File

@@ -1,4 +1,6 @@
import os
import unittest
from sys import platform
import backend as F
@@ -6,6 +8,7 @@ import dgl
import dgl.graphbolt
import pytest
import torch
import torch.distributed as thd
from dgl.graphbolt.datapipes import find_dps, traverse_dps
@@ -63,6 +66,7 @@ def test_DataLoader(overlap_feature_fetch):
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
@pytest.mark.parametrize("overlap_graph_fetch", [True, False])
@pytest.mark.parametrize("cooperative", [True, False])
@pytest.mark.parametrize("asynchronous", [True, False])
@pytest.mark.parametrize("num_gpu_cached_edges", [0, 1024])
@pytest.mark.parametrize("gpu_cache_threshold", [1, 3])
@@ -71,10 +75,23 @@ def test_gpu_sampling_DataLoader(
enable_feature_fetch,
overlap_feature_fetch,
overlap_graph_fetch,
cooperative,
asynchronous,
num_gpu_cached_edges,
gpu_cache_threshold,
):
if cooperative and not thd.is_initialized():
# On Windows, the init method can only be file.
init_method = (
f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}"
if platform == "win32"
else "tcp://127.0.0.1:12345"
)
thd.init_process_group(
init_method=init_method,
world_size=1,
rank=0,
)
N = 40
B = 4
num_layers = 2
@@ -110,6 +127,7 @@ def test_gpu_sampling_DataLoader(
"overlap_fetch": overlap_graph_fetch,
"num_gpu_cached_edges": num_gpu_cached_edges,
"gpu_cache_threshold": gpu_cache_threshold,
"cooperative": cooperative,
"asynchronous": asynchronous,
}
if i != 0:
@@ -118,7 +136,7 @@ def test_gpu_sampling_DataLoader(
datapipe,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(num_layers)],
**kwargs
**kwargs,
)
if enable_feature_fetch:
datapipe = dgl.graphbolt.FeatureFetcher(
@@ -138,6 +156,11 @@ def test_gpu_sampling_DataLoader(
bufferer_cnt += 2 * num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers + 1 # _preprocess stage has 1.
if cooperative:
bufferer_cnt += 3 * num_layers
if cooperative:
# _preprocess stage and each sampling layer.
bufferer_cnt += 3
datapipe_graph = traverse_dps(dataloader)
bufferers = find_dps(
datapipe_graph,
@@ -171,3 +194,5 @@ def test_gpu_sampling_DataLoader(
if sampler_name == "LayerNeighborSampler":
assert torch.equal(edge_feature, edge_feature_ref)
assert len(list(dataloader)) == N // B
if thd.is_initialized():
thd.destroy_process_group()