Triton Conv Kernels First Commit to AITER#2886
Draft
saeid-rostami wants to merge 1 commit intoROCm:mainfrom
Draft
Triton Conv Kernels First Commit to AITER#2886saeid-rostami wants to merge 1 commit intoROCm:mainfrom
saeid-rostami wants to merge 1 commit intoROCm:mainfrom
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a Triton conv2d library targeted at AMD RDNA GPUs, plus a
correctness + benchmark harness that compares against PyTorch / MIOpen.
Motivation
PyTorch on AMD goes through MIOpen, whose hand-tuned solvers cover some
dtype × layout × architecture combinations well and others poorly — bf16 in
particular falls back to direct/GEMM solvers on RDNA4 that are noticeably
slower at large channel counts. Most modern checkpoints (LLMs, diffusion VAEs)
ship in bf16, so the gap matters.
This op takes the opposite approach: a single set of Triton kernels that runs
fp16 and bf16 through the same code path, supports NCHW and NHWC, and gets reasonable performance across
the full matrix without per-architecture hand tuning.
What's added
Library (
aiter/ops/triton/conv/):conv2d.py— public API + shape-driven router_launch.py— grid setup +_select_3x3_methodheuristic_prepack.py— LRU-cached weight/input repack_utils.py— shape math, dtype/activation enums, tolerance modelKernels (
aiter/ops/triton/_triton_kernels/conv/), five families:R==1, S==1C ≥ 512,K ≥ 512, enough output tilesTest/bench harness (
op_tests/triton_tests/conv/):cli.py—--test-mode {edge,random,stability,activations,models,all}suite.py— correctness checking + bench accumulation + result tablesbench.py— timing +precompute_miopen_solvers(subprocess +MIOPEN_LOG_LEVEL=6to label each PyTorch baseline row with the MIOpen solver it picked)
test_edge.py/test_fuzz.py/test_models.py— shape sourcestest_pytest.py— parametrized over fp16/bf16 × nchw/nhwc_registry.py— single source of truth for kernel methods (used by CLI,suite, comparison tables, tolerance dispatch)
Bench shim (
op_tests/op_benchmarks/triton/bench_conv2d.py) — convenienceentry that injects
--benchmark --test-mode models.Docs:
aiter/ops/triton/conv/README.md— quick start, headline results, constraints,reproducing instructions
aiter/ops/triton/conv/DESIGN.md— architecture, per-kernel deep-dive, fullWinograd F(4,3) derivation (G/Bᵀ/Aᵀ matrices, 361× amplification analysis,
why Winograd is disabled for
C < 4), the routing heuristic, memory layouts,numerical model, extension guide
Performance
See
aiter/ops/triton/conv/README.md#headline-resultsfor the full chart set (resnet50 / SD3.5 VAE / FLUX.2 VAE × fp16/bf16 ×
nchw/nhwc × multiple batch sizes, on RDNA4).
Constraints
groupsmust equal 1 — depthwise / grouped not yet implemented.Test harness skips grouped layers and prints a banner showing how many were
skipped (so coverage % is visible).
padding_modemust be"zeros". Pad amount is unrestricted; only padvalue —
"reflect","replicate","circular"are out of scope.fp16orbf16.Testing
All run on ROCm 7.2 / PyTorch 2.9.1 / Triton 3.7 (commit
23f4e522d).cli --test-mode all --layout both --dtype fp16cli --test-mode all --layout both --dtype bf16pytest test_pytest.py -k test_no_bias(× fp16/bf16 × nchw/nhwc)bench_conv2d --model-name resnet50 --num-layers 5 --layout both(fp16, bf16)Per-method correctness: each kernel family is exercised across 12 edge-case
shapes, 200 random shapes, 4 fused activations (none/relu/relu6/gelu), and the
real per-layer shapes captured by hooking ResNet-50 / SD3.5 VAE / FLUX.2 VAE
forwards.
How to use