mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
Remove dependency on torchdata. (#7638)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
This commit is contained in:
committed by
GitHub
parent
cb4604aca2
commit
26ff09fdbf
@@ -4,14 +4,18 @@ from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.utils.data as torch_data
|
||||
import torchdata.dataloader2.graph as dp_utils
|
||||
|
||||
from .base import CopyTo, get_host_to_device_uva_stream
|
||||
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
|
||||
from .impl.gpu_graph_cache import GPUGraphCache
|
||||
from .impl.neighbor_sampler import SamplePerLayer
|
||||
|
||||
from .internal import datapipe_graph_to_adjlist
|
||||
from .internal import (
|
||||
datapipe_graph_to_adjlist,
|
||||
find_dps,
|
||||
replace_dp,
|
||||
traverse_dps,
|
||||
)
|
||||
from .item_sampler import ItemSampler
|
||||
|
||||
|
||||
@@ -47,7 +51,7 @@ def construct_gpu_graph_cache(
|
||||
|
||||
def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
|
||||
"""Find parent of target_datapipe and wrap it with ."""
|
||||
datapipes = dp_utils.find_dps(
|
||||
datapipes = find_dps(
|
||||
datapipe_graph,
|
||||
target_datapipe,
|
||||
)
|
||||
@@ -56,7 +60,7 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
|
||||
datapipe_id = id(datapipe)
|
||||
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
|
||||
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
|
||||
datapipe_graph = dp_utils.replace_dp(
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
parent_datapipe,
|
||||
wrapper(parent_datapipe, **kwargs),
|
||||
@@ -157,18 +161,18 @@ class DataLoader(torch_data.DataLoader):
|
||||
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
|
||||
|
||||
datapipe = datapipe.mark_end()
|
||||
datapipe_graph = dp_utils.traverse_dps(datapipe)
|
||||
datapipe_graph = traverse_dps(datapipe)
|
||||
|
||||
# (1) Insert minibatch distribution.
|
||||
# TODO(BarclayII): Currently I'm using sharding_filter() as a
|
||||
# concept demonstration. Later on minibatch distribution should be
|
||||
# merged into ItemSampler to maximize efficiency.
|
||||
item_samplers = dp_utils.find_dps(
|
||||
item_samplers = find_dps(
|
||||
datapipe_graph,
|
||||
ItemSampler,
|
||||
)
|
||||
for item_sampler in item_samplers:
|
||||
datapipe_graph = dp_utils.replace_dp(
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
item_sampler,
|
||||
item_sampler.sharding_filter(),
|
||||
@@ -186,7 +190,7 @@ class DataLoader(torch_data.DataLoader):
|
||||
# (3) Limit the number of UVA threads used if the feature_fetcher has
|
||||
# overlapping optimization enabled.
|
||||
if num_workers == 0 and torch.cuda.is_available():
|
||||
feature_fetchers = dp_utils.find_dps(
|
||||
feature_fetchers = find_dps(
|
||||
datapipe_graph,
|
||||
FeatureFetcher,
|
||||
)
|
||||
@@ -200,7 +204,7 @@ class DataLoader(torch_data.DataLoader):
|
||||
and torch.cuda.is_available()
|
||||
):
|
||||
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
|
||||
samplers = dp_utils.find_dps(
|
||||
samplers = find_dps(
|
||||
datapipe_graph,
|
||||
SamplePerLayer,
|
||||
)
|
||||
@@ -210,7 +214,7 @@ class DataLoader(torch_data.DataLoader):
|
||||
gpu_graph_cache = construct_gpu_graph_cache(
|
||||
sampler, num_gpu_cached_edges, gpu_cache_threshold
|
||||
)
|
||||
datapipe_graph = dp_utils.replace_dp(
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
sampler,
|
||||
sampler.fetch_and_sample(
|
||||
@@ -225,10 +229,10 @@ class DataLoader(torch_data.DataLoader):
|
||||
# Prefetching enables the data pipeline up to the CopyTo to run in a
|
||||
# separate thread.
|
||||
if torch.cuda.is_available():
|
||||
copiers = dp_utils.find_dps(datapipe_graph, CopyTo)
|
||||
copiers = find_dps(datapipe_graph, CopyTo)
|
||||
for copier in copiers:
|
||||
if copier.device.type == "cuda":
|
||||
datapipe_graph = dp_utils.replace_dp(
|
||||
datapipe_graph = replace_dp(
|
||||
datapipe_graph,
|
||||
copier,
|
||||
# Add prefetch so that CPU and GPU can run concurrently.
|
||||
|
||||
@@ -1,6 +1,27 @@
|
||||
"""DataPipe utilities"""
|
||||
|
||||
__all__ = ["datapipe_graph_to_adjlist"]
|
||||
import threading
|
||||
import time
|
||||
|
||||
from collections import deque
|
||||
from typing import final, List, Set, Type # pylint: disable=no-name-in-module
|
||||
|
||||
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
|
||||
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
|
||||
|
||||
__all__ = [
|
||||
"datapipe_graph_to_adjlist",
|
||||
"find_dps",
|
||||
"replace_dp",
|
||||
"traverse_dps",
|
||||
]
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/datapipes/iter/util/prefetcher.py#L19-L20
|
||||
# Interval between buffer fulfillment checks
|
||||
PRODUCER_SLEEP_INTERVAL = 0.0001
|
||||
# Interval between checking items availability in buffer
|
||||
CONSUMER_SLEEP_INTERVAL = 0.0001
|
||||
|
||||
|
||||
def _get_parents(result_dict, datapipe_graph):
|
||||
@@ -51,3 +72,301 @@ def datapipe_graph_to_adjlist(datapipe_graph):
|
||||
result_dict = {}
|
||||
_get_parents(result_dict, datapipe_graph)
|
||||
return result_dict
|
||||
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/dataloader2/graph/utils.py#L16-L35
|
||||
def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
|
||||
r"""
|
||||
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
|
||||
instances with the provided DataPipe type.
|
||||
"""
|
||||
dps: List[DataPipe] = []
|
||||
cache: Set[int] = set()
|
||||
|
||||
def helper(g) -> None: # pyre-ignore
|
||||
for dp_id, (dp, src_graph) in g.items():
|
||||
if dp_id in cache:
|
||||
continue
|
||||
cache.add(dp_id)
|
||||
# Please not use `isinstance`, there is a bug.
|
||||
if type(dp) is dp_type: # pylint: disable=unidiomatic-typecheck
|
||||
dps.append(dp)
|
||||
helper(src_graph)
|
||||
|
||||
helper(graph)
|
||||
|
||||
return dps
|
||||
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/dataloader2/graph/utils.py#L82-L97
|
||||
# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph
|
||||
def replace_dp(
|
||||
graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe
|
||||
) -> DataPipeGraph:
|
||||
r"""
|
||||
Given the graph of DataPipe generated by ``traverse_dps`` function and the
|
||||
DataPipe to be replaced and the new DataPipe, return the new graph of
|
||||
DataPipe.
|
||||
"""
|
||||
assert len(graph) == 1
|
||||
|
||||
if id(old_datapipe) in graph:
|
||||
graph = traverse_dps(new_datapipe)
|
||||
|
||||
final_datapipe = list(graph.values())[0][0]
|
||||
|
||||
for recv_dp, send_graph in graph.values():
|
||||
_replace_dp(recv_dp, send_graph, old_datapipe, new_datapipe)
|
||||
|
||||
return traverse_dps(final_datapipe)
|
||||
|
||||
|
||||
# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one.
|
||||
# If found, find where the `old_dp` is located in `recv_dp` and switch it to the `new_dp`
|
||||
def _replace_dp(
|
||||
recv_dp, send_graph: DataPipeGraph, old_dp: DataPipe, new_dp: DataPipe
|
||||
) -> None:
|
||||
old_dp_id = id(old_dp)
|
||||
for send_id in send_graph:
|
||||
if send_id == old_dp_id:
|
||||
_assign_attr(recv_dp, old_dp, new_dp, inner_dp=True)
|
||||
else:
|
||||
send_dp, sub_send_graph = send_graph[send_id]
|
||||
_replace_dp(send_dp, sub_send_graph, old_dp, new_dp)
|
||||
|
||||
|
||||
# Recursively re-assign datapipe for the sake of nested data structure
|
||||
# `inner_dp` is used to prevent recursive call if we have already met a `DataPipe`
|
||||
def _assign_attr(obj, old_dp, new_dp, inner_dp: bool = False):
|
||||
if obj is old_dp:
|
||||
return new_dp
|
||||
elif isinstance(obj, (IterDataPipe, MapDataPipe)):
|
||||
# Prevent recursive call for DataPipe
|
||||
if not inner_dp:
|
||||
return None
|
||||
for k in list(obj.__dict__.keys()):
|
||||
new_obj = _assign_attr(obj.__dict__[k], old_dp, new_dp)
|
||||
if new_obj is not None:
|
||||
obj.__dict__[k] = new_obj
|
||||
break
|
||||
return None
|
||||
elif isinstance(obj, dict):
|
||||
for k in list(obj.keys()):
|
||||
new_obj = _assign_attr(obj[k], old_dp, new_dp)
|
||||
if new_obj is not None:
|
||||
obj[k] = new_obj
|
||||
break
|
||||
return None
|
||||
# Tuple is immutable, has to re-create a tuple
|
||||
elif isinstance(obj, tuple):
|
||||
temp_list = []
|
||||
flag = False
|
||||
for item in obj:
|
||||
new_obj = _assign_attr(item, old_dp, new_dp, inner_dp)
|
||||
if new_obj is not None:
|
||||
flag = True
|
||||
temp_list.append(new_dp)
|
||||
else:
|
||||
temp_list.append(item)
|
||||
if flag:
|
||||
return tuple(temp_list) # Special case
|
||||
else:
|
||||
return None
|
||||
elif isinstance(obj, list):
|
||||
for i in range(len(obj)): # pylint: disable=consider-using-enumerate
|
||||
new_obj = _assign_attr(obj[i], old_dp, new_dp, inner_dp)
|
||||
if new_obj is not None:
|
||||
obj[i] = new_obj
|
||||
break
|
||||
return None
|
||||
elif isinstance(obj, set):
|
||||
new_obj = None
|
||||
for item in obj:
|
||||
if _assign_attr(item, old_dp, new_dp, inner_dp) is not None:
|
||||
new_obj = new_dp
|
||||
break
|
||||
if new_obj is not None:
|
||||
obj.remove(old_dp)
|
||||
obj.add(new_dp)
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class _PrefetchData:
|
||||
def __init__(self, source_datapipe, buffer_size: int):
|
||||
self.run_prefetcher: bool = True
|
||||
self.prefetch_buffer: Deque = deque()
|
||||
self.buffer_size: int = buffer_size
|
||||
self.source_datapipe = source_datapipe
|
||||
self.stop_iteration: bool = False
|
||||
self.paused: bool = False
|
||||
|
||||
|
||||
# Copied from:
|
||||
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/datapipes/iter/util/prefetcher.py#L34-L172
|
||||
@functional_datapipe("prefetch")
|
||||
class PrefetcherIterDataPipe(IterDataPipe):
|
||||
r"""
|
||||
Prefetches elements from the source DataPipe and puts them into a buffer
|
||||
(functional name: ``prefetch``). Prefetching performs the operations (e.g.
|
||||
I/O, computations) of the DataPipes up to this one ahead of time and stores
|
||||
the result in the buffer, ready to be consumed by the subsequent DataPipe.
|
||||
It has no effect aside from getting the sample ready ahead of time.
|
||||
|
||||
This is used by ``MultiProcessingReadingService`` when the arguments
|
||||
``worker_prefetch_cnt`` (for prefetching at each worker process) or
|
||||
``main_prefetch_cnt`` (for prefetching at the main loop) are greater than 0.
|
||||
|
||||
Beyond the built-in use cases, this can be useful to put after I/O DataPipes
|
||||
that have expensive I/O operations (e.g. takes a long time to request a file
|
||||
from a remote server).
|
||||
|
||||
Args:
|
||||
source_datapipe: IterDataPipe from which samples are prefetched
|
||||
buffer_size: the size of the buffer which stores the prefetched samples
|
||||
|
||||
Example:
|
||||
>>> from torchdata.datapipes.iter import IterableWrapper
|
||||
>>> dp = IterableWrapper(file_paths).open_files().prefetch(5)
|
||||
"""
|
||||
|
||||
def __init__(self, source_datapipe, buffer_size: int = 10):
|
||||
self.source_datapipe = source_datapipe
|
||||
if buffer_size <= 0:
|
||||
raise ValueError(
|
||||
"'buffer_size' is required to be a positive integer."
|
||||
)
|
||||
self.buffer_size = buffer_size
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.prefetch_data: Optional[_PrefetchData] = None
|
||||
|
||||
@staticmethod
|
||||
def thread_worker(
|
||||
prefetch_data: _PrefetchData,
|
||||
): # pylint: disable=missing-function-docstring
|
||||
itr = iter(prefetch_data.source_datapipe)
|
||||
while not prefetch_data.stop_iteration:
|
||||
# Run if not paused
|
||||
while prefetch_data.run_prefetcher:
|
||||
if (
|
||||
len(prefetch_data.prefetch_buffer)
|
||||
< prefetch_data.buffer_size
|
||||
):
|
||||
try:
|
||||
item = next(itr)
|
||||
prefetch_data.prefetch_buffer.append(item)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
prefetch_data.run_prefetcher = False
|
||||
prefetch_data.stop_iteration = True
|
||||
prefetch_data.prefetch_buffer.append(e)
|
||||
else: # Buffer is full, waiting for main thread to consume items
|
||||
# TODO: Calculate sleep interval based on previous consumption speed
|
||||
time.sleep(PRODUCER_SLEEP_INTERVAL)
|
||||
prefetch_data.paused = True
|
||||
# Sleep longer when this prefetcher thread is paused
|
||||
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
prefetch_data = _PrefetchData(
|
||||
self.source_datapipe, self.buffer_size
|
||||
)
|
||||
self.prefetch_data = prefetch_data
|
||||
thread = threading.Thread(
|
||||
target=PrefetcherIterDataPipe.thread_worker,
|
||||
args=(prefetch_data,),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
self.thread = thread
|
||||
|
||||
while (
|
||||
not prefetch_data.stop_iteration
|
||||
or len(prefetch_data.prefetch_buffer) > 0
|
||||
):
|
||||
if len(prefetch_data.prefetch_buffer) > 0:
|
||||
data = prefetch_data.prefetch_buffer.popleft()
|
||||
if isinstance(data, Exception):
|
||||
if isinstance(data, StopIteration):
|
||||
break
|
||||
raise data
|
||||
yield data
|
||||
else:
|
||||
time.sleep(CONSUMER_SLEEP_INTERVAL)
|
||||
finally:
|
||||
if "prefetch_data" in locals():
|
||||
prefetch_data.run_prefetcher = False
|
||||
prefetch_data.stop_iteration = True
|
||||
prefetch_data.paused = False
|
||||
if "thread" in locals():
|
||||
thread.join()
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Getting state in threading environment requires next operations:
|
||||
1) Stopping of the producer thread.
|
||||
2) Saving buffer.
|
||||
3) Adding lazy restart of producer thread when __next__ is called again
|
||||
(this will guarantee that you only change state of the source_datapipe
|
||||
after entire state of the graph is saved).
|
||||
"""
|
||||
# TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration
|
||||
return {
|
||||
"source_datapipe": self.source_datapipe,
|
||||
"buffer_size": self.buffer_size,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source_datapipe = state["source_datapipe"]
|
||||
self.buffer_size = state["buffer_size"]
|
||||
self.thread = None
|
||||
|
||||
@final
|
||||
def reset(self): # pylint: disable=missing-function-docstring
|
||||
self.shutdown()
|
||||
|
||||
def pause(self): # pylint: disable=missing-function-docstring
|
||||
if self.thread is not None:
|
||||
assert self.prefetch_data is not None
|
||||
self.prefetch_data.run_prefetcher = False
|
||||
if self.thread.is_alive():
|
||||
# Blocking until the thread is paused
|
||||
while not self.prefetch_data.paused:
|
||||
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)
|
||||
|
||||
@final
|
||||
def resume(self): # pylint: disable=missing-function-docstring
|
||||
if (
|
||||
self.thread is not None
|
||||
and self.prefetch_data is not None
|
||||
and (
|
||||
not self.prefetch_data.stop_iteration
|
||||
or len(self.prefetch_data.prefetch_buffer) > 0
|
||||
)
|
||||
):
|
||||
self.prefetch_data.run_prefetcher = True
|
||||
self.prefetch_data.paused = False
|
||||
|
||||
@final
|
||||
def shutdown(self): # pylint: disable=missing-function-docstring
|
||||
if hasattr(self, "prefetch_data") and self.prefetch_data is not None:
|
||||
self.prefetch_data.run_prefetcher = False
|
||||
self.prefetch_data.stop_iteration = True
|
||||
self.prefetch_data.paused = False
|
||||
self.prefetch_data = None
|
||||
if hasattr(self, "thread") and self.thread is not None:
|
||||
self.thread.join()
|
||||
self.thread = None
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
def __len__(self) -> int:
|
||||
if isinstance(self.source_datapipe, Sized):
|
||||
return len(self.source_datapipe)
|
||||
raise TypeError(
|
||||
f"{type(self).__name__} instance doesn't have valid length"
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Callable, Iterator, Optional, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torchdata.datapipes.iter import IterDataPipe
|
||||
from torch.utils.data import IterDataPipe
|
||||
|
||||
from .internal import calculate_range
|
||||
from .internal_utils import gb_warning
|
||||
@@ -112,9 +112,9 @@ class ItemSampler(IterDataPipe):
|
||||
pairs with negative sources/destinations.
|
||||
|
||||
Note: This class `ItemSampler` is not decorated with
|
||||
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
|
||||
`torch.utils.data.functional_datapipe` on purpose. This indicates it
|
||||
does not support function-like call. But any iterable datapipes from
|
||||
`torchdata` can be further appended.
|
||||
`torch.utils.data.datapipes` can be further appended.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -195,7 +195,7 @@ class ItemSampler(IterDataPipe):
|
||||
compacted_seeds=None, blocks=None,)
|
||||
|
||||
5. Further process batches with other datapipes such as
|
||||
:class:`torchdata.datapipes.iter.Mapper`.
|
||||
:class:`torch.utils.data.datapipes.iter.Mapper`.
|
||||
|
||||
>>> item_set = gb.ItemSet(torch.arange(0, 10))
|
||||
>>> data_pipe = gb.ItemSampler(item_set, 4)
|
||||
@@ -365,9 +365,9 @@ class DistributedItemSampler(ItemSampler):
|
||||
of items.
|
||||
|
||||
Note: This class `DistributedItemSampler` is not decorated with
|
||||
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
|
||||
`torch.utils.data.functional_datapipe` on purpose. This indicates it
|
||||
does not support function-like call. But any iterable datapipes from
|
||||
`torchdata` can be further appended.
|
||||
`torch.utils.data.datapipes` can be further appended.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -7,7 +7,7 @@ import dgl.graphbolt
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import torchdata.dataloader2.graph as dp_utils
|
||||
from dgl.graphbolt.internal import find_dps, traverse_dps
|
||||
|
||||
from . import gb_test_utils
|
||||
|
||||
@@ -137,13 +137,13 @@ def test_gpu_sampling_DataLoader(
|
||||
bufferer_cnt += num_layers
|
||||
awaiter_cnt += num_layers
|
||||
datapipe = dataloader.dataset
|
||||
datapipe_graph = dp_utils.traverse_dps(datapipe)
|
||||
awaiters = dp_utils.find_dps(
|
||||
datapipe_graph = traverse_dps(datapipe)
|
||||
awaiters = find_dps(
|
||||
datapipe_graph,
|
||||
dgl.graphbolt.Waiter,
|
||||
)
|
||||
assert len(awaiters) == awaiter_cnt
|
||||
bufferers = dp_utils.find_dps(
|
||||
bufferers = find_dps(
|
||||
datapipe_graph,
|
||||
dgl.graphbolt.Bufferer,
|
||||
)
|
||||
|
||||
@@ -408,11 +408,10 @@ def test_append_with_other_datapipes():
|
||||
batch_size = 4
|
||||
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
|
||||
data_pipe = gb.ItemSampler(item_set, batch_size)
|
||||
# torchdata.datapipes.iter.Enumerator
|
||||
data_pipe = data_pipe.enumerate()
|
||||
for i, (idx, data) in enumerate(data_pipe):
|
||||
assert i == idx
|
||||
assert len(data.seeds) == batch_size
|
||||
for i, data in enumerate(data_pipe):
|
||||
expected = torch.full((batch_size,), i * batch_size)
|
||||
expected = expected + torch.tensor([0, 1, 2, 3])
|
||||
assert torch.equal(data.seeds, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 4])
|
||||
|
||||
Reference in New Issue
Block a user