diff --git a/src/eckit/linalg/sparse/LinearAlgebraMKL.cc b/src/eckit/linalg/sparse/LinearAlgebraMKL.cc index 9e0a62206..1b036dff1 100644 --- a/src/eckit/linalg/sparse/LinearAlgebraMKL.cc +++ b/src/eckit/linalg/sparse/LinearAlgebraMKL.cc @@ -26,6 +26,49 @@ namespace eckit { namespace linalg { namespace sparse { +namespace { + +struct MKLSparseHandle { + sparse_matrix_t value = nullptr; + + ~MKLSparseHandle() { + if (value != nullptr) { + mkl_sparse_destroy(value); + } + } + + operator sparse_matrix_t() const { return value; } +}; + +void checkSparseStatus(sparse_status_t status, const char* call) { + if (status != SPARSE_STATUS_SUCCESS) { + throw SeriousBug(std::string(call) + " failed with MKL sparse status " + std::to_string(status), Here()); + } +} + +matrix_descr generalMatrixDescription() { + matrix_descr description{}; + description.type = SPARSE_MATRIX_TYPE_GENERAL; + description.mode = SPARSE_FILL_MODE_FULL; + description.diag = SPARSE_DIAG_NON_UNIT; + return description; +} + +MKLSparseHandle createSparseMatrixHandle(const SparseMatrix& A) { + MKLSparseHandle handle; + auto* outer = const_cast(static_cast(A.outer())); + auto* inner = const_cast(static_cast(A.inner())); + auto* data = const_cast(static_cast(A.data())); + + checkSparseStatus(mkl_sparse_d_create_csr(&handle.value, SPARSE_INDEX_BASE_ZERO, static_cast(A.rows()), + static_cast(A.cols()), outer, outer + 1, inner, data), + "mkl_sparse_d_create_csr"); + checkSparseStatus(mkl_sparse_optimize(handle), "mkl_sparse_optimize"); + return handle; +} + +} // namespace + static const LinearAlgebraMKL __la("mkl"); @@ -45,21 +88,12 @@ void LinearAlgebraMKL::spmv(const SparseMatrix& A, const Vector& x, Vector& y) c // We expect indices to be 0-based ASSERT(A.outer()[0] == 0); - const auto m = static_cast(A.rows()); - const auto k = static_cast(A.cols()); - - // FIXME: mkl_dcsrmv is deprecated, use mkl_sparse_d_mv instead - // void mkl_dcsrmv(const char *transa, const MKL_INT *m, const MKL_INT *k, - // const double *alpha, const char *matdescra, - // const double *val, const MKL_INT *indx, const MKL_INT *pntrb, const MKL_INT *pntre, - // const double *x, const double *beta, double *y); - - const auto* matrix = static_cast(A.data()); - const auto* inner = static_cast(A.inner()); - const auto* outer = static_cast(A.outer()); - const auto* vector = static_cast(x.data()); + const auto description = generalMatrixDescription(); + const auto handle = createSparseMatrixHandle(A); - mkl_dcsrmv("N", &m, &k, &alpha, "G__C", matrix, inner, outer, outer + 1, vector, &beta, y.data()); + checkSparseStatus( + mkl_sparse_d_mv(SPARSE_OPERATION_NON_TRANSPOSE, alpha, handle, description, x.data(), beta, y.data()), + "mkl_sparse_d_mv"); } @@ -75,28 +109,12 @@ void LinearAlgebraMKL::spmm(const SparseMatrix& A, const Matrix& B, Matrix& C) c const auto n = static_cast(C.cols()); const auto k = static_cast(A.cols()); - // FIXME: with 0-based indexing, MKL assumes row-major ordering for B and C - // We need to use 1-based indexing i.e. offset outer and inner indices by 1 - - std::vector pntrb(A.rows() + 1); - for (Size i = 0; i < A.rows() + 1; ++i) { - pntrb[i] = A.outer()[i] + 1; - } - - std::vector indx(A.nonZeros()); - for (Size i = 0; i < A.nonZeros(); ++i) { - indx[i] = A.inner()[i] + 1; - } - - // FIXME: mkl_dcsrmm is deprecated, use mkl_sparse_d_mm instead - // void mkl_dcsrmm(const char *transa, const MKL_INT *m, const MKL_INT *n, const MKL_INT *k, - // const double *alpha, const char *matdescra, - // const double *val, const MKL_INT *indx, const MKL_INT *pntrb, const MKL_INT *pntre, - // const double *b, const MKL_INT *ldb, const double *beta, - // double *c, const MKL_INT *ldc); + const auto description = generalMatrixDescription(); + const auto handle = createSparseMatrixHandle(A); - mkl_dcsrmm("N", &m, &n, &k, &alpha, "G__F", A.data(), indx.data(), pntrb.data(), pntrb.data() + 1, B.data(), &k, - &beta, C.data(), &m); + checkSparseStatus(mkl_sparse_d_mm(SPARSE_OPERATION_NON_TRANSPOSE, alpha, handle, description, + SPARSE_LAYOUT_COLUMN_MAJOR, B.data(), n, k, beta, C.data(), m), + "mkl_sparse_d_mm"); }