Skip to content

Feat/step3p5 moe swiglustep#2887

Draft
LJ-underdog wants to merge 6 commits intomainfrom
feat/step3p5-moe-swiglustep
Draft

Feat/step3p5 moe swiglustep#2887
LJ-underdog wants to merge 6 commits intomainfrom
feat/step3p5-moe-swiglustep

Conversation

@LJ-underdog
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

…n support

Four bugs fixed in the BF16 no-quant CK 2-stage MoE path on gfx950 (MI350X):

1. Force block_m=128 to select V3 CK kernel
   V1 kernel (block_m=16/64) produces wrong results for inter_dim>192
   on gfx950 due to tile misalignment. Extend workaround to cover both
   preshuffle_on and preshuffle_off paths.

2. Fix preshuffle mode for swiglustep+no-quant in JIT build
   _infer_preshuffle_modes() now compiles both preshuffle variants for
   the swiglustep activation. Fix _build_moe_variant() to pass
   --preshuffle flag for all kernel types (not just fp4x2).

3. Disable CustomAllreduce on gfx950 tp>=2
   CustomAllreduce produces NaN on gfx950 multi-GPU; disable it in
   parallel_state.py.

Add SwigluStep Python support in fused_moe.py:
- swiglustep(gate, up, limit=7.0): silu(gate).clamp(max=7) * up.clamp(±7)
- torch_moe_act(), torch_moe_stage1(): SwigluStep branches
- Exclude SwigluStep from flydsl path (no kernel implementation)

Verified: cos_sim=0.999989 for T=1,4,32,128,512 on gfx950
(E=288, K=8, model_dim=2048, inter_dim=640, bf16 no-quant, preshuffle_on)

Co-Authored-By: Jun Lin <junlin12@amd.com>
Add SwigluStep (sigmoid-gated with ±7 clamp) as a new activation type
for the CK 2-stage MoE kernel, required by Step-3.5-Flash routed experts.

Changes:
- csrc/include/aiter_enum.h: add SwigluStep=3 to ActivationType enum
- csrc/include/rocm_ops.hpp: expose SwigluStep to Python via pybind11
- csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages.cu:
    replace boolean !activation hack with explicit map_activation_to_ck_stage1()
    (old code: !SwigluStep(=3) = 0 = Gelu, wrong)
    remove !activation from stage2 (stage2 never runs activation)
- csrc/ck_gemm_moe_2stages_codegen/gemm_moe_ck2stages_common.py:
    ActOP: bool -> int, add ACT_OP_MAP/ACT_OP_NAME dicts
- csrc/ck_gemm_moe_2stages_codegen/gen_instances.py:
    add swiglustep to codegen loop (preshuffle_on + preshuffle_off)
- aiter/ops/quant.py: add SwigluStep single-input fallback
- aiter/utility/dtypes.py: fix str2ActivationType for CamelCase enum names
- 3rdparty/composable_kernel: bump to commit with swiglustep_and_mul
    kernel branches in gridwise_moe_gemm.hpp (4 paths: quant/no-quant x
    MulRoutedWeight on/off)

Verified: cos_sim=0.999989 for T=1,4,32,128,512 (H=2048,I=640,E=288,K=8,
bf16, preshuffle_on) against torch_moe Python reference.

Co-Authored-By: Jun Lin <junlin12@amd.com>
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2887 --add-label <label>

LJ-underdog and others added 3 commits April 24, 2026 02:46
Rebase feat/swiglustep-moe-no-quant onto latest CK develop.
The blockscale swiglustep support was already merged into develop;
only the standard no-quant path commit (gridwise_moe_gemm.hpp)
remains as our contribution.

Co-Authored-By: Jun Lin <junlin12@amd.com>
Three related fixes for Step-3.5-Flash tp=2/4/8 on gfx950 (MI350X):

1. communicator_pynccl.py: add AITER_PYNCCL_SKIP=1 env var to skip
   ncclCommInitRank, which hangs in RCCL on gfx950 with world_size=8.
   Falls back to torch.distributed standard collective.

2. parallel_state.py (_all_gather_out_place): add ca_comm None guard.
   When set_custom_all_reduce(False) disables custom all-reduce,
   ca_comm becomes None but _all_gather_out_place still asserts it
   non-None. Add NCCL all_gather fallback matching the existing
   all_gather() NCCL path (L567-576).

3. communication.py (init_dist_env): set_custom_all_reduce(False) for
   gfx950 where IPC-based custom allreduce causes hangs. Add ca_comm
   None guard around signal buffer setup to prevent AttributeError.

Verified: tp=2 inference passes (4 prompts, no crash) on gfx950.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The V1 CK kernel correctness workaround at L904 unconditionally forces
block_m=128 for inter_dim>192 on gfx950, but the a8w8blkscale dispatch
(per_1x128/per_1x32) only supports block_m<=64 and is not affected by
the V1 bug. Passing block_m=128 to blockscale dispatch triggers
TORCH_CHECK failure ("Unsupported block_m value for moe heuristic
dispatch: 128"), breaking FP8 weight-quantized model inference.

Add a q_type guard to exclude blockscale paths from the override.
For tp=2 Step-3.5-Flash-FP8 (inter_dim=640, per_1x128): block_m stays
at 64 (set at L895), and 640%128=0 satisfies alignment constraints
without any inter_dim padding.
@LJ-underdog LJ-underdog force-pushed the feat/step3p5-moe-swiglustep branch from 22f03a3 to 0f10156 Compare April 24, 2026 21:53
…fx950

The tuning entry `M=16384,N=4096,K=2048,bf16,asm,bf16gemm_bf16_tn_256x256`
causes the dispatcher to select `_ZN5aiter24bf16gemm_bf16_tn_256x256E`
for all M in [8193, 16384] via padded_M=16384. The ASM kernel produces
completely wrong outputs (diff ≈ 392 vs ref_max ≈ 247) for non-256-aligned
M values such as 8209-8223, causing silent data corruption.

In practice this broke tp=4 long-sequence prefill (M ≥ 8209) for models
whose attention o_proj has exactly this GEMM shape (e.g. Step-3.5-Flash),
producing all-BOS output tokens. Removing the entry makes all M values
fall back to torch.mm, restoring correctness.

Verified: tgemm.mm diff drops to 0 for all M in {8208,8209,8214,8216,10021};
end-to-end BF16 tp=4 inference on 10021-token input now produces coherent text.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant