Skip to content

Use cblas_dgemm_batch when available instead of manual batch loop #125

@shinaoka

Description

@shinaoka

Summary

strided-einsum2/src/bgemm_blas.rs currently loops over batch dimensions and calls cblas_dgemm one slice at a time (the do_batch closure). OpenBLAS 0.3.29+ provides cblas_dgemm_batch (pointer-array variant) which handles the loop internally and may enable BLAS-level optimizations.

Current state

  • cblas_dgemm_batch (pointer-array version): available in OpenBLAS >= 0.3.29
  • cblas_dgemm_batch_strided: NOT in 0.3.29 (added later), but not needed — the pointer-array version is more flexible
  • cblas-sys crate (0.2.0 / 0.3.0): does NOT export cblas_dgemm_batch bindings (it's a BLAS-like extension, not standard CBLAS)
  • cblas-inject feature: provides runtime-registered CBLAS fallback — currently only has single GEMM

Proposal

  1. Add cblas_dgemm_batch / cblas_zgemm_batch via extern "C" declarations in bgemm_blas.rs
  2. Add cblas_dgemm_batch / cblas_zgemm_batch fallback to cblas-inject — implement as a loop over individual GEMM calls, so that the batch API is always available regardless of the underlying BLAS library
  3. Use the batch API uniformly in bgemm_contiguous_into() — no need for feature gates or runtime detection since cblas-inject provides the fallback
  4. For real BLAS libraries (OpenBLAS, MKL), the native cblas_dgemm_batch is called; for cblas-inject, the fallback loop runs

API signature (OpenBLAS/MKL)

void cblas_dgemm_batch(
    CBLAS_LAYOUT layout,
    CBLAS_TRANSPOSE *transa_array,
    CBLAS_TRANSPOSE *transb_array,
    MKL_INT *m_array, MKL_INT *n_array, MKL_INT *k_array,
    double *alpha_array,
    const double **a_array, MKL_INT *lda_array,
    const double **b_array, MKL_INT *ldb_array,
    double *beta_array,
    double **c_array, MKL_INT *ldc_array,
    MKL_INT group_count,
    MKL_INT *group_size
);

For uniform batches: group_count=1, group_size=[batch_count].

cblas-inject fallback implementation

// In cblas-inject crate
pub unsafe extern "C" fn cblas_dgemm_batch(
    layout: ..., transa_array: ..., transb_array: ...,
    m_array: ..., n_array: ..., k_array: ...,
    alpha_array: ...,
    a_array: ..., lda_array: ...,
    b_array: ..., ldb_array: ...,
    beta_array: ...,
    c_array: ..., ldc_array: ...,
    group_count: ..., group_size: ...
) {
    // Loop over groups, then over matrices in each group,
    // calling cblas_dgemm for each individual matrix multiply
    for group in 0..group_count {
        for i in 0..group_size[group] {
            cblas_dgemm(layout, transa_array[group], transb_array[group], ...);
        }
    }
}

Expected impact

Likely small for einsum2 (batch dimensions are typically small). Main benefit is cleaner code and potential BLAS-internal parallelization of the batch loop on capable libraries.

Risk

Low. cblas-inject fallback ensures correctness everywhere; native BLAS libraries get the optimized path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions