Files
dgl/tests/python/pytorch/sparse/test_matrix_op.py
2023-10-09 08:36:49 +08:00

48 lines
1.3 KiB
Python

import backend as F
import pytest
import torch
from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)
@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("dim", [0, 1])
@pytest.mark.parametrize("index", [None, (1, 3), (4, 0, 2)])
def test_compact(create_func, dim, index):
ctx = F.ctx()
shape = (5, 5)
ans_idx = []
if index is not None:
ans_idx = list(dict.fromkeys(index))
index = torch.tensor(index).to(ctx)
A = create_func(shape, 8, ctx)
A_compact, ret_id = A.compact(dim, index)
A_compact_dense = sparse_matrix_to_dense(A_compact)
A_dense = sparse_matrix_to_dense(A)
for i in range(shape[dim]):
if dim == 0:
row = list(A_dense[i, :].nonzero().reshape(-1))
else:
row = list(A_dense[:, i].nonzero().reshape(-1))
if (i not in list(ans_idx)) and len(row) > 0:
ans_idx.append(i)
if len(ans_idx):
ans_idx = torch.tensor(ans_idx).to(ctx)
A_dense_select = sparse_matrix_to_dense(A.index_select(dim, ans_idx))
assert A_compact_dense.shape == A_dense_select.shape
assert torch.allclose(A_compact_dense, A_dense_select)
assert torch.allclose(ans_idx, ret_id)