Files
dgl/tests/python/common/test_frame.py
Hongzhi (Steve), Chen 8a83027274 [Misc] Merge test utils. (#5440)
* merge

* format

* rename

* sort

* sort

* update

* update

* update

* Update tests/utils/checks.py

Co-authored-by: Mufei Li <mufeili1996@gmail.com>

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
Co-authored-by: Mufei Li <mufeili1996@gmail.com>
2023-03-20 15:32:45 +08:00

121 lines
2.9 KiB
Python

import pickle
import unittest
import backend as F
import dgl
import dgl.ndarray as nd
import numpy as np
from dgl.frame import Column
from utils import parametrize_idtype
def test_column_subcolumn():
data = F.copy_to(
F.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 9.0, 0.0],
[3.0, 2.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 4.0, 0.0],
]
),
F.ctx(),
)
original = Column(data)
# subcolumn from cpu context
i1 = F.tensor([0, 2, 1, 3], dtype=F.int64)
l1 = original.subcolumn(i1)
assert len(l1) == i1.shape[0]
assert F.array_equal(l1.data, F.gather_row(data, i1))
# next subcolumn from target context
i2 = F.copy_to(F.tensor([0, 2], dtype=F.int64), F.ctx())
l2 = l1.subcolumn(i2)
assert len(l2) == i2.shape[0]
i1i2 = F.copy_to(F.gather_row(i1, F.copy_to(i2, F.context(i1))), F.ctx())
assert F.array_equal(l2.data, F.gather_row(data, i1i2))
# next subcolumn also from target context
i3 = F.copy_to(F.tensor([1], dtype=F.int64), F.ctx())
l3 = l2.subcolumn(i3)
assert len(l3) == i3.shape[0]
i1i2i3 = F.copy_to(
F.gather_row(i1i2, F.copy_to(i3, F.context(i1i2))), F.ctx()
)
assert F.array_equal(l3.data, F.gather_row(data, i1i2i3))
def test_serialize_deserialize_plain():
data = F.copy_to(
F.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 9.0, 0.0],
[3.0, 2.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 4.0, 0.0],
]
),
F.ctx(),
)
original = Column(data)
serial = pickle.dumps(original)
new = pickle.loads(serial)
print("new = {}".format(new))
assert F.array_equal(new.data, original.data)
def test_serialize_deserialize_subcolumn():
data = F.copy_to(
F.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 9.0, 0.0],
[3.0, 2.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 4.0, 0.0],
]
),
F.ctx(),
)
original = Column(data)
# subcolumn from cpu context
i1 = F.tensor([0, 2, 1, 3], dtype=F.int64)
l1 = original.subcolumn(i1)
serial = pickle.dumps(l1)
new = pickle.loads(serial)
assert F.array_equal(new.data, l1.data)
def test_serialize_deserialize_dtype():
data = F.copy_to(
F.tensor(
[
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 9.0, 0.0],
[3.0, 2.0, 1.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[0.0, 2.0, 4.0, 0.0],
]
),
F.ctx(),
)
original = Column(data)
original = original.astype(F.int64)
serial = pickle.dumps(original)
new = pickle.loads(serial)
assert new.dtype == F.int64