[water] Add wave.broadcast operation definition#800
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a wave.broadcast operation to the Wave dialect for broadcasting tensors to larger shapes by replicating values along specified dimensions. The operation infers broadcast dimensions from the difference between source and result tensor shapes.
Changes:
- Added BroadcastOp definition with type inference, elements-per-thread propagation, and index expression propagation support
- Implemented custom parsing/printing for wave symbol arrays
- Added comprehensive test coverage for broadcast operation behavior and error cases
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| water/include/water/Dialect/Wave/IR/WaveOps.td | Defines BroadcastOp with traits and assembly format |
| water/lib/Dialect/Wave/IR/WaveOps.cpp | Implements BroadcastOp verification, type inference, EPT propagation, and symbol array parsing/printing |
| water/test/Dialect/Wave/ops.mlir | Tests valid broadcast operations with various shapes and explicit dims |
| water/test/Dialect/Wave/ops-invalid.mlir | Tests error cases: missing source dims, type mismatches, and invalid explicit dims |
| water/test/Dialect/Wave/infer-types.mlir | Tests type inference for broadcast operations |
| water/test/Dialect/Wave/infer-index-exprs.mlir | Tests index expression propagation through broadcast operations |
| water/test/Dialect/Wave/propagate-elements-per-thread.mlir | Tests EPT propagation in forward/backward directions and thread X broadcast behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Test broadcast propagates EPT forward (identity case - no thread X broadcast). | ||
| normalform.module [#wave.normal_form<full_types>] { | ||
| // CHECK-LABEL: @broadcast_propagation_forward_identity | ||
| // expected-warning @+1 {{unused hyperparameter: N}} |
There was a problem hiding this comment.
The test includes an expected warning for an unused hyperparameter N, but doesn't verify that the broadcast operation correctly handles the case where a hyperparameter is intentionally unused (broadcast along a dimension not in constraints). Consider adding a test that verifies the operation works correctly when N is actually used in constraints.
martin-luecke
left a comment
There was a problem hiding this comment.
We currently define the attribute broadcast_dims to represent the broadcast dimensions, but don't use it consistently in the implementation, besides skipping inference of it from result and arg types if it's present.
In my opinion, it does not seem to add much in its current form. It also does not really enable inferring the result type from arg type + this attribute.
Additionally, I think there might be an issue with invalidating a correct op when all types are properly resolved, and the attribute is still present.
Should we keep this attribute or remove it?
| // Note: If partial type inference is needed in the future, an optional | ||
| // broadcast_dims attribute could be added (mutually exclusive with | ||
| // fully-specified types). |
There was a problem hiding this comment.
This is already in the PR so we should either adjust the comment or go ahead and implement propagation based on the attribute, if possible.
There was a problem hiding this comment.
This is now implemented
| if (!sourceType || !resultType || !sourceType.getFullySpecified() || | ||
| !resultType.getFullySpecified()) | ||
| return false; |
There was a problem hiding this comment.
The semantics of this if statement are actually that we may broadcast along threadX, we just don't yet know for sure.
This should probably return true to stop eagerly propagating as long as we do not know that we definitely don't broadcast along X, right?
If I am correct in thinking this this function should also be renamed to mayBroadcastAlongThreadX
There was a problem hiding this comment.
Good point. Let me try hardening this, arguably we don't want to do EPT propagation with underspecified types or with vectors, so this condition should never fail.
| // CHECK: wave.read {{.*}} -> vector<8xf32> | ||
| %reg = wave.read %mem {elements_per_thread = 8} : (!wave.tensor<[@M] of f32, <global>>) -> !wave.tensor<[@M] of f32, <register>> | ||
|
|
||
| // Broadcast along @N (not thread X) - identity propagation, EPT stays 8. |
There was a problem hiding this comment.
Is not having this in the constraints #wave.workgroup_constraint<dim = <"M">, workgroup_dim = <x>>] guarantee that N is never bound to wg_dim 0?
There was a problem hiding this comment.
yes, in my understanding, if no dimension is bound to thread x, we run on one thread
There was a problem hiding this comment.
Let's also document that fact on the op
There was a problem hiding this comment.
I'll add a "developer remark" because this kind of guts shouldn't affect the user and therefore being exposed in the user documentaiton...
Based on discussion on #778, we don't want to add the attribute. I thought I'd removed this completely, turned out there's still some stashed changes on my local repo. Sorry about this, I just pushed it. |
This was also my opinion, but @tyb0807 argued strongly that this should be kept for consistency with reduction operations. The unknown point is whether broadcasting is allowed on non-trailing dimensions. If so, the attribute isn't sufficient and must be complemented by "insertion positions" of each dimension. If not, it is sufficient and we can just append these dimensions to the shape during type inference. I don't think element type inference is meaningful by itself. |
You did not, neither you should. This is a separate branch where I took over your commits. I will handle this from here., |
| if (!sourceType.getFullySpecified() || !resultType.getFullySpecified()) | ||
| return success(); |
There was a problem hiding this comment.
If we allow missing shapes here, I am wondering if we should maybe do the same for wave.permute?
There was a problem hiding this comment.
I think we should be consistent across all shape-related operations. Looking at the frontend, it looks like broadcast lists all result dimensions, so we can expect the full result type but potentially underspecified source type (and I'd rather do inference in mlir dataflow).
| return llvm::any_of(broadcastDims, [&](WaveSymbolAttr sym) { | ||
| return sym == init.threadXDimension; | ||
| }); |
There was a problem hiding this comment.
| return llvm::any_of(broadcastDims, [&](WaveSymbolAttr sym) { | |
| return sym == init.threadXDimension; | |
| }); | |
| return llvm::any_of(broadcastDims, llvm::equal_to(init.threadXDimension)); |
e4f6923 to
166bbe8
Compare
martin-luecke
left a comment
There was a problem hiding this comment.
Good to go, could be improved with tiny additions to the documentation of wave.broadcast
| // CHECK: wave.read {{.*}} -> vector<8xf32> | ||
| %reg = wave.read %mem {elements_per_thread = 8} : (!wave.tensor<[@M] of f32, <global>>) -> !wave.tensor<[@M] of f32, <register>> | ||
|
|
||
| // Broadcast along @N (not thread X) - identity propagation, EPT stays 8. |
There was a problem hiding this comment.
Let's also document that fact on the op
Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger shape by replicating values along specified dimensions. The operation takes a source tensor and a broadcast_dims attribute specifying which dimensions are being added. Some design decisions: - broadcast_dims attribute explicitly specifies which dimensions are added (source_shape + broadcast_dims = result_shape). - BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast dimension. When broadcasting along thread X, EPT is not propagated (NoChange) since source has no thread X and result EPT should come from downstream users (similar to how PyWave copies index from context). For other dims, identity propagation is used. - IdentityIndexExprsOpTrait: index expressions for shared dims propagate bidirectionally, broadcast dims are filled by backward propagation. - Custom type inference: backward propagation can infer source shape from result shape minus broadcast_dims. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Signed-off-by: tyb0807 <sontuan.vu@amd.com>
* the list of dimensions is optionally present * drop the verification that is always true * verify element types
The broadcast dimensions are now always inferred from the difference between the result shape and source shape. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
This reverts commit f5bffd16c05ffd63c826236ed06a56a8260f2d64. Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
02a3819 to
bb1d198
Compare
Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger
shape by replicating values along specified dimensions. The operation takes
a source tensor and a broadcast_dims attribute specifying which dimensions
are being added.
Some design decisions:
(source_shape + broadcast_dims = result_shape).
dimension. When broadcasting along thread X, EPT is not propagated
(NoChange) since source has no thread X and result EPT should come from
downstream users (similar to how PyWave copies index from context).
For other dims, identity propagation is used.
bidirectionally, broadcast dims are filled by backward propagation.
result shape minus broadcast_dims.
Revival of #778 , Fixes #721