Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 10 additions & 27 deletions src/compiler/analysis/bounds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ function transfer(a::BoundsAnalysis, r::DataflowResult, @nospecialize(func),
# `arr.strides[i]` are non-negative by construction (Int32 widths
# carry `[0, ∞)` even before launch-value specialisation).
if func === Base.getfield
return getfield_bounds(r, ops, block)
ref = decode_tilearray_field(block, ops)
ref === nothing && return TOP_RANGE
return tilearray_field_bounds(ref)
end

return TOP_RANGE
Expand Down Expand Up @@ -264,32 +266,13 @@ end
getfield bounds
=============================================================================#

# Recognise `getfield(arg::TileArray, :ptr|:sizes|:strides)` (top — they
# return tuples or pointers, no integer range) and the two-step chain
# `getfield(getfield(arg, :sizes|:strides), i)` (non-negative integer).
function getfield_bounds(r::DataflowResult, ops, block::Block)
length(ops) >= 2 || return TOP_RANGE
obj = ops[1]
field = ops[2] isa QuoteNode ? ops[2].value : ops[2]

obj_T = value_type(block, obj)
obj_T = obj_T === nothing ? Any : CC.widenconst(obj_T)
obj_T <: TileArray && return TOP_RANGE # ptr or whole tuple — no scalar range

obj isa SSAValue || return TOP_RANGE
obj_def = lookup_def_call(block, obj)
obj_def === nothing && return TOP_RANGE
obj_func, obj_ops = obj_def
obj_func === Base.getfield || return TOP_RANGE
length(obj_ops) >= 2 || return TOP_RANGE

inner_field = obj_ops[2] isa QuoteNode ? obj_ops[2].value : obj_ops[2]
(inner_field === :sizes || inner_field === :strides) || return TOP_RANGE

inner_T = value_type(block, obj_ops[1])
inner_T = inner_T === nothing ? Any : CC.widenconst(inner_T)
inner_T <: TileArray || return TOP_RANGE

# Project a `TileArrayFieldRef` to its bound: per-axis `sizes[i]` /
# `strides[i]` reads are non-negative (Int32 fields carry `[0, ∞)` even
# before launch-value specialisation); pointer and whole-tuple reads have
# no scalar range.
function tilearray_field_bounds(ref::TileArrayFieldRef)
ref.index === nothing && return TOP_RANGE
(ref.field === :sizes || ref.field === :strides) || return TOP_RANGE
return nonneg_range()
end

Expand Down
37 changes: 37 additions & 0 deletions src/compiler/analysis/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,46 @@ function transfer(a::ConstantAnalysis, r::DataflowResult, @nospecialize(func),
length(ops) >= 1
return operand_value(a, r, ops[1])
end
# Type-narrowing intrinsics — preserve scalar values across width changes
# so a `1::Int64` field becomes a `1::Int32` after `Int32(stride)` lowers
# to `trunci`. Otherwise downstream `muli(idx, stride_i32)` would lose
# the constant on the convert. Also covers `exti` (widening) and `bitcast`
# (no-op on signless integers).
if (func === Intrinsics.trunci || func === Intrinsics.exti ||
func === Intrinsics.bitcast) && length(ops) >= 1
v = operand_value(a, r, ops[1])
return v isa Number ? v : CONSTANT_TOP
end
# `getfield(getfield(arg::TileArray, :strides), 1)` for an array with
# `ArraySpec` `contiguous=true` is statically `1` (Julia column-major
# convention: the first dimension is unit-stride). Mirrors the
# `make_tensor_view` codegen, which already inlines the literal `1` for
# the contiguous stride. Without this, gather/scatter offset
# computations leave a runtime `muli(idx, stride1)` that the algebra
# rules can't fold.
if func === Base.getfield
ref = decode_tilearray_field(block, ops)
v = ref === nothing ? nothing : tilearray_field_constant(ref)
v !== nothing && return v
end
CONSTANT_TOP
end

# Project a `TileArrayFieldRef` to a scalar constant when the spec pins
# the field's value statically. Currently only the contiguous-axis stride
# (= 1) qualifies; sizes are dynamic (only divisibility / bounds are
# encoded), and the pointer is opaque to constant analysis. The
# `contiguous ⟹ stride_div_by[1] ∈ {0, 1}` consistency is enforced by
# `ArraySpec`'s inner constructor, so no defensive check is needed here.
function tilearray_field_constant(ref::TileArrayFieldRef)
ref.field === :strides || return nothing
ref.index == 1 || return nothing
ref.spec.contiguous || return nothing
# Match the strides field's element type: a contiguous-axis stride
# equals `1` in whatever integer width `TileArray.strides` carries.
return eltype(fieldtype(ref.T, :strides))(1)
end

#=============================================================================
Public query API
=============================================================================#
Expand Down
20 changes: 20 additions & 0 deletions src/compiler/analysis/dataflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,26 @@ the integer itself for a raw `Int` operand.
operand_value(a::ForwardAnalysis, r::DataflowResult, @nospecialize(op)) =
op isa LatticeAnchor ? r[op] : bottom(a)

"""
lookup_def_call(block::Block, val::SSAValue) -> Union{Tuple, Nothing}

Walk parent blocks searching for the def of an SSAValue, returning the
resolved `(func, operands)` tuple if it's a call. Returns `nothing` for
non-call defs or unresolvable values. Used by transfer rules that need
to peek through a one-step IR chain (e.g. `getfield(getfield(arg, :strides), i)`).
"""
function lookup_def_call(block::Block, val::SSAValue)
p = block
while p isa Block
entry = get(p.body, val.id, nothing)
if entry !== nothing
return resolve_call(p, entry.stmt)
end
p = p.parent
end
return nothing
end


#=============================================================================
Driver
Expand Down
92 changes: 21 additions & 71 deletions src/compiler/analysis/divisibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,68 +154,32 @@ function transfer(a::DivByAnalysis, r::DataflowResult, @nospecialize(func),

# Field access on a TileArray-typed Argument: derive from the ArraySpec.
if func === Base.getfield
return getfield_divby(r, ops, block)
ref = decode_tilearray_field(block, ops)
ref === nothing && return 1
return tilearray_field_divby(ref)
end

return 1
end

# `getfield(obj, field)` — derive divisibility when `obj` traces back to a
# TileArray-typed `Argument`. Handles the two-step chain
# `getfield(getfield(arg, :sizes), i)` and `getfield(getfield(arg, :strides), i)`.
function getfield_divby(r::DataflowResult, ops, block::Block)
length(ops) >= 2 || return 1
obj = ops[1]
field = ops[2] isa QuoteNode ? ops[2].value : ops[2]

obj_T = value_type(block, obj)
obj_T = obj_T === nothing ? Any : CC.widenconst(obj_T)

# First-level: getfield(arg, :ptr | :sizes | :strides)
if obj_T <: TileArray
spec = array_spec(obj_T)
spec === nothing && return 1
if field === :ptr
return spec.alignment > 0 ? spec.alignment : 1
end
# :sizes / :strides return a tuple — element-level facts come from
# the second-level getfield handler below.
return 1
end

# Second-level: getfield(getfield(arg, :sizes|:strides), i)
# Only meaningful when `obj` is itself a getfield SSA defined in this
# block (or a parent), with `obj.field ∈ {:sizes, :strides}` and its
# source object is a TileArray-typed Argument.
obj isa SSAValue || return 1
obj_def = lookup_def_call(block, obj)
obj_def === nothing && return 1
obj_func, obj_ops = obj_def
obj_func === Base.getfield || return 1
length(obj_ops) >= 2 || return 1

inner_field = obj_ops[2] isa QuoteNode ? obj_ops[2].value : obj_ops[2]
inner_field === :sizes || inner_field === :strides || return 1

inner_obj = obj_ops[1]
inner_T = value_type(block, inner_obj)
inner_T = inner_T === nothing ? Any : CC.widenconst(inner_T)
inner_T <: TileArray || return 1
spec = array_spec(inner_T)
spec === nothing && return 1

idx = field isa Integer ? Int(field) : nothing
idx === nothing && return 1
idx >= 1 || return 1
if inner_field === :sizes
idx <= length(spec.shape_div_by) || return 1
d = spec.shape_div_by[idx]
return d > 0 ? d : 1
else # :strides
idx <= length(spec.stride_div_by) || return 1
d = spec.stride_div_by[idx]
return d > 0 ? d : 1
end
# Project a `TileArrayFieldRef` to its divisor:
# - `:ptr` → `spec.alignment` (pointer alignment in bytes)
# - `:sizes[i]` → `spec.shape_div_by[i]`
# - `:strides[i]` → `spec.stride_div_by[i]`
# Whole-tuple reads (`getfield(arg, :sizes)` / `:strides`) return 1: the
# element-level facts live one getfield deeper.
function tilearray_field_divby(ref::TileArrayFieldRef)
spec = ref.spec
if ref.index === nothing
ref.field === :ptr || return 1
return spec.alignment > 0 ? spec.alignment : 1
end
table = ref.field === :sizes ? spec.shape_div_by :
ref.field === :strides ? spec.stride_div_by : nothing
table === nothing && return 1
ref.index <= length(table) || return 1
d = table[ref.index]
return d > 0 ? d : 1
end

# Pointee element type of a 0-D pointer base, accepting both `Ptr{T}` and
Expand All @@ -233,20 +197,6 @@ function ptr_pointee(@nospecialize(T))
return nothing
end

# Walk parent blocks searching for the def of an SSAValue, returning the
# resolved (func, operands) tuple if it's a call. Returns nothing otherwise.
function lookup_def_call(block::Block, val::SSAValue)
p = block
while p isa Block
entry = get(p.body, val.id, nothing)
if entry !== nothing
return resolve_call(p, entry.stmt)
end
p = p.parent
end
return nothing
end

#=============================================================================
Public query API
=============================================================================#
Expand Down
98 changes: 98 additions & 0 deletions src/compiler/analysis/tilearray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Shared TileArray field-access decoder.
#
# `TileArray` carries its static facts (alignment, contiguity, divisibility)
# in the `ArraySpec` type parameter, but the runtime fields (`ptr`, `sizes`,
# `strides`) are opaque to Julia's compiler. The dataflow analyses therefore
# need to recognise the `getfield` chains that read those fields and project
# the spec's facts onto the appropriate lattice.
#
# Three analyses do this — divisibility, bounds, constant — each previously
# duplicated the same chain walker. This file collects it in one place:
# `decode_tilearray_field` walks the `getfield(...)` operands and returns a
# `TileArrayFieldRef` describing which field of which `ArraySpec` is being
# read; each analysis then projects to its own lattice.
#
# A future PartialStruct-style refactor would push this further: structured
# lattice values seeded at `init_arg`, generic getfield/tuple transfer rules.
# That generalises to user-written `Core.tuple → getfield` patterns inside
# kernels, but at the cost of recursive lattice values across the framework.
# Until a workload demands that, the per-analysis projection here is the
# pragmatic shape — TileArray is the only type whose facts live in a type
# parameter rather than in the lattice itself.

"""
TileArrayFieldRef

The result of decoding a `getfield` chain rooted at a TileArray-typed
`Argument`. Carries:

- `T::Type` — the rooting `TileArray` type, for projections that need
the runtime field's element type (`fieldtype(T, :strides)` etc.).
- `spec::ArraySpec` — `array_spec(T)`, cached for ergonomic access.
- `field::Symbol` — `:ptr`, `:sizes`, or `:strides`.
- `index::Union{Int,Nothing}` — `nothing` for whole-tuple / pointer reads
(`getfield(arg, :sizes)`, `getfield(arg, :ptr)`); a 1-based positive
integer for element reads (`getfield(getfield(arg, :sizes), i)`).
"""
struct TileArrayFieldRef
T::Type
spec::ArraySpec
field::Symbol
index::Union{Int, Nothing}
end

"""
decode_tilearray_field(block, ops) -> Union{TileArrayFieldRef, Nothing}

Decode the operands of a `Base.getfield(...)` call as a TileArray field
reference. Caller is expected to have already established that the current
call is `Base.getfield`. Recognises two shapes:

- `getfield(arg::TileArray, :ptr | :sizes | :strides)` — returns the spec
with `index = nothing`.
- `getfield(getfield(arg::TileArray, :sizes | :strides), i)` — returns the
spec with `field` set to the inner `:sizes` / `:strides` and `index = i`.

Returns `nothing` for any other shape (non-TileArray base, missing spec,
non-integer / non-positive element index, opaque inner producer).
"""
function decode_tilearray_field(block::Block, ops)
length(ops) >= 2 || return nothing
field = ops[2] isa QuoteNode ? ops[2].value : ops[2]
obj = ops[1]

obj_T = value_type(block, obj)
obj_T = obj_T === nothing ? Any : CC.widenconst(obj_T)

# First-level: getfield(arg::TileArray, :ptr | :sizes | :strides)
if obj_T <: TileArray
spec = array_spec(obj_T)
spec === nothing && return nothing
field isa Symbol || return nothing
return TileArrayFieldRef(obj_T, spec, field, nothing)
end

# Second-level: getfield(getfield(arg::TileArray, :sizes|:strides), i).
# Walk back through `obj`'s defining call to find the inner getfield.
obj isa SSAValue || return nothing
inner_def = lookup_def_call(block, obj)
inner_def === nothing && return nothing
inner_func, inner_ops = inner_def
inner_func === Base.getfield || return nothing
length(inner_ops) >= 2 || return nothing

inner_field = inner_ops[2] isa QuoteNode ? inner_ops[2].value : inner_ops[2]
(inner_field === :sizes || inner_field === :strides) || return nothing

inner_T = value_type(block, inner_ops[1])
inner_T = inner_T === nothing ? Any : CC.widenconst(inner_T)
inner_T <: TileArray || return nothing
spec = array_spec(inner_T)
spec === nothing && return nothing

field isa Integer || return nothing
idx = Int(field)
idx >= 1 || return nothing

return TileArrayFieldRef(inner_T, spec, inner_field, idx)
end
1 change: 1 addition & 0 deletions src/cuTile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("compiler/utils.jl")
include("compiler/intrinsics.jl")
include("compiler/analysis/dataflow.jl")
include("compiler/analysis/alias.jl")
include("compiler/analysis/tilearray.jl")
include("compiler/analysis/constant.jl")
include("compiler/analysis/effects.jl")
include("compiler/analysis/divisibility.jl")
Expand Down
18 changes: 17 additions & 1 deletion src/language/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,23 @@ Divisibility values enable optimizations:
- stride_div_by[i] = 4 means stride[i] is divisible by 4 (enables vectorized access)
- shape_div_by[i] = 16 means shape[i] is divisible by 16 (no tile boundary handling needed)
"""
struct ArraySpec{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy} end
struct ArraySpec{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy}
# Validate invariants once per concrete spec type (this struct is a
# singleton, so the inner constructor runs on every instantiation but
# the result is then cached as a type parameter). Catches synthetic
# specs that combine `contiguous=true` with a `stride_div_by[1]` that
# contradicts `stride[1] == 1` — `1 % d == 0` only for `d ∈ {0, 1}`.
function ArraySpec{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy}() where
{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy}
if Contiguous && N >= 1
sdb1 = StrideDivBy[1]
(sdb1 == 0 || sdb1 == 1) || throw(ArgumentError(
"ArraySpec: contiguous=true requires stride_div_by[1] ∈ {0, 1} " *
"(stride[1]=1, and 1 % d == 0 only for d ∈ {0, 1}); got $sdb1"))
end
new{N, Alignment, Contiguous, StrideDivBy, ShapeDivBy}()
end
end

# Constructors
function ArraySpec{N}(alignment::Int, contiguous::Bool,
Expand Down
Loading
Loading