mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-05 19:54:25 +08:00
* new functional for creating data splits in graph
* minor fix in data split implementation
* apply suggestions from code review
Co-authored-by: Mufei Li <mufeili1996@gmail.com>
* refactoring + unit tests
* fix test file name
* move imports to the top
* Revert "fix test file name"
This reverts commit 126323e38c.
* remove nccl submodule
* address linter issues
---------
Co-authored-by: Mufei Li <mufeili1996@gmail.com>
Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
import gzip
|
|
import io
|
|
import os
|
|
import tarfile
|
|
import tempfile
|
|
import unittest
|
|
|
|
import backend as F
|
|
|
|
import dgl
|
|
import dgl.data as data
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pytest
|
|
import yaml
|
|
from dgl import DGLError
|
|
|
|
|
|
@unittest.skipIf(
|
|
F._default_context_str == "gpu",
|
|
reason="Datasets don't need to be tested on GPU.",
|
|
)
|
|
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
|
|
def test_add_nodepred_split():
|
|
dataset = data.AmazonCoBuyComputerDataset()
|
|
print("train_mask" in dataset[0].ndata)
|
|
data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
|
|
assert "train_mask" in dataset[0].ndata
|
|
|
|
dataset = data.AIFBDataset()
|
|
print("train_mask" in dataset[0].nodes["Publikationen"].data)
|
|
data.utils.add_nodepred_split(
|
|
dataset, [0.8, 0.1, 0.1], ntype="Publikationen"
|
|
)
|
|
assert "train_mask" in dataset[0].nodes["Publikationen"].data
|
|
|
|
|
|
@unittest.skipIf(
|
|
F._default_context_str == "gpu",
|
|
reason="Datasets don't need to be tested on GPU.",
|
|
)
|
|
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
|
|
def test_extract_archive():
|
|
# gzip
|
|
with tempfile.TemporaryDirectory() as src_dir:
|
|
gz_file = "gz_archive"
|
|
gz_path = os.path.join(src_dir, gz_file + ".gz")
|
|
content = b"test extract archive gzip"
|
|
with gzip.open(gz_path, "wb") as f:
|
|
f.write(content)
|
|
with tempfile.TemporaryDirectory() as dst_dir:
|
|
data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
|
|
assert os.path.exists(os.path.join(dst_dir, gz_file))
|
|
|
|
# tar
|
|
with tempfile.TemporaryDirectory() as src_dir:
|
|
tar_file = "tar_archive"
|
|
tar_path = os.path.join(src_dir, tar_file + ".tar")
|
|
# default encode to utf8
|
|
content = "test extract archive tar\n".encode()
|
|
info = tarfile.TarInfo(name="tar_archive")
|
|
info.size = len(content)
|
|
with tarfile.open(tar_path, "w") as f:
|
|
f.addfile(info, io.BytesIO(content))
|
|
with tempfile.TemporaryDirectory() as dst_dir:
|
|
data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
|
|
assert os.path.exists(os.path.join(dst_dir, tar_file))
|
|
|
|
|
|
@unittest.skipIf(
|
|
F._default_context_str == "gpu",
|
|
reason="Datasets don't need to be tested on GPU.",
|
|
)
|
|
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
|
|
def test_mask_nodes_by_property():
|
|
num_nodes = 1000
|
|
property_values = np.random.uniform(size=num_nodes)
|
|
part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
|
|
split_masks = data.utils.mask_nodes_by_property(
|
|
property_values, part_ratios
|
|
)
|
|
assert "in_valid_mask" in split_masks
|
|
|
|
|
|
@unittest.skipIf(
|
|
F._default_context_str == "gpu",
|
|
reason="Datasets don't need to be tested on GPU.",
|
|
)
|
|
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
|
|
def test_add_node_property_split():
|
|
dataset = data.AmazonCoBuyComputerDataset()
|
|
part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
|
|
for property_name in ["popularity", "locality", "density"]:
|
|
data.utils.add_node_property_split(dataset, part_ratios, property_name)
|
|
assert "in_valid_mask" in dataset[0].ndata
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_extract_archive()
|
|
test_add_nodepred_split()
|
|
test_mask_nodes_by_property()
|
|
test_add_node_property_split()
|