mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
import backend as F
|
|
import pytest
|
|
import torch
|
|
|
|
from .utils import (
|
|
rand_coo,
|
|
rand_csc,
|
|
rand_csr,
|
|
rand_diag,
|
|
sparse_matrix_to_dense,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
|
|
)
|
|
@pytest.mark.parametrize("dim", [0, 1])
|
|
@pytest.mark.parametrize("index", [None, (1, 3), (4, 0, 2)])
|
|
def test_compact(create_func, dim, index):
|
|
ctx = F.ctx()
|
|
shape = (5, 5)
|
|
ans_idx = []
|
|
if index is not None:
|
|
ans_idx = list(dict.fromkeys(index))
|
|
index = torch.tensor(index).to(ctx)
|
|
|
|
A = create_func(shape, 8, ctx)
|
|
|
|
A_compact, ret_id = A.compact(dim, index)
|
|
A_compact_dense = sparse_matrix_to_dense(A_compact)
|
|
|
|
A_dense = sparse_matrix_to_dense(A)
|
|
|
|
for i in range(shape[dim]):
|
|
if dim == 0:
|
|
row = list(A_dense[i, :].nonzero().reshape(-1))
|
|
else:
|
|
row = list(A_dense[:, i].nonzero().reshape(-1))
|
|
if (i not in list(ans_idx)) and len(row) > 0:
|
|
ans_idx.append(i)
|
|
if len(ans_idx):
|
|
ans_idx = torch.tensor(ans_idx).to(ctx)
|
|
A_dense_select = sparse_matrix_to_dense(A.index_select(dim, ans_idx))
|
|
|
|
assert A_compact_dense.shape == A_dense_select.shape
|
|
assert torch.allclose(A_compact_dense, A_dense_select)
|
|
assert torch.allclose(ans_idx, ret_id)
|