Skip to content

Adapt SDPA diagnostics to OptionalRef#3704

Open
drisspg wants to merge 1 commit into
intel:mainfrom
drisspg:agent/sdpa-diagnostics-optionalref
Open

Adapt SDPA diagnostics to OptionalRef#3704
drisspg wants to merge 1 commit into
intel:mainfrom
drisspg:agent/sdpa-diagnostics-optionalref

Conversation

@drisspg
Copy link
Copy Markdown

@drisspg drisspg commented May 19, 2026

Fixes pytorch/pytorch#184307

PyTorch PR #184307 updates SDPA backend diagnostics to pass c10::OptionalRefsdp::SDPDiagnostics through backend constraint checks instead of bool debug flags. Stage the XPU mem-efficient SDPA helper behind TORCH_XPU_OPS_USE_SDPA_OPTIONALREF_DIAGNOSTICS so current PyTorch builds keep using the bool-debug path while the PyTorch PR can opt into the new diagnostics plumbing. The bool-debug wrapper remains available for compatibility.

Test: none (build compatibility fix for PyTorch PR #184307; validated both current PyTorch-main bool path and macro-enabled OptionalRef path with compile-only checks against local PyTorch checkouts)

Copilot AI review requested due to automatic review settings May 19, 2026 02:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the XPU mem-efficient SDPA “can_use” gating helper to match upstream PyTorch’s new diagnostics plumbing by threading c10::OptionalRef<sdp::SDPDiagnostics> through backend constraint checks, while preserving the legacy bool debug wrapper overload for compatibility.

Skill files read: .github/skills/xpu-ops-pr-review/SKILL.md.

Changes:

  • Add a new can_use_mem_efficient_attention(sdp_params, c10::OptionalRef<SDPDiagnostics>) overload and route constraint checks through it.
  • Convert local XPU constraints (check_all_tensors_on_device, check_head_dim) to use diagnostics reporting instead of TORCH_WARN gated on debug.
  • Keep the existing bool debug overload as a wrapper that constructs/omits diagnostics.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
src/ATen/native/transformers/SDPUtils.h Declares the new diagnostics-based overload alongside the legacy bool debug API.
src/ATen/native/transformers/SDPUtils.cpp Plumbs c10::OptionalRef<SDPDiagnostics> through constraint arrays and adds a debug-wrapper implementation.

Comment on lines +29 to +33
"All tensors need to be on cuda device. Got query on device: ",
params.query.device(),
", key on device: ",
params.key.device(),
", value on device: ",
Comment on lines +27 to +34
report_failure(
diagnostics,
"All tensors need to be on cuda device. Got query on device: ",
params.query.device(),
", key on device: ",
params.key.device(),
", value on device: ",
params.value.device());
sdp_params const& params,
c10::OptionalRef<SDPDiagnostics> diagnostics) {
// Check that all tensors are on the GPU device
// This should be handled by the stub dispatch, but whe call
@drisspg drisspg force-pushed the agent/sdpa-diagnostics-optionalref branch from 920a5a5 to 0a4f51b Compare May 19, 2026 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants