[WIP]perf: add Gluon MoE kernels for GPT-OSS#314
Conversation
Adds a new ``Mxfp4Fp8TritonKernelBackend`` that consumes AMD-Quark
``w_mxfp4_a_fp8`` checkpoints (e.g. ``amd/gpt-oss-120b-w-mxfp4-a-fp8``):
fp8 e4m3 activations with per-tensor static scales paired with the
existing mxfp4 weight path through ``triton_kernels.matmul``.
Pieces:
* ``Mxfp4Config`` learns to detect Quark ``w_mxfp4_a_fp8`` and surface
an ``is_w4a8_fp8`` flag plus the excluded layers; ``mxfp4`` selector
preference promotes the new backend ahead of bf16/flashinfer fallbacks.
* New backend allocates per-expert ``w13_input_scale`` /
``w2_input_scale`` parameters, collapses them to per-tensor scales and
builds ``PrecisionConfig(flex_ctx.lhs_data=InFlexData(fp8, scale),
out_dtype=bf16)`` so the GEMM output stays in bf16 for swiglu and the
second quantisation step.
* ``gpt_oss._load_mxfp4_weights`` learns the AMD per-expert checkpoint
layout (``...experts.{e}.{gate_up_proj,down_proj}.{weight,
weight_scale,bias,input_scale}``) and dispatches to a new loader that
shards each expert tensor onto the existing fused ``w13_*`` / ``w2_*``
parameters.
* Fix a pre-existing torch / tokenspeed_triton ABI mismatch that
segfaults ``libtriton.so`` whenever it is loaded after torch by
importing ``tokenspeed_kernel`` first in ``tokenspeed/__init__.py``
and ``tokenspeed.runtime.utils.common`` (and in the affected
test files).
* New unit tests in ``test/runtime/test_mxfp4_fp8_backend.py`` cover
Quark detection, the override path and the per-expert input-scale
loader.
Eval (gpqa_diamond, medium reasoning, TP=2): mean_acc 0.6616,
avg output throughput 79.37 tok/s (vs ~33 tok/s baseline on the same
hardware).
Signed-off-by: kylewng <kylewng@users.noreply.github.com>
Adds the first port of the gfx1250 Gluon MoE example to CDNA4 / gfx950
(MI355). The new kernel sits in
``tokenspeed_kernel.ops.moe.gluon`` and ships a baseline (non-pipelined)
bf16 ragged matmul that uses MFMA v4 + Swizzled Shared layouts. The
launcher transparently falls back to the upstream
``triton_kernels.matmul`` for any unsupported precision (mxfp4 weights,
fp8 activations, fused swiglu, persistent / split-K), keeping the public
op signature stable while we graduate features.
Three kernel specs are registered (matmul_ogs / dispatch_gemm /
gemm_combine) and gated behind a ``TOKENSPEED_MOE_GLUON=1`` env knob:
when set, priority is ``Priority.SPECIALIZED + 1`` (above triton_kernels);
otherwise priority is ``Priority.PORTABLE + 1`` so default behavior is
unchanged.
Includes:
* tokenspeed_kernel/ops/moe/gluon.py
+ tokenspeed_kernel/ops/moe/__init__.py wiring
* tokenspeed-kernel/test/ops/test_moe_gluon.py
- 3 correctness tests vs torch reference (passes within bf16 tol)
- 1 selection test under TOKENSPEED_MOE_GLUON=1
- 1 fallback-for-mxfp4 test
* tokenspeed-kernel/benchmarks/moe_gluon_microbench.py
- block_m/block_n/block_k/num_warps grid sweep against the upstream
triton_kernels baseline
* task-progress-2.md documenting API/layout differences from gfx1250 to
CDNA4, perf vs baseline (~46% on 2880x2880x32E) and the explicit
follow-ups needed to graduate to mxfp4 + software-pipelined paths.
Signed-off-by: Kyle Wang <kwang102@amd.com>
Per `TASKS.md:31-38`, finish the MI355 Gluon MoE kernel work:
* Re-organise `_pipelined_moe_kernel` as a single `gluon.jit` body that
feeds three public launchers via `constexpr` flags
(`HAS_BIAS`, `HAS_GATHER`, `HAS_SCATTER`, `DO_SWIGLU`,
`APPLY_GATE_SCAL`):
- `gluon_bf16_gating_gemm` (dense bf16xbf16 GEMM)
- `gluon_bf16_dispatch_swiglu` (dispatch + 1st GEMM + SwiGLU)
- `gluon_bf16_combine` (2nd GEMM + scatter combine)
* Add a register-staged software pipeline (prefetch the next K tile into
registers while the current MFMA executes) modelled on
`triton-450/.../moe_gfx1250.py::MoEPipelinedProgram`. We initially
tried `buffer_load_to_shared` + `commit_group/wait_group` like
`f16_gemm_gfx950.py`, but Triton/MLIR can't currently legalise the
per-expert dynamic base pointer through the async-DMA op; the
register-staged variant already clears the perf bar so async LDS DMA
is left as a follow-up.
* Add `static_profile` / `assert_no_spills` (port of
`gfx1250_utils.py`) and wire them into both the unit test
(`test_no_register_spill`) and the microbench summary so any future
change that introduces sgpr/vgpr spills fails CI.
* New shape-aware `_autotune_block` heuristic obtained from a
microbench sweep: dense GEMM uses BLOCK_N=64 to keep grid_n high;
SwiGLU paths use 64x64x32 + 4 warps to leave VGPR headroom for the
reduce.
* Update `test_moe_gluon.py` with correctness coverage for all three
launchers + the no-spill assertion, and update
`benchmarks/moe_gluon_microbench.py` to compare each kernel against
its proper baseline (upstream `triton_kernels.matmul`, with
`FusedActivation` for SwiGLU and `scatter_indx` for combine) and
dump GPR profiles.
Microbench (MI355, container `kylewng_triton_dev_mi355_1505_dev`):
Kernel 1 (gating GEMM) M=512 : 2.39x M=1024 : 2.05x M=4096 : 1.05x
Kernel 2 (dispatch+SwiGLU) M=512: 1.16x M=1024: 1.13x
Kernel 3 (2nd GEMM + combine) M=512: 1.42x M=1024: 1.41x
All compiled variants: sgpr_spill = 0, vgpr_spill = 0, scratch = 0.
Signed-off-by: Kyle Wang <kwang102@amd.com>
Per `TASKS.md:39-41`, retune the MI355 Gluon MoE kernels around the real
gpt-oss-120b MoE GEMM dimensions (`H=I=2880, E=128, topk=4` from the HF
config) for both decode (`B in {1, 32, 64}`) and prefill
(`B in {1024, 4096, 8192}`).
Key changes:
* Active-expert remap (decode-critical). The unified
`_pipelined_moe_kernel` now takes an optional `expert_remap_ptr` +
`HAS_EXPERT_REMAP` constexpr. When the ragged metadata is sparse
(e.g. decode `B=1` activates only 4/128 experts) the launcher builds a
dense i32 index of *active* experts and the kernel does a scalar
`gl.load(expert_remap_ptr + compact_idx)` to recover the real expert
id (mirroring `moe_gfx1250.py::gl.load(XSliceSizes + expt_id)`). The
compact index drives the M-offset; the real expert id only feeds the
W / bias indexing. All-active prefill keeps the zero-overhead
`expert_id = compact_idx` path.
* Shape-aware autotuner. `_autotune_block(M, N, K, *, do_swiglu, ragged)`
now branches on three axes -- dense gating GEMM, fused SwiGLU 1st GEMM
and ragged 2nd GEMM -- with thresholds picked from a microbench sweep
on the gpt-oss shapes. Specifically:
- dense gating: 64/128 BM x 64 BN x 64 BK depending on M
- SwiGLU: 64x128x32 (M<=8192) / 128x128x32 (M>8192), 4 warps
- ragged combine: 64x128x32 (M<=8192) / 128x128x32 (M>8192). 256x256
tiles get 295 TFLOPs but spill 210 VGPRs -> backed off to 128x128
which keeps every variant at 0 spill / scratch.
* Microbench rewrite (`benchmarks/moe_gluon_microbench.py`) sweeps the
six gpt-oss-120b token batches (`1, 32, 64, 1024, 4096, 8192`) for
each kernel against the upstream `triton_kernels.matmul`/+SwiGLU
fused-activation/scatter baseline. Decode batches build a sparse
ragged metadata (`min(B*topk, E)` active experts) so the kernel
exercises the new remap path; prefill batches saturate all 128
experts. The static GPR profile dump now hard-fails the script if any
variant spills.
* Five new unit tests (`test/ops/test_moe_gluon.py`):
- `test_gpt_oss_decode_remap[B=1, 32, 64]` -- numerical equivalence
of the active-expert remap path.
- `test_gpt_oss_no_spill[B=1, 64]` -- compile every kernel at the
real `H=I=2880, E=128` dimensions and assert no sgpr/vgpr spill
and zero scratch.
Microbench (MI355, container `kylewng_triton_dev_mi355_1505_dev`):
Kernel 1 (gating GEMM) B=1..8192 speedup_vs_triton 1.62x .. 2.49x
Kernel 2 (dispatch+SwiGLU)
B=1 M_d= 256 speedup 0.78x (was 0.09x before remap)
B=32..1024 M_d=8192 0.58x
B=4096 M_d=16384 0.64x
B=8192 M_d=32768 0.59x
Kernel 3 (2nd GEMM + combine)
B=1 M_d= 256 speedup 0.80x (was 0.12x before remap)
B=32..1024 M_d=8192 0.53x
B=4096 M_d=16384 0.56x
B=8192 M_d=32768 0.48x
Static GPR / spill: 12/12 variants clean (sgpr<=42, vgpr<=200,
sgpr_spill = vgpr_spill = scratch = 0).
Kernel 1 still beats upstream comfortably on all batches. Kernels 2/3
on prefill (`M_d >= 8192`) close most of the previous decode hole but
cap at ~0.5-0.6x baseline -- the residual gap is purely the missing
multi-buffer LDS pipeline (`buffer_load_to_shared` +
`commit_group/wait_group`), which is now the top follow-up in
`task-progress-2.md`.
Signed-off-by: Kyle Wang <kwang102@amd.com>
CDNA4's scaled MFMA op (gl.amd.cdna4.mfma_scaled) has instruction shape
[16, 16, 128], so any kernel that swaps in scaled MFMA (the mxfp4 weight
+ fp8/bf16 activation path used by gpt-oss-120b) must use BLOCK_K that
is a multiple of 128 with a floor of 128.
* _autotune_block(..., scaled_mfma=False): when True, promote BLOCK_K to
>= 128 and to the next multiple of 128. Regular MFMA path (bf16) keeps
the BK=32/64 fast path because a microbench sweep over BK in
{64, 128, 256} showed BK=128 makes the register-staged pipeline 2-3x
slower on bf16 due to VGPR pressure (per-tile registers scale linearly
with BK and our prefetch holds 2 tiles in flight).
* _launch_pipelined(..., scaled_mfma=False): assert BLOCK_K % MFMA_K
== 0 and BLOCK_K >= 128 when scaled, BLOCK_M % 16 == 0 always.
* 6 new UTs (21 total passing) covering autotune contract,
launcher rejection of bad BK, and bf16 kernel compile @ BK=128.
* Microbench now also prints the scaled-MFMA autotune preview after
the static-profile spill check.
bf16 throughput and spill profile are unchanged; this commit only pins
the constraint so it can't silently regress when the mxfp4/fp8 scaled
path lands.
Signed-off-by: kwang102 <kwang102@users.noreply.github.com>
Adds a sibling `_pipelined_moe_kernel_scaled` that uses `gl.amd.cdna4.mfma_scaled` (16x16x128, k_width=16) to support three A/W dtype combinations through the same body: * mxfp4 (e2m1) x mxfp4 (e2m1) -- both with e8m0 block scales * fp8 (e4m3) x mxfp4 (e2m1) -- A with a host-side global scalar scale * fp8 (e5m2) x mxfp4 (e2m1) -- same as above The variant is selected via constexpr flags (`A_FORMAT`, `B_FORMAT`, `HAS_A_BLOCK_SCALE`) so a single kernel covers all paths, matching the upstream `triton_kernels._matmul.py` pattern. Block scales reuse the `get_mfma_scale_layout` distributed layout for direct gl.load (same recipe as `test_amd_mfma_scaled` in triton-450). Three new public launchers (`gluon_mxfp_gating_gemm`, `gluon_mxfp_dispatch_swiglu`, `gluon_mxfp_combine`) parallel the existing bf16 launchers, including ragged metadata / SwiGLU / scatter-combine. The autotuner already promoted BLOCK_K to 128 for `scaled_mfma=True` in Update 3, so no further tuning changes were needed. Along the way fix a latent OOB-write bug in the bf16 kernel: the store mask used `OUT_BLOCK_N * grid_n` instead of the actual output N, so shapes where `N % BLOCK_N != 0` (e.g. `N=80, BLOCK_N=64`) wrote into the next row of the output tensor. A targeted regression test covers it. Microbench on gpt-oss-120b dims (H=I=2880, E=128, topk=4) shows the new scaled paths beat the bf16 kernels on all three pieces at prefill scale: K1 gating @ B=8192: 62 -> 132/124 TFLOPs (2.1x / 2.0x) K2 dispatch+SwiGLU @ B=8192: 364 -> 436/408 TFLOPs (1.20x / 1.12x) K3 combine @ B=8192: 274 -> 402/378 TFLOPs (1.47x / 1.38x) All 11 scaled kernel variants compile with zero sgpr / vgpr spills and zero scratch (occupancy 3-5). Test suite: 34 passing (13 new -- multi-shape mxfp4xmxfp4, fp8xmxfp4 e4m3 & e5m2, dispatch+SwiGLU, ragged combine, OOB regression, no-spill profile). True LDS multi-buffer pipelining stays as the top follow-up (documented in task-progress-2.md). Signed-off-by: kwang102 <kwang102@users.noreply.github.com>
Adds a new ``Mxfp4Fp8TritonKernelBackend`` that consumes AMD-Quark
``w_mxfp4_a_fp8`` checkpoints (e.g. ``amd/gpt-oss-120b-w-mxfp4-a-fp8``):
fp8 e4m3 activations with per-tensor static scales paired with the
existing mxfp4 weight path through ``triton_kernels.matmul``.
Pieces:
* ``Mxfp4Config`` learns to detect Quark ``w_mxfp4_a_fp8`` and surface
an ``is_w4a8_fp8`` flag plus the excluded layers; ``mxfp4`` selector
preference promotes the new backend ahead of bf16/flashinfer fallbacks.
* New backend allocates per-expert ``w13_input_scale`` /
``w2_input_scale`` parameters, collapses them to per-tensor scales and
builds ``PrecisionConfig(flex_ctx.lhs_data=InFlexData(fp8, scale),
out_dtype=bf16)`` so the GEMM output stays in bf16 for swiglu and the
second quantisation step.
* ``gpt_oss._load_mxfp4_weights`` learns the AMD per-expert checkpoint
layout (``...experts.{e}.{gate_up_proj,down_proj}.{weight,
weight_scale,bias,input_scale}``) and dispatches to a new loader that
shards each expert tensor onto the existing fused ``w13_*`` / ``w2_*``
parameters.
* Fix a pre-existing torch / tokenspeed_triton ABI mismatch that
segfaults ``libtriton.so`` whenever it is loaded after torch by
importing ``tokenspeed_kernel`` first in ``tokenspeed/__init__.py``
and ``tokenspeed.runtime.utils.common`` (and in the affected
test files).
* New unit tests in ``test/runtime/test_mxfp4_fp8_backend.py`` cover
Quark detection, the override path and the per-expert input-scale
loader.
Eval (gpqa_diamond, medium reasoning, TP=2): mean_acc 0.6616,
avg output throughput 79.37 tok/s (vs ~33 tok/s baseline on the same
hardware).
Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Brings in main (~70 commits) plus the fp8 x mxfp4 cleanup pass on top of our shared base commit. Conflicts in mxfp4.py and gpt_oss.py resolved by taking the incoming cleanup (drops unused excluded_layers plumbing and adds is_amd platform guard). Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Re-rank the Gluon MoE GEMM registration so the selector picks it over
the upstream triton_kernels MoE on MI355 (gfx950) by default. Per the
selection objective semantics this kernel is *not* a portability
fallback -- it is a CDNA4-specific throughput / latency optimisation,
so swap the tags accordingly and bump the priority to
``Priority.SPECIALIZED + 2`` (14, above triton_kernels' 10). The
``TOKENSPEED_MOE_GLUON`` env now acts as a *disable* knob (set to
``0/false/no/off`` to drop the priority below triton_kernels for A/B
comparison without rebuilding).
Also harden ``_gluon_bf16_ragged_matmul`` so when it falls back to
``triton_kernels.matmul`` for unsupported precisions (mxfp4 weight
scales, fp8 activation flex data) it forwards extra epilogue kwargs
(``gammas``, ``betas``, ``out_alpha``, ``c``/``c_acc_in``,
``fused_comm``, ``epilogue``, ``b_ragged_metadata``, ...). Previously
``**_unused`` silently dropped them, which would have miscomputed the
gemm+combine path the moment the selector started routing through us.
Verified on the MI355 dev container:
* ``select_kernel("moe", "experts", bf16, features={"ragged_metadata",
"dispatch_gemm"})`` now returns ``triton_kernels_gluon_dispatch_gemm``.
* ``tokenspeed-kernel/test/ops/test_moe_gluon.py``: 51 passed.
* ``tokenspeed-kernel/test/test_callsite_selection.py``: 36 passed,
14 skipped (NV-only fixtures, our gluon kernel filtered by
``CapabilityRequirement(amd, gfx950)``).
* ``test/runtime/test_mxfp4_fp8_backend.py``: 11 passed.
* ``test/runtime/test_mxfp4_weights.py``: 1 passed.
Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Document the steps for ``TASKS.md:68-73`` Update 6: * The merge of ``origin/kylewng/fp8_mxfp4_moe`` (cleanup pass + ~70 main commits) and the two conflict resolutions (mxfp4.py, gpt_oss.py). * The selector ranking (``oracle, objective, priority`` lex sort) and why we default Gluon to ``Priority.SPECIALIZED + 2 = 14`` (above triton_kernels' ``PERFORMANT + 2 = 10``) with throughput/latency tags. * The hidden ``gammas`` passthrough bug that surfaced once Gluon started winning combine selection, and how the adapter now forwards every non-tuning kwarg to the upstream fallback. * Honest assessment that the production fp8 × mxfp4 path on gpt-oss-120b currently still falls back to ``triton_kernels.matmul`` (because ``Mxfp4TritonKernelBackend.process_weights_after_loading`` emits ``triton_kernels.Tensor``-wrapped weights + an mxfp4 scaled PrecisionConfig that our adapter cannot consume natively yet). * Reproduction recipe for ``test/ci/eval/gpt-oss-120b-mxfp4-evalscope- gpqa-diamond.yaml`` plus the Gluon-on/off A/B knob via ``TOKENSPEED_MOE_GLUON``. No code changes besides the doc; companion code commits are ``86d22e4`` (merge) and ``f25dd60`` (registration tweak). Signed-off-by: Kyle Wang <ec1wng@gmail.com>
When the gluon adapter falls back to ``triton_kernels.matmul`` for
unsupported precisions (mxfp4 weight scale + fp8 flex_ctx), it goes
straight to the underlying ``triton_kernels.matmul`` -- bypassing the
``ops.moe.triton_kernels._matmul`` wrapper that is normally registered
for the ``moe.experts`` family. That wrapper performs a critical
post-processing step on the combine path:
```
if scatter_indx is not None and n_expts_act > 1:
out = out.view(n_tokens, n_expts_act, out.shape[-1]).sum(dim=1)
return out
```
Without this fold-back, the adapter returns
``(n_tokens * n_expts_act, N)`` while the runtime expects
``(n_tokens, N)`` -- a silent shape divergence that surfaced the moment
the priority bump in ``f25dd60`` started routing combine through us.
Bit-exact A/B verification on a synthetic gpt-oss-shaped fp8 x mxfp4
combine call (H=I=256, E=4, top_k=2, n_tokens=64):
TOKENSPEED_MOE_GLUON=0 -> shape=(64, 256) abs_mean=1248.0
TOKENSPEED_MOE_GLUON=1 -> shape=(64, 256) abs_mean=1248.0
max_abs_diff = 0.0 bit-exact match: True
This proves the production gpt-oss-120b path is now no-regression
under the new selector winner.
Tests: 51 + 12 passing (test_moe_gluon.py, test_mxfp4_fp8_backend.py,
test_mxfp4_weights.py).
Signed-off-by: Kyle Wang <ec1wng@gmail.com>
Append the synthetic gpt-oss-shaped fp8 x mxfp4 combine A/B test that proved ``TOKENSPEED_MOE_GLUON=0`` (triton_kernels direct) and the default ``TOKENSPEED_MOE_GLUON=1`` (selector picks gluon, adapter falls back) produce bit-exact identical outputs (max_abs_diff = 0.0). Signed-off-by: Kyle Wang <ec1wng@gmail.com>
…ed slice_size<=16
Empirical sweep on TP=1 prod shape (H=2880 I=2880 E=128 top_k=4) via
check_gluon_decode_perf.py --prod-block-m 32, over 29 (bm,bn,bk,nw)
configs for both dispatch+swiglu and combine. (32,128,256,4) is the
cross-batch + cross-kernel sweet spot:
dispatch+swiglu best per B (us):
B=1 B=4 B=8
new tier (32,128,256,4) 75.0 76.8 107.0
old pick (32,128,256,8) 77.4 81.9 116.9
legacy (64,128,512,8) 79.8 106.0 158.1 (SPILL)
combine best per B (us):
B=1 B=4 B=8
new tier (32,128,256,4) 73.2 75.8 86.8
old pick (32,128,256,8) 73.2 74.4 91.0
NW=4 beats NW=8 by 5-9us on B=4/8 (lower SGPR/VGPR pressure raises
occupancy from 3 to 4); BK=256 matches the mxfp4 cacheline
(128 B / 0.5 B per nibble); BM=32 caps per-tile mask waste at 31/32
(vs 63/64 under BM=64).
A BM=16 sub-tier for fp8 X + slice_size<=2 was tried (sweep showed
71.9us vs 75us at dispatch B=1) and reverted: production E2E c=1 TPOT
regressed from 7.33ms to 7.56ms across 3 stable runs even though
microbench was a clear win. The bench shape sets slice_sizes[i]=
per_expert_padded while production has slice_sizes[i]=actual, so the
microbench measures only MFMA pipeline efficiency at smaller tile
while production hits a different bottleneck (LDS / gather path /
nbuf). Profile c=1 before re-trying. The disabled-sub-tier comment
points to this trade-off so the experiment isn't repeated blindly.
Implementation: thread per-expert M hint through _autotune_block via
a new slice_size kwarg + host-side helper that reads
RaggedTensorMetadata.expected_slice_size (or the
max(1, M // n_slices) fallback). No D2H sync, graph capture stays
intact.
E2E gpt-oss-120b decode TP=1 (single MI355, in=512/out=128, 100 req,
3-run median; prior 1-run measurements were ±7% noise on c=8):
Gluon ON (this CL) Gluon OFF (triton_kernels)
c=8 tok/s 779 776 (parity within noise)
c=8 TPOT (ms) 9.32 9.30
c=1 tok/s 126.5 141.0 (-10%)
c=1 TPOT (ms) 7.33 6.44 (+14%)
c=8 closes the previous -13% Gluon-ON gap to parity. c=1 still trails
by ~10% -- the remaining gap is not mask waste (already capped at
31/32) nor compute, and will be addressed by profiling + targeted
follow-ups (multi-tile-per-CTA persistent path, bypass scale-LDS for
tiny tiles, num_buffers tuning).
All existing kernel tests pass (test_moe_gluon.py: 57 passed) and
graph-capture bit-exact checks remain green
(check_gluon_graph_capture.py: dispatch+swiglu and combine both
eager<->graph diff = 0).
Signed-off-by: Kyle Wang <ec1wng@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: c3b635659e
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| layer.w2_weight = Parameter(w2_w_padded, requires_grad=False) | ||
| layer.w2_weight_scale = Parameter(w2_s_padded, requires_grad=False) |
There was a problem hiding this comment.
Pad the W2 bias when padding the output dimension
For GPT-OSS W4A8 on AMD, with_bias=True and hidden_size=2880, so this pads w2_weight/w2_weight_scale to 2944 columns but leaves w2_weight_bias at 2880. The new Gluon combine epilogue derives N from the padded W and masks bias loads with bias_offs < N, so the final tile reads w2_weight_bias[..., 2880:2944] out of bounds before the output is trimmed. Please pad the W2 bias with zeros alongside the weights, or make the bias mask use the logical N.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 3fe6c0e25c
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| "tokenspeed.runtime.layers.moe.backends.mxfp4.gluon_kernel", | ||
| "Mxfp4GluonKernelBackend", |
There was a problem hiding this comment.
Register an actual Gluon runtime backend
On AMD with MXFP4 and the default auto backend, selector.py now tries gluon_kernel before triton_kernel, so this entry makes ensure_backend_family_registered() import tokenspeed.runtime.layers.moe.backends.mxfp4.gluon_kernel. I checked the tree with repo-wide search for Mxfp4GluonKernelBackend/gluon_kernel, and no such module or class exists; the selector only catches KeyError, so this path raises ModuleNotFoundError during MXFP4 MoE initialization instead of falling back to the existing Triton backend.
Useful? React with 👍 / 👎.
| from tokenspeed_triton.experimental import gluon | ||
| from tokenspeed_triton.experimental.gluon import language as gl | ||
| from tokenspeed_triton.language.core import _aggregate as aggregate |
There was a problem hiding this comment.
Keep tokenspeed_triton imports behind _triton
These direct tokenspeed_triton imports violate the dependency boundary in /workspace/tokenspeed/AGENTS.md that says all direct tokenspeed-triton imports must happen in _triton.py and be re-imported elsewhere. This new module is imported unconditionally by tokenspeed_kernel.ops.moe.__init__, so bypassing the central shim makes the Gluon dependency harder to swap or gate consistently with the rest of tokenspeed-kernel; import gluon/gl through tokenspeed_kernel._triton and add any missing exports there.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f5141d6f51
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if not current_platform().is_amd: | ||
| return False |
There was a problem hiding this comment.
When MXFP4 W4A8 runs on AMD devices other than gfx950 (for example CDNA3/gfx94), this supports() check accepts the new gluon_kernel backend because current_platform().is_amd is true, and the auto preference now tries it before triton_kernel. However the Gluon kernels themselves are registered with CapabilityRequirement(... min_arch_version=ArchVersion(9, 5), max_arch_version=ArchVersion(9, 5)), so the first moe_fused call on those AMD platforms has no matching gluon_mxfp4_fp8_fused_moe candidate instead of falling back to the Triton backend. Please make the runtime backend support predicate match the kernel capability (or otherwise keep non-gfx950 on triton_kernel).
Useful? React with 👍 / 👎.
Added support for Gluon MoE kernels and applied optimizations including: