Skip to content

[water] Add wave.broadcast operation definition#800

Merged
ftynse merged 9 commits intomainfrom
users/ftynse/broadcast
Feb 6, 2026
Merged

[water] Add wave.broadcast operation definition#800
ftynse merged 9 commits intomainfrom
users/ftynse/broadcast

Conversation

@ftynse
Copy link
Contributor

@ftynse ftynse commented Feb 2, 2026

Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger
shape by replicating values along specified dimensions. The operation takes
a source tensor and a broadcast_dims attribute specifying which dimensions
are being added.

Some design decisions:

  • broadcast_dims attribute explicitly specifies which dimensions are added
    (source_shape + broadcast_dims = result_shape).
  • BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast
    dimension. When broadcasting along thread X, EPT is not propagated
    (NoChange) since source has no thread X and result EPT should come from
    downstream users (similar to how PyWave copies index from context).
    For other dims, identity propagation is used.
  • IdentityIndexExprsOpTrait: index expressions for shared dims propagate
    bidirectionally, broadcast dims are filled by backward propagation.
  • Custom type inference: backward propagation can infer source shape from
    result shape minus broadcast_dims.

Revival of #778 , Fixes #721

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a wave.broadcast operation to the Wave dialect for broadcasting tensors to larger shapes by replicating values along specified dimensions. The operation infers broadcast dimensions from the difference between source and result tensor shapes.

Changes:

  • Added BroadcastOp definition with type inference, elements-per-thread propagation, and index expression propagation support
  • Implemented custom parsing/printing for wave symbol arrays
  • Added comprehensive test coverage for broadcast operation behavior and error cases

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
water/include/water/Dialect/Wave/IR/WaveOps.td Defines BroadcastOp with traits and assembly format
water/lib/Dialect/Wave/IR/WaveOps.cpp Implements BroadcastOp verification, type inference, EPT propagation, and symbol array parsing/printing
water/test/Dialect/Wave/ops.mlir Tests valid broadcast operations with various shapes and explicit dims
water/test/Dialect/Wave/ops-invalid.mlir Tests error cases: missing source dims, type mismatches, and invalid explicit dims
water/test/Dialect/Wave/infer-types.mlir Tests type inference for broadcast operations
water/test/Dialect/Wave/infer-index-exprs.mlir Tests index expression propagation through broadcast operations
water/test/Dialect/Wave/propagate-elements-per-thread.mlir Tests EPT propagation in forward/backward directions and thread X broadcast behavior

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

// Test broadcast propagates EPT forward (identity case - no thread X broadcast).
normalform.module [#wave.normal_form<full_types>] {
// CHECK-LABEL: @broadcast_propagation_forward_identity
// expected-warning @+1 {{unused hyperparameter: N}}
Copy link

Copilot AI Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test includes an expected warning for an unused hyperparameter N, but doesn't verify that the broadcast operation correctly handles the case where a hyperparameter is intentionally unused (broadcast along a dimension not in constraints). Consider adding a test that verifies the operation works correctly when N is actually used in constraints.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently define the attribute broadcast_dims to represent the broadcast dimensions, but don't use it consistently in the implementation, besides skipping inference of it from result and arg types if it's present.

In my opinion, it does not seem to add much in its current form. It also does not really enable inferring the result type from arg type + this attribute.

Additionally, I think there might be an issue with invalidating a correct op when all types are properly resolved, and the attribute is still present.

Should we keep this attribute or remove it?

Comment on lines 2010 to 2012
// Note: If partial type inference is needed in the future, an optional
// broadcast_dims attribute could be added (mutually exclusive with
// fully-specified types).
Copy link
Contributor

@martin-luecke martin-luecke Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already in the PR so we should either adjust the comment or go ahead and implement propagation based on the attribute, if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now implemented

Comment on lines 2025 to 2019
if (!sourceType || !resultType || !sourceType.getFullySpecified() ||
!resultType.getFullySpecified())
return false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The semantics of this if statement are actually that we may broadcast along threadX, we just don't yet know for sure.
This should probably return true to stop eagerly propagating as long as we do not know that we definitely don't broadcast along X, right?
If I am correct in thinking this this function should also be renamed to mayBroadcastAlongThreadX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Let me try hardening this, arguably we don't want to do EPT propagation with underspecified types or with vectors, so this condition should never fail.

// CHECK: wave.read {{.*}} -> vector<8xf32>
%reg = wave.read %mem {elements_per_thread = 8} : (!wave.tensor<[@M] of f32, <global>>) -> !wave.tensor<[@M] of f32, <register>>

// Broadcast along @N (not thread X) - identity propagation, EPT stays 8.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is not having this in the constraints #wave.workgroup_constraint<dim = <"M">, workgroup_dim = <x>>] guarantee that N is never bound to wg_dim 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, in my understanding, if no dimension is bound to thread x, we run on one thread

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also document that fact on the op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a "developer remark" because this kind of guts shouldn't affect the user and therefore being exposed in the user documentaiton...

@tyb0807
Copy link
Contributor

tyb0807 commented Feb 2, 2026

Should we keep this attribute or remove it?

Based on discussion on #778, we don't want to add the attribute. I thought I'd removed this completely, turned out there's still some stashed changes on my local repo. Sorry about this, I just pushed it.

@ftynse
Copy link
Contributor Author

ftynse commented Feb 2, 2026

In my opinion, it does not seem to add much in its current form. It also does not really enable inferring the result type from arg type + this attribute.

This was also my opinion, but @tyb0807 argued strongly that this should be kept for consistency with reduction operations. The unknown point is whether broadcasting is allowed on non-trailing dimensions. If so, the attribute isn't sufficient and must be complemented by "insertion positions" of each dimension. If not, it is sufficient and we can just append these dimensions to the shape during type inference.

I don't think element type inference is meaningful by itself.

@ftynse
Copy link
Contributor Author

ftynse commented Feb 2, 2026

Sorry about this, I just pushed it.

You did not, neither you should. This is a separate branch where I took over your commits. I will handle this from here.,

Comment on lines +1971 to +1960
if (!sourceType.getFullySpecified() || !resultType.getFullySpecified())
return success();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow missing shapes here, I am wondering if we should maybe do the same for wave.permute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be consistent across all shape-related operations. Looking at the frontend, it looks like broadcast lists all result dimensions, so we can expect the full result type but potentially underspecified source type (and I'd rather do inference in mlir dataflow).

Comment on lines 2018 to 2020
return llvm::any_of(broadcastDims, [&](WaveSymbolAttr sym) {
return sym == init.threadXDimension;
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return llvm::any_of(broadcastDims, [&](WaveSymbolAttr sym) {
return sym == init.threadXDimension;
});
return llvm::any_of(broadcastDims, llvm::equal_to(init.threadXDimension));

@ftynse ftynse force-pushed the users/ftynse/broadcast branch 2 times, most recently from e4f6923 to 166bbe8 Compare February 4, 2026 15:10
@ftynse ftynse requested a review from martin-luecke February 4, 2026 15:10
Copy link
Contributor

@martin-luecke martin-luecke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go, could be improved with tiny additions to the documentation of wave.broadcast

// CHECK: wave.read {{.*}} -> vector<8xf32>
%reg = wave.read %mem {elements_per_thread = 8} : (!wave.tensor<[@M] of f32, <global>>) -> !wave.tensor<[@M] of f32, <register>>

// Broadcast along @N (not thread X) - identity propagation, EPT stays 8.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also document that fact on the op

tyb0807 and others added 9 commits February 6, 2026 10:21
Add BroadcastOp to the Wave dialect for broadcasting a tensor to a larger
shape by replicating values along specified dimensions. The operation takes
a source tensor and a broadcast_dims attribute specifying which dimensions
are being added.

Some design decisions:
- broadcast_dims attribute explicitly specifies which dimensions are added
  (source_shape + broadcast_dims = result_shape).
- BroadcastElementsPerThreadOpTrait: EPT propagation depends on broadcast
  dimension. When broadcasting along thread X, EPT is not propagated
  (NoChange) since source has no thread X and result EPT should come from
  downstream users (similar to how PyWave copies index from context).
  For other dims, identity propagation is used.
- IdentityIndexExprsOpTrait: index expressions for shared dims propagate
  bidirectionally, broadcast dims are filled by backward propagation.
- Custom type inference: backward propagation can infer source shape from
  result shape minus broadcast_dims.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Signed-off-by: tyb0807 <sontuan.vu@amd.com>
* the list of dimensions is optionally present
* drop the verification that is always true
* verify element types
The broadcast dimensions are now always inferred from the difference
between the result shape and source shape.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
This reverts commit f5bffd16c05ffd63c826236ed06a56a8260f2d64.

Signed-off-by: tyb0807 <sontuan.vu@amd.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
Signed-off-by: Alex Zinenko <git@ozinenko.com>
@ftynse ftynse force-pushed the users/ftynse/broadcast branch from 02a3819 to bb1d198 Compare February 6, 2026 09:29
@ftynse ftynse merged commit d0932b8 into main Feb 6, 2026
15 checks passed
@ftynse ftynse deleted the users/ftynse/broadcast branch February 6, 2026 09:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[water] Implement wave.broadcast

4 participants