Skip to content

Performance: batch_encode scales poorly on high-core Server CPUs compared to sharded tokenizer instances #1900

@stargazerZJ

Description

@stargazerZJ

TLDR: In a 120-core k8s container (cpu request = cpu limit) on a 192-core system, sharding the tokenizer manually (letting 32 threads share one tokenizer instance) drastically outperforms the default batch_encode implementation on uniform-length, synthetic data.

I have also tested this on actual text datasets with non-uniform lengths, where my manual worker group implementation performs 7-9x faster. This likely indicates lock contention, memory locality issues, or other bottlenecks in the current batch_encode implementation when scaling to high core counts.

Benchmark Output:

❯ ./bench_data/target/release/tokenizer_bench -i bench_data -n 4 --compare-batch -b 1024 -w 120 --threads-per-tokenizer 32 --synthetic
[INFO] Tokenization Benchmark
[INFO] Workers: 120
[INFO] Batch size: 1024
[INFO] Loading tokenizer: Qwen/Qwen2.5-Coder-32B-Instruct
[INFO] Tokenizer loaded successfully
[INFO] Generating 10000 synthetic strings of 204800 bytes each
[00:00:08] ======================================== 10000/10000 strings
[INFO] Generated 10000 strings, total size: 1953.12 MB

=== Worker Pool Tokenization (120 workers, 32 threads per tokenizer) ===
[INFO] Creating 4 tokenizer groups
[00:00:12] ======================================== 10000/10000 Docs (798.3005/s)
[THROUGHPUT] Docs: 10000 (797.9/s) | Tokens: 1346513955 (107441384.6/s) | Data: 1.91 GB (155.84 MB/s) | Elapsed: 12.5s

=== Batch Tokenization (batch_size=1024) ===
[00:00:40] ======================================== 10000/10000 Docs (244.8588/s)
[THROUGHPUT] Docs: 10000 (244.9/s) | Tokens: 1346513955 (32970451.7/s) | Data: 1.91 GB (47.82 MB/s) | Elapsed: 40.8s

=== Performance Comparison ===
Worker Pool: 12.5s | Batch: 40.8s | Speedup: 0.31x

[INFO] Benchmark complete
❯ ./bench_data/target/release/tokenizer_bench -i bench_data -n 4 --compare-batch -b 10000 -w 120 --threads-per-tokenizer 32 --synthetic
[INFO] Tokenization Benchmark
[INFO] Workers: 120
[INFO] Batch size: 10000
[INFO] Loading tokenizer: Qwen/Qwen2.5-Coder-32B-Instruct
[INFO] Tokenizer loaded successfully
[INFO] Generating 10000 synthetic strings of 204800 bytes each
[00:00:08] ======================================== 10000/10000 strings
[INFO] Generated 10000 strings, total size: 1953.12 MB

=== Worker Pool Tokenization (120 workers, 32 threads per tokenizer) ===
[INFO] Creating 4 tokenizer groups
[00:00:12] ======================================== 10000/10000 Docs (792.242/s)
[THROUGHPUT] Docs: 10000 (792.0/s) | Tokens: 1346505782 (106643927.4/s) | Data: 1.91 GB (154.69 MB/s) | Elapsed: 12.6s

=== Batch Tokenization (batch_size=10000) ===
[00:00:51] ======================================== 10000/10000 Docs (192.5807/s)
[THROUGHPUT] Docs: 10000 (192.5/s) | Tokens: 1346505782 (25926359.3/s) | Data: 1.91 GB (37.61 MB/s) | Elapsed: 51.9s

=== Performance Comparison ===
Worker Pool: 12.6s | Batch: 51.9s | Speedup: 0.24x

[INFO] Benchmark complete

Reproduction Code:

use anyhow::{Context, Result};
use arrow::array::{Array, Int64Array, StringArray};
use arrow::record_batch::RecordBatch;
use clap::Parser;
use indicatif::{ProgressBar, ProgressStyle};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use rayon::prelude::*;
use std::fs::File;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tokenizers::Tokenizer;

#[derive(Parser, Debug)]
#[command(author, version, about = "Tokenization benchmark for text data", long_about = None)]
struct Args {
    /// Path to the output directory containing parquet files
    #[arg(short, long)]
    input_dir: Option<PathBuf>,

    /// Number of parquet files to read (0 = all)
    #[arg(short = 'n', long, default_value = "0")]
    num_files: usize,

    /// Number of worker threads (0 = num_cpus)
    #[arg(short = 'w', long, default_value = "0")]
    workers: usize,

    /// Tokenizer model path or HuggingFace model ID
    #[arg(short = 't', long, default_value = "Qwen/Qwen2.5-Coder-32B-Instruct")]
    tokenizer_model: String,

    /// Batch size for batch tokenization test
    #[arg(short = 'b', long, default_value = "32")]
    batch_size: usize,

    /// Enable batch tokenization comparison
    #[arg(long)]
    compare_batch: bool,

    /// Number of threads sharing one tokenizer instance (0 = all share one, 1 = one per thread)
    #[arg(long, default_value = "1")]
    threads_per_tokenizer: usize,

    /// Use synthetic data instead of parquet files
    #[arg(long)]
    synthetic: bool,

    /// Number of synthetic strings to generate
    #[arg(long, default_value = "10000")]
    synthetic_count: usize,

    /// Length of each synthetic string in bytes
    #[arg(long, default_value = "204800")]
    synthetic_length: usize,
}

#[allow(dead_code)]
struct Document {
    group_id: i64,
    source_name: String,
    doc_id: i64,
    text: String,
}

struct BenchmarkStats {
    docs_processed: AtomicUsize,
    tokens_generated: AtomicU64,
    bytes_processed: AtomicU64,
    start_time: Instant,
}

impl BenchmarkStats {
    fn new() -> Self {
        Self {
            docs_processed: AtomicUsize::new(0),
            tokens_generated: AtomicU64::new(0),
            bytes_processed: AtomicU64::new(0),
            start_time: Instant::now(),
        }
    }

    fn add_doc(&self, token_count: usize, byte_count: usize) {
        self.docs_processed.fetch_add(1, Ordering::Relaxed);
        self.tokens_generated
            .fetch_add(token_count as u64, Ordering::Relaxed);
        self.bytes_processed
            .fetch_add(byte_count as u64, Ordering::Relaxed);
    }

    fn report(&self) {
        let elapsed = self.start_time.elapsed().as_secs_f64();
        let docs = self.docs_processed.load(Ordering::Relaxed);
        let tokens = self.tokens_generated.load(Ordering::Relaxed);
        let bytes = self.bytes_processed.load(Ordering::Relaxed);

        let doc_rate = docs as f64 / elapsed;
        let token_rate = tokens as f64 / elapsed;
        let mb_rate = (bytes as f64 / (1024.0 * 1024.0)) / elapsed;
        let gb_total = bytes as f64 / (1024.0 * 1024.0 * 1024.0);

        println!(
            "[THROUGHPUT] Docs: {} ({:.1}/s) | Tokens: {} ({:.1}/s) | Data: {:.2} GB ({:.2} MB/s) | Elapsed: {:.1}s",
            docs, doc_rate, tokens, token_rate, gb_total, mb_rate, elapsed
        );
    }
}

fn read_parquet_file(path: &PathBuf) -> Result<Vec<Document>> {
    let file = File::open(path).context("Failed to open parquet file")?;
    let builder = ParquetRecordBatchReaderBuilder::try_new(file)
        .context("Failed to create parquet reader")?;
    let mut reader = builder.build().context("Failed to build reader")?;

    let mut results = Vec::new();

    while let Some(batch) = reader.next() {
        let batch = batch.context("Failed to read batch")?;
        results.extend(parse_batch(&batch)?);
    }

    Ok(results)
}

fn parse_batch(batch: &RecordBatch) -> Result<Vec<Document>> {
    use arrow::array::BinaryArray;
    
    // Adjusted schema expectation for generic docs
    let group_id = batch
        .column_by_name("group_id") // changed from repo_id
        .context("Missing group_id column")?
        .as_any()
        .downcast_ref::<Int64Array>()
        .context("group_id is not Int64Array")?;

    let source_name = batch
        .column_by_name("source_name") // changed from repo_name
        .context("Missing source_name column")?
        .as_any()
        .downcast_ref::<StringArray>()
        .context("source_name is not StringArray")?;

    let doc_id = batch
        .column_by_name("doc_id") // changed from pr_id
        .context("Missing doc_id column")?
        .as_any()
        .downcast_ref::<Int64Array>()
        .context("doc_id is not Int64Array")?;

    let text_column = batch
        .column_by_name("text")
        .context("Missing text column")?;

    let mut results = Vec::with_capacity(batch.num_rows());
    
    // Try StringArray first, fall back to BinaryArray if UTF-8 validation fails
    if let Some(text) = text_column.as_any().downcast_ref::<StringArray>() {
        for i in 0..batch.num_rows() {
            if !text.is_null(i) {
                results.push(Document {
                    group_id: group_id.value(i),
                    source_name: source_name.value(i).to_string(),
                    doc_id: doc_id.value(i),
                    text: text.value(i).to_string(),
                });
            }
        }
    } else if let Some(text) = text_column.as_any().downcast_ref::<BinaryArray>() {
        // Handle binary data with lossy UTF-8 conversion
        for i in 0..batch.num_rows() {
            if !text.is_null(i) {
                let bytes = text.value(i);
                let text_str = String::from_utf8_lossy(bytes).to_string();
                results.push(Document {
                    group_id: group_id.value(i),
                    source_name: source_name.value(i).to_string(),
                    doc_id: doc_id.value(i),
                    text: text_str,
                });
            }
        }
    } else {
        anyhow::bail!("text column is neither StringArray nor BinaryArray");
    }

    Ok(results)
}

fn load_data(input_dir: &PathBuf, num_files: usize) -> Result<Vec<Document>> {
    let mut parquet_files: Vec<PathBuf> = std::fs::read_dir(input_dir)
        .context("Failed to read input directory")?
        .filter_map(|entry| {
            let entry = entry.ok()?;
            let path = entry.path();
            if path.extension()? == "parquet" {
                Some(path)
            } else {
                None
            }
        })
        .collect();

    parquet_files.sort();

    if num_files > 0 && num_files < parquet_files.len() {
        parquet_files.truncate(num_files);
    }

    println!(
        "[INFO] Loading {} parquet files from {}",
        parquet_files.len(),
        input_dir.display()
    );

    let pb = ProgressBar::new(parquet_files.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} files")
            .unwrap()
            .progress_chars("=>-"),
    );

    let mut all_data = Vec::new();
    for file in parquet_files {
        let data = read_parquet_file(&file)?;
        all_data.extend(data);
        pb.inc(1);
    }
    pb.finish_with_message("Loading complete");

    println!("[INFO] Loaded {} documents", all_data.len());
    Ok(all_data)
}

fn generate_synthetic_data(count: usize, length: usize) -> Vec<Document> {
    use rand::Rng;
    
    println!(
        "[INFO] Generating {} synthetic strings of {} bytes each",
        count, length
    );

    let pb = ProgressBar::new(count as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} strings")
            .unwrap()
            .progress_chars("=>-"),
    );

    let data: Vec<Document> = (0..count)
        .map(|i| {
            let mut rng = rand::thread_rng();
            let text: String = (0..length)
                .map(|_| {
                    let idx = rng.gen_range(0..52);
                    if idx < 26 {
                        (b'a' + idx) as char
                    } else {
                        (b'A' + idx - 26) as char
                    }
                })
                .collect();
            
            pb.inc(1);
            
            Document {
                group_id: i as i64,
                source_name: format!("synthetic/group-{}", i),
                doc_id: i as i64,
                text,
            }
        })
        .collect();

    pb.finish_with_message("Generation complete");
    
    let total_bytes: usize = data.iter().map(|doc| doc.text.len()).sum();
    let total_mb = total_bytes as f64 / (1024.0 * 1024.0);
    println!(
        "[INFO] Generated {} strings, total size: {:.2} MB",
        data.len(),
        total_mb
    );
    
    data
}

#[allow(dead_code)]
fn benchmark_single_threaded(
    tokenizer: &Tokenizer,
    data: &[Document],
    stats: &BenchmarkStats,
) -> Result<()> {
    println!("\n=== Single-threaded Tokenization ===");

    let pb = ProgressBar::new(data.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} Docs ({per_sec})")
            .unwrap()
            .progress_chars("=>-"),
    );

    for doc in data {
        let encoding = tokenizer
            .encode(doc.text.as_str(), false)
            .map_err(|e| anyhow::anyhow!("Failed to encode text: {}", e))?;
        let token_count = encoding.get_ids().len();
        let byte_count = doc.text.len();

        stats.add_doc(token_count, byte_count);
        pb.inc(1);
    }

    pb.finish_with_message("Complete");
    stats.report();

    Ok(())
}

fn benchmark_worker_pool(
    tokenizer: &Tokenizer,
    data: &[Document],
    stats: &BenchmarkStats,
    num_workers: usize,
    threads_per_tokenizer: usize,
) -> Result<()> {
    let sharing_mode = if threads_per_tokenizer == 0 {
        "all share one".to_string()
    } else if threads_per_tokenizer == 1 {
        "one per thread".to_string()
    } else {
        format!("{} threads per tokenizer", threads_per_tokenizer)
    };
    
    println!(
        "\n=== Worker Pool Tokenization ({} workers, {}) ===",
        num_workers, sharing_mode
    );

    let pool = rayon::ThreadPoolBuilder::new()
        .num_threads(num_workers)
        .build()
        .context("Failed to build thread pool")?;

    let pb = ProgressBar::new(data.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} Docs ({per_sec})")
            .unwrap()
            .progress_chars("=>-"),
    );

    if threads_per_tokenizer == 0 {
        // Mode 1: All threads share one tokenizer
        pool.install(|| {
            data.par_iter().for_each(|doc| {
                if let Ok(encoding) = tokenizer.encode(doc.text.as_str(), false) {
                    let token_count = encoding.get_ids().len();
                    let byte_count = doc.text.len();
                    stats.add_doc(token_count, byte_count);
                    pb.inc(1);
                }
            });
        });
    } else if threads_per_tokenizer == 1 {
        // Mode 2: One tokenizer per thread
        let tokenizer_json = tokenizer.to_string(false)
            .map_err(|e| anyhow::anyhow!("Failed to serialize tokenizer: {}", e))?;
        let tokenizer_json = Arc::new(tokenizer_json);

        pool.install(|| {
            data.par_iter().for_each(|doc| {
                thread_local! {
                    static TOKENIZER: std::cell::RefCell<Option<Tokenizer>> = std::cell::RefCell::new(None);
                }
                
                TOKENIZER.with(|tok_cell| {
                    let mut tok_opt = tok_cell.borrow_mut();
                    if tok_opt.is_none() {
                        match Tokenizer::from_bytes(tokenizer_json.as_bytes()) {
                            Ok(tok) => *tok_opt = Some(tok),
                            Err(e) => {
                                eprintln!("Failed to create tokenizer for thread: {}", e);
                                return;
                            }
                        }
                    }
                    
                    if let Some(tok) = tok_opt.as_ref() {
                        if let Ok(encoding) = tok.encode(doc.text.as_str(), false) {
                            let token_count = encoding.get_ids().len();
                            let byte_count = doc.text.len();
                            stats.add_doc(token_count, byte_count);
                            pb.inc(1);
                        }
                    }
                });
            });
        });
    } else {
        // Mode 3: N threads share one tokenizer (best performance)
        let tokenizer_json = tokenizer.to_string(false)
            .map_err(|e| anyhow::anyhow!("Failed to serialize tokenizer: {}", e))?;
        let tokenizer_json = Arc::new(tokenizer_json);
        
        // Calculate number of tokenizer groups
        let num_groups = (num_workers + threads_per_tokenizer - 1) / threads_per_tokenizer;
        println!("[INFO] Creating {} tokenizer groups", num_groups);
        
        // Pre-create tokenizers for each group
        let tokenizers: Vec<Arc<Tokenizer>> = (0..num_groups)
            .map(|_| {
                Arc::new(
                    Tokenizer::from_bytes(tokenizer_json.as_bytes())
                        .expect("Failed to create tokenizer")
                )
            })
            .collect();
        
        let tokenizers = Arc::new(tokenizers);
        let counter = Arc::new(AtomicUsize::new(0));

        pool.install(|| {
            data.par_iter().for_each(|doc| {
                thread_local! {
                    static THREAD_ID: std::cell::RefCell<Option<usize>> = std::cell::RefCell::new(None);
                }
                
                THREAD_ID.with(|id_cell| {
                    let mut id_opt = id_cell.borrow_mut();
                    if id_opt.is_none() {
                        // Assign this thread to a group
                        let thread_id = counter.fetch_add(1, Ordering::Relaxed);
                        *id_opt = Some(thread_id);
                    }
                    
                    let thread_id = id_opt.unwrap();
                    let group_id = thread_id / threads_per_tokenizer;
                    let tok = &tokenizers[group_id.min(tokenizers.len() - 1)];
                    
                    if let Ok(encoding) = tok.encode(doc.text.as_str(), false) {
                        let token_count = encoding.get_ids().len();
                        let byte_count = doc.text.len();
                        stats.add_doc(token_count, byte_count);
                        pb.inc(1);
                    }
                });
            });
        });
    }

    pb.finish_with_message("Complete");
    stats.report();

    Ok(())
}

fn benchmark_batch_tokenization(
    tokenizer: &Tokenizer,
    data: &[Document],
    stats: &BenchmarkStats,
    batch_size: usize,
) -> Result<()> {
    println!(
        "\n=== Batch Tokenization (batch_size={}) ===",
        batch_size
    );

    let pb = ProgressBar::new(data.len() as u64);
    pb.set_style(
        ProgressStyle::default_bar()
            .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} Docs ({per_sec})")
            .unwrap()
            .progress_chars("=>-"),
    );

    for chunk in data.chunks(batch_size) {
        let texts: Vec<&str> = chunk.iter().map(|doc| doc.text.as_str()).collect();

        let encodings = tokenizer
            .encode_batch(texts, false)
            .map_err(|e| anyhow::anyhow!("Failed to encode batch: {}", e))?;

        for (i, encoding) in encodings.iter().enumerate() {
            let token_count = encoding.get_ids().len();
            let byte_count = chunk[i].text.len();
            stats.add_doc(token_count, byte_count);
            pb.inc(1);
        }
    }

    pb.finish_with_message("Complete");
    stats.report();

    Ok(())
}

fn main() -> Result<()> {
    // Disable tokenizer's internal parallelism to avoid nested parallelism
    // The tokenizer library uses Rayon internally, which conflicts with our worker pool
    // This line is commented out because it doesn't have an obvious effect in my experiments
    // std::env::set_var("TOKENIZERS_PARALLELISM", "false");
    
    let args = Args::parse();

    // Set number of workers
    let num_workers = if args.workers == 0 {
        num_cpus::get()
    } else {
        args.workers
    };

    println!("[INFO] Tokenization Benchmark");
    println!("[INFO] Workers: {}", num_workers);
    println!("[INFO] Batch size: {}", args.batch_size);

    // Load tokenizer
    println!("[INFO] Loading tokenizer: {}", args.tokenizer_model);
    let tokenizer = if std::path::Path::new(&args.tokenizer_model).exists() {
        Tokenizer::from_file(&args.tokenizer_model)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer from file: {}", e))?
    } else {
        // Try to load from HuggingFace cache
        let cache_dir = std::env::var("HF_HOME")
            .or_else(|_| std::env::var("HOME").map(|h| format!("{}/.cache/huggingface", h)))
            .unwrap_or_else(|_| ".cache/huggingface".to_string());
        
        let model_path = args.tokenizer_model.replace("/", "--");
        let tokenizer_path = format!("{}/hub/models--{}/snapshots", cache_dir, model_path);
        
        // Find the tokenizer.json in snapshots
        let mut found_tokenizer = None;
        if let Ok(entries) = std::fs::read_dir(&tokenizer_path) {
            for entry in entries.flatten() {
                let snapshot_path = entry.path().join("tokenizer.json");
                if snapshot_path.exists() {
                    found_tokenizer = Some(snapshot_path);
                    break;
                }
            }
        }
        
        if let Some(path) = found_tokenizer {
            Tokenizer::from_file(path)
                .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?
        } else {
            anyhow::bail!("Tokenizer not found. Please download it first using Python or provide a path to tokenizer.json");
        }
    };
    println!("[INFO] Tokenizer loaded successfully");

    // Load data
    let data = if args.synthetic {
        generate_synthetic_data(args.synthetic_count, args.synthetic_length)
    } else {
        let input_dir = args.input_dir.ok_or_else(|| {
            anyhow::anyhow!("--input-dir is required when not using --synthetic mode")
        })?;
        load_data(&input_dir, args.num_files)?
    };

    if data.is_empty() {
        anyhow::bail!("No data loaded");
    }

    // Benchmark 1: Worker pool (main benchmark)
    let stats = BenchmarkStats::new();
    benchmark_worker_pool(&tokenizer, &data, &stats, num_workers, args.threads_per_tokenizer)?;

    // std::env::set_var("TOKENIZERS_PARALLELISM", "true");

    // Benchmark 2: Batch tokenization (if requested)
    if args.compare_batch {
        let worker_elapsed = stats.start_time.elapsed().as_secs_f64();
        let stats_batch = BenchmarkStats::new();
        benchmark_batch_tokenization(&tokenizer, &data, &stats_batch, args.batch_size)?;

        // Compare results
        println!("\n=== Performance Comparison ===");
        let batch_elapsed = stats_batch.start_time.elapsed().as_secs_f64();
        let speedup = worker_elapsed / batch_elapsed;

        println!(
            "Worker Pool: {:.1}s | Batch: {:.1}s | Speedup: {:.2}x",
            worker_elapsed, batch_elapsed, speedup
        );
    }

    println!("\n[INFO] Benchmark complete");
    Ok(())
}

Environment:

[package]
name = "tokenizer_bench"
version = "0.1.0"
edition = "2021"

[dependencies]
tokenizers = "0.22"
parquet = "57"
arrow = "57"
clap = { version = "4.5", features = ["derive"] }
rayon = "1.10"
indicatif = "0.18"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
num_cpus = "1.16"
rand = "0.8"

[profile.release]
opt-level = 3
lto = true
codegen-units = 1

System Info (inside container):

❯ rustc --version
rustc 1.91.1 (ed61e7d7e 2025-11-07)
❯ cat /etc/os-release | head -n 1
PRETTY_NAME="Ubuntu 22.04.5 LTS"
❯ uname -r
5.15.0-119-generic
❯ lsmem
RANGE                            SIZE  STATE REMOVABLE  BLOCK
0x0000000000000000-0x000000007fffffff   2G online       yes      0
0x0000000100000000-0x000002007fffffff   2T online       yes 2-1024

Memory block size:         2G
Total online memory:       2T
Total offline memory:      0B
❯ cat /sys/fs/cgroup/cpu.max
12000000 100000
❯ lscpu
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8462Y+

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions