From 9811b633482a360f3a63a2eee84e7d361d1f5f84 Mon Sep 17 00:00:00 2001 From: Sebastian Pop Date: Wed, 22 Apr 2026 23:05:48 -0500 Subject: [PATCH] Batch encode: lock-free work queue with dynamic window sizing Replace `inputs.into_maybe_par_iter().map(...).collect()` in `encode_batch`, `encode_batch_char_offsets` and `encode_batch_fast` with a small helper `TokenizerImpl::run_batch` that: - Dispatches to a plain `inputs.into_iter().map(...).collect()` serial loop when parallelism is disabled or only one thread is available, avoiding all rayon involvement for single-threaded callers. - At higher thread counts, uses a lock-free atomic counter (`BatchWorkQueue`) inside one `rayon::scope` with one `s.spawn` per worker. Each worker claims windows of item indices via a single `AtomicUsize::fetch_add`, takes inputs from per-slot `UnsafeCell>`, and writes results into per-slot `UnsafeCell>>`. No shared mutable state outside the counter; no final `collect()` on a parallel iterator. The lock-free design is motivated by aarch64 LSE atomic cost: every mutex / condvar the previous parallel-iterator path took hit was a CAS / LDADD emitted by libpthread, and those dominate small-work parallel loops at high thread counts on arm64. Replacing that with a single `fetch_add` per window removes the mutex-backed per-item signaling entirely. ## Cache-line / loop-tiling rationale Shared-memory parallel loops are bottlenecked by the cache coherence protocol when two cores alternate writes to the same cache line: the line "ping-pongs" between their private L1d caches, each transfer costing dozens of cycles. To avoid that, every line should be filled by one producer core, drained (or no longer needed), and only then touched by a different core. This is the cache-aware equivalent of loop tiling / blocking: group the iteration space into chunks whose data footprint is a whole number of cache lines, and give each chunk to a single core. The work queue enforces this three ways: 1. The counter itself lives on its own 64-byte cache line (`#[repr(C, align(64))]` on `AlignedCounter`). A worker's `fetch_add` does not evict any neighbouring data, and reads of the counter do not pull input or result payloads into the core's L1d. 2. Each window is a contiguous run of `window_size` indices, so every worker owns a run of adjacent slots for the duration of one window. With `MAX_WINDOW_SIZE = 8`, a window covers roughly `8 * sizeof(slot)` bytes -- for `Option` (~48 B) that is ~6 cache lines; for `Option>` (multi-line per slot) it is even more. Within one window, a worker writes several whole cache lines before any other worker comes near them. 3. Each slot has its own `UnsafeCell` (`Vec>>`). `UnsafeCell` is `#[repr(transparent)]`, so the heap layout is byte-identical to a plain `Vec>` (no padding, same alignment, same contiguous packing -- zero runtime overhead vs. the "unsafe fast" version that reborrows the whole `Vec`). What the per-slot cell buys is that `self.0[i].get()` returns `*mut Option` pointing straight at slot `i`, without ever materialising a `&mut Vec>` that would alias the enclosing container (which is UB when two threads touch any distinct indices concurrently). At window boundaries a single cache line can be shared between two successive windows when the slot size does not divide 64 bytes. That is a sequential handoff (window N finishes writes; window N+1 then reads/writes), not a concurrent ping-pong, so the cost is at most one coherence transfer per window-pair. ## Window sizing `window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD))`, clamped to `[1, MAX_WINDOW_SIZE]`. - `WINDOWS_PER_THREAD = 4` keeps several windows per thread so a slow worker on its last item does not stall the whole batch. - `MAX_WINDOW_SIZE = 8` caps per-claim atomic latency and keeps the per-window memory footprint small enough to fit in L1d. Examples: 100 items / 16 threads yields `window_size = 2` (50 windows); 10 000 items / 16 threads yields `window_size = 8` (1250 windows). ## Tests 7 new unit tests in `utils::batch::tests` cover window sizing, `TakeVec` and `ResultVec` round-trip, and `test_parallel_distribution` (4 threads concurrently claiming and writing 100 slots, exercising the Sync bounds under real contention). cargo test --lib --features http: 208 passed, 0 failed. ## Perf evidence On Vera (88-core Olympus, 176 logical), `bpe_benchmark`/`bpe-encode/BPE GPT2 encode batch` at 88T, `perf record -g --call-graph fp -F 4999`. LSE atomic instructions (the direct motivation for the lock-free counter): instruction before after __aarch64_cas4_acq 3.57% 0.61% (-5.9x) __aarch64_ldadd8_acq_rel 1.05% 0.08% (-13x) __aarch64_swp4_rel 0.21% 0.05% __aarch64_ldadd8_relax 0.17% 0.24% __aarch64_swp4_acq 0.12% 0.00% __aarch64_swp8_acq_rel 0.06% 0.00% __aarch64_cas8_acq_rel 0.01% 0.01% total LSE ~5.2% ~1.0% (-4.2x) Rayon / crossbeam-epoch: symbol before after rayon_core::sleep::Sleep::wake_specific_thread 0.57% 0.06% (-10x) crossbeam_epoch::internal::Global::try_advance 25.93% 28.38% crossbeam_epoch::default::with_handle 21.41% 23.12% rayon_core::registry::WorkerThread::wait_until_cold 8.40% 10.72% rayon::iter::plumbing::bridge_producer_consumer::helper 0.20% 0.24% `bridge_producer_consumer::helper` was not a hotspot on this workload before the change (0.20%) and does not move; the observable rayon-side change is `Sleep::wake_specific_thread` dropping ~10x because `rayon::scope` issues one wake per worker per batch call rather than streaming wakes per parallel-iterator split. The three remaining rayon/crossbeam ceiling symbols (`try_advance` + `with_handle` + `wait_until_cold` = ~62% of cycles) stay similar in percentage because total cycles decrease; absolute wall-clock per benchmark iteration drops 35 ms (295 ms -> 260 ms at 88T). Removing that rayon ceiling is a separate change. Throughput on Vera, `bpe-encode/BPE GPT2 encode batch` (data/big.txt, encode_batch through the full post-processor): threads before after change ------- ------ ------ ------ 1T 3.98 MiB/s 4.46 MiB/s +12% 88T 20.97 MiB/s 23.76 MiB/s +13% 176T 18.83 MiB/s 21.58 MiB/s +15% --- tokenizers/src/tokenizer/mod.rs | 98 +++++++++-- tokenizers/src/utils/batch.rs | 299 ++++++++++++++++++++++++++++++++ tokenizers/src/utils/mod.rs | 1 + 3 files changed, 381 insertions(+), 17 deletions(-) create mode 100644 tokenizers/src/utils/batch.rs diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 8e282fba28..288b04aefd 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -20,6 +20,7 @@ use std::{ use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use crate::utils::batch::{BatchWorkQueue, ResultVec, TakeVec}; use crate::utils::iter::ResultShunt; use crate::utils::parallelism::*; use crate::utils::progress::{ProgressBar, ProgressStyle}; @@ -1300,7 +1301,11 @@ where PP: PostProcessor + Send + Sync, D: Decoder + Send + Sync, { - /// Encode all the sentences in parallel, using multiple threads + /// Encode all the sentences in parallel, using multiple threads. + /// + /// Uses a lock-free work queue with cache-line-sized windows instead of + /// rayon's `bridge_producer_consumer`, eliminating its synchronization + /// overhead at higher thread counts. pub fn encode_batch<'s, E>( &self, inputs: Vec, @@ -1309,13 +1314,10 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = + self.run_batch(inputs, |this, input| this.encode(input, add_special_tokens))?; if let Some(params) = &self.padding { - // We do the padding here to make sure we handle the batch padding pad_encodings(&mut encodings, params)?; } @@ -1332,20 +1334,22 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode_char_offsets(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = self.run_batch(inputs, |this, input| { + this.encode_char_offsets(input, add_special_tokens) + })?; if let Some(params) = &self.padding { - // We do the padding here to make sure we handle the batch padding pad_encodings(&mut encodings, params)?; } Ok(encodings) } - /// Encode all the sentences in parallel, using multiple threads + /// Encode all the sentences in parallel, using multiple threads. + /// + /// Uses a lock-free work queue with cache-line-sized windows instead of + /// rayon's `bridge_producer_consumer`, eliminating its synchronization + /// overhead at higher thread counts. pub fn encode_batch_fast<'s, E>( &self, inputs: Vec, @@ -1354,19 +1358,79 @@ where where E: Into> + Send, { - let mut encodings = inputs - .into_maybe_par_iter() - .map(|input| self.encode_fast(input, add_special_tokens)) - .collect::>>()?; + let mut encodings = self.run_batch(inputs, |this, input| { + this.encode_fast(input, add_special_tokens) + })?; if let Some(params) = &self.padding { - // We do the padding here to make sure we handle the batch padding pad_encodings(&mut encodings, params)?; } Ok(encodings) } + /// Shared implementation for all batch encode variants. + /// + /// Distributes work items across threads using a lock-free atomic counter. + /// Each thread claims a dynamically-sized window of items, processes them, + /// and writes results directly to pre-allocated slots. + /// + /// Uses `rayon::scope` to run on the existing rayon thread pool, avoiding + /// the cost of creating/destroying OS threads on every call. + fn run_batch<'s, E, F>(&self, inputs: Vec, encode_fn: F) -> Result> + where + E: Into> + Send, + F: Fn(&Self, EncodeInput<'s>) -> Result + Sync, + { + let n = inputs.len(); + if n == 0 { + return Ok(vec![]); + } + + let parallelism = get_parallelism(); + let num_threads = if parallelism { + current_num_threads().min(n) + } else { + 1 + }; + + if num_threads <= 1 { + return inputs + .into_iter() + .map(|input| encode_fn(self, input.into())) + .collect::>>(); + } + + // Lock-free batch distribution: atomic counter hands out + // dynamically-sized windows of item indices to worker threads. + let inputs = TakeVec::new( + inputs + .into_iter() + .map(|e| e.into()) + .collect::>>(), + ); + let results: ResultVec> = ResultVec::new(n); + let queue = BatchWorkQueue::new(n, num_threads); + + rayon::scope(|s| { + for _ in 0..num_threads { + s.spawn(|_| { + while let Some((start, end)) = queue.claim_window() { + for i in start..end { + let input = inputs.take(i); + results.set(i, encode_fn(self, input)); + } + } + }); + } + }); + + results + .into_vec() + .into_iter() + .collect::>>() + } + /// Decode all sentences in parallel pub fn decode_batch( &self, diff --git a/tokenizers/src/utils/batch.rs b/tokenizers/src/utils/batch.rs new file mode 100644 index 0000000000..e0a3c0a6b7 --- /dev/null +++ b/tokenizers/src/utils/batch.rs @@ -0,0 +1,299 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Lock-free batch work distribution with dynamic window sizing. +//! +//! Replaces rayon's parallel iteration for batch encode with a simpler +//! mechanism: a single atomic counter hands out contiguous windows of +//! item indices to worker threads running on rayon's persistent thread +//! pool. The only cross-thread synchronization on the hot path is the +//! `AtomicUsize::fetch_add` that claims each window. +//! +//! ## Cache-line / loop-tiling rationale +//! +//! Shared-memory parallel loops are bottlenecked by the cache coherence +//! protocol when two cores alternate writes to the same cache line: the +//! line "ping-pongs" between their private L1d caches, each transfer +//! costing dozens of cycles. To avoid that, every line should be filled +//! by one producer core, drained (or no longer needed), and only then +//! touched by a different core. This is the cache-aware equivalent of +//! loop tiling / blocking. +//! +//! The work queue enforces this three ways: +//! +//! 1. The counter itself lives on its own 64-byte cache line +//! (`#[repr(C, align(64))]` on `AlignedCounter`). A worker's +//! `fetch_add` does not evict any neighbouring data, and reads of the +//! counter do not pull input or result payloads into the core's L1d. +//! +//! 2. Each window is a contiguous run of `window_size` indices, so every +//! worker owns a run of adjacent slots for the duration of one +//! window. With `MAX_WINDOW_SIZE = 8`, a window covers roughly +//! `8 * sizeof(slot)` bytes -- for `Option` (~48 B) that +//! is ~6 cache lines; for `Option>` (multi-line per +//! slot) it is even more. So within one window, a worker writes +//! several whole cache lines before any other worker comes near them. +//! +//! 3. Each slot has its own `UnsafeCell` +//! (`Vec>>`). `UnsafeCell` is +//! `#[repr(transparent)]` so the heap layout is identical to a plain +//! `Vec>` (no padding, no indirection), but concurrent +//! accesses to different indices never materialise a shared `&mut` +//! reference to the enclosing `Vec` (which would be UB, regardless of +//! which element each access ultimately reached). +//! +//! At window boundaries a single cache line can be shared between two +//! successive windows when the slot size does not divide 64 bytes. That +//! is a *sequential* handoff (window N finishes writes; window N+1 then +//! reads/writes), not a concurrent ping-pong. +//! +//! ## Window sizing +//! +//! `window_size = ceil(total / (num_threads * WINDOWS_PER_THREAD))`, +//! clamped to `[1, MAX_WINDOW_SIZE]`. +//! +//! - `WINDOWS_PER_THREAD = 4` keeps several windows per thread so a +//! slow worker on its last item does not stall the whole batch. +//! - `MAX_WINDOW_SIZE = 8` caps per-claim atomic latency and keeps the +//! per-window memory footprint small enough to fit comfortably in L1d. +//! +//! Example: 100 items / 16 threads yields window_size = 2 (50 windows); +//! 10000 items / 16 threads yields window_size = 8 (1250 windows). + +use std::cell::UnsafeCell; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Minimum number of windows each thread should get for load balancing. +const WINDOWS_PER_THREAD: usize = 4; + +/// Maximum window size (items per atomic claim). Larger values reduce +/// atomic contention but worsen tail-latency from uneven last windows. +const MAX_WINDOW_SIZE: usize = 8; + +/// Cache-line-aligned atomic counter. +/// Ensures the counter does not share a cache line with any other data. +#[repr(C, align(64))] +struct AlignedCounter(AtomicUsize); + +/// Lock-free work distributor. +/// +/// Workers atomically claim non-overlapping windows of item indices. +/// The window size is chosen dynamically based on `total` and +/// `num_threads` so that every thread gets several windows of work. +/// The counter is on its own cache line so claiming work does not +/// contend with result writes. +pub(crate) struct BatchWorkQueue { + next: AlignedCounter, + total: usize, + window_size: usize, +} + +impl BatchWorkQueue { + /// Create a new queue distributing `total` items across `num_threads`. + /// + /// The window size is chosen to give each thread at least + /// `WINDOWS_PER_THREAD` windows, capped at `MAX_WINDOW_SIZE`. + pub(crate) fn new(total: usize, num_threads: usize) -> Self { + let target_windows = num_threads.saturating_mul(WINDOWS_PER_THREAD).max(1); + let window_size = total.div_ceil(target_windows).clamp(1, MAX_WINDOW_SIZE); + Self { + next: AlignedCounter(AtomicUsize::new(0)), + total, + window_size, + } + } + + /// Claim the next window of work items. + /// Returns `Some((start, end))` half-open range, or `None` when all + /// items have been claimed. + pub(crate) fn claim_window(&self) -> Option<(usize, usize)> { + let start = self.next.0.fetch_add(self.window_size, Ordering::Relaxed); + if start >= self.total { + return None; + } + Some((start, (start + self.window_size).min(self.total))) + } +} + +/// A `Vec` whose elements can each be *taken* exactly once from any thread. +/// +/// The `BatchWorkQueue` guarantees that no two threads access the same +/// index, so no synchronization is needed beyond the queue itself. +/// +/// Layout: each slot has its own `UnsafeCell>`. Because +/// `UnsafeCell` is `#[repr(transparent)]` over `U`, this heap layout +/// is byte-identical to a plain `Vec>`: no added padding, +/// identical slot alignment, identical contiguous packing. The only +/// difference is that `self.0[i].get()` gives a raw `*mut Option` +/// pointing straight at slot `i`, without ever materialising a +/// `&mut Vec>` (which would alias the enclosing container and +/// be UB when two threads touch any distinct indices concurrently). +pub(crate) struct TakeVec(Vec>>); + +// SAFETY: callers guarantee each index is accessed by at most one thread; +// `take` produces a raw pointer to a single slot's `UnsafeCell` without +// aliasing the surrounding `Vec`. +unsafe impl Sync for TakeVec {} + +impl TakeVec { + /// Wrap a `Vec` so items can be taken by index. + pub(crate) fn new(items: Vec) -> Self { + Self( + items + .into_iter() + .map(|t| UnsafeCell::new(Some(t))) + .collect(), + ) + } + + /// Take the item at `index`, leaving `None` in its place. + /// Panics if the item was already taken. + pub(crate) fn take(&self, index: usize) -> T { + // SAFETY: the `BatchWorkQueue` guarantees that each `index` is passed + // to `take` by at most one thread. `self.0[index].get()` returns a + // raw pointer to that slot's `Option`; reborrowing it as `&mut` + // does not alias any sibling slot's data. + unsafe { + (*self.0[index].get()) + .take() + .expect("batch item already taken") + } + } +} + +/// A `Vec>` where each slot is written exactly once from any +/// thread. +/// +/// The `BatchWorkQueue` guarantees non-overlapping index access. +/// +/// Layout: same note as `TakeVec`. Each slot is a +/// `UnsafeCell>` (`#[repr(transparent)]` over `Option`), so +/// the heap layout is byte-identical to a plain `Vec>` +/// and `self.0[i].get()` yields a raw `*mut Option` to slot `i` +/// without materialising a `&mut Vec>`. +pub(crate) struct ResultVec(Vec>>); + +// SAFETY: callers guarantee each index is written by at most one thread; +// `set` produces a raw pointer to a single slot's `UnsafeCell` without +// aliasing the surrounding `Vec`. +unsafe impl Sync for ResultVec {} + +impl ResultVec { + /// Allocate `len` empty result slots. + pub(crate) fn new(len: usize) -> Self { + Self((0..len).map(|_| UnsafeCell::new(None)).collect()) + } + + /// Write a result to the slot at `index`. + pub(crate) fn set(&self, index: usize, value: T) { + // SAFETY: the `BatchWorkQueue` guarantees that each `index` is passed + // to `set` by at most one thread, so no other reference to this + // slot's `Option` exists concurrently. + unsafe { + *self.0[index].get() = Some(value); + } + } + + /// Consume self and return the results in order. + /// Panics if any slot was not written. + pub(crate) fn into_vec(self) -> Vec { + self.0 + .into_iter() + .map(|cell| cell.into_inner().expect("result slot was never written")) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_batch_work_queue_single_thread() { + // 20 items, 1 thread => target 4 windows => window_size = 5. + let queue = BatchWorkQueue::new(20, 1); + let mut ranges = Vec::new(); + while let Some(range) = queue.claim_window() { + ranges.push(range); + } + assert_eq!(ranges.len(), 4); + assert_eq!(ranges[0], (0, 5)); + assert_eq!(ranges[1], (5, 10)); + assert_eq!(ranges[2], (10, 15)); + assert_eq!(ranges[3], (15, 20)); + } + + #[test] + fn test_batch_work_queue_many_threads() { + // 100 items, 16 threads => target 64 windows => window_size = 2. + let queue = BatchWorkQueue::new(100, 16); + let mut ranges = Vec::new(); + while let Some(range) = queue.claim_window() { + ranges.push(range); + } + assert_eq!(ranges.len(), 50); + assert_eq!(ranges[0], (0, 2)); + assert_eq!(ranges[49], (98, 100)); + } + + #[test] + fn test_batch_work_queue_window_capped() { + // 10000 items, 4 threads => target 16 windows => window_size = 625, + // but capped at MAX_WINDOW_SIZE (8). + let queue = BatchWorkQueue::new(10000, 4); + let mut count = 0; + while queue.claim_window().is_some() { + count += 1; + } + // 10000 / 8 = 1250 windows. + assert_eq!(count, 1250); + } + + #[test] + fn test_batch_work_queue_empty() { + let queue = BatchWorkQueue::new(0, 4); + assert!(queue.claim_window().is_none()); + } + + #[test] + fn test_take_vec() { + let tv = TakeVec::new(vec![10, 20, 30]); + assert_eq!(tv.take(1), 20); + assert_eq!(tv.take(0), 10); + assert_eq!(tv.take(2), 30); + } + + #[test] + fn test_result_vec() { + let rv = ResultVec::::new(3); + rv.set(2, 30); + rv.set(0, 10); + rv.set(1, 20); + assert_eq!(rv.into_vec(), vec![10, 20, 30]); + } + + #[test] + fn test_parallel_distribution() { + let n = 100; + let num_threads = 4; + let queue = BatchWorkQueue::new(n, num_threads); + let results = ResultVec::new(n); + + std::thread::scope(|s| { + for _ in 0..num_threads { + s.spawn(|| { + while let Some((start, end)) = queue.claim_window() { + for i in start..end { + results.set(i, i * 2); + } + } + }); + } + }); + + let v = results.into_vec(); + for (i, &item) in v.iter().enumerate() { + assert_eq!(item, i * 2); + } + } +} diff --git a/tokenizers/src/utils/mod.rs b/tokenizers/src/utils/mod.rs index c9450b3222..252a466e64 100644 --- a/tokenizers/src/utils/mod.rs +++ b/tokenizers/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod batch; pub(crate) mod cache; #[cfg(feature = "http")] pub(crate) mod from_pretrained;