[Dist] backward compatible with dgl.dataloading.DistDataLoader (#7782)

This commit is contained in:
Rhett Ying
2024-09-06 11:21:52 +08:00
committed by GitHub
parent 32b11c98e5
commit 12841c675b
5 changed files with 80 additions and 1 deletions

View File

@@ -68,6 +68,10 @@ Distributed Sampling
Distributed DataLoader
``````````````````````
.. autoclass:: NodeCollator
.. autoclass:: EdgeCollator
.. autoclass:: DistDataLoader
.. autoclass:: DistNodeDataLoader

View File

@@ -1474,6 +1474,48 @@ class GraphDataLoader(torch.utils.data.DataLoader):
raise DGLError("set_epoch is only available when use_ddp is True.")
class NodeCollator:
"""Deprecated. Please use :class:`~dgl.distributed.NodeCollator` instead."""
def __new__(cls, *args, **kwargs):
dgl_warning(
"NodeCollator is defined in dgl.distributed This class is for "
"backward compatibility and will be removed soon. Please update "
"your code to use `dgl.distributed.NodeCollator`."
)
from ..distributed import NodeCollator as NewNodeCollator
return NewNodeCollator(*args, **kwargs)
class EdgeCollator:
"""Deprecated. Please use :class:`~dgl.distributed.EdgeCollator` instead."""
def __new__(cls, *args, **kwargs):
dgl_warning(
"EdgeCollator is defined in dgl.distributed This class is for "
"backward compatibility and will be removed soon. Please update "
"your code to use `dgl.distributed.EdgeCollator`."
)
from ..distributed import EdgeCollator as NewEdgeCollator
return NewEdgeCollator(*args, **kwargs)
class DistDataLoader:
"""Deprecated. Please use :class:`~dgl.distributed.DistDataLoader` instead."""
def __new__(cls, *args, **kwargs):
dgl_warning(
"DistDataLoader is defined in dgl.distributed This class is for "
"backward compatibility and will be removed soon. Please update "
"your code to use `dgl.distributed.DistDataLoader`."
)
from ..distributed import DistDataLoader as NewDistDataLoader
return NewDistDataLoader(*args, **kwargs)
class DistNodeDataLoader:
"""Deprecated. Please use :class:`~dgl.distributed.DistNodeDataLoader`
instead.

View File

@@ -6,6 +6,8 @@ from .dist_dataloader import (
DistDataLoader,
DistEdgeDataLoader,
DistNodeDataLoader,
EdgeCollator,
NodeCollator,
)
from .dist_graph import DistGraph, DistGraphServer, edge_split, node_split
from .dist_tensor import DistTensor

View File

@@ -9,7 +9,13 @@ from ..base import EID, NID
from ..convert import heterograph
from .dist_context import get_sampler_pool
__all__ = ["DistDataLoader", "DistNodeDataLoader", "DistEdgeDataLoader"]
__all__ = [
"NodeCollator",
"EdgeCollator",
"DistDataLoader",
"DistNodeDataLoader",
"DistEdgeDataLoader",
]
DATALOADER_ID = 0

View File

@@ -821,6 +821,31 @@ def test_dataloader_worker_init_fn():
pass
def test_distributed_dataloaders():
# Test distributed dataloaders could be successfully imported.
try:
from dgl.dataloading import (
DistDataLoader,
DistEdgeDataLoader,
DistNodeDataLoader,
EdgeCollator,
NodeCollator,
)
except ImportError:
pytest.fail("Distributed DataLoader from dataloading import failed")
try:
from dgl.distributed import (
DistDataLoader,
DistEdgeDataLoader,
DistNodeDataLoader,
EdgeCollator,
NodeCollator,
)
except ImportError:
pytest.fail("Distributed DataLoader from dataloading import failed")
if __name__ == "__main__":
# test_node_dataloader(F.int32, 'neighbor', None)
test_edge_dataloader_excludes(