Remove dependency on torchdata. (#7638)

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
This commit is contained in:
Hongzhi (Steve), Chen
2024-08-06 11:20:26 +08:00
committed by GitHub
parent cb4604aca2
commit 26ff09fdbf
5 changed files with 350 additions and 28 deletions

View File

@@ -4,14 +4,18 @@ from collections import OrderedDict
import torch
import torch.utils.data as torch_data
import torchdata.dataloader2.graph as dp_utils
from .base import CopyTo, get_host_to_device_uva_stream
from .feature_fetcher import FeatureFetcher, FeatureFetcherStartMarker
from .impl.gpu_graph_cache import GPUGraphCache
from .impl.neighbor_sampler import SamplePerLayer
from .internal import datapipe_graph_to_adjlist
from .internal import (
datapipe_graph_to_adjlist,
find_dps,
replace_dp,
traverse_dps,
)
from .item_sampler import ItemSampler
@@ -47,7 +51,7 @@ def construct_gpu_graph_cache(
def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipes = find_dps(
datapipe_graph,
target_datapipe,
)
@@ -56,7 +60,7 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
@@ -157,18 +161,18 @@ class DataLoader(torch_data.DataLoader):
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
datapipe = datapipe.mark_end()
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_graph = traverse_dps(datapipe)
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into ItemSampler to maximize efficiency.
item_samplers = dp_utils.find_dps(
item_samplers = find_dps(
datapipe_graph,
ItemSampler,
)
for item_sampler in item_samplers:
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
item_sampler,
item_sampler.sharding_filter(),
@@ -186,7 +190,7 @@ class DataLoader(torch_data.DataLoader):
# (3) Limit the number of UVA threads used if the feature_fetcher has
# overlapping optimization enabled.
if num_workers == 0 and torch.cuda.is_available():
feature_fetchers = dp_utils.find_dps(
feature_fetchers = find_dps(
datapipe_graph,
FeatureFetcher,
)
@@ -200,7 +204,7 @@ class DataLoader(torch_data.DataLoader):
and torch.cuda.is_available()
):
torch.ops.graphbolt.set_max_uva_threads(max_uva_threads)
samplers = dp_utils.find_dps(
samplers = find_dps(
datapipe_graph,
SamplePerLayer,
)
@@ -210,7 +214,7 @@ class DataLoader(torch_data.DataLoader):
gpu_graph_cache = construct_gpu_graph_cache(
sampler, num_gpu_cached_edges, gpu_cache_threshold
)
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
sampler,
sampler.fetch_and_sample(
@@ -225,10 +229,10 @@ class DataLoader(torch_data.DataLoader):
# Prefetching enables the data pipeline up to the CopyTo to run in a
# separate thread.
if torch.cuda.is_available():
copiers = dp_utils.find_dps(datapipe_graph, CopyTo)
copiers = find_dps(datapipe_graph, CopyTo)
for copier in copiers:
if copier.device.type == "cuda":
datapipe_graph = dp_utils.replace_dp(
datapipe_graph = replace_dp(
datapipe_graph,
copier,
# Add prefetch so that CPU and GPU can run concurrently.

View File

@@ -1,6 +1,27 @@
"""DataPipe utilities"""
__all__ = ["datapipe_graph_to_adjlist"]
import threading
import time
from collections import deque
from typing import final, List, Set, Type # pylint: disable=no-name-in-module
from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
__all__ = [
"datapipe_graph_to_adjlist",
"find_dps",
"replace_dp",
"traverse_dps",
]
# Copied from:
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/datapipes/iter/util/prefetcher.py#L19-L20
# Interval between buffer fulfillment checks
PRODUCER_SLEEP_INTERVAL = 0.0001
# Interval between checking items availability in buffer
CONSUMER_SLEEP_INTERVAL = 0.0001
def _get_parents(result_dict, datapipe_graph):
@@ -51,3 +72,301 @@ def datapipe_graph_to_adjlist(datapipe_graph):
result_dict = {}
_get_parents(result_dict, datapipe_graph)
return result_dict
# Copied from:
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/dataloader2/graph/utils.py#L16-L35
def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe
instances with the provided DataPipe type.
"""
dps: List[DataPipe] = []
cache: Set[int] = set()
def helper(g) -> None: # pyre-ignore
for dp_id, (dp, src_graph) in g.items():
if dp_id in cache:
continue
cache.add(dp_id)
# Please not use `isinstance`, there is a bug.
if type(dp) is dp_type: # pylint: disable=unidiomatic-typecheck
dps.append(dp)
helper(src_graph)
helper(graph)
return dps
# Copied from:
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/dataloader2/graph/utils.py#L82-L97
# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph
def replace_dp(
graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe
) -> DataPipeGraph:
r"""
Given the graph of DataPipe generated by ``traverse_dps`` function and the
DataPipe to be replaced and the new DataPipe, return the new graph of
DataPipe.
"""
assert len(graph) == 1
if id(old_datapipe) in graph:
graph = traverse_dps(new_datapipe)
final_datapipe = list(graph.values())[0][0]
for recv_dp, send_graph in graph.values():
_replace_dp(recv_dp, send_graph, old_datapipe, new_datapipe)
return traverse_dps(final_datapipe)
# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one.
# If found, find where the `old_dp` is located in `recv_dp` and switch it to the `new_dp`
def _replace_dp(
recv_dp, send_graph: DataPipeGraph, old_dp: DataPipe, new_dp: DataPipe
) -> None:
old_dp_id = id(old_dp)
for send_id in send_graph:
if send_id == old_dp_id:
_assign_attr(recv_dp, old_dp, new_dp, inner_dp=True)
else:
send_dp, sub_send_graph = send_graph[send_id]
_replace_dp(send_dp, sub_send_graph, old_dp, new_dp)
# Recursively re-assign datapipe for the sake of nested data structure
# `inner_dp` is used to prevent recursive call if we have already met a `DataPipe`
def _assign_attr(obj, old_dp, new_dp, inner_dp: bool = False):
if obj is old_dp:
return new_dp
elif isinstance(obj, (IterDataPipe, MapDataPipe)):
# Prevent recursive call for DataPipe
if not inner_dp:
return None
for k in list(obj.__dict__.keys()):
new_obj = _assign_attr(obj.__dict__[k], old_dp, new_dp)
if new_obj is not None:
obj.__dict__[k] = new_obj
break
return None
elif isinstance(obj, dict):
for k in list(obj.keys()):
new_obj = _assign_attr(obj[k], old_dp, new_dp)
if new_obj is not None:
obj[k] = new_obj
break
return None
# Tuple is immutable, has to re-create a tuple
elif isinstance(obj, tuple):
temp_list = []
flag = False
for item in obj:
new_obj = _assign_attr(item, old_dp, new_dp, inner_dp)
if new_obj is not None:
flag = True
temp_list.append(new_dp)
else:
temp_list.append(item)
if flag:
return tuple(temp_list) # Special case
else:
return None
elif isinstance(obj, list):
for i in range(len(obj)): # pylint: disable=consider-using-enumerate
new_obj = _assign_attr(obj[i], old_dp, new_dp, inner_dp)
if new_obj is not None:
obj[i] = new_obj
break
return None
elif isinstance(obj, set):
new_obj = None
for item in obj:
if _assign_attr(item, old_dp, new_dp, inner_dp) is not None:
new_obj = new_dp
break
if new_obj is not None:
obj.remove(old_dp)
obj.add(new_dp)
return None
else:
return None
class _PrefetchData:
def __init__(self, source_datapipe, buffer_size: int):
self.run_prefetcher: bool = True
self.prefetch_buffer: Deque = deque()
self.buffer_size: int = buffer_size
self.source_datapipe = source_datapipe
self.stop_iteration: bool = False
self.paused: bool = False
# Copied from:
# https://github.com/pytorch/data/blob/88c8bdc6662f37649b7ea5df0bd90a4b24a56876/torchdata/datapipes/iter/util/prefetcher.py#L34-L172
@functional_datapipe("prefetch")
class PrefetcherIterDataPipe(IterDataPipe):
r"""
Prefetches elements from the source DataPipe and puts them into a buffer
(functional name: ``prefetch``). Prefetching performs the operations (e.g.
I/O, computations) of the DataPipes up to this one ahead of time and stores
the result in the buffer, ready to be consumed by the subsequent DataPipe.
It has no effect aside from getting the sample ready ahead of time.
This is used by ``MultiProcessingReadingService`` when the arguments
``worker_prefetch_cnt`` (for prefetching at each worker process) or
``main_prefetch_cnt`` (for prefetching at the main loop) are greater than 0.
Beyond the built-in use cases, this can be useful to put after I/O DataPipes
that have expensive I/O operations (e.g. takes a long time to request a file
from a remote server).
Args:
source_datapipe: IterDataPipe from which samples are prefetched
buffer_size: the size of the buffer which stores the prefetched samples
Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> dp = IterableWrapper(file_paths).open_files().prefetch(5)
"""
def __init__(self, source_datapipe, buffer_size: int = 10):
self.source_datapipe = source_datapipe
if buffer_size <= 0:
raise ValueError(
"'buffer_size' is required to be a positive integer."
)
self.buffer_size = buffer_size
self.thread: Optional[threading.Thread] = None
self.prefetch_data: Optional[_PrefetchData] = None
@staticmethod
def thread_worker(
prefetch_data: _PrefetchData,
): # pylint: disable=missing-function-docstring
itr = iter(prefetch_data.source_datapipe)
while not prefetch_data.stop_iteration:
# Run if not paused
while prefetch_data.run_prefetcher:
if (
len(prefetch_data.prefetch_buffer)
< prefetch_data.buffer_size
):
try:
item = next(itr)
prefetch_data.prefetch_buffer.append(item)
except Exception as e: # pylint: disable=broad-except
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
prefetch_data.prefetch_buffer.append(e)
else: # Buffer is full, waiting for main thread to consume items
# TODO: Calculate sleep interval based on previous consumption speed
time.sleep(PRODUCER_SLEEP_INTERVAL)
prefetch_data.paused = True
# Sleep longer when this prefetcher thread is paused
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)
def __iter__(self):
try:
prefetch_data = _PrefetchData(
self.source_datapipe, self.buffer_size
)
self.prefetch_data = prefetch_data
thread = threading.Thread(
target=PrefetcherIterDataPipe.thread_worker,
args=(prefetch_data,),
daemon=True,
)
thread.start()
self.thread = thread
while (
not prefetch_data.stop_iteration
or len(prefetch_data.prefetch_buffer) > 0
):
if len(prefetch_data.prefetch_buffer) > 0:
data = prefetch_data.prefetch_buffer.popleft()
if isinstance(data, Exception):
if isinstance(data, StopIteration):
break
raise data
yield data
else:
time.sleep(CONSUMER_SLEEP_INTERVAL)
finally:
if "prefetch_data" in locals():
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
prefetch_data.paused = False
if "thread" in locals():
thread.join()
def __getstate__(self):
"""
Getting state in threading environment requires next operations:
1) Stopping of the producer thread.
2) Saving buffer.
3) Adding lazy restart of producer thread when __next__ is called again
(this will guarantee that you only change state of the source_datapipe
after entire state of the graph is saved).
"""
# TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration
return {
"source_datapipe": self.source_datapipe,
"buffer_size": self.buffer_size,
}
def __setstate__(self, state):
self.source_datapipe = state["source_datapipe"]
self.buffer_size = state["buffer_size"]
self.thread = None
@final
def reset(self): # pylint: disable=missing-function-docstring
self.shutdown()
def pause(self): # pylint: disable=missing-function-docstring
if self.thread is not None:
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = False
if self.thread.is_alive():
# Blocking until the thread is paused
while not self.prefetch_data.paused:
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)
@final
def resume(self): # pylint: disable=missing-function-docstring
if (
self.thread is not None
and self.prefetch_data is not None
and (
not self.prefetch_data.stop_iteration
or len(self.prefetch_data.prefetch_buffer) > 0
)
):
self.prefetch_data.run_prefetcher = True
self.prefetch_data.paused = False
@final
def shutdown(self): # pylint: disable=missing-function-docstring
if hasattr(self, "prefetch_data") and self.prefetch_data is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
self.prefetch_data.paused = False
self.prefetch_data = None
if hasattr(self, "thread") and self.thread is not None:
self.thread.join()
self.thread = None
def __del__(self):
self.shutdown()
def __len__(self) -> int:
if isinstance(self.source_datapipe, Sized):
return len(self.source_datapipe)
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length"
)

View File

@@ -6,7 +6,7 @@ from typing import Callable, Iterator, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
from torchdata.datapipes.iter import IterDataPipe
from torch.utils.data import IterDataPipe
from .internal import calculate_range
from .internal_utils import gb_warning
@@ -112,9 +112,9 @@ class ItemSampler(IterDataPipe):
pairs with negative sources/destinations.
Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
`torch.utils.data.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
`torch.utils.data.datapipes` can be further appended.
Parameters
----------
@@ -195,7 +195,7 @@ class ItemSampler(IterDataPipe):
compacted_seeds=None, blocks=None,)
5. Further process batches with other datapipes such as
:class:`torchdata.datapipes.iter.Mapper`.
:class:`torch.utils.data.datapipes.iter.Mapper`.
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.ItemSampler(item_set, 4)
@@ -365,9 +365,9 @@ class DistributedItemSampler(ItemSampler):
of items.
Note: This class `DistributedItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
`torch.utils.data.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
`torch.utils.data.datapipes` can be further appended.
Parameters
----------

View File

@@ -7,7 +7,7 @@ import dgl.graphbolt
import pytest
import torch
import torchdata.dataloader2.graph as dp_utils
from dgl.graphbolt.internal import find_dps, traverse_dps
from . import gb_test_utils
@@ -137,13 +137,13 @@ def test_gpu_sampling_DataLoader(
bufferer_cnt += num_layers
awaiter_cnt += num_layers
datapipe = dataloader.dataset
datapipe_graph = dp_utils.traverse_dps(datapipe)
awaiters = dp_utils.find_dps(
datapipe_graph = traverse_dps(datapipe)
awaiters = find_dps(
datapipe_graph,
dgl.graphbolt.Waiter,
)
assert len(awaiters) == awaiter_cnt
bufferers = dp_utils.find_dps(
bufferers = find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
)

View File

@@ -408,11 +408,10 @@ def test_append_with_other_datapipes():
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe):
assert i == idx
assert len(data.seeds) == batch_size
for i, data in enumerate(data_pipe):
expected = torch.full((batch_size,), i * batch_size)
expected = expected + torch.tensor([0, 1, 2, 3])
assert torch.equal(data.seeds, expected)
@pytest.mark.parametrize("batch_size", [1, 4])