Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ jobs:
working-directory: ./tokenizers
run: make test

# Skip integration tests for now on Windows
- name: Download xnli test dataset on Windows for whitespace equivalence test
if: matrix.os == 'windows-latest'
shell: bash
working-directory: ./tokenizers
run: |
mkdir -p data
curl -L https://huggingface.co/datasets/hf-internal-testing/tokenizers-test-data/resolve/main/xnli.txt -o data/xnli.txt

- name: Run lib Tests on Windows
if: matrix.os == 'windows-latest'
uses: actions-rs/cargo@844f36862e911db73fe0815f00a4a2602c279505 # v1
Expand Down
10 changes: 7 additions & 3 deletions tokenizers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ harness = false
name = "truncation_benchmark"
harness = false

[[bench]]
name = "whitespace_pretok_benchmark"
harness = false

[[bench]]
name = "ci_benchmark"
harness = false
Expand All @@ -83,7 +87,9 @@ itertools = "0.14"
log = "0.4"
derive_builder = "0.20"
spm_precompiled = "0.1.3"
hf-hub = { version = "0.4.1", features = ["ureq"], default-features = false, optional = true }
hf-hub = { version = "0.4.1", features = [
"ureq",
], default-features = false, optional = true }
daachorse = "1.0.1"
paste = "1.0.14"
macro_rules_attribute = "0.2.0"
Expand Down Expand Up @@ -121,5 +127,3 @@ debug = true
[[example]]
name = "encode_batch"
required-features = ["http"]


6 changes: 5 additions & 1 deletion tokenizers/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dir_guard=@mkdir -p $(@D)

SHARED_RESOURCES = $(DATA_DIR)/gpt2-vocab.json $(DATA_DIR)/gpt2-merges.txt $(DATA_DIR)/bert-base-uncased-vocab.txt $(DATA_DIR)/big.txt $(DATA_DIR)/small.txt $(DATA_DIR)/albert-base-v1-tokenizer.json $(DATA_DIR)/llama-3-tokenizer.json
BENCHMARK_RESOURCES = $(SHARED_RESOURCES)
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json
TESTS_RESOURCES = $(SHARED_RESOURCES) $(DATA_DIR)/unigram.json $(DATA_DIR)/unigram_wagahaiwa_nekodearu.txt $(DATA_DIR)/roberta.json $(DATA_DIR)/tokenizer-wiki.json $(DATA_DIR)/bert-wiki.json $(DATA_DIR)/xnli.txt

.PHONY : build
build :
Expand Down Expand Up @@ -94,3 +94,7 @@ $(DATA_DIR)/bert-wiki.json :
$(DATA_DIR)/llama-3-tokenizer.json :
$(dir_guard)
wget $(HF_TEST_DATA)/llama-3-tokenizer.json -O $@

$(DATA_DIR)/xnli.txt :
$(dir_guard)
wget $(HF_TEST_DATA)/xnli.txt -O $@
96 changes: 96 additions & 0 deletions tokenizers/benches/whitespace_pretok_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#[macro_use]
extern crate criterion;

use criterion::{BenchmarkId, Criterion, Throughput};
use std::hint::black_box;
use tokenizers::pre_tokenizers::whitespace::{ManualWhitespaceSplit, Whitespace};
use tokenizers::{PreTokenizedString, PreTokenizer};

fn bench_pretokenizer(c: &mut Criterion) {
let data = std::fs::read_to_string("data/big.txt").unwrap();
let lines: Vec<&str> = data.lines().collect();

let mut group = c.benchmark_group("whitespace-pretok");
group.throughput(Throughput::Bytes(data.len() as u64));

// Full corpus as a single string
group.bench_function("regex/full-corpus", |b| {
let pretok = Whitespace {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(data.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});

group.bench_function("manual/full-corpus", |b| {
let pretok = ManualWhitespaceSplit {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(data.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});

// Line-by-line (many short strings — tests per-call overhead)
group.bench_function("regex/line-by-line", |b| {
let pretok = Whitespace {};
b.iter(|| {
for line in &lines {
let mut pre = PreTokenizedString::from(black_box(*line));
pretok.pre_tokenize(&mut pre).unwrap();
black_box(&pre);
}
})
});

group.bench_function("manual/line-by-line", |b| {
let pretok = ManualWhitespaceSplit {};
b.iter(|| {
for line in &lines {
let mut pre = PreTokenizedString::from(black_box(*line));
pretok.pre_tokenize(&mut pre).unwrap();
black_box(&pre);
}
})
});

group.finish();

// --- Scaling with input size ---

let mut group = c.benchmark_group("whitespace-pretok-scaling");

for size in [100, 1_000, 10_000, 100_000] {
let input: String = data.chars().take(size).collect();
group.throughput(Throughput::Bytes(input.len() as u64));

group.bench_with_input(BenchmarkId::new("regex", size), &input, |b, input| {
let pretok = Whitespace {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(input.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});

group.bench_with_input(BenchmarkId::new("manual", size), &input, |b, input| {
let pretok = ManualWhitespaceSplit {};
b.iter(|| {
let mut pre = PreTokenizedString::from(black_box(input.as_str()));
pretok.pre_tokenize(&mut pre).unwrap();
pre
})
});
}

group.finish();
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bench a full encode_batch please


criterion_group! {
name = whitespace_pretok;
config = Criterion::default().sample_size(50);
targets = bench_pretokenizer
}

criterion_main!(whitespace_pretok);
199 changes: 199 additions & 0 deletions tokenizers/src/pre_tokenizers/whitespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::LazyLock;

use regex::Regex;

use crate::pattern::Pattern;
use crate::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
};
Expand Down Expand Up @@ -40,6 +41,85 @@ impl PreTokenizer for WhitespaceSplit {
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct ManualWhitespaceSplit;

impl PreTokenizer for ManualWhitespaceSplit {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, normalized| {
normalized.split(WhiteSpacePattern, SplitDelimiterBehavior::Removed)
})
}
}

#[derive(Clone, Copy, Eq, PartialEq)]
enum CharType {
Whitespace,
Word,
Symbol,
}

struct WhiteSpacePattern;

impl Pattern for WhiteSpacePattern {
fn find_matches(&self, inside: &str) -> Result<Vec<(crate::Offsets, bool)>> {
if inside.is_empty() {
return Ok(vec![((0, 0), false)]);
}

let mut matches = Vec::new();
let mut span_start = 0;
let mut prev_type: Option<CharType> = None;

for (i, ch) in inside.char_indices() {
let ct = classify(ch);

if let Some(pt) = prev_type {
if pt != ct {
// Emit the previous span:
// - whitespace spans are non-matches (false)
// - word/symbol spans are matches (true)
matches.push(((span_start, i), pt == CharType::Whitespace));
span_start = i;
}
}
prev_type = Some(ct);
}

// Emit the final span
if let Some(pt) = prev_type {
matches.push(((span_start, inside.len()), pt == CharType::Whitespace));
}

Ok(matches)
}
}

fn classify(ch: char) -> CharType {
if ch.is_whitespace() {
CharType::Whitespace
} else if is_word_char(ch) {
CharType::Word
} else {
CharType::Symbol
}
}

/// Matches the same characters as the `\w` regex class (Unicode-aware).
/// This is: Alphabetic + Nd (decimal digit) + Pc (connector punctuation) +
/// M (marks) + Join_Control — NOT Nl/No (which Rust's is_alphanumeric includes).
fn is_word_char(ch: char) -> bool {
use unicode_categories::UnicodeCategories;

ch.is_alphabetic()
|| ch.is_number_decimal_digit()
|| ch.is_punctuation_connector()
|| ch.is_mark()
|| ch == '\u{200c}' // Zero-Width Non-Joiner
|| ch == '\u{200d}' // Zero-Width Joiner
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -102,4 +182,123 @@ mod tests {
);
}
}

#[test]
fn assert_equivalent() {
let test_cases = vec![
"Hello world!",
"How are you doing?",
"This is a test with numbers 123 and symbols @#$%",
"Multiple spaces",
"Tabs\tand\nnewlines",
"Unicode: café résumé naïve",
"Mixed: Hello123!@# world",
"Edge cases: a.b,c;d:e",
"Empty string:",
"Only spaces: ",
"Only symbols: !@#$%",
"Only words: hello world",
"Numbers: 123 456 789",
"Underscores: hello_world test_case",
"Special chars: αβγ δέζ ηθι",
];

for test_case in test_cases {
let mut original = PreTokenizedString::from(test_case);
let mut manual = PreTokenizedString::from(test_case);

let original_pretok = Whitespace {};
let manual_pretok = ManualWhitespaceSplit {};

original_pretok.pre_tokenize(&mut original).unwrap();
manual_pretok.pre_tokenize(&mut manual).unwrap();

let original_splits = original
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

let manual_splits = manual
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

assert_eq!(
original_splits, manual_splits,
"Mismatch for test case: '{}'",
test_case
);
}
}

#[test]
fn manual_edge_cases() {
let pretok = ManualWhitespaceSplit {};

// Test various edge cases
let edge_cases = vec![
("", vec![]),
(" ", vec![]),
(" ", vec![]),
("a", vec![("a", (0, 1))]),
("!", vec![("!", (0, 1))]),
("a!", vec![("a", (0, 1)), ("!", (1, 2))]),
("!a", vec![("!", (0, 1)), ("a", (1, 2))]),
("a b", vec![("a", (0, 1)), ("b", (2, 3))]),
("a b", vec![("a", (0, 1)), ("b", (3, 4))]),
("a\tb", vec![("a", (0, 1)), ("b", (2, 3))]),
("a\nb", vec![("a", (0, 1)), ("b", (2, 3))]),
("a\r\nb", vec![("a", (0, 1)), ("b", (3, 4))]),
];

for (input, expected) in edge_cases {
let mut pretokenized = PreTokenizedString::from(input);
pretok.pre_tokenize(&mut pretokenized).unwrap();
let result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
assert_eq!(result, expected, "Failed for input: '{}'", input);
}
}

#[test]
fn assert_equivalent_xnli() {
let Ok(data) = std::fs::read_to_string("data/xnli.txt") else {
eprintln!("Could not read data/xnli.txt, skipping test");
return;
};
let original_pretok = Whitespace {};
let manual_pretok = ManualWhitespaceSplit {};

for (i, line) in data.lines().enumerate() {
let mut original = PreTokenizedString::from(line);
let mut manual = PreTokenizedString::from(line);

original_pretok.pre_tokenize(&mut original).unwrap();
manual_pretok.pre_tokenize(&mut manual).unwrap();

let original_splits = original
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
let manual_splits = manual
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

assert_eq!(
original_splits,
manual_splits,
"Mismatch on line {}: '{}'",
i,
&line.chars().take(80).collect::<String>(),
);
}
}
}
Loading