diff --git a/src/compiler/analysis/bounds.jl b/src/compiler/analysis/bounds.jl index 567251fd..f411f3ab 100644 --- a/src/compiler/analysis/bounds.jl +++ b/src/compiler/analysis/bounds.jl @@ -143,7 +143,9 @@ function transfer(a::BoundsAnalysis, r::DataflowResult, @nospecialize(func), # `arr.strides[i]` are non-negative by construction (Int32 widths # carry `[0, ∞)` even before launch-value specialisation). if func === Base.getfield - return getfield_bounds(r, ops, block) + ref = decode_tilearray_field(block, ops) + ref === nothing && return TOP_RANGE + return tilearray_field_bounds(ref) end return TOP_RANGE @@ -264,32 +266,13 @@ end getfield bounds =============================================================================# -# Recognise `getfield(arg::TileArray, :ptr|:sizes|:strides)` (top — they -# return tuples or pointers, no integer range) and the two-step chain -# `getfield(getfield(arg, :sizes|:strides), i)` (non-negative integer). -function getfield_bounds(r::DataflowResult, ops, block::Block) - length(ops) >= 2 || return TOP_RANGE - obj = ops[1] - field = ops[2] isa QuoteNode ? ops[2].value : ops[2] - - obj_T = value_type(block, obj) - obj_T = obj_T === nothing ? Any : CC.widenconst(obj_T) - obj_T <: TileArray && return TOP_RANGE # ptr or whole tuple — no scalar range - - obj isa SSAValue || return TOP_RANGE - obj_def = lookup_def_call(block, obj) - obj_def === nothing && return TOP_RANGE - obj_func, obj_ops = obj_def - obj_func === Base.getfield || return TOP_RANGE - length(obj_ops) >= 2 || return TOP_RANGE - - inner_field = obj_ops[2] isa QuoteNode ? obj_ops[2].value : obj_ops[2] - (inner_field === :sizes || inner_field === :strides) || return TOP_RANGE - - inner_T = value_type(block, obj_ops[1]) - inner_T = inner_T === nothing ? Any : CC.widenconst(inner_T) - inner_T <: TileArray || return TOP_RANGE - +# Project a `TileArrayFieldRef` to its bound: per-axis `sizes[i]` / +# `strides[i]` reads are non-negative (Int32 fields carry `[0, ∞)` even +# before launch-value specialisation); pointer and whole-tuple reads have +# no scalar range. +function tilearray_field_bounds(ref::TileArrayFieldRef) + ref.index === nothing && return TOP_RANGE + (ref.field === :sizes || ref.field === :strides) || return TOP_RANGE return nonneg_range() end diff --git a/src/compiler/analysis/constant.jl b/src/compiler/analysis/constant.jl index 73ef4026..6dad1ad5 100644 --- a/src/compiler/analysis/constant.jl +++ b/src/compiler/analysis/constant.jl @@ -69,9 +69,46 @@ function transfer(a::ConstantAnalysis, r::DataflowResult, @nospecialize(func), length(ops) >= 1 return operand_value(a, r, ops[1]) end + # Type-narrowing intrinsics — preserve scalar values across width changes + # so a `1::Int64` field becomes a `1::Int32` after `Int32(stride)` lowers + # to `trunci`. Otherwise downstream `muli(idx, stride_i32)` would lose + # the constant on the convert. Also covers `exti` (widening) and `bitcast` + # (no-op on signless integers). + if (func === Intrinsics.trunci || func === Intrinsics.exti || + func === Intrinsics.bitcast) && length(ops) >= 1 + v = operand_value(a, r, ops[1]) + return v isa Number ? v : CONSTANT_TOP + end + # `getfield(getfield(arg::TileArray, :strides), 1)` for an array with + # `ArraySpec` `contiguous=true` is statically `1` (Julia column-major + # convention: the first dimension is unit-stride). Mirrors the + # `make_tensor_view` codegen, which already inlines the literal `1` for + # the contiguous stride. Without this, gather/scatter offset + # computations leave a runtime `muli(idx, stride1)` that the algebra + # rules can't fold. + if func === Base.getfield + ref = decode_tilearray_field(block, ops) + v = ref === nothing ? nothing : tilearray_field_constant(ref) + v !== nothing && return v + end CONSTANT_TOP end +# Project a `TileArrayFieldRef` to a scalar constant when the spec pins +# the field's value statically. Currently only the contiguous-axis stride +# (= 1) qualifies; sizes are dynamic (only divisibility / bounds are +# encoded), and the pointer is opaque to constant analysis. The +# `contiguous ⟹ stride_div_by[1] ∈ {0, 1}` consistency is enforced by +# `ArraySpec`'s inner constructor, so no defensive check is needed here. +function tilearray_field_constant(ref::TileArrayFieldRef) + ref.field === :strides || return nothing + ref.index == 1 || return nothing + ref.spec.contiguous || return nothing + # Match the strides field's element type: a contiguous-axis stride + # equals `1` in whatever integer width `TileArray.strides` carries. + return eltype(fieldtype(ref.T, :strides))(1) +end + #============================================================================= Public query API =============================================================================# diff --git a/src/compiler/analysis/dataflow.jl b/src/compiler/analysis/dataflow.jl index 4fd023de..ef26dacf 100644 --- a/src/compiler/analysis/dataflow.jl +++ b/src/compiler/analysis/dataflow.jl @@ -151,6 +151,26 @@ the integer itself for a raw `Int` operand. operand_value(a::ForwardAnalysis, r::DataflowResult, @nospecialize(op)) = op isa LatticeAnchor ? r[op] : bottom(a) +""" + lookup_def_call(block::Block, val::SSAValue) -> Union{Tuple, Nothing} + +Walk parent blocks searching for the def of an SSAValue, returning the +resolved `(func, operands)` tuple if it's a call. Returns `nothing` for +non-call defs or unresolvable values. Used by transfer rules that need +to peek through a one-step IR chain (e.g. `getfield(getfield(arg, :strides), i)`). +""" +function lookup_def_call(block::Block, val::SSAValue) + p = block + while p isa Block + entry = get(p.body, val.id, nothing) + if entry !== nothing + return resolve_call(p, entry.stmt) + end + p = p.parent + end + return nothing +end + #============================================================================= Driver diff --git a/src/compiler/analysis/divisibility.jl b/src/compiler/analysis/divisibility.jl index 3bb7e192..1dd60dc7 100644 --- a/src/compiler/analysis/divisibility.jl +++ b/src/compiler/analysis/divisibility.jl @@ -154,68 +154,32 @@ function transfer(a::DivByAnalysis, r::DataflowResult, @nospecialize(func), # Field access on a TileArray-typed Argument: derive from the ArraySpec. if func === Base.getfield - return getfield_divby(r, ops, block) + ref = decode_tilearray_field(block, ops) + ref === nothing && return 1 + return tilearray_field_divby(ref) end return 1 end -# `getfield(obj, field)` — derive divisibility when `obj` traces back to a -# TileArray-typed `Argument`. Handles the two-step chain -# `getfield(getfield(arg, :sizes), i)` and `getfield(getfield(arg, :strides), i)`. -function getfield_divby(r::DataflowResult, ops, block::Block) - length(ops) >= 2 || return 1 - obj = ops[1] - field = ops[2] isa QuoteNode ? ops[2].value : ops[2] - - obj_T = value_type(block, obj) - obj_T = obj_T === nothing ? Any : CC.widenconst(obj_T) - - # First-level: getfield(arg, :ptr | :sizes | :strides) - if obj_T <: TileArray - spec = array_spec(obj_T) - spec === nothing && return 1 - if field === :ptr - return spec.alignment > 0 ? spec.alignment : 1 - end - # :sizes / :strides return a tuple — element-level facts come from - # the second-level getfield handler below. - return 1 - end - - # Second-level: getfield(getfield(arg, :sizes|:strides), i) - # Only meaningful when `obj` is itself a getfield SSA defined in this - # block (or a parent), with `obj.field ∈ {:sizes, :strides}` and its - # source object is a TileArray-typed Argument. - obj isa SSAValue || return 1 - obj_def = lookup_def_call(block, obj) - obj_def === nothing && return 1 - obj_func, obj_ops = obj_def - obj_func === Base.getfield || return 1 - length(obj_ops) >= 2 || return 1 - - inner_field = obj_ops[2] isa QuoteNode ? obj_ops[2].value : obj_ops[2] - inner_field === :sizes || inner_field === :strides || return 1 - - inner_obj = obj_ops[1] - inner_T = value_type(block, inner_obj) - inner_T = inner_T === nothing ? Any : CC.widenconst(inner_T) - inner_T <: TileArray || return 1 - spec = array_spec(inner_T) - spec === nothing && return 1 - - idx = field isa Integer ? Int(field) : nothing - idx === nothing && return 1 - idx >= 1 || return 1 - if inner_field === :sizes - idx <= length(spec.shape_div_by) || return 1 - d = spec.shape_div_by[idx] - return d > 0 ? d : 1 - else # :strides - idx <= length(spec.stride_div_by) || return 1 - d = spec.stride_div_by[idx] - return d > 0 ? d : 1 - end +# Project a `TileArrayFieldRef` to its divisor: +# - `:ptr` → `spec.alignment` (pointer alignment in bytes) +# - `:sizes[i]` → `spec.shape_div_by[i]` +# - `:strides[i]` → `spec.stride_div_by[i]` +# Whole-tuple reads (`getfield(arg, :sizes)` / `:strides`) return 1: the +# element-level facts live one getfield deeper. +function tilearray_field_divby(ref::TileArrayFieldRef) + spec = ref.spec + if ref.index === nothing + ref.field === :ptr || return 1 + return spec.alignment > 0 ? spec.alignment : 1 + end + table = ref.field === :sizes ? spec.shape_div_by : + ref.field === :strides ? spec.stride_div_by : nothing + table === nothing && return 1 + ref.index <= length(table) || return 1 + d = table[ref.index] + return d > 0 ? d : 1 end # Pointee element type of a 0-D pointer base, accepting both `Ptr{T}` and @@ -233,20 +197,6 @@ function ptr_pointee(@nospecialize(T)) return nothing end -# Walk parent blocks searching for the def of an SSAValue, returning the -# resolved (func, operands) tuple if it's a call. Returns nothing otherwise. -function lookup_def_call(block::Block, val::SSAValue) - p = block - while p isa Block - entry = get(p.body, val.id, nothing) - if entry !== nothing - return resolve_call(p, entry.stmt) - end - p = p.parent - end - return nothing -end - #============================================================================= Public query API =============================================================================# diff --git a/src/compiler/analysis/tilearray.jl b/src/compiler/analysis/tilearray.jl new file mode 100644 index 00000000..07e08f55 --- /dev/null +++ b/src/compiler/analysis/tilearray.jl @@ -0,0 +1,98 @@ +# Shared TileArray field-access decoder. +# +# `TileArray` carries its static facts (alignment, contiguity, divisibility) +# in the `ArraySpec` type parameter, but the runtime fields (`ptr`, `sizes`, +# `strides`) are opaque to Julia's compiler. The dataflow analyses therefore +# need to recognise the `getfield` chains that read those fields and project +# the spec's facts onto the appropriate lattice. +# +# Three analyses do this — divisibility, bounds, constant — each previously +# duplicated the same chain walker. This file collects it in one place: +# `decode_tilearray_field` walks the `getfield(...)` operands and returns a +# `TileArrayFieldRef` describing which field of which `ArraySpec` is being +# read; each analysis then projects to its own lattice. +# +# A future PartialStruct-style refactor would push this further: structured +# lattice values seeded at `init_arg`, generic getfield/tuple transfer rules. +# That generalises to user-written `Core.tuple → getfield` patterns inside +# kernels, but at the cost of recursive lattice values across the framework. +# Until a workload demands that, the per-analysis projection here is the +# pragmatic shape — TileArray is the only type whose facts live in a type +# parameter rather than in the lattice itself. + +""" + TileArrayFieldRef + +The result of decoding a `getfield` chain rooted at a TileArray-typed +`Argument`. Carries: + +- `T::Type` — the rooting `TileArray` type, for projections that need + the runtime field's element type (`fieldtype(T, :strides)` etc.). +- `spec::ArraySpec` — `array_spec(T)`, cached for ergonomic access. +- `field::Symbol` — `:ptr`, `:sizes`, or `:strides`. +- `index::Union{Int,Nothing}` — `nothing` for whole-tuple / pointer reads + (`getfield(arg, :sizes)`, `getfield(arg, :ptr)`); a 1-based positive + integer for element reads (`getfield(getfield(arg, :sizes), i)`). +""" +struct TileArrayFieldRef + T::Type + spec::ArraySpec + field::Symbol + index::Union{Int, Nothing} +end + +""" + decode_tilearray_field(block, ops) -> Union{TileArrayFieldRef, Nothing} + +Decode the operands of a `Base.getfield(...)` call as a TileArray field +reference. Caller is expected to have already established that the current +call is `Base.getfield`. Recognises two shapes: + +- `getfield(arg::TileArray, :ptr | :sizes | :strides)` — returns the spec + with `index = nothing`. +- `getfield(getfield(arg::TileArray, :sizes | :strides), i)` — returns the + spec with `field` set to the inner `:sizes` / `:strides` and `index = i`. + +Returns `nothing` for any other shape (non-TileArray base, missing spec, +non-integer / non-positive element index, opaque inner producer). +""" +function decode_tilearray_field(block::Block, ops) + length(ops) >= 2 || return nothing + field = ops[2] isa QuoteNode ? ops[2].value : ops[2] + obj = ops[1] + + obj_T = value_type(block, obj) + obj_T = obj_T === nothing ? Any : CC.widenconst(obj_T) + + # First-level: getfield(arg::TileArray, :ptr | :sizes | :strides) + if obj_T <: TileArray + spec = array_spec(obj_T) + spec === nothing && return nothing + field isa Symbol || return nothing + return TileArrayFieldRef(obj_T, spec, field, nothing) + end + + # Second-level: getfield(getfield(arg::TileArray, :sizes|:strides), i). + # Walk back through `obj`'s defining call to find the inner getfield. + obj isa SSAValue || return nothing + inner_def = lookup_def_call(block, obj) + inner_def === nothing && return nothing + inner_func, inner_ops = inner_def + inner_func === Base.getfield || return nothing + length(inner_ops) >= 2 || return nothing + + inner_field = inner_ops[2] isa QuoteNode ? inner_ops[2].value : inner_ops[2] + (inner_field === :sizes || inner_field === :strides) || return nothing + + inner_T = value_type(block, inner_ops[1]) + inner_T = inner_T === nothing ? Any : CC.widenconst(inner_T) + inner_T <: TileArray || return nothing + spec = array_spec(inner_T) + spec === nothing && return nothing + + field isa Integer || return nothing + idx = Int(field) + idx >= 1 || return nothing + + return TileArrayFieldRef(inner_T, spec, inner_field, idx) +end diff --git a/src/cuTile.jl b/src/cuTile.jl index da79bef2..2eb8165c 100644 --- a/src/cuTile.jl +++ b/src/cuTile.jl @@ -43,6 +43,7 @@ include("compiler/utils.jl") include("compiler/intrinsics.jl") include("compiler/analysis/dataflow.jl") include("compiler/analysis/alias.jl") +include("compiler/analysis/tilearray.jl") include("compiler/analysis/constant.jl") include("compiler/analysis/effects.jl") include("compiler/analysis/divisibility.jl") diff --git a/src/language/types.jl b/src/language/types.jl index e2dcc3ed..2af0927e 100644 --- a/src/language/types.jl +++ b/src/language/types.jl @@ -23,7 +23,23 @@ Divisibility values enable optimizations: - stride_div_by[i] = 4 means stride[i] is divisible by 4 (enables vectorized access) - shape_div_by[i] = 16 means shape[i] is divisible by 16 (no tile boundary handling needed) """ -struct ArraySpec{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy} end +struct ArraySpec{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy} + # Validate invariants once per concrete spec type (this struct is a + # singleton, so the inner constructor runs on every instantiation but + # the result is then cached as a type parameter). Catches synthetic + # specs that combine `contiguous=true` with a `stride_div_by[1]` that + # contradicts `stride[1] == 1` — `1 % d == 0` only for `d ∈ {0, 1}`. + function ArraySpec{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy}() where + {N, Alignment, Contiguous, StrideDivBy, ShapeDivBy} + if Contiguous && N >= 1 + sdb1 = StrideDivBy[1] + (sdb1 == 0 || sdb1 == 1) || throw(ArgumentError( + "ArraySpec: contiguous=true requires stride_div_by[1] ∈ {0, 1} " * + "(stride[1]=1, and 1 % d == 0 only for d ∈ {0, 1}); got $sdb1")) + end + new{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy}() + end +end # Constructors function ArraySpec{N}(alignment::Int, contiguous::Bool, diff --git a/test/codegen/assume.jl b/test/codegen/assume.jl index d8c39f04..a157e8f6 100644 --- a/test/codegen/assume.jl +++ b/test/codegen/assume.jl @@ -32,7 +32,9 @@ end @testset "assume — per-axis shape divisibility" begin - spec2d = ct.ArraySpec{2}(16, true, (4, 0), (16, 8)) + # stride_div_by left at 0 on the contiguous axis (consistent with + # stride[1]=1); the test only asserts on shape facts. + spec2d = ct.ArraySpec{2}(16, true, (0, 0), (16, 8)) @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}}) do a @@ -50,7 +52,7 @@ end @testset "assume — strides skip the contiguous axis" begin # contiguous=true: stride[1]=1 statically; that operand never enters # the bytecode signature, so no assume is emitted for it. - spec2d = ct.ArraySpec{2}(16, true, (4, 0), (16, 8)) + spec2d = ct.ArraySpec{2}(16, true, (0, 0), (16, 8)) @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}}) do a @@ -67,7 +69,9 @@ end # slice offset is `start * stride` divisible by 4 elements = 16 bytes). # The slice's TileArray type has alignment=0 (conservative), but the # divisibility dataflow recovers gcd(128, 16) = 16 on the offset ptr. - spec1d = ct.ArraySpec{1}(128, true, (4,), (16,)) + # Uses contiguous=false: `stride_div_by[1]>1` is only physically + # consistent with a non-unit stride. + spec1d = ct.ArraySpec{1}(128, false, (4,), (16,)) @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}, Int32, Int32}) do a, i, j @@ -86,8 +90,9 @@ end # `a[1:64]` has start_0 == 0, so offset == 0 bytes; the dataflow # treats the literal `0` as ∞-divisible, so `gcd(spec.alignment, 0)` # == spec.alignment. The slice ptr inherits the source's full - # 128-byte alignment. - spec1d = ct.ArraySpec{1}(128, true, (4,), (16,)) + # 128-byte alignment. Uses contiguous=false: `stride_div_by[1]>1` is + # only physically consistent with a non-unit stride. + spec1d = ct.ArraySpec{1}(128, false, (4,), (16,)) @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a diff --git a/test/codegen/cse.jl b/test/codegen/cse.jl index 9fdea58f..8860c833 100644 --- a/test/codegen/cse.jl +++ b/test/codegen/cse.jl @@ -8,7 +8,7 @@ # Three loads/stores on the same TileArray collapse to one # `make_tensor_view` and one `make_partition_view`. Without CSE, # each `ct.load`/`ct.store` would emit its own getfield+view chain. - spec1d = ct.ArraySpec{1}(16, true, (4,), (16,)) + spec1d = ct.ArraySpec{1}(16, true, (0,), (16,)) @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,1,spec1d}}) do a diff --git a/test/codegen/integration.jl b/test/codegen/integration.jl index 94151030..84f587b2 100644 --- a/test/codegen/integration.jl +++ b/test/codegen/integration.jl @@ -339,8 +339,10 @@ end TILE_M = 32 TILE_N = 1024 - # Use ArraySpec with shape_div_by to match real CuArray behavior - spec2d = ct.ArraySpec{2}(128, true, (4, 0), (32, 32)) + # Use ArraySpec with shape_div_by to match real CuArray behavior. + # The non-trivial stride_div_by lives on axis 2 (consistent with + # contiguous=true on axis 1). + spec2d = ct.ArraySpec{2}(128, true, (0, 4), (32, 32)) spec1d = ct.ArraySpec{1}(128, true, (0,), (32,)) @test @filecheck begin diff --git a/test/codegen/no_wrap.jl b/test/codegen/no_wrap.jl index fcbbf14f..3361c940 100644 --- a/test/codegen/no_wrap.jl +++ b/test/codegen/no_wrap.jl @@ -9,8 +9,10 @@ # `0 × anything = 0` is provably non-wrapping regardless of the # other operand's range. The slice path emits `muli(start_0, # stride)` where `start_0 == 0` for `a[1:N]`; the resulting muli - # picks up the `no_wrap` flag. - spec = ct.ArraySpec{1}(16, true) + # picks up the `no_wrap` flag. Use a non-contiguous spec so the + # constant analysis doesn't fold `stride[1]` to `1` (which would + # eliminate the muli before the no_wrap pass sees it). + spec = ct.ArraySpec{1}(16, false) @test @filecheck begin @check_label "entry" code_tiled(Tuple{ct.TileArray{Float32,1,spec}}) do a diff --git a/test/codegen/slice.jl b/test/codegen/slice.jl index ce743dba..2cc5cac5 100644 --- a/test/codegen/slice.jl +++ b/test/codegen/slice.jl @@ -6,7 +6,7 @@ # - an `offset` for base + offset # - a follow-up `make_tensor_view` on the derived pointer -spec1d = ct.ArraySpec{1}(16, true) +spec1d = ct.ArraySpec{1}(16, false) # non-contiguous so muli isn't folded spec2d = ct.ArraySpec{2}(16, true) @testset "slice — 1D single axis" begin @@ -48,10 +48,13 @@ spec2d = ct.ArraySpec{2}(16, true) end @testset "slice — 2D single axis" begin - # Slice along axis 1; axis 2 is full (`:`). + # Slice along axis 1 with a non-contiguous spec so the contiguous-stride + # constant fold doesn't remove the `muli`. (For a contiguous spec, + # stride[1] is statically `1` and `start * 1` collapses to `start`.) + spec2d_nc = ct.ArraySpec{2}(16, false) @test @filecheck begin @check_label "entry" - code_tiled(Tuple{ct.TileArray{Float32,2,spec2d}, Int32, Int32}) do a, i, j + code_tiled(Tuple{ct.TileArray{Float32,2,spec2d_nc}, Int32, Int32}) do a, i, j sub = @view a[i:j, :] t = ct.load(sub, (1, 1), (4, 4)) ct.store(sub, (1, 1), t)