mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
Revert "[Sparse] Add SpMM and SDDMM." (#5014)
* Revert "[Sparse] Add SpMM and SDDMM. (#4999)"
This reverts commit 15365d7855.
* lint
This commit is contained in:
@@ -357,6 +357,5 @@ endif(BUILD_CPP_TEST)
|
||||
|
||||
if(BUILD_SPARSE)
|
||||
set(DGL_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/include")
|
||||
list(APPEND DGL_INCLUDE "${CMAKE_CURRENT_SOURCE_DIR}/src")
|
||||
add_subdirectory(dgl_sparse)
|
||||
endif(BUILD_SPARSE)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/**
|
||||
* Copyright (c) 2022 by Contributors
|
||||
* @file sparse/dgl_headers.h
|
||||
* @file dgl_headers.h
|
||||
* @brief DGL headers used in the sparse library. This is a workaround to
|
||||
* avoid the macro naming conflict between dmlc/logging.h and torch logger. This
|
||||
* file includes all the DGL headers used in the sparse library and
|
||||
@@ -14,14 +14,9 @@
|
||||
|
||||
#include <dgl/aten/coo.h>
|
||||
#include <dgl/aten/csr.h>
|
||||
#include <dgl/graph.h>
|
||||
#include <dgl/kernel.h>
|
||||
#include <dgl/runtime/dlpack_convert.h>
|
||||
#include <dmlc/logging.h>
|
||||
|
||||
// Headers not in DGL include directory
|
||||
#include "graph/unit_graph.h"
|
||||
|
||||
#undef CHECK
|
||||
#undef CHECK_OP
|
||||
#undef CHECK_EQ
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2022 by Contributors
|
||||
* @file sparse/sddmm.h
|
||||
* @brief DGL C++ SDDMM operator.
|
||||
*/
|
||||
#ifndef SPARSE_SDDMM_H_
|
||||
#define SPARSE_SDDMM_H_
|
||||
|
||||
// clang-format off
|
||||
#include <sparse/dgl_headers.h>
|
||||
// clang-format on
|
||||
|
||||
#include <sparse/sparse_matrix.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace dgl {
|
||||
namespace sparse {
|
||||
|
||||
/**
|
||||
* @brief Perform a sampled matrix multiplication of a sparse matrix and two
|
||||
* dense matrices. For efficiency, `mat2_tr` is the transposition of the matrix
|
||||
* to be multiplied. If the sparse matrix has shape (n, m), `mat1` and `mat2_tr`
|
||||
* must have shapes of `(n, k)` and `(m, k)` or
|
||||
* `(n,)` and `(m,)` respectively. And the returned tensor has shape
|
||||
* `(sparse_matrix->nnz())`.
|
||||
*
|
||||
* This function does not take care of autograd.
|
||||
*
|
||||
* @param sparse_mat The sparse matrix.
|
||||
* @param mat1 The first dense matrix.
|
||||
* @param mat2_tr Transposition of the second matrix.
|
||||
*
|
||||
* @return Dense tensor.
|
||||
*/
|
||||
torch::Tensor SDDMMImpl(
|
||||
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
|
||||
torch::Tensor mat2_tr);
|
||||
|
||||
} // namespace sparse
|
||||
} // namespace dgl
|
||||
|
||||
#endif // SPARSE_SDDMM_H_
|
||||
@@ -83,15 +83,6 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr);
|
||||
/** @brief COO transposition. */
|
||||
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);
|
||||
|
||||
/** @brief Convert a COO sparse format to DGL Graph. */
|
||||
HeteroGraphPtr COOToDGLGraph(const std::shared_ptr<COO>& coo);
|
||||
|
||||
/** @brief Convert a CSR sparse format to DGL Graph. */
|
||||
HeteroGraphPtr CSRToDGLGraph(const std::shared_ptr<CSR>& csr);
|
||||
|
||||
/** @brief Convert a CSC sparse format to DGL Graph. */
|
||||
HeteroGraphPtr CSCToDGLGraph(const std::shared_ptr<CSR>& csc);
|
||||
|
||||
} // namespace sparse
|
||||
} // namespace dgl
|
||||
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2022 by Contributors
|
||||
* @file sparse/spmm.h
|
||||
* @brief DGL C++ SpMM operator.
|
||||
*/
|
||||
#ifndef SPARSE_SPMM_H_
|
||||
#define SPARSE_SPMM_H_
|
||||
|
||||
// clang-format off
|
||||
#include <sparse/dgl_headers.h>
|
||||
// clang-format on
|
||||
|
||||
#include <sparse/sparse_matrix.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace dgl {
|
||||
namespace sparse {
|
||||
|
||||
/**
|
||||
* @brief Perform a matrix multiplication of the sparse matrix and dense
|
||||
* matrix. It uses the sparse formats of `sparse_mat` and non-zero values of
|
||||
* `sparse_val` for SpMM. The `sparse_val` must be 1-dimensional. If the sparse
|
||||
* matrix has shape (n, m), the dense matrix must have shape (m, k) or (m,). and
|
||||
* the returned dense matrix has shape (n, k) or (n,
|
||||
* ).
|
||||
*
|
||||
* This function does not take care of autograd.
|
||||
*
|
||||
* @param sparse_mat The sparse matrix.
|
||||
* @param sparse_val Non-zero values of the sparse matrix.
|
||||
* @param dense_mat The dense matrix.
|
||||
*
|
||||
* @return Dense tensor.
|
||||
*/
|
||||
torch::Tensor SpMMImpl(
|
||||
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
|
||||
torch::Tensor sparse_val, torch::Tensor dense_mat);
|
||||
|
||||
/**
|
||||
* @brief Perform a matrix multiplication of the sparse matrix and dense
|
||||
* matrix. The sparse matrix must have 1-dimensional values. If the sparse
|
||||
* matrix has shape (n, m), the dense matrix must have shape (m, k) or (m,), and
|
||||
* the returned dense matrix has shape (n, k) or (n,).
|
||||
*
|
||||
* This function supports autograd for both the sparse and dense matrix.
|
||||
*
|
||||
* @param sparse_mat The sparse matrix.
|
||||
* @param dense_mat The dense matrix.
|
||||
*
|
||||
* @return Dense matrix.
|
||||
*/
|
||||
torch::Tensor SpMM(
|
||||
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
|
||||
torch::Tensor dense_mat);
|
||||
|
||||
} // namespace sparse
|
||||
} // namespace dgl
|
||||
|
||||
#endif // SPARSE_SPMM_H_
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
#include <sparse/elementwise_op.h>
|
||||
#include <sparse/sparse_matrix.h>
|
||||
#include <sparse/spmm.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
@@ -30,8 +29,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
|
||||
.def("create_from_csr", &CreateFromCSR)
|
||||
.def("create_from_csc", &CreateFromCSC)
|
||||
.def("spsp_add", &SpSpAdd)
|
||||
.def("val_like", &CreateValLike)
|
||||
.def("spmm", &SpMM);
|
||||
.def("val_like", &CreateValLike);
|
||||
}
|
||||
|
||||
} // namespace sparse
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2022 by Contributors
|
||||
* @file sddmm.cc
|
||||
* @brief DGL C++ sparse SDDMM operator implementation.
|
||||
*/
|
||||
// clang-format off
|
||||
#include <sparse/dgl_headers.h>
|
||||
// clang-format on
|
||||
|
||||
#include <sparse/sparse_matrix.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "./utils.h"
|
||||
|
||||
namespace dgl {
|
||||
namespace sparse {
|
||||
|
||||
torch::Tensor SDDMMImpl(
|
||||
const c10::intrusive_ptr<SparseMatrix>& sparse_mat, torch::Tensor mat1,
|
||||
torch::Tensor mat2_tr) {
|
||||
HeteroGraphPtr dgl_graph;
|
||||
// Use CSR if the spars matrix has CSR or does not have COO. Otherwise use
|
||||
// COO.
|
||||
if (sparse_mat->HasCSR() || !sparse_mat->HasCOO()) {
|
||||
auto csr = sparse_mat->CSRPtr();
|
||||
dgl_graph = CSRToDGLGraph(csr);
|
||||
} else {
|
||||
auto coo = sparse_mat->COOPtr();
|
||||
dgl_graph = COOToDGLGraph(coo);
|
||||
}
|
||||
if (mat2_tr.dim() == 1) {
|
||||
mat1 = mat1.view({-1, 1});
|
||||
mat2_tr = mat2_tr.view({-1, 1});
|
||||
}
|
||||
int64_t out_row = sparse_mat->nnz();
|
||||
auto shape = std::vector<int64_t>({out_row});
|
||||
auto ret = torch::zeros(shape, mat1.options());
|
||||
const std::string op = "dot";
|
||||
aten::SDDMM(
|
||||
op.c_str(), dgl_graph, TorchTensorToDGLArray(mat1),
|
||||
TorchTensorToDGLArray(mat2_tr), TorchTensorToDGLArray(ret),
|
||||
0 /* Lhs target: u */, 2 /* rhs target: v */);
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace sparse
|
||||
} // namespace dgl
|
||||
@@ -92,20 +92,5 @@ std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
|
||||
return COOFromOldDGLCOO(dgl_coo_tr);
|
||||
}
|
||||
|
||||
HeteroGraphPtr COOToDGLGraph(const std::shared_ptr<COO>& coo) {
|
||||
auto dgl_coo = COOToOldDGLCOO(coo);
|
||||
return UnitGraph::CreateFromCOO(2 /* Number of node types */, dgl_coo);
|
||||
}
|
||||
|
||||
HeteroGraphPtr CSRToDGLGraph(const std::shared_ptr<CSR>& csr) {
|
||||
auto dgl_csr = CSRToOldDGLCSR(csr);
|
||||
return UnitGraph::CreateFromCSR(2 /* Number of node types */, dgl_csr);
|
||||
}
|
||||
|
||||
HeteroGraphPtr CSCToDGLGraph(const std::shared_ptr<CSR>& csc) {
|
||||
auto dgl_csc = CSRToOldDGLCSR(csc);
|
||||
return UnitGraph::CreateFromCSC(2 /* Number of node types */, dgl_csc);
|
||||
}
|
||||
|
||||
} // namespace sparse
|
||||
} // namespace dgl
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2022 by Contributors
|
||||
* @file spmm.cc
|
||||
* @brief DGL C++ sparse SpMM operator implementation.
|
||||
*/
|
||||
// clang-format off
|
||||
#include <sparse/dgl_headers.h>
|
||||
// clang-format on
|
||||
|
||||
#include <sparse/sddmm.h>
|
||||
#include <sparse/sparse_matrix.h>
|
||||
#include <sparse/spmm.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
#include "./utils.h"
|
||||
|
||||
namespace dgl {
|
||||
namespace sparse {
|
||||
|
||||
using namespace torch::autograd;
|
||||
|
||||
class SpMMAutoGrad : public Function<SpMMAutoGrad> {
|
||||
public:
|
||||
static torch::Tensor forward(
|
||||
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
|
||||
torch::Tensor sparse_val, torch::Tensor dense_mat);
|
||||
|
||||
static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
|
||||
};
|
||||
|
||||
void _SpMMSanityCheck(
|
||||
c10::intrusive_ptr<SparseMatrix> sparse_mat, torch::Tensor sparse_val,
|
||||
torch::Tensor dense_mat) {
|
||||
const auto& sparse_mat_shape = sparse_mat->shape();
|
||||
auto val_shape = sparse_val.sizes();
|
||||
auto dense_shape = dense_mat.sizes();
|
||||
CHECK(sparse_mat_shape[1] == dense_shape[0]);
|
||||
CHECK(val_shape.size() == 1 && val_shape[0] == sparse_mat->nnz());
|
||||
CHECK_LE(dense_shape.size(), 2);
|
||||
CHECK(sparse_val.dtype() == dense_mat.dtype());
|
||||
CHECK(
|
||||
sparse_val.device() == sparse_mat->device() &&
|
||||
sparse_val.device() == dense_mat.device());
|
||||
}
|
||||
|
||||
torch::Tensor SpMMImpl(
|
||||
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
|
||||
torch::Tensor sparse_val, torch::Tensor dense_mat) {
|
||||
// Transpose the sparse matrix because dgl::aten::SpMM calculates A^T @ X.
|
||||
auto sparse_mat_tr = sparse_mat->Transpose();
|
||||
HeteroGraphPtr dgl_graph;
|
||||
// Use CSR if the spars matrix has CSR or does not have COO. Otherwise use
|
||||
// COO.
|
||||
if (sparse_mat->HasCSC() || !sparse_mat->HasCOO()) {
|
||||
auto csc = sparse_mat_tr->CSCPtr();
|
||||
dgl_graph = CSCToDGLGraph(csc);
|
||||
} else {
|
||||
auto coo = sparse_mat_tr->COOPtr();
|
||||
dgl_graph = COOToDGLGraph(coo);
|
||||
}
|
||||
const std::string op = "mul";
|
||||
const std::string reduce = "sum";
|
||||
int64_t out_row = sparse_mat->shape()[0];
|
||||
std::vector<int64_t> shape;
|
||||
|
||||
if (dense_mat.dim() == 1) {
|
||||
shape = {out_row};
|
||||
} else {
|
||||
shape = {out_row, dense_mat.size(1)};
|
||||
}
|
||||
auto ret = torch::zeros(shape, dense_mat.options());
|
||||
|
||||
aten::SpMM(
|
||||
op.c_str(), reduce.c_str(), dgl_graph, TorchTensorToDGLArray(dense_mat),
|
||||
TorchTensorToDGLArray(sparse_val), TorchTensorToDGLArray(ret),
|
||||
std::vector<runtime::NDArray>());
|
||||
return ret;
|
||||
}
|
||||
|
||||
torch::Tensor SpMMAutoGrad::forward(
|
||||
AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
|
||||
torch::Tensor sparse_val, torch::Tensor dense_mat) {
|
||||
_SpMMSanityCheck(sparse_mat, sparse_val, dense_mat);
|
||||
auto ret = SpMMImpl(sparse_mat, sparse_val, dense_mat);
|
||||
|
||||
bool sparse_require_grad = sparse_val.requires_grad();
|
||||
bool dense_require_grad = dense_mat.requires_grad();
|
||||
torch::Tensor cache_sparse_val, cache_dense_mat;
|
||||
if (dense_require_grad) {
|
||||
cache_sparse_val = sparse_val;
|
||||
}
|
||||
if (sparse_require_grad) {
|
||||
cache_dense_mat = dense_mat;
|
||||
}
|
||||
ctx->saved_data["sparse_matrix"] = sparse_mat;
|
||||
ctx->saved_data["sparse_require_grad"] = sparse_require_grad;
|
||||
ctx->saved_data["dense_require_grad"] = dense_require_grad;
|
||||
ctx->save_for_backward({cache_sparse_val, cache_dense_mat});
|
||||
return ret;
|
||||
}
|
||||
|
||||
tensor_list SpMMAutoGrad::backward(
|
||||
AutogradContext* ctx, tensor_list grad_outputs) {
|
||||
auto saved = ctx->get_saved_variables();
|
||||
auto sparse_val = saved[0];
|
||||
auto dense_mat = saved[1];
|
||||
auto output_grad = grad_outputs[0];
|
||||
|
||||
auto sparse_mat =
|
||||
ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
|
||||
bool sparse_require_grad = ctx->saved_data["sparse_require_grad"].toBool();
|
||||
bool dense_require_grad = ctx->saved_data["dense_require_grad"].toBool();
|
||||
|
||||
torch::Tensor dense_mat_grad, sparse_val_grad;
|
||||
if (sparse_require_grad) {
|
||||
sparse_val_grad = SDDMMImpl(sparse_mat, output_grad, dense_mat);
|
||||
}
|
||||
if (dense_require_grad) {
|
||||
auto sparse_mat_tr = sparse_mat->Transpose();
|
||||
dense_mat_grad = SpMMImpl(sparse_mat_tr, sparse_val, output_grad);
|
||||
}
|
||||
return {torch::Tensor(), sparse_val_grad, dense_mat_grad};
|
||||
}
|
||||
|
||||
torch::Tensor SpMM(
|
||||
const c10::intrusive_ptr<SparseMatrix>& sparse_mat,
|
||||
torch::Tensor dense_mat) {
|
||||
return SpMMAutoGrad::apply(sparse_mat, sparse_mat->value(), dense_mat);
|
||||
}
|
||||
|
||||
} // namespace sparse
|
||||
} // namespace dgl
|
||||
@@ -34,10 +34,6 @@ void SpMM(
|
||||
const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
|
||||
void SpMM(
|
||||
const char* op, const char* reduce, HeteroGraphPtr graph, NDArray ufeat,
|
||||
NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
|
||||
/**
|
||||
* @brief Generalized Sampled Dense-Dense Matrix Multiplication.
|
||||
* @param op The binary operator, could be `add`, `sub', `mul`, 'div',
|
||||
@@ -46,16 +42,10 @@ void SpMM(
|
||||
* @param ufeat The source node feature.
|
||||
* @param vfeat The destination node feature.
|
||||
* @param out The output feature on edge.
|
||||
* @param lhs_target Type of `ufeat`. (0: source; 1: edge; 2: destination)
|
||||
* @param rhs_target Type of `vfeat`. (0: source; 1: edge; 2: destination)
|
||||
*/
|
||||
void SDDMM(
|
||||
const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray vfeat,
|
||||
NDArray out, int lhs_target, int rhs_target);
|
||||
|
||||
void SDDMM(
|
||||
const char* op, HeteroGraphPtr graph, NDArray ufeat, NDArray vfeat,
|
||||
NDArray out, int lhs_target, int rhs_target);
|
||||
const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
|
||||
NDArray out);
|
||||
|
||||
/**
|
||||
* @brief Sparse-sparse matrix multiplication.
|
||||
|
||||
@@ -10,7 +10,6 @@ from .elementwise_op import *
|
||||
from .sparse_matrix import *
|
||||
from .unary_op_diag import *
|
||||
from .unary_op_sp import *
|
||||
from .matmul import *
|
||||
|
||||
|
||||
def load_dgl_sparse():
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Matmul ops for SparseMatrix"""
|
||||
# pylint: disable=invalid-name
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .diag_matrix import DiagMatrix
|
||||
|
||||
from .sparse_matrix import SparseMatrix
|
||||
|
||||
__all__ = ["spmm"]
|
||||
|
||||
|
||||
def spmm(A: Union[SparseMatrix, DiagMatrix], X: torch.Tensor) -> torch.Tensor:
|
||||
"""Multiply a sparse matrix by a dense matrix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : SparseMatrix or DiagMatrix
|
||||
Sparse matrix of shape (N, M) with values of shape (nnz)
|
||||
X : torch.Tensor
|
||||
Dense tensor of shape (M, F) or (M)
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The result of multiplication
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> row = torch.tensor([0, 1, 1])
|
||||
>>> col = torch.tensor([1, 0, 1])
|
||||
>>> val = torch.randn(len(row))
|
||||
>>> A = create_from_coo(row, col, val)
|
||||
>>> X = torch.randn(2, 3)
|
||||
>>> result = A @ X
|
||||
>>> print(type(result))
|
||||
<class 'torch.Tensor'>
|
||||
>>> print(result.shape)
|
||||
torch.Size([2, 3])
|
||||
"""
|
||||
assert isinstance(
|
||||
A, (SparseMatrix, DiagMatrix)
|
||||
), f"Expect arg1 to be a SparseMatrix or DiagMatrix object, got {type(A)}"
|
||||
assert isinstance(
|
||||
X, torch.Tensor
|
||||
), f"Expect arg2 to be a torch.Tensor, got {type(X)}"
|
||||
assert (
|
||||
A.shape[1] == X.shape[0]
|
||||
), f"Expect arg1.shape[1] == arg2.shape[0], got {A.shape[1]} and {X.shape[0]}"
|
||||
val_dim = len(A.val.shape)
|
||||
assert val_dim == 1, f"Expect arg1.val to be a 1D tensor, got {val_dim}D"
|
||||
val_dim = len(X.shape)
|
||||
assert val_dim <= 2, f"Expect arg2 to be a 1D/2D tensor, got {val_dim}D"
|
||||
|
||||
if not isinstance(A, SparseMatrix):
|
||||
A = A.as_sparse()
|
||||
return torch.ops.dgl_sparse.spmm(A.c_sparse_matrix, X)
|
||||
|
||||
|
||||
def mm_sp(
|
||||
A1: SparseMatrix, A2: Union[torch.Tensor, SparseMatrix, DiagMatrix]
|
||||
) -> Union[torch.Tensor, SparseMatrix]:
|
||||
"""Internal function for multiplying a sparse matrix by a dense/sparse/diagonal matrix
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A1 : SparseMatrix
|
||||
Matrix of shape (N, M), with values of shape (nnz1)
|
||||
A2 : torch.Tensor, SparseMatrix, or DiagMatrix
|
||||
If A2 is a dense tensor, it can have shapes of (M, P) or (M, ).
|
||||
Otherwise it must have a shape of (M, P).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor or SparseMatrix
|
||||
The result of multiplication.
|
||||
|
||||
* It is a dense torch tensor if :attr:`A2` is so.
|
||||
* It is a SparseMatrix object otherwise.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> row = torch.tensor([0, 1, 1])
|
||||
>>> col = torch.tensor([1, 0, 1])
|
||||
>>> val = torch.randn(len(row))
|
||||
>>> A1 = create_from_coo(row, col, val)
|
||||
>>> A2 = torch.randn(2, 3)
|
||||
>>> result = A1 @ A2
|
||||
>>> print(type(result))
|
||||
<class 'torch.Tensor'>
|
||||
>>> print(result.shape)
|
||||
torch.Size([2, 3])
|
||||
"""
|
||||
assert isinstance(
|
||||
A2, (torch.Tensor, SparseMatrix, DiagMatrix)
|
||||
), f"Expect arg2 to be a torch Tensor, SparseMatrix, or DiagMatrix object, got {type(A2)}"
|
||||
|
||||
if isinstance(A2, torch.Tensor):
|
||||
return spmm(A1, A2)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
SparseMatrix.__matmul__ = mm_sp
|
||||
@@ -48,12 +48,6 @@ void SpMM(
|
||||
});
|
||||
}
|
||||
|
||||
void SpMM(
|
||||
const char* op, const char* reduce, HeteroGraphPtr graph, NDArray ufeat,
|
||||
NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
|
||||
SpMM(std::string(op), std::string(reduce), graph, ufeat, efeat, out, out_aux);
|
||||
}
|
||||
|
||||
/** @brief Generalized segmented dense Matrix-Matrix Multiplication. */
|
||||
void SegmentMM(
|
||||
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
|
||||
@@ -258,12 +252,6 @@ void SDDMM(
|
||||
});
|
||||
}
|
||||
|
||||
void SDDMM(
|
||||
const char* op, HeteroGraphPtr graph, NDArray ufeat, NDArray vfeat,
|
||||
NDArray out, int lhs_target, int rhs_target) {
|
||||
SDDMM(std::string(op), graph, ufeat, vfeat, out, lhs_target, rhs_target);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the src/dst/etype id based on the target 'u', 'v' or 'e'.
|
||||
*
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import sys
|
||||
|
||||
import backend as F
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from dgl.mock_sparse2 import create_from_coo, create_from_csc, create_from_csr
|
||||
|
||||
# TODO(#4818): Skipping tests on win.
|
||||
if not sys.platform.startswith("linux"):
|
||||
pytest.skip("skipping tests on win", allow_module_level=True)
|
||||
|
||||
|
||||
def get_adj(A):
|
||||
row, col = A.coo()
|
||||
edge_index = torch.cat((row.unsqueeze(0), col.unsqueeze(0)), 0)
|
||||
shape = A.shape
|
||||
val = A.val.detach()
|
||||
if len(A.val.shape) > 1:
|
||||
shape += (A.val.shape[-1],)
|
||||
return torch.sparse_coo_tensor(edge_index, val, shape).coalesce()
|
||||
|
||||
|
||||
def test_spmm_coo():
|
||||
dev = F.ctx()
|
||||
# A: shape (N, M), X: shape (M, F)
|
||||
row = torch.tensor([0, 1, 1, 1]).to(dev)
|
||||
col = torch.tensor([1, 0, 1, 2]).to(dev)
|
||||
val = torch.randn(len(row), requires_grad=True, device=dev)
|
||||
A = create_from_coo(row, col, val)
|
||||
X = torch.randn(3, 4, requires_grad=True, device=dev)
|
||||
sparse_result = A @ X
|
||||
grad = torch.randn_like(sparse_result)
|
||||
sparse_result.backward(grad)
|
||||
|
||||
adj = get_adj(A)
|
||||
adj.requires_grad_()
|
||||
XX = X.clone().detach()
|
||||
XX.requires_grad_()
|
||||
dense_result = torch.sparse.mm(adj, XX)
|
||||
dense_result.backward(grad)
|
||||
assert torch.allclose(sparse_result, dense_result)
|
||||
assert torch.allclose(X.grad, XX.grad)
|
||||
assert torch.allclose(adj.grad.coalesce().values(), val.grad)
|
||||
|
||||
|
||||
def test_spmm_coo_one_dim_rhs():
|
||||
dev = F.ctx()
|
||||
# A: shape (N, M), X: shape (M,)
|
||||
row = torch.tensor([0, 1, 1, 1]).to(dev)
|
||||
col = torch.tensor([1, 0, 1, 2]).to(dev)
|
||||
val = torch.randn(len(row), requires_grad=True, device=dev)
|
||||
A = create_from_coo(row, col, val)
|
||||
X = torch.randn(3, requires_grad=True, device=dev)
|
||||
sparse_result = A @ X
|
||||
grad = torch.randn_like(sparse_result)
|
||||
sparse_result.backward(grad)
|
||||
|
||||
adj = get_adj(A)
|
||||
adj.requires_grad_()
|
||||
XX = X.clone().detach()
|
||||
XX.requires_grad_()
|
||||
dense_result = torch.sparse.mm(adj, XX.view(-1, 1))
|
||||
dense_result = dense_result.view(-1)
|
||||
dense_result.backward(grad)
|
||||
assert torch.allclose(sparse_result, dense_result)
|
||||
assert torch.allclose(X.grad, XX.grad)
|
||||
assert torch.allclose(adj.grad.coalesce().values(), val.grad)
|
||||
|
||||
|
||||
def test_spmm_csr():
|
||||
dev = F.ctx()
|
||||
# A: shape (N, M), X: shape (M, F)
|
||||
indptr = torch.tensor([0, 1, 4]).to(dev)
|
||||
indices = torch.tensor([1, 0, 1, 2]).to(dev)
|
||||
val = torch.randn(len(indices), requires_grad=True, device=dev)
|
||||
A = create_from_csr(indptr, indices, val, shape=(2, 3))
|
||||
X = torch.randn(3, 4, requires_grad=True, device=dev)
|
||||
sparse_result = A @ X
|
||||
grad = torch.randn_like(sparse_result)
|
||||
sparse_result.backward(grad)
|
||||
|
||||
adj = get_adj(A)
|
||||
adj.requires_grad_()
|
||||
XX = X.clone().detach()
|
||||
XX.requires_grad_()
|
||||
dense_result = torch.sparse.mm(adj, XX)
|
||||
dense_result.backward(grad)
|
||||
assert torch.allclose(sparse_result, dense_result)
|
||||
assert torch.allclose(X.grad, XX.grad)
|
||||
assert torch.allclose(adj.grad.coalesce().values(), val.grad)
|
||||
|
||||
|
||||
def test_spmm_csc():
|
||||
dev = F.ctx()
|
||||
# A: shape (N, M), X: shape (M, F)
|
||||
indptr = torch.tensor([0, 1, 3, 4]).to(dev)
|
||||
indices = torch.tensor([0, 0, 1, 1]).to(dev)
|
||||
val = torch.randn(len(indices), requires_grad=True, device=dev)
|
||||
A = create_from_csc(indptr, indices, val, shape=(2, 3))
|
||||
X = torch.randn(3, 4, requires_grad=True, device=dev)
|
||||
sparse_result = A @ X
|
||||
grad = torch.randn_like(sparse_result)
|
||||
sparse_result.backward(grad)
|
||||
|
||||
adj = get_adj(A)
|
||||
adj.requires_grad_()
|
||||
XX = X.clone().detach()
|
||||
XX.requires_grad_()
|
||||
dense_result = torch.sparse.mm(adj, XX)
|
||||
dense_result.backward(grad)
|
||||
assert torch.allclose(sparse_result, dense_result)
|
||||
assert torch.allclose(X.grad, XX.grad)
|
||||
assert torch.allclose(adj.grad.coalesce().values(), val.grad)
|
||||
Reference in New Issue
Block a user