mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
PPIDataset
This commit is contained in:
@@ -30,34 +30,26 @@ def main(args):
|
||||
|
||||
# load and preprocess dataset
|
||||
data = load_data(args)
|
||||
g = data.g
|
||||
train_mask = g.ndata['train_mask']
|
||||
val_mask = g.ndata['val_mask']
|
||||
test_mask = g.ndata['test_mask']
|
||||
labels = g.ndata['label']
|
||||
|
||||
train_nid = np.nonzero(data.train_mask)[0].astype(np.int64)
|
||||
train_nid = np.nonzero(train_mask.data.numpy())[0].astype(np.int64)
|
||||
|
||||
# Normalize features
|
||||
if args.normalize:
|
||||
train_feats = data.features[train_nid]
|
||||
feats = g.ndata['feat']
|
||||
train_feats = feats[train_mask]
|
||||
scaler = sklearn.preprocessing.StandardScaler()
|
||||
scaler.fit(train_feats)
|
||||
features = scaler.transform(data.features)
|
||||
else:
|
||||
features = data.features
|
||||
scaler.fit(train_feats.data.numpy())
|
||||
features = scaler.transform(feats.data.numpy())
|
||||
g.ndata['feat'] = torch.FloatTensor(features)
|
||||
|
||||
features = torch.FloatTensor(features)
|
||||
if not multitask:
|
||||
labels = torch.LongTensor(data.labels)
|
||||
else:
|
||||
labels = torch.FloatTensor(data.labels)
|
||||
if hasattr(torch, 'BoolTensor'):
|
||||
train_mask = torch.BoolTensor(data.train_mask)
|
||||
val_mask = torch.BoolTensor(data.val_mask)
|
||||
test_mask = torch.BoolTensor(data.test_mask)
|
||||
else:
|
||||
train_mask = torch.ByteTensor(data.train_mask)
|
||||
val_mask = torch.ByteTensor(data.val_mask)
|
||||
test_mask = torch.ByteTensor(data.test_mask)
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_edges = data.graph.number_of_edges()
|
||||
in_feats = g.ndata['feat'].shape[1]
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.number_of_edges()
|
||||
|
||||
n_train_samples = train_mask.int().sum().item()
|
||||
n_val_samples = val_mask.int().sum().item()
|
||||
@@ -74,17 +66,12 @@ def main(args):
|
||||
n_val_samples,
|
||||
n_test_samples))
|
||||
# create GCN model
|
||||
g = data.graph
|
||||
g = dgl.graph(g)
|
||||
if args.self_loop and not args.dataset.startswith('reddit'):
|
||||
g = dgl.remove_self_loop(g)
|
||||
g = dgl.add_self_loop(g)
|
||||
print("adding self-loop edges")
|
||||
# metis only support int64 graph
|
||||
g = g.long()
|
||||
g.ndata['features'] = features
|
||||
g.ndata['labels'] = labels
|
||||
g.ndata['train_mask'] = train_mask
|
||||
|
||||
cluster_iterator = ClusterIter(
|
||||
args.dataset, g, args.psize, args.batch_size, train_nid, use_pp=args.use_pp)
|
||||
@@ -99,9 +86,8 @@ def main(args):
|
||||
test_mask = test_mask.cuda()
|
||||
g = g.to(args.gpu)
|
||||
|
||||
print(torch.cuda.get_device_name(0))
|
||||
print('labels shape:', labels.shape)
|
||||
print("features shape, ", features.shape)
|
||||
print('labels shape:', g.ndata['label'].shape)
|
||||
print("features shape, ", g.ndata['feat'].shape)
|
||||
|
||||
model = GraphSAGE(in_feats,
|
||||
args.n_hidden,
|
||||
@@ -136,19 +122,20 @@ def main(args):
|
||||
# set train_nids to cuda tensor
|
||||
if cuda:
|
||||
train_nid = torch.from_numpy(train_nid).cuda()
|
||||
print("current memory after model before training",
|
||||
torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
|
||||
print("current memory after model before training",
|
||||
torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024)
|
||||
start_time = time.time()
|
||||
best_f1 = -1
|
||||
|
||||
for epoch in range(args.n_epochs):
|
||||
for j, cluster in enumerate(cluster_iterator):
|
||||
# sync with upper level training graph
|
||||
cluster = cluster.to(torch.cuda.current_device())
|
||||
if cuda:
|
||||
cluster = cluster.to(torch.cuda.current_device())
|
||||
model.train()
|
||||
# forward
|
||||
pred = model(cluster)
|
||||
batch_labels = cluster.ndata['labels']
|
||||
batch_labels = cluster.ndata['label']
|
||||
batch_train_mask = cluster.ndata['train_mask']
|
||||
loss = loss_f(pred[batch_train_mask],
|
||||
batch_labels[batch_train_mask])
|
||||
|
||||
@@ -90,7 +90,7 @@ class GraphSAGE(nn.Module):
|
||||
dropout=dropout, use_pp=False, use_lynorm=False))
|
||||
|
||||
def forward(self, g):
|
||||
h = g.ndata['features']
|
||||
h = g.ndata['feat']
|
||||
for layer in self.layers:
|
||||
h = layer(g, h)
|
||||
return h
|
||||
|
||||
@@ -57,21 +57,21 @@ class ClusterIter(object):
|
||||
def precalc(self, g):
|
||||
norm = self.get_norm(g)
|
||||
g.ndata['norm'] = norm
|
||||
features = g.ndata['features']
|
||||
features = g.ndata['feat']
|
||||
print("features shape, ", features.shape)
|
||||
with torch.no_grad():
|
||||
g.update_all(fn.copy_src(src='features', out='m'),
|
||||
fn.sum(msg='m', out='features'),
|
||||
g.update_all(fn.copy_src(src='feat', out='m'),
|
||||
fn.sum(msg='m', out='feat'),
|
||||
None)
|
||||
pre_feats = g.ndata['features'] * norm
|
||||
pre_feats = g.ndata['feat'] * norm
|
||||
# use graphsage embedding aggregation style
|
||||
g.ndata['features'] = torch.cat([features, pre_feats], dim=1)
|
||||
g.ndata['feat'] = torch.cat([features, pre_feats], dim=1)
|
||||
|
||||
# use one side normalization
|
||||
def get_norm(self, g):
|
||||
norm = 1. / g.in_degrees().float().unsqueeze(1)
|
||||
norm[torch.isinf(norm)] = 0
|
||||
norm = norm.to(self.g.ndata['features'].device)
|
||||
norm = norm.to(self.g.ndata['feat'].device)
|
||||
return norm
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -60,22 +60,23 @@ def evaluate(model, g, labels, mask, multitask=False):
|
||||
|
||||
def load_data(args):
|
||||
'''Wraps the dgl's load_data utility to handle ppi special case'''
|
||||
DataType = namedtuple('Dataset', ['num_classes', 'g'])
|
||||
if args.dataset != 'ppi':
|
||||
return _load_data(args)
|
||||
dataset = _load_data(args)
|
||||
data = DataType(g=dataset[0], num_classes=dataset.num_classes)
|
||||
return data
|
||||
train_dataset = PPIDataset('train')
|
||||
train_graph = dgl.batch([train_dataset[i] for i in range(len(train_dataset))], edge_attrs=None, node_attrs=None)
|
||||
val_dataset = PPIDataset('valid')
|
||||
val_graph = dgl.batch([val_dataset[i] for i in range(len(val_dataset))], edge_attrs=None, node_attrs=None)
|
||||
test_dataset = PPIDataset('test')
|
||||
PPIDataType = namedtuple('PPIDataset', ['train_mask', 'test_mask',
|
||||
'val_mask', 'features', 'labels', 'num_labels', 'graph'])
|
||||
test_graph = dgl.batch([test_dataset[i] for i in range(len(test_dataset))], edge_attrs=None, node_attrs=None)
|
||||
G = dgl.batch(
|
||||
[train_dataset.graph, val_dataset.graph, test_dataset.graph], edge_attrs=None, node_attrs=None)
|
||||
G = G.to_networkx()
|
||||
# hack to dodge the potential bugs of to_networkx
|
||||
for (n1, n2, d) in G.edges(data=True):
|
||||
d.clear()
|
||||
train_nodes_num = train_dataset.graph.number_of_nodes()
|
||||
test_nodes_num = test_dataset.graph.number_of_nodes()
|
||||
val_nodes_num = val_dataset.graph.number_of_nodes()
|
||||
[train_graph, val_graph, test_graph], edge_attrs=None, node_attrs=None)
|
||||
|
||||
train_nodes_num = train_graph.number_of_nodes()
|
||||
test_nodes_num = test_graph.number_of_nodes()
|
||||
val_nodes_num = val_graph.number_of_nodes()
|
||||
nodes_num = G.number_of_nodes()
|
||||
assert(nodes_num == (train_nodes_num + test_nodes_num + val_nodes_num))
|
||||
# construct mask
|
||||
@@ -87,13 +88,9 @@ def load_data(args):
|
||||
test_mask = mask.copy()
|
||||
test_mask[-test_nodes_num:] = True
|
||||
|
||||
# construct features
|
||||
features = np.concatenate(
|
||||
[train_dataset.features, val_dataset.features, test_dataset.features], axis=0)
|
||||
G.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool)
|
||||
G.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool)
|
||||
G.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool)
|
||||
|
||||
labels = np.concatenate(
|
||||
[train_dataset.labels, val_dataset.labels, test_dataset.labels], axis=0)
|
||||
|
||||
data = PPIDataType(graph=G, train_mask=train_mask, test_mask=test_mask,
|
||||
val_mask=val_mask, features=features, labels=labels, num_labels=121)
|
||||
data = DataType(g=G, num_classes=train_dataset.num_labels)
|
||||
return data
|
||||
|
||||
@@ -17,15 +17,12 @@ import torch.nn.functional as F
|
||||
import argparse
|
||||
from sklearn.metrics import f1_score
|
||||
from gat import GAT
|
||||
from dgl.data.ppi import LegacyPPIDataset
|
||||
from dgl.data.ppi import PPIDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
def collate(sample):
|
||||
graphs, feats, labels =map(list, zip(*sample))
|
||||
def collate(graphs):
|
||||
graph = dgl.batch(graphs)
|
||||
feats = torch.from_numpy(np.concatenate(feats))
|
||||
labels = torch.from_numpy(np.concatenate(labels))
|
||||
return graph, feats, labels
|
||||
return graph
|
||||
|
||||
def evaluate(feats, model, subgraph, labels, loss_fcn):
|
||||
with torch.no_grad():
|
||||
@@ -54,15 +51,15 @@ def main(args):
|
||||
# define loss function
|
||||
loss_fcn = torch.nn.BCEWithLogitsLoss()
|
||||
# create the dataset
|
||||
train_dataset = LegacyPPIDataset(mode='train')
|
||||
valid_dataset = LegacyPPIDataset(mode='valid')
|
||||
test_dataset = LegacyPPIDataset(mode='test')
|
||||
train_dataset = PPIDataset(mode='train')
|
||||
valid_dataset = PPIDataset(mode='valid')
|
||||
test_dataset = PPIDataset(mode='test')
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate)
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, collate_fn=collate)
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=collate)
|
||||
n_classes = train_dataset.labels.shape[1]
|
||||
num_feats = train_dataset.features.shape[1]
|
||||
g = train_dataset.graph
|
||||
g = train_dataset[0]
|
||||
n_classes = train_dataset.num_labels
|
||||
num_feats = g.ndata['feat'].shape[1]
|
||||
g = g.to(device)
|
||||
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
|
||||
# define the model
|
||||
@@ -83,16 +80,13 @@ def main(args):
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
loss_list = []
|
||||
for batch, data in enumerate(train_dataloader):
|
||||
subgraph, feats, labels = data
|
||||
for batch, subgraph in enumerate(train_dataloader):
|
||||
subgraph = subgraph.to(device)
|
||||
feats = feats.to(device)
|
||||
labels = labels.to(device)
|
||||
model.g = subgraph
|
||||
for layer in model.gat_layers:
|
||||
layer.g = subgraph
|
||||
logits = model(feats.float())
|
||||
loss = loss_fcn(logits, labels.float())
|
||||
logits = model(subgraph.ndata['feat'].float())
|
||||
loss = loss_fcn(logits, subgraph.ndata['label'])
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
@@ -1,154 +1,158 @@
|
||||
"""PPI Dataset.
|
||||
(zhang hao): Used for inductive learning.
|
||||
"""
|
||||
""" PPIDataset for inductive learning. """
|
||||
import json
|
||||
import numpy as np
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
import os
|
||||
|
||||
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
|
||||
from ..utils import retry_method_with_fix
|
||||
from .dgl_dataset import DGLBuiltinDataset
|
||||
from .utils import _get_dgl_url, save_graphs, save_info, load_info, load_graphs, deprecate_property
|
||||
from .. import backend as F
|
||||
from ..convert import from_networkx
|
||||
|
||||
_url = 'dataset/ppi.zip'
|
||||
|
||||
class PPIDataset(DGLBuiltinDataset):
|
||||
""" Protein-Protein Interaction dataset for inductive node classification
|
||||
|
||||
class PPIDataset(object):
|
||||
"""A toy Protein-Protein Interaction network dataset.
|
||||
A toy Protein-Protein Interaction network dataset. The dataset contains
|
||||
24 graphs. The average number of nodes per graph is 2372. Each node has
|
||||
50 features and 121 labels. 20 graphs for training, 2 for validation
|
||||
and 2 for testing.
|
||||
|
||||
Adapted from https://github.com/williamleif/GraphSAGE/tree/master/example_data.
|
||||
Reference: http://snap.stanford.edu/graphsage/
|
||||
|
||||
The dataset contains 24 graphs. The average number of nodes per graph
|
||||
is 2372. Each node has 50 features and 121 labels.
|
||||
Statistics
|
||||
----------
|
||||
Train examples: 20
|
||||
Valid examples: 2
|
||||
Test examples: 2
|
||||
|
||||
We use 20 graphs for training, 2 for validation and 2 for testing.
|
||||
Parameters
|
||||
----------
|
||||
mode : str
|
||||
Must be one of ('train', 'valid', 'test'). Default: 'train'
|
||||
raw_dir : str
|
||||
Raw file directory to download/contains the input data directory.
|
||||
Default: ~/.dgl/
|
||||
force_reload : bool
|
||||
Whether to reload the dataset. Default: False
|
||||
verbose: bool
|
||||
Whether to print out progress information. Default: True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PPIDataset object with two properties
|
||||
graphs: list of DGLGraph objects that contains graph structure, node features and node labels:
|
||||
- ndata['feat']: node features
|
||||
- ndata['label']: nodel labels
|
||||
graph: DGLGraph, the graph of the dataset
|
||||
num_labels: int, number of labels for each node
|
||||
Examples
|
||||
--------
|
||||
>>> data = PPIDataset(mode='valid')
|
||||
>>> num_labels = data.num_labels
|
||||
>>> for g, _ in data:
|
||||
.... feat = g.ndata['feat']
|
||||
.... label = g.ndata['label']
|
||||
.... # your code here
|
||||
"""
|
||||
def __init__(self, mode):
|
||||
"""Initialize the dataset.
|
||||
|
||||
Paramters
|
||||
---------
|
||||
mode : str
|
||||
('train', 'valid', 'test').
|
||||
"""
|
||||
def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False):
|
||||
assert mode in ['train', 'valid', 'test']
|
||||
self.mode = mode
|
||||
self._name = 'ppi'
|
||||
self._dir = get_download_dir()
|
||||
self._zip_file_path = '{}/{}.zip'.format(self._dir, self._name)
|
||||
self._load()
|
||||
self._preprocess()
|
||||
_url = _get_dgl_url('dataset/ppi.zip')
|
||||
super(PPIDataset, self).__init__(name='ppi',
|
||||
url=_url,
|
||||
raw_dir=raw_dir,
|
||||
force_reload=force_reload,
|
||||
verbose=verbose)
|
||||
|
||||
def _download(self):
|
||||
download(_get_dgl_url(_url), path=self._zip_file_path)
|
||||
extract_archive(self._zip_file_path,
|
||||
'{}/{}'.format(self._dir, self._name))
|
||||
def process(self):
|
||||
graph_file = os.path.join(self.save_path, '{}_graph.json'.format(self.mode))
|
||||
label_file = os.path.join(self.save_path, '{}_labels.npy'.format(self.mode))
|
||||
feat_file = os.path.join(self.save_path, '{}_feats.npy'.format(self.mode))
|
||||
graph_id_file = os.path.join(self.save_path, '{}_graph_id.npy'.format(self.mode))
|
||||
|
||||
@retry_method_with_fix(_download)
|
||||
def _load(self):
|
||||
"""Loads input data.
|
||||
g_data = json.load(open(graph_file))
|
||||
self._labels = np.load(label_file)
|
||||
self._feats = np.load(feat_file)
|
||||
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
|
||||
graph_id = np.load(graph_id_file)
|
||||
|
||||
train/test/valid_graph.json => the graph data used for training,
|
||||
test and validation as json format;
|
||||
train/test/valid_feats.npy => the feature vectors of nodes as
|
||||
numpy.ndarry object, it's shape is [n, v],
|
||||
n is the number of nodes, v is the feature's dimension;
|
||||
train/test/valid_labels.npy=> the labels of the input nodes, it
|
||||
is a numpy ndarry, it's like[[0, 0, 1, ... 0],
|
||||
[0, 1, 1, 0 ...1]], shape of it is n*h, n is the number of nodes,
|
||||
h is the label's dimension;
|
||||
train/test/valid/_graph_id.npy => the element in it indicates which
|
||||
graph the nodes belong to, it is a one dimensional numpy.ndarray
|
||||
object and the length of it is equal the number of nodes,
|
||||
it's like [1, 1, 2, 1...20].
|
||||
"""
|
||||
print('Loading G...')
|
||||
if self.mode == 'train':
|
||||
with open('{}/ppi/train_graph.json'.format(self._dir)) as jsonfile:
|
||||
g_data = json.load(jsonfile)
|
||||
self.labels = np.load('{}/ppi/train_labels.npy'.format(self._dir))
|
||||
self.features = np.load('{}/ppi/train_feats.npy'.format(self._dir))
|
||||
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
|
||||
self.graph_id = np.load('{}/ppi/train_graph_id.npy'.format(self._dir))
|
||||
# lo, hi means the range of graph ids for different portion of the dataset,
|
||||
# 20 graphs for training, 2 for validation and 2 for testing.
|
||||
lo, hi = 1, 21
|
||||
if self.mode == 'valid':
|
||||
with open('{}/ppi/valid_graph.json'.format(self._dir)) as jsonfile:
|
||||
g_data = json.load(jsonfile)
|
||||
self.labels = np.load('{}/ppi/valid_labels.npy'.format(self._dir))
|
||||
self.features = np.load('{}/ppi/valid_feats.npy'.format(self._dir))
|
||||
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
|
||||
self.graph_id = np.load('{}/ppi/valid_graph_id.npy'.format(self._dir))
|
||||
if self.mode == 'test':
|
||||
with open('{}/ppi/test_graph.json'.format(self._dir)) as jsonfile:
|
||||
g_data = json.load(jsonfile)
|
||||
self.labels = np.load('{}/ppi/test_labels.npy'.format(self._dir))
|
||||
self.features = np.load('{}/ppi/test_feats.npy'.format(self._dir))
|
||||
self.graph = from_networkx(nx.DiGraph(json_graph.node_link_graph(g_data)))
|
||||
self.graph_id = np.load('{}/ppi/test_graph_id.npy'.format(self._dir))
|
||||
lo, hi = 21, 23
|
||||
elif self.mode == 'test':
|
||||
lo, hi = 23, 25
|
||||
|
||||
def _preprocess(self):
|
||||
if self.mode == 'train':
|
||||
self.train_mask_list = []
|
||||
self.train_graphs = []
|
||||
self.train_labels = []
|
||||
for train_graph_id in range(1, 21):
|
||||
train_graph_mask = np.where(self.graph_id == train_graph_id)[0]
|
||||
self.train_mask_list.append(train_graph_mask)
|
||||
self.train_graphs.append(self.graph.subgraph(train_graph_mask))
|
||||
self.train_labels.append(self.labels[train_graph_mask])
|
||||
if self.mode == 'valid':
|
||||
self.valid_mask_list = []
|
||||
self.valid_graphs = []
|
||||
self.valid_labels = []
|
||||
for valid_graph_id in range(21, 23):
|
||||
valid_graph_mask = np.where(self.graph_id == valid_graph_id)[0]
|
||||
self.valid_mask_list.append(valid_graph_mask)
|
||||
self.valid_graphs.append(self.graph.subgraph(valid_graph_mask))
|
||||
self.valid_labels.append(self.labels[valid_graph_mask])
|
||||
if self.mode == 'test':
|
||||
self.test_mask_list = []
|
||||
self.test_graphs = []
|
||||
self.test_labels = []
|
||||
for test_graph_id in range(23, 25):
|
||||
test_graph_mask = np.where(self.graph_id == test_graph_id)[0]
|
||||
self.test_mask_list.append(test_graph_mask)
|
||||
self.test_graphs.append(self.graph.subgraph(test_graph_mask))
|
||||
self.test_labels.append(self.labels[test_graph_mask])
|
||||
graph_masks = []
|
||||
self.graphs = []
|
||||
for g_id in range(lo, hi):
|
||||
g_mask = np.where(graph_id == g_id)[0]
|
||||
graph_masks.append(g_mask)
|
||||
g = self.graph.subgraph(g_mask)
|
||||
g.ndata['feat'] = F.tensor(self._feats[g_mask], dtype=F.data_type_dict['float32'])
|
||||
g.ndata['label'] = F.tensor(self._labels[g_mask], dtype=F.data_type_dict['float32'])
|
||||
self.graphs.append(g)
|
||||
|
||||
def has_cache(self):
|
||||
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
|
||||
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
|
||||
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
|
||||
return os.path.exists(graph_list_path) and os.path.exists(g_path) and os.path.exists(info_path)
|
||||
|
||||
def save(self):
|
||||
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
|
||||
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
|
||||
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
|
||||
save_graphs(graph_list_path, self.graphs)
|
||||
save_graphs(g_path, self.graph)
|
||||
save_info(info_path, {'labels': self._labels, 'feats': self._feats})
|
||||
|
||||
def load(self):
|
||||
graph_list_path = os.path.join(self.save_path, '{}_dgl_graph_list.bin'.format(self.mode))
|
||||
g_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode))
|
||||
info_path = os.path.join(self.save_path, '{}_info.pkl'.format(self.mode))
|
||||
self.graphs = load_graphs(graph_list_path)[0]
|
||||
g, _ = load_graphs(g_path)
|
||||
self.graph = g[0]
|
||||
info = load_info(info_path)
|
||||
self._labels = info['labels']
|
||||
self._feats = info['feats']
|
||||
|
||||
@property
|
||||
def num_labels(self):
|
||||
return 121
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
deprecate_property('dataset.labels', 'dataset.graphs[i].ndata[\'label\']')
|
||||
return self._labels
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
deprecate_property('dataset.features', 'dataset.graphs[i].ndata[\'feat\']')
|
||||
return self._feats
|
||||
|
||||
def __len__(self):
|
||||
"""Return number of samples in this dataset."""
|
||||
if self.mode == 'train':
|
||||
return len(self.train_mask_list)
|
||||
if self.mode == 'valid':
|
||||
return len(self.valid_mask_list)
|
||||
if self.mode == 'test':
|
||||
return len(self.test_mask_list)
|
||||
return len(self.graphs)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Get the i^th sample.
|
||||
|
||||
Paramters
|
||||
Parameters
|
||||
---------
|
||||
idx : int
|
||||
item : int
|
||||
The sample index.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(dgl.DGLGraph, ndarray)
|
||||
The graph, and its label.
|
||||
dgl.DGLGraph
|
||||
The graph
|
||||
"""
|
||||
if self.mode == 'train':
|
||||
g = self.train_graphs[item]
|
||||
g.ndata['feat'] = self.features[self.train_mask_list[item]]
|
||||
label = self.train_labels[item]
|
||||
elif self.mode == 'valid':
|
||||
g = self.valid_graphs[item]
|
||||
g.ndata['feat'] = self.features[self.valid_mask_list[item]]
|
||||
label = self.valid_labels[item]
|
||||
elif self.mode == 'test':
|
||||
g = self.test_graphs[item]
|
||||
g.ndata['feat'] = self.features[self.test_mask_list[item]]
|
||||
label = self.test_labels[item]
|
||||
return g, label
|
||||
return self.graphs[item]
|
||||
|
||||
|
||||
class LegacyPPIDataset(PPIDataset):
|
||||
@@ -168,9 +172,5 @@ class LegacyPPIDataset(PPIDataset):
|
||||
(dgl.DGLGraph, ndarray, ndarray)
|
||||
The graph, features and its label.
|
||||
"""
|
||||
if self.mode == 'train':
|
||||
return self.train_graphs[item], self.features[self.train_mask_list[item]], self.train_labels[item]
|
||||
if self.mode == 'valid':
|
||||
return self.valid_graphs[item], self.features[self.valid_mask_list[item]], self.valid_labels[item]
|
||||
if self.mode == 'test':
|
||||
return self.test_graphs[item], self.features[self.test_mask_list[item]], self.test_labels[item]
|
||||
|
||||
return self.graphs[item], self.graphs[item].ndata['feat'], self.graphs[item].ndata['label']
|
||||
|
||||
Reference in New Issue
Block a user