Skip to content

Triton Conv Kernels First Commit to AITER#2886

Draft
saeid-rostami wants to merge 1 commit intoROCm:mainfrom
saeid-rostami:conv2d-initial
Draft

Triton Conv Kernels First Commit to AITER#2886
saeid-rostami wants to merge 1 commit intoROCm:mainfrom
saeid-rostami:conv2d-initial

Conversation

@saeid-rostami
Copy link
Copy Markdown
Contributor

Adds a Triton conv2d library targeted at AMD RDNA GPUs, plus a
correctness + benchmark harness that compares against PyTorch / MIOpen.

Motivation

PyTorch on AMD goes through MIOpen, whose hand-tuned solvers cover some
dtype × layout × architecture combinations well and others poorly — bf16 in
particular falls back to direct/GEMM solvers on RDNA4 that are noticeably
slower at large channel counts. Most modern checkpoints (LLMs, diffusion VAEs)
ship in bf16, so the gap matters.

This op takes the opposite approach: a single set of Triton kernels that runs
fp16 and bf16 through the same code path, supports NCHW and NHWC, and gets reasonable performance across
the full matrix without per-architecture hand tuning.

What's added

Library (aiter/ops/triton/conv/):

  • conv2d.py — public API + shape-driven router
  • _launch.py — grid setup + _select_3x3_method heuristic
  • _prepack.py — LRU-cached weight/input repack
  • _utils.py — shape math, dtype/activation enums, tolerance model

Kernels (aiter/ops/triton/_triton_kernels/conv/), five families:

Family When it runs
1×1 GEMM R==1, S==1
3×3 cblocked (NCHW) 3×3, channel-blocked input for coalesced loads
3×3 NHWC 3×3 with channels-last input — no input repack
Winograd F(4×4, 3×3) 3×3, stride=1, dilation=1, C ≥ 512, K ≥ 512, enough output tiles
General anything else (5×5, 7×7, dilated, strided)

Test/bench harness (op_tests/triton_tests/conv/):

  • cli.py--test-mode {edge,random,stability,activations,models,all}

  • suite.py — correctness checking + bench accumulation + result tables

  • bench.py — timing + precompute_miopen_solvers (subprocess + MIOPEN_LOG_LEVEL=6
    to label each PyTorch baseline row with the MIOpen solver it picked)

  • test_edge.py / test_fuzz.py / test_models.py — shape sources

  • test_pytest.py — parametrized over fp16/bf16 × nchw/nhwc

  • _registry.py — single source of truth for kernel methods (used by CLI,
    suite, comparison tables, tolerance dispatch)

    Bench shim (op_tests/op_benchmarks/triton/bench_conv2d.py) — convenience
    entry that injects --benchmark --test-mode models.

Docs:

  • aiter/ops/triton/conv/README.md — quick start, headline results, constraints,
    reproducing instructions
  • aiter/ops/triton/conv/DESIGN.md — architecture, per-kernel deep-dive, full
    Winograd F(4,3) derivation (G/Bᵀ/Aᵀ matrices, 361× amplification analysis,
    why Winograd is disabled for C < 4), the routing heuristic, memory layouts,
    numerical model, extension guide

Performance

See aiter/ops/triton/conv/README.md#headline-results
for the full chart set (resnet50 / SD3.5 VAE / FLUX.2 VAE × fp16/bf16 ×
nchw/nhwc × multiple batch sizes, on RDNA4).

Note on TFLOPS: numbers are direct-convolution-equivalent throughput , applied identically to
both backends. Winograd kernels execute fewer literal hardware MACs than this
denominator counts (~4× fewer for F(4,3)). The comparison is apples-to-apples.

Constraints

  • groups must equal 1 — depthwise / grouped not yet implemented.
    Test harness skips grouped layers and prints a banner showing how many were
    skipped (so coverage % is visible).
  • padding_mode must be "zeros". Pad amount is unrestricted; only pad
    value"reflect", "replicate", "circular" are out of scope.
  • Inputs must be fp16 or bf16.
  • Forward only (no backward / training).

Testing

All run on ROCm 7.2 / PyTorch 2.9.1 / Triton 3.7 (commit 23f4e522d).

Run Result
cli --test-mode all --layout both --dtype fp16 484 / 484 passed
cli --test-mode all --layout both --dtype bf16 484 / 484 passed
pytest test_pytest.py -k test_no_bias (× fp16/bf16 × nchw/nhwc) 4 / 4 passed
bench_conv2d --model-name resnet50 --num-layers 5 --layout both (fp16, bf16) exit 0

Per-method correctness: each kernel family is exercised across 12 edge-case
shapes, 200 random shapes, 4 fused activations (none/relu/relu6/gelu), and the
real per-layer shapes captured by hooking ResNet-50 / SD3.5 VAE / FLUX.2 VAE
forwards.

How to use

from aiter.ops.triton.conv.conv2d import conv2d

y = conv2d(
    x, w, bias=None,
    stride=(1, 1), padding=(1, 1), dilation=(1, 1),
    activation="relu",          # "none" | "relu" | "relu6" | "gelu"
    out_dtype=torch.float16,
    layout="nchw",              # "nchw" or "nhwc"
)

Drop-in replacement: walk an nn.Module, swap each nn.Conv2d.forward for one
that calls conv2d(...). Numerical agreement on FLUX.2 VAE end-to-end:
max pixel diff 6/255, mean 0.17/255.

Reproducing the benchmarks

From repo root:

# Correctness (full matrix)
python -m op_tests.triton_tests.conv.cli --test-mode all --layout both --dtype fp16
python -m op_tests.triton_tests.conv.cli --test-mode all --layout both --dtype bf16

# Per-layer TFLOPS table vs PyTorch / MIOpen
python -m op_tests.op_benchmarks.triton.bench_conv2d --model-name resnet50 --num-layers 53
python -m op_tests.op_benchmarks.triton.bench_conv2d --model-name sd35_vae \
    --model-path <path to model>/stable-diffusion-3.5-medium

# 3×3 method comparison (cblocked vs Winograd vs nhwc, side-by-side)
python -m op_tests.op_benchmarks.triton.bench_conv2d --method all --model-name resnet50

# Pytest matrix
pytest op_tests/triton_tests/conv/test_pytest.py

Files

aiter/ops/triton/conv/                        # wrapper + docs
  conv2d.py, _launch.py, _prepack.py, _utils.py, __init__.py
  README.md, DESIGN.md, images/

aiter/ops/triton/_triton_kernels/conv/        # @triton.jit kernels
  conv_1x1.py, conv_3x3.py, conv_3x3_winograd_f4x3.py,
  conv_general.py, helpers.py, __init__.py

op_tests/triton_tests/conv/                   # correctness + bench harness
  cli.py, suite.py, bench.py, _registry.py,
  test_edge.py, test_fuzz.py, test_models.py, test_pytest.py, __init__.py

op_tests/op_benchmarks/triton/bench_conv2d.py # convenience bench shim

@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2886 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant