mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
Fix num_labels to num_classes in dataset files (#6666)
This commit is contained in:
@@ -93,7 +93,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -69,7 +69,7 @@ def main(args):
|
||||
test_mask = g.ndata["test_mask"]
|
||||
test_mask = mx.nd.array(np.nonzero(test_mask.asnumpy())[0], ctx=ctx)
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
|
||||
g = dgl.remove_self_loop(g)
|
||||
|
||||
@@ -94,7 +94,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -46,7 +46,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -111,7 +111,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -93,7 +93,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -56,7 +56,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -48,7 +48,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -86,7 +86,7 @@ class QM9(QM9Dataset):
|
||||
Examples
|
||||
--------
|
||||
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
|
||||
>>> data.num_labels
|
||||
>>> data.num_classes
|
||||
2
|
||||
>>>
|
||||
>>> # iterate over the dataset
|
||||
|
||||
@@ -116,7 +116,7 @@ if __name__ == "__main__":
|
||||
|
||||
# create GAT model
|
||||
in_size = features.shape[1]
|
||||
out_size = train_dataset.num_labels
|
||||
out_size = train_dataset.num_classes
|
||||
model = GAT(in_size, 256, out_size, heads=[4, 4, 6]).to(device)
|
||||
|
||||
# model training
|
||||
|
||||
@@ -49,7 +49,7 @@ def main(args):
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
num_classes = train_dataset.num_labels
|
||||
num_classes = train_dataset.num_classes
|
||||
|
||||
# Extract node features
|
||||
graph = train_dataset[0]
|
||||
|
||||
@@ -59,7 +59,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.num_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -51,7 +51,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.num_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -28,7 +28,7 @@ def load_dataset(name):
|
||||
|
||||
data = CitationGraphDataset("cora")
|
||||
g = data[0]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
train_mask = g.ndata["train_mask"]
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
|
||||
@@ -38,7 +38,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.num_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -50,7 +50,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.number_of_edges()
|
||||
|
||||
# add self loop
|
||||
|
||||
@@ -66,7 +66,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
num_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -114,7 +114,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -121,7 +121,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -43,7 +43,7 @@ def main(args):
|
||||
val_mask = g.ndata["val_mask"]
|
||||
test_mask = g.ndata["test_mask"]
|
||||
in_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_classes = data.num_classes
|
||||
n_edges = g.number_of_edges()
|
||||
print(
|
||||
"""----Data statistics------'
|
||||
|
||||
@@ -59,7 +59,7 @@ class PPIDataset(DGLBuiltinDataset):
|
||||
Examples
|
||||
--------
|
||||
>>> dataset = PPIDataset(mode='valid')
|
||||
>>> num_labels = dataset.num_labels
|
||||
>>> num_classes = dataset.num_classes
|
||||
>>> for g in dataset:
|
||||
.... feat = g.ndata['feat']
|
||||
.... label = g.ndata['label']
|
||||
@@ -173,6 +173,10 @@ class PPIDataset(DGLBuiltinDataset):
|
||||
def num_labels(self):
|
||||
return 121
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return 121
|
||||
|
||||
def __len__(self):
|
||||
"""Return number of samples in this dataset."""
|
||||
return len(self.graphs)
|
||||
|
||||
@@ -141,6 +141,11 @@ class QM7bDataset(DGLDataset):
|
||||
"""Number of prediction tasks."""
|
||||
return 14
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
"""Number of prediction tasks."""
|
||||
return 14
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""Get graph and label by index
|
||||
|
||||
|
||||
@@ -157,6 +157,16 @@ class QM9Dataset(DGLDataset):
|
||||
"""
|
||||
return self.label.shape[1]
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
r"""
|
||||
Returns
|
||||
--------
|
||||
int
|
||||
Number of prediction tasks.
|
||||
"""
|
||||
return self.label.shape[1]
|
||||
|
||||
@property
|
||||
def num_tasks(self):
|
||||
r"""
|
||||
|
||||
Reference in New Issue
Block a user