From e2e524df41cb84b99d3ff7fc9508e14634fdb044 Mon Sep 17 00:00:00 2001 From: Tong He Date: Thu, 7 Jan 2021 13:38:21 +0800 Subject: [PATCH] [Feature] Add GraphDataLoader implementation (#2496) * add graph dataloader * add to doc * fix * fix * fix docstring * update according to torch default_collate * add unittest * fix * fix lint * fix --- docs/source/api/python/dgl.dataloading.rst | 1 + python/dgl/dataloading/dataloader.py | 84 +++++++++++++++++++++- python/dgl/dataloading/pytorch/__init__.py | 52 +++++++++++++- tests/pytorch/test_dataloader.py | 9 +++ 4 files changed, 144 insertions(+), 2 deletions(-) diff --git a/docs/source/api/python/dgl.dataloading.rst b/docs/source/api/python/dgl.dataloading.rst index 4e4bf03d50..292f8f8791 100644 --- a/docs/source/api/python/dgl.dataloading.rst +++ b/docs/source/api/python/dgl.dataloading.rst @@ -16,6 +16,7 @@ and an ``EdgeDataLoader`` for edge/link prediction task. .. autoclass:: NodeDataLoader .. autoclass:: EdgeDataLoader +.. autoclass:: GraphDataLoader .. _api-dataloading-neighbor-sampling: Neighbor Sampler diff --git a/python/dgl/dataloading/dataloader.py b/python/dgl/dataloading/dataloader.py index 5482bfe9b2..7074c82940 100644 --- a/python/dgl/dataloading/dataloader.py +++ b/python/dgl/dataloading/dataloader.py @@ -1,13 +1,16 @@ """Data loaders""" -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from abc import ABC, abstractproperty, abstractmethod +import re import numpy as np from .. import transform from ..base import NID, EID from .. import backend as F from .. import utils +from ..batch import batch from ..convert import heterograph +from ..heterograph import DGLHeteroGraph as DGLGraph from ..distributed.dist_graph import DistGraph # pylint: disable=unused-argument @@ -678,3 +681,82 @@ class EdgeCollator(Collator): return self._collate(items) else: return self._collate_with_negative_sampling(items) + +class GraphCollator(object): + """Given a set of graphs as well as their graph-level data, the collate function will batch the + graphs into a batched graph, and stack the tensors into a single bigger tensor. If the + example is a container (such as sequences or mapping), the collate function preserves + the structure and collates each of the elements recursively. + + If the set of graphs has no graph-level data, the collate function will yield a batched graph. + + Examples + -------- + To train a GNN for graph classification on a set of graphs in ``dataset`` (assume + the backend is PyTorch): + + >>> dataloader = dgl.dataloading.GraphDataLoader( + ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for batched_graph, labels in dataloader: + ... train_on(batched_graph, labels) + """ + def __init__(self): + self.graph_collate_err_msg_format = ( + "graph_collate: batch must contain DGLGraph, tensors, numpy arrays, " + "numbers, dicts or lists; found {}") + self.np_str_obj_array_pattern = re.compile(r'[SaUO]') + + #This implementation is based on torch.utils.data._utils.collate.default_collate + def collate(self, items): + """This function is similar to ``torch.utils.data._utils.collate.default_collate``. + It combines the sampled graphs and corresponding graph-level data + into a batched graph and tensors. + + Parameters + ---------- + items : list of data points or tuples + Elements in the list are expected to have the same length. + Each sub-element will be batched as a batched graph, or a + batched tensor correspondingly. + + Returns + ------- + A tuple of the batching results. + """ + elem = items[0] + elem_type = type(elem) + if isinstance(elem, DGLGraph): + batched_graphs = batch(items) + return batched_graphs + elif F.is_tensor(elem): + return F.stack(items, 0) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(self.graph_collate_err_msg_format.format(elem.dtype)) + + return self.collate([F.tensor(b) for b in items]) + elif elem.shape == (): # scalars + return F.tensor(items) + elif isinstance(elem, float): + return F.tensor(items, dtype=F.float64) + elif isinstance(elem, int): + return F.tensor(items) + elif isinstance(elem, (str, bytes)): + return items + elif isinstance(elem, Mapping): + return {key: self.collate([d[key] for d in items]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(self.collate(samples) for samples in zip(*items))) + elif isinstance(elem, Sequence): + # check to make sure that the elements in batch have consistent size + item_iter = iter(items) + elem_size = len(next(item_iter)) + if not all(len(elem) == elem_size for elem in item_iter): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*items) + return [self.collate(samples) for samples in transposed] + + raise TypeError(self.graph_collate_err_msg_format.format(elem_type)) diff --git a/python/dgl/dataloading/pytorch/__init__.py b/python/dgl/dataloading/pytorch/__init__.py index 0f04a115ff..3262d94223 100644 --- a/python/dgl/dataloading/pytorch/__init__.py +++ b/python/dgl/dataloading/pytorch/__init__.py @@ -1,7 +1,7 @@ """DGL PyTorch DataLoaders""" import inspect from torch.utils.data import DataLoader -from ..dataloader import NodeCollator, EdgeCollator +from ..dataloader import NodeCollator, EdgeCollator, GraphCollator from ...distributed import DistGraph from ...distributed import DistDataLoader @@ -414,3 +414,53 @@ class EdgeDataLoader: def __len__(self): """Return the number of batches of the data loader.""" return len(self.dataloader) + +class GraphDataLoader: + """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched + graph and corresponding label tensor (if provided) of the said minibatch. + + Parameters + ---------- + collate : Function, default is None + The customized collate function. Will use the default collate + function if not given. + kwargs : dict + Arguments being passed to :py:class:`torch.utils.data.DataLoader`. + + Examples + -------- + To train a GNN for graph classification on a set of graphs in ``dataset`` (assume + the backend is PyTorch): + + >>> dataloader = dgl.dataloading.GraphDataLoader( + ... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4) + >>> for batched_graph, labels in dataloader: + ... train_on(batched_graph, labels) + """ + collator_arglist = inspect.getfullargspec(GraphCollator).args + + def __init__(self, dataset, collate=None, **kwargs): + collator_kwargs = {} + dataloader_kwargs = {} + for k, v in kwargs.items(): + if k in self.collator_arglist: + collator_kwargs[k] = v + else: + dataloader_kwargs[k] = v + + if collate is None: + self.collate = GraphCollator(**collator_kwargs).collate + else: + self.collate = collate + + self.dataloader = DataLoader(dataset=dataset, + collate_fn=self.collate, + **dataloader_kwargs) + + def __iter__(self): + """Return the iterator of the data loader.""" + return iter(self.dataloader) + + def __len__(self): + """Return the number of batches of the data loader.""" + return len(self.dataloader) diff --git a/tests/pytorch/test_dataloader.py b/tests/pytorch/test_dataloader.py index e88626c7f5..73c7592e74 100644 --- a/tests/pytorch/test_dataloader.py +++ b/tests/pytorch/test_dataloader.py @@ -181,6 +181,15 @@ def test_neighbor_sampler_dataloader(): collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False) _check_neighbor_sampling_dataloader(_g, nid, dl, mode) +def test_graph_dataloader(): + batch_size = 16 + num_batches = 2 + minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20) + data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True) + for graph, label in data_loader: + assert isinstance(graph, dgl.DGLGraph) + assert F.asnumpy(label).shape[0] == batch_size if __name__ == '__main__': test_neighbor_sampler_dataloader() + test_graph_dataloader()