Skip to content

Add FlyDSL GemmTuner integration#2816

Closed
apicciau wants to merge 2 commits intomainfrom
apicciau/flydsl_a16w16_tuner
Closed

Add FlyDSL GemmTuner integration#2816
apicciau wants to merge 2 commits intomainfrom
apicciau/flydsl_a16w16_tuner

Conversation

@apicciau
Copy link
Copy Markdown
Contributor

@apicciau apicciau commented Apr 20, 2026

No description provided.

…x1250

Introduces the FlyDSL A16W16 GEMM kernel for RDNA4 (gfx1250) and integrates
it as a first-class tunable backend in GemmTuner, alongside the existing
splitk_hgemm and ASM paths.

New files:
- aiter/ops/flydsl/kernels/gemm_a16w16_gfx1250.py: WMMA 16x16x32 kernel
  using RDNA4 wave32; handles K-padding and N-stride internally; supports
  fp16/bf16 input, configurable tiling (tile_m/n/k), warp layout (m/n_warp),
  double-buffering (num_buffers), waves_per_eu, and L2 prefetch distance

Changes to existing files:
- aiter/ops/flydsl/gemm_kernels.py: add get_flydsl_a16w16_gfx1250_kernels()
  catalog and get_flydsl_a16w16_gfx1250_kernel_params() lookup; kernel name
  encodes all config parameters for reversible CSV serialisation
- gradlib/gradlib/GemmTuner.py: import the new kernel; add
  run_flydsl_gemm_a16w16() run function; add flydsl_a16w16_gemm_all_sols()
  enumerator; route gfx1250 through the a16w16 path in run_asm_triton_sols()
  while other architectures continue using the existing splitk_hgemm path;
  also restores the ASM SplitK semaphore guard (gdx*gdy <= 1024) that was
  missing on main (also tracked in PR #2721)
- aiter/tuned_gemm.py: add flydsl_a16w16_gemm() dispatch function; update
  the flydsl config lookup to resolve a16w16 kernel names, falling back to
  splitk_hgemm; select the correct call site based on the resolved config
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2816 --add-label <label>

@@ -0,0 +1,857 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
s
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F821> reported by reviewdog 🐶
Undefined name s

# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
s
import torch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
s
import torch
import flydsl.compiler as flyc
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

s
import torch
import flydsl.compiler as flyc
import flydsl.expr as fx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

import torch
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl._mlir import ir
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from flydsl.compiler.kernel_function import CompilationContext
from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector
from flydsl.expr.arith import _to_raw as _raw
from flydsl.expr.typing import T
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector
from flydsl.expr.arith import _to_raw as _raw
from flydsl.expr.typing import T
from flydsl.runtime.device import get_rocm_arch as get_hip_arch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from flydsl.expr.arith import _to_raw as _raw
from flydsl.expr.typing import T
from flydsl.runtime.device import get_rocm_arch as get_hip_arch
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from flydsl.expr.typing import T
from flydsl.runtime.device import get_rocm_arch as get_hip_arch
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value
from flydsl.expr import idx2crd
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from flydsl.runtime.device import get_rocm_arch as get_hip_arch
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value
from flydsl.expr import idx2crd
from typing import Optional
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

@apicciau apicciau changed the title feat: add FlyDSL A16W16 GEMM kernel and GemmTuner integration for gfx1250 (RDNA4) Add FlyDSL A16W16 GEMM kernel and GemmTuner integration Apr 20, 2026
…le candidates, warn on unresolved kernel names

- flydsl_gemm() now passes stages/async_copy/c_to_lds from the stored
  catalog config to flydsl_hgemm(), matching what was benchmarked at
  tune time
- flydsl_gemm_all_sols() skips tile_m configs larger than max(M, 16),
  reducing the candidate search space for small-M shapes
- get_GEMM_A16W16_config() emits a warning when a stored FlyDSL kernel
  name cannot be resolved against the current catalog, instead of
  silently falling back to torch
@apicciau apicciau closed this Apr 20, 2026
@apicciau apicciau changed the title Add FlyDSL A16W16 GEMM kernel and GemmTuner integration Add FlyDSL GemmTuner integration Apr 20, 2026
@apicciau
Copy link
Copy Markdown
Contributor Author

This conflicts with another recent PR, changing strategy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant