[WebGPU] QKV and MLP fusions for Qwen3#28280
Open
hariharans29 wants to merge 25 commits intomainfrom
Open
Conversation
…xruntime into hari/webgpu_perf_1
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds WebGPU-focused fused operators and optimizer passes for decoder-style MatMulNBits patterns (MLP gate/up and QKV projections), along with tests and a microbenchmark to evaluate decode performance/correctness.
Changes:
- Introduces new contrib ops
MatMulNBitsMlpandMatMulNBitsQkv(schemas + WebGPU kernels + WGSL templates). - Adds graph transformers
MatMulNBitsMlpFusion/MatMulNBitsQkvFusionand corresponding optimizer tests. - Improves WebGPU runtime support (graph-capture buffer manager activation, queue-idle wait helper, better shader compilation diagnostics) and adds a decode microbenchmark.
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/test/optimizer/matmul_nbits_qkv_fusion_test.cc | New unit tests validating QKV fusion and output contracts on WebGPU. |
| onnxruntime/test/optimizer/matmul_nbits_mlp_fusion_test.cc | New unit tests validating MLP fusion (simplified/skip + passthrough) on WebGPU. |
| onnxruntime/test/optimizer/graph_transform_utils_test.cc | Minor formatting-only tweak (blank line). |
| onnxruntime/test/onnx/microbenchmark/webgpu_matmul_nbits_decode.cc | New benchmark harness for fused/unfused decode paths on WebGPU. |
| onnxruntime/test/onnx/microbenchmark/main.cc | Adjusts benchmark env logging severity. |
| onnxruntime/core/session/ort_version_check.h | Makes version parsing consteval-friendly with a macro fallback. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Tracks when graph-capture buffer manager is active. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Lazily creates/activates graph buffer manager for capture; allocator uses dynamic buffer manager getter. |
| onnxruntime/core/providers/webgpu/webgpu_context.h | Adds WaitForQueueIdle() declaration. |
| onnxruntime/core/providers/webgpu/webgpu_context.cc | Implements WaitForQueueIdle() using OnSubmittedWorkDone. |
| onnxruntime/core/providers/webgpu/program_manager.cc | Enhances pipeline build failures with shader compilation diagnostics. |
| onnxruntime/core/providers/webgpu/compute_context.h | Adds FlushAndWait() convenience for flushing + waiting on queue idle. |
| onnxruntime/core/providers/webgpu/allocator.h | Adds allocator ctor that accepts a buffer-manager getter function. |
| onnxruntime/core/providers/webgpu/allocator.cc | Implements getter-based allocator to support switching buffer managers. |
| onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.h | New transformer declaration for QKV fusion. |
| onnxruntime/core/optimizer/matmul_nbits_qkv_fusion.cc | New transformer implementation for QKV fusion. |
| onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.h | New transformer declaration for MLP fusion. |
| onnxruntime/core/optimizer/matmul_nbits_mlp_fusion.cc | New transformer implementation for MLP fusion. |
| onnxruntime/core/optimizer/graph_transformer_utils.cc | Registers the new fusion transformers. |
| onnxruntime/core/graph/contrib_ops/contrib_defs.cc | Adds contrib operator schemas/docs for MatMulNBitsMlp and MatMulNBitsQkv. |
| onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | Registers WebGPU kernels for the new fused ops. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.wgsl.template | New WGSL template implementing fused QKV decode kernel. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.h | New WebGPU kernel wrapper for MatMulNBitsQkv. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_qkv.cc | New WebGPU kernel implementation for MatMulNBitsQkv. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp_wide_tile_m1.wgsl.template | New WGSL template for an MLP wide-tile variant. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.wgsl.template | New WGSL template implementing fused MLP (optionally with norm/skip/passthrough). |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.h | New WebGPU kernel wrapper for MatMulNBitsMlp. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_mlp.cc | New WebGPU kernel implementation for MatMulNBitsMlp. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.h | Adds declarations for “would apply” dispatch-selection helpers and shared constants. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_common.cc | Implements the new dispatch-selection helpers. |
| onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc | Refactors path selection to use the new “would apply” helpers. |
| onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_mlp.wgsl.template | Adds WGSL template for DP4A MLP path. |
| cmake/onnxruntime_unittests.cmake | Wires the new WebGPU decode benchmark into the benchmark target sources. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…shader diagnostics These changes are kept on hari/webgpu_perf_1_full locally. The lazy buffer-mgr fix is being submitted as a separate PR (branch hari/webgpu_graph_capture_buffer_fix) because it is an independent correctness fix for a pre-existing latent bug, exposed but not introduced by these fusions.
This template file was added speculatively but is not referenced by any kernel, include, or build rule. Removing to keep the PR clean.
…_transformer_utils
Contributor
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The shared-EP path through TransformerTester triggers a SEH 0xC0000005 in CI when the EP outlives a per-session profiler whose pointer is still cached on the EP. A separate fix to the WebGPU EP's session_profiler_ lifetime is in flight; meanwhile, switch the 8 MatMulNBits MLP and QKV WebGPU fusion-vs- unfused tests to a small RunWebGpuFusionTransformerTest helper that creates a fresh execution provider per session via a factory lambda. Production code is unchanged.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Summary
Adds two WebGPU-only graph fusions and the contrib ops they target, plus a small
refactor of the existing
MatMulNBitsdispatch logic so the new fused kernelscan share its predicates.
MatMulNBitsMlpop + kernelcontrib_ops/webgpu/quantization/matmul_nbits_mlp.{cc,h},*.wgsl.template(3)(Skip)SimplifiedLayerNormalization+ twoMatMulNBitsprojections (gate, up) + optional biases +Sigmoid/Mul(SiLU) + element-wiseMul. Single dispatch instead of 5–7.MatMulNBitsQkvop + kernelcontrib_ops/webgpu/quantization/matmul_nbits_qkv.{cc,h},*.wgsl.template(Skip)SimplifiedLayerNormalization+ threeMatMulNBitsprojections (Q, K, V) sharing the same input. Single dispatch instead of 4.core/graph/contrib_ops/contrib_defs.ccMatMulNBitsMlpandMatMulNBitsQkvcontrib op schemas (kMSDomain, opset 1).core/optimizer/matmul_nbits_{mlp,qkv}_fusion.{cc,h}graph_transformer_utils.cc.contrib_ops/webgpu/quantization/matmul_nbits_common.{cc,h}+matmul_nbits.ccMatMulNBitspath.test/optimizer/matmul_nbits_{mlp,qkv}_fusion_test.cc,graph_transform_utils_test.ccMotivation and Context
~25-30% decode TPS throughput improvement on WebGPU + D3D backend on Windows. GPU used: RTX 5060Ti for Qwe3-1.7B.
BEFORE (95 decode TPS): main branch

AFTER (120+ decode TPS): PR branch
