Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/fragment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;

pub enum Fragment {
Text(Option<utf8::Incomplete>, Vec<u8>),
Text(Option<utf8::Incomplete>, Vec<u8>, usize),
Binary(Vec<u8>),
}

impl Fragment {
/// Returns the payload of the fragment.
fn take_buffer(self) -> Vec<u8> {
match self {
Fragment::Text(_, buffer) => buffer,
Fragment::Text(_, buffer, _) => buffer,
Fragment::Binary(buffer) => buffer,
}
}
Expand Down Expand Up @@ -118,7 +118,10 @@ impl<'f, S> FragmentCollector<S> {
if is_closed && frame.opcode != OpCode::Close {
return Err(WebSocketError::ConnectionClosed);
}
if let Some(frame) = self.fragments.accumulate(frame)? {
if let Some(frame) = self
.fragments
.accumulate(frame, self.read_half.max_message_size)?
{
return Ok(frame);
}
}
Expand Down Expand Up @@ -191,7 +194,10 @@ impl<'f, S> FragmentCollectorRead<S> {
let Some(frame) = res? else {
continue;
};
if let Some(frame) = self.fragments.accumulate(frame)? {
if let Some(frame) = self
.fragments
.accumulate(frame, self.read_half.max_message_size)?
{
return Ok(frame);
}
}
Expand All @@ -215,6 +221,7 @@ impl Fragments {
pub fn accumulate<'f>(
&mut self,
frame: Frame<'f>,
max_message_size: usize,
) -> Result<Option<Frame<'f>>, WebSocketError> {
match frame.opcode {
OpCode::Text | OpCode::Binary => {
Expand All @@ -224,15 +231,23 @@ impl Fragments {
}
return Ok(Some(Frame::new(true, frame.opcode, None, frame.payload)));
} else {
if frame.payload.len() >= max_message_size {
return Err(WebSocketError::FrameTooLarge);
}
self.fragments = match frame.opcode {
OpCode::Text => match utf8::decode(&frame.payload) {
Ok(text) => Some(Fragment::Text(None, text.as_bytes().to_vec())),
Ok(text) => Some(Fragment::Text(
None,
text.as_bytes().to_vec(),
frame.payload.len(),
)),
Err(utf8::DecodeError::Incomplete {
valid_prefix,
incomplete_suffix,
}) => Some(Fragment::Text(
Some(incomplete_suffix),
valid_prefix.as_bytes().to_vec(),
frame.payload.len(),
)),
Err(utf8::DecodeError::Invalid { .. }) => {
return Err(WebSocketError::InvalidUTF8);
Expand All @@ -248,7 +263,15 @@ impl Fragments {
None => {
return Err(WebSocketError::InvalidContinuationFrame);
}
Some(Fragment::Text(data, input)) => {
Some(Fragment::Text(data, input, message_len)) => {
let new_message_len = message_len
.checked_add(frame.payload.len())
.ok_or(WebSocketError::FrameTooLarge)?;
if new_message_len >= max_message_size {
return Err(WebSocketError::FrameTooLarge);
}
*message_len = new_message_len;

let mut tail = &frame.payload[..];
if let Some(mut incomplete) = data.take() {
if let Some((result, rest)) =
Expand Down Expand Up @@ -296,6 +319,14 @@ impl Fragments {
}
}
Some(Fragment::Binary(data)) => {
let message_len = data
.len()
.checked_add(frame.payload.len())
.ok_or(WebSocketError::FrameTooLarge)?;
if message_len >= max_message_size {
return Err(WebSocketError::FrameTooLarge);
}

data.extend_from_slice(&frame.payload);
if frame.fin {
return Ok(Some(Frame::new(
Expand Down
81 changes: 81 additions & 0 deletions tests/fragment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use fastwebsockets::FragmentCollector;
#[cfg(feature = "unstable-split")]
use fastwebsockets::FragmentCollectorRead;
use fastwebsockets::Frame;
use fastwebsockets::OpCode;
use fastwebsockets::Role;
use fastwebsockets::WebSocket;
use fastwebsockets::WebSocketError;
use tokio::io::AsyncWriteExt;

fn encoded_frames(mut frames: Vec<Frame<'static>>) -> Vec<u8> {
let mut out = Vec::new();
let mut scratch = Vec::new();

for frame in &mut frames {
out.extend_from_slice(frame.write(&mut scratch));
}

out
}

fn assert_frame_too_large<T>(result: Result<T, WebSocketError>) {
assert!(matches!(result, Err(WebSocketError::FrameTooLarge)));
}

#[tokio::test]
async fn fragment_collector_rejects_aggregate_binary_over_limit() {
let (mut peer, socket) = tokio::io::duplex(1024);
let mut ws = WebSocket::after_handshake(socket, Role::Client);
ws.set_max_message_size(9);
let mut ws = FragmentCollector::new(ws);

let frames = encoded_frames(vec![
Frame::new(false, OpCode::Binary, None, b"12345".to_vec().into()),
Frame::new(true, OpCode::Continuation, None, b"67890".to_vec().into()),
]);
peer.write_all(&frames).await.unwrap();

assert_frame_too_large(ws.read_frame().await);
}

#[tokio::test]
async fn fragment_collector_rejects_aggregate_text_over_limit() {
let (mut peer, socket) = tokio::io::duplex(1024);
let mut ws = WebSocket::after_handshake(socket, Role::Client);
ws.set_max_message_size(9);
let mut ws = FragmentCollector::new(ws);

let frames = encoded_frames(vec![
Frame::new(false, OpCode::Text, None, b"hello".to_vec().into()),
Frame::new(true, OpCode::Continuation, None, b"world".to_vec().into()),
]);
peer.write_all(&frames).await.unwrap();

assert_frame_too_large(ws.read_frame().await);
}

#[cfg(feature = "unstable-split")]
#[tokio::test]
async fn split_fragment_collector_rejects_aggregate_binary_over_limit() {
let (mut peer, socket) = tokio::io::duplex(1024);
let (read, _write) = tokio::io::split(socket);
let (mut ws_read, _ws_write) = fastwebsockets::after_handshake_split(
read,
tokio::io::sink(),
Role::Client,
);
ws_read.set_max_message_size(9);
let mut ws = FragmentCollectorRead::new(ws_read);

let frames = encoded_frames(vec![
Frame::new(false, OpCode::Binary, None, b"12345".to_vec().into()),
Frame::new(true, OpCode::Continuation, None, b"67890".to_vec().into()),
]);
peer.write_all(&frames).await.unwrap();

assert_frame_too_large(
ws.read_frame(&mut |_| async { Ok::<(), std::io::Error>(()) })
.await,
);
}
Loading