mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-06 20:04:24 +08:00
* auto-reformat * lintrunner --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
256 lines
8.4 KiB
Python
256 lines
8.4 KiB
Python
"""
|
||
Make Your Own Dataset
|
||
=====================
|
||
|
||
This tutorial assumes that you already know :doc:`the basics of training a
|
||
GNN for node classification <1_introduction>` and :doc:`how to
|
||
create, load, and store a DGL graph <2_dglgraph>`.
|
||
|
||
By the end of this tutorial, you will be able to
|
||
|
||
- Create your own graph dataset for node classification, link
|
||
prediction, or graph classification.
|
||
|
||
(Time estimate: 15 minutes)
|
||
"""
|
||
|
||
|
||
######################################################################
|
||
# ``DGLDataset`` Object Overview
|
||
# ------------------------------
|
||
#
|
||
# Your custom graph dataset should inherit the ``dgl.data.DGLDataset``
|
||
# class and implement the following methods:
|
||
#
|
||
# - ``__getitem__(self, i)``: retrieve the ``i``-th example of the
|
||
# dataset. An example often contains a single DGL graph, and
|
||
# occasionally its label.
|
||
# - ``__len__(self)``: the number of examples in the dataset.
|
||
# - ``process(self)``: load and process raw data from disk.
|
||
#
|
||
|
||
|
||
######################################################################
|
||
# Creating a Dataset for Node Classification or Link Prediction from CSV
|
||
# ----------------------------------------------------------------------
|
||
#
|
||
# A node classification dataset often consists of a single graph, as well
|
||
# as its node and edge features.
|
||
#
|
||
# This tutorial takes a small dataset based on `Zachary’s Karate Club
|
||
# network <https://en.wikipedia.org/wiki/Zachary%27s_karate_club>`__. It
|
||
# contains
|
||
#
|
||
# * A ``members.csv`` file containing the attributes of all
|
||
# members, as well as their attributes.
|
||
#
|
||
# * An ``interactions.csv`` file
|
||
# containing the pair-wise interactions between two club members.
|
||
#
|
||
|
||
import urllib.request
|
||
|
||
import pandas as pd
|
||
|
||
urllib.request.urlretrieve(
|
||
"https://data.dgl.ai/tutorial/dataset/members.csv", "./members.csv"
|
||
)
|
||
urllib.request.urlretrieve(
|
||
"https://data.dgl.ai/tutorial/dataset/interactions.csv",
|
||
"./interactions.csv",
|
||
)
|
||
|
||
members = pd.read_csv("./members.csv")
|
||
members.head()
|
||
|
||
interactions = pd.read_csv("./interactions.csv")
|
||
interactions.head()
|
||
|
||
|
||
######################################################################
|
||
# This tutorial treats the members as nodes and interactions as edges. It
|
||
# takes age as a numeric feature of the nodes, affiliated club as the label
|
||
# of the nodes, and edge weight as a numeric feature of the edges.
|
||
#
|
||
# .. note::
|
||
#
|
||
# The original Zachary’s Karate Club network does not have
|
||
# member ages. The ages in this tutorial are generated synthetically
|
||
# for demonstrating how to add node features into the graph for dataset
|
||
# creation.
|
||
#
|
||
# .. note::
|
||
#
|
||
# In practice, taking age directly as a numeric feature may
|
||
# not work well in machine learning; strategies like binning or
|
||
# normalizing the feature would work better. This tutorial directly
|
||
# takes the values as-is for simplicity.
|
||
#
|
||
|
||
import os
|
||
|
||
os.environ["DGLBACKEND"] = "pytorch"
|
||
import dgl
|
||
import torch
|
||
from dgl.data import DGLDataset
|
||
|
||
|
||
class KarateClubDataset(DGLDataset):
|
||
def __init__(self):
|
||
super().__init__(name="karate_club")
|
||
|
||
def process(self):
|
||
nodes_data = pd.read_csv("./members.csv")
|
||
edges_data = pd.read_csv("./interactions.csv")
|
||
node_features = torch.from_numpy(nodes_data["Age"].to_numpy())
|
||
node_labels = torch.from_numpy(
|
||
nodes_data["Club"].astype("category").cat.codes.to_numpy()
|
||
)
|
||
edge_features = torch.from_numpy(edges_data["Weight"].to_numpy())
|
||
edges_src = torch.from_numpy(edges_data["Src"].to_numpy())
|
||
edges_dst = torch.from_numpy(edges_data["Dst"].to_numpy())
|
||
|
||
self.graph = dgl.graph(
|
||
(edges_src, edges_dst), num_nodes=nodes_data.shape[0]
|
||
)
|
||
self.graph.ndata["feat"] = node_features
|
||
self.graph.ndata["label"] = node_labels
|
||
self.graph.edata["weight"] = edge_features
|
||
|
||
# If your dataset is a node classification dataset, you will need to assign
|
||
# masks indicating whether a node belongs to training, validation, and test set.
|
||
n_nodes = nodes_data.shape[0]
|
||
n_train = int(n_nodes * 0.6)
|
||
n_val = int(n_nodes * 0.2)
|
||
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
|
||
val_mask = torch.zeros(n_nodes, dtype=torch.bool)
|
||
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
|
||
train_mask[:n_train] = True
|
||
val_mask[n_train : n_train + n_val] = True
|
||
test_mask[n_train + n_val :] = True
|
||
self.graph.ndata["train_mask"] = train_mask
|
||
self.graph.ndata["val_mask"] = val_mask
|
||
self.graph.ndata["test_mask"] = test_mask
|
||
|
||
def __getitem__(self, i):
|
||
return self.graph
|
||
|
||
def __len__(self):
|
||
return 1
|
||
|
||
|
||
dataset = KarateClubDataset()
|
||
graph = dataset[0]
|
||
|
||
print(graph)
|
||
|
||
|
||
######################################################################
|
||
# Since a link prediction dataset only involves a single graph, preparing
|
||
# a link prediction dataset will have the same experience as preparing a
|
||
# node classification dataset.
|
||
#
|
||
|
||
|
||
######################################################################
|
||
# Creating a Dataset for Graph Classification from CSV
|
||
# ----------------------------------------------------
|
||
#
|
||
# Creating a graph classification dataset involves implementing
|
||
# ``__getitem__`` to return both the graph and its graph-level label.
|
||
#
|
||
# This tutorial demonstrates how to create a graph classification dataset
|
||
# with the following synthetic CSV data:
|
||
#
|
||
# - ``graph_edges.csv``: containing three columns:
|
||
#
|
||
# - ``graph_id``: the ID of the graph.
|
||
# - ``src``: the source node of an edge of the given graph.
|
||
# - ``dst``: the destination node of an edge of the given graph.
|
||
#
|
||
# - ``graph_properties.csv``: containing three columns:
|
||
#
|
||
# - ``graph_id``: the ID of the graph.
|
||
# - ``label``: the label of the graph.
|
||
# - ``num_nodes``: the number of nodes in the graph.
|
||
#
|
||
|
||
urllib.request.urlretrieve(
|
||
"https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
|
||
)
|
||
urllib.request.urlretrieve(
|
||
"https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
|
||
"./graph_properties.csv",
|
||
)
|
||
edges = pd.read_csv("./graph_edges.csv")
|
||
properties = pd.read_csv("./graph_properties.csv")
|
||
|
||
edges.head()
|
||
|
||
properties.head()
|
||
|
||
|
||
class SyntheticDataset(DGLDataset):
|
||
def __init__(self):
|
||
super().__init__(name="synthetic")
|
||
|
||
def process(self):
|
||
edges = pd.read_csv("./graph_edges.csv")
|
||
properties = pd.read_csv("./graph_properties.csv")
|
||
self.graphs = []
|
||
self.labels = []
|
||
|
||
# Create a graph for each graph ID from the edges table.
|
||
# First process the properties table into two dictionaries with graph IDs as keys.
|
||
# The label and number of nodes are values.
|
||
label_dict = {}
|
||
num_nodes_dict = {}
|
||
for _, row in properties.iterrows():
|
||
label_dict[row["graph_id"]] = row["label"]
|
||
num_nodes_dict[row["graph_id"]] = row["num_nodes"]
|
||
|
||
# For the edges, first group the table by graph IDs.
|
||
edges_group = edges.groupby("graph_id")
|
||
|
||
# For each graph ID...
|
||
for graph_id in edges_group.groups:
|
||
# Find the edges as well as the number of nodes and its label.
|
||
edges_of_id = edges_group.get_group(graph_id)
|
||
src = edges_of_id["src"].to_numpy()
|
||
dst = edges_of_id["dst"].to_numpy()
|
||
num_nodes = num_nodes_dict[graph_id]
|
||
label = label_dict[graph_id]
|
||
|
||
# Create a graph and add it to the list of graphs and labels.
|
||
g = dgl.graph((src, dst), num_nodes=num_nodes)
|
||
self.graphs.append(g)
|
||
self.labels.append(label)
|
||
|
||
# Convert the label list to tensor for saving.
|
||
self.labels = torch.LongTensor(self.labels)
|
||
|
||
def __getitem__(self, i):
|
||
return self.graphs[i], self.labels[i]
|
||
|
||
def __len__(self):
|
||
return len(self.graphs)
|
||
|
||
|
||
dataset = SyntheticDataset()
|
||
graph, label = dataset[0]
|
||
print(graph, label)
|
||
|
||
######################################################################
|
||
# Creating Dataset from CSV via :class:`~dgl.data.CSVDataset`
|
||
# ------------------------------------------------------------
|
||
#
|
||
# The previous examples describe how to create a dataset from CSV files
|
||
# step-by-step. DGL also provides a utility class :class:`~dgl.data.CSVDataset`
|
||
# for reading and parsing data from CSV files. See :ref:`guide-data-pipeline-loadcsv`
|
||
# for more details.
|
||
#
|
||
|
||
|
||
# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
|
||
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'
|