mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[Dist] backward compatible with dgl.dataloading.DistDataLoader (#7782)
This commit is contained in:
@@ -68,6 +68,10 @@ Distributed Sampling
|
||||
Distributed DataLoader
|
||||
``````````````````````
|
||||
|
||||
.. autoclass:: NodeCollator
|
||||
|
||||
.. autoclass:: EdgeCollator
|
||||
|
||||
.. autoclass:: DistDataLoader
|
||||
|
||||
.. autoclass:: DistNodeDataLoader
|
||||
|
||||
@@ -1474,6 +1474,48 @@ class GraphDataLoader(torch.utils.data.DataLoader):
|
||||
raise DGLError("set_epoch is only available when use_ddp is True.")
|
||||
|
||||
|
||||
class NodeCollator:
|
||||
"""Deprecated. Please use :class:`~dgl.distributed.NodeCollator` instead."""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
dgl_warning(
|
||||
"NodeCollator is defined in dgl.distributed This class is for "
|
||||
"backward compatibility and will be removed soon. Please update "
|
||||
"your code to use `dgl.distributed.NodeCollator`."
|
||||
)
|
||||
from ..distributed import NodeCollator as NewNodeCollator
|
||||
|
||||
return NewNodeCollator(*args, **kwargs)
|
||||
|
||||
|
||||
class EdgeCollator:
|
||||
"""Deprecated. Please use :class:`~dgl.distributed.EdgeCollator` instead."""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
dgl_warning(
|
||||
"EdgeCollator is defined in dgl.distributed This class is for "
|
||||
"backward compatibility and will be removed soon. Please update "
|
||||
"your code to use `dgl.distributed.EdgeCollator`."
|
||||
)
|
||||
from ..distributed import EdgeCollator as NewEdgeCollator
|
||||
|
||||
return NewEdgeCollator(*args, **kwargs)
|
||||
|
||||
|
||||
class DistDataLoader:
|
||||
"""Deprecated. Please use :class:`~dgl.distributed.DistDataLoader` instead."""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
dgl_warning(
|
||||
"DistDataLoader is defined in dgl.distributed This class is for "
|
||||
"backward compatibility and will be removed soon. Please update "
|
||||
"your code to use `dgl.distributed.DistDataLoader`."
|
||||
)
|
||||
from ..distributed import DistDataLoader as NewDistDataLoader
|
||||
|
||||
return NewDistDataLoader(*args, **kwargs)
|
||||
|
||||
|
||||
class DistNodeDataLoader:
|
||||
"""Deprecated. Please use :class:`~dgl.distributed.DistNodeDataLoader`
|
||||
instead.
|
||||
|
||||
@@ -6,6 +6,8 @@ from .dist_dataloader import (
|
||||
DistDataLoader,
|
||||
DistEdgeDataLoader,
|
||||
DistNodeDataLoader,
|
||||
EdgeCollator,
|
||||
NodeCollator,
|
||||
)
|
||||
from .dist_graph import DistGraph, DistGraphServer, edge_split, node_split
|
||||
from .dist_tensor import DistTensor
|
||||
|
||||
@@ -9,7 +9,13 @@ from ..base import EID, NID
|
||||
from ..convert import heterograph
|
||||
from .dist_context import get_sampler_pool
|
||||
|
||||
__all__ = ["DistDataLoader", "DistNodeDataLoader", "DistEdgeDataLoader"]
|
||||
__all__ = [
|
||||
"NodeCollator",
|
||||
"EdgeCollator",
|
||||
"DistDataLoader",
|
||||
"DistNodeDataLoader",
|
||||
"DistEdgeDataLoader",
|
||||
]
|
||||
|
||||
DATALOADER_ID = 0
|
||||
|
||||
|
||||
@@ -821,6 +821,31 @@ def test_dataloader_worker_init_fn():
|
||||
pass
|
||||
|
||||
|
||||
def test_distributed_dataloaders():
|
||||
# Test distributed dataloaders could be successfully imported.
|
||||
try:
|
||||
from dgl.dataloading import (
|
||||
DistDataLoader,
|
||||
DistEdgeDataLoader,
|
||||
DistNodeDataLoader,
|
||||
EdgeCollator,
|
||||
NodeCollator,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.fail("Distributed DataLoader from dataloading import failed")
|
||||
|
||||
try:
|
||||
from dgl.distributed import (
|
||||
DistDataLoader,
|
||||
DistEdgeDataLoader,
|
||||
DistNodeDataLoader,
|
||||
EdgeCollator,
|
||||
NodeCollator,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.fail("Distributed DataLoader from dataloading import failed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_node_dataloader(F.int32, 'neighbor', None)
|
||||
test_edge_dataloader_excludes(
|
||||
|
||||
Reference in New Issue
Block a user