Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,28 @@ class Tokenizer:
"""
pass

def get_special_tokens(self) -> List[str]:
"""
Get the list of special tokens

Returns:
:obj:`List[str]`: The list of special tokens
"""
pass

def is_special_token(self, token: str) -> bool:
"""
Check if a token is a special token

Args:
token (:obj:`str`):
The token to check

Returns:
:obj:`bool`: Whether the token is a special token
"""
pass

def id_to_token(self, id):
"""
Convert the given id to its corresponding token if it exists
Expand Down
22 changes: 22 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,28 @@ impl PyTokenizer {
self.tokenizer.get_vocab_size(with_added_tokens)
}

/// Get the list of special tokens
///
/// Returns:
/// :obj:`List[str]`: The list of special tokens
#[pyo3(text_signature = "(self)")]
fn get_special_tokens(&self) -> Vec<String> {
self.tokenizer.get_special_tokens()
}

/// Check if a token is a special token
///
/// Args:
/// token (:obj:`str`):
/// The token to check
///
/// Returns:
/// :obj:`bool`: Whether the token is a special token
#[pyo3(text_signature = "(self, token)")]
fn is_special_token(&self, token: &str) -> bool {
self.tokenizer.is_special_token(token)
}

/// Enable truncation
///
/// Args:
Expand Down
31 changes: 31 additions & 0 deletions bindings/python/tests/bindings/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,37 @@ def test_add_special_tokens(self):
assert tokens[0].normalized == False
assert tokens[1].normalized == True

def test_get_special_tokens(self):
tokenizer = Tokenizer(BPE())

# Initially no special tokens
special_tokens = tokenizer.get_special_tokens()
assert special_tokens == []

# Add special tokens
tokenizer.add_special_tokens(["[CLS]", "[SEP]", "[PAD]"])

# Check get_special_tokens returns them
special_tokens = tokenizer.get_special_tokens()
assert set(special_tokens) == {"[CLS]", "[SEP]", "[PAD]"}

# Check is_special_token
assert tokenizer.is_special_token("[CLS]") == True
assert tokenizer.is_special_token("[SEP]") == True
assert tokenizer.is_special_token("[PAD]") == True
assert tokenizer.is_special_token("[UNK]") == False

# Add regular tokens (not special)
tokenizer.add_tokens(["hello", "world"])

# Regular tokens should not be special
assert tokenizer.is_special_token("hello") == False
assert tokenizer.is_special_token("world") == False

# Special tokens list should not change
special_tokens = tokenizer.get_special_tokens()
assert set(special_tokens) == {"[CLS]", "[SEP]", "[PAD]"}

def test_encode(self):
tokenizer = Tokenizer(BPE())
tokenizer.add_tokens(["my", "name", "is", "john", "pair"])
Expand Down
5 changes: 5 additions & 0 deletions tokenizers/src/tokenizer/added_vocabulary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ impl AddedVocabulary {
self.special_tokens_set.contains(token)
}

/// Get the set of special tokens
pub fn get_special_tokens(&self) -> &AHashSet<String> {
&self.special_tokens_set
}

/// Add some special tokens to the vocabulary
pub fn add_special_tokens<N: Normalizer>(
&mut self,
Expand Down
14 changes: 14 additions & 0 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,20 @@ where
self.added_vocabulary.get_encode_special_tokens()
}

/// Check if a token is a special token
pub fn is_special_token(&self, token: &str) -> bool {
self.added_vocabulary.is_special_token(token)
}

/// Get the set of special tokens
pub fn get_special_tokens(&self) -> Vec<String> {
self.added_vocabulary
.get_special_tokens()
.iter()
.cloned()
.collect()
}

/// Encode a single sequence
fn encode_single_sequence(
&self,
Expand Down
Loading