[water] verify wave constraints match wave per block#830
Conversation
The two are alternative ways of specifying the same things and therefore should match. Arguably, only one way should remain, but both are currently used at the frontend level. Signed-off-by: Alex Zinenko <git@ozinenko.com>
There was a problem hiding this comment.
Pull request overview
This PR adds validation to ensure wave constraints and waves_per_block attributes remain consistent. The validation verifies that wave tile sizes evenly divide workgroup tile sizes, and that computed wave counts match the waves_per_block hardware constraint when both are specified.
Changes:
- Added validation to verify wave constraint tile sizes evenly divide workgroup constraint tile sizes
- Added validation to verify computed wave counts match waves_per_block hardware constraint values
- Added comprehensive test cases for both valid configurations and error cases
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| water/lib/Dialect/Wave/IR/WaveDialect.cpp | Implements validation logic for wave constraints and waves_per_block consistency checks |
| water/test/Dialect/Wave/attr-constraint.mlir | Adds positive test cases for valid wave constraint configurations |
| water/test/Dialect/Wave/attr-constraint-invalid.mlir | Adds negative test cases for invalid wave constraint configurations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| #hyperparams = #wave.hyperparameters<{M = 1024, BLOCK_M = 128}> | ||
| #wg_constraint = #wave.workgroup_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, workgroup_dim = <x>> | ||
| #wv_constraint = #wave.wave_constraint<dim = <"M">, tile_size = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M floordiv 3)>> | ||
| // expected-error @below {{wave constraint tile size 42 does not evenly divide workgroup constraint tile size 128 for dimension: #wave.symbol<"M">}} |
There was a problem hiding this comment.
The error message expects tile size 42, but BLOCK_M floordiv 3 evaluates to 42 only if BLOCK_M = 128. However, the constraint uses BLOCK_M floordiv 3 which with BLOCK_M = 128 gives floor(128/3) = 42. The test will fail because the actual computed value is 42, but this is correct. The issue is the comment seems inconsistent - verify the expected value matches the actual computation.
There was a problem hiding this comment.
This makes no sense whatsoever!
| #hyperparams = #wave.hyperparameters<{N = 512, BLOCK_N = 64}> | ||
| #wg_constraint = #wave.workgroup_constraint<dim = <"N">, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N)>, workgroup_dim = <y>> | ||
| #wv_constraint = #wave.wave_constraint<dim = <"N">, tile_size = <[#wave.symbol<"BLOCK_N">] -> (BLOCK_N floordiv 5)>> | ||
| // expected-error @below {{wave constraint tile size 12 does not evenly divide workgroup constraint tile size 64 for dimension: #wave.symbol<"N">}} |
There was a problem hiding this comment.
The error message expects tile size 12, but BLOCK_N floordiv 5 with BLOCK_N = 64 gives floor(64/5) = 12. While mathematically correct, verify this is the intended test case value.
The two are alternative ways of specifying the same things and therefore should match. Arguably, only one way should remain, but both are currently used at the frontend level.