Skip to content

Implement TurboQuant#2885

Draft
waqahmed-amd-fi wants to merge 3 commits intomainfrom
waqahmed/turboquant-dev
Draft

Implement TurboQuant#2885
waqahmed-amd-fi wants to merge 3 commits intomainfrom
waqahmed/turboquant-dev

Conversation

@waqahmed-amd-fi
Copy link
Copy Markdown

@waqahmed-amd-fi waqahmed-amd-fi commented Apr 23, 2026

Motivation

Transformer inference at long context lengths is increasingly KV-cache bound: the memory required to store keys and values grows linearly with sequence length and becomes the dominant VRAM cost, limiting both maximum context and batch size.

This PR begins the integration of TurboQuant (arXiv:2504.19874) into AIter as an opt-in KV-cache compression backend. TurboQuant is a data-oblivious, training-free online quantizer — it requires no calibration data, no model retraining, and no changes to model weights. It compresses KV cache entries to 2–4 bits per coordinate by applying a random orthogonal rotation (making coordinates near-independent and optimally quantizable) followed by Lloyd-Max scalar quantization. A 1-bit QJL residual sketch corrects the inner-product bias that MSE-only quantization introduces, producing an unbiased attention score estimator.

Technical Details

Implementation Plan

The full TurboQuant integration is structured across 5 phases. Each phase builds directly on the previous, with the PyTorch reference implementation (Phases 1–2) serving as the correctness oracle for the Triton kernel work (Phase 3).

Phase Scope Status
1 — Core Quantization Codebooks, rotation matrices, MSE/Prod/Value quantizers, bit-packing, unit tests Ongoing
2 — PyTorch Reference Attention turboquant_attention() wrapper using decompress + standard flash attention; benchmark harness Next
3 — Fused Triton Kernels Score directly on compressed KV (no materialization); flash-attention style online softmax Planned
4 — Integration & Autotuning Stateful TurboQuantKVCache manager for multi-step decode; per-arch config tuning Planned
5 — Documentation API docs, usage examples, CI benchmark integration Planned

Phase 1 — Core Quantization Infrastructure

Goal: Implement the full compression/decompression stack in pure PyTorch with no Triton kernels, establishing a mathematically correct reference baseline.

aiter/ops/triton/attention/turboquant/
├── codebook.py     Lloyd-Max optimal codebooks (Beta(d/2, d/2) distribution)
├── rotation.py     Random orthogonal Π and Gaussian QJL projection S
├── quantizer.py    TurboQuantMSE (Alg 1), TurboQuantProd (Alg 2), ValueQuantizer
├── utils.py        Vectorized 2/3/4-bit pack_indices / unpack_indices
├── __init__.py     Public API exports
└── configs/        Pre-generated codebook .pt files (9 total)

Key design decisions made in this phase:

  • Codebooks pre-generated as .pt files rather than computed at runtime — Lloyd-Max iteration over 2M samples takes ~30s per config; no reason to repeat it each startup.
  • √d coordinate scaling applied before codebook lookup and inverted on decompression — raw rotated coordinates of unit-norm vectors are $O(1/\sqrt{d})$ while the codebook spans $O(1)$; without this, all coordinates collapse into 1–2 central levels regardless of bit-width.
  • Rotation matrix Π generated via torch.linalg.qr in float64, cast to float32. Cached per (head_dim, device, seed) — regenerating within a run would break compress/decompress consistency.
  • QJL matrix S is plain Gaussian (not orthogonalized) — the paper's QJL construction requires this; orthogonalizing S changes the estimator's bias properties.
  • Query projection kept continuous in TurboQuantProd.inner_product_score — the estimator is ‖r‖ · (Sq) · sign(Sr) / d, not sign(Sq) · sign(Sr) / d. Binarizing both sides doubles variance and introduces bias.
  • Standalone package under aiter/ops/triton/attention/turboquant/ rather than modifying any existing files — users explicitly opt in; no risk of regressions.

Phase 2 — PyTorch Reference Attention

Wraps Phase 1 into a usable turboquant_attention(q, k, v, key_bits, value_bits, ...) function: compress K and V, decompress, call flash_attn_func. Slow (materializes full decompressed KV) but correct — validates end-to-end attention output before any kernel work. Benchmark harness added here to track quality (cos_sim vs bf16 baseline) and compression ratios across sequence lengths.

Phase 3 — Fused Triton Kernels

Eliminates the decompress-then-attend bottleneck. Three kernels:

  1. MSE score: score = norm/√d · (q @ Πᵀ) · codebook[idx] — no K materialization
  2. QJL correction: correction = res_norm · (S@q) · sign(S@r) / d
  3. Fused attention: flash-attention style tiled loop operating directly on compressed K/V blocks

Phase 2's PyTorch path remains as the reference oracle for kernel validation.

Phases 4–5 — Integration & Documentation

Phase 4 adds the stateful TurboQuantKVCache (ring buffer of compressed KV for multi-step decode) and per-architecture Triton config tuning. Phase 5 adds docs/turboquant.md, a worked example, and integrates into the CI benchmark suite alongside SAGE and FAv3.

Test Plan

Compacted conversation## Test Plan

Phase 1 — Core Quantization (this PR)

Test file: op_tests/triton_tests/attention/test_turboquant_core.py (72 tests, all passing)

Class What it tests
TestBitPacking 2/3/4-bit pack→unpack round-trips for all valid indices; verifies no index is lost or corrupted across boundary bytes
TestRotationMatrices Π is orthogonal: Π^T Π = I within float32 tolerance; S rows are independently Gaussian; both are stable across repeated get_* calls (cache hit returns same object)
TestCodebook Codebook entries are sorted and strictly increasing; 2/3/4-bit codebooks have 4/8/16 entries; get_codebook() returns identical tensor from disk as freshly generated
TestTurboQuantMSE compress→decompress reconstructs unit vectors with cos_sim ≥ 0.95 (3-bit) / ≥ 0.99 (4-bit); norm is preserved within 5%; quantization indices are in valid range [0, 2^b)
TestTurboQuantProd inner_product_score is unbiased: mean error E[<q̂,k>] - <q,k> < 0.05 over 1000 random pairs; QJL component reduces variance vs MSE-only estimate
TestValueQuantizer compress→decompress cos_sim ≥ 0.90 (2-bit) / ≥ 0.99 (4-bit); scales and zero-points are finite; group boundaries are respected
TestDistortionScaling MSE distortion decreases by ~4× per additional bit (Theorem 3); validated at head_dim ∈ {64, 128, 256}
TestCompressionRatio compression_ratio() returns values in [1.0, 8.0]; 4-bit keys + 2-bit values at d=256 yields ≥ 4.0×

Run locally (CPU, no GPU required):

pytest op_tests/triton_tests/attention/test_turboquant_core.py -v

Test Result

Submission Checklist

Add TurboQuant KV-cache compression as a new attention backend in AIter. Phase 1 covers the full quantization stack in pure PyTorch (no Triton kernels), providing a correct reference implementation for later phases.

New package: aiter/ops/triton/attention/turboquant/

- codebook.py: Lloyd-Max optimal scalar quantizer for the Beta(d/2, d/2)
  coordinate distribution that arises after random orthogonal rotation.
  Pre-generates .pt codebooks for head_dim ∈ {64,128,256} × bits ∈ {2,3,4}.

- rotation.py: Π (random orthogonal, via QR decomposition of Gaussian matrix)
  and S (plain Gaussian, NOT orthogonalized) projection matrices. Both cached
  per (head_dim, device, seed) for consistency within a run.

- quantizer.py: Three quantizers:
  · TurboQuantMSE (Algorithm 1): normalize → rotate → scale by √d →
    Lloyd-Max codebook lookup → pack. Decompress inverts exactly.
  · TurboQuantProd (Algorithm 2): MSE + 1-bit QJL residual sketch.
    Unbiased inner-product estimator: ||r|| * (S@q)·sign(S@r) / d.
    Query projection kept continuous (not binarized) for lower variance.
  · ValueQuantizer: per-group affine quantization for V tensors (2 or 4 bit).

- utils.py: Vectorized 2/3/4-bit pack_indices / unpack_indices with correct
  sqrt(d) scaling accounting. Includes compression_ratio() for benchmarking.

- configs/: Directory for pre-generated codebook .pt files.

Tests: op_tests/triton_tests/attention/test_turboquant_core.py
  72 tests covering bit-packing round-trips, orthogonality, codebook shape/
  symmetry, MSE cos_sim thresholds, IP unbiasedness, distortion scaling
  (Theorem 3), value quantization quality, and compression ratio accounting.
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2885 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant