mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
* [Misc] Auto-format tests. * more --------- Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
55 lines
1.3 KiB
Python
55 lines
1.3 KiB
Python
import unittest
|
|
|
|
import backend as F
|
|
|
|
import dgl
|
|
import dgl.ndarray as nd
|
|
import numpy as np
|
|
|
|
|
|
@unittest.skipIf(
|
|
dgl.backend.backend_name == "tensorflow",
|
|
reason="TF doesn't support inplace update",
|
|
)
|
|
def test_dlpack():
|
|
# test dlpack conversion.
|
|
def nd2th():
|
|
ans = np.array(
|
|
[[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
|
)
|
|
x = nd.array(np.zeros((3, 4), dtype=np.float32))
|
|
dl = x.to_dlpack()
|
|
y = F.zerocopy_from_dlpack(dl)
|
|
y[0] = 1
|
|
print(x)
|
|
print(y)
|
|
assert np.allclose(x.asnumpy(), ans)
|
|
|
|
def th2nd():
|
|
ans = np.array(
|
|
[[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
|
|
)
|
|
x = F.zeros((3, 4))
|
|
dl = F.zerocopy_to_dlpack(x)
|
|
y = nd.from_dlpack(dl)
|
|
x[0] = 1
|
|
print(x)
|
|
print(y)
|
|
assert np.allclose(y.asnumpy(), ans)
|
|
|
|
def th2nd_incontiguous():
|
|
x = F.astype(F.tensor([[0, 1], [2, 3]]), F.int64)
|
|
ans = np.array([0, 2])
|
|
y = x[:2, 0]
|
|
# Uncomment this line and comment the one below to observe error
|
|
# dl = dlpack.to_dlpack(y)
|
|
dl = F.zerocopy_to_dlpack(y)
|
|
z = nd.from_dlpack(dl)
|
|
print(x)
|
|
print(z)
|
|
assert np.allclose(z.asnumpy(), ans)
|
|
|
|
nd2th()
|
|
th2nd()
|
|
th2nd_incontiguous()
|