Skip to content

[WIP]perf: add Gluon MoE kernels for GPT-OSS#314

Open
knwng wants to merge 45 commits into
mainfrom
kylewng/gluon_moe
Open

[WIP]perf: add Gluon MoE kernels for GPT-OSS#314
knwng wants to merge 45 commits into
mainfrom
kylewng/gluon_moe

Conversation

@knwng
Copy link
Copy Markdown
Contributor

@knwng knwng commented May 29, 2026

Added support for Gluon MoE kernels and applied optimizations including:

  • sliceMN for prefill shapes
  • plain pipelining for decode shapes for now
  • weight preshuffling on host

kylewng and others added 30 commits May 10, 2026 14:29
Adds a new ``Mxfp4Fp8TritonKernelBackend`` that consumes AMD-Quark
``w_mxfp4_a_fp8`` checkpoints (e.g. ``amd/gpt-oss-120b-w-mxfp4-a-fp8``):
fp8 e4m3 activations with per-tensor static scales paired with the
existing mxfp4 weight path through ``triton_kernels.matmul``.

Pieces:

* ``Mxfp4Config`` learns to detect Quark ``w_mxfp4_a_fp8`` and surface
  an ``is_w4a8_fp8`` flag plus the excluded layers; ``mxfp4`` selector
  preference promotes the new backend ahead of bf16/flashinfer fallbacks.
* New backend allocates per-expert ``w13_input_scale`` /
  ``w2_input_scale`` parameters, collapses them to per-tensor scales and
  builds ``PrecisionConfig(flex_ctx.lhs_data=InFlexData(fp8, scale),
  out_dtype=bf16)`` so the GEMM output stays in bf16 for swiglu and the
  second quantisation step.
* ``gpt_oss._load_mxfp4_weights`` learns the AMD per-expert checkpoint
  layout (``...experts.{e}.{gate_up_proj,down_proj}.{weight,
  weight_scale,bias,input_scale}``) and dispatches to a new loader that
  shards each expert tensor onto the existing fused ``w13_*`` / ``w2_*``
  parameters.
* Fix a pre-existing torch / tokenspeed_triton ABI mismatch that
  segfaults ``libtriton.so`` whenever it is loaded after torch by
  importing ``tokenspeed_kernel`` first in ``tokenspeed/__init__.py``
  and ``tokenspeed.runtime.utils.common`` (and in the affected
  test files).
* New unit tests in ``test/runtime/test_mxfp4_fp8_backend.py`` cover
  Quark detection, the override path and the per-expert input-scale
  loader.

Eval (gpqa_diamond, medium reasoning, TP=2): mean_acc 0.6616,
avg output throughput 79.37 tok/s (vs ~33 tok/s baseline on the same
hardware).

Signed-off-by: kylewng <kylewng@users.noreply.github.com>
Adds the first port of the gfx1250 Gluon MoE example to CDNA4 / gfx950
(MI355). The new kernel sits in
``tokenspeed_kernel.ops.moe.gluon`` and ships a baseline (non-pipelined)
bf16 ragged matmul that uses MFMA v4 + Swizzled Shared layouts. The
launcher transparently falls back to the upstream
``triton_kernels.matmul`` for any unsupported precision (mxfp4 weights,
fp8 activations, fused swiglu, persistent / split-K), keeping the public
op signature stable while we graduate features.

Three kernel specs are registered (matmul_ogs / dispatch_gemm /
gemm_combine) and gated behind a ``TOKENSPEED_MOE_GLUON=1`` env knob:
when set, priority is ``Priority.SPECIALIZED + 1`` (above triton_kernels);
otherwise priority is ``Priority.PORTABLE + 1`` so default behavior is
unchanged.

Includes:

* tokenspeed_kernel/ops/moe/gluon.py
  + tokenspeed_kernel/ops/moe/__init__.py wiring
* tokenspeed-kernel/test/ops/test_moe_gluon.py
  - 3 correctness tests vs torch reference (passes within bf16 tol)
  - 1 selection test under TOKENSPEED_MOE_GLUON=1
  - 1 fallback-for-mxfp4 test
* tokenspeed-kernel/benchmarks/moe_gluon_microbench.py
  - block_m/block_n/block_k/num_warps grid sweep against the upstream
    triton_kernels baseline
* task-progress-2.md documenting API/layout differences from gfx1250 to
  CDNA4, perf vs baseline (~46% on 2880x2880x32E) and the explicit
  follow-ups needed to graduate to mxfp4 + software-pipelined paths.

Signed-off-by: Kyle Wang <kwang102@amd.com>
Per `TASKS.md:31-38`, finish the MI355 Gluon MoE kernel work:

* Re-organise `_pipelined_moe_kernel` as a single `gluon.jit` body that
  feeds three public launchers via `constexpr` flags
  (`HAS_BIAS`, `HAS_GATHER`, `HAS_SCATTER`, `DO_SWIGLU`,
  `APPLY_GATE_SCAL`):
    - `gluon_bf16_gating_gemm`         (dense bf16xbf16 GEMM)
    - `gluon_bf16_dispatch_swiglu`     (dispatch + 1st GEMM + SwiGLU)
    - `gluon_bf16_combine`             (2nd GEMM + scatter combine)

* Add a register-staged software pipeline (prefetch the next K tile into
  registers while the current MFMA executes) modelled on
  `triton-450/.../moe_gfx1250.py::MoEPipelinedProgram`.  We initially
  tried `buffer_load_to_shared` + `commit_group/wait_group` like
  `f16_gemm_gfx950.py`, but Triton/MLIR can't currently legalise the
  per-expert dynamic base pointer through the async-DMA op; the
  register-staged variant already clears the perf bar so async LDS DMA
  is left as a follow-up.

* Add `static_profile` / `assert_no_spills` (port of
  `gfx1250_utils.py`) and wire them into both the unit test
  (`test_no_register_spill`) and the microbench summary so any future
  change that introduces sgpr/vgpr spills fails CI.

* New shape-aware `_autotune_block` heuristic obtained from a
  microbench sweep: dense GEMM uses BLOCK_N=64 to keep grid_n high;
  SwiGLU paths use 64x64x32 + 4 warps to leave VGPR headroom for the
  reduce.

* Update `test_moe_gluon.py` with correctness coverage for all three
  launchers + the no-spill assertion, and update
  `benchmarks/moe_gluon_microbench.py` to compare each kernel against
  its proper baseline (upstream `triton_kernels.matmul`, with
  `FusedActivation` for SwiGLU and `scatter_indx` for combine) and
  dump GPR profiles.

Microbench (MI355, container `kylewng_triton_dev_mi355_1505_dev`):

  Kernel 1 (gating GEMM)   M=512  : 2.39x    M=1024 : 2.05x   M=4096 : 1.05x
  Kernel 2 (dispatch+SwiGLU) M=512: 1.16x    M=1024: 1.13x
  Kernel 3 (2nd GEMM + combine) M=512: 1.42x  M=1024: 1.41x

All compiled variants: sgpr_spill = 0, vgpr_spill = 0, scratch = 0.

Signed-off-by: Kyle Wang <kwang102@amd.com>
Per `TASKS.md:39-41`, retune the MI355 Gluon MoE kernels around the real
gpt-oss-120b MoE GEMM dimensions (`H=I=2880, E=128, topk=4` from the HF
config) for both decode (`B in {1, 32, 64}`) and prefill
(`B in {1024, 4096, 8192}`).

Key changes:

* Active-expert remap (decode-critical). The unified
  `_pipelined_moe_kernel` now takes an optional `expert_remap_ptr` +
  `HAS_EXPERT_REMAP` constexpr. When the ragged metadata is sparse
  (e.g. decode `B=1` activates only 4/128 experts) the launcher builds a
  dense i32 index of *active* experts and the kernel does a scalar
  `gl.load(expert_remap_ptr + compact_idx)` to recover the real expert
  id (mirroring `moe_gfx1250.py::gl.load(XSliceSizes + expt_id)`). The
  compact index drives the M-offset; the real expert id only feeds the
  W / bias indexing. All-active prefill keeps the zero-overhead
  `expert_id = compact_idx` path.

* Shape-aware autotuner. `_autotune_block(M, N, K, *, do_swiglu, ragged)`
  now branches on three axes -- dense gating GEMM, fused SwiGLU 1st GEMM
  and ragged 2nd GEMM -- with thresholds picked from a microbench sweep
  on the gpt-oss shapes. Specifically:
    - dense gating: 64/128 BM x 64 BN x 64 BK depending on M
    - SwiGLU: 64x128x32 (M<=8192) / 128x128x32 (M>8192), 4 warps
    - ragged combine: 64x128x32 (M<=8192) / 128x128x32 (M>8192). 256x256
      tiles get 295 TFLOPs but spill 210 VGPRs -> backed off to 128x128
      which keeps every variant at 0 spill / scratch.

* Microbench rewrite (`benchmarks/moe_gluon_microbench.py`) sweeps the
  six gpt-oss-120b token batches (`1, 32, 64, 1024, 4096, 8192`) for
  each kernel against the upstream `triton_kernels.matmul`/+SwiGLU
  fused-activation/scatter baseline. Decode batches build a sparse
  ragged metadata (`min(B*topk, E)` active experts) so the kernel
  exercises the new remap path; prefill batches saturate all 128
  experts. The static GPR profile dump now hard-fails the script if any
  variant spills.

* Five new unit tests (`test/ops/test_moe_gluon.py`):
    - `test_gpt_oss_decode_remap[B=1, 32, 64]` -- numerical equivalence
      of the active-expert remap path.
    - `test_gpt_oss_no_spill[B=1, 64]` -- compile every kernel at the
      real `H=I=2880, E=128` dimensions and assert no sgpr/vgpr spill
      and zero scratch.

Microbench (MI355, container `kylewng_triton_dev_mi355_1505_dev`):

  Kernel 1 (gating GEMM)   B=1..8192   speedup_vs_triton 1.62x .. 2.49x
  Kernel 2 (dispatch+SwiGLU)
                B=1   M_d= 256  speedup 0.78x   (was 0.09x before remap)
                B=32..1024 M_d=8192   0.58x
                B=4096      M_d=16384  0.64x
                B=8192      M_d=32768  0.59x
  Kernel 3 (2nd GEMM + combine)
                B=1   M_d= 256  speedup 0.80x   (was 0.12x before remap)
                B=32..1024 M_d=8192   0.53x
                B=4096      M_d=16384  0.56x
                B=8192      M_d=32768  0.48x

  Static GPR / spill: 12/12 variants clean (sgpr<=42, vgpr<=200,
                      sgpr_spill = vgpr_spill = scratch = 0).

Kernel 1 still beats upstream comfortably on all batches. Kernels 2/3
on prefill (`M_d >= 8192`) close most of the previous decode hole but
cap at ~0.5-0.6x baseline -- the residual gap is purely the missing
multi-buffer LDS pipeline (`buffer_load_to_shared` +
`commit_group/wait_group`), which is now the top follow-up in
`task-progress-2.md`.

Signed-off-by: Kyle Wang <kwang102@amd.com>
CDNA4's scaled MFMA op (gl.amd.cdna4.mfma_scaled) has instruction shape
[16, 16, 128], so any kernel that swaps in scaled MFMA (the mxfp4 weight
+ fp8/bf16 activation path used by gpt-oss-120b) must use BLOCK_K that
is a multiple of 128 with a floor of 128.

* _autotune_block(..., scaled_mfma=False): when True, promote BLOCK_K to
  >= 128 and to the next multiple of 128. Regular MFMA path (bf16) keeps
  the BK=32/64 fast path because a microbench sweep over BK in
  {64, 128, 256} showed BK=128 makes the register-staged pipeline 2-3x
  slower on bf16 due to VGPR pressure (per-tile registers scale linearly
  with BK and our prefetch holds 2 tiles in flight).
* _launch_pipelined(..., scaled_mfma=False): assert BLOCK_K % MFMA_K
  == 0 and BLOCK_K >= 128 when scaled, BLOCK_M % 16 == 0 always.
* 6 new UTs (21 total passing) covering autotune contract,
  launcher rejection of bad BK, and bf16 kernel compile @ BK=128.
* Microbench now also prints the scaled-MFMA autotune preview after
  the static-profile spill check.

bf16 throughput and spill profile are unchanged; this commit only pins
the constraint so it can't silently regress when the mxfp4/fp8 scaled
path lands.

Signed-off-by: kwang102 <kwang102@users.noreply.github.com>
Adds a sibling `_pipelined_moe_kernel_scaled` that uses
`gl.amd.cdna4.mfma_scaled` (16x16x128, k_width=16) to support three
A/W dtype combinations through the same body:

* mxfp4 (e2m1) x mxfp4 (e2m1)  -- both with e8m0 block scales
* fp8   (e4m3) x mxfp4 (e2m1)  -- A with a host-side global scalar scale
* fp8   (e5m2) x mxfp4 (e2m1)  -- same as above

The variant is selected via constexpr flags (`A_FORMAT`, `B_FORMAT`,
`HAS_A_BLOCK_SCALE`) so a single kernel covers all paths, matching the
upstream `triton_kernels._matmul.py` pattern. Block scales reuse the
`get_mfma_scale_layout` distributed layout for direct gl.load (same
recipe as `test_amd_mfma_scaled` in triton-450).

Three new public launchers (`gluon_mxfp_gating_gemm`,
`gluon_mxfp_dispatch_swiglu`, `gluon_mxfp_combine`) parallel the
existing bf16 launchers, including ragged metadata / SwiGLU /
scatter-combine. The autotuner already promoted BLOCK_K to 128 for
`scaled_mfma=True` in Update 3, so no further tuning changes were
needed.

Along the way fix a latent OOB-write bug in the bf16 kernel: the
store mask used `OUT_BLOCK_N * grid_n` instead of the actual output
N, so shapes where `N % BLOCK_N != 0` (e.g. `N=80, BLOCK_N=64`) wrote
into the next row of the output tensor. A targeted regression test
covers it.

Microbench on gpt-oss-120b dims (H=I=2880, E=128, topk=4) shows the
new scaled paths beat the bf16 kernels on all three pieces at
prefill scale:

  K1 gating @ B=8192:        62 -> 132/124 TFLOPs (2.1x / 2.0x)
  K2 dispatch+SwiGLU @ B=8192: 364 -> 436/408 TFLOPs (1.20x / 1.12x)
  K3 combine @ B=8192:       274 -> 402/378 TFLOPs (1.47x / 1.38x)

All 11 scaled kernel variants compile with zero sgpr / vgpr spills and
zero scratch (occupancy 3-5).

Test suite: 34 passing (13 new -- multi-shape mxfp4xmxfp4, fp8xmxfp4
e4m3 & e5m2, dispatch+SwiGLU, ragged combine, OOB regression, no-spill
profile). True LDS multi-buffer pipelining stays as the top follow-up
(documented in task-progress-2.md).

Signed-off-by: kwang102 <kwang102@users.noreply.github.com>
Adds a new ``Mxfp4Fp8TritonKernelBackend`` that consumes AMD-Quark
``w_mxfp4_a_fp8`` checkpoints (e.g. ``amd/gpt-oss-120b-w-mxfp4-a-fp8``):
fp8 e4m3 activations with per-tensor static scales paired with the
existing mxfp4 weight path through ``triton_kernels.matmul``.

Pieces:

* ``Mxfp4Config`` learns to detect Quark ``w_mxfp4_a_fp8`` and surface
  an ``is_w4a8_fp8`` flag plus the excluded layers; ``mxfp4`` selector
  preference promotes the new backend ahead of bf16/flashinfer fallbacks.
* New backend allocates per-expert ``w13_input_scale`` /
  ``w2_input_scale`` parameters, collapses them to per-tensor scales and
  builds ``PrecisionConfig(flex_ctx.lhs_data=InFlexData(fp8, scale),
  out_dtype=bf16)`` so the GEMM output stays in bf16 for swiglu and the
  second quantisation step.
* ``gpt_oss._load_mxfp4_weights`` learns the AMD per-expert checkpoint
  layout (``...experts.{e}.{gate_up_proj,down_proj}.{weight,
  weight_scale,bias,input_scale}``) and dispatches to a new loader that
  shards each expert tensor onto the existing fused ``w13_*`` / ``w2_*``
  parameters.
* Fix a pre-existing torch / tokenspeed_triton ABI mismatch that
  segfaults ``libtriton.so`` whenever it is loaded after torch by
  importing ``tokenspeed_kernel`` first in ``tokenspeed/__init__.py``
  and ``tokenspeed.runtime.utils.common`` (and in the affected
  test files).
* New unit tests in ``test/runtime/test_mxfp4_fp8_backend.py`` cover
  Quark detection, the override path and the per-expert input-scale
  loader.

Eval (gpqa_diamond, medium reasoning, TP=2): mean_acc 0.6616,
avg output throughput 79.37 tok/s (vs ~33 tok/s baseline on the same
hardware).

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Brings in main (~70 commits) plus the fp8 x mxfp4 cleanup pass on top of
our shared base commit. Conflicts in mxfp4.py and gpt_oss.py resolved by
taking the incoming cleanup (drops unused excluded_layers plumbing and
adds is_amd platform guard).

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Re-rank the Gluon MoE GEMM registration so the selector picks it over
the upstream triton_kernels MoE on MI355 (gfx950) by default. Per the
selection objective semantics this kernel is *not* a portability
fallback -- it is a CDNA4-specific throughput / latency optimisation,
so swap the tags accordingly and bump the priority to
``Priority.SPECIALIZED + 2`` (14, above triton_kernels' 10). The
``TOKENSPEED_MOE_GLUON`` env now acts as a *disable* knob (set to
``0/false/no/off`` to drop the priority below triton_kernels for A/B
comparison without rebuilding).

Also harden ``_gluon_bf16_ragged_matmul`` so when it falls back to
``triton_kernels.matmul`` for unsupported precisions (mxfp4 weight
scales, fp8 activation flex data) it forwards extra epilogue kwargs
(``gammas``, ``betas``, ``out_alpha``, ``c``/``c_acc_in``,
``fused_comm``, ``epilogue``, ``b_ragged_metadata``, ...). Previously
``**_unused`` silently dropped them, which would have miscomputed the
gemm+combine path the moment the selector started routing through us.

Verified on the MI355 dev container:

* ``select_kernel("moe", "experts", bf16, features={"ragged_metadata",
  "dispatch_gemm"})`` now returns ``triton_kernels_gluon_dispatch_gemm``.
* ``tokenspeed-kernel/test/ops/test_moe_gluon.py``: 51 passed.
* ``tokenspeed-kernel/test/test_callsite_selection.py``: 36 passed,
  14 skipped (NV-only fixtures, our gluon kernel filtered by
  ``CapabilityRequirement(amd, gfx950)``).
* ``test/runtime/test_mxfp4_fp8_backend.py``: 11 passed.
* ``test/runtime/test_mxfp4_weights.py``: 1 passed.

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Document the steps for ``TASKS.md:68-73`` Update 6:

* The merge of ``origin/kylewng/fp8_mxfp4_moe`` (cleanup pass +
  ~70 main commits) and the two conflict resolutions (mxfp4.py,
  gpt_oss.py).
* The selector ranking (``oracle, objective, priority`` lex sort) and
  why we default Gluon to ``Priority.SPECIALIZED + 2 = 14`` (above
  triton_kernels' ``PERFORMANT + 2 = 10``) with throughput/latency
  tags.
* The hidden ``gammas`` passthrough bug that surfaced once Gluon
  started winning combine selection, and how the adapter now forwards
  every non-tuning kwarg to the upstream fallback.
* Honest assessment that the production fp8 × mxfp4 path on
  gpt-oss-120b currently still falls back to ``triton_kernels.matmul``
  (because ``Mxfp4TritonKernelBackend.process_weights_after_loading``
  emits ``triton_kernels.Tensor``-wrapped weights + an mxfp4 scaled
  PrecisionConfig that our adapter cannot consume natively yet).
* Reproduction recipe for ``test/ci/eval/gpt-oss-120b-mxfp4-evalscope-
  gpqa-diamond.yaml`` plus the Gluon-on/off A/B knob via
  ``TOKENSPEED_MOE_GLUON``.

No code changes besides the doc; companion code commits are
``86d22e4`` (merge) and ``f25dd60`` (registration tweak).

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
When the gluon adapter falls back to ``triton_kernels.matmul`` for
unsupported precisions (mxfp4 weight scale + fp8 flex_ctx), it goes
straight to the underlying ``triton_kernels.matmul`` -- bypassing the
``ops.moe.triton_kernels._matmul`` wrapper that is normally registered
for the ``moe.experts`` family. That wrapper performs a critical
post-processing step on the combine path:

```
if scatter_indx is not None and n_expts_act > 1:
    out = out.view(n_tokens, n_expts_act, out.shape[-1]).sum(dim=1)
return out
```

Without this fold-back, the adapter returns
``(n_tokens * n_expts_act, N)`` while the runtime expects
``(n_tokens, N)`` -- a silent shape divergence that surfaced the moment
the priority bump in ``f25dd60`` started routing combine through us.

Bit-exact A/B verification on a synthetic gpt-oss-shaped fp8 x mxfp4
combine call (H=I=256, E=4, top_k=2, n_tokens=64):

  TOKENSPEED_MOE_GLUON=0  -> shape=(64, 256)  abs_mean=1248.0
  TOKENSPEED_MOE_GLUON=1  -> shape=(64, 256)  abs_mean=1248.0
  max_abs_diff = 0.0    bit-exact match: True

This proves the production gpt-oss-120b path is now no-regression
under the new selector winner.

Tests: 51 + 12 passing (test_moe_gluon.py, test_mxfp4_fp8_backend.py,
test_mxfp4_weights.py).

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Append the synthetic gpt-oss-shaped fp8 x mxfp4 combine A/B test that
proved ``TOKENSPEED_MOE_GLUON=0`` (triton_kernels direct) and the
default ``TOKENSPEED_MOE_GLUON=1`` (selector picks gluon, adapter
falls back) produce bit-exact identical outputs (max_abs_diff = 0.0).

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
knwng added 13 commits May 15, 2026 17:41
…ed slice_size<=16

Empirical sweep on TP=1 prod shape (H=2880 I=2880 E=128 top_k=4) via
check_gluon_decode_perf.py --prod-block-m 32, over 29 (bm,bn,bk,nw)
configs for both dispatch+swiglu and combine. (32,128,256,4) is the
cross-batch + cross-kernel sweet spot:

  dispatch+swiglu best per B (us):
                          B=1     B=4     B=8
    new tier (32,128,256,4) 75.0   76.8   107.0
    old pick (32,128,256,8) 77.4   81.9   116.9
    legacy   (64,128,512,8) 79.8  106.0   158.1   (SPILL)

  combine best per B (us):
                          B=1     B=4     B=8
    new tier (32,128,256,4) 73.2   75.8    86.8
    old pick (32,128,256,8) 73.2   74.4    91.0

NW=4 beats NW=8 by 5-9us on B=4/8 (lower SGPR/VGPR pressure raises
occupancy from 3 to 4); BK=256 matches the mxfp4 cacheline
(128 B / 0.5 B per nibble); BM=32 caps per-tile mask waste at 31/32
(vs 63/64 under BM=64).

A BM=16 sub-tier for fp8 X + slice_size<=2 was tried (sweep showed
71.9us vs 75us at dispatch B=1) and reverted: production E2E c=1 TPOT
regressed from 7.33ms to 7.56ms across 3 stable runs even though
microbench was a clear win. The bench shape sets slice_sizes[i]=
per_expert_padded while production has slice_sizes[i]=actual, so the
microbench measures only MFMA pipeline efficiency at smaller tile
while production hits a different bottleneck (LDS / gather path /
nbuf). Profile c=1 before re-trying. The disabled-sub-tier comment
points to this trade-off so the experiment isn't repeated blindly.

Implementation: thread per-expert M hint through _autotune_block via
a new slice_size kwarg + host-side helper that reads
RaggedTensorMetadata.expected_slice_size (or the
max(1, M // n_slices) fallback). No D2H sync, graph capture stays
intact.

E2E gpt-oss-120b decode TP=1 (single MI355, in=512/out=128, 100 req,
3-run median; prior 1-run measurements were ±7% noise on c=8):

                Gluon ON (this CL)    Gluon OFF (triton_kernels)
  c=8 tok/s     779                   776   (parity within noise)
  c=8 TPOT (ms)   9.32                  9.30
  c=1 tok/s     126.5                 141.0 (-10%)
  c=1 TPOT (ms)   7.33                  6.44 (+14%)

c=8 closes the previous -13% Gluon-ON gap to parity. c=1 still trails
by ~10% -- the remaining gap is not mask waste (already capped at
31/32) nor compute, and will be addressed by profiling + targeted
follow-ups (multi-tile-per-CTA persistent path, bypass scale-LDS for
tiny tiles, num_buffers tuning).

All existing kernel tests pass (test_moe_gluon.py: 57 passed) and
graph-capture bit-exact checks remain green
(check_gluon_graph_capture.py: dispatch+swiglu and combine both
eager<->graph diff = 0).

Signed-off-by: Kyle Wang <ec1wng@gmail.com>
@knwng knwng requested a review from a team as a code owner May 29, 2026 17:59
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c3b635659e

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +133 to +134
layer.w2_weight = Parameter(w2_w_padded, requires_grad=False)
layer.w2_weight_scale = Parameter(w2_s_padded, requires_grad=False)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Pad the W2 bias when padding the output dimension

For GPT-OSS W4A8 on AMD, with_bias=True and hidden_size=2880, so this pads w2_weight/w2_weight_scale to 2944 columns but leaves w2_weight_bias at 2880. The new Gluon combine epilogue derives N from the padded W and masks bias loads with bias_offs < N, so the final tile reads w2_weight_bias[..., 2880:2944] out of bounds before the output is trimmed. Please pad the W2 bias with zeros alongside the weights, or make the bias mask use the logical N.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 3fe6c0e25c

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +73 to +74
"tokenspeed.runtime.layers.moe.backends.mxfp4.gluon_kernel",
"Mxfp4GluonKernelBackend",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Register an actual Gluon runtime backend

On AMD with MXFP4 and the default auto backend, selector.py now tries gluon_kernel before triton_kernel, so this entry makes ensure_backend_family_registered() import tokenspeed.runtime.layers.moe.backends.mxfp4.gluon_kernel. I checked the tree with repo-wide search for Mxfp4GluonKernelBackend/gluon_kernel, and no such module or class exists; the selector only catches KeyError, so this path raises ModuleNotFoundError during MXFP4 MoE initialization instead of falling back to the existing Triton backend.

Useful? React with 👍 / 👎.

Comment on lines +50 to +52
from tokenspeed_triton.experimental import gluon
from tokenspeed_triton.experimental.gluon import language as gl
from tokenspeed_triton.language.core import _aggregate as aggregate
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep tokenspeed_triton imports behind _triton

These direct tokenspeed_triton imports violate the dependency boundary in /workspace/tokenspeed/AGENTS.md that says all direct tokenspeed-triton imports must happen in _triton.py and be re-imported elsewhere. This new module is imported unconditionally by tokenspeed_kernel.ops.moe.__init__, so bypassing the central shim makes the Gluon dependency harder to swap or gate consistently with the rest of tokenspeed-kernel; import gluon/gl through tokenspeed_kernel._triton and add any missing exports there.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f5141d6f51

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +137 to +138
if not current_platform().is_amd:
return False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate Gluon backend to gfx950

When MXFP4 W4A8 runs on AMD devices other than gfx950 (for example CDNA3/gfx94), this supports() check accepts the new gluon_kernel backend because current_platform().is_amd is true, and the auto preference now tries it before triton_kernel. However the Gluon kernels themselves are registered with CapabilityRequirement(... min_arch_version=ArchVersion(9, 5), max_arch_version=ArchVersion(9, 5)), so the first moe_fused call on those AMD platforms has no matching gluon_mxfp4_fp8_fused_moe candidate instead of falling back to the Triton backend. Please make the runtime backend support predicate match the kernel capability (or otherwise keep non-gfx950 on triton_kernel).

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants