Conversation
…include gfx1250 rmsnorm
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
| sharedLayoutWeights: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0]) | ||
|
|
||
| # create a swizzled shared layout for the output | ||
| gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) |
There was a problem hiding this comment.
This isn't assigned to anything but is probably also not needed since the output is not TDM stored so you don't need a shared layout for it?
|
|
||
| # Loop through the rows of the input tensor by NUM_PROG blocks | ||
| for row_idx in range(row_start, n_rows, NUM_PROG): | ||
| input_ptr + (row_idx * input_row_stride) |
There was a problem hiding this comment.
This isn't assigned to anything?
| rms_norm = a * norm_factor * weights | ||
| # store rms norm and the norm factor | ||
| gl.store( | ||
| rsigma_ptr + row_start, norm_factor.to(rsigma_ptr.dtype.element_ty) |
| USE_BLOCK = COL > BLOCK_SIZE | ||
| NUM_PROG = min(ROW, get_num_sms()) | ||
|
|
||
| grid = (NUM_PROG,) |
There was a problem hiding this comment.
I think you can put min(ROW, get_num_sms()) here.
| output = torch.empty_like(input, device=input.device) | ||
| rsigma = torch.empty((ROW,), device=input.device, dtype=input.dtype) | ||
|
|
||
| MAX_FUSED_SIZE = 65536 // input.element_size() |
There was a problem hiding this comment.
Comment for the magic number?
| @@ -0,0 +1,39 @@ | |||
| # SPDX-License-Identifier: MIT | |||
There was a problem hiding this comment.
I don't think we want two different files. We want a single API and the wrapper decides whether to call triton (gfx950 and earlier) or gluon (gfx1250, if a gluon kernel exists).
Motivation
Create rmsnorm kernel in gluon for gfx1250
Technical Details
Translated existing triton implementation into a gluon equivalent.
Test Plan
Added a test reference in existing test_rmsnorm.py for gluon implementation.
Test Result
Passed all test condition