diff --git a/README.md b/README.md index 7636ccc..3b27cca 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Week 1 and 2 is complete. Week 3 is in progress. | 2.5 | Flash Attention 2 - GPU | ✅ | ✅ | ✅ | | 2.6 | Continuous Batching | ✅ | ✅ | ✅ | | 2.7 | Chunked Prefill | ✅ | ✅ | ✅ | -| 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 | +| 3.1 | Paged Attention - Part 1 | ✅ | ✅ | 🚧 | | 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 | | 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 | | 3.4 | Speculative Decoding | 🚧 | ✅ | 🚧 | diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index f7fb814..321a003 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -19,6 +19,8 @@ - [Flash Attention (2 Days)](./week2-04-flash-attention.md) - [Continuous Batching (2 Days)](./week2-06-prefill-and-batch.md) - [Week 3: Serving]() + - [Paged Attention, Part 1]() + - [Paged Attention, Part 2]() --- diff --git a/book/src/week3-01-paged-attention-part1.md b/book/src/week3-01-paged-attention-part1.md new file mode 100644 index 0000000..d3e5b72 --- /dev/null +++ b/book/src/week3-01-paged-attention-part1.md @@ -0,0 +1,281 @@ +# Week 3 Day 1: Paged Attention, Part 1 + +In this chapter, we will design the **paged KV cache**. This is the storage abstraction behind paged attention. + +By the end of Week 2, our serving stack already supports: + +- per-request KV cache +- chunked prefill +- continuous batching +- FlashAttention + +That gives us a working miniature serving engine, but the memory layout is still too simple. KV for each request is treated as one growing dense tensor, and batching rebuilds dense K/V for all active requests. That approach is easy to teach, but it does not scale well once requests become long and numerous. + +Paged attention starts by fixing the storage layout. + +**📚 Readings** + +- [vLLM Paged Attention Design](https://docs.vllm.ai/en/v0.18.0/design/paged_attention/) +- [Efficient Memory Management for Large Language Model Serving with PagedAttention](https://arxiv.org/abs/2309.06180) + +## Why the Week 2 KV Layout Becomes Expensive + +Right now, the mental model looks like this: + +```plain +request A -> one dense KV tensor +request B -> one dense KV tensor +request C -> one dense KV tensor +``` + +Before attention, the runtime repacks them into: + +```plain +keys: [B, H, S_max, D] +values: [B, H, S_max, D] +mask: [B, 1, L, S_max] +``` + +The trouble is that decode only adds a tiny amount of new information each step, but the dense layout keeps revisiting old KV. + +For example, if a request already has 17 cached tokens and we decode 1 more token: + +```plain +new useful work: append 1 token +dense repack view: rebuild 18 logical positions +``` + +For one request this is fine. For many live requests, the runtime spends more and more time moving previously computed KV instead of doing actual model work. + +## The Page Abstraction + +Instead of storing each layer's KV for a request as one long tensor, we divide storage into fixed-size **pages**: + +```plain +key_pages: pages with up to page_size token slots +value_pages: pages with up to page_size token slots +``` + +Each layer cache keeps a small page table: + +```plain +page_ids = [12, 5, 3] +context_len = 10 +``` + +That means: + +```plain +page 12 -> tokens 0..3 +page 5 -> tokens 4..7 +page 3 -> tokens 8..9 +``` + +The logical sequence is still length 10. The difference is that the runtime is no longer forced to represent it as one contiguous tensor. + +In our Day 1 teaching implementation, those fixed-size pages live in one shared **page pool** owned by the model. Every layer cache receives that same pool, but each layer cache keeps its own `page_ids`, `page_lens`, and `offset`. + +In the reference solution, `page_size` is the physical page capacity. Unused tail slots are not part of the logical sequence; `page_lens` decides which prefix of each page is valid. + +## Why Fixed-Size Pages Help + +The page abstraction gives us two immediate wins: + +1. Appending a token usually updates only the current tail page in the pool. +2. Finished requests can return their pages to a shared free list. + +This is the key memory-management idea behind paged attention systems such as vLLM. + +## Data Structures We Need + +## 1. `PagePool` + +The model should own one pool with a model-wide page allocator and flat K/V page storage: + +```plain +free_pages: available page ids for the whole model +keys[page_id]: physical key page +values[page_id]: physical value page +``` + +Each layer still has distinct K/V contents because each layer cache allocates its own physical pages. In this teaching version, each layer cache also has its own logical page table. That is simpler than nano-vllm's shared block table: layer 0 might own pages `[0, 1]`, while layer 1 owns pages `[2, 3]`, but both page sets came from the same model-owned pool. + +In the reference solution, this becomes `TinyKvPagedPool`. + +## 2. `PagedRequestCache` + +A layer cache for one request should track: + +- `page_ids` +- `page_lens` +- `offset` +- `page_size` + +Derived values: + +- `num_pages = len(page_ids)` +- `context_len = offset` +- `last_page_fill = page_lens[-1]` when at least one page exists + +In the reference solution, this becomes `TinyKvPagedCache`. +It is created with a pool from the model. It should not allocate its own pool, +because that would isolate one request from the shared page allocator. + +The reference solution creates one `TinyKvPagedCache` per transformer layer. Those caches share the pool, but they do not share metadata: each layer cache owns its own `page_ids`, `page_lens`, and `offset`. + +## 3. Tail-Append Logic + +When new K/V arrives for one layer: + +1. look at that layer cache's last page +2. if there is room, append only the new slice into the tail page +3. otherwise allocate a new page and continue writing +4. update cache metadata such as `page_lens` and `offset` + +This replaces the Week 2 pattern of repeatedly concatenating along the sequence dimension. + +## Prefill with Pages + +Suppose `page_size = 4` and one prefill chunk contains 6 tokens: + +```plain +chunk = [t0 t1 t2 t3 t4 t5] +``` + +One possible layout is: + +```plain +page 7 <- [t0 t1 t2 t3] +page 2 <- [t4 t5] # 2 valid tokens, 2 unused slots of capacity +``` + +That layer cache's metadata becomes: + +```plain +page_ids = [7, 2] +context_len = 6 +``` + +The important property is that a later decode token can be appended to page `2` without touching page `7`. + +## Decode with Pages + +During decode, each live request adds one token at a time. + +With paged storage: + +1. compute one-token `k` and `v` +2. check whether the tail page still has space +3. write into that page if possible +4. allocate a new page only when the old one is full + +So if `page_size = 4` and `context_len = 9`: + +```plain +page_ids = [12, 5, 3] +``` + +Appending token 9 only updates the last page instead of rebuilding all earlier KV. + +## Stage A: Keep Dense Attention + +The cleanest first implementation is **paged storage with dense gather**. + +That means: + +- pages in the shared pool are the source of truth, +- layer caches stop owning one monolithic K/V tensor, +- layer caches only track page metadata, +- attention still receives dense K/V reconstructed from pages. + +This is not the final paged attention runtime yet, but it is a very useful intermediate step: + +- small surface-area change +- easier debugging +- direct correctness comparison against `TinyKvFullCache` + +## How This Maps to `tiny-llm` + +## `src/tiny_llm/kv_cache.py` + +Add: + +- `PagePool` +- `PagedRequestCache` + +Keep `TinyKvFullCache` around as a baseline and test oracle. + +The key Day 1 behavior is: + +1. write new K/V into the layer cache's tail page or newly allocated pages, +2. gather the layer cache's pages back into dense K/V, +3. feed that dense K/V into the old attention path. + +So Day 1 changes the storage model first, not the attention kernel yet. + +## `src/tiny_llm/batch.py` + +Requests should own per-layer cache handles instead of long dense K/V tensors. + +The scheduler should still: + +- perform chunked prefill, +- hold active requests, +- free cache pages when a slot finishes. + +The difference is that freeing a request now means releasing all pages owned by its layer caches back to the pool. + +Day 1 also keeps a small `rewind(n)` lifecycle hook. Rewind is useful for speculative decoding: if some drafted tokens are rejected, the cache must forget their K/V. In the paged cache, rewind frees whole pages that are no longer needed and shortens the valid length of the final remaining page. + +## Design Questions for Day 1 + +Before implementing, make sure the following are clear: + +1. What page size should this repo use for teaching? +2. How do we represent the free-page allocator? +3. How do we prove that paged storage reconstructs the same logical KV as `TinyKvFullCache`? +4. How do layer cache handles share one pool while keeping their own page metadata? +5. When do we materialize page writes to avoid MLX lazy-graph growth? + +## Task 1: Design `PagePool` + +``` +src/tiny_llm/kv_cache.py +``` + +Design a model-owned page pool that: + +- owns the model-wide free-page allocator, +- stores flat fixed-size K/V pages, +- allocates and frees page ids, +- supports writing a chunk into page storage, +- is shared by every layer cache created by the model. + +## Task 2: Design `PagedRequestCache` + +``` +src/tiny_llm/kv_cache.py +``` + +Replace the "one layer cache = one dense KV tensor" model with: + +- `page_ids` +- `context_len` +- append logic over fixed-size pages +- `release()` for returning pages on request completion +- `rewind(n)` for dropping the newest `n` logical tokens + +## Task 3: Add a Dense-Gather Compatibility Path + +``` +src/tiny_llm/kv_cache.py +src/tiny_llm/qwen2_week3.py +``` + +Build a compatibility path that reconstructs dense K/V from pages and compares it against `TinyKvFullCache`. + +This gives us a correctness checkpoint before we change the attention path itself. + +In the next chapter, we will take the next step: instead of gathering dense K/V before attention, we will pass runtime metadata such as `block_table` directly into a paged attention path. + +{{#include copyright.md}} diff --git a/book/src/week3-02-paged-attention-part2.md b/book/src/week3-02-paged-attention-part2.md new file mode 100644 index 0000000..8a5044f --- /dev/null +++ b/book/src/week3-02-paged-attention-part2.md @@ -0,0 +1,223 @@ +# Week 3 Day 2: Paged Attention, Part 2 + +In this chapter, we move from **paged KV storage** to the runtime metadata and execution path needed for **real paged attention**. + +Part 1 introduced fixed-size pages, a model-owned page pool shared by layer caches, and per-layer page metadata. That change already improves the storage abstraction, but it does not yet remove the dense gather before attention. To get the full benefit, the attention path itself must understand how to read from pages. + +## Paged KV Cache vs Paged Attention + +These two ideas are related, but they are not the same: + +1. **Paged KV cache** + KV is stored in fixed-size pages. +2. **Paged attention** + The attention path reads KV directly from those pages via metadata such as a page table. + +You can implement the first one without the second one, but the real serving payoff comes when both are present. + +## The Metadata a Paged Runtime Needs + +Once KV is paged, dense `B x H x S x D` tensors are no longer the natural runtime representation. Instead, the runtime should prepare metadata like: + +```plain +block_table: [B, max_pages_per_request] +context_lens: [B] +slot_mapping: [B] or [num_new_tokens] +``` + +For the current layer being executed: + +- `block_table[b, i]` gives the page id for request `b`'s current-layer logical page `i` +- `context_lens[b]` gives the valid token count for request `b` +- `slot_mapping` tells us where newly generated K/V should be written + +This is the bridge between the scheduler and the attention kernel. + +## Why `block_table` Matters + +Suppose one layer cache for request A has: + +```plain +page_ids = [12, 5, 3] +context_len = 10 +page_size = 4 +``` + +Then the logical sequence positions map to physical storage like this: + +```plain +logical 0..3 -> page 12 +logical 4..7 -> page 5 +logical 8..9 -> page 3 +``` + +The attention runtime does not need a fully gathered dense tensor if it already knows: + +- which current-layer page each logical block lives in, +- how long the context is, +- and where the current query positions are. + +That is exactly what `block_table` and `context_lens` encode. + +## The Real Attention API + +At this point, the runtime should grow a new attention entry point: + +```python +paged_attention( + query, + key_pages, + value_pages, + block_table, + context_lens, + scale=None, + mask="causal", +) +``` + +With shapes like: + +```plain +query: B, H_q, L, D +key_pages: P, H_kv, page_size, D +value_pages: P, H_kv, page_size, D +block_table: B, max_pages +context_lens: B +``` + +Compared with the Week 2 dense path, the important difference is that the source length is no longer represented as one contiguous tensor dimension. It is reconstructed logically from the page table. + +## Prefill Metadata + +During prefill, a chunk may span multiple pages. The runtime needs to know: + +- which current-layer pages already existed, +- which new pages were allocated, +- how many valid tokens are in the tail page, +- how to map incoming K/V rows into page storage. + +This is why a paged design usually carries a write-side structure such as `slot_mapping`. + +## Decode Metadata + +During decode, each active request typically writes one token. + +The runtime should be able to: + +1. compute the destination slot for that token, +2. write K/V into the correct page slot, +3. update the current layer cache's `context_len`, +4. run attention over the full logical context using `block_table` + +This is the point where decode stops paying the repeated dense-repack cost from Week 2. + +## How This Maps to `tiny-llm` + +## `src/tiny_llm/attention.py` + +Add a new function: + +```python +def paged_attention(...): + ... +``` + +The easiest rollout is: + +1. first implement it as a gather-then-call wrapper around existing attention, +2. later replace that wrapper with a real paged kernel or paged FlashAttention path. + +That preserves correctness while the runtime contracts stabilize. + +## `src/tiny_llm/qwen2_week3.py` + +The attention module should be able to branch on cache capability: + +```python +if cache.supports_paged_attention: + x = paged_attention(...) +else: + x = scaled_dot_product_attention_grouped(...) +``` + +This keeps the model code readable while letting the cache and kernel evolve independently. + +## `src/tiny_llm/batch.py` + +The scheduler now needs to prepare runtime metadata instead of only dense K/V: + +- per-layer page tables for each active request +- padded batch `block_table` +- `context_lens` +- write positions for prefill and decode + +This is where continuous batching and paged attention finally connect. In Week 2, batching worked by repacking tensors. In Week 3, batching should work by reusing page tables and updating only the new slots. + +## Recommended Incremental Rollout + +The safest implementation order is: + +1. paged storage +2. dense gather compatibility path +3. `block_table` / `context_lens` plumbing +4. real paged attention dispatch + +This order matters because it gives us a clean correctness baseline at each step. + +## Correctness Invariants + +These are the invariants worth checking in tests: + +1. `context_len` always equals the number of written logical token positions. +2. `block_table` reconstructs the same logical KV order as the dense baseline. +3. the allocator never hands the same page to two live cache handles unless explicit sharing is implemented. +4. releasing a request returns all pages owned by all of its layer caches exactly once. +5. decode allocates a new page only when the tail page overflows. + +## Task 1: Add Batch Metadata + +``` +src/tiny_llm/kv_cache.py +src/tiny_llm/batch.py +``` + +Extend the batch cache and scheduler so they can prepare: + +- `block_table` +- `context_lens` +- write-slot metadata + +for all active requests. + +## Task 2: Define `paged_attention` + +``` +src/tiny_llm/attention.py +``` + +Add a paged attention interface whose inputs come from the paged runtime rather than a dense reconstructed `S` dimension. + +## Task 3: Dispatch from the Model + +``` +src/tiny_llm/qwen2_week3.py +``` + +Update the model so it can route to paged attention when the cache provides paged runtime metadata. + +## Task 4: Connect It to Continuous Batching + +``` +src/tiny_llm/batch.py +``` + +Update request admission, slot reuse, and request removal so that: + +- finished requests free their pages, +- in this teaching implementation, that means freeing pages from every layer cache, +- new requests allocate from the shared pool, +- active decode steps reuse page metadata instead of rebuilding dense K/V. + +After this chapter, the serving stack has the right structure for a real high-throughput runtime: paging is no longer just a storage trick, but part of the execution model itself. + +{{#include copyright.md}} diff --git a/book/src/week3-overview.md b/book/src/week3-overview.md index e69de29..5397f3e 100644 --- a/book/src/week3-overview.md +++ b/book/src/week3-overview.md @@ -0,0 +1,20 @@ +# Week 3: Serving + +In Week 3 of the course, we move from the "tiny vLLM" baseline to the next layer of serving-system ideas. Week 2 gave us the core runtime loop: KV cache, quantized kernels, FlashAttention, chunked prefill, and continuous batching. Week 3 is where we start addressing the limitations of that baseline and connect the model runtime to more realistic serving features. + +## What We’ll Cover + +* Paged attention + * Part 1: paged KV cache and the page-table abstraction + * Part 2: block tables, paged runtime metadata, and the real attention path +* Additional serving optimizations + * MoE routing and serving considerations + * speculative decoding + * long-context techniques +* Model interaction with the outside world + * retrieval-augmented generation (RAG) + * tool calling / agent-style execution + +The goal of Week 3 is not just to make the model faster. It is to understand how a serving system evolves once the basic decode loop already works: how memory is managed, how runtime metadata flows into kernels, and how the serving stack coordinates with external systems. + +{{#include copyright.md}} diff --git a/main.py b/main.py index 2fb26a4..b955c9f 100644 --- a/main.py +++ b/main.py @@ -89,6 +89,23 @@ ) else: draft_tiny_llm_model = None + elif args.loader == "week3": + print( + f"Using week3 loader with flash_attn={args.enable_flash_attn} thinking={args.enable_thinking} for {args.model}" + ) + tiny_llm_model = models.dispatch_model( + args.model, mlx_model, week=3, enable_flash_attn=args.enable_flash_attn + ) + if draft_mlx_model is not None: + print(f"Using draft model {args.draft_model}") + draft_tiny_llm_model = models.dispatch_model( + args.draft_model, + draft_mlx_model, + week=3, + enable_flash_attn=args.enable_flash_attn, + ) + else: + draft_tiny_llm_model = None else: raise ValueError(f"Loader {args.loader} not supported") messages = [ @@ -107,7 +124,7 @@ ) if args.loader == "week1": simple_generate(tiny_llm_model, tokenizer, prompt, sampler=sampler) - elif args.loader == "week2": + elif args.loader in ("week2", "week3"): if draft_tiny_llm_model is not None: speculative_generate( draft_tiny_llm_model, diff --git a/src/tiny_llm/models.py b/src/tiny_llm/models.py index 2a58f74..af087b5 100644 --- a/src/tiny_llm/models.py +++ b/src/tiny_llm/models.py @@ -1,5 +1,6 @@ from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 +from .qwen2_week3 import Qwen2ModelWeek3 from .qwen3 import Qwen3Model @@ -29,6 +30,8 @@ def dispatch_model(model_name: str, mlx_model, week: int, **kwargs): return Qwen2ModelWeek1(mlx_model, **kwargs) elif week == 2 and model_name.startswith("Qwen/Qwen2"): return Qwen2ModelWeek2(mlx_model, **kwargs) + elif week == 3 and model_name.startswith("Qwen/Qwen2"): + return Qwen2ModelWeek3(mlx_model, **kwargs) elif week == 2 and model_name.startswith("mlx-community/Qwen3"): return Qwen3Model(mlx_model, **kwargs) else: diff --git a/src/tiny_llm_ref/__init__.py b/src/tiny_llm_ref/__init__.py index bb237dd..6d8a4d9 100644 --- a/src/tiny_llm_ref/__init__.py +++ b/src/tiny_llm_ref/__init__.py @@ -8,6 +8,7 @@ from .kv_cache import * from .qwen2_week1 import Qwen2ModelWeek1 from .qwen2_week2 import Qwen2ModelWeek2 +from .qwen2_week3 import Qwen2ModelWeek3 from .qwen3 import Qwen3Model from .sampler import * from .kv_cache import * diff --git a/src/tiny_llm_ref/batch.py b/src/tiny_llm_ref/batch.py index 9b43503..8c5d390 100644 --- a/src/tiny_llm_ref/batch.py +++ b/src/tiny_llm_ref/batch.py @@ -25,7 +25,7 @@ def __init__( prompt_idx: int = 0, ): self.prompt = prompt - self.kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + self.kv_cache = model.create_kv_cache() self.model = model self.detokenizer = tokenizer.detokenizer.__class__(tokenizer._tokenizer) self.prefill_tokens = mx.array( diff --git a/src/tiny_llm_ref/generate.py b/src/tiny_llm_ref/generate.py index 92931d2..e06694b 100644 --- a/src/tiny_llm_ref/generate.py +++ b/src/tiny_llm_ref/generate.py @@ -6,6 +6,11 @@ from typing import Callable +def _release_kv_cache(kv_cache): + for layer in kv_cache: + layer.release() + + def simple_generate( model: Qwen2ModelWeek1, tokenizer: TokenizerWrapper, @@ -42,7 +47,7 @@ def _step(model, y): def simple_generate_with_kv_cache( model: Qwen2ModelWeek2, tokenizer: TokenizerWrapper, prompt: str ) -> str: - kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + kv_cache = model.create_kv_cache() def _step(model, y, offset, kv_cache): logits = model(y[None], offset, kv_cache) @@ -52,23 +57,26 @@ def _step(model, y, offset, kv_cache): y = sampler(logprobs) return y, logprobs.squeeze(0) - # prefill with the prompt - tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False)) - detokenizer = tokenizer.detokenizer - detokenizer.reset() - offset = 0 - # generate/decode - while True: - token, _ = _step(model, tokens, offset, kv_cache) - mx.eval(token) - if token.item() == tokenizer.eos_token_id: - break - detokenizer.add_token(token.item()) - print(detokenizer.last_segment, end="", flush=True) - # The first iteration of this loop is prefill. We want to add the offset to the prefilled token size. - # Otherwise, we add the decoded token size (which is always 1). - offset += tokens.size - tokens = token + try: + # prefill with the prompt + tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False)) + detokenizer = tokenizer.detokenizer + detokenizer.reset() + offset = 0 + # generate/decode + while True: + token, _ = _step(model, tokens, offset, kv_cache) + mx.eval(token) + if token.item() == tokenizer.eos_token_id: + break + detokenizer.add_token(token.item()) + print(detokenizer.last_segment, end="", flush=True) + # The first iteration of this loop is prefill. We want to add the offset to the prefilled token size. + # Otherwise, we add the decoded token size (which is always 1). + offset += tokens.size + tokens = token + finally: + _release_kv_cache(kv_cache) def speculative_generate( @@ -78,8 +86,8 @@ def speculative_generate( tokenizer: TokenizerWrapper, prompt: str, ) -> str: - draft_kv_cache = [TinyKvFullCache() for _ in range(draft_model.num_hidden_layers)] - kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)] + draft_kv_cache = draft_model.create_kv_cache() + kv_cache = model.create_kv_cache() def _step(model, y, offset, kv_cache, n_tokens=1): logits = model(y[None], offset, kv_cache) @@ -103,81 +111,85 @@ def _prefill(model, tokenizer, prompt, kv_cache): offset = prefill_tokens.size return token, offset - draft_token, draft_offset = _prefill( - draft_model, draft_tokenizer, prompt, draft_kv_cache - ) - token, offset = _prefill(model, tokenizer, prompt, kv_cache) - - def _decode_one(token, tokenizer): - if token.item() == tokenizer.eos_token_id: - return False - detokenizer = tokenizer.detokenizer - detokenizer.add_token(token.item()) - return True - - def draft_generate(model, last_token, offset, kv_cache, num_drafts): - tokens = [] - current_offset = offset - for _ in range(num_drafts): - token, _ = _step(model, last_token, current_offset, kv_cache) - mx.eval(token) - tokens.append(token.item()) - last_token = token - current_offset += 1 - return tokens - - num_drafts = 4 - - def _rewind_cache(kv_cache, revert_len): - for layer in kv_cache: - layer.rewind(revert_len) - - def _print_text(text, progress): - newline = '\n' - print(f"+{progress} {text.replace(newline, ' ')[-80:]}") - - # speculative decode - while True: - draft_tokens = draft_generate( - draft_model, token, draft_offset, draft_kv_cache, num_drafts + try: + draft_token, draft_offset = _prefill( + draft_model, draft_tokenizer, prompt, draft_kv_cache ) - draft_offset += num_drafts - # assume both models use the same tokenizer - draft_tokens = mx.concat([token, mx.array(draft_tokens)]) - new_tokens, _ = _step(model, draft_tokens, offset, kv_cache, num_drafts + 1) - new_tokens = new_tokens.tolist()[0] - offset += num_drafts + 1 - last_new_token = new_tokens[-1] - new_tokens = mx.array([token.item()] + new_tokens[:-1]) - assert len(new_tokens) == len(draft_tokens) - accept_all = True - for i in range(len(new_tokens)): - if new_tokens[i] != draft_tokens[i]: - # revert the full draft generation; re-generate next time - # or we matched full, then no rewind and use the last token - assert i >= 1 # first token is always the same - revert_len = len(draft_tokens) - i - _rewind_cache(draft_kv_cache, revert_len - 1) - draft_offset -= revert_len - 1 - _rewind_cache(kv_cache, revert_len) - token = mx.array([new_tokens[i]]) - offset -= revert_len - assert offset == draft_offset - assert offset == kv_cache[0].offset - _print_text(tokenizer._detokenizer.text, i) - accept_all = False - break - if not _decode_one(new_tokens[i], tokenizer): - print(tokenizer._detokenizer.text) - return tokenizer._detokenizer.text - if accept_all: - _print_text(tokenizer._detokenizer.text, len(new_tokens)) - draft_generate( - draft_model, - mx.array(draft_tokens[-1:]), - draft_offset, - draft_kv_cache, - 1, + token, offset = _prefill(model, tokenizer, prompt, kv_cache) + + def _decode_one(token, tokenizer): + if token.item() == tokenizer.eos_token_id: + return False + detokenizer = tokenizer.detokenizer + detokenizer.add_token(token.item()) + return True + + def draft_generate(model, last_token, offset, kv_cache, num_drafts): + tokens = [] + current_offset = offset + for _ in range(num_drafts): + token, _ = _step(model, last_token, current_offset, kv_cache) + mx.eval(token) + tokens.append(token.item()) + last_token = token + current_offset += 1 + return tokens + + num_drafts = 4 + + def _rewind_cache(kv_cache, revert_len): + for layer in kv_cache: + layer.rewind(revert_len) + + def _print_text(text, progress): + newline = '\n' + print(f"+{progress} {text.replace(newline, ' ')[-80:]}") + + # speculative decode + while True: + draft_tokens = draft_generate( + draft_model, token, draft_offset, draft_kv_cache, num_drafts ) - token = mx.array([last_new_token]) - draft_offset += 1 + draft_offset += num_drafts + # assume both models use the same tokenizer + draft_tokens = mx.concat([token, mx.array(draft_tokens)]) + new_tokens, _ = _step(model, draft_tokens, offset, kv_cache, num_drafts + 1) + new_tokens = new_tokens.tolist()[0] + offset += num_drafts + 1 + last_new_token = new_tokens[-1] + new_tokens = mx.array([token.item()] + new_tokens[:-1]) + assert len(new_tokens) == len(draft_tokens) + accept_all = True + for i in range(len(new_tokens)): + if new_tokens[i] != draft_tokens[i]: + # revert the full draft generation; re-generate next time + # or we matched full, then no rewind and use the last token + assert i >= 1 # first token is always the same + revert_len = len(draft_tokens) - i + _rewind_cache(draft_kv_cache, revert_len - 1) + draft_offset -= revert_len - 1 + _rewind_cache(kv_cache, revert_len) + token = mx.array([new_tokens[i]]) + offset -= revert_len + assert offset == draft_offset + assert offset == kv_cache[0].offset + _print_text(tokenizer._detokenizer.text, i) + accept_all = False + break + if not _decode_one(new_tokens[i], tokenizer): + print(tokenizer._detokenizer.text) + return tokenizer._detokenizer.text + if accept_all: + _print_text(tokenizer._detokenizer.text, len(new_tokens)) + draft_generate( + draft_model, + mx.array(draft_tokens[-1:]), + draft_offset, + draft_kv_cache, + 1, + ) + token = mx.array([last_new_token]) + draft_offset += 1 + finally: + _release_kv_cache(draft_kv_cache) + _release_kv_cache(kv_cache) diff --git a/src/tiny_llm_ref/kv_cache.py b/src/tiny_llm_ref/kv_cache.py index 144dbd0..8c37b6a 100644 --- a/src/tiny_llm_ref/kv_cache.py +++ b/src/tiny_llm_ref/kv_cache.py @@ -27,6 +27,253 @@ def update_and_fetch( A tuple of the updated key-value cache, the updated value, the sequence length, and the mask. """ + def release(self): + """ + Release all resources owned by this cache. + + Request-scoped caches use this when generation finishes or a batch slot + is removed. Dense caches do not own shared resources, while paged caches + return their physical pages to a shared pool. + """ + return None + + def rewind(self, n: int): + """ + Remove the newest n logical tokens from this cache. + + This is needed by speculative decoding when some draft tokens are + rejected after their K/V has already been written. Implementations may + drop dense suffixes or return whole pages to a page pool. + """ + raise NotImplementedError("This KV cache does not support rewind") + + +class TinyKvPagedPool: + """Model-local physical storage for paged KV. + + The model owns one pool and passes it to every layer cache. The pool gives + out physical page ids from one free list. Because every live page id is + unique, the page id alone is enough to find the physical K/V page. + """ + + def __init__(self, page_size: int = 128): + assert page_size > 0 + self.page_size = page_size + self.key_pages: list[mx.array | None] = [] + self.value_pages: list[mx.array | None] = [] + self.free_page_ids: list[int] = [] + self.used_page_ids: set[int] = set() + + @property + def num_pages(self) -> int: + return len(self.key_pages) + + @property + def num_free_pages(self) -> int: + return len(self.free_page_ids) + + def _check_page_chunk(self, x: mx.array) -> None: + B, H, S, D = x.shape + assert 0 < S <= self.page_size + + def allocate_page(self) -> int: + # The page id is allocated from a model-wide free list. In this teaching + # version, a layer cache owns the page until release/rewind returns it. + if self.free_page_ids: + page_id = self.free_page_ids.pop() + else: + page_id = self.num_pages + self.key_pages.append(None) + self.value_pages.append(None) + self.used_page_ids.add(page_id) + return page_id + + def read_page(self, page_id: int) -> tuple[mx.array, mx.array]: + key = self.key_pages[page_id] + value = self.value_pages[page_id] + if key is None or value is None: + raise ValueError(f"Page {page_id} has no storage") + return key, value + + def _ensure_page_storage( + self, + page_id: int, + key: mx.array, + value: mx.array, + ) -> tuple[mx.array, mx.array]: + key_page = self.key_pages[page_id] + value_page = self.value_pages[page_id] + if key_page is not None and value_page is not None: + return key_page, value_page + + B, H, _, D = key.shape + key_page = mx.zeros((B, H, self.page_size, D), dtype=key.dtype) + value_page = mx.zeros((B, H, self.page_size, D), dtype=value.dtype) + self.key_pages[page_id] = key_page + self.value_pages[page_id] = value_page + return key_page, value_page + + def write_page_slice( + self, + page_id: int, + start: int, + key: mx.array, + value: mx.array, + ) -> None: + assert key.shape == value.shape + self._check_page_chunk(key) + if page_id not in self.used_page_ids: + raise ValueError(f"Page {page_id} is free") + key_page, value_page = self._ensure_page_storage(page_id, key, value) + B, H, capacity, D = key_page.shape + assert value_page.shape == (B, H, capacity, D) + assert capacity == self.page_size + assert key.shape[:2] == (B, H) + assert key.shape[3] == D + end = start + key.shape[2] + assert 0 <= start <= capacity + assert end <= self.page_size + + key_page[:, :, start:end, :] = key + value_page[:, :, start:end, :] = value + self.key_pages[page_id] = key_page + self.value_pages[page_id] = value_page + + def free_page(self, page_id: int) -> None: + if page_id not in self.used_page_ids: + raise ValueError(f"Page {page_id} is already free") + # Keep the page id stable, but drop its old K/V tensors so the id can be + # handed to a future cache. + self.used_page_ids.remove(page_id) + self.key_pages[page_id] = None + self.value_pages[page_id] = None + self.free_page_ids.append(page_id) + + +class TinyKvPagedCache(TinyKvCache): + """Layer-local K/V cache backed by a model-owned page pool. + + Each transformer layer gets its own TinyKvPagedCache and therefore its own + `page_ids`, `page_lens`, and `offset`. The shared part is only the pool, + which lets pages be recycled across requests and layers. + """ + + def __init__(self, pool: TinyKvPagedPool): + self.pool = pool + self.page_size = self.pool.page_size + self.page_ids: list[int] = [] + self.page_lens: list[int] = [] + self.offset = 0 + + @property + def num_pages(self) -> int: + return len(self.page_ids) + + @property + def key_values(self) -> tuple[mx.array, mx.array] | None: + if self.offset == 0: + return None + return self.gather_dense() + + def _append_chunk(self, key: mx.array, value: mx.array) -> None: + assert key.shape == value.shape + B, H, S, D = key.shape + assert B == 1, "Paged request cache only supports one request at a time" + start = 0 + + # First fill the existing tail page if it has free slots. + if self.page_ids and self.page_lens[-1] < self.page_size: + page_id = self.page_ids[-1] + page_start = self.page_lens[-1] + take = min(self.page_size - page_start, S) + self.pool.write_page_slice( + page_id, + page_start, + key[:, :, :take, :], + value[:, :, :take, :], + ) + self.page_lens[-1] += take + start += take + + # Then allocate fresh pages for the remaining chunk. We only write the + # valid prefix; unused tail slots are ignored by page_lens. + while start < S: + end = min(start + self.page_size, S) + page_id = self.pool.allocate_page() + self.pool.write_page_slice( + page_id, + 0, + key[:, :, start:end, :], + value[:, :, start:end, :], + ) + self.page_ids.append(page_id) + self.page_lens.append(end - start) + start = end + + self.offset += S + + def gather_dense(self) -> tuple[mx.array, mx.array]: + assert self.offset > 0 + # Stage A compatibility path: attention still expects dense K/V, so we + # trim each fixed-capacity page to its valid prefix and concatenate + # request pages in logical order. + key_chunks = [] + value_chunks = [] + for page_id, page_len in zip(self.page_ids, self.page_lens): + key_page, value_page = self.pool.read_page(page_id) + assert key_page.shape[2] == self.page_size + assert value_page.shape[2] == self.page_size + key_chunks.append(key_page[:, :, :page_len, :]) + value_chunks.append(value_page[:, :, :page_len, :]) + if len(key_chunks) == 1: + return key_chunks[0], value_chunks[0] + return mx.concat(key_chunks, axis=2), mx.concat(value_chunks, axis=2) + + def update_and_fetch( + self, + key: mx.array, + value: mx.array, + mask_length: int | None = None, + mask: mx.array | str | None = None, + ) -> tuple[mx.array, mx.array, int, Optional[mx.array]]: + assert key.shape == value.shape + self._append_chunk(key, value) + # Day 1 keeps the old attention interface. Day 2 can replace this dense + # gather with block_table/context_lens metadata. + dense_key, dense_value = self.gather_dense() + return dense_key, dense_value, self.offset, mask + + def rewind(self, n: int): + assert 0 <= n <= self.offset + new_offset = self.offset - n + if new_offset == self.offset: + return + if new_offset == 0: + self.release() + return + + target_num_pages = (new_offset + self.page_size - 1) // self.page_size + while len(self.page_ids) > target_num_pages: + # Whole pages beyond the new logical length return to the shared + # allocator. Stale suffix slots in the final page are ignored because + # page_lens defines the valid prefix and future writes overwrite them. + page_id = self.page_ids.pop() + self.page_lens.pop() + self.pool.free_page(page_id) + + last_page_len = new_offset - self.page_size * (target_num_pages - 1) + self.page_lens[-1] = last_page_len + self.offset = new_offset + + def release(self): + # Request completion returns every page owned by this layer cache to the + # model-level allocator. Other layer caches release their own pages. + for page_id in self.page_ids: + self.pool.free_page(page_id) + self.page_ids.clear() + self.page_lens.clear() + self.offset = 0 + class BatchingKvCache(TinyKvCache): def __init__(self, max_active_requests: int, max_seq_len: int): @@ -50,7 +297,9 @@ def update_and_fetch( else: assert self.HD == (H, D), f"expect {self.HD} but got {H, D}" assert B == self.max_active_requests - # Step 1: append the result to the cache + # Step 1: append each active row into its request cache. For paged + # caches, this writes into the request's page table and then gathers a + # dense view for the current Week 3 Day 1 attention path. data = [] for b in range(B): if self.kv_caches[b] is None: @@ -70,7 +319,8 @@ def get_seq_len(data): return seq_len seq_len = max(map(get_seq_len, data)) - # Step 3: generate masks and a single array of keys and values + # Step 3: rebuild one dense batch tensor. True paged attention will + # replace this with block_table/context_lens metadata. keys = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=key.dtype) values = mx.zeros((self.max_active_requests, H, seq_len, D), dtype=value.dtype) masks = mx.full( @@ -106,8 +356,9 @@ def add_request(self, prefilled: TinyKvCache, id: int): self.kv_caches[id] = prefilled def remove_request(self, id: int): - if self.kv_caches is None: + if self.kv_caches[id] is None: raise ValueError(f"Request id {id} is not in the cache") + self.kv_caches[id].release() self.kv_caches[id] = None diff --git a/src/tiny_llm_ref/qwen2_week2.py b/src/tiny_llm_ref/qwen2_week2.py index 5f9047f..a65cc12 100644 --- a/src/tiny_llm_ref/qwen2_week2.py +++ b/src/tiny_llm_ref/qwen2_week2.py @@ -263,6 +263,11 @@ def __init__( self.w_lm_head = None self.mlx_model = mlx_model + def create_kv_cache(self) -> list[TinyKvCache]: + from .kv_cache import TinyKvFullCache + + return [TinyKvFullCache() for _ in range(self.num_hidden_layers)] + def __call__( self, inputs: mx.array, diff --git a/src/tiny_llm_ref/qwen2_week3.py b/src/tiny_llm_ref/qwen2_week3.py new file mode 100644 index 0000000..3e248e9 --- /dev/null +++ b/src/tiny_llm_ref/qwen2_week3.py @@ -0,0 +1,292 @@ +import mlx.core as mx + +from .basics import silu +from .attention import ( + scaled_dot_product_attention_grouped, + flash_attention, +) +from .layer_norm import RMSNorm +from .positional_encoding import RoPE +from typing import Any +from .embedding import Embedding +from .quantize import dequantize_linear, QuantizedWeights, quantized_linear +from .kv_cache import TinyKvCache, TinyKvPagedCache, TinyKvPagedPool + + +class Qwen2MultiHeadAttention: + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + wq: QuantizedWeights, + wk: QuantizedWeights, + wv: QuantizedWeights, + wo: QuantizedWeights, + bq: mx.array, + bk: mx.array, + bv: mx.array, + max_seq_len: int = 32768, + theta: int = 1000000, + use_flash_attention: bool = False, + ): + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + assert hidden_size % num_heads == 0, ( + f"hidden_size {hidden_size} must be divisible by num_heads {num_heads}" + ) + assert num_heads % num_kv_heads == 0, ( + f"num_heads {num_heads} must be divisible by num_kv_heads {num_kv_heads}" + ) + self.head_dim = hidden_size // num_heads + self.scale = mx.rsqrt(self.head_dim) + self.wq = wq + self.wk = wk + self.wv = wv + self.wo = wo + self.bq = bq + self.bk = bk + self.bv = bv + self.rope = RoPE(self.head_dim, max_seq_len, theta, traditional=False) + self.use_flash_attention = use_flash_attention + + def __call__( + self, + x: mx.array, + offsets: int | list[int] | mx.array, + cache: TinyKvCache, + mask: mx.array | str | None = None, + ) -> mx.array: + B, L, _ = x.shape + projection_q = quantized_linear(x, self.wq, bias=self.bq).reshape( + B, L, self.num_heads, self.head_dim + ) + projection_k = quantized_linear(x, self.wk, bias=self.bk).reshape( + B, L, self.num_kv_heads, self.head_dim + ) + projection_v = quantized_linear(x, self.wv, bias=self.bv).reshape( + B, L, self.num_kv_heads, self.head_dim + ) + if isinstance(offsets, int): + offset_slice = [slice(int(offsets), int(offsets + L))] + else: + offset_slice = [slice(int(i), int(i + L)) for i in offsets] + projection_q = self.rope(projection_q, offset=offset_slice) + projection_k = self.rope(projection_k, offset=offset_slice) + projection_q = projection_q.transpose(0, 2, 1, 3) + projection_k = projection_k.transpose(0, 2, 1, 3) + projection_v = projection_v.transpose(0, 2, 1, 3) + projection_k, projection_v, _, mask = cache.update_and_fetch( + projection_k, projection_v, mask_length=L, mask=mask + ) + if self.use_flash_attention: + x = flash_attention( + projection_q.astype(mx.float32), + projection_k.astype(mx.float32), + projection_v.astype(mx.float32), + scale=self.scale, + mask=mask, + ).astype(x.dtype) + else: + x = scaled_dot_product_attention_grouped( + projection_q.astype(mx.float32), + projection_k.astype(mx.float32), + projection_v.astype(mx.float32), + scale=self.scale, + mask=mask, + ).astype(x.dtype) + x = x.transpose(0, 2, 1, 3).reshape(B, L, self.hidden_size) + return quantized_linear(x, self.wo) + + +class Qwen2MLP: + def __init__( + self, + dim: int, + hidden_dim: int, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + ): + self.dim = dim + self.hidden_dim = hidden_dim + self.w_gate = w_gate + self.w_up = w_up + self.w_down = w_down + + def __call__(self, x: mx.array) -> mx.array: + return quantized_linear( + silu(quantized_linear(x, self.w_gate)) * quantized_linear(x, self.w_up), + self.w_down, + ) + + +class Qwen2TransformerBlock: + def __init__( + self, + num_attention_heads: int, + num_kv_heads: int, + hidden_size: int, + intermediate_size: int, + rms_norm_eps: float, + wq: QuantizedWeights, + wk: QuantizedWeights, + wv: QuantizedWeights, + wo: QuantizedWeights, + bq: mx.array, + bk: mx.array, + bv: mx.array, + w_gate: QuantizedWeights, + w_up: QuantizedWeights, + w_down: QuantizedWeights, + w_input_layernorm: mx.array, + w_post_attention_layernorm: mx.array, + max_seq_len: int = 32768, + theta: int = 1000000, + use_flash_attention: bool = False, + ): + self.num_attention_heads = num_attention_heads + self.hidden_size = hidden_size + self.mlp = Qwen2MLP(hidden_size, intermediate_size, w_gate, w_up, w_down) + self.input_layernorm = RMSNorm(hidden_size, w_input_layernorm, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + hidden_size, w_post_attention_layernorm, eps=rms_norm_eps + ) + self.self_attn = Qwen2MultiHeadAttention( + num_heads=num_attention_heads, + hidden_size=hidden_size, + num_kv_heads=num_kv_heads, + wq=wq, + wk=wk, + wv=wv, + wo=wo, + bq=bq, + bk=bk, + bv=bv, + max_seq_len=max_seq_len, + theta=theta, + use_flash_attention=use_flash_attention, + ) + + def __call__( + self, + x: mx.array, + offset: int | list[int] | mx.array, + cache: TinyKvCache, + mask: mx.array | str | None = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), offset, cache, mask) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class Qwen2ModelWeek3: + def __init__( + self, + mlx_model: Any, + enable_flash_attn: bool = False, + page_size: int = 128, + ): + self.num_hidden_layers = mlx_model.args.num_hidden_layers + self.hidden_size = mlx_model.args.hidden_size + self.vocab_size = mlx_model.args.vocab_size + self.page_size = page_size + # One model-level pool is shared by all layer caches. Each layer cache + # still owns its own page table and allocates its own physical pages. + self.page_pool = TinyKvPagedPool(page_size=self.page_size) + precision = mx.float16 + self.precision = precision + + self.embedding = Embedding( + vocab_size=self.vocab_size, + embedding_dim=self.hidden_size, + weight=dequantize_linear(mlx_model.model.embed_tokens).astype(precision), + ) + self.layers_inner = [] + + for i in range(mlx_model.args.num_hidden_layers): + wq = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.q_proj + ) + wk = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.k_proj + ) + wv = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.v_proj + ) + wo = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].self_attn.o_proj + ) + w_gate = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.gate_proj + ) + w_up = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.up_proj + ) + w_down = QuantizedWeights.from_mlx_layer( + mlx_model.model.layers[i].mlp.down_proj + ) + + layer = Qwen2TransformerBlock( + num_attention_heads=mlx_model.args.num_attention_heads, + num_kv_heads=mlx_model.args.num_key_value_heads, + hidden_size=mlx_model.args.hidden_size, + intermediate_size=mlx_model.args.intermediate_size, + rms_norm_eps=mlx_model.args.rms_norm_eps, + wq=wq, + wk=wk, + wv=wv, + wo=wo, + bq=mlx_model.model.layers[i].self_attn.q_proj.bias.astype(precision), + bk=mlx_model.model.layers[i].self_attn.k_proj.bias.astype(precision), + bv=mlx_model.model.layers[i].self_attn.v_proj.bias.astype(precision), + w_gate=w_gate, + w_up=w_up, + w_down=w_down, + w_input_layernorm=mlx_model.model.layers[ + i + ].input_layernorm.weight.astype(precision), + w_post_attention_layernorm=mlx_model.model.layers[ + i + ].post_attention_layernorm.weight.astype(precision), + max_seq_len=mlx_model.args.max_position_embeddings, + theta=mlx_model.args.rope_theta, + use_flash_attention=enable_flash_attn, + ) + self.layers_inner.append(layer) + self.norm = RMSNorm( + mlx_model.args.hidden_size, + weight=mlx_model.model.norm.weight.astype(precision), + eps=mlx_model.args.rms_norm_eps, + ) + if not mlx_model.args.tie_word_embeddings: + self.w_lm_head = QuantizedWeights.from_mlx_layer(mlx_model.lm_head) + else: + self.w_lm_head = None + self.mlx_model = mlx_model + + def create_kv_cache(self) -> list[TinyKvCache]: + # One request gets one cache handle per layer. The handles share the + # model-level pool, but their page_ids/page_lens/offset are independent. + return [ + TinyKvPagedCache(pool=self.page_pool) + for _ in range(self.num_hidden_layers) + ] + + def __call__( + self, + inputs: mx.array, + offset: int | list[int] | mx.array, + cache: list[TinyKvCache], + ) -> mx.array: + h = self.embedding(inputs) + for layer in range(self.num_hidden_layers): + h = self.layers_inner[layer](h, offset, cache[layer], mask="causal") + h = self.norm(h) + if self.w_lm_head is not None: + return quantized_linear(h, self.w_lm_head) + else: + return self.embedding.as_linear(h) diff --git a/tests_refsol/test_week_3_day_1.py b/tests_refsol/test_week_3_day_1.py new file mode 100644 index 0000000..ee2d3b2 --- /dev/null +++ b/tests_refsol/test_week_3_day_1.py @@ -0,0 +1,217 @@ +from types import SimpleNamespace + +import mlx.core as mx + +from .tiny_llm_base import ( + Qwen2ModelWeek2, + Qwen2ModelWeek3, + TinyKvFullCache, + TinyKvPagedCache, + TinyKvPagedPool, +) +from .utils import assert_allclose + + +def _random_chunk(length: int, num_heads: int = 2, head_dim: int = 4) -> tuple[mx.array, mx.array]: + key = mx.random.normal(shape=(1, num_heads, length, head_dim)).astype(mx.float32) + value = mx.random.normal(shape=(1, num_heads, length, head_dim)).astype(mx.float32) + return key, value + + +def _quantized_layer( + out_dim: int, in_dim: int, *, bias: bool = False, group_size: int = 64 +) -> SimpleNamespace: + weight = mx.random.normal(shape=(out_dim, in_dim), dtype=mx.float16) + quantized_weight, scales, biases = mx.quantize(weight, group_size=group_size, bits=4) + layer = SimpleNamespace( + weight=quantized_weight, + scales=scales, + biases=biases, + group_size=group_size, + bits=4, + ) + if bias: + layer.bias = mx.random.normal(shape=(out_dim,), dtype=mx.float16) + return layer + + +def _fake_qwen2_mlx_model() -> SimpleNamespace: + mx.random.seed(0) + args = SimpleNamespace( + num_hidden_layers=2, + hidden_size=64, + vocab_size=128, + num_attention_heads=4, + num_key_value_heads=2, + intermediate_size=128, + rms_norm_eps=1e-5, + max_position_embeddings=128, + rope_theta=10000, + tie_word_embeddings=True, + ) + embed_tokens = _quantized_layer(args.vocab_size, args.hidden_size) + kv_hidden_size = args.hidden_size // args.num_attention_heads * args.num_key_value_heads + layers = [] + for _ in range(args.num_hidden_layers): + layers.append( + SimpleNamespace( + self_attn=SimpleNamespace( + q_proj=_quantized_layer(args.hidden_size, args.hidden_size, bias=True), + k_proj=_quantized_layer(kv_hidden_size, args.hidden_size, bias=True), + v_proj=_quantized_layer(kv_hidden_size, args.hidden_size, bias=True), + o_proj=_quantized_layer(args.hidden_size, args.hidden_size), + ), + mlp=SimpleNamespace( + gate_proj=_quantized_layer(args.intermediate_size, args.hidden_size), + up_proj=_quantized_layer(args.intermediate_size, args.hidden_size), + down_proj=_quantized_layer(args.hidden_size, args.intermediate_size), + ), + input_layernorm=SimpleNamespace( + weight=mx.ones((args.hidden_size,), dtype=mx.float16) + ), + post_attention_layernorm=SimpleNamespace( + weight=mx.ones((args.hidden_size,), dtype=mx.float16) + ), + ) + ) + return SimpleNamespace( + args=args, + model=SimpleNamespace( + embed_tokens=embed_tokens, + layers=layers, + norm=SimpleNamespace(weight=mx.ones((args.hidden_size,), dtype=mx.float16)), + ), + ) + + +def test_task_1_paged_cache_matches_full_cache(): + page_size = 4 + full = TinyKvFullCache() + pool = TinyKvPagedPool(page_size=page_size) + paged = TinyKvPagedCache(pool=pool) + + total_len = 0 + for length in [3, 2, 5]: + key, value = _random_chunk(length) + full_key, full_value, full_len, _ = full.update_and_fetch(key, value) + paged_key, paged_value, paged_len, _ = paged.update_and_fetch(key, value) + total_len += length + assert full_len == paged_len == total_len + assert paged.num_pages == (total_len + page_size - 1) // page_size + physical_page_capacity = [ + paged.pool.read_page(page_id)[0].shape[2] for page_id in paged.page_ids + ] + assert physical_page_capacity == [page_size] * paged.num_pages + assert sum(paged.page_lens) == total_len + assert_allclose(paged_key, full_key, precision=mx.float32) + assert_allclose(paged_value, full_value, precision=mx.float32) + + +def test_task_1_paged_pool_reuses_freed_pages(): + pool = TinyKvPagedPool(page_size=4) + first = TinyKvPagedCache(pool=pool) + second = TinyKvPagedCache(pool=pool) + + key, value = _random_chunk(6) + first.update_and_fetch(key, value) + assert first.page_ids == [0, 1] + assert pool.num_pages == 2 + assert pool.num_free_pages == 0 + + first.release() + assert first.offset == 0 + assert pool.num_pages == 2 + assert pool.num_free_pages == 2 + + second_key, second_value = _random_chunk(5) + gathered_key, gathered_value, seq_len, _ = second.update_and_fetch(second_key, second_value) + assert seq_len == 5 + assert pool.num_pages == 2 + assert pool.num_free_pages == 0 + assert set(second.page_ids) == {0, 1} + assert_allclose(gathered_key, second_key, precision=mx.float32) + assert_allclose(gathered_value, second_value, precision=mx.float32) + + +def test_task_1_paged_cache_rewind(): + page_size = 4 + pool = TinyKvPagedPool(page_size=page_size) + paged = TinyKvPagedCache(pool=pool) + full = TinyKvFullCache() + + for length in [4, 3, 2]: + key, value = _random_chunk(length) + paged.update_and_fetch(key, value) + full.update_and_fetch(key, value) + + assert paged.page_lens == [4, 4, 1] + paged.rewind(3) + full.rewind(3) + + paged_key, paged_value = paged.gather_dense() + full_key, full_value = full.key_values + assert paged.offset == full.offset == 6 + assert paged.page_lens == [4, 2] + assert paged.num_pages == 2 + assert paged.pool.num_pages == 3 + assert paged.pool.num_free_pages == 1 + physical_page_capacity = [ + paged.pool.read_page(page_id)[0].shape[2] for page_id in paged.page_ids + ] + assert physical_page_capacity == [page_size] * paged.num_pages + assert_allclose(paged_key, full_key, precision=mx.float32) + assert_allclose(paged_value, full_value, precision=mx.float32) + + +def test_task_1_model_kv_caches_share_layer_pools(): + mlx_model = _fake_qwen2_mlx_model() + week3_model = Qwen2ModelWeek3(mlx_model, page_size=4) + first_request_cache = week3_model.create_kv_cache() + second_request_cache = week3_model.create_kv_cache() + + assert len(first_request_cache) == week3_model.num_hidden_layers + for layer in range(week3_model.num_hidden_layers): + assert first_request_cache[layer].pool is week3_model.page_pool + assert second_request_cache[layer].pool is week3_model.page_pool + + assert first_request_cache[0].page_ids is not first_request_cache[1].page_ids + assert first_request_cache[0].page_lens is not first_request_cache[1].page_lens + assert first_request_cache[0].pool is first_request_cache[1].pool + + +def test_task_1_model_layer_caches_keep_independent_page_metadata(): + mlx_model = _fake_qwen2_mlx_model() + week3_model = Qwen2ModelWeek3(mlx_model, page_size=4) + cache = week3_model.create_kv_cache() + inputs = mx.array([[1, 5, 7, 3, 9]], dtype=mx.int32) + + week3_model(inputs, 0, cache) + + assert cache[0].page_ids == [0, 1] + assert cache[0].page_lens == [4, 1] + owned_page_ids = set(cache[0].page_ids) + for layer in range(1, week3_model.num_hidden_layers): + assert cache[layer].page_lens == cache[0].page_lens + assert owned_page_ids.isdisjoint(cache[layer].page_ids) + owned_page_ids.update(cache[layer].page_ids) + for page_id in cache[layer].page_ids: + key_page, value_page = week3_model.page_pool.read_page(page_id) + assert key_page.shape[2] == week3_model.page_size + assert value_page.shape[2] == week3_model.page_size + + +def test_task_3_incremental_decode_matches_week2(): + mlx_model = _fake_qwen2_mlx_model() + week2_model = Qwen2ModelWeek2(mlx_model) + week3_model = Qwen2ModelWeek3(mlx_model, page_size=4) + inputs = mx.array([[1, 5, 7, 3, 9, 11]], dtype=mx.int32) + week2_cache = week2_model.create_kv_cache() + week3_cache = week3_model.create_kv_cache() + + for offset in range(inputs.shape[1]): + token = inputs[:, offset : offset + 1] + week2_out = week2_model(token, offset, week2_cache) + week3_out = week3_model(token, offset, week3_cache) + week2_out = week2_out - mx.logsumexp(week2_out, keepdims=True) + week3_out = week3_out - mx.logsumexp(week3_out, keepdims=True) + assert_allclose(week3_out, week2_out, precision=mx.float16, rtol=1e-3, atol=1e-3)