Skip to content

Simplify axis classification: reduce 6 groups to 4 + lazy trace #126

@shinaoka

Description

@shinaoka

Summary

strided-einsum2/src/plan.rs classifies axes into 6 groups: batch, lo, ro, sum, left_trace, right_trace. The trace groups add complexity to the plan phase (extra vectors, extra permutation computation).

Proposal

Reduce to 4 groups (batch, lo, ro, sum). Handle trace axes lazily at runtime — detect and reduce them on-demand before the GEMM call, rather than pre-classifying them in the plan. The trace.rs module already provides reduce_trace_axes() which can be called inline.

This could potentially eliminate trace.rs as a separate module (~40 lines).

Risk

Low. Trace axes are rare (typically 0-2 per contraction) and already handled by a separate reduction step.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions