Skip to content

einsum2 batched GEMM has 3.5x overhead vs Julia for small-m contractions #115

@shinaoka

Description

@shinaoka

Summary

einsum2_into_owned is 3.5x slower than Julia's direct BLAS.gemm! for the GEMM phase (after data is already contiguous), when m is small (m=4).

Benchmark (AMD EPYC 7713P)

Step 408 of tensornetwork_permutation_light_415: m=4, k=256, n=8192, 8 batches.

GEMM (contiguous data) 1T 4T
Rust einsum2_into_owned (blas) 63 ms 62 ms
Julia mul!BLAS.gemm! 18 ms 5 ms

Rust 4T shows no speedup for GEMM, while Julia gets 3.6x speedup.

Reproduction

# Rust (measures "einsum2 (contiguous, ~GEMM)" line)
RAYON_NUM_THREADS=1 OMP_NUM_THREADS=1 \
  cargo run --release --no-default-features --features blas --bin step408_bench

# Julia (measures "BLAS gemm only" line)
OPENBLAS_NUM_THREADS=1 julia --project=. micro_bench/step408_fair.jl

Analysis: two layers of overhead

Layer 1: einsum2_into_owned wrapper cost (both backends)

On every invocation, einsum2_into_owned performs:

  • Einsum2Plan::new — axis classification and permutation computation via linear scans
  • validate_dimensions — scans all axis groups against all operand label arrays
  • prepare_input_ownedtry_fuse_group, REQUIRES_UNIT_STRIDE checks, allocation

This overhead is independent of the GEMM backend (faer or blas) and dominates for small contractions.

Layer 2: per-batch GEMM dispatch (backend-specific)

The 8 batches are dispatched as 8 separate GEMM calls. Each call has backend-specific overhead.

Approach: Adopt OMEinsum.jl's strategy (both backends)

OMEinsum.jl (BatchedRoutines.jl):

for batch in 1:nb
    ccall(dgemm_, ..., ptrA + batch_offset, lda, ptrB + batch_offset, ldb, ...)
end

Key: plan is built once, then GEMM is a tight loop over raw pointers with stride parameters.

Proposed fix

  1. Reduce per-call overhead (both backends): Cache or hoist Einsum2Plan construction and validation outside the batch loop. The current bgemm_contiguous_into already receives pre-prepared operands, but the per-batch dispatch in the backend still has overhead.

  2. Batch loop with stride passthrough (backend-specific):

    • blas backend: Loop over batches calling cblas_dgemm directly with lda/ldb stride parameters, advancing raw pointers by batch stride.
    • faer backend: Loop over batches calling faer::linalg::matmul::matmul with MatRef/MatMut (which natively carry stride info), advancing by batch stride. faer already handles non-unit strides via its MatRef API.
  3. Skip multi-threading for small m: When m ≤ threshold, disable BLAS/faer threading for GEMM (the thread dispatch overhead exceeds the benefit).

Context

Related: #114 (copy_into performance), #116 (pre-permutation dim fusion — the highest impact fix).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions