Skip to content

Split subgroup topk kernel per-K for parallel AOT compilation#3683

Open
jianyizh wants to merge 3 commits into
mainfrom
jianyi/sbtopk-split-compile
Open

Split subgroup topk kernel per-K for parallel AOT compilation#3683
jianyizh wants to merge 3 commits into
mainfrom
jianyi/sbtopk-split-compile

Conversation

@jianyizh
Copy link
Copy Markdown
Contributor

Summary

Split TensorTopKSbtopkKernel.cpp into per-K compilation units so each generates a separate .so that can be AOT-compiled in parallel, reducing build time. Follows the same pattern as #3652 (FlashAttention per-headdim split).

Changes

  • TensorTopKSbtopkKernelImpl.h (new) — shared header with SubgroupTopKFunctor, sbtopk_launch_impl, and sbtopk_launch_vec_dispatch templates
  • TensorTopKSbtopkKernel_k{1,2,4,8,16}.cpp (new) — per-K instantiations, each producing its own .so with AOT kernel code
  • TensorTopKSbtopkKernel.cpp (modified) — dispatch-only, routes to per-K functions; no kernel code, tiny .so

Motivation

The subgroup topk kernel has a large number of template instantiations: K(5) × scalar_t(~6) × IndexT(2) × Largest(2) × VEC_SIZE(up to 4). Keeping all instantiations in a single .cpp creates one monolithic AOT compilation that cannot be parallelized. Splitting by K allows 5 independent AOT compilations to run in parallel.

Testing

  • Pure refactoring: no functional changes to kernel logic or dispatch behavior
  • Verified 180/180 accuracy tests pass on B580 with the equivalent split on the single-workgroup topk branch

Copilot AI review requested due to automatic review settings May 15, 2026 04:14
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Read .github/skills/xpu-ops-pr-review/SKILL.md and its referenced review notes/checklists.

This PR refactors subgroup top-k XPU kernel instantiations into per-K compilation units to enable more parallel AOT compilation while keeping the original dispatch path in TensorTopKSbtopkKernel.cpp.

Changes:

  • Moves shared subgroup top-k functor and launch templates into TensorTopKSbtopkKernelImpl.h.
  • Adds K-specific launch files for K=1,2,4,8,16.
  • Leaves TensorTopKSbtopkKernel.cpp as dispatch logic that routes to the per-K launchers.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernelImpl.h Shared subgroup top-k template implementation and per-K launcher declarations.
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp Dispatch-only routing to K-specific launch units.
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k1.cpp K=1 typed dispatch and launch instantiations.
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k2.cpp K=2 typed dispatch and launch instantiations.
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k4.cpp K=4 typed dispatch and launch instantiations.
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k8.cpp K=8 typed dispatch and launch instantiations.
src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k16.cpp K=16 typed dispatch and launch instantiations.

Comment thread src/ATen/native/xpu/sycl/TensorTopKSbtopkKernelImpl.h
Comment thread src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k1.cpp
Comment thread src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k2.cpp
Comment thread src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k4.cpp
Comment thread src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k8.cpp
Comment thread src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel_k16.cpp Outdated
Copilot AI review requested due to automatic review settings May 15, 2026 04:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Comment on lines +36 to +43
sbtopk_k1_launch(
self,
nsegments,
nelements,
static_cast<int>(k),
largest,
values,
indices);
@jianyizh jianyizh requested review from CuiYifeng and chuanqi129 May 15, 2026 05:02
@chuanqi129
Copy link
Copy Markdown
Contributor

@copilot New UT failures detected in op_ut tests (24 total). Please check these new failures and analyze whether they are caused by the PR changes.

Important: Do NOT update the plan in the PR description directly. Use reply comments to update the status.

Job log: https://github.com/intel/torch-xpu-ops/actions/runs/25900306416
Full new failure report: https://github.com/intel/torch-xpu-ops/actions/runs/25900306416/artifacts/7032558883

New failures:

Category Class name Test name Status Message
op_ut third_party.torch-xpu-ops.test.xpu.dynamo.test_misc_xpu.MiscTests test_const_dict_variable_python_type failed AttributeError: '_thread._local' object has no attribute 'current_tx'
op_ut third_party.torch-xpu-ops.test.xpu.dynamo.test_regional_inductor_xpu.RegionalInductorInvokeSubgraphTests test_flex_attention_serialize_False failed AssertionError: "def [724 chars]n alias = torch.ops.aten.alias.default(geti[349 chars]s_3)" != "def [724 chars]n return (getitem, primals_0, primals_1, pr[90 chars]m_1)"
op_ut third_party.torch-xpu-ops.test.xpu.profiler.test_profiler_xpu.TestProfiler test_disable_external_correlation failed AssertionError: False is not true
op_ut third_party.torch-xpu-ops.test.xpu.profiler.test_profiler_xpu.TestProfiler test_dynamic_toggle failed AssertionError: False is not true
op_ut third_party.torch-xpu-ops.test.xpu.test_custom_ops_xpu.TestCustomOpAPI test_no_grad_skips_autograd failed AssertionError: Scalars are not equal!

... and 19 more failures. See full list in artifact.

Instructions:

  1. Download the full new failure report artifact for complete error messages
  2. Analyze whether the failures are caused by PR changes
  3. If caused by PR changes, fix those failures on this PR branch directly
  4. Add failing test cases with pytest commands in the PR description for the reproducer job to use
  5. Use reply comments to report analysis status - do NOT modify the PR description plan directly

Copy link
Copy Markdown
Contributor

Copilot AI commented May 16, 2026

@copilot New UT failures detected in op_ut tests (24 total). Please check these new failures and analyze whether they are caused by the PR changes.

**Important: Do NOT update the plan in the PR d...

Analyzed run 25900306416 and downloaded both artifacts (New-UT-Failures-3683-op_ut-25900306416-1 and Inductor-XPU-UT-Data-3683-op_ut-25900306416-1) at commit 9050577.

Conclusion: these new failures are not caused by this PR’s subgroup topk split. The failures are in Dynamo/Profiler/AOTAutogradCache/custom-ops paths (e.g. current_tx, cudagraphs, profiler expectation diffs, device-properties attr mismatch), and none of the failure traces/messages reference TensorTopKSbtopk, sbtopk, or topk kernel loading/linkage.

Reproducer pytest commands from failure XML:

  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/dynamo/test_misc_xpu.py MiscTests.test_const_dict_variable_python_type
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/dynamo/test_regional_inductor_xpu.py RegionalInductorInvokeSubgraphTests.test_flex_attention_serialize_False
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/profiler/test_profiler_xpu.py TestProfiler.test_disable_external_correlation
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/test_custom_ops_xpu.py TestCustomOpAPI.test_no_grad_skips_autograd
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/dynamo/test_aot_autograd_cache_xpu.py AOTAutogradCachePicklerTests.test_identical_graphs_and_configs

No code change is applied on this PR branch for this comment because no causal link was found to the subgroup topk per-K refactor.

@github-actions
Copy link
Copy Markdown

Performance outliers, please check!

  • 🔴 [-1, 80%), should be regression
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
timm_models_bfloat16_training ghostnet_100 0.863207 0.568140
timm_models_bfloat16_training mobilenetv3_large_100 0.783723 0.569339
timm_models_bfloat16_training mobilevit_s 0.711398 0.575314
timm_models_bfloat16_training mobilenetv2_100 0.827169 0.577861
timm_models_bfloat16_training tf_efficientnet_b0 0.811818 0.624784
timm_models_bfloat16_training vit_base_patch16_siglip_256 0.692765 0.743556
timm_models_bfloat16_training deit_base_distilled_patch16_224 0.734606 0.745640
timm_models_bfloat16_training nfnet_l0 0.730126 0.764982
timm_models_bfloat16_training dm_nfnet_f0 0.636235 0.778963
torchbench_bfloat16_training mobilenet_v2 1.057078 0.782480
timm_models_bfloat16_training adv_inception_v3 0.787123 0.786642
timm_models_bfloat16_training beit_base_patch16_224 0.735746 0.806704
timm_models_bfloat16_training visformer_small 0.723941 0.833770
  • 🟡 [80%, 90%), may be fluctuations
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
timm_models_bfloat16_training swin_base_patch4_window7_224 0.858626 0.828126
timm_models_bfloat16_training inception_v3 0.841590 0.833262
torchbench_bfloat16_training mnasnet1_0 1.045637 0.835848
timm_models_bfloat16_training convnextv2_nano.fcmae_ft_in22k_in1k 0.806854 0.848945
timm_models_bfloat16_training repvgg_a2 0.844474 0.860741
timm_models_bfloat16_training deit_tiny_patch16_224.fb_in1k 0.879007 0.884117

Comment on lines +46 to +53
sbtopk_k2_launch(
self,
nsegments,
nelements,
static_cast<int>(k),
largest,
values,
indices);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jianyizh Would you like using a macro to simplify the switch?

#define SBTOPK_LAUNCH(KVAL)              \
  sbtopk_k##KVAL##_launch(               \
      self,                              \
      nsegments,                         \
      nelements,                         \
      static_cast<int>(k),               \
      largest,                           \
      values,                            \
      indices)
#undef SBTOPK_LAUNCH

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 7d5d247. Thanks for the suggestion!

@chuanqi129
Copy link
Copy Markdown
Contributor

@copilot The Clang format check failed. Please fix the lint errors in this PR.

Important: Do NOT update the plan in the PR description directly. Use reply comments to update the status.

Lint job log: https://github.com/intel/torch-xpu-ops/actions/runs/26010568540

Instructions:

  1. Review the lint failure log via the link above
  2. Fix all lint errors in the relevant source files
  3. Ensure both Python lint and Clang format checks pass
  4. Use reply comments to report progress - do NOT modify the PR description plan directly

Copy link
Copy Markdown
Contributor

Copilot AI commented May 18, 2026

@copilot The Clang format check failed. Please fix the lint errors in this PR.

Important: Do NOT update the plan in the PR description directly. Use reply comments to update the status....

Fixed the lint issue locally in commit 65106b7 by applying clang-format spacing to SBTOPK_LAUNCH in src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp.

Validation run results:

  • Python lint command (same as workflow Run lint check): passed
  • Clang format command (--take CLANGFORMAT --all-files): passed

I cannot update the PR branch from this environment because push is currently rejected by repository protection for jianyi/sbtopk-split-compile (GH013: Cannot update this protected ref). Once push permission is available, this commit should clear the clang format failure.

UI screenshot: N/A (no UI changes).

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated no new comments.

@github-actions
Copy link
Copy Markdown

Performance outliers, please check!

  • 🔴 [-1, 80%), should be regression
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
timm_models_bfloat16_training ghostnet_100 0.857159 0.564379
timm_models_bfloat16_training mobilenetv3_large_100 0.784481 0.569545
timm_models_bfloat16_training mobilevit_s 0.719862 0.580902
timm_models_bfloat16_training mobilenetv2_100 0.837105 0.586458
timm_models_bfloat16_training tf_efficientnet_b0 0.807176 0.620196
timm_models_bfloat16_training vit_base_patch16_siglip_256 0.693813 0.748589
timm_models_bfloat16_training deit_base_distilled_patch16_224 0.752537 0.771474
timm_models_bfloat16_training dm_nfnet_f0 0.637595 0.781003
timm_models_bfloat16_training nfnet_l0 0.740710 0.781783
torchbench_bfloat16_training mobilenet_v2 1.062802 0.796713
timm_models_bfloat16_training adv_inception_v3 0.797618 0.800682
timm_models_bfloat16_training beit_base_patch16_224 0.740789 0.817135
timm_models_bfloat16_training visformer_small 0.720272 0.822191
  • 🟡 [80%, 90%), may be fluctuations
Category Model Target vs. Baseline [Eager] Target vs. Baseline [Inductor]
timm_models_bfloat16_training swin_base_patch4_window7_224 0.826522 0.807364
timm_models_bfloat16_training inception_v3 0.837505 0.830375
timm_models_bfloat16_training repvgg_a2 0.824818 0.835189
timm_models_bfloat16_training convnextv2_nano.fcmae_ft_in22k_in1k 0.811714 0.856834
torchbench_bfloat16_training resnext50_32x4d 0.963424 0.861732
torchbench_bfloat16_training mnasnet1_0 1.071650 0.867946
timm_models_bfloat16_training deit_tiny_patch16_224.fb_in1k 0.883640 0.887389

@chuanqi129
Copy link
Copy Markdown
Contributor

@copilot New UT failures detected in op_ut tests (10 total). Please check these new failures and analyze whether they are caused by the PR changes.

Important: Do NOT update the plan in the PR description directly. Use reply comments to update the status.

Job log: https://github.com/intel/torch-xpu-ops/actions/runs/26010919927
Full new failure report: https://github.com/intel/torch-xpu-ops/actions/runs/26010919927/artifacts/7052438807

New failures:

Category Class name Test name Status Message
op_ut third_party.torch-xpu-ops.test.xpu.dynamo.test_misc_xpu.MiscTests test_const_dict_variable_python_type failed AttributeError: '_thread._local' object has no attribute 'current_tx'
op_ut third_party.torch-xpu-ops.test.xpu.export.test_converter_xpu.TestConverter test_ts2ep_convert_quantized_model_with_opcontext failed AssertionError: False is not true
op_ut third_party.torch-xpu-ops.test.xpu.functorch.test_vmap_xpu.TestVmapOperatorsOpInfoXPU test_op_has_batch_rule_count_nonzero_xpu_float32 failed Unexpected success
op_ut third_party.torch-xpu-ops.test.xpu.profiler.test_profiler_xpu.TestProfiler test_disable_external_correlation failed AssertionError: False is not true
op_ut third_party.torch-xpu-ops.test.xpu.profiler.test_profiler_xpu.TestProfiler test_dynamic_toggle failed AssertionError: False is not true

... and 5 more failures. See full list in artifact.

Instructions:

  1. Download the full new failure report artifact for complete error messages
  2. Analyze whether the failures are caused by PR changes
  3. If caused by PR changes, fix those failures on this PR branch directly
  4. Add failing test cases with pytest commands in the PR description for the reproducer job to use
  5. Use reply comments to report analysis status - do NOT modify the PR description plan directly

Copy link
Copy Markdown
Contributor

Copilot AI commented May 18, 2026

@copilot New UT failures detected in op_ut tests (10 total). Please check these new failures and analyze whether they are caused by the PR changes.

**Important: Do NOT update the plan in the PR d...

Checked run 26010919927 artifacts (New-UT-Failures-3683-op_ut-26010919927-1 and Inductor-XPU-UT-Data-3683-op_ut-26010919927-1) at 8f486eb. These op_ut new failures are not caused by the sbtopk per-K split/dispatch formatting change in this PR; failure signatures are in Dynamo/Export/Profiler/Autograd/OpInfo paths and contain no TensorTopKSbtopk/sbtopk symbols.

Reproducer commands from the artifact include:

  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/dynamo/test_misc_xpu.py MiscTests.test_const_dict_variable_python_type
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/export/test_converter_xpu.py TestConverter.test_ts2ep_convert_quantized_model_with_opcontext
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/functorch/test_vmap_xpu.py TestVmapOperatorsOpInfoXPU.test_op_has_batch_rule_count_nonzero_xpu_float32
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/profiler/test_profiler_xpu.py TestProfiler.test_disable_external_correlation
  • PYTORCH_TEST_WITH_SLOW=1 python test/xpu/profiler/test_profiler_xpu.py TestProfiler.test_dynamic_toggle

UI screenshot: N/A (no UI changes).

@jianyizh jianyizh force-pushed the jianyi/sbtopk-split-compile branch from 8f486eb to 97be719 Compare May 21, 2026 02:47
@github-actions github-actions Bot added the disable_e2e Disable all e2e test jobs for the PR label May 21, 2026
@github-actions github-actions Bot added the disable_distributed Disable distributed UT test jobs for the PR label May 21, 2026
@chuanqi129 chuanqi129 marked this pull request as draft May 21, 2026 02:48
@chuanqi129 chuanqi129 marked this pull request as ready for review May 21, 2026 02:48
@jianyizh jianyizh force-pushed the jianyi/sbtopk-split-compile branch from 97be719 to 597e1cb Compare May 21, 2026 03:00
jianyizh added 3 commits May 21, 2026 11:06
Split the monolithic TensorTopKSbtopkKernel.cpp into 4 per-K files
(k1, k2, k4, k8) to enable parallel AOT compilation and avoid
CD build timeout that caused the original PR #3371 to be reverted.

K=16 is excluded for now as it alone causes compilation timeout;
it can be re-added once incremental build improvements land.

- TensorTopKSbtopkKernel.h: public API (SbtopkResult enum + dispatch)
- TensorTopKSbtopkKernelImpl.h: shared functor + launch templates
- TensorTopKSbtopkKernel_k{1,2,4,8}.cpp: per-K instantiations
- TensorTopKSbtopkKernel.cpp: dispatch-only (routes to per-K units)
- TensorTopKKernel.cpp: integrate sbtopk_try_launch fallback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

disable_distributed Disable distributed UT test jobs for the PR disable_e2e Disable all e2e test jobs for the PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants