Skip to content

perf: avoid finalize copy in einsum2 by deferring output permutation (lazy permute) #128

@shinaoka

Description

@shinaoka

Summary

einsum2_dispatch always materializes the output via finalize_into when the output permutation makes the C dimensions non-fusable. This causes a full copy of the output tensor at the end of every contraction step with non-trivial output permutation.

A "lazy permute" approach (as implemented in tenferro-einsum) can avoid this copy entirely for intermediate steps by returning a non-contiguous view instead of a contiguous array. The next step's prepare_input_view/owned handles the non-contiguous input, and if the input happens to be fusable for the next GEMM, the copy is eliminated completely.

Current behavior

In einsum2_dispatch (lib.rs):

// 3. Prepare output
let c_op = prepare_output_view(&mut c_perm, n_lo, n_ro, beta, ...)?;
// 4. GEMM
B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, ...)?;
// 5. Finalize — copies temp → c_perm when output was non-fusable
c_op.finalize_into(&mut c_perm)?;

When prepare_output_view detects non-fusable strides, it allocates a temp buffer and finalize_into copies the GEMM result back. For intermediate steps (beta=0), the copy-in is skipped but the copy-back always happens.

Proposed behavior

For intermediate contraction steps (where the output is consumed by a subsequent step), skip finalize_into and instead return the GEMM output as a non-contiguous view with rearranged strides (lazy permute). The downstream step's prepare_input_owned already handles non-contiguous inputs.

Benchmark evidence

Comparing strided-rs vs tenferro-einsum (which implements lazy permute) on gm_queen5_5_3.wcsp:

Strategy strided-rs tenferro Ratio
opt_flops (148 steps, large intermediates) 8116ms 7083ms 0.87x
opt_size (159 steps, small intermediates) 2426ms 2753ms 1.13x
  • opt_flops: tenferro is 13% faster — large intermediates make the finalize copy expensive. Lazy permute avoids ~GB of total copies across 148 steps.
  • opt_size: strided-rs is 13% faster — small intermediates make finalize copies cheap, and tenferro's multi-layer dispatch overhead (~2ms/step) dominates.

The opt_flops result demonstrates that eliminating finalize copies can yield significant speedups for workloads with large intermediate tensors.

Affected code

  • strided-einsum2/src/lib.rs: einsum2_dispatch — the finalize_into call
  • strided-einsum2/src/contiguous.rs: ContiguousOperandMut::finalize_into
  • strided-opteinsum/src/expr.rs: eval_pair_alloc — would need to support returning non-contiguous results

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions