Summary
einsum2_into_owned is 3.5x slower than Julia's direct BLAS.gemm! for the GEMM phase (after data is already contiguous), when m is small (m=4).
Benchmark (AMD EPYC 7713P)
Step 408 of tensornetwork_permutation_light_415: m=4, k=256, n=8192, 8 batches.
| GEMM (contiguous data) |
1T |
4T |
Rust einsum2_into_owned (blas) |
63 ms |
62 ms |
Julia mul! → BLAS.gemm! |
18 ms |
5 ms |
Rust 4T shows no speedup for GEMM, while Julia gets 3.6x speedup.
Reproduction
# Rust (measures "einsum2 (contiguous, ~GEMM)" line)
RAYON_NUM_THREADS=1 OMP_NUM_THREADS=1 \
cargo run --release --no-default-features --features blas --bin step408_bench
# Julia (measures "BLAS gemm only" line)
OPENBLAS_NUM_THREADS=1 julia --project=. micro_bench/step408_fair.jl
Analysis: two layers of overhead
Layer 1: einsum2_into_owned wrapper cost (both backends)
On every invocation, einsum2_into_owned performs:
Einsum2Plan::new — axis classification and permutation computation via linear scans
validate_dimensions — scans all axis groups against all operand label arrays
prepare_input_owned — try_fuse_group, REQUIRES_UNIT_STRIDE checks, allocation
This overhead is independent of the GEMM backend (faer or blas) and dominates for small contractions.
Layer 2: per-batch GEMM dispatch (backend-specific)
The 8 batches are dispatched as 8 separate GEMM calls. Each call has backend-specific overhead.
Approach: Adopt OMEinsum.jl's strategy (both backends)
OMEinsum.jl (BatchedRoutines.jl):
for batch in 1:nb
ccall(dgemm_, ..., ptrA + batch_offset, lda, ptrB + batch_offset, ldb, ...)
end
Key: plan is built once, then GEMM is a tight loop over raw pointers with stride parameters.
Proposed fix
-
Reduce per-call overhead (both backends): Cache or hoist Einsum2Plan construction and validation outside the batch loop. The current bgemm_contiguous_into already receives pre-prepared operands, but the per-batch dispatch in the backend still has overhead.
-
Batch loop with stride passthrough (backend-specific):
- blas backend: Loop over batches calling
cblas_dgemm directly with lda/ldb stride parameters, advancing raw pointers by batch stride.
- faer backend: Loop over batches calling
faer::linalg::matmul::matmul with MatRef/MatMut (which natively carry stride info), advancing by batch stride. faer already handles non-unit strides via its MatRef API.
-
Skip multi-threading for small m: When m ≤ threshold, disable BLAS/faer threading for GEMM (the thread dispatch overhead exceeds the benefit).
Context
Related: #114 (copy_into performance), #116 (pre-permutation dim fusion — the highest impact fix).
Summary
einsum2_into_ownedis 3.5x slower than Julia's directBLAS.gemm!for the GEMM phase (after data is already contiguous), when m is small (m=4).Benchmark (AMD EPYC 7713P)
Step 408 of
tensornetwork_permutation_light_415: m=4, k=256, n=8192, 8 batches.einsum2_into_owned(blas)mul!→BLAS.gemm!Rust 4T shows no speedup for GEMM, while Julia gets 3.6x speedup.
Reproduction
Analysis: two layers of overhead
Layer 1:
einsum2_into_ownedwrapper cost (both backends)On every invocation,
einsum2_into_ownedperforms:Einsum2Plan::new— axis classification and permutation computation via linear scansvalidate_dimensions— scans all axis groups against all operand label arraysprepare_input_owned—try_fuse_group,REQUIRES_UNIT_STRIDEchecks, allocationThis overhead is independent of the GEMM backend (faer or blas) and dominates for small contractions.
Layer 2: per-batch GEMM dispatch (backend-specific)
The 8 batches are dispatched as 8 separate GEMM calls. Each call has backend-specific overhead.
Approach: Adopt OMEinsum.jl's strategy (both backends)
OMEinsum.jl (BatchedRoutines.jl):
Key: plan is built once, then GEMM is a tight loop over raw pointers with stride parameters.
Proposed fix
Reduce per-call overhead (both backends): Cache or hoist
Einsum2Planconstruction and validation outside the batch loop. The currentbgemm_contiguous_intoalready receives pre-prepared operands, but the per-batch dispatch in the backend still has overhead.Batch loop with stride passthrough (backend-specific):
cblas_dgemmdirectly withlda/ldbstride parameters, advancing raw pointers by batch stride.faer::linalg::matmul::matmulwithMatRef/MatMut(which natively carry stride info), advancing by batch stride. faer already handles non-unit strides via itsMatRefAPI.Skip multi-threading for small m: When
m ≤ threshold, disable BLAS/faer threading for GEMM (the thread dispatch overhead exceeds the benefit).Context
Related: #114 (copy_into performance), #116 (pre-permutation dim fusion — the highest impact fix).