Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 53 additions & 35 deletions src/eckit/linalg/sparse/LinearAlgebraMKL.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,49 @@ namespace eckit {
namespace linalg {
namespace sparse {

namespace {

struct MKLSparseHandle {
sparse_matrix_t value = nullptr;

~MKLSparseHandle() {
if (value != nullptr) {
mkl_sparse_destroy(value);
}
}

operator sparse_matrix_t() const { return value; }
};

void checkSparseStatus(sparse_status_t status, const char* call) {
if (status != SPARSE_STATUS_SUCCESS) {
throw SeriousBug(std::string(call) + " failed with MKL sparse status " + std::to_string(status), Here());
}
}

matrix_descr generalMatrixDescription() {
matrix_descr description{};
description.type = SPARSE_MATRIX_TYPE_GENERAL;
description.mode = SPARSE_FILL_MODE_FULL;
description.diag = SPARSE_DIAG_NON_UNIT;
return description;
}

MKLSparseHandle createSparseMatrixHandle(const SparseMatrix& A) {
MKLSparseHandle handle;
auto* outer = const_cast<MKL_INT*>(static_cast<const MKL_INT*>(A.outer()));
auto* inner = const_cast<MKL_INT*>(static_cast<const MKL_INT*>(A.inner()));
auto* data = const_cast<double*>(static_cast<const double*>(A.data()));

checkSparseStatus(mkl_sparse_d_create_csr(&handle.value, SPARSE_INDEX_BASE_ZERO, static_cast<MKL_INT>(A.rows()),
static_cast<MKL_INT>(A.cols()), outer, outer + 1, inner, data),
"mkl_sparse_d_create_csr");
checkSparseStatus(mkl_sparse_optimize(handle), "mkl_sparse_optimize");
return handle;
}

} // namespace


static const LinearAlgebraMKL __la("mkl");

Expand All @@ -45,21 +88,12 @@ void LinearAlgebraMKL::spmv(const SparseMatrix& A, const Vector& x, Vector& y) c
// We expect indices to be 0-based
ASSERT(A.outer()[0] == 0);

const auto m = static_cast<MKL_INT>(A.rows());
const auto k = static_cast<MKL_INT>(A.cols());

// FIXME: mkl_dcsrmv is deprecated, use mkl_sparse_d_mv instead
// void mkl_dcsrmv(const char *transa, const MKL_INT *m, const MKL_INT *k,
// const double *alpha, const char *matdescra,
// const double *val, const MKL_INT *indx, const MKL_INT *pntrb, const MKL_INT *pntre,
// const double *x, const double *beta, double *y);

const auto* matrix = static_cast<const double*>(A.data());
const auto* inner = static_cast<const MKL_INT*>(A.inner());
const auto* outer = static_cast<const MKL_INT*>(A.outer());
const auto* vector = static_cast<const double*>(x.data());
const auto description = generalMatrixDescription();
const auto handle = createSparseMatrixHandle(A);

mkl_dcsrmv("N", &m, &k, &alpha, "G__C", matrix, inner, outer, outer + 1, vector, &beta, y.data());
checkSparseStatus(
mkl_sparse_d_mv(SPARSE_OPERATION_NON_TRANSPOSE, alpha, handle, description, x.data(), beta, y.data()),
"mkl_sparse_d_mv");
}


Expand All @@ -75,28 +109,12 @@ void LinearAlgebraMKL::spmm(const SparseMatrix& A, const Matrix& B, Matrix& C) c
const auto n = static_cast<MKL_INT>(C.cols());
const auto k = static_cast<MKL_INT>(A.cols());

// FIXME: with 0-based indexing, MKL assumes row-major ordering for B and C
// We need to use 1-based indexing i.e. offset outer and inner indices by 1

std::vector<MKL_INT> pntrb(A.rows() + 1);
for (Size i = 0; i < A.rows() + 1; ++i) {
pntrb[i] = A.outer()[i] + 1;
}

std::vector<MKL_INT> indx(A.nonZeros());
for (Size i = 0; i < A.nonZeros(); ++i) {
indx[i] = A.inner()[i] + 1;
}

// FIXME: mkl_dcsrmm is deprecated, use mkl_sparse_d_mm instead
// void mkl_dcsrmm(const char *transa, const MKL_INT *m, const MKL_INT *n, const MKL_INT *k,
// const double *alpha, const char *matdescra,
// const double *val, const MKL_INT *indx, const MKL_INT *pntrb, const MKL_INT *pntre,
// const double *b, const MKL_INT *ldb, const double *beta,
// double *c, const MKL_INT *ldc);
const auto description = generalMatrixDescription();
const auto handle = createSparseMatrixHandle(A);

mkl_dcsrmm("N", &m, &n, &k, &alpha, "G__F", A.data(), indx.data(), pntrb.data(), pntrb.data() + 1, B.data(), &k,
&beta, C.data(), &m);
checkSparseStatus(mkl_sparse_d_mm(SPARSE_OPERATION_NON_TRANSPOSE, alpha, handle, description,
SPARSE_LAYOUT_COLUMN_MAJOR, B.data(), n, k, beta, C.data(), m),
"mkl_sparse_d_mm");
}


Expand Down
Loading