Skip to content

Conversation

@ErwinTerpstra
Copy link
Contributor

@ErwinTerpstra ErwinTerpstra commented Jan 8, 2026

Proposed changes

This MR implements batched gemm bias permute for RDNA3/4. In practice, this is a multidimensional contraction operation. The MR contains the following:

  • Profiler and test infrastructure for the batched contraction instances, as this was not implemented yet for XDL versions
  • Device struct for batched contraction using WMMA instructions (device_batched_contraction_multiple_d_wmma_cshuffle_v3)
  • Changes to the GridwiseGemmWmmaCShuffleV3 to allow passing in non-naive grid descriptors

Note that support for different dimensions and D tensor configurations is very limited at the moment. More scaffolding would be needed to add generic support for variable number of dimensions, but with this limited implementation there is at least parity with the XDL versions.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@EnricoDeg
Copy link
Contributor

Can you also add an example for wmma?

Copy link
Contributor

@EnricoDeg EnricoDeg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work !

Comment on lines +554 to +580
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(g_idx) * batch_stride_A_;
}

__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(g_idx) * batch_stride_B_;
}

__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
std::array<long_index_t, NumDTensor> ds_offset;

static_for<0, NumDTensor, 1>{}([&](auto i) {
ds_offset[i] = static_cast<long_index_t>(g_idx) *
ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
});

return ds_offset;
}

__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return static_cast<long_index_t>(g_idx) *
e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing for me. Why for A and B the stride is used and for D and E the grid descriptor is used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For A and B there is no 3D grid descriptor created (No GMN, just MN), because that isn't used anywhere. I assumed the D and E must use a grid descriptor because it can be a non-trivial transformation (I think because E is permuted, although that probably doesn't change the batch stride)

But yeah, it's a bit inconsistent. I could make it more consistent if you want.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine. It was just unclear to me looking at the code. Maybe add a comment about it

Comment on lines 780 to 787
if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we define tailNum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, that was missing. I added it and added some small test cases (where HasMainKBlock == false) to verify that it works.

Comment on lines 47 to 49
DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3< 1, 2, 3, 1, F16, F16, F32, F16, F16_Tuple, F16, PassThrough, PassThrough, Add, GemmSpec, ABSpec, ABSpec, DESpec, 256, 256, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1>>
// clang-format on
>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to have a few more instances to check correctness in the tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

@ApoorvaKalyani ApoorvaKalyani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!
I also think we need more instances and we need to reverify the tests for those.

@ErwinTerpstra
Copy link
Contributor Author

@EnricoDeg @ApoorvaKalyani Thank you for the reviews. I processed the comments, added an example and added a couple of instances for both v1 and v3 pipelines. Let me know if there's still something you'd like to see changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants