Skip to content

Feature/prefill fmha esimd#425

Draft
liu-shaojun wants to merge 4 commits into
mainfrom
feature/prefill-fmha-esimd
Draft

Feature/prefill fmha esimd#425
liu-shaojun wants to merge 4 commits into
mainfrom
feature/prefill-fmha-esimd

Conversation

@liu-shaojun
Copy link
Copy Markdown
Contributor

No description provided.

liu-shaojun and others added 4 commits May 21, 2026 02:36
Implements Flash Attention for prefill with paged KV cache, GQA, and
causal mask using ESIMD intrinsics. All 8 UT cases pass with max diff
< 0.002 vs IPEX reference.

Current performance: ~23x slower than IPEX (JIT mode, kBr=4, kBc=16,
scalar dot product). Hardware metrics show 0% XMX utilization — next
step is adding DPAS for Q×K^T score computation.

Key design decisions:
- JIT compilation (-fsycl-targets=spir64) to avoid AOT GRF overflow
- kBr=4 Q rows per work group with shared K/V loading
- kBc=16 batch softmax with safe_exp16 (clamp + merge to avoid NaN)
- head_dim=256 hardcoded for Qwen3.5 series models
- Supports: causal mask, paged KV (block_table), GQA (any ratio)

Build: TORCH_XPU_ARCH_LIST=bmg-g21 MAX_JOBS=8 python3 setup.py bdist_wheel
Test:  ZE_AFFINITY_MASK=4 python tests/test_prefill_fmha.py 1 2 3 4 5 6 7 8

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Implements split-KV architecture:
- Sub-kernel: kBr=4 per-token online softmax, processes PARTITION_SIZE
  tokens of KV range, outputs partial_out + row_max + row_sum (float32)
- Reduce kernel: log-sum-exp merge across partitions → final fp16 output
- Host launcher: dispatches all partitions in parallel, then reduce

Performance: 15.1x slower than IPEX (from 18x without split-KV).
Limited improvement because GPU EUs are already saturated with WGs.
Next step: add DPAS to sub-kernel for per-WG compute acceleration.

All 8 UT cases pass (max diff < 0.002).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ction

Current state:
- kBr=8, kBc=16, DPAS score, SIMD exp, scalar max/sum
- JIT mode (-fsycl-targets=spir64)
- All 8 UT cases pass (max diff < 0.002)
- Performance: 93x (slow due to large kernel body from select ops)

Known JIT bugs:
- reduce<float>(simd, maximum<>()) returns wrong value (simd[0] not max)
- pointer array simd*[N] indirect access produces NaN

Next: try manual tree reduction for max to avoid both reduce bug and select overhead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Documents all findings, JIT bugs, performance analysis, and next steps.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant