Adapt SDPA diagnostics to OptionalRef#3704
Open
drisspg wants to merge 1 commit into
Open
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the XPU mem-efficient SDPA “can_use” gating helper to match upstream PyTorch’s new diagnostics plumbing by threading c10::OptionalRef<sdp::SDPDiagnostics> through backend constraint checks, while preserving the legacy bool debug wrapper overload for compatibility.
Skill files read: .github/skills/xpu-ops-pr-review/SKILL.md.
Changes:
- Add a new
can_use_mem_efficient_attention(sdp_params, c10::OptionalRef<SDPDiagnostics>)overload and route constraint checks through it. - Convert local XPU constraints (
check_all_tensors_on_device,check_head_dim) to use diagnostics reporting instead ofTORCH_WARNgated ondebug. - Keep the existing
bool debugoverload as a wrapper that constructs/omits diagnostics.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
src/ATen/native/transformers/SDPUtils.h |
Declares the new diagnostics-based overload alongside the legacy bool debug API. |
src/ATen/native/transformers/SDPUtils.cpp |
Plumbs c10::OptionalRef<SDPDiagnostics> through constraint arrays and adds a debug-wrapper implementation. |
Comment on lines
+29
to
+33
| "All tensors need to be on cuda device. Got query on device: ", | ||
| params.query.device(), | ||
| ", key on device: ", | ||
| params.key.device(), | ||
| ", value on device: ", |
Comment on lines
+27
to
+34
| report_failure( | ||
| diagnostics, | ||
| "All tensors need to be on cuda device. Got query on device: ", | ||
| params.query.device(), | ||
| ", key on device: ", | ||
| params.key.device(), | ||
| ", value on device: ", | ||
| params.value.device()); |
| sdp_params const& params, | ||
| c10::OptionalRef<SDPDiagnostics> diagnostics) { | ||
| // Check that all tensors are on the GPU device | ||
| // This should be handled by the stub dispatch, but whe call |
920a5a5 to
0a4f51b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes pytorch/pytorch#184307
PyTorch PR #184307 updates SDPA backend diagnostics to pass c10::OptionalRefsdp::SDPDiagnostics through backend constraint checks instead of bool debug flags. Stage the XPU mem-efficient SDPA helper behind TORCH_XPU_OPS_USE_SDPA_OPTIONALREF_DIAGNOSTICS so current PyTorch builds keep using the bool-debug path while the PyTorch PR can opt into the new diagnostics plumbing. The bool-debug wrapper remains available for compatibility.
Test: none (build compatibility fix for PyTorch PR #184307; validated both current PyTorch-main bool path and macro-enabled OptionalRef path with compile-only checks against local PyTorch checkouts)