diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a3eff25e80..f785aa2655 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -75,7 +75,14 @@ jobs: working-directory: ./tokenizers run: make test - # Skip integration tests for now on Windows + - name: Download xnli test dataset on Windows for whitespace equivalence test + if: matrix.os == 'windows-latest' + shell: bash + working-directory: ./tokenizers + run: | + mkdir -p data + curl -L https://huggingface.co/datasets/hf-internal-testing/tokenizers-test-data/resolve/main/xnli.txt -o data/xnli.txt + - name: Run lib Tests on Windows if: matrix.os == 'windows-latest' uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1 diff --git a/tokenizers/Cargo.toml b/tokenizers/Cargo.toml index 0e937f3cc5..505a4c237c 100644 --- a/tokenizers/Cargo.toml +++ b/tokenizers/Cargo.toml @@ -62,6 +62,10 @@ harness = false name = "truncation_benchmark" harness = false +[[bench]] +name = "whitespace_pretok_benchmark" +harness = false + [[bench]] name = "ci_benchmark" harness = false @@ -83,7 +87,9 @@ itertools = "0.14" log = "0.4" derive_builder = "0.20" spm_precompiled = "0.1.3" -hf-hub = { version = "0.4.1", features = ["ureq"], default-features = false, optional = true } +hf-hub = { version = "0.4.1", features = [ + "ureq", +], default-features = false, optional = true } daachorse = "1.0.1" paste = "1.0.14" macro_rules_attribute = "0.2.0" @@ -121,5 +127,3 @@ debug = true [[example]] name = "encode_batch" required-features = ["http"] - - diff --git a/tokenizers/Makefile b/tokenizers/Makefile index e7bd98aada..beab1ed317 100644 --- a/tokenizers/Makefile +++ b/tokenizers/Makefile @@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D) SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json BENCHMARK_RESOURCES = $(SHARED_RESOURCES) -TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json +TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json $(DATA_DIR)/xnli.txt .PHONY : build build : @@ -94,3 +94,7 @@ $(DATA_DIR)/bert-wiki.json : $(DATA_DIR)/llama-3-tokenizer.json : $(dir_guard) wget $(HF_TEST_DATA)/llama-3-tokenizer.json -O $@ + +$(DATA_DIR)/xnli.txt : + $(dir_guard) + wget $(HF_TEST_DATA)/xnli.txt -O $@ diff --git a/tokenizers/benches/whitespace_pretok_benchmark.rs b/tokenizers/benches/whitespace_pretok_benchmark.rs new file mode 100644 index 0000000000..1695c97218 --- /dev/null +++ b/tokenizers/benches/whitespace_pretok_benchmark.rs @@ -0,0 +1,96 @@ +#[macro_use] +extern crate criterion; + +use criterion::{BenchmarkId, Criterion, Throughput}; +use std::hint::black_box; +use tokenizers::pre_tokenizers::whitespace::{ManualWhitespaceSplit, Whitespace}; +use tokenizers::{PreTokenizedString, PreTokenizer}; + +fn bench_pretokenizer(c: &mut Criterion) { + let data = std::fs::read_to_string("data/big.txt").unwrap(); + let lines: Vec<&str> = data.lines().collect(); + + let mut group = c.benchmark_group("whitespace-pretok"); + group.throughput(Throughput::Bytes(data.len() as u64)); + + // Full corpus as a single string + group.bench_function("regex/full-corpus", |b| { + let pretok = Whitespace {}; + b.iter(|| { + let mut pre = PreTokenizedString::from(black_box(data.as_str())); + pretok.pre_tokenize(&mut pre).unwrap(); + pre + }) + }); + + group.bench_function("manual/full-corpus", |b| { + let pretok = ManualWhitespaceSplit {}; + b.iter(|| { + let mut pre = PreTokenizedString::from(black_box(data.as_str())); + pretok.pre_tokenize(&mut pre).unwrap(); + pre + }) + }); + + // Line-by-line (many short strings — tests per-call overhead) + group.bench_function("regex/line-by-line", |b| { + let pretok = Whitespace {}; + b.iter(|| { + for line in &lines { + let mut pre = PreTokenizedString::from(black_box(*line)); + pretok.pre_tokenize(&mut pre).unwrap(); + black_box(&pre); + } + }) + }); + + group.bench_function("manual/line-by-line", |b| { + let pretok = ManualWhitespaceSplit {}; + b.iter(|| { + for line in &lines { + let mut pre = PreTokenizedString::from(black_box(*line)); + pretok.pre_tokenize(&mut pre).unwrap(); + black_box(&pre); + } + }) + }); + + group.finish(); + + // --- Scaling with input size --- + + let mut group = c.benchmark_group("whitespace-pretok-scaling"); + + for size in [100, 1_000, 10_000, 100_000] { + let input: String = data.chars().take(size).collect(); + group.throughput(Throughput::Bytes(input.len() as u64)); + + group.bench_with_input(BenchmarkId::new("regex", size), &input, |b, input| { + let pretok = Whitespace {}; + b.iter(|| { + let mut pre = PreTokenizedString::from(black_box(input.as_str())); + pretok.pre_tokenize(&mut pre).unwrap(); + pre + }) + }); + + group.bench_with_input(BenchmarkId::new("manual", size), &input, |b, input| { + let pretok = ManualWhitespaceSplit {}; + b.iter(|| { + let mut pre = PreTokenizedString::from(black_box(input.as_str())); + pretok.pre_tokenize(&mut pre).unwrap(); + pre + }) + }); + } + + group.finish(); +} + +criterion_group! { + name = whitespace_pretok; + config = Criterion::default().sample_size(50); + targets = bench_pretokenizer +} + +criterion_main!(whitespace_pretok); diff --git a/tokenizers/src/pre_tokenizers/whitespace.rs b/tokenizers/src/pre_tokenizers/whitespace.rs index 20cfb65193..47077fc886 100644 --- a/tokenizers/src/pre_tokenizers/whitespace.rs +++ b/tokenizers/src/pre_tokenizers/whitespace.rs @@ -2,6 +2,7 @@ use std::sync::LazyLock; use regex::Regex; +use crate::pattern::Pattern; use crate::tokenizer::{ pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior, }; @@ -40,6 +41,85 @@ impl PreTokenizer for WhitespaceSplit { } } +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[macro_rules_attribute(impl_serde_type!)] +pub struct ManualWhitespaceSplit; + +impl PreTokenizer for ManualWhitespaceSplit { + fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> { + pretokenized.split(|_, normalized| { + normalized.split(WhiteSpacePattern, SplitDelimiterBehavior::Removed) + }) + } +} + +#[derive(Clone, Copy, Eq, PartialEq)] +enum CharType { + Whitespace, + Word, + Symbol, +} + +struct WhiteSpacePattern; + +impl Pattern for WhiteSpacePattern { + fn find_matches(&self, inside: &str) -> Result> { + if inside.is_empty() { + return Ok(vec![((0, 0), false)]); + } + + let mut matches = Vec::new(); + let mut span_start = 0; + let mut prev_type: Option = None; + + for (i, ch) in inside.char_indices() { + let ct = classify(ch); + + if let Some(pt) = prev_type { + if pt != ct { + // Emit the previous span: + // - whitespace spans are non-matches (false) + // - word/symbol spans are matches (true) + matches.push(((span_start, i), pt == CharType::Whitespace)); + span_start = i; + } + } + prev_type = Some(ct); + } + + // Emit the final span + if let Some(pt) = prev_type { + matches.push(((span_start, inside.len()), pt == CharType::Whitespace)); + } + + Ok(matches) + } +} + +fn classify(ch: char) -> CharType { + if ch.is_whitespace() { + CharType::Whitespace + } else if is_word_char(ch) { + CharType::Word + } else { + CharType::Symbol + } +} + +/// Matches the same characters as the `\w` regex class (Unicode-aware). +/// This is: Alphabetic + Nd (decimal digit) + Pc (connector punctuation) + +/// M (marks) + Join_Control — NOT Nl/No (which Rust's is_alphanumeric includes). +fn is_word_char(ch: char) -> bool { + use unicode_categories::UnicodeCategories; + + ch.is_alphabetic() + || ch.is_number_decimal_digit() + || ch.is_punctuation_connector() + || ch.is_mark() + || ch == '\u{200c}' // Zero-Width Non-Joiner + || ch == '\u{200d}' // Zero-Width Joiner +} + #[cfg(test)] mod tests { use super::*; @@ -102,4 +182,123 @@ mod tests { ); } } + + #[test] + fn assert_equivalent() { + let test_cases = vec![ + "Hello world!", + "How are you doing?", + "This is a test with numbers 123 and symbols @#$%", + "Multiple spaces", + "Tabs\tand\nnewlines", + "Unicode: café résumé naïve", + "Mixed: Hello123!@# world", + "Edge cases: a.b,c;d:e", + "Empty string:", + "Only spaces: ", + "Only symbols: !@#$%", + "Only words: hello world", + "Numbers: 123 456 789", + "Underscores: hello_world test_case", + "Special chars: αβγ δέζ ηθι", + ]; + + for test_case in test_cases { + let mut original = PreTokenizedString::from(test_case); + let mut manual = PreTokenizedString::from(test_case); + + let original_pretok = Whitespace {}; + let manual_pretok = ManualWhitespaceSplit {}; + + original_pretok.pre_tokenize(&mut original).unwrap(); + manual_pretok.pre_tokenize(&mut manual).unwrap(); + + let original_splits = original + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + let manual_splits = manual + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!( + original_splits, manual_splits, + "Mismatch for test case: '{}'", + test_case + ); + } + } + + #[test] + fn manual_edge_cases() { + let pretok = ManualWhitespaceSplit {}; + + // Test various edge cases + let edge_cases = vec![ + ("", vec![]), + (" ", vec![]), + (" ", vec![]), + ("a", vec![("a", (0, 1))]), + ("!", vec![("!", (0, 1))]), + ("a!", vec![("a", (0, 1)), ("!", (1, 2))]), + ("!a", vec![("!", (0, 1)), ("a", (1, 2))]), + ("a b", vec![("a", (0, 1)), ("b", (2, 3))]), + ("a b", vec![("a", (0, 1)), ("b", (3, 4))]), + ("a\tb", vec![("a", (0, 1)), ("b", (2, 3))]), + ("a\nb", vec![("a", (0, 1)), ("b", (2, 3))]), + ("a\r\nb", vec![("a", (0, 1)), ("b", (3, 4))]), + ]; + + for (input, expected) in edge_cases { + let mut pretokenized = PreTokenizedString::from(input); + pretok.pre_tokenize(&mut pretokenized).unwrap(); + let result = pretokenized + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + assert_eq!(result, expected, "Failed for input: '{}'", input); + } + } + + #[test] + fn assert_equivalent_xnli() { + let Ok(data) = std::fs::read_to_string("data/xnli.txt") else { + eprintln!("Could not read data/xnli.txt, skipping test"); + return; + }; + let original_pretok = Whitespace {}; + let manual_pretok = ManualWhitespaceSplit {}; + + for (i, line) in data.lines().enumerate() { + let mut original = PreTokenizedString::from(line); + let mut manual = PreTokenizedString::from(line); + + original_pretok.pre_tokenize(&mut original).unwrap(); + manual_pretok.pre_tokenize(&mut manual).unwrap(); + + let original_splits = original + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + let manual_splits = manual + .get_splits(OffsetReferential::Original, OffsetType::Byte) + .into_iter() + .map(|(s, o, _)| (s, o)) + .collect::>(); + + assert_eq!( + original_splits, + manual_splits, + "Mismatch on line {}: '{}'", + i, + &line.chars().take(80).collect::(), + ); + } + } }