-
Notifications
You must be signed in to change notification settings - Fork 265
Implement batched gemm bias permute for RDNA4 #3534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Implement batched gemm bias permute for RDNA4 #3534
Conversation
…rs for gridwise_gemm_wmma_cshuffle_v3, test setup for odd cases
…_bias_permute-for-rdna4
|
Can you also add an example for wmma? |
EnricoDeg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work !
| __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const | ||
| { | ||
| return static_cast<long_index_t>(g_idx) * batch_stride_A_; | ||
| } | ||
|
|
||
| __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const | ||
| { | ||
| return static_cast<long_index_t>(g_idx) * batch_stride_B_; | ||
| } | ||
|
|
||
| __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const | ||
| { | ||
| std::array<long_index_t, NumDTensor> ds_offset; | ||
|
|
||
| static_for<0, NumDTensor, 1>{}([&](auto i) { | ||
| ds_offset[i] = static_cast<long_index_t>(g_idx) * | ||
| ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); | ||
| }); | ||
|
|
||
| return ds_offset; | ||
| } | ||
|
|
||
| __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const | ||
| { | ||
| return static_cast<long_index_t>(g_idx) * | ||
| e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing for me. Why for A and B the stride is used and for D and E the grid descriptor is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For A and B there is no 3D grid descriptor created (No GMN, just MN), because that isn't used anywhere. I assumed the D and E must use a grid descriptor because it can be a non-trivial transformation (I think because E is permuted, although that probably doesn't change the batch stride)
But yeah, it's a bit inconsistent. I could make it more consistent if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine. It was just unclear to me looking at the code. Maybe add a comment about it
.../tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp
Outdated
Show resolved
Hide resolved
| if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K)) | ||
| { | ||
| return launch_kernel(integral_constant<bool, true>{}); | ||
| } | ||
| else | ||
| { | ||
| return launch_kernel(integral_constant<bool, false>{}); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we define tailNum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, that was missing. I added it and added some small test cases (where HasMainKBlock == false) to verify that it works.
.../tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp
Outdated
Show resolved
Hide resolved
.../tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp
Show resolved
Hide resolved
| DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>> | ||
| // clang-format on | ||
| >; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's better to have a few more instances to check correctness in the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
ApoorvaKalyani
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
I also think we need more instances and we need to reverify the tests for those.
…e code between platforms
…tances to the test
…_bias_permute-for-rdna4
|
@EnricoDeg @ApoorvaKalyani Thank you for the reviews. I processed the comments, added an example and added a couple of instances for both v1 and v3 pipelines. Let me know if there's still something you'd like to see changed. |
Proposed changes
This MR implements batched gemm bias permute for RDNA3/4. In practice, this is a multidimensional contraction operation. The MR contains the following:
device_batched_contraction_multiple_d_wmma_cshuffle_v3)GridwiseGemmWmmaCShuffleV3to allow passing in non-naive grid descriptorsNote that support for different dimensions and D tensor configurations is very limited at the moment. More scaffolding would be needed to add generic support for variable number of dimensions, but with this limited implementation there is at least parity with the XDL versions.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered