mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-07 20:41:20 +08:00
[Dataset] Add transform argument to built-in datasets (#3733)
* Update * Fix * Update
This commit is contained in:
@@ -39,6 +39,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
|
||||
verbose: bool
|
||||
Whether to print out progress information.
|
||||
Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -67,12 +71,13 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
|
||||
_url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'
|
||||
_sha1_str = 'c14281f9e252de0bd0b5f1c6e2bae03123938641'
|
||||
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(BitcoinOTCDataset, self).__init__(name='bitcoinotc',
|
||||
url=self._url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def download(self):
|
||||
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
|
||||
@@ -143,7 +148,10 @@ class BitcoinOTCDataset(DGLBuiltinDataset):
|
||||
|
||||
- ``edata['h']`` : edge weights
|
||||
"""
|
||||
return self.graphs[item]
|
||||
if self._transform is None:
|
||||
return self.graphs[item]
|
||||
else:
|
||||
return self._transform(self.graphs[item])
|
||||
|
||||
@property
|
||||
def is_temporal(self):
|
||||
|
||||
@@ -43,10 +43,14 @@ class CitationGraphDataset(DGLBuiltinDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
"""
|
||||
_urls = {
|
||||
'cora_v2' : 'dataset/cora_v2.zip',
|
||||
@@ -54,7 +58,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
|
||||
'pubmed' : 'dataset/pubmed.zip',
|
||||
}
|
||||
|
||||
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def __init__(self, name, raw_dir=None, force_reload=False,
|
||||
verbose=True, reverse_edge=True, transform=None):
|
||||
assert name.lower() in ['cora', 'citeseer', 'pubmed']
|
||||
|
||||
# Previously we use the pre-processing in pygcn (https://github.com/tkipf/pygcn)
|
||||
@@ -69,7 +74,8 @@ class CitationGraphDataset(DGLBuiltinDataset):
|
||||
url=url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
"""Loads input data from data directory and reorder graph for better locality
|
||||
@@ -213,7 +219,10 @@ class CitationGraphDataset(DGLBuiltinDataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert idx == 0, "This dataset has only one graph"
|
||||
return self._g
|
||||
if self._transform is None:
|
||||
return self._g
|
||||
else:
|
||||
return self._transform(self._g)
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
@@ -267,7 +276,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
|
||||
@property
|
||||
def reverse_edge(self):
|
||||
return self._reverse_edge
|
||||
|
||||
|
||||
|
||||
def _preprocess_features(features):
|
||||
"""Row-normalize feature matrix and convert to tuple representation"""
|
||||
@@ -356,10 +365,14 @@ class CoraGraphDataset(CitationGraphDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -400,10 +413,12 @@ class CoraGraphDataset(CitationGraphDataset):
|
||||
>>> label = g.ndata['label']
|
||||
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
|
||||
reverse_edge=True, transform=None):
|
||||
name = 'cora'
|
||||
|
||||
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
|
||||
super(CoraGraphDataset, self).__init__(name, raw_dir, force_reload,
|
||||
verbose, reverse_edge, transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -496,10 +511,14 @@ class CiteseerGraphDataset(CitationGraphDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -543,10 +562,12 @@ class CiteseerGraphDataset(CitationGraphDataset):
|
||||
>>> label = g.ndata['label']
|
||||
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def __init__(self, raw_dir=None, force_reload=False,
|
||||
verbose=True, reverse_edge=True, transform=None):
|
||||
name = 'citeseer'
|
||||
|
||||
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
|
||||
super(CiteseerGraphDataset, self).__init__(name, raw_dir, force_reload,
|
||||
verbose, reverse_edge, transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -639,10 +660,14 @@ class PubmedGraphDataset(CitationGraphDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -683,10 +708,12 @@ class PubmedGraphDataset(CitationGraphDataset):
|
||||
>>> label = g.ndata['label']
|
||||
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True,
|
||||
reverse_edge=True, transform=None):
|
||||
name = 'pubmed'
|
||||
|
||||
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload, verbose, reverse_edge)
|
||||
super(PubmedGraphDataset, self).__init__(name, raw_dir, force_reload,
|
||||
verbose, reverse_edge, transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -714,7 +741,7 @@ class PubmedGraphDataset(CitationGraphDataset):
|
||||
r"""The number of graphs in the dataset."""
|
||||
return super(PubmedGraphDataset, self).__len__()
|
||||
|
||||
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True, transform=None):
|
||||
"""Get CoraGraphDataset
|
||||
|
||||
Parameters
|
||||
@@ -724,19 +751,24 @@ def load_cora(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True)
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Return
|
||||
-------
|
||||
CoraGraphDataset
|
||||
"""
|
||||
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
|
||||
data = CoraGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
|
||||
return data
|
||||
|
||||
def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def load_citeseer(raw_dir=None, force_reload=False, verbose=True,
|
||||
reverse_edge=True, transform=None):
|
||||
"""Get CiteseerGraphDataset
|
||||
|
||||
Parameters
|
||||
@@ -746,38 +778,47 @@ def load_citeseer(raw_dir=None, force_reload=False, verbose=True, reverse_edge=T
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Return
|
||||
-------
|
||||
CiteseerGraphDataset
|
||||
"""
|
||||
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
|
||||
data = CiteseerGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
|
||||
return data
|
||||
|
||||
def load_pubmed(raw_dir=None, force_reload=False, verbose=True, reverse_edge=True):
|
||||
def load_pubmed(raw_dir=None, force_reload=False, verbose=True,
|
||||
reverse_edge=True, transform=None):
|
||||
"""Get PubmedGraphDataset
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
reverse_edge: bool
|
||||
reverse_edge : bool
|
||||
Whether to add reverse edges in graph. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Return
|
||||
-------
|
||||
PubmedGraphDataset
|
||||
"""
|
||||
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge)
|
||||
data = PubmedGraphDataset(raw_dir, force_reload, verbose, reverse_edge, transform)
|
||||
return data
|
||||
|
||||
class CoraBinary(DGLBuiltinDataset):
|
||||
@@ -798,15 +839,20 @@ class CoraBinary(DGLBuiltinDataset):
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=True, transform=None):
|
||||
name = 'cora_binary'
|
||||
url = _get_dgl_url('dataset/cora_binary.zip')
|
||||
super(CoraBinary, self).__init__(name,
|
||||
url=url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
root = self.raw_path
|
||||
@@ -894,7 +940,11 @@ class CoraBinary(DGLBuiltinDataset):
|
||||
(dgl.DGLGraph, scipy.sparse.coo_matrix, int)
|
||||
The graph, scipy sparse coo_matrix and its label.
|
||||
"""
|
||||
return (self.graphs[i], self.pmpds[i], self.labels[i])
|
||||
if self._transform is None:
|
||||
g = self.graphs[i]
|
||||
else:
|
||||
g = self._transform(self.graphs[i])
|
||||
return (g, self.pmpds[i], self.labels[i])
|
||||
|
||||
@property
|
||||
def save_name(self):
|
||||
|
||||
@@ -33,6 +33,10 @@ class DGLCSVDataset(DGLDataset):
|
||||
A callable object which is used to parse corresponding column graph
|
||||
data. Default: None. If None, a default data parser is applied
|
||||
which load data directly and tries to convert list into array.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -46,7 +50,8 @@ class DGLCSVDataset(DGLDataset):
|
||||
"""
|
||||
META_YAML_NAME = 'meta.yaml'
|
||||
|
||||
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None, edge_data_parser=None, graph_data_parser=None):
|
||||
def __init__(self, data_path, force_reload=False, verbose=True, node_data_parser=None,
|
||||
edge_data_parser=None, graph_data_parser=None, transform=None):
|
||||
from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
|
||||
self.graphs = None
|
||||
self.data = None
|
||||
@@ -61,7 +66,7 @@ class DGLCSVDataset(DGLDataset):
|
||||
self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
|
||||
ds_name = self.meta_yaml.dataset_name
|
||||
super().__init__(ds_name, raw_dir=os.path.dirname(
|
||||
meta_yaml_path), force_reload=force_reload, verbose=verbose)
|
||||
meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform)
|
||||
|
||||
|
||||
def process(self):
|
||||
@@ -122,10 +127,15 @@ class DGLCSVDataset(DGLDataset):
|
||||
self.graphs, self.data = load_graphs(graph_path)
|
||||
|
||||
def __getitem__(self, i):
|
||||
if 'label' in self.data:
|
||||
return self.graphs[i], self.data['label'][i]
|
||||
if self._transform is None:
|
||||
g = self.graphs[i]
|
||||
else:
|
||||
return self.graphs[i]
|
||||
g = self._transform(self.graphs[i])
|
||||
|
||||
if 'label' in self.data:
|
||||
return g, self.data['label'][i]
|
||||
else:
|
||||
return g
|
||||
|
||||
def __len__(self):
|
||||
return len(self.graphs)
|
||||
|
||||
@@ -49,6 +49,10 @@ class DGLDataset(object):
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose : bool
|
||||
Whether to print out progress information
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -72,13 +76,14 @@ class DGLDataset(object):
|
||||
Hash value for the dataset and the setting.
|
||||
"""
|
||||
def __init__(self, name, url=None, raw_dir=None, save_dir=None,
|
||||
hash_key=(), force_reload=False, verbose=False):
|
||||
hash_key=(), force_reload=False, verbose=False, transform=None):
|
||||
self._name = name
|
||||
self._url = url
|
||||
self._force_reload = force_reload
|
||||
self._verbose = verbose
|
||||
self._hash_key = hash_key
|
||||
self._hash = self._get_hash()
|
||||
self._transform = transform
|
||||
|
||||
# if no dir is provided, the default dgl download dir is used.
|
||||
if raw_dir is None:
|
||||
@@ -142,7 +147,7 @@ class DGLDataset(object):
|
||||
def _download(self):
|
||||
"""Download dataset by calling ``self.download()``
|
||||
if the dataset does not exists under ``self.raw_path``.
|
||||
|
||||
|
||||
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
|
||||
One can overwrite ``raw_path()`` function to change the path.
|
||||
"""
|
||||
@@ -161,7 +166,7 @@ class DGLDataset(object):
|
||||
- If loadin process fails, re-download and process the dataset.
|
||||
|
||||
else:
|
||||
|
||||
|
||||
- Download the dataset if needed.
|
||||
- Process the dataset and build the dgl graph.
|
||||
- Save the processed dataset into files.
|
||||
@@ -287,17 +292,23 @@ class DGLBuiltinDataset(DGLDataset):
|
||||
from the same dataset class by comparing the hash values.
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: False
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
"""
|
||||
def __init__(self, name, url, raw_dir=None, hash_key=(), force_reload=False, verbose=False):
|
||||
def __init__(self, name, url, raw_dir=None, hash_key=(),
|
||||
force_reload=False, verbose=False, transform=None):
|
||||
super(DGLBuiltinDataset, self).__init__(name,
|
||||
url=url,
|
||||
raw_dir=raw_dir,
|
||||
save_dir=None,
|
||||
hash_key=hash_key,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def download(self):
|
||||
r""" Automatically download data and extract it.
|
||||
|
||||
@@ -76,6 +76,10 @@ class FakeNewsDataset(DGLBuiltinDataset):
|
||||
downloaded data or the directory that
|
||||
already stores the input data.
|
||||
Default: ~/.dgl/
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -113,7 +117,7 @@ class FakeNewsDataset(DGLBuiltinDataset):
|
||||
'politifact': 'dataset/FakeNewsPOL.zip'
|
||||
}
|
||||
|
||||
def __init__(self, name, feature_name, raw_dir=None):
|
||||
def __init__(self, name, feature_name, raw_dir=None, transform=None):
|
||||
assert name in ['gossipcop', 'politifact'], \
|
||||
"Only supports 'gossipcop' or 'politifact'."
|
||||
url = _get_dgl_url(self.file_urls[name])
|
||||
@@ -123,7 +127,8 @@ class FakeNewsDataset(DGLBuiltinDataset):
|
||||
self.feature_name = feature_name
|
||||
super(FakeNewsDataset, self).__init__(name=name,
|
||||
url=url,
|
||||
raw_dir=raw_dir)
|
||||
raw_dir=raw_dir,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
"""process raw data to graph, labels and masks"""
|
||||
@@ -213,7 +218,11 @@ class FakeNewsDataset(DGLBuiltinDataset):
|
||||
-------
|
||||
(:class:`dgl.DGLGraph`, Tensor)
|
||||
"""
|
||||
return self.graphs[i], self.labels[i]
|
||||
if self._transform is None:
|
||||
g = self.graphs[i]
|
||||
else:
|
||||
g = self._transform(self.graphs[i])
|
||||
return g, self.labels[i]
|
||||
|
||||
def __len__(self):
|
||||
r"""Number of graphs in the dataset.
|
||||
|
||||
@@ -48,8 +48,12 @@ class FraudDataset(DGLBuiltinDataset):
|
||||
Default: 0.1
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -88,9 +92,9 @@ class FraudDataset(DGLBuiltinDataset):
|
||||
'yelp': 'review',
|
||||
'amazon': 'user'
|
||||
}
|
||||
|
||||
|
||||
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7,
|
||||
val_size=0.1, force_reload=False, verbose=True):
|
||||
val_size=0.1, force_reload=False, verbose=True, transform=None):
|
||||
assert name in ['yelp', 'amazon'], "only supports 'yelp', or 'amazon'"
|
||||
url = _get_dgl_url(self.file_urls[name])
|
||||
self.seed = random_seed
|
||||
@@ -101,30 +105,31 @@ class FraudDataset(DGLBuiltinDataset):
|
||||
raw_dir=raw_dir,
|
||||
hash_key=(random_seed, train_size, val_size),
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
"""process raw data to graph, labels, splitting masks"""
|
||||
file_path = os.path.join(self.raw_path, self.file_names[self.name])
|
||||
|
||||
|
||||
data = io.loadmat(file_path)
|
||||
node_features = data['features'].todense()
|
||||
# remove additional dimension of length 1 in raw .mat file
|
||||
node_labels = data['label'].squeeze()
|
||||
|
||||
|
||||
graph_data = {}
|
||||
for relation in self.relations[self.name]:
|
||||
adj = data[relation].tocoo()
|
||||
row, col = adj.row, adj.col
|
||||
graph_data[(self.node_name[self.name], relation, self.node_name[self.name])] = (row, col)
|
||||
g = heterograph(graph_data)
|
||||
|
||||
|
||||
g.ndata['feature'] = F.tensor(node_features, dtype=F.data_type_dict['float32'])
|
||||
g.ndata['label'] = F.tensor(node_labels, dtype=F.data_type_dict['int64'])
|
||||
self.graph = g
|
||||
|
||||
|
||||
self._random_split(g.ndata['feature'], self.seed, self.train_size, self.val_size)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r""" Get graph object
|
||||
|
||||
@@ -145,12 +150,15 @@ class FraudDataset(DGLBuiltinDataset):
|
||||
- ``ndata['test_mask']``: mask of testing set
|
||||
"""
|
||||
assert idx == 0, "This dataset has only one graph"
|
||||
return self.graph
|
||||
|
||||
if self._transform is None:
|
||||
return self.graph
|
||||
else:
|
||||
return self._transform(self.graph)
|
||||
|
||||
def __len__(self):
|
||||
"""number of data examples"""
|
||||
return len(self.graph)
|
||||
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
"""Number of classes.
|
||||
@@ -160,37 +168,37 @@ class FraudDataset(DGLBuiltinDataset):
|
||||
int
|
||||
"""
|
||||
return 2
|
||||
|
||||
|
||||
def save(self):
|
||||
"""save processed data to directory `self.save_path`"""
|
||||
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
|
||||
save_graphs(str(graph_path), self.graph)
|
||||
|
||||
|
||||
def load(self):
|
||||
"""load processed data from directory `self.save_path`"""
|
||||
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
|
||||
graph_list, _ = load_graphs(str(graph_path))
|
||||
g = graph_list[0]
|
||||
self.graph = g
|
||||
|
||||
|
||||
def has_cache(self):
|
||||
"""check whether there are processed data in `self.save_path`"""
|
||||
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph_{}.bin'.format(self.hash))
|
||||
return os.path.exists(graph_path)
|
||||
|
||||
|
||||
def _random_split(self, x, seed=717, train_size=0.7, val_size=0.1):
|
||||
"""split the dataset into training set, validation set and testing set"""
|
||||
|
||||
|
||||
assert 0 <= train_size + val_size <= 1, \
|
||||
"The sum of valid training set size and validation set size " \
|
||||
"must between 0 and 1 (inclusive)."
|
||||
|
||||
|
||||
N = x.shape[0]
|
||||
index = np.arange(N)
|
||||
if self.name == 'amazon':
|
||||
# 0-3304 are unlabeled nodes
|
||||
index = np.arange(3305, N)
|
||||
|
||||
|
||||
index = np.random.RandomState(seed).permutation(index)
|
||||
train_idx = index[:int(train_size * len(index))]
|
||||
val_idx = index[len(index) - int(val_size * len(index)):]
|
||||
@@ -254,8 +262,12 @@ class FraudYelpDataset(FraudDataset):
|
||||
Default: 0.1
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -265,16 +277,17 @@ class FraudYelpDataset(FraudDataset):
|
||||
>>> feat = graph.ndata['feature']
|
||||
>>> label = graph.ndata['label']
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
|
||||
val_size=0.1, force_reload=False, verbose=True):
|
||||
val_size=0.1, force_reload=False, verbose=True, transform=None):
|
||||
super(FraudYelpDataset, self).__init__(name='yelp',
|
||||
raw_dir=raw_dir,
|
||||
random_seed=random_seed,
|
||||
train_size=train_size,
|
||||
val_size=val_size,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
|
||||
class FraudAmazonDataset(FraudDataset):
|
||||
@@ -330,8 +343,12 @@ class FraudAmazonDataset(FraudDataset):
|
||||
Default: 0.1
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -341,13 +358,14 @@ class FraudAmazonDataset(FraudDataset):
|
||||
>>> feat = graph.ndata['feature']
|
||||
>>> label = graph.ndata['label']
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, raw_dir=None, random_seed=717, train_size=0.7,
|
||||
val_size=0.1, force_reload=False, verbose=True):
|
||||
val_size=0.1, force_reload=False, verbose=True, transform=None):
|
||||
super(FraudAmazonDataset, self).__init__(name='amazon',
|
||||
raw_dir=raw_dir,
|
||||
random_seed=random_seed,
|
||||
train_size=train_size,
|
||||
val_size=val_size,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
@@ -37,8 +37,12 @@ class GDELTDataset(DGLBuiltinDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -65,7 +69,8 @@ class GDELTDataset(DGLBuiltinDataset):
|
||||
....
|
||||
>>>
|
||||
"""
|
||||
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, mode='train', raw_dir=None,
|
||||
force_reload=False, verbose=False, transform=None):
|
||||
mode = mode.lower()
|
||||
assert mode in ['train', 'valid', 'test'], "Mode not valid."
|
||||
self.mode = mode
|
||||
@@ -75,7 +80,8 @@ class GDELTDataset(DGLBuiltinDataset):
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
file_path = os.path.join(self.raw_path, self.mode + '.txt')
|
||||
@@ -148,6 +154,8 @@ class GDELTDataset(DGLBuiltinDataset):
|
||||
rate = self.data[row_mask][:, 1]
|
||||
g = dgl_graph((edges[:, 0], edges[:, 1]))
|
||||
g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64'])
|
||||
if self._transform is not None:
|
||||
g = self._transform(g)
|
||||
return g
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -18,9 +18,9 @@ from ..convert import graph as dgl_graph
|
||||
|
||||
class GINDataset(DGLBuiltinDataset):
|
||||
"""Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_.
|
||||
|
||||
|
||||
This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
|
||||
|
||||
|
||||
The class provides an interface for nine datasets used in the paper along with the paper-specific
|
||||
settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``,
|
||||
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``.
|
||||
@@ -44,6 +44,10 @@ class GINDataset(DGLBuiltinDataset):
|
||||
add self to self edge if true
|
||||
degree_as_nlabel: bool
|
||||
take node degree as label and feature if true
|
||||
transform: callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -73,7 +77,7 @@ class GINDataset(DGLBuiltinDataset):
|
||||
"""
|
||||
|
||||
def __init__(self, name, self_loop, degree_as_nlabel=False,
|
||||
raw_dir=None, force_reload=False, verbose=False):
|
||||
raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
|
||||
self._name = name # MUTAG
|
||||
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
|
||||
@@ -106,7 +110,8 @@ class GINDataset(DGLBuiltinDataset):
|
||||
self.nlabels_flag = False
|
||||
|
||||
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
|
||||
raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
|
||||
raw_dir=raw_dir, force_reload=force_reload,
|
||||
verbose=verbose, transform=transform)
|
||||
|
||||
@property
|
||||
def raw_path(self):
|
||||
@@ -136,7 +141,11 @@ class GINDataset(DGLBuiltinDataset):
|
||||
(:class:`dgl.Graph`, Tensor)
|
||||
The graph and its label.
|
||||
"""
|
||||
return self.graphs[idx], self.labels[idx]
|
||||
if self._transform is None:
|
||||
g = self.graphs[idx]
|
||||
else:
|
||||
g = self._transform(self.graphs[idx])
|
||||
return g, self.labels[idx]
|
||||
|
||||
def _file_path(self):
|
||||
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
|
||||
|
||||
@@ -27,13 +27,14 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
|
||||
|
||||
Reference: https://github.com/shchur/gnn-benchmark#datasets
|
||||
"""
|
||||
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
_url = _get_dgl_url('dataset/' + name + '.zip')
|
||||
super(GNNBenchmarkDataset, self).__init__(name=name,
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
npz_path = os.path.join(self.raw_path, self.name + '.npz')
|
||||
@@ -128,7 +129,10 @@ class GNNBenchmarkDataset(DGLBuiltinDataset):
|
||||
- ``ndata['label']``: node labels
|
||||
"""
|
||||
assert idx == 0, "This dataset has only one graph"
|
||||
return self._graph
|
||||
if self._transform is None:
|
||||
return self._graph
|
||||
else:
|
||||
return self._transform(self._graph)
|
||||
|
||||
def __len__(self):
|
||||
r"""Number of graphs in the dataset"""
|
||||
@@ -164,8 +168,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -182,11 +190,12 @@ class CoraFullDataset(GNNBenchmarkDataset):
|
||||
>>> feat = g.ndata['feat'] # get node feature
|
||||
>>> label = g.ndata['label'] # get node labels
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(CoraFullDataset, self).__init__(name="cora_full",
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
@@ -231,8 +240,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -249,11 +262,12 @@ class CoauthorCSDataset(GNNBenchmarkDataset):
|
||||
>>> feat = g.ndata['feat'] # get node feature
|
||||
>>> label = g.ndata['label'] # get node labels
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(CoauthorCSDataset, self).__init__(name='coauthor_cs',
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
@@ -298,8 +312,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -316,11 +334,12 @@ class CoauthorPhysicsDataset(GNNBenchmarkDataset):
|
||||
>>> feat = g.ndata['feat'] # get node feature
|
||||
>>> label = g.ndata['label'] # get node labels
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(CoauthorPhysicsDataset, self).__init__(name='coauthor_physics',
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
@@ -364,8 +383,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -382,11 +405,12 @@ class AmazonCoBuyComputerDataset(GNNBenchmarkDataset):
|
||||
>>> feat = g.ndata['feat'] # get node feature
|
||||
>>> label = g.ndata['label'] # get node labels
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(AmazonCoBuyComputerDataset, self).__init__(name='amazon_co_buy_computer',
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
@@ -430,8 +454,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -448,11 +476,12 @@ class AmazonCoBuyPhotoDataset(GNNBenchmarkDataset):
|
||||
>>> feat = g.ndata['feat'] # get node feature
|
||||
>>> label = g.ndata['label'] # get node labels
|
||||
"""
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(AmazonCoBuyPhotoDataset, self).__init__(name='amazon_co_buy_photo',
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
|
||||
@@ -39,8 +39,12 @@ class ICEWS18Dataset(DGLBuiltinDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
-------
|
||||
@@ -61,7 +65,7 @@ class ICEWS18Dataset(DGLBuiltinDataset):
|
||||
....
|
||||
>>>
|
||||
"""
|
||||
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
mode = mode.lower()
|
||||
assert mode in ['train', 'valid', 'test'], "Mode not valid"
|
||||
self.mode = mode
|
||||
@@ -70,7 +74,8 @@ class ICEWS18Dataset(DGLBuiltinDataset):
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)),
|
||||
@@ -118,7 +123,10 @@ class ICEWS18Dataset(DGLBuiltinDataset):
|
||||
|
||||
- ``edata['rel_type']``: edge type
|
||||
"""
|
||||
return self._graphs[idx]
|
||||
if self._transform is None:
|
||||
return self._graphs[idx]
|
||||
else:
|
||||
return self._transform(self._graphs[idx])
|
||||
|
||||
def __len__(self):
|
||||
r"""Number of graphs in the dataset.
|
||||
|
||||
@@ -34,6 +34,13 @@ class KarateClubDataset(DGLDataset):
|
||||
- Edges: 156
|
||||
- Number of Classes: 2
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
num_classes : int
|
||||
@@ -48,8 +55,8 @@ class KarateClubDataset(DGLDataset):
|
||||
>>> g = dataset[0]
|
||||
>>> labels = g.ndata['label']
|
||||
"""
|
||||
def __init__(self):
|
||||
super(KarateClubDataset, self).__init__(name='karate_club')
|
||||
def __init__(self, transform=None):
|
||||
super(KarateClubDataset, self).__init__(name='karate_club', transform=transform)
|
||||
|
||||
def process(self):
|
||||
kc_graph = nx.karate_club_graph()
|
||||
@@ -88,7 +95,10 @@ class KarateClubDataset(DGLDataset):
|
||||
- ``ndata['label']``: ground truth labels
|
||||
"""
|
||||
assert idx == 0, "This dataset has only one graph"
|
||||
return self._graph
|
||||
if self._transform is None:
|
||||
return self._graph
|
||||
else:
|
||||
return self._transform(self._graph)
|
||||
|
||||
def __len__(self):
|
||||
r"""The number of graphs in the dataset."""
|
||||
|
||||
@@ -25,19 +25,24 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
name: str
|
||||
name : str
|
||||
Name can be 'FB15k-237', 'FB15k' or 'wn18'.
|
||||
reverse: bool
|
||||
reverse : bool
|
||||
Whether add reverse edges. Default: True.
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
"""
|
||||
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
|
||||
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False,
|
||||
verbose=True, transform=None):
|
||||
self._name = name
|
||||
self.reverse = reverse
|
||||
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
|
||||
@@ -45,7 +50,8 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
|
||||
url=url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def download(self):
|
||||
r""" Automatically download data and extract it.
|
||||
@@ -112,7 +118,10 @@ class KnowledgeGraphDataset(DGLBuiltinDataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert idx == 0, "This dataset has only one graph"
|
||||
return self._g
|
||||
if self._transform is None:
|
||||
return self._g
|
||||
else:
|
||||
return self._transform(self._g)
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
@@ -389,8 +398,12 @@ class FB15k237Dataset(KnowledgeGraphDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -433,9 +446,11 @@ class FB15k237Dataset(KnowledgeGraphDataset):
|
||||
>>>
|
||||
>>> # Train, Validation and Test
|
||||
"""
|
||||
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
|
||||
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
|
||||
verbose=True, transform=None):
|
||||
name = 'FB15k-237'
|
||||
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
|
||||
super(FB15k237Dataset, self).__init__(name, reverse, raw_dir,
|
||||
force_reload, verbose, transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -526,8 +541,12 @@ class FB15kDataset(KnowledgeGraphDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -570,9 +589,11 @@ class FB15kDataset(KnowledgeGraphDataset):
|
||||
>>> # Train, Validation and Test
|
||||
>>>
|
||||
"""
|
||||
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
|
||||
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
|
||||
verbose=True, transform=None):
|
||||
name = 'FB15k'
|
||||
super(FB15kDataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
|
||||
super(FB15kDataset, self).__init__(name, reverse, raw_dir,
|
||||
force_reload, verbose, transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -662,8 +683,12 @@ class WN18Dataset(KnowledgeGraphDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -706,9 +731,11 @@ class WN18Dataset(KnowledgeGraphDataset):
|
||||
>>> # Train, Validation and Test
|
||||
>>>
|
||||
"""
|
||||
def __init__(self, reverse=True, raw_dir=None, force_reload=False, verbose=True):
|
||||
def __init__(self, reverse=True, raw_dir=None, force_reload=False,
|
||||
verbose=True, transform=None):
|
||||
name = 'wn18'
|
||||
super(WN18Dataset, self).__init__(name, reverse, raw_dir, force_reload, verbose)
|
||||
super(WN18Dataset, self).__init__(name, reverse, raw_dir,
|
||||
force_reload, verbose, transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
|
||||
@@ -33,8 +33,12 @@ class MiniGCDataset(DGLDataset):
|
||||
Minimum number of nodes for graphs
|
||||
max_num_v: int
|
||||
Maximum number of nodes for graphs
|
||||
seed : int, default is 0
|
||||
seed: int, default is 0
|
||||
Random seed for data generation
|
||||
transform: callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -75,7 +79,7 @@ class MiniGCDataset(DGLDataset):
|
||||
"""
|
||||
|
||||
def __init__(self, num_graphs, min_num_v, max_num_v, seed=0,
|
||||
save_graph=True, force_reload=False, verbose=False):
|
||||
save_graph=True, force_reload=False, verbose=False, transform=None):
|
||||
self.num_graphs = num_graphs
|
||||
self.min_num_v = min_num_v
|
||||
self.max_num_v = max_num_v
|
||||
@@ -84,7 +88,7 @@ class MiniGCDataset(DGLDataset):
|
||||
|
||||
super(MiniGCDataset, self).__init__(name="minigc", hash_key=(num_graphs, min_num_v, max_num_v, seed),
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose, transform=transform)
|
||||
|
||||
def process(self):
|
||||
self.graphs = []
|
||||
@@ -108,7 +112,11 @@ class MiniGCDataset(DGLDataset):
|
||||
(:class:`dgl.Graph`, Tensor)
|
||||
The graph and its label.
|
||||
"""
|
||||
return self.graphs[idx], self.labels[idx]
|
||||
if self._transform is None:
|
||||
g = self.graphs[idx]
|
||||
else:
|
||||
g = self._transform(self.graphs[idx])
|
||||
return g, self.labels[idx]
|
||||
|
||||
def has_cache(self):
|
||||
graph_path = os.path.join(self.save_path, 'dgl_graph_{}.bin'.format(self.hash))
|
||||
|
||||
@@ -56,9 +56,13 @@ class PPIDataset(DGLBuiltinDataset):
|
||||
force_reload : bool
|
||||
Whether to reload the dataset.
|
||||
Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information.
|
||||
Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -79,7 +83,8 @@ class PPIDataset(DGLBuiltinDataset):
|
||||
.... # your code here
|
||||
>>>
|
||||
"""
|
||||
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, mode='train', raw_dir=None, force_reload=False,
|
||||
verbose=False, transform=None):
|
||||
assert mode in ['train', 'valid', 'test']
|
||||
self.mode = mode
|
||||
_url = _get_dgl_url('dataset/ppi.zip')
|
||||
@@ -87,7 +92,8 @@ class PPIDataset(DGLBuiltinDataset):
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
|
||||
@@ -178,7 +184,10 @@ class PPIDataset(DGLBuiltinDataset):
|
||||
- ``ndata['feat']``: node features
|
||||
- ``ndata['label']``: node labels
|
||||
"""
|
||||
return self.graphs[item]
|
||||
if self._transform is None:
|
||||
return self.graphs[item]
|
||||
else:
|
||||
return self._transform(self.graphs[item])
|
||||
|
||||
|
||||
class LegacyPPIDataset(PPIDataset):
|
||||
@@ -198,5 +207,8 @@ class LegacyPPIDataset(PPIDataset):
|
||||
(dgl.DGLGraph, Tensor, Tensor)
|
||||
The graph, features and its label.
|
||||
"""
|
||||
|
||||
return self.graphs[item], self.graphs[item].ndata['feat'], self.graphs[item].ndata['label']
|
||||
if self._transform is None:
|
||||
g = self.graphs[item]
|
||||
else:
|
||||
g = self._transform(self.graphs[item])
|
||||
return g, g.ndata['feat'], g.ndata['label']
|
||||
|
||||
@@ -34,8 +34,12 @@ class QM7bDataset(DGLDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -65,12 +69,13 @@ class QM7bDataset(DGLDataset):
|
||||
'datasets/qm7b.mat'
|
||||
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
|
||||
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
super(QM7bDataset, self).__init__(name='qm7b',
|
||||
url=self._url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
mat_path = self.raw_path + '.mat'
|
||||
@@ -129,7 +134,11 @@ class QM7bDataset(DGLDataset):
|
||||
-------
|
||||
(:class:`dgl.DGLGraph`, Tensor)
|
||||
"""
|
||||
return self.graphs[idx], self.label[idx]
|
||||
if self._transform is None:
|
||||
g = self.graphs[idx]
|
||||
else:
|
||||
g = self._transform(self.graphs[idx])
|
||||
return g, self.label[idx]
|
||||
|
||||
def __len__(self):
|
||||
r"""Number of graphs in the dataset.
|
||||
|
||||
@@ -20,11 +20,11 @@ class QM9Dataset(DGLDataset):
|
||||
2. It only provides atoms' coordinates and atomic numbers as node features
|
||||
3. It only provides 12 regression targets.
|
||||
|
||||
Reference:
|
||||
|
||||
Reference:
|
||||
|
||||
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_,
|
||||
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
|
||||
|
||||
|
||||
Statistics:
|
||||
|
||||
- Number of graphs: 130,831
|
||||
@@ -60,9 +60,9 @@ class QM9Dataset(DGLDataset):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label_keys: list
|
||||
label_keys : list
|
||||
Names of the regression property, which should be a subset of the keys in the table above.
|
||||
cutoff: float
|
||||
cutoff : float
|
||||
Cutoff distance for interatomic interactions, i.e. two atoms are connected in the corresponding graph if the distance between them is no larger than this.
|
||||
Default: 5.0 Angstrom
|
||||
raw_dir : str
|
||||
@@ -70,8 +70,12 @@ class QM9Dataset(DGLDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -82,7 +86,7 @@ class QM9Dataset(DGLDataset):
|
||||
------
|
||||
UserWarning
|
||||
If the raw data is changed in the remote server by the author.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
|
||||
@@ -102,8 +106,9 @@ class QM9Dataset(DGLDataset):
|
||||
cutoff=5.0,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=False):
|
||||
|
||||
verbose=False,
|
||||
transform=None):
|
||||
|
||||
self.cutoff = cutoff
|
||||
self.label_keys = label_keys
|
||||
self._url = _get_dgl_url('dataset/qm9_eV.npz')
|
||||
@@ -112,7 +117,8 @@ class QM9Dataset(DGLDataset):
|
||||
url=self._url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
npz_path = f'{self.raw_dir}/qm9_eV.npz'
|
||||
@@ -148,7 +154,7 @@ class QM9Dataset(DGLDataset):
|
||||
----------
|
||||
idx : int
|
||||
Item index
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
dgl.DGLGraph
|
||||
@@ -170,8 +176,12 @@ class QM9Dataset(DGLDataset):
|
||||
g = dgl_graph((u, v))
|
||||
g = to_bidirected(g)
|
||||
g.ndata['R'] = F.tensor(R, dtype=F.data_type_dict['float32'])
|
||||
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
|
||||
g.ndata['Z'] = F.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]],
|
||||
dtype=F.data_type_dict['int64'])
|
||||
|
||||
if self._transform is not None:
|
||||
g = self._transform(g)
|
||||
|
||||
return g, label
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -14,34 +14,34 @@ class QM9EdgeDataset(DGLDataset):
|
||||
|
||||
This dataset consists of 130,831 molecules with 19 regression targets.
|
||||
Nodes correspond to atoms and edges correspond to bonds.
|
||||
|
||||
|
||||
This dataset differs from :class:`~dgl.data.QM9Dataset` in the following aspects:
|
||||
1. It includes the bonds in a molecule in the edges of the corresponding graph while the edges in :class:`~dgl.data.QM9Dataset` are purely distance-based.
|
||||
2. It provides edge features, and node features in addition to the atoms' coordinates and atomic numbers.
|
||||
3. It provides another 7 regression tasks(from 12 to 19).
|
||||
|
||||
This class is built based on a preprocessed version of the dataset, and we provide the preprocessing datails `here <https://gist.github.com/hengruizhang98/a2da30213b2356fff18b25385c9d3cd2>`_.
|
||||
|
||||
|
||||
Reference:
|
||||
|
||||
- `"MoleculeNet: A Benchmark for Molecular Machine Learning" <https://arxiv.org/abs/1703.00564>`_
|
||||
- `"Neural Message Passing for Quantum Chemistry" <https://arxiv.org/abs/1704.01212>`_
|
||||
|
||||
For
|
||||
|
||||
For
|
||||
Statistics:
|
||||
|
||||
- Number of graphs: 130,831.
|
||||
- Number of regression targets: 19.
|
||||
|
||||
|
||||
Node attributes:
|
||||
|
||||
- pos: the 3D coordinates of each atom.
|
||||
- attr: the 11D atom features.
|
||||
|
||||
- pos: the 3D coordinates of each atom.
|
||||
- attr: the 11D atom features.
|
||||
|
||||
Edge attributes:
|
||||
|
||||
|
||||
- edge_attr: the 4D bond features.
|
||||
|
||||
|
||||
Regression targets:
|
||||
|
||||
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
|
||||
@@ -85,10 +85,10 @@ class QM9EdgeDataset(DGLDataset):
|
||||
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
|
||||
| C | :math:`C` | Rotational constant | :math:`\textrm{GHz}` |
|
||||
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
label_keys: list
|
||||
label_keys : list
|
||||
Names of the regression property, which should be a subset of the keys in the table above.
|
||||
If not provided, it will load all the labels.
|
||||
raw_dir : str
|
||||
@@ -96,8 +96,12 @@ class QM9EdgeDataset(DGLDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False.
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -108,13 +112,13 @@ class QM9EdgeDataset(DGLDataset):
|
||||
------
|
||||
UserWarning
|
||||
If the raw data is changed in the remote server by the author.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> data = QM9EdgeDataset(label_keys=['mu', 'alpha'])
|
||||
>>> data.num_labels
|
||||
2
|
||||
|
||||
|
||||
>>> # iterate over the dataset
|
||||
>>> for graph, labels in data:
|
||||
... print(graph) # get information of each graph
|
||||
@@ -122,47 +126,49 @@ class QM9EdgeDataset(DGLDataset):
|
||||
... # your code here...
|
||||
>>>
|
||||
"""
|
||||
|
||||
|
||||
keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv', 'U0_atom',
|
||||
'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C']
|
||||
map_dict = {}
|
||||
|
||||
|
||||
for i, key in enumerate(keys):
|
||||
map_dict[key] = i
|
||||
|
||||
def __init__(self,
|
||||
|
||||
def __init__(self,
|
||||
label_keys=None,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True):
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True,
|
||||
transform=None):
|
||||
if label_keys is None:
|
||||
self.label_keys = None
|
||||
self.num_labels = 19
|
||||
else:
|
||||
self.label_keys = [self.map_dict[i] for i in label_keys]
|
||||
self.num_labels = len(label_keys)
|
||||
|
||||
|
||||
|
||||
|
||||
self._url = _get_dgl_url('dataset/qm9_edge.npz')
|
||||
|
||||
|
||||
super(QM9EdgeDataset, self).__init__(name='qm9Edge',
|
||||
raw_dir=raw_dir,
|
||||
url=self._url,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def download(self):
|
||||
file_path = f'{self.raw_dir}/qm9_edge.npz'
|
||||
if not os.path.exists(file_path):
|
||||
download(self._url, path=file_path)
|
||||
|
||||
|
||||
def process(self):
|
||||
self.load()
|
||||
|
||||
def has_cache(self):
|
||||
npz_path = f'{self.raw_dir}/qm9_edge.npz'
|
||||
return os.path.exists(npz_path)
|
||||
|
||||
|
||||
def save(self):
|
||||
np.savez_compressed(f'{self.raw_dir}/qm9_edge.npz',
|
||||
n_node=self.n_node,
|
||||
@@ -171,7 +177,7 @@ class QM9EdgeDataset(DGLDataset):
|
||||
node_pos=self.node_pos,
|
||||
edge_attr=self.edge_attr,
|
||||
src=self.src,
|
||||
dst=self.dst,
|
||||
dst=self.dst,
|
||||
targets=self.targets)
|
||||
|
||||
def load(self):
|
||||
@@ -184,52 +190,55 @@ class QM9EdgeDataset(DGLDataset):
|
||||
self.node_pos = data_dict['node_pos']
|
||||
self.edge_attr = data_dict['edge_attr']
|
||||
self.targets = data_dict['targets']
|
||||
|
||||
|
||||
self.src = data_dict['src']
|
||||
self.dst = data_dict['dst']
|
||||
|
||||
|
||||
self.n_cumsum = np.concatenate([[0], np.cumsum(self.n_node)])
|
||||
self.ne_cumsum = np.concatenate([[0], np.cumsum(self.n_edge)])
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r""" Get graph and label by index
|
||||
|
||||
r""" Get graph and label by index
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : int
|
||||
Item index
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
dgl.DGLGraph
|
||||
The graph contains:
|
||||
|
||||
|
||||
- ``ndata['pos']``: the coordinates of each atom
|
||||
- ``ndata['attr']``: the features of each atom
|
||||
- ``edata['edge_attr']``: the features of each bond
|
||||
|
||||
|
||||
Tensor
|
||||
Property values of molecular graphs
|
||||
"""
|
||||
|
||||
|
||||
pos = self.node_pos[self.n_cumsum[idx]:self.n_cumsum[idx+1]]
|
||||
src = self.src[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
|
||||
dst = self.dst[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]]
|
||||
|
||||
g = dgl_graph((src, dst))
|
||||
|
||||
|
||||
g.ndata['pos'] = F.tensor(pos, dtype=F.data_type_dict['float32'])
|
||||
g.ndata['attr'] = F.tensor(self.node_attr[self.n_cumsum[idx]:self.n_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
|
||||
g.edata['edge_attr'] = F.tensor(self.edge_attr[self.ne_cumsum[idx]:self.ne_cumsum[idx+1]], dtype=F.data_type_dict['float32'])
|
||||
|
||||
|
||||
|
||||
|
||||
label = F.tensor(self.targets[idx][self.label_keys], dtype=F.data_type_dict['float32'])
|
||||
|
||||
|
||||
if self._transform is not None:
|
||||
g = self._transform(g)
|
||||
|
||||
return g, label
|
||||
|
||||
def __len__(self):
|
||||
r""" Number of graphs in the dataset.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
|
||||
@@ -94,15 +94,20 @@ class RDFGraphDataset(DGLBuiltinDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool, optional
|
||||
If true, force load and process from raw data. Ignore cached pre-processed data.
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
"""
|
||||
def __init__(self, name, url, predict_category,
|
||||
print_every=10000,
|
||||
insert_reverse=True,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
transform=None):
|
||||
self._insert_reverse = insert_reverse
|
||||
self._print_every = print_every
|
||||
self._predict_category = predict_category
|
||||
@@ -110,7 +115,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
|
||||
super(RDFGraphDataset, self).__init__(name, url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
raw_tuples = self.load_raw_tuples(self.raw_path)
|
||||
@@ -409,6 +415,8 @@ class RDFGraphDataset(DGLBuiltinDataset):
|
||||
r"""Gets the graph object
|
||||
"""
|
||||
g = self._hg
|
||||
if self._transform is not None:
|
||||
g = self._transform(g)
|
||||
return g
|
||||
|
||||
def __len__(self):
|
||||
@@ -523,17 +531,21 @@ class AIFBDataset(RDFGraphDataset):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
print_every: int
|
||||
print_every : int
|
||||
Preprocessing log for every X tuples. Default: 10000.
|
||||
insert_reverse: bool
|
||||
insert_reverse : bool
|
||||
If true, add reverse edge and reverse relations to the final graph. Default: True.
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -562,7 +574,8 @@ class AIFBDataset(RDFGraphDataset):
|
||||
insert_reverse=True,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
transform=None):
|
||||
import rdflib as rdf
|
||||
self.employs = rdf.term.URIRef("http://swrc.ontoware.org/ontology#employs")
|
||||
self.affiliation = rdf.term.URIRef("http://swrc.ontoware.org/ontology#affiliation")
|
||||
@@ -574,7 +587,8 @@ class AIFBDataset(RDFGraphDataset):
|
||||
insert_reverse=insert_reverse,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -653,17 +667,21 @@ class MUTAGDataset(RDFGraphDataset):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
print_every: int
|
||||
print_every : int
|
||||
Preprocessing log for every X tuples. Default: 10000.
|
||||
insert_reverse: bool
|
||||
insert_reverse : bool
|
||||
If true, add reverse edge and reverse relations to the final graph. Default: True.
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -697,7 +715,8 @@ class MUTAGDataset(RDFGraphDataset):
|
||||
insert_reverse=True,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
transform=None):
|
||||
import rdflib as rdf
|
||||
self.is_mutagenic = rdf.term.URIRef("http://dl-learner.org/carcinogenesis#isMutagenic")
|
||||
self.rdf_type = rdf.term.URIRef("http://www.w3.org/1999/02/22-rdf-syntax-ns#type")
|
||||
@@ -712,7 +731,8 @@ class MUTAGDataset(RDFGraphDataset):
|
||||
insert_reverse=insert_reverse,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -814,17 +834,21 @@ class BGSDataset(RDFGraphDataset):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
print_every: int
|
||||
print_every : int
|
||||
Preprocessing log for every X tuples. Default: 10000.
|
||||
insert_reverse: bool
|
||||
insert_reverse : bool
|
||||
If true, add reverse edge and reverse relations to the final graph. Default: True.
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -854,7 +878,8 @@ class BGSDataset(RDFGraphDataset):
|
||||
insert_reverse=True,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
transform=None):
|
||||
import rdflib as rdf
|
||||
url = _get_dgl_url('dataset/rdf/bgs-hetero.zip')
|
||||
name = 'bgs-hetero'
|
||||
@@ -865,7 +890,8 @@ class BGSDataset(RDFGraphDataset):
|
||||
insert_reverse=insert_reverse,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
@@ -964,17 +990,21 @@ class AMDataset(RDFGraphDataset):
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
print_every: int
|
||||
print_every : int
|
||||
Preprocessing log for every X tuples. Default: 10000.
|
||||
insert_reverse: bool
|
||||
insert_reverse : bool
|
||||
If true, add reverse edge and reverse relations to the final graph. Default: True.
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -1003,7 +1033,8 @@ class AMDataset(RDFGraphDataset):
|
||||
insert_reverse=True,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=True):
|
||||
verbose=True,
|
||||
transform=None):
|
||||
import rdflib as rdf
|
||||
self.objectCategory = rdf.term.URIRef("http://purl.org/collections/nl/am/objectCategory")
|
||||
self.material = rdf.term.URIRef("http://purl.org/collections/nl/am/material")
|
||||
@@ -1015,7 +1046,8 @@ class AMDataset(RDFGraphDataset):
|
||||
insert_reverse=insert_reverse,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Gets the graph object
|
||||
|
||||
@@ -84,8 +84,12 @@ class RedditDataset(DGLBuiltinDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -125,7 +129,8 @@ class RedditDataset(DGLBuiltinDataset):
|
||||
>>>
|
||||
>>> # Train, Validation and Test
|
||||
"""
|
||||
def __init__(self, self_loop=False, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, self_loop=False, raw_dir=None, force_reload=False,
|
||||
verbose=False, transform=None):
|
||||
self_loop_str = ""
|
||||
if self_loop:
|
||||
self_loop_str = "_self_loop"
|
||||
@@ -135,7 +140,8 @@ class RedditDataset(DGLBuiltinDataset):
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
# graph
|
||||
@@ -251,7 +257,10 @@ class RedditDataset(DGLBuiltinDataset):
|
||||
- ``ndata['test_mask']:`` mask for test node set
|
||||
"""
|
||||
assert idx == 0, "Reddit Dataset only has one graph"
|
||||
return self._graph
|
||||
if self._transform is None:
|
||||
return self._graph
|
||||
else:
|
||||
return self._transform(self._graph)
|
||||
|
||||
def __len__(self):
|
||||
r"""Number of graphs in the dataset"""
|
||||
|
||||
@@ -23,7 +23,7 @@ class SSTDataset(DGLBuiltinDataset):
|
||||
r"""Stanford Sentiment Treebank dataset.
|
||||
|
||||
.. deprecated:: 0.5.0
|
||||
|
||||
|
||||
- ``trees`` is deprecated, it is replaced by:
|
||||
|
||||
>>> dataset = SSTDataset()
|
||||
@@ -63,8 +63,12 @@ class SSTDataset(DGLBuiltinDataset):
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
verbose : bool
|
||||
Whether to print out progress information. Default: True.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -120,7 +124,8 @@ class SSTDataset(DGLBuiltinDataset):
|
||||
vocab_file=None,
|
||||
raw_dir=None,
|
||||
force_reload=False,
|
||||
verbose=False):
|
||||
verbose=False,
|
||||
transform=None):
|
||||
assert mode in ['train', 'dev', 'test', 'tiny']
|
||||
_url = _get_dgl_url('dataset/sst.zip')
|
||||
self._glove_embed_file = glove_embed_file if mode == 'train' else None
|
||||
@@ -130,7 +135,8 @@ class SSTDataset(DGLBuiltinDataset):
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
verbose=verbose,
|
||||
transform=transform)
|
||||
|
||||
def process(self):
|
||||
from nltk.corpus.reader import BracketParseCorpusReader
|
||||
@@ -255,7 +261,10 @@ class SSTDataset(DGLBuiltinDataset):
|
||||
- ``ndata['y']:`` label of the node
|
||||
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
|
||||
"""
|
||||
return self._trees[idx]
|
||||
if self._transform is None:
|
||||
return self._trees[idx]
|
||||
else:
|
||||
return self._transform(self._trees[idx])
|
||||
|
||||
def __len__(self):
|
||||
r"""Number of graphs in the dataset."""
|
||||
|
||||
@@ -13,7 +13,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
|
||||
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
|
||||
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
|
||||
use_pandas : bool
|
||||
Numpy's file read function has performance issue when file is large,
|
||||
@@ -26,6 +26,10 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
max_allow_node : int
|
||||
Remove graphs that contains more nodes than ``max_allow_node``.
|
||||
Default : None
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -39,7 +43,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
LegacyTUDataset uses provided node feature by default. If no feature provided, it uses one-hot node label instead.
|
||||
If neither labels provided, it uses constant for node feature.
|
||||
|
||||
The dataset sorts graphs by their labels.
|
||||
The dataset sorts graphs by their labels.
|
||||
Shuffle is preferred before manual train/val split.
|
||||
|
||||
Examples
|
||||
@@ -73,7 +77,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
|
||||
def __init__(self, name, use_pandas=False,
|
||||
hidden_size=10, max_allow_node=None,
|
||||
raw_dir=None, force_reload=False, verbose=False):
|
||||
raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
|
||||
url = self._url.format(name)
|
||||
self.hidden_size = hidden_size
|
||||
@@ -81,7 +85,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
self.use_pandas = use_pandas
|
||||
super(LegacyTUDataset, self).__init__(name=name, url=url, raw_dir=raw_dir,
|
||||
hash_key=(name, use_pandas, hidden_size, max_allow_node),
|
||||
force_reload=force_reload, verbose=verbose)
|
||||
force_reload=force_reload, verbose=verbose, transform=transform)
|
||||
|
||||
def process(self):
|
||||
self.data_mode = None
|
||||
@@ -100,7 +104,7 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
DS_graph_labels = self._idx_from_zero(
|
||||
np.genfromtxt(self._file_path("graph_labels"), dtype=int))
|
||||
self.num_labels = max(DS_graph_labels) + 1
|
||||
self.graph_labels = DS_graph_labels
|
||||
self.graph_labels = DS_graph_labels
|
||||
elif os.path.exists(self._file_path("graph_attributes")):
|
||||
DS_graph_labels = np.genfromtxt(self._file_path("graph_attributes"), dtype=float)
|
||||
self.num_labels = None
|
||||
@@ -211,6 +215,8 @@ class LegacyTUDataset(DGLBuiltinDataset):
|
||||
And its label.
|
||||
"""
|
||||
g = self.graph_lists[idx]
|
||||
if self._transform is not None:
|
||||
g = self._transform(g)
|
||||
return g, self.graph_labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
@@ -245,8 +251,12 @@ class TUDataset(DGLBuiltinDataset):
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
|
||||
Dataset Name, such as ``ENZYMES``, ``DD``, ``COLLAB``, ``MUTAG``, can be the
|
||||
datasets name on `<https://chrsmrrs.github.io/datasets/docs/datasets/>`_.
|
||||
transform : callable, optional
|
||||
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
|
||||
a transformed version. The :class:`~dgl.DGLGraph` object will be
|
||||
transformed before every access.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@@ -271,7 +281,7 @@ class TUDataset(DGLBuiltinDataset):
|
||||
label was added so that :math:`\lbrace -1, 1 \rbrace` was mapped to
|
||||
:math:`\lbrace 0, 2 \rbrace`.
|
||||
|
||||
The dataset sorts graphs by their labels.
|
||||
The dataset sorts graphs by their labels.
|
||||
Shuffle is preferred before manual train/val split.
|
||||
|
||||
Examples
|
||||
@@ -299,32 +309,32 @@ class TUDataset(DGLBuiltinDataset):
|
||||
Graph(num_nodes=9539, num_edges=47382,
|
||||
ndata_schemes={'node_labels': Scheme(shape=(1,), dtype=torch.int64), '_ID': Scheme(shape=(), dtype=torch.int64)}
|
||||
edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})
|
||||
|
||||
|
||||
"""
|
||||
|
||||
_url = r"https://www.chrsmrrs.com/graphkerneldatasets/{}.zip"
|
||||
|
||||
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False):
|
||||
def __init__(self, name, raw_dir=None, force_reload=False, verbose=False, transform=None):
|
||||
url = self._url.format(name)
|
||||
super(TUDataset, self).__init__(name=name, url=url,
|
||||
raw_dir=raw_dir, force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
|
||||
verbose=verbose, transform=transform)
|
||||
|
||||
def process(self):
|
||||
DS_edge_list = self._idx_from_zero(
|
||||
loadtxt(self._file_path("A"), delimiter=",").astype(int))
|
||||
DS_indicator = self._idx_from_zero(
|
||||
loadtxt(self._file_path("graph_indicator"), delimiter=",").astype(int))
|
||||
|
||||
|
||||
if os.path.exists(self._file_path("graph_labels")):
|
||||
DS_graph_labels = self._idx_reset(
|
||||
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
|
||||
loadtxt(self._file_path("graph_labels"), delimiter=",").astype(int))
|
||||
self.num_labels = max(DS_graph_labels) + 1
|
||||
self.graph_labels = F.tensor(DS_graph_labels)
|
||||
self.graph_labels = F.tensor(DS_graph_labels)
|
||||
elif os.path.exists(self._file_path("graph_attributes")):
|
||||
DS_graph_labels = loadtxt(self._file_path("graph_attributes"), delimiter=",").astype(float)
|
||||
self.num_labels = None
|
||||
self.graph_labels = F.tensor(DS_graph_labels)
|
||||
self.graph_labels = F.tensor(DS_graph_labels)
|
||||
else:
|
||||
raise Exception("Unknown graph label or graph attributes")
|
||||
|
||||
@@ -404,6 +414,8 @@ class TUDataset(DGLBuiltinDataset):
|
||||
And its label.
|
||||
"""
|
||||
g = self.graph_lists[idx]
|
||||
if self._transform is not None:
|
||||
g = self._transform(g)
|
||||
return g, self.graph_labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user