diff --git a/src/compiler/analysis/assume.jl b/src/compiler/analysis/assume.jl index a57b13a9..9d6860e6 100644 --- a/src/compiler/analysis/assume.jl +++ b/src/compiler/analysis/assume.jl @@ -1,223 +1,146 @@ -# Assume Aggregation +# Assume Helpers # -# Read-only aggregator that bundles `DivByAnalysis` and `BoundsAnalysis` -# results with `ArraySpec` lookup, projection (pow2 + bound clamping), -# and tuple-element-source resolution into a per-`make_tensor_view` -# predicate bundle. Codegen consumes one bundle per call: -# `predicates_for(ctx.assume_info, mtv_ssa)` returns an `MTVPredicates` -# struct with `ptr` / `sizes[i]` / `strides[i]` chains ready to wrap -# the corresponding bytecode `Value`s. +# `DivByAnalysis` and `BoundsAnalysis` propagate per-anchor facts; this +# file is the projection layer that turns those facts (plus the +# operand's `ArraySpec`) into the `AssumePredicate` chains that codegen +# wraps `Value`s with. There's no precomputed sidecar — the analyses' +# results live on the `CGCtx` and consumer-op codegen calls +# `op_predicates` / `arg_chain` on demand for each operand it cares +# about. # -# This replaces the prior `assume_pass!` (transform/assume.jl). The -# difference is that the aggregator does *not* mutate the SCI: no -# `Intrinsics.assume` ops are inserted, no `Core.tuple` SSAs are rebuilt, -# no `getfield` SSAs are synthesised. The "where do I attach this fact" -# step happens at bytecode emission, where per-element `Value`s already -# exist as a natural product of `resolve_tuple` (which the -# `make_tensor_view` codegen had to do anyway to feed -# `encode_MakeTensorViewOp!` its flat operands). +# Mirrors cuTile Python's `_passes/propagate_divby.py`: same +# `_OPS_NEED_ASSUME = (MakeTensorView, LoadPointer, StorePointer)` +# consumer set, same per-consumer derivation. Where Python mutates the +# IR (inserts `AssumeDivBy` ops with a `var_map` to dedup), we wrap at +# bytecode emission time and the per-`Value` cache on `CGCtx` +# (`assume_wrapped`) plays the role of `var_map`, ensuring a `Value` +# reused across consumers — e.g. a kernel-arg pointer threaded through +# both an MTV and a gather — is wrapped exactly once. # -# Tuple-element-source navigation (the recovery of per-axis SCI handles -# from a tuple-typed operand) lives in `tuple_element_source` and is -# entirely an analysis-internal concern — codegen requests "facts for -# this make_tensor_view" and gets back tuple-shaped chains it can -# consume positionally. -# -# Mirrors cuTile Python's `add_divby_pass` + inline `assume_bounded(0, -# None)` emission, but as a sidecar query rather than an IR-mutation pass. -# Their MakeTensorView has variadic per-axis operands so attaching is a -# per-slot operand swap; ours has tuple-typed operands so the bytecode -# emission is the natural attachment point. +# Pure analysis: does not mutate the SCI. + +const EMPTY_PREDS = AssumePredicate[] #============================================================================= - AssumeInfo + Per-operand chain derivation =============================================================================# """ - MTVPredicates - -Per-operand `AssumePredicate` chains for one `make_tensor_view` call. -- `ptr`: chain to wrap the base-pointer operand. -- `sizes[i]`: chain to wrap the i-th size operand (Julia/column-major order). -- `strides[i]`: chain to wrap the i-th stride operand. - -`sizes` and `strides` always have length `N` (the TileArray rank); -slots that produce no useful fact (literal element, contiguous-axis -static stride) carry an empty chain. Empty chains mean "emit no -`AssumeOp`" — same observable result as omitting the entry. + op_predicates(divby, bounds, op, kind, spec_div=1) -> Vector{AssumePredicate} + +Derive the `AssumePredicate` chain for a consumer-op operand. `kind` +selects the structural prior: +- `:ptr` — pointer operand, no `Bounded` (a pointer's range is + meaningless to tileiras's vectorizer); chain is `[DivBy(d)]` when + `d > 1`, else empty. +- `:size` / `:stride` — integer operand; always `Bounded(0, ?)` since + sizes / strides are non-negative, plus `DivBy(d)` when `d > 1`. + +`spec_div` is the consumer-side type-level divisor hint +(`spec.alignment`, `spec.shape_div_by[i]`, `spec.stride_div_by[i]`) +combined with the dataflow result via `lcm` — both inputs are +guarantees, so the value is divisible by their lcm. + +Returns `EMPTY_PREDS` for literal operands; the Tile IR translator +already sees the literal directly. """ -struct MTVPredicates - ptr::Vector{AssumePredicate} - sizes::Vector{Vector{AssumePredicate}} - strides::Vector{Vector{AssumePredicate}} -end - -const EMPTY_PREDS = AssumePredicate[] +function op_predicates(divby_info::Union{DivByInfo, Nothing}, + bounds_info::Union{BoundsInfo, Nothing}, + @nospecialize(op), + kind::Symbol, + spec_div::Int=1) + is_literal_op(op) && return EMPTY_PREDS + + df_div = op === nothing ? 0 : divby_query(divby_info, op) + d = pow2_divisor(combine_divisor(spec_div, df_div)) + + if kind === :ptr + return d > 1 ? AssumePredicate[DivBy(d)] : EMPTY_PREDS + end -""" - AssumeInfo - -Sidecar carrying per-`make_tensor_view` predicate bundles. Built by -`analyze_assume_info` from `DivByInfo` + `BoundsInfo` + the operand -TileArray's `ArraySpec`; queried by codegen via `predicates_for`. - -Each entry collapses ptr / per-axis sizes / per-axis strides for one -make_tensor_view into a single `MTVPredicates` struct, so codegen sees -the tuple-shaped result directly rather than reconstructing it from -flat indexed lookups. The walk from a tuple operand to its per-element -sources lives in `tuple_element_source` — the cost of recovering -"per-field facts on a tuple-valued operand" stays inside the analysis. -""" -struct AssumeInfo - predicates::Dict{Int, MTVPredicates} + # :size / :stride — always assert non-negativity, refine with + # dataflow's tighter range when available. + df_bound = op === nothing ? TOP_RANGE : bounds_query(bounds_info, op) + bound = combine_bound(nonneg_range(), df_bound) + chain = AssumePredicate[as_bounded(bound)] + d > 1 && push!(chain, DivBy(d)) + return chain end -AssumeInfo() = AssumeInfo(Dict{Int, MTVPredicates}()) - -""" - predicates_for(info, mtv_ssa) -> Union{MTVPredicates, Nothing} - -Return the predicate bundle for the `make_tensor_view` at SSA index -`mtv_ssa`, or `nothing` if no entry exists (e.g. the analysis didn't -run, or the make_tensor_view's TileArray type was unresolvable). -Codegen treats `nothing` as "no assumes" — same as all-empty chains. -""" -@inline predicates_for(info::AssumeInfo, mtv_ssa::Int) = - get(info.predicates, mtv_ssa, nothing) - -predicates_for(::Nothing, ::Int) = nothing - #============================================================================= - Analysis driver + Kernel-arg flat-slot chain derivation (spec-only) =============================================================================# """ - analyze_assume_info(sci, divby_info, bounds_info) -> AssumeInfo - -Walk every `Intrinsics.make_tensor_view` in `sci`, derive -divisibility / bound facts from the operand TileArray's `ArraySpec` -combined with the optional dataflow analyses, and store the resulting -`AssumePredicate` chains keyed by `(mtv_ssa, kind, slot)`. Pure -analysis: does not mutate `sci`. + arg_chain(T::Type{<:TileArray}, path) -> Vector{AssumePredicate} + +Per-flat-slot chain for a `TileArray` kernel argument. Thin +dispatcher over `op_predicates` keyed on the flat slot path +produced by `flatten_struct_params!`: + +- `[1]` → `:ptr` (with `spec.alignment`) +- `[2, i]` → `:size` (with `spec.shape_div_by[i]`) +- `[3, i]` → `:stride` (with `spec.stride_div_by[i]`) + +Dataflow inputs are `nothing` because the kernel-arg slot is the +analysis anchor — there's no upstream IR for the dataflow to refine +against. Consumer-site queries against an SSA derived from the slot +*do* carry dataflow refinement (and combine with the same spec hints +via `lcm`), so the entry-time chain is an upper bound on what any +consumer would derive — important for the `wrap_for` cache invariant +(see its docstring). + +Used by `apply_arg_assume_predicates!` (codegen/kernel.jl) at kernel +entry to wrap each flat kernel-arg `Value` *before* any consumer +reads it. Important for raw `offset` / `load_ptr_tko` / +`store_ptr_tko` access paths (gather/scatter): the assume must attach +to the base pointer, not just to the post-offset operand, for +tileiras's vectorizer to prove the wide-vector address alignment its +STG.E.128 / LDG.E.128 lowering requires. + +Returns `EMPTY_PREDS` when no useful fact exists (no spec on `T`, +unrecognised path, or contiguous-axis stride which is a static `1`). """ -function analyze_assume_info(sci::StructuredIRCode, - divby_info::Union{DivByInfo, Nothing}=nothing, - bounds_info::Union{BoundsInfo, Nothing}=nothing) - info = AssumeInfo() - walk_collect!(info, sci.entry, divby_info, bounds_info) - return info -end - -function walk_collect!(info::AssumeInfo, block::Block, - divby_info::Union{DivByInfo, Nothing}, - bounds_info::Union{BoundsInfo, Nothing}) - for inst in instructions(block) - s = inst[:stmt] - if s isa ControlFlowOp - for sub in blocks(s) - walk_collect!(info, sub, divby_info, bounds_info) - end - continue - end - call = resolve_call(block, inst) - call === nothing && continue - func, ops = call - if func === Intrinsics.make_tensor_view - collect_make_tensor_view!(info, block, inst, ops, divby_info, bounds_info) - end - end -end - -function collect_make_tensor_view!(info::AssumeInfo, block::Block, - inst::Instruction, ops, - divby_info::Union{DivByInfo, Nothing}, - bounds_info::Union{BoundsInfo, Nothing}) - length(ops) >= 4 || return - T_arg = ops[1] - ptr_op = ops[2] - sizes_op = ops[3] - strides_op = ops[4] - - T = resolve_tilearray_type(block, T_arg) - T === nothing && return +function arg_chain(::Type{T}, path::Vector{Int}) where {T <: TileArray} spec = array_spec(T) - spec === nothing && return - - N = ndims(T) - mtv_ssa = inst.ssa_idx - - # ---- Pointer --------------------------------------------------------- - ptr_div = pow2_divisor(combine_divisor(Int(spec.alignment), - divby_query(divby_info, ptr_op))) - ptr_chain = ptr_div > 1 ? AssumePredicate[DivBy(ptr_div)] : EMPTY_PREDS - - # ---- Sizes ----------------------------------------------------------- - # Lower bound is structurally `0` (sizes are non-negative). Combine - # with the dataflow result to refine: an exact known size collapses - # to `Bounded(N, N)`; a ForOp-IV-derived size to `Bounded(0, max)`, - # etc. - sizes_chains = Vector{Vector{AssumePredicate}}(undef, N) - for i in 1:N - sizes_chains[i] = element_chain(block, sizes_op, i, - Int(spec.shape_div_by[i]), - divby_info, bounds_info) + spec === nothing && return EMPTY_PREDS + + if length(path) == 1 && path[1] == 1 + return op_predicates(nothing, nothing, nothing, :ptr, Int(spec.alignment)) end - # ---- Strides --------------------------------------------------------- - strides_chains = Vector{Vector{AssumePredicate}}(undef, N) - for i in 1:N - # Skip the contiguous axis: its stride is statically `1` and never - # enters the bytecode kernel signature (filter_dynamic_strides). - if spec.contiguous && i == 1 - strides_chains[i] = EMPTY_PREDS - continue + if length(path) == 2 + i = path[2] + 1 <= i <= ndims(T) || return EMPTY_PREDS + if path[1] == 2 # sizes[i] + return op_predicates(nothing, nothing, nothing, :size, Int(spec.shape_div_by[i])) + elseif path[1] == 3 # strides[i] + # Contiguous axis: `make_tensor_view` inlines `1` and the + # `muli(x, 1)` algebra rule folds it out of scatter/gather + # offsets, so this slot never enters the bytecode signature. + spec.contiguous && i == 1 && return EMPTY_PREDS + return op_predicates(nothing, nothing, nothing, :stride, Int(spec.stride_div_by[i])) end - strides_chains[i] = element_chain(block, strides_op, i, - Int(spec.stride_div_by[i]), - divby_info, bounds_info) end - - info.predicates[mtv_ssa] = MTVPredicates(ptr_chain, sizes_chains, strides_chains) - return -end - -# Build the predicate chain for a single tuple element (size or stride). -# Walks back through `tuple_element_source` to recover a per-element SCI -# handle when one exists (`Core.tuple(...)` constructor); falls through -# to spec-only facts (`spec_div`, structural `[0, ∞)`) when the source is -# wholesale (`getfield(arg, :sizes)`) or otherwise opaque. Returns -# `EMPTY_PREDS` for literal elements and for the all-trivial case. -function element_chain(block::Block, tuple_op, i::Int, spec_div::Int, - divby_info::Union{DivByInfo, Nothing}, - bounds_info::Union{BoundsInfo, Nothing}) - elem_op = tuple_element_source(block, tuple_op, i) - # Literals — `assume bounded` on `` adds no info - # the Tile IR translator can't see directly. - is_literal_op(elem_op) && return EMPTY_PREDS - - df_div = elem_op === nothing ? 0 : divby_query(divby_info, elem_op) - df_bound = elem_op === nothing ? TOP_RANGE : bounds_query(bounds_info, elem_op) - - d = pow2_divisor(combine_divisor(spec_div, df_div)) - bound = combine_bound(nonneg_range(), df_bound) - - preds = AssumePredicate[as_bounded(bound)] - d > 1 && push!(preds, DivBy(d)) - return preds + return EMPTY_PREDS end #============================================================================= Tuple element source resolution =============================================================================# -# Resolve a tuple-typed operand to its i-th element's SCI handle. -# Recognises: -# - Literal `Tuple` values (`(64, 64)`): returns the i-th literal. -# - `Core.tuple(s1, ..., sN)` SSA: returns the i-th operand. -# - Anything else (e.g. `getfield(arg, :sizes)`): returns `nothing`, -# leaving the caller to use spec-only facts. -# -# The walk-up parent chain mirrors `value_type` / `lookup_def_call`. +""" + tuple_element_source(block, tuple_op, i) -> SSAValue / literal / nothing + +Resolve a tuple-typed operand to its i-th element's SCI handle. +Recognises: +- Literal `Tuple` values (`(64, 64)`): returns the i-th literal. +- `Core.tuple(s1, ..., sN)` SSA: returns the i-th operand. +- Anything else (e.g. `getfield(arg, :sizes)`): returns `nothing`, + leaving the caller to use spec-only facts. + +The walk-up parent chain mirrors `value_type` / `lookup_def_call`. +""" function tuple_element_source(block::Block, @nospecialize(tuple_op), i::Int) if tuple_op isa Tuple return length(tuple_op) >= i ? tuple_op[i] : nothing @@ -245,10 +168,14 @@ end Operand-type extraction =============================================================================# -# Extract a `Type{TileArray{...}}` value from an SCI operand. Recognises -# a constant `Type` literal, a `QuoteNode(::Type)`, and an SSA whose -# inferred type is `Const(T)` / `Type{T}`. Returns the unwrapped `T` or -# `nothing`. +""" + resolve_tilearray_type(block, op) -> Union{Type, Nothing} + +Extract a `Type{TileArray{...}}` value from an SCI operand. Recognises +a constant `Type` literal, a `QuoteNode(::Type)`, and an SSA whose +inferred type is `Const(T)` / `Type{T}`. Returns the unwrapped `T` or +`nothing`. +""" function resolve_tilearray_type(block::Block, @nospecialize(op)) if op isa Type op <: TileArray && return op diff --git a/src/compiler/analysis/bounds.jl b/src/compiler/analysis/bounds.jl index f411f3ab..18f73863 100644 --- a/src/compiler/analysis/bounds.jl +++ b/src/compiler/analysis/bounds.jl @@ -14,9 +14,9 @@ # precision on loop-carried values that vary per iteration; gains # precision on literal-constant flow and ForOp induction variables. # -# Consumed by `analyze_assume_info` (analysis/assume.jl) to emit -# sharper `Bounded(...)` predicates on `make_tensor_view` operands at -# codegen time, and by `no_wrap_pass!` (transform/no_wrap.jl) to +# Consumed at codegen time by `op_predicates` (analysis/assume.jl) to +# emit sharper `Bounded(...)` predicates on `make_tensor_view` size / +# stride operands, and by `no_wrap_pass!` (transform/no_wrap.jl) to # attach `nsw`/`nuw` flags on integer arithmetic that provably fits in # its destination width. diff --git a/src/compiler/analysis/divisibility.jl b/src/compiler/analysis/divisibility.jl index 1dd60dc7..93fb12c7 100644 --- a/src/compiler/analysis/divisibility.jl +++ b/src/compiler/analysis/divisibility.jl @@ -6,11 +6,12 @@ # transfer rules for arithmetic, pointer offset, and getfield chains # rooted at TileArray arguments. # -# Consumed by `analyze_assume_info` (analysis/assume.jl), which combines -# this with the bounds analysis and the operand TileArray's `ArraySpec` -# into per-`make_tensor_view` predicate bundles. Codegen reads those -# bundles and wraps each operand `Value` with `encode_AssumeOp!` — -# the analysis itself does *not* mutate the SCI. +# Consumed at codegen time by `op_predicates` (analysis/assume.jl), +# which combines this with the bounds analysis and the operand +# TileArray's `ArraySpec` into per-operand `AssumePredicate` chains. +# Codegen reads those chains via `wrap_for` and wraps each consumer's +# operand `Value` with `encode_AssumeOp!` — the analysis itself does +# *not* mutate the SCI. """ DivByAnalysis diff --git a/src/compiler/codegen/control_flow.jl b/src/compiler/codegen/control_flow.jl index 37432224..60626425 100644 --- a/src/compiler/codegen/control_flow.jl +++ b/src/compiler/codegen/control_flow.jl @@ -8,22 +8,32 @@ All SSA values use original Julia SSA indices (no local renumbering). Values are stored in ctx.values by their original index. """ function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false) - for inst in instructions(block) - # Set debug location for this instruction - if ctx.debug_emitter !== nothing - ln = isempty(ctx.linkage_name) ? nothing : ctx.linkage_name - ctx.cb.cur_debug_attr = resolve_debug_attr!( - ctx.debug_emitter, ctx.sci, inst.ssa_idx; linkage_name=ln) + # Track the current block so consumer-op codegen can perform parent- + # walking queries (e.g. `tuple_element_source` for MTV size/stride + # operands) starting from the right scope. Restored on exit so a + # caller still in an outer block sees its own context. + prev_block = ctx.current_block + ctx.current_block = block + try + for inst in instructions(block) + # Set debug location for this instruction + if ctx.debug_emitter !== nothing + ln = isempty(ctx.linkage_name) ? nothing : ctx.linkage_name + ctx.cb.cur_debug_attr = resolve_debug_attr!( + ctx.debug_emitter, ctx.sci, inst.ssa_idx; linkage_name=ln) + end + s = inst[:stmt] + if s isa ControlFlowOp + emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx) + else + emit_statement!(ctx, s, inst.ssa_idx, value_type(inst)) + end end - s = inst[:stmt] - if s isa ControlFlowOp - emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx) - else - emit_statement!(ctx, s, inst.ssa_idx, value_type(inst)) + if !skip_terminator && terminator(block) !== nothing + emit_terminator!(ctx, terminator(block)) end - end - if !skip_terminator && terminator(block) !== nothing - emit_terminator!(ctx, terminator(block)) + finally + ctx.current_block = prev_block end end diff --git a/src/compiler/codegen/kernel.jl b/src/compiler/codegen/kernel.jl index ff9e2465..dd1143ca 100644 --- a/src/compiler/codegen/kernel.jl +++ b/src/compiler/codegen/kernel.jl @@ -99,6 +99,23 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, # Set up argument values arg_values = make_block_args!(cb, length(param_types)) + # Hoist early returns BEFORE token ordering — hoist_returns! rewrites + # ReturnNode terminators to YieldOp, which the token pass then extends. + hoist_returns!(ctx.sci.entry) + + # Run the pass pipeline (normalize, optimize, token ordering, DCE). + # Returns the dataflow results consumed at consumer codegen sites. + ctx.divby_info, ctx.bounds_info = run_passes!(sci) + + # Wrap each TileArray-derived flat kernel-arg `Value` with the + # `AssumeOp` chain its `ArraySpec` justifies, *before* any consumer + # reads it. This puts the spec.alignment proof on the base pointer + # at kernel entry — gather/scatter offset chains downstream + # inherit it, which is what tileiras's vectorizer needs to lower + # to `STG.E.128` / `LDG.E.128` (the post-offset operand alone gives + # only the lane-stride alignment, not the base alignment). + apply_arg_assume_predicates!(ctx, arg_values, param_mapping, param_types, sci) + # Build arg_flat_values map. User args and the trailing KernelState # pieces land here — they go through the same `param_mapping`-keyed path. # `kernel_state()` resolves to a lazy arg_ref into this map. @@ -164,14 +181,6 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, ctx[Argument(arg_idx)] = tv end - # Hoist early returns BEFORE token ordering — hoist_returns! rewrites - # ReturnNode terminators to YieldOp, which the token pass then extends. - hoist_returns!(ctx.sci.entry) - - # Run the pass pipeline (normalize, optimize, token ordering, DCE). - # Returns the AssumeInfo sidecar consumed by `make_tensor_view` codegen. - ctx.assume_info = run_passes!(sci) - # Cache the token bytecode type for codegen ctx.token_type = Token(tt) @@ -181,6 +190,45 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8}, finalize_function!(func_buf, cb, writer.debug_info) end +""" + apply_arg_assume_predicates!(ctx, arg_values, param_mapping, param_types, sci) + +Wrap each `TileArray`-derived flat kernel-arg `Value` in `arg_values` +with the `AssumeOp` chain its `ArraySpec` implies. Mutates `arg_values` +in place. The slot path (`[1]` is `ptr`, `[2, i]` is `sizes[i]`, +`[3, i]` is `strides[i]`) maps to a chain via `arg_chain` (analysis/ +assume.jl). + +After wrapping, the `wrapped` `Value` is recorded as a fixed point in +`ctx.assume_wrapped`: a consumer-side `wrap_for(ctx, wrapped, ...)` +hits the cache and returns `wrapped` unchanged, so a kernel-arg ptr +consumed by both an MTV and a gather is wrapped *exactly once* across +the kernel. +""" +function apply_arg_assume_predicates!(ctx::CGCtx, arg_values::Vector{Value}, + param_mapping::Vector{Tuple{Int, Vector{Int}}}, + param_types::Vector{TypeId}, + sci::StructuredIRCode) + for param_idx in eachindex(arg_values) + arg_idx, path = param_mapping[param_idx] + # Trailing `KernelState` arg has no Julia argtype entry. + arg_idx > length(sci.argtypes) && continue + argT = CC.widenconst(sci.argtypes[arg_idx]) + argT <: TileArray || continue + chain = arg_chain(argT, path) + isempty(chain) && continue + original = arg_values[param_idx] + wrapped = original + for p in chain + wrapped = encode_AssumeOp!(ctx.cb, param_types[param_idx], wrapped, p) + end + arg_values[param_idx] = wrapped + # Mark the wrapped `Value` as a fixed point so subsequent + # consumer wraps don't re-emit the same predicates. + ctx.assume_wrapped[wrapped] = wrapped + end +end + """ flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, T, path) @@ -316,14 +364,15 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector, end # 2b. Run the pass pipeline on subprogram IR - sub_assume_info = run_passes!(sci) + sub_divby, sub_bounds = run_passes!(sci) # 3. Create sub-context (inherits active fpmode from caller) sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci, ctx.token_type, ctx.type_cache, ctx.sm_arch, ctx.cache) - sub_ctx.assume_info = sub_assume_info + sub_ctx.divby_info = sub_divby + sub_ctx.bounds_info = sub_bounds append!(sub_ctx.fpmode_stack, ctx.fpmode_stack) # Inherit kernel-state flat values from the parent. Subprograms compile diff --git a/src/compiler/intrinsics/memory.jl b/src/compiler/intrinsics/memory.jl index fd3084fe..24b30f3f 100644 --- a/src/compiler/intrinsics/memory.jl +++ b/src/compiler/intrinsics/memory.jl @@ -33,7 +33,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args) # args: (ptrs, latency, mask?, padding?) ptrs_tv = emit_value!(ctx, args[1]) ptrs_tv === nothing && throw(IRError("load_ptr_tko: cannot resolve pointer tile")) - pointers = ptrs_tv.v + pointers = wrap_for(ctx, ptrs_tv.v::Value, ptrs_tv.type_id::TypeId, + op_predicates(ctx.divby_info, ctx.bounds_info, + args[1], :ptr)) tile_shape = ptrs_tv.shape ptrs_type = CC.widenconst(ptrs_tv.jltype) @@ -99,7 +101,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args) ptrs_tv = emit_value!(ctx, args[1]) ptrs_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve pointer tile")) - pointers = ptrs_tv.v + pointers = wrap_for(ctx, ptrs_tv.v::Value, ptrs_tv.type_id::TypeId, + op_predicates(ctx.divby_info, ctx.bounds_info, + args[1], :ptr)) values_tv = emit_value!(ctx, args[2]) values_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve values tile")) diff --git a/src/compiler/intrinsics/misc.jl b/src/compiler/intrinsics/misc.jl index 88d0282a..77d24979 100644 --- a/src/compiler/intrinsics/misc.jl +++ b/src/compiler/intrinsics/misc.jl @@ -32,12 +32,14 @@ end `SameElements`). Returns its input value — a pure-data annotation, eliminated if downstream uses vanish. - The make_tensor_view assume bundle is emitted directly to bytecode by - `analyze_assume_info` + `views.jl` codegen and never materialises as - an `Intrinsics.assume` SCI op; this intrinsic exists for hand-written - user annotations and as the lattice-level shape the dataflow analyses - recognise (so a future pass that does insert SCI-level assumes still - composes correctly with divisibility/bounds). + Consumer-driven assumes (`make_tensor_view`, `load_ptr_tko`, + `store_ptr_tko`) are emitted directly to bytecode at codegen time — + the chain comes from `op_predicates` / `arg_chain` (analysis/assume.jl) + and the wrapping happens via `wrap_for` (intrinsics/views.jl), so they + never materialise as `Intrinsics.assume` SCI ops. This intrinsic exists + for hand-written user annotations and as the lattice-level shape the + dataflow analyses recognise (so a future pass that does insert + SCI-level assumes still composes correctly with divisibility/bounds). cuTile Python uses one IR op class per predicate (`AssumeDivBy`, `AssumeBounded`, …); we collapse to a single polymorphic intrinsic diff --git a/src/compiler/intrinsics/views.jl b/src/compiler/intrinsics/views.jl index 5a55c86b..0ba83a63 100644 --- a/src/compiler/intrinsics/views.jl +++ b/src/compiler/intrinsics/views.jl @@ -249,13 +249,13 @@ Constructs a `TensorView` from a destructured `TileArray`; lowers to The first argument is a compile-time constant `TileArray` type. Its `ArraySpec` (alignment, contiguity, per-axis divisibility) plus the -divisibility / bounds dataflow analyses are aggregated by -`analyze_assume_info` (analysis/assume.jl) into a per-operand -predicate sidecar; codegen reads the sidecar via `predicates_for` and -wraps each operand `Value` with `encode_AssumeOp!` before feeding the -result to `encode_MakeTensorViewOp!`. `sizes` and `strides` are tuples -in Julia (column-major) order; they are reversed for Tile IR's -row-major layout. +divisibility / bounds dataflow analyses feed `op_predicates` +(analysis/assume.jl) at codegen time to derive an `AssumePredicate` +chain per operand; `wrap_for` consults the per-`Value` cache so each +source `Value` is wrapped at most once across all consumers, then the +wrapped operands are fed to `encode_MakeTensorViewOp!`. `sizes` and +`strides` are tuples in Julia (column-major) order; they are reversed +for Tile IR's row-major layout. """ @intrinsic make_tensor_view(::Type{T}, ptr, sizes, strides) where {T} function tfunc(𝕃, ::typeof(Intrinsics.make_tensor_view), @@ -293,26 +293,47 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_tensor_view), args length(stride_tvs) == ndim || throw(IRError("make_tensor_view: expected $ndim strides, got $(length(stride_tvs))")) - # Wrap each operand `Value` with the AssumeOp chain the aggregator - # (analysis/assume.jl) computed for this make_tensor_view call. The - # bundle is per-mtv and tuple-shaped: `mtv_preds.ptr` is the ptr - # chain, `mtv_preds.sizes[i]` / `mtv_preds.strides[i]` are per-axis - # chains in Julia (column-major) order. Each `encode_AssumeOp!` - # emits one Tile IR `AssumeOp` and returns a fresh `Value` that we - # thread into the next link of the chain. - mtv_preds = predicates_for(ctx.assume_info, ctx.current_ssa_idx) - - base_ptr = wrap_chain!(cb, base_ptr, ptr_tv.type_id::TypeId, - mtv_preds === nothing ? EMPTY_PREDS : mtv_preds.ptr) + # Wrap each operand `Value` with the `AssumeOp` chain derived + # on demand from the divby/bounds dataflow plus this MTV's spec. + # `wrap_for` consults `ctx.assume_wrapped` so a `Value` shared + # with another consumer (e.g. a gather over the same kernel-arg + # ptr) — or with this kernel's entry-time slot wrap — is wrapped + # exactly once. For tuple-typed sizes/strides we walk back via + # `tuple_element_source` to the per-axis source SSA so the + # dataflow query has the right anchor; when the source is opaque + # (wholesale `getfield(arg, :sizes)`) `nothing` falls through to + # spec-only facts. + block = ctx.current_block::Block + + # Spec-derived divisor hints (1 = "no info") combine with the + # dataflow via `lcm` inside `op_predicates`, so a missing spec + # collapses cleanly to dataflow-only facts. + align_hint = spec === nothing ? 1 : Int(spec.alignment) + shape_hint(i) = spec === nothing ? 1 : Int(spec.shape_div_by[i]) + stride_hint(i) = spec === nothing ? 1 : Int(spec.stride_div_by[i]) + + base_ptr = wrap_for(ctx, base_ptr, ptr_tv.type_id::TypeId, + op_predicates(ctx.divby_info, ctx.bounds_info, + ptr_arg, :ptr, align_hint)) size_vals = Value[ - wrap_chain!(cb, tv.v::Value, tv.type_id::TypeId, - mtv_preds === nothing ? EMPTY_PREDS : mtv_preds.sizes[i]) + let elem_op = tuple_element_source(block, sizes_arg, i) + wrap_for(ctx, tv.v::Value, tv.type_id::TypeId, + op_predicates(ctx.divby_info, ctx.bounds_info, + elem_op, :size, shape_hint(i))) + end for (i, tv) in enumerate(size_tvs) ] stride_vals = Value[ - wrap_chain!(cb, tv.v::Value, tv.type_id::TypeId, - mtv_preds === nothing ? EMPTY_PREDS : mtv_preds.strides[i]) + let elem_op = tuple_element_source(block, strides_arg, i), + # Skip the contiguous axis: its stride is statically `1` + # and never enters the bytecode kernel signature + # (`filter_dynamic_strides`). + chain = (spec !== nothing && spec.contiguous && i == 1) ? EMPTY_PREDS : + op_predicates(ctx.divby_info, ctx.bounds_info, + elem_op, :stride, stride_hint(i)) + wrap_for(ctx, tv.v::Value, tv.type_id::TypeId, chain) + end for (i, tv) in enumerate(stride_tvs) ] @@ -332,16 +353,50 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_tensor_view), args return CGVal(tensor_view, tv_type, result_jltype) end -# Apply a predicate chain to a `Value`: each `AssumeOp` returns a new -# `Value` that becomes the input to the next link. Empty chain returns -# the input unchanged. Caller passes `EMPTY_PREDS` when the assume -# sidecar has no entry for this operand. -@inline function wrap_chain!(cb::CodeBuilder, value::Value, type_id::TypeId, - preds::Vector{AssumePredicate}) +""" + wrap_for(ctx, value, type_id, preds) -> Value + +Apply an `AssumePredicate` chain to a `Value` at most once across all +consumers. `ctx.assume_wrapped` records the first wrap so subsequent +consumers of the same source `Value` reuse it instead of emitting a +parallel `AssumeOp` chain. Empty chain returns the input unchanged. +Mirrors the role of cuTile Python's `var_map` in +`_passes/propagate_divby.py::_add_assume_divby`. + +Cache invariant: the cache keys on `Value` only, *not* on the chain +contents. This is sound only when every consumer-derived chain on a +given `Value` is a subset of the first-seen chain — i.e. the first +wrap establishes an upper bound on the facts that any later consumer +would derive. The pipeline arranges this in two ways: + +- **Kernel-arg slots:** `apply_arg_assume_predicates!` runs at kernel + entry and seeds the cache with the spec-tightest chain for each + TileArray-derived flat slot. Consumer-site `op_predicates` calls on + SSAs sourced from the same slot can only re-derive a subset (same + spec hints, equally-tight or looser dataflow), so the cache hit + drops no information. +- **Per-`Value` consistency of structural priors:** `op_predicates`'s + `kind` selector (`:ptr` vs. `:size`/`:stride`) is determined by the + operand's tile type. A single `Value` has one tile type, so all + consumers see the same `kind` and the same structural prior. + +If you ever introduce a consumer that derives a *tighter* chain on a +`Value` already wrapped at kernel entry, the cache will silently drop +the extra facts. Either route the new consumer through a fresh `Value` +(common — the post-offset gather ptr already does this) or refine the +cache key. +""" +@inline function wrap_for(ctx::CGCtx, value::Value, type_id::TypeId, + preds::Vector{AssumePredicate}) + isempty(preds) && return value + cached = get(ctx.assume_wrapped, value, nothing) + cached !== nothing && return cached + wrapped = value for p in preds - value = encode_AssumeOp!(cb, type_id, value, p) + wrapped = encode_AssumeOp!(ctx.cb, type_id, wrapped, p) end - return value + ctx.assume_wrapped[value] = wrapped + return wrapped end """ diff --git a/src/compiler/transform/pipeline.jl b/src/compiler/transform/pipeline.jl index 0dd8a7a5..00df84d9 100644 --- a/src/compiler/transform/pipeline.jl +++ b/src/compiler/transform/pipeline.jl @@ -145,6 +145,22 @@ const ALGEBRA_RULES = RewriteRule[ @rewriter Intrinsics.addi(Intrinsics.reshape(~x, ~s), ~c) => commute_arith_transparent @rewriter Intrinsics.subi(Intrinsics.broadcast(~x, ~s), ~c) => commute_arith_transparent @rewriter Intrinsics.addi(Intrinsics.broadcast(~x, ~s), ~c) => commute_arith_transparent + + # muli identity: x * 1 → x. Drives the contiguous-axis stride fold in + # gather/scatter offset chains: for a `TileArray` with `ArraySpec` + # `contiguous=true`, `getfield(getfield(arg, :strides), 1)` is statically + # `1` (recognised by `analyze_constants`), and constant analysis + # propagates through `broadcast`/`reshape`/`from_scalar`. The fold drops + # the `muli(idx, broadcast(1))` so the surviving offset has the + # `idx + idx_other_axis * stride_other` shape that tileiras's + # auto-vectorizer matches against — without it, the contiguous-axis + # stride is a runtime value and consecutive lanes' addresses differ by + # an unknown scalar, forcing scalar (`STG.E.U16`) stores instead of + # wide vector (`STG.E.128`) stores. Mirrors Python cuTile's + # `_gather_scatter_pointer_and_mask` static-stride skip + # (`if static_stride == 1: offset_delta = ind`). + @rewrite Intrinsics.muli(~x, $(1)) => ~x + @rewrite Intrinsics.muli($(1), ~x) => ~x ] algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES) @@ -236,12 +252,14 @@ const OPTIMIZATION_RULES = RewriteRule[ =============================================================================# """ - run_passes!(sci::StructuredIRCode) -> AssumeInfo + run_passes!(sci::StructuredIRCode) -> (DivByInfo, BoundsInfo) Run the full pass pipeline on a StructuredIRCode. Called for both -kernel and subprogram compilation. Returns the `AssumeInfo` aggregator -that codegen consumes when emitting `make_tensor_view` operands; the -caller stores it on the `CGCtx`. +kernel and subprogram compilation. Returns the divisibility / bounds +dataflow results; the caller stores them on the `CGCtx` so consumer-op +codegen (`make_tensor_view`, `load_ptr_tko`, `store_ptr_tko`) can +derive per-operand `AssumePredicate` chains on demand via +`op_predicates` (analysis/assume.jl). """ function run_passes!(sci::StructuredIRCode) canonicalize!(sci) @@ -257,8 +275,8 @@ function run_passes!(sci::StructuredIRCode) # before alias analysis so the alias map is built over the # deduplicated form. The dedup naturally extends to TileViews and # the `getfield(arg, :ptr|:sizes|:strides)` chains that feed them, - # which is what the downstream `analyze_assume_info` and - # `licm_pass!` benefit from most. + # which is what downstream consumers (codegen assume wraps, LICM) + # benefit from most. cse_pass!(sci) alias_info = analyze_aliases(sci) @@ -270,24 +288,21 @@ function run_passes!(sci::StructuredIRCode) licm_pass!(sci) - # Build the assume sidecar for codegen. Runs after LICM so the - # dataflow analyses see the post-LICM form. Pure analysis: does - # *not* mutate the SCI — `make_tensor_view` codegen reads the - # result and emits `AssumeOp`s on the per-element bytecode `Value`s - # that `resolve_tuple` produces. - divby = analyze_divisibility(sci) - bnds = analyze_bounds(sci) - assume_info = analyze_assume_info(sci, divby, bnds) + # Run dataflow analyses. Pure: does *not* mutate the SCI — the + # codegen consumer ops query these on demand to derive each + # operand's `AssumePredicate` chain. + divby_info = analyze_divisibility(sci) + bounds_info = analyze_bounds(sci) # Attach `no_signed_wrap` / `no_unsigned_wrap` flags to integer # arithmetic where the bounds analysis proves the result fits in - # the destination width. Reuses the same `bnds` result; mutates - # `addi`/`subi`/`muli` Exprs in place by appending an + # the destination width. Reuses the same `bounds_info` result; + # mutates `addi`/`subi`/`muli` Exprs in place by appending an # `IntegerOverflow.T` operand that the codegen forwards as the # encoder's overflow kwarg. - no_wrap_pass!(sci, bnds) + no_wrap_pass!(sci, bounds_info) dce_pass!(sci) - return assume_info + return divby_info, bounds_info end diff --git a/src/compiler/utils.jl b/src/compiler/utils.jl index 1e3a3eb6..fb6d5c63 100644 --- a/src/compiler/utils.jl +++ b/src/compiler/utils.jl @@ -302,14 +302,28 @@ mutable struct CGCtx # Kernel linkage name (for debug info subprogram) linkage_name::String - # Per-make_tensor_view assume predicates, populated by `run_passes!` - # via `analyze_assume_info`. `nothing` when no pipeline ran (e.g. - # tests building a CGCtx by hand). Queried by `make_tensor_view`'s - # codegen; see `analysis/assume.jl`. + # Dataflow analyses, populated by `run_passes!`. `nothing` when no + # pipeline ran (e.g. tests building a CGCtx by hand). Queried by + # `op_predicates` (analysis/assume.jl) at consumer sites + # (`make_tensor_view`, `load_ptr_tko`, `store_ptr_tko`) to derive + # per-operand `AssumePredicate` chains on demand. # - # Untyped (vs. `Union{AssumeInfo, Nothing}`) because `AssumeInfo` is - # defined in `analysis/assume.jl`, included after this file. - assume_info::Any + # Untyped (vs. `Union{DivByInfo, Nothing}` etc.) because the result + # types are defined in `analysis/`, included after this file. + divby_info::Any + bounds_info::Any + + # Per-`Value` `AssumeOp` wrap cache. The first consumer that wraps a + # given `Value` records the result here; subsequent consumers reuse + # it instead of emitting a parallel `AssumeOp` chain. Mirrors cuTile + # Python's `var_map` dedup in `_passes/propagate_divby.py`. Reset + # per kernel by `CGCtx`'s constructor. + assume_wrapped::Dict{Value, Value} + + # Block currently being emitted. Set by `emit_block!` per region so + # `tuple_element_source` and other parent-walking queries can start + # from the right scope. `nothing` when no block has been entered yet. + current_block::Any end function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, @@ -338,7 +352,10 @@ function CGCtx(; cb::CodeBuilder, tt::TypeTable, sci::StructuredIRCode, FPMode[], # fpmode_stack debug_emitter, linkage_name, - nothing, # assume_info — set by run_passes! + nothing, # divby_info — set by run_passes! + nothing, # bounds_info — set by run_passes! + Dict{Value, Value}(), # assume_wrapped + nothing, # current_block — set by emit_block! ) end diff --git a/test/codegen/assume.jl b/test/codegen/assume.jl index a157e8f6..07dd5a96 100644 --- a/test/codegen/assume.jl +++ b/test/codegen/assume.jl @@ -1,9 +1,10 @@ -# Codegen tests for the `analyze_assume_info` aggregator + the -# `make_tensor_view` codegen path that wraps each operand `Value` with -# `encode_AssumeOp!`. Facts come from the TileArray-type `ArraySpec` -# plus the divisibility dataflow analysis (analysis/divisibility.jl) -# and bounds dataflow analysis (analysis/bounds.jl), so derived -# TileArrays (slices, permutes, reshapes) get assumes too — recovering +# Codegen tests for the consumer-driven `AssumeOp` emission path +# (`make_tensor_view`, `load_ptr_tko`, `store_ptr_tko`) plus the +# kernel-arg-entry wrap. Chains are derived on demand by `op_predicates` +# / `arg_chain` (analysis/assume.jl) from the TileArray-type `ArraySpec` +# combined with the divisibility (analysis/divisibility.jl) and bounds +# (analysis/bounds.jl) dataflow analyses, so derived TileArrays +# (slices, permutes, reshapes) get assumes too — recovering # through-arithmetic facts that the conservative `sliced_arraytype` # etc. drop. @@ -135,3 +136,56 @@ end @check "assume bounded<0, ?>" end end + +@testset "assume — kernel-arg ptr wrap survives offset for gather/scatter" begin + # Pure-gather/scatter kernel: no MTV consumes the kernel-arg ptr, so + # the only path that can attach `spec.alignment` to the base pointer + # is the kernel-arg-entry wrap (`apply_arg_assume_predicates!`). The + # post-offset ptr that reaches `load_ptr_tko` / `store_ptr_tko` then + # carries the base-alignment fact (via the assumed `Value` flowing + # through `reshape` → `broadcast` → `offset`) plus a tighter local + # divby chain wrapped at the consumer site. + spec1d = ct.ArraySpec{1}(128, true) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, + ct.TileArray{Float32,1,spec1d}}) do a, b + indices = ct.arange(16) + tile = ct.gather(a, indices) + ct.scatter(b, indices, tile) + return + end + # Base alignment on each kernel-arg ptr (entry wrap). + @check "assume div_by<128>" + @check "assume div_by<128>" + # Local divby chain on the post-offset ptr at each consumer. + @check "assume div_by<" + @check "load_ptr_tko" + @check "assume div_by<" + @check "store_ptr_tko" + end +end + +@testset "assume — shared ptr Value is wrapped once across consumers" begin + # Two MTVs (`ct.load` + `ct.store`) plus a gather all source from + # the same kernel-arg ptr. The entry wrap puts one `assume div_by<128>` + # on it; the per-`Value` cache (`ctx.assume_wrapped`) ensures the + # MTV consumer wraps don't re-emit the same predicate on the same + # source. The post-offset gather ptr is a different `Value` and + # gets its own (looser) divby chain. + spec1d = ct.ArraySpec{1}(128, true) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a + tile = ct.load(a, 1, (16,)) + indices = ct.arange(16) + tile2 = ct.gather(a, indices) + ct.store(a, 1, tile + tile2) + return + end + # Exactly one `assume div_by<128>` despite three consumers of + # the same kernel-arg ptr. + @check "assume div_by<128>" + @check_not "assume div_by<128>" + end +end diff --git a/test/codegen/bounds.jl b/test/codegen/bounds.jl index cd3ecf06..f2197822 100644 --- a/test/codegen/bounds.jl +++ b/test/codegen/bounds.jl @@ -1,9 +1,9 @@ # Codegen tests for `BoundsAnalysis` and the bounded-predicate -# emission path through `analyze_assume_info`. The analysis tracks -# integer-valued SSA values to a closed `[lo, hi]` interval; the -# aggregator intersects the result with the structural lower bound -# (sizes/strides ≥ 0) and emits a sharper `Bounded(lo, hi)` predicate -# where the dataflow has information. +# emission path through `op_predicates` (analysis/assume.jl). The +# analysis tracks integer-valued SSA values to a closed `[lo, hi]` +# interval; `op_predicates` intersects the result with the structural +# lower bound (sizes/strides ≥ 0) and emits a sharper `Bounded(lo, hi)` +# predicate where the dataflow has information. # # Today the dataflow doesn't add information at `make_tensor_view` # operands for typical kernels (sizes come from `getfield(arg, :sizes)` diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index 84f587b2..7725d0b7 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -564,6 +564,46 @@ end end end end + + @testset "contiguous-axis stride folds out of 2D gather offset" begin + spec_out = ct.ArraySpec{1}(16, true) + + # Contiguous 2D source: stride[1]=1 statically, the muli for axis-1 + # folds away. Exactly one muli (for axis-2 stride) survives. + spec2d_c = ct.ArraySpec{2}(16, true) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d_c}, + ct.TileArray{Float32,1,spec_out}}) do a, b + pid = ct.bid(1) + i0 = ct.arange(16) + i1 = ct.arange(16) + tile = ct.gather(a, (i0, i1)) + ct.store(b, pid, tile) + return + end + @check "muli" + @check_not "muli" + end + + # Sibling: non-contiguous keeps both stride multiplies. + # Confirms the fold is gated on `Spec.contiguous`, not unconditional. + spec2d_nc = ct.ArraySpec{2}(16, false) + @test @filecheck begin + @check_label "entry" + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d_nc}, + ct.TileArray{Float32,1,spec_out}}) do a, b + pid = ct.bid(1) + i0 = ct.arange(16) + i1 = ct.arange(16) + tile = ct.gather(a, (i0, i1)) + ct.store(b, pid, tile) + return + end + @check "muli" + @check "muli" + end + end end #=========================================================================