diff --git a/src/compiler/analysis/effects.jl b/src/compiler/analysis/effects.jl index 813ec068..0d18c97b 100644 --- a/src/compiler/analysis/effects.jl +++ b/src/compiler/analysis/effects.jl @@ -43,3 +43,40 @@ function is_atomic_intrinsic(func) end return false end + +""" + intrinsic_effects(func) -> Union{CC.Effects, Nothing} + +Declared effects of a cuTile intrinsic, or `nothing` for non-intrinsic callees. +Single source of truth for transform passes that need per-intrinsic effect +information (rewriter flag recomputation, DCE root classification). + +Starts from `EFFECTS_TOTAL` — intrinsic methods are `not_callable()` bodies with +no observable effect — and applies any `efunc` override. Returns `nothing` for +non-intrinsic callees: purity of arbitrary Julia functions isn't ours to claim, +and callers should treat `nothing` as "unknown, be conservative". +""" +function intrinsic_effects(@nospecialize(func)) + func isa Function || return nothing + parentmodule(func) === Intrinsics || return nothing + effects = CC.EFFECTS_TOTAL + override = efunc(func, effects) + override !== nothing && (effects = override) + return effects +end + +""" + inferred_flags(func) -> UInt32 + +IR flags corresponding to `func`'s declared effects, mirroring inference's +`flags_for_effects`. Used by the rewriter to set fresh flags on inserted or +opcode-changed instructions, so downstream gates (CSE, LICM) see the same +information they would have gotten from a fresh inference. + +Returns `IR_FLAG_NULL` for non-intrinsic callees — see `intrinsic_effects`. +""" +function inferred_flags(@nospecialize(func)) + effects = intrinsic_effects(func) + effects === nothing && return CC.IR_FLAG_NULL + return CC.flags_for_effects(effects) +end diff --git a/src/compiler/transform/dce.jl b/src/compiler/transform/dce.jl index 8ff9f3dc..a6c38047 100644 --- a/src/compiler/transform/dce.jl +++ b/src/compiler/transform/dce.jl @@ -71,10 +71,10 @@ end Check if a statement is side-effectful and must be kept as a root. -Uses the Julia effects system: each cuTile intrinsic has an `efunc` override -that specifies `effect_free=ALWAYS_FALSE` for side-effectful operations -(stores, atomics, assert). Intrinsics without an efunc override are pure. -Unknown calls are conservatively kept. +Uses the Julia effects system via `intrinsic_effects`: each cuTile intrinsic +has an `efunc` override that specifies `effect_free=ALWAYS_FALSE` for +side-effectful operations (stores, atomics, assert). Intrinsics without an +efunc override are pure. Unknown calls are conservatively kept. Mirrors Python cuTile's `_must_keep` (dce.py:205-206) and Julia's compiler `stmt_effect_free` — both classify by per-instruction effect annotations. @@ -91,16 +91,9 @@ function must_keep(block::Block, @nospecialize(s)) call = resolve_call(block, s) if call !== nothing resolved_func, _ = call - # cuTile intrinsics: use the efunc effects system - if resolved_func isa Function && parentmodule(resolved_func) === Intrinsics - # Query the efunc override for this intrinsic - override = efunc(resolved_func, CC.Effects()) - if override !== nothing - # Has custom effects — keep if not effect-free - return override.effect_free !== CC.ALWAYS_TRUE - end - # No efunc override → pure intrinsic, safe to remove - return false + effects = intrinsic_effects(resolved_func) + if effects !== nothing + return effects.effect_free !== CC.ALWAYS_TRUE end # getfield is pure if s isa Expr diff --git a/src/compiler/transform/pipeline.jl b/src/compiler/transform/pipeline.jl index a7702ba0..0dd8a7a5 100644 --- a/src/compiler/transform/pipeline.jl +++ b/src/compiler/transform/pipeline.jl @@ -95,7 +95,8 @@ function commute_arith_transparent(sci, block, inst, match, driver) # Insert broadcast of the scalar to x's shape and register as constant x_shape = size(xT) bc_type = Tile{eltype(xT), Tuple{x_shape...}} - bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type) + bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type; + flag=inferred_flags(Intrinsics.broadcast)) notify_insert!(driver, block, bc) # Side-inject the freshly synthesized constant into the dataflow result so # downstream pattern matches see it. Bypasses tmerge (this is a brand-new @@ -103,14 +104,15 @@ function commute_arith_transparent(sci, block, inst, match, driver) driver.constants[SSAValue(bc)] = convert(eltype(xT), scalar) # Insert op(x, broadcast) with x's type - op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT) + op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT; + flag=inferred_flags(root_func)) notify_insert!(driver, block, op) # Replace root with transparent_op(op_result, s). Func changes - # (subi/addi → reshape/broadcast), so clear the flag — the inferred - # effect bits describe the old op, not the transparent one. + # (subi/addi → reshape/broadcast), so recompute the flag from the new + # func's declared effects — the inferred bits describe the OLD op. block[val.id] = (stmt=Expr(:call, transparent_func, SSAValue(op), match.bindings[:s]), - flag=CC.IR_FLAG_NULL) + flag=inferred_flags(transparent_func)) driver.defs[val] = DefEntry(block, val, transparent_func) push!(driver.worklist, val) add_users_to_worklist!(driver, val) diff --git a/src/compiler/transform/rewrite.jl b/src/compiler/transform/rewrite.jl index 1b98c962..fbcd0d56 100644 --- a/src/compiler/transform/rewrite.jl +++ b/src/compiler/transform/rewrite.jl @@ -384,7 +384,8 @@ function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, roo typ = CC.widenconst(t) break end - inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ) + inst = insert_before!(block, ref, Expr(:call, op.func, operands...), typ; + flag=inferred_flags(op.func)) notify_insert!(driver, block, inst) SSAValue(inst) end @@ -401,13 +402,14 @@ function apply_inplace_rewrite!(driver::RewriteDriver, block, val::SSAValue, rul for (op, lhs_op) in zip(rule.rhs.operands, rule.lhs.operands)] new_stmt = Expr(:call, rule.rhs.func, new_operands...) # Same-func rewrites (most common: only operands change) preserve flag - # via the partial-NamedTuple setindex. Different-func rewrites clear it to - # IR_FLAG_NULL since the inferred effects describe the OLD op (LLVM - # `copyIRFlags` analogue: don't blanket-inherit when the opcode changes). + # via the partial-NamedTuple setindex. Different-func rewrites recompute + # the flag from the new func's declared effects (`efunc` overrides), + # mirroring inference's `flags_for_effects` — the inferred bits on the + # old call describe the OLD op and don't carry over. if rule.rhs.func === driver.defs[val].func block[val.id] = (stmt=new_stmt,) else - block[val.id] = (stmt=new_stmt, flag=CC.IR_FLAG_NULL) + block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func)) end driver.defs[val] = DefEntry(block, val, rule.rhs.func) push!(driver.worklist, val) @@ -526,13 +528,13 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match end # The substituted func almost always differs from the matched root # (otherwise this would be an inplace rule). Inferred IR_FLAG_* bits - # describe the OLD op; reset to IR_FLAG_NULL on func change so - # downstream effect checks (CSE, LICM) don't act on stale info. + # describe the OLD op; recompute from the new func's `efunc` effects + # so downstream gates (CSE, LICM) see fresh, correct information. new_stmt = Expr(:call, rule.rhs.func, operands...) if rule.rhs.func === driver.defs[val].func block[val.id] = (stmt=new_stmt,) else - block[val.id] = (stmt=new_stmt, flag=CC.IR_FLAG_NULL) + block[val.id] = (stmt=new_stmt, flag=inferred_flags(rule.rhs.func)) end # Update defs, re-add self and users to worklist (statement changed) driver.defs[val] = DefEntry(block, val, rule.rhs.func)