From 01dc86d87fb3e971dda029948054ee8252b0051e Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Wed, 18 Mar 2026 20:34:01 -0700 Subject: [PATCH 01/13] optimize performance --- tokenizers/Cargo.toml | 1 + tokenizers/benches/llama3_benchmark.rs | 37 +++++ tokenizers/src/models/bpe/mod.rs | 73 +++++++++ tokenizers/src/models/bpe/model.rs | 9 +- tokenizers/src/models/bpe/serialization.rs | 6 +- tokenizers/src/models/bpe/trainer.rs | 5 +- tokenizers/src/models/bpe/word.rs | 181 +++++++++++---------- tokenizers/src/utils/cache.rs | 181 +++++++++++++++------ 8 files changed, 347 insertions(+), 146 deletions(-) diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 40b273ac4a..34f0f0c5a0 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -75,6 +75,7 @@ getrandom = { version = "0.3" } esaxx-rs = { version = "0.1.10", default-features = false, features=[]} monostate = "0.1.12" ahash = { version = "0.8.11", features = ["serde"] } +rustc-hash = "2" dary_heap = { version = "0.3.6", features = ["serde"] } compact_str = { version = "0.9", features = ["serde"] } diff --git a/tokenizers/benches/llama3_benchmark.rs b/tokenizers/benches/llama3_benchmark.rs index 8bd45396a3..9bd02b952b 100644 --- a/tokenizers/benches/llama3_benchmark.rs +++ b/tokenizers/benches/llama3_benchmark.rs @@ -6,6 +6,7 @@ mod common; use common::{iter_bench_encode, iter_bench_encode_batch, iter_bench_train}; use criterion::{Criterion, Throughput}; use std::hint::black_box; +use std::sync::Arc; use tokenizers::{ models::{bpe::BpeTrainerBuilder, TrainerWrapper}, EncodeInput, Tokenizer, @@ -43,6 +44,42 @@ pub fn llama3(c: &mut Criterion) { group.bench_function("llama3-batch", |b| { b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches)) }); + + // Cache effectiveness: encode the same medium-length input 1000 times + group.bench_function("llama3-cache-repeated", |b| { + let sample = data.lines().find(|l| l.len() > 200).unwrap_or("hello world"); + b.iter(|| { + for _ in 0..1000 { + let _ = black_box(tokenizer.encode(black_box(sample), false)); + } + }) + }); + + // Concurrent: N threads each encode a medium prompt + let tokenizer_arc = Arc::new(tokenizer.clone()); + for num_threads in [2, 4, 8] { + let tok = tokenizer_arc.clone(); + let sample: String = data.lines().take(3).collect::>().join("\n"); + group.bench_function(format!("llama3-concurrent-{num_threads}t"), move |b| { + b.iter(|| { + std::thread::scope(|s| { + let handles: Vec<_> = (0..num_threads) + .map(|_| { + let tok = &tok; + let sample = &sample; + s.spawn(move || { + let _ = black_box(tok.encode(black_box(sample.as_str()), false)); + }) + }) + .collect(); + for h in handles { + h.join().unwrap(); + } + }); + }) + }); + } + let mut trainer: TrainerWrapper = BpeTrainerBuilder::default() .show_progress(false) .build() diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index f0d40b2df6..abf83382ff 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -1,4 +1,5 @@ //! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model. +use rustc_hash::FxHashMap; use std::{iter, mem}; mod model; @@ -8,6 +9,78 @@ mod word; type Pair = (u32, u32); +/// Packs a `(u32, u32)` pair into a single `u64` for faster hashing. +#[inline] +fn pack_pair(pair: &Pair) -> u64 { + (pair.0 as u64) << 32 | pair.1 as u64 +} + +/// Unpacks a `u64` back into a `(u32, u32)` pair. +#[inline] +fn unpack_pair(packed: u64) -> Pair { + ((packed >> 32) as u32, packed as u32) +} + +/// A merge-lookup map that packs `(u32, u32)` pair keys into single `u64` values +/// for faster hashing (single FxHash multiply instead of hashing two fields). +/// +/// Values are `(rank, new_id)` tuples. +#[derive(Clone, Debug)] +pub(crate) struct MergeMap { + inner: FxHashMap, +} + +impl MergeMap { + #[allow(dead_code)] + pub fn new() -> Self { + MergeMap { + inner: FxHashMap::default(), + } + } + + pub fn with_capacity(cap: usize) -> Self { + MergeMap { + inner: FxHashMap::with_capacity_and_hasher(cap, Default::default()), + } + } + + #[inline] + pub fn get(&self, pair: &Pair) -> Option<&(u32, u32)> { + self.inner.get(&pack_pair(pair)) + } + + pub fn insert(&mut self, pair: Pair, value: (u32, u32)) -> Option<(u32, u32)> { + self.inner.insert(pack_pair(&pair), value) + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Iterate over `(Pair, &(rank, new_id))`. + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(|(k, v)| (unpack_pair(*k), v)) + } +} + +impl PartialEq for MergeMap { + fn eq(&self, other: &Self) -> bool { + self.inner == other.inner + } +} + +impl std::iter::FromIterator<(Pair, (u32, u32))> for MergeMap { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lo, _) = iter.size_hint(); + let mut map = MergeMap::with_capacity(lo); + for (pair, val) in iter { + map.insert(pair, val); + } + map + } +} + /// Errors that can be encountered while using or constructing a `BPE` model. #[derive(thiserror::Error, Debug)] pub enum Error { diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 2f560b7e3f..9be27dc446 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -1,4 +1,4 @@ -use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, Pair, Word}; +use super::{super::OrderedVocabIter, trainer::BpeTrainer, Error, MergeMap, Pair, Word}; use crate::tokenizer::{Model, Result, Token}; use crate::utils::cache::{Cache, DEFAULT_CACHE_CAPACITY, MAX_LENGTH}; use crate::utils::iter::ResultShunt; @@ -16,7 +16,6 @@ use std::{ pub type Vocab = AHashMap; type VocabR = AHashMap; -pub type MergeMap = AHashMap; pub type Merges = Vec<(String, String)>; struct Config { @@ -553,12 +552,12 @@ impl Model for BPE { .iter() .collect(); let mut merges_file = File::create(&merges_path)?; - let mut merges: Vec<(&Pair, &u32)> = self + let mut merges: Vec<(Pair, u32)> = self .merges .iter() - .map(|(pair, (rank, _))| (pair, rank)) + .map(|(pair, (rank, _))| (pair, *rank)) .collect(); - merges.sort_unstable_by_key(|k| *k.1); + merges.sort_unstable_by_key(|k| k.1); merges_file.write_all(b"#version: 0.2\n")?; merges_file.write_all( &merges diff --git a/tokenizers/src/models/bpe/serialization.rs b/tokenizers/src/models/bpe/serialization.rs index 98cf549445..c28f2b184f 100644 --- a/tokenizers/src/models/bpe/serialization.rs +++ b/tokenizers/src/models/bpe/serialization.rs @@ -24,12 +24,12 @@ impl Serialize for BPE { model.serialize_field("ignore_merges", &self.ignore_merges)?; // Then the large ones - let mut merges: Vec<(&Pair, &u32)> = self + let mut merges: Vec<(Pair, u32)> = self .merges .iter() - .map(|(pair, (rank, _))| (pair, rank)) + .map(|(pair, (rank, _))| (pair, *rank)) .collect(); - merges.sort_unstable_by_key(|k| *k.1); + merges.sort_unstable_by_key(|k| k.1); let merges = merges .into_iter() .map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone())) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index cda6aea654..4b9ec0d588 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -640,7 +640,8 @@ impl Trainer for BpeTrainer { #[cfg(test)] mod tests { - use super::{BpeTrainer, Pair, BPE}; + use super::{BpeTrainer, BPE}; + use crate::models::bpe::MergeMap; use ahash::AHashMap; use compact_str::CompactString; @@ -707,7 +708,7 @@ mod tests { // where 'rank' determines the order in which this merge will be applied during // tokenization, and 'id' is the vocab id of the symbol resulting from merging // the pair of symbols in the corresponding key. - let expected_merges: AHashMap = [ + let expected_merges: MergeMap = [ ((17, 11), (0, 22)), // 'r' + 'e' -> 're' ((8, 22), (1, 23)), // 'a' + 're' -> 'are' ((13, 18), (2, 24)), // 'i' + 's' -> 'is' diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 7bf2dee566..9e34defe2f 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,9 +1,14 @@ -use super::Pair; -use ahash::AHashMap; +use super::{MergeMap, Pair}; use dary_heap::QuaternaryHeap; use rand::{rng, Rng}; +use std::cell::RefCell; use std::cmp::Ordering; +thread_local! { + static TL_MERGE_HEAP: RefCell> = RefCell::new(QuaternaryHeap::new()); + static TL_MERGE_SKIP: RefCell> = RefCell::new(Vec::new()); +} + #[derive(Debug, Eq)] struct Merge { pos: usize, @@ -159,91 +164,97 @@ impl Word { changes } - pub(super) fn merge_all(&mut self, merges: &AHashMap, dropout: Option) { - let mut queue = QuaternaryHeap::with_capacity(self.symbols.len()); - let mut skip = Vec::with_capacity(queue.len()); - - queue.extend( - self.symbols - .windows(2) - .enumerate() - .filter_map(|(index, window)| { - let pair = (window[0].c, window[1].c); - merges.get(&pair).map(|m| Merge { - pos: index, - rank: m.0, - new_id: m.1, - }) - }), - ); - - while let Some(top) = queue.pop() { - if dropout.map(|d| rng().random::() < d).unwrap_or(false) { - skip.push(top); - } else { - // Re-insert the skipped elements - queue.extend(skip.drain(..)); - - if self.symbols[top.pos].len == 0 { - continue; - } - // Do nothing if we are the last symbol - if self.symbols[top.pos].next == -1 { - continue; - } - - let next_pos = self.symbols[top.pos].next as usize; - let right = self.symbols[next_pos]; - - // Make sure we are not processing an expired queue entry - let target_new_pair = (self.symbols[top.pos].c, right.c); - if merges - .get(&target_new_pair) - .is_none_or(|(_, new_id)| *new_id != top.new_id) - { - continue; - } - - // Otherwise, let's merge - self.symbols[top.pos].merge_with(&right, top.new_id); - // Tag the right part as removed - self.symbols[next_pos].len = 0; - - // Update `prev` on the new `next` to the current pos - if right.next > -1 && (right.next as usize) < self.symbols.len() { - self.symbols[right.next as usize].prev = top.pos as isize; - } - - // Insert the new pair formed with the previous symbol - let current = &self.symbols[top.pos]; - if current.prev >= 0 { - let prev = current.prev as usize; - let prev_symbol = self.symbols[prev]; - let new_pair = (prev_symbol.c, current.c); - if let Some((rank, new_id)) = merges.get(&new_pair) { - queue.push(Merge { - pos: current.prev as usize, - rank: *rank, - new_id: *new_id, - }); - } - } - - // Insert the new pair formed with the next symbol - let next = current.next as usize; - if next < self.symbols.len() { - let next_symbol = self.symbols[next]; - let new_pair = (current.c, next_symbol.c); - if let Some((rank, new_id)) = merges.get(&new_pair) { - queue.push(Merge { - pos: top.pos, - rank: *rank, - new_id: *new_id, - }); + pub(super) fn merge_all(&mut self, merges: &MergeMap, dropout: Option) { + TL_MERGE_HEAP.with(|heap_cell| { + TL_MERGE_SKIP.with(|skip_cell| { + let mut queue = heap_cell.borrow_mut(); + let mut skip = skip_cell.borrow_mut(); + queue.clear(); + skip.clear(); + + queue.extend( + self.symbols + .windows(2) + .enumerate() + .filter_map(|(index, window)| { + let pair = (window[0].c, window[1].c); + merges.get(&pair).map(|m| Merge { + pos: index, + rank: m.0, + new_id: m.1, + }) + }), + ); + + while let Some(top) = queue.pop() { + if dropout.map(|d| rng().random::() < d).unwrap_or(false) { + skip.push(top); + } else { + // Re-insert the skipped elements + queue.extend(skip.drain(..)); + + if self.symbols[top.pos].len == 0 { + continue; + } + // Do nothing if we are the last symbol + if self.symbols[top.pos].next == -1 { + continue; + } + + let next_pos = self.symbols[top.pos].next as usize; + let right = self.symbols[next_pos]; + + // Make sure we are not processing an expired queue entry + let target_new_pair = (self.symbols[top.pos].c, right.c); + if merges + .get(&target_new_pair) + .is_none_or(|(_, new_id)| *new_id != top.new_id) + { + continue; + } + + // Otherwise, let's merge + self.symbols[top.pos].merge_with(&right, top.new_id); + // Tag the right part as removed + self.symbols[next_pos].len = 0; + + // Update `prev` on the new `next` to the current pos + if right.next > -1 && (right.next as usize) < self.symbols.len() { + self.symbols[right.next as usize].prev = top.pos as isize; + } + + // Insert the new pair formed with the previous symbol + let current = &self.symbols[top.pos]; + if current.prev >= 0 { + let prev = current.prev as usize; + let prev_symbol = self.symbols[prev]; + let new_pair = (prev_symbol.c, current.c); + if let Some((rank, new_id)) = merges.get(&new_pair) { + queue.push(Merge { + pos: current.prev as usize, + rank: *rank, + new_id: *new_id, + }); + } + } + + // Insert the new pair formed with the next symbol + let next = current.next as usize; + if next < self.symbols.len() { + let next_symbol = self.symbols[next]; + let new_pair = (current.c, next_symbol.c); + if let Some((rank, new_id)) = merges.get(&new_pair) { + queue.push(Merge { + pos: top.pos, + rank: *rank, + new_id: *new_id, + }); + } + } } } - } - } + }); + }); // Filter out the removed symbols self.symbols.retain(|s| s.len != 0); diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 15c6b65f18..47ecf0e61b 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -1,6 +1,6 @@ -use ahash::AHashMap; +use rustc_hash::FxHashMap; use std::borrow::Borrow; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; use std::sync::RwLock; /// The default capacity for a `BPE`'s internal cache. @@ -9,25 +9,126 @@ pub static DEFAULT_CACHE_CAPACITY: usize = 10_000; /// Strings that are too long have minimal chances to cache hit anyway pub static MAX_LENGTH: usize = 256; -/// Provides a simple multithread cache to speed up BPE tokenization that will try to read values -/// concurrently but won't block if another thread is writing. -/// The goal is clearly not the accuracy of the content, both get and set -/// are not guaranteed to actually get or set. -#[derive(Debug)] +/// Number of shards in the shared cache. +const SHARED_CACHE_SHARDS: usize = 64; + +// --------------------------------------------------------------------------- +// FxHash helper +// --------------------------------------------------------------------------- + +#[inline] +fn fx_hash(key: &K) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + key.hash(&mut h); + h.finish() +} + +// --------------------------------------------------------------------------- +// Sharded cache +// --------------------------------------------------------------------------- + +struct ShardedMap { + shards: Vec>>, + per_shard_capacity: usize, +} + +impl ShardedMap { + fn new(total_capacity: usize) -> Self { + let per_shard = total_capacity.div_ceil(SHARED_CACHE_SHARDS).max(1); + let shards = (0..SHARED_CACHE_SHARDS) + .map(|_| { + RwLock::new(FxHashMap::with_capacity_and_hasher( + per_shard, + Default::default(), + )) + }) + .collect(); + ShardedMap { + shards, + per_shard_capacity: per_shard, + } + } + + #[inline] + fn shard_for(key: &Q) -> usize { + let h = fx_hash(key); + (h >> 48) as usize % SHARED_CACHE_SHARDS + } + + fn get(&self, key: &Q) -> Option + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + let idx = Self::shard_for(key); + let shard = &self.shards[idx]; + if let Ok(guard) = shard.try_read() { + guard.get(key).cloned() + } else { + None + } + } + + fn set(&self, key: K, value: V) { + let idx = Self::shard_for(&key); + let shard = &self.shards[idx]; + if let Ok(guard) = shard.try_read() { + if guard.len() >= self.per_shard_capacity { + return; + } + } else { + return; + } + if let Ok(mut guard) = shard.try_write() { + if guard.len() < self.per_shard_capacity { + guard.insert(key, value); + } + } + } + + fn clear(&self) { + for shard in &self.shards { + if let Ok(mut guard) = shard.write() { + guard.clear(); + } + } + } +} + +// --------------------------------------------------------------------------- +// Public Cache +// --------------------------------------------------------------------------- + +/// Sharded cache for fast concurrent tokenization lookups. +/// +/// Uses 64 shards with per-shard `RwLock` to minimize lock +/// contention across threads. FxHash provides fast, non-cryptographic hashing +/// suited to the small keys used in tokenization caches. pub(crate) struct Cache where - K: Eq + Hash + Clone, - V: Clone, + K: Eq + Hash + Clone + 'static, + V: Clone + 'static, { - map: RwLock>, + map: ShardedMap, pub capacity: usize, } -// We dont really care about Cache comparison, so let's make them always equal +impl std::fmt::Debug for Cache +where + K: Eq + Hash + Clone + 'static, + V: Clone + 'static, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Cache") + .field("capacity", &self.capacity) + .finish() + } +} + impl PartialEq for Cache where - K: Eq + Hash + Clone, - V: Clone, + K: Eq + Hash + Clone + 'static, + V: Clone + 'static, { fn eq(&self, _other: &Cache) -> bool { true @@ -36,8 +137,8 @@ where impl Default for Cache where - K: Eq + Hash + Clone, - V: Clone, + K: Eq + Hash + Clone + 'static, + V: Clone + 'static, { fn default() -> Self { Self::new(DEFAULT_CACHE_CAPACITY) @@ -46,13 +147,15 @@ where impl Cache where - K: Eq + Hash + Clone, - V: Clone, + K: Eq + Hash + Clone + 'static, + V: Clone + 'static, { /// Create new `Cache` with the given capacity. pub(crate) fn new(capacity: usize) -> Self { - let map = RwLock::new(AHashMap::with_capacity(capacity)); - Cache { map, capacity } + Cache { + map: ShardedMap::new(capacity), + capacity, + } } /// Create a fresh `Cache` with the same configuration. @@ -62,7 +165,7 @@ where /// Clear the cache. pub(crate) fn clear(&self) { - self.map.write().unwrap().clear(); + self.map.clear(); } #[allow(dead_code)] @@ -72,11 +175,7 @@ where K: Borrow, Q: Hash + Eq + ?Sized + 'a, { - if let Ok(ref mut cache) = self.map.try_read() { - Some(keys_iter.map(|k| cache.get(k).cloned()).collect()) - } else { - None - } + Some(keys_iter.map(|k| self.get(k)).collect()) } pub(crate) fn get(&self, key: &Q) -> Option @@ -84,45 +183,25 @@ where K: Borrow, Q: Hash + Eq + ?Sized, { - if let Ok(ref mut cache) = self.map.try_read() { - cache.get(key).cloned() - } else { - None - } + self.map.get(key) } + #[allow(dead_code)] pub(crate) fn set_values(&self, entries: I) where I: IntoIterator, { - // Before trying to acquire a write lock, we check if we are already at - // capacity with a read handler. - if let Ok(cache) = self.map.try_read() { - if cache.len() >= self.capacity { - // At capacity, so do nothing. - return; - } - } else { - // If we couldn't acquire a read handle then we probably won't be able to acquire - // a write handle one quadrillionth of a second later. - return; - } - - // Not at capacity, so try acquiring a write handle. - if let Ok(mut cache) = self.map.try_write() { - let free = self.capacity - cache.len(); - cache.extend(entries.into_iter().take(free)); + for (k, v) in entries { + self.map.set(k, v); } } pub(crate) fn set(&self, key: K, value: V) { - self.set_values(std::iter::once((key, value))) + self.map.set(key, value); } pub(crate) fn resize(&mut self, capacity: usize) { self.capacity = capacity; - if let Ok(mut cache) = self.map.try_write() { - cache.shrink_to(capacity); - } + self.map = ShardedMap::new(capacity); } } From e779f54a14c15c1d77c7b7455c3965bc43409522 Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Wed, 18 Mar 2026 21:12:09 -0700 Subject: [PATCH 02/13] add byte level lookup --- tokenizers/src/pre_tokenizers/byte_level.rs | 23 +++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 8bc0f30af0..4295ecef7e 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -48,6 +48,17 @@ static BYTES_CHAR: LazyLock> = LazyLock::new(bytes_char); static CHAR_BYTES: LazyLock> = LazyLock::new(|| bytes_char().into_iter().map(|(c, b)| (b, c)).collect()); +/// Flat lookup table: byte value → unicode char. Eliminates HashMap lookup +/// in the byte-level encoding hot path. +static BYTE_TO_CHAR: LazyLock<[char; 256]> = LazyLock::new(|| { + let map = bytes_char(); + let mut table = ['\0'; 256]; + for (b, c) in &map { + table[*b as usize] = *c; + } + table +}); + #[derive(Copy, Clone, Debug, PartialEq, Eq)] /// Provides all the necessary steps to handle the BPE tokenization at the byte-level. Takes care /// of all the required processing steps to transform a UTF-8 string as needed before and after the @@ -131,15 +142,15 @@ impl PreTokenizer for ByteLevel { })?; pretokenized.normalize(|normalized| { let s = normalized.get(); + let table = &*BYTE_TO_CHAR; let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len()); for (i, cur_char) in s.char_indices() { let size = cur_char.len_utf8(); - transformations.extend( - s.as_bytes()[i..i + size] - .iter() - .enumerate() - .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))), - ); + let bytes = &s.as_bytes()[i..i + size]; + transformations.push((table[bytes[0] as usize], 0)); + for &b in &bytes[1..] { + transformations.push((table[b as usize], 1)); + } } normalized.transform(transformations, 0); Ok(()) From 9c53d924dac2dfc90e5a1927772da44b6ff22fd2 Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Thu, 19 Mar 2026 08:49:19 -0700 Subject: [PATCH 03/13] sync llama 3 benchmark --- tokenizers/benches/llama3_benchmark.rs | 35 +++++++++++++------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/tokenizers/benches/llama3_benchmark.rs b/tokenizers/benches/llama3_benchmark.rs index 9bd02b952b..ad0dd1c29a 100644 --- a/tokenizers/benches/llama3_benchmark.rs +++ b/tokenizers/benches/llama3_benchmark.rs @@ -45,30 +45,31 @@ pub fn llama3(c: &mut Criterion) { b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches)) }); - // Cache effectiveness: encode the same medium-length input 1000 times - group.bench_function("llama3-cache-repeated", |b| { - let sample = data.lines().find(|l| l.len() > 200).unwrap_or("hello world"); - b.iter(|| { - for _ in 0..1000 { - let _ = black_box(tokenizer.encode(black_box(sample), false)); - } - }) - }); - - // Concurrent: N threads each encode a medium prompt + // Concurrent long-context: N threads each encode a different ~10KB input + // through a shared tokenizer. Each thread gets 200 unique lines, simulating + // concurrent inference requests. + let all_lines: Vec<&str> = data.lines().collect(); + let lines_per_thread = 200; let tokenizer_arc = Arc::new(tokenizer.clone()); for num_threads in [2, 4, 8] { + let inputs: Vec = (0..num_threads) + .map(|i| { + let start = i * lines_per_thread; + all_lines[start..start + lines_per_thread].join("\n") + }) + .collect(); + let total_bytes: usize = inputs.iter().map(|s| s.len()).sum(); let tok = tokenizer_arc.clone(); - let sample: String = data.lines().take(3).collect::>().join("\n"); - group.bench_function(format!("llama3-concurrent-{num_threads}t"), move |b| { + group.throughput(Throughput::Bytes(total_bytes as u64)); + group.bench_function(format!("llama3-concurrent-long-{num_threads}t"), move |b| { b.iter(|| { std::thread::scope(|s| { - let handles: Vec<_> = (0..num_threads) - .map(|_| { + let handles: Vec<_> = inputs + .iter() + .map(|input| { let tok = &tok; - let sample = &sample; s.spawn(move || { - let _ = black_box(tok.encode(black_box(sample.as_str()), false)); + black_box(tok.encode(black_box(input.as_str()), false).unwrap()) }) }) .collect(); From 473588ef45d8e42337d5dfc6f2aefb86058bf75a Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Thu, 19 Mar 2026 14:13:01 -0700 Subject: [PATCH 04/13] bring back train bench --- tokenizers/benches/llama3_benchmark.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tokenizers/benches/llama3_benchmark.rs b/tokenizers/benches/llama3_benchmark.rs index ad0dd1c29a..fcfce6175e 100644 --- a/tokenizers/benches/llama3_benchmark.rs +++ b/tokenizers/benches/llama3_benchmark.rs @@ -44,14 +44,11 @@ pub fn llama3(c: &mut Criterion) { group.bench_function("llama3-batch", |b| { b.iter_custom(|iters| iter_bench_encode_batch(iters, &tokenizer, &batches)) }); - - // Concurrent long-context: N threads each encode a different ~10KB input - // through a shared tokenizer. Each thread gets 200 unique lines, simulating - // concurrent inference requests. + // Concurrent long-context: N threads each encode a different large input (80k chars) let all_lines: Vec<&str> = data.lines().collect(); - let lines_per_thread = 200; + let lines_per_thread = 1000; let tokenizer_arc = Arc::new(tokenizer.clone()); - for num_threads in [2, 4, 8] { + for num_threads in [1, 2, 4, 8] { let inputs: Vec = (0..num_threads) .map(|i| { let start = i * lines_per_thread; From 43678bfbb27a38206f86d6a51383d0324c3047f4 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:03:15 -0700 Subject: [PATCH 05/13] fix const Signed-off-by: Michael Feil <63565275+michaelfeil@users.noreply.github.com> --- tokenizers/src/models/bpe/word.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 9e34defe2f..8770d0d630 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -6,7 +6,7 @@ use std::cmp::Ordering; thread_local! { static TL_MERGE_HEAP: RefCell> = RefCell::new(QuaternaryHeap::new()); - static TL_MERGE_SKIP: RefCell> = RefCell::new(Vec::new()); + static TL_MERGE_SKIP: RefCell> = const { RefCell::new(Vec::new()) }; } #[derive(Debug, Eq)] From 3fc1a8fd3fbaa05491c182d124fb39de99a2216e Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:08:43 -0700 Subject: [PATCH 06/13] Update tokenizers/src/utils/cache.rs Co-authored-by: Luc Georges --- tokenizers/src/utils/cache.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 47ecf0e61b..5cef8fd2ed 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -23,10 +23,6 @@ fn fx_hash(key: &K) -> u64 { h.finish() } -// --------------------------------------------------------------------------- -// Sharded cache -// --------------------------------------------------------------------------- - struct ShardedMap { shards: Vec>>, per_shard_capacity: usize, From b473bcfdf0433b817d2df8be6caf8381943033a0 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:08:56 -0700 Subject: [PATCH 07/13] Update tokenizers/src/utils/cache.rs Co-authored-by: Luc Georges --- tokenizers/src/utils/cache.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 5cef8fd2ed..93a0b0b878 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -12,10 +12,6 @@ pub static MAX_LENGTH: usize = 256; /// Number of shards in the shared cache. const SHARED_CACHE_SHARDS: usize = 64; -// --------------------------------------------------------------------------- -// FxHash helper -// --------------------------------------------------------------------------- - #[inline] fn fx_hash(key: &K) -> u64 { let mut h = rustc_hash::FxHasher::default(); From 586cbaf96776d15ba70ad24c73f463bae93a6142 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 11 Apr 2026 18:27:37 -0700 Subject: [PATCH 08/13] fix typos and comments --- tokenizers/src/models/bpe/mod.rs | 2 ++ tokenizers/src/utils/cache.rs | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index abf83382ff..668f068261 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -45,10 +45,12 @@ impl MergeMap { } #[inline] + /// Get `(rank, new_id)` for a given `Pair` in the map. pub fn get(&self, pair: &Pair) -> Option<&(u32, u32)> { self.inner.get(&pack_pair(pair)) } + /// Insert `(rank, new_id)` for a given `Pair` in the map. pub fn insert(&mut self, pair: Pair, value: (u32, u32)) -> Option<(u32, u32)> { self.inner.insert(pack_pair(&pair), value) } diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 93a0b0b878..19e974e1f1 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -9,7 +9,7 @@ pub static DEFAULT_CACHE_CAPACITY: usize = 10_000; /// Strings that are too long have minimal chances to cache hit anyway pub static MAX_LENGTH: usize = 256; -/// Number of shards in the shared cache. +/// Number of shards in the sharded cache. const SHARED_CACHE_SHARDS: usize = 64; #[inline] From 9ddab19e6b4037abcbbd38b85e8b8ba14f3ecde7 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 11 Apr 2026 18:58:28 -0700 Subject: [PATCH 09/13] add extend from cache, so we don't have to return a clone in shared cache --- tokenizers/src/models/bpe/model.rs | 7 ++- tokenizers/src/models/bpe/word.rs | 8 ++++ tokenizers/src/models/unigram/model.rs | 22 +++++----- tokenizers/src/utils/cache.rs | 60 ++++++++++++-------------- 4 files changed, 52 insertions(+), 45 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 9be27dc446..93762007f4 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -481,8 +481,11 @@ impl BPE { )]); } } - if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) { - return Ok(self.word_to_tokens(hit).collect()); + if let Some(ref cache) = self.cache { + let mut word = Word::new(); + if cache.get_into(sequence, &mut word) { + return Ok(self.word_to_tokens(&word).collect()); + } } let word = self.merge_word(sequence)?; let ret = self.word_to_tokens(&word).collect(); diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 8770d0d630..6d7603bcc1 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -1,4 +1,5 @@ use super::{MergeMap, Pair}; +use crate::utils::cache::ExtendFromRef; use dary_heap::QuaternaryHeap; use rand::{rng, Rng}; use std::cell::RefCell; @@ -62,6 +63,13 @@ impl Symbol { pub(super) struct Word { symbols: Vec, } + +impl ExtendFromRef for Word { + fn extend_from_ref(&mut self, other: &Self) { + self.symbols.extend_from_slice(&other.symbols); + } +} + impl std::fmt::Debug for Word { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { fmt.debug_struct("Word") diff --git a/tokenizers/src/models/unigram/model.rs b/tokenizers/src/models/unigram/model.rs index 3a9a6bddbd..89e5da55aa 100644 --- a/tokenizers/src/models/unigram/model.rs +++ b/tokenizers/src/models/unigram/model.rs @@ -233,19 +233,19 @@ impl Unigram { return Ok(vec![]); } if self.alpha.is_none() || self.alpha == Some(0.0) { - if let Some(result) = self.cache.get(sentence) { - Ok(result.to_vec()) + let mut result = Vec::new(); + if self.cache.get_into(sentence, &mut result) { + return Ok(result); + } + let result = if self.is_optimized { + self.encode_optimized(sentence)? } else { - let result = if self.is_optimized { - self.encode_optimized(sentence)? - } else { - self.encode_unoptimized(sentence)? - }; - if sentence.len() < MAX_LENGTH { - self.cache.set(sentence.to_owned(), result.clone()); - } - Ok(result) + self.encode_unoptimized(sentence)? + }; + if sentence.len() < MAX_LENGTH { + self.cache.set(sentence.to_owned(), result.clone()); } + Ok(result) } else { let result = self.encode_unoptimized(sentence)?; Ok(result) diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 19e974e1f1..4fc892fd71 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -12,6 +12,18 @@ pub static MAX_LENGTH: usize = 256; /// Number of shards in the sharded cache. const SHARED_CACHE_SHARDS: usize = 64; +/// Trait for copying data from a reference into a mutable buffer. +/// Used by the cache to avoid cloning on cache hits. +pub trait ExtendFromRef { + fn extend_from_ref(&mut self, other: &Self); +} + +impl ExtendFromRef for Vec { + fn extend_from_ref(&mut self, other: &Self) { + self.extend_from_slice(other); + } +} + #[inline] fn fx_hash(key: &K) -> u64 { let mut h = rustc_hash::FxHasher::default(); @@ -24,7 +36,7 @@ struct ShardedMap { per_shard_capacity: usize, } -impl ShardedMap { +impl ShardedMap { fn new(total_capacity: usize) -> Self { let per_shard = total_capacity.div_ceil(SHARED_CACHE_SHARDS).max(1); let shards = (0..SHARED_CACHE_SHARDS) @@ -47,7 +59,7 @@ impl ShardedMap { (h >> 48) as usize % SHARED_CACHE_SHARDS } - fn get(&self, key: &Q) -> Option + fn get_into(&self, key: &Q, out: &mut V) -> bool where K: Borrow, Q: Hash + Eq + ?Sized, @@ -55,10 +67,12 @@ impl ShardedMap { let idx = Self::shard_for(key); let shard = &self.shards[idx]; if let Ok(guard) = shard.try_read() { - guard.get(key).cloned() - } else { - None + if let Some(value) = guard.get(key) { + out.extend_from_ref(value); + return true; + } } + false } fn set(&self, key: K, value: V) { @@ -99,7 +113,7 @@ impl ShardedMap { pub(crate) struct Cache where K: Eq + Hash + Clone + 'static, - V: Clone + 'static, + V: ExtendFromRef + 'static, { map: ShardedMap, pub capacity: usize, @@ -108,7 +122,7 @@ where impl std::fmt::Debug for Cache where K: Eq + Hash + Clone + 'static, - V: Clone + 'static, + V: ExtendFromRef + 'static, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Cache") @@ -120,7 +134,7 @@ where impl PartialEq for Cache where K: Eq + Hash + Clone + 'static, - V: Clone + 'static, + V: ExtendFromRef + 'static, { fn eq(&self, _other: &Cache) -> bool { true @@ -130,7 +144,7 @@ where impl Default for Cache where K: Eq + Hash + Clone + 'static, - V: Clone + 'static, + V: ExtendFromRef + 'static, { fn default() -> Self { Self::new(DEFAULT_CACHE_CAPACITY) @@ -140,7 +154,7 @@ where impl Cache where K: Eq + Hash + Clone + 'static, - V: Clone + 'static, + V: ExtendFromRef + 'static, { /// Create new `Cache` with the given capacity. pub(crate) fn new(capacity: usize) -> Self { @@ -160,32 +174,14 @@ where self.map.clear(); } - #[allow(dead_code)] - pub(crate) fn get_values<'a, I, Q>(&self, keys_iter: I) -> Option>> - where - I: Iterator, - K: Borrow, - Q: Hash + Eq + ?Sized + 'a, - { - Some(keys_iter.map(|k| self.get(k)).collect()) - } - - pub(crate) fn get(&self, key: &Q) -> Option + /// Get a value from the cache, extending the output buffer. + /// Returns true if the key was found, false otherwise. + pub(crate) fn get_into(&self, key: &Q, out: &mut V) -> bool where K: Borrow, Q: Hash + Eq + ?Sized, { - self.map.get(key) - } - - #[allow(dead_code)] - pub(crate) fn set_values(&self, entries: I) - where - I: IntoIterator, - { - for (k, v) in entries { - self.map.set(k, v); - } + self.map.get_into(key, out) } pub(crate) fn set(&self, key: K, value: V) { From 32d96a8da54baedaff4ef2cd2819f8050b510735 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 11 Apr 2026 18:58:45 -0700 Subject: [PATCH 10/13] add extend from cache, so we don't have to return a clone in shared cache --- tokenizers/src/models/bpe/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/mod.rs b/tokenizers/src/models/bpe/mod.rs index 668f068261..603c29c815 100644 --- a/tokenizers/src/models/bpe/mod.rs +++ b/tokenizers/src/models/bpe/mod.rs @@ -45,7 +45,7 @@ impl MergeMap { } #[inline] - /// Get `(rank, new_id)` for a given `Pair` in the map. + /// Get `(rank, new_id)` for a given `Pair` in the map. pub fn get(&self, pair: &Pair) -> Option<&(u32, u32)> { self.inner.get(&pack_pair(pair)) } From d791e82716fb4f8c36422a96586068a93f55a051 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 11 Apr 2026 19:00:35 -0700 Subject: [PATCH 11/13] add static --- tokenizers/src/utils/cache.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tokenizers/src/utils/cache.rs b/tokenizers/src/utils/cache.rs index 4fc892fd71..aa8a7067bc 100644 --- a/tokenizers/src/utils/cache.rs +++ b/tokenizers/src/utils/cache.rs @@ -112,8 +112,8 @@ impl ShardedMap { /// suited to the small keys used in tokenization caches. pub(crate) struct Cache where - K: Eq + Hash + Clone + 'static, - V: ExtendFromRef + 'static, + K: Eq + Hash + Clone, + V: ExtendFromRef, { map: ShardedMap, pub capacity: usize, @@ -121,8 +121,8 @@ where impl std::fmt::Debug for Cache where - K: Eq + Hash + Clone + 'static, - V: ExtendFromRef + 'static, + K: Eq + Hash + Clone, + V: ExtendFromRef, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Cache") @@ -133,8 +133,8 @@ where impl PartialEq for Cache where - K: Eq + Hash + Clone + 'static, - V: ExtendFromRef + 'static, + K: Eq + Hash + Clone, + V: ExtendFromRef, { fn eq(&self, _other: &Cache) -> bool { true @@ -143,8 +143,8 @@ where impl Default for Cache where - K: Eq + Hash + Clone + 'static, - V: ExtendFromRef + 'static, + K: Eq + Hash + Clone, + V: ExtendFromRef, { fn default() -> Self { Self::new(DEFAULT_CACHE_CAPACITY) @@ -153,8 +153,8 @@ where impl Cache where - K: Eq + Hash + Clone + 'static, - V: ExtendFromRef + 'static, + K: Eq + Hash + Clone, + V: ExtendFromRef, { /// Create new `Cache` with the given capacity. pub(crate) fn new(capacity: usize) -> Self { From 104b735e2f0fe2aad91fb06979bc8f5d5452a903 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 11 Apr 2026 19:33:27 -0700 Subject: [PATCH 12/13] tl word cache --- tokenizers/src/models/bpe/model.rs | 32 +++++++++++++++++++----------- tokenizers/src/models/bpe/word.rs | 4 ++++ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 93762007f4..aabffce66b 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -5,6 +5,7 @@ use crate::utils::iter::ResultShunt; use ahash::AHashMap; use serde_json::Value; use std::borrow::Cow; +use std::cell::RefCell; use std::collections::HashMap; use std::{ @@ -14,6 +15,10 @@ use std::{ path::{Path, PathBuf}, }; +thread_local! { + static TL_WORD: RefCell = RefCell::new(Word::with_capacity(MAX_LENGTH)); +} + pub type Vocab = AHashMap; type VocabR = AHashMap; pub type Merges = Vec<(String, String)>; @@ -481,20 +486,23 @@ impl BPE { )]); } } - if let Some(ref cache) = self.cache { - let mut word = Word::new(); - if cache.get_into(sequence, &mut word) { - return Ok(self.word_to_tokens(&word).collect()); + TL_WORD.with(|w| { + let mut word = w.borrow_mut(); + word.clear(); + if let Some(ref cache) = self.cache { + if cache.get_into(sequence, &mut word) { + return Ok(self.word_to_tokens(&word).collect()); + } } - } - let word = self.merge_word(sequence)?; - let ret = self.word_to_tokens(&word).collect(); - if let Some(ref cache) = self.cache { - if sequence.len() < MAX_LENGTH { - cache.set(sequence.to_owned(), word); + let word = self.merge_word(sequence)?; + let ret = self.word_to_tokens(&word).collect(); + if let Some(ref cache) = self.cache { + if sequence.len() < MAX_LENGTH { + cache.set(sequence.to_owned(), word); + } } - } - Ok(ret) + Ok(ret) + }) } } diff --git a/tokenizers/src/models/bpe/word.rs b/tokenizers/src/models/bpe/word.rs index 6d7603bcc1..640ec1da60 100644 --- a/tokenizers/src/models/bpe/word.rs +++ b/tokenizers/src/models/bpe/word.rs @@ -98,6 +98,10 @@ impl Word { } } + pub(super) fn clear(&mut self) { + self.symbols.clear(); + } + pub(super) fn add(&mut self, c: u32, byte_len: usize) { let (prev, next) = { let len = self.symbols.len() as isize; From d4386c6a85c48fc851e12f16a43e00439cc01aea Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 11 Apr 2026 19:37:26 -0700 Subject: [PATCH 13/13] cache size --- tokenizers/src/models/bpe/model.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index aabffce66b..6d12be0c86 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -16,7 +16,7 @@ use std::{ }; thread_local! { - static TL_WORD: RefCell = RefCell::new(Word::with_capacity(MAX_LENGTH)); + static TL_WORD: RefCell = RefCell::new(Word::with_capacity(64)); } pub type Vocab = AHashMap;