diff --git a/neqo-http3/src/connection.rs b/neqo-http3/src/connection.rs index 581b3d6bfb..b00b58cc10 100644 --- a/neqo-http3/src/connection.rs +++ b/neqo-http3/src/connection.rs @@ -60,6 +60,8 @@ pub struct RequestDescription<'b, T: RequestTarget> { #[derive(Display)] pub enum SessionAcceptAction { Accept, + /// Accept the session and include additional headers in the 200 response. + AcceptWith(Vec
), Reject(Vec
), } @@ -1334,11 +1336,18 @@ impl Http3Connection { } Ok(()) } - (Some(s), Some(_r), SessionAcceptAction::Accept) => { + ( + Some(s), + Some(_r), + SessionAcceptAction::Accept | SessionAcceptAction::AcceptWith(_), + ) => { let mut response_headers = vec![Header::new(":status", "200")]; if connect_type == ExtendedConnectType::ConnectUdp { response_headers.push(Header::new("capsule-protocol", "?1")); } + if let SessionAcceptAction::AcceptWith(extra) = accept_res { + response_headers.extend_from_slice(extra); + } if s.http_stream() .ok_or(Error::InvalidStreamId)? @@ -1400,6 +1409,23 @@ impl Http3Connection { .collect() } + /// Get the negotiated protocol for a WebTransport session. + /// + /// Returns `Ok(None)` if no protocol was negotiated. + /// Returns an error if the session does not exist or is not a WebTransport session. + + pub(crate) fn webtransport_session_protocol( + &self, + session_id: StreamId, + ) -> Res> { + let stream = self + .send_streams + .get(&session_id) + .filter(|s| s.stream_type() == Http3StreamType::ExtendedConnect) + .ok_or(Error::InvalidStreamId)?; + Ok(stream.session_protocol()) + } + pub(crate) fn connect_udp_close_session( &mut self, conn: &mut Connection, diff --git a/neqo-http3/src/connection_client.rs b/neqo-http3/src/connection_client.rs index 8dbc1331c9..1a6e2f5517 100644 --- a/neqo-http3/src/connection_client.rs +++ b/neqo-http3/src/connection_client.rs @@ -1423,6 +1423,22 @@ impl Http3Client { pub const fn webtransport_enabled(&self) -> bool { self.base_handler.webtransport_enabled() } + + /// Get the negotiated subprotocol for a WebTransport session. + /// + /// Returns the raw protocol value from the server's `wt-protocol` response header, or `None` + /// if the server did not include a `wt-protocol` header (or its value was malformed). + /// + /// **Note:** this returns the server's raw value without validating it against the list of + /// protocols offered by the client. Callers are responsible for checking that the returned + /// protocol was among those originally offered. + /// + /// # Errors + /// + /// Returns error if the session ID is invalid. + pub fn webtransport_session_protocol(&self, session_id: StreamId) -> Res> { + self.base_handler.webtransport_session_protocol(session_id) + } } impl EventProvider for Http3Client { diff --git a/neqo-http3/src/features/extended_connect/session.rs b/neqo-http3/src/features/extended_connect/session.rs index 2c03e580e8..c43a784668 100644 --- a/neqo-http3/src/features/extended_connect/session.rs +++ b/neqo-http3/src/features/extended_connect/session.rs @@ -13,6 +13,8 @@ use std::{ time::Instant, }; +use sfv::{BareItem, Item, Parser}; + use neqo_common::{Bytes, Encoder, Header, MessageType, Role, qdebug, qtrace}; use neqo_transport::{AppError, Connection, DatagramTracking, StreamId}; @@ -296,6 +298,22 @@ impl Session { ); State::Done } else { + let negotiated_protocol = headers + .iter() + .find(|h| h.name().eq_ignore_ascii_case("wt-protocol")) + .and_then(|h| Parser::new(h.value()).parse::().ok()) + .and_then(|item| { + if let BareItem::String(s) = item.bare_item { + Some(s.into()) + } else { + None + } + }); + + if let Some(protocol) = negotiated_protocol { + self.protocol.set_protocol(protocol); + } + self.events.session_start( self.protocol.connect_type(), self.id, @@ -459,6 +477,10 @@ impl Stream for Rc> { fn stream_type(&self) -> Http3StreamType { Http3StreamType::ExtendedConnect } + + fn session_protocol(&self) -> Option { + self.borrow().protocol.protocol().map(|s| s.to_string()) + } } impl RecvStream for Rc> { @@ -591,6 +613,15 @@ pub(crate) trait Protocol: Debug + Display { (HashSet::default(), HashSet::default()) } + fn set_protocol(&mut self, _protocol: String) { + // Default implementation does nothing + } + + fn protocol(&self) -> Option<&str> { + // Default implementation returns None + None + } + fn write_datagram_prefix(&self, encoder: &mut Encoder); fn dgram_context_id(&self, datagram: Bytes) -> Result; diff --git a/neqo-http3/src/features/extended_connect/tests/webtransport/sessions.rs b/neqo-http3/src/features/extended_connect/tests/webtransport/sessions.rs index 2fb9429034..0b505a0451 100644 --- a/neqo-http3/src/features/extended_connect/tests/webtransport/sessions.rs +++ b/neqo-http3/src/features/extended_connect/tests/webtransport/sessions.rs @@ -517,3 +517,88 @@ fn wt_goaway_draining_rejected_session() { "expected SessionClosed event for rejected session" ); } + +#[test] +fn wt_session_protocol_none_when_no_header() { + let mut wt = WtTest::new(); + let wt_session = wt.create_wt_session(); + let session_id = wt_session.stream_id(); + assert_eq!( + wt.client.webtransport_session_protocol(session_id).unwrap(), + None + ); +} + +#[test] +fn wt_session_protocol_quoted_value() { + let mut wt = WtTest::new(); + let (session_id, _) = + wt.negotiate_wt_session(&SessionAcceptAction::AcceptWith(vec![Header::new( + "wt-protocol", + r#""myproto""#, + )])); + assert_eq!( + wt.client.webtransport_session_protocol(session_id).unwrap(), + Some("myproto".to_string()) + ); +} + +#[test] +fn wt_session_protocol_parameters_stripped() { + let mut wt = WtTest::new(); + let (session_id, _) = + wt.negotiate_wt_session(&SessionAcceptAction::AcceptWith(vec![Header::new( + "wt-protocol", + r#""myproto"; foo=bar; baz=2"#, + )])); + assert_eq!( + wt.client.webtransport_session_protocol(session_id).unwrap(), + Some("myproto".to_string()) + ); +} + +#[test] +fn wt_session_protocol_malformed_unquoted_rejected() { + let mut wt = WtTest::new(); + let (session_id, _) = + wt.negotiate_wt_session(&SessionAcceptAction::AcceptWith(vec![Header::new( + "wt-protocol", + "myproto", + )])); + assert_eq!( + wt.client.webtransport_session_protocol(session_id).unwrap(), + None + ); +} + +#[test] +fn wt_session_protocol_invalid_stream_id() { + let wt = WtTest::new(); + assert_eq!( + wt.client + .webtransport_session_protocol(StreamId::new(9999)) + .unwrap_err(), + Error::InvalidStreamId + ); +} + +#[test] +fn wt_session_protocol_non_webtransport_session() { + let mut wt = WtTest::new(); + let stream_id = wt + .client + .fetch( + now(), + "GET", + ("https", "something.com", "/"), + &[], + Priority::default(), + ) + .unwrap(); + assert_eq!( + wt.client + .webtransport_session_protocol(stream_id) + .unwrap_err(), + Error::InvalidStreamId + ); +} diff --git a/neqo-http3/src/features/extended_connect/webtransport_session.rs b/neqo-http3/src/features/extended_connect/webtransport_session.rs index b41c93d4b6..2646a76c58 100644 --- a/neqo-http3/src/features/extended_connect/webtransport_session.rs +++ b/neqo-http3/src/features/extended_connect/webtransport_session.rs @@ -34,6 +34,8 @@ pub struct Session { /// /// [`HashSet`] size limited by QUIC connection stream limit. pending_streams: HashSet, + /// The negotiated protocol from server response headers. + negotiated_protocol: Option, } impl Display for Session { @@ -52,6 +54,7 @@ impl Session { recv_streams: HashSet::default(), role, pending_streams: HashSet::default(), + negotiated_protocol: None, } } } @@ -201,6 +204,14 @@ impl Protocol for Session { ) } + fn set_protocol(&mut self, protocol: String) { + self.negotiated_protocol = Some(protocol); + } + + fn protocol(&self) -> Option<&str> { + self.negotiated_protocol.as_deref() + } + fn write_datagram_prefix(&self, _encoder: &mut Encoder) { // WebTransport does not add prefix (i.e. context ID). } diff --git a/neqo-http3/src/lib.rs b/neqo-http3/src/lib.rs index 9c9b0f955e..268494753d 100644 --- a/neqo-http3/src/lib.rs +++ b/neqo-http3/src/lib.rs @@ -433,6 +433,10 @@ enum ReceiveOutput { trait Stream: Debug { fn stream_type(&self) -> Http3StreamType; + + fn session_protocol(&self) -> Option { + None + } } trait RecvStream: Stream {