Files
dgl/tests/python/pytorch/sparse/test_sddmm.py

93 lines
2.8 KiB
Python

import sys
import backend as F
import pytest
import torch
from dgl.sparse import bsddmm, sddmm
from .utils import (
clone_detach_and_grad,
rand_coo,
rand_csc,
rand_csr,
rand_stride,
)
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
@pytest.mark.parametrize("nnz", [2, 10])
@pytest.mark.parametrize("hidden", [1, 5])
def test_sddmm(create_func, shape, nnz, hidden):
dev = F.ctx()
A = create_func(shape, nnz, dev)
if hidden > 1:
B = torch.rand(shape[0], hidden, requires_grad=True, device=dev)
C = torch.rand(hidden, shape[1], requires_grad=True, device=dev)
else:
B = torch.rand(shape[0], requires_grad=True, device=dev)
C = torch.rand(shape[1], requires_grad=True, device=dev)
B = rand_stride(B)
C = rand_stride(C)
A_val_clone = clone_detach_and_grad(A.val)
dense_B = clone_detach_and_grad(B)
dense_C = clone_detach_and_grad(C)
sparse_result = sddmm(A, B, C)
grad = torch.rand_like(sparse_result.val)
sparse_result.val.backward(grad)
if hidden == 1:
dense_result = dense_B.view(-1, 1) @ dense_C.view(1, -1)
else:
dense_result = dense_B @ dense_C
row, col = A.coo()
dense_val = dense_result[row, col] * A_val_clone
dense_val.backward(grad)
assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)
@pytest.mark.parametrize("create_func", [rand_coo, rand_csr, rand_csc])
@pytest.mark.parametrize("shape", [(5, 5), (5, 4)])
@pytest.mark.parametrize("nnz", [2, 10])
@pytest.mark.parametrize("nz_dim", [2, 10])
def test_bsddmm(create_func, shape, nnz, nz_dim):
dev = F.ctx()
hidden = 2
A = create_func(shape, nnz, dev, nz_dim)
B = torch.rand(shape[0], hidden, nz_dim, requires_grad=True, device=dev)
C = torch.rand(hidden, shape[1], nz_dim, requires_grad=True, device=dev)
B = rand_stride(B)
C = rand_stride(C)
A_val_clone = clone_detach_and_grad(A.val)
dense_B = clone_detach_and_grad(B)
dense_C = clone_detach_and_grad(C)
sparse_result = bsddmm(A, B, C)
grad = torch.rand_like(sparse_result.val)
sparse_result.val.backward(grad)
dense_result = dense_B.permute(2, 0, 1) @ dense_C.permute(2, 0, 1)
dense_result = dense_result.permute(1, 2, 0)
row, col = A.coo()
dense_val = dense_result[row, col] * A_val_clone
dense_val.backward(grad)
assert torch.allclose(dense_val, sparse_result.val, atol=1e-05)
assert torch.allclose(dense_C.grad, C.grad, atol=1e-05)
assert torch.allclose(dense_B.grad, B.grad, atol=1e-05)
assert torch.allclose(A_val_clone.grad, A.val.grad, atol=1e-05)