Draft
Conversation
Add TurboQuant KV-cache compression as a new attention backend in AIter. Phase 1 covers the full quantization stack in pure PyTorch (no Triton kernels), providing a correct reference implementation for later phases.
New package: aiter/ops/triton/attention/turboquant/
- codebook.py: Lloyd-Max optimal scalar quantizer for the Beta(d/2, d/2)
coordinate distribution that arises after random orthogonal rotation.
Pre-generates .pt codebooks for head_dim ∈ {64,128,256} × bits ∈ {2,3,4}.
- rotation.py: Π (random orthogonal, via QR decomposition of Gaussian matrix)
and S (plain Gaussian, NOT orthogonalized) projection matrices. Both cached
per (head_dim, device, seed) for consistency within a run.
- quantizer.py: Three quantizers:
· TurboQuantMSE (Algorithm 1): normalize → rotate → scale by √d →
Lloyd-Max codebook lookup → pack. Decompress inverts exactly.
· TurboQuantProd (Algorithm 2): MSE + 1-bit QJL residual sketch.
Unbiased inner-product estimator: ||r|| * (S@q)·sign(S@r) / d.
Query projection kept continuous (not binarized) for lower variance.
· ValueQuantizer: per-group affine quantization for V tensors (2 or 4 bit).
- utils.py: Vectorized 2/3/4-bit pack_indices / unpack_indices with correct
sqrt(d) scaling accounting. Includes compression_ratio() for benchmarking.
- configs/: Directory for pre-generated codebook .pt files.
Tests: op_tests/triton_tests/attention/test_turboquant_core.py
72 tests covering bit-packing round-trips, orthogonality, codebook shape/
symmetry, MSE cos_sim thresholds, IP unbiasedness, distortion scaling
(Theorem 3), value quantization quality, and compression ratio accounting.
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Transformer inference at long context lengths is increasingly KV-cache bound: the memory required to store keys and values grows linearly with sequence length and becomes the dominant VRAM cost, limiting both maximum context and batch size.
This PR begins the integration of TurboQuant (arXiv:2504.19874) into AIter as an opt-in KV-cache compression backend. TurboQuant is a data-oblivious, training-free online quantizer — it requires no calibration data, no model retraining, and no changes to model weights. It compresses KV cache entries to 2–4 bits per coordinate by applying a random orthogonal rotation (making coordinates near-independent and optimally quantizable) followed by Lloyd-Max scalar quantization. A 1-bit QJL residual sketch corrects the inner-product bias that MSE-only quantization introduces, producing an unbiased attention score estimator.
Technical Details
Implementation Plan
The full TurboQuant integration is structured across 5 phases. Each phase builds directly on the previous, with the PyTorch reference implementation (Phases 1–2) serving as the correctness oracle for the Triton kernel work (Phase 3).
turboquant_attention()wrapper using decompress + standard flash attention; benchmark harnessTurboQuantKVCachemanager for multi-step decode; per-arch config tuningPhase 1 — Core Quantization Infrastructure
Goal: Implement the full compression/decompression stack in pure PyTorch with no Triton kernels, establishing a mathematically correct reference baseline.
Key design decisions made in this phase:
.ptfiles rather than computed at runtime — Lloyd-Max iteration over 2M samples takes ~30s per config; no reason to repeat it each startup.√dcoordinate scaling applied before codebook lookup and inverted on decompression — raw rotated coordinates of unit-norm vectors aretorch.linalg.qrin float64, cast to float32. Cached per(head_dim, device, seed)— regenerating within a run would break compress/decompress consistency.TurboQuantProd.inner_product_score— the estimator is‖r‖ · (Sq) · sign(Sr) / d, notsign(Sq) · sign(Sr) / d. Binarizing both sides doubles variance and introduces bias.aiter/ops/triton/attention/turboquant/rather than modifying any existing files — users explicitly opt in; no risk of regressions.Phase 2 — PyTorch Reference Attention
Wraps Phase 1 into a usable
turboquant_attention(q, k, v, key_bits, value_bits, ...)function: compress K and V, decompress, callflash_attn_func. Slow (materializes full decompressed KV) but correct — validates end-to-end attention output before any kernel work. Benchmark harness added here to track quality (cos_sim vs bf16 baseline) and compression ratios across sequence lengths.Phase 3 — Fused Triton Kernels
Eliminates the decompress-then-attend bottleneck. Three kernels:
score = norm/√d · (q @ Πᵀ) · codebook[idx]— no K materializationcorrection = res_norm · (S@q) · sign(S@r) / dPhase 2's PyTorch path remains as the reference oracle for kernel validation.
Phases 4–5 — Integration & Documentation
Phase 4 adds the stateful
TurboQuantKVCache(ring buffer of compressed KV for multi-step decode) and per-architecture Triton config tuning. Phase 5 addsdocs/turboquant.md, a worked example, and integrates into the CI benchmark suite alongside SAGE and FAv3.Test Plan
Compacted conversation## Test Plan
Phase 1 — Core Quantization (this PR)
Test file:
op_tests/triton_tests/attention/test_turboquant_core.py(72 tests, all passing)TestBitPackingTestRotationMatricesΠ^T Π = Iwithin float32 tolerance; S rows are independently Gaussian; both are stable across repeatedget_*calls (cache hit returns same object)TestCodebookget_codebook()returns identical tensor from disk as freshly generatedTestTurboQuantMSE[0, 2^b)TestTurboQuantProdinner_product_scoreis unbiased: mean errorE[<q̂,k>] - <q,k>< 0.05 over 1000 random pairs; QJL component reduces variance vs MSE-only estimateTestValueQuantizerTestDistortionScalingTestCompressionRatiocompression_ratio()returns values in[1.0, 8.0]; 4-bit keys + 2-bit values at d=256 yields ≥ 4.0×Run locally (CPU, no GPU required):
Test Result
Submission Checklist