Add rewrite rule to drop contiguous-axis stride in scatter/gather offsets + unified AssumeOp injection#213
Conversation
…gather. For a `TileArray` with `ArraySpec` `contiguous=true`, the constant analysis already recognises `getfield(getfield(arg, :strides), 1)` as the literal `1`, propagating through `broadcast`/`reshape`/`from_scalar`. Adding the matching algebra rewrite drops the `muli(idx, broadcast(1))` in scatter/gather offset chains. The fold mirrors Python cuTile's `_gather_scatter_pointer_and_mask` which uses a structural skip (`if static_stride == 1: offset_delta = ind`). Without it, the contiguous-axis stride is a runtime broadcast and consecutive lanes' addresses differ by an unknown scalar, forcing tileiras to fall back to scalar stores (`STG.E.U16`) instead of wide vector (`STG.E.128`) stores in 2-D scatter kernels (MoE down-projection). Standalone, this rewrite triggers a tileiras crash at -O1+ on the MoE kernel — tileiras's auto-vectorizer enters a code path that needs alignment proofs which cuTile.jl doesn't currently emit for scatter/gather pointer/size/stride args. The follow-up commit extends the assume pass to inject those.
Drops the per-`make_tensor_view` `AssumeInfo` / `MTVPredicates` sidecar
in favor of on-demand chain derivation at consumer sites. The
divisibility / bounds dataflow results live on `CGCtx`; consumer
codegen calls `op_predicates(divby, bounds, op, kind, spec_div)` to
derive each operand's `AssumePredicate` chain, and `wrap_for` consults
a per-`Value` cache (`ctx.assume_wrapped`) so a `Value` reused across
consumers — e.g. a kernel-arg pointer threaded through both an MTV and
a gather — is wrapped exactly once. Mirrors the role of cuTile Python's
`var_map` in `_passes/propagate_divby.py`.
Extends the consumer set from `{make_tensor_view}` to
`{make_tensor_view, load_ptr_tko, store_ptr_tko}` (Python's
`_OPS_NEED_ASSUME`), and adds an entry-time spec-derived wrap on each
`TileArray` kernel-arg flat slot (`apply_arg_assume_predicates!` +
`arg_chain`). The entry wrap is what carries `spec.alignment` to the
base pointer of gather/scatter chains: the post-offset operand at
`load_ptr_tko` only sees the lane-stride alignment, but the assumed
base `Value` flows through `reshape` → `broadcast` → `offset` so
tileiras's vectorizer has both alignments — the proof its STG.E.128 /
LDG.E.128 lowering needs on the MoE down-projection scatter.
`current_block` is tracked on `CGCtx` so consumer-op codegen can run
parent-walking queries (`tuple_element_source` for tuple-typed
sizes/strides operands) starting from the right scope.
scatter/gather offsetsscatter/gather offsets + unified AssumeOp injection
|
Also adding a refactor I was planning in a subsequent PR, to add assumptions for non-TileView operands (since gather/scatter take raw pointers), because withouto it I trigger NVIDIA/cuda-tile#19: Unified AssumeOp injectionThe fold above triggers a tileiras crash at Architecturally it's also a simplification: the per- The entry wrap is what carries |
The two helpers were parallel: both pulled spec.alignment / shape_div_by / stride_div_by and applied the same structural priors (`Bounded(0,?)` for sizes/strides, `DivBy(d)` when `d > 1`). Replace arg_chain's body with a path-keyed dispatch over op_predicates with nothing dataflow inputs — the kernel-arg slot is the dataflow anchor, so there's nothing upstream to refine against. The contiguous-axis stride skip stays in the dispatcher since it has the path context. Drops ~15 lines and removes a pair of trivially-equivalent code paths.
The per-Value cache keys only on Value, not on chain contents — sound because the pipeline arranges that the first-seen chain on a given Value is an upper bound on what any later consumer could derive (kernel-arg-entry wrap seeds the spec-tightest chain; structural prior is tile-type-determined so per-Value consistent). Promote the comment to a docstring spelling out both reasons plus the failure mode if a future consumer ever derives a tighter chain on a pre-wrapped Value.
For a TileArray with
ArraySpec{contiguous=true}, Julia's column-major convention makesstride[1] == 1statically known, and the constant analysi) propagates that 1 through thebroadcast/reshape/from_scalarchain that feeds thegather/scatteroffset compute. This PR adds the matchingmuli(x, 1) → xrewrite so the contiguous-axis stride multiply collapses out of the offset.