Skip to content

Add TopK Gating Softmax Kernel#426

Open
amd-wsung102 wants to merge 5 commits intoROCm:mainfrom
amd-wsung102:topkgatingsoftmax
Open

Add TopK Gating Softmax Kernel#426
amd-wsung102 wants to merge 5 commits intoROCm:mainfrom
amd-wsung102:topkgatingsoftmax

Conversation

@amd-wsung102
Copy link
Copy Markdown

@amd-wsung102 amd-wsung102 commented Apr 22, 2026

Motivation

FlyDSL currently does not have a TopK Gating Softmax kernel, but the operations in this kernel are used in models like GPT-OSS 120B.

Relevant Files

kernels/topk_gating_softmax_kernel.py - topk gating softmax kernel
tests/kernels/test_topk_gating_softmax.py - unit test for the kernel

Optimizations

  • Packed 32 tokens into a block, similar to vLLM's multi-token per block style
  • Every reduction is a sub-warp butterfly shuffle of width 8 (THREADS_PER_TOKEN), eliminating the need for shared memory and barriers
  • Wide vector loads: Each thread issues one 128-bit BufferCopy that pulls all bf16 experts at once

Test Result - Kernel Level

Tested on MI350. Unit test passed.
Around 1.4x-1.5x performance improvement over the vLLM TopKGatingSoftmax kernel, which is used in models like GPT-OSS 120B. Also 1.1x-1.6x performance improvement over the current AITER version.

  Number of Blocks Number of Tokens med (us) p99 (us) Speedup over vLLM Speedup over AITER
vLLM 1 T ≤ 16 4.96 5.56 baseline  
vLLM 2 T = 32 5.4 5.96 baseline  
vLLM 4 T = 64 5.56 6.16 baseline  
vLLM 8 T = 128 6 7.08 baseline  
vLLM 256 T = 4096 6.2 7.28 baseline  
vLLM 1024 T = 16384 7.76 9.84 baseline  
AITER 1 T = 32 5.58 5.68   baseline
AITER 1 T = 64 5.67 5.84   baseline
AITER 2 T = 128 5.88 5.98   baseline
AITER 64 T = 4096 5.88 6.14   baseline
AITER 256 T = 16384 6.18 6.89   baseline
FlyDSL 1 T ≤ 32 3.48 3.68 1.5517241 1.6034483
FlyDSL 2 T = 64 3.92 4.12 1.4183673 1.4464286
FlyDSL 4 T = 128 4.12 4.4 1.4563107 1.4271845
FlyDSL 128 T = 4096 4.72 4.96 1.3135593 1.2457627
FlyDSL 512 T = 16384 5.4 5.64 1.437037 1.1444444

Test Result - E2E GPT-OSS 120B

Config: 1k8k / TP=8 / BS = 4

  Kernel Mean (us) Median (us)
FlyDSL topk_gating_softmax_kernel_0 4.64 4.64
vLLM void vllm::moe::topkGatingSoftmax<std::bfloat16_t, 16, 128, 8, 32, true, 0, (vllm::moe::SharedExpertScoringFunc)0>(std::bfloat16_t const*, bool const*, float*, int, int*, int*, int, int, int, int, ... 4.64 4.64

Submission Checklist

@coderfeli
Copy link
Copy Markdown
Collaborator

CI failed. @amd-wsung102

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants