Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/compiler/analysis/effects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,40 @@ function is_atomic_intrinsic(func)
end
return false
end

"""
intrinsic_effects(func) -> Union{CC.Effects, Nothing}

Declared effects of a cuTile intrinsic, or `nothing` for non-intrinsic callees.
Single source of truth for transform passes that need per-intrinsic effect
information (rewriter flag recomputation, DCE root classification).

Starts from `EFFECTS_TOTAL` — intrinsic methods are `not_callable()` bodies with
no observable effect — and applies any `efunc` override. Returns `nothing` for
non-intrinsic callees: purity of arbitrary Julia functions isn't ours to claim,
and callers should treat `nothing` as "unknown, be conservative".
"""
function intrinsic_effects(@nospecialize(func))
func isa Function || return nothing
parentmodule(func) === Intrinsics || return nothing
effects = CC.EFFECTS_TOTAL
override = efunc(func, effects)
override !== nothing && (effects = override)
return effects
end

"""
inferred_flags(func) -> UInt32

IR flags corresponding to `func`'s declared effects, mirroring inference's
`flags_for_effects`. Used by the rewriter to set fresh flags on inserted or
opcode-changed instructions, so downstream gates (CSE, LICM) see the same
information they would have gotten from a fresh inference.

Returns `IR_FLAG_NULL` for non-intrinsic callees — see `intrinsic_effects`.
"""
function inferred_flags(@nospecialize(func))
effects = intrinsic_effects(func)
effects === nothing && return CC.IR_FLAG_NULL
return CC.flags_for_effects(effects)
end
21 changes: 7 additions & 14 deletions src/compiler/transform/dce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ end

Check if a statement is side-effectful and must be kept as a root.

Uses the Julia effects system: each cuTile intrinsic has an `efunc` override
that specifies `effect_free=ALWAYS_FALSE` for side-effectful operations
(stores, atomics, assert). Intrinsics without an efunc override are pure.
Unknown calls are conservatively kept.
Uses the Julia effects system via `intrinsic_effects`: each cuTile intrinsic
has an `efunc` override that specifies `effect_free=ALWAYS_FALSE` for
side-effectful operations (stores, atomics, assert). Intrinsics without an
efunc override are pure. Unknown calls are conservatively kept.

Mirrors Python cuTile's `_must_keep` (dce.py:205-206) and Julia's compiler
`stmt_effect_free` — both classify by per-instruction effect annotations.
Expand All @@ -91,16 +91,9 @@ function must_keep(block::Block, @nospecialize(s))
call = resolve_call(block, s)
if call !== nothing
resolved_func, _ = call
# cuTile intrinsics: use the efunc effects system
if resolved_func isa Function && parentmodule(resolved_func) === Intrinsics
# Query the efunc override for this intrinsic
override = efunc(resolved_func, CC.Effects())
if override !== nothing
# Has custom effects — keep if not effect-free
return override.effect_free !== CC.ALWAYS_TRUE
end
# No efunc override → pure intrinsic, safe to remove
return false
effects = intrinsic_effects(resolved_func)
if effects !== nothing
return effects.effect_free !== CC.ALWAYS_TRUE
end
# getfield is pure
if s isa Expr
Expand Down
12 changes: 7 additions & 5 deletions src/compiler/transform/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,22 +95,24 @@ function commute_arith_transparent(sci, block, inst, match, driver)
# Insert broadcast of the scalar to x's shape and register as constant
x_shape = size(xT)
bc_type = Tile{eltype(xT), Tuple{x_shape...}}
bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type)
bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type;
flag=inferred_flags(Intrinsics.broadcast))
notify_insert!(driver, block, bc)
# Side-inject the freshly synthesized constant into the dataflow result so
# downstream pattern matches see it. Bypasses tmerge (this is a brand-new
# SSA value, not a merge).
driver.constants[SSAValue(bc)] = convert(eltype(xT), scalar)

# Insert op(x, broadcast) with x's type
op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT)
op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT;
flag=inferred_flags(root_func))
notify_insert!(driver, block, op)

# Replace root with transparent_op(op_result, s). Func changes
# (subi/addi → reshape/broadcast), so clear the flag the inferred
# effect bits describe the old op, not the transparent one.
# (subi/addi → reshape/broadcast), so recompute the flag from the new
# func's declared effects — the inferred bits describe the OLD op.
block[val.id] = (stmt=Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]),
flag=CC.IR_FLAG_NULL)
flag=inferred_flags(transparent_func))
driver.defs[val] = DefEntry(block, val, transparent_func)
push!(driver.worklist, val)
add_users_to_worklist!(driver, val)
Expand Down
18 changes: 10 additions & 8 deletions src/compiler/transform/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, roo
typ = CC.widenconst(t)
break
end
inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ)
inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ;
flag=inferred_flags(op.func))
notify_insert!(driver, block, inst)
SSAValue(inst)
end
Expand All @@ -401,13 +402,14 @@ function apply_inplace_rewrite!(driver::RewriteDriver, block, val::SSAValue, rul
for (op, lhs_op) in zip(rule.rhs.operands, rule.lhs.operands)]
new_stmt = Expr(:call, rule.rhs.func, new_operands...)
# Same-func rewrites (most common: only operands change) preserve flag
# via the partial-NamedTuple setindex. Different-func rewrites clear it to
# IR_FLAG_NULL since the inferred effects describe the OLD op (LLVM
# `copyIRFlags` analogue: don't blanket-inherit when the opcode changes).
# via the partial-NamedTuple setindex. Different-func rewrites recompute
# the flag from the new func's declared effects (`efunc` overrides),
# mirroring inference's `flags_for_effects` — the inferred bits on the
# old call describe the OLD op and don't carry over.
if rule.rhs.func === driver.defs[val].func
block[val.id] = (stmt=new_stmt,)
else
block[val.id] = (stmt=new_stmt, flag=CC.IR_FLAG_NULL)
block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func))
end
driver.defs[val] = DefEntry(block, val, rule.rhs.func)
push!(driver.worklist, val)
Expand Down Expand Up @@ -526,13 +528,13 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
end
# The substituted func almost always differs from the matched root
# (otherwise this would be an inplace rule). Inferred IR_FLAG_* bits
# describe the OLD op; reset to IR_FLAG_NULL on func change so
# downstream effect checks (CSE, LICM) don't act on stale info.
# describe the OLD op; recompute from the new func's `efunc` effects
# so downstream gates (CSE, LICM) see fresh, correct information.
new_stmt = Expr(:call, rule.rhs.func, operands...)
if rule.rhs.func === driver.defs[val].func
block[val.id] = (stmt=new_stmt,)
else
block[val.id] = (stmt=new_stmt, flag=CC.IR_FLAG_NULL)
block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func))
end
# Update defs, re-add self and users to worklist (statement changed)
driver.defs[val] = DefEntry(block, val, rule.rhs.func)
Expand Down
Loading