Summary
strided-einsum2/src/plan.rs classifies axes into 6 groups: batch, lo, ro, sum, left_trace, right_trace. The trace groups add complexity to the plan phase (extra vectors, extra permutation computation).
Proposal
Reduce to 4 groups (batch, lo, ro, sum). Handle trace axes lazily at runtime — detect and reduce them on-demand before the GEMM call, rather than pre-classifying them in the plan. The trace.rs module already provides reduce_trace_axes() which can be called inline.
This could potentially eliminate trace.rs as a separate module (~40 lines).
Risk
Low. Trace axes are rare (typically 0-2 per contraction) and already handled by a separate reduction step.
Summary
strided-einsum2/src/plan.rsclassifies axes into 6 groups:batch,lo,ro,sum,left_trace,right_trace. The trace groups add complexity to the plan phase (extra vectors, extra permutation computation).Proposal
Reduce to 4 groups (
batch,lo,ro,sum). Handle trace axes lazily at runtime — detect and reduce them on-demand before the GEMM call, rather than pre-classifying them in the plan. Thetrace.rsmodule already providesreduce_trace_axes()which can be called inline.This could potentially eliminate
trace.rsas a separate module (~40 lines).Risk
Low. Trace axes are rare (typically 0-2 per contraction) and already handled by a separate reduction step.