Skip to content

optimized fast path by reducing register use and vectorize on normal …#436

Open
kudomcho wants to merge 3 commits intomainfrom
khanin/rmsnorm_opt
Open

optimized fast path by reducing register use and vectorize on normal …#436
kudomcho wants to merge 3 commits intomainfrom
khanin/rmsnorm_opt

Conversation

@kudomcho
Copy link
Copy Markdown

@kudomcho kudomcho commented Apr 24, 2026

Motivation

This PR improves the performance of the FlyDSL RMSNorm kernel by addressing inefficiencies observed in production-like workloads (e.g., GPT-style shapes such as N=2880).

The previous implementation suffered from:

  • heavy reliance on scalar operations in non-aligned cases
  • high per-row overhead in generic paths
  • inconsistent performance across different shapes

The goal of this PR is to:

  • eliminate scalar bottlenecks
  • maximize vectorized execution
  • improve performance consistency across both aligned and non-aligned workloads

Technical Details

This PR introduces several optimizations to the RMSNorm kernel:

  1. Dual execution paths
  • Tile-fast path (aligned shapes): fully vectorized using buffer_load/store
  • Vector-generic path (arbitrary shapes): vectorized bulk + minimal scalar tail. This ensures most workloads avoid expensive scalar execution.
  1. Vectorized memory access
  • Standardized on VEC_WIDTH = 8
  • Improved memory coalescing and throughput using vectorized buffer ops
  1. Reduced scalar overhead
  • Scalar operations are now restricted to the final tail only
  • Eliminates full-row scalar fallback in generic cases
  1. Block reduction simplification
  • Replaced multi-buffer reduction with a single shared-memory reduction
  • Reduced synchronization and shared memory traffic
  1. Loop unrolling
  • Added UNROLL = 2 for medium-width bf16/f16 workloads (N <= 4096)
  • Improves instruction-level parallelism and reduces loop overhead
  1. Bounded input caching (generic path)
  • Cached x values during reduction pass to avoid reloading in normalization pass
  • Enabled only for medium-width cases to control register pressure
  1. Register pressure optimization
  • Removed input caching from fast tiled path
  • Avoids excessive register usage and potential spills

Test Plan

The unit-test on test_rmsnorm.py within FlyDSL repo. Run python -m tests.kernels.test_rmsnorm

Test Result

================
Running RMSNorm Tests

Testing RMSNorm (M=32768, N=8192, dtype=bf16)
Launching kernel...
[W424 19:43:23.196708004 collection.cpp:1133] Warning: ROCTracer produced duplicate flow start: 1 (function operator())
Kernel avg time: 0.2125 ms via run_perftest (warmup=10, iters=100)
Bandwidth: 5052.71 GB/s
Max absolute error: 1.56e-02 (atol=0.02)
PASSED

================================================================================
ALL TESTS PASSED

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant