[Dataset] Add transform argument to built-in datasets (#3733)

* Update

* Fix

* Update
This commit is contained in:
Mufei Li
2022-02-15 16:45:47 +08:00
committed by GitHub
parent b3d3a2c4b0
commit 8b8fd2c0be
22 changed files with 622 additions and 259 deletions

View File

@@ -39,6 +39,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
verbose: bool
Whether to print out progress information.
Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -67,12 +71,13 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
@@ -143,7 +148,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
- ``edata['h']`` : edge weights
"""
return self.graphs[item]
if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
@property
def is_temporal(self):

View File

@@ -43,10 +43,14 @@ class CitationGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
@@ -54,7 +58,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
@@ -69,7 +74,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
"""Loads input data from data directory and reorder graph for better locality
@@ -213,7 +219,10 @@ class CitationGraphDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self):
return 1
@@ -267,7 +276,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
@property
def reverse_edge(self):
return self._reverse_edge
def _preprocess_features(features):
"""Row-normalize feature matrix and convert to tuple representation"""
@@ -356,10 +365,14 @@ class CoraGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -400,10 +413,12 @@ class CoraGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'cora'
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -496,10 +511,14 @@ class CiteseerGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -543,10 +562,12 @@ class CiteseerGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None):
name = 'citeseer'
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -639,10 +660,14 @@ class PubmedGraphDataset(CitationGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -683,10 +708,12 @@ class PubmedGraphDataset(CitationGraphDataset):
>>> label = g.ndata['label']
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
name = 'pubmed'
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload,
verbose, reverse_edge, transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -714,7 +741,7 @@ class PubmedGraphDataset(CitationGraphDataset):
r"""The number of graphs in the dataset."""
return super(PubmedGraphDataset, self).__len__()
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None):
"""Get CoraGraphDataset
Parameters
@@ -724,19 +751,24 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True)
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
CoraGraphDataset
"""
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data
def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_citeseer(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get CiteseerGraphDataset
Parameters
@@ -746,38 +778,47 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=T
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
reverse_edge: bool
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
CiteseerGraphDataset
"""
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data
def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
def load_pubmed(raw_dir=None, force_reload=False, verbose=True,
reverse_edge=True, transform=None):
"""Get PubmedGraphDataset
Parameters
-----------
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
reverse_edge: bool
reverse_edge : bool
Whether to add reverse edges in graph. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Return
-------
PubmedGraphDataset
"""
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
return data
class CoraBinary(DGLBuiltinDataset):
@@ -798,15 +839,20 @@ class CoraBinary(DGLBuiltinDataset):
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, raw_dir=None, force_reload=False, verbose=True, transform=None):
name = 'cora_binary'
url = _get_dgl_url('dataset/cora_binary.zip')
super(CoraBinary, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
root = self.raw_path
@@ -894,7 +940,11 @@ class CoraBinary(DGLBuiltinDataset):
(dgl.DGLGraph, scipy.sparse.coo_matrix, int)
The graph, scipy sparse coo_matrix and its label.
"""
return (self.graphs[i], self.pmpds[i], self.labels[i])
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return (g, self.pmpds[i], self.labels[i])
@property
def save_name(self):

View File

@@ -33,6 +33,10 @@ class DGLCSVDataset(DGLDataset):
A callable object which is used to parse corresponding column graph
data. Default: None. If None, a default data parser is applied
which load data directly and tries to convert list into array.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -46,7 +50,8 @@ class DGLCSVDataset(DGLDataset):
"""
META_YAML_NAME = 'meta.yaml'
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None):
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
edge_data_parser=None, graph_data_parser=None, transform=None):
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
self.graphs = None
self.data = None
@@ -61,7 +66,7 @@ class DGLCSVDataset(DGLDataset):
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
ds_name = self.meta_yaml.dataset_name
super().__init__(ds_name, raw_dir=os.path.dirname(
meta_yaml_path), force_reload=force_reload, verbose=verbose)
meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform)
def process(self):
@@ -122,10 +127,15 @@ class DGLCSVDataset(DGLDataset):
self.graphs, self.data = load_graphs(graph_path)
def __getitem__(self, i):
if 'label' in self.data:
return self.graphs[i], self.data['label'][i]
if self._transform is None:
g = self.graphs[i]
else:
return self.graphs[i]
g = self._transform(self.graphs[i])
if 'label' in self.data:
return g, self.data['label'][i]
else:
return g
def __len__(self):
return len(self.graphs)

View File

@@ -49,6 +49,10 @@ class DGLDataset(object):
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -72,13 +76,14 @@ class DGLDataset(object):
Hash value for the dataset and the setting.
"""
def __init__(self, name, url=None, raw_dir=None, save_dir=None,
hash_key=(), force_reload=False, verbose=False):
hash_key=(), force_reload=False, verbose=False, transform=None):
self._name = name
self._url = url
self._force_reload = force_reload
self._verbose = verbose
self._hash_key = hash_key
self._hash = self._get_hash()
self._transform = transform
# if no dir is provided, the default dgl download dir is used.
if raw_dir is None:
@@ -142,7 +147,7 @@ class DGLDataset(object):
def _download(self):
"""Download dataset by calling ``self.download()``
if the dataset does not exists under ``self.raw_path``.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path.
"""
@@ -161,7 +166,7 @@ class DGLDataset(object):
- If loadin process fails, re-download and process the dataset.
else:
- Download the dataset if needed.
- Process the dataset and build the dgl graph.
- Save the processed dataset into files.
@@ -287,17 +292,23 @@ class DGLBuiltinDataset(DGLDataset):
from the same dataset class by comparing the hash values.
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, url, raw_dir=None, hash_key=(), force_reload=False, verbose=False):
def __init__(self, name, url, raw_dir=None, hash_key=(),
force_reload=False, verbose=False, transform=None):
super(DGLBuiltinDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
save_dir=None,
hash_key=hash_key,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
r""" Automatically download data and extract it.

View File

@@ -76,6 +76,10 @@ class FakeNewsDataset(DGLBuiltinDataset):
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -113,7 +117,7 @@ class FakeNewsDataset(DGLBuiltinDataset):
'politifact': 'dataset/FakeNewsPOL.zip'
}
def __init__(self, name, feature_name, raw_dir=None):
def __init__(self, name, feature_name, raw_dir=None, transform=None):
assert name in ['gossipcop', 'politifact'], \
"Only supports 'gossipcop' or 'politifact'."
url = _get_dgl_url(self.file_urls[name])
@@ -123,7 +127,8 @@ class FakeNewsDataset(DGLBuiltinDataset):
self.feature_name = feature_name
super(FakeNewsDataset, self).__init__(name=name,
url=url,
raw_dir=raw_dir)
raw_dir=raw_dir,
transform=transform)
def process(self):
"""process raw data to graph, labels and masks"""
@@ -213,7 +218,11 @@ class FakeNewsDataset(DGLBuiltinDataset):
-------
(:class:`dgl.DGLGraph`, Tensor)
"""
return self.graphs[i], self.labels[i]
if self._transform is None:
g = self.graphs[i]
else:
g = self._transform(self.graphs[i])
return g, self.labels[i]
def __len__(self):
r"""Number of graphs in the dataset.

View File

@@ -48,8 +48,12 @@ class FraudDataset(DGLBuiltinDataset):
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -88,9 +92,9 @@ class FraudDataset(DGLBuiltinDataset):
'yelp': 'review',
'amazon': 'user'
}
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
val_size=0.1, force_reload=False, verbose=True, transform=None):
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
url = _get_dgl_url(self.file_urls[name])
self.seed = random_seed
@@ -101,30 +105,31 @@ class FraudDataset(DGLBuiltinDataset):
raw_dir=raw_dir,
hash_key=(random_seed, train_size, val_size),
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
"""process raw data to graph, labels, splitting masks"""
file_path = os.path.join(self.raw_path, self.file_names[self.name])
data = io.loadmat(file_path)
node_features = data['features'].todense()
# remove additional dimension of length 1 in raw .mat file
node_labels = data['label'].squeeze()
graph_data = {}
for relation in self.relations[self.name]:
adj = data[relation].tocoo()
row, col = adj.row, adj.col
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
g = heterograph(graph_data)
g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32'])
g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64'])
self.graph = g
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
def __getitem__(self, idx):
r""" Get graph object
@@ -145,12 +150,15 @@ class FraudDataset(DGLBuiltinDataset):
- ``ndata['test_mask']``: mask of testing set
"""
assert idx == 0, "This dataset has only one graph"
return self.graph
if self._transform is None:
return self.graph
else:
return self._transform(self.graph)
def __len__(self):
"""number of data examples"""
return len(self.graph)
@property
def num_classes(self):
"""Number of classes.
@@ -160,37 +168,37 @@ class FraudDataset(DGLBuiltinDataset):
int
"""
return 2
def save(self):
"""save processed data to directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
save_graphs(str(graph_path), self.graph)
def load(self):
"""load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
graph_list, _ = load_graphs(str(graph_path))
g = graph_list[0]
self.graph = g
def has_cache(self):
"""check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
return os.path.exists(graph_path)
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \
"The sum of valid training set size and validation set size " \
"must between 0 and 1 (inclusive)."
N = x.shape[0]
index = np.arange(N)
if self.name == 'amazon':
# 0-3304 are unlabeled nodes
index = np.arange(3305, N)
index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * len(index))]
val_idx = index[len(index) - int(val_size * len(index)):]
@@ -254,8 +262,12 @@ class FraudYelpDataset(FraudDataset):
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples
--------
@@ -265,16 +277,17 @@ class FraudYelpDataset(FraudDataset):
>>> feat = graph.ndata['feature']
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
val_size=0.1, force_reload=False, verbose=True, transform=None):
super(FraudYelpDataset, self).__init__(name='yelp',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
class FraudAmazonDataset(FraudDataset):
@@ -330,8 +343,12 @@ class FraudAmazonDataset(FraudDataset):
Default: 0.1
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples
--------
@@ -341,13 +358,14 @@ class FraudAmazonDataset(FraudDataset):
>>> feat = graph.ndata['feature']
>>> label = graph.ndata['label']
"""
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
val_size=0.1, force_reload=False, verbose=True):
val_size=0.1, force_reload=False, verbose=True, transform=None):
super(FraudAmazonDataset, self).__init__(name='amazon',
raw_dir=raw_dir,
random_seed=random_seed,
train_size=train_size,
val_size=val_size,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)

View File

@@ -37,8 +37,12 @@ class GDELTDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -65,7 +69,8 @@ class GDELTDataset(DGLBuiltinDataset):
....
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
def __init__(self, mode='train', raw_dir=None,
force_reload=False, verbose=False, transform=None):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid."
self.mode = mode
@@ -75,7 +80,8 @@ class GDELTDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
file_path = os.path.join(self.raw_path, self.mode + '.txt')
@@ -148,6 +154,8 @@ class GDELTDataset(DGLBuiltinDataset):
rate = self.data[row_mask][:, 1]
g = dgl_graph((edges[:, 0], edges[:, 1]))
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
if self._transform is not None:
g = self._transform(g)
return g
def __len__(self):

View File

@@ -18,9 +18,9 @@ from ..convert import graph as dgl_graph
class GINDataset(DGLBuiltinDataset):
"""Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_.
This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
The class provides an interface for nine datasets used in the paper along with the paper-specific
settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``,
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``.
@@ -44,6 +44,10 @@ class GINDataset(DGLBuiltinDataset):
add self to self edge if true
degree_as_nlabel: bool
take node degree as label and feature if true
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Examples
--------
@@ -73,7 +77,7 @@ class GINDataset(DGLBuiltinDataset):
"""
def __init__(self, name, self_loop, degree_as_nlabel=False,
raw_dir=None, force_reload=False, verbose=False):
raw_dir=None, force_reload=False, verbose=False, transform=None):
self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
@@ -106,7 +110,8 @@ class GINDataset(DGLBuiltinDataset):
self.nlabels_flag = False
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose, transform=transform)
@property
def raw_path(self):
@@ -136,7 +141,11 @@ class GINDataset(DGLBuiltinDataset):
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def _file_path(self):
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))

View File

@@ -27,13 +27,14 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
Reference: https://github.com/shchur/gnn-benchmark#datasets
"""
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
_url = _get_dgl_url('dataset/' + name + '.zip')
super(GNNBenchmarkDataset, self).__init__(name=name,
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
npz_path = os.path.join(self.raw_path, self.name + '.npz')
@@ -128,7 +129,10 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
- ``ndata['label']``: node labels
"""
assert idx == 0, "This dataset has only one graph"
return self._graph
if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self):
r"""Number of graphs in the dataset"""
@@ -164,8 +168,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -182,11 +190,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoraFullDataset, self).__init__(name="cora_full",
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
@@ -231,8 +240,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -249,11 +262,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorCSDataset, self).__init__(name='coauthor_cs',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
@@ -298,8 +312,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -316,11 +334,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
@@ -364,8 +383,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -382,11 +405,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):
@@ -430,8 +454,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -448,11 +476,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
>>> feat = g.ndata['feat'] # get node feature
>>> label = g.ndata['label'] # get node labels
"""
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo',
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
@property
def num_classes(self):

View File

@@ -39,8 +39,12 @@ class ICEWS18Dataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
-------
@@ -61,7 +65,7 @@ class ICEWS18Dataset(DGLBuiltinDataset):
....
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None):
mode = mode.lower()
assert mode in ['train', 'valid', 'test'], "Mode not valid"
self.mode = mode
@@ -70,7 +74,8 @@ class ICEWS18Dataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)),
@@ -118,7 +123,10 @@ class ICEWS18Dataset(DGLBuiltinDataset):
- ``edata['rel_type']``: edge type
"""
return self._graphs[idx]
if self._transform is None:
return self._graphs[idx]
else:
return self._transform(self._graphs[idx])
def __len__(self):
r"""Number of graphs in the dataset.

View File

@@ -34,6 +34,13 @@ class KarateClubDataset(DGLDataset):
- Edges: 156
- Number of Classes: 2
Parameters
----------
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
num_classes : int
@@ -48,8 +55,8 @@ class KarateClubDataset(DGLDataset):
>>> g = dataset[0]
>>> labels = g.ndata['label']
"""
def __init__(self):
super(KarateClubDataset, self).__init__(name='karate_club')
def __init__(self, transform=None):
super(KarateClubDataset, self).__init__(name='karate_club', transform=transform)
def process(self):
kc_graph = nx.karate_club_graph()
@@ -88,7 +95,10 @@ class KarateClubDataset(DGLDataset):
- ``ndata['label']``: ground truth labels
"""
assert idx == 0, "This dataset has only one graph"
return self._graph
if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self):
r"""The number of graphs in the dataset."""

View File

@@ -25,19 +25,24 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
Parameters
-----------
name: str
name : str
Name can be 'FB15k-237', 'FB15k' or 'wn18'.
reverse: bool
reverse : bool
Whether add reverse edges. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
self._name = name
self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
@@ -45,7 +50,8 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
r""" Automatically download data and extract it.
@@ -112,7 +118,10 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
if self._transform is None:
return self._g
else:
return self._transform(self._g)
def __len__(self):
return 1
@@ -389,8 +398,12 @@ class FB15k237Dataset(KnowledgeGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -433,9 +446,11 @@ class FB15k237Dataset(KnowledgeGraphDataset):
>>>
>>> # Train, Validation and Test
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'FB15k-237'
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -526,8 +541,12 @@ class FB15kDataset(KnowledgeGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -570,9 +589,11 @@ class FB15kDataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test
>>>
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'FB15k'
super(FB15kDataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
super(FB15kDataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -662,8 +683,12 @@ class WN18Dataset(KnowledgeGraphDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -706,9 +731,11 @@ class WN18Dataset(KnowledgeGraphDataset):
>>> # Train, Validation and Test
>>>
"""
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
verbose=True, transform=None):
name = 'wn18'
super(WN18Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
super(WN18Dataset, self).__init__(name, reverse, raw_dir,
force_reload, verbose, transform)
def __getitem__(self, idx):
r"""Gets the graph object

View File

@@ -33,8 +33,12 @@ class MiniGCDataset(DGLDataset):
Minimum number of nodes for graphs
max_num_v: int
Maximum number of nodes for graphs
seed : int, default is 0
seed: int, default is 0
Random seed for data generation
transform: callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -75,7 +79,7 @@ class MiniGCDataset(DGLDataset):
"""
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
save_graph=True, force_reload=False, verbose=False):
save_graph=True, force_reload=False, verbose=False, transform=None):
self.num_graphs = num_graphs
self.min_num_v = min_num_v
self.max_num_v = max_num_v
@@ -84,7 +88,7 @@ class MiniGCDataset(DGLDataset):
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
force_reload=force_reload,
verbose=verbose)
verbose=verbose, transform=transform)
def process(self):
self.graphs = []
@@ -108,7 +112,11 @@ class MiniGCDataset(DGLDataset):
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.labels[idx]
def has_cache(self):
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))

View File

@@ -56,9 +56,13 @@ class PPIDataset(DGLBuiltinDataset):
force_reload : bool
Whether to reload the dataset.
Default: False
verbose: bool
verbose : bool
Whether to print out progress information.
Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -79,7 +83,8 @@ class PPIDataset(DGLBuiltinDataset):
.... # your code here
>>>
"""
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
def __init__(self, mode='train', raw_dir=None, force_reload=False,
verbose=False, transform=None):
assert mode in ['train', 'valid', 'test']
self.mode = mode
_url = _get_dgl_url('dataset/ppi.zip')
@@ -87,7 +92,8 @@ class PPIDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
@@ -178,7 +184,10 @@ class PPIDataset(DGLBuiltinDataset):
- ``ndata['feat']``: node features
- ``ndata['label']``: node labels
"""
return self.graphs[item]
if self._transform is None:
return self.graphs[item]
else:
return self._transform(self.graphs[item])
class LegacyPPIDataset(PPIDataset):
@@ -198,5 +207,8 @@ class LegacyPPIDataset(PPIDataset):
(dgl.DGLGraph, Tensor, Tensor)
The graph, features and its label.
"""
return self.graphs[item], self.graphs[item].ndata['feat'], self.graphs[item].ndata['label']
if self._transform is None:
g = self.graphs[item]
else:
g = self._transform(self.graphs[item])
return g, g.ndata['feat'], g.ndata['label']

View File

@@ -34,8 +34,12 @@ class QM7bDataset(DGLDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -65,12 +69,13 @@ class QM7bDataset(DGLDataset):
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
mat_path = self.raw_path + '.mat'
@@ -129,7 +134,11 @@ class QM7bDataset(DGLDataset):
-------
(:class:`dgl.DGLGraph`, Tensor)
"""
return self.graphs[idx], self.label[idx]
if self._transform is None:
g = self.graphs[idx]
else:
g = self._transform(self.graphs[idx])
return g, self.label[idx]
def __len__(self):
r"""Number of graphs in the dataset.

View File

@@ -20,11 +20,11 @@ class QM9Dataset(DGLDataset):
2. It only provides atoms' coordinates and atomic numbers as node features
3. It only provides 12 regression targets.
Reference:
Reference:
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 130,831
@@ -60,9 +60,9 @@ class QM9Dataset(DGLDataset):
Parameters
----------
label_keys: list
label_keys : list
Names of the regression property, which should be a subset of the keys in the table above.
cutoff: float
cutoff : float
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
Default: 5.0 Angstrom
raw_dir : str
@@ -70,8 +70,12 @@ class QM9Dataset(DGLDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -82,7 +86,7 @@ class QM9Dataset(DGLDataset):
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
@@ -102,8 +106,9 @@ class QM9Dataset(DGLDataset):
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
verbose=False,
transform=None):
self.cutoff = cutoff
self.label_keys = label_keys
self._url = _get_dgl_url('dataset/qm9_eV.npz')
@@ -112,7 +117,8 @@ class QM9Dataset(DGLDataset):
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
npz_path = f'{self.raw_dir}/qm9_eV.npz'
@@ -148,7 +154,7 @@ class QM9Dataset(DGLDataset):
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
@@ -170,8 +176,12 @@ class QM9Dataset(DGLDataset):
g = dgl_graph((u, v))
g = to_bidirected(g)
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
dtype=F.data_type_dict['int64'])
if self._transform is not None:
g = self._transform(g)
return g, label
def __len__(self):

View File

@@ -14,34 +14,34 @@ class QM9EdgeDataset(DGLDataset):
This dataset consists of 130,831 molecules with 19 regression targets.
Nodes correspond to atoms and edges correspond to bonds.
This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:
1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.
2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.
3. It provides another 7 regression tasks(from 12 to 19).
This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_.
Reference:
- `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
- `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
For
For
Statistics:
- Number of graphs: 130,831.
- Number of regression targets: 19.
Node attributes:
- pos: the 3D coordinates of each atom.
- attr: the 11D atom features.
- pos: the 3D coordinates of each atom.
- attr: the 11D atom features.
Edge attributes:
- edge_attr: the 4D bond features.
Regression targets:
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
@@ -85,10 +85,10 @@ class QM9EdgeDataset(DGLDataset):
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
label_keys : list
Names of the regression property, which should be a subset of the keys in the table above.
If not provided, it will load all the labels.
raw_dir : str
@@ -96,8 +96,12 @@ class QM9EdgeDataset(DGLDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False.
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -108,13 +112,13 @@ class QM9EdgeDataset(DGLDataset):
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
>>> data.num_labels
2
>>> # iterate over the dataset
>>> for graph, labels in data:
... print(graph) # get information of each graph
@@ -122,47 +126,49 @@ class QM9EdgeDataset(DGLDataset):
... # your code here...
>>>
"""
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom',
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']
map_dict = {}
for i, key in enumerate(keys):
map_dict[key] = i
def __init__(self,
def __init__(self,
label_keys=None,
raw_dir=None,
force_reload=False,
verbose=True):
raw_dir=None,
force_reload=False,
verbose=True,
transform=None):
if label_keys is None:
self.label_keys = None
self.num_labels = 19
else:
self.label_keys = [self.map_dict[i] for i in label_keys]
self.num_labels = len(label_keys)
self._url = _get_dgl_url('dataset/qm9_edge.npz')
super(QM9EdgeDataset, self).__init__(name='qm9Edge',
raw_dir=raw_dir,
url=self._url,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def download(self):
file_path = f'{self.raw_dir}/qm9_edge.npz'
if not os.path.exists(file_path):
download(self._url, path=file_path)
def process(self):
self.load()
def has_cache(self):
npz_path = f'{self.raw_dir}/qm9_edge.npz'
return os.path.exists(npz_path)
def save(self):
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz',
n_node=self.n_node,
@@ -171,7 +177,7 @@ class QM9EdgeDataset(DGLDataset):
node_pos=self.node_pos,
edge_attr=self.edge_attr,
src=self.src,
dst=self.dst,
dst=self.dst,
targets=self.targets)
def load(self):
@@ -184,52 +190,55 @@ class QM9EdgeDataset(DGLDataset):
self.node_pos = data_dict['node_pos']
self.edge_attr = data_dict['edge_attr']
self.targets = data_dict['targets']
self.src = data_dict['src']
self.dst = data_dict['dst']
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
def __getitem__(self, idx):
r""" Get graph and label by index
r""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
The graph contains:
- ``ndata['pos']``: the coordinates of each atom
- ``ndata['attr']``: the features of each atom
- ``edata['edge_attr']``: the features of each bond
Tensor
Property values of molecular graphs
"""
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
g = dgl_graph((src, dst))
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32'])
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32'])
if self._transform is not None:
g = self._transform(g)
return g, label
def __len__(self):
r""" Number of graphs in the dataset.
Returns
-------
int

View File

@@ -94,15 +94,20 @@ class RDFGraphDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool, optional
If true, force load and process from raw data. Ignore cached pre-processed data.
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(self, name, url, predict_category,
print_every=10000,
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
self._insert_reverse = insert_reverse
self._print_every = print_every
self._predict_category = predict_category
@@ -110,7 +115,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
super(RDFGraphDataset, self).__init__(name, url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
raw_tuples = self.load_raw_tuples(self.raw_path)
@@ -409,6 +415,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
r"""Gets the graph object
"""
g = self._hg
if self._transform is not None:
g = self._transform(g)
return g
def __len__(self):
@@ -523,17 +531,21 @@ class AIFBDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -562,7 +574,8 @@ class AIFBDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
@@ -574,7 +587,8 @@ class AIFBDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -653,17 +667,21 @@ class MUTAGDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -697,7 +715,8 @@ class MUTAGDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
@@ -712,7 +731,8 @@ class MUTAGDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -814,17 +834,21 @@ class BGSDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -854,7 +878,8 @@ class BGSDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
name = 'bgs-hetero'
@@ -865,7 +890,8 @@ class BGSDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object
@@ -964,17 +990,21 @@ class AMDataset(RDFGraphDataset):
Parameters
-----------
print_every: int
print_every : int
Preprocessing log for every X tuples. Default: 10000.
insert_reverse: bool
insert_reverse : bool
If true, add reverse edge and reverse relations to the final graph. Default: True.
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -1003,7 +1033,8 @@ class AMDataset(RDFGraphDataset):
insert_reverse=True,
raw_dir=None,
force_reload=False,
verbose=True):
verbose=True,
transform=None):
import rdflib as rdf
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
@@ -1015,7 +1046,8 @@ class AMDataset(RDFGraphDataset):
insert_reverse=insert_reverse,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def __getitem__(self, idx):
r"""Gets the graph object

View File

@@ -84,8 +84,12 @@ class RedditDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -125,7 +129,8 @@ class RedditDataset(DGLBuiltinDataset):
>>>
>>> # Train, Validation and Test
"""
def __init__(self, self_loop=False, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, self_loop=False, raw_dir=None, force_reload=False,
verbose=False, transform=None):
self_loop_str = ""
if self_loop:
self_loop_str = "_self_loop"
@@ -135,7 +140,8 @@ class RedditDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
# graph
@@ -251,7 +257,10 @@ class RedditDataset(DGLBuiltinDataset):
- ``ndata['test_mask']:`` mask for test node set
"""
assert idx == 0, "Reddit Dataset only has one graph"
return self._graph
if self._transform is None:
return self._graph
else:
return self._transform(self._graph)
def __len__(self):
r"""Number of graphs in the dataset"""

View File

@@ -23,7 +23,7 @@ class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset.
.. deprecated:: 0.5.0
- ``trees`` is deprecated, it is replaced by:
>>> dataset = SSTDataset()
@@ -63,8 +63,12 @@ class SSTDataset(DGLBuiltinDataset):
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -120,7 +124,8 @@ class SSTDataset(DGLBuiltinDataset):
vocab_file=None,
raw_dir=None,
force_reload=False,
verbose=False):
verbose=False,
transform=None):
assert mode in ['train', 'dev', 'test', 'tiny']
_url = _get_dgl_url('dataset/sst.zip')
self._glove_embed_file = glove_embed_file if mode == 'train' else None
@@ -130,7 +135,8 @@ class SSTDataset(DGLBuiltinDataset):
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
verbose=verbose,
transform=transform)
def process(self):
from nltk.corpus.reader import BracketParseCorpusReader
@@ -255,7 +261,10 @@ class SSTDataset(DGLBuiltinDataset):
- ``ndata['y']:`` label of the node
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
"""
return self._trees[idx]
if self._transform is None:
return self._trees[idx]
else:
return self._transform(self._trees[idx])
def __len__(self):
r"""Number of graphs in the dataset."""

View File

@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
Parameters
----------
name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
use_pandas : bool
Numpy's file read function has performance issue when file is large,
@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
max_allow_node : int
Remove graphs that contains more nodes than ``max_allow_node``.
Default : None
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.
If neither labels provided, it uses constant for node feature.
The dataset sorts graphs by their labels.
The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split.
Examples
@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
def __init__(self, name, use_pandas=False,
hidden_size=10, max_allow_node=None,
raw_dir=None, force_reload=False, verbose=False):
raw_dir=None, force_reload=False, verbose=False, transform=None):
url = self._url.format(name)
self.hidden_size = hidden_size
@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
self.use_pandas = use_pandas
super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir,
hash_key=(name, use_pandas, hidden_size, max_allow_node),
force_reload=force_reload, verbose=verbose)
force_reload=force_reload, verbose=verbose, transform=transform)
def process(self):
self.data_mode = None
@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
DS_graph_labels = self._idx_from_zero(
np.genfromtxt(self._file_path("graph_labels"), dtype=int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = DS_graph_labels
self.graph_labels = DS_graph_labels
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float)
self.num_labels = None
@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset):
And its label.
"""
g = self.graph_lists[idx]
if self._transform is not None:
g = self._transform(g)
return g, self.graph_labels[idx]
def __len__(self):
@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset):
Parameters
----------
name : str
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset):
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
:math:`\lbrace 0, 2 \rbrace`.
The dataset sorts graphs by their labels.
The dataset sorts graphs by their labels.
Shuffle is preferred before manual train/val split.
Examples
@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset):
Graph(num_nodes=9539, num_edges=47382,
ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
"""
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
url = self._url.format(name)
super(TUDataset, self).__init__(name=name, url=url,
raw_dir=raw_dir, force_reload=force_reload,
verbose=verbose)
verbose=verbose, transform=transform)
def process(self):
DS_edge_list = self._idx_from_zero(
loadtxt(self._file_path("A"), delimiter=",").astype(int))
DS_indicator = self._idx_from_zero(
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
if os.path.exists(self._file_path("graph_labels")):
DS_graph_labels = self._idx_reset(
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
self.num_labels = max(DS_graph_labels) + 1
self.graph_labels = F.tensor(DS_graph_labels)
self.graph_labels = F.tensor(DS_graph_labels)
elif os.path.exists(self._file_path("graph_attributes")):
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
self.num_labels = None
self.graph_labels = F.tensor(DS_graph_labels)
self.graph_labels = F.tensor(DS_graph_labels)
else:
raise Exception("Unknown graph label or graph attributes")
@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset):
And its label.
"""
g = self.graph_lists[idx]
if self._transform is not None:
g = self._transform(g)
return g, self.graph_labels[idx]
def __len__(self):