Files
dgl/tests/python/common/data/test_utils.py
Gleb Bazhenov 54b4bd0acd [Data] new functional for creating data splits in graph (#5418)
* new functional for creating data splits in graph

* minor fix in data split implementation

* apply suggestions from code review

Co-authored-by: Mufei Li <mufeili1996@gmail.com>

* refactoring + unit tests

* fix test file name

* move imports to the top

* Revert "fix test file name"

This reverts commit 126323e38c.

* remove nccl submodule

* address linter issues

---------

Co-authored-by: Mufei Li <mufeili1996@gmail.com>
Co-authored-by: Hongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
2023-03-09 18:55:25 +08:00

103 lines
3.3 KiB
Python

import gzip
import io
import os
import tarfile
import tempfile
import unittest
import backend as F
import dgl
import dgl.data as data
import numpy as np
import pandas as pd
import pytest
import yaml
from dgl import DGLError
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_add_nodepred_split():
dataset = data.AmazonCoBuyComputerDataset()
print("train_mask" in dataset[0].ndata)
data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
assert "train_mask" in dataset[0].ndata
dataset = data.AIFBDataset()
print("train_mask" in dataset[0].nodes["Publikationen"].data)
data.utils.add_nodepred_split(
dataset, [0.8, 0.1, 0.1], ntype="Publikationen"
)
assert "train_mask" in dataset[0].nodes["Publikationen"].data
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_extract_archive():
# gzip
with tempfile.TemporaryDirectory() as src_dir:
gz_file = "gz_archive"
gz_path = os.path.join(src_dir, gz_file + ".gz")
content = b"test extract archive gzip"
with gzip.open(gz_path, "wb") as f:
f.write(content)
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, gz_file))
# tar
with tempfile.TemporaryDirectory() as src_dir:
tar_file = "tar_archive"
tar_path = os.path.join(src_dir, tar_file + ".tar")
# default encode to utf8
content = "test extract archive tar\n".encode()
info = tarfile.TarInfo(name="tar_archive")
info.size = len(content)
with tarfile.open(tar_path, "w") as f:
f.addfile(info, io.BytesIO(content))
with tempfile.TemporaryDirectory() as dst_dir:
data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
assert os.path.exists(os.path.join(dst_dir, tar_file))
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_mask_nodes_by_property():
num_nodes = 1000
property_values = np.random.uniform(size=num_nodes)
part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
split_masks = data.utils.mask_nodes_by_property(
property_values, part_ratios
)
assert "in_valid_mask" in split_masks
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_add_node_property_split():
dataset = data.AmazonCoBuyComputerDataset()
part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
for property_name in ["popularity", "locality", "density"]:
data.utils.add_node_property_split(dataset, part_ratios, property_name)
assert "in_valid_mask" in dataset[0].ndata
if __name__ == "__main__":
test_extract_archive()
test_add_nodepred_split()
test_mask_nodes_by_property()
test_add_node_property_split()