Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion neqo-http3/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ pub struct RequestDescription<'b, T: RequestTarget> {
#[derive(Display)]
pub enum SessionAcceptAction {
Accept,
/// Accept the session and include additional headers in the 200 response.
AcceptWith(Vec<Header>),
Reject(Vec<Header>),
}

Expand Down Expand Up @@ -1334,11 +1336,18 @@ impl Http3Connection {
}
Ok(())
}
(Some(s), Some(_r), SessionAcceptAction::Accept) => {
(
Some(s),
Some(_r),
SessionAcceptAction::Accept | SessionAcceptAction::AcceptWith(_),
) => {
let mut response_headers = vec![Header::new(":status", "200")];
if connect_type == ExtendedConnectType::ConnectUdp {
response_headers.push(Header::new("capsule-protocol", "?1"));
}
if let SessionAcceptAction::AcceptWith(extra) = accept_res {
response_headers.extend_from_slice(extra);
}

if s.http_stream()
.ok_or(Error::InvalidStreamId)?
Expand Down Expand Up @@ -1400,6 +1409,23 @@ impl Http3Connection {
.collect()
}

/// Get the negotiated protocol for a WebTransport session.
///
/// Returns `Ok(None)` if no protocol was negotiated.
/// Returns an error if the session does not exist or is not a WebTransport session.

pub(crate) fn webtransport_session_protocol(
&self,
session_id: StreamId,
) -> Res<Option<String>> {
let stream = self
.send_streams
.get(&session_id)
.filter(|s| s.stream_type() == Http3StreamType::ExtendedConnect)
.ok_or(Error::InvalidStreamId)?;
Ok(stream.session_protocol())
}

pub(crate) fn connect_udp_close_session(
&mut self,
conn: &mut Connection,
Expand Down
16 changes: 16 additions & 0 deletions neqo-http3/src/connection_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1423,6 +1423,22 @@ impl Http3Client {
pub const fn webtransport_enabled(&self) -> bool {
self.base_handler.webtransport_enabled()
}

/// Get the negotiated subprotocol for a WebTransport session.
///
/// Returns the raw protocol value from the server's `wt-protocol` response header, or `None`
/// if the server did not include a `wt-protocol` header (or its value was malformed).
///
/// **Note:** this returns the server's raw value without validating it against the list of
/// protocols offered by the client. Callers are responsible for checking that the returned
/// protocol was among those originally offered.
///
/// # Errors
///
/// Returns error if the session ID is invalid.
pub fn webtransport_session_protocol(&self, session_id: StreamId) -> Res<Option<String>> {
self.base_handler.webtransport_session_protocol(session_id)
}
}

impl EventProvider for Http3Client {
Expand Down
31 changes: 31 additions & 0 deletions neqo-http3/src/features/extended_connect/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use std::{
time::Instant,
};

use sfv::{BareItem, Item, Parser};

use neqo_common::{Bytes, Encoder, Header, MessageType, Role, qdebug, qtrace};
use neqo_transport::{AppError, Connection, DatagramTracking, StreamId};

Expand Down Expand Up @@ -296,6 +298,22 @@ impl Session {
);
State::Done
} else {
let negotiated_protocol = headers
.iter()
.find(|h| h.name().eq_ignore_ascii_case("wt-protocol"))
.and_then(|h| Parser::new(h.value()).parse::<Item>().ok())
.and_then(|item| {
if let BareItem::String(s) = item.bare_item {
Some(s.into())
} else {
None
}
});

if let Some(protocol) = negotiated_protocol {
self.protocol.set_protocol(protocol);
}

self.events.session_start(
self.protocol.connect_type(),
self.id,
Expand Down Expand Up @@ -459,6 +477,10 @@ impl Stream for Rc<RefCell<Session>> {
fn stream_type(&self) -> Http3StreamType {
Http3StreamType::ExtendedConnect
}

fn session_protocol(&self) -> Option<String> {
self.borrow().protocol.protocol().map(|s| s.to_string())
}
}

impl RecvStream for Rc<RefCell<Session>> {
Expand Down Expand Up @@ -591,6 +613,15 @@ pub(crate) trait Protocol: Debug + Display {
(HashSet::default(), HashSet::default())
}

fn set_protocol(&mut self, _protocol: String) {
// Default implementation does nothing
}

fn protocol(&self) -> Option<&str> {
// Default implementation returns None
None
}

fn write_datagram_prefix(&self, encoder: &mut Encoder);

fn dgram_context_id(&self, datagram: Bytes) -> Result<Bytes, DgramContextIdError>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,3 +517,88 @@ fn wt_goaway_draining_rejected_session() {
"expected SessionClosed event for rejected session"
);
}

#[test]
fn wt_session_protocol_none_when_no_header() {
let mut wt = WtTest::new();
let wt_session = wt.create_wt_session();
let session_id = wt_session.stream_id();
assert_eq!(
wt.client.webtransport_session_protocol(session_id).unwrap(),
None
);
}

#[test]
fn wt_session_protocol_quoted_value() {
let mut wt = WtTest::new();
let (session_id, _) =
wt.negotiate_wt_session(&SessionAcceptAction::AcceptWith(vec![Header::new(
"wt-protocol",
r#""myproto""#,
)]));
assert_eq!(
wt.client.webtransport_session_protocol(session_id).unwrap(),
Some("myproto".to_string())
);
}

#[test]
fn wt_session_protocol_parameters_stripped() {
let mut wt = WtTest::new();
let (session_id, _) =
wt.negotiate_wt_session(&SessionAcceptAction::AcceptWith(vec![Header::new(
"wt-protocol",
r#""myproto"; foo=bar; baz=2"#,
)]));
assert_eq!(
wt.client.webtransport_session_protocol(session_id).unwrap(),
Some("myproto".to_string())
);
}

#[test]
fn wt_session_protocol_malformed_unquoted_rejected() {
let mut wt = WtTest::new();
let (session_id, _) =
wt.negotiate_wt_session(&SessionAcceptAction::AcceptWith(vec![Header::new(
"wt-protocol",
"myproto",
)]));
assert_eq!(
wt.client.webtransport_session_protocol(session_id).unwrap(),
None
);
}

#[test]
fn wt_session_protocol_invalid_stream_id() {
let wt = WtTest::new();
assert_eq!(
wt.client
.webtransport_session_protocol(StreamId::new(9999))
.unwrap_err(),
Error::InvalidStreamId
);
}

#[test]
fn wt_session_protocol_non_webtransport_session() {
let mut wt = WtTest::new();
let stream_id = wt
.client
.fetch(
now(),
"GET",
("https", "something.com", "/"),
&[],
Priority::default(),
)
.unwrap();
assert_eq!(
wt.client
.webtransport_session_protocol(stream_id)
.unwrap_err(),
Error::InvalidStreamId
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ pub struct Session {
///
/// [`HashSet`] size limited by QUIC connection stream limit.
pending_streams: HashSet<StreamId>,
/// The negotiated protocol from server response headers.
negotiated_protocol: Option<String>,
}

impl Display for Session {
Expand All @@ -52,6 +54,7 @@ impl Session {
recv_streams: HashSet::default(),
role,
pending_streams: HashSet::default(),
negotiated_protocol: None,
}
}
}
Expand Down Expand Up @@ -201,6 +204,14 @@ impl Protocol for Session {
)
}

fn set_protocol(&mut self, protocol: String) {
self.negotiated_protocol = Some(protocol);
}

fn protocol(&self) -> Option<&str> {
self.negotiated_protocol.as_deref()
}

fn write_datagram_prefix(&self, _encoder: &mut Encoder) {
// WebTransport does not add prefix (i.e. context ID).
}
Expand Down
4 changes: 4 additions & 0 deletions neqo-http3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ enum ReceiveOutput {

trait Stream: Debug {
fn stream_type(&self) -> Http3StreamType;

fn session_protocol(&self) -> Option<String> {
None
}
}

trait RecvStream: Stream {
Expand Down
Loading