diff --git a/src/fragment.rs b/src/fragment.rs index b333e5d..6968a07 100644 --- a/src/fragment.rs +++ b/src/fragment.rs @@ -27,7 +27,7 @@ use tokio::io::AsyncRead; use tokio::io::AsyncWrite; pub enum Fragment { - Text(Option, Vec), + Text(Option, Vec, usize), Binary(Vec), } @@ -35,7 +35,7 @@ impl Fragment { /// Returns the payload of the fragment. fn take_buffer(self) -> Vec { match self { - Fragment::Text(_, buffer) => buffer, + Fragment::Text(_, buffer, _) => buffer, Fragment::Binary(buffer) => buffer, } } @@ -118,7 +118,10 @@ impl<'f, S> FragmentCollector { if is_closed && frame.opcode != OpCode::Close { return Err(WebSocketError::ConnectionClosed); } - if let Some(frame) = self.fragments.accumulate(frame)? { + if let Some(frame) = self + .fragments + .accumulate(frame, self.read_half.max_message_size)? + { return Ok(frame); } } @@ -191,7 +194,10 @@ impl<'f, S> FragmentCollectorRead { let Some(frame) = res? else { continue; }; - if let Some(frame) = self.fragments.accumulate(frame)? { + if let Some(frame) = self + .fragments + .accumulate(frame, self.read_half.max_message_size)? + { return Ok(frame); } } @@ -215,6 +221,7 @@ impl Fragments { pub fn accumulate<'f>( &mut self, frame: Frame<'f>, + max_message_size: usize, ) -> Result>, WebSocketError> { match frame.opcode { OpCode::Text | OpCode::Binary => { @@ -224,15 +231,23 @@ impl Fragments { } return Ok(Some(Frame::new(true, frame.opcode, None, frame.payload))); } else { + if frame.payload.len() >= max_message_size { + return Err(WebSocketError::FrameTooLarge); + } self.fragments = match frame.opcode { OpCode::Text => match utf8::decode(&frame.payload) { - Ok(text) => Some(Fragment::Text(None, text.as_bytes().to_vec())), + Ok(text) => Some(Fragment::Text( + None, + text.as_bytes().to_vec(), + frame.payload.len(), + )), Err(utf8::DecodeError::Incomplete { valid_prefix, incomplete_suffix, }) => Some(Fragment::Text( Some(incomplete_suffix), valid_prefix.as_bytes().to_vec(), + frame.payload.len(), )), Err(utf8::DecodeError::Invalid { .. }) => { return Err(WebSocketError::InvalidUTF8); @@ -248,7 +263,15 @@ impl Fragments { None => { return Err(WebSocketError::InvalidContinuationFrame); } - Some(Fragment::Text(data, input)) => { + Some(Fragment::Text(data, input, message_len)) => { + let new_message_len = message_len + .checked_add(frame.payload.len()) + .ok_or(WebSocketError::FrameTooLarge)?; + if new_message_len >= max_message_size { + return Err(WebSocketError::FrameTooLarge); + } + *message_len = new_message_len; + let mut tail = &frame.payload[..]; if let Some(mut incomplete) = data.take() { if let Some((result, rest)) = @@ -296,6 +319,14 @@ impl Fragments { } } Some(Fragment::Binary(data)) => { + let message_len = data + .len() + .checked_add(frame.payload.len()) + .ok_or(WebSocketError::FrameTooLarge)?; + if message_len >= max_message_size { + return Err(WebSocketError::FrameTooLarge); + } + data.extend_from_slice(&frame.payload); if frame.fin { return Ok(Some(Frame::new( diff --git a/tests/fragment.rs b/tests/fragment.rs new file mode 100644 index 0000000..1217c05 --- /dev/null +++ b/tests/fragment.rs @@ -0,0 +1,81 @@ +use fastwebsockets::FragmentCollector; +#[cfg(feature = "unstable-split")] +use fastwebsockets::FragmentCollectorRead; +use fastwebsockets::Frame; +use fastwebsockets::OpCode; +use fastwebsockets::Role; +use fastwebsockets::WebSocket; +use fastwebsockets::WebSocketError; +use tokio::io::AsyncWriteExt; + +fn encoded_frames(mut frames: Vec>) -> Vec { + let mut out = Vec::new(); + let mut scratch = Vec::new(); + + for frame in &mut frames { + out.extend_from_slice(frame.write(&mut scratch)); + } + + out +} + +fn assert_frame_too_large(result: Result) { + assert!(matches!(result, Err(WebSocketError::FrameTooLarge))); +} + +#[tokio::test] +async fn fragment_collector_rejects_aggregate_binary_over_limit() { + let (mut peer, socket) = tokio::io::duplex(1024); + let mut ws = WebSocket::after_handshake(socket, Role::Client); + ws.set_max_message_size(9); + let mut ws = FragmentCollector::new(ws); + + let frames = encoded_frames(vec![ + Frame::new(false, OpCode::Binary, None, b"12345".to_vec().into()), + Frame::new(true, OpCode::Continuation, None, b"67890".to_vec().into()), + ]); + peer.write_all(&frames).await.unwrap(); + + assert_frame_too_large(ws.read_frame().await); +} + +#[tokio::test] +async fn fragment_collector_rejects_aggregate_text_over_limit() { + let (mut peer, socket) = tokio::io::duplex(1024); + let mut ws = WebSocket::after_handshake(socket, Role::Client); + ws.set_max_message_size(9); + let mut ws = FragmentCollector::new(ws); + + let frames = encoded_frames(vec![ + Frame::new(false, OpCode::Text, None, b"hello".to_vec().into()), + Frame::new(true, OpCode::Continuation, None, b"world".to_vec().into()), + ]); + peer.write_all(&frames).await.unwrap(); + + assert_frame_too_large(ws.read_frame().await); +} + +#[cfg(feature = "unstable-split")] +#[tokio::test] +async fn split_fragment_collector_rejects_aggregate_binary_over_limit() { + let (mut peer, socket) = tokio::io::duplex(1024); + let (read, _write) = tokio::io::split(socket); + let (mut ws_read, _ws_write) = fastwebsockets::after_handshake_split( + read, + tokio::io::sink(), + Role::Client, + ); + ws_read.set_max_message_size(9); + let mut ws = FragmentCollectorRead::new(ws_read); + + let frames = encoded_frames(vec![ + Frame::new(false, OpCode::Binary, None, b"12345".to_vec().into()), + Frame::new(true, OpCode::Continuation, None, b"67890".to_vec().into()), + ]); + peer.write_all(&frames).await.unwrap(); + + assert_frame_too_large( + ws.read_frame(&mut |_| async { Ok::<(), std::io::Error>(()) }) + .await, + ); +}