mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] gb.DataLoader can simply be a datapipe. (#7732)
This commit is contained in:
committed by
GitHub
parent
b3eacd22d7
commit
37d1064c22
@@ -4,7 +4,6 @@ import torch
|
||||
import torch.utils.data as torch_data
|
||||
|
||||
from .base import CopyTo
|
||||
|
||||
from .datapipes import (
|
||||
datapipe_graph_to_adjlist,
|
||||
find_dps,
|
||||
@@ -15,6 +14,7 @@ from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
|
||||
from .impl.neighbor_sampler import SamplePerLayer
|
||||
from .internal_utils import gb_warning
|
||||
from .item_sampler import ItemSampler
|
||||
from .minibatch_transformer import MiniBatchTransformer
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -75,7 +75,7 @@ class MultiprocessingWrapper(torch_data.IterDataPipe):
|
||||
yield from self.dataloader
|
||||
|
||||
|
||||
class DataLoader(torch_data.DataLoader):
|
||||
class DataLoader(MiniBatchTransformer):
|
||||
"""Multiprocessing DataLoader.
|
||||
|
||||
Iterates over the data pipeline with everything before feature fetching
|
||||
@@ -122,32 +122,33 @@ class DataLoader(torch_data.DataLoader):
|
||||
datapipe = datapipe.mark_end()
|
||||
datapipe_graph = traverse_dps(datapipe)
|
||||
|
||||
# (1) Insert minibatch distribution.
|
||||
# TODO(BarclayII): Currently I'm using sharding_filter() as a
|
||||
# concept demonstration. Later on minibatch distribution should be
|
||||
# merged into ItemSampler to maximize efficiency.
|
||||
item_samplers = find_dps(
|
||||
datapipe_graph,
|
||||
ItemSampler,
|
||||
)
|
||||
for item_sampler in item_samplers:
|
||||
datapipe_graph = replace_dp(
|
||||
if num_workers > 0:
|
||||
# (1) Insert minibatch distribution.
|
||||
# TODO(BarclayII): Currently I'm using sharding_filter() as a
|
||||
# concept demonstration. Later on minibatch distribution should be
|
||||
# merged into ItemSampler to maximize efficiency.
|
||||
item_samplers = find_dps(
|
||||
datapipe_graph,
|
||||
item_sampler,
|
||||
item_sampler.sharding_filter(),
|
||||
ItemSampler,
|
||||
)
|
||||
for item_sampler in item_samplers:
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
item_sampler,
|
||||
item_sampler.sharding_filter(),
|
||||
)
|
||||
|
||||
# (2) Cut datapipe at FeatureFetcher and wrap.
|
||||
datapipe_graph = _find_and_wrap_parent(
|
||||
datapipe_graph,
|
||||
FeatureFetcherStartMarker,
|
||||
MultiprocessingWrapper,
|
||||
num_workers=num_workers,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
|
||||
# (2) Cut datapipe at FeatureFetcher and wrap.
|
||||
datapipe_graph = _find_and_wrap_parent(
|
||||
datapipe_graph,
|
||||
FeatureFetcherStartMarker,
|
||||
MultiprocessingWrapper,
|
||||
num_workers=num_workers,
|
||||
persistent_workers=persistent_workers,
|
||||
)
|
||||
|
||||
# (3) Limit the number of UVA threads used if the feature_fetcher has
|
||||
# overlapping optimization enabled.
|
||||
# (3) Limit the number of UVA threads used if the feature_fetcher
|
||||
# or any of the samplers have overlapping optimization enabled.
|
||||
if num_workers == 0 and torch.cuda.is_available():
|
||||
feature_fetchers = find_dps(
|
||||
datapipe_graph,
|
||||
@@ -187,6 +188,4 @@ class DataLoader(torch_data.DataLoader):
|
||||
),
|
||||
)
|
||||
|
||||
# The stages after feature fetching is still done in the main process.
|
||||
# So we set num_workers to 0 here.
|
||||
super().__init__(datapipe, batch_size=None, num_workers=0)
|
||||
super().__init__(datapipe)
|
||||
|
||||
@@ -138,8 +138,7 @@ def test_gpu_sampling_DataLoader(
|
||||
bufferer_cnt += 2 * num_layers
|
||||
if asynchronous:
|
||||
bufferer_cnt += 2 * num_layers
|
||||
datapipe = dataloader.dataset
|
||||
datapipe_graph = traverse_dps(datapipe)
|
||||
datapipe_graph = traverse_dps(dataloader)
|
||||
bufferers = find_dps(
|
||||
datapipe_graph,
|
||||
dgl.graphbolt.Bufferer,
|
||||
|
||||
@@ -64,55 +64,54 @@ def test_integration_link_prediction():
|
||||
[3, 2],
|
||||
[3, 2],
|
||||
[3, 3],
|
||||
[5, 0],
|
||||
[5, 0],
|
||||
[5, 2],
|
||||
[5, 1],
|
||||
[3, 4],
|
||||
[3, 3],
|
||||
[3, 0],
|
||||
[3, 5],
|
||||
[3, 3],
|
||||
[3, 3],
|
||||
[3, 2],
|
||||
[3, 0],
|
||||
[3, 4]]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
|
||||
indices=tensor([0, 5, 4], dtype=torch.int32),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 2, 2], dtype=torch.int32),
|
||||
indices=tensor([4, 5], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
|
||||
original_edge_ids=tensor([8, 5, 7]),
|
||||
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([5, 4], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([5, 1, 3, 2, 0, 4]),
|
||||
original_row_node_ids=tensor([5, 1, 3, 2, 4, 0]),
|
||||
original_edge_ids=tensor([9, 7]),
|
||||
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
|
||||
original_column_node_ids=tensor([5, 1, 3, 2, 4, 0]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 2, 2], dtype=torch.int32),
|
||||
indices=tensor([0, 5], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([5, 1, 3, 2, 4, 0]),
|
||||
original_edge_ids=tensor([8, 7]),
|
||||
original_column_node_ids=tensor([5, 1, 3, 2, 4, 0]),
|
||||
)],
|
||||
node_features={'feat': tensor([[0.5160, 0.2486],
|
||||
[0.6172, 0.7865],
|
||||
[0.8672, 0.2276],
|
||||
[0.2109, 0.1089],
|
||||
[0.9634, 0.2294],
|
||||
[0.5503, 0.8223]])},
|
||||
[0.5503, 0.8223],
|
||||
[0.9634, 0.2294]])},
|
||||
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
|
||||
input_nodes=tensor([5, 1, 3, 2, 0, 4]),
|
||||
input_nodes=tensor([5, 1, 3, 2, 4, 0]),
|
||||
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
|
||||
edge_features=[{'feat': tensor([[0.8972, 0.7511, 0.3617],
|
||||
[0.7885, 0.3414, 0.5485],
|
||||
edge_features=[{'feat': tensor([[0.5773, 0.2199, 0.3366],
|
||||
[0.0056, 0.9469, 0.4432]])},
|
||||
{'feat': tensor([[0.5773, 0.2199, 0.3366],
|
||||
{'feat': tensor([[0.8972, 0.7511, 0.3617],
|
||||
[0.0056, 0.9469, 0.4432]])}],
|
||||
compacted_seeds=tensor([[0, 1],
|
||||
[2, 3],
|
||||
[2, 3],
|
||||
[2, 2],
|
||||
[0, 4],
|
||||
[0, 4],
|
||||
[2, 2],
|
||||
[0, 3],
|
||||
[0, 1],
|
||||
[2, 4],
|
||||
[2, 2],
|
||||
[2, 0],
|
||||
[2, 2],
|
||||
[2, 2],
|
||||
[2, 5]]),
|
||||
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
|
||||
[2, 3],
|
||||
[2, 5],
|
||||
[2, 4]]),
|
||||
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
|
||||
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
|
||||
)"""
|
||||
),
|
||||
@@ -121,103 +120,97 @@ def test_integration_link_prediction():
|
||||
[4, 3],
|
||||
[4, 4],
|
||||
[0, 4],
|
||||
[3, 1],
|
||||
[3, 4],
|
||||
[3, 5],
|
||||
[4, 2],
|
||||
[4, 5],
|
||||
[4, 1],
|
||||
[4, 4],
|
||||
[4, 3],
|
||||
[4, 4],
|
||||
[4, 5],
|
||||
[0, 1],
|
||||
[0, 5]]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([4, 0], dtype=torch.int32),
|
||||
[0, 3]]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 0, 1], dtype=torch.int32),
|
||||
indices=tensor([3], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
|
||||
original_edge_ids=tensor([0, 1]),
|
||||
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
|
||||
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
|
||||
original_edge_ids=tensor([0]),
|
||||
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3], dtype=torch.int32),
|
||||
indices=tensor([4, 4, 0], dtype=torch.int32),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([3, 3], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([3, 4, 0, 1, 5, 2]),
|
||||
original_edge_ids=tensor([0, 8, 1]),
|
||||
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
|
||||
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
|
||||
original_edge_ids=tensor([8, 0]),
|
||||
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
|
||||
)],
|
||||
node_features={'feat': tensor([[0.8672, 0.2276],
|
||||
[0.5503, 0.8223],
|
||||
[0.9634, 0.2294],
|
||||
[0.6172, 0.7865],
|
||||
[0.5160, 0.2486],
|
||||
[0.2109, 0.1089]])},
|
||||
[0.6172, 0.7865]])},
|
||||
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
|
||||
input_nodes=tensor([3, 4, 0, 1, 5, 2]),
|
||||
input_nodes=tensor([3, 4, 0, 5, 1]),
|
||||
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
|
||||
edge_features=[{'feat': tensor([[0.5123, 0.1709, 0.6150],
|
||||
[0.1476, 0.1902, 0.1314]])},
|
||||
{'feat': tensor([[0.5123, 0.1709, 0.6150],
|
||||
[0.8972, 0.7511, 0.3617],
|
||||
[0.1476, 0.1902, 0.1314]])}],
|
||||
edge_features=[{'feat': tensor([[0.5123, 0.1709, 0.6150]])},
|
||||
{'feat': tensor([[0.8972, 0.7511, 0.3617],
|
||||
[0.5123, 0.1709, 0.6150]])}],
|
||||
compacted_seeds=tensor([[0, 0],
|
||||
[1, 0],
|
||||
[1, 1],
|
||||
[2, 1],
|
||||
[0, 1],
|
||||
[0, 3],
|
||||
[0, 4],
|
||||
[1, 5],
|
||||
[1, 4],
|
||||
[1, 1],
|
||||
[1, 0],
|
||||
[2, 3],
|
||||
[2, 4]]),
|
||||
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
|
||||
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
|
||||
[1, 1],
|
||||
[1, 3],
|
||||
[2, 4],
|
||||
[2, 0]]),
|
||||
blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=1),
|
||||
Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)],
|
||||
)"""
|
||||
),
|
||||
str(
|
||||
"""MiniBatch(seeds=tensor([[5, 5],
|
||||
[4, 5],
|
||||
[5, 0],
|
||||
[5, 4],
|
||||
[5, 5],
|
||||
[5, 5],
|
||||
[4, 0],
|
||||
[4, 1]]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([1, 0], dtype=torch.int32),
|
||||
[4, 0]]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1], dtype=torch.int32),
|
||||
indices=tensor([1], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([5, 4, 0, 1]),
|
||||
original_edge_ids=tensor([6, 0]),
|
||||
original_column_node_ids=tensor([5, 4, 0, 1]),
|
||||
original_row_node_ids=tensor([5, 4, 0]),
|
||||
original_edge_ids=tensor([6]),
|
||||
original_column_node_ids=tensor([5, 4, 0]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([1, 0], dtype=torch.int32),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1], dtype=torch.int32),
|
||||
indices=tensor([2], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([5, 4, 0, 1]),
|
||||
original_edge_ids=tensor([6, 0]),
|
||||
original_column_node_ids=tensor([5, 4, 0, 1]),
|
||||
original_row_node_ids=tensor([5, 4, 0]),
|
||||
original_edge_ids=tensor([7]),
|
||||
original_column_node_ids=tensor([5, 4, 0]),
|
||||
)],
|
||||
node_features={'feat': tensor([[0.5160, 0.2486],
|
||||
[0.5503, 0.8223],
|
||||
[0.9634, 0.2294],
|
||||
[0.6172, 0.7865]])},
|
||||
[0.9634, 0.2294]])},
|
||||
labels=tensor([1., 1., 0., 0., 0., 0.]),
|
||||
input_nodes=tensor([5, 4, 0, 1]),
|
||||
input_nodes=tensor([5, 4, 0]),
|
||||
indexes=tensor([0, 1, 0, 0, 1, 1]),
|
||||
edge_features=[{'feat': tensor([[0.4088, 0.8200, 0.1851],
|
||||
[0.5123, 0.1709, 0.6150]])},
|
||||
{'feat': tensor([[0.4088, 0.8200, 0.1851],
|
||||
[0.5123, 0.1709, 0.6150]])}],
|
||||
edge_features=[{'feat': tensor([[0.4088, 0.8200, 0.1851]])},
|
||||
{'feat': tensor([[0.0056, 0.9469, 0.4432]])}],
|
||||
compacted_seeds=tensor([[0, 0],
|
||||
[1, 0],
|
||||
[0, 2],
|
||||
[0, 1],
|
||||
[0, 0],
|
||||
[0, 0],
|
||||
[1, 2],
|
||||
[1, 3]]),
|
||||
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
|
||||
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
|
||||
[1, 2]]),
|
||||
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=1),
|
||||
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=1)],
|
||||
)"""
|
||||
),
|
||||
]
|
||||
for step, data in enumerate(dataloader):
|
||||
assert expected[step] == str(data), print(data)
|
||||
assert expected[step] == str(data), print(step, data)
|
||||
|
||||
|
||||
def test_integration_node_classification():
|
||||
@@ -275,10 +268,10 @@ def test_integration_node_classification():
|
||||
str(
|
||||
"""MiniBatch(seeds=tensor([5, 1]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([2, 0], dtype=torch.int32),
|
||||
indices=tensor([0, 0], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([5, 1, 4]),
|
||||
original_edge_ids=tensor([9, 0]),
|
||||
original_row_node_ids=tensor([5, 1]),
|
||||
original_edge_ids=tensor([8, 0]),
|
||||
original_column_node_ids=tensor([5, 1]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
|
||||
@@ -289,51 +282,49 @@ def test_integration_node_classification():
|
||||
original_column_node_ids=tensor([5, 1]),
|
||||
)],
|
||||
node_features={'feat': tensor([[0.5160, 0.2486],
|
||||
[0.6172, 0.7865],
|
||||
[0.5503, 0.8223]])},
|
||||
[0.6172, 0.7865]])},
|
||||
labels=None,
|
||||
input_nodes=tensor([5, 1, 4]),
|
||||
input_nodes=tensor([5, 1]),
|
||||
indexes=None,
|
||||
edge_features=[{'feat': tensor([[0.5773, 0.2199, 0.3366],
|
||||
edge_features=[{'feat': tensor([[0.8972, 0.7511, 0.3617],
|
||||
[0.5123, 0.1709, 0.6150]])},
|
||||
{'feat': tensor([[0.8972, 0.7511, 0.3617],
|
||||
[0.5123, 0.1709, 0.6150]])}],
|
||||
compacted_seeds=None,
|
||||
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
|
||||
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2),
|
||||
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
|
||||
)"""
|
||||
),
|
||||
str(
|
||||
"""MiniBatch(seeds=tensor([2, 4]),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 3], dtype=torch.int32),
|
||||
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
|
||||
indices=tensor([2, 1, 2], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([2, 4, 3, 0]),
|
||||
original_edge_ids=tensor([2, 6, 3]),
|
||||
original_column_node_ids=tensor([2, 4, 3, 0]),
|
||||
original_row_node_ids=tensor([2, 4, 3]),
|
||||
original_edge_ids=tensor([1, 6, 3]),
|
||||
original_column_node_ids=tensor([2, 4, 3]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
|
||||
indices=tensor([2, 3], dtype=torch.int32),
|
||||
indices=tensor([2, 1], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([2, 4, 3, 0]),
|
||||
original_edge_ids=tensor([2, 7]),
|
||||
original_row_node_ids=tensor([2, 4, 3]),
|
||||
original_edge_ids=tensor([2, 6]),
|
||||
original_column_node_ids=tensor([2, 4]),
|
||||
)],
|
||||
node_features={'feat': tensor([[0.2109, 0.1089],
|
||||
[0.5503, 0.8223],
|
||||
[0.8672, 0.2276],
|
||||
[0.9634, 0.2294]])},
|
||||
[0.8672, 0.2276]])},
|
||||
labels=None,
|
||||
input_nodes=tensor([2, 4, 3, 0]),
|
||||
input_nodes=tensor([2, 4, 3]),
|
||||
indexes=None,
|
||||
edge_features=[{'feat': tensor([[0.2582, 0.5203, 0.6228],
|
||||
edge_features=[{'feat': tensor([[0.1476, 0.1902, 0.1314],
|
||||
[0.4088, 0.8200, 0.1851],
|
||||
[0.3708, 0.7631, 0.2683]])},
|
||||
{'feat': tensor([[0.2582, 0.5203, 0.6228],
|
||||
[0.0056, 0.9469, 0.4432]])}],
|
||||
[0.4088, 0.8200, 0.1851]])}],
|
||||
compacted_seeds=None,
|
||||
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=3),
|
||||
Block(num_src_nodes=4, num_dst_nodes=2, num_edges=2)],
|
||||
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=3),
|
||||
Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2)],
|
||||
)"""
|
||||
),
|
||||
str(
|
||||
@@ -342,14 +333,14 @@ def test_integration_node_classification():
|
||||
indices=tensor([0], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([3, 0]),
|
||||
original_edge_ids=tensor([4]),
|
||||
original_edge_ids=tensor([3]),
|
||||
original_column_node_ids=tensor([3, 0]),
|
||||
),
|
||||
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
|
||||
indices=tensor([0], dtype=torch.int32),
|
||||
),
|
||||
original_row_node_ids=tensor([3, 0]),
|
||||
original_edge_ids=tensor([4]),
|
||||
original_edge_ids=tensor([3]),
|
||||
original_column_node_ids=tensor([3, 0]),
|
||||
)],
|
||||
node_features={'feat': tensor([[0.8672, 0.2276],
|
||||
@@ -357,8 +348,8 @@ def test_integration_node_classification():
|
||||
labels=None,
|
||||
input_nodes=tensor([3, 0]),
|
||||
indexes=None,
|
||||
edge_features=[{'feat': tensor([[0.2126, 0.7878, 0.7225]])},
|
||||
{'feat': tensor([[0.2126, 0.7878, 0.7225]])}],
|
||||
edge_features=[{'feat': tensor([[0.3708, 0.7631, 0.2683]])},
|
||||
{'feat': tensor([[0.3708, 0.7631, 0.2683]])}],
|
||||
compacted_seeds=None,
|
||||
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
|
||||
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
|
||||
|
||||
Reference in New Issue
Block a user