Open
Conversation
…ck func on kernels_common.py
…ps on normal path
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR improves the performance of the FlyDSL RMSNorm kernel by addressing inefficiencies observed in production-like workloads (e.g., GPT-style shapes such as N=2880).
The previous implementation suffered from:
The goal of this PR is to:
Technical Details
This PR introduces several optimizations to the RMSNorm kernel:
Test Plan
The unit-test against Pytorch reference and Benchmark against AITER script is provided as test_rmsnorm_bench_against_aiter.py. One can run from flydsl dir as python test_rmsnorm_bench_against_aiter.py
Test Result
All cases from dimensions of GPT OSS 120B for RMSnorm operators were passed. Speedup Improvements present on dimension 16384 with 40% improvement. The remaining cases are improved by 1-2 us on average.
======================================================================
SUMMARY
PASS flydsl_rmsnorm_M3000_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.97us FlyDSL= 29.11us speedup= 0.446x
PASS aiter_rmsnorm_M3000_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.97us FlyDSL= 29.11us speedup= 0.446x
PASS flydsl_rmsnorm_M4000_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.96us FlyDSL= 27.71us speedup= 0.468x
PASS aiter_rmsnorm_M4000_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.96us FlyDSL= 27.71us speedup= 0.468x
PASS flydsl_rmsnorm_M5000_N2880_torch.bfloat16 max_delta= 0.015440 close=100.00% AITER= 15.96us FlyDSL= 27.96us speedup= 0.571x
PASS aiter_rmsnorm_M5000_N2880_torch.bfloat16 max_delta= 0.015440 close=100.00% AITER= 15.96us FlyDSL= 27.96us speedup= 0.571x
PASS flydsl_rmsnorm_M7000_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 14.59us FlyDSL= 27.93us speedup= 0.522x
PASS aiter_rmsnorm_M7000_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 14.59us FlyDSL= 27.93us speedup= 0.522x
PASS flydsl_rmsnorm_M3072_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.35us FlyDSL= 27.66us speedup= 0.447x
PASS aiter_rmsnorm_M3072_N2880_torch.bfloat16 max_delta= 0.015321 close=100.00% AITER= 12.35us FlyDSL= 27.66us speedup= 0.447x
PASS flydsl_rmsnorm_M4096_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.75us FlyDSL= 27.32us speedup= 0.467x
PASS aiter_rmsnorm_M4096_N2880_torch.bfloat16 max_delta= 0.015002 close=100.00% AITER= 12.75us FlyDSL= 27.32us speedup= 0.467x
PASS flydsl_rmsnorm_M7168_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 15.02us FlyDSL= 27.90us speedup= 0.538x
PASS aiter_rmsnorm_M7168_N2880_torch.bfloat16 max_delta= 0.015307 close=100.00% AITER= 15.02us FlyDSL= 27.90us speedup= 0.538x
PASS flydsl_rmsnorm_M8192_N2880_torch.bfloat16 max_delta= 0.015525 close=100.00% AITER= 16.87us FlyDSL= 27.81us speedup= 0.606x
PASS aiter_rmsnorm_M8192_N2880_torch.bfloat16 max_delta= 0.015525 close=100.00% AITER= 16.87us FlyDSL= 27.81us speedup= 0.606x
PASS flydsl_rmsnorm_M16384_N2880_torch.bfloat16 max_delta= 0.015553 close=100.00% AITER= 30.41us FlyDSL= 29.02us speedup= 1.048x
PASS aiter_rmsnorm_M16384_N2880_torch.bfloat16 max_delta= 0.015553 close=100.00% AITER= 30.41us FlyDSL= 29.02us speedup= 1.048x
18/18 passed