Skip to content

Flydsl rmsnorm#2889

Open
kudomcho wants to merge 27 commits intomainfrom
flydsl_rmsnorm
Open

Flydsl rmsnorm#2889
kudomcho wants to merge 27 commits intomainfrom
flydsl_rmsnorm

Conversation

@kudomcho
Copy link
Copy Markdown

Motivation

This PR improves the performance of the FlyDSL RMSNorm kernel by addressing inefficiencies observed in production-like workloads (e.g., GPT-style shapes such as N=2880).

The previous implementation suffered from:

  • heavy reliance on scalar operations in non-aligned cases
  • high per-row overhead in generic paths
  • inconsistent performance across different shapes

The goal of this PR is to:

  • eliminate scalar bottlenecks
  • maximize vectorized execution
  • improve performance consistency across both aligned and non-aligned workloads

Technical Details

This PR introduces several optimizations to the RMSNorm kernel:

  1. Dual execution paths
  • Tile-fast path (aligned shapes): fully vectorized using buffer_load/store
  • Vector-generic path (arbitrary shapes): vectorized bulk + minimal scalar tail. This ensures most workloads avoid expensive scalar execution.
  1. Vectorized memory access
  • Standardized on VEC_WIDTH = 8
  • Improved memory coalescing and throughput using vectorized buffer ops
  1. Reduced scalar overhead
  • Scalar operations are now restricted to the final tail only
  • Eliminates full-row scalar fallback in generic cases
  1. Block reduction simplification
  • Replaced multi-buffer reduction with a single shared-memory reduction
  • Reduced synchronization and shared memory traffic
  1. Loop unrolling
  • Added UNROLL = 2 for medium-width bf16/f16 workloads (N <= 4096)
  • Improves instruction-level parallelism and reduces loop overhead
  1. Bounded input caching (generic path)
  • Cached x values during reduction pass to avoid reloading in normalization pass
  • Enabled only for medium-width cases to control register pressure
  1. Register pressure optimization
  • Removed input caching from fast tiled path
  • Avoids excessive register usage and potential spills

Test Plan

The unit-test against Pytorch reference and Benchmark against AITER script is provided as test_rmsnorm_bench_against_aiter.py. One can run from flydsl dir as python test_rmsnorm_bench_against_aiter.py

Test Result

All cases from dimensions of GPT OSS 120B for RMSnorm operators were passed. Speedup Improvements present on dimension 16384 with 40% improvement. The remaining cases are improved by 1-2 us on average.

======================================================================
SUMMARY

PASS flydsl_rmsnorm_M3000_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.97us FlyDSL= 29.11us speedup= 0.446x
PASS aiter_rmsnorm_M3000_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.97us FlyDSL= 29.11us speedup= 0.446x
PASS flydsl_rmsnorm_M4000_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.96us FlyDSL= 27.71us speedup= 0.468x
PASS aiter_rmsnorm_M4000_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.96us FlyDSL= 27.71us speedup= 0.468x
PASS flydsl_rmsnorm_M5000_N2880_torch.bfloat16 max_delta= 0.015440 close=100.00% AITER= 15.96us FlyDSL= 27.96us speedup= 0.571x
PASS aiter_rmsnorm_M5000_N2880_torch.bfloat16 max_delta= 0.015440 close=100.00% AITER= 15.96us FlyDSL= 27.96us speedup= 0.571x
PASS flydsl_rmsnorm_M7000_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 14.59us FlyDSL= 27.93us speedup= 0.522x
PASS aiter_rmsnorm_M7000_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 14.59us FlyDSL= 27.93us speedup= 0.522x
PASS flydsl_rmsnorm_M3072_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.35us FlyDSL= 27.66us speedup= 0.447x
PASS aiter_rmsnorm_M3072_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.35us FlyDSL= 27.66us speedup= 0.447x
PASS flydsl_rmsnorm_M4096_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.75us FlyDSL= 27.32us speedup= 0.467x
PASS aiter_rmsnorm_M4096_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.75us FlyDSL= 27.32us speedup= 0.467x
PASS flydsl_rmsnorm_M7168_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 15.02us FlyDSL= 27.90us speedup= 0.538x
PASS aiter_rmsnorm_M7168_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 15.02us FlyDSL= 27.90us speedup= 0.538x
PASS flydsl_rmsnorm_M8192_N2880_torch.bfloat16 max_delta= 0.015525 close=100.00% AITER= 16.87us FlyDSL= 27.81us speedup= 0.606x
PASS aiter_rmsnorm_M8192_N2880_torch.bfloat16 max_delta= 0.015525 close=100.00% AITER= 16.87us FlyDSL= 27.81us speedup= 0.606x
PASS flydsl_rmsnorm_M16384_N2880_torch.bfloat16 max_delta= 0.015553 close=100.00% AITER= 30.41us FlyDSL= 29.02us speedup= 1.048x
PASS aiter_rmsnorm_M16384_N2880_torch.bfloat16 max_delta= 0.015553 close=100.00% AITER= 30.41us FlyDSL= 29.02us speedup= 1.048x

18/18 passed

@kudomcho kudomcho requested a review from a team April 23, 2026 23:02
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2889 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants