[GraphBolt] gb.DataLoader can simply be a datapipe. (#7732)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-23 01:02:17 -04:00
committed by GitHub
parent b3eacd22d7
commit 37d1064c22
3 changed files with 130 additions and 141 deletions

View File

@@ -4,7 +4,6 @@ import torch
import torch.utils.data as torch_data
from .base import CopyTo
from .datapipes import (
datapipe_graph_to_adjlist,
find_dps,
@@ -15,6 +14,7 @@ from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
from .impl.neighbor_sampler import SamplePerLayer
from .internal_utils import gb_warning
from .item_sampler import ItemSampler
from .minibatch_transformer import MiniBatchTransformer
__all__ = [
@@ -75,7 +75,7 @@ class MultiprocessingWrapper(torch_data.IterDataPipe):
yield from self.dataloader
class DataLoader(torch_data.DataLoader):
class DataLoader(MiniBatchTransformer):
"""Multiprocessing DataLoader.
Iterates over the data pipeline with everything before feature fetching
@@ -122,32 +122,33 @@ class DataLoader(torch_data.DataLoader):
datapipe = datapipe.mark_end()
datapipe_graph = traverse_dps(datapipe)
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into ItemSampler to maximize efficiency.
item_samplers = find_dps(
datapipe_graph,
ItemSampler,
)
for item_sampler in item_samplers:
datapipe_graph = replace_dp(
if num_workers > 0:
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into ItemSampler to maximize efficiency.
item_samplers = find_dps(
datapipe_graph,
item_sampler,
item_sampler.sharding_filter(),
ItemSampler,
)
for item_sampler in item_samplers:
datapipe_graph = replace_dp(
datapipe_graph,
item_sampler,
item_sampler.sharding_filter(),
)
# (2) Cut datapipe at FeatureFetcher and wrap.
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
FeatureFetcherStartMarker,
MultiprocessingWrapper,
num_workers=num_workers,
persistent_workers=persistent_workers,
)
# (2) Cut datapipe at FeatureFetcher and wrap.
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
FeatureFetcherStartMarker,
MultiprocessingWrapper,
num_workers=num_workers,
persistent_workers=persistent_workers,
)
# (3) Limit the number of UVA threads used if the feature_fetcher has
# overlapping optimization enabled.
# (3) Limit the number of UVA threads used if the feature_fetcher
# or any of the samplers have overlapping optimization enabled.
if num_workers == 0 and torch.cuda.is_available():
feature_fetchers = find_dps(
datapipe_graph,
@@ -187,6 +188,4 @@ class DataLoader(torch_data.DataLoader):
),
)
# The stages after feature fetching is still done in the main process.
# So we set num_workers to 0 here.
super().__init__(datapipe, batch_size=None, num_workers=0)
super().__init__(datapipe)

View File

@@ -138,8 +138,7 @@ def test_gpu_sampling_DataLoader(
bufferer_cnt += 2 * num_layers
if asynchronous:
bufferer_cnt += 2 * num_layers
datapipe = dataloader.dataset
datapipe_graph = traverse_dps(datapipe)
datapipe_graph = traverse_dps(dataloader)
bufferers = find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,

View File

@@ -64,55 +64,54 @@ def test_integration_link_prediction():
[3, 2],
[3, 2],
[3, 3],
[5, 0],
[5, 0],
[5, 2],
[5, 1],
[3, 4],
[3, 3],
[3, 0],
[3, 5],
[3, 3],
[3, 3],
[3, 2],
[3, 0],
[3, 4]]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
indices=tensor([0, 5, 4], dtype=torch.int32),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 2, 2], dtype=torch.int32),
indices=tensor([4, 5], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_edge_ids=tensor([8, 5, 7]),
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
indices=tensor([5, 4], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_row_node_ids=tensor([5, 1, 3, 2, 4, 0]),
original_edge_ids=tensor([9, 7]),
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
original_column_node_ids=tensor([5, 1, 3, 2, 4, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 5], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2, 4, 0]),
original_edge_ids=tensor([8, 7]),
original_column_node_ids=tensor([5, 1, 3, 2, 4, 0]),
)],
node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.9634, 0.2294],
[0.5503, 0.8223]])},
[0.5503, 0.8223],
[0.9634, 0.2294]])},
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([5, 1, 3, 2, 0, 4]),
input_nodes=tensor([5, 1, 3, 2, 4, 0]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
edge_features=[{'feat': tensor([[0.8972, 0.7511, 0.3617],
[0.7885, 0.3414, 0.5485],
edge_features=[{'feat': tensor([[0.5773, 0.2199, 0.3366],
[0.0056, 0.9469, 0.4432]])},
{'feat': tensor([[0.5773, 0.2199, 0.3366],
{'feat': tensor([[0.8972, 0.7511, 0.3617],
[0.0056, 0.9469, 0.4432]])}],
compacted_seeds=tensor([[0, 1],
[2, 3],
[2, 3],
[2, 2],
[0, 4],
[0, 4],
[2, 2],
[0, 3],
[0, 1],
[2, 4],
[2, 2],
[2, 0],
[2, 2],
[2, 2],
[2, 5]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
[2, 3],
[2, 5],
[2, 4]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
)"""
),
@@ -121,103 +120,97 @@ def test_integration_link_prediction():
[4, 3],
[4, 4],
[0, 4],
[3, 1],
[3, 4],
[3, 5],
[4, 2],
[4, 5],
[4, 1],
[4, 4],
[4, 3],
[4, 4],
[4, 5],
[0, 1],
[0, 5]]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([4, 0], dtype=torch.int32),
[0, 3]]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 0, 1], dtype=torch.int32),
indices=tensor([3], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=tensor([0, 1]),
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=tensor([0]),
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
indices=tensor([4, 4, 0], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([3, 3], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_edge_ids=tensor([0, 8, 1]),
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=tensor([8, 0]),
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
)],
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865],
[0.5160, 0.2486],
[0.2109, 0.1089]])},
[0.6172, 0.7865]])},
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([3, 4, 0, 1, 5, 2]),
input_nodes=tensor([3, 4, 0, 5, 1]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
edge_features=[{'feat': tensor([[0.5123, 0.1709, 0.6150],
[0.1476, 0.1902, 0.1314]])},
{'feat': tensor([[0.5123, 0.1709, 0.6150],
[0.8972, 0.7511, 0.3617],
[0.1476, 0.1902, 0.1314]])}],
edge_features=[{'feat': tensor([[0.5123, 0.1709, 0.6150]])},
{'feat': tensor([[0.8972, 0.7511, 0.3617],
[0.5123, 0.1709, 0.6150]])}],
compacted_seeds=tensor([[0, 0],
[1, 0],
[1, 1],
[2, 1],
[0, 1],
[0, 3],
[0, 4],
[1, 5],
[1, 4],
[1, 1],
[1, 0],
[2, 3],
[2, 4]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
[1, 1],
[1, 3],
[2, 4],
[2, 0]]),
blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=1),
Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=tensor([[5, 5],
[4, 5],
[5, 0],
[5, 4],
[5, 5],
[5, 5],
[4, 0],
[4, 1]]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
[4, 0]]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1], dtype=torch.int32),
indices=tensor([1], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=tensor([6, 0]),
original_column_node_ids=tensor([5, 4, 0, 1]),
original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=tensor([6]),
original_column_node_ids=tensor([5, 4, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1], dtype=torch.int32),
indices=tensor([2], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0, 1]),
original_edge_ids=tensor([6, 0]),
original_column_node_ids=tensor([5, 4, 0, 1]),
original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=tensor([7]),
original_column_node_ids=tensor([5, 4, 0]),
)],
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865]])},
[0.9634, 0.2294]])},
labels=tensor([1., 1., 0., 0., 0., 0.]),
input_nodes=tensor([5, 4, 0, 1]),
input_nodes=tensor([5, 4, 0]),
indexes=tensor([0, 1, 0, 0, 1, 1]),
edge_features=[{'feat': tensor([[0.4088, 0.8200, 0.1851],
[0.5123, 0.1709, 0.6150]])},
{'feat': tensor([[0.4088, 0.8200, 0.1851],
[0.5123, 0.1709, 0.6150]])}],
edge_features=[{'feat': tensor([[0.4088, 0.8200, 0.1851]])},
{'feat': tensor([[0.0056, 0.9469, 0.4432]])}],
compacted_seeds=tensor([[0, 0],
[1, 0],
[0, 2],
[0, 1],
[0, 0],
[0, 0],
[1, 2],
[1, 3]]),
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
[1, 2]]),
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=1),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=1)],
)"""
),
]
for step, data in enumerate(dataloader):
assert expected[step] == str(data), print(data)
assert expected[step] == str(data), print(step, data)
def test_integration_node_classification():
@@ -275,10 +268,10 @@ def test_integration_node_classification():
str(
"""MiniBatch(seeds=tensor([5, 1]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([2, 0], dtype=torch.int32),
indices=tensor([0, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 4]),
original_edge_ids=tensor([9, 0]),
original_row_node_ids=tensor([5, 1]),
original_edge_ids=tensor([8, 0]),
original_column_node_ids=tensor([5, 1]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
@@ -289,51 +282,49 @@ def test_integration_node_classification():
original_column_node_ids=tensor([5, 1]),
)],
node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865],
[0.5503, 0.8223]])},
[0.6172, 0.7865]])},
labels=None,
input_nodes=tensor([5, 1, 4]),
input_nodes=tensor([5, 1]),
indexes=None,
edge_features=[{'feat': tensor([[0.5773, 0.2199, 0.3366],
edge_features=[{'feat': tensor([[0.8972, 0.7511, 0.3617],
[0.5123, 0.1709, 0.6150]])},
{'feat': tensor([[0.8972, 0.7511, 0.3617],
[0.5123, 0.1709, 0.6150]])}],
compacted_seeds=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=tensor([2, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 3], dtype=torch.int32),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([2, 1, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=tensor([2, 6, 3]),
original_column_node_ids=tensor([2, 4, 3, 0]),
original_row_node_ids=tensor([2, 4, 3]),
original_edge_ids=tensor([1, 6, 3]),
original_column_node_ids=tensor([2, 4, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([2, 3], dtype=torch.int32),
indices=tensor([2, 1], dtype=torch.int32),
),
original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=tensor([2, 7]),
original_row_node_ids=tensor([2, 4, 3]),
original_edge_ids=tensor([2, 6]),
original_column_node_ids=tensor([2, 4]),
)],
node_features={'feat': tensor([[0.2109, 0.1089],
[0.5503, 0.8223],
[0.8672, 0.2276],
[0.9634, 0.2294]])},
[0.8672, 0.2276]])},
labels=None,
input_nodes=tensor([2, 4, 3, 0]),
input_nodes=tensor([2, 4, 3]),
indexes=None,
edge_features=[{'feat': tensor([[0.2582, 0.5203, 0.6228],
edge_features=[{'feat': tensor([[0.1476, 0.1902, 0.1314],
[0.4088, 0.8200, 0.1851],
[0.3708, 0.7631, 0.2683]])},
{'feat': tensor([[0.2582, 0.5203, 0.6228],
[0.0056, 0.9469, 0.4432]])}],
[0.4088, 0.8200, 0.1851]])}],
compacted_seeds=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=3),
Block(num_src_nodes=4, num_dst_nodes=2, num_edges=2)],
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=3),
Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2)],
)"""
),
str(
@@ -342,14 +333,14 @@ def test_integration_node_classification():
indices=tensor([0], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 0]),
original_edge_ids=tensor([4]),
original_edge_ids=tensor([3]),
original_column_node_ids=tensor([3, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
indices=tensor([0], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 0]),
original_edge_ids=tensor([4]),
original_edge_ids=tensor([3]),
original_column_node_ids=tensor([3, 0]),
)],
node_features={'feat': tensor([[0.8672, 0.2276],
@@ -357,8 +348,8 @@ def test_integration_node_classification():
labels=None,
input_nodes=tensor([3, 0]),
indexes=None,
edge_features=[{'feat': tensor([[0.2126, 0.7878, 0.7225]])},
{'feat': tensor([[0.2126, 0.7878, 0.7225]])}],
edge_features=[{'feat': tensor([[0.3708, 0.7631, 0.2683]])},
{'feat': tensor([[0.3708, 0.7631, 0.2683]])}],
compacted_seeds=None,
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],