diff --git a/src/frame.rs b/src/frame.rs index 9f7ec4d..500a7f7 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -365,3 +365,71 @@ repr_u8! { pub fn is_control(opcode: OpCode) -> bool { matches!(opcode, OpCode::Close | OpCode::Ping | OpCode::Pong) } + +#[cfg(test)] +mod tests { + use super::*; + + fn parse_frame_header( + buf: &[u8], + ) -> (bool, OpCode, Option<[u8; 4]>, usize, usize) { + let fin = buf[0] & 0b10000000 != 0; + let opcode = OpCode::try_from(buf[0] & 0b00001111).unwrap(); + let masked = buf[1] & 0b10000000 != 0; + let length_code = buf[1] & 0x7F; + let mut offset = 2; + let payload_len = match length_code { + 126 => { + let len = u16::from_be_bytes([buf[2], buf[3]]) as usize; + offset = 4; + len + } + 127 => { + let len = + u64::from_be_bytes(buf[2..10].try_into().unwrap()) as usize; + offset = 10; + len + } + n => n as usize, + }; + let mask = if masked { + let m = + [buf[offset], buf[offset + 1], buf[offset + 2], buf[offset + 3]]; + offset += 4; + Some(m) + } else { + None + }; + (fin, opcode, mask, payload_len, offset) + } + + #[test] + fn frame_roundtrip_at_125_126_boundary() { + for len in [125usize, 126] { + let original_payload: Vec = + (0..len).map(|i| (i % 256) as u8).collect(); + let mut frame = Frame::text(original_payload.clone().into()); + frame.mask(); + + let mut buf = Vec::new(); + frame.write(&mut buf); + + let (fin, opcode, mask, payload_len, offset) = + parse_frame_header(&buf); + assert!(fin); + assert_eq!(opcode, OpCode::Text); + assert!(mask.is_some()); + + let mut parsed_payload = buf[offset..offset + payload_len].to_vec(); + crate::mask::unmask(&mut parsed_payload, mask.unwrap()); + assert_eq!(parsed_payload, original_payload); + + if len == 125 { + assert_eq!(buf[1] & 0x7F, 125); + } else { + assert_eq!(buf[1] & 0x7F, 126); + assert_eq!(payload_len, 126); + } + } + } +} diff --git a/src/mask.rs b/src/mask.rs index b1b4de3..bfd68b0 100644 --- a/src/mask.rs +++ b/src/mask.rs @@ -169,4 +169,33 @@ mod tests { assert_eq!(payload, expected); } } + + #[test] + fn roundtrip_edge_cases() { + let mask = [0xAB, 0xCD, 0xEF, 0x12]; + for &len in &[0, 1, 3, 7] { + let original: Vec = + (0..len).map(|i| (i as u8).wrapping_add(0x42)).collect(); + let mut payload = original.clone(); + unmask(&mut payload, mask); + unmask(&mut payload, mask); + assert_eq!(payload, original, "roundtrip failed for length {}", len); + } + } + + #[test] + fn mask_key_indexing_rfc6455() { + // RFC 6455 ยง5.3: octet i of the payload is XORed with octet i mod 4 of the mask key + let mask = [0x01, 0x02, 0x03, 0x04]; + let mut payload = vec![0xFF; 10]; + unmask(&mut payload, mask); + for i in 0..10 { + assert_eq!( + payload[i], + 0xFF ^ mask[i % 4], + "mask indexing mismatch at {}", + i + ); + } + } }