Files
dgl/tests/python/common/backend/test_tensor.py
Hongzhi (Steve), Chen 74c9d27d16 [Misc] Auto-format tests. (#5313)
* [Misc] Auto-format tests.

* more

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
2023-02-17 23:33:44 +08:00

55 lines
1.3 KiB
Python

import unittest
import backend as F
import dgl
import dgl.ndarray as nd
import numpy as np
@unittest.skipIf(
dgl.backend.backend_name == "tensorflow",
reason="TF doesn't support inplace update",
)
def test_dlpack():
# test dlpack conversion.
def nd2th():
ans = np.array(
[[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
)
x = nd.array(np.zeros((3, 4), dtype=np.float32))
dl = x.to_dlpack()
y = F.zerocopy_from_dlpack(dl)
y[0] = 1
print(x)
print(y)
assert np.allclose(x.asnumpy(), ans)
def th2nd():
ans = np.array(
[[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]
)
x = F.zeros((3, 4))
dl = F.zerocopy_to_dlpack(x)
y = nd.from_dlpack(dl)
x[0] = 1
print(x)
print(y)
assert np.allclose(y.asnumpy(), ans)
def th2nd_incontiguous():
x = F.astype(F.tensor([[0, 1], [2, 3]]), F.int64)
ans = np.array([0, 2])
y = x[:2, 0]
# Uncomment this line and comment the one below to observe error
# dl = dlpack.to_dlpack(y)
dl = F.zerocopy_to_dlpack(y)
z = nd.from_dlpack(dl)
print(x)
print(z)
assert np.allclose(z.asnumpy(), ans)
nd2th()
th2nd()
th2nd_incontiguous()