Skip to content

Relax GQA seqlens_k shape validation for backward compat with older models#28259

Merged
vraspar merged 5 commits intomainfrom
vraspar/fix-gqa-seqlens-k-shape-compat
May 1, 2026
Merged

Relax GQA seqlens_k shape validation for backward compat with older models#28259
vraspar merged 5 commits intomainfrom
vraspar/fix-gqa-seqlens-k-shape-compat

Conversation

@vraspar
Copy link
Copy Markdown
Contributor

@vraspar vraspar commented Apr 29, 2026

Problem

PR #28031 fixed a security OOB GEMM bug via crafted seqlens_k by changing && to || in the shape validation in group_query_attention_helper.h. This correctly enforces the spec (1D Tensor of shape (batch_size)) but breaks models (e.g. qwen3-0.6b, qwen3-1.7b) whose builder.py emits seqlens_k with shape [1,1] instead of [1].

Fix

Relax the shape check to accept shapes with unit dimensions around the batch axis. The validation rule is:

  1. seqlens_k must be at least 1D (scalars are rejected)
  2. Total element count must equal batch_size
  3. Each dimension must be 1 or batch_size (e.g. accepts [B], [B,1], [1,B] but rejects [2,2] for B=4)

Also fixes the same latent &&/|| bug in the JS/WebGPU EP (group-query-attention.ts).

Security: The per-element value bounds checks in Compute() are unchanged -- the OOB fix from #28031 is fully preserved.

Changes

  • group_query_attention_helper.h -- scalar rejection + element-count shape check (shared by CPU, CUDA, WebGPU EPs)
  • group-query-attention.ts -- same fix for the JS WebGPU path
  • group_query_attention_op_test.cc -- tests for [1,1] compat, multi-batch [2,1] compat, trailing-batch [1,2] compat, scalar rejection, wrong-count rejection, and invalid factored shape rejection

Comment thread js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts Outdated
Comment thread js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts Outdated
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
…odels

PR #28031 tightened seqlens_k shape validation (&&->||), correctly
rejecting non-1D tensors per spec. However, older model builders emit
seqlens_k with shape [1,1] instead of [1], breaking HuggingFace LLMs
(qwen3-0.6b, qwen3-1.7b).

Relax shape check to allow unit dimensions around the batch axis: each
dim must be 1 or batch_size (accepts [B], [B,1], [1,1] but rejects
[2,2] for B=4). Also fixes the same latent && bug in JS/WebGPU EP.

Value bounds checks in Compute() are unchanged.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@vraspar vraspar force-pushed the vraspar/fix-gqa-seqlens-k-shape-compat branch from ba7d3a2 to c0b4397 Compare April 29, 2026 05:38
@vraspar
Copy link
Copy Markdown
Contributor Author

vraspar commented Apr 29, 2026

Sorry about the force-push — Copilot CLI rewrote the branch and lost the incremental diff history.

Addressed all 5 comments:

  • group_query_attention_helper.h:267 — Tightened the factored-shape check so each dim must be 1 or batch_size (rejects e.g. [2,2] for B=4). Added SeqlensKInvalidFactoredShape test to cover it.
  • group-query-attention.ts:203 — Aligned error messages between JS and C++ so they match.
  • group-query-attention.ts:197 — Removed [1, 1] from the comment in both C++ and JS. Now just shows [B, 1] instead of [B].
  • group_query_attention_op_test.cc:267 — Added a comment explaining the loose tolerance: these tests validate shape acceptance, not numerical correctness. Agree exact-value tests can be a follow-up.
  • group_query_attention_op_test.cc:237 — Extended RunGQASeqlensKTest with an optional seqlens_k_shape param. All 5 shape tests use the helper now, net -73 lines.

Add JS/WebGPU test for [1,1] seqlens_k shape (the exact qwen3 regression
case) and C++ test for trailing batch dim shape {1,B}.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Relaxes seqlens_k shape validation for GroupQueryAttention to restore backward compatibility with older model exporters that emit extra unit dimensions (e.g., [B,1]), while keeping the value-range checks that prevent OOB access.

Changes:

  • Update C++ CheckInputs() validation to accept seqlens_k shapes with batch_size total elements (with additional per-dimension constraints).
  • Apply equivalent validation updates in the JS/WebGPU validateInputs() path.
  • Extend CPU and JS test coverage with legacy-shape acceptance and wrong-shape rejection cases.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h Updates seqlens_k shape validation and error messages in shared helper.
js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts Aligns WebGPU input validation with the relaxed seqlens_k shape rules.
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds regression tests for legacy 2D shapes and invalid element-count/shape cases.
js/web/test/data/ops/group-query-attention.jsonc Adds a Web test case covering legacy [1,1] seqlens_k shape acceptance.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
Comment thread js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
Comment thread onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h
edgchen1
edgchen1 previously approved these changes Apr 29, 2026
Address review comments:
- Reject rank-0 (scalar) seqlens_k in both C++ and JS validation
- Use std::optional<vector> for test helper seqlens_k_shape param
- Add SeqlensKScalarRejected test case

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@vraspar
Copy link
Copy Markdown
Contributor Author

vraspar commented Apr 29, 2026

Addressed remaining comments:

  • *\helper.h:265* (edgchen1 + Copilot) — Added \NumDimensions() == 0\ rejection so scalar \seqlens_k\ is no longer silently accepted when \�atch_size==1. Same check added in JS path (\dims.length === 0).
  • *\ est.cc:26* (edgchen1) — Changed \seqlens_k_shape\ param to \std::optional<std::vector<int64_t>>\ so empty {}\ isn't confused with scalar shape. All call sites wrapped with explicit \std::vector<int64_t>{...}.
  • *\helper.h:277* (Copilot) — Updated PR description to reflect the full validation rule (at least 1D + element count + per-dim constraint).
  • Added \SeqlensKScalarRejected\ test to cover the new scalar rejection path.

Comment thread onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Outdated
edgchen1
edgchen1 previously approved these changes Apr 29, 2026
Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
@vraspar
Copy link
Copy Markdown
Contributor Author

vraspar commented Apr 29, 2026

Validated with https://huggingface.co/schmuell/Qwen3-1.7B

edgchen1
edgchen1 previously approved these changes Apr 29, 2026
@ankitm3k
Copy link
Copy Markdown
Contributor

@vraspar your PR #28031 broke the functionality & I have tested with open source models too. FYI intel#1067

@edgchen1
Copy link
Copy Markdown
Contributor

looks like the CI build is complaining about JS formatting.

Error: Following source files are not formatted: (did you run "npm run format"?)
js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts

@vraspar
Copy link
Copy Markdown
Contributor Author

vraspar commented May 1, 2026

Thanks @edgchen1, Fixed the linting issue

@vraspar vraspar merged commit 60ce9cc into main May 1, 2026
89 of 91 checks passed
@vraspar vraspar deleted the vraspar/fix-gqa-seqlens-k-shape-compat branch May 1, 2026 22:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants