diff --git a/relay-server/src/utils/multipart.rs b/relay-server/src/utils/multipart.rs index c4d0490944..1c0360880e 100644 --- a/relay-server/src/utils/multipart.rs +++ b/relay-server/src/utils/multipart.rs @@ -1,13 +1,14 @@ use std::future::Future; use std::io; -use std::task::Poll; use axum::extract::Request; -use bytes::{Bytes, BytesMut}; -use futures::{StreamExt, TryStreamExt}; +use bytes::Bytes; +use futures::TryStreamExt; use multer::{Constraints, Field, Multipart, SizeLimit}; use relay_config::Config; use serde::{Deserialize, Serialize}; +use tokio::io::AsyncReadExt; +use tokio_util::io::StreamReader; use crate::endpoints::common::BadStoreRequest; use crate::envelope::{AttachmentType, ContentType, Item, ItemType, Items}; @@ -185,9 +186,23 @@ pub async fn read_attachment_bytes_into_item( ignore_size_exceeded: bool, ) -> Result, multer::Error> { let content_type = field.content_type().cloned(); - let field = LimitedField::new(field, config.max_attachment_size()); - match field.bytes().await { - Ok(bytes) => { + let field_name = field.name().map(String::from); + let limit = config.max_attachment_size(); + // Extra byte needed to determine if limit was exceeded. + let mut take = StreamReader::new(field.map_err(io::Error::other)).take((limit + 1) as u64); + let mut buf = Vec::new(); + match take.read_to_end(&mut buf).await { + Ok(_) if buf.len() > limit => { + if ignore_size_exceeded { + return Ok(None); + } + Err(multer::Error::FieldSizeExceeded { + limit: limit as u64, + field_name, + }) + } + Ok(_) => { + let bytes = Bytes::from(buf); if let Some(content_type) = content_type { let ct = content_type .as_ref() @@ -199,8 +214,7 @@ pub async fn read_attachment_bytes_into_item( } Ok(Some(item)) } - Err(multer::Error::FieldSizeExceeded { .. }) if ignore_size_exceeded => Ok(None), - Err(err) => Err(err), + Err(io_err) => Err(multer::Error::StreamReadFailed(Box::new(io_err))), } } @@ -254,76 +268,6 @@ pub async fn multipart_items( Ok(items) } -/// Wrapper around `multer::Field` which consumes the entire underlying stream even when the -/// size limit is exceeded. -/// -/// The idea being that you can process fields in a multi-part form even if one fields is too large. -struct LimitedField<'a> { - field: Field<'a>, - consumed_size: usize, - size_limit: usize, - inner_finished: bool, -} - -impl<'a> LimitedField<'a> { - fn new(field: Field<'a>, limit: usize) -> Self { - LimitedField { - field, - consumed_size: 0, - size_limit: limit, - inner_finished: false, - } - } - - async fn bytes(self) -> Result { - self.try_fold(BytesMut::new(), |mut acc, x| async move { - acc.extend_from_slice(&x); - Ok(acc) - }) - .await - .map(|x| x.freeze()) - } -} - -impl futures::Stream for LimitedField<'_> { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.inner_finished { - return Poll::Ready(None); - } - - match self.field.poll_next_unpin(cx) { - err @ Poll::Ready(Some(Err(_))) => err, - Poll::Ready(Some(Ok(t))) => { - self.consumed_size += t.len(); - match self.consumed_size <= self.size_limit { - true => Poll::Ready(Some(Ok(t))), - false => { - cx.waker().wake_by_ref(); - Poll::Pending - } - } - } - Poll::Ready(None) if self.consumed_size > self.size_limit => { - self.inner_finished = true; - Poll::Ready(Some(Err(multer::Error::FieldSizeExceeded { - limit: self.consumed_size as u64, - field_name: self.field.name().map(Into::into), - }))) - } - Poll::Ready(None) => { - self.inner_finished = true; - Poll::Ready(None) - } - Poll::Pending => Poll::Pending, - } - } -} - pub fn multipart_from_request( request: Request, stream_size_limit: usize, diff --git a/tests/integration/test_playstation.py b/tests/integration/test_playstation.py index 4f3fa456c3..ea9f2f10d6 100644 --- a/tests/integration/test_playstation.py +++ b/tests/integration/test_playstation.py @@ -352,7 +352,7 @@ def test_playstation_max_stream_size_exceeded( assert response.status_code == 400, "Expected a 400 status code" assert ( response.content.decode("utf-8") - == f'{{"detail":"invalid multipart data","causes":["stream size exceeded limit: {stream_size_limit} bytes"]}}' + == f'{{"detail":"invalid multipart data","causes":["failed to read stream","stream size exceeded limit: {stream_size_limit} bytes"]}}' ) assert len(outcomes_consumer.get_outcomes()) == 0 @@ -419,7 +419,7 @@ def test_playstation_user_data_extraction( @pytest.mark.parametrize("use_pop_relay", [True, False]) -def test_playstation_large_attachments( +def test_playstation_upload_attachments( mini_sentry, relay_with_playstation, relay_processing_with_playstation, @@ -485,6 +485,36 @@ def test_playstation_large_attachments( assert chunks[dump_attachment["id"]] == playstation_dump +def test_playstation_ignore_large_attachments_when_uploading_disabled( + mini_sentry, + relay_with_playstation, +): + PROJECT_ID = 42 + config = playstation_project_config() + config["config"]["features"].remove("projects:relay-playstation-uploads") + mini_sentry.add_full_project_config(PROJECT_ID, extra=config) + playstation_dump = load_dump_file("user_data.prosperodmp") + relay = relay_with_playstation( + mini_sentry, + { + "limits": { + "max_attachment_size": len(playstation_dump), + }, + }, + ) + # Make a dummy video that exceeds max_attachment_size + video_content = "1" * 1024 * 1024 + + response = relay.send_playstation_request( + PROJECT_ID, playstation_dump, video_content + ) + + assert response.ok + assert [ + item.headers["filename"] for item in mini_sentry.get_captured_envelope().items + ] == ["playstation.prosperodmp"] + + def test_playstation_attachment( mini_sentry, relay_processing_with_playstation,