Fix num_labels to num_classes in dataset files (#6666)

This commit is contained in:
Zhen Liu
2023-12-04 13:25:44 +08:00
committed by GitHub
parent 5e64481be3
commit 15d05be387
23 changed files with 40 additions and 21 deletions

View File

@@ -93,7 +93,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -69,7 +69,7 @@ def main(args):
test_mask = g.ndata["test_mask"]
test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx)
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
g = dgl.remove_self_loop(g)

View File

@@ -94,7 +94,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -46,7 +46,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -111,7 +111,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -93,7 +93,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -56,7 +56,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -48,7 +48,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -86,7 +86,7 @@ class QM9(QM9Dataset):
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
>>> data.num_labels
>>> data.num_classes
2
>>>
>>> # iterate over the dataset

View File

@@ -116,7 +116,7 @@ if __name__ == "__main__":
# create GAT model
in_size = features.shape[1]
out_size = train_dataset.num_labels
out_size = train_dataset.num_classes
model = GAT(in_size, 256, out_size, heads=[4, 4, 6]).to(device)
# model training

View File

@@ -49,7 +49,7 @@ def main(args):
else:
device = "cpu"
num_classes = train_dataset.num_labels
num_classes = train_dataset.num_classes
# Extract node features
graph = train_dataset[0]

View File

@@ -59,7 +59,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.num_edges()
print(
"""----Data statistics------'

View File

@@ -51,7 +51,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.num_edges()
print(
"""----Data statistics------'

View File

@@ -28,7 +28,7 @@ def load_dataset(name):
data = CitationGraphDataset("cora")
g = data[0]
n_classes = data.num_labels
n_classes = data.num_classes
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

View File

@@ -38,7 +38,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.num_edges()
print(
"""----Data statistics------'

View File

@@ -50,7 +50,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.number_of_edges()
# add self loop

View File

@@ -66,7 +66,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
num_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -114,7 +114,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -121,7 +121,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = data.graph.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -43,7 +43,7 @@ def main(args):
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_classes = data.num_classes
n_edges = g.number_of_edges()
print(
"""----Data statistics------'

View File

@@ -59,7 +59,7 @@ class PPIDataset(DGLBuiltinDataset):
Examples
--------
>>> dataset = PPIDataset(mode='valid')
>>> num_labels = dataset.num_labels
>>> num_classes = dataset.num_classes
>>> for g in dataset:
.... feat = g.ndata['feat']
.... label = g.ndata['label']
@@ -173,6 +173,10 @@ class PPIDataset(DGLBuiltinDataset):
def num_labels(self):
return 121
@property
def num_classes(self):
return 121
def __len__(self):
"""Return number of samples in this dataset."""
return len(self.graphs)

View File

@@ -141,6 +141,11 @@ class QM7bDataset(DGLDataset):
"""Number of prediction tasks."""
return 14
@property
def num_classes(self):
"""Number of prediction tasks."""
return 14
def __getitem__(self, idx):
r"""Get graph and label by index

View File

@@ -157,6 +157,16 @@ class QM9Dataset(DGLDataset):
"""
return self.label.shape[1]
@property
def num_classes(self):
r"""
Returns
--------
int
Number of prediction tasks.
"""
return self.label.shape[1]
@property
def num_tasks(self):
r"""