diff --git a/tokenizers/src/processors/roberta.rs b/tokenizers/src/processors/roberta.rs index f2a47a9d38..1c345801aa 100644 --- a/tokenizers/src/processors/roberta.rs +++ b/tokenizers/src/processors/roberta.rs @@ -79,9 +79,12 @@ impl PostProcessor for RobertaProcessing { } // Roberta is weird, and every encoding is type_id=0. - encodings - .iter_mut() - .for_each(|encoding| encoding.set_type_ids(vec![0; encoding.len()])); + encodings.iter_mut().for_each(|encoding| { + encoding.set_type_ids(vec![0; encoding.len()]); + for overflow in encoding.get_overflowing_mut() { + overflow.set_type_ids(vec![0; overflow.len()]); + } + }); if !add_special_tokens { return Ok(encodings); diff --git a/tokenizers/src/processors/template.rs b/tokenizers/src/processors/template.rs index 50fac99dfc..bdf2fde459 100644 --- a/tokenizers/src/processors/template.rs +++ b/tokenizers/src/processors/template.rs @@ -556,6 +556,10 @@ impl TemplateProcessing { let encoding = &mut encodings[i]; encoding.set_type_ids(vec![*type_id; encoding.len()]); encoding.set_sequence_id(i); + for overflow in encoding.get_overflowing_mut() { + overflow.set_type_ids(vec![*type_id; overflow.len()]); + overflow.set_sequence_id(i); + } Some(encoding.clone()) } Piece::SpecialToken { id, type_id } => { @@ -1048,7 +1052,7 @@ mod tests { vec![1, 1, 1, 1, 1, 1], vec![Encoding::new( vec![1, 13, 0, 17, 0], - vec![0, 0, 0, 0, 1], + vec![0, 0, 0, 1, 1], vec![ "[CLS]".into(), "you".into(), @@ -1067,7 +1071,7 @@ mod tests { ), Encoding::new( vec![1, 13, 0, 17, 0], - vec![0, 0, 0, 0, 1], + vec![0, 0, 0, 1, 1], vec![ "[CLS]".into(), "you".into(), @@ -1084,7 +1088,7 @@ mod tests { ), Encoding::new( vec![1, 12, 14, 0, 17, 0], - vec![0, 0, 0, 0, 0, 1], + vec![0, 0, 0, 0, 1, 1], vec![ "[CLS]".into(), "Hello".into(), @@ -1099,7 +1103,7 @@ mod tests { vec![1, 1, 1, 1, 1, 1], vec![Encoding::new( vec![1, 13, 0, 17, 0], - vec![0, 0, 0, 0, 1], + vec![0, 0, 0, 1, 1], vec![ "[CLS]".into(), "you".into(), @@ -1126,6 +1130,87 @@ mod tests { assert_eq!(pair_encoding.token_to_sequence(5), Some(1)); assert_eq!(pair_encoding.token_to_sequence(6), None); } + + #[test] + fn template_processing_overflow_type_ids() { + // Regression test for https://github.com/huggingface/tokenizers/issues/1908 + // Verifies that type_ids are correctly applied to overflow encodings, + // not just the main encoding. + let processor = TemplateProcessing::builder() + .try_single("[CLS]:0 $0 [SEP]:0") + .unwrap() + .try_pair("[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1") + .unwrap() + .special_tokens(vec![("[CLS]", 1), ("[SEP]", 0)]) + .build() + .unwrap(); + + use crate::Token; + + // Sequence A with one overflow + let mut encoding_a = Encoding::from_tokens( + vec![Token::new(10, "hello".into(), (0, 5))], + 0, + ); + let overflow_a = Encoding::from_tokens( + vec![Token::new(11, "world".into(), (6, 11))], + 0, + ); + encoding_a.set_overflowing(vec![overflow_a]); + + // Sequence B with one overflow + let mut encoding_b = Encoding::from_tokens( + vec![Token::new(20, "foo".into(), (0, 3))], + 0, + ); + let overflow_b = Encoding::from_tokens( + vec![Token::new(21, "bar".into(), (4, 7))], + 0, + ); + encoding_b.set_overflowing(vec![overflow_b]); + + let result = processor + .process(encoding_a, Some(encoding_b), true) + .unwrap(); + + // Main encoding: [CLS]:0 hello:0 [SEP]:0 foo:1 [SEP]:1 + assert_eq!(result.get_type_ids(), &[0, 0, 0, 1, 1]); + + // Every overflow encoding must also have correct type_ids + for overflow in result.get_overflowing() { + let type_ids = overflow.get_type_ids(); + let tokens = overflow.get_tokens(); + for (i, (tid, tok)) in type_ids.iter().zip(tokens.iter()).enumerate() { + // Tokens from sequence B (or special tokens after B) should have type_id=1 + // Tokens from sequence A (or special tokens before/around A) should have type_id=0 + match tok.as_str() { + "[CLS]" => assert_eq!(*tid, 0, "overflow token {i} '[CLS]' should have type_id=0, got {tid}"), + "[SEP]" => { + // [SEP] after A has type_id=0, [SEP] after B has type_id=1 + // We just check it's 0 or 1 + assert!(*tid <= 1, "overflow token {} '[SEP]' has unexpected type_id={}", i, tid); + } + "foo" | "bar" => assert_eq!(*tid, 1, "overflow token {i} '{tok}' (from seq B) should have type_id=1, got {tid}"), + "hello" | "world" => assert_eq!(*tid, 0, "overflow token {i} '{tok}' (from seq A) should have type_id=0, got {tid}"), + _ => {} + } + } + + // Also check nested overflows + for nested in overflow.get_overflowing() { + let type_ids = nested.get_type_ids(); + let tokens = nested.get_tokens(); + for (i, (tid, tok)) in type_ids.iter().zip(tokens.iter()).enumerate() { + match tok.as_str() { + "foo" | "bar" => assert_eq!(*tid, 1, "nested overflow token {i} '{tok}' (from seq B) should have type_id=1, got {tid}"), + "hello" | "world" => assert_eq!(*tid, 0, "nested overflow token {i} '{tok}' (from seq A) should have type_id=0, got {tid}"), + _ => {} + } + } + } + } + } + #[test] fn pair_must_use_both_sequences() { let processor = TemplateProcessing::builder() diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index 6e88c5e023..d5076ab35f 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -106,11 +106,14 @@ pub trait PostProcessor { }; encodings.iter_mut().enumerate().for_each(|(i, encoding)| { encoding.set_sequence_id(i); + encoding.set_type_ids(vec![i as u32; encoding.len()]); encoding .get_overflowing_mut() .iter_mut() - .for_each(|encoding| encoding.set_sequence_id(i)); - encoding.set_type_ids(vec![i as u32; encoding.len()]); + .for_each(|encoding| { + encoding.set_sequence_id(i); + encoding.set_type_ids(vec![i as u32; encoding.len()]); + }); }); let encodings = self.process_encodings(encodings, add_special_tokens)?;