diff --git a/bindings/python/py_src/tokenizers/__init__.pyi b/bindings/python/py_src/tokenizers/__init__.pyi index 44f19b8a4..c7123ed23 100644 --- a/bindings/python/py_src/tokenizers/__init__.pyi +++ b/bindings/python/py_src/tokenizers/__init__.pyi @@ -1501,6 +1501,28 @@ class Tokenizer: """ pass + def get_special_tokens(self) -> List[str]: + """ + Get the list of special tokens + + Returns: + :obj:`List[str]`: The list of special tokens + """ + pass + + def is_special_token(self, token: str) -> bool: + """ + Check if a token is a special token + + Args: + token (:obj:`str`): + The token to check + + Returns: + :obj:`bool`: Whether the token is a special token + """ + pass + def id_to_token(self, id): """ Convert the given id to its corresponding token if it exists diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 0cd06985c..b4bdca27a 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -815,6 +815,28 @@ impl PyTokenizer { self.tokenizer.get_vocab_size(with_added_tokens) } + /// Get the list of special tokens + /// + /// Returns: + /// :obj:`List[str]`: The list of special tokens + #[pyo3(text_signature = "(self)")] + fn get_special_tokens(&self) -> Vec { + self.tokenizer.get_special_tokens() + } + + /// Check if a token is a special token + /// + /// Args: + /// token (:obj:`str`): + /// The token to check + /// + /// Returns: + /// :obj:`bool`: Whether the token is a special token + #[pyo3(text_signature = "(self, token)")] + fn is_special_token(&self, token: &str) -> bool { + self.tokenizer.is_special_token(token) + } + /// Enable truncation /// /// Args: diff --git a/bindings/python/tests/bindings/test_tokenizer.py b/bindings/python/tests/bindings/test_tokenizer.py index 96b75b24d..6c3f46296 100644 --- a/bindings/python/tests/bindings/test_tokenizer.py +++ b/bindings/python/tests/bindings/test_tokenizer.py @@ -125,6 +125,37 @@ def test_add_special_tokens(self): assert tokens[0].normalized == False assert tokens[1].normalized == True + def test_get_special_tokens(self): + tokenizer = Tokenizer(BPE()) + + # Initially no special tokens + special_tokens = tokenizer.get_special_tokens() + assert special_tokens == [] + + # Add special tokens + tokenizer.add_special_tokens(["[CLS]", "[SEP]", "[PAD]"]) + + # Check get_special_tokens returns them + special_tokens = tokenizer.get_special_tokens() + assert set(special_tokens) == {"[CLS]", "[SEP]", "[PAD]"} + + # Check is_special_token + assert tokenizer.is_special_token("[CLS]") == True + assert tokenizer.is_special_token("[SEP]") == True + assert tokenizer.is_special_token("[PAD]") == True + assert tokenizer.is_special_token("[UNK]") == False + + # Add regular tokens (not special) + tokenizer.add_tokens(["hello", "world"]) + + # Regular tokens should not be special + assert tokenizer.is_special_token("hello") == False + assert tokenizer.is_special_token("world") == False + + # Special tokens list should not change + special_tokens = tokenizer.get_special_tokens() + assert set(special_tokens) == {"[CLS]", "[SEP]", "[PAD]"} + def test_encode(self): tokenizer = Tokenizer(BPE()) tokenizer.add_tokens(["my", "name", "is", "john", "pair"]) diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index ff1a592a6..0cce94f7e 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -244,6 +244,11 @@ impl AddedVocabulary { self.special_tokens_set.contains(token) } + /// Get the set of special tokens + pub fn get_special_tokens(&self) -> &AHashSet { + &self.special_tokens_set + } + /// Add some special tokens to the vocabulary pub fn add_special_tokens( &mut self, diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index cedabeebc..d32c9c087 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -714,6 +714,20 @@ where self.added_vocabulary.get_encode_special_tokens() } + /// Check if a token is a special token + pub fn is_special_token(&self, token: &str) -> bool { + self.added_vocabulary.is_special_token(token) + } + + /// Get the set of special tokens + pub fn get_special_tokens(&self) -> Vec { + self.added_vocabulary + .get_special_tokens() + .iter() + .cloned() + .collect() + } + /// Encode a single sequence fn encode_single_sequence( &self,