Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 123 additions & 196 deletions src/compiler/analysis/assume.jl

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/compiler/analysis/bounds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# precision on loop-carried values that vary per iteration; gains
# precision on literal-constant flow and ForOp induction variables.
#
# Consumed by `analyze_assume_info` (analysis/assume.jl) to emit
# sharper `Bounded(...)` predicates on `make_tensor_view` operands at
# codegen time, and by `no_wrap_pass!` (transform/no_wrap.jl) to
# Consumed at codegen time by `op_predicates` (analysis/assume.jl) to
# emit sharper `Bounded(...)` predicates on `make_tensor_view` size /
# stride operands, and by `no_wrap_pass!` (transform/no_wrap.jl) to
# attach `nsw`/`nuw` flags on integer arithmetic that provably fits in
# its destination width.

Expand Down
11 changes: 6 additions & 5 deletions src/compiler/analysis/divisibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
# transfer rules for arithmetic, pointer offset, and getfield chains
# rooted at TileArray arguments.
#
# Consumed by `analyze_assume_info` (analysis/assume.jl), which combines
# this with the bounds analysis and the operand TileArray's `ArraySpec`
# into per-`make_tensor_view` predicate bundles. Codegen reads those
# bundles and wraps each operand `Value` with `encode_AssumeOp!` —
# the analysis itself does *not* mutate the SCI.
# Consumed at codegen time by `op_predicates` (analysis/assume.jl),
# which combines this with the bounds analysis and the operand
# TileArray's `ArraySpec` into per-operand `AssumePredicate` chains.
# Codegen reads those chains via `wrap_for` and wraps each consumer's
# operand `Value` with `encode_AssumeOp!` — the analysis itself does
# *not* mutate the SCI.

"""
DivByAnalysis
Expand Down
38 changes: 24 additions & 14 deletions src/compiler/codegen/control_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,32 @@ All SSA values use original Julia SSA indices (no local renumbering).
Values are stored in ctx.values by their original index.
"""
function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false)
for inst in instructions(block)
# Set debug location for this instruction
if ctx.debug_emitter !== nothing
ln = isempty(ctx.linkage_name) ? nothing : ctx.linkage_name
ctx.cb.cur_debug_attr = resolve_debug_attr!(
ctx.debug_emitter, ctx.sci, inst.ssa_idx; linkage_name=ln)
# Track the current block so consumer-op codegen can perform parent-
# walking queries (e.g. `tuple_element_source` for MTV size/stride
# operands) starting from the right scope. Restored on exit so a
# caller still in an outer block sees its own context.
prev_block = ctx.current_block
ctx.current_block = block
try
for inst in instructions(block)
# Set debug location for this instruction
if ctx.debug_emitter !== nothing
ln = isempty(ctx.linkage_name) ? nothing : ctx.linkage_name
ctx.cb.cur_debug_attr = resolve_debug_attr!(
ctx.debug_emitter, ctx.sci, inst.ssa_idx; linkage_name=ln)
end
s = inst[:stmt]
if s isa ControlFlowOp
emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx)
else
emit_statement!(ctx, s, inst.ssa_idx, value_type(inst))
end
end
s = inst[:stmt]
if s isa ControlFlowOp
emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx)
else
emit_statement!(ctx, s, inst.ssa_idx, value_type(inst))
if !skip_terminator && terminator(block) !== nothing
emit_terminator!(ctx, terminator(block))
end
end
if !skip_terminator && terminator(block) !== nothing
emit_terminator!(ctx, terminator(block))
finally
ctx.current_block = prev_block
end
end

Expand Down
69 changes: 59 additions & 10 deletions src/compiler/codegen/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
# Set up argument values
arg_values = make_block_args!(cb, length(param_types))

# Hoist early returns BEFORE token ordering — hoist_returns! rewrites
# ReturnNode terminators to YieldOp, which the token pass then extends.
hoist_returns!(ctx.sci.entry)

# Run the pass pipeline (normalize, optimize, token ordering, DCE).
# Returns the dataflow results consumed at consumer codegen sites.
ctx.divby_info, ctx.bounds_info = run_passes!(sci)

# Wrap each TileArray-derived flat kernel-arg `Value` with the
# `AssumeOp` chain its `ArraySpec` justifies, *before* any consumer
# reads it. This puts the spec.alignment proof on the base pointer
# at kernel entry — gather/scatter offset chains downstream
# inherit it, which is what tileiras's vectorizer needs to lower
# to `STG.E.128` / `LDG.E.128` (the post-offset operand alone gives
# only the lane-stride alignment, not the base alignment).
apply_arg_assume_predicates!(ctx, arg_values, param_mapping, param_types, sci)

# Build arg_flat_values map. User args and the trailing KernelState
# pieces land here — they go through the same `param_mapping`-keyed path.
# `kernel_state()` resolves to a lazy arg_ref into this map.
Expand Down Expand Up @@ -164,14 +181,6 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
ctx[Argument(arg_idx)] = tv
end

# Hoist early returns BEFORE token ordering — hoist_returns! rewrites
# ReturnNode terminators to YieldOp, which the token pass then extends.
hoist_returns!(ctx.sci.entry)

# Run the pass pipeline (normalize, optimize, token ordering, DCE).
# Returns the AssumeInfo sidecar consumed by `make_tensor_view` codegen.
ctx.assume_info = run_passes!(sci)

# Cache the token bytecode type for codegen
ctx.token_type = Token(tt)

Expand All @@ -181,6 +190,45 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
finalize_function!(func_buf, cb, writer.debug_info)
end

"""
apply_arg_assume_predicates!(ctx, arg_values, param_mapping, param_types, sci)

Wrap each `TileArray`-derived flat kernel-arg `Value` in `arg_values`
with the `AssumeOp` chain its `ArraySpec` implies. Mutates `arg_values`
in place. The slot path (`[1]` is `ptr`, `[2, i]` is `sizes[i]`,
`[3, i]` is `strides[i]`) maps to a chain via `arg_chain` (analysis/
assume.jl).

After wrapping, the `wrapped` `Value` is recorded as a fixed point in
`ctx.assume_wrapped`: a consumer-side `wrap_for(ctx, wrapped, ...)`
hits the cache and returns `wrapped` unchanged, so a kernel-arg ptr
consumed by both an MTV and a gather is wrapped *exactly once* across
the kernel.
"""
function apply_arg_assume_predicates!(ctx::CGCtx, arg_values::Vector{Value},
param_mapping::Vector{Tuple{Int, Vector{Int}}},
param_types::Vector{TypeId},
sci::StructuredIRCode)
for param_idx in eachindex(arg_values)
arg_idx, path = param_mapping[param_idx]
# Trailing `KernelState` arg has no Julia argtype entry.
arg_idx > length(sci.argtypes) && continue
argT = CC.widenconst(sci.argtypes[arg_idx])
argT <: TileArray || continue
chain = arg_chain(argT, path)
isempty(chain) && continue
original = arg_values[param_idx]
wrapped = original
for p in chain
wrapped = encode_AssumeOp!(ctx.cb, param_types[param_idx], wrapped, p)
end
arg_values[param_idx] = wrapped
# Mark the wrapped `Value` as a fixed point so subsequent
# consumer wraps don't re-emit the same predicates.
ctx.assume_wrapped[wrapped] = wrapped
end
end

"""
flatten_struct_params!(ctx, param_types, param_mapping, arg_idx, T, path)

Expand Down Expand Up @@ -316,14 +364,15 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
end

# 2b. Run the pass pipeline on subprogram IR
sub_assume_info = run_passes!(sci)
sub_divby, sub_bounds = run_passes!(sci)

# 3. Create sub-context (inherits active fpmode from caller)
sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci,
ctx.token_type,
ctx.type_cache, ctx.sm_arch,
ctx.cache)
sub_ctx.assume_info = sub_assume_info
sub_ctx.divby_info = sub_divby
sub_ctx.bounds_info = sub_bounds
append!(sub_ctx.fpmode_stack, ctx.fpmode_stack)

# Inherit kernel-state flat values from the parent. Subprograms compile
Expand Down
8 changes: 6 additions & 2 deletions src/compiler/intrinsics/memory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.load_ptr_tko), args)
# args: (ptrs, latency, mask?, padding?)
ptrs_tv = emit_value!(ctx, args[1])
ptrs_tv === nothing && throw(IRError("load_ptr_tko: cannot resolve pointer tile"))
pointers = ptrs_tv.v
pointers = wrap_for(ctx, ptrs_tv.v::Value, ptrs_tv.type_id::TypeId,
op_predicates(ctx.divby_info, ctx.bounds_info,
args[1], :ptr))
tile_shape = ptrs_tv.shape

ptrs_type = CC.widenconst(ptrs_tv.jltype)
Expand Down Expand Up @@ -99,7 +101,9 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.store_ptr_tko), args)

ptrs_tv = emit_value!(ctx, args[1])
ptrs_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve pointer tile"))
pointers = ptrs_tv.v
pointers = wrap_for(ctx, ptrs_tv.v::Value, ptrs_tv.type_id::TypeId,
op_predicates(ctx.divby_info, ctx.bounds_info,
args[1], :ptr))

values_tv = emit_value!(ctx, args[2])
values_tv === nothing && throw(IRError("store_ptr_tko: cannot resolve values tile"))
Expand Down
14 changes: 8 additions & 6 deletions src/compiler/intrinsics/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ end
`SameElements`). Returns its input value — a pure-data annotation,
eliminated if downstream uses vanish.

The make_tensor_view assume bundle is emitted directly to bytecode by
`analyze_assume_info` + `views.jl` codegen and never materialises as
an `Intrinsics.assume` SCI op; this intrinsic exists for hand-written
user annotations and as the lattice-level shape the dataflow analyses
recognise (so a future pass that does insert SCI-level assumes still
composes correctly with divisibility/bounds).
Consumer-driven assumes (`make_tensor_view`, `load_ptr_tko`,
`store_ptr_tko`) are emitted directly to bytecode at codegen time —
the chain comes from `op_predicates` / `arg_chain` (analysis/assume.jl)
and the wrapping happens via `wrap_for` (intrinsics/views.jl), so they
never materialise as `Intrinsics.assume` SCI ops. This intrinsic exists
for hand-written user annotations and as the lattice-level shape the
dataflow analyses recognise (so a future pass that does insert
SCI-level assumes still composes correctly with divisibility/bounds).

cuTile Python uses one IR op class per predicate (`AssumeDivBy`,
`AssumeBounded`, …); we collapse to a single polymorphic intrinsic
Expand Down
115 changes: 85 additions & 30 deletions src/compiler/intrinsics/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ Constructs a `TensorView` from a destructured `TileArray`; lowers to

The first argument is a compile-time constant `TileArray` type. Its
`ArraySpec` (alignment, contiguity, per-axis divisibility) plus the
divisibility / bounds dataflow analyses are aggregated by
`analyze_assume_info` (analysis/assume.jl) into a per-operand
predicate sidecar; codegen reads the sidecar via `predicates_for` and
wraps each operand `Value` with `encode_AssumeOp!` before feeding the
result to `encode_MakeTensorViewOp!`. `sizes` and `strides` are tuples
in Julia (column-major) order; they are reversed for Tile IR's
row-major layout.
divisibility / bounds dataflow analyses feed `op_predicates`
(analysis/assume.jl) at codegen time to derive an `AssumePredicate`
chain per operand; `wrap_for` consults the per-`Value` cache so each
source `Value` is wrapped at most once across all consumers, then the
wrapped operands are fed to `encode_MakeTensorViewOp!`. `sizes` and
`strides` are tuples in Julia (column-major) order; they are reversed
for Tile IR's row-major layout.
"""
@intrinsic make_tensor_view(::Type{T}, ptr, sizes, strides) where {T}
function tfunc(𝕃, ::typeof(Intrinsics.make_tensor_view),
Expand Down Expand Up @@ -293,26 +293,47 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_tensor_view), args
length(stride_tvs) == ndim ||
throw(IRError("make_tensor_view: expected $ndim strides, got $(length(stride_tvs))"))

# Wrap each operand `Value` with the AssumeOp chain the aggregator
# (analysis/assume.jl) computed for this make_tensor_view call. The
# bundle is per-mtv and tuple-shaped: `mtv_preds.ptr` is the ptr
# chain, `mtv_preds.sizes[i]` / `mtv_preds.strides[i]` are per-axis
# chains in Julia (column-major) order. Each `encode_AssumeOp!`
# emits one Tile IR `AssumeOp` and returns a fresh `Value` that we
# thread into the next link of the chain.
mtv_preds = predicates_for(ctx.assume_info, ctx.current_ssa_idx)

base_ptr = wrap_chain!(cb, base_ptr, ptr_tv.type_id::TypeId,
mtv_preds === nothing ? EMPTY_PREDS : mtv_preds.ptr)
# Wrap each operand `Value` with the `AssumeOp` chain derived
# on demand from the divby/bounds dataflow plus this MTV's spec.
# `wrap_for` consults `ctx.assume_wrapped` so a `Value` shared
# with another consumer (e.g. a gather over the same kernel-arg
# ptr) — or with this kernel's entry-time slot wrap — is wrapped
# exactly once. For tuple-typed sizes/strides we walk back via
# `tuple_element_source` to the per-axis source SSA so the
# dataflow query has the right anchor; when the source is opaque
# (wholesale `getfield(arg, :sizes)`) `nothing` falls through to
# spec-only facts.
block = ctx.current_block::Block

# Spec-derived divisor hints (1 = "no info") combine with the
# dataflow via `lcm` inside `op_predicates`, so a missing spec
# collapses cleanly to dataflow-only facts.
align_hint = spec === nothing ? 1 : Int(spec.alignment)
shape_hint(i) = spec === nothing ? 1 : Int(spec.shape_div_by[i])
stride_hint(i) = spec === nothing ? 1 : Int(spec.stride_div_by[i])

base_ptr = wrap_for(ctx, base_ptr, ptr_tv.type_id::TypeId,
op_predicates(ctx.divby_info, ctx.bounds_info,
ptr_arg, :ptr, align_hint))

size_vals = Value[
wrap_chain!(cb, tv.v::Value, tv.type_id::TypeId,
mtv_preds === nothing ? EMPTY_PREDS : mtv_preds.sizes[i])
let elem_op = tuple_element_source(block, sizes_arg, i)
wrap_for(ctx, tv.v::Value, tv.type_id::TypeId,
op_predicates(ctx.divby_info, ctx.bounds_info,
elem_op, :size, shape_hint(i)))
end
for (i, tv) in enumerate(size_tvs)
]
stride_vals = Value[
wrap_chain!(cb, tv.v::Value, tv.type_id::TypeId,
mtv_preds === nothing ? EMPTY_PREDS : mtv_preds.strides[i])
let elem_op = tuple_element_source(block, strides_arg, i),
# Skip the contiguous axis: its stride is statically `1`
# and never enters the bytecode kernel signature
# (`filter_dynamic_strides`).
chain = (spec !== nothing && spec.contiguous && i == 1) ? EMPTY_PREDS :
op_predicates(ctx.divby_info, ctx.bounds_info,
elem_op, :stride, stride_hint(i))
wrap_for(ctx, tv.v::Value, tv.type_id::TypeId, chain)
end
for (i, tv) in enumerate(stride_tvs)
]

Expand All @@ -332,16 +353,50 @@ function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.make_tensor_view), args
return CGVal(tensor_view, tv_type, result_jltype)
end

# Apply a predicate chain to a `Value`: each `AssumeOp` returns a new
# `Value` that becomes the input to the next link. Empty chain returns
# the input unchanged. Caller passes `EMPTY_PREDS` when the assume
# sidecar has no entry for this operand.
@inline function wrap_chain!(cb::CodeBuilder, value::Value, type_id::TypeId,
preds::Vector{AssumePredicate})
"""
wrap_for(ctx, value, type_id, preds) -> Value

Apply an `AssumePredicate` chain to a `Value` at most once across all
consumers. `ctx.assume_wrapped` records the first wrap so subsequent
consumers of the same source `Value` reuse it instead of emitting a
parallel `AssumeOp` chain. Empty chain returns the input unchanged.
Mirrors the role of cuTile Python's `var_map` in
`_passes/propagate_divby.py::_add_assume_divby`.

Cache invariant: the cache keys on `Value` only, *not* on the chain
contents. This is sound only when every consumer-derived chain on a
given `Value` is a subset of the first-seen chain — i.e. the first
wrap establishes an upper bound on the facts that any later consumer
would derive. The pipeline arranges this in two ways:

- **Kernel-arg slots:** `apply_arg_assume_predicates!` runs at kernel
entry and seeds the cache with the spec-tightest chain for each
TileArray-derived flat slot. Consumer-site `op_predicates` calls on
SSAs sourced from the same slot can only re-derive a subset (same
spec hints, equally-tight or looser dataflow), so the cache hit
drops no information.
- **Per-`Value` consistency of structural priors:** `op_predicates`'s
`kind` selector (`:ptr` vs. `:size`/`:stride`) is determined by the
operand's tile type. A single `Value` has one tile type, so all
consumers see the same `kind` and the same structural prior.

If you ever introduce a consumer that derives a *tighter* chain on a
`Value` already wrapped at kernel entry, the cache will silently drop
the extra facts. Either route the new consumer through a fresh `Value`
(common — the post-offset gather ptr already does this) or refine the
cache key.
"""
@inline function wrap_for(ctx::CGCtx, value::Value, type_id::TypeId,
preds::Vector{AssumePredicate})
isempty(preds) && return value
cached = get(ctx.assume_wrapped, value, nothing)
cached !== nothing && return cached
wrapped = value
for p in preds
value = encode_AssumeOp!(cb, type_id, value, p)
wrapped = encode_AssumeOp!(ctx.cb, type_id, wrapped, p)
end
return value
ctx.assume_wrapped[value] = wrapped
return wrapped
end

"""
Expand Down
Loading
Loading