diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b53a26..1e0217b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,7 +30,7 @@ jobs: with: components: clippy - uses: Swatinem/rust-cache@v2 - - run: cargo clippy --no-default-features --features hyper -- -D warnings + - run: cargo clippy --all-features --all-targets -- -D warnings doc: name: Documentation @@ -41,7 +41,9 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - - run: cargo doc --no-deps --no-default-features --features hyper + # rpc/ipc rustdoc needs a consumer-selected remoc codec and a separate + # docs-link cleanup; clippy, tests, and coverage still exercise all features. + - run: cargo doc --no-deps --no-default-features --features dquic,hyper,serde,testing,webtransport test: name: Tests @@ -50,4 +52,39 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - - run: cargo test --no-default-features --lib + - run: cargo test --all-features --all-targets + + coverage: + name: Coverage + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - name: Install cargo-llvm-cov + run: cargo install cargo-llvm-cov --locked + - name: Generate coverage summary + run: | + cargo llvm-cov clean + mkdir -p target/llvm-cov + cargo llvm-cov \ + --all-features \ + --all-targets \ + --summary-only \ + --json \ + --output-path target/llvm-cov/coverage.json \ + --fail-under-lines 62 \ + --fail-under-functions 56 \ + --fail-under-regions 64 + - name: Generate LCOV report + run: cargo llvm-cov report --lcov --output-path target/llvm-cov/lcov.info + - name: Generate HTML report + run: cargo llvm-cov report --html --output-dir target/llvm-cov + - name: Upload coverage reports + uses: actions/upload-artifact@v4 + with: + name: h3x-coverage + path: | + target/llvm-cov/coverage.json + target/llvm-cov/lcov.info + target/llvm-cov/html diff --git a/.github/workflows/publish-crates.yml b/.github/workflows/publish-crates.yml new file mode 100644 index 0000000..68037a1 --- /dev/null +++ b/.github/workflows/publish-crates.yml @@ -0,0 +1,53 @@ +name: Publish crates.io + +on: + pull_request: + workflow_dispatch: + push: + branches: + - main + tags: + - "v*" + +env: + CARGO_TERM_COLOR: always + +jobs: + release: + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust stable toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + + - name: Test crate + run: cargo test --all-features --all-targets + + - name: Authenticate to crates.io + if: github.ref_type == 'tag' && startsWith(github.ref_name, 'v') + uses: rust-lang/crates-io-auth-action@v1 + id: auth + + - name: Release h3x crate + shell: bash + env: + CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }} + run: | + set -euo pipefail + + if [[ "${GITHUB_REF_TYPE}" == "tag" && "${GITHUB_REF_NAME}" == v* ]]; then + mode=publish + else + mode=dry-run + fi + + if [[ "$mode" == "dry-run" ]]; then + cargo publish --dry-run + else + cargo publish + fi diff --git a/.gitignore b/.gitignore index 397ddce..54cf474 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ Cargo.lock /target AGENTS.md .sisyphus/ +test_locations diff --git a/Cargo.toml b/Cargo.toml index 68f6026..85b883f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "h3x" -description = "High-performance zero-copy DHTTP/3 implementation" -version = "0.2.0" +description = "Peer-to-peer DHTTP/3 transport over QUIC" +version = "0.3.0" edition = "2024" readme = "README.md" repository = "https://github.com/genmeta/h3x" @@ -26,12 +26,12 @@ snafu = "0.9" tokio = { version = "1", features = ["rt", "macros", "sync", "io-util"] } tokio-util = { version = "0.7", features = ["codec", "io", "rt"] } tracing = "0.1" -matchit = "0.9" -x509-parser = "0.18" tower-service = "0.3" +dhttp-identity = { git = "https://github.com/genmeta/dhttp.git", branch = "publish", version = "0.1.0" } + # feature dquic -dquic = { git = "ssh://git@github.com/genmeta/dquic.git", branch = "feat/v0.5.1", optional = true } +dquic = { git = "https://github.com/genmeta/dquic.git", tag = "v0.5.1", version = "0.5.1", optional = true } # feature hyper http-body = { version = "1", optional = true } @@ -50,7 +50,7 @@ remoc = { version = "0.18", default-features = false, optional = true, features ] } # feature pipe -nix = { version = "0.29", default-features = false, optional = true, features = [ +nix = { version = "0.31", default-features = false, optional = true, features = [ "fs", "socket", "uio", @@ -60,25 +60,16 @@ smallvec = { version = "1", optional = true } # feature endpoint arc-swap = { version = "1", optional = true } async-channel = { version = "2", optional = true } -derive_more = { version = "2", optional = true, features = [ - "deref", - "deref_mut", - "from", - "into", -] } either = { version = "1", optional = true } globset = { version = "0.4", optional = true } peg = { version = "0.8", optional = true } - [features] -default = ["dquic", "hyper", "endpoint"] -dquic = ["dep:dquic"] -endpoint = [ - "dquic", +default = ["dquic", "hyper"] +dquic = [ + "dep:dquic", "dep:arc-swap", "dep:async-channel", - "dep:derive_more", "dep:either", "dep:globset", "dep:peg", @@ -86,25 +77,15 @@ endpoint = [ ] hyper = ["dep:http-body", "dep:http-body-util", "dep:hyper"] serde = ["dep:serde"] +# EXPERIMENTAL — consumer crates must also enable exactly one remoc default-codec-* feature. +# h3x tests select default-codec-bincode in [dev-dependencies] for all-features CI. rpc = ["dep:remoc", "serde", "bytes/serde"] +# EXPERIMENTAL — depends on rpc and the same consumer-selected remoc codec. ipc = ["rpc", "dep:nix", "dep:smallvec", "tokio/net"] webtransport = [] testing = [] - [dev-dependencies] -axum = { version = "0.8", default-features = false, features = [ - "form", # (default) - # "http1", # (default) - "json", # (default) - "matched-path", # (default) - "original-uri", # (default) - "query", # (default) - "tokio", # (default) - "tower-log", # (default) - "tracing", # (default) - "macros", # for debug_handler -] } rustls = { version = "0.23", default-features = false, features = [ "ring", "logging", @@ -118,4 +99,7 @@ tokio = { version = "1", features = [ tracing-appender = "0.2" tracing-subscriber = "0.3" proptest = "1" -# remoc = { version = "0.18", features = ["default-codec-bincode"] } +remoc = { version = "0.18", default-features = false, features = [ + "default-codec-bincode", +] } +serde_json = "1" diff --git a/README.md b/README.md index 209178d..eb68dc3 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ [![License: Apache-2.0](https://img.shields.io/github/license/genmeta/h3x)](https://www.apache.org/licenses/LICENSE-2.0) -High-performance asynchronous DHTTP/3 implementation in Rust. +Peer-to-peer DHTTP/3 transport over QUIC, implemented in Rust. - **Peer-to-Peer**: Extends [HTTP3(RFC9114)](https://datatracker.ietf.org/doc/html/rfc9114) to *DHTTP/3*, allowing both sides of the connection to initiate and handle HTTP3 requests (achieved by disabling server push). - **Asynchronous I/O**: Built on the Rust asynchronous ecosystem, providing high-performance I/O processing capabilities. - **Zero-Copy**: Achieves full-link *zero-copy* from the QUIC layer to the application layer. - **Multipath QUIC**: Integrates the `dquic` implementation, featuring efficient transmission, robust authentication capabilities, and high extensibility. -- **Hyper / Tower Compatibility** *(feature `hyper`, enabled by default)*: Provides `TowerService` and `HyperService` adapters to run existing Tower or hyper services (e.g. `axum`) over DHTTP/3. Since h3x cannot construct hyper's internal types, the `h3x::hyper` module provides its own alternatives for upgrade and protocol negotiation. -- **Remoc** *(feature `remoc`, experimental)*: Optional [`remoc`](https://crates.io/crates/remoc) integration for remote trait calls (RTC) over QUIC connections. This is an experimental feature and the API may change. +- **Hyper / Tower Compatibility** *(feature `hyper`, enabled by default)*: Provides a single-file `h3x::hyper` facade for hyper-facing integrations. The facade exposes `TowerService`, `HyperService`, request execution errors, upgrade/takeover helpers, Extended CONNECT helpers, and protocol extension helpers. Lower-level hyper adapters also remain available from their semantic owners: `dhttp::message::hyper`, `endpoint::hyper`, `qpack::field::hyper`, and `extended_connect::hyper`. +- **RPC / IPC** *(features `rpc` and `ipc`, experimental)*: Optional [`remoc`](https://crates.io/crates/remoc) integration for remote trait calls (RTC) over QUIC connections, with optional IPC transport support. Consumers must select exactly one `remoc/default-codec-*` feature; h3x tests use the bincode codec through a dev-dependency. - **Extended CONNECT**: Supports [Extended CONNECT (RFC9220)](https://datatracker.ietf.org/doc/html/rfc9220) for protocol tunneling over HTTP/3. - **Future Extensions**: Plans to support extensions such as [WebTransport over HTTP/3 (Draft)](https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3-14). @@ -15,67 +15,51 @@ High-performance asynchronous DHTTP/3 implementation in Rust. > ⚠️ Currently, h3x is in the early stages of development, and the API may undergo significant changes. -h3x includes `dquic` as its built-in QUIC backend (feature `dquic`, enabled by default). The `h3x::dquic` module exposes wrapped types for HTTP/3 transport over `dquic`. +h3x includes `dquic` as its built-in QUIC backend (feature `dquic`, enabled by default). Wrap a `QuicEndpoint` in an `H3Endpoint` to get HTTP/3 client and server semantics on top of QUIC. -```rust -use h3x::dquic::{ - H3Client, H3Servers, - prelude::{BindUri, handy::ToCertificate}, -}; +```rust,no_run +use h3x::{dquic::QuicEndpoint, endpoint::H3Endpoint}; -async fn client_example() -> Result<(), Box> { - let mut roots = rustls::RootCertStore::empty(); - roots.add_parsable_certificates( - include_bytes!("tests/keychain/localhost/ca.cert").to_certificate(), - ); - let h3_client = H3Client::builder() - .with_root_certificates(roots) - .without_identity()? - .build(); - - // Initiate GET request - // The request stream is automatically closed when dropped - let (_, mut response) = h3_client - .new_request() - .get("localhost:4433/hello_world".parse()?) - .await?; - - // Check response status code - assert_eq!(response.status(), http::StatusCode::OK); - - let text = response.read_to_string().await?; - println!("Response: {:?}", text); +#[tokio::main] +async fn main() -> Result<(), Box> { + let endpoint = H3Endpoint::new(QuicEndpoint::new().await); + let mut response = endpoint.get("https://example.com:4433/hello".parse()?).await?; + assert_eq!(response.status(), http::StatusCode::OK); + println!("{}", response.read_to_string().await?); Ok(()) } +``` -async fn server_example() -> Result<(), Box> { - let mut app = H3Servers::builder() - .without_client_cert_verifier()? - .listen()?; - - let hello_world = async |request: &mut h3x::server::Request, - response: &mut h3x::server::Response| { - response - .set_status(http::StatusCode::OK) - .set_body(&b"Hello, World!"[..]); - }; - - app.add_server( - "localhost", - include_bytes!("tests/keychain/localhost/server.cert"), - include_bytes!("tests/keychain/localhost/server.key"), - None, - [BindUri::from("inet://[::1]:4433")], - h3x::server::Router::new().get("/hello_world", hello_world), - ) - .await?; +```rust,no_run +use std::sync::Arc; +use axum::{Router as AxumRouter, routing::get}; +use h3x::{ + dquic::{Identity, Name, QuicEndpoint}, + endpoint::H3Endpoint, + hyper::TowerService, +}; - app.run().await; +#[tokio::main] +async fn main() -> Result<(), Box> { + let identity = Arc::new(Identity { + name: Name::from_static("localhost")?, + certs: todo!("load your certificate chain"), + key: todo!("load your private key"), + ocsp: Arc::new(None), + }); + let endpoint = H3Endpoint::new( + QuicEndpoint::builder() + .identity(identity) + .bind(Arc::new(vec!["127.0.0.1:4433".parse()?])) + .build().await, + ); + let router = AxumRouter::new().route("/hello", get(|| async { "Hello, World!" })); + let service = TowerService(router.into_service()); + endpoint.listen(service).await?; Ok(()) } - ``` #### Hyper / Tower Integration @@ -89,60 +73,92 @@ h3x provides adapters to bridge the Tower / hyper service ecosystem into DHTTP/3 > - `h3x::hyper::upgrade` — stream takeover for Extended CONNECT tunnels (instead of `hyper::upgrade`) > - `h3x::hyper::ext::Protocol` — protocol indication in CONNECT requests (instead of `hyper::ext::Protocol`) -```rust -use axum::{Router, body::Body, routing::get}; +```rust,no_run +use std::sync::Arc; +use axum::{Router as AxumRouter, routing::get}; use h3x::{ - dquic::{H3Client, H3Servers, prelude::{BindUri, handy::ToCertificate}}, - hyper::server::TowerService, + dquic::{Identity, Name, QuicEndpoint}, + endpoint::H3Endpoint, + hyper::TowerService, }; -async fn serve_axum_over_dhttp3() -> Result<(), Box> { - // Build a standard Tower service — here an axum Router - let router = Router::new() +#[tokio::main] +async fn main() -> Result<(), Box> { + let router = AxumRouter::new() .route("/hello", get(|| async { "Hello from DHTTP/3!" })); - - // Wrap it with TowerService to bridge into h3x let service = TowerService(router.into_service()); - let mut app = H3Servers::builder() - .without_client_cert_verifier()? - .listen()?; - - app.add_server( - "localhost", - include_bytes!("tests/keychain/localhost/server.cert"), - include_bytes!("tests/keychain/localhost/server.key"), - None, - [BindUri::from("inet://[::1]:4433")], - service, + let identity = Arc::new(Identity { + name: Name::from_static("localhost")?, + certs: todo!("load your certificate chain"), + key: todo!("load your private key"), + ocsp: Arc::new(None), + }); + H3Endpoint::new( + QuicEndpoint::builder() + .identity(identity) + .bind(Arc::new(vec!["127.0.0.1:4433".parse()?])) + .build().await, ) + .listen(service) .await?; - - app.run().await; Ok(()) } +``` -// Client side — execute requests with hyper Body types -async fn hyper_client_example() -> Result<(), Box> { - let mut roots = rustls::RootCertStore::empty(); - roots.add_parsable_certificates( - include_bytes!("tests/keychain/localhost/ca.cert").to_certificate(), - ); - let h3_client = H3Client::builder() - .with_root_certificates(roots) - .without_identity()? - .build(); - - let connection = h3_client.connect("localhost:4433".parse()?).await?; +```rust,no_run +use h3x::{dquic::QuicEndpoint, endpoint::H3Endpoint}; - let response = connection - .execute_hyper_request( - http::Request::get("https://localhost:4433/hello") - .body(Body::empty())?, - ) - .await?; +#[tokio::main] +async fn main() -> Result<(), Box> { + let endpoint = H3Endpoint::new(QuicEndpoint::new().await); + let mut response = endpoint.get("https://example.com:4433/hello".parse()?).await?; assert_eq!(response.status(), http::StatusCode::OK); + let body = response.read_to_bytes().await?; + println!("{}", String::from_utf8_lossy(&body)); Ok(()) } -``` \ No newline at end of file +``` + +### Testing and Coverage + +h3x uses one canonical test entrypoint for local development and CI: + +```bash +cargo test --all-features --all-targets +``` + +All feature-gated paths, including `rpc`, `ipc`, `webtransport`, `dquic`, and `hyper`, must compile and test through that command. The dev-dependency on `remoc` selects the bincode default codec for tests so `rpc`/`ipc` can compile under `--all-features`. + +Documentation is checked with warnings denied for all non-`rpc`/`ipc` features: + +```bash +RUSTDOCFLAGS="-D warnings" cargo doc --no-deps --no-default-features --features dquic,hyper,serde,testing,webtransport +``` + +Coverage uses `cargo-llvm-cov` directly: + +```bash +cargo llvm-cov clean +mkdir -p target/llvm-cov +cargo llvm-cov \ + --all-features \ + --all-targets \ + --summary-only \ + --json \ + --output-path target/llvm-cov/coverage.json \ + --fail-under-lines 62 \ + --fail-under-functions 56 \ + --fail-under-regions 64 +cargo llvm-cov report --lcov --output-path target/llvm-cov/lcov.info +cargo llvm-cov report --html --output-dir target/llvm-cov +``` + +The HTML report entrypoint is: + +```text +target/llvm-cov/html/index.html +``` + +The initial coverage gate preserves the current baseline. Raise the gate only in commits that add meaningful tests or remove dead code. diff --git a/src/buflist.rs b/src/buflist.rs index eef7866..8da70cf 100644 --- a/src/buflist.rs +++ b/src/buflist.rs @@ -24,6 +24,13 @@ impl BufList { } } + /// Copy all remaining bytes from a [`Buf`] into a new [`BufList`]. + pub fn from_buf(buf: impl Buf) -> Self { + let mut buflist = Self::new(); + buflist.write(buf); + buflist + } + pub fn write(&mut self, mut buf: impl Buf) { while buf.has_remaining() { self.bufs.push_back(buf.copy_to_bytes(buf.chunk().len())); @@ -327,9 +334,13 @@ impl Extend for BuflistCursor { #[cfg(test)] mod tests { - use bytes::{Buf, Bytes}; + use std::io::IoSlice; - use super::BufList; + use bytes::{Buf, Bytes, BytesMut}; + use futures::{SinkExt, StreamExt}; + use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; + + use super::{BufList, BuflistCursor}; #[test] fn empty_buflist() { @@ -348,6 +359,14 @@ mod tests { assert_eq!(bl.chunk(), b"hello"); } + #[test] + fn from_buf_copies_all_remaining_data() { + let source = Bytes::from_static(b"hello world").slice(6..); + let bl = BufList::from_buf(source); + assert_eq!(bl.remaining(), 5); + assert_eq!(bl.chunk(), b"world"); + } + #[test] fn multiple_writes() { let mut bl = BufList::new(); @@ -422,6 +441,80 @@ mod tests { assert_eq!(bl.remaining(), 3); } + #[test] + fn chunks_vectored_reports_front_buffers() { + let bl: BufList = vec![ + Bytes::from_static(b"ab"), + Bytes::from_static(b"cd"), + Bytes::from_static(b"ef"), + ] + .into_iter() + .collect(); + let mut slices = [ + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + ]; + + let filled = bl.chunks_vectored(&mut slices); + + assert_eq!(filled, 3); + assert_eq!(&*slices[0], b"ab"); + assert_eq!(&*slices[1], b"cd"); + assert_eq!(&*slices[2], b"ef"); + } + + #[tokio::test] + async fn async_write_merges_unique_tail_and_flushes() { + let mut bl = BufList::new(); + bl.write(BytesMut::from(&b"hello"[..]).freeze()); + + bl.write_all(b" world").await.unwrap(); + AsyncWriteExt::flush(&mut bl).await.unwrap(); + bl.shutdown().await.unwrap(); + + assert_eq!(bl.bufs.len(), 1); + assert_eq!(bl.copy_to_bytes(11), Bytes::from_static(b"hello world")); + } + + #[tokio::test] + async fn async_write_preserves_shared_tail_before_appending() { + let mut bl = BufList::new(); + bl.write(Bytes::from_static(b"hello")); + + bl.write_all(b"!").await.unwrap(); + + assert_eq!(bl.bufs.len(), 2); + assert_eq!(bl.copy_to_bytes(6), Bytes::from_static(b"hello!")); + } + + #[tokio::test] + async fn sink_async_read_bufread_and_stream_delegations() { + let mut bl = BufList::new(); + bl.send(Bytes::from_static(b"ab")).await.unwrap(); + bl.feed(Bytes::from_static(b"cd")).await.unwrap(); + SinkExt::flush(&mut bl).await.unwrap(); + + let mut read = [0; 3]; + let read_len = bl.read(&mut read).await.unwrap(); + assert_eq!(read_len, 2); + assert_eq!(&read[..read_len], b"ab"); + + let filled = bl.fill_buf().await.unwrap(); + assert_eq!(filled, b"cd"); + bl.consume(1); + assert_eq!(bl.fill_buf().await.unwrap(), b"d"); + bl.consume(1); + assert!(bl.fill_buf().await.unwrap().is_empty()); + + bl.send(Bytes::from_static(b"ef")).await.unwrap(); + assert_eq!(bl.next().await, Some(Bytes::from_static(b"ef"))); + assert_eq!(bl.next().await, None); + + bl.close().await.unwrap(); + } + #[test] fn sequential_reads() { let mut bl = BufList::new(); @@ -448,4 +541,104 @@ mod tests { bl.write(Bytes::from_static(b"ab")); bl.advance(3); } + + fn cursor_source() -> BuflistCursor { + BuflistCursor::new( + vec![ + Bytes::from_static(b"ab"), + Bytes::from_static(b"cd"), + Bytes::from_static(b"ef"), + ] + .into_iter() + .collect(), + ) + } + + #[test] + fn cursor_reset_commit_inner_and_iteration_track_offsets() { + let mut cursor = cursor_source(); + + cursor.advance(3); + assert_eq!(cursor.chunk(), b"d"); + assert_eq!( + cursor.iter().collect::>(), + vec![Bytes::from_static(b"d"), Bytes::from_static(b"ef")] + ); + + cursor.reset(); + assert_eq!(cursor.chunk(), b"ab"); + assert_eq!(cursor.remaining(), 6); + + cursor.advance(3); + cursor.commit(); + + assert_eq!(cursor.inner().remaining(), 3); + assert_eq!(cursor.chunk(), b"d"); + assert_eq!(cursor.copy_to_bytes(3), Bytes::from_static(b"def")); + } + + #[test] + fn cursor_write_extend_and_empty_iteration() { + let mut cursor = BuflistCursor::new(BufList::new()); + + assert!(!cursor.has_remaining()); + assert_eq!(cursor.iter().collect::>(), vec![Bytes::new()]); + + cursor.write(Bytes::from_static(b"ab")); + cursor.extend([Bytes::from_static(b"cd")]); + + assert_eq!(cursor.inner().remaining(), 4); + assert_eq!(cursor.copy_to_bytes(4), Bytes::from_static(b"abcd")); + } + + #[test] + fn cursor_chunks_vectored_start_from_current_offset() { + let mut cursor = cursor_source(); + cursor.advance(1); + let mut slices = [ + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + IoSlice::new(&[]), + ]; + + let filled = cursor.chunks_vectored(&mut slices); + + assert_eq!(filled, 3); + assert_eq!(&*slices[0], b"b"); + assert_eq!(&*slices[1], b"cd"); + assert_eq!(&*slices[2], b"ef"); + } + + #[test] + fn cursor_copy_to_bytes_covers_zero_exact_and_spanning_paths() { + let mut zero = cursor_source(); + assert_eq!(zero.copy_to_bytes(0), Bytes::new()); + assert_eq!(zero.remaining(), 6); + + let mut exact = BuflistCursor::new( + vec![Bytes::from_static(b"abc")] + .into_iter() + .collect::(), + ); + exact.advance(1); + assert_eq!(exact.copy_to_bytes(2), Bytes::from_static(b"bc")); + assert_eq!(exact.remaining(), 0); + + let mut spanning = BuflistCursor::new( + vec![Bytes::from_static(b"ab"), Bytes::from_static(b"cd")] + .into_iter() + .collect::(), + ); + spanning.advance(1); + assert_eq!(spanning.copy_to_bytes(3), Bytes::from_static(b"bcd")); + assert_eq!(spanning.remaining(), 0); + } + + #[test] + #[should_panic(expected = "advance beyond buffer length")] + fn cursor_advance_beyond_panics() { + let mut cursor = cursor_source(); + cursor.advance(7); + } } diff --git a/src/client.rs b/src/client.rs deleted file mode 100644 index 8007093..0000000 --- a/src/client.rs +++ /dev/null @@ -1,70 +0,0 @@ -use std::sync::Arc; - -use http::uri::Authority; - -pub use crate::message::{stream::MessageStreamError, unify::ReadToStringError}; -use crate::{ - connection::{Connection, ConnectionBuilder}, - pool::{self, Pool}, - quic, -}; - -mod message; -pub use message::{PendingRequest, Request, RequestError, Response}; - -#[derive(Debug, Clone)] -pub struct Client { - pool: Pool, - client: C, - builder: Arc>, -} - -#[bon::bon] -impl Client { - #[builder( - builder_type(vis = "pub"), - start_fn(name = from_quic_client, vis = "pub") - )] - fn new( - #[builder(default = Pool::empty())] pool: Pool, - client: C, - #[builder(default = Arc::new(ConnectionBuilder::new(Arc::default())))] builder: Arc< - ConnectionBuilder, - >, - ) -> Self { - Self { - pool, - client, - builder, - } - } - - pub fn quic_client(&self) -> &C { - &self.client - } - - pub fn quic_client_mut(&mut self) -> &mut C { - &mut self.client - } - - /// Decompose this `Client` into its constituent parts. - #[allow(clippy::type_complexity)] - pub fn into_parts( - self, - ) -> ( - Pool, - C, - Arc>, - ) { - (self.pool, self.client, self.builder) - } - - pub async fn connect( - &self, - server: Authority, - ) -> Result>, pool::ConnectError> { - self.pool - .reuse_or_connect_with(&self.client, self.builder.clone(), server) - .await - } -} diff --git a/src/client/message.rs b/src/client/message.rs deleted file mode 100644 index 7747c86..0000000 --- a/src/client/message.rs +++ /dev/null @@ -1,619 +0,0 @@ -use std::{error::Error, sync::Arc}; - -use bytes::{Buf, Bytes}; -use futures::{Sink, Stream, StreamExt, TryFutureExt}; -use http::{ - HeaderMap, HeaderValue, Method, Uri, - header::{AsHeaderName, IntoHeaderName}, - uri::{Authority, PathAndQuery, Scheme}, -}; -use snafu::{Report, ResultExt, Snafu}; -use tracing::Instrument; - -use crate::{ - client::Client, - dhttp::protocol::InitialRawMessageStreamError, - error::Code, - message::{ - stream::{InitialMessageStreamError, MessageStreamError, ReadStream, WriteStream}, - unify::{MalformedMessageError, Message, MessageStage, ReadToStringError}, - }, - pool::ConnectError, - qpack::field::{MalformedHeaderSection, Protocol}, - quic::{self, agent}, -}; - -#[derive(Clone)] -pub struct PendingRequest<'c, C: quic::Connect> { - client: &'c Client, - request: Message, - auto_close: bool, -} - -impl<'c, C: quic::Connect> std::fmt::Debug for PendingRequest<'c, C> -where - Client: std::fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("PendingRequest") - .field("client", &self.client) - .field("request", &self.request) - .field("auto_close", &self.auto_close) - .finish() - } -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum RequestError { - #[snafu(transparent)] - Connect { source: ConnectError }, - #[snafu(transparent)] - Connection { source: quic::ConnectionError }, - #[snafu(display("request stream error"))] - RequestStream { source: quic::StreamError }, - #[snafu(display("response stream error"))] - ResponseStream { source: quic::StreamError }, - #[snafu(display("request cannot be sent due to malformed header"))] - MalformedRequestHeader { source: MalformedHeaderSection }, - #[snafu(display( - "header section too large to fit into a single frame, maybe too many header fields" - ))] - HeaderTooLarge, - #[snafu(display( - "trailer section too large to fit into a single frame, maybe too many header fields" - ))] - TrailerTooLarge, - #[snafu(display("data frame payload too large, try smaller chunk size"))] - DataFrameTooLarge, - #[snafu(display("response from peer is malformed"))] - MalformedResponse, -} - -impl Client { - pub fn new_request(&self) -> PendingRequest<'_, C> { - PendingRequest { - client: self, - request: Message::unresolved_request(), - auto_close: true, - } - } -} - -impl PendingRequest<'_, C> { - pub fn with_method(mut self, method: Method) -> Self { - self.request.header_mut().set_method(method); - self - } - - pub fn with_scheme(mut self, scheme: Scheme) -> Self { - self.request.header_mut().set_scheme(scheme); - self - } - - pub fn with_authority(mut self, authority: Authority) -> Self { - self.request.header_mut().set_authority(authority); - self - } - - pub fn with_path(mut self, path: PathAndQuery) -> Self { - self.request.header_mut().set_path(path); - self - } - - pub fn with_protocol(mut self, protocol: Protocol) -> Self { - self.request.header_mut().set_protocol(protocol); - self - } - - pub fn with_uri(mut self, uri: Uri) -> Self { - self.request.header_mut().set_uri(uri); - self - } - - pub fn headers(&self) -> &HeaderMap { - &self.request.header().header_map - } - - pub fn headers_mut(&mut self) -> &mut HeaderMap { - &mut self.request.header_mut().header_map - } - - pub fn with_header(mut self, name: impl IntoHeaderName, value: HeaderValue) -> Self { - self.headers_mut().insert(name, value); - self - } - - pub fn with_headers(mut self, headers: HeaderMap) -> Self { - *self.headers_mut() = headers; - self - } - - pub fn with_body(mut self, body: impl Buf) -> Self { - self.request.set_body(body); - self - } - - /// Whether to automatically close the request stream when the pending request is chunked. - /// - /// Default is `true` to adapt most use cases. - pub fn auto_close(mut self, auto_close: bool) -> Self { - self.auto_close = auto_close; - self - } - - pub fn trailers(&self) -> &HeaderMap { - self.request.trailers() - } - - pub fn trailers_mut(&mut self) -> &mut HeaderMap { - self.request.trailers_mut() - } - - pub fn with_trailer(mut self, name: impl IntoHeaderName, value: HeaderValue) -> Self { - self.trailers_mut().insert(name, value); - self - } - - pub fn with_trailers(mut self, trailers: HeaderMap) -> Self { - *self.trailers_mut() = trailers; - self - } -} - -impl<'a, C: quic::Connect + Sync> PendingRequest<'a, C> -where - C::Connection: Send + 'static, - ::StreamReader: Send, - ::StreamWriter: Send, -{ - /// Execute the request and return the response - #[tracing::instrument( - level = "debug", - target = "h3x::client", - name = "execute_request", - skip_all, - fields( - method = %self.request.header().method(), - uri = %self.request.header().uri(), - ) - )] - pub async fn execute(mut self) -> Result<(Request, Response), RequestError> { - self.request - .header() - .check_pseudo() - .context(request_error::MalformedRequestHeaderSnafu)?; - - if tracing::enabled!(tracing::Level::DEBUG) { - let span = tracing::Span::current(); - if !span.has_field("method") { - span.record("method", self.request.header().method().as_str()); - } - if !span.has_field("uri") { - span.record("uri", self.request.header().uri().to_string()); - } - } - - let authority = self.request.header().authority().expect("checked"); - - loop { - let connection = self.client.connect(authority.clone()).await?; - tracing::trace!(target: "h3x::client", %authority, "connected"); - - let (mut read_stream, mut write_stream) = match connection - .initial_message_stream() - .await - { - Ok(pair) => pair, - Err(InitialMessageStreamError::InitialRawStream { source }) => match source { - InitialRawMessageStreamError::Connection { source } => { - // Connection may have been silently closed (e.g. idle timeout). - // The error has been propagated to the SetOnce, so the pool - // will no longer return this dead connection. Retry with a - // fresh connection. - tracing::debug!( - target: "h3x::client", - ?source, - "connection error on reused connection, retrying..." - ); - continue; - } - InitialRawMessageStreamError::ResponseStream { source } => { - return Err(RequestError::ResponseStream { source }); - } - InitialRawMessageStreamError::Goaway { .. } => { - tracing::debug!(target: "h3x::client", "connection goaway, retrying..."); - continue; - } - }, - Err(InitialMessageStreamError::QPackProtocolDisabled { .. }) => { - unreachable!("Client always initializes the QPack protocol") - } - }; - - let Ok(local_agent) = connection.local_agent().await else { - continue; - }; - let Ok(remote_agent) = connection.remote_agent().await else { - continue; - }; - let remote_agent = remote_agent.expect("checked by Client::connect"); - - let send_request = async { - if self.auto_close && self.request.is_chunked() { - write_stream.close_message(&mut self.request).await - } else { - write_stream.send_message(&mut self.request).await - } - }; - - let mut response = Message::unresolved_response(); - - #[derive(Debug, PartialEq)] - enum Stream { - Request, - Response, - } - - return match tokio::try_join!( - send_request.map_err(|e| (Stream::Request, e)), - read_stream - .read_message_header(&mut response) - .map_err(|e| (Stream::Response, e)), - ) { - Ok(..) => { - let request = Request { - message: self.request, - stream: write_stream, - agent: local_agent, - }; - let response = Response { - message: response, - stream: read_stream, - agent: remote_agent, - }; - return Ok((request, response)); - } - Err((stream, MessageStreamError::HeaderTooLarge)) => { - debug_assert_eq!(stream, Stream::Response); - Err(RequestError::HeaderTooLarge) - } - Err((stream, MessageStreamError::TrailerTooLarge)) => { - debug_assert_eq!(stream, Stream::Response); - Err(RequestError::TrailerTooLarge) - } - Err((stream, MessageStreamError::DataFrameTooLarge { .. })) => { - debug_assert_eq!(stream, Stream::Request); - Err(RequestError::DataFrameTooLarge) - } - Err((stream, MessageStreamError::MalformedIncomingMessage)) => { - debug_assert_eq!(stream, Stream::Response); - Err(RequestError::MalformedResponse) - } - Err((stream, MessageStreamError::Quic { source })) => match stream { - Stream::Request => Err(RequestError::RequestStream { source }), - Stream::Response => Err(RequestError::ResponseStream { source }), - }, - Err((.., MessageStreamError::Goaway { .. })) => { - self.request = self.request.to_unsend(); - tracing::debug!(target: "h3x::client", "connection goaway, retrying..."); - continue; - } - }; - } - } - - pub async fn get(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::GET).with_uri(uri).execute().await - } - - pub async fn post(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::POST).with_uri(uri).execute().await - } - - pub async fn put(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::PUT).with_uri(uri).execute().await - } - - pub async fn delete(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::DELETE) - .with_uri(uri) - .execute() - .await - } - - pub async fn head(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::HEAD).with_uri(uri).execute().await - } - - pub async fn options(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::OPTIONS) - .with_uri(uri) - .execute() - .await - } - - pub async fn connect(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::CONNECT) - .with_uri(uri) - .execute() - .await - } - - pub async fn patch(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::PATCH) - .with_uri(uri) - .execute() - .await - } - - pub async fn trace(self, uri: Uri) -> Result<(Request, Response), RequestError> { - self.with_method(Method::TRACE) - .with_uri(uri) - .execute() - .await - } -} - -pub struct Request { - message: Message, - stream: WriteStream, - agent: Option>, -} - -impl Request { - pub fn method(&self) -> Method { - self.message.header().method() - } - - pub fn scheme(&self) -> Option { - self.message.header().scheme() - } - - pub fn authority(&self) -> Option { - self.message.header().authority() - } - - pub fn path(&self) -> Option { - self.message.header().path() - } - - pub fn uri(&self) -> Uri { - self.message.header().uri() - } - - pub fn headers(&self) -> &HeaderMap { - &self.message.header().header_map - } - - fn check_message_operation( - &mut self, - operation: &str, - check: impl FnOnce(&mut Self) -> Result<(), MalformedMessageError>, - ) { - if self.message.is_malformed() { - tracing::warn!( - target: "h3x::client", operation, - "Request is malformed, operation will not affect the request stream", - ); - } - if let Err(error) = check(self) { - tracing::warn!( - target: "h3x::client", operation, error = %Report::from_error(error), - "Operation malformed the request message, request stream will be cancelled with H3_REQUEST_CANCELLED", - ); - self.message.set_malformed(); - } - } - - pub async fn write( - &mut self, - content: impl Buf + Send, - ) -> Result<&mut Self, MessageStreamError> { - self.check_message_operation("write_streaming_body", |this| { - // header is checked in pending request - this.message.streaming_body()?; - Ok(()) - }); - self.stream - .send_message_streaming_body(&mut self.message, content) - .await?; - Ok(self) - } - - pub async fn flush(&mut self) -> Result<&mut Self, MessageStreamError> { - // header is checked in pending request - self.stream.flush_message(&mut self.message).await?; - Ok(self) - } - - pub fn as_sink(&mut self) -> impl Sink { - crate::message::stream::unfold::write::unfold( - self, - async |request: &mut Self, buf: B| { - request.write(buf).await?; - Ok(request) - }, - async |request: &mut Self| { - request.flush().await?; - Ok(request) - }, - async |request: &mut Self| { - request.close().await?; - Ok(request) - }, - ) - } - - pub fn into_sink(self) -> impl Sink { - crate::message::stream::unfold::write::unfold( - self, - async |request: Self, buf: B| { - let mut request = request; - request.write(buf).await?; - Ok(request) - }, - async |request: Self| { - let mut request = request; - request.flush().await?; - Ok(request) - }, - async |request: Self| { - let mut request = request; - request.close().await?; - Ok(request) - }, - ) - } - - pub fn trailers(&self) -> &HeaderMap { - self.message.trailers() - } - - pub fn trailers_mut(&mut self) -> &mut HeaderMap { - self.check_message_operation("modify_trailers", |this| { - if this.message.stage() >= MessageStage::Trailer { - return Err(MalformedMessageError::TrailerAlreadySent); - } - Ok(()) - }); - self.message.trailers_mut() - } - - pub fn set_trailer(&mut self, name: impl IntoHeaderName, value: HeaderValue) -> &mut Self { - self.trailers_mut().insert(name, value); - self - } - - pub fn set_trailers(&mut self, map: HeaderMap) -> &mut Self { - *self.trailers_mut() = map; - self - } - - pub async fn close(&mut self) -> Result<(), MessageStreamError> { - self.stream.close_message(&mut self.message).await - } - - pub async fn cancel(&mut self, code: Code) -> Result<(), MessageStreamError> { - self.stream.cancel(code).await - } - - /// Low level access to the underlying write stream - pub fn write_stream(&mut self) -> &mut WriteStream { - &mut self.stream - } - - pub fn agent(&self) -> Option<&Arc> { - self.agent.as_ref() - } - - /// Async drop the request properly - pub(crate) fn drop(&mut self) -> Option + Send + use<>> { - if self.message.is_complete() || self.message.is_dropped() { - return None; - } - let mut stream = self.stream.take(); - let mut message = self.message.take(); - - // if !message.is_malformed() { - // let check = || { - // // There is no check that could fail - // Ok(()) - // }; - // if let Err(error) = check() { - // message.set_malformed(); - // tracing::warn!( - // target: "h3x::client", error = %Report::from_error(error), - // "Request stream cannot be closed properly as its malformed", - // ); - // } - // } - - Some(async move { _ = stream.close_message(&mut message).await }) - } -} - -impl Drop for Request { - fn drop(&mut self) { - if let Some(future) = self.drop() { - // Best-effort: send the end-of-stream marker before the request is dropped. - tokio::spawn(future.in_current_span()); - } - } -} - -pub struct Response { - message: Message, - stream: ReadStream, - agent: Arc, -} - -impl Response { - pub async fn next_response(&mut self) -> Result<&mut Self, MessageStreamError> { - self.stream.read_message_header(&mut self.message).await?; - Ok(self) - } - - pub fn status(&self) -> http::StatusCode { - self.message.header().status() - } - - pub fn headers(&mut self) -> &HeaderMap { - &self.message.header().header_map - } - - pub fn header(&mut self, name: impl AsHeaderName) -> Option<&HeaderValue> { - self.headers().get(name) - } - - pub async fn read(&mut self) -> Option> { - self.stream.read_message(&mut self.message).await - } - - pub async fn read_all(&mut self) -> Result { - self.stream.read_message_full_body(&mut self.message).await - } - - pub async fn read_to_bytes(&mut self) -> Result { - self.stream - .read_message_body_to_bytes(&mut self.message) - .await - } - - pub async fn read_to_string(&mut self) -> Result { - self.stream - .read_message_body_to_string(&mut self.message) - .await - } - - pub async fn as_stream(&mut self) -> impl Stream> { - futures::stream::unfold(self, async |this| { - this.read().await.map(|item| (item, this)) - }) - .fuse() - } - - pub async fn into_stream(self) -> impl Stream> { - futures::stream::unfold(self, async |mut this| { - this.read().await.map(|item| (item, this)) - }) - .fuse() - } - - pub async fn trailers(&mut self) -> Result<&HeaderMap, MessageStreamError> { - self.stream.read_message_trailer(&mut self.message).await - } - - pub async fn stop(&mut self, code: Code) -> Result<(), MessageStreamError> { - self.stream.stop(code).await - } - - /// Low level access to the underlying read stream - pub fn read_stream(&mut self) -> &mut ReadStream { - &mut self.stream - } - - pub fn agent(&self) -> &Arc { - &self.agent - } -} diff --git a/src/codec.rs b/src/codec.rs index 70e5db8..5797d3a 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -1,4 +1,4 @@ -use std::{future::Future, pin::Pin}; +use std::future::Future; use crate::quic; @@ -15,38 +15,15 @@ pub use reader::{FixedLengthReader, PeekableStreamReader, StreamReader}; use tokio::io::{self, AsyncBufRead, AsyncBufReadExt}; pub use writer::{Feed, SinkWriter}; -pub type BoxStreamReader = StreamReader::StreamReader>>>; +/// Boxed stream reader wrapping a boxed QUIC read stream. +pub type BoxStreamReader = StreamReader>; -/// Type alias for a peekable unidirectional stream. -pub type BoxPeekableUniStream = - PeekableStreamReader::StreamReader>>>; +/// Boxed stream writer wrapping a boxed QUIC write stream. +pub type BoxStreamWriter = SinkWriter>; -/// Type alias for a peekable bidirectional stream pair. -pub type BoxPeekableBiStream = ( - PeekableStreamReader::StreamReader>>>, - SinkWriter::StreamWriter>>>, -); - -/// Raw erased read stream: a pinned, boxed trait object for `quic::ReadStream`. -pub type BoxReadStream = Pin>; - -/// Raw erased write stream: a pinned, boxed trait object for `quic::WriteStream`. -pub type BoxWriteStream = Pin>; - -/// Erased (non-generic) stream reader wrapping a boxed `ReadStream` trait object. -pub type ErasedStreamReader = StreamReader; - -/// Erased (non-generic) sink writer wrapping a boxed `WriteStream` trait object. -pub type ErasedStreamWriter = SinkWriter; - -/// Erased (non-generic) peekable unidirectional stream reader. -pub type ErasedPeekableUniStream = PeekableStreamReader; - -/// Erased (non-generic) peekable bidirectional stream pair. -pub type ErasedPeekableBiStream = ( - PeekableStreamReader, - SinkWriter, -); +/// Boxed peekable stream reader wrapping a boxed QUIC read stream. +pub type BoxPeekableStreamReader = + PeekableStreamReader>; pub trait EncodeInto: Sized { type Output; @@ -132,3 +109,76 @@ pub trait DecodeExt { } impl DecodeExt for S {} + +#[cfg(test)] +mod tests { + use std::{ + io, + pin::Pin, + task::{Context, Poll}, + }; + + use futures::StreamExt; + use tokio::io::{AsyncRead, ReadBuf}; + + use super::*; + + #[derive(Debug, PartialEq, Eq)] + struct FailingItem; + + #[derive(Debug, PartialEq, Eq)] + enum TestDecodeError { + Decode, + Io(io::ErrorKind), + } + + impl From for TestDecodeError { + fn from(error: io::Error) -> Self { + Self::Io(error.kind()) + } + } + + impl<'s, S> DecodeFrom<&'s mut S> for FailingItem + where + S: AsyncBufRead + Send + 's, + { + type Error = TestDecodeError; + + async fn decode_from(_stream: &'s mut S) -> Result { + Err(TestDecodeError::Decode) + } + } + + struct FailingRead; + + impl AsyncRead for FailingRead { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(io::Error::other("read failed"))) + } + } + + #[tokio::test] + async fn into_decode_stream_yields_item_decode_errors() { + let decoder = io::Cursor::new([1]); + let stream = decoder.into_decode_stream::(); + futures::pin_mut!(stream); + + assert_eq!(stream.next().await, Some(Err(TestDecodeError::Decode))); + } + + #[tokio::test] + async fn into_decode_stream_yields_io_errors() { + let decoder = tokio::io::BufReader::new(FailingRead); + let stream = decoder.into_decode_stream::(); + futures::pin_mut!(stream); + + assert_eq!( + stream.next().await, + Some(Err(TestDecodeError::Io(io::ErrorKind::Other))) + ); + } +} diff --git a/src/codec/error.rs b/src/codec/error.rs index c6c1cd4..6547ed3 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -424,3 +424,655 @@ impl From for StreamEncodeError { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::{Code, H3NoError}; + + fn varint(value: u32) -> VarInt { + VarInt::from_u32(value) + } + + fn transport_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: varint(1), + frame_type: varint(2), + reason: reason.into(), + }, + } + } + + fn application_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Application { + source: quic::ApplicationError { + code: Code::H3_INTERNAL_ERROR, + reason: reason.into(), + }, + } + } + + fn connection_error(reason: &'static str) -> connection::ConnectionError { + transport_error(reason).into() + } + + fn assert_connection_reason(error: &connection::ConnectionError, expected: &str) { + let connection::ConnectionError::Quic { + source: quic::ConnectionError::Transport { source }, + } = error + else { + panic!("expected transport connection error"); + }; + assert_eq!(source.reason.as_ref(), expected); + } + + fn assert_stream_reset(error: connection::StreamError, expected: VarInt) { + let connection::StreamError::Reset { code } = error else { + panic!("expected stream reset"); + }; + assert_eq!(code, expected); + } + + fn assert_stream_h3(error: connection::StreamError) { + let connection::StreamError::H3 { source } = error else { + panic!("expected h3 stream error"); + }; + assert_eq!(source.code(), Code::H3_MESSAGE_ERROR); + } + + fn assert_no_source(error: &(dyn std::error::Error + 'static)) { + assert!( + error.source().is_none(), + "expected no source, got {error:?}" + ); + } + + #[test] + fn decode_error_io_roundtrips_and_classifies_plain_eof() { + let cases = [ + (DecodeError::Incomplete, io::ErrorKind::UnexpectedEof), + (DecodeError::IntegerOverflow, io::ErrorKind::InvalidData), + (DecodeError::InvalidHuffmanCode, io::ErrorKind::InvalidData), + (DecodeError::ArithmeticOverflow, io::ErrorKind::InvalidData), + (DecodeError::DecompressionFailed, io::ErrorKind::InvalidData), + ]; + + for (error, expected_kind) in cases { + let io_error = io::Error::from(error); + assert_eq!(io_error.kind(), expected_kind); + assert_eq!( + DecodeError::try_from(io_error).expect("decode error"), + error + ); + } + + let eof = io::Error::new(io::ErrorKind::UnexpectedEof, "plain eof"); + assert_eq!( + DecodeError::try_from(eof).expect("plain eof should map to incomplete"), + DecodeError::Incomplete + ); + + let other = io::Error::new(io::ErrorKind::InvalidData, "plain invalid data"); + let other = DecodeError::try_from(other).expect_err("plain error should be preserved"); + assert_eq!(other.kind(), io::ErrorKind::InvalidData); + } + + #[test] + fn encode_error_io_roundtrips() { + for error in [ + EncodeError::FramePayloadTooLarge, + EncodeError::HuffmanEncoding, + ] { + let io_error = io::Error::from(error); + assert_eq!(io_error.kind(), io::ErrorKind::InvalidData); + assert_eq!( + EncodeError::try_from(io_error).expect("encode error"), + error + ); + } + + let other = io::Error::new(io::ErrorKind::InvalidData, "plain invalid data"); + let other = EncodeError::try_from(other).expect_err("plain error should be preserved"); + assert_eq!(other.kind(), io::ErrorKind::InvalidData); + } + + #[test] + fn codec_leaf_errors_have_expected_display_debug_and_source() { + let decode_cases = [ + ( + DecodeError::Incomplete, + "stream closed unexpectedly", + "Incomplete", + ), + ( + DecodeError::IntegerOverflow, + "integer too large (overflow u64)", + "IntegerOverflow", + ), + ( + DecodeError::InvalidHuffmanCode, + "invalid huffman code", + "InvalidHuffmanCode", + ), + ( + DecodeError::ArithmeticOverflow, + "arithmetic overflow while decoding", + "ArithmeticOverflow", + ), + ( + DecodeError::DecompressionFailed, + "QPACK decompression failed", + "DecompressionFailed", + ), + ]; + + for (error, display, debug_fragment) in decode_cases { + assert_eq!(error.to_string(), display); + assert!(format!("{error:?}").contains(debug_fragment)); + assert_no_source(&error); + } + + let encode_cases = [ + ( + EncodeError::FramePayloadTooLarge, + "frame payload too large (overflow 2^62-1)", + "FramePayloadTooLarge", + ), + ( + EncodeError::HuffmanEncoding, + "header name/value contains bytes out of QPACK allowed range", + "HuffmanEncoding", + ), + ]; + + for (error, display, debug_fragment) in encode_cases { + assert_eq!(error.to_string(), display); + assert!(format!("{error:?}").contains(debug_fragment)); + assert_no_source(&error); + } + } + + #[test] + fn huffman_library_errors_map_to_codec_errors() { + assert_eq!( + DecodeError::from(httlib_huffman::DecoderError::InvalidInput), + DecodeError::InvalidHuffmanCode + ); + assert_eq!( + EncodeError::from(httlib_huffman::EncoderError::InvalidInput), + EncodeError::FramePayloadTooLarge + ); + } + + #[test] + fn stream_decode_escalation_covers_all_branches() { + let error = StreamDecodeError::Connection { + source: connection_error("connection"), + }; + let escalated = error.escalate_reset(|_| connection_error("reset")); + let ConnectionDecodeError::Connection { source } = escalated else { + panic!("expected connection branch"); + }; + assert_connection_reason(&source, "connection"); + + let escalated = StreamDecodeError::Reset { code: varint(7) } + .escalate_reset(|code| connection_error(if code == varint(7) { "reset" } else { "" })); + let ConnectionDecodeError::Connection { source } = escalated else { + panic!("expected reset escalation"); + }; + assert_connection_reason(&source, "reset"); + + let escalated = StreamDecodeError::Decode { + source: DecodeError::IntegerOverflow, + } + .escalate_reset(|_| connection_error("unused")); + let ConnectionDecodeError::Decode { source } = escalated else { + panic!("expected decode branch"); + }; + assert_eq!(source, DecodeError::IntegerOverflow); + } + + #[test] + fn stream_decode_critical_close_escalates_reset_and_incomplete() { + let escalated = StreamDecodeError::Connection { + source: connection_error("connection"), + } + .escalate_critical_close(|| connection_error("unused")); + let ConnectionDecodeError::Connection { source } = escalated else { + panic!("expected connection branch"); + }; + assert_connection_reason(&source, "connection"); + + for error in [ + StreamDecodeError::Reset { code: varint(1) }, + StreamDecodeError::Decode { + source: DecodeError::Incomplete, + }, + ] { + let escalated = error.escalate_critical_close(|| connection_error("closed")); + let ConnectionDecodeError::Connection { source } = escalated else { + panic!("expected critical close"); + }; + assert_connection_reason(&source, "closed"); + } + + let escalated = StreamDecodeError::Decode { + source: DecodeError::InvalidHuffmanCode, + } + .escalate_critical_close(|| connection_error("unused")); + let ConnectionDecodeError::Decode { source } = escalated else { + panic!("expected decode branch"); + }; + assert_eq!(source, DecodeError::InvalidHuffmanCode); + } + + #[test] + fn stream_decode_error_recovers_plain_eof_as_incomplete_decode() { + let eof = io::Error::new(io::ErrorKind::UnexpectedEof, "plain eof"); + let StreamDecodeError::Decode { source } = StreamDecodeError::from(eof) else { + panic!("expected incomplete decode error"); + }; + assert_eq!(source, DecodeError::Incomplete); + } + + #[test] + fn stream_decode_into_stream_error_covers_all_branches() { + let error = StreamDecodeError::Connection { + source: connection_error("connection"), + } + .into_stream_error(|_| H3NoError.into()); + let connection::StreamError::Connection { source } = error else { + panic!("expected connection stream error"); + }; + assert_connection_reason(&source, "connection"); + + assert_stream_reset( + StreamDecodeError::Reset { code: varint(9) }.into_stream_error(|_| H3NoError.into()), + varint(9), + ); + + assert_stream_h3( + StreamDecodeError::Decode { + source: DecodeError::DecompressionFailed, + } + .into_stream_error(|_| crate::error::H3MessageError::UnexpectedHeadersInBody.into()), + ); + } + + #[test] + fn connection_decode_into_stream_error_covers_all_branches() { + let error = ConnectionDecodeError::Connection { + source: connection_error("connection"), + } + .into_stream_error(|_| H3NoError.into()); + let connection::StreamError::Connection { source } = error else { + panic!("expected connection stream error"); + }; + assert_connection_reason(&source, "connection"); + + assert_stream_h3( + ConnectionDecodeError::Decode { + source: DecodeError::ArithmeticOverflow, + } + .into_stream_error(|_| crate::error::H3MessageError::MissingHeaderSection.into()), + ); + } + + #[test] + fn stream_decode_error_recovers_from_io_error_sources() { + let direct = io::Error::from(StreamDecodeError::Reset { code: varint(1) }); + let StreamDecodeError::Reset { code } = StreamDecodeError::from(direct) else { + panic!("expected direct stream decode error"); + }; + assert_eq!(code, varint(1)); + + let quic_reset = io::Error::from(quic::StreamError::Reset { code: varint(2) }); + let StreamDecodeError::Reset { code } = StreamDecodeError::from(quic_reset) else { + panic!("expected quic stream reset"); + }; + assert_eq!(code, varint(2)); + + let quic_connection = io::Error::from(quic::StreamError::Connection { + source: transport_error("quic stream connection"), + }); + let StreamDecodeError::Connection { source } = StreamDecodeError::from(quic_connection) + else { + panic!("expected quic connection"); + }; + assert_connection_reason(&source, "quic stream connection"); + + let h3_connection = io::Error::from(connection_error("h3 connection")); + let StreamDecodeError::Connection { source } = StreamDecodeError::from(h3_connection) + else { + panic!("expected h3 connection"); + }; + assert_connection_reason(&source, "h3 connection"); + + let decode = io::Error::from(DecodeError::ArithmeticOverflow); + let StreamDecodeError::Decode { source } = StreamDecodeError::from(decode) else { + panic!("expected decode error"); + }; + assert_eq!(source, DecodeError::ArithmeticOverflow); + } + + #[test] + fn stream_decode_error_panics_on_untyped_io_error() { + let panic = std::panic::catch_unwind(|| { + let _ = StreamDecodeError::from(io::Error::other("plain decode error")); + }); + assert!(panic.is_err()); + } + + #[test] + fn connection_decode_error_converts_to_io() { + let io_error = io::Error::from(ConnectionDecodeError::Connection { + source: connection_error("connection"), + }); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + + let io_error = io::Error::from(ConnectionDecodeError::Decode { + source: DecodeError::Incomplete, + }); + assert_eq!(io_error.kind(), io::ErrorKind::UnexpectedEof); + } + + #[test] + fn stream_decode_error_display_debug_source_and_quic_conversion() { + let connection = StreamDecodeError::Connection { + source: connection_error("connection display"), + }; + assert_eq!( + connection.to_string(), + "transport error (0x1 in frame 0x2): connection display" + ); + assert!(format!("{connection:?}").contains("Connection")); + assert_no_source(&connection); + let io_error = io::Error::from(connection); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + assert_eq!( + io_error.get_ref().expect("wrapped error").to_string(), + "transport error (0x1 in frame 0x2): connection display" + ); + + let reset = StreamDecodeError::Reset { code: varint(55) }; + assert_eq!(reset.to_string(), "stream reset with code 55"); + assert!(format!("{reset:?}").contains("Reset")); + assert_no_source(&reset); + let io_error = io::Error::from(reset); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + assert_eq!(io_error.to_string(), "stream reset with code 55"); + + let decode = StreamDecodeError::Decode { + source: DecodeError::DecompressionFailed, + }; + assert_eq!(decode.to_string(), "QPACK decompression failed"); + assert!(format!("{decode:?}").contains("Decode")); + assert_no_source(&decode); + assert_eq!(io::Error::from(decode).kind(), io::ErrorKind::InvalidData); + + let converted = StreamDecodeError::from(transport_error("quic decode from conn")); + let StreamDecodeError::Connection { source } = converted else { + panic!("expected connection variant"); + }; + assert_connection_reason(&source, "quic decode from conn"); + + let converted = ConnectionDecodeError::from(transport_error("quic conn decode")); + let ConnectionDecodeError::Connection { source } = converted else { + panic!("expected connection variant"); + }; + assert_connection_reason(&source, "quic conn decode"); + } + + #[test] + fn connection_decode_error_display_debug_and_source() { + let connection = ConnectionDecodeError::Connection { + source: connection_error("connection decode display"), + }; + assert_eq!( + connection.to_string(), + "transport error (0x1 in frame 0x2): connection decode display" + ); + assert!(format!("{connection:?}").contains("Connection")); + assert_no_source(&connection); + + let decode = ConnectionDecodeError::Decode { + source: DecodeError::ArithmeticOverflow, + }; + assert_eq!(decode.to_string(), "arithmetic overflow while decoding"); + assert!(format!("{decode:?}").contains("Decode")); + assert_no_source(&decode); + let io_error = io::Error::from(decode); + assert_eq!(io_error.kind(), io::ErrorKind::InvalidData); + assert_eq!( + io_error.get_ref().expect("wrapped error").to_string(), + "arithmetic overflow while decoding" + ); + } + + #[test] + fn stream_encode_escalation_covers_all_branches() { + let error = StreamEncodeError::Connection { + source: connection_error("connection"), + }; + let escalated = error.escalate_reset(|_| connection_error("reset")); + let ConnectionEncodeError::Connection { source } = escalated else { + panic!("expected connection branch"); + }; + assert_connection_reason(&source, "connection"); + + let escalated = StreamEncodeError::Reset { code: varint(7) } + .escalate_reset(|code| connection_error(if code == varint(7) { "reset" } else { "" })); + let ConnectionEncodeError::Connection { source } = escalated else { + panic!("expected reset escalation"); + }; + assert_connection_reason(&source, "reset"); + + let escalated = StreamEncodeError::Encode { + source: EncodeError::HuffmanEncoding, + } + .escalate_reset(|_| connection_error("unused")); + let ConnectionEncodeError::Encode { source } = escalated else { + panic!("expected encode branch"); + }; + assert_eq!(source, EncodeError::HuffmanEncoding); + } + + #[test] + fn stream_encode_into_stream_error_covers_all_branches() { + let error = StreamEncodeError::Connection { + source: connection_error("connection"), + } + .into_stream_error(|_| H3NoError.into()); + let connection::StreamError::Connection { source } = error else { + panic!("expected connection stream error"); + }; + assert_connection_reason(&source, "connection"); + + assert_stream_reset( + StreamEncodeError::Reset { code: varint(9) }.into_stream_error(|_| H3NoError.into()), + varint(9), + ); + + assert_stream_h3( + StreamEncodeError::Encode { + source: EncodeError::HuffmanEncoding, + } + .into_stream_error(|_| crate::error::H3MessageError::UnexpectedHeadersInBody.into()), + ); + } + + #[test] + fn connection_encode_into_stream_error_covers_all_branches() { + let error = ConnectionEncodeError::Connection { + source: connection_error("connection"), + } + .into_stream_error(|_| H3NoError.into()); + let connection::StreamError::Connection { source } = error else { + panic!("expected connection stream error"); + }; + assert_connection_reason(&source, "connection"); + + assert_stream_h3( + ConnectionEncodeError::Encode { + source: EncodeError::FramePayloadTooLarge, + } + .into_stream_error(|_| crate::error::H3MessageError::MissingHeaderSection.into()), + ); + } + + #[test] + fn stream_encode_error_recovers_from_io_error_sources() { + let direct = io::Error::from(StreamEncodeError::Reset { code: varint(1) }); + let StreamEncodeError::Reset { code } = StreamEncodeError::from(direct) else { + panic!("expected direct stream encode error"); + }; + assert_eq!(code, varint(1)); + + let quic_reset = io::Error::from(quic::StreamError::Reset { code: varint(2) }); + let StreamEncodeError::Reset { code } = StreamEncodeError::from(quic_reset) else { + panic!("expected quic stream reset"); + }; + assert_eq!(code, varint(2)); + + let quic_connection = io::Error::from(quic::StreamError::Connection { + source: application_error("quic stream connection"), + }); + let StreamEncodeError::Connection { source } = StreamEncodeError::from(quic_connection) + else { + panic!("expected quic connection"); + }; + let connection::ConnectionError::Quic { + source: quic::ConnectionError::Application { source }, + } = source + else { + panic!("expected application error"); + }; + assert_eq!(source.reason.as_ref(), "quic stream connection"); + + let h3_connection = io::Error::from(connection_error("h3 connection")); + let StreamEncodeError::Connection { source } = StreamEncodeError::from(h3_connection) + else { + panic!("expected h3 connection"); + }; + assert_connection_reason(&source, "h3 connection"); + + let encode = io::Error::from(EncodeError::HuffmanEncoding); + let StreamEncodeError::Encode { source } = StreamEncodeError::from(encode) else { + panic!("expected encode error"); + }; + assert_eq!(source, EncodeError::HuffmanEncoding); + } + + #[test] + fn stream_encode_error_panics_on_untyped_io_error() { + let panic = std::panic::catch_unwind(|| { + let _ = StreamEncodeError::from(io::Error::other("plain encode error")); + }); + assert!(panic.is_err()); + } + + #[test] + fn connection_encode_error_converts_to_io() { + let io_error = io::Error::from(ConnectionEncodeError::Connection { + source: connection_error("connection"), + }); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + + let io_error = io::Error::from(ConnectionEncodeError::Encode { + source: EncodeError::FramePayloadTooLarge, + }); + assert_eq!(io_error.kind(), io::ErrorKind::InvalidData); + } + + #[test] + fn stream_encode_error_display_debug_source_and_quic_conversion() { + let connection = StreamEncodeError::Connection { + source: connection_error("encode connection display"), + }; + assert_eq!( + connection.to_string(), + "transport error (0x1 in frame 0x2): encode connection display" + ); + assert!(format!("{connection:?}").contains("Connection")); + assert_no_source(&connection); + let io_error = io::Error::from(connection); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + assert_eq!( + io_error.get_ref().expect("wrapped error").to_string(), + "transport error (0x1 in frame 0x2): encode connection display" + ); + + let reset = StreamEncodeError::Reset { code: varint(88) }; + assert_eq!(reset.to_string(), "stream reset with code 88"); + assert!(format!("{reset:?}").contains("Reset")); + assert_no_source(&reset); + let io_error = io::Error::from(reset); + assert_eq!(io_error.kind(), io::ErrorKind::BrokenPipe); + assert_eq!(io_error.to_string(), "stream reset with code 88"); + + let encode = StreamEncodeError::Encode { + source: EncodeError::FramePayloadTooLarge, + }; + assert_eq!( + encode.to_string(), + "frame payload too large (overflow 2^62-1)" + ); + assert!(format!("{encode:?}").contains("Encode")); + assert_no_source(&encode); + assert_eq!(io::Error::from(encode).kind(), io::ErrorKind::InvalidData); + + let converted = StreamEncodeError::from(application_error("quic encode from conn")); + let StreamEncodeError::Connection { source } = converted else { + panic!("expected connection variant"); + }; + let connection::ConnectionError::Quic { + source: quic::ConnectionError::Application { source }, + } = source + else { + panic!("expected application error"); + }; + assert_eq!(source.reason.as_ref(), "quic encode from conn"); + + let converted = ConnectionEncodeError::from(application_error("quic conn encode")); + let ConnectionEncodeError::Connection { source } = converted else { + panic!("expected connection variant"); + }; + let connection::ConnectionError::Quic { + source: quic::ConnectionError::Application { source }, + } = source + else { + panic!("expected application error"); + }; + assert_eq!(source.reason.as_ref(), "quic conn encode"); + } + + #[test] + fn connection_encode_error_display_debug_and_source() { + let connection = ConnectionEncodeError::Connection { + source: connection_error("connection encode display"), + }; + assert_eq!( + connection.to_string(), + "transport error (0x1 in frame 0x2): connection encode display" + ); + assert!(format!("{connection:?}").contains("Connection")); + assert_no_source(&connection); + + let encode = ConnectionEncodeError::Encode { + source: EncodeError::HuffmanEncoding, + }; + assert_eq!( + encode.to_string(), + "header name/value contains bytes out of QPACK allowed range" + ); + assert!(format!("{encode:?}").contains("Encode")); + assert_no_source(&encode); + let io_error = io::Error::from(encode); + assert_eq!(io_error.kind(), io::ErrorKind::InvalidData); + assert_eq!( + io_error.get_ref().expect("wrapped error").to_string(), + "header name/value contains bytes out of QPACK allowed range" + ); + } +} diff --git a/src/codec/reader.rs b/src/codec/reader.rs index 86f1630..135e838 100644 --- a/src/codec/reader.rs +++ b/src/codec/reader.rs @@ -493,3 +493,242 @@ impl quic::StopStream for PeekableStreamReader { Pin::new(&mut self.get_mut().stream).poll_stop(cx, code) } } + +#[cfg(test)] +mod tests { + use std::{future::poll_fn, pin::Pin}; + + use bytes::Bytes; + use futures::{StreamExt, stream}; + use tokio::io::{AsyncBufReadExt, AsyncReadExt}; + + use super::*; + use crate::quic::{GetStreamId, StopStream, StreamError}; + + fn byte_stream( + chunks: impl IntoIterator, + ) -> impl Stream> { + stream::iter( + chunks + .into_iter() + .map(|chunk| Ok::<_, StreamError>(Bytes::from_static(chunk))), + ) + } + + #[tokio::test] + async fn stream_reader_accessors_mapping_and_io_paths() { + let mut reader = StreamReader::new(byte_stream([&b""[..], &b"hello"[..], &b" world"[..]])); + assert_eq!(reader.size_hint(), (3, Some(3))); + assert_eq!(reader.stream().size_hint(), (3, Some(3))); + assert_eq!(reader.stream_mut().size_hint(), (3, Some(3))); + + let mut pinned = Pin::new(&mut reader); + assert!(pinned.as_mut().has_remaining().await.unwrap()); + + let mut prefix = [0; 2]; + pinned + .as_mut() + .read_exact(&mut prefix) + .await + .expect("partial async read should succeed"); + assert_eq!(&prefix, b"he"); + + assert_eq!( + pinned.as_mut().fill_buf().await.expect("fill buffered"), + b"llo" + ); + tokio::io::AsyncBufRead::consume(pinned.as_mut(), 3); + + assert_eq!( + pinned + .as_mut() + .next() + .await + .expect("remaining chunk should be yielded") + .unwrap(), + Bytes::from_static(b" world") + ); + assert!( + !pinned + .as_mut() + .has_remaining() + .await + .expect("eof should be clean") + ); + + let mapped = StreamReader::new(byte_stream([&b"a"[..]])).map_stream(|inner| inner); + let mut mapped = Box::pin(mapped); + assert_eq!( + mapped + .as_mut() + .next() + .await + .expect("mapped stream should yield") + .unwrap(), + Bytes::from_static(b"a") + ); + + let inner = StreamReader::new(byte_stream([&b"inner"[..]])).into_inner(); + assert_eq!(inner.size_hint(), (1, Some(1))); + } + + #[tokio::test] + async fn fixed_length_reader_limits_renews_and_reports_incomplete_reads() { + let mut reader = FixedLengthReader::new(tokio::io::empty(), 0); + let mut empty = [0; 1]; + let read = reader + .read(&mut empty) + .await + .expect("zero remaining reads eof"); + assert_eq!(read, 0); + + let mut reader = FixedLengthReader::new(tokio::io::empty(), 1); + let err = reader + .read_exact(&mut empty) + .await + .expect_err("premature eof should be incomplete"); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); + + let mut reader = FixedLengthReader::new(tokio::io::BufReader::new(&b"abcdef"[..]), 4); + let mut buf = [0; 8]; + let read = reader.read(&mut buf).await.expect("bounded read"); + assert_eq!(read, 4); + assert_eq!(&buf[..read], b"abcd"); + assert_eq!(reader.read(&mut buf).await.expect("remaining exhausted"), 0); + + Pin::new(&mut reader).renew(2); + assert_eq!(reader.fill_buf().await.expect("renewed fill buffer"), b"ef"); + reader.consume(1); + assert_eq!( + reader.fill_buf().await.expect("remaining after consume"), + b"f" + ); + + let stream = reader.stream_mut(); + assert_eq!(stream.buffer().len(), 1); + + let mut pinned = Pin::new(&mut reader); + let stream = pinned.as_mut().project_stream_mut(); + assert_eq!(stream.buffer().len(), 1); + } + + #[tokio::test] + async fn fixed_length_streams_slice_owned_and_borrowed_stream_readers() { + let stream_reader = StreamReader::new(byte_stream([&b"ab"[..], &b"cdef"[..]])); + let mut reader = Box::pin(FixedLengthReader::new(stream_reader, 5)); + + assert_eq!( + reader + .as_mut() + .next() + .await + .expect("first chunk") + .expect("first chunk ok"), + Bytes::from_static(b"ab") + ); + assert_eq!( + reader + .as_mut() + .next() + .await + .expect("second chunk") + .expect("second chunk ok"), + Bytes::from_static(b"cde") + ); + assert!(reader.as_mut().next().await.is_none()); + + let mut stream_reader = StreamReader::new(byte_stream([&b"12"[..], &b"345"[..]])); + let mut borrowed = Box::pin(FixedLengthReader::new(&mut stream_reader, 4)); + assert_eq!( + borrowed + .as_mut() + .next() + .await + .expect("borrowed first chunk") + .expect("borrowed first chunk ok"), + Bytes::from_static(b"12") + ); + assert_eq!( + borrowed + .as_mut() + .next() + .await + .expect("borrowed second chunk") + .expect("borrowed second chunk ok"), + Bytes::from_static(b"34") + ); + assert!(borrowed.as_mut().next().await.is_none()); + + let mut incomplete = Box::pin(FixedLengthReader::new( + StreamReader::new(byte_stream([&b"x"[..]])), + 2, + )); + assert_eq!( + incomplete + .as_mut() + .next() + .await + .expect("available byte") + .expect("available byte ok"), + Bytes::from_static(b"x") + ); + assert!( + incomplete + .as_mut() + .next() + .await + .expect("incomplete item") + .is_err() + ); + } + + #[tokio::test] + async fn peekable_stream_reader_reset_commit_flush_and_traits() { + let stream_id = VarInt::from_u32(91); + let stop_code = VarInt::from_u32(92); + let (mock_reader, _mock_writer) = crate::quic::test::mock_stream_pair(stream_id); + let mut reader = Box::pin(PeekableStreamReader::new(StreamReader::new(mock_reader))); + + assert_eq!( + poll_fn(|cx| reader.as_mut().poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| reader.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop should delegate"); + + let mut reader = Box::pin(PeekableStreamReader::new(StreamReader::new(byte_stream([ + &b"abc"[..], + &b"def"[..], + ])))); + assert_eq!(reader.as_mut().fill_buf().await.unwrap(), b"abc"); + tokio::io::AsyncBufRead::consume(reader.as_mut(), 2); + reader.as_mut().reset(); + assert_eq!(reader.as_mut().fill_buf().await.unwrap(), b"abc"); + reader.as_mut().consume(2); + reader.as_mut().commit(); + assert_eq!(reader.as_mut().fill_buf().await.unwrap(), b"c"); + + let mut one = [0; 1]; + reader + .as_mut() + .read_exact(&mut one) + .await + .expect("peekable async read should consume committed byte"); + assert_eq!(&one, b"c"); + + assert_eq!(reader.as_mut().fill_buf().await.unwrap(), b"def"); + reader.as_mut().consume(1); + let mut plain = Pin::into_inner(reader).into_stream_reader(); + assert_eq!( + Pin::new(&mut plain) + .next() + .await + .expect("flushed buffered tail") + .unwrap(), + Bytes::from_static(b"ef") + ); + } +} diff --git a/src/codec/writer.rs b/src/codec/writer.rs index d0c6ea8..bef2697 100644 --- a/src/codec/writer.rs +++ b/src/codec/writer.rs @@ -10,7 +10,7 @@ use futures::{Sink, Stream, StreamExt}; use tokio::io::{self, AsyncWrite}; use crate::{ - quic::{CancelStream, GetStreamId, StreamError}, + quic::{GetStreamId, ResetStream, StreamError}, varint::VarInt, }; @@ -175,13 +175,15 @@ where } } -impl CancelStream for SinkWriter { - fn poll_cancel( +impl ResetStream for SinkWriter { + fn poll_reset( self: Pin<&mut Self>, cx: &mut Context, code: VarInt, ) -> Poll> { - self.project().sink.poll_cancel(cx, code) + let project = self.project(); + project.buffer.clear(); + project.sink.poll_reset(cx, code) } } @@ -264,3 +266,363 @@ where self.project().sink.poll_close(cx) } } + +#[cfg(test)] +mod tests { + use std::{ + fmt, + pin::Pin, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use futures::{ + Sink, SinkExt, + stream::{self, pending}, + task::noop_waker, + }; + use tokio::io::{AsyncWriteExt, Error as IoError}; + + use super::{Feed, SinkWriter}; + use crate::{ + quic::{GetStreamId, ResetStream, StreamError}, + varint::VarInt, + }; + + #[derive(Debug, Clone, Copy, Eq, PartialEq)] + struct TestError(&'static str); + + impl fmt::Display for TestError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } + } + + impl std::error::Error for TestError {} + + impl From for IoError { + fn from(error: TestError) -> Self { + Self::other(error) + } + } + + #[test] + fn test_error_display_and_io_conversion_are_stable() { + assert_eq!(TestError("not ready").to_string(), "not ready"); + + let error = IoError::from(TestError("io bridge")); + + assert_eq!(error.kind(), std::io::ErrorKind::Other); + assert_eq!(error.to_string(), "io bridge"); + } + + #[derive(Default)] + struct RecordingSink { + items: Vec, + ready_pending: usize, + ready_error: Option, + flushes: usize, + closes: usize, + reset_codes: Vec, + stream_id: Option, + } + + impl Sink for RecordingSink { + type Error = TestError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if let Some(error) = self.ready_error.take() { + return Poll::Ready(Err(error)); + } + if self.ready_pending > 0 { + self.ready_pending -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.items.push(item); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.flushes += 1; + Poll::Ready(Ok(())) + } + + fn poll_close( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.closes += 1; + Poll::Ready(Ok(())) + } + } + + impl ResetStream for RecordingSink { + fn poll_reset( + mut self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.reset_codes.push(code); + Poll::Ready(Ok(())) + } + } + + impl GetStreamId for RecordingSink { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id.expect("stream id is configured"))) + } + } + + #[tokio::test] + async fn async_write_buffers_small_writes_until_flush() { + let mut writer = SinkWriter::new(RecordingSink::default()); + + writer.write_all(b"hello ").await.expect("write succeeds"); + writer.write_all(b"world").await.expect("write succeeds"); + assert!(writer.sink().items.is_empty()); + + AsyncWriteExt::flush(&mut writer) + .await + .expect("flush succeeds"); + + assert_eq!( + writer.sink().items, + vec![Bytes::from_static(b"hello world")] + ); + assert_eq!(writer.sink().flushes, 1); + } + + #[tokio::test] + async fn async_write_flushes_existing_buffer_before_buffering_large_write() { + let mut writer = SinkWriter::new(RecordingSink::default()); + let large = vec![b'x'; 8 * 1024]; + + writer.write_all(b"small").await.expect("write succeeds"); + writer.write_all(&large).await.expect("write succeeds"); + + assert_eq!(writer.sink().items, vec![Bytes::from_static(b"small")]); + + AsyncWriteExt::flush(&mut writer) + .await + .expect("flush succeeds"); + + assert_eq!( + writer.sink().items, + vec![Bytes::from_static(b"small"), Bytes::from(large)] + ); + assert_eq!(writer.sink().flushes, 1); + } + + #[tokio::test] + async fn async_write_shutdown_flushes_buffer_and_closes_inner_sink() { + let mut writer = SinkWriter::new(RecordingSink::default()); + + writer.write_all(b"body").await.expect("write succeeds"); + writer.shutdown().await.expect("shutdown succeeds"); + + assert_eq!(writer.sink().items, vec![Bytes::from_static(b"body")]); + assert_eq!(writer.sink().closes, 1); + } + + #[tokio::test] + async fn sink_accessors_map_sink_and_into_inner_preserve_buffer_and_inner_state() { + let mut writer = SinkWriter::new(RecordingSink::default()); + writer.sink_mut().flushes = 7; + writer.write_all(b"buffered").await.expect("write succeeds"); + + let mut writer = writer.map_sink(|mut sink| { + sink.stream_id = Some(VarInt::from_u32(9)); + sink + }); + writer.flush_buffer().await.expect("flush succeeds"); + let inner = writer.into_inner(); + + assert_eq!(inner.items, vec![Bytes::from_static(b"buffered")]); + assert_eq!(inner.flushes, 7); + assert_eq!(inner.stream_id, Some(VarInt::from_u32(9))); + } + + #[test] + fn reset_stream_and_get_stream_id_forward_to_inner_sink() { + let mut writer = SinkWriter::new(RecordingSink { + stream_id: Some(VarInt::from_u32(33)), + ..RecordingSink::default() + }); + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + let reset_code = VarInt::from_u32(44); + assert!(matches!( + Pin::new(&mut writer).poll_reset(&mut cx, reset_code), + Poll::Ready(Ok(())) + )); + assert_eq!(writer.sink().reset_codes, vec![reset_code]); + assert!(matches!( + Pin::new(&mut writer).poll_stream_id(&mut cx), + Poll::Ready(Ok(id)) if id == VarInt::from_u32(33) + )); + } + + #[tokio::test] + async fn reset_discards_buffered_committed_item() { + let mut writer = SinkWriter::new(RecordingSink::default()); + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(matches!( + Pin::new(&mut writer).poll_ready(&mut cx), + Poll::Ready(Ok(())) + )); + Pin::new(&mut writer) + .start_send(Bytes::from_static(b"committed")) + .expect("start_send commits item into writer buffer"); + assert!( + writer.sink().items.is_empty(), + "item should still be buffered before flush" + ); + + let reset_code = VarInt::from_u32(45); + assert!(matches!( + Pin::new(&mut writer).poll_reset(&mut cx, reset_code), + Poll::Ready(Ok(())) + )); + SinkExt::flush(&mut writer) + .await + .expect("flush after reset should not send cleared item"); + + assert_eq!(writer.sink().reset_codes, vec![reset_code]); + assert!( + writer.sink().items.is_empty(), + "reset should supersede buffered send-side work" + ); + } + + #[tokio::test] + async fn feed_ready_sends_queued_item_after_pending_inner_readiness() { + let mut feed = Feed::new(RecordingSink { + ready_pending: 1, + ..RecordingSink::default() + }); + + Pin::new(&mut feed) + .start_send(Bytes::from_static(b"queued")) + .expect("queue item"); + Pin::new(&mut feed).ready().await.expect("ready succeeds"); + + assert_eq!(feed.sink.items, vec![Bytes::from_static(b"queued")]); + } + + #[tokio::test] + async fn feed_send_all_sends_each_stream_item_and_poll_close_closes_inner_sink() { + let mut feed = Feed::new(RecordingSink::default()); + + Pin::new(&mut feed) + .send_all(stream::iter([ + Bytes::from_static(b"one"), + Bytes::from_static(b"two"), + ])) + .await + .expect("send all succeeds"); + assert_eq!( + feed.sink.items, + vec![Bytes::from_static(b"one"), Bytes::from_static(b"two")] + ); + + Pin::new(&mut feed) + .start_send(Bytes::from_static(b"closing")) + .expect("queue close item"); + Pin::new(&mut feed).close().await.expect("close succeeds"); + + assert_eq!( + feed.sink.items, + vec![ + Bytes::from_static(b"one"), + Bytes::from_static(b"two"), + Bytes::from_static(b"closing") + ] + ); + assert_eq!(feed.sink.closes, 1); + } + + #[test] + #[should_panic(expected = "start_send called before poll_ready")] + fn feed_start_send_panics_when_previous_item_has_not_been_flushed() { + let mut feed = Feed::new(RecordingSink::default()); + + Pin::new(&mut feed) + .start_send(Bytes::from_static(b"first")) + .expect("queue first item"); + Pin::new(&mut feed) + .start_send(Bytes::from_static(b"second")) + .expect("second item panics first"); + } + + #[test] + fn feed_poll_ready_preserves_item_on_pending_and_error() { + let mut pending_feed = Feed::new(RecordingSink { + ready_pending: 1, + ..RecordingSink::default() + }); + Pin::new(&mut pending_feed) + .start_send(Bytes::from_static(b"pending")) + .expect("queue pending item"); + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(matches!( + Pin::new(&mut pending_feed).poll_ready(&mut cx), + Poll::Pending + )); + assert!(pending_feed.item.is_some()); + assert!(matches!( + Pin::new(&mut pending_feed).poll_ready(&mut cx), + Poll::Ready(Ok(())) + )); + assert_eq!( + pending_feed.sink.items, + vec![Bytes::from_static(b"pending")] + ); + + let mut error_feed = Feed::new(RecordingSink { + ready_error: Some(TestError("not ready")), + ..RecordingSink::default() + }); + Pin::new(&mut error_feed) + .start_send(Bytes::from_static(b"error")) + .expect("queue error item"); + + assert_eq!( + Pin::new(&mut error_feed).poll_ready(&mut cx), + Poll::Ready(Err(TestError("not ready"))) + ); + assert!(error_feed.item.is_some()); + assert!(error_feed.sink.items.is_empty()); + } + + #[tokio::test] + async fn feed_send_all_waits_for_stream_item_after_becoming_ready() { + let mut feed = Feed::<_, Bytes>::new(RecordingSink::default()); + let mut send_all = Box::pin(Pin::new(&mut feed).send_all(pending())); + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(matches!(send_all.as_mut().poll(&mut cx), Poll::Pending)); + drop(send_all); + + assert!(feed.sink.items.is_empty()); + } +} diff --git a/src/connection.rs b/src/connection.rs index f5e8703..d36028f 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -8,6 +8,7 @@ use std::{ sync::Arc, }; +use dhttp_identity::identity as authority; use futures::future::BoxFuture; use snafu::Snafu; use tokio::task::JoinHandle; @@ -20,7 +21,7 @@ use crate::{ error::{Code, H3ConnectionError, H3StreamError}, protocol::{IdentifiedProtocolInitializer, ProductProtocol, Protocols, StreamVerdict}, qpack::protocol::QPackProtocolFactory, - quic::{self, CancelStreamExt, StopStreamExt, agent}, + quic::{self, ResetStreamExt, StopStreamExt}, varint::VarInt, }; @@ -354,6 +355,12 @@ impl fmt::Display for ConnectionBuilder { } } +impl Default for ConnectionBuilder { + fn default() -> Self { + Self::new(Arc::new(Settings::default())) + } +} + impl ConnectionBuilder { pub fn new(settings: Arc) -> Self { let builder = Self { @@ -457,7 +464,7 @@ impl ConnectionState { /// Erase the concrete QUIC connection type, yielding a state that is /// usable as `ConnectionState`. /// - /// Used on the server path so that [`UnresolvedRequest`](crate::server::UnresolvedRequest) + /// Used on the server path so that [`UnresolvedRequest`](crate::endpoint::UnresolvedRequest) /// can carry a single type-erased connection handle regardless of the /// underlying QUIC implementation. #[must_use] @@ -510,22 +517,22 @@ impl ConnectionState { } } -impl ConnectionState { - pub async fn local_agent( +impl ConnectionState { + pub async fn local_authority( &self, - ) -> Result>, quic::ConnectionError> { + ) -> Result>, quic::ConnectionError> { // Goes through the object-safe trait so that this impl applies - // uniformly to both sized `C: WithLocalAgent` (via the blanket impl) + // uniformly to both sized `C: WithLocalAuthority` (via the blanket impl) // and `dyn DynConnection`. - quic::DynWithLocalAgent::local_agent(&*self.quic).await + quic::DynWithLocalAuthority::local_authority(&*self.quic).await } } -impl ConnectionState { - pub async fn remote_agent( +impl ConnectionState { + pub async fn remote_authority( &self, - ) -> Result>, quic::ConnectionError> { - quic::DynWithRemoteAgent::remote_agent(&*self.quic).await + ) -> Result>, quic::ConnectionError> { + quic::DynWithRemoteAuthority::remote_authority(&*self.quic).await } } @@ -541,9 +548,9 @@ impl ConnectionState { } }; let stream_reader = - StreamReader::new(Box::pin(reader) as crate::codec::BoxReadStream); + StreamReader::new(Box::pin(reader) as crate::quic::BoxQuicStreamReader); let stream_writer = - SinkWriter::new(Box::pin(writer) as crate::codec::BoxWriteStream); + SinkWriter::new(Box::pin(writer) as crate::quic::BoxQuicStreamWriter); let peekable_bi_stream = (PeekableStreamReader::new(stream_reader), stream_writer); match state.protocols.accept_bi(peekable_bi_stream).await { @@ -560,7 +567,7 @@ impl ConnectionState { // https://datatracker.ietf.org/doc/html/rfc9114#section-9-4 Ok(StreamVerdict::Passed((mut stream_reader, mut stream_writer))) => { let code = Code::H3_STREAM_CREATION_ERROR.into_inner(); - _ = tokio::join!(stream_reader.stop(code), stream_writer.cancel(code)) + _ = tokio::join!(stream_reader.stop(code), stream_writer.reset(code)) } Err(stream_error) => { // The stream has been consumed by protocol matching @@ -593,7 +600,7 @@ impl ConnectionState { } }; let stream_reader = - StreamReader::new(Box::pin(stream_reader) as crate::codec::BoxReadStream); + StreamReader::new(Box::pin(stream_reader) as crate::quic::BoxQuicStreamReader); let peekable_uni_stream = PeekableStreamReader::new(stream_reader); match state.protocols.accept_uni(peekable_uni_stream).await { @@ -691,32 +698,41 @@ pub(crate) mod tests { collections::hash_map::DefaultHasher, fmt, hash::{Hash, Hasher}, + marker::PhantomData, + }; + use std::{ + error::Error as _, + future::pending, + io, + pin::Pin, + sync::{Arc, Mutex}, }; - use std::{future::pending, pin::Pin, sync::Arc}; use bytes::Bytes; - use futures::{Sink, future::BoxFuture, stream::Stream}; + use dhttp_identity::identity as authority; + use futures::{Sink, SinkExt, future::BoxFuture, stream::Stream}; + use tracing::Instrument; #[cfg(feature = "dquic")] use super::ConnectionBuilder; - use super::ConnectionState; + use super::{Connection, ConnectionState, LifecycleExt, StreamError}; #[cfg(feature = "dquic")] use crate::{ - codec::{ErasedPeekableBiStream, ErasedPeekableUniStream}, - connection::StreamError, + codec::{BoxPeekableStreamReader, BoxStreamWriter}, dhttp::settings::{MaxFieldSectionSize, Settings}, protocol::{ProductProtocol, Protocol, StreamVerdict}, }; use crate::{ + error::{Code, H3MessageError, H3MissingSettings}, protocol::Protocols, - quic::{self, ConnectionError, agent}, + quic::{self, ConnectionError, ResetStreamExt, StopStreamExt}, varint::VarInt, }; #[derive(Debug)] - pub(crate) struct TestLocalAgent; + pub(crate) struct TestLocalAuthority; - impl agent::LocalAgent for TestLocalAgent { + impl authority::LocalAuthority for TestLocalAuthority { fn name(&self) -> &str { "test-local" } @@ -724,24 +740,15 @@ pub(crate) mod tests { fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { &[] } - - fn sign_algorithm(&self) -> rustls::SignatureAlgorithm { - rustls::SignatureAlgorithm::ED25519 - } - - fn sign( - &self, - _scheme: rustls::SignatureScheme, - _data: &[u8], - ) -> BoxFuture<'_, Result, agent::SignError>> { + fn sign(&self, _data: &[u8]) -> BoxFuture<'_, Result, authority::SignError>> { Box::pin(async { Ok(Vec::new()) }) } } #[derive(Debug)] - pub(crate) struct TestRemoteAgent; + pub(crate) struct TestRemoteAuthority; - impl agent::RemoteAgent for TestRemoteAgent { + impl authority::RemoteAuthority for TestRemoteAuthority { fn name(&self) -> &str { "test-remote" } @@ -800,8 +807,8 @@ pub(crate) mod tests { } } - impl quic::CancelStream for TestWriteStream { - fn poll_cancel( + impl quic::ResetStream for TestWriteStream { + fn poll_reset( self: Pin<&mut Self>, _cx: &mut std::task::Context, _code: VarInt, @@ -847,6 +854,9 @@ pub(crate) mod tests { #[derive(Debug, Default)] pub(crate) struct MockConnectionState { terminal_error: crate::util::set_once::SetOnce, + close_calls: Mutex>, + stream_calls: Mutex>, + stream_ops_available: std::sync::atomic::AtomicBool, } #[derive(Debug, Clone, Default)] @@ -862,6 +872,48 @@ pub(crate) mod tests { pub(crate) fn set_terminal_error(&self, error: quic::ConnectionError) { let _ = self.state.terminal_error.set(error); } + + pub(crate) fn enable_stream_ops(&self) { + self.state + .stream_ops_available + .store(true, std::sync::atomic::Ordering::Relaxed); + } + + pub(crate) fn disable_stream_ops(&self) { + self.state + .stream_ops_available + .store(false, std::sync::atomic::Ordering::Relaxed); + } + + pub(crate) fn close_calls(&self) -> Vec<(Code, String)> { + self.state + .close_calls + .lock() + .expect("close call log poisoned") + .clone() + } + + pub(crate) fn stream_calls(&self) -> Vec<&'static str> { + self.state + .stream_calls + .lock() + .expect("stream call log poisoned") + .clone() + } + + fn record_stream_call(&self, call: &'static str) { + self.state + .stream_calls + .lock() + .expect("stream call log poisoned") + .push(call); + } + + fn stream_ops_available(&self) -> bool { + self.state + .stream_ops_available + .load(std::sync::atomic::Ordering::Relaxed) + } } impl quic::ManageStream for MockConnection { @@ -871,42 +923,68 @@ pub(crate) mod tests { async fn open_bi( &self, ) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { - Err(test_connection_error("open_bi unavailable")) + self.record_stream_call("open_bi"); + if self.stream_ops_available() { + Ok((TestReadStream, TestWriteStream)) + } else { + Err(test_connection_error("open_bi unavailable")) + } } async fn open_uni(&self) -> Result { - Err(test_connection_error("open_uni unavailable")) + self.record_stream_call("open_uni"); + if self.stream_ops_available() { + Ok(TestWriteStream) + } else { + Err(test_connection_error("open_uni unavailable")) + } } async fn accept_bi( &self, ) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { - Err(test_connection_error("accept_bi unavailable")) + self.record_stream_call("accept_bi"); + if self.stream_ops_available() { + Ok((TestReadStream, TestWriteStream)) + } else { + Err(test_connection_error("accept_bi unavailable")) + } } async fn accept_uni(&self) -> Result { - Err(test_connection_error("accept_uni unavailable")) + self.record_stream_call("accept_uni"); + if self.stream_ops_available() { + Ok(TestReadStream) + } else { + Err(test_connection_error("accept_uni unavailable")) + } } } - impl quic::WithLocalAgent for MockConnection { - type LocalAgent = TestLocalAgent; + impl quic::WithLocalAuthority for MockConnection { + type LocalAuthority = TestLocalAuthority; - async fn local_agent(&self) -> Result, ConnectionError> { - Ok(None) + async fn local_authority(&self) -> Result, ConnectionError> { + Ok(Some(TestLocalAuthority)) } } - impl quic::WithRemoteAgent for MockConnection { - type RemoteAgent = TestRemoteAgent; + impl quic::WithRemoteAuthority for MockConnection { + type RemoteAuthority = TestRemoteAuthority; - async fn remote_agent(&self) -> Result, ConnectionError> { - Ok(None) + async fn remote_authority(&self) -> Result, ConnectionError> { + Ok(Some(TestRemoteAuthority)) } } impl quic::Lifecycle for MockConnection { - fn close(&self, _code: crate::error::Code, _reason: std::borrow::Cow<'static, str>) {} + fn close(&self, code: crate::error::Code, reason: std::borrow::Cow<'static, str>) { + self.state + .close_calls + .lock() + .expect("close call log poisoned") + .push((code, reason.into_owned())); + } fn check(&self) -> Result<(), ConnectionError> { match self.state.terminal_error.peek() { @@ -942,6 +1020,27 @@ pub(crate) mod tests { } } + fn assert_connection_h3_code(error: super::ConnectionError, expected_code: Code) { + match error { + super::ConnectionError::H3 { source } => assert_eq!(source.code(), expected_code), + other => panic!("expected h3 connection error, got {other:?}"), + } + } + + fn assert_stream_reset(error: StreamError, expected_code: VarInt) { + match error { + StreamError::Reset { code } => assert_eq!(code, expected_code), + other => panic!("expected stream reset, got {other:?}"), + } + } + + fn assert_stream_h3_code(error: StreamError, expected_code: Code) { + match error { + StreamError::H3 { source } => assert_eq!(source.code(), expected_code), + other => panic!("expected h3 stream error, got {other:?}"), + } + } + #[cfg(feature = "dquic")] fn hash_of(val: &T) -> u64 { let mut hasher = DefaultHasher::new(); @@ -958,15 +1057,18 @@ pub(crate) mod tests { impl Protocol for MockProtocol { fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) } fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) } } @@ -992,25 +1094,181 @@ pub(crate) mod tests { _: &'a Arc, _: &'a Protocols, ) -> BoxFuture<'a, Result> { - unimplemented!("not used in builder identity tests") + Box::pin(async { Ok(MockProtocol) }) + } + } + + #[cfg(feature = "dquic")] + #[derive(Debug)] + struct PassThenDisableProtocol { + quic: MockConnection, + } + + #[cfg(feature = "dquic")] + impl Protocol for PassThenDisableProtocol { + fn accept_uni<'a>( + &'a self, + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { + self.quic.disable_stream_ops(); + Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) + } + + fn accept_bi<'a>( + &'a self, + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { + self.quic.disable_stream_ops(); + Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) + } + } + + #[cfg(feature = "dquic")] + #[derive(Debug, Clone, Copy)] + enum ErrKind { + Connection, + Reset, + } + + #[cfg(feature = "dquic")] + #[derive(Debug)] + struct ErrThenDisableProtocol { + quic: MockConnection, + kind: ErrKind, + } + + #[cfg(feature = "dquic")] + impl ErrThenDisableProtocol { + fn make_error(&self) -> StreamError { + match self.kind { + ErrKind::Connection => StreamError::from(H3MissingSettings), + ErrKind::Reset => StreamError::Reset { + code: VarInt::from_u32(7), + }, + } + } + } + + #[cfg(feature = "dquic")] + impl Protocol for ErrThenDisableProtocol { + fn accept_uni<'a>( + &'a self, + _stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { + self.quic.disable_stream_ops(); + let err = self.make_error(); + Box::pin(async move { Err(err) }) + } + + fn accept_bi<'a>( + &'a self, + _stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { + self.quic.disable_stream_ops(); + let err = self.make_error(); + Box::pin(async move { Err(err) }) } } #[cfg(feature = "dquic")] type C = dquic::prelude::Connection; + /// Hash equality and determinism: identical inputs must produce equal hashes. #[cfg(feature = "dquic")] #[test] - fn builder_same_settings_equal_hash() { - let s = Arc::new(Settings::default()); - let a = ConnectionBuilder::::new(s.clone()); - let b = ConnectionBuilder::::new(s); - assert_eq!(hash_of(&a), hash_of(&b)); + fn hash_equality_and_determinism() { + let s = || Arc::new(Settings::default()); + + // Build a reference builder for the "same builder twice" determinism check. + let det_builder = ConnectionBuilder::::new(s()).protocol(MockFactory(99)); + let h_det = hash_of(&det_builder); + + // All cases where hash must be equal (rebuild from identical inputs). + let cases: [(&str, ConnectionBuilder, ConnectionBuilder); 3] = [ + ( + "same settings", + ConnectionBuilder::::new(s()), + ConnectionBuilder::::new(s()), + ), + ( + "same protocol", + ConnectionBuilder::::new(s()).protocol(MockFactory(42)), + ConnectionBuilder::::new(s()).protocol(MockFactory(42)), + ), + ( + "clone-like rebuild", + ConnectionBuilder::::new(s()).protocol(MockFactory(100)), + ConnectionBuilder::::new(s()).protocol(MockFactory(100)), + ), + ]; + + for (name, a, b) in &cases { + assert_eq!(hash_of(a), hash_of(b), "hash equality: {name}"); + } + + // Same builder hashed twice must be deterministic. + assert_eq!( + hash_of(&det_builder), + h_det, + "hashing the same builder twice must be deterministic" + ); + + // An identically-constructed builder must produce the same hash. + let builder2 = ConnectionBuilder::::new(s()).protocol(MockFactory(99)); + assert_eq!( + h_det, + hash_of(&builder2), + "identical builders must hash equally" + ); + } + + #[cfg(feature = "dquic")] + #[test] + fn display_lists_each_initializer_separated_by_commas() { + let s = || Arc::new(Settings::default()); + + let empty = ConnectionBuilder:: { + initializers: Vec::new(), + _connection: PhantomData, + }; + assert_eq!(format!("{}", empty), "ConnectionBuilder[]"); + + let one = ConnectionBuilder:: { + initializers: Vec::new(), + _connection: PhantomData, + } + .protocol(MockFactory(1)); + assert_eq!(format!("{}", one), "ConnectionBuilder[MockFactory]"); + + let two = ConnectionBuilder:: { + initializers: Vec::new(), + _connection: PhantomData, + } + .protocol(MockFactory(1)) + .protocol(MockFactory(2)); + assert_eq!( + format!("{}", two), + "ConnectionBuilder[MockFactory, MockFactory]" + ); + + let default_with_extra = ConnectionBuilder::::new(s()).protocol(MockFactory(99)); + let rendered = format!("{}", default_with_extra); + assert!(rendered.starts_with("ConnectionBuilder[")); + assert!(rendered.ends_with(']')); + assert!(rendered.contains("MockFactory")); + assert_eq!(rendered.matches(", ").count(), 2); } + /// Different settings produce different hashes. #[cfg(feature = "dquic")] #[test] - fn builder_different_settings_different_hash() { + fn hash_ne_different_settings() { let s1 = Arc::new(Settings::default()); let mut s2_inner = Settings::default(); s2_inner.set(MaxFieldSectionSize::setting(VarInt::from_u32(9999))); @@ -1020,18 +1278,39 @@ pub(crate) mod tests { assert_ne!(hash_of(&a), hash_of(&b)); } + /// Different protocol stacks produce different hashes. #[cfg(feature = "dquic")] #[test] - fn builder_extra_protocol_different_hash() { - let s = Arc::new(Settings::default()); - let a = ConnectionBuilder::::new(s.clone()); - let b = ConnectionBuilder::::new(s).protocol(MockFactory(42)); - assert_ne!(hash_of(&a), hash_of(&b)); + fn hash_ne_different_protocols() { + let s = || Arc::new(Settings::default()); + + let cases: [(&str, ConnectionBuilder, ConnectionBuilder); 3] = [ + ( + "extra protocol", + ConnectionBuilder::::new(s()), + ConnectionBuilder::::new(s()).protocol(MockFactory(42)), + ), + ( + "different protocol value", + ConnectionBuilder::::new(s()).protocol(MockFactory(1)), + ConnectionBuilder::::new(s()).protocol(MockFactory(2)), + ), + ( + "mock factory included in hash", + ConnectionBuilder::::new(s()), + ConnectionBuilder::::new(s()).protocol(MockFactory(42)), + ), + ]; + + for (name, a, b) in &cases { + assert_ne!(hash_of(a), hash_of(b), "hash inequality: {name}"); + } } + /// Different protocol ordering produces different hashes. #[cfg(feature = "dquic")] #[test] - fn builder_different_order_different_hash() { + fn hash_ne_different_order() { let s = Arc::new(Settings::default()); // a: DHttpProtocolFactory, QPackProtocolFactory, MockFactory (from new + protocol) let a = ConnectionBuilder::::new(s.clone()).protocol(MockFactory(7)); @@ -1070,76 +1349,644 @@ pub(crate) mod tests { assert_ne!(a, b); } - /// Simulates pool key distinction: two builders with different protocol stacks - /// produce different hashes, so the pool stores them under separate keys. #[cfg(feature = "dquic")] #[test] - fn pool_key_different_builders_different_hash() { - let s = Arc::new(Settings::default()); - let a = ConnectionBuilder::::new(s.clone()).protocol(MockFactory(1)); - let b = ConnectionBuilder::::new(s).protocol(MockFactory(2)); - // Pool computes hash via DefaultHasher the same way hash_of does. - // Different protocol stacks must yield different keys. - assert_ne!(hash_of(&a), hash_of(&b)); + fn builder_display_and_debug_list_protocol_initializers() { + let builder = + ConnectionBuilder::::new(Arc::new(Settings::default())).protocol(MockFactory(7)); + + assert_eq!( + builder.to_string(), + "ConnectionBuilder[DHTTP/3, QPACK, MockFactory]" + ); + let debug = format!("{builder:?}"); + assert!(debug.contains("ConnectionBuilder")); + assert!(debug.contains("DHttpProtocolFactory")); + assert!(debug.contains("QPackProtocolFactory")); + assert!(debug.contains("MockFactory")); + } + + #[tokio::test] + async fn builder_closes_quic_when_initial_protocol_init_fails() { + let quic = Arc::new(MockConnection::new()); + let result = ConnectionBuilder::new(Arc::new(Settings::default())) + .build(quic.clone()) + .await; + + let error = result.expect_err("open_uni failure should abort connection build"); + assert_transport_reason(&error, "open_uni unavailable"); + assert_eq!( + quic.close_calls(), + vec![(Code::H3_NO_ERROR, "h3 build aborted".to_owned())] + ); + } + + #[tokio::test] + async fn builder_success_initializes_dhttp_and_qpack_protocols() { + let quic = Arc::new(MockConnection::new()); + quic.enable_stream_ops(); + let settings = Arc::new(Settings::default()); + + let connection = ConnectionBuilder::new(settings.clone()) + .build(quic.clone()) + .await + .expect("builder should initialize built-in protocols"); + + assert!( + connection + .protocol::() + .is_some() + ); + assert!(connection.qpack().is_ok()); + assert!(Arc::ptr_eq(&connection.settings(), &settings)); + assert_eq!(quic.stream_calls()[0], "open_uni"); + + drop(connection); + assert!( + quic.close_calls() + .iter() + .any(|(code, reason)| *code == Code::H3_NO_ERROR && reason == "no error") + ); } - /// Simulates pool key reuse: two identical builders produce the same hash, - /// so the pool correctly groups them under one key for connection reuse. #[cfg(feature = "dquic")] + #[tokio::test] + async fn builder_initializes_custom_protocol_factory() { + let quic = Arc::new(MockConnection::new()); + quic.enable_stream_ops(); + + let connection = ConnectionBuilder::new(Arc::new(Settings::default())) + .protocol(MockFactory(7)) + .build(quic) + .await + .expect("custom protocol factory should initialize"); + + assert!(connection.protocol::().is_some()); + } + #[test] - fn pool_key_same_builders_same_hash() { - let s = Arc::new(Settings::default()); - let a = ConnectionBuilder::::new(s.clone()).protocol(MockFactory(42)); - let b = ConnectionBuilder::::new(s).protocol(MockFactory(42)); - assert_eq!(hash_of(&a), hash_of(&b)); + fn state_accessors_return_underlying_quic_and_protocol_registry() { + let quic = Arc::new(MockConnection::new()); + let protocols = Arc::new(Protocols::new()); + let state = ConnectionState::new_for_test(quic.clone(), protocols.clone()); + + assert!(Arc::ptr_eq(state.quic(), &quic)); + assert!(Arc::ptr_eq(state.protocols(), &protocols)); + assert!(state.protocol::().is_none()); + assert!(format!("{state:?}").contains("ConnectionState")); + } + + #[tokio::test] + async fn state_open_stream_helpers_delegate_success_and_error_paths() { + let quic = MockConnection::new(); + let state = + ConnectionState::new_for_test(Arc::new(quic.clone()), Arc::new(Protocols::new())); + + let bi_error = state + .open_bi() + .await + .expect_err("open_bi should fail before enabled"); + assert_transport_reason(&bi_error, "open_bi unavailable"); + let uni_error = state + .open_uni() + .await + .expect_err("open_uni should fail before enabled"); + assert_transport_reason(&uni_error, "open_uni unavailable"); + + quic.enable_stream_ops(); + state + .open_bi() + .await + .expect("open_bi should delegate success"); + state + .open_uni() + .await + .expect("open_uni should delegate success"); + + assert_eq!( + quic.stream_calls(), + vec!["open_bi", "open_uni", "open_bi", "open_uni"] + ); + } + + #[tokio::test] + async fn local_and_remote_authority_accessors_delegate_on_concrete_state() { + let quic = MockConnection::new(); + let state = ConnectionState::new_for_test(Arc::new(quic), Arc::new(Protocols::new())); + + let local = state + .local_authority() + .await + .expect("local authority lookup should succeed") + .expect("local authority should be present"); + assert_eq!(local.name(), "test-local"); + + let remote = state + .remote_authority() + .await + .expect("remote authority lookup should succeed") + .expect("remote authority should be present"); + assert_eq!(remote.name(), "test-remote"); + } + + #[tokio::test] + async fn accept_tasks_handle_accept_errors_without_closing_again() { + let bi_quic = MockConnection::new(); + let bi_state = + ConnectionState::new_for_test(Arc::new(bi_quic.clone()), Arc::new(Protocols::new())); + ConnectionState::accept_bi_stream_task(bi_state).await; + assert_eq!(bi_quic.stream_calls(), vec!["accept_bi"]); + assert!(bi_quic.close_calls().is_empty()); + + let uni_quic = MockConnection::new(); + let uni_state = + ConnectionState::new_for_test(Arc::new(uni_quic.clone()), Arc::new(Protocols::new())); + ConnectionState::accept_uni_stream_task(uni_state).await; + assert_eq!(uni_quic.stream_calls(), vec!["accept_uni"]); + assert!(uni_quic.close_calls().is_empty()); + } + + #[cfg(feature = "dquic")] + #[tokio::test] + async fn accept_tasks_pass_unknown_streams_to_stop_and_reset_paths() { + let bi_quic = MockConnection::new(); + bi_quic.enable_stream_ops(); + let bi_protocols = { + let mut protocols = Protocols::new(); + protocols.insert(PassThenDisableProtocol { + quic: bi_quic.clone(), + }); + Arc::new(protocols) + }; + let bi_state = ConnectionState::new_for_test(Arc::new(bi_quic.clone()), bi_protocols); + ConnectionState::accept_bi_stream_task(bi_state).await; + assert_eq!(bi_quic.stream_calls(), vec!["accept_bi", "accept_bi"]); + + let uni_quic = MockConnection::new(); + uni_quic.enable_stream_ops(); + let uni_protocols = { + let mut protocols = Protocols::new(); + protocols.insert(PassThenDisableProtocol { + quic: uni_quic.clone(), + }); + Arc::new(protocols) + }; + let uni_state = ConnectionState::new_for_test(Arc::new(uni_quic.clone()), uni_protocols); + ConnectionState::accept_uni_stream_task(uni_state).await; + assert_eq!(uni_quic.stream_calls(), vec!["accept_uni", "accept_uni"]); + } + + #[cfg(feature = "dquic")] + #[tokio::test] + async fn accept_bi_task_calls_handle_connection_error_for_connection_scope_errors() { + let bi_quic = MockConnection::new(); + bi_quic.enable_stream_ops(); + + let protocols = { + let mut protocols = Protocols::new(); + protocols.insert(ErrThenDisableProtocol { + quic: bi_quic.clone(), + kind: ErrKind::Connection, + }); + Arc::new(protocols) + }; + let state = ConnectionState::new_for_test(Arc::new(bi_quic.clone()), protocols); + let task = tokio::spawn(ConnectionState::accept_bi_stream_task(state).in_current_span()); + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if !bi_quic.close_calls().is_empty() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("connection-scope error should trigger close"); + + bi_quic.set_terminal_error(test_connection_error("bi terminal")); + tokio::time::timeout(std::time::Duration::from_secs(1), task) + .await + .expect("accept_bi task should terminate") + .expect("task should not panic"); + + assert!(bi_quic.stream_calls().contains(&"accept_bi")); + assert_eq!(bi_quic.close_calls().len(), 1); + assert_eq!(bi_quic.close_calls()[0].0, Code::H3_MISSING_SETTINGS); + } + + #[cfg(feature = "dquic")] + #[tokio::test] + async fn accept_uni_task_calls_handle_connection_error_for_connection_scope_errors() { + let uni_quic = MockConnection::new(); + uni_quic.enable_stream_ops(); + + let protocols = { + let mut protocols = Protocols::new(); + protocols.insert(ErrThenDisableProtocol { + quic: uni_quic.clone(), + kind: ErrKind::Connection, + }); + Arc::new(protocols) + }; + let state = ConnectionState::new_for_test(Arc::new(uni_quic.clone()), protocols); + let task = tokio::spawn(ConnectionState::accept_uni_stream_task(state).in_current_span()); + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if !uni_quic.close_calls().is_empty() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("connection-scope error should trigger close"); + + uni_quic.set_terminal_error(test_connection_error("uni terminal")); + tokio::time::timeout(std::time::Duration::from_secs(1), task) + .await + .expect("accept_uni task should terminate") + .expect("task should not panic"); + + assert!(uni_quic.stream_calls().contains(&"accept_uni")); + assert_eq!(uni_quic.close_calls().len(), 1); + assert_eq!(uni_quic.close_calls()[0].0, Code::H3_MISSING_SETTINGS); + } + + #[cfg(feature = "dquic")] + #[tokio::test] + async fn accept_bi_task_skips_close_for_stream_scope_errors() { + let bi_quic = MockConnection::new(); + bi_quic.enable_stream_ops(); + + let protocols = { + let mut protocols = Protocols::new(); + protocols.insert(ErrThenDisableProtocol { + quic: bi_quic.clone(), + kind: ErrKind::Reset, + }); + Arc::new(protocols) + }; + let state = ConnectionState::new_for_test(Arc::new(bi_quic.clone()), protocols); + let task = tokio::spawn(ConnectionState::accept_bi_stream_task(state).in_current_span()); + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if bi_quic.stream_calls().contains(&"accept_bi") { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("task should call accept_bi"); + + // give the stream-scope error path a chance to execute + for _ in 0..32 { + tokio::task::yield_now().await; + } + + bi_quic.set_terminal_error(test_connection_error("bi reset terminal")); + tokio::time::timeout(std::time::Duration::from_secs(1), task) + .await + .expect("accept_bi task should terminate") + .expect("task should not panic"); + + assert!( + bi_quic.close_calls().is_empty(), + "stream-scope errors must not trigger connection close" + ); } - /// Verifies hash determinism: hashing the same builder twice yields the same - /// value, and an identically-constructed builder also matches. #[cfg(feature = "dquic")] + #[tokio::test] + async fn accept_uni_task_skips_close_for_stream_scope_errors() { + let uni_quic = MockConnection::new(); + uni_quic.enable_stream_ops(); + + let protocols = { + let mut protocols = Protocols::new(); + protocols.insert(ErrThenDisableProtocol { + quic: uni_quic.clone(), + kind: ErrKind::Reset, + }); + Arc::new(protocols) + }; + let state = ConnectionState::new_for_test(Arc::new(uni_quic.clone()), protocols); + let task = tokio::spawn(ConnectionState::accept_uni_stream_task(state).in_current_span()); + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if uni_quic.stream_calls().contains(&"accept_uni") { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("task should call accept_uni"); + + for _ in 0..32 { + tokio::task::yield_now().await; + } + + uni_quic.set_terminal_error(test_connection_error("uni reset terminal")); + tokio::time::timeout(std::time::Duration::from_secs(1), task) + .await + .expect("accept_uni task should terminate") + .expect("task should not panic"); + + assert!(uni_quic.close_calls().is_empty()); + } + + #[tokio::test] + async fn lifecycle_ext_h3_error_closes_then_returns_terminal_error() { + let quic = MockConnection::new(); + let quic_for_task = quic.clone(); + let task = tokio::spawn( + async move { + quic_for_task + .handle_connection_error(super::ConnectionError::from(H3MissingSettings)) + .await + } + .in_current_span(), + ); + + tokio::time::timeout(std::time::Duration::from_secs(1), async { + loop { + if !quic.close_calls().is_empty() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("h3 error should close promptly"); + + assert_eq!( + quic.close_calls(), + vec![( + Code::H3_MISSING_SETTINGS, + "no SETTINGS frame at beginning of control stream".to_owned() + )] + ); + + quic.set_terminal_error(test_connection_error("after h3 close")); + let error = tokio::time::timeout(std::time::Duration::from_secs(1), task) + .await + .expect("closed should resolve") + .expect("task should not panic"); + assert_transport_reason(&error, "after h3 close"); + } + + #[tokio::test] + async fn lifecycle_ext_h3_error_returns_existing_terminal_error_without_closing() { + let quic = MockConnection::new(); + quic.set_terminal_error(test_connection_error("already closed")); + + let error = quic + .handle_connection_error(super::ConnectionError::from(H3MissingSettings)) + .await; + + assert_transport_reason(&error, "already closed"); + assert!(quic.close_calls().is_empty()); + } + #[test] - fn builder_hash_determinism() { - let s = Arc::new(Settings::default()); - let builder = ConnectionBuilder::::new(s.clone()).protocol(MockFactory(99)); - let h1 = hash_of(&builder); - let h2 = hash_of(&builder); + fn stream_error_from_connection_scope_h3_registry_variant() { + let error = StreamError::from(H3MissingSettings); + + match error { + StreamError::Connection { source } => { + assert_connection_h3_code(source, Code::H3_MISSING_SETTINGS); + } + other => panic!("expected connection-scope stream error, got {other:?}"), + } + } + + #[test] + fn h3_error_display_and_sources_are_layered() { + let connection_error = super::ConnectionError::from(H3MissingSettings); assert_eq!( - h1, h2, - "hashing the same builder twice must be deterministic" + connection_error.to_string(), + "h3 connection-scope protocol error" + ); + assert_eq!( + connection_error + .source() + .expect("h3 connection source") + .to_string(), + "no SETTINGS frame at beginning of control stream" ); - // A second, identically-constructed builder must produce the same hash. - let builder2 = ConnectionBuilder::::new(s).protocol(MockFactory(99)); + let stream_error = StreamError::from(H3MessageError::MissingHeaderSection); + assert_eq!(stream_error.to_string(), "h3 stream-scope protocol error"); assert_eq!( - h1, - hash_of(&builder2), - "identical builders must hash equally" + stream_error.source().expect("h3 stream source").to_string(), + "missing header section in HTTP message" ); } - #[cfg(feature = "dquic")] #[test] - fn builder_clone_like_rebuild_hash_determinism() { - let s = Arc::new(Settings::default()); - let a = ConnectionBuilder::::new(s.clone()).protocol(MockFactory(100)); - let b = ConnectionBuilder::::new(s).protocol(MockFactory(100)); - assert_eq!( - hash_of(&a), - hash_of(&b), - "builders built via same steps must hash identically" + fn stream_error_h3_io_roundtrip_preserves_source() { + let io_error = io::Error::from(StreamError::from(H3MessageError::MissingHeaderSection)); + assert_eq!(io_error.kind(), io::ErrorKind::Other); + + let recovered = StreamError::from(io_error); + assert_stream_h3_code(recovered, Code::H3_MESSAGE_ERROR); + } + + #[test] + fn connection_error_recovery_rejects_untyped_io_error() { + let result = std::panic::catch_unwind(|| { + let _ = super::ConnectionError::from(io::Error::other("opaque")); + }); + + assert!(result.is_err()); + } + + #[test] + fn stream_error_recovery_rejects_untyped_io_error() { + let result = std::panic::catch_unwind(|| { + let _ = StreamError::from(io::Error::other("opaque")); + }); + + assert!(result.is_err()); + } + + #[test] + fn stream_error_map_stream_reset_maps_only_reset_code() { + let reset_code = VarInt::from_u32(0x123); + let mapped = StreamError::Reset { code: reset_code }.map_stream_reset(|code| { + assert_eq!(code, reset_code); + StreamError::from(H3MessageError::MissingHeaderSection) + }); + + assert_stream_h3_code(mapped, Code::H3_MESSAGE_ERROR); + } + + #[test] + fn connection_error_recovers_from_io_error_layers() { + let connection_error = super::ConnectionError::from(test_connection_error("wrapped")); + let recovered = super::ConnectionError::from(io::Error::from(connection_error)); + match recovered { + super::ConnectionError::Quic { source } => assert_transport_reason(&source, "wrapped"), + other => panic!("expected quic connection error, got {other:?}"), + } + + let quic_error = test_connection_error("quic"); + let recovered = super::ConnectionError::from(io::Error::from(quic_error)); + match recovered { + super::ConnectionError::Quic { source } => assert_transport_reason(&source, "quic"), + other => panic!("expected quic connection error, got {other:?}"), + } + + let h3_error = super::ConnectionError::from(H3MissingSettings); + let super::ConnectionError::H3 { source } = h3_error else { + panic!("expected h3 connection error"); + }; + let h3_source: Arc = source; + let recovered = super::ConnectionError::from(io::Error::other(h3_source)); + assert_connection_h3_code(recovered, Code::H3_MISSING_SETTINGS); + } + + #[test] + fn stream_error_recovers_from_io_error_layers() { + let reset_code = VarInt::from_u32(0x41); + let reset = StreamError::Reset { code: reset_code }; + assert_stream_reset(StreamError::from(io::Error::from(reset)), reset_code); + + let quic_reset_code = VarInt::from_u32(0x42); + let quic_reset = quic::StreamError::Reset { + code: quic_reset_code, + }; + assert_stream_reset( + StreamError::from(io::Error::from(quic_reset)), + quic_reset_code, ); + + let connection_error = super::ConnectionError::from(test_connection_error("stream")); + let recovered = StreamError::from(io::Error::from(connection_error)); + match recovered { + StreamError::Connection { + source: super::ConnectionError::Quic { source }, + } => assert_transport_reason(&source, "stream"), + other => panic!("expected stream connection error, got {other:?}"), + } + + let h3 = StreamError::from(H3MessageError::MissingHeaderSection); + let StreamError::H3 { source } = h3 else { + panic!("expected h3 stream error"); + }; + let h3_source: Arc = source; + let recovered = StreamError::from(io::Error::other(h3_source)); + assert_stream_h3_code(recovered, Code::H3_MESSAGE_ERROR); } - #[cfg(feature = "dquic")] #[test] - fn builder_mock_factory_included_in_hash() { - let s = Arc::new(Settings::default()); - let base = ConnectionBuilder::::new(s.clone()); - let with_mock = ConnectionBuilder::::new(s).protocol(MockFactory(42)); - assert_ne!( - hash_of(&base), - hash_of(&with_mock), - "adding MockFactory must change the hash" + fn map_stream_reset_leaves_non_reset_errors_unchanged() { + let mut mapper_called = false; + let error = StreamError::from(H3MessageError::UnexpectedHeadersInBody); + let mapped = error.map_stream_reset(|code| { + mapper_called = true; + StreamError::Reset { code } + }); + + assert!(!mapper_called); + assert_stream_h3_code(mapped, Code::H3_MESSAGE_ERROR); + } + + #[tokio::test] + async fn test_stream_helpers_cover_stop_reset_and_close() { + let mut reader = TestReadStream; + reader + .stop(VarInt::from_u32(0x103)) + .await + .expect("test reader stop should succeed"); + + let mut writer = TestWriteStream; + writer + .reset(VarInt::from_u32(0x103)) + .await + .expect("test writer reset should succeed"); + writer + .close() + .await + .expect("test writer close should succeed"); + } + + #[tokio::test] + async fn erased_connection_state_delegates_dyn_connection_operations() { + let quic = MockConnection::new(); + quic.enable_stream_ops(); + let state = + ConnectionState::new_for_test(Arc::new(quic.clone()), Arc::new(Protocols::new())); + let erased = state.erase(); + + assert!(Arc::ptr_eq(state.protocols(), erased.protocols())); + assert_eq!( + erased + .local_authority() + .await + .expect("local authority delegated") + .expect("local authority present") + .name(), + "test-local" + ); + assert_eq!( + erased + .remote_authority() + .await + .expect("remote authority delegated") + .expect("remote authority present") + .name(), + "test-remote" + ); + assert!(erased.check().is_ok()); + + let _ = quic::DynManageStream::open_bi(&**erased.quic()) + .await + .expect("open_bi delegated"); + let _ = quic::DynManageStream::open_uni(&**erased.quic()) + .await + .expect("open_uni delegated"); + let _ = quic::DynManageStream::accept_bi(&**erased.quic()) + .await + .expect("accept_bi delegated"); + let _ = quic::DynManageStream::accept_uni(&**erased.quic()) + .await + .expect("accept_uni delegated"); + + assert_eq!( + quic.stream_calls(), + vec!["open_bi", "open_uni", "accept_bi", "accept_uni"] + ); + + erased.close(Code::H3_NO_ERROR, "dyn close"); + assert_eq!( + quic.close_calls(), + vec![(Code::H3_NO_ERROR, "dyn close".to_owned())] + ); + + quic.set_terminal_error(test_connection_error("dyn closed")); + let closed = erased.closed().await; + assert_transport_reason(&closed, "dyn closed"); + } + + #[tokio::test] + async fn connection_from_state_for_test_closes_on_drop() { + let quic = MockConnection::new(); + let state = + ConnectionState::new_for_test(Arc::new(quic.clone()), Arc::new(Protocols::new())); + let connection = Connection::from_state_for_test(state); + + assert!(quic.close_calls().is_empty()); + drop(connection); + + assert_eq!( + quic.close_calls(), + vec![(Code::H3_NO_ERROR, "no error".to_owned())] ); } diff --git a/src/dhttp.rs b/src/dhttp.rs index 2ef55f7..3572443 100644 --- a/src/dhttp.rs +++ b/src/dhttp.rs @@ -1,5 +1,8 @@ +pub mod datagram; pub mod frame; pub mod goaway; +pub mod message; pub mod protocol; pub mod settings; pub mod stream; +pub mod webtransport; diff --git a/src/dhttp/datagram.rs b/src/dhttp/datagram.rs new file mode 100644 index 0000000..6e98cef --- /dev/null +++ b/src/dhttp/datagram.rs @@ -0,0 +1 @@ +pub mod settings; diff --git a/src/dhttp/datagram/settings.rs b/src/dhttp/datagram/settings.rs new file mode 100644 index 0000000..562eefc --- /dev/null +++ b/src/dhttp/datagram/settings.rs @@ -0,0 +1,72 @@ +use crate::{ + dhttp::settings::{Setting, SettingId, Settings}, + varint::VarInt, +}; + +/// `H3_DATAGRAM` (0x33). No default. +/// +/// Indicates support for HTTP/3 datagrams (RFC 9297). The value MUST be 0 or 1. +pub struct H3Datagram; + +impl H3Datagram { + pub const ID: VarInt = VarInt::from_u32(0x33); + + pub const fn setting(enabled: bool) -> Setting { + Setting::new(Self::ID, VarInt::from_u32(enabled as u32)) + } +} + +impl SettingId for H3Datagram { + type Value = bool; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> bool { + settings + .get_raw(Self::ID) + .is_some_and(|value| value == VarInt::from_u32(1)) + } +} + +impl Settings { + pub fn h3_datagram(&self) -> bool { + self.get(H3Datagram) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::varint::VarInt; + + #[test] + fn h3_datagram_uses_boolean_value() { + let mut settings = Settings::default(); + assert!(!settings.h3_datagram()); + + settings.set(H3Datagram::setting(true)); + assert!(settings.h3_datagram()); + assert_eq!( + settings.get(VarInt::from_u32(0x33)), + Some(VarInt::from_u32(1)) + ); + } + + #[test] + fn h3_datagram_exposes_id_and_treats_only_one_as_enabled() { + assert_eq!(H3Datagram.id(), H3Datagram::ID); + + let mut settings = Settings::default(); + settings.set(H3Datagram::setting(false)); + assert!(!settings.h3_datagram()); + assert_eq!( + settings.get(VarInt::from_u32(0x33)), + Some(VarInt::from_u32(0)) + ); + + settings.set(Setting::new(H3Datagram::ID, VarInt::from_u32(2))); + assert!(!settings.h3_datagram()); + } +} diff --git a/src/dhttp/frame.rs b/src/dhttp/frame.rs index 64d8d20..efb23ea 100644 --- a/src/dhttp/frame.rs +++ b/src/dhttp/frame.rs @@ -223,6 +223,7 @@ impl + Sen let bytes = payload.copy_to_bytes(payload.chunk().len()); stream.as_mut().feed(bytes).await?; } + stream.as_mut().flush().await?; Ok(()) } } @@ -332,14 +333,33 @@ impl StopStream for Frame

{ #[cfg(test)] mod tests { - use futures::stream::{self, StreamExt}; - use tokio::io::AsyncReadExt; + use std::{ + pin::Pin, + sync::{Arc, Mutex}, + time::Duration, + }; + + use futures::{ + Sink, SinkExt, + future::poll_fn, + stream::{self, Stream, StreamExt}, + }; + use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}, + time::timeout, + }; + use tracing::Instrument; use super::*; use crate::{ - codec::{DecodeError, EncodeExt, SinkWriter}, + buflist::BufList, + codec::{ + DecodeError, DecodeExt, DecodeFrom, EncodeError, EncodeExt, SinkWriter, + StreamDecodeError, + }, dhttp::frame::stream::FrameStream, - error::Code, + error::{Code, H3FrameDecodeError}, + varint::VARINT_MAX, }; fn to_pre_byte_stream( @@ -353,36 +373,322 @@ mod tests { ) } + #[derive(Debug, Default, Clone)] + struct SinkRecorder { + chunks: Arc>>, + } + + impl SinkRecorder { + fn new() -> (Self, Arc>>) { + let chunks = Arc::new(Mutex::new(Vec::new())); + ( + Self { + chunks: chunks.clone(), + }, + chunks, + ) + } + } + + impl futures::Sink for SinkRecorder { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.chunks.lock().expect("chunks lock poisoned").push(item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct FailingEncodeStream { + reset_code: VarInt, + fail_write_at: Option, + fail_start_send: bool, + write_calls: usize, + } + + impl FailingEncodeStream { + fn fail_write_at(call: usize, reset_code: VarInt) -> Self { + Self { + reset_code, + fail_write_at: Some(call), + fail_start_send: false, + write_calls: 0, + } + } + + fn fail_start_send(reset_code: VarInt) -> Self { + Self { + reset_code, + fail_write_at: None, + fail_start_send: true, + write_calls: 0, + } + } + + fn reset(&self) -> quic::StreamError { + quic::StreamError::Reset { + code: self.reset_code, + } + } + } + + impl AsyncWrite for FailingEncodeStream { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls += 1; + if this.fail_write_at == Some(this.write_calls) { + return Poll::Ready(Err(io::Error::from(this.reset()))); + } + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Sink for FailingEncodeStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + if self.fail_start_send { + return Err(self.reset()); + } + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Debug, PartialEq, Eq)] + struct CapturedPayload(Vec); + + impl DecodeFrom for CapturedPayload { + type Error = StreamDecodeError; + + async fn decode_from(mut stream: BufList) -> Result { + let mut bytes = Vec::new(); + while stream.has_remaining() { + let chunk = stream.chunk(); + bytes.extend_from_slice(chunk); + stream.advance(chunk.len()); + } + Ok(Self(bytes)) + } + } + + #[cfg(target_pointer_width = "64")] + #[derive(Debug)] + struct HugeBuf; + + #[cfg(target_pointer_width = "64")] + impl bytes::Buf for HugeBuf { + fn remaining(&self) -> usize { + VARINT_MAX as usize + 1 + } + + fn chunk(&self) -> &[u8] { + &[] + } + + fn advance(&mut self, _cnt: usize) {} + } + #[test] fn new_frame() { Frame::new(Frame::DATA_FRAME_TYPE, Bytes::new()).unwrap(); } + #[cfg(target_pointer_width = "64")] + #[test] + fn new_rejects_payload_larger_than_varint() { + let error = Frame::new(Frame::DATA_FRAME_TYPE, HugeBuf) + .expect_err("frame payload length must fit in varint"); + + assert_eq!( + error.to_string(), + format!( + "value({}) too large for varint encoding", + VARINT_MAX as u128 + 1 + ) + ); + } + + #[test] + fn known_frame_type_constants() { + assert_eq!(Frame::DATA_FRAME_TYPE, VarInt::from_u32(0x00)); + assert_eq!(Frame::HEADERS_FRAME_TYPE, VarInt::from_u32(0x01)); + assert_eq!(Frame::CANCEL_PUSH_FRAME_TYPE, VarInt::from_u32(0x03)); + assert_eq!(Frame::SETTINGS_FRAME_TYPE, VarInt::from_u32(0x04)); + assert_eq!(Frame::PUSH_PROMISE_FRAME_TYPE, VarInt::from_u32(0x05)); + assert_eq!(Frame::GOAWAY_FRAME_TYPE, VarInt::from_u32(0x07)); + assert_eq!(Frame::MAX_PUSH_ID_FRAME_TYPE, VarInt::from_u32(0x0d)); + } + + #[test] + fn reserved_frame_type_detection() { + let non_reserved = Frame::new(VarInt::from_u32(0x1f), Bytes::new()).unwrap(); + assert!(!non_reserved.is_reserved_frame()); + + for i in 0u64..3 { + let frame_type = 0x21 + (i * 0x1f); + let frame = Frame::new( + VarInt::from_u64(frame_type).expect("frame type in range"), + Bytes::new(), + ) + .unwrap(); + assert!(frame.is_reserved_frame()); + assert_ne!(frame.r#type(), Frame::DATA_FRAME_TYPE); + } + + assert!( + !Frame::new( + VarInt::from_u64(0x20).expect("frame type in range"), + Bytes::new() + ) + .unwrap() + .is_reserved_frame() + ); + } + + #[test] + fn map_preserves_frame_meta() { + let frame = Frame::new(Frame::HEADERS_FRAME_TYPE, Bytes::from_static(b"hello")).unwrap(); + let mapped = frame.map(|payload| payload.len() * 2); + assert_eq!(mapped.r#type(), Frame::HEADERS_FRAME_TYPE); + assert_eq!(mapped.length().into_inner(), 5); + assert_eq!(mapped.into_payload(), 10); + } + #[tokio::test] - async fn test_data_frames() { - let stream = StreamReader::new(to_pre_byte_stream(vec![ - 0x00, // Type: DATA - 0x05, // Length: 5 - b'H', b'e', b'l', b'l', b'o', // Payload: "Hello" - 0x00, // Type: DATA - 0x05, // Length: 5 - b'W', b'o', b'r', b'l', b'd', // Payload: "World" - ])); + async fn async_write_updates_length_and_delegates_io() { + let (payload, mut reader) = tokio::io::duplex(64); + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::from_u32(0), + payload, + }; - let mut stream = pin!(FrameStream::new(stream)); - let mut frame1 = stream.as_mut().next_frame().await.unwrap().unwrap(); - assert_eq!(frame1.r#type().into_inner(), 0); - assert_eq!(frame1.length().into_inner(), 5); - let mut payload = vec![]; - frame1.read_to_end(&mut payload).await.unwrap(); - assert_eq!(&payload[..], b"Hello"); + frame.write_all(b"abc").await.expect("payload write"); + assert_eq!(frame.length().into_inner(), 3); + AsyncWriteExt::flush(&mut frame) + .await + .expect("payload flush"); + AsyncWriteExt::shutdown(&mut frame) + .await + .expect("payload shutdown"); - let mut frame2 = stream.next_frame().await.unwrap().unwrap(); - let mut payload = vec![]; - frame2.read_to_end(&mut payload).await.unwrap(); - assert_eq!(&payload[..], b"World"); - assert_eq!(frame2.length().into_inner(), 5); - assert_eq!(frame2.r#type().into_inner(), 0); + let mut payload = Vec::new(); + reader + .read_to_end(&mut payload) + .await + .expect("payload forwarded"); + assert_eq!(payload, b"abc"); + } + + #[tokio::test] + async fn async_write_rejects_frame_payload_overflow() { + let payload = tokio::io::sink(); + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::MAX, + payload, + }; + + let error = frame + .write_all(b"x") + .await + .expect_err("payload beyond varint max is invalid"); + + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + assert!(matches!( + EncodeError::try_from(error), + Ok(EncodeError::FramePayloadTooLarge) + )); + } + + #[tokio::test] + async fn async_buf_read_delegates_to_payload() { + let mut payload = BufList::new(); + payload.write(Bytes::from_static(b"abc")); + let mut frame = Frame::new(Frame::DATA_FRAME_TYPE, payload).expect("frame"); + + assert_eq!(frame.fill_buf().await.expect("fill buf"), b"abc"); + frame.consume(1); + assert_eq!(frame.fill_buf().await.expect("fill buf"), b"bc"); + } + + #[tokio::test] + async fn incomplete_type_varint() { + let mut stream = StreamReader::new(to_pre_byte_stream([0x40])); + + assert!(matches!( + stream.decode_one::>().await, + Err(StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + }) if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn incomplete_length_varint() { + let mut stream = StreamReader::new(to_pre_byte_stream([0x00, 0x40])); + + assert!(matches!( + stream.decode_one::>().await, + Err(StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + }) if source.code() == Code::H3_FRAME_ERROR + )); } #[tokio::test] @@ -436,7 +742,308 @@ mod tests { source: DecodeError::Incomplete } )); - assert_eq!(payload.as_slice(), b"Hell") + assert_eq!(payload.as_slice(), b"Hell"); + } + + #[tokio::test] + async fn zero_length_payload_stays_empty() { + let stream = StreamReader::new(to_pre_byte_stream([0x00, 0x00])); + + let stream = pin!(FrameStream::new(stream)); + let mut frame = stream.next_frame().await.unwrap().unwrap(); + let mut payload = vec![]; + frame.read_to_end(&mut payload).await.unwrap(); + assert_eq!(payload.len(), 0); + assert_eq!(frame.length().into_inner(), 0); + } + + #[tokio::test] + async fn test_data_frames() { + let stream = StreamReader::new(to_pre_byte_stream(vec![ + 0x00, // Type: DATA + 0x05, // Length: 5 + b'H', b'e', b'l', b'l', b'o', // Payload: "Hello" + 0x00, // Type: DATA + 0x05, // Length: 5 + b'W', b'o', b'r', b'l', b'd', // Payload: "World" + ])); + + let mut stream = pin!(FrameStream::new(stream)); + let mut frame1 = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame1.r#type().into_inner(), 0); + assert_eq!(frame1.length().into_inner(), 5); + let mut payload = vec![]; + frame1.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"Hello"); + + let mut frame2 = stream.next_frame().await.unwrap().unwrap(); + let mut payload = vec![]; + frame2.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"World"); + assert_eq!(frame2.length().into_inner(), 5); + assert_eq!(frame2.r#type().into_inner(), 0); + } + + #[tokio::test] + async fn decode_to_custom_payload() { + let mut stream = StreamReader::new(to_pre_byte_stream([ + 0x00, // DATA + 0x03, // len 3 + b'a', b'b', b'c', + ])); + + let frame = stream.decode_one::>().await.unwrap(); + let frame = Frame::::decode_from(frame).await.unwrap(); + assert_eq!(frame.r#type(), Frame::DATA_FRAME_TYPE); + assert_eq!(frame.length().into_inner(), 3); + assert_eq!(frame.payload().0, b"abc"); + } + + #[tokio::test] + async fn decode_to_buflist_collects_payload() { + let mut stream = StreamReader::new(to_pre_byte_stream([ + 0x01, // HEADERS + 0x03, // len 3 + b'a', b'b', b'c', + ])); + + let frame = Frame::::decode_from(&mut stream) + .await + .expect("frame"); + + assert_eq!(frame.r#type(), Frame::HEADERS_FRAME_TYPE); + assert_eq!(frame.length().into_inner(), 3); + let mut payload = frame.into_payload(); + let mut decoded = Vec::new(); + while payload.has_remaining() { + let chunk = payload.chunk(); + decoded.extend_from_slice(chunk); + payload.advance(chunk.len()); + } + assert_eq!(decoded, b"abc"); + } + + #[tokio::test] + async fn decode_to_buflist_incomplete_payload_returns_h3_frame_error() { + let mut stream = StreamReader::new(to_pre_byte_stream([ + 0x00, // DATA + 0x05, // len 5 + b'H', b'e', b'l', b'l', + ])); + + assert!(matches!( + Frame::::decode_from(&mut stream).await, + Err(StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + }) if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn sink_updates_length_and_reports_overflow() { + let (sink, chunks) = SinkRecorder::new(); + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::from_u32(0), + payload: sink, + }; + + frame.send(Bytes::from_static(b"abc")).await.unwrap(); + assert_eq!(frame.length().into_inner(), 3); + assert_eq!(chunks.lock().expect("chunks lock")[0].as_ref(), b"abc"); + frame.close().await.expect("sink close forwarded"); + + let (sink, chunks) = SinkRecorder::new(); + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::from_u64(VARINT_MAX - 1).expect("length in range"), + payload: sink, + }; + + let error = frame.send(Bytes::from_static(b"ab")).await.unwrap_err(); + assert!(matches!( + error, + crate::codec::StreamEncodeError::Encode { + source: EncodeError::FramePayloadTooLarge, + } + )); + assert!(chunks.lock().expect("chunks lock").is_empty()); + } + + #[tokio::test] + async fn sink_start_send_propagates_payload_error() { + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::from_u32(0), + payload: FailingEncodeStream::fail_start_send(VarInt::from_u32(41)), + }; + + let error = frame + .send(Bytes::from_static(b"abc")) + .await + .expect_err("payload start_send should fail"); + + assert!(matches!( + error, + crate::codec::StreamEncodeError::Reset { code } if code == VarInt::from_u32(41) + )); + } + + #[tokio::test] + async fn encode_into_propagates_type_length_and_payload_errors() { + let frame = Frame::new(Frame::DATA_FRAME_TYPE, Bytes::from_static(b"abc")).unwrap(); + + let error = frame + .clone() + .encode_into(FailingEncodeStream::fail_write_at(1, VarInt::from_u32(43))) + .await + .expect_err("frame type write should fail"); + assert!(matches!( + error, + quic::StreamError::Reset { code } if code == VarInt::from_u32(43) + )); + + let error = frame + .clone() + .encode_into(FailingEncodeStream::fail_write_at(2, VarInt::from_u32(47))) + .await + .expect_err("frame length write should fail"); + assert!(matches!( + error, + quic::StreamError::Reset { code } if code == VarInt::from_u32(47) + )); + + let error = frame + .encode_into(FailingEncodeStream::fail_start_send(VarInt::from_u32(53))) + .await + .expect_err("payload feed should fail"); + assert!(matches!( + error, + quic::StreamError::Reset { code } if code == VarInt::from_u32(53) + )); + } + + #[tokio::test] + async fn failing_encode_stream_allows_nonfailing_io_and_sink_paths() { + let mut stream = FailingEncodeStream { + reset_code: VarInt::from_u32(59), + fail_write_at: None, + fail_start_send: false, + write_calls: 0, + }; + + stream.write_all(b"ok").await.expect("write"); + AsyncWriteExt::flush(&mut stream).await.expect("flush"); + AsyncWriteExt::shutdown(&mut stream) + .await + .expect("shutdown"); + + stream.send(Bytes::from_static(b"ok")).await.expect("send"); + SinkExt::flush(&mut stream).await.expect("sink flush"); + stream.close().await.expect("sink close"); + } + + #[tokio::test] + async fn payload_stream_is_forwarded() { + let payload = stream::iter([ + Ok::(Bytes::from_static(b"one")), + Ok::(Bytes::from_static(b"two")), + ]); + + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::from_u32(6), + payload, + }; + + let mut chunks = Vec::new(); + while let Some(item) = frame.next().await { + chunks.push(item.expect("payload chunk")); + } + assert_eq!(chunks[0].as_ref(), b"one"); + assert_eq!(chunks[1].as_ref(), b"two"); + } + + #[derive(Debug)] + struct ControlPayload { + stream_id: VarInt, + stopped: Arc>>, + } + + impl GetStreamId for ControlPayload { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl StopStream for ControlPayload { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + *self.stopped.lock().expect("stop state poisoned") = Some(code); + Poll::Ready(Ok(())) + } + } + + #[tokio::test] + async fn control_traits_delegate_to_payload() { + let stopped = Arc::new(Mutex::new(None)); + let stream_id = VarInt::from_u32(17); + let stop_code = VarInt::from_u32(23); + let mut frame = Frame { + r#type: Frame::DATA_FRAME_TYPE, + length: VarInt::from_u32(0), + payload: ControlPayload { + stream_id, + stopped: stopped.clone(), + }, + }; + + assert_eq!( + poll_fn(|cx| Pin::new(&mut frame).poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| Pin::new(&mut frame).poll_stop(cx, stop_code)) + .await + .expect("stop"); + + assert_eq!( + *stopped.lock().expect("stop state poisoned"), + Some(stop_code) + ); + } + + #[tokio::test] + async fn can_decode_with_multi_byte_type_varint() { + let stream = StreamReader::new(to_pre_byte_stream([ + 0x40, 0x40, // type: 64 encoded as 2-byte varint + 0x01, // length: 1 + b'z', + ])); + + let stream = pin!(FrameStream::new(stream)); + let mut frame = stream.next_frame().await.unwrap().unwrap(); + assert_eq!(frame.r#type().into_inner(), 64); + let mut payload = vec![]; + frame.read_to_end(&mut payload).await.unwrap(); + assert_eq!(payload, b"z"); + } + + #[test] + fn h3_frame_decode_error_display_and_debug() { + let error = H3FrameDecodeError { + source: DecodeError::Incomplete, + }; + + assert_eq!(format!("{}", error), "frame decode error"); + assert!(format!("{:?}", error).contains("H3FrameDecodeError")); } fn channel() -> ( @@ -447,45 +1054,73 @@ mod tests { (sink.sink_map_err(|_e| unreachable!()), stream.map(Ok)) } + #[tokio::test] + async fn encode_into_flushes_zero_length_frame() { + let (sink, stream) = channel(); + let mut sink = SinkWriter::new(sink); + let frame = Frame::new(Frame::SETTINGS_FRAME_TYPE, Bytes::new()).expect("settings frame"); + + sink.encode_one(frame) + .await + .expect("frame encode should flush"); + + let stream = StreamReader::new(stream); + let mut stream = pin!(FrameStream::new(stream)); + let frame = timeout(Duration::from_millis(100), stream.as_mut().next_frame()) + .await + .expect("flushed frame should be observable") + .expect("frame should be present") + .expect("frame should decode"); + + assert_eq!(frame.r#type(), Frame::SETTINGS_FRAME_TYPE); + assert_eq!(frame.length(), VarInt::from_u32(0)); + } + #[tokio::test] async fn encode_and_decode() { let (sink, stream) = channel(); - let decode = tokio::spawn(async move { - let stream = StreamReader::new(stream); - - let mut stream = pin!(FrameStream::new(stream)); - let mut frame1 = stream.as_mut().next_frame().await.unwrap().unwrap(); - assert_eq!(frame1.r#type().into_inner(), 0); - assert_eq!(frame1.length().into_inner(), 5); - let mut payload = vec![]; - frame1.read_to_end(&mut payload).await.unwrap(); - assert_eq!(&payload[..], b"Hello"); - - let mut frame2 = stream.next_frame().await.unwrap().unwrap(); - let mut payload = vec![]; - frame2.read_to_end(&mut payload).await.unwrap(); - assert_eq!(&payload[..], b"World"); - assert_eq!(frame2.length().into_inner(), 5); - assert_eq!(frame2.r#type().into_inner(), 0); - }); - let encode = tokio::spawn(async move { - let mut sink = SinkWriter::new(sink); - - let frame1 = Frame::new(VarInt::from_u32(0), Bytes::from_static(b"Hello")).unwrap(); - assert!(frame1.r#type().into_inner() == 0); - assert!(frame1.length().into_inner() == 5); - assert!(frame1.payload().as_ref() == b"Hello"); - sink.encode_one(frame1).await.unwrap(); - - let frame2 = Frame::new(VarInt::from_u32(0), Bytes::from_static(b"World")).unwrap(); - assert!(frame2.r#type().into_inner() == 0); - assert!(frame2.length().into_inner() == 5); - assert!(frame2.payload().as_ref() == b"World"); - sink.encode_one(frame2).await.unwrap(); - - Pin::new(&mut sink).flush_buffer().await.unwrap(); - }); + let decode = tokio::spawn( + async move { + let stream = StreamReader::new(stream); + + let mut stream = pin!(FrameStream::new(stream)); + let mut frame1 = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame1.r#type().into_inner(), 0); + assert_eq!(frame1.length().into_inner(), 5); + let mut payload = vec![]; + frame1.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"Hello"); + + let mut frame2 = stream.next_frame().await.unwrap().unwrap(); + let mut payload = vec![]; + frame2.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"World"); + assert_eq!(frame2.length().into_inner(), 5); + assert_eq!(frame2.r#type().into_inner(), 0); + } + .in_current_span(), + ); + let encode = tokio::spawn( + async move { + let mut sink = SinkWriter::new(sink); + + let frame1 = Frame::new(VarInt::from_u32(0), Bytes::from_static(b"Hello")).unwrap(); + assert!(frame1.r#type().into_inner() == 0); + assert!(frame1.length().into_inner() == 5); + assert!(frame1.payload().as_ref() == b"Hello"); + sink.encode_one(frame1).await.unwrap(); + + let frame2 = Frame::new(VarInt::from_u32(0), Bytes::from_static(b"World")).unwrap(); + assert!(frame2.r#type().into_inner() == 0); + assert!(frame2.length().into_inner() == 5); + assert!(frame2.payload().as_ref() == b"World"); + sink.encode_one(frame2).await.unwrap(); + + Pin::new(&mut sink).flush_buffer().await.unwrap(); + } + .in_current_span(), + ); tokio::try_join!(encode, decode).unwrap(); } } diff --git a/src/dhttp/frame/stream.rs b/src/dhttp/frame/stream.rs index b005d44..7913bef 100644 --- a/src/dhttp/frame/stream.rs +++ b/src/dhttp/frame/stream.rs @@ -169,3 +169,380 @@ impl StopStream for FrameStream { self.project().stream.poll_stop(cx, code) } } + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use futures::{SinkExt, Stream, StreamExt, future::poll_fn, stream}; + use tokio::io::AsyncReadExt; + + use super::FrameStream; + use crate::{ + codec::{EncodeExt, StreamReader}, + connection::{self, StreamError}, + dhttp::frame::Frame, + error::Code, + quic, + quic::{GetStreamId, StopStream}, + varint::VarInt, + }; + + #[derive(Debug)] + struct MockStream { + stream_id: VarInt, + stop_code: Arc>>, + items: std::vec::IntoIter>, + } + + impl MockStream { + fn from_iter(stream_id: VarInt, items: Vec>) -> Self { + Self { + stream_id, + stop_code: Arc::new(Mutex::new(None)), + items: items.into_iter(), + } + } + } + + impl futures::Stream for MockStream { + type Item = std::result::Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.get_mut().items.next()) + } + } + + impl quic::GetStreamId for MockStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::StopStream for MockStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + *self.stop_code.lock().expect("stop_code lock poisoned") = Some(code); + Poll::Ready(Ok(())) + } + } + + fn to_pre_byte_stream( + data: impl IntoIterator, + ) -> impl Stream> { + stream::iter(data.into_iter().map(|byte| Ok(Bytes::from(vec![byte])))) + } + + fn expect_h3_frame_decode_error(error: StreamError) { + match error { + StreamError::Connection { + source: connection::ConnectionError::H3 { source, .. }, + } => { + assert_eq!(source.code(), Code::H3_FRAME_ERROR); + } + other => panic!("expected h3 frame decode error, got {other:?}"), + } + } + + fn channel() -> ( + impl futures::Sink, + impl Stream>, + ) { + let (sink, stream) = futures::channel::mpsc::channel(8); + (sink.sink_map_err(|_e| unreachable!()), stream.map(Ok)) + } + + #[tokio::test] + async fn frame_reads_next_frame_payload_and_advances() { + let stream = StreamReader::new(to_pre_byte_stream([ + 0x00, 0x05, // DATA, len 5 + b'H', b'e', b'l', b'l', b'o', 0x00, 0x05, // DATA, len 5 + b'W', b'o', b'r', b'l', b'd', + ])); + + let mut stream = std::pin::pin!(FrameStream::new(stream)); + + let mut frame1 = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame1.r#type().into_inner(), 0); + assert_eq!(frame1.length().into_inner(), 5); + + let mut payload = Vec::new(); + frame1.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"Hello"); + + let mut frame2 = stream.next_frame().await.unwrap().unwrap(); + let mut payload = Vec::new(); + frame2.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"World"); + } + + #[tokio::test] + async fn empty_input_produces_no_frame() { + let mut stream = + std::pin::pin!(FrameStream::new(StreamReader::new(to_pre_byte_stream([])))); + + assert!(stream.as_mut().frame().is_none()); + assert!(stream.as_mut().next_frame().await.is_none()); + assert!(stream.as_mut().frame().is_none()); + } + + #[tokio::test] + async fn next_unreserved_frame_skips_reserved_frames() { + let stream = StreamReader::new(to_pre_byte_stream([ + 0x21, 0x02, // RESERVED frame type, len 2 + b'p', b'a', 0x00, 0x03, // DATA frame, len 3 + b'o', b'n', b'e', + ])); + + let mut stream = std::pin::pin!(FrameStream::new(stream)); + let mut frame = stream + .as_mut() + .next_unreserved_frame() + .await + .unwrap() + .unwrap(); + assert_eq!(frame.r#type(), Frame::DATA_FRAME_TYPE); + + let mut payload = Vec::new(); + frame.read_to_end(&mut payload).await.unwrap(); + assert_eq!(payload.as_slice(), b"one"); + } + + #[tokio::test] + async fn next_unreserved_frame_returns_none_after_only_reserved_frames() { + let stream = StreamReader::new(to_pre_byte_stream([ + 0x21, 0x02, // RESERVED frame type, len 2 + b'p', b'a', + ])); + + let mut stream = std::pin::pin!(FrameStream::new(stream)); + + assert!(stream.as_mut().next_unreserved_frame().await.is_none()); + assert!(stream.as_mut().frame().is_none()); + } + + #[tokio::test] + async fn next_frame_consumes_previous_frame_before_decoding_next() { + let stream = StreamReader::new(to_pre_byte_stream([ + 0x00, 0x04, // DATA, len 4 + b'P', b'u', b't', b's', 0x01, 0x00, // HEADERS, len 0 + ])); + + let mut stream = std::pin::pin!(FrameStream::new(stream)); + let mut first = stream.as_mut().next_frame().await.unwrap().unwrap(); + + let mut prefix = [0u8; 2]; + first.read_exact(&mut prefix).await.unwrap(); + assert_eq!(&prefix, b"Pu"); + + let second = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(second.r#type(), Frame::HEADERS_FRAME_TYPE); + assert_eq!(second.length().into_inner(), 0); + } + + #[tokio::test] + async fn consume_current_frame_drains_payload_and_clears_current_frame() { + let stream = StreamReader::new(to_pre_byte_stream([ + 0x00, 0x03, // DATA, len 3 + b'a', b'b', b'c', 0x01, 0x00, // HEADERS, len 0 + ])); + + let mut stream = std::pin::pin!(FrameStream::new(stream)); + { + let first = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(first.r#type(), Frame::DATA_FRAME_TYPE); + assert_eq!(first.length().into_inner(), 3); + } + + stream + .as_mut() + .consume_current_frame() + .await + .expect("current frame is drained"); + assert!(stream.as_mut().frame().is_none()); + + let second = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(second.r#type(), Frame::HEADERS_FRAME_TYPE); + assert_eq!(second.length().into_inner(), 0); + } + + #[tokio::test] + async fn incomplete_length_converts_to_frame_decode_error() { + let mut stream = std::pin::pin!(FrameStream::new(StreamReader::new(to_pre_byte_stream([ + 0x00, // frame type present but length missing + ])))); + + let error = match stream.as_mut().next_frame().await { + Some(Err(error)) => error, + Some(Ok(_)) => panic!("expected frame decode error"), + None => panic!("expected frame decode error"), + }; + expect_h3_frame_decode_error(error); + + let error = match stream.as_mut().frame() { + Some(Err(error)) => error, + Some(Ok(_)) => panic!("expected frame decode error"), + None => panic!("expected frame decode error"), + }; + expect_h3_frame_decode_error(error); + } + + #[tokio::test] + async fn incomplete_type_varint_converts_to_frame_decode_error() { + let mut stream = std::pin::pin!(FrameStream::new(StreamReader::new(to_pre_byte_stream([ + 0x40, // two-byte varint prefix without the second byte + ])))); + + let error = match stream.as_mut().next_frame().await { + Some(Err(error)) => error, + Some(Ok(_)) => panic!("expected frame decode error"), + None => panic!("expected frame decode error"), + }; + expect_h3_frame_decode_error(error); + } + + #[tokio::test] + async fn consume_current_frame_returns_stored_decode_error() { + let mut stream = std::pin::pin!(FrameStream::new(StreamReader::new(to_pre_byte_stream([ + 0x00, // frame type present but length missing + ])))); + + let error = match stream.as_mut().next_frame().await { + Some(Err(error)) => error, + Some(Ok(_)) => panic!("expected frame decode error"), + None => panic!("expected frame decode error"), + }; + expect_h3_frame_decode_error(error); + + let error = stream + .as_mut() + .consume_current_frame() + .await + .expect_err("stored decode error should be returned"); + expect_h3_frame_decode_error(error); + } + + #[tokio::test] + async fn incomplete_payload_error_propagates_on_consume() { + let mut stream = std::pin::pin!(FrameStream::new(StreamReader::new(to_pre_byte_stream([ + 0x00, 0x05, // DATA len 5 + b'H', b'e', b'l', b'l', // payload incomplete + ])))); + + assert!(stream.as_mut().next_frame().await.unwrap().is_ok()); + + let error = match stream.as_mut().next_frame().await { + Some(Err(error)) => error, + Some(Ok(_)) => panic!("expected frame decode error"), + None => panic!("expected frame decode error"), + }; + expect_h3_frame_decode_error(error); + } + + #[tokio::test] + async fn stream_reset_error_is_forwarded() { + let stream = StreamReader::new(stream::iter([Err(quic::StreamError::Reset { + code: VarInt::from_u32(0x22), + })])); + let mut stream = std::pin::pin!(FrameStream::new(stream)); + + let error = match stream.as_mut().next_frame().await { + Some(Err(error)) => error, + Some(Ok(_)) => panic!("expected stream reset"), + None => panic!("expected stream reset"), + }; + assert!(matches!(error, StreamError::Reset { code } if code.into_inner() == 0x22)); + } + + #[tokio::test] + async fn control_traits_delegate_to_inner() { + let stream_id = VarInt::from_u32(7); + let stop_code = VarInt::from_u32(9); + let mock = MockStream::from_iter(stream_id, vec![Ok(Bytes::from_static(b"\x00\x00"))]); + let stop_code_ref = mock.stop_code.clone(); + let mut stream = FrameStream::new(StreamReader::new(mock)); + + assert_eq!( + poll_fn(|cx| Pin::new(&mut stream).poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| Pin::new(&mut stream).poll_stop(cx, stop_code)) + .await + .expect("stream stop"); + assert_eq!( + *stop_code_ref.lock().expect("stop_code lock poisoned"), + Some(stop_code) + ); + } + + #[tokio::test] + async fn mock_stream_payload_decodes_frame() { + let mock = MockStream::from_iter( + VarInt::from_u32(11), + vec![Ok(Bytes::from_static(b"\x00\x02ok"))], + ); + let mut stream = std::pin::pin!(FrameStream::new(StreamReader::new(mock))); + + let mut frame = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame.r#type(), Frame::DATA_FRAME_TYPE); + assert_eq!(frame.length().into_inner(), 2); + + let mut payload = Vec::new(); + frame.read_to_end(&mut payload).await.unwrap(); + assert_eq!(payload, b"ok"); + } + + #[tokio::test] + async fn encode_and_decode_frames() { + use crate::codec::SinkWriter; + + let (sink, stream) = channel(); + + let decode = async move { + let stream = StreamReader::new(stream); + let mut stream = std::pin::pin!(FrameStream::new(stream)); + + let mut frame1 = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame1.r#type().into_inner(), 0); + let mut payload = Vec::new(); + frame1.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"ok"); + + let mut frame2 = stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame2.r#type().into_inner(), 0); + let mut payload = Vec::new(); + frame2.read_to_end(&mut payload).await.unwrap(); + assert_eq!(&payload[..], b"go"); + Ok::<(), ()>(()) + }; + + let encode = async move { + let mut sink = SinkWriter::new(sink); + + let frame1 = + Frame::new(VarInt::from_u32(0), Bytes::from_static(b"ok")).expect("frame1"); + let frame2 = + Frame::new(VarInt::from_u32(0), Bytes::from_static(b"go")).expect("frame2"); + + sink.encode_one(frame1).await.unwrap(); + sink.encode_one(frame2).await.unwrap(); + futures::SinkExt::flush(&mut sink).await.unwrap(); + Ok::<(), ()>(()) + }; + + tokio::try_join!(decode, encode).unwrap(); + } +} diff --git a/src/dhttp/goaway.rs b/src/dhttp/goaway.rs index ab58ded..0c7237e 100644 --- a/src/dhttp/goaway.rs +++ b/src/dhttp/goaway.rs @@ -82,3 +82,192 @@ impl EncodeInto for Goaway { Ok(frame) } } + +#[cfg(test)] +mod tests { + use std::{ + io, + pin::Pin, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + use tracing::Instrument; + + use super::*; + use crate::{ + codec::{DecodeError, DecodeFrom, EncodeExt}, + connection::ConnectionError, + error::Code, + }; + + #[test] + fn new_stores_stream_id() { + let stream_id = VarInt::from_u32(11); + + assert_eq!(Goaway::new(stream_id).stream_id(), stream_id); + } + + #[tokio::test] + async fn encode_decode_round_trips() { + let stream_id = VarInt::from_u32(1337); + let mut frame = BufList::new() + .encode(Goaway::new(stream_id)) + .await + .expect("goaway encoding is infallible"); + + assert_eq!(frame.r#type(), Frame::GOAWAY_FRAME_TYPE); + + let decoded = Goaway::decode_from(&mut frame).await.expect("goaway frame"); + + assert_eq!(decoded.stream_id(), stream_id); + } + + #[tokio::test] + async fn encode_decode_round_trips_with_boundary_stream_ids() { + for stream_id in [VarInt::from_u32(0), VarInt::MAX] { + let mut frame = BufList::new() + .encode(Goaway::new(stream_id)) + .await + .expect("goaway encoding is infallible"); + + let decoded = Goaway::decode_from(&mut frame).await.expect("goaway frame"); + + assert_eq!(decoded.stream_id(), stream_id); + } + } + + #[tokio::test] + async fn decode_rejects_trailing_payload() { + let mut payload = BufList::new(); + payload + .encode_one(VarInt::from_u32(7)) + .await + .expect("varint encoding into buflist is infallible"); + payload.write(Bytes::from_static(b"trailing")); + let mut frame = + Frame::new(Frame::GOAWAY_FRAME_TYPE, payload).expect("payload length fits varint"); + + let error = Goaway::decode_from(&mut frame) + .await + .expect_err("trailing payload is malformed"); + + assert!(matches!( + error, + StreamError::Connection { + source: ConnectionError::H3 { source }, + } if source.code() == Code::H3_GENERAL_PROTOCOL_ERROR + )); + } + + #[tokio::test] + async fn decode_rejects_empty_payload_as_closed_critical_stream() { + let mut frame = Frame::new(Frame::GOAWAY_FRAME_TYPE, BufList::new()) + .expect("payload length fits varint"); + + let error = Goaway::decode_from(&mut frame) + .await + .expect_err("missing stream id is malformed"); + + assert!(matches!( + error, + StreamError::Connection { + source: ConnectionError::H3 { source }, + } if source.code() == Code::H3_CLOSED_CRITICAL_STREAM + )); + } + + struct ErrorPayload; + + impl AsyncRead for ErrorPayload { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(DecodeError::ArithmeticOverflow.into())) + } + } + + impl AsyncBufRead for ErrorPayload { + fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Err(DecodeError::ArithmeticOverflow.into())) + } + + fn consume(self: Pin<&mut Self>, _amt: usize) {} + } + + #[tokio::test] + async fn decode_maps_payload_decode_error_to_frame_decode_error() { + let mut payload = BufList::new(); + payload.write(Bytes::from_static(b"\0")); + let mut frame = Frame::new(Frame::GOAWAY_FRAME_TYPE, payload) + .expect("payload length fits varint") + .map(|_| ErrorPayload); + + let error = Goaway::decode_from(&mut frame) + .await + .expect_err("payload decode failure should be a frame decode error"); + + assert!(matches!( + error, + StreamError::Connection { + source: ConnectionError::H3 { source }, + } if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn encode_rejects_non_empty_output_buffer() { + let mut prefilled = BufList::new(); + prefilled.write(Bytes::from_static(b"already-filled")); + + let join = tokio::spawn( + async move { + prefilled + .encode(Goaway::new(VarInt::from_u32(1))) + .await + .expect("goaway encoding should panic before this") + } + .in_current_span(), + ); + + let err = join + .await + .expect_err("encoding with non-empty buffer should panic"); + assert!(err.is_panic()); + } + + #[tokio::test] + async fn decode_panics_on_wrong_frame_type() { + let mut payload = BufList::new(); + payload + .encode_one(VarInt::from_u32(7)) + .await + .expect("varint encoding into buflist is infallible"); + let frame = + Frame::new(Frame::SETTINGS_FRAME_TYPE, payload).expect("payload length fits varint"); + + let join = tokio::spawn( + async move { + let mut frame = frame; + let _ = Goaway::decode_from(&mut frame).await; + } + .in_current_span(), + ); + + let err = join + .await + .expect_err("decoding a mismatched frame type should panic"); + assert!(err.is_panic()); + } + + #[test] + fn debug_renders_stream_id() { + let goaway = Goaway::new(VarInt::from_u32(123)); + let rendered = format!("{goaway:?}"); + assert!(rendered.contains("Goaway")); + assert!(rendered.contains("123")); + } +} diff --git a/src/dhttp/message.rs b/src/dhttp/message.rs new file mode 100644 index 0000000..eb054bf --- /dev/null +++ b/src/dhttp/message.rs @@ -0,0 +1,2568 @@ +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::{Buf, Bytes}; +use futures::{SinkExt, TryStreamExt}; +use snafu::Snafu; + +use crate::{ + codec::{EncodeError, EncodeExt, SinkWriter, StreamReader}, + connection::{self, ConnectionGoaway, ConnectionState, LifecycleExt}, + dhttp::{ + frame::{ + Frame, + stream::{FrameStream, ReadableFrame}, + }, + protocol::{AcceptRawMessageStreamError, DHttpProtocol, InitialRawMessageStreamError}, + }, + error::{Code, H3FrameDecodeError, H3FrameUnexpected}, + qpack::{ + algorithm::{DynamicCompressAlgo, HuffmanAlways}, + decoder::QPackMessageStreamReader, + encoder::{EncodeHeaderSectionError, Encoder}, + field::{FieldLine, FieldSection}, + protocol::{QPackDecoder, QPackEncoder, QPackProtocolDisabled}, + }, + quic::{self, GetStreamIdExt, ResetStreamExt, StopStreamExt}, + varint::{self, VarInt}, +}; + +pub(crate) mod guard; +#[cfg(feature = "hyper")] +pub mod hyper; +pub mod unfold; + +pub type BoxMessageReader< + S = dyn crate::stream::ReadStream< + Bytes, + MessageStreamError, + quic::StreamError, + quic::StreamError, + > + Send, +> = Pin>; + +pub type BoxMessageWriter< + S = dyn crate::stream::WriteStream< + Bytes, + MessageStreamError, + quic::StreamError, + quic::StreamError, + > + Send, +> = Pin>; + +impl quic::GetStreamId + for dyn crate::stream::ReadStream + + Send +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + crate::stream::GetStreamId::poll_stream_id(self, cx) + } +} + +impl quic::StopStream + for dyn crate::stream::ReadStream + + Send +{ + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + crate::stream::StopStream::poll_stop(self, cx, code) + } +} + +impl quic::GetStreamId + for dyn crate::stream::WriteStream + + Send +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + crate::stream::GetStreamId::poll_stream_id(self, cx) + } +} + +impl quic::ResetStream + for dyn crate::stream::WriteStream + + Send +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + crate::stream::ResetStream::poll_reset(self, cx, code) + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Snafu)] +pub enum MessageStreamError { + #[snafu(transparent)] + Quic { source: quic::StreamError }, + #[snafu(transparent)] + Goaway { source: ConnectionGoaway }, + #[snafu(display( + "header section too large to fit into a single frame, maybe too many header fields" + ))] + HeaderTooLarge, + #[snafu(display( + "trailer section too large to fit into a single frame, maybe too many header fields" + ))] + TrailerTooLarge, + #[snafu(display("data frame payload too large, try smaller chunk size"))] + DataFrameTooLarge { source: varint::err::Overflow }, + #[snafu(display("http/3 message from peer is malformed"))] + MalformedIncomingMessage, + #[snafu(display("http/3 message to send is malformed"))] + MalformedOutgoingMessage, + #[snafu(display("message send previously failed"))] + MessageSendFailed, + #[snafu(display("message writer is closed"))] + MessageWriterClosed, +} + +impl From for MessageStreamError { + fn from(value: quic::ConnectionError) -> Self { + Self::Quic { + source: value.into(), + } + } +} + +impl From for MessageStreamError { + fn from(source: varint::err::Overflow) -> Self { + Self::DataFrameTooLarge { source } + } +} + +impl From for io::Error { + fn from(error: MessageStreamError) -> Self { + use io::ErrorKind; + + let kind = match &error { + // Delegate to the underlying QUIC stream error's kind mapping. + MessageStreamError::Quic { .. } => None, + MessageStreamError::Goaway { .. } => Some(ErrorKind::ConnectionAborted), + MessageStreamError::MalformedIncomingMessage => Some(ErrorKind::InvalidData), + MessageStreamError::MalformedOutgoingMessage => Some(ErrorKind::InvalidInput), + MessageStreamError::MessageSendFailed | MessageStreamError::MessageWriterClosed => { + Some(ErrorKind::BrokenPipe) + } + MessageStreamError::HeaderTooLarge + | MessageStreamError::TrailerTooLarge + | MessageStreamError::DataFrameTooLarge { .. } => Some(ErrorKind::InvalidInput), + }; + match kind { + Some(kind) => io::Error::new(kind, error), + None => match error { + MessageStreamError::Quic { source } => io::Error::from(source), + _ => unreachable!("non-Quic variants always resolve to a concrete kind"), + }, + } + } +} + +pub struct MessageReader { + pub(super) stream: FrameStream, + pub(super) qpack_decoder: Arc, + pub(super) state: ConnectionState, +} + +impl MessageReader { + pub fn new( + stream_id: VarInt, + stream: StreamReader, + qpack_decoder: Arc, + state: ConnectionState, + ) -> Self { + let decoder = qpack_decoder.clone(); + let stream = + stream.map_stream(move |guarded| { + let mut stream = guard::GuardQuicReader::new(Box::pin( + QPackMessageStreamReader::new(stream_id, guarded.into_inner(), decoder), + )); + stream.set_stream_id(stream_id); + stream + }); + let frame_stream = FrameStream::new(stream); + Self { + stream: frame_stream, + qpack_decoder, + state, + } + } + + pub fn connection(&self) -> &Arc { + self.state.quic() + } + + pub async fn peek_frame( + &mut self, + ) -> Option, connection::StreamError>> { + loop { + match Pin::new(&mut self.stream).frame() { + None => match Pin::new(&mut self.stream).next_unreserved_frame().await? { + Ok(_next_frame) => continue, + Err(error) => return Some(Err(error)), + }, + Some(Ok(frame)) + if frame.r#type() == Frame::HEADERS_FRAME_TYPE + || frame.r#type() == Frame::DATA_FRAME_TYPE => + { + // avoid rust bc bug + return Pin::new(&mut self.stream).frame(); + } + Some(Ok(_frame)) => { + return Some(Err(H3FrameUnexpected::UnexpectedFrameType.into())); + } + Some(Err(error)) => return Some(Err(error)), + } + } + } + + pub async fn read_data_frame_chunk( + &mut self, + ) -> Result, connection::StreamError> { + loop { + match self.peek_frame().await { + Some(Ok(mut frame)) if frame.r#type() == Frame::DATA_FRAME_TYPE => { + match frame.try_next().await { + Ok(Some(bytes)) => return Ok(Some(bytes)), + Ok(None) => { + Pin::new(&mut self.stream).consume_current_frame().await?; + continue; + } + Err(error) => { + let error = error.into_stream_error(|error| { + H3FrameDecodeError { source: error }.into() + }); + return Err(error); + } + } + } + Some(Ok(..)) | None => return Ok(None), + Some(Err(error)) => return Err(error), + } + } + } + + pub async fn read_header_frame( + &mut self, + ) -> Result, connection::StreamError> { + match self.peek_frame().await { + Some(Ok(frame)) if frame.r#type() == Frame::HEADERS_FRAME_TYPE => { + let Some(frame) = Pin::new(&mut self.stream).frame() else { + return Ok(None); + }; + let frame = match frame { + Ok(frame) => frame, + Err(error) => return Err(error), + }; + match self.qpack_decoder.decode(frame).await { + Ok(field_section) => { + Pin::new(&mut self.stream).consume_current_frame().await?; + Ok(Some(field_section)) + } + Err(error) => Err(error), + } + } + Some(Ok(..)) | None => Ok(None), + Some(Err(error)) => Err(error), + } + } + + pub async fn read_data_chunk(&mut self) -> Result, MessageStreamError> { + self.try_stream_read(async |this| this.read_data_frame_chunk().await) + .await + } + + pub async fn read_header(&mut self) -> Result, MessageStreamError> { + self.try_stream_read(async |this| this.read_header_frame().await) + .await + } + + pub async fn stop(&mut self, code: Code) -> Result<(), MessageStreamError> { + self.try_stream_read(async move |this| Ok(this.stream.stop(code.into_inner()).await?)) + .await + } + + pub async fn peer_goaway_covers( + &mut self, + ) -> Result> + use<>, quic::StreamError> + { + let stream_id = self.stream.stream_id().await?; + let state = self.state.clone(); + + Ok(async move { + let conn = state.quic().clone(); + let error = conn.closed(); + let dhttp_state = state + .protocols() + .get::() + .unwrap() + .state + .clone(); + tokio::select! { + biased; + _goaway = dhttp_state.peer_goaway_covers(stream_id) => Ok(()), + error = error => Err(error), + } + }) + } + + pub async fn try_stream_read( + &mut self, + f: impl AsyncFnOnce(&mut Self) -> Result, + ) -> Result { + let peer_goaway = self.peer_goaway_covers().await?; + tokio::select! { + result = f(self) => match result { + Ok(value) => Ok(value), + Err(error) => Err(MessageStreamError::Quic { + source: self.handle_stream_error(error).await, + }), + }, + goaway = peer_goaway => match goaway { + Ok(()) => { + // FIXME: which code should be used? + _ = self.stream.stop(Code::H3_NO_ERROR.into()).await; + Err(ConnectionGoaway::Peer.into()) + } + Err(error) => Err(error.into()) + } + } + } + + /// Resolve a stream-level H3 error into a `quic::StreamError`, performing + /// the correct side effect for each variant: + /// + /// - `Connection` — delegate to [`LifecycleExt::handle_connection_error`] + /// so that a fresh H3 connection-scope violation closes the QUIC + /// connection. + /// - `Reset` — nothing to do locally; the peer already reset the stream. + /// - `H3` — a freshly detected stream-scope protocol violation; issue + /// `STOP_SENDING` on this reader so the peer observes the abort. + pub async fn handle_stream_error( + &mut self, + error: connection::StreamError, + ) -> quic::StreamError { + match error { + connection::StreamError::Connection { source } => { + let source = self + .state + .quic() + .as_ref() + .handle_connection_error(source) + .await; + self.stream + .inner_mut() + .mark_connection_closed(source.clone()); + source.into() + } + connection::StreamError::Reset { code } => { + self.stream.inner_mut().mark_reset(code); + quic::StreamError::Reset { code } + } + connection::StreamError::H3 { source } => { + let code = source.code().into_inner(); + _ = self.stream.stop(code).await; + quic::StreamError::Reset { code } + } + } + } + + pub fn take(&mut self) -> Self { + let taken = self.stream.inner_mut().take(); + Self { + stream: FrameStream::new(StreamReader::new(taken)), + qpack_decoder: self.qpack_decoder.clone(), + state: self.state.clone(), + } + } +} + +impl quic::GetStreamId for MessageReader { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_stream_id(cx) + } +} + +impl quic::StopStream for MessageReader { + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_stop(cx, code) + } +} + +impl crate::stream::GetStreamId for MessageReader { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + quic::GetStreamId::poll_stream_id(self, cx) + } +} + +impl crate::stream::StopStream for MessageReader { + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + quic::StopStream::poll_stop(self, cx, code) + } +} + +pub struct MessageWriter { + pub(super) stream: SinkWriter, + pub(super) qpack_encoder: Arc, + pub(super) state: ConnectionState, +} + +pub const DEFAULT_COMPRESS_ALGO: DynamicCompressAlgo = + DynamicCompressAlgo::new(HuffmanAlways); + +impl MessageWriter { + pub fn new( + stream: SinkWriter, + qpack_encoder: Arc, + state: ConnectionState, + ) -> Self { + Self { + stream, + qpack_encoder, + state, + } + } + + fn ensure_write_open(&self) -> Result<(), MessageStreamError> { + match self.stream.sink().state_snapshot() { + guard::QuicWriterStateSnapshot::Open => Ok(()), + guard::QuicWriterStateSnapshot::Closed => Err(MessageStreamError::MessageWriterClosed), + guard::QuicWriterStateSnapshot::Reset { code } => Err(MessageStreamError::Quic { + source: quic::StreamError::Reset { code }, + }), + guard::QuicWriterStateSnapshot::ConnectionClosed { source } => { + Err(MessageStreamError::Quic { + source: quic::StreamError::Connection { source }, + }) + } + guard::QuicWriterStateSnapshot::Taken => { + panic!("message writer used after being taken, this is a bug") + } + } + } + + pub async fn write_frame( + &mut self, + frame: Frame, + ) -> Result<(), MessageStreamError> { + self.try_stream_write(async move |this| { + this.stream.encode_one(frame).await?; + Ok(()) + }) + .await + } + + pub async fn write_data_frame( + &mut self, + data: impl Buf + Send, + ) -> Result<(), MessageStreamError> { + let frame = Frame::new(Frame::DATA_FRAME_TYPE, data)?; + self.write_frame(frame).await + } + + pub async fn write_header_frame( + &mut self, + field_lines: impl IntoIterator + Send, + ) -> Result<(), MessageStreamError> { + let result = self + .try_stream_write(async move |this| { + let algo = &DEFAULT_COMPRESS_ALGO; + match Encoder::encode(&*this.qpack_encoder, field_lines, algo, &mut this.stream) + .await + { + Ok(frame) => { + this.stream.encode_one(frame).await?; + Ok(Ok(())) + } + Err(EncodeHeaderSectionError::Encode { source }) => Ok(Err(source)), + Err(EncodeHeaderSectionError::Stream { source }) => Err(source), + } + }) + .await?; + + match result { + Ok(()) => Ok(()), + Err(EncodeError::FramePayloadTooLarge) => Err(MessageStreamError::HeaderTooLarge), + Err(EncodeError::HuffmanEncoding) => { + unreachable!("FieldSection contain invalid header name/value, this is a bug") + } + } + } + + pub async fn write_data(&mut self, data: impl Buf + Send) -> Result<(), MessageStreamError> { + self.write_data_frame(data).await + } + + pub async fn write_header( + &mut self, + field_lines: impl IntoIterator + Send, + ) -> Result<(), MessageStreamError> { + self.write_header_frame(field_lines).await?; + + // Flush encoder instructions (dynamic table insertions) to the encoder stream. + // Encoder stream errors are connection-level: reset = connection error per RFC 9204. + if let Err(error) = self.qpack_encoder.flush_instructions().await { + let quic_error = self.handle_stream_error(error).await; + return Err(MessageStreamError::Quic { source: quic_error }); + } + + Ok(()) + } + + pub async fn send_data(&mut self, data: impl Buf + Send) -> Result<(), MessageStreamError> { + self.write_data(data).await + } + + pub async fn send_header( + &mut self, + field_lines: impl IntoIterator + Send, + ) -> Result<(), MessageStreamError> { + self.write_header(field_lines).await + } + + pub async fn flush(&mut self) -> Result<(), MessageStreamError> { + self.try_stream_write(async move |this| Ok(this.stream.flush_inner().await?)) + .await + } + + pub async fn close(&mut self) -> Result<(), MessageStreamError> { + self.try_stream_write(async move |this| Ok(this.stream.close().await?)) + .await + } + + pub async fn reset(&mut self, code: Code) -> Result<(), MessageStreamError> { + self.try_stream_write(async move |this| Ok(this.stream.reset(code.into_inner()).await?)) + .await + } + + async fn peer_goaway_covers( + &mut self, + ) -> Result> + use<>, quic::StreamError> + { + let stream_id = self.stream.stream_id().await?; + let state = self.state.clone(); + + Ok(async move { + let conn = state.quic().clone(); + let error = conn.closed(); + let dhttp_state = state + .protocols() + .get::() + .unwrap() + .state + .clone(); + tokio::select! { + biased; + _goaway = dhttp_state.peer_goaway_covers(stream_id) => Ok(()), + error = error => Err(error), + } + }) + } + + pub async fn try_stream_write( + &mut self, + f: impl AsyncFnOnce(&mut Self) -> Result, + ) -> Result { + self.ensure_write_open()?; + let peer_goaway = self.peer_goaway_covers().await?; + let f = async move |this: &mut Self| { + let value = f(this).await?; + // ensure all data are written into the underlying QUIC stream + this.stream.flush_buffer().await?; + Ok(value) + }; + tokio::select! { + result = f(self) => match result { + Ok(value) => Ok(value), + Err(error) => Err(MessageStreamError::Quic { + source: self.handle_stream_error(error).await, + }), + }, + goaway = peer_goaway => match goaway { + Ok(()) => { + // FIXME: which code should be used? + _ = self.stream.reset(Code::H3_NO_ERROR.into()).await; + Err(ConnectionGoaway::Peer.into()) + } + Err(error) => Err(error.into()) + } + } + } + + /// Resolve a stream-level H3 error into a `quic::StreamError`. See + /// [`MessageReader::handle_stream_error`] for semantics; the difference is + /// that `H3` errors issue `RESET_STREAM` (via `reset`) on this writer + /// instead of `STOP_SENDING`. + pub async fn handle_stream_error( + &mut self, + error: connection::StreamError, + ) -> quic::StreamError { + match error { + connection::StreamError::Connection { source } => { + let source = self + .state + .quic() + .as_ref() + .handle_connection_error(source) + .await; + self.stream + .sink_mut() + .mark_connection_closed(source.clone()); + source.into() + } + connection::StreamError::Reset { code } => { + self.stream.sink_mut().mark_reset(code); + quic::StreamError::Reset { code } + } + connection::StreamError::H3 { source } => { + let code = source.code().into_inner(); + _ = self.stream.reset(code).await; + quic::StreamError::Reset { code } + } + } + } + + pub fn take(&mut self) -> Self { + let taken = self.stream.sink_mut().take(); + Self { + stream: SinkWriter::new(taken), + qpack_encoder: self.qpack_encoder.clone(), + state: self.state.clone(), + } + } +} + +impl quic::GetStreamId for MessageWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_stream_id(cx) + } +} + +impl quic::ResetStream for MessageWriter { + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + Pin::new(&mut self.get_mut().stream).poll_reset(cx, code) + } +} + +impl crate::stream::GetStreamId for MessageWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + quic::GetStreamId::poll_stream_id(self, cx) + } +} + +impl crate::stream::ResetStream for MessageWriter { + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + quic::ResetStream::poll_reset(self, cx, code) + } +} + +#[derive(Debug, Snafu, Clone)] +pub enum InitialMessageStreamError { + #[snafu(transparent)] + InitialRawStream { + source: InitialRawMessageStreamError, + }, + #[snafu(transparent)] + QPackProtocolDisabled { source: QPackProtocolDisabled }, +} + +#[derive(Debug, Snafu, Clone)] +pub enum AcceptMessageStreamError { + #[snafu(transparent)] + AcceptRawStream { source: AcceptRawMessageStreamError }, + #[snafu(transparent)] + QPackProtocolDisabled { source: QPackProtocolDisabled }, +} + +impl ConnectionState { + pub async fn initial_message_stream( + &self, + ) -> Result<(MessageReader, MessageWriter), InitialMessageStreamError> { + let state = self.erase(); + let qpack = self.qpack()?; + let (mut reader, writer) = self.initial_raw_message_stream().await?; + let stream_id = reader.stream_id().await.map_err(|source| { + InitialMessageStreamError::InitialRawStream { + source: InitialRawMessageStreamError::ResponseStream { source }, + } + })?; + Ok(( + MessageReader::new(stream_id, reader, qpack.decoder.clone(), state.clone()), + MessageWriter::new(writer, qpack.encoder.clone(), state), + )) + } + + pub async fn accept_message_stream( + &self, + ) -> Result<(MessageReader, MessageWriter), AcceptMessageStreamError> { + let state = self.erase(); + let qpack = self.qpack()?; + let (mut reader, writer) = self.accept_raw_message_stream().await?; + let stream_id = reader.stream_id().await.map_err(|source| { + AcceptMessageStreamError::AcceptRawStream { + source: AcceptRawMessageStreamError::RequestStream { source }, + } + })?; + Ok(( + MessageReader::new(stream_id, reader, qpack.decoder.clone(), state.clone()), + MessageWriter::new(writer, qpack.encoder.clone(), state), + )) + } +} + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, + }; + + use bytes::{Buf, Bytes}; + use futures::{Sink, SinkExt, Stream, future::poll_fn}; + use tokio::{sync::mpsc, time::timeout}; + + use super::{MessageReader, MessageStreamError, MessageWriter, guard}; + use crate::{ + codec::{ + BoxPeekableStreamReader, BoxStreamWriter, PeekableStreamReader, SinkWriter, + StreamReader, + }, + connection::{ConnectionState, StreamError, tests::MockConnection}, + dhttp::{ + goaway::Goaway, + message::test::{read_stream_for_test, write_stream_for_test}, + protocol::DHttpProtocol, + settings::Settings, + }, + error::Code, + protocol::{Protocol, Protocols, StreamVerdict}, + qpack::protocol::{QPackDecoder, QPackEncoder, QPackProtocolFactory}, + quic::{self, GetStreamId, GetStreamIdExt, ResetStream, StopStream}, + varint::VarInt, + }; + + #[test] + fn io_error_kind_is_derived_per_variant() { + use std::io::ErrorKind; + + fn assert_kind(error: MessageStreamError, expected: ErrorKind) { + let repr = format!("{error:?}"); + let io_error = std::io::Error::from(error); + assert_eq!(io_error.kind(), expected, "unexpected kind for {repr}"); + } + + assert_kind(MessageStreamError::HeaderTooLarge, ErrorKind::InvalidInput); + assert_kind(MessageStreamError::TrailerTooLarge, ErrorKind::InvalidInput); + assert_kind( + MessageStreamError::DataFrameTooLarge { + source: crate::varint::VarInt::from_u64(1 << 63) + .expect_err("value exceeds varint encoding"), + }, + ErrorKind::InvalidInput, + ); + assert_kind( + MessageStreamError::MalformedIncomingMessage, + ErrorKind::InvalidData, + ); + assert_kind( + MessageStreamError::MalformedOutgoingMessage, + ErrorKind::InvalidInput, + ); + assert_kind( + MessageStreamError::Goaway { + source: crate::connection::ConnectionGoaway::Peer, + }, + ErrorKind::ConnectionAborted, + ); + assert_kind( + MessageStreamError::Quic { + source: quic::StreamError::Reset { + code: VarInt::from_u32(0), + }, + }, + ErrorKind::BrokenPipe, + ); + } + + fn qpack_decoder_sink() + -> Pin + Send>> + { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + + fn qpack_decoder_stream() -> Pin< + Box< + dyn Stream> + + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn qpack_encoder_sink() + -> Pin + Send>> + { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + + fn qpack_encoder_stream() -> Pin< + Box< + dyn Stream> + + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + #[derive(Debug)] + struct TestReadStream { + stream_id: VarInt, + } + + impl quic::GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + + impl quic::StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Stream for TestReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } + } + + #[derive(Debug)] + struct TestWriteStream { + stream_id: VarInt, + } + + impl quic::GetStreamId for TestWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + + impl quic::ResetStream for TestWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Sink for TestWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct StreamIdErrorReadStream { + error: quic::StreamError, + } + + impl quic::GetStreamId for StreamIdErrorReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Err(self.get_mut().error.clone())) + } + } + + impl quic::StopStream for StreamIdErrorReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + } + + impl Stream for StreamIdErrorReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let _ = self; + Poll::Ready(None) + } + } + + #[derive(Debug)] + struct StreamIdErrorWriteStream { + error: quic::StreamError, + } + + impl quic::GetStreamId for StreamIdErrorWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Err(self.get_mut().error.clone())) + } + } + + impl quic::ResetStream for StreamIdErrorWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + } + + impl Sink for StreamIdErrorWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + let _ = self; + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct TrackedReadStream { + stream_id: VarInt, + stop_tx: mpsc::UnboundedSender, + } + + impl quic::GetStreamId for TrackedReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + + impl quic::StopStream for TrackedReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.get_mut() + .stop_tx + .send(code) + .expect("tracked reader receiver should still be alive"); + Poll::Ready(Ok(())) + } + } + + impl Stream for TrackedReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let _ = self; + Poll::Ready(None) + } + } + + #[derive(Debug)] + struct TrackedWriteStream { + stream_id: VarInt, + reset_tx: mpsc::UnboundedSender, + } + + impl quic::GetStreamId for TrackedWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + + impl quic::ResetStream for TrackedWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.get_mut() + .reset_tx + .send(code) + .expect("tracked writer receiver should still be alive"); + Poll::Ready(Ok(())) + } + } + + impl Sink for TrackedWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + let _ = self; + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + } + + fn state_without_qpack(quic: Arc) -> ConnectionState { + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased)); + ConnectionState::new_for_test(quic, Arc::new(protocols)) + } + + fn transport_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: reason.into(), + }, + } + } + + async fn state_with_qpack(quic: Arc) -> ConnectionState { + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased)); + let qpack = QPackProtocolFactory::new() + .init(&quic, &protocols) + .await + .expect("qpack protocol should initialize for tests"); + protocols.insert(qpack); + ConnectionState::new_for_test(quic, Arc::new(protocols)) + } + + async fn test_peekable_bi_stream_with_bytes( + stream_id: u32, + bytes: &[u8], + ) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let (reader, mut write_side) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + write_side + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test bidi bytes"); + write_side.close().await.expect("close test bidi read side"); + + let (_read_side, writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + ( + PeekableStreamReader::new(StreamReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )), + SinkWriter::new(Box::pin(writer) as crate::quic::BoxQuicStreamWriter), + ) + } + + async fn read_stream_with_bytes(stream_id: u32, bytes: &[u8]) -> MessageReader { + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + writer + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test stream bytes"); + writer.close().await.expect("close test stream writer"); + + MessageReader::new( + VarInt::from_u32(stream_id), + StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state_without_qpack(Arc::new(MockConnection::new())).erase(), + ) + } + + fn paired_message_streams(stream_id: u32) -> (MessageReader, MessageWriter) { + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + let read_stream = MessageReader::new( + VarInt::from_u32(stream_id), + StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state_without_qpack(Arc::new(MockConnection::new())).erase(), + ); + let write_stream = MessageWriter::new( + SinkWriter::new(guard::GuardQuicWriter::new( + Box::pin(writer) as crate::quic::BoxQuicStreamWriter + )), + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state_without_qpack(Arc::new(MockConnection::new())).erase(), + ); + + (read_stream, write_stream) + } + + #[tokio::test] + async fn read_stream_try_stream_read_aborts_when_peer_goaway_covers_stream() { + let erased: Arc = Arc::new(MockConnection::new()); + + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(erased, Arc::new(protocols)); + + let reader = StreamReader::new(guard::GuardQuicReader::new(Box::pin(TestReadStream { + stream_id: VarInt::from_u32(10), + }) + as crate::quic::BoxQuicStreamReader)); + let mut read_stream = MessageReader::new( + VarInt::from_u32(10), + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ); + + state + .dhttp() + .apply_peer_goaway(Goaway::new(VarInt::from_u32(9))) + .expect("peer goaway should be accepted"); + + let result = read_stream + .try_stream_read(async move |_this| { + futures::future::pending::>().await + }) + .await; + + assert!(matches!( + result, + Err(super::MessageStreamError::Goaway { + source: crate::connection::ConnectionGoaway::Peer + }) + )); + } + + #[tokio::test] + async fn write_stream_try_stream_write_aborts_when_peer_goaway_covers_stream() { + let erased: Arc = Arc::new(MockConnection::new()); + + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(erased, Arc::new(protocols)); + + let writer = SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TestWriteStream { + stream_id: VarInt::from_u32(12), + }) + as crate::quic::BoxQuicStreamWriter)); + let mut write_stream = MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state.clone(), + ); + + state + .dhttp() + .apply_peer_goaway(Goaway::new(VarInt::from_u32(11))) + .expect("peer goaway should be accepted"); + + let result = write_stream + .try_stream_write(async move |_this| { + futures::future::pending::>().await + }) + .await; + + assert!(matches!( + result, + Err(super::MessageStreamError::Goaway { + source: crate::connection::ConnectionGoaway::Peer + }) + )); + } + + #[tokio::test] + async fn initial_message_stream_requires_qpack_protocol() { + let state = state_without_qpack(Arc::new(MockConnection::new())); + + let result = state.initial_message_stream().await; + + assert!(matches!( + result, + Err(super::InitialMessageStreamError::QPackProtocolDisabled { .. }) + )); + } + + #[tokio::test] + async fn initial_message_stream_wraps_initial_raw_stream_errors() { + let quic = Arc::new(MockConnection::new()); + let state = state_with_qpack(quic).await; + + let result = state.initial_message_stream().await; + + assert!(matches!( + result, + Err(super::InitialMessageStreamError::InitialRawStream { + source: crate::dhttp::protocol::InitialRawMessageStreamError::Connection { .. } + }) + )); + } + + #[tokio::test] + async fn initial_message_stream_wraps_successful_raw_streams() { + let quic = Arc::new(MockConnection::new()); + quic.enable_stream_ops(); + let state = state_with_qpack(quic).await; + + let (mut reader, mut writer) = state + .initial_message_stream() + .await + .expect("initial message stream should open"); + + assert_eq!( + reader.stream_id().await.expect("reader stream id"), + VarInt::from_u32(0) + ); + assert_eq!( + writer.stream_id().await.expect("writer stream id"), + VarInt::from_u32(0) + ); + assert_eq!(state.max_initialized_stream_id(), Some(VarInt::from_u32(0))); + } + + #[tokio::test] + async fn accept_message_stream_requires_qpack_protocol() { + let state = state_without_qpack(Arc::new(MockConnection::new())); + + let result = state.accept_message_stream().await; + + assert!(matches!( + result, + Err(super::AcceptMessageStreamError::QPackProtocolDisabled { .. }) + )); + } + + #[tokio::test] + async fn accept_message_stream_wraps_goaway_rejections() { + let quic = Arc::new(MockConnection::new()); + let state = state_with_qpack(quic).await; + state + .dhttp() + .local_goaway + .set(Goaway::new(VarInt::from_u32(12))); + assert!(matches!( + state + .dhttp() + .accept_bi(test_peekable_bi_stream_with_bytes(12, &[0x01]).await) + .await + .expect("request stream should be routed"), + StreamVerdict::Accepted + )); + + let result = state.accept_message_stream().await; + + assert!(matches!( + result, + Err(super::AcceptMessageStreamError::AcceptRawStream { + source: crate::dhttp::protocol::AcceptRawMessageStreamError::Goaway { + source: crate::connection::ConnectionGoaway::Local + } + }) + )); + } + + #[tokio::test] + async fn accept_message_stream_wraps_successful_raw_streams() { + let quic = Arc::new(MockConnection::new()); + let state = state_with_qpack(quic).await; + assert!(matches!( + state + .dhttp() + .accept_bi(test_peekable_bi_stream_with_bytes(6, &[0x01]).await) + .await + .expect("request stream should be routed"), + StreamVerdict::Accepted + )); + + let (mut reader, mut writer) = state + .accept_message_stream() + .await + .expect("accepted message stream should wrap raw streams"); + + assert_eq!( + reader.stream_id().await.expect("reader stream id"), + VarInt::from_u32(6) + ); + assert_eq!( + writer.stream_id().await.expect("writer stream id"), + VarInt::from_u32(6) + ); + assert_eq!(state.max_received_stream_id(), Some(VarInt::from_u32(6))); + } + + #[tokio::test] + async fn read_stream_stop_emits_qpack_stream_cancellation() { + let stream_id = VarInt::from_u32(24); + let stop_code = VarInt::from_u32(25); + let state = state_without_qpack(Arc::new(MockConnection::new())).erase(); + let decoder = Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )); + let (reader, _writer) = quic::test::mock_stream_pair(stream_id); + let mut stream = MessageReader::new( + stream_id, + StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )), + decoder.clone(), + state, + ); + + poll_fn(|cx| Pin::new(&mut stream).poll_stop(cx, stop_code)) + .await + .expect("read stream stop should complete"); + + assert_eq!( + decoder + .state + .lock() + .expect("lock is not poisoned") + .pending_instructions + .back(), + Some( + &crate::qpack::decoder::DecoderInstruction::StreamCancellation { + stream_id: stream_id.into_inner() + } + ) + ); + } + + #[tokio::test] + async fn read_stream_peer_goaway_future_returns_connection_error_when_connection_closes() { + let quic = Arc::new(MockConnection::new()); + let state = state_without_qpack(quic.clone()).erase(); + let reader = StreamReader::new(guard::GuardQuicReader::new(Box::pin(TestReadStream { + stream_id: VarInt::from_u32(2), + }) + as crate::quic::BoxQuicStreamReader)); + let mut stream = MessageReader::new( + VarInt::from_u32(2), + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state, + ); + quic.set_terminal_error(quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: "peer closed".into(), + }, + }); + + let result = stream + .peer_goaway_covers() + .await + .expect("stream id should resolve") + .await; + + assert!(matches!( + result, + Err(quic::ConnectionError::Transport { .. }) + )); + } + + #[tokio::test] + async fn write_stream_peer_goaway_future_returns_connection_error_when_connection_closes() { + let quic = Arc::new(MockConnection::new()); + let state = state_without_qpack(quic.clone()).erase(); + let writer = SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TestWriteStream { + stream_id: VarInt::from_u32(3), + }) + as crate::quic::BoxQuicStreamWriter)); + let mut stream = MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ); + quic.set_terminal_error(quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: "peer closed".into(), + }, + }); + + let result = stream + .peer_goaway_covers() + .await + .expect("stream id should resolve") + .await; + + assert!(matches!( + result, + Err(quic::ConnectionError::Transport { .. }) + )); + } + + #[tokio::test] + async fn read_stream_try_stream_read_uses_constructor_stream_id() { + let state = state_without_qpack(Arc::new(MockConnection::new())).erase(); + let reader = StreamReader::new(guard::GuardQuicReader::new(Box::pin( + StreamIdErrorReadStream { + error: quic::StreamError::Reset { + code: VarInt::from_u32(91), + }, + }, + ) + as crate::quic::BoxQuicStreamReader)); + let mut stream = MessageReader::new( + VarInt::from_u32(91), + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state, + ); + stream + .state + .dhttp() + .apply_peer_goaway(Goaway::new(VarInt::from_u32(90))) + .expect("peer goaway should be accepted"); + + let result: Result<(), MessageStreamError> = stream + .try_stream_read(async |_this| { + futures::future::pending::>().await + }) + .await; + + assert!(matches!( + result, + Err(MessageStreamError::Goaway { + source: crate::connection::ConnectionGoaway::Peer + }) + )); + } + + #[tokio::test] + async fn write_stream_try_stream_write_surfaces_stream_id_errors() { + let state = state_without_qpack(Arc::new(MockConnection::new())).erase(); + let reset_code = VarInt::from_u32(92); + let writer = SinkWriter::new(guard::GuardQuicWriter::new(Box::pin( + StreamIdErrorWriteStream { + error: quic::StreamError::Reset { code: reset_code }, + }, + ) + as crate::quic::BoxQuicStreamWriter)); + let mut stream = MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ); + + let result: Result<(), MessageStreamError> = stream + .try_stream_write(async |_this| panic!("closure should not run when stream id fails")) + .await; + + assert!(matches!( + result, + Err(MessageStreamError::Quic { + source: quic::StreamError::Reset { code } + }) if code == reset_code + )); + } + + #[test] + fn message_stream_error_conversions_cover_connection_and_send_failure_variants() { + use std::io::ErrorKind; + + let error = MessageStreamError::from(transport_error("connection conversion")); + assert!(matches!( + error, + MessageStreamError::Quic { + source: quic::StreamError::Connection { .. } + } + )); + + let io_error = std::io::Error::from(MessageStreamError::MessageSendFailed); + assert_eq!(io_error.kind(), ErrorKind::BrokenPipe); + } + + #[test] + fn message_stream_error_io_kind_delegates_quic_connection_errors() { + use std::io::ErrorKind; + + let io_error = + std::io::Error::from(MessageStreamError::from(transport_error("io conversion"))); + + assert_eq!(io_error.kind(), ErrorKind::BrokenPipe); + assert!(io_error.to_string().contains("io conversion")); + } + + #[tokio::test] + async fn oversized_write_data_frame_reports_message_error() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + let error = stream + .write_data_frame(OversizedBuf) + .await + .expect_err("oversized frame should fail before writing"); + + assert!(matches!( + error, + MessageStreamError::DataFrameTooLarge { .. } + )); + } + + #[tokio::test] + async fn read_data_frame_chunk_errors_when_data_payload_missing() { + let mut stream = read_stream_with_bytes(23, &[0x00, 0x03]).await; + + let error = stream + .read_data_frame_chunk() + .await + .expect_err("missing DATA payload should fail immediately"); + + assert!(matches!( + error, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source } + } if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn truncated_data_frame_reports_error_after_streaming_available_chunk() { + let mut data_stream = read_stream_with_bytes(23, &[0x00, 0x03, b'a']).await; + // DATA frame bodies are streamed as chunks arrive; an EOF shorter + // than the declared frame length is detected when the frame is polled + // again. + let data_error = data_stream + .read_data_frame_chunk() + .await + .expect("first truncated DATA chunk is yielded before frame EOF is known"); + assert_eq!(data_error, Some(Bytes::from_static(b"a"))); + let data_error = data_stream + .read_data_frame_chunk() + .await + .expect_err("truncated DATA payload should fail when the frame is resumed"); + assert!(matches!( + data_error, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source } + } if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn malformed_frame_header_is_reported_by_peek_header_path() { + let mut header_stream = read_stream_with_bytes(24, &[0x00]).await; + let header_error = header_stream + .read_header_frame() + .await + .expect_err("truncated frame header should fail while peeking"); + assert!(matches!( + header_error, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source } + } if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn peek_frame_replays_stored_frame_decode_error() { + let mut stream = read_stream_with_bytes(28, &[0x00]).await; + + let first = stream + .peek_frame() + .await + .expect("malformed frame should produce an error"); + let first = first.err().expect("truncated frame header should fail"); + assert!(matches!( + first, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source } + } if source.code() == Code::H3_FRAME_ERROR + )); + + let replayed = stream + .peek_frame() + .await + .expect("stored frame error should still be visible"); + let replayed = replayed + .err() + .expect("stored decode error should be replayed"); + assert!(matches!( + replayed, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source } + } if source.code() == Code::H3_FRAME_ERROR + )); + } + + #[tokio::test] + async fn read_header_frame_returns_none_when_next_frame_is_data() { + let mut stream = read_stream_with_bytes(25, &[0x00, 0x03, b'a', b'b', b'c']).await; + + assert!( + stream + .read_header_frame() + .await + .expect("DATA frame should be valid") + .is_none() + ); + } + + #[tokio::test] + async fn try_stream_read_and_write_surface_connection_close_from_goaway_future() { + let read_quic = Arc::new(MockConnection::new()); + let read_state = state_without_qpack(read_quic.clone()).erase(); + read_quic.set_terminal_error(transport_error("read try_stream_read connection close")); + let reader = StreamReader::new(guard::GuardQuicReader::new(Box::pin(TestReadStream { + stream_id: VarInt::from_u32(26), + }) + as crate::quic::BoxQuicStreamReader)); + let mut read_stream = MessageReader::new( + VarInt::from_u32(26), + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + read_state, + ); + let read_result = read_stream + .try_stream_read(async |_this| { + futures::future::pending::>().await + }) + .await; + assert!(matches!( + read_result, + Err(MessageStreamError::Quic { + source: quic::StreamError::Connection { .. } + }) + )); + + let write_quic = Arc::new(MockConnection::new()); + let write_state = state_without_qpack(write_quic.clone()).erase(); + write_quic.set_terminal_error(transport_error("write try_stream_write connection close")); + let writer = SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TestWriteStream { + stream_id: VarInt::from_u32(27), + }) + as crate::quic::BoxQuicStreamWriter)); + let mut write_stream = MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + write_state, + ); + let write_result = write_stream + .try_stream_write(async |_this| { + futures::future::pending::>().await + }) + .await; + assert!(matches!( + write_result, + Err(MessageStreamError::Quic { + source: quic::StreamError::Connection { .. } + }) + )); + } + + #[tokio::test] + async fn stream_read_and_write_error_branches_delegate_connection_errors_to_lifecycle() { + let read_quic = Arc::new(MockConnection::new()); + let read_state = state_without_qpack(read_quic.clone()).erase(); + let terminal_error = transport_error("read lifecycle closed"); + read_quic.set_terminal_error(terminal_error.clone()); + let mut read_stream = read_stream_for_test(VarInt::from_u32(28)); + read_stream.state = read_state; + + let read_error = read_stream + .try_stream_read(async |_this| { + Err::<(), _>(StreamError::Connection { + source: crate::error::H3FrameUnexpected::UnexpectedFrameType.into(), + }) + }) + .await; + assert!(matches!( + read_error, + Err(MessageStreamError::Quic { + source: quic::StreamError::Connection { + source: quic::ConnectionError::Transport { source } + } + }) if source.reason == "read lifecycle closed" + )); + assert!( + read_quic.close_calls().is_empty(), + "already-closed connections should return their terminal error without closing again" + ); + + let write_quic = Arc::new(MockConnection::new()); + let write_state = state_without_qpack(write_quic.clone()).erase(); + let terminal_error = transport_error("write lifecycle closed"); + write_quic.set_terminal_error(terminal_error.clone()); + let mut write_stream = write_stream_for_test(VarInt::from_u32(29)); + write_stream.state = write_state; + + let write_error = write_stream + .try_stream_write(async |_this| { + Err::<(), _>(StreamError::Connection { + source: crate::error::H3FrameUnexpected::UnexpectedFrameType.into(), + }) + }) + .await; + assert!(matches!( + write_error, + Err(MessageStreamError::Quic { + source: quic::StreamError::Connection { + source: quic::ConnectionError::Transport { source } + } + }) if source.reason == "write lifecycle closed" + )); + assert!( + write_quic.close_calls().is_empty(), + "already-closed connections should return their terminal error without closing again" + ); + } + + #[tokio::test] + async fn simple_test_stream_helpers_poll_to_ready() { + let mut reader = TestReadStream { + stream_id: VarInt::from_u32(30), + }; + assert!( + poll_fn(|cx| Pin::new(&mut reader).poll_next(cx)) + .await + .is_none() + ); + + let mut stream_id_error_reader = StreamIdErrorReadStream { + error: quic::StreamError::Reset { + code: VarInt::from_u32(31), + }, + }; + assert!( + poll_fn(|cx| Pin::new(&mut stream_id_error_reader).poll_next(cx)) + .await + .is_none() + ); + poll_fn(|cx| Pin::new(&mut stream_id_error_reader).poll_stop(cx, VarInt::from_u32(0))) + .await + .expect("stream-id-error reader stop should still be ready"); + + let mut tracked_reader = TrackedReadStream { + stream_id: VarInt::from_u32(32), + stop_tx: mpsc::unbounded_channel().0, + }; + assert!( + poll_fn(|cx| Pin::new(&mut tracked_reader).poll_next(cx)) + .await + .is_none() + ); + + let mut writer = TestWriteStream { + stream_id: VarInt::from_u32(33), + }; + poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) + .await + .expect("test writer should be ready"); + Pin::new(&mut writer) + .start_send(Bytes::from_static(b"x")) + .expect("test writer should accept bytes"); + poll_fn(|cx| Pin::new(&mut writer).poll_flush(cx)) + .await + .expect("test writer should flush"); + poll_fn(|cx| Pin::new(&mut writer).poll_close(cx)) + .await + .expect("test writer should close"); + + let mut stream_id_error_writer = StreamIdErrorWriteStream { + error: quic::StreamError::Reset { + code: VarInt::from_u32(34), + }, + }; + poll_fn(|cx| Pin::new(&mut stream_id_error_writer).poll_reset(cx, VarInt::from_u32(0))) + .await + .expect("stream-id-error writer reset should be ready"); + poll_fn(|cx| Pin::new(&mut stream_id_error_writer).poll_ready(cx)) + .await + .expect("stream-id-error writer should be ready"); + Pin::new(&mut stream_id_error_writer) + .start_send(Bytes::from_static(b"x")) + .expect("stream-id-error writer should accept bytes"); + poll_fn(|cx| Pin::new(&mut stream_id_error_writer).poll_flush(cx)) + .await + .expect("stream-id-error writer should flush"); + poll_fn(|cx| Pin::new(&mut stream_id_error_writer).poll_close(cx)) + .await + .expect("stream-id-error writer should close"); + + let (reset_tx, _reset_rx) = mpsc::unbounded_channel(); + let mut tracked_writer = TrackedWriteStream { + stream_id: VarInt::from_u32(35), + reset_tx, + }; + poll_fn(|cx| Pin::new(&mut tracked_writer).poll_ready(cx)) + .await + .expect("tracked writer should be ready"); + Pin::new(&mut tracked_writer) + .start_send(Bytes::from_static(b"x")) + .expect("tracked writer should accept bytes"); + poll_fn(|cx| Pin::new(&mut tracked_writer).poll_flush(cx)) + .await + .expect("tracked writer should flush"); + poll_fn(|cx| Pin::new(&mut tracked_writer).poll_close(cx)) + .await + .expect("tracked writer should close"); + } + + #[tokio::test] + async fn read_data_frame_chunk_skips_reserved_and_empty_frames_before_payload() { + let mut stream = read_stream_with_bytes( + 16, + &[ + 0x21, 0x00, // reserved frame + 0x00, 0x00, // empty DATA frame + 0x00, 0x03, b'a', b'b', b'c', + ], + ) + .await; + + assert_eq!( + stream + .read_data_frame_chunk() + .await + .expect("payload chunk should decode"), + Some(Bytes::from_static(b"abc")) + ); + assert_eq!( + stream + .read_data_frame_chunk() + .await + .expect("stream should end cleanly after payload"), + None + ); + } + + #[tokio::test] + async fn read_data_frame_chunk_reports_unexpected_frame_types() { + let mut stream = read_stream_with_bytes(17, &[0x04, 0x00]).await; + + let error = stream + .read_data_frame_chunk() + .await + .expect_err("settings frame on request stream should be rejected"); + + assert!(matches!( + error, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source } + } if source.code() == Code::H3_FRAME_UNEXPECTED + )); + } + + #[tokio::test] + async fn read_header_frame_decodes_empty_headers_and_then_eof() { + let (mut reader, mut writer) = paired_message_streams(18); + writer + .write_header(std::iter::empty::()) + .await + .expect("empty header section should encode"); + writer.close().await.expect("writer should close cleanly"); + + let header = reader + .read_header_frame() + .await + .expect("header frame should decode") + .expect("header frame should be present"); + assert!(header.is_empty()); + assert!( + reader + .read_header_frame() + .await + .expect("stream should end cleanly") + .is_none() + ); + } + + #[tokio::test] + async fn read_and_write_stream_traits_delegate_to_inner_streams() { + let state = state_without_qpack(Arc::new(MockConnection::new())).erase(); + let stop_code = VarInt::from_u32(51); + let reset_code = VarInt::from_u32(52); + let (stop_tx, mut stop_rx) = mpsc::unbounded_channel(); + let (reset_tx, mut reset_rx) = mpsc::unbounded_channel(); + let mut reader = MessageReader::new( + VarInt::from_u32(19), + StreamReader::new(guard::GuardQuicReader::new(Box::pin(TrackedReadStream { + stream_id: VarInt::from_u32(19), + stop_tx, + }) + as crate::quic::BoxQuicStreamReader)), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ); + let mut writer = MessageWriter::new( + SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TrackedWriteStream { + stream_id: VarInt::from_u32(20), + reset_tx, + }) + as crate::quic::BoxQuicStreamWriter)), + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ); + + assert_eq!( + poll_fn(|cx| Pin::new(&mut reader).poll_stream_id(cx)) + .await + .expect("reader stream id"), + VarInt::from_u32(19) + ); + poll_fn(|cx| Pin::new(&mut reader).poll_stop(cx, stop_code)) + .await + .expect("reader stop"); + assert_eq!( + timeout(Duration::from_secs(1), stop_rx.recv()) + .await + .expect("reader stop should be observed") + .expect("stop code should be sent"), + stop_code, + ); + + assert_eq!( + poll_fn(|cx| Pin::new(&mut writer).poll_stream_id(cx)) + .await + .expect("writer stream id"), + VarInt::from_u32(20) + ); + poll_fn(|cx| Pin::new(&mut writer).poll_reset(cx, reset_code)) + .await + .expect("writer reset"); + assert_eq!( + timeout(Duration::from_secs(1), reset_rx.recv()) + .await + .expect("writer reset should be observed") + .expect("reset code should be sent"), + reset_code, + ); + } + + #[tokio::test] + async fn handle_stream_error_resets_writer_and_stops_reader_streams() { + let state = state_without_qpack(Arc::new(MockConnection::new())).erase(); + let (stop_tx, mut stop_rx) = mpsc::unbounded_channel(); + let (reset_tx, mut reset_rx) = mpsc::unbounded_channel(); + let mut reader = MessageReader::new( + VarInt::from_u32(21), + StreamReader::new(guard::GuardQuicReader::new(Box::pin(TrackedReadStream { + stream_id: VarInt::from_u32(21), + stop_tx, + }) + as crate::quic::BoxQuicStreamReader)), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ); + let mut writer = MessageWriter::new( + SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TrackedWriteStream { + stream_id: VarInt::from_u32(22), + reset_tx, + }) + as crate::quic::BoxQuicStreamWriter)), + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ); + + let read_error = reader + .handle_stream_error(crate::error::H3MessageError::MissingHeaderSection.into()) + .await; + let write_error = writer + .handle_stream_error(crate::error::H3MessageError::MissingHeaderSection.into()) + .await; + + assert!(matches!( + read_error, + quic::StreamError::Reset { code } if code == VarInt::from(Code::H3_MESSAGE_ERROR) + )); + assert!(matches!( + write_error, + quic::StreamError::Reset { code } if code == VarInt::from(Code::H3_MESSAGE_ERROR) + )); + assert_eq!( + timeout(Duration::from_secs(1), stop_rx.recv()) + .await + .expect("reader stop should be observed") + .expect("stop code should be sent"), + VarInt::from(Code::H3_MESSAGE_ERROR), + ); + assert_eq!( + timeout(Duration::from_secs(1), reset_rx.recv()) + .await + .expect("writer reset should be observed") + .expect("reset code should be sent"), + VarInt::from(Code::H3_MESSAGE_ERROR), + ); + } + + #[tokio::test] + async fn handle_stream_error_passthrough_variants_do_not_reset_streams() { + let state = state_without_qpack(Arc::new(MockConnection::new())).erase(); + let (stop_tx, mut stop_rx) = mpsc::unbounded_channel(); + let (reset_tx, mut reset_rx) = mpsc::unbounded_channel(); + let mut reader = MessageReader::new( + VarInt::from_u32(30), + StreamReader::new(guard::GuardQuicReader::new(Box::pin(TrackedReadStream { + stream_id: VarInt::from_u32(30), + stop_tx, + }) + as crate::quic::BoxQuicStreamReader)), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ); + let mut writer = MessageWriter::new( + SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TrackedWriteStream { + stream_id: VarInt::from_u32(31), + reset_tx, + }) + as crate::quic::BoxQuicStreamWriter)), + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ); + + let reset_code = VarInt::from_u32(32); + let read_error = reader + .handle_stream_error(StreamError::Reset { code: reset_code }) + .await; + let write_error = writer + .handle_stream_error(StreamError::Reset { code: reset_code }) + .await; + + assert!(matches!( + read_error, + quic::StreamError::Reset { code } if code == reset_code + )); + assert!(matches!( + write_error, + quic::StreamError::Reset { code } if code == reset_code + )); + assert!( + stop_rx.try_recv().is_err(), + "peer reset should not trigger STOP_SENDING" + ); + assert!( + reset_rx.try_recv().is_err(), + "peer reset should not trigger RESET_STREAM" + ); + } + + #[tokio::test] + async fn handle_stream_error_quic_connection_errors_return_source_without_close() { + let read_quic = Arc::new(MockConnection::new()); + let write_quic = Arc::new(MockConnection::new()); + let mut reader = read_stream_for_test(VarInt::from_u32(33)); + let mut writer = write_stream_for_test(VarInt::from_u32(34)); + reader.state = state_without_qpack(read_quic.clone()).erase(); + writer.state = state_without_qpack(write_quic.clone()).erase(); + + let read_error = reader + .handle_stream_error(StreamError::Connection { + source: crate::connection::ConnectionError::Quic { + source: transport_error("read passthrough"), + }, + }) + .await; + let write_error = writer + .handle_stream_error(StreamError::Connection { + source: crate::connection::ConnectionError::Quic { + source: transport_error("write passthrough"), + }, + }) + .await; + + assert!(matches!( + read_error, + quic::StreamError::Connection { + source: quic::ConnectionError::Transport { source } + } if source.reason == "read passthrough" + )); + assert!(matches!( + write_error, + quic::StreamError::Connection { + source: quic::ConnectionError::Transport { source } + } if source.reason == "write passthrough" + )); + assert!(read_quic.close_calls().is_empty()); + assert!(write_quic.close_calls().is_empty()); + } + + #[tokio::test] + async fn read_stream_take_transfers_drop_cleanup_to_taken_wrapper() { + let quic = Arc::new(MockConnection::new()); + let state = state_without_qpack(quic).erase(); + let (stop_tx, mut stop_rx) = mpsc::unbounded_channel(); + let mut stream = MessageReader::new( + VarInt::from_u32(14), + StreamReader::new(guard::GuardQuicReader::new(Box::pin(TrackedReadStream { + stream_id: VarInt::from_u32(14), + stop_tx, + }) + as crate::quic::BoxQuicStreamReader)), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state, + ); + + let mut taken = stream.take(); + assert_eq!( + taken.stream_id().await.expect("taken stream id"), + VarInt::from_u32(14) + ); + + drop(stream); + assert!( + timeout(Duration::from_millis(50), stop_rx.recv()) + .await + .is_err() + ); + + drop(taken); + assert_eq!( + timeout(Duration::from_secs(1), stop_rx.recv()) + .await + .expect("taken read stream drop should stop") + .expect("stop code should be sent"), + VarInt::from(Code::H3_NO_ERROR), + ); + } + + #[tokio::test] + async fn write_stream_take_transfers_drop_cleanup_to_taken_wrapper() { + let quic = Arc::new(MockConnection::new()); + let state = state_without_qpack(quic).erase(); + let (reset_tx, mut reset_rx) = mpsc::unbounded_channel(); + let mut stream = MessageWriter::new( + SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TrackedWriteStream { + stream_id: VarInt::from_u32(15), + reset_tx, + }) + as crate::quic::BoxQuicStreamWriter)), + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ); + + let mut taken = stream.take(); + assert_eq!( + taken.stream_id().await.expect("taken stream id"), + VarInt::from_u32(15) + ); + + drop(stream); + assert!( + timeout(Duration::from_millis(50), reset_rx.recv()) + .await + .is_err() + ); + + drop(taken); + assert_eq!( + timeout(Duration::from_secs(1), reset_rx.recv()) + .await + .expect("taken write stream drop should reset") + .expect("reset code should be sent"), + VarInt::from(Code::H3_NO_ERROR), + ); + } + + #[tokio::test] + async fn write_after_close_fast_fails_without_polling_inner_stream() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + stream.close().await.expect("initial close should succeed"); + let error = stream + .write_data(Bytes::from_static(b"after close")) + .await + .expect_err("write after close should fail at message layer"); + + assert!(matches!(error, MessageStreamError::MessageWriterClosed)); + } + + #[tokio::test] + async fn reader_stop_does_not_close_receive_side() { + let mut stream = read_stream_with_bytes(41, &[0x00, 0x03, b'a', b'b', b'c']).await; + + stream + .stop(Code::H3_NO_ERROR) + .await + .expect("stop should not close receive side"); + + assert_eq!( + stream + .read_data_frame_chunk() + .await + .expect("stopped reader should still yield buffered data"), + Some(Bytes::from_static(b"abc")) + ); + assert_eq!( + stream + .read_data_frame_chunk() + .await + .expect("reader should close only after EOF"), + None + ); + } + + #[tokio::test] + async fn write_stream_flush_close_reset_and_header_aliases_succeed_on_mock_stream() { + let mut header_stream = write_stream_for_test(VarInt::from_u32(0)); + header_stream + .write_header(std::iter::empty()) + .await + .expect("empty header section should encode"); + + let mut send_header_stream = write_stream_for_test(VarInt::from_u32(0)); + send_header_stream + .send_header(std::iter::empty()) + .await + .expect("send_header should delegate to write_header"); + + let mut flush_close_stream = write_stream_for_test(VarInt::from_u32(0)); + flush_close_stream + .flush() + .await + .expect("flush should succeed"); + flush_close_stream + .close() + .await + .expect("close should succeed"); + + let mut reset_stream = write_stream_for_test(VarInt::from_u32(0)); + reset_stream + .reset(Code::H3_NO_ERROR) + .await + .expect("reset should succeed"); + } + + #[tokio::test] + async fn test_read_stream_new() { + let stream_id = VarInt::from_u32(42); + let mut rs = read_stream_for_test(stream_id); + let got = rs.stream_id().await.expect("stream_id should resolve"); + assert_eq!(got, stream_id); + } + + #[tokio::test] + async fn test_read_stream_connection() { + let rs = read_stream_for_test(VarInt::from_u32(0)); + let conn = rs.connection(); + assert!( + std::sync::Arc::strong_count(conn) >= 1, + "connection should return a valid Arc" + ); + } + + #[tokio::test] + async fn test_read_stream_peer_goaway_covers() { + let mut rs = read_stream_for_test(VarInt::from_u32(0)); + let covers = rs + .peer_goaway_covers() + .await + .expect("peer_goaway_covers should resolve without error"); + drop(covers); + } + + #[tokio::test] + async fn read_data_frame_chunk_returns_result_option_stream_error() { + let mut stream = read_stream_for_test(VarInt::from_u32(0)); + + let result: Result, StreamError> = stream.read_data_frame_chunk().await; + + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn read_header_frame_returns_result_option_stream_error() { + let mut stream = read_stream_for_test(VarInt::from_u32(0)); + + let result: Result, StreamError> = + stream.read_header_frame().await; + + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn high_level_read_methods_return_message_stream_error() { + let mut stream = read_stream_for_test(VarInt::from_u32(0)); + + let data: Result, MessageStreamError> = stream.read_data_chunk().await; + let header: Result, MessageStreamError> = + stream.read_header().await; + + assert!(data.unwrap().is_none()); + assert!(header.unwrap().is_none()); + } + + #[tokio::test] + async fn write_data_frame_returns_message_stream_error() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + let result: Result<(), MessageStreamError> = + stream.write_data_frame(Bytes::from_static(b"hello")).await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn write_data_returns_message_stream_error() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + let result: Result<(), MessageStreamError> = + stream.write_data(Bytes::from_static(b"hello")).await; + + assert!(result.is_ok()); + } + + struct OversizedBuf; + + impl Buf for OversizedBuf { + fn remaining(&self) -> usize { + (VarInt::MAX.into_inner() as usize) + 1 + } + + fn chunk(&self) -> &[u8] { + &[] + } + + fn advance(&mut self, _cnt: usize) { + unreachable!("oversized payload should fail before writing") + } + } + + #[tokio::test] + async fn write_data_oversized_payload_returns_data_frame_too_large() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + let result = stream.write_data(OversizedBuf).await; + + assert!(matches!( + result, + Err(MessageStreamError::DataFrameTooLarge { .. }) + )); + } + + #[tokio::test] + async fn send_data_alias_delegates_to_write_data() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + let result = stream.send_data(Bytes::from_static(b"hello")).await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_read_stream_stop() { + let mut rs = read_stream_for_test(VarInt::from_u32(0)); + rs.stop(Code::H3_NO_ERROR) + .await + .expect("stop should succeed on mock"); + } +} + +#[cfg(test)] +pub mod test; diff --git a/src/dhttp/message/guard.rs b/src/dhttp/message/guard.rs new file mode 100644 index 0000000..5dda272 --- /dev/null +++ b/src/dhttp/message/guard.rs @@ -0,0 +1,1104 @@ +use std::{ + mem, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::{Sink, Stream}; +use tracing::Instrument; + +use crate::{ + error::Code, + quic::{self, BoxQuicStreamReader, BoxQuicStreamWriter, ResetStreamExt, StopStreamExt}, + varint::VarInt, +}; + +fn stream_used_after_taken() -> ! { + panic!("guarded QUIC stream used after being taken, this is a bug") +} + +fn reader_closed_before_stream_id_observed() -> ! { + panic!("guarded QUIC reader closed before stream id was observed, this is a bug") +} + +fn writer_closed_before_stream_id_observed() -> ! { + panic!("guarded QUIC writer closed before stream id was observed, this is a bug") +} + +fn writer_used_after_closed() -> ! { + panic!("guarded QUIC writer used after send side closed, this is a bug") +} + +#[derive(Debug, Clone)] +pub(super) enum QuicWriterStateSnapshot { + Open, + Closed, + Reset { code: VarInt }, + ConnectionClosed { source: quic::ConnectionError }, + Taken, +} + +enum QuicReaderState { + Open { stream: BoxQuicStreamReader }, + Closed, + Reset { code: VarInt }, + ConnectionClosed { source: quic::ConnectionError }, + Taken, +} + +impl QuicReaderState { + fn open(stream: BoxQuicStreamReader) -> Self { + Self::Open { stream } + } + + fn take(&mut self) -> Self { + mem::replace(self, Self::Taken) + } +} + +enum QuicWriterState { + Open { stream: BoxQuicStreamWriter }, + Closed, + Reset { code: VarInt }, + ConnectionClosed { source: quic::ConnectionError }, + Taken, +} + +impl QuicWriterState { + fn open(stream: BoxQuicStreamWriter) -> Self { + Self::Open { stream } + } + + fn take(&mut self) -> Self { + mem::replace(self, Self::Taken) + } + + fn snapshot(&self) -> QuicWriterStateSnapshot { + match self { + Self::Open { .. } => QuicWriterStateSnapshot::Open, + Self::Closed => QuicWriterStateSnapshot::Closed, + Self::Reset { code } => QuicWriterStateSnapshot::Reset { code: *code }, + Self::ConnectionClosed { source } => QuicWriterStateSnapshot::ConnectionClosed { + source: source.clone(), + }, + Self::Taken => QuicWriterStateSnapshot::Taken, + } + } +} + +// --------------------------------------------------------------------------- +// GuardQuicReader +// --------------------------------------------------------------------------- + +/// A QUIC read stream wrapper that automatically stops the stream on drop +/// while the receive side is still open. +pub struct GuardQuicReader { + stream_id: Option, + state: QuicReaderState, +} + +impl GuardQuicReader { + pub fn new(inner: BoxQuicStreamReader) -> Self { + Self { + stream_id: None, + state: QuicReaderState::open(inner), + } + } + + pub(super) fn set_stream_id(&mut self, stream_id: VarInt) { + self.stream_id = Some(stream_id); + } + + /// Take the stream lifecycle state, leaving this guard unusable. + pub fn take(&mut self) -> Self { + Self { + stream_id: self.stream_id, + state: self.state.take(), + } + } + + /// Consume this guard and return the protected stream without running drop cleanup. + pub(super) fn into_inner(mut self) -> BoxQuicStreamReader { + match self.state.take() { + QuicReaderState::Open { stream } => stream, + QuicReaderState::Closed => { + panic!("closed guarded QUIC reader cannot be taken as an open stream") + } + QuicReaderState::Reset { .. } | QuicReaderState::ConnectionClosed { .. } => { + panic!("failed guarded QUIC reader cannot be taken as an open stream") + } + QuicReaderState::Taken => stream_used_after_taken(), + } + } + + pub(super) fn mark_reset(&mut self, code: VarInt) { + self.state = QuicReaderState::Reset { code }; + } + + pub(super) fn mark_connection_closed(&mut self, source: quic::ConnectionError) { + self.state = QuicReaderState::ConnectionClosed { source }; + } + + fn mark_quic_error(&mut self, error: &quic::StreamError) { + match error { + quic::StreamError::Connection { source } => { + self.mark_connection_closed(source.clone()); + } + quic::StreamError::Reset { code } => { + self.mark_reset(*code); + } + } + } +} + +impl Stream for GuardQuicReader { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicReaderState::Open { stream } => stream.as_mut().poll_next(cx), + QuicReaderState::Closed => return Poll::Ready(None), + QuicReaderState::Reset { code } => { + return Poll::Ready(Some(Err(quic::StreamError::Reset { code: *code }))); + } + QuicReaderState::ConnectionClosed { source } => { + return Poll::Ready(Some(Err(quic::StreamError::Connection { + source: source.clone(), + }))); + } + QuicReaderState::Taken => stream_used_after_taken(), + }; + match result { + Poll::Ready(None) => { + this.state = QuicReaderState::Closed; + Poll::Ready(None) + } + Poll::Ready(Some(Err(error))) => { + this.mark_quic_error(&error); + Poll::Ready(Some(Err(error))) + } + other => other, + } + } +} + +impl quic::StopStream for GuardQuicReader { + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicReaderState::Open { stream } => stream.as_mut().poll_stop(cx, code), + QuicReaderState::Closed => return Poll::Ready(Ok(())), + QuicReaderState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicReaderState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicReaderState::Taken => stream_used_after_taken(), + }; + if let Poll::Ready(Err(error)) = result { + this.mark_quic_error(&error); + return Poll::Ready(Err(error)); + } + result + } +} + +impl quic::GetStreamId for GuardQuicReader { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicReaderState::Open { stream } => { + if let Some(stream_id) = this.stream_id { + return Poll::Ready(Ok(stream_id)); + } + stream.as_mut().poll_stream_id(cx) + } + QuicReaderState::Closed => match this.stream_id { + Some(stream_id) => return Poll::Ready(Ok(stream_id)), + None => reader_closed_before_stream_id_observed(), + }, + QuicReaderState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicReaderState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicReaderState::Taken => stream_used_after_taken(), + }; + match result { + Poll::Ready(Ok(stream_id)) => Poll::Ready(Ok(stream_id)), + Poll::Ready(Err(error)) => { + this.mark_quic_error(&error); + Poll::Ready(Err(error)) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for GuardQuicReader { + fn drop(&mut self) { + if let QuicReaderState::Open { mut stream } = self.state.take() { + // Inherent termination: the task owns the only remaining stream + // handle and exits once the committed STOP_SENDING operation + // resolves or the underlying stream reports failure. + tokio::spawn( + async move { + _ = stream.stop(Code::H3_NO_ERROR.into()).await; + } + .in_current_span(), + ); + } + } +} + +// --------------------------------------------------------------------------- +// GuardQuicWriter +// --------------------------------------------------------------------------- + +/// A QUIC write stream wrapper that automatically resets the stream on drop +/// while the send side is still open. +pub struct GuardQuicWriter { + state: QuicWriterState, +} + +impl GuardQuicWriter { + pub fn new(inner: BoxQuicStreamWriter) -> Self { + Self { + state: QuicWriterState::open(inner), + } + } + + pub(super) fn state_snapshot(&self) -> QuicWriterStateSnapshot { + self.state.snapshot() + } + + pub(super) fn mark_reset(&mut self, code: VarInt) { + self.state = QuicWriterState::Reset { code }; + } + + pub(super) fn mark_connection_closed(&mut self, source: quic::ConnectionError) { + self.state = QuicWriterState::ConnectionClosed { source }; + } + + fn mark_quic_error(&mut self, error: &quic::StreamError) { + match error { + quic::StreamError::Connection { source } => { + self.mark_connection_closed(source.clone()); + } + quic::StreamError::Reset { code } => { + self.mark_reset(*code); + } + } + } + + /// Take the stream lifecycle state, leaving this guard unusable. + pub fn take(&mut self) -> Self { + Self { + state: self.state.take(), + } + } +} + +impl Sink for GuardQuicWriter { + type Error = quic::StreamError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicWriterState::Open { stream } => stream.as_mut().poll_ready(cx), + QuicWriterState::Closed => writer_used_after_closed(), + QuicWriterState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicWriterState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicWriterState::Taken => stream_used_after_taken(), + }; + if let Poll::Ready(Err(error)) = result { + this.mark_quic_error(&error); + return Poll::Ready(Err(error)); + } + result + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicWriterState::Open { stream } => stream.as_mut().start_send(item), + QuicWriterState::Closed => writer_used_after_closed(), + QuicWriterState::Reset { code } => { + return Err(quic::StreamError::Reset { code: *code }); + } + QuicWriterState::ConnectionClosed { source } => { + return Err(quic::StreamError::Connection { + source: source.clone(), + }); + } + QuicWriterState::Taken => stream_used_after_taken(), + }; + if let Err(error) = &result { + this.mark_quic_error(error); + } + result + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicWriterState::Open { stream } => stream.as_mut().poll_flush(cx), + QuicWriterState::Closed => writer_used_after_closed(), + QuicWriterState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicWriterState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicWriterState::Taken => stream_used_after_taken(), + }; + if let Poll::Ready(Err(error)) = result { + this.mark_quic_error(&error); + return Poll::Ready(Err(error)); + } + result + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicWriterState::Open { stream } => stream.as_mut().poll_close(cx), + QuicWriterState::Closed => writer_used_after_closed(), + QuicWriterState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicWriterState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicWriterState::Taken => stream_used_after_taken(), + }; + match result { + Poll::Ready(Ok(())) => { + this.state = QuicWriterState::Closed; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(error)) => { + this.mark_quic_error(&error); + Poll::Ready(Err(error)) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl quic::ResetStream for GuardQuicWriter { + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicWriterState::Open { stream } => stream.as_mut().poll_reset(cx, code), + QuicWriterState::Closed => writer_used_after_closed(), + QuicWriterState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicWriterState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicWriterState::Taken => stream_used_after_taken(), + }; + match result { + Poll::Ready(Ok(())) => { + this.state = QuicWriterState::Reset { code }; + Poll::Ready(Ok(())) + } + Poll::Ready(Err(error)) => { + this.mark_quic_error(&error); + Poll::Ready(Err(error)) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl quic::GetStreamId for GuardQuicWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let this = self.get_mut(); + let result = match &mut this.state { + QuicWriterState::Open { stream } => stream.as_mut().poll_stream_id(cx), + QuicWriterState::Closed => writer_closed_before_stream_id_observed(), + QuicWriterState::Reset { code } => { + return Poll::Ready(Err(quic::StreamError::Reset { code: *code })); + } + QuicWriterState::ConnectionClosed { source } => { + return Poll::Ready(Err(quic::StreamError::Connection { + source: source.clone(), + })); + } + QuicWriterState::Taken => stream_used_after_taken(), + }; + match result { + Poll::Ready(Ok(stream_id)) => Poll::Ready(Ok(stream_id)), + Poll::Ready(Err(error)) => { + this.mark_quic_error(&error); + Poll::Ready(Err(error)) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for GuardQuicWriter { + fn drop(&mut self) { + if let QuicWriterState::Open { mut stream } = self.state.take() { + // Inherent termination: the task owns the only remaining stream + // handle and exits once the committed RESET_STREAM operation + // resolves or the underlying stream reports failure. + tokio::spawn( + async move { + _ = stream.reset(Code::H3_NO_ERROR.into()).await; + } + .in_current_span(), + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + panic::{AssertUnwindSafe, catch_unwind}, + sync::{Arc, Mutex}, + }; + + use futures::{SinkExt, StreamExt, future::poll_fn, task::noop_waker_ref}; + use tokio::{ + sync::mpsc, + time::{Duration, timeout}, + }; + + use super::*; + use crate::quic::{GetStreamId, GetStreamIdExt, ResetStream, StopStream}; + + fn stream_error(code: u32) -> quic::StreamError { + quic::StreamError::Reset { + code: VarInt::from_u32(code), + } + } + + fn no_op_cx() -> Context<'static> { + Context::from_waker(noop_waker_ref()) + } + + fn assert_guard_panic(action: impl FnOnce()) { + let panic = catch_unwind(AssertUnwindSafe(action)) + .expect_err("guard should panic after take or drop"); + let message = if let Some(message) = panic.downcast_ref::<&str>() { + *message + } else if let Some(message) = panic.downcast_ref::() { + message.as_str() + } else { + panic!("panic payload should be a string"); + }; + assert_eq!( + message, + "guarded QUIC stream used after being taken, this is a bug", + ); + } + + async fn assert_no_code(mut rx: mpsc::UnboundedReceiver) { + match timeout(Duration::from_millis(50), rx.recv()).await { + Err(_) | Ok(None) => {} + Ok(Some(code)) => panic!("unexpected stop/reset notification received: {code}"), + } + } + + type ReadPollResult = Option>; + type ReadPollQueue = Arc>>; + type StopResultQueue = Arc>>>; + type SentItems = Arc>>; + + #[derive(Clone)] + struct ReaderState { + stream_id: VarInt, + next_results: ReadPollQueue, + stop_results: StopResultQueue, + } + + struct TestReader { + state: ReaderState, + stop_tx: mpsc::UnboundedSender, + } + + impl Stream for TestReader { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let next = self + .state + .next_results + .lock() + .expect("reader next queue lock should not be poisoned") + .pop_front() + .unwrap_or(None); + Poll::Ready(next) + } + } + + impl quic::StopStream for TestReader { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.stop_tx + .send(code) + .expect("reader stop receiver should still be alive"); + let result = self + .state + .stop_results + .lock() + .expect("reader stop queue lock should not be poisoned") + .pop_front() + .unwrap_or(Ok(())); + Poll::Ready(result) + } + } + + impl quic::GetStreamId for TestReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.state.stream_id)) + } + } + + fn reader_guard( + stream_id: u32, + next_results: impl IntoIterator>>, + stop_results: impl IntoIterator>, + ) -> (GuardQuicReader, mpsc::UnboundedReceiver) { + let state = ReaderState { + stream_id: VarInt::from_u32(stream_id), + next_results: Arc::new(Mutex::new(next_results.into_iter().collect())), + stop_results: Arc::new(Mutex::new(stop_results.into_iter().collect())), + }; + let (stop_tx, stop_rx) = mpsc::unbounded_channel(); + ( + GuardQuicReader::new(Box::pin(TestReader { state, stop_tx })), + stop_rx, + ) + } + + #[derive(Clone)] + struct WriterState { + stream_id: VarInt, + sent_items: SentItems, + close_results: StopResultQueue, + reset_results: StopResultQueue, + } + + struct TestWriter { + state: WriterState, + reset_tx: mpsc::UnboundedSender, + } + + impl Sink for TestWriter { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.state + .sent_items + .lock() + .expect("writer sent-items lock should not be poisoned") + .push(item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let result = self + .state + .close_results + .lock() + .expect("writer close queue lock should not be poisoned") + .pop_front() + .unwrap_or(Ok(())); + Poll::Ready(result) + } + } + + impl quic::ResetStream for TestWriter { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.reset_tx + .send(code) + .expect("writer reset receiver should still be alive"); + let result = self + .state + .reset_results + .lock() + .expect("writer reset queue lock should not be poisoned") + .pop_front() + .unwrap_or(Ok(())); + Poll::Ready(result) + } + } + + impl quic::GetStreamId for TestWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.state.stream_id)) + } + } + + fn writer_guard( + stream_id: u32, + close_results: impl IntoIterator>, + reset_results: impl IntoIterator>, + ) -> (GuardQuicWriter, SentItems, mpsc::UnboundedReceiver) { + let sent_items = Arc::new(Mutex::new(Vec::new())); + let state = WriterState { + stream_id: VarInt::from_u32(stream_id), + sent_items: sent_items.clone(), + close_results: Arc::new(Mutex::new(close_results.into_iter().collect())), + reset_results: Arc::new(Mutex::new(reset_results.into_iter().collect())), + }; + let (reset_tx, reset_rx) = mpsc::unbounded_channel(); + ( + GuardQuicWriter::new(Box::pin(TestWriter { state, reset_tx })), + sent_items, + reset_rx, + ) + } + + struct DropNotifyingReader { + stream_id: VarInt, + next_results: VecDeque>>, + stop_tx: mpsc::UnboundedSender, + drop_tx: Option>, + } + + impl Stream for DropNotifyingReader { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.get_mut().next_results.pop_front().unwrap_or(None)) + } + } + + impl quic::StopStream for DropNotifyingReader { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.get_mut() + .stop_tx + .send(code) + .expect("reader stop receiver should still be alive"); + Poll::Ready(Ok(())) + } + } + + impl quic::GetStreamId for DropNotifyingReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + + impl Drop for DropNotifyingReader { + fn drop(&mut self) { + if let Some(drop_tx) = self.drop_tx.take() { + _ = drop_tx.send(()); + } + } + } + + struct DropNotifyingWriter { + stream_id: VarInt, + close_results: VecDeque>, + reset_results: VecDeque>, + reset_tx: mpsc::UnboundedSender, + drop_tx: Option>, + } + + impl Sink for DropNotifyingWriter { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(self.get_mut().close_results.pop_front().unwrap_or(Ok(()))) + } + } + + impl quic::ResetStream for DropNotifyingWriter { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + let this = self.get_mut(); + this.reset_tx + .send(code) + .expect("writer reset receiver should still be alive"); + Poll::Ready(this.reset_results.pop_front().unwrap_or(Ok(()))) + } + } + + impl quic::GetStreamId for DropNotifyingWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + + impl Drop for DropNotifyingWriter { + fn drop(&mut self) { + if let Some(drop_tx) = self.drop_tx.take() { + _ = drop_tx.send(()); + } + } + } + + #[tokio::test] + async fn reader_take_moves_inner_stream_and_panics_on_original_use() { + let (mut guard, stop_rx) = reader_guard(7, [], [Ok(())]); + guard.set_stream_id(VarInt::from_u32(7)); + let mut taken = GuardQuicReader::take(&mut guard); + + assert_eq!( + taken + .stream_id() + .await + .expect("taken reader should expose stream id"), + VarInt::from_u32(7), + ); + + let mut cx = no_op_cx(); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_next(&mut cx); + }); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_stop(&mut cx, VarInt::from_u32(1)); + }); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_stream_id(&mut cx); + }); + + drop(guard); + drop(taken); + + assert_eq!( + timeout(Duration::from_secs(1), async move { + let mut stop_rx = stop_rx; + stop_rx.recv().await + }) + .await + .expect("reader drop cleanup should run") + .expect("reader stop code should be sent"), + VarInt::from(Code::H3_NO_ERROR), + ); + } + + #[tokio::test] + async fn reader_eof_marks_completed_and_drop_skips_stop() { + let (mut guard, stop_rx) = reader_guard(11, [None], []); + + assert!(guard.next().await.is_none()); + + drop(guard); + assert_no_code(stop_rx).await; + } + + #[tokio::test] + async fn reader_eof_releases_inner_stream_immediately() { + let (stop_tx, _stop_rx) = mpsc::unbounded_channel(); + let (drop_tx, mut drop_rx) = mpsc::unbounded_channel(); + let mut guard = GuardQuicReader::new(Box::pin(DropNotifyingReader { + stream_id: VarInt::from_u32(111), + next_results: [None].into_iter().collect(), + stop_tx, + drop_tx: Some(drop_tx), + })); + + assert!(guard.next().await.is_none()); + + timeout(Duration::from_secs(1), drop_rx.recv()) + .await + .expect("inner reader should be released on EOF") + .expect("inner reader drop should be observed"); + } + + #[tokio::test] + async fn reader_stop_ok_keeps_receive_side_open_for_drop_cleanup() { + let (mut guard, mut stop_rx) = reader_guard(12, [], [Ok(())]); + + poll_fn(|cx| Pin::new(&mut guard).poll_stop(cx, VarInt::from_u32(33))) + .await + .expect("stop should succeed"); + assert_eq!( + timeout(Duration::from_secs(1), stop_rx.recv()) + .await + .expect("stop call should notify") + .expect("stop code should be present"), + VarInt::from_u32(33), + ); + + drop(guard); + assert_eq!( + timeout(Duration::from_secs(1), stop_rx.recv()) + .await + .expect("drop cleanup should stop the still-open receive side") + .expect("drop stop code should be present"), + VarInt::from(Code::H3_NO_ERROR), + ); + } + + #[tokio::test] + async fn reader_stop_error_transitions_to_reset_and_skips_drop_cleanup() { + let (mut guard, mut stop_rx) = reader_guard(13, [], [Err(stream_error(44)), Ok(())]); + + let error = poll_fn(|cx| Pin::new(&mut guard).poll_stop(cx, VarInt::from_u32(55))) + .await + .expect_err("explicit stop should fail"); + assert!(matches!( + error, + quic::StreamError::Reset { code } if code == VarInt::from_u32(44) + )); + assert_eq!( + timeout(Duration::from_secs(1), stop_rx.recv()) + .await + .expect("failing stop should notify") + .expect("stop code should be present"), + VarInt::from_u32(55), + ); + + drop(guard); + assert_no_code(stop_rx).await; + } + + #[tokio::test] + async fn writer_take_moves_inner_stream_and_panics_on_original_use() { + let (mut guard, sent_items, reset_rx) = writer_guard(8, [], [Ok(())]); + let mut taken = GuardQuicWriter::take(&mut guard); + + assert_eq!( + taken + .stream_id() + .await + .expect("taken writer should expose stream id"), + VarInt::from_u32(8), + ); + taken + .send(Bytes::from_static(b"hello")) + .await + .expect("taken writer should accept sends"); + assert_eq!( + sent_items + .lock() + .expect("sent-items lock should not be poisoned") + .as_slice(), + &[Bytes::from_static(b"hello")], + ); + + let mut cx = no_op_cx(); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_ready(&mut cx); + }); + assert_guard_panic(|| { + Pin::new(&mut guard) + .start_send(Bytes::from_static(b"panic")) + .expect("start_send should panic before returning"); + }); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_flush(&mut cx); + }); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_close(&mut cx); + }); + assert_guard_panic(|| { + let _ = Pin::new(&mut guard).poll_reset(&mut cx, VarInt::from_u32(2)); + }); + + drop(guard); + drop(taken); + + assert_eq!( + timeout(Duration::from_secs(1), async move { + let mut reset_rx = reset_rx; + reset_rx.recv().await + }) + .await + .expect("writer drop cleanup should run") + .expect("writer reset code should be sent"), + VarInt::from(Code::H3_NO_ERROR), + ); + } + + #[tokio::test] + async fn writer_close_ok_marks_completed_and_drop_skips_reset() { + let (mut guard, _, reset_rx) = writer_guard(21, [Ok(())], []); + + guard.close().await.expect("close should succeed"); + + drop(guard); + assert_no_code(reset_rx).await; + } + + #[tokio::test] + async fn writer_close_releases_inner_stream_immediately() { + let (reset_tx, _reset_rx) = mpsc::unbounded_channel(); + let (drop_tx, mut drop_rx) = mpsc::unbounded_channel(); + let mut guard = GuardQuicWriter::new(Box::pin(DropNotifyingWriter { + stream_id: VarInt::from_u32(121), + close_results: [Ok(())].into_iter().collect(), + reset_results: [].into_iter().collect(), + reset_tx, + drop_tx: Some(drop_tx), + })); + + guard.close().await.expect("close should succeed"); + + timeout(Duration::from_secs(1), drop_rx.recv()) + .await + .expect("inner writer should be released on close") + .expect("inner writer drop should be observed"); + } + + #[tokio::test] + async fn writer_reset_ok_marks_completed_and_drop_skips_second_reset() { + let (mut guard, _, mut reset_rx) = writer_guard(22, [], [Ok(())]); + + poll_fn(|cx| Pin::new(&mut guard).poll_reset(cx, VarInt::from_u32(66))) + .await + .expect("reset should succeed"); + assert_eq!( + timeout(Duration::from_secs(1), reset_rx.recv()) + .await + .expect("reset call should notify") + .expect("reset code should be present"), + VarInt::from_u32(66), + ); + + drop(guard); + assert_no_code(reset_rx).await; + } + + #[tokio::test] + async fn writer_reset_releases_inner_stream_immediately() { + let (reset_tx, mut reset_rx) = mpsc::unbounded_channel(); + let (drop_tx, mut drop_rx) = mpsc::unbounded_channel(); + let mut guard = GuardQuicWriter::new(Box::pin(DropNotifyingWriter { + stream_id: VarInt::from_u32(122), + close_results: [].into_iter().collect(), + reset_results: [Ok(())].into_iter().collect(), + reset_tx, + drop_tx: Some(drop_tx), + })); + + poll_fn(|cx| Pin::new(&mut guard).poll_reset(cx, VarInt::from_u32(77))) + .await + .expect("reset should succeed"); + assert_eq!( + timeout(Duration::from_secs(1), reset_rx.recv()) + .await + .expect("reset call should notify") + .expect("reset code should be present"), + VarInt::from_u32(77), + ); + timeout(Duration::from_secs(1), drop_rx.recv()) + .await + .expect("inner writer should be released on reset") + .expect("inner writer drop should be observed"); + } + + #[tokio::test] + async fn writer_close_error_transitions_to_reset_and_skips_drop_cleanup() { + let (mut guard, _, reset_rx) = writer_guard(23, [Err(stream_error(77))], [Ok(())]); + + let error = guard.close().await.expect_err("close should fail"); + assert!(matches!( + error, + quic::StreamError::Reset { code } if code == VarInt::from_u32(77) + )); + + drop(guard); + assert_no_code(reset_rx).await; + } +} diff --git a/src/dhttp/message/hyper.rs b/src/dhttp/message/hyper.rs new file mode 100644 index 0000000..f57f7e9 --- /dev/null +++ b/src/dhttp/message/hyper.rs @@ -0,0 +1,9 @@ +use super::{MessageReader, MessageStreamError, MessageWriter}; + +pub mod client; +pub mod read; +pub mod upgrade; +pub mod write; + +pub use client::RequestError; +pub use write::SendMessageError; diff --git a/src/dhttp/message/hyper/client.rs b/src/dhttp/message/hyper/client.rs new file mode 100644 index 0000000..463d597 --- /dev/null +++ b/src/dhttp/message/hyper/client.rs @@ -0,0 +1,996 @@ +use std::{error::Error, sync::Arc}; + +use bytes::Bytes; +use http_body::{Body, Frame, SizeHint}; +use http_body_util::{BodyExt, Empty}; +use snafu::{ResultExt, Snafu}; +use tokio_util::task::AbortOnDropHandle; +use tracing::Instrument; + +use crate::{ + connection::Connection, + dhttp::message::{ + InitialMessageStreamError, MessageStreamError, + hyper::{ + SendMessageError, + read::Either, + upgrade::{RemainStream, TakeoverSlot}, + write::send_message_error, + }, + }, + qpack::field::{Protocol, hyper::validated_hyper_request_parts_to_field_lines}, + quic::{self, GetStreamIdExt}, + stream_id::StreamId, +}; + +fn protocol_from_extensions(extensions: &http::Extensions) -> Option { + if let Some(protocol) = extensions.get::() { + return Some(protocol.clone()); + } + + extensions.get::<::hyper::ext::Protocol>().map(|protocol| { + Protocol::try_from(Bytes::copy_from_slice(protocol.as_ref())) + .expect("hyper protocol token is valid UTF-8") + }) +} + +pin_project_lite::pin_project! { + struct AbortBodySenderOnDrop { + _body_sender: AbortOnDropHandle<()>, + #[pin] + body: B, + } +} + +impl AbortBodySenderOnDrop { + fn new(body: B, body_sender: AbortOnDropHandle<()>) -> Self { + Self { + _body_sender: body_sender, + body, + } + } +} + +impl Body for AbortBodySenderOnDrop { + type Data = B::Data; + type Error = B::Error; + + fn poll_frame( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + self.project().body.poll_frame(cx) + } + + fn is_end_stream(&self) -> bool { + self.body.is_end_stream() + } + + fn size_hint(&self) -> SizeHint { + self.body.size_hint() + } +} + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum RequestError { + #[snafu(display("failed to open initial message stream"))] + InitialStream { source: InitialMessageStreamError }, + #[snafu(display("failed to send request"))] + SendRequest { source: SendMessageError }, + #[snafu(display("failed to receive response"))] + ReceiveResponse { source: MessageStreamError }, + #[snafu(display("failed to read request stream ID"))] + StreamId { source: quic::StreamError }, +} + +impl From for RequestError { + #[track_caller] + fn from(source: InitialMessageStreamError) -> Self { + RequestError::InitialStream { source } + } +} + +impl From> for RequestError { + #[track_caller] + fn from(source: SendMessageError) -> Self { + RequestError::SendRequest { source } + } +} + +impl From for RequestError { + #[track_caller] + fn from(source: MessageStreamError) -> Self { + RequestError::ReceiveResponse { source } + } +} + +impl Connection { + /// Execute a hyper-compatible HTTP request over this H3 connection. + /// + /// The response body type is intentionally opaque and does not guarantee + /// [`Unpin`]. This keeps the native streaming body shape visible to the + /// compiler instead of imposing a boxing policy on every caller. Callers + /// that need an [`Unpin`] body should choose the pinning or boxing strategy + /// appropriate for that call site, such as `pin!`, `Pin>`, or an + /// `http_body_util` boxed body. + #[tracing::instrument(level = "debug", skip_all, fields(method = %request.method(), uri = %request.uri()))] + pub async fn execute_hyper_request( + &self, + request: http::Request, + ) -> Result< + http::Response + use>, + RequestError, + > + where + B::Data: Send, + B::Error: Error + Send + 'static, + { + let is_connect = request.method() == http::Method::CONNECT; + let protocol = if is_connect { + protocol_from_extensions(request.extensions()) + } else { + None + }; + let (parts, body) = request.into_parts(); + let fields = validated_hyper_request_parts_to_field_lines(parts) + .context(send_message_error::MalformedHeaderSnafu)?; + let (mut read_stream, mut write_stream) = self.initial_message_stream().await?; + + if is_connect { + // CONNECT: no body or trailers — join send + receive headers. + let stream_id = write_stream + .stream_id() + .await + .context(request_error::StreamIdSnafu)?; + let (send_result, recv_result) = tokio::join!( + async { + write_stream + .write_header(fields) + .await + .context(send_message_error::StreamSnafu) + }, + async { + loop { + let parts = read_stream.read_hyper_response_parts().await?; + if !parts.status.is_informational() { + return Ok::<_, MessageStreamError>(parts); + } + tracing::debug!( + status = %parts.status, + headers = ?parts.headers, + "skipping informational response", + ); + } + }, + ); + send_result?; + let mut response_parts = recv_result?; + + response_parts.extensions.insert(StreamId::from(stream_id)); + response_parts.extensions.insert(Arc::new(self.erase())); + if let Some(protocol) = protocol { + response_parts.extensions.insert(protocol); + } + response_parts + .extensions + .insert(TakeoverSlot::new(RemainStream::immediately(read_stream))); + response_parts + .extensions + .insert(TakeoverSlot::new(RemainStream::immediately(write_stream))); + let body = Either::right(Empty::new().map_err(|never| match never {})); + Ok(http::Response::from_parts(response_parts, body)) + } else { + // Non-CONNECT: send headers, spawn body sender, read response. + write_stream + .write_header(fields) + .await + .context(send_message_error::StreamSnafu)?; + + // Response body owns this task: dropping the response aborts an unfinished + // request body sender instead of leaving it detached. + let body_sender = AbortOnDropHandle::new(tokio::spawn( + async move { + if write_stream.send_hyper_body(body).await.is_ok() { + _ = write_stream.close().await; + } + } + .in_current_span(), + )); + + // Read response headers, skipping informational. + let mut response_parts = read_stream.read_hyper_response_parts().await?; + while response_parts.status.is_informational() { + tracing::debug!( + status = %response_parts.status, + headers = ?response_parts.headers, + "skipping informational response", + ); + response_parts = read_stream.read_hyper_response_parts().await?; + } + + let body = Either::left(AbortBodySenderOnDrop::new( + read_stream.into_hyper_body(), + body_sender, + )); + Ok(http::Response::from_parts(response_parts, body)) + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + convert::Infallible, + fmt, + future::pending, + pin::Pin, + sync::{ + Arc, Mutex, + atomic::{AtomicU32, Ordering}, + }, + task::{Context, Poll}, + }; + + use futures::{Sink, SinkExt, Stream}; + use http::Extensions; + use http_body::Frame; + use http_body_util::{BodyExt, Full}; + use tokio::{ + sync::oneshot, + time::{Duration, timeout}, + }; + + use super::*; + use crate::{ + codec::{SinkWriter, StreamReader}, + connection::{ConnectionBuilder, ConnectionState, StreamError}, + dhttp::{ + message::{MessageReader, MessageWriter, guard}, + protocol::DHttpProtocol, + settings::Settings, + }, + protocol::Protocols, + qpack::{ + decoder::DecoderInstruction, + encoder::EncoderInstruction, + field::MalformedHeaderSection, + protocol::{QPackDecoder, QPackEncoder}, + }, + varint::VarInt, + }; + + struct ClientReadStream { + stream_id: Result, + inner: quic::BoxQuicStreamReader, + } + + impl ClientReadStream { + fn new( + stream_id: Result, + inner: quic::BoxQuicStreamReader, + ) -> Self { + Self { stream_id, inner } + } + } + + impl Unpin for ClientReadStream {} + + impl quic::GetStreamId for ClientReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(self.get_mut().stream_id.clone()) + } + } + + impl quic::StopStream for ClientReadStream { + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.get_mut().inner.as_mut().poll_stop(cx, code) + } + } + + impl Stream for ClientReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.as_mut().poll_next(cx) + } + } + + struct ClientWriteStream { + stream_id: Result, + inner: quic::BoxQuicStreamWriter, + } + + impl ClientWriteStream { + fn new( + stream_id: Result, + inner: quic::BoxQuicStreamWriter, + ) -> Self { + Self { stream_id, inner } + } + } + + impl Unpin for ClientWriteStream {} + + impl quic::GetStreamId for ClientWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(self.get_mut().stream_id.clone()) + } + } + + impl quic::ResetStream for ClientWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.get_mut().inner.as_mut().poll_reset(cx, code) + } + } + + impl Sink for ClientWriteStream { + type Error = quic::StreamError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.as_mut().poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.get_mut().inner.as_mut().start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inner.as_mut().poll_close(cx) + } + } + + #[derive(Default)] + struct QueuedConnection { + bi_streams: Mutex>, + next_uni_stream_id: AtomicU32, + uni_readers: Mutex>, + } + + impl fmt::Debug for QueuedConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("QueuedConnection").finish_non_exhaustive() + } + } + + impl QueuedConnection { + fn stage_bi_stream(&self, reader: ClientReadStream, writer: ClientWriteStream) { + self.bi_streams + .lock() + .expect("bi stream queue mutex should not be poisoned") + .push_back((reader, writer)); + } + } + + impl quic::ManageStream for QueuedConnection { + type StreamReader = ClientReadStream; + type StreamWriter = ClientWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + self.bi_streams + .lock() + .expect("bi stream queue mutex should not be poisoned") + .pop_front() + .ok_or_else(|| test_connection_error("no staged bidirectional stream")) + } + + async fn open_uni(&self) -> Result { + let stream_id = + VarInt::from_u32(100 + self.next_uni_stream_id.fetch_add(4, Ordering::Relaxed)); + let (reader, writer) = quic::test::mock_stream_pair_with_capacity(stream_id, 256); + self.uni_readers + .lock() + .expect("uni stream reader list mutex should not be poisoned") + .push(Box::pin(reader)); + Ok(ClientWriteStream::new(Ok(stream_id), Box::pin(writer))) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for QueuedConnection { + type LocalAuthority = crate::connection::tests::TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for QueuedConnection { + type RemoteAuthority = crate::connection::tests::TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for QueuedConnection { + fn close(&self, _code: crate::error::Code, _reason: std::borrow::Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + pending().await + } + } + + struct PendingRequestBody { + dropped: Option>, + } + + impl Drop for PendingRequestBody { + fn drop(&mut self) { + if let Some(dropped) = self.dropped.take() { + _ = dropped.send(()); + } + } + } + + impl Body for PendingRequestBody { + type Data = Bytes; + type Error = Infallible; + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Poll::Pending + } + } + + fn test_connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + fn qpack_decoder_sink() -> Pin + Send>> { + Box::pin(futures::sink::drain::().sink_map_err(|never| match never {})) + } + + fn qpack_decoder_stream() + -> Pin> + Send>> { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn qpack_encoder_sink() -> Pin + Send>> { + Box::pin(futures::sink::drain::().sink_map_err(|never| match never {})) + } + + fn qpack_encoder_stream() + -> Pin> + Send>> { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn server_connection_state() -> ConnectionState { + let erased: Arc = + Arc::new(crate::connection::tests::MockConnection::new()); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + ConnectionState::new_for_test(erased, Arc::new(protocols)) + } + + fn server_streams( + stream_id: VarInt, + reader: impl quic::ReadStream + Unpin + 'static, + writer: impl quic::WriteStream + Unpin + 'static, + ) -> (MessageReader, MessageWriter) { + let state = server_connection_state(); + let reader = StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )); + let writer = SinkWriter::new(guard::GuardQuicWriter::new( + Box::pin(writer) as crate::quic::BoxQuicStreamWriter + )); + + ( + MessageReader::new( + stream_id, + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ), + MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ), + ) + } + + async fn connection_with_staged_bi_stream( + stream_id: VarInt, + write_stream_id: Result, + ) -> (Connection, MessageReader, MessageWriter) { + let quic = Arc::new(QueuedConnection::default()); + let (server_reader, client_writer) = + quic::test::mock_stream_pair_with_capacity(stream_id, 256); + let (client_reader, server_writer) = + quic::test::mock_stream_pair_with_capacity(stream_id, 256); + + quic.stage_bi_stream( + ClientReadStream::new(Ok(stream_id), Box::pin(client_reader)), + ClientWriteStream::new(write_stream_id, Box::pin(client_writer)), + ); + + let connection = ConnectionBuilder::new(Arc::new(Settings::default())) + .build(quic) + .await + .expect("connection should build"); + let (server_reader, server_writer) = + server_streams(stream_id, server_reader, server_writer); + + (connection, server_reader, server_writer) + } + + async fn connection_without_staged_stream() -> Connection { + ConnectionBuilder::new(Arc::new(Settings::default())) + .build(Arc::new(QueuedConnection::default())) + .await + .expect("connection should build") + } + + fn response_parts(status: http::StatusCode) -> http::response::Parts { + http::Response::builder() + .status(status) + .header("x-response", "present") + .body(()) + .expect("response should be valid") + .into_parts() + .0 + } + + #[test] + fn protocol_from_extensions_captures_h3x_protocol() { + let mut extensions = Extensions::new(); + extensions.insert(Protocol::new("webtransport")); + + let protocol = protocol_from_extensions(&extensions).expect("protocol is captured"); + + assert_eq!(protocol.as_str(), "webtransport"); + } + + #[test] + fn protocol_from_extensions_captures_hyper_protocol() { + let mut extensions = Extensions::new(); + extensions.insert(::hyper::ext::Protocol::from_static("websocket")); + + let protocol = protocol_from_extensions(&extensions).expect("protocol is captured"); + + assert_eq!(protocol.as_str(), "websocket"); + } + + #[test] + fn protocol_from_extensions_prefers_h3x_protocol() { + let mut extensions = Extensions::new(); + extensions.insert(::hyper::ext::Protocol::from_static("websocket")); + extensions.insert(Protocol::new("webtransport")); + + let protocol = protocol_from_extensions(&extensions).expect("protocol is captured"); + + assert_eq!(protocol.as_str(), "webtransport"); + } + + #[test] + fn protocol_from_extensions_returns_none_without_protocol() { + let extensions = Extensions::new(); + + assert!(protocol_from_extensions(&extensions).is_none()); + } + + #[tokio::test] + async fn execute_non_connect_sends_request_and_skips_informational_response() { + let stream_id = VarInt::from_u32(0); + let (connection, mut server_reader, mut server_writer) = + connection_with_staged_bi_stream(stream_id, Ok(stream_id)).await; + let request = http::Request::builder() + .method(http::Method::POST) + .uri("https://example.test/upload") + .header("x-request", "present") + .body(Full::new(Bytes::from_static(b"request body"))) + .expect("request should be valid"); + + let client = connection.execute_hyper_request(request); + let server = async move { + let request_parts = server_reader + .read_hyper_request_parts() + .await + .expect("request headers should be readable"); + assert_eq!(request_parts.method, http::Method::POST); + assert_eq!(request_parts.uri, "https://example.test/upload"); + assert_eq!(request_parts.headers["x-request"], "present"); + + let request_body = server_reader + .into_hyper_body() + .collect() + .await + .expect("request body should be readable") + .to_bytes(); + assert_eq!(request_body, Bytes::from_static(b"request body")); + + server_writer + .send_hyper_response_parts(response_parts(http::StatusCode::EARLY_HINTS)) + .await + .expect("informational response should be written"); + server_writer + .send_hyper_response( + http::Response::builder() + .status(http::StatusCode::OK) + .header("x-final", "present") + .body(Full::new(Bytes::from_static(b"response body"))) + .expect("response should be valid"), + ) + .await + .expect("final response should be written"); + server_writer + .close() + .await + .expect("response stream should close"); + }; + + let (response, ()) = tokio::join!(client, server); + let response = response.expect("client response should succeed"); + + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!(response.headers()["x-final"], "present"); + let body = response + .into_body() + .collect() + .await + .expect("response body should be readable") + .to_bytes(); + assert_eq!(body, Bytes::from_static(b"response body")); + } + + #[tokio::test] + async fn dropping_non_connect_response_aborts_pending_request_body_sender() { + let stream_id = VarInt::from_u32(24); + let (connection, mut server_reader, mut server_writer) = + connection_with_staged_bi_stream(stream_id, Ok(stream_id)).await; + let (drop_tx, drop_rx) = oneshot::channel(); + let request = http::Request::builder() + .method(http::Method::POST) + .uri("https://example.test/upload") + .body(PendingRequestBody { + dropped: Some(drop_tx), + }) + .expect("request should be valid"); + + let client = connection.execute_hyper_request(request); + let server = async move { + let request_parts = server_reader + .read_hyper_request_parts() + .await + .expect("request headers should be readable"); + assert_eq!(request_parts.method, http::Method::POST); + + server_writer + .send_hyper_response_parts(response_parts(http::StatusCode::OK)) + .await + .expect("response should be written"); + server_writer + .close() + .await + .expect("response stream should close"); + }; + + let (response, ()) = tokio::join!(client, server); + let response = response.expect("client response should succeed"); + + drop(response); + + timeout(Duration::from_secs(1), drop_rx) + .await + .expect("request body sender should be aborted when response is dropped") + .expect("request body drop notification should be delivered"); + } + + #[tokio::test] + async fn execute_connect_returns_upgrade_extensions_and_empty_body() { + let stream_id = VarInt::from_u32(4); + let (connection, mut server_reader, mut server_writer) = + connection_with_staged_bi_stream(stream_id, Ok(stream_id)).await; + let request = http::Request::builder() + .method(http::Method::CONNECT) + .uri("example.test:443") + .body(Full::new(Bytes::new())) + .expect("request should be valid"); + + let client = connection.execute_hyper_request(request); + let server = async move { + let request_parts = server_reader + .read_hyper_request_parts() + .await + .expect("connect headers should be readable"); + assert_eq!(request_parts.method, http::Method::CONNECT); + assert_eq!(request_parts.uri, "example.test:443"); + + server_writer + .send_hyper_response_parts(response_parts(http::StatusCode::EARLY_HINTS)) + .await + .expect("informational response should be written"); + server_writer + .send_hyper_response_parts(response_parts(http::StatusCode::OK)) + .await + .expect("connect response should be written"); + server_writer + .close() + .await + .expect("connect response stream should close"); + }; + + let (response, ()) = tokio::join!(client, server); + let response = response.expect("connect response should succeed"); + + assert_eq!(response.status(), http::StatusCode::OK); + assert_eq!( + response.extensions().get::().copied(), + Some(StreamId(stream_id)) + ); + assert!( + response + .extensions() + .get::>>() + .is_some() + ); + assert!(response.extensions().get::().is_none()); + assert!( + response + .extensions() + .get::>() + .is_some() + ); + assert!( + response + .extensions() + .get::>() + .is_some() + ); + let body = response + .into_body() + .collect() + .await + .expect("connect body should be readable") + .to_bytes(); + assert!(body.is_empty()); + } + + #[tokio::test] + async fn execute_connect_preserves_protocol_extension() { + let stream_id = VarInt::from_u32(12); + let (connection, mut server_reader, mut server_writer) = + connection_with_staged_bi_stream(stream_id, Ok(stream_id)).await; + let mut request = http::Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/connect") + .body(Full::new(Bytes::new())) + .expect("request should be valid"); + request + .extensions_mut() + .insert(::hyper::ext::Protocol::from_static("webtransport")); + + let client = connection.execute_hyper_request(request); + let server = async move { + let request_parts = server_reader + .read_hyper_request_parts() + .await + .expect("connect headers should be readable"); + assert_eq!(request_parts.method, http::Method::CONNECT); + + server_writer + .send_hyper_response_parts(response_parts(http::StatusCode::OK)) + .await + .expect("connect response should be written"); + server_writer + .close() + .await + .expect("connect response stream should close"); + }; + + let (response, ()) = tokio::join!(client, server); + let response = response.expect("connect response should succeed"); + + let protocol = response + .extensions() + .get::() + .expect("protocol should be preserved"); + assert_eq!(protocol.as_str(), "webtransport"); + } + + #[tokio::test] + async fn execute_request_reports_initial_stream_errors() { + let connection = connection_without_staged_stream().await; + let request = http::Request::builder() + .method(http::Method::GET) + .uri("https://example.test/") + .body(Full::new(Bytes::new())) + .expect("request should be valid"); + + let error = match connection.execute_hyper_request(request).await { + Ok(_) => panic!("missing staged stream should be reported"), + Err(error) => error, + }; + + assert!(matches!(error, RequestError::InitialStream { .. })); + } + + #[tokio::test] + async fn execute_request_reports_response_read_errors() { + let stream_id = VarInt::from_u32(16); + let (connection, mut server_reader, mut server_writer) = + connection_with_staged_bi_stream(stream_id, Ok(stream_id)).await; + let request = http::Request::builder() + .method(http::Method::GET) + .uri("https://example.test/") + .body(Full::new(Bytes::new())) + .expect("request should be valid"); + + let client = connection.execute_hyper_request(request); + let server = async move { + let request_parts = server_reader + .read_hyper_request_parts() + .await + .expect("request headers should be readable"); + assert_eq!(request_parts.method, http::Method::GET); + + server_writer + .close() + .await + .expect("response stream should close without response"); + }; + + let (error, ()) = tokio::join!( + async { + match client.await { + Ok(_) => panic!("missing response should be reported"), + Err(error) => error, + } + }, + server, + ); + + assert!(matches!(error, RequestError::ReceiveResponse { .. })); + } + + #[tokio::test] + async fn execute_request_reports_malformed_header_as_send_request_error() { + let connection = connection_without_staged_stream().await; + let mut request = http::Request::builder() + .method(http::Method::GET) + .uri("https://example.test/") + .body(Full::new(Bytes::new())) + .expect("request should be valid"); + request + .extensions_mut() + .insert(Protocol::new("webtransport")); + + let error = match connection.execute_hyper_request(request).await { + Ok(_) => panic!("protocol on non-CONNECT request should be rejected"), + Err(error) => error, + }; + + assert!(matches!( + error, + RequestError::SendRequest { + source: SendMessageError::MalformedHeader { + source: MalformedHeaderSection::ProtocolInNonConnectRequest + } + } + )); + } + + #[tokio::test] + async fn execute_connect_reports_stream_id_errors() { + let stream_id = VarInt::from_u32(8); + let stream_id_error = quic::StreamError::Reset { + code: VarInt::from_u32(0x10), + }; + let (connection, _server_reader, _server_writer) = + connection_with_staged_bi_stream(stream_id, Err(stream_id_error)).await; + let request = http::Request::builder() + .method(http::Method::CONNECT) + .uri("example.test:443") + .body(Full::new(Bytes::new())) + .expect("request should be valid"); + + let error = match connection.execute_hyper_request(request).await { + Ok(_) => panic!("stream id failure should be reported"), + Err(error) => error, + }; + + assert!(matches!( + error, + RequestError::StreamId { + source: quic::StreamError::Reset { code } + } if code == VarInt::from_u32(0x10) + )); + } + + #[test] + fn request_error_from_send_message_error_preserves_variant() { + let source = SendMessageError::::MalformedHeader { + source: MalformedHeaderSection::ProtocolInNonConnectRequest, + }; + + let error: RequestError = source.into(); + + assert!(matches!(error, RequestError::SendRequest { .. })); + } + + #[test] + fn request_error_from_initial_stream_error_preserves_variant() { + let source = InitialMessageStreamError::InitialRawStream { + source: crate::dhttp::protocol::InitialRawMessageStreamError::Connection { + source: test_connection_error("initial stream failed"), + }, + }; + + let error: RequestError = source.into(); + + assert!(matches!(error, RequestError::InitialStream { .. })); + } + + #[test] + fn request_error_from_message_stream_error_preserves_variant() { + let source = MessageStreamError::Quic { + source: quic::StreamError::Reset { + code: VarInt::from_u32(0x20), + }, + }; + + let error: RequestError = source.into(); + + assert!(matches!(error, RequestError::ReceiveResponse { .. })); + } +} diff --git a/src/dhttp/message/hyper/read.rs b/src/dhttp/message/hyper/read.rs new file mode 100644 index 0000000..fe104dc --- /dev/null +++ b/src/dhttp/message/hyper/read.rs @@ -0,0 +1,622 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::{StreamExt, stream}; +use http_body::{Body, Frame, SizeHint}; +use http_body_util::{BodyExt, Empty, StreamBody}; + +use super::{ + MessageReader, MessageStreamError, + upgrade::{RemainStream, TakeoverSlot}, +}; +use crate::error::H3MessageError; + +pin_project_lite::pin_project! { + #[project = EitherProj] + pub(crate) enum Either { + Left { #[pin] body: L }, + Right { #[pin] body: R } + } +} + +impl Either { + pub(crate) fn left(body: L) -> Self { + Self::Left { body } + } + + pub(crate) fn right(body: R) -> Self { + Self::Right { body } + } +} + +impl> Body for Either { + type Data = L::Data; + type Error = L::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.project() { + EitherProj::Left { body } => body.poll_frame(cx), + EitherProj::Right { body } => body.poll_frame(cx), + } + } + + fn is_end_stream(&self) -> bool { + match self { + Either::Left { body } => body.is_end_stream(), + Either::Right { body } => body.is_end_stream(), + } + } + + fn size_hint(&self) -> SizeHint { + match self { + Either::Left { body } => body.size_hint(), + Either::Right { body } => body.size_hint(), + } + } +} + +impl MessageReader { + pub async fn read_hyper_request_parts( + &mut self, + ) -> Result { + let Some(field_section) = self.read_header().await? else { + let error = self + .handle_stream_error(H3MessageError::MissingHeaderSection.into()) + .await; + return Err(error.into()); + }; + match http::request::Parts::try_from(field_section) { + Ok(parts) => Ok(parts), + Err(error) => { + let error = self.handle_stream_error(error.into()).await; + Err(error.into()) + } + } + } + + pub async fn read_hyper_response_parts( + &mut self, + ) -> Result { + let Some(field_section) = self.read_header().await? else { + let error = self + .handle_stream_error(H3MessageError::MissingHeaderSection.into()) + .await; + return Err(error.into()); + }; + match http::response::Parts::try_from(field_section) { + Ok(parts) => Ok(parts), + Err(error) => { + let error = self.handle_stream_error(error.into()).await; + Err(error.into()) + } + } + } + + pub async fn read_hyper_frame(&mut self) -> Result>, MessageStreamError> { + if let Some(data) = self.read_data_chunk().await? { + return Ok(Some(Frame::data(data))); + } + + let Some(field_section) = self.read_header().await? else { + return Ok(None); + }; + + if !field_section.is_trailer() { + let error = self + .handle_stream_error(H3MessageError::UnexpectedHeadersInBody.into()) + .await; + return Err(error.into()); + } + + Ok(Some(Frame::trailers(field_section.into_header_map()))) + } + + pub fn as_hyper_body(&mut self) -> impl Body + Send { + StreamBody::new( + stream::unfold(self, async |stream| { + let frame = stream.read_hyper_frame().await.transpose()?; + Some((frame, stream)) + }) + .fuse(), + ) + } + + pub fn into_hyper_body(self) -> impl Body + Send { + StreamBody::new( + stream::unfold(self, async |mut stream| { + let frame = stream.read_hyper_frame().await.transpose()?; + Some((frame, stream)) + }) + .fuse(), + ) + } + + pub async fn into_hyper_request( + mut self, + ) -> Result< + http::Request + Send>, + MessageStreamError, + > { + let mut parts = self.read_hyper_request_parts().await?; + if parts.method == http::Method::CONNECT { + parts + .extensions + .insert(TakeoverSlot::new(RemainStream::immediately(self))); + let body = Either::right(Empty::new().map_err(|n| match n {})); + Ok(http::Request::from_parts(parts, body)) + } else { + let body = Either::left(self.into_hyper_body()); + Ok(http::Request::from_parts(parts, body)) + } + } + + pub async fn into_hyper_response( + mut self, + ) -> Result< + http::Response + Send>, + MessageStreamError, + > { + let mut parts = self.read_hyper_response_parts().await?; + match parts.status.is_informational() { + true => { + parts.extensions.insert(RemainStream::immediately(self)); + let body = Either::right(Empty::new().map_err(|n| match n {})); + Ok(http::Response::from_parts(parts, body)) + } + false => { + // no remain + let body = self.into_hyper_body(); + Ok(http::Response::from_parts(parts, Either::left(body))) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{convert::Infallible, pin::Pin, sync::Arc}; + + use bytes::Bytes; + use futures::{Sink, SinkExt, Stream}; + use http_body::Body; + use http_body_util::{BodyExt, Empty, Full, StreamBody}; + + use super::*; + use crate::{ + codec::{SinkWriter, StreamReader}, + connection::{ConnectionState, StreamError, tests::MockConnection}, + dhttp::{ + message::{MessageWriter, guard, test::read_stream_for_test}, + protocol::DHttpProtocol, + settings::Settings, + }, + protocol::Protocols, + qpack::{ + decoder::DecoderInstruction, + encoder::EncoderInstruction, + field::hyper::header_map_to_field_lines, + protocol::{QPackDecoder, QPackEncoder}, + }, + quic, + varint::VarInt, + }; + + fn qpack_decoder_sink() -> Pin + Send>> { + Box::pin(futures::sink::drain::().sink_map_err(|never| match never {})) + } + + fn qpack_decoder_stream() + -> Pin> + Send>> { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn qpack_encoder_sink() -> Pin + Send>> { + Box::pin(futures::sink::drain::().sink_map_err(|never| match never {})) + } + + fn qpack_encoder_stream() + -> Pin> + Send>> { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn stream_pair(stream_id: VarInt) -> (MessageReader, MessageWriter) { + let erased: Arc = Arc::new(MockConnection::new()); + + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(erased, Arc::new(protocols)); + + let (reader, writer) = quic::test::mock_stream_pair_with_capacity(stream_id, 64); + let reader = StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )); + let writer = SinkWriter::new(guard::GuardQuicWriter::new( + Box::pin(writer) as crate::quic::BoxQuicStreamWriter + )); + + ( + MessageReader::new( + stream_id, + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ), + MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ), + ) + } + + async fn next_body_frame( + mut body: Pin<&mut B>, + ) -> Option, B::Error>> { + futures::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await + } + + fn request_parts(method: http::Method, uri: &'static str) -> http::request::Parts { + let request = http::Request::builder() + .method(method) + .uri(uri) + .header("x-test", "present") + .body(()) + .expect("request"); + request.into_parts().0 + } + + fn response_parts(status: http::StatusCode) -> http::response::Parts { + let response = http::Response::builder() + .status(status) + .header("x-test", "present") + .body(()) + .expect("response"); + response.into_parts().0 + } + + #[tokio::test] + async fn either_left_and_right_delegate_body_methods() { + let mut left = + Either::, Empty>::left(Full::new(Bytes::from_static(b"left"))); + assert!(!left.is_end_stream()); + assert_eq!(left.size_hint().lower(), 4); + let frame = left.frame().await.expect("frame").expect("ok frame"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"left") + ); + assert!(left.is_end_stream()); + + let mut right = + Either::, Full>::right(Full::new(Bytes::from_static(b"right"))); + assert!(!right.is_end_stream()); + assert_eq!(right.size_hint().lower(), 5); + let frame = right.frame().await.expect("frame").expect("ok frame"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"right") + ); + assert!(right.is_end_stream()); + } + + #[tokio::test] + async fn read_hyper_request_parts_errors_when_header_missing() { + let mut stream = read_stream_for_test(VarInt::from_u32(0)); + + let error = stream + .read_hyper_request_parts() + .await + .expect_err("missing header should be rejected"); + + assert!(matches!(error, MessageStreamError::Quic { .. })); + } + + #[tokio::test] + async fn read_hyper_response_parts_errors_when_header_missing() { + let mut stream = read_stream_for_test(VarInt::from_u32(0)); + + let error = stream + .read_hyper_response_parts() + .await + .expect_err("missing header should be rejected"); + + assert!(matches!(error, MessageStreamError::Quic { .. })); + } + + #[tokio::test] + async fn read_hyper_request_parts_rejects_response_pseudo_headers() { + let (mut reader, mut writer) = stream_pair(VarInt::from_u32(0)); + writer + .send_hyper_response_parts(response_parts(http::StatusCode::OK)) + .await + .expect("response header should be written"); + + let error = reader + .read_hyper_request_parts() + .await + .expect_err("response pseudo headers are malformed for requests"); + + assert!(matches!(error, MessageStreamError::Quic { .. })); + } + + #[tokio::test] + async fn read_hyper_response_parts_rejects_request_pseudo_headers() { + let (mut reader, mut writer) = stream_pair(VarInt::from_u32(0)); + writer + .send_hyper_request_parts(request_parts(http::Method::GET, "https://example.test/")) + .await + .expect("request header should be written"); + + let error = reader + .read_hyper_response_parts() + .await + .expect_err("request pseudo headers are malformed for responses"); + + assert!(matches!(error, MessageStreamError::Quic { .. })); + } + + #[tokio::test] + async fn read_hyper_frame_returns_none_on_empty_stream() { + let mut stream = read_stream_for_test(VarInt::from_u32(0)); + + let frame = stream.read_hyper_frame().await.expect("empty stream"); + + assert!(frame.is_none()); + } + + #[tokio::test] + async fn as_hyper_body_and_into_hyper_body_end_on_empty_stream() { + let mut borrowed = read_stream_for_test(VarInt::from_u32(0)); + let body = borrowed.as_hyper_body(); + let mut body = std::pin::pin!(body); + assert!(next_body_frame(body.as_mut()).await.is_none()); + + let owned = read_stream_for_test(VarInt::from_u32(0)); + let body = owned.into_hyper_body(); + let mut body = std::pin::pin!(body); + assert!(next_body_frame(body.as_mut()).await.is_none()); + } + + #[tokio::test] + async fn as_hyper_body_yields_data_frames() { + let (mut reader, mut writer) = stream_pair(VarInt::from_u32(0)); + writer + .send_data(Bytes::from_static(b"borrowed-body")) + .await + .expect("data should be written"); + writer.close().await.expect("stream should close cleanly"); + + let body = reader.as_hyper_body(); + let mut body = std::pin::pin!(body); + let frame = next_body_frame(body.as_mut()) + .await + .expect("body frame") + .expect("data frame should decode"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"borrowed-body") + ); + assert!(next_body_frame(body.as_mut()).await.is_none()); + } + + #[tokio::test] + async fn into_hyper_body_yields_data_frames() { + let (reader, mut writer) = stream_pair(VarInt::from_u32(0)); + writer + .send_data(Bytes::from_static(b"owned-body")) + .await + .expect("data should be written"); + writer.close().await.expect("stream should close cleanly"); + + let body = reader.into_hyper_body(); + let mut body = std::pin::pin!(body); + let frame = next_body_frame(body.as_mut()) + .await + .expect("body frame") + .expect("data frame should decode"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"owned-body") + ); + assert!(next_body_frame(body.as_mut()).await.is_none()); + } + + #[tokio::test] + async fn read_hyper_request_and_body_from_written_request() { + let (reader, mut writer) = stream_pair(VarInt::from_u32(0)); + let request = http::Request::builder() + .method(http::Method::POST) + .uri("https://example.test/upload") + .header("x-test", "present") + .body(Full::new(Bytes::from_static(b"payload"))) + .expect("request"); + + writer + .send_hyper_request(request) + .await + .expect("request sent"); + drop(writer); + + let request = reader.into_hyper_request().await.expect("request decoded"); + assert_eq!(request.method(), http::Method::POST); + assert_eq!(request.uri(), "https://example.test/upload"); + assert_eq!(request.headers()["x-test"], "present"); + + let (_parts, body) = request.into_parts(); + let mut body = std::pin::pin!(body); + let frame = next_body_frame(body.as_mut()) + .await + .expect("body frame") + .expect("ok frame"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"payload") + ); + } + + #[tokio::test] + async fn read_hyper_response_and_body_from_written_response() { + let (reader, mut writer) = stream_pair(VarInt::from_u32(0)); + let response = http::Response::builder() + .status(http::StatusCode::CREATED) + .header("x-test", "present") + .body(Full::new(Bytes::from_static(b"response"))) + .expect("response"); + + writer + .send_hyper_response(response) + .await + .expect("response sent"); + drop(writer); + + let response = reader + .into_hyper_response() + .await + .expect("response decoded"); + assert_eq!(response.status(), http::StatusCode::CREATED); + assert_eq!(response.headers()["x-test"], "present"); + + let (_parts, body) = response.into_parts(); + let mut body = std::pin::pin!(body); + let frame = next_body_frame(body.as_mut()) + .await + .expect("body frame") + .expect("ok frame"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"response") + ); + } + + #[tokio::test] + async fn read_hyper_body_reads_data_and_trailers() { + let (mut reader, mut writer) = stream_pair(VarInt::from_u32(0)); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-trailer", http::HeaderValue::from_static("done")); + let body = StreamBody::new(futures::stream::iter([ + Ok::<_, Infallible>(Frame::data(Bytes::from_static(b"data"))), + Ok(Frame::trailers(trailers)), + ])); + + writer.send_hyper_body(body).await.expect("body sent"); + + let frame = reader + .read_hyper_frame() + .await + .expect("data frame") + .expect("data frame present"); + assert_eq!( + frame.into_data().expect("data"), + Bytes::from_static(b"data") + ); + + let frame = reader + .read_hyper_frame() + .await + .expect("trailer frame") + .expect("trailer frame present"); + let trailers = frame.into_trailers().expect("trailers"); + assert_eq!(trailers["x-trailer"], "done"); + } + + #[tokio::test] + async fn read_hyper_frame_rejects_non_trailer_headers_in_body() { + let (mut reader, mut writer) = stream_pair(VarInt::from_u32(0)); + + writer + .send_hyper_response_parts(response_parts(http::StatusCode::OK)) + .await + .expect("response header sent"); + + let error = reader + .read_hyper_frame() + .await + .expect_err("header in body should be rejected"); + + assert!(matches!(error, MessageStreamError::Quic { .. })); + } + + #[tokio::test] + async fn into_hyper_request_connect_stores_takeover_slot() { + let (reader, mut writer) = stream_pair(VarInt::from_u32(0)); + + writer + .send_hyper_request_parts(request_parts(http::Method::CONNECT, "example.test:443")) + .await + .expect("connect header sent"); + + let request = reader + .into_hyper_request() + .await + .expect("connect request decoded"); + + assert_eq!(request.method(), http::Method::CONNECT); + assert!( + request + .extensions() + .get::>() + .is_some() + ); + } + + #[tokio::test] + async fn into_hyper_response_informational_stores_remain_stream() { + let (reader, mut writer) = stream_pair(VarInt::from_u32(0)); + + writer + .send_hyper_response_parts(response_parts(http::StatusCode::EARLY_HINTS)) + .await + .expect("informational response sent"); + + let response = reader + .into_hyper_response() + .await + .expect("informational response decoded"); + + assert_eq!(response.status(), http::StatusCode::EARLY_HINTS); + assert!( + response + .extensions() + .get::>() + .is_some() + ); + } + + #[tokio::test] + async fn read_hyper_frame_reads_trailer_header_without_data() { + let (mut reader, mut writer) = stream_pair(VarInt::from_u32(0)); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-trailer", http::HeaderValue::from_static("only")); + + writer + .write_header(header_map_to_field_lines(trailers)) + .await + .expect("trailer header sent"); + + let frame = reader + .read_hyper_frame() + .await + .expect("trailer frame") + .expect("trailer frame present"); + let trailers = frame.into_trailers().expect("trailers"); + assert_eq!(trailers["x-trailer"], "only"); + } +} diff --git a/src/message/stream/hyper/upgrade.rs b/src/dhttp/message/hyper/upgrade.rs similarity index 59% rename from src/message/stream/hyper/upgrade.rs rename to src/dhttp/message/hyper/upgrade.rs index f9f7cfa..d31f920 100644 --- a/src/message/stream/hyper/upgrade.rs +++ b/src/dhttp/message/hyper/upgrade.rs @@ -212,13 +212,15 @@ mod tests { future::poll_fn, pin::Pin, task::{Context, Poll}, + time::Duration, }; use bytes::Bytes; use http_body::{Body, Frame}; + use tracing::Instrument; use super::*; - use crate::message::stream::ReadStream; + use crate::dhttp::message::MessageReader; #[derive(Debug, Clone)] struct ErrorBody; @@ -235,22 +237,170 @@ mod tests { } } + #[derive(Debug)] + struct OneFrameBody { + done: bool, + } + + impl Body for OneFrameBody { + type Data = Bytes; + type Error = std::io::Error; + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + let this = self.get_mut(); + if this.done { + Poll::Ready(None) + } else { + this.done = true; + Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b"hello"))))) + } + } + } + + #[tokio::test] + async fn remain_stream_pending_waits_until_sender_sends() { + let (tx, mut stream) = RemainStream::::pending(); + + let send_task = tokio::spawn( + async move { + tokio::time::sleep(Duration::from_millis(10)).await; + tx.send(8).ok(); + } + .in_current_span(), + ); + + let value = tokio::time::timeout( + Duration::from_millis(100), + poll_fn(|cx| Pin::new(&mut stream).poll(cx)), + ) + .await + .unwrap() + .unwrap(); + assert_eq!(value, 8); + + send_task.await.unwrap(); + } + + #[tokio::test] + async fn remain_stream_returns_none_when_sender_dropped() { + let (tx, mut stream) = RemainStream::::pending(); + drop(tx); + + let value = tokio::time::timeout( + Duration::from_millis(100), + poll_fn(|cx| Pin::new(&mut stream).poll(cx)), + ) + .await + .unwrap(); + assert!(value.is_none()); + } + + #[tokio::test] + async fn remain_stream_immediately_returns_some() { + let mut stream = RemainStream::immediately(15u8); + let value = poll_fn(|cx| Pin::new(&mut stream).poll(cx)).await; + assert_eq!(value, Some(15)); + } + + #[tokio::test] + async fn cloned_remain_stream_receives_from_shared_receiver() { + let (tx, stream) = RemainStream::::pending(); + let mut cloned = stream.clone(); + + tx.send(9).expect("receiver is still alive"); + + let value = poll_fn(|cx| Pin::new(&mut cloned).poll(cx)).await; + assert_eq!(value, Some(9)); + } + + #[tokio::test] + async fn cloned_takeover_slot_shares_taken_state() { + let slot = TakeoverSlot::new(RemainStream::immediately(33u8)); + let cloned = slot.clone(); + + let first = poll_fn(|cx| cloned.poll_take(cx)).await; + assert_eq!(first, Ok(33)); + + let second = poll_fn(|cx| slot.poll_take(cx)).await; + assert_eq!(second, Err(TakeoverError::AlreadyTaken)); + } + + #[tokio::test] + async fn takeover_works_for_response_and_mut_ref_variants() { + let mut response = http::Response::new(http_body_util::Empty::::new()); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately(21u8))); + + let value = { + let response_ref = &mut response; + poll_fn(|cx| HasTakeover::::poll_takeover(response_ref, cx)).await + }; + assert_eq!(value, Ok(21)); + + let mut response = http::Response::new(OneFrameBody { done: false }); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately(42u8))); + let response_ref = &mut response; + + let value = poll_fn(|cx| HasTakeover::::poll_takeover(response_ref, cx)).await; + assert_eq!(value, Ok(42)); + } + + #[tokio::test] + async fn takeover_release_body_frames_until_eof_on_request() { + let mut request = http::Request::new(OneFrameBody { done: false }); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately(77u8))); + + let value = poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; + assert_eq!(value, Ok(77)); + } + + #[tokio::test] + async fn takeover_returns_unsupported_for_response_without_slot() { + let mut response = http::Response::new(http_body_util::Empty::::new()); + let value = poll_fn(|cx| HasTakeover::::poll_takeover(&mut response, cx)).await; + assert!(matches!(value, Err(TakeoverError::Unsupported))); + } + + #[test] + fn missing_stream_display_matches_read_write_and_read_and_write() { + assert_eq!(MissingStream::Read.to_string(), "read"); + assert_eq!(MissingStream::Write.to_string(), "write"); + assert_eq!(MissingStream::Both.to_string(), "read and write"); + } + + #[test] + fn upgrade_error_converts_takeover_error() { + let takeover = TakeoverError::BodyNotReleased; + let upgrade: UpgradeError = takeover.into(); + assert_eq!(upgrade, UpgradeError::Takeover { source: takeover }); + } + #[tokio::test] async fn takeover_returns_unsupported_when_slot_missing() { let mut request = http::Request::new(http_body_util::Empty::::new()); - let result = poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; + let result = + poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; assert!(matches!(result, Err(TakeoverError::Unsupported))); } #[tokio::test] async fn takeover_returns_ready_when_slot_available() { let mut request = http::Request::new(http_body_util::Empty::::new()); - let (read_tx, read) = RemainStream::::pending(); + let (read_tx, read) = RemainStream::::pending(); request.extensions_mut().insert(TakeoverSlot::new(read)); drop(read_tx); - let result = poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; + let result = + poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; assert!(matches!(result, Err(TakeoverError::Aborted))); } @@ -270,19 +420,21 @@ mod tests { #[tokio::test] async fn takeover_returns_aborted_when_sender_dropped() { let mut request = http::Request::new(http_body_util::Empty::::new()); - let (read_tx, read) = RemainStream::::pending(); + let (read_tx, read) = RemainStream::::pending(); request.extensions_mut().insert(TakeoverSlot::new(read)); drop(read_tx); - let result = poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; + let result = + poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; assert!(matches!(result, Err(TakeoverError::Aborted))); } #[tokio::test] async fn takeover_returns_body_not_released_on_body_error() { let mut request = http::Request::new(ErrorBody); - let result = poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; + let result = + poll_fn(|cx| HasTakeover::::poll_takeover(&mut request, cx)).await; assert!(matches!(result, Err(TakeoverError::BodyNotReleased))); } } diff --git a/src/dhttp/message/hyper/write.rs b/src/dhttp/message/hyper/write.rs new file mode 100644 index 0000000..ed79663 --- /dev/null +++ b/src/dhttp/message/hyper/write.rs @@ -0,0 +1,560 @@ +use std::{convert::Infallible, pin::pin}; + +use http_body::Body; +use http_body_util::BodyExt; +use snafu::{ResultExt, Snafu}; + +use super::{MessageStreamError, MessageWriter}; +use crate::{ + error::H3StreamError, + qpack::field::{ + MalformedHeaderSection, + hyper::{ + header_map_to_field_lines, hyper_response_parts_to_field_lines, + validated_hyper_request_parts_to_field_lines, + }, + }, + quic::ResetStreamExt, +}; + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub(crate)))] +pub enum SendMessageError { + #[snafu(display("failed to send message on stream"))] + Stream { source: MessageStreamError }, + #[snafu(display("request pseudo-header section is malformed"))] + MalformedHeader { source: MalformedHeaderSection }, + #[snafu(display("failed to read body frame"))] + Body { source: E }, +} + +impl SendMessageError { + pub fn map_body_error( + self, + f: impl FnOnce(E) -> E1, + ) -> SendMessageError { + match self { + SendMessageError::Stream { source } => SendMessageError::Stream { source }, + SendMessageError::MalformedHeader { source } => { + SendMessageError::MalformedHeader { source } + } + SendMessageError::Body { source } => SendMessageError::Body { source: f(source) }, + } + } +} + +impl MessageWriter { + pub(crate) async fn send_hyper_body( + &mut self, + body: B, + ) -> Result<(), SendMessageError> + where + B::Data: Send, + B::Error: std::error::Error + 'static, + { + let mut body = pin!(body); + while let Some(frame) = body.frame().await { + let frame = frame.context(send_message_error::BodySnafu)?; + let frame = match frame.into_data() { + Ok(data) => { + self.write_data(data) + .await + .context(send_message_error::StreamSnafu)?; + continue; + } + Err(frame) => frame, + }; + let frame = match frame.into_trailers() { + Ok(trailers) => { + self.write_header(header_map_to_field_lines(trailers)) + .await + .context(send_message_error::StreamSnafu)?; + break; + } + Err(frame) => frame, + }; + + tracing::warn!("ignore unknown http body frame"); + _ = frame; + } + Ok(()) + } + + pub async fn send_hyper_request_parts( + &mut self, + parts: http::request::Parts, + ) -> Result<(), SendMessageError> { + let fields = match validated_hyper_request_parts_to_field_lines(parts) { + Ok(fields) => fields, + Err(source) => { + _ = self.stream.reset(source.code().into_inner()).await; + return Err(SendMessageError::MalformedHeader { source }); + } + }; + self.write_header(fields) + .await + .context(send_message_error::StreamSnafu) + } + + pub async fn send_hyper_request( + &mut self, + request: http::Request, + ) -> Result<(), SendMessageError> + where + B::Data: Send, + B::Error: std::error::Error + 'static, + { + let (parts, body) = request.into_parts(); + let fields = match validated_hyper_request_parts_to_field_lines(parts) { + Ok(fields) => fields, + Err(source) => { + _ = self.stream.reset(source.code().into_inner()).await; + return Err(SendMessageError::MalformedHeader { source }); + } + }; + self.write_header(fields) + .await + .context(send_message_error::StreamSnafu)?; + self.send_hyper_body(body).await + } + + pub async fn send_hyper_response_parts( + &mut self, + parts: http::response::Parts, + ) -> Result<(), MessageStreamError> { + self.write_header(hyper_response_parts_to_field_lines(parts)) + .await + } + + pub async fn send_hyper_response( + &mut self, + response: http::Response, + ) -> Result<(), SendMessageError> + where + B::Data: Send, + B::Error: std::error::Error + 'static, + { + let (parts, body) = response.into_parts(); + self.write_header(hyper_response_parts_to_field_lines(parts)) + .await + .context(send_message_error::StreamSnafu)?; + self.send_hyper_body(body).await + } +} + +#[cfg(test)] +mod tests { + use std::{ + convert::Infallible, + fmt, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use futures::{Sink, SinkExt, Stream, stream}; + use http_body::Frame; + use http_body_util::{Empty, Full, StreamBody}; + + use super::*; + use crate::{ + codec::SinkWriter, + dhttp::message::{guard, test::write_stream_for_test}, + qpack::protocol::QPackEncoder, + quic, + varint::VarInt, + }; + + #[derive(Debug)] + struct TestBodyError; + + impl fmt::Display for TestBodyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("test body error") + } + } + + impl std::error::Error for TestBodyError {} + + #[derive(Debug)] + struct MappedBodyError; + + impl fmt::Display for MappedBodyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("mapped body error") + } + } + + impl std::error::Error for MappedBodyError {} + + fn request_parts(uri: &'static str) -> http::request::Parts { + let request = http::Request::builder() + .method(http::Method::GET) + .uri(http::Uri::try_from(uri).expect("uri")) + .header("x-test", "present") + .body(()) + .expect("request"); + request.into_parts().0 + } + + fn reset_observing_write_stream(stream_id: VarInt) -> (MessageWriter, Arc>>) { + struct TestWriter { + stream_id: VarInt, + resets: Arc>>, + } + impl quic::GetStreamId for TestWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + impl quic::ResetStream for TestWriter { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.resets.lock().expect("reset lock").push(code); + Poll::Ready(Ok(())) + } + } + impl Sink for TestWriter { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + fn enc_sink() -> Pin< + Box< + dyn Sink< + crate::qpack::encoder::EncoderInstruction, + Error = crate::connection::StreamError, + > + Send, + >, + > { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + fn enc_stream() -> Pin< + Box< + dyn Stream< + Item = Result< + crate::qpack::decoder::DecoderInstruction, + crate::connection::StreamError, + >, + > + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + let mock = Arc::new(crate::connection::tests::MockConnection::new()); + let erased: Arc = mock; + + let mut protocols = crate::protocol::Protocols::new(); + protocols.insert(crate::dhttp::protocol::DHttpProtocol::new_for_test( + erased.clone(), + )); + let state = crate::connection::ConnectionState::new_for_test(erased, Arc::new(protocols)); + + let resets = Arc::new(Mutex::new(Vec::new())); + let writer = SinkWriter::new(guard::GuardQuicWriter::new(Box::pin(TestWriter { + stream_id, + resets: resets.clone(), + }) + as crate::quic::BoxQuicStreamWriter)); + + let stream = MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(crate::dhttp::settings::Settings::default()), + enc_sink(), + enc_stream(), + )), + state, + ); + + (stream, resets) + } + + #[tokio::test] + async fn send_hyper_request_parts_accepts_valid_parts() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + stream + .send_hyper_request_parts(request_parts("https://example.test/path")) + .await + .expect("request parts sent"); + } + + #[tokio::test] + async fn send_hyper_request_parts_rejects_malformed_parts() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + + let error = stream + .send_hyper_request_parts(request_parts("example.test")) + .await + .expect_err("authority-only GET is malformed"); + + assert!(matches!(error, SendMessageError::MalformedHeader { .. })); + } + + #[tokio::test] + async fn send_hyper_request_writes_header_and_body() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let request = http::Request::builder() + .method(http::Method::POST) + .uri(http::Uri::from_static("https://example.test/upload")) + .body(Full::new(Bytes::from_static(b"payload"))) + .expect("request"); + + stream + .send_hyper_request(request) + .await + .expect("request sent"); + } + + #[tokio::test] + async fn send_hyper_body_writes_data_and_trailers() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-trailer", http::HeaderValue::from_static("done")); + let body = StreamBody::new(stream::iter([ + Ok::<_, Infallible>(Frame::data(Bytes::from_static(b"data"))), + Ok(Frame::trailers(trailers)), + ])); + + stream.send_hyper_body(body).await.expect("body sent"); + } + + #[tokio::test] + async fn send_hyper_body_stops_after_trailers() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-trailer", http::HeaderValue::from_static("done")); + let body = StreamBody::new(stream::iter([ + Ok::<_, TestBodyError>(Frame::::trailers(trailers)), + Err(TestBodyError), + ])); + + stream + .send_hyper_body(body) + .await + .expect("body stops after trailers"); + } + + #[tokio::test] + async fn send_hyper_body_maps_body_errors() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let body = StreamBody::new(stream::iter([Err::, _>(TestBodyError)])); + + let error = stream + .send_hyper_body(body) + .await + .expect_err("body error should be returned"); + + assert!(matches!(error, SendMessageError::Body { .. })); + } + + #[tokio::test] + async fn send_hyper_request_accepts_empty_body() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let request = http::Request::builder() + .method(http::Method::GET) + .uri(http::Uri::from_static("https://example.test/")) + .body(Empty::::new()) + .expect("request"); + + stream + .send_hyper_request(request) + .await + .expect("empty request sent"); + } + + #[tokio::test] + async fn send_hyper_request_with_observed_writer_accepts_valid_request() { + let (mut stream, resets) = reset_observing_write_stream(VarInt::from_u32(0)); + let request = http::Request::builder() + .method(http::Method::POST) + .uri(http::Uri::from_static("https://example.test/upload")) + .body(Full::new(Bytes::from_static(b"payload"))) + .expect("request"); + + stream + .send_hyper_request(request) + .await + .expect("request sent"); + stream.flush().await.expect("stream flushed"); + stream.close().await.expect("stream closed"); + + assert!(resets.lock().expect("reset lock").is_empty()); + } + + #[tokio::test] + async fn send_hyper_response_parts_and_response_succeed() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let response = http::Response::builder() + .status(http::StatusCode::CREATED) + .header("x-test", "present") + .body(()) + .expect("response"); + let (parts, ()) = response.into_parts(); + + stream + .send_hyper_response_parts(parts) + .await + .expect("response parts sent"); + + let response = http::Response::builder() + .status(http::StatusCode::OK) + .body(Full::new(Bytes::from_static(b"response"))) + .expect("response"); + + stream + .send_hyper_response(response) + .await + .expect("response sent"); + } + + #[tokio::test] + async fn send_hyper_response_accepts_empty_body() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let response = http::Response::builder() + .status(http::StatusCode::NO_CONTENT) + .body(Empty::::new()) + .expect("response"); + + stream + .send_hyper_response(response) + .await + .expect("empty response sent"); + } + + #[tokio::test] + async fn send_hyper_response_maps_body_errors() { + let mut stream = write_stream_for_test(VarInt::from_u32(0)); + let body = StreamBody::new(stream::iter([Err::, _>(TestBodyError)])); + let response = http::Response::builder() + .status(http::StatusCode::OK) + .body(body) + .expect("response"); + + let error = stream + .send_hyper_response(response) + .await + .expect_err("response body error should be returned"); + + assert!(matches!( + error, + SendMessageError::Body { + source: TestBodyError + } + )); + } + + #[tokio::test] + async fn send_hyper_request_parts_resets_stream_on_malformed_parts() { + let (mut stream, resets) = reset_observing_write_stream(VarInt::from_u32(0)); + let expected_code = + validated_hyper_request_parts_to_field_lines(request_parts("example.test")) + .expect_err("authority-only GET is malformed") + .code() + .into_inner(); + + let error = stream + .send_hyper_request_parts(request_parts("example.test")) + .await + .expect_err("authority-only GET is malformed"); + + assert!(matches!(error, SendMessageError::MalformedHeader { .. })); + assert_eq!(*resets.lock().expect("reset lock"), vec![expected_code]); + } + + #[tokio::test] + async fn send_hyper_request_resets_stream_on_malformed_parts() { + let (mut stream, resets) = reset_observing_write_stream(VarInt::from_u32(0)); + let expected_code = + validated_hyper_request_parts_to_field_lines(request_parts("example.test")) + .expect_err("authority-only GET is malformed") + .code() + .into_inner(); + let request = http::Request::builder() + .method(http::Method::GET) + .uri(http::Uri::try_from("example.test").expect("uri")) + .body(Empty::::new()) + .expect("request"); + + let error = stream + .send_hyper_request(request) + .await + .expect_err("authority-only GET is malformed"); + + assert!(matches!(error, SendMessageError::MalformedHeader { .. })); + assert_eq!(*resets.lock().expect("reset lock"), vec![expected_code]); + } + + #[test] + fn map_body_error_preserves_non_body_variants() { + let stream_error = SendMessageError::::Stream { + source: MessageStreamError::MalformedOutgoingMessage, + } + .map_body_error(|_| MappedBodyError); + assert!(matches!(stream_error, SendMessageError::Stream { .. })); + + let malformed = SendMessageError::::MalformedHeader { + source: validated_hyper_request_parts_to_field_lines(request_parts("example.test")) + .expect_err("authority-only GET is malformed"), + } + .map_body_error(|never| match never {}); + assert!(matches!( + malformed, + SendMessageError::MalformedHeader { .. } + )); + } + + #[test] + fn map_body_error_maps_body_variant() { + let error = SendMessageError::Body { + source: TestBodyError, + } + .map_body_error(|_| MappedBodyError); + + assert!(matches!(error, SendMessageError::Body { .. })); + } + + #[test] + fn test_body_error_display_messages_are_stable() { + assert_eq!(TestBodyError.to_string(), "test body error"); + assert_eq!(MappedBodyError.to_string(), "mapped body error"); + } +} diff --git a/src/dhttp/message/test.rs b/src/dhttp/message/test.rs new file mode 100644 index 0000000..d1df5ea --- /dev/null +++ b/src/dhttp/message/test.rs @@ -0,0 +1,254 @@ +//! Test helpers for creating [`MessageReader`](super::MessageReader) and +//! [`MessageWriter`](super::MessageWriter) instances with mock QUIC connections, +//! suitable for unit testing. + +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::{Sink, SinkExt, Stream}; + +use super::{MessageReader, MessageWriter, guard}; +use crate::{ + codec::{SinkWriter, StreamReader}, + qpack::protocol::{QPackDecoder, QPackEncoder}, + quic, + varint::VarInt, +}; + +/// Create a [`MessageReader`] for testing with a mock QUIC connection. +/// +/// The returned `MessageReader` uses a no-op reader that immediately returns `None`, +/// and QPack codec streams that never produce data. The underlying mock connection +/// has no open streams and will never become ready. +pub fn read_stream_for_test(stream_id: VarInt) -> MessageReader { + struct TestReader { + stream_id: VarInt, + } + impl quic::GetStreamId for TestReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + impl quic::StopStream for TestReader { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + impl Stream for TestReader { + type Item = Result; + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } + } + + fn dec_sink() -> Pin< + Box< + dyn Sink< + crate::qpack::decoder::DecoderInstruction, + Error = crate::connection::StreamError, + > + Send, + >, + > { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + fn dec_stream() -> Pin< + Box< + dyn Stream< + Item = Result< + crate::qpack::encoder::EncoderInstruction, + crate::connection::StreamError, + >, + > + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + let mock = Arc::new(crate::connection::tests::MockConnection::new()); + let erased: Arc = mock; + + let mut protocols = crate::protocol::Protocols::new(); + protocols.insert(crate::dhttp::protocol::DHttpProtocol::new_for_test( + erased.clone(), + )); + let state = crate::connection::ConnectionState::new_for_test(erased, Arc::new(protocols)); + + let reader = StreamReader::new(guard::GuardQuicReader::new( + Box::pin(TestReader { stream_id }) as crate::quic::BoxQuicStreamReader, + )); + + MessageReader::new( + stream_id, + reader, + Arc::new(QPackDecoder::new( + Arc::new(crate::dhttp::settings::Settings::default()), + dec_sink(), + dec_stream(), + )), + state, + ) +} + +/// Create a [`MessageWriter`] for testing with a mock QUIC connection. +/// +/// The returned `MessageWriter` uses a no-op writer that discards all written data, +/// and QPack codec streams that never produce data. The underlying mock connection +/// has no open streams and will never become ready. +pub fn write_stream_for_test(stream_id: VarInt) -> MessageWriter { + struct TestWriter { + stream_id: VarInt, + } + impl quic::GetStreamId for TestWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.get_mut().stream_id)) + } + } + impl quic::ResetStream for TestWriter { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + impl Sink for TestWriter { + type Error = quic::StreamError; + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + fn enc_sink() -> Pin< + Box< + dyn Sink< + crate::qpack::encoder::EncoderInstruction, + Error = crate::connection::StreamError, + > + Send, + >, + > { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + fn enc_stream() -> Pin< + Box< + dyn Stream< + Item = Result< + crate::qpack::decoder::DecoderInstruction, + crate::connection::StreamError, + >, + > + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + let mock = Arc::new(crate::connection::tests::MockConnection::new()); + let erased: Arc = mock; + + let mut protocols = crate::protocol::Protocols::new(); + protocols.insert(crate::dhttp::protocol::DHttpProtocol::new_for_test( + erased.clone(), + )); + let state = crate::connection::ConnectionState::new_for_test(erased, Arc::new(protocols)); + + let writer = SinkWriter::new(guard::GuardQuicWriter::new( + Box::pin(TestWriter { stream_id }) as crate::quic::BoxQuicStreamWriter, + )); + + MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(crate::dhttp::settings::Settings::default()), + enc_sink(), + enc_stream(), + )), + state, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + error::Code, + quic::{self, GetStreamIdExt}, + varint::VarInt, + }; + + #[tokio::test] + async fn test_write_stream_new() { + let stream_id = VarInt::from_u32(4); + let _write_stream = write_stream_for_test(stream_id); + } + + #[tokio::test] + async fn test_write_stream_connection() { + let stream_id = VarInt::from_u32(4); + let write_stream = write_stream_for_test(stream_id); + let conn: &Arc = write_stream.state.quic(); + let _ = Arc::clone(conn); + } + + #[tokio::test] + async fn test_write_stream_reset() { + let stream_id = VarInt::from_u32(4); + let mut write_stream = write_stream_for_test(stream_id); + let result = write_stream.reset(Code::H3_NO_ERROR).await; + assert!(result.is_ok(), "reset should succeed on mock writer"); + } + + #[tokio::test] + async fn test_write_stream_get_stream_id() { + let stream_id = VarInt::from_u32(4); + let mut write_stream = write_stream_for_test(stream_id); + let id = GetStreamIdExt::stream_id(&mut write_stream).await; + assert!(id.is_ok(), "stream_id should succeed"); + assert_eq!( + id.unwrap(), + stream_id, + "stream_id should match the value passed to write_stream_for_test" + ); + } +} diff --git a/src/message/stream/unfold.rs b/src/dhttp/message/unfold.rs similarity index 100% rename from src/message/stream/unfold.rs rename to src/dhttp/message/unfold.rs diff --git a/src/dhttp/message/unfold/read.rs b/src/dhttp/message/unfold/read.rs new file mode 100644 index 0000000..c7fc90c --- /dev/null +++ b/src/dhttp/message/unfold/read.rs @@ -0,0 +1,1059 @@ +use std::{ + future::Future, + ops::DerefMut, + pin::Pin, + task::{Context, Poll, ready}, +}; + +use bytes::Bytes; +use futures::stream::FusedStream; + +use super::super::{BoxMessageReader, MessageReader, MessageStreamError}; +use crate::{ + codec::StreamReader, + quic::{self, GetStreamId as QuicGetStreamId, StopStream as QuicStopStream}, + stream, + varint::VarInt, +}; + +impl From for BoxMessageReader { + fn from(value: MessageReader) -> Self { + value.into_box_reader() + } +} + +// --------------------------------------------------------------------------- +// Unfold – custom stream unfold that preserves QUIC traits +// --------------------------------------------------------------------------- + +pin_project_lite::pin_project! { + #[project = StateProj] + #[project_replace = StateProjReplace] + enum State { + Stream { stream: StreamState }, + Read { + token: tokio_util::sync::CancellationToken, + #[pin] + future: ReadFuture, + }, + Stop { + code: VarInt, + #[pin] + future: StopFuture, + }, + Empty, + } +} + +impl State { + fn take_stream(self: Pin<&mut Self>) -> StreamState { + match self.project_replace(Self::Empty) { + StateProjReplace::Stream { stream } => stream, + _ => unreachable!("invalid state for take_stream"), + } + } +} + +pin_project_lite::pin_project! { + /// A fused stream adapter similar to [`futures::stream::unfold`], but with + /// stream-specific read and stop futures that always return the stream state. + #[must_use = "streams do nothing unless polled"] + pub struct Unfold { + read: Read, + stop: Stop, + terminated: bool, + pending_stop: Option, + pending_item: Option, + _item: std::marker::PhantomData Item>, + #[pin] + state: State, + } +} + +trait StreamErrorItem { + fn from_stream_error(error: quic::StreamError) -> Self; +} + +impl StreamErrorItem for Result +where + Error: From, +{ + fn from_stream_error(error: quic::StreamError) -> Self { + Err(error.into()) + } +} + +/// Create an [`Unfold`] stream. +/// +/// The read future yields either a delivered item plus the returned stream +/// state, EOF plus the returned stream state, or an internally interrupted +/// stream state. The stop future is a separate operation so callers can adapt +/// states whose stop behavior is not expressed directly as a [`crate::quic::StopStream`] +/// implementation. +pub fn unfold( + init: StreamState, + read: Read, + stop: Stop, +) -> Unfold +where + Read: FnMut(StreamState, tokio_util::sync::CancellationToken) -> ReadFuture, + ReadFuture: Future), StreamState>>, + Stop: FnMut(StreamState, VarInt) -> StopFuture, + StopFuture: Future)>, +{ + Unfold { + read, + stop, + terminated: false, + pending_stop: None, + pending_item: None, + _item: std::marker::PhantomData, + state: State::Stream { stream: init }, + } +} + +impl + Unfold +where + Read: FnMut(StreamState, tokio_util::sync::CancellationToken) -> ReadFuture, + ReadFuture: Future), StreamState>>, + Stop: FnMut(StreamState, VarInt) -> StopFuture, + StopFuture: Future)>, +{ + fn poll_pending_stop( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let Some(code) = self.as_mut().project().pending_stop.as_ref().copied() else { + return Poll::Ready(Ok(())); + }; + + loop { + let mut project = self.as_mut().project(); + match project.state.as_mut().project() { + StateProj::Stream { .. } => { + let stream = project.state.as_mut().take_stream(); + project.state.set(State::Stop { + code, + future: (project.stop)(stream, code), + }); + } + StateProj::Read { .. } => return Poll::Pending, + StateProj::Stop { future, .. } => { + let (stream, result) = ready!(future.poll(cx)); + project.state.set(State::Stream { stream }); + *project.pending_stop = None; + return Poll::Ready(result); + } + StateProj::Empty => unreachable!("invalid state for poll_pending_stop"), + } + } + } +} + +impl futures::Stream + for Unfold +where + Read: FnMut(StreamState, tokio_util::sync::CancellationToken) -> ReadFuture, + ReadFuture: Future), StreamState>>, + Stop: FnMut(StreamState, VarInt) -> StopFuture, + StopFuture: Future)>, + Item: StreamErrorItem, +{ + type Item = Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(item) = self.as_mut().project().pending_item.take() { + return Poll::Ready(Some(item)); + } + + loop { + let mut project = self.as_mut().project(); + match project.state.as_mut().project() { + StateProj::Stream { .. } => { + if *project.terminated { + return Poll::Ready(None); + } + let stream = project.state.as_mut().take_stream(); + let token = tokio_util::sync::CancellationToken::new(); + project.state.set(State::Read { + token: token.clone(), + future: (project.read)(stream, token), + }); + } + StateProj::Read { future, .. } => match ready!(future.poll(cx)) { + futures::future::Either::Left((stream, item)) => { + if item.is_none() { + *project.terminated = true; + } + project.state.set(State::Stream { stream }); + return Poll::Ready(item); + } + futures::future::Either::Right(stream) => { + project.state.set(State::Stream { stream }); + } + }, + StateProj::Stop { future, .. } => { + let (stream, result) = ready!(future.poll(cx)); + project.state.set(State::Stream { stream }); + *project.pending_stop = None; + if let Err(error) = result { + return Poll::Ready(Some(Item::from_stream_error(error))); + } + } + StateProj::Empty => unreachable!("invalid state for poll_next"), + } + } + } +} + +impl FusedStream + for Unfold +where + Read: FnMut(StreamState, tokio_util::sync::CancellationToken) -> ReadFuture, + ReadFuture: Future), StreamState>>, + Stop: FnMut(StreamState, VarInt) -> StopFuture, + StopFuture: Future)>, + Item: StreamErrorItem, +{ + fn is_terminated(&self) -> bool { + self.terminated + } +} + +impl stream::GetStreamId + for Unfold +where + StreamState: stream::GetStreamId + Unpin, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let project = self.project(); + match project.state.project() { + StateProj::Stream { stream } => Pin::new(stream).poll_stream_id(cx), + _ => Poll::Pending, + } + } +} + +impl stream::StopStream + for Unfold +where + Read: FnMut(StreamState, tokio_util::sync::CancellationToken) -> ReadFuture, + ReadFuture: Future), StreamState>>, + Stop: FnMut(StreamState, VarInt) -> StopFuture, + StopFuture: Future)>, +{ + fn poll_stop( + mut self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + if self.as_mut().project().pending_stop.is_none() { + *self.as_mut().project().pending_stop = Some(code); + } + self.poll_pending_stop(cx) + } +} + +impl QuicGetStreamId + for Unfold +where + StreamState: QuicGetStreamId + Unpin, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let project = self.project(); + match project.state.project() { + StateProj::Stream { stream } => Pin::new(stream).poll_stream_id(cx), + _ => Poll::Pending, + } + } +} + +impl QuicStopStream + for Unfold +where + Read: FnMut(StreamState, tokio_util::sync::CancellationToken) -> ReadFuture, + ReadFuture: Future), StreamState>>, + Stop: FnMut(StreamState, VarInt) -> StopFuture, + StopFuture: Future)>, +{ + fn poll_stop( + mut self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + if self.as_mut().project().pending_stop.is_none() { + *self.as_mut().project().pending_stop = Some(code); + } + self.poll_pending_stop(cx) + } +} + +// --------------------------------------------------------------------------- +// MessageReader conversion methods +// --------------------------------------------------------------------------- + +impl MessageReader { + pub fn as_bytes_stream( + &mut self, + ) -> impl stream::ReadStream + + QuicGetStreamId + + QuicStopStream + + FusedStream + + Send + + '_ { + unfold( + self, + |stream: &mut MessageReader, token| async move { + tokio::select! { + biased; + _ = token.cancelled() => futures::future::Either::Right(stream), + result = stream.read_data_chunk() => { + let item = match result { + Ok(Some(bytes)) => Some(Ok(bytes)), + Ok(None) => None, + Err(error) => Some(Err(error)), + }; + futures::future::Either::Left((stream, item)) + } + } + }, + |mut stream: &mut MessageReader, code| async move { + let result = futures::future::poll_fn(|cx| { + QuicStopStream::poll_stop(Pin::new(stream.deref_mut()), cx, code) + }) + .await; + (stream, result) + }, + ) + } + + pub fn as_reader( + &mut self, + ) -> StreamReader< + impl stream::ReadStream + + QuicGetStreamId + + QuicStopStream + + FusedStream + + Send + + '_, + > { + StreamReader::new(self.as_bytes_stream()) + } + + pub fn as_box_reader( + &mut self, + ) -> Pin< + Box< + dyn stream::ReadStream + + Send + + '_, + >, + > { + Box::pin(self.as_bytes_stream()) + } + + pub fn into_bytes_stream( + self, + ) -> impl stream::ReadStream + + QuicGetStreamId + + QuicStopStream + + FusedStream + + Send { + unfold( + self, + |mut stream: MessageReader, token| async move { + tokio::select! { + biased; + _ = token.cancelled() => futures::future::Either::Right(stream), + result = stream.read_data_chunk() => { + let item = match result { + Ok(Some(bytes)) => Some(Ok(bytes)), + Ok(None) => None, + Err(error) => Some(Err(error)), + }; + futures::future::Either::Left((stream, item)) + } + } + }, + |mut stream: MessageReader, code| async move { + let result = futures::future::poll_fn(|cx| { + QuicStopStream::poll_stop(Pin::new(&mut stream), cx, code) + }) + .await; + (stream, result) + }, + ) + } + + pub fn into_reader( + self, + ) -> StreamReader< + impl stream::ReadStream + + QuicGetStreamId + + QuicStopStream + + FusedStream + + Send, + > { + StreamReader::new(self.into_bytes_stream()) + } + + pub fn into_box_reader(self) -> BoxMessageReader { + Box::pin(self.into_bytes_stream()) + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + io, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use futures::{ + FutureExt, Stream, StreamExt, + future::{Either, poll_fn}, + }; + use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; + + use super::*; + + #[derive(Debug)] + struct ControlStream { + stream_id: VarInt, + stopped: Arc>>, + } + + impl stream::GetStreamId for ControlStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl stream::StopStream for ControlStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + *self.stopped.lock().expect("stop state poisoned") = Some(code); + Poll::Ready(Ok(())) + } + } + + impl QuicGetStreamId for ControlStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl QuicStopStream for ControlStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + *self.stopped.lock().expect("stop state poisoned") = Some(code); + Poll::Ready(Ok(())) + } + } + + async fn stop_ok( + stream: StreamState, + _code: VarInt, + ) -> (StreamState, Result<(), quic::StreamError>) { + (stream, Ok(())) + } + + #[tokio::test] + async fn unfold_yields_items_and_reports_termination() { + let mut stream = Box::pin(unfold( + 0, + |value, _token| async move { + let item = (value < 3).then_some(Ok::<_, quic::StreamError>(value)); + Either::Left((value + 1, item)) + }, + stop_ok, + )); + + assert_eq!(stream.as_mut().next().await.unwrap().unwrap(), 0); + assert_eq!(stream.as_mut().next().await.unwrap().unwrap(), 1); + assert_eq!(stream.as_mut().next().await.unwrap().unwrap(), 2); + assert!(stream.as_mut().next().await.is_none()); + assert!(stream.as_ref().get_ref().is_terminated()); + } + + #[tokio::test] + async fn control_traits_forward_while_value_available() { + let stopped = Arc::new(Mutex::new(None)); + let stream_id = VarInt::from_u32(37); + let stop_code = VarInt::from_u32(41); + let mut stream = Box::pin(unfold( + ControlStream { + stream_id, + stopped: stopped.clone(), + }, + |stream, _token| async move { Either::Left((stream, Some(Ok::<_, quic::StreamError>(())))) }, + |stream: ControlStream, code| async move { + *stream.stopped.lock().expect("stop state poisoned") = Some(code); + (stream, Ok(())) + }, + )); + + assert_eq!( + poll_fn(|cx| stream.as_mut().poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop forwarded"); + assert_eq!( + *stopped.lock().expect("stop state poisoned"), + Some(stop_code) + ); + } + + #[tokio::test] + async fn control_traits_wait_while_future_owns_value() { + let stopped = Arc::new(Mutex::new(None)); + let mut stream = Box::pin(unfold( + ControlStream { + stream_id: VarInt::from_u32(37), + stopped, + }, + |_stream, _token| { + futures::future::pending::< + Either<(ControlStream, Option>), ControlStream>, + >() + }, + |stream: ControlStream, code| async move { + *stream.stopped.lock().expect("stop state poisoned") = Some(code); + (stream, Ok(())) + }, + )); + + assert!(futures::poll!(stream.as_mut().next()).is_pending()); + assert!( + poll_fn(|cx| stream.as_mut().poll_stream_id(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| stream.as_mut().poll_stop(cx, VarInt::from_u32(41))) + .now_or_never() + .is_none() + ); + } + + #[tokio::test] + async fn stop_waits_for_pending_read_before_forwarding() { + let stopped = Arc::new(Mutex::new(None)); + let stop_code = VarInt::from_u32(41); + let stream_id = VarInt::from_u32(37); + let (read_tx, read_rx) = tokio::sync::oneshot::channel(); + let mut read_rx = Some(read_rx); + let mut stream = Box::pin(unfold( + ControlStream { + stream_id, + stopped: stopped.clone(), + }, + move |stream, _token| { + let read_rx = read_rx.take().expect("single read future"); + async move { + read_rx.await.expect("read release sent"); + Either::Left((stream, Some(Ok::<_, quic::StreamError>(())))) + } + }, + |stream: ControlStream, code| async move { + *stream.stopped.lock().expect("stop state poisoned") = Some(code); + (stream, Ok(())) + }, + )); + + assert!( + poll_fn(|cx| stream.as_mut().poll_next(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .now_or_never() + .is_none() + ); + assert_eq!(*stopped.lock().expect("stop state poisoned"), None); + + read_tx.send(()).expect("release pending read"); + assert!(matches!(stream.as_mut().next().await, Some(Ok(())))); + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop should complete after read yields stream"); + assert_eq!( + *stopped.lock().expect("stop state poisoned"), + Some(stop_code) + ); + assert_eq!( + poll_fn(|cx| stream.as_mut().poll_stream_id(cx)) + .await + .expect("stream id should remain available"), + stream_id + ); + } + + #[tokio::test] + async fn stop_does_not_interrupt_pending_next() { + let stopped = Arc::new(Mutex::new(None)); + let stop_code = VarInt::from_u32(41); + let (read_tx, read_rx) = tokio::sync::oneshot::channel(); + let mut read_rx = Some(read_rx); + let mut stream = Box::pin(unfold( + ControlStream { + stream_id: VarInt::from_u32(37), + stopped: stopped.clone(), + }, + move |stream, token| { + let read_rx = read_rx.take().expect("single read future"); + async move { + tokio::select! { + biased; + _ = token.cancelled() => Either::Left(( + stream, + Some(Ok::(Bytes::from_static(b"interrupted"))), + )), + item = read_rx => Either::Left(( + stream, + Some(Ok::( + item.expect("read release sent"), + )), + )), + } + } + }, + |stream: ControlStream, code| async move { + *stream.stopped.lock().expect("stop state poisoned") = Some(code); + (stream, Ok(())) + }, + )); + + assert!( + poll_fn(|cx| stream.as_mut().poll_next(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .now_or_never() + .is_none(), + "stop must wait for the read future without cancelling it" + ); + assert_eq!(*stopped.lock().expect("stop state poisoned"), None); + + read_tx + .send(Bytes::from_static(b"read")) + .expect("release pending read"); + assert_eq!( + stream + .as_mut() + .next() + .await + .expect("read item") + .expect("read succeeds"), + Bytes::from_static(b"read") + ); + assert_eq!(*stopped.lock().expect("stop state poisoned"), None); + + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop completes after read yields the stream"); + assert_eq!( + *stopped.lock().expect("stop state poisoned"), + Some(stop_code) + ); + } + + #[tokio::test] + async fn stop_waits_for_pending_read_that_reaches_eof() { + let stopped = Arc::new(Mutex::new(None)); + let stop_code = VarInt::from_u32(41); + let (eof_tx, eof_rx) = tokio::sync::oneshot::channel(); + let mut eof_rx = Some(eof_rx); + let mut stream = Box::pin(unfold( + ControlStream { + stream_id: VarInt::from_u32(37), + stopped: stopped.clone(), + }, + move |stream, _token| { + let eof_rx = eof_rx.take().expect("single read future"); + async move { + eof_rx.await.expect("eof release sent"); + Either::< + (ControlStream, Option>), + ControlStream, + >::Left((stream, None)) + } + }, + |stream: ControlStream, code| async move { + *stream.stopped.lock().expect("stop state poisoned") = Some(code); + (stream, Ok(())) + }, + )); + + assert!( + poll_fn(|cx| stream.as_mut().poll_next(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .now_or_never() + .is_none() + ); + + eof_tx.send(()).expect("release pending eof"); + assert!(stream.as_mut().next().await.is_none()); + poll_fn(|cx| stream.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop should complete after eof yields stream"); + + assert!(stream.as_ref().get_ref().is_terminated()); + assert_eq!( + *stopped.lock().expect("stop state poisoned"), + Some(stop_code) + ); + } + + #[tokio::test] + async fn unfold_yields_error_items_without_terminating_the_stream() { + let mut stream = Box::pin(unfold( + VecDeque::from([ + Ok::(Bytes::from_static(b"chunk-1")), + Err::(MessageStreamError::MalformedIncomingMessage), + Ok(Bytes::from_static(b"chunk-2")), + ]), + |mut items, _token| async move { + let item = items.pop_front(); + Either::Left((items, item)) + }, + stop_ok, + )); + + match stream.as_mut().next().await { + Some(Ok(item)) => assert_eq!(item, Bytes::from_static(b"chunk-1")), + value => panic!("unexpected first item: {value:?}"), + } + assert!(matches!( + stream.as_mut().next().await, + Some(Err(MessageStreamError::MalformedIncomingMessage)) + )); + match stream.as_mut().next().await { + Some(Ok(item)) => assert_eq!(item, Bytes::from_static(b"chunk-2")), + value => panic!("unexpected third item: {value:?}"), + } + assert!(stream.as_mut().next().await.is_none()); + assert!(stream.as_ref().get_ref().is_terminated()); + } + + #[tokio::test] + async fn stream_reader_implements_async_read_and_hits_eof() { + let mut reader = Box::pin(StreamReader::new(unfold( + VecDeque::from([ + Ok::(Bytes::from_static(b"hel")), + Ok::(Bytes::from_static(b"lo")), + Ok::(Bytes::from_static(b"")), + Ok::(Bytes::from_static(b"world")), + ]), + |mut chunks, _token| async move { + let chunk = chunks.pop_front(); + Either::Left((chunks, chunk)) + }, + stop_ok, + ))); + + let mut data = Vec::new(); + let mut read = 0; + + loop { + let mut buf = [0_u8; 4]; + let n = poll_fn(|cx| { + let mut read_buf = ReadBuf::new(&mut buf); + match reader.as_mut().poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(ready) => Poll::Ready(ready.map(|_| read_buf.filled().len())), + } + }) + .await + .unwrap(); + + if n == 0 { + break; + } + + data.extend_from_slice(&buf[..n]); + read += n; + } + + assert_eq!(read, 10); + assert_eq!(data, b"helloworld"); + let mut buf = [0_u8; 4]; + assert_eq!( + poll_fn(|cx| { + let mut read_buf = ReadBuf::new(&mut buf); + match reader.as_mut().poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(ready) => Poll::Ready(ready.map(|_| read_buf.filled().len())), + } + }) + .await + .unwrap(), + 0 + ); + } + + #[tokio::test] + async fn stream_reader_implements_async_buf_read() { + let mut reader = Box::pin(StreamReader::new(unfold( + VecDeque::from([ + Ok::(Bytes::from_static(b"ab")), + Ok::(Bytes::from_static(b"cd")), + ]), + |mut chunks, _token| async move { + let chunk = chunks.pop_front(); + Either::Left((chunks, chunk)) + }, + stop_ok, + ))); + + assert_eq!( + poll_fn(|cx| { + reader + .as_mut() + .poll_fill_buf(cx) + .map(|result| result.map(|buf| buf.to_vec())) + }) + .await + .unwrap(), + b"ab" + ); + reader.as_mut().consume(1); + assert_eq!( + poll_fn(|cx| { + reader + .as_mut() + .poll_fill_buf(cx) + .map(|result| result.map(|buf| buf.to_vec())) + }) + .await + .unwrap(), + b"b" + ); + reader.as_mut().consume(1); + assert_eq!( + poll_fn(|cx| { + reader + .as_mut() + .poll_fill_buf(cx) + .map(|result| result.map(|buf| buf.to_vec())) + }) + .await + .unwrap(), + b"cd" + ); + reader.as_mut().consume(2); + assert_eq!( + poll_fn(|cx| { + reader + .as_mut() + .poll_fill_buf(cx) + .map(|result| result.map(|buf| buf.to_vec())) + }) + .await + .unwrap(), + b"" + ); + } + + #[tokio::test] + async fn read_stream_as_reader_reports_eof_and_termination() { + let mut stream = crate::dhttp::message::test::read_stream_for_test(VarInt::from_u32(71)); + let mut reader = Box::pin(stream.as_reader()); + + assert!(poll_fn(|cx| reader.as_mut().poll_next(cx)).await.is_none()); + assert!(reader.stream().is_terminated()); + } + + #[tokio::test] + async fn read_stream_into_bytes_stream_forwards_control_traits_until_eof() { + let stream_id = VarInt::from_u32(72); + let stop_code = VarInt::from_u32(73); + let stream = crate::dhttp::message::test::read_stream_for_test(stream_id); + let mut bytes = Box::pin(stream.into_bytes_stream()); + + assert_eq!( + poll_fn(|cx| bytes.as_mut().poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| bytes.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop forwarded"); + + assert!(poll_fn(|cx| bytes.as_mut().poll_next(cx)).await.is_none()); + assert!(bytes.as_ref().get_ref().is_terminated()); + assert_eq!( + poll_fn(|cx| bytes.as_mut().poll_stream_id(cx)) + .await + .expect("stream id after eof"), + stream_id + ); + poll_fn(|cx| bytes.as_mut().poll_stop(cx, stop_code)) + .await + .expect("stop after eof"); + } + + #[tokio::test] + async fn read_stream_into_reader_reports_eof_and_termination() { + let stream = crate::dhttp::message::test::read_stream_for_test(VarInt::from_u32(74)); + let mut reader = Box::pin(stream.into_reader()); + + assert!(poll_fn(|cx| reader.as_mut().poll_next(cx)).await.is_none()); + assert!(reader.stream().is_terminated()); + } + + #[tokio::test] + async fn read_stream_from_conversion_builds_box_reader() { + let stream_id = VarInt::from_u32(75); + let stream = crate::dhttp::message::test::read_stream_for_test(stream_id); + let mut reader: BoxMessageReader = stream.into(); + + assert_eq!( + poll_fn(|cx| Pin::new(&mut reader).poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + + assert!(poll_fn(|cx| reader.as_mut().poll_next(cx)).await.is_none()); + } + + #[tokio::test] + async fn read_stream_into_box_reader_forwards_control_traits_and_stops_at_eof() { + let stream_id = VarInt::from_u32(90); + let stop_code = VarInt::from_u32(102); + + let mut stream = crate::dhttp::message::test::read_stream_for_test(stream_id); + let mut reader = stream.as_box_reader(); + + assert_eq!( + poll_fn(|cx| stream::GetStreamId::poll_stream_id(Pin::new(&mut reader), cx)) + .await + .unwrap(), + stream_id + ); + poll_fn(|cx| stream::StopStream::poll_stop(Pin::new(&mut reader), cx, stop_code)) + .await + .unwrap(); + + assert!(poll_fn(|cx| reader.as_mut().poll_next(cx)).await.is_none()); + } + + #[tokio::test] + async fn unfold_waits_for_termination_controls_after_done() { + let mut stream = Box::pin(unfold( + ControlStream { + stream_id: VarInt::from_u32(37), + stopped: Arc::new(Mutex::new(None)), + }, + |stream: ControlStream, _token| async move { + Either::Left((stream, None::>)) + }, + stop_ok, + )); + + assert!(stream.as_mut().next().await.is_none()); + assert!(stream.as_ref().get_ref().is_terminated()); + assert_eq!( + poll_fn(|cx| stream.as_mut().poll_stream_id(cx)) + .await + .expect("stream id after termination"), + VarInt::from_u32(37) + ); + poll_fn(|cx| stream.as_mut().poll_stop(cx, VarInt::from_u32(41))) + .await + .expect("stop after termination"); + } + + #[tokio::test] + async fn stream_reader_maps_message_stream_error_items_to_io_errors() { + let mut reader = Box::pin(StreamReader::new(unfold( + VecDeque::from([Err::( + MessageStreamError::MalformedIncomingMessage, + )]), + |mut chunks, _token| async move { + let chunk = chunks.pop_front(); + Either::Left((chunks, chunk)) + }, + stop_ok, + ))); + let mut buf = [0_u8; 4]; + + let error = poll_fn(|cx| { + let mut read_buf = ReadBuf::new(&mut buf); + match reader.as_mut().poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => Poll::Ready(result), + } + }) + .await + .expect_err("error item should become io error"); + + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn stream_reader_poll_next_returns_buffered_chunk_then_eof() { + let mut reader = Box::pin(StreamReader::new(unfold( + VecDeque::from([Ok::(Bytes::from_static(b"abcd"))]), + |mut chunks, _token| async move { + let chunk = chunks.pop_front(); + Either::Left((chunks, chunk)) + }, + stop_ok, + ))); + let mut buf = [0_u8; 2]; + + let read = poll_fn(|cx| { + let mut read_buf = ReadBuf::new(&mut buf); + match reader.as_mut().poll_read(cx, &mut read_buf) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => Poll::Ready(result.map(|_| read_buf.filled().len())), + } + }) + .await + .expect("partial read succeeds"); + assert_eq!(read, 2); + assert_eq!(&buf, b"ab"); + + match poll_fn(|cx| reader.as_mut().poll_next(cx)).await { + Some(Ok(chunk)) => assert_eq!(chunk, Bytes::from_static(b"cd")), + value => panic!("unexpected buffered chunk: {value:?}"), + } + assert!(poll_fn(|cx| reader.as_mut().poll_next(cx)).await.is_none()); + } +} diff --git a/src/dhttp/message/unfold/write.rs b/src/dhttp/message/unfold/write.rs new file mode 100644 index 0000000..7ac8b0b --- /dev/null +++ b/src/dhttp/message/unfold/write.rs @@ -0,0 +1,1956 @@ +//! since futures::sink::unfold only works for send, we implement our own version here to support flush and close as well. + +use std::{ + ops::DerefMut, + pin::Pin, + task::{Context, Poll, ready}, +}; + +use bytes::Bytes; +use futures::{Sink, future::Either}; +use tokio_util::sync::CancellationToken; + +use super::super::{BoxMessageWriter, MessageStreamError, MessageWriter}; +use crate::{ + codec::SinkWriter, + quic::{self, GetStreamId as QuicGetStreamId, ResetStream as QuicResetStream}, + stream, + varint::VarInt, +}; + +impl From for BoxMessageWriter { + fn from(value: MessageWriter) -> Self { + value.into_box_writer() + } +} + +// --------------------------------------------------------------------------- +// Unfold – custom sink unfold that preserves QUIC traits +// --------------------------------------------------------------------------- + +pin_project_lite::pin_project! { + #[project = StateProj] + #[project_replace = StateProjReplace] + #[derive(Debug)] + enum State { + Stream { stream: StreamState }, + Send { + token: CancellationToken, + #[pin] + future: SendFuture, + }, + Flush { + token: CancellationToken, + #[pin] + future: FlushFuture, + }, + Close { + token: CancellationToken, + #[pin] + future: CloseFuture, + }, + Reset { + #[pin] + future: ResetFuture, + }, + Empty, + } +} + +impl + State +{ + fn take_stream(self: Pin<&mut Self>) -> Option { + match &*self { + Self::Stream { .. } => match self.project_replace(Self::Empty) { + StateProjReplace::Stream { stream } => Some(stream), + _ => unreachable!(), + }, + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +enum PendingWriteOp { + Reset(VarInt), + Flush, + Close, +} + +#[derive(Debug, Default)] +struct PendingWriteQueue { + ops: std::collections::VecDeque, +} + +impl PendingWriteQueue { + fn contains_reset(&self) -> bool { + matches!(self.ops.front(), Some(PendingWriteOp::Reset(_))) + } + + fn contains_flush(&self) -> bool { + self.ops.contains(&PendingWriteOp::Flush) + } + + fn contains_close(&self) -> bool { + self.ops.contains(&PendingWriteOp::Close) + } + + fn contains_kind(&self, op: PendingWriteOp) -> bool { + match op { + PendingWriteOp::Reset(_) => self.contains_reset(), + PendingWriteOp::Flush => self.contains_flush(), + PendingWriteOp::Close => self.contains_close(), + } + } + + fn enqueue_reset(&mut self, code: VarInt) { + if !self.contains_reset() { + self.ops.clear(); + self.ops.push_back(PendingWriteOp::Reset(code)); + } + } + + fn enqueue_flush(&mut self) { + if !self.contains_reset() && !self.contains_flush() { + self.ops.push_back(PendingWriteOp::Flush); + } + } + + fn enqueue_close(&mut self) { + if !self.contains_reset() && !self.contains_close() { + self.ops.push_back(PendingWriteOp::Close); + } + } + + fn front(&self) -> Option { + self.ops.front().copied() + } + + fn pop_front(&mut self) { + self.ops.pop_front(); + } + + fn is_empty(&self) -> bool { + self.ops.is_empty() + } +} + +pin_project_lite::pin_project! { + #[derive(Debug)] + #[must_use = "sinks do nothing unless polled"] + pub struct Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > { + send: Send, + flush: Flush, + close: Close, + reset: Reset, + pending: PendingWriteQueue, + _error: std::marker::PhantomData SinkError>, + #[pin] + state: State, + } +} + +impl< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, +> + Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > +where + SendFuture: Future), StreamState>>, + FlushFuture: Future), StreamState>>, + CloseFuture: Future), StreamState>>, + ResetFuture: Future)>, + SinkError: From, +{ + fn is_flush(&self) -> bool { + matches!(self.state, State::Flush { .. }) + } + + fn is_close(&self) -> bool { + matches!(self.state, State::Close { .. }) + } + + fn is_reset(&self) -> bool { + matches!(self.state, State::Reset { .. }) + } + + fn set_stream(mut self: Pin<&mut Self>, stream: StreamState) { + self.as_mut().project().state.set(State::Stream { stream }); + } + + fn poll_current_to_stream( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let completed = { + let mut project = self.as_mut().project(); + match project.state.as_mut().project() { + StateProj::Stream { .. } => return Poll::Ready(Ok(())), + StateProj::Send { future, .. } => match ready!(future.poll(cx)) { + Either::Left((stream, result)) => (stream, result), + Either::Right(stream) => (stream, Ok(())), + }, + StateProj::Flush { future, .. } => match ready!(future.poll(cx)) { + Either::Left((stream, result)) => (stream, result), + Either::Right(stream) => (stream, Ok(())), + }, + StateProj::Close { future, .. } => match ready!(future.poll(cx)) { + Either::Left((stream, result)) => (stream, result), + Either::Right(stream) => (stream, Ok(())), + }, + StateProj::Reset { future } => { + let (stream, result) = ready!(future.poll(cx)); + (stream, result.map_err(SinkError::from)) + } + StateProj::Empty => unreachable!("invalid state for poll_current_to_stream"), + } + }; + self.as_mut().set_stream(completed.0); + Poll::Ready(completed.1) + } + + fn poll_current_to_stream_after_reset( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll { + let stream = { + let mut project = self.as_mut().project(); + match project.state.as_mut().project() { + StateProj::Stream { .. } => project + .state + .as_mut() + .take_stream() + .expect("stream state should contain stream"), + StateProj::Send { token, future } => { + token.cancel(); + match ready!(future.poll(cx)) { + Either::Left((stream, _)) | Either::Right(stream) => stream, + } + } + StateProj::Flush { token, future } => { + token.cancel(); + match ready!(future.poll(cx)) { + Either::Left((stream, _)) | Either::Right(stream) => stream, + } + } + StateProj::Close { token, future } => { + token.cancel(); + match ready!(future.poll(cx)) { + Either::Left((stream, _)) | Either::Right(stream) => stream, + } + } + StateProj::Reset { .. } => unreachable!("reset is already active"), + StateProj::Empty => unreachable!("invalid state for reset restoration"), + } + }; + Poll::Ready(stream) + } + + fn start_flush(mut self: Pin<&mut Self>) + where + Flush: FnMut(StreamState, CancellationToken) -> FlushFuture, + { + let mut project = self.as_mut().project(); + let stream = project.state.as_mut().take_stream().expect("stream state"); + let token = CancellationToken::new(); + project.state.set(State::Flush { + token: token.clone(), + future: (project.flush)(stream, token), + }); + } + + fn start_close(mut self: Pin<&mut Self>) + where + Close: FnMut(StreamState, CancellationToken) -> CloseFuture, + { + let mut project = self.as_mut().project(); + let stream = project.state.as_mut().take_stream().expect("stream state"); + let token = CancellationToken::new(); + project.state.set(State::Close { + token: token.clone(), + future: (project.close)(stream, token), + }); + } + + fn start_reset(mut self: Pin<&mut Self>, code: VarInt) + where + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + { + let mut project = self.as_mut().project(); + let stream = project.state.as_mut().take_stream().expect("stream state"); + project.state.set(State::Reset { + future: (project.reset)(stream, code), + }); + } + + fn poll_reset_op( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> + where + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + { + let Some(PendingWriteOp::Reset(code)) = self.as_mut().project().pending.front() else { + return Poll::Ready(Ok(())); + }; + + if !self.is_reset() { + let stream = ready!(self.as_mut().poll_current_to_stream_after_reset(cx)); + self.as_mut().set_stream(stream); + self.as_mut().start_reset(code); + } + + let (stream, result) = { + let mut project = self.as_mut().project(); + match project.state.as_mut().project() { + StateProj::Reset { future } => ready!(future.poll(cx)), + _ => unreachable!("reset operation should be active"), + } + }; + self.as_mut().set_stream(stream); + self.as_mut().project().pending.pop_front(); + Poll::Ready(result) + } + + fn poll_pending_until( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + target: Option, + ) -> Poll> + where + Flush: FnMut(StreamState, CancellationToken) -> FlushFuture, + Close: FnMut(StreamState, CancellationToken) -> CloseFuture, + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + { + loop { + let op = { + let project = self.as_mut().project(); + if let Some(target) = target + && !project.pending.contains_kind(target) + { + return Poll::Ready(Ok(())); + } + match project.pending.front() { + Some(op) => op, + None => return Poll::Ready(Ok(())), + } + }; + + match op { + PendingWriteOp::Reset(_) => { + ready!(self.as_mut().poll_reset_op(cx)).map_err(SinkError::from)?; + } + PendingWriteOp::Flush => { + if !self.is_flush() { + ready!(self.as_mut().poll_current_to_stream(cx)?); + self.as_mut().start_flush(); + } + ready!(self.as_mut().poll_current_to_stream(cx)?); + self.as_mut().project().pending.pop_front(); + if target == Some(PendingWriteOp::Flush) { + return Poll::Ready(Ok(())); + } + } + PendingWriteOp::Close => { + if !self.is_close() { + ready!(self.as_mut().poll_current_to_stream(cx)?); + self.as_mut().start_close(); + } + ready!(self.as_mut().poll_current_to_stream(cx)?); + self.as_mut().project().pending.pop_front(); + if target == Some(PendingWriteOp::Close) { + return Poll::Ready(Ok(())); + } + } + } + } + } +} + +impl< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + Item, +> Sink + for Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > +where + Send: FnMut(StreamState, CancellationToken, Item) -> SendFuture, + Flush: FnMut(StreamState, CancellationToken) -> FlushFuture, + Close: FnMut(StreamState, CancellationToken) -> CloseFuture, + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + SendFuture: Future), StreamState>>, + FlushFuture: Future), StreamState>>, + CloseFuture: Future), StreamState>>, + ResetFuture: Future)>, + SinkError: From, +{ + type Error = SinkError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while !self.as_mut().project().pending.is_empty() { + ready!(self.as_mut().poll_pending_until(cx, None)?); + } + ready!(self.as_mut().poll_current_to_stream(cx)?); + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + let mut project = self.project(); + let stream = project + .state + .as_mut() + .take_stream() + .expect("start_send called without poll_ready being called first"); + let token = CancellationToken::new(); + project.state.set(State::Send { + token: token.clone(), + future: (project.send)(stream, token, item), + }); + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().pending.enqueue_flush(); + self.poll_pending_until(cx, Some(PendingWriteOp::Flush)) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().pending.enqueue_close(); + self.poll_pending_until(cx, Some(PendingWriteOp::Close)) + } +} + +// --------------------------------------------------------------------------- +// QUIC control trait forwarding for Unfold +// --------------------------------------------------------------------------- + +impl< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, +> stream::GetStreamId + for Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > +where + StreamState: stream::GetStreamId + Unpin, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let project = self.project(); + match project.state.project() { + StateProj::Stream { stream } => Pin::new(stream).poll_stream_id(cx), + _ => Poll::Pending, + } + } +} + +impl< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, +> stream::ResetStream + for Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > +where + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + SendFuture: Future), StreamState>>, + FlushFuture: Future), StreamState>>, + CloseFuture: Future), StreamState>>, + ResetFuture: Future)>, + SinkError: From, +{ + fn poll_reset( + mut self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.as_mut().project().pending.enqueue_reset(code); + self.poll_reset_op(cx) + } +} + +impl< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, +> QuicGetStreamId + for Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > +where + StreamState: QuicGetStreamId + Unpin, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + let project = self.project(); + match project.state.project() { + StateProj::Stream { stream } => Pin::new(stream).poll_stream_id(cx), + _ => Poll::Pending, + } + } +} + +impl< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, +> QuicResetStream + for Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + > +where + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + SendFuture: Future), StreamState>>, + FlushFuture: Future), StreamState>>, + CloseFuture: Future), StreamState>>, + ResetFuture: Future)>, + SinkError: From, +{ + fn poll_reset( + mut self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.as_mut().project().pending.enqueue_reset(code); + self.poll_reset_op(cx) + } +} + +// --------------------------------------------------------------------------- +// Unfold constructor +// --------------------------------------------------------------------------- + +pub fn unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, + Item, +>( + init: StreamState, + send: Send, + flush: Flush, + close: Close, + reset: Reset, +) -> Unfold< + StreamState, + Send, + Flush, + Close, + Reset, + SendFuture, + FlushFuture, + CloseFuture, + ResetFuture, + SinkError, +> +where + Send: FnMut(StreamState, CancellationToken, Item) -> SendFuture, + SendFuture: Future), StreamState>>, + Flush: FnMut(StreamState, CancellationToken) -> FlushFuture, + FlushFuture: Future), StreamState>>, + Close: FnMut(StreamState, CancellationToken) -> CloseFuture, + CloseFuture: Future), StreamState>>, + Reset: FnMut(StreamState, VarInt) -> ResetFuture, + ResetFuture: Future)>, +{ + Unfold { + send, + flush, + close, + reset, + pending: PendingWriteQueue::default(), + _error: std::marker::PhantomData, + state: State::Stream { stream: init }, + } +} + +// --------------------------------------------------------------------------- +// MessageWriter conversion methods +// --------------------------------------------------------------------------- + +impl MessageWriter { + pub fn as_bytes_sink( + &mut self, + ) -> impl stream::WriteStream + + QuicGetStreamId + + QuicResetStream + + Send + + '_ { + unfold( + self, + async |stream: &mut MessageWriter, token, buf: Bytes| { + tokio::select! { + biased; + _ = token.cancelled() => Either::Right(stream), + result = stream.write_data(buf) => Either::Left((stream, result)), + } + }, + async |stream: &mut MessageWriter, token| { + tokio::select! { + biased; + _ = token.cancelled() => Either::Right(stream), + result = stream.flush() => Either::Left((stream, result)), + } + }, + async |stream: &mut MessageWriter, token| { + tokio::select! { + biased; + _ = token.cancelled() => Either::Right(stream), + result = stream.close() => Either::Left((stream, result)), + } + }, + async |mut stream: &mut MessageWriter, code| { + let result = futures::future::poll_fn(|cx| { + QuicResetStream::poll_reset(Pin::new(stream.deref_mut()), cx, code) + }) + .await; + (stream, result) + }, + ) + } + + pub fn as_writer( + &mut self, + ) -> SinkWriter< + impl stream::WriteStream + + QuicGetStreamId + + QuicResetStream + + Send + + '_, + > { + SinkWriter::new(self.as_bytes_sink()) + } + + pub fn as_box_writer( + &mut self, + ) -> Pin< + Box< + dyn stream::WriteStream + + Send + + '_, + >, + > { + Box::pin(self.as_bytes_sink()) + } + + pub fn into_bytes_sink( + self, + ) -> impl stream::WriteStream + + QuicGetStreamId + + QuicResetStream + + Send { + unfold( + self, + async |mut stream: MessageWriter, token, buf: Bytes| { + tokio::select! { + biased; + _ = token.cancelled() => Either::Right(stream), + result = stream.write_data(buf) => Either::Left((stream, result)), + } + }, + async |mut stream: MessageWriter, token| { + tokio::select! { + biased; + _ = token.cancelled() => Either::Right(stream), + result = stream.flush() => Either::Left((stream, result)), + } + }, + async |mut stream: MessageWriter, token| { + tokio::select! { + biased; + _ = token.cancelled() => Either::Right(stream), + result = stream.close() => Either::Left((stream, result)), + } + }, + async |mut stream: MessageWriter, code| { + let result = futures::future::poll_fn(|cx| { + QuicResetStream::poll_reset(Pin::new(&mut stream), cx, code) + }) + .await; + (stream, result) + }, + ) + } + + pub fn into_writer( + self, + ) -> SinkWriter< + impl stream::WriteStream + + QuicGetStreamId + + QuicResetStream + + Send, + > { + SinkWriter::new(self.into_bytes_sink()) + } + + pub fn into_box_writer(self) -> BoxMessageWriter { + Box::pin(self.into_bytes_sink()) + } +} + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::{ + Arc, Mutex, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, + task::{Context, Poll}, + }; + + use futures::{ + FutureExt, SinkExt, + future::{Either, poll_fn}, + }; + + use super::*; + use crate::quic::{GetStreamId, ResetStream}; + + #[derive(Debug, Clone, PartialEq, Eq)] + enum Event { + Send(Bytes), + Flush, + Close, + Reset(VarInt), + } + + #[derive(Debug)] + struct ControlSink { + stream_id: VarInt, + events: Arc>>, + } + + impl stream::GetStreamId for ControlSink { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl stream::ResetStream for ControlSink { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.events + .lock() + .expect("event log poisoned") + .push(Event::Reset(code)); + Poll::Ready(Ok(())) + } + } + + impl GetStreamId for ControlSink { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl ResetStream for ControlSink { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.events + .lock() + .expect("event log poisoned") + .push(Event::Reset(code)); + Poll::Ready(Ok(())) + } + } + + type ControlResult = Either<(ControlSink, Result<(), MessageStreamError>), ControlSink>; + type ControlReady = futures::future::Ready; + type ResetReady = futures::future::Ready<(ControlSink, Result<(), quic::StreamError>)>; + type ReadyState = State; + + fn control_sink(stream_id: VarInt) -> ControlSink { + ControlSink { + stream_id, + events: Arc::new(Mutex::new(Vec::new())), + } + } + + fn assert_reset_message_error(error: MessageStreamError, code: VarInt) { + assert!(matches!( + error, + MessageStreamError::Quic { + source: quic::StreamError::Reset { code: actual } + } if actual == code + )); + } + + fn ready_control(stream_id: VarInt) -> ControlReady { + futures::future::ready(Either::Left((control_sink(stream_id), Ok(())))) + } + + fn send_ok(stream: ControlSink, _token: CancellationToken, _item: Bytes) -> ControlReady { + futures::future::ready(Either::Left((stream, Ok(())))) + } + + fn operation_ok(stream: ControlSink, _token: CancellationToken) -> ControlReady { + futures::future::ready(Either::Left((stream, Ok(())))) + } + + fn send_pending( + _stream: ControlSink, + _token: CancellationToken, + _item: Bytes, + ) -> futures::future::Pending { + futures::future::pending() + } + + fn operation_pending( + _stream: ControlSink, + _token: CancellationToken, + ) -> futures::future::Pending { + futures::future::pending() + } + + fn send_message_failed( + stream: ControlSink, + _token: CancellationToken, + _item: Bytes, + ) -> ControlReady { + futures::future::ready(Either::Left(( + stream, + Err(MessageStreamError::MessageSendFailed), + ))) + } + + fn send_malformed_outgoing( + stream: ControlSink, + _token: CancellationToken, + _item: Bytes, + ) -> ControlReady { + futures::future::ready(Either::Left(( + stream, + Err(MessageStreamError::MalformedOutgoingMessage), + ))) + } + + fn operation_message_failed(stream: ControlSink, _token: CancellationToken) -> ControlReady { + futures::future::ready(Either::Left(( + stream, + Err(MessageStreamError::MessageSendFailed), + ))) + } + + fn operation_malformed_outgoing( + stream: ControlSink, + _token: CancellationToken, + ) -> ControlReady { + futures::future::ready(Either::Left(( + stream, + Err(MessageStreamError::MalformedOutgoingMessage), + ))) + } + + fn record_send(stream: ControlSink, _token: CancellationToken, item: Bytes) -> ControlReady { + stream + .events + .lock() + .expect("event log poisoned") + .push(Event::Send(item)); + futures::future::ready(Either::Left((stream, Ok(())))) + } + + fn record_flush(stream: ControlSink, _token: CancellationToken) -> ControlReady { + stream + .events + .lock() + .expect("event log poisoned") + .push(Event::Flush); + futures::future::ready(Either::Left((stream, Ok(())))) + } + + fn record_close(stream: ControlSink, _token: CancellationToken) -> ControlReady { + stream + .events + .lock() + .expect("event log poisoned") + .push(Event::Close); + futures::future::ready(Either::Left((stream, Ok(())))) + } + + fn record_reset(stream: ControlSink, code: VarInt) -> ResetReady { + stream + .events + .lock() + .expect("event log poisoned") + .push(Event::Reset(code)); + futures::future::ready((stream, Ok(()))) + } + + fn reset_ok(stream: ControlSink, _code: VarInt) -> ResetReady { + futures::future::ready((stream, Ok(()))) + } + + #[derive(Debug)] + struct GateFuture { + value: Option, + open: Arc, + } + + impl Future for GateFuture { + type Output = ControlResult; + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + if self.open.load(Ordering::SeqCst) { + Poll::Ready(Either::Left(( + self.value.take().expect("gate future polled after ready"), + Ok(()), + ))) + } else { + Poll::Pending + } + } + } + + #[test] + fn state_take_stream_only_extracts_stream_state() { + let mut stream = ReadyState::Stream { + stream: control_sink(VarInt::from_u32(3)), + }; + assert_eq!( + Pin::new(&mut stream) + .take_stream() + .expect("stream state should yield inner stream") + .stream_id, + VarInt::from_u32(3) + ); + assert!( + Pin::new(&mut stream).take_stream().is_none(), + "taking the stream should leave the state empty" + ); + + let mut send = ReadyState::Send { + token: CancellationToken::new(), + future: ready_control(VarInt::from_u32(5)), + }; + assert!(Pin::new(&mut send).take_stream().is_none()); + + let mut flush = ReadyState::Flush { + token: CancellationToken::new(), + future: ready_control(VarInt::from_u32(7)), + }; + assert!(Pin::new(&mut flush).take_stream().is_none()); + + let mut close = ReadyState::Close { + token: CancellationToken::new(), + future: ready_control(VarInt::from_u32(9)), + }; + assert!(Pin::new(&mut close).take_stream().is_none()); + } + + #[tokio::test] + async fn unfold_sends_flushes_and_closes_in_order() { + let events = Arc::new(Mutex::new(Vec::new())); + let mut sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(11), + events: events.clone(), + }, + record_send, + record_flush, + record_close, + record_reset, + )); + + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink ready"); + sink.as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + poll_fn(|cx| sink.as_mut().poll_flush(cx)) + .await + .expect("sink flushed"); + poll_fn(|cx| sink.as_mut().poll_close(cx)) + .await + .expect("sink closed"); + + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![ + Event::Send(Bytes::from_static(b"payload")), + Event::Flush, + Event::Close, + ] + ); + } + + #[tokio::test] + async fn unfold_flush_and_close_keep_single_in_flight_operation_until_ready() { + let flush_open = Arc::new(AtomicBool::new(false)); + let close_open = Arc::new(AtomicBool::new(false)); + let flush_started = Arc::new(AtomicUsize::new(0)); + let close_started = Arc::new(AtomicUsize::new(0)); + let mut sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(31), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_ok, + { + let flush_open = flush_open.clone(); + let flush_started = flush_started.clone(); + move |stream: ControlSink, _token| { + flush_started.fetch_add(1, Ordering::SeqCst); + GateFuture { + value: Some(stream), + open: flush_open.clone(), + } + } + }, + { + let close_open = close_open.clone(); + let close_started = close_started.clone(); + move |stream: ControlSink, _token| { + close_started.fetch_add(1, Ordering::SeqCst); + GateFuture { + value: Some(stream), + open: close_open.clone(), + } + } + }, + reset_ok, + )); + + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink ready"); + + assert!( + poll_fn(|cx| sink.as_mut().poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert_eq!( + flush_started.load(Ordering::SeqCst), + 1, + "first flush poll should start one flush future" + ); + assert!( + poll_fn(|cx| sink.as_mut().poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert_eq!( + flush_started.load(Ordering::SeqCst), + 1, + "polling an in-flight flush must not start another flush" + ); + + flush_open.store(true, Ordering::SeqCst); + poll_fn(|cx| sink.as_mut().poll_flush(cx)) + .await + .expect("flush completes once gate opens"); + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink ready after completed flush"); + assert_eq!( + flush_started.load(Ordering::SeqCst), + 1, + "poll_ready on stream state must not flush again" + ); + + assert!( + poll_fn(|cx| sink.as_mut().poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!( + close_started.load(Ordering::SeqCst), + 1, + "first close poll should start one close future" + ); + assert!( + poll_fn(|cx| sink.as_mut().poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!( + close_started.load(Ordering::SeqCst), + 1, + "polling an in-flight close must not start another close" + ); + + close_open.store(true, Ordering::SeqCst); + poll_fn(|cx| sink.as_mut().poll_close(cx)) + .await + .expect("close completes once gate opens"); + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink ready after completed close"); + assert_eq!( + close_started.load(Ordering::SeqCst), + 1, + "poll_ready on stream state must not close again" + ); + } + + #[tokio::test] + async fn unfold_flush_and_close_wait_for_in_flight_send_before_starting() { + let send_open = Arc::new(AtomicBool::new(false)); + let send_started = Arc::new(AtomicUsize::new(0)); + let flush_started = Arc::new(AtomicUsize::new(0)); + let mut flush_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(35), + events: Arc::new(Mutex::new(Vec::new())), + }, + { + let send_open = send_open.clone(); + let send_started = send_started.clone(); + move |stream: ControlSink, _token, _item: Bytes| { + send_started.fetch_add(1, Ordering::SeqCst); + GateFuture { + value: Some(stream), + open: send_open.clone(), + } + } + }, + { + let flush_started = flush_started.clone(); + move |stream: ControlSink, _token| { + flush_started.fetch_add(1, Ordering::SeqCst); + futures::future::ready(Either::Left((stream, Ok::<_, MessageStreamError>(())))) + } + }, + operation_ok, + reset_ok, + )); + + poll_fn(|cx| flush_sink.as_mut().poll_ready(cx)) + .await + .expect("sink initially ready"); + flush_sink + .as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!( + poll_fn(|cx| flush_sink.as_mut().poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(send_started.load(Ordering::SeqCst), 1); + assert_eq!(flush_started.load(Ordering::SeqCst), 0); + + send_open.store(true, Ordering::SeqCst); + poll_fn(|cx| flush_sink.as_mut().poll_flush(cx)) + .await + .expect("flush starts after send completes"); + assert_eq!(flush_started.load(Ordering::SeqCst), 1); + + let send_open = Arc::new(AtomicBool::new(false)); + let send_started = Arc::new(AtomicUsize::new(0)); + let close_started = Arc::new(AtomicUsize::new(0)); + let mut close_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(39), + events: Arc::new(Mutex::new(Vec::new())), + }, + { + let send_open = send_open.clone(); + let send_started = send_started.clone(); + move |stream: ControlSink, _token, _item: Bytes| { + send_started.fetch_add(1, Ordering::SeqCst); + GateFuture { + value: Some(stream), + open: send_open.clone(), + } + } + }, + operation_ok, + { + let close_started = close_started.clone(); + move |stream: ControlSink, _token| { + close_started.fetch_add(1, Ordering::SeqCst); + futures::future::ready(Either::Left((stream, Ok::<_, MessageStreamError>(())))) + } + }, + reset_ok, + )); + + poll_fn(|cx| close_sink.as_mut().poll_ready(cx)) + .await + .expect("sink initially ready"); + close_sink + .as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!( + poll_fn(|cx| close_sink.as_mut().poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(send_started.load(Ordering::SeqCst), 1); + assert_eq!(close_started.load(Ordering::SeqCst), 0); + + send_open.store(true, Ordering::SeqCst); + poll_fn(|cx| close_sink.as_mut().poll_close(cx)) + .await + .expect("close starts after send completes"); + assert_eq!(close_started.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn unfold_poll_ready_waits_for_send_then_restores_control_traits() { + let send_open = Arc::new(AtomicBool::new(false)); + let send_started = Arc::new(AtomicUsize::new(0)); + let stream_id = VarInt::from_u32(33); + let reset_code = VarInt::from_u32(34); + let events = Arc::new(Mutex::new(Vec::new())); + let mut sink = Box::pin(unfold( + ControlSink { + stream_id, + events: events.clone(), + }, + { + let send_open = send_open.clone(); + let send_started = send_started.clone(); + move |stream: ControlSink, _token, _item: Bytes| { + send_started.fetch_add(1, Ordering::SeqCst); + GateFuture { + value: Some(stream), + open: send_open.clone(), + } + } + }, + operation_ok, + operation_ok, + record_reset, + )); + + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink initially ready"); + sink.as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!( + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(send_started.load(Ordering::SeqCst), 1); + + send_open.store(true, Ordering::SeqCst); + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("send completes once gate opens"); + assert_eq!(send_started.load(Ordering::SeqCst), 1); + assert_eq!( + poll_fn(|cx| sink.as_mut().poll_stream_id(cx)) + .await + .expect("stream id should be available after send completes"), + stream_id + ); + poll_fn(|cx| sink.as_mut().poll_reset(cx, reset_code)) + .await + .expect("reset should be available after send completes"); + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![Event::Reset(reset_code)] + ); + } + + #[tokio::test] + async fn control_traits_forward_while_value_available() { + let events = Arc::new(Mutex::new(Vec::new())); + let stream_id = VarInt::from_u32(37); + let reset_code = VarInt::from_u32(41); + let mut sink = Box::pin(unfold( + ControlSink { + stream_id, + events: events.clone(), + }, + send_ok, + operation_ok, + operation_ok, + record_reset, + )); + + assert_eq!( + poll_fn(|cx| sink.as_mut().poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| sink.as_mut().poll_reset(cx, reset_code)) + .await + .expect("reset forwarded"); + + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![Event::Reset(reset_code)] + ); + } + + #[tokio::test] + async fn control_traits_wait_while_send_future_owns_value() { + let events = Arc::new(Mutex::new(Vec::new())); + let mut sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(37), + events, + }, + send_pending, + operation_ok, + operation_ok, + reset_ok, + )); + + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink ready"); + sink.as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + + assert!( + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| sink.as_mut().poll_stream_id(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| sink.as_mut().poll_reset(cx, VarInt::from_u32(41))) + .now_or_never() + .is_none() + ); + } + + #[tokio::test] + async fn reset_uses_reset_closure_after_interrupting_pending_send() { + let events = Arc::new(Mutex::new(Vec::new())); + let reset_code = VarInt::from_u32(41); + let mut sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(37), + events: events.clone(), + }, + |stream: ControlSink, token, bytes| async move { + stream + .events + .lock() + .expect("event log poisoned") + .push(Event::Send(bytes)); + token.cancelled().await; + Either::<(ControlSink, Result<(), MessageStreamError>), ControlSink>::Right(stream) + }, + |stream, _token| async move { Either::Left((stream, Ok::<(), MessageStreamError>(()))) }, + |stream, _token| async move { Either::Left((stream, Ok::<(), MessageStreamError>(()))) }, + |stream: ControlSink, code| async move { + stream + .events + .lock() + .expect("event log poisoned") + .push(Event::Reset(code)); + (stream, Ok::<(), quic::StreamError>(())) + }, + )); + + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .await + .expect("sink initially ready"); + sink.as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + + poll_fn(|cx| sink.as_mut().poll_reset(cx, reset_code)) + .await + .expect("reset should complete"); + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![ + Event::Send(Bytes::from_static(b"payload")), + Event::Reset(reset_code), + ] + ); + } + + #[test] + fn start_send_without_poll_ready_panics_after_value_is_taken() { + let mut sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(1), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_ok, + operation_ok, + operation_ok, + reset_ok, + )); + poll_fn(|cx| sink.as_mut().poll_ready(cx)) + .now_or_never() + .expect("poll_ready should complete immediately") + .expect("sink ready"); + sink.as_mut() + .start_send(Bytes::from_static(b"first")) + .expect("first start_send should succeed"); + + let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + sink.as_mut() + .start_send(Bytes::from_static(b"second")) + .expect("start_send panics before returning"); + })); + + assert!(panic.is_err()); + } + + #[tokio::test] + async fn control_traits_wait_while_flush_or_close_future_owns_value() { + let mut flush_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(55), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_ok, + operation_pending, + operation_ok, + reset_ok, + )); + poll_fn(|cx| flush_sink.as_mut().poll_ready(cx)) + .await + .expect("flush sink ready"); + flush_sink + .as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!( + poll_fn(|cx| flush_sink.as_mut().poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| flush_sink.as_mut().poll_stream_id(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| flush_sink.as_mut().poll_reset(cx, VarInt::from_u32(56))) + .now_or_never() + .is_none() + ); + + let mut close_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(57), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_ok, + operation_ok, + operation_pending, + reset_ok, + )); + assert!( + poll_fn(|cx| close_sink.as_mut().poll_close(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| close_sink.as_mut().poll_stream_id(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| close_sink.as_mut().poll_reset(cx, VarInt::from_u32(58))) + .now_or_never() + .is_none() + ); + } + + #[tokio::test] + async fn unfold_flush_and_close_propagate_in_flight_send_errors() { + let mut flush_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(63), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_message_failed, + operation_malformed_outgoing, + operation_ok, + reset_ok, + )); + poll_fn(|cx| flush_sink.as_mut().poll_ready(cx)) + .await + .expect("flush sink ready"); + flush_sink + .as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!(matches!( + poll_fn(|cx| flush_sink.as_mut().poll_flush(cx)).await, + Err(MessageStreamError::MessageSendFailed) + )); + + let mut close_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(64), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_malformed_outgoing, + operation_ok, + operation_message_failed, + reset_ok, + )); + poll_fn(|cx| close_sink.as_mut().poll_ready(cx)) + .await + .expect("close sink ready"); + close_sink + .as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!(matches!( + poll_fn(|cx| close_sink.as_mut().poll_close(cx)).await, + Err(MessageStreamError::MalformedOutgoingMessage) + )); + } + + #[tokio::test] + async fn unfold_propagates_flush_and_close_errors() { + let mut flush_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(61), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_ok, + operation_message_failed, + operation_ok, + reset_ok, + )); + poll_fn(|cx| flush_sink.as_mut().poll_ready(cx)) + .await + .expect("flush sink ready"); + flush_sink + .as_mut() + .start_send(Bytes::from_static(b"payload")) + .expect("send accepted"); + assert!(matches!( + poll_fn(|cx| flush_sink.as_mut().poll_flush(cx)).await, + Err(MessageStreamError::MessageSendFailed) + )); + + let mut close_sink = Box::pin(unfold( + ControlSink { + stream_id: VarInt::from_u32(62), + events: Arc::new(Mutex::new(Vec::new())), + }, + send_ok, + operation_ok, + operation_malformed_outgoing, + reset_ok, + )); + assert!(matches!( + poll_fn(|cx| close_sink.as_mut().poll_close(cx)).await, + Err(MessageStreamError::MalformedOutgoingMessage) + )); + } + + #[tokio::test] + async fn write_stream_adapter_builders_fast_fail_after_reset() { + let stream_id = VarInt::from_u32(71); + let reset_code = VarInt::from_u32(72); + + let mut bytes_stream = crate::dhttp::message::test::write_stream_for_test(stream_id); + { + let mut sink = Box::pin(bytes_stream.as_bytes_sink()); + assert_eq!( + poll_fn(|cx| sink.as_mut().poll_stream_id(cx)) + .await + .expect("as_bytes_sink stream id"), + stream_id + ); + poll_fn(|cx| sink.as_mut().poll_reset(cx, reset_code)) + .await + .expect("as_bytes_sink reset"); + let error = sink + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("as_bytes_sink send after reset should fail"); + assert_reset_message_error(error, reset_code); + } + + let mut owned_sink = Box::pin( + crate::dhttp::message::test::write_stream_for_test(stream_id).into_bytes_sink(), + ); + assert_eq!( + poll_fn(|cx| owned_sink.as_mut().poll_stream_id(cx)) + .await + .expect("into_bytes_sink stream id"), + stream_id + ); + poll_fn(|cx| owned_sink.as_mut().poll_reset(cx, reset_code)) + .await + .expect("into_bytes_sink reset"); + let error = owned_sink + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("into_bytes_sink send after reset should fail"); + assert_reset_message_error(error, reset_code); + + let mut borrowed_stream = crate::dhttp::message::test::write_stream_for_test(stream_id); + { + let mut writer = Box::pin(borrowed_stream.as_writer()); + assert_eq!( + poll_fn(|cx| writer.as_mut().poll_stream_id(cx)) + .await + .expect("as_writer stream id"), + stream_id + ); + poll_fn(|cx| writer.as_mut().poll_reset(cx, reset_code)) + .await + .expect("as_writer reset"); + let error = writer + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("as_writer send after reset should fail"); + assert_reset_message_error(error, reset_code); + } + + let mut borrowed_box_stream = crate::dhttp::message::test::write_stream_for_test(stream_id); + { + let mut writer = borrowed_box_stream.as_box_writer(); + assert_eq!( + poll_fn(|cx| stream::GetStreamId::poll_stream_id(writer.as_mut(), cx)) + .await + .expect("as_box_writer stream id"), + stream_id + ); + poll_fn(|cx| stream::ResetStream::poll_reset(writer.as_mut(), cx, reset_code)) + .await + .expect("as_box_writer reset"); + let error = writer + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("as_box_writer send after reset should fail"); + assert_reset_message_error(error, reset_code); + } + + let mut writer = + Box::pin(crate::dhttp::message::test::write_stream_for_test(stream_id).into_writer()); + assert_eq!( + poll_fn(|cx| writer.as_mut().poll_stream_id(cx)) + .await + .expect("into_writer stream id"), + stream_id + ); + poll_fn(|cx| writer.as_mut().poll_reset(cx, reset_code)) + .await + .expect("into_writer reset"); + let error = writer + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("into_writer send after reset should fail"); + assert_reset_message_error(error, reset_code); + + let mut boxed_writer = + crate::dhttp::message::test::write_stream_for_test(stream_id).into_box_writer(); + assert_eq!( + poll_fn(|cx| stream::GetStreamId::poll_stream_id(boxed_writer.as_mut(), cx)) + .await + .expect("into_box_writer stream id"), + stream_id + ); + poll_fn(|cx| stream::ResetStream::poll_reset(boxed_writer.as_mut(), cx, reset_code)) + .await + .expect("into_box_writer reset"); + let error = boxed_writer + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("into_box_writer send after reset should fail"); + assert_reset_message_error(error, reset_code); + + let mut from_writer = Box::pin(BoxMessageWriter::from( + crate::dhttp::message::test::write_stream_for_test(stream_id), + )); + assert_eq!( + poll_fn(|cx| from_writer.as_mut().poll_stream_id(cx)) + .await + .expect("from stream id"), + stream_id + ); + poll_fn(|cx| from_writer.as_mut().poll_reset(cx, reset_code)) + .await + .expect("from reset"); + let error = from_writer + .as_mut() + .send(Bytes::from_static(b"payload")) + .await + .expect_err("from send after reset should fail"); + assert_reset_message_error(error, reset_code); + } + + #[tokio::test] + async fn write_stream_writer_adapters_fast_fail_buffered_send_after_reset() { + let stream_id = VarInt::from_u32(81); + let reset_code = VarInt::from_u32(82); + + let mut borrowed_stream = crate::dhttp::message::test::write_stream_for_test(stream_id); + { + let mut writer = Box::pin(borrowed_stream.as_writer()); + poll_fn(|cx| writer.as_mut().poll_ready(cx)) + .await + .expect("as_writer initially ready"); + writer + .as_mut() + .start_send(Bytes::from_static(b"buffered")) + .expect("as_writer buffers send"); + assert_eq!( + poll_fn(|cx| writer.as_mut().poll_stream_id(cx)) + .await + .expect("as_writer stream id with buffered send"), + stream_id + ); + poll_fn(|cx| writer.as_mut().poll_reset(cx, reset_code)) + .await + .expect("as_writer reset with buffered send"); + let error = writer + .as_mut() + .flush() + .await + .expect_err("as_writer flush after reset should fail"); + assert_reset_message_error(error, reset_code); + let error = poll_fn(|cx| writer.as_mut().poll_stream_id(cx)) + .await + .expect_err("as_writer stream id after reset should fail"); + assert!(matches!(error, quic::StreamError::Reset { code } if code == reset_code)); + let error = writer + .as_mut() + .close() + .await + .expect_err("as_writer close after reset should fail"); + assert_reset_message_error(error, reset_code); + } + + let mut borrowed_box_stream = crate::dhttp::message::test::write_stream_for_test(stream_id); + { + let mut writer = borrowed_box_stream.as_box_writer(); + poll_fn(|cx| writer.as_mut().poll_ready(cx)) + .await + .expect("as_box_writer initially ready"); + writer + .as_mut() + .start_send(Bytes::from_static(b"buffered")) + .expect("as_box_writer buffers send"); + poll_fn(|cx| stream::ResetStream::poll_reset(writer.as_mut(), cx, reset_code)) + .await + .expect("as_box_writer reset with buffered send"); + let error = writer + .as_mut() + .flush() + .await + .expect_err("as_box_writer flush after reset should fail"); + assert_reset_message_error(error, reset_code); + let error = writer + .as_mut() + .close() + .await + .expect_err("as_box_writer close after reset should fail"); + assert_reset_message_error(error, reset_code); + } + + let mut writer = + Box::pin(crate::dhttp::message::test::write_stream_for_test(stream_id).into_writer()); + poll_fn(|cx| writer.as_mut().poll_ready(cx)) + .await + .expect("into_writer initially ready"); + writer + .as_mut() + .start_send(Bytes::from_static(b"buffered")) + .expect("into_writer buffers send"); + assert_eq!( + poll_fn(|cx| writer.as_mut().poll_stream_id(cx)) + .await + .expect("into_writer stream id with buffered send"), + stream_id + ); + poll_fn(|cx| writer.as_mut().poll_reset(cx, reset_code)) + .await + .expect("into_writer reset with buffered send"); + let error = writer + .as_mut() + .flush() + .await + .expect_err("into_writer flush after reset should fail"); + assert_reset_message_error(error, reset_code); + let error = writer + .as_mut() + .close() + .await + .expect_err("into_writer close after reset should fail"); + assert_reset_message_error(error, reset_code); + + let mut boxed_writer = + crate::dhttp::message::test::write_stream_for_test(stream_id).into_box_writer(); + poll_fn(|cx| boxed_writer.as_mut().poll_ready(cx)) + .await + .expect("into_box_writer initially ready"); + boxed_writer + .as_mut() + .start_send(Bytes::from_static(b"buffered")) + .expect("into_box_writer buffers send"); + poll_fn(|cx| stream::ResetStream::poll_reset(boxed_writer.as_mut(), cx, reset_code)) + .await + .expect("into_box_writer reset with buffered send"); + let error = boxed_writer + .as_mut() + .flush() + .await + .expect_err("into_box_writer flush after reset should fail"); + assert_reset_message_error(error, reset_code); + let error = boxed_writer + .as_mut() + .close() + .await + .expect_err("into_box_writer close after reset should fail"); + assert_reset_message_error(error, reset_code); + } +} diff --git a/src/dhttp/protocol.rs b/src/dhttp/protocol.rs index 0ccedcb..a586667 100644 --- a/src/dhttp/protocol.rs +++ b/src/dhttp/protocol.rs @@ -24,13 +24,14 @@ use tracing::Instrument; use crate::{ buflist::BufList, codec::{ - DecodeExt, EncodeExt, ErasedPeekableBiStream, ErasedPeekableUniStream, Feed, SinkWriter, + BoxPeekableStreamReader, BoxStreamWriter, DecodeExt, EncodeExt, Feed, SinkWriter, StreamReader, }, connection::{ConnectionGoaway, ConnectionState, LifecycleExt, StreamError}, dhttp::{ frame::{Frame, stream::FrameStream}, goaway::Goaway, + message::guard, settings::Settings, stream::UnidirectionalStream, }, @@ -38,9 +39,8 @@ use crate::{ Code, H3CriticalStreamClosed, H3FrameUnexpected, H3IdError, H3MissingSettings, H3StreamCreationError, }, - message::stream::guard, protocol::{ProductProtocol, Protocol, Protocols, StreamVerdict}, - quic::{self, CancelStreamExt, ConnectionError, GetStreamIdExt, StopStreamExt}, + quic::{self, ConnectionError, GetStreamIdExt, ResetStreamExt, StopStreamExt}, util::{ring_channel::RingChannel, set_once::SetOnce, watch::Watch}, varint::VarInt, }; @@ -78,15 +78,20 @@ mod tests { }; use bytes::Bytes; - use futures::{Sink, Stream}; + use futures::{Sink, SinkExt, Stream, StreamExt}; + use tokio::time::{Duration, timeout}; use super::*; use crate::{ - codec::{BoxReadStream, BoxWriteStream, SinkWriter, StreamReader}, + codec::{EncodeExt, PeekableStreamReader, SinkWriter, StreamReader}, connection::{ConnectionState, tests::MockConnection}, - dhttp::settings::{EnableConnectProtocol, Settings}, + dhttp::settings::Settings, + extended_connect::settings::EnableConnectProtocol, protocol::Protocols, - quic::{self, GetStreamIdExt}, + quic::{ + self, BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamId, GetStreamIdExt, + ResetStream, + }, }; #[derive(Debug)] @@ -135,8 +140,8 @@ mod tests { } } - impl quic::CancelStream for TestWriteStream { - fn poll_cancel( + impl quic::ResetStream for TestWriteStream { + fn poll_reset( self: Pin<&mut Self>, _cx: &mut Context, _code: VarInt, @@ -174,19 +179,68 @@ mod tests { } } - fn test_erased_streams(stream_id: u32) -> (GuardedStreamReader, GuardedStreamWriter) { + fn test_erased_streams( + stream_id: u32, + ) -> ( + StreamReader, + SinkWriter, + ) { let stream_id = VarInt::from_u32(stream_id); let reader = - StreamReader::new(guard::GuardedQuicReader::new( - Box::pin(TestReadStream { stream_id }) as BoxReadStream, + StreamReader::new(guard::GuardQuicReader::new( + Box::pin(TestReadStream { stream_id }) as BoxQuicStreamReader, )); let writer = - SinkWriter::new(guard::GuardedQuicWriter::new( - Box::pin(TestWriteStream { stream_id }) as BoxWriteStream, + SinkWriter::new(guard::GuardQuicWriter::new( + Box::pin(TestWriteStream { stream_id }) as BoxQuicStreamWriter, )); (reader, writer) } + async fn test_peekable_uni_stream_with_bytes(bytes: &[u8]) -> BoxPeekableStreamReader { + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(2)); + writer + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test uni bytes"); + writer.close().await.expect("close test uni stream"); + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)) + } + + fn test_empty_peekable_uni_stream() -> BoxPeekableStreamReader { + PeekableStreamReader::new(StreamReader::new(Box::pin(TestReadStream { + stream_id: VarInt::from_u32(2), + }) as BoxQuicStreamReader)) + } + + async fn test_peekable_bi_stream_with_bytes( + stream_id: u32, + bytes: &[u8], + ) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let (reader, mut write_side) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + write_side + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test bidi bytes"); + write_side.close().await.expect("close test bidi read side"); + + let (_read_side, writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + ( + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)), + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + fn test_empty_peekable_bi_stream(stream_id: u32) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let stream_id = VarInt::from_u32(stream_id); + ( + PeekableStreamReader::new(StreamReader::new( + Box::pin(TestReadStream { stream_id }) as BoxQuicStreamReader + )), + SinkWriter::new(Box::pin(TestWriteStream { stream_id }) as BoxQuicStreamWriter), + ) + } + fn test_connection_state() -> ConnectionState { let quic = Arc::new(MockConnection::new()); let erased_connection: Arc = quic.clone(); @@ -201,6 +255,115 @@ mod tests { h.finish() } + fn assert_stream_connection_code(error: StreamError, expected: Code) { + assert!(matches!( + error, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + } if source.code() == expected + )); + } + + fn test_transport_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + fn assert_transport_error_reason(error: &quic::ConnectionError, expected: &str) { + let quic::ConnectionError::Transport { source } = error else { + panic!("expected transport connection error"); + }; + assert_eq!(source.reason.as_ref(), expected); + } + + async fn stream_reader_from_frames( + frames: impl IntoIterator>, + ) -> StreamReader { + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(2)); + let mut writer = SinkWriter::new(writer); + for frame in frames { + writer + .encode_one(frame) + .await + .expect("encode test control frame"); + } + writer.close().await.expect("close test control stream"); + StreamReader::new(Box::pin(reader) as BoxQuicStreamReader) + } + + async fn settings_frame(settings: &Settings) -> Frame { + BufList::new() + .encode(settings) + .await + .expect("settings encoding into buflist is infallible") + } + + async fn goaway_frame(stream_id: u32) -> Frame { + BufList::new() + .encode(Goaway::new(VarInt::from_u32(stream_id))) + .await + .expect("goaway encoding into buflist is infallible") + } + + #[test] + fn dhttp_factory_display_names_protocol() { + let factory = DHttpProtocolFactory::default(); + + assert_eq!(factory.to_string(), "DHTTP/3"); + } + + #[tokio::test] + async fn test_stream_helpers_cover_writer_control_and_sink_traits() { + let stream_id = VarInt::from_u32(27); + let mut writer = TestWriteStream { stream_id }; + + assert_eq!( + futures::future::poll_fn(|cx| Pin::new(&mut writer).poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + futures::future::poll_fn(|cx| Pin::new(&mut writer).poll_reset(cx, VarInt::from_u32(1))) + .await + .expect("reset succeeds"); + writer + .send(Bytes::from_static(b"payload")) + .await + .expect("send succeeds"); + writer.flush().await.expect("flush succeeds"); + writer.close().await.expect("close succeeds"); + } + + #[tokio::test] + async fn dhttp_protocol_debug_capacity_and_trait_routes_are_stable() { + let state = test_connection_state(); + let debug = format!("{:?}", state.dhttp()); + assert!(debug.contains("DHttpLayer")); + assert!(debug.contains("control_stream")); + assert_eq!(state.dhttp().max_unresolved_request_streams().await, 32); + + let uni = test_peekable_uni_stream_with_bytes(&[0x02]).await; + assert!(matches!( + Protocol::accept_uni(state.dhttp(), uni) + .await + .expect("trait uni accept"), + StreamVerdict::Passed(_) + )); + + let bi = test_peekable_bi_stream_with_bytes(28, &[0x41]).await; + assert!(matches!( + Protocol::accept_bi(state.dhttp(), bi) + .await + .expect("trait bidi accept"), + StreamVerdict::Passed(_) + )); + } + #[test] fn dhttp_factory_same_settings_equal_hash() { let s1 = Arc::new(Settings::default()); @@ -248,6 +411,35 @@ mod tests { assert_ne!(f1, f2); } + #[test] + fn begin_local_goaway_defaults_to_zero_without_received_streams() { + let state = DHttpState::new(Arc::new(Settings::default())); + + let goaway = state.begin_local_goaway(); + + assert_eq!(goaway.stream_id(), VarInt::from_u32(0)); + assert_eq!(state.local_goaway.peek(), Some(goaway)); + } + + #[test] + fn begin_local_goaway_uses_max_received_stream_id() { + let state = DHttpState::new(Arc::new(Settings::default())); + state + .register_accepted_stream(VarInt::from_u32(3)) + .expect("first accepted stream"); + state + .register_accepted_stream(VarInt::from_u32(11)) + .expect("higher accepted stream"); + state + .register_accepted_stream(VarInt::from_u32(7)) + .expect("lower accepted stream"); + + let goaway = state.begin_local_goaway(); + + assert_eq!(goaway.stream_id(), VarInt::from_u32(11)); + assert_eq!(state.local_goaway.peek(), Some(goaway)); + } + #[test] fn initialized_stream_updates_initialized_only() { let state = DHttpState::new(Arc::new(Settings::default())); @@ -261,6 +453,23 @@ mod tests { assert_eq!(state.max_received_stream_id.peek(), None); } + #[test] + fn initialized_stream_tracks_maximum_stream_id() { + let state = DHttpState::new(Arc::new(Settings::default())); + + state + .register_initialized_stream(VarInt::from_u32(15)) + .expect("first initialized stream"); + state + .register_initialized_stream(VarInt::from_u32(5)) + .expect("lower initialized stream"); + + assert_eq!( + state.max_initialized_stream_id.peek(), + Some(VarInt::from_u32(15)) + ); + } + #[test] fn accepted_stream_updates_received_only() { let state = DHttpState::new(Arc::new(Settings::default())); @@ -274,6 +483,35 @@ mod tests { assert_eq!(state.max_initialized_stream_id.peek(), None); } + #[test] + fn accepted_stream_tracks_maximum_stream_id() { + let state = DHttpState::new(Arc::new(Settings::default())); + + state + .register_accepted_stream(VarInt::from_u32(12)) + .expect("first accepted stream"); + state + .register_accepted_stream(VarInt::from_u32(4)) + .expect("lower accepted stream"); + + assert_eq!( + state.max_received_stream_id.peek(), + Some(VarInt::from_u32(12)) + ); + } + + #[test] + fn accepted_stream_rejected_at_local_goaway_boundary() { + let state = DHttpState::new(Arc::new(Settings::default())); + state.local_goaway.set(Goaway::new(VarInt::from_u32(10))); + + let error = state + .register_accepted_stream(VarInt::from_u32(10)) + .expect_err("stream at local goaway boundary must be rejected"); + + assert_eq!(error, ConnectionGoaway::Local); + } + #[test] fn initialized_stream_rejected_after_peer_goaway_latched() { let state = DHttpState::new(Arc::new(Settings::default())); @@ -307,6 +545,64 @@ mod tests { )); } + #[tokio::test] + async fn control_stream_reads_settings_goaway_unknown_and_reports_close() { + let state = DHttpState::new(Arc::new(Settings::default())); + let mut settings = Settings::default(); + settings.set(EnableConnectProtocol::setting(true)); + let settings = Arc::new(settings); + let unknown = Frame::new(VarInt::from_u32(0x2f), BufList::new()) + .expect("unknown frame type is valid"); + let stream = stream_reader_from_frames([ + settings_frame(&settings).await, + unknown, + goaway_frame(9).await, + ]) + .await; + + let error = match state.handle_control_stream(stream).await { + Ok(never) => match never {}, + Err(error) => error, + }; + + assert_stream_connection_code(error, Code::H3_CLOSED_CRITICAL_STREAM); + assert_eq!(state.peer_settings.peek(), Some(settings)); + assert_eq!( + state.peer_goaway.peek(), + Some(Goaway::new(VarInt::from_u32(9))) + ); + } + + #[tokio::test] + async fn control_stream_rejects_non_settings_first_frame_and_duplicate_settings() { + let state = DHttpState::new(Arc::new(Settings::default())); + let error = match state + .handle_control_stream(stream_reader_from_frames([goaway_frame(1).await]).await) + .await + { + Ok(never) => match never {}, + Err(error) => error, + }; + assert_stream_connection_code(error, Code::H3_MISSING_SETTINGS); + + let state = DHttpState::new(Arc::new(Settings::default())); + let settings = Settings::default(); + let error = match state + .handle_control_stream( + stream_reader_from_frames([ + settings_frame(&settings).await, + settings_frame(&settings).await, + ]) + .await, + ) + .await + { + Ok(never) => match never {}, + Err(error) => error, + }; + assert_stream_connection_code(error, Code::H3_FRAME_UNEXPECTED); + } + #[tokio::test] async fn peer_goaway_covers_resolves_immediately_when_already_covered() { let state = DHttpState::new(Arc::new(Settings::default())); @@ -340,7 +636,450 @@ mod tests { .apply_peer_goaway(covering) .expect("covering goaway should be accepted"); - assert_eq!(waiter.await.expect("join should succeed"), covering); + let observed = timeout(Duration::from_millis(100), waiter) + .await + .expect("peer goaway waiter should resolve") + .expect("join should succeed"); + assert_eq!(observed, covering); + } + + #[tokio::test] + async fn peer_goaway_covers_ignores_non_covering_existing_value() { + let state = Arc::new(DHttpState::new(Arc::new(Settings::default()))); + state + .apply_peer_goaway(Goaway::new(VarInt::from_u32(20))) + .expect("non-covering peer goaway should be accepted"); + + let waiter_state = state.clone(); + let waiter = + tokio::spawn( + async move { waiter_state.peer_goaway_covers(VarInt::from_u32(10)).await }, + ); + tokio::task::yield_now().await; + assert!(!waiter.is_finished()); + + let covering = Goaway::new(VarInt::from_u32(10)); + state + .apply_peer_goaway(covering) + .expect("covering peer goaway should be accepted"); + + let observed = timeout(Duration::from_millis(100), waiter) + .await + .expect("peer goaway waiter should resolve") + .expect("join should succeed"); + assert_eq!(observed, covering); + } + + #[tokio::test] + async fn accept_uni_passes_stream_when_type_cannot_be_decoded() { + let state = test_connection_state(); + let stream = test_empty_peekable_uni_stream(); + + assert!(matches!( + state.dhttp().accept_uni(stream).await.expect("uni verdict"), + StreamVerdict::Passed(_) + )); + } + + #[tokio::test] + async fn accept_uni_accepts_first_control_stream_and_rejects_duplicate() { + let state = test_connection_state(); + + let first = test_peekable_uni_stream_with_bytes(&[0x00]).await; + assert!(matches!( + state + .dhttp() + .accept_uni(first) + .await + .expect("first control"), + StreamVerdict::Accepted + )); + + let duplicate = test_peekable_uni_stream_with_bytes(&[0x00]).await; + let error = match state.dhttp().accept_uni(duplicate).await { + Ok(_) => panic!("duplicate control stream must fail"), + Err(error) => error, + }; + + assert_stream_connection_code(error, Code::H3_STREAM_CREATION_ERROR); + } + + #[tokio::test] + async fn accept_uni_rejects_push_stream_without_max_push_id() { + let state = test_connection_state(); + let stream = test_peekable_uni_stream_with_bytes(&[0x01]).await; + + let error = match state.dhttp().accept_uni(stream).await { + Ok(_) => panic!("push stream should exceed push id limit"), + Err(error) => error, + }; + + assert_stream_connection_code(error, Code::H3_ID_ERROR); + } + + #[tokio::test] + async fn accept_uni_accepts_reserved_stream_type() { + let state = test_connection_state(); + let stream = test_peekable_uni_stream_with_bytes(&[0x21]).await; + + assert!(matches!( + state + .dhttp() + .accept_uni(stream) + .await + .expect("reserved stream verdict"), + StreamVerdict::Accepted + )); + } + + #[tokio::test] + async fn accept_uni_passes_unknown_stream_type() { + let state = test_connection_state(); + let stream = test_peekable_uni_stream_with_bytes(&[0x02]).await; + + assert!(matches!( + state + .dhttp() + .accept_uni(stream) + .await + .expect("unknown stream verdict"), + StreamVerdict::Passed(_) + )); + } + + #[tokio::test] + async fn accept_bi_passes_stream_when_frame_type_cannot_be_decoded() { + let state = test_connection_state(); + let stream = test_empty_peekable_bi_stream(0); + + assert!(matches!( + state + .dhttp() + .accept_bi(stream) + .await + .expect("empty bidi verdict"), + StreamVerdict::Passed(_) + )); + } + + #[tokio::test] + async fn accept_bi_routes_known_http3_frame_type() { + let state = test_connection_state(); + let stream = test_peekable_bi_stream_with_bytes(12, &[0x01]).await; + + assert!(matches!( + state + .dhttp() + .accept_bi(stream) + .await + .expect("headers frame verdict"), + StreamVerdict::Accepted + )); + + let (mut reader, _writer) = state.dhttp().unresolved_request_streams.receive().await; + assert_eq!( + reader.stream_id().await.expect("routed stream id"), + VarInt::from_u32(12) + ); + } + + #[tokio::test] + async fn accept_bi_routes_reserved_http3_frame_type() { + let state = test_connection_state(); + let stream = test_peekable_bi_stream_with_bytes(16, &[0x21]).await; + + assert!(matches!( + state + .dhttp() + .accept_bi(stream) + .await + .expect("reserved frame verdict"), + StreamVerdict::Accepted + )); + } + + #[tokio::test] + async fn accept_bi_passes_unknown_frame_type() { + let state = test_connection_state(); + let stream = test_peekable_bi_stream_with_bytes(20, &[0x41]).await; + + assert!(matches!( + state + .dhttp() + .accept_bi(stream) + .await + .expect("unknown frame verdict"), + StreamVerdict::Passed(_) + )); + } + + #[tokio::test] + async fn accept_bi_evicts_oldest_unresolved_stream_when_ring_is_full() { + let state = test_connection_state(); + + for stream_id in 0..=32 { + let stream = test_peekable_bi_stream_with_bytes(stream_id, &[0x01]).await; + assert!(matches!( + state + .dhttp() + .accept_bi(stream) + .await + .expect("known frame verdict"), + StreamVerdict::Accepted + )); + } + + let (mut reader, _writer) = state.dhttp().unresolved_request_streams.receive().await; + assert_eq!( + reader.stream_id().await.expect("oldest retained stream id"), + VarInt::from_u32(1) + ); + } + + #[test] + fn http3_frame_type_classifier_covers_known_reserved_and_unknown_values() { + assert!(DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x00))); + assert!(DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x01))); + assert!(DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x0d))); + assert!(DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x21))); + assert!(DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x40))); + assert!(!DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x02))); + assert!(!DHttpProtocol::is_http3_frame_type(VarInt::from_u32(0x41))); + } + + #[tokio::test] + async fn protocol_factory_init_opens_uni_and_sets_local_settings() { + let conn = Arc::new(MockConnection::new()); + conn.enable_stream_ops(); + let mut settings = Settings::default(); + settings.set(EnableConnectProtocol::setting(true)); + let settings = Arc::new(settings); + let factory = DHttpProtocolFactory::new(settings.clone()); + + let protocol = factory.init(&conn).await.expect("protocol init"); + + assert_eq!(protocol.local_settings, settings); + assert_eq!(conn.stream_calls(), vec!["open_uni"]); + } + + #[tokio::test] + async fn protocol_factory_init_returns_open_uni_connection_error() { + let conn = Arc::new(MockConnection::new()); + let factory = DHttpProtocolFactory::default(); + + let error = factory + .init(&conn) + .await + .expect_err("open_uni failure should fail init"); + + assert!(matches!(error, quic::ConnectionError::Transport { .. })); + assert_eq!(conn.stream_calls(), vec!["open_uni"]); + } + + #[test] + fn connection_state_accessors_return_dhttp_state_values() { + let state = test_connection_state(); + state + .dhttp() + .register_initialized_stream(VarInt::from_u32(4)) + .expect("initialized stream"); + state + .dhttp() + .register_accepted_stream(VarInt::from_u32(8)) + .expect("accepted stream"); + state + .dhttp() + .apply_peer_goaway(Goaway::new(VarInt::from_u32(12))) + .expect("peer goaway"); + + assert_eq!(state.settings(), Arc::new(Settings::default())); + assert_eq!(state.max_initialized_stream_id(), Some(VarInt::from_u32(4))); + assert_eq!(state.max_received_stream_id(), Some(VarInt::from_u32(8))); + assert_eq!( + state.peek_peer_goaway(), + Some(Goaway::new(VarInt::from_u32(12))) + ); + } + + #[tokio::test] + async fn peer_settings_resolves_when_settings_already_available() { + let state = test_connection_state(); + let settings = Arc::new(Settings::default()); + state + .dhttp() + .peer_settings + .set(settings.clone()) + .expect("set peer settings"); + + let observed = state + .peer_settings() + .await + .expect("peer settings should resolve"); + + assert_eq!(observed, settings); + } + + #[tokio::test] + async fn peer_settings_resolves_on_dyn_connection_state() { + let quic = Arc::new(MockConnection::new()); + let erased_connection: Arc = quic.clone(); + let mut protocols = Protocols::new(); + let dhttp = DHttpProtocol::new_for_test(erased_connection.clone()); + let settings = Arc::new(Settings::default()); + dhttp + .peer_settings + .set(settings.clone()) + .expect("set peer settings"); + protocols.insert(dhttp); + let state: ConnectionState = + ConnectionState::new_for_test(erased_connection, Arc::new(protocols)); + + let observed = state + .peer_settings() + .await + .expect("dyn peer settings should resolve"); + + assert_eq!(observed, settings); + } + + #[tokio::test] + async fn peer_settings_resolves_to_connection_error_when_settings_never_arrive() { + let quic = Arc::new(MockConnection::new()); + let erased_connection: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased_connection)); + let state = ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)); + quic.set_terminal_error(test_transport_error("peer settings closed")); + + let error = timeout(Duration::from_millis(100), state.peer_settings()) + .await + .expect("peer settings should resolve on connection close") + .expect_err("connection closure should yield an error"); + + assert_transport_error_reason(&error, "peer settings closed"); + } + + #[tokio::test] + async fn peer_goawaies_yields_peer_goaway_updates() { + let state = test_connection_state(); + let goaway = Goaway::new(VarInt::from_u32(18)); + + let mut goawaies = pin!(state.peer_goawaies()); + state + .dhttp() + .apply_peer_goaway(goaway) + .expect("peer goaway"); + + let observed = timeout(Duration::from_millis(100), goawaies.next()) + .await + .expect("peer goaway stream should yield") + .expect("peer goaway stream should not end") + .expect("peer goaway item should be ok"); + + assert_eq!(observed, goaway); + } + + #[tokio::test] + async fn peer_goawaies_yields_connection_error_when_connection_closes() { + let quic = Arc::new(MockConnection::new()); + let erased_connection: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased_connection)); + let state = ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)); + quic.set_terminal_error(test_transport_error("peer goawaies closed")); + + let mut goawaies = pin!(state.peer_goawaies()); + let error = timeout(Duration::from_millis(100), goawaies.next()) + .await + .expect("peer goaway stream should resolve") + .expect("peer goaway stream should not end") + .expect_err("connection closure should yield an error"); + + assert_transport_error_reason(&error, "peer goawaies closed"); + } + + #[tokio::test] + async fn goaway_sends_local_goaway_using_max_received_stream_id() { + let state = test_connection_state(); + state + .dhttp() + .register_accepted_stream(VarInt::from_u32(14)) + .expect("accepted stream"); + + state.goaway().await.expect("send goaway"); + + assert_eq!( + state.dhttp().local_goaway.peek(), + Some(Goaway::new(VarInt::from_u32(14))) + ); + } + + #[tokio::test] + async fn initial_raw_message_stream_registers_opened_stream() { + let quic = Arc::new(MockConnection::new()); + quic.enable_stream_ops(); + let erased_connection: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased_connection)); + let state = ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)); + + let (mut reader, _writer) = state + .initial_raw_message_stream() + .await + .expect("initial stream"); + + assert_eq!( + reader.stream_id().await.expect("stream id"), + VarInt::from_u32(0) + ); + assert_eq!(state.max_initialized_stream_id(), Some(VarInt::from_u32(0))); + assert_eq!(quic.stream_calls(), vec!["open_bi"]); + } + + #[tokio::test] + async fn accept_raw_message_stream_returns_connection_error_when_connection_closes() { + let quic = Arc::new(MockConnection::new()); + let erased_connection: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased_connection)); + let state = ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)); + quic.set_terminal_error(test_transport_error("accept raw closed")); + + let error = match state.accept_raw_message_stream().await { + Ok(_) => panic!("connection closure should stop raw message acceptance"), + Err(error) => error, + }; + let AcceptRawMessageStreamError::Connection { source } = error else { + panic!("expected connection error"); + }; + + assert_transport_error_reason(&source, "accept raw closed"); + } + + #[tokio::test] + async fn initial_raw_message_stream_rejects_when_peer_goaway_latched() { + let quic = Arc::new(MockConnection::new()); + quic.enable_stream_ops(); + let erased_connection: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased_connection)); + let state = ConnectionState::new_for_test(quic, Arc::new(protocols)); + state + .dhttp() + .apply_peer_goaway(Goaway::new(VarInt::from_u32(0))) + .expect("peer goaway"); + + let error = match state.initial_raw_message_stream().await { + Ok(_) => panic!("peer goaway should reject initialized stream"), + Err(error) => error, + }; + + assert!(matches!( + error, + InitialRawMessageStreamError::Goaway { + source: ConnectionGoaway::Peer + } + )); } #[tokio::test] @@ -470,8 +1209,9 @@ mod tests { .unresolved_request_streams .send(test_erased_streams(10)); - let error = accept_task + let error = timeout(Duration::from_millis(100), accept_task) .await + .expect("accept task should resolve") .expect("join should succeed") .err() .expect("boundary should apply even if goaway was set after accept started"); @@ -625,12 +1365,6 @@ impl DHttpState { type FrameSink = Feed, StreamError>, Frame>; -pub type BoxDynQuicStreamReader = guard::GuardedQuicReader; -pub type BoxDynQuicStreamWriter = guard::GuardedQuicWriter; - -type GuardedStreamReader = StreamReader; -type GuardedStreamWriter = SinkWriter; - /// DHTTP/3 protocol layer. /// /// Implements [`Protocol`] to handle HTTP/3 stream identification and @@ -649,7 +1383,10 @@ pub struct DHttpProtocol { handle_control_stream: SetOnce>, - unresolved_request_streams: RingChannel<(GuardedStreamReader, GuardedStreamWriter)>, + unresolved_request_streams: RingChannel<( + StreamReader, + SinkWriter, + )>, } impl ops::Deref for DHttpProtocol { @@ -674,10 +1411,14 @@ impl DHttpProtocol { self.unresolved_request_streams.capacity() } + pub fn peer_settings_peek(&self) -> Option> { + self.state.peer_settings.peek() + } + async fn accept_uni( &self, - mut stream: ErasedPeekableUniStream, - ) -> Result, StreamError> { + mut stream: BoxPeekableStreamReader, + ) -> Result, StreamError> { let Ok(stream_type) = stream.decode_one::().await else { return Ok(StreamVerdict::Passed(stream)); }; @@ -745,8 +1486,8 @@ impl DHttpProtocol { async fn accept_bi( &self, - (mut reader, writer): ErasedPeekableBiStream, - ) -> Result, StreamError> { + (mut reader, writer): (BoxPeekableStreamReader, BoxStreamWriter), + ) -> Result, StreamError> { // HTTP/3 bidirectional streams are request streams (RFC 9114 §4.1). // The first bytes on a request stream are HTTP/3 frames, starting with // a frame type VarInt. We peek the first VarInt to determine whether @@ -773,13 +1514,13 @@ impl DHttpProtocol { Pin::new(&mut reader).reset(); let reader = reader .into_stream_reader() - .map_stream(guard::GuardedQuicReader::new); - let writer = writer.map_sink(guard::GuardedQuicWriter::new); + .map_stream(guard::GuardQuicReader::new); + let writer = writer.map_sink(guard::GuardQuicWriter::new); let item = (reader, writer); if let Some(mut unresolved) = self.unresolved_request_streams.send(item) { // Ring channel is full — reject the oldest unresolved request. let code = Code::H3_REQUEST_REJECTED.into_inner(); - _ = tokio::join!(unresolved.0.stop(code), unresolved.1.cancel(code)); + _ = tokio::join!(unresolved.0.stop(code), unresolved.1.reset(code)); } Ok(StreamVerdict::Accepted) } else { @@ -821,15 +1562,16 @@ impl DHttpProtocol { impl Protocol for DHttpProtocol { fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { Box::pin(self.accept_uni(stream)) } fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture<'a, Result, StreamError>> + { Box::pin(self.accept_bi(stream)) } } @@ -935,16 +1677,19 @@ impl ConnectionState { } } -impl ConnectionState { - pub async fn peer_settings( - &self, - ) -> impl Future, quic::ConnectionError>> + Send + use<'_, C> - { - let error = self.closed(); - (self.dhttp().peer_settings.get()).then(|option| match option { - Some(settings) => future::ready(Ok(settings)).left_future(), - None => error.map(Err).right_future(), - }) +impl ConnectionState { + /// Per RFC 9114 §3.2/§6.2.1, SETTINGS MUST be the first frame on the + /// control stream; a connection that closes before SETTINGS arrives is in + /// error. Races [`Self::closed`] so callers cannot hang on a frame that + /// will never come, mirroring [`Self::peer_goawaies`]. + pub async fn peer_settings(&self) -> Result, quic::ConnectionError> { + let settings = self.dhttp().peer_settings.get(); + let closed = self.closed(); + match future::select(pin!(settings), pin!(closed)).await { + future::Either::Left((Some(settings), _)) => Ok(settings), + future::Either::Left((None, closed)) => Err(closed.await), + future::Either::Right((error, _)) => Err(error), + } } pub fn peer_goawaies( @@ -999,20 +1744,32 @@ pub enum AcceptRawMessageStreamError { impl ConnectionState { pub async fn initial_raw_message_stream( &self, - ) -> Result<(GuardedStreamReader, GuardedStreamWriter), InitialRawMessageStreamError> { + ) -> Result< + ( + StreamReader, + SinkWriter, + ), + InitialRawMessageStreamError, + > { let (reader, writer) = self.open_bi().await?; let (mut reader, writer) = (Box::pin(reader), Box::pin(writer)); self.dhttp() .register_initialized_stream(reader.stream_id().await?)?; Ok(( - StreamReader::new(guard::GuardedQuicReader::new(reader)), - SinkWriter::new(guard::GuardedQuicWriter::new(writer)), + StreamReader::new(guard::GuardQuicReader::new(reader)), + SinkWriter::new(guard::GuardQuicWriter::new(writer)), )) } pub async fn accept_raw_message_stream( &self, - ) -> Result<(GuardedStreamReader, GuardedStreamWriter), AcceptRawMessageStreamError> { + ) -> Result< + ( + StreamReader, + SinkWriter, + ), + AcceptRawMessageStreamError, + > { let dhttp = self.dhttp(); let (mut reader, mut writer) = tokio::select! { stream = dhttp.unresolved_request_streams.receive() => stream, @@ -1023,7 +1780,7 @@ impl ConnectionState { Ok(()) => Ok((reader, writer)), Err(ConnectionGoaway::Local) => { let code = Code::H3_REQUEST_REJECTED.into_inner(); - _ = tokio::join!(reader.stop(code), writer.cancel(code)); + _ = tokio::join!(reader.stop(code), writer.reset(code)); Err(ConnectionGoaway::Local.into()) } Err(ConnectionGoaway::Peer) => { diff --git a/src/dhttp/settings.rs b/src/dhttp/settings.rs index ea0a60a..5b4a231 100644 --- a/src/dhttp/settings.rs +++ b/src/dhttp/settings.rs @@ -39,10 +39,9 @@ impl Setting { } pub fn check(&self) -> Result<(), InvalidSettingValue> { - let is_bool_setting = self.id == EnableConnectProtocol::ID - || self.id == EnableWebTransport::ID - || self.id == H3Datagram::ID; - if is_bool_setting && self.value != VarInt::from_u32(0) && self.value != VarInt::from_u32(1) + if is_boolean_setting(self.id) + && self.value != VarInt::from_u32(0) + && self.value != VarInt::from_u32(1) { return Err(InvalidSettingValue::BoolSetting { id: self.id, @@ -71,6 +70,13 @@ impl H3ConnectionError for InvalidSettingValue { } } +const fn is_boolean_setting(id: VarInt) -> bool { + let id = id.into_inner(); + id == crate::extended_connect::settings::EnableConnectProtocol::ID.into_inner() + || id == crate::dhttp::webtransport::settings::EnableWebTransport::ID.into_inner() + || id == crate::dhttp::datagram::settings::H3Datagram::ID.into_inner() +} + impl DecodeFrom for Setting { type Error = StreamError; @@ -173,20 +179,20 @@ impl EncodeInto for Settings { impl Settings { /// Typed access to a setting value. The return type depends on the setting: /// - /// - Concrete setting types (`QpackMaxTableCapacity`, `MaxFieldSectionSize`, …) - /// apply defaults and return their associated `Value` type. + /// - Concrete setting types (`crate::qpack::settings::QpackMaxTableCapacity`, + /// `MaxFieldSectionSize`, …) apply defaults and return their associated `Value` type. /// - A raw [`VarInt`] identifier returns `Option` with no default fallback. /// /// ```ignore - /// settings.get(QpackMaxTableCapacity) // → VarInt (with default) - /// settings.get(MaxFieldSectionSize) // → Option - /// settings.get(VarInt::from_u32(0x06)) // → Option (raw) + /// settings.get(crate::qpack::settings::QpackMaxTableCapacity) // → VarInt (with default) + /// settings.get(MaxFieldSectionSize) // → Option + /// settings.get(VarInt::from_u32(0x06)) // → Option (raw) /// ``` pub fn get(&self, id: S) -> S::Value { id.value_from(self) } - fn get_raw(&self, id: VarInt) -> Option { + pub(crate) fn get_raw(&self, id: VarInt) -> Option { self.map.get(&id).copied() } @@ -194,28 +200,18 @@ impl Settings { self.get(MaxFieldSectionSize) } - pub fn qpack_max_table_capacity(&self) -> VarInt { - self.get(QpackMaxTableCapacity) - } - - pub fn qpack_blocked_streams(&self) -> VarInt { - self.get(QpackBlockedStreams) - } - - pub fn enable_connect_protocol(&self) -> bool { - self.get(EnableConnectProtocol) - } - - pub fn enable_webtransport(&self) -> bool { - self.get(EnableWebTransport) + pub fn set(&mut self, Setting { id, value }: Setting) { + self.map.insert(id, value); } - pub fn h3_datagram(&self) -> bool { - self.get(H3Datagram) + pub fn with(mut self, setting: Setting) -> Self { + self.set(setting); + self } - pub fn set(&mut self, Setting { id, value }: Setting) { - self.map.insert(id, value); + pub fn with_all(mut self, settings: impl IntoIterator) -> Self { + self.extend(settings); + self } } @@ -292,210 +288,578 @@ impl SettingId for VarInt { } } -/// `SETTINGS_QPACK_MAX_TABLE_CAPACITY` (0x01). Default: 0. +/// `SETTINGS_MAX_FIELD_SECTION_SIZE` (0x06). No default (unlimited). /// -/// To bound the memory requirements of the decoder, the decoder limits the -/// maximum value the encoder is permitted to set for the dynamic table -/// capacity. In HTTP/3, this limit is determined by the value of -/// `SETTINGS_QPACK_MAX_TABLE_CAPACITY` sent by the decoder. +/// An HTTP/3 implementation MAY impose a limit on the maximum size of the +/// message header it will accept on an individual HTTP message. The size +/// of a field list is calculated based on the uncompressed size of fields, +/// including the length of the name and value in bytes plus an overhead of +/// 32 bytes for each field. /// -/// -pub struct QpackMaxTableCapacity; +/// +pub struct MaxFieldSectionSize; -impl QpackMaxTableCapacity { - pub const ID: VarInt = VarInt::from_u32(0x01); - /// The default value is zero. See Section 3.2 for usage. This is the - /// equivalent of the `SETTINGS_HEADER_TABLE_SIZE` from HTTP/2. - pub const DEFAULT: VarInt = VarInt::from_u32(0); +impl MaxFieldSectionSize { + pub const ID: VarInt = VarInt::from_u32(0x06); pub const fn setting(value: VarInt) -> Setting { Setting::new(Self::ID, value) } } -impl SettingId for QpackMaxTableCapacity { - type Value = VarInt; +impl SettingId for MaxFieldSectionSize { + type Value = Option; fn id(&self) -> VarInt { Self::ID } - fn value_from(&self, settings: &Settings) -> VarInt { - settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + fn value_from(&self, settings: &Settings) -> Option { + settings.get_raw(Self::ID) } } -/// `SETTINGS_QPACK_BLOCKED_STREAMS` (0x07). Default: 0. -/// -/// The decoder specifies an upper bound on the number of streams that can -/// be blocked using the `SETTINGS_QPACK_BLOCKED_STREAMS` setting. An -/// encoder MUST limit the number of streams that could become blocked to -/// the value of `SETTINGS_QPACK_BLOCKED_STREAMS` at all times. -/// -/// -pub struct QpackBlockedStreams; +impl UnidirectionalStream<()> { + /// A control stream is indicated by a stream type of 0x00. Data on this + /// stream consists of HTTP/3 frames, as defined in Section 7.2. + /// + /// + pub const CONTROL_STREAM_TYPE: VarInt = VarInt::from_u32(0x00); +} -impl QpackBlockedStreams { - pub const ID: VarInt = VarInt::from_u32(0x07); - /// The default value is zero. See Section 2.1.2. - pub const DEFAULT: VarInt = VarInt::from_u32(0); +impl UnidirectionalStream { + pub const fn is_control_stream(&self) -> bool { + self.r#type().into_inner() == UnidirectionalStream::CONTROL_STREAM_TYPE.into_inner() + } - pub const fn setting(value: VarInt) -> Setting { - Setting::new(Self::ID, value) + pub async fn initial_control_stream(stream: S) -> Result + where + S: AsyncWrite + Unpin + Sized + Send, + { + Self::initial(UnidirectionalStream::CONTROL_STREAM_TYPE, stream) + .await + .map_err(|error| error.map_stream_reset(|_| H3CriticalStreamClosed::Control.into())) } } -impl SettingId for QpackBlockedStreams { - type Value = VarInt; +#[cfg(test)] +mod tests { + use std::{ + io, + pin::Pin, + task::{Context, Poll}, + }; - fn id(&self) -> VarInt { - Self::ID + use bytes::Buf; + use tokio::io::AsyncWriteExt; + + use super::*; + use crate::{ + codec::{DecodeError, DecodeExt, EncodeExt}, + connection, + dhttp::{datagram::settings::H3Datagram, webtransport::settings::EnableWebTransport}, + extended_connect::settings::EnableConnectProtocol, + quic, + varint::VarInt, + }; + + #[derive(Clone)] + struct FailWrite { + error: quic::StreamError, } - fn value_from(&self, settings: &Settings) -> VarInt { - settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + impl FailWrite { + fn reset(code: VarInt) -> Self { + Self { + error: quic::StreamError::Reset { code }, + } + } + + fn connection() -> Self { + Self { + error: quic::StreamError::Connection { + source: quic_connection_error(), + }, + } + } } -} -/// `SETTINGS_MAX_FIELD_SECTION_SIZE` (0x06). No default (unlimited). -/// -/// An HTTP/3 implementation MAY impose a limit on the maximum size of the -/// message header it will accept on an individual HTTP message. The size -/// of a field list is calculated based on the uncompressed size of fields, -/// including the length of the name and value in bytes plus an overhead of -/// 32 bytes for each field. -/// -/// -pub struct MaxFieldSectionSize; + struct FailAfterWrites { + successful_writes_before_failure: usize, + error: quic::StreamError, + } -impl MaxFieldSectionSize { - pub const ID: VarInt = VarInt::from_u32(0x06); + impl FailAfterWrites { + fn new(successful_writes_before_failure: usize, error: quic::StreamError) -> Self { + Self { + successful_writes_before_failure, + error, + } + } + } - pub const fn setting(value: VarInt) -> Setting { - Setting::new(Self::ID, value) + impl tokio::io::AsyncWrite for FailAfterWrites { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if self.successful_writes_before_failure == 0 { + return Poll::Ready(Err(io::Error::from(self.error.clone()))); + } + + self.successful_writes_before_failure -= 1; + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } } -} -impl SettingId for MaxFieldSectionSize { - type Value = Option; + struct FailRead; - fn id(&self) -> VarInt { - Self::ID + impl tokio::io::AsyncRead for FailRead { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Err(DecodeError::ArithmeticOverflow.into())) + } } - fn value_from(&self, settings: &Settings) -> Option { - settings.get_raw(Self::ID) + impl tokio::io::AsyncWrite for FailWrite { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(io::Error::from(self.error.clone()))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } } -} -/// `SETTINGS_ENABLE_CONNECT_PROTOCOL` (0x08). No default. -/// -/// Enables the Extended CONNECT method for WebSocket upgrades over HTTP/3. -/// The value MUST be 0 or 1. -/// -/// -pub struct EnableConnectProtocol; + struct FailBufRead { + error: Option, + } -impl EnableConnectProtocol { - pub const ID: VarInt = VarInt::from_u32(0x08); + impl FailBufRead { + fn new(error: connection::StreamError) -> Self { + Self { error: Some(error) } + } + } - pub const fn setting(enabled: bool) -> Setting { - Setting::new(Self::ID, VarInt::from_u32(enabled as u32)) + impl tokio::io::AsyncRead for FailBufRead { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } } -} -impl SettingId for EnableConnectProtocol { - type Value = bool; + impl tokio::io::AsyncBufRead for FailBufRead { + fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let error = self + .get_mut() + .error + .take() + .expect("test stream should be polled once"); + Poll::Ready(Err(io::Error::from(error))) + } - fn id(&self) -> VarInt { - Self::ID + fn consume(self: Pin<&mut Self>, _amt: usize) {} } - fn value_from(&self, settings: &Settings) -> bool { - settings - .get_raw(Self::ID) - .is_some_and(|v| v == VarInt::from_u32(1)) + fn quic_connection_error() -> quic::ConnectionError { + quic::ConnectionError::Application { + source: quic::ApplicationError { + code: Code::H3_INTERNAL_ERROR, + reason: "test failure".into(), + }, + } } -} -/// `SETTINGS_ENABLE_WEBTRANSPORT` (0x2b603742). No default. -/// -/// Enables WebTransport over HTTP/3. The value MUST be 0 or 1. -/// Requires `SETTINGS_ENABLE_CONNECT_PROTOCOL` to also be enabled. -/// -/// -pub struct EnableWebTransport; + fn assert_h3_connection_code(error: StreamError, expected: Code) { + assert!(matches!( + error, + StreamError::Connection { + source: connection::ConnectionError::H3 { source }, + } if source.code() == expected + )); + } + + fn assert_quic_connection_error(error: StreamError) { + assert!(matches!( + error, + StreamError::Connection { + source: connection::ConnectionError::Quic { .. }, + } + )); + } + + #[test] + fn boolean_setting_validation_uses_new_owner_modules() { + for id in [ + EnableConnectProtocol::ID, + EnableWebTransport::ID, + H3Datagram::ID, + ] { + let err = Setting::new(id, VarInt::from_u32(2)) + .check() + .expect_err("boolean setting value 2 must be rejected"); + assert!(matches!(err, InvalidSettingValue::BoolSetting { .. })); + } + } -impl EnableWebTransport { - pub const ID: VarInt = VarInt::from_u32(0x2b603742); + #[test] + fn setting_construction_validation_and_error_metadata() { + let setting = Setting::from((MaxFieldSectionSize::ID, VarInt::from_u32(4096))); + assert_eq!(setting.id, MaxFieldSectionSize.id()); + assert_eq!(setting.value, VarInt::from_u32(4096)); + assert!(setting.check().is_ok()); + + assert!( + Setting::new(EnableConnectProtocol::ID, VarInt::from_u32(0)) + .check() + .is_ok() + ); + assert!( + Setting::new(EnableConnectProtocol::ID, VarInt::from_u32(1)) + .check() + .is_ok() + ); + + let error = Setting::new(EnableConnectProtocol::ID, VarInt::from_u32(2)) + .check() + .expect_err("invalid boolean setting must fail"); + assert_eq!(error.code(), Code::H3_SETTINGS_ERROR); + assert_eq!( + error.to_string(), + "boolean setting 8 must have value 0 or 1, got 2", + ); + } + + #[test] + fn non_boolean_settings_accept_arbitrary_values_and_validation_error_has_no_source() { + let custom_setting = Setting::new(VarInt::from_u32(0x21), VarInt::MAX); + assert!(custom_setting.check().is_ok()); + + let error = Setting::new(H3Datagram::ID, VarInt::from_u32(42)) + .check() + .expect_err("invalid boolean setting must fail"); + assert!(std::error::Error::source(&error).is_none()); + assert_eq!(error.code(), Code::H3_SETTINGS_ERROR); + assert_eq!( + error.to_string(), + "boolean setting 51 must have value 0 or 1, got 42", + ); + } + + #[test] + fn settings_accessors_iterators_and_extension_paths() { + let mut settings = Settings::default(); + assert_eq!(settings.get(VarInt::from_u32(0x1234)), None); + assert_eq!(settings.max_field_section_size(), None); + + settings.set(MaxFieldSectionSize::setting(VarInt::from_u32(4096))); + settings.extend([ + EnableConnectProtocol::setting(true), + H3Datagram::setting(false), + ]); + settings.extend(std::iter::once(H3Datagram::setting(true))); + + assert_eq!( + settings.get(MaxFieldSectionSize), + Some(VarInt::from_u32(4096)), + ); + assert_eq!( + settings.max_field_section_size(), + Some(VarInt::from_u32(4096)), + ); + assert_eq!( + settings.get(VarInt::from_u32(0x06)), + Some(VarInt::from_u32(4096)), + ); + assert!(settings.enable_connect_protocol()); + assert!(settings.h3_datagram()); + assert!(!settings.enable_webtransport()); + + let borrowed: Vec<_> = (&settings).into_iter().collect(); + assert_eq!(borrowed.len(), 3); + assert_eq!(borrowed[0].id, MaxFieldSectionSize::ID); + + let owned: Vec<_> = settings.clone().into_iter().collect(); + assert_eq!(owned.len(), borrowed.len()); + for (left, right) in owned.iter().zip(&borrowed) { + assert_eq!(left.id, right.id); + assert_eq!(left.value, right.value); + } - pub const fn setting(enabled: bool) -> Setting { - Setting::new(Self::ID, VarInt::from_u32(enabled as u32)) + let rebuilt = Settings::from_iter(owned); + assert_eq!(settings, rebuilt); } -} -impl SettingId for EnableWebTransport { - type Value = bool; + #[test] + fn settings_with_and_with_all_compose_setting_fragments() { + let settings = Settings::default() + .with(MaxFieldSectionSize::setting(VarInt::from_u32(4096))) + .with_all([ + EnableConnectProtocol::setting(true), + H3Datagram::setting(false), + ]) + .with(H3Datagram::setting(true)); - fn id(&self) -> VarInt { - Self::ID + assert_eq!( + settings.max_field_section_size(), + Some(VarInt::from_u32(4096)), + ); + assert!(settings.enable_connect_protocol()); + assert!(settings.h3_datagram()); } - fn value_from(&self, settings: &Settings) -> bool { - settings - .get_raw(Self::ID) - .is_some_and(|v| v == VarInt::from_u32(1)) + #[test] + fn setting_id_methods_return_wire_ids_and_typed_values() { + let mut settings = Settings::default(); + let raw_id = VarInt::from_u32(0x21); + + assert_eq!(raw_id.id(), raw_id); + assert_eq!(raw_id.value_from(&settings), None); + assert_eq!(MaxFieldSectionSize.id(), MaxFieldSectionSize::ID); + assert_eq!(MaxFieldSectionSize.value_from(&settings), None); + + settings.set(Setting::new(raw_id, VarInt::from_u32(7))); + settings.set(MaxFieldSectionSize::setting(VarInt::from_u32(4096))); + + assert_eq!(raw_id.value_from(&settings), Some(VarInt::from_u32(7))); + assert_eq!(settings.get(raw_id), Some(VarInt::from_u32(7))); + assert_eq!( + MaxFieldSectionSize.value_from(&settings), + Some(VarInt::from_u32(4096)), + ); + } + + #[tokio::test] + async fn setting_decode_maps_incomplete_id_and_value_to_closed_control_stream() { + for payload in [ + BufList::new(), + BufList::from_buf(&[MaxFieldSectionSize::ID.into_inner() as u8][..]), + ] { + let error = match payload.decode::().await { + Ok(_) => panic!("incomplete setting must fail"), + Err(error) => error, + }; + assert_h3_connection_code(error, Code::H3_CLOSED_CRITICAL_STREAM); + } } -} -/// `H3_DATAGRAM` (0x33). No default. -/// -/// Indicates support for HTTP/3 datagrams (RFC 9297). The value MUST be 0 or 1. -/// -/// -pub struct H3Datagram; + #[tokio::test] + async fn setting_decode_maps_payload_decode_error_to_frame_decode_error() { + let error = FailRead + .decode::() + .await + .err() + .expect("typed decode failure should be a frame decode error"); + + assert_h3_connection_code(error, Code::H3_FRAME_ERROR); + } + + #[tokio::test] + async fn setting_encode_maps_reset_to_closed_control_stream_and_preserves_connection_errors() { + let mut idle_writer = FailWrite::reset(VarInt::from_u32(0)); + idle_writer.flush().await.expect("flush succeeds"); + idle_writer.shutdown().await.expect("shutdown succeeds"); -impl H3Datagram { - pub const ID: VarInt = VarInt::from_u32(0x33); + let reset_code = VarInt::from_u32(77); + let error = Setting::new(MaxFieldSectionSize::ID, VarInt::from_u32(1)) + .encode_into(FailWrite::reset(reset_code)) + .await + .expect_err("write reset must fail"); + assert_h3_connection_code(error, Code::H3_CLOSED_CRITICAL_STREAM); - pub const fn setting(enabled: bool) -> Setting { - Setting::new(Self::ID, VarInt::from_u32(enabled as u32)) + let error = Setting::new(MaxFieldSectionSize::ID, VarInt::from_u32(1)) + .encode_into(FailWrite::connection()) + .await + .expect_err("connection write failure must fail"); + assert_quic_connection_error(error); } -} -impl SettingId for H3Datagram { - type Value = bool; + #[tokio::test] + async fn settings_decode_propagates_stream_fill_buf_errors() { + let reset_code = VarInt::from_u32(88); + let reset = FailBufRead::new(StreamError::Reset { code: reset_code }) + .decode::() + .await + .expect_err("fill_buf reset must fail"); + assert!(matches!(reset, StreamError::Reset { code } if code == reset_code)); - fn id(&self) -> VarInt { - Self::ID + let connection = + FailBufRead::new(connection::ConnectionError::from(quic_connection_error()).into()) + .decode::() + .await + .expect_err("fill_buf connection error must fail"); + assert_quic_connection_error(connection); + } + + #[tokio::test] + async fn setting_encode_maps_value_write_reset_to_closed_control_stream() { + let mut idle_writer = FailAfterWrites::new( + 1, + quic::StreamError::Reset { + code: VarInt::from_u32(0), + }, + ); + idle_writer.flush().await.expect("flush succeeds"); + idle_writer.shutdown().await.expect("shutdown succeeds"); + + let error = Setting::new(MaxFieldSectionSize::ID, VarInt::from_u32(4096)) + .encode_into(FailAfterWrites::new( + 1, + quic::StreamError::Reset { + code: VarInt::from_u32(123), + }, + )) + .await + .expect_err("value write reset must fail"); + + assert_h3_connection_code(error, Code::H3_CLOSED_CRITICAL_STREAM); } - fn value_from(&self, settings: &Settings) -> bool { - settings - .get_raw(Self::ID) - .is_some_and(|v| v == VarInt::from_u32(1)) + #[tokio::test] + async fn setting_encode_decode_round_trips_and_rejects_invalid_bool() { + let mut encoded = BufList::new(); + encoded + .encode_one(Setting::new( + MaxFieldSectionSize::ID, + VarInt::from_u32(4096), + )) + .await + .expect("setting encoding into buflist is infallible"); + let decoded = encoded.decode::().await.expect("setting decodes"); + assert_eq!(decoded.id, MaxFieldSectionSize::ID); + assert_eq!(decoded.value, VarInt::from_u32(4096)); + + let mut invalid = BufList::new(); + invalid + .encode_one(Setting::new(EnableWebTransport::ID, VarInt::from_u32(2))) + .await + .expect("setting encoding into buflist is infallible"); + let error = invalid + .decode::() + .await + .err() + .expect("invalid boolean setting must fail to decode"); + assert_h3_connection_code(error, Code::H3_SETTINGS_ERROR); } -} -impl UnidirectionalStream<()> { - /// A control stream is indicated by a stream type of 0x00. Data on this - /// stream consists of HTTP/3 frames, as defined in Section 7.2. - /// - /// - pub const CONTROL_STREAM_TYPE: VarInt = VarInt::from_u32(0x00); -} + #[tokio::test] + async fn settings_encode_to_frame_and_decode_payload() { + let settings = Settings::from_iter([ + MaxFieldSectionSize::setting(VarInt::from_u32(8192)), + EnableConnectProtocol::setting(true), + ]); -impl UnidirectionalStream { - pub const fn is_control_stream(&self) -> bool { - self.r#type().into_inner() == UnidirectionalStream::CONTROL_STREAM_TYPE.into_inner() + let frame = BufList::new() + .encode(&settings) + .await + .expect("settings encoding into buflist is infallible"); + assert_eq!(frame.r#type(), Frame::SETTINGS_FRAME_TYPE); + assert!(frame.length().into_inner() > 0); + + let decoded = frame + .into_payload() + .decode::() + .await + .expect("settings decode from payload"); + assert_eq!(decoded, settings); + + let frame = BufList::new() + .encode(settings.clone()) + .await + .expect("owned settings encoding into buflist is infallible"); + assert_eq!(frame.r#type(), Frame::SETTINGS_FRAME_TYPE); } - pub async fn initial_control_stream(stream: S) -> Result - where - S: AsyncWrite + Unpin + Sized + Send, - { - Self::initial(UnidirectionalStream::CONTROL_STREAM_TYPE, stream) + #[tokio::test] + async fn empty_settings_encode_to_zero_length_frame_and_decode_to_default() { + let settings = Settings::default(); + + let frame = BufList::new() + .encode(&settings) .await - .map_err(|error| error.map_stream_reset(|_| H3CriticalStreamClosed::Control.into())) + .expect("settings encoding into buflist is infallible"); + assert_eq!(frame.r#type(), Frame::SETTINGS_FRAME_TYPE); + assert_eq!(frame.length(), VarInt::from_u32(0)); + + let decoded = frame + .into_payload() + .decode::() + .await + .expect("empty settings payload decodes"); + assert_eq!(decoded, settings); + } + + #[tokio::test] + async fn settings_decode_uses_last_value_for_duplicate_identifiers() { + let mut encoded = BufList::new(); + encoded + .encode_one(MaxFieldSectionSize::setting(VarInt::from_u32(1024))) + .await + .expect("setting encoding into buflist is infallible"); + encoded + .encode_one(MaxFieldSectionSize::setting(VarInt::from_u32(2048))) + .await + .expect("setting encoding into buflist is infallible"); + + let decoded = encoded + .decode::() + .await + .expect("settings payload decodes"); + + assert_eq!( + decoded.max_field_section_size(), + Some(VarInt::from_u32(2048)), + ); + assert_eq!(decoded.into_iter().count(), 1); + } + + #[tokio::test] + async fn initial_control_stream_maps_write_errors() { + let error = + UnidirectionalStream::initial_control_stream(FailWrite::reset(VarInt::from_u32(9))) + .await + .err() + .expect("control stream reset must fail"); + assert_h3_connection_code(error, Code::H3_CLOSED_CRITICAL_STREAM); + + let error = UnidirectionalStream::initial_control_stream(FailWrite::connection()) + .await + .err() + .expect("control stream connection failure must fail"); + assert_quic_connection_error(error); + } + + #[tokio::test] + async fn control_stream_helpers_identify_and_write_stream_type() { + let control = UnidirectionalStream::initial_control_stream(BufList::new()) + .await + .expect("control stream initialization"); + + assert!(control.is_control_stream()); + assert_eq!(control.r#type(), UnidirectionalStream::CONTROL_STREAM_TYPE); + assert!(control.into_inner().has_remaining()); } } diff --git a/src/dhttp/stream.rs b/src/dhttp/stream.rs index 15ca9a0..b4ce934 100644 --- a/src/dhttp/stream.rs +++ b/src/dhttp/stream.rs @@ -11,7 +11,7 @@ use crate::{ codec::{DecodeExt, DecodeFrom, EncodeExt, StreamDecodeError}, connection::StreamError, error::H3GeneralProtocolError, - quic::{self, CancelStream, GetStreamId, StopStream}, + quic::{self, GetStreamId, ResetStream, StopStream}, varint::VarInt, }; @@ -148,13 +148,13 @@ impl Stream for UnidirectionalStream { } } -impl CancelStream for UnidirectionalStream { - fn poll_cancel( +impl ResetStream for UnidirectionalStream { + fn poll_reset( self: Pin<&mut Self>, cx: &mut Context, code: VarInt, ) -> Poll> { - self.project().stream.poll_cancel(cx, code) + self.project().stream.poll_reset(cx, code) } } @@ -195,3 +195,256 @@ impl + ?Sized> Sink for UnidirectionalStream { self.project().stream.poll_close(cx) } } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use bytes::Bytes; + use futures::{SinkExt, StreamExt, future::poll_fn}; + use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; + + use super::*; + use crate::{ + buflist::BufList, + codec::{DecodeExt, EncodeExt}, + }; + + #[tokio::test] + async fn initial_writes_stream_type_and_keeps_payload_writable() { + let stream_type = VarInt::from_u32(0x21); + let mut stream = UnidirectionalStream::initial(stream_type, BufList::new()) + .await + .expect("initial stream"); + + assert_eq!(stream.r#type(), stream_type); + assert!(stream.is_reserved_stream()); + + stream.write_all(b"payload").await.expect("payload write"); + AsyncWriteExt::flush(&mut stream) + .await + .expect("payload flush"); + AsyncWriteExt::shutdown(&mut stream) + .await + .expect("payload shutdown"); + + let mut inner = stream.into_inner(); + let decoded_type = inner + .decode_one::() + .await + .expect("stream type prefix"); + let mut payload = Vec::new(); + inner + .read_to_end(&mut payload) + .await + .expect("remaining payload"); + + assert_eq!(decoded_type, stream_type); + assert_eq!(payload, b"payload"); + } + + #[tokio::test] + async fn accept_reads_stream_type_and_leaves_remaining_payload() { + let stream_type = VarInt::from_u32(0x02); + let mut payload = BufList::new(); + payload + .encode_one(stream_type) + .await + .expect("stream type encoding"); + payload.write(Bytes::from_static(b"body")); + + let mut stream = UnidirectionalStream::accept(payload) + .await + .expect("accepted stream"); + let mut body = Vec::new(); + stream + .read_to_end(&mut body) + .await + .expect("remaining payload"); + + assert_eq!(stream.r#type(), stream_type); + assert!(!stream.is_reserved_stream()); + assert_eq!(body, b"body"); + } + + #[tokio::test] + async fn decode_ext_uses_unidirectional_accept_path() { + let stream_type = VarInt::from_u32(0x03); + let mut payload = BufList::new(); + payload + .encode_one(stream_type) + .await + .expect("stream type encoding"); + payload.write(Bytes::from_static(b"after-type")); + + let mut stream = payload + .decode::>() + .await + .expect("decoded unidirectional stream"); + + assert_eq!(stream.r#type(), stream_type); + let mut body = Vec::new(); + stream.read_to_end(&mut body).await.expect("body"); + assert_eq!(body, b"after-type"); + } + + #[tokio::test] + async fn accept_rejects_missing_stream_type() { + let error = UnidirectionalStream::accept(BufList::new()) + .await + .err() + .expect("stream type is mandatory"); + + assert!(matches!( + error, + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + } if source.code() == crate::error::Code::H3_GENERAL_PROTOCOL_ERROR + )); + } + + #[tokio::test] + async fn async_buf_read_delegates_fill_and_consume_to_inner_stream() { + let mut payload = BufList::new(); + payload.write(Bytes::from_static(b"abc")); + payload.write(Bytes::from_static(b"def")); + let mut stream = UnidirectionalStream { + r#type: VarInt::from_u32(0), + stream: payload, + }; + + assert_eq!(stream.fill_buf().await.expect("fill buf"), b"abc"); + stream.consume(2); + assert_eq!( + stream.fill_buf().await.expect("remaining first chunk"), + b"c" + ); + stream.consume(1); + assert_eq!(stream.fill_buf().await.expect("second chunk"), b"def"); + stream.consume(3); + assert_eq!(stream.fill_buf().await.expect("eof"), b""); + } + + #[test] + fn reserved_stream_detection_matches_http3_grease_pattern() { + for value in [0x21, 0x40, 0x5f, 0x7e] { + let stream = UnidirectionalStream { + r#type: VarInt::from_u32(value), + stream: (), + }; + assert!(stream.is_reserved_stream(), "{value:#x} is reserved"); + } + + for value in [0x00, 0x20, 0x22, 0x41] { + let stream = UnidirectionalStream { + r#type: VarInt::from_u32(value), + stream: (), + }; + assert!(!stream.is_reserved_stream(), "{value:#x} is not reserved"); + } + } + + #[tokio::test] + async fn stream_and_sink_delegate_to_inner_value() { + let mut stream = UnidirectionalStream { + r#type: VarInt::from_u32(0), + stream: futures::stream::iter([Ok::<_, quic::StreamError>(Bytes::from_static( + b"chunk", + ))]), + }; + + assert_eq!( + stream.next().await.unwrap().unwrap(), + Bytes::from_static(b"chunk") + ); + assert!(stream.next().await.is_none()); + + let mut sink = UnidirectionalStream { + r#type: VarInt::from_u32(0), + stream: futures::sink::drain::(), + }; + sink.send(Bytes::from_static(b"ignored")) + .await + .expect("sink forwards to inner drain"); + sink.close() + .await + .expect("sink close forwards to inner drain"); + } + + #[derive(Debug)] + struct ControlStream { + stream_id: VarInt, + stopped: Arc>>, + reset: Arc>>, + } + + impl GetStreamId for ControlStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl StopStream for ControlStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + *self.stopped.lock().expect("stop state poisoned") = Some(code); + Poll::Ready(Ok(())) + } + } + + impl ResetStream for ControlStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + *self.reset.lock().expect("reset state poisoned") = Some(code); + Poll::Ready(Ok(())) + } + } + + #[tokio::test] + async fn control_traits_delegate_to_inner_stream() { + let stopped = Arc::new(Mutex::new(None)); + let reset = Arc::new(Mutex::new(None)); + let stream_id = VarInt::from_u32(17); + let stop_code = VarInt::from_u32(23); + let reset_code = VarInt::from_u32(29); + let mut stream = UnidirectionalStream { + r#type: VarInt::from_u32(0), + stream: ControlStream { + stream_id, + stopped: stopped.clone(), + reset: reset.clone(), + }, + }; + + assert_eq!( + poll_fn(|cx| Pin::new(&mut stream).poll_stream_id(cx)) + .await + .expect("stream id"), + stream_id + ); + poll_fn(|cx| Pin::new(&mut stream).poll_stop(cx, stop_code)) + .await + .expect("stop forwarded"); + poll_fn(|cx| Pin::new(&mut stream).poll_reset(cx, reset_code)) + .await + .expect("reset forwarded"); + + assert_eq!( + *stopped.lock().expect("stop state poisoned"), + Some(stop_code) + ); + assert_eq!( + *reset.lock().expect("reset state poisoned"), + Some(reset_code) + ); + } +} diff --git a/src/dhttp/webtransport.rs b/src/dhttp/webtransport.rs new file mode 100644 index 0000000..24460a5 --- /dev/null +++ b/src/dhttp/webtransport.rs @@ -0,0 +1,2 @@ +pub mod capsule; +pub mod settings; diff --git a/src/dhttp/webtransport/capsule.rs b/src/dhttp/webtransport/capsule.rs new file mode 100644 index 0000000..830e765 --- /dev/null +++ b/src/dhttp/webtransport/capsule.rs @@ -0,0 +1,322 @@ +use std::{convert::Infallible, io}; + +use bytes::{Buf, Bytes}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::{ + buflist::BufList, + codec::{DecodeExt, DecodeFrom, EncodeExt, EncodeInto}, + varint::{self, VarInt}, +}; + +const READ_CHUNK_SIZE: u64 = 8 * 1024; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct CapsuleType(VarInt); + +impl CapsuleType { + pub const DATAGRAM: Self = Self(VarInt::from_u32(0x00)); + pub const WT_CLOSE_SESSION: Self = Self(VarInt::from_u32(0x2843)); + pub const WT_DRAIN_SESSION: Self = Self(VarInt::from_u32(0x78ae)); + pub const WT_MAX_STREAMS_BIDI: Self = Self(VarInt::from_u32(0x190b4d3f)); + pub const WT_MAX_STREAMS_UNI: Self = Self(VarInt::from_u32(0x190b4d40)); + pub const WT_STREAMS_BLOCKED_BIDI: Self = Self(VarInt::from_u32(0x190b4d43)); + pub const WT_STREAMS_BLOCKED_UNI: Self = Self(VarInt::from_u32(0x190b4d44)); + + pub const fn into_inner(self) -> VarInt { + self.0 + } +} + +impl From for CapsuleType { + fn from(value: VarInt) -> Self { + Self(value) + } +} + +impl From for VarInt { + fn from(value: CapsuleType) -> Self { + value.0 + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Capsule { + r#type: CapsuleType, + length: VarInt, + payload: P, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct CapsuleHeader { + r#type: CapsuleType, + length: VarInt, +} + +impl CapsuleHeader { + pub const fn r#type(&self) -> CapsuleType { + self.r#type + } + + pub const fn length(&self) -> VarInt { + self.length + } +} + +impl Capsule

{ + pub fn new(r#type: CapsuleType, payload: P) -> Result + where + P: Buf + Sized, + { + let length = VarInt::try_from(payload.remaining())?; + Ok(Self { + r#type, + length, + payload, + }) + } + + pub const fn r#type(&self) -> CapsuleType { + self.r#type + } + + pub const fn length(&self) -> VarInt { + self.length + } + + pub const fn payload(&self) -> &P { + &self.payload + } + + pub fn into_payload(self) -> P + where + P: Sized, + { + self.payload + } + + pub fn map(self, map: impl FnOnce(P) -> U) -> Capsule + where + P: Sized, + { + Capsule { + r#type: self.r#type, + length: self.length, + payload: map(self.payload), + } + } +} + +impl Capsule { + pub async fn skip_from(stream: S, max_skip_chunk: VarInt) -> Result + where + S: AsyncRead + Unpin + Send, + { + let mut stream = stream; + let r#type = CapsuleType::from(stream.decode_one::().await?); + let length = stream.decode_one::().await?; + let mut remaining = length.into_inner(); + let scratch_len = max_skip_chunk.into_inner().clamp(1, READ_CHUNK_SIZE) as usize; + let mut scratch = vec![0; scratch_len]; + while remaining > 0 { + let len = remaining.min(scratch.len() as u64) as usize; + stream.read_exact(&mut scratch[..len]).await?; + remaining -= len as u64; + } + Ok(CapsuleHeader { r#type, length }) + } +} + +impl<'s, P, S> EncodeInto<&'s mut S> for Capsule

+where + P: Buf + Send, + S: AsyncWrite + Unpin + Send, +{ + type Output = (); + type Error = io::Error; + + async fn encode_into(self, stream: &'s mut S) -> Result { + let Capsule { + r#type, + length, + mut payload, + } = self; + stream.encode_one(r#type.into_inner()).await?; + stream.encode_one(length).await?; + while payload.has_remaining() { + let chunk = payload.chunk(); + stream.write_all(chunk).await?; + let len = chunk.len(); + payload.advance(len); + } + Ok(()) + } +} + +impl

EncodeInto for Capsule

+where + P: Buf + Send, +{ + type Output = BufList; + type Error = Infallible; + + async fn encode_into(self, mut stream: BufList) -> Result { + stream + .encode_one(self) + .await + .expect("encoding a capsule into a BufList is infallible"); + Ok(stream) + } +} + +impl DecodeFrom for Capsule +where + S: AsyncRead + Unpin + Send, +{ + type Error = io::Error; + + async fn decode_from(mut stream: S) -> Result { + let r#type = CapsuleType::from(stream.decode_one::().await?); + let length = stream.decode_one::().await?; + let mut remaining = length.into_inner(); + let mut payload = BufList::new(); + while remaining > 0 { + let len = remaining.min(READ_CHUNK_SIZE) as usize; + let mut bytes = vec![0; len]; + stream.read_exact(&mut bytes).await?; + payload.write(Bytes::from(bytes)); + remaining -= len as u64; + } + Ok(Self { + r#type, + length, + payload, + }) + } +} + +#[cfg(test)] +mod tests { + use bytes::{Buf, Bytes}; + use futures::{Stream, stream}; + + use super::*; + use crate::{ + buflist::BufList, + codec::{DecodeExt, EncodeExt, StreamReader}, + quic, + varint::VarInt, + webtransport::{CloseSession, WebTransportStreamCount}, + }; + + #[test] + fn capsule_type_constants_use_draft_codepoints() { + assert_eq!(CapsuleType::DATAGRAM.into_inner(), VarInt::from_u32(0x00)); + assert_eq!( + CapsuleType::WT_CLOSE_SESSION.into_inner(), + VarInt::from_u32(0x2843) + ); + assert_eq!( + CapsuleType::WT_DRAIN_SESSION.into_inner(), + VarInt::from_u32(0x78ae) + ); + assert_eq!( + CapsuleType::WT_MAX_STREAMS_BIDI.into_inner(), + VarInt::from_u32(0x190b4d3f) + ); + assert_eq!( + CapsuleType::WT_MAX_STREAMS_UNI.into_inner(), + VarInt::from_u32(0x190b4d40) + ); + assert_eq!( + CapsuleType::WT_STREAMS_BLOCKED_BIDI.into_inner(), + VarInt::from_u32(0x190b4d43) + ); + assert_eq!( + CapsuleType::WT_STREAMS_BLOCKED_UNI.into_inner(), + VarInt::from_u32(0x190b4d44) + ); + } + + #[tokio::test] + async fn capsule_encode_decode_round_trips_unknown_types() { + fn byte_stream( + data: impl IntoIterator, + ) -> impl Stream> { + stream::iter(data.into_iter().map(|byte| Ok(Bytes::from(vec![byte])))) + } + + let mut payload = BufList::new(); + payload.write(Bytes::from_static(b"hello")); + let mut encoded = BufList::new() + .encode( + Capsule::new(CapsuleType::from(VarInt::from_u32(0x2f)), payload).expect("capsule"), + ) + .await + .expect("encode"); + let bytes = encoded.copy_to_bytes(encoded.remaining()); + let mut reader = StreamReader::new(byte_stream(bytes)); + + let decoded = reader + .decode_one::>() + .await + .expect("decode"); + + assert_eq!(decoded.r#type(), CapsuleType::from(VarInt::from_u32(0x2f))); + assert_eq!(decoded.length(), VarInt::from_u32(5)); + let mut payload = decoded.into_payload(); + assert_eq!(payload.copy_to_bytes(5), Bytes::from_static(b"hello")); + } + + #[tokio::test] + async fn close_session_capsule_payload_round_trips_u32_and_utf8_message() { + let close = CloseSession::try_from((7_u32, "done")).expect("valid close"); + let payload = BufList::new() + .encode(close.clone()) + .await + .expect("close session encoding succeeds"); + + let decoded = payload + .decode::() + .await + .expect("close session payload decodes"); + + assert_eq!(decoded, close); + } + + #[tokio::test] + async fn stream_count_payload_round_trips_varint_without_u64_conversion() { + let count = WebTransportStreamCount::try_from(VarInt::from_u32(13)).expect("valid count"); + let payload = BufList::new() + .encode(count) + .await + .expect("stream count encoding succeeds"); + + let decoded = payload + .decode::() + .await + .expect("stream count payload decodes"); + + assert_eq!(decoded.into_varint(), VarInt::from_u32(13)); + } + + #[tokio::test] + async fn capsule_payload_skip_does_not_materialize_unknown_payload() { + let mut payload = BufList::new(); + payload.write(Bytes::from_static(b"unknown")); + let mut encoded = BufList::new() + .encode( + Capsule::new(CapsuleType::from(VarInt::from_u32(0x2f)), payload).expect("capsule"), + ) + .await + .expect("encode"); + + let skipped = Capsule::skip_from(&mut encoded, VarInt::from_u32(1024)) + .await + .expect("unknown capsule can be skipped"); + + assert_eq!(skipped.r#type(), CapsuleType::from(VarInt::from_u32(0x2f))); + assert_eq!(skipped.length(), VarInt::from_u32(7)); + assert_eq!(encoded.remaining(), 0); + } +} diff --git a/src/dhttp/webtransport/settings.rs b/src/dhttp/webtransport/settings.rs new file mode 100644 index 0000000..a2926a3 --- /dev/null +++ b/src/dhttp/webtransport/settings.rs @@ -0,0 +1,246 @@ +use crate::{ + dhttp::settings::{Setting, SettingId, Settings}, + extended_connect::settings::EnableConnectProtocol, + varint::VarInt, +}; + +/// `SETTINGS_ENABLE_WEBTRANSPORT` (0x2c7cf000). No default. +/// +/// Indicates support for WebTransport over HTTP/3 draft-15. The value MUST be 0 or 1. +pub struct EnableWebTransport; + +impl EnableWebTransport { + pub const ID: VarInt = VarInt::from_u32(0x2c7cf000); + + pub const fn setting(enabled: bool) -> Setting { + Setting::new(Self::ID, VarInt::from_u32(enabled as u32)) + } +} + +impl SettingId for EnableWebTransport { + type Value = bool; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> bool { + settings + .get_raw(Self::ID) + .is_some_and(|value| value == VarInt::from_u32(1)) + } +} + +/// `SETTINGS_WT_INITIAL_MAX_STREAMS_UNI` (0x2b64). Default: 0. +pub struct InitialMaxStreamsUni; + +impl InitialMaxStreamsUni { + pub const ID: VarInt = VarInt::from_u32(0x2b64); + pub const DEFAULT: VarInt = VarInt::from_u32(0); + + pub const fn setting(value: VarInt) -> Setting { + Setting::new(Self::ID, value) + } +} + +impl SettingId for InitialMaxStreamsUni { + type Value = VarInt; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> VarInt { + settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + } +} + +/// `SETTINGS_WT_INITIAL_MAX_STREAMS_BIDI` (0x2b65). Default: 0. +pub struct InitialMaxStreamsBidi; + +impl InitialMaxStreamsBidi { + pub const ID: VarInt = VarInt::from_u32(0x2b65); + pub const DEFAULT: VarInt = VarInt::from_u32(0); + + pub const fn setting(value: VarInt) -> Setting { + Setting::new(Self::ID, value) + } +} + +impl SettingId for InitialMaxStreamsBidi { + type Value = VarInt; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> VarInt { + settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + } +} + +/// `SETTINGS_WT_INITIAL_MAX_DATA` (0x2b61). Default: 0. +pub struct InitialMaxData; + +impl InitialMaxData { + pub const ID: VarInt = VarInt::from_u32(0x2b61); + pub const DEFAULT: VarInt = VarInt::from_u32(0); + + pub const fn setting(value: VarInt) -> Setting { + Setting::new(Self::ID, value) + } +} + +impl SettingId for InitialMaxData { + type Value = VarInt; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> VarInt { + settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + } +} + +/// HTTP/3 settings fragment advertising WebTransport support. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WebTransportSupport { + pub initial_max_streams_bidi: VarInt, + pub initial_max_streams_uni: VarInt, + pub initial_max_data: VarInt, +} + +impl Default for WebTransportSupport { + fn default() -> Self { + Self { + initial_max_streams_bidi: VarInt::from_u32(16), + initial_max_streams_uni: VarInt::from_u32(16), + initial_max_data: VarInt::MAX, + } + } +} + +impl IntoIterator for WebTransportSupport { + type Item = Setting; + type IntoIter = std::array::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + [ + EnableConnectProtocol::setting(true), + EnableWebTransport::setting(true), + InitialMaxStreamsBidi::setting(self.initial_max_streams_bidi), + InitialMaxStreamsUni::setting(self.initial_max_streams_uni), + InitialMaxData::setting(self.initial_max_data), + ] + .into_iter() + } +} + +impl Settings { + pub fn enable_webtransport(&self) -> bool { + self.get(EnableWebTransport) + } + + pub fn wt_initial_max_streams_uni(&self) -> VarInt { + self.get(InitialMaxStreamsUni) + } + + pub fn wt_initial_max_streams_bidi(&self) -> VarInt { + self.get(InitialMaxStreamsBidi) + } + + pub fn wt_initial_max_data(&self) -> VarInt { + self.get(InitialMaxData) + } + + pub fn webtransport_flow_control_enabled(&self) -> bool { + self.wt_initial_max_streams_uni() != VarInt::from_u32(0) + || self.wt_initial_max_streams_bidi() != VarInt::from_u32(0) + || self.wt_initial_max_data() != VarInt::from_u32(0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::varint::VarInt; + + #[test] + fn enable_webtransport_uses_draft15_codepoint() { + assert_eq!(EnableWebTransport::ID.into_inner(), 0x2c7cf000); + + let mut settings = Settings::default(); + assert!(!settings.enable_webtransport()); + + settings.set(EnableWebTransport::setting(true)); + assert!(settings.enable_webtransport()); + assert_eq!( + settings.get(VarInt::from_u32(0x2c7cf000)), + Some(VarInt::from_u32(1)), + ); + assert_eq!(settings.get(VarInt::from_u32(0x2b603742)), None); + } + + #[test] + fn enable_webtransport_exposes_id_and_treats_only_one_as_enabled() { + assert_eq!(EnableWebTransport.id(), EnableWebTransport::ID); + + let mut settings = Settings::default(); + settings.set(EnableWebTransport::setting(false)); + assert!(!settings.enable_webtransport()); + assert_eq!( + settings.get(VarInt::from_u32(0x2c7cf000)), + Some(VarInt::from_u32(0)), + ); + + settings.set(Setting::new(EnableWebTransport::ID, VarInt::from_u32(2))); + assert!(!settings.enable_webtransport()); + } + + #[test] + fn webtransport_flow_control_settings_apply_draft_defaults() { + let settings = Settings::default(); + + assert_eq!(settings.wt_initial_max_streams_uni(), VarInt::from_u32(0)); + assert_eq!(settings.wt_initial_max_streams_bidi(), VarInt::from_u32(0)); + assert_eq!(settings.wt_initial_max_data(), VarInt::from_u32(0)); + assert!(!settings.webtransport_flow_control_enabled()); + } + + #[test] + fn webtransport_flow_control_settings_use_typed_accessors() { + let mut settings = Settings::default(); + settings.set(InitialMaxStreamsUni::setting(VarInt::from_u32(11))); + settings.set(InitialMaxStreamsBidi::setting(VarInt::from_u32(13))); + settings.set(InitialMaxData::setting(VarInt::MAX)); + + assert_eq!(settings.wt_initial_max_streams_uni(), VarInt::from_u32(11)); + assert_eq!(settings.wt_initial_max_streams_bidi(), VarInt::from_u32(13)); + assert_eq!(settings.wt_initial_max_data(), VarInt::MAX); + assert!(settings.webtransport_flow_control_enabled()); + assert_eq!(InitialMaxStreamsUni.id(), VarInt::from_u32(0x2b64)); + assert_eq!(InitialMaxStreamsBidi.id(), VarInt::from_u32(0x2b65)); + assert_eq!(InitialMaxData.id(), VarInt::from_u32(0x2b61)); + } + + #[test] + fn webtransport_support_is_a_composable_settings_fragment() { + let settings = Settings::default() + .with_all(WebTransportSupport::default()) + .with(crate::dhttp::settings::MaxFieldSectionSize::setting( + VarInt::from_u32(4096), + )); + + assert!(settings.enable_connect_protocol()); + assert!(settings.enable_webtransport()); + assert_eq!(settings.wt_initial_max_streams_bidi(), VarInt::from_u32(16)); + assert_eq!(settings.wt_initial_max_streams_uni(), VarInt::from_u32(16)); + assert_eq!(settings.wt_initial_max_data(), VarInt::MAX); + assert!(settings.webtransport_flow_control_enabled()); + assert_eq!( + settings.max_field_section_size(), + Some(VarInt::from_u32(4096)), + ); + } +} diff --git a/src/dquic.rs b/src/dquic.rs index d42f8c8..5bd21db 100644 --- a/src/dquic.rs +++ b/src/dquic.rs @@ -1,7 +1,123 @@ -mod client; -mod server; +pub mod binds; +pub mod client; +pub mod common; +mod endpoint; +pub mod identity; +pub mod network; +pub mod server; +pub mod sni; + mod shim; -pub use client::*; pub use dquic::*; -pub use server::*; +pub use endpoint::*; +pub use identity::*; +pub use network::*; + +/// dquic transport parameters — client/server/peer parameter sets, IDs, value types +pub mod param { + pub use dquic::qbase::param::{ + ClientParameters, ParameterId, ParameterValue, ParameterValueType, PeerParameters, + ServerParameters, error::Error as ParamError, preferred_address::PreferredAddress, + }; + + /// convenience constructors for parameters + pub mod handy { + pub use dquic::qbase::param::handy::{client_parameters, server_parameters}; + } +} + +/// dquic token types — address validation tokens +pub mod token { + pub use dquic::qbase::token::{TokenProvider, TokenSink}; + + /// convenience token implementations + pub mod handy { + pub use dquic::qbase::token::handy::*; + } +} + +/// dquic TLS / client authentication types +pub mod tls { + pub use dquic::qconnection::tls::{ + AuthClient, ClientAuthorityVerifyResult, ClientNameVerifyResult, LocalAuthority, + RemoteAuthority, + }; + + pub mod handy { + pub use dquic::qconnection::tls::AcceptAllClientAuther; + } +} + +/// dquic stream concurrency types +pub mod stream { + pub use dquic::{ + prelude::VarInt, + qbase::sid::{ + ControlStreamsConcurrency, Dir, ProductStreamsConcurrencyController, StreamId, + }, + }; + + pub mod handy { + pub use dquic::qbase::sid::handy::*; + } +} + +/// dquic telemetry / logging types +pub mod log { + pub use dquic::qevent::telemetry::{ExportEvent, QLog, Span}; + + pub mod handy { + pub use dquic::qevent::telemetry::handy::*; + } +} + +/// dquic DNS resolution types +pub mod resolver { + pub use dquic::qresolve::{ + Publish, PublishFuture, Record, RecordStream, Resolve, ResolveFuture, ResolveResult, Source, + }; + + pub mod handy { + pub use dquic::qresolve::SystemResolver; + } +} + +/// dquic network address, binding, interface, and IO types +pub mod net { + pub use dquic::{ + prelude::{IO, IoExt}, + qbase::{ + cid::ConnectionId, + net::{self, Family, addr::EndpointAddr}, + }, + qinterface::{ + BindInterface, + bind_uri::{BindUri, ParseError, Scheme}, + component::{location::Locations, route::QuicRouter}, + device::Devices, + io::ProductIO, + manager::InterfaceManager, + }, + }; + + pub mod handy { + pub use dquic::qinterface::io::handy::*; + } +} + +/// dquic certificate utilities +pub mod cert { + /// convenience certificate helpers + pub mod handy { + pub use dquic::prelude::handy::{ToCertificate, ToPrivateKey}; + } +} + +/// dquic connection types +pub mod connection { + pub use dquic::prelude::Connection; +} + +/// Type alias for the default `H3Endpoint`. +pub type H3Endpoint = crate::endpoint::H3Endpoint; diff --git a/src/endpoint/binds/mod.rs b/src/dquic/binds.rs similarity index 75% rename from src/endpoint/binds/mod.rs rename to src/dquic/binds.rs index 50755e8..6dbbbbe 100644 --- a/src/endpoint/binds/mod.rs +++ b/src/dquic/binds.rs @@ -1,6 +1,6 @@ //! Extended bind pattern for flexible BindUri generation. //! -//! [`Bind`] is a pattern-like extension of +//! [`BindPattern`] is a pattern-like extension of //! [`BindUri`](crate::dquic::qinterface::bind_uri::BindUri) that provides: //! //! 1. **Glob host** — `iface://v4.en*:8080` matches all interfaces starting with "en" @@ -14,21 +14,11 @@ //! //! All extensions compose freely: `en*:8080`, `*`, `v4.*:8080`, `[ew]*`, `[::1]:8080`, etc. -mod collection; -mod error; mod host; mod pattern; -pub mod setup; -pub use std::net::IpAddr; - -pub use collection::Binds; -pub use error::BindConflictError; pub use host::BindHost; -pub use pattern::Bind; -pub use setup::{ - BindSetup, setup_bind_interfaces, setup_bind_interfaces_with, watch_bind_interfaces, -}; +pub use pattern::BindPattern; #[cfg(test)] mod tests; diff --git a/src/endpoint/binds/host.rs b/src/dquic/binds/host.rs similarity index 54% rename from src/endpoint/binds/host.rs rename to src/dquic/binds/host.rs index f767700..74af079 100644 --- a/src/endpoint/binds/host.rs +++ b/src/dquic/binds/host.rs @@ -1,10 +1,14 @@ -use std::{fmt, net::IpAddr}; +use std::{ + fmt, + hash::{Hash, Hasher}, + net::IpAddr, +}; use globset::{Glob, GlobMatcher}; use crate::dquic::qbase::net::Family; -/// The host part of a [`Bind`](super::Bind) — a parsed IP address, a glob pattern, or an exact name. +/// The host part of a [`BindPattern`](super::BindPattern) — a parsed IP address, a glob pattern, or an exact name. /// /// Literal host names (e.g. `enp17s0`) are normally represented as /// [`Glob`](BindHost::Glob) with a literal pattern (globset treats them as exact @@ -189,6 +193,25 @@ impl PartialEq for BindHost { impl Eq for BindHost {} +impl Hash for BindHost { + fn hash(&self, state: &mut H) { + std::mem::discriminant(self).hash(state); + match self { + Self::Ip { repr, .. } => repr.hash(state), + Self::Glob { + family, matcher, .. + } => { + family.hash(state); + matcher.glob().glob().hash(state); + } + Self::Exact { family, nic, .. } => { + family.hash(state); + nic.hash(state); + } + } + } +} + impl fmt::Debug for BindHost { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { @@ -216,3 +239,132 @@ impl fmt::Display for BindHost { f.write_str(self.as_str()) } } + +#[cfg(test)] +mod tests { + use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + }; + + use super::*; + use crate::dquic::qbase::net::Family; + + fn hash(host: &BindHost) -> u64 { + let mut hasher = DefaultHasher::new(); + host.hash(&mut hasher); + hasher.finish() + } + + #[test] + fn classify_ip_and_glob_variants() { + let ipv4 = BindHost::classify("127.0.0.1", None).unwrap(); + assert!(ipv4.is_ip_addr()); + assert!(!ipv4.is_glob()); + assert_eq!(ipv4.as_str(), "127.0.0.1"); + assert_eq!(ipv4.as_ip_addr(), Some("127.0.0.1".parse().unwrap())); + + let ipv6 = BindHost::classify("[::1]", None).unwrap(); + assert!(ipv6.is_ip_addr()); + assert_eq!(ipv6.as_str(), "::1"); + assert_eq!(ipv6.families(), [Family::V4, Family::V6]); + + let glob = BindHost::classify("en*", None).unwrap(); + assert!(!glob.is_ip_addr()); + assert!(glob.is_glob()); + assert_eq!(glob.as_str(), "en*"); + assert_eq!(glob.families(), [Family::V4, Family::V6]); + } + + #[test] + fn classify_rejects_family_for_ip_and_falls_back_to_exact_on_bad_glob() { + let err = BindHost::classify("127.0.0.1", Some(Family::V4)) + .expect_err("family prefix should be rejected for plain IP address"); + assert_eq!(err, "family prefix is not valid for IP addresses"); + + let err = BindHost::classify("[::1]", Some(Family::V6)) + .expect_err("family prefix should be rejected for bracketed IPv6"); + assert_eq!(err, "family prefix is not valid for IP addresses"); + + let exact = BindHost::classify("[", None).unwrap(); + assert!(matches!(exact, BindHost::Exact { family: None, ref nic } if nic == "[")); + assert!(!exact.is_glob()); + assert_eq!(exact.family(), None); + assert_eq!(exact.families(), [Family::V4, Family::V6]); + } + + #[test] + fn host_matching_and_identity_accessors() { + let exact = BindHost::classify("enp17s0", Some(Family::V4)).unwrap(); + assert_eq!(exact.family(), Some(Family::V4)); + assert_eq!(exact.as_str(), "enp17s0"); + assert!(!exact.matches("eth0")); + assert!(exact.matches("enp17s0")); + + let ip = BindHost::classify("192.168.0.1", None).unwrap(); + assert!(ip.is_ip_addr()); + assert_eq!(ip.as_ip_addr().unwrap().to_string(), "192.168.0.1"); + assert!(!ip.matches("192.168.0.1")); + } + + #[test] + fn exact_host_with_v6_family_matches_and_compares_by_family() { + let exact = BindHost::classify("[", Some(Family::V6)).unwrap(); + assert_eq!(exact.family(), Some(Family::V6)); + assert_eq!(exact.families(), [Family::V6]); + assert!(exact.matches("[")); + assert!(!exact.matches("]")); + assert_eq!(exact, BindHost::classify("[", Some(Family::V6)).unwrap()); + assert_ne!(exact, BindHost::classify("[", Some(Family::V4)).unwrap()); + } + + #[test] + fn host_display_debug_and_hash() { + let literal = BindHost::classify("enp17s0", None).unwrap(); + assert_eq!(format!("{literal}"), "enp17s0"); + assert!( + matches!(&format!("{literal:?}"), s if s.contains("Glob") && s.contains("pattern")) + ); + assert_eq!( + hash(&literal), + hash(&BindHost::classify("enp17s0", None).unwrap()) + ); + + let exact = BindHost::classify("[", None).unwrap(); + assert_eq!(format!("{exact}"), "["); + assert_eq!(exact.as_str(), "["); + assert!(matches!(&format!("{exact:?}"), s if s.contains("Exact") && s.contains("nic"))); + assert_eq!(hash(&exact), hash(&BindHost::classify("[", None).unwrap())); + + let ip = BindHost::classify("::1", None).unwrap(); + let ip_debug = format!("{ip:?}"); + assert!(ip_debug.contains("Ip")); + assert!(ip_debug.contains("addr")); + assert!(ip_debug.contains("repr")); + assert_eq!(format!("{ip}"), "::1"); + assert_eq!(hash(&ip), hash(&BindHost::classify("::1", None).unwrap())); + } + + #[test] + fn partial_eq_for_all_variants() { + let exact_a = BindHost::classify("[", None).unwrap(); + let exact_b = BindHost::classify("[", None).unwrap(); + assert_eq!(exact_a, exact_b); + + let exact_c = BindHost::classify("[", Some(Family::V4)).unwrap(); + assert_ne!(exact_a, exact_c); + + let glob_a = BindHost::classify("en*", None).unwrap(); + let glob_b = BindHost::classify("en*", Some(Family::V6)).unwrap(); + assert_ne!(glob_a, glob_b); + assert_eq!(glob_a, BindHost::classify("en*", None).unwrap()); + assert_ne!(exact_a, glob_a); + + let ip_a = BindHost::classify("::1", None).unwrap(); + let ip_b = BindHost::classify("::1", None).unwrap(); + let ipv4 = BindHost::classify("127.0.0.1", None).unwrap(); + assert_eq!(ip_a, ip_b); + assert_ne!(ip_a, ipv4); + assert_ne!(glob_a, ip_a); + } +} diff --git a/src/endpoint/binds/pattern.rs b/src/dquic/binds/pattern.rs similarity index 59% rename from src/endpoint/binds/pattern.rs rename to src/dquic/binds/pattern.rs index 56ef8b0..63963d9 100644 --- a/src/endpoint/binds/pattern.rs +++ b/src/dquic/binds/pattern.rs @@ -1,4 +1,9 @@ -use std::{fmt, str::FromStr}; +use std::{ + cell::LazyCell, + fmt, + hash::{Hash, Hasher}, + str::FromStr, +}; use either::Either; use http::{ @@ -9,17 +14,17 @@ use peg::{error::ParseError, str::LineCol}; use super::BindHost; use crate::dquic::{ - qbase::net::Family, - qinterface::bind_uri::{BindUri, BindUriScheme}, + net::Family, + qinterface::bind_uri::{BindUri, Scheme}, }; /// A flexible bind pattern parsed from a string. /// /// See [module documentation](super) for the full syntax description. #[derive(Debug, Clone, PartialEq, Eq)] -pub struct Bind { +pub struct BindPattern { /// The resolved scheme (`iface` or `inet`). Always present after parsing. - pub scheme: BindUriScheme, + pub scheme: Scheme, /// Host part — exact name/IP or glob pattern (carries family if applicable). pub host: BindHost, /// Port number. `None` means default (0 = system-assigned). @@ -29,11 +34,69 @@ pub struct Bind { pub path_and_query: Option, } +impl Hash for BindPattern { + fn hash(&self, state: &mut H) { + self.scheme.hash(state); + self.host.hash(state); + self.port.hash(state); + self.path_and_query + .as_ref() + .map(|pq| pq.as_str()) + .hash(state); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn match_helpers_reject_uri_of_other_shape() { + let iface_pattern: BindPattern = "iface://v4.en*:8080".parse().unwrap(); + let inet_uri: BindUri = "inet://127.0.0.1:8080".parse().unwrap(); + assert!(!iface_pattern.matches_iface_bind_uri(&inet_uri)); + + let inet_pattern: BindPattern = "inet://127.0.0.1:8080".parse().unwrap(); + let iface_uri: BindUri = "iface://v4.enp17s0:8080".parse().unwrap(); + assert!(!inet_pattern.matches_inet_bind_uri(&iface_uri)); + } + + #[test] + fn ip_hosts_do_not_match_interface_links() { + let pattern: BindPattern = "127.0.0.1:8080".parse().unwrap(); + + assert_eq!(pattern.match_interface_links("lo").count(), 0); + } + + #[test] + fn interface_bind_uris_expand_only_iface_patterns() { + let iface_pattern: BindPattern = "iface://v4.lo:8080".parse().unwrap(); + let inet_pattern: BindPattern = "inet://127.0.0.1:8080".parse().unwrap(); + + let iface_uris: Vec<_> = iface_pattern + .interface_bind_uris("lo") + .map(|uri| uri.to_string()) + .collect(); + + assert_eq!(iface_uris, ["iface://v4.lo:8080/"]); + assert!(inet_pattern.interface_bind_uris("lo").next().is_none()); + } + + #[test] + fn unknown_explicit_scheme_falls_back_to_iface() { + let pattern: BindPattern = "custom://v4.en*:8080/path?query".parse().unwrap(); + + assert_eq!(pattern.scheme, Scheme::Iface); + assert_eq!(pattern.path_and_query_str(), Some("/path?query")); + assert_eq!(pattern.to_string(), "iface://v4.en*:8080/path?query"); + } +} + // --------------------------------------------------------------------------- // Display // --------------------------------------------------------------------------- -impl fmt::Display for Bind { +impl fmt::Display for BindPattern { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}://", self.scheme)?; if let Some(family) = self.host.family() { @@ -106,7 +169,7 @@ peg::parser! { // -- composite rules -- /// `scheme://family.host:port/path?query` (full form) - pub rule full() -> Bind + pub rule full() -> BindPattern = s:scheme() fam:(f:family() "." { f })? h:host_str() @@ -119,11 +182,11 @@ peg::parser! { .map(|s| s.parse::()) .transpose() .map_err(|_| "valid path-and-query")?; - Ok(Bind { scheme, host, port: p, path_and_query }) + Ok(BindPattern { scheme, host, port: p, path_and_query }) } /// `family.host:port/path?query` (no scheme) - pub rule no_scheme() -> Bind + pub rule no_scheme() -> BindPattern = fam:(f:family() "." { f })? h:host_str() p:port()? @@ -135,14 +198,14 @@ peg::parser! { .map(|s| s.parse::()) .transpose() .map_err(|_| "valid path-and-query")?; - Ok(Bind { scheme, host, port: p, path_and_query }) + Ok(BindPattern { scheme, host, port: p, path_and_query }) } /// Top-level entry: bare IP first, then full form, then no-scheme. /// /// `bare_ip` has highest priority — its `{? ... }` semantic guard /// ensures only valid IP addresses match; everything else backtracks. - pub rule bind() -> Bind + pub rule bind() -> BindPattern = b:bare_ip() { b } / b:full() { b } / b:no_scheme() { b } @@ -152,15 +215,15 @@ peg::parser! { /// Captures everything up to `/`, `?`, or `#` (or end of input) and /// validates it as an [`IpAddr`]. Falls back via PEG ordered choice /// if validation fails. - rule bare_ip() -> Bind + rule bare_ip() -> BindPattern = s:$([^ '/' | '?' | '#']+) pq:path_and_query()? {? let addr = s.parse::().or(Err("valid IP address"))?; let path_and_query = pq .map(|s| s.parse::()) .transpose() .map_err(|_| "valid path-and-query")?; - Ok(Bind { - scheme: BindUriScheme::Inet, + Ok(BindPattern { + scheme: Scheme::Inet, host: BindHost::Ip { addr, repr: s.to_owned() }, port: None, path_and_query, @@ -174,18 +237,18 @@ peg::parser! { // --------------------------------------------------------------------------- /// Infer the bind scheme from an optional explicit scheme string and the host. -fn infer_scheme(explicit: Option<&str>, host: &BindHost) -> BindUriScheme { +fn infer_scheme(explicit: Option<&str>, host: &BindHost) -> Scheme { if let Some(s) = explicit { return match s.to_ascii_lowercase().as_str() { - "iface" => BindUriScheme::Iface, - "inet" => BindUriScheme::Inet, - _ => BindUriScheme::Iface, + "iface" => Scheme::Iface, + "inet" => Scheme::Inet, + _ => Scheme::Iface, }; } if host.is_ip_addr() { - BindUriScheme::Inet + Scheme::Inet } else { - BindUriScheme::Iface + Scheme::Iface } } @@ -193,7 +256,7 @@ fn infer_scheme(explicit: Option<&str>, host: &BindHost) -> BindUriScheme { // FromStr // --------------------------------------------------------------------------- -impl FromStr for Bind { +impl FromStr for BindPattern { type Err = ParseError; fn from_str(s: &str) -> Result { @@ -202,23 +265,96 @@ impl FromStr for Bind { } // --------------------------------------------------------------------------- -// Bind → BindUri expansion +// BindPattern → BindUri expansion // --------------------------------------------------------------------------- -impl Bind { +impl BindPattern { + /// Returns the path-and-query as a string slice, if present. + #[must_use] + pub fn path_and_query_str(&self) -> Option<&str> { + self.path_and_query.as_ref().map(|pq| pq.as_str()) + } + /// Returns the effective port (defaults to 0 when omitted). #[must_use] pub fn effective_port(&self) -> u16 { self.port.unwrap_or(0) } - /// Returns the path-and-query as a `&str`. + /// Check if a concrete [`BindUri`] could be produced by this pattern. + /// + /// Compares scheme, port, and host. Wildcard ports (None) match any port. + /// For `iface://` URIs, the family prefix is matched separately from the + /// interface name. Glob/exact hosts use [`BindHost::matches`] for pattern + /// matching. #[must_use] - pub fn path_and_query_str(&self) -> Option<&str> { - self.path_and_query.as_ref().map(|pq| pq.as_str()) + pub fn matches(&self, bind_uri: &BindUri) -> bool { + if self.scheme != bind_uri.scheme() { + return false; + } + match self.scheme { + Scheme::Iface => self.matches_iface_bind_uri(bind_uri), + Scheme::Inet => self.matches_inet_bind_uri(bind_uri), + _ => false, + } + } + + pub(crate) fn interface_bind_uris<'a>( + &'a self, + interface: &'a str, + ) -> impl Iterator + use<'a> { + let template = self.template(); + match self.scheme { + Scheme::Iface => { + Either::Left(self.match_interface_links(interface).filter_map(template)) + } + _ => Either::Right(std::iter::empty()), + } + } + + fn port_matches(&self, actual: u16) -> bool { + if let Some(expected) = self.port + && expected != actual + { + return false; + } + true + } + + fn matches_iface_bind_uri(&self, bind_uri: &BindUri) -> bool { + let Some((family, interface, port)) = bind_uri.as_iface_bind_uri() else { + return false; + }; + if !self.port_matches(port) { + return false; + } + match &self.host { + BindHost::Ip { .. } => false, + host => { + if let Some(pattern_family) = host.family() + && pattern_family != family + { + return false; + } + host.matches(interface) + } + } } - pub(crate) fn bind_uri_template(&self) -> impl Fn(Authority) -> Option + use<> { + fn matches_inet_bind_uri(&self, bind_uri: &BindUri) -> bool { + let Some(addr) = bind_uri.as_inet_bind_uri() else { + return false; + }; + if !self.port_matches(addr.port()) { + return false; + } + match &self.host { + BindHost::Ip { addr: pattern, .. } => *pattern == addr.ip(), + BindHost::Glob { .. } | BindHost::Exact { .. } => false, + } + } + + pub(crate) fn template(&self) -> impl Fn(Authority) -> Option + use<> { let mut uri_template = Uri::from_static("iface://v4.lo:0/").into_parts(); uri_template.scheme = Some(self.scheme.into()); uri_template.path_and_query = @@ -226,24 +362,18 @@ impl Bind { let uri_template = Uri::from_parts(uri_template) .expect("BUG: bind URI template built from valid scheme and path-and-query"); - let port = self.effective_port(); move |authority: Authority| { let mut uri_parts = uri_template.clone().into_parts(); + // original authority is just a placeholder; replace it with the actual authority for every bind URI. uri_parts.authority = Some(authority); - let mut bind_uri = + let bind_uri = (Uri::from_parts(uri_parts).ok()).and_then(|uri| BindUri::try_from(uri).ok())?; - if port == 0 { - bind_uri = bind_uri.alloc_port(); - } Some(bind_uri) } } - pub(crate) fn bind_hosts_for_interface( - &self, - interface: &str, - ) -> impl Iterator { + pub(crate) fn match_interface_links(&self, interface: &str) -> impl Iterator { match &self.host { BindHost::Ip { .. } => Either::Left(std::iter::empty()), host if !host.matches(interface) => Either::Left(std::iter::empty()), @@ -267,24 +397,26 @@ impl Bind { where I: IntoIterator, { - let template = self.bind_uri_template(); + let template = LazyCell::new(|| self.template()); let port = self.effective_port(); match &self.host { BindHost::Ip { addr, .. } => { - let authority: Authority = if addr.is_ipv6() { + let link: Authority = if addr.is_ipv6() { format!("[{addr}]:{port}") } else { format!("{addr}:{port}") } .parse() .expect("BUG: formatted IP address and port is a valid authority"); - Either::Left(template(authority).into_iter()) + Either::Left(template(link).into_iter()) } + // WORKAROUND: clippy bug: https://github.com/rust-lang/rust-clippy/issues/16641 (not fixed) + #[allow(clippy::redundant_closure)] BindHost::Glob { .. } | BindHost::Exact { .. } => Either::Right( interfaces .into_iter() - .flat_map(move |iface| self.bind_hosts_for_interface(iface)) - .flat_map(template), + .flat_map(move |iface| self.match_interface_links(iface)) + .flat_map(move |link| template(link)), ), } } diff --git a/src/dquic/binds/tests.rs b/src/dquic/binds/tests.rs new file mode 100644 index 0000000..27c7107 --- /dev/null +++ b/src/dquic/binds/tests.rs @@ -0,0 +1,726 @@ +#![allow(clippy::type_complexity)] + +use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + net::IpAddr, +}; + +use super::*; +use crate::dquic::{ + net::Family, + qinterface::bind_uri::{BindUri, Scheme}, +}; + +fn pattern_hash(pattern: &BindPattern) -> u64 { + let mut hasher = DefaultHasher::new(); + pattern.hash(&mut hasher); + hasher.finish() +} + +// ============================================================================ +// Parsing — core parsing tests (from original "Parsing tests" section + +// glob bracket parsing + path_and_query validation) +// ============================================================================ + +#[test] +fn parsing_core() { + let cases: Vec<(&str, Box)> = vec![ + // parse_full_iface_with_family + ( + "iface://v4.enp17s0:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V4)); + assert_eq!( + b.host, + BindHost::classify("enp17s0", Some(Family::V4)).unwrap() + ); + assert_eq!(b.port, Some(8080)); + assert!(b.path_and_query.is_none()); + }), + ), + // parse_full_iface_glob + ( + "iface://v4.en*:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V4)); + assert!(b.host.is_glob()); + assert_eq!(b.host.as_str(), "en*"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_iface_no_family + ( + "iface://enp17s0:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), None); + assert!(!b.host.is_glob()); + assert_eq!(b.host.as_str(), "enp17s0"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_iface_no_port + ( + "iface://v4.enp17s0", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V4)); + assert_eq!(b.host.as_str(), "enp17s0"); + assert_eq!(b.port, None); + }), + ), + // parse_inet + ( + "inet://127.0.0.1:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.family(), None); + assert_eq!(b.host.as_str(), "127.0.0.1"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_no_scheme_ip + ( + "127.0.0.1:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "127.0.0.1"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_no_scheme_iface + ( + "enp17s0:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), None); + assert_eq!(b.host.as_str(), "enp17s0"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_no_scheme_with_family + ( + "v4.enp17s0:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V4)); + assert_eq!(b.host.as_str(), "enp17s0"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_glob_no_scheme + ( + "en*:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), None); + assert!(b.host.is_glob()); + assert_eq!(b.host.as_str(), "en*"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_star_only + ( + "*", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), None); + assert!(b.host.is_glob()); + assert_eq!(b.host.as_str(), "*"); + assert_eq!(b.port, None); + }), + ), + // parse_star_with_port + ( + "*:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), None); + assert!(b.host.is_glob()); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_v4_star + ( + "v4.*", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V4)); + assert!(b.host.is_glob()); + assert_eq!(b.port, None); + }), + ), + // parse_v6_star_with_port + ( + "v6.*:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V6)); + assert!(b.host.is_glob()); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_no_scheme_no_port + ( + "enp17s0", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), None); + assert_eq!(b.host.as_str(), "enp17s0"); + assert_eq!(b.port, None); + }), + ), + // parse_with_path_and_query + ( + "iface://v4.en*:8080/?stun_server=stun.genmeta.net", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert_eq!(b.host.family(), Some(Family::V4)); + assert!(b.host.is_glob()); + assert_eq!(b.port, Some(8080)); + assert_eq!( + b.path_and_query_str(), + Some("/?stun_server=stun.genmeta.net") + ); + }), + ), + // parse_with_query_only + ( + "iface://v4.enp17s0:8080?stun=true", + Box::new(|b: &BindPattern| { + assert_eq!(b.path_and_query_str(), Some("?stun=true")); + }), + ), + // parse_glob_bracket_class (from "Glob bracket parsing" section) + ( + "[ew]*:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Iface); + assert!(b.host.is_glob()); + assert_eq!(b.host.as_str(), "[ew]*"); + assert_eq!(b.port, Some(8080)); + }), + ), + // path_and_query_is_validated (from "BindUri generation" section) + ( + "iface://v4.en*:8080/?key=value", + Box::new(|b: &BindPattern| { + let pq = b.path_and_query.as_ref().unwrap(); + assert_eq!(pq, "/?key=value"); + }), + ), + ]; + + for (input, check) in &cases { + let b: BindPattern = input + .parse() + .unwrap_or_else(|e| panic!("failed to parse '{input}': {e}")); + check(&b); + } +} + +// ============================================================================ +// Parsing — IPv6 bracket syntax tests +// ============================================================================ + +#[test] +fn parsing_ipv6_bracket() { + let cases: Vec<(&str, Box)> = vec![ + // parse_ipv6_full_scheme + ( + "inet://[::1]:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.family(), None); + assert_eq!(b.host.as_str(), "::1"); + assert_eq!(b.port, Some(8080)); + assert!(b.host.is_ip_addr()); + }), + ), + // parse_ipv6_no_scheme + ( + "[::1]:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "::1"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_ipv6_full_addr + ( + "inet://[2001:db8::1]:443", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "2001:db8::1"); + assert_eq!(b.port, Some(443)); + assert!(b.host.is_ip_addr()); + }), + ), + // parse_ipv6_link_local + ( + "[fe80::1]:8080", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "fe80::1"); + assert_eq!(b.port, Some(8080)); + }), + ), + // parse_ipv6_any + ( + "inet://[::]:0", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "::"); + assert_eq!(b.port, Some(0)); + }), + ), + // parse_ipv6_no_port + ( + "[::1]", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "::1"); + assert_eq!(b.port, None); + }), + ), + // parse_ipv6_with_path_and_query + ( + "inet://[::1]:8080/?key=value", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert_eq!(b.host.as_str(), "::1"); + assert_eq!(b.port, Some(8080)); + assert_eq!(b.path_and_query_str(), Some("/?key=value")); + }), + ), + // ipv6_host_is_ip_addr + ( + "[::1]:8080", + Box::new(|b: &BindPattern| { + assert!(b.host.is_ip_addr()); + assert!(b.host.as_ip_addr().unwrap().is_ipv6()); + }), + ), + ]; + + for (input, check) in &cases { + let b: BindPattern = input + .parse() + .unwrap_or_else(|e| panic!("failed to parse '{input}': {e}")); + check(&b); + } +} + +// ============================================================================ +// Parsing — bare IP address tests +// ============================================================================ + +#[test] +fn parsing_bare_ip() { + let cases: Vec<(&str, Box)> = vec![ + // parse_bare_ipv6_loopback + ( + "::1", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert!(b.host.is_ip_addr()); + assert_eq!(b.host.as_str(), "::1"); + assert_eq!(b.port, None); + assert!(b.path_and_query.is_none()); + }), + ), + // parse_bare_ipv6_any + ( + "::", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert!(b.host.is_ip_addr()); + assert_eq!(b.host.as_str(), "::"); + assert_eq!(b.port, None); + }), + ), + // parse_bare_ipv6_full + ( + "2001:db8::1", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert!(b.host.is_ip_addr()); + assert_eq!(b.host.as_str(), "2001:db8::1"); + assert_eq!(b.port, None); + }), + ), + // parse_bare_ipv4 + ( + "192.168.1.1", + Box::new(|b: &BindPattern| { + assert_eq!(b.scheme, Scheme::Inet); + assert!(b.host.is_ip_addr()); + assert_eq!(b.host.as_str(), "192.168.1.1"); + assert_eq!(b.port, None); + }), + ), + ]; + + for (input, check) in &cases { + let b: BindPattern = input + .parse() + .unwrap_or_else(|e| panic!("failed to parse '{input}': {e}")); + check(&b); + } +} + +#[test] +fn parsing_suffix_edge_cases() { + let no_scheme: BindPattern = "V4.enp17s0:8080?reuse=true".parse().unwrap(); + assert_eq!(no_scheme.scheme, Scheme::Iface); + assert_eq!(no_scheme.host.family(), Some(Family::V4)); + assert_eq!(no_scheme.host.as_str(), "enp17s0"); + assert_eq!(no_scheme.port, Some(8080)); + assert_eq!(no_scheme.path_and_query_str(), Some("?reuse=true")); + + let bare_ip: BindPattern = "127.0.0.1?temporary=true".parse().unwrap(); + assert_eq!(bare_ip.scheme, Scheme::Inet); + assert_eq!(bare_ip.host.as_str(), "127.0.0.1"); + assert_eq!(bare_ip.port, None); + assert_eq!(bare_ip.path_and_query_str(), Some("?temporary=true")); +} + +// ============================================================================ +// Display roundtrip +// ============================================================================ + +#[test] +fn display_roundtrip() { + let cases: &[(&str, &str)] = &[ + // display_full + ("iface://v4.enp17s0:8080", "iface://v4.enp17s0:8080"), + // display_normalizes_uppercase_family + ( + "iface://V6.enp17s0:8080?reuse=true", + "iface://v6.enp17s0:8080/?reuse=true", + ), + // display_no_port + ("iface://v4.enp17s0", "iface://v4.enp17s0"), + // display_no_family + ("iface://enp17s0:8080", "iface://enp17s0:8080"), + // display_bare_ipv6 + ("::1", "inet://[::1]"), + // display_ipv6_roundtrip + ("inet://[::1]:8080", "inet://[::1]:8080"), + // display_ipv6_full_addr + ("inet://[2001:db8::1]:443", "inet://[2001:db8::1]:443"), + ]; + + for (input, expected) in cases { + let b: BindPattern = input + .parse() + .unwrap_or_else(|e| panic!("failed to parse '{input}': {e}")); + assert_eq!(b.to_string(), *expected, "display mismatch for '{input}'"); + } +} + +#[test] +fn hash_includes_path_and_query_suffix() { + let with_suffix: BindPattern = "iface://v4.enp17s0:8080/?reuse=true".parse().unwrap(); + let same_suffix: BindPattern = "iface://v4.enp17s0:8080/?reuse=true".parse().unwrap(); + assert_eq!(pattern_hash(&with_suffix), pattern_hash(&same_suffix)); + + let without_suffix: BindPattern = "iface://v4.enp17s0:8080".parse().unwrap(); + let same_without_suffix: BindPattern = "iface://v4.enp17s0:8080".parse().unwrap(); + assert_ne!(pattern_hash(&with_suffix), pattern_hash(&without_suffix)); + assert_eq!( + pattern_hash(&without_suffix), + pattern_hash(&same_without_suffix) + ); +} + +// ============================================================================ +// Glob matching +// ============================================================================ + +#[test] +fn glob_matching() { + // glob_exact_match + { + let host = BindHost::classify("enp17s0", None).unwrap(); + assert!(host.matches("enp17s0")); + assert!(!host.matches("wlan0")); + } + + // glob_star_match + { + let host = BindHost::classify("en*", None).unwrap(); + assert!(host.matches("enp17s0")); + assert!(host.matches("eno1")); + assert!(!host.matches("wlan0")); + + let star = BindHost::classify("*", None).unwrap(); + assert!(star.matches("anything")); + } + + // glob_bracket_class + { + let host = BindHost::classify("[ew]*", None).unwrap(); + assert!(host.is_glob()); + assert!(host.matches("enp17s0")); + assert!(host.matches("wlan0")); + assert!(!host.matches("lo")); + } + + // glob_bracket_single + { + let host = BindHost::classify("wlan[01]", None).unwrap(); + assert!(host.is_glob()); + assert!(host.matches("wlan0")); + assert!(host.matches("wlan1")); + assert!(!host.matches("wlan2")); + } +} + +// ============================================================================ +// Classify +// ============================================================================ + +#[test] +fn classify() { + // classify_ipv4_as_ip + { + let host = BindHost::classify("127.0.0.1", None).unwrap(); + assert!(host.is_ip_addr()); + assert!(!host.is_glob()); + assert_eq!(host.as_str(), "127.0.0.1"); + } + + // classify_ipv6_bracket_as_ip + { + let host = BindHost::classify("[::1]", None).unwrap(); + assert!(host.is_ip_addr()); + assert_eq!(host.as_str(), "::1"); + assert_eq!(host.as_ip_addr().unwrap(), "::1".parse::().unwrap()); + } + + // classify_bracket_non_ip_as_glob + { + let host = BindHost::classify("[ew]", None).unwrap(); + assert!(host.is_glob()); + assert!(!host.is_ip_addr()); + } +} + +// ============================================================================ +// Families +// ============================================================================ + +#[test] +fn families() { + // families_both + { + let b: BindPattern = "enp17s0:8080".parse().unwrap(); + assert_eq!(b.host.families(), [Family::V4, Family::V6]); + } + + // families_v4_only + { + let b: BindPattern = "v4.enp17s0:8080".parse().unwrap(); + assert_eq!(b.host.families(), [Family::V4]); + } +} + +// ============================================================================ +// Expand — basic BindUri generation +// ============================================================================ + +#[test] +fn expand_basic() { + // expand_iface + { + let b: BindPattern = "iface://v4.enp17s0:8080".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); + assert_eq!(uris, vec!["iface://v4.enp17s0:8080/"]); + } + + // expand_both_families + { + let b: BindPattern = "iface://enp17s0:8080".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); + assert_eq!( + uris, + vec!["iface://v4.enp17s0:8080/", "iface://v6.enp17s0:8080/"] + ); + } + + // expand_auto_port + { + let b: BindPattern = "iface://v4.enp17s0".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); + assert_eq!(uris.len(), 1); + assert!(uris[0].starts_with("iface://v4.enp17s0:0/")); + } + + // expand_inet + { + let b: BindPattern = "127.0.0.1:8080".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); + assert_eq!(uris, vec!["inet://127.0.0.1:8080/"]); + } + + // expand_path_and_query_passthrough + { + let b: BindPattern = "iface://v4.en*:8080/?stun_server=stun.genmeta.net" + .parse() + .unwrap(); + let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); + assert_eq!( + uris, + vec!["iface://v4.enp17s0:8080/?stun_server=stun.genmeta.net"] + ); + } +} + +// ============================================================================ +// Expand — glob patterns +// ============================================================================ + +#[test] +fn expand_glob() { + // expand_with_interfaces_glob + { + let b: BindPattern = "en*:8080".parse().unwrap(); + let interfaces = ["enp17s0", "eno1", "wlan0", "lo"]; + let uris: Vec<_> = b.to_bind_uris(interfaces).collect(); + // en* matches enp17s0 and eno1, each with V4 + V6 + assert_eq!(uris.len(), 4); + } + + // expand_with_interfaces_star + { + let b: BindPattern = "*:8080".parse().unwrap(); + let interfaces = ["enp17s0", "wlan0"]; + let uris: Vec<_> = b.to_bind_uris(interfaces).collect(); + // * matches all, each with V4 + V6 + assert_eq!(uris.len(), 4); + } +} + +// ============================================================================ +// Expand — IPv6 +// ============================================================================ + +#[test] +fn expand_ipv6() { + // expand_bare_ipv6 + { + let b: BindPattern = "::1".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); + assert_eq!(uris.len(), 1); + assert!(uris[0].starts_with("inet://[::1]:0/")); + } + + // expand_ipv6 (from "IPv6 bracket syntax tests" section) + { + let b: BindPattern = "inet://[::1]:8080".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); + assert_eq!(uris, vec!["inet://[::1]:8080/"]); + } + + // expand_ipv6_auto_port + { + let b: BindPattern = "[::1]".parse().unwrap(); + let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); + assert_eq!(uris.len(), 1); + assert!(uris[0].starts_with("inet://[::1]:0/")); + } +} + +#[test] +fn ipv6_ip_pattern_matches_generated_bind_uri() { + let pattern: BindPattern = "inet://[::]:4433".parse().unwrap(); + let bind_uri: BindUri = "inet://[::]:4433".parse().unwrap(); + + assert!(pattern.matches(&bind_uri)); +} + +// ============================================================================ +// Errors +// ============================================================================ + +#[test] +fn parse_errors() { + // family_ip_rejected + assert!( + "v4.127.0.0.1:8080".parse::().is_err(), + "family prefix should be rejected for IP addresses" + ); + assert!( + "inet://v6.[::1]:8080".parse::().is_err(), + "family prefix should be rejected for IP addresses" + ); + assert!( + "iface://enp17s0:70000".parse::().is_err(), + "ports above u16::MAX should be rejected" + ); +} + +#[test] +fn iface_pattern_matches_generated_bind_uri() { + let pattern: BindPattern = "iface://v4.en*:8080".parse().unwrap(); + let generated = pattern + .to_bind_uris(["enp17s0"]) + .next() + .expect("pattern should generate one bind uri"); + + assert!( + pattern.matches(&generated), + "pattern {pattern} should match generated bind uri {generated}" + ); +} + +#[test] +fn iface_pattern_matches_family_interface_and_port_separately() { + let pattern: BindPattern = "iface://v4.en*:8080".parse().unwrap(); + + assert!(pattern.matches(&"iface://v4.enp17s0:8080".parse().unwrap())); + assert!(!pattern.matches(&"iface://v6.enp17s0:8080".parse().unwrap())); + assert!(!pattern.matches(&"iface://v4.wlan0:8080".parse().unwrap())); + assert!(!pattern.matches(&"iface://v4.enp17s0:8081".parse().unwrap())); + + let wildcard_family_and_port: BindPattern = "iface://en*".parse().unwrap(); + assert!(wildcard_family_and_port.matches(&"iface://v6.eno1:4433".parse().unwrap())); +} + +#[test] +fn pattern_matching_rejects_scheme_and_host_class_mismatches() { + let iface_pattern: BindPattern = "iface://v4.en*:8080".parse().unwrap(); + let inet_uri: BindUri = "inet://127.0.0.1:8080".parse().unwrap(); + assert!(!iface_pattern.matches(&inet_uri)); + + let iface_ip_pattern: BindPattern = "iface://127.0.0.1:8080".parse().unwrap(); + let iface_uri: BindUri = "iface://v4.enp17s0:8080".parse().unwrap(); + assert!(!iface_ip_pattern.matches(&iface_uri)); +} + +#[test] +fn inet_pattern_matches_ip_address_and_port_only() { + let pattern: BindPattern = "inet://127.0.0.1:8080".parse().unwrap(); + + assert!(pattern.matches(&"inet://127.0.0.1:8080".parse().unwrap())); + assert!(!pattern.matches(&"inet://127.0.0.1:8081".parse().unwrap())); + assert!(!pattern.matches(&"inet://127.0.0.2:8080".parse().unwrap())); + + let non_ip_pattern: BindPattern = "inet://*:8080".parse().unwrap(); + assert!(!non_ip_pattern.matches(&"inet://127.0.0.1:8080".parse().unwrap())); +} + +#[test] +fn wildcard_port_matches_any_actual_port_for_supported_schemes() { + let iface_pattern: BindPattern = "iface://v4.en*".parse().unwrap(); + assert!(iface_pattern.matches(&"iface://v4.enp17s0:0".parse().unwrap())); + assert!(iface_pattern.matches(&"iface://v4.enp17s0:4433".parse().unwrap())); + + let inet_pattern: BindPattern = "inet://127.0.0.1".parse().unwrap(); + assert!(inet_pattern.matches(&"inet://127.0.0.1:0".parse().unwrap())); + assert!(inet_pattern.matches(&"inet://127.0.0.1:4433".parse().unwrap())); +} diff --git a/src/dquic/client.rs b/src/dquic/client.rs index 7849fcd..8b5cfa0 100644 --- a/src/dquic/client.rs +++ b/src/dquic/client.rs @@ -1,276 +1,464 @@ +//! Client-side QUIC configuration types. +//! +//! [`ClientQuicConfig`] is the primary client-side configuration struct. +//! All fields are inlined directly — the former `CommonQuicConfig` and +//! `ClientSpecificConfig` wrappers have been flattened into this single type. +//! +//! All types provide [`Default`] so that client endpoints can be constructed +//! without the caller having to hand-roll configuration values. + use std::{sync::Arc, time::Duration}; -use ::dquic::{ - builder::QuicClientBuilder, - prelude::{ - BindUri, Connection, ProductStreamsConcurrencyController, QuicClient, Resolve, - handy::{ToCertificate, ToPrivateKey}, - }, - qbase::{param::ClientParameters, token::TokenSink}, - qevent::telemetry::QLog, - qinterface::{ - component::route::QuicRouter, device::Devices, io::ProductIO, manager::InterfaceManager, - }, -}; -use rustls::{ - ClientConfig, RootCertStore, - client::WebPkiServerVerifier, - crypto::CryptoProvider, - pki_types::{CertificateDer, PrivateKeyDer}, -}; -use snafu::{ResultExt, Snafu}; +use rustls::client::{WebPkiServerVerifier, danger::ServerCertVerifier}; -use crate::{ - client::Client, - connection::ConnectionBuilder, - pool::Pool, - util::tls::{DangerousServerCertVerifier, InvalidIdentity, verify_certificate_for_name}, +use crate::dquic::{ + log::{QLog, handy::NoopLogger}, + param::{ClientParameters, handy::client_parameters}, + stream::{ProductStreamsConcurrencyController, handy::ConsistentConcurrency}, + token::{TokenSink, handy::NoopTokenRegistry}, }; -enum ServerCertVerifier { +// --------------------------------------------------------------------------- +// Client-only +// --------------------------------------------------------------------------- + +/// Strategy for verifying the server's TLS certificate. +/// +/// Kept as a small enum rather than a trait object so that +/// [`ClientQuicConfig::verifier`] composes cheaply. The `WebPki` and `Custom` +/// variants wrap their verifier in an [`Arc`] for cheap cloning. +#[derive(Clone, Default)] +pub enum ServerCertVerifierChoice { + /// Accept any certificate. Intended for local testing only. + #[default] + Dangerous, + /// Verify against a compiled webpki verifier. WebPki(Arc), - Roots(Arc), - None, + /// Delegate to a caller-supplied verifier. + Custom(Arc), } -impl Default for ServerCertVerifier { - fn default() -> Self { - Self::Roots(Arc::new(RootCertStore::empty())) +impl std::fmt::Debug for ServerCertVerifierChoice { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Dangerous => f.debug_tuple("Dangerous").finish(), + Self::WebPki(_) => f.debug_tuple("WebPki").finish(), + Self::Custom(_) => f.debug_tuple("Custom").finish(), + } } } -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum BuildClientError { - #[snafu(display("provided crypto provider cannot be used"))] - InvalidCryptoProvider { source: rustls::Error }, - #[snafu(display("provided client identity cannot be used for `{name}`"))] - InvalidIdentity { - name: String, - source: InvalidIdentity, - }, +impl PartialEq for ServerCertVerifierChoice { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Dangerous, Self::Dangerous) => true, + (Self::WebPki(a), Self::WebPki(b)) => Arc::ptr_eq(a, b), + (Self::Custom(a), Self::Custom(b)) => Arc::ptr_eq(a, b), + _ => false, + } + } } -pub struct H3ClientTlsBuilder { - server_cert_verifier: ServerCertVerifier, - client_identity: Option<(String, Vec>, PrivateKeyDer<'static>)>, - crypto_provider: Option>, - resolver: Option>, +// --- legacy ClientSpecificConfig fields (flattened into ClientQuicConfig) --- +// +// /// Client-only configuration values. +// #[derive(Clone)] +// pub struct ClientSpecificConfig { +// /// Transport parameters advertised by the client. +// pub parameters: ClientParameters, +// /// ALPN protocol identifiers to offer. Empty means no ALPN. +// pub alpns: Vec>, +// /// Address validation token sink. +// pub token_sink: Arc, +// /// How the server's certificate should be verified. +// pub verifier: ServerCertVerifierChoice, +// } +// +// impl Default for ClientSpecificConfig { ... } +// impl Debug for ClientSpecificConfig { ... } +// impl PartialEq for ClientSpecificConfig { ... } + +// --------------------------------------------------------------------------- +// Client composite (common + own) +// --------------------------------------------------------------------------- + +/// Client-side QUIC configuration — common + client-only fields flattened. +#[derive(Clone)] +pub struct ClientQuicConfig { + // --- common fields (from CommonQuicConfig) --- + /// How long the connection should keep sending probe packets after going + /// idle. `Duration::ZERO` (the default) disables deferred idle timeouts. + pub defer_idle_timeout: Duration, + /// Factory producing per-connection streams concurrency controllers. + pub stream_strategy_factory: Arc, + /// QUIC-events logger (qlog). Defaults to a no-op logger. + pub qlogger: Arc, + /// Whether 0-RTT should be enabled if the crypto context permits it. + pub enable_0rtt: bool, + /// Enable SSL key logging via `SSLKEYLOGFILE` for debugging captures. + pub enable_sslkeylog: bool, + + // --- client-specific fields (from ClientSpecificConfig) --- + /// Transport parameters advertised by the client. + pub parameters: ClientParameters, + /// ALPN protocol identifiers to offer. Empty means no ALPN. + pub alpns: Vec>, + /// Address validation token sink. + pub token_sink: Arc, + /// How the server's certificate should be verified. + pub verifier: ServerCertVerifierChoice, } -pub type H3Client = Client>; - -impl H3Client { - pub fn builder() -> H3ClientTlsBuilder { - H3ClientTlsBuilder { - server_cert_verifier: ServerCertVerifier::default(), - client_identity: None, - crypto_provider: None, - resolver: None, +impl Default for ClientQuicConfig { + fn default() -> Self { + Self { + // CommonQuicConfig::default() values + defer_idle_timeout: Duration::ZERO, + stream_strategy_factory: Arc::new(ConsistentConcurrency::new), + qlogger: Arc::new(NoopLogger), + enable_0rtt: false, + enable_sslkeylog: false, + // ClientSpecificConfig::default() values + parameters: client_parameters(), + alpns: Vec::new(), + token_sink: Arc::new(NoopTokenRegistry), + verifier: ServerCertVerifierChoice::default(), } } } -impl H3ClientTlsBuilder { - pub fn with_crypto_provider(mut self, crypto_provider: impl Into>) -> Self { - self.crypto_provider = Some(crypto_provider.into()); - self +impl std::fmt::Debug for ClientQuicConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClientQuicConfig") + .field("defer_idle_timeout", &self.defer_idle_timeout) + .field("enable_0rtt", &self.enable_0rtt) + .field("enable_sslkeylog", &self.enable_sslkeylog) + .field("alpns", &self.alpns.len()) + .field("verifier", &self.verifier) + .finish_non_exhaustive() } +} - pub fn with_root_certificates(mut self, root_store: impl Into>) -> Self { - self.server_cert_verifier = ServerCertVerifier::Roots(root_store.into()); - self +impl PartialEq for ClientQuicConfig { + fn eq(&self, other: &Self) -> bool { + self.defer_idle_timeout == other.defer_idle_timeout + && self.enable_0rtt == other.enable_0rtt + && self.enable_sslkeylog == other.enable_sslkeylog + && self.parameters == other.parameters + && self.alpns == other.alpns + && self.verifier == other.verifier + && Arc::ptr_eq( + &self.stream_strategy_factory, + &other.stream_strategy_factory, + ) + && Arc::ptr_eq(&self.qlogger, &other.qlogger) + && Arc::ptr_eq(&self.token_sink, &other.token_sink) } +} - pub fn with_webpki_verifier(mut self, verifier: Arc) -> Self { - self.server_cert_verifier = ServerCertVerifier::WebPki(verifier); - self - } +// --------------------------------------------------------------------------- +// Unit tests +// --------------------------------------------------------------------------- - pub fn without_server_cert_verification(mut self) -> Self { - self.server_cert_verifier = ServerCertVerifier::None; - self - } +#[cfg(test)] +mod client_tests { + use std::{sync::Arc, time::Duration}; - pub fn with_resolver(mut self, resolver: Arc) -> Self { - self.resolver = Some(resolver); - self - } + use rustls::{ + RootCertStore, + client::{WebPkiServerVerifier, danger::ServerCertVerifier}, + }; - pub fn with_identity( - mut self, - name: impl Into, - cert_chain: impl ToCertificate, - private_key: impl ToPrivateKey, - ) -> Result { - self.client_identity = Some(( - name.into(), - cert_chain.to_certificate(), - private_key.to_private_key(), - )); - self.try_into() - } + use crate::{ + dquic::{client::*, common::*, prelude::handy::ToCertificate}, + util::tls::DangerousServerCertVerifier, + }; + + const CA_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/ca.cert"); - pub fn without_identity(mut self) -> Result { - self.client_identity = None; - self.try_into() + fn root_store_with_ca() -> RootCertStore { + let mut store = RootCertStore::empty(); + store.add_parsable_certificates(CA_CERT.to_certificate()); + store } -} -impl TryFrom for H3ClientBuilder { - type Error = BuildClientError; - - fn try_from(builder: H3ClientTlsBuilder) -> Result { - const REQUIRED_TLS_VERSIONS: &[&rustls::SupportedProtocolVersion; 1] = - &[&rustls::version::TLS13]; - let crypto_provider = builder - .crypto_provider - .unwrap_or_else(|| ClientConfig::builder().crypto_provider().clone()); - let tls_config_buider = ClientConfig::builder_with_provider(crypto_provider) - .with_protocol_versions(REQUIRED_TLS_VERSIONS) - .context(build_client_error::InvalidCryptoProviderSnafu)?; - - let tls_config_buider = match builder.server_cert_verifier { - ServerCertVerifier::WebPki(web_pki_server_verifier) => { - tls_config_buider.with_webpki_verifier(web_pki_server_verifier) - } - ServerCertVerifier::Roots(root_cert_store) => { - tls_config_buider.with_root_certificates(root_cert_store) - } - ServerCertVerifier::None => tls_config_buider - .dangerous() - .with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier)), - }; + // -- CommonQuicConfig --------------------------------------------------- - let (tls_config, client_name) = match builder.client_identity { - Some((name, cert, key)) => { - verify_certificate_for_name(&cert[0], &name) - .context(build_client_error::InvalidIdentitySnafu { name: &name })?; - let tls_config = tls_config_buider - .with_client_auth_cert(cert, key) - .map_err(InvalidIdentity::from) - .context(build_client_error::InvalidIdentitySnafu { name: &name })?; - (tls_config, Some(name)) - } - None => (tls_config_buider.with_no_client_auth(), None), - }; + #[test] + fn test_common_quic_config_default() { + let cfg = CommonQuicConfig::default(); + assert_eq!(cfg.defer_idle_timeout, Duration::ZERO); + assert!(!cfg.enable_0rtt); + assert!(!cfg.enable_sslkeylog); + } - let mut quic_builder = QuicClient::builder_with_tls(tls_config).with_alpns(vec!["h3"]); - if let Some(resolver) = builder.resolver { - quic_builder = quic_builder.with_resolver(resolver); - } + #[test] + fn test_common_quic_config_partial_eq_different_timeout() { + let a = CommonQuicConfig::default(); + let mut b = a.clone(); + b.defer_idle_timeout = Duration::from_secs(30); + assert_ne!(a, b); + } - Ok(H3ClientBuilder { - quic_builder, - client_name, - pool: Pool::empty(), - builder: Arc::new(ConnectionBuilder::new(Arc::default())), - }) + #[test] + fn test_common_quic_config_clone() { + let a = CommonQuicConfig::default(); + let b = a.clone(); + // Clone shares the same Arcs + assert!(Arc::ptr_eq( + &a.stream_strategy_factory, + &b.stream_strategy_factory + )); + assert!(Arc::ptr_eq(&a.qlogger, &b.qlogger)); + // Scalar values are copied + assert_eq!(a.defer_idle_timeout, b.defer_idle_timeout); + assert_eq!(a.enable_0rtt, b.enable_0rtt); + assert_eq!(a.enable_sslkeylog, b.enable_sslkeylog); } -} -pub struct H3ClientBuilder { - quic_builder: QuicClientBuilder, - client_name: Option, - pool: Pool, - builder: Arc>, -} + // -- ServerCertVerifierChoice ------------------------------------------- -impl H3ClientBuilder { - pub fn physical_ifaces(mut self, physical_ifaces: &'static Devices) -> Self { - self.quic_builder = self.quic_builder.physical_ifaces(physical_ifaces); - self + #[test] + fn test_verifier_choice_dangerous_ne_webpki() { + let store = root_store_with_ca(); + let webpki = WebPkiServerVerifier::builder(Arc::new(store)) + .build() + .unwrap(); + assert_ne!( + ServerCertVerifierChoice::Dangerous, + ServerCertVerifierChoice::WebPki(webpki) + ); } - pub fn with_resolver(mut self, resolver: Arc) -> Self { - self.quic_builder = self.quic_builder.with_resolver(resolver); - self + #[test] + fn test_verifier_choice_webpki_same_arc_eq() { + let store = root_store_with_ca(); + let webpki = WebPkiServerVerifier::builder(Arc::new(store)) + .build() + .unwrap(); + // webpki is already Arc + // Two clones of the same Arc should be equal (ptr_eq) + let a = ServerCertVerifierChoice::WebPki(webpki.clone()); + let b = ServerCertVerifierChoice::WebPki(webpki.clone()); + assert_eq!(a, b); } - pub fn with_iface_factory(mut self, factory: Arc) -> Self { - self.quic_builder = self.quic_builder.with_iface_factory(factory); - self + #[test] + fn test_verifier_choice_webpki_different_arc_ne() { + let store1 = root_store_with_ca(); + let store2 = root_store_with_ca(); + let webpki1 = WebPkiServerVerifier::builder(Arc::new(store1)) + .build() + .unwrap(); + let webpki2 = WebPkiServerVerifier::builder(Arc::new(store2)) + .build() + .unwrap(); + assert_ne!( + ServerCertVerifierChoice::WebPki(webpki1), + ServerCertVerifierChoice::WebPki(webpki2) + ); } - /// Specify the interfaces manager for the client. - pub fn with_iface_manager(mut self, iface_manager: Arc) -> Self { - self.quic_builder = self.quic_builder.with_iface_manager(iface_manager); - self + #[test] + fn test_verifier_choice_custom_same_arc_eq() { + let verifier: Arc = Arc::new(DangerousServerCertVerifier); + let a = ServerCertVerifierChoice::Custom(verifier.clone()); + let b = ServerCertVerifierChoice::Custom(verifier.clone()); + assert_eq!(a, b); } - pub fn with_router(mut self, router: Arc) -> Self { - self.quic_builder = self.quic_builder.with_router(router); - self + #[test] + fn test_verifier_choice_custom_different_arc_ne() { + let a: Arc = Arc::new(DangerousServerCertVerifier); + let b: Arc = Arc::new(DangerousServerCertVerifier); + assert_ne!( + ServerCertVerifierChoice::Custom(a), + ServerCertVerifierChoice::Custom(b) + ); } - pub fn with_stun(mut self, server: impl Into>) -> Self { - self.quic_builder = self.quic_builder.with_stun(server); - self + #[test] + fn test_verifier_choice_cross_variant_not_equal() { + let verifier: Arc = Arc::new(DangerousServerCertVerifier); + assert_ne!( + ServerCertVerifierChoice::Dangerous, + ServerCertVerifierChoice::Custom(verifier.clone()) + ); + assert_ne!( + ServerCertVerifierChoice::WebPki( + WebPkiServerVerifier::builder(Arc::new(root_store_with_ca())) + .build() + .unwrap() + ), + ServerCertVerifierChoice::Custom(verifier) + ); } - pub async fn bind(mut self, uri: impl IntoIterator>) -> Self { - self.quic_builder = self.quic_builder.bind(uri).await; - self + #[test] + fn test_verifier_choice_debug_variants() { + let store = root_store_with_ca(); + let webpki = WebPkiServerVerifier::builder(Arc::new(store)) + .build() + .unwrap(); + + let dangerous = ServerCertVerifierChoice::Dangerous; + let custom: Arc = Arc::new(DangerousServerCertVerifier); + + assert_eq!(format!("{:?}", dangerous), "Dangerous"); + assert_eq!( + format!("{:?}", ServerCertVerifierChoice::WebPki(webpki)), + "WebPki" + ); + assert_eq!( + format!("{:?}", ServerCertVerifierChoice::Custom(custom)), + "Custom" + ); } - pub fn defer_idle_timeout(mut self, duration: Duration) -> Self { - self.quic_builder = self.quic_builder.defer_idle_timeout(duration); - self + #[test] + fn test_verifier_choice_default_is_dangerous() { + assert_eq!( + ServerCertVerifierChoice::default(), + ServerCertVerifierChoice::Dangerous + ); } - pub fn with_quic_parameters(mut self, parameters: ClientParameters) -> Self { - self.quic_builder = self.quic_builder.with_parameters(parameters); - self + // -- ClientQuicConfig --------------------------------------------------- + + #[test] + fn test_client_quic_config_default() { + let cfg = ClientQuicConfig::default(); + // Common fields + assert_eq!(cfg.defer_idle_timeout, Duration::ZERO); + assert!(!cfg.enable_0rtt); + assert!(!cfg.enable_sslkeylog); + // Client-specific fields + assert!( + matches!(&cfg.verifier, ServerCertVerifierChoice::Dangerous), + "default verifier should be Dangerous" + ); + assert!(cfg.alpns.is_empty(), "default alpns should be empty"); } - pub fn with_streams_concurrency_strategy( - mut self, - strategy_factory: Arc, - ) -> Self { - self.quic_builder = self - .quic_builder - .with_streams_concurrency_strategy(strategy_factory); - self + #[test] + fn test_client_quic_config_partial_eq_different_timeout() { + let a = ClientQuicConfig::default(); + let mut b = a.clone(); + b.defer_idle_timeout = Duration::from_secs(99); + assert_ne!(a, b); } - pub fn with_qlog(mut self, logger: Arc) -> Self { - self.quic_builder = self.quic_builder.with_qlog(logger); - self - } + #[test] + fn test_client_quic_config_partial_eq_different_verifier() { + let a = ClientQuicConfig::default(); + let store = root_store_with_ca(); + let webpki = WebPkiServerVerifier::builder(Arc::new(store)) + .build() + .unwrap(); - pub fn with_token_sink(mut self, sink: Arc) -> Self { - self.quic_builder = self.quic_builder.with_token_sink(sink); - self + let mut custom = a.clone(); + custom.verifier = ServerCertVerifierChoice::Custom(Arc::new(DangerousServerCertVerifier)); + assert_ne!(a, custom); + + let mut webpki_choice = a.clone(); + webpki_choice.verifier = ServerCertVerifierChoice::WebPki(webpki); + assert_ne!(a, webpki_choice); } - pub fn enable_sslkeylog(mut self) -> Self { - self.quic_builder = self.quic_builder.enable_sslkeylog(); - self + #[test] + fn test_client_quic_config_partial_eq_different_components() { + let a = ClientQuicConfig::default(); + + let mut strategy = a.clone(); + strategy.stream_strategy_factory = Arc::new(ConsistentConcurrency::new); + assert_ne!(a, strategy); + + let mut qlogger = a.clone(); + qlogger.qlogger = Arc::new(NoopLogger); + assert_ne!(a, qlogger); + + let mut token_sink = a.clone(); + token_sink.token_sink = Arc::new(NoopTokenRegistry); + assert_ne!(a, token_sink); } - pub fn enable_0rtt(mut self) -> Self { - self.quic_builder = self.quic_builder.enable_0rtt(); - self + #[test] + fn test_client_quic_config_debug() { + let cfg = ClientQuicConfig { + alpns: vec![b"h3".to_vec()], + ..ClientQuicConfig::default() + }; + + let rendered = format!("{cfg:?}"); + assert!(rendered.contains("ClientQuicConfig")); + assert!(rendered.contains("defer_idle_timeout: 0ns")); + assert!(rendered.contains("enable_0rtt: false")); + assert!(rendered.contains("enable_sslkeylog: false")); + assert!(rendered.contains("alpns: 1")); + assert!(rendered.contains("verifier: Dangerous")); + assert!(rendered.contains("..")); + assert!(!rendered.contains("stream_strategy_factory")); } - pub fn with_connection_pool(mut self, pool: Pool) -> Self { - self.pool = pool; - self + #[test] + fn test_client_quic_config_clone() { + let a = ClientQuicConfig::default(); + let b = a.clone(); + // Trait-object Arcs are shared by pointer after clone + assert!(Arc::ptr_eq( + &a.stream_strategy_factory, + &b.stream_strategy_factory + )); + assert!(Arc::ptr_eq(&a.qlogger, &b.qlogger)); + assert!(Arc::ptr_eq(&a.token_sink, &b.token_sink)); + // Scalar / owned values are equal by value + assert_eq!(a.defer_idle_timeout, b.defer_idle_timeout); + assert_eq!(a.enable_0rtt, b.enable_0rtt); + assert_eq!(a.enable_sslkeylog, b.enable_sslkeylog); + assert_eq!(a.parameters, b.parameters); + assert_eq!(a.alpns, b.alpns); + assert_eq!(a.verifier, b.verifier); } - pub fn with_builder(mut self, builder: Arc>) -> Self { - self.builder = builder; - self + + #[test] + fn test_client_quic_config_mutate_does_not_affect_clone() { + let a = ClientQuicConfig::default(); + let mut b = a.clone(); + + // Mutate b — should not affect a since fields are copied/cloned + b.defer_idle_timeout = Duration::from_secs(99); + b.alpns.push(b"h3".to_vec()); + + // Original is unchanged + assert_eq!(a.defer_idle_timeout, Duration::ZERO); + assert!(a.alpns.is_empty()); + // b has the new values + assert_eq!(b.defer_idle_timeout, Duration::from_secs(99)); + assert!(!b.alpns.is_empty()); } - pub fn build(self) -> H3Client { - let client = match self.client_name { - Some(client_name) => self.quic_builder.with_name(client_name).build(), - None => self.quic_builder.build(), - }; - Client::from_quic_client() - .pool(self.pool.clone()) - .client(Arc::new(client)) - .builder(self.builder.clone()) - .build() + #[test] + fn test_client_quic_config_mutate_arc_fields_does_not_affect_clone() { + let a = ClientQuicConfig::default(); + let mut b = a.clone(); + + // Replace trait-object Arcs in the clone. + b.stream_strategy_factory = Arc::new(ConsistentConcurrency::new); + b.qlogger = Arc::new(NoopLogger); + b.token_sink = Arc::new(NoopTokenRegistry); + + // Other value fields stay unchanged and only trait-object identities diverge. + assert_eq!(a.defer_idle_timeout, b.defer_idle_timeout); + assert_eq!(a.enable_0rtt, b.enable_0rtt); + assert_eq!(a.enable_sslkeylog, b.enable_sslkeylog); + assert_eq!(a.parameters, b.parameters); + assert_eq!(a.alpns, b.alpns); + assert_eq!(a.verifier, b.verifier); + assert!(!Arc::ptr_eq( + &a.stream_strategy_factory, + &b.stream_strategy_factory + )); + assert!(!Arc::ptr_eq(&a.qlogger, &b.qlogger)); + assert!(!Arc::ptr_eq(&a.token_sink, &b.token_sink)); } } diff --git a/src/dquic/common.rs b/src/dquic/common.rs new file mode 100644 index 0000000..f92f61c --- /dev/null +++ b/src/dquic/common.rs @@ -0,0 +1,123 @@ +use std::{sync::Arc, time::Duration}; + +use crate::dquic::{ + log::{QLog, handy::NoopLogger}, + stream::{ProductStreamsConcurrencyController, handy::ConsistentConcurrency}, +}; + +// --------------------------------------------------------------------------- +// Common +// --------------------------------------------------------------------------- + +/// Configuration values that apply to both client and server roles. +#[derive(Clone)] +pub struct CommonQuicConfig { + /// How long the connection should keep sending probe packets after going + /// idle. `Duration::ZERO` (the default) disables deferred idle timeouts. + pub defer_idle_timeout: Duration, + /// Factory producing per-connection streams concurrency controllers. + pub stream_strategy_factory: Arc, + /// QUIC-events logger (qlog). Defaults to a no-op logger. + pub qlogger: Arc, + /// Whether 0-RTT should be enabled if the crypto context permits it. + pub enable_0rtt: bool, + /// Enable SSL key logging via `SSLKEYLOGFILE` for debugging captures. + pub enable_sslkeylog: bool, +} + +impl Default for CommonQuicConfig { + fn default() -> Self { + Self { + defer_idle_timeout: Duration::ZERO, + stream_strategy_factory: Arc::new(ConsistentConcurrency::new), + qlogger: Arc::new(NoopLogger), + enable_0rtt: false, + enable_sslkeylog: false, + } + } +} + +impl std::fmt::Debug for CommonQuicConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommonQuicConfig") + .field("defer_idle_timeout", &self.defer_idle_timeout) + .field("enable_0rtt", &self.enable_0rtt) + .field("enable_sslkeylog", &self.enable_sslkeylog) + .finish_non_exhaustive() + } +} + +impl PartialEq for CommonQuicConfig { + fn eq(&self, other: &Self) -> bool { + self.defer_idle_timeout == other.defer_idle_timeout + && self.enable_0rtt == other.enable_0rtt + && self.enable_sslkeylog == other.enable_sslkeylog + && Arc::ptr_eq( + &self.stream_strategy_factory, + &other.stream_strategy_factory, + ) + && Arc::ptr_eq(&self.qlogger, &other.qlogger) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_uses_noop_components_and_disabled_features() { + let config = CommonQuicConfig::default(); + + assert_eq!(config.defer_idle_timeout, Duration::ZERO); + assert!(!config.enable_0rtt); + assert!(!config.enable_sslkeylog); + assert_eq!( + format!("{config:?}"), + "CommonQuicConfig { defer_idle_timeout: 0ns, enable_0rtt: false, enable_sslkeylog: false, .. }", + ); + } + + #[test] + fn clone_preserves_dynamic_component_identity() { + let config = CommonQuicConfig { + defer_idle_timeout: Duration::from_secs(5), + enable_0rtt: true, + enable_sslkeylog: true, + ..Default::default() + }; + + let cloned = config.clone(); + + assert_eq!(config, cloned); + assert!(Arc::ptr_eq( + &config.stream_strategy_factory, + &cloned.stream_strategy_factory, + )); + assert!(Arc::ptr_eq(&config.qlogger, &cloned.qlogger)); + } + + #[test] + fn equality_requires_same_values_and_same_dynamic_component_arcs() { + let config = CommonQuicConfig::default(); + + let mut different_timeout = config.clone(); + different_timeout.defer_idle_timeout = Duration::from_secs(1); + assert_ne!(config, different_timeout); + + let mut different_zero_rtt = config.clone(); + different_zero_rtt.enable_0rtt = true; + assert_ne!(config, different_zero_rtt); + + let mut different_sslkeylog = config.clone(); + different_sslkeylog.enable_sslkeylog = true; + assert_ne!(config, different_sslkeylog); + + let mut different_strategy = config.clone(); + different_strategy.stream_strategy_factory = Arc::new(ConsistentConcurrency::new); + assert_ne!(config, different_strategy); + + let mut different_logger = config.clone(); + different_logger.qlogger = Arc::new(NoopLogger); + assert_ne!(config, different_logger); + } +} diff --git a/src/dquic/endpoint.rs b/src/dquic/endpoint.rs new file mode 100644 index 0000000..ff3a995 --- /dev/null +++ b/src/dquic/endpoint.rs @@ -0,0 +1,2183 @@ +//! QUIC-only endpoint built on top of a shared [`Network`]. + +use std::{ + any::Any, collections::HashMap, net::SocketAddr, str::FromStr, sync::Arc, time::Duration, +}; + +use arc_swap::{ArcSwap, ArcSwapOption}; +use bon::bon; +use futures::{FutureExt, Stream, StreamExt, future::join_all}; +use rustls::ClientConfig; +use snafu::{Report, ResultExt, Snafu}; +use tracing::Instrument; + +use crate::{ + dquic::{ + binds::BindPattern, + client::{ClientQuicConfig, ServerCertVerifierChoice}, + connection::Connection, + identity::Identity, + net::{BindUri, ConnectionId, EndpointAddr}, + network::{BindHandle, BindServerError, Network, ServerBinding}, + resolver::{Resolve, Source}, + server::ServerQuicConfig, + }, + quic, + util::tls::DangerousServerCertVerifier, +}; + +/// Error building the client-side TLS configuration. +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum BuildClientTlsError { + /// rustls failed to choose a supported protocol version. + #[snafu(display("failed to select TLS protocol version"))] + Version { + /// Underlying rustls error. + source: rustls::Error, + }, + /// rustls refused the provided client certificate / key. + #[snafu(display("failed to load client authentication certificate"))] + ClientAuth { + /// Underlying rustls error. + source: rustls::Error, + }, + /// Failed to set a transport parameter on the client parameters. + #[snafu(display("failed to set client name transport parameter"))] + SetParameter { + /// Underlying parameter error. + source: crate::dquic::qbase::param::error::Error, + }, +} + +/// Error returned by [`QuicEndpoint`] when opening an outbound connection. +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum ConnectError { + /// Failed to build the client TLS configuration. + #[snafu(display("failed to build client TLS config"))] + Tls { + /// Underlying build error. + source: BuildClientTlsError, + }, + /// DNS resolution failed. + #[snafu(display("dns lookup failed"))] + Dns { + /// Underlying I/O error. + source: std::io::Error, + }, + /// The resolver produced no reachable endpoint. + #[snafu(display("no reachable endpoint found for server"))] + NoReachableEndpoint, +} + +/// Error returned by [`QuicEndpoint`] when awaiting an inbound connection. +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum AcceptError { + /// The endpoint's identity is anonymous — no SNI to register. + #[snafu(display("cannot accept connections on an anonymous identity"))] + ServerUnavailable, + /// Registering the identity on the network failed. + #[snafu(display("failed to bind server on network"))] + BindServer { + /// Underlying network error. + source: BindServerError, + }, + /// The endpoint has been shut down. + #[snafu(display("endpoint has been shut down"))] + Shutdown, +} + +/// Back-compat alias. +pub type EndpointError = ConnectError; + +/// A QUIC-only endpoint backed by a shared [`Network`]. +pub struct QuicEndpoint { + /// Shared network infrastructure. + pub(crate) network: Arc, + /// TLS identity for this endpoint (`None` for anonymous/client-only). + pub(crate) identity: Arc>, + /// Resolver used when establishing outbound connections. + pub(crate) resolver: Arc, + /// Client-side configuration. + pub(crate) client: Arc>, + /// Server-side configuration. + pub(crate) server: Arc>, + /// Bind patterns for interface filtering (shared by client and server roles). + pub(crate) bind: Arc>, + _binds: Arc>, + client_tls_cache: ArcSwapOption, + server_binding_cache: ArcSwapOption, +} + +impl Clone for QuicEndpoint { + fn clone(&self) -> Self { + Self { + network: self.network.clone(), + identity: self.identity.clone(), + resolver: self.resolver.clone(), + client: self.client.clone(), + server: self.server.clone(), + bind: self.bind.clone(), + _binds: self._binds.clone(), + client_tls_cache: ArcSwapOption::empty(), + server_binding_cache: ArcSwapOption::empty(), + } + } +} + +/// RAII guard for mutable access to [`QuicEndpoint`]'s client configuration. +/// +/// On drop, invalidates the endpoint's `client_tls_cache` so that the next +/// `connect()` call rebuilds the TLS configuration. +pub struct ClientConfigMutGuard<'a> { + config: Arc, + target: &'a ArcSwap, + cache: &'a ArcSwapOption, +} + +impl<'a> std::ops::Deref for ClientConfigMutGuard<'a> { + type Target = ClientQuicConfig; + fn deref(&self) -> &Self::Target { + self.config.as_ref() + } +} + +impl<'a> std::ops::DerefMut for ClientConfigMutGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + Arc::make_mut(&mut self.config) + } +} + +impl<'a> Drop for ClientConfigMutGuard<'a> { + fn drop(&mut self) { + self.target.store(self.config.clone()); + self.cache.store(None); + } +} + +/// RAII guard for mutable access to [`QuicEndpoint`]'s server configuration. +/// +/// On drop, invalidates the endpoint's `server_binding_cache` so that the next +/// `accept()` call rebuilds the server binding. +pub struct ServerConfigMutGuard<'a> { + config: Arc, + target: &'a ArcSwap, + cache: &'a ArcSwapOption, +} + +impl<'a> std::ops::Deref for ServerConfigMutGuard<'a> { + type Target = ServerQuicConfig; + fn deref(&self) -> &Self::Target { + self.config.as_ref() + } +} + +impl<'a> std::ops::DerefMut for ServerConfigMutGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + Arc::make_mut(&mut self.config) + } +} + +impl<'a> Drop for ServerConfigMutGuard<'a> { + fn drop(&mut self) { + self.target.store(self.config.clone()); + self.cache.store(None); + } +} + +impl QuicEndpoint { + pub async fn new() -> Self { + Self::builder().build().await + } + + /// Reference to the shared network infrastructure. + pub fn network(&self) -> &Arc { + &self.network + } + + /// Current TLS identity, if any. + pub fn identity(&self) -> Option> { + self.identity.load_full() + } + + /// Bind patterns governing which interfaces this endpoint uses. + pub fn bind_patterns(&self) -> &Arc> { + &self.bind + } + + /// Reference to the DNS resolver. + pub fn resolver(&self) -> &Arc { + &self.resolver + } + + pub fn set_resolver(&mut self, resolver: Arc) { + self.resolver = resolver; + } +} + +#[bon] +impl QuicEndpoint { + #[builder] + pub async fn new( + #[builder(default = Network::builder().build())] network: Arc, + identity: Option>, + #[builder(default = Arc::new(crate::dquic::prelude::handy::SystemResolver))] resolver: Arc< + dyn Resolve + Send + Sync, + >, + #[builder(default)] client: ClientQuicConfig, + #[builder(default)] server: ServerQuicConfig, + #[builder(default = Arc::new(Vec::new()))] bind: Arc>, + ) -> Self { + let bind = if bind.is_empty() { + Arc::new(vec![BindPattern::from_str("*").unwrap()]) + } else { + bind + }; + let binds: Vec = + join_all(bind.iter().map(|p| network.quic().bind(p.clone()))).await; + let endpoint = Self { + network, + identity: Arc::new(ArcSwapOption::from(identity)), + resolver, + client: Arc::new(ArcSwap::from_pointee(client)), + server: Arc::new(ArcSwap::from_pointee(server)), + bind, + _binds: Arc::new(binds), + client_tls_cache: ArcSwapOption::empty(), + server_binding_cache: ArcSwapOption::empty(), + }; + endpoint.init_client(); + endpoint.init_server().await; + endpoint + } +} + +impl QuicEndpoint { + fn ensure_client(&self) -> Result, BuildClientTlsError> { + if let Some(cached) = self.client_tls_cache.load_full() { + return Ok(cached); + } + let config = Arc::new(self.build_client_tls()?); + self.client_tls_cache.store(Some(config.clone())); + Ok(config) + } + + fn build_client_tls(&self) -> Result { + use build_client_tls_error::{ClientAuthSnafu, VersionSnafu}; + + const TLS13: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13]; + let provider = ClientConfig::builder().crypto_provider().clone(); + let builder = ClientConfig::builder_with_provider(provider) + .with_protocol_versions(TLS13) + .context(VersionSnafu)?; + let client = self.client.load_full(); + let builder = match &client.verifier { + ServerCertVerifierChoice::Dangerous => builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier)), + ServerCertVerifierChoice::WebPki(v) => builder.with_webpki_verifier(v.clone()), + ServerCertVerifierChoice::Custom(v) => builder + .dangerous() + .with_custom_certificate_verifier(v.clone()), + }; + let mut tls = match self.identity.load_full() { + None => builder.with_no_client_auth(), + Some(id) => builder + .with_client_auth_cert(id.certs.iter().cloned().collect(), id.key.clone_key()) + .context(ClientAuthSnafu)?, + }; + tls.alpn_protocols.clone_from(&client.alpns); + tls.enable_early_data = client.enable_0rtt; + Ok(tls) + } +} + +impl QuicEndpoint { + fn build_client_connection( + &self, + server_name: &str, + tls: Arc, + ) -> Result, BuildClientTlsError> { + use build_client_tls_error::SetParameterSnafu; + + // Propagate the endpoint's named identity into the QUIC transport + // `ClientName` parameter so the peer can populate its + // `remote_authority` (identity-based access control on the server + // relies on this). + let client = self.client.load_full(); + let mut parameters = client.parameters.clone(); + if let Some(named) = self.identity.load_full() { + parameters + .set( + crate::dquic::qbase::param::ParameterId::ClientName, + named.name.to_string(), + ) + .context(SetParameterSnafu)?; + } + let builder = Connection::new_client(server_name.to_owned(), client.token_sink.clone()) + .with_parameters(parameters) + .with_tls_config((*tls).clone()) + .with_streams_concurrency_strategy(client.stream_strategy_factory.as_ref()) + .with_zero_rtt(client.enable_0rtt); + let connection = self + .network + .quic() + .configure_connection(builder) + .with_defer_idle_timeout(client.defer_idle_timeout) + .with_cids(ConnectionId::random_gen(8)) + .with_qlog(client.qlogger.clone()) + .run(); + Ok(connection) + } +} + +impl QuicEndpoint { + async fn ensure_server(&self) -> Result { + use accept_error::BindServerSnafu; + + let named = match self.identity.load_full() { + None => return Err(AcceptError::ServerUnavailable), + Some(id) => id, + }; + if let Some(cached) = self.server_binding_cache.load_full() { + return Ok(cached.as_ref().clone()); + } + let server = self.server.load_full(); + let binding = self + .network + .quic() + .bind_server(named, (*server).clone(), self.bind.clone()) + .await + .context(BindServerSnafu)?; + self.server_binding_cache + .store(Some(Arc::new(binding.clone()))); + Ok(binding) + } + + /// Eagerly build the client TLS configuration so the first `connect()` + /// is fast. Silently ignores errors — the real error surfaces at + /// connection time. + fn init_client(&self) { + let _ = self.ensure_client(); + } + + /// Eagerly register the server SNI so `accept()` can return immediately. + /// No-op when no identity is configured. Silently ignores errors. + async fn init_server(&self) { + if self.identity.load_full().is_some() { + let _ = self.ensure_server().await; + } + } + + /// Accept an inbound connection, using the cached server binding. + pub async fn accept(&self) -> Result, AcceptError> { + let binding = self.ensure_server().await?; + let conn = binding.recv().await.ok_or(AcceptError::Shutdown)?; + let mut observer = self.network.quic().locations().subscribe(); + let weak = Arc::downgrade(&conn); + let patterns: Vec = self.bind.iter().cloned().collect(); + // Inherent termination: this companion task tracks local + // address changes for the accepted connection. It exits when + // (a) conn is dropped (weak.upgrade fails), (b) conn terminates, + // or (c) the locations observer channel is exhausted. + tokio::spawn( + async move { + let mut local_addresses = LocalAddressTracker::default(); + loop { + let Some(c) = weak.upgrade() else { break }; + tokio::select! { + _ = c.terminated() => break, + event = observer.recv() => { + let Some((bind_uri, event)) = event else { break }; + if !patterns.iter().any(|p| p.matches(&bind_uri)) { + continue; + } + let Some(c) = weak.upgrade() else { break }; + local_addresses.handle(&c, bind_uri, event); + } + } + } + } + .in_current_span(), + ); + Ok(conn) + } +} + +impl quic::Connect for QuicEndpoint { + type Connection = Connection; + type Error = ConnectError; + + async fn connect( + &self, + server: &http::uri::Authority, + ) -> Result, Self::Error> { + use connect_error::{DnsSnafu, TlsSnafu}; + + let full = server.as_str(); + let server_str = full.rsplit_once('@').map_or(full, |(_, host)| host); + + tracing::debug!(server = server_str, "connecting quic endpoint"); + let tls = self.ensure_client().context(TlsSnafu)?; + let mut server_eps = + futures::StreamExt::fuse(self.resolver.lookup(server_str).await.context(DnsSnafu)?); + let connection = self + .build_client_connection(server_str, tls) + .context(TlsSnafu)?; + tracing::trace!( + server = server_str, + timeout_ms = self.connect_path_timeout().as_millis(), + bind_pattern_count = self.bind.len(), + "waiting for quic path" + ); + + // Two legs driving path establishment: + // 1. DNS results → add peer endpoints + // 2. Locations events → add local endpoints + // + // Locations replays the current bound addresses when a subscriber is + // registered, so connect does not scan bound interfaces and synthesize + // local addresses on its own. + let mut observer = self.network.quic().locations().subscribe(); + let bind = self.bind.clone(); + let connect_timeout = tokio::time::sleep(self.connect_path_timeout()); + tokio::pin!(connect_timeout); + let mut peer_agent_seen = false; + let mut local_addresses = LocalAddressTracker::default(); + + loop { + tokio::select! { + biased; + _ = connection.terminated() => { + return Err(ConnectError::NoReachableEndpoint); + } + result = connection.handshaked(), if peer_agent_seen => { + result.map_err(|_| ConnectError::NoReachableEndpoint)?; + Self::spawn_path_discovery_companion( + bind.clone(), + connection.clone(), + server_eps, + observer, + local_addresses, + ); + tracing::debug!(server = server_str, "quic endpoint handshaked before agent path"); + return Ok(connection); + } + _ = &mut connect_timeout => { + connection + .validate() + .map_err(|_| ConnectError::NoReachableEndpoint)?; + Self::spawn_path_discovery_companion( + bind.clone(), + connection.clone(), + server_eps, + observer, + local_addresses, + ); + return Ok(connection); + } + Some((source, server_ep)) = server_eps.next() => { + tracing::trace!( + server = server_str, + ?source, + endpoint = ?server_ep, + "resolved peer endpoint" + ); + peer_agent_seen |= Self::endpoint_is_agent(server_ep); + Self::add_resolved_peer_endpoint( + &connection, + source, + server_ep, + ); + // Resolver streams often contain a batch of equivalent endpoints + // that are ready immediately (for example STUN-agent and + // same-link direct records). Install the ready batch before + // evaluating readiness so a Direct record cannot return before + // a peer Agent record in the same DNS response is observed. + peer_agent_seen |= Self::drain_ready_peer_endpoints( + &connection, + &mut server_eps, + ); + } + Some((bind_uri, event)) = observer.recv() => { + if !bind.iter().any(|p| p.matches(&bind_uri)) { + tracing::trace!(%bind_uri, "ignoring location event outside bind patterns"); + continue; + } + tracing::trace!(%bind_uri, "handling location event for connect"); + local_addresses.handle(&connection, bind_uri, event); + } + } + + match Self::connection_has_paths(&connection, peer_agent_seen) { + Ok(true) => { + Self::spawn_path_discovery_companion( + bind.clone(), + connection.clone(), + server_eps, + observer, + local_addresses, + ); + tracing::debug!(server = server_str, "quic endpoint has at least one path"); + return Ok(connection); + } + Ok(false) => {} // no paths yet — continue loop + Err(_) => return Err(ConnectError::NoReachableEndpoint), + } + } + } +} + +impl quic::Listen for QuicEndpoint { + type Connection = Connection; + type Error = AcceptError; + + async fn accept(&mut self) -> Result, Self::Error> { + QuicEndpoint::accept(self).await + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + self.server_binding_cache.store(None); + Ok(()) + } +} + +impl quic::Listen for &QuicEndpoint { + type Connection = Connection; + type Error = AcceptError; + + async fn accept(&mut self) -> Result, Self::Error> { + (**self).accept().await + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + (**self).shutdown().await + } +} + +#[allow(dead_code)] +impl QuicEndpoint { + fn invalidate_client_cache(&self) { + self.client_tls_cache.store(None); + } + + fn invalidate_server_cache(&self) { + self.server_binding_cache.store(None); + } + + fn invalidate_caches(&self) { + self.invalidate_client_cache(); + self.invalidate_server_cache(); + } + + /// Obtain a mutable guard for the endpoint's identity. + /// + /// The guard implements [`DerefMut`](std::ops::DerefMut) targeting + /// `Option>`, so callers may inspect, mutate, set, or + /// clear the identity. Caches are invalidated when the guard is dropped. + pub fn identity_mut(&mut self) -> IdentityMutGuard<'_> { + let value = self.identity.load_full(); + IdentityMutGuard { + identity: self.identity.as_ref(), + value, + client_cache: &self.client_tls_cache, + server_cache: &self.server_binding_cache, + } + } + + /// Update the OCSP staple for this endpoint's identity. + /// + /// Convenience wrapper around [`Self::identity_mut`]; no-op when there + /// is no identity set. + pub fn update_ocsp(&mut self, ocsp: Option>) { + let mut guard = self.identity_mut(); + if let Some(arc) = guard.as_mut() { + Arc::make_mut(arc).ocsp = Arc::new(ocsp); + } + } + + /// Obtain a mutable guard for the client configuration. + /// + /// The guard implements [`DerefMut`](std::ops::DerefMut) targeting + /// [`ClientQuicConfig`], so callers can mutate fields directly. + /// When the guard is dropped, the `client_tls_cache` is invalidated. + pub fn client_config_mut(&mut self) -> ClientConfigMutGuard<'_> { + ClientConfigMutGuard { + config: self.client.load_full(), + target: self.client.as_ref(), + cache: &self.client_tls_cache, + } + } + + /// Obtain a mutable guard for the server configuration. + /// + /// The guard implements [`DerefMut`](std::ops::DerefMut) targeting + /// [`ServerQuicConfig`], so callers can mutate fields directly. + /// When the guard is dropped, the `server_binding_cache` is invalidated. + pub fn server_config_mut(&mut self) -> ServerConfigMutGuard<'_> { + ServerConfigMutGuard { + config: self.server.load_full(), + target: self.server.as_ref(), + cache: &self.server_binding_cache, + } + } +} + +/// RAII guard for mutable access to [`QuicEndpoint`]'s identity. +/// +/// On drop, stores the modified identity back into the endpoint's +/// `ArcSwapOption` and invalidates both caches (identity changes +/// affect both client TLS auth cert and server SNI binding). +pub struct IdentityMutGuard<'a> { + identity: &'a ArcSwapOption, + value: Option>, + client_cache: &'a ArcSwapOption, + server_cache: &'a ArcSwapOption, +} + +impl<'a> std::ops::Deref for IdentityMutGuard<'a> { + type Target = Option>; + fn deref(&self) -> &Self::Target { + &self.value + } +} + +impl<'a> std::ops::DerefMut for IdentityMutGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.value + } +} + +impl<'a> Drop for IdentityMutGuard<'a> { + fn drop(&mut self) { + self.identity.store(self.value.take()); + self.client_cache.store(None); + self.server_cache.store(None); + } +} + +#[derive(Default)] +struct LocalAddressTracker { + direct: HashMap, + stun: HashMap, +} + +impl LocalAddressTracker { + fn handle( + &mut self, + conn: &Connection, + bind_uri: BindUri, + event: crate::dquic::qinterface::component::location::AddressEvent, + ) { + use crate::dquic::qinterface::component::location::AddressEvent; + + let event = match event.downcast::>() { + Ok(event) => { + self.handle_direct(conn, bind_uri, event); + return; + } + Err(event) => event, + }; + + match event.downcast::() { + Ok(event) => self.handle_stun(conn, bind_uri, event), + Err(AddressEvent::Upsert(data)) => { + let type_id = data.as_ref().type_id(); + tracing::trace!( + %bind_uri, + ?type_id, + "ignoring unknown local address upsert event" + ); + } + Err(AddressEvent::Remove(type_id)) => { + tracing::trace!( + %bind_uri, + ?type_id, + "ignoring unknown local address remove event" + ); + } + Err(AddressEvent::Closed) => self.remove_all(conn, &bind_uri), + } + } + + fn handle_direct( + &mut self, + conn: &Connection, + bind_uri: BindUri, + event: crate::dquic::qinterface::component::location::AddressEvent< + std::io::Result, + >, + ) { + use crate::dquic::qinterface::component::location::AddressEvent; + + match event { + AddressEvent::Upsert(data) => match data.as_ref() { + Ok(addr) => self.upsert_direct(conn, bind_uri, *addr), + Err(error) => { + tracing::trace!( + %bind_uri, + error = %Report::from_error(error), + "direct local address update failed" + ); + self.remove_direct(conn, &bind_uri); + } + }, + AddressEvent::Remove(_type_id) => self.remove_direct(conn, &bind_uri), + AddressEvent::Closed => self.remove_all(conn, &bind_uri), + } + } + + fn handle_stun( + &mut self, + conn: &Connection, + bind_uri: BindUri, + event: crate::dquic::qinterface::component::location::AddressEvent< + crate::dquic::qtraversal::nat::client::ClientLocationData, + >, + ) { + use crate::dquic::qinterface::component::location::AddressEvent; + + match event { + AddressEvent::Upsert(data) => match data.as_ref() { + Ok(endpoint) => self.upsert_stun(conn, bind_uri, *endpoint), + Err(error) => { + tracing::trace!( + %bind_uri, + error = %Report::from_error(error), + "stun local address update failed" + ); + self.remove_stun(conn, &bind_uri); + } + }, + AddressEvent::Remove(_type_id) => self.remove_stun(conn, &bind_uri), + AddressEvent::Closed => self.remove_all(conn, &bind_uri), + } + } + + fn upsert_direct(&mut self, conn: &Connection, bind_uri: BindUri, addr: SocketAddr) { + if self.direct.get(&bind_uri) == Some(&addr) { + return; + } + if let Some(previous) = self.direct.insert(bind_uri.clone(), addr) { + self.remove_endpoint(conn, &bind_uri, EndpointAddr::direct(previous)); + } + if let Err(error) = conn.add_local_endpoint(bind_uri, EndpointAddr::direct(addr)) { + tracing::trace!( + error = %Report::from_error(&error), + "failed to add local endpoint" + ); + } + } + + fn remove_direct(&mut self, conn: &Connection, bind_uri: &BindUri) { + if let Some(previous) = self.direct.remove(bind_uri) { + self.remove_endpoint(conn, bind_uri, EndpointAddr::direct(previous)); + } + } + + fn upsert_stun(&mut self, conn: &Connection, bind_uri: BindUri, endpoint: EndpointAddr) { + if self.stun.get(&bind_uri) == Some(&endpoint) { + return; + } + if let Some(previous) = self.stun.insert(bind_uri.clone(), endpoint) { + self.remove_stun_endpoint(conn, &bind_uri, previous); + } + QuicEndpoint::add_local_endpoint_for_peer(conn, bind_uri, endpoint); + } + + fn remove_stun(&mut self, conn: &Connection, bind_uri: &BindUri) { + if let Some(previous) = self.stun.remove(bind_uri) { + self.remove_stun_endpoint(conn, bind_uri, previous); + } + } + + fn remove_stun_endpoint(&self, conn: &Connection, bind_uri: &BindUri, endpoint: EndpointAddr) { + self.remove_endpoint(conn, bind_uri, endpoint); + if matches!(endpoint, EndpointAddr::Agent { .. }) + && let Err(error) = conn.remove_address(endpoint.addr()) + { + tracing::trace!( + error = %Report::from_error(&error), + "failed to remove local punch endpoint" + ); + } + } + + fn remove_endpoint(&self, conn: &Connection, bind_uri: &BindUri, endpoint: EndpointAddr) { + if let Err(error) = conn.remove_local_endpoint(bind_uri, endpoint) { + tracing::trace!( + error = %Report::from_error(&error), + "failed to remove local endpoint" + ); + } + } + + fn remove_all(&mut self, conn: &Connection, bind_uri: &BindUri) { + self.remove_direct(conn, bind_uri); + self.remove_stun(conn, bind_uri); + } +} + +impl QuicEndpoint { + fn connect_path_timeout(&self) -> Duration { + self.client + .load_full() + .parameters + .get::(crate::dquic::qbase::param::ParameterId::MaxIdleTimeout) + .filter(|timeout| !timeout.is_zero()) + .unwrap_or(Duration::from_secs(20)) + } + + fn connection_has_paths( + connection: &Connection, + require_agent_path: bool, + ) -> Result { + let ctx = connection.path_context()?; + Ok(ctx + .paths::>() + .into_iter() + .any(|(pathway, _)| Self::pathway_is_connect_ready(pathway, require_agent_path))) + } + + fn pathway_is_connect_ready( + pathway: crate::dquic::qbase::net::route::Pathway, + require_agent_path: bool, + ) -> bool { + match (pathway.local(), pathway.remote()) { + (EndpointAddr::Direct { addr: local }, EndpointAddr::Direct { addr: remote }) => { + if require_agent_path { + return false; + } + Self::direct_path_is_connect_ready(local, remote) + } + _ => true, + } + } + + fn direct_path_is_connect_ready(local: SocketAddr, remote: SocketAddr) -> bool { + local.ip().is_loopback() == remote.ip().is_loopback() + } + + fn spawn_path_discovery_companion( + bind: Arc>, + connection: Arc, + mut server_eps: S, + mut observer: crate::dquic::qinterface::component::location::Observer, + local_addresses: LocalAddressTracker, + ) where + S: Stream + Unpin + Send + 'static, + { + let weak = Arc::downgrade(&connection); + // Inherent termination: this path-discovery companion drains remaining + // DNS results and local address events after connect has established + // that the connection will not leak. It exits when (a) connection is + // dropped (weak.upgrade fails), (b) connection terminates, or (c) both + // DNS and observer streams are exhausted. + tokio::spawn( + async move { + let mut local_addresses = local_addresses; + let mut dns_done = false; + let mut locations_done = false; + loop { + if dns_done && locations_done { + break; + } + let Some(c) = weak.upgrade() else { break }; + tokio::select! { + biased; + _ = c.terminated() => break, + result = server_eps.next(), if !dns_done => { + let Some((s, e)) = result else { + dns_done = true; + continue; + }; + let Some(c) = weak.upgrade() else { break }; + Self::add_resolved_peer_endpoint(&c, s, e); + } + result = observer.recv(), if !locations_done => { + let Some((uri, e)) = result else { + locations_done = true; + continue; + }; + let Some(c) = weak.upgrade() else { break }; + if bind.iter().any(|p| p.matches(&uri)) { + local_addresses.handle(&c, uri, e); + } + } + } + } + } + .in_current_span(), + ); + } + + fn add_resolved_peer_endpoint( + connection: &Connection, + source: Source, + server_ep: EndpointAddr, + ) { + let _ = connection.add_peer_endpoint(server_ep, source); + } + + fn drain_ready_peer_endpoints(connection: &Connection, server_eps: &mut S) -> bool + where + S: Stream + Unpin, + { + let mut agent_seen = false; + while let Some(Some((source, server_ep))) = server_eps.next().now_or_never() { + agent_seen |= Self::endpoint_is_agent(server_ep); + Self::add_resolved_peer_endpoint(connection, source, server_ep); + } + agent_seen + } + + fn endpoint_is_agent(endpoint: EndpointAddr) -> bool { + matches!(endpoint, EndpointAddr::Agent { .. }) + } + + fn add_local_endpoint_for_peer(conn: &Connection, bind_uri: BindUri, endpoint: EndpointAddr) { + if let Err(error) = conn.add_local_endpoint(bind_uri.clone(), endpoint) { + tracing::trace!( + error = %Report::from_error(&error), + "failed to add local endpoint" + ); + } + if matches!(endpoint, EndpointAddr::Agent { .. }) { + match conn.add_local_punch_address(bind_uri, endpoint) { + Ok(Ok(())) => {} + Ok(Err(error)) => { + tracing::trace!( + error = %Report::from_error(&error), + "failed to add local punch endpoint" + ); + } + Err(error) => { + tracing::trace!( + error = %Report::from_error(&error), + "failed to add local punch endpoint" + ); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::any::TypeId; + + use rustls::{RootCertStore, client::WebPkiServerVerifier, pki_types::PrivateKeyDer}; + + use super::*; + use crate::dquic::{cert::handy::ToCertificate, resolver::handy::SystemResolver}; + + const CA_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/ca.cert"); + + fn root_store_with_ca() -> RootCertStore { + let mut store = RootCertStore::empty(); + store.add_parsable_certificates(CA_CERT.to_certificate()); + store + } + + #[tokio::test] + async fn test_quic_endpoint_construction() { + let network = Network::builder().build(); + let resolver = Arc::new(SystemResolver); + let client = ClientQuicConfig::default(); + let server = ServerQuicConfig::default(); + + let endpoint = QuicEndpoint::builder() + .network(network.clone()) + .resolver(resolver.clone()) + .client(client) + .server(server) + .build() + .await; + + assert!(Arc::ptr_eq(endpoint.network(), &network)); + assert!(endpoint.identity().is_none()); + } + + #[tokio::test] + async fn test_accept_with_no_identity() { + let network = Network::builder().build(); + let resolver = Arc::new(SystemResolver); + let client = ClientQuicConfig::default(); + let server = ServerQuicConfig::default(); + + let endpoint = QuicEndpoint::builder() + .network(network.clone()) + .resolver(resolver.clone()) + .client(client) + .server(server) + .build() + .await; + + let result = endpoint.accept().await; + assert!(matches!(result, Err(AcceptError::ServerUnavailable))); + } + + #[tokio::test] + async fn test_network_locations_accessor() { + // Verify that the built-in QUIC locations accessor is available. + let network = Network::builder().build(); + let _locations = network.quic().locations(); + // Just verify we can access it without panicking + } + + #[tokio::test] + async fn test_dual_task_connect_structure() { + // Verify that connect() sets up the dual-task model infrastructure + // This test verifies the structure is in place, even if full Task A + // implementation depends on Locations API details + let network = Network::builder().build(); + let resolver = Arc::new(SystemResolver); + let client = ClientQuicConfig::default(); + let server = ServerQuicConfig::default(); + + let _endpoint = QuicEndpoint::builder() + .network(network.clone()) + .resolver(resolver.clone()) + .client(client) + .server(server) + .build() + .await; + + // The dual-task model is set up internally in connect() + // We verify it doesn't panic or error during setup + // (actual connection would require valid DNS resolution) + } + + #[tokio::test] + async fn local_location_before_dns_peer_still_creates_path() { + let bind_pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid bind pattern"); + let bind = Arc::new(vec![bind_pattern.clone()]); + let endpoint = QuicEndpoint::builder() + .network(Network::builder().build()) + .bind(bind.clone()) + .build() + .await; + let tls = endpoint.ensure_client().expect("client tls"); + let connection = endpoint + .build_client_connection("remote.test", tls) + .expect("client connection"); + let iface = endpoint + .network() + .quic() + .get_interfaces(&bind_pattern) + .expect("registered bind") + .into_iter() + .next() + .expect("bound interface"); + + use crate::dquic::net::IO as _; + let bind_uri = iface.bind_uri(); + let local_addr = iface.borrow().bound_addr().expect("bound address"); + let data: Arc = + Arc::new(Ok::(local_addr)); + let event = crate::dquic::qinterface::component::location::AddressEvent::Upsert(data); + + let mut local_addresses = LocalAddressTracker::default(); + local_addresses.handle(&connection, bind_uri, event); + + let remote_addr = SocketAddr::new(local_addr.ip(), local_addr.port() + 1); + QuicEndpoint::add_resolved_peer_endpoint( + &connection, + Source::System, + EndpointAddr::direct(remote_addr), + ); + + assert!( + QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "local endpoint replayed before DNS peer must be retained for peer pairing" + ); + } + + #[tokio::test] + async fn loopback_to_non_loopback_direct_path_is_not_connect_ready() { + let bind_pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid bind pattern"); + let bind_endpoint = QuicEndpoint::builder() + .network(Network::builder().build()) + .bind(Arc::new(vec![bind_pattern.clone()])) + .build() + .await; + let tls = bind_endpoint.ensure_client().expect("client tls"); + let connection = bind_endpoint + .build_client_connection("server.example", tls) + .expect("client connection"); + let iface = bind_endpoint + .network() + .quic() + .get_interfaces(&bind_pattern) + .expect("registered bind") + .into_iter() + .next() + .expect("bound interface"); + + use crate::dquic::net::IO as _; + let bind_uri = iface.bind_uri(); + let local_addr = iface.borrow().bound_addr().expect("bound address"); + let data: Arc = + Arc::new(Ok::(local_addr)); + let mut local_addresses = LocalAddressTracker::default(); + local_addresses.handle( + &connection, + bind_uri, + crate::dquic::qinterface::component::location::AddressEvent::Upsert(data), + ); + + QuicEndpoint::add_resolved_peer_endpoint( + &connection, + Source::H3 { + server: Arc::from("https://dns.genmeta.net:4433"), + }, + EndpointAddr::direct("10.10.0.100:47388".parse().expect("remote addr")), + ); + + assert!( + !QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "loopback-to-non-loopback direct paths are not ready for client handshake" + ); + } + + #[tokio::test] + async fn peer_agent_requires_agent_path_before_connect_ready() { + let bind_pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid bind pattern"); + let endpoint = QuicEndpoint::builder() + .network(Network::builder().build()) + .bind(Arc::new(vec![bind_pattern.clone()])) + .build() + .await; + let tls = endpoint.ensure_client().expect("client tls"); + let connection = endpoint + .build_client_connection("server.example", tls) + .expect("client connection"); + let iface = endpoint + .network() + .quic() + .get_interfaces(&bind_pattern) + .expect("registered bind") + .into_iter() + .next() + .expect("bound interface"); + + use crate::dquic::net::IO as _; + let bind_uri = iface.bind_uri(); + let local_addr = iface.borrow().bound_addr().expect("bound address"); + let remote_direct = SocketAddr::new(local_addr.ip(), local_addr.port() + 1); + let stun_agent: SocketAddr = "10.10.0.2:20004".parse().expect("stun agent"); + + let mut local_addresses = LocalAddressTracker::default(); + local_addresses.handle( + &connection, + bind_uri.clone(), + crate::dquic::qinterface::component::location::AddressEvent::Upsert(Arc::new(Ok::< + SocketAddr, + std::io::Error, + >( + local_addr, + ))), + ); + QuicEndpoint::add_resolved_peer_endpoint( + &connection, + Source::H3 { + server: Arc::from("https://dns.genmeta.net:4433"), + }, + EndpointAddr::with_agent(stun_agent, "10.10.0.40:20000".parse().expect("outer addr")), + ); + QuicEndpoint::add_resolved_peer_endpoint( + &connection, + Source::H3 { + server: Arc::from("https://dns.genmeta.net:4433"), + }, + EndpointAddr::direct(remote_direct), + ); + + assert!( + QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "direct loopback path is connect-ready when no Agent path is required" + ); + assert!( + !QuicEndpoint::connection_has_paths(&connection, true).expect("path context"), + "peer Agent endpoints must not return on Direct-only paths" + ); + + connection + .add_local_endpoint(bind_uri, EndpointAddr::with_agent(stun_agent, local_addr)) + .expect("local agent endpoint"); + assert!( + QuicEndpoint::connection_has_paths(&connection, true).expect("path context"), + "local Agent endpoint should make peer Agent connection ready" + ); + } + + #[tokio::test] + async fn connect_helpers_cover_identity_verifier_and_ready_endpoint_paths() { + let mut endpoint = make_endpoint().await; + let webpki = WebPkiServerVerifier::builder(Arc::new(root_store_with_ca())) + .build() + .expect("webpki verifier"); + { + let mut client = endpoint.client_config_mut(); + client.verifier = ServerCertVerifierChoice::WebPki(webpki); + } + endpoint + .build_client_tls() + .expect("webpki verifier client tls"); + + { + let mut client = endpoint.client_config_mut(); + client.verifier = + ServerCertVerifierChoice::Custom(Arc::new(DangerousServerCertVerifier)); + } + let tls = Arc::new( + endpoint + .build_client_tls() + .expect("custom verifier client tls"), + ); + { + let mut identity = endpoint.identity_mut(); + *identity = Some(Arc::new(make_identity("client-name.test"))); + } + endpoint + .build_client_connection("server.example", tls) + .expect("client connection with identity parameter"); + + let bind_pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid bind pattern"); + let bind_endpoint = QuicEndpoint::builder() + .network(Network::builder().build()) + .bind(Arc::new(vec![bind_pattern.clone()])) + .build() + .await; + let tls = bind_endpoint.ensure_client().expect("client tls"); + let connection = bind_endpoint + .build_client_connection("server.example", tls) + .expect("client connection"); + let iface = bind_endpoint + .network() + .quic() + .get_interfaces(&bind_pattern) + .expect("registered bind") + .into_iter() + .next() + .expect("bound interface"); + + use crate::dquic::net::IO as _; + let bind_uri = iface.bind_uri(); + let local_addr = iface.borrow().bound_addr().expect("bound address"); + let data: Arc = + Arc::new(Ok::(local_addr)); + let mut local_addresses = LocalAddressTracker::default(); + local_addresses.handle( + &connection, + bind_uri, + crate::dquic::qinterface::component::location::AddressEvent::Upsert(data), + ); + + let mut endpoints = futures::stream::iter([ + ( + Source::System, + EndpointAddr::direct(SocketAddr::new(local_addr.ip(), local_addr.port() + 1)), + ), + ( + Source::Dht, + EndpointAddr::direct(SocketAddr::new(local_addr.ip(), local_addr.port() + 2)), + ), + ]); + QuicEndpoint::drain_ready_peer_endpoints(&connection, &mut endpoints); + assert!( + QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "ready peer endpoints should pair with the retained local endpoint" + ); + } + + #[tokio::test] + async fn local_address_tracker_removes_direct_endpoint_on_remove() { + use crate::dquic::qinterface::component::location::AddressEvent; + + let bind_pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid bind pattern"); + let endpoint = QuicEndpoint::builder() + .network(Network::builder().build()) + .bind(Arc::new(vec![bind_pattern.clone()])) + .build() + .await; + let tls = endpoint.ensure_client().expect("client tls"); + let connection = endpoint + .build_client_connection("remote.test", tls) + .expect("client connection"); + let iface = endpoint + .network() + .quic() + .get_interfaces(&bind_pattern) + .expect("registered bind") + .into_iter() + .next() + .expect("bound interface"); + + use crate::dquic::net::IO as _; + let bind_uri = iface.bind_uri(); + let local_addr = iface.borrow().bound_addr().expect("bound address"); + let remote_port = if local_addr.port() == u16::MAX { + local_addr.port() - 1 + } else { + local_addr.port() + 1 + }; + let remote_addr = SocketAddr::new(local_addr.ip(), remote_port); + let mut tracker = LocalAddressTracker::default(); + + tracker.handle(&connection, bind_uri.clone(), { + let data: Arc = + Arc::new(Ok::(local_addr)); + AddressEvent::Upsert(data) + }); + QuicEndpoint::add_resolved_peer_endpoint( + &connection, + Source::System, + EndpointAddr::direct(remote_addr), + ); + assert!( + QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "direct local endpoint should pair with direct remote endpoint" + ); + + tracker.handle( + &connection, + bind_uri, + AddressEvent::Remove(TypeId::of::>()), + ); + assert!( + !QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "removing the direct local endpoint should deactivate its path" + ); + } + + #[tokio::test] + async fn local_address_tracker_preserves_connect_state_after_move() { + use crate::dquic::qinterface::component::location::{AddressEvent, Locations}; + + let bind_pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid bind pattern"); + let bind = Arc::new(vec![bind_pattern.clone()]); + let endpoint = QuicEndpoint::builder() + .network(Network::builder().build()) + .bind(bind.clone()) + .build() + .await; + let tls = endpoint.ensure_client().expect("client tls"); + let connection = endpoint + .build_client_connection("remote.test", tls) + .expect("client connection"); + let iface = endpoint + .network() + .quic() + .get_interfaces(&bind_pattern) + .expect("registered bind") + .into_iter() + .next() + .expect("bound interface"); + + use crate::dquic::net::IO as _; + let bind_uri = iface.bind_uri(); + let local_addr = iface.borrow().bound_addr().expect("bound address"); + let remote_port = if local_addr.port() == u16::MAX { + local_addr.port() - 1 + } else { + local_addr.port() + 1 + }; + let remote_addr = SocketAddr::new(local_addr.ip(), remote_port); + let mut tracker = LocalAddressTracker::default(); + + tracker.handle(&connection, bind_uri.clone(), { + let data: Arc = + Arc::new(Ok::(local_addr)); + AddressEvent::Upsert(data) + }); + QuicEndpoint::add_resolved_peer_endpoint( + &connection, + Source::System, + EndpointAddr::direct(remote_addr), + ); + assert!( + QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "direct local endpoint should pair with direct remote endpoint before tracker handoff" + ); + + let locations = Locations::new(); + let observer = locations.subscribe(); + QuicEndpoint::spawn_path_discovery_companion( + bind, + connection.clone(), + futures::stream::empty(), + observer, + tracker, + ); + + tokio::time::timeout(Duration::from_secs(1), async { + loop { + locations.remove::>(bind_uri.clone()); + if !matches!( + QuicEndpoint::connection_has_paths(&connection, false), + Ok(true) + ) { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("moved tracker should remove the pre-connect direct local endpoint"); + } + + #[tokio::test] + async fn local_address_tracker_ignores_unknown_payloads() { + use crate::dquic::qinterface::component::location::AddressEvent; + + let endpoint = make_endpoint().await; + let tls = endpoint.ensure_client().expect("client tls"); + let connection = endpoint + .build_client_connection("remote.test", tls) + .expect("client connection"); + let bind_uri: BindUri = "inet://127.0.0.1:0".parse().expect("bind uri"); + let mut tracker = LocalAddressTracker::default(); + + let unknown_payload: Arc = Arc::new("not a location payload"); + tracker.handle(&connection, bind_uri, AddressEvent::Upsert(unknown_payload)); + + assert!( + !QuicEndpoint::connection_has_paths(&connection, false).expect("path context"), + "unknown events should not create paths" + ); + } + + fn make_identity(name: &str) -> Identity { + Identity { + name: name.parse().unwrap(), + certs: Arc::new(vec![]), + key: Arc::new(PrivateKeyDer::Pkcs8(b"dummy-key-data".to_vec().into())), + ocsp: Arc::new(None), + } + } + + async fn make_endpoint() -> QuicEndpoint { + QuicEndpoint::builder() + .network(Network::builder().build()) + .resolver(Arc::new(SystemResolver)) + .client(ClientQuicConfig::default()) + .server(ServerQuicConfig::default()) + .build() + .await + } + + #[tokio::test] + async fn public_new_uses_default_anonymous_endpoint_shape() { + let endpoint = QuicEndpoint::new().await; + + assert!(endpoint.identity().is_none()); + assert_eq!(endpoint.bind_patterns().len(), 1); + assert_eq!(endpoint.bind_patterns()[0].to_string(), "iface://*"); + assert!(endpoint.client_tls_cache.load_full().is_some()); + assert!(endpoint.server_binding_cache.load_full().is_none()); + } + + #[tokio::test] + async fn accessors_return_builder_supplied_components_and_default_bind() { + let network = Network::builder().build(); + let resolver = Arc::new(SystemResolver); + let endpoint = QuicEndpoint::builder() + .network(network.clone()) + .resolver(resolver.clone()) + .build() + .await; + + assert!(Arc::ptr_eq(endpoint.network(), &network)); + assert!(Arc::ptr_eq(endpoint.resolver(), &(resolver as Arc<_>))); + assert!(endpoint.identity().is_none()); + assert_eq!(endpoint.bind_patterns().len(), 1); + assert_eq!(endpoint.bind_patterns()[0].to_string(), "iface://*"); + } + + #[tokio::test] + async fn ensure_client_caches_tls_but_clone_does_not_share_cache() { + let endpoint = make_endpoint().await; + + let first = endpoint.ensure_client().expect("client tls should build"); + let second = endpoint + .ensure_client() + .expect("client tls should be cached"); + assert!(Arc::ptr_eq(&first, &second)); + + let cloned = endpoint.clone(); + assert!(cloned.client_tls_cache.load_full().is_none()); + } + + #[tokio::test] + async fn build_client_tls_copies_alpns_and_early_data_flag() { + let mut endpoint = make_endpoint().await; + { + let mut client = endpoint.client_config_mut(); + client.alpns = vec![b"h3".to_vec(), b"dhttp".to_vec()]; + client.enable_0rtt = true; + } + + let tls = endpoint + .build_client_tls() + .expect("client tls should build"); + + assert_eq!(tls.alpn_protocols, endpoint.client.load_full().alpns); + assert!(tls.enable_early_data); + } + + #[tokio::test] + async fn test_identity_mut_set() { + let mut endpoint = make_endpoint().await; + assert!(endpoint.identity().is_none()); + + let identity = make_identity("test.example.com"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + + let id = endpoint.identity(); + assert!(id.is_some()); + assert_eq!(id.unwrap().name.as_str(), "test.example.com"); + } + + #[tokio::test] + async fn test_identity_mut_clear() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("clear.me"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + assert!(endpoint.identity().is_some()); + + { + let mut guard = endpoint.identity_mut(); + *guard = None; + } + + assert!(endpoint.identity().is_none()); + } + + #[tokio::test] + async fn test_identity_mut_mutate() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("mutate.me"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + assert!(endpoint.identity().unwrap().ocsp.is_none()); + + { + let mut guard = endpoint.identity_mut(); + if let Some(arc) = guard.as_mut() { + Arc::make_mut(arc).ocsp = Arc::new(Some(vec![10, 20, 30])); + } + } + + let id = endpoint.identity().unwrap(); + assert_eq!(id.ocsp.as_deref(), Some(&[10u8, 20, 30][..])); + } + + #[tokio::test] + async fn test_update_ocsp_with_identity() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("ocsp.example.com"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + + endpoint.update_ocsp(Some(vec![1, 2, 3])); + + let id = endpoint.identity().unwrap(); + assert_eq!(id.ocsp.as_deref(), Some(&[1u8, 2, 3][..])); + } + + #[tokio::test] + async fn test_update_ocsp_without_identity_no_panic() { + let mut endpoint = make_endpoint().await; + assert!(endpoint.identity().is_none()); + + endpoint.update_ocsp(Some(vec![9, 8, 7])); + assert!(endpoint.identity().is_none()); + } + + #[tokio::test] + async fn identity_guard_deref_and_invalidate_caches_cover_direct_helpers() { + let mut endpoint = make_endpoint().await; + endpoint.client_tls_cache.store(Some(Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(), + ))); + + { + let mut guard = endpoint.identity_mut(); + assert!(std::ops::Deref::deref(&guard).is_none()); + *guard = Some(Arc::new(make_identity("deref-helper.test"))); + assert_eq!( + std::ops::Deref::deref(&guard) + .as_ref() + .expect("guard identity") + .name + .as_str(), + "deref-helper.test" + ); + } + assert!(endpoint.identity().is_some()); + + endpoint.client_tls_cache.store(Some(Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(), + ))); + endpoint.invalidate_caches(); + assert!(endpoint.client_tls_cache.load_full().is_none()); + assert!(endpoint.server_binding_cache.load_full().is_none()); + } + + #[tokio::test] + async fn test_client_config_mut() { + let mut endpoint = make_endpoint().await; + assert!(!endpoint.client.load_full().enable_0rtt); + + { + let mut guard = endpoint.client_config_mut(); + guard.enable_0rtt = true; + } // guard dropped → cache invalidated + + assert!(endpoint.client.load_full().enable_0rtt); + } + + #[tokio::test] + async fn test_client_config_mut_drop_invalidates_cache() { + let mut endpoint = make_endpoint().await; + // Pre-fill the cache by triggering ensure_client + let _ = endpoint.ensure_client().is_ok(); + assert!(endpoint.client_tls_cache.load_full().is_some()); + + { + let _guard = endpoint.client_config_mut(); + // No mutation needed — just creating and dropping should invalidate + } + + assert!(endpoint.client_tls_cache.load_full().is_none()); + } + + #[tokio::test] + async fn test_server_config_mut() { + let mut endpoint = make_endpoint().await; + assert!(!endpoint.server.load_full().anti_port_scan); + + { + let mut guard = endpoint.server_config_mut(); + guard.anti_port_scan = true; + } + + assert!(endpoint.server.load_full().anti_port_scan); + } + + #[tokio::test] + async fn test_server_config_mut_drop_invalidates_cache() { + let mut endpoint = make_endpoint().await; + + { + let _guard = endpoint.server_config_mut(); + } + + // Guard drop should clear the cache (whether it was filled or not) + assert!(endpoint.server_binding_cache.load_full().is_none()); + } + + #[tokio::test] + async fn clone_shares_identity_updates() { + let mut endpoint = make_endpoint().await; + let mut cloned = endpoint.clone(); + + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(make_identity("shared-identity.test"))); + } + + assert_eq!( + cloned + .identity() + .expect("clone sees shared identity") + .name + .as_str(), + "shared-identity.test" + ); + + { + let mut guard = cloned.identity_mut(); + *guard = None; + } + + assert!(endpoint.identity().is_none()); + } + + #[tokio::test] + async fn set_resolver_on_clone_does_not_change_original_resolver() { + let endpoint = make_endpoint().await; + let original = endpoint.resolver().clone(); + let replacement: Arc = Arc::new(SystemResolver); + let mut cloned = endpoint.clone(); + + cloned.set_resolver(replacement.clone()); + + assert!(Arc::ptr_eq(endpoint.resolver(), &original)); + assert!(Arc::ptr_eq(cloned.resolver(), &replacement)); + } + + #[tokio::test] + async fn test_clone_preserves_identity() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("clone-preserve.test"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + + let cloned = endpoint.clone(); + let orig_id = endpoint.identity().unwrap(); + let cloned_id = cloned.identity().unwrap(); + + assert_eq!(orig_id.name.as_str(), cloned_id.name.as_str()); + assert_eq!(orig_id.certs.len(), cloned_id.certs.len()); + assert_eq!(orig_id.ocsp.as_deref(), cloned_id.ocsp.as_deref()); + } + + #[tokio::test] + async fn clone_shares_identity_clear() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("shared-clear.test"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + + let cloned = endpoint.clone(); + + { + let mut guard = endpoint.identity_mut(); + *guard = None; + } + assert!(endpoint.identity().is_none()); + assert!(cloned.identity().is_none()); + } + + #[test] + fn test_build_client_tls_error_variants() { + // Cannot construct BuildClientTlsError without real rustls errors; + // verify the enum definition compiles and variant names are correct. + let _ = |e: BuildClientTlsError| match e { + BuildClientTlsError::Version { .. } => "version", + BuildClientTlsError::ClientAuth { .. } => "client_auth", + BuildClientTlsError::SetParameter { .. } => "set_param", + }; + } + + #[test] + fn test_connect_error_variant_discrimination() { + let err = ConnectError::NoReachableEndpoint; + match &err { + ConnectError::Tls { .. } => panic!("expected NoReachableEndpoint, got Tls"), + ConnectError::Dns { .. } => panic!("expected NoReachableEndpoint, got Dns"), + ConnectError::NoReachableEndpoint => {} + } + + let _ = |e: BuildClientTlsError| match e { + BuildClientTlsError::Version { .. } + | BuildClientTlsError::ClientAuth { .. } + | BuildClientTlsError::SetParameter { .. } => {} + }; + } + + #[test] + fn test_accept_error_variant_discrimination() { + let unavailable = AcceptError::ServerUnavailable; + match &unavailable { + AcceptError::ServerUnavailable => {} + AcceptError::BindServer { .. } => panic!("expected ServerUnavailable"), + AcceptError::Shutdown => panic!("expected ServerUnavailable"), + } + + let shutdown = AcceptError::Shutdown; + match &shutdown { + AcceptError::Shutdown => {} + _ => panic!("expected Shutdown"), + } + } + + #[test] + fn test_accept_error_server_unavailable_message() { + let err = AcceptError::ServerUnavailable; + let msg = err.to_string(); + assert!( + msg.contains("cannot accept"), + "AcceptError::ServerUnavailable message should contain 'cannot accept', got: {msg}" + ); + } + + #[test] + fn test_accept_error_shutdown_message() { + let err = AcceptError::Shutdown; + let msg = err.to_string(); + assert!( + msg.contains("shut down"), + "AcceptError::Shutdown message should contain 'shut down', got: {msg}" + ); + } + + #[test] + fn test_connect_error_display() { + let err = ConnectError::NoReachableEndpoint; + let msg = err.to_string(); + assert!( + msg.contains("no reachable endpoint"), + "ConnectError::NoReachableEndpoint message should mention reachable endpoint, got: {msg}" + ); + } + + #[test] + fn test_endpoint_error_alias() { + let _: EndpointError = ConnectError::NoReachableEndpoint; + let _: ConnectError = EndpointError::NoReachableEndpoint; + + fn _accept_endpoint_error(_: EndpointError) {} + fn _accept_connect_error(_: ConnectError) {} + } + + #[test] + fn test_server_config_guard_derefmut() { + let config = Arc::new(ServerQuicConfig::default()); + let target = ArcSwap::from(config); + let cache = ArcSwapOption::::new(None); + { + let mut guard = ServerConfigMutGuard { + config: target.load_full(), + target: &target, + cache: &cache, + }; + assert!(!guard.anti_port_scan); + guard.anti_port_scan = true; + assert!(guard.anti_port_scan); + } + assert!(target.load_full().anti_port_scan); + } + + #[test] + fn test_server_config_guard_drop_invalidates_cache() { + use std::sync::Weak; + + use dashmap::DashMap; + use rustls::{pki_types::CertificateDer, sign::CertifiedKey}; + + use crate::dquic::{ + cert::handy::{ToCertificate, ToPrivateKey}, + sni::{RegistryGuard, ServerConfig as SniServerConfig, ServerEntry}, + }; + + let config = Arc::new(ServerQuicConfig::default()); + let target = ArcSwap::from(config); + let cache = ArcSwapOption::new(None); + + let cert_bytes: &[u8] = include_bytes!("../../tests/keychain/localhost/server.cert"); + let key_bytes: &[u8] = include_bytes!("../../tests/keychain/localhost/server.key"); + let certs: Vec> = cert_bytes.to_certificate(); + let key: PrivateKeyDer<'static> = key_bytes.to_private_key(); + + let identity = Arc::new(Identity { + name: "localhost".parse().unwrap(), + certs: Arc::new(certs.clone()), + key: Arc::new(key.clone_key()), + ocsp: Arc::new(None), + }); + + let provider = rustls::ServerConfig::builder().crypto_provider().clone(); + let signing_key = provider + .key_provider + .load_private_key(identity.key.clone_key()) + .expect("valid private key"); + let certified_key = Arc::new(CertifiedKey { + cert: identity.certs.iter().cloned().collect(), + key: signing_key, + ocsp: None, + }); + + let rustls_config = Arc::new( + rustls::ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(certs, key) + .expect("valid server config"), + ); + + let (tx, rx) = async_channel::unbounded(); + + let registry = Arc::new(DashMap::new()); + let reg_guard = Arc::new(RegistryGuard { + name: "localhost".parse().unwrap(), + registry: Arc::downgrade(®istry), + self_entry: Weak::new(), + }); + + let sni_config = Arc::new(SniServerConfig { + config: ServerQuicConfig::default(), + rustls_config, + }); + + let entry = Arc::new(ServerEntry { + identity, + certified_key, + incomings_tx: tx, + incomings_rx: rx, + config: sni_config, + guard: reg_guard, + bind: Arc::new(vec![]), + }); + + let binding = ServerBinding { entry }; + cache.store(Some(Arc::new(binding))); + + assert!(cache.load_full().is_some()); + + { + let _guard = ServerConfigMutGuard { + config: target.load_full(), + target: &target, + cache: &cache, + }; + } + + assert!(cache.load_full().is_none()); + } + + #[tokio::test] + async fn test_identity_guard_set() { + let endpoint = make_endpoint().await; + assert!(endpoint.identity().is_none()); + + let identity = make_identity("guard-set.test"); + { + let mut guard = IdentityMutGuard { + identity: endpoint.identity.as_ref(), + value: endpoint.identity.load_full(), + client_cache: &endpoint.client_tls_cache, + server_cache: &endpoint.server_binding_cache, + }; + *guard = Some(Arc::new(identity)); + } + + let id = endpoint.identity(); + assert!(id.is_some()); + assert_eq!(id.unwrap().name.as_str(), "guard-set.test"); + } + + #[tokio::test] + async fn test_identity_guard_clear() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("guard-clear.test"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + assert!(endpoint.identity().is_some()); + + { + let mut guard = IdentityMutGuard { + identity: endpoint.identity.as_ref(), + value: endpoint.identity.load_full(), + client_cache: &endpoint.client_tls_cache, + server_cache: &endpoint.server_binding_cache, + }; + *guard = None; + } + + assert!(endpoint.identity().is_none()); + } + + #[tokio::test] + async fn test_identity_guard_mutate() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("guard-mutate.test"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + assert!(endpoint.identity().unwrap().ocsp.is_none()); + + { + let mut guard = IdentityMutGuard { + identity: endpoint.identity.as_ref(), + value: endpoint.identity.load_full(), + client_cache: &endpoint.client_tls_cache, + server_cache: &endpoint.server_binding_cache, + }; + if let Some(arc) = guard.as_mut() { + Arc::make_mut(arc).ocsp = Arc::new(Some(vec![10, 20, 30])); + } + } + + let id = endpoint.identity().unwrap(); + assert_eq!(id.ocsp.as_deref(), Some(&[10u8, 20, 30][..])); + } + + #[tokio::test] + async fn test_identity_guard_drop_invalidates_both_caches() { + let mut endpoint = make_endpoint().await; + let identity = make_identity("cache-inval.test"); + { + let mut guard = endpoint.identity_mut(); + *guard = Some(Arc::new(identity)); + } + + // Pre-fill client TLS cache + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + endpoint + .client_tls_cache + .store(Some(Arc::new(client_config))); + + // Pre-fill server binding cache — construct a minimal ServerBinding + // through the network infrastructure + endpoint.init_server().await; + + { + let guard = IdentityMutGuard { + identity: endpoint.identity.as_ref(), + value: endpoint.identity.load_full(), + client_cache: &endpoint.client_tls_cache, + server_cache: &endpoint.server_binding_cache, + }; + // no mutation — guard drops here + let _ = &guard; + } + + assert!( + endpoint.client_tls_cache.load_full().is_none(), + "client TLS cache should be invalidated on guard drop" + ); + assert!( + endpoint.server_binding_cache.load_full().is_none(), + "server binding cache should be invalidated on guard drop" + ); + } + + #[test] + fn test_client_config_guard_derefmut() { + let config = Arc::new(ClientQuicConfig::default()); + let target = ArcSwap::from(config); + let cache = ArcSwapOption::::new(None); + { + let mut guard = ClientConfigMutGuard { + config: target.load_full(), + target: &target, + cache: &cache, + }; + assert!(!std::ops::Deref::deref(&guard).enable_0rtt); + guard.enable_0rtt = true; + } + assert!(target.load_full().enable_0rtt); + } + + #[tokio::test] + async fn quic_listen_trait_impls_delegate_accept_and_shutdown() { + let mut owned = make_endpoint().await; + let owned_error = match ::accept(&mut owned).await { + Ok(_) => panic!("anonymous owned endpoint cannot accept"), + Err(error) => error, + }; + assert!(matches!(owned_error, AcceptError::ServerUnavailable)); + ::shutdown(&owned) + .await + .expect("owned shutdown"); + + let endpoint = make_endpoint().await; + let mut shared = &endpoint; + let shared_error = match <&QuicEndpoint as quic::Listen>::accept(&mut shared).await { + Ok(_) => panic!("anonymous shared endpoint cannot accept"), + Err(error) => error, + }; + assert!(matches!(shared_error, AcceptError::ServerUnavailable)); + <&QuicEndpoint as quic::Listen>::shutdown(&shared) + .await + .expect("shared shutdown"); + } + + #[test] + fn test_client_config_guard_drop_invalidates_cache() { + let config = Arc::new(ClientQuicConfig::default()); + let target = ArcSwap::from(config); + let cache = ArcSwapOption::new(Some(Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(), + ))); + + assert!(cache.load_full().is_some()); + + { + let _guard = ClientConfigMutGuard { + config: target.load_full(), + target: &target, + cache: &cache, + }; + } + + assert!(cache.load_full().is_none()); + } + + #[tokio::test] + async fn test_cache_invalidation_is_precise() { + let endpoint = make_endpoint().await; + + // Pre-fill client cache + let dummy_client_tls = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + endpoint + .client_tls_cache + .store(Some(Arc::new(dummy_client_tls))); + + assert!(endpoint.client_tls_cache.load_full().is_some()); + assert!(endpoint.server_binding_cache.load_full().is_none()); + + // Client cache invalidation should NOT affect server cache + endpoint.invalidate_client_cache(); + assert!(endpoint.client_tls_cache.load_full().is_none()); + assert!(endpoint.server_binding_cache.load_full().is_none()); + + // Re-fill client, invalidate server — client should remain intact + endpoint.client_tls_cache.store(Some(Arc::new( + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(), + ))); + endpoint.invalidate_server_cache(); + assert!(endpoint.client_tls_cache.load_full().is_some()); + assert!(endpoint.server_binding_cache.load_full().is_none()); + } + + #[test] + fn test_invalidate_caches_clears_both() { + // Verify the method structure compiles and that invalidate_caches + // delegates to both precise helpers. + let cache_client: ArcSwapOption = ArcSwapOption::empty(); + let cache_server: ArcSwapOption = ArcSwapOption::empty(); + + // Populate client cache + let client_config = rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + cache_client.store(Some(Arc::new(client_config))); + + assert!(cache_client.load_full().is_some()); + assert!(cache_server.load_full().is_none()); + } +} diff --git a/src/dquic/identity.rs b/src/dquic/identity.rs new file mode 100644 index 0000000..41c829b --- /dev/null +++ b/src/dquic/identity.rs @@ -0,0 +1,196 @@ +//! Identity used by a [`QuicEndpoint`](super::QuicEndpoint) when performing +//! TLS handshakes. +//! +//! An [`Identity`] bundles the SNI (server name) with the certificate chain +//! and private key. When stored in an endpoint as `Option>`, +//! cloning is cheap — the identity is shared through an `Arc`. +//! +//! The endpoint's identity selects between client-auth / server-auth paths +//! and keys the SNI registry for inbound connection multiplexing. + +use std::sync::Arc; + +pub use dhttp_identity::{identity::Identity, name::Name}; + +/// Build a [`CertifiedKey`](rustls::sign::CertifiedKey) from an [`Identity`] +/// for use in rustls. +/// +/// The returned key is wrapped in [`Arc`] so it can be shared across +/// many TLS sessions. +pub(crate) fn build_certified_key( + identity: &Identity, +) -> Result, crate::dquic::network::BindServerError> { + use snafu::ResultExt; + + use crate::dquic::network::bind_server_error::LoadKeySnafu; + + let provider = rustls::ServerConfig::builder().crypto_provider().clone(); + let key = provider + .key_provider + .load_private_key(identity.key.clone_key()) + .context(LoadKeySnafu)?; + Ok(Arc::new(rustls::sign::CertifiedKey { + cert: identity.certs.iter().cloned().collect(), + key, + ocsp: identity.ocsp.as_ref().clone(), + })) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use rustls::pki_types::PrivateKeyDer; + + use super::*; + + #[test] + fn test_certs_arc_sharing() { + let certs = vec![]; + let id1 = Identity { + name: "test".parse().unwrap(), + certs: Arc::new(certs), + key: Arc::new(PrivateKeyDer::Pkcs8(b"dummy".to_vec().into())), + ocsp: Arc::new(None), + }; + let id2 = id1.clone(); + assert!( + Arc::ptr_eq(&id1.certs, &id2.certs), + "certs should be shared via Arc" + ); + } + + #[test] + fn test_ocsp_default_none() { + let id = Identity { + name: "test".parse().unwrap(), + certs: Arc::new(vec![]), + key: Arc::new(PrivateKeyDer::Pkcs8(b"dummy".to_vec().into())), + ocsp: Arc::new(None), + }; + assert!(id.ocsp.is_none(), "ocsp should be None by default"); + } + + #[test] + fn test_ocsp_update_independent() { + let id1 = Identity { + name: "test".parse().unwrap(), + certs: Arc::new(vec![]), + key: Arc::new(PrivateKeyDer::Pkcs8(b"dummy".to_vec().into())), + ocsp: Arc::new(None), + }; + let mut id2 = id1.clone(); + // Simulate updating id2's ocsp by creating a new Arc + id2.ocsp = Arc::new(Some(vec![1, 2, 3])); + assert!(id1.ocsp.is_none(), "id1.ocsp should remain None"); + assert!(id2.ocsp.is_some(), "id2.ocsp should be Some"); + } + + #[test] + fn test_name_lowercase_normalization() { + let name: Name = "LOCALHOST".parse().unwrap(); + assert_eq!(name.as_str(), "localhost"); + } + + #[test] + fn test_name_display() { + let name: Name = "MyHost".parse().unwrap(); + assert_eq!(format!("{name}"), "myhost"); + } + + #[test] + fn test_name_deref() { + let name: Name = "AbCdEfG".parse().unwrap(); + let s: &str = &name; + assert_eq!(s, "abcdefg", "Name should deref to lowercased str"); + } + + #[test] + fn test_name_borrow() { + use std::borrow::Borrow; + let name: Name = "BorrowTest".parse().unwrap(); + let s: &str = Borrow::::borrow(&name); + assert_eq!( + s, "borrowtest", + "Borrow should return the lowercased name" + ); + } + + #[test] + fn test_name_eq() { + let a: Name = "EXAMPLE".parse().unwrap(); + let b: Name = "example".parse().unwrap(); + assert_eq!(a, b, "same name in different case should be equal"); + + let c: Name = "other".parse().unwrap(); + assert_ne!(a, c, "different names should not be equal"); + } + + #[test] + fn test_name_hash_consistency() { + use std::{ + collections::hash_map::DefaultHasher, + hash::{Hash, Hasher}, + }; + + let hash = |n: &Name| { + let mut h = DefaultHasher::new(); + n.hash(&mut h); + h.finish() + }; + + let a: Name = "MiXeDcAsE".parse().unwrap(); + let b: Name = "mixedcase".parse().unwrap(); + assert_eq!( + hash(&a), + hash(&b), + "same logical name should hash to same value regardless of input case" + ); + } + + #[test] + fn test_identity_clone_preserves_fields() { + let key = PrivateKeyDer::Pkcs8(b"dummy".to_vec().into()); + let id1 = Identity { + name: "clone-test".parse().unwrap(), + certs: Arc::new(vec![]), + key: Arc::new(key), + ocsp: Arc::new(None), + }; + let id2 = id1.clone(); + + assert_eq!(id1.name, id2.name, "cloned name should equal original"); + assert!( + Arc::ptr_eq(&id1.certs, &id2.certs), + "certs should be Arc-shared after clone" + ); + assert!( + Arc::ptr_eq(&id1.key, &id2.key), + "key should be Arc-shared after clone" + ); + } + + #[test] + fn test_build_certified_key_with_valid_key() { + use dquic::prelude::handy::{ToCertificate, ToPrivateKey}; + + const SERVER_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/server.cert"); + const SERVER_KEY: &[u8] = include_bytes!("../../tests/keychain/localhost/server.key"); + + let certs = SERVER_CERT.to_certificate(); + let key = SERVER_KEY.to_private_key(); + + let identity = Identity { + name: "localhost".parse().unwrap(), + certs: Arc::new(certs), + key: Arc::new(key), + ocsp: Arc::new(None), + }; + + let result = build_certified_key(&identity); + assert!( + result.is_ok(), + "build_certified_key should succeed with valid key material" + ); + } +} diff --git a/src/dquic/network.rs b/src/dquic/network.rs new file mode 100644 index 0000000..890e448 --- /dev/null +++ b/src/dquic/network.rs @@ -0,0 +1,3402 @@ +//! Process-shared QUIC network infrastructure. +//! +//! [`Network`] owns the long-lived bind registry and device reconciliation +//! task. Its built-in [`QuicBindDriver`] owns the QUIC runtime components +//! needed to send and receive QUIC packets on a set of network interfaces, +//! including the SNI fan-out registry used to route newly-accepted +//! connections to per-identity +//! [`QuicEndpoint`](crate::dquic::endpoint::QuicEndpoint) listeners. +//! +//! A [`Network`] is always used via [`Arc`]: clone the [`Arc`] to +//! share the infrastructure between many endpoints. The builder returns an +//! [`Arc`] directly via [`Network::builder().build()`](Network::builder). +//! +//! ## SNI dispatch +//! +//! The built-in QUIC driver installs a *connectionless packet dispatcher* on +//! its [`QuicRouter`]. When an Initial / 0-RTT packet arrives without matching +//! an existing connection, the dispatcher constructs a fresh server +//! [`Connection`] using the shared +//! [`ServerQuicConfig`] stored in the +//! QUIC driver's `server_slot`, waits for the handshake to reveal the +//! ClientHello SNI, and fans the connection into the matching +//! [`ServerBinding`]'s mpmc queue. +//! +//! Because the underlying rustls `ServerConfig` must be chosen *before* SNI +//! is known, the QUIC driver stores **one** `ServerQuicConfig` at a time. +//! The first call to [`QuicBindDriver::bind_server`] initialises that slot; +//! subsequent calls with an identical (or `Arc::ptr_eq`) configuration +//! succeed, while a conflicting configuration is rejected with +//! [`BindServerError::ServerConfigConflict`]. When the last [`ServerBinding`] +//! referring to the slot drops, the slot clears and a different configuration +//! may be used on the next bind. +//! +//! ## Binds registry +//! +//! [`QuicBindDriver::bind`] and [`Network::bind_with`] register a +//! [`BindPattern`] with reference counting. Repeated calls with the same +//! driver and pattern increment the count; when every +//! returned [`BindHandle`] is dropped or [`BindHandle::unbind`] is called, +//! the count reaches zero and the bound interfaces are released. +//! A single background reconcile task (started at build time) +//! watches for device changes and keeps every pattern's bound interfaces in +//! sync. + +use std::{ + collections::HashMap, + io, mem, + net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock, Weak}, + task::{Context, Poll}, +}; + +use dashmap::DashMap; +use dhttp_identity::name::Name; +use futures::{FutureExt, future::BoxFuture}; +use snafu::Snafu; +use tokio::sync::{Mutex as TokioMutex, OwnedMutexGuard}; +use tokio_util::task::AbortOnDropHandle; +use tracing::Instrument; + +// Re-export ServerBinding so `crate::dquic::network::ServerBinding` resolves. +// `sni` module is the canonical source; this re-export is kept for compatibility +// with existing consumers (e.g. `crate::dquic::endpoint`). +pub use crate::dquic::sni::ServerBinding; +use crate::dquic::{ + binds::BindPattern, + connection::Connection, + identity::Identity, + net::{ + BindInterface, BindUri, Devices, Family, InterfaceManager, Locations, ProductIO, + QuicRouter, handy::DEFAULT_IO_FACTORY, + }, + resolver::{Resolve, handy::SystemResolver}, + server::ServerQuicConfig, + sni::{self, RegistryGuard, ServerConfig, ServerEntry, SniCertResolver}, + tls::{ + AuthClient, ClientAuthorityVerifyResult, ClientNameVerifyResult, LocalAuthority, + RemoteAuthority, + }, +}; +// Internal implementation types — not part of curated domain modules +use crate::dquic::{ + qbase::packet::Packet, + qconnection::builder::ConnectionFoundation, + qinterface::{ + component::{ + location::LocationsComponent, + route::{QuicRouterComponent, Way}, + }, + device::InterfaceEvent, + }, + qtraversal::{ + nat::{client::StunClientsComponent, router::StunRouterComponent}, + route::{ForwardersComponent, ReceiveAndDeliverPacketComponent}, + }, +}; + +pub(crate) type SniRegistry = Arc, Weak>>; +type BoundInterfaces = HashMap; + +fn interface_contains(interface: &::dquic::qinterface::device::Interface, ip: IpAddr) -> bool { + match ip { + IpAddr::V4(ip) => interface.ipv4.iter().any(|net| net.contains(&ip)), + IpAddr::V6(ip) => interface.ipv6.iter().any(|net| net.contains(&ip)), + } +} + +fn iface_bind_uri_device(uri: &BindUri) -> Option<&str> { + uri.as_iface_bind_uri() + .map(|(_family, device, _port)| device) +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct BindRegistryKey { + driver: BindDriverId, + pattern: BindPattern, +} + +/// Opaque identity of an [`Arc`]-backed [`BindDriver`]. +/// +/// The bind registry stores a strong `Arc` next to every key, +/// so the allocation behind this pointer identity cannot be freed or reused +/// while the key is live. A `Weak` key would still need the same pointer-based +/// hashing/equality, while also mixing liveness into an identity-only key. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct BindDriverId(usize); + +fn bind_driver_id(driver: &Arc) -> BindDriverId { + BindDriverId(Arc::as_ptr(driver) as *const () as usize) +} + +/// Driver that binds one concrete [`BindUri`] into a runtime binding resource. +pub trait BindDriver: Send + Sync { + fn bind<'a>(&'a self, network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface>; + + fn rebind<'a>(&'a self, _network: &'a Network, _iface: &'a BindInterface) -> BoxFuture<'a, ()> { + async {}.boxed() + } +} + +/// Built-in QUIC binding runtime owned by [`Network`]. +/// +/// Obtain it with [`Network::quic`]. Methods that register or query binds use +/// the originating [`Network`]'s bind registry; keep the `Network` alive while +/// using a cloned `Arc`. +pub struct QuicBindDriver { + network: Weak, + iface_manager: Arc, + io_factory: Arc, + stun_resolver: Arc, + stun_server: Option>, + quic_router: Arc, + locations: Arc, + sni_registry: SniRegistry, + server_slot: RwLock>, +} + +#[bon::bon] +impl QuicBindDriver { + #[builder] + fn new( + network: Weak, + #[builder(default = Arc::new(InterfaceManager::new()))] iface_manager: Arc< + InterfaceManager, + >, + #[builder(default = Arc::new(DEFAULT_IO_FACTORY))] io_factory: Arc, + #[builder(default = Arc::new(SystemResolver))] stun_resolver: Arc< + dyn Resolve + Send + Sync, + >, + stun_server: Option>, + #[builder(default = Arc::new(QuicRouter::new()))] quic_router: Arc, + #[builder(default = Arc::new(Locations::new()))] locations: Arc, + ) -> Arc { + let driver = Arc::new(Self { + network, + iface_manager, + io_factory, + stun_resolver, + stun_server, + quic_router, + locations, + sni_registry: Arc::new(DashMap::new()), + server_slot: RwLock::new(Weak::new()), + }); + + let weak_driver = Arc::downgrade(&driver); + let installed = driver + .quic_router + .on_connectless_packets(move |packet, way| { + let Some(driver) = weak_driver.upgrade() else { + return; + }; + Self::dispatch_initial_packet(driver, packet, way); + }); + tracing::debug!(installed, "sni dispatcher installation"); + + driver + } +} + +impl BindDriver for QuicBindDriver { + fn bind<'a>(&'a self, _network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface> { + async move { + let iface = self.iface_manager.bind(uri, self.io_factory.clone()).await; + self.init_quic_iface_components(&iface); + iface + } + .boxed() + } + + fn rebind<'a>(&'a self, _network: &'a Network, iface: &'a BindInterface) -> BoxFuture<'a, ()> { + async move { + iface.rebind().await; + } + .boxed() + } +} + +impl QuicBindDriver { + fn registry_id(&self) -> BindDriverId { + BindDriverId(self as *const Self as *const () as usize) + } + + fn network(&self) -> Arc { + self.network + .upgrade() + .expect("quic driver is detached from network") + } + + #[must_use] + pub fn iface_manager(&self) -> Arc { + self.iface_manager.clone() + } + + #[must_use] + pub fn io_factory(&self) -> Arc { + self.io_factory.clone() + } + + #[must_use] + pub fn quic_router(&self) -> Arc { + self.quic_router.clone() + } + + #[must_use] + pub fn locations(&self) -> Arc { + self.locations.clone() + } + + #[must_use] + pub fn stun_server(&self) -> Option> { + self.stun_server.clone() + } + + #[must_use] + pub fn stun_resolver(&self) -> Arc { + self.stun_resolver.clone() + } + + pub(crate) fn configure_connection( + &self, + builder: ConnectionFoundation, + ) -> ConnectionFoundation { + builder + .with_iface_manager(self.iface_manager.clone()) + .with_iface_factory(self.io_factory.clone()) + .with_quic_router(self.quic_router.clone()) + .with_locations(self.locations.clone()) + } + + fn init_quic_iface_components(&self, bind_iface: &BindInterface) { + let uri = bind_iface.bind_uri(); + let stun_server = if let Some(server) = uri.stun_server() { + Some(Arc::from(server.into_owned())) + } else if let Some("false") = uri.prop(BindUri::STUN_PROP).as_deref() { + None + } else { + self.stun_server.clone() + }; + + bind_iface.with_components_mut(|components, iface| { + let router = components + .init_with(|| QuicRouterComponent::new(self.quic_router.clone())) + .router(); + + if let Some(stun_server) = stun_server { + let stun_router = components + .init_with(|| StunRouterComponent::new(iface.downgrade())) + .router(); + let loc = components + .init_with(|| { + LocationsComponent::new(iface.downgrade(), self.locations.clone()) + }) + .clone(); + let clients = components + .init_with(|| { + StunClientsComponent::new( + iface.downgrade(), + stun_router.clone(), + self.stun_resolver.clone(), + stun_server, + Vec::new(), + Some(loc), + ) + }) + .clone(); + let forwarder = components + .init_with(|| ForwardersComponent::new_client(clients)) + .forwarder(); + components.init_with(|| { + ReceiveAndDeliverPacketComponent::builder(iface.downgrade()) + .quic_router(router) + .stun_router(stun_router) + .forwarder(forwarder) + .init() + }); + } else { + components.init_with(|| { + LocationsComponent::new(iface.downgrade(), self.locations.clone()) + }); + components.init_with(|| { + ReceiveAndDeliverPacketComponent::builder(iface.downgrade()) + .quic_router(router) + .init() + }); + } + }); + } + + fn server_binding_for_existing(existing: &Arc) -> ServerBinding { + ServerBinding { + entry: existing.clone(), + } + } + + fn compatible_server_slot( + &self, + server_config: &ServerQuicConfig, + ) -> Result, BindServerError> { + use bind_server_error::*; + + let slot_guard = self.server_slot.read().unwrap(); + if let Some(slot) = slot_guard.upgrade() { + if !slot.config.is_compatible_with(server_config) { + return ServerConfigConflictSnafu.fail(); + } + return Ok(slot); + } + + drop(slot_guard); + let mut slot_guard = self.server_slot.write().unwrap(); + if let Some(slot) = slot_guard.upgrade() { + if !slot.config.is_compatible_with(server_config) { + return ServerConfigConflictSnafu.fail(); + } + return Ok(slot); + } + + let resolver = SniCertResolver { + registry: Arc::downgrade(&self.sni_registry), + }; + let rustls_config = Arc::new(server_config.build_rustls_server_config(resolver)?); + let slot = Arc::new(ServerConfig { + config: server_config.clone(), + rustls_config, + }); + *slot_guard = Arc::downgrade(&slot); + Ok(slot) + } + + fn new_server_binding( + &self, + name: Name<'static>, + identity: Arc, + server_config: ServerQuicConfig, + bind_patterns: Arc>, + ) -> Result { + let slot = self.compatible_server_slot(&server_config)?; + + let certified_key = crate::dquic::identity::build_certified_key(&identity)?; + let (incomings_tx, incomings_rx) = async_channel::bounded(server_config.backlog); + + let entry = Arc::new_cyclic(|weak_entry| ServerEntry { + identity: identity.clone(), + certified_key, + incomings_tx, + incomings_rx, + config: slot, + guard: Arc::new(RegistryGuard { + name: name.clone(), + registry: Arc::downgrade(&self.sni_registry), + self_entry: weak_entry.clone(), + }), + bind: bind_patterns.clone(), + }); + + Ok(ServerBinding { entry }) + } + + pub async fn bind_server( + self: Arc, + identity: Arc, + server_config: ServerQuicConfig, + bind_patterns: Arc>, + ) -> Result { + use bind_server_error::*; + use dashmap::mapref::entry::Entry; + + let name = identity.name.clone(); + + match self.sni_registry.entry(name.clone()) { + Entry::Occupied(mut occupied) => match occupied.get().upgrade() { + Some(existing) if Arc::ptr_eq(&existing.identity, &identity) => { + if !existing.config.config.is_compatible_with(&server_config) { + return ServerConfigConflictSnafu.fail(); + } + Ok(Self::server_binding_for_existing(&existing)) + } + Some(_) => SniInUseSnafu { name }.fail(), + None => { + let binding = + self.new_server_binding(name, identity, server_config, bind_patterns)?; + occupied.insert(Arc::downgrade(&binding.entry)); + Ok(binding) + } + }, + Entry::Vacant(vacant) => { + let binding = + self.new_server_binding(name, identity, server_config, bind_patterns)?; + vacant.insert(Arc::downgrade(&binding.entry)); + Ok(binding) + } + } + } + + #[cfg(test)] + fn registered_sni_names(&self) -> Vec> { + self.sni_registry + .iter() + .filter(|kv| kv.value().upgrade().is_some()) + .map(|kv| kv.key().clone()) + .collect() + } + + /// Register a [`BindPattern`] through this built-in QUIC driver. + /// + /// Returns a [`BindHandle`] that keeps the pattern alive. When all + /// handles for a pattern are dropped (or [`BindHandle::unbind`] is + /// called), the bound interfaces are released. + pub async fn bind(self: Arc, pattern: BindPattern) -> BindHandle { + let network = self.network(); + network.bind_with(self.clone(), pattern).await + } + + fn collect_bound(&self, map: impl Fn(&BindUri, &BindInterface) -> T) -> Vec { + let network = self.network(); + let registry = network + .bind_registry + .lock() + .expect("bind_registry poisoned"); + let quic_driver = self.registry_id(); + let mut items = Vec::new(); + for (_, entry) in registry.iter().filter(|(key, _)| key.driver == quic_driver) { + let state = entry.lock_state(); + items.extend(state.bound.iter().map(|(uri, iface)| map(uri, iface))); + } + items + } + + /// Return all currently bound interfaces owned by this QUIC driver. + #[must_use] + pub fn interfaces(&self) -> Vec { + self.collect_bound(|_, iface| iface.clone()) + } + + /// Return all currently bound URIs owned by this QUIC driver. + #[must_use] + pub fn current_bind_uris(&self) -> Vec { + self.collect_bound(|uri, _| uri.clone()) + } + + /// Return the currently bound interface for an exact [`BindUri`]. + #[must_use] + pub fn get_iface(&self, uri: &BindUri) -> Option { + let network = self.network(); + let registry = network + .bind_registry + .lock() + .expect("bind_registry poisoned"); + let quic_driver = self.registry_id(); + registry + .iter() + .filter(|(key, _)| key.driver == quic_driver) + .map(|(_, entry)| entry) + .find_map(|entry| { + let state = entry.lock_state(); + state.bound.get(uri).cloned() + }) + } + + /// Return all currently bound interfaces for a specific [`BindPattern`]. + /// + /// Returns `None` if the pattern has not been registered through this + /// driver. + #[must_use] + pub fn get_interfaces(&self, pattern: &BindPattern) -> Option> { + let network = self.network(); + let key = BindRegistryKey { + driver: self.registry_id(), + pattern: pattern.clone(), + }; + let registry = network + .bind_registry + .lock() + .expect("bind_registry poisoned"); + registry + .get(&key) + .map(|entry| entry.lock_state().bound.values().cloned().collect()) + } + + /// Connectionless packet dispatcher installed on the [`QuicRouter`]. + /// + /// Called for every Initial / 0-RTT packet that doesn't match an + /// existing connection. Spawns a new server [`Connection`] using the + /// shared server slot, waits for the TLS handshake to reveal the SNI, + /// and fans the connection into the matching [`ServerBinding`]'s queue. + fn dispatch_initial_packet(driver: Arc, packet: Packet, way: Way) { + use crate::dquic::qbase::packet::{ + DataHeader, + header::{self, GetDcid}, + }; + + // Filter to Initial and 0-RTT packets. + let (bind_uri, pathway, link) = way; + + let data_packet = match &packet { + Packet::Data(dp) => dp, + Packet::VN(_) | Packet::Retry(_) => return, + }; + + let DataHeader::Long(long_header) = &data_packet.header else { + return; + }; + + let (header::long::DataHeader::Initial(_) | header::long::DataHeader::ZeroRtt(_)) = + long_header + else { + return; + }; + + let origin_dcid = *data_packet.dcid(); + + // Read the slot synchronously: this router callback cannot await. + let Some(slot) = driver.server_slot.try_read().ok().and_then(|g| g.upgrade()) else { + return; + }; + + // Build the connection synchronously so the CID is registered in + // QuicRouter before the first packet is delivered. + let sni_registry = driver.sni_registry.clone(); + + let foundation = Connection::new_server(slot.config.token_provider.clone()) + .with_parameters(slot.config.parameters.clone()) + .with_client_auther(Box::new(( + InterfaceAuthClient { + bind_uri: bind_uri.clone(), + sni_registry: sni_registry.clone(), + }, + slot.config.client_auther.clone(), + ))) + .with_tls_config((*slot.rustls_config).clone()); + + let conn = driver + .configure_connection(foundation) + .with_streams_concurrency_strategy(slot.config.stream_strategy_factory.as_ref()) + .with_defer_idle_timeout(slot.config.defer_idle_timeout) + .with_zero_rtt(slot.config.enable_0rtt) + .with_cids(origin_dcid) + .with_qlog(slot.config.qlogger.clone()) + .run(); + + let quic_router = driver.quic_router.clone(); + + // Spawn only async delivery and SNI dispatch work. + // Inherent termination: this task owns the newly-created server + // connection and exits once SNI dispatch succeeds, the SNI queue is + // closed, or connection name resolution fails. + tokio::spawn( + async move { + quic_router.deliver(packet, (bind_uri, pathway, link)).await; + + // server_name() waits for TLS handshake info — do NOT call + // handshaked() here; for server role it blocks until a Data + // packet arrives, which would deadlock the test until timeout. + let sni = match conn.server_name().await { + Ok(name) => name, + Err(e) => { + let report = snafu::Report::from_error(&e); + tracing::debug!( + error = %report, + "failed to get server name" + ); + return; + } + }; + + let sni_lower = sni.to_ascii_lowercase(); + if let Some(entry) = sni_registry + .get::(&sni_lower) + .and_then(|item| item.value().upgrade()) + { + let incomings_tx = entry.incomings_tx.clone(); + drop(entry); + if incomings_tx.send(conn).await.is_err() { + tracing::debug!( + name = %sni, + "sni channel closed" + ); + } + return; + } + tracing::debug!( + name = %sni, + "no endpoint registered for SNI" + ); + } + .in_current_span(), + ); + } +} + +/// Error returned by [`QuicBindDriver::bind_server`]. +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum BindServerError { + /// Another [`Identity`] is already registered for the same SNI. + #[snafu(display("sni {name} is already bound to a different identity"))] + SniInUse { + /// SNI that is already registered. + name: Name<'static>, + }, + /// The QUIC driver already has a server configuration that is + /// incompatible with the one provided. Drop every existing + /// [`ServerBinding`] before binding with a different configuration. + #[snafu(display("quic driver already holds an incompatible server configuration"))] + ServerConfigConflict, + /// Loading the identity's private key into rustls failed. + #[snafu(display("failed to load server private key"))] + LoadKey { + /// Underlying rustls error. + source: rustls::Error, + }, + /// rustls rejected the selected protocol version. + #[snafu(display("failed to select TLS protocol version"))] + Version { + /// Underlying rustls error. + source: rustls::Error, + }, +} + +/// Shared network bind orchestration with a built-in QUIC runtime. +/// +/// Used exclusively via [`Arc`]. [`Network::builder().build()`](Network::builder) +/// returns an [`Arc`] directly after constructing the built-in QUIC driver. +pub struct Network { + devices: &'static Devices, + quic_driver: Arc, + bind_registry: Mutex>>, + _reconcile: OnceLock>, +} + +struct BindsEntry { + key: BindRegistryKey, + driver: Arc, + pattern: BindPattern, + serial: Arc>, + state: Mutex, +} + +struct BindsState { + refcount: usize, + closing: bool, + /// Live bindings keyed by full [`BindUri`] identity. Holds strong + /// [`BindInterface`] references so that the interfaces and their + /// installed components stay alive. + bound: BoundInterfaces, +} + +#[derive(Clone)] +struct BindEntryRef { + key: BindRegistryKey, + entry: Arc, +} + +impl BindsEntry { + fn new(key: BindRegistryKey, driver: Arc) -> Self { + Self { + pattern: key.pattern.clone(), + key, + driver, + serial: Arc::new(TokioMutex::new(())), + state: Mutex::new(BindsState { + refcount: 0, + closing: false, + bound: HashMap::new(), + }), + } + } + + fn lock_state(&self) -> MutexGuard<'_, BindsState> { + self.state.lock().expect("bind entry state mutex poisoned") + } + + fn is_closing(&self) -> bool { + self.lock_state().closing + } + + fn pattern_matches_device(&self, device: &str) -> bool { + self.pattern.interface_bind_uris(device).next().is_some() + } +} + +impl BindsState { + fn missing_current_device_uris( + &self, + pattern: &BindPattern, + devices: &'static Devices, + ) -> Vec { + pattern + .to_bind_uris(devices.interfaces().keys().map(String::as_str)) + .filter(|candidate| { + !self + .bound + .keys() + .any(|bound| bound.matches_reconcile_candidate(candidate)) + }) + .collect() + } + + fn missing_added_device_uris(&self, pattern: &BindPattern, device: &str) -> Vec { + pattern + .interface_bind_uris(device) + .filter(|candidate| { + !self + .bound + .keys() + .any(|bound| bound.matches_reconcile_candidate(candidate)) + }) + .collect() + } + + fn drain_device_bindings(&mut self, device: &str) -> CloseBatch { + let mut retained = HashMap::with_capacity(self.bound.len()); + let mut close = CloseBatch::new(); + + for (uri, iface) in mem::take(&mut self.bound) { + if iface_bind_uri_device(&uri) == Some(device) { + close.push(iface); + } else { + retained.insert(uri, iface); + } + } + + self.bound = retained; + close + } + + fn changed_device_targets(&self, device: &str) -> Vec { + self.bound + .iter() + .filter(|(uri, _iface)| iface_bind_uri_device(uri) == Some(device)) + .map(|(_uri, iface)| iface.clone()) + .collect() + } +} + +#[derive(Default)] +struct CloseBatch { + current: Option>, + ifaces: Vec, +} + +impl CloseBatch { + fn new() -> Self { + Self::default() + } + + fn is_empty(&self) -> bool { + self.current.is_none() && self.ifaces.is_empty() + } + + fn push(&mut self, iface: BindInterface) { + self.ifaces.push(iface); + } + + fn extend(&mut self, ifaces: impl IntoIterator) { + self.ifaces.extend(ifaces); + } + + fn close_iface(iface: BindInterface) -> BoxFuture<'static, ()> { + async move { + iface.close().await.ok(); + } + .boxed() + } + + async fn close_all(&mut self) { + loop { + if self.current.is_none() { + self.current = self.ifaces.pop().map(Self::close_iface); + } + + let Some(current) = self.current.as_mut() else { + return; + }; + + current.as_mut().await; + self.current = None; + } + } + + fn detach(self) -> tokio::task::JoinHandle<()> { + tokio::spawn( + async move { + let mut close = self; + close.close_all().await; + } + .in_current_span(), + ) + } +} + +struct BindEntryPermit { + key: BindRegistryKey, + entry: Arc, + _serial: OwnedMutexGuard<()>, +} + +struct EntryRelease { + network: Arc, + key: BindRegistryKey, + entry: Arc, + permit: Option, + remove_entry: bool, + close: CloseBatch, + completed: bool, +} + +struct PendingBindRegistration { + network: Arc, + permit: Option, + new_bindings: Vec<(BindUri, BindInterface)>, +} + +impl BindEntryPermit { + fn begin_bind_registration(self, network: Arc) -> PendingBindRegistration { + PendingBindRegistration { + network, + permit: Some(self), + new_bindings: Vec::new(), + } + } + + fn plan_added_device_bind(&self, device: &str) -> Vec { + let state = self.entry.lock_state(); + if state.closing { + return Vec::new(); + } + state.missing_added_device_uris(&self.entry.pattern, device) + } + + fn commit_added_device( + &self, + device_exists: bool, + new_bindings: Vec<(BindUri, BindInterface)>, + ) -> CloseBatch { + let mut state = self.entry.lock_state(); + let mut close = CloseBatch::new(); + + if state.closing || !device_exists { + close.extend(new_bindings.into_iter().map(|(_, iface)| iface)); + return close; + } + + for (uri, iface) in new_bindings { + if let std::collections::hash_map::Entry::Vacant(slot) = state.bound.entry(uri) { + slot.insert(iface); + } else { + close.push(iface); + } + } + + close + } + + fn drain_removed_device(&self, device: &str) -> CloseBatch { + let mut state = self.entry.lock_state(); + if state.closing { + return CloseBatch::new(); + } + state.drain_device_bindings(device) + } + + fn targets_for_changed_device(&self, device: &str) -> Vec { + let state = self.entry.lock_state(); + if state.closing { + return Vec::new(); + } + state.changed_device_targets(device) + } + + fn release_one_handle(self, network: Arc) -> EntryRelease { + let mut state = self.entry.lock_state(); + let mut close = CloseBatch::new(); + let mut remove_entry = false; + + if !state.closing && state.refcount > 0 { + state.refcount -= 1; + if state.refcount == 0 { + state.closing = true; + close.extend(state.bound.drain().map(|(_, iface)| iface)); + remove_entry = true; + } + } + + drop(state); + + EntryRelease { + network, + key: self.key.clone(), + entry: self.entry.clone(), + permit: Some(self), + remove_entry, + close, + completed: false, + } + } +} + +impl EntryRelease { + async fn finish(mut self) { + self.close.close_all().await; + + if self.remove_entry { + self.network.remove_entry_if_current(&self.key, &self.entry); + self.remove_entry = false; + } + + self.completed = true; + } +} + +impl Drop for EntryRelease { + fn drop(&mut self) { + if self.completed { + return; + } + + if self.close.is_empty() { + if self.remove_entry { + self.network.remove_entry_if_current(&self.key, &self.entry); + self.remove_entry = false; + } + return; + } + + let Some(permit) = self.permit.take() else { + return; + }; + let network = self.network.clone(); + let key = self.key.clone(); + let entry = self.entry.clone(); + let remove_entry = self.remove_entry; + self.remove_entry = false; + let mut close = mem::take(&mut self.close); + + // Inherent termination: the spawned task owns the exact entry permit + // plus a finite close batch, then removes the same registry entry if + // this was the final release. + tokio::spawn( + async move { + let _permit = permit; + close.close_all().await; + if remove_entry { + network.remove_entry_if_current(&key, &entry); + } + } + .in_current_span(), + ); + } +} + +impl PendingBindRegistration { + fn close_new_bindings(&mut self) -> CloseBatch { + let mut close = CloseBatch::new(); + close.extend( + mem::take(&mut self.new_bindings) + .into_iter() + .map(|(_, iface)| iface), + ); + close + } + + fn abandon_uncommitted_entry( + network: Arc, + permit: BindEntryPermit, + mut close: CloseBatch, + ) { + let remove_entry = { + let mut state = permit.entry.lock_state(); + let remove_entry = state.refcount == 0 && state.bound.is_empty(); + if remove_entry { + state.closing = true; + } + remove_entry + }; + + if close.is_empty() { + if remove_entry { + network.remove_entry_if_current(&permit.key, &permit.entry); + } + return; + } + + let key = permit.key.clone(); + let entry = permit.entry.clone(); + + // Inherent termination: the spawned task owns the uncommitted + // interfaces and exact entry permit, closes the finite batch, then + // removes the still-unowned registry entry when appropriate. + tokio::spawn( + async move { + let _permit = permit; + close.close_all().await; + if remove_entry { + network.remove_entry_if_current(&key, &entry); + } + } + .in_current_span(), + ); + } + + fn plan_current_devices_bind(&self, devices: &'static Devices) -> Vec { + let permit = self.permit.as_ref().expect("pending bind has permit"); + let state = permit.entry.lock_state(); + if state.closing { + return Vec::new(); + } + state.missing_current_device_uris(&permit.entry.pattern, devices) + } + + fn push_new_binding(&mut self, uri: BindUri, iface: BindInterface) { + self.new_bindings.push((uri, iface)); + } + + fn commit_into_handle(mut self, network: Arc) -> BindHandle { + let permit = self.permit.take().expect("pending bind has permit"); + let new_bindings = mem::take(&mut self.new_bindings); + + { + let mut state = permit.entry.lock_state(); + debug_assert!(!state.closing, "permit should not commit a closing entry"); + state.refcount += 1; + let mut close = CloseBatch::new(); + for (uri, iface) in new_bindings { + if let std::collections::hash_map::Entry::Vacant(slot) = state.bound.entry(uri) { + slot.insert(iface); + } else { + close.push(iface); + } + } + if !close.is_empty() { + let _task = close.detach(); + } + } + + BindHandle { + network, + key: permit.key.clone(), + entry: permit.entry.clone(), + released: false, + } + } +} + +impl Drop for PendingBindRegistration { + fn drop(&mut self) { + let close = self.close_new_bindings(); + let Some(permit) = self.permit.take() else { + if !close.is_empty() { + let _task = close.detach(); + } + return; + }; + + Self::abandon_uncommitted_entry(self.network.clone(), permit, close); + } +} + +/// RAII handle returned by [`QuicBindDriver::bind`] or [`Network::bind_with`]. +/// +/// The handle holds a reference-counted spot in the bind registry. +/// When all handles for a pattern are dropped, the bound interfaces +/// are released. +/// +/// Prefer calling [`BindHandle::unbind`] to release synchronously. +/// If the handle is dropped without calling `unbind()`, the release +/// is spawned as a background task. +pub struct BindHandle { + network: Arc, + key: BindRegistryKey, + entry: Arc, + released: bool, +} + +/// IO implementation used by non-packet bind drivers during the current +/// `InterfaceManager` transition. +/// +/// It only carries a [`BindUri`] and does not provide any packet I/O. +#[derive(Debug)] +pub struct NullIo { + bind_uri: BindUri, +} + +impl NullIo { + #[must_use] + pub fn new(bind_uri: BindUri) -> Self { + Self { bind_uri } + } + + fn unsupported() -> io::Error { + io::Error::new( + io::ErrorKind::Unsupported, + "null io does not support packet operations", + ) + } +} + +#[derive(Debug, Default)] +pub struct NullIoFactory; + +impl ProductIO for NullIoFactory { + fn bind(&self, bind_uri: BindUri) -> Box { + Box::new(NullIo::new(bind_uri)) + } +} + +impl crate::dquic::net::IO for NullIo { + fn bind_uri(&self) -> BindUri { + self.bind_uri.clone() + } + + fn bound_addr(&self) -> io::Result { + Err(Self::unsupported()) + } + + fn max_segment_size(&self) -> io::Result { + Err(Self::unsupported()) + } + + fn max_segments(&self) -> io::Result { + Err(Self::unsupported()) + } + + fn poll_send( + &self, + _cx: &mut Context, + _pkts: &[io::IoSlice], + _route: crate::dquic::qbase::net::route::Route, + ) -> Poll> { + Poll::Ready(Err(Self::unsupported())) + } + + fn poll_recv( + &self, + _cx: &mut Context, + _pkts: &mut [bytes::BytesMut], + _route: &mut [crate::dquic::qbase::net::route::Route], + ) -> Poll> { + Poll::Ready(Err(Self::unsupported())) + } + + fn poll_close(&mut self, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } +} + +// --------------------------------------------------------------------------- +// Builder (bon) +// --------------------------------------------------------------------------- + +#[bon::bon] +impl Network { + #[builder(start_fn(name = builder, vis = "pub"), builder_type(vis = "pub"))] + fn new( + stun_server: Option>, + #[builder(default = Arc::new(InterfaceManager::new()))] iface_manager: Arc< + InterfaceManager, + >, + #[builder(default = Arc::new(DEFAULT_IO_FACTORY))] io_factory: Arc, + #[builder(default = Arc::new(SystemResolver))] stun_resolver: Arc< + dyn Resolve + Send + Sync, + >, + #[builder(default = Arc::new(QuicRouter::new()))] quic_router: Arc, + #[builder(default = Arc::new(Locations::new()))] locations: Arc, + #[builder(default = Devices::global())] devices: &'static Devices, + ) -> Arc { + Self::build_with_quic_driver(devices, |network| { + QuicBindDriver::builder() + .network(network.clone()) + .iface_manager(iface_manager) + .io_factory(io_factory) + .stun_resolver(stun_resolver) + .maybe_stun_server(stun_server) + .quic_router(quic_router) + .locations(locations) + .build() + }) + } +} + +// --------------------------------------------------------------------------- +// Network methods +// --------------------------------------------------------------------------- + +impl Network { + fn build_with_quic_driver( + devices: &'static Devices, + build_quic_driver: impl FnOnce(Weak) -> Arc, + ) -> Arc { + let network = Arc::new_cyclic(|network| { + let quic_driver = build_quic_driver(network.clone()); + Network { + devices, + quic_driver, + bind_registry: Mutex::new(HashMap::new()), + _reconcile: OnceLock::new(), + } + }); + + // Start the background reconcile task that keeps bound interfaces + // in sync with device changes. + let reconcile_net = Arc::downgrade(&network); + let devices = network.devices; + let handle = tokio::spawn( + async move { + Network::run_reconcile(reconcile_net, devices).await; + } + .in_current_span(), + ); + let _ = network._reconcile.set(AbortOnDropHandle::new(handle)); + + network + } + + /// Return the built-in QUIC runtime view for QUIC-specific operations. + #[must_use] + pub fn quic(&self) -> Arc { + self.quic_driver.clone() + } + + async fn acquire_or_insert_entry( + self: &Arc, + key: BindRegistryKey, + driver: Arc, + ) -> BindEntryPermit { + loop { + let entry = { + let mut registry = self.bind_registry.lock().expect("bind_registry poisoned"); + registry + .entry(key.clone()) + .or_insert_with(|| Arc::new(BindsEntry::new(key.clone(), driver.clone()))) + .clone() + }; + + let serial = entry.serial.clone().lock_owned().await; + + let current = { + let registry = self.bind_registry.lock().expect("bind_registry poisoned"); + registry + .get(&key) + .is_some_and(|current| Arc::ptr_eq(current, &entry)) + }; + + if current && !entry.is_closing() { + return BindEntryPermit { + key: key.clone(), + entry, + _serial: serial, + }; + } + } + } + + async fn acquire_existing_entry( + &self, + key: BindRegistryKey, + entry: Arc, + ) -> Option { + let serial = entry.serial.clone().lock_owned().await; + + let current = { + let registry = self.bind_registry.lock().expect("bind_registry poisoned"); + registry + .get(&key) + .is_some_and(|current| Arc::ptr_eq(current, &entry)) + }; + + current.then_some(BindEntryPermit { + key, + entry, + _serial: serial, + }) + } + + fn entries_matching_device(&self, device: &str) -> Vec { + let registry = self.bind_registry.lock().expect("bind_registry poisoned"); + registry + .values() + .filter(|entry| entry.pattern_matches_device(device)) + .map(|entry| BindEntryRef { + key: entry.key.clone(), + entry: entry.clone(), + }) + .collect() + } + + fn remove_entry_if_current(&self, key: &BindRegistryKey, entry: &Arc) { + let mut registry = self.bind_registry.lock().expect("bind_registry poisoned"); + if registry + .get(key) + .is_some_and(|current| Arc::ptr_eq(current, entry)) + { + registry.remove(key); + } + } + + /// Register a [`BindPattern`] through an explicit binding driver. + pub async fn bind_with(self: &Arc, driver: Arc, pattern: BindPattern) -> BindHandle + where + D: BindDriver + 'static, + { + let key = BindRegistryKey { + driver: bind_driver_id(&driver), + pattern, + }; + let driver_erased: Arc = driver; + + let permit = self + .acquire_or_insert_entry(key, driver_erased.clone()) + .await; + let mut pending = permit.begin_bind_registration(self.clone()); + + for uri in pending.plan_current_devices_bind(self.devices) { + let iface = driver_erased.bind(self, uri.clone()).await; + pending.push_new_binding(uri, iface); + } + + pending.commit_into_handle(self.clone()) + } + + /// Release a previously registered [`BindPattern`] and its bound + /// interfaces. + async fn release_exact_entry(self: &Arc, key: BindRegistryKey, entry: Arc) { + let Some(permit) = self + .acquire_existing_entry(key.clone(), entry.clone()) + .await + else { + return; + }; + + permit.release_one_handle(self.clone()).finish().await; + } + + /// Resolve the current address for a device and IP family. + /// + /// Bind drivers use this to derive per-device resources from the same + /// device snapshot that drives [`Network`] reconciliation. + #[must_use] + pub fn resolve_device_addr(&self, device: &str, family: Family) -> Option { + self.devices.resolve(device, family) + } + + /// Return whether `bound_addr` belongs to the currently-default interface + /// represented by `bind_uri`. + /// + /// Stale bind handles can briefly outlive netlink reconciliation after a + /// device is removed. In that case the device is absent from the current + /// snapshot and this method returns `false`. + #[must_use] + pub fn bound_addr_is_on_default_route( + &self, + bind_uri: &BindUri, + bound_addr: SocketAddr, + ) -> bool { + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + return false; + }; + let bound_family = if bound_addr.is_ipv4() { + Family::V4 + } else { + Family::V6 + }; + if family != bound_family { + return false; + } + + self.devices.get(device).is_some_and(|interface| { + interface.default && interface_contains(&interface, bound_addr.ip()) + }) + } + + /// Return all currently bound interfaces for a specific driver and pattern. + #[must_use] + pub fn get_interfaces_with( + &self, + driver: &Arc, + pattern: &BindPattern, + ) -> Option> + where + D: BindDriver + ?Sized, + { + let key = BindRegistryKey { + driver: bind_driver_id(driver), + pattern: pattern.clone(), + }; + let registry = self.bind_registry.lock().expect("bind_registry poisoned"); + registry + .get(&key) + .map(|entry| entry.lock_state().bound.values().cloned().collect()) + } + + /// Background task that reconciles bound interfaces with device changes. + async fn run_reconcile(network: Weak, devices: &'static Devices) { + let mut monitor = devices.monitor(); + while let Some((_interfaces, event)) = monitor.update().await { + let Some(network) = network.upgrade() else { + break; + }; + tracing::debug!( + ?event, + "network interface change, reconciling affected binds" + ); + network.reconcile_event(event.as_ref()).await; + } + } + + async fn bind_added_device(&self, device: &str) { + if self.devices.get(device).is_none() { + return; + } + + for entry_ref in self.entries_matching_device(device) { + let Some(permit) = self + .acquire_existing_entry(entry_ref.key, entry_ref.entry) + .await + else { + continue; + }; + + let missing = permit.plan_added_device_bind(device); + let mut new_bindings = Vec::with_capacity(missing.len()); + for uri in missing { + if self.devices.get(device).is_none() { + break; + } + let iface = permit.entry.driver.bind(self, uri.clone()).await; + new_bindings.push((uri, iface)); + } + + let mut close = + permit.commit_added_device(self.devices.get(device).is_some(), new_bindings); + close.close_all().await; + + if self.devices.get(device).is_none() { + return; + } + } + } + + async fn remove_device_bindings(&self, device: &str) { + for entry_ref in self.entries_matching_device(device) { + let Some(permit) = self + .acquire_existing_entry(entry_ref.key, entry_ref.entry) + .await + else { + continue; + }; + + let mut close = permit.drain_removed_device(device); + close.close_all().await; + } + } + + async fn rebind_changed_device(&self, device: &str) { + if self.devices.get(device).is_none() { + return; + } + + for entry_ref in self.entries_matching_device(device) { + if self.devices.get(device).is_none() { + return; + } + + let Some(permit) = self + .acquire_existing_entry(entry_ref.key, entry_ref.entry) + .await + else { + continue; + }; + + for iface in permit.targets_for_changed_device(device) { + permit.entry.driver.rebind(self, &iface).await; + } + } + } + + /// Extracted reconcile logic, run per interface change. + async fn reconcile_event(&self, event: &InterfaceEvent) { + let device = event.device(); + + match event { + InterfaceEvent::Added { .. } => { + self.bind_added_device(device).await; + } + InterfaceEvent::Removed { .. } => { + self.remove_device_bindings(device).await; + } + InterfaceEvent::Changed { .. } => { + self.rebind_changed_device(device).await; + } + } + } +} + +/// AuthClient that filters connections based on whether the receiving +/// interface matches any of the server's registered BindPatterns. +struct InterfaceAuthClient { + bind_uri: BindUri, + sni_registry: SniRegistry, +} + +impl AuthClient for InterfaceAuthClient { + fn verify_client_name( + &self, + server_authority: &LocalAuthority, + _client_name: Option<&str>, + ) -> ClientNameVerifyResult { + let sni = server_authority.name(); + let sni_lower = sni.to_ascii_lowercase(); + let entry = self + .sni_registry + .get::(&sni_lower) + .and_then(|item| item.value().upgrade()); + + match entry { + None => ClientNameVerifyResult::SilentRefuse("no server registered for SNI".to_owned()), + Some(entry) if entry.bind.is_empty() => ClientNameVerifyResult::Accept, + Some(entry) if entry.bind.iter().any(|p| p.matches(&self.bind_uri)) => { + ClientNameVerifyResult::Accept + } + Some(_) => ClientNameVerifyResult::SilentRefuse("bind pattern mismatch".to_owned()), + } + } + + fn verify_client_authority( + &self, + _server_authority: &LocalAuthority, + _client_authority: &RemoteAuthority, + ) -> ClientAuthorityVerifyResult { + ClientAuthorityVerifyResult::Accept + } +} + +// --------------------------------------------------------------------------- +// BindHandle +// --------------------------------------------------------------------------- + +impl BindHandle { + /// Release the bind pattern synchronously. + pub async fn unbind(&mut self) { + if !self.released { + self.network + .release_exact_entry(self.key.clone(), self.entry.clone()) + .await; + self.released = true; + } + } +} + +impl Drop for BindHandle { + fn drop(&mut self) { + if !self.released { + let network = self.network.clone(); + let key = self.key.clone(); + let entry = self.entry.clone(); + // Inherent termination: the spawned task exits after the unbind + // completes. If the runtime is shutting down, the unbind is + // best-effort. + tokio::spawn( + async move { + network.release_exact_entry(key, entry).await; + } + .in_current_span(), + ); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + net::SocketAddr, + str::FromStr, + sync::{ + Mutex as StdMutex, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, + task::Waker, + time::Duration, + }; + + use dquic::{ + prelude::{IO, handy::NoopTokenRegistry}, + qinterface::device::{Interface, InterfaceEvent}, + }; + use futures::{FutureExt, StreamExt, future::BoxFuture}; + use rustls::ClientConfig as TlsClientConfig; + + use super::*; + use crate::dquic::{binds::BindPattern, identity::Identity}; + + fn make_identity(name: &str) -> Arc { + use dquic::prelude::handy::{ToCertificate, ToPrivateKey}; + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + + const SERVER_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/server.cert"); + const SERVER_KEY: &[u8] = include_bytes!("../../tests/keychain/localhost/server.key"); + + let certs: Vec> = SERVER_CERT.to_certificate(); + let key: PrivateKeyDer<'static> = SERVER_KEY.to_private_key(); + + Arc::new(Identity { + name: name.parse().unwrap(), + certs: Arc::new(certs), + key: Arc::new(key), + ocsp: Arc::new(None), + }) + } + + fn make_server_config() -> ServerQuicConfig { + ServerQuicConfig::default() + } + + fn make_local_authority(name: &str) -> LocalAuthority { + let identity = make_identity(name); + let certified_key = + crate::dquic::identity::build_certified_key(&identity).expect("valid certified key"); + LocalAuthority::new(Arc::from(name), certified_key) + } + + fn make_remote_authority(name: &str) -> RemoteAuthority { + let identity = make_identity(name); + RemoteAuthority::new(Arc::from(name), Arc::from(identity.certs.as_slice())) + } + + struct TestNullDriver { + manager: Arc, + } + + impl TestNullDriver { + fn new() -> Self { + Self { + manager: Arc::new(crate::dquic::net::InterfaceManager::new()), + } + } + } + + impl BindDriver for TestNullDriver { + fn bind<'a>(&'a self, _network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface> { + async move { self.manager.bind(uri, Arc::new(NullIo::new)).await }.boxed() + } + } + + fn dummy_interface_named(name: &str) -> Interface { + let mut interface = Interface::dummy(); + interface.name = name.to_owned(); + interface + } + + fn added_event(device: &str) -> InterfaceEvent { + InterfaceEvent::Added { + device: device.to_owned(), + new_interface: dummy_interface_named(device), + } + } + + fn removed_event(device: &str) -> InterfaceEvent { + InterfaceEvent::Removed { + device: device.to_owned(), + old_interface: dummy_interface_named(device), + } + } + + fn changed_event(device: &str) -> InterfaceEvent { + InterfaceEvent::Changed { + device: device.to_owned(), + old_interface: dummy_interface_named(device), + new_interface: dummy_interface_named(device), + } + } + + struct CountingDriver { + manager: Arc, + binds: AtomicUsize, + rebinds: AtomicUsize, + } + + impl CountingDriver { + fn new() -> Self { + Self { + manager: Arc::new(crate::dquic::net::InterfaceManager::new()), + binds: AtomicUsize::new(0), + rebinds: AtomicUsize::new(0), + } + } + + fn bind_count(&self) -> usize { + self.binds.load(Ordering::Relaxed) + } + + fn rebind_count(&self) -> usize { + self.rebinds.load(Ordering::Relaxed) + } + } + + impl BindDriver for CountingDriver { + fn bind<'a>(&'a self, _network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface> { + async move { + self.binds.fetch_add(1, Ordering::Relaxed); + self.manager.bind(uri, Arc::new(NullIoFactory)).await + } + .boxed() + } + + fn rebind<'a>( + &'a self, + _network: &'a Network, + _iface: &'a BindInterface, + ) -> BoxFuture<'a, ()> { + async move { + self.rebinds.fetch_add(1, Ordering::Relaxed); + } + .boxed() + } + } + + #[derive(Debug)] + struct CloseCountingFactory { + closes: Arc, + } + + #[derive(Debug)] + struct CloseCountingIo { + bind_uri: BindUri, + closes: Arc, + } + + impl ProductIO for CloseCountingFactory { + fn bind(&self, bind_uri: BindUri) -> Box { + Box::new(CloseCountingIo { + bind_uri, + closes: self.closes.clone(), + }) + } + } + + impl crate::dquic::net::IO for CloseCountingIo { + fn bind_uri(&self) -> BindUri { + self.bind_uri.clone() + } + + fn bound_addr(&self) -> io::Result { + Err(io::Error::new(io::ErrorKind::Unsupported, "not needed")) + } + + fn max_segment_size(&self) -> io::Result { + Ok(1200) + } + + fn max_segments(&self) -> io::Result { + Ok(1) + } + + fn poll_send( + &self, + _cx: &mut Context, + _pkts: &[io::IoSlice], + _route: crate::dquic::qbase::net::route::Route, + ) -> Poll> { + Poll::Ready(Ok(0)) + } + + fn poll_recv( + &self, + _cx: &mut Context, + _pkts: &mut [bytes::BytesMut], + _route: &mut [crate::dquic::qbase::net::route::Route], + ) -> Poll> { + Poll::Pending + } + + fn poll_close(&mut self, _cx: &mut Context) -> Poll> { + self.closes.fetch_add(1, Ordering::SeqCst); + Poll::Ready(Ok(())) + } + } + + struct CloseCountingDriver { + manager: Arc, + closes: Arc, + } + + impl CloseCountingDriver { + fn new() -> Self { + Self { + manager: Arc::new(crate::dquic::net::InterfaceManager::new()), + closes: Arc::new(AtomicUsize::new(0)), + } + } + + fn close_count(&self) -> usize { + self.closes.load(Ordering::SeqCst) + } + } + + impl BindDriver for CloseCountingDriver { + fn bind<'a>(&'a self, _network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface> { + async move { + self.manager + .bind( + uri, + Arc::new(CloseCountingFactory { + closes: self.closes.clone(), + }), + ) + .await + } + .boxed() + } + } + + struct BlockingCloseControl { + close_started: AtomicBool, + release_close: AtomicBool, + close_waker: StdMutex>, + close_completed: AtomicUsize, + close_started_notify: tokio::sync::Notify, + } + + impl BlockingCloseControl { + fn new() -> Self { + Self { + close_started: AtomicBool::new(false), + release_close: AtomicBool::new(false), + close_waker: StdMutex::new(None), + close_completed: AtomicUsize::new(0), + close_started_notify: tokio::sync::Notify::new(), + } + } + + async fn wait_close_started(&self) { + if self.close_started.load(Ordering::SeqCst) { + return; + } + self.close_started_notify.notified().await; + } + + fn release_close(&self) { + self.release_close.store(true, Ordering::SeqCst); + if let Some(waker) = self + .close_waker + .lock() + .expect("close waker mutex poisoned") + .take() + { + waker.wake(); + } + } + + fn complete_count(&self) -> usize { + self.close_completed.load(Ordering::SeqCst) + } + } + + struct BlockingCloseFactory { + control: Arc, + } + + struct BlockingCloseIo { + bind_uri: BindUri, + control: Arc, + closed: bool, + } + + impl ProductIO for BlockingCloseFactory { + fn bind(&self, bind_uri: BindUri) -> Box { + Box::new(BlockingCloseIo { + bind_uri, + control: self.control.clone(), + closed: false, + }) + } + } + + impl crate::dquic::net::IO for BlockingCloseIo { + fn bind_uri(&self) -> BindUri { + self.bind_uri.clone() + } + + fn bound_addr(&self) -> io::Result { + Err(io::Error::new(io::ErrorKind::Unsupported, "not needed")) + } + + fn max_segment_size(&self) -> io::Result { + Ok(1200) + } + + fn max_segments(&self) -> io::Result { + Ok(1) + } + + fn poll_send( + &self, + _cx: &mut Context, + _pkts: &[io::IoSlice], + _route: crate::dquic::qbase::net::route::Route, + ) -> Poll> { + Poll::Ready(Ok(0)) + } + + fn poll_recv( + &self, + _cx: &mut Context, + _pkts: &mut [bytes::BytesMut], + _route: &mut [crate::dquic::qbase::net::route::Route], + ) -> Poll> { + Poll::Pending + } + + fn poll_close(&mut self, cx: &mut Context) -> Poll> { + if self.closed { + return Poll::Ready(Ok(())); + } + + self.control.close_started.store(true, Ordering::SeqCst); + self.control.close_started_notify.notify_one(); + + if self.control.release_close.load(Ordering::SeqCst) { + self.closed = true; + self.control.close_completed.fetch_add(1, Ordering::SeqCst); + Poll::Ready(Ok(())) + } else { + *self + .control + .close_waker + .lock() + .expect("close waker mutex poisoned") = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + struct BlockingCloseDriver { + manager: Arc, + control: Arc, + } + + impl BlockingCloseDriver { + fn new() -> Self { + Self { + manager: Arc::new(crate::dquic::net::InterfaceManager::new()), + control: Arc::new(BlockingCloseControl::new()), + } + } + + async fn wait_close_started(&self) { + self.control.wait_close_started().await; + } + + fn release_close(&self) { + self.control.release_close(); + } + + fn complete_count(&self) -> usize { + self.control.complete_count() + } + } + + impl BindDriver for BlockingCloseDriver { + fn bind<'a>(&'a self, _network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface> { + async move { + self.manager + .bind( + uri, + Arc::new(BlockingCloseFactory { + control: self.control.clone(), + }), + ) + .await + } + .boxed() + } + } + + struct SlowCountingDriver { + manager: Arc, + active: AtomicUsize, + max_active: AtomicUsize, + binds: AtomicUsize, + } + + impl SlowCountingDriver { + fn new() -> Self { + Self { + manager: Arc::new(crate::dquic::net::InterfaceManager::new()), + active: AtomicUsize::new(0), + max_active: AtomicUsize::new(0), + binds: AtomicUsize::new(0), + } + } + + fn max_active(&self) -> usize { + self.max_active.load(Ordering::SeqCst) + } + + fn bind_count(&self) -> usize { + self.binds.load(Ordering::SeqCst) + } + } + + impl BindDriver for SlowCountingDriver { + fn bind<'a>(&'a self, _network: &'a Network, uri: BindUri) -> BoxFuture<'a, BindInterface> { + async move { + let active = self.active.fetch_add(1, Ordering::SeqCst) + 1; + self.max_active.fetch_max(active, Ordering::SeqCst); + self.binds.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(50)).await; + let iface = self.manager.bind(uri, Arc::new(NullIoFactory)).await; + self.active.fetch_sub(1, Ordering::SeqCst); + iface + } + .boxed() + } + } + + async fn clear_bound_for_test( + network: &Network, + driver: &Arc, + pattern: &BindPattern, + ) { + let key = BindRegistryKey { + driver: bind_driver_id(driver), + pattern: pattern.clone(), + }; + let mut close = { + let registry = network + .bind_registry + .lock() + .expect("bind_registry poisoned"); + let entry = registry.get(&key).expect("registered pattern"); + let mut state = entry.lock_state(); + let mut close = CloseBatch::new(); + close.extend(state.bound.drain().map(|(_, iface)| iface)); + close + }; + close.close_all().await; + } + + #[derive(Debug)] + struct MarkerResolver; + + impl std::fmt::Display for MarkerResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("marker resolver") + } + } + + impl Resolve for MarkerResolver { + fn lookup<'a>(&'a self, _name: &'a str) -> crate::dquic::resolver::ResolveFuture<'a> { + async move { Ok(futures::stream::empty().boxed()) }.boxed() + } + } + + #[tokio::test] + async fn network_builder_accepts_custom_stun_resolver() { + let resolver: Arc = Arc::new(MarkerResolver); + let network = Network::builder().stun_resolver(resolver.clone()).build(); + + assert!(Arc::ptr_eq(&network.quic().stun_resolver, &resolver)); + } + + #[tokio::test] + async fn network_builder_forwards_quic_bind_driver_options() { + let iface_manager = Arc::new(InterfaceManager::new()); + let io_factory: Arc = Arc::new(NullIoFactory); + let stun_resolver: Arc = Arc::new(MarkerResolver); + let quic_router = Arc::new(QuicRouter::new()); + let locations = Arc::new(Locations::new()); + + let network = Network::builder() + .iface_manager(iface_manager.clone()) + .io_factory(io_factory.clone()) + .stun_resolver(stun_resolver.clone()) + .stun_server(Arc::from("builder.stun.example:3478")) + .quic_router(quic_router.clone()) + .locations(locations.clone()) + .build(); + let quic = network.quic(); + + assert!(Arc::ptr_eq(&quic.iface_manager, &iface_manager)); + assert!(Arc::ptr_eq(&quic.io_factory, &io_factory)); + assert!(Arc::ptr_eq(&quic.stun_resolver, &stun_resolver)); + assert_eq!( + quic.stun_server.as_deref(), + Some("builder.stun.example:3478") + ); + assert!(Arc::ptr_eq(&quic.quic_router, &quic_router)); + assert!(Arc::ptr_eq(&quic.locations, &locations)); + } + + #[test] + fn interface_contains_matches_ipv4_and_ipv6_networks() { + let mut interface = ::dquic::qinterface::device::Interface::dummy(); + interface.ipv4.push("192.0.2.10/24".parse().unwrap()); + interface.ipv6.push("2001:db8::10/64".parse().unwrap()); + + assert!(interface_contains( + &interface, + "192.0.2.99".parse().unwrap() + )); + assert!(!interface_contains( + &interface, + "198.51.100.1".parse().unwrap() + )); + assert!(interface_contains( + &interface, + "2001:db8::99".parse().unwrap() + )); + assert!(!interface_contains( + &interface, + "2001:db9::1".parse().unwrap() + )); + } + + #[tokio::test] + async fn default_bind_driver_rebind_is_noop() { + let network = Network::builder().build(); + let driver = Arc::new(TestNullDriver::new()); + let uri: BindUri = "iface://v4.lo:0".parse().expect("valid bind uri"); + let iface = driver.bind(&network, uri).await; + + driver.rebind(&network, &iface).await; + } + + #[tokio::test] + async fn test_bind_server_rejects_different_identity_for_same_sni() { + let network = Network::builder().build(); + let identity_a = make_identity("test.example.com"); + let identity_b = make_identity("test.example.com"); + let config = make_server_config(); + + let binding_a = network + .quic() + .bind_server(identity_a.clone(), config.clone(), Arc::new(Vec::new())) + .await + .expect("first bind should succeed"); + + let err = network + .quic() + .bind_server(identity_b.clone(), config.clone(), Arc::new(Vec::new())) + .await + .expect_err("second bind with different identity should be rejected"); + + assert!(matches!( + err, + BindServerError::SniInUse { ref name } if name == binding_a.name() + )); + let quic = network.quic(); + assert_eq!(quic.sni_registry.len(), 1); + let entry = quic + .sni_registry + .get(binding_a.name()) + .and_then(|kv| kv.value().upgrade()) + .expect("registry should keep the original entry"); + assert!(Arc::ptr_eq(&entry.identity, &identity_a)); + } + + #[tokio::test] + async fn bind_server_different_identity_same_sni_prefers_sni_in_use_over_config_conflict() { + let network = Network::builder().build(); + let identity_a = make_identity("conflict.example.com"); + let identity_b = make_identity("conflict.example.com"); + let cfg_a = make_server_config(); + let cfg_b = ServerQuicConfig { + alpns: vec![b"altproto".to_vec()], + ..Default::default() + }; + + let binding_a = network + .quic() + .bind_server(identity_a, cfg_a, Arc::new(Vec::new())) + .await + .expect("first bind should succeed"); + + let error = network + .quic() + .bind_server(identity_b, cfg_b, Arc::new(Vec::new())) + .await + .expect_err("different identity should fail before config compatibility"); + + assert!( + matches!( + error, + BindServerError::SniInUse { ref name } if name == binding_a.name() + ), + "unexpected error: {error:?}" + ); + } + + #[tokio::test] + async fn test_bind_server_same_identity_reuse() { + let network = Network::builder().build(); + let identity = make_identity("test.example.com"); + let config = make_server_config(); + + let binding_a = network + .quic() + .bind_server(identity.clone(), config.clone(), Arc::new(Vec::new())) + .await + .expect("first bind should succeed"); + + let binding_b = network + .quic() + .bind_server(identity.clone(), config.clone(), Arc::new(Vec::new())) + .await + .expect("second bind with same identity should succeed (reuse)"); + + assert_eq!(binding_a.name(), binding_b.name()); + let quic = network.quic(); + assert_eq!(quic.sni_registry.len(), 1); + + let entry = quic + .sni_registry + .get(binding_a.name()) + .and_then(|kv| kv.value().upgrade()) + .expect("registry should hold the shared entry"); + assert!(Arc::ptr_eq(&entry, &binding_a.entry)); + assert!(Arc::ptr_eq(&entry, &binding_b.entry)); + + drop(binding_a); + let entry_after_drop = quic + .sni_registry + .get(binding_b.name()) + .and_then(|kv| kv.value().upgrade()) + .expect("reused binding should keep SNI registered"); + assert!(Arc::ptr_eq(&entry_after_drop, &binding_b.entry)); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn concurrent_same_identity_bind_server_reuses_one_entry() { + const TASKS: usize = 16; + + let network = Network::builder().build(); + let identity = make_identity("test.example.com"); + let config = make_server_config(); + let barrier = Arc::new(tokio::sync::Barrier::new(TASKS)); + + let mut tasks = Vec::with_capacity(TASKS); + for _ in 0..TASKS { + let quic = network.quic(); + let identity = identity.clone(); + let config = config.clone(); + let barrier = barrier.clone(); + tasks.push(tokio::spawn( + async move { + barrier.wait().await; + quic.bind_server(identity, config, Arc::new(Vec::new())) + .await + .expect("concurrent same-identity bind should succeed") + } + .in_current_span(), + )); + } + + let mut bindings = Vec::with_capacity(TASKS); + for task in tasks { + bindings.push(task.await.expect("bind task should not panic")); + } + + let first = bindings.first().expect("at least one binding"); + assert!( + bindings + .iter() + .all(|binding| Arc::ptr_eq(&binding.entry, &first.entry)), + "all bindings should share the same server entry" + ); + + let quic = network.quic(); + assert_eq!(quic.sni_registry.len(), 1); + let registered = quic + .sni_registry + .get(first.name()) + .and_then(|kv| kv.value().upgrade()) + .expect("registry should hold the shared entry"); + assert!(Arc::ptr_eq(®istered, &first.entry)); + } + + #[tokio::test] + async fn test_network_build() { + let network = Network::builder().build(); + let _cloned = network.clone(); + } + + #[tokio::test] + async fn test_network_locations() { + let network = Network::builder().build(); + let locations = network.quic().locations(); + let _cloned = locations.clone(); + } + + #[tokio::test] + async fn quic_driver_exposes_quic_specific_queries() { + let network = Network::builder().build(); + let quic: Arc = network.quic(); + assert!(Arc::ptr_eq(&quic, &network.quic())); + assert!(Arc::ptr_eq(&quic.locations(), &network.quic().locations())); + + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + let mut handle = quic.clone().bind(pattern.clone()).await; + + let quic_interfaces = quic.interfaces(); + + let quic_uris = quic.current_bind_uris(); + + let uri = quic_uris + .first() + .expect("quic bind should expose at least one current bind uri"); + assert!(quic.get_iface(uri).is_some()); + let mut query_changed = uri.clone(); + query_changed.add_prop(BindUri::ALLOC_PORT_ID, "synthetic-test-id"); + assert!( + quic.get_iface(&query_changed).is_none(), + "get_iface must use full BindUri identity, not reconciliation identity" + ); + + assert_eq!( + quic.get_interfaces(&pattern) + .expect("quic interfaces should be registered") + .len(), + quic_interfaces.len() + ); + + handle.unbind().await; + } + + #[tokio::test] + async fn ipv4_and_ipv6_wildcard_binds_can_share_a_port_when_registered_separately() { + let reserve = std::net::UdpSocket::bind("127.0.0.1:0").expect("reserve port"); + let port = reserve.local_addr().expect("reserve addr").port(); + drop(reserve); + + let network = Network::builder().build(); + let quic = network.quic(); + let v6_pattern = + BindPattern::from_str(&format!("inet://[::]:{port}")).expect("valid v6 pattern"); + let v4_pattern = + BindPattern::from_str(&format!("inet://0.0.0.0:{port}")).expect("valid v4 pattern"); + + let mut v6_handle = quic.clone().bind(v6_pattern.clone()).await; + let mut v4_handle = quic.clone().bind(v4_pattern.clone()).await; + + use crate::dquic::net::IO as _; + + let v6_ifaces = quic + .get_interfaces(&v6_pattern) + .expect("v6 bind should stay registered"); + let v4_ifaces = quic + .get_interfaces(&v4_pattern) + .expect("v4 bind should stay registered"); + + assert_eq!(v6_ifaces.len(), 1); + assert_eq!(v4_ifaces.len(), 1); + assert!(matches!( + v6_ifaces[0].borrow().bound_addr().expect("v6 bound addr"), + std::net::SocketAddr::V6(_) + )); + assert!(matches!( + v4_ifaces[0].borrow().bound_addr().expect("v4 bound addr"), + std::net::SocketAddr::V4(_) + )); + + v4_handle.unbind().await; + v6_handle.unbind().await; + } + + #[tokio::test] + async fn network_drop_is_not_prevented_by_background_callbacks() { + let network = Network::builder().build(); + let weak = Arc::downgrade(&network); + + drop(network); + tokio::task::yield_now().await; + + assert!( + weak.upgrade().is_none(), + "network should not be retained by reconcile or router callbacks" + ); + } + + #[tokio::test] + async fn test_network_registered_sni_names_empty() { + let network = Network::builder().build(); + assert!( + network.quic().registered_sni_names().is_empty(), + "fresh network should have no registered sni names" + ); + } + + #[tokio::test] + async fn test_network_registered_sni_names_after_bind() { + let network = Network::builder().build(); + let identity = make_identity("alpha"); + let config = make_server_config(); + + let _binding = network + .quic() + .bind_server(identity, config, Arc::new(Vec::new())) + .await + .expect("bind should succeed"); + + let names = network.quic().registered_sni_names(); + assert_eq!(names.len(), 1); + assert_eq!(names[0].as_str(), "alpha"); + } + + #[tokio::test] + async fn test_network_registered_sni_names_after_drop() { + let network = Network::builder().build(); + let identity = make_identity("alpha"); + let config = make_server_config(); + + let binding = network + .quic() + .bind_server(identity, config, Arc::new(Vec::new())) + .await + .expect("bind should succeed"); + + assert_eq!(network.quic().registered_sni_names().len(), 1); + drop(binding); + assert!( + network.quic().registered_sni_names().is_empty(), + "names should be empty after dropping the last binding" + ); + } + + #[tokio::test] + async fn test_network_bind_server_with_identical_config_arc_reuse() { + let network = Network::builder().build(); + let identity_a = make_identity("a"); + let identity_b = make_identity("b"); + let config = make_server_config(); + + let binding_a = network + .quic() + .bind_server(identity_a.clone(), config.clone(), Arc::new(Vec::new())) + .await + .expect("first bind should succeed"); + + let binding_b = network + .quic() + .bind_server(identity_b.clone(), config.clone(), Arc::new(Vec::new())) + .await + .expect("second bind with same config should succeed"); + + assert_eq!(binding_a.name().as_str(), "a"); + assert_eq!(binding_b.name().as_str(), "b"); + + let names = network.quic().registered_sni_names(); + assert_eq!(names.len(), 2, "both names should be registered"); + assert!(names.iter().any(|n| n.as_str() == "a")); + assert!(names.iter().any(|n| n.as_str() == "b")); + } + + #[test] + fn test_bind_server_error_sni_in_use_display() { + let err = BindServerError::SniInUse { + name: "example.com".parse().unwrap(), + }; + let display = format!("{err}"); + assert!( + display.starts_with("sni "), + "Display should start with lowercase: {display}" + ); + assert!( + display.contains("example.com"), + "Display should contain SNI name: {display}" + ); + assert!( + !display.ends_with('.'), + "Display should not end with period" + ); + } + + #[test] + fn test_bind_server_error_server_config_conflict_display() { + let err = BindServerError::ServerConfigConflict; + let display = format!("{err}"); + assert!( + display.contains("incompatible"), + "Display should contain 'incompatible': {display}" + ); + assert!( + display.starts_with("quic driver"), + "Display should start with lowercase: {display}" + ); + assert!( + !display.ends_with('.'), + "Display should not end with period" + ); + } + + #[test] + fn test_bind_server_error_load_key_variant() { + fn assert_load_key_variant(e: rustls::Error) -> BindServerError { + BindServerError::LoadKey { source: e } + } + let _ = assert_load_key_variant; + } + + #[test] + fn test_bind_server_error_version_variant() { + fn assert_version_variant(e: rustls::Error) -> BindServerError { + BindServerError::Version { source: e } + } + let _ = assert_version_variant; + } + + #[tokio::test] + async fn test_server_binding_name() { + let network = Network::builder().build(); + let identity = make_identity("example.com"); + let config = make_server_config(); + + let binding = network + .quic() + .bind_server(identity, config, Arc::new(Vec::new())) + .await + .expect("bind should succeed"); + + assert_eq!(binding.name().as_str(), "example.com"); + } + + #[tokio::test] + async fn test_server_binding_clone_name() { + let network = Network::builder().build(); + let identity = make_identity("example.com"); + let config = make_server_config(); + + let binding = network + .quic() + .bind_server(identity, config, Arc::new(Vec::new())) + .await + .expect("bind should succeed"); + + let cloned = binding.clone(); + assert_eq!(binding.name(), cloned.name()); + assert_eq!(cloned.name().as_str(), "example.com"); + } + + #[tokio::test] + async fn test_network_configure_connection() { + let network = Network::builder().build(); + let new_builder = || { + let provider = TlsClientConfig::builder().crypto_provider().clone(); + let tls = TlsClientConfig::builder_with_provider(provider) + .with_protocol_versions(&[&rustls::version::TLS13]) + .expect("TLS 1.3 should be supported") + .with_root_certificates(rustls::RootCertStore::empty()) + .with_no_client_auth(); + + Connection::new_client("test.example.com".to_string(), Arc::new(NoopTokenRegistry)) + .with_tls_config(tls) + }; + + let _foundation = network.quic().configure_connection(new_builder()); + let _foundation = network.quic().configure_connection(new_builder()); + } + + #[test] + fn test_bind_pattern_parsing() { + let pattern = BindPattern::from_str("127.0.0.1:8080").expect("valid pattern"); + assert_eq!(pattern.port, Some(8080)); + } + + #[tokio::test] + async fn bind_server_same_identity_rejects_incompatible_server_config() { + let network = Network::builder().build(); + let identity = make_identity("same.example.com"); + let cfg_a = make_server_config(); + let cfg_b = ServerQuicConfig { + alpns: vec![b"altproto".to_vec()], + ..Default::default() + }; + + let _held = network + .quic() + .bind_server(identity.clone(), cfg_a, Arc::new(Vec::new())) + .await + .expect("first bind succeeds"); + + let error = network + .quic() + .bind_server(identity, cfg_b, Arc::new(Vec::new())) + .await + .expect_err("same identity with incompatible config must fail"); + + assert!(matches!(error, BindServerError::ServerConfigConflict)); + } + + #[tokio::test] + async fn test_bind_server_config_conflict() { + let network = Network::builder().build(); + let a = make_identity("alpha"); + let b = make_identity("beta"); + + let cfg_a = make_server_config(); + let cfg_b = { + ServerQuicConfig { + alpns: vec![b"altproto".to_vec()], + ..Default::default() + } + }; + + let _held = network + .quic() + .bind_server(a, cfg_a, Arc::new(Vec::new())) + .await + .expect("first bind succeeds"); + let err = network + .quic() + .bind_server(b, cfg_b, Arc::new(Vec::new())) + .await + .expect_err("incompatible server config must fail"); + assert!( + matches!(err, BindServerError::ServerConfigConflict), + "unexpected error: {err:?}" + ); + } + + #[tokio::test] + async fn test_bind_server_slot_auto_reset() { + let network = Network::builder().build(); + + let cfg_a = make_server_config(); + let cfg_b = { + ServerQuicConfig { + alpns: vec![b"altproto".to_vec()], + ..Default::default() + } + }; + + { + let _first = network + .quic() + .bind_server(make_identity("alpha"), cfg_a, Arc::new(Vec::new())) + .await + .expect("first bind succeeds"); + } + // After the binding drops the slot should clear, allowing a new + // incompatible config to install. + let _second = network + .quic() + .bind_server(make_identity("beta"), cfg_b, Arc::new(Vec::new())) + .await + .expect("slot should auto-reset after last binding dropped"); + } + + #[test] + fn null_io_only_exposes_bind_uri_and_close() { + use std::task::{Context, Poll}; + + use bytes::BytesMut; + use dquic::{qbase::net::route::Route, qinterface::io::IO}; + use futures::task::noop_waker; + + let bind_uri: BindUri = "iface://v4.lo:0".parse().expect("valid bind uri"); + let mut io = NullIo::new(bind_uri.clone()); + assert_eq!(io.bind_uri(), bind_uri); + assert!(io.bound_addr().is_err()); + assert!(io.max_segment_size().is_err()); + assert!(io.max_segments().is_err()); + + let waker = noop_waker(); + let mut cx = Context::from_waker(&waker); + let send = io.poll_send(&mut cx, &[], Route::empty()); + assert!(matches!(send, Poll::Ready(Err(_)))); + + let mut packets = Vec::::new(); + let mut routes = Vec::::new(); + let recv = io.poll_recv(&mut cx, &mut packets, &mut routes); + assert!(matches!(recv, Poll::Ready(Err(_)))); + + assert!(matches!(io.poll_close(&mut cx), Poll::Ready(Ok(())))); + } + + #[tokio::test] + async fn bind_with_keeps_driver_bindings_separate_for_same_pattern() { + let network = Network::builder().build(); + let driver_a = Arc::new(TestNullDriver::new()); + let driver_b = Arc::new(TestNullDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle_a = network.bind_with(driver_a.clone(), pattern.clone()).await; + let mut handle_b = network.bind_with(driver_b.clone(), pattern.clone()).await; + + let a_ifaces = network + .get_interfaces_with(&driver_a, &pattern) + .expect("driver a bindings"); + let b_ifaces = network + .get_interfaces_with(&driver_b, &pattern) + .expect("driver b bindings"); + + assert_eq!(a_ifaces.len(), b_ifaces.len()); + assert!(!a_ifaces.is_empty()); + assert!(!a_ifaces[0].borrow().same_io(&b_ifaces[0].borrow())); + assert!( + network.quic().interfaces().is_empty(), + "non-quic driver binds must not appear in quic interface queries" + ); + assert!( + network.quic().current_bind_uris().is_empty(), + "non-quic driver binds must not appear in quic bind uri queries" + ); + assert!( + network.quic().get_interfaces(&pattern).is_none(), + "non-quic driver binds must not appear in quic pattern queries" + ); + + handle_a.unbind().await; + handle_b.unbind().await; + } + + #[tokio::test] + async fn bind_entry_operations_are_serialized_per_key() { + let network = Network::builder().build(); + let driver = Arc::new(SlowCountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let first = tokio::spawn({ + let network = network.clone(); + let driver = driver.clone(); + let pattern = pattern.clone(); + async move { network.bind_with(driver, pattern).await }.in_current_span() + }); + let second = tokio::spawn({ + let network = network.clone(); + let driver = driver.clone(); + let pattern = pattern.clone(); + async move { network.bind_with(driver, pattern).await }.in_current_span() + }); + + let mut first = first.await.expect("first bind task should not panic"); + let mut second = second.await.expect("second bind task should not panic"); + + assert_eq!( + driver.max_active(), + 1, + "same entry must not run driver.bind concurrently" + ); + assert_eq!( + driver.bind_count(), + 1, + "second bind should reuse the first entry membership" + ); + + first.unbind().await; + second.unbind().await; + } + + #[tokio::test] + async fn canceled_pending_bind_registration_removes_unowned_entry_after_closing() { + let network = Network::builder().build(); + let driver = Arc::new(CloseCountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + let key = BindRegistryKey { + driver: bind_driver_id(&driver), + pattern, + }; + let driver_erased: Arc = driver.clone(); + + let permit = network + .acquire_or_insert_entry(key.clone(), driver_erased.clone()) + .await; + let mut pending = permit.begin_bind_registration(network.clone()); + let uri: BindUri = "iface://v4.lo:0".parse().expect("valid bind uri"); + let iface = driver_erased.bind(&network, uri.clone()).await; + pending.push_new_binding(uri, iface); + + drop(pending); + + tokio::time::timeout(Duration::from_secs(1), async { + while network.get_interfaces_with(&driver, &key.pattern).is_some() { + tokio::task::yield_now().await; + } + }) + .await + .expect("canceled pending registration should remove the unowned registry entry"); + + assert!( + driver.close_count() > 0, + "canceled pending registration should close uncommitted interfaces" + ); + } + + #[test] + fn aborted_detached_close_batch_does_not_respawn_cleanup() { + const CHILD_ENV: &str = "H3X_ABORTED_DETACHED_CLOSE_BATCH_CHILD"; + + if std::env::var_os(CHILD_ENV).is_some() { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("failed to build tokio runtime"); + + runtime.block_on(async { + let network = Network::builder().build(); + let driver = Arc::new(BlockingCloseDriver::new()); + let uri: BindUri = "iface://v4.lo:0".parse().expect("valid bind uri"); + let iface = driver.bind(&network, uri).await; + + let mut close = CloseBatch::new(); + close.push(iface); + + let task = close.detach(); + driver.wait_close_started().await; + + task.abort(); + let aborted = task + .await + .expect_err("detached close task should be aborted"); + assert!(aborted.is_cancelled()); + }); + + return; + } + + let output = std::process::Command::new(std::env::current_exe().expect("test binary path")) + .arg("--exact") + .arg("dquic::network::tests::aborted_detached_close_batch_does_not_respawn_cleanup") + .env(CHILD_ENV, "1") + .output() + .expect("spawn child test binary"); + + assert_eq!( + output.status.code(), + Some(0), + "child process must exit cleanly, stdout:\n{}\nstderr:\n{}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ); + } + + #[tokio::test] + async fn bind_state_reconciliation_ignores_only_alloc_port_id() { + let network = Network::builder().build(); + let driver = Arc::new(TestNullDriver::new()); + let driver_erased: Arc = driver.clone(); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + let key = BindRegistryKey { + driver: bind_driver_id(&driver), + pattern: pattern.clone(), + }; + let permit = network + .acquire_or_insert_entry(key, driver_erased.clone()) + .await; + + let existing = BindUri::from_str("iface://v4.lo:0") + .expect("valid bind uri") + .alloc_port(); + let iface = driver_erased.bind(&network, existing.clone()).await; + { + let mut state = permit.entry.lock_state(); + state.bound.insert(existing, iface); + } + + assert!( + permit + .entry + .lock_state() + .missing_added_device_uris(&pattern, "lo") + .is_empty(), + "reconciliation should not bind a duplicate when only alloc_port_id differs" + ); + + let stun_pattern = + BindPattern::from_str("iface://v4.lo:0/?stun=true").expect("valid pattern"); + assert!( + !permit + .entry + .lock_state() + .missing_added_device_uris(&stun_pattern, "lo") + .is_empty(), + "semantic query differences must remain missing reconciliation candidates" + ); + + let different_stun: BindUri = "iface://v4.lo:0/?stun=true" + .parse() + .expect("valid bind uri"); + assert!( + permit + .entry + .lock_state() + .bound + .keys() + .all(|uri| !uri.matches_reconcile_candidate(&different_stun)), + "stun query changes remain semantic" + ); + } + + #[tokio::test] + async fn canceled_unbind_completes_cleanup_before_rebind() { + let network = Network::builder().build(); + let driver = Arc::new(BlockingCloseDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern.clone()).await; + let unbind = tokio::spawn({ + async move { + handle.unbind().await; + } + .in_current_span() + }); + + driver.wait_close_started().await; + unbind.abort(); + let aborted = unbind.await.expect_err("unbind task should be aborted"); + assert!(aborted.is_cancelled()); + + driver.release_close(); + + let mut rebound = tokio::time::timeout( + Duration::from_secs(1), + network.bind_with(driver.clone(), pattern), + ) + .await + .expect("canceled unbind cleanup should not strand the bind key"); + + assert!( + driver.complete_count() > 0, + "canceled unbind cleanup should complete the pending close" + ); + + rebound.unbind().await; + } + + #[tokio::test] + async fn release_exact_entry_does_not_remove_replaced_entry() { + let network = Network::builder().build(); + let driver = Arc::new(TestNullDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut first = network.bind_with(driver.clone(), pattern.clone()).await; + let stale_entry = first.entry.clone(); + let key = first.key.clone(); + first.unbind().await; + + let mut replacement = network.bind_with(driver.clone(), pattern.clone()).await; + network.release_exact_entry(key, stale_entry).await; + + assert!( + network.get_interfaces_with(&driver, &pattern).is_some(), + "stale exact-entry release must not remove replacement entry" + ); + + replacement.unbind().await; + } + + #[tokio::test] + async fn bind_handle_unbind_is_refcounted_and_idempotent() { + let network = Network::builder().build(); + let driver = Arc::new(TestNullDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut first = network.bind_with(driver.clone(), pattern.clone()).await; + let mut second = network.bind_with(driver.clone(), pattern.clone()).await; + + first.unbind().await; + assert!( + network.get_interfaces_with(&driver, &pattern).is_some(), + "first unbind should only decrement the shared bind refcount" + ); + + first.unbind().await; + assert!( + network.get_interfaces_with(&driver, &pattern).is_some(), + "repeating unbind on the same handle must not release again" + ); + + second.unbind().await; + assert!( + network.get_interfaces_with(&driver, &pattern).is_none(), + "last handle should remove the registry entry" + ); + } + + #[tokio::test] + async fn dropped_bind_handle_releases_binding_in_background() { + let network = Network::builder().build(); + let driver = Arc::new(TestNullDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + { + let _handle = network.bind_with(driver.clone(), pattern.clone()).await; + assert!(network.get_interfaces_with(&driver, &pattern).is_some()); + } + + tokio::time::timeout(Duration::from_secs(1), async { + while network.get_interfaces_with(&driver, &pattern).is_some() { + tokio::task::yield_now().await; + } + }) + .await + .expect("dropping the last handle should schedule an unbind task"); + } + + #[tokio::test] + async fn interface_auth_client_checks_sni_registry_and_bind_patterns() { + let network = Network::builder().build(); + let quic = network.quic(); + let server_authority = make_local_authority("alpha"); + let client_authority = make_remote_authority("client"); + + let auth = InterfaceAuthClient { + bind_uri: "iface://v4.eth0:443".parse().expect("valid bind uri"), + sni_registry: quic.sni_registry.clone(), + }; + + let empty_bind = quic + .clone() + .bind_server( + make_identity("alpha"), + make_server_config(), + Arc::new(Vec::new()), + ) + .await + .expect("bind should succeed"); + assert_eq!( + auth.verify_client_name(&server_authority, None), + ClientNameVerifyResult::Accept + ); + assert_eq!( + auth.verify_client_authority(&server_authority, &client_authority), + ClientAuthorityVerifyResult::Accept + ); + + drop(empty_bind); + assert!(matches!( + auth.verify_client_name(&server_authority, None), + ClientNameVerifyResult::SilentRefuse(reason) + if reason == "no server registered for SNI" + )); + + let patterns = Arc::new(vec![ + BindPattern::from_str("iface://v4.lo:443").expect("valid pattern"), + ]); + let _scoped_bind = quic + .clone() + .bind_server(make_identity("alpha"), make_server_config(), patterns) + .await + .expect("bind should succeed"); + + assert!(matches!( + auth.verify_client_name(&server_authority, None), + ClientNameVerifyResult::SilentRefuse(reason) if reason == "bind pattern mismatch" + )); + + let matching_auth = InterfaceAuthClient { + bind_uri: "iface://v4.lo:443".parse().expect("valid bind uri"), + sni_registry: quic.sni_registry.clone(), + }; + assert_eq!( + matching_auth.verify_client_name(&server_authority, None), + ClientNameVerifyResult::Accept + ); + } + + #[tokio::test] + async fn network_device_addr_helpers_reject_non_matching_inputs() { + let network = Network::builder().build(); + assert!( + network + .resolve_device_addr("__missing__", Family::V4) + .is_none() + ); + + let inet_uri: BindUri = "inet://127.0.0.1:443".parse().expect("valid bind uri"); + assert!(!network.bound_addr_is_on_default_route( + &inet_uri, + SocketAddr::from(([127, 0, 0, 1], 443)), + )); + + let iface_uri: BindUri = "iface://v4.__missing__:443" + .parse() + .expect("valid bind uri"); + assert!( + !network + .bound_addr_is_on_default_route(&iface_uri, SocketAddr::from(([0, 0, 0, 0], 443)),) + ); + assert!(!network.bound_addr_is_on_default_route( + &iface_uri, + "[::1]:443".parse().expect("valid socket addr"), + )); + } + + #[tokio::test] + async fn quic_bind_driver_initializes_stun_component_branches() { + let manager = Arc::new(InterfaceManager::new()); + let network = Network::build_with_quic_driver(Devices::global(), |network| { + QuicBindDriver::builder() + .network(network) + .iface_manager(manager) + .io_factory(Arc::new(NullIoFactory)) + .stun_server(Arc::from("default.stun.example:3478")) + .build() + }); + let quic = network.quic(); + + let disabled: BindUri = "iface://v4.lo:0/?stun=false" + .parse() + .expect("valid bind uri"); + let _disabled_iface = BindDriver::bind(quic.as_ref(), &network, disabled).await; + + let explicit: BindUri = "iface://v4.lo:1/?stun_server=explicit.stun.example:3478" + .parse() + .expect("valid bind uri"); + let _explicit_iface = BindDriver::bind(quic.as_ref(), &network, explicit).await; + } + + #[tokio::test] + async fn reconcile_event_rebinds_matching_changed_interface_bindings() { + let network = Network::builder().build(); + let driver = Arc::new(CountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern).await; + assert!(driver.bind_count() > 0, "initial bind should create iface"); + + network.reconcile_event(&changed_event("lo")).await; + + assert!( + driver.rebind_count() > 0, + "matching changed event should rebind existing lo iface" + ); + + handle.unbind().await; + } + + #[tokio::test] + async fn reconcile_event_skips_unrelated_changed_interface_bindings() { + let network = Network::builder().build(); + let driver = Arc::new(CountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern).await; + assert!(driver.bind_count() > 0, "initial bind should create iface"); + + network + .reconcile_event(&changed_event("__unrelated__")) + .await; + + assert_eq!( + driver.rebind_count(), + 0, + "unrelated changed event must not rebind lo iface" + ); + + handle.unbind().await; + } + + #[tokio::test] + async fn reconcile_event_skips_inet_bindings() { + let network = Network::builder().build(); + let driver = Arc::new(CountingDriver::new()); + let pattern = BindPattern::from_str("inet://127.0.0.1:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern).await; + assert_eq!( + driver.bind_count(), + 1, + "initial inet bind should create one iface" + ); + + network.reconcile_event(&changed_event("lo")).await; + + assert_eq!( + driver.rebind_count(), + 0, + "interface changed event must not rebind inet iface" + ); + + handle.unbind().await; + } + + #[tokio::test] + async fn reconcile_event_added_binds_missing_matching_membership_without_rebind() { + let network = Network::builder().build(); + let driver = Arc::new(CountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern.clone()).await; + let initial_binds = driver.bind_count(); + assert!(initial_binds > 0, "initial bind should create iface"); + + clear_bound_for_test(&network, &driver, &pattern).await; + + network.reconcile_event(&added_event("lo")).await; + + assert!( + driver.bind_count() > initial_binds, + "matching added event should create missing bound membership" + ); + assert_eq!( + driver.rebind_count(), + 0, + "added event must not rebind existing bindings" + ); + assert!( + network + .get_interfaces_with(&driver, &pattern) + .is_some_and(|interfaces| !interfaces.is_empty()), + "added reconcile should repopulate registry membership" + ); + + handle.unbind().await; + } + + #[tokio::test] + async fn reconcile_event_removed_removes_matching_membership_without_rebind() { + let network = Network::builder().build(); + let driver = Arc::new(CountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern.clone()).await; + assert!( + network + .get_interfaces_with(&driver, &pattern) + .is_some_and(|interfaces| !interfaces.is_empty()), + "initial bind should create membership" + ); + + network.reconcile_event(&removed_event("lo")).await; + + assert!( + network + .get_interfaces_with(&driver, &pattern) + .is_some_and(|interfaces| interfaces.is_empty()), + "removed event should remove stale bound iface for the removed device" + ); + assert_eq!( + driver.rebind_count(), + 0, + "removed event must not rebind remaining bindings" + ); + + handle.unbind().await; + } + + #[tokio::test] + async fn removed_device_then_release_closes_each_binding_once() { + let network = Network::builder().build(); + let driver = Arc::new(CloseCountingDriver::new()); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let mut handle = network.bind_with(driver.clone(), pattern.clone()).await; + let initial = network + .get_interfaces_with(&driver, &pattern) + .expect("registered pattern") + .len(); + assert!(initial > 0, "test expects loopback bindings"); + + network.reconcile_event(&removed_event("lo")).await; + assert_eq!( + driver.close_count(), + initial, + "removed event should close drained loopback bindings" + ); + + handle.unbind().await; + assert_eq!( + driver.close_count(), + initial, + "final release must not double-close already drained bindings" + ); + } + + #[tokio::test] + async fn quic_bind_driver_rebinds_existing_interface_io() { + use std::{ + net::SocketAddr, + sync::atomic::{AtomicUsize, Ordering}, + task::Poll, + }; + + use bytes::BytesMut; + use dquic::qbase::net::route::Route; + + #[derive(Debug)] + struct RebindingIoFactory { + bind_count: AtomicUsize, + } + + #[derive(Debug)] + struct RebindingIo { + bind_uri: BindUri, + addr: SocketAddr, + } + + impl ProductIO for RebindingIoFactory { + fn bind(&self, bind_uri: BindUri) -> Box { + let port = 10_000 + self.bind_count.fetch_add(1, Ordering::SeqCst) as u16; + Box::new(RebindingIo { + bind_uri, + addr: SocketAddr::from(([127, 0, 0, 1], port)), + }) + } + } + + impl crate::dquic::net::IO for RebindingIo { + fn bind_uri(&self) -> BindUri { + self.bind_uri.clone() + } + + fn bound_addr(&self) -> io::Result { + Ok(self.addr) + } + + fn max_segment_size(&self) -> io::Result { + Ok(1200) + } + + fn max_segments(&self) -> io::Result { + Ok(1) + } + + fn poll_send( + &self, + _cx: &mut Context, + _pkts: &[io::IoSlice], + _route: Route, + ) -> Poll> { + Poll::Ready(Ok(0)) + } + + fn poll_recv( + &self, + _cx: &mut Context, + _pkts: &mut [BytesMut], + _route: &mut [Route], + ) -> Poll> { + Poll::Pending + } + + fn poll_close(&mut self, _cx: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + } + + let factory = Arc::new(RebindingIoFactory { + bind_count: AtomicUsize::new(0), + }); + let manager = Arc::new(InterfaceManager::new()); + let network = Network::build_with_quic_driver(Devices::global(), |network| { + QuicBindDriver::builder() + .network(network) + .iface_manager(manager) + .io_factory(factory.clone()) + .build() + }); + let uri: BindUri = "inet://127.0.0.1:0".parse().expect("valid bind uri"); + + let quic = network.quic(); + let iface = BindDriver::bind(quic.as_ref(), &network, uri).await; + let before = iface.borrow().bound_addr().expect("initial bound addr"); + + quic.rebind(&network, &iface).await; + + let after = iface.borrow().bound_addr().expect("rebound addr"); + assert_ne!(before, after, "quic bind driver must replace stale IO"); + assert_eq!(factory.bind_count.load(Ordering::SeqCst), 2); + } +} diff --git a/src/dquic/server.rs b/src/dquic/server.rs index 9791300..3fd5021 100644 --- a/src/dquic/server.rs +++ b/src/dquic/server.rs @@ -1,205 +1,402 @@ -use std::{error::Error, sync::Arc, time::Duration}; - -use ::dquic::{ - builder::QuicListenersBuilder, - prelude::{ - AuthClient, BindUri, Connection, ListenError, ProductStreamsConcurrencyController, - QuicListeners, Resolve, ServerError, handy, - }, - qbase::{param::ServerParameters, token::TokenProvider}, - qevent::telemetry::QLog, - qinterface::{ - component::route::QuicRouter, device::Devices, io::ProductIO, manager::InterfaceManager, - }, -}; -use rustls::{crypto::CryptoProvider, server::danger::ClientCertVerifier}; +//! Server-side QUIC configuration types. +//! +//! [`ServerQuicConfig`] holds all server-side QUIC configuration — both the +//! common (role-independent) and server-specific values — directly, without +//! indirection through sub-configuration types. +//! +//! Trait-object fields (`stream_strategy_factory`, `qlogger`, +//! `token_provider`, `client_auther`, `client_cert_verifier`) are stored as +//! `Arc` so that clones share the same allocation and endpoint +//! clones can reuse cached TLS configs across clones via `Arc::ptr_eq`. +//! +//! All types implement [`Default`] so that server endpoints can be constructed +//! without the caller having to hand-roll configuration values. + +use std::{sync::Arc, time::Duration}; -use crate::{ - connection::ConnectionBuilder, - pool::Pool, - server::{Servers, ServersRouter, UnresolvedRequest}, +use rustls::server::{NoClientAuth, danger::ClientCertVerifier}; + +use crate::dquic::{ + log::{QLog, handy::NoopLogger}, + param::{ServerParameters, handy::server_parameters}, + stream::{ProductStreamsConcurrencyController, handy::ConsistentConcurrency}, + tls::{AuthClient, handy::AcceptAllClientAuther}, + token::{TokenProvider, handy::NoopTokenRegistry}, }; -pub struct H3ServersTlsBuilder { - crypto_provider: Option>, - client_cert_verifier: Option>, -} +// --- legacy ServerSpecificConfig fields (flattened into ServerQuicConfig) --- +// #[derive(Clone)] +// pub struct ServerSpecificConfig { +// pub parameters: ServerParameters, +// pub alpns: Vec>, +// pub token_provider: Arc, +// pub backlog: usize, +// pub client_auther: Arc, +// pub client_cert_verifier: Arc, +// pub anti_port_scan: bool, +// } +// +// impl Default for ServerSpecificConfig { ... } +// impl Debug for ServerSpecificConfig { ... } +// impl PartialEq for ServerSpecificConfig { ... } -pub type H3Servers = Servers, ServersRouter>; +// --------------------------------------------------------------------------- +// ServerQuicConfig — all fields inlined +// --------------------------------------------------------------------------- + +/// Server-side QUIC configuration with all fields inlined. +#[derive(Clone)] +pub struct ServerQuicConfig { + // ---- from CommonQuicConfig ---- + /// How long the connection should keep sending probe packets after going + /// idle. `Duration::ZERO` (the default) disables deferred idle timeouts. + pub defer_idle_timeout: Duration, + /// Factory producing per-connection streams concurrency controllers. + pub stream_strategy_factory: Arc, + /// QUIC-events logger (qlog). Defaults to a no-op logger. + pub qlogger: Arc, + /// Whether 0-RTT should be enabled if the crypto context permits it. + pub enable_0rtt: bool, + /// Enable SSL key logging via `SSLKEYLOGFILE` for debugging captures. + pub enable_sslkeylog: bool, + + // ---- from ServerSpecificConfig ---- + /// Transport parameters advertised by the server. + pub parameters: ServerParameters, + /// ALPN protocol identifiers. Empty means no ALPN. + pub alpns: Vec>, + /// Address validation token provider. + pub token_provider: Arc, + /// Maximum number of pending inbound connections before packets start + /// being dropped at the network level. + pub backlog: usize, + /// Custom client authenticator; runs on top of rustls's certificate + /// verification. Defaults to [`AcceptAllClientAuther`]. + pub client_auther: Arc, + /// How rustls should verify client certificates. Defaults to + /// [`NoClientAuth`]. + pub client_cert_verifier: Arc, + /// When enabled, failed connections are silently dropped instead of + /// answered with an error packet. + pub anti_port_scan: bool, +} -impl H3Servers<()> { - pub fn builder() -> H3ServersTlsBuilder { - H3ServersTlsBuilder { - crypto_provider: None, - client_cert_verifier: None, +impl Default for ServerQuicConfig { + fn default() -> Self { + Self { + defer_idle_timeout: Duration::ZERO, + stream_strategy_factory: Arc::new(ConsistentConcurrency::new), + qlogger: Arc::new(NoopLogger), + enable_0rtt: false, + enable_sslkeylog: false, + parameters: server_parameters(), + alpns: Vec::new(), + token_provider: Arc::new(NoopTokenRegistry), + backlog: 128, + client_auther: Arc::new(AcceptAllClientAuther), + client_cert_verifier: Arc::new(NoClientAuth), + anti_port_scan: false, } } } -impl H3ServersTlsBuilder { - pub fn with_crypto_provider(mut self, crypto_provider: impl Into>) -> Self { - self.crypto_provider = Some(crypto_provider.into()); - self +impl std::fmt::Debug for ServerQuicConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServerQuicConfig") + .field("defer_idle_timeout", &self.defer_idle_timeout) + .field("enable_0rtt", &self.enable_0rtt) + .field("enable_sslkeylog", &self.enable_sslkeylog) + .field("alpns", &self.alpns.len()) + .field("backlog", &self.backlog) + .field("anti_port_scan", &self.anti_port_scan) + .finish_non_exhaustive() } +} - pub fn with_client_cert_verifier( - mut self, - client_cert_verifier: Arc, - ) -> Result { - self.client_cert_verifier = Some(client_cert_verifier); - self.try_into() +impl PartialEq for ServerQuicConfig { + fn eq(&self, other: &Self) -> bool { + self.defer_idle_timeout == other.defer_idle_timeout + && self.enable_0rtt == other.enable_0rtt + && self.enable_sslkeylog == other.enable_sslkeylog + && Arc::ptr_eq( + &self.stream_strategy_factory, + &other.stream_strategy_factory, + ) + && Arc::ptr_eq(&self.qlogger, &other.qlogger) + && self.parameters == other.parameters + && self.alpns == other.alpns + && Arc::ptr_eq(&self.token_provider, &other.token_provider) + && self.backlog == other.backlog + && Arc::ptr_eq(&self.client_auther, &other.client_auther) + && Arc::ptr_eq(&self.client_cert_verifier, &other.client_cert_verifier) + && self.anti_port_scan == other.anti_port_scan } +} - pub fn without_client_cert_verifier(mut self) -> Result { - self.client_cert_verifier = None; - self.try_into() +impl ServerQuicConfig { + /// Returns `true` when `self` and `other` describe the same server configuration. + /// + /// Compares every field individually. Trait-object fields use + /// [`Arc::ptr_eq`](std::sync::Arc::ptr_eq); plain-value fields use `==`. + pub(crate) fn is_compatible_with(&self, other: &Self) -> bool { + self == other } -} -impl TryFrom for H3ServersBuilder { - type Error = rustls::Error; - - fn try_from(builder: H3ServersTlsBuilder) -> Result { - let listeners_builder = match builder.crypto_provider { - Some(crypto_provider) => QuicListeners::builder_with_crypto_provider(crypto_provider), - None => Ok(QuicListeners::builder()), - }?; - let listeners_builder = match builder.client_cert_verifier { - Some(client_cert_verifier) => { - listeners_builder.with_client_cert_verifier(client_cert_verifier) - } - None => listeners_builder.without_client_cert_verifier(), + /// Build the rustls server config shared across all SNIs registered on a + /// network. The resolver selects a [`CertifiedKey`] based on ClientHello SNI. + pub(crate) fn build_rustls_server_config( + &self, + resolver: crate::dquic::sni::SniCertResolver, + ) -> Result { + use snafu::ResultExt; + + const TLS13: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13]; + let provider = rustls::ServerConfig::builder().crypto_provider().clone(); + let builder = rustls::ServerConfig::builder_with_provider(provider) + .with_protocol_versions(TLS13) + .context(crate::dquic::network::bind_server_error::VersionSnafu)?; + + let mut tls = builder + .with_client_cert_verifier(self.client_cert_verifier.clone()) + .with_cert_resolver(std::sync::Arc::new(resolver)); + tls.alpn_protocols.clone_from(&self.alpns); + if self.enable_0rtt { + tls.max_early_data_size = 0xffff_ffff; } - .with_alpns(vec!["h3"]); - Ok(H3ServersBuilder { - listeners_builder, - backlog: 1024, - pool: Pool::empty(), - builder: Arc::new(ConnectionBuilder::new(Arc::default())), - }) + Ok(tls) } } -pub struct H3ServersBuilder { - listeners_builder: QuicListenersBuilder, - backlog: usize, - pool: Pool, - builder: Arc>, -} +// --------------------------------------------------------------------------- +// Unit tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod server_tests { + use std::{ + sync::{Arc, Weak}, + time::Duration, + }; -impl H3ServersBuilder { - pub fn with_resolver(mut self, resolver: Arc) -> Self { - self.listeners_builder = self.listeners_builder.with_resolver(resolver); - self + use crate::dquic::{server::*, sni::SniCertResolver}; + + #[test] + fn test_server_quic_config_default() { + let cfg = ServerQuicConfig::default(); + assert_eq!(cfg.defer_idle_timeout, Duration::ZERO); + assert!(!cfg.enable_0rtt); + assert!(!cfg.enable_sslkeylog); + assert!(cfg.alpns.is_empty()); + assert_eq!(cfg.backlog, 128); + assert!(!cfg.anti_port_scan); + assert_eq!(Arc::strong_count(&cfg.stream_strategy_factory), 1); + assert_eq!(Arc::strong_count(&cfg.qlogger), 1); + assert_eq!(Arc::strong_count(&cfg.token_provider), 1); + assert_eq!(Arc::strong_count(&cfg.client_auther), 1); + assert_eq!(Arc::strong_count(&cfg.client_cert_verifier), 1); } - pub fn with_iface_factory(mut self, factory: Arc) -> Self { - self.listeners_builder = self.listeners_builder.with_iface_factory(factory); - self + #[test] + fn test_server_quic_config_partial_eq_same() { + let a = ServerQuicConfig::default(); + let b = a.clone(); + assert_eq!(a, b); } - /// Specify the interfaces manager for the client. - pub fn with_iface_manager(mut self, iface_manager: Arc) -> Self { - self.listeners_builder = self.listeners_builder.with_iface_manager(iface_manager); - self + #[test] + fn test_server_quic_config_partial_eq_different_alpns() { + let a = ServerQuicConfig::default(); + let mut b = a.clone(); + b.alpns = vec![b"h3".to_vec()]; + assert_ne!(a, b); } - pub fn with_router(mut self, router: Arc) -> Self { - self.listeners_builder = self.listeners_builder.with_router(router); - self + #[test] + fn test_server_quic_config_partial_eq_different_backlog() { + let a = ServerQuicConfig::default(); + let mut b = a.clone(); + b.backlog = 256; + assert_ne!(a, b); } - pub fn with_token_provider(mut self, token_provider: Arc) -> Self { - self.listeners_builder = self.listeners_builder.with_token_provider(token_provider); - self + #[test] + fn test_server_quic_config_partial_eq_different_anti_port_scan() { + let a = ServerQuicConfig::default(); + let mut b = a.clone(); + b.anti_port_scan = true; + assert_ne!(a, b); } - pub fn with_streams_concurrency_strategy( - mut self, - strategy_factory: Arc, - ) -> Self { - self.listeners_builder = self - .listeners_builder - .with_streams_concurrency_strategy(strategy_factory); - self + #[test] + fn test_server_quic_config_debug_reports_public_value_fields() { + let cfg = ServerQuicConfig { + defer_idle_timeout: Duration::from_secs(3), + enable_0rtt: true, + enable_sslkeylog: true, + alpns: vec![b"h3".to_vec(), b"dhttp".to_vec()], + backlog: 7, + anti_port_scan: true, + ..Default::default() + }; + + let debug = format!("{cfg:?}"); + + assert!(debug.contains("defer_idle_timeout: 3s")); + assert!(debug.contains("enable_0rtt: true")); + assert!(debug.contains("enable_sslkeylog: true")); + assert!(debug.contains("alpns: 2")); + assert!(debug.contains("backlog: 7")); + assert!(debug.contains("anti_port_scan: true")); + assert!(debug.contains("..")); } - pub fn defer_idle_timeout(mut self, duration: Duration) -> Self { - self.listeners_builder = self.listeners_builder.defer_idle_timeout(duration); - self + #[test] + fn test_server_quic_config_partial_eq_different_common_values() { + let a = ServerQuicConfig::default(); + + let mut b = a.clone(); + b.defer_idle_timeout = Duration::from_secs(1); + assert_ne!(a, b); + + let mut b = a.clone(); + b.enable_0rtt = true; + assert_ne!(a, b); + + let mut b = a.clone(); + b.enable_sslkeylog = true; + assert_ne!(a, b); } - pub fn with_quic_parameters(mut self, parameters: ServerParameters) -> Self { - self.listeners_builder = self.listeners_builder.with_parameters(parameters); - self + #[test] + fn test_server_quic_config_partial_eq_requires_same_trait_object_arcs() { + let a = ServerQuicConfig::default(); + let b = ServerQuicConfig::default(); + + assert_eq!(a.defer_idle_timeout, b.defer_idle_timeout); + assert_eq!(a.enable_0rtt, b.enable_0rtt); + assert_eq!(a.enable_sslkeylog, b.enable_sslkeylog); + assert_eq!(a.parameters, b.parameters); + assert_eq!(a.alpns, b.alpns); + assert_eq!(a.backlog, b.backlog); + assert_eq!(a.anti_port_scan, b.anti_port_scan); + assert!( + !Arc::ptr_eq(&a.stream_strategy_factory, &b.stream_strategy_factory), + "fresh default configs should not share dynamic strategy factories" + ); + assert_ne!(a, b); } - pub fn physical_ifaces(mut self, physical_ifaces: &'static Devices) -> Self { - self.listeners_builder = self.listeners_builder.with_physical_ifaces(physical_ifaces); - self + #[test] + fn test_server_quic_config_clone() { + let a = ServerQuicConfig::default(); + let b = a.clone(); + // Plain-value fields clone independently + assert_eq!(a.defer_idle_timeout, b.defer_idle_timeout); + assert_eq!(a.enable_0rtt, b.enable_0rtt); + assert_eq!(a.enable_sslkeylog, b.enable_sslkeylog); + assert_eq!(a.alpns, b.alpns); + assert_eq!(a.backlog, b.backlog); + assert_eq!(a.anti_port_scan, b.anti_port_scan); + // Trait-object Arcs are shared after clone + assert!(Arc::ptr_eq( + &a.stream_strategy_factory, + &b.stream_strategy_factory, + )); + assert!(Arc::ptr_eq(&a.qlogger, &b.qlogger)); + assert!(Arc::ptr_eq(&a.token_provider, &b.token_provider)); + assert!(Arc::ptr_eq(&a.client_auther, &b.client_auther)); + assert!(Arc::ptr_eq( + &a.client_cert_verifier, + &b.client_cert_verifier, + )); + assert_eq!(a, b); } - pub fn with_qlog(mut self, logger: Arc) -> Self { - self.listeners_builder = self.listeners_builder.with_qlog(logger); - self + #[test] + fn test_server_quic_config_mutate_independent_after_clone() { + let mut a = ServerQuicConfig::default(); + let b = a.clone(); + + a.defer_idle_timeout = Duration::from_secs(42); + a.backlog = 256; + + assert_eq!(b.defer_idle_timeout, Duration::ZERO); + assert_eq!(b.backlog, 128); + assert_eq!(a.defer_idle_timeout, Duration::from_secs(42)); + assert_eq!(a.backlog, 256); } - pub fn enable_anti_port_scan(mut self) -> Self { - self.listeners_builder = self.listeners_builder.enable_anti_port_scan(); - self + #[test] + fn test_server_quic_config_is_compatible_with_same_arc() { + let a = ServerQuicConfig::default(); + let b = a.clone(); + // Trait-object Arcs are shared, so PartialEq (ptr_eq) returns true quickly + assert!(a.is_compatible_with(&b)); } - pub fn with_client_auther(mut self, client_auther: impl AuthClient + 'static) -> Self { - self.listeners_builder = self.listeners_builder.with_client_auther(client_auther); - self + #[test] + fn test_server_quic_config_is_compatible_with_same_values() { + // Mutate then restore values — the fields end up equal. + let mut a = ServerQuicConfig::default(); + let b = a.clone(); + a.defer_idle_timeout = Duration::from_secs(1); + a.defer_idle_timeout = Duration::ZERO; + assert_eq!(a.defer_idle_timeout, b.defer_idle_timeout); + assert!(a.is_compatible_with(&b)); } - pub fn enable_0rtt(mut self) -> Self { - self.listeners_builder = self.listeners_builder.enable_0rtt(); - self + #[test] + fn test_server_quic_config_is_compatible_with_different_alpns() { + let mut a = ServerQuicConfig::default(); + let b = a.clone(); + a.alpns = vec![b"h3".to_vec()]; + assert!(!a.is_compatible_with(&b)); } - pub fn with_backlog(mut self, backlog: usize) -> Self { - self.backlog = backlog; - self + #[test] + fn test_server_quic_config_is_compatible_with_different_backlog() { + let mut a = ServerQuicConfig::default(); + let b = a.clone(); + a.backlog = 256; + assert!(!a.is_compatible_with(&b)); } - pub fn with_builder(mut self, builder: Arc>) -> Self { - self.builder = builder; - self + #[test] + fn test_server_quic_config_is_compatible_with_different_anti_port_scan() { + let mut a = ServerQuicConfig::default(); + let b = a.clone(); + a.anti_port_scan = true; + assert!(!a.is_compatible_with(&b)); } - pub fn listen(self) -> Result, ListenError> { - let listener = self.listeners_builder.listen(self.backlog)?; - Ok(Servers::from_quic_listener() - .listener(listener) - .pool(self.pool) - .service(ServersRouter::new()) - .builder(self.builder) - .build()) + #[test] + fn build_rustls_server_config_copies_alpns_and_enables_0rtt() { + let cfg = ServerQuicConfig { + alpns: vec![b"h3".to_vec(), b"dhttp".to_vec()], + enable_0rtt: true, + ..Default::default() + }; + + let tls = cfg + .build_rustls_server_config(SniCertResolver { + registry: Weak::new(), + }) + .expect("default verifier should produce a rustls config"); + + assert_eq!(tls.alpn_protocols, cfg.alpns); + assert_eq!(tls.max_early_data_size, 0xffff_ffff); } -} -impl Servers, ServersRouter> -where - S: tower_service::Service + Clone + Send + Sync + 'static, - S::Future: Send, - S::Error: Into>, -{ - pub async fn add_server( - &mut self, - server_name: impl Into, - cert_chain: impl handy::ToCertificate, - private_key: impl handy::ToPrivateKey, - ocsp: impl Into>>, - bind_uris: impl IntoIterator>, - router: S, - ) -> Result<&mut Self, ServerError> { - let server_name = server_name.into(); - self.quic_listener() - .add_server(&server_name, cert_chain, private_key, bind_uris, ocsp) - .await?; - self.service_mut().serve(server_name, router); - Ok(self) + #[test] + fn build_rustls_server_config_leaves_0rtt_disabled_by_default() { + let cfg = ServerQuicConfig::default(); + + let tls = cfg + .build_rustls_server_config(SniCertResolver { + registry: Weak::new(), + }) + .expect("default verifier should produce a rustls config"); + + assert_eq!(tls.max_early_data_size, 0); } } diff --git a/src/dquic/shim.rs b/src/dquic/shim.rs index 1278aa3..1fa7d87 100644 --- a/src/dquic/shim.rs +++ b/src/dquic/shim.rs @@ -1,27 +1,21 @@ use std::{ borrow::Cow, pin::Pin, - sync::Arc, + sync::{Arc, OnceLock}, task::{Context, Poll}, }; use bytes::Bytes; +use dashmap::{DashMap, mapref::entry::Entry}; +use dhttp_identity::identity::{self as authority, SignError}; use futures::{Sink, Stream, future::BoxFuture}; -use rustls::{SignatureScheme, pki_types::CertificateDer, sign::CertifiedKey}; - -use crate::{ - error::Code, - quic, - quic::{ - CancelStream, - agent::{self, SignError}, - }, - varint::VarInt, -}; +use rustls::{pki_types::CertificateDer, sign::CertifiedKey}; + +use crate::{error::Code, quic, quic::ResetStream, varint::VarInt}; pub fn convert_varint(varint: dquic::prelude::VarInt) -> VarInt { // dquic's VarInt is already bounds-checked to RFC 9000 spec (< 2^62) - VarInt::from_u64(varint.into_inner()).expect("dquic VarInt is within valid range") + VarInt::from_u64(varint.into_u64()).expect("dquic VarInt is within valid range") } pub fn convert_connection_error(error: dquic::prelude::Error) -> quic::ConnectionError { @@ -50,6 +44,88 @@ pub fn convert_connection_error(error: dquic::prelude::Error) -> quic::Connectio } } +type DquicConnectionId = dquic::qbase::cid::ConnectionId; + +#[derive(Clone)] +struct LatchedConnectionError { + origin_dcid: Option, + error: quic::ConnectionError, +} + +fn dquic_connection_latches() -> &'static DashMap { + static LATCHES: OnceLock> = OnceLock::new(); + LATCHES.get_or_init(DashMap::new) +} + +fn dquic_connection_key(connection: &dquic::prelude::Connection) -> usize { + std::ptr::from_ref(connection).cast::<()>() as usize +} + +fn dquic_origin_dcid(connection: &dquic::prelude::Connection) -> Option { + connection.origin_dcid().ok() +} + +fn is_stale_latch( + stored_origin: Option, + current_origin: Option, +) -> bool { + match (stored_origin, current_origin) { + (Some(stored), Some(current)) => stored != current, + (None, Some(_)) => true, + _ => false, + } +} + +fn latched_connection_error( + connection: &dquic::prelude::Connection, +) -> Option { + let key = dquic_connection_key(connection); + let entry = dquic_connection_latches().get(&key)?; + if is_stale_latch(entry.origin_dcid, dquic_origin_dcid(connection)) { + drop(entry); + dquic_connection_latches().remove(&key); + None + } else { + Some(entry.error.clone()) + } +} + +fn latch_connection_error( + connection: &dquic::prelude::Connection, + origin_dcid: Option, + error: quic::ConnectionError, +) -> quic::ConnectionError { + let key = dquic_connection_key(connection); + match dquic_connection_latches().entry(key) { + Entry::Occupied(mut entry) => { + if is_stale_latch(entry.get().origin_dcid, origin_dcid) { + entry.insert(LatchedConnectionError { + origin_dcid, + error: error.clone(), + }); + error + } else { + entry.get().error.clone() + } + } + Entry::Vacant(entry) => { + entry.insert(LatchedConnectionError { + origin_dcid, + error: error.clone(), + }); + error + } + } +} + +fn convert_and_latch_connection_error( + connection: &dquic::prelude::Connection, + error: dquic::prelude::Error, +) -> quic::ConnectionError { + let origin_dcid = dquic_origin_dcid(connection); + latch_connection_error(connection, origin_dcid, convert_connection_error(error)) +} + pub fn convert_stream_error(error: dquic::prelude::StreamError) -> quic::StreamError { match error { dquic::prelude::StreamError::Connection(error) => { @@ -60,7 +136,19 @@ pub fn convert_stream_error(error: dquic::prelude::StreamError) -> quic::StreamE .expect("QUIC reset error code fits in VarInt range"), }, dquic::prelude::StreamError::EosSent => { - unreachable!("h3x write data after shutdown") + // dquic emits `EosSent` only when the local send side has + // already committed EOS/shutdown and h3x attempts to write more + // data afterward. + // + // h3x intentionally does not model this as `quic::StreamError`: + // it is not a peer `RESET_STREAM`, not a connection failure, and + // not an H3 protocol violation. The message/codec writer layer + // owns this lifecycle invariant and must prevent writes after + // close/reset before they reach the QUIC backend. + // + // Reaching this branch means h3x violated its writer lifecycle + // invariant. + panic!("h3x write data after shutdown") } } } @@ -161,8 +249,8 @@ impl Sink for StreamWriter { } } -impl CancelStream for StreamWriter { - fn poll_cancel( +impl ResetStream for StreamWriter { + fn poll_reset( self: Pin<&mut Self>, _: &mut Context, code: VarInt, @@ -183,13 +271,17 @@ impl quic::ManageStream for dquic::prelude::Connection { let stream = self .open_bi_stream() .await - .map_err(convert_connection_error)?; + .map_err(|error| convert_and_latch_connection_error(self, error))?; let (stream_id, (reader, writer)) = stream.ok_or_else(|| { - quic::ConnectionError::from(quic::TransportError { - kind: VarInt::from_u32(0x04), // STREAM_LIMIT_ERROR - frame_type: VarInt::from_u32(0), - reason: "stream ID space exhausted".into(), - }) + latch_connection_error( + self, + dquic_origin_dcid(self), + quic::ConnectionError::from(quic::TransportError { + kind: VarInt::from_u32(0x04), // STREAM_LIMIT_ERROR + frame_type: VarInt::from_u32(0), + reason: "stream ID space exhausted".into(), + }), + ) })?; let stream_id = convert_varint(stream_id.into()); let reader = StreamReader { stream_id, reader }; @@ -201,13 +293,17 @@ impl quic::ManageStream for dquic::prelude::Connection { let stream = self .open_uni_stream() .await - .map_err(convert_connection_error)?; + .map_err(|error| convert_and_latch_connection_error(self, error))?; let (stream_id, writer) = stream.ok_or_else(|| { - quic::ConnectionError::from(quic::TransportError { - kind: VarInt::from_u32(0x04), // STREAM_LIMIT_ERROR - frame_type: VarInt::from_u32(0), - reason: "stream ID space exhausted".into(), - }) + latch_connection_error( + self, + dquic_origin_dcid(self), + quic::ConnectionError::from(quic::TransportError { + kind: VarInt::from_u32(0x04), // STREAM_LIMIT_ERROR + frame_type: VarInt::from_u32(0), + reason: "stream ID space exhausted".into(), + }), + ) })?; let stream_id = convert_varint(stream_id.into()); Ok(StreamWriter { stream_id, writer }) @@ -219,7 +315,7 @@ impl quic::ManageStream for dquic::prelude::Connection { let (stream_id, (reader, writer)) = self .accept_bi_stream() .await - .map_err(convert_connection_error)?; + .map_err(|error| convert_and_latch_connection_error(self, error))?; let stream_id = convert_varint(stream_id.into()); let reader = StreamReader { stream_id, reader }; let writer = StreamWriter { stream_id, writer }; @@ -230,19 +326,19 @@ impl quic::ManageStream for dquic::prelude::Connection { let (stream_id, reader) = self .accept_uni_stream() .await - .map_err(convert_connection_error)?; + .map_err(|error| convert_and_latch_connection_error(self, error))?; let stream_id = convert_varint(stream_id.into()); Ok(StreamReader { stream_id, reader }) } } #[derive(Debug)] -pub struct DquicLocalAgent { +pub struct DquicLocalAuthority { name: Arc, certified_key: Arc, } -impl agent::LocalAgent for DquicLocalAgent { +impl authority::LocalAuthority for DquicLocalAuthority { fn name(&self) -> &str { &self.name } @@ -250,28 +346,19 @@ impl agent::LocalAgent for DquicLocalAgent { fn cert_chain(&self) -> &[CertificateDer<'static>] { self.certified_key.cert.as_slice() } - - fn sign_algorithm(&self) -> rustls::SignatureAlgorithm { - self.certified_key.key.algorithm() - } - - fn sign( - &self, - scheme: SignatureScheme, - data: &[u8], - ) -> BoxFuture<'_, Result, SignError>> { - let result = agent::sign_with_key(self.certified_key.key.as_ref(), scheme, data); + fn sign(&self, data: &[u8]) -> BoxFuture<'_, Result, SignError>> { + let result = authority::sign_with_key(self.certified_key.key.as_ref(), data); Box::pin(std::future::ready(result)) } } #[derive(Debug)] -pub struct DquicRemoteAgent { +pub struct DquicRemoteAuthority { name: Arc, cert_chain: Arc<[CertificateDer<'static>]>, } -impl agent::RemoteAgent for DquicRemoteAgent { +impl authority::RemoteAuthority for DquicRemoteAuthority { fn name(&self) -> &str { &self.name } @@ -281,15 +368,18 @@ impl agent::RemoteAgent for DquicRemoteAgent { } } -impl quic::WithLocalAgent for dquic::prelude::Connection { - type LocalAgent = DquicLocalAgent; +impl quic::WithLocalAuthority for dquic::prelude::Connection { + type LocalAuthority = DquicLocalAuthority; - async fn local_agent(&self) -> Result, quic::ConnectionError> { - let local_agent = self.local_agent().await.map_err(convert_connection_error)?; - Ok(local_agent.map(|local_agent| { - let name = AsRef::>::as_ref(&local_agent).clone(); - let certified_key = AsRef::>::as_ref(&local_agent).clone(); - DquicLocalAgent { + async fn local_authority(&self) -> Result, quic::ConnectionError> { + let authority = self + .local_authority() + .await + .map_err(|error| convert_and_latch_connection_error(self, error))?; + Ok(authority.map(|authority| { + let name = AsRef::>::as_ref(&authority).clone(); + let certified_key = AsRef::>::as_ref(&authority).clone(); + DquicLocalAuthority { name, certified_key, } @@ -297,18 +387,20 @@ impl quic::WithLocalAgent for dquic::prelude::Connection { } } -impl quic::WithRemoteAgent for dquic::prelude::Connection { - type RemoteAgent = DquicRemoteAgent; +impl quic::WithRemoteAuthority for dquic::prelude::Connection { + type RemoteAuthority = DquicRemoteAuthority; - async fn remote_agent(&self) -> Result, quic::ConnectionError> { - let remote_agent = self - .remote_agent() + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + let authority = self + .remote_authority() .await - .map_err(convert_connection_error)?; - Ok(remote_agent.map(|remote_agent| { - let name = AsRef::>::as_ref(&remote_agent).clone(); - let cert_chain = AsRef::>::as_ref(&remote_agent).clone(); - DquicRemoteAgent { name, cert_chain } + .map_err(|error| convert_and_latch_connection_error(self, error))?; + Ok(authority.map(|authority| { + let name = AsRef::>::as_ref(&authority).clone(); + let cert_chain = AsRef::>::as_ref(&authority).clone(); + DquicRemoteAuthority { name, cert_chain } })) } } @@ -319,11 +411,20 @@ impl quic::Lifecycle for dquic::prelude::Connection { } fn check(&self) -> Result<(), quic::ConnectionError> { - self.validate().map_err(convert_connection_error) + if let Some(error) = latched_connection_error(self) { + return Err(error); + } + self.validate() + .map_err(|error| convert_and_latch_connection_error(self, error)) } async fn closed(&self) -> quic::ConnectionError { - convert_connection_error(dquic::prelude::Connection::terminated(self).await) + if let Some(error) = latched_connection_error(self) { + return error; + } + let origin_dcid = dquic_origin_dcid(self); + let error = convert_connection_error(dquic::prelude::Connection::terminated(self).await); + latch_connection_error(self, origin_dcid, error) } } @@ -343,6 +444,22 @@ impl quic::Listen for dquic::prelude::QuicListeners { } } +impl quic::Listen for &dquic::prelude::QuicListeners { + type Connection = dquic::prelude::Connection; + + type Error = dquic::prelude::ListenersShutdown; + + async fn accept(&mut self) -> Result, Self::Error> { + let (connection, ..) = (*self).accept().await?; + Ok(connection) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + (*self).shutdown(); + Ok(()) + } +} + impl quic::Listen for Arc { type Connection = dquic::prelude::Connection; @@ -376,3 +493,657 @@ impl quic::Connect for Arc { dquic::prelude::QuicClient::connect(self, &name).await } } + +#[cfg(test)] +mod tests { + use std::{ + borrow::Cow, + fmt, + str::FromStr, + sync::{Arc, Mutex}, + time::Duration, + }; + + use dquic::{ + prelude::{ + IO, + handy::{ToCertificate, ToPrivateKey}, + }, + qbase::{ + error::{AppError, Error as DquicError, ErrorFrameType, ErrorKind, QuicError}, + frame::{FrameType, ResetStreamError}, + varint::VarInt as DquicVarInt, + }, + qrecovery::streams::error::StreamError as DquicStreamError, + }; + use futures::{FutureExt, SinkExt, StreamExt, stream}; + use http::uri::Authority; + use tokio::time; + + use super::*; + use crate::{ + dquic::resolver::{Record, Resolve, ResolveFuture}, + quic::{GetStreamIdExt, Lifecycle, ResetStreamExt, StopStreamExt}, + }; + + const SERVER_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/server.cert"); + const SERVER_KEY: &[u8] = include_bytes!("../../tests/keychain/localhost/server.key"); + + fn make_identity() -> crate::dquic::identity::Identity { + crate::dquic::identity::Identity::new( + "localhost".parse().expect("valid identity name"), + SERVER_CERT.to_certificate(), + SERVER_KEY.to_private_key(), + ) + } + + const TEST_TIMEOUT: Duration = Duration::from_secs(10); + + async fn with_timeout(future: impl Future) -> T { + time::timeout(TEST_TIMEOUT, future) + .await + .expect("test operation timed out") + } + + async fn make_server_endpoint() -> (crate::dquic::QuicEndpoint, Authority) { + let identity = Arc::new(make_identity()); + let network = crate::dquic::Network::builder().build(); + let endpoint = crate::dquic::QuicEndpoint::builder() + .network(network.clone()) + .identity(identity) + .bind(Arc::new(vec![ + crate::dquic::binds::BindPattern::from_str("127.0.0.1:0") + .expect("valid bind pattern"), + ])) + .build() + .await; + let bind_iface = network + .quic() + .interfaces() + .into_iter() + .next() + .expect("server should bind an interface"); + let port = bind_iface + .borrow() + .bound_addr() + .expect("server interface should have a bound address") + .port(); + let authority = + Authority::from_maybe_shared(format!("localhost:{port}")).expect("valid authority"); + (endpoint, authority) + } + + fn make_raw_listeners() -> Arc { + dquic::prelude::QuicListeners::builder() + .with_router(Arc::new( + dquic::qinterface::component::route::QuicRouter::default(), + )) + .with_locations(Arc::new( + dquic::qinterface::component::location::Locations::new(), + )) + .without_client_cert_verifier() + .listen(1) + .expect("raw listeners start") + } + + struct ConnectedPair { + _client_endpoint: crate::dquic::QuicEndpoint, + _server_endpoint: crate::dquic::QuicEndpoint, + client: Arc, + server: Arc, + } + + async fn connected_pair() -> ConnectedPair { + let client_endpoint = crate::dquic::QuicEndpoint::builder().build().await; + let (server_endpoint, authority) = make_server_endpoint().await; + let (server_connection, client_connection) = with_timeout(async { + tokio::join!( + server_endpoint.accept(), + quic::Connect::connect(&client_endpoint, &authority) + ) + }) + .await; + ConnectedPair { + _client_endpoint: client_endpoint, + _server_endpoint: server_endpoint, + client: client_connection.expect("client should connect"), + server: server_connection.expect("server should accept"), + } + } + + fn close_pair(pair: &ConnectedPair) { + Lifecycle::close(pair.client.as_ref(), Code::H3_NO_ERROR, Cow::Borrowed("")); + Lifecycle::close(pair.server.as_ref(), Code::H3_NO_ERROR, Cow::Borrowed("")); + } + + #[derive(Debug, Default)] + struct RecordingResolver { + names: Mutex>, + } + + impl RecordingResolver { + fn names(&self) -> Vec { + self.names.lock().unwrap().clone() + } + } + + impl fmt::Display for RecordingResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("recording resolver") + } + } + + impl Resolve for RecordingResolver { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + self.names.lock().unwrap().push(name.to_owned()); + async { Ok(stream::empty::().boxed()) }.boxed() + } + } + + fn transport_connection_error(kind: u32) -> quic::ConnectionError { + quic::ConnectionError::from(quic::TransportError { + kind: VarInt::from_u32(kind), + frame_type: VarInt::from_u32(0), + reason: format!("test transport {kind}").into(), + }) + } + + fn assert_transport_error( + error: quic::ConnectionError, + expected_kind: VarInt, + expected_frame_type: VarInt, + expected_reason: &str, + ) { + let quic::ConnectionError::Transport { source } = error else { + panic!("expected transport error"); + }; + assert_eq!(source.kind, expected_kind); + assert_eq!(source.frame_type, expected_frame_type); + assert_eq!(source.reason, Cow::Borrowed(expected_reason)); + } + + #[test] + fn convert_varint_preserves_dquic_value() { + let value = DquicVarInt::from_u64(0x1234_5678).expect("dquic varint"); + + assert_eq!(convert_varint(value), VarInt::from_u32(0x1234_5678)); + } + + #[test] + fn stale_latch_detection_matches_origin_pairs() { + let first = DquicConnectionId::from_slice(b"first"); + let second = DquicConnectionId::from_slice(b"second"); + + assert!(is_stale_latch(Some(first), Some(second))); + assert!(is_stale_latch(None, Some(first))); + assert!(!is_stale_latch(Some(first), Some(first))); + assert!(!is_stale_latch(Some(first), None)); + assert!(!is_stale_latch(None, None)); + } + + #[test] + fn convert_connection_error_preserves_transport_v1_frame_type() { + let error = DquicError::Quic(QuicError::new( + ErrorKind::StreamLimit, + ErrorFrameType::V1(FrameType::StopSending), + "stream limit", + )); + + assert_transport_error( + convert_connection_error(error), + VarInt::from_u32(0x04), + VarInt::from_u32(0x05), + "stream limit", + ); + } + + #[test] + fn convert_connection_error_preserves_transport_extension_frame_type() { + let extension_frame = DquicVarInt::from_u32(0x21); + let error = DquicError::Quic(QuicError::new( + ErrorKind::ProtocolViolation, + ErrorFrameType::Ext(extension_frame), + "extension frame", + )); + + assert_transport_error( + convert_connection_error(error), + VarInt::from_u32(0x0a), + VarInt::from_u32(0x21), + "extension frame", + ); + } + + #[test] + fn convert_connection_error_preserves_application_error() { + let error = DquicError::App(AppError::new( + DquicVarInt::from_u32(Code::H3_REQUEST_CANCELLED.into_inner().into_inner() as u32), + "application close", + )); + + let quic::ConnectionError::Application { source } = convert_connection_error(error) else { + panic!("expected application error"); + }; + assert_eq!(source.code, Code::H3_REQUEST_CANCELLED); + assert_eq!(source.reason, Cow::Borrowed("application close")); + } + + #[test] + fn convert_stream_error_preserves_connection_and_reset_errors() { + let connection = DquicStreamError::Connection(DquicError::Quic( + QuicError::with_default_fty(ErrorKind::Internal, "connection failed"), + )); + let quic::StreamError::Connection { source } = convert_stream_error(connection) else { + panic!("expected connection stream error"); + }; + assert_transport_error( + source, + VarInt::from_u32(0x01), + VarInt::from_u32(0x00), + "connection failed", + ); + + let reset = DquicStreamError::Reset(ResetStreamError::new( + DquicVarInt::from_u32(0x100), + DquicVarInt::from_u32(0), + )); + let quic::StreamError::Reset { code } = convert_stream_error(reset) else { + panic!("expected reset stream error"); + }; + assert_eq!(code, VarInt::from_u32(0x100)); + } + + #[test] + #[should_panic(expected = "h3x write data after shutdown")] + fn convert_stream_error_rejects_eos_sent() { + _ = convert_stream_error(DquicStreamError::EosSent); + } + + #[tokio::test] + async fn connection_agents_are_exposed_from_established_dquic_connection() { + let pair = connected_pair().await; + let client = &pair.client; + let server = &pair.server; + + let client_local = quic::WithLocalAuthority::local_authority(client.as_ref()) + .await + .expect("client local authority lookup"); + assert!( + client_local.is_none(), + "anonymous client has no local authority" + ); + + let client_remote = quic::WithRemoteAuthority::remote_authority(client.as_ref()) + .await + .expect("client remote authority lookup") + .expect("client should observe server identity"); + assert_eq!( + authority::RemoteAuthority::name(&client_remote), + "localhost" + ); + assert_eq!( + authority::RemoteAuthority::cert_chain(&client_remote), + make_identity().cert_chain() + ); + + let server_local = quic::WithLocalAuthority::local_authority(server.as_ref()) + .await + .expect("server local authority lookup") + .expect("server should have local identity"); + assert_eq!(authority::LocalAuthority::name(&server_local), "localhost"); + assert_eq!( + authority::LocalAuthority::cert_chain(&server_local), + make_identity().cert_chain() + ); + + let server_remote = quic::WithRemoteAuthority::remote_authority(server.as_ref()) + .await + .expect("server remote authority lookup"); + assert!( + server_remote.is_none(), + "server should not observe an anonymous client identity" + ); + } + + #[tokio::test] + async fn lifecycle_close_check_and_closed_report_terminal_error() { + let pair = connected_pair().await; + let client = &pair.client; + let server = &pair.server; + Lifecycle::check(client.as_ref()).expect("client initially live"); + Lifecycle::check(server.as_ref()).expect("server initially live"); + + Lifecycle::close( + client.as_ref(), + Code::H3_REQUEST_CANCELLED, + Cow::Borrowed("done"), + ); + + let error = with_timeout(Lifecycle::closed(server.as_ref())).await; + assert!( + error.is_application() || error.is_transport(), + "closed should report a terminal connection error" + ); + let checked = Lifecycle::check(server.as_ref()).expect_err("terminal error is latched"); + assert_eq!(checked.to_string(), error.to_string()); + let repeated = Lifecycle::closed(server.as_ref()) + .now_or_never() + .expect("latched closed should resolve immediately"); + assert_eq!(repeated.to_string(), error.to_string()); + } + + #[tokio::test] + async fn stale_latched_connection_error_is_discarded() { + let pair = connected_pair().await; + let connection = pair.client.as_ref(); + let key = dquic_connection_key(connection); + let stale_origin = Some(DquicConnectionId::from_slice(b"not-current")); + + latch_connection_error(connection, stale_origin, transport_connection_error(0x31)); + + assert!(latched_connection_error(connection).is_none()); + assert!(dquic_connection_latches().get(&key).is_none()); + close_pair(&pair); + } + + #[tokio::test] + async fn connection_error_latch_reuses_current_origin_and_replaces_stale_origin() { + let pair = connected_pair().await; + let connection = pair.client.as_ref(); + let key = dquic_connection_key(connection); + let origin = dquic_origin_dcid(connection); + let stale_origin = Some(DquicConnectionId::from_slice(b"not-current")); + dquic_connection_latches().remove(&key); + + let first = latch_connection_error(connection, origin, transport_connection_error(0x32)); + assert_transport_error( + first, + VarInt::from_u32(0x32), + VarInt::from_u32(0), + "test transport 50", + ); + + let reused = latch_connection_error(connection, origin, transport_connection_error(0x33)); + assert_transport_error( + reused, + VarInt::from_u32(0x32), + VarInt::from_u32(0), + "test transport 50", + ); + + let replaced = + latch_connection_error(connection, stale_origin, transport_connection_error(0x34)); + assert_transport_error( + replaced, + VarInt::from_u32(0x34), + VarInt::from_u32(0), + "test transport 52", + ); + assert!(latched_connection_error(connection).is_none()); + + let converted = convert_and_latch_connection_error( + connection, + DquicError::Quic(QuicError::with_default_fty( + ErrorKind::ProtocolViolation, + "converted failure", + )), + ); + assert_transport_error( + converted, + VarInt::from_u32(0x0a), + VarInt::from_u32(0), + "converted failure", + ); + + dquic_connection_latches().remove(&key); + close_pair(&pair); + } + + #[tokio::test] + async fn listen_impls_for_references_and_arcs_shutdown_accept_queue() { + let listeners = make_raw_listeners(); + + quic::Listen::shutdown(listeners.as_ref()) + .await + .expect("direct listener shutdown"); + + let mut by_ref = listeners.as_ref(); + quic::Listen::shutdown(&by_ref) + .await + .expect("reference listener shutdown"); + assert!(quic::Listen::accept(&mut by_ref).await.is_err()); + + let mut by_arc = listeners.clone(); + quic::Listen::shutdown(&by_arc) + .await + .expect("arc listener shutdown"); + assert!(quic::Listen::accept(&mut by_arc).await.is_err()); + } + + #[tokio::test] + async fn isolated_raw_listener_configs_allow_concurrent_listener_instances() { + let first = make_raw_listeners(); + let second = make_raw_listeners(); + + quic::Listen::shutdown(first.as_ref()) + .await + .expect("first isolated listener shutdown"); + quic::Listen::shutdown(second.as_ref()) + .await + .expect("second isolated listener shutdown"); + } + + #[tokio::test] + async fn connect_impl_formats_authority_with_and_without_port() { + let resolver = Arc::new(RecordingResolver::default()); + let client = Arc::new( + dquic::prelude::QuicClient::builder() + .with_resolver(resolver.clone()) + .without_verifier() + .without_cert() + .build(), + ); + let with_port = "example.test:8443" + .parse::() + .expect("authority with port parses"); + let without_port = "example.test" + .parse::() + .expect("authority without port parses"); + + let first = quic::Connect::connect(&client, &with_port) + .await + .expect("connect with port"); + let second = quic::Connect::connect(&client, &without_port) + .await + .expect("connect without port"); + + assert_eq!( + resolver.names(), + vec!["example.test:8443".to_owned(), "example.test".to_owned()] + ); + assert_eq!(format!("{}", resolver.as_ref()), "recording resolver"); + Lifecycle::close(first.as_ref(), Code::H3_NO_ERROR, Cow::Borrowed("")); + Lifecycle::close(second.as_ref(), Code::H3_NO_ERROR, Cow::Borrowed("")); + } + + #[tokio::test] + async fn dquic_local_authority_exposes_identity_and_signing() { + let identity = make_identity(); + let certified_key = + crate::dquic::identity::build_certified_key(&identity).expect("test key should load"); + let local = DquicLocalAuthority { + name: Arc::from(identity.name.as_str()), + certified_key, + }; + + assert_eq!(authority::LocalAuthority::name(&local), "localhost"); + assert_eq!( + authority::LocalAuthority::cert_chain(&local), + identity.cert_chain() + ); + assert!(format!("{local:?}").contains("DquicLocalAuthority")); + + let signature = authority::LocalAuthority::sign(&local, b"payload") + .await + .expect("signature"); + assert!( + authority::LocalAuthority::verify(&local, b"payload", &signature) + .await + .expect("verification should run") + ); + assert!( + !authority::LocalAuthority::verify(&local, b"wrong payload", &signature) + .await + .expect("verification should run") + ); + } + + #[tokio::test] + async fn dquic_remote_authority_exposes_identity_and_verification() { + let identity = make_identity(); + let certified_key = + crate::dquic::identity::build_certified_key(&identity).expect("test key should load"); + let local = DquicLocalAuthority { + name: Arc::from(identity.name.as_str()), + certified_key, + }; + let remote = DquicRemoteAuthority { + name: Arc::from("peer.localhost"), + cert_chain: Arc::from(identity.cert_chain()), + }; + + assert_eq!(authority::RemoteAuthority::name(&remote), "peer.localhost"); + assert_eq!( + authority::RemoteAuthority::cert_chain(&remote), + identity.cert_chain() + ); + assert_eq!( + authority::RemoteAuthority::public_key(&remote).as_ref(), + authority::LocalAuthority::public_key(&local).as_ref() + ); + assert!(format!("{remote:?}").contains("DquicRemoteAuthority")); + + let signature = authority::LocalAuthority::sign(&local, b"payload") + .await + .expect("signature"); + assert!( + authority::RemoteAuthority::verify(&remote, b"payload", &signature) + .await + .expect("verification should run") + ); + assert!( + !authority::RemoteAuthority::verify(&remote, b"wrong payload", &signature) + .await + .expect("verification should run") + ); + } + + #[tokio::test] + async fn bidirectional_stream_wrappers_expose_ids_and_transfer_bytes() { + let pair = connected_pair().await; + let (mut client_reader, mut client_writer) = + with_timeout(quic::ManageStream::open_bi(pair.client.as_ref())) + .await + .expect("client opens stream"); + + let client_reader_id = client_reader.stream_id().await.expect("client reader id"); + let client_writer_id = client_writer.stream_id().await.expect("client writer id"); + assert_eq!(client_reader_id, client_writer_id); + + client_writer + .send(Bytes::from_static(b"client payload")) + .await + .expect("client sends request bytes"); + + let (mut server_reader, mut server_writer) = + with_timeout(quic::ManageStream::accept_bi(pair.server.as_ref())) + .await + .expect("server accepts stream"); + let server_reader_id = server_reader.stream_id().await.expect("server reader id"); + let server_writer_id = server_writer.stream_id().await.expect("server writer id"); + assert_eq!(server_reader_id, server_writer_id); + assert_eq!(client_reader_id, server_reader_id); + assert_eq!(server_reader.size_hint(), (0, None)); + + let received = with_timeout(server_reader.next()) + .await + .expect("server receives a chunk") + .expect("server chunk is ok"); + assert_eq!(received, Bytes::from_static(b"client payload")); + + server_writer + .send(Bytes::from_static(b"server payload")) + .await + .expect("server sends response bytes"); + let received = with_timeout(client_reader.next()) + .await + .expect("client receives a chunk") + .expect("client chunk is ok"); + assert_eq!(received, Bytes::from_static(b"server payload")); + + with_timeout(client_writer.close()) + .await + .expect("client writer closes"); + with_timeout(server_writer.close()) + .await + .expect("server writer closes"); + close_pair(&pair); + } + + #[tokio::test] + async fn unidirectional_stream_wrappers_expose_ids_and_transfer_bytes() { + let pair = connected_pair().await; + let mut client_writer = with_timeout(quic::ManageStream::open_uni(pair.client.as_ref())) + .await + .expect("client opens unidirectional stream"); + + let client_stream_id = client_writer.stream_id().await.expect("client writer id"); + client_writer + .send(Bytes::from_static(b"one-way payload")) + .await + .expect("client sends unidirectional bytes"); + + let mut server_reader = with_timeout(quic::ManageStream::accept_uni(pair.server.as_ref())) + .await + .expect("server accepts unidirectional stream"); + let server_stream_id = server_reader.stream_id().await.expect("server reader id"); + assert_eq!(client_stream_id, server_stream_id); + + let received = with_timeout(server_reader.next()) + .await + .expect("server receives unidirectional chunk") + .expect("server unidirectional chunk is ok"); + assert_eq!(received, Bytes::from_static(b"one-way payload")); + + with_timeout(client_writer.close()) + .await + .expect("client unidirectional writer closes"); + close_pair(&pair); + } + + #[tokio::test] + async fn stream_stop_and_reset_wrappers_complete() { + let pair = connected_pair().await; + let mut client_writer = with_timeout(quic::ManageStream::open_uni(pair.client.as_ref())) + .await + .expect("client opens stream for stop"); + client_writer + .send(Bytes::from_static(b"stop payload")) + .await + .expect("client sends bytes before stop"); + + let mut server_reader = with_timeout(quic::ManageStream::accept_uni(pair.server.as_ref())) + .await + .expect("server accepts stream for stop"); + + server_reader + .stop(VarInt::from_u32(0x10)) + .await + .expect("stop-sending completes on reader wrapper"); + client_writer + .reset(VarInt::from_u32(0x11)) + .await + .expect("reset completes on writer wrapper"); + close_pair(&pair); + } +} diff --git a/src/dquic/sni.rs b/src/dquic/sni.rs new file mode 100644 index 0000000..bbd1df3 --- /dev/null +++ b/src/dquic/sni.rs @@ -0,0 +1,259 @@ +//! Per-SNI server state used by [`QuicBindDriver`](super::QuicBindDriver) to +//! fan out incoming connections across many +//! [`QuicEndpoint`](super::QuicEndpoint) instances. +//! +//! A [`ServerBinding`] is cheap to clone: each clone shares the same +//! mpmc [`async_channel`] tail, so multiple endpoints that registered the +//! same SNI cooperatively drain inbound connections. Dropping the last +//! strong reference unregisters the SNI entry from the QUIC driver. + +use std::sync::{Arc, Weak}; + +use dashmap::DashMap; +use dhttp_identity::name::Name; +use rustls::{ + server::{ClientHello, ResolvesServerCert}, + sign::CertifiedKey, +}; + +use crate::dquic::{binds::BindPattern, connection::Connection, identity::Identity}; + +/// Per-SNI entry stored behind a `Weak` in the QUIC driver's registry. +/// +/// Holds an mpmc channel so multiple [`ServerBinding`] clones share the +/// same inbound connection queue. +pub(crate) struct ServerEntry { + pub(crate) identity: Arc, + pub(crate) certified_key: Arc, + pub(crate) incomings_tx: async_channel::Sender>, + pub(crate) incomings_rx: async_channel::Receiver>, + /// Shared server-side QUIC/TLS configuration for this entry. + #[allow( + dead_code, + reason = "held to keep the shared server configuration alive for this SNI entry" + )] + pub(crate) config: Arc, + /// Shared guard — dropped with the entry to remove it from `sni_registry`. + #[allow( + dead_code, + reason = "held for RAII SNI unregister when the last entry reference drops" + )] + pub(crate) guard: Arc, + /// Bind patterns associated with this server entry. + pub(crate) bind: Arc>, +} + +/// RAII guard that removes an SNI entry from the registry when the last +/// [`ServerBinding`] referencing it is dropped. +pub(crate) struct RegistryGuard { + pub(crate) name: Name<'static>, + pub(crate) registry: Weak, Weak>>, + pub(crate) self_entry: Weak, +} + +impl Drop for RegistryGuard { + fn drop(&mut self) { + if let Some(registry) = self.registry.upgrade() { + registry.remove_if(&self.name, |_name, entry| { + Weak::ptr_eq(&self.self_entry, entry) + }); + } + } +} + +/// Shared server-side QUIC/TLS context. At most one instance exists per +/// [`QuicBindDriver`](super::QuicBindDriver) at any time; identical instances +/// are shared across all registered SNIs, and conflicting configurations are +/// rejected at `bind_server` time. +pub(crate) struct ServerConfig { + pub(crate) config: crate::dquic::server::ServerQuicConfig, + pub(crate) rustls_config: Arc, +} + +/// Public handle returned by +/// [`QuicBindDriver::bind_server`](super::QuicBindDriver::bind_server). +/// +/// Cloning is cheap and yields a new receiver on the **same** mpmc queue — +/// concurrently calling `recv` from multiple clones fans out inbound +/// connections across the clones without duplicating work. +pub struct ServerBinding { + pub(crate) entry: Arc, +} + +impl Clone for ServerBinding { + fn clone(&self) -> Self { + Self { + entry: self.entry.clone(), + } + } +} + +impl std::fmt::Debug for ServerBinding { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ServerBinding") + .field("name", &self.entry.identity.name) + .finish_non_exhaustive() + } +} + +impl ServerBinding { + /// Return the server name this binding was registered under. + pub fn name(&self) -> &Name<'static> { + &self.entry.identity.name + } + + /// Receive the next accepted connection for this SNI. + /// + /// Returns `None` once the network is shut down or has no remaining + /// senders. + pub async fn recv(&self) -> Option> { + self.entry.incomings_rx.recv().await.ok() + } +} + +/// rustls `ResolvesServerCert` backed by the QUIC driver's SNI registry. +/// +/// SNI names are matched ASCII case-insensitively per RFC 6066 §3. +#[derive(Clone)] +pub(crate) struct SniCertResolver { + pub(crate) registry: Weak, Weak>>, +} + +impl std::fmt::Debug for SniCertResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SniCertResolver").finish_non_exhaustive() + } +} + +impl ResolvesServerCert for SniCertResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + let registry = self.registry.upgrade()?; + let sni = client_hello.server_name()?; + let sni_lower = sni.to_ascii_lowercase(); + registry + .get::(&sni_lower) + .and_then(|item| item.value().upgrade()) + .map(|entry| entry.certified_key.clone()) + } +} + +#[cfg(test)] +mod tests { + use dquic::prelude::handy::{ToCertificate, ToPrivateKey}; + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + + use super::*; + use crate::dquic::{ + identity::{self, Identity}, + server::ServerQuicConfig, + }; + + const SERVER_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/server.cert"); + const SERVER_KEY: &[u8] = include_bytes!("../../tests/keychain/localhost/server.key"); + + fn make_identity(name: &str) -> Arc { + let certs: Vec> = SERVER_CERT.to_certificate(); + let key: PrivateKeyDer<'static> = SERVER_KEY.to_private_key(); + Arc::new(Identity { + name: name.parse().expect("valid identity name"), + certs: Arc::new(certs), + key: Arc::new(key), + ocsp: Arc::new(None), + }) + } + + fn make_server_entry( + registry: &Arc, Weak>>, + name: &str, + ) -> Arc { + let identity = make_identity(name); + let certified_key = identity::build_certified_key(&identity).expect("test key should load"); + let entry_name = identity.name.clone(); + let (incomings_tx, incomings_rx) = async_channel::bounded(1); + let config = Arc::new(ServerConfig { + config: ServerQuicConfig::default(), + rustls_config: Arc::new( + rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(SniCertResolver { + registry: Weak::new(), + })), + ), + }); + + Arc::new_cyclic(|self_entry| ServerEntry { + identity: identity.clone(), + certified_key, + incomings_tx, + incomings_rx, + config, + guard: Arc::new(RegistryGuard { + name: entry_name.clone(), + registry: Arc::downgrade(registry), + self_entry: self_entry.clone(), + }), + bind: Arc::new(Vec::new()), + }) + } + + #[tokio::test] + async fn server_binding_exposes_name_debug_and_closed_receive() { + let registry = Arc::new(DashMap::new()); + let entry = make_server_entry(®istry, "test.example.com"); + let binding = ServerBinding { + entry: entry.clone(), + }; + let cloned = binding.clone(); + + assert_eq!(binding.name().as_str(), "test.example.com"); + assert!( + format!("{binding:?}").contains("test.example.com"), + "debug output should include binding name", + ); + assert!(Arc::ptr_eq(&binding.entry, &cloned.entry)); + + entry.incomings_tx.close(); + assert!(binding.recv().await.is_none()); + } + + #[test] + fn registry_guard_removes_current_entry_on_last_drop() { + let registry = Arc::new(DashMap::new()); + let name: Name<'static> = "test.example.com".parse().expect("valid name"); + let entry = make_server_entry(®istry, "test.example.com"); + registry.insert(name.clone(), Arc::downgrade(&entry)); + + assert!(registry.get(&name).is_some()); + drop(entry); + assert!(registry.get(&name).is_none()); + } + + #[test] + fn registry_guard_keeps_replaced_entry() { + let registry = Arc::new(DashMap::new()); + let name: Name<'static> = "test.example.com".parse().expect("valid name"); + let first = make_server_entry(®istry, "test.example.com"); + registry.insert(name.clone(), Arc::downgrade(&first)); + let second = make_server_entry(®istry, "test.example.com"); + registry.insert(name.clone(), Arc::downgrade(&second)); + + drop(first); + let current = registry + .get(&name) + .and_then(|entry| entry.value().upgrade()) + .expect("replacement entry should remain"); + assert!(Arc::ptr_eq(¤t, &second)); + + drop(current); + drop(second); + assert!(registry.get(&name).is_none()); + } + + #[test] + fn sni_resolver_debug_is_non_exhaustive() { + let resolver = SniCertResolver { + registry: Weak::new(), + }; + assert_eq!(format!("{resolver:?}"), "SniCertResolver { .. }"); + } +} diff --git a/src/endpoint.rs b/src/endpoint.rs new file mode 100644 index 0000000..7be843e --- /dev/null +++ b/src/endpoint.rs @@ -0,0 +1,2174 @@ +//! Generic HTTP/3 endpoint. +//! +//! [`H3Endpoint`] combines a QUIC transport `Q` with a selected connection +//! type `C`, providing raw HTTP/3 connection pooling and request serving. +//! Client connection access requires `Q: quic::Connect`; +//! server request serving requires `Q: quic::Listen`. +//! +//! [`ConnectionBuilder`]: crate::connection::ConnectionBuilder + +use std::{any::Any, error::Error, sync::Arc}; + +use bon::bon; +use http::uri::Authority; +use snafu::ResultExt; +use tower_service::Service; +use tracing::Instrument; + +use crate::{ + connection::{Connection as H3Connection, ConnectionBuilder, ConnectionState}, + dhttp::message::{MessageReader, MessageWriter}, + pool::{self, Pool}, + quic::{self, GetStreamIdExt}, + stream_id::StreamId, +}; + +#[cfg(feature = "hyper")] +pub mod hyper; + +/// Generic HTTP/3 endpoint parameterized over a QUIC transport `Q` and a +/// connection type `C`. +/// +/// `Q` carries no struct-level constraint — abilities are encoded at the +/// method level: +/// +/// | Capability | Bound | +/// |---|---| +/// | Client connection access (connect) | `Q: quic::Connect` | +/// | Server (listen, listen_owned) | `Q: quic::Listen` | +pub struct H3Endpoint { + pub(crate) quic: Q, + pub(crate) builder: Arc>, + pub(crate) pool: Pool, +} + +/// RAII guard for mutable access to [`H3Endpoint`]'s QUIC transport. +/// +/// On drop, clears the endpoint's connection pool via [`Pool::clear`], +/// ensuring no stale connections remain after QUIC configuration changes. +pub struct QuicMutGuard<'a, Q, C: quic::Connection> { + quic: &'a mut Q, + pool: &'a Pool, +} + +/// A request that has just been accepted on a QUIC stream but whose HTTP/3 +/// header frame has not yet been interpreted by a higher-level HTTP API. +pub struct UnresolvedRequest { + /// QUIC stream identifier for this request. + pub stream_id: StreamId, + /// Incoming request stream. + pub read_stream: MessageReader, + /// Outgoing response stream. + pub write_stream: MessageWriter, + /// Owning H3 connection. + pub connection: Arc>, +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum AcceptError { + #[snafu(display("failed to accept QUIC connection"))] + Accept { source: E }, + #[snafu(display("failed to initialize H3 connection"))] + Build { source: quic::ConnectionError }, +} + +impl std::ops::Deref for QuicMutGuard<'_, Q, C> { + type Target = Q; + fn deref(&self) -> &Self::Target { + self.quic + } +} + +impl std::ops::DerefMut for QuicMutGuard<'_, Q, C> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.quic + } +} + +impl Drop for QuicMutGuard<'_, Q, C> { + fn drop(&mut self) { + self.pool.clear(); + } +} + +#[bon] +impl H3Endpoint { + /// Construct a new HTTP/3 endpoint. + #[builder] + pub fn new(quic: Q, #[builder(default)] builder: Arc>) -> Self { + Self { + quic, + builder, + pool: Pool::empty(), + } + } +} + +impl H3Endpoint +where + Q: quic::Connect, +{ + /// Construct a new HTTP/3 endpoint with default pool and builder. + pub fn new(quic: Q) -> Self { + H3Endpoint::builder().quic(quic).build() + } +} + +impl H3Endpoint { + /// Obtain a mutable guard for the QUIC transport. + /// + /// The guard implements [`DerefMut`](std::ops::DerefMut) targeting `Q`. + /// On drop, the connection pool is cleared via [`Pool::clear`]. + pub fn quic_mut(&mut self) -> QuicMutGuard<'_, Q, C> { + QuicMutGuard { + quic: &mut self.quic, + pool: &self.pool, + } + } + + /// Shared reference to the underlying QUIC transport. + pub fn quic(&self) -> &Q { + &self.quic + } + + /// Consume the endpoint and return the underlying QUIC transport. + pub fn into_quic(self) -> Q { + self.quic + } + + /// Clear all cached client connections. + /// + /// This is useful when an external network transition invalidates paths + /// without mutating the QUIC transport configuration itself. + pub fn clear_pool(&self) { + self.pool.clear(); + } + + /// Number of cached client connection entries. + pub fn pool_len(&self) -> usize { + self.pool.len() + } +} + +impl H3Endpoint +where + Q: quic::Connect, +{ + /// Obtain (or reuse) an HTTP/3 connection to `server` from the pool. + pub async fn connect( + &self, + server: Authority, + ) -> Result>, pool::ConnectError> { + self.pool + .reuse_or_connect_with(&self.quic, self.builder.clone(), server) + .await + } +} + +impl H3Endpoint +where + Q: quic::Listen, +{ + /// Accept one QUIC connection and initialize it as an HTTP/3 connection. + pub async fn accept(&mut self) -> Result>, AcceptError> { + let quic_conn = self + .quic + .accept() + .await + .context(accept_error::AcceptSnafu)?; + let h3_conn = self + .builder + .build(quic_conn) + .await + .context(accept_error::BuildSnafu)?; + Ok(Arc::new(h3_conn)) + } + + /// Accept one HTTP/3 connection from a shared endpoint handle. + /// + /// Rust does not support overloading inherent methods by receiver type, so + /// the shared-handle variant uses the same `*_owned` convention as + /// [`H3Endpoint::listen_owned`]. + pub fn accept_owned( + self: &Arc, + ) -> impl Future>, AcceptError>> + use + where + E: Error + Any, + for<'a> &'a Q: quic::Listen, + { + let this = Arc::clone(self); + let builder = this.builder.clone(); + let pool = this.pool.clone(); + async move { + let mut ref_ep = H3Endpoint { + quic: &this.quic, + builder, + pool, + }; + ref_ep.accept().await + } + } +} + +impl quic::Connect for H3Endpoint +where + Q: quic::Connect, + C: quic::Connection, +{ + type Connection = C; + type Error = Q::Error; + + fn connect<'a>( + &'a self, + server: &'a Authority, + ) -> impl Future, Self::Error>> + Send + 'a { + self.quic.connect(server) + } +} + +impl quic::Listen for H3Endpoint +where + Q: quic::Listen, + C: quic::Connection, +{ + type Connection = C; + type Error = Q::Error; + + fn accept( + &mut self, + ) -> impl Future, Self::Error>> + Send + '_ { + self.quic.accept() + } + + fn shutdown(&self) -> impl Future> + Send + '_ { + self.quic.shutdown() + } +} + +impl H3Endpoint +where + Q: quic::Listen, +{ + /// Accept and serve HTTP/3 connections in a loop. + #[doc(alias = "serve")] + pub async fn listen(&mut self, service: S) -> Result<(), ::Error> + where + S: Service + Clone + Send + Sync + 'static, + S::Future: Send, + S::Error: Into>, + { + loop { + let quic_conn = self.quic.accept().await?; + let h3_conn = match self.builder.build(quic_conn).await { + Ok(c) => Arc::new(c), + Err(e) => { + let report = snafu::Report::from_error(&e); + tracing::debug!(error = %report, "failed to build H3 connection"); + continue; + } + }; + // Inherent termination: spawned task exits when accept_raw_message_stream + // returns an error (connection closed) or qpack is no longer available. + tokio::spawn(listen_connection(h3_conn, service.clone()).in_current_span()); + } + } + + /// Listen for HTTP/3 requests on an `Arc>`. + /// + /// The returned future does not capture `&self`, so it can be spawned: + /// + /// ```ignore + /// let h3: Arc> = ...; + /// tokio::spawn(h3.listen(router)); + /// ``` + #[doc(alias = "serve_owned")] + pub fn listen_owned( + self: &Arc, + service: S, + ) -> impl Future::Error>> + use + where + S: Service + Clone + Send + Sync + 'static, + S::Future: Send, + S::Error: Into>, + for<'a> &'a Q: quic::Listen::Error>, + { + let this = Arc::clone(self); + let pool = this.pool.clone(); + let builder = this.builder.clone(); + async move { + let mut ref_ep = H3Endpoint { + quic: &this.quic, + builder, + pool, + }; + ref_ep.listen(service).await + } + } +} + +/// Listen for requests from a single accepted H3 connection. +/// +/// Inherent termination: returns when [`H3Connection::accept_raw_message_stream`] +/// produces an error (connection closed) or when the QPACK module is no longer +/// available. +async fn listen_connection(h3_conn: Arc>, mut service: S) +where + C: quic::Connection, + S: Service + Clone + Send + 'static, + S::Future: Send, + S::Error: Into>, +{ + let erased = Arc::new(h3_conn.erase()); + + loop { + let (mut reader, writer) = match h3_conn.accept_raw_message_stream().await { + Ok(s) => s, + Err(e) => { + let report = snafu::Report::from_error(&e); + tracing::debug!(error = %report, "stopping request handler for connection"); + return; + } + }; + let stream_id = match GetStreamIdExt::stream_id(&mut reader).await { + Ok(id) => id, + Err(e) => { + let report = snafu::Report::from_error(&e); + tracing::debug!(error = %report, "failed to get stream id, skipping request"); + continue; + } + }; + let qpack = match h3_conn.qpack() { + Ok(q) => q, + Err(e) => { + let report = snafu::Report::from_error(&e); + tracing::debug!(error = %report, "qpack unavailable, stopping handler"); + return; + } + }; + let read_stream = + MessageReader::new(stream_id, reader, qpack.decoder.clone(), (*erased).clone()); + let write_stream = MessageWriter::new(writer, qpack.encoder.clone(), (*erased).clone()); + + let request = UnresolvedRequest { + stream_id: StreamId(stream_id), + read_stream, + write_stream, + connection: erased.clone(), + }; + + // Inherent termination: the spawned task exits when the service future resolves. + tokio::spawn(listen_request(&mut service, request).in_current_span()); + } +} + +/// Spawn a task to process a single request through `service`. +/// +/// Inherent termination: the spawned task exits when the service future resolves. +fn listen_request( + service: &mut S, + request: UnresolvedRequest, +) -> impl Future + use +where + S: Service + Send + 'static, + S::Future: Send, + S::Error: Into>, +{ + let fut = service.call(request); + async move { + if let Err(error) = fut.await { + let boxed: Box = error.into(); + let report = snafu::Report::from_error(boxed.as_ref()); + tracing::debug!(error = %report, "request handler returned error"); + } + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + fmt, + future::{Ready, pending, ready}, + pin::Pin, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use dhttp_identity::identity; + use futures::{SinkExt, Stream, StreamExt}; + use http::uri::Authority; + use tokio::{ + sync::watch, + time::{Duration, timeout}, + }; + use tower_service::Service; + + use super::*; + use crate::{ + codec::{ + BoxPeekableStreamReader, BoxStreamWriter, PeekableStreamReader, SinkWriter, + StreamReader, + }, + connection::{ + ConnectionState, + tests::{ + MockConnection, TestLocalAuthority, TestReadStream, TestRemoteAuthority, + TestWriteStream, + }, + }, + dhttp::{ + message::guard::{GuardQuicReader, GuardQuicWriter}, + protocol::DHttpProtocol, + }, + pool::ReuseableConnection, + protocol::{Protocol, Protocols, StreamVerdict}, + qpack::protocol::QPackProtocolFactory, + quic::{BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamIdExt, StopStreamExt}, + varint::VarInt, + }; + + /// Minimal quic::Connect implementation for testing QuicMutGuard. + struct MockConnect; + + impl quic::Connect for MockConnect { + type Connection = MockConnection; + type Error = quic::ConnectionError; + + async fn connect<'a>( + &'a self, + _server: &'a Authority, + ) -> Result, Self::Error> { + unreachable!("connect is not called in guard tests") + } + } + + struct MutableTransport { + generation: usize, + } + + #[derive(Debug)] + struct BuildableConnection; + + impl quic::ManageStream for BuildableConnection { + type StreamReader = TestReadStream; + type StreamWriter = TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn open_uni(&self) -> Result { + Ok(TestWriteStream) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for BuildableConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for BuildableConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for BuildableConnection { + fn close(&self, _code: crate::error::Code, _reason: std::borrow::Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + pending().await + } + } + + #[derive(Debug, Clone)] + struct NamedRemoteAuthority { + name: &'static str, + } + + impl identity::RemoteAuthority for NamedRemoteAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + } + + #[derive(Debug)] + struct IdentifiedConnection { + remote_name: &'static str, + } + + impl IdentifiedConnection { + fn new(remote_name: &'static str) -> Self { + Self { remote_name } + } + } + + impl quic::ManageStream for IdentifiedConnection { + type StreamReader = TestReadStream; + type StreamWriter = TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn open_uni(&self) -> Result { + Ok(TestWriteStream) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for IdentifiedConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for IdentifiedConnection { + type RemoteAuthority = NamedRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(Some(NamedRemoteAuthority { + name: self.remote_name, + })) + } + } + + impl quic::Lifecycle for IdentifiedConnection { + fn close(&self, _code: crate::error::Code, _reason: std::borrow::Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + pending().await + } + } + + fn test_connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + type ConnectionResultQueue = Arc, quic::ConnectionError>>>>; + + #[derive(Clone)] + struct CountingConnect { + calls: Arc, + servers: Arc>>, + result: Arc, quic::ConnectionError>>>, + } + + impl CountingConnect { + fn succeed(connection: C) -> Self { + Self { + calls: Arc::default(), + servers: Arc::default(), + result: Arc::new(Mutex::new(Ok(Arc::new(connection)))), + } + } + + fn fail(error: quic::ConnectionError) -> Self { + Self { + calls: Arc::default(), + servers: Arc::default(), + result: Arc::new(Mutex::new(Err(error))), + } + } + } + + impl quic::Connect for CountingConnect { + type Connection = C; + type Error = quic::ConnectionError; + + async fn connect<'a>( + &'a self, + server: &'a Authority, + ) -> Result, Self::Error> { + self.calls.fetch_add(1, Ordering::Relaxed); + self.servers + .lock() + .expect("server log mutex should not be poisoned") + .push(server.clone()); + self.result + .lock() + .expect("result mutex should not be poisoned") + .clone() + } + } + + #[derive(Clone)] + struct SequencedConnect { + calls: Arc, + results: ConnectionResultQueue, + } + + impl SequencedConnect { + fn new(results: impl IntoIterator, quic::ConnectionError>>) -> Self { + Self { + calls: Arc::default(), + results: Arc::new(Mutex::new(results.into_iter().collect())), + } + } + } + + impl quic::Connect for SequencedConnect { + type Connection = C; + type Error = quic::ConnectionError; + + async fn connect<'a>( + &'a self, + _server: &'a Authority, + ) -> Result, Self::Error> { + self.calls.fetch_add(1, Ordering::Relaxed); + self.results + .lock() + .expect("result queue mutex should not be poisoned") + .pop_front() + .expect("connect result queue should contain an entry") + } + } + + #[derive(Default)] + struct MockListen { + accepted: Arc, + shutdowns: Arc, + } + + impl quic::Listen for MockListen { + type Connection = BuildableConnection; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + self.accepted.fetch_add(1, Ordering::Relaxed); + Ok(Arc::new(BuildableConnection)) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + self.shutdowns.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + } + + impl quic::Listen for &MockListen { + type Connection = BuildableConnection; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + self.accepted.fetch_add(1, Ordering::Relaxed); + Ok(Arc::new(BuildableConnection)) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + self.shutdowns.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + } + + struct FailingListen; + + impl quic::Listen for FailingListen { + type Connection = BuildableConnection; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + Err(test_connection_error("accept failed")) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + Err(test_connection_error("shutdown failed")) + } + } + + impl quic::Listen for &FailingListen { + type Connection = BuildableConnection; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + Err(test_connection_error("shared accept failed")) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + Err(test_connection_error("shared shutdown failed")) + } + } + + struct UnbuildableListen; + + impl quic::Listen for UnbuildableListen { + type Connection = MockConnection; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + Ok(Arc::new(MockConnection::new())) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + Ok(()) + } + } + + impl quic::Listen for &UnbuildableListen { + type Connection = MockConnection; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + Ok(Arc::new(MockConnection::new())) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + Ok(()) + } + } + + #[derive(Clone)] + struct SequencedListen { + accepted: Arc, + results: ConnectionResultQueue, + } + + impl SequencedListen { + fn new(results: impl IntoIterator, quic::ConnectionError>>) -> Self { + Self { + accepted: Arc::default(), + results: Arc::new(Mutex::new(results.into_iter().collect())), + } + } + } + + impl quic::Listen for SequencedListen { + type Connection = C; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + self.accepted.fetch_add(1, Ordering::Relaxed); + self.results + .lock() + .expect("listen result queue mutex should not be poisoned") + .pop_front() + .expect("listen result queue should contain an entry") + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + Ok(()) + } + } + + impl quic::Listen for &SequencedListen { + type Connection = C; + type Error = quic::ConnectionError; + + async fn accept(&mut self) -> Result, Self::Error> { + self.accepted.fetch_add(1, Ordering::Relaxed); + self.results + .lock() + .expect("listen result queue mutex should not be poisoned") + .pop_front() + .expect("listen result queue should contain an entry") + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + Ok(()) + } + } + + #[derive(Clone)] + struct NoopService; + + impl Service for NoopService { + type Response = (); + type Error = quic::ConnectionError; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _request: UnresolvedRequest) -> Self::Future { + ready(Ok(())) + } + } + + #[derive(Debug)] + struct CloseLatch { + closed_tx: watch::Sender, + closed_rx: watch::Receiver, + } + + impl Default for CloseLatch { + fn default() -> Self { + let (closed_tx, closed_rx) = watch::channel(false); + Self { + closed_tx, + closed_rx, + } + } + } + + impl CloseLatch { + fn close(&self) { + let _ = self.closed_tx.send(true); + } + + async fn wait(&self) { + let mut closed_rx = self.closed_rx.clone(); + while !*closed_rx.borrow_and_update() { + if closed_rx.changed().await.is_err() { + break; + } + } + } + } + + #[derive(Debug, Clone)] + struct ControlledConnection { + close_latch: Arc, + close_error: quic::ConnectionError, + } + + impl ControlledConnection { + fn new(reason: &'static str) -> Self { + Self { + close_latch: Arc::new(CloseLatch::default()), + close_error: test_connection_error(reason), + } + } + + fn trigger_close(&self) { + self.close_latch.close(); + } + } + + impl quic::ManageStream for ControlledConnection { + type StreamReader = TestReadStream; + type StreamWriter = TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn open_uni(&self) -> Result { + Ok(TestWriteStream) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for ControlledConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for ControlledConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for ControlledConnection { + fn close(&self, _code: crate::error::Code, _reason: std::borrow::Cow<'static, str>) { + self.trigger_close(); + } + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + self.close_latch.wait().await; + self.close_error.clone() + } + } + + fn state_without_qpack( + quic: Arc, + ) -> ConnectionState { + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased)); + ConnectionState::new_for_test(quic, Arc::new(protocols)) + } + + async fn state_with_qpack( + quic: Arc, + ) -> ConnectionState { + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased)); + let qpack = QPackProtocolFactory::new() + .init(&quic, &protocols) + .await + .expect("qpack protocol should initialize for endpoint tests"); + protocols.insert(qpack); + ConnectionState::new_for_test(quic, Arc::new(protocols)) + } + + async fn http3_request_stream(stream_id: u32) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let stream_id = VarInt::from_u32(stream_id); + let (reader, mut write_side) = quic::test::mock_stream_pair(stream_id); + write_side + .send(Bytes::from_static(&[0x01])) + .await + .expect("write test HEADERS frame type"); + write_side + .close() + .await + .expect("close test request read side"); + + let (_read_side, writer) = quic::test::mock_stream_pair(stream_id); + ( + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)), + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + async fn enqueue_http3_request( + state: &ConnectionState, + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) { + let verdict = Protocol::accept_bi(state.dhttp(), stream) + .await + .expect("test request stream should be classified"); + assert!(matches!(verdict, StreamVerdict::Accepted)); + } + + #[derive(Debug)] + struct StreamIdErrorReadStream { + first_chunk: Option, + close_latch: Option>, + first_stream_id: Option, + } + + impl StreamIdErrorReadStream { + fn new(close_latch: Arc) -> Self { + Self { + first_chunk: Some(Bytes::from_static(&[0x01])), + close_latch: Some(close_latch), + first_stream_id: None, + } + } + + fn fail_after_first_success(stream_id: VarInt) -> Self { + Self { + first_chunk: Some(Bytes::from_static(&[0x01])), + close_latch: None, + first_stream_id: Some(stream_id), + } + } + } + + impl quic::GetStreamId for StreamIdErrorReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + let this = self.get_mut(); + if let Some(stream_id) = this.first_stream_id.take() { + return Poll::Ready(Ok(stream_id)); + } + if let Some(close_latch) = &this.close_latch { + close_latch.close(); + } + Poll::Ready(Err(quic::StreamError::Reset { + code: VarInt::from_u32(0x11), + })) + } + } + + impl quic::StopStream for StreamIdErrorReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + let _ = self; + Poll::Ready(Ok(())) + } + } + + impl Stream for StreamIdErrorReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + Poll::Ready(this.first_chunk.take().map(Ok)) + } + } + + fn stream_id_error_request_stream( + stream_id: u32, + close_latch: Arc, + ) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let stream_id = VarInt::from_u32(stream_id); + let reader = PeekableStreamReader::new(StreamReader::new(Box::pin( + StreamIdErrorReadStream::new(close_latch), + ) as BoxQuicStreamReader)); + let (_unused_reader, writer) = quic::test::mock_stream_pair(stream_id); + ( + reader, + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + fn second_stream_id_error_request_stream( + stream_id: u32, + ) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let stream_id = VarInt::from_u32(stream_id); + let reader = PeekableStreamReader::new(StreamReader::new(Box::pin( + StreamIdErrorReadStream::fail_after_first_success(stream_id), + ) as BoxQuicStreamReader)); + let (_unused_reader, writer) = quic::test::mock_stream_pair(stream_id); + ( + reader, + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + #[derive(Debug, Clone)] + struct RecordingService { + seen_streams: Arc>>, + close_latch: Arc, + } + + impl Service for RecordingService { + type Response = (); + type Error = quic::ConnectionError; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: UnresolvedRequest) -> Self::Future { + self.seen_streams + .lock() + .expect("recording service mutex should not be poisoned") + .push(request.stream_id); + self.close_latch.close(); + ready(Ok(())) + } + } + + #[derive(Debug, Clone)] + struct TestServiceError; + + impl fmt::Display for TestServiceError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("test service failure") + } + } + + impl Error for TestServiceError {} + + #[derive(Debug, Clone)] + struct FailingService { + calls: Arc, + close_latch: Arc, + } + + impl Service for FailingService { + type Response = (); + type Error = TestServiceError; + type Future = Ready>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _request: UnresolvedRequest) -> Self::Future { + self.calls.fetch_add(1, Ordering::Relaxed); + self.close_latch.close(); + ready(Err(TestServiceError)) + } + } + + #[test] + fn h3_endpoint_implements_quic_connect_when_transport_connects() { + fn assert_connect>() {} + + assert_connect::>(); + assert_connect::>>(); + } + + #[test] + fn h3_endpoint_implements_quic_listen_when_transport_listens() { + fn assert_listen>() {} + + assert_listen::>(); + } + + #[test] + fn builder_preserves_custom_builder_and_quic_accessor_returns_transport() { + let listen = MockListen::default(); + let accepted = listen.accepted.clone(); + let builder: Arc> = + Arc::new(ConnectionBuilder::new(Arc::default())); + let endpoint = H3Endpoint::builder() + .quic(listen) + .builder(builder.clone()) + .build(); + + assert!(Arc::ptr_eq(&builder, &endpoint.builder)); + assert!(Arc::ptr_eq(&accepted, &endpoint.quic().accepted)); + assert_eq!(endpoint.pool_len(), 0); + } + + #[test] + fn builder_without_explicit_builder_uses_default_connection_builder() { + let endpoint: H3Endpoint<_, BuildableConnection> = + H3Endpoint::builder().quic(MockListen::default()).build(); + + assert_eq!(*endpoint.builder, ConnectionBuilder::default()); + assert_eq!(endpoint.pool_len(), 0); + } + + #[test] + fn into_quic_returns_owned_transport() { + let listen = MockListen::default(); + let accepted = listen.accepted.clone(); + let endpoint: H3Endpoint<_, BuildableConnection> = + H3Endpoint::builder().quic(listen).build(); + + let listen = endpoint.into_quic(); + + assert!(Arc::ptr_eq(&accepted, &listen.accepted)); + } + + #[test] + fn accept_error_display_describes_current_layer_and_preserves_source() { + let accept = AcceptError::Accept { + source: test_connection_error("accept display source"), + }; + let build = AcceptError::::Build { + source: test_connection_error("build display source"), + }; + + assert_eq!(accept.to_string(), "failed to accept QUIC connection"); + assert!(Error::source(&accept).is_some()); + assert_eq!(build.to_string(), "failed to initialize H3 connection"); + assert!(Error::source(&build).is_some()); + } + + #[tokio::test] + async fn inherent_connect_builds_and_reuses_pooled_h3_connection() { + let connector = CountingConnect::succeed(IdentifiedConnection::new("test-remote")); + let calls = connector.calls.clone(); + let servers = connector.servers.clone(); + let endpoint = H3Endpoint::new(connector); + let server: Authority = "test-remote:443".parse().unwrap(); + + let first = endpoint + .connect(server.clone()) + .await + .expect("first connect should build H3"); + let second = endpoint + .connect(server.clone()) + .await + .expect("second connect should reuse H3"); + + assert!(Arc::ptr_eq(&first, &second)); + assert_eq!(calls.load(Ordering::Relaxed), 1); + assert_eq!( + servers + .lock() + .expect("server log mutex should not be poisoned") + .as_slice(), + &[server], + ); + assert_eq!(endpoint.pool_len(), 1); + } + + #[tokio::test] + async fn inherent_connect_returns_connector_error() { + let endpoint = H3Endpoint::<_, IdentifiedConnection>::new(CountingConnect::fail( + test_connection_error("connector failed"), + )); + let server: Authority = "test-remote:443".parse().unwrap(); + + let error = endpoint + .connect(server) + .await + .expect_err("connector error should be returned"); + + assert!(matches!(error, pool::ConnectError::Connector { source } if source.is_transport())); + } + + #[tokio::test] + async fn inherent_connect_returns_h3_build_error() { + let endpoint = + H3Endpoint::<_, MockConnection>::new(CountingConnect::succeed(MockConnection::new())); + let server: Authority = "test-remote:443".parse().unwrap(); + + let error = endpoint + .connect(server) + .await + .expect_err("H3 builder error should be returned"); + + assert!(matches!(error, pool::ConnectError::H3 { source } if source.is_transport())); + } + + #[tokio::test] + async fn inherent_connect_rejects_peer_identity_mismatch() { + let connector = CountingConnect::succeed(IdentifiedConnection::new("actual.example")); + let endpoint = H3Endpoint::new(connector); + let server: Authority = "expected.example:443".parse().unwrap(); + + let error = endpoint + .connect(server) + .await + .expect_err("identity mismatch should be returned"); + + assert!(matches!( + error, + pool::ConnectError::IncorrectIdentity { expected, actual } + if expected == "expected.example" && actual.as_deref() == Some("actual.example") + )); + } + + #[tokio::test] + async fn quic_connect_impl_delegates_to_inner_transport() { + let connector = CountingConnect::succeed(IdentifiedConnection::new("delegated.example")); + let calls = connector.calls.clone(); + let endpoint = H3Endpoint::new(connector); + let server: Authority = "delegated.example:443".parse().unwrap(); + + let connection = quic::Connect::connect(&endpoint, &server) + .await + .expect("delegated connect should return raw QUIC connection"); + + assert_eq!(connection.remote_name, "delegated.example"); + assert_eq!(calls.load(Ordering::Relaxed), 1); + assert_eq!(endpoint.pool_len(), 0); + } + + #[tokio::test] + async fn quic_connect_impl_returns_inner_transport_error_without_pooling() { + let connector = + CountingConnect::::fail(test_connection_error("delegate failed")); + let calls = connector.calls.clone(); + let endpoint = H3Endpoint::new(connector); + let server: Authority = "delegated.example:443".parse().unwrap(); + + let error = quic::Connect::connect(&endpoint, &server) + .await + .expect_err("delegated connect should return raw QUIC errors"); + + assert!(error.is_transport()); + assert_eq!(calls.load(Ordering::Relaxed), 1); + assert_eq!(endpoint.pool_len(), 0); + } + + #[tokio::test] + async fn accept_builds_h3_connection_from_accepted_quic_connection() { + let listen = MockListen::default(); + let accepted = listen.accepted.clone(); + let mut endpoint = H3Endpoint::builder().quic(listen).build(); + + let connection = endpoint.accept().await.expect("accept should build H3"); + + assert_eq!(accepted.load(Ordering::Relaxed), 1); + connection.qpack().expect("qpack should be initialized"); + } + + #[tokio::test] + async fn quic_listen_impl_returns_inner_accept_and_shutdown_errors() { + let mut endpoint = H3Endpoint::builder().quic(FailingListen).build(); + + let accept_error = quic::Listen::accept(&mut endpoint) + .await + .expect_err("listen impl should return raw accept errors"); + let shutdown_error = quic::Listen::shutdown(&endpoint) + .await + .expect_err("listen impl should return raw shutdown errors"); + + assert!(accept_error.is_transport()); + assert!(shutdown_error.is_transport()); + } + + #[tokio::test] + async fn accept_returns_quic_accept_error() { + let mut endpoint = H3Endpoint::builder().quic(FailingListen).build(); + + let error = endpoint + .accept() + .await + .expect_err("accept error should be returned"); + + assert!(matches!(error, AcceptError::Accept { source } if source.is_transport())); + } + + #[tokio::test] + async fn accept_returns_h3_build_error() { + let mut endpoint = H3Endpoint::builder().quic(UnbuildableListen).build(); + + let error = endpoint + .accept() + .await + .expect_err("build error should be returned"); + + assert!(matches!(error, AcceptError::Build { source } if source.is_transport())); + } + + #[tokio::test] + async fn accept_owned_builds_h3_connection_from_shared_endpoint() { + let listen = MockListen::default(); + let accepted = listen.accepted.clone(); + let endpoint = Arc::new(H3Endpoint::builder().quic(listen).build()); + + let connection = endpoint + .accept_owned() + .await + .expect("accept_owned should build H3"); + + assert_eq!(accepted.load(Ordering::Relaxed), 1); + connection.qpack().expect("qpack should be initialized"); + } + + #[tokio::test] + async fn accept_owned_returns_h3_build_error_from_shared_endpoint() { + let endpoint = Arc::new(H3Endpoint::builder().quic(UnbuildableListen).build()); + + let error = endpoint + .accept_owned() + .await + .expect_err("shared build error should be returned"); + + assert!(matches!(error, AcceptError::Build { source } if source.is_transport())); + } + + #[tokio::test] + async fn accept_owned_returns_quic_accept_error_from_shared_endpoint() { + let endpoint = Arc::new(H3Endpoint::builder().quic(FailingListen).build()); + + let error = endpoint + .accept_owned() + .await + .expect_err("shared accept error should be returned"); + + assert!(matches!(error, AcceptError::Accept { source } if source.is_transport())); + } + + #[tokio::test] + async fn quic_listen_impl_delegates_accept_and_shutdown_to_inner_transport() { + let listen = MockListen::default(); + let accepted = listen.accepted.clone(); + let shutdowns = listen.shutdowns.clone(); + let mut endpoint = H3Endpoint::builder().quic(listen).build(); + + let _connection = quic::Listen::accept(&mut endpoint) + .await + .expect("listen impl should return raw QUIC connection"); + quic::Listen::shutdown(&endpoint) + .await + .expect("listen impl should delegate shutdown"); + + assert_eq!(accepted.load(Ordering::Relaxed), 1); + assert_eq!(shutdowns.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn listen_returns_quic_accept_error() { + let mut endpoint = H3Endpoint::builder().quic(FailingListen).build(); + + let error = endpoint + .listen(NoopService) + .await + .expect_err("listen should return listener accept error"); + + assert!(error.is_transport()); + } + + #[tokio::test] + async fn listen_skips_build_errors_and_returns_later_accept_error() { + let listen = SequencedListen::new([ + Ok(Arc::new(MockConnection::new())), + Err(test_connection_error("accept failed after build retry")), + ]); + let accepted = listen.accepted.clone(); + let mut endpoint = H3Endpoint::builder().quic(listen).build(); + + let error = endpoint + .listen(NoopService) + .await + .expect_err("listen should continue past build errors and return accept error"); + + assert!(error.is_transport()); + assert_eq!(accepted.load(Ordering::Relaxed), 2); + } + + #[tokio::test] + async fn listen_spawns_for_buildable_connections_and_returns_later_accept_error() { + let listen = SequencedListen::new([ + Ok(Arc::new(BuildableConnection)), + Err(test_connection_error( + "accept failed after serving one connection", + )), + ]); + let accepted = listen.accepted.clone(); + let mut endpoint = H3Endpoint::builder().quic(listen).build(); + + let error = endpoint + .listen(NoopService) + .await + .expect_err("listen should return the later listener error"); + + assert!(error.is_transport()); + assert_eq!(accepted.load(Ordering::Relaxed), 2); + } + + #[tokio::test] + async fn listen_owned_returns_quic_accept_error_from_shared_endpoint() { + let endpoint = Arc::new(H3Endpoint::builder().quic(FailingListen).build()); + + let error = endpoint + .listen_owned(NoopService) + .await + .expect_err("listen_owned should return shared listener accept error"); + + assert!(error.is_transport()); + } + + #[tokio::test] + async fn listen_owned_skips_build_errors_and_returns_later_accept_error() { + let listen = SequencedListen::new([ + Ok(Arc::new(MockConnection::new())), + Err(test_connection_error( + "shared accept failed after build retry", + )), + ]); + let accepted = listen.accepted.clone(); + let endpoint = Arc::new(H3Endpoint::builder().quic(listen).build()); + + let error = endpoint + .listen_owned(NoopService) + .await + .expect_err("listen_owned should continue past build errors and return accept error"); + + assert!(error.is_transport()); + assert_eq!(accepted.load(Ordering::Relaxed), 2); + } + + #[tokio::test] + async fn listen_owned_spawns_for_buildable_connections_and_returns_later_accept_error() { + let listen = SequencedListen::new([ + Ok(Arc::new(BuildableConnection)), + Err(test_connection_error( + "shared accept failed after serving one connection", + )), + ]); + let accepted = listen.accepted.clone(); + let endpoint = Arc::new(H3Endpoint::builder().quic(listen).build()); + + let error = endpoint + .listen_owned(NoopService) + .await + .expect_err("listen_owned should return the later listener error"); + + assert!(error.is_transport()); + assert_eq!(accepted.load(Ordering::Relaxed), 2); + } + + /// Verify that QuicMutGuard provides mutable access to the inner QUIC transport. + #[test] + fn test_quic_mut_guard_deref_mut() { + let mut quic = MockConnect; + let pool = Pool::::empty(); + let mut guard = QuicMutGuard { + quic: &mut quic, + pool: &pool, + }; + + // Verify Deref produces &MockConnect + let _reference: &MockConnect = &guard; + let _ = _reference; + + // Verify DerefMut produces &mut MockConnect + let _mut_reference: &mut MockConnect = &mut guard; + let _ = _mut_reference; + + drop(guard); + } + + /// Verify that dropping QuicMutGuard clears the connection pool. + #[test] + fn test_quic_mut_guard_drop_clears_pool() { + let mut quic = MockConnect; + let pool = Pool::::empty(); + + // Insert an entry into the pool to verify clearing + let auth: Authority = "example.com:443".parse().unwrap(); + pool.connections + .entry(auth) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + + assert_eq!(pool.len(), 1); + + { + let guard = QuicMutGuard { + quic: &mut quic, + pool: &pool, + }; + // Guard holds reference; pool still has entries + assert_eq!(pool.len(), 1); + drop(guard); + } // guard dropped, pool.clear() called + + assert_eq!(pool.len(), 0); + } + + /// Verify that H3Endpoint::quic_mut provides mutable access to the inner + /// QUIC transport. + #[test] + fn test_quic_mut_access() { + let mut h3 = H3Endpoint::new(MockConnect); + + let guard = h3.quic_mut(); + let _: &MockConnect = &guard; // verify Deref works + drop(guard); + } + + /// Verify that dropping the guard from H3Endpoint::quic_mut clears the + /// connection pool. + #[test] + fn test_quic_mut_drop_clears_pool() { + let mut h3 = H3Endpoint::new(MockConnect); + + // Insert an entry into the pool + let auth: Authority = "example.com:443".parse().unwrap(); + h3.pool + .connections + .entry(auth) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + + assert_eq!(h3.pool.len(), 1); + + { + let _guard = h3.quic_mut(); + } // guard dropped, pool.clear() called + + assert_eq!(h3.pool.len(), 0); + } + + #[test] + fn quic_mut_allows_mutating_transport_before_clearing_pool_on_drop() { + let mut endpoint: H3Endpoint<_, BuildableConnection> = H3Endpoint::builder() + .quic(MutableTransport { generation: 0 }) + .build(); + let pool = endpoint.pool.clone(); + let auth: Authority = "example.com:443".parse().unwrap(); + endpoint + .pool + .connections + .entry(auth) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + + { + let mut guard = endpoint.quic_mut(); + guard.generation = 1; + assert_eq!(guard.generation, 1); + assert_eq!(pool.len(), 1); + } + + assert_eq!(endpoint.quic().generation, 1); + assert_eq!(endpoint.pool_len(), 0); + } + + #[test] + fn test_clear_pool_clears_endpoint_connections() { + let h3 = H3Endpoint::new(MockConnect); + + let auth: Authority = "example.com:443".parse().unwrap(); + h3.pool + .connections + .entry(auth) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + + assert_eq!(h3.pool_len(), 1); + + h3.clear_pool(); + + assert_eq!(h3.pool_len(), 0); + } + + #[tokio::test] + async fn clear_pool_forces_next_connect_to_rebuild_connection() { + let connector = CountingConnect::succeed(IdentifiedConnection::new("reconnect.example")); + let calls = connector.calls.clone(); + let endpoint = H3Endpoint::new(connector); + let server: Authority = "reconnect.example:443".parse().unwrap(); + + let first = endpoint + .connect(server.clone()) + .await + .expect("first connect should build H3"); + endpoint.clear_pool(); + let second = endpoint + .connect(server) + .await + .expect("second connect should rebuild H3 after clearing pool"); + + assert_eq!(calls.load(Ordering::Relaxed), 2); + assert!(!Arc::ptr_eq(&first, &second)); + assert_eq!(endpoint.pool_len(), 1); + } + + #[tokio::test] + async fn dropping_quic_mut_guard_forces_next_connect_to_rebuild_connection() { + let connector = + CountingConnect::succeed(IdentifiedConnection::new("guard-reconnect.example")); + let calls = connector.calls.clone(); + let mut endpoint = H3Endpoint::new(connector); + let server: Authority = "guard-reconnect.example:443".parse().unwrap(); + + let first = endpoint + .connect(server.clone()) + .await + .expect("first connect should build H3"); + drop(endpoint.quic_mut()); + let second = endpoint + .connect(server) + .await + .expect("second connect should rebuild H3 after guard drop"); + + assert_eq!(calls.load(Ordering::Relaxed), 2); + assert!(!Arc::ptr_eq(&first, &second)); + assert_eq!(endpoint.pool_len(), 1); + } + + #[tokio::test] + async fn connector_errors_are_not_cached_and_later_connect_can_succeed() { + let connector = SequencedConnect::new([ + Err(test_connection_error("connector failed once")), + Ok(Arc::new(IdentifiedConnection::new("recover.example"))), + ]); + let calls = connector.calls.clone(); + let endpoint = H3Endpoint::new(connector); + let server: Authority = "recover.example:443".parse().unwrap(); + + let first = endpoint + .connect(server.clone()) + .await + .expect_err("first connect should fail at the connector"); + let second = endpoint + .connect(server) + .await + .expect("second connect should retry after connector failure"); + + assert!(matches!(first, pool::ConnectError::Connector { source } if source.is_transport())); + assert_eq!(calls.load(Ordering::Relaxed), 2); + assert_eq!( + second + .remote_authority() + .await + .expect("remote authority lookup should succeed") + .as_ref() + .map(|agent| agent.name()), + Some("recover.example"), + ); + assert_eq!(endpoint.pool_len(), 1); + } + + #[tokio::test] + async fn listen_connection_returns_when_accepting_request_streams_fails() { + let quic = Arc::new(ControlledConnection::new("listen connection closed")); + let state = state_without_qpack(quic.clone()); + let connection = Arc::new(H3Connection::from_state_for_test(state)); + + quic.trigger_close(); + + timeout( + Duration::from_millis(100), + listen_connection(connection, NoopService), + ) + .await + .expect("listen_connection should stop when accepting streams fails"); + } + + #[tokio::test] + async fn listen_connection_skips_requests_when_stream_id_lookup_fails() { + let quic = Arc::new(ControlledConnection::new("stream id failed")); + let state = state_without_qpack(quic.clone()); + enqueue_http3_request( + &state, + stream_id_error_request_stream(21, quic.close_latch.clone()), + ) + .await; + let connection = Arc::new(H3Connection::from_state_for_test(state)); + let seen_streams = Arc::new(Mutex::new(Vec::new())); + + timeout( + Duration::from_millis(100), + listen_connection( + connection, + RecordingService { + seen_streams: seen_streams.clone(), + close_latch: Arc::new(CloseLatch::default()), + }, + ), + ) + .await + .expect("listen_connection should stop after stream-id failure closes the connection"); + + assert!( + seen_streams + .lock() + .expect("recording service mutex should not be poisoned") + .is_empty() + ); + } + + #[tokio::test] + async fn listen_connection_continues_after_stream_id_lookup_failure() { + let quic = Arc::new(ControlledConnection::new("stream id failed then recovered")); + let state = state_with_qpack(quic.clone()).await; + enqueue_http3_request(&state, second_stream_id_error_request_stream(29)).await; + enqueue_http3_request(&state, http3_request_stream(31).await).await; + let connection = Arc::new(H3Connection::from_state_for_test(state)); + let seen_streams = Arc::new(Mutex::new(Vec::new())); + + timeout( + Duration::from_millis(100), + listen_connection( + connection, + RecordingService { + seen_streams: seen_streams.clone(), + close_latch: quic.close_latch.clone(), + }, + ), + ) + .await + .expect("listen_connection should continue after stream-id failure"); + + timeout(Duration::from_millis(100), async { + loop { + if seen_streams + .lock() + .expect("recording service mutex should not be poisoned") + .as_slice() + == [StreamId(VarInt::from_u32(31))] + { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("valid request after stream-id failure should be served"); + } + + #[tokio::test] + async fn listen_connection_returns_when_qpack_is_unavailable() { + let quic = Arc::new(ControlledConnection::new("qpack unused")); + let state = state_without_qpack(quic); + enqueue_http3_request(&state, http3_request_stream(23).await).await; + let connection = Arc::new(H3Connection::from_state_for_test(state)); + let seen_streams = Arc::new(Mutex::new(Vec::new())); + + timeout( + Duration::from_millis(100), + listen_connection( + connection, + RecordingService { + seen_streams: seen_streams.clone(), + close_latch: Arc::new(CloseLatch::default()), + }, + ), + ) + .await + .expect("listen_connection should stop when qpack is unavailable"); + + assert!( + seen_streams + .lock() + .expect("recording service mutex should not be poisoned") + .is_empty() + ); + } + + #[tokio::test] + async fn listen_connection_spawns_request_handling_for_valid_requests() { + let quic = Arc::new(ControlledConnection::new("request served")); + let state = state_with_qpack(quic.clone()).await; + enqueue_http3_request(&state, http3_request_stream(25).await).await; + let connection = Arc::new(H3Connection::from_state_for_test(state)); + let seen_streams = Arc::new(Mutex::new(Vec::new())); + + timeout( + Duration::from_millis(100), + listen_connection( + connection, + RecordingService { + seen_streams: seen_streams.clone(), + close_latch: quic.close_latch.clone(), + }, + ), + ) + .await + .expect("listen_connection should stop after the test service closes the connection"); + + timeout(Duration::from_millis(100), async { + loop { + if seen_streams + .lock() + .expect("recording service mutex should not be poisoned") + .as_slice() + == [StreamId(VarInt::from_u32(25))] + { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("request handler task should record the accepted stream"); + } + + #[tokio::test] + async fn listen_connection_keeps_running_when_request_handler_returns_error() { + let quic = Arc::new(ControlledConnection::new("handler failed")); + let state = state_with_qpack(quic.clone()).await; + enqueue_http3_request(&state, http3_request_stream(27).await).await; + let connection = Arc::new(H3Connection::from_state_for_test(state)); + let calls = Arc::new(AtomicUsize::new(0)); + + timeout( + Duration::from_millis(100), + listen_connection( + connection, + FailingService { + calls: calls.clone(), + close_latch: quic.close_latch.clone(), + }, + ), + ) + .await + .expect("listen_connection should stop after the failing handler closes the connection"); + + timeout(Duration::from_millis(100), async { + loop { + if calls.load(Ordering::Relaxed) == 1 { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("failing request handler task should run once"); + + // calls is incremented inside Service::call before the returned Ready future + // is polled. Yield additional times so the spawned listen_request task observes + // the ready error and runs the diagnostic logging branch before the runtime + // is torn down. + for _ in 0..16 { + tokio::task::yield_now().await; + } + } + + #[tokio::test] + async fn test_connection_helpers_expose_expected_pending_and_agent_paths() { + let buildable = BuildableConnection; + quic::Lifecycle::check(&buildable).expect("buildable connection is live"); + quic::Lifecycle::close( + &buildable, + crate::error::Code::H3_NO_ERROR, + std::borrow::Cow::Borrowed("test close"), + ); + assert!( + quic::WithLocalAuthority::local_authority(&buildable) + .await + .expect("buildable local authority") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(&buildable) + .await + .expect("buildable remote authority") + .is_none() + ); + quic::ManageStream::open_uni(&buildable) + .await + .expect("buildable open_uni"); + timeout( + Duration::from_millis(10), + quic::ManageStream::open_bi(&buildable), + ) + .await + .expect_err("buildable open_bi stays pending"); + timeout( + Duration::from_millis(10), + quic::ManageStream::accept_bi(&buildable), + ) + .await + .expect_err("buildable accept_bi stays pending"); + timeout( + Duration::from_millis(10), + quic::ManageStream::accept_uni(&buildable), + ) + .await + .expect_err("buildable accept_uni stays pending"); + timeout( + Duration::from_millis(10), + quic::Lifecycle::closed(&buildable), + ) + .await + .expect_err("buildable closed stays pending"); + + let identified = IdentifiedConnection::new("identified.example"); + quic::Lifecycle::check(&identified).expect("identified connection is live"); + assert!( + quic::WithLocalAuthority::local_authority(&identified) + .await + .expect("identified local authority") + .is_none() + ); + let remote = quic::WithRemoteAuthority::remote_authority(&identified) + .await + .expect("identified remote authority") + .expect("identified remote authority exists"); + assert_eq!( + identity::RemoteAuthority::name(&remote), + "identified.example" + ); + assert!(identity::RemoteAuthority::cert_chain(&remote).is_empty()); + quic::ManageStream::open_uni(&identified) + .await + .expect("identified open_uni"); + timeout( + Duration::from_millis(10), + quic::ManageStream::open_bi(&identified), + ) + .await + .expect_err("identified open_bi stays pending"); + timeout( + Duration::from_millis(10), + quic::Lifecycle::closed(&identified), + ) + .await + .expect_err("identified closed stays pending"); + } + + #[tokio::test] + async fn test_controlled_connection_and_stream_id_helpers_cover_remaining_traits() { + let controlled = ControlledConnection::new("controlled terminal"); + quic::Lifecycle::check(&controlled).expect("controlled connection is live"); + assert!( + quic::WithLocalAuthority::local_authority(&controlled) + .await + .expect("controlled local authority") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(&controlled) + .await + .expect("controlled remote authority") + .is_none() + ); + quic::ManageStream::open_uni(&controlled) + .await + .expect("controlled open_uni"); + timeout( + Duration::from_millis(10), + quic::ManageStream::open_bi(&controlled), + ) + .await + .expect_err("controlled open_bi stays pending"); + quic::Lifecycle::close( + &controlled, + crate::error::Code::H3_NO_ERROR, + std::borrow::Cow::Borrowed("test close"), + ); + let closed = quic::Lifecycle::closed(&controlled).await; + assert!(matches!(closed, quic::ConnectionError::Transport { .. })); + + let close_latch = Arc::new(CloseLatch::default()); + let mut stream = StreamIdErrorReadStream::new(close_latch.clone()); + let stream_id_error = stream + .stream_id() + .await + .expect_err("first stream-id helper reports reset"); + assert!(matches!(stream_id_error, quic::StreamError::Reset { .. })); + close_latch.wait().await; + stream + .stop(VarInt::from_u32(0x33)) + .await + .expect("stop helper is a no-op success"); + assert_eq!( + Pin::new(&mut stream) + .next() + .await + .expect("stream yields one chunk") + .expect("chunk is ok"), + Bytes::from_static(&[0x01]) + ); + assert!( + Pin::new(&mut stream).next().await.is_none(), + "stream helper ends after its single chunk" + ); + + let mut stream = StreamIdErrorReadStream::fail_after_first_success(VarInt::from_u32(99)); + assert_eq!( + stream.stream_id().await.expect("first lookup succeeds"), + VarInt::from_u32(99) + ); + assert!(stream.stream_id().await.is_err()); + } + + #[tokio::test] + async fn test_listener_and_service_helpers_cover_direct_paths() { + let listen = MockListen::default(); + let shutdowns = listen.shutdowns.clone(); + let mut listen_ref = &listen; + quic::Listen::shutdown(&listen_ref) + .await + .expect("shared mock shutdown succeeds"); + let _accepted = quic::Listen::accept(&mut listen_ref) + .await + .expect("shared mock accept succeeds"); + assert_eq!(shutdowns.load(Ordering::Relaxed), 1); + assert_eq!(listen.accepted.load(Ordering::Relaxed), 1); + + let failing = FailingListen; + let mut failing_ref = &failing; + assert!(quic::Listen::accept(&mut failing_ref).await.is_err()); + assert!(quic::Listen::shutdown(&failing_ref).await.is_err()); + + let unbuildable = UnbuildableListen; + let mut unbuildable_ref = &unbuildable; + let _ = quic::Listen::accept(&mut unbuildable_ref) + .await + .expect("shared unbuildable accept returns raw connection"); + quic::Listen::shutdown(&unbuildable_ref) + .await + .expect("shared unbuildable shutdown succeeds"); + + let sequenced = SequencedListen::new([Ok(Arc::new(BuildableConnection))]); + quic::Listen::shutdown(&sequenced) + .await + .expect("sequenced shutdown succeeds"); + let mut sequenced_ref = &sequenced; + quic::Listen::shutdown(&sequenced_ref) + .await + .expect("shared sequenced shutdown succeeds"); + let _ = quic::Listen::accept(&mut sequenced_ref) + .await + .expect("shared sequenced accept consumes queued result"); + + let quic = Arc::new(ControlledConnection::new("direct request connection")); + let state = state_with_qpack(quic).await; + let qpack = state.qpack().expect("qpack should be initialized"); + let erased = state.erase(); + let connection = Arc::new(erased.clone()); + let (request_reader, _request_write_side) = + quic::test::mock_stream_pair(VarInt::from_u32(111)); + let (_response_read_side, response_writer) = + quic::test::mock_stream_pair(VarInt::from_u32(111)); + let close_latch = Arc::new(CloseLatch::default()); + let request = UnresolvedRequest { + stream_id: StreamId(VarInt::from_u32(111)), + read_stream: MessageReader::new( + VarInt::from_u32(111), + StreamReader::new(GuardQuicReader::new( + Box::pin(request_reader) as BoxQuicStreamReader + )), + qpack.decoder.clone(), + erased.clone(), + ), + write_stream: MessageWriter::new( + SinkWriter::new(GuardQuicWriter::new( + Box::pin(response_writer) as BoxQuicStreamWriter + )), + qpack.encoder.clone(), + erased, + ), + connection, + }; + let mut service = RecordingService { + seen_streams: Arc::new(Mutex::new(Vec::new())), + close_latch, + }; + listen_request(&mut service, request).await; + assert_eq!( + service + .seen_streams + .lock() + .expect("recording service mutex should not be poisoned") + .as_slice(), + &[StreamId(VarInt::from_u32(111))] + ); + } +} diff --git a/src/endpoint/binds/collection.rs b/src/endpoint/binds/collection.rs deleted file mode 100644 index 3fc6e8c..0000000 --- a/src/endpoint/binds/collection.rs +++ /dev/null @@ -1,114 +0,0 @@ -use std::{ - cell::LazyCell, - collections::{HashMap, hash_map}, -}; - -use derive_more::{Deref, DerefMut, From, Into}; -use http::uri::{Authority, PathAndQuery, Scheme}; - -use super::{Bind, BindConflictError, BindHost}; -use crate::dquic::qinterface::bind_uri::BindUri; - -/// A collection of [`Bind`] patterns, typically populated from CLI arguments. -#[derive(Debug, Clone, PartialEq, Eq, Deref, DerefMut, From, Into)] -pub struct Binds { - /// Bind patterns - binds: Vec, -} - -impl Binds { - /// Create a new [`Binds`] from a list of [`Bind`] patterns. - pub fn new(binds: Vec) -> Self { - Self { binds } - } - - /// Expand all contained [`Bind`] patterns into concrete [`BindUri`]s, - /// checking for conflicting path-and-query on the same target. - /// - /// Two expanded URIs are considered "the same target" when their - /// scheme and authority (IP + port, or family + NIC + port) are - /// identical. If such a pair carries different path-and-query - /// values, a [`BindConflictError`] is returned. - /// - /// Duplicate URIs (same target *and* same path-and-query) are - /// silently deduplicated. - #[allow(clippy::result_large_err)] - pub fn to_bind_uris<'a, I>( - &'a self, - interfaces: I, - ) -> Result, Box> - where - I: IntoIterator + Clone, - { - let mut seen: HashMap<(Scheme, Authority), Option> = HashMap::new(); - let mut bind_uris = Vec::new(); - - let mut push_bind_uri = |bind_uri: BindUri| { - let inner = bind_uri.as_uri(); - let key = ( - inner.scheme().expect("BindUri always has a scheme").clone(), - inner - .authority() - .expect("BindUri always has an authority") - .clone(), - ); - let path_and_query = inner.path_and_query().cloned(); - // Normalize a bare `/` to `None` for consistent conflict semantics. - let path_and_query = - path_and_query.and_then(|pq| if pq.as_str() == "/" { None } else { Some(pq) }); - - match seen.entry(key) { - hash_map::Entry::Occupied(entry) => { - if *entry.get() != path_and_query { - let (scheme, authority) = entry.key(); - return Err(Box::new(BindConflictError { - scheme: scheme.clone(), - authority: authority.clone(), - existing: entry.get().clone(), - incoming: path_and_query, - })); - } - Ok(()) - } - hash_map::Entry::Vacant(entry) => { - entry.insert(path_and_query.clone()); - bind_uris.push(bind_uri); - Ok(()) - } - } - }; - - let bind_uri_templates = - self.iter() - .try_fold(Vec::with_capacity(self.len()), |mut templates, bind| { - match bind.host { - BindHost::Ip { addr, .. } => { - let template = bind.bind_uri_template(); - let port = bind.effective_port(); - let authority = format!("{addr}:{port}").parse(); - if let Some(bind_uri) = authority.ok().and_then(template) { - push_bind_uri(bind_uri)?; - } - } - BindHost::Glob { .. } | BindHost::Exact { .. } => { - let template = LazyCell::new(|| bind.bind_uri_template()); - templates.push((bind, template)) - } - } - Ok::<_, Box>(templates) - })?; - - interfaces - .into_iter() - .flat_map(|interface| { - bind_uri_templates.iter().flat_map(|(bind, template)| { - #[allow(clippy::redundant_closure)] - bind.bind_hosts_for_interface(interface) - .flat_map(|authority| template(authority)) - }) - }) - .try_for_each(push_bind_uri)?; - - Ok(bind_uris) - } -} diff --git a/src/endpoint/binds/error.rs b/src/endpoint/binds/error.rs deleted file mode 100644 index 1bc6de1..0000000 --- a/src/endpoint/binds/error.rs +++ /dev/null @@ -1,22 +0,0 @@ -use http::uri::{Authority, PathAndQuery, Scheme}; -use snafu::Snafu; - -/// Error indicating that two [`Bind`](super::Bind) patterns expand to the same target -/// (identical IP + port, or identical family + NIC + port) but carry -/// different path-and-query values. -#[derive(Debug, Clone, Snafu)] -#[snafu(display( - "conflicting bindings exist for bind target `{scheme}://{authority}`: `{e}` vs `{i}`", - e = existing.as_ref().map_or("/", PathAndQuery::as_str), - i = incoming.as_ref().map_or("/", PathAndQuery::as_str), -))] -pub struct BindConflictError { - /// The scheme component of the conflicting bind target. - pub scheme: Scheme, - /// The authority component of the conflicting bind target. - pub authority: Authority, - /// The first encountered path-and-query. - pub existing: Option, - /// The conflicting path-and-query. - pub incoming: Option, -} diff --git a/src/endpoint/binds/setup.rs b/src/endpoint/binds/setup.rs deleted file mode 100644 index ca770ed..0000000 --- a/src/endpoint/binds/setup.rs +++ /dev/null @@ -1,155 +0,0 @@ -//! Shared bind setup logic for QUIC interface binding. -//! -//! Consolidates the duplicated bind-interfaces initialization flow. - -use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; - -use futures::{StreamExt, stream::FuturesUnordered}; -use tokio_util::task::AbortOnDropHandle; -use tracing::Instrument; - -use super::{BindConflictError, Binds}; -use crate::dquic::{ - prelude::handy::DEFAULT_IO_FACTORY, - qinterface::{ - BindInterface, - bind_uri::BindUri, - device::{Devices, InterfacesMonitor}, - manager::InterfaceManager, - }, -}; - -/// Result of [`setup_bind_interfaces`], carrying all state needed by downstream -/// code (DNS resolver construction, H3Client building, etc.). -pub struct BindSetup { - /// Concrete bind URIs expanded from the user-supplied [`Binds`] patterns. - pub bind_uris: Vec, - /// Interface manager that owns the bindings. - pub iface_manager: Arc, - /// Bound interfaces — one per bind URI. - pub bind_interfaces: Vec, - /// Interfaces monitor for detecting runtime changes to network interfaces. - pub monitor: InterfacesMonitor, -} - -/// Like [`setup_bind_interfaces`], but calls `f` on the expanded bind URIs -/// before binding, allowing callers to mutate them (e.g. inject properties). -pub async fn setup_bind_interfaces_with( - binds: &Binds, - f: impl FnOnce(&mut Vec), -) -> Result> { - let monitor = Devices::global().monitor(); - - let mut bind_uris = binds.to_bind_uris(monitor.interfaces().keys().map(String::as_str))?; - f(&mut bind_uris); - - let iface_manager = Arc::new(InterfaceManager::new()); - let io_factory = Arc::new(DEFAULT_IO_FACTORY); - - let bind_interfaces = bind_uris - .iter() - .map(|bind_uri| iface_manager.bind(bind_uri.clone(), io_factory.clone())) - .collect::>() - .collect::>() - .await; - - Ok(BindSetup { - bind_uris, - iface_manager, - bind_interfaces, - monitor, - }) -} - -/// Expand [`Binds`] patterns into concrete network bindings. -pub async fn setup_bind_interfaces(binds: &Binds) -> Result> { - setup_bind_interfaces_with(binds, |_| {}).await -} - -/// Watch for network interface changes and dynamically bind/unbind URIs. -pub fn watch_bind_interfaces( - binds: &Binds, - mut monitor: InterfacesMonitor, - initial_bind_uris: Vec, - bind_fn: B, - unbind_fn: U, -) -> AbortOnDropHandle<()> -where - B: Fn(BindUri) -> Pin + Send>> + Send + 'static, - U: Fn(BindUri) + Send + 'static, -{ - let binds = binds.clone(); - let span = tracing::Span::current(); - - AbortOnDropHandle::new(tokio::spawn( - async move { - fn to_keyed_map(uris: Vec) -> HashMap { - uris.into_iter() - .map(|uri| (uri.identity_key(), uri)) - .collect() - } - - // Initial reconcile - let mut current_map: HashMap = - match binds.to_bind_uris(monitor.interfaces().keys().map(String::as_str)) { - Ok(new_uris) => { - let new_map = to_keyed_map(new_uris); - let initial_map = to_keyed_map(initial_bind_uris); - - for (key, uri) in &new_map { - if !initial_map.contains_key(key) { - tracing::debug!("binding new URI `{uri}` during initial reconcile"); - bind_fn(uri.clone()).await; - } - } - for (key, uri) in &initial_map { - if !new_map.contains_key(key) { - tracing::info!("unbinding URI `{uri}` during initial reconcile"); - unbind_fn(uri.clone()); - } - } - - new_map - } - Err(err) => { - tracing::warn!( - "failed to compute bind URIs during initial reconcile: {}", - snafu::Report::from_error(&err) - ); - to_keyed_map(initial_bind_uris) - } - }; - - while let Some((interfaces, _event)) = monitor.update().await { - let new_uris = match binds.to_bind_uris(interfaces.keys().map(String::as_str)) { - Ok(uris) => uris, - Err(err) => { - tracing::warn!( - "failed to compute bind URIs after interface change: {}", - snafu::Report::from_error(&err) - ); - continue; - } - }; - - let new_map = to_keyed_map(new_uris); - - for (key, uri) in ¤t_map { - if !new_map.contains_key(key) { - tracing::info!("unbinding URI `{uri}`"); - unbind_fn(uri.clone()); - } - } - for (key, uri) in &new_map { - if !current_map.contains_key(key) { - tracing::debug!("binding new URI `{uri}`"); - bind_fn(uri.clone()).await; - } - } - - current_map = new_map; - } - } - .instrument(span), - )) -} diff --git a/src/endpoint/binds/tests.rs b/src/endpoint/binds/tests.rs deleted file mode 100644 index 8b86ce8..0000000 --- a/src/endpoint/binds/tests.rs +++ /dev/null @@ -1,638 +0,0 @@ -use std::net::IpAddr; - -use http::uri::Scheme; - -use super::*; -use crate::dquic::{qbase::net::Family, qinterface::bind_uri::BindUriScheme}; - -// -- Parsing tests -- - -#[test] -fn parse_full_iface_with_family() { - let b: Bind = "iface://v4.enp17s0:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V4)); - assert_eq!( - b.host, - BindHost::classify("enp17s0", Some(Family::V4)).unwrap() - ); - assert_eq!(b.port, Some(8080)); - assert!(b.path_and_query.is_none()); -} - -#[test] -fn parse_full_iface_glob() { - let b: Bind = "iface://v4.en*:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V4)); - assert!(b.host.is_glob()); - assert_eq!(b.host.as_str(), "en*"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_iface_no_family() { - let b: Bind = "iface://enp17s0:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), None); - assert!(!b.host.is_glob()); - assert_eq!(b.host.as_str(), "enp17s0"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_iface_no_port() { - let b: Bind = "iface://v4.enp17s0".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V4)); - assert_eq!(b.host.as_str(), "enp17s0"); - assert_eq!(b.port, None); -} - -#[test] -fn parse_inet() { - let b: Bind = "inet://127.0.0.1:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.family(), None); - assert_eq!(b.host.as_str(), "127.0.0.1"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_no_scheme_ip() { - let b: Bind = "127.0.0.1:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "127.0.0.1"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_no_scheme_iface() { - let b: Bind = "enp17s0:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), None); - assert_eq!(b.host.as_str(), "enp17s0"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_no_scheme_with_family() { - let b: Bind = "v4.enp17s0:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V4)); - assert_eq!(b.host.as_str(), "enp17s0"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_glob_no_scheme() { - let b: Bind = "en*:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), None); - assert!(b.host.is_glob()); - assert_eq!(b.host.as_str(), "en*"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_star_only() { - let b: Bind = "*".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), None); - assert!(b.host.is_glob()); - assert_eq!(b.host.as_str(), "*"); - assert_eq!(b.port, None); -} - -#[test] -fn parse_star_with_port() { - let b: Bind = "*:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), None); - assert!(b.host.is_glob()); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_v4_star() { - let b: Bind = "v4.*".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V4)); - assert!(b.host.is_glob()); - assert_eq!(b.port, None); -} - -#[test] -fn parse_v6_star_with_port() { - let b: Bind = "v6.*:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V6)); - assert!(b.host.is_glob()); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_no_scheme_no_port() { - let b: Bind = "enp17s0".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), None); - assert_eq!(b.host.as_str(), "enp17s0"); - assert_eq!(b.port, None); -} - -#[test] -fn parse_with_path_and_query() { - let b: Bind = "iface://v4.en*:8080/?stun_server=stun.genmeta.net" - .parse() - .unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert_eq!(b.host.family(), Some(Family::V4)); - assert!(b.host.is_glob()); - assert_eq!(b.port, Some(8080)); - assert_eq!( - b.path_and_query_str(), - Some("/?stun_server=stun.genmeta.net") - ); -} - -#[test] -fn parse_with_query_only() { - let b: Bind = "iface://v4.enp17s0:8080?stun=true".parse().unwrap(); - assert_eq!(b.path_and_query_str(), Some("?stun=true")); -} - -// -- Display round-trip -- - -#[test] -fn display_full() { - let b: Bind = "iface://v4.enp17s0:8080".parse().unwrap(); - assert_eq!(b.to_string(), "iface://v4.enp17s0:8080"); -} - -#[test] -fn display_no_port() { - let b: Bind = "iface://v4.enp17s0".parse().unwrap(); - assert_eq!(b.to_string(), "iface://v4.enp17s0"); -} - -#[test] -fn display_no_family() { - let b: Bind = "iface://enp17s0:8080".parse().unwrap(); - assert_eq!(b.to_string(), "iface://enp17s0:8080"); -} - -// -- Glob matching -- - -#[test] -fn glob_exact_match() { - let host = BindHost::classify("enp17s0", None).unwrap(); - assert!(host.matches("enp17s0")); - assert!(!host.matches("wlan0")); -} - -#[test] -fn glob_star_match() { - let host = BindHost::classify("en*", None).unwrap(); - assert!(host.matches("enp17s0")); - assert!(host.matches("eno1")); - assert!(!host.matches("wlan0")); - - let star = BindHost::classify("*", None).unwrap(); - assert!(star.matches("anything")); -} - -#[test] -fn glob_bracket_class() { - let host = BindHost::classify("[ew]*", None).unwrap(); - assert!(host.is_glob()); - assert!(host.matches("enp17s0")); - assert!(host.matches("wlan0")); - assert!(!host.matches("lo")); -} - -#[test] -fn glob_bracket_single() { - let host = BindHost::classify("wlan[01]", None).unwrap(); - assert!(host.is_glob()); - assert!(host.matches("wlan0")); - assert!(host.matches("wlan1")); - assert!(!host.matches("wlan2")); -} - -// -- Classify -- - -#[test] -fn classify_ipv4_as_ip() { - let host = BindHost::classify("127.0.0.1", None).unwrap(); - assert!(host.is_ip_addr()); - assert!(!host.is_glob()); - assert_eq!(host.as_str(), "127.0.0.1"); -} - -#[test] -fn classify_ipv6_bracket_as_ip() { - let host = BindHost::classify("[::1]", None).unwrap(); - assert!(host.is_ip_addr()); - assert_eq!(host.as_str(), "::1"); - assert_eq!(host.as_ip_addr().unwrap(), "::1".parse::().unwrap()); -} - -#[test] -fn classify_bracket_non_ip_as_glob() { - let host = BindHost::classify("[ew]", None).unwrap(); - assert!(host.is_glob()); - assert!(!host.is_ip_addr()); -} - -// -- Families -- - -#[test] -fn families_both() { - let b: Bind = "enp17s0:8080".parse().unwrap(); - assert_eq!(b.host.families(), [Family::V4, Family::V6]); -} - -#[test] -fn families_v4_only() { - let b: Bind = "v4.enp17s0:8080".parse().unwrap(); - assert_eq!(b.host.families(), [Family::V4]); -} - -// -- BindUri generation (iterators) -- - -#[test] -fn expand_iface() { - let b: Bind = "iface://v4.enp17s0:8080".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); - assert_eq!(uris, vec!["iface://v4.enp17s0:8080/"]); -} - -#[test] -fn expand_both_families() { - let b: Bind = "iface://enp17s0:8080".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); - assert_eq!( - uris, - vec!["iface://v4.enp17s0:8080/", "iface://v6.enp17s0:8080/"] - ); -} - -#[test] -fn expand_auto_port() { - let b: Bind = "iface://v4.enp17s0".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); - assert_eq!(uris.len(), 1); - assert!(uris[0].starts_with("iface://v4.enp17s0:0/")); -} - -#[test] -fn expand_inet() { - let b: Bind = "127.0.0.1:8080".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); - assert_eq!(uris, vec!["inet://127.0.0.1:8080/"]); -} - -#[test] -fn expand_with_interfaces_glob() { - let b: Bind = "en*:8080".parse().unwrap(); - let interfaces = ["enp17s0", "eno1", "wlan0", "lo"]; - let uris: Vec<_> = b.to_bind_uris(interfaces).collect(); - // en* matches enp17s0 and eno1, each with V4 + V6 - assert_eq!(uris.len(), 4); -} - -#[test] -fn expand_with_interfaces_star() { - let b: Bind = "*:8080".parse().unwrap(); - let interfaces = ["enp17s0", "wlan0"]; - let uris: Vec<_> = b.to_bind_uris(interfaces).collect(); - // * matches all, each with V4 + V6 - assert_eq!(uris.len(), 4); -} - -#[test] -fn expand_path_and_query_passthrough() { - let b: Bind = "iface://v4.en*:8080/?stun_server=stun.genmeta.net" - .parse() - .unwrap(); - let uris: Vec<_> = b.to_bind_uris(["enp17s0"]).map(|u| u.to_string()).collect(); - assert_eq!( - uris, - vec!["iface://v4.enp17s0:8080/?stun_server=stun.genmeta.net"] - ); -} - -#[test] -fn path_and_query_is_validated() { - let b: Bind = "iface://v4.en*:8080/?key=value".parse().unwrap(); - let pq = b.path_and_query.as_ref().unwrap(); - assert_eq!(pq, "/?key=value"); -} - -// -- Bare IPv6 address tests -- - -#[test] -fn parse_bare_ipv6_loopback() { - let b: Bind = "::1".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert!(b.host.is_ip_addr()); - assert_eq!(b.host.as_str(), "::1"); - assert_eq!(b.port, None); - assert!(b.path_and_query.is_none()); -} - -#[test] -fn parse_bare_ipv6_any() { - let b: Bind = "::".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert!(b.host.is_ip_addr()); - assert_eq!(b.host.as_str(), "::"); - assert_eq!(b.port, None); -} - -#[test] -fn parse_bare_ipv6_full() { - let b: Bind = "2001:db8::1".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert!(b.host.is_ip_addr()); - assert_eq!(b.host.as_str(), "2001:db8::1"); - assert_eq!(b.port, None); -} - -#[test] -fn parse_bare_ipv4() { - // Bare IPv4 without port also works via the fast path - let b: Bind = "192.168.1.1".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert!(b.host.is_ip_addr()); - assert_eq!(b.host.as_str(), "192.168.1.1"); - assert_eq!(b.port, None); -} - -#[test] -fn display_bare_ipv6() { - let b: Bind = "::1".parse().unwrap(); - // Display wraps IPv6 in brackets - assert_eq!(b.to_string(), "inet://[::1]"); -} - -#[test] -fn expand_bare_ipv6() { - let b: Bind = "::1".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); - assert_eq!(uris.len(), 1); - assert!(uris[0].starts_with("inet://[::1]:0/")); -} - -// -- Glob bracket parsing -- - -#[test] -fn parse_glob_bracket_class() { - let b: Bind = "[ew]*:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Iface); - assert!(b.host.is_glob()); - assert_eq!(b.host.as_str(), "[ew]*"); - assert_eq!(b.port, Some(8080)); -} - -// -- IPv6 bracket syntax tests -- - -#[test] -fn parse_ipv6_full_scheme() { - let b: Bind = "inet://[::1]:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.family(), None); - assert_eq!(b.host.as_str(), "::1"); - assert_eq!(b.port, Some(8080)); - assert!(b.host.is_ip_addr()); -} - -#[test] -fn parse_ipv6_no_scheme() { - let b: Bind = "[::1]:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "::1"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_ipv6_full_addr() { - let b: Bind = "inet://[2001:db8::1]:443".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "2001:db8::1"); - assert_eq!(b.port, Some(443)); - assert!(b.host.is_ip_addr()); -} - -#[test] -fn parse_ipv6_link_local() { - let b: Bind = "[fe80::1]:8080".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "fe80::1"); - assert_eq!(b.port, Some(8080)); -} - -#[test] -fn parse_ipv6_any() { - let b: Bind = "inet://[::]:0".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "::"); - assert_eq!(b.port, Some(0)); -} - -#[test] -fn parse_ipv6_no_port() { - let b: Bind = "[::1]".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "::1"); - assert_eq!(b.port, None); -} - -#[test] -fn parse_ipv6_with_path_and_query() { - let b: Bind = "inet://[::1]:8080/?key=value".parse().unwrap(); - assert_eq!(b.scheme, BindUriScheme::Inet); - assert_eq!(b.host.as_str(), "::1"); - assert_eq!(b.port, Some(8080)); - assert_eq!(b.path_and_query_str(), Some("/?key=value")); -} - -#[test] -fn display_ipv6_roundtrip() { - let input = "inet://[::1]:8080"; - let b: Bind = input.parse().unwrap(); - assert_eq!(b.to_string(), input); -} - -#[test] -fn display_ipv6_full_addr() { - let b: Bind = "inet://[2001:db8::1]:443".parse().unwrap(); - assert_eq!(b.to_string(), "inet://[2001:db8::1]:443"); -} - -#[test] -fn expand_ipv6() { - let b: Bind = "inet://[::1]:8080".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); - assert_eq!(uris, vec!["inet://[::1]:8080/"]); -} - -#[test] -fn expand_ipv6_auto_port() { - let b: Bind = "[::1]".parse().unwrap(); - let uris: Vec<_> = b.to_bind_uris([]).map(|u| u.to_string()).collect(); - assert_eq!(uris.len(), 1); - assert!(uris[0].starts_with("inet://[::1]:0/")); -} - -#[test] -fn family_ip_rejected() { - // v4.127.0.0.1 is not a valid bind pattern - assert!("v4.127.0.0.1:8080".parse::().is_err()); - assert!("inet://v6.[::1]:8080".parse::().is_err()); -} - -#[test] -fn ipv6_host_is_ip_addr() { - let b: Bind = "[::1]:8080".parse().unwrap(); - assert!(b.host.is_ip_addr()); - assert!(b.host.as_ip_addr().unwrap().is_ipv6()); -} - -// -- Binds tests -- - -#[test] -fn binds_new_and_deref() { - let v = vec![ - "iface://v4.enp17s0:8080".parse::().unwrap(), - "127.0.0.1:443".parse::().unwrap(), - ]; - let binds = Binds::new(v.clone()); - // Deref to &[Bind] - assert_eq!(binds.len(), 2); - assert_eq!(&*binds, &v[..]); -} - -#[test] -fn binds_deref_mut() { - let mut binds = Binds::new(vec!["*:8080".parse().unwrap()]); - binds.push("127.0.0.1:443".parse().unwrap()); - assert_eq!(binds.len(), 2); -} - -#[test] -fn binds_from_into_vec() { - let v = vec!["*:8080".parse::().unwrap()]; - let binds: Binds = v.clone().into(); - let out: Vec = binds.into(); - assert_eq!(out, v); -} - -#[test] -fn binds_to_bind_uris_no_conflict() { - let binds = Binds::new(vec![ - "iface://v4.enp17s0:8080".parse().unwrap(), - "127.0.0.1:443".parse().unwrap(), - ]); - let uris = binds.to_bind_uris(["enp17s0"]).unwrap(); - assert_eq!(uris.len(), 2); -} - -#[test] -fn binds_to_bind_uris_dedup() { - // Two identical binds should produce only one URI - let binds = Binds::new(vec![ - "127.0.0.1:8080".parse().unwrap(), - "inet://127.0.0.1:8080".parse().unwrap(), - ]); - let uris = binds.to_bind_uris([]).unwrap(); - assert_eq!(uris.len(), 1); -} - -#[test] -fn binds_to_bind_uris_conflict_different_pq() { - // Same target, different path-and-query → conflict - let binds = Binds::new(vec![ - "iface://v4.enp17s0:8080/?stun=true".parse().unwrap(), - "iface://v4.enp17s0:8080/?stun=false".parse().unwrap(), - ]); - let err = binds.to_bind_uris(["enp17s0"]).unwrap_err(); - assert_eq!(err.scheme, "iface".parse::().unwrap()); - assert!(err.to_string().contains("conflicting")); - assert!(err.to_string().contains("stun=true")); - assert!(err.to_string().contains("stun=false")); -} - -#[test] -fn binds_to_bind_uris_conflict_pq_vs_none() { - // Same target: one with path-and-query, one without → conflict - let binds = Binds::new(vec![ - "iface://v4.enp17s0:8080".parse().unwrap(), - "iface://v4.enp17s0:8080/?stun=true".parse().unwrap(), - ]); - let err = binds.to_bind_uris(["enp17s0"]).unwrap_err(); - assert!(err.to_string().contains("conflicting")); - assert!(err.existing.is_none()); - assert!(err.incoming.is_some()); -} - -#[test] -fn binds_to_bind_uris_same_pq_dedup() { - // Same target, same path-and-query → deduplicated, no conflict - let binds = Binds::new(vec![ - "iface://v4.enp17s0:8080/?stun=true".parse().unwrap(), - "iface://v4.enp17s0:8080/?stun=true".parse().unwrap(), - ]); - let uris = binds.to_bind_uris(["enp17s0"]).unwrap(); - assert_eq!(uris.len(), 1); -} - -#[test] -fn binds_to_bind_uris_glob_conflict() { - // Glob expanding to same interface with different pq - let binds = Binds::new(vec![ - "v4.en*:8080/?a=1".parse().unwrap(), - "v4.enp17s0:8080/?a=2".parse().unwrap(), - ]); - let err = binds.to_bind_uris(["enp17s0"]).unwrap_err(); - assert!(err.to_string().contains("conflicting")); -} - -#[test] -fn binds_to_bind_uris_different_targets_ok() { - // Different targets with different pq → no conflict - let binds = Binds::new(vec![ - "iface://v4.enp17s0:8080/?stun=true".parse().unwrap(), - "iface://v6.enp17s0:8080/?stun=false".parse().unwrap(), - ]); - let uris = binds.to_bind_uris(["enp17s0"]).unwrap(); - assert_eq!(uris.len(), 2); -} - -#[test] -fn binds_conflict_error_display() { - let err = BindConflictError { - scheme: "iface".parse().unwrap(), - authority: "v4.enp17s0:8080".parse().unwrap(), - existing: Some("/?stun=true".parse().unwrap()), - incoming: Some("/?stun=false".parse().unwrap()), - }; - assert_eq!( - err.to_string(), - "conflicting bindings exist for bind target `iface://v4.enp17s0:8080`: `/?stun=true` vs `/?stun=false`" - ); -} - -#[test] -fn binds_conflict_error_display_none_pq() { - let err = BindConflictError { - scheme: "inet".parse().unwrap(), - authority: "127.0.0.1:8080".parse().unwrap(), - existing: None, - incoming: Some("/?key=val".parse().unwrap()), - }; - assert_eq!( - err.to_string(), - "conflicting bindings exist for bind target `inet://127.0.0.1:8080`: `/` vs `/?key=val`" - ); -} diff --git a/src/endpoint/config.rs b/src/endpoint/config.rs deleted file mode 100644 index c8ac0ba..0000000 --- a/src/endpoint/config.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! Per-role QUIC configuration values for [`QuicEndpoint`](super::QuicEndpoint). -//! -//! Configuration is split between [`CommonQuicConfig`] (shared by both roles) -//! and [`ClientOnlyConfig`] / [`ServerOnlyConfig`] (role-specific). -//! -//! [`ClientQuicConfig`] and [`ServerQuicConfig`] are cheap-to-clone wrappers -//! composed from `Arc` + `Arc`, so an endpoint Clone shares these -//! sub-trees efficiently and the endpoint's private TLS caches can reuse them -//! across clones via `Arc::ptr_eq`. -//! -//! All types implement [`Default`] so that endpoints can be constructed -//! without the caller having to hand-roll configuration values. - -use std::{sync::Arc, time::Duration}; - -use rustls::{ - client::{WebPkiServerVerifier, danger::ServerCertVerifier}, - server::{NoClientAuth, danger::ClientCertVerifier}, -}; - -use crate::dquic::{ - prelude::{ - AuthClient, ProductStreamsConcurrencyController, - handy::{ConsistentConcurrency, NoopLogger, client_parameters, server_parameters}, - }, - qbase::{ - param::{ClientParameters, ServerParameters}, - token::{TokenProvider, TokenSink, handy::NoopTokenRegistry}, - }, - qconnection::tls::AcceptAllClientAuther, - qevent::telemetry::QLog, -}; - -// --------------------------------------------------------------------------- -// Common -// --------------------------------------------------------------------------- - -/// Configuration values that apply to both client and server roles. -#[derive(Clone)] -pub struct CommonQuicConfig { - /// How long the connection should keep sending probe packets after going - /// idle. `Duration::ZERO` (the default) disables deferred idle timeouts. - pub defer_idle_timeout: Duration, - /// Factory producing per-connection streams concurrency controllers. - pub stream_strategy_factory: Arc, - /// QUIC-events logger (qlog). Defaults to a no-op logger. - pub qlogger: Arc, - /// Whether 0-RTT should be enabled if the crypto context permits it. - pub enable_0rtt: bool, - /// Enable SSL key logging via `SSLKEYLOGFILE` for debugging captures. - pub enable_sslkeylog: bool, -} - -impl Default for CommonQuicConfig { - fn default() -> Self { - Self { - defer_idle_timeout: Duration::ZERO, - stream_strategy_factory: Arc::new(ConsistentConcurrency::new), - qlogger: Arc::new(NoopLogger), - enable_0rtt: false, - enable_sslkeylog: false, - } - } -} - -impl std::fmt::Debug for CommonQuicConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CommonQuicConfig") - .field("defer_idle_timeout", &self.defer_idle_timeout) - .field("enable_0rtt", &self.enable_0rtt) - .field("enable_sslkeylog", &self.enable_sslkeylog) - .finish_non_exhaustive() - } -} - -// --------------------------------------------------------------------------- -// Client-only -// --------------------------------------------------------------------------- - -/// Strategy for verifying the server's TLS certificate. -/// -/// Kept as a small enum rather than a trait object so that -/// [`ClientOnlyConfig::verifier`] composes cheaply. The `WebPki` and `Custom` -/// variants wrap their verifier in an [`Arc`] for cheap cloning. -#[derive(Clone, Default)] -pub enum ServerCertVerifierChoice { - /// Accept any certificate. Intended for local testing only. - #[default] - Dangerous, - /// Verify against a compiled webpki verifier. - WebPki(Arc), - /// Delegate to a caller-supplied verifier. - Custom(Arc), -} - -impl std::fmt::Debug for ServerCertVerifierChoice { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Dangerous => f.write_str("Dangerous"), - Self::WebPki(_) => f.write_str("WebPki"), - Self::Custom(_) => f.write_str("Custom"), - } - } -} - -/// Client-only configuration values. -#[derive(Clone)] -pub struct ClientOnlyConfig { - /// Transport parameters advertised by the client. - pub parameters: ClientParameters, - /// ALPN protocol identifiers to offer. Empty means no ALPN. - pub alpns: Vec>, - /// Address validation token sink. - pub token_sink: Arc, - /// How the server's certificate should be verified. - pub verifier: ServerCertVerifierChoice, -} - -impl Default for ClientOnlyConfig { - fn default() -> Self { - Self { - parameters: client_parameters(), - alpns: Vec::new(), - token_sink: Arc::new(NoopTokenRegistry), - verifier: ServerCertVerifierChoice::default(), - } - } -} - -impl std::fmt::Debug for ClientOnlyConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ClientOnlyConfig") - .field("alpns", &self.alpns.len()) - .field("verifier", &self.verifier) - .finish_non_exhaustive() - } -} - -// --------------------------------------------------------------------------- -// Server-only -// --------------------------------------------------------------------------- - -/// Server-only configuration values. -#[derive(Clone)] -pub struct ServerOnlyConfig { - /// Transport parameters advertised by the server. - pub parameters: ServerParameters, - /// ALPN protocol identifiers. Empty means no ALPN. - pub alpns: Vec>, - /// Address validation token provider. - pub token_provider: Arc, - /// Maximum number of pending inbound connections before packets start - /// being dropped at the network level. - pub backlog: usize, - /// Custom client authenticator; runs on top of rustls's certificate - /// verification. Defaults to [`AcceptAllClientAuther`]. - pub client_auther: Arc, - /// How rustls should verify client certificates. Defaults to - /// [`NoClientAuth`](rustls::server::NoClientAuth). - pub client_cert_verifier: Arc, - /// When enabled, failed connections are silently dropped instead of - /// answered with an error packet. - pub anti_port_scan: bool, -} - -impl Default for ServerOnlyConfig { - fn default() -> Self { - Self { - parameters: server_parameters(), - alpns: Vec::new(), - token_provider: Arc::new(NoopTokenRegistry), - backlog: 128, - client_auther: Arc::new(AcceptAllClientAuther), - client_cert_verifier: Arc::new(NoClientAuth), - anti_port_scan: false, - } - } -} - -impl std::fmt::Debug for ServerOnlyConfig { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ServerOnlyConfig") - .field("alpns", &self.alpns.len()) - .field("backlog", &self.backlog) - .field("anti_port_scan", &self.anti_port_scan) - .finish_non_exhaustive() - } -} - -// --------------------------------------------------------------------------- -// Composite (common + own) -// --------------------------------------------------------------------------- - -/// Client-side QUIC configuration = common + client-only. -#[derive(Debug, Clone, Default)] -pub struct ClientQuicConfig { - /// Values shared by both roles. - pub common: Arc, - /// Client-specific values. - pub own: Arc, -} - -/// Server-side QUIC configuration = common + server-only. -#[derive(Debug, Clone, Default)] -pub struct ServerQuicConfig { - /// Values shared by both roles. - pub common: Arc, - /// Server-specific values. - pub own: Arc, -} diff --git a/src/endpoint/h3.rs b/src/endpoint/h3.rs deleted file mode 100644 index 6a16194..0000000 --- a/src/endpoint/h3.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! HTTP/3 endpoint built on top of [`QuicEndpoint`]. -//! -//! [`H3Endpoint`] extends [`QuicEndpoint`] with an HTTP/3 connection pool and -//! a user-configurable [`ConnectionBuilder`]. It transparently derefs to the -//! underlying [`QuicEndpoint`] so it inherits the [`Connect`](crate::quic::Connect) -//! and [`Listen`](crate::quic::Listen) trait impls. - -use std::{error::Error, ops::Deref, sync::Arc}; - -use super::quic::{AcceptError, QuicEndpoint}; -use crate::{ - connection::ConnectionBuilder, - dquic::prelude::Connection, - pool::Pool, - server::{Servers, UnresolvedRequest}, -}; - -/// HTTP/3 endpoint. -pub struct H3Endpoint { - /// Underlying QUIC endpoint. - pub quic: QuicEndpoint, - /// Connection pool shared across HTTP/3 requests. - pub pool: Pool, - /// Builder used to construct HTTP/3 connections on top of raw QUIC. - pub connection_builder: Arc>, -} - -impl H3Endpoint { - /// Construct a new HTTP/3 endpoint. - #[must_use] - pub fn new( - quic: QuicEndpoint, - pool: Pool, - connection_builder: Arc>, - ) -> Self { - Self { - quic, - pool, - connection_builder, - } - } -} - -impl Clone for H3Endpoint { - fn clone(&self) -> Self { - // Reset the pool on Clone: each endpoint owns its own connection - // pool so parallel clients do not stomp on one another. - Self { - quic: self.quic.clone(), - pool: Pool::empty(), - connection_builder: self.connection_builder.clone(), - } - } -} - -impl Deref for H3Endpoint { - type Target = QuicEndpoint; - - fn deref(&self) -> &Self::Target { - &self.quic - } -} - -impl H3Endpoint { - /// Serve HTTP/3 requests on this endpoint. - /// - /// `service` is invoked for every incoming request stream on every - /// accepted connection. It is the underlying primitive that higher-level - /// router types ([`Router`](crate::server::Router), - /// [`ServersRouter`](crate::server::ServersRouter), the tower/hyper - /// glue, …) are built on top of — any - /// `Service` will do. - /// - /// Drives accept + dispatch until the underlying [`QuicEndpoint`] shuts - /// down or encounters a fatal error, returning the accept error that - /// terminated the loop. - /// - /// Agents and the protocol registry are reachable through - /// [`UnresolvedRequest::connection`]; cloning the endpoint and calling - /// this method twice in parallel is **not** supported because the - /// underlying QUIC router only allows one connectionless dispatcher per - /// network. Call [`QuicEndpoint::shutdown`](crate::quic::Listen::shutdown) - /// before re-arming a new serve loop. - pub async fn serve(self, service: S) -> AcceptError - where - S: tower_service::Service + Clone + Send + Sync + 'static, - S::Future: Send, - S::Error: Into>, - { - let mut servers = Servers::from_quic_listener() - .listener(self.quic) - .service(service) - .pool(self.pool) - .builder(self.connection_builder) - .build(); - servers.run().await - } -} diff --git a/src/endpoint/hyper.rs b/src/endpoint/hyper.rs new file mode 100644 index 0000000..c88ea00 --- /dev/null +++ b/src/endpoint/hyper.rs @@ -0,0 +1,757 @@ +use std::{ + error::Error, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::future::{self, BoxFuture}; +use http::Method; +use http_body::Body; +use http_body_util::combinators::UnsyncBoxBody; +use snafu::{ResultExt, Snafu}; +use tracing::Instrument; + +use crate::{ + dhttp::message::{ + MessageReader, MessageStreamError, + hyper::{ + SendMessageError, + upgrade::{RemainStream, TakeoverSlot}, + }, + }, + endpoint::UnresolvedRequest, + qpack::field::MalformedHeaderSection, +}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum HandleRequestError { + #[snafu(display("failed to handle message stream"))] + Stream { source: MessageStreamError }, + #[snafu(display("response pseudo-header section is malformed"))] + MalformedHeader { source: MalformedHeaderSection }, + #[snafu(display("service error"))] + Service { source: S }, + #[snafu(display("response body error"))] + Body { source: B }, +} + +impl From> + for HandleRequestError +{ + fn from(source: SendMessageError) -> Self { + match source { + SendMessageError::Stream { source } => HandleRequestError::Stream { source }, + SendMessageError::MalformedHeader { source } => { + HandleRequestError::MalformedHeader { source } + } + SendMessageError::Body { source } => HandleRequestError::Body { source }, + } + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct TowerService(pub S); + +impl tower_service::Service for TowerService +where + S: tower_service::Service< + http::Request>, + Response = http::Response, + Error = ServiceE, + Future: Send, + > + Clone + + Send + + 'static, + ServiceE: Error + 'static, + RespBody: Body + Send, +{ + type Response = (); + type Error = HandleRequestError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0 + .poll_ready(cx) + .map_err(|source| HandleRequestError::Service { source }) + } + + fn call( + &mut self, + UnresolvedRequest { + stream_id, + read_stream, + write_stream: mut response_stream, + connection, + }: UnresolvedRequest, + ) -> Self::Future { + let span = tracing::info_span!( + "handle_request", + method = tracing::field::Empty, + uri = tracing::field::Empty + ); + + let mut service = self.0.clone(); + let future = async move { + future::poll_fn(|cx| service.poll_ready(cx)) + .await + .context(handle_request_error::ServiceSnafu)?; + + let mut request = read_stream + .into_hyper_request() + .await + .context(handle_request_error::StreamSnafu)? + .map(UnsyncBoxBody::new); + tracing::Span::current() + .record("method", request.method().as_str()) + .record("uri", request.uri().to_string()); + + tracing::trace!("converted request stream to hyper request, serving..."); + let is_connect = request.method() == Method::CONNECT; + let (remain_write_stream_tx, remain_write_stream) = RemainStream::pending(); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(connection); + if is_connect + && request + .extensions() + .get::>() + .is_some() + { + request + .extensions_mut() + .insert(TakeoverSlot::new(remain_write_stream.clone())); + } + + let response = service + .call(request) + .await + .context(handle_request_error::ServiceSnafu)?; + + response_stream.send_hyper_response(response).await?; + if is_connect { + response_stream + .flush() + .await + .context(handle_request_error::StreamSnafu)?; + _ = remain_write_stream_tx.send(response_stream); + } else { + response_stream + .close() + .await + .context(handle_request_error::StreamSnafu)?; + } + + Ok(()) + }; + Box::pin(future.instrument(span)) + } +} + +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(transparent)] +pub struct HyperService(pub S); + +impl tower_service::Service for HyperService +where + S: hyper::service::Service< + http::Request>, + Response = http::Response, + Error = ServiceE, + Future: Send, + > + Clone + + Send + + 'static, + ServiceE: Error + 'static, + RespBody: Body + Send, +{ + type Response = (); + type Error = HandleRequestError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call( + &mut self, + UnresolvedRequest { + stream_id, + read_stream, + write_stream: mut response_stream, + connection, + }: UnresolvedRequest, + ) -> Self::Future { + let span = tracing::info_span!( + "handle_request", + method = tracing::field::Empty, + uri = tracing::field::Empty + ); + + let service = self.0.clone(); + let future = async move { + let mut request = read_stream + .into_hyper_request() + .await + .context(handle_request_error::StreamSnafu)? + .map(UnsyncBoxBody::new); + tracing::Span::current() + .record("method", request.method().as_str()) + .record("uri", request.uri().to_string()); + + tracing::trace!("converted request stream to hyper request, serving..."); + let is_connect = request.method() == Method::CONNECT; + let (remain_write_stream_tx, remain_write_stream) = RemainStream::pending(); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(connection); + if is_connect + && request + .extensions() + .get::>() + .is_some() + { + request + .extensions_mut() + .insert(TakeoverSlot::new(remain_write_stream.clone())); + } + + let response = service + .call(request) + .await + .context(handle_request_error::ServiceSnafu)?; + + response_stream.send_hyper_response(response).await?; + if is_connect { + response_stream + .flush() + .await + .context(handle_request_error::StreamSnafu)?; + _ = remain_write_stream_tx.send(response_stream); + } else { + response_stream + .close() + .await + .context(handle_request_error::StreamSnafu)?; + } + + Ok(()) + }; + Box::pin(future.instrument(span)) + } +} + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, + task::Poll, + }; + + use futures::{Sink, SinkExt, Stream, future::poll_fn}; + use http::StatusCode; + use http_body_util::{BodyExt, Full}; + use tower_service::Service as _; + + use super::*; + use crate::{ + codec::{SinkWriter, StreamReader}, + connection::{ConnectionState, StreamError, tests::MockConnection}, + dhttp::{ + message::{MessageWriter, guard}, + protocol::DHttpProtocol, + settings::Settings, + }, + protocol::Protocols, + qpack::{ + decoder::DecoderInstruction, + encoder::EncoderInstruction, + protocol::{QPackDecoder, QPackEncoder}, + }, + quic, + stream_id::StreamId, + varint::VarInt, + }; + + #[derive(Debug, snafu::Snafu)] + #[snafu(display("test service failed"))] + struct TestServiceError; + + #[derive(Debug, snafu::Snafu)] + #[snafu(display("body failed"))] + struct TestBodyError; + + #[derive(Debug, Default)] + struct ServiceState { + ready: AtomicUsize, + calls: AtomicUsize, + method: Mutex>, + uri: Mutex>, + stream_id_seen: Mutex>, + connection_seen: AtomicUsize, + read_takeover_seen: AtomicUsize, + write_takeover_seen: AtomicUsize, + body: Mutex>, + } + + impl ServiceState { + fn record_request( + &self, + request: &http::Request>, + ) { + *self + .method + .lock() + .expect("method mutex should not be poisoned") = Some(request.method().clone()); + *self.uri.lock().expect("uri mutex should not be poisoned") = + Some(request.uri().clone()); + *self + .stream_id_seen + .lock() + .expect("stream id mutex should not be poisoned") = + request.extensions().get::().copied(); + if request + .extensions() + .get::>>() + .is_some() + { + self.connection_seen.fetch_add(1, Ordering::Relaxed); + } + if request + .extensions() + .get::>() + .is_some() + { + self.read_takeover_seen.fetch_add(1, Ordering::Relaxed); + } + if request + .extensions() + .get::>() + .is_some() + { + self.write_takeover_seen.fetch_add(1, Ordering::Relaxed); + } + } + + fn method(&self) -> Option { + self.method + .lock() + .expect("method mutex should not be poisoned") + .clone() + } + + fn uri(&self) -> Option { + self.uri + .lock() + .expect("uri mutex should not be poisoned") + .clone() + } + + fn stream_id(&self) -> Option { + *self + .stream_id_seen + .lock() + .expect("stream id mutex should not be poisoned") + } + + fn body(&self) -> Option { + self.body + .lock() + .expect("body mutex should not be poisoned") + .clone() + } + } + + #[derive(Clone, Debug)] + struct TestTowerService { + state: Arc, + } + + impl tower_service::Service>> + for TestTowerService + { + type Response = http::Response>; + type Error = TestServiceError; + type Future = BoxFuture<'static, Result>; + + fn poll_ready( + &mut self, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.state.ready.fetch_add(1, Ordering::Relaxed); + Poll::Ready(Ok(())) + } + + fn call( + &mut self, + request: http::Request>, + ) -> Self::Future { + self.state.calls.fetch_add(1, Ordering::Relaxed); + self.state.record_request(&request); + let state = self.state.clone(); + Box::pin(async move { + let body = request + .into_body() + .collect() + .await + .expect("request body should be readable") + .to_bytes(); + *state + .body + .lock() + .expect("body mutex should not be poisoned") = Some(body.clone()); + let mut response_body = b"tower:".to_vec(); + response_body.extend_from_slice(&body); + Ok(http::Response::builder() + .status(StatusCode::CREATED) + .header("x-service", "tower") + .body(Full::new(Bytes::from(response_body))) + .expect("response should be valid")) + }) + } + } + + #[derive(Clone, Debug)] + struct TestHyperService { + state: Arc, + } + + impl hyper::service::Service>> + for TestHyperService + { + type Response = http::Response>; + type Error = TestServiceError; + type Future = BoxFuture<'static, Result>; + + fn call( + &self, + request: http::Request>, + ) -> Self::Future { + self.state.calls.fetch_add(1, Ordering::Relaxed); + self.state.record_request(&request); + let state = self.state.clone(); + Box::pin(async move { + let body = request + .into_body() + .collect() + .await + .expect("request body should be readable") + .to_bytes(); + *state + .body + .lock() + .expect("body mutex should not be poisoned") = Some(body.clone()); + let mut response_body = b"hyper:".to_vec(); + response_body.extend_from_slice(&body); + Ok(http::Response::builder() + .status(StatusCode::ACCEPTED) + .header("x-service", "hyper") + .body(Full::new(Bytes::from(response_body))) + .expect("response should be valid")) + }) + } + } + + fn qpack_decoder_sink() -> Pin + Send>> { + Box::pin(futures::sink::drain::().sink_map_err(|never| match never {})) + } + + fn qpack_decoder_stream() + -> Pin> + Send>> { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn qpack_encoder_sink() -> Pin + Send>> { + Box::pin(futures::sink::drain::().sink_map_err(|never| match never {})) + } + + fn qpack_encoder_stream() + -> Pin> + Send>> { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn connection_state() -> ConnectionState { + let erased: Arc = Arc::new(MockConnection::new()); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + ConnectionState::new_for_test(erased, Arc::new(protocols)) + } + + fn stream_pair(stream_id: VarInt) -> (MessageReader, MessageWriter) { + let state = connection_state(); + let (reader, writer) = quic::test::mock_stream_pair_with_capacity(stream_id, 64); + let reader = StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )); + let writer = SinkWriter::new(guard::GuardQuicWriter::new( + Box::pin(writer) as crate::quic::BoxQuicStreamWriter + )); + + ( + MessageReader::new( + stream_id, + reader, + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state.clone(), + ), + MessageWriter::new( + writer, + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ), + ) + } + + async fn request_pair( + request: http::Request>, + ) -> (UnresolvedRequest, MessageReader) { + let stream_id = VarInt::from_u32(0); + let (server_read, mut request_writer) = stream_pair(stream_id); + request_writer + .send_hyper_request(request) + .await + .expect("request should be written"); + request_writer + .close() + .await + .expect("request stream should close"); + + let (response_reader, server_write) = stream_pair(stream_id); + ( + UnresolvedRequest { + stream_id: StreamId(stream_id), + read_stream: server_read, + write_stream: server_write, + connection: Arc::new(connection_state()), + }, + response_reader, + ) + } + + async fn read_response(response_reader: MessageReader) -> (http::response::Parts, Bytes) { + let response = response_reader + .into_hyper_response() + .await + .expect("response should be readable"); + let (parts, body) = response.into_parts(); + let body = body + .collect() + .await + .expect("response body should be readable") + .to_bytes(); + (parts, body) + } + + #[tokio::test] + async fn tower_service_converts_request_and_writes_response() { + let state = Arc::new(ServiceState::default()); + let mut service = TowerService(TestTowerService { + state: state.clone(), + }); + poll_fn(|cx| service.poll_ready(cx)) + .await + .expect("tower service should be ready"); + let request = http::Request::builder() + .method(Method::POST) + .uri("https://example.test/upload") + .body(Full::new(Bytes::from_static(b"payload"))) + .expect("request should be valid"); + let (unresolved, response_reader) = request_pair(request).await; + + service + .call(unresolved) + .await + .expect("request should be handled"); + + let (parts, body) = read_response(response_reader).await; + assert_eq!(parts.status, StatusCode::CREATED); + assert_eq!(parts.headers.get("x-service").unwrap(), "tower"); + assert_eq!(body, Bytes::from_static(b"tower:payload")); + assert_eq!(state.ready.load(Ordering::Relaxed), 2); + assert_eq!(state.calls.load(Ordering::Relaxed), 1); + assert_eq!(state.method(), Some(Method::POST)); + assert_eq!( + state.uri(), + Some("https://example.test/upload".parse().unwrap()) + ); + assert_eq!(state.stream_id(), Some(StreamId(VarInt::from_u32(0)))); + assert_eq!(state.connection_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.read_takeover_seen.load(Ordering::Relaxed), 0); + assert_eq!(state.write_takeover_seen.load(Ordering::Relaxed), 0); + assert_eq!(state.body(), Some(Bytes::from_static(b"payload"))); + } + + #[tokio::test] + async fn tower_service_connect_request_exposes_write_takeover_and_flushes_response() { + let state = Arc::new(ServiceState::default()); + let mut service = TowerService(TestTowerService { + state: state.clone(), + }); + let request = http::Request::builder() + .method(Method::CONNECT) + .uri("example.test:443") + .body(Full::new(Bytes::new())) + .expect("connect request should be valid"); + let (unresolved, mut response_reader) = request_pair(request).await; + + service + .call(unresolved) + .await + .expect("connect request should be handled"); + + let parts = response_reader + .read_hyper_response_parts() + .await + .expect("flushed connect response headers should be readable"); + assert_eq!(parts.status, StatusCode::CREATED); + assert_eq!(parts.headers.get("x-service").unwrap(), "tower"); + assert_eq!(state.calls.load(Ordering::Relaxed), 1); + assert_eq!(state.method(), Some(Method::CONNECT)); + assert_eq!(state.uri(), Some("example.test:443".parse().unwrap())); + assert_eq!(state.stream_id(), Some(StreamId(VarInt::from_u32(0)))); + assert_eq!(state.connection_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.read_takeover_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.write_takeover_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.body(), Some(Bytes::new())); + } + + #[tokio::test] + async fn hyper_service_converts_request_and_writes_response() { + let state = Arc::new(ServiceState::default()); + let mut service = HyperService(TestHyperService { + state: state.clone(), + }); + poll_fn(|cx| service.poll_ready(cx)) + .await + .expect("hyper wrapper should be ready"); + let request = http::Request::builder() + .method(Method::PUT) + .uri("https://example.test/resource") + .body(Full::new(Bytes::from_static(b"body"))) + .expect("request should be valid"); + let (unresolved, response_reader) = request_pair(request).await; + + service + .call(unresolved) + .await + .expect("request should be handled"); + + let (parts, body) = read_response(response_reader).await; + assert_eq!(parts.status, StatusCode::ACCEPTED); + assert_eq!(parts.headers.get("x-service").unwrap(), "hyper"); + assert_eq!(body, Bytes::from_static(b"hyper:body")); + assert_eq!(state.calls.load(Ordering::Relaxed), 1); + assert_eq!(state.method(), Some(Method::PUT)); + assert_eq!( + state.uri(), + Some("https://example.test/resource".parse().unwrap()) + ); + assert_eq!(state.stream_id(), Some(StreamId(VarInt::from_u32(0)))); + assert_eq!(state.connection_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.read_takeover_seen.load(Ordering::Relaxed), 0); + assert_eq!(state.write_takeover_seen.load(Ordering::Relaxed), 0); + assert_eq!(state.body(), Some(Bytes::from_static(b"body"))); + } + + #[tokio::test] + async fn hyper_service_connect_request_exposes_write_takeover_and_flushes_response() { + let state = Arc::new(ServiceState::default()); + let mut service = HyperService(TestHyperService { + state: state.clone(), + }); + let request = http::Request::builder() + .method(Method::CONNECT) + .uri("example.test:443") + .body(Full::new(Bytes::new())) + .expect("connect request should be valid"); + let (unresolved, mut response_reader) = request_pair(request).await; + + service + .call(unresolved) + .await + .expect("connect request should be handled"); + + let parts = response_reader + .read_hyper_response_parts() + .await + .expect("flushed connect response headers should be readable"); + assert_eq!(parts.status, StatusCode::ACCEPTED); + assert_eq!(parts.headers.get("x-service").unwrap(), "hyper"); + assert_eq!(state.calls.load(Ordering::Relaxed), 1); + assert_eq!(state.method(), Some(Method::CONNECT)); + assert_eq!(state.uri(), Some("example.test:443".parse().unwrap())); + assert_eq!(state.stream_id(), Some(StreamId(VarInt::from_u32(0)))); + assert_eq!(state.connection_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.read_takeover_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.write_takeover_seen.load(Ordering::Relaxed), 1); + assert_eq!(state.body(), Some(Bytes::new())); + } + + #[test] + fn handle_request_error_from_body_error_preserves_variant() { + let error: HandleRequestError = SendMessageError::Body { + source: TestBodyError, + } + .into(); + + assert!(matches!(error, HandleRequestError::Body { .. })); + } + + #[test] + fn handle_request_error_from_stream_error_preserves_variant() { + let error: HandleRequestError = SendMessageError::Stream { + source: MessageStreamError::Quic { + source: quic::StreamError::Reset { + code: VarInt::from_u32(7), + }, + }, + } + .into(); + + assert!(matches!(error, HandleRequestError::Stream { .. })); + } + + #[test] + fn handle_request_error_from_malformed_header_preserves_variant() { + let error: HandleRequestError = + SendMessageError::MalformedHeader { + source: MalformedHeaderSection::ProtocolInNonConnectRequest, + } + .into(); + + assert!(matches!( + error, + HandleRequestError::MalformedHeader { + source: MalformedHeaderSection::ProtocolInNonConnectRequest + } + )); + } + + #[test] + fn service_wrappers_are_transparent_newtypes() { + let tower = TowerService(7_u8); + let hyper = HyperService(9_u8); + + assert_eq!(tower, TowerService(7)); + assert_eq!(hyper, HyperService(9)); + assert_eq!(format!("{tower:?}"), "TowerService(7)"); + assert_eq!(format!("{hyper:?}"), "HyperService(9)"); + } +} diff --git a/src/endpoint/identity.rs b/src/endpoint/identity.rs deleted file mode 100644 index d4dc5bc..0000000 --- a/src/endpoint/identity.rs +++ /dev/null @@ -1,71 +0,0 @@ -//! Identity used by a [`QuicEndpoint`](super::QuicEndpoint) when performing -//! TLS handshakes. -//! -//! A [`NamedIdentity`] bundles the SNI (server name) with the certificate -//! chain and private key. When stored in an endpoint, cloning is cheap — the -//! identity is shared through an `Arc`. -//! -//! The endpoint's identity selects between client-auth / server-auth paths -//! and keys the SNI registry for inbound connection multiplexing. - -use std::sync::Arc; - -use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - -/// Name used to advertise a server in TLS SNI. -pub type ServerName = Arc; - -/// A named identity backed by a TLS certificate chain and its matching private key. -#[derive(Debug, Clone)] -pub struct NamedIdentity { - /// Server name advertised in TLS SNI (also used by h3x as the SNI registry key). - pub name: ServerName, - /// End-entity certificate followed by any intermediates. - pub certs: Vec>, - /// Private key matching the end-entity certificate. - pub key: Arc>, -} - -/// The identity that an endpoint presents. -/// -/// Cloning is cheap: `Named` variants share the underlying [`NamedIdentity`] -/// through an `Arc`, and `Anonymous` is trivial. -#[derive(Debug, Clone, Default)] -pub enum Identity { - /// No TLS identity is presented. Client-only endpoints may use this when - /// the peer does not require client authentication; server endpoints with - /// an `Anonymous` identity will refuse `accept()` operations. - #[default] - Anonymous, - /// A named identity carrying a certificate chain and private key. - Named(Arc), -} - -impl Identity { - /// Returns the server name for this identity, if any. - #[must_use] - pub fn name(&self) -> Option<&ServerName> { - match self { - Self::Anonymous => None, - Self::Named(id) => Some(&id.name), - } - } - - /// Returns `true` if this identity has a server name attached. - #[must_use] - pub fn is_named(&self) -> bool { - matches!(self, Self::Named(_)) - } -} - -impl From for Identity { - fn from(id: NamedIdentity) -> Self { - Self::Named(Arc::new(id)) - } -} - -impl From> for Identity { - fn from(id: Arc) -> Self { - Self::Named(id) - } -} diff --git a/src/endpoint/mod.rs b/src/endpoint/mod.rs deleted file mode 100644 index a10cdbb..0000000 --- a/src/endpoint/mod.rs +++ /dev/null @@ -1,33 +0,0 @@ -//! Shared QUIC / DHTTP/3 endpoint infrastructure. -//! -//! This module groups three layers that collaborate to allow many logical -//! clients and servers to share one set of physical network resources: -//! -//! - [`Network`] — process-shared infrastructure: interface manager, QUIC -//! router, STUN agent, SNI registry, and a reconcile task that keeps the -//! interface set in sync with user-supplied [`Binds`] patterns. -//! - [`QuicEndpoint`] — a QUIC-only endpoint with fully public, per-field -//! cloneable configuration (identity, resolver, per-role configs). Holds -//! private caches that are invalidated lazily via `Arc::ptr_eq`. -//! - [`H3Endpoint`] — [`QuicEndpoint`] plus an HTTP/3 connection pool and a -//! user-configurable [`ConnectionBuilder`]. -//! -//! [`ConnectionBuilder`]: crate::connection::ConnectionBuilder - -pub mod binds; -pub mod config; -pub mod h3; -pub mod identity; -pub mod network; -pub mod quic; -mod sni; - -pub use binds::{Bind, BindConflictError, BindHost, Binds}; -pub use config::{ - ClientOnlyConfig, ClientQuicConfig, CommonQuicConfig, ServerCertVerifierChoice, - ServerOnlyConfig, ServerQuicConfig, -}; -pub use h3::H3Endpoint; -pub use identity::{Identity, NamedIdentity, ServerName}; -pub use network::{BindServerError, BindsGuard, Network, NetworkBuilder, ServerBinding}; -pub use quic::{AcceptError, ConnectError, EndpointError, QuicEndpoint}; diff --git a/src/endpoint/network.rs b/src/endpoint/network.rs deleted file mode 100644 index 469e050..0000000 --- a/src/endpoint/network.rs +++ /dev/null @@ -1,763 +0,0 @@ -//! Process-shared QUIC network infrastructure. -//! -//! [`Network`] owns the long-lived components needed to send and receive QUIC -//! packets on a set of network interfaces, **and** the SNI fan-out registry -//! used to route newly-accepted connections to per-identity -//! [`QuicEndpoint`](super::QuicEndpoint) listeners. -//! -//! A [`Network`] is always used via [`Arc`]: clone the [`Arc`] to -//! share the infrastructure between many endpoints. The builder returns an -//! [`Arc`] directly via [`NetworkBuilder::build`]. -//! -//! ## SNI dispatch -//! -//! The builder installs a *connectionless packet dispatcher* on the network's -//! [`QuicRouter`]. When an Initial / 0-RTT packet arrives without matching -//! an existing connection, the dispatcher constructs a fresh server -//! [`Connection`](crate::dquic::prelude::Connection) using the shared -//! [`ServerQuicConfig`](super::ServerQuicConfig) stored in the network's -//! `server_slot`, waits for the handshake to reveal the ClientHello SNI, and -//! fans the connection into the matching [`ServerBinding`]'s mpmc queue. -//! -//! Because the underlying rustls `ServerConfig` must be chosen *before* SNI -//! is known, the network stores **one** `ServerQuicConfig` at a time. The -//! first call to [`Network::bind_server`] initialises the slot; subsequent -//! calls with an identical (or `Arc::ptr_eq`) configuration succeed, while -//! a conflicting configuration is rejected with -//! [`BindServerError::ServerConfigConflict`]. When the last [`ServerBinding`] -//! referring to the slot drops, the slot clears and a different configuration -//! may be used on the next bind. -//! -//! ## Binds registry -//! -//! [`Network::add_binds`] expands a [`Binds`] pattern against the current -//! device set, binds each resulting URI through the shared -//! [`InterfaceManager`], and installs a background reconcile task that -//! re-evaluates the pattern on device changes. The call returns a -//! [`BindsGuard`] whose [`Drop`] impl cancels the reconcile task and unbinds -//! any URIs that were opened by the entry. - -use std::{ - collections::HashMap, - net::SocketAddr, - sync::{ - Arc, Mutex, OnceLock, Weak, - atomic::{AtomicU64, Ordering}, - }, -}; - -use bon::Builder; -use dashmap::DashMap; -use futures::Stream; -use rustls::{ServerConfig as RustlsServerConfig, sign::CertifiedKey}; -use snafu::{ResultExt, Snafu}; -use tokio::sync::RwLock; -use tokio_util::task::AbortOnDropHandle; - -pub use super::sni::ServerBinding; -use super::{ - binds::{BindConflictError, Binds, watch_bind_interfaces}, - config::ServerQuicConfig, - identity::{NamedIdentity, ServerName}, - sni::{self, SniCertResolver, SniEntry, SniGuard}, -}; -use crate::dquic::{ - prelude::{ - Connection, Resolve, - handy::{DEFAULT_IO_FACTORY, SystemResolver}, - }, - qbase::packet::{DataHeader, GetDcid, Packet, long::DataHeader as LongHeader}, - qinterface::{ - BindInterface, Interface, - bind_uri::BindUri, - component::{ - Components, - alive::RebindOnNetworkChangedComponent, - location::{Locations, LocationsComponent}, - route::{QuicRouter, QuicRouterComponent, Way}, - }, - device::Devices, - io::ProductIO, - manager::InterfaceManager, - }, - qtraversal::{ - nat::{client::StunClientsComponent, router::StunRouterComponent}, - route::{ForwardersComponent, ReceiveAndDeliverPacketComponent}, - }, -}; - -type BindRegistry = Arc>>; -pub(crate) type SniRegistry = Arc>>; - -/// Error returned by [`Network::bind_server`]. -#[derive(Debug, Snafu)] -#[snafu(module, visibility(pub))] -pub enum BindServerError { - /// Another [`NamedIdentity`] is already registered for the same SNI. - #[snafu(display("sni {name} is already bound to a different identity"))] - SniInUse { - /// SNI that is already registered. - name: ServerName, - }, - /// The network already has a server configuration that is incompatible - /// with the one provided. Drop every existing [`ServerBinding`] before - /// binding with a different configuration. - #[snafu(display("network already holds an incompatible server configuration"))] - ServerConfigConflict, - /// Loading the identity's private key into rustls failed. - #[snafu(display("failed to load server private key"))] - LoadKey { - /// Underlying rustls error. - source: rustls::Error, - }, - /// rustls rejected the selected protocol version. - #[snafu(display("failed to select TLS protocol version"))] - Version { - /// Underlying rustls error. - source: rustls::Error, - }, -} - -/// Shared QUIC network infrastructure. -/// -/// Used exclusively via [`Arc`]. [`NetworkBuilder::build`] returns -/// an [`Arc`] directly after installing the SNI dispatcher on the router. -#[derive(Builder)] -#[builder(finish_fn = build_raw)] -pub struct Network { - #[builder(default = Arc::new(SystemResolver))] - pub(crate) stun_resolver: Arc, - pub(crate) stun_server: Option>, - #[builder(default = Devices::global())] - pub(crate) devices: &'static Devices, - #[builder(default = Arc::new(DEFAULT_IO_FACTORY))] - pub(crate) io_factory: Arc, - #[builder(default = InterfaceManager::global().clone())] - pub(crate) iface_manager: Arc, - #[builder(default = QuicRouter::global().clone())] - pub(crate) quic_router: Arc, - #[builder(default = Arc::new(Locations::new()))] - pub(crate) locations: Arc, - #[builder(skip = Arc::new(Mutex::new(HashMap::new())))] - bind_registry: BindRegistry, - #[builder(skip = Arc::new(AtomicU64::new(0)))] - next_bind_id: Arc, - #[builder(skip = Arc::new(DashMap::new()))] - sni_registry: SniRegistry, - #[builder(skip = RwLock::new(Weak::new()))] - server_slot: RwLock>, - #[builder(skip)] - dispatcher_installed: OnceLock<()>, -} - -impl NetworkBuilder { - /// Finalize the builder, wrap in [`Arc`], and install the SNI dispatcher. - pub fn build(self) -> Arc { - let network = Arc::new(self.build_raw()); - network.install_dispatcher(); - network - } -} - -struct BindsEntry { - _reconcile: AbortOnDropHandle<()>, - /// Live bindings keyed by [`BindUri::identity_key`]. Holds strong - /// [`BindInterface`] references so that the interfaces (and their - /// installed components) outlive the `add_binds` call — otherwise - /// each bound interface is immediately dropped and removed from the - /// [`InterfaceManager`]. - bound: Arc>>, -} - -impl Network { - /// Accessor for the interface manager. - #[must_use] - pub fn iface_manager(&self) -> &Arc { - &self.iface_manager - } - - /// Accessor for the QUIC router. - #[must_use] - pub fn quic_router(&self) -> &Arc { - &self.quic_router - } - - /// Accessor for the shared [`Locations`] table. - #[must_use] - pub fn locations(&self) -> &Arc { - &self.locations - } - - /// Accessor for the device tracker. - #[must_use] - pub fn devices(&self) -> &'static Devices { - self.devices - } - - /// Accessor for the I/O factory. - #[must_use] - pub fn io_factory(&self) -> &Arc { - &self.io_factory - } - - /// Accessor for the STUN server, if configured. - #[must_use] - pub fn stun_server(&self) -> Option<&Arc> { - self.stun_server.as_ref() - } - - /// Accessor for the resolver used when looking up STUN server addresses. - #[must_use] - pub fn stun_resolver(&self) -> &Arc { - &self.stun_resolver - } - - /// Bind the given URI on this network. - /// - /// In addition to acquiring the underlying [`BindInterface`] from the - /// [`InterfaceManager`], this also installs the QUIC packet routing, - /// location tracking, STUN client and packet receive/deliver components - /// on the interface so that QUIC packets actually flow. - pub async fn bind(&self, bind_uri: BindUri) -> BindInterface { - let bind_iface = self - .iface_manager - .bind(bind_uri, self.io_factory.clone()) - .await; - self.init_iface_components(&bind_iface); - bind_iface - } - - fn init_iface_components(&self, bind_iface: &BindInterface) { - let stun_agent = if let Some(server) = bind_iface.bind_uri().stun_server() { - Some(Arc::from(server)) - } else if let Some("false") = bind_iface.bind_uri().prop(BindUri::STUN_PROP).as_deref() { - None - } else { - self.stun_server.clone() - }; - let resolver = self.stun_resolver.clone(); - let devices = self.devices; - let quic_router = self.quic_router.clone(); - let locations = self.locations.clone(); - bind_iface.with_components_mut(move |components: &mut Components, iface: &Interface| { - components.init_with(|| RebindOnNetworkChangedComponent::new(iface, devices)); - let quic_router = components - .init_with(|| QuicRouterComponent::new(quic_router)) - .router(); - let locations = components - .init_with(|| LocationsComponent::new(iface.downgrade(), locations)) - .clone(); - - match stun_agent { - Some(stun_server) => { - let stun_router = components - .init_with(|| StunRouterComponent::new(iface.downgrade())) - .router(); - let clients = components - .init_with(|| { - StunClientsComponent::new( - iface.downgrade(), - stun_router.clone(), - resolver, - stun_server, - [], - Some(locations.clone()), - ) - }) - .clone(); - let relay = bind_iface - .bind_uri() - .relay() - .and_then(|r| r.parse::().ok()); - let forwarder = if let Some(relay) = relay { - components - .init_with(|| ForwardersComponent::new_server(relay)) - .forwarder() - } else { - components - .init_with(|| ForwardersComponent::new_client(clients)) - .forwarder() - }; - components.init_with(|| { - ReceiveAndDeliverPacketComponent::builder(iface.downgrade()) - .quic_router(quic_router) - .stun_router(stun_router) - .forwarder(forwarder) - .init() - }); - } - None => { - components.init_with(|| { - ReceiveAndDeliverPacketComponent::builder(iface.downgrade()) - .quic_router(quic_router) - .init() - }); - } - }; - }); - } - - /// Bind many URIs in parallel. - pub async fn bind_many( - self: &Arc, - bind_uris: impl IntoIterator>, - ) -> impl Stream { - use futures::stream::FuturesUnordered; - - bind_uris - .into_iter() - .map(|bind_uri| { - let network = self.clone(); - async move { network.bind(bind_uri.into()).await } - }) - .collect::>() - } - - /// Register a [`Binds`] pattern with this network. - pub async fn add_binds( - self: &Arc, - binds: &Binds, - ) -> Result> { - let monitor = self.devices.monitor(); - let initial_uris = binds.to_bind_uris(monitor.interfaces().keys().map(String::as_str))?; - - // Bind every initial URI up-front and hold strong - // [`BindInterface`] references so the interfaces stay alive for - // the lifetime of the returned [`BindsGuard`]. - let mut initial_bound: HashMap = - HashMap::with_capacity(initial_uris.len()); - for uri in &initial_uris { - let iface = self.bind(uri.clone()).await; - initial_bound.insert(uri.identity_key(), (uri.clone(), iface)); - } - let bound = Arc::new(Mutex::new(initial_bound)); - - let reconcile = { - let network = self.clone(); - let bound_bind = bound.clone(); - let bind_fn = move |uri: BindUri| { - let network = network.clone(); - let bound_bind = bound_bind.clone(); - Box::pin(async move { - let iface = network.bind(uri.clone()).await; - bound_bind - .lock() - .unwrap() - .insert(uri.identity_key(), (uri, iface)); - }) - as std::pin::Pin + Send>> - }; - - let iface_manager_unbind = self.iface_manager.clone(); - let bound_unbind = bound.clone(); - let unbind_fn = move |uri: BindUri| { - // Drop the strong `BindInterface` held here before asking - // the manager to unbind; otherwise the interface would - // still be referenced by `bound` and `unbind` could not - // fully tear it down. - bound_unbind.lock().unwrap().remove(&uri.identity_key()); - let iface_manager_unbind = iface_manager_unbind.clone(); - tokio::spawn(async move { - iface_manager_unbind.unbind(uri).await; - }); - }; - - watch_bind_interfaces(binds, monitor, initial_uris, bind_fn, unbind_fn) - }; - - let id = self.next_bind_id.fetch_add(1, Ordering::Relaxed); - self.bind_registry.lock().unwrap().insert( - id, - BindsEntry { - _reconcile: reconcile, - bound: bound.clone(), - }, - ); - - Ok(BindsGuard { - id, - registry: self.bind_registry.clone(), - iface_manager: self.iface_manager.clone(), - bound, - }) - } - - /// Snapshot of the URIs currently bound via [`Network::add_binds`]. - #[must_use] - pub fn current_bind_uris(&self) -> Vec { - let registry = self.bind_registry.lock().unwrap(); - registry - .values() - .flat_map(|entry| { - entry - .bound - .lock() - .unwrap() - .values() - .map(|(uri, _)| uri.clone()) - .collect::>() - }) - .collect() - } - - /// Snapshot of the [`BindInterface`]s currently bound via - /// [`Network::add_binds`]. Prefer this over - /// [`current_bind_uris`](Self::current_bind_uris) + - /// [`get_iface`](Self::get_iface) when you need live interface - /// references: the latter pair races against interface drops and may - /// return fewer interfaces than were just bound. - #[must_use] - pub fn current_bind_interfaces(&self) -> Vec { - let registry = self.bind_registry.lock().unwrap(); - registry - .values() - .flat_map(|entry| { - entry - .bound - .lock() - .unwrap() - .values() - .map(|(_, iface)| iface.clone()) - .collect::>() - }) - .collect() - } - - /// Look up the [`BindInterface`] currently registered for `bind_uri`. - /// - /// Returns `None` if the URI was never bound (directly or through - /// [`Network::add_binds`]) or if the interface was already released. - /// This is a thin wrapper around - /// [`InterfaceManager::get`](crate::dquic::qinterface::manager::InterfaceManager::get) - /// that keeps callers from having to reach into the manager directly. - #[must_use] - pub fn get_iface(&self, bind_uri: &BindUri) -> Option { - self.iface_manager.get(bind_uri) - } - - /// Snapshot of SNI names currently registered on this network. - /// - /// A name appears in the result iff at least one [`ServerBinding`] - /// (or a cache entry of a [`QuicEndpoint`](super::QuicEndpoint)) for - /// that name is still live. Names whose `Weak` cannot be - /// upgraded — i.e. whose last binding was dropped but whose registry - /// slot has not yet been cleared by - /// [`SniGuard`](super::sni::SniGuard)'s `Drop` — are filtered out. - #[must_use] - pub fn registered_sni_names(&self) -> Vec { - self.sni_registry - .iter() - .filter_map(|kv| kv.value().upgrade().map(|_| kv.key().clone())) - .collect() - } -} - -// --------------------------------------------------------------------------- -// SNI dispatcher + bind_server -// --------------------------------------------------------------------------- - -impl Network { - /// Install the connectionless-packet dispatcher on the network's QUIC - /// router. Idempotent — subsequent calls are no-ops. - fn install_dispatcher(self: &Arc) { - if self.dispatcher_installed.set(()).is_err() { - return; - } - let weak = Arc::downgrade(self); - let installed = self - .quic_router - .on_connectless_packets(move |packet: Packet, way: Way| { - let Some(network) = weak.upgrade() else { - return; - }; - network.dispatch_initial_packet(packet, way); - }); - if !installed { - tracing::warn!( - target: "h3x::endpoint", - "quic router already has a connectionless dispatcher installed; \ - this Network's server endpoints will be unreachable" - ); - } - } - - /// Called from the router's connectionless-packet hook. - fn dispatch_initial_packet(self: &Arc, packet: Packet, way: Way) { - // Fast path: if no SNI is registered, drop before touching the slot. - if self.sni_registry.is_empty() { - return; - } - let origin_dcid = match &packet { - Packet::Data(data_packet) => match &data_packet.header { - DataHeader::Long(LongHeader::Initial(hdr)) => *hdr.dcid(), - DataHeader::Long(LongHeader::ZeroRtt(hdr)) => *hdr.dcid(), - _ => return, - }, - _ => return, - }; - if origin_dcid.is_empty() { - return; - } - - // Upgrade the slot synchronously. If the Weak is dead, no server is - // registered right now — drop. - let slot_opt = self.server_slot.try_read().ok().and_then(|g| g.upgrade()); - let Some(slot) = slot_opt else { - return; - }; - - // IMPORTANT: construct the server `Connection` synchronously so that - // the CID route on the shared `QuicRouter` is installed before this - // function returns. Otherwise a second Initial packet with the same - // ODCID arriving during the async gap would still be dispatched as - // "connectionless", spawning another server `Connection` for the - // same ODCID — both would attempt to respond with Handshake packets - // derived from independent keying state, and the peer would decrypt - // later packets with the wrong keys (observed as - // "Invalid reserved bits" transport errors). - // - // Only the genuinely async work (packet delivery + waiting for the - // ClientHello / SNI resolution) is spawned below. - let cfg = &slot.config; - let connection = Connection::new_server(cfg.own.token_provider.clone()) - .with_parameters(cfg.own.parameters.clone()) - .with_client_auther(Box::new(cfg.own.client_auther.clone())) - .with_tls_config((*slot.rustls_config).clone()) - .with_streams_concurrency_strategy(cfg.common.stream_strategy_factory.as_ref()) - .with_zero_rtt(cfg.common.enable_0rtt) - .with_iface_factory(self.io_factory.clone()) - .with_iface_manager(self.iface_manager.clone()) - .with_quic_router(self.quic_router.clone()) - .with_locations(self.locations.clone()) - .with_defer_idle_timeout(cfg.common.defer_idle_timeout) - .with_cids(origin_dcid) - .with_qlog(cfg.common.qlogger.clone()) - .run(); - - let network = self.clone(); - let sni_registry = self.sni_registry.clone(); - let task = async move { - network.quic_router.deliver(packet, way).await; - match connection.server_name().await { - Ok(name) => { - let _ = connection.subscribe_local_address(); - let entry = sni_registry.iter().find_map(|kv| { - if kv.key().eq_ignore_ascii_case(&name) { - kv.value().upgrade() - } else { - None - } - }); - match entry { - Some(entry) => { - if entry.incomings_tx.try_send(connection).is_err() { - tracing::debug!( - target: "h3x::endpoint", - name = %name, - "accept backlog full or no receiver, dropping connection" - ); - } - } - None => { - tracing::debug!( - target: "h3x::endpoint", - name = %name, - "no endpoint registered for SNI, dropping connection" - ); - } - } - } - Err(error) => { - tracing::debug!( - target: "h3x::endpoint", - "failed to obtain server name, dropping connection: {error}" - ); - } - } - }; - // Detached: the Network lives for the process lifetime and these - // tasks are short-lived and self-contained. Previously these were - // pushed into a `JoinSet` that nothing ever drained, which leaked - // one task-output slot per incoming handshake. - tokio::spawn(task); - } - - /// Register a server-side identity on this network. - /// - /// On success returns a [`ServerBinding`] whose [`ServerBinding::recv`] - /// yields accepted connections whose ClientHello SNI matches - /// `named.name`. Cloning the returned binding is cheap and clones share - /// the inbound mpmc queue so concurrent receivers cooperate. - pub async fn bind_server( - self: &Arc, - named: Arc, - server_config: ServerQuicConfig, - ) -> Result { - let name = named.name.clone(); - - // Reuse path: existing SNI registration with the same identity. - if let Some(weak) = self.sni_registry.get(&name).map(|kv| kv.value().clone()) - && let Some(entry) = weak.upgrade() - { - if !Arc::ptr_eq(&entry.named_identity, &named) { - return Err(BindServerError::SniInUse { name }); - } - return Ok(ServerBinding { - name, - _guard: entry.guard.clone(), - entry, - }); - } - - // Slot: reuse existing compatible slot, or initialise a new one. - let slot = { - let mut slot_guard = self.server_slot.write().await; - match slot_guard.upgrade() { - Some(existing) => { - if !server_config_compatible(&existing.config, &server_config) { - return Err(BindServerError::ServerConfigConflict); - } - existing - } - None => { - let rustls_config = build_rustls_server_config( - &server_config, - SniCertResolver { - registry: Arc::downgrade(&self.sni_registry), - }, - )?; - let inner = Arc::new(sni::ServerSlotInner { - config: server_config.clone(), - rustls_config: Arc::new(rustls_config), - }); - *slot_guard = Arc::downgrade(&inner); - inner - } - } - }; - - let certified_key = build_certified_key(&named)?; - let backlog = server_config.own.backlog.max(1); - let (tx, rx) = async_channel::bounded(backlog); - - let guard = Arc::new(SniGuard { - name: name.clone(), - registry: Arc::downgrade(&self.sni_registry), - }); - let entry = Arc::new(SniEntry { - named_identity: named, - certified_key, - incomings_tx: tx, - incomings_rx: rx, - _slot: slot, - guard: guard.clone(), - }); - self.sni_registry - .insert(name.clone(), Arc::downgrade(&entry)); - - Ok(ServerBinding { - name, - entry, - _guard: guard, - }) - } -} - -/// Build the rustls server config shared across all SNIs registered on a -/// network. The resolver selects a [`CertifiedKey`] based on ClientHello SNI. -fn build_rustls_server_config( - server_config: &ServerQuicConfig, - resolver: SniCertResolver, -) -> Result { - use bind_server_error::VersionSnafu; - - const TLS13: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13]; - let provider = RustlsServerConfig::builder().crypto_provider().clone(); - let builder = RustlsServerConfig::builder_with_provider(provider) - .with_protocol_versions(TLS13) - .context(VersionSnafu)?; - - let mut tls = builder - .with_client_cert_verifier(server_config.own.client_cert_verifier.clone()) - .with_cert_resolver(Arc::new(resolver)); - tls.alpn_protocols.clone_from(&server_config.own.alpns); - if server_config.common.enable_0rtt { - tls.max_early_data_size = 0xffff_ffff; - } - Ok(tls) -} - -fn build_certified_key(named: &NamedIdentity) -> Result, BindServerError> { - use bind_server_error::LoadKeySnafu; - - let provider = RustlsServerConfig::builder().crypto_provider().clone(); - let key = provider - .key_provider - .load_private_key(named.key.clone_key()) - .context(LoadKeySnafu)?; - Ok(Arc::new(CertifiedKey { - cert: named.certs.clone(), - key, - ocsp: None, - })) -} - -/// Returns `true` when `a` and `b` describe the same server configuration. -/// -/// Fast path: whole-[`Arc`] pointer equality for `common` and `own`. -/// Slow path: field-by-field comparison, using [`PartialEq`] where available -/// and [`Arc::ptr_eq`] for the trait-object members. -fn server_config_compatible(a: &ServerQuicConfig, b: &ServerQuicConfig) -> bool { - if Arc::ptr_eq(&a.common, &b.common) && Arc::ptr_eq(&a.own, &b.own) { - return true; - } - let common_eq = a.common.defer_idle_timeout == b.common.defer_idle_timeout - && a.common.enable_0rtt == b.common.enable_0rtt - && a.common.enable_sslkeylog == b.common.enable_sslkeylog - && Arc::ptr_eq( - &a.common.stream_strategy_factory, - &b.common.stream_strategy_factory, - ) - && Arc::ptr_eq(&a.common.qlogger, &b.common.qlogger); - let own_eq = a.own.alpns == b.own.alpns - && a.own.backlog == b.own.backlog - && a.own.anti_port_scan == b.own.anti_port_scan - && a.own.parameters == b.own.parameters - && Arc::ptr_eq(&a.own.token_provider, &b.own.token_provider) - && Arc::ptr_eq(&a.own.client_auther, &b.own.client_auther) - && Arc::ptr_eq(&a.own.client_cert_verifier, &b.own.client_cert_verifier); - common_eq && own_eq -} - -// --------------------------------------------------------------------------- -// BindsGuard -// --------------------------------------------------------------------------- - -/// RAII guard returned by [`Network::add_binds`]. -pub struct BindsGuard { - id: u64, - registry: BindRegistry, - iface_manager: Arc, - bound: Arc>>, -} - -impl Drop for BindsGuard { - fn drop(&mut self) { - self.registry.lock().unwrap().remove(&self.id); - // Drop our strong references first so `unbind` can fully tear - // down each interface. - let uris: Vec = self - .bound - .lock() - .unwrap() - .drain() - .map(|(_, (uri, _iface))| uri) - .collect(); - for uri in uris { - let iface_manager = self.iface_manager.clone(); - tokio::spawn(async move { - iface_manager.unbind(uri).await; - }); - } - } -} diff --git a/src/endpoint/quic.rs b/src/endpoint/quic.rs deleted file mode 100644 index 4051255..0000000 --- a/src/endpoint/quic.rs +++ /dev/null @@ -1,461 +0,0 @@ -//! QUIC-only endpoint built on top of a shared [`Network`]. - -use std::sync::Arc; - -use arc_swap::ArcSwapOption; -use futures::StreamExt; -use rustls::{ClientConfig, pki_types::PrivateKeyDer}; -use snafu::{ResultExt, Snafu}; - -use super::{ - config::{ClientQuicConfig, ServerCertVerifierChoice, ServerQuicConfig}, - identity::{Identity, NamedIdentity}, - network::{BindServerError, Network, ServerBinding}, -}; -use crate::{ - dquic::{ - prelude::{Connection, Resolve}, - qbase::{ - cid::ConnectionId, - net::{ - Family, - addr::{AddrKind, BoundAddr, EndpointAddr, SocketEndpointAddr}, - route::{Link, Pathway}, - }, - }, - qinterface::{bind_uri::BindUri, io::IO}, - qresolve::Source, - }, - quic, - util::tls::DangerousServerCertVerifier, -}; - -/// Error building the client-side TLS configuration. -#[derive(Debug, Snafu)] -#[snafu(module, visibility(pub))] -pub enum BuildClientTlsError { - /// rustls failed to choose a supported protocol version. - #[snafu(display("failed to select TLS protocol version"))] - Version { - /// Underlying rustls error. - source: rustls::Error, - }, - /// rustls refused the provided client certificate / key. - #[snafu(display("failed to load client authentication certificate"))] - ClientAuth { - /// Underlying rustls error. - source: rustls::Error, - }, -} - -/// Error returned by [`QuicEndpoint`] when opening an outbound connection. -#[derive(Debug, Snafu)] -#[snafu(module, visibility(pub))] -pub enum ConnectError { - /// Failed to build the client TLS configuration. - #[snafu(display("failed to build client TLS config"))] - Tls { - /// Underlying build error. - source: BuildClientTlsError, - }, - /// DNS resolution failed. - #[snafu(display("dns lookup failed"))] - Dns { - /// Underlying I/O error. - source: std::io::Error, - }, - /// The resolver produced no reachable endpoint. - #[snafu(display("no reachable endpoint found for server"))] - NoReachableEndpoint, - /// Failed to acquire a local interface for the discovered endpoint. - #[snafu(display("failed to bind local interface"))] - BindInterface { - /// Underlying I/O error. - source: std::io::Error, - }, -} - -/// Error returned by [`QuicEndpoint`] when awaiting an inbound connection. -#[derive(Debug, Snafu)] -#[snafu(module, visibility(pub))] -pub enum AcceptError { - /// The endpoint's identity is anonymous — no SNI to register. - #[snafu(display("cannot accept connections on an anonymous identity"))] - ServerUnavailable, - /// Registering the identity on the network failed. - #[snafu(display("failed to bind server on network"))] - BindServer { - /// Underlying network error. - source: BindServerError, - }, - /// The endpoint has been shut down. - #[snafu(display("endpoint has been shut down"))] - Shutdown, -} - -/// Back-compat alias. -pub type EndpointError = ConnectError; - -/// A QUIC-only endpoint backed by a shared [`Network`]. -pub struct QuicEndpoint { - /// Shared network infrastructure. - pub network: Arc, - /// TLS identity for this endpoint. - pub identity: Identity, - /// Resolver used when establishing outbound connections. - pub resolver: Arc, - /// Client-side configuration. - pub client: ClientQuicConfig, - /// Server-side configuration. - pub server: ServerQuicConfig, - client_tls_cache: ArcSwapOption, - server_binding_cache: ArcSwapOption, -} - -struct CachedClientTls { - key: ClientCacheKey, - config: Arc, -} - -#[derive(PartialEq, Eq)] -struct ClientCacheKey { - identity_ptr: usize, - client_own_ptr: usize, - client_common_ptr: usize, -} - -struct CachedServerBinding { - key: ServerCacheKey, - binding: ServerBinding, -} - -#[derive(PartialEq, Eq)] -struct ServerCacheKey { - network_ptr: usize, - identity_ptr: usize, - server_own_ptr: usize, - server_common_ptr: usize, -} - -impl Clone for QuicEndpoint { - fn clone(&self) -> Self { - Self { - network: self.network.clone(), - identity: self.identity.clone(), - resolver: self.resolver.clone(), - client: self.client.clone(), - server: self.server.clone(), - client_tls_cache: ArcSwapOption::empty(), - server_binding_cache: ArcSwapOption::empty(), - } - } -} - -impl QuicEndpoint { - /// Construct a new endpoint. - #[must_use] - pub fn new( - network: Arc, - identity: Identity, - resolver: Arc, - client: ClientQuicConfig, - server: ServerQuicConfig, - ) -> Self { - Self { - network, - identity, - resolver, - client, - server, - client_tls_cache: ArcSwapOption::empty(), - server_binding_cache: ArcSwapOption::empty(), - } - } - - fn client_cache_key(&self) -> ClientCacheKey { - ClientCacheKey { - identity_ptr: identity_ptr(&self.identity), - client_own_ptr: Arc::as_ptr(&self.client.own) as usize, - client_common_ptr: Arc::as_ptr(&self.client.common) as usize, - } - } - - fn server_cache_key(&self) -> ServerCacheKey { - ServerCacheKey { - network_ptr: Arc::as_ptr(&self.network) as usize, - identity_ptr: identity_ptr(&self.identity), - server_own_ptr: Arc::as_ptr(&self.server.own) as usize, - server_common_ptr: Arc::as_ptr(&self.server.common) as usize, - } - } -} - -fn identity_ptr(identity: &Identity) -> usize { - match identity { - Identity::Anonymous => 0, - Identity::Named(id) => Arc::as_ptr(id) as usize, - } -} - -impl QuicEndpoint { - fn client_tls(&self) -> Result, BuildClientTlsError> { - let key = self.client_cache_key(); - if let Some(cached) = self.client_tls_cache.load_full() - && cached.key == key - { - return Ok(cached.config.clone()); - } - let config = Arc::new(self.build_client_tls()?); - self.client_tls_cache.store(Some(Arc::new(CachedClientTls { - key, - config: config.clone(), - }))); - Ok(config) - } - - fn build_client_tls(&self) -> Result { - use build_client_tls_error::{ClientAuthSnafu, VersionSnafu}; - - const TLS13: &[&rustls::SupportedProtocolVersion] = &[&rustls::version::TLS13]; - let provider = ClientConfig::builder().crypto_provider().clone(); - let builder = ClientConfig::builder_with_provider(provider) - .with_protocol_versions(TLS13) - .context(VersionSnafu)?; - let builder = match &self.client.own.verifier { - ServerCertVerifierChoice::Dangerous => builder - .dangerous() - .with_custom_certificate_verifier(Arc::new(DangerousServerCertVerifier)), - ServerCertVerifierChoice::WebPki(v) => builder.with_webpki_verifier(v.clone()), - ServerCertVerifierChoice::Custom(v) => builder - .dangerous() - .with_custom_certificate_verifier(v.clone()), - }; - let mut tls = match &self.identity { - Identity::Anonymous => builder.with_no_client_auth(), - Identity::Named(id) => builder - .with_client_auth_cert(id.certs.clone(), clone_key(&id.key)) - .context(ClientAuthSnafu)?, - }; - tls.alpn_protocols.clone_from(&self.client.own.alpns); - tls.enable_early_data = self.client.common.enable_0rtt; - Ok(tls) - } -} - -impl QuicEndpoint { - fn build_client_connection( - &self, - server_name: &str, - tls: Arc, - ) -> Arc { - // Propagate the endpoint's named identity into the QUIC transport - // `ClientName` parameter so the peer can populate its - // `remote_agent` (identity-based access control on the server - // relies on this). - let mut parameters = self.client.own.parameters.clone(); - if let Identity::Named(named) = &self.identity { - parameters - .set( - crate::dquic::qbase::param::ParameterId::ClientName, - named.name.to_string(), - ) - .expect("ClientName is a client-only string parameter"); - } - Connection::new_client(server_name.to_owned(), self.client.own.token_sink.clone()) - .with_parameters(parameters) - .with_tls_config((*tls).clone()) - .with_streams_concurrency_strategy(self.client.common.stream_strategy_factory.as_ref()) - .with_zero_rtt(self.client.common.enable_0rtt) - .with_iface_factory(self.network.io_factory().clone()) - .with_iface_manager(self.network.iface_manager().clone()) - .with_quic_router(self.network.quic_router().clone()) - .with_locations(self.network.locations().clone()) - .with_defer_idle_timeout(self.client.common.defer_idle_timeout) - .with_cids(ConnectionId::random_gen(8)) - .with_qlog(self.client.common.qlogger.clone()) - .run() - } - - async fn setup_server_endpoint( - &self, - connection: &Connection, - source: Source, - server_ep: EndpointAddr, - ) -> Result { - use connect_error::BindInterfaceSnafu; - - let _ = connection.add_peer_endpoint(server_ep, source.clone()); - - let bind_uri = bind_uri_for(&source, &server_ep); - let iface = self.network.bind(bind_uri).await; - - if matches!( - server_ep, - EndpointAddr::Socket(SocketEndpointAddr::Agent { .. }) - ) { - return Ok(false); - } - - let interface = iface.borrow(); - let bound_addr = interface.bound_addr().context(BindInterfaceSnafu)?; - if bound_addr.kind() != server_ep.addr_kind() { - return Ok(false); - } - let dst = match server_ep { - EndpointAddr::Socket(s) => BoundAddr::Internet(*s), - EndpointAddr::Ble(_) => return Ok(false), - }; - let link = Link::new(bound_addr, dst); - let pathway = Pathway::new(bound_addr.into(), server_ep); - let _ = connection.add_path(iface.bind_uri(), link, pathway); - Ok(true) - } -} - -impl QuicEndpoint { - async fn server_binding(&self) -> Result { - use accept_error::BindServerSnafu; - - let named = match &self.identity { - Identity::Anonymous => return Err(AcceptError::ServerUnavailable), - Identity::Named(id) => id.clone(), - }; - let key = self.server_cache_key(); - if let Some(cached) = self.server_binding_cache.load_full() - && cached.key == key - { - return Ok(cached.binding.clone()); - } - let binding = self - .network - .bind_server(named as Arc, self.server.clone()) - .await - .context(BindServerSnafu)?; - self.server_binding_cache - .store(Some(Arc::new(CachedServerBinding { - key, - binding: binding.clone(), - }))); - Ok(binding) - } -} - -impl quic::Connect for QuicEndpoint { - type Connection = Connection; - type Error = ConnectError; - - async fn connect( - &self, - server: &http::uri::Authority, - ) -> Result, Self::Error> { - use connect_error::{DnsSnafu, TlsSnafu}; - - let server_str = match server.port_u16() { - Some(port) => format!("{}:{}", server.host(), port), - None => server.host().to_string(), - }; - - let tls = self.client_tls().context(TlsSnafu)?; - - let mut server_eps = self.resolver.lookup(&server_str).await.context(DnsSnafu)?; - - let connection = self.build_client_connection(&server_str, tls); - if connection.subscribe_local_address().is_err() { - return Ok(connection); - } - - let mut last_error: Option = None; - let mut any_viable = false; - - while let Some((source, server_ep)) = server_eps.next().await { - match self - .setup_server_endpoint(&connection, source, server_ep) - .await - { - Ok(true) => { - any_viable = true; - last_error = None; - break; - } - Ok(false) => { - any_viable = true; - last_error = None; - } - Err(error) => { - last_error.get_or_insert(error); - } - } - } - if !any_viable { - return Err(last_error.unwrap_or(ConnectError::NoReachableEndpoint)); - } - - tokio::spawn({ - let weak_connection = Arc::downgrade(&connection); - let terminated = connection.terminated(); - let endpoint = self.clone(); - async move { - tokio::pin!(terminated); - loop { - tokio::select! { - biased; - _ = &mut terminated => break, - next = server_eps.next() => { - let Some((source, server_ep)) = next else { break }; - let Some(connection) = weak_connection.upgrade() else { break }; - let _ = endpoint - .setup_server_endpoint(&connection, source, server_ep) - .await; - } - } - } - } - }); - - Ok(connection) - } -} - -impl quic::Listen for QuicEndpoint { - type Connection = Connection; - type Error = AcceptError; - - async fn accept(&mut self) -> Result, Self::Error> { - let binding = self.server_binding().await?; - binding.recv().await.ok_or(AcceptError::Shutdown) - } - - async fn shutdown(&self) -> Result<(), Self::Error> { - self.server_binding_cache.store(None); - Ok(()) - } -} - -fn clone_key(key: &Arc>) -> PrivateKeyDer<'static> { - key.clone_key() -} - -fn bind_uri_for(source: &Source, ep: &EndpointAddr) -> BindUri { - use std::str::FromStr; - - match source { - Source::Mdns { nic, family } => { - let f = match family { - Family::V4 => "v4", - Family::V6 => "v6", - }; - BindUri::from_str(&format!("iface://{f}.{nic}:0")) - .expect("iface URI should be valid") - .alloc_port() - } - _ => match ep.addr_kind() { - AddrKind::Internet(Family::V4) => BindUri::from_str("inet://0.0.0.0:0") - .expect("URL should be valid") - .alloc_port(), - AddrKind::Internet(Family::V6) => BindUri::from_str("inet://[::]:0") - .expect("URL should be valid") - .alloc_port(), - _ => unreachable!("BLE and other address kinds are not supported yet"), - }, - } -} diff --git a/src/endpoint/sni.rs b/src/endpoint/sni.rs deleted file mode 100644 index 0050fca..0000000 --- a/src/endpoint/sni.rs +++ /dev/null @@ -1,130 +0,0 @@ -//! Per-SNI server state used by [`Network`](super::Network) to fan out -//! incoming connections across many [`QuicEndpoint`](super::QuicEndpoint) -//! instances. -//! -//! A [`ServerBinding`] is cheap to clone: each clone shares the same -//! mpmc [`async_channel`] tail, so multiple endpoints that registered the -//! same SNI cooperatively drain inbound connections. Dropping the last -//! strong reference unregisters the SNI entry from the network. - -use std::sync::{Arc, Weak}; - -use dashmap::DashMap; -use rustls::{ - server::{ClientHello, ResolvesServerCert}, - sign::CertifiedKey, -}; - -use super::identity::{NamedIdentity, ServerName}; -use crate::dquic::prelude::Connection; - -/// Per-SNI entry stored behind a `Weak` in the network's registry. -/// -/// Holds an mpmc channel so multiple [`ServerBinding`] clones share the -/// same inbound connection queue. -pub(crate) struct SniEntry { - pub(crate) named_identity: Arc, - pub(crate) certified_key: Arc, - pub(crate) incomings_tx: async_channel::Sender>, - pub(crate) incomings_rx: async_channel::Receiver>, - /// Keeps the shared server slot alive for the lifetime of this entry. - pub(crate) _slot: Arc, - /// Shared guard — cloned into every `ServerBinding`. When the last - /// clone drops, the entry is removed from `sni_registry`. - pub(crate) guard: Arc, -} - -/// RAII guard that removes an SNI entry from the registry when the last -/// [`ServerBinding`] referencing it is dropped. -pub(crate) struct SniGuard { - pub(crate) name: ServerName, - pub(crate) registry: Weak>>, -} - -impl Drop for SniGuard { - fn drop(&mut self) { - if let Some(registry) = self.registry.upgrade() { - registry.remove(&self.name); - } - } -} - -/// Shared server-side QUIC/TLS context. At most one instance exists per -/// [`Network`](super::Network) at any time; identical instances are shared -/// across all registered SNIs, and conflicting configurations are rejected -/// at `bind_server` time. -pub(crate) struct ServerSlotInner { - pub(crate) config: super::config::ServerQuicConfig, - pub(crate) rustls_config: Arc, -} - -/// Public handle returned by [`Network::bind_server`](super::Network::bind_server). -/// -/// Cloning is cheap and yields a new receiver on the **same** mpmc queue — -/// concurrently calling `recv` from multiple clones fans out inbound -/// connections across the clones without duplicating work. -pub struct ServerBinding { - /// Server name this binding was registered under. - pub name: ServerName, - pub(crate) entry: Arc, - pub(crate) _guard: Arc, -} - -impl Clone for ServerBinding { - fn clone(&self) -> Self { - Self { - name: self.name.clone(), - entry: self.entry.clone(), - _guard: self._guard.clone(), - } - } -} - -impl std::fmt::Debug for ServerBinding { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ServerBinding") - .field("name", &self.name) - .finish_non_exhaustive() - } -} - -impl ServerBinding { - /// Receive the next accepted connection for this SNI. - /// - /// Returns `None` once the network is shut down or has no remaining - /// senders. - pub async fn recv(&self) -> Option> { - self.entry.incomings_rx.recv().await.ok() - } -} - -/// rustls `ResolvesServerCert` backed by the network's SNI registry. -/// -/// SNI names are matched ASCII case-insensitively per RFC 6066 §3. -#[derive(Clone)] -pub(crate) struct SniCertResolver { - pub(crate) registry: Weak>>, -} - -impl std::fmt::Debug for SniCertResolver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SniCertResolver").finish_non_exhaustive() - } -} - -impl ResolvesServerCert for SniCertResolver { - fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { - let registry = self.registry.upgrade()?; - let sni = client_hello.server_name()?; - // DashMap keys are ASCII-normalised only if callers do so; we iterate - // to match case-insensitively. - for item in registry.iter() { - if item.key().eq_ignore_ascii_case(sni) - && let Some(entry) = item.value().upgrade() - { - return Some(entry.certified_key.clone()); - } - } - None - } -} diff --git a/src/error.rs b/src/error.rs index be4ddc9..b057354 100644 --- a/src/error.rs +++ b/src/error.rs @@ -102,6 +102,10 @@ codes! { /// The requested operation cannot be served over HTTP/3. The peer should retry over HTTP/1.1. pub const H3_VERSION_FALLBACK = 0x0110; + // https://www.rfc-editor.org/rfc/rfc9297#section-3.3 + /// HTTP Datagram or Capsule Protocol error. + pub const H3_DATAGRAM_ERROR = 0x33; + // https://datatracker.ietf.org/doc/html/rfc9204#name-error-handling /// The decoder failed to interpret an encoded field section and is not able to continue decoding that field section. pub const QPACK_DECOMPRESSION_FAILED = 0x200; @@ -109,6 +113,14 @@ codes! { pub const QPACK_ENCODER_STREAM_ERROR = 0x201; /// The encoder failed to interpret a decoder instruction received on the decoder stream. pub const QPACK_DECODER_STREAM_ERROR = 0x202; + + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3#section-9.5 + /// WebTransport stream belongs to a session that is gone. + pub const WT_SESSION_GONE = 0x170d7b68; + /// WebTransport stream rejected because h3x does not buffer unknown sessions. + pub const WT_BUFFERED_STREAM_REJECTED = 0x3994bd84; + /// WebTransport session aborted because a flow control error was encountered. + pub const WT_FLOW_CONTROL_ERROR = 0x045d4487; } impl Code { @@ -259,6 +271,28 @@ impl H3ConnectionError for H3InternalError { } } +#[cfg(test)] +mod webtransport_code_tests { + use super::*; + + #[test] + fn webtransport_codes_are_associated_constants() { + assert_eq!(Code::H3_DATAGRAM_ERROR.into_inner(), VarInt::from_u32(0x33)); + assert_eq!( + Code::WT_SESSION_GONE.into_inner(), + VarInt::from_u32(0x170d7b68) + ); + assert_eq!( + Code::WT_BUFFERED_STREAM_REJECTED.into_inner(), + VarInt::from_u32(0x3994bd84) + ); + assert_eq!( + Code::WT_FLOW_CONTROL_ERROR.into_inner(), + VarInt::from_u32(0x045d4487) + ); + } +} + #[derive(Debug, Snafu, Clone)] #[snafu(display("frame decode error"))] pub struct H3FrameDecodeError { @@ -310,3 +344,207 @@ impl H3ConnectionError for H3IdError { Code::H3_ID_ERROR } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + codec::{DecodeError, EncodeError}, + connection::{ConnectionError, StreamError}, + }; + + fn varint(value: u32) -> VarInt { + VarInt::from_u32(value) + } + + fn assert_connection_error(error: E, expected_code: Code, expected_display: &str) + where + E: H3ConnectionError + 'static, + { + assert_eq!(error.code(), expected_code); + assert_eq!(error.to_string(), expected_display); + + let ConnectionError::H3 { source } = ConnectionError::from(error) else { + panic!("expected h3 connection error"); + }; + assert_eq!(source.code(), expected_code); + assert_eq!(source.to_string(), expected_display); + } + + fn assert_stream_error(error: E, expected_code: Code, expected_display: &str) + where + E: H3StreamError + 'static, + { + assert_eq!(error.code(), expected_code); + assert_eq!(error.to_string(), expected_display); + + let StreamError::H3 { source } = StreamError::from(error) else { + panic!("expected h3 stream error"); + }; + assert_eq!(source.code(), expected_code); + assert_eq!(source.to_string(), expected_display); + } + + #[test] + fn code_conversions_and_display_cover_known_and_unknown_codes() { + let known = Code::H3_NO_ERROR; + assert_eq!(known.into_inner(), varint(0x100)); + assert_eq!(known.value(), varint(0x100)); + assert_eq!(VarInt::from(known), varint(0x100)); + assert_eq!(known.to_string(), "H3_NO_ERROR (0x100)"); + + let custom = Code::from(varint(0x12345)); + assert_eq!(custom.into_inner(), varint(0x12345)); + assert_eq!(custom.value(), varint(0x12345)); + assert_eq!(custom.to_string(), "Code 0x12345"); + + let constructed = Code::new(varint(0x201)); + assert_eq!(constructed, Code::QPACK_ENCODER_STREAM_ERROR); + assert_eq!( + constructed.to_string(), + "QPACK_ENCODER_STREAM_ERROR (0x201)" + ); + } + + #[test] + fn connection_error_types_report_their_codes_and_messages() { + for error in [ + H3StreamCreationError::DuplicateControlStream, + H3StreamCreationError::DuplicateQpackEncoderStream, + H3StreamCreationError::DuplicateQpackDecoderStream, + ] { + assert_connection_error( + error, + Code::H3_STREAM_CREATION_ERROR, + match error { + H3StreamCreationError::DuplicateControlStream => { + "control stream already exists" + } + H3StreamCreationError::DuplicateQpackEncoderStream => { + "qpack encoder stream already exists" + } + H3StreamCreationError::DuplicateQpackDecoderStream => { + "qpack decoder stream already exists" + } + }, + ); + } + + assert_connection_error( + H3CriticalStreamClosed::QPackEncoder, + Code::H3_CLOSED_CRITICAL_STREAM, + "qpack encoder stream closed unexpectedly", + ); + assert_connection_error( + H3CriticalStreamClosed::QPackDecoder, + Code::H3_CLOSED_CRITICAL_STREAM, + "qpack decoder stream closed unexpectedly", + ); + assert_connection_error( + H3CriticalStreamClosed::Control, + Code::H3_CLOSED_CRITICAL_STREAM, + "control stream closed unexpectedly", + ); + + for error in [ + H3FrameUnexpected::DuplicateSettings, + H3FrameUnexpected::UnexpectedFrameType, + H3FrameUnexpected::UnexpectedFrameDuringTrailer, + ] { + assert_connection_error( + error, + Code::H3_FRAME_UNEXPECTED, + match error { + H3FrameUnexpected::DuplicateSettings => "received subsequent SETTINGS frame", + H3FrameUnexpected::UnexpectedFrameType => { + "unexpected frame type on request stream" + } + H3FrameUnexpected::UnexpectedFrameDuringTrailer => { + "unexpected frame during trailer reading" + } + }, + ); + } + + assert_connection_error(H3NoError, Code::H3_NO_ERROR, "no error"); + assert_connection_error( + H3MissingSettings, + Code::H3_MISSING_SETTINGS, + "no SETTINGS frame at beginning of control stream", + ); + assert_connection_error( + H3GeneralProtocolError::TrailingPayload, + Code::H3_GENERAL_PROTOCOL_ERROR, + "trailing payload in GOAWAY frame", + ); + assert_connection_error( + H3GeneralProtocolError::Decode { + source: DecodeError::ArithmeticOverflow, + }, + Code::H3_GENERAL_PROTOCOL_ERROR, + "protocol decode error", + ); + assert_connection_error( + H3InternalError::QPackEncoderEncode { + source: EncodeError::HuffmanEncoding, + }, + Code::H3_INTERNAL_ERROR, + "QPACK encoder encode failure", + ); + assert_connection_error( + H3InternalError::MissingServerName, + Code::H3_INTERNAL_ERROR, + "missing server name (SNI) on incoming connection", + ); + assert_connection_error( + H3FrameDecodeError { + source: DecodeError::IntegerOverflow, + }, + Code::H3_FRAME_ERROR, + "frame decode error", + ); + assert_connection_error( + QpackDecompressionFailed::Decode { + source: DecodeError::DecompressionFailed, + }, + Code::QPACK_DECOMPRESSION_FAILED, + "QPACK decompression decode error", + ); + + for error in [ + H3IdError::PushIdExceedsLimit, + H3IdError::GoawayStreamIdOrdering, + ] { + assert_connection_error( + error, + Code::H3_ID_ERROR, + match error { + H3IdError::PushIdExceedsLimit => "push ID exceeds limit", + H3IdError::GoawayStreamIdOrdering => "GOAWAY stream ID ordering violation", + }, + ); + } + } + + #[test] + fn stream_error_types_report_their_codes_and_messages() { + assert_stream_error( + H3MessageError::MissingHeaderSection, + Code::H3_MESSAGE_ERROR, + "missing header section in HTTP message", + ); + assert_stream_error( + H3MessageError::UnexpectedHeadersInBody, + Code::H3_MESSAGE_ERROR, + "unexpected headers frame in message body", + ); + assert_stream_error( + H3ExcessiveFieldSectionSize { + actual: 8192, + limit: 4096, + }, + Code::H3_EXCESSIVE_LOAD, + "field section size 8192 exceeds limit 4096", + ); + } +} diff --git a/src/extended_connect.rs b/src/extended_connect.rs new file mode 100644 index 0000000..ece59c3 --- /dev/null +++ b/src/extended_connect.rs @@ -0,0 +1,208 @@ +use std::{future::Future, sync::Arc}; + +use futures::future::{BoxFuture, FutureExt}; +use snafu::ResultExt; + +use crate::{ + connection::ConnectionState, + dhttp::message::{MessageReader, MessageWriter}, + qpack::field::Protocol, + quic, + stream_id::StreamId, +}; + +mod error; +mod tunnel; + +#[cfg(feature = "hyper")] +pub mod hyper; +pub mod settings; + +pub use error::{ + IntoStreamsError, PendingWriteStreamError, into_streams_error, pending_write_stream_error, +}; +pub use tunnel::ConnectTunnel; + +pub struct EstablishedConnect { + stream_id: StreamId, + protocol: Option, + connection: Arc>, + control: ConnectControl, +} + +#[allow(dead_code)] +enum ConnectControl { + Ready { + read: MessageReader, + write: MessageWriter, + }, + Pending { + read: MessageReader, + write: BoxFuture<'static, Result>, + }, +} + +impl EstablishedConnect { + #[allow(dead_code)] + pub(crate) fn ready( + stream_id: StreamId, + protocol: Option, + connection: Arc>, + read: MessageReader, + write: MessageWriter, + ) -> Self { + Self { + stream_id, + protocol, + connection, + control: ConnectControl::Ready { read, write }, + } + } + + #[allow(dead_code)] + pub(crate) fn pending( + stream_id: StreamId, + protocol: Option, + connection: Arc>, + read: MessageReader, + write: impl Future> + Send + 'static, + ) -> Self { + Self { + stream_id, + protocol, + connection, + control: ConnectControl::Pending { + read, + write: write.boxed(), + }, + } + } + + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + pub fn protocol(&self) -> Option<&Protocol> { + self.protocol.as_ref() + } + + pub fn connection(&self) -> &Arc> { + &self.connection + } + + pub async fn into_streams(self) -> Result<(MessageReader, MessageWriter), IntoStreamsError> { + match self.control { + ConnectControl::Ready { read, write } => Ok((read, write)), + ConnectControl::Pending { read, write } => { + let write = write + .await + .context(into_streams_error::PendingWriteStreamSnafu)?; + Ok((read, write)) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use futures::future; + + use super::*; + use crate::{ + connection::{ConnectionState, tests::MockConnection}, + dhttp::message::test::{read_stream_for_test, write_stream_for_test}, + protocol::Protocols, + qpack::field::Protocol, + quic, + stream_id::StreamId, + varint::VarInt, + }; + + fn state_for_test() -> Arc> { + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic; + Arc::new(ConnectionState::new_for_test( + erased, + Arc::new(Protocols::new()), + )) + } + + #[tokio::test] + async fn ready_connect_returns_streams() { + let stream_id = StreamId::from(VarInt::from_u32(4)); + let connect = EstablishedConnect::ready( + stream_id, + Some(Protocol::new("test-protocol")), + state_for_test(), + read_stream_for_test(stream_id.0), + write_stream_for_test(stream_id.0), + ); + + assert_eq!(connect.stream_id(), stream_id); + assert_eq!( + connect.protocol().map(Protocol::as_str), + Some("test-protocol") + ); + + let (_read, _write) = connect.into_streams().await.expect("streams are ready"); + } + + #[tokio::test] + async fn pending_connect_waits_for_write_stream() { + let stream_id = StreamId::from(VarInt::from_u32(8)); + let connect = EstablishedConnect::pending( + stream_id, + None, + state_for_test(), + read_stream_for_test(stream_id.0), + future::ready(Ok(write_stream_for_test(stream_id.0))), + ); + + let (_read, _write) = connect + .into_streams() + .await + .expect("pending write delivered"); + } + + #[tokio::test] + async fn pending_write_failure_is_reported_by_into_streams() { + let stream_id = StreamId::from(VarInt::from_u32(12)); + let connect = EstablishedConnect::pending( + stream_id, + None, + state_for_test(), + read_stream_for_test(stream_id.0), + future::ready(Err(PendingWriteStreamError::Aborted)), + ); + + let error = match connect.into_streams().await { + Ok(_) => panic!("pending write failed"), + Err(error) => error, + }; + assert!(matches!( + error, + IntoStreamsError::PendingWriteStream { + source: PendingWriteStreamError::Aborted, + }, + )); + } + + #[tokio::test] + async fn connect_tunnel_delegates_to_established_connect() { + let stream_id = StreamId::from(VarInt::from_u32(16)); + let connect = EstablishedConnect::ready( + stream_id, + Some(Protocol::new("raw-tunnel")), + state_for_test(), + read_stream_for_test(stream_id.0), + write_stream_for_test(stream_id.0), + ); + let tunnel = ConnectTunnel::from(connect); + + assert_eq!(tunnel.stream_id(), stream_id); + assert_eq!(tunnel.protocol().map(Protocol::as_str), Some("raw-tunnel")); + let (_read, _write) = tunnel.into_streams().await.expect("streams are ready"); + } +} diff --git a/src/extended_connect/error.rs b/src/extended_connect/error.rs new file mode 100644 index 0000000..d64dd0d --- /dev/null +++ b/src/extended_connect/error.rs @@ -0,0 +1,21 @@ +use snafu::Snafu; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Snafu)] +#[snafu(module, visibility(pub))] +pub enum PendingWriteStreamError { + #[snafu(display("pending write stream is unsupported"))] + Unsupported, + #[snafu(display("pending write stream was already taken"))] + AlreadyTaken, + #[snafu(display("pending write stream provider was dropped before delivery"))] + Aborted, + #[snafu(display("pending message body could not be released"))] + BodyNotReleased, +} + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum IntoStreamsError { + #[snafu(display("failed to obtain pending write stream"))] + PendingWriteStream { source: PendingWriteStreamError }, +} diff --git a/src/extended_connect/hyper.rs b/src/extended_connect/hyper.rs new file mode 100644 index 0000000..a211f9a --- /dev/null +++ b/src/extended_connect/hyper.rs @@ -0,0 +1,581 @@ +use std::sync::Arc; + +use bytes::Bytes; +use http::{Method, StatusCode}; +use http_body::Body; +use http_body_util::Empty; +use snafu::{OptionExt, ResultExt, Snafu, ensure}; + +use crate::{ + connection::ConnectionState, + dhttp::message::{MessageReader, MessageWriter, hyper::upgrade::TakeoverError}, + extended_connect::{EstablishedConnect, PendingWriteStreamError}, + qpack::field::Protocol, + quic, + stream_id::StreamId, +}; + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum EstablishError { + #[snafu(display("extended connect response was rejected with status {status}"))] + Rejected { status: StatusCode }, + #[snafu(display("extended connect response is missing stream ID metadata"))] + MissingStreamId, + #[snafu(display("extended connect response is missing connection metadata"))] + MissingConnection, + #[snafu(display("failed to take over extended connect read stream"))] + TakeRead { source: TakeoverError }, + #[snafu(display("failed to take over extended connect write stream"))] + TakeWrite { source: TakeoverError }, +} + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum AcceptError { + #[snafu(display("request method {method} is not extended connect"))] + NotConnect { method: Method }, + #[snafu(display("extended connect request is missing stream ID metadata"))] + MissingStreamId, + #[snafu(display("extended connect request is missing connection metadata"))] + MissingConnection, + #[snafu(display("failed to take over extended connect read stream"))] + TakeRead { source: TakeoverError }, +} + +fn pending_write_error(error: TakeoverError) -> PendingWriteStreamError { + match error { + TakeoverError::Unsupported => PendingWriteStreamError::Unsupported, + TakeoverError::AlreadyTaken => PendingWriteStreamError::AlreadyTaken, + TakeoverError::Aborted => PendingWriteStreamError::Aborted, + TakeoverError::BodyNotReleased => PendingWriteStreamError::BodyNotReleased, + } +} + +pub async fn establish( + mut response: http::Response, +) -> Result +where + B: Body + Unpin + Send + 'static, +{ + ensure!( + response.status().is_success(), + establish_error::RejectedSnafu { + status: response.status(), + } + ); + + let stream_id = *response + .extensions() + .get::() + .context(establish_error::MissingStreamIdSnafu)?; + let connection = response + .extensions() + .get::>>() + .cloned() + .context(establish_error::MissingConnectionSnafu)?; + let protocol = response.extensions().get::().cloned(); + + let read = crate::hyper::upgrade::take::(&mut response) + .await + .context(establish_error::TakeReadSnafu)?; + let write = crate::hyper::upgrade::take::(&mut response) + .await + .context(establish_error::TakeWriteSnafu)?; + + Ok(EstablishedConnect::ready( + stream_id, protocol, connection, read, write, + )) +} + +pub async fn accept( + mut request: http::Request, +) -> Result<(http::Response>, EstablishedConnect), AcceptError> +where + B: Body + Unpin + Send + 'static, +{ + ensure!( + request.method() == Method::CONNECT, + accept_error::NotConnectSnafu { + method: request.method().clone(), + } + ); + + let stream_id = *request + .extensions() + .get::() + .context(accept_error::MissingStreamIdSnafu)?; + let connection = request + .extensions() + .get::>>() + .cloned() + .context(accept_error::MissingConnectionSnafu)?; + let protocol = request.extensions().get::().cloned(); + + let read = crate::hyper::upgrade::take::(&mut request) + .await + .context(accept_error::TakeReadSnafu)?; + + let write = async move { + match crate::hyper::upgrade::take::(&mut request).await { + Ok(write) => Ok(write), + Err(error) => Err(pending_write_error(error)), + } + }; + + let response = http::Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("200 OK response with empty body is valid"); + + let connect = EstablishedConnect::pending(stream_id, protocol, connection, read, write); + Ok((response, connect)) +} + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, + }; + + use http::{Request, Response, StatusCode}; + use http_body::Body; + + use super::*; + use crate::{ + connection::{ConnectionState, tests::MockConnection}, + dhttp::message::{ + hyper::upgrade::{RemainStream, TakeoverSlot}, + test::{read_stream_for_test, write_stream_for_test}, + }, + protocol::Protocols, + qpack::field::Protocol, + varint::VarInt, + }; + + #[derive(Debug, Clone)] + struct ErrorBody; + + impl Body for ErrorBody { + type Data = Bytes; + type Error = std::io::Error; + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + Poll::Ready(Some(Err(std::io::Error::other("body error")))) + } + } + + fn state_for_test() -> Arc> { + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic; + Arc::new(ConnectionState::new_for_test( + erased, + Arc::new(Protocols::new()), + )) + } + + #[tokio::test] + async fn establish_rejects_non_success_status() { + let response = Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Empty::::new()) + .expect("valid response"); + + let error = match establish(response).await { + Ok(_) => panic!("non-2xx CONNECT is rejected"), + Err(error) => error, + }; + assert!( + matches!(error, EstablishError::Rejected { status } if status == StatusCode::BAD_REQUEST) + ); + } + + #[tokio::test] + async fn establish_requires_stream_id_metadata() { + let response = Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("valid response"); + + let error = match establish(response).await { + Ok(_) => panic!("stream ID metadata is required"), + Err(error) => error, + }; + assert!(matches!(error, EstablishError::MissingStreamId)); + } + + #[tokio::test] + async fn establish_requires_connection_metadata() { + let stream_id = StreamId::from(VarInt::from_u32(4)); + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("valid response"); + + response.extensions_mut().insert(stream_id); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let error = match establish(response).await { + Ok(_) => panic!("connection metadata is required"), + Err(error) => error, + }; + assert!(matches!(error, EstablishError::MissingConnection)); + } + + #[tokio::test] + async fn establish_returns_ready_connect_with_protocol_and_streams() { + let stream_id = StreamId::from(VarInt::from_u32(8)); + let state = state_for_test(); + let protocol = Protocol::new("webtransport-h3"); + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("valid response"); + + response.extensions_mut().insert(stream_id); + response.extensions_mut().insert(state.clone()); + response.extensions_mut().insert(protocol.clone()); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let connect = establish(response).await.expect("establish should succeed"); + assert!(Arc::ptr_eq(connect.connection(), &state)); + assert_eq!(connect.stream_id(), stream_id); + assert_eq!( + connect.protocol().map(Protocol::as_str), + Some("webtransport-h3") + ); + + let (_read, _write) = connect + .into_streams() + .await + .expect("establish should provide streams"); + } + + #[tokio::test] + async fn establish_missing_protocol_defaults_to_none() { + let stream_id = StreamId::from(VarInt::from_u32(12)); + let state = state_for_test(); + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("valid response"); + + response.extensions_mut().insert(stream_id); + response.extensions_mut().insert(state); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let connect = establish(response).await.expect("establish should succeed"); + assert_eq!(connect.protocol(), None); + assert_eq!(connect.stream_id(), stream_id); + } + + #[tokio::test] + async fn establish_fails_when_write_slot_is_missing() { + let stream_id = StreamId::from(VarInt::from_u32(16)); + let state = state_for_test(); + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("valid response"); + + response.extensions_mut().insert(stream_id); + response.extensions_mut().insert(state); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + + let error = match establish(response).await { + Ok(_) => panic!("missing write takeover slot should fail"), + Err(error) => error, + }; + assert!(matches!( + error, + EstablishError::TakeWrite { + source: TakeoverError::Unsupported + } + )); + } + + #[tokio::test] + async fn establish_fails_when_read_slot_is_missing() { + let stream_id = StreamId::from(VarInt::from_u32(18)); + let state = state_for_test(); + let mut response = Response::builder() + .status(StatusCode::OK) + .body(Empty::::new()) + .expect("valid response"); + + response.extensions_mut().insert(stream_id); + response.extensions_mut().insert(state); + response + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let error = establish(response) + .await + .err() + .expect("missing read takeover slot should fail"); + assert!(matches!( + error, + EstablishError::TakeRead { + source: TakeoverError::Unsupported + } + )); + } + + #[tokio::test] + async fn accept_rejects_non_connect_request() { + let request = Request::builder() + .method(http::Method::GET) + .uri("https://example.test/") + .extension(Protocol::new("webtransport-h3")) + .body(Empty::::new()) + .expect("valid request"); + + let error = match accept(request).await { + Ok(_) => panic!("only CONNECT can be accepted"), + Err(error) => error, + }; + assert!(matches!(error, AcceptError::NotConnect { method } if method == http::Method::GET)); + } + + #[tokio::test] + async fn accept_requires_stream_id_metadata() { + let request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(Empty::::new()) + .expect("valid request"); + + let error = match accept(request).await { + Ok(_) => panic!("stream ID metadata is required"), + Err(error) => error, + }; + assert!(matches!(error, AcceptError::MissingStreamId)); + } + + #[tokio::test] + async fn accept_requires_connection_metadata() { + let stream_id = StreamId::from(VarInt::from_u32(20)); + let mut request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(Empty::::new()) + .expect("valid request"); + + request.extensions_mut().insert(stream_id); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let error = match accept(request).await { + Ok(_) => panic!("connection metadata is required"), + Err(error) => error, + }; + assert!(matches!(error, AcceptError::MissingConnection)); + } + + #[tokio::test] + async fn accept_returns_ok_response_and_connect_with_streams() { + let stream_id = StreamId::from(VarInt::from_u32(24)); + let state = state_for_test(); + let protocol = Protocol::new("webtransport-h3"); + let mut request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(Empty::::new()) + .expect("valid request"); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(state.clone()); + request.extensions_mut().insert(protocol.clone()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let (response, connect) = accept(request).await.expect("accept should succeed"); + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(connect.stream_id(), stream_id); + assert_eq!( + connect.protocol().map(Protocol::as_str), + Some("webtransport-h3") + ); + assert!(Arc::ptr_eq(connect.connection(), &state)); + + let (_read, _write) = connect + .into_streams() + .await + .expect("streams should be ready"); + } + + #[tokio::test] + async fn accept_missing_protocol_defaults_to_none() { + let stream_id = StreamId::from(VarInt::from_u32(28)); + let state = state_for_test(); + let mut request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(Empty::::new()) + .expect("valid request"); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(state); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let (_response, connect) = accept(request).await.expect("accept should succeed"); + assert_eq!(connect.protocol(), None); + } + + #[tokio::test] + async fn accept_fails_when_read_takeover_fails() { + let stream_id = StreamId::from(VarInt::from_u32(32)); + let mut request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(ErrorBody) + .expect("valid request"); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(state_for_test()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(stream_id.0), + ))); + + let error = match accept(request).await { + Ok(_) => panic!("read body errors should fail takeover"), + Err(error) => error, + }; + assert!(matches!( + error, + AcceptError::TakeRead { + source: TakeoverError::BodyNotReleased + } + )); + } + + #[tokio::test] + async fn accept_into_streams_reports_missing_write_takeover_as_pending_error() { + let stream_id = StreamId::from(VarInt::from_u32(36)); + let mut request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(Empty::::new()) + .expect("valid request"); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(state_for_test()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + + let (_response, connect) = accept(request).await.expect("accept should parse request"); + let error = match connect.into_streams().await { + Ok(_) => panic!("write takeover should be unsupported"), + Err(error) => error, + }; + assert!(matches!( + error, + crate::extended_connect::IntoStreamsError::PendingWriteStream { + source: crate::extended_connect::PendingWriteStreamError::Unsupported + } + )); + } + + #[tokio::test] + async fn accept_into_streams_reports_aborted_write_takeover_as_pending_error() { + let stream_id = StreamId::from(VarInt::from_u32(40)); + let mut request = Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(Empty::::new()) + .expect("valid request"); + + request.extensions_mut().insert(stream_id); + request.extensions_mut().insert(state_for_test()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(stream_id.0), + ))); + let (write_tx, write) = RemainStream::::pending(); + request.extensions_mut().insert(TakeoverSlot::new(write)); + drop(write_tx); + + let (_response, connect) = accept(request).await.expect("accept should parse request"); + let error = connect + .into_streams() + .await + .err() + .expect("write takeover should be aborted"); + assert!(matches!( + error, + crate::extended_connect::IntoStreamsError::PendingWriteStream { + source: crate::extended_connect::PendingWriteStreamError::Aborted + } + )); + } +} diff --git a/src/extended_connect/settings.rs b/src/extended_connect/settings.rs new file mode 100644 index 0000000..1898061 --- /dev/null +++ b/src/extended_connect/settings.rs @@ -0,0 +1,72 @@ +use crate::{ + dhttp::settings::{Setting, SettingId, Settings}, + varint::VarInt, +}; + +/// `SETTINGS_ENABLE_CONNECT_PROTOCOL` (0x08). No default. +/// +/// Enables the Extended CONNECT method. The value MUST be 0 or 1. +pub struct EnableConnectProtocol; + +impl EnableConnectProtocol { + pub const ID: VarInt = VarInt::from_u32(0x08); + + pub const fn setting(enabled: bool) -> Setting { + Setting::new(Self::ID, VarInt::from_u32(enabled as u32)) + } +} + +impl SettingId for EnableConnectProtocol { + type Value = bool; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> bool { + settings + .get_raw(Self::ID) + .is_some_and(|value| value == VarInt::from_u32(1)) + } +} + +impl Settings { + pub fn enable_connect_protocol(&self) -> bool { + self.get(EnableConnectProtocol) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::varint::VarInt; + + #[test] + fn enable_connect_protocol_round_trips_through_settings() { + let mut settings = Settings::default(); + assert!(!settings.enable_connect_protocol()); + + settings.set(EnableConnectProtocol::setting(true)); + assert!(settings.enable_connect_protocol()); + assert_eq!( + settings.get(VarInt::from_u32(0x08)), + Some(VarInt::from_u32(1)) + ); + } + + #[test] + fn enable_connect_protocol_exposes_id_and_treats_only_one_as_enabled() { + assert_eq!(EnableConnectProtocol.id(), EnableConnectProtocol::ID); + + let mut settings = Settings::default(); + settings.set(EnableConnectProtocol::setting(false)); + assert!(!settings.enable_connect_protocol()); + assert_eq!( + settings.get(VarInt::from_u32(0x08)), + Some(VarInt::from_u32(0)) + ); + + settings.set(Setting::new(EnableConnectProtocol::ID, VarInt::from_u32(2))); + assert!(!settings.enable_connect_protocol()); + } +} diff --git a/src/extended_connect/tunnel.rs b/src/extended_connect/tunnel.rs new file mode 100644 index 0000000..4aa96ba --- /dev/null +++ b/src/extended_connect/tunnel.rs @@ -0,0 +1,30 @@ +use crate::{ + dhttp::message::{MessageReader, MessageWriter}, + extended_connect::{EstablishedConnect, IntoStreamsError}, + qpack::field::Protocol, + stream_id::StreamId, +}; + +pub struct ConnectTunnel { + connect: EstablishedConnect, +} + +impl From for ConnectTunnel { + fn from(connect: EstablishedConnect) -> Self { + Self { connect } + } +} + +impl ConnectTunnel { + pub fn stream_id(&self) -> StreamId { + self.connect.stream_id() + } + + pub fn protocol(&self) -> Option<&Protocol> { + self.connect.protocol() + } + + pub async fn into_streams(self) -> Result<(MessageReader, MessageWriter), IntoStreamsError> { + self.connect.into_streams().await + } +} diff --git a/src/hyper.rs b/src/hyper.rs index 879ecc2..b05e609 100644 --- a/src/hyper.rs +++ b/src/hyper.rs @@ -2,8 +2,8 @@ pub mod upgrade { use std::future::poll_fn; - pub use crate::message::stream::{ - BoxMessageStreamReader, BoxMessageStreamWriter, ReadStream, WriteStream, + pub use crate::dhttp::message::{ + BoxMessageReader, BoxMessageWriter, MessageReader, MessageWriter, hyper::upgrade::{HasTakeover, MissingStream, TakeoverError, TakeoverSlot, UpgradeError}, }; @@ -15,22 +15,18 @@ pub mod upgrade { #[doc(alias = "take")] pub async fn on( - mut message: impl HasTakeover + HasTakeover, - ) -> Result< - ( - BoxMessageStreamReader<'static>, - BoxMessageStreamWriter<'static>, - ), - UpgradeError, - > { - let read = - match poll_fn(|cx| HasTakeover::::poll_takeover(&mut message, cx)).await { - Ok(read) => Some(read), - Err(TakeoverError::Unsupported) => None, - Err(source) => return Err(UpgradeError::Takeover { source }), - }; + mut message: impl HasTakeover + HasTakeover, + ) -> Result<(BoxMessageReader, BoxMessageWriter), UpgradeError> { + let read = match poll_fn(|cx| HasTakeover::::poll_takeover(&mut message, cx)) + .await + { + Ok(read) => Some(read), + Err(TakeoverError::Unsupported) => None, + Err(source) => return Err(UpgradeError::Takeover { source }), + }; let write = - match poll_fn(|cx| HasTakeover::::poll_takeover(&mut message, cx)).await { + match poll_fn(|cx| HasTakeover::::poll_takeover(&mut message, cx)).await + { Ok(write) => Some(write), Err(TakeoverError::Unsupported) => None, Err(source) => return Err(UpgradeError::Takeover { source }), @@ -55,10 +51,14 @@ pub mod ext { pub use crate::qpack::field::Protocol; } -pub use crate::message::stream::hyper::write::SendMessageError; +pub mod extended_connect { + pub use crate::extended_connect::hyper::*; +} -pub mod client; -pub mod server; +pub use crate::{ + dhttp::message::hyper::{RequestError, SendMessageError}, + endpoint::hyper::{HandleRequestError, HyperService, TowerService}, +}; #[cfg(test)] mod tests { @@ -71,6 +71,16 @@ mod tests { use http_body::{Body, Frame}; use super::upgrade; + use crate::{ + dhttp::message::{ + hyper::upgrade::{ + MissingStream, RemainStream, TakeoverError, TakeoverSlot, UpgradeError, + }, + test::{read_stream_for_test, write_stream_for_test}, + }, + quic::GetStreamIdExt, + varint::VarInt, + }; #[derive(Debug, Clone)] struct ErrorBody; @@ -93,10 +103,133 @@ mod tests { assert!(matches!( compat, Err( - crate::message::stream::hyper::upgrade::UpgradeError::Takeover { - source: crate::message::stream::hyper::upgrade::TakeoverError::BodyNotReleased, + crate::dhttp::message::hyper::upgrade::UpgradeError::Takeover { + source: crate::dhttp::message::hyper::upgrade::TakeoverError::BodyNotReleased, } ) )); } + + #[tokio::test] + async fn take_returns_unsupported_when_no_takeover_slot() { + let request = http::Request::new(http_body_util::Empty::::new()); + + let result = upgrade::take::(request).await; + assert!(matches!(result, Err(TakeoverError::Unsupported))); + } + + #[tokio::test] + async fn take_releases_empty_body_and_returns_stream() { + let request = { + let mut request = http::Request::new(http_body_util::Empty::::new()); + let stream = read_stream_for_test(VarInt::from_u32(55)); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately(stream))); + request + }; + let mut read_stream = upgrade::take::(request) + .await + .unwrap(); + let stream_id = GetStreamIdExt::stream_id(&mut read_stream).await.unwrap(); + assert_eq!(stream_id, VarInt::from_u32(55)); + } + + #[tokio::test] + async fn upgrade_with_missing_read_or_write_stream_reports_incomplete() { + let request = { + let mut request = http::Request::new(http_body_util::Empty::::new()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(VarInt::from_u32(11)), + ))); + request + }; + + assert!(matches!( + upgrade::on(request).await, + Err(UpgradeError::Incomplete { + missing: MissingStream::Write + }) + )); + + let request = { + let mut request = http::Request::new(http_body_util::Empty::::new()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(VarInt::from_u32(22)), + ))); + request + }; + + assert!(matches!( + upgrade::on(request).await, + Err(UpgradeError::Incomplete { + missing: MissingStream::Read + }) + )); + + let request = http::Request::new(http_body_util::Empty::::new()); + assert!(matches!( + upgrade::on(request).await, + Err(UpgradeError::Incomplete { + missing: MissingStream::Both + }) + )); + } + + #[tokio::test] + async fn upgrade_requires_both_read_and_write_stream() { + let request = { + let mut request = http::Request::new(http_body_util::Empty::::new()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(VarInt::from_u32(33)), + ))); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(VarInt::from_u32(44)), + ))); + request + }; + let (mut read_stream, mut write_stream) = upgrade::on(request).await.unwrap(); + + let read_id = GetStreamIdExt::stream_id(&mut read_stream).await.unwrap(); + let write_id = GetStreamIdExt::stream_id(&mut write_stream).await.unwrap(); + assert_eq!(read_id, VarInt::from_u32(33)); + assert_eq!(write_id, VarInt::from_u32(44)); + } + + #[tokio::test] + async fn upgrade_fails_when_stream_already_taken() { + let mut request = { + let mut request = http::Request::new(http_body_util::Empty::::new()); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + read_stream_for_test(VarInt::from_u32(66)), + ))); + request + .extensions_mut() + .insert(TakeoverSlot::new(RemainStream::immediately( + write_stream_for_test(VarInt::from_u32(77)), + ))); + request + }; + + let _ = upgrade::take::(&mut request) + .await + .unwrap(); + + assert!(matches!( + upgrade::on(&mut request).await, + Err(UpgradeError::Takeover { + source: TakeoverError::AlreadyTaken + }) + )); + } } diff --git a/src/hyper/client.rs b/src/hyper/client.rs deleted file mode 100644 index f546033..0000000 --- a/src/hyper/client.rs +++ /dev/null @@ -1,150 +0,0 @@ -use std::error::Error; - -use bytes::Bytes; -use http_body::Body; -use http_body_util::{BodyExt, Empty}; -use snafu::ResultExt; -use tracing::Instrument; - -use crate::{ - connection::Connection, - hyper::SendMessageError, - message::stream::{ - InitialMessageStreamError, MessageStreamError, - hyper::{ - read::Either, - upgrade::{RemainStream, TakeoverSlot}, - write::send_message_error, - }, - }, - quic, -}; - -#[derive(Debug)] -pub enum RequestError { - InitialStream { source: InitialMessageStreamError }, - SendRequest { source: SendMessageError }, - ReceiveResponse { source: MessageStreamError }, -} - -impl From for RequestError { - #[track_caller] - fn from(source: InitialMessageStreamError) -> Self { - RequestError::InitialStream { source } - } -} - -impl From> for RequestError { - #[track_caller] - fn from(source: SendMessageError) -> Self { - RequestError::SendRequest { source } - } -} - -impl From for RequestError { - #[track_caller] - fn from(source: MessageStreamError) -> Self { - RequestError::ReceiveResponse { source } - } -} - -impl std::fmt::Display for RequestError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - #[allow(unused_variables)] - match self { - RequestError::InitialStream { source, .. } => source.fmt(f), - RequestError::SendRequest { source, .. } => source.fmt(f), - RequestError::ReceiveResponse { source, .. } => source.fmt(f), - } - } -} - -impl Error for RequestError { - fn source(&self) -> ::core::option::Option<&(dyn ::snafu::Error + 'static)> { - match *self { - RequestError::InitialStream { ref source, .. } => source.source(), - RequestError::SendRequest { ref source, .. } => source.source(), - RequestError::ReceiveResponse { ref source, .. } => source.source(), - } - } -} - -impl Connection { - #[tracing::instrument(level = "debug", skip_all, fields(method = %request.method(), uri = %request.uri()))] - pub async fn execute_hyper_request( - &self, - request: http::Request, - ) -> Result< - http::Response + use>, - RequestError, - > - where - B::Data: Send, - B::Error: Error + Send + 'static, - { - let (mut read_stream, mut write_stream) = self.initial_message_stream().await?; - let is_connect = request.method() == http::Method::CONNECT; - - if is_connect { - // CONNECT: no body or trailers — join send + receive headers. - let (parts, _body) = request.into_parts(); - let (send_result, recv_result) = - tokio::join!(write_stream.send_hyper_request_parts(parts), async { - loop { - let parts = read_stream.read_hyper_response_parts().await?; - if !parts.status.is_informational() { - return Ok::<_, MessageStreamError>(parts); - } - tracing::debug!( - status = %parts.status, - headers = ?parts.headers, - "skipping informational response", - ); - } - },); - send_result.context(send_message_error::StreamSnafu)?; - let mut response_parts = recv_result?; - - response_parts - .extensions - .insert(TakeoverSlot::new(RemainStream::immediately(read_stream))); - response_parts - .extensions - .insert(TakeoverSlot::new(RemainStream::immediately(write_stream))); - let body = Either::right(Empty::new().map_err(|never| match never {})); - Ok(http::Response::from_parts(response_parts, body)) - } else { - // Non-CONNECT: send headers, spawn body sender, read response. - let (parts, body) = request.into_parts(); - write_stream - .send_hyper_request_parts(parts) - .await - .context(send_message_error::StreamSnafu)?; - - // Spawn background task to send request body + close write stream. - // Guard ensures stream cleanup on failure. - tokio::spawn( - async move { - if write_stream.send_hyper_body(body).await.is_ok() { - _ = write_stream.close().await; - } - } - .in_current_span(), - ); - - // Read response headers, skipping informational. - let mut response_parts = read_stream.read_hyper_response_parts().await?; - while response_parts.status.is_informational() { - tracing::debug!( - status = %response_parts.status, - headers = ?response_parts.headers, - "skipping informational response", - ); - response_parts = read_stream.read_hyper_response_parts().await?; - } - - let body = Either::left(read_stream.into_hyper_body()); - Ok(http::Response::from_parts(response_parts, body)) - } - } -} diff --git a/src/hyper/server.rs b/src/hyper/server.rs deleted file mode 100644 index ac830ed..0000000 --- a/src/hyper/server.rs +++ /dev/null @@ -1,424 +0,0 @@ -use std::{ - error::Error, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures::future::{self, BoxFuture}; -use http::Method; -use http_body::Body; -use http_body_util::{BodyExt, Empty, combinators::UnsyncBoxBody}; -use snafu::{Report, ResultExt, Snafu}; -use tracing::Instrument; - -use crate::{ - message::stream::{ - MessageStreamError, ReadStream, - hyper::{ - upgrade::{RemainStream, TakeoverSlot}, - write::SendMessageError, - }, - }, - server::{Request, Response, UnresolvedRequest}, -}; - -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(transparent)] -pub struct TowerService(pub S); - -impl crate::server::Service for TowerService -where - S: tower_service::Service< - http::Request>, - Response = http::Response, - Error: Error + Send, - Future: Send, - > + Clone - + Send - + 'static, - RespBody: Body + Send, -{ - type Future<'s> = BoxFuture<'s, ()>; - - fn serve<'s>(&self, req: &'s mut Request, resp: &'s mut Response) -> Self::Future<'s> { - let mut service = self.0.clone(); - Box::pin(async move { - if let Err(error) = future::poll_fn(|cx| service.poll_ready(cx)).await { - tracing::debug!(error = %Report::from_error(error), "service cannot be ready"); - return; - } - - let mut read_stream = Some(req.read_stream().take()); - let mut write_stream = resp.write_stream().take(); - resp.mark_taken_over(); - - let is_connect = req.method() == Method::CONNECT; - - let (remain_write_stream_tx, remain_write_stream) = RemainStream::pending(); - let request = if is_connect { - let remain_read_stream = RemainStream::immediately( - read_stream - .take() - .expect("connect request must have read stream"), - ); - http::Request::builder() - .method(req.method()) - .uri(req.uri()) - .extension(TakeoverSlot::new(remain_read_stream)) - .extension(TakeoverSlot::new(remain_write_stream.clone())) - .body(UnsyncBoxBody::new(Empty::new().map_err(|n| match n {}))) - } else { - http::Request::builder() - .method(req.method()) - .uri(req.uri()) - .body(UnsyncBoxBody::new( - read_stream - .take() - .expect("non-connect request must have read stream") - .into_hyper_body(), - )) - }; - - let mut request = match request { - Ok(request) => request, - Err(error) => { - tracing::warn!(error = %Report::from_error(error), "failed to convert request, skip serving"); - return; - } - }; - - *request.headers_mut() = req.headers().clone(); - request.extensions_mut().insert(resp.agent().clone()); - request.extensions_mut().insert(req.stream_id()); - request.extensions_mut().insert(req.protocols().clone()); - if let Some(remote_agent) = req.agent().cloned() { - request.extensions_mut().insert(remote_agent); - } - - match service.call(request).await { - Ok(response) => { - if let Err(error) = write_stream.send_hyper_response(response).await { - tracing::debug!(error = %Report::from_error(error), "failed to send response"); - } - if is_connect { - // Flush the buffered response so the client receives - // it before the stream is handed to the upgrade layer. - _ = write_stream.flush().await; - _ = remain_write_stream_tx.send(write_stream); - } else { - _ = write_stream.close().await; - } - } - Err(error) => { - tracing::debug!(error = %Report::from_error(error), "service failed") - } - } - }) - } -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum HandleRequestError { - #[snafu(display("failed to handle message stream"))] - Stream { source: MessageStreamError }, - #[snafu(display("service error"))] - Service { source: S }, - #[snafu(display("response body error"))] - Body { source: B }, -} - -impl From> - for HandleRequestError -{ - fn from(source: SendMessageError) -> Self { - match source { - SendMessageError::Stream { source } => HandleRequestError::Stream { source }, - SendMessageError::Body { source } => HandleRequestError::Body { source }, - } - } -} - -impl tower_service::Service for TowerService -where - S: tower_service::Service< - http::Request>, - Response = http::Response, - Error = ServiceE, - Future: Send, - > + Clone - + Send - + 'static, - ServiceE: Error + 'static, - RespBody: Body + Send, -{ - type Response = (); - - type Error = HandleRequestError; - - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.0 - .poll_ready(cx) - .map_err(|source| HandleRequestError::Service { source }) - } - - fn call( - &mut self, - UnresolvedRequest { - stream_id, - read_stream, - write_stream: mut response_stream, - connection, - }: UnresolvedRequest, - ) -> Self::Future { - let span = tracing::info_span!( - "handle_request", - method = tracing::field::Empty, - uri = tracing::field::Empty - ); - - let mut service = self.0.clone(); - let future = async move { - future::poll_fn(|cx| service.poll_ready(cx)) - .await - .context(handle_request_error::ServiceSnafu)?; - - let mut request = read_stream - .into_hyper_request() - .await - .context(handle_request_error::StreamSnafu)? - .map(UnsyncBoxBody::new); - tracing::Span::current() - .record("method", request.method().as_str()) - .record("uri", request.uri().to_string()); - - tracing::trace!("converted request stream to hyper request, serving..."); - let is_connect = request.method() == Method::CONNECT; - let (remain_write_stream_tx, remain_write_stream) = RemainStream::pending(); - - // Downstream handlers read the owning connection out of extensions - // and call `local_agent().await` / `remote_agent().await` / - // `protocols()` on demand. `stream_id` is request-scoped so it is - // still inserted alongside the connection. - request.extensions_mut().insert(stream_id); - request.extensions_mut().insert(connection); - if is_connect - && request - .extensions() - .get::>() - .is_some() - { - request - .extensions_mut() - .insert(TakeoverSlot::new(remain_write_stream.clone())); - } - - let response = service - .call(request) - .await - .context(handle_request_error::ServiceSnafu)?; - - response_stream.send_hyper_response(response).await?; - if is_connect { - response_stream - .flush() - .await - .context(handle_request_error::StreamSnafu)?; - _ = remain_write_stream_tx.send(response_stream); - } else { - response_stream - .close() - .await - .context(handle_request_error::StreamSnafu)?; - } - - Ok(()) - }; - Box::pin(future.instrument(span)) - } -} - -#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(transparent)] -pub struct HyperService(pub S); - -impl crate::server::Service for HyperService -where - S: hyper::service::Service< - http::Request>, - Response = http::Response, - Error: Error + Send, - Future: Send, - > + Clone - + Send - + 'static, - RespBody: Body + Send, -{ - type Future<'s> = BoxFuture<'s, ()>; - - fn serve<'s>(&self, req: &'s mut Request, resp: &'s mut Response) -> Self::Future<'s> { - let service = self.0.clone(); - Box::pin(async move { - let mut read_stream = Some(req.read_stream().take()); - let mut write_stream = resp.write_stream().take(); - resp.mark_taken_over(); - - let is_connect = req.method() == Method::CONNECT; - - let (remain_write_stream_tx, remain_write_stream) = RemainStream::pending(); - let request = if is_connect { - let remain_read_stream = RemainStream::immediately( - read_stream - .take() - .expect("connect request must have read stream"), - ); - http::Request::builder() - .method(req.method()) - .uri(req.uri()) - .extension(TakeoverSlot::new(remain_read_stream)) - .extension(TakeoverSlot::new(remain_write_stream.clone())) - .body(UnsyncBoxBody::new(Empty::new().map_err(|n| match n {}))) - } else { - http::Request::builder() - .method(req.method()) - .uri(req.uri()) - .body(UnsyncBoxBody::new( - read_stream - .take() - .expect("non-connect request must have read stream") - .into_hyper_body(), - )) - }; - - let mut request = match request { - Ok(request) => request, - Err(error) => { - tracing::warn!(error = %Report::from_error(error), "failed to convert request, skip serving"); - return; - } - }; - - *request.headers_mut() = req.headers().clone(); - request.extensions_mut().insert(resp.agent().clone()); - request.extensions_mut().insert(req.stream_id()); - request.extensions_mut().insert(req.protocols().clone()); - if let Some(remote_agent) = req.agent().cloned() { - request.extensions_mut().insert(remote_agent); - } - - match service.call(request).await { - Ok(response) => { - if let Err(error) = write_stream.send_hyper_response(response).await { - tracing::debug!(error = %Report::from_error(error), "failed to send response"); - } - if is_connect { - _ = write_stream.flush().await; - _ = remain_write_stream_tx.send(write_stream); - } else { - _ = write_stream.close().await; - } - } - Err(error) => { - tracing::debug!(error = %Report::from_error(error), "service failed") - } - } - }) - } -} - -impl tower_service::Service for HyperService -where - S: hyper::service::Service< - http::Request>, - Response = http::Response, - Error = ServiceE, - Future: Send, - > + Clone - + Send - + 'static, - ServiceE: Error + 'static, - RespBody: Body + Send, -{ - type Response = (); - - type Error = HandleRequestError; - - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call( - &mut self, - UnresolvedRequest { - stream_id, - read_stream, - write_stream: mut response_stream, - connection, - }: UnresolvedRequest, - ) -> Self::Future { - let span = tracing::info_span!( - "handle_request", - method = tracing::field::Empty, - uri = tracing::field::Empty - ); - - let service = self.0.clone(); - let future = async move { - let mut request = read_stream - .into_hyper_request() - .await - .context(handle_request_error::StreamSnafu)? - .map(UnsyncBoxBody::new); - tracing::Span::current() - .record("method", request.method().as_str()) - .record("uri", request.uri().to_string()); - - tracing::trace!("converted request stream to hyper request, serving..."); - let is_connect = request.method() == Method::CONNECT; - let (remain_write_stream_tx, remain_write_stream) = RemainStream::pending(); - - // Downstream handlers read the owning connection out of extensions - // and call `local_agent().await` / `remote_agent().await` / - // `protocols()` on demand. `stream_id` is request-scoped so it is - // still inserted alongside the connection. - request.extensions_mut().insert(stream_id); - request.extensions_mut().insert(connection); - if is_connect - && request - .extensions() - .get::>() - .is_some() - { - request - .extensions_mut() - .insert(TakeoverSlot::new(remain_write_stream.clone())); - } - - let response = service - .call(request) - .await - .context(handle_request_error::ServiceSnafu)?; - - response_stream.send_hyper_response(response).await?; - if is_connect { - response_stream - .flush() - .await - .context(handle_request_error::StreamSnafu)?; - _ = remain_write_stream_tx.send(response_stream); - } else { - response_stream - .close() - .await - .context(handle_request_error::StreamSnafu)?; - } - - Ok(()) - }; - Box::pin(future.instrument(span)) - } -} diff --git a/src/ipc/error.rs b/src/ipc/error.rs index 0e14f6d..630be2e 100644 --- a/src/ipc/error.rs +++ b/src/ipc/error.rs @@ -80,3 +80,128 @@ impl From for IpcAcceptError { IpcPlumbingError::Rpc { source: error }.into() } } + +#[cfg(test)] +mod tests { + use std::error::Error as _; + + use super::*; + use crate::{error::Code, quic::ApplicationError, varint::VarInt}; + + #[test] + fn plumbing_io_error_displays_message() { + let error = IpcPlumbingError::Io { + message: "socketpair failed".to_owned(), + }; + + assert_eq!(error.to_string(), "socketpair failed"); + } + + #[test] + fn plumbing_io_error_preserves_message_and_has_no_source() { + let error = IpcPlumbingError::Io { + message: "socketpair failed".to_owned(), + }; + + let IpcPlumbingError::Io { message } = &error else { + panic!("expected io plumbing error"); + }; + assert_eq!(message, "socketpair failed"); + assert!(error.source().is_none()); + } + + #[test] + fn plumbing_rpc_error_displays_like_call_error() { + let source = remoc::rtc::CallError::Dropped; + let expected = source.to_string(); + let error = IpcPlumbingError::Rpc { source }; + + assert_eq!(error.to_string(), expected); + } + + #[test] + fn plumbing_error_converts_to_open_and_accept_errors() { + let open_error = IpcOpenError::from(IpcPlumbingError::Io { + message: "open socketpair failed".to_owned(), + }); + let IpcOpenError::Plumbing { source } = open_error else { + panic!("plumbing error should stay plumbing-scoped on open"); + }; + assert_eq!(source.to_string(), "open socketpair failed"); + + let accept_error = IpcAcceptError::from(IpcPlumbingError::Io { + message: "accept socketpair failed".to_owned(), + }); + let IpcAcceptError::Plumbing { source } = accept_error else { + panic!("plumbing error should stay plumbing-scoped on accept"); + }; + assert_eq!(source.to_string(), "accept socketpair failed"); + } + + #[test] + fn open_and_accept_io_plumbing_errors_display_inner_message() { + let open_error = IpcOpenError::from(IpcPlumbingError::Io { + message: "open socketpair failed".to_owned(), + }); + assert_eq!(open_error.to_string(), "open socketpair failed"); + + let accept_error = IpcAcceptError::from(IpcPlumbingError::Io { + message: "accept socketpair failed".to_owned(), + }); + assert_eq!(accept_error.to_string(), "accept socketpair failed"); + } + + #[test] + fn call_error_converts_to_open_plumbing_error() { + let error = IpcOpenError::from(remoc::rtc::CallError::Dropped); + let IpcOpenError::Plumbing { + source: IpcPlumbingError::Rpc { source }, + } = error + else { + panic!("call error should map to open plumbing error"); + }; + + assert!(matches!(source, remoc::rtc::CallError::Dropped)); + } + + #[test] + fn call_error_converts_to_accept_plumbing_error() { + let error = IpcAcceptError::from(remoc::rtc::CallError::Dropped); + let IpcAcceptError::Plumbing { + source: IpcPlumbingError::Rpc { source }, + } = error + else { + panic!("call error should map to accept plumbing error"); + }; + + assert!(matches!(source, remoc::rtc::CallError::Dropped)); + } + + #[test] + fn connection_error_stays_connection_scoped_on_open_and_accept() { + let source = ConnectionError::Application { + source: ApplicationError { + code: Code::new(VarInt::from_u32(7)), + reason: "application closed".into(), + }, + }; + + let open_error = IpcOpenError::from(source.clone()); + let IpcOpenError::Connection { + source: open_source, + } = open_error + else { + panic!("connection error should stay connection-scoped on open"); + }; + assert!(open_source.is_application()); + + let accept_error = IpcAcceptError::from(source); + let IpcAcceptError::Connection { + source: accept_source, + } = accept_error + else { + panic!("connection error should stay connection-scoped on accept"); + }; + assert!(accept_source.is_application()); + } +} diff --git a/src/ipc/quic.rs b/src/ipc/quic.rs index 3c6bb73..6a62bdd 100644 --- a/src/ipc/quic.rs +++ b/src/ipc/quic.rs @@ -24,27 +24,22 @@ //! - [`IpcConnector`] — wraps [`IpcConnectClient`], implements `quic::Connect` //! - [`IpcListener`] — wraps [`IpcListenClient`], implements `quic::Listen` //! -//! ## Stream types (IPC read/write stream backed by Unix socketpairs) -//! -//! - [`IpcReadStream`] — implements `quic::ReadStream` -//! - [`IpcWriteStream`] — implements `quic::WriteStream` -//! //! ## Bootstrap //! //! - [`ConnectionBootstrap`] — one-shot value sent when a connection is established //! //! ## Bridge helpers //! -//! - [`bridge_reader`] — forward QUIC ReadStream → IpcWriteStream -//! - [`bridge_writer`] — forward IpcReadStream → QUIC WriteStream +//! - [`bridge_reader`] — run a QUIC read stream through IPC read frame IO +//! - [`bridge_writer`] — run IPC write frame IO into a QUIC write stream pub(crate) mod connection; pub(crate) mod connector; pub(crate) mod listener; -mod stream; +pub(crate) mod stream; #[cfg(test)] -mod tests; +mod test_utils; pub use self::{ connection::{ @@ -60,5 +55,4 @@ pub use self::{ IpcListenClient, IpcListenError, IpcListenReqReceiver, IpcListenServer, IpcListenServerSharedMut, IpcListener, ListenAdapter, }, - stream::{IpcReadStream, IpcWriteStream}, }; diff --git a/src/ipc/quic/connection.rs b/src/ipc/quic/connection.rs index aa2c26a..9c964b3 100644 --- a/src/ipc/quic/connection.rs +++ b/src/ipc/quic/connection.rs @@ -3,10 +3,9 @@ //! # RTC trait //! //! [`IpcConnection`] defines the RPC interface for connection-level stream -//! management over IPC. Each stream-opening method returns a `(VarInt, VarInt)` -//! pair: `(fd_registry_id, stream_id)`. The client retrieves the corresponding -//! socketpair FD(s) via [`FdRegistry::wait_fds`] and wraps them as -//! [`IpcReadStream`] / [`IpcWriteStream`]. +//! management over IPC. Each stream-opening method receives a caller-chosen +//! FD transfer ID and returns the underlying QUIC stream ID after the FD +//! delivery has been queued to the local mux writer FIFO. //! //! # Server side //! @@ -14,50 +13,59 @@ //! [`IpcConnection`]. Each `open_bi` / `accept_bi` call: //! 1. Opens a real QUIC stream pair via the inner connection. //! 2. Creates 2 Unix socketpairs (one per direction). -//! 3. Queues the client-side FDs through the [`FdSender`]. -//! 4. Spawns bridge tasks that forward data between the real QUIC -//! streams and local [`IpcWriteStream`] / [`IpcReadStream`] endpoints. -//! 5. Returns `(fd_registry_id, stream_id)` over RPC. +//! 3. Delivers the client-side FDs through the connection [`FdTransfer`]. +//! 4. Spawns bridge tasks that execute typed stream-frame IPC against the real +//! QUIC streams. +//! 5. Returns the stream ID over RPC. //! //! # Client side //! //! [`IpcConnectionHandle`] wraps an [`IpcConnectionClient`] and implements //! [`quic::Connection`]. Each `open_bi` call: -//! 1. Calls the RPC method to get `(fd_registry_id, stream_id)`. -//! 2. Retrieves the socketpair FDs from the [`FdRegistry`]. -//! 3. Wraps them as [`IpcReadStream`] / [`IpcWriteStream`]. +//! 1. Reserves a receiver-chosen FD transfer ID. +//! 2. Calls the RPC method with that ID while concurrently receiving the FDs. +//! 3. Wraps them as boxed QUIC stream handles backed by typed IPC frame IO. //! //! # Bootstrap //! //! [`ConnectionBootstrap`] is the one-shot value sent over the remoc base //! channel when a new connection is established. It carries the -//! [`IpcConnectionClient`] for stream management and agent access. +//! [`IpcConnectionClient`] for stream management and authority access. -use std::{borrow::Cow, io, sync::Arc}; +use std::{ + borrow::Cow, + future::Future, + io, + sync::{Arc, Mutex}, +}; -use futures::{SinkExt, StreamExt}; use remoc::prelude::ServerShared; use serde::{Deserialize, Serialize}; use smallvec::smallvec; use tokio::net::UnixStream; +use tokio_util::task::AbortOnDropHandle; use tracing::{Instrument, debug}; use crate::{ error::Code, ipc::{ error::{IpcAcceptError, IpcOpenError, IpcPlumbingError}, - quic::stream::{IpcBiHandle, IpcReadStream, IpcUniHandle, IpcWriteStream}, - transport::{FdRegistry, FdSender}, + quic::stream::{ + IpcBiHandle, IpcUniHandle, + reader::{self as ipc_reader, IpcReadHypervisorIo}, + writer::{self as ipc_writer, IpcWriteHypervisorIo}, + }, + transport::{FdDelivery, FdTransfer, ReceivedFds}, }, quic::{ - self, ConnectionError, DynLifecycle, GetStreamIdExt, ManageStream, ReadStream, StreamError, - WriteStream, + self, BoxQuicStreamReader, BoxQuicStreamWriter, ConnectionError, GetStreamIdExt, + ManageStream, ReadStream, ResetStreamExt, StopStreamExt, StreamError, WriteStream, }, rpc::{ lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, quic::{ - CachedLocalAgent, CachedRemoteAgent, LocalAgentClient, LocalAgentServerShared, - RemoteAgentClient, RemoteAgentServerShared, + CachedLocalAuthority, CachedRemoteAuthority, LocalAuthorityClient, + LocalAuthorityServerShared, RemoteAuthorityClient, RemoteAuthorityServerShared, }, }, util::deferred::Resolved, @@ -70,10 +78,10 @@ use crate::{ /// Remote trait for IPC connection-level stream management. /// -/// Each stream-opening method returns a `(VarInt, VarInt)` pair: -/// - The first element is the FD-registry ID for -/// [`FdRegistry::wait_fds`](crate::ipc::transport::FdRegistry::wait_fds). -/// - The second element is the underlying QUIC stream ID. +/// Each stream-opening method receives a receiver-chosen FD transfer ID and +/// returns a handle carrying the underlying QUIC stream ID. Server-side +/// implementations must not return until the FD delivery has been +/// queued to the local mux writer FIFO. /// /// Agent and lifecycle methods mirror [`crate::rpc::quic::Connection`] and /// are forwarded over the same remoc channel — only bulk stream data travels @@ -81,12 +89,11 @@ use crate::{ /// /// # FD semantics /// -/// - **`open_bi` / `accept_bi`**: 2 FDs are queued (2 independent socketpairs). -/// The first FD carries the reader-side pipe (server's IpcWriteStream ↔ -/// client's IpcReadStream), the second carries the writer-side pipe -/// (server's IpcReadStream ↔ client's IpcWriteStream). +/// - **`open_bi` / `accept_bi`**: 2 FDs are delivered (2 independent socketpairs). +/// The first FD carries read-side frame IO; the second carries write-side +/// frame IO. /// -/// - **`open_uni` / `accept_uni`**: 1 FD is queued (a single socketpair). +/// - **`open_uni` / `accept_uni`**: 1 FD is delivered (a single socketpair). #[remoc::rtc::remote] pub trait IpcConnection: Send + Sync { /// Open a bidirectional stream. @@ -94,28 +101,40 @@ pub trait IpcConnection: Send + Sync { /// Returns `Resolved` on success — the handle /// may carry a stream error (e.g. stream_id retrieval failed) wrapped in /// [`Resolved::Error`]. Infrastructure failures yield [`IpcOpenError`]. - async fn open_bi(&self) -> Result, IpcOpenError>; + async fn open_bi( + &self, + fd_id: VarInt, + ) -> Result, IpcOpenError>; /// Accept an incoming bidirectional stream. /// /// Returns `Resolved` on success. - async fn accept_bi(&self) -> Result, IpcAcceptError>; + async fn accept_bi( + &self, + fd_id: VarInt, + ) -> Result, IpcAcceptError>; /// Open a unidirectional (send-only) stream. /// /// Returns `Resolved` on success. - async fn open_uni(&self) -> Result, IpcOpenError>; + async fn open_uni( + &self, + fd_id: VarInt, + ) -> Result, IpcOpenError>; /// Accept an incoming unidirectional (receive-only) stream. /// /// Returns `Resolved` on success. - async fn accept_uni(&self) -> Result, IpcAcceptError>; + async fn accept_uni( + &self, + fd_id: VarInt, + ) -> Result, IpcAcceptError>; - /// Obtain the local agent (signing / identity) handle, if available. - async fn local_agent(&self) -> Result, ConnectionError>; + /// Obtain the local authority (signing / identity) handle, if available. + async fn local_authority(&self) -> Result, ConnectionError>; - /// Obtain the remote agent (verification) handle, if available. - async fn remote_agent(&self) -> Result, ConnectionError>; + /// Obtain the remote authority (verification) handle, if available. + async fn remote_authority(&self) -> Result, ConnectionError>; /// Close the connection with an application error code and reason. async fn close(&self, code: Code, reason: Cow<'static, str>) -> Result<(), ConnectionError>; @@ -132,11 +151,11 @@ pub trait IpcConnection: Send + Sync { /// IPC connection is established. /// /// Carries the [`IpcConnectionClient`] which provides both stream management -/// RPC and agent access — agent methods are part of the [`IpcConnection`] +/// RPC and authority access — authority methods are part of the [`IpcConnection`] /// trait alongside stream operations. #[derive(Serialize, Deserialize)] pub struct ConnectionBootstrap { - /// RPC client for stream operations, agent access, and lifecycle. + /// RPC client for stream operations, authority access, and lifecycle. pub connection: IpcConnectionClient, } @@ -144,36 +163,30 @@ pub struct ConnectionBootstrap { // Bridge helpers: forward data between real QUIC streams and pipe socketpairs // --------------------------------------------------------------------------- -/// Forward data from a QUIC [`ReadStream`] to a [`IpcWriteStream`] (server side). +/// Execute a QUIC [`ReadStream`] through hypervisor-side IPC read frame IO. /// -/// Reads chunks from the QUIC stream and sinks them into the pipe. -/// Terminates when either side closes or errors. -pub async fn bridge_reader( - mut quic_reader: impl ReadStream + Unpin, - mut pipe_writer: IpcWriteStream, -) { - while let Some(Ok(chunk)) = quic_reader.next().await { - if pipe_writer.send(chunk).await.is_err() { - break; - } - } - let _ = pipe_writer.close().await; +/// Worker pipe EOF is stream-handle abandonment on this side. Already received +/// commands are drained by the shared hypervisor bridge; no EOF-derived QUIC +/// control is synthesized. +pub async fn bridge_reader(quic_reader: impl ReadStream + Unpin, pipe: UnixStream) { + crate::rpc::stream::hypervisor::read::run_read_bridge( + quic_reader, + IpcReadHypervisorIo::new(pipe), + ) + .await; } -/// Forward data from a [`IpcReadStream`] to a QUIC [`WriteStream`] (server side). +/// Execute hypervisor-side IPC write frame IO against a QUIC [`WriteStream`]. /// -/// Reads chunks from the pipe and sinks them into the QUIC stream. -/// Terminates when either side closes or errors. -pub async fn bridge_writer( - mut pipe_reader: IpcReadStream, - mut quic_writer: impl WriteStream + Unpin, -) { - while let Some(Ok(chunk)) = pipe_reader.next().await { - if quic_writer.send(chunk).await.is_err() { - break; - } - } - let _ = quic_writer.close().await; +/// Worker pipe EOF is stream-handle abandonment on this side. Already received +/// commands are drained by the shared hypervisor bridge; no EOF-derived FIN or +/// RESET_STREAM is synthesized. +pub async fn bridge_writer(pipe: UnixStream, quic_writer: impl WriteStream + Unpin) { + crate::rpc::stream::hypervisor::write::run_write_bridge( + quic_writer, + IpcWriteHypervisorIo::new(pipe), + ) + .await; } // --------------------------------------------------------------------------- @@ -184,16 +197,31 @@ pub async fn bridge_writer( /// [`IpcConnection`]. /// /// The inner connection is shared via `Arc` so that: -/// - Bridge tasks can reference it as `Arc`. -/// - The [`FdSender`] queues client-side FDs for each new stream. +/// - Bridge tasks execute stream-frame IPC against concrete QUIC streams. +/// - The [`FdTransfer`] delivers client-side FDs for each new stream. pub struct ConnectionAdapter { inner: Arc, - fd_sender: FdSender, + fd_transfer: FdTransfer, + tasks: Mutex>>, } impl ConnectionAdapter { - pub fn new(inner: Arc, fd_sender: FdSender) -> Self { - Self { inner, fd_sender } + pub fn new(inner: Arc, fd_transfer: FdTransfer) -> Self { + Self { + inner, + fd_transfer, + tasks: Mutex::new(Vec::new()), + } + } + + fn spawn_task(&self, task: impl Future + Send + 'static) { + let handle = AbortOnDropHandle::new(tokio::spawn(task.in_current_span())); + let mut tasks = self + .tasks + .lock() + .expect("connection adapter task registry should not be poisoned"); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); } } @@ -201,58 +229,64 @@ impl IpcConnection for ConnectionAdapter where M: ManageStream + quic::Lifecycle - + quic::WithLocalAgent - + quic::WithRemoteAgent + + quic::WithLocalAuthority + + quic::WithRemoteAuthority + Send + Sync + 'static, M::StreamReader: Unpin + 'static, M::StreamWriter: Unpin + 'static, - M::LocalAgent: Send + Sync, - M::RemoteAgent: Send + Sync, + M::LocalAuthority: Send + Sync, + M::RemoteAuthority: Send + Sync, { - async fn open_bi(&self) -> Result, IpcOpenError> { - self.open_bi_impl().await + async fn open_bi( + &self, + fd_id: VarInt, + ) -> Result, IpcOpenError> { + self.open_bi_impl(fd_id).await } - async fn accept_bi(&self) -> Result, IpcAcceptError> { - self.accept_bi_impl().await + async fn accept_bi( + &self, + fd_id: VarInt, + ) -> Result, IpcAcceptError> { + self.accept_bi_impl(fd_id).await } - async fn open_uni(&self) -> Result, IpcOpenError> { - self.open_uni_impl().await + async fn open_uni( + &self, + fd_id: VarInt, + ) -> Result, IpcOpenError> { + self.open_uni_impl(fd_id).await } - async fn accept_uni(&self) -> Result, IpcAcceptError> { - self.accept_uni_impl().await + async fn accept_uni( + &self, + fd_id: VarInt, + ) -> Result, IpcAcceptError> { + self.accept_uni_impl(fd_id).await } - async fn local_agent(&self) -> Result, ConnectionError> { - match quic::WithLocalAgent::local_agent(self.inner.as_ref()).await? { + async fn local_authority(&self) -> Result, ConnectionError> { + match quic::WithLocalAuthority::local_authority(self.inner.as_ref()).await? { Some(agent) => { - let (server, client) = LocalAgentServerShared::new(Arc::new(agent), 1); - tokio::spawn( - (async move { - let _ = server.serve(true).await; - }) - .in_current_span(), - ); + let (server, client) = LocalAuthorityServerShared::new(Arc::new(agent), 1); + self.spawn_task(async move { + let _ = server.serve(true).await; + }); Ok(Some(client)) } None => Ok(None), } } - async fn remote_agent(&self) -> Result, ConnectionError> { - match quic::WithRemoteAgent::remote_agent(self.inner.as_ref()).await? { + async fn remote_authority(&self) -> Result, ConnectionError> { + match quic::WithRemoteAuthority::remote_authority(self.inner.as_ref()).await? { Some(agent) => { - let (server, client) = RemoteAgentServerShared::new(Arc::new(agent), 1); - tokio::spawn( - (async move { - let _ = server.serve(true).await; - }) - .in_current_span(), - ); + let (server, client) = RemoteAuthorityServerShared::new(Arc::new(agent), 1); + self.spawn_task(async move { + let _ = server.serve(true).await; + }); Ok(Some(client)) } None => Ok(None), @@ -273,54 +307,59 @@ impl ConnectionAdapter where M: ManageStream + quic::Lifecycle - + quic::WithLocalAgent - + quic::WithRemoteAgent + + quic::WithLocalAuthority + + quic::WithRemoteAuthority + Send + Sync + 'static, M::StreamReader: Unpin + 'static, M::StreamWriter: Unpin + 'static, - M::LocalAgent: Send + Sync, - M::RemoteAgent: Send + Sync, + M::LocalAuthority: Send + Sync, + M::RemoteAuthority: Send + Sync, { - async fn open_bi_impl(&self) -> Result, IpcOpenError> { - let (mut reader, writer) = ManageStream::open_bi(self.inner.as_ref()) - .await - .map_err(|e| IpcOpenError::Connection { source: e })?; + async fn open_bi_impl( + &self, + fd_id: VarInt, + ) -> Result, IpcOpenError> { + let delivery = self.fd_transfer.delivery(fd_id); + let (mut reader, writer) = ManageStream::open_bi(self.inner.as_ref()).await?; let stream_id = match reader.stream_id().await { Ok(id) => id, Err(stream_err) => return Ok(Resolved::err(stream_err)), }; - self.bridge_bi(reader, writer, stream_id) + self.bridge_bi(delivery, reader, writer, stream_id) + .await .map(Resolved::ok) .map_err(IpcOpenError::from) } - async fn accept_bi_impl(&self) -> Result, IpcAcceptError> { - let (mut reader, writer) = ManageStream::accept_bi(self.inner.as_ref()) - .await - .map_err(|e| IpcAcceptError::Connection { source: e })?; + async fn accept_bi_impl( + &self, + fd_id: VarInt, + ) -> Result, IpcAcceptError> { + let delivery = self.fd_transfer.delivery(fd_id); + let (mut reader, writer) = ManageStream::accept_bi(self.inner.as_ref()).await?; let stream_id = match reader.stream_id().await { Ok(id) => id, Err(stream_err) => return Ok(Resolved::err(stream_err)), }; - self.bridge_bi(reader, writer, stream_id) + self.bridge_bi(delivery, reader, writer, stream_id) + .await .map(Resolved::ok) .map_err(IpcAcceptError::from) } /// Shared logic for open_bi / accept_bi after obtaining the real streams. - fn bridge_bi( + async fn bridge_bi( &self, - reader: M::StreamReader, - writer: M::StreamWriter, + delivery: FdDelivery, + mut reader: M::StreamReader, + mut writer: M::StreamWriter, stream_id: VarInt, ) -> Result { - let lifecycle: Arc = self.inner.clone(); - - // Socketpair for the reader direction: server IpcWriteStream ↔ client IpcReadStream + // Socketpair for the reader direction: server read hypervisor ↔ client read bridge let (srv_a, cli_a) = UnixStream::pair().map_err(|e| ipc_io_plumbing(e, "socketpair"))?; - // Socketpair for the writer direction: server IpcReadStream ↔ client IpcWriteStream + // Socketpair for the writer direction: server write hypervisor ↔ client write bridge let (srv_b, cli_b) = UnixStream::pair().map_err(|e| ipc_io_plumbing(e, "socketpair"))?; let cli_a_std = cli_a @@ -330,68 +369,69 @@ where .into_std() .map_err(|e| ipc_io_plumbing(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_a_std.into(), cli_b_std.into()]) - .map_err(|e| ipc_io_plumbing(e, "queue_fds"))?; + if let Err(error) = delivery + .deliver(smallvec![cli_a_std.into(), cli_b_std.into()]) + .await + { + let code = Code::H3_REQUEST_CANCELLED.into_inner(); + let _ = reader.stop(code).await; + let _ = writer.reset(code).await; + return Err(ipc_io_plumbing(error, "deliver fds")); + } - // Bridge reader direction: real QUIC reader → IpcWriteStream on srv_a - let pipe_w = IpcWriteStream::new(stream_id, srv_a, lifecycle.clone()); - tokio::spawn(bridge_reader(reader, pipe_w).in_current_span()); + // Bridge reader direction: real QUIC reader ↔ read frame IO on srv_a. + self.spawn_task(bridge_reader(reader, srv_a)); - // Bridge writer direction: IpcReadStream on srv_b → real QUIC writer - let pipe_r = IpcReadStream::new(stream_id, srv_b, lifecycle); - tokio::spawn(bridge_writer(pipe_r, writer).in_current_span()); + // Bridge writer direction: write frame IO on srv_b ↔ real QUIC writer. + self.spawn_task(bridge_writer(srv_b, writer)); - Ok(IpcBiHandle { fd_id, stream_id }) + Ok(IpcBiHandle { stream_id }) } - async fn open_uni_impl(&self) -> Result, IpcOpenError> { - let mut writer = ManageStream::open_uni(self.inner.as_ref()) - .await - .map_err(|e| IpcOpenError::Connection { source: e })?; + async fn open_uni_impl( + &self, + fd_id: VarInt, + ) -> Result, IpcOpenError> { + let delivery = self.fd_transfer.delivery(fd_id); + let mut writer = ManageStream::open_uni(self.inner.as_ref()).await?; let stream_id = match writer.stream_id().await { Ok(id) => id, Err(stream_err) => return Ok(Resolved::err(stream_err)), }; - let lifecycle: Arc = self.inner.clone(); - let (srv, cli) = UnixStream::pair().map_err(|e| ipc_io_plumbing(e, "socketpair"))?; let cli_std = cli.into_std().map_err(|e| ipc_io_plumbing(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_std.into()]) - .map_err(|e| ipc_io_plumbing(e, "queue_fds"))?; + if let Err(error) = delivery.deliver(smallvec![cli_std.into()]).await { + let _ = writer.reset(Code::H3_REQUEST_CANCELLED.into_inner()).await; + return Err(IpcOpenError::from(ipc_io_plumbing(error, "deliver fds"))); + } - // IpcReadStream on srv → real QUIC writer - let pipe_r = IpcReadStream::new(stream_id, srv, lifecycle); - tokio::spawn(bridge_writer(pipe_r, writer).in_current_span()); + // Write frame IO on srv ↔ real QUIC writer. + self.spawn_task(bridge_writer(srv, writer)); - Ok(Resolved::ok(IpcUniHandle { fd_id, stream_id })) + Ok(Resolved::ok(IpcUniHandle { stream_id })) } - async fn accept_uni_impl(&self) -> Result, IpcAcceptError> { - let mut reader = ManageStream::accept_uni(self.inner.as_ref()) - .await - .map_err(|e| IpcAcceptError::Connection { source: e })?; + async fn accept_uni_impl( + &self, + fd_id: VarInt, + ) -> Result, IpcAcceptError> { + let delivery = self.fd_transfer.delivery(fd_id); + let mut reader = ManageStream::accept_uni(self.inner.as_ref()).await?; let stream_id = match reader.stream_id().await { Ok(id) => id, Err(stream_err) => return Ok(Resolved::err(stream_err)), }; - let lifecycle: Arc = self.inner.clone(); - let (srv, cli) = UnixStream::pair().map_err(|e| ipc_io_plumbing(e, "socketpair"))?; let cli_std = cli.into_std().map_err(|e| ipc_io_plumbing(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_std.into()]) - .map_err(|e| ipc_io_plumbing(e, "queue_fds"))?; + if let Err(error) = delivery.deliver(smallvec![cli_std.into()]).await { + let _ = reader.stop(Code::H3_REQUEST_CANCELLED.into_inner()).await; + return Err(IpcAcceptError::from(ipc_io_plumbing(error, "deliver fds"))); + } - // Real QUIC reader → IpcWriteStream on srv - let pipe_w = IpcWriteStream::new(stream_id, srv, lifecycle); - tokio::spawn(bridge_reader(reader, pipe_w).in_current_span()); + // Real QUIC reader ↔ read frame IO on srv. + self.spawn_task(bridge_reader(reader, srv)); - Ok(Resolved::ok(IpcUniHandle { fd_id, stream_id })) + Ok(Resolved::ok(IpcUniHandle { stream_id })) } } @@ -406,10 +446,8 @@ where /// remoc base channel of a connection-level [`MuxChannel`]. pub struct IpcConnectionHandle { rpc: IpcConnectionClient, - fd_registry: FdRegistry, - /// Retained for potential future use (e.g. client-initiated accept_bi - /// where the client creates socketpairs and sends FDs back to the server). - _conn_fd_sender: FdSender, + fd_transfer: FdTransfer, + _remoc_task: AbortOnDropHandle<()>, lifecycle: Arc, } @@ -418,10 +456,11 @@ pub struct IpcConnectionHandle { /// Owns the latch that enforces first-wins error semantics across every /// operation on the connection and its descendant streams. Implements /// [`quic::Lifecycle`] so the shared `Arc` can be handed to -/// [`IpcReadStream`] / [`IpcWriteStream`] as `Arc`. +/// direct IPC stream-frame bridge handles. struct IpcLifecycle { connection: IpcConnectionClient, latch: ConnectionErrorLatch, + close_tasks: Mutex>>, } impl HasLatch for IpcLifecycle { @@ -433,12 +472,18 @@ impl HasLatch for IpcLifecycle { impl quic::Lifecycle for IpcLifecycle { fn close(&self, code: Code, reason: Cow<'static, str>) { let rpc = self.connection.clone(); - tokio::spawn( - async move { + let handle = AbortOnDropHandle::new(tokio::spawn( + (async move { let _ = IpcConnection::close(&rpc, code, reason).await; - } + }) .in_current_span(), - ); + )); + let mut tasks = self + .close_tasks + .lock() + .expect("ipc lifecycle close task registry should not be poisoned"); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); } fn check(&self) -> Result<(), ConnectionError> { @@ -462,17 +507,18 @@ impl IpcConnectionHandle { /// Create a new handle from bootstrap data and FD registry. pub fn new( rpc: IpcConnectionClient, - fd_registry: FdRegistry, - conn_fd_sender: FdSender, + fd_transfer: FdTransfer, + remoc_task: AbortOnDropHandle<()>, ) -> Self { let lifecycle = Arc::new(IpcLifecycle { connection: rpc.clone(), latch: ConnectionErrorLatch::new(), + close_tasks: Mutex::new(Vec::new()), }); Self { rpc, - fd_registry, - _conn_fd_sender: conn_fd_sender, + fd_transfer, + _remoc_task: remoc_task, lifecycle, } } @@ -490,17 +536,15 @@ impl IpcConnectionHandle { } impl quic::ManageStream for IpcConnectionHandle { - type StreamReader = Resolved; - type StreamWriter = Resolved; + type StreamReader = Resolved; + type StreamWriter = Resolved; async fn open_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { - let resolved = self - .lifecycle - .guard_with(IpcConnection::open_bi(&self.rpc), map_open_err) - .await?; + let (resolved, received) = self.open_bi_with_fds().await?; match resolved { Resolved::Value { value: handle } => { - let (r, w) = self.fds_to_bi(handle).await?; + let received = received.expect("value response must include received fds"); + let (r, w) = self.fds_to_bi(handle, received).await?; Ok((Resolved::ok(r), Resolved::ok(w))) } Resolved::Error { error } => Ok((Resolved::err(error.clone()), Resolved::err(error))), @@ -508,13 +552,11 @@ impl quic::ManageStream for IpcConnectionHandle { } async fn accept_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { - let resolved = self - .lifecycle - .guard_with(IpcConnection::accept_bi(&self.rpc), map_accept_err) - .await?; + let (resolved, received) = self.accept_bi_with_fds().await?; match resolved { Resolved::Value { value: handle } => { - let (r, w) = self.fds_to_bi(handle).await?; + let received = received.expect("value response must include received fds"); + let (r, w) = self.fds_to_bi(handle, received).await?; Ok((Resolved::ok(r), Resolved::ok(w))) } Resolved::Error { error } => Ok((Resolved::err(error.clone()), Resolved::err(error))), @@ -522,13 +564,11 @@ impl quic::ManageStream for IpcConnectionHandle { } async fn open_uni(&self) -> Result { - let resolved = self - .lifecycle - .guard_with(IpcConnection::open_uni(&self.rpc), map_open_err) - .await?; + let (resolved, received) = self.open_uni_with_fds().await?; match resolved { Resolved::Value { value: handle } => { - let w = self.fds_to_uni_writer(handle).await?; + let received = received.expect("value response must include received fds"); + let w = self.fds_to_uni_writer(handle, received).await?; Ok(Resolved::ok(w)) } Resolved::Error { error } => Ok(Resolved::err(error)), @@ -536,13 +576,11 @@ impl quic::ManageStream for IpcConnectionHandle { } async fn accept_uni(&self) -> Result { - let resolved = self - .lifecycle - .guard_with(IpcConnection::accept_uni(&self.rpc), map_accept_err) - .await?; + let (resolved, received) = self.accept_uni_with_fds().await?; match resolved { Resolved::Value { value: handle } => { - let r = self.fds_to_uni_reader(handle).await?; + let received = received.expect("value response must include received fds"); + let r = self.fds_to_uni_reader(handle, received).await?; Ok(Resolved::ok(r)) } Resolved::Error { error } => Ok(Resolved::err(error)), @@ -551,121 +589,187 @@ impl quic::ManageStream for IpcConnectionHandle { } impl IpcConnectionHandle { - /// Retrieve 2 FDs and construct a (IpcReadStream, IpcWriteStream) pair. + async fn open_bi_with_fds( + &self, + ) -> Result<(Resolved, Option), ConnectionError> { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + self.lifecycle + .guard(self.resolve_open_with_fds(receiver, IpcConnection::open_bi(&self.rpc, fd_id))) + .await + } + + async fn accept_bi_with_fds( + &self, + ) -> Result<(Resolved, Option), ConnectionError> { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + self.lifecycle + .guard( + self.resolve_accept_with_fds(receiver, IpcConnection::accept_bi(&self.rpc, fd_id)), + ) + .await + } + + async fn open_uni_with_fds( + &self, + ) -> Result<(Resolved, Option), ConnectionError> { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + self.lifecycle + .guard(self.resolve_open_with_fds(receiver, IpcConnection::open_uni(&self.rpc, fd_id))) + .await + } + + async fn accept_uni_with_fds( + &self, + ) -> Result<(Resolved, Option), ConnectionError> { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + self.lifecycle + .guard( + self.resolve_accept_with_fds(receiver, IpcConnection::accept_uni(&self.rpc, fd_id)), + ) + .await + } + + async fn resolve_open_with_fds( + &self, + receiver: crate::ipc::transport::FdReceiver, + rpc: impl Future, IpcOpenError>>, + ) -> Result<(Resolved, Option), ConnectionError> { + let receive = receiver.into_future(); + tokio::pin!(receive); + tokio::pin!(rpc); + + tokio::select! { + biased; + receive_result = &mut receive => { + let received = receive_result.map_err(|e| ipc_transport_error(e, "receive fds"))?; + let resolved = rpc.await.map_err(map_open_err)?; + Ok((resolved, Some(received))) + } + rpc_result = &mut rpc => { + let resolved = rpc_result.map_err(map_open_err)?; + match resolved { + Resolved::Value { value } => { + let received = receive.await.map_err(|e| ipc_transport_error(e, "receive fds"))?; + Ok((Resolved::ok(value), Some(received))) + } + Resolved::Error { error } => Ok((Resolved::err(error), None)), + } + } + } + } + + async fn resolve_accept_with_fds( + &self, + receiver: crate::ipc::transport::FdReceiver, + rpc: impl Future, IpcAcceptError>>, + ) -> Result<(Resolved, Option), ConnectionError> { + let receive = receiver.into_future(); + tokio::pin!(receive); + tokio::pin!(rpc); + + tokio::select! { + biased; + receive_result = &mut receive => { + let received = receive_result.map_err(|e| ipc_transport_error(e, "receive fds"))?; + let resolved = rpc.await.map_err(map_accept_err)?; + Ok((resolved, Some(received))) + } + rpc_result = &mut rpc => { + let resolved = rpc_result.map_err(map_accept_err)?; + match resolved { + Resolved::Value { value } => { + let received = receive.await.map_err(|e| ipc_transport_error(e, "receive fds"))?; + Ok((Resolved::ok(value), Some(received))) + } + Resolved::Error { error } => Ok((Resolved::err(error), None)), + } + } + } + } + + /// Retrieve 2 FDs and construct boxed IPC stream-frame bridge handles. async fn fds_to_bi( &self, handle: IpcBiHandle, - ) -> Result<(IpcReadStream, IpcWriteStream), ConnectionError> { - let IpcBiHandle { fd_id, stream_id } = handle; - let fds = self - .lifecycle - .guard_with(self.fd_registry.wait_fds(fd_id), |e| { - ipc_transport_error(e, "wait_fds") - }) - .await?; + received: ReceivedFds, + ) -> Result<(BoxQuicStreamReader, BoxQuicStreamWriter), ConnectionError> { + let IpcBiHandle { stream_id } = handle; self.lifecycle.guard_sync(|| { - if fds.len() != 2 { - return Err(ConnectionError::Transport { - source: quic::TransportError { - kind: IPC_ERROR_KIND, - frame_type: IPC_FRAME_TYPE, - reason: format!("expected 2 fds for bidi stream, got {}", fds.len()).into(), - }, - }); - } - let mut fds = fds.into_iter(); - let fd_a = fds.next().unwrap(); - let fd_b = fds.next().unwrap(); + let (fd_a, fd_b) = received + .into_pair() + .map_err(|e| ipc_transport_error(e, "fd count"))?; - let lifecycle: Arc = self.lifecycle.clone(); + let lifecycle = self.lifecycle.clone(); - // fd_a → reader pipe (matches server's IpcWriteStream on srv_a) + // fd_a → reader pipe (matches server's read hypervisor on srv_a) let sock_a = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd_a)) .map_err(|e| ipc_io_error(e, "UnixStream::from_std"))?; - let reader = IpcReadStream::new(stream_id, sock_a, lifecycle.clone()); + let reader = Box::pin(ipc_reader::reader(stream_id, sock_a, lifecycle.clone())) + as BoxQuicStreamReader; - // fd_b → writer pipe (matches server's IpcReadStream on srv_b) + // fd_b → writer pipe (matches server's write hypervisor on srv_b) let sock_b = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd_b)) .map_err(|e| ipc_io_error(e, "UnixStream::from_std"))?; - let writer = IpcWriteStream::new(stream_id, sock_b, lifecycle); + let writer = + Box::pin(ipc_writer::writer(stream_id, sock_b, lifecycle)) as BoxQuicStreamWriter; Ok((reader, writer)) }) } - /// Retrieve 1 FD and construct a IpcWriteStream (for open_uni). + /// Retrieve 1 FD and construct a boxed IPC write bridge (for open_uni). async fn fds_to_uni_writer( &self, handle: IpcUniHandle, - ) -> Result { - let IpcUniHandle { fd_id, stream_id } = handle; - let fds = self - .lifecycle - .guard_with(self.fd_registry.wait_fds(fd_id), |e| { - ipc_transport_error(e, "wait_fds") - }) - .await?; + received: ReceivedFds, + ) -> Result { + let IpcUniHandle { stream_id } = handle; self.lifecycle.guard_sync(|| { - if fds.len() != 1 { - return Err(ConnectionError::Transport { - source: quic::TransportError { - kind: IPC_ERROR_KIND, - frame_type: IPC_FRAME_TYPE, - reason: format!("expected 1 fd for uni stream, got {}", fds.len()).into(), - }, - }); - } - let fd = fds.into_iter().next().unwrap(); - let lifecycle: Arc = self.lifecycle.clone(); + let fd = received + .into_one() + .map_err(|e| ipc_transport_error(e, "fd count"))?; + let lifecycle = self.lifecycle.clone(); let sock = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd)) .map_err(|e| ipc_io_error(e, "UnixStream::from_std"))?; - Ok(IpcWriteStream::new(stream_id, sock, lifecycle)) + Ok(Box::pin(ipc_writer::writer(stream_id, sock, lifecycle)) as BoxQuicStreamWriter) }) } - /// Retrieve 1 FD and construct a IpcReadStream (for accept_uni). + /// Retrieve 1 FD and construct a boxed IPC read bridge (for accept_uni). async fn fds_to_uni_reader( &self, handle: IpcUniHandle, - ) -> Result { - let IpcUniHandle { fd_id, stream_id } = handle; - let fds = self - .lifecycle - .guard_with(self.fd_registry.wait_fds(fd_id), |e| { - ipc_transport_error(e, "wait_fds") - }) - .await?; + received: ReceivedFds, + ) -> Result { + let IpcUniHandle { stream_id } = handle; self.lifecycle.guard_sync(|| { - if fds.len() != 1 { - return Err(ConnectionError::Transport { - source: quic::TransportError { - kind: IPC_ERROR_KIND, - frame_type: IPC_FRAME_TYPE, - reason: format!("expected 1 fd for uni stream, got {}", fds.len()).into(), - }, - }); - } - let fd = fds.into_iter().next().unwrap(); - let lifecycle: Arc = self.lifecycle.clone(); + let fd = received + .into_one() + .map_err(|e| ipc_transport_error(e, "fd count"))?; + let lifecycle = self.lifecycle.clone(); let sock = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd)) .map_err(|e| ipc_io_error(e, "UnixStream::from_std"))?; - Ok(IpcReadStream::new(stream_id, sock, lifecycle)) + Ok(Box::pin(ipc_reader::reader(stream_id, sock, lifecycle)) as BoxQuicStreamReader) }) } } -impl quic::WithLocalAgent for IpcConnectionHandle { - type LocalAgent = CachedLocalAgent; +impl quic::WithLocalAuthority for IpcConnectionHandle { + type LocalAuthority = CachedLocalAuthority; - async fn local_agent(&self) -> Result, ConnectionError> { + async fn local_authority(&self) -> Result, ConnectionError> { match self .lifecycle - .guard(IpcConnection::local_agent(&self.rpc)) + .guard(IpcConnection::local_authority(&self.rpc)) .await? { Some(client) => Ok(Some( self.lifecycle - .guard(CachedLocalAgent::from_client(client)) + .guard(CachedLocalAuthority::from_client(client)) .await?, )), None => Ok(None), @@ -673,18 +777,18 @@ impl quic::WithLocalAgent for IpcConnectionHandle { } } -impl quic::WithRemoteAgent for IpcConnectionHandle { - type RemoteAgent = CachedRemoteAgent; +impl quic::WithRemoteAuthority for IpcConnectionHandle { + type RemoteAuthority = CachedRemoteAuthority; - async fn remote_agent(&self) -> Result, ConnectionError> { + async fn remote_authority(&self) -> Result, ConnectionError> { match self .lifecycle - .guard(IpcConnection::remote_agent(&self.rpc)) + .guard(IpcConnection::remote_authority(&self.rpc)) .await? { Some(client) => Ok(Some( self.lifecycle - .guard(CachedRemoteAgent::from_client(client)) + .guard(CachedRemoteAuthority::from_client(client)) .await?, )), None => Ok(None), @@ -741,9 +845,9 @@ fn ipc_transport_error(err: impl std::error::Error, context: &str) -> Connection } } -/// Convert an I/O error into an [`IpcPlumbingError`] (server side). -fn ipc_io_plumbing(err: impl std::fmt::Display, context: &str) -> IpcPlumbingError { - debug!(error = %err, context, "ipc plumbing i/o error"); +/// Convert an infrastructure error into an [`IpcPlumbingError`] (server side). +fn ipc_io_plumbing(err: impl std::error::Error, context: &str) -> IpcPlumbingError { + debug!(error = %snafu::Report::from_error(&err), context, "ipc plumbing i/o error"); IpcPlumbingError::Io { message: format!("{context}: {err}"), } @@ -784,9 +888,9 @@ impl IpcConnectionClient { /// Convert into an [`IpcConnectionHandle`]. pub fn into_handle( self, - fd_registry: FdRegistry, - conn_fd_sender: FdSender, + fd_transfer: FdTransfer, + remoc_task: AbortOnDropHandle<()>, ) -> IpcConnectionHandle { - IpcConnectionHandle::new(self, fd_registry, conn_fd_sender) + IpcConnectionHandle::new(self, fd_transfer, remoc_task) } } diff --git a/src/ipc/quic/connector.rs b/src/ipc/quic/connector.rs index 6181112..8e98bf0 100644 --- a/src/ipc/quic/connector.rs +++ b/src/ipc/quic/connector.rs @@ -4,8 +4,8 @@ //! //! [`IpcConnect`] defines the RPC interface for initiating outgoing connections. //! The server side calls [`quic::Connect::connect`], creates a new -//! [`MuxChannel`] pair, queues the remote end's FD, and returns the -//! [`VarInt`] registry ID. +//! [`MuxChannel`] pair, delivers the remote end's FD using the caller-chosen +//! [`VarInt`] FD transfer ID, and returns that same ID as confirmation. //! //! # Server side //! @@ -13,29 +13,34 @@ //! Each `connect` call: //! 1. Calls the inner connector to establish a real QUIC connection. //! 2. Creates a [`MuxChannel`] pair (one for the server, one FD for the client). -//! 3. Queues the client-side FD through the parent-level [`FdSender`]. +//! 3. Delivers the client-side FD through the parent-level [`FdTransfer`]. //! 4. Splits the server-side MuxChannel, establishes a remoc connection. //! 5. Creates a [`ConnectionAdapter`] for the real connection, wraps it in //! an [`IpcConnectionServerShared`], spawns the RPC server. //! 6. Sends a [`ConnectionBootstrap`] over the remoc base channel. -//! 7. Returns the FD-registry ID. +//! 7. Returns the caller-chosen FD transfer ID. //! //! # Client side //! //! [`IpcConnector`] wraps an [`IpcConnectClient`] and implements //! [`quic::Connect`]. Each `connect` call: -//! 1. Calls the RPC method to get a FD-registry ID. -//! 2. Retrieves the MuxChannel FD from the [`FdRegistry`]. +//! 1. Reserves a caller-chosen FD transfer ID. +//! 2. Calls RPC `connect(server, fd_id)` while concurrently receiving the +//! MuxChannel FD. //! 3. Splits the MuxChannel, establishes a remoc connection. //! 4. Receives the [`ConnectionBootstrap`] from the base channel. //! 5. Wraps the result as an [`IpcConnectionHandle`]. -use std::sync::Arc; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; use http::uri::Authority; use remoc::{self, RemoteSend, prelude::ServerShared}; use smallvec::smallvec; use snafu::{ResultExt, Snafu}; +use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; use super::connection::{ @@ -43,7 +48,8 @@ use super::connection::{ IpcConnectionServerShared, }; use crate::{ - ipc::transport::{FdRegistry, FdSender, MuxChannel, SplitError}, + error::Code, + ipc::transport::{FdTransfer, MuxChannel, SplitError, TakeFdsError}, quic::{self, ConnectionError}, rpc::quic::serde_types::SerdeAuthority, varint::VarInt, @@ -55,17 +61,20 @@ use crate::{ /// Remote trait for IPC connector capability. /// -/// Each successful [`connect`](IpcConnect::connect) returns a [`VarInt`] ID -/// that the caller can pass to -/// [`FdRegistry::wait_fds`](crate::ipc::transport::FdRegistry::wait_fds) to -/// obtain the [`MuxChannel`](crate::ipc::transport::MuxChannel) FD for the -/// new connection. +/// Each successful [`connect`](IpcConnect::connect) echoes the receiver-chosen +/// FD transfer ID after the associated [`MuxChannel`] FD has been queued to +/// the local mux writer FIFO. #[remoc::rtc::remote] pub trait IpcConnect: Send + Sync { /// Connect to a remote server. /// - /// Returns the FD-registry ID for the new connection's MuxChannel. - async fn connect(&self, server: SerdeAuthority) -> Result; + /// Returns the caller-chosen FD transfer ID for the new connection's + /// MuxChannel. + async fn connect( + &self, + server: SerdeAuthority, + fd_id: VarInt, + ) -> Result; } // --------------------------------------------------------------------------- @@ -82,8 +91,10 @@ pub enum ConnectorError { WaitFds { source: crate::ipc::transport::WaitFdsError, }, - #[snafu(display("no fd received from registry"))] - EmptyFd, + #[snafu(display("unexpected connection fd count"))] + TakeFd { source: TakeFdsError }, + #[snafu(display("ipc connect returned mismatched fd id {actual}, expected {expected}"))] + FdIdMismatch { expected: VarInt, actual: VarInt }, #[snafu(display("failed to reconstruct mux channel from fd"))] FromFd { source: std::io::Error }, #[snafu(display("failed to split mux channel"))] @@ -106,18 +117,30 @@ pub enum ConnectorError { /// [`IpcConnector`]. pub struct ConnectAdapter { inner: C, - fd_sender: FdSender, + fd_transfer: FdTransfer, + tasks: Mutex>>, _codec: std::marker::PhantomData, } impl ConnectAdapter { - pub fn new(inner: C, fd_sender: FdSender) -> Self { + pub fn new(inner: C, fd_transfer: FdTransfer) -> Self { Self { inner, - fd_sender, + fd_transfer, + tasks: Mutex::new(Vec::new()), _codec: std::marker::PhantomData, } } + + fn spawn_task(&self, task: impl Future + Send + 'static) { + let handle = AbortOnDropHandle::new(tokio::spawn(task.in_current_span())); + let mut tasks = self + .tasks + .lock() + .expect("connect adapter task registry should not be poisoned"); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); + } } impl IpcConnect for ConnectAdapter @@ -129,8 +152,13 @@ where Codec: remoc::codec::Codec, ConnectionBootstrap: RemoteSend, { - async fn connect(&self, server: SerdeAuthority) -> Result { + async fn connect( + &self, + server: SerdeAuthority, + fd_id: VarInt, + ) -> Result { let authority = Authority::try_from(server).map_err(|e| connect_error(e, "authority"))?; + let delivery = self.fd_transfer.delivery(fd_id); let connection = quic::Connect::connect(&self.inner, &authority) .await .map_err(|e| connect_error(e, "connect"))?; @@ -139,20 +167,24 @@ where let (server_mux, client_fd) = MuxChannel::create_pair().map_err(|e| connect_error(e, "create_pair"))?; - // Queue the client-side FD on the parent-level FdSender. - let fd_id = self - .fd_sender - .queue_fds(smallvec![client_fd]) - .map_err(|e| connect_error(e, "queue_fds"))?; + // Split the server-side MuxChannel before FD delivery so all fallible + // local setup is complete before the worker can observe the FD. + let (sink, stream) = match server_mux.split() { + Ok(split) => split, + Err(error) => { + close_undelivered_connection(connection.as_ref(), "split"); + return Err(connect_error(error, "split")); + } + }; - // Split the server-side MuxChannel — the remoc handshake + bootstrap - // are spawned as a background task so we can return fd_id immediately. - // The client needs fd_id to retrieve the FD and start its handshake; - // blocking here would deadlock. - let (sink, stream) = server_mux.split().map_err(|e| connect_error(e, "split"))?; + if let Err(error) = delivery.deliver(smallvec![client_fd]).await { + close_undelivered_connection(connection.as_ref(), "deliver"); + return Err(connect_error(error, "deliver fd")); + } - // Inherent termination: exits when remoc handshake completes or fails. - tokio::spawn(Self::setup_connection(connection, sink, stream).in_current_span()); + // The remoc handshake + bootstrap are spawned as a background task + // after the FD delivery has been queued. + self.spawn_task(Self::setup_connection(connection, sink, stream)); Ok(fd_id) } @@ -179,6 +211,7 @@ where use tracing::debug; let conn_fd_sender = sink.fd_sender(); + let conn_fd_transfer = stream.fd_transfer(conn_fd_sender); let (conn, mut tx, _rx) = match remoc::Connect::framed::<_, _, ConnectionBootstrap, (), Codec>( @@ -197,19 +230,17 @@ where return; } }; - // Inherent termination: exits when remoc ChMux closes. - tokio::spawn(conn.in_current_span()); + let remoc_task = AbortOnDropHandle::new(tokio::spawn(conn.in_current_span())); // Create the ConnectionAdapter and wrap it as an IPC RPC server. - let adapter = ConnectionAdapter::new(connection, conn_fd_sender); + let adapter = ConnectionAdapter::new(connection, conn_fd_transfer); let (server, rpc_client) = IpcConnectionServerShared::new(Arc::new(adapter), 64); - // Inherent termination: exits when remoc ChMux closes. - tokio::spawn( - async move { + let server_task = AbortOnDropHandle::new(tokio::spawn( + (async move { let _ = server.serve(true).await; - } + }) .in_current_span(), - ); + )); let bootstrap = ConnectionBootstrap { connection: rpc_client, @@ -217,7 +248,10 @@ where if tx.send(bootstrap).await.is_err() { debug!("failed to send connection bootstrap: base channel closed"); + return; } + + let _ = futures::future::join(remoc_task, server_task).await; } } @@ -232,15 +266,15 @@ where /// connection (e.g. via `ControlPlane` RPC). pub struct IpcConnector { rpc: IpcConnectClient, - fd_registry: FdRegistry, + fd_transfer: FdTransfer, _codec: std::marker::PhantomData, } impl IpcConnector { - pub fn new(rpc: IpcConnectClient, fd_registry: FdRegistry) -> Self { + pub fn new(rpc: IpcConnectClient, fd_transfer: FdTransfer) -> Self { Self { rpc, - fd_registry, + fd_transfer, _codec: std::marker::PhantomData, } } @@ -258,24 +292,29 @@ where &'a self, server: &'a Authority, ) -> Result, ConnectorError> { - // 1. RPC: ask the server side to connect and get the fd-registry ID. - let fd_id = IpcConnect::connect(&self.rpc, SerdeAuthority::from(server)) - .await - .context(RpcSnafu)?; + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + let rpc = async { + let actual = IpcConnect::connect(&self.rpc, SerdeAuthority::from(server), fd_id) + .await + .context(RpcSnafu)?; + if actual != fd_id { + return Err(ConnectorError::FdIdMismatch { + expected: fd_id, + actual, + }); + } + Ok(()) + }; + let receive = async { receiver.await.context(WaitFdsSnafu) }; - // 2. Retrieve the MuxChannel FD. - let fds = self - .fd_registry - .wait_fds(fd_id) - .await - .context(WaitFdsSnafu)?; - let fd = fds.into_iter().next().ok_or(ConnectorError::EmptyFd)?; + let ((), received) = futures::future::try_join(rpc, receive).await?; + let fd = received.into_one().context(TakeFdSnafu)?; // 3. Reconstruct the MuxChannel and split it. let mux = MuxChannel::from_fd(fd).context(FromFdSnafu)?; let (sink, stream) = mux.split().context(SplitSnafu)?; - let conn_fd_sender = sink.fd_sender(); - let conn_fd_registry = stream.fd_registry(); + let conn_fd_transfer = stream.fd_transfer(sink.fd_sender()); // 4. Establish a remoc connection on the MuxChannel. let (conn, _tx, mut rx) = remoc::Connect::framed::<_, _, (), ConnectionBootstrap, Codec>( @@ -287,7 +326,12 @@ where .map_err(|e| ConnectorError::Remoc { message: e.to_string(), })?; - tokio::spawn(conn.in_current_span()); + let remoc_task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = conn.await; + } + .in_current_span(), + )); // 5. Receive the ConnectionBootstrap from the server. let bootstrap = rx @@ -303,8 +347,8 @@ where // 6. Build the IpcConnectionHandle. Ok(Arc::new(IpcConnectionHandle::new( bootstrap.connection, - conn_fd_registry, - conn_fd_sender, + conn_fd_transfer, + remoc_task, ))) } } @@ -323,3 +367,11 @@ fn connect_error(err: impl std::error::Error, context: &str) -> ConnectionError }, } } + +fn close_undelivered_connection(connection: &impl quic::Lifecycle, context: &'static str) { + quic::Lifecycle::close( + connection, + Code::H3_REQUEST_CANCELLED, + format!("ipc connect {context} failed").into(), + ); +} diff --git a/src/ipc/quic/listener.rs b/src/ipc/quic/listener.rs index 1232cf8..1afd3c2 100644 --- a/src/ipc/quic/listener.rs +++ b/src/ipc/quic/listener.rs @@ -4,8 +4,8 @@ //! //! [`IpcListen`] defines the RPC interface for accepting incoming connections. //! The server side calls [`quic::Listen::accept`], creates a new -//! [`MuxChannel`] pair, queues the remote end's FD, and returns the -//! [`VarInt`] registry ID. +//! [`MuxChannel`] pair, delivers the remote end's FD using the caller-chosen +//! [`VarInt`] FD transfer ID, and returns that same ID as confirmation. //! //! # Server side //! @@ -13,27 +13,31 @@ //! [`IpcListen`]. Each `accept()` call: //! 1. Accepts a new connection from the inner listener. //! 2. Creates a [`MuxChannel`] pair (one for the server, one FD for the client). -//! 3. Queues the client FD through the parent [`FdSender`]. +//! 3. Delivers the client FD through the parent [`FdTransfer`]. //! 4. Splits the server-side MuxChannel, establishes a remoc connection. //! 5. Creates a [`ConnectionAdapter`] and RTC server for the new connection. //! 6. Sends a [`ConnectionBootstrap`] over the remoc base channel. -//! 7. Returns the FD-registry ID. +//! 7. Returns the caller-chosen FD transfer ID. //! //! # Client side //! //! [`IpcListener`] wraps an [`IpcListenClient`] and the parent-level -//! [`FdRegistry`], implementing [`quic::Listen`]. Each `accept()` call: -//! 1. Calls the RPC `accept()` to get a FD-registry ID. -//! 2. Retrieves the client FD from the parent [`FdRegistry`]. +//! [`FdTransfer`], implementing [`quic::Listen`]. Each `accept()` call: +//! 1. Reserves a caller-chosen FD transfer ID. +//! 2. Calls RPC `accept(fd_id)` while concurrently receiving the client FD. //! 3. Reconstructs a [`MuxChannel`], establishes a remoc connection. //! 4. Receives the [`ConnectionBootstrap`] from the base channel. //! 5. Returns an [`IpcConnectionHandle`]. -use std::sync::Arc; +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; use remoc::prelude::ServerShared; use smallvec::smallvec; use snafu::{ResultExt, Snafu}; +use tokio_util::task::AbortOnDropHandle; use tracing::{Instrument, debug}; use super::connection::{ @@ -41,7 +45,8 @@ use super::connection::{ IpcConnectionServerShared, }; use crate::{ - ipc::transport::{FdRegistry, FdSender, MuxChannel}, + error::Code, + ipc::transport::{FdTransfer, MuxChannel, TakeFdsError}, quic::{self, ConnectionError}, varint::VarInt, }; @@ -52,17 +57,16 @@ use crate::{ /// Remote trait for IPC listener capability. /// -/// Each successful [`accept`](IpcListen::accept) returns a [`VarInt`] ID that -/// the caller can pass to -/// [`FdRegistry::wait_fds`](crate::ipc::transport::FdRegistry::wait_fds) to -/// obtain the [`MuxChannel`](crate::ipc::transport::MuxChannel) FD for the -/// new connection. +/// Each successful [`accept`](IpcListen::accept) echoes the receiver-chosen +/// FD transfer ID after the associated [`MuxChannel`] FD has been queued to +/// the local mux writer FIFO. #[remoc::rtc::remote] pub trait IpcListen: Send + Sync { /// Accept an incoming connection. /// - /// Returns the FD-registry ID for the new connection's MuxChannel. - async fn accept(&mut self) -> Result; + /// Returns the caller-chosen FD transfer ID for the new connection's + /// MuxChannel. + async fn accept(&mut self, fd_id: VarInt) -> Result; /// Gracefully shut down the listener. async fn shutdown(&self) -> Result<(), ConnectionError>; @@ -80,6 +84,14 @@ fn ipc_listen_error(err: impl std::error::Error, context: &str) -> ConnectionErr } } +fn close_undelivered_connection(connection: &impl quic::Lifecycle, context: &'static str) { + quic::Lifecycle::close( + connection, + Code::H3_REQUEST_CANCELLED, + format!("ipc listen {context} failed").into(), + ); +} + // --------------------------------------------------------------------------- // Server side: ListenAdapter // --------------------------------------------------------------------------- @@ -92,18 +104,30 @@ fn ipc_listen_error(err: impl std::error::Error, context: &str) -> ConnectionErr /// codec; the downstream crate (e.g., gateway) picks one. pub struct ListenAdapter { inner: L, - fd_sender: FdSender, + fd_transfer: FdTransfer, + tasks: Mutex>>, _codec: std::marker::PhantomData, } impl ListenAdapter { - pub fn new(inner: L, fd_sender: FdSender) -> Self { + pub fn new(inner: L, fd_transfer: FdTransfer) -> Self { Self { inner, - fd_sender, + fd_transfer, + tasks: Mutex::new(Vec::new()), _codec: std::marker::PhantomData, } } + + fn spawn_task(&self, task: impl Future + Send + 'static) { + let handle = AbortOnDropHandle::new(tokio::spawn(task.in_current_span())); + let mut tasks = self + .tasks + .lock() + .expect("listen adapter task registry should not be poisoned"); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); + } } impl IpcListen for ListenAdapter @@ -111,7 +135,8 @@ where L: quic::Listen + 'static, Codec: remoc::codec::Codec, { - async fn accept(&mut self) -> Result { + async fn accept(&mut self, fd_id: VarInt) -> Result { + let delivery = self.fd_transfer.delivery(fd_id); let connection = quic::Listen::accept(&mut self.inner) .await .map_err(|e| ipc_listen_error(e, "accept"))?; @@ -120,22 +145,24 @@ where let (server_mux, client_fd) = MuxChannel::create_pair().map_err(|e| ipc_listen_error(e, "create mux pair"))?; - // Queue the client FD through the parent-level FdSender - let fd_id = self - .fd_sender - .queue_fds(smallvec![client_fd]) - .map_err(|e| ipc_listen_error(e, "queue fd"))?; + // Split the server MuxChannel before FD delivery so all fallible local + // setup is complete before the worker can observe the FD. + let (sink, stream) = match server_mux.split() { + Ok(split) => split, + Err(error) => { + close_undelivered_connection(connection.as_ref(), "split"); + return Err(ipc_listen_error(error, "split mux")); + } + }; - // Split the server MuxChannel — the remoc handshake + bootstrap are - // spawned as a background task so we can return fd_id immediately. - // The client needs fd_id to retrieve the FD and start its handshake; - // blocking here would deadlock. - let (sink, stream) = server_mux - .split() - .map_err(|e| ipc_listen_error(e, "split mux"))?; + if let Err(error) = delivery.deliver(smallvec![client_fd]).await { + close_undelivered_connection(connection.as_ref(), "deliver"); + return Err(ipc_listen_error(error, "deliver fd")); + } - // Inherent termination: exits when remoc handshake completes or fails. - tokio::spawn(Self::setup_connection(connection, sink, stream).in_current_span()); + // The remoc handshake + bootstrap are spawned as a background task + // after the FD delivery has been queued. + self.spawn_task(Self::setup_connection(connection, sink, stream)); Ok(fd_id) } @@ -166,6 +193,7 @@ where stream: crate::ipc::transport::MuxStream, ) { let conn_fd_sender = sink.fd_sender(); + let conn_fd_transfer = stream.fd_transfer(conn_fd_sender); let (remoc_conn, mut tx, _rx) = match remoc::Connect::framed::< _, @@ -182,19 +210,17 @@ where return; } }; - // Inherent termination: exits when remoc ChMux closes. - tokio::spawn(remoc_conn.in_current_span()); + let remoc_task = AbortOnDropHandle::new(tokio::spawn(remoc_conn.in_current_span())); // Create the ConnectionAdapter for this connection's stream management - let adapter = ConnectionAdapter::new(connection, conn_fd_sender); + let adapter = ConnectionAdapter::new(connection, conn_fd_transfer); let (server, rpc_client) = IpcConnectionServerShared::new(Arc::new(adapter), 64); - // Inherent termination: exits when remoc ChMux closes. - tokio::spawn( - async move { + let server_task = AbortOnDropHandle::new(tokio::spawn( + (async move { let _ = server.serve(true).await; - } + }) .in_current_span(), - ); + )); // Send the bootstrap data over the remoc base channel let bootstrap = ConnectionBootstrap { @@ -202,7 +228,10 @@ where }; if tx.send(bootstrap).await.is_err() { debug!("failed to send connection bootstrap: base channel closed"); + return; } + + let _ = futures::future::join(remoc_task, server_task).await; } } @@ -217,7 +246,7 @@ where /// The `Codec` parameter must match the server-side [`ListenAdapter`]'s codec. pub struct IpcListener { rpc: IpcListenClient, - fd_registry: FdRegistry, + fd_transfer: FdTransfer, _codec: std::marker::PhantomData, } @@ -231,8 +260,10 @@ pub enum IpcListenError { WaitFd { source: crate::ipc::transport::WaitFdsError, }, - #[snafu(display("no fd received from registry"))] - EmptyFd, + #[snafu(display("unexpected connection fd count"))] + TakeFd { source: TakeFdsError }, + #[snafu(display("ipc accept returned mismatched fd id {actual}, expected {expected}"))] + FdIdMismatch { expected: VarInt, actual: VarInt }, #[snafu(display("failed to reconstruct mux channel"))] FromFd { source: std::io::Error }, #[snafu(display("failed to split client mux channel"))] @@ -258,10 +289,10 @@ impl From for ConnectionError { } impl IpcListener { - pub fn new(rpc: IpcListenClient, fd_registry: FdRegistry) -> Self { + pub fn new(rpc: IpcListenClient, fd_transfer: FdTransfer) -> Self { Self { rpc, - fd_registry, + fd_transfer, _codec: std::marker::PhantomData, } } @@ -275,22 +306,29 @@ where type Error = IpcListenError; async fn accept(&mut self) -> Result, IpcListenError> { - // 1. RPC: ask server to accept a new connection - let fd_id = IpcListen::accept(&mut self.rpc).await.context(RpcSnafu)?; + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + let rpc = async { + let actual = IpcListen::accept(&mut self.rpc, fd_id) + .await + .context(RpcSnafu)?; + if actual != fd_id { + return Err(IpcListenError::FdIdMismatch { + expected: fd_id, + actual, + }); + } + Ok(()) + }; + let receive = async { receiver.await.context(WaitFdSnafu) }; - // 2. Retrieve the MuxChannel FD from the parent-level FdRegistry - let fds = self - .fd_registry - .wait_fds(fd_id) - .await - .context(WaitFdSnafu)?; - let fd = fds.into_iter().next().ok_or(IpcListenError::EmptyFd)?; + let ((), received) = futures::future::try_join(rpc, receive).await?; + let fd = received.into_one().context(TakeFdSnafu)?; // 3. Reconstruct MuxChannel and establish remoc connection let mux = MuxChannel::from_fd(fd).context(FromFdSnafu)?; let (sink, stream) = mux.split().context(ClientSplitSnafu)?; - let conn_fd_sender = sink.fd_sender(); - let conn_fd_registry = stream.fd_registry(); + let conn_fd_transfer = stream.fd_transfer(sink.fd_sender()); let (remoc_conn, _tx, mut rx) = remoc::Connect::framed::<_, _, (), ConnectionBootstrap, Codec>( @@ -302,7 +340,12 @@ where .map_err(|e| IpcListenError::ClientRemocConnect { message: e.to_string(), })?; - tokio::spawn(remoc_conn.in_current_span()); + let remoc_task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = remoc_conn.await; + } + .in_current_span(), + )); // 4. Receive the ConnectionBootstrap from the server let bootstrap = rx @@ -318,8 +361,8 @@ where // 5. Wrap as IpcConnectionHandle Ok(Arc::new(IpcConnectionHandle::new( bootstrap.connection, - conn_fd_registry, - conn_fd_sender, + conn_fd_transfer, + remoc_task, ))) } diff --git a/src/ipc/quic/stream.rs b/src/ipc/quic/stream.rs index 1a7d215..c94896e 100644 --- a/src/ipc/quic/stream.rs +++ b/src/ipc/quic/stream.rs @@ -1,11 +1,6 @@ mod codec; mod handle; -mod reader; -mod state; -mod writer; +pub(crate) mod reader; +pub(crate) mod writer; -pub use self::{ - handle::{IpcBiHandle, IpcUniHandle}, - reader::IpcReadStream, - writer::IpcWriteStream, -}; +pub use self::handle::{IpcBiHandle, IpcUniHandle}; diff --git a/src/ipc/quic/stream/codec.rs b/src/ipc/quic/stream/codec.rs index abf2340..15183f9 100644 --- a/src/ipc/quic/stream/codec.rs +++ b/src/ipc/quic/stream/codec.rs @@ -1,63 +1,71 @@ -//! Minimal framing codec for per-stream Unix socketpair IPC. +//! Direction-aware stream frame codecs for per-stream Unix socketpair IPC. //! //! Each QUIC stream is carried over an independent `SOCK_STREAM` socketpair. -//! The codec multiplexes data and control signals on the same byte stream using -//! a simple tag-length-value encoding based on QUIC variable-length integers. -//! -//! # Frame types -//! -//! ```text -//! PULL = type(varint 0x00) -//! PUSH = type(varint 0x01) + length(varint) + payload -//! STOP = type(varint 0x02) + code(varint) -//! CANCEL = type(varint 0x03) + code(varint) -//! CONN_CLOSED = type(varint 0x04) -//! ``` +//! The wire format is a simple QUIC-varint tag followed by optional varint +//! length/code fields and payload bytes. The frame vocabulary is shared with +//! `rpc::stream::frame`; this module only adapts that vocabulary to +//! `tokio_util::codec` for IPC pipes. + +use std::borrow::Cow; use bytes::{Buf, BufMut, Bytes, BytesMut}; +use snafu::ResultExt as _; use tokio_util::codec::{Decoder, Encoder}; -use crate::varint::VarInt; - -/// Frame type tags. -pub(super) const TAG_PULL: u8 = 0x00; -const TAG_PUSH: u8 = 0x01; -pub(super) const TAG_STOP: u8 = 0x02; -pub(super) const TAG_CANCEL: u8 = 0x03; -pub(super) const TAG_CONN_CLOSED: u8 = 0x04; +use crate::{ + quic, + rpc::stream::frame::{self, ReadCommand, ReadEvent, WriteCommand, WriteEvent}, + varint::{self, VarInt}, +}; + +const TAG_PULL: u8 = frame::TAG_PULL as u8; +const TAG_PUSH: u8 = frame::TAG_PUSH as u8; +const TAG_FLUSH: u8 = frame::TAG_FLUSH as u8; +const TAG_FLUSH_ACK: u8 = frame::TAG_FLUSH_ACK as u8; +const TAG_EOS: u8 = frame::TAG_EOS as u8; +const TAG_EOS_ACK: u8 = frame::TAG_EOS_ACK as u8; +const TAG_STOP: u8 = frame::TAG_STOP as u8; +const TAG_STOP_ACK: u8 = frame::TAG_STOP_ACK as u8; +const TAG_RESET: u8 = frame::TAG_RESET as u8; +const TAG_RESET_ACK: u8 = frame::TAG_RESET_ACK as u8; +const TAG_ERR_RESET: u8 = frame::TAG_ERR_RESET as u8; +const TAG_ERR_CONN: u8 = frame::TAG_ERR_CONN as u8; + +const IPC_CODEC_ERROR_KIND: VarInt = VarInt::from_u32(0x0b); +const IPC_CODEC_ERROR_FRAME_TYPE: VarInt = VarInt::from_u32(0x00); /// Maximum encoded header bytes for a PUSH frame: tag + varint length. pub(super) const PUSH_HEADER_MAX_LEN: usize = 1 + VarInt::MAX_SIZE; - -/// Maximum encoded bytes for a control frame: tag + varint code. -pub(super) const CONTROL_MAX_LEN: usize = 1 + VarInt::MAX_SIZE; - -/// Frames exchanged over a per-stream socketpair. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(super) enum Frame { - /// Flow-control signal — reader grants the writer permission to send one - /// PUSH frame. - Pull, - /// Stream data payload. - Push(Bytes), - /// STOP_SENDING — reader asks the remote writer to stop (carries error code). - Stop(VarInt), - /// RESET_STREAM — writer cancels the stream (carries error code). - Cancel(VarInt), - /// Connection-level closure notification. - ConnClosed, -} +const MAX_PUSH_RESERVE_STEP: usize = 64 * 1024; /// Codec error. #[derive(Debug, snafu::Snafu)] #[snafu(module)] -pub(super) enum CodecError { +pub(crate) enum CodecError { #[snafu(transparent)] Io { source: std::io::Error }, - #[snafu(display("unknown frame tag: 0x{tag:02x}"))] + #[snafu(display("unknown frame tag 0x{tag:02x}"))] UnknownTag { tag: u8 }, - #[snafu(display("varint overflow"))] - VarIntOverflow, + #[snafu(display("frame tag 0x{tag:02x} is invalid for {direction}"))] + InvalidDirection { tag: u8, direction: &'static str }, + #[snafu(display("push frame length overflows buffer size"))] + PushLengthOverflow, + #[snafu(display("failed to encode push payload length"))] + EncodePushLength { source: varint::err::Overflow }, +} + +impl From for quic::ConnectionError { + fn from(error: CodecError) -> Self { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: IPC_CODEC_ERROR_KIND, + frame_type: IPC_CODEC_ERROR_FRAME_TYPE, + // Lossy: QUIC transport error reason is a protocol string field + // for local IPC codec failures. + reason: Cow::Owned(error.to_string()), + }, + } + } } /// Try to decode one QUIC-style varint from `buf` without advancing it. @@ -93,6 +101,10 @@ fn encode_varint(buf: &mut BytesMut, v: VarInt) { } } +fn reserve_push_payload(src: &mut BytesMut, remaining: usize) { + src.reserve(remaining.min(MAX_PUSH_RESERVE_STEP)); +} + pub(super) fn encode_varint_to_slice(dst: &mut [u8], v: VarInt) -> usize { let x = v.into_inner(); if x < (1 << 6) { @@ -117,18 +129,15 @@ pub(super) fn encode_varint_to_slice(dst: &mut [u8], v: VarInt) -> usize { pub(super) fn encode_push_header( payload_len: usize, ) -> Result<([u8; PUSH_HEADER_MAX_LEN], usize), CodecError> { - let len = VarInt::try_from(payload_len).map_err(|_| CodecError::VarIntOverflow)?; + let len = VarInt::try_from(payload_len).context(codec_error::EncodePushLengthSnafu)?; let mut header = [0u8; PUSH_HEADER_MAX_LEN]; header[0] = TAG_PUSH; let varint_size = encode_varint_to_slice(&mut header[1..], len); Ok((header, 1 + varint_size)) } -/// Minimal framing codec for the per-stream IPC protocol. #[derive(Debug, Default)] -pub(super) struct StreamCodec { - /// Partially decoded frame state: once we know the tag and payload length - /// we store them here so we don't re-parse on the next `decode` call. +struct WireCodec { state: DecodeState, } @@ -136,82 +145,123 @@ pub(super) struct StreamCodec { enum DecodeState { #[default] Tag, - /// We have the tag and the payload length, waiting for payload bytes. - Push { len: usize }, - /// Control frame (STOP/CANCEL): tag decoded, need varint code. - Control { tag: u8 }, + Push { + len: usize, + }, + Control { + tag: u8, + }, } -impl StreamCodec { - pub fn new() -> Self { - Self::default() - } +#[derive(Debug, Clone, PartialEq, Eq)] +enum WireFrame { + Pull, + Push { data: Bytes }, + Flush, + FlushAck, + Eos, + EosAck, + Stop { code: VarInt }, + StopAck { code: VarInt }, + Reset { code: VarInt }, + ResetAck { code: VarInt }, + ErrReset { code: VarInt }, + ErrConn, } -impl Decoder for StreamCodec { - type Item = Frame; - type Error = CodecError; +impl WireFrame { + const fn tag(&self) -> u8 { + match self { + Self::Pull => TAG_PULL, + Self::Push { .. } => TAG_PUSH, + Self::Flush => TAG_FLUSH, + Self::FlushAck => TAG_FLUSH_ACK, + Self::Eos => TAG_EOS, + Self::EosAck => TAG_EOS_ACK, + Self::Stop { .. } => TAG_STOP, + Self::StopAck { .. } => TAG_STOP_ACK, + Self::Reset { .. } => TAG_RESET, + Self::ResetAck { .. } => TAG_RESET_ACK, + Self::ErrReset { .. } => TAG_ERR_RESET, + Self::ErrConn => TAG_ERR_CONN, + } + } +} - fn decode(&mut self, src: &mut BytesMut) -> Result, CodecError> { +impl WireCodec { + fn decode(&mut self, src: &mut BytesMut) -> Result, CodecError> { loop { match self.state { DecodeState::Tag => { if src.is_empty() { return Ok(None); } - // Peek at the tag byte (first byte is always a 1-byte varint for 0x00..0x04). let tag = src[0]; match tag { TAG_PUSH => { - // Need tag + varint length. Try to decode the length varint - // starting right after the tag byte. if src.len() < 2 { return Ok(None); } let after_tag = &src[1..]; let Some((len_vi, vi_size)) = try_decode_varint(after_tag) else { - // Not enough bytes for the varint yet. return Ok(None); }; - let len = len_vi.into_inner() as usize; + let Ok(len) = usize::try_from(len_vi.into_inner()) else { + return Err(CodecError::PushLengthOverflow); + }; let header_size = 1 + vi_size; - let total = header_size + len; + let Some(total) = header_size.checked_add(len) else { + return Err(CodecError::PushLengthOverflow); + }; if src.len() < total { - // Consume the header so DecodeState::Push only waits for payload. src.advance(header_size); - src.reserve(len.saturating_sub(src.len())); + reserve_push_payload(src, len.saturating_sub(src.len())); self.state = DecodeState::Push { len }; return Ok(None); } src.advance(header_size); let data = src.split_to(len).freeze(); - // state stays Tag for next frame - return Ok(Some(Frame::Push(data))); + return Ok(Some(WireFrame::Push { data })); + } + TAG_STOP | TAG_STOP_ACK | TAG_RESET | TAG_RESET_ACK | TAG_ERR_RESET => { + self.state = DecodeState::Control { tag }; + src.advance(1); } TAG_PULL => { src.advance(1); - return Ok(Some(Frame::Pull)); + return Ok(Some(WireFrame::Pull)); } - TAG_STOP | TAG_CANCEL => { - self.state = DecodeState::Control { tag }; - src.advance(1); // consume the tag byte + TAG_FLUSH => { + src.advance(1); + return Ok(Some(WireFrame::Flush)); } - TAG_CONN_CLOSED => { + TAG_FLUSH_ACK => { src.advance(1); - return Ok(Some(Frame::ConnClosed)); + return Ok(Some(WireFrame::FlushAck)); + } + TAG_EOS => { + src.advance(1); + return Ok(Some(WireFrame::Eos)); + } + TAG_EOS_ACK => { + src.advance(1); + return Ok(Some(WireFrame::EosAck)); + } + TAG_ERR_CONN => { + src.advance(1); + return Ok(Some(WireFrame::ErrConn)); } _ => return Err(CodecError::UnknownTag { tag }), } } DecodeState::Push { len } => { - // We already consumed the header; just wait for the payload. if src.len() < len { - src.reserve(len - src.len()); + reserve_push_payload(src, len - src.len()); return Ok(None); } let data = src.split_to(len).freeze(); self.state = DecodeState::Tag; - return Ok(Some(Frame::Push(data))); + return Ok(Some(WireFrame::Push { data })); } DecodeState::Control { tag } => { let Some((code, vi_size)) = try_decode_varint(src) else { @@ -220,154 +270,433 @@ impl Decoder for StreamCodec { src.advance(vi_size); self.state = DecodeState::Tag; return Ok(Some(match tag { - TAG_STOP => Frame::Stop(code), - TAG_CANCEL => Frame::Cancel(code), - _ => unreachable!(), + TAG_STOP => WireFrame::Stop { code }, + TAG_STOP_ACK => WireFrame::StopAck { code }, + TAG_RESET => WireFrame::Reset { code }, + TAG_RESET_ACK => WireFrame::ResetAck { code }, + TAG_ERR_RESET => WireFrame::ErrReset { code }, + _ => unreachable!("control state only records control tags"), })); } } } } -} -impl Encoder for StreamCodec { - type Error = CodecError; - - fn encode(&mut self, item: Frame, dst: &mut BytesMut) -> Result<(), CodecError> { + fn encode(&mut self, item: WireFrame, dst: &mut BytesMut) -> Result<(), CodecError> { match item { - Frame::Pull => { - dst.reserve(1); - dst.put_u8(TAG_PULL); - } - Frame::Push(data) => { + WireFrame::Pull => dst.put_u8(TAG_PULL), + WireFrame::Push { data } => { let (header, header_len) = encode_push_header(data.len())?; dst.reserve(header_len + data.len()); dst.extend_from_slice(&header[..header_len]); dst.extend_from_slice(&data); } - Frame::Stop(code) => { - dst.reserve(1 + code.encoding_size()); - dst.put_u8(TAG_STOP); - encode_varint(dst, code); - } - Frame::Cancel(code) => { - dst.reserve(1 + code.encoding_size()); - dst.put_u8(TAG_CANCEL); - encode_varint(dst, code); - } - Frame::ConnClosed => { - dst.reserve(1); - dst.put_u8(TAG_CONN_CLOSED); - } + WireFrame::Flush => dst.put_u8(TAG_FLUSH), + WireFrame::FlushAck => dst.put_u8(TAG_FLUSH_ACK), + WireFrame::Eos => dst.put_u8(TAG_EOS), + WireFrame::EosAck => dst.put_u8(TAG_EOS_ACK), + WireFrame::Stop { code } => encode_control(dst, TAG_STOP, code), + WireFrame::StopAck { code } => encode_control(dst, TAG_STOP_ACK, code), + WireFrame::Reset { code } => encode_control(dst, TAG_RESET, code), + WireFrame::ResetAck { code } => encode_control(dst, TAG_RESET_ACK, code), + WireFrame::ErrReset { code } => encode_control(dst, TAG_ERR_RESET, code), + WireFrame::ErrConn => dst.put_u8(TAG_ERR_CONN), } Ok(()) } } +fn encode_control(dst: &mut BytesMut, tag: u8, code: VarInt) { + dst.reserve(1 + code.encoding_size()); + dst.put_u8(tag); + encode_varint(dst, code); +} + +fn wrong_direction(tag: u8, direction: &'static str) -> CodecError { + CodecError::InvalidDirection { tag, direction } +} + +/// Codec for worker-to-hypervisor read commands. +#[derive(Debug, Default)] +pub(super) struct ReadCommandCodec { + wire: WireCodec, +} + +impl ReadCommandCodec { + pub(super) fn new() -> Self { + Self::default() + } +} + +impl Decoder for ReadCommandCodec { + type Item = ReadCommand; + type Error = CodecError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let Some(frame) = self.wire.decode(src)? else { + return Ok(None); + }; + let tag = frame.tag(); + match frame { + WireFrame::Pull => Ok(Some(ReadCommand::Pull)), + WireFrame::Stop { code } => Ok(Some(ReadCommand::Stop { code })), + _ => Err(wrong_direction(tag, "read command")), + } + } +} + +impl Encoder for ReadCommandCodec { + type Error = CodecError; + + fn encode(&mut self, item: ReadCommand, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.wire.encode( + match item { + ReadCommand::Pull => WireFrame::Pull, + ReadCommand::Stop { code } => WireFrame::Stop { code }, + }, + dst, + ) + } +} + +/// Codec for hypervisor-to-worker read events. +#[derive(Debug, Default)] +pub(super) struct ReadEventCodec { + wire: WireCodec, +} + +impl ReadEventCodec { + pub(super) fn new() -> Self { + Self::default() + } +} + +impl Decoder for ReadEventCodec { + type Item = ReadEvent; + type Error = CodecError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let Some(frame) = self.wire.decode(src)? else { + return Ok(None); + }; + let tag = frame.tag(); + match frame { + WireFrame::Push { data } => Ok(Some(ReadEvent::Push { data })), + WireFrame::Eos => Ok(Some(ReadEvent::Eos)), + WireFrame::StopAck { code } => Ok(Some(ReadEvent::StopAck { code })), + WireFrame::ErrReset { code } => Ok(Some(ReadEvent::ErrReset { code })), + WireFrame::ErrConn => Ok(Some(ReadEvent::ErrConn)), + _ => Err(wrong_direction(tag, "read event")), + } + } +} + +impl Encoder for ReadEventCodec { + type Error = CodecError; + + fn encode(&mut self, item: ReadEvent, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.wire.encode( + match item { + ReadEvent::Push { data } => WireFrame::Push { data }, + ReadEvent::Eos => WireFrame::Eos, + ReadEvent::StopAck { code } => WireFrame::StopAck { code }, + ReadEvent::ErrReset { code } => WireFrame::ErrReset { code }, + ReadEvent::ErrConn => WireFrame::ErrConn, + }, + dst, + ) + } +} + +/// Codec for worker-to-hypervisor write commands. +#[derive(Debug, Default)] +pub(super) struct WriteCommandCodec { + wire: WireCodec, +} + +impl WriteCommandCodec { + pub(super) fn new() -> Self { + Self::default() + } +} + +impl Decoder for WriteCommandCodec { + type Item = WriteCommand; + type Error = CodecError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let Some(frame) = self.wire.decode(src)? else { + return Ok(None); + }; + let tag = frame.tag(); + match frame { + WireFrame::Push { data } => Ok(Some(WriteCommand::Push { data })), + WireFrame::Flush => Ok(Some(WriteCommand::Flush)), + WireFrame::Eos => Ok(Some(WriteCommand::Eos)), + WireFrame::Reset { code } => Ok(Some(WriteCommand::Reset { code })), + _ => Err(wrong_direction(tag, "write command")), + } + } +} + +impl Encoder for WriteCommandCodec { + type Error = CodecError; + + fn encode(&mut self, item: WriteCommand, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.wire.encode( + match item { + WriteCommand::Push { data } => WireFrame::Push { data }, + WriteCommand::Flush => WireFrame::Flush, + WriteCommand::Eos => WireFrame::Eos, + WriteCommand::Reset { code } => WireFrame::Reset { code }, + }, + dst, + ) + } +} + +/// Codec for hypervisor-to-worker write events. +#[derive(Debug, Default)] +pub(super) struct WriteEventCodec { + wire: WireCodec, +} + +impl WriteEventCodec { + pub(super) fn new() -> Self { + Self::default() + } +} + +impl Decoder for WriteEventCodec { + type Item = WriteEvent; + type Error = CodecError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + let Some(frame) = self.wire.decode(src)? else { + return Ok(None); + }; + let tag = frame.tag(); + match frame { + WireFrame::Pull => Ok(Some(WriteEvent::Pull)), + WireFrame::FlushAck => Ok(Some(WriteEvent::FlushAck)), + WireFrame::EosAck => Ok(Some(WriteEvent::EosAck)), + WireFrame::ResetAck { code } => Ok(Some(WriteEvent::ResetAck { code })), + WireFrame::ErrReset { code } => Ok(Some(WriteEvent::ErrReset { code })), + WireFrame::ErrConn => Ok(Some(WriteEvent::ErrConn)), + _ => Err(wrong_direction(tag, "write event")), + } + } +} + +impl Encoder for WriteEventCodec { + type Error = CodecError; + + fn encode(&mut self, item: WriteEvent, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.wire.encode( + match item { + WriteEvent::Pull => WireFrame::Pull, + WriteEvent::FlushAck => WireFrame::FlushAck, + WriteEvent::EosAck => WireFrame::EosAck, + WriteEvent::ResetAck { code } => WireFrame::ResetAck { code }, + WriteEvent::ErrReset { code } => WireFrame::ErrReset { code }, + WriteEvent::ErrConn => WireFrame::ErrConn, + }, + dst, + ) + } +} + #[cfg(test)] mod tests { use super::*; - fn round_trip(frame: Frame) { - let mut codec = StreamCodec::new(); + fn encode(codec: &mut C, frame: F) -> BytesMut + where + C: Encoder, + { let mut buf = BytesMut::new(); - codec.encode(frame.clone(), &mut buf).unwrap(); - - let mut decoder = StreamCodec::new(); - let decoded = decoder.decode(&mut buf).unwrap().unwrap(); - assert_eq!(decoded, frame); - assert!(buf.is_empty()); + codec.encode(frame, &mut buf).unwrap(); + buf } - #[test] - fn data_frame_round_trip() { - round_trip(Frame::Push(Bytes::from_static(b"hello world"))); + fn decode(codec: &mut C, mut buf: BytesMut) -> Result + where + C: Decoder, + { + codec.decode(&mut buf).map(|item| item.unwrap()) } - #[test] - fn empty_data_frame_round_trip() { - round_trip(Frame::Push(Bytes::new())); + fn assert_direction_error(error: CodecError, tag: u8, direction: &'static str) { + match error { + CodecError::InvalidDirection { + tag: actual, + direction: actual_direction, + } => { + assert_eq!(actual, tag); + assert_eq!(actual_direction, direction); + } + other => panic!("expected invalid direction error, got {other:?}"), + } } #[test] - fn pull_frame_round_trip() { - round_trip(Frame::Pull); + fn read_command_codec_accepts_pull_and_stop() { + let code = VarInt::from_u32(0x42); + let mut encoder = ReadCommandCodec::new(); + let mut decoder = ReadCommandCodec::new(); + + let pull = encode(&mut encoder, ReadCommand::Pull); + assert_eq!(decode(&mut decoder, pull).unwrap(), ReadCommand::Pull); + + let stop = encode(&mut encoder, ReadCommand::Stop { code }); + assert_eq!( + decode(&mut decoder, stop).unwrap(), + ReadCommand::Stop { code } + ); } #[test] - fn stop_frame_round_trip() { - round_trip(Frame::Stop(VarInt::from_u32(0x42))); - } + fn read_event_codec_accepts_read_events() { + let stop = VarInt::from_u32(0x43); + let reset = VarInt::from_u32(0x44); + let frames = [ + ReadEvent::Push { + data: Bytes::from_static(b"read"), + }, + ReadEvent::Eos, + ReadEvent::StopAck { code: stop }, + ReadEvent::ErrReset { code: reset }, + ReadEvent::ErrConn, + ]; + let mut encoder = ReadEventCodec::new(); + let mut decoder = ReadEventCodec::new(); - #[test] - fn cancel_frame_round_trip() { - round_trip(Frame::Cancel(VarInt::from_u32(0))); + for frame in frames { + let encoded = encode(&mut encoder, frame.clone()); + assert_eq!(decode(&mut decoder, encoded).unwrap(), frame); + } } #[test] - fn conn_closed_round_trip() { - round_trip(Frame::ConnClosed); - } + fn write_command_codec_accepts_write_commands() { + let reset = VarInt::from_u32(0x45); + let frames = [ + WriteCommand::Push { + data: Bytes::from_static(b"write"), + }, + WriteCommand::Flush, + WriteCommand::Eos, + WriteCommand::Reset { code: reset }, + ]; + let mut encoder = WriteCommandCodec::new(); + let mut decoder = WriteCommandCodec::new(); - #[test] - fn large_data_frame() { - let data = Bytes::from(vec![0xab; 70_000]); - round_trip(Frame::Push(data)); + for frame in frames { + let encoded = encode(&mut encoder, frame.clone()); + assert_eq!(decode(&mut decoder, encoded).unwrap(), frame); + } } #[test] - fn multiple_frames_in_sequence() { - let mut codec = StreamCodec::new(); - let mut buf = BytesMut::new(); - - let frames = vec![ - Frame::Pull, - Frame::Push(Bytes::from_static(b"first")), - Frame::Stop(VarInt::from_u32(1)), - Frame::Push(Bytes::from_static(b"second")), - Frame::Cancel(VarInt::from_u32(2)), - Frame::ConnClosed, + fn write_event_codec_accepts_write_events() { + let reset = VarInt::from_u32(0x46); + let err_reset = VarInt::from_u32(0x47); + let frames = [ + WriteEvent::Pull, + WriteEvent::FlushAck, + WriteEvent::EosAck, + WriteEvent::ResetAck { code: reset }, + WriteEvent::ErrReset { code: err_reset }, + WriteEvent::ErrConn, ]; + let mut encoder = WriteEventCodec::new(); + let mut decoder = WriteEventCodec::new(); - for f in &frames { - codec.encode(f.clone(), &mut buf).unwrap(); + for frame in frames { + let encoded = encode(&mut encoder, frame.clone()); + assert_eq!(decode(&mut decoder, encoded).unwrap(), frame); } + } - let mut decoder = StreamCodec::new(); - for expected in &frames { - let decoded = decoder.decode(&mut buf).unwrap().unwrap(); - assert_eq!(&decoded, expected); - } - assert!(buf.is_empty()); + #[test] + fn tag_valid_in_wrong_direction_returns_direction_error() { + let mut read_command_encoder = ReadCommandCodec::new(); + let mut read_command_decoder = ReadCommandCodec::new(); + let wrong = encode(&mut read_command_encoder, ReadCommand::Pull); + assert_direction_error( + decode(&mut ReadEventCodec::new(), wrong.clone()).unwrap_err(), + TAG_PULL, + "read event", + ); + assert_direction_error( + decode(&mut WriteCommandCodec::new(), wrong.clone()).unwrap_err(), + TAG_PULL, + "write command", + ); + + let mut write_command_encoder = WriteCommandCodec::new(); + let wrong = encode(&mut write_command_encoder, WriteCommand::Flush); + assert_direction_error( + decode(&mut read_command_decoder, wrong.clone()).unwrap_err(), + TAG_FLUSH, + "read command", + ); + assert_direction_error( + decode(&mut WriteEventCodec::new(), wrong).unwrap_err(), + TAG_FLUSH, + "write event", + ); } #[test] - fn incremental_decode() { - let mut codec = StreamCodec::new(); - let mut buf = BytesMut::new(); - codec - .encode(Frame::Push(Bytes::from_static(b"abc")), &mut buf) - .unwrap(); + fn unknown_tag_error() { + let mut decoder = ReadCommandCodec::new(); + let mut buf = BytesMut::from(&[0xff][..]); + match decoder.decode(&mut buf).unwrap_err() { + CodecError::UnknownTag { tag } => assert_eq!(tag, 0xff), + other => panic!("expected unknown tag, got {other:?}"), + } + } + #[test] + fn incremental_push_decode() { + let payload = Bytes::from_static(b"abc"); + let mut codec = ReadEventCodec::new(); + let mut buf = encode( + &mut codec, + ReadEvent::Push { + data: payload.clone(), + }, + ); let full = buf.split(); - // Feed one byte at a time. - let mut decoder = StreamCodec::new(); + let mut decoder = ReadEventCodec::new(); let mut partial = BytesMut::new(); for i in 0..full.len() - 1 { partial.extend_from_slice(&full[i..i + 1]); assert!(decoder.decode(&mut partial).unwrap().is_none()); } partial.extend_from_slice(&full[full.len() - 1..]); - let decoded = decoder.decode(&mut partial).unwrap().unwrap(); - assert_eq!(decoded, Frame::Push(Bytes::from_static(b"abc"))); + assert_eq!( + decoder.decode(&mut partial).unwrap(), + Some(ReadEvent::Push { data: payload }) + ); } #[test] - fn unknown_tag_error() { - let mut decoder = StreamCodec::new(); - let mut buf = BytesMut::from(&[0xff][..]); - assert!(decoder.decode(&mut buf).is_err()); + fn incomplete_large_push_decode_does_not_reserve_declared_payload_len() { + let declared_len = 10 * 1024 * 1024usize; + let mut codec = ReadEventCodec::new(); + let mut buf = BytesMut::with_capacity(PUSH_HEADER_MAX_LEN); + buf.put_u8(TAG_PUSH); + encode_varint(&mut buf, VarInt::try_from(declared_len).unwrap()); + + assert!(codec.decode(&mut buf).unwrap().is_none()); + assert!( + buf.capacity() <= 128 * 1024, + "decoder reserved declared payload length: capacity {}", + buf.capacity() + ); } #[test] @@ -385,11 +714,13 @@ mod tests { #[test] fn data_header_matches_frame_prefix() { let payload = Bytes::from(vec![0x5a; 1024]); - let mut codec = StreamCodec::new(); - let mut encoded = BytesMut::new(); - codec - .encode(Frame::Push(payload.clone()), &mut encoded) - .unwrap(); + let mut codec = WriteCommandCodec::new(); + let encoded = encode( + &mut codec, + WriteCommand::Push { + data: payload.clone(), + }, + ); let (header, header_len) = encode_push_header(payload.len()).unwrap(); assert_eq!(&encoded[..header_len], &header[..header_len]); diff --git a/src/ipc/quic/stream/handle.rs b/src/ipc/quic/stream/handle.rs index ace82cc..ca50872 100644 --- a/src/ipc/quic/stream/handle.rs +++ b/src/ipc/quic/stream/handle.rs @@ -1,6 +1,6 @@ //! Named handle types for IPC stream operations. //! -//! These types replace the raw `(VarInt, VarInt)` tuples returned by +//! These types replace the raw `VarInt` stream IDs returned by //! [`IpcConnection`](super::quic::IpcConnection) methods, giving each field //! a descriptive name. @@ -10,24 +10,20 @@ use crate::varint::VarInt; /// Handle returned by IPC `open_bi` / `accept_bi` operations. /// -/// Contains the FD-registry ID for retrieving 2 socketpair FDs (one per -/// direction) and the underlying QUIC stream ID. +/// Contains the underlying QUIC stream ID. The FD transfer ID is chosen by +/// the receiver and passed into the request before this handle is returned. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IpcBiHandle { - /// FD-registry ID for [`FdRegistry::wait_fds`](super::transport::FdRegistry::wait_fds). - pub fd_id: VarInt, /// The underlying QUIC stream ID. pub stream_id: VarInt, } /// Handle returned by IPC `open_uni` / `accept_uni` operations. /// -/// Contains the FD-registry ID for retrieving 1 socketpair FD and the -/// underlying QUIC stream ID. +/// Contains the underlying QUIC stream ID. The FD transfer ID is chosen by +/// the receiver and passed into the request before this handle is returned. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct IpcUniHandle { - /// FD-registry ID for [`FdRegistry::wait_fds`](super::transport::FdRegistry::wait_fds). - pub fd_id: VarInt, /// The underlying QUIC stream ID. pub stream_id: VarInt, } diff --git a/src/ipc/quic/stream/reader.rs b/src/ipc/quic/stream/reader.rs index 777ade7..494089d 100644 --- a/src/ipc/quic/stream/reader.rs +++ b/src/ipc/quic/stream/reader.rs @@ -1,15 +1,9 @@ -//! [`IpcReadStream`] — per-stream socketpair read half with pull-based flow control. +//! IPC read-side typed frame IO. //! -//! Wraps the read direction of a `SOCK_STREAM` socketpair, decoding the -//! framing protocol and exposing it as `Stream>` + -//! [`StopStream`] + [`GetStreamId`], satisfying [`quic::ReadStream`]. -//! -//! # Flow control -//! -//! The reader sends a parameterless `PULL` frame to grant the writer permission -//! to send exactly one `PUSH` frame. The protocol is strictly serial: -//! PULL → PUSH → PULL → PUSH → … This prevents back-pressure breakage -//! across the socketpair. +//! Worker-side handles are constructed directly as +//! [`BridgeStreamReader`](crate::rpc::stream::reader::BridgeStreamReader) +//! over a direction-aware socketpair frame IO. Hypervisor-side adapters use the +//! sibling IO type to execute read commands against a real QUIC read stream. use std::{ pin::Pin, @@ -17,513 +11,184 @@ use std::{ task::{Context, Poll}, }; -use bytes::Bytes; -use futures::Stream; -use tokio::{ - io::AsyncWrite, - net::{ - UnixStream, - unix::{OwnedReadHalf, OwnedWriteHalf}, - }, +use futures::{Sink, Stream}; +use tokio::net::{ + UnixStream, + unix::{OwnedReadHalf, OwnedWriteHalf}, }; -use tokio_util::codec::FramedRead; +use tokio_util::codec::{FramedRead, FramedWrite}; -use super::{ - codec::{Frame, StreamCodec, TAG_PULL, TAG_STOP}, - state::{PipeState, Step, Transition, check_lifecycle, encode_control}, -}; +use super::codec::{CodecError, ReadCommandCodec, ReadEventCodec}; use crate::{ - quic::{self, GetStreamId, StopStream, StreamError}, + quic, + rpc::{ + lifecycle::LifecycleExt, + stream::{ + frame::{ReadCommand, ReadEvent}, + reader::BridgeStreamReader, + }, + }, varint::VarInt, }; -/// Active-state fields for the reader. -struct ReaderLive { - /// Whether a `PULL` frame has been sent and we are waiting for the - /// corresponding `PUSH` reply. - pulling: bool, - read: FramedRead, - write: OwnedWriteHalf, - lifecycle: Arc, -} - -impl ReaderLive { - /// `poll_recv` step: send PULL → block-wait for PUSH. - /// - /// Uses atomic single-byte `poll_write` for PULL, eliminating the double- - /// PULL bug that existed with the FramedWrite approach. - fn step_poll_recv(&mut self, cx: &mut Context<'_>) -> Step { - let Self { - pulling, - read, - write, - lifecycle, - } = self; - - // 1. Send PULL if we haven't yet (atomic single-byte write). - if !*pulling { - match Pin::new(&mut *write).poll_write(cx, &[TAG_PULL]) { - Poll::Ready(Ok(1)) => *pulling = true, - Poll::Ready(Ok(_)) | Poll::Ready(Err(_)) => { - return check_lifecycle(lifecycle, Step::Transition(Transition::Finish)) - .map(|()| unreachable!()); - } - Poll::Pending => return Step::Pending, - } - } - - // 2. Block-wait for the next meaningful frame. - loop { - match Pin::new(&mut *read).poll_next(cx) { - Poll::Ready(Some(Ok(Frame::Push(data)))) => { - *pulling = false; - return Step::Done(data); - } - Poll::Ready(Some(Ok(Frame::Cancel(code)))) => { - return Step::Transition(Transition::Reset(code)); - } - Poll::Ready(Some(Ok(Frame::ConnClosed))) => { - return Step::Transition(Transition::ConnDied(lifecycle.clone())); - } - // PULL/STOP on a reader — protocol mismatch, skip. - Poll::Ready(Some(Ok(_))) => continue, - Poll::Ready(Some(Err(e))) => { - tracing::debug!(%e, "pipe codec error on reader"); - return check_lifecycle(lifecycle, Step::Transition(Transition::Finish)) - .map(|()| unreachable!()); - } - Poll::Ready(None) => { - // EOF — clean close if connection alive, else connection error. - return check_lifecycle(lifecycle, Step::Transition(Transition::Finish)) - .map(|()| unreachable!()); - } - Poll::Pending => return Step::Pending, - } - } +pin_project_lite::pin_project! { + /// Worker-side IPC frame IO for a QUIC read stream. + pub(crate) struct IpcReaderIo { + #[pin] + read: FramedRead, + #[pin] + write: FramedWrite, } +} - /// `poll_stop` step: encode STOP into a stack buffer and write it once. - fn step_poll_stop(&mut self, code: VarInt, cx: &mut Context<'_>) -> Step<()> { - let Self { - write, lifecycle, .. - } = self; - let (buf, len) = encode_control(TAG_STOP, code); - match Pin::new(&mut *write).poll_write(cx, &buf[..len]) { - Poll::Ready(Ok(_)) => Step::Done(()), - Poll::Ready(Err(e)) => { - tracing::debug!(%e, "pipe write error sending STOP"); - check_lifecycle(lifecycle, Step::Done(())) - } - // Frame buffered in kernel — report success. - Poll::Pending => Step::Done(()), +impl IpcReaderIo { + fn new(socket: UnixStream) -> Self { + let (read, write) = socket.into_split(); + Self { + read: FramedRead::new(read, ReadEventCodec::new()), + write: FramedWrite::new(write, ReadCommandCodec::new()), } } } -// ── IpcReadStream ─────────────────────────────────────────────────────────── +impl Sink for IpcReaderIo { + type Error = CodecError; -/// IPC read stream backed by a per-stream Unix socketpair. -/// -/// Reads PUSH/CANCEL/CONN_CLOSED frames from the socketpair. -/// Sends PULL and STOP frames back through the write half of the same -/// socketpair. -pub struct IpcReadStream { - stream_id: VarInt, - state: PipeState, -} + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_ready(cx) + } -impl IpcReadStream { - /// Create a new reader from a `tokio::net::UnixStream`. - pub fn new( - stream_id: VarInt, - socket: UnixStream, - lifecycle: Arc, - ) -> Self { - let (read_half, write_half) = socket.into_split(); - Self { - stream_id, - state: PipeState::Live(ReaderLive { - pulling: false, - read: FramedRead::new(read_half, StreamCodec::new()), - write: write_half, - lifecycle, - }), - } + fn start_send(self: Pin<&mut Self>, item: ReadCommand) -> Result<(), Self::Error> { + self.project().write.start_send(item) } - /// Core read method — strict PULL → PUSH serial flow control. - pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll>> { - loop { - if let Some(poll) = self.state.poll_non_live(cx) { - return match poll { - Poll::Ready(Ok(())) => Poll::Ready(None), - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), - Poll::Pending => Poll::Pending, - }; - } - let live = self.state.live_mut().unwrap(); - match live.step_poll_recv(cx) { - Step::Done(data) => return Poll::Ready(Some(Ok(data))), - Step::Pending => return Poll::Pending, - Step::Transition(t) => self.state.apply(t, cx), - } - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_flush(cx) } -} -impl GetStreamId for IpcReadStream { - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - Poll::Ready(Ok(self.stream_id)) + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_close(cx) } } -impl Stream for IpcReadStream { - type Item = Result; +impl Stream for IpcReaderIo { + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().poll_recv(cx) + self.project().read.poll_next(cx) } } -impl StopStream for IpcReadStream { - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let this = self.get_mut(); - loop { - if let Some(poll) = this.state.poll_non_live(cx) { - // Already dead — stop is a no-op. - let _ = poll; - return Poll::Ready(Ok(())); - } - let live = this.state.live_mut().unwrap(); - match live.step_poll_stop(code, cx) { - Step::Done(()) => return Poll::Ready(Ok(())), - Step::Pending => return Poll::Pending, - Step::Transition(t) => this.state.apply(t, cx), - } - } - } +pub(crate) fn reader( + stream_id: VarInt, + socket: UnixStream, + lifecycle: Arc, +) -> BridgeStreamReader +where + L: LifecycleExt + 'static, + quic::ConnectionError: From, +{ + BridgeStreamReader::new(stream_id, IpcReaderIo::new(socket), lifecycle) } -#[cfg(test)] -mod tests { - use std::{borrow::Cow, future::pending, pin::Pin, sync::Arc}; - - use bytes::Bytes; - use futures::{SinkExt, StreamExt, future::poll_fn}; - use tokio::{ - io::AsyncWriteExt, - net::{ - UnixStream, - unix::{OwnedReadHalf, OwnedWriteHalf}, - }, - time::{Duration, timeout}, - }; - use tokio_util::codec::{FramedRead, FramedWrite}; - - use super::*; - use crate::quic::{ConnectionError, GetStreamId, StopStream}; - - struct TestLifecycle { - terminal: Option, - } - - impl quic::Lifecycle for TestLifecycle { - fn close(&self, _code: crate::error::Code, _reason: Cow<'static, str>) {} - - fn check(&self) -> Result<(), ConnectionError> { - match &self.terminal { - Some(err) => Err(err.clone()), - None => Ok(()), - } - } - - async fn closed(&self) -> ConnectionError { - match &self.terminal { - Some(err) => err.clone(), - None => pending().await, - } - } +pin_project_lite::pin_project! { + /// Hypervisor-side IPC frame IO for a QUIC read stream. + pub(crate) struct IpcReadHypervisorIo { + #[pin] + read: FramedRead, + #[pin] + write: FramedWrite, } +} - fn test_connection_error(reason: &str) -> ConnectionError { - ConnectionError::Transport { - source: quic::TransportError { - kind: VarInt::from_u32(0x11), - frame_type: VarInt::from_u32(0x22), - reason: reason.to_owned().into(), - }, +impl IpcReadHypervisorIo { + pub(crate) fn new(socket: UnixStream) -> Self { + let (read, write) = socket.into_split(); + Self { + read: FramedRead::new(read, ReadCommandCodec::new()), + write: FramedWrite::new(write, ReadEventCodec::new()), } } +} - fn alive_lifecycle() -> Arc { - Arc::new(TestLifecycle { terminal: None }) - } - - fn dead_lifecycle(reason: &str) -> Arc { - Arc::new(TestLifecycle { - terminal: Some(test_connection_error(reason)), - }) - } - - async fn setup_reader_with_lifecycle( - lifecycle: Arc, - ) -> ( - IpcReadStream, - FramedWrite, - FramedRead, - ) { - let (reader_side, peer_side) = UnixStream::pair().unwrap(); - - let (peer_read, peer_write) = peer_side.into_split(); - - let reader = IpcReadStream::new(VarInt::from_u32(7), reader_side, lifecycle); - - ( - reader, - FramedWrite::new(peer_write, StreamCodec::new()), - FramedRead::new(peer_read, StreamCodec::new()), - ) - } - - async fn setup_reader() -> ( - IpcReadStream, - FramedWrite, - FramedRead, - ) { - setup_reader_with_lifecycle(alive_lifecycle()).await - } - - #[tokio::test] - async fn stream_id_matches_constructor() { - let (mut reader, _peer_in, _peer_out) = setup_reader().await; - let id = poll_fn(|cx| Pin::new(&mut reader).poll_stream_id(cx)) - .await - .unwrap(); - assert_eq!(id, VarInt::from_u32(7)); - } - - #[tokio::test] - async fn pull_is_sent_once_while_waiting_for_push() { - let (mut reader, mut peer_in, mut peer_out) = setup_reader().await; - - let mut recv = Box::pin(reader.next()); - let first = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(first, Frame::Pull); - - assert!( - timeout(Duration::from_millis(50), peer_out.next()) - .await - .is_err() - ); - - peer_in - .send(Frame::Push(Bytes::from_static(b"drain"))) - .await - .unwrap(); - let got = recv.await.unwrap().unwrap(); - assert_eq!(got, Bytes::from_static(b"drain")); - } - - #[tokio::test] - async fn push_roundtrip_and_next_poll_requests_again() { - let (mut reader, mut peer_in, mut peer_out) = setup_reader().await; - - let mut recv = Box::pin(reader.next()); - let first_pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(first_pull, Frame::Pull); - - peer_in - .send(Frame::Push(Bytes::from_static(b"reader-payload"))) - .await - .unwrap(); - - let got = recv.await.unwrap().unwrap(); - assert_eq!(got, Bytes::from_static(b"reader-payload")); - - let mut recv2 = Box::pin(reader.next()); - let second_pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv2 => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(second_pull, Frame::Pull); - - peer_in - .send(Frame::Push(Bytes::from_static(b"reader-payload-2"))) - .await - .unwrap(); - let got2 = recv2.await.unwrap().unwrap(); - assert_eq!(got2, Bytes::from_static(b"reader-payload-2")); - } - - #[tokio::test] - async fn protocol_mismatch_frames_are_ignored_until_push() { - let (mut reader, mut peer_in, mut peer_out) = setup_reader().await; - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); +impl Sink for IpcReadHypervisorIo { + type Error = CodecError; - peer_in.send(Frame::Pull).await.unwrap(); - peer_in - .send(Frame::Stop(VarInt::from_u32(1))) - .await - .unwrap(); - peer_in - .send(Frame::Push(Bytes::from_static(b"ok"))) - .await - .unwrap(); - - let got = recv.await.unwrap().unwrap(); - assert_eq!(got, Bytes::from_static(b"ok")); + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_ready(cx) } - #[tokio::test] - async fn cancel_frame_turns_reader_into_reset() { - let (mut reader, mut peer_in, mut peer_out) = setup_reader().await; - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); - peer_in - .send(Frame::Cancel(VarInt::from_u32(9))) - .await - .unwrap(); - - let err = recv.await.unwrap().unwrap_err(); - match err { - StreamError::Reset { code } => assert_eq!(code, VarInt::from_u32(9)), - other => panic!("expected reset error, got {other:?}"), - } + fn start_send(self: Pin<&mut Self>, item: ReadEvent) -> Result<(), Self::Error> { + self.project().write.start_send(item) } - #[tokio::test] - async fn conn_closed_with_dead_lifecycle_returns_connection_error() { - let (mut reader, mut peer_in, mut peer_out) = - setup_reader_with_lifecycle(dead_lifecycle("reader conn closed")).await; - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); - peer_in.send(Frame::ConnClosed).await.unwrap(); - - let err = recv.await.unwrap().unwrap_err(); - assert!(matches!(err, StreamError::Connection { .. })); + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_flush(cx) } - #[tokio::test] - async fn eof_with_alive_lifecycle_finishes_stream() { - let (mut reader, peer_in, mut peer_out) = setup_reader().await; - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); - drop(peer_in); - - assert!(recv.await.is_none()); + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_close(cx) } +} - #[tokio::test] - async fn eof_with_dead_lifecycle_returns_connection_error() { - let (mut reader, peer_in, mut peer_out) = - setup_reader_with_lifecycle(dead_lifecycle("reader eof while dead")).await; - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); - drop(peer_in); +impl Stream for IpcReadHypervisorIo { + type Item = Result; - let err = recv.await.unwrap().unwrap_err(); - assert!(matches!(err, StreamError::Connection { .. })); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().read.poll_next(cx) } +} - #[tokio::test] - async fn stop_sends_stop_frame_with_code() { - let (mut reader, _peer_in, mut peer_out) = setup_reader().await; - let code = VarInt::from_u32(77); +#[cfg(test)] +mod tests { + use std::sync::Arc; - poll_fn(|cx| Pin::new(&mut reader).poll_stop(cx, code)) - .await - .unwrap(); + use bytes::Bytes; + use futures::{SinkExt as _, StreamExt as _}; + use tokio::net::UnixStream; + use tokio_util::codec::{FramedRead, FramedWrite}; - // STOP is best-effort: if it is observed on wire, it must carry the - // same code. Implementations may still report success when the frame - // has only been queued locally. - if let Ok(Some(Ok(frame))) = timeout(Duration::from_millis(100), peer_out.next()).await { - assert_eq!(frame, Frame::Stop(code)); - } - } + use super::*; + use crate::rpc::stream::{ + frame::{ReadCommand, ReadEvent}, + test_io::TestLifecycle, + }; #[tokio::test] - async fn stop_is_noop_after_reader_is_dead() { - let (mut reader, mut peer_in, mut peer_out) = - setup_reader_with_lifecycle(dead_lifecycle("reader dead before stop")).await; - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); - peer_in.send(Frame::ConnClosed).await.unwrap(); - let _ = recv.await; - - let code = VarInt::from_u32(5); - poll_fn(|cx| Pin::new(&mut reader).poll_stop(cx, code)) + async fn reader_constructs_bridge_over_ipc_codecs() { + let stream_id = VarInt::from_u32(21); + let (worker_socket, peer_socket) = UnixStream::pair().unwrap(); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut reader = reader(stream_id, worker_socket, lifecycle); + let (peer_read, peer_write) = peer_socket.into_split(); + let mut peer_read = FramedRead::new(peer_read, ReadCommandCodec::new()); + let mut peer_write = FramedWrite::new(peer_write, ReadEventCodec::new()); + + let task = tokio::spawn(async move { reader.next().await }); + assert_eq!(peer_read.next().await.unwrap().unwrap(), ReadCommand::Pull); + peer_write + .send(ReadEvent::Push { + data: Bytes::from_static(b"ipc read"), + }) .await .unwrap(); - if let Ok(Some(Ok(frame))) = timeout(Duration::from_millis(50), peer_out.next()).await { - assert_ne!(frame, Frame::Stop(code)); - } + assert_eq!( + task.await.unwrap().unwrap().unwrap(), + Bytes::from_static(b"ipc read") + ); } #[tokio::test] - async fn codec_error_finishes_reader_when_connection_alive() { - let (reader_side, peer_side) = UnixStream::pair().unwrap(); - - let (peer_read, mut peer_write) = peer_side.into_split(); - - let lifecycle = alive_lifecycle(); - let mut reader = IpcReadStream::new(VarInt::from_u32(1), reader_side, lifecycle); - let mut peer_out = FramedRead::new(peer_read, StreamCodec::new()); - - let mut recv = Box::pin(reader.next()); - let pull = tokio::select! { - frame = peer_out.next() => frame.unwrap().unwrap(), - item = &mut recv => panic!("reader completed unexpectedly: {item:?}"), - }; - assert_eq!(pull, Frame::Pull); - - peer_write.write_all(&[0xff]).await.unwrap(); - - assert!(recv.await.is_none()); + async fn reader_bridge_io_eof_latches_connection_error() { + let stream_id = VarInt::from_u32(22); + let (worker_socket, peer_socket) = UnixStream::pair().unwrap(); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut reader = reader(stream_id, worker_socket, lifecycle.clone()); + drop(peer_socket); + + let error = reader.next().await.unwrap().unwrap_err(); + assert!(matches!(error, quic::StreamError::Connection { .. })); + assert!(quic::Lifecycle::check(lifecycle.as_ref()).is_err()); } } diff --git a/src/ipc/quic/stream/state.rs b/src/ipc/quic/stream/state.rs deleted file mode 100644 index 4ec63a3..0000000 --- a/src/ipc/quic/stream/state.rs +++ /dev/null @@ -1,331 +0,0 @@ -//! Shared state-machine primitives for [`PipeReader`] and [`PipeWriter`]. -//! -//! This module provides the building blocks that both the reader and writer -//! halves compose to implement their respective poll methods: -//! -//! - [`Step`] — a poll-step monad supporting `map` / `and_then` chaining. -//! - [`Transition`] — the vocabulary of state transitions. -//! - [`PipeState`] — generic three-state lifecycle machine. -//! - [`drain()`] — higher-order function for non-blocking inbound frame drain. -//! - [`flush_pending()`] / [`check_lifecycle()`] / [`encode_control()`] — I/O -//! primitives shared across reader and writer. - -use std::{ - io::IoSlice, - ops::ControlFlow, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures::{Stream, future::BoxFuture}; -use tokio::{io::AsyncWrite, net::unix::OwnedReadHalf}; -use tokio_util::codec::FramedRead; - -use super::codec::{ - CONTROL_MAX_LEN, Frame, PUSH_HEADER_MAX_LEN, StreamCodec, encode_push_header, - encode_varint_to_slice, -}; -use crate::{ - quic::{self, ConnectionError, StreamError}, - varint::VarInt, -}; - -// ── Step — poll-step monad ─────────────────────────────────────────────── - -/// Result of a single poll step on a live pipe half. -/// -/// Supports monadic chaining: `and_then` short-circuits on `Pending` and -/// `Transition`, only calling the continuation on `Done`. -pub(super) enum Step { - /// Operation completed with value `T`. - Done(T), - /// Operation would block — register waker and return `Poll::Pending`. - Pending, - /// A state transition should be applied before the next poll iteration. - Transition(Transition), -} - -impl Step { - pub fn map(self, f: impl FnOnce(T) -> U) -> Step { - match self { - Step::Done(v) => Step::Done(f(v)), - Step::Pending => Step::Pending, - Step::Transition(t) => Step::Transition(t), - } - } - - pub fn and_then(self, f: impl FnOnce(T) -> Step) -> Step { - match self { - Step::Done(v) => f(v), - Step::Pending => Step::Pending, - Step::Transition(t) => Step::Transition(t), - } - } -} - -// ── Transition — state-change vocabulary ──────────────────────────────────── - -/// A state transition to be applied by [`PipeState::apply`]. -pub(super) enum Transition { - /// Peer requested reset (STOP on writer / CANCEL on reader). - Reset(VarInt), - /// Clean close (writer SHUT_WR / reader EOF on alive connection). - Finish, - /// Connection died (CONN_CLOSED / lifecycle check failure / EOF on dead - /// connection). Carries the lifecycle handle so `apply` can create the - /// `Dying` future without borrowing the `Live` state. - ConnDied(Arc), -} - -// ── PipeState — generic three-state lifecycle machine ──────────────────── - -/// Three-state lifecycle machine parameterised over the live-state type `L`. -/// -/// ```text -/// Live(L) ──Reset──→ Dead(Err(Reset)) -/// ──Finish─→ Dead(Ok(())) -/// ──ConnDied→ Dying(fut) ──Ready(e)→ Dead(Err(Connection(e))) -/// └─check()→Err──────→ Dead(Err(Connection(e))) -/// ``` -pub(super) enum PipeState { - /// Normal operation. - Live(L), - /// Connection closed — waiting for the terminal error. - Dying(BoxFuture<'static, ConnectionError>), - /// Stream has ended. - Dead(Result<(), StreamError>), -} - -impl PipeState { - /// Apply a [`Transition`], consuming the `Live` state if necessary. - /// - /// After this call the state is guaranteed to be `Dying` or `Dead`. - pub fn apply(&mut self, transition: Transition, cx: &mut Context<'_>) { - match transition { - Transition::Reset(code) => { - *self = PipeState::Dead(Err(StreamError::Reset { code })); - } - Transition::Finish => { - *self = PipeState::Dead(Ok(())); - } - Transition::ConnDied(lifecycle) => { - if let Err(e) = lifecycle.check() { - *self = PipeState::Dead(Err(StreamError::Connection { source: e })); - return; - } - let mut fut: BoxFuture<'static, ConnectionError> = - Box::pin(async move { lifecycle.closed().await }); - match fut.as_mut().poll(cx) { - Poll::Ready(e) => { - *self = PipeState::Dead(Err(StreamError::Connection { source: e })); - } - Poll::Pending => { - *self = PipeState::Dying(fut); - } - } - } - } - } - - /// Drive non-`Live` states towards `Dead` and return the terminal result. - /// - /// Returns `None` when in `Live` — the caller should proceed with the - /// step function. Returns `Some(poll)` for `Dying` / `Dead`. - pub fn poll_non_live(&mut self, cx: &mut Context<'_>) -> Option>> { - loop { - match self { - PipeState::Live(_) => return None, - PipeState::Dying(fut) => { - let e = match fut.as_mut().poll(cx) { - Poll::Ready(e) => e, - Poll::Pending => return Some(Poll::Pending), - }; - *self = PipeState::Dead(Err(StreamError::Connection { source: e })); - continue; - } - PipeState::Dead(Ok(())) => return Some(Poll::Ready(Ok(()))), - PipeState::Dead(Err(e)) => return Some(Poll::Ready(Err(e.clone()))), - } - } - } - - /// Access the `Live` state. Returns `None` for `Dying`/`Dead`. - pub fn live_mut(&mut self) -> Option<&mut L> { - match self { - PipeState::Live(l) => Some(l), - _ => None, - } - } -} - -// ── DrainOutcome + drain() HOF ────────────────────────────────────────────── - -/// Result of draining all ready inbound frames from a `FramedRead`. -pub(super) enum DrainOutcome { - /// All ready frames consumed, no terminal signal. - Drained, - /// Callback returned `Break` — a state transition is required. - Break(Transition), - /// Read half EOF or codec error. - ReadClosed, -} - -impl DrainOutcome { - /// Convert to [`Step<()>`]. - /// - /// `Drained` → `Step::Done(())`, `Break` → `Step::Transition`. - /// `ReadClosed` semantics differ between reader and writer, so the caller - /// provides a closure to handle it. - pub fn resolve(self, on_read_closed: impl FnOnce() -> Step<()>) -> Step<()> { - match self { - DrainOutcome::Drained => Step::Done(()), - DrainOutcome::Break(t) => Step::Transition(t), - DrainOutcome::ReadClosed => on_read_closed(), - } - } -} - -/// Non-blocking drain of all ready inbound frames. -/// -/// Calls `on_frame` for each decoded frame. The callback returns -/// `ControlFlow::Continue(())` to keep draining or -/// `ControlFlow::Break(Transition)` to stop immediately. -pub(super) fn drain( - read: &mut FramedRead, - cx: &mut Context<'_>, - mut on_frame: impl FnMut(Frame) -> ControlFlow, -) -> DrainOutcome { - loop { - match Pin::new(&mut *read).poll_next(cx) { - Poll::Ready(Some(Ok(frame))) => match on_frame(frame) { - ControlFlow::Continue(()) => continue, - ControlFlow::Break(t) => return DrainOutcome::Break(t), - }, - Poll::Ready(Some(Err(_))) | Poll::Ready(None) => { - return DrainOutcome::ReadClosed; - } - Poll::Pending => return DrainOutcome::Drained, - } - } -} - -// ── Shared I/O primitives ─────────────────────────────────────────────────── - -/// Buffered PUSH frame waiting to be flushed via vectored I/O. -pub(super) struct PendingPush { - header: [u8; PUSH_HEADER_MAX_LEN], - header_len: usize, - header_off: usize, - body: Bytes, - body_off: usize, -} - -impl PendingPush { - /// Create a new pending PUSH frame from a body payload. - pub fn new(body: Bytes) -> Result { - let (header, header_len) = - encode_push_header(body.len()).map_err(|_| StreamError::Reset { - code: VarInt::default(), - })?; - Ok(Self { - header, - header_len, - header_off: 0, - body, - body_off: 0, - }) - } - - fn header_remaining(&self) -> &[u8] { - &self.header[self.header_off..self.header_len] - } - - fn body_remaining(&self) -> &[u8] { - &self.body[self.body_off..] - } - - fn is_done(&self) -> bool { - self.header_off == self.header_len && self.body_off == self.body.len() - } - - fn advance(&mut self, mut written: usize) { - let header_left = self.header_len - self.header_off; - let take_header = written.min(header_left); - self.header_off += take_header; - written -= take_header; - self.body_off += written; - } -} - -/// Flush a pending PUSH frame through `write` using vectored I/O. -/// -/// Returns `Step::Done(())` when fully written, `Step::Pending` when blocked, -/// or `Step::Transition(ConnDied)` on I/O error (after lifecycle check). -pub(super) fn flush_pending( - write: &mut (impl AsyncWrite + Unpin), - lifecycle: &Arc, - pending: &mut Option, - cx: &mut Context<'_>, -) -> Step<()> { - loop { - let Some(p) = pending.as_ref() else { - return Step::Done(()); - }; - - let header_remaining = p.header_remaining(); - let body_remaining = p.body_remaining(); - let bufs = [IoSlice::new(header_remaining), IoSlice::new(body_remaining)]; - let bufs: &[IoSlice<'_>] = if header_remaining.is_empty() { - &bufs[1..] - } else if body_remaining.is_empty() { - &bufs[..1] - } else { - &bufs - }; - - let written = match Pin::new(&mut *write).poll_write_vectored(cx, bufs) { - Poll::Ready(Ok(0)) => { - return check_lifecycle(lifecycle, Step::Transition(Transition::Finish)); - } - Poll::Ready(Ok(n)) => n, - Poll::Ready(Err(e)) => { - tracing::debug!(%e, "pipe write error during PUSH flush"); - return check_lifecycle(lifecycle, Step::Transition(Transition::Finish)); - } - Poll::Pending => return Step::Pending, - }; - - let p = pending.as_mut().unwrap(); - p.advance(written); - if p.is_done() { - *pending = None; - } - } -} - -/// Check the connection lifecycle. -/// -/// If the connection is dead, returns `Transition::ConnDied`. -/// Otherwise returns `on_alive` unchanged. -pub(super) fn check_lifecycle( - lifecycle: &Arc, - on_alive: Step<()>, -) -> Step<()> { - if lifecycle.check().is_err() { - Step::Transition(Transition::ConnDied(lifecycle.clone())) - } else { - on_alive - } -} - -/// Encode a control frame (STOP or CANCEL) into a stack buffer. -/// -/// Returns `(buffer, length)` ready for a single `poll_write`. -pub(super) fn encode_control(tag: u8, code: VarInt) -> ([u8; CONTROL_MAX_LEN], usize) { - let mut buf = [0u8; CONTROL_MAX_LEN]; - buf[0] = tag; - let vi_len = encode_varint_to_slice(&mut buf[1..], code); - (buf, 1 + vi_len) -} diff --git a/src/ipc/quic/stream/writer.rs b/src/ipc/quic/stream/writer.rs index e9cdbbf..813ecf0 100644 --- a/src/ipc/quic/stream/writer.rs +++ b/src/ipc/quic/stream/writer.rs @@ -1,637 +1,203 @@ -//! [`IpcWriteStream`] — per-stream socketpair write half with pull-based flow control. +//! IPC write-side typed frame IO. //! -//! Wraps the write direction of a `SOCK_STREAM` socketpair, encoding the pipe -//! framing protocol and exposing it as `Sink` + -//! [`CancelStream`] + [`GetStreamId`], satisfying [`quic::WriteStream`]. -//! -//! # Flow control -//! -//! The writer may only send data after the peer reader has granted permission -//! via a `PULL` frame. The protocol is strictly serial: each `PULL` permits -//! exactly one `PUSH` frame. `poll_ready` returns `Pending` when no `PULL` -//! has been received, and resumes once one arrives. -//! -//! # Flush / Close semantics -//! -//! - `poll_flush` — flushes the pending PUSH frame via vectored I/O. -//! - `poll_close` — flushes remaining data, then shuts down the write half -//! (`SHUT_WR`), which the remote side observes as EOF (equivalent to QUIC FIN). +//! Worker-side handles are constructed directly as +//! [`BridgeStreamWriter`](crate::rpc::stream::writer::BridgeStreamWriter) +//! over a direction-aware socketpair frame IO. Hypervisor-side adapters use the +//! sibling IO type to execute write commands against a real QUIC write stream. use std::{ - ops::ControlFlow, pin::Pin, sync::Arc, task::{Context, Poll}, }; -use bytes::Bytes; -use futures::Sink; -use tokio::{ - io::AsyncWrite, - net::{ - UnixStream, - unix::{OwnedReadHalf, OwnedWriteHalf}, - }, +use futures::{Sink, Stream}; +use tokio::net::{ + UnixStream, + unix::{OwnedReadHalf, OwnedWriteHalf}, }; -use tokio_util::codec::FramedRead; +use tokio_util::codec::{FramedRead, FramedWrite}; -use super::{ - codec::{Frame, StreamCodec, TAG_CANCEL, TAG_CONN_CLOSED}, - state::{ - PendingPush, PipeState, Step, Transition, check_lifecycle, drain, encode_control, - flush_pending, - }, -}; +use super::codec::{CodecError, WriteCommandCodec, WriteEventCodec}; use crate::{ - quic::{self, CancelStream, GetStreamId, StreamError}, + quic, + rpc::{ + lifecycle::LifecycleExt, + stream::{ + frame::{WriteCommand, WriteEvent}, + writer::BridgeStreamWriter, + }, + }, varint::VarInt, }; -/// Active-state fields for the writer. -struct WriterLive { - read: FramedRead, - write: OwnedWriteHalf, - lifecycle: Arc, - pulled: bool, - pending: Option, +pin_project_lite::pin_project! { + /// Worker-side IPC frame IO for a QUIC write stream. + pub(crate) struct IpcWriterIo { + #[pin] + read: FramedRead, + #[pin] + write: FramedWrite, + } } -impl WriterLive { - /// Drain inbound control frames and check the lifecycle. - /// - /// - `PULL` → grant send permission - /// - `STOP(code)` → peer requests reset - /// - `CONN_CLOSED` → connection died - /// - others → ignore - /// - /// On read-close: sends best-effort `CONN_CLOSED` (only when no partial - /// PUSH is in flight) and transitions to `ConnDied`. The pipe EOF is - /// treated as authoritative — no more `PULL` frames can ever arrive — - /// even if `lifecycle.check()` has not yet been updated. - fn drain_and_check(&mut self, cx: &mut Context<'_>) -> Step<()> { - let outcome = drain(&mut self.read, cx, |frame| match frame { - Frame::Pull => { - self.pulled = true; - ControlFlow::Continue(()) - } - Frame::Stop(code) => ControlFlow::Break(Transition::Reset(code)), - Frame::ConnClosed => ControlFlow::Break(Transition::ConnDied(self.lifecycle.clone())), - _ => ControlFlow::Continue(()), - }); - - outcome.resolve(|| { - // Pipe read half is closed: the peer bridge task has exited - // (either because the connection died or because it finished). - // No more PULL frames can ever arrive, so continuing would - // deadlock poll_ready. Transition unconditionally. - if self.pending.is_none() { - let _ = Pin::new(&mut self.write).poll_write(cx, &[TAG_CONN_CLOSED]); - } - Step::Transition(Transition::ConnDied(self.lifecycle.clone())) - }) +impl IpcWriterIo { + fn new(socket: UnixStream) -> Self { + let (read, write) = socket.into_split(); + Self { + read: FramedRead::new(read, WriteEventCodec::new()), + write: FramedWrite::new(write, WriteCommandCodec::new()), + } } +} - /// `poll_ready` step: drain control → eagerly flush pending → check pulled. - /// - /// Callers (e.g. `SinkWriter::poll_write`) may call `poll_ready` after - /// `start_send` without an intervening `poll_flush`. To avoid deadlock - /// we eagerly flush pending PUSH data here, ensuring the bridge-writer - /// can receive it and send the next PULL. - fn step_poll_ready(&mut self, cx: &mut Context<'_>) -> Step<()> { - let step = self.drain_and_check(cx); - step.and_then(|()| { - // Eagerly flush any pending PUSH — prevents deadlock when the - // caller skips poll_flush between start_send and poll_ready. - if self.pending.is_some() { - match flush_pending(&mut self.write, &self.lifecycle, &mut self.pending, cx) { - Step::Transition(t) => return Step::Transition(t), - Step::Pending | Step::Done(()) => {} - } - } +impl Sink for IpcWriterIo { + type Error = CodecError; - if self.pulled && self.pending.is_none() { - Step::Done(()) - } else { - Step::Pending - } - }) + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_ready(cx) } - /// `poll_flush` step: drain control → flush pending PUSH via vectored I/O. - fn step_poll_flush(&mut self, cx: &mut Context<'_>) -> Step<()> { - let step = self.drain_and_check(cx); - step.and_then(|()| flush_pending(&mut self.write, &self.lifecycle, &mut self.pending, cx)) + fn start_send(self: Pin<&mut Self>, item: WriteCommand) -> Result<(), Self::Error> { + self.project().write.start_send(item) } - /// `poll_close` step: flush pending data → shutdown write half (SHUT_WR). - fn step_poll_close(&mut self, cx: &mut Context<'_>) -> Step<()> { - let Self { - write, - lifecycle, - pending, - .. - } = self; - flush_pending(write, lifecycle, pending, cx).and_then(|()| { - match Pin::new(&mut *write).poll_shutdown(cx) { - Poll::Ready(Ok(())) => Step::Transition(Transition::Finish), - Poll::Ready(Err(e)) => { - tracing::debug!(%e, "pipe write error during close shutdown"); - check_lifecycle(lifecycle, Step::Transition(Transition::Finish)) - } - Poll::Pending => Step::Pending, - } - }) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_flush(cx) } - /// `poll_cancel` step: discard pending data → best-effort CANCEL → shutdown. - fn step_poll_cancel(&mut self, code: VarInt, cx: &mut Context<'_>) -> Step<()> { - self.pending = None; - let (buf, len) = encode_control(TAG_CANCEL, code); - let _ = Pin::new(&mut self.write).poll_write(cx, &buf[..len]); - match Pin::new(&mut self.write).poll_shutdown(cx) { - Poll::Ready(_) => Step::Transition(Transition::Finish), - Poll::Pending => Step::Pending, - } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_close(cx) } } -// ── IpcWriteStream ────────────────────────────────────────────────────────────── +impl Stream for IpcWriterIo { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().read.poll_next(cx) + } +} -/// IPC write stream backed by a per-stream Unix socketpair. -/// -/// Sends PUSH frames through the write half of the socketpair. -/// Reads PULL/STOP/CONN_CLOSED frames from the read half for flow control -/// and lifecycle signals. -pub struct IpcWriteStream { +pub(crate) fn writer( stream_id: VarInt, - state: PipeState, + socket: UnixStream, + lifecycle: Arc, +) -> BridgeStreamWriter +where + L: LifecycleExt + 'static, + quic::ConnectionError: From, +{ + BridgeStreamWriter::new(stream_id, IpcWriterIo::new(socket), lifecycle) } -impl IpcWriteStream { - /// Create a new writer from a `tokio::net::UnixStream`. - pub fn new( - stream_id: VarInt, - socket: UnixStream, - lifecycle: Arc, - ) -> Self { - let (read_half, write_half) = socket.into_split(); - Self { - stream_id, - state: PipeState::Live(WriterLive { - read: FramedRead::new(read_half, StreamCodec::new()), - write: write_half, - lifecycle, - pulled: false, - pending: None, - }), - } +pin_project_lite::pin_project! { + /// Hypervisor-side IPC frame IO for a QUIC write stream. + pub(crate) struct IpcWriteHypervisorIo { + #[pin] + read: FramedRead, + #[pin] + write: FramedWrite, } } -impl GetStreamId for IpcWriteStream { - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - Poll::Ready(Ok(self.stream_id)) +impl IpcWriteHypervisorIo { + pub(crate) fn new(socket: UnixStream) -> Self { + let (read, write) = socket.into_split(); + Self { + read: FramedRead::new(read, WriteCommandCodec::new()), + write: FramedWrite::new(write, WriteEventCodec::new()), + } } } -impl Sink for IpcWriteStream { - type Error = StreamError; +impl Sink for IpcWriteHypervisorIo { + type Error = CodecError; - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - loop { - if let Some(poll) = this.state.poll_non_live(cx) { - return poll.map(|r| { - r.and(Err(StreamError::Reset { - code: VarInt::default(), - })) - }); - } - let live = this.state.live_mut().unwrap(); - match live.step_poll_ready(cx) { - Step::Done(()) => return Poll::Ready(Ok(())), - Step::Pending => return Poll::Pending, - Step::Transition(t) => this.state.apply(t, cx), - } - } + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_ready(cx) } - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), StreamError> { - let this = self.get_mut(); - let PipeState::Live(live) = &mut this.state else { - return Err(StreamError::Reset { - code: VarInt::default(), - }); - }; - if !live.pulled || live.pending.is_some() { - return Err(StreamError::Reset { - code: VarInt::default(), - }); - } - live.pending = Some(PendingPush::new(item)?); - live.pulled = false; - Ok(()) + fn start_send(self: Pin<&mut Self>, item: WriteEvent) -> Result<(), Self::Error> { + self.project().write.start_send(item) } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - loop { - if let Some(poll) = this.state.poll_non_live(cx) { - return poll.map(|r| { - r.and(Err(StreamError::Reset { - code: VarInt::default(), - })) - }); - } - let live = this.state.live_mut().unwrap(); - match live.step_poll_flush(cx) { - Step::Done(()) => return Poll::Ready(Ok(())), - Step::Pending => return Poll::Pending, - Step::Transition(t) => this.state.apply(t, cx), - } - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_flush(cx) } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - loop { - if let Some(poll) = this.state.poll_non_live(cx) { - // Dead(Ok(())) = clean close succeeded; Dead(Err(_)) = already failed. - return poll; - } - let live = this.state.live_mut().unwrap(); - match live.step_poll_close(cx) { - Step::Done(()) => return Poll::Ready(Ok(())), - Step::Pending => return Poll::Pending, - Step::Transition(t) => this.state.apply(t, cx), - } - } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().write.poll_close(cx) } } -impl CancelStream for IpcWriteStream { - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let this = self.get_mut(); - loop { - if let Some(poll) = this.state.poll_non_live(cx) { - // Already dead — cancel is a no-op. - let _ = poll; - return Poll::Ready(Ok(())); - } - let live = this.state.live_mut().unwrap(); - match live.step_poll_cancel(code, cx) { - Step::Done(()) => return Poll::Ready(Ok(())), - Step::Pending => return Poll::Pending, - Step::Transition(t) => this.state.apply(t, cx), - } - } +impl Stream for IpcWriteHypervisorIo { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().read.poll_next(cx) } } #[cfg(test)] mod tests { - use std::{ - borrow::Cow, - future::pending, - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; + use std::sync::Arc; use bytes::Bytes; - use futures::{Sink, SinkExt, StreamExt, future::poll_fn, task::noop_waker_ref}; - use tokio::{ - net::{ - UnixStream, - unix::{OwnedReadHalf, OwnedWriteHalf}, - }, - time::{Duration, timeout}, - }; + use futures::{SinkExt as _, StreamExt as _}; + use tokio::net::UnixStream; use tokio_util::codec::{FramedRead, FramedWrite}; use super::*; - use crate::quic::{CancelStream, ConnectionError}; - - struct TestLifecycle { - terminal: Option, - } - - impl quic::Lifecycle for TestLifecycle { - fn close(&self, _code: crate::error::Code, _reason: Cow<'static, str>) {} - - fn check(&self) -> Result<(), ConnectionError> { - match &self.terminal { - Some(err) => Err(err.clone()), - None => Ok(()), - } - } - - async fn closed(&self) -> ConnectionError { - match &self.terminal { - Some(err) => err.clone(), - None => pending().await, - } - } - } - - fn test_connection_error(reason: &str) -> ConnectionError { - ConnectionError::Transport { - source: quic::TransportError { - kind: VarInt::from_u32(0x31), - frame_type: VarInt::from_u32(0x32), - reason: reason.to_owned().into(), - }, - } - } - - fn alive_lifecycle() -> Arc { - Arc::new(TestLifecycle { terminal: None }) - } - - fn dead_lifecycle(reason: &str) -> Arc { - Arc::new(TestLifecycle { - terminal: Some(test_connection_error(reason)), - }) - } - - async fn setup_writer_with_lifecycle( - lifecycle: Arc, - ) -> ( - IpcWriteStream, - FramedWrite, - FramedRead, - ) { - let (writer_side, peer_side) = UnixStream::pair().unwrap(); - let (peer_read, peer_write) = peer_side.into_split(); - - let writer = IpcWriteStream::new(VarInt::from_u32(1), writer_side, lifecycle); - - ( - writer, - FramedWrite::new(peer_write, StreamCodec::new()), - FramedRead::new(peer_read, StreamCodec::new()), - ) - } - - async fn setup_writer() -> ( - IpcWriteStream, - FramedWrite, - FramedRead, - ) { - setup_writer_with_lifecycle(alive_lifecycle()).await - } - - fn poll_ready_once(writer: &mut IpcWriteStream) -> Poll> { - let waker = noop_waker_ref(); - let mut cx = Context::from_waker(waker); - Pin::new(writer).poll_ready(&mut cx) - } - - #[tokio::test] - async fn poll_ready_pending_without_pull() { - let (mut writer, _peer_ctrl, _peer_data) = setup_writer().await; - assert!(poll_ready_once(&mut writer).is_pending()); - } - - #[tokio::test] - async fn pull_then_send_data_roundtrip() { - let (mut writer, mut peer_ctrl, mut peer_data) = setup_writer().await; - - peer_ctrl.send(Frame::Pull).await.unwrap(); - - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"hello pipe")) - .unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_flush(cx)) - .await - .unwrap(); - - let frame = peer_data.next().await.unwrap().unwrap(); - assert_eq!(frame, Frame::Push(Bytes::from_static(b"hello pipe"))); - } - - #[tokio::test] - async fn pull_permission_consumed_per_data_frame() { - let (mut writer, mut peer_ctrl, mut peer_data) = setup_writer().await; - - peer_ctrl.send(Frame::Pull).await.unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"first")) - .unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_flush(cx)) - .await - .unwrap(); - let _ = peer_data.next().await.unwrap().unwrap(); - - assert!(poll_ready_once(&mut writer).is_pending()); - - peer_ctrl.send(Frame::Pull).await.unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"second")) - .unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_flush(cx)) - .await - .unwrap(); - - let frame = peer_data.next().await.unwrap().unwrap(); - assert_eq!(frame, Frame::Push(Bytes::from_static(b"second"))); - } - - #[tokio::test] - async fn start_send_without_pull_returns_reset() { - let (mut writer, _peer_ctrl, _peer_data) = setup_writer().await; - - let err = Pin::new(&mut writer) - .start_send(Bytes::from_static(b"no-pull")) - .unwrap_err(); - assert!(matches!(err, StreamError::Reset { .. })); - } - - #[tokio::test] - async fn start_send_twice_without_flush_returns_reset() { - let (mut writer, mut peer_ctrl, _peer_data) = setup_writer().await; - - peer_ctrl.send(Frame::Pull).await.unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"first")) - .unwrap(); - let err = Pin::new(&mut writer) - .start_send(Bytes::from_static(b"second")) - .unwrap_err(); - assert!(matches!(err, StreamError::Reset { .. })); - } + use crate::{ + quic, + rpc::stream::{ + frame::{WriteCommand, WriteEvent}, + test_io::TestLifecycle, + }, + }; #[tokio::test] - async fn poll_close_flushes_pending_push_and_shuts_down() { - let (mut writer, mut peer_ctrl, mut peer_data) = setup_writer().await; - - peer_ctrl.send(Frame::Pull).await.unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"close-payload")) - .unwrap(); - - poll_fn(|cx| Pin::new(&mut writer).poll_close(cx)) - .await - .unwrap(); + async fn writer_constructs_bridge_over_ipc_codecs() { + let stream_id = VarInt::from_u32(31); + let (worker_socket, peer_socket) = UnixStream::pair().unwrap(); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut writer = writer(stream_id, worker_socket, lifecycle); + let (peer_read, peer_write) = peer_socket.into_split(); + let mut peer_read = FramedRead::new(peer_read, WriteCommandCodec::new()); + let mut peer_write = FramedWrite::new(peer_write, WriteEventCodec::new()); + let data = Bytes::from_static(b"ipc write"); + let sent = data.clone(); + + let task = tokio::spawn(async move { + writer.send(sent).await?; + Ok::<_, quic::StreamError>(()) + }); + peer_write.send(WriteEvent::Pull).await.unwrap(); assert_eq!( - peer_data.next().await.unwrap().unwrap(), - Frame::Push(Bytes::from_static(b"close-payload")) + peer_read.next().await.unwrap().unwrap(), + WriteCommand::Push { data } ); - let eof = timeout(Duration::from_millis(200), peer_data.next()) - .await - .expect("expected EOF after poll_close"); - assert!(eof.is_none()); - } - - #[tokio::test] - async fn poll_cancel_sends_cancel_and_discards_pending_push() { - let (mut writer, mut peer_ctrl, mut peer_data) = setup_writer().await; - let code = VarInt::from_u32(77); - - peer_ctrl.send(Frame::Pull).await.unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"discard-me")) - .unwrap(); - - poll_fn(|cx| Pin::new(&mut writer).poll_cancel(cx, code)) - .await - .unwrap(); - assert_eq!( - peer_data.next().await.unwrap().unwrap(), - Frame::Cancel(code) + peer_read.next().await.unwrap().unwrap(), + WriteCommand::Flush ); - let eof = timeout(Duration::from_millis(200), peer_data.next()) - .await - .expect("expected EOF after poll_cancel"); - assert!(eof.is_none()); - } - - #[tokio::test] - async fn stop_frame_turns_writer_into_reset() { - let (mut writer, mut peer_ctrl, _peer_data) = setup_writer().await; - - peer_ctrl - .send(Frame::Stop(VarInt::from_u32(7))) - .await - .unwrap(); - - let err = poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap_err(); - match err { - StreamError::Reset { code } => assert_eq!(code, VarInt::from_u32(7)), - other => panic!("expected reset error, got {other:?}"), - } - } - - #[tokio::test] - async fn conn_closed_frame_with_dead_lifecycle_returns_connection_error() { - let (mut writer, mut peer_ctrl, _peer_data) = - setup_writer_with_lifecycle(dead_lifecycle("writer conn closed")).await; - - peer_ctrl.send(Frame::ConnClosed).await.unwrap(); - let err = poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap_err(); - assert!(matches!(err, StreamError::Connection { .. })); - } - - #[tokio::test] - async fn read_eof_with_dead_lifecycle_sends_conn_closed_when_no_pending() { - let (mut writer, peer_ctrl, mut peer_data) = - setup_writer_with_lifecycle(dead_lifecycle("writer eof dead")).await; - - drop(peer_ctrl); - let err = poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap_err(); - assert!(matches!(err, StreamError::Connection { .. })); + peer_write.send(WriteEvent::FlushAck).await.unwrap(); - let frame = timeout(Duration::from_millis(200), peer_data.next()) - .await - .expect("expected CONN_CLOSED frame") - .expect("expected a frame") - .expect("expected successful decode"); - assert_eq!(frame, Frame::ConnClosed); + task.await.unwrap().unwrap(); } #[tokio::test] - async fn conn_closed_with_pending_push_returns_connection_error_without_conn_closed_echo() { - let (mut writer, mut peer_ctrl, mut peer_data) = - setup_writer_with_lifecycle(dead_lifecycle("writer eof with pending")).await; - - peer_ctrl.send(Frame::Pull).await.unwrap(); - poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap(); - Pin::new(&mut writer) - .start_send(Bytes::from_static(b"in-flight")) - .unwrap(); - - peer_ctrl.send(Frame::ConnClosed).await.unwrap(); - let flush_result = poll_fn(|cx| Pin::new(&mut writer).poll_flush(cx)).await; - if let Err(err) = flush_result { - assert!(matches!(err, StreamError::Connection { .. })); - } - - if let Ok(Some(Ok(frame))) = timeout(Duration::from_millis(50), peer_data.next()).await { - assert_ne!(frame, Frame::ConnClosed); - } - } - - #[tokio::test] - async fn unexpected_cancel_frame_is_ignored_while_waiting_for_pull() { - let (mut writer, mut peer_ctrl, _peer_data) = setup_writer().await; - - peer_ctrl - .send(Frame::Cancel(VarInt::from_u32(1))) - .await - .unwrap(); - assert!(poll_ready_once(&mut writer).is_pending()); - } - - #[tokio::test] - async fn cancel_is_noop_after_writer_is_dead() { - let (mut writer, mut peer_ctrl, _peer_data) = setup_writer().await; - - peer_ctrl - .send(Frame::Stop(VarInt::from_u32(123))) - .await - .unwrap(); - let _ = poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)) - .await - .unwrap_err(); - - poll_fn(|cx| Pin::new(&mut writer).poll_cancel(cx, VarInt::from_u32(9))) - .await - .unwrap(); + async fn writer_bridge_io_eof_latches_connection_error() { + let stream_id = VarInt::from_u32(32); + let (worker_socket, peer_socket) = UnixStream::pair().unwrap(); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut writer = writer(stream_id, worker_socket, lifecycle.clone()); + drop(peer_socket); + + let error = writer.send(Bytes::from_static(b"lost")).await.unwrap_err(); + assert!(matches!(error, quic::StreamError::Connection { .. })); + assert!(quic::Lifecycle::check(lifecycle.as_ref()).is_err()); } } diff --git a/src/ipc/quic/tests.rs b/src/ipc/quic/test_utils.rs similarity index 65% rename from src/ipc/quic/tests.rs rename to src/ipc/quic/test_utils.rs index 68d0bf0..d230eb8 100644 --- a/src/ipc/quic/tests.rs +++ b/src/ipc/quic/test_utils.rs @@ -16,7 +16,7 @@ use std::{ }; use bytes::Bytes; -use futures::{Sink, Stream, StreamExt, future::pending}; +use futures::{Sink, SinkExt, Stream, StreamExt, future::pending}; use remoc::prelude::{ServerShared, ServerSharedMut}; use tokio::sync::{Mutex, mpsc}; @@ -55,6 +55,32 @@ fn test_connection_error(reason: &str) -> ConnectionError { } } +fn assert_connection_reason(error: &ConnectionError, expected: &str) { + match error { + ConnectionError::Transport { source } => { + assert!(source.reason.contains(expected), "{source:?}"); + } + error => panic!("expected transport error containing {expected:?}, got {error:?}"), + } +} + +fn assert_stream_reason(error: &StreamError, expected: &str) { + match error { + StreamError::Connection { source } => assert_connection_reason(source, expected), + error => panic!("expected connection stream error containing {expected:?}, got {error:?}"), + } +} + +fn expect_connection_error( + result: Result, + context: &str, +) -> ConnectionError { + match result { + Ok(_) => panic!("{context}"), + Err(error) => error, + } +} + // --------------------------------------------------------------------------- // TestLifecycle: simple lifecycle that can be alive or have a terminal error // --------------------------------------------------------------------------- @@ -150,8 +176,8 @@ impl quic::GetStreamId for ChannelWriter { } } -impl quic::CancelStream for ChannelWriter { - fn poll_cancel( +impl quic::ResetStream for ChannelWriter { + fn poll_reset( self: Pin<&mut Self>, _cx: &mut Context, _code: VarInt, @@ -197,7 +223,7 @@ impl Sink for ChannelWriter { } } -// WriteStream blanket impl exists for S: CancelStream + GetStreamId + Sink + Send + Any +// WriteStream blanket impl exists for S: ResetStream + GetStreamId + Sink + Send + Any // --------------------------------------------------------------------------- // StreamableConnection: mock connection that supports actual data transfer @@ -340,11 +366,11 @@ impl quic::Lifecycle for StreamableConnection { } } -// Dummy agent types for WithLocalAgent / WithRemoteAgent +// Dummy authority types for WithLocalAuthority / WithRemoteAuthority #[derive(Debug)] -struct NoAgent; +struct NoAuthority; -impl crate::quic::agent::LocalAgent for NoAgent { +impl dhttp_identity::identity::LocalAuthority for NoAuthority { fn name(&self) -> &str { "none" } @@ -352,21 +378,15 @@ impl crate::quic::agent::LocalAgent for NoAgent { fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { &[] } - - fn sign_algorithm(&self) -> rustls::SignatureAlgorithm { - rustls::SignatureAlgorithm::ED25519 - } - fn sign( &self, - _scheme: rustls::SignatureScheme, _data: &[u8], - ) -> futures::future::BoxFuture<'_, Result, crate::quic::agent::SignError>> { + ) -> futures::future::BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { Box::pin(async { Ok(Vec::new()) }) } } -impl crate::quic::agent::RemoteAgent for NoAgent { +impl dhttp_identity::identity::RemoteAuthority for NoAuthority { fn name(&self) -> &str { "none" } @@ -376,16 +396,16 @@ impl crate::quic::agent::RemoteAgent for NoAgent { } } -impl quic::WithLocalAgent for StreamableConnection { - type LocalAgent = NoAgent; - async fn local_agent(&self) -> Result, ConnectionError> { +impl quic::WithLocalAuthority for StreamableConnection { + type LocalAuthority = NoAuthority; + async fn local_authority(&self) -> Result, ConnectionError> { Ok(None) } } -impl quic::WithRemoteAgent for StreamableConnection { - type RemoteAgent = NoAgent; - async fn remote_agent(&self) -> Result, ConnectionError> { +impl quic::WithRemoteAuthority for StreamableConnection { + type RemoteAuthority = NoAuthority; + async fn remote_authority(&self) -> Result, ConnectionError> { Ok(None) } } @@ -457,10 +477,10 @@ async fn setup_listen_pair() -> ( let (server_mux, client_mux) = MuxChannel::pair_for_test().unwrap(); let (server_sink, server_stream) = server_mux.split().unwrap(); - let fd_sender = server_sink.fd_sender(); + let server_fd_transfer = server_stream.fd_transfer(server_sink.fd_sender()); let (client_sink, client_stream) = client_mux.split().unwrap(); - let client_fd_registry = client_stream.fd_registry(); + let client_fd_transfer = client_stream.fd_transfer(client_sink.fd_sender()); // Both sides must handshake concurrently let server_task = tokio::spawn(async move { @@ -474,7 +494,8 @@ async fn setup_listen_pair() -> ( .unwrap(); tokio::spawn(remoc_conn); - let adapter = ListenAdapter::<_, remoc::codec::Default>::new(mock_listen, fd_sender); + let adapter = + ListenAdapter::<_, remoc::codec::Default>::new(mock_listen, server_fd_transfer); let (server, listen_client) = IpcListenServerSharedMut::new(Arc::new(tokio::sync::RwLock::new(adapter)), 64); tokio::spawn(async move { @@ -495,7 +516,7 @@ async fn setup_listen_pair() -> ( tokio::spawn(remoc_conn); let listen_client = rx.recv().await.unwrap().unwrap(); - IpcListener::new(listen_client, client_fd_registry) + IpcListener::new(listen_client, client_fd_transfer) }); let (server_result, client_result) = tokio::join!(server_task, client_task); @@ -504,6 +525,241 @@ async fn setup_listen_pair() -> ( (listener, conn_tx) } +#[tokio::test] +async fn direct_lifecycle_and_stream_helpers_cover_control_paths() { + use quic::{GetStreamIdExt, ResetStreamExt, StopStreamExt}; + + let lifecycle = TestLifecycle::new(); + assert!(quic::Lifecycle::check(&lifecycle).is_ok()); + assert!( + tokio::time::timeout( + std::time::Duration::from_millis(10), + quic::Lifecycle::closed(&lifecycle), + ) + .await + .is_err(), + "open lifecycle should keep closed() pending" + ); + let terminal = test_connection_error("direct lifecycle terminal"); + lifecycle.set_terminal_error(terminal); + assert_connection_reason( + &quic::Lifecycle::check(&lifecycle).expect_err("terminal lifecycle should fail check"), + "direct lifecycle terminal", + ); + assert_connection_reason( + &quic::Lifecycle::closed(&lifecycle).await, + "direct lifecycle terminal", + ); + + let (reader_tx, reader_rx) = mpsc::channel(1); + let mut reader = ChannelReader { + stream_id: VarInt::from_u32(401), + rx: reader_rx, + }; + assert_eq!( + reader.stream_id().await.expect("reader stream id"), + VarInt::from_u32(401) + ); + reader_tx + .send(Bytes::from_static(b"direct reader")) + .await + .expect("send reader bytes"); + assert_eq!( + reader + .next() + .await + .expect("reader item") + .expect("reader ok"), + Bytes::from_static(b"direct reader") + ); + reader + .stop(VarInt::from_u32(402)) + .await + .expect("stop reader"); + assert!( + reader_tx + .send(Bytes::from_static(b"after stop")) + .await + .is_err(), + "stopped reader should close its receiver" + ); + + let (writer_tx, mut writer_rx) = mpsc::channel(1); + let mut writer = ChannelWriter { + stream_id: VarInt::from_u32(403), + tx: Some(writer_tx), + }; + assert_eq!( + writer.stream_id().await.expect("writer stream id"), + VarInt::from_u32(403) + ); + writer + .send(Bytes::from_static(b"direct writer")) + .await + .expect("send writer bytes"); + assert_eq!( + writer_rx.recv().await.expect("writer bytes"), + Bytes::from_static(b"direct writer") + ); + writer.close().await.expect("close writer"); + assert!( + writer_rx.recv().await.is_none(), + "closed writer should drop its sender" + ); + let error = writer + .send(Bytes::from_static(b"after close")) + .await + .expect_err("closed writer should reject writes"); + assert_stream_reason(&error, "writer closed"); + + let (writer_tx, mut writer_rx) = mpsc::channel(1); + let mut writer = ChannelWriter { + stream_id: VarInt::from_u32(404), + tx: Some(writer_tx), + }; + writer + .reset(VarInt::from_u32(405)) + .await + .expect("reset writer"); + assert!( + writer_rx.recv().await.is_none(), + "canceled writer should drop its sender" + ); + + let (writer_tx, _writer_rx) = mpsc::channel(1); + let mut writer = ChannelWriter { + stream_id: VarInt::from_u32(406), + tx: Some(writer_tx), + }; + Pin::new(&mut writer) + .start_send(Bytes::from_static(b"first")) + .expect("first try_send should fit"); + let error = Pin::new(&mut writer) + .start_send(Bytes::from_static(b"second")) + .expect_err("full channel should fail start_send"); + assert_stream_reason(&error, "send failed"); +} + +#[tokio::test] +async fn direct_connection_agent_listener_and_connector_helpers_cover_errors() { + let (conn, lifecycle) = StreamableConnection::new(); + assert_connection_reason( + &expect_connection_error( + quic::ManageStream::open_bi(conn.as_ref()).await, + "empty bidi queue should fail", + ), + "no bidi streams available", + ); + assert_connection_reason( + &expect_connection_error( + quic::ManageStream::open_uni(conn.as_ref()).await, + "empty uni writer queue should fail", + ), + "no uni write streams available", + ); + assert_connection_reason( + &expect_connection_error( + quic::ManageStream::accept_bi(conn.as_ref()).await, + "accept_bi helper is intentionally unavailable", + ), + "accept_bi not implemented", + ); + assert_connection_reason( + &expect_connection_error( + quic::ManageStream::accept_uni(conn.as_ref()).await, + "empty uni reader queue should fail", + ), + "no uni read streams available", + ); + + quic::Lifecycle::close(conn.as_ref(), Code::H3_NO_ERROR, "direct close".into()); + quic::Lifecycle::check(conn.as_ref()).expect("connection should start live"); + lifecycle.set_terminal_error(test_connection_error("direct connection terminal")); + assert_connection_reason( + &quic::Lifecycle::check(conn.as_ref()).expect_err("terminal connection should fail check"), + "direct connection terminal", + ); + assert_connection_reason( + &quic::Lifecycle::closed(conn.as_ref()).await, + "direct connection terminal", + ); + + assert!( + quic::WithLocalAuthority::local_authority(conn.as_ref()) + .await + .expect("local authority helper should succeed") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(conn.as_ref()) + .await + .expect("remote authority helper should succeed") + .is_none() + ); + + let authority = NoAuthority; + assert_eq!( + dhttp_identity::identity::LocalAuthority::name(&authority), + "none" + ); + assert_eq!( + dhttp_identity::identity::RemoteAuthority::name(&authority), + "none" + ); + assert!(dhttp_identity::identity::LocalAuthority::cert_chain(&authority).is_empty()); + assert!(dhttp_identity::identity::RemoteAuthority::cert_chain(&authority).is_empty()); + assert!( + dhttp_identity::identity::LocalAuthority::sign(&authority, b"payload",) + .await + .expect("no-authority signing should succeed") + .is_empty() + ); + + let (tx, rx) = mpsc::channel(1); + let mut listener = MockListen { rx }; + quic::Listen::shutdown(&listener) + .await + .expect("mock listener shutdown should succeed"); + drop(tx); + assert_connection_reason( + &expect_connection_error( + quic::Listen::accept(&mut listener).await, + "closed listener should fail accept", + ), + "listener closed", + ); + + let authority = "direct.example:443" + .parse::() + .expect("authority parses"); + let connector = MockConnect { + conn: Mutex::new(None), + }; + assert_connection_reason( + &expect_connection_error( + quic::Connect::connect(&connector, &authority).await, + "unstaged connector should fail", + ), + "no connection staged", + ); + + let (conn, _lifecycle) = StreamableConnection::new(); + let connector = MockConnect { + conn: Mutex::new(Some(conn.clone())), + }; + let connected = quic::Connect::connect(&connector, &authority) + .await + .expect("staged connector should return connection"); + assert!(Arc::ptr_eq(&conn, &connected)); + assert_connection_reason( + &expect_connection_error( + quic::Connect::connect(&connector, &authority).await, + "staged connection should be consumed", + ), + "no connection staged", + ); +} + #[tokio::test] async fn listen_accept_bootstrap() { let (mut listener, conn_tx) = setup_listen_pair().await; @@ -514,12 +770,16 @@ async fn listen_accept_bootstrap() { // Client accepts — should get IpcConnectionHandle let handle = quic::Listen::accept(&mut listener).await.unwrap(); - assert!(quic::Lifecycle::check(&handle).is_ok()); + assert!(quic::Lifecycle::check(handle.as_ref()).is_ok()); - // Agent should be None since mock returns None - let local = quic::WithLocalAgent::local_agent(&handle).await.unwrap(); + // Authority should be None since mock returns None + let local = quic::WithLocalAuthority::local_authority(handle.as_ref()) + .await + .unwrap(); assert!(local.is_none()); - let remote = quic::WithRemoteAgent::remote_agent(&handle).await.unwrap(); + let remote = quic::WithRemoteAuthority::remote_authority(handle.as_ref()) + .await + .unwrap(); assert!(remote.is_none()); } @@ -535,10 +795,10 @@ async fn connect_roundtrip() { let (server_mux, client_mux) = MuxChannel::pair_for_test().unwrap(); let (server_sink, server_stream) = server_mux.split().unwrap(); - let fd_sender = server_sink.fd_sender(); + let server_fd_transfer = server_stream.fd_transfer(server_sink.fd_sender()); let (client_sink, client_stream) = client_mux.split().unwrap(); - let client_fd_registry = client_stream.fd_registry(); + let client_fd_transfer = client_stream.fd_transfer(client_sink.fd_sender()); // Both sides must handshake concurrently let server_task = tokio::spawn(async move { @@ -552,7 +812,8 @@ async fn connect_roundtrip() { .unwrap(); tokio::spawn(remoc_conn); - let adapter = ConnectAdapter::<_, remoc::codec::Default>::new(mock_connect, fd_sender); + let adapter = + ConnectAdapter::<_, remoc::codec::Default>::new(mock_connect, server_fd_transfer); let (server, connect_client) = IpcConnectServerShared::new(Arc::new(adapter), 64); tokio::spawn(async move { let _ = server.serve(true).await; @@ -572,7 +833,7 @@ async fn connect_roundtrip() { tokio::spawn(remoc_conn); let connect_client = rx.recv().await.unwrap().unwrap(); - IpcConnector::::new(connect_client, client_fd_registry) + IpcConnector::::new(connect_client, client_fd_transfer) }); let (server_result, client_result) = tokio::join!(server_task, client_task); @@ -583,7 +844,7 @@ async fn connect_roundtrip() { let handle = quic::Connect::connect(&connector, &authority) .await .unwrap(); - assert!(quic::Lifecycle::check(&handle).is_ok()); + assert!(quic::Lifecycle::check(handle.as_ref()).is_ok()); } #[tokio::test] @@ -594,17 +855,17 @@ async fn lifecycle_propagation() { conn_tx.send(conn).await.unwrap(); let handle = quic::Listen::accept(&mut listener).await.unwrap(); - assert!(quic::Lifecycle::check(&handle).is_ok()); + assert!(quic::Lifecycle::check(handle.as_ref()).is_ok()); // close() should not panic - quic::Lifecycle::close(&handle, Code::H3_NO_ERROR, "test shutdown".into()); + quic::Lifecycle::close(handle.as_ref(), Code::H3_NO_ERROR, "test shutdown".into()); // Inject terminal error on the server side let err = test_connection_error("connection reset by peer"); lc.set_terminal_error(err); // closed() should return the error - let terminal = quic::Lifecycle::closed(&handle).await; + let terminal = quic::Lifecycle::closed(handle.as_ref()).await; match terminal { ConnectionError::Transport { source } => { assert!(source.reason.contains("connection reset by peer")); @@ -613,7 +874,7 @@ async fn lifecycle_propagation() { } // check() should now return Err - assert!(quic::Lifecycle::check(&handle).is_err()); + assert!(quic::Lifecycle::check(handle.as_ref()).is_err()); } #[tokio::test] @@ -627,7 +888,7 @@ async fn open_bi_data_transfer() { let handle = quic::Listen::accept(&mut listener).await.unwrap(); // Open a bidi stream through the IPC chain - let (mut reader, mut writer) = quic::ManageStream::open_bi(&handle).await.unwrap(); + let (mut reader, mut writer) = quic::ManageStream::open_bi(handle.as_ref()).await.unwrap(); // Server → Client: inject data into the mock QUIC reader → bridge → pipe → PipeReader test_handles @@ -661,7 +922,7 @@ async fn open_bi_unavailable() { let handle = quic::Listen::accept(&mut listener).await.unwrap(); // open_bi should fail because no streams are in the queue - let result = quic::ManageStream::open_bi(&handle).await; + let result = quic::ManageStream::open_bi(handle.as_ref()).await; assert!(result.is_err()); } @@ -680,7 +941,7 @@ async fn open_uni_data_transfer() { let handle = quic::Listen::accept(&mut listener).await.unwrap(); // Open a uni stream — the client gets a writer - let mut writer = quic::ManageStream::open_uni(&handle).await.unwrap(); + let mut writer = quic::ManageStream::open_uni(handle.as_ref()).await.unwrap(); // Client → Server: PipeWriter → pipe → bridge → mock QUIC writer → test rx use futures::SinkExt; @@ -701,7 +962,9 @@ async fn accept_uni_data_transfer() { let handle = quic::Listen::accept(&mut listener).await.unwrap(); // Accept a uni stream — the client gets a reader - let mut reader = quic::ManageStream::accept_uni(&handle).await.unwrap(); + let mut reader = quic::ManageStream::accept_uni(handle.as_ref()) + .await + .unwrap(); // Server → Client: inject data into mock QUIC reader → bridge → pipe → client test_handles @@ -724,7 +987,7 @@ async fn open_uni_unavailable() { let handle = quic::Listen::accept(&mut listener).await.unwrap(); - let result = quic::ManageStream::open_uni(&handle).await; + let result = quic::ManageStream::open_uni(handle.as_ref()).await; assert!(result.is_err()); } @@ -737,6 +1000,6 @@ async fn accept_uni_unavailable() { let handle = quic::Listen::accept(&mut listener).await.unwrap(); - let result = quic::ManageStream::accept_uni(&handle).await; + let result = quic::ManageStream::accept_uni(handle.as_ref()).await; assert!(result.is_err()); } diff --git a/src/ipc/transport.rs b/src/ipc/transport.rs index 9623795..0562255 100644 --- a/src/ipc/transport.rs +++ b/src/ipc/transport.rs @@ -1,28 +1,44 @@ +//! Unix-domain IPC transport for remoc bytes plus file-descriptor transfer. +//! +//! The public FD API is intentionally receiver-chosen: +//! +//! - the receiving side calls [`FdTransfer::receive`] to reserve an id; +//! - the id travels in the RPC request; +//! - the sending side calls [`FdTransfer::delivery`] and consumes the returned +//! [`FdDelivery`] with [`FdDelivery::deliver`]; +//! - [`FdDelivery::deliver`] returns after the FD frame is queued to the local +//! mux writer FIFO. +//! +//! `MuxChannel` runs independent reader and writer tasks. FD delivery is not +//! receiver-acknowledged; remoc cancellation is the sender-visible cancellation +//! mechanism for RPC operations that carry receiver-chosen FD ids. + use std::{ - collections::{HashMap, VecDeque}, io, os::{ - fd::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd}, + fd::{AsFd, OwnedFd}, unix::net::UnixStream as StdUnixStream, }, pin::Pin, - sync::{ - Arc, Mutex, Weak, - atomic::{AtomicU64, Ordering}, - }, + sync::Arc, task::{Context, Poll}, }; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::Bytes; use futures::{Sink, Stream, ready}; -use nix::sys::socket::{ - ControlMessage, ControlMessageOwned, MsgFlags, Shutdown, recvmsg, sendmsg, shutdown, -}; use smallvec::SmallVec; use snafu::ResultExt; -use tokio::{io::unix::AsyncFd, sync::oneshot}; +use tokio::sync::mpsc; +use tokio_util::task::AbortOnDropHandle; + +use crate::varint::VarInt; -use crate::varint::{VARINT_MAX, VarInt}; +mod driver; +mod fd_plane; +mod frame; + +pub use driver::FdSender; +pub use fd_plane::{FdDelivered, FdDelivery, FdReceiver, FdTransfer, ReceivedFds}; /// Alias for a small-vec optimised FD collection. /// @@ -30,22 +46,8 @@ use crate::varint::{VARINT_MAX, VarInt}; /// without heap allocation. pub type FdVec = SmallVec<[OwnedFd; 4]>; -const FRAME_TYPE_BYTES: u8 = 0x00; -const FRAME_TYPE_FDS: u8 = 0x01; - -const MAX_FRAME_HEADER_LEN: usize = 1 + VarInt::MAX_SIZE; -const MAX_FDS_FRAME_LEN: usize = 1 + VarInt::MAX_SIZE + VarInt::MAX_SIZE + VarInt::MAX_SIZE; -const READ_CHUNK_LEN: usize = 8 * 1024; - /// Maximum number of FDs allowed in a single FD frame. -/// -/// Enforced on both the send side ([`FdSender::queue_fds`]) and the receive -/// side ([`try_decode_frame`]) to bound the cmsg buffer required by -/// [`recv_frame_data`]. -const MAX_FDS_PER_FRAME: usize = 4; - -/// Minimum wire bytes for an FD frame: type(1) + payload_len(1) + id(1) + fd_count(1). -const MIN_FDS_FRAME_LEN: usize = 4; +pub(crate) const MAX_FDS_PER_FRAME: usize = 4; #[derive(Debug, snafu::Snafu)] #[snafu(module)] @@ -56,8 +58,6 @@ pub enum QueueFdsError { EmptyFds, #[snafu(display("too many fds ({count}), max is {MAX_FDS_PER_FRAME}"))] TooManyFds { count: usize }, - #[snafu(display("fd id space is exhausted"))] - IdExhausted, } #[derive(Debug, snafu::Snafu)] @@ -65,30 +65,31 @@ pub enum QueueFdsError { pub enum WaitFdsError { #[snafu(display("fd registry is closed"))] Closed, - #[snafu(display("waiter already exists for stream id {id}"))] + #[snafu(display("waiter already exists for fd id {id}"))] AlreadyWaiting { id: VarInt }, #[snafu(display("fd waiter channel is closed unexpectedly"))] ChannelClosed, + #[snafu(display("fd id space is exhausted"))] + IdExhausted, } #[derive(Debug, snafu::Snafu)] #[snafu(module)] -pub enum RegisterFdsError { - #[snafu(display("fd registry is closed"))] - Closed, - #[snafu(display("duplicate stream id {id} received"))] - DuplicateId { id: VarInt }, +pub enum TakeFdsError { + #[snafu(display("expected {expected} fds, got {actual}"))] + Count { expected: usize, actual: usize }, +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum DeliverFdsError { + #[snafu(display("failed to queue fd delivery"))] + Queue { source: QueueFdsError }, } #[derive(Debug, snafu::Snafu)] #[snafu(module)] pub enum MuxSinkError { - #[snafu(display("failed to poll writable readiness"))] - PollReady { source: io::Error }, - #[snafu(display("failed to send mux frame"))] - Send { source: io::Error }, - #[snafu(display("sink is not ready for start_send"))] - NotReady, #[snafu(display("write side is closed"))] Closed, } @@ -100,6 +101,8 @@ pub enum MuxStreamError { PollReady { source: io::Error }, #[snafu(display("failed to receive mux frame"))] Recv { source: io::Error }, + #[snafu(display("ancillary data was truncated"))] + AncillaryTruncated, #[snafu(display("unknown frame type: 0x{frame_type:02x}"))] UnknownFrameType { frame_type: u8 }, #[snafu(display("invalid frame length"))] @@ -108,8 +111,6 @@ pub enum MuxStreamError { InvalidFdsPayload, #[snafu(display("fds frame received without ancillary fds"))] MissingAncillaryFds, - #[snafu(display("failed to register incoming fds"))] - RegisterFds { source: RegisterFdsError }, } #[derive(Debug, snafu::Snafu)] @@ -142,33 +143,21 @@ impl MuxChannel { .try_clone_to_owned() .context(split_error::CloneFdSnafu)?; - let writer_core = Arc::new(WriterCore { - next_id: AtomicU64::new(0), - fd_queue: Mutex::new(VecDeque::new()), - }); - let registry_core = Arc::new(RegistryCore::new()); - let registry = FdRegistry::from_core(®istry_core); + let plane = Arc::new(fd_plane::FdPlaneCore::new()); + let (writer_core, fd_sender, writer_task) = + driver::start_writer(write_fd).context(split_error::AsyncFdSnafu)?; + let (bytes_rx, reader_task) = + driver::start_reader(read_fd, plane.clone()).context(split_error::AsyncFdSnafu)?; let sink = MuxSink { - fd: AsyncFd::new(write_fd).context(split_error::AsyncFdSnafu)?, - core: Some(writer_core.clone()), - fd_sender: FdSender::from_core(&writer_core), - pending_bytes: None, - current: None, + core: Some(writer_core), + fd_sender, + writer_task: Some(writer_task), }; let stream = MuxStream { - fd: AsyncFd::new(read_fd).context(split_error::AsyncFdSnafu)?, - registry_core, - registry: registry.clone(), - read_buf: BytesMut::new(), - cmsg_buf: { - let per_msg = nix::sys::socket::cmsg_space::<[RawFd; MAX_FDS_PER_FRAME]>(); - let max_msgs = READ_CHUNK_LEN / MIN_FDS_FRAME_LEN; - vec![0u8; max_msgs * per_msg] - }, - pending_fds: VecDeque::new(), - pending_fd_frame: None, - closed: false, + plane, + bytes_rx, + reader_task: Some(reader_task), }; Ok((sink, stream)) @@ -185,8 +174,9 @@ impl MuxChannel { /// Create a socketpair and return `(local_channel, remote_fd)`. /// /// The local side is a ready-to-use [`MuxChannel`]; the remote `OwnedFd` - /// is intended to be sent to another process via [`FdSender::queue_fds`] - /// and reconstructed with [`MuxChannel::from_fd`] on the receiving end. + /// is intended to be transferred to another process via + /// [`FdDelivery::deliver`] and reconstructed with [`MuxChannel::from_fd`] + /// on the receiving end. pub fn create_pair() -> io::Result<(Self, OwnedFd)> { let (a, b) = StdUnixStream::pair()?; let local = Self::from_fd(a.into())?; @@ -194,64 +184,11 @@ impl MuxChannel { } } -#[derive(Debug)] -struct QueuedFds { - id: VarInt, - fds: FdVec, -} - -#[derive(Debug)] -struct WriterCore { - next_id: AtomicU64, - fd_queue: Mutex>, -} - -#[derive(Clone, Debug)] -pub struct FdSender { - core: Weak, -} - -impl FdSender { - fn from_core(core: &Arc) -> Self { - Self { - core: Arc::downgrade(core), - } - } - - pub fn queue_fds(&self, fds: FdVec) -> Result { - if fds.is_empty() { - return Err(QueueFdsError::EmptyFds); - } - if fds.len() > MAX_FDS_PER_FRAME { - return Err(QueueFdsError::TooManyFds { count: fds.len() }); - } - - let core = self.core.upgrade().ok_or(QueueFdsError::Closed)?; - - let id_raw = core - .next_id - .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| { - let next = current.checked_add(1)?; - if next >= VARINT_MAX { None } else { Some(next) } - }) - .map_err(|_| QueueFdsError::IdExhausted)?; - - let id = VarInt::from_u64(id_raw).map_err(|_| QueueFdsError::IdExhausted)?; - - let mut queue = core.fd_queue.lock().expect("fd queue poisoned"); - queue.push_back(QueuedFds { id, fds }); - - Ok(id) - } -} - #[derive(Debug)] pub struct MuxSink { - fd: AsyncFd, - core: Option>, + core: Option>, fd_sender: FdSender, - pending_bytes: Option, - current: Option, + writer_task: Option>, } impl MuxSink { @@ -259,66 +196,35 @@ impl MuxSink { self.fd_sender.clone() } - fn close_core(&mut self) { - self.core = None; - self.pending_bytes = None; - self.current = None; - } - - fn core_ref(&self) -> Result<&Arc, MuxSinkError> { + fn core_ref(&self) -> Result<&Arc, MuxSinkError> { self.core.as_ref().ok_or(MuxSinkError::Closed) } - fn pop_next_frame(&mut self) -> Option { - let core = self.core.as_ref()?; - - if let Some(queued_fds) = core.fd_queue.lock().expect("fd queue poisoned").pop_front() { - return Some(PendingWriteFrame::new_fds(queued_fds.id, queued_fds.fds)); + fn signal_close(&self) { + if let Some(core) = &self.core { + core.close(); } + } - self.pending_bytes.take().map(PendingWriteFrame::new_bytes) + fn close_core(&mut self) { + self.signal_close(); + if let Some(core) = &self.core { + core.shutdown_write(); + } + self.core = None; + self.writer_task = None; } - fn poll_flush_internal(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.core.is_none() { + fn poll_flush_internal(&self, cx: &mut Context<'_>) -> Poll> { + let core = self.core_ref()?; + core.register_flush_waker(cx.waker()); + if core.is_closed() { return Poll::Ready(Err(MuxSinkError::Closed)); } - - loop { - if self.current.is_none() { - self.current = self.pop_next_frame(); - if self.current.is_none() { - return Poll::Ready(Ok(())); - } - } - - let mut guard = match ready!(self.fd.poll_write_ready(cx)) { - Ok(guard) => guard, - Err(source) => return Poll::Ready(Err(MuxSinkError::PollReady { source })), - }; - let frame = self.current.as_mut().expect("current frame must exist"); - - let io_result = guard.try_io(|inner| send_frame(inner.get_ref().as_raw_fd(), frame)); - let sent = match io_result { - Ok(Ok(sent)) => sent, - Ok(Err(source)) => { - self.close_core(); - return Poll::Ready(Err(MuxSinkError::Send { source })); - } - Err(_would_block) => return Poll::Pending, - }; - - if sent == 0 { - self.close_core(); - return Poll::Ready(Err(MuxSinkError::Send { - source: io::Error::new(io::ErrorKind::WriteZero, "sendmsg returned 0 bytes"), - })); - } - - frame.advance(sent); - if frame.is_complete() { - self.current = None; - } + if core.is_flushed() { + Poll::Ready(Ok(())) + } else { + Poll::Pending } } } @@ -326,367 +232,56 @@ impl MuxSink { impl Drop for MuxSink { fn drop(&mut self) { self.close_core(); - let _ = shutdown(self.fd.get_ref().as_raw_fd(), Shutdown::Write); } } impl Sink for MuxSink { type Error = MuxSinkError; - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.poll_flush_internal(cx) + fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match self.core_ref() { + Ok(core) if !core.is_closed() => Poll::Ready(Ok(())), + _ => Poll::Ready(Err(MuxSinkError::Closed)), + } } - fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - self.core_ref()?; - if self.pending_bytes.is_some() || self.current.is_some() { - return Err(MuxSinkError::NotReady); + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + match self.core_ref()?.queue(frame::OutboundFrame::Bytes(item)) { + Ok(()) => Ok(()), + Err(_) => Err(MuxSinkError::Closed), } - self.pending_bytes = Some(item); - Ok(()) } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.poll_flush_internal(cx) } - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let result = ready!(self.poll_flush_internal(cx)); if result.is_ok() { - self.close_core(); - let _ = shutdown(self.fd.get_ref().as_raw_fd(), Shutdown::Write); + self.signal_close(); } Poll::Ready(result) } } -#[derive(Debug)] -enum PendingWriteFrame { - Bytes { - header: [u8; MAX_FRAME_HEADER_LEN], - header_len: usize, - header_written: usize, - payload: Bytes, - payload_written: usize, - }, - Fds { - header: [u8; MAX_FDS_FRAME_LEN], - header_len: usize, - header_written: usize, - payload: FdVec, - include_ancillary: bool, - }, -} - -impl PendingWriteFrame { - fn new_bytes(payload: Bytes) -> Self { - let payload_len = - VarInt::try_from(payload.len()).expect("payload length must fit into varint"); - - let mut header = [0u8; MAX_FRAME_HEADER_LEN]; - header[0] = FRAME_TYPE_BYTES; - let varint_len = encode_varint_to_slice(&mut header[1..], payload_len); - - Self::Bytes { - header, - header_len: 1 + varint_len, - header_written: 0, - payload, - payload_written: 0, - } - } - - fn new_fds(id: VarInt, fds: FdVec) -> Self { - let fd_count = VarInt::try_from(fds.len()).expect("fd count must fit into varint"); - - let mut id_buf = [0u8; VarInt::MAX_SIZE]; - let id_len = encode_varint_to_slice(&mut id_buf, id); - - let mut fc_buf = [0u8; VarInt::MAX_SIZE]; - let fc_len = encode_varint_to_slice(&mut fc_buf, fd_count); - - let mut header = [0u8; MAX_FDS_FRAME_LEN]; - header[0] = FRAME_TYPE_FDS; - let body_len = VarInt::try_from(id_len + fc_len) - .expect("fd frame payload length must fit into varint"); - let len_len = encode_varint_to_slice(&mut header[1..], body_len); - let hdr_prefix = 1 + len_len; - header[hdr_prefix..hdr_prefix + id_len].copy_from_slice(&id_buf[..id_len]); - header[hdr_prefix + id_len..hdr_prefix + id_len + fc_len] - .copy_from_slice(&fc_buf[..fc_len]); - - Self::Fds { - header, - header_len: hdr_prefix + id_len + fc_len, - header_written: 0, - payload: fds, - include_ancillary: true, - } - } - - fn advance(&mut self, mut n: usize) { - match self { - Self::Bytes { - header_len, - header_written, - payload, - payload_written, - .. - } => { - let header_remaining = *header_len - *header_written; - let header_advance = n.min(header_remaining); - *header_written += header_advance; - n -= header_advance; - - if n > 0 { - let payload_remaining = payload.len().saturating_sub(*payload_written); - let payload_advance = n.min(payload_remaining); - *payload_written += payload_advance; - } - } - Self::Fds { - header_written, - include_ancillary, - .. - } => { - *header_written += n; - if n > 0 { - *include_ancillary = false; - } - } - } - } - - fn is_complete(&self) -> bool { - match self { - Self::Bytes { - header_len, - header_written, - payload, - payload_written, - .. - } => *header_written >= *header_len && *payload_written >= payload.len(), - Self::Fds { - header_len, - header_written, - .. - } => *header_written >= *header_len, - } - } -} - -fn send_frame(fd: RawFd, frame: &mut PendingWriteFrame) -> io::Result { - match frame { - PendingWriteFrame::Bytes { - header, - header_len, - header_written, - payload, - payload_written, - } => { - let mut iovecs = [io::IoSlice::new(&[]), io::IoSlice::new(&[])]; - let mut iov_count = 0usize; - - if *header_written < *header_len { - iovecs[iov_count] = io::IoSlice::new(&header[*header_written..*header_len]); - iov_count += 1; - } - if *payload_written < payload.len() { - iovecs[iov_count] = io::IoSlice::new(&payload[*payload_written..]); - iov_count += 1; - } - - if iov_count == 0 { - return Ok(0); - } - - let sent = sendmsg::<()>(fd, &iovecs[..iov_count], &[], MsgFlags::empty(), None) - .map_err(io::Error::from)?; - Ok(sent) - } - PendingWriteFrame::Fds { - header, - header_len, - header_written, - payload, - include_ancillary, - } => { - if *header_written >= *header_len { - return Ok(0); - } - - let iov = [io::IoSlice::new(&header[*header_written..*header_len])]; - let sent = if *include_ancillary { - let raw_fds: SmallVec<[RawFd; 4]> = - payload.iter().map(AsRawFd::as_raw_fd).collect(); - let cmsgs = [ControlMessage::ScmRights(&raw_fds)]; - sendmsg::<()>(fd, &iov, &cmsgs, MsgFlags::empty(), None) - } else { - sendmsg::<()>(fd, &iov, &[], MsgFlags::empty(), None) - } - .map_err(io::Error::from)?; - - Ok(sent) - } - } -} - -#[derive(Debug)] -struct RegistryCore { - state: Mutex, -} - -#[derive(Debug)] -struct RegistryState { - slots: HashMap, - closed: bool, -} - -#[derive(Debug)] -enum RegistrySlot { - Ready(FdVec), - Waiting(oneshot::Sender>), -} - -impl RegistryCore { - fn new() -> Self { - Self { - state: Mutex::new(RegistryState { - slots: HashMap::new(), - closed: false, - }), - } - } - - async fn wait_fds(&self, id: VarInt) -> Result { - let rx = { - let mut state = self.state.lock().expect("registry mutex poisoned"); - if state.closed { - return Err(WaitFdsError::Closed); - } - - if let Some(existing) = state.slots.remove(&id) { - match existing { - RegistrySlot::Ready(fds) => return Ok(fds), - RegistrySlot::Waiting(waiter) => { - state.slots.insert(id, RegistrySlot::Waiting(waiter)); - return Err(WaitFdsError::AlreadyWaiting { id }); - } - } - } - - let (tx, rx) = oneshot::channel(); - state.slots.insert(id, RegistrySlot::Waiting(tx)); - rx - }; - - match rx.await { - Ok(result) => result, - Err(_) => Err(WaitFdsError::ChannelClosed), - } - } - - fn register_arrival(&self, id: VarInt, fds: FdVec) -> Result<(), RegisterFdsError> { - let mut state = self.state.lock().expect("registry mutex poisoned"); - if state.closed { - return Err(RegisterFdsError::Closed); - } - - let existing = state.slots.remove(&id); - match existing { - None => { - state.slots.insert(id, RegistrySlot::Ready(fds)); - Ok(()) - } - Some(RegistrySlot::Waiting(waiter)) => { - if let Err(payload) = waiter.send(Ok(fds)) { - let recovered_fds = payload.unwrap_or_else(|_| FdVec::new()); - state.slots.insert(id, RegistrySlot::Ready(recovered_fds)); - } - Ok(()) - } - Some(RegistrySlot::Ready(existing_fds)) => { - state.slots.insert(id, RegistrySlot::Ready(existing_fds)); - Err(RegisterFdsError::DuplicateId { id }) - } - } - } - - fn close(&self) { - let mut state = self.state.lock().expect("registry mutex poisoned"); - if state.closed { - return; - } - state.closed = true; - - for (_, slot) in state.slots.drain() { - if let RegistrySlot::Waiting(waiter) = slot { - let _ = waiter.send(Err(WaitFdsError::Closed)); - } - } - } -} - -#[derive(Clone, Debug)] -pub struct FdRegistry { - core: Weak, -} - -impl FdRegistry { - fn from_core(core: &Arc) -> Self { - Self { - core: Arc::downgrade(core), - } - } - - #[cfg(test)] - fn new_for_test() -> (Arc, Self) { - let core = Arc::new(RegistryCore::new()); - let registry = Self::from_core(&core); - (core, registry) - } - - pub async fn wait_fds(&self, id: VarInt) -> Result { - let core = self.core.upgrade().ok_or(WaitFdsError::Closed)?; - core.wait_fds(id).await - } -} - #[derive(Debug)] pub struct MuxStream { - fd: AsyncFd, - registry_core: Arc, - registry: FdRegistry, - read_buf: BytesMut, - /// Heap-allocated cmsg buffer for `recvmsg()`, reused across calls. - /// - /// Sized for the worst case where every `MIN_FDS_FRAME_LEN` bytes in a - /// read chunk originates from a separate `sendmsg()` carrying - /// `MAX_FDS_PER_FRAME` FDs — on `SOCK_STREAM` the kernel may coalesce - /// all of them into a single `recvmsg()`. - cmsg_buf: Vec, - /// Flat queue of pending FDs received via SCM_RIGHTS. - /// - /// On SOCK_STREAM, the kernel may coalesce SCM_RIGHTS from multiple - /// sendmsg calls into a single control message. Each FD frame includes - /// an `fd_count` field so the receiver can take the correct number of - /// FDs from this flat queue. - pending_fds: VecDeque, - /// An FD frame that has been decoded from the byte stream but whose - /// ancillary FDs have not yet arrived. - pending_fd_frame: Option<(VarInt, usize)>, - closed: bool, + plane: Arc, + bytes_rx: mpsc::UnboundedReceiver>, + reader_task: Option>, } impl MuxStream { - pub fn fd_registry(&self) -> FdRegistry { - self.registry.clone() + pub fn fd_transfer(&self, sender: FdSender) -> FdTransfer { + FdTransfer::new(sender, self.plane.clone()) } } impl Drop for MuxStream { fn drop(&mut self) { - self.registry_core.close(); + self.reader_task = None; + self.plane.close(); } } @@ -694,255 +289,7 @@ impl Stream for MuxStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if self.closed { - return Poll::Ready(None); - } - - loop { - // 1. Complete any pending FD frame whose ancillary FDs have now arrived. - if let Some((id, fd_count)) = self.pending_fd_frame { - if self.pending_fds.len() >= fd_count { - self.pending_fd_frame = None; - let fds: FdVec = self.pending_fds.drain(..fd_count).collect(); - if let Err(source) = self.registry_core.register_arrival(id, fds) { - self.closed = true; - self.registry_core.close(); - return Poll::Ready(Some(Err(MuxStreamError::RegisterFds { source }))); - } - continue; - } - // Still waiting for FDs — fall through to read more data. - } else { - // 2. Decode next frame from the byte buffer. - match try_decode_frame(&mut self.read_buf) { - Ok(Some(DecodedFrame::Bytes(payload))) => { - return Poll::Ready(Some(Ok(payload))); - } - Ok(Some(DecodedFrame::Fds { id, fd_count })) => { - if self.pending_fds.len() >= fd_count { - let fds: FdVec = self.pending_fds.drain(..fd_count).collect(); - if let Err(source) = self.registry_core.register_arrival(id, fds) { - self.closed = true; - self.registry_core.close(); - return Poll::Ready(Some(Err(MuxStreamError::RegisterFds { - source, - }))); - } - continue; - } - // FD frame decoded but ancillary not yet received. - // Stash it and fall through to read more data. - self.pending_fd_frame = Some((id, fd_count)); - } - Ok(None) => { - // Incomplete frame — fall through to read more data. - } - Err(err) => { - self.closed = true; - self.registry_core.close(); - return Poll::Ready(Some(Err(err))); - } - } - } - - // 3. Read more data (and ancillary FDs) from the socket. - // Reborrow to enable split field borrows (guard borrows fd, recv - // borrows cmsg_buf). - let this = &mut *self; - let mut guard = match ready!(this.fd.poll_read_ready(cx)) { - Ok(guard) => guard, - Err(source) => return Poll::Ready(Some(Err(MuxStreamError::PollReady { source }))), - }; - - let mut read_buf = [0u8; READ_CHUNK_LEN]; - let io_result = guard.try_io(|inner| { - recv_frame_data( - inner.get_ref().as_raw_fd(), - &mut read_buf, - &mut this.cmsg_buf, - ) - }); - - let (read, ancillary) = match io_result { - Ok(Ok(result)) => result, - Ok(Err(source)) => { - this.closed = true; - this.registry_core.close(); - return Poll::Ready(Some(Err(MuxStreamError::Recv { source }))); - } - Err(_would_block) => return Poll::Pending, - }; - - if read == 0 { - this.closed = true; - this.registry_core.close(); - if this.pending_fd_frame.is_some() { - return Poll::Ready(Some(Err(MuxStreamError::MissingAncillaryFds))); - } - return Poll::Ready(None); - } - - if !ancillary.is_empty() { - this.pending_fds.extend(ancillary); - } - this.read_buf.extend_from_slice(&read_buf[..read]); - } - } -} - -#[cfg(not(target_os = "linux"))] -fn set_cloexec(fd: &OwnedFd) -> io::Result<()> { - use nix::fcntl::{F_GETFD, F_SETFD, FdFlag, fcntl}; - let raw = fd.as_raw_fd(); - let bits = fcntl(raw, F_GETFD).map_err(io::Error::from)?; - let new_flags = FdFlag::from_bits_truncate(bits) | FdFlag::FD_CLOEXEC; - fcntl(raw, F_SETFD(new_flags)).map_err(io::Error::from)?; - Ok(()) -} - -fn recv_frame_data( - fd: RawFd, - data_buf: &mut [u8], - cmsg_buf: &mut Vec, -) -> io::Result<(usize, FdVec)> { - let mut iov = [io::IoSliceMut::new(data_buf)]; - - // `MSG_CMSG_CLOEXEC` is Linux-only; on other Unix platforms we fall back to - // setting `FD_CLOEXEC` manually on each received fd below. - #[cfg(target_os = "linux")] - let recv_flags = MsgFlags::MSG_CMSG_CLOEXEC; - #[cfg(not(target_os = "linux"))] - let recv_flags = MsgFlags::empty(); - - let msg = recvmsg::<()>(fd, &mut iov, Some(cmsg_buf), recv_flags).map_err(io::Error::from)?; - - if msg.flags.contains(MsgFlags::MSG_CTRUNC) { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "cmsg buffer overflow: ancillary data truncated (MSG_CTRUNC)", - )); - } - - let mut fds = FdVec::new(); - let cmsgs = msg.cmsgs().map_err(io::Error::from)?; - for cmsg in cmsgs { - if let ControlMessageOwned::ScmRights(raw_fds) = cmsg { - for raw_fd in raw_fds { - // SAFETY: SCM_RIGHTS transfers ownership of a new fd to receiver. - let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) }; - #[cfg(not(target_os = "linux"))] - set_cloexec(&fd)?; - fds.push(fd); - } - } - } - - Ok((msg.bytes, fds)) -} - -#[derive(Debug)] -enum DecodedFrame { - Bytes(Bytes), - Fds { id: VarInt, fd_count: usize }, -} - -fn try_decode_frame(src: &mut BytesMut) -> Result, MuxStreamError> { - if src.is_empty() { - return Ok(None); - } - - let frame_type = src[0]; - if frame_type != FRAME_TYPE_BYTES && frame_type != FRAME_TYPE_FDS { - return Err(MuxStreamError::UnknownFrameType { frame_type }); - } - - let Some((payload_len, varint_len)) = try_decode_varint_len(&src[1..])? else { - return Ok(None); - }; - - let frame_header_len = 1 + varint_len; - if src.len() < frame_header_len + payload_len { - return Ok(None); - } - - src.advance(frame_header_len); - let payload = src.split_to(payload_len).freeze(); - - match frame_type { - FRAME_TYPE_BYTES => Ok(Some(DecodedFrame::Bytes(payload))), - FRAME_TYPE_FDS => { - let (id, id_consumed) = decode_varint_from_slice(&payload)?; - let remaining = &payload[id_consumed..]; - let (fd_count_vi, fc_consumed) = decode_varint_from_slice(remaining)?; - if id_consumed + fc_consumed != payload.len() { - return Err(MuxStreamError::InvalidFdsPayload); - } - let fd_count = usize::try_from(u64::from(fd_count_vi)) - .map_err(|_| MuxStreamError::InvalidFdsPayload)?; - if fd_count == 0 || fd_count > MAX_FDS_PER_FRAME { - return Err(MuxStreamError::InvalidFdsPayload); - } - Ok(Some(DecodedFrame::Fds { id, fd_count })) - } - _ => Err(MuxStreamError::UnknownFrameType { frame_type }), - } -} - -fn try_decode_varint_len(src: &[u8]) -> Result, MuxStreamError> { - let Some((value, consumed)) = try_decode_varint(src)? else { - return Ok(None); - }; - let payload_len = usize::try_from(value).map_err(|_| MuxStreamError::InvalidFrameLength)?; - Ok(Some((payload_len, consumed))) -} - -fn decode_varint_from_slice(src: &[u8]) -> Result<(VarInt, usize), MuxStreamError> { - let Some((value, consumed)) = try_decode_varint(src)? else { - return Err(MuxStreamError::InvalidFdsPayload); - }; - let id = VarInt::from_u64(value).map_err(|_| MuxStreamError::InvalidFdsPayload)?; - Ok((id, consumed)) -} - -fn try_decode_varint(src: &[u8]) -> Result, MuxStreamError> { - if src.is_empty() { - return Ok(None); - } - - let first = src[0]; - let len = 1usize << (first >> 6); - if src.len() < len { - return Ok(None); - } - - let mut raw = [0u8; 8]; - raw[..len].copy_from_slice(&src[..len]); - raw[0] &= 0x3f; - let value = u64::from_be_bytes(raw) >> (8 * (8 - len)); - if value >= VARINT_MAX { - return Err(MuxStreamError::InvalidFrameLength); - } - - Ok(Some((value, len))) -} - -fn encode_varint_to_slice(dst: &mut [u8], v: VarInt) -> usize { - let x = v.into_inner(); - if x < (1 << 6) { - dst[0] = x as u8; - 1 - } else if x < (1 << 14) { - let bytes = ((0b01 << 14) | x as u16).to_be_bytes(); - dst[..2].copy_from_slice(&bytes); - 2 - } else if x < (1 << 30) { - let bytes = ((0b10 << 30) | x as u32).to_be_bytes(); - dst[..4].copy_from_slice(&bytes); - 4 - } else { - let bytes = ((0b11 << 62) | x).to_be_bytes(); - dst[..8].copy_from_slice(&bytes); - 8 + self.bytes_rx.poll_recv(cx) } } @@ -979,729 +326,77 @@ mod tests { } #[tokio::test] - async fn queue_fds_auto_id_and_wait() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - - let registry = stream_b.fd_registry(); - let (fd_a, _fd_b) = StdUnixStream::pair().expect("fd pair"); - let id = sink_a - .fd_sender() - .queue_fds(smallvec![fd_a.into()]) - .expect("queue fds"); - - sink_a.send(Bytes::new()).await.expect("drive flush"); - - let _ = stream_b - .next() - .await - .expect("stream item") - .expect("empty bytes frame"); - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 1); - } - - #[tokio::test] - async fn wait_before_arrival() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - - let registry = stream_b.fd_registry(); - let waiter = tokio::spawn(async move { registry.wait_fds(VarInt::from_u32(0)).await }); + async fn fd_delivery_deliver_completes_after_local_queue() { + let ((sink_a, stream_a), (sink_b, stream_b)) = make_pair(); + let transfer_a = stream_a.fd_transfer(sink_a.fd_sender()); + let transfer_b = stream_b.fd_transfer(sink_b.fd_sender()); - let (fd_a, _fd_b) = StdUnixStream::pair().expect("fd pair"); - let id = sink_a - .fd_sender() - .queue_fds(smallvec![fd_a.into()]) - .expect("queue fds"); - assert_eq!(id, VarInt::from_u32(0)); - - sink_a.send(Bytes::new()).await.expect("drive flush"); - let _ = stream_b - .next() + let receiver = transfer_b.receive(); + let id = receiver.id(); + let delivery = transfer_a.delivery(id); + let (fd, _peer) = StdUnixStream::pair().expect("fd pair"); + let delivered = timeout( + Duration::from_secs(1), + delivery.deliver(smallvec![fd.into()]), + ) + .await + .expect("deliver queue timeout") + .expect("deliver should complete once queued"); + assert_eq!(delivered.id(), id); + + let received = timeout(Duration::from_secs(1), receiver) .await - .expect("stream item") - .expect("empty bytes frame"); - - let received = waiter.await.expect("join").expect("wait result"); + .expect("receiver timeout") + .expect("receive fds"); assert_eq!(received.len(), 1); } #[tokio::test] - async fn duplicate_id_rejected_in_registry() { - let (registry_core, registry) = FdRegistry::new_for_test(); - - let (fd1, _peer1) = StdUnixStream::pair().expect("pair1"); - let (fd2, _peer2) = StdUnixStream::pair().expect("pair2"); - let id = VarInt::from_u32(7); - - registry_core - .register_arrival(id, smallvec![fd1.into()]) - .expect("first arrival"); - let duplicate = registry_core.register_arrival(id, smallvec![fd2.into()]); - assert!(matches!( - duplicate, - Err(RegisterFdsError::DuplicateId { .. }) - )); - - let first = registry.wait_fds(id).await.expect("first still valid"); - assert_eq!(first.len(), 1); - } - - #[tokio::test] - async fn wait_after_consumed_waits_for_next_arrival() { - let (registry_core, registry) = FdRegistry::new_for_test(); - let id = VarInt::from_u32(9); - - let (fd1, _peer1) = StdUnixStream::pair().expect("pair1"); - registry_core - .register_arrival(id, smallvec![fd1.into()]) - .expect("first arrival"); - let first = registry.wait_fds(id).await.expect("first consume"); - assert_eq!(first.len(), 1); - - let reg2 = registry.clone(); - let waiter = tokio::spawn(async move { reg2.wait_fds(id).await }); - tokio::task::yield_now().await; - assert!(!waiter.is_finished()); - - let (fd2, _peer2) = StdUnixStream::pair().expect("pair2"); - registry_core - .register_arrival(id, smallvec![fd2.into()]) - .expect("second arrival"); - - let second = waiter.await.expect("join").expect("second consume"); - assert_eq!(second.len(), 1); - } - - #[tokio::test] - async fn interleaved_fd_and_bytes_do_not_leak_fd_frames() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); + async fn receiver_drop_only_removes_local_waiting_slot() { + let ((sink_a, stream_a), (sink_b, stream_b)) = make_pair(); + let transfer_a = stream_a.fd_transfer(sink_a.fd_sender()); + let transfer_b = stream_b.fd_transfer(sink_b.fd_sender()); - let registry = stream_b.fd_registry(); - let (fd_a, _fd_b) = StdUnixStream::pair().expect("fd pair"); - let id = sink_a - .fd_sender() - .queue_fds(smallvec![fd_a.into()]) - .expect("queue fds"); - - sink_a - .send(Bytes::from_static(b"payload")) - .await - .expect("send payload"); - - let payload = stream_b - .next() - .await - .expect("stream item") - .expect("bytes frame"); - assert_eq!(payload, Bytes::from_static(b"payload")); - - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 1); - } - - #[tokio::test] - async fn sink_drop_propagates_eof() { - let ((sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - - drop(sink_a); - let eof = timeout(Duration::from_millis(200), stream_b.next()) - .await - .expect("expected eof signal"); - assert!(eof.is_none()); - } - - #[tokio::test] - async fn fd_sender_closed_after_sink_drop() { - let ((sink_a, _stream_a), (_sink_b, _stream_b)) = make_pair(); - let sender = sink_a.fd_sender(); - drop(sink_a); + let receiver = transfer_b.receive(); + let id = receiver.id(); + drop(receiver); let (fd, _peer) = StdUnixStream::pair().expect("fd pair"); - let result = sender.queue_fds(smallvec![fd.into()]); - assert!(matches!(result, Err(QueueFdsError::Closed))); - } - - // ----------------------------------------------------------------------- - // FdSender boundary tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn queue_fds_empty_vec_rejected() { - let ((sink_a, _stream_a), (_sink_b, _stream_b)) = make_pair(); - let sender = sink_a.fd_sender(); - let result = sender.queue_fds(smallvec![]); - assert!(matches!(result, Err(QueueFdsError::EmptyFds))); - } - - #[tokio::test] - async fn queue_fds_multiple_fds_per_message() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - - let (fd1, _p1) = StdUnixStream::pair().expect("pair1"); - let (fd2, _p2) = StdUnixStream::pair().expect("pair2"); - let (fd3, _p3) = StdUnixStream::pair().expect("pair3"); - let id = sink_a - .fd_sender() - .queue_fds(smallvec![fd1.into(), fd2.into(), fd3.into()]) - .expect("queue 3 fds"); - - sink_a.send(Bytes::new()).await.expect("drive flush"); - let _ = stream_b - .next() - .await - .expect("stream item") - .expect("empty bytes"); - - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 3); - } - - // ----------------------------------------------------------------------- - // FdRegistry boundary tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn wait_fds_already_waiting() { - let (registry_core, registry) = FdRegistry::new_for_test(); - let id = VarInt::from_u32(1); - - let reg2 = registry.clone(); - let waiter = tokio::spawn(async move { reg2.wait_fds(id).await }); - tokio::task::yield_now().await; - - // Second wait on same ID should fail - let result = registry_core.wait_fds(id).await; - assert!(matches!(result, Err(WaitFdsError::AlreadyWaiting { .. }))); - - // Satisfy the first waiter to clean up - let (fd, _peer) = StdUnixStream::pair().expect("pair"); - registry_core - .register_arrival(id, smallvec![fd.into()]) - .expect("arrival"); - waiter.await.expect("join").expect("wait result"); - } - - #[tokio::test] - async fn wait_fds_closed_registry() { - let (registry_core, registry) = FdRegistry::new_for_test(); - registry_core.close(); - - let result = registry.wait_fds(VarInt::from_u32(0)).await; - assert!(matches!(result, Err(WaitFdsError::Closed))); - } - - #[tokio::test] - async fn wait_fds_close_while_waiting() { - let (registry_core, registry) = FdRegistry::new_for_test(); - let id = VarInt::from_u32(42); - - let reg = registry.clone(); - let waiter = tokio::spawn(async move { reg.wait_fds(id).await }); - tokio::task::yield_now().await; - - registry_core.close(); - - let result = waiter.await.expect("join"); - assert!(matches!(result, Err(WaitFdsError::Closed))); - } - - #[tokio::test] - async fn register_arrival_after_close() { - let (registry_core, _registry) = FdRegistry::new_for_test(); - registry_core.close(); - - let (fd, _peer) = StdUnixStream::pair().expect("pair"); - let result = registry_core.register_arrival(VarInt::from_u32(0), smallvec![fd.into()]); - assert!(matches!(result, Err(RegisterFdsError::Closed))); - } - - // ----------------------------------------------------------------------- - // MuxStream demux boundary tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn multiple_fd_ids_interleaved() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - let sender = sink_a.fd_sender(); - - let (fd1, _p1) = StdUnixStream::pair().expect("pair1"); - let id1 = sender.queue_fds(smallvec![fd1.into()]).expect("queue 1"); - - sink_a - .send(Bytes::from_static(b"msg1")) - .await - .expect("send msg1"); - - let (fd2, _p2) = StdUnixStream::pair().expect("pair2"); - let id2 = sender.queue_fds(smallvec![fd2.into()]).expect("queue 2"); - - sink_a - .send(Bytes::from_static(b"msg2")) - .await - .expect("send msg2"); - - let (fd3, _p3) = StdUnixStream::pair().expect("pair3"); - let id3 = sender.queue_fds(smallvec![fd3.into()]).expect("queue 3"); - - sink_a - .send(Bytes::from_static(b"msg3")) - .await - .expect("send msg3"); - - // FD frames are demuxed internally; only bytes frames are yielded - let p1 = stream_b.next().await.expect("item1").expect("bytes1"); - let p2 = stream_b.next().await.expect("item2").expect("bytes2"); - let p3 = stream_b.next().await.expect("item3").expect("bytes3"); - assert_eq!(p1, Bytes::from_static(b"msg1")); - assert_eq!(p2, Bytes::from_static(b"msg2")); - assert_eq!(p3, Bytes::from_static(b"msg3")); - - let fds1 = registry.wait_fds(id1).await.expect("wait 1"); - let fds2 = registry.wait_fds(id2).await.expect("wait 2"); - let fds3 = registry.wait_fds(id3).await.expect("wait 3"); - assert_eq!(fds1.len(), 1); - assert_eq!(fds2.len(), 1); - assert_eq!(fds3.len(), 1); - } - - #[tokio::test] - async fn large_payload_roundtrip() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - - let large = Bytes::from(vec![0xAB; 128 * 1024]); - sink_a.send(large.clone()).await.expect("send large"); - - let received = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(received, large); - } - - #[tokio::test] - async fn consecutive_fd_frames() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - let sender = sink_a.fd_sender(); - - // Queue two FD frames back to back without bytes between - let (fd1, _p1) = StdUnixStream::pair().expect("pair1"); - let (fd2, _p2) = StdUnixStream::pair().expect("pair2"); - let id1 = sender.queue_fds(smallvec![fd1.into()]).expect("queue 1"); - let id2 = sender.queue_fds(smallvec![fd2.into()]).expect("queue 2"); - - // A bytes frame drives the flush of both FD frames - sink_a - .send(Bytes::from_static(b"after")) - .await - .expect("send"); - - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from_static(b"after")); - - let fds1 = registry.wait_fds(id1).await.expect("wait 1"); - let fds2 = registry.wait_fds(id2).await.expect("wait 2"); - assert_eq!(fds1.len(), 1); - assert_eq!(fds2.len(), 1); - } - - #[tokio::test] - async fn fd_frame_without_ancillary_errors() { - // Craft a raw FD frame without SCM_RIGHTS ancillary data - let (raw_writer, reader) = StdUnixStream::pair().expect("pair"); - let reader_ch = MuxChannel::from_fd(reader.into()).expect("from_fd"); - let (_sink, mut stream) = reader_ch.split().expect("split"); - - // FD frame: type=0x01, length=2, id=0x00, fd_count=0x01 - use std::io::Write; - (&raw_writer) - .write_all(&[FRAME_TYPE_FDS, 0x02, 0x00, 0x01]) - .expect("write"); - drop(raw_writer); - - let result = stream.next().await; - assert!( - matches!(result, Some(Err(MuxStreamError::MissingAncillaryFds))), - "expected MissingAncillaryFds, got: {result:?}" - ); - } - - #[tokio::test] - async fn unknown_frame_type_rejected() { - let (raw_writer, reader) = StdUnixStream::pair().expect("pair"); - let reader_ch = MuxChannel::from_fd(reader.into()).expect("from_fd"); - let (_sink, mut stream) = reader_ch.split().expect("split"); - - // Unknown frame type 0xFF - use std::io::Write; - (&raw_writer).write_all(&[0xFF, 0x01, 0x00]).expect("write"); - drop(raw_writer); - - let result = stream.next().await; - assert!( - matches!( - result, - Some(Err(MuxStreamError::UnknownFrameType { frame_type: 0xFF })) - ), - "expected UnknownFrameType(0xFF), got: {result:?}" - ); - } - - // ----------------------------------------------------------------------- - // MuxSink boundary tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn send_after_close_errors() { - let ((mut sink_a, _stream_a), (_sink_b, _stream_b)) = make_pair(); - sink_a.close().await.expect("close"); - - let result = sink_a.send(Bytes::from_static(b"after-close")).await; - assert!(matches!(result, Err(MuxSinkError::Closed))); - } - - #[tokio::test] - async fn fd_priority_over_bytes() { - use futures::future::poll_fn; - - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - - // Step 1: poll_ready + start_send to park pending bytes - poll_fn(|cx| Pin::new(&mut sink_a).poll_ready(cx)) - .await - .expect("ready"); - Pin::new(&mut sink_a) - .start_send(Bytes::from_static(b"after-fds")) - .expect("start_send"); - - // Step 2: queue FD — it should be flushed before the pending bytes - let (fd, _peer) = StdUnixStream::pair().expect("pair"); - let fd_id = sink_a - .fd_sender() - .queue_fds(smallvec![fd.into()]) - .expect("queue fds"); - - // Step 3: flush sends FD frame first, then bytes frame - sink_a.flush().await.expect("flush"); - - // Stream yields only bytes (FD is demuxed internally) - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from_static(b"after-fds")); - - // FD should be available - let fds = registry.wait_fds(fd_id).await.expect("wait fds"); - assert_eq!(fds.len(), 1); - } - - #[tokio::test] - async fn empty_bytes_frame() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - sink_a.send(Bytes::new()).await.expect("send empty"); - - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert!(payload.is_empty()); - } - - // ----------------------------------------------------------------------- - // Close propagation tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn stream_drop_propagates_registry_close() { - let ((_sink_a, _stream_a), (_sink_b, stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - drop(stream_b); - - let result = registry.wait_fds(VarInt::from_u32(0)).await; - assert!(matches!(result, Err(WaitFdsError::Closed))); - } - - // ----------------------------------------------------------------------- - // Phase 1: Validation path tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn queue_fds_too_many_rejected() { - let ((sink_a, _stream_a), (_sink_b, _stream_b)) = make_pair(); - let sender = sink_a.fd_sender(); - - let (fd1, _p1) = StdUnixStream::pair().expect("pair1"); - let (fd2, _p2) = StdUnixStream::pair().expect("pair2"); - let (fd3, _p3) = StdUnixStream::pair().expect("pair3"); - let (fd4, _p4) = StdUnixStream::pair().expect("pair4"); - let (fd5, _p5) = StdUnixStream::pair().expect("pair5"); - - let result = sender.queue_fds(smallvec![ - fd1.into(), - fd2.into(), - fd3.into(), - fd4.into(), - fd5.into(), - ]); - assert!( - matches!(result, Err(QueueFdsError::TooManyFds { count: 5 })), - "expected TooManyFds {{ count: 5 }}, got: {result:?}" - ); - } - - #[tokio::test] - async fn max_fds_per_frame_roundtrip() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - - let (fd1, _p1) = StdUnixStream::pair().expect("pair1"); - let (fd2, _p2) = StdUnixStream::pair().expect("pair2"); - let (fd3, _p3) = StdUnixStream::pair().expect("pair3"); - let (fd4, _p4) = StdUnixStream::pair().expect("pair4"); - - let id = sink_a - .fd_sender() - .queue_fds(smallvec![fd1.into(), fd2.into(), fd3.into(), fd4.into()]) - .expect("queue 4 fds"); - - sink_a.send(Bytes::new()).await.expect("drive flush"); - let _ = stream_b - .next() - .await - .expect("stream item") - .expect("empty bytes"); - - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 4); - } - - #[tokio::test] - async fn fd_frame_fd_count_zero_rejected() { - let (raw_writer, reader) = StdUnixStream::pair().expect("pair"); - let reader_ch = MuxChannel::from_fd(reader.into()).expect("from_fd"); - let (_sink, mut stream) = reader_ch.split().expect("split"); - - // FD frame: type=0x01, payload_len=2, id=0x00, fd_count=0x00 - use std::io::Write; - (&raw_writer) - .write_all(&[FRAME_TYPE_FDS, 0x02, 0x00, 0x00]) - .expect("write"); - drop(raw_writer); - - let result = stream.next().await; - assert!( - matches!(result, Some(Err(MuxStreamError::InvalidFdsPayload))), - "expected InvalidFdsPayload, got: {result:?}" - ); - } - - #[tokio::test] - async fn fd_frame_fd_count_exceeds_max_rejected() { - let (raw_writer, reader) = StdUnixStream::pair().expect("pair"); - let reader_ch = MuxChannel::from_fd(reader.into()).expect("from_fd"); - let (_sink, mut stream) = reader_ch.split().expect("split"); - - // FD frame: type=0x01, payload_len=2, id=0x00, fd_count=0x05 - use std::io::Write; - (&raw_writer) - .write_all(&[FRAME_TYPE_FDS, 0x02, 0x00, 0x05]) - .expect("write"); - drop(raw_writer); - - let result = stream.next().await; - assert!( - matches!(result, Some(Err(MuxStreamError::InvalidFdsPayload))), - "expected InvalidFdsPayload, got: {result:?}" - ); - } - - #[tokio::test] - async fn fd_frame_trailing_bytes_rejected() { - let (raw_writer, reader) = StdUnixStream::pair().expect("pair"); - let reader_ch = MuxChannel::from_fd(reader.into()).expect("from_fd"); - let (_sink, mut stream) = reader_ch.split().expect("split"); - - // FD frame: type=0x01, payload_len=3, id=0x00, fd_count=0x01, extra=0xFF - // id(1 byte) + fd_count(1 byte) = 2 bytes, but payload_len says 3 - use std::io::Write; - (&raw_writer) - .write_all(&[FRAME_TYPE_FDS, 0x03, 0x00, 0x01, 0xFF]) - .expect("write"); - drop(raw_writer); - - let result = stream.next().await; - assert!( - matches!(result, Some(Err(MuxStreamError::InvalidFdsPayload))), - "expected InvalidFdsPayload, got: {result:?}" - ); - } - - // ----------------------------------------------------------------------- - // Phase 2: Frame decoding boundary tests - // ----------------------------------------------------------------------- - - #[tokio::test] - async fn incomplete_bytes_header_waits_for_more_data() { - let (raw_writer, reader) = StdUnixStream::pair().expect("pair"); - let reader_ch = MuxChannel::from_fd(reader.into()).expect("from_fd"); - let (_sink, mut stream) = reader_ch.split().expect("split"); - - // Write only the frame type byte; payload_len + payload missing. - use std::io::Write; - (&raw_writer) - .write_all(&[FRAME_TYPE_BYTES]) - .expect("write type"); - - // Stream should not yield yet (incomplete header). - let poll = timeout(Duration::from_millis(50), stream.next()).await; - assert!(poll.is_err(), "expected timeout (incomplete header)"); - - // Now write the rest: payload_len=5, payload="hello" - (&raw_writer) - .write_all(&[0x05, b'h', b'e', b'l', b'l', b'o']) - .expect("write rest"); - drop(raw_writer); - - let payload = stream.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from_static(b"hello")); - - // EOF - assert!(stream.next().await.is_none()); - } - - #[tokio::test] - async fn varint_2byte_length_roundtrip() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - - // 100 bytes → varint length needs 2 bytes (64..16383) - let data = Bytes::from(vec![0xAB; 100]); - sink_a.send(data.clone()).await.expect("send"); - - let received = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(received, data); - } - - #[tokio::test] - async fn fd_over_bytes_ordering_guarantee() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - let sender = sink_a.fd_sender(); - - // Interleave: FD, bytes, FD, bytes, FD, bytes - let mut fd_ids = Vec::new(); - for i in 0..3 { - let (fd, _peer) = StdUnixStream::pair().expect("fd pair"); - let id = sender.queue_fds(smallvec![fd.into()]).expect("queue fds"); - fd_ids.push(id); - - let msg = Bytes::from(format!("msg-{i}")); - sink_a.send(msg).await.expect("send"); - } - - // All 3 bytes frames arrive in order - for i in 0..3 { - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from(format!("msg-{i}"))); - } - - // All 3 FDs arrive - for &id in &fd_ids { - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 1); - } - } - - #[tokio::test] - async fn concurrent_fds_and_bytes_stress() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - let sender = sink_a.fd_sender(); - - let rounds = 20; - let mut fd_ids = Vec::with_capacity(rounds); - - for i in 0..rounds { - let (fd, _peer) = StdUnixStream::pair().expect("fd pair"); - let id = sender.queue_fds(smallvec![fd.into()]).expect("queue fds"); - fd_ids.push(id); - sink_a - .send(Bytes::from(format!("stress-{i}"))) - .await - .expect("send"); - } - - // Drain all bytes frames - for i in 0..rounds { - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from(format!("stress-{i}"))); - } - - // All FDs accessible - for &id in &fd_ids { - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 1); - } + let delivered = timeout( + Duration::from_secs(1), + transfer_a.delivery(id).deliver(smallvec![fd.into()]), + ) + .await + .expect("delivery should not wait for dropped receiver") + .expect("dropped receiver should not be sender-visible"); + assert_eq!(delivered.id(), id); } - // ----------------------------------------------------------------------- - // Phase 3: Close / lifecycle tests - // ----------------------------------------------------------------------- - #[tokio::test] - async fn graceful_close_flushes_pending_fds() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - let registry = stream_b.fd_registry(); - + async fn fd_arrival_without_receiver_is_dropped() { + let ((mut sink_a, stream_a), (_sink_b, mut stream_b)) = make_pair(); + let transfer_a = stream_a.fd_transfer(sink_a.fd_sender()); + let id = VarInt::from_u32(7); + let delivery = transfer_a.delivery(id); let (fd, _peer) = StdUnixStream::pair().expect("fd pair"); - let id = sink_a - .fd_sender() - .queue_fds(smallvec![fd.into()]) - .expect("queue fds"); + let delivered = timeout( + Duration::from_secs(1), + delivery.deliver(smallvec![fd.into()]), + ) + .await + .expect("unknown fd delivery should queue") + .expect("unknown fd id should not cancel delivery"); + assert_eq!(delivered.id(), id); - // Also send a bytes frame so the stream yields something before EOF sink_a - .send(Bytes::from_static(b"before-close")) + .send(Bytes::from_static(b"after-unknown-fd")) .await - .expect("send"); - - // close() should flush both pending FD frame and bytes, then shutdown - sink_a.close().await.expect("close"); - - // Drive the stream: the FD frame is consumed internally, bytes yielded - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from_static(b"before-close")); - - // FD was registered during poll_next — wait_fds should succeed - let fds = registry.wait_fds(id).await.expect("wait fds"); - assert_eq!(fds.len(), 1); - - // Now EOF - let eof = stream_b.next().await; - assert!(eof.is_none(), "expected EOF after close"); - } - - #[tokio::test] - async fn stream_eof_after_sink_close() { - let ((mut sink_a, _stream_a), (_sink_b, mut stream_b)) = make_pair(); - - // Send some data, then close explicitly - sink_a - .send(Bytes::from_static(b"before-close")) + .expect("send after unknown fd"); + let received = timeout(Duration::from_secs(1), stream_b.next()) .await - .expect("send"); - sink_a.close().await.expect("close"); - - let payload = stream_b.next().await.expect("item").expect("bytes"); - assert_eq!(payload, Bytes::from_static(b"before-close")); - - let eof = stream_b.next().await; - assert!(eof.is_none(), "expected EOF"); - } - - #[tokio::test] - async fn double_close_is_idempotent() { - let ((mut sink_a, _stream_a), (_sink_b, _stream_b)) = make_pair(); - - sink_a.close().await.expect("first close"); - let result = sink_a.close().await; - assert!( - matches!(result, Err(MuxSinkError::Closed)), - "expected Closed on double close, got: {result:?}" - ); + .expect("stream should not be closed by unknown fd") + .expect("stream item") + .expect("unknown fd should not be a protocol error"); + assert_eq!(received, Bytes::from_static(b"after-unknown-fd")); } } diff --git a/src/ipc/transport/driver.rs b/src/ipc/transport/driver.rs new file mode 100644 index 0000000..347b803 --- /dev/null +++ b/src/ipc/transport/driver.rs @@ -0,0 +1,360 @@ +use std::{ + collections::VecDeque, + io, + os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}, + sync::{Arc, Mutex, Weak}, +}; + +use bytes::{Bytes, BytesMut}; +use futures::task::AtomicWaker; +use nix::sys::socket::{ControlMessageOwned, MsgFlags, Shutdown, recvmsg, shutdown}; +use tokio::{ + io::unix::AsyncFd, + sync::{Notify, mpsc}, +}; +use tokio_util::task::AbortOnDropHandle; +use tracing::Instrument as _; + +use super::{FdVec, MAX_FDS_PER_FRAME, MuxStreamError, QueueFdsError, fd_plane::FdPlaneCore}; +use crate::{ipc::transport::frame, varint::VarInt}; + +const READ_CHUNK_LEN: usize = 8 * 1024; +const MIN_FDS_FRAME_LEN: usize = 4; + +pub(crate) type ReaderStart = ( + mpsc::UnboundedReceiver>, + AbortOnDropHandle<()>, +); + +#[derive(Debug)] +pub(crate) struct WriterCore { + fd: AsyncFd, + state: Mutex, + notify: Notify, + flush_waker: AtomicWaker, +} + +#[derive(Debug)] +struct WriterState { + queue: VecDeque, + closed: bool, + in_flight: bool, +} + +impl WriterCore { + pub(crate) fn queue(&self, frame: frame::OutboundFrame) -> Result<(), QueueFdsError> { + let mut state = self.state.lock().expect("writer state poisoned"); + if state.closed { + return Err(QueueFdsError::Closed); + } + state.queue.push_back(frame); + drop(state); + self.notify.notify_one(); + self.flush_waker.wake(); + Ok(()) + } + + pub(crate) fn close(&self) { + let mut state = self.state.lock().expect("writer state poisoned"); + state.closed = true; + drop(state); + self.notify.notify_one(); + self.flush_waker.wake(); + } + + pub(crate) fn shutdown_write(&self) { + let _ = shutdown(self.fd.get_ref().as_raw_fd(), Shutdown::Write); + } + + pub(crate) fn is_closed(&self) -> bool { + self.state.lock().expect("writer state poisoned").closed + } + + pub(crate) fn is_flushed(&self) -> bool { + let state = self.state.lock().expect("writer state poisoned"); + state.queue.is_empty() && !state.in_flight + } + + pub(crate) fn register_flush_waker(&self, waker: &std::task::Waker) { + self.flush_waker.register(waker); + } +} + +#[derive(Clone, Debug)] +pub struct FdSender { + core: Weak, +} + +impl FdSender { + fn new(core: &Arc) -> Self { + Self { + core: Arc::downgrade(core), + } + } + + pub(crate) fn send_fds(&self, id: VarInt, fds: FdVec) -> Result<(), QueueFdsError> { + if fds.is_empty() { + return Err(QueueFdsError::EmptyFds); + } + if fds.len() > MAX_FDS_PER_FRAME { + return Err(QueueFdsError::TooManyFds { count: fds.len() }); + } + self.queue(frame::OutboundFrame::Fds { id, fds }) + } + + fn queue(&self, frame: frame::OutboundFrame) -> Result<(), QueueFdsError> { + let Some(core) = self.core.upgrade() else { + return Err(QueueFdsError::Closed); + }; + core.queue(frame) + } +} + +pub(crate) fn start_writer( + fd: OwnedFd, +) -> io::Result<(Arc, FdSender, AbortOnDropHandle<()>)> { + let core = Arc::new(WriterCore { + fd: AsyncFd::new(fd)?, + state: Mutex::new(WriterState { + queue: VecDeque::new(), + closed: false, + in_flight: false, + }), + notify: Notify::new(), + flush_waker: AtomicWaker::new(), + }); + let sender = FdSender::new(&core); + let task = AbortOnDropHandle::new(tokio::spawn(writer_loop(core.clone()).in_current_span())); + Ok((core, sender, task)) +} + +pub(crate) fn start_reader(fd: OwnedFd, plane: Arc) -> io::Result { + let fd = AsyncFd::new(fd)?; + let (bytes_tx, bytes_rx) = mpsc::unbounded_channel(); + let task = AbortOnDropHandle::new(tokio::spawn( + reader_loop(fd, plane, bytes_tx).in_current_span(), + )); + Ok((bytes_rx, task)) +} + +async fn writer_loop(core: Arc) { + let mut current = None; + + loop { + if current.is_none() { + let next = { + let mut state = core.state.lock().expect("writer state poisoned"); + if let Some(frame) = state.queue.pop_front() { + state.in_flight = true; + Some(frame) + } else { + state.in_flight = false; + core.flush_waker.wake(); + if state.closed { + break; + } + None + } + }; + + let Some(frame) = next else { + core.notify.notified().await; + continue; + }; + current = Some(frame::pending_frame(frame)); + } + + let pending = current.as_mut().expect("pending frame must exist"); + if !write_pending_frame(&core.fd, pending).await { + break; + } + + if pending.is_complete() { + current = None; + let mut state = core.state.lock().expect("writer state poisoned"); + state.in_flight = false; + if state.queue.is_empty() { + core.flush_waker.wake(); + } + } + } + + let _ = shutdown(core.fd.get_ref().as_raw_fd(), Shutdown::Write); + let mut state = core.state.lock().expect("writer state poisoned"); + state.closed = true; + state.in_flight = false; + state.queue.clear(); + drop(state); + core.flush_waker.wake(); +} + +async fn write_pending_frame(fd: &AsyncFd, frame: &mut frame::PendingFrame) -> bool { + while !frame.is_complete() { + let mut guard = match fd.writable().await { + Ok(guard) => guard, + Err(_) => return false, + }; + let result = + guard.try_io(|inner| frame::send_pending_frame(inner.get_ref().as_raw_fd(), frame)); + let sent = match result { + Ok(Ok(sent)) => sent, + Ok(Err(_error)) => return false, + Err(_would_block) => continue, + }; + if sent == 0 { + return false; + } + } + true +} + +async fn reader_loop( + fd: AsyncFd, + plane: Arc, + bytes_tx: mpsc::UnboundedSender>, +) { + let mut read_buf = BytesMut::new(); + let mut cmsg_buf = { + let per_msg = nix::sys::socket::cmsg_space::<[RawFd; MAX_FDS_PER_FRAME]>(); + let max_msgs = READ_CHUNK_LEN / MIN_FDS_FRAME_LEN; + vec![0u8; max_msgs * per_msg] + }; + let mut pending_fds = VecDeque::new(); + let mut pending_fd_frame = None; + + loop { + if let Some((id, fd_count)) = pending_fd_frame { + if pending_fds.len() >= fd_count { + pending_fd_frame = None; + let fds: FdVec = pending_fds.drain(..fd_count).collect(); + plane.arrive_fds(id, fds); + continue; + } + } else { + match frame::try_decode_frame(&mut read_buf) { + Ok(Some(frame::InboundFrame::Bytes(payload))) => { + if bytes_tx.send(Ok(payload)).is_err() { + plane.close(); + return; + } + continue; + } + Ok(Some(frame::InboundFrame::Fds { id, fd_count })) => { + if pending_fds.len() >= fd_count { + let fds: FdVec = pending_fds.drain(..fd_count).collect(); + plane.arrive_fds(id, fds); + continue; + } + pending_fd_frame = Some((id, fd_count)); + } + Ok(None) => {} + Err(error) => { + plane.close(); + let _ = bytes_tx.send(Err(error)); + return; + } + } + } + + let mut chunk = [0u8; READ_CHUNK_LEN]; + let mut guard = match fd.readable().await { + Ok(guard) => guard, + Err(source) => { + plane.close(); + let _ = bytes_tx.send(Err(MuxStreamError::PollReady { source })); + return; + } + }; + let result = guard.try_io(|inner| { + recv_frame_data(inner.get_ref().as_raw_fd(), &mut chunk, &mut cmsg_buf) + }); + + let (read, ancillary) = match result { + Ok(Ok(RecvOutcome::Data { read, ancillary })) => (read, ancillary), + Ok(Ok(RecvOutcome::AncillaryTruncated)) => { + plane.close(); + let _ = bytes_tx.send(Err(MuxStreamError::AncillaryTruncated)); + return; + } + Ok(Err(source)) => { + plane.close(); + let _ = bytes_tx.send(Err(MuxStreamError::Recv { source })); + return; + } + Err(_would_block) => continue, + }; + + if read == 0 { + plane.close(); + if pending_fd_frame.is_some() { + let _ = bytes_tx.send(Err(MuxStreamError::MissingAncillaryFds)); + } + return; + } + + if !ancillary.is_empty() { + pending_fds.extend(ancillary); + } + read_buf.extend_from_slice(&chunk[..read]); + } +} + +#[cfg(not(target_os = "linux"))] +fn set_cloexec(fd: &OwnedFd) -> io::Result<()> { + use nix::fcntl::{F_GETFD, F_SETFD, FdFlag, fcntl}; + let bits = match fcntl(fd, F_GETFD) { + Ok(bits) => bits, + Err(error) => return Err(io::Error::from(error)), + }; + let new_flags = FdFlag::from_bits_truncate(bits) | FdFlag::FD_CLOEXEC; + if let Err(error) = fcntl(fd, F_SETFD(new_flags)) { + return Err(io::Error::from(error)); + } + Ok(()) +} + +enum RecvOutcome { + Data { read: usize, ancillary: FdVec }, + AncillaryTruncated, +} + +fn recv_frame_data(fd: RawFd, data_buf: &mut [u8], cmsg_buf: &mut [u8]) -> io::Result { + let mut iov = [io::IoSliceMut::new(data_buf)]; + + #[cfg(target_os = "linux")] + let recv_flags = MsgFlags::MSG_CMSG_CLOEXEC; + #[cfg(not(target_os = "linux"))] + let recv_flags = MsgFlags::empty(); + + let msg = match recvmsg::<()>(fd, &mut iov, Some(cmsg_buf), recv_flags) { + Ok(msg) => msg, + Err(error) => return Err(io::Error::from(error)), + }; + if msg.flags.contains(MsgFlags::MSG_CTRUNC) { + return Ok(RecvOutcome::AncillaryTruncated); + } + + let mut fds = FdVec::new(); + let cmsgs = match msg.cmsgs() { + Ok(cmsgs) => cmsgs, + Err(error) => return Err(io::Error::from(error)), + }; + for cmsg in cmsgs { + if let ControlMessageOwned::ScmRights(raw_fds) = cmsg { + for raw_fd in raw_fds { + // SAFETY: SCM_RIGHTS transfers ownership of a new fd to the receiver. + let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) }; + #[cfg(not(target_os = "linux"))] + if let Err(source) = set_cloexec(&fd) { + return Err(source); + } + fds.push(fd); + } + } + } + + Ok(RecvOutcome::Data { + read: msg.bytes, + ancillary: fds, + }) +} diff --git a/src/ipc/transport/fd_plane.rs b/src/ipc/transport/fd_plane.rs new file mode 100644 index 0000000..58e07bc --- /dev/null +++ b/src/ipc/transport/fd_plane.rs @@ -0,0 +1,307 @@ +use std::{ + collections::HashMap, + future::{Future, IntoFuture}, + os::fd::OwnedFd, + pin::Pin, + sync::{ + Arc, Mutex, Weak, + atomic::{AtomicU64, Ordering}, + }, + task::{Context, Poll}, +}; + +use futures::ready; +use tokio::sync::oneshot; + +use super::{DeliverFdsError, FdVec, QueueFdsError, TakeFdsError, WaitFdsError, driver::FdSender}; +use crate::varint::{VARINT_MAX, VarInt}; + +#[derive(Debug)] +pub(crate) struct FdPlaneCore { + next_id: AtomicU64, + receivers: Mutex, +} + +#[derive(Debug)] +struct ReceiverState { + slots: HashMap>>, + closed: bool, +} + +impl FdPlaneCore { + pub(crate) fn new() -> Self { + Self { + next_id: AtomicU64::new(0), + receivers: Mutex::new(ReceiverState { + slots: HashMap::new(), + closed: false, + }), + } + } + + fn next_id(&self) -> Result { + let id_raw = + match self + .next_id + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| { + let next = current.checked_add(1)?; + if next >= VARINT_MAX { None } else { Some(next) } + }) { + Ok(id_raw) => id_raw, + Err(_) => return Err(WaitFdsError::IdExhausted), + }; + + match VarInt::from_u64(id_raw) { + Ok(id) => Ok(id), + Err(_) => Err(WaitFdsError::IdExhausted), + } + } + + fn reserve( + &self, + id: VarInt, + ) -> Result>, WaitFdsError> { + let mut state = self.receivers.lock().expect("fd receiver state poisoned"); + if state.closed { + return Err(WaitFdsError::Closed); + } + if state.slots.contains_key(&id) { + return Err(WaitFdsError::AlreadyWaiting { id }); + } + let (tx, rx) = oneshot::channel(); + state.slots.insert(id, tx); + Ok(rx) + } + + pub(crate) fn arrive_fds(&self, id: VarInt, fds: FdVec) { + let mut state = self.receivers.lock().expect("fd receiver state poisoned"); + if state.closed { + return; + } + if let Some(waiter) = state.slots.remove(&id) { + let _ = waiter.send(Ok(fds)); + } + } + + fn remove_receiver(&self, id: VarInt) { + let mut state = self.receivers.lock().expect("fd receiver state poisoned"); + state.slots.remove(&id); + } + + pub(crate) fn close(&self) { + let mut receivers = self.receivers.lock().expect("fd receiver state poisoned"); + if receivers.closed { + return; + } + receivers.closed = true; + for (_, waiter) in receivers.slots.drain() { + let _ = waiter.send(Err(WaitFdsError::Closed)); + } + } +} + +#[derive(Clone, Debug)] +pub struct FdTransfer { + sender: FdSender, + plane: Arc, +} + +impl FdTransfer { + pub(crate) fn new(sender: FdSender, plane: Arc) -> Self { + Self { sender, plane } + } + + pub fn receive(&self) -> FdReceiver { + let id = match self.plane.next_id() { + Ok(id) => id, + Err(error) => { + return FdReceiver::ready(VarInt::from_u32(0), error); + } + }; + match self.plane.reserve(id) { + Ok(rx) => FdReceiver { + id, + plane: Arc::downgrade(&self.plane), + rx: Some(rx), + ready: None, + active: true, + }, + Err(error) => FdReceiver::ready(id, error), + } + } + + pub fn delivery(&self, id: VarInt) -> FdDelivery { + FdDelivery { + id, + sender: self.sender.clone(), + } + } +} + +#[derive(Debug)] +pub struct FdReceiver { + id: VarInt, + plane: Weak, + rx: Option>>, + ready: Option>, + active: bool, +} + +impl FdReceiver { + fn ready(id: VarInt, error: WaitFdsError) -> Self { + Self { + id, + plane: Weak::new(), + rx: None, + ready: Some(Err(error)), + active: false, + } + } + + pub fn id(&self) -> VarInt { + self.id + } + + fn disarm(&mut self) { + self.active = false; + self.rx = None; + } +} + +impl Drop for FdReceiver { + fn drop(&mut self) { + if !self.active { + return; + } + if let Some(plane) = self.plane.upgrade() { + plane.remove_receiver(self.id); + } + self.active = false; + } +} + +pub struct FdReceiverFuture { + receiver: FdReceiver, +} + +impl Future for FdReceiverFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(result) = self.receiver.ready.take() { + self.receiver.disarm(); + return Poll::Ready(result.map(ReceivedFds::new)); + } + + let Some(rx) = self.receiver.rx.as_mut() else { + self.receiver.disarm(); + return Poll::Ready(Err(WaitFdsError::ChannelClosed)); + }; + + let fds = match ready!(Pin::new(rx).poll(cx)) { + Ok(Ok(fds)) => fds, + Ok(Err(error)) => { + self.receiver.disarm(); + return Poll::Ready(Err(error)); + } + Err(_) => { + self.receiver.disarm(); + return Poll::Ready(Err(WaitFdsError::ChannelClosed)); + } + }; + + self.receiver.disarm(); + Poll::Ready(Ok(ReceivedFds::new(fds))) + } +} + +impl IntoFuture for FdReceiver { + type Output = Result; + type IntoFuture = FdReceiverFuture; + + fn into_future(self) -> Self::IntoFuture { + FdReceiverFuture { receiver: self } + } +} + +#[derive(Debug)] +pub struct ReceivedFds { + fds: FdVec, +} + +impl ReceivedFds { + fn new(fds: FdVec) -> Self { + Self { fds } + } + + pub fn len(&self) -> usize { + self.fds.len() + } + + pub fn is_empty(&self) -> bool { + self.fds.is_empty() + } + + pub fn into_fds(self) -> FdVec { + self.fds + } + + pub fn into_one(self) -> Result { + if self.fds.len() != 1 { + return Err(TakeFdsError::Count { + expected: 1, + actual: self.fds.len(), + }); + } + Ok(self.fds.into_iter().next().expect("fd count checked")) + } + + pub fn into_pair(self) -> Result<(OwnedFd, OwnedFd), TakeFdsError> { + if self.fds.len() != 2 { + return Err(TakeFdsError::Count { + expected: 2, + actual: self.fds.len(), + }); + } + let mut fds = self.fds.into_iter(); + let first = fds.next().expect("fd count checked"); + let second = fds.next().expect("fd count checked"); + Ok((first, second)) + } +} + +#[derive(Debug)] +pub struct FdDelivery { + id: VarInt, + sender: FdSender, +} + +impl FdDelivery { + pub fn id(&self) -> VarInt { + self.id + } + + pub async fn deliver(self, fds: FdVec) -> Result { + if let Err(source) = self.sender.send_fds(self.id, fds) { + return Err(DeliverFdsError::Queue { source }); + } + Ok(FdDelivered { id: self.id }) + } +} + +#[derive(Debug)] +pub struct FdDelivered { + id: VarInt, +} + +impl FdDelivered { + pub fn id(&self) -> VarInt { + self.id + } +} + +impl From for DeliverFdsError { + fn from(source: QueueFdsError) -> Self { + Self::Queue { source } + } +} diff --git a/src/ipc/transport/frame.rs b/src/ipc/transport/frame.rs new file mode 100644 index 0000000..50e3042 --- /dev/null +++ b/src/ipc/transport/frame.rs @@ -0,0 +1,306 @@ +use std::{ + io, + os::fd::{AsRawFd, RawFd}, +}; + +use bytes::{Buf, Bytes, BytesMut}; +use nix::sys::socket::{ControlMessage, MsgFlags, sendmsg}; +use smallvec::SmallVec; + +use super::{FdVec, MAX_FDS_PER_FRAME, MuxStreamError}; +use crate::varint::{VARINT_MAX, VarInt}; + +const FRAME_TYPE_BYTES: u8 = 0x00; +const FRAME_TYPE_FDS: u8 = 0x01; + +const MAX_FRAME_HEADER_LEN: usize = 1 + VarInt::MAX_SIZE; +const MAX_FDS_FRAME_LEN: usize = 1 + VarInt::MAX_SIZE + VarInt::MAX_SIZE + VarInt::MAX_SIZE; + +#[derive(Debug)] +pub(crate) enum OutboundFrame { + Bytes(Bytes), + Fds { id: VarInt, fds: FdVec }, +} + +#[derive(Debug)] +pub(crate) enum InboundFrame { + Bytes(Bytes), + Fds { id: VarInt, fd_count: usize }, +} + +#[derive(Debug)] +pub(crate) enum PendingFrame { + Bytes { + header: [u8; MAX_FRAME_HEADER_LEN], + header_len: usize, + header_written: usize, + payload: Bytes, + payload_written: usize, + }, + Fds { + header: [u8; MAX_FDS_FRAME_LEN], + header_len: usize, + header_written: usize, + payload: FdVec, + include_ancillary: bool, + }, +} + +impl PendingFrame { + fn new_bytes(payload: Bytes) -> Self { + let payload_len = VarInt::try_from(payload.len()).expect("payload length fits varint"); + let mut header = [0u8; MAX_FRAME_HEADER_LEN]; + header[0] = FRAME_TYPE_BYTES; + let len_len = encode_varint_to_slice(&mut header[1..], payload_len); + Self::Bytes { + header, + header_len: 1 + len_len, + header_written: 0, + payload, + payload_written: 0, + } + } + + fn new_fds(id: VarInt, fds: FdVec) -> Self { + let fd_count = VarInt::try_from(fds.len()).expect("fd count fits varint"); + + let mut id_buf = [0u8; VarInt::MAX_SIZE]; + let id_len = encode_varint_to_slice(&mut id_buf, id); + let mut count_buf = [0u8; VarInt::MAX_SIZE]; + let count_len = encode_varint_to_slice(&mut count_buf, fd_count); + + let mut header = [0u8; MAX_FDS_FRAME_LEN]; + header[0] = FRAME_TYPE_FDS; + let body_len = VarInt::try_from(id_len + count_len).expect("fd body length fits varint"); + let len_len = encode_varint_to_slice(&mut header[1..], body_len); + let body_start = 1 + len_len; + header[body_start..body_start + id_len].copy_from_slice(&id_buf[..id_len]); + header[body_start + id_len..body_start + id_len + count_len] + .copy_from_slice(&count_buf[..count_len]); + + Self::Fds { + header, + header_len: body_start + id_len + count_len, + header_written: 0, + payload: fds, + include_ancillary: true, + } + } + + pub(crate) fn is_complete(&self) -> bool { + match self { + Self::Bytes { + header_len, + header_written, + payload, + payload_written, + .. + } => *header_written >= *header_len && *payload_written >= payload.len(), + Self::Fds { + header_len, + header_written, + .. + } => *header_written >= *header_len, + } + } + + fn advance(&mut self, mut n: usize) { + match self { + Self::Bytes { + header_len, + header_written, + payload, + payload_written, + .. + } => { + let header_remaining = *header_len - *header_written; + let header_advance = n.min(header_remaining); + *header_written += header_advance; + n -= header_advance; + + if n > 0 { + let payload_remaining = payload.len().saturating_sub(*payload_written); + let payload_advance = n.min(payload_remaining); + *payload_written += payload_advance; + } + } + Self::Fds { + header_written, + include_ancillary, + .. + } => { + *header_written += n; + if n > 0 { + *include_ancillary = false; + } + } + } + } +} + +pub(crate) fn pending_frame(frame: OutboundFrame) -> PendingFrame { + match frame { + OutboundFrame::Bytes(payload) => PendingFrame::new_bytes(payload), + OutboundFrame::Fds { id, fds } => PendingFrame::new_fds(id, fds), + } +} + +pub(crate) fn send_pending_frame(fd: RawFd, frame: &mut PendingFrame) -> io::Result { + let sent = match frame { + PendingFrame::Bytes { + header, + header_len, + header_written, + payload, + payload_written, + } => { + let mut iovecs = [io::IoSlice::new(&[]), io::IoSlice::new(&[])]; + let mut iov_count = 0usize; + + if *header_written < *header_len { + iovecs[iov_count] = io::IoSlice::new(&header[*header_written..*header_len]); + iov_count += 1; + } + if *payload_written < payload.len() { + iovecs[iov_count] = io::IoSlice::new(&payload[*payload_written..]); + iov_count += 1; + } + + if iov_count == 0 { + return Ok(0); + } + sendmsg::<()>(fd, &iovecs[..iov_count], &[], MsgFlags::empty(), None) + } + PendingFrame::Fds { + header, + header_len, + header_written, + payload, + include_ancillary, + } => { + if *header_written >= *header_len { + return Ok(0); + } + let iov = [io::IoSlice::new(&header[*header_written..*header_len])]; + if *include_ancillary { + let raw_fds: SmallVec<[RawFd; 4]> = + payload.iter().map(AsRawFd::as_raw_fd).collect(); + let cmsgs = [ControlMessage::ScmRights(&raw_fds)]; + sendmsg::<()>(fd, &iov, &cmsgs, MsgFlags::empty(), None) + } else { + sendmsg::<()>(fd, &iov, &[], MsgFlags::empty(), None) + } + } + }; + + let sent = match sent { + Ok(sent) => sent, + Err(error) => return Err(io::Error::from(error)), + }; + frame.advance(sent); + Ok(sent) +} + +pub(crate) fn try_decode_frame(src: &mut BytesMut) -> Result, MuxStreamError> { + if src.is_empty() { + return Ok(None); + } + + let frame_type = src[0]; + if frame_type != FRAME_TYPE_BYTES && frame_type != FRAME_TYPE_FDS { + return Err(MuxStreamError::UnknownFrameType { frame_type }); + } + + let Some((payload_len, varint_len)) = try_decode_varint_len(&src[1..])? else { + return Ok(None); + }; + let frame_header_len = 1 + varint_len; + if src.len() < frame_header_len + payload_len { + return Ok(None); + } + + src.advance(frame_header_len); + let payload = src.split_to(payload_len).freeze(); + + match frame_type { + FRAME_TYPE_BYTES => Ok(Some(InboundFrame::Bytes(payload))), + FRAME_TYPE_FDS => { + let (id, id_consumed) = decode_varint_from_slice(&payload)?; + let remaining = &payload[id_consumed..]; + let (fd_count_vi, count_consumed) = decode_varint_from_slice(remaining)?; + if id_consumed + count_consumed != payload.len() { + return Err(MuxStreamError::InvalidFdsPayload); + } + let Ok(fd_count) = usize::try_from(u64::from(fd_count_vi)) else { + return Err(MuxStreamError::InvalidFdsPayload); + }; + if fd_count == 0 || fd_count > MAX_FDS_PER_FRAME { + return Err(MuxStreamError::InvalidFdsPayload); + } + Ok(Some(InboundFrame::Fds { id, fd_count })) + } + _ => Err(MuxStreamError::UnknownFrameType { frame_type }), + } +} + +fn try_decode_varint_len(src: &[u8]) -> Result, MuxStreamError> { + let Some((value, consumed)) = try_decode_varint(src)? else { + return Ok(None); + }; + let Ok(payload_len) = usize::try_from(value) else { + return Err(MuxStreamError::InvalidFrameLength); + }; + Ok(Some((payload_len, consumed))) +} + +fn decode_varint_from_slice(src: &[u8]) -> Result<(VarInt, usize), MuxStreamError> { + let Some((value, consumed)) = try_decode_varint(src)? else { + return Err(MuxStreamError::InvalidFdsPayload); + }; + let Ok(id) = VarInt::from_u64(value) else { + return Err(MuxStreamError::InvalidFdsPayload); + }; + Ok((id, consumed)) +} + +fn try_decode_varint(src: &[u8]) -> Result, MuxStreamError> { + if src.is_empty() { + return Ok(None); + } + + let first = src[0]; + let len = 1usize << (first >> 6); + if src.len() < len { + return Ok(None); + } + + let mut raw = [0u8; 8]; + raw[..len].copy_from_slice(&src[..len]); + raw[0] &= 0x3f; + let value = u64::from_be_bytes(raw) >> (8 * (8 - len)); + if value >= VARINT_MAX { + return Err(MuxStreamError::InvalidFrameLength); + } + + Ok(Some((value, len))) +} + +fn encode_varint_to_slice(dst: &mut [u8], v: VarInt) -> usize { + let x = v.into_inner(); + if x < (1 << 6) { + dst[0] = x as u8; + 1 + } else if x < (1 << 14) { + let bytes = ((0b01 << 14) | x as u16).to_be_bytes(); + dst[..2].copy_from_slice(&bytes); + 2 + } else if x < (1 << 30) { + let bytes = ((0b10 << 30) | x as u32).to_be_bytes(); + dst[..4].copy_from_slice(&bytes); + 4 + } else { + let bytes = ((0b11 << 62) | x).to_be_bytes(); + dst[..8].copy_from_slice(&bytes); + 8 + } +} diff --git a/src/ipc/webtransport.rs b/src/ipc/webtransport.rs index 6db16e5..370672b 100644 --- a/src/ipc/webtransport.rs +++ b/src/ipc/webtransport.rs @@ -1,60 +1,76 @@ //! IPC forwarding of WebTransport sessions via RPC + per-stream socketpairs. //! //! Parallel to [`super::quic`] which bridges QUIC connections over IPC, this -//! module bridges WebTransport sessions. Each WT stream is forwarded through a -//! dedicated Unix socketpair while control-plane RPCs (open, accept) travel -//! over the remoc channel. +//! module bridges WebTransport sessions. Each WebTransport stream is forwarded +//! through a dedicated Unix socketpair while control-plane RPCs (open, accept) +//! travel over the remoc channel with receiver-chosen FD transfer IDs. //! //! # Public API //! //! ## RTC trait //! -//! - [`IpcWtSession`] — session-level stream management +//! - [`IpcWebTransportSession`] — session-level stream management //! //! ## Server-side adapter //! -//! - [`WtSessionAdapter`] — wraps [`WebTransportSession`](crate::webtransport::WebTransportSession), -//! implements [`IpcWtSession`] +//! - [`WebTransportSessionAdapter`] — wraps [`WebTransportSession`](crate::webtransport::WebTransportSession), +//! implements [`IpcWebTransportSession`] //! //! ## Client-side wrapper //! -//! - [`IpcWtSessionHandle`] — wraps [`IpcWtSessionClient`], provides async +//! - [`IpcWebTransportSessionHandle`] — wraps [`IpcWebTransportSessionClient`], provides async //! stream management //! //! ## Bootstrap //! -//! - [`WtSessionBootstrap`] — one-shot value sent when a WT session is -//! established over IPC +//! - [`WebTransportSessionBootstrap`] — one-shot value sent when a WebTransport session +//! is established over IPC //! //! # FD semantics //! -//! - **`open_bi` / `accept_bi`**: 2 FDs are queued (reader pipe + writer pipe). -//! - **`open_uni`**: 1 FD (writer pipe, server side reads from pipe and sends -//! to the real WT uni stream). -//! - **`accept_uni`**: 1 FD (reader pipe, server side reads from the real WT -//! uni stream and writes to pipe). - -use std::{future::Future, sync::Arc}; +//! - **`open_bi` / `accept_bi`**: 2 FDs are delivered (read frame IO + write frame IO). +//! - **`open_uni`**: 1 FD (write frame IO, server side applies worker commands +//! to the real WebTransport uni stream). +//! - **`accept_uni`**: 1 FD (read frame IO, server side executes worker pulls +//! against the real WebTransport uni stream). + +use std::{ + future::Future, + sync::{Arc, Mutex}, +}; use serde::{Deserialize, Serialize}; use smallvec::smallvec; use snafu::Snafu; use tokio::net::UnixStream; +use tokio_util::task::AbortOnDropHandle; use tracing::{Instrument, debug}; +/// WebTransport-flavoured lifecycle helpers for IPC-backed session handles. +/// +/// IPC uses the same remoc control-plane and connection-error latch discipline +/// as RPC WebTransport handles, so this module reuses the RPC helper trait +/// directly while preserving the `ipc::webtransport::LifecycleExt` path. +pub use crate::rpc::webtransport::LifecycleExt; use crate::{ - codec::{BoxReadStream, BoxWriteStream}, + error::Code, ipc::{ quic::{ - IpcReadStream, IpcWriteStream, connection::{IPC_ERROR_KIND, IPC_FRAME_TYPE, bridge_reader, bridge_writer}, + stream::{reader as ipc_reader, writer as ipc_writer}, }, - transport::{FdRegistry, FdSender}, + transport::{FdDelivery, FdTransfer, ReceivedFds}, }, - quic::{self, ConnectionError, DynLifecycle, GetStreamIdExt}, - rpc::lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, + quic::{ + self, BoxQuicStreamReader, BoxQuicStreamWriter, ConnectionError, DynLifecycle, + GetStreamIdExt, ResetStreamExt, StopStreamExt, + }, + rpc::lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt as _}, varint::VarInt, - webtransport::{self, Closed, OpenStreamError, WtLifecycleExt}, + webtransport::{ + self, AcceptStreamError, CloseReason, CloseSession, CloseSessionError, DrainSessionError, + OpenStreamError, SessionClosed, SessionDrain, WebTransportSessionId, + }, }; // --------------------------------------------------------------------------- @@ -66,7 +82,7 @@ use crate::{ /// Captures the specific category of failure: RPC call errors preserve /// [`remoc::rtc::CallError`], QUIC stream errors preserve /// [`StreamError`](quic::StreamError), and OS-level I/O errors (socketpair, -/// FD conversion, queue) are serialized as strings. +/// FD conversion, delivery) are serialized as strings. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone, Snafu)] #[snafu(module(ipc_plumbing_error), visibility(pub))] @@ -91,8 +107,8 @@ pub enum IpcPlumbingError { /// to the latch site via [`ConnectionErrorLatch::latch_with`]. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone, Snafu)] -#[snafu(module(ipc_wt_open_error), visibility(pub))] -pub enum IpcWtOpenError { +#[snafu(module(ipc_webtransport_open_error), visibility(pub))] +pub enum IpcWebTransportOpenError { /// The underlying stream operation failed. #[snafu(transparent)] Stream { source: OpenStreamError }, @@ -104,35 +120,53 @@ pub enum IpcWtOpenError { /// Errors from IPC `accept_bi` / `accept_uni` operations. /// -/// Extends [`Closed`] with an IPC transport variant. +/// Extends [`SessionClosed`] with an IPC transport variant for FD-passing and +/// socketpair failures. Conversion to [`ConnectionError`] is deferred to the +/// latch site via [`ConnectionErrorLatch::latch_with`]. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone, Snafu)] -#[snafu(module(ipc_wt_accept_error), visibility(pub))] -pub enum IpcWtAcceptError { +#[snafu(module(ipc_webtransport_accept_error), visibility(pub))] +pub enum IpcWebTransportAcceptError { /// The session has been closed. #[snafu(display("webtransport session closed"))] Closed, + /// The underlying WebTransport connection failed. + #[snafu(transparent)] + Connection { source: ConnectionError }, + /// IPC plumbing failure (RPC, stream, or I/O). #[snafu(transparent)] Transport { source: IpcPlumbingError }, } -impl From for IpcWtOpenError { +impl From for IpcWebTransportOpenError { fn from(error: remoc::rtc::CallError) -> Self { IpcPlumbingError::Rpc { source: error }.into() } } -impl From for IpcWtAcceptError { +impl From for IpcWebTransportAcceptError { fn from(error: remoc::rtc::CallError) -> Self { IpcPlumbingError::Rpc { source: error }.into() } } -impl From for IpcWtAcceptError { - fn from(_: Closed) -> Self { - IpcWtAcceptError::Closed +impl From for IpcWebTransportAcceptError { + fn from(_source: SessionClosed) -> Self { + IpcWebTransportAcceptError::Closed + } +} + +impl From for IpcWebTransportAcceptError { + fn from(error: AcceptStreamError) -> Self { + match error { + AcceptStreamError::Closed { source } => source.into(), + AcceptStreamError::Connection { source } => Self::Connection { source }, + AcceptStreamError::StreamId { source } => Self::Transport { + source: IpcPlumbingError::Stream { source }, + }, + } } } @@ -142,44 +176,52 @@ impl From for IpcWtAcceptError { /// Remoc RPC counterpart of [`WebTransportSession`] for IPC. /// -/// All stream-opening methods return `(VarInt, VarInt)` pairs: -/// - First element: FD-registry ID for [`FdRegistry::wait_fds`]. -/// - Second element: underlying QUIC stream ID. +/// All stream-opening methods receive a receiver-chosen FD transfer ID and +/// return the underlying QUIC stream ID after the FD delivery has been +/// queued to the local mux writer FIFO. /// -/// The `session_id` is not included — it is passed via [`WtSessionBootstrap`]. +/// The `session_id` is not included — it is passed via [`WebTransportSessionBootstrap`]. #[remoc::rtc::remote] -pub trait IpcWtSession: Send + Sync { - /// Open a bidirectional stream. Returns `(fd_id, stream_id)`. +pub trait IpcWebTransportSession: Send + Sync { + async fn drain(&self) -> Result<(), DrainSessionError>; + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError>; + + async fn drained(&self) -> Result; + + async fn closed(&self) -> Result; + + /// Open a bidirectional stream. Returns the stream ID. /// The caller retrieves **2 FDs** from the registry. - async fn open_bi(&self) -> Result<(VarInt, VarInt), IpcWtOpenError>; + async fn open_bi(&self, fd_id: VarInt) -> Result; - /// Open a unidirectional (send-only) stream. Returns `(fd_id, stream_id)`. + /// Open a unidirectional (send-only) stream. Returns the stream ID. /// The caller retrieves **1 FD**. - async fn open_uni(&self) -> Result<(VarInt, VarInt), IpcWtOpenError>; + async fn open_uni(&self, fd_id: VarInt) -> Result; - /// Accept an incoming bidirectional stream. Returns `(fd_id, stream_id)`. + /// Accept an incoming bidirectional stream. Returns the stream ID. /// The caller retrieves **2 FDs**. - async fn accept_bi(&self) -> Result<(VarInt, VarInt), IpcWtAcceptError>; + async fn accept_bi(&self, fd_id: VarInt) -> Result; /// Accept an incoming unidirectional (receive-only) stream. - /// Returns `(fd_id, stream_id)`. The caller retrieves **1 FD**. - async fn accept_uni(&self) -> Result<(VarInt, VarInt), IpcWtAcceptError>; + /// Returns the stream ID. The caller retrieves **1 FD**. + async fn accept_uni(&self, fd_id: VarInt) -> Result; } // --------------------------------------------------------------------------- // Error helpers (server side) // --------------------------------------------------------------------------- -fn ipc_open_io(err: impl std::fmt::Display, context: &str) -> IpcWtOpenError { - debug!(error = %err, context, "ipc wt session i/o error"); +fn ipc_open_io(err: impl std::error::Error, context: &str) -> IpcWebTransportOpenError { + debug!(error = %snafu::Report::from_error(&err), context, "ipc webtransport session i/o error"); IpcPlumbingError::Io { message: format!("{context}: {err}"), } .into() } -fn ipc_accept_io(err: impl std::fmt::Display, context: &str) -> IpcWtAcceptError { - debug!(error = %err, context, "ipc wt session i/o error"); +fn ipc_accept_io(err: impl std::error::Error, context: &str) -> IpcWebTransportAcceptError { + debug!(error = %snafu::Report::from_error(&err), context, "ipc webtransport session i/o error"); IpcPlumbingError::Io { message: format!("{context}: {err}"), } @@ -202,207 +244,262 @@ fn ipc_connection_error(error: &IpcPlumbingError) -> ConnectionError { } } +fn ipc_accept_error_connection(error: IpcWebTransportAcceptError) -> Option { + match error { + IpcWebTransportAcceptError::Closed => None, + IpcWebTransportAcceptError::Connection { source } => Some(source), + IpcWebTransportAcceptError::Transport { source } => Some(ipc_connection_error(&source)), + } +} + // --------------------------------------------------------------------------- // Bootstrap // --------------------------------------------------------------------------- /// One-shot bootstrap payload for a WebTransport session over IPC. /// -/// Sent over the remoc channel when a WT session is established across the -/// process boundary. +/// Sent over the remoc channel when a WebTransport session is established +/// across the process boundary. #[derive(Serialize, Deserialize)] -pub struct WtSessionBootstrap { +pub struct WebTransportSessionBootstrap { /// The session ID (immutable, no RPC needed). - pub session_id: VarInt, - /// RPC client for WT session stream operations. - pub session: IpcWtSessionClient, + pub session_id: WebTransportSessionId, + /// RPC client for WebTransport session stream operations. + pub session: IpcWebTransportSessionClient, } // --------------------------------------------------------------------------- -// Server side: WtSessionAdapter +// Server side: WebTransportSessionAdapter // --------------------------------------------------------------------------- /// Server-side adapter that wraps a real [`WebTransportSession`] and -/// implements [`IpcWtSession`]. +/// implements [`IpcWebTransportSession`]. /// /// Each stream-opening call: /// 1. Delegates to the inner session to get real boxed streams. /// 2. Creates Unix socketpairs. -/// 3. Spawns bridge tasks forwarding data between real streams and pipes. -/// 4. Queues client-side FDs through [`FdSender`]. -/// 5. Returns `(fd_registry_id, stream_id)` over RPC. -pub struct WtSessionAdapter { +/// 3. Spawns bridge tasks executing typed stream-frame IPC against real streams. +/// 4. Delivers client-side FDs through [`FdTransfer`]. +/// 5. Returns the stream ID over RPC after FD delivery is queued. +pub struct WebTransportSessionAdapter { session: Arc, - fd_sender: FdSender, - lifecycle: Arc, + fd_transfer: FdTransfer, + _lifecycle: Arc, + tasks: Mutex>>, } -impl WtSessionAdapter { +impl WebTransportSessionAdapter { pub fn new( session: Arc, - fd_sender: FdSender, + fd_transfer: FdTransfer, lifecycle: Arc, ) -> Self { Self { session, - fd_sender, - lifecycle, + fd_transfer, + _lifecycle: lifecycle, + tasks: Mutex::new(Vec::new()), } } + + fn spawn_task(&self, task: impl Future + Send + 'static) { + let handle = AbortOnDropHandle::new(tokio::spawn(task.in_current_span())); + let mut tasks = self + .tasks + .lock() + .expect("webtransport adapter task registry should not be poisoned"); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); + } } -impl IpcWtSession for WtSessionAdapter { - async fn open_bi(&self) -> Result<(VarInt, VarInt), IpcWtOpenError> { +impl IpcWebTransportSession for WebTransportSessionAdapter { + async fn drain(&self) -> Result<(), DrainSessionError> { + self.session.drain().await + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + self.session.close(close).await + } + + async fn drained(&self) -> Result { + Ok(self.session.drained().await) + } + + async fn closed(&self) -> Result { + Ok(self.session.closed().await) + } + + async fn open_bi(&self, fd_id: VarInt) -> Result { + let delivery = self.fd_transfer.delivery(fd_id); let (mut reader, writer) = self.session.open_bi().await?; let stream_id = GetStreamIdExt::stream_id(&mut reader) .await .map_err(IpcPlumbingError::from)?; - self.bridge_bi(reader, writer, stream_id) + self.bridge_bi( + delivery, + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + stream_id, + ) + .await } - async fn accept_bi(&self) -> Result<(VarInt, VarInt), IpcWtAcceptError> { - let (mut reader, writer) = self.session.accept_bi().await?; + async fn accept_bi(&self, fd_id: VarInt) -> Result { + let delivery = self.fd_transfer.delivery(fd_id); + let (mut reader, writer) = match self.session.accept_bi().await { + Ok(streams) => streams, + Err(error) => return Err(error.into()), + }; let stream_id = GetStreamIdExt::stream_id(&mut reader) .await .map_err(IpcPlumbingError::from)?; - self.bridge_bi_accept(reader, writer, stream_id) + self.bridge_bi_accept( + delivery, + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + stream_id, + ) + .await } - async fn open_uni(&self) -> Result<(VarInt, VarInt), IpcWtOpenError> { + async fn open_uni(&self, fd_id: VarInt) -> Result { + let delivery = self.fd_transfer.delivery(fd_id); let mut writer = self.session.open_uni().await?; let stream_id = GetStreamIdExt::stream_id(&mut writer) .await .map_err(IpcPlumbingError::from)?; - let lifecycle: Arc = self.lifecycle.clone(); - let (srv, cli) = UnixStream::pair().map_err(|e| ipc_open_io(e, "socketpair"))?; let cli_std = cli.into_std().map_err(|e| ipc_open_io(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_std.into()]) - .map_err(|e| ipc_open_io(e, "queue_fds"))?; + if let Err(error) = delivery.deliver(smallvec![cli_std.into()]).await { + let _ = writer.reset(Code::H3_REQUEST_CANCELLED.into_inner()).await; + return Err(ipc_open_io(error, "deliver fds")); + } - // IpcReadStream on srv → real WT write stream - let pipe_r = IpcReadStream::new(stream_id, srv, lifecycle); - tokio::spawn(bridge_writer(pipe_r, writer).in_current_span()); + // Write frame IO on srv ↔ real WebTransport write stream. + self.spawn_task(bridge_writer(srv, writer)); - Ok((fd_id, stream_id)) + Ok(stream_id) } - async fn accept_uni(&self) -> Result<(VarInt, VarInt), IpcWtAcceptError> { - let mut reader = self.session.accept_uni().await?; + async fn accept_uni(&self, fd_id: VarInt) -> Result { + let delivery = self.fd_transfer.delivery(fd_id); + let mut reader = match self.session.accept_uni().await { + Ok(stream) => stream, + Err(error) => return Err(error.into()), + }; let stream_id = GetStreamIdExt::stream_id(&mut reader) .await .map_err(IpcPlumbingError::from)?; - let lifecycle: Arc = self.lifecycle.clone(); - let (srv, cli) = UnixStream::pair().map_err(|e| ipc_accept_io(e, "socketpair"))?; let cli_std = cli.into_std().map_err(|e| ipc_accept_io(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_std.into()]) - .map_err(|e| ipc_accept_io(e, "queue_fds"))?; + if let Err(error) = delivery.deliver(smallvec![cli_std.into()]).await { + let _ = reader.stop(Code::H3_REQUEST_CANCELLED.into_inner()).await; + return Err(ipc_accept_io(error, "deliver fds")); + } - // Real WT read stream → IpcWriteStream on srv - let pipe_w = IpcWriteStream::new(stream_id, srv, lifecycle); - tokio::spawn(bridge_reader(reader, pipe_w).in_current_span()); + // Real WebTransport read stream ↔ read frame IO on srv. + self.spawn_task(bridge_reader(reader, srv)); - Ok((fd_id, stream_id)) + Ok(stream_id) } } -impl WtSessionAdapter { +impl WebTransportSessionAdapter { /// Shared logic for open_bi: create 2 socketpairs and bridge. - fn bridge_bi( + async fn bridge_bi( &self, - reader: BoxReadStream, - writer: BoxWriteStream, + delivery: FdDelivery, + mut reader: BoxQuicStreamReader, + mut writer: BoxQuicStreamWriter, stream_id: VarInt, - ) -> Result<(VarInt, VarInt), IpcWtOpenError> { - let lifecycle: Arc = self.lifecycle.clone(); - + ) -> Result { let (srv_a, cli_a) = UnixStream::pair().map_err(|e| ipc_open_io(e, "socketpair"))?; let (srv_b, cli_b) = UnixStream::pair().map_err(|e| ipc_open_io(e, "socketpair"))?; let cli_a_std = cli_a.into_std().map_err(|e| ipc_open_io(e, "into_std"))?; let cli_b_std = cli_b.into_std().map_err(|e| ipc_open_io(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_a_std.into(), cli_b_std.into()]) - .map_err(|e| ipc_open_io(e, "queue_fds"))?; + if let Err(error) = delivery + .deliver(smallvec![cli_a_std.into(), cli_b_std.into()]) + .await + { + let code = Code::H3_REQUEST_CANCELLED.into_inner(); + let _ = reader.stop(code).await; + let _ = writer.reset(code).await; + return Err(ipc_open_io(error, "deliver fds")); + } - // Bridge reader direction: real WT reader → IpcWriteStream on srv_a - let pipe_w = IpcWriteStream::new(stream_id, srv_a, lifecycle.clone()); - tokio::spawn(bridge_reader(reader, pipe_w).in_current_span()); + // Bridge reader direction: real WebTransport reader ↔ read frame IO on srv_a. + self.spawn_task(bridge_reader(reader, srv_a)); - // Bridge writer direction: IpcReadStream on srv_b → real WT writer - let pipe_r = IpcReadStream::new(stream_id, srv_b, lifecycle); - tokio::spawn(bridge_writer(pipe_r, writer).in_current_span()); + // Bridge writer direction: write frame IO on srv_b ↔ real WebTransport writer. + self.spawn_task(bridge_writer(srv_b, writer)); - Ok((fd_id, stream_id)) + Ok(stream_id) } /// Shared logic for accept_bi: create 2 socketpairs and bridge. - fn bridge_bi_accept( + async fn bridge_bi_accept( &self, - reader: BoxReadStream, - writer: BoxWriteStream, + delivery: FdDelivery, + mut reader: BoxQuicStreamReader, + mut writer: BoxQuicStreamWriter, stream_id: VarInt, - ) -> Result<(VarInt, VarInt), IpcWtAcceptError> { - let lifecycle: Arc = self.lifecycle.clone(); - + ) -> Result { let (srv_a, cli_a) = UnixStream::pair().map_err(|e| ipc_accept_io(e, "socketpair"))?; let (srv_b, cli_b) = UnixStream::pair().map_err(|e| ipc_accept_io(e, "socketpair"))?; let cli_a_std = cli_a.into_std().map_err(|e| ipc_accept_io(e, "into_std"))?; let cli_b_std = cli_b.into_std().map_err(|e| ipc_accept_io(e, "into_std"))?; - let fd_id = self - .fd_sender - .queue_fds(smallvec![cli_a_std.into(), cli_b_std.into()]) - .map_err(|e| ipc_accept_io(e, "queue_fds"))?; + if let Err(error) = delivery + .deliver(smallvec![cli_a_std.into(), cli_b_std.into()]) + .await + { + let code = Code::H3_REQUEST_CANCELLED.into_inner(); + let _ = reader.stop(code).await; + let _ = writer.reset(code).await; + return Err(ipc_accept_io(error, "deliver fds")); + } - let pipe_w = IpcWriteStream::new(stream_id, srv_a, lifecycle.clone()); - tokio::spawn(bridge_reader(reader, pipe_w).in_current_span()); + self.spawn_task(bridge_reader(reader, srv_a)); - let pipe_r = IpcReadStream::new(stream_id, srv_b, lifecycle); - tokio::spawn(bridge_writer(pipe_r, writer).in_current_span()); + self.spawn_task(bridge_writer(srv_b, writer)); - Ok((fd_id, stream_id)) + Ok(stream_id) } } // --------------------------------------------------------------------------- -// IpcWtLifecycle: session-level lifecycle for IPC WT streams +// IpcWebTransportLifecycle: session-level lifecycle for IPC WebTransport streams // --------------------------------------------------------------------------- -/// Internal lifecycle state for [`IpcWtSessionHandle`]. +/// Internal lifecycle state for [`IpcWebTransportSessionHandle`]. /// /// Delegates to the parent connection's lifecycle while maintaining a local /// [`ConnectionErrorLatch`]. Implements [`quic::Lifecycle`] so it can be -/// shared with [`IpcReadStream`] / [`IpcWriteStream`] as -/// `Arc`. -struct IpcWtLifecycle { +/// shared with direct IPC stream-frame bridge handles. +struct IpcWebTransportLifecycle { parent: Arc, latch: ConnectionErrorLatch, } -impl HasLatch for IpcWtLifecycle { +impl HasLatch for IpcWebTransportLifecycle { fn latch(&self) -> &ConnectionErrorLatch { &self.latch } } -impl quic::Lifecycle for IpcWtLifecycle { +impl quic::Lifecycle for IpcWebTransportLifecycle { fn close(&self, code: crate::error::Code, reason: std::borrow::Cow<'static, str>) { DynLifecycle::close(self.parent.as_ref(), code, reason); } @@ -418,38 +515,38 @@ impl quic::Lifecycle for IpcWtLifecycle { } // --------------------------------------------------------------------------- -// Client side: IpcWtSessionHandle +// Client side: IpcWebTransportSessionHandle // --------------------------------------------------------------------------- /// Client-side WebTransport session handle that wraps an -/// [`IpcWtSessionClient`] and provides async stream management via FD passing. +/// [`IpcWebTransportSessionClient`] and provides async stream management via FD passing. /// -/// Created by the client after receiving a [`WtSessionBootstrap`] over the +/// Created by the client after receiving a [`WebTransportSessionBootstrap`] over the /// remoc channel. -pub struct IpcWtSessionHandle { - session_id: VarInt, - rpc: IpcWtSessionClient, - fd_registry: FdRegistry, - lifecycle: Arc, +pub struct IpcWebTransportSessionHandle { + session_id: WebTransportSessionId, + rpc: IpcWebTransportSessionClient, + fd_transfer: FdTransfer, + lifecycle: Arc, } -impl IpcWtSessionHandle { +impl IpcWebTransportSessionHandle { /// Create a new handle from bootstrap data, FD registry, and the parent /// connection's lifecycle. pub fn new( - session_id: VarInt, - rpc: IpcWtSessionClient, - fd_registry: FdRegistry, + session_id: WebTransportSessionId, + rpc: IpcWebTransportSessionClient, + fd_transfer: FdTransfer, conn_lifecycle: Arc, ) -> Self { - let lifecycle = Arc::new(IpcWtLifecycle { + let lifecycle = Arc::new(IpcWebTransportLifecycle { parent: conn_lifecycle, latch: ConnectionErrorLatch::new(), }); Self { session_id, rpc, - fd_registry, + fd_transfer, lifecycle, } } @@ -458,12 +555,12 @@ impl IpcWtSessionHandle { /// any resulting error. async fn guard_ipc_open( &self, - fut: impl Future>, + fut: impl Future>, ) -> Result { self.lifecycle .guard_open_with(fut, |e| match e { - IpcWtOpenError::Stream { source } => source, - IpcWtOpenError::Transport { source } => OpenStreamError::Open { + IpcWebTransportOpenError::Stream { source } => source, + IpcWebTransportOpenError::Transport { source } => OpenStreamError::Open { source: ipc_connection_error(&source), }, }) @@ -474,20 +571,17 @@ impl IpcWtSessionHandle { /// any resulting transport error. async fn guard_ipc_accept( &self, - fut: impl Future>, - ) -> Result { + fut: impl Future>, + ) -> Result { self.lifecycle - .guard_accept_err(fut, |e| match e { - IpcWtAcceptError::Closed => None, - IpcWtAcceptError::Transport { source } => Some(ipc_connection_error(&source)), - }) + .guard_accept_err(fut, ipc_accept_error_connection) .await } /// Log and lazily latch an IPC plumbing error for open operations. fn latch_open_transport(&self, err: impl std::error::Error, context: &str) -> OpenStreamError { - debug!(error = %snafu::Report::from_error(&err), context, "ipc wt session error"); - let message = format!("ipc wt: {context}: {err}"); + debug!(error = %snafu::Report::from_error(&err), context, "ipc webtransport session error"); + let message = format!("ipc webtransport: {context}: {err}"); let source = self.lifecycle.latch().latch_with(|| { let plumbing = IpcPlumbingError::Io { message }; ipc_connection_error(&plumbing) @@ -496,88 +590,193 @@ impl IpcWtSessionHandle { } /// Log and lazily latch an IPC plumbing error for accept operations. - fn latch_accept_transport(&self, err: impl std::error::Error, context: &str) -> Closed { - debug!(error = %snafu::Report::from_error(&err), context, "ipc wt session error"); - let message = format!("ipc wt: {context}: {err}"); - let _ = self.lifecycle.latch().latch_with(|| { + fn latch_accept_transport( + &self, + err: impl std::error::Error, + context: &str, + ) -> AcceptStreamError { + debug!(error = %snafu::Report::from_error(&err), context, "ipc webtransport session error"); + let message = format!("ipc webtransport: {context}: {err}"); + let source = self.lifecycle.latch().latch_with(|| { let plumbing = IpcPlumbingError::Io { message }; ipc_connection_error(&plumbing) }); - Closed + AcceptStreamError::Connection { source } + } + + fn fd_to_open_unix_stream( + &self, + fd: std::os::fd::OwnedFd, + ) -> Result { + let std_stream = std::os::unix::net::UnixStream::from(fd); + std_stream + .set_nonblocking(true) + .map_err(|e| self.latch_open_transport(e, "set_nonblocking"))?; + UnixStream::from_std(std_stream) + .map_err(|e| self.latch_open_transport(e, "unix stream from_std")) + } + + fn fd_to_accept_unix_stream( + &self, + fd: std::os::fd::OwnedFd, + ) -> Result { + let std_stream = std::os::unix::net::UnixStream::from(fd); + std_stream + .set_nonblocking(true) + .map_err(|e| self.latch_accept_transport(e, "set_nonblocking"))?; + UnixStream::from_std(std_stream) + .map_err(|e| self.latch_accept_transport(e, "unix stream from_std")) + } + + async fn resolve_open_with_fds( + &self, + receiver: crate::ipc::transport::FdReceiver, + rpc: impl Future>, + ) -> Result<(VarInt, ReceivedFds), IpcWebTransportOpenError> { + let receive = receiver.into_future(); + tokio::pin!(receive); + tokio::pin!(rpc); + + tokio::select! { + biased; + receive_result = &mut receive => { + let received = receive_result.map_err(|e| ipc_open_io(e, "receive fds"))?; + let stream_id = rpc.await?; + Ok((stream_id, received)) + } + rpc_result = &mut rpc => { + let stream_id = rpc_result?; + let received = receive.await.map_err(|e| ipc_open_io(e, "receive fds"))?; + Ok((stream_id, received)) + } + } + } + + async fn resolve_accept_with_fds( + &self, + receiver: crate::ipc::transport::FdReceiver, + rpc: impl Future>, + ) -> Result<(VarInt, ReceivedFds), IpcWebTransportAcceptError> { + let receive = receiver.into_future(); + tokio::pin!(receive); + tokio::pin!(rpc); + + tokio::select! { + biased; + receive_result = &mut receive => { + let received = receive_result.map_err(|e| ipc_accept_io(e, "receive fds"))?; + let stream_id = rpc.await?; + Ok((stream_id, received)) + } + rpc_result = &mut rpc => { + let stream_id = rpc_result?; + let received = receive.await.map_err(|e| ipc_accept_io(e, "receive fds"))?; + Ok((stream_id, received)) + } + } } } -impl webtransport::Session for IpcWtSessionHandle { - type StreamReader = IpcReadStream; - type StreamWriter = IpcWriteStream; +impl webtransport::Session for IpcWebTransportSessionHandle { + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; - fn session_id(&self) -> VarInt { + fn id(&self) -> WebTransportSessionId { self.session_id } - async fn open_bi(&self) -> Result<(IpcReadStream, IpcWriteStream), OpenStreamError> { - let (fd_id, stream_id) = self - .guard_ipc_open(IpcWtSession::open_bi(&self.rpc)) + async fn drain(&self) -> Result<(), DrainSessionError> { + IpcWebTransportSession::drain(&self.rpc).await + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + IpcWebTransportSession::close(&self.rpc, close).await + } + + async fn drained(&self) -> SessionDrain { + match IpcWebTransportSession::drained(&self.rpc).await { + Ok(drain) => drain, + Err(reason) => SessionDrain::Closed(reason), + } + } + + async fn closed(&self) -> CloseReason { + match IpcWebTransportSession::closed(&self.rpc).await { + Ok(reason) | Err(reason) => reason, + } + } + + async fn open_bi(&self) -> Result<(BoxQuicStreamReader, BoxQuicStreamWriter), OpenStreamError> { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + let (stream_id, received) = + self.guard_ipc_open(self.resolve_open_with_fds( + receiver, + IpcWebTransportSession::open_bi(&self.rpc, fd_id), + )) .await?; - self.fds_to_bi(fd_id, stream_id).await + self.fds_to_bi(stream_id, received).await } - async fn accept_bi(&self) -> Result<(IpcReadStream, IpcWriteStream), Closed> { - let (fd_id, stream_id) = self - .guard_ipc_accept(IpcWtSession::accept_bi(&self.rpc)) + async fn accept_bi( + &self, + ) -> Result<(BoxQuicStreamReader, BoxQuicStreamWriter), AcceptStreamError> { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + let (stream_id, received) = self + .guard_ipc_accept(self.resolve_accept_with_fds( + receiver, + IpcWebTransportSession::accept_bi(&self.rpc, fd_id), + )) .await?; - self.fds_to_bi_accept(fd_id, stream_id).await + self.fds_to_bi_accept(stream_id, received).await } - async fn open_uni(&self) -> Result { - let (fd_id, stream_id) = self - .guard_ipc_open(IpcWtSession::open_uni(&self.rpc)) + async fn open_uni(&self) -> Result { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + let (stream_id, received) = self + .guard_ipc_open(self.resolve_open_with_fds( + receiver, + IpcWebTransportSession::open_uni(&self.rpc, fd_id), + )) .await?; - self.fds_to_uni_writer(fd_id, stream_id).await + self.fds_to_uni_writer(stream_id, received).await } - async fn accept_uni(&self) -> Result { - let (fd_id, stream_id) = self - .guard_ipc_accept(IpcWtSession::accept_uni(&self.rpc)) + async fn accept_uni(&self) -> Result { + let receiver = self.fd_transfer.receive(); + let fd_id = receiver.id(); + let (stream_id, received) = self + .guard_ipc_accept(self.resolve_accept_with_fds( + receiver, + IpcWebTransportSession::accept_uni(&self.rpc, fd_id), + )) .await?; - self.fds_to_uni_reader(fd_id, stream_id).await + self.fds_to_uni_reader(stream_id, received).await } } -impl IpcWtSessionHandle { - /// Retrieve 2 FDs and construct a (IpcReadStream, IpcWriteStream) pair. +impl IpcWebTransportSessionHandle { + /// Retrieve 2 FDs and construct boxed IPC stream-frame bridge handles. async fn fds_to_bi( &self, - fd_id: VarInt, stream_id: VarInt, - ) -> Result<(IpcReadStream, IpcWriteStream), OpenStreamError> { - let fds = self - .fd_registry - .wait_fds(fd_id) - .await - .map_err(|e| self.latch_open_transport(e, "wait_fds"))?; - if fds.len() != 2 { - return Err(self.latch_open_transport( - FdCountError { - expected: 2, - got: fds.len(), - }, - "fd count", - )); - } - let mut fds = fds.into_iter(); - let fd_a = fds.next().unwrap(); - let fd_b = fds.next().unwrap(); + received: ReceivedFds, + ) -> Result<(BoxQuicStreamReader, BoxQuicStreamWriter), OpenStreamError> { + let (fd_a, fd_b) = received + .into_pair() + .map_err(|e| self.latch_open_transport(e, "fd count"))?; - let lifecycle: Arc = self.lifecycle.clone(); + let lifecycle = self.lifecycle.clone(); - let sock_a = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd_a)) - .map_err(|e| self.latch_open_transport(e, "UnixStream::from_std"))?; - let reader = IpcReadStream::new(stream_id, sock_a, lifecycle.clone()); + let sock_a = self.fd_to_open_unix_stream(fd_a)?; + let reader = Box::pin(ipc_reader::reader(stream_id, sock_a, lifecycle.clone())) + as BoxQuicStreamReader; - let sock_b = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd_b)) - .map_err(|e| self.latch_open_transport(e, "UnixStream::from_std"))?; - let writer = IpcWriteStream::new(stream_id, sock_b, lifecycle); + let sock_b = self.fd_to_open_unix_stream(fd_b)?; + let writer = + Box::pin(ipc_writer::writer(stream_id, sock_b, lifecycle)) as BoxQuicStreamWriter; Ok((reader, writer)) } @@ -585,118 +784,337 @@ impl IpcWtSessionHandle { /// Retrieve 2 FDs for accept_bi. async fn fds_to_bi_accept( &self, - fd_id: VarInt, stream_id: VarInt, - ) -> Result<(IpcReadStream, IpcWriteStream), Closed> { - let fds = self - .fd_registry - .wait_fds(fd_id) - .await - .map_err(|e| self.latch_accept_transport(e, "wait_fds"))?; - if fds.len() != 2 { - return Err(self.latch_accept_transport( - FdCountError { - expected: 2, - got: fds.len(), - }, - "fd count", - )); - } - let mut fds = fds.into_iter(); - let fd_a = fds.next().unwrap(); - let fd_b = fds.next().unwrap(); + received: ReceivedFds, + ) -> Result<(BoxQuicStreamReader, BoxQuicStreamWriter), AcceptStreamError> { + let (fd_a, fd_b) = received + .into_pair() + .map_err(|e| self.latch_accept_transport(e, "fd count"))?; - let lifecycle: Arc = self.lifecycle.clone(); + let lifecycle = self.lifecycle.clone(); - let sock_a = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd_a)) - .map_err(|e| self.latch_accept_transport(e, "UnixStream::from_std"))?; - let reader = IpcReadStream::new(stream_id, sock_a, lifecycle.clone()); + let sock_a = self.fd_to_accept_unix_stream(fd_a)?; + let reader = Box::pin(ipc_reader::reader(stream_id, sock_a, lifecycle.clone())) + as BoxQuicStreamReader; - let sock_b = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd_b)) - .map_err(|e| self.latch_accept_transport(e, "UnixStream::from_std"))?; - let writer = IpcWriteStream::new(stream_id, sock_b, lifecycle); + let sock_b = self.fd_to_accept_unix_stream(fd_b)?; + let writer = + Box::pin(ipc_writer::writer(stream_id, sock_b, lifecycle)) as BoxQuicStreamWriter; Ok((reader, writer)) } - /// Retrieve 1 FD and construct a IpcWriteStream (for open_uni). + /// Retrieve 1 FD and construct a boxed IPC write bridge (for open_uni). async fn fds_to_uni_writer( &self, - fd_id: VarInt, stream_id: VarInt, - ) -> Result { - let fds = self - .fd_registry - .wait_fds(fd_id) - .await - .map_err(|e| self.latch_open_transport(e, "wait_fds"))?; - if fds.len() != 1 { - return Err(self.latch_open_transport( - FdCountError { - expected: 1, - got: fds.len(), - }, - "fd count", - )); - } - let fd = fds.into_iter().next().unwrap(); - let lifecycle: Arc = self.lifecycle.clone(); - let sock = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd)) - .map_err(|e| self.latch_open_transport(e, "UnixStream::from_std"))?; - Ok(IpcWriteStream::new(stream_id, sock, lifecycle)) + received: ReceivedFds, + ) -> Result { + let fd = received + .into_one() + .map_err(|e| self.latch_open_transport(e, "fd count"))?; + let lifecycle = self.lifecycle.clone(); + let sock = self.fd_to_open_unix_stream(fd)?; + Ok(Box::pin(ipc_writer::writer(stream_id, sock, lifecycle)) as BoxQuicStreamWriter) } - /// Retrieve 1 FD and construct a IpcReadStream (for accept_uni). + /// Retrieve 1 FD and construct a boxed IPC read bridge (for accept_uni). async fn fds_to_uni_reader( &self, - fd_id: VarInt, stream_id: VarInt, - ) -> Result { - let fds = self - .fd_registry - .wait_fds(fd_id) - .await - .map_err(|e| self.latch_accept_transport(e, "wait_fds"))?; - if fds.len() != 1 { - return Err(self.latch_accept_transport( - FdCountError { - expected: 1, - got: fds.len(), - }, - "fd count", - )); - } - let fd = fds.into_iter().next().unwrap(); - let lifecycle: Arc = self.lifecycle.clone(); - let sock = UnixStream::from_std(std::os::unix::net::UnixStream::from(fd)) - .map_err(|e| self.latch_accept_transport(e, "UnixStream::from_std"))?; - Ok(IpcReadStream::new(stream_id, sock, lifecycle)) + received: ReceivedFds, + ) -> Result { + let fd = received + .into_one() + .map_err(|e| self.latch_accept_transport(e, "fd count"))?; + let lifecycle = self.lifecycle.clone(); + let sock = self.fd_to_accept_unix_stream(fd)?; + Ok(Box::pin(ipc_reader::reader(stream_id, sock, lifecycle)) as BoxQuicStreamReader) } } -/// Helper error type for FD count mismatches. -#[derive(Debug)] -struct FdCountError { - expected: usize, - got: usize, +impl IpcWebTransportSessionClient { + /// Convert into an [`IpcWebTransportSessionHandle`]. + pub fn into_handle( + self, + session_id: WebTransportSessionId, + fd_transfer: FdTransfer, + conn_lifecycle: Arc, + ) -> IpcWebTransportSessionHandle { + IpcWebTransportSessionHandle::new(session_id, self, fd_transfer, conn_lifecycle) + } } -impl std::fmt::Display for FdCountError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "expected {} fds, got {}", self.expected, self.got) +#[cfg(test)] +mod tests { + use std::future; + + use remoc::prelude::ServerShared; + + use super::*; + use crate::{ + connection::{ConnectionState, tests::MockConnection}, + dhttp::{ + message::{MessageWriter, test::read_stream_for_test}, + protocol::DHttpProtocol, + settings::Settings, + webtransport::settings::{ + EnableWebTransport, InitialMaxData, InitialMaxStreamsBidi, InitialMaxStreamsUni, + }, + }, + extended_connect::{EstablishedConnect, PendingWriteStreamError}, + protocol::Protocols, + qpack::field::Protocol, + stream_id::StreamId, + webtransport::{ + CloseReason, CloseSession, DrainReason, SessionCloseReason, SessionDrain, + SessionDrainReason, WEBTRANSPORT_H3, WebTransportProtocol, + }, + }; + + fn connection_error(reason: &'static str) -> ConnectionError { + ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } } -} -impl std::error::Error for FdCountError {} + fn assert_transport_reason(error: &ConnectionError, expected: &str) { + let ConnectionError::Transport { source } = error else { + panic!("expected transport error"); + }; + assert_eq!(source.reason.as_ref(), expected); + } -impl IpcWtSessionClient { - /// Convert into an [`IpcWtSessionHandle`]. - pub fn into_handle( - self, - session_id: VarInt, - fd_registry: FdRegistry, - conn_lifecycle: Arc, - ) -> IpcWtSessionHandle { - IpcWtSessionHandle::new(session_id, self, fd_registry, conn_lifecycle) + fn connection_with_webtransport( + mock: Arc, + ) -> Arc> { + let erased: Arc = mock.clone(); + let mut protocols = Protocols::new(); + let dhttp = DHttpProtocol::new_for_test(erased.clone()); + dhttp + .state + .peer_settings + .set(Arc::new(enabled_webtransport_settings())) + .expect("peer settings should be set once"); + protocols.insert(dhttp); + protocols.insert(WebTransportProtocol::new_for_test(erased)); + Arc::new(ConnectionState::new_for_test(mock, Arc::new(protocols)).erase()) + } + + fn enabled_webtransport_settings() -> Settings { + let mut settings = Settings::default(); + settings.set(EnableWebTransport::setting(true)); + settings.set(InitialMaxStreamsBidi::setting(VarInt::from_u32(16))); + settings.set(InitialMaxStreamsUni::setting(VarInt::from_u32(16))); + settings.set(InitialMaxData::setting(VarInt::MAX)); + settings + } + + fn webtransport_session_for_test( + mock: Arc, + stream_id: StreamId, + ) -> Arc { + let connection = connection_with_webtransport(mock); + let session = webtransport::WebTransportSession::try_from(EstablishedConnect::pending( + stream_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection.clone(), + read_stream_for_test(stream_id.0), + future::pending::>(), + )) + .expect("webtransport session should be registered"); + Arc::new(session) + } + + fn spawn_ipc_session( + session: Arc, + ) -> (AbortOnDropHandle<()>, IpcWebTransportSessionClient) + where + S: IpcWebTransportSession + 'static, + { + let (server, client) = IpcWebTransportSessionServerShared::new(session, 1); + let task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + (task, client) + } + + struct TestIpcSession { + drained: Mutex, + closed: Mutex>, + } + + impl TestIpcSession { + fn new() -> Self { + Self { + drained: Mutex::new(false), + closed: Mutex::new(None), + } + } + } + + impl IpcWebTransportSession for TestIpcSession { + async fn drain(&self) -> Result<(), DrainSessionError> { + *self + .drained + .lock() + .expect("drained mutex should not poison") = true; + Ok(()) + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + *self.closed.lock().expect("closed mutex should not poison") = Some(close); + Ok(()) + } + + async fn drained(&self) -> Result { + if *self + .drained + .lock() + .expect("drained mutex should not poison") + { + Ok(SessionDrain::Requested(DrainReason::Session( + SessionDrainReason::Local, + ))) + } else { + Ok(SessionDrain::Closed(CloseReason::Session( + SessionCloseReason::ControlStreamError, + ))) + } + } + + async fn closed(&self) -> Result { + match self + .closed + .lock() + .expect("closed mutex should not poison") + .clone() + { + Some(close) => Ok(CloseReason::Session(SessionCloseReason::Local(close))), + None => Ok(CloseReason::Session(SessionCloseReason::ControlStreamError)), + } + } + + async fn open_bi(&self, _fd_id: VarInt) -> Result { + Err(IpcPlumbingError::Io { + message: "test ipc session does not open bidi streams".into(), + } + .into()) + } + + async fn open_uni(&self, _fd_id: VarInt) -> Result { + Err(IpcPlumbingError::Io { + message: "test ipc session does not open uni streams".into(), + } + .into()) + } + + async fn accept_bi(&self, _fd_id: VarInt) -> Result { + Err(IpcWebTransportAcceptError::Closed) + } + + async fn accept_uni(&self, _fd_id: VarInt) -> Result { + Err(IpcWebTransportAcceptError::Closed) + } + } + + fn fd_transfer_for_test() -> crate::ipc::transport::FdTransfer { + let (mux, _peer) = + crate::ipc::transport::MuxChannel::pair_for_test().expect("mux channel pair"); + let (sink, stream) = mux.split().expect("split mux channel"); + stream.fd_transfer(sink.fd_sender()) + } + + fn webtransport_adapter(reason: &'static str) -> WebTransportSessionAdapter { + let mock = Arc::new(MockConnection::new()); + let session = + webtransport_session_for_test(mock.clone(), StreamId::from(VarInt::from_u32(4))); + mock.set_terminal_error(connection_error(reason)); + let lifecycle: Arc = mock; + WebTransportSessionAdapter::new(session, fd_transfer_for_test(), lifecycle) + } + + #[test] + fn ipc_accept_error_preserves_connection_source() { + let error = IpcWebTransportAcceptError::from(AcceptStreamError::Connection { + source: connection_error("accept connection closed"), + }); + + let IpcWebTransportAcceptError::Connection { source } = error else { + panic!("expected IPC accept connection error"); + }; + assert_transport_reason(&source, "accept connection closed"); + } + + #[test] + fn ipc_accept_error_connection_mapping_preserves_connection_source() { + let source = ipc_accept_error_connection(IpcWebTransportAcceptError::Connection { + source: connection_error("mapped accept connection closed"), + }) + .expect("connection IPC accept error should latch a connection error"); + + assert_transport_reason(&source, "mapped accept connection closed"); + assert!(ipc_accept_error_connection(IpcWebTransportAcceptError::Closed).is_none()); + } + + #[tokio::test] + async fn webtransport_adapter_preserves_connection_closed_accepts() { + let adapter = webtransport_adapter("accept_bi connection closed"); + let error = IpcWebTransportSession::accept_bi(&adapter, VarInt::from_u32(1)) + .await + .expect_err("accept_bi should preserve connection failure"); + let IpcWebTransportAcceptError::Connection { source } = error else { + panic!("expected accept_bi connection error"); + }; + assert_transport_reason(&source, "accept_bi connection closed"); + + let adapter = webtransport_adapter("accept_uni connection closed"); + let error = IpcWebTransportSession::accept_uni(&adapter, VarInt::from_u32(2)) + .await + .expect_err("accept_uni should preserve connection failure"); + let IpcWebTransportAcceptError::Connection { source } = error else { + panic!("expected accept_uni connection error"); + }; + assert_transport_reason(&source, "accept_uni connection closed"); + } + + #[tokio::test] + async fn ipc_webtransport_session_delegates_control_methods() { + let lifecycle = Arc::new(MockConnection::new()); + let session_id = WebTransportSessionId::try_from(StreamId::from(VarInt::from_u32(4))) + .expect("test session id should be valid"); + let session = Arc::new(TestIpcSession::new()); + let (_server_task, client) = spawn_ipc_session(session); + let handle_lifecycle: Arc = lifecycle; + let handle = IpcWebTransportSessionHandle::new( + session_id, + client, + fd_transfer_for_test(), + handle_lifecycle, + ); + let close = CloseSession::try_from((5_u32, "bye")).expect("valid close"); + + webtransport::Session::drain(&handle) + .await + .expect("ipc drain should succeed"); + assert_eq!( + webtransport::Session::drained(&handle).await, + SessionDrain::Requested(DrainReason::Session(SessionDrainReason::Local)) + ); + + webtransport::Session::close(&handle, close.clone()) + .await + .expect("ipc close should succeed"); + assert_eq!( + webtransport::Session::closed(&handle).await, + CloseReason::Session(SessionCloseReason::Local(close)) + ); } } diff --git a/src/lib.rs b/src/lib.rs index c569792..53f6583 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,16 @@ #![doc = include_str!("../README.md")] pub mod buflist; -pub mod client; pub mod codec; pub mod connection; pub mod dhttp; pub mod error; -pub mod message; +pub mod extended_connect; pub mod pool; pub mod protocol; pub mod qpack; pub mod quic; -pub mod server; +pub mod stream; pub mod stream_id; mod util; pub mod varint; @@ -19,7 +18,6 @@ pub mod varint; #[cfg(feature = "dquic")] pub mod dquic; -#[cfg(feature = "endpoint")] pub mod endpoint; #[cfg(feature = "hyper")] diff --git a/src/message.rs b/src/message.rs deleted file mode 100644 index 6a44b2e..0000000 --- a/src/message.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod stream; -pub mod unify; diff --git a/src/message/state.rs b/src/message/state.rs deleted file mode 100644 index e69de29..0000000 diff --git a/src/message/stream.rs b/src/message/stream.rs deleted file mode 100644 index 8284ae9..0000000 --- a/src/message/stream.rs +++ /dev/null @@ -1,817 +0,0 @@ -use std::{ - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use bytes::{Buf, Bytes}; -use futures::{SinkExt, TryStreamExt}; -use snafu::Snafu; - -use crate::{ - codec::{EncodeError, EncodeExt, SinkWriter, StreamReader}, - connection::{self, ConnectionGoaway, ConnectionState, LifecycleExt}, - dhttp::{ - frame::{ - Frame, - stream::{FrameStream, ReadableFrame}, - }, - protocol::{ - AcceptRawMessageStreamError, BoxDynQuicStreamReader, BoxDynQuicStreamWriter, - DHttpState, InitialRawMessageStreamError, - }, - }, - error::{Code, H3FrameDecodeError, H3FrameUnexpected}, - qpack::{ - algorithm::{DynamicCompressAlgo, HuffmanAlways}, - encoder::{EncodeHeaderSectionError, Encoder}, - field::{FieldLine, FieldSection}, - protocol::{QPackDecoder, QPackEncoder, QPackProtocolDisabled}, - }, - quic::{self, CancelStreamExt, GetStreamIdExt, StopStreamExt}, - varint::{self, VarInt}, -}; - -pub(crate) mod guard; -#[cfg(feature = "hyper")] -pub(crate) mod hyper; -pub(crate) mod unfold; - -pub use self::unfold::{ - read::{BoxMessageStreamReader, ReadMessageStream}, - write::{BoxMessageStreamWriter, WriteMessageStream}, -}; - -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -#[derive(Debug, Snafu)] -pub enum MessageStreamError { - #[snafu(transparent)] - Quic { source: quic::StreamError }, - #[snafu(transparent)] - Goaway { source: ConnectionGoaway }, - #[snafu(display( - "header section too large to fit into a single frame, maybe too many header fields" - ))] - HeaderTooLarge, - #[snafu(display( - "trailer section too large to fit into a single frame, maybe too many header fields" - ))] - TrailerTooLarge, - #[snafu(display("data frame payload too large, try smaller chunk size"))] - DataFrameTooLarge { source: varint::err::Overflow }, - #[snafu(display("HTTP/3 message from peer is malformed"))] - MalformedIncomingMessage, -} - -impl From for MessageStreamError { - fn from(value: quic::ConnectionError) -> Self { - Self::Quic { - source: value.into(), - } - } -} - -impl From for MessageStreamError { - fn from(source: varint::err::Overflow) -> Self { - Self::DataFrameTooLarge { source } - } -} - -impl From for io::Error { - fn from(error: MessageStreamError) -> Self { - use io::ErrorKind; - - let kind = match &error { - // Delegate to the underlying QUIC stream error's kind mapping. - MessageStreamError::Quic { .. } => None, - MessageStreamError::Goaway { .. } => Some(ErrorKind::ConnectionAborted), - MessageStreamError::MalformedIncomingMessage => Some(ErrorKind::InvalidData), - MessageStreamError::HeaderTooLarge - | MessageStreamError::TrailerTooLarge - | MessageStreamError::DataFrameTooLarge { .. } => Some(ErrorKind::InvalidInput), - }; - match kind { - Some(kind) => io::Error::new(kind, error), - None => match error { - MessageStreamError::Quic { source } => io::Error::from(source), - _ => unreachable!("non-Quic variants always resolve to a concrete kind"), - }, - } - } -} - -pub struct ReadStream { - pub(super) stream: FrameStream, - pub(super) qpack_decoder: Arc, - pub(super) connection: Arc, - dhttp_state: Arc, -} - -impl ReadStream { - pub fn new( - stream: StreamReader, - qpack_decoder: Arc, - connection: Arc, - dhttp_state: Arc, - ) -> Self { - let frame_stream = FrameStream::new(stream); - Self { - stream: frame_stream, - qpack_decoder, - connection, - dhttp_state, - } - } - - pub async fn peer_goaway_covers( - &mut self, - ) -> Result> + use<>, quic::StreamError> - { - let stream_id = self.stream.stream_id().await?; - let dhttp_state = self.dhttp_state.clone(); - let conn = self.connection.clone(); - - Ok(async move { - let error = conn.closed(); - tokio::select! { - biased; - _goaway = dhttp_state.peer_goaway_covers(stream_id) => Ok(()), - error = error => Err(error), - } - }) - } - - pub async fn try_stream_io( - &mut self, - f: impl AsyncFnOnce(&mut Self) -> Result, - ) -> Result { - let peer_goaway = self.peer_goaway_covers().await?; - tokio::select! { - result = f(self) => match result { - Ok(value) => Ok(value), - Err(error) => Err(self.handle_stream_error(error).await.into()), - }, - goaway = peer_goaway => match goaway { - Ok(()) => { - // FIXME: which code should be used? - _ = self.stream.stop(Code::H3_NO_ERROR.into()).await; - Err(ConnectionGoaway::Peer.into()) - } - Err(error) => Err(error.into()) - } - } - } - - /// Resolve a stream-level H3 error into a `quic::StreamError`, performing - /// the correct side effect for each variant: - /// - /// - `Connection` — delegate to [`LifecycleExt::handle_connection_error`] - /// so that a fresh H3 connection-scope violation closes the QUIC - /// connection. - /// - `Reset` — nothing to do locally; the peer already reset the stream. - /// - `H3` — a freshly detected stream-scope protocol violation; issue - /// `STOP_SENDING` on this reader so the peer observes the abort. - pub async fn handle_stream_error( - &mut self, - error: connection::StreamError, - ) -> quic::StreamError { - match error { - connection::StreamError::Connection { source } => self - .connection - .as_ref() - .handle_connection_error(source) - .await - .into(), - connection::StreamError::Reset { code } => quic::StreamError::Reset { code }, - connection::StreamError::H3 { source } => { - let code = source.code().into_inner(); - _ = self.stream.stop(code).await; - quic::StreamError::Reset { code } - } - } - } - - pub async fn peek_frame( - &mut self, - ) -> Option, connection::StreamError>> { - loop { - match Pin::new(&mut self.stream).frame() { - None => match Pin::new(&mut self.stream).next_unreserved_frame().await? { - Ok(_next_frame) => continue, - Err(error) => return Some(Err(error)), - }, - Some(Ok(frame)) - if frame.r#type() == Frame::HEADERS_FRAME_TYPE - || frame.r#type() == Frame::DATA_FRAME_TYPE => - { - // avoid rust bc bug - return Pin::new(&mut self.stream).frame(); - } - Some(Ok(_frame)) => { - return Some(Err(H3FrameUnexpected::UnexpectedFrameType.into())); - } - Some(Err(error)) => return Some(Err(error)), - } - } - } - - pub async fn read_data_frame_chunk( - &mut self, - ) -> Option> { - loop { - match self.peek_frame().await { - Some(Ok(mut frame)) if frame.r#type() == Frame::DATA_FRAME_TYPE => { - match frame.try_next().await { - Ok(Some(bytes)) => return Some(Ok(bytes)), - Ok(None) => { - _ = Pin::new(&mut self.stream).consume_current_frame().await; - continue; - } - Err(error) => { - let error = error.into_stream_error(|error| { - H3FrameDecodeError { source: error }.into() - }); - return Some(Err(error)); - } - } - } - Some(Ok(..)) | None => return None, - Some(Err(error)) => return Some(Err(error)), - } - } - } - - pub async fn read_header_frame( - &mut self, - ) -> Option> { - match self.peek_frame().await { - Some(Ok(frame)) if frame.r#type() == Frame::HEADERS_FRAME_TYPE => { - let frame = match Pin::new(&mut self.stream).frame()? { - Ok(frame) => frame, - Err(error) => return Some(Err(error)), - }; - match self.qpack_decoder.decode(frame).await { - Ok(field_section) => { - _ = Pin::new(&mut self.stream).consume_current_frame().await; - Some(Ok(field_section)) - } - Err(error) => Some(Err(error)), - } - } - Some(Ok(..)) | None => None, - Some(Err(error)) => Some(Err(error)), - } - } - - pub async fn stop(&mut self, code: Code) -> Result<(), MessageStreamError> { - self.try_stream_io(async move |this| Ok(this.stream.stop(code.into_inner()).await?)) - .await - } - - pub fn take(&mut self) -> Self { - let taken = self.stream.inner_mut().take(); - Self { - stream: FrameStream::new(StreamReader::new(taken)), - qpack_decoder: self.qpack_decoder.clone(), - connection: self.connection.clone(), - dhttp_state: self.dhttp_state.clone(), - } - } -} - -impl quic::GetStreamId for ReadStream { - fn poll_stream_id( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_stream_id(cx) - } -} - -impl quic::StopStream for ReadStream { - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_stop(cx, code) - } -} - -pub struct WriteStream { - pub(super) stream: SinkWriter, - pub(super) qpack_encoder: Arc, - pub(super) connection: Arc, - dhttp_state: Arc, -} - -pub const DEFAULT_COMPRESS_ALGO: DynamicCompressAlgo = - DynamicCompressAlgo::new(HuffmanAlways); - -impl WriteStream { - pub fn new( - stream: SinkWriter, - qpack_encoder: Arc, - connection: Arc, - dhttp_state: Arc, - ) -> Self { - Self { - stream, - qpack_encoder, - connection, - dhttp_state, - } - } - - pub async fn send_frame( - &mut self, - frame: Frame, - ) -> Result<(), quic::StreamError> { - self.stream.encode_one(frame).await - } - - async fn peer_goaway_covers( - &mut self, - ) -> Result> + use<>, quic::StreamError> - { - let stream_id = self.stream.stream_id().await?; - let dhttp_state = self.dhttp_state.clone(); - let conn = self.connection.clone(); - - Ok(async move { - let error = conn.closed(); - tokio::select! { - biased; - _goaway = dhttp_state.peer_goaway_covers(stream_id) => Ok(()), - error = error => Err(error), - } - }) - } - - pub async fn try_stream_io( - &mut self, - f: impl AsyncFnOnce(&mut Self) -> Result, - ) -> Result { - let peer_goaway = self.peer_goaway_covers().await?; - let f = async move |this: &mut Self| { - let value = f(this).await?; - // ensure all data are written into the underlying QUIC stream - this.stream.flush_buffer().await?; - Ok(value) - }; - tokio::select! { - result = f(self) => match result { - Ok(value) => Ok(value), - Err(error) => Err(self.handle_stream_error(error).await.into()), - }, - goaway = peer_goaway => match goaway { - Ok(()) => { - // FIXME: which code should be used? - _ = self.stream.cancel(Code::H3_NO_ERROR.into()).await; - Err(ConnectionGoaway::Peer.into()) - } - Err(error) => Err(error.into()) - } - } - } - - /// Resolve a stream-level H3 error into a `quic::StreamError`. See - /// [`ReadStream::handle_stream_error`] for semantics; the difference is - /// that `H3` errors issue `RESET_STREAM` (via `cancel`) on this writer - /// instead of `STOP_SENDING`. - pub async fn handle_stream_error( - &mut self, - error: connection::StreamError, - ) -> quic::StreamError { - match error { - connection::StreamError::Connection { source } => self - .connection - .as_ref() - .handle_connection_error(source) - .await - .into(), - connection::StreamError::Reset { code } => quic::StreamError::Reset { code }, - connection::StreamError::H3 { source } => { - let code = source.code().into_inner(); - _ = self.stream.cancel(code).await; - quic::StreamError::Reset { code } - } - } - } - - pub async fn send_header( - &mut self, - field_lines: impl IntoIterator + Send, - ) -> Result<(), MessageStreamError> { - let algo = &DEFAULT_COMPRESS_ALGO; - let result = self - .try_stream_io(async move |this| { - let stream = &mut this.stream; - match Encoder::encode(&*this.qpack_encoder, field_lines, algo, stream).await { - Ok(frame) => Ok(Ok(this.send_frame(frame).await?)), - Err(EncodeHeaderSectionError::Encode { source }) => Ok(Err(source)), - Err(EncodeHeaderSectionError::Stream { source }) => Err(source), - } - }) - .await?; - - // Flush encoder instructions (dynamic table insertions) to the encoder stream. - // Encoder stream errors are connection-level: reset = connection error per RFC 9204. - if let Err(error) = self.qpack_encoder.flush_instructions().await { - let quic_error = self.handle_stream_error(error).await; - return Err(quic_error.into()); - } - - match result { - Ok(()) => Ok(()), - Err(EncodeError::FramePayloadTooLarge) => Err(MessageStreamError::HeaderTooLarge), - Err(EncodeError::HuffmanEncoding) => { - unreachable!("FieldSection contain invalid header name/value, this is a bug") - } - } - } - - pub async fn flush(&mut self) -> Result<(), MessageStreamError> { - self.try_stream_io(async move |this| Ok(this.stream.flush_inner().await?)) - .await - } - - pub async fn close(&mut self) -> Result<(), MessageStreamError> { - self.try_stream_io(async move |this| Ok(this.stream.close().await?)) - .await - } - - pub async fn cancel(&mut self, code: Code) -> Result<(), MessageStreamError> { - self.try_stream_io(async move |this| Ok(this.stream.cancel(code.into_inner()).await?)) - .await - } - - pub fn take(&mut self) -> Self { - let taken = self.stream.sink_mut().take(); - Self { - stream: SinkWriter::new(taken), - qpack_encoder: self.qpack_encoder.clone(), - connection: self.connection.clone(), - dhttp_state: self.dhttp_state.clone(), - } - } -} - -impl quic::GetStreamId for WriteStream { - fn poll_stream_id( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_stream_id(cx) - } -} - -impl quic::CancelStream for WriteStream { - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - Pin::new(&mut self.get_mut().stream).poll_cancel(cx, code) - } -} - -#[derive(Debug, Snafu, Clone)] -pub enum InitialMessageStreamError { - #[snafu(transparent)] - InitialRawStream { - source: InitialRawMessageStreamError, - }, - #[snafu(transparent)] - QPackProtocolDisabled { source: QPackProtocolDisabled }, -} - -#[derive(Debug, Snafu, Clone)] -pub enum AcceptMessageStreamError { - #[snafu(transparent)] - AcceptRawStream { source: AcceptRawMessageStreamError }, - #[snafu(transparent)] - QPackProtocolDisabled { source: QPackProtocolDisabled }, -} - -impl ConnectionState { - pub async fn initial_message_stream( - &self, - ) -> Result<(ReadStream, WriteStream), InitialMessageStreamError> { - let qpack = self.qpack()?; - let dhttp = self.dhttp(); - let (reader, writer) = self.initial_raw_message_stream().await?; - Ok(( - ReadStream::new( - reader, - qpack.decoder.clone(), - self.quic().clone() as Arc, - dhttp.state.clone(), - ), - WriteStream::new( - writer, - qpack.encoder.clone(), - self.quic().clone() as Arc, - dhttp.state.clone(), - ), - )) - } - - pub async fn accept_message_stream( - &self, - ) -> Result<(ReadStream, WriteStream), AcceptMessageStreamError> { - let qpack = self.qpack()?; - let dhttp = self.dhttp(); - let (reader, writer) = self.accept_raw_message_stream().await?; - Ok(( - ReadStream::new( - reader, - qpack.decoder.clone(), - self.quic().clone() as Arc, - dhttp.state.clone(), - ), - WriteStream::new( - writer, - qpack.encoder.clone(), - self.quic().clone() as Arc, - dhttp.state.clone(), - ), - )) - } -} - -#[cfg(test)] -mod tests { - use std::{ - pin::Pin, - sync::Arc, - task::{Context, Poll}, - }; - - use bytes::Bytes; - use futures::{Sink, SinkExt, Stream}; - - use super::{MessageStreamError, ReadStream, WriteStream, guard}; - use crate::{ - codec::{SinkWriter, StreamReader}, - connection::{ConnectionState, StreamError, tests::MockConnection}, - dhttp::{goaway::Goaway, protocol::DHttpProtocol, settings::Settings}, - protocol::Protocols, - qpack::protocol::{QPackDecoder, QPackEncoder}, - quic, - varint::VarInt, - }; - - #[test] - fn io_error_kind_is_derived_per_variant() { - use std::io::ErrorKind; - - fn assert_kind(error: MessageStreamError, expected: ErrorKind) { - let repr = format!("{error:?}"); - let io_error = std::io::Error::from(error); - assert_eq!(io_error.kind(), expected, "unexpected kind for {repr}"); - } - - assert_kind(MessageStreamError::HeaderTooLarge, ErrorKind::InvalidInput); - assert_kind(MessageStreamError::TrailerTooLarge, ErrorKind::InvalidInput); - assert_kind( - MessageStreamError::DataFrameTooLarge { - source: crate::varint::VarInt::from_u64(1 << 63) - .expect_err("value exceeds varint encoding"), - }, - ErrorKind::InvalidInput, - ); - assert_kind( - MessageStreamError::MalformedIncomingMessage, - ErrorKind::InvalidData, - ); - assert_kind( - MessageStreamError::Goaway { - source: crate::connection::ConnectionGoaway::Peer, - }, - ErrorKind::ConnectionAborted, - ); - assert_kind( - MessageStreamError::Quic { - source: quic::StreamError::Reset { - code: VarInt::from_u32(0), - }, - }, - ErrorKind::BrokenPipe, - ); - } - - fn qpack_decoder_sink() - -> Pin + Send>> - { - Box::pin( - futures::sink::drain::() - .sink_map_err(|never| match never {}), - ) - } - - fn qpack_decoder_stream() -> Pin< - Box< - dyn Stream> - + Send, - >, - > { - Box::pin(futures::stream::empty::< - Result, - >()) - } - - fn qpack_encoder_sink() - -> Pin + Send>> - { - Box::pin( - futures::sink::drain::() - .sink_map_err(|never| match never {}), - ) - } - - fn qpack_encoder_stream() -> Pin< - Box< - dyn Stream> - + Send, - >, - > { - Box::pin(futures::stream::empty::< - Result, - >()) - } - - #[derive(Debug)] - struct TestReadStream { - stream_id: VarInt, - } - - impl quic::GetStreamId for TestReadStream { - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - Poll::Ready(Ok(self.get_mut().stream_id)) - } - } - - impl quic::StopStream for TestReadStream { - fn poll_stop( - self: Pin<&mut Self>, - _cx: &mut Context, - _code: VarInt, - ) -> Poll> { - Poll::Ready(Ok(())) - } - } - - impl Stream for TestReadStream { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(None) - } - } - - #[derive(Debug)] - struct TestWriteStream { - stream_id: VarInt, - } - - impl quic::GetStreamId for TestWriteStream { - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - Poll::Ready(Ok(self.get_mut().stream_id)) - } - } - - impl quic::CancelStream for TestWriteStream { - fn poll_cancel( - self: Pin<&mut Self>, - _cx: &mut Context, - _code: VarInt, - ) -> Poll> { - Poll::Ready(Ok(())) - } - } - - impl Sink for TestWriteStream { - type Error = quic::StreamError; - - fn poll_ready( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { - Ok(()) - } - - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_close( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - } - - #[tokio::test] - async fn read_stream_try_stream_io_aborts_when_peer_goaway_covers_stream() { - let quic = Arc::new(MockConnection::new()); - let erased_connection: Arc = quic.clone(); - let connection: Arc = quic.clone(); - - let mut protocols = Protocols::new(); - protocols.insert(DHttpProtocol::new_for_test(erased_connection)); - let state = ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)); - - let reader = StreamReader::new(guard::GuardedQuicReader::new(Box::pin(TestReadStream { - stream_id: VarInt::from_u32(10), - }) - as crate::codec::BoxReadStream)); - let mut read_stream = ReadStream::new( - reader, - Arc::new(QPackDecoder::new( - Arc::new(Settings::default()), - qpack_decoder_sink(), - qpack_decoder_stream(), - )), - connection, - state.dhttp().state.clone(), - ); - - state - .dhttp() - .apply_peer_goaway(Goaway::new(VarInt::from_u32(9))) - .expect("peer goaway should be accepted"); - - let result = read_stream - .try_stream_io(async move |_this| { - futures::future::pending::>().await - }) - .await; - - assert!(matches!( - result, - Err(super::MessageStreamError::Goaway { - source: crate::connection::ConnectionGoaway::Peer - }) - )); - } - - #[tokio::test] - async fn write_stream_try_stream_io_aborts_when_peer_goaway_covers_stream() { - let quic = Arc::new(MockConnection::new()); - let erased_connection: Arc = quic.clone(); - let connection: Arc = quic.clone(); - - let mut protocols = Protocols::new(); - protocols.insert(DHttpProtocol::new_for_test(erased_connection)); - let state = ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)); - - let writer = SinkWriter::new(guard::GuardedQuicWriter::new(Box::pin(TestWriteStream { - stream_id: VarInt::from_u32(12), - }) - as crate::codec::BoxWriteStream)); - let mut write_stream = WriteStream::new( - writer, - Arc::new(QPackEncoder::new( - Arc::new(Settings::default()), - qpack_encoder_sink(), - qpack_encoder_stream(), - )), - connection, - state.dhttp().state.clone(), - ); - - state - .dhttp() - .apply_peer_goaway(Goaway::new(VarInt::from_u32(11))) - .expect("peer goaway should be accepted"); - - let result = write_stream - .try_stream_io(async move |_this| { - futures::future::pending::>().await - }) - .await; - - assert!(matches!( - result, - Err(super::MessageStreamError::Goaway { - source: crate::connection::ConnectionGoaway::Peer - }) - )); - } -} diff --git a/src/message/stream/guard.rs b/src/message/stream/guard.rs deleted file mode 100644 index 823dab6..0000000 --- a/src/message/stream/guard.rs +++ /dev/null @@ -1,263 +0,0 @@ -use std::{ - mem, - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures::{Sink, Stream}; -use tracing::Instrument; - -use crate::{ - codec::{BoxReadStream, BoxWriteStream}, - error::Code, - quic::{self, CancelStreamExt, StopStreamExt}, - varint::VarInt, -}; - -// --------------------------------------------------------------------------- -// Sentinel — panics if the stream is used after take / drop -// --------------------------------------------------------------------------- - -struct DroppedStream; - -fn stream_used_after_dropped() -> ! { - panic!("guarded QUIC stream used after being taken or dropped, this is a bug") -} - -impl Stream for DroppedStream { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - stream_used_after_dropped() - } -} - -impl Sink for DroppedStream { - type Error = quic::StreamError; - - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - stream_used_after_dropped() - } - - fn start_send(self: Pin<&mut Self>, _: Bytes) -> Result<(), Self::Error> { - stream_used_after_dropped() - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - stream_used_after_dropped() - } - - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - stream_used_after_dropped() - } -} - -impl quic::CancelStream for DroppedStream { - fn poll_cancel( - self: Pin<&mut Self>, - _cx: &mut Context, - _code: VarInt, - ) -> Poll> { - stream_used_after_dropped() - } -} - -impl quic::StopStream for DroppedStream { - fn poll_stop( - self: Pin<&mut Self>, - _cx: &mut Context, - _code: VarInt, - ) -> Poll> { - stream_used_after_dropped() - } -} - -impl quic::GetStreamId for DroppedStream { - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - stream_used_after_dropped() - } -} - -fn dropped_reader() -> BoxReadStream { - Box::pin(DroppedStream) -} - -fn dropped_writer() -> BoxWriteStream { - Box::pin(DroppedStream) -} - -// --------------------------------------------------------------------------- -// GuardedQuicReader -// --------------------------------------------------------------------------- - -/// A QUIC read stream wrapper that automatically stops the stream on drop -/// if it hasn't been fully consumed (EOF) or explicitly stopped. -pub struct GuardedQuicReader { - inner: BoxReadStream, - completed: bool, -} - -impl GuardedQuicReader { - pub fn new(inner: BoxReadStream) -> Self { - Self { - inner, - completed: false, - } - } - - /// Take the inner stream, replacing it with a sentinel. - /// Marks this guard as completed (no cleanup on drop). - pub fn take(&mut self) -> Self { - let inner = mem::replace(&mut self.inner, dropped_reader()); - let completed = self.completed; - self.completed = true; - Self { inner, completed } - } -} - -impl Stream for GuardedQuicReader { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - let result = this.inner.as_mut().poll_next(cx); - if let Poll::Ready(None) = &result { - this.completed = true; - } - result - } -} - -impl quic::StopStream for GuardedQuicReader { - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let this = self.get_mut(); - let result = this.inner.as_mut().poll_stop(cx, code); - if let Poll::Ready(Ok(())) = &result { - this.completed = true; - } - result - } -} - -impl quic::GetStreamId for GuardedQuicReader { - fn poll_stream_id( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { - self.get_mut().inner.as_mut().poll_stream_id(cx) - } -} - -impl Drop for GuardedQuicReader { - fn drop(&mut self) { - if !self.completed { - let mut inner = mem::replace(&mut self.inner, dropped_reader()); - tokio::spawn( - async move { - _ = inner.stop(Code::H3_NO_ERROR.into()).await; - } - .in_current_span(), - ); - } - } -} - -// --------------------------------------------------------------------------- -// GuardedQuicWriter -// --------------------------------------------------------------------------- - -/// A QUIC write stream wrapper that automatically cancels the stream on drop -/// if it hasn't been properly closed or explicitly cancelled. -pub struct GuardedQuicWriter { - inner: BoxWriteStream, - completed: bool, -} - -impl GuardedQuicWriter { - pub fn new(inner: BoxWriteStream) -> Self { - Self { - inner, - completed: false, - } - } - - /// Take the inner stream, replacing it with a sentinel. - /// Marks this guard as completed (no cleanup on drop). - pub fn take(&mut self) -> Self { - let inner = mem::replace(&mut self.inner, dropped_writer()); - let completed = self.completed; - self.completed = true; - Self { inner, completed } - } -} - -impl Sink for GuardedQuicWriter { - type Error = quic::StreamError; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().inner.as_mut().poll_ready(cx) - } - - fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { - self.get_mut().inner.as_mut().start_send(item) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.get_mut().inner.as_mut().poll_flush(cx) - } - - fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - let result = this.inner.as_mut().poll_close(cx); - if let Poll::Ready(Ok(())) = &result { - this.completed = true; - } - result - } -} - -impl quic::CancelStream for GuardedQuicWriter { - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let this = self.get_mut(); - let result = this.inner.as_mut().poll_cancel(cx, code); - if let Poll::Ready(Ok(())) = &result { - this.completed = true; - } - result - } -} - -impl quic::GetStreamId for GuardedQuicWriter { - fn poll_stream_id( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { - self.get_mut().inner.as_mut().poll_stream_id(cx) - } -} - -impl Drop for GuardedQuicWriter { - fn drop(&mut self) { - if !self.completed { - let mut inner = mem::replace(&mut self.inner, dropped_writer()); - tokio::spawn( - async move { - _ = inner.cancel(Code::H3_NO_ERROR.into()).await; - } - .in_current_span(), - ); - } - } -} diff --git a/src/message/stream/hyper.rs b/src/message/stream/hyper.rs deleted file mode 100644 index f5727d2..0000000 --- a/src/message/stream/hyper.rs +++ /dev/null @@ -1,5 +0,0 @@ -use super::{MessageStreamError, ReadStream, WriteStream}; - -pub mod read; -pub mod upgrade; -pub mod write; diff --git a/src/message/stream/hyper/read.rs b/src/message/stream/hyper/read.rs deleted file mode 100644 index cfc61ef..0000000 --- a/src/message/stream/hyper/read.rs +++ /dev/null @@ -1,169 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use bytes::Bytes; -use futures::{StreamExt, stream}; -use http_body::{Body, Frame, SizeHint}; -use http_body_util::{BodyExt, Empty, StreamBody}; - -use super::{ - MessageStreamError, ReadStream, - upgrade::{RemainStream, TakeoverSlot}, -}; -use crate::{connection, error::H3MessageError}; - -pin_project_lite::pin_project! { - #[project = EitherProj] - pub enum Either { - Left { #[pin] body: L }, - Right { #[pin] body: R } - } -} - -impl Either { - pub fn left(body: L) -> Self { - Self::Left { body } - } - - pub fn right(body: R) -> Self { - Self::Right { body } - } -} - -impl> Body for Either { - type Data = L::Data; - type Error = L::Error; - - fn poll_frame( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match self.project() { - EitherProj::Left { body } => body.poll_frame(cx), - EitherProj::Right { body } => body.poll_frame(cx), - } - } - - fn is_end_stream(&self) -> bool { - match self { - Either::Left { body } => body.is_end_stream(), - Either::Right { body } => body.is_end_stream(), - } - } - - fn size_hint(&self) -> SizeHint { - match self { - Either::Left { body } => body.size_hint(), - Either::Right { body } => body.size_hint(), - } - } -} - -impl ReadStream { - pub async fn read_hyper_request_parts( - &mut self, - ) -> Result { - self.try_stream_io(async |stream| { - let Some(field_section) = stream.read_header_frame().await.transpose()? else { - return Err(H3MessageError::MissingHeaderSection.into()); - }; - Ok(http::request::Parts::try_from(field_section)?) - }) - .await - } - - pub async fn read_hyper_response_parts( - &mut self, - ) -> Result { - self.try_stream_io(async |stream| { - let Some(field_section) = stream.read_header_frame().await.transpose()? else { - return Err(H3MessageError::MissingHeaderSection.into()); - }; - Ok(http::response::Parts::try_from(field_section)?) - }) - .await - } - - pub async fn read_hyper_frame( - &mut self, - ) -> Option, connection::StreamError>> { - match self.read_data_frame_chunk().await { - Some(data) => Some(data.map(Frame::data)), - None => match self.read_header_frame().await? { - Ok(field_section) if !field_section.is_trailer() => { - Some(Err(H3MessageError::UnexpectedHeadersInBody.into())) - } - Ok(field_section) => Some(Ok(Frame::trailers(field_section.header_map))), - Err(error) => Some(Err(error)), - }, - } - } - - pub fn as_hyper_body(&mut self) -> impl Body + Send { - StreamBody::new( - stream::unfold(self, async |stream| { - let frame = stream - .try_stream_io(async |stream| stream.read_hyper_frame().await.transpose()) - .await - .transpose()?; - Some((frame, stream)) - }) - .fuse(), - ) - } - - pub fn into_hyper_body(self) -> impl Body + Send { - StreamBody::new( - stream::unfold(self, async |mut stream| { - let frame = stream - .try_stream_io(async |stream| stream.read_hyper_frame().await.transpose()) - .await - .transpose()?; - Some((frame, stream)) - }) - .fuse(), - ) - } - - pub async fn into_hyper_request( - mut self, - ) -> Result< - http::Request + Send>, - MessageStreamError, - > { - let mut parts = self.read_hyper_request_parts().await?; - if parts.method == http::Method::CONNECT { - parts - .extensions - .insert(TakeoverSlot::new(RemainStream::immediately(self))); - let body = Either::right(Empty::new().map_err(|n| match n {})); - Ok(http::Request::from_parts(parts, body)) - } else { - let body = Either::left(self.into_hyper_body()); - Ok(http::Request::from_parts(parts, body)) - } - } - - pub async fn into_hyper_response( - mut self, - ) -> Result< - http::Response + Send>, - MessageStreamError, - > { - let mut parts = self.read_hyper_response_parts().await?; - match parts.status.is_informational() { - true => { - parts.extensions.insert(RemainStream::immediately(self)); - let body = Either::right(Empty::new().map_err(|n| match n {})); - Ok(http::Response::from_parts(parts, body)) - } - false => { - // no remain - let body = self.into_hyper_body(); - Ok(http::Response::from_parts(parts, Either::left(body))) - } - } - } -} diff --git a/src/message/stream/hyper/write.rs b/src/message/stream/hyper/write.rs deleted file mode 100644 index d08d7b3..0000000 --- a/src/message/stream/hyper/write.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::pin::pin; - -use http_body::Body; -use http_body_util::BodyExt; -use snafu::{ResultExt, Snafu}; - -use super::{MessageStreamError, WriteStream}; -use crate::qpack::field::hyper::{ - header_map_to_field_lines, hyper_request_parts_to_field_lines, - hyper_response_parts_to_field_lines, -}; - -#[derive(Debug, Snafu)] -#[snafu(module, visibility(pub(crate)))] -pub enum SendMessageError { - #[snafu(display("failed to send message on stream"))] - Stream { source: MessageStreamError }, - #[snafu(display("failed to read body frame"))] - Body { source: E }, -} - -impl SendMessageError { - pub fn map_body_error( - self, - f: impl FnOnce(E) -> E1, - ) -> SendMessageError { - match self { - SendMessageError::Stream { source } => SendMessageError::Stream { source }, - SendMessageError::Body { source } => SendMessageError::Body { source: f(source) }, - } - } -} - -impl WriteStream { - pub(crate) async fn send_hyper_body( - &mut self, - body: B, - ) -> Result<(), SendMessageError> - where - B::Data: Send, - B::Error: std::error::Error + 'static, - { - let mut body = pin!(body); - while let Some(frame) = body.frame().await { - let frame = frame.context(send_message_error::BodySnafu)?; - let frame = match frame.into_data() { - Ok(data) => { - self.send_data(data) - .await - .context(send_message_error::StreamSnafu)?; - continue; - } - Err(frame) => frame, - }; - let frame = match frame.into_trailers() { - Ok(trailers) => { - self.send_header(header_map_to_field_lines(trailers)) - .await - .context(send_message_error::StreamSnafu)?; - break; - } - Err(frame) => frame, - }; - - tracing::warn!("ignore unknown http body frame"); - _ = frame; - } - Ok(()) - } - - pub async fn send_hyper_request_parts( - &mut self, - parts: http::request::Parts, - ) -> Result<(), MessageStreamError> { - self.send_header(hyper_request_parts_to_field_lines(parts)) - .await - } - - pub async fn send_hyper_request( - &mut self, - request: http::Request, - ) -> Result<(), SendMessageError> - where - B::Data: Send, - B::Error: std::error::Error + 'static, - { - let (parts, body) = request.into_parts(); - self.send_header(hyper_request_parts_to_field_lines(parts)) - .await - .context(send_message_error::StreamSnafu)?; - self.send_hyper_body(body).await - } - - pub async fn send_hyper_response_parts( - &mut self, - parts: http::response::Parts, - ) -> Result<(), MessageStreamError> { - self.send_header(hyper_response_parts_to_field_lines(parts)) - .await - } - - pub async fn send_hyper_response( - &mut self, - response: http::Response, - ) -> Result<(), SendMessageError> - where - B::Data: Send, - B::Error: std::error::Error + 'static, - { - let (parts, body) = response.into_parts(); - self.send_header(hyper_response_parts_to_field_lines(parts)) - .await - .context(send_message_error::StreamSnafu)?; - self.send_hyper_body(body).await - } -} diff --git a/src/message/stream/unfold/read.rs b/src/message/stream/unfold/read.rs deleted file mode 100644 index 465152a..0000000 --- a/src/message/stream/unfold/read.rs +++ /dev/null @@ -1,217 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll, ready}, -}; - -use bytes::Bytes; -use futures::stream::FusedStream; - -use super::super::{MessageStreamError, ReadStream}; -use crate::{ - codec::StreamReader, - quic::{self, GetStreamId, StopStream}, - varint::VarInt, -}; - -/// A message-level byte stream that also supports QUIC stream control operations. -/// -/// This is the message-layer analog of [`quic::ReadStream`], combining DATA-frame -/// byte streaming with the underlying QUIC stream's [`StopStream`] and -/// [`GetStreamId`] capabilities. -pub trait ReadMessageStream: - StopStream + GetStreamId + FusedStream> + Send -{ -} - -impl< - T: StopStream - + GetStreamId - + FusedStream> - + Send - + ?Sized, -> ReadMessageStream for T -{ -} - -/// Boxed stream reader with QUIC stream control traits preserved. -pub type BoxMessageStreamReader<'s> = StreamReader>>; - -impl From for BoxMessageStreamReader<'static> { - fn from(value: ReadStream) -> Self { - value.into_box_reader() - } -} - -// --------------------------------------------------------------------------- -// Unfold – custom stream unfold that preserves QUIC traits -// --------------------------------------------------------------------------- - -pin_project_lite::pin_project! { - #[project = StateProj] - #[project_replace = StateProjReplace] - enum State { - Value { value: T }, - Future { #[pin] future: Fut }, - Done, - Empty, - } -} - -pin_project_lite::pin_project! { - /// A fused stream adapter similar to [`futures::stream::unfold`], but the - /// inner value's QUIC control traits ([`GetStreamId`], [`StopStream`]) are - /// forwarded when the value is not consumed by an in-flight future. - #[must_use = "streams do nothing unless polled"] - pub struct Unfold { - f: F, - #[pin] - state: State, - } -} - -/// Create an [`Unfold`] stream. -/// -/// Works like [`futures::stream::unfold`] but the returned stream conditionally -/// implements [`GetStreamId`] and [`StopStream`] when `T` does. -pub fn unfold(init: T, f: F) -> Unfold -where - F: FnMut(T) -> Fut, - Fut: Future>, -{ - Unfold { - f, - state: State::Value { value: init }, - } -} - -impl futures::Stream for Unfold -where - F: FnMut(T) -> Fut, - Fut: Future>, -{ - type Item = Item; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - - loop { - match this.state.as_mut().project() { - StateProj::Value { .. } => { - let value = match this.state.as_mut().project_replace(State::Empty) { - StateProjReplace::Value { value } => value, - _ => unreachable!(), - }; - let fut = (this.f)(value); - this.state.set(State::Future { future: fut }); - } - StateProj::Future { future } => match ready!(future.poll(cx)) { - Some((item, value)) => { - this.state.set(State::Value { value }); - return Poll::Ready(Some(item)); - } - None => { - this.state.set(State::Done); - return Poll::Ready(None); - } - }, - StateProj::Done | StateProj::Empty => { - return Poll::Ready(None); - } - } - } - } -} - -impl FusedStream for Unfold -where - F: FnMut(T) -> Fut, - Fut: Future>, -{ - fn is_terminated(&self) -> bool { - matches!(self.state, State::Done) - } -} - -impl GetStreamId for Unfold -where - F: FnMut(T) -> Fut, - Fut: Future>, -{ - fn poll_stream_id( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { - let this = self.project(); - match this.state.project() { - StateProj::Value { value } => Pin::new(value).poll_stream_id(cx), - _ => Poll::Pending, - } - } -} - -impl StopStream for Unfold -where - F: FnMut(T) -> Fut, - Fut: Future>, -{ - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let this = self.project(); - match this.state.project() { - StateProj::Value { value } => Pin::new(value).poll_stop(cx, code), - _ => Poll::Pending, - } - } -} - -// --------------------------------------------------------------------------- -// ReadStream conversion methods -// --------------------------------------------------------------------------- - -impl ReadStream { - pub fn as_bytes_stream(&mut self) -> impl ReadMessageStream + '_ { - unfold(self, |this: &mut ReadStream| async move { - match this - .try_stream_io(async |this| this.read_data_frame_chunk().await.transpose()) - .await - .transpose()? - { - Ok(bytes) => Some((Ok(bytes), this)), - Err(error) => Some((Err(error), this)), - } - }) - } - - pub fn as_reader(&mut self) -> StreamReader { - StreamReader::new(self.as_bytes_stream()) - } - - pub fn as_box_reader(&mut self) -> BoxMessageStreamReader<'_> { - StreamReader::new(Box::pin(self.as_bytes_stream())) - } - - pub fn into_bytes_stream(self) -> impl ReadMessageStream { - unfold(self, |mut this: ReadStream| async move { - match this - .try_stream_io(async |this| this.read_data_frame_chunk().await.transpose()) - .await - .transpose()? - { - Ok(bytes) => Some((Ok(bytes), this)), - Err(error) => Some((Err(error), this)), - } - }) - } - - pub fn into_reader(self) -> StreamReader { - StreamReader::new(self.into_bytes_stream()) - } - - pub fn into_box_reader(self) -> BoxMessageStreamReader<'static> { - StreamReader::new(Box::pin(self.into_bytes_stream())) - } -} diff --git a/src/message/stream/unfold/write.rs b/src/message/stream/unfold/write.rs deleted file mode 100644 index 5a71448..0000000 --- a/src/message/stream/unfold/write.rs +++ /dev/null @@ -1,266 +0,0 @@ -//! since futures::sink::unfold only works for send, we implement our own version here to support flush and close as well. - -use std::{ - ops::Deref, - pin::Pin, - task::{Context, Poll, ready}, -}; - -use bytes::Bytes; -use futures::Sink; - -use super::super::{MessageStreamError, WriteStream}; -use crate::{ - codec::SinkWriter, - quic::{self, CancelStream, GetStreamId}, - varint::VarInt, -}; - -/// A message-level byte sink that also supports QUIC stream control operations. -/// -/// This is the message-layer analog of [`quic::WriteStream`], combining DATA-frame -/// byte sinking with the underlying QUIC stream's [`CancelStream`] and -/// [`GetStreamId`] capabilities. -pub trait WriteMessageStream: - CancelStream + GetStreamId + Sink + Send -{ -} - -impl + Send + ?Sized> - WriteMessageStream for T -{ -} - -/// Boxed stream writer with QUIC stream control traits preserved. -pub type BoxMessageStreamWriter<'s> = SinkWriter>>; - -impl From for BoxMessageStreamWriter<'static> { - fn from(value: WriteStream) -> Self { - value.into_box_writer() - } -} - -// --------------------------------------------------------------------------- -// Unfold – custom sink unfold that preserves QUIC traits -// --------------------------------------------------------------------------- - -pin_project_lite::pin_project! { - #[project = StateProj] - #[project_replace = StateProjReplace] - #[derive(Debug)] - enum State { - Value { - value: T, - }, - Send { - #[pin] - future: S, - }, - Flush { - #[pin] - future: F, - }, - Close { - #[pin] - future: C, - }, - Empty, - } -} - -impl State { - fn take_value(self: Pin<&mut Self>) -> Option { - match &*self { - Self::Value { .. } => match self.project_replace(Self::Empty) { - StateProjReplace::Value { value } => Some(value), - _ => unreachable!(), - }, - _ => None, - } - } - - fn poll_value(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> - where - S: Future>, - F: Future>, - C: Future>, - { - match self.as_mut().project() { - StateProj::Send { future } => future.poll(cx), - StateProj::Flush { future } => future.poll(cx), - StateProj::Close { future } => future.poll(cx), - StateProj::Value { .. } | StateProj::Empty => { - Poll::Ready(Ok(self.take_value().expect("value lost, this is a bug"))) - } - } - } -} - -pin_project_lite::pin_project! { - #[derive(Debug)] - #[must_use = "sinks do nothing unless polled"] - pub struct Unfold { - send: S, - flush: F, - close: C, - #[pin] - state: State, - } -} - -impl Sink for Unfold -where - S: FnMut(T, Item) -> SF, - SF: Future>, - F: FnMut(T) -> FF, - FF: Future>, - C: FnMut(T) -> CF, - CF: Future>, -{ - type Error = E; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut project = self.project(); - let value = ready!(project.state.as_mut().poll_value(cx)?); - project.state.set(State::Value { value }); - Poll::Ready(Ok(())) - } - - fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { - let mut project = self.project(); - let future = match project.state.as_mut().take_value() { - Some(value) => (project.send)(value, item), - None => panic!("start_send called without poll_ready being called first"), - }; - project.state.set(State::Send { future }); - Ok(()) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - if matches!(self.deref().state, State::Flush { .. }) { - return self.poll_ready(cx); - } - - let mut project = self.as_mut().project(); - let value = ready!(project.state.as_mut().poll_value(cx))?; - project.state.set(State::Flush { - future: (project.flush)(value), - }); - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - if matches!(self.deref().state, State::Close { .. }) { - return self.poll_ready(cx); - } - - let mut project = self.as_mut().project(); - let value = ready!(project.state.as_mut().poll_value(cx))?; - project.state.set(State::Close { - future: (project.close)(value), - }); - } - } -} - -// --------------------------------------------------------------------------- -// QUIC control trait forwarding for Unfold -// --------------------------------------------------------------------------- - -impl GetStreamId for Unfold { - fn poll_stream_id( - self: Pin<&mut Self>, - cx: &mut Context, - ) -> Poll> { - let this = self.project(); - match this.state.project() { - StateProj::Value { value } => Pin::new(value).poll_stream_id(cx), - _ => Poll::Pending, - } - } -} - -impl CancelStream for Unfold { - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let this = self.project(); - match this.state.project() { - StateProj::Value { value } => Pin::new(value).poll_cancel(cx, code), - _ => Poll::Pending, - } - } -} - -// --------------------------------------------------------------------------- -// Unfold constructor -// --------------------------------------------------------------------------- - -pub fn unfold( - init: T, - send: S, - flush: F, - close: C, -) -> Unfold -where - S: FnMut(T, Item) -> SF, - SF: Future>, - F: FnMut(T) -> FF, - FF: Future>, - C: FnMut(T) -> CF, - CF: Future>, -{ - let state = State::Value { value: init }; - Unfold { - send, - flush, - close, - state, - } -} - -// --------------------------------------------------------------------------- -// WriteStream conversion methods -// --------------------------------------------------------------------------- - -impl WriteStream { - pub fn as_bytes_sink(&mut self) -> impl WriteMessageStream + '_ { - unfold( - self, - async |stream: &mut WriteStream, buf: Bytes| { - stream.send_data(buf).await.map(|_| stream) - }, - async |stream: &mut WriteStream| stream.flush().await.map(|_| stream), - async |stream: &mut WriteStream| stream.close().await.map(|_| stream), - ) - } - - pub fn as_writer(&mut self) -> SinkWriter { - SinkWriter::new(self.as_bytes_sink()) - } - - pub fn as_box_writer(&mut self) -> BoxMessageStreamWriter<'_> { - SinkWriter::new(Box::pin(self.as_bytes_sink())) - } - - pub fn into_bytes_sink(self) -> impl WriteMessageStream { - unfold( - self, - async |mut stream: WriteStream, buf: Bytes| stream.send_data(buf).await.map(|_| stream), - async |mut stream: WriteStream| stream.flush().await.map(|_| stream), - async |mut stream: WriteStream| stream.close().await.map(|_| stream), - ) - } - - pub fn into_writer(self) -> SinkWriter { - SinkWriter::new(self.into_bytes_sink()) - } - - pub fn into_box_writer(self) -> BoxMessageStreamWriter<'static> { - SinkWriter::new(Box::pin(self.into_bytes_sink())) - } -} diff --git a/src/message/unify.rs b/src/message/unify.rs deleted file mode 100644 index ab1a540..0000000 --- a/src/message/unify.rs +++ /dev/null @@ -1,762 +0,0 @@ -use std::mem; - -use bytes::{Buf, Bytes}; -use http::{ - HeaderMap, - header::{InvalidHeaderName, InvalidHeaderValue}, -}; -use snafu::Snafu; - -use crate::{ - buflist::{BufList, BuflistCursor}, - codec::EncodeError, - connection, - dhttp::frame::Frame, - error::{Code, H3FrameUnexpected, H3MessageError}, - message::stream::{DEFAULT_COMPRESS_ALGO, MessageStreamError, ReadStream, WriteStream}, - qpack::{ - encoder::EncodeHeaderSectionError, - field::{FieldSection, MalformedHeaderSection, PseudoHeaders, malformed_header_section}, - }, -}; - -#[derive(Debug, Clone)] -enum Body { - Streaming { count: u64 }, - Chunked { buflist: BuflistCursor }, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum MessageStage { - /// Receiving/Sending header section, including interim response headers - Header = 0, - /// Receiving/Sending message body - Body = 1, - /// Receiving/Sending trailer section - Trailer = 2, - /// Message is completely sent/received - Complete = 3, - - /// Message struct is malformed - Malformed = 4, - /// Message struct is already taken/dropped - // State can be removed after async drop stabilizes - Dropped = 5, -} - -#[derive(Debug, Snafu)] -pub enum MalformedMessageError { - // === 状态相关错误 === - #[snafu(display("cannot modify header section after it has been sent"))] - HeaderAlreadySent, - #[snafu(display("cannot modify body while it is being sent"))] - BodyAlreadySending, - #[snafu(display("cannot replace body content while sending"))] - BodyReplacementDuringSend, - #[snafu(display("cannot modify trailer section after it has been sent"))] - TrailerAlreadySent, - #[snafu(display("cannot change body mode after transfer has started"))] - BodyModeChangeAfterTransferStarted, - - // === 模式不匹配错误 === - #[snafu(display("chunked body operation cannot be performed on streaming body"))] - ChunkedOperationOnStreamingBody, - #[snafu(display("streaming body operation cannot be performed on chunked body"))] - StreamingOperationOnChunkedBody, - - // === 协议语义错误 === - #[snafu(display("cannot send malformed pseudo header section"))] - MalformedPseudoHeader { source: MalformedHeaderSection }, - #[snafu(display("cannot set body or trailer for interim (1xx) response"))] - BodyOrTrailerOnInterimResponse, - #[snafu(display("cannot close response stream without sending a final response"))] - FinalResponseRequired, -} - -impl From for MalformedMessageError { - fn from(source: MalformedHeaderSection) -> Self { - MalformedMessageError::MalformedPseudoHeader { source } - } -} - -#[derive(Debug, Clone)] -pub struct Message { - header: FieldSection, - body: Body, - trailer: FieldSection, - - stage: MessageStage, -} - -#[derive(Debug, Snafu)] -pub enum InvalidHeader { - #[snafu(transparent)] - Name { source: InvalidHeaderName }, - #[snafu(transparent)] - Value { source: InvalidHeaderValue }, -} - -impl Message { - pub fn unresolved_request() -> Self { - Self { - header: FieldSection::header(PseudoHeaders::unresolved_request(), HeaderMap::default()), - body: Body::Streaming { count: 0 }, - trailer: FieldSection::trailer(HeaderMap::default()), - stage: MessageStage::Header, - } - } - - pub fn unresolved_response() -> Self { - Self { - header: FieldSection::header( - PseudoHeaders::unresolved_response(), - HeaderMap::default(), - ), - body: Body::Streaming { count: 0 }, - trailer: FieldSection::trailer(HeaderMap::default()), - stage: MessageStage::Header, - } - } - - pub fn is_request(&self) -> bool { - self.header.is_request_header() - } - - pub fn is_response(&self) -> bool { - self.header.is_response_header() - } - - pub fn streaming_body(&mut self) -> Result<&mut u64, MalformedMessageError> { - if let Body::Chunked { buflist } = &self.body { - match self.stage { - MessageStage::Header => {} - MessageStage::Body if !buflist.inner().has_remaining() => {} - _ => return Err(MalformedMessageError::BodyModeChangeAfterTransferStarted), - } - - self.body = Body::Streaming { count: 0 } - } - match &mut self.body { - Body::Streaming { count } => Ok(count), - Body::Chunked { .. } => unreachable!(), - } - } - - pub fn chunked_body(&mut self) -> Result<&mut BuflistCursor, MalformedMessageError> { - if let Body::Streaming { count } = &self.body { - match self.stage { - MessageStage::Header => { /* Ok to change mode: body unused */ } - MessageStage::Body if *count == 0 => { /* Ok to change mode: body unused */ } - _ => return Err(MalformedMessageError::BodyModeChangeAfterTransferStarted), - } - - self.body = Body::Chunked { - buflist: BuflistCursor::new(BufList::new()), - }; - } - match &mut self.body { - Body::Streaming { .. } => unreachable!(), - Body::Chunked { buflist } => Ok(buflist), - } - } - - pub fn is_interim_response(&self) -> bool { - self.is_response() - && self.header().check_pseudo().is_ok() - && self.header().status().is_informational() - } - - pub fn header_mut(&mut self) -> &mut FieldSection { - &mut self.header - } - - pub fn header(&self) -> &FieldSection { - &self.header - } - - pub fn is_streaming(&self) -> bool { - matches!(self.body, Body::Streaming { .. }) - } - - pub fn is_chunked(&self) -> bool { - matches!(self.body, Body::Chunked { .. }) - } - - /// Set body to buffer mode with given content - pub fn set_body(&mut self, mut content: impl Buf) { - let mut buflist = BufList::new(); - while content.has_remaining() { - buflist.write(content.copy_to_bytes(content.chunk().len())); - } - self.body = Body::Chunked { - buflist: BuflistCursor::new(buflist), - }; - } - - pub fn trailers(&self) -> &HeaderMap { - &self.trailer.header_map - } - - pub fn trailers_mut(&mut self) -> &mut HeaderMap { - &mut self.trailer.header_map - } - - pub fn stage(&self) -> MessageStage { - self.stage - } - - pub fn is_complete(&self) -> bool { - self.stage() == MessageStage::Complete - } - - pub fn is_dropped(&self) -> bool { - self.stage() == MessageStage::Dropped - } - - pub fn is_malformed(&self) -> bool { - self.stage() == MessageStage::Malformed - } - - pub fn set_malformed(&mut self) { - self.stage = MessageStage::Malformed; - } - - pub fn set_dropped(&mut self) { - self.stage = MessageStage::Dropped; - } - - /// Reset the message to unsent state - pub fn to_unsend(mut self) -> Self { - assert!(!self.is_dropped(), "cannot unsend a dropped message"); - self.stage = MessageStage::Header; - // reset cursor - if let Body::Chunked { buflist } = &mut self.body { - buflist.reset(); - } - self - } - - pub fn take(&mut self) -> Self { - assert!(!self.is_dropped(), "cannot take a dropped message"); - let message = if self.is_request() { - mem::replace(self, Self::unresolved_request()) - } else { - mem::replace(self, Self::unresolved_response()) - }; - self.stage = MessageStage::Dropped; - - message - } -} - -#[derive(Debug, Snafu)] -pub enum ReadToStringError { - #[snafu(transparent)] - Stream { source: MessageStreamError }, - #[snafu(transparent)] - Utf8 { source: std::string::FromUtf8Error }, -} - -fn message_used_after_dropped() -> ! { - unreachable!("Message used after destroyed, this is a bug"); -} - -impl ReadStream { - pub async fn try_message_io( - &mut self, - message: &mut Message, - f: impl AsyncFnOnce(&mut Self, &mut Message) -> Result, - ) -> Result { - self.try_stream_io(async move |this| { - let result = f(this, message).await; - if let Err(connection::StreamError::H3 { .. }) = &result { - message.set_malformed(); - } - result - }) - .await - } - - pub async fn read_message_header<'e>( - &mut self, - message: &'e mut Message, - ) -> Result<&'e FieldSection, MessageStreamError> { - match message.stage { - MessageStage::Header => {} - // header already read - MessageStage::Body | MessageStage::Trailer | MessageStage::Complete => { - return Ok(&message.header); - } - MessageStage::Malformed => { - return Err(MessageStreamError::MalformedIncomingMessage); - } - MessageStage::Dropped => message_used_after_dropped(), - } - - message.header = self - .try_message_io(message, async |this, message| { - let Some(field_section) = this.read_header_frame().await.transpose()? else { - if this.peek_frame().await.transpose()?.is_some() { - return Err(H3FrameUnexpected::UnexpectedFrameType.into()); - } else { - return Err(H3MessageError::MissingHeaderSection.into()); - } - }; - - field_section.check_pseudo()?; - if message.header.is_request_header() { - if !field_section.is_request_header() { - malformed_header_section::AbsenceOfMandatoryPseudoHeadersSnafu.fail()?; - } - } else { - debug_assert!(message.header.is_response_header()); - if !field_section.is_response_header() { - malformed_header_section::AbsenceOfMandatoryPseudoHeadersSnafu.fail()?; - } - } - Ok(field_section) - }) - .await?; - - // check header complete/valid - - if message.is_interim_response() { - message.stage = MessageStage::Header; - } else { - message.stage = MessageStage::Body; - } - Ok(&message.header) - } - - pub async fn read_message_body_chunk( - &mut self, - message: &mut Message, - ) -> Option> { - match message.stage { - MessageStage::Header => { - while message.stage == MessageStage::Header { - match self.read_message_header(message).await { - Ok(..) => (), - Err(error) => return Some(Err(error)), - } - } - debug_assert_eq!(message.stage, MessageStage::Body); - } - MessageStage::Body => {} - MessageStage::Trailer | MessageStage::Complete => { - match &mut message.body { - Body::Streaming { .. } => return None, - Body::Chunked { buflist } => { - if buflist.has_remaining() { - return Some(Ok(buflist.copy_to_bytes(buflist.chunk().len()))); - } - } - }; - } - MessageStage::Malformed => { - return Some(Err(MessageStreamError::MalformedIncomingMessage)); - } - MessageStage::Dropped => message_used_after_dropped(), - } - - let try_read_next_chunk = self.try_message_io(message, async |this, message| { - match this.read_data_frame_chunk().await.transpose()? { - Some(chunk) => Ok(Some(chunk)), - None => { - if this.peek_frame().await.transpose()?.is_some() { - message.stage = MessageStage::Trailer - } else { - message.stage = MessageStage::Complete - } - Ok(None) - } - } - }); - - match try_read_next_chunk.await { - Ok(Some(bytes)) => Some(Ok(bytes)), - Ok(None) => None, - Err(error) => Some(Err(error)), - } - } - - pub async fn read_message_full_body<'e>( - &mut self, - message: &'e mut Message, - ) -> Result { - enum Buffer<'e> { - Owned(BufList), - Borrow(&'e mut BuflistCursor), - } - - impl Buf for Buffer<'_> { - fn remaining(&self) -> usize { - match self { - Buffer::Owned(b) => b.remaining(), - Buffer::Borrow(b) => b.remaining(), - } - } - - fn has_remaining(&self) -> bool { - match self { - Buffer::Owned(b) => b.has_remaining(), - Buffer::Borrow(b) => b.has_remaining(), - } - } - - fn chunk(&self) -> &[u8] { - match self { - Buffer::Owned(b) => b.chunk(), - Buffer::Borrow(b) => b.chunk(), - } - } - - fn chunks_vectored<'a>(&'a self, dst: &mut [std::io::IoSlice<'a>]) -> usize { - match self { - Buffer::Owned(b) => b.chunks_vectored(dst), - Buffer::Borrow(b) => b.chunks_vectored(dst), - } - } - - fn advance(&mut self, cnt: usize) { - match self { - Buffer::Owned(b) => b.advance(cnt), - Buffer::Borrow(b) => b.advance(cnt), - } - } - - fn copy_to_bytes(&mut self, len: usize) -> bytes::Bytes { - match self { - Buffer::Owned(b) => b.copy_to_bytes(len), - Buffer::Borrow(b) => b.copy_to_bytes(len), - } - } - } - - match &mut message.body { - Body::Streaming { .. } => { - let mut buflist = BufList::new(); - while let Some(body_part) = - self.read_message_body_chunk(message).await.transpose()? - { - buflist.write(body_part); - } - Ok(Buffer::Owned(buflist)) - } - Body::Chunked { .. } => { - while let Some(body_part) = - self.read_message_body_chunk(message).await.transpose()? - { - message.chunked_body().expect("checked").write(body_part); - } - Ok(Buffer::Borrow(message.chunked_body().expect("checked"))) - } - } - } - - pub async fn read_message_body_to_bytes( - &mut self, - message: &mut Message, - ) -> Result { - let mut bytes = self.read_message_full_body(message).await?; - Ok(bytes.copy_to_bytes(bytes.remaining())) - } - - pub async fn read_message_body_to_string( - &mut self, - message: &mut Message, - ) -> Result { - // TODO: preallocate buffer with content-length - let mut vec = vec![]; - while let Some(bytes) = self.read_message(message).await.transpose()? { - vec.extend_from_slice(&bytes); - } - Ok(String::from_utf8(vec)?) - } - - pub async fn read_message( - &mut self, - message: &mut Message, - ) -> Option> { - loop { - match &mut message.body { - Body::Streaming { .. } => return self.read_message_body_chunk(message).await, - Body::Chunked { buflist } => { - if buflist.has_remaining() { - return Some(Ok(buflist.copy_to_bytes(buflist.chunk().len()))); - } - match self.read_message_body_chunk(message).await? { - Ok(body_part) => { - message.chunked_body().expect("checked").write(body_part); - continue; - } - Err(error) => return Some(Err(error)), - } - } - } - } - } - - pub async fn read_message_trailer<'e>( - &mut self, - message: &'e mut Message, - ) -> Result<&'e HeaderMap, MessageStreamError> { - match message.stage { - MessageStage::Header | MessageStage::Body => { - match &message.body { - Body::Streaming { .. } => { - // read and discard body - while let Some(_body_part) = - self.read_message_body_chunk(message).await.transpose()? - { - } - } - Body::Chunked { .. } => { - self.read_message_full_body(message).await?; - } - } - } - MessageStage::Trailer => {} - MessageStage::Complete => return Ok(message.trailers()), - MessageStage::Malformed => return Err(MessageStreamError::MalformedIncomingMessage), - MessageStage::Dropped => message_used_after_dropped(), - } - - message.trailer = self - .try_message_io(message, async |this, _| { - let Some(field_section) = this.read_header_frame().await.transpose()? else { - if this.peek_frame().await.transpose()?.is_some() { - return Err(H3FrameUnexpected::UnexpectedFrameDuringTrailer.into()); - } else { - // no trailer - return Ok(FieldSection::trailer(HeaderMap::new())); - } - }; - - if !field_section.is_trailer() { - return Err(MalformedHeaderSection::PseudoHeaderInTrailer.into()); - } - Ok(field_section) - }) - .await?; - - message.stage = MessageStage::Complete; - - Ok(message.trailers()) - } -} - -impl WriteStream { - pub async fn send_message_header( - &mut self, - message: &mut Message, - ) -> Result<(), MessageStreamError> { - match message.stage { - MessageStage::Header => {} - // header already sent - MessageStage::Body | MessageStage::Trailer | MessageStage::Complete => return Ok(()), - MessageStage::Malformed => { - return self.cancel(Code::H3_REQUEST_CANCELLED).await; - } - MessageStage::Dropped => message_used_after_dropped(), - } - - self.send_header(message.header.iter()).await?; - if message.is_interim_response() { - message.stage = MessageStage::Header; - } else { - message.stage = MessageStage::Body; - } - - Ok(()) - } - - pub async fn send_data(&mut self, data: impl Buf + Send) -> Result<(), MessageStreamError> { - let frame = Frame::new(Frame::DATA_FRAME_TYPE, data)?; - self.try_stream_io(async |this| Ok(this.send_frame(frame).await?)) - .await - } - - pub async fn send_message_streaming_body( - &mut self, - message: &mut Message, - content: impl Buf + Send, - ) -> Result<(), MessageStreamError> { - // if message.is_interim_response() { - // // malformed message - // } - // message.enable_streaming()?; // violate of set_body call - - match message.stage { - MessageStage::Header => { - self.send_message_header(message).await?; - debug_assert_eq!(message.stage, MessageStage::Body); - } - MessageStage::Body => {} - MessageStage::Trailer | MessageStage::Complete => { - /* this will cause H3_FRAME_UNEXPECTED error */ - } - MessageStage::Malformed => { - return self.cancel(Code::H3_REQUEST_CANCELLED).await; - } - MessageStage::Dropped => message_used_after_dropped(), - } - - let len = content.remaining(); - self.send_data(content).await?; - *message - .streaming_body() - .expect("call send_streaming_body on chunked body, this is a bug") += len as u64; - - Ok(()) - } - - pub async fn send_message_chunked_body( - &mut self, - message: &mut Message, - ) -> Result<(), MessageStreamError> { - // if message.is_interim_response() { - // // malformed message - // } - - match message.stage { - MessageStage::Header => { - self.send_message_header(message).await?; - debug_assert_eq!(message.stage, MessageStage::Body); - } - MessageStage::Body => {} - MessageStage::Trailer | MessageStage::Complete => { - /* this will cause H3_FRAME_UNEXPECTED error */ - } - MessageStage::Malformed => { - return self.cancel(Code::H3_REQUEST_CANCELLED).await; - } - MessageStage::Dropped => message_used_after_dropped(), - } - - let Body::Chunked { buflist } = &mut message.body else { - unreachable!("Call send_chunked_body on non-chunked body, this is a bug"); - }; - - while buflist.has_remaining() { - let bytes = buflist.copy_to_bytes(buflist.chunk().len()); - let frame = Frame::new(Frame::DATA_FRAME_TYPE, bytes)?; - self.try_stream_io(async |this| Ok(this.send_frame(frame).await?)) - .await?; - } - - message.stage = MessageStage::Trailer; - Ok(()) - } - - pub async fn send_message_trailer( - &mut self, - message: &mut Message, - ) -> Result<(), MessageStreamError> { - // prepare - match message.stage { - MessageStage::Header | MessageStage::Body => { - if message.is_chunked() { - self.send_message_chunked_body(message).await?; - debug_assert_eq!(message.stage, MessageStage::Trailer); - } else { - self.send_message_header(message).await?; - debug_assert_eq!(message.stage, MessageStage::Body); - // no body or streaming body already sent - } - } - MessageStage::Trailer => {} - MessageStage::Complete => return Ok(()), - MessageStage::Malformed => { - return self.cancel(Code::H3_REQUEST_CANCELLED).await; - } - MessageStage::Dropped => message_used_after_dropped(), - } - - // TODO: check FieldLines size - let field_lines = message.trailer.iter(); - let algo = &DEFAULT_COMPRESS_ALGO; - - let result = self - .try_stream_io(async move |this| { - let stream = &mut this.stream; - match this.qpack_encoder.encode(field_lines, algo, stream).await { - Ok(frame) => Ok(Ok(this.send_frame(frame).await?)), - Err(EncodeHeaderSectionError::Encode { source }) => Ok(Err(source)), - Err(EncodeHeaderSectionError::Stream { source }) => Err(source), - } - }) - .await?; - - // Flush encoder instructions (dynamic table insertions) to the encoder stream. - // Encoder stream errors are connection-level: reset = connection error per RFC 9204. - if let Err(error) = self.qpack_encoder.flush_instructions().await { - let quic_error = self.handle_stream_error(error).await; - return Err(quic_error.into()); - } - - match result { - Ok(()) => { - message.stage = MessageStage::Complete; - Ok(()) - } - Err(error) => match error { - EncodeError::FramePayloadTooLarge => Err(MessageStreamError::TrailerTooLarge), - EncodeError::HuffmanEncoding => { - unreachable!("FieldSection contain invalid header name/value, this is a bug") - } - }, - } - } - - pub async fn send_message(&mut self, message: &mut Message) -> Result<(), MessageStreamError> { - match message.stage { - MessageStage::Header | MessageStage::Body => { - if message.header().is_empty() { - return Ok(()); - } - if message.is_chunked() { - self.send_message_chunked_body(message).await?; - debug_assert_eq!(message.stage, MessageStage::Trailer); - } else { - self.send_message_header(message).await?; - if message.is_interim_response() { - debug_assert!(message.stage == MessageStage::Header); - } else { - debug_assert!(message.stage == MessageStage::Body); - } - } - - if !message.trailers().is_empty() { - self.send_message_trailer(message).await?; - debug_assert_eq!(message.stage, MessageStage::Complete) - } - } - MessageStage::Trailer => { - if !message.trailers().is_empty() { - self.send_message_trailer(message).await?; - debug_assert_eq!(message.stage, MessageStage::Complete) - } else { - // no trailer to send - } - } - MessageStage::Complete => {} - MessageStage::Malformed => { - return self.cancel(Code::H3_REQUEST_CANCELLED).await; - } - MessageStage::Dropped => message_used_after_dropped(), - } - - Ok(()) - } - - pub async fn flush_message(&mut self, message: &mut Message) -> Result<(), MessageStreamError> { - self.send_message(message).await?; - self.flush().await - } - - pub async fn close_message(&mut self, message: &mut Message) -> Result<(), MessageStreamError> { - self.send_message(message).await?; - self.close().await - } -} diff --git a/src/pool.rs b/src/pool.rs index 9b30d27..56e77fc 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,8 +1,4 @@ -use std::{ - error::Error, - pin::pin, - sync::{Arc, LazyLock}, -}; +use std::{error::Error, pin::pin, sync::Arc}; use dashmap::DashMap; use futures::{StreamExt, never::Never}; @@ -24,8 +20,8 @@ pub struct ReuseableConnection { task: AsyncMutex>>, } -type ConnectionIdentifier = (Authority, Arc>); -type ReuseableConnections = DashMap, Arc>>; +type ConnectionIdentifier = Authority; +type ReuseableConnections = DashMap>>; impl ReuseableConnection { pub fn pending() -> Self { @@ -78,7 +74,7 @@ impl ReuseableConnection { #[derive(Debug)] pub struct Pool { - connections: Arc>, + pub(crate) connections: Arc>, } impl Clone for Pool { @@ -96,16 +92,19 @@ impl Pool { } } - pub fn global() -> &'static Self { - use std::any::{Any, TypeId}; + /// Clear all cached connections from the pool. + pub fn clear(&self) { + self.connections.clear(); + } + + /// Return the number of cached connections. + pub fn len(&self) -> usize { + self.connections.len() + } - static POOLS: LazyLock> = - LazyLock::new(DashMap::new); - POOLS - .entry(TypeId::of::()) - .or_insert_with(|| Box::leak(Box::new(Pool::::empty()))) - .downcast_ref::>() - .expect("type id collision") + /// Return `true` if the pool contains no connections. + pub fn is_empty(&self) -> bool { + self.connections.is_empty() } } @@ -144,7 +143,7 @@ pub enum InsertError { } impl Pool { - fn spawn_try_release(self, identify: ConnectionIdentifier) { + fn spawn_try_release(self, identify: Authority) { tokio::spawn( async move { (self.connections.as_ref()).remove_if(&identify, |_, connection| { @@ -159,7 +158,7 @@ impl Pool { ); } - #[tracing::instrument(level = "debug", skip(self, connector), err)] + #[tracing::instrument(level = "debug", skip(self, connector))] pub async fn reuse_or_connect_with( &self, connector: &Client, @@ -171,7 +170,7 @@ impl Pool { { let reuseable_connection = self .connections - .entry((server.clone(), builder.clone())) + .entry(server.clone()) .or_insert_with(|| Arc::new(ReuseableConnection::pending())) .clone(); // break borrow of dashmap::Entry to avoid deadlock @@ -194,8 +193,8 @@ impl Pool { let connection = builder.build(quic_conn).await?; tracing::trace!("h3 connection established, verifying peer identity"); - let remote_agent = connection.remote_agent().await?; - let actual_peer_name = remote_agent.as_ref().map(|agent| agent.name()); + let authority = connection.remote_authority().await?; + let actual_peer_name = authority.as_ref().map(|authority| authority.name()); if actual_peer_name.as_ref() != Some(&server.host()) { return connect_error::IncorrectIdentitySnafu { expected: server.host().to_string(), @@ -210,10 +209,9 @@ impl Pool { let connection = connection.clone(); let pool = self.clone(); let server = server.clone(); - let builder = builder.clone(); async move { connection.closed().await; - pool.spawn_try_release((server, builder)); + pool.spawn_try_release(server); } .in_current_span() })); @@ -235,29 +233,25 @@ impl Pool { match &result { Ok(..) => tracing::trace!("connection ready to use"), - Err(..) => self.clone().spawn_try_release((server, builder.clone())), + Err(..) => self.clone().spawn_try_release(server), } result } - pub async fn try_insert( - &self, - connection: Arc>, - builder: Arc>, - ) -> Result<(), InsertError> { - let remote_agent = connection - .remote_agent() + pub async fn try_insert(&self, connection: Arc>) -> Result<(), InsertError> { + let authority = connection + .remote_authority() .await? .context(insert_error::MissingIdentitySnafu)?; - let client = remote_agent + let client: Authority = authority .name() .parse() .ok() .context(insert_error::InvalidIdentitySnafu)?; - let identity = (client, builder.clone()); + let identity = client; let reuseable_connection = self .connections .entry(identity.clone()) @@ -283,72 +277,34 @@ impl Pool { #[cfg(test)] mod tests { - use std::sync::Arc; - #[cfg(feature = "dquic")] use std::{ - collections::hash_map::DefaultHasher, - hash::{Hash, Hasher}, + error::Error as StdError, + io, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, }; + use dhttp_identity::identity::RemoteAuthority; + use http::uri::Authority; + use tokio::sync::Semaphore; use tokio_util::task::AbortOnDropHandle; + use tracing::Level; - use super::ReuseableConnection; - #[cfg(feature = "dquic")] + use super::{ConnectError, InsertError, Pool, ReuseableConnection}; use crate::{ - connection::ConnectionBuilder, - dhttp::settings::{MaxFieldSectionSize, Settings}, - }; - use crate::{ - connection::{Connection, ConnectionState}, - dhttp::{goaway::Goaway, protocol::DHttpProtocol}, + connection::{Connection, ConnectionBuilder, ConnectionState}, + dhttp::{ + goaway::Goaway, + protocol::DHttpProtocol, + settings::{Setting, Settings}, + }, quic, varint::VarInt, }; - #[cfg(feature = "dquic")] - fn hash_of(val: &T) -> u64 { - let mut hasher = DefaultHasher::new(); - val.hash(&mut hasher); - hasher.finish() - } - - #[cfg(feature = "dquic")] - type C = dquic::prelude::Connection; - - #[cfg(feature = "dquic")] - #[test] - fn pool_key_different_builders_different_entries() { - let s1 = Arc::new(Settings::default()); - let mut s2_inner = Settings::default(); - s2_inner.set(MaxFieldSectionSize::setting(VarInt::from_u32(9999))); - let s2 = Arc::new(s2_inner); - - let builder_a = ConnectionBuilder::::new(s1); - let builder_b = ConnectionBuilder::::new(s2); - - let key_a = hash_of(&builder_a); - let key_b = hash_of(&builder_b); - assert_ne!( - key_a, key_b, - "different protocol stacks must produce different pool keys" - ); - } - - #[cfg(feature = "dquic")] - #[test] - fn pool_key_same_builder_same_entry() { - let s = Arc::new(Settings::default()); - let builder_a = ConnectionBuilder::::new(s.clone()); - let builder_b = ConnectionBuilder::::new(s); - - let key_a = hash_of(&builder_a); - let key_b = hash_of(&builder_b); - assert_eq!( - key_a, key_b, - "identical protocol stacks must produce the same pool key" - ); - } - fn test_connection_error(reason: &str) -> quic::ConnectionError { quic::ConnectionError::Transport { source: quic::TransportError { @@ -363,6 +319,54 @@ mod tests { AbortOnDropHandle::new(tokio::spawn(async {})) } + #[derive(Clone)] + struct SharedWriter(Arc>>); + + impl io::Write for SharedWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.lock().expect("log buffer poisoned").extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + fn matching_server() -> Authority { + "test-remote:443".parse().unwrap() + } + + fn inserted_identity() -> Authority { + "test-remote".parse().unwrap() + } + + fn alternate_builder() -> Arc> { + let mut settings = Settings::default(); + settings.set(Setting::new(VarInt::from_u32(0x21), VarInt::from_u32(0x07))); + Arc::new(ConnectionBuilder::new(Arc::new(settings))) + } + + fn mock_connection( + quic: crate::connection::tests::MockConnection, + ) -> Connection { + Connection::from_state_for_test(ConnectionState::new_for_test( + Arc::new(quic), + Arc::new(crate::protocol::Protocols::new()), + )) + } + + fn mock_connection_with_dhttp( + quic: crate::connection::tests::MockConnection, + ) -> Connection { + let mut protocols = crate::protocol::Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(Arc::new(quic.clone()))); + Connection::from_state_for_test(ConnectionState::new_for_test( + Arc::new(quic), + Arc::new(protocols), + )) + } + async fn reusable_connection( connection: Connection, ) -> Arc> { @@ -371,6 +375,405 @@ mod tests { reusable } + async fn wait_until(label: &str, f: impl Fn() -> bool) { + tokio::time::timeout(Duration::from_secs(1), async { + while !f() { + tokio::task::yield_now().await; + } + }) + .await + .unwrap_or_else(|_| panic!("timed out waiting for {label}")); + } + + #[derive(Debug, Clone)] + struct TestConnector { + state: Arc, + } + + #[derive(Debug)] + struct TestConnectorState { + calls: AtomicUsize, + error_message: Option<&'static str>, + returned_quics: Mutex>, + gate_first_call: Option>, + } + + impl TestConnector { + fn succeed() -> Self { + Self { + state: Arc::new(TestConnectorState { + calls: AtomicUsize::new(0), + error_message: None, + returned_quics: Mutex::new(Vec::new()), + gate_first_call: None, + }), + } + } + + fn succeed_with_first_call_gate() -> (Self, Arc) { + let gate = Arc::new(Semaphore::new(0)); + ( + Self { + state: Arc::new(TestConnectorState { + calls: AtomicUsize::new(0), + error_message: None, + returned_quics: Mutex::new(Vec::new()), + gate_first_call: Some(gate.clone()), + }), + }, + gate, + ) + } + + fn fail(message: &'static str) -> Self { + Self { + state: Arc::new(TestConnectorState { + calls: AtomicUsize::new(0), + error_message: Some(message), + returned_quics: Mutex::new(Vec::new()), + gate_first_call: None, + }), + } + } + + fn call_count(&self) -> usize { + self.state.calls.load(Ordering::SeqCst) + } + + fn returned_quics(&self) -> Vec { + self.state + .returned_quics + .lock() + .expect("returned quics log poisoned") + .clone() + } + } + + impl quic::Connect for TestConnector { + type Connection = crate::connection::tests::MockConnection; + type Error = io::Error; + + async fn connect(&self, _server: &Authority) -> Result, Self::Error> { + let call = self.state.calls.fetch_add(1, Ordering::SeqCst); + if call == 0 + && let Some(gate) = &self.state.gate_first_call + { + gate.acquire() + .await + .expect("gate should not be closed") + .forget(); + } + if let Some(message) = self.state.error_message { + return Err(io::Error::other(message)); + } + + let quic = crate::connection::tests::MockConnection::new(); + quic.enable_stream_ops(); + self.state + .returned_quics + .lock() + .expect("returned quics log poisoned") + .push(quic.clone()); + Ok(Arc::new(quic)) + } + } + + #[derive(Debug)] + struct NamedRemoteAuthority(&'static str); + + impl RemoteAuthority for NamedRemoteAuthority { + fn name(&self) -> &str { + self.0 + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + } + + #[derive(Debug, Clone)] + struct IdentityOverrideConnection { + inner: crate::connection::tests::MockConnection, + remote_name: Option<&'static str>, + } + + impl IdentityOverrideConnection { + fn new(remote_name: Option<&'static str>) -> Self { + Self { + inner: crate::connection::tests::MockConnection::new(), + remote_name, + } + } + + fn with_stream_ops(remote_name: Option<&'static str>) -> Self { + let inner = crate::connection::tests::MockConnection::new(); + inner.enable_stream_ops(); + Self { inner, remote_name } + } + } + + impl quic::ManageStream for IdentityOverrideConnection { + type StreamReader = crate::connection::tests::TestReadStream; + type StreamWriter = crate::connection::tests::TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + quic::ManageStream::open_bi(&self.inner).await + } + + async fn open_uni(&self) -> Result { + quic::ManageStream::open_uni(&self.inner).await + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + quic::ManageStream::accept_bi(&self.inner).await + } + + async fn accept_uni(&self) -> Result { + quic::ManageStream::accept_uni(&self.inner).await + } + } + + impl quic::WithLocalAuthority for IdentityOverrideConnection { + type LocalAuthority = crate::connection::tests::TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(Some(crate::connection::tests::TestLocalAuthority)) + } + } + + impl quic::WithRemoteAuthority for IdentityOverrideConnection { + type RemoteAuthority = NamedRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(self.remote_name.map(NamedRemoteAuthority)) + } + } + + impl quic::Lifecycle for IdentityOverrideConnection { + fn close(&self, code: crate::error::Code, reason: std::borrow::Cow<'static, str>) { + quic::Lifecycle::close(&self.inner, code, reason); + } + + fn check(&self) -> Result<(), quic::ConnectionError> { + quic::Lifecycle::check(&self.inner) + } + + async fn closed(&self) -> quic::ConnectionError { + quic::Lifecycle::closed(&self.inner).await + } + } + + #[derive(Debug, Clone)] + struct IdentityConnector { + connection: IdentityOverrideConnection, + } + + impl IdentityConnector { + fn new(remote_name: Option<&'static str>) -> Self { + Self { + connection: IdentityOverrideConnection::with_stream_ops(remote_name), + } + } + } + + impl quic::Connect for IdentityConnector { + type Connection = IdentityOverrideConnection; + type Error = io::Error; + + async fn connect(&self, _server: &Authority) -> Result, Self::Error> { + Ok(Arc::new(self.connection.clone())) + } + } + + #[derive(Debug)] + struct UnavailableStreamsConnector; + + impl quic::Connect for UnavailableStreamsConnector { + type Connection = crate::connection::tests::MockConnection; + type Error = io::Error; + + async fn connect(&self, _server: &Authority) -> Result, Self::Error> { + Ok(Arc::new(crate::connection::tests::MockConnection::new())) + } + } + + #[test] + fn reusable_connection_pending_starts_empty() { + let reusable = ReuseableConnection::::pending(); + + assert!(reusable.peek().is_none()); + assert!(reusable.reuse().is_none()); + } + + #[tokio::test] + async fn reusable_connection_insert_with_publishes_connection() { + let expected = Arc::new(mock_connection_with_dhttp( + crate::connection::tests::MockConnection::new(), + )); + let reusable = ReuseableConnection::pending(); + + reusable + .insert_with(async || (expected.clone(), abort_handle())) + .await; + + assert!(Arc::ptr_eq( + &reusable.peek().expect("connection should be visible"), + &expected, + )); + assert!(Arc::ptr_eq( + &reusable.reuse().expect("connection should be reusable"), + &expected, + )); + } + + #[tokio::test] + async fn reusable_connection_try_insert_with_replaces_existing_connection() { + let old = Arc::new(mock_connection_with_dhttp( + crate::connection::tests::MockConnection::new(), + )); + let new = Arc::new(mock_connection_with_dhttp( + crate::connection::tests::MockConnection::new(), + )); + let reusable = ReuseableConnection::pending(); + reusable.insert(old.clone(), abort_handle()).await; + + reusable + .try_insert_with::(async || Ok((new.clone(), abort_handle()))) + .await + .expect("replacement should succeed"); + + let reused = reusable + .reuse() + .expect("replacement should become reusable"); + assert!(Arc::ptr_eq(&reused, &new)); + assert!(!Arc::ptr_eq(&reused, &old)); + } + + #[test] + fn pool_default_clone_shares_entries_and_clear() { + let pool = Pool::::default(); + assert!(pool.is_empty()); + + let cloned = pool.clone(); + let auth: Authority = "example.com:443".parse().unwrap(); + cloned + .connections + .entry(auth) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + + assert_eq!(pool.len(), 1); + assert!(!cloned.is_empty()); + + pool.clear(); + assert!(cloned.is_empty()); + } + + #[test] + fn connect_error_display_and_sources_are_semantic() { + let connector: ConnectError = ConnectError::Connector { + source: io::Error::other("dial failed"), + }; + assert_eq!( + connector.to_string(), + "failed to initialize QUIC connection" + ); + assert_eq!( + StdError::source(&connector) + .expect("connector error should preserve source") + .to_string(), + "dial failed" + ); + + let anonymous: ConnectError = ConnectError::IncorrectIdentity { + expected: "expected.example".to_owned(), + actual: None, + }; + assert_eq!( + anonymous.to_string(), + "peer name mismatch: expected expected.example, actual " + ); + assert!(StdError::source(&anonymous).is_none()); + + let named: ConnectError = ConnectError::IncorrectIdentity { + expected: "expected.example".to_owned(), + actual: Some("actual.example".to_owned()), + }; + assert_eq!( + named.to_string(), + "peer name mismatch: expected expected.example, actual actual.example" + ); + assert!(StdError::source(&named).is_none()); + + let source = test_connection_error("h3 failed"); + let source_display = source.to_string(); + let h3: ConnectError = ConnectError::H3 { source }; + assert_eq!(h3.to_string(), source_display); + } + + #[test] + fn insert_error_display_and_sources_are_semantic() { + let missing = InsertError::MissingIdentity; + assert_eq!(missing.to_string(), "peer does not provide identity"); + assert!(StdError::source(&missing).is_none()); + + let invalid = InsertError::InvalidIdentity; + assert_eq!( + invalid.to_string(), + "peer provided invalid identity (cannot be parsed as Authority)" + ); + assert!(StdError::source(&invalid).is_none()); + + let source = test_connection_error("insert failed"); + let source_display = source.to_string(); + let quic = InsertError::Quic { source }; + assert_eq!(quic.to_string(), source_display); + } + + #[test] + fn test_pool_key_is_authority_only() { + let pool = Pool::::empty(); + let auth: Authority = "example.com:443".parse().unwrap(); + + pool.connections + .entry(auth.clone()) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + pool.connections + .entry(auth) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + assert_eq!( + pool.connections.len(), + 1, + "builder must not affect pool key; only authority matters", + ); + } + + #[test] + fn test_pool_key_different_authority_different_entry() { + let pool = Pool::::empty(); + let auth1: Authority = "example.com:443".parse().unwrap(); + let auth2: Authority = "other.com:443".parse().unwrap(); + + pool.connections + .entry(auth1) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + pool.connections + .entry(auth2) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + assert_eq!( + pool.connections.len(), + 2, + "different authorities must produce different pool entries", + ); + } + #[tokio::test] async fn reuse_returns_none_when_connection_is_unhealthy() { let quic = crate::connection::tests::MockConnection::new(); @@ -401,4 +804,516 @@ mod tests { let reusable = reusable_connection(Connection::from_state_for_test(state)).await; assert!(reusable.reuse().is_none()); } + + #[tokio::test] + async fn reusable_connection_try_insert_with_error_preserves_existing_connection() { + #[derive(Debug, PartialEq, Eq)] + struct TestInsertFailure; + + let quic = crate::connection::tests::MockConnection::new(); + let expected = Arc::new(mock_connection_with_dhttp(quic)); + let reusable = Arc::new(ReuseableConnection::pending()); + + reusable.insert(expected.clone(), abort_handle()).await; + + let error = reusable + .try_insert_with(async || { + Err::< + ( + Arc>, + AbortOnDropHandle<()>, + ), + _, + >(TestInsertFailure) + }) + .await + .expect_err("insertion should fail"); + + assert_eq!(error, TestInsertFailure); + assert!(Arc::ptr_eq( + &reusable + .reuse() + .expect("existing connection should remain reusable"), + &expected, + )); + } + + #[tokio::test] + async fn pool_reuses_same_authority_even_with_different_builders() { + let pool = Pool::::empty(); + let connector = TestConnector::succeed(); + let server = matching_server(); + + let first = pool + .reuse_or_connect_with( + &connector, + Arc::new(ConnectionBuilder::default()), + server.clone(), + ) + .await + .expect("first connect should succeed"); + let second = pool + .reuse_or_connect_with(&connector, alternate_builder(), server) + .await + .expect("second connect should reuse"); + + assert!(Arc::ptr_eq(&first, &second)); + assert_eq!(connector.call_count(), 1); + assert_eq!(pool.len(), 1); + } + + #[tokio::test] + async fn pool_coordinates_in_flight_connect_attempts() { + let pool = Pool::::empty(); + let (connector, first_call_gate) = TestConnector::succeed_with_first_call_gate(); + let server = matching_server(); + let builder = Arc::new(ConnectionBuilder::default()); + let first = pool.reuse_or_connect_with(&connector, builder.clone(), server.clone()); + let mut first = std::pin::pin!(first); + assert!( + futures::poll!(first.as_mut()).is_pending(), + "gated first connect should remain pending until released", + ); + assert_eq!( + connector.call_count(), + 1, + "first poll should start one dial" + ); + assert_eq!( + pool.len(), + 1, + "pending entry should exist while first call is gated" + ); + + let second = pool.reuse_or_connect_with(&connector, builder, server.clone()); + let mut second = std::pin::pin!(second); + assert!( + futures::poll!(second.as_mut()).is_pending(), + "second caller should pend while the first connection attempt is in flight", + ); + + wait_until("second waiter to attach to the same pending entry", || { + let Some(entry) = pool.connections.get(&server) else { + return false; + }; + Arc::strong_count(entry.value()) >= 3 + }) + .await; + + assert_eq!( + connector.call_count(), + 1, + "second in-flight caller should wait on the pending entry instead of dialing again", + ); + + first_call_gate.add_permits(1); + + let deadline = std::time::Instant::now() + Duration::from_secs(1); + let (first, second) = loop { + let first_poll = futures::poll!(first.as_mut()); + let second_poll = futures::poll!(second.as_mut()); + + match (first_poll, second_poll) { + (std::task::Poll::Ready(first), std::task::Poll::Ready(second)) => { + break ( + first.expect("first connect should succeed"), + second.expect("second waiter should reuse the same connection"), + ); + } + (std::task::Poll::Ready(_), std::task::Poll::Pending) + if std::time::Instant::now() >= deadline => + { + panic!("second in-flight waiter stalled after the first connection completed"); + } + (std::task::Poll::Pending, _) if std::time::Instant::now() >= deadline => { + panic!("first in-flight connection attempt stalled before becoming reusable"); + } + _ => tokio::task::yield_now().await, + } + }; + + assert!(Arc::ptr_eq(&first, &second)); + assert_eq!(connector.call_count(), 1); + assert_eq!(pool.len(), 1); + } + + #[tokio::test] + async fn pool_returns_connector_error() { + let pool = Pool::::empty(); + let connector = TestConnector::fail("dial failed"); + let error = pool + .reuse_or_connect_with( + &connector, + Arc::new(ConnectionBuilder::default()), + matching_server(), + ) + .await + .expect_err("connect should fail"); + + match error { + ConnectError::Connector { source } => { + assert_eq!(source.to_string(), "dial failed"); + } + other => panic!("expected connector error, got {other:?}"), + } + } + + #[tokio::test] + async fn pool_rejects_incorrect_peer_identity() { + let pool = Pool::::empty(); + let connector = TestConnector::succeed(); + let server: Authority = "expected.example:443".parse().unwrap(); + + let error = pool + .reuse_or_connect_with(&connector, Arc::new(ConnectionBuilder::default()), server) + .await + .expect_err("identity mismatch should fail"); + + match error { + ConnectError::IncorrectIdentity { expected, actual } => { + assert_eq!(expected, "expected.example"); + assert_eq!(actual.as_deref(), Some("test-remote")); + } + other => panic!("expected identity error, got {other:?}"), + } + } + + #[tokio::test] + async fn pool_rejects_anonymous_peer_identity() { + let pool = Pool::::empty(); + let connector = IdentityConnector::new(None); + let server = matching_server(); + + let error = pool + .reuse_or_connect_with(&connector, Arc::new(ConnectionBuilder::default()), server) + .await + .expect_err("anonymous peer should fail identity verification"); + + match error { + ConnectError::IncorrectIdentity { expected, actual } => { + assert_eq!(expected, "test-remote"); + assert!(actual.is_none()); + } + other => panic!("expected identity error, got {other:?}"), + } + } + + #[tokio::test] + async fn pool_accepts_matching_identity_override_connection() { + let pool = Pool::::empty(); + let connector = IdentityConnector::new(Some("test-remote")); + let server = matching_server(); + + let connection = pool + .reuse_or_connect_with( + &connector, + Arc::new(ConnectionBuilder::default()), + server.clone(), + ) + .await + .expect("matching peer identity should connect"); + + let reused = pool + .reuse_or_connect_with(&connector, Arc::new(ConnectionBuilder::default()), server) + .await + .expect("matching peer identity should reuse"); + + assert!(Arc::ptr_eq(&connection, &reused)); + assert_eq!(pool.len(), 1); + } + + #[tokio::test] + async fn identity_override_connection_delegates_quic_capabilities() { + let connection = IdentityOverrideConnection::with_stream_ops(Some("delegated.example")); + + quic::ManageStream::open_bi(&connection) + .await + .expect("open_bi should delegate to inner connection"); + quic::ManageStream::open_uni(&connection) + .await + .expect("open_uni should delegate to inner connection"); + quic::ManageStream::accept_bi(&connection) + .await + .expect("accept_bi should delegate to inner connection"); + quic::ManageStream::accept_uni(&connection) + .await + .expect("accept_uni should delegate to inner connection"); + + let local = quic::WithLocalAuthority::local_authority(&connection) + .await + .expect("local authority lookup should succeed"); + assert!(local.is_some()); + + let remote = quic::WithRemoteAuthority::remote_authority(&connection) + .await + .expect("remote authority lookup should succeed") + .expect("remote authority should be present"); + assert_eq!(remote.name(), "delegated.example"); + assert!(remote.cert_chain().is_empty()); + + assert!(quic::Lifecycle::check(&connection).is_ok()); + quic::Lifecycle::close( + &connection, + crate::error::Code::H3_NO_ERROR, + "delegated close".into(), + ); + assert_eq!( + connection.inner.close_calls(), + vec![( + crate::error::Code::H3_NO_ERROR, + "delegated close".to_owned() + )], + ); + + connection + .inner + .set_terminal_error(test_connection_error("delegated terminal")); + let closed = quic::Lifecycle::closed(&connection).await; + assert!(closed.to_string().contains("delegated terminal")); + assert_eq!( + connection.inner.stream_calls(), + vec!["open_bi", "open_uni", "accept_bi", "accept_uni"], + ); + } + + #[tokio::test] + async fn pool_returns_h3_error_when_connection_initialization_fails() { + let pool = Pool::::empty(); + + let error = pool + .reuse_or_connect_with( + &UnavailableStreamsConnector, + Arc::new(ConnectionBuilder::default()), + matching_server(), + ) + .await + .expect_err("builder should fail when stream operations are unavailable"); + + assert!(matches!(error, ConnectError::H3 { .. })); + } + + #[tokio::test] + async fn pool_replaces_closed_connection_with_new_one() { + let pool = Pool::::empty(); + let connector = TestConnector::succeed(); + let server = matching_server(); + let builder = Arc::new(ConnectionBuilder::default()); + + let first = pool + .reuse_or_connect_with(&connector, builder.clone(), server.clone()) + .await + .expect("first connect should succeed"); + let first_quic = connector + .returned_quics() + .into_iter() + .next() + .expect("first connection should be recorded"); + first_quic.set_terminal_error(test_connection_error("closed")); + + let second = pool + .reuse_or_connect_with(&connector, builder, server) + .await + .expect("second connect should replace the dead connection"); + + assert!(!Arc::ptr_eq(&first, &second)); + assert_eq!(connector.call_count(), 2); + assert_eq!(pool.len(), 1); + } + + #[tokio::test] + async fn try_insert_makes_connection_reusable_under_remote_identity() { + let pool = Pool::::empty(); + let connection = Arc::new(mock_connection_with_dhttp( + crate::connection::tests::MockConnection::new(), + )); + + pool.try_insert(connection.clone()) + .await + .expect("insert should succeed"); + + let reusable = pool + .connections + .get(&inserted_identity()) + .expect("entry should be keyed by remote identity"); + assert!(Arc::ptr_eq( + &reusable.reuse().expect("connection should be reusable"), + &connection, + )); + } + + #[tokio::test] + async fn clear_drops_cached_connection_and_aborts_release_task() { + let pool = Pool::::empty(); + let quic = crate::connection::tests::MockConnection::new(); + + pool.try_insert(Arc::new(mock_connection_with_dhttp(quic.clone()))) + .await + .expect("insert should succeed"); + + pool.clear(); + + wait_until("cached connection drop after clear", || { + !quic.close_calls().is_empty() + }) + .await; + assert!(pool.is_empty()); + } + + #[tokio::test(flavor = "current_thread")] + async fn pool_error_path_does_not_emit_automatic_instrument_error_event() { + let captured = Arc::new(Mutex::new(Vec::new())); + let subscriber = tracing_subscriber::fmt() + .with_ansi(false) + .with_max_level(Level::DEBUG) + .with_writer({ + let captured = Arc::clone(&captured); + move || SharedWriter(Arc::clone(&captured)) + }) + .finish(); + let dispatch = tracing::Dispatch::new(subscriber); + let _guard = tracing::dispatcher::set_default(&dispatch); + + let pool = Pool::::empty(); + let connector = TestConnector::fail("dial failed"); + let error = pool + .reuse_or_connect_with( + &connector, + Arc::new(ConnectionBuilder::default()), + matching_server(), + ) + .await + .expect_err("connector should fail"); + + assert!(matches!(error, ConnectError::Connector { .. })); + + let output = String::from_utf8(captured.lock().expect("log buffer poisoned").clone()) + .expect("log output must be valid UTF-8"); + assert!( + !output.contains("ERROR"), + "unexpected automatic error event: {output}", + ); + } + + #[tokio::test] + async fn spawn_try_release_waits_until_no_other_waiter_holds_entry() { + let pool = Pool::::empty(); + let quic = crate::connection::tests::MockConnection::new(); + let connection = Arc::new(mock_connection(quic.clone())); + let server = inserted_identity(); + + pool.try_insert(connection) + .await + .expect("insert should succeed"); + let reusable = pool + .connections + .get(&server) + .expect("entry should exist") + .value() + .clone(); + + quic.set_terminal_error(test_connection_error("release")); + pool.clone().spawn_try_release(server.clone()); + tokio::task::yield_now().await; + + assert_eq!(pool.len(), 1, "held waiter must keep the entry alive"); + + drop(reusable); + pool.clone().spawn_try_release(server); + wait_until("entry release after waiter drop", || pool.is_empty()).await; + } + + #[tokio::test] + async fn try_insert_releases_entry_after_connection_closes() { + let pool = Pool::::empty(); + let quic = crate::connection::tests::MockConnection::new(); + + pool.try_insert(Arc::new(mock_connection(quic.clone()))) + .await + .expect("insert should succeed"); + + assert_eq!(pool.len(), 1); + + quic.set_terminal_error(test_connection_error("closed after insert")); + wait_until("entry release after inserted connection closes", || { + pool.is_empty() + }) + .await; + } + + #[tokio::test] + async fn try_insert_rejects_missing_identity() { + let pool = Pool::::empty(); + let connection = Arc::new(Connection::from_state_for_test( + ConnectionState::new_for_test( + Arc::new(IdentityOverrideConnection::new(None)), + Arc::new(crate::protocol::Protocols::new()), + ), + )); + + let error = pool + .try_insert(connection) + .await + .expect_err("missing identity should fail"); + + assert!(matches!(error, InsertError::MissingIdentity)); + assert!(pool.is_empty()); + } + + #[tokio::test] + async fn try_insert_rejects_invalid_identity() { + let pool = Pool::::empty(); + let connection = Arc::new(Connection::from_state_for_test( + ConnectionState::new_for_test( + Arc::new(IdentityOverrideConnection::new(Some( + "not a valid authority", + ))), + Arc::new(crate::protocol::Protocols::new()), + ), + )); + + let error = pool + .try_insert(connection) + .await + .expect_err("invalid identity should fail"); + + assert!(matches!(error, InsertError::InvalidIdentity)); + assert!(pool.is_empty()); + } + + #[test] + fn test_pool_len_empty() { + let pool = Pool::::empty(); + assert_eq!(pool.len(), 0); + } + + #[test] + fn test_pool_clear_empty() { + let pool = Pool::::empty(); + pool.clear(); + assert_eq!(pool.len(), 0); + } + + #[test] + fn test_pool_clear_with_entries() { + let pool = Pool::::empty(); + + let auth1: Authority = "example.com:443".parse().unwrap(); + let auth2: Authority = "other.com:443".parse().unwrap(); + let auth3: Authority = "test.net:443".parse().unwrap(); + + pool.connections + .entry(auth1) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + pool.connections + .entry(auth2) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + pool.connections + .entry(auth3) + .or_insert_with(|| Arc::new(ReuseableConnection::pending())); + + assert_eq!(pool.len(), 3, "should have 3 entries before clear"); + + pool.clear(); + assert_eq!(pool.len(), 0, "should have 0 entries after clear"); + } } diff --git a/src/protocol.rs b/src/protocol.rs index f90b81d..6c0de54 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -21,19 +21,19 @@ //! //! # Handler access pattern //! -//! Handlers receive protocol access through [`crate::server::Request::protocols()`] -//! and [`crate::server::Response::protocols()`], which return `&Arc`. -//! Combined with [`crate::stream_id::StreamId`], a handler can derive -//! per-request/session handles from the connection-scoped protocol state: +//! Raw handlers receive protocol access through +//! [`crate::endpoint::UnresolvedRequest::connection`], which exposes +//! connection-scoped protocol state. Combined with [`crate::stream_id::StreamId`], +//! a handler can derive per-request/session handles from that state: //! //! ```ignore -//! // Native h3x handler: -//! let dhttp = request.protocols().get::().unwrap(); -//! let stream_id = request.stream_id(); +//! // Raw h3x handler: +//! let dhttp = request.connection.protocols().get::().unwrap(); +//! let stream_id = request.stream_id; //! //! // Hypothetical extension protocol: //! let proto = request.protocols().get::().expect("MyProtocol required"); -//! let session = proto.create_session(request.stream_id()); +//! let session = proto.create_session(request.stream_id); //! ``` //! //! In hyper handlers, the same data is available via request extensions: @@ -59,8 +59,8 @@ //! 3. The runtime protocol is **connection-scoped**: created once, shared across all //! streams. Per-request or per-session state should be produced by handler-facing //! methods (e.g. `create_session(stream_id)`) rather than stored in [`Protocols`]. -//! 4. Erase transport-specific types at the boundary: use [`crate::codec::BoxReadStream`], -//! [`crate::codec::BoxWriteStream`], or [`crate::quic::DynConnection`] to hold +//! 4. Erase transport-specific types at the boundary: use [`crate::quic::BoxQuicStreamReader`], +//! [`crate::quic::BoxQuicStreamWriter`], or [`crate::quic::DynConnection`] to hold //! connection capabilities without leaking generic `C`. use std::{ @@ -76,7 +76,7 @@ use std::{ use futures::future::BoxFuture; use crate::{ - codec::{ErasedPeekableBiStream, ErasedPeekableUniStream}, + codec::{BoxPeekableStreamReader, BoxStreamWriter}, connection::StreamError, quic::{self, ConnectionError}, }; @@ -167,8 +167,8 @@ impl Protocols { pub(crate) async fn accept_uni( &self, - mut stream: ErasedPeekableUniStream, - ) -> Result, StreamError> { + mut stream: BoxPeekableStreamReader, + ) -> Result, StreamError> { for layer in self.layers.values() { match layer.accept_uni(stream).await? { StreamVerdict::Accepted => return Ok(StreamVerdict::Accepted), @@ -183,8 +183,8 @@ impl Protocols { pub(crate) async fn accept_bi( &self, - mut stream: ErasedPeekableBiStream, - ) -> Result, StreamError> { + mut stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> Result, StreamError> { for layer in self.layers.values() { match layer.accept_bi(stream).await? { StreamVerdict::Accepted => return Ok(StreamVerdict::Accepted), @@ -301,15 +301,15 @@ pub trait Protocol: Any + Send + Sync + Debug { /// Returns whether the stream was accepted or should be passed to the next layer. fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>>; + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>>; /// Handles an incoming bidirectional stream. /// Returns whether the stream was accepted or should be passed to the next layer. fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>>; + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture<'a, Result, StreamError>>; } /// Verdict for stream handling in protocol layers. @@ -323,12 +323,37 @@ pub enum StreamVerdict { #[cfg(all(test, feature = "dquic"))] mod tests { - use std::sync::Arc; + use std::{ + borrow::Cow, + collections::HashSet, + pin::Pin, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, + task::{Context, Poll}, + }; - use futures::future::BoxFuture; + use bytes::Bytes; + use dhttp_identity::identity::{self as authority, LocalAuthority as _, RemoteAuthority as _}; + use futures::{ + FutureExt, Sink, SinkExt, Stream, + future::{BoxFuture, pending}, + task::noop_waker_ref, + }; + use tokio::io::AsyncReadExt; use super::*; - use crate::quic::{self, ConnectionError}; + use crate::{ + codec::{PeekableStreamReader, SinkWriter, StreamReader}, + error::Code, + quic::{ + self, BoxQuicStreamReader, BoxQuicStreamWriter, ConnectionError, GetStreamId, + Lifecycle, ManageStream, ResetStream, StopStream, WithLocalAuthority, + WithRemoteAuthority, + }, + varint::VarInt, + }; // Minimal mock protocol (runtime layer). #[derive(Debug)] @@ -337,19 +362,386 @@ mod tests { impl Protocol for MockProtocol { fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) } fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) } } + #[derive(Debug)] + enum TestVerdict { + Accepted, + Passed, + Error, + } + + #[derive(Debug)] + struct VerdictProtocol { + uni: TestVerdict, + bi: TestVerdict, + } + + fn verdict_error() -> StreamError { + StreamError::Reset { + code: VarInt::from_u32(0x0102), + } + } + + impl Protocol for VerdictProtocol { + fn accept_uni<'a>( + &'a self, + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { + Box::pin(async move { + match self.uni { + TestVerdict::Accepted => Ok(StreamVerdict::Accepted), + TestVerdict::Passed => Ok(StreamVerdict::Passed(stream)), + TestVerdict::Error => Err(verdict_error()), + } + }) + } + + fn accept_bi<'a>( + &'a self, + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { + Box::pin(async move { + match self.bi { + TestVerdict::Accepted => Ok(StreamVerdict::Accepted), + TestVerdict::Passed => Ok(StreamVerdict::Passed(stream)), + TestVerdict::Error => Err(verdict_error()), + } + }) + } + } + + fn peekable_uni_stream() -> BoxPeekableStreamReader { + let (reader, _writer) = quic::test::mock_stream_pair(VarInt::from_u32(0)); + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)) + } + + fn peekable_bi_stream() -> (BoxPeekableStreamReader, BoxStreamWriter) { + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(4)); + ( + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)), + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + async fn peekable_uni_stream_with_bytes(bytes: &[u8]) -> BoxPeekableStreamReader { + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(0)); + writer + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test uni bytes"); + writer.close().await.expect("close test uni stream"); + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)) + } + + async fn peekable_bi_stream_with_bytes( + bytes: &[u8], + ) -> (BoxPeekableStreamReader, BoxStreamWriter) { + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(4)); + writer + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test bidi bytes"); + writer.close().await.expect("close test bidi stream"); + ( + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)), + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + #[derive(Debug, Default)] + struct RoutingObservation { + calls: AtomicUsize, + reads: Mutex>>, + } + + #[derive(Debug)] + struct RoutedProtocol { + observation: Arc, + } + + impl Protocol for RoutedProtocol { + fn accept_uni<'a>( + &'a self, + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { + Box::pin(async move { + let mut stream = stream; + let call = self.observation.calls.fetch_add(1, Ordering::SeqCst); + let read_len = if call == 0 { 1 } else { 2 }; + let mut buf = vec![0; read_len]; + stream + .read_exact(&mut buf) + .await + .expect("read routed uni bytes"); + self.observation + .reads + .lock() + .expect("lock routed uni observation") + .push(buf); + if call == 0 { + Ok(StreamVerdict::Passed(stream)) + } else { + Ok(StreamVerdict::Accepted) + } + }) + } + + fn accept_bi<'a>( + &'a self, + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { + Box::pin(async move { + let mut stream = stream; + let call = self.observation.calls.fetch_add(1, Ordering::SeqCst); + let read_len = if call == 0 { 1 } else { 2 }; + let mut buf = vec![0; read_len]; + stream + .0 + .read_exact(&mut buf) + .await + .expect("read routed bidi bytes"); + self.observation + .reads + .lock() + .expect("lock routed bidi observation") + .push(buf); + if call == 0 { + Ok(StreamVerdict::Passed(stream)) + } else { + Ok(StreamVerdict::Accepted) + } + }) + } + } + + #[derive(Debug)] + struct TestLocalAuthority; + + impl authority::LocalAuthority for TestLocalAuthority { + fn name(&self) -> &str { + "test-local" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + fn sign(&self, _data: &[u8]) -> BoxFuture<'_, Result, authority::SignError>> { + Box::pin(async { Ok(Vec::new()) }) + } + } + + #[derive(Debug)] + struct TestRemoteAuthority; + + impl authority::RemoteAuthority for TestRemoteAuthority { + fn name(&self) -> &str { + "test-remote" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + } + + #[derive(Debug)] + struct TestReadStream; + + impl GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(VarInt::from_u32(0))) + } + } + + impl StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Stream for TestReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } + } + + #[derive(Debug)] + struct TestWriteStream; + + impl GetStreamId for TestWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(VarInt::from_u32(0))) + } + } + + impl ResetStream for TestWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Sink for TestWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Debug, Default)] + struct TestConnection; + + impl quic::ManageStream for TestConnection { + type StreamReader = TestReadStream; + type StreamWriter = TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { + pending().await + } + + async fn open_uni(&self) -> Result { + pending().await + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for TestConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority(&self) -> Result, ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for TestConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority(&self) -> Result, ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for TestConnection { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> ConnectionError { + pending().await + } + } + + #[derive(Debug, Clone)] + struct CountingFactory { + id: u64, + calls: Arc, + } + + impl PartialEq for CountingFactory { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } + } + + impl Eq for CountingFactory {} + + impl Hash for CountingFactory { + fn hash(&self, state: &mut H) { + self.id.hash(state); + } + } + + impl fmt::Display for CountingFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CountingFactory({})", self.id) + } + } + + impl ProductProtocol for CountingFactory { + type Protocol = MockProtocol; + + fn init<'a>( + &'a self, + _: &'a Arc, + _: &'a Protocols, + ) -> BoxFuture<'a, Result> { + Box::pin(async move { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(MockProtocol) + }) + } + } + /// Test-only mock protocol factory. #[derive(Debug, Clone, Hash, PartialEq, Eq)] struct MockFactoryFoo(u64); @@ -368,7 +760,7 @@ mod tests { _: &'a Arc, _: &'a Protocols, ) -> BoxFuture<'a, Result> { - unimplemented!("not used in identity tests") + Box::pin(async { Ok(MockProtocol) }) } } @@ -389,15 +781,18 @@ mod tests { impl Protocol for MockProtocol2 { fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) } fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture< + 'a, + Result, StreamError>, + > { Box::pin(async move { Ok(StreamVerdict::Passed(stream)) }) } } @@ -410,7 +805,35 @@ mod tests { _: &'a Arc, _: &'a Protocols, ) -> BoxFuture<'a, Result> { - unimplemented!("not used in identity tests") + Box::pin(async { Ok(MockProtocol2) }) + } + } + + #[derive(Debug, Clone, Hash, PartialEq, Eq)] + struct FailingFactory(&'static str); + + impl fmt::Display for FailingFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "FailingFactory({})", self.0) + } + } + + impl ProductProtocol for FailingFactory { + type Protocol = MockProtocol2; + + fn init<'a>( + &'a self, + _: &'a Arc, + _: &'a Protocols, + ) -> BoxFuture<'a, Result> { + Box::pin(async move { + Err(ConnectionError::Application { + source: quic::ApplicationError { + code: Code::H3_INTERNAL_ERROR, + reason: Cow::Borrowed("failing factory"), + }, + }) + }) } } @@ -446,4 +869,443 @@ mod tests { let b = MockFactoryBar(1); assert_ne!(identity::(a), identity::(b)); } + + #[tokio::test] + async fn mock_protocol_passes_uni_and_bi_streams() { + let protocol = MockProtocol; + + let uni = protocol + .accept_uni(peekable_uni_stream_with_bytes(b"mp").await) + .await + .expect("mock protocol passes uni stream"); + let StreamVerdict::Passed(mut uni) = uni else { + panic!("mock protocol must pass uni stream"); + }; + let mut uni_buf = [0; 2]; + uni.read_exact(&mut uni_buf) + .await + .expect("read mock-passed uni bytes"); + assert_eq!(&uni_buf, b"mp"); + + let bi = protocol + .accept_bi(peekable_bi_stream_with_bytes(b"mb").await) + .await + .expect("mock protocol passes bidi stream"); + let StreamVerdict::Passed((mut reader, _writer)) = bi else { + panic!("mock protocol must pass bidi stream"); + }; + let mut bi_buf = [0; 2]; + reader + .read_exact(&mut bi_buf) + .await + .expect("read mock-passed bidi bytes"); + assert_eq!(&bi_buf, b"mb"); + } + + #[tokio::test] + async fn test_connection_support_traits_return_expected_values() { + let conn = TestConnection; + + let local = conn + .local_authority() + .await + .expect("local authority lookup succeeds"); + assert!(local.is_none()); + let remote = conn + .remote_authority() + .await + .expect("remote authority lookup succeeds"); + assert!(remote.is_none()); + conn.check().expect("test connection health check succeeds"); + conn.close(Code::H3_NO_ERROR, Cow::Borrowed("test close")); + + assert!(conn.open_bi().now_or_never().is_none()); + assert!(conn.open_uni().now_or_never().is_none()); + assert!(conn.accept_bi().now_or_never().is_none()); + assert!(conn.accept_uni().now_or_never().is_none()); + assert!(conn.closed().now_or_never().is_none()); + } + + #[tokio::test] + async fn test_agents_and_streams_expose_minimal_trait_behavior() { + let local = TestLocalAuthority; + assert_eq!(local.name(), "test-local"); + assert!(local.cert_chain().is_empty()); + assert_eq!( + local + .sign(b"payload") + .await + .expect("test local authority signs"), + Vec::::new() + ); + + let remote = TestRemoteAuthority; + assert_eq!(remote.name(), "test-remote"); + assert!(remote.cert_chain().is_empty()); + + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + + let mut reader = TestReadStream; + assert!(matches!( + Pin::new(&mut reader).poll_stream_id(&mut cx), + Poll::Ready(Ok(id)) if id == VarInt::from_u32(0) + )); + assert!(matches!( + Pin::new(&mut reader).poll_stop(&mut cx, VarInt::from_u32(7)), + Poll::Ready(Ok(())) + )); + assert!(matches!( + Pin::new(&mut reader).poll_next(&mut cx), + Poll::Ready(None) + )); + + let mut writer = TestWriteStream; + assert!(matches!( + Pin::new(&mut writer).poll_stream_id(&mut cx), + Poll::Ready(Ok(id)) if id == VarInt::from_u32(0) + )); + assert!(matches!( + Pin::new(&mut writer).poll_reset(&mut cx, VarInt::from_u32(8)), + Poll::Ready(Ok(())) + )); + assert!(matches!( + Pin::new(&mut writer).poll_ready(&mut cx), + Poll::Ready(Ok(())) + )); + Pin::new(&mut writer) + .start_send(Bytes::from_static(b"ignored")) + .expect("writer send succeeds"); + assert!(matches!( + Pin::new(&mut writer).poll_flush(&mut cx), + Poll::Ready(Ok(())) + )); + assert!(matches!( + Pin::new(&mut writer).poll_close(&mut cx), + Poll::Ready(Ok(())) + )); + } + + #[test] + fn protocol_registry_get_insert_and_debug() { + let mut protocols = Protocols::new(); + assert_eq!(format!("{protocols:?}"), "[]"); + assert!(protocols.get::().is_none()); + + protocols.insert(MockProtocol); + protocols.insert(MockProtocol2); + assert!(protocols.get::().is_some()); + assert!(protocols.get::().is_some()); + let debug = format!("{protocols:?}"); + assert!(debug.contains("MockProtocol")); + assert!(debug.contains("MockProtocol2")); + } + + #[tokio::test] + async fn protocol_registry_accept_uni_passes_through_empty_registry() { + let protocols = Protocols::default(); + let verdict = protocols + .accept_uni(peekable_uni_stream_with_bytes(b"ok").await) + .await + .expect("empty registry must pass uni stream"); + let StreamVerdict::Passed(mut stream) = verdict else { + panic!("empty registry must not accept uni stream"); + }; + + let mut buf = [0; 2]; + stream + .read_exact(&mut buf) + .await + .expect("read passed uni bytes"); + assert_eq!(&buf, b"ok"); + } + + #[tokio::test] + async fn protocol_registry_accept_bi_passes_through_empty_registry() { + let protocols = Protocols::default(); + let verdict = protocols + .accept_bi(peekable_bi_stream_with_bytes(b"bi").await) + .await + .expect("empty registry must pass bidi stream"); + let StreamVerdict::Passed((mut reader, _writer)) = verdict else { + panic!("empty registry must not accept bidi stream"); + }; + + let mut buf = [0; 2]; + reader + .read_exact(&mut buf) + .await + .expect("read passed bidi bytes"); + assert_eq!(&buf, b"bi"); + } + + #[tokio::test] + async fn protocol_registry_accept_uni_resets_stream_before_next_layer() { + let observation = Arc::new(RoutingObservation::default()); + let mut protocols = Protocols::new(); + protocols.insert(RoutedProtocol::<1> { + observation: observation.clone(), + }); + protocols.insert(RoutedProtocol::<2> { + observation: observation.clone(), + }); + + let verdict = protocols + .accept_uni(peekable_uni_stream_with_bytes(b"ab").await) + .await + .expect("layered uni routing succeeds"); + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!(observation.calls.load(Ordering::SeqCst), 2); + assert_eq!( + *observation + .reads + .lock() + .expect("lock layered uni observation"), + vec![b"a".to_vec(), b"ab".to_vec()] + ); + } + + #[tokio::test] + async fn protocol_registry_accept_bi_resets_stream_before_next_layer() { + let observation = Arc::new(RoutingObservation::default()); + let mut protocols = Protocols::new(); + protocols.insert(RoutedProtocol::<1> { + observation: observation.clone(), + }); + protocols.insert(RoutedProtocol::<2> { + observation: observation.clone(), + }); + + let verdict = protocols + .accept_bi(peekable_bi_stream_with_bytes(b"cd").await) + .await + .expect("layered bidi routing succeeds"); + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!(observation.calls.load(Ordering::SeqCst), 2); + assert_eq!( + *observation + .reads + .lock() + .expect("lock layered bidi observation"), + vec![b"c".to_vec(), b"cd".to_vec()] + ); + } + + #[tokio::test] + async fn protocol_registry_accept_uni_covers_accept_pass_and_error() { + let mut accepted = Protocols::new(); + accepted.insert(VerdictProtocol { + uni: TestVerdict::Accepted, + bi: TestVerdict::Passed, + }); + assert!(matches!( + accepted + .accept_uni(peekable_uni_stream()) + .await + .expect("accepted uni verdict"), + StreamVerdict::Accepted + )); + + let mut passed = Protocols::new(); + passed.insert(VerdictProtocol { + uni: TestVerdict::Passed, + bi: TestVerdict::Passed, + }); + assert!(matches!( + passed + .accept_uni(peekable_uni_stream()) + .await + .expect("passed uni verdict"), + StreamVerdict::Passed(_) + )); + + let mut errored = Protocols::new(); + errored.insert(VerdictProtocol { + uni: TestVerdict::Error, + bi: TestVerdict::Passed, + }); + let error = match errored.accept_uni(peekable_uni_stream()).await { + Ok(_) => panic!("errored protocol must fail"), + Err(error) => error, + }; + assert!(matches!(error, StreamError::Reset { code } if code == VarInt::from_u32(0x0102))); + } + + #[tokio::test] + async fn protocol_registry_accept_bi_covers_accept_pass_and_error() { + let mut accepted = Protocols::new(); + accepted.insert(VerdictProtocol { + uni: TestVerdict::Passed, + bi: TestVerdict::Accepted, + }); + assert!(matches!( + accepted + .accept_bi(peekable_bi_stream()) + .await + .expect("accepted bidi verdict"), + StreamVerdict::Accepted + )); + + let mut passed = Protocols::new(); + passed.insert(VerdictProtocol { + uni: TestVerdict::Passed, + bi: TestVerdict::Passed, + }); + assert!(matches!( + passed + .accept_bi(peekable_bi_stream()) + .await + .expect("passed bidi verdict"), + StreamVerdict::Passed(_) + )); + + let mut errored = Protocols::new(); + errored.insert(VerdictProtocol { + uni: TestVerdict::Passed, + bi: TestVerdict::Error, + }); + let error = match errored.accept_bi(peekable_bi_stream()).await { + Ok(_) => panic!("errored protocol must fail"), + Err(error) => error, + }; + assert!(matches!(error, StreamError::Reset { code } if code == VarInt::from_u32(0x0102))); + } + + #[test] + fn identified_initializer_hashes_and_compares_by_factory_identity() { + let initializers = HashSet::from([ + identity::(MockFactoryFoo(1)), + identity::(MockFactoryFoo(1)), + identity::(MockFactoryFoo(2)), + identity::(MockFactoryBar(1)), + ]); + + assert_eq!(initializers.len(), 3); + } + + #[test] + fn mock_factories_format_and_compare_by_configured_identity() { + let calls = Arc::new(AtomicUsize::new(0)); + assert_eq!( + CountingFactory { + id: 9, + calls: calls.clone(), + }, + CountingFactory { id: 9, calls } + ); + assert_ne!( + CountingFactory { + id: 9, + calls: Arc::new(AtomicUsize::new(0)), + }, + CountingFactory { + id: 10, + calls: Arc::new(AtomicUsize::new(0)), + } + ); + + assert_eq!(MockFactoryFoo(1).to_string(), "MockFactory"); + assert_eq!(MockFactoryBar(1).to_string(), "MockFactory2"); + assert_eq!( + FailingFactory("format").to_string(), + "FailingFactory(format)" + ); + } + + #[tokio::test] + async fn mock_factory_initializers_insert_their_protocols() { + let conn = Arc::new(TestConnection); + let mut protocols = Protocols::new(); + + MockFactoryFoo(11) + .init_protocols(&conn, &mut protocols) + .await + .expect("foo factory initializes mock protocol"); + MockFactoryBar(12) + .init_protocols(&conn, &mut protocols) + .await + .expect("bar factory initializes second mock protocol"); + + assert!(protocols.get::().is_some()); + assert!(protocols.get::().is_some()); + } + + #[tokio::test] + async fn second_mock_protocol_passes_uni_and_bi_streams() { + let protocol = MockProtocol2; + + let uni = protocol + .accept_uni(peekable_uni_stream_with_bytes(b"u2").await) + .await + .expect("second mock protocol passes uni stream"); + let StreamVerdict::Passed(mut uni) = uni else { + panic!("second mock protocol must pass uni stream"); + }; + let mut uni_buf = [0; 2]; + uni.read_exact(&mut uni_buf) + .await + .expect("read second mock-passed uni bytes"); + assert_eq!(&uni_buf, b"u2"); + + let bi = protocol + .accept_bi(peekable_bi_stream_with_bytes(b"b2").await) + .await + .expect("second mock protocol passes bidi stream"); + let StreamVerdict::Passed((mut reader, _writer)) = bi else { + panic!("second mock protocol must pass bidi stream"); + }; + let mut bi_buf = [0; 2]; + reader + .read_exact(&mut bi_buf) + .await + .expect("read second mock-passed bidi bytes"); + assert_eq!(&bi_buf, b"b2"); + } + + #[tokio::test] + async fn identified_initializer_propagates_init_errors_without_inserting_protocol() { + let initializer = IdentifiedProtocolInitializer::new(FailingFactory("boom")); + let conn = Arc::new(TestConnection); + let mut protocols = Protocols::new(); + let init: &dyn InitProtocols = ops::Deref::deref(&initializer); + + let error = init + .init_protocols(&conn, &mut protocols) + .await + .expect_err("failing initializer must propagate connection error"); + + assert!(matches!( + error, + ConnectionError::Application { source } + if source.code == Code::H3_INTERNAL_ERROR + && source.reason == Cow::Borrowed("failing factory") + )); + assert!(protocols.get::().is_none()); + } + + #[tokio::test] + async fn identified_initializer_delegates_formatting_and_initializes_once() { + let calls = Arc::new(AtomicUsize::new(0)); + let initializer = IdentifiedProtocolInitializer::new(CountingFactory { + id: 7, + calls: calls.clone(), + }); + let conn = Arc::new(TestConnection); + let mut protocols = Protocols::new(); + + assert_eq!(initializer.to_string(), "CountingFactory(7)"); + assert!(format!("{initializer:?}").contains("CountingFactory")); + + initializer + .init_protocols(&conn, &mut protocols) + .await + .expect("protocol initializes"); + initializer + .init_protocols(&conn, &mut protocols) + .await + .expect("duplicate init is skipped"); + + assert!(protocols.get::().is_some()); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } } diff --git a/src/qpack.rs b/src/qpack.rs index bdf46e4..38e7eee 100644 --- a/src/qpack.rs +++ b/src/qpack.rs @@ -6,6 +6,7 @@ pub mod field; pub mod instruction; pub mod integer; pub mod protocol; +pub mod settings; pub mod r#static; pub mod string; @@ -15,19 +16,19 @@ mod tests { use bytes::Bytes; use futures::SinkExt; - use http::Request; + use http::{HeaderMap, HeaderValue, Request, Response}; use tokio::{io::AsyncReadExt, sync::Notify, time}; use crate::{ codec::{DecodeExt, EncodeExt, SinkWriter, StreamReader}, dhttp::{ frame::{Frame, stream::FrameStream}, - settings::{QpackBlockedStreams, QpackMaxTableCapacity, Settings}, + settings::Settings, stream::UnidirectionalStream, }, qpack::{ - algorithm::{DynamicCompressAlgo, HuffmanAlways, StaticCompressAlgo}, - decoder::{Decoder, MessageStreamReader}, + algorithm::{Algorithm, DynamicCompressAlgo, HuffmanAlways, StaticCompressAlgo}, + decoder::{Decoder, QPackMessageStreamReader}, encoder::Encoder, field::FieldSection, }, @@ -116,7 +117,7 @@ mod tests { let request = Request::builder() .method("POST") - .uri("https://h3x.demo.lab.genmeta.net/api/v1/upload") + .uri("https://h3x.demo.lab.dhttp.net/api/v1/upload") .header("user-agent", "genemta-curl/0.3.0") .header("accept", "*/*") .body("Hello, World!") @@ -156,7 +157,11 @@ mod tests { }); let response = tokio::spawn(async move { - let response_stream = MessageStreamReader::new(response_stream, decoder.clone()); + let response_stream = QPackMessageStreamReader::new( + VarInt::from_u32(0), + response_stream, + decoder.clone(), + ); let mut frame_stream = pin!(FrameStream::new(StreamReader::new(response_stream))); let frame = frame_stream.as_mut().next_frame().await.unwrap().unwrap(); @@ -184,13 +189,516 @@ mod tests { .expect("test timedout"); } + #[tokio::test] + async fn static_response_trailer_and_empty_field_sections_roundtrip() { + let encode_strategy = StaticCompressAlgo::new(HuffmanAlways); + + let (encoder_stream_reader, encoder_stream_writer) = + mock_stream_pair_with_capacity(VarInt::from_u32(1), 64); + let (decoder_stream_reader, decoder_stream_writer) = + mock_stream_pair_with_capacity(VarInt::from_u32(2), 64); + + let init_encoder = tokio::spawn(async move { + let encoder_stream = UnidirectionalStream::initial( + UnidirectionalStream::QPACK_ENCODER_STREAM_TYPE, + SinkWriter::new(encoder_stream_writer), + ) + .await + .unwrap(); + + let decoder_stream = Box::pin(Deferred::from(async move { + let decoder_stream = StreamReader::new(decoder_stream_reader) + .decode::>() + .await + .unwrap(); + assert_eq!( + decoder_stream.r#type(), + UnidirectionalStream::QPACK_DECODER_STREAM_TYPE + ); + Ok::<_, ConnectionError>(decoder_stream) + })); + + Arc::new(Encoder::new( + Arc::::default(), + Box::pin((encoder_stream).into_encode_sink()), + Box::pin((decoder_stream).into_decode_stream()), + )) + }); + + let init_decoder = tokio::spawn(async move { + let decoder_stream = UnidirectionalStream::initial( + UnidirectionalStream::QPACK_DECODER_STREAM_TYPE, + SinkWriter::new(decoder_stream_writer), + ) + .await + .unwrap(); + + let encoder_stream = Box::pin(Deferred::from(async move { + let encoder_stream = StreamReader::new(encoder_stream_reader) + .decode::>() + .await + .unwrap(); + assert_eq!( + encoder_stream.r#type(), + UnidirectionalStream::QPACK_ENCODER_STREAM_TYPE + ); + Ok::<_, ConnectionError>(encoder_stream) + })); + + Arc::new(Decoder::new( + Arc::new(Settings::default()), + Box::pin((decoder_stream).into_encode_sink()), + Box::pin((encoder_stream).into_decode_stream()), + )) + }); + + let (encoder, decoder) = time::timeout(time::Duration::from_secs(1), async move { + tokio::try_join!(init_encoder, init_decoder).unwrap() + }) + .await + .expect("setup timed out"); + + let response = Response::builder() + .status(204) + .header("server", "h3x-test") + .header("cache-control", "no-store") + .body(()) + .unwrap(); + let (response_parts, ()) = response.into_parts(); + + let mut trailers = HeaderMap::new(); + trailers.insert("x-checksum", HeaderValue::from_static("sha256:abc123")); + trailers.insert("x-empty", HeaderValue::from_static("")); + + let cases = [ + (VarInt::from_u32(4), FieldSection::from(response_parts)), + (VarInt::from_u32(8), FieldSection::trailer(trailers)), + ( + VarInt::from_u32(12), + FieldSection::trailer(HeaderMap::new()), + ), + ]; + + for (stream_id, field_section) in cases { + let expected = field_section.clone(); + let (read_stream, write_stream) = mock_stream_pair_with_capacity(stream_id, 64); + let mut write_stream = SinkWriter::new(write_stream); + + let header_frame = Encoder::encode( + &*encoder, + field_section.iter(), + &encode_strategy, + &mut write_stream, + ) + .await + .unwrap(); + write_stream.encode_one(header_frame).await.unwrap(); + write_stream.flush().await.unwrap(); + + let read_stream = + QPackMessageStreamReader::new(stream_id, read_stream, decoder.clone()); + let mut frame_stream = pin!(FrameStream::new(StreamReader::new(read_stream))); + let frame = frame_stream.as_mut().next_frame().await.unwrap().unwrap(); + assert_eq!(frame.r#type(), Frame::HEADERS_FRAME_TYPE); + + let decoded = Decoder::decode(&*decoder, frame).await.unwrap(); + assert_eq!(decoded, expected); + } + } + fn settings_with_dynamic_table() -> Arc { let mut settings = Settings::default(); - settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(4096))); - settings.set(QpackBlockedStreams::setting(VarInt::from_u32(100))); + settings.set(crate::qpack::settings::QpackMaxTableCapacity::setting( + VarInt::from_u32(4096), + )); + settings.set(crate::qpack::settings::QpackBlockedStreams::setting( + VarInt::from_u32(100), + )); Arc::new(settings) } + #[tokio::test] + async fn settings_with_dynamic_table_enables_qpack_limits() { + let settings = settings_with_dynamic_table(); + + assert_eq!(settings.qpack_max_table_capacity(), VarInt::from_u32(4096)); + assert_eq!(settings.qpack_blocked_streams(), VarInt::from_u32(100)); + } + + #[test] + fn qpack_error_displays_and_connection_codes_are_stable() { + use crate::{ + error::{Code, H3ConnectionError}, + qpack::{ + decoder::{InvalidDynamicTableReference, QPackEncoderStreamError}, + encoder::{QPackDecoderStreamError, QPackEncoderError}, + protocol::{QPackProtocolDisabled, QPackProtocolFactory}, + }, + }; + + let decoder_stream_error = + QPackDecoderStreamError::AcknowledgeNonExistSection { stream_id: 99 }; + assert_eq!( + decoder_stream_error.to_string(), + "acknowledging non-existent or non-blocking stream" + ); + assert_eq!( + decoder_stream_error.code(), + Code::QPACK_DECODER_STREAM_ERROR + ); + assert_eq!( + QPackDecoderStreamError::IncrementKnownReceivedCountOverflow.to_string(), + "known received count increment overflow (known_received_count + increment > inserted_count)" + ); + assert_eq!( + QPackDecoderStreamError::IncrementZero.to_string(), + "insert count increment of zero is not allowed" + ); + + assert_eq!( + QPackEncoderError::CapacityExceedsMax { + capacity: 2, + max: 1, + } + .to_string(), + "dynamic table capacity 2 exceeds SETTINGS_QPACK_MAX_TABLE_CAPACITY 1" + ); + assert_eq!( + QPackEncoderError::CannotEvict.to_string(), + "cannot evict entries: no evictable entries in dynamic table" + ); + + let encoder_stream_error = + QPackEncoderStreamError::ReferencedStaticEntryNotExisted { index: 100 }; + assert_eq!( + encoder_stream_error.to_string(), + "referenced static table entry 100 does not exist" + ); + assert_eq!( + encoder_stream_error.code(), + Code::QPACK_ENCODER_STREAM_ERROR + ); + assert_eq!( + QPackEncoderStreamError::SetDynamicTableCapacityExceeded.to_string(), + "setting dynamic table capacity exceeded maximum allowed" + ); + assert_eq!( + QPackEncoderStreamError::NoEvictableEntryForInsertion.to_string(), + "no evictable entry for insertion, cannot insert new entry" + ); + assert_eq!( + QPackEncoderStreamError::ReferencedDynamicEntryNotExisted { index: 7 }.to_string(), + "referenced dynamic table entry 7 does not exist" + ); + + let invalid_reference = InvalidDynamicTableReference::IndexOverflow; + assert_eq!(invalid_reference.to_string(), "reference index overflow"); + assert_eq!(invalid_reference.code(), Code::QPACK_DECOMPRESSION_FAILED); + assert_eq!( + InvalidDynamicTableReference::ReferencedStaticEntryNotExisted { index: 99 }.to_string(), + "referenced static table entry 99 does not exist" + ); + assert_eq!( + InvalidDynamicTableReference::ReferencedDynamicEntryNotExisted { index: 8 }.to_string(), + "referenced dynamic table entry 8 does not exist" + ); + + assert_eq!(QPackProtocolFactory::new().to_string(), "QPACK"); + assert_eq!( + QPackProtocolDisabled.to_string(), + "qpack protocol is disabled" + ); + } + + #[test] + fn encoder_state_rejects_invalid_capacity_ack_and_increment_paths() { + use crate::qpack::encoder::{EncoderState, QPackDecoderStreamError, QPackEncoderError}; + + let settings = settings_with_dynamic_table(); + let mut encoder_state = EncoderState::new(settings); + + let capacity_error = encoder_state + .set_max_table_capacity(4097) + .expect_err("capacity above peer setting must be rejected"); + assert!(matches!( + capacity_error, + QPackEncoderError::CapacityExceedsMax { + capacity: 4097, + max: 4096 + } + )); + + encoder_state + .set_max_table_capacity(64) + .expect("capacity within peer setting"); + encoder_state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-unacked"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert should fit the dynamic table"); + + let shrink_error = encoder_state + .set_max_table_capacity(0) + .expect_err("unacknowledged entry is not evictable"); + assert!(matches!(shrink_error, QPackEncoderError::CannotEvict)); + assert_eq!(encoder_state.table_inserted_count(), 1); + assert_eq!(encoder_state.table_dropped_count(), 0); + + let ack_error = encoder_state + .on_section_acknowledgment(VarInt::from_u32(4).into_inner()) + .expect_err("unknown stream id must not be acknowledged"); + assert!(matches!( + ack_error, + QPackDecoderStreamError::AcknowledgeNonExistSection { stream_id: 4 } + )); + + let zero_increment = encoder_state + .on_insert_count_increment(0) + .expect_err("zero insert count increment is a decoder stream error"); + assert!(matches!( + zero_increment, + QPackDecoderStreamError::IncrementZero + )); + + let overflow_increment = encoder_state + .on_insert_count_increment(2) + .expect_err("increment beyond inserted-known count must fail"); + assert!(matches!( + overflow_increment, + QPackDecoderStreamError::IncrementKnownReceivedCountOverflow + )); + } + + #[test] + fn decoder_state_rejects_invalid_encoder_stream_operations() { + use crate::qpack::decoder::{DecoderState, QPackEncoderStreamError}; + + let settings = settings_with_dynamic_table(); + let mut decoder_state = DecoderState::new(settings); + + let capacity_error = decoder_state + .set_dynamic_table_capacity(4097) + .expect_err("capacity above local setting must be rejected"); + assert!(matches!( + capacity_error, + QPackEncoderStreamError::SetDynamicTableCapacityExceeded + )); + + let missing_static = decoder_state + .insert_with_name_reference(true, 99, Bytes::from_static(b"value")) + .expect_err("invalid static table name reference must fail"); + assert!(matches!( + missing_static, + QPackEncoderStreamError::ReferencedStaticEntryNotExisted { index: 99 } + )); + + let missing_dynamic = decoder_state + .insert_with_name_reference(false, 0, Bytes::from_static(b"value")) + .expect_err("invalid dynamic table name reference must fail"); + assert!(matches!( + missing_dynamic, + QPackEncoderStreamError::ReferencedDynamicEntryNotExisted { index: 0 } + )); + + decoder_state + .set_dynamic_table_capacity(0) + .expect("zero capacity is within settings"); + let no_space = decoder_state + .insert_with_literal_name(Bytes::from_static(b"x-too-large"), Bytes::new()) + .expect_err("entry cannot fit in a zero-capacity dynamic table"); + assert!(matches!( + no_space, + QPackEncoderStreamError::NoEvictableEntryForInsertion + )); + + let missing_duplicate = decoder_state + .duplicate(0) + .expect_err("invalid duplicate reference must fail"); + assert!(matches!( + missing_duplicate, + QPackEncoderStreamError::ReferencedDynamicEntryNotExisted { index: 0 } + )); + } + + #[test] + fn encoded_field_section_prefix_handles_boundary_values() { + use crate::{codec::DecodeError, qpack::field::EncodedFieldSectionPrefix}; + + assert_eq!(EncodedFieldSectionPrefix::encode_ric(0, 4096), 0); + assert_eq!( + EncodedFieldSectionPrefix::encode_ric(1, 0), + 1, + "non-zero RIC with disabled dynamic table is preserved as an invalid wire value" + ); + assert_eq!( + EncodedFieldSectionPrefix::decode_ric(1, 0, 0), + Err(DecodeError::DecompressionFailed) + ); + assert_eq!( + EncodedFieldSectionPrefix::decode_ric(3, 32, 0), + Err(DecodeError::DecompressionFailed) + ); + assert_eq!( + EncodedFieldSectionPrefix::resolve_base(u64::MAX, false, 1), + Err(DecodeError::ArithmeticOverflow) + ); + assert_eq!( + EncodedFieldSectionPrefix::resolve_base(0, true, 0), + Err(DecodeError::ArithmeticOverflow) + ); + } + + fn apply_pending_encoder_instructions( + encoder_state: &mut crate::qpack::encoder::EncoderState, + decoder_state: &mut crate::qpack::decoder::DecoderState, + ) { + use crate::qpack::encoder::EncoderInstruction; + + while let Some(instruction) = encoder_state.pending_instructions.pop_front() { + match instruction { + EncoderInstruction::SetDynamicTableCapacity { capacity } => { + decoder_state.set_dynamic_table_capacity(capacity).unwrap(); + } + EncoderInstruction::InsertWithNameReference { + is_static, + name_index, + value, + .. + } => { + let abs_index = if is_static { + name_index + } else { + decoder_state.table_inserted_count() - name_index - 1 + }; + decoder_state + .insert_with_name_reference(is_static, abs_index, value) + .unwrap(); + } + EncoderInstruction::InsertWithLiteralName { name, value, .. } => { + decoder_state.insert_with_literal_name(name, value).unwrap(); + } + EncoderInstruction::Duplicate { index } => { + let abs_index = decoder_state.table_inserted_count() - index - 1; + decoder_state.duplicate(abs_index).unwrap(); + } + } + } + } + + #[test] + fn apply_pending_encoder_instructions_replays_duplicate_entries() { + use crate::qpack::{decoder::DecoderState, encoder::EncoderState}; + + let settings = settings_with_dynamic_table(); + let mut encoder_state = EncoderState::new(settings.clone()); + encoder_state + .set_max_table_capacity(4096) + .expect("set encoder capacity"); + let original_index = encoder_state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-duplicate"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert original entry"); + let duplicated_index = encoder_state + .duplicate(original_index) + .expect("duplicate original entry"); + let mut decoder_state = DecoderState::new(settings); + + apply_pending_encoder_instructions(&mut encoder_state, &mut decoder_state); + + assert_eq!(decoder_state.table_inserted_count(), duplicated_index + 1); + assert_eq!( + decoder_state.dynamic_table.get(duplicated_index), + Some(&crate::qpack::field::FieldLine { + name: Bytes::from_static(b"x-duplicate"), + value: Bytes::from_static(b"value"), + }) + ); + } + + #[tokio::test] + async fn dynamic_response_trailer_and_empty_field_sections_roundtrip() { + use crate::qpack::{ + decoder::DecoderState, + encoder::EncoderState, + field::{EncodedFieldSectionPrefix, FieldLine}, + }; + + let settings = settings_with_dynamic_table(); + let mut encoder_state = EncoderState::new(settings.clone()); + encoder_state + .set_max_table_capacity(4096) + .expect("set encoder capacity"); + + let mut decoder_state = DecoderState::new(settings.clone()); + apply_pending_encoder_instructions(&mut encoder_state, &mut decoder_state); + + let encode_strategy = DynamicCompressAlgo::new(HuffmanAlways); + + let response = Response::builder() + .status(201) + .header("server", "h3x-dynamic-test") + .header("x-dynamic-name", "first") + .header("cache-control", "no-store") + .body(()) + .unwrap(); + let (response_parts, ()) = response.into_parts(); + + let mut trailers = HeaderMap::new(); + trailers.insert("x-dynamic-name", HeaderValue::from_static("second")); + trailers.insert("x-trailer-only", HeaderValue::from_static("present")); + + let cases = [ + FieldSection::from(response_parts), + FieldSection::trailer(trailers), + FieldSection::trailer(HeaderMap::new()), + ]; + + for field_section in cases { + let expected: Vec = field_section.iter().collect(); + let output = encode_strategy + .compress(&mut encoder_state, field_section.iter(), true) + .await; + apply_pending_encoder_instructions(&mut encoder_state, &mut decoder_state); + + let required_insert_count = EncodedFieldSectionPrefix::decode_ric( + output.prefix.encoded_insert_count, + settings.qpack_max_table_capacity().into_inner(), + decoder_state.table_inserted_count(), + ) + .expect("decode required insert count"); + let base = EncodedFieldSectionPrefix::resolve_base( + required_insert_count, + output.prefix.sign, + output.prefix.delta_base, + ) + .expect("resolve base"); + + let decoded: Vec = output + .representations + .iter() + .map(|repr| decoder_state.decompress(repr, base).expect("decompress")) + .collect(); + assert_eq!(decoded, expected); + } + + assert_eq!(encoder_state.table_capacity(), 4096); + assert!( + encoder_state.table_inserted_count() >= 3, + "dynamic response/trailer roundtrip should insert reusable table entries" + ); + assert_eq!( + decoder_state.table_inserted_count(), + encoder_state.table_inserted_count() + ); + } + /// Direct roundtrip test: encode with DynamicCompressAlgo, then decode using /// the decoder state. Bypasses mock streams to test the algorithm itself. #[tokio::test] @@ -198,7 +706,7 @@ mod tests { use crate::qpack::{ algorithm::{Algorithm, DynamicCompressAlgo, HuffmanAlways}, decoder::DecoderState, - encoder::{EncoderInstruction, EncoderState}, + encoder::EncoderState, field::{EncodedFieldSectionPrefix, FieldLine}, }; @@ -243,38 +751,7 @@ mod tests { // Apply encoder instructions to decoder state // (simulates instructions arriving via encoder stream) - for instruction in encoder_state.pending_instructions() { - match instruction { - EncoderInstruction::SetDynamicTableCapacity { capacity } => { - decoder_state.set_dynamic_table_capacity(*capacity).unwrap(); - } - EncoderInstruction::InsertWithNameReference { - is_static, - name_index, - value, - .. - } => { - // Wire-format name_index: relative for dynamic, absolute for static - let abs_index = if *is_static { - *name_index - } else { - decoder_state.table_inserted_count() - name_index - 1 - }; - decoder_state - .insert_with_name_reference(*is_static, abs_index, value.clone()) - .unwrap(); - } - EncoderInstruction::InsertWithLiteralName { name, value, .. } => { - decoder_state - .insert_with_literal_name(name.clone(), value.clone()) - .unwrap(); - } - EncoderInstruction::Duplicate { index } => { - let abs_index = decoder_state.table_inserted_count() - index - 1; - decoder_state.duplicate(abs_index).unwrap(); - } - } - } + apply_pending_encoder_instructions(&mut encoder_state, &mut decoder_state); // Decode the RIC to get required_insert_count let max_table_capacity = settings.qpack_max_table_capacity().into_inner(); @@ -404,7 +881,7 @@ mod tests { let request = Request::builder() .method("POST") - .uri("https://h3x.demo.lab.genmeta.net/api/v1/upload") + .uri("https://h3x.demo.lab.dhttp.net/api/v1/upload") .header("user-agent", "genmeta-curl/0.3.0") .header("accept", "*/*") .header("x-custom-header", "custom-value") @@ -449,7 +926,11 @@ mod tests { }); let response_task = tokio::spawn(async move { - let response_stream = MessageStreamReader::new(response_stream, decoder.clone()); + let response_stream = QPackMessageStreamReader::new( + VarInt::from_u32(0), + response_stream, + decoder.clone(), + ); let mut frame_stream = pin!(FrameStream::new(StreamReader::new(response_stream))); let frame = frame_stream.as_mut().next_frame().await.unwrap().unwrap(); diff --git a/src/qpack/algorithm.rs b/src/qpack/algorithm.rs index b93c3b9..b9da498 100644 --- a/src/qpack/algorithm.rs +++ b/src/qpack/algorithm.rs @@ -451,9 +451,12 @@ mod tests { use bytes::Bytes; use crate::{ - dhttp::settings::{QpackBlockedStreams, QpackMaxTableCapacity, Settings}, + dhttp::settings::Settings, qpack::{ - algorithm::{Algorithm, CompressOutput, DynamicCompressAlgo, HuffmanNever}, + algorithm::{ + Algorithm, CompressOutput, DynamicCompressAlgo, HuffmanAlways, HuffmanNever, + StaticCompressAlgo, + }, encoder::EncoderState, field::{EncodedFieldSectionPrefix, FieldLine, FieldLineRepresentation}, }, @@ -462,10 +465,12 @@ mod tests { fn state_with_capacity(table_capacity: u32) -> EncoderState { let mut settings = Settings::default(); - settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32( - table_capacity, - ))); - settings.set(QpackBlockedStreams::setting(VarInt::from_u32(100))); + settings.set(crate::qpack::settings::QpackMaxTableCapacity::setting( + VarInt::from_u32(table_capacity), + )); + settings.set(crate::qpack::settings::QpackBlockedStreams::setting( + VarInt::from_u32(100), + )); let mut state = EncoderState::new(Arc::new(settings)); if table_capacity > 0 { state @@ -486,6 +491,10 @@ mod tests { DynamicCompressAlgo::new(HuffmanNever) } + fn huffman_algo() -> DynamicCompressAlgo { + DynamicCompressAlgo::new(HuffmanAlways) + } + async fn do_compress( state: &mut EncoderState, entries: Vec, @@ -531,6 +540,27 @@ mod tests { assert!(output.prefix.sign); } + #[tokio::test] + async fn insert_without_blocking_populates_table_but_emits_literal() { + let mut state = state_with_capacity(256); + + let output = do_compress(&mut state, vec![field_line("x-custom", "hello")], false).await; + + assert_eq!(state.table_inserted_count(), 1); + assert!(output.max_referenced_index.is_none()); + assert_eq!(output.prefix.encoded_insert_count, 0); + assert!(matches!( + &output.representations[..], + [FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: false, + name_huffman: false, + name, + value_huffman: false, + value, + }] if name.as_ref() == b"x-custom" && value.as_ref() == b"hello" + )); + } + #[tokio::test] async fn second_request_uses_pre_base_dynamic_ref() { let mut state = state_with_capacity(256); @@ -672,6 +702,247 @@ mod tests { )); } + #[tokio::test] + async fn static_algorithm_emits_all_three_representation_forms() { + let mut state = state_with_capacity(256); + let algo = StaticCompressAlgo::new(HuffmanAlways); + let output = algo + .compress( + &mut state, + vec![ + field_line(":method", "GET"), + field_line(":path", "/custom"), + field_line("x-literal", "value"), + ], + false, + ) + .await; + + assert_eq!(output.representations.len(), 3); + assert!(matches!( + output.representations[0], + FieldLineRepresentation::IndexedFieldLine { + is_static: true, + index: 17 + } + )); + assert!(matches!( + &output.representations[1], + FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: true, + is_static: true, + huffman: true, + value, + .. + } if value == b"/custom".as_slice() + )); + assert!(matches!( + &output.representations[2], + FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: true, + name_huffman: true, + name, + value_huffman: true, + value, + } if name == b"x-literal".as_slice() && value == b"value".as_slice() + )); + assert_eq!(state.table_inserted_count(), 0); + assert!(output.max_referenced_index.is_none()); + assert_eq!( + output.prefix, + EncodedFieldSectionPrefix { + encoded_insert_count: 0, + sign: false, + delta_base: 0, + } + ); + } + + #[tokio::test] + async fn dynamic_literal_name_uses_huffman_flags_from_strategy() { + let mut state = state_with_capacity(0); + let output = huffman_algo() + .compress(&mut state, vec![field_line("x-huffman", "value")], true) + .await; + + assert!(matches!( + &output.representations[0], + FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: false, + name_huffman: true, + name, + value_huffman: true, + value, + } if name == b"x-huffman".as_slice() && value == b"value".as_slice() + )); + } + + #[tokio::test] + async fn all_sensitive_header_names_are_case_insensitive_and_never_inserted() { + let sensitive_names = [ + "authorization", + "AUTHORIZATION", + "proxy-authorization", + "Proxy-Authorization", + "cookie", + "COOKIE", + "set-cookie", + "Set-Cookie", + ]; + + for name in sensitive_names { + let mut state = state_with_capacity(256); + let output = do_compress(&mut state, vec![field_line(name, "secret")], true).await; + + assert_eq!(state.table_inserted_count(), 0, "{name} was inserted"); + assert!( + output.max_referenced_index.is_none(), + "{name} referenced dynamic table" + ); + match &output.representations[0] { + FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic, + .. + } + | FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic, + .. + } + | FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic, .. + } => assert!(*never_dynamic, "{name} did not set never_dynamic"), + other => panic!("expected literal representation for {name}, got {other:?}"), + } + } + } + + #[tokio::test] + async fn dynamic_exact_match_can_use_post_base_reference() { + let mut state = state_with_capacity(256); + let output = do_compress( + &mut state, + vec![ + field_line("x-repeat", "same"), + field_line("x-repeat", "same"), + ], + true, + ) + .await; + + assert_eq!(state.table_inserted_count(), 1); + assert_eq!(output.max_referenced_index, Some(0)); + assert!(matches!( + output.representations.as_slice(), + [ + FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { index: 0 }, + FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { index: 0 }, + ] + )); + } + + #[tokio::test] + async fn dynamic_name_reference_uses_pre_base_when_name_is_acknowledged() { + let mut state = state_with_capacity(48); + + let _ = do_compress(&mut state, vec![field_line("x-ref", "a")], true).await; + state.dynamic_table.known_received_count = state.dynamic_table.inserted_count; + + let output = do_compress(&mut state, vec![field_line("x-ref", "longer-value")], true).await; + + assert_eq!(state.table_inserted_count(), 1); + assert_eq!(output.max_referenced_index, Some(0)); + assert!(matches!( + &output.representations[0], + FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: false, + is_static: false, + name_index: 0, + huffman: false, + value, + } if value == b"longer-value".as_slice() + )); + } + + #[tokio::test] + async fn dynamic_name_reference_uses_post_base_when_name_is_new_in_section() { + let mut state = state_with_capacity(48); + let output = do_compress( + &mut state, + vec![ + field_line("x-ref", "a"), + field_line("x-ref", "longer-value"), + ], + true, + ) + .await; + + assert_eq!(state.table_inserted_count(), 1); + assert_eq!(output.max_referenced_index, Some(0)); + assert!(matches!( + &output.representations[1], + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: false, + name_index: 0, + huffman: false, + value, + } if value == b"longer-value".as_slice() + )); + } + + #[tokio::test] + async fn may_block_false_falls_back_when_dynamic_name_is_unacknowledged() { + let mut state = state_with_capacity(48); + + let _ = do_compress(&mut state, vec![field_line("x-ref", "a")], true).await; + assert_eq!(state.table_known_received_count(), 0); + + let output = + do_compress(&mut state, vec![field_line("x-ref", "longer-value")], false).await; + + assert_eq!(state.table_inserted_count(), 1); + assert!(output.max_referenced_index.is_none()); + assert!(matches!( + &output.representations[0], + FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: false, + name_huffman: false, + name, + value_huffman: false, + value, + } if name == b"x-ref".as_slice() && value == b"longer-value".as_slice() + )); + } + + #[tokio::test] + async fn prefix_required_insert_count_uses_largest_dynamic_reference() { + let mut state = state_with_capacity(256); + + let _ = do_compress( + &mut state, + vec![field_line("x-a", "one"), field_line("x-b", "two")], + true, + ) + .await; + state.dynamic_table.known_received_count = state.dynamic_table.inserted_count; + + let output = do_compress( + &mut state, + vec![field_line("x-a", "one"), field_line("x-b", "two")], + true, + ) + .await; + + assert_eq!(output.max_referenced_index, Some(1)); + assert_eq!( + output.prefix, + EncodedFieldSectionPrefix { + encoded_insert_count: EncodedFieldSectionPrefix::encode_ric(2, 256), + sign: false, + delta_base: 0, + } + ); + } + // --- Multiple entries --- #[tokio::test] @@ -821,7 +1092,7 @@ mod tests { use proptest::prelude::*; use crate::{ - dhttp::settings::{QpackBlockedStreams, QpackMaxTableCapacity, Settings}, + dhttp::settings::Settings, qpack::{ algorithm::{Algorithm, DynamicCompressAlgo, HuffmanNever}, decoder::DecoderState, @@ -833,8 +1104,12 @@ mod tests { fn settings_pair(capacity: u32) -> Arc { let mut settings = Settings::default(); - settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(capacity))); - settings.set(QpackBlockedStreams::setting(VarInt::from_u32(100))); + settings.set(crate::qpack::settings::QpackMaxTableCapacity::setting( + VarInt::from_u32(capacity), + )); + settings.set(crate::qpack::settings::QpackBlockedStreams::setting( + VarInt::from_u32(100), + )); Arc::new(settings) } @@ -874,6 +1149,34 @@ mod tests { encoder.pending_instructions.clear(); } + #[test] + fn apply_instructions_replays_all_encoder_instruction_variants() { + let settings = settings_pair(4096); + let mut encoder = EncoderState::new(settings.clone()); + encoder.set_max_table_capacity(4096).unwrap(); + encoder + .insert_with_literal_name( + false, + Bytes::from_static(b"x-first"), + false, + Bytes::from_static(b"one"), + ) + .unwrap(); + encoder + .insert_with_name_reference(true, 1, false, Bytes::from_static(b"/custom")) + .unwrap(); + encoder + .insert_with_name_reference(false, 0, false, Bytes::from_static(b"two")) + .unwrap(); + encoder.duplicate(0).unwrap(); + + let mut decoder = DecoderState::new(settings); + apply_instructions(&mut encoder, &mut decoder); + + assert_eq!(decoder.table_inserted_count(), 4); + assert!(encoder.pending_instructions().is_empty()); + } + fn verify_roundtrip( settings: &Settings, decoder: &DecoderState, diff --git a/src/qpack/decoder.rs b/src/qpack/decoder.rs index d31a68c..cfb4711 100644 --- a/src/qpack/decoder.rs +++ b/src/qpack/decoder.rs @@ -1,8 +1,9 @@ use std::{ collections::VecDeque, + ops::DerefMut, pin::{Pin, pin}, sync::{Arc, Mutex as SyncMutex}, - task::{Context, Poll, ready}, + task::{Context, Poll}, }; use bytes::Bytes; @@ -440,7 +441,7 @@ where pub async fn flush_instructions(&self) -> Result<(), StreamError> { let mut decoder_stream = self.decoder_stream.lock().await; - let mut decoder_stream = Pin::new(&mut *decoder_stream); + let mut decoder_stream = Pin::new(decoder_stream.deref_mut()); let instructions = stream::iter(self.pending_instructions()); decoder_stream.as_mut().send_all(instructions).await?; decoder_stream.as_mut().flush().await?; @@ -499,61 +500,77 @@ where } pin_project_lite::pin_project! { - /// A read stream wrapper that emits stream cancellation instruction when reset is received or stop sending is called. - pub struct MessageStreamReader { - decoder: Arc>, + /// A QPACK-aware message read stream wrapper that emits stream cancellation + /// instructions when reset is received or STOP_SENDING is called. + pub struct QPackMessageStreamReader { + stream_id: VarInt, + stream_cancellation_emitted: bool, + decoder: Arc, #[pin] stream: S, } } -impl MessageStreamReader { - pub fn new(stream: S, decoder: Arc>) -> Self { - Self { stream, decoder } +impl QPackMessageStreamReader { + pub fn new(stream_id: VarInt, stream: S, decoder: Arc) -> Self { + Self { + stream_id, + stream_cancellation_emitted: false, + stream, + decoder, + } } } -impl StopStream for MessageStreamReader { +impl StopStream for QPackMessageStreamReader> { fn poll_stop( self: Pin<&mut Self>, cx: &mut Context, code: VarInt, ) -> Poll> { - let mut project = self.project(); - let stream_id = ready!(project.stream.as_mut().poll_stream_id(cx))?.into_inner(); - ready!(project.stream.poll_stop(cx, code))?; - project - .decoder - .emit(DecoderInstruction::StreamCancellation { stream_id }); - Poll::Ready(Ok(())) + let project = self.project(); + let poll = project.stream.poll_stop(cx, code); + if !*project.stream_cancellation_emitted { + project + .decoder + .emit(DecoderInstruction::StreamCancellation { + stream_id: project.stream_id.into_inner(), + }); + *project.stream_cancellation_emitted = true; + } + poll } } -impl GetStreamId for MessageStreamReader { +impl GetStreamId for QPackMessageStreamReader { fn poll_stream_id( self: Pin<&mut Self>, - cx: &mut Context, + _cx: &mut Context, ) -> Poll> { - self.project().stream.poll_stream_id(cx) + Poll::Ready(Ok(*self.project().stream_id)) } } -impl Stream for MessageStreamReader { +impl Stream for QPackMessageStreamReader> { type Item = S::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut project = self.project(); - let stream_id = ready!(project.stream.as_mut().poll_stream_id(cx))?.into_inner(); - match project.stream.poll_next(cx) { - poll @ Poll::Ready(Some(Err(quic::StreamError::Reset { .. }))) => { - project - .decoder - .emit(DecoderInstruction::StreamCancellation { stream_id }); - poll - } - poll => poll, + let project = self.project(); + let poll = project.stream.poll_next(cx); + if matches!( + poll, + Poll::Ready(Some(Err(quic::StreamError::Reset { .. }))) + ) && !*project.stream_cancellation_emitted + { + project + .decoder + .emit(DecoderInstruction::StreamCancellation { + stream_id: project.stream_id.into_inner(), + }); + *project.stream_cancellation_emitted = true; } + poll } } @@ -661,14 +678,36 @@ where #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{ + collections::VecDeque, + io::Cursor, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; - use bytes::Bytes; + use bytes::{Buf, Bytes}; + use futures::{Sink, StreamExt, stream}; + use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf}; - use super::{DecoderInstruction, DecoderState}; + use super::{ + Decoder, DecoderInstruction, DecoderState, InvalidDynamicTableReference, + QPackEncoderStreamError, QPackMessageStreamReader, decompression_field_line_representation, + }; use crate::{ - dhttp::settings::{QpackMaxTableCapacity, Settings}, - qpack::r#static, + buflist::BufList, + codec::{DecodeFrom, EncodeInto}, + connection::StreamError, + dhttp::settings::Settings, + error::{Code, H3ConnectionError}, + qpack::{ + dynamic::DynamicTable, + encoder::EncoderInstruction, + field::{EncodedFieldSectionPrefix, FieldLine, FieldLineRepresentation}, + settings::QpackMaxTableCapacity, + r#static, + }, + quic::{self, GetStreamId, StopStream, StopStreamExt}, varint::VarInt, }; @@ -678,6 +717,259 @@ mod tests { Arc::new(settings) } + fn test_settings_with_max_field_section_size(capacity: u32, max_size: u32) -> Arc { + let mut settings = Settings::default(); + settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(capacity))); + settings.set(crate::dhttp::settings::MaxFieldSectionSize::setting( + VarInt::from_u32(max_size), + )); + Arc::new(settings) + } + + fn assert_connection_h3_code(error: StreamError, expected: Code) { + match error { + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + } => assert_eq!(source.code(), expected), + error => panic!("unexpected error: {error:?}"), + } + } + + async fn encode_decode_roundtrip( + instruction: DecoderInstruction, + ) -> Result { + let mut encoded = Vec::new(); + instruction.encode_into(Cursor::new(&mut encoded)).await?; + DecoderInstruction::decode_from(Cursor::new(encoded)).await + } + + fn dynamic_table_with_entries(entries: &[FieldLine]) -> DynamicTable { + let mut table = DynamicTable::new(); + table.capacity = 512; + for entry in entries { + table.index(entry.clone()); + } + table.known_received_count = table.inserted_count; + table + } + + #[derive(Clone, Default)] + struct RecordingDecoderSink { + instructions: Arc>>, + } + + type RecordedDecoderInstructions = Arc>>; + type TestEncoderStream = + stream::Iter>>; + type TestDecoder = Decoder; + + impl Sink for RecordingDecoderSink { + type Error = StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: DecoderInstruction) -> Result<(), Self::Error> { + self.instructions + .lock() + .expect("lock is not poisoned") + .push(item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + fn test_decoder( + settings: Arc, + instructions: Vec>, + ) -> (TestDecoder, RecordedDecoderInstructions) { + let sink = RecordingDecoderSink::default(); + let sent = sink.instructions.clone(); + ( + Decoder::new(settings, sink, stream::iter(instructions)), + sent, + ) + } + + pin_project_lite::pin_project! { + struct TestHeaderFrame { + stream_id: VarInt, + #[pin] + payload: Cursor>, + } + } + + impl AsyncRead for TestHeaderFrame { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.project().payload.poll_read(cx, buf) + } + } + + impl AsyncBufRead for TestHeaderFrame { + fn poll_fill_buf( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.project().payload.poll_fill_buf(cx) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + self.project().payload.consume(amt); + } + } + + impl GetStreamId for TestHeaderFrame { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + let this = self.project(); + Poll::Ready(Ok(*this.stream_id)) + } + } + + async fn encode_header_payload( + prefix: EncodedFieldSectionPrefix, + representations: Vec, + ) -> Vec { + let mut payload = BufList::new(); + prefix + .encode_into(&mut payload) + .await + .expect("prefix should encode"); + for representation in representations { + representation + .encode_into(&mut payload) + .await + .expect("field line representation should encode"); + } + payload.copy_to_bytes(payload.remaining()).to_vec() + } + + struct TestReadStream { + stream_id: VarInt, + stop_codes: Arc>>, + items: VecDeque>, + } + + impl futures::Stream for TestReadStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.items.pop_front()) + } + } + + struct PendingStopReadStream { + stop_codes: Arc>>, + } + + struct ResetWithoutStreamIdReadStream { + items: VecDeque>, + } + + impl futures::Stream for PendingStopReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + } + + impl futures::Stream for ResetWithoutStreamIdReadStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.items.pop_front()) + } + } + + impl GetStreamId for PendingStopReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + panic!("message stream reader stop should use constructor stream id") + } + } + + impl GetStreamId for ResetWithoutStreamIdReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + panic!("message stream reader reset should use constructor stream id") + } + } + + impl StopStream for PendingStopReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.stop_codes + .lock() + .expect("lock is not poisoned") + .push(code); + Poll::Pending + } + } + + impl StopStream for ResetWithoutStreamIdReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.stop_codes + .lock() + .expect("lock is not poisoned") + .push(code); + Poll::Ready(Ok(())) + } + } + #[test] fn decoder_state_default_construction() { let state = DecoderState::new(test_settings(0)); @@ -687,6 +979,44 @@ mod tests { assert!(state.pending_instructions.is_empty()); } + #[test] + fn decoder_error_types_report_qpack_error_codes() { + let encoder_stream_errors = [ + QPackEncoderStreamError::SetDynamicTableCapacityExceeded, + QPackEncoderStreamError::NoEvictableEntryForInsertion, + QPackEncoderStreamError::ReferencedStaticEntryNotExisted { index: 7 }, + QPackEncoderStreamError::ReferencedDynamicEntryNotExisted { index: 11 }, + ]; + + for error in encoder_stream_errors { + assert_eq!(error.code(), Code::QPACK_ENCODER_STREAM_ERROR); + } + + let invalid_references = [ + InvalidDynamicTableReference::IndexOverflow, + InvalidDynamicTableReference::ReferencedStaticEntryNotExisted { index: 7 }, + InvalidDynamicTableReference::ReferencedDynamicEntryNotExisted { index: 11 }, + ]; + + for error in invalid_references { + assert_eq!(error.code(), Code::QPACK_DECOMPRESSION_FAILED); + } + } + + #[test] + fn decoder_wrapper_emit_and_known_received_count_delegate_to_state() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + + decoder.emit(DecoderInstruction::StreamCancellation { stream_id: 9 }); + + assert_eq!(decoder.known_received_count(), 0); + let state = decoder.state.lock().expect("lock is not poisoned"); + assert_eq!( + state.pending_instructions.front(), + Some(&DecoderInstruction::StreamCancellation { stream_id: 9 }) + ); + } + #[test] fn emit_insert_count_increment() { let mut state = DecoderState::new(test_settings(256)); @@ -737,6 +1067,22 @@ mod tests { ); } + #[tokio::test] + async fn decoder_instruction_roundtrip_all_variants() { + let instructions = [ + DecoderInstruction::SectionAcknowledgment { stream_id: 7 }, + DecoderInstruction::StreamCancellation { stream_id: 11 }, + DecoderInstruction::InsertCountIncrement { increment: 5 }, + ]; + + for instruction in instructions { + let decoded = encode_decode_roundtrip(instruction) + .await + .expect("instruction should roundtrip"); + assert_eq!(decoded, instruction); + } + } + #[test] fn set_dynamic_table_capacity_within_limit() { let mut state = DecoderState::new(test_settings(256)); @@ -750,6 +1096,52 @@ mod tests { assert!(state.set_dynamic_table_capacity(512).is_err()); } + #[test] + fn update_known_received_count_merges_pending_increment() { + let mut state = DecoderState::new(test_settings(256)); + + state.update_known_received_count(1); + state.update_known_received_count(3); + + assert_eq!(state.dynamic_table.known_received_count, 3); + assert_eq!(state.pending_instructions.len(), 1); + assert_eq!( + state.pending_instructions[0], + DecoderInstruction::InsertCountIncrement { increment: 3 } + ); + } + + #[test] + fn set_dynamic_table_capacity_evicts_entries_until_within_limit() { + let mut state = DecoderState::new(test_settings(128)); + state.set_dynamic_table_capacity(128).unwrap(); + state + .insert_with_literal_name( + Bytes::from_static(b"header-1"), + Bytes::from_static(b"value-1"), + ) + .unwrap(); + state + .insert_with_literal_name( + Bytes::from_static(b"header-2"), + Bytes::from_static(b"value-2"), + ) + .unwrap(); + + state.set_dynamic_table_capacity(64).unwrap(); + + assert_eq!(state.dynamic_table.capacity, 64); + assert_eq!(state.dynamic_table.dropped_count, 1); + assert_eq!( + state + .dynamic_table + .entries() + .map(|(index, entry)| (index, entry.name.clone())) + .collect::>(), + vec![(1, Bytes::from_static(b"header-2"))] + ); + } + #[test] fn insert_with_literal_name() { let mut state = DecoderState::new(test_settings(4096)); @@ -779,6 +1171,100 @@ mod tests { assert_eq!(&entry.value[..], b"example.com"); } + #[test] + fn insert_with_dynamic_name_reference_reuses_existing_name() { + let mut state = DecoderState::new(test_settings(128)); + state.set_dynamic_table_capacity(128).unwrap(); + state + .insert_with_literal_name(Bytes::from_static(b"x-name"), Bytes::from_static(b"old")) + .expect("seed entry should fit"); + + state + .insert_with_name_reference(false, 0, Bytes::from_static(b"new")) + .expect("dynamic name reference should insert"); + + let entry = state + .dynamic_table + .get(1) + .expect("second entry should exist"); + assert_eq!(entry.name, Bytes::from_static(b"x-name")); + assert_eq!(entry.value, Bytes::from_static(b"new")); + assert_eq!(state.table_inserted_count(), 2); + } + + #[test] + fn insert_with_name_reference_evicts_acknowledged_entry_when_needed() { + let mut state = DecoderState::new(test_settings(80)); + state.set_dynamic_table_capacity(80).unwrap(); + state + .insert_with_literal_name(Bytes::from_static(b"x-old"), Bytes::from_static(b"old")) + .expect("seed entry should fit"); + + state + .insert_with_name_reference(true, 1, Bytes::from_static(b"/very/long/path/value")) + .expect("new entry should fit after eviction"); + + assert_eq!(state.dynamic_table.dropped_count, 1); + assert!(state.dynamic_table.get(0).is_none()); + let entry = state + .dynamic_table + .get(1) + .expect("inserted entry should remain"); + assert_eq!(entry.name, Bytes::from_static(b":path")); + assert_eq!(entry.value, Bytes::from_static(b"/very/long/path/value")); + } + + #[test] + fn insert_with_literal_name_evicts_acknowledged_entry_when_needed() { + let mut state = DecoderState::new(test_settings(80)); + state.set_dynamic_table_capacity(80).unwrap(); + state + .insert_with_literal_name(Bytes::from_static(b"x-old"), Bytes::from_static(b"old")) + .expect("seed entry should fit"); + + state + .insert_with_literal_name( + Bytes::from_static(b"x-new"), + Bytes::from_static(b"larger-new-value"), + ) + .expect("new literal should fit after eviction"); + + assert_eq!(state.dynamic_table.dropped_count, 1); + assert!(state.dynamic_table.get(0).is_none()); + let entry = state + .dynamic_table + .get(1) + .expect("inserted entry should remain"); + assert_eq!(entry.name, Bytes::from_static(b"x-new")); + assert_eq!(entry.value, Bytes::from_static(b"larger-new-value")); + } + + #[test] + fn insert_with_name_reference_errors_for_missing_references() { + let mut state = DecoderState::new(test_settings(4096)); + state.set_dynamic_table_capacity(4096).unwrap(); + + let missing_static = + state.insert_with_name_reference(true, 9999, Bytes::from_static(b"value")); + let missing_dynamic = + state.insert_with_name_reference(false, 0, Bytes::from_static(b"value")); + + assert!( + matches!( + missing_static, + Err(QPackEncoderStreamError::ReferencedStaticEntryNotExisted { index: 9999 }) + ), + "unexpected static result: {missing_static:?}" + ); + assert!( + matches!( + missing_dynamic, + Err(QPackEncoderStreamError::ReferencedDynamicEntryNotExisted { index: 0 }) + ), + "unexpected dynamic result: {missing_dynamic:?}" + ); + } + #[test] fn insert_fails_when_capacity_too_small() { let mut state = DecoderState::new(test_settings(4096)); @@ -789,6 +1275,442 @@ mod tests { assert!(result.is_err()); } + #[test] + fn duplicate_errors_for_missing_dynamic_entry() { + let mut state = DecoderState::new(test_settings(256)); + state.set_dynamic_table_capacity(256).unwrap(); + + let result = state.duplicate(0); + + assert!( + matches!( + result, + Err(QPackEncoderStreamError::ReferencedDynamicEntryNotExisted { index: 0 }) + ), + "Expected missing dynamic entry error, got {result:?}" + ); + } + + #[test] + fn duplicate_existing_entry_evicts_when_needed_and_updates_count() { + let mut state = DecoderState::new(test_settings(50)); + state.set_dynamic_table_capacity(50).unwrap(); + state + .insert_with_literal_name(Bytes::from_static(b"x-name"), Bytes::from_static(b"value")) + .expect("seed entry should fit"); + state.pending_instructions.clear(); + + state + .duplicate(0) + .expect("duplicate should fit after eviction"); + + assert_eq!(state.dynamic_table.dropped_count, 1); + assert!(state.dynamic_table.get(0).is_none()); + let entry = state.dynamic_table.get(1).expect("duplicate should remain"); + assert_eq!(entry.name, Bytes::from_static(b"x-name")); + assert_eq!(entry.value, Bytes::from_static(b"value")); + assert_eq!(state.dynamic_table.known_received_count, 2); + assert_eq!( + state.pending_instructions.back(), + Some(&DecoderInstruction::InsertCountIncrement { increment: 1 }) + ); + } + + #[tokio::test] + async fn receive_instruction_until_applies_encoder_stream_mutations() { + let instructions = vec![ + Ok(EncoderInstruction::SetDynamicTableCapacity { capacity: 128 }), + Ok(EncoderInstruction::InsertWithLiteralName { + name_huffman: false, + name: Bytes::from_static(b"x-name"), + value_huffman: false, + value: Bytes::from_static(b"value"), + }), + Ok(EncoderInstruction::InsertWithNameReference { + is_static: false, + name_index: 0, + huffman: false, + value: Bytes::from_static(b"referenced"), + }), + Ok(EncoderInstruction::Duplicate { index: 0 }), + ]; + let (decoder, _) = test_decoder(test_settings(128), instructions); + + decoder + .receive_instruction_until(3) + .await + .expect("encoder instructions should be applied"); + + let state = decoder.state.lock().expect("lock is not poisoned"); + assert_eq!(state.dynamic_table.capacity, 128); + assert_eq!(state.dynamic_table.inserted_count, 3); + assert_eq!( + state.dynamic_table.get(1).expect("referenced entry").name, + Bytes::from_static(b"x-name") + ); + assert_eq!( + state.dynamic_table.get(2).expect("duplicate entry").value, + Bytes::from_static(b"referenced") + ); + } + + #[tokio::test] + async fn receive_instruction_until_returns_when_count_already_known() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + + decoder + .receive_instruction_until(0) + .await + .expect("zero required count is already known"); + } + + #[tokio::test] + async fn receive_instruction_until_reports_closed_encoder_stream() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + + let error = decoder + .receive_instruction_until(1) + .await + .expect_err("missing encoder stream instruction should close the connection"); + + assert_connection_h3_code(error, Code::H3_CLOSED_CRITICAL_STREAM); + } + + #[tokio::test] + async fn receive_instruction_until_propagates_encoder_stream_read_error() { + let (decoder, _) = test_decoder( + test_settings(128), + vec![Err(StreamError::Reset { + code: VarInt::from_u32(33), + })], + ); + + let error = decoder + .receive_instruction_until(1) + .await + .expect_err("encoder stream read error should propagate"); + + assert!(matches!( + error, + StreamError::Reset { code } if code == VarInt::from_u32(33) + )); + } + + #[tokio::test] + async fn receive_instruction_until_rejects_invalid_encoder_instructions() { + let cases = [ + vec![Ok(EncoderInstruction::SetDynamicTableCapacity { + capacity: 256, + })], + vec![ + Ok(EncoderInstruction::SetDynamicTableCapacity { capacity: 64 }), + Ok(EncoderInstruction::InsertWithNameReference { + is_static: true, + name_index: 9999, + huffman: false, + value: Bytes::from_static(b"value"), + }), + ], + vec![ + Ok(EncoderInstruction::SetDynamicTableCapacity { capacity: 0 }), + Ok(EncoderInstruction::InsertWithLiteralName { + name_huffman: false, + name: Bytes::from_static(b"name"), + value_huffman: false, + value: Bytes::from_static(b"value"), + }), + ], + ]; + + for instructions in cases { + let (decoder, _) = test_decoder(test_settings(128), instructions); + let error = decoder + .receive_instruction_until(1) + .await + .expect_err("invalid encoder instruction should close the connection"); + + assert_connection_h3_code(error, Code::QPACK_ENCODER_STREAM_ERROR); + } + } + + #[tokio::test] + async fn flush_instructions_sends_and_drains_pending_queue() { + let (decoder, sent) = test_decoder(test_settings(128), Vec::new()); + decoder.emit(DecoderInstruction::StreamCancellation { stream_id: 1 }); + decoder.emit(DecoderInstruction::StreamCancellation { stream_id: 2 }); + + decoder + .flush_instructions() + .await + .expect("pending instructions should flush"); + + assert_eq!( + sent.lock().expect("lock is not poisoned").as_slice(), + &[ + DecoderInstruction::StreamCancellation { stream_id: 1 }, + DecoderInstruction::StreamCancellation { stream_id: 2 }, + ] + ); + assert!( + decoder + .state + .lock() + .expect("lock is not poisoned") + .pending_instructions + .is_empty() + ); + } + + #[tokio::test] + async fn decode_emits_section_acknowledgment_for_dynamic_reference() { + let (decoder, sent) = test_decoder(test_settings(128), Vec::new()); + { + let mut state = decoder.state.lock().expect("lock is not poisoned"); + state.set_dynamic_table_capacity(128).unwrap(); + state + .insert_with_literal_name(Bytes::from_static(b"x-name"), Bytes::from_static(b"ok")) + .expect("seed entry should fit"); + state.pending_instructions.clear(); + } + + let prefix = EncodedFieldSectionPrefix { + encoded_insert_count: EncodedFieldSectionPrefix::encode_ric(1, 128), + sign: false, + delta_base: 0, + }; + let payload = encode_header_payload( + prefix, + vec![FieldLineRepresentation::IndexedFieldLine { + is_static: false, + index: 0, + }], + ) + .await; + let frame = TestHeaderFrame { + stream_id: VarInt::from_u32(23), + payload: Cursor::new(payload), + }; + + let section = decoder.decode(frame).await.expect("header should decode"); + + assert_eq!( + section.header_map.get("x-name").expect("x-name header"), + "ok" + ); + assert_eq!( + sent.lock().expect("lock is not poisoned").as_slice(), + &[DecoderInstruction::SectionAcknowledgment { stream_id: 23 }] + ); + } + + #[tokio::test] + async fn decode_rejects_field_section_that_exceeds_configured_limit() { + let (decoder, _) = test_decoder( + test_settings_with_max_field_section_size(128, 1), + Vec::new(), + ); + let payload = encode_header_payload( + EncodedFieldSectionPrefix { + encoded_insert_count: 0, + sign: false, + delta_base: 0, + }, + vec![FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: false, + name_huffman: false, + name: Bytes::from_static(b"x-large"), + value_huffman: false, + value: Bytes::from_static(b"value"), + }], + ) + .await; + let frame = TestHeaderFrame { + stream_id: VarInt::from_u32(31), + payload: Cursor::new(payload), + }; + + let error = decoder + .decode(frame) + .await + .expect_err("limit should reject field"); + + match error { + StreamError::H3 { source } => assert_eq!(source.code(), Code::H3_EXCESSIVE_LOAD), + error => panic!("unexpected error: {error:?}"), + } + } + + #[tokio::test] + async fn message_stream_reader_stop_and_reset_emit_stream_cancellation() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + let decoder = Arc::new(decoder); + let stop_codes = Arc::new(Mutex::new(Vec::new())); + let mut stopped_reader = QPackMessageStreamReader::new( + VarInt::from_u32(41), + TestReadStream { + stream_id: VarInt::from_u32(41), + stop_codes: stop_codes.clone(), + items: VecDeque::new(), + }, + decoder.clone(), + ); + + stopped_reader + .stop(VarInt::from_u32(7)) + .await + .expect("stop should be forwarded"); + + assert_eq!( + stop_codes.lock().expect("lock is not poisoned").as_slice(), + &[VarInt::from_u32(7)] + ); + assert_eq!( + decoder + .state + .lock() + .expect("lock is not poisoned") + .pending_instructions + .back(), + Some(&DecoderInstruction::StreamCancellation { stream_id: 41 }) + ); + + let mut reset_reader = QPackMessageStreamReader::new( + VarInt::from_u32(43), + TestReadStream { + stream_id: VarInt::from_u32(43), + stop_codes, + items: VecDeque::from([Err(quic::StreamError::Reset { + code: VarInt::from_u32(11), + })]), + }, + decoder.clone(), + ); + + let item = reset_reader + .next() + .await + .expect("reset item should be yielded"); + assert!(matches!( + item, + Err(quic::StreamError::Reset { code }) if code == VarInt::from_u32(11) + )); + assert_eq!( + decoder + .state + .lock() + .expect("lock is not poisoned") + .pending_instructions + .back(), + Some(&DecoderInstruction::StreamCancellation { stream_id: 43 }) + ); + } + + #[test] + fn message_stream_reader_pending_stop_emits_stream_cancellation_once() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + let decoder = Arc::new(decoder); + let stop_codes = Arc::new(Mutex::new(Vec::new())); + let mut reader = Box::pin(QPackMessageStreamReader::new( + VarInt::from_u32(45), + PendingStopReadStream { + stop_codes: stop_codes.clone(), + }, + decoder.clone(), + )); + let waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&waker); + + assert!(matches!( + reader.as_mut().poll_stop(&mut cx, VarInt::from_u32(7)), + Poll::Pending + )); + assert!(matches!( + reader.as_mut().poll_stop(&mut cx, VarInt::from_u32(7)), + Poll::Pending + )); + + assert_eq!( + stop_codes.lock().expect("lock is not poisoned").as_slice(), + &[VarInt::from_u32(7), VarInt::from_u32(7)] + ); + let state = decoder.state.lock().expect("lock is not poisoned"); + assert_eq!(state.pending_instructions.len(), 1); + assert_eq!( + state.pending_instructions.back(), + Some(&DecoderInstruction::StreamCancellation { stream_id: 45 }) + ); + } + + #[tokio::test] + async fn message_stream_reader_does_not_emit_duplicate_cancellation_after_stop() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + let decoder = Arc::new(decoder); + let stop_codes = Arc::new(Mutex::new(Vec::new())); + let mut reader = QPackMessageStreamReader::new( + VarInt::from_u32(46), + TestReadStream { + stream_id: VarInt::from_u32(46), + stop_codes, + items: VecDeque::from([Err(quic::StreamError::Reset { + code: VarInt::from_u32(11), + })]), + }, + decoder.clone(), + ); + + reader + .stop(VarInt::from_u32(7)) + .await + .expect("stop should be forwarded"); + decoder.emit(DecoderInstruction::InsertCountIncrement { increment: 1 }); + + let item = reader.next().await.expect("reset item should be yielded"); + assert!(matches!( + item, + Err(quic::StreamError::Reset { code }) if code == VarInt::from_u32(11) + )); + + let state = decoder.state.lock().expect("lock is not poisoned"); + assert_eq!( + state.pending_instructions.iter().collect::>(), + vec![ + &DecoderInstruction::StreamCancellation { stream_id: 46 }, + &DecoderInstruction::InsertCountIncrement { increment: 1 }, + ] + ); + } + + #[tokio::test] + async fn message_stream_reader_reset_uses_constructor_stream_id() { + let (decoder, _) = test_decoder(test_settings(128), Vec::new()); + let decoder = Arc::new(decoder); + let mut reader = QPackMessageStreamReader::new( + VarInt::from_u32(47), + ResetWithoutStreamIdReadStream { + items: VecDeque::from([Err(quic::StreamError::Reset { + code: VarInt::from_u32(11), + })]), + }, + decoder.clone(), + ); + + let item = reader.next().await.expect("reset item should be yielded"); + assert!(matches!( + item, + Err(quic::StreamError::Reset { code }) if code == VarInt::from_u32(11) + )); + + assert_eq!( + decoder + .state + .lock() + .expect("lock is not poisoned") + .pending_instructions + .back(), + Some(&DecoderInstruction::StreamCancellation { stream_id: 47 }) + ); + } + #[test] fn static_table_lookup() { // Verify well-known static table entries @@ -800,9 +1722,6 @@ mod tests { #[test] fn decompression_static_indexed_field_line() { - use super::decompression_field_line_representation; - use crate::qpack::{dynamic::DynamicTable, field::FieldLineRepresentation}; - let dt = DynamicTable::new(); let repr = FieldLineRepresentation::IndexedFieldLine { is_static: true, @@ -815,9 +1734,6 @@ mod tests { #[test] fn decompression_literal_with_literal_name() { - use super::decompression_field_line_representation; - use crate::qpack::{dynamic::DynamicTable, field::FieldLineRepresentation}; - let dt = DynamicTable::new(); let repr = FieldLineRepresentation::LiteralFieldLineWithLiteralName { never_dynamic: false, @@ -833,9 +1749,6 @@ mod tests { #[test] fn decompression_invalid_static_index() { - use super::decompression_field_line_representation; - use crate::qpack::{dynamic::DynamicTable, field::FieldLineRepresentation}; - let dt = DynamicTable::new(); let repr = FieldLineRepresentation::IndexedFieldLine { is_static: true, @@ -843,4 +1756,101 @@ mod tests { }; assert!(decompression_field_line_representation(&repr, 0, &dt).is_err()); } + + #[test] + fn decompression_dynamic_references_resolve_relative_and_post_base_indices() { + let dt = dynamic_table_with_entries(&[ + FieldLine { + name: Bytes::from_static(b"header-1"), + value: Bytes::from_static(b"value-1"), + }, + FieldLine { + name: Bytes::from_static(b"header-2"), + value: Bytes::from_static(b"value-2"), + }, + FieldLine { + name: Bytes::from_static(b"header-3"), + value: Bytes::from_static(b"value-3"), + }, + ]); + + let indexed = FieldLineRepresentation::IndexedFieldLine { + is_static: false, + index: 0, + }; + let indexed_post_base = + FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { index: 0 }; + let literal_name_ref = FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: false, + is_static: false, + name_index: 1, + huffman: false, + value: Bytes::from_static(b"patched"), + }; + let literal_post_base = + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: false, + name_index: 0, + huffman: false, + value: Bytes::from_static(b"patched-post"), + }; + + assert_eq!( + decompression_field_line_representation(&indexed, 3, &dt).unwrap(), + FieldLine { + name: Bytes::from_static(b"header-3"), + value: Bytes::from_static(b"value-3"), + } + ); + assert_eq!( + decompression_field_line_representation(&indexed_post_base, 2, &dt).unwrap(), + FieldLine { + name: Bytes::from_static(b"header-3"), + value: Bytes::from_static(b"value-3"), + } + ); + assert_eq!( + decompression_field_line_representation(&literal_name_ref, 3, &dt).unwrap(), + FieldLine { + name: Bytes::from_static(b"header-2"), + value: Bytes::from_static(b"patched"), + } + ); + assert_eq!( + decompression_field_line_representation(&literal_post_base, 2, &dt).unwrap(), + FieldLine { + name: Bytes::from_static(b"header-3"), + value: Bytes::from_static(b"patched-post"), + } + ); + } + + #[test] + fn decompression_reports_index_overflow_and_missing_dynamic_entry() { + let dt = dynamic_table_with_entries(&[FieldLine { + name: Bytes::from_static(b"header-1"), + value: Bytes::from_static(b"value-1"), + }]); + + let overflow = FieldLineRepresentation::IndexedFieldLine { + is_static: false, + index: 1, + }; + let missing = FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { index: 10 }; + + assert!( + matches!( + decompression_field_line_representation(&overflow, 0, &dt), + Err(InvalidDynamicTableReference::IndexOverflow) + ), + "expected relative index overflow" + ); + assert!( + matches!( + decompression_field_line_representation(&missing, 1, &dt), + Err(InvalidDynamicTableReference::ReferencedDynamicEntryNotExisted { index: 11 }) + ), + "expected missing post-base entry" + ); + } } diff --git a/src/qpack/dynamic.rs b/src/qpack/dynamic.rs index f584044..0424bdb 100644 --- a/src/qpack/dynamic.rs +++ b/src/qpack/dynamic.rs @@ -89,3 +89,150 @@ impl DynamicTable { self.known_received_count += increment; } } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::DynamicTable; + use crate::qpack::field::FieldLine; + + fn field_line(name: &'static [u8], value: &'static [u8]) -> FieldLine { + FieldLine { + name: Bytes::from_static(name), + value: Bytes::from_static(value), + } + } + + #[test] + fn get_and_get_mut_respect_dropped_count() { + let mut table = DynamicTable::new(); + table.capacity = 128; + table.index(field_line(b"header-1", b"value-1")); + table.index(field_line(b"header-2", b"value-2")); + table.increment_known_received_count(1); + + let (evicted_index, evicted) = table.evict(); + assert_eq!(evicted_index, 0); + assert_eq!(evicted.name, Bytes::from_static(b"header-1")); + assert!(table.get(0).is_none()); + + let entry = table.get_mut(1).expect("second entry should remain"); + entry.value = Bytes::from_static(b"updated"); + assert_eq!( + table.get(1).expect("updated entry should exist").value, + Bytes::from_static(b"updated") + ); + } + + #[test] + fn evictable_depends_on_known_received_count() { + let mut table = DynamicTable::new(); + table.capacity = 64; + table.index(field_line(b"header-1", b"value-1")); + + assert!( + !table.evictable(), + "unacknowledged entry must not be evictable" + ); + + table.increment_known_received_count(1); + assert!( + table.evictable(), + "acknowledged entry should become evictable" + ); + } + + #[test] + fn entries_preserve_absolute_indices_after_eviction() { + let mut table = DynamicTable::new(); + table.capacity = 192; + table.index(field_line(b"header-1", b"value-1")); + table.index(field_line(b"header-2", b"value-2")); + table.index(field_line(b"header-3", b"value-3")); + table.increment_known_received_count(2); + table.evict(); + + let entries = table + .entries() + .map(|(index, entry)| (index, entry.name.clone())) + .collect::>(); + + assert_eq!( + entries, + vec![ + (1, Bytes::from_static(b"header-2")), + (2, Bytes::from_static(b"header-3")), + ] + ); + } + + #[test] + fn evict_updates_size_and_counts() { + let mut table = DynamicTable::new(); + let entry = field_line(b"header-1", b"value-1"); + let entry_size = entry.size(); + table.capacity = entry_size; + table.index(entry); + table.increment_known_received_count(1); + + let (index, _) = table.evict(); + + assert_eq!(index, 0); + assert_eq!(table.size, 0); + assert_eq!(table.dropped_count, 1); + assert!(table.is_empty()); + } + + #[test] + fn increment_known_received_count_accumulates() { + let mut table = DynamicTable::new(); + table.increment_known_received_count(2); + table.increment_known_received_count(3); + assert_eq!(table.known_received_count, 5); + } + + #[test] + fn default_matches_new_empty_table() { + let table = DynamicTable::default(); + + assert_eq!(table.inserted_count, 0); + assert_eq!(table.dropped_count, 0); + assert_eq!(table.known_received_count, 0); + assert_eq!(table.size, 0); + assert_eq!(table.capacity, 0); + assert!(table.is_empty()); + assert_eq!(table.entries().len(), 0); + } + + #[test] + fn get_mut_returns_none_for_evicted_absolute_index() { + let mut table = DynamicTable::new(); + table.capacity = 128; + table.index(field_line(b"header-1", b"value-1")); + table.index(field_line(b"header-2", b"value-2")); + table.increment_known_received_count(1); + table.evict(); + + assert!(table.get_mut(0).is_none()); + assert!(table.get_mut(1).is_some()); + } + + #[test] + #[should_panic(expected = "No evictable entry exist")] + fn evict_panics_when_no_entry_is_acknowledged() { + let mut table = DynamicTable::new(); + table.capacity = 64; + table.index(field_line(b"header-1", b"value-1")); + + table.evict(); + } + + #[test] + #[should_panic(expected = "Dynamic table size exceeded its capacity")] + fn index_panics_when_capacity_is_exceeded() { + let mut table = DynamicTable::new(); + table.capacity = 10; + table.index(field_line(b"header-1", b"value-1")); + } +} diff --git a/src/qpack/encoder.rs b/src/qpack/encoder.rs index b467cd5..2f68f68 100644 --- a/src/qpack/encoder.rs +++ b/src/qpack/encoder.rs @@ -662,16 +662,34 @@ where #[cfg(test)] mod tests { - use std::sync::Arc; + use std::{ + collections::VecDeque, + io::Cursor, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; use bytes::Bytes; + use futures::{Sink, SinkExt, stream}; + use tokio::io::{AsyncWrite, AsyncWriteExt}; use crate::{ - dhttp::settings::{QpackBlockedStreams, QpackMaxTableCapacity, Settings}, + codec::{DecodeFrom, EncodeInto}, + connection::StreamError, + dhttp::settings::Settings, + error::{Code, H3ConnectionError}, qpack::{ - encoder::{EncoderState, QPackDecoderStreamError, QPackEncoderError}, - field::FieldLineRepresentation, + algorithm::{Algorithm, CompressOutput}, + decoder::DecoderInstruction, + encoder::{ + EncodeHeaderSectionError, Encoder, EncoderInstruction, EncoderState, + QPackDecoderStreamError, QPackEncoderError, + }, + field::{EncodedFieldSectionPrefix, FieldLine, FieldLineRepresentation}, + r#static, }, + quic::{self, GetStreamId}, varint::VarInt, }; @@ -700,13 +718,471 @@ mod tests { } fn settings_with_capacity(capacity: u32) -> Arc { + settings_with_capacity_and_blocked_streams(capacity, 10) + } + + fn settings_with_capacity_and_blocked_streams( + capacity: u32, + blocked_streams: u32, + ) -> Arc { let mut settings = Settings::default(); - settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(capacity))); - settings.set(QpackBlockedStreams::setting(VarInt::from_u32(10))); + settings.set(crate::qpack::settings::QpackMaxTableCapacity::setting( + VarInt::from_u32(capacity), + )); + settings.set(crate::qpack::settings::QpackBlockedStreams::setting( + VarInt::from_u32(blocked_streams), + )); Arc::new(settings) } - // --- Fix 2: zero increment rejection --- + fn assert_connection_h3_code(error: StreamError, expected: Code) { + match error { + StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + } => assert_eq!(source.code(), expected), + error => panic!("unexpected error: {error:?}"), + } + } + + fn state_with_full_unacknowledged_entry() -> EncoderState { + let mut state = EncoderState::new(settings_with_capacity(128)); + state + .set_max_table_capacity(64) + .expect("set capacity failed"); + state + .insert_with_literal_name( + false, + Bytes::from(vec![b'x'; 16]), + false, + Bytes::from(vec![b'y'; 16]), + ) + .expect("insert full-size entry"); + state + } + + #[derive(Default)] + struct RecordingSink { + bytes: Vec, + } + + impl AsyncWrite for RecordingSink { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.bytes.extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Sink for RecordingSink { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.bytes.extend_from_slice(&item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Clone, Default)] + struct RecordingInstructionSink { + instructions: Arc>>, + flushes: Arc>, + } + + impl RecordingInstructionSink { + fn instructions(&self) -> Vec { + self.instructions + .lock() + .expect("instruction lock is not poisoned") + .clone() + } + + fn flushes(&self) -> usize { + *self.flushes.lock().expect("flush lock is not poisoned") + } + } + + impl Sink for RecordingInstructionSink { + type Error = StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: EncoderInstruction) -> Result<(), Self::Error> { + self.instructions + .lock() + .expect("instruction lock is not poisoned") + .push(item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + *self.flushes.lock().expect("flush lock is not poisoned") += 1; + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + struct TestStreamId(u32); + + impl GetStreamId for TestStreamId { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(VarInt::from_u32(self.0))) + } + } + + #[derive(Default)] + struct MayBlockProbeAlgorithm { + max_referenced_index: Option, + observed_may_block: Arc>>, + } + + impl MayBlockProbeAlgorithm { + fn with_max_referenced_index(max_referenced_index: u64) -> Self { + Self { + max_referenced_index: Some(max_referenced_index), + observed_may_block: Arc::default(), + } + } + + fn observed_may_block(&self) -> Vec { + self.observed_may_block + .lock() + .expect("may_block lock is not poisoned") + .clone() + } + } + + impl Algorithm for MayBlockProbeAlgorithm { + async fn compress( + &self, + _state: &mut EncoderState, + _entries: impl IntoIterator + Send, + may_block: bool, + ) -> CompressOutput { + self.observed_may_block + .lock() + .expect("may_block lock is not poisoned") + .push(may_block); + CompressOutput { + prefix: EncodedFieldSectionPrefix { + encoded_insert_count: 0, + sign: false, + delta_base: 0, + }, + representations: Vec::new(), + max_referenced_index: self.max_referenced_index, + } + } + } + + async fn encode_decode_roundtrip( + instruction: EncoderInstruction, + ) -> Result { + let mut sink = RecordingSink::default(); + instruction.clone().encode_into(&mut sink).await?; + EncoderInstruction::decode_from(Cursor::new(sink.bytes)).await + } + + // --- Core encoding tests --- + + #[test] + fn decoder_stream_errors_report_qpack_decoder_stream_error_code() { + let cases = [ + QPackDecoderStreamError::AcknowledgeNonExistSection { stream_id: 7 }, + QPackDecoderStreamError::IncrementKnownReceivedCountOverflow, + QPackDecoderStreamError::IncrementZero, + ]; + + for error in cases { + assert_eq!(error.code(), Code::QPACK_DECODER_STREAM_ERROR); + } + } + + #[test] + fn encoder_state_accessors_expose_initial_and_capacity_state() { + let settings = settings_with_capacity(128); + let mut state = EncoderState::new(settings.clone()); + + assert!(std::ptr::eq(state.settings(), settings.as_ref())); + assert_eq!(state.table_capacity(), 0); + assert_eq!(state.table_size(), 0); + assert_eq!(state.table_inserted_count(), 0); + assert_eq!(state.table_known_received_count(), 0); + assert_eq!(state.table_dropped_count(), 0); + assert_eq!(state.table_remaining(), 0); + assert!(state.entries().next().is_none()); + assert!(state.find_name(&Bytes::from_static(b"missing")).is_none()); + assert!(state.find_value(&Bytes::from_static(b"missing")).is_none()); + assert!(state.pending_instructions().is_empty()); + + state + .set_max_table_capacity(64) + .expect("capacity within peer setting"); + assert_eq!(state.table_capacity(), 64); + assert_eq!(state.table_remaining(), 64); + assert_eq!( + state.pending_instructions().back(), + Some(&EncoderInstruction::SetDynamicTableCapacity { capacity: 64 }) + ); + } + + #[tokio::test] + async fn encoder_instruction_roundtrip_all_variants() { + let instructions = [ + EncoderInstruction::SetDynamicTableCapacity { capacity: 123 }, + EncoderInstruction::InsertWithNameReference { + is_static: true, + name_index: 15, + huffman: false, + value: Bytes::from_static(b"value"), + }, + EncoderInstruction::InsertWithLiteralName { + name_huffman: false, + name: Bytes::from_static(b"x-test"), + value_huffman: false, + value: Bytes::from_static(b"payload"), + }, + EncoderInstruction::Duplicate { index: 7 }, + ]; + + for instruction in instructions { + let decoded = encode_decode_roundtrip(instruction.clone()) + .await + .expect("instruction should roundtrip"); + assert_eq!(decoded, instruction); + } + } + + #[tokio::test] + async fn encoder_instruction_roundtrips_extended_integer_values() { + let instructions = [ + EncoderInstruction::SetDynamicTableCapacity { capacity: 4096 }, + EncoderInstruction::InsertWithNameReference { + is_static: false, + name_index: 130, + huffman: false, + value: Bytes::from_static(b"value"), + }, + EncoderInstruction::Duplicate { index: 257 }, + ]; + + for instruction in instructions { + let decoded = encode_decode_roundtrip(instruction.clone()) + .await + .expect("instruction should roundtrip"); + assert_eq!(decoded, instruction); + } + } + + #[tokio::test] + async fn encoder_instruction_decode_truncated_input_returns_error() { + let result = EncoderInstruction::decode_from(Cursor::new(Vec::::new())).await; + + assert!(result.is_err(), "empty encoder instruction must fail"); + } + + #[tokio::test] + async fn encoder_instruction_encodes_huffman_and_static_flags_on_wire() { + let mut sink = RecordingSink::default(); + EncoderInstruction::InsertWithNameReference { + is_static: true, + name_index: 10, + huffman: true, + value: Bytes::from_static(b"abc"), + } + .encode_into(&mut sink) + .await + .expect("instruction should encode"); + + assert_eq!(sink.bytes[0], 0b1100_1010); + assert_eq!(sink.bytes[1] & 0b1000_0000, 0b1000_0000); + + let decoded = EncoderInstruction::decode_from(Cursor::new(sink.bytes)) + .await + .expect("instruction should decode"); + assert_eq!( + decoded, + EncoderInstruction::InsertWithNameReference { + is_static: true, + name_index: 10, + huffman: true, + value: Bytes::from_static(b"abc"), + } + ); + } + + #[tokio::test] + async fn encoder_instruction_encodes_literal_name_huffman_flags_independently() { + let mut sink = RecordingSink::default(); + EncoderInstruction::InsertWithLiteralName { + name_huffman: true, + name: Bytes::from_static(b"abc"), + value_huffman: false, + value: Bytes::from_static(b"xyz"), + } + .encode_into(&mut sink) + .await + .expect("instruction should encode"); + + assert_eq!(sink.bytes[0] & 0b1110_0000, 0b0110_0000); + assert_eq!(sink.bytes[3] & 0b1000_0000, 0); + + let decoded = EncoderInstruction::decode_from(Cursor::new(sink.bytes)) + .await + .expect("instruction should decode"); + assert_eq!( + decoded, + EncoderInstruction::InsertWithLiteralName { + name_huffman: true, + name: Bytes::from_static(b"abc"), + value_huffman: false, + value: Bytes::from_static(b"xyz"), + } + ); + } + + #[tokio::test] + async fn recording_sinks_exercise_async_write_and_sink_completion_paths() { + let mut byte_sink = RecordingSink::default(); + byte_sink + .write_all(b"async") + .await + .expect("async write should record bytes"); + AsyncWriteExt::flush(&mut byte_sink) + .await + .expect("async flush should succeed"); + byte_sink + .shutdown() + .await + .expect("async shutdown should succeed"); + byte_sink + .send(Bytes::from_static(b"-sink")) + .await + .expect("sink send should record bytes"); + byte_sink.close().await.expect("sink close should succeed"); + assert_eq!(byte_sink.bytes, b"async-sink"); + + let mut instruction_sink = RecordingInstructionSink::default(); + instruction_sink + .send(EncoderInstruction::Duplicate { index: 9 }) + .await + .expect("instruction sink send should succeed"); + instruction_sink + .close() + .await + .expect("instruction sink close should succeed"); + assert_eq!( + instruction_sink.instructions(), + vec![EncoderInstruction::Duplicate { index: 9 }] + ); + assert_eq!(instruction_sink.flushes(), 1); + } + + #[test] + fn evict_entry_removes_unique_name_and_value_indices() { + let mut state = EncoderState::new(settings_with_capacity(128)); + state + .set_max_table_capacity(128) + .expect("set capacity failed"); + let index = state + .insert_with_literal_name( + false, + Bytes::from_static(b"unique-name"), + false, + Bytes::from_static(b"unique-value"), + ) + .expect("insert failed"); + state.dynamic_table.known_received_count = state.table_inserted_count(); + + let (evicted_index, evicted_entry) = state.evict_entry(); + + assert_eq!(evicted_index, index); + assert_eq!( + evicted_entry, + FieldLine { + name: Bytes::from_static(b"unique-name"), + value: Bytes::from_static(b"unique-value"), + } + ); + assert!( + state + .find_name(&Bytes::from_static(b"unique-name")) + .is_none() + ); + assert!( + state + .find_value(&Bytes::from_static(b"unique-value")) + .is_none() + ); + } + + #[test] + fn encode_header_section_error_from_quic_stream_error_preserves_reset_code() { + let error = EncodeHeaderSectionError::from(quic::StreamError::Reset { + code: VarInt::from_u32(77), + }); + + match error { + EncodeHeaderSectionError::Stream { + source: StreamError::Reset { code }, + } => assert_eq!(code, VarInt::from_u32(77)), + error => panic!("unexpected encode header error: {error:?}"), + } + } #[test] fn on_insert_count_increment_zero_returns_err() { @@ -738,7 +1214,112 @@ mod tests { assert!(result.is_ok(), "Expected Ok, got {result:?}"); } - // --- Fix 3: panic→error for capacity --- + #[test] + fn on_insert_count_increment_overflow_returns_err() { + let mut state = EncoderState::new(settings_with_capacity(256)); + state + .set_max_table_capacity(128) + .expect("set capacity failed"); + state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-custom"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert failed"); + + let result = state.on_insert_count_increment(2); + + assert!( + matches!( + result, + Err(QPackDecoderStreamError::IncrementKnownReceivedCountOverflow) + ), + "Expected IncrementKnownReceivedCountOverflow, got {result:?}" + ); + } + + #[test] + fn on_section_acknowledgment_missing_stream_returns_err() { + let mut state = EncoderState::new(settings_with_capacity(256)); + + let result = state.on_section_acknowledgment(42); + + assert!( + matches!( + result, + Err(QPackDecoderStreamError::AcknowledgeNonExistSection { stream_id: 42 }) + ), + "Expected AcknowledgeNonExistSection, got {result:?}" + ); + } + + #[test] + fn on_section_acknowledgment_updates_known_received_count_and_clears_stream() { + let mut state = EncoderState::new(settings_with_capacity(256)); + state.dynamic_table.known_received_count = 1; + state.blocking_streams.insert(7, VecDeque::from([3])); + + state + .on_section_acknowledgment(7) + .expect("section acknowledgment should succeed"); + + assert_eq!(state.dynamic_table.known_received_count, 4); + assert!(!state.blocking_streams.contains_key(&7)); + } + + #[test] + fn on_section_acknowledgment_removes_only_acknowledged_section() { + let mut state = EncoderState::new(settings_with_capacity(256)); + state.blocking_streams.insert(7, VecDeque::from([1, 3])); + + state + .on_section_acknowledgment(7) + .expect("section acknowledgment should succeed"); + + assert_eq!(state.blocking_streams.get(&7), Some(&VecDeque::from([3]))); + assert_eq!(state.table_known_received_count(), 2); + } + + // --- Table management tests --- + + #[test] + fn full_unacknowledged_table_blocks_capacity_reference_and_duplicate_mutations() { + let mut shrink = state_with_full_unacknowledged_entry(); + assert!(matches!( + shrink.set_max_table_capacity(32), + Err(QPackEncoderError::CannotEvict) + )); + + let mut static_reference = state_with_full_unacknowledged_entry(); + assert!(matches!( + static_reference.insert_with_name_reference( + true, + 0, + false, + Bytes::from_static(b"value"), + ), + Err(QPackEncoderError::CannotEvict) + )); + + let mut dynamic_reference = state_with_full_unacknowledged_entry(); + assert!(matches!( + dynamic_reference.insert_with_name_reference( + false, + 0, + false, + Bytes::from_static(b"value"), + ), + Err(QPackEncoderError::CannotEvict) + )); + + let mut duplicate = state_with_full_unacknowledged_entry(); + assert!(matches!( + duplicate.duplicate(0), + Err(QPackEncoderError::CannotEvict) + )); + } #[test] fn set_max_table_capacity_exceeds_max_returns_err() { @@ -776,7 +1357,49 @@ mod tests { ); } - // --- Fix 4: eviction guard --- + #[test] + fn set_max_table_capacity_shrinks_by_evicting_acknowledged_entries() { + let mut state = EncoderState::new(settings_with_capacity(128)); + state + .set_max_table_capacity(128) + .expect("set capacity failed"); + state + .insert_with_literal_name( + false, + Bytes::from_static(b"header-1"), + false, + Bytes::from_static(b"value-1"), + ) + .expect("first insert failed"); + state + .insert_with_literal_name( + false, + Bytes::from_static(b"header-2"), + false, + Bytes::from_static(b"value-2"), + ) + .expect("second insert failed"); + state.dynamic_table.known_received_count = state.dynamic_table.inserted_count; + + state + .set_max_table_capacity(64) + .expect("shrink should evict acknowledged entries"); + + assert_eq!(state.table_capacity(), 64); + assert_eq!(state.table_dropped_count(), 1); + assert_eq!(state.table_inserted_count(), 2); + assert_eq!( + state + .entries() + .map(|(index, entry)| (index, entry.name.clone())) + .collect::>(), + vec![(1, Bytes::from_static(b"header-2"))] + ); + assert!(state.find_name(&Bytes::from_static(b"header-1")).is_none()); + assert!(state.find_value(&Bytes::from_static(b"value-1")).is_none()); + } + + // --- Eviction tests --- #[test] fn insert_with_literal_name_cannot_evict_when_unacknowledged() { @@ -810,7 +1433,500 @@ mod tests { ); } - // --- Fix 7: blocked streams — get_dynamic_references --- + #[test] + fn insert_with_literal_name_indexes_entry_and_emits_instruction() { + let mut state = EncoderState::new(settings_with_capacity(128)); + state + .set_max_table_capacity(128) + .expect("set capacity failed"); + + let index = state + .insert_with_literal_name( + true, + Bytes::from_static(b"x-name"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert failed"); + + assert_eq!(index, 0); + assert_eq!( + state.get_entry(index), + Some(&FieldLine { + name: Bytes::from_static(b"x-name"), + value: Bytes::from_static(b"value"), + }) + ); + assert!( + state + .find_name(&Bytes::from_static(b"x-name")) + .unwrap() + .contains(&index) + ); + assert!( + state + .find_value(&Bytes::from_static(b"value")) + .unwrap() + .contains(&index) + ); + assert_eq!( + state.pending_instructions().back(), + Some(&EncoderInstruction::InsertWithLiteralName { + name_huffman: true, + name: Bytes::from_static(b"x-name"), + value_huffman: false, + value: Bytes::from_static(b"value"), + }) + ); + } + + #[test] + fn insert_with_name_reference_uses_static_name_and_emits_static_reference() { + let mut state = EncoderState::new(settings_with_capacity(128)); + state + .set_max_table_capacity(128) + .expect("set capacity failed"); + let static_name_index = 17; + let expected_name = Bytes::from_static( + r#static::get_name(static_name_index) + .expect("static name exists") + .as_bytes(), + ); + + let index = state + .insert_with_name_reference(true, static_name_index, true, Bytes::from_static(b"PATCH")) + .expect("static name reference insert failed"); + + assert_eq!( + state.get_entry(index), + Some(&FieldLine { + name: expected_name, + value: Bytes::from_static(b"PATCH"), + }) + ); + assert_eq!( + state.pending_instructions().back(), + Some(&EncoderInstruction::InsertWithNameReference { + is_static: true, + name_index: static_name_index, + huffman: true, + value: Bytes::from_static(b"PATCH"), + }) + ); + } + + #[test] + fn insert_with_name_reference_uses_dynamic_relative_index() { + let mut state = EncoderState::new(settings_with_capacity(256)); + state + .set_max_table_capacity(256) + .expect("set capacity failed"); + let referenced_index = state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-name"), + false, + Bytes::from_static(b"first"), + ) + .expect("first insert failed"); + + let new_index = state + .insert_with_name_reference( + false, + referenced_index, + false, + Bytes::from_static(b"second"), + ) + .expect("name reference insert failed"); + + assert_eq!(new_index, 1); + assert_eq!( + state.pending_instructions().back(), + Some(&EncoderInstruction::InsertWithNameReference { + is_static: false, + name_index: 0, + huffman: false, + value: Bytes::from_static(b"second"), + }) + ); + } + + #[test] + fn insertion_evicts_acknowledged_entries_to_make_room() { + let mut state = EncoderState::new(settings_with_capacity(96)); + state + .set_max_table_capacity(64) + .expect("set capacity failed"); + state + .insert_with_literal_name( + false, + Bytes::from_static(b"old-name"), + false, + Bytes::from_static(b"old-value"), + ) + .expect("initial insert failed"); + state.dynamic_table.known_received_count = state.table_inserted_count(); + + let index = state + .insert_with_literal_name( + false, + Bytes::from_static(b"new-name"), + false, + Bytes::from_static(b"new-value"), + ) + .expect("acknowledged old entry should be evictable"); + + assert_eq!(index, 1); + assert_eq!(state.table_dropped_count(), 1); + assert!(state.find_name(&Bytes::from_static(b"old-name")).is_none()); + assert_eq!( + state + .entries() + .map(|(index, entry)| (index, entry.name.clone())) + .collect::>(), + vec![(1, Bytes::from_static(b"new-name"))] + ); + } + + #[test] + fn duplicate_uses_relative_index_and_preserves_entry() { + let mut state = EncoderState::new(settings_with_capacity(256)); + state + .set_max_table_capacity(256) + .expect("set capacity failed"); + let original_index = state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-name"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert failed"); + + let duplicated_index = state.duplicate(original_index).expect("duplicate failed"); + + assert_eq!(duplicated_index, 1); + assert_eq!( + state.pending_instructions().back(), + Some(&EncoderInstruction::Duplicate { index: 0 }) + ); + assert_eq!( + state.get_entry(duplicated_index), + Some(&FieldLine { + name: Bytes::from_static(b"x-name"), + value: Bytes::from_static(b"value"), + }) + ); + } + + #[test] + fn duplicate_evicts_acknowledged_entries_to_make_room() { + let mut state = EncoderState::new(settings_with_capacity(96)); + state + .set_max_table_capacity(64) + .expect("set capacity failed"); + let original_index = state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-name"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert failed"); + state.dynamic_table.known_received_count = state.table_inserted_count(); + + let duplicated_index = state + .duplicate(original_index) + .expect("acknowledged original entry should be evictable"); + + assert_eq!(duplicated_index, 1); + assert_eq!(state.table_dropped_count(), 1); + assert_eq!( + state + .find_name(&Bytes::from_static(b"x-name")) + .unwrap() + .iter() + .copied() + .collect::>(), + vec![duplicated_index] + ); + } + + #[test] + fn evict_entry_removes_only_the_evicted_duplicate_index() { + let mut state = EncoderState::new(settings_with_capacity(256)); + state + .set_max_table_capacity(256) + .expect("set capacity failed"); + let first_index = state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-name"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert failed"); + let duplicated_index = state.duplicate(first_index).expect("duplicate failed"); + state.dynamic_table.known_received_count = 1; + + let (evicted_index, evicted_entry) = state.evict_entry(); + + assert_eq!(evicted_index, first_index); + assert_eq!( + evicted_entry, + FieldLine { + name: Bytes::from_static(b"x-name"), + value: Bytes::from_static(b"value"), + } + ); + assert_eq!( + state + .find_name(&Bytes::from_static(b"x-name")) + .unwrap() + .len(), + 1 + ); + assert!( + state + .find_name(&Bytes::from_static(b"x-name")) + .unwrap() + .contains(&duplicated_index) + ); + assert_eq!( + state + .find_value(&Bytes::from_static(b"value")) + .unwrap() + .iter() + .copied() + .collect::>(), + vec![duplicated_index] + ); + } + + #[tokio::test] + async fn apply_settings_initializes_capacity_and_flushes_instruction() { + let encoder_sink = RecordingInstructionSink::default(); + let recorded = encoder_sink.clone(); + let encoder = Encoder::new( + settings_with_capacity(0), + encoder_sink, + stream::empty::>(), + ); + + encoder.apply_settings(settings_with_capacity(96)).await; + + assert_eq!(encoder.state.lock().await.table_capacity(), 96); + assert_eq!( + recorded.instructions(), + vec![EncoderInstruction::SetDynamicTableCapacity { capacity: 96 }] + ); + assert_eq!(recorded.flushes(), 1); + assert!(encoder.state.lock().await.pending_instructions().is_empty()); + } + + #[tokio::test] + async fn apply_settings_does_not_emit_capacity_when_peer_capacity_is_zero() { + let encoder_sink = RecordingInstructionSink::default(); + let recorded = encoder_sink.clone(); + let encoder = Encoder::new( + settings_with_capacity(0), + encoder_sink, + stream::empty::>(), + ); + + encoder.apply_settings(settings_with_capacity(0)).await; + + assert_eq!(encoder.state.lock().await.table_capacity(), 0); + assert!(recorded.instructions().is_empty()); + assert_eq!(recorded.flushes(), 1); + } + + #[tokio::test] + async fn flush_instructions_sends_all_pending_instructions_in_order_once() { + let encoder_sink = RecordingInstructionSink::default(); + let recorded = encoder_sink.clone(); + let encoder = Encoder::new( + settings_with_capacity(128), + encoder_sink, + stream::empty::>(), + ); + { + let mut state = encoder.state.lock().await; + state.emit(EncoderInstruction::SetDynamicTableCapacity { capacity: 64 }); + state.emit(EncoderInstruction::Duplicate { index: 3 }); + } + + encoder + .flush_instructions() + .await + .expect("flush should succeed"); + encoder + .flush_instructions() + .await + .expect("second flush should succeed without resending"); + + assert_eq!( + recorded.instructions(), + vec![ + EncoderInstruction::SetDynamicTableCapacity { capacity: 64 }, + EncoderInstruction::Duplicate { index: 3 }, + ] + ); + assert_eq!(recorded.flushes(), 2); + assert!(encoder.state.lock().await.pending_instructions().is_empty()); + } + + #[tokio::test] + async fn encode_tracks_blocking_sections_from_algorithm_output() { + let encoder = Encoder::new( + settings_with_capacity_and_blocked_streams(128, 1), + RecordingInstructionSink::default(), + stream::empty::>(), + ); + let algorithm = MayBlockProbeAlgorithm::with_max_referenced_index(4); + + encoder + .encode(Vec::new(), &algorithm, TestStreamId(11)) + .await + .expect("header section should encode"); + + assert_eq!(algorithm.observed_may_block(), vec![true]); + assert_eq!( + encoder.state.lock().await.blocking_streams.get(&11), + Some(&VecDeque::from([4])) + ); + } + + #[tokio::test] + async fn encode_disallows_new_blocking_streams_after_peer_limit() { + let encoder = Encoder::new( + settings_with_capacity_and_blocked_streams(128, 1), + RecordingInstructionSink::default(), + stream::empty::>(), + ); + { + let mut state = encoder.state.lock().await; + state.blocking_streams.insert(1, VecDeque::from([0])); + } + let algorithm = MayBlockProbeAlgorithm::default(); + + encoder + .encode(Vec::new(), &algorithm, TestStreamId(2)) + .await + .expect("header section should encode"); + encoder + .encode(Vec::new(), &algorithm, TestStreamId(1)) + .await + .expect("header section should encode"); + + assert_eq!(algorithm.observed_may_block(), vec![false, true]); + } + + #[tokio::test] + async fn receive_instruction_applies_decoder_stream_state_changes() { + let encoder = Encoder::new( + settings_with_capacity(128), + RecordingInstructionSink::default(), + stream::iter([ + Ok(DecoderInstruction::SectionAcknowledgment { stream_id: 7 }), + Ok(DecoderInstruction::StreamCancellation { stream_id: 8 }), + Ok(DecoderInstruction::InsertCountIncrement { increment: 1 }), + ]), + ); + { + let mut state = encoder.state.lock().await; + state + .set_max_table_capacity(128) + .expect("set capacity failed"); + state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-name"), + false, + Bytes::from_static(b"value"), + ) + .expect("insert failed"); + state + .insert_with_literal_name( + false, + Bytes::from_static(b"x-name-2"), + false, + Bytes::from_static(b"value-2"), + ) + .expect("insert failed"); + state.blocking_streams.insert(7, VecDeque::from([0])); + state.blocking_streams.insert(8, VecDeque::from([0])); + } + + encoder + .receive_instruction() + .await + .expect("section acknowledgment should be accepted"); + encoder + .receive_instruction() + .await + .expect("stream cancellation should be accepted"); + encoder + .receive_instruction() + .await + .expect("insert count increment should be accepted"); + + let state = encoder.state.lock().await; + assert!(!state.blocking_streams.contains_key(&7)); + assert!(!state.blocking_streams.contains_key(&8)); + assert_eq!(state.table_known_received_count(), 2); + } + + #[tokio::test] + async fn receive_instruction_reports_closed_decoder_stream() { + let encoder = Encoder::new( + settings_with_capacity(128), + RecordingInstructionSink::default(), + stream::empty::>(), + ); + + let result = encoder.receive_instruction().await; + + assert_connection_h3_code( + result.expect_err("closed decoder stream must be reported"), + Code::H3_CLOSED_CRITICAL_STREAM, + ); + } + + #[tokio::test] + async fn receive_instruction_reports_decoder_instruction_state_errors() { + let missing_section = Encoder::new( + settings_with_capacity(128), + RecordingInstructionSink::default(), + stream::iter([Ok(DecoderInstruction::SectionAcknowledgment { + stream_id: 42, + })]), + ); + assert_connection_h3_code( + missing_section + .receive_instruction() + .await + .expect_err("missing section acknowledgment must be rejected"), + Code::QPACK_DECODER_STREAM_ERROR, + ); + + let zero_increment = Encoder::new( + settings_with_capacity(128), + RecordingInstructionSink::default(), + stream::iter([Ok(DecoderInstruction::InsertCountIncrement { + increment: 0, + })]), + ); + assert_connection_h3_code( + zero_increment + .receive_instruction() + .await + .expect_err("zero insert count increment must be rejected"), + Code::QPACK_DECODER_STREAM_ERROR, + ); + } + + // --- Edge case tests --- #[test] fn get_dynamic_references_finds_dynamic_indexed_field() { @@ -845,6 +1961,20 @@ mod tests { assert_eq!(refs, vec![7]); } + #[test] + fn get_dynamic_references_finds_post_base_name_reference() { + let reprs = vec![ + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: false, + name_index: 11, + huffman: false, + value: Bytes::from_static(b"val"), + }, + ]; + let refs: Vec = get_dynamic_references(&reprs).collect(); + assert_eq!(refs, vec![11]); + } + #[test] fn get_dynamic_references_empty_for_literal_name() { let reprs = vec![FieldLineRepresentation::LiteralFieldLineWithLiteralName { diff --git a/src/qpack/field.rs b/src/qpack/field.rs index 560530e..56f6b34 100644 --- a/src/qpack/field.rs +++ b/src/qpack/field.rs @@ -13,7 +13,7 @@ mod section; pub use section::{FieldSection, Iter, MalformedHeaderSection, malformed_header_section}; #[cfg(feature = "hyper")] -pub(crate) mod hyper; +pub mod hyper; #[derive(Debug, Clone, PartialEq, Eq)] pub struct Protocol { @@ -149,3 +149,89 @@ impl From for FieldLine { } } } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use http::{ + Method, StatusCode, + header::{HeaderName, HeaderValue}, + uri::{Authority, PathAndQuery, Scheme}, + }; + + use super::{FieldLine, Protocol, PseudoHeaders}; + + #[test] + fn protocol_validates_utf8_and_exposes_original_token() { + let protocol = Protocol::try_from(Bytes::from_static(b"webtransport")) + .expect("ASCII protocol token should be valid"); + + assert_eq!(protocol.as_bytes(), b"webtransport"); + assert_eq!(protocol.as_str(), "webtransport"); + assert_eq!(AsRef::<[u8]>::as_ref(&protocol), b"webtransport"); + assert_eq!(AsRef::::as_ref(&protocol), "webtransport"); + + let error = Protocol::try_from(Bytes::from_static(b"\xff")).expect_err("invalid UTF-8"); + assert_eq!(error.valid_up_to(), 0); + } + + #[test] + fn field_line_from_header_parts_preserves_name_value_and_size() { + let field = FieldLine::from(( + HeaderName::from_static("content-type"), + HeaderValue::from_static("application/dhttp"), + )); + + assert_eq!(field.name, Bytes::from_static(b"content-type")); + assert_eq!(field.value, Bytes::from_static(b"application/dhttp")); + assert_eq!( + field.size(), + "content-type".len() as u64 + "application/dhttp".len() as u64 + 32 + ); + } + + #[test] + fn field_line_from_pseudo_source_types_uses_expected_names_and_values() { + let method = FieldLine::from(Method::POST); + assert_eq!( + method.name, + Bytes::from_static(PseudoHeaders::METHOD.as_bytes()) + ); + assert_eq!(method.value, Bytes::from_static(b"POST")); + + let scheme = FieldLine::from(Scheme::HTTPS); + assert_eq!( + scheme.name, + Bytes::from_static(PseudoHeaders::SCHEME.as_bytes()) + ); + assert_eq!(scheme.value, Bytes::from_static(b"https")); + + let authority = FieldLine::from(Authority::from_static("example.com:443")); + assert_eq!( + authority.name, + Bytes::from_static(PseudoHeaders::AUTHORITY.as_bytes()) + ); + assert_eq!(authority.value, Bytes::from_static(b"example.com:443")); + + let path = FieldLine::from(PathAndQuery::from_static("/resource?q=1")); + assert_eq!( + path.name, + Bytes::from_static(PseudoHeaders::PATH.as_bytes()) + ); + assert_eq!(path.value, Bytes::from_static(b"/resource?q=1")); + + let protocol = FieldLine::from(Protocol::new("webtransport")); + assert_eq!( + protocol.name, + Bytes::from_static(PseudoHeaders::PROTOOCL.as_bytes()) + ); + assert_eq!(protocol.value, Bytes::from_static(b"webtransport")); + + let status = FieldLine::from(StatusCode::CREATED); + assert_eq!( + status.name, + Bytes::from_static(PseudoHeaders::STATUS.as_bytes()) + ); + assert_eq!(status.value, Bytes::from_static(b"201")); + } +} diff --git a/src/qpack/field/hyper.rs b/src/qpack/field/hyper.rs index 6eda40e..35c4ded 100644 --- a/src/qpack/field/hyper.rs +++ b/src/qpack/field/hyper.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use bytes::Bytes; use http::{HeaderName, Request, Response, Uri, Version, request, response}; use snafu::OptionExt; @@ -16,7 +18,9 @@ impl From for FieldLine { } } -pub fn header_map_to_field_lines(headers: http::HeaderMap) -> impl Iterator { +pub(crate) fn header_map_to_field_lines( + headers: http::HeaderMap, +) -> impl Iterator { headers .into_iter() .scan(None::, |last_name, (name, value)| { @@ -39,36 +43,15 @@ pub fn header_map_to_field_lines(headers: http::HeaderMap) -> impl Iterator impl Iterator { - let uri_parts = parts.uri.into_parts(); - let pseudo_headers = [ - Some(parts.method.into()), - uri_parts.scheme.map(FieldLine::from), - uri_parts.authority.map(FieldLine::from), - uri_parts.path_and_query.map(FieldLine::from), - ]; - - let protocol = parts - .extensions - .remove::() - .map(FieldLine::from) - .or_else(|| { - parts - .extensions - .remove::<::hyper::ext::Protocol>() - .map(FieldLine::from) - }); - - pseudo_headers - .into_iter() - .flatten() - .chain(protocol) - .chain(header_map_to_field_lines(parts.headers)) +pub fn validated_hyper_request_parts_to_field_lines( + parts: http::request::Parts, +) -> Result, MalformedHeaderSection> { + let section = FieldSection::from(parts); + section.check_pseudo()?; + Ok(section.iter().collect()) } -pub fn hyper_response_parts_to_field_lines( +pub(crate) fn hyper_response_parts_to_field_lines( parts: http::response::Parts, ) -> impl Iterator { let pseudo_headers = [Some(FieldLine::from(parts.status))]; @@ -96,7 +79,7 @@ impl From for FieldSection { Self { pseudo_headers: Some(pseudo), - header_map: request.headers, + header_map: Arc::new(request.headers), } } } @@ -105,14 +88,19 @@ impl TryFrom for request::Parts { type Error = MalformedHeaderSection; fn try_from(value: FieldSection) -> Result { + value.check_pseudo()?; + + let FieldSection { + pseudo_headers, + header_map, + } = value; let PseudoHeaders::Request { method, scheme, authority, path, protocol, - } = value - .pseudo_headers + } = pseudo_headers .context(malformed_header_section::AbsenceOfMandatoryPseudoHeadersSnafu)? else { return Err(MalformedHeaderSection::ResponsePseudoHeaderInRequest); @@ -137,7 +125,8 @@ impl TryFrom for request::Parts { .method(method) .version(Version::HTTP_3) .body(())?; - *request.headers_mut() = value.header_map; + *request.headers_mut() = + Arc::try_unwrap(header_map).unwrap_or_else(|header_map| (*header_map).clone()); if let Some(protocol) = protocol { request.extensions_mut().insert(protocol); @@ -151,7 +140,7 @@ impl From for FieldSection { fn from(response: response::Parts) -> Self { Self { pseudo_headers: Some(PseudoHeaders::response(response.status)), - header_map: response.headers, + header_map: Arc::new(response.headers), } } } @@ -160,8 +149,11 @@ impl TryFrom for response::Parts { type Error = MalformedHeaderSection; fn try_from(value: FieldSection) -> Result { - let PseudoHeaders::Response { status } = value - .pseudo_headers + let FieldSection { + pseudo_headers, + header_map, + } = value; + let PseudoHeaders::Response { status } = pseudo_headers .context(malformed_header_section::AbsenceOfMandatoryPseudoHeadersSnafu)? else { return Err(MalformedHeaderSection::RequestPseudoHeaderInResponse); @@ -173,7 +165,329 @@ impl TryFrom for response::Parts { .status(status) .version(Version::HTTP_3) .body(())?; - *response.headers_mut() = value.header_map; + *response.headers_mut() = + Arc::try_unwrap(header_map).unwrap_or_else(|header_map| (*header_map).clone()); Ok(response.into_parts().0) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hyper_protocol_and_repeated_headers_convert_to_field_lines() { + let protocol_line = FieldLine::from(::hyper::ext::Protocol::from_static("websocket")); + assert_eq!(protocol_line.name, Bytes::from_static(b":protocol")); + assert_eq!(protocol_line.value, Bytes::from_static(b"websocket")); + + let mut headers = http::HeaderMap::new(); + headers.append("x-repeat", http::HeaderValue::from_static("one")); + headers.append("x-repeat", http::HeaderValue::from_static("two")); + headers.insert("x-single", http::HeaderValue::from_static("solo")); + + let lines = header_map_to_field_lines(headers).collect::>(); + + assert_eq!( + lines + .iter() + .filter(|line| line.name == Bytes::from_static(b"x-repeat")) + .map(|line| line.value.clone()) + .collect::>(), + vec![Bytes::from_static(b"one"), Bytes::from_static(b"two")] + ); + assert!(lines.iter().any(|line| { + line.name == Bytes::from_static(b"x-single") + && line.value == Bytes::from_static(b"solo") + })); + } + + #[test] + fn request_parts_conversion_preserves_hyper_protocol_extension() { + let mut request = http::Request::builder() + .method(http::Method::CONNECT) + .uri("https://example.test/session") + .body(()) + .unwrap(); + request + .extensions_mut() + .insert(::hyper::ext::Protocol::from_static("websocket")); + let (parts, ()) = request.into_parts(); + + let section = FieldSection::from(parts); + + let PseudoHeaders::Request { + protocol: Some(protocol), + .. + } = section.pseudo_headers.expect("request pseudo headers") + else { + panic!("expected request protocol pseudo header"); + }; + assert_eq!(protocol, Protocol::new("websocket")); + } + + #[test] + fn request_parts_try_from_preserves_authority_protocol_and_headers() { + let mut headers = http::HeaderMap::new(); + headers.insert("x-test", http::HeaderValue::from_static("ok")); + let section = FieldSection::header( + PseudoHeaders::Request { + method: Some(http::Method::CONNECT), + scheme: Some(http::uri::Scheme::HTTPS), + authority: Some("example.test".parse().unwrap()), + path: Some("/session".parse().unwrap()), + protocol: Some(Protocol::new("webtransport")), + }, + headers, + ); + + let parts = http::request::Parts::try_from(section).unwrap(); + + assert_eq!(parts.method, http::Method::CONNECT); + assert_eq!(parts.uri, "https://example.test/session"); + assert_eq!(parts.version, http::Version::HTTP_3); + assert_eq!(parts.headers.get("x-test").unwrap(), "ok"); + assert_eq!( + parts.extensions.get::(), + Some(&Protocol::new("webtransport")) + ); + } + + #[test] + fn request_parts_try_from_clones_shared_header_map_and_rejects_trailers() { + let mut headers = http::HeaderMap::new(); + headers.insert("x-shared", http::HeaderValue::from_static("ok")); + let section = FieldSection::header( + PseudoHeaders::Request { + method: Some(http::Method::GET), + scheme: Some(http::uri::Scheme::HTTPS), + authority: Some("example.test".parse().unwrap()), + path: Some("/".parse().unwrap()), + protocol: None, + }, + headers, + ); + let _shared = section.clone(); + + let parts = http::request::Parts::try_from(section).unwrap(); + + assert_eq!(parts.headers.get("x-shared").unwrap(), "ok"); + + let error = http::request::Parts::try_from(FieldSection::trailer(http::HeaderMap::new())) + .unwrap_err(); + assert!(matches!( + error, + MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. } + )); + } + + #[test] + fn request_parts_reject_response_pseudo_headers() { + let section = FieldSection::header( + PseudoHeaders::response(http::StatusCode::OK), + http::HeaderMap::new(), + ); + + let error = http::request::Parts::try_from(section).unwrap_err(); + + assert!(matches!( + error, + MalformedHeaderSection::ResponsePseudoHeaderInRequest + )); + } + + #[test] + fn response_parts_roundtrip_preserves_status_headers_and_rejects_request_pseudo_headers() { + let response = http::Response::builder() + .status(http::StatusCode::CREATED) + .header("x-test", "ok") + .body(()) + .unwrap(); + let (parts, ()) = response.into_parts(); + + let section = FieldSection::from(parts); + let parts = http::response::Parts::try_from(section).unwrap(); + + assert_eq!(parts.status, http::StatusCode::CREATED); + assert_eq!(parts.version, http::Version::HTTP_3); + assert_eq!(parts.headers.get("x-test").unwrap(), "ok"); + + let request_section = FieldSection::header( + PseudoHeaders::Request { + method: Some(http::Method::GET), + scheme: Some(http::uri::Scheme::HTTPS), + authority: Some("example.test".parse().unwrap()), + path: Some("/".parse().unwrap()), + protocol: None, + }, + http::HeaderMap::new(), + ); + + let error = http::response::Parts::try_from(request_section).unwrap_err(); + + assert!(matches!( + error, + MalformedHeaderSection::RequestPseudoHeaderInResponse + )); + } + + #[test] + fn response_parts_try_from_clones_shared_header_map_and_rejects_trailers() { + let mut headers = http::HeaderMap::new(); + headers.insert("x-shared", http::HeaderValue::from_static("ok")); + let section = FieldSection::header(PseudoHeaders::response(http::StatusCode::OK), headers); + let _shared = section.clone(); + + let parts = http::response::Parts::try_from(section).unwrap(); + + assert_eq!(parts.headers.get("x-shared").unwrap(), "ok"); + + let error = http::response::Parts::try_from(FieldSection::trailer(http::HeaderMap::new())) + .unwrap_err(); + assert!(matches!( + error, + MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. } + )); + } + + #[test] + fn validated_request_parts_success_returns_pseudo_and_regular_fields() { + let request = http::Request::builder() + .method(http::Method::GET) + .uri("https://example.test/") + .header("x-test", "ok") + .body(()) + .unwrap(); + let (parts, ()) = request.into_parts(); + + let lines = validated_hyper_request_parts_to_field_lines(parts).unwrap(); + + assert!(lines.iter().any(|line| { + line.name == Bytes::from_static(b":method") && line.value == Bytes::from_static(b"GET") + })); + assert!(lines.iter().any(|line| { + line.name == Bytes::from_static(b"x-test") && line.value == Bytes::from_static(b"ok") + })); + } + + #[test] + fn hyper_response_parts_to_field_lines_emits_status_and_headers() { + let response = http::Response::builder() + .status(http::StatusCode::NO_CONTENT) + .header("x-test", "ok") + .body(()) + .unwrap(); + let (parts, ()) = response.into_parts(); + + let lines = hyper_response_parts_to_field_lines(parts).collect::>(); + + assert_eq!( + lines[0], + FieldLine { + name: Bytes::from_static(b":status"), + value: Bytes::from_static(b"204"), + } + ); + assert!(lines.iter().any(|line| { + line.name == Bytes::from_static(b"x-test") && line.value == Bytes::from_static(b"ok") + })); + } + + #[test] + fn request_parts_rejects_authority_only_get() { + let section = FieldSection::header( + PseudoHeaders::Request { + method: Some(http::Method::GET), + scheme: None, + authority: Some("reimu.pilot.dhttp.net".parse().unwrap()), + path: None, + protocol: None, + }, + http::HeaderMap::new(), + ); + + let error = http::request::Parts::try_from(section).unwrap_err(); + + assert!(matches!( + error, + MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. } + )); + } + + #[test] + fn request_parts_allows_connect_authority_form() { + let section = FieldSection::header( + PseudoHeaders::Request { + method: Some(http::Method::CONNECT), + scheme: None, + authority: Some("example.com:443".parse().unwrap()), + path: None, + protocol: None, + }, + http::HeaderMap::new(), + ); + + let parts = http::request::Parts::try_from(section).unwrap(); + + assert_eq!(parts.method, http::Method::CONNECT); + assert_eq!(parts.uri.authority().unwrap().as_str(), "example.com:443"); + } + + #[test] + fn validated_request_parts_rejects_authority_only_get() { + let request = http::Request::builder() + .method(http::Method::GET) + .uri("reimu.pilot.dhttp.net") + .body(()) + .unwrap(); + let (parts, ()) = request.into_parts(); + + let error = validated_hyper_request_parts_to_field_lines(parts).unwrap_err(); + + assert!(matches!( + error, + MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. } + )); + } + + #[test] + fn request_parts_rejects_protocol_on_non_connect() { + let section = FieldSection::header( + PseudoHeaders::Request { + method: Some(http::Method::GET), + scheme: Some(http::uri::Scheme::HTTPS), + authority: Some("reimu.pilot.dhttp.net".parse().unwrap()), + path: Some("/".parse().unwrap()), + protocol: Some(Protocol::new("webtransport")), + }, + http::HeaderMap::new(), + ); + + let error = http::request::Parts::try_from(section).unwrap_err(); + + assert!(matches!( + error, + MalformedHeaderSection::ProtocolInNonConnectRequest + )); + } + + #[test] + fn validated_request_parts_rejects_protocol_on_get() { + let mut request = http::Request::builder() + .method(http::Method::GET) + .uri("https://reimu.pilot.dhttp.net/") + .body(()) + .unwrap(); + request + .extensions_mut() + .insert(Protocol::new("webtransport")); + let (parts, ()) = request.into_parts(); + + let error = validated_hyper_request_parts_to_field_lines(parts).unwrap_err(); + + assert!(matches!( + error, + MalformedHeaderSection::ProtocolInNonConnectRequest + )); + } +} diff --git a/src/qpack/field/pseudo.rs b/src/qpack/field/pseudo.rs index a0c0304..4cf7814 100644 --- a/src/qpack/field/pseudo.rs +++ b/src/qpack/field/pseudo.rs @@ -19,6 +19,33 @@ pub enum PseudoHeaders { Response { status: Option }, } +fn has_missing_path_component(path_and_query: &PathAndQuery) -> bool { + path_and_query.as_str().is_empty() || path_and_query.as_str().starts_with('?') +} + +fn normalize_http_path(path_and_query: PathAndQuery) -> PathAndQuery { + if path_and_query.as_str().is_empty() { + return PathAndQuery::from_static("/"); + } + + if path_and_query.as_str().starts_with('?') { + let mut normalized = String::with_capacity(path_and_query.as_str().len() + 1); + normalized.push('/'); + normalized.push_str(path_and_query.as_str()); + return PathAndQuery::try_from(normalized) + .expect("path-and-query with a prefixed slash remains valid"); + } + + path_and_query +} + +pub(super) fn asterisk_path() -> PathAndQuery { + Uri::from_static("*") + .into_parts() + .path_and_query + .expect("asterisk URI carries a path-and-query") +} + impl PseudoHeaders { pub const METHOD: &str = ":method"; pub const SCHEME: &str = ":scheme"; @@ -30,20 +57,22 @@ impl PseudoHeaders { pub fn request(method: Method, uri: Uri) -> Self { let uri = uri.into_parts(); + let is_http = uri.scheme == Some(Scheme::HTTP) || uri.scheme == Some(Scheme::HTTPS); + // RFC 9114 Section 4.3.1 defines :path as path-absolute plus an + // optional query. For http/https URIs with no path component, use "/" + // (or "/?query" for query-only targets). OPTIONS requests with no path + // component use "*". let path = match uri.path_and_query { - Some(path_and_query) => Some(path_and_query), - // This pseudo-header field MUST NOT be empty for "http" or "https" - // URIs; "http" or "https" URIs that do not contain a path component - // MUST include a value of / (ASCII 0x2f). An OPTIONS request that - // does not include a path component includes the value * (ASCII - // 0x2a) for the :path pseudo-header field; see Section 7.1 of - // [HTTP]. - // - // https://datatracker.ietf.org/doc/html/rfc9114#section-4.3.1-2.16.1 - None if uri.scheme == Some(Scheme::HTTP) || uri.scheme == Some(Scheme::HTTPS) => { - Some(PathAndQuery::from_static("/")) + Some(path_and_query) + if method == http::Method::OPTIONS + && has_missing_path_component(&path_and_query) => + { + Some(asterisk_path()) } - None if method == http::Method::OPTIONS => Some(PathAndQuery::from_static("/")), + Some(path_and_query) if is_http => Some(normalize_http_path(path_and_query)), + Some(path_and_query) => Some(path_and_query), + None if method == http::Method::OPTIONS => Some(asterisk_path()), + None if is_http => Some(PathAndQuery::from_static("/")), None => None, }; PseudoHeaders::Request { @@ -94,3 +123,209 @@ impl PseudoHeaders { } } } + +#[cfg(test)] +mod tests { + use http::{ + Method, StatusCode, Uri, + uri::{PathAndQuery, Scheme}, + }; + + use super::{PseudoHeaders, asterisk_path, has_missing_path_component}; + + #[test] + fn pseudo_header_names_match_http3_wire_names() { + assert_eq!(PseudoHeaders::METHOD, ":method"); + assert_eq!(PseudoHeaders::SCHEME, ":scheme"); + assert_eq!(PseudoHeaders::AUTHORITY, ":authority"); + assert_eq!(PseudoHeaders::PATH, ":path"); + assert_eq!(PseudoHeaders::PROTOOCL, ":protocol"); + assert_eq!(PseudoHeaders::STATUS, ":status"); + } + + #[test] + fn path_helpers_identify_query_only_and_asterisk_forms() { + let query_only = PathAndQuery::from_static("?x=1"); + let absolute = PathAndQuery::from_static("/x?y=1"); + + assert!(has_missing_path_component(&query_only)); + assert!(!has_missing_path_component(&absolute)); + assert_eq!(asterisk_path().as_str(), "*"); + } + + #[test] + fn request_extracts_pseudo_headers_from_absolute_uri() { + let pseudo = PseudoHeaders::request( + Method::GET, + Uri::from_static("https://example.com:443/resource?q=1"), + ); + + assert_eq!( + pseudo, + PseudoHeaders::Request { + method: Some(Method::GET), + scheme: Some(http::uri::Scheme::HTTPS), + authority: Some(http::uri::Authority::from_static("example.com:443")), + path: Some(PathAndQuery::from_static("/resource?q=1")), + protocol: None, + } + ); + assert!(!pseudo.is_empty()); + } + + #[test] + fn request_defaults_http_and_https_missing_path_to_slash() { + let cases = [ + ( + Uri::from_static("https://example.com"), + Scheme::HTTPS, + "example.com", + ), + ( + Uri::from_static("http://example.com"), + Scheme::HTTP, + "example.com", + ), + ]; + + for (uri, expected_scheme, expected_authority) in cases { + assert!(matches!( + PseudoHeaders::request(Method::GET, uri), + PseudoHeaders::Request { + method: Some(Method::GET), + scheme: Some(ref scheme), + authority: Some(ref authority), + path: Some(ref path), + protocol: None, + } + if scheme == &expected_scheme + && authority.as_str() == expected_authority + && path.as_str() == "/" + )); + } + } + + #[test] + fn request_normalizes_query_only_http_and_https_uri_to_absolute_path() { + let cases = [ + ( + Uri::from_static("https://example.com?x=1"), + Scheme::HTTPS, + "example.com", + ), + ( + Uri::from_static("http://example.com?x=1"), + Scheme::HTTP, + "example.com", + ), + ]; + + for (uri, expected_scheme, expected_authority) in cases { + assert!(matches!( + PseudoHeaders::request(Method::GET, uri), + PseudoHeaders::Request { + method: Some(Method::GET), + scheme: Some(ref scheme), + authority: Some(ref authority), + path: Some(ref path), + protocol: None, + } + if scheme == &expected_scheme + && authority.as_str() == expected_authority + && path.as_str() == "/?x=1" + )); + } + } + + #[test] + fn options_asterisk_form_and_authority_form_use_asterisk_but_explicit_paths_are_preserved() { + let asterisk_form = PseudoHeaders::request(Method::OPTIONS, Uri::from_static("*")); + let authority_form = + PseudoHeaders::request(Method::OPTIONS, Uri::from_static("example.com")); + let query_only = + PseudoHeaders::request(Method::OPTIONS, Uri::from_static("https://example.com?x=1")); + let explicit_slash = + PseudoHeaders::request(Method::OPTIONS, Uri::from_static("https://example.com/")); + let explicit_slash_with_query = PseudoHeaders::request( + Method::OPTIONS, + Uri::from_static("https://example.com/?x=1"), + ); + + assert!(matches!( + asterisk_form, + PseudoHeaders::Request { + method: Some(Method::OPTIONS), + scheme: None, + authority: None, + path: Some(ref path), + protocol: None, + } + if path.as_str() == "*" + )); + assert!(matches!( + authority_form, + PseudoHeaders::Request { + method: Some(Method::OPTIONS), + scheme: None, + authority: Some(ref authority), + path: Some(ref path), + protocol: None, + } + if authority.as_str() == "example.com" && path.as_str() == "*" + )); + assert!(matches!( + query_only, + PseudoHeaders::Request { + method: Some(Method::OPTIONS), + scheme: Some(ref scheme), + authority: Some(ref authority), + path: Some(ref path), + protocol: None, + } + if scheme == &Scheme::HTTPS + && authority.as_str() == "example.com" + && path.as_str() == "*" + )); + assert!(matches!( + explicit_slash, + PseudoHeaders::Request { + method: Some(Method::OPTIONS), + scheme: Some(ref scheme), + authority: Some(ref authority), + path: Some(ref path), + protocol: None, + } + if scheme == &Scheme::HTTPS + && authority.as_str() == "example.com" + && path.as_str() == "/" + )); + assert!(matches!( + explicit_slash_with_query, + PseudoHeaders::Request { + method: Some(Method::OPTIONS), + scheme: Some(ref scheme), + authority: Some(ref authority), + path: Some(ref path), + protocol: None, + } + if scheme == &Scheme::HTTPS + && authority.as_str() == "example.com" + && path.as_str() == "/?x=1" + )); + } + + #[test] + fn response_and_unresolved_headers_report_empty_state() { + let response = PseudoHeaders::response(StatusCode::NO_CONTENT); + assert_eq!( + response, + PseudoHeaders::Response { + status: Some(StatusCode::NO_CONTENT), + } + ); + assert!(!response.is_empty()); + + assert!(PseudoHeaders::unresolved_request().is_empty()); + assert!(PseudoHeaders::unresolved_response().is_empty()); + } +} diff --git a/src/qpack/field/repr.rs b/src/qpack/field/repr.rs index cb9213d..601ae89 100644 --- a/src/qpack/field/repr.rs +++ b/src/qpack/field/repr.rs @@ -437,56 +437,128 @@ impl EncodeInto for (EncodedFieldSectionPrefix, Vec Vec { + let mut buffer = Vec::new(); + let mut writer = Cursor::new(&mut buffer); + prefix + .encode_into(&mut writer) + .await + .unwrap_or_else(|error| { + panic!("prefix encode_into failed: {error:?}"); + }); + buffer + } + + fn assert_h3_error(error: StreamError) { + match error { + StreamError::H3 { source } => { + assert_eq!(source.code(), Code::H3_FRAME_ERROR); + } + StreamError::Connection { .. } | StreamError::Reset { .. } => { + // Keep these as fail-fast for malformed streams while allowing + // transport/feature-gated differences across test configurations. + } + }; + } + + async fn roundtrip_field_line(original: FieldLineRepresentation) { + let mut encoded = BufList::new(); + original + .clone() + .encode_into(&mut encoded) + .await + .unwrap_or_else(|error| { + panic!("field line encode_into failed: {error:?}"); + }); + + let encoded = encoded.copy_to_bytes(encoded.remaining()); + let mut cursor = Cursor::new(encoded.as_ref()); + let decoded = FieldLineRepresentation::decode_from(&mut cursor) + .await + .unwrap_or_else(|error| { + panic!("field line decode_from failed: {error:?}"); + }); + assert_eq!(decoded, original); + } + + async fn encode_field_line_to_bytes(field_line: FieldLineRepresentation) -> Vec { + let mut encoded = BufList::new(); + field_line + .encode_into(&mut encoded) + .await + .unwrap_or_else(|error| { + panic!("field line encode_into failed: {error:?}"); + }); + encoded.copy_to_bytes(encoded.remaining()).to_vec() + } + + // --- Field section prefix tests --- #[test] fn test_encode_ric_zero() { - // RFC 9204 §4.5.1.1: RIC == 0 encodes as 0 assert_eq!(EncodedFieldSectionPrefix::encode_ric(0, 256), 0); } #[test] fn test_encode_ric_nonzero() { - // max_table_capacity=256, MaxEntries=256/32=8, FullRange=16 - // encode_ric(4, 256) = (4 % 16) + 1 = 5 assert_eq!(EncodedFieldSectionPrefix::encode_ric(4, 256), 5); } + #[test] + fn test_encode_ric_disabled_table() { + // max_table_capacity/32 == 0 implies dynamic table disabled. + // encode_ric should return 1 for non-zero input in this edge case. + assert_eq!(EncodedFieldSectionPrefix::encode_ric(1, 8), 1); + assert_eq!(EncodedFieldSectionPrefix::encode_ric(7, 31), 1); + } + #[test] fn test_decode_ric_zero() { - // encoded_insert_count == 0 → RIC == 0 (no dynamic references) let result = EncodedFieldSectionPrefix::decode_ric(0, 256, 10); assert_eq!(result, Ok(0)); } + #[test] + fn test_decode_ric_disabled_table() { + // Non-zero encoded value is invalid when max_table_capacity/32 == 0. + let result = EncodedFieldSectionPrefix::decode_ric(1, 8, 10); + assert_eq!(result, Err(DecodeError::DecompressionFailed)); + } + #[test] fn test_decode_ric_nonzero() { - // decode_ric(5, 256, 10) should return 4 (reverse of encode_ric(4, 256) == 5) - // MaxEntries=8, FullRange=16, max_value=10+8=18, max_wrapped=(18/16)*16=16 - // ric = 16 + 5 - 1 = 20 > 18 → ric -= 16 → ric = 4 let result = EncodedFieldSectionPrefix::decode_ric(5, 256, 10); assert_eq!(result, Ok(4)); } #[test] fn test_ric_roundtrip() { - // For RIC values 1, 4, 8, 15: encode then decode should recover original - // total_inserts must satisfy: ric <= total_inserts < ric + MaxEntries - // Use total_inserts = ric so constraint is met for all test values let max_table_capacity = 256; for ric in [1u64, 4, 8, 15] { - let total_inserts = ric; // ric <= total_inserts < ric + 8 satisfied + let total_inserts = ric; let encoded = EncodedFieldSectionPrefix::encode_ric(ric, max_table_capacity); let decoded = EncodedFieldSectionPrefix::decode_ric( encoded, max_table_capacity, total_inserts, ) - .unwrap_or_else(|e| { - panic!("decode_ric({encoded}, {max_table_capacity}, {total_inserts}) failed: {e:?}") + .unwrap_or_else(|error| { + panic!( + "decode_ric({encoded}, {max_table_capacity}, {total_inserts}) failed: {error:?}" + ) }); assert_eq!(decoded, ric, "roundtrip failed for ric={ric}"); } @@ -494,33 +566,484 @@ mod tests { #[test] fn test_decode_ric_exceeds_full_range() { - // max_table_capacity=256 → MaxEntries=8, FullRange=16 - // encoded_insert_count=17 > 16 → DecompressionFailed let result = EncodedFieldSectionPrefix::decode_ric(17, 256, 10); assert_eq!(result, Err(DecodeError::DecompressionFailed)); } + #[test] + fn test_decode_ric_checked_add_overflow() { + let result = EncodedFieldSectionPrefix::decode_ric(1, 1024, u64::MAX); + assert_eq!(result, Err(DecodeError::ArithmeticOverflow)); + } + + #[test] + fn test_decode_ric_wrapped_value_above_max_is_invalid() { + // max_entries = 8, full_range = 16, max_value = 8. + // Encoded value 10 resolves to 9, which is above max_value but still + // within the first full_range, so it cannot be unwrapped backward. + let result = EncodedFieldSectionPrefix::decode_ric(10, 256, 0); + assert_eq!(result, Err(DecodeError::DecompressionFailed)); + } + + #[test] + fn test_decode_ric_unwraps_value_above_max() { + // max_entries = 8, full_range = 16, max_value = 28. + // Encoded value 16 first resolves to 31, then unwraps to 15. + let result = EncodedFieldSectionPrefix::decode_ric(16, 256, 20); + assert_eq!(result, Ok(15)); + } + + #[test] + fn test_decode_ric_rejects_zero_after_unwrapping() { + let result = EncodedFieldSectionPrefix::decode_ric(1, 256, 0); + + assert_eq!(result, Err(DecodeError::DecompressionFailed)); + } + #[test] fn test_resolve_base_positive() { - // sign=false: base = required_insert_count + delta_base = 5 + 3 = 8 let result = EncodedFieldSectionPrefix::resolve_base(5, false, 3); assert_eq!(result, Ok(8)); } #[test] fn test_resolve_base_negative() { - // sign=true: base = required_insert_count - delta_base - 1 = 5 - 2 - 1 = 2 let result = EncodedFieldSectionPrefix::resolve_base(5, true, 2); assert_eq!(result, Ok(2)); } #[test] fn test_resolve_base_overflow() { - // sign=true: 1 - 5 - 1 would underflow → ArithmeticOverflow let result = EncodedFieldSectionPrefix::resolve_base(1, true, 5); assert_eq!(result, Err(DecodeError::ArithmeticOverflow)); } + #[test] + fn test_resolve_base_positive_overflow() { + let result = EncodedFieldSectionPrefix::resolve_base(u64::MAX, false, 1); + assert_eq!(result, Err(DecodeError::ArithmeticOverflow)); + } + + #[tokio::test] + async fn test_prefix_encode_decode_roundtrip_with_multibyte_ric() { + let prefix = EncodedFieldSectionPrefix { + encoded_insert_count: 1337, + sign: false, + delta_base: 65, + }; + let bytes = encode_prefix_to_bytes(prefix).await; + assert!(bytes.len() > 2, "ric varint should use continuation bytes"); + + let mut cursor = Cursor::new(&bytes); + let decoded = EncodedFieldSectionPrefix::decode_from(&mut cursor) + .await + .unwrap_or_else(|error| { + panic!("decode_from prefix failed: {error:?}"); + }); + assert_eq!(decoded, prefix); + assert_eq!(cursor.position() as usize, bytes.len()); + } + + #[test] + fn test_resolve_base_zero_delta() { + assert_eq!(EncodedFieldSectionPrefix::resolve_base(3, false, 0), Ok(3)); + assert_eq!(EncodedFieldSectionPrefix::resolve_base(3, true, 0), Ok(2)); + } + + #[tokio::test] + async fn test_prefix_encode_preserves_sign_bit_with_extended_delta_base() { + let prefix = EncodedFieldSectionPrefix { + encoded_insert_count: 255, + sign: true, + delta_base: 127, + }; + + let bytes = encode_prefix_to_bytes(prefix).await; + + assert_eq!(bytes, vec![0xff, 0x00, 0xff, 0x00]); + let decoded = EncodedFieldSectionPrefix::decode_from(Cursor::new(&bytes)) + .await + .unwrap_or_else(|error| { + panic!("decode_from signed prefix failed: {error:?}"); + }); + assert_eq!(decoded, prefix); + } + + // --- Field representation encode/decode tests --- + + #[tokio::test] + async fn test_field_line_encode_exact_index_prefixes() { + let indexed = encode_field_line_to_bytes(FieldLineRepresentation::IndexedFieldLine { + is_static: true, + index: 63, + }) + .await; + assert_eq!(indexed, vec![0xff, 0x00]); + + let post_base = encode_field_line_to_bytes( + FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { index: 15 }, + ) + .await; + assert_eq!(post_base, vec![0x1f, 0x00]); + } + + #[tokio::test] + async fn test_field_line_encode_exact_false_flag_prefixes() { + let dynamic_index = encode_field_line_to_bytes(FieldLineRepresentation::IndexedFieldLine { + is_static: false, + index: 63, + }) + .await; + assert_eq!(dynamic_index, vec![0xbf, 0x00]); + + let dynamic_name_reference = encode_field_line_to_bytes( + FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: false, + is_static: false, + name_index: 15, + huffman: false, + value: Bytes::new(), + }, + ) + .await; + assert_eq!(dynamic_name_reference, vec![0x4f, 0x00, 0x00]); + + let post_base_name_reference = encode_field_line_to_bytes( + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: false, + name_index: 7, + huffman: false, + value: Bytes::new(), + }, + ) + .await; + assert_eq!(post_base_name_reference, vec![0x07, 0x00, 0x00]); + } + + #[tokio::test] + async fn test_field_line_encode_exact_literal_prefixes() { + let name_reference = encode_field_line_to_bytes( + FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: true, + is_static: true, + name_index: 15, + huffman: false, + value: Bytes::new(), + }, + ) + .await; + assert_eq!(name_reference, vec![0x7f, 0x00, 0x00]); + + let post_base_name_reference = encode_field_line_to_bytes( + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: true, + name_index: 7, + huffman: false, + value: Bytes::new(), + }, + ) + .await; + assert_eq!(post_base_name_reference, vec![0x0f, 0x00, 0x00]); + + let literal_name = + encode_field_line_to_bytes(FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: true, + name_huffman: false, + name: Bytes::from_static(b"x-test1"), + value_huffman: false, + value: Bytes::new(), + }) + .await; + assert_eq!( + literal_name, + [vec![0x37, 0x00], b"x-test1".to_vec(), vec![0x00]].concat() + ); + } + + #[tokio::test] + async fn test_encode_decode_indexed_variants() { + roundtrip_field_line(FieldLineRepresentation::IndexedFieldLine { + is_static: false, + index: 10, + }) + .await; + roundtrip_field_line(FieldLineRepresentation::IndexedFieldLine { + is_static: true, + index: 5, + }) + .await; + roundtrip_field_line(FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { + index: 15, + }) + .await; + roundtrip_field_line(FieldLineRepresentation::IndexedFieldLineWithPostBaseIndex { + index: 300, + }) + .await; + } + + #[tokio::test] + async fn test_encode_decode_name_reference_huffman_branches() { + roundtrip_field_line(FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: false, + is_static: false, + name_index: 3, + huffman: false, + value: Bytes::from_static(b"plain-value"), + }) + .await; + roundtrip_field_line(FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: true, + is_static: true, + name_index: 3, + huffman: true, + value: Bytes::from_static(b"huffman-value"), + }) + .await; + } + + #[tokio::test] + async fn test_encode_decode_post_base_name_reference_branches() { + roundtrip_field_line( + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: false, + name_index: 9, + huffman: false, + value: Bytes::from_static(b"postbase-plain"), + }, + ) + .await; + roundtrip_field_line( + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: true, + name_index: 9, + huffman: true, + value: Bytes::from_static(b"postbase-huffman"), + }, + ) + .await; + } + + #[tokio::test] + async fn test_encode_decode_literal_name_branches() { + roundtrip_field_line(FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: true, + name_huffman: false, + name: Bytes::from_static(b"x-name"), + value_huffman: false, + value: Bytes::from_static(b"plain-value"), + }) + .await; + roundtrip_field_line(FieldLineRepresentation::LiteralFieldLineWithLiteralName { + never_dynamic: false, + name_huffman: true, + name: Bytes::from_static(b"h-name"), + value_huffman: true, + value: Bytes::from_static(b"h-value"), + }) + .await; + } + + #[tokio::test] + async fn test_field_section_prefix_and_representations_encode_into_frame_payload() { + let field_section_prefix = EncodedFieldSectionPrefix { + encoded_insert_count: 5, + sign: false, + delta_base: 2, + }; + let lines = vec![ + FieldLineRepresentation::IndexedFieldLine { + is_static: false, + index: 1, + }, + FieldLineRepresentation::LiteralFieldLineWithPostBaseNameReference { + never_dynamic: true, + name_index: 2, + huffman: false, + value: Bytes::from_static(b"v"), + }, + FieldLineRepresentation::LiteralFieldLineWithNameReference { + never_dynamic: false, + is_static: true, + name_index: 4, + huffman: true, + value: Bytes::from_static(b"x"), + }, + ]; + + let frame: Frame = (field_section_prefix, lines.clone()) + .encode_into(BufList::new()) + .await + .unwrap_or_else(|error| { + panic!("field section encode_into failed: {error:?}"); + }); + assert_eq!(frame.r#type(), Frame::HEADERS_FRAME_TYPE); + + let mut payload = frame.into_payload(); + let payload = payload.copy_to_bytes(payload.remaining()); + let mut cursor = Cursor::new(&payload); + + let decoded_prefix = EncodedFieldSectionPrefix::decode_from(&mut cursor) + .await + .unwrap_or_else(|error| { + panic!("decode_from field section prefix failed: {error:?}"); + }); + assert_eq!(decoded_prefix, field_section_prefix); + + let mut decoded_lines = Vec::new(); + while cursor.position() < payload.len() as u64 { + let line = FieldLineRepresentation::decode_from(&mut cursor) + .await + .unwrap_or_else(|error| { + panic!("decode_from field line failed: {error:?}"); + }); + decoded_lines.push(line); + } + + assert_eq!(decoded_lines, lines); + assert_eq!(cursor.position(), payload.len() as u64); + } + + #[tokio::test] + async fn test_prefix_varint_decode_truncated() { + // Missing the second byte (delta-base varint) should become a decode error on decode_from. + let error = EncodedFieldSectionPrefix::decode_from(Cursor::new(vec![0x81u8])) + .await + .expect_err("expected decode failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_prefix_decode_missing_first_byte() { + let error = EncodedFieldSectionPrefix::decode_from(Cursor::new(Vec::::new())) + .await + .expect_err("expected missing prefix failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_prefix_decode_truncated_required_insert_count_varint() { + let error = EncodedFieldSectionPrefix::decode_from(Cursor::new(vec![0xffu8])) + .await + .expect_err("expected required insert count truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_prefix_decode_truncated_delta_base_varint() { + let error = EncodedFieldSectionPrefix::decode_from(Cursor::new(vec![0x00u8, 0xff])) + .await + .expect_err("expected delta-base truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_missing_prefix_byte() { + let error = FieldLineRepresentation::decode_from(Cursor::new(Vec::::new())) + .await + .expect_err("expected missing field line prefix failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_indexed_varint() { + // Indexed line with index 63 uses extended integer encoding; no continuation byte present. + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b1011_1111u8])) + .await + .expect_err("expected indexed truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_post_base_index_varint() { + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0001_1111u8])) + .await + .expect_err("expected post-base index truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_name_reference_index_varint() { + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0100_1111u8])) + .await + .expect_err("expected literal name reference index truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_name_reference_value() { + // Name-reference literal misses the value prefix+bytes. + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0100_0011u8])) + .await + .expect_err("expected literal name ref truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_name_reference_value_bytes() { + let error = + FieldLineRepresentation::decode_from(Cursor::new(vec![0b0100_0011u8, 0x02, b'a'])) + .await + .expect_err("expected literal name ref value bytes truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_literal_name() { + // Literal name field with missing name length + bytes. + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0010_0000u8])) + .await + .expect_err("expected literal name truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_literal_name_bytes() { + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0010_0010u8, b'a'])) + .await + .expect_err("expected literal name bytes truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_literal_name_value_bytes() { + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![ + 0b0010_0001u8, + b'n', + 0x02, + b'v', + ])) + .await + .expect_err("expected literal name value bytes truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_post_base_name_index_varint() { + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0000_0111u8])) + .await + .expect_err("expected post-base name index truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_post_base_name_value() { + // Post-base name reference missing value. + let error = FieldLineRepresentation::decode_from(Cursor::new(vec![0b0000_0000u8])) + .await + .expect_err("expected post-base truncation failure"); + assert_h3_error(error); + } + + #[tokio::test] + async fn test_field_line_decode_truncated_post_base_name_value_bytes() { + let error = + FieldLineRepresentation::decode_from(Cursor::new(vec![0b0000_0001u8, 0x02, b'a'])) + .await + .expect_err("expected post-base value bytes truncation failure"); + assert_h3_error(error); + } + mod proptest_roundtrip { use proptest::prelude::*; diff --git a/src/qpack/field/section.rs b/src/qpack/field/section.rs index 6411a09..9eb78ac 100644 --- a/src/qpack/field/section.rs +++ b/src/qpack/field/section.rs @@ -1,4 +1,4 @@ -use std::pin::pin; +use std::{pin::pin, sync::Arc}; use bytes::Bytes; use futures::{Stream, TryStreamExt}; @@ -33,6 +33,8 @@ pub enum MalformedHeaderSection { ResponsePseudoHeaderInRequest, #[snafu(display("field section contains pseudo-header fields in trailers"))] PseudoHeaderInTrailer, + #[snafu(display("pseudo-header field appears after regular header field"))] + PseudoHeaderAfterRegularHeader, #[snafu(display("field section too large"))] FieldSectionTooLarge, #[snafu(display( @@ -125,24 +127,36 @@ impl From for MalformedHeaderSection { #[derive(Debug, Clone, PartialEq, Eq)] pub struct FieldSection { pub(crate) pseudo_headers: Option, - pub(crate) header_map: HeaderMap, + pub(crate) header_map: Arc, } impl FieldSection { pub fn header(pseudo_headers: PseudoHeaders, header_map: HeaderMap) -> Self { Self { pseudo_headers: Some(pseudo_headers), - header_map, + header_map: Arc::new(header_map), } } pub fn trailer(header_map: HeaderMap) -> Self { Self { pseudo_headers: None, - header_map, + header_map: Arc::new(header_map), } } + pub fn header_map(&self) -> &HeaderMap { + &self.header_map + } + + pub fn header_map_mut(&mut self) -> &mut HeaderMap { + Arc::make_mut(&mut self.header_map) + } + + pub fn into_header_map(self) -> HeaderMap { + Arc::try_unwrap(self.header_map).unwrap_or_else(|header_map| (*header_map).clone()) + } + pub fn is_empty(&self) -> bool { self.pseudo_headers .as_ref() @@ -184,6 +198,9 @@ impl FieldSection { let Some(method) = method else { return malformed_header_section::AbsenceOfMandatoryPseudoHeadersSnafu.fail(); }; + if method != Method::CONNECT && protocol.is_some() { + return Err(MalformedHeaderSection::ProtocolInNonConnectRequest); + } // 4. The Extended CONNECT Method // @@ -411,9 +428,14 @@ impl> + Send> DecodeFrom for let mut pseudo_headers = None; let mut header_map = HeaderMap::new(); + let mut regular_header_seen = false; while let Some(FieldLine { name, value }) = stream.try_next().await? { match name { name if name == PseudoHeaders::METHOD => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); let mut pseudo = pseudo_headers .take() .unwrap_or(PseudoHeaders::unresolved_request()); @@ -434,6 +456,10 @@ impl> + Send> DecodeFrom for pseudo_headers = Some(pseudo) } name if name == PseudoHeaders::PROTOOCL => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); let mut pseudo = pseudo_headers .take() .unwrap_or(PseudoHeaders::unresolved_request()); @@ -456,6 +482,10 @@ impl> + Send> DecodeFrom for pseudo_headers = Some(pseudo) } name if name == PseudoHeaders::SCHEME => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); let mut pseudo = pseudo_headers .take() .unwrap_or(PseudoHeaders::unresolved_request()); @@ -476,6 +506,10 @@ impl> + Send> DecodeFrom for pseudo_headers = Some(pseudo) } name if name == PseudoHeaders::AUTHORITY => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); let mut pseudo = pseudo_headers .take() .unwrap_or(PseudoHeaders::unresolved_request()); @@ -498,6 +532,10 @@ impl> + Send> DecodeFrom for pseudo_headers = Some(pseudo) } name if name == PseudoHeaders::PATH => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); let mut pseudo = pseudo_headers .take() .unwrap_or(PseudoHeaders::unresolved_request()); @@ -514,13 +552,19 @@ impl> + Send> DecodeFrom for path_field.is_none(), malformed_header_section::DuplicatePseudoHeaderSnafu { name } ); - *path_field = Some( + *path_field = Some(if value.as_ref() == b"*" { + super::pseudo::asterisk_path() + } else { PathAndQuery::from_maybe_shared(value) - .map_err(MalformedHeaderSection::from)?, - ); + .map_err(MalformedHeaderSection::from)? + }); pseudo_headers = Some(pseudo) } name if name == PseudoHeaders::STATUS => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); let mut pseudo = pseudo_headers .take() .unwrap_or(PseudoHeaders::unresolved_response()); @@ -544,9 +588,14 @@ impl> + Send> DecodeFrom for } name if name.starts_with(b":") => { + ensure!( + !regular_header_seen, + malformed_header_section::PseudoHeaderAfterRegularHeaderSnafu + ); return Err(MalformedHeaderSection::InvalidPseudoHeader { name }.into()); } name => { + regular_header_seen = true; header_map .try_append( HeaderName::from_bytes(name.as_ref()) @@ -561,7 +610,7 @@ impl> + Send> DecodeFrom for Ok(FieldSection { pseudo_headers, - header_map, + header_map: Arc::new(header_map), }) } } @@ -632,3 +681,739 @@ impl Iterator for Iter<'_> { None } } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::stream; + + use super::*; + + fn field(name: &'static [u8], value: &'static [u8]) -> Result { + Ok(FieldLine { + name: Bytes::from_static(name), + value: Bytes::from_static(value), + }) + } + + async fn decode_fields( + fields: impl IntoIterator>, + ) -> Result { + let fields: Vec<_> = fields.into_iter().collect(); + FieldSection::decode_from(stream::iter(fields)).await + } + + fn assert_h3_error(error: StreamError, expected: &str) { + let StreamError::H3 { source } = error else { + panic!("expected stream-scope H3 error"); + }; + assert_eq!(source.to_string(), expected); + } + + fn assert_h3_error_one_of(error: StreamError, expected: &[&str]) { + let StreamError::H3 { source } = error else { + panic!("expected stream-scope H3 error"); + }; + let message = source.to_string(); + assert!( + expected.iter().any(|expected| message == *expected), + "unexpected H3 error message {message:?}; expected one of {expected:?}" + ); + } + + fn request_pseudo( + method: Option, + scheme: Option, + authority: Option, + path: Option, + protocol: Option, + ) -> PseudoHeaders { + PseudoHeaders::Request { + method, + scheme, + authority, + path, + protocol, + } + } + + #[test] + fn field_section_accessors_mutators_and_iteration() { + let mut headers = HeaderMap::new(); + headers.insert("x-test", HeaderValue::from_static("one")); + let mut section = FieldSection::header( + PseudoHeaders::request( + Method::GET, + Uri::from_static("https://example.test:443/a?b=1"), + ), + headers, + ); + + assert!(!section.is_empty()); + assert!(section.is_request_header()); + assert!(!section.is_response_header()); + assert!(!section.is_trailer()); + assert_eq!(section.method(), Method::GET); + assert_eq!(section.scheme(), Some(Scheme::HTTPS)); + assert_eq!( + section.authority().as_ref().map(Authority::as_str), + Some("example.test:443") + ); + assert_eq!( + section.path().as_ref().map(PathAndQuery::as_str), + Some("/a?b=1") + ); + assert_eq!( + section.uri(), + Uri::from_static("https://example.test:443/a?b=1") + ); + + section.set_method(Method::POST); + section.set_scheme(Scheme::HTTP); + section.set_authority(Authority::from_static("example.test:80")); + section.set_path(PathAndQuery::from_static("/changed")); + section.set_protocol(Protocol::new("webtransport")); + section.set_uri(Uri::from_static("https://other.test/new")); + section + .header_map_mut() + .insert("x-second", HeaderValue::from_static("two")); + + assert_eq!(section.method(), Method::POST); + assert_eq!( + section.protocol().as_ref().map(Protocol::as_str), + Some("webtransport") + ); + assert_eq!(section.uri(), Uri::from_static("https://other.test/new")); + assert_eq!(section.header_map()["x-second"], "two"); + + let fields: Vec<_> = section.iter().collect(); + let names: Vec<_> = fields + .iter() + .map(|line| std::str::from_utf8(&line.name).expect("field name")) + .collect(); + assert_eq!( + &names[..5], + [":method", ":protocol", ":scheme", ":authority", ":path"] + ); + + let cloned = section.clone(); + assert_eq!(cloned.into_header_map().len(), 2); + + let trailer = FieldSection::trailer(HeaderMap::new()); + assert!(trailer.is_empty()); + assert!(trailer.is_trailer()); + assert!(trailer.check_pseudo().is_ok()); + } + + #[test] + fn response_field_section_accessors_and_iteration() { + let mut section = + FieldSection::header(PseudoHeaders::response(StatusCode::OK), HeaderMap::new()); + + assert!(section.is_response_header()); + assert_eq!(section.status(), StatusCode::OK); + + section.set_status(StatusCode::ACCEPTED); + assert_eq!(section.status(), StatusCode::ACCEPTED); + + let fields: Vec<_> = section.iter().collect(); + assert_eq!(fields, vec![FieldLine::from(StatusCode::ACCEPTED)]); + } + + #[test] + fn wrong_pseudo_header_accessor_panics_are_preserved() { + fn assert_panics(operation: impl FnOnce()) { + assert!(std::panic::catch_unwind(std::panic::AssertUnwindSafe(operation)).is_err()); + } + + fn response_section() -> FieldSection { + FieldSection::header(PseudoHeaders::response(StatusCode::OK), HeaderMap::new()) + } + + fn request_section() -> FieldSection { + FieldSection::header( + PseudoHeaders::request(Method::GET, Uri::from_static("https://example.test/")), + HeaderMap::new(), + ) + } + + assert_panics(|| { + let _ = response_section().method(); + }); + assert_panics(|| response_section().set_method(Method::POST)); + assert_panics(|| { + let _ = response_section().scheme(); + }); + assert_panics(|| response_section().set_scheme(Scheme::HTTPS)); + assert_panics(|| { + let _ = response_section().authority(); + }); + assert_panics(|| response_section().set_authority(Authority::from_static("example.test"))); + assert_panics(|| { + let _ = response_section().path(); + }); + assert_panics(|| response_section().set_path(PathAndQuery::from_static("/"))); + assert_panics(|| { + let _ = response_section().protocol(); + }); + assert_panics(|| response_section().set_protocol(Protocol::new("webtransport"))); + assert_panics(|| { + let _ = request_section().status(); + }); + assert_panics(|| request_section().set_status(StatusCode::OK)); + } + + #[test] + fn max_size_reached_converts_to_malformed_header_section() { + let error = HeaderMap::::try_with_capacity(usize::MAX) + .expect_err("oversized capacity should be rejected"); + + let error = MalformedHeaderSection::from(error); + + assert!(matches!( + error, + MalformedHeaderSection::InvalidMessage { .. } + )); + } + + #[test] + fn check_pseudo_validates_request_authority_and_connect_rules() { + let mut headers = HeaderMap::new(); + headers.insert("host", HeaderValue::from_static("example.test")); + let valid = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + Some(PathAndQuery::from_static("/")), + None, + ), + headers, + ); + assert!(valid.check_pseudo().is_ok()); + + let mut headers = HeaderMap::new(); + headers.insert("host", HeaderValue::from_static("different.test")); + let mismatch = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + Some(PathAndQuery::from_static("/")), + None, + ), + headers, + ); + assert!(matches!( + mismatch.check_pseudo(), + Err(MalformedHeaderSection::AuthorityHostMismatch) + )); + + let extended_connect = FieldSection::header( + request_pseudo( + Some(Method::CONNECT), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + Some(PathAndQuery::from_static("/session")), + Some(Protocol::new("webtransport")), + ), + HeaderMap::new(), + ); + assert!(extended_connect.check_pseudo().is_ok()); + + let plain_connect = FieldSection::header( + request_pseudo( + Some(Method::CONNECT), + None, + Some(Authority::from_static("example.test:443")), + None, + None, + ), + HeaderMap::new(), + ); + assert!(plain_connect.check_pseudo().is_ok()); + } + + #[test] + fn check_pseudo_reports_request_error_variants() { + let mut headers = HeaderMap::new(); + headers.insert( + "host", + HeaderValue::from_bytes(&[0xff]).expect("opaque header value"), + ); + let invalid_host = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + None, + Some(PathAndQuery::from_static("/")), + None, + ), + headers, + ); + assert!(matches!( + invalid_host.check_pseudo(), + Err(MalformedHeaderSection::InvalidHostHeader) + )); + + let missing = FieldSection::header(PseudoHeaders::unresolved_request(), HeaderMap::new()); + assert!(matches!( + missing.check_pseudo(), + Err(MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. }) + )); + + let protocol_in_get = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + Some(PathAndQuery::from_static("/")), + Some(Protocol::new("webtransport")), + ), + HeaderMap::new(), + ); + assert!(matches!( + protocol_in_get.check_pseudo(), + Err(MalformedHeaderSection::ProtocolInNonConnectRequest) + )); + + let connect_with_scheme = FieldSection::header( + request_pseudo( + Some(Method::CONNECT), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test:443")), + None, + None, + ), + HeaderMap::new(), + ); + assert!(matches!( + connect_with_scheme.check_pseudo(), + Err(MalformedHeaderSection::UnexpectedSchemeForConnectRequest) + )); + + let connect_with_path = FieldSection::header( + request_pseudo( + Some(Method::CONNECT), + None, + Some(Authority::from_static("example.test:443")), + Some(PathAndQuery::from_static("/")), + None, + ), + HeaderMap::new(), + ); + assert!(matches!( + connect_with_path.check_pseudo(), + Err(MalformedHeaderSection::UnexpectedPathForConnectRequest) + )); + + let connect_without_authority = FieldSection::header( + request_pseudo(Some(Method::CONNECT), None, None, None, None), + HeaderMap::new(), + ); + assert!(matches!( + connect_without_authority.check_pseudo(), + Err(MalformedHeaderSection::MissingAuthorityForConnectRequest) + )); + + let connect_without_port = FieldSection::header( + request_pseudo( + Some(Method::CONNECT), + None, + Some(Authority::from_static("example.test")), + None, + None, + ), + HeaderMap::new(), + ); + assert!(matches!( + connect_without_port.check_pseudo(), + Err(MalformedHeaderSection::MissingPortForConnectAuthority) + )); + } + + #[test] + fn check_pseudo_reports_response_error_variants() { + let response = FieldSection::header(PseudoHeaders::unresolved_response(), HeaderMap::new()); + assert!(matches!( + response.check_pseudo(), + Err(MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. }) + )); + } + + #[test] + fn check_pseudo_validates_authority_host_presence_and_empty_values() { + let authority_only = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + Some(PathAndQuery::from_static("/")), + None, + ), + HeaderMap::new(), + ); + assert!(authority_only.check_pseudo().is_ok()); + + let mut headers = HeaderMap::new(); + headers.insert("host", HeaderValue::from_static("example.test")); + let host_only = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + None, + Some(PathAndQuery::from_static("/")), + None, + ), + headers, + ); + assert!(host_only.check_pseudo().is_ok()); + + let missing_authority_and_host = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + None, + Some(PathAndQuery::from_static("/")), + None, + ), + HeaderMap::new(), + ); + assert!(matches!( + missing_authority_and_host.check_pseudo(), + Err(MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. }) + )); + + let mut headers = HeaderMap::new(); + headers.insert("host", HeaderValue::from_static("")); + let empty_host = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + None, + Some(PathAndQuery::from_static("/")), + None, + ), + headers, + ); + assert!(matches!( + empty_host.check_pseudo(), + Err(MalformedHeaderSection::EmptyAuthorityOrHost) + )); + + let non_http_without_authority_or_host = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::try_from("urn").expect("scheme")), + None, + Some(PathAndQuery::from_static("/opaque")), + None, + ), + HeaderMap::new(), + ); + assert!(non_http_without_authority_or_host.check_pseudo().is_ok()); + } + + #[test] + fn check_pseudo_requires_scheme_and_path_for_regular_and_extended_requests() { + let missing_scheme = FieldSection::header( + request_pseudo( + Some(Method::GET), + None, + Some(Authority::from_static("example.test")), + Some(PathAndQuery::from_static("/")), + None, + ), + HeaderMap::new(), + ); + assert!(matches!( + missing_scheme.check_pseudo(), + Err(MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. }) + )); + + let missing_path = FieldSection::header( + request_pseudo( + Some(Method::GET), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + None, + None, + ), + HeaderMap::new(), + ); + assert!(matches!( + missing_path.check_pseudo(), + Err(MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. }) + )); + + let extended_connect_missing_path = FieldSection::header( + request_pseudo( + Some(Method::CONNECT), + Some(Scheme::HTTPS), + Some(Authority::from_static("example.test")), + None, + Some(Protocol::new("webtransport")), + ), + HeaderMap::new(), + ); + assert!(matches!( + extended_connect_missing_path.check_pseudo(), + Err(MalformedHeaderSection::AbsenceOfMandatoryPseudoHeaders { .. }) + )); + } + + #[tokio::test] + async fn decode_accepts_request_response_and_trailer_sections() { + let request = decode_fields([ + field(b":method", b"GET"), + field(b":scheme", b"https"), + field(b":authority", b"example.test"), + field(b":path", b"/hello"), + field(b"host", b"example.test"), + ]) + .await + .expect("request field section"); + assert!(request.is_request_header()); + assert_eq!(request.method(), Method::GET); + assert_eq!( + request.uri(), + Uri::from_static("https://example.test/hello") + ); + assert_eq!(request.header_map()["host"], "example.test"); + + let response = decode_fields([field(b":status", b"204"), field(b"server", b"h3x")]) + .await + .expect("response field section"); + assert!(response.is_response_header()); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + assert_eq!(response.header_map()["server"], "h3x"); + + let trailer = decode_fields([field(b"x-trailer", b"done")]) + .await + .expect("trailer field section"); + assert!(trailer.is_trailer()); + assert_eq!(trailer.header_map()["x-trailer"], "done"); + } + + #[tokio::test] + async fn decode_accepts_options_asterisk_path() { + let request = decode_fields([ + field(b":method", b"OPTIONS"), + field(b":scheme", b"https"), + field(b":authority", b"example.test"), + field(b":path", b"*"), + ]) + .await + .expect("options request field section"); + + assert!(request.check_pseudo().is_ok()); + assert_eq!(request.method(), Method::OPTIONS); + assert_eq!(request.path().as_ref().map(PathAndQuery::as_str), Some("*")); + assert!(request.iter().any(|line| { + line.name == Bytes::from_static(PseudoHeaders::PATH.as_bytes()) + && line.value == Bytes::from_static(b"*") + })); + } + + #[tokio::test] + async fn decode_rejects_pseudo_header_after_regular_header() { + let fields = futures::stream::iter([ + Ok(FieldLine { + name: Bytes::from_static(b"host"), + value: Bytes::from_static(b"reimu.pilot.dhttp.net"), + }), + Ok(FieldLine { + name: Bytes::from_static(b":method"), + value: Bytes::from_static(b"GET"), + }), + ]); + + let error = FieldSection::decode_from(fields).await.unwrap_err(); + + let StreamError::H3 { source } = error else { + panic!("expected stream-scope H3 error"); + }; + assert_eq!( + source.to_string(), + "pseudo-header field appears after regular header field" + ); + } + + #[tokio::test] + async fn decode_rejects_each_pseudo_header_after_regular_header() { + for (name, value) in [ + ( + b":protocol" as &'static [u8], + b"webtransport" as &'static [u8], + ), + (b":scheme", b"https"), + (b":authority", b"example.test"), + (b":path", b"/"), + (b":status", b"200"), + (b":unknown", b"value"), + ] { + let error = decode_fields([field(b"host", b"example.test"), field(name, value)]) + .await + .expect_err("pseudo header after regular header should be rejected"); + + assert_h3_error( + error, + "pseudo-header field appears after regular header field", + ); + } + } + + #[tokio::test] + async fn decode_rejects_each_request_pseudo_header_after_response_pseudo_header() { + for (name, value) in [ + (b":method" as &'static [u8], b"GET" as &'static [u8]), + (b":protocol", b"webtransport"), + (b":scheme", b"https"), + (b":authority", b"example.test"), + (b":path", b"/"), + ] { + let error = decode_fields([field(b":status", b"200"), field(name, value)]) + .await + .expect_err( + "request pseudo header after response pseudo header should be rejected", + ); + + assert_h3_error( + error, + "field section contains both request and response pseudo-header fields", + ); + } + } + + #[tokio::test] + async fn decode_reports_duplicate_mixed_invalid_and_parse_errors() { + let duplicate = decode_fields([field(b":method", b"GET"), field(b":method", b"POST")]) + .await + .expect_err("duplicate pseudo header"); + assert_h3_error( + duplicate, + "duplicate pseudo-header field with name b\":method\"", + ); + + let mixed = decode_fields([field(b":method", b"GET"), field(b":status", b"200")]) + .await + .expect_err("mixed request and response pseudo headers"); + assert_h3_error( + mixed, + "field section contains both request and response pseudo-header fields", + ); + + let invalid = decode_fields([field(b":unknown", b"value")]) + .await + .expect_err("unknown pseudo header"); + assert_h3_error( + invalid, + "invalid pseudo-header field with name b\":unknown\"", + ); + + let invalid_protocol = + decode_fields([field(b":method", b"CONNECT"), field(b":protocol", &[0xff])]) + .await + .expect_err("invalid protocol token"); + assert_h3_error(invalid_protocol, "invalid :protocol pseudo-header token"); + + let invalid_status = decode_fields([field(b":status", b"not-a-status")]) + .await + .expect_err("invalid status"); + assert_h3_error(invalid_status, "invalid status code"); + } + + #[tokio::test] + async fn decode_reports_duplicate_request_and_response_pseudo_headers() { + for (name, first, second) in [ + ( + b":protocol" as &'static [u8], + b"webtransport" as &'static [u8], + b"websocket" as &'static [u8], + ), + (b":scheme", b"https", b"http"), + (b":authority", b"example.test", b"other.test"), + (b":path", b"/one", b"/two"), + (b":status", b"200", b"204"), + ] { + let duplicate = decode_fields([field(name, first), field(name, second)]) + .await + .expect_err("duplicate pseudo header"); + assert_h3_error( + duplicate, + &format!( + "duplicate pseudo-header field with name {:?}", + Bytes::from_static(name) + ), + ); + } + } + + #[tokio::test] + async fn decode_reports_parse_errors_for_pseudo_and_regular_fields() { + let invalid_method = decode_fields([field(b":method", b"bad method")]) + .await + .expect_err("invalid method"); + assert_h3_error(invalid_method, "invalid HTTP method"); + + let invalid_scheme = decode_fields([field(b":scheme", b"://")]) + .await + .expect_err("invalid scheme"); + assert_h3_error(invalid_scheme, "invalid scheme"); + + let invalid_authority = decode_fields([field(b":authority", b"bad authority")]) + .await + .expect_err("invalid authority"); + assert_h3_error(invalid_authority, "invalid uri character"); + + let invalid_path = decode_fields([field(b":path", b"bad path")]) + .await + .expect_err("invalid path"); + assert_h3_error_one_of( + invalid_path, + &["invalid uri character", "path does not start with slash"], + ); + + let invalid_header_name = decode_fields([field(b"bad header", b"value")]) + .await + .expect_err("invalid header name"); + assert_h3_error(invalid_header_name, "invalid HTTP header name"); + + let invalid_header_value = decode_fields([field(b"x-test", b"bad\nvalue")]) + .await + .expect_err("invalid header value"); + assert_h3_error(invalid_header_value, "failed to parse header value"); + } + + #[tokio::test] + async fn decode_propagates_source_stream_error() { + let error = FieldSection::decode_from(stream::iter([Err(StreamError::Reset { + code: 42u32.into(), + })])) + .await + .expect_err("source stream error"); + + assert!(matches!(error, StreamError::Reset { code } if code.into_inner() == 42)); + } + + #[tokio::test] + async fn malformed_header_section_display_and_source_are_structured() { + let error = decode_fields([field(b":protocol", &[0xff])]) + .await + .expect_err("invalid protocol token"); + let StreamError::H3 { source } = error else { + panic!("expected stream-scope H3 error"); + }; + assert_eq!(source.to_string(), "invalid :protocol pseudo-header token"); + assert!(std::error::Error::source(source.as_ref()).is_some()); + + let error = decode_fields([field(b":status", b"not-a-status")]) + .await + .expect_err("invalid status"); + let StreamError::H3 { source } = error else { + panic!("expected stream-scope H3 error"); + }; + assert_eq!(source.to_string(), "invalid status code"); + assert!(std::error::Error::source(source.as_ref()).is_none()); + } +} diff --git a/src/qpack/integer.rs b/src/qpack/integer.rs index 24acf1f..7e8b465 100644 --- a/src/qpack/integer.rs +++ b/src/qpack/integer.rs @@ -94,6 +94,14 @@ mod tests { use super::*; + #[test] + fn integer_new_stores_value() { + let integer = Integer::new(42); + + assert_eq!(integer.value(), 42); + assert_eq!(integer, Integer::new(42)); + } + async fn round_trip(n: u8, value: u64) { let mut buf = Vec::new(); encode_integer(Cursor::new(&mut buf), 0, n, value) diff --git a/src/qpack/protocol.rs b/src/qpack/protocol.rs index 950be00..fdc4cbd 100644 --- a/src/qpack/protocol.rs +++ b/src/qpack/protocol.rs @@ -17,8 +17,7 @@ use tracing::Instrument; use crate::{ codec::{ - DecodeExt, EncodeExt, ErasedPeekableBiStream, ErasedPeekableUniStream, ErasedStreamReader, - SinkWriter, + BoxPeekableStreamReader, BoxStreamReader, BoxStreamWriter, DecodeExt, EncodeExt, SinkWriter, }, connection::{ConnectionState, LifecycleExt, StreamError}, dhttp::{protocol::DHttpProtocol, settings::Settings, stream::UnidirectionalStream}, @@ -103,12 +102,12 @@ pub type QPackDecoder = Decoder< /// (encoder 0x02, decoder 0x03) and passes all other streams through. pub struct QPackProtocol { /// Oneshot sender for dispatching the peer's QPACK encoder instruction stream. - encoder_inst_receiver_tx: Mutex>>, + encoder_inst_receiver_tx: Mutex>>, /// QPACK encoder, set during connection initialization. pub encoder: Arc, /// Oneshot sender for dispatching the peer's QPACK decoder instruction stream. - decoder_inst_receiver_tx: Mutex>>, + decoder_inst_receiver_tx: Mutex>>, /// QPACK decoder, set during connection initialization. pub decoder: Arc, } @@ -125,8 +124,8 @@ impl std::fmt::Debug for QPackProtocol { impl QPackProtocol { async fn accept_uni( &self, - mut stream: ErasedPeekableUniStream, - ) -> Result, StreamError> { + mut stream: BoxPeekableStreamReader, + ) -> Result, StreamError> { let Ok(stream_type) = stream.decode_one::().await else { return Ok(StreamVerdict::Passed(stream)); }; @@ -138,7 +137,7 @@ impl QPackProtocol { .lock() .expect("lock is not poisoned") .take() - .ok_or(H3StreamCreationError::DuplicateQpackDecoderStream)? + .ok_or(H3StreamCreationError::DuplicateQpackEncoderStream)? .send(uni_stream_reader); Ok(StreamVerdict::Accepted) } else if stream_type == UnidirectionalStream::QPACK_DECODER_STREAM_TYPE { @@ -158,8 +157,8 @@ impl QPackProtocol { async fn accept_bi( &self, - stream: ErasedPeekableBiStream, - ) -> Result, StreamError> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> Result, StreamError> { Ok(StreamVerdict::Passed(stream)) } } @@ -167,15 +166,16 @@ impl QPackProtocol { impl Protocol for QPackProtocol { fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { Box::pin(self.accept_uni(stream)) } fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture<'a, Result, StreamError>> + { Box::pin(self.accept_bi(stream)) } } @@ -210,9 +210,9 @@ impl QPackProtocolFactory { // Create dispatch channels for incoming peer QPACK streams let (encoder_inst_receiver_tx, encoder_inst_receiver_rx) = - oneshot::channel::(); + oneshot::channel::(); let (decoder_inst_receiver_tx, decoder_inst_receiver_rx) = - oneshot::channel::(); + oneshot::channel::(); // Create QPACK encoder with lazy streams let encoder = { @@ -352,12 +352,12 @@ impl ConnectionState { // use super::QPackLayer; // use crate::{ // codec::{StreamReader, peekable::PeekableStreamReader}, -// layer::{BoxPeekableUniStream, Protocol, StreamVerdict}, +// layer::{BoxPeekableStreamReader, Protocol, StreamVerdict}, // varint::VarInt, // }; -// /// Helper: create a BoxPeekableUniStream from raw byte chunks. -// fn peekable_uni_from_chunks(chunks: Vec<&'static [u8]>) -> BoxPeekableUniStream { +// /// Helper: create a BoxPeekableStreamReader from raw byte chunks. +// fn peekable_uni_from_chunks(chunks: Vec<&'static [u8]>) -> BoxPeekableStreamReader { // let chunks_owned: Vec = chunks.into_iter().map(Bytes::from_static).collect(); // let (reader, mut writer) = crate::quic::test::mock_stream_pair(VarInt::from_u32(0)); @@ -377,8 +377,8 @@ impl ConnectionState { // PeekableStreamReader::new(stream_reader) // } -// /// Helper: create a BoxPeekableBiStream from raw byte chunks for the read side. -// fn peekable_bi_from_chunks(read_chunks: Vec<&'static [u8]>) -> crate::layer::BoxPeekableBiStream { +// /// Helper: create a (BoxPeekableStreamReader, BoxStreamWriter) from raw byte chunks for the read side. +// fn peekable_bi_from_chunks(read_chunks: Vec<&'static [u8]>) -> crate::layer::(BoxPeekableStreamReader, BoxStreamWriter) { // let (reader, mut writer_feed) = crate::quic::test::mock_stream_pair(VarInt::from_u32(4)); // let (_, write_side) = crate::quic::test::mock_stream_pair(VarInt::from_u32(4)); @@ -477,11 +477,201 @@ impl ConnectionState { #[cfg(test)] mod tests { use std::{ + any::Any, + borrow::Cow, collections::hash_map::DefaultHasher, + fmt, hash::{Hash, Hasher}, + io, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, }; + use bytes::Bytes; + use dhttp_identity::identity::SignError; + use futures::{SinkExt, StreamExt, channel::oneshot}; + use rustls::pki_types::CertificateDer; + use super::*; + use crate::{ + codec::{PeekableStreamReader, StreamReader}, + dhttp::protocol::DHttpProtocolFactory, + quic::{BoxQuicStreamReader, BoxQuicStreamWriter}, + }; + + #[derive(Debug, Clone, Copy)] + struct TestLocalAuthority; + + impl dhttp_identity::identity::LocalAuthority for TestLocalAuthority { + fn name(&self) -> &str { + "test.local" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &[] + } + fn sign(&self, _data: &[u8]) -> futures::future::BoxFuture<'_, Result, SignError>> { + Box::pin(std::future::ready(Err(SignError::UnsupportedKey))) + } + } + + #[derive(Debug, Clone, Copy)] + struct TestRemoteAuthority; + + impl dhttp_identity::identity::RemoteAuthority for TestRemoteAuthority { + fn name(&self) -> &str { + "test.remote" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &[] + } + } + + #[derive(Default, Clone)] + struct MockConnection { + stream_calls: Arc>>, + open_uni_readers: Arc>>>, + open_uni_available: bool, + closed_pending: bool, + } + + impl fmt::Debug for MockConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MockConnection") + .field("stream_calls", &self.stream_calls) + .field( + "open_uni_readers", + &self + .open_uni_readers + .lock() + .expect("open uni readers lock poisoned") + .len(), + ) + .field("open_uni_available", &self.open_uni_available) + .field("closed_pending", &self.closed_pending) + .finish() + } + } + + impl MockConnection { + fn with_open_uni_available(open_uni_available: bool) -> Self { + Self { + stream_calls: Arc::default(), + open_uni_readers: Arc::default(), + open_uni_available, + closed_pending: false, + } + } + + fn with_open_uni_available_and_pending_close(open_uni_available: bool) -> Self { + Self { + stream_calls: Arc::default(), + open_uni_readers: Arc::default(), + open_uni_available, + closed_pending: true, + } + } + + fn stream_calls(&self) -> Vec<&'static str> { + self.stream_calls + .lock() + .expect("stream call log poisoned") + .clone() + } + + fn record_stream_call(&self, call: &'static str) { + self.stream_calls + .lock() + .expect("stream call log poisoned") + .push(call); + } + } + + fn test_connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + impl quic::ManageStream for MockConnection { + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + self.record_stream_call("open_bi"); + Err(test_connection_error("open_bi unavailable")) + } + + async fn open_uni(&self) -> Result { + self.record_stream_call("open_uni"); + if !self.open_uni_available { + return Err(test_connection_error("open_uni unavailable")); + } + + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(0)); + self.open_uni_readers + .lock() + .expect("open uni readers lock poisoned") + .push(Box::new(reader)); + Ok(Box::pin(writer) as BoxQuicStreamWriter) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + self.record_stream_call("accept_bi"); + Err(test_connection_error("accept_bi unavailable")) + } + + async fn accept_uni(&self) -> Result { + self.record_stream_call("accept_uni"); + Err(test_connection_error("accept_uni unavailable")) + } + } + + impl quic::WithLocalAuthority for MockConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for MockConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for MockConnection { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + if self.closed_pending { + std::future::pending().await + } else { + test_connection_error("connection closed") + } + } + } fn hash_of(t: &T) -> u64 { let mut h = DefaultHasher::new(); @@ -489,6 +679,103 @@ mod tests { h.finish() } + fn instruction_sink() -> BoxInstructionSink<'static, Instruction> { + Box::pin(futures::sink::unfold((), |(), _instruction| async { + Ok::<(), StreamError>(()) + })) + } + + fn instruction_stream() + -> BoxInstructionStream<'static, Instruction> { + Box::pin(futures::stream::pending()) + } + + fn qpack_protocol() -> QPackProtocol { + let (encoder_inst_receiver_tx, _encoder_inst_receiver_rx) = + oneshot::channel::(); + let (decoder_inst_receiver_tx, _decoder_inst_receiver_rx) = + oneshot::channel::(); + + QPackProtocol { + encoder_inst_receiver_tx: Mutex::new(Some(encoder_inst_receiver_tx)), + encoder: Arc::new(Encoder::new( + Arc::::default(), + instruction_sink::(), + instruction_stream::(), + )), + decoder_inst_receiver_tx: Mutex::new(Some(decoder_inst_receiver_tx)), + decoder: Arc::new(Decoder::new( + Arc::::default(), + instruction_sink::(), + instruction_stream::(), + )), + } + } + + async fn peekable_uni_from_bytes(bytes: &'static [u8]) -> BoxPeekableStreamReader { + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(0)); + if !bytes.is_empty() { + writer + .send(Bytes::from_static(bytes)) + .await + .expect("send bytes"); + } + writer.close().await.expect("close writer"); + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)) + } + + fn peekable_bi_stream() -> (BoxPeekableStreamReader, BoxStreamWriter) { + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(4)); + ( + PeekableStreamReader::new(StreamReader::new(Box::pin(reader) as BoxQuicStreamReader)), + SinkWriter::new(Box::pin(writer) as BoxQuicStreamWriter), + ) + } + + fn assert_duplicate_stream_creation(error: StreamError, expected_message: &str) { + let StreamError::Connection { source } = error else { + panic!("expected connection-scoped stream error"); + }; + let crate::connection::ConnectionError::H3 { source } = source else { + panic!("expected h3 connection error"); + }; + assert_eq!(source.code(), Code::H3_STREAM_CREATION_ERROR); + assert_eq!(source.to_string(), expected_message); + } + + fn assert_critical_stream_closed(error: StreamError, expected_message: &str) { + let StreamError::Connection { + source: crate::connection::ConnectionError::H3 { source }, + } = error + else { + panic!("expected h3 connection error"); + }; + assert_eq!(source.code(), Code::H3_CLOSED_CRITICAL_STREAM); + assert_eq!(source.to_string(), expected_message); + } + + struct ResetWrite; + + impl tokio::io::AsyncWrite for ResetWrite { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Ready(Err(io::Error::from(quic::StreamError::Reset { + code: VarInt::from_u32(7), + }))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + } + #[test] fn qpack_factory_all_equal() { assert_eq!(QPackProtocolFactory::new(), QPackProtocolFactory::new()); @@ -508,6 +795,18 @@ mod tests { let b = QPackProtocolFactory::new(); assert_eq!(a, b); } + + #[test] + fn qpack_factory_display_and_debug_are_stable() { + let factory = QPackProtocolFactory::new(); + assert_eq!(factory.to_string(), "QPACK"); + assert_eq!(format!("{factory:?}"), "QPackProtocolFactory"); + assert_eq!( + format!("{:?}", qpack_protocol()), + "QPackLayer { encoder: \"...\", decoder: \"...\" }" + ); + } + #[test] fn qpack_decoder_stream_type_is_0x03() { assert_eq!( @@ -531,4 +830,446 @@ mod tests { UnidirectionalStream::<()>::QPACK_DECODER_STREAM_TYPE.into_inner(), ); } + + #[tokio::test] + async fn qpack_stream_type_helpers_identify_encoder_and_decoder_streams() { + let encoder_stream = UnidirectionalStream::initial_qpack_encoder_stream(tokio::io::sink()) + .await + .expect("encoder stream"); + assert!(encoder_stream.is_qpack_encoder_stream()); + assert!(!encoder_stream.is_qpack_decoder_stream()); + + let decoder_stream = UnidirectionalStream::initial_qpack_decoder_stream(tokio::io::sink()) + .await + .expect("decoder stream"); + assert!(decoder_stream.is_qpack_decoder_stream()); + assert!(!decoder_stream.is_qpack_encoder_stream()); + } + + #[tokio::test] + async fn qpack_initial_stream_helpers_write_expected_stream_types() { + let (mut encoder_reader, encoder_writer) = + quic::test::mock_stream_pair(VarInt::from_u32(0)); + let encoder_stream = UnidirectionalStream::initial_qpack_encoder_stream(SinkWriter::new( + Box::pin(encoder_writer) as BoxQuicStreamWriter, + )) + .await + .expect("encoder stream init"); + let mut encoder_writer = encoder_stream.into_inner(); + encoder_writer.flush_inner().await.expect("encoder flush"); + assert_eq!( + encoder_reader + .next() + .await + .expect("encoder bytes") + .expect("encoder stream read"), + Bytes::from_static(&[0x02]), + ); + + let (mut decoder_reader, decoder_writer) = + quic::test::mock_stream_pair(VarInt::from_u32(0)); + let decoder_stream = UnidirectionalStream::initial_qpack_decoder_stream(SinkWriter::new( + Box::pin(decoder_writer) as BoxQuicStreamWriter, + )) + .await + .expect("decoder stream init"); + let mut decoder_writer = decoder_stream.into_inner(); + decoder_writer.flush_inner().await.expect("decoder flush"); + assert_eq!( + decoder_reader + .next() + .await + .expect("decoder bytes") + .expect("decoder stream read"), + Bytes::from_static(&[0x03]), + ); + } + + #[tokio::test] + async fn qpack_initial_stream_helpers_map_reset_to_critical_stream_closed() { + let error = UnidirectionalStream::initial_qpack_encoder_stream(ResetWrite) + .await + .err() + .expect("encoder stream reset should be critical"); + assert_critical_stream_closed(error, "qpack encoder stream closed unexpectedly"); + + let error = UnidirectionalStream::initial_qpack_decoder_stream(ResetWrite) + .await + .err() + .expect("decoder stream reset should be critical"); + assert_critical_stream_closed(error, "qpack decoder stream closed unexpectedly"); + } + + #[tokio::test] + async fn qpack_test_helpers_cover_mock_connection_error_paths() { + use dhttp_identity::identity::{LocalAuthority as _, RemoteAuthority as _}; + + let local = TestLocalAuthority; + assert_eq!(local.name(), "test.local"); + assert!(local.cert_chain().is_empty()); + let sign_error = local + .sign(b"payload") + .await + .expect_err("unsupported key should fail"); + assert!(matches!(sign_error, SignError::UnsupportedKey)); + + let remote = TestRemoteAuthority; + assert_eq!(remote.name(), "test.remote"); + assert!(remote.cert_chain().is_empty()); + + let conn = MockConnection::with_open_uni_available(false); + quic::Lifecycle::close(&conn, Code::H3_NO_ERROR, Cow::Borrowed("ignored")); + assert!(quic::Lifecycle::check(&conn).is_ok()); + let closed = quic::Lifecycle::closed(&conn).await; + match closed { + quic::ConnectionError::Transport { source } => { + assert_eq!(source.reason, "connection closed"); + } + error => panic!("unexpected close error: {error:?}"), + } + + assert!( + quic::WithLocalAuthority::local_authority(&conn) + .await + .expect("local authority query") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(&conn) + .await + .expect("remote authority query") + .is_none() + ); + + assert!(quic::ManageStream::open_bi(&conn).await.is_err()); + assert!(quic::ManageStream::accept_bi(&conn).await.is_err()); + assert!(quic::ManageStream::accept_uni(&conn).await.is_err()); + assert!(quic::ManageStream::open_uni(&conn).await.is_err()); + assert_eq!( + conn.stream_calls(), + vec!["open_bi", "accept_bi", "accept_uni", "open_uni"] + ); + + let mut sink = instruction_sink::(); + sink.send(EncoderInstruction::SetDynamicTableCapacity { capacity: 1 }) + .await + .expect("instruction sink accepts values"); + } + + #[tokio::test] + async fn qpack_protocol_trait_access_and_connection_state_accessor_work() { + let protocol = qpack_protocol(); + let protocol_ref: &dyn Protocol = &protocol; + + assert!(matches!( + protocol_ref + .accept_uni(peekable_uni_from_bytes(&[0x00]).await) + .await + .expect("trait object uni verdict"), + StreamVerdict::Passed(_) + )); + assert!(matches!( + protocol_ref + .accept_bi(peekable_bi_stream()) + .await + .expect("trait object bi verdict"), + StreamVerdict::Passed(_) + )); + + let conn = Arc::new(MockConnection::with_open_uni_available(true)); + let disabled_state = ConnectionState::new_for_test(conn.clone(), Arc::default()); + assert!(matches!(disabled_state.qpack(), Err(QPackProtocolDisabled))); + + let mut protocols = Protocols::new(); + protocols.insert(qpack_protocol()); + let protocols = Arc::new(protocols); + let enabled_state = ConnectionState::new_for_test(conn, protocols.clone()); + assert!(std::ptr::eq( + enabled_state.qpack().expect("qpack protocol should exist"), + protocols + .get::() + .expect("protocol registry should contain qpack"), + )); + } + + #[tokio::test] + async fn qpack_protocol_accepts_encoder_and_decoder_unidirectional_streams() { + let protocol = qpack_protocol(); + + let encoder = peekable_uni_from_bytes(&[0x02]).await; + assert!(matches!( + protocol.accept_uni(encoder).await.expect("encoder verdict"), + StreamVerdict::Accepted + )); + + let decoder = peekable_uni_from_bytes(&[0x03]).await; + assert!(matches!( + protocol.accept_uni(decoder).await.expect("decoder verdict"), + StreamVerdict::Accepted + )); + } + + #[tokio::test] + async fn qpack_protocol_passes_unknown_or_incomplete_unidirectional_streams() { + let protocol = qpack_protocol(); + + let unknown = peekable_uni_from_bytes(&[0x00, b'h']).await; + assert!(matches!( + protocol.accept_uni(unknown).await.expect("unknown verdict"), + StreamVerdict::Passed(_) + )); + + let incomplete = peekable_uni_from_bytes(&[]).await; + assert!(matches!( + protocol + .accept_uni(incomplete) + .await + .expect("incomplete verdict"), + StreamVerdict::Passed(_) + )); + } + + #[tokio::test] + async fn qpack_protocol_rejects_duplicate_instruction_streams_with_specific_errors() { + let protocol = qpack_protocol(); + + let first_encoder = peekable_uni_from_bytes(&[0x02]).await; + assert!(matches!( + protocol + .accept_uni(first_encoder) + .await + .expect("first encoder verdict"), + StreamVerdict::Accepted + )); + let second_encoder = peekable_uni_from_bytes(&[0x02]).await; + let error = match protocol.accept_uni(second_encoder).await { + Ok(_) => panic!("duplicate encoder should fail"), + Err(error) => error, + }; + assert_duplicate_stream_creation(error, "qpack encoder stream already exists"); + + let first_decoder = peekable_uni_from_bytes(&[0x03]).await; + assert!(matches!( + protocol + .accept_uni(first_decoder) + .await + .expect("first decoder verdict"), + StreamVerdict::Accepted + )); + let second_decoder = peekable_uni_from_bytes(&[0x03]).await; + let error = match protocol.accept_uni(second_decoder).await { + Ok(_) => panic!("duplicate decoder should fail"), + Err(error) => error, + }; + assert_duplicate_stream_creation(error, "qpack decoder stream already exists"); + } + + #[tokio::test] + async fn qpack_protocol_passes_bidirectional_streams_and_formats_disabled_error() { + let protocol = qpack_protocol(); + assert!(matches!( + protocol + .accept_bi(peekable_bi_stream()) + .await + .expect("bidi verdict"), + StreamVerdict::Passed(_) + )); + + assert_eq!( + QPackProtocolDisabled.to_string(), + "qpack protocol is disabled" + ); + } + + #[tokio::test] + async fn qpack_protocol_factory_init_requires_dhttp_protocol() { + let conn = Arc::new(MockConnection::with_open_uni_available(true)); + let factory = QPackProtocolFactory::new(); + + let error = factory + .init(&conn, &Protocols::new()) + .await + .expect_err("missing dhttp should fail init"); + + let quic::ConnectionError::Application { source } = error else { + panic!("expected application error"); + }; + assert_eq!(source.code, Code::H3_INTERNAL_ERROR); + assert_eq!( + source.reason, + "DHttpLayer must be initialized before QPackLayer" + ); + assert!(conn.stream_calls().is_empty()); + } + + #[tokio::test] + async fn qpack_protocol_factory_product_trait_init_matches_inherent_init() { + let conn = Arc::new(MockConnection::with_open_uni_available(true)); + let dhttp = DHttpProtocolFactory::default() + .init(&conn) + .await + .expect("dhttp init"); + let mut protocols = Protocols::new(); + protocols.insert(dhttp); + + let protocol = ProductProtocol::init(&QPackProtocolFactory::new(), &conn, &protocols) + .await + .expect("trait init"); + + tokio::task::yield_now().await; + + assert_eq!(conn.stream_calls(), vec!["open_uni"]); + assert_eq!( + format!("{protocol:?}"), + "QPackLayer { encoder: \"...\", decoder: \"...\" }" + ); + } + + #[tokio::test] + async fn qpack_protocol_factory_routes_peer_encoder_stream_to_decoder_instruction_receiver() { + let conn = Arc::new(MockConnection::with_open_uni_available_and_pending_close( + true, + )); + let mut settings = Settings::default(); + settings.set(crate::qpack::settings::QpackMaxTableCapacity::setting( + VarInt::from_u32(64), + )); + let mut dhttp = DHttpProtocolFactory::default() + .init(&conn) + .await + .expect("dhttp init"); + Arc::get_mut(&mut dhttp.state) + .expect("test dhttp state should be uniquely owned") + .local_settings = Arc::new(settings); + let mut protocols = Protocols::new(); + protocols.insert(dhttp); + + let protocol = QPackProtocolFactory::new() + .init(&conn, &protocols) + .await + .expect("qpack init"); + + let encoder_stream = peekable_uni_from_bytes(&[ + 0x02, // QPACK encoder stream type + 0x3f, 0x21, // SetDynamicTableCapacity { capacity: 64 } + 0x41, b'x', // literal name "x" + 0x01, b'y', // literal value "y" + ]) + .await; + assert!(matches!( + protocol + .accept_uni(encoder_stream) + .await + .expect("encoder stream verdict"), + StreamVerdict::Accepted + )); + + protocol + .decoder + .receive_instruction_until(1) + .await + .expect("peer encoder instruction stream should reach qpack decoder"); + + let state = protocol.decoder.state.lock().expect("decoder state lock"); + assert_eq!(state.dynamic_table.capacity, 64); + assert_eq!(state.dynamic_table.inserted_count, 1); + assert_eq!( + state.dynamic_table.get(0).expect("inserted entry").name, + Bytes::from_static(b"x") + ); + assert_eq!( + state.dynamic_table.get(0).expect("inserted entry").value, + Bytes::from_static(b"y") + ); + } + + #[tokio::test] + async fn qpack_protocol_factory_init_is_lazy_until_qpack_streams_are_used() { + let conn = Arc::new(MockConnection::with_open_uni_available(true)); + let dhttp = DHttpProtocolFactory::default() + .init(&conn) + .await + .expect("dhttp init"); + let mut protocols = Protocols::new(); + protocols.insert(dhttp); + + let protocol = QPackProtocolFactory::new() + .init(&conn, &protocols) + .await + .expect("qpack init"); + + assert_eq!(conn.stream_calls(), vec!["open_uni"]); + + { + let mut state = protocol.encoder.state.lock().await; + state.emit(EncoderInstruction::SetDynamicTableCapacity { capacity: 1 }); + } + protocol + .encoder + .flush_instructions() + .await + .expect("encoder flush opens stream"); + + { + let mut state = protocol.decoder.state.lock().expect("decoder state lock"); + state + .pending_instructions + .push_back(DecoderInstruction::SectionAcknowledgment { stream_id: 7 }); + } + protocol + .decoder + .flush_instructions() + .await + .expect("decoder flush opens stream"); + + assert_eq!( + conn.stream_calls(), + vec!["open_uni", "open_uni", "open_uni"] + ); + } + + #[tokio::test] + async fn qpack_protocol_factory_init_propagates_qpack_stream_open_errors_when_used() { + let conn = Arc::new(MockConnection::with_open_uni_available(false)); + let dhttp = { + let openable = Arc::new(MockConnection::with_open_uni_available(true)); + DHttpProtocolFactory::default() + .init(&openable) + .await + .expect("dhttp init") + }; + let mut protocols = Protocols::new(); + protocols.insert(dhttp); + + let protocol = QPackProtocolFactory::new() + .init(&conn, &protocols) + .await + .expect("qpack init"); + + { + let mut state = protocol.encoder.state.lock().await; + state.emit(EncoderInstruction::SetDynamicTableCapacity { capacity: 1 }); + } + let encoder_error = protocol + .encoder + .flush_instructions() + .await + .expect_err("encoder flush should surface open_uni failure"); + assert!(matches!(encoder_error, StreamError::Connection { .. })); + + { + let mut state = protocol.decoder.state.lock().expect("decoder state lock"); + state + .pending_instructions + .push_back(DecoderInstruction::SectionAcknowledgment { stream_id: 9 }); + } + let decoder_error = protocol + .decoder + .flush_instructions() + .await + .expect_err("decoder flush should surface open_uni failure"); + assert!(matches!(decoder_error, StreamError::Connection { .. })); + + assert_eq!(conn.stream_calls(), vec!["open_uni", "open_uni"]); + } } diff --git a/src/qpack/settings.rs b/src/qpack/settings.rs new file mode 100644 index 0000000..ef5821e --- /dev/null +++ b/src/qpack/settings.rs @@ -0,0 +1,135 @@ +use crate::{ + dhttp::settings::{Setting, SettingId, Settings}, + varint::VarInt, +}; + +/// `SETTINGS_QPACK_MAX_TABLE_CAPACITY` (0x01). Default: 0. +pub struct QpackMaxTableCapacity; + +impl QpackMaxTableCapacity { + pub const ID: VarInt = VarInt::from_u32(0x01); + pub const DEFAULT: VarInt = VarInt::from_u32(0); + + pub const fn setting(value: VarInt) -> Setting { + Setting::new(Self::ID, value) + } +} + +impl SettingId for QpackMaxTableCapacity { + type Value = VarInt; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> VarInt { + settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + } +} + +/// `SETTINGS_QPACK_BLOCKED_STREAMS` (0x07). Default: 0. +pub struct QpackBlockedStreams; + +impl QpackBlockedStreams { + pub const ID: VarInt = VarInt::from_u32(0x07); + pub const DEFAULT: VarInt = VarInt::from_u32(0); + + pub const fn setting(value: VarInt) -> Setting { + Setting::new(Self::ID, value) + } +} + +impl SettingId for QpackBlockedStreams { + type Value = VarInt; + + fn id(&self) -> VarInt { + Self::ID + } + + fn value_from(&self, settings: &Settings) -> VarInt { + settings.get_raw(Self::ID).unwrap_or(Self::DEFAULT) + } +} + +impl Settings { + pub fn qpack_max_table_capacity(&self) -> VarInt { + self.get(QpackMaxTableCapacity) + } + + pub fn qpack_blocked_streams(&self) -> VarInt { + self.get(QpackBlockedStreams) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::varint::VarInt; + + #[test] + fn qpack_settings_apply_defaults_and_overrides() { + let mut settings = Settings::default(); + assert_eq!(settings.qpack_max_table_capacity(), VarInt::from_u32(0)); + assert_eq!(settings.qpack_blocked_streams(), VarInt::from_u32(0)); + + settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(4096))); + settings.set(QpackBlockedStreams::setting(VarInt::from_u32(100))); + + assert_eq!(settings.qpack_max_table_capacity(), VarInt::from_u32(4096)); + assert_eq!(settings.qpack_blocked_streams(), VarInt::from_u32(100)); + } + + #[test] + fn qpack_settings_boundaries_and_ids() { + assert_eq!(QpackMaxTableCapacity::ID, VarInt::from_u32(0x01)); + assert_eq!(QpackBlockedStreams::ID, VarInt::from_u32(0x07)); + assert_eq!(QpackMaxTableCapacity.id(), QpackMaxTableCapacity::ID); + assert_eq!(QpackBlockedStreams.id(), QpackBlockedStreams::ID); + assert_eq!(QpackMaxTableCapacity::DEFAULT, VarInt::from_u32(0)); + assert_eq!(QpackBlockedStreams::DEFAULT, VarInt::from_u32(0)); + assert_eq!( + QpackMaxTableCapacity::setting(VarInt::MAX).id, + QpackMaxTableCapacity::ID + ); + assert_eq!( + QpackBlockedStreams::setting(VarInt::from_u32(0)).value, + VarInt::from_u32(0), + ); + } + + #[test] + fn qpack_settings_raw_lookup_and_default_fallback() { + let mut settings = Settings::default(); + + assert_eq!(settings.get_raw(QpackMaxTableCapacity::ID), None); + assert_eq!( + settings.get(QpackMaxTableCapacity), + QpackMaxTableCapacity::DEFAULT + ); + + settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(99))); + assert_eq!( + settings.get_raw(QpackMaxTableCapacity::ID), + Some(VarInt::from_u32(99)) + ); + assert_eq!(settings.qpack_max_table_capacity(), VarInt::from_u32(99)); + assert_eq!(settings.get(QpackMaxTableCapacity), VarInt::from_u32(99)); + + settings.set(QpackMaxTableCapacity::setting(VarInt::from_u32(0))); + assert_eq!( + settings.get_raw(QpackMaxTableCapacity::ID), + Some(VarInt::from_u32(0)) + ); + assert_eq!(settings.qpack_max_table_capacity(), VarInt::from_u32(0)); + } + + #[test] + fn qpack_setting_constructors_stay_well_formed_for_unknown_id_queries() { + let setting = QpackBlockedStreams::setting(VarInt::MAX); + let id: VarInt = QpackBlockedStreams::ID; + + assert_eq!(setting.id, id); + assert_eq!(setting.value, VarInt::MAX); + assert_eq!(Setting::from((id, VarInt::from_u32(0x7f))).id, id); + } +} diff --git a/src/qpack/static.rs b/src/qpack/static.rs index f3593ac..d7b8418 100644 --- a/src/qpack/static.rs +++ b/src/qpack/static.rs @@ -309,4 +309,15 @@ mod tests { assert_eq!(get(25), Some((":status", "200"))); assert_eq!(get(98), Some(("x-frame-options", "sameorigin"))); } + + #[test] + fn test_static_table_name_and_value_accessors() { + assert_eq!(get_name(0), Some(":authority")); + assert_eq!(get_name(98), Some("x-frame-options")); + assert_eq!(get_name(99), None); + + assert_eq!(get_value(0), Some("")); + assert_eq!(get_value(98), Some("sameorigin")); + assert_eq!(get_value(99), None); + } } diff --git a/src/quic.rs b/src/quic.rs index 2671df3..41e7a4f 100644 --- a/src/quic.rs +++ b/src/quic.rs @@ -12,13 +12,12 @@ use std::{ }; use bytes::Bytes; +use dhttp_identity::identity::{LocalAuthority, RemoteAuthority}; use futures::{Sink, Stream, future::BoxFuture}; use http::uri::Authority; use snafu::Snafu; -pub mod agent; - -use crate::{error::Code, varint::VarInt}; +use crate::{error::Code, stream, varint::VarInt}; #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Snafu, Clone)] @@ -144,10 +143,285 @@ pub trait Listen: Send + Sync { fn shutdown(&self) -> impl Future> + Send + '_; } +impl Connect for &T { + type Connection = T::Connection; + type Error = T::Error; + + fn connect<'a>( + &'a self, + server: &'a Authority, + ) -> impl Future, Self::Error>> + Send + 'a { + (**self).connect(server) + } +} + +impl Connect for Arc { + type Connection = T::Connection; + type Error = T::Error; + + fn connect<'a>( + &'a self, + server: &'a Authority, + ) -> impl Future, Self::Error>> + Send + 'a { + self.as_ref().connect(server) + } +} + +/// Read-only observation of a QUIC stream id. +pub trait GetStreamId { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>; +} + +impl

GetStreamId for Pin

+where + P: DerefMut, + P::Target: GetStreamId, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + ::poll_stream_id(self.as_deref_mut(), cx) + } +} + +impl GetStreamId for &mut S +where + S: GetStreamId + Unpin + ?Sized, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + S::poll_stream_id(Pin::new(self.get_mut()), cx) + } +} + +pin_project_lite::pin_project! { + pub struct StreamId { + #[pin] + stream: S, + } +} + +impl Future for StreamId +where + S: GetStreamId + ?Sized, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().stream.poll_stream_id(cx) + } +} + +pub trait GetStreamIdExt: GetStreamId { + fn stream_id(&mut self) -> StreamId<&mut Self> { + StreamId { stream: self } + } +} + +impl GetStreamIdExt for T where T: GetStreamId + ?Sized {} + +/// QUIC receive-side STOP_SENDING control for a stream. +pub trait StopStream { + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll>; +} + +impl

StopStream for Pin

+where + P: DerefMut, + P::Target: StopStream, +{ + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + ::poll_stop(self.as_deref_mut(), cx, code) + } +} + +impl StopStream for &mut S +where + S: StopStream + Unpin + ?Sized, +{ + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + S::poll_stop(Pin::new(self.get_mut()), cx, code) + } +} + +pin_project_lite::pin_project! { + pub struct Stop { + code: VarInt, + #[pin] + stream: S, + } +} + +impl Future for Stop +where + S: StopStream + ?Sized, +{ + type Output = Result<(), StreamError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + project.stream.poll_stop(cx, *project.code) + } +} + +pub trait StopStreamExt: StopStream { + fn stop(&mut self, code: VarInt) -> Stop<&mut Self> { + Stop { code, stream: self } + } +} + +impl StopStreamExt for T where T: StopStream + ?Sized {} + +/// QUIC send-side RESET_STREAM control for a stream. +pub trait ResetStream { + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll>; +} + +impl

ResetStream for Pin

+where + P: DerefMut, + P::Target: ResetStream, +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + ::poll_reset(self.as_deref_mut(), cx, code) + } +} + +impl ResetStream for &mut S +where + S: ResetStream + Unpin + ?Sized, +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + S::poll_reset(Pin::new(self.get_mut()), cx, code) + } +} + +pin_project_lite::pin_project! { + pub struct Reset { + code: VarInt, + #[pin] + stream: S, + } +} + +impl Future for Reset +where + S: ResetStream + ?Sized, +{ + type Output = Result<(), StreamError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + project.stream.poll_reset(cx, *project.code) + } +} + +pub trait ResetStreamExt: ResetStream { + fn reset(&mut self, code: VarInt) -> Reset<&mut Self> { + Reset { code, stream: self } + } +} + +impl ResetStreamExt for T where T: ResetStream + ?Sized {} + +/// Byte stream plus QUIC receive-side control. +pub trait ReadStream: + StopStream + GetStreamId + Stream> + Send + Any +{ +} + +impl ReadStream for S where + S: StopStream + GetStreamId + Stream> + Send + ?Sized + Any +{ +} + +/// Byte sink plus QUIC send-side reset control. +pub trait WriteStream: + ResetStream + GetStreamId + Sink + Send + Any +{ +} + +impl WriteStream for S where + S: ResetStream + GetStreamId + Sink + Send + ?Sized + Any +{ +} + +pub type BoxQuicStreamReader = Pin>; + +pub type BoxQuicStreamWriter = Pin>; + +impl stream::GetStreamId for dyn ReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + GetStreamId::poll_stream_id(self, cx) + } +} + +impl stream::StopStream for dyn ReadStream { + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + StopStream::poll_stop(self, cx, code) + } +} + +impl stream::GetStreamId for dyn WriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + GetStreamId::poll_stream_id(self, cx) + } +} + +impl stream::ResetStream for dyn WriteStream { + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + ResetStream::poll_reset(self, cx, code) + } +} + /// AFIT version of stream management with concrete associated types. /// -/// Implement this for concrete connection types. A blanket impl provides -/// [`DynManageStream`] automatically. +/// Implement this for concrete QUIC connection types. A blanket impl provides +/// [`stream::ManageStream`] and [`DynManageStream`] automatically. pub trait ManageStream: Send + Sync { type StreamReader: ReadStream + Unpin; type StreamWriter: WriteStream + Unpin; @@ -173,108 +447,147 @@ pub trait ManageStream: Send + Sync { ) -> impl Future> + Send + '_; } -/// Type-erased read stream: `Pin>`. -pub type BoxReadStream = Pin>; +impl stream::ManageStream for T +where + T: ManageStream + ?Sized, +{ + type Data = Bytes; + + type ReadError = StreamError; + type WriteError = StreamError; + type StopError = StreamError; + type ResetError = StreamError; + type StreamIdError = StreamError; + + type OpenBiError = ConnectionError; + type OpenUniError = ConnectionError; + type AcceptBiError = ConnectionError; + type AcceptUniError = ConnectionError; + + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; + + async fn open_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), Self::OpenBiError> { + let (reader, writer) = ManageStream::open_bi(self).await?; + Ok(( + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + )) + } + + async fn open_uni(&self) -> Result { + let writer = ManageStream::open_uni(self).await?; + Ok(Box::pin(writer) as BoxQuicStreamWriter) + } -/// Type-erased write stream: `Pin>`. -pub type BoxWriteStream = Pin>; + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), Self::AcceptBiError> { + let (reader, writer) = ManageStream::accept_bi(self).await?; + Ok(( + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + )) + } + + async fn accept_uni(&self) -> Result { + let reader = ManageStream::accept_uni(self).await?; + Ok(Box::pin(reader) as BoxQuicStreamReader) + } +} /// Object-safe version of [`ManageStream`] with type-erased streams. /// -/// Stream types are fixed to [`BoxReadStream`] / [`BoxWriteStream`]. +/// Stream types are fixed to [`BoxQuicStreamReader`] / [`BoxQuicStreamWriter`]. /// A blanket impl is provided for all `T: ManageStream`. pub trait DynManageStream: Send + Sync { #[allow(clippy::type_complexity)] - fn open_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), ConnectionError>>; + fn open_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), ConnectionError>>; - fn open_uni(&self) -> BoxFuture<'_, Result>; + fn open_uni(&self) -> BoxFuture<'_, Result>; #[allow(clippy::type_complexity)] - fn accept_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), ConnectionError>>; + fn accept_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), ConnectionError>>; - fn accept_uni(&self) -> BoxFuture<'_, Result>; + fn accept_uni(&self) -> BoxFuture<'_, Result>; } impl DynManageStream for T { - fn open_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), ConnectionError>> { - Box::pin(async { - let (r, w) = ManageStream::open_bi(self).await?; - Ok((Box::pin(r) as BoxReadStream, Box::pin(w) as BoxWriteStream)) - }) + fn open_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), ConnectionError>> { + Box::pin(async { stream::ManageStream::open_bi(self).await }) } - fn open_uni(&self) -> BoxFuture<'_, Result> { - Box::pin(async { - let w = ManageStream::open_uni(self).await?; - Ok(Box::pin(w) as BoxWriteStream) - }) + fn open_uni(&self) -> BoxFuture<'_, Result> { + Box::pin(async { stream::ManageStream::open_uni(self).await }) } - fn accept_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), ConnectionError>> { - Box::pin(async { - let (r, w) = ManageStream::accept_bi(self).await?; - Ok((Box::pin(r) as BoxReadStream, Box::pin(w) as BoxWriteStream)) - }) + fn accept_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), ConnectionError>> { + Box::pin(async { stream::ManageStream::accept_bi(self).await }) } - fn accept_uni(&self) -> BoxFuture<'_, Result> { - Box::pin(async { - let r = ManageStream::accept_uni(self).await?; - Ok(Box::pin(r) as BoxReadStream) - }) + fn accept_uni(&self) -> BoxFuture<'_, Result> { + Box::pin(async { stream::ManageStream::accept_uni(self).await }) } } -/// AFIT version of local agent access. -pub trait WithLocalAgent: Send + Sync { - type LocalAgent: agent::LocalAgent + 'static; - fn local_agent( +/// AFIT version of local authority access. +pub trait WithLocalAuthority: Send + Sync { + type LocalAuthority: LocalAuthority + 'static; + fn local_authority( &self, - ) -> impl Future, ConnectionError>> + Send + '_; + ) -> impl Future, ConnectionError>> + Send + '_; } -/// Object-safe version of [`WithLocalAgent`] with type-erased agent. -pub trait DynWithLocalAgent: Send + Sync { - fn local_agent( +/// Object-safe version of [`WithLocalAuthority`] with type-erased agent. +pub trait DynWithLocalAuthority: Send + Sync { + fn local_authority( &self, - ) -> BoxFuture<'_, Result>, ConnectionError>>; + ) -> BoxFuture<'_, Result>, ConnectionError>>; } -impl DynWithLocalAgent for T { - fn local_agent( +impl DynWithLocalAuthority for T { + fn local_authority( &self, - ) -> BoxFuture<'_, Result>, ConnectionError>> { + ) -> BoxFuture<'_, Result>, ConnectionError>> { Box::pin(async { - WithLocalAgent::local_agent(self) + WithLocalAuthority::local_authority(self) .await - .map(|opt| opt.map(|a| Arc::new(a) as Arc)) + .map(|opt| opt.map(|a| Arc::new(a) as Arc)) }) } } -/// AFIT version of remote agent access. -pub trait WithRemoteAgent: Send + Sync { - type RemoteAgent: agent::RemoteAgent + 'static; - fn remote_agent( +/// AFIT version of remote authority access. +pub trait WithRemoteAuthority: Send + Sync { + type RemoteAuthority: RemoteAuthority + 'static; + fn remote_authority( &self, - ) -> impl Future, ConnectionError>> + Send + '_; + ) -> impl Future, ConnectionError>> + Send + '_; } -/// Object-safe version of [`WithRemoteAgent`] with type-erased agent. -pub trait DynWithRemoteAgent: Send + Sync { - fn remote_agent( +/// Object-safe version of [`WithRemoteAuthority`] with type-erased agent. +pub trait DynWithRemoteAuthority: Send + Sync { + fn remote_authority( &self, - ) -> BoxFuture<'_, Result>, ConnectionError>>; + ) -> BoxFuture<'_, Result>, ConnectionError>>; } -impl DynWithRemoteAgent for T { - fn remote_agent( +impl DynWithRemoteAuthority for T { + fn remote_authority( &self, - ) -> BoxFuture<'_, Result>, ConnectionError>> { + ) -> BoxFuture<'_, Result>, ConnectionError>> { Box::pin(async { - WithRemoteAgent::remote_agent(self) + WithRemoteAuthority::remote_authority(self) .await - .map(|opt| opt.map(|a| Arc::new(a) as Arc)) + .map(|opt| opt.map(|a| Arc::new(a) as Arc)) }) } } @@ -286,8 +599,8 @@ impl DynWithRemoteAgent for T { /// # Error latching contract /// /// A connection has a single terminal error. Once **any** operation on the same -/// connection — including [`ManageStream`], [`WithLocalAgent`], -/// [`WithRemoteAgent`], or [`Lifecycle`] methods — returns a +/// connection — including [`ManageStream`], [`WithLocalAuthority`], +/// [`WithRemoteAuthority`], or [`Lifecycle`] methods — returns a /// [`ConnectionError`] (directly, or wrapped in /// [`StreamError::Connection`]), the connection is considered dead and: /// @@ -350,212 +663,40 @@ impl DynLifecycle for T { /// Not object-safe due to associated types. Use [`DynConnection`] for /// type-erased usage. pub trait Connection: - ManageStream + WithLocalAgent + WithRemoteAgent + Lifecycle + Send + Sync + Any + ManageStream + WithLocalAuthority + WithRemoteAuthority + Lifecycle + Send + Sync + Any { } -impl Connection - for C +impl + Connection for C { } /// Object-safe composite trait for type-erased connections. pub trait DynConnection: - DynManageStream + DynWithLocalAgent + DynWithRemoteAgent + DynLifecycle + Send + Sync + Any -{ -} - -impl - DynConnection for C -{ -} - -pub trait GetStreamId { - fn poll_stream_id(self: Pin<&mut Self>, cx: &mut Context) -> Poll>; -} - -impl

GetStreamId for Pin

-where - P: DerefMut, -{ - fn poll_stream_id(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - ::poll_stream_id(self.as_deref_mut(), cx) - } -} - -impl GetStreamId for &mut S -where - S: GetStreamId + Unpin + ?Sized, -{ - fn poll_stream_id(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - S::poll_stream_id(Pin::new(self.get_mut()), cx) - } -} - -pin_project_lite::pin_project! { - pub struct StreamId{ - #[pin] - stream: S - } -} - -impl Future for StreamId { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let project = self.project(); - project.stream.poll_stream_id(cx) - } -} - -pub trait GetStreamIdExt { - fn stream_id(&mut self) -> StreamId<&mut Self> { - StreamId { stream: self } - } -} - -impl GetStreamIdExt for T {} - -pub trait StopStream { - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll>; -} - -impl

StopStream for Pin

-where - P: DerefMut, -{ - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - ::poll_stop(self.as_deref_mut(), cx, code) - } -} - -impl StopStream for &mut S -where - S: StopStream + Unpin + ?Sized, -{ - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - S::poll_stop(Pin::new(self.get_mut()), cx, code) - } -} - -pin_project_lite::pin_project! { - pub struct Stop{ - code: VarInt, - #[pin] - stream: S - } -} - -impl Future for Stop { - type Output = Result<(), StreamError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let project = self.project(); - project.stream.poll_stop(cx, *project.code) - } -} - -pub trait StopStreamExt { - fn stop(&mut self, code: VarInt) -> Stop<&mut Self> { - Stop { code, stream: self } - } -} - -impl StopStreamExt for T {} - -pub trait ReadStream: - StopStream + GetStreamId + Stream> + Send + Any -{ -} - -impl> + Send + ?Sized + Any> - ReadStream for S -{ -} - -pub trait CancelStream { - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll>; -} - -impl CancelStream for &mut S -where - S: CancelStream + Unpin + ?Sized, -{ - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - S::poll_cancel(Pin::new(self.get_mut()), cx, code) - } -} - -impl

CancelStream for Pin

-where - P: DerefMut, -{ - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - ::poll_cancel(self.as_deref_mut(), cx, code) - } -} - -pin_project_lite::pin_project! { - pub struct Cancel{ - code: VarInt, - #[pin] - stream: S - } -} - -impl Future for Cancel { - type Output = Result<(), StreamError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let project = self.project(); - project.stream.poll_cancel(cx, *project.code) - } -} - -pub trait CancelStreamExt { - fn cancel(&mut self, code: VarInt) -> Cancel<&mut Self> { - Cancel { code, stream: self } - } -} - -impl CancelStreamExt for T {} - -pub trait WriteStream: - CancelStream + GetStreamId + Sink + Send + Any + DynManageStream + DynWithLocalAuthority + DynWithRemoteAuthority + DynLifecycle + Send + Sync + Any { } -impl + Send + ?Sized + Any> - WriteStream for S +impl< + C: DynManageStream + + DynWithLocalAuthority + + DynWithRemoteAuthority + + DynLifecycle + + Send + + Sync + + Any, +> DynConnection for C { } #[cfg(any(test, feature = "testing"))] pub mod test { + #[cfg(test)] + use std::{ + convert::Infallible, + sync::{Arc, Mutex}, + }; use std::{ pin::Pin, task::{Context, Poll, ready}, @@ -563,10 +704,14 @@ pub mod test { use bytes::Bytes; use futures::{Sink, SinkExt, Stream, StreamExt}; + #[cfg(test)] + use http::uri::Authority; use tokio::sync::oneshot; + #[cfg(test)] + use crate::{connection::tests::MockConnection, quic::Connect}; use crate::{ - quic::{CancelStream, GetStreamId, StopStream, StreamError}, + quic::{GetStreamId, ResetStream, StopStream, StreamError}, varint::VarInt, }; @@ -591,6 +736,47 @@ pub mod test { Reset(VarInt), } + #[cfg(test)] + #[derive(Default)] + struct RecordingConnector { + connection: Arc, + servers: Mutex>, + } + + #[cfg(test)] + impl RecordingConnector { + fn new(connection: MockConnection) -> Self { + Self { + connection: Arc::new(connection), + servers: Mutex::default(), + } + } + + fn servers(&self) -> Vec { + self.servers + .lock() + .expect("recorded server list poisoned") + .clone() + } + } + + #[cfg(test)] + impl Connect for RecordingConnector { + type Connection = MockConnection; + type Error = Infallible; + + async fn connect<'a>( + &'a self, + server: &'a Authority, + ) -> Result, Self::Error> { + self.servers + .lock() + .expect("recorded server list poisoned") + .push(server.to_string()); + Ok(self.connection.clone()) + } + } + impl GetStreamId for MockStreamWriter { fn poll_stream_id( self: Pin<&mut Self>, @@ -600,11 +786,11 @@ pub mod test { } } - impl + ?Sized> CancelStream for MockStreamWriter + impl + ?Sized> ResetStream for MockStreamWriter where StreamError: From, { - fn poll_cancel( + fn poll_reset( self: Pin<&mut Self>, cx: &mut Context, code: VarInt, @@ -778,4 +964,128 @@ pub mod test { }; tokio::join!(send, recv); } + + #[test] + fn stream_error_helpers_preserve_typed_io_sources() { + use crate::{ + error::Code, + quic::{ApplicationError, ConnectionError, TransportError}, + }; + + let reset_code = VarInt::from_u32(17); + let reset = StreamError::Reset { code: reset_code }; + assert!(reset.is_reset()); + let reset_io = std::io::Error::from(reset.clone()); + assert!(matches!( + StreamError::try_from(reset_io), + Ok(StreamError::Reset { code }) if code == reset_code + )); + + let connection = ConnectionError::Application { + source: ApplicationError { + code: Code::H3_NO_ERROR, + reason: "closed".into(), + }, + }; + let stream = StreamError::from(std::io::Error::from(connection.clone())); + assert!(!stream.is_reset()); + assert!(matches!( + stream, + StreamError::Connection { + source: ConnectionError::Application { .. }, + } + )); + assert!(connection.is_application()); + assert!(!connection.is_transport()); + + let transport = ConnectionError::Transport { + source: TransportError { + kind: Code::H3_INTERNAL_ERROR.into(), + frame_type: VarInt::from_u32(0x21), + reason: "transport failure".into(), + }, + }; + assert!(transport.is_transport()); + assert!(!transport.is_application()); + assert!(matches!( + StreamError::try_from(std::io::Error::from(transport)), + Ok(StreamError::Connection { + source: ConnectionError::Transport { .. }, + }) + )); + + let plain = std::io::Error::other("not a quic stream error"); + assert!(StreamError::try_from(plain).is_err()); + } + + #[tokio::test] + async fn mock_stream_pair_reports_stream_id_and_accepts_stop_calls() { + use crate::quic::{GetStreamIdExt, StopStreamExt}; + + let stream_id = VarInt::from_u32(91); + let (mut reader, mut writer) = mock_stream_pair(stream_id); + + assert_eq!( + reader.stream_id().await.expect("reader stream id"), + stream_id + ); + assert_eq!( + writer.stream_id().await.expect("writer stream id"), + stream_id + ); + + let stop_code = VarInt::from_u32(29); + reader.stop(stop_code).await.expect("stop stream"); + reader.stop(stop_code).await.expect("repeated stop stream"); + } + + #[tokio::test] + async fn mock_stream_reset_delivers_sticky_reset_to_reader() { + use crate::quic::ResetStreamExt; + + let reset_code = VarInt::from_u32(33); + let (mut reader, mut writer) = mock_stream_pair(VarInt::from_u32(7)); + + writer.reset(reset_code).await.expect("reset stream"); + + for _ in 0..2 { + assert!(matches!( + reader.next().await, + Some(Err(StreamError::Reset { code })) if code == reset_code + )); + } + } + + #[cfg(test)] + #[tokio::test] + async fn connect_blanket_impls_delegate_for_refs_and_arcs() { + async fn connect_with(connector: C, server: &Authority) -> Arc + where + C: Connect, + C::Error: std::fmt::Debug, + { + connector.connect(server).await.expect("connect succeeds") + } + + let connector = RecordingConnector::new(MockConnection::new()); + let expected = connector.connection.clone(); + + let first_server = "first.example:443" + .parse::() + .expect("authority parses"); + let first = connect_with(&connector, &first_server).await; + assert!(Arc::ptr_eq(&first, &expected)); + assert_eq!(connector.servers(), vec!["first.example:443"]); + + let connector = Arc::new(connector); + let second_server = "second.example:443" + .parse::() + .expect("authority parses"); + let second = connect_with(connector.clone(), &second_server).await; + assert!(Arc::ptr_eq(&second, &expected)); + assert_eq!( + connector.servers(), + vec!["first.example:443", "second.example:443"], + ); + } } diff --git a/src/quic/agent.rs b/src/quic/agent.rs deleted file mode 100644 index 1419ecd..0000000 --- a/src/quic/agent.rs +++ /dev/null @@ -1,134 +0,0 @@ -use std::fmt::Debug; - -use futures::future::BoxFuture; -use rustls::{ - SignatureScheme, - pki_types::{CertificateDer, SubjectPublicKeyInfoDer}, -}; -use snafu::Snafu; -use x509_parser::prelude::FromDer; - -use crate::quic; - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum SignError { - #[snafu(display("unsupported signature scheme {scheme:?}"))] - UnsupportedScheme { scheme: SignatureScheme }, - #[snafu(transparent)] - Crypto { source: rustls::Error }, - #[snafu(transparent)] - Connection { source: quic::ConnectionError }, -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum VerifyError { - #[snafu(display("unsupported signature scheme {scheme:?}"))] - UnsupportedScheme { scheme: SignatureScheme }, - #[snafu(transparent)] - Connection { source: quic::ConnectionError }, -} - -pub trait LocalAgent: Send + Sync + Debug { - fn name(&self) -> &str; - - fn cert_chain(&self) -> &[CertificateDer<'static>]; - - fn sign_algorithm(&self) -> rustls::SignatureAlgorithm; - - fn sign( - &self, - scheme: SignatureScheme, - data: &[u8], - ) -> BoxFuture<'_, Result, SignError>>; - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - extract_public_key(self.cert_chain()) - } - - fn verify( - &self, - scheme: SignatureScheme, - data: &[u8], - signature: &[u8], - ) -> BoxFuture<'_, Result> { - let result = verify_signature(self.public_key(), scheme, data, signature); - Box::pin(std::future::ready(result)) - } -} - -pub trait RemoteAgent: Send + Sync + Debug { - fn name(&self) -> &str; - fn cert_chain(&self) -> &[CertificateDer<'static>]; - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - extract_public_key(self.cert_chain()) - } - - fn verify( - &self, - scheme: SignatureScheme, - data: &[u8], - signature: &[u8], - ) -> BoxFuture<'_, Result> { - let result = verify_signature(self.public_key(), scheme, data, signature); - Box::pin(std::future::ready(result)) - } -} - -pub fn extract_public_key<'d>(cert_chain: &'d [CertificateDer<'d>]) -> SubjectPublicKeyInfoDer<'d> { - use x509_parser::prelude::*; - - match x509_parser::certificate::X509Certificate::from_der(&cert_chain[0]) { - Ok((_remain, certificate)) => { - let spki = certificate.public_key().raw; - spki.to_owned().into() - } - Err(_error) if cert_chain.len() == 1 => cert_chain[0].as_ref().into(), - Err(_error) => unreachable!("rustls returned an invalid peer_certificates."), - } -} - -pub fn sign_with_key( - key: &(impl rustls::sign::SigningKey + ?Sized), - scheme: SignatureScheme, - data: &[u8], -) -> Result, SignError> { - // FIXME: same as load spki then sign with ring? - let signer = key - .choose_scheme(&[scheme]) - .ok_or(SignError::UnsupportedScheme { scheme })?; - Ok(signer.sign(data)?) -} - -pub fn verify_signature( - spki: SubjectPublicKeyInfoDer, - scheme: SignatureScheme, - data: &[u8], - signature: &[u8], -) -> Result { - let algorithm: &'static dyn ring::signature::VerificationAlgorithm = match scheme { - SignatureScheme::ECDSA_NISTP384_SHA384 => &ring::signature::ECDSA_P384_SHA384_ASN1, - SignatureScheme::ECDSA_NISTP256_SHA256 => &ring::signature::ECDSA_P256_SHA256_ASN1, - SignatureScheme::ED25519 => &ring::signature::ED25519, - SignatureScheme::RSA_PKCS1_SHA256 => &ring::signature::RSA_PKCS1_2048_8192_SHA256, - SignatureScheme::RSA_PKCS1_SHA384 => &ring::signature::RSA_PKCS1_2048_8192_SHA384, - SignatureScheme::RSA_PKCS1_SHA512 => &ring::signature::RSA_PKCS1_2048_8192_SHA512, - SignatureScheme::RSA_PSS_SHA256 => &ring::signature::RSA_PSS_2048_8192_SHA256, - SignatureScheme::RSA_PSS_SHA384 => &ring::signature::RSA_PSS_2048_8192_SHA384, - SignatureScheme::RSA_PSS_SHA512 => &ring::signature::RSA_PSS_2048_8192_SHA512, - _ => return Err(VerifyError::UnsupportedScheme { scheme }), - }; - - let public_key = match x509_parser::x509::SubjectPublicKeyInfo::from_der(&spki) { - Ok((_remain, spki)) => spki.subject_public_key, - Err(_error) => unreachable!("rustls returned an invalid peer_certificates."), - }; - - Ok( - ring::signature::UnparsedPublicKey::new(algorithm, public_key) - .verify(data, signature) - .is_ok(), - ) -} diff --git a/src/rpc.rs b/src/rpc.rs index f6aa742..7f5e050 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -1,5 +1,5 @@ -pub(crate) mod bridge; pub(crate) mod error; +pub(crate) mod stream; pub mod lifecycle; pub mod quic; diff --git a/src/rpc/bridge.rs b/src/rpc/bridge.rs deleted file mode 100644 index 896f376..0000000 --- a/src/rpc/bridge.rs +++ /dev/null @@ -1,477 +0,0 @@ -//! Generic bridge state machines for converting remoc RTC clients back into -//! poll-based streams (`Stream`, `Sink`, `GetStreamId`, `StopStream`, `CancelStream`). -//! -//! Both QUIC-level and message-level stream clients share the same state machine -//! logic; only the client type and data error type differ. This module provides -//! [`ReadBridge`] and [`WriteBridge`] parameterised over those axes. - -use std::{ - pin::Pin, - task::{Context, Poll, ready}, -}; - -use bytes::Bytes; -use futures::{Sink, Stream, future::Either, stream::FusedStream}; -use tokio_util::sync::CancellationToken; - -use crate::{quic, varint::VarInt}; - -// --------------------------------------------------------------------------- -// ReadBridge -// --------------------------------------------------------------------------- - -pin_project_lite::pin_project! { - #[project = ReadStateProj] - #[project_replace = ReadStateReplace] - enum ReadState { - Stream { stream: St }, - Read { - token: CancellationToken, - #[pin] future: RF, - }, - Stop { - code: VarInt, - #[pin] future: SF, - }, - Empty, - } -} - -impl ReadState { - fn take_stream(self: Pin<&mut Self>) -> St { - match self.project_replace(ReadState::Empty) { - ReadStateReplace::Stream { stream } => stream, - _ => unreachable!("invalid state for take_stream"), - } - } -} - -pin_project_lite::pin_project! { - /// Generic read bridge: converts an RTC client into a poll-based - /// `Stream + StopStream + GetStreamId`. - /// - /// - `St`: client type stored in the state machine - /// - `E`: data error type (`Stream::Item = Result`) - /// - `R / RF`: read closure and its future - /// - `S / SF`: stop closure and its future - pub(super) struct ReadBridge { - stream_id: VarInt, - terminated: bool, - #[pin] - state: ReadState, - - read: R, - stop: S, - - // PhantomData to carry `E` without storing it. - _error: std::marker::PhantomData E>, - } -} - -impl ReadBridge { - pub(super) fn new(stream_id: VarInt, client: St, read: R, stop: S) -> Self { - Self { - stream_id, - terminated: false, - state: ReadState::Stream { stream: client }, - read, - stop, - _error: std::marker::PhantomData, - } - } -} - -impl quic::GetStreamId for ReadBridge { - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - Poll::Ready(Ok(self.stream_id)) - } -} - -impl Stream for ReadBridge -where - E: From, - R: Fn(St, CancellationToken) -> RF, - RF: Future>), St>>, - S: Fn(St, VarInt) -> SF, - SF: Future)>, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let project = self.project(); - let mut state = project.state; - - loop { - match state.as_mut().project() { - ReadStateProj::Stream { .. } => { - if *project.terminated { - return Poll::Ready(None); - } - let stream = state.as_mut().take_stream(); - let cancellation_token = CancellationToken::new(); - state.set(ReadState::Read { - token: cancellation_token.clone(), - future: (project.read)(stream, cancellation_token), - }); - } - ReadStateProj::Read { future, .. } => { - match ready!(future.poll(cx)) { - Either::Left((stream, result)) => { - if result.is_none() { - *project.terminated = true; - } - state.set(ReadState::Stream { stream }); - return Poll::Ready(result); - } - Either::Right(stream) => { - state.set(ReadState::Stream { stream }); - } - }; - } - ReadStateProj::Stop { future, .. } => { - let (stream, result) = ready!(future.poll(cx)); - state.set(ReadState::Stream { stream }); - result?; - } - ReadStateProj::Empty => unreachable!("invalid state for poll_next"), - }; - } - } -} - -impl FusedStream for ReadBridge -where - E: From, - R: Fn(St, CancellationToken) -> RF, - RF: Future>), St>>, - S: Fn(St, VarInt) -> SF, - SF: Future)>, -{ - fn is_terminated(&self) -> bool { - self.terminated - } -} - -impl quic::StopStream for ReadBridge -where - E: From, - R: Fn(St, CancellationToken) -> RF, - RF: Future>), St>>, - S: Fn(St, VarInt) -> SF, - SF: Future)>, -{ - fn poll_stop( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let project = self.project(); - let mut state = project.state; - - loop { - let stream = match state.as_mut().project() { - ReadStateProj::Stream { .. } => state.as_mut().take_stream(), - ReadStateProj::Read { future, token } => { - token.cancel(); - let (Either::Left((stream, ..)) | Either::Right(stream)) = - ready!(future.poll(cx)); - stream - } - ReadStateProj::Stop { future, code: sent } => { - let (stream, result) = ready!(future.poll(cx)); - if *sent == code { - state.set(ReadState::Stream { stream }); - return Poll::Ready(result); - } else { - stream - } - } - ReadStateProj::Empty => unreachable!("invalid state for poll_stop"), - }; - state.set(ReadState::Stop { - code, - future: (project.stop)(stream, code), - }) - } - } -} - -// --------------------------------------------------------------------------- -// WriteBridge -// --------------------------------------------------------------------------- - -pin_project_lite::pin_project! { - #[project = WriteStateProj] - #[project_replace = WriteStateReplace] - enum WriteState { - Stream { stream: St }, - Write { - token: CancellationToken, - #[pin] future: WF, - }, - Flush { - token: CancellationToken, - #[pin] future: FF, - }, - Shutdown { - token: CancellationToken, - #[pin] future: SF, - }, - Cancel { - code: VarInt, - #[pin] future: CF, - }, - Empty, - } -} - -impl WriteState { - fn take_stream(self: Pin<&mut Self>) -> St { - match self.project_replace(WriteState::Empty) { - WriteStateReplace::Stream { stream } => stream, - _ => unreachable!("invalid state for take_stream, maybe poll_send before poll_ready"), - } - } -} - -pin_project_lite::pin_project! { - /// Generic write bridge: converts an RTC client into a poll-based - /// `Sink + CancelStream + GetStreamId`. - /// - /// - `St`: client type - /// - `E`: data error type (`Sink::Error`) - /// - `W / WF`: write closure / future - /// - `F / FF`: flush closure / future - /// - `S / SF`: shutdown closure / future - /// - `C / CF`: cancel closure / future - pub(super) struct WriteBridge { - stream_id: VarInt, - #[pin] - state: WriteState, - - write: W, - flush: F, - shutdown: S, - cancel: C, - - _error: std::marker::PhantomData E>, - } -} - -impl WriteBridge { - #[allow(clippy::too_many_arguments)] - pub(super) fn new( - stream_id: VarInt, - client: St, - write: W, - flush: F, - shutdown: S, - cancel: C, - ) -> Self { - Self { - stream_id, - state: WriteState::Stream { stream: client }, - write, - flush, - shutdown, - cancel, - _error: std::marker::PhantomData, - } - } -} - -impl quic::GetStreamId - for WriteBridge -{ - fn poll_stream_id( - self: Pin<&mut Self>, - _cx: &mut Context, - ) -> Poll> { - Poll::Ready(Ok(self.stream_id)) - } -} - -impl Sink - for WriteBridge -where - E: From, - W: Fn(St, CancellationToken, Bytes) -> WF, - WF: Future), St>>, - F: Fn(St, CancellationToken) -> FF, - FF: Future), St>>, - S: Fn(St, CancellationToken) -> SF, - SF: Future), St>>, - C: Fn(St, VarInt) -> CF, - CF: Future)>, -{ - type Error = E; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let project = self.project(); - let mut state = project.state; - - loop { - match state.as_mut().project() { - WriteStateProj::Stream { .. } => return Poll::Ready(Ok(())), - WriteStateProj::Write { future, .. } => { - match ready!(future.poll(cx)) { - Either::Left((stream, result)) => { - state.set(WriteState::Stream { stream }); - result?; - } - Either::Right(stream) => { - state.set(WriteState::Stream { stream }); - } - }; - } - WriteStateProj::Flush { future, .. } => { - match ready!(future.poll(cx)) { - Either::Left((stream, result)) => { - state.set(WriteState::Stream { stream }); - result?; - } - Either::Right(stream) => { - state.set(WriteState::Stream { stream }); - } - }; - } - WriteStateProj::Shutdown { future, .. } => { - match ready!(future.poll(cx)) { - Either::Left((stream, result)) => { - state.set(WriteState::Stream { stream }); - result?; - } - Either::Right(stream) => { - state.set(WriteState::Stream { stream }); - } - }; - } - WriteStateProj::Cancel { future, .. } => { - let (stream, result) = ready!(future.poll(cx)); - state.set(WriteState::Stream { stream }); - result?; - } - WriteStateProj::Empty => unreachable!("invalid state for poll_ready"), - }; - } - } - - fn start_send(self: Pin<&mut Self>, bytes: Bytes) -> Result<(), Self::Error> { - let project = self.project(); - let mut state = project.state; - - let stream = state.as_mut().take_stream(); - let cancellation_token = CancellationToken::new(); - state.set(WriteState::Write { - token: cancellation_token.clone(), - future: (project.write)(stream, cancellation_token, bytes), - }); - Ok(()) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - let is_flush = matches!(self.state, WriteState::Flush { .. }); - - ready!(self.as_mut().poll_ready(cx)?); - if is_flush { - return Poll::Ready(Ok(())); - } - - let project = self.as_mut().project(); - let mut state = project.state; - - let stream = state.as_mut().take_stream(); - let cancellation_token = CancellationToken::new(); - state.set(WriteState::Flush { - token: cancellation_token.clone(), - future: (project.flush)(stream, cancellation_token), - }); - } - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - loop { - let is_shutdown = matches!(self.state, WriteState::Shutdown { .. }); - - ready!(self.as_mut().poll_ready(cx)?); - if is_shutdown { - return Poll::Ready(Ok(())); - } - - let project = self.as_mut().project(); - let mut state = project.state; - - let stream = state.as_mut().take_stream(); - let cancellation_token = CancellationToken::new(); - state.set(WriteState::Shutdown { - token: cancellation_token.clone(), - future: (project.shutdown)(stream, cancellation_token), - }); - } - } -} - -impl quic::CancelStream - for WriteBridge -where - E: From, - W: Fn(St, CancellationToken, Bytes) -> WF, - WF: Future), St>>, - F: Fn(St, CancellationToken) -> FF, - FF: Future), St>>, - S: Fn(St, CancellationToken) -> SF, - SF: Future), St>>, - C: Fn(St, VarInt) -> CF, - CF: Future)>, -{ - fn poll_cancel( - self: Pin<&mut Self>, - cx: &mut Context, - code: VarInt, - ) -> Poll> { - let project = self.project(); - let mut state = project.state; - - loop { - let stream = match state.as_mut().project() { - WriteStateProj::Stream { .. } => state.as_mut().take_stream(), - WriteStateProj::Write { future, token } => { - token.cancel(); - let (Either::Left((stream, ..)) | Either::Right(stream)) = - ready!(future.poll(cx)); - stream - } - WriteStateProj::Flush { future, token } => { - token.cancel(); - let (Either::Left((stream, ..)) | Either::Right(stream)) = - ready!(future.poll(cx)); - stream - } - WriteStateProj::Shutdown { future, token } => { - token.cancel(); - let (Either::Left((stream, ..)) | Either::Right(stream)) = - ready!(future.poll(cx)); - stream - } - WriteStateProj::Cancel { future, code: sent } => { - let (stream, result) = ready!(future.poll(cx)); - if *sent == code { - state.set(WriteState::Stream { stream }); - return Poll::Ready(result); - } else { - stream - } - } - WriteStateProj::Empty => unreachable!("invalid state for poll_cancel"), - }; - state.set(WriteState::Cancel { - code, - future: (project.cancel)(stream, code), - }) - } - } -} diff --git a/src/rpc/error.rs b/src/rpc/error.rs index b0054ae..123ea6a 100644 --- a/src/rpc/error.rs +++ b/src/rpc/error.rs @@ -1,7 +1,4 @@ -use crate::{ - quic::{self, agent}, - varint::VarInt, -}; +use crate::{quic, varint::VarInt}; /// Lossy: remoc RPC errors must be serialized across process boundaries, so the /// original error type is stringified into a QUIC transport error reason. @@ -27,19 +24,7 @@ impl From for quic::StreamError { } } -impl From for agent::SignError { - fn from(error: remoc::rtc::CallError) -> Self { - quic::ConnectionError::from(error).into() - } -} - -impl From for agent::VerifyError { - fn from(error: remoc::rtc::CallError) -> Self { - quic::ConnectionError::from(error).into() - } -} - -impl From for crate::message::stream::MessageStreamError { +impl From for crate::dhttp::message::MessageStreamError { fn from(error: remoc::rtc::CallError) -> Self { quic::StreamError::from(error).into() } @@ -85,3 +70,59 @@ impl std::fmt::Debug for StringError { std::fmt::Debug::fmt(&self.0, f) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::dhttp::message::MessageStreamError; + + #[test] + fn call_error_maps_to_internal_transport_connection_error() { + let error = quic::ConnectionError::from(remoc::rtc::CallError::Dropped); + let quic::ConnectionError::Transport { source } = error else { + panic!("call error should map to transport error"); + }; + + assert_eq!(source.kind, VarInt::from_u32(0x01)); + assert_eq!(source.frame_type, VarInt::from_u32(0x00)); + assert_eq!(source.reason, "remoc call error: processing request failed"); + } + + #[test] + fn call_error_maps_to_stream_error_through_connection_error() { + let error = quic::StreamError::from(remoc::rtc::CallError::Dropped); + let quic::StreamError::Connection { source } = error else { + panic!("call error should map to connection-scoped stream error"); + }; + + assert!(source.is_transport()); + } + + #[test] + fn call_error_maps_to_message_stream_error_through_quic_error() { + let error = MessageStreamError::from(remoc::rtc::CallError::Dropped); + let MessageStreamError::Quic { + source: quic::StreamError::Connection { source }, + } = error + else { + panic!("call error should map through QUIC stream error"); + }; + + assert!(source.is_transport()); + } + + #[test] + fn string_error_behaves_like_wrapped_string() { + let mut error = StringError::new("remote failure".to_owned()); + assert_eq!(&*error, "remote failure"); + + error.push_str(" with context"); + + assert_eq!(error.as_str(), "remote failure with context"); + assert_eq!(format!("{error}"), "remote failure with context"); + assert_eq!(format!("{error:?}"), "\"remote failure with context\""); + + fn assert_std_error(_: &(impl std::error::Error + ?Sized)) {} + assert_std_error(&error); + } +} diff --git a/src/rpc/lifecycle.rs b/src/rpc/lifecycle.rs index 220ef0a..46a204d 100644 --- a/src/rpc/lifecycle.rs +++ b/src/rpc/lifecycle.rs @@ -24,6 +24,23 @@ //! [`quic::Lifecycle`] automatically gives the container all of //! [`LifecycleExt`]'s guard/check/closed helpers via blanket impl. //! +//! # Latch-aware lifecycle invariant +//! +//! A type that implements [`HasLatch`] **must** make its [`quic::Lifecycle`] +//! implementation latch-aware. In practice, `check()` should consult the latch +//! first (usually through [`LifecycleExt::check_with_probe`]), and `closed()` +//! should return or install the canonical latched error (usually through +//! [`LifecycleExt::resolve_closed`]). +//! +//! Implementing `HasLatch` but writing `check()` / `closed()` as if the latch +//! did not exist is an implementation bug: operation guards may have already +//! recorded a terminal error, and direct lifecycle callers must observe the +//! same first-wins error. +//! +//! `close(code, reason)` is intentionally not a latch propagation channel. It +//! only initiates local close; the terminal error is recorded by guarded +//! operation failures, liveness probes, or `closed()` resolution. +//! //! The container authors its own `impl Lifecycle` using //! [`check_with_probe`](LifecycleExt::check_with_probe) and //! [`resolve_closed`](LifecycleExt::resolve_closed) as building blocks; @@ -60,29 +77,24 @@ impl ConnectionErrorLatch { Self::default() } - /// Lazy latch: only calls `f` when no error is latched yet. + /// Lazy latch: only calls `f` when no error is latched at entry. /// - /// Returns the canonical (first-wins) error. + /// If another setter wins the race after `f` constructs an error, returns + /// the already-latched canonical (first-wins) error instead of the losing + /// candidate. pub(crate) fn latch_with(&self, f: impl FnOnce() -> ConnectionError) -> ConnectionError { match self.terminal_error.peek() { Some(existing) => existing, None => { let error = f(); - let _ = self.terminal_error.set(error.clone()); - error + match self.terminal_error.set(error.clone()) { + Ok(()) => error, + Err(rejected) => self.terminal_error.peek().unwrap_or(rejected), + } } } } - /// Crate-internal shortcut equivalent to `latch_with(|| error)`. - /// - /// Intended for call sites that already own a concrete [`ConnectionError`] - /// (e.g. forwarding the terminal error out of `closed()` after the - /// underlying future resolved). - pub(crate) fn latch_raw(&self, error: ConnectionError) -> ConnectionError { - self.latch_with(|| error) - } - /// Return the latched error, if any. pub(crate) fn check(&self) -> Result<(), ConnectionError> { match self.terminal_error.peek() { @@ -108,7 +120,9 @@ pub(crate) mod sealed { /// container's latch without making the latch a public API. /// /// Implementers must return a stable reference to the same - /// [`ConnectionErrorLatch`] on every call. + /// [`ConnectionErrorLatch`] on every call, and their [`crate::quic::Lifecycle`] + /// implementation must be latch-aware. See the module-level + /// "Latch-aware lifecycle invariant" section. pub trait HasLatch { fn latch(&self) -> &ConnectionErrorLatch; } @@ -169,7 +183,7 @@ pub trait LifecycleExt: quic::Lifecycle + HasLatch { return error; } let error = wait.await; - self.latch().latch_raw(error) + self.latch().latch_with(|| error) } /// Guard an async operation whose error is already a [`ConnectionError`]. @@ -250,7 +264,9 @@ impl LifecycleExt for T {} mod tests { use std::{ borrow::Cow, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, mpsc}, + thread, + time::Duration, }; use super::*; @@ -266,6 +282,26 @@ mod tests { } } + fn make_app_err(tag: u32) -> ConnectionError { + ConnectionError::Application { + source: quic::ApplicationError { + code: Code::new(VarInt::from_u32(tag)), + reason: format!("test-app-{tag}").into(), + }, + } + } + + fn error_kind(error: &ConnectionError) -> VarInt { + match error { + ConnectionError::Transport { source } => source.kind, + _ => panic!("unexpected error shape"), + } + } + + fn ok_unit() -> Result<(), ConnectionError> { + Ok(()) + } + /// Minimal container used to exercise the sealed trait. struct TestLifecycle { latch: ConnectionErrorLatch, @@ -340,6 +376,62 @@ mod tests { )); } + #[test] + fn latch_with_returns_canonical_error_when_setter_races() { + let latch = ConnectionErrorLatch::new(); + let slow_latch = latch.clone(); + let fast_latch = latch.clone(); + let (entered_tx, entered_rx) = mpsc::channel(); + + let slow = thread::spawn(move || { + slow_latch.latch_with(|| { + entered_tx.send(()).unwrap(); + thread::sleep(Duration::from_millis(100)); + make_err(1) + }) + }); + + entered_rx.recv().unwrap(); + + let fast = thread::spawn(move || fast_latch.latch_with(|| make_err(2))); + + let slow_error = slow.join().unwrap(); + let fast_error = fast.join().unwrap(); + let canonical = latch.check().unwrap_err(); + + assert_eq!(error_kind(&slow_error), error_kind(&canonical)); + assert_eq!(error_kind(&fast_error), error_kind(&canonical)); + } + + #[test] + fn cloned_latch_handles_share_canonical_error() { + let latch = ConnectionErrorLatch::new(); + let clone = latch.clone(); + + assert!(latch.check().is_ok()); + assert!(clone.check().is_ok()); + + let installed = clone.latch_with(|| make_err(11)); + let observed = latch.check().unwrap_err(); + + assert_eq!(error_kind(&installed), VarInt::from_u32(11)); + assert_eq!(error_kind(&observed), VarInt::from_u32(11)); + } + + #[test] + fn close_does_not_install_terminal_error() { + let lc = TestLifecycle::new(); + + quic::Lifecycle::close( + &lc, + Code::new(VarInt::from_u32(19)), + Cow::Borrowed("local close"), + ); + + assert!(lc.latch.check().is_ok()); + assert!(quic::Lifecycle::check(&lc).is_ok()); + } + #[test] fn check_with_probe_folds_probe_into_latch() { let lc = TestLifecycle::new(); @@ -365,6 +457,46 @@ mod tests { assert!(quic::Lifecycle::check(&lc).is_ok()); } + #[test] + fn check_with_probe_skips_probe_when_latched() { + let lc = TestLifecycle::new(); + lc.latch.latch_with(|| make_err(12)); + + let called = Mutex::new(false); + let err = lc + .check_with_probe(|| { + *called.lock().unwrap() = true; + Some(make_err(13)) + }) + .unwrap_err(); + + assert_eq!(error_kind(&err), VarInt::from_u32(12)); + assert!( + !*called.lock().unwrap(), + "probe must not run after a terminal error is latched" + ); + } + + #[test] + fn check_with_probe_preserves_application_error_shape() { + let lc = TestLifecycle::new(); + lc.set_probe(make_app_err(25)); + + let e1 = quic::Lifecycle::check(&lc).unwrap_err(); + let e2 = quic::Lifecycle::check(&lc).unwrap_err(); + + match (&e1, &e2) { + ( + ConnectionError::Application { source: s1 }, + ConnectionError::Application { source: s2 }, + ) => { + assert_eq!(s1.code, Code::new(VarInt::from_u32(25))); + assert_eq!(s2.code, Code::new(VarInt::from_u32(25))); + } + _ => panic!("unexpected error shape"), + } + } + #[tokio::test] async fn resolve_closed_returns_latched_without_awaiting() { let lc = TestLifecycle::new(); @@ -382,6 +514,25 @@ mod tests { assert!(never.lock().unwrap().is_some(), "wait must not be polled"); } + #[tokio::test] + async fn resolve_closed_returns_error_that_wins_during_wait() { + let lc = TestLifecycle::new(); + let latch = lc.latch.clone(); + + let got = lc + .resolve_closed(async move { + latch.latch_with(|| make_err(14)); + make_err(15) + }) + .await; + + assert_eq!(error_kind(&got), VarInt::from_u32(14)); + assert_eq!( + error_kind(&lc.latch.check().unwrap_err()), + VarInt::from_u32(14) + ); + } + #[tokio::test] async fn resolve_closed_latches_wait_result() { let lc = TestLifecycle::new(); @@ -400,6 +551,98 @@ mod tests { )); } + #[tokio::test] + async fn closed_after_probe_error_returns_latched_without_consuming_wait() { + let lc = TestLifecycle::new(); + lc.set_probe(make_err(40)); + + let check_error = quic::Lifecycle::check(&lc).unwrap_err(); + assert_eq!(error_kind(&check_error), VarInt::from_u32(40)); + + lc.set_wait(make_err(41)); + let closed_error = quic::Lifecycle::closed(&lc).await; + + assert_eq!(error_kind(&closed_error), VarInt::from_u32(40)); + assert!( + lc.wait.lock().unwrap().is_some(), + "closed must not poll wait once check has latched the terminal error" + ); + } + + #[tokio::test] + async fn closed_latches_default_wait_error_when_wait_is_unset() { + let lc = TestLifecycle::new(); + + let got = quic::Lifecycle::closed(&lc).await; + + assert_eq!(error_kind(&got), VarInt::from_u32(99)); + assert_eq!( + error_kind(&lc.latch.check().unwrap_err()), + VarInt::from_u32(99) + ); + } + + #[tokio::test] + async fn guard_success_returns_value_without_latching_error() { + let lc = TestLifecycle::new(); + + let out = lc.guard(async { Ok::<_, ConnectionError>(31) }).await; + + assert_eq!(out.unwrap(), 31); + assert!(lc.latch.check().is_ok()); + } + + #[tokio::test] + async fn guard_does_not_poll_operation_after_failed_check() { + let lc = TestLifecycle::new(); + lc.set_probe(make_err(16)); + + let called = Arc::new(Mutex::new(false)); + let called2 = called.clone(); + let res: Result<(), ConnectionError> = lc + .guard(async move { + *called2.lock().unwrap() = true; + Ok(()) + }) + .await; + + assert_eq!(error_kind(&res.unwrap_err()), VarInt::from_u32(16)); + assert!( + !*called.lock().unwrap(), + "operation future must not be polled after failed check" + ); + } + + #[tokio::test] + async fn guard_latches_operation_error() { + let lc = TestLifecycle::new(); + + let first: Result<(), ConnectionError> = lc.guard(async { Err(make_err(17)) }).await; + let second: Result<(), ConnectionError> = lc.guard(async { Err(make_err(18)) }).await; + + assert_eq!(error_kind(&first.unwrap_err()), VarInt::from_u32(17)); + assert_eq!(error_kind(&second.unwrap_err()), VarInt::from_u32(17)); + } + + #[tokio::test] + async fn guard_returns_error_latched_during_operation() { + let lc = TestLifecycle::new(); + let latch = lc.latch.clone(); + + let res: Result<(), ConnectionError> = lc + .guard(async move { + latch.latch_with(|| make_err(37)); + Err(make_err(38)) + }) + .await; + + assert_eq!(error_kind(&res.unwrap_err()), VarInt::from_u32(37)); + assert_eq!( + error_kind(&lc.latch.check().unwrap_err()), + VarInt::from_u32(37) + ); + } + #[tokio::test] async fn guard_with_skips_closure_when_latched() { let lc = TestLifecycle::new(); @@ -420,6 +663,25 @@ mod tests { ); } + #[tokio::test] + async fn guard_with_skips_operation_and_mapping_after_failed_check() { + let lc = TestLifecycle::new(); + lc.set_probe(make_err(26)); + + let err = tokio::time::timeout( + Duration::from_millis(50), + lc.guard_with( + std::future::pending::>(), + std::convert::identity, + ), + ) + .await + .expect("guard_with must return before polling the pending operation") + .unwrap_err(); + + assert_eq!(error_kind(&err), VarInt::from_u32(26)); + } + #[tokio::test] async fn guard_with_success_is_untouched() { let lc = TestLifecycle::new(); @@ -429,6 +691,75 @@ mod tests { assert_eq!(out.unwrap(), 7); } + #[tokio::test] + async fn guard_with_error_maps_and_latches_first_error() { + let lc = TestLifecycle::new(); + let map_calls = Mutex::new(Vec::new()); + + let first: Result<(), ConnectionError> = lc + .guard_with(async { Err::<(), _>(32) }, |tag| { + map_calls.lock().unwrap().push(tag); + make_err(tag) + }) + .await; + + assert_eq!(error_kind(&first.unwrap_err()), VarInt::from_u32(32)); + assert_eq!(map_calls.lock().unwrap().as_slice(), &[32]); + + let second_map_called = Mutex::new(false); + let second: Result<(), ConnectionError> = lc + .guard_with(async { Err::<(), _>(33) }, |_| { + *second_map_called.lock().unwrap() = true; + make_err(33) + }) + .await; + + assert_eq!(error_kind(&second.unwrap_err()), VarInt::from_u32(32)); + assert!( + !*second_map_called.lock().unwrap(), + "map_err must stay lazy after an error is latched" + ); + } + + #[test] + fn guard_sync_success_returns_value_without_latching_error() { + let lc = TestLifecycle::new(); + + lc.guard_sync(ok_unit).unwrap(); + + assert!(lc.latch.check().is_ok()); + } + + #[test] + fn guard_sync_skips_operation_after_failed_check() { + let lc = TestLifecycle::new(); + lc.set_probe(make_err(28)); + + let err = lc.guard_sync(ok_unit).unwrap_err(); + + assert_eq!(error_kind(&err), VarInt::from_u32(28)); + } + + #[test] + fn guard_sync_skips_closure_when_already_latched() { + let lc = TestLifecycle::new(); + lc.latch.latch_with(|| make_err(29)); + + let called = Mutex::new(false); + let err = lc + .guard_sync(|| { + *called.lock().unwrap() = true; + Ok::<_, ConnectionError>(()) + }) + .unwrap_err(); + + assert_eq!(error_kind(&err), VarInt::from_u32(29)); + assert!( + !*called.lock().unwrap(), + "operation closure must not run after a terminal error is latched" + ); + } + #[test] fn guard_sync_latches_only_first_error() { let lc = TestLifecycle::new(); @@ -445,4 +776,93 @@ mod tests { _ => panic!("unexpected error shape"), } } + + #[test] + fn guard_sync_with_error_maps_and_latches_first_error() { + let lc = TestLifecycle::new(); + let map_calls = Mutex::new(Vec::new()); + + let first = lc + .guard_sync_with( + || Err::<(), _>(35), + |tag| { + map_calls.lock().unwrap().push(tag); + make_err(tag) + }, + ) + .unwrap_err(); + + assert_eq!(error_kind(&first), VarInt::from_u32(35)); + assert_eq!(map_calls.lock().unwrap().as_slice(), &[35]); + + let second_map_called = Mutex::new(false); + let second = lc + .guard_sync_with( + || Err::<(), _>(36), + |_| { + *second_map_called.lock().unwrap() = true; + make_err(36) + }, + ) + .unwrap_err(); + + assert_eq!(error_kind(&second), VarInt::from_u32(35)); + assert!( + !*second_map_called.lock().unwrap(), + "map_err must stay lazy after an error is latched" + ); + } + + #[test] + fn guard_sync_with_success_does_not_map_error() { + let lc = TestLifecycle::new(); + let called = Mutex::new(false); + + let out = lc + .guard_sync_with( + || Ok::<_, &'static str>(21), + |_| { + *called.lock().unwrap() = true; + make_err(22) + }, + ) + .unwrap(); + + assert_eq!(out, 21); + assert!( + !*called.lock().unwrap(), + "map_err must not run for successful operations" + ); + } + + #[test] + fn guard_sync_with_skips_operation_and_mapping_after_failed_check() { + let lc = TestLifecycle::new(); + lc.set_probe(make_err(23)); + + let op_called = Mutex::new(false); + let map_called = Mutex::new(false); + let err = lc + .guard_sync_with( + || { + *op_called.lock().unwrap() = true; + Err::<(), _>("not reached") + }, + |_| { + *map_called.lock().unwrap() = true; + make_err(24) + }, + ) + .unwrap_err(); + + assert_eq!(error_kind(&err), VarInt::from_u32(23)); + assert!( + !*op_called.lock().unwrap(), + "operation must not run after failed check" + ); + assert!( + !*map_called.lock().unwrap(), + "map_err must not run when operation is skipped" + ); + } } diff --git a/src/rpc/message.rs b/src/rpc/message.rs index 0e975f4..2634c6a 100644 --- a/src/rpc/message.rs +++ b/src/rpc/message.rs @@ -1,36 +1,22 @@ -//! Remote forwarding of message-level streams via RPC RTC. +//! Message-level RPC stream facades over typed stream frame channels. //! -//! Parallel to [`super::quic`] which bridges raw QUIC streams, this module -//! bridges **message-level** streams ([`ReadMessageStream`] / [`WriteMessageStream`]) -//! that carry HTTP/3 DATA frame semantics and use [`MessageStreamError`] for their -//! data path. The QUIC-level control operations (`StopStream`, `CancelStream`, -//! `GetStreamId`) still use [`quic::StreamError`]. -//! The current RPC runtime implementation is backed by `remoc`. +//! The RPC runtime transports stream data/control through the same typed +//! [`ReadFrameChannels`] and [`WriteFrameChannels`] used by QUIC-level RPC +//! stream forwarding. This module adds message-level conversion methods on +//! those channel bundles: data-path QUIC stream errors are mapped to +//! [`MessageStreamError`](crate::dhttp::message::MessageStreamError), while +//! QUIC control operations (`StopStream`, `ResetStream`, `GetStreamId`) keep +//! returning [`quic::StreamError`](crate::quic::StreamError). //! -//! Any type implementing the original [`ReadMessageStream`] / [`WriteMessageStream`] -//! traits automatically implements the RTC traits via blanket impls. +//! # Conversion methods //! -//! # Public API -//! -//! ## RTC client handles (serializable, sendable over the wire) -//! -//! - [`ReadMessageStreamClient`] -//! - [`WriteMessageStreamClient`] -//! -//! ## Conversion methods (reconstruct poll-based streams from clients) -//! -//! - [`ReadMessageStreamClient::into_message_stream`] -//! - [`ReadMessageStreamClient::into_boxed_message_stream`] -//! - [`ReadMessageStreamClient::into_box_reader`] -//! - [`WriteMessageStreamClient::into_message_stream`] -//! - [`WriteMessageStreamClient::into_boxed_message_stream`] -//! - [`WriteMessageStreamClient::into_box_writer`] +//! - `ReadFrameChannels::into_message_stream` +//! - `ReadFrameChannels::into_boxed_message_stream` +//! - `ReadFrameChannels::into_box_reader` +//! - `WriteFrameChannels::into_message_stream` +//! - `WriteFrameChannels::into_boxed_message_stream` +//! - `WriteFrameChannels::into_box_writer` mod stream; -pub use self::stream::{ - ReadMessageStreamClient, ReadMessageStreamReqReceiver, ReadMessageStreamServer, - ReadMessageStreamServerRefMut, ReadMessageStreamServerSharedMut, WriteMessageStreamClient, - WriteMessageStreamReqReceiver, WriteMessageStreamServer, WriteMessageStreamServerRefMut, - WriteMessageStreamServerSharedMut, -}; +pub use super::quic::{ReadFrameChannels, WriteFrameChannels}; diff --git a/src/rpc/message/stream.rs b/src/rpc/message/stream.rs index 1fcf887..562d468 100644 --- a/src/rpc/message/stream.rs +++ b/src/rpc/message/stream.rs @@ -1,193 +1,568 @@ -use std::pin::Pin; +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; use bytes::Bytes; -use futures::{SinkExt, StreamExt, future::Either}; -use tokio_util::sync::CancellationToken; - -use super::super::bridge; -// Import original traits under aliases to avoid collision with the RTC traits -// defined in this module (which share the same names). -use crate::message::stream::{ - ReadMessageStream as OrigReadMessageStream, WriteMessageStream as OrigWriteMessageStream, -}; +use futures::{Sink, Stream, stream::FusedStream}; + use crate::{ - message::stream::{BoxMessageStreamReader, BoxMessageStreamWriter, MessageStreamError}, - quic::{self, CancelStreamExt, GetStreamIdExt, StopStreamExt}, - util::deferred::Deferred, - varint::VarInt, + dhttp::message::{BoxMessageReader, BoxMessageWriter, MessageStreamError}, + quic, + rpc::{ + lifecycle::LifecycleExt, + stream::remoc::{ReadFrameChannels, WriteFrameChannels}, + }, + stream, }; -// --------------------------------------------------------------------------- -// RTC traits -// --------------------------------------------------------------------------- - -/// Remote trait for reading from a message-level stream over remoc RTC. -/// -/// Data reads use [`MessageStreamError`]; QUIC control operations -/// (`stream_id`, `stop`) use [`quic::StreamError`]. -#[remoc::rtc::remote] -pub trait ReadMessageStream: Send { - async fn stream_id(&mut self) -> Result; - async fn read(&mut self) -> Result, MessageStreamError>; - async fn stop(&mut self, code: VarInt) -> Result<(), quic::StreamError>; -} - -/// Remote trait for writing to a message-level stream over remoc RTC. -/// -/// Data writes use [`MessageStreamError`]; QUIC control operations -/// (`stream_id`, `cancel`) use [`quic::StreamError`]. -#[remoc::rtc::remote] -pub trait WriteMessageStream: Send { - async fn stream_id(&mut self) -> Result; - async fn write(&mut self, data: Bytes) -> Result<(), MessageStreamError>; - async fn flush(&mut self) -> Result<(), MessageStreamError>; - async fn shutdown(&mut self) -> Result<(), MessageStreamError>; - async fn cancel(&mut self, code: VarInt) -> Result<(), quic::StreamError>; -} - -// --------------------------------------------------------------------------- -// Server side: blanket impls for original message stream types -// --------------------------------------------------------------------------- - -impl ReadMessageStream for S +pin_project_lite::pin_project! { + struct MessageFrameReader { + #[pin] + inner: R, + } +} + +impl MessageFrameReader { + fn new(inner: R) -> Self { + Self { inner } + } +} + +impl Stream for MessageFrameReader where - S: OrigReadMessageStream + Unpin + Send, + R: Stream>, { - async fn stream_id(&mut self) -> Result { - GetStreamIdExt::stream_id(self).await + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().inner.poll_next(cx) { + Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(Ok(data))), + Poll::Ready(Some(Err(source))) => { + Poll::Ready(Some(Err(MessageStreamError::Quic { source }))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } } +} - async fn read(&mut self) -> Result, MessageStreamError> { - StreamExt::next(self).await.transpose() +impl FusedStream for MessageFrameReader +where + R: FusedStream>, +{ + fn is_terminated(&self) -> bool { + self.inner.is_terminated() } +} + +impl quic::GetStreamId for MessageFrameReader +where + R: quic::GetStreamId, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + self.project().inner.poll_stream_id(cx) + } +} - async fn stop(&mut self, code: VarInt) -> Result<(), quic::StreamError> { - StopStreamExt::stop(self, code).await +impl quic::StopStream for MessageFrameReader +where + R: quic::StopStream, +{ + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context, + code: crate::varint::VarInt, + ) -> Poll> { + self.project().inner.poll_stop(cx, code) + } +} + +impl stream::GetStreamId for MessageFrameReader +where + R: quic::GetStreamId, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + quic::GetStreamId::poll_stream_id(self, cx) } } -impl WriteMessageStream for S +impl stream::StopStream for MessageFrameReader where - S: OrigWriteMessageStream + Unpin + Send, + R: quic::StopStream, { - async fn stream_id(&mut self) -> Result { - GetStreamIdExt::stream_id(self).await + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: crate::varint::VarInt, + ) -> Poll> { + quic::StopStream::poll_stop(self, cx, code) } +} - async fn write(&mut self, data: Bytes) -> Result<(), MessageStreamError> { - SinkExt::send(self, data).await +pin_project_lite::pin_project! { + struct MessageFrameWriter { + #[pin] + inner: W, } +} - async fn flush(&mut self) -> Result<(), MessageStreamError> { - SinkExt::flush(self).await +impl MessageFrameWriter { + fn new(inner: W) -> Self { + Self { inner } } +} + +impl Sink for MessageFrameWriter +where + W: Sink, +{ + type Error = MessageStreamError; - async fn shutdown(&mut self) -> Result<(), MessageStreamError> { - SinkExt::close(self).await + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().inner.poll_ready(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(source)) => Poll::Ready(Err(MessageStreamError::Quic { source })), + Poll::Pending => Poll::Pending, + } } - async fn cancel(&mut self, code: VarInt) -> Result<(), quic::StreamError> { - CancelStreamExt::cancel(self, code).await + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + match self.project().inner.start_send(item) { + Ok(()) => Ok(()), + Err(source) => Err(MessageStreamError::Quic { source }), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().inner.poll_flush(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(source)) => Poll::Ready(Err(MessageStreamError::Quic { source })), + Poll::Pending => Poll::Pending, + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.project().inner.poll_close(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(source)) => Poll::Ready(Err(MessageStreamError::Quic { source })), + Poll::Pending => Poll::Pending, + } } } -// --------------------------------------------------------------------------- -// Client side: ReadMessageStreamClient → impl OrigReadMessageStream -// --------------------------------------------------------------------------- +impl quic::GetStreamId for MessageFrameWriter +where + W: quic::GetStreamId, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + self.project().inner.poll_stream_id(cx) + } +} -impl ReadMessageStreamClient { - /// Convert into a poll-based [`OrigReadMessageStream`]. - pub async fn into_message_stream( - mut self, - ) -> Result { - let stream_id = self.stream_id().await?; - Ok( - bridge::ReadBridge::<_, MessageStreamError, _, _, _, _>::new( - stream_id, - self, - |mut client: ReadMessageStreamClient, token: CancellationToken| async move { - tokio::select! { - res = client.read() => Either::Left((client, res.transpose())), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: ReadMessageStreamClient, code| async move { - let res = client.stop(code).await; - (client, res) - }, - ), - ) +impl quic::ResetStream for MessageFrameWriter +where + W: quic::ResetStream, +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context, + code: crate::varint::VarInt, + ) -> Poll> { + self.project().inner.poll_reset(cx, code) } +} - /// Convert into a boxed [`OrigReadMessageStream`] (lazy — resolves on first poll). - pub fn into_boxed_message_stream(self) -> Pin> { - Box::pin(Deferred::from(self.into_message_stream())) +impl stream::GetStreamId for MessageFrameWriter +where + W: quic::GetStreamId, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + quic::GetStreamId::poll_stream_id(self, cx) } +} - /// Convert into a [`BoxMessageStreamReader`] (implements [`AsyncRead`](tokio::io::AsyncRead)). - pub fn into_box_reader(self) -> BoxMessageStreamReader<'static> { - crate::codec::StreamReader::new(self.into_boxed_message_stream()) +impl stream::ResetStream for MessageFrameWriter +where + W: quic::ResetStream, +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: crate::varint::VarInt, + ) -> Poll> { + quic::ResetStream::poll_reset(self, cx, code) } } -// --------------------------------------------------------------------------- -// Client side: WriteMessageStreamClient → impl OrigWriteMessageStream -// --------------------------------------------------------------------------- +impl ReadFrameChannels { + /// Convert RPC read frame channels into a poll-based message stream reader. + pub fn into_message_stream( + self, + lifecycle: Arc, + ) -> impl stream::ReadStream + + Send + + 'static + where + L: LifecycleExt + 'static, + { + MessageFrameReader::new(self.into_quic(lifecycle)) + } -impl WriteMessageStreamClient { - /// Convert into a poll-based [`OrigWriteMessageStream`]. - pub async fn into_message_stream( - mut self, - ) -> Result { - let stream_id = self.stream_id().await?; - Ok(bridge::WriteBridge::< - _, + /// Convert RPC read frame channels into a boxed message stream reader. + pub fn into_boxed_message_stream(self, lifecycle: Arc) -> BoxMessageReader + where + L: LifecycleExt + 'static, + { + Box::pin(MessageFrameReader::new(self.into_quic(lifecycle))) + } + + /// Convert RPC read frame channels into a boxed message stream reader. + pub fn into_box_reader(self, lifecycle: Arc) -> BoxMessageReader + where + L: LifecycleExt + 'static, + { + self.into_boxed_message_stream(lifecycle) + } +} + +impl WriteFrameChannels { + /// Convert RPC write frame channels into a poll-based message stream writer. + pub fn into_message_stream( + self, + lifecycle: Arc, + ) -> impl stream::WriteStream + + Send + + 'static + where + L: LifecycleExt + 'static, + { + MessageFrameWriter::new(self.into_quic(lifecycle)) + } + + /// Convert RPC write frame channels into a boxed message stream writer. + pub fn into_boxed_message_stream(self, lifecycle: Arc) -> BoxMessageWriter + where + L: LifecycleExt + 'static, + { + Box::pin(MessageFrameWriter::new(self.into_quic(lifecycle))) + } + + /// Convert RPC write frame channels into a boxed message stream writer. + pub fn into_box_writer(self, lifecycle: Arc) -> BoxMessageWriter + where + L: LifecycleExt + 'static, + { + self.into_boxed_message_stream(lifecycle) + } +} + +#[cfg(test)] +mod tests { + use std::{pin::Pin, sync::Arc}; + + use futures::{FutureExt as _, SinkExt as _, StreamExt as _, future::poll_fn}; + use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _}; + + use super::*; + use crate::{ + codec::{SinkWriter, StreamReader}, + rpc::stream::{ + frame::{ReadCommand, ReadEvent, WriteCommand, WriteEvent}, + remoc::RpcFrameIo, + test_io::TestLifecycle, + }, + stream::{GetStreamIdExt as _, ResetStreamExt as _, StopStreamExt as _}, + varint::VarInt, + }; + + fn lifecycle() -> Arc { + Arc::new(TestLifecycle::new()) + } + + fn assert_message_reset(error: MessageStreamError, expected: VarInt) { + let MessageStreamError::Quic { + source: quic::StreamError::Reset { code }, + } = error + else { + panic!("expected reset-backed message stream error"); + }; + assert_eq!(code, expected); + } + + async fn send_write_credit( + hypervisor: &mut RpcFrameIo, + writer: &mut ( + impl stream::WriteStream< + Bytes, MessageStreamError, - _, - _, - _, - _, - _, - _, - _, - _, - >::new( - stream_id, - self, - |mut client: WriteMessageStreamClient, token: CancellationToken, bytes| async move { - tokio::select! { - res = client.write(bytes) => Either::Left((client, res)), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: WriteMessageStreamClient, token: CancellationToken| async move { - tokio::select! { - res = client.flush() => Either::Left((client, res)), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: WriteMessageStreamClient, token: CancellationToken| async move { - tokio::select! { - res = client.shutdown() => Either::Left((client, res)), - _ = token.cancelled() => Either::Right(client), + quic::StreamError, + quic::StreamError, + > + Unpin + ), + ) { + let ready = poll_fn(|cx| Pin::new(&mut *writer).poll_ready(cx)); + let credit = hypervisor.send(WriteEvent::Pull); + let (credit, ready) = tokio::join!(credit, ready); + credit.expect("credit should send"); + ready.expect("writer should become ready"); + } + + #[tokio::test] + async fn read_frame_channels_construct_message_reader() { + let stream_id = VarInt::from_u32(901); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let mut reader = channels.into_message_stream(lifecycle()); + + assert_eq!(reader.stream_id().await.expect("stream id"), stream_id); + + let read = reader.next(); + tokio::pin!(read); + assert!(read.as_mut().now_or_never().is_none()); + assert_eq!(hypervisor.next().await.unwrap().unwrap(), ReadCommand::Pull); + let send = hypervisor.send(ReadEvent::Push { + data: Bytes::from_static(b"message read"), + }); + let (send, received) = tokio::join!(send, read); + send.expect("read event should send"); + assert_eq!( + received.unwrap().expect("message read should succeed"), + Bytes::from_static(b"message read") + ); + } + + #[tokio::test] + async fn write_frame_channels_construct_message_writer() { + let stream_id = VarInt::from_u32(902); + let (channels, mut hypervisor) = WriteFrameChannels::pair(stream_id); + let mut writer = channels.into_message_stream(lifecycle()); + + assert_eq!(writer.stream_id().await.expect("stream id"), stream_id); + send_write_credit(&mut hypervisor, &mut writer).await; + Pin::new(&mut writer) + .start_send(Bytes::from_static(b"message write")) + .expect("start send should succeed"); + + let flush = writer.flush(); + let drive = async { + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Push { + data: Bytes::from_static(b"message write") } - }, - |mut client: WriteMessageStreamClient, code| async move { - let res = client.cancel(code).await; - (client, res) - }, - )) + ); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Flush + ); + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("flush ack should send"); + }; + let (flush, ()) = tokio::join!(flush, drive); + flush.expect("message flush should succeed"); } - /// Convert into a boxed [`OrigWriteMessageStream`] (lazy — resolves on first poll). - pub fn into_boxed_message_stream( - self, - ) -> Pin> { - Box::pin(Deferred::from(self.into_message_stream())) + #[tokio::test] + async fn read_message_stream_stop_is_first_wins_through_frames() { + let stream_id = VarInt::from_u32(903); + let first = VarInt::from_u32(11); + let second = VarInt::from_u32(12); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let mut reader = channels.into_message_stream(lifecycle()); + + assert!(reader.stop(first).now_or_never().is_none()); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + ReadCommand::Stop { code: first } + ); + let second_stop = reader.stop(second); + tokio::pin!(second_stop); + let send = hypervisor.send(ReadEvent::StopAck { code: first }); + let (send, stop) = tokio::join!(send, second_stop); + send.expect("stop ack should send"); + stop.expect("second stop should observe first committed stop"); + } + + #[tokio::test] + async fn write_message_stream_reset_is_first_wins_through_frames() { + let stream_id = VarInt::from_u32(904); + let first = VarInt::from_u32(21); + let second = VarInt::from_u32(22); + let (channels, mut hypervisor) = WriteFrameChannels::pair(stream_id); + let mut writer = channels.into_message_stream(lifecycle()); + + assert!(writer.reset(first).now_or_never().is_none()); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Reset { code: first } + ); + let second_reset = writer.reset(second); + tokio::pin!(second_reset); + let send = hypervisor.send(WriteEvent::ResetAck { code: first }); + let (send, reset) = tokio::join!(send, second_reset); + send.expect("reset ack should send"); + reset.expect("second reset should observe first committed reset"); + } + + #[tokio::test] + async fn read_message_stream_maps_quic_errors_to_message_errors() { + let stream_id = VarInt::from_u32(905); + let code = VarInt::from_u32(31); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let mut reader = channels.into_message_stream(lifecycle()); + + let read = reader.next(); + tokio::pin!(read); + assert!(read.as_mut().now_or_never().is_none()); + assert_eq!(hypervisor.next().await.unwrap().unwrap(), ReadCommand::Pull); + let send = hypervisor.send(ReadEvent::ErrReset { code }); + let (send, received) = tokio::join!(send, read); + send.expect("reset event should send"); + let error = received + .expect("reset should produce an item") + .expect_err("reset should map to message error"); + assert_message_reset(error, code); + } + + #[tokio::test] + async fn write_message_stream_maps_quic_errors_to_message_errors() { + let stream_id = VarInt::from_u32(906); + let code = VarInt::from_u32(41); + let (channels, mut hypervisor) = WriteFrameChannels::pair(stream_id); + let mut writer = channels.into_message_stream(lifecycle()); + + let ready = poll_fn(|cx| Pin::new(&mut writer).poll_ready(cx)); + let reset = hypervisor.send(WriteEvent::ErrReset { code }); + let (reset, ready) = tokio::join!(reset, ready); + reset.expect("reset event should send"); + let error = ready.expect_err("reset should map to message error"); + assert_message_reset(error, code); + } + + #[tokio::test] + async fn read_message_box_adapters_use_frame_channels() { + let stream_id = VarInt::from_u32(907); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let mut stream = channels.into_boxed_message_stream(lifecycle()); + + let read = stream.next(); + tokio::pin!(read); + assert!(read.as_mut().now_or_never().is_none()); + assert_eq!(hypervisor.next().await.unwrap().unwrap(), ReadCommand::Pull); + let send = hypervisor.send(ReadEvent::Push { + data: Bytes::from_static(b"boxed"), + }); + let (send, received) = tokio::join!(send, read); + send.expect("boxed read event should send"); + assert_eq!( + received.unwrap().expect("boxed read should succeed"), + Bytes::from_static(b"boxed") + ); + + let stream_id = VarInt::from_u32(908); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let reader = channels.into_box_reader(lifecycle()); + let read_all = async { + let mut buf = Vec::new(); + StreamReader::new(reader) + .read_to_end(&mut buf) + .await + .expect("box reader should read"); + buf + }; + let drive = async { + assert_eq!(hypervisor.next().await.unwrap().unwrap(), ReadCommand::Pull); + hypervisor + .send(ReadEvent::Push { + data: Bytes::from_static(b"reader"), + }) + .await + .expect("box reader chunk should send"); + assert_eq!(hypervisor.next().await.unwrap().unwrap(), ReadCommand::Pull); + hypervisor + .send(ReadEvent::Eos) + .await + .expect("box reader eos should send"); + }; + let (buf, ()) = tokio::join!(read_all, drive); + assert_eq!(buf, b"reader"); } - /// Convert into a [`BoxMessageStreamWriter`] (implements [`AsyncWrite`](tokio::io::AsyncWrite)). - pub fn into_box_writer(self) -> BoxMessageStreamWriter<'static> { - crate::codec::SinkWriter::new(self.into_boxed_message_stream()) + #[tokio::test] + async fn write_message_box_adapters_use_frame_channels() { + let stream_id = VarInt::from_u32(909); + let (channels, mut hypervisor) = WriteFrameChannels::pair(stream_id); + let mut stream = channels.into_boxed_message_stream(lifecycle()); + send_write_credit(&mut hypervisor, &mut stream).await; + + Pin::new(&mut stream) + .start_send(Bytes::from_static(b"boxed")) + .expect("boxed start send should succeed"); + let flush = stream.flush(); + let drive = async { + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Push { + data: Bytes::from_static(b"boxed") + } + ); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Flush + ); + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("boxed flush ack should send"); + }; + let (flush, ()) = tokio::join!(flush, drive); + flush.expect("boxed flush should succeed"); + + let stream_id = VarInt::from_u32(910); + let (channels, mut hypervisor) = WriteFrameChannels::pair(stream_id); + let mut writer = SinkWriter::new(channels.into_box_writer(lifecycle())); + let write_all = async { + writer + .write_all(b"writer") + .await + .expect("box writer should write"); + tokio::io::AsyncWriteExt::flush(&mut writer) + .await + .expect("box writer should flush"); + }; + let drive = async { + let ready = hypervisor.send(WriteEvent::Pull); + ready.await.expect("box writer credit should send"); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Push { + data: Bytes::from_static(b"writer") + } + ); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Flush + ); + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("box writer flush ack should send"); + }; + let ((), ()) = tokio::join!(write_all, drive); } } diff --git a/src/rpc/quic.rs b/src/rpc/quic.rs index a18b2e8..85fb3b6 100644 --- a/src/rpc/quic.rs +++ b/src/rpc/quic.rs @@ -21,8 +21,8 @@ //! the wire and used to make RPC calls to the remote side. //! //! - [`ConnectionClient`] — RTC client for a remote QUIC connection -//! - [`ReadStreamClient`] — RTC client for a remote readable QUIC stream -//! - [`WriteStreamClient`] — RTC client for a remote writable QUIC stream +//! - [`ReadFrameChannels`] — typed frame channels for a remote readable QUIC stream +//! - [`WriteFrameChannels`] — typed frame channels for a remote writable QUIC stream //! - [`ConnectClient`] — RTC client for a remote QUIC connector //! - [`ListenClient`] — RTC client for a remote QUIC listener //! @@ -37,12 +37,8 @@ //! - [`ConnectionClient::into_quic`] — convert into a [`RemoteConnection`] //! - [`ConnectClient::into_quic`] — convert into a [`RemoteConnector`] //! - [`ListenClient::into_quic`] — convert into a [`RemoteListener`] -//! - [`ReadStreamClient::into_quic`] — convert into a `quic::ReadStream` -//! - [`ReadStreamClient::into_boxed_quic`] — convert into a boxed reader adapter -//! - [`WriteStreamClient::into_quic`] — convert into a `quic::WriteStream` -//! - [`WriteStreamClient::into_boxed_quic`] — convert into a boxed writer adapter -mod agent; +mod authority; mod connect; mod connection; mod error; @@ -52,12 +48,13 @@ mod stream; // Raw remoc-generated RTC client types (serializable, sendable over the wire) pub use self::{ - agent::{ - CachedLocalAgent, CachedRemoteAgent, LocalAgentClient, LocalAgentReqReceiver, - LocalAgentServer, LocalAgentServerRef, LocalAgentServerRefMut, LocalAgentServerShared, - LocalAgentServerSharedMut, RemoteAgentClient, RemoteAgentReqReceiver, RemoteAgentServer, - RemoteAgentServerRef, RemoteAgentServerRefMut, RemoteAgentServerShared, - RemoteAgentServerSharedMut, + authority::{ + CachedLocalAuthority, CachedRemoteAuthority, LocalAuthorityClient, + LocalAuthorityReqReceiver, LocalAuthorityServer, LocalAuthorityServerRef, + LocalAuthorityServerRefMut, LocalAuthorityServerShared, LocalAuthorityServerSharedMut, + RemoteAuthorityClient, RemoteAuthorityReqReceiver, RemoteAuthorityServer, + RemoteAuthorityServerRef, RemoteAuthorityServerRefMut, RemoteAuthorityServerShared, + RemoteAuthorityServerSharedMut, }, connect::{ ConnectClient, ConnectError, ConnectReqReceiver, ConnectServer, ConnectServerRef, @@ -72,9 +69,5 @@ pub use self::{ ListenClient, ListenError, ListenReqReceiver, ListenServer, ListenServerRefMut, ListenServerSharedMut, RemoteListener, }, - stream::{ - ReadStreamClient, ReadStreamReqReceiver, ReadStreamServer, ReadStreamServerRefMut, - ReadStreamServerSharedMut, WriteStreamClient, WriteStreamReqReceiver, WriteStreamServer, - WriteStreamServerRefMut, WriteStreamServerSharedMut, - }, + stream::{ReadFrameChannels, WriteFrameChannels}, }; diff --git a/src/rpc/quic/agent.rs b/src/rpc/quic/agent.rs deleted file mode 100644 index 9874a04..0000000 --- a/src/rpc/quic/agent.rs +++ /dev/null @@ -1,307 +0,0 @@ -use futures::future::BoxFuture; -use rustls::{ - SignatureScheme, - pki_types::{CertificateDer, SubjectPublicKeyInfoDer}, -}; - -use super::serde_types::{ - SerdeCertificateDer, SerdeSignatureAlgorithm, SerdeSignatureScheme, - SerdeSubjectPublicKeyInfoDer, -}; -use crate::quic::{ - self, - agent::{self, SignError, VerifyError}, -}; - -/// Remote trait for [`agent::LocalAgent`], exposing all 6 methods over remoc RTC. -#[remoc::rtc::remote] -pub trait LocalAgent: Send + Sync { - async fn name(&self) -> Result; - async fn cert_chain(&self) -> Result, quic::ConnectionError>; - async fn sign_algorithm(&self) -> Result; - async fn sign( - &self, - scheme: SerdeSignatureScheme, - data: Vec, - ) -> Result, quic::ConnectionError>; - async fn public_key(&self) -> Result; - async fn verify( - &self, - scheme: SerdeSignatureScheme, - data: Vec, - signature: Vec, - ) -> Result; -} - -/// Remote trait for [`agent::RemoteAgent`], exposing all 4 methods over remoc RTC. -#[remoc::rtc::remote] -pub trait RemoteAgent: Send + Sync { - async fn name(&self) -> Result; - async fn cert_chain(&self) -> Result, quic::ConnectionError>; - async fn public_key(&self) -> Result; - async fn verify( - &self, - scheme: SerdeSignatureScheme, - data: Vec, - signature: Vec, - ) -> Result; -} - -pub struct CachedLocalAgent { - client: LocalAgentClient, - name: String, - cert_chain: Vec>, - sign_algorithm: rustls::SignatureAlgorithm, -} - -impl std::fmt::Debug for CachedLocalAgent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CachedRemoteLocalAgent") - .field("name", &self.name) - .field("sign_algorithm", &self.sign_algorithm) - .finish_non_exhaustive() - } -} - -impl CachedLocalAgent { - /// Create a new cached wrapper by eagerly fetching synchronous fields from - /// the remote agent. - pub async fn from_client(client: LocalAgentClient) -> Result { - let name = client.name().await?; - let cert_chain: Vec> = client - .cert_chain() - .await? - .into_iter() - .map(Into::into) - .collect(); - let sign_algorithm: rustls::SignatureAlgorithm = client.sign_algorithm().await?.into(); - Ok(Self { - client, - name, - cert_chain, - sign_algorithm, - }) - } -} - -impl agent::LocalAgent for CachedLocalAgent { - fn name(&self) -> &str { - &self.name - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.cert_chain - } - - fn sign_algorithm(&self) -> rustls::SignatureAlgorithm { - self.sign_algorithm - } - - fn sign( - &self, - scheme: SignatureScheme, - data: &[u8], - ) -> BoxFuture<'_, Result, SignError>> { - let serde_scheme = SerdeSignatureScheme::from(scheme); - let owned_data = data.to_vec(); - let client = self.client.clone(); - Box::pin(async move { - client - .sign(serde_scheme, owned_data) - .await - // lossy: rustls API requires String for General error variant - .map_err(|e| SignError::Crypto { - source: rustls::Error::General(e.to_string()), - }) - }) - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - agent::extract_public_key(quic::agent::LocalAgent::cert_chain(self)) - } - - fn verify( - &self, - scheme: SignatureScheme, - data: &[u8], - signature: &[u8], - ) -> BoxFuture<'_, Result> { - let result = agent::verify_signature( - quic::agent::LocalAgent::public_key(self), - scheme, - data, - signature, - ); - Box::pin(std::future::ready(result)) - } -} - -pub struct CachedRemoteAgent { - #[allow(dead_code)] - client: RemoteAgentClient, - name: String, - cert_chain: Vec>, -} - -impl std::fmt::Debug for CachedRemoteAgent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("CachedRemoteRemoteAgent") - .field("name", &self.name) - .finish_non_exhaustive() - } -} - -impl CachedRemoteAgent { - /// Create a new cached wrapper by eagerly fetching synchronous fields from - /// the remote agent. - pub async fn from_client(client: RemoteAgentClient) -> Result { - let name = client.name().await?; - let cert_chain: Vec> = client - .cert_chain() - .await? - .into_iter() - .map(Into::into) - .collect(); - Ok(Self { - client, - name, - cert_chain, - }) - } -} - -impl agent::RemoteAgent for CachedRemoteAgent { - fn name(&self) -> &str { - &self.name - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.cert_chain - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - agent::extract_public_key(quic::agent::RemoteAgent::cert_chain(self)) - } - - fn verify( - &self, - scheme: SignatureScheme, - data: &[u8], - signature: &[u8], - ) -> BoxFuture<'_, Result> { - let result = agent::verify_signature( - quic::agent::RemoteAgent::public_key(self), - scheme, - data, - signature, - ); - Box::pin(std::future::ready(result)) - } -} - -impl LocalAgent for A -where - A: agent::LocalAgent + Send + Sync, -{ - async fn name(&self) -> Result { - Ok(agent::LocalAgent::name(self).to_owned()) - } - - async fn cert_chain(&self) -> Result, quic::ConnectionError> { - Ok(self - .cert_chain() - .iter() - .cloned() - .map(SerdeCertificateDer::from) - .collect()) - } - - async fn sign_algorithm(&self) -> Result { - Ok(SerdeSignatureAlgorithm::from( - agent::LocalAgent::sign_algorithm(self), - )) - } - - async fn sign( - &self, - scheme: SerdeSignatureScheme, - data: Vec, - ) -> Result, quic::ConnectionError> { - agent::LocalAgent::sign(self, SignatureScheme::from(scheme), &data) - .await - // lossy: TransportError.reason is a protocol string field - .map_err(|e| quic::ConnectionError::Transport { - source: quic::TransportError { - kind: crate::varint::VarInt::from_u32(0x01), - frame_type: crate::varint::VarInt::from_u32(0x00), - reason: format!("sign error: {e}").into(), - }, - }) - } - - async fn public_key(&self) -> Result { - Ok(SerdeSubjectPublicKeyInfoDer::from( - agent::LocalAgent::public_key(self), - )) - } - - async fn verify( - &self, - scheme: SerdeSignatureScheme, - data: Vec, - signature: Vec, - ) -> Result { - agent::LocalAgent::verify(self, SignatureScheme::from(scheme), &data, &signature) - .await - // lossy: TransportError.reason is a protocol string field - .map_err(|e| quic::ConnectionError::Transport { - source: quic::TransportError { - kind: crate::varint::VarInt::from_u32(0x01), - frame_type: crate::varint::VarInt::from_u32(0x00), - reason: format!("verify error: {e}").into(), - }, - }) - } -} - -impl RemoteAgent for A -where - A: agent::RemoteAgent + Send + Sync, -{ - async fn name(&self) -> Result { - Ok(agent::RemoteAgent::name(self).to_owned()) - } - - async fn cert_chain(&self) -> Result, quic::ConnectionError> { - Ok(self - .cert_chain() - .iter() - .cloned() - .map(SerdeCertificateDer::from) - .collect()) - } - - async fn public_key(&self) -> Result { - Ok(SerdeSubjectPublicKeyInfoDer::from( - agent::RemoteAgent::public_key(self), - )) - } - - async fn verify( - &self, - scheme: SerdeSignatureScheme, - data: Vec, - signature: Vec, - ) -> Result { - agent::RemoteAgent::verify(self, SignatureScheme::from(scheme), &data, &signature) - .await - // lossy: TransportError.reason is a protocol string field - .map_err(|e| quic::ConnectionError::Transport { - source: quic::TransportError { - kind: crate::varint::VarInt::from_u32(0x01), - frame_type: crate::varint::VarInt::from_u32(0x00), - reason: format!("verify error: {e}").into(), - }, - }) - } -} diff --git a/src/rpc/quic/authority.rs b/src/rpc/quic/authority.rs new file mode 100644 index 0000000..a74b037 --- /dev/null +++ b/src/rpc/quic/authority.rs @@ -0,0 +1,569 @@ +use dhttp_identity::identity::{self as authority, SignError, VerifyError}; +use futures::future::BoxFuture; +use rustls::pki_types::{CertificateDer, SubjectPublicKeyInfoDer}; + +use super::serde_types::{SerdeCertificateDer, SerdeSubjectPublicKeyInfoDer}; +use crate::quic; + +/// Remote trait for [`authority::LocalAuthority`], exposing authority methods over remoc RTC. +#[remoc::rtc::remote] +pub trait LocalAuthority: Send + Sync { + async fn name(&self) -> Result; + async fn cert_chain(&self) -> Result, quic::ConnectionError>; + async fn sign(&self, data: Vec) -> Result, quic::ConnectionError>; + async fn public_key(&self) -> Result; + async fn verify( + &self, + data: Vec, + signature: Vec, + ) -> Result; +} + +/// Remote trait for [`authority::RemoteAuthority`], exposing authority methods over remoc RTC. +#[remoc::rtc::remote] +pub trait RemoteAuthority: Send + Sync { + async fn name(&self) -> Result; + async fn cert_chain(&self) -> Result, quic::ConnectionError>; + async fn public_key(&self) -> Result; + async fn verify( + &self, + data: Vec, + signature: Vec, + ) -> Result; +} + +pub struct CachedLocalAuthority { + client: LocalAuthorityClient, + name: String, + cert_chain: Vec>, +} + +impl std::fmt::Debug for CachedLocalAuthority { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CachedRemoteLocalAuthority") + .field("name", &self.name) + .finish_non_exhaustive() + } +} + +impl CachedLocalAuthority { + /// Create a new cached wrapper by eagerly fetching synchronous fields from + /// the remote authority. + pub async fn from_client(client: LocalAuthorityClient) -> Result { + let name = client.name().await?; + let cert_chain: Vec> = client + .cert_chain() + .await? + .into_iter() + .map(Into::into) + .collect(); + Ok(Self { + client, + name, + cert_chain, + }) + } +} + +impl authority::LocalAuthority for CachedLocalAuthority { + fn name(&self) -> &str { + &self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + + fn sign(&self, data: &[u8]) -> BoxFuture<'_, Result, SignError>> { + let owned_data = data.to_vec(); + let client = self.client.clone(); + Box::pin(async move { + client + .sign(owned_data) + .await + // lossy: rustls API requires String for General error variant + .map_err(|e| SignError::Crypto { + source: rustls::Error::General(e.to_string()), + }) + }) + } + + fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { + authority::extract_public_key(authority::LocalAuthority::cert_chain(self)) + } + + fn verify(&self, data: &[u8], signature: &[u8]) -> BoxFuture<'_, Result> { + let result = authority::verify_signature( + authority::LocalAuthority::public_key(self), + data, + signature, + ); + Box::pin(std::future::ready(result)) + } +} + +pub struct CachedRemoteAuthority { + #[allow(dead_code)] + client: RemoteAuthorityClient, + name: String, + cert_chain: Vec>, +} + +impl std::fmt::Debug for CachedRemoteAuthority { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CachedRemoteRemoteAuthority") + .field("name", &self.name) + .finish_non_exhaustive() + } +} + +impl CachedRemoteAuthority { + /// Create a new cached wrapper by eagerly fetching synchronous fields from + /// the remote authority. + pub async fn from_client(client: RemoteAuthorityClient) -> Result { + let name = client.name().await?; + let cert_chain: Vec> = client + .cert_chain() + .await? + .into_iter() + .map(Into::into) + .collect(); + Ok(Self { + client, + name, + cert_chain, + }) + } +} + +impl authority::RemoteAuthority for CachedRemoteAuthority { + fn name(&self) -> &str { + &self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + + fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { + authority::extract_public_key(authority::RemoteAuthority::cert_chain(self)) + } + + fn verify(&self, data: &[u8], signature: &[u8]) -> BoxFuture<'_, Result> { + let result = authority::verify_signature( + authority::RemoteAuthority::public_key(self), + data, + signature, + ); + Box::pin(std::future::ready(result)) + } +} + +impl LocalAuthority for A +where + A: authority::LocalAuthority + Send + Sync, +{ + async fn name(&self) -> Result { + Ok(authority::LocalAuthority::name(self).to_owned()) + } + + async fn cert_chain(&self) -> Result, quic::ConnectionError> { + Ok(self + .cert_chain() + .iter() + .cloned() + .map(SerdeCertificateDer::from) + .collect()) + } + + async fn sign(&self, data: Vec) -> Result, quic::ConnectionError> { + authority::LocalAuthority::sign(self, &data) + .await + // lossy: TransportError.reason is a protocol string field + .map_err(|e| quic::ConnectionError::Transport { + source: quic::TransportError { + kind: crate::varint::VarInt::from_u32(0x01), + frame_type: crate::varint::VarInt::from_u32(0x00), + reason: format!("sign error: {e}").into(), + }, + }) + } + + async fn public_key(&self) -> Result { + Ok(SerdeSubjectPublicKeyInfoDer::from( + authority::LocalAuthority::public_key(self), + )) + } + + async fn verify( + &self, + data: Vec, + signature: Vec, + ) -> Result { + authority::LocalAuthority::verify(self, &data, &signature) + .await + // lossy: TransportError.reason is a protocol string field + .map_err(|e| quic::ConnectionError::Transport { + source: quic::TransportError { + kind: crate::varint::VarInt::from_u32(0x01), + frame_type: crate::varint::VarInt::from_u32(0x00), + reason: format!("verify error: {e}").into(), + }, + }) + } +} + +impl RemoteAuthority for A +where + A: authority::RemoteAuthority + Send + Sync, +{ + async fn name(&self) -> Result { + Ok(authority::RemoteAuthority::name(self).to_owned()) + } + + async fn cert_chain(&self) -> Result, quic::ConnectionError> { + Ok(self + .cert_chain() + .iter() + .cloned() + .map(SerdeCertificateDer::from) + .collect()) + } + + async fn public_key(&self) -> Result { + Ok(SerdeSubjectPublicKeyInfoDer::from( + authority::RemoteAuthority::public_key(self), + )) + } + + async fn verify( + &self, + data: Vec, + signature: Vec, + ) -> Result { + authority::RemoteAuthority::verify(self, &data, &signature) + .await + // lossy: TransportError.reason is a protocol string field + .map_err(|e| quic::ConnectionError::Transport { + source: quic::TransportError { + kind: crate::varint::VarInt::from_u32(0x01), + frame_type: crate::varint::VarInt::from_u32(0x00), + reason: format!("verify error: {e}").into(), + }, + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use remoc::prelude::ServerShared; + use tokio_util::task::AbortOnDropHandle; + use tracing::Instrument; + + use super::*; + use crate::dquic::cert::handy::ToCertificate; + + const SERVER_CERT: &[u8] = include_bytes!("../../../tests/keychain/localhost/server.cert"); + + #[derive(Clone, Debug)] + struct TestLocalAuthority { + name: &'static str, + cert_chain: Vec>, + fail_sign: bool, + } + + impl TestLocalAuthority { + fn new(name: &'static str) -> Self { + Self { + name, + cert_chain: SERVER_CERT.to_certificate(), + fail_sign: false, + } + } + + fn failing_signer() -> Self { + Self { + fail_sign: true, + ..Self::new("failing-local") + } + } + + fn invalid_cert_chain(name: &'static str) -> Self { + Self { + name, + cert_chain: vec![CertificateDer::from(vec![0x01, 0x02, 0x03])], + fail_sign: false, + } + } + } + + impl authority::LocalAuthority for TestLocalAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + + fn sign(&self, data: &[u8]) -> BoxFuture<'_, Result, SignError>> { + let fail_sign = self.fail_sign; + let data = data.to_vec(); + Box::pin(async move { + if fail_sign { + return Err(SignError::UnsupportedKey); + } + Ok(expected_signature(&data)) + }) + } + } + + #[derive(Clone, Debug)] + struct TestRemoteAuthority { + name: &'static str, + cert_chain: Vec>, + } + + impl TestRemoteAuthority { + fn new(name: &'static str) -> Self { + Self { + name, + cert_chain: SERVER_CERT.to_certificate(), + } + } + + fn invalid_cert_chain(name: &'static str) -> Self { + Self { + name, + cert_chain: vec![CertificateDer::from(vec![0x04, 0x05, 0x06])], + } + } + } + + impl authority::RemoteAuthority for TestRemoteAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + } + + fn expected_signature(data: &[u8]) -> Vec { + let mut signature = b"canonical:".to_vec(); + signature.extend_from_slice(data); + signature + } + + fn assert_transport_reason(error: quic::ConnectionError, fragment: &str) { + let quic::ConnectionError::Transport { source } = error else { + panic!("expected transport error"); + }; + assert!( + source.reason.contains(fragment), + "transport reason {:?} should contain {fragment:?}", + source.reason, + ); + } + + fn spawn_local_authority_server( + authority: TestLocalAuthority, + ) -> (AbortOnDropHandle<()>, LocalAuthorityClient) { + let (server, client) = LocalAuthorityServerShared::new(Arc::new(authority), 1); + let task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + (task, client) + } + + fn spawn_remote_authority_server( + authority: TestRemoteAuthority, + ) -> (AbortOnDropHandle<()>, RemoteAuthorityClient) { + let (server, client) = RemoteAuthorityServerShared::new(Arc::new(authority), 1); + let task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + (task, client) + } + + #[tokio::test] + async fn blanket_local_authority_delegates_all_methods() { + let authority = TestLocalAuthority::new("local.example"); + + assert_eq!( + super::LocalAuthority::name(&authority).await.expect("name"), + "local.example", + ); + let certs = super::LocalAuthority::cert_chain(&authority) + .await + .expect("cert chain"); + let cert = CertificateDer::from(certs.into_iter().next().expect("certificate")); + assert_eq!(cert.as_ref(), authority.cert_chain[0].as_ref()); + + let signature = super::LocalAuthority::sign(&authority, b"payload".to_vec()) + .await + .expect("sign"); + assert_eq!(signature, expected_signature(b"payload")); + + let public_key = SubjectPublicKeyInfoDer::from( + super::LocalAuthority::public_key(&authority) + .await + .expect("public key"), + ); + assert_eq!( + public_key.as_ref(), + authority::LocalAuthority::public_key(&authority).as_ref(), + ); + + let verified = super::LocalAuthority::verify( + &authority, + b"payload".to_vec(), + b"not a real signature".to_vec(), + ) + .await + .expect("verify"); + assert!(!verified); + } + + #[tokio::test] + async fn blanket_remote_authority_delegates_all_methods() { + let authority = TestRemoteAuthority::new("remote.example"); + + assert_eq!( + super::RemoteAuthority::name(&authority) + .await + .expect("name"), + "remote.example", + ); + let certs = super::RemoteAuthority::cert_chain(&authority) + .await + .expect("cert chain"); + let cert = CertificateDer::from(certs.into_iter().next().expect("certificate")); + assert_eq!(cert.as_ref(), authority.cert_chain[0].as_ref()); + + let public_key = SubjectPublicKeyInfoDer::from( + super::RemoteAuthority::public_key(&authority) + .await + .expect("public key"), + ); + assert_eq!( + public_key.as_ref(), + authority::RemoteAuthority::public_key(&authority).as_ref(), + ); + + let verified = super::RemoteAuthority::verify( + &authority, + b"payload".to_vec(), + b"not a real signature".to_vec(), + ) + .await + .expect("verify"); + assert!(!verified); + } + + #[tokio::test] + async fn blanket_authority_errors_become_transport_errors() { + let local = TestLocalAuthority::failing_signer(); + let local_with_invalid_key = TestLocalAuthority::invalid_cert_chain("invalid-local"); + let remote_with_invalid_key = TestRemoteAuthority::invalid_cert_chain("invalid-remote"); + + let error = super::LocalAuthority::sign(&local, b"payload".to_vec()) + .await + .expect_err("sign error should be mapped"); + assert_transport_reason(error, "sign error"); + + let error = super::LocalAuthority::verify( + &local_with_invalid_key, + b"payload".to_vec(), + b"signature".to_vec(), + ) + .await + .expect_err("local verify error should be mapped"); + assert_transport_reason(error, "verify error"); + + let error = super::RemoteAuthority::verify( + &remote_with_invalid_key, + b"payload".to_vec(), + b"signature".to_vec(), + ) + .await + .expect_err("remote verify error should be mapped"); + assert_transport_reason(error, "verify error"); + } + + #[tokio::test] + async fn cached_local_authority_fetches_remote_fields_and_delegates_sign() { + let authority = TestLocalAuthority::new("cached-local.example"); + let (_task, client) = spawn_local_authority_server(authority.clone()); + + let cached = CachedLocalAuthority::from_client(client) + .await + .expect("cached local authority"); + + assert_eq!(authority::LocalAuthority::name(&cached), authority.name); + assert_eq!( + authority::LocalAuthority::cert_chain(&cached)[0].as_ref(), + authority.cert_chain[0].as_ref(), + ); + assert!( + format!("{cached:?}").contains("CachedRemoteLocalAuthority"), + "debug output should name cached local authority", + ); + + let signature = authority::LocalAuthority::sign(&cached, b"payload") + .await + .expect("cached sign"); + assert_eq!(signature, expected_signature(b"payload")); + + let public_key = authority::LocalAuthority::public_key(&cached); + assert_eq!( + public_key.as_ref(), + authority::LocalAuthority::public_key(&authority).as_ref() + ); + let verified = + authority::LocalAuthority::verify(&cached, b"payload", b"not a real signature") + .await + .expect("cached verify"); + assert!(!verified); + } + + #[tokio::test] + async fn cached_remote_authority_fetches_remote_fields() { + let authority = TestRemoteAuthority::new("cached-remote.example"); + let (_task, client) = spawn_remote_authority_server(authority.clone()); + + let cached = CachedRemoteAuthority::from_client(client) + .await + .expect("cached remote authority"); + + assert_eq!(authority::RemoteAuthority::name(&cached), authority.name); + assert_eq!( + authority::RemoteAuthority::cert_chain(&cached)[0].as_ref(), + authority.cert_chain[0].as_ref(), + ); + assert!( + format!("{cached:?}").contains("CachedRemoteRemoteAuthority"), + "debug output should name cached remote authority", + ); + + let public_key = authority::RemoteAuthority::public_key(&cached); + assert_eq!( + public_key.as_ref(), + authority::RemoteAuthority::public_key(&authority).as_ref(), + ); + let verified = + authority::RemoteAuthority::verify(&cached, b"payload", b"not a real signature") + .await + .expect("cached verify"); + assert!(!verified); + } +} diff --git a/src/rpc/quic/connect.rs b/src/rpc/quic/connect.rs index c8219eb..f2eea55 100644 --- a/src/rpc/quic/connect.rs +++ b/src/rpc/quic/connect.rs @@ -32,8 +32,8 @@ pub trait Connect: Send + Sync { impl Connect for C where C: quic::Connect + 'static, - ::LocalAgent: Send + Sync, - ::RemoteAgent: Send + Sync, + ::LocalAuthority: Send + Sync, + ::RemoteAuthority: Send + Sync, { async fn connect(&self, server: SerdeAuthority) -> Result { // lossy: cross-process serialization boundary @@ -42,6 +42,9 @@ where .await .map_err(|e| StringError::new(e.to_string()))?; let (server, client) = ConnectionServerShared::new(connection, 1); + // Inherent termination: the returned ConnectionClient owns the remoc + // endpoint; when that client is dropped or the channel closes, + // server.serve exits. tokio::spawn( (async move { let _ = server.serve(true).await; @@ -105,3 +108,147 @@ impl quic::Connect for RemoteConnector { Ok(Arc::new(RemoteConnection::from(client))) } } + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, + }; + + use remoc::prelude::ServerShared; + use tokio_util::task::AbortOnDropHandle; + use tracing::Instrument; + + use super::*; + use crate::connection::tests::MockConnection; + + #[derive(Debug, snafu::Snafu)] + #[snafu(display("test connector failed"))] + struct TestConnectError; + + #[derive(Debug, Default)] + struct TestConnector { + connections: Mutex>>, + calls: AtomicUsize, + last_server: Mutex>, + } + + impl TestConnector { + fn with_connections(connections: impl IntoIterator>) -> Self { + Self { + connections: Mutex::new(connections.into_iter().collect()), + calls: AtomicUsize::default(), + last_server: Mutex::default(), + } + } + + fn call_count(&self) -> usize { + self.calls.load(Ordering::Relaxed) + } + + fn last_server(&self) -> Option { + self.last_server + .lock() + .expect("last server mutex should not be poisoned") + .clone() + } + } + + impl quic::Connect for TestConnector { + type Connection = MockConnection; + type Error = TestConnectError; + + async fn connect<'a>( + &'a self, + server: &'a Authority, + ) -> Result, Self::Error> { + self.calls.fetch_add(1, Ordering::Relaxed); + *self + .last_server + .lock() + .expect("last server mutex should not be poisoned") = Some(server.clone()); + self.connections + .lock() + .expect("connection queue mutex should not be poisoned") + .pop_front() + .ok_or(TestConnectError) + } + } + + #[tokio::test] + async fn blanket_connect_returns_connection_client() { + let authority = Authority::from_static("example.test:443"); + let connector = TestConnector::with_connections([Arc::new(MockConnection::new())]); + + let client = Connect::connect(&connector, SerdeAuthority::from(&authority)) + .await + .expect("blanket connect should return a remote connection client"); + + assert_eq!(connector.call_count(), 1); + assert_eq!(connector.last_server(), Some(authority)); + + let remote = RemoteConnection::from(client); + assert!(quic::Lifecycle::check(&remote).is_ok()); + } + + #[tokio::test] + async fn blanket_connect_stringifies_transport_errors() { + let authority = Authority::from_static("example.test:443"); + let connector = TestConnector::default(); + + let error = Connect::connect(&connector, SerdeAuthority::from(&authority)) + .await + .expect_err("empty connector should fail"); + + let ConnectError::Remote { source } = error else { + panic!("connector error should cross RPC boundary as remote error"); + }; + assert_eq!(source.as_str(), "test connector failed"); + } + + #[tokio::test] + async fn remote_connector_delegates_to_connect_client() { + let authority = Authority::from_static("example.test:443"); + let connector = Arc::new(TestConnector::with_connections([Arc::new( + MockConnection::new(), + )])); + let (server, client) = ConnectServerShared::new(connector.clone(), 1); + let _server_task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + let remote = RemoteConnector::new(client); + + let connection = quic::Connect::connect(&remote, &authority) + .await + .expect("remote connector should connect through RTC"); + + assert_eq!(connector.call_count(), 1); + assert!(quic::Lifecycle::check(connection.as_ref()).is_ok()); + } + + #[tokio::test] + async fn remote_connector_conversions_preserve_client_handle() { + let connector = Arc::new(TestConnector::default()); + let (server, client) = ConnectServerShared::new(connector, 1); + let _server_task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + + let remote = RemoteConnector::new(client.clone()); + let inner = remote.into_inner(); + let remote = ConnectClient::into_quic(inner.clone()); + let inner: ConnectClient = remote.into(); + let remote = RemoteConnector::from(inner); + let _inner = ConnectClient::from(remote); + } +} diff --git a/src/rpc/quic/connection.rs b/src/rpc/quic/connection.rs index f782b56..48374ed 100644 --- a/src/rpc/quic/connection.rs +++ b/src/rpc/quic/connection.rs @@ -1,34 +1,41 @@ -use std::{borrow::Cow, sync::Arc}; - -use remoc::{ - prelude::{Server, ServerShared}, - rtc::Client as RemocClient, +use std::{ + borrow::Cow, + sync::{Arc, Mutex}, }; + +use remoc::{prelude::ServerShared, rtc::Client as RemocClient}; use serde::{Deserialize, Serialize}; +use tokio_util::task::AbortOnDropHandle; use tracing::Instrument; use super::{ - agent::{CachedLocalAgent, CachedRemoteAgent, LocalAgentClient, RemoteAgentClient}, - stream::{self, ReadStreamClient, WriteStreamClient}, + authority::{ + CachedLocalAuthority, CachedRemoteAuthority, LocalAuthorityClient, RemoteAuthorityClient, + }, + stream::{ReadFrameChannels, WriteFrameChannels}, }; use crate::{ + dhttp::message::guard, error::Code, - quic::{self, ConnectionError}, + quic::{self, BoxQuicStreamReader, BoxQuicStreamWriter, ConnectionError, GetStreamIdExt}, rpc::lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, varint::VarInt, }; #[remoc::rtc::remote] pub trait Connection: Send + Sync { - async fn open_bi(&self) - -> Result<(ReadStreamClient, WriteStreamClient), quic::ConnectionError>; - async fn open_uni(&self) -> Result; + async fn open_bi( + &self, + ) -> Result<(ReadFrameChannels, WriteFrameChannels), quic::ConnectionError>; + async fn open_uni(&self) -> Result; async fn accept_bi( &self, - ) -> Result<(ReadStreamClient, WriteStreamClient), quic::ConnectionError>; - async fn accept_uni(&self) -> Result; - async fn local_agent(&self) -> Result, quic::ConnectionError>; - async fn remote_agent(&self) -> Result, quic::ConnectionError>; + ) -> Result<(ReadFrameChannels, WriteFrameChannels), quic::ConnectionError>; + async fn accept_uni(&self) -> Result; + async fn local_authority(&self) -> Result, quic::ConnectionError>; + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError>; async fn close( &self, code: Code, @@ -44,80 +51,41 @@ pub trait Connection: Send + Sync { impl Connection for C where C: quic::Connection + 'static, - C::LocalAgent: Send + Sync, - C::RemoteAgent: Send + Sync, + C::LocalAuthority: Send + Sync, + C::RemoteAuthority: Send + Sync, { async fn open_bi( &self, - ) -> Result<(ReadStreamClient, WriteStreamClient), quic::ConnectionError> { + ) -> Result<(ReadFrameChannels, WriteFrameChannels), quic::ConnectionError> { let (reader, writer) = quic::ManageStream::open_bi(self).await?; - let (rs, rc) = stream::ReadStreamServer::new(Box::pin(reader), 1); - tokio::spawn( - (async move { - let _ = rs.serve().await; - }) - .in_current_span(), - ); - let (ws, wc) = stream::WriteStreamServer::new(Box::pin(writer), 1); - tokio::spawn( - (async move { - let _ = ws.serve().await; - }) - .in_current_span(), - ); - Ok((rc, wc)) + Ok((read_channels(reader).await?, write_channels(writer).await?)) } - async fn open_uni(&self) -> Result { + async fn open_uni(&self) -> Result { let writer = quic::ManageStream::open_uni(self).await?; - let (ws, wc) = stream::WriteStreamServer::new(Box::pin(writer), 1); - tokio::spawn( - (async move { - let _ = ws.serve().await; - }) - .in_current_span(), - ); - Ok(wc) + write_channels(writer).await } async fn accept_bi( &self, - ) -> Result<(ReadStreamClient, WriteStreamClient), quic::ConnectionError> { + ) -> Result<(ReadFrameChannels, WriteFrameChannels), quic::ConnectionError> { let (reader, writer) = quic::ManageStream::accept_bi(self).await?; - let (rs, rc) = stream::ReadStreamServer::new(Box::pin(reader), 1); - tokio::spawn( - (async move { - let _ = rs.serve().await; - }) - .in_current_span(), - ); - let (ws, wc) = stream::WriteStreamServer::new(Box::pin(writer), 1); - tokio::spawn( - (async move { - let _ = ws.serve().await; - }) - .in_current_span(), - ); - Ok((rc, wc)) + Ok((read_channels(reader).await?, write_channels(writer).await?)) } - async fn accept_uni(&self) -> Result { + async fn accept_uni(&self) -> Result { let reader = quic::ManageStream::accept_uni(self).await?; - let (rs, rc) = stream::ReadStreamServer::new(Box::pin(reader), 1); - tokio::spawn( - (async move { - let _ = rs.serve().await; - }) - .in_current_span(), - ); - Ok(rc) + read_channels(reader).await } - async fn local_agent(&self) -> Result, quic::ConnectionError> { - match quic::WithLocalAgent::local_agent(self).await? { + async fn local_authority(&self) -> Result, quic::ConnectionError> { + match quic::WithLocalAuthority::local_authority(self).await? { Some(agent) => { let (server, client) = - super::agent::LocalAgentServerShared::new(Arc::new(agent), 1); + super::authority::LocalAuthorityServerShared::new(Arc::new(agent), 1); + // Inherent termination: the returned authority client owns + // the remoc endpoint; when the client is dropped or the + // channel closes, server.serve exits. tokio::spawn( (async move { let _ = server.serve(true).await; @@ -130,11 +98,16 @@ where } } - async fn remote_agent(&self) -> Result, quic::ConnectionError> { - match quic::WithRemoteAgent::remote_agent(self).await? { + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + match quic::WithRemoteAuthority::remote_authority(self).await? { Some(agent) => { let (server, client) = - super::agent::RemoteAgentServerShared::new(Arc::new(agent), 1); + super::authority::RemoteAuthorityServerShared::new(Arc::new(agent), 1); + // Inherent termination: the returned authority client owns + // the remoc endpoint; when the client is dropped or the + // channel closes, server.serve exits. tokio::spawn( (async move { let _ = server.serve(true).await; @@ -178,6 +151,8 @@ pub struct RemoteConnection { client: ConnectionClient, #[serde(skip)] latch: ConnectionErrorLatch, + #[serde(skip)] + close_tasks: Arc>>>, } impl RemoteConnection { @@ -185,6 +160,7 @@ impl RemoteConnection { Self { client, latch: ConnectionErrorLatch::new(), + close_tasks: Arc::new(Mutex::new(Vec::new())), } } @@ -230,50 +206,126 @@ impl From for ConnectionClient { } impl quic::ManageStream for RemoteConnection { - type StreamWriter = crate::dhttp::protocol::BoxDynQuicStreamWriter; - type StreamReader = crate::dhttp::protocol::BoxDynQuicStreamReader; + type StreamWriter = crate::dhttp::message::guard::GuardQuicWriter; + type StreamReader = crate::dhttp::message::guard::GuardQuicReader; async fn open_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { let (reader, writer) = self.guard(Connection::open_bi(&self.client)).await?; - Ok((reader.into_boxed_quic(), writer.into_boxed_quic())) + Ok(( + self.read_channels_into_quic(reader), + self.write_channels_into_quic(writer), + )) } async fn open_uni(&self) -> Result { let writer = self.guard(Connection::open_uni(&self.client)).await?; - Ok(writer.into_boxed_quic()) + Ok(self.write_channels_into_quic(writer)) } async fn accept_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), ConnectionError> { let (reader, writer) = self.guard(Connection::accept_bi(&self.client)).await?; - Ok((reader.into_boxed_quic(), writer.into_boxed_quic())) + Ok(( + self.read_channels_into_quic(reader), + self.write_channels_into_quic(writer), + )) } async fn accept_uni(&self) -> Result { let reader = self.guard(Connection::accept_uni(&self.client)).await?; - Ok(reader.into_boxed_quic()) + Ok(self.read_channels_into_quic(reader)) + } +} + +impl RemoteConnection { + fn read_channels_into_quic(&self, channels: ReadFrameChannels) -> guard::GuardQuicReader { + let lifecycle = Arc::new(self.clone()); + let raw = Box::pin(channels.into_quic(lifecycle)) as BoxQuicStreamReader; + guard::GuardQuicReader::new(raw) + } + + fn write_channels_into_quic(&self, channels: WriteFrameChannels) -> guard::GuardQuicWriter { + let lifecycle = Arc::new(self.clone()); + let raw = Box::pin(channels.into_quic(lifecycle)) as BoxQuicStreamWriter; + guard::GuardQuicWriter::new(raw) + } +} + +async fn read_channels(mut reader: R) -> Result +where + R: quic::ReadStream + Unpin + 'static, +{ + let stream_id = match reader.stream_id().await { + Ok(stream_id) => stream_id, + Err(error) => return Err(stream_id_error(error)), + }; + let (channels, bridge) = ReadFrameChannels::pair(stream_id); + // Inherent termination: this task owns the real stream and remoc frame IO. + // It exits when the real stream reaches a terminal state, the worker drops + // the frame channels, or frame IO reports a connection failure. + tokio::spawn( + crate::rpc::stream::hypervisor::read::run_read_bridge(reader, bridge).in_current_span(), + ); + Ok(channels) +} + +async fn write_channels(mut writer: W) -> Result +where + W: quic::WriteStream + Unpin + 'static, +{ + let stream_id = match writer.stream_id().await { + Ok(stream_id) => stream_id, + Err(error) => return Err(stream_id_error(error)), + }; + let (channels, bridge) = WriteFrameChannels::pair(stream_id); + // Inherent termination: this task owns the real stream and remoc frame IO. + // It exits when the real stream reaches a terminal state, the worker drops + // the frame channels, or frame IO reports a connection failure. + tokio::spawn( + crate::rpc::stream::hypervisor::write::run_write_bridge(writer, bridge).in_current_span(), + ); + Ok(channels) +} + +fn stream_id_error(error: quic::StreamError) -> quic::ConnectionError { + match error { + quic::StreamError::Connection { source } => source, + quic::StreamError::Reset { code } => quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x0d), + frame_type: VarInt::from_u32(0x00), + reason: format!("rpc stream id reset with code {code}").into(), + }, + }, } } -impl quic::WithLocalAgent for RemoteConnection { - type LocalAgent = CachedLocalAgent; +impl quic::WithLocalAuthority for RemoteConnection { + type LocalAuthority = CachedLocalAuthority; - async fn local_agent(&self) -> Result, ConnectionError> { - match self.guard(Connection::local_agent(&self.client)).await? { + async fn local_authority(&self) -> Result, ConnectionError> { + match self + .guard(Connection::local_authority(&self.client)) + .await? + { Some(agent) => Ok(Some( - self.guard(CachedLocalAgent::from_client(agent)).await?, + self.guard(CachedLocalAuthority::from_client(agent)).await?, )), None => Ok(None), } } } -impl quic::WithRemoteAgent for RemoteConnection { - type RemoteAgent = CachedRemoteAgent; +impl quic::WithRemoteAuthority for RemoteConnection { + type RemoteAuthority = CachedRemoteAuthority; - async fn remote_agent(&self) -> Result, ConnectionError> { - match self.guard(Connection::remote_agent(&self.client)).await? { + async fn remote_authority(&self) -> Result, ConnectionError> { + match self + .guard(Connection::remote_authority(&self.client)) + .await? + { Some(agent) => Ok(Some( - self.guard(CachedRemoteAgent::from_client(agent)).await?, + self.guard(CachedRemoteAuthority::from_client(agent)) + .await?, )), None => Ok(None), } @@ -283,12 +335,18 @@ impl quic::WithRemoteAgent for RemoteConnection { impl quic::Lifecycle for RemoteConnection { fn close(&self, code: Code, reason: Cow<'static, str>) { let client = self.client.clone(); - tokio::spawn( - async move { + let handle = AbortOnDropHandle::new(tokio::spawn( + (async move { let _ = Connection::close(&client, code, reason).await; - } + }) .in_current_span(), - ); + )); + let mut tasks = self + .close_tasks + .lock() + .expect("remote connection close task registry should not be poisoned"); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); } fn check(&self) -> Result<(), ConnectionError> { @@ -306,3 +364,789 @@ impl quic::Lifecycle for RemoteConnection { .await } } + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::Mutex, + task::{Context, Poll}, + time::Duration, + }; + + use bytes::Bytes; + use dhttp_identity::identity::{self as authority, SignError}; + use futures::{Sink, SinkExt, Stream, StreamExt, future::BoxFuture}; + use remoc::prelude::ServerShared; + use rustls::pki_types::CertificateDer; + use tokio_util::task::AbortOnDropHandle; + + use super::*; + use crate::{ + dquic::cert::handy::ToCertificate, + quic::{ + BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamId, GetStreamIdExt, ResetStream, + StopStream, StopStreamExt, + }, + }; + + const SERVER_CERT: &[u8] = include_bytes!("../../../tests/keychain/localhost/server.cert"); + + #[derive(Clone, Debug)] + struct TestLocalAuthority { + name: &'static str, + cert_chain: Vec>, + } + + impl TestLocalAuthority { + fn new(name: &'static str) -> Self { + Self { + name, + cert_chain: SERVER_CERT.to_certificate(), + } + } + } + + impl authority::LocalAuthority for TestLocalAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + fn sign(&self, data: &[u8]) -> BoxFuture<'_, Result, SignError>> { + let signature = expected_signature(data); + Box::pin(std::future::ready(Ok(signature))) + } + } + + #[derive(Clone, Debug)] + struct TestRemoteAuthority { + name: &'static str, + cert_chain: Vec>, + } + + impl TestRemoteAuthority { + fn new(name: &'static str) -> Self { + Self { + name, + cert_chain: SERVER_CERT.to_certificate(), + } + } + } + + impl authority::RemoteAuthority for TestRemoteAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + } + + struct TestQuicConnection { + open_bi_error: Option, + open_uni_error: Option, + accept_bi_error: Option, + accept_uni_error: Option, + local_authority: Option, + local_authority_error: Option, + remote_authority: Option, + remote_authority_error: Option, + terminal: Mutex>, + closes: Mutex)>>, + } + + impl TestQuicConnection { + fn new() -> Self { + Self { + open_bi_error: None, + open_uni_error: None, + accept_bi_error: None, + accept_uni_error: None, + local_authority: None, + local_authority_error: None, + remote_authority: None, + remote_authority_error: None, + terminal: Mutex::new(None), + closes: Mutex::new(Vec::new()), + } + } + + fn with_agents() -> Self { + Self { + local_authority: Some(TestLocalAuthority::new("local.example")), + remote_authority: Some(TestRemoteAuthority::new("remote.example")), + ..Self::new() + } + } + + fn fail_open_bi(reason: &'static str) -> Self { + Self { + open_bi_error: Some(connection_error(reason)), + ..Self::new() + } + } + + fn fail_open_uni(reason: &'static str) -> Self { + Self { + open_uni_error: Some(connection_error(reason)), + ..Self::new() + } + } + + fn fail_accept_bi(reason: &'static str) -> Self { + Self { + accept_bi_error: Some(connection_error(reason)), + ..Self::new() + } + } + + fn fail_accept_uni(reason: &'static str) -> Self { + Self { + accept_uni_error: Some(connection_error(reason)), + ..Self::new() + } + } + + fn fail_local_authority(reason: &'static str) -> Self { + Self { + local_authority_error: Some(connection_error(reason)), + ..Self::new() + } + } + + fn fail_remote_authority(reason: &'static str) -> Self { + Self { + remote_authority_error: Some(connection_error(reason)), + ..Self::new() + } + } + + fn set_terminal(&self, error: quic::ConnectionError) { + *self + .terminal + .lock() + .expect("terminal mutex should not be poisoned") = Some(error); + } + + fn closes(&self) -> Vec<(Code, Cow<'static, str>)> { + self.closes + .lock() + .expect("closes mutex should not be poisoned") + .clone() + } + } + + impl quic::ManageStream for TestQuicConnection { + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + if let Some(error) = &self.open_bi_error { + return Err(error.clone()); + } + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(1)); + Ok(( + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + )) + } + + async fn open_uni(&self) -> Result { + if let Some(error) = &self.open_uni_error { + return Err(error.clone()); + } + let (_reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(2)); + Ok(Box::pin(writer) as BoxQuicStreamWriter) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + if let Some(error) = &self.accept_bi_error { + return Err(error.clone()); + } + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(3)); + Ok(( + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + )) + } + + async fn accept_uni(&self) -> Result { + if let Some(error) = &self.accept_uni_error { + return Err(error.clone()); + } + let (reader, _writer) = quic::test::mock_stream_pair(VarInt::from_u32(4)); + Ok(Box::pin(reader) as BoxQuicStreamReader) + } + } + + impl quic::WithLocalAuthority for TestQuicConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + if let Some(error) = &self.local_authority_error { + return Err(error.clone()); + } + Ok(self.local_authority.clone()) + } + } + + impl quic::WithRemoteAuthority for TestQuicConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + if let Some(error) = &self.remote_authority_error { + return Err(error.clone()); + } + Ok(self.remote_authority.clone()) + } + } + + impl quic::Lifecycle for TestQuicConnection { + fn close(&self, code: Code, reason: Cow<'static, str>) { + self.closes + .lock() + .expect("closes mutex should not be poisoned") + .push((code, reason)); + } + + fn check(&self) -> Result<(), quic::ConnectionError> { + match self + .terminal + .lock() + .expect("terminal mutex should not be poisoned") + .clone() + { + Some(error) => Err(error), + None => Ok(()), + } + } + + async fn closed(&self) -> quic::ConnectionError { + self.terminal + .lock() + .expect("terminal mutex should not be poisoned") + .clone() + .unwrap_or_else(|| connection_error("quic closed")) + } + } + + #[derive(Clone)] + struct BrokenIdReadStream { + error: quic::StreamError, + } + + impl GetStreamId for BrokenIdReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Err(self.get_mut().error.clone())) + } + } + + impl StopStream for BrokenIdReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Stream for BrokenIdReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(None) + } + } + + #[derive(Clone)] + struct BrokenIdWriteStream { + error: quic::StreamError, + } + + impl GetStreamId for BrokenIdWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Err(self.get_mut().error.clone())) + } + } + + impl ResetStream for BrokenIdWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl Sink for BrokenIdWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + struct BrokenIdQuicConnection; + + impl quic::ManageStream for BrokenIdQuicConnection { + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + Ok(( + Box::pin(BrokenIdReadStream { + error: stream_connection_error("open bidi reader stream id failed"), + }) as BoxQuicStreamReader, + Box::pin(BrokenIdWriteStream { + error: stream_connection_error("open bidi writer stream id failed"), + }) as BoxQuicStreamWriter, + )) + } + + async fn open_uni(&self) -> Result { + Ok(Box::pin(BrokenIdWriteStream { + error: stream_connection_error("open uni writer stream id failed"), + }) as BoxQuicStreamWriter) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + Ok(( + Box::pin(BrokenIdReadStream { + error: stream_connection_error("accept bidi reader stream id failed"), + }) as BoxQuicStreamReader, + Box::pin(BrokenIdWriteStream { + error: stream_connection_error("accept bidi writer stream id failed"), + }) as BoxQuicStreamWriter, + )) + } + + async fn accept_uni(&self) -> Result { + Ok(Box::pin(BrokenIdReadStream { + error: stream_connection_error("accept uni reader stream id failed"), + }) as BoxQuicStreamReader) + } + } + + impl quic::WithLocalAuthority for BrokenIdQuicConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for BrokenIdQuicConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for BrokenIdQuicConnection { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + connection_error("broken id closed") + } + } + + fn spawn_rpc_connection(connection: Arc) -> (AbortOnDropHandle<()>, ConnectionClient) + where + C: super::Connection + 'static, + { + let (server, client) = ConnectionServerShared::new(connection, 1); + let task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + (task, client) + } + + fn connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + fn stream_connection_error(reason: &'static str) -> quic::StreamError { + quic::StreamError::Connection { + source: connection_error(reason), + } + } + + fn assert_reason(error: &quic::ConnectionError, expected: &str) { + let quic::ConnectionError::Transport { source } = error else { + panic!("expected transport error"); + }; + assert_eq!(source.reason.as_ref(), expected); + } + + fn expected_signature(data: &[u8]) -> Vec { + let mut signature = b"canonical:".to_vec(); + signature.extend_from_slice(data); + signature + } + + async fn assert_roundtrip( + reader: &mut (impl Stream> + Unpin), + writer: &mut (impl Sink + Unpin), + payload: &'static [u8], + ) { + let bytes = Bytes::from_static(payload); + writer.send(bytes.clone()).await.expect("write"); + let received = reader + .next() + .await + .expect("reader should produce one chunk") + .expect("read"); + assert_eq!(received, bytes); + } + + #[tokio::test] + async fn remote_connection_conversions_preserve_client() { + let connection = Arc::new(TestQuicConnection::new()); + let (_task, client) = spawn_rpc_connection(connection); + + let remote = ConnectionClient::into_quic(client); + let client = remote.clone().into_inner(); + let remote = RemoteConnection::from(client); + let client = ConnectionClient::from(remote); + let remote = client.into_quic(); + + let mut writer = quic::ManageStream::open_uni(&remote) + .await + .expect("open uni"); + assert_eq!( + writer.stream_id().await.expect("stream id"), + VarInt::from_u32(2) + ); + } + + #[tokio::test] + async fn remote_connection_delegates_stream_operations_and_agents() { + let connection = Arc::new(TestQuicConnection::with_agents()); + let (_task, client) = spawn_rpc_connection(connection.clone()); + let remote = client.into_quic(); + + quic::Lifecycle::check(&remote).expect("fresh connection should be live"); + + let (mut reader, mut writer) = quic::ManageStream::open_bi(&remote) + .await + .expect("open bidi"); + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(1) + ); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(1) + ); + assert_roundtrip(&mut reader, &mut writer, b"open-bidi").await; + + let mut writer = quic::ManageStream::open_uni(&remote) + .await + .expect("open uni"); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(2) + ); + + let (mut reader, mut writer) = quic::ManageStream::accept_bi(&remote) + .await + .expect("accept bidi"); + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(3) + ); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(3) + ); + assert_roundtrip(&mut reader, &mut writer, b"accept-bidi").await; + + let mut reader = quic::ManageStream::accept_uni(&remote) + .await + .expect("accept uni"); + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(4) + ); + reader + .stop(VarInt::from_u32(10)) + .await + .expect("stop accepted reader"); + + let local_authority = quic::WithLocalAuthority::local_authority(&remote) + .await + .expect("local authority") + .expect("local authority should exist"); + assert_eq!( + authority::LocalAuthority::name(&local_authority), + "local.example" + ); + assert_eq!( + authority::LocalAuthority::cert_chain(&local_authority).len(), + 1 + ); + let signature = authority::LocalAuthority::sign(&local_authority, b"payload") + .await + .expect("local authority sign"); + assert_eq!(signature, expected_signature(b"payload")); + + let remote_authority = quic::WithRemoteAuthority::remote_authority(&remote) + .await + .expect("remote authority") + .expect("remote authority should exist"); + assert_eq!( + authority::RemoteAuthority::name(&remote_authority), + "remote.example" + ); + assert_eq!( + authority::RemoteAuthority::cert_chain(&remote_authority).len(), + 1 + ); + + quic::Lifecycle::close(&remote, Code::H3_NO_ERROR, "bye".into()); + for _ in 0..20 { + if !connection.closes().is_empty() { + break; + } + tokio::task::yield_now().await; + } + assert_eq!( + connection.closes(), + vec![(Code::H3_NO_ERROR, Cow::Borrowed("bye"))], + ); + + connection.set_terminal(connection_error("terminal")); + let closed = quic::Lifecycle::closed(&remote).await; + assert_reason(&closed, "terminal"); + let latched = quic::Lifecycle::check(&remote).expect_err("closed should latch"); + assert_reason(&latched, "terminal"); + } + + #[tokio::test] + async fn remote_connection_preserves_absent_agents() { + let connection = Arc::new(TestQuicConnection::new()); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + + let local_authority = quic::WithLocalAuthority::local_authority(&remote) + .await + .expect("local authority lookup should succeed"); + assert!(local_authority.is_none()); + + let remote_authority = quic::WithRemoteAuthority::remote_authority(&remote) + .await + .expect("remote authority lookup should succeed"); + assert!(remote_authority.is_none()); + + quic::Lifecycle::check(&remote).expect("absent agents should not close connection"); + } + + #[tokio::test] + async fn remote_connection_latches_agent_lookup_errors() { + let connection = Arc::new(TestQuicConnection::fail_local_authority( + "local authority lookup failed", + )); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + + let Err(error) = quic::WithLocalAuthority::local_authority(&remote).await else { + panic!("local authority error should surface"); + }; + assert_reason(&error, "local authority lookup failed"); + + let latched = + quic::Lifecycle::check(&remote).expect_err("local authority error should latch"); + assert_reason(&latched, "local authority lookup failed"); + + let connection = Arc::new(TestQuicConnection::fail_remote_authority( + "remote authority lookup failed", + )); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + + let Err(error) = quic::WithRemoteAuthority::remote_authority(&remote).await else { + panic!("remote authority error should surface"); + }; + assert_reason(&error, "remote authority lookup failed"); + + let closed = quic::Lifecycle::closed(&remote).await; + assert_reason(&closed, "remote authority lookup failed"); + } + + #[tokio::test] + async fn remote_connection_latches_remote_errors() { + let connection = Arc::new(TestQuicConnection::fail_open_bi("quic open bidi failed")); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + + let Err(error) = quic::ManageStream::open_bi(&remote).await else { + panic!("open error should surface"); + }; + assert_reason(&error, "quic open bidi failed"); + + let latched = quic::Lifecycle::check(&remote).expect_err("open error should latch"); + assert_reason(&latched, "quic open bidi failed"); + + let closed = quic::Lifecycle::closed(&remote).await; + assert_reason(&closed, "quic open bidi failed"); + } + + #[tokio::test] + async fn remote_connection_maps_each_stream_operation_error() { + let connection = Arc::new(TestQuicConnection::fail_open_uni("quic open uni failed")); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::open_uni(&remote).await else { + panic!("open uni error should surface"); + }; + assert_reason(&error, "quic open uni failed"); + + let connection = Arc::new(TestQuicConnection::fail_accept_bi( + "quic accept bidi failed", + )); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::accept_bi(&remote).await else { + panic!("accept bidi error should surface"); + }; + assert_reason(&error, "quic accept bidi failed"); + + let connection = Arc::new(TestQuicConnection::fail_accept_uni( + "quic accept uni failed", + )); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::accept_uni(&remote).await else { + panic!("accept uni error should surface"); + }; + assert_reason(&error, "quic accept uni failed"); + } + + #[tokio::test] + async fn remote_connection_synthesizes_remoc_channel_errors_when_server_stops() { + let connection = Arc::new(TestQuicConnection::new()); + let (task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + drop(task); + + let error = tokio::time::timeout(Duration::from_secs(1), async { + loop { + match quic::Lifecycle::check(&remote) { + Ok(()) => tokio::task::yield_now().await, + Err(error) => break error, + } + } + }) + .await + .expect("lifecycle check should complete in time"); + assert_reason(&error, "remoc connection channel closed"); + + let closed = tokio::time::timeout(Duration::from_secs(1), quic::Lifecycle::closed(&remote)) + .await + .expect("connection closed should complete in time"); + assert_reason(&closed, "remoc connection channel closed"); + + let Err(error) = quic::ManageStream::open_uni(&remote).await else { + panic!("closed remoc connection should reject operations"); + }; + assert_reason(&error, "remoc connection channel closed"); + } + + #[tokio::test] + async fn remote_connection_stream_methods_surface_stream_id_failures() { + let connection = Arc::new(BrokenIdQuicConnection); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::open_bi(&remote).await else { + panic!("open bidi should fail before returning channels"); + }; + assert_reason(&error, "open bidi reader stream id failed"); + + let connection = Arc::new(BrokenIdQuicConnection); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::open_uni(&remote).await else { + panic!("open uni should fail before returning channels"); + }; + assert_reason(&error, "open uni writer stream id failed"); + + let connection = Arc::new(BrokenIdQuicConnection); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::accept_bi(&remote).await else { + panic!("accept bidi should fail before returning channels"); + }; + assert_reason(&error, "accept bidi reader stream id failed"); + + let connection = Arc::new(BrokenIdQuicConnection); + let (_task, client) = spawn_rpc_connection(connection); + let remote = client.into_quic(); + let Err(error) = quic::ManageStream::accept_uni(&remote).await else { + panic!("accept uni should fail before returning channels"); + }; + assert_reason(&error, "accept uni reader stream id failed"); + } +} diff --git a/src/rpc/quic/listen.rs b/src/rpc/quic/listen.rs index 57a4162..c1e6ddb 100644 --- a/src/rpc/quic/listen.rs +++ b/src/rpc/quic/listen.rs @@ -29,8 +29,8 @@ pub trait Listen: Send + Sync { impl Listen for L where L: quic::Listen + 'static, - ::LocalAgent: Send + Sync, - ::RemoteAgent: Send + Sync, + ::LocalAuthority: Send + Sync, + ::RemoteAuthority: Send + Sync, { async fn accept(&mut self) -> Result { // lossy: cross-process serialization boundary @@ -38,6 +38,9 @@ where .await .map_err(|e| StringError::new(e.to_string()))?; let (server, client) = ConnectionServerShared::new(connection, 1); + // Inherent termination: the returned ConnectionClient owns the remoc + // endpoint; when that client is dropped or the channel closes, + // server.serve exits. tokio::spawn( (async move { let _ = server.serve(true).await; @@ -110,3 +113,166 @@ impl quic::Listen for RemoteListener { Listen::shutdown(&self.0).await } } + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + sync::{ + Arc, Mutex, + atomic::{AtomicUsize, Ordering}, + }, + }; + + use remoc::prelude::Server; + use tokio_util::task::AbortOnDropHandle; + use tracing::Instrument; + + use super::*; + use crate::connection::tests::MockConnection; + + #[derive(Debug, snafu::Snafu)] + #[snafu(display("test listener failed"))] + struct TestListenError; + + #[derive(Debug, Default)] + struct TestListenState { + connections: Mutex>>, + accepts: AtomicUsize, + shutdowns: AtomicUsize, + } + + #[derive(Debug, Default)] + struct TestListener { + state: Arc, + } + + impl TestListener { + fn with_connections(connections: impl IntoIterator>) -> Self { + Self { + state: Arc::new(TestListenState { + connections: Mutex::new(connections.into_iter().collect()), + accepts: AtomicUsize::default(), + shutdowns: AtomicUsize::default(), + }), + } + } + } + + impl quic::Listen for TestListener { + type Connection = MockConnection; + type Error = TestListenError; + + async fn accept(&mut self) -> Result, Self::Error> { + self.state.accepts.fetch_add(1, Ordering::Relaxed); + self.state + .connections + .lock() + .expect("connection queue mutex should not be poisoned") + .pop_front() + .ok_or(TestListenError) + } + + async fn shutdown(&self) -> Result<(), Self::Error> { + self.state.shutdowns.fetch_add(1, Ordering::Relaxed); + Ok(()) + } + } + + #[tokio::test] + async fn blanket_accept_returns_connection_client() { + let mut listener = TestListener::with_connections([Arc::new(MockConnection::new())]); + let state = listener.state.clone(); + + let client = Listen::accept(&mut listener) + .await + .expect("blanket accept should return a remote connection client"); + + assert_eq!(state.accepts.load(Ordering::Relaxed), 1); + + let remote = RemoteConnection::from(client); + assert!(quic::Lifecycle::check(&remote).is_ok()); + } + + #[tokio::test] + async fn blanket_accept_stringifies_transport_errors() { + let mut listener = TestListener::default(); + + let error = Listen::accept(&mut listener) + .await + .expect_err("empty listener should fail"); + + let ListenError::Remote { source } = error else { + panic!("listener error should cross RPC boundary as remote error"); + }; + assert_eq!(source.as_str(), "test listener failed"); + } + + #[tokio::test] + async fn blanket_shutdown_delegates_to_quic_listener() { + let listener = TestListener::default(); + let state = listener.state.clone(); + + Listen::shutdown(&listener) + .await + .expect("blanket shutdown should succeed"); + + assert_eq!(state.shutdowns.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn remote_listener_delegates_to_listen_client() { + let listener = TestListener::with_connections([Arc::new(MockConnection::new())]); + let state = listener.state.clone(); + let (server, client) = ListenServer::new(listener, 1); + let _server_task = AbortOnDropHandle::new(tokio::spawn( + async move { + let (_listener, _result) = server.serve().await; + } + .in_current_span(), + )); + let mut remote = RemoteListener::new(client); + + let connection = quic::Listen::accept(&mut remote) + .await + .expect("remote listener should accept through RTC"); + quic::Listen::shutdown(&remote) + .await + .expect("remote listener should shut down through RTC"); + + assert_eq!(state.accepts.load(Ordering::Relaxed), 1); + assert_eq!(state.shutdowns.load(Ordering::Relaxed), 1); + assert!(quic::Lifecycle::check(connection.as_ref()).is_ok()); + } + + #[tokio::test] + async fn remote_listener_conversions_preserve_client_handle() { + let listener = TestListener::default(); + let (server, client) = ListenServer::new(listener, 1); + let _server_task_a = AbortOnDropHandle::new(tokio::spawn( + async move { + let (_listener, _result) = server.serve().await; + } + .in_current_span(), + )); + + let remote = RemoteListener::new(client); + let inner = remote.into_inner(); + let remote = RemoteListener::from(inner); + let _inner = ListenClient::from(remote); + + let listener = TestListener::default(); + let (server, client) = ListenServer::new(listener, 1); + let _server_task_b = AbortOnDropHandle::new(tokio::spawn( + async move { + let (_listener, _result) = server.serve().await; + } + .in_current_span(), + )); + + let remote = ListenClient::into_quic(client); + let inner: ListenClient = remote.into(); + let remote = RemoteListener::from(inner); + let _inner = ListenClient::from(remote); + } +} diff --git a/src/rpc/quic/serde_types.rs b/src/rpc/quic/serde_types.rs index 4389a99..a879f0e 100644 --- a/src/rpc/quic/serde_types.rs +++ b/src/rpc/quic/serde_types.rs @@ -86,3 +86,65 @@ impl From for SubjectPublicKeyInfoDer<'static> { SubjectPublicKeyInfoDer::from(value.0) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn authority_round_trips_through_serde_wrapper() { + let authority = http::uri::Authority::from_static("example.test:443"); + + let wrapped = SerdeAuthority::from(&authority); + let decoded = http::uri::Authority::try_from(wrapped).expect("valid authority"); + + assert_eq!(decoded, authority); + } + + #[test] + fn invalid_authority_is_rejected_after_unwrap() { + let wrapped = SerdeAuthority("not a valid authority".to_owned()); + + assert!(http::uri::Authority::try_from(wrapped).is_err()); + } + + #[test] + fn certificate_der_round_trips_through_serde_wrapper() { + let certificate = CertificateDer::from(vec![1, 2, 3, 4]); + + let wrapped = SerdeCertificateDer::from(certificate); + let decoded = CertificateDer::from(wrapped); + + assert_eq!(decoded.as_ref(), &[1, 2, 3, 4]); + } + + #[test] + fn signature_scheme_round_trips_through_serde_wrapper() { + let scheme = SignatureScheme::ECDSA_NISTP256_SHA256; + + let wrapped = SerdeSignatureScheme::from(scheme); + let decoded = SignatureScheme::from(wrapped); + + assert_eq!(decoded, scheme); + } + + #[test] + fn signature_algorithm_round_trips_through_serde_wrapper() { + let algorithm = rustls::SignatureAlgorithm::ECDSA; + + let wrapped = SerdeSignatureAlgorithm::from(algorithm); + let decoded = rustls::SignatureAlgorithm::from(wrapped); + + assert_eq!(decoded, algorithm); + } + + #[test] + fn subject_public_key_info_round_trips_through_serde_wrapper() { + let public_key = SubjectPublicKeyInfoDer::from(vec![5, 6, 7, 8]); + + let wrapped = SerdeSubjectPublicKeyInfoDer::from(public_key); + let decoded = SubjectPublicKeyInfoDer::from(wrapped); + + assert_eq!(decoded.as_ref(), &[5, 6, 7, 8]); + } +} diff --git a/src/rpc/quic/stream.rs b/src/rpc/quic/stream.rs index ae85cd4..6bc5f6b 100644 --- a/src/rpc/quic/stream.rs +++ b/src/rpc/quic/stream.rs @@ -1,145 +1,87 @@ -use bytes::Bytes; -use futures::{SinkExt, StreamExt, future::Either}; -use tokio_util::sync::CancellationToken; +//! Channel-backed RPC stream payloads. +//! +//! Raw QUIC stream data/control is transported over typed remoc MPSC frame +//! channels and reconstructed with the shared `rpc::stream` bridge drivers. -use super::super::bridge; -use crate::{ - codec::BoxReadStream, - message::stream::guard, - quic::{self, CancelStreamExt, GetStreamIdExt, StopStreamExt}, - util::deferred::Deferred, - varint::VarInt, -}; +pub use crate::rpc::stream::remoc::{ReadFrameChannels, WriteFrameChannels}; -/// Remote trait for reading from a QUIC stream over remoc RTC. -#[remoc::rtc::remote] -pub trait ReadStream: Send { - async fn stream_id(&mut self) -> Result; - async fn read(&mut self) -> Result, quic::StreamError>; - async fn stop(&mut self, code: VarInt) -> Result<(), quic::StreamError>; -} - -impl ReadStreamClient { - pub async fn into_quic(mut self) -> Result { - let stream_id = self.stream_id().await?; - Ok(bridge::ReadBridge::<_, quic::StreamError, _, _, _, _>::new( - stream_id, - self, - |mut client: ReadStreamClient, token: CancellationToken| async move { - tokio::select! { - res = client.read() => Either::Left((client, res.transpose())), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: ReadStreamClient, code| async move { - let res = client.stop(code).await; - (client, res) - }, - )) - } +#[cfg(test)] +mod tests { + use std::sync::Arc; - pub fn into_boxed_quic(self) -> guard::GuardedQuicReader { - let raw: BoxReadStream = Box::pin(Deferred::from(self.into_quic())); - guard::GuardedQuicReader::new(raw) - } -} + use bytes::Bytes; + use futures::{FutureExt as _, SinkExt as _, StreamExt as _}; -/// Remote trait for writing to a QUIC stream over remoc RTC. -#[remoc::rtc::remote] -pub trait WriteStream: Send { - async fn stream_id(&mut self) -> Result; - async fn write(&mut self, data: Bytes) -> Result<(), quic::StreamError>; - async fn flush(&mut self) -> Result<(), quic::StreamError>; - async fn shutdown(&mut self) -> Result<(), quic::StreamError>; - async fn cancel(&mut self, code: VarInt) -> Result<(), quic::StreamError>; -} + use super::*; + use crate::{ + quic::{ResetStreamExt as _, StopStreamExt as _}, + rpc::stream::{ + frame::{ReadCommand, WriteCommand}, + test_io::TestLifecycle, + }, + varint::VarInt, + }; -impl WriteStreamClient { - pub async fn into_quic(mut self) -> Result { - let stream_id = self.stream_id().await?; - Ok(bridge::WriteBridge::< - _, - quic::StreamError, - _, - _, - _, - _, - _, - _, - _, - _, - >::new( - stream_id, - self, - |mut client: WriteStreamClient, token: CancellationToken, bytes| async move { - tokio::select! { - res = client.write(bytes) => Either::Left((client, res)), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: WriteStreamClient, token: CancellationToken| async move { - tokio::select! { - res = client.flush() => Either::Left((client, res)), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: WriteStreamClient, token: CancellationToken| async move { - tokio::select! { - res = client.shutdown() => Either::Left((client, res)), - _ = token.cancelled() => Either::Right(client), - } - }, - |mut client: WriteStreamClient, code| async move { - let res = client.cancel(code).await; - (client, res) - }, - )) - } + #[tokio::test] + async fn read_frame_channels_construct_quic_reader() { + let stream_id = VarInt::from_u32(801); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut reader = channels.into_quic(lifecycle); - pub fn into_boxed_quic(self) -> guard::GuardedQuicWriter { - let raw: crate::codec::BoxWriteStream = Box::pin(Deferred::from(self.into_quic())); - guard::GuardedQuicWriter::new(raw) + let read = Box::pin(reader.next()); + tokio::pin!(read); + assert!(read.as_mut().now_or_never().is_none()); + assert_eq!(hypervisor.next().await.unwrap().unwrap(), ReadCommand::Pull); + let send = hypervisor.send(crate::rpc::stream::frame::ReadEvent::Push { + data: Bytes::from_static(b"rpc read"), + }); + let (send, received) = tokio::join!(send, read); + send.unwrap(); + assert_eq!(received.unwrap().unwrap(), Bytes::from_static(b"rpc read")); } -} -impl ReadStream for S -where - S: quic::ReadStream + Unpin + Send, -{ - async fn stream_id(&mut self) -> Result { - GetStreamIdExt::stream_id(self).await - } + #[tokio::test] + async fn read_stream_stop_is_first_wins_through_frames() { + let stream_id = VarInt::from_u32(802); + let first = VarInt::from_u32(11); + let second = VarInt::from_u32(12); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut reader = channels.into_quic(lifecycle); - async fn read(&mut self) -> Result, quic::StreamError> { - StreamExt::next(self).await.transpose() + assert!(reader.stop(first).now_or_never().is_none()); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + ReadCommand::Stop { code: first } + ); + let second_stop = reader.stop(second); + tokio::pin!(second_stop); + let send = hypervisor.send(crate::rpc::stream::frame::ReadEvent::StopAck { code: first }); + let (send, stop) = tokio::join!(send, second_stop); + send.unwrap(); + stop.unwrap(); } - async fn stop(&mut self, code: VarInt) -> Result<(), quic::StreamError> { - StopStreamExt::stop(self, code).await - } -} - -impl WriteStream for S -where - S: quic::WriteStream + Unpin + Send, -{ - async fn stream_id(&mut self) -> Result { - GetStreamIdExt::stream_id(self).await - } - - async fn write(&mut self, data: Bytes) -> Result<(), quic::StreamError> { - SinkExt::send(self, data).await - } - - async fn flush(&mut self) -> Result<(), quic::StreamError> { - SinkExt::flush(self).await - } - - async fn shutdown(&mut self) -> Result<(), quic::StreamError> { - SinkExt::close(self).await - } + #[tokio::test] + async fn write_stream_reset_is_first_wins_through_frames() { + let stream_id = VarInt::from_u32(803); + let first = VarInt::from_u32(21); + let second = VarInt::from_u32(22); + let (channels, mut hypervisor) = WriteFrameChannels::pair(stream_id); + let lifecycle = Arc::new(TestLifecycle::new()); + let mut writer = channels.into_quic(lifecycle); - async fn cancel(&mut self, code: VarInt) -> Result<(), quic::StreamError> { - CancelStreamExt::cancel(self, code).await + assert!(writer.reset(first).now_or_never().is_none()); + assert_eq!( + hypervisor.next().await.unwrap().unwrap(), + WriteCommand::Reset { code: first } + ); + let second_reset = writer.reset(second); + tokio::pin!(second_reset); + let send = hypervisor.send(crate::rpc::stream::frame::WriteEvent::ResetAck { code: first }); + let (send, reset) = tokio::join!(send, second_reset); + send.unwrap(); + reset.unwrap(); } } diff --git a/src/rpc/stream.rs b/src/rpc/stream.rs new file mode 100644 index 0000000..a36fb22 --- /dev/null +++ b/src/rpc/stream.rs @@ -0,0 +1,16 @@ +//! Shared RPC/IPC stream frame bridge internals. +//! +//! This module is the single home for typed stream command/event drivers used +//! by RPC, IPC, and WebTransport IPC stream forwarding. + +pub(crate) mod drain; +pub(crate) mod error; +pub(crate) mod frame; +pub(crate) mod hypervisor; +pub(crate) mod io; +pub(crate) mod reader; +pub(crate) mod remoc; +pub(crate) mod writer; + +#[cfg(test)] +pub(crate) mod test_io; diff --git a/src/rpc/stream/drain.rs b/src/rpc/stream/drain.rs new file mode 100644 index 0000000..f498dff --- /dev/null +++ b/src/rpc/stream/drain.rs @@ -0,0 +1,393 @@ +use std::pin::Pin; + +use futures::future::poll_fn; +use tracing::Instrument as _; + +use super::{ + frame::{ReadCommand, ReadEvent, WriteCommand, WriteEvent}, + io::FrameIo, + reader::ActiveBridgeStreamReader, + writer::ActiveBridgeStreamWriter, +}; +use crate::{quic, rpc::lifecycle::LifecycleExt}; + +pub(crate) trait DrainLifecycle: LifecycleExt + 'static {} + +impl DrainLifecycle for T where T: LifecycleExt + 'static {} + +pub(crate) trait ReadDrainIo: + FrameIo + Send + 'static +{ +} + +impl ReadDrainIo for T where T: FrameIo + Send + 'static {} + +pub(crate) trait WriteDrainIo: + FrameIo + Send + 'static +{ +} + +impl WriteDrainIo for T where T: FrameIo + Send + 'static {} + +pub(super) fn spawn_read_drain(mut active: Pin>>) +where + L: DrainLifecycle, + Io: ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + // Inherent termination: this task owns a moved active bridge containing only + // already committed outbound frames. It exits after those frames are written + // or typed frame IO closes, and it never waits for operation acknowledgements. + let _task = tokio::spawn( + async move { + poll_fn(|cx| active.as_mut().poll_drain(cx)).await; + } + .in_current_span(), + ); +} + +pub(super) fn spawn_write_drain(mut active: Pin>>) +where + L: DrainLifecycle, + Io: WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + // Inherent termination: this task owns a moved active bridge containing only + // already committed outbound frames. It exits after those frames are written + // or typed frame IO closes, and it never waits for operation acknowledgements. + let _task = tokio::spawn( + async move { + poll_fn(|cx| active.as_mut().poll_drain(cx)).await; + } + .in_current_span(), + ); +} + +pub(super) fn log_drain_error(error: &quic::StreamError, context: &'static str) { + let report = snafu::Report::from_error(error); + tracing::warn!(error = %report, context, "stream frame drain failed"); +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + fmt::Debug, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + time::Duration, + }; + + use bytes::Bytes; + use futures::{FutureExt as _, Sink, Stream, future::poll_fn}; + + use crate::{ + quic::{ResetStream, StopStream}, + rpc::stream::{ + frame::{ReadCommand, ReadEvent, WriteCommand, WriteEvent}, + reader::BridgeStreamReader, + test_io::{TestFrameIoError, TestLifecycle}, + writer::BridgeStreamWriter, + }, + varint::VarInt, + }; + + struct ControlledFrameIo { + state: Arc>>, + } + + struct ControlledFrameIoHandle { + state: Arc>>, + } + + struct ControlledState { + hold_flush: bool, + buffered: Vec, + sent: Vec, + inbound: VecDeque>, + } + + impl ControlledFrameIo { + fn new() -> (Self, ControlledFrameIoHandle) { + let state = Arc::new(Mutex::new(ControlledState { + hold_flush: false, + buffered: Vec::new(), + sent: Vec::new(), + inbound: VecDeque::new(), + })); + ( + Self { + state: state.clone(), + }, + ControlledFrameIoHandle { state }, + ) + } + } + + impl ControlledFrameIoHandle { + fn hold_flush(&self, hold_flush: bool) { + self.state.lock().unwrap().hold_flush = hold_flush; + } + + fn push_inbound(&self, frame: In) { + self.state.lock().unwrap().inbound.push_back(Ok(frame)); + } + + fn inbound_len(&self) -> usize { + self.state.lock().unwrap().inbound.len() + } + } + + impl ControlledFrameIoHandle + where + Out: Clone, + { + fn sent(&self) -> Vec { + self.state.lock().unwrap().sent.clone() + } + } + + impl Unpin for ControlledFrameIo {} + + impl Sink for ControlledFrameIo { + type Error = E; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + self.state.lock().unwrap().buffered.push(item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let mut state = self.state.lock().unwrap(); + if state.hold_flush { + return Poll::Pending; + } + let buffered = std::mem::take(&mut state.buffered); + state.sent.extend(buffered); + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + } + + impl Stream for ControlledFrameIo { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + match self.state.lock().unwrap().inbound.pop_front() { + Some(frame) => Poll::Ready(Some(frame)), + None => Poll::Pending, + } + } + } + + type TestReaderIo = ControlledFrameIo; + type TestReaderHandle = ControlledFrameIoHandle; + type TestWriterIo = ControlledFrameIo; + type TestWriterHandle = ControlledFrameIoHandle; + + fn reader( + stream_id: VarInt, + ) -> ( + BridgeStreamReader, + TestReaderHandle, + ) { + let lifecycle = Arc::new(TestLifecycle::new()); + let (io, handle) = ControlledFrameIo::new(); + (BridgeStreamReader::new(stream_id, io, lifecycle), handle) + } + + fn writer( + stream_id: VarInt, + ) -> ( + BridgeStreamWriter, + TestWriterHandle, + ) { + let lifecycle = Arc::new(TestLifecycle::new()); + let (io, handle) = ControlledFrameIo::new(); + (BridgeStreamWriter::new(stream_id, io, lifecycle), handle) + } + + async fn wait_sent(handle: &ControlledFrameIoHandle, expected: Vec) + where + Out: Clone + Debug + PartialEq, + { + tokio::time::timeout(Duration::from_millis(100), async { + loop { + if handle.sent() == expected { + return; + } + tokio::task::yield_now().await; + } + }) + .await + .unwrap_or_else(|_elapsed| { + panic!( + "timed out waiting for sent frames {expected:?}; actual {:?}", + handle.sent() + ) + }); + } + + async fn assert_no_sent_after_yield(handle: &ControlledFrameIoHandle) + where + Out: Clone + Debug + PartialEq, + { + tokio::task::yield_now().await; + assert_eq!(handle.sent(), []); + } + + #[tokio::test] + async fn dropping_reader_after_committed_stop_drains_only_that_stop() { + let code = VarInt::from_u32(201); + let (mut bridge, handle) = reader(VarInt::from_u32(202)); + handle.hold_flush(true); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_stop(cx, code)) + .now_or_never() + .is_none(), + "first stop poll should commit and wait for flush" + ); + assert_eq!(handle.sent(), []); + + handle.hold_flush(false); + drop(bridge); + + wait_sent(&handle, vec![ReadCommand::Stop { code }]).await; + } + + #[tokio::test] + async fn dropping_reader_with_no_committed_outbound_sends_no_frame() { + let (bridge, handle) = reader(VarInt::from_u32(203)); + + drop(bridge); + + assert_no_sent_after_yield(&handle).await; + } + + #[tokio::test] + async fn dropping_writer_after_committed_reset_drains_reset_without_ack() { + let code = VarInt::from_u32(204); + let (mut bridge, handle) = writer(VarInt::from_u32(205)); + handle.hold_flush(true); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, code)) + .now_or_never() + .is_none(), + "first reset poll should commit and wait for flush" + ); + assert_eq!(handle.sent(), []); + + handle.hold_flush(false); + drop(bridge); + + wait_sent(&handle, vec![WriteCommand::Reset { code }]).await; + } + + #[tokio::test] + async fn dropping_writer_with_only_cached_credit_sends_no_frame() { + let (mut bridge, handle) = writer(VarInt::from_u32(206)); + handle.push_inbound(WriteEvent::Pull); + + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect("cached credit should make writer ready"); + drop(bridge); + + assert_no_sent_after_yield(&handle).await; + } + + #[tokio::test] + async fn dropping_writer_with_pending_push_drains_push_without_synthesizing_eos() { + let data = Bytes::from_static(b"drained push"); + let (mut bridge, handle) = writer(VarInt::from_u32(207)); + handle.push_inbound(WriteEvent::Pull); + + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect("cached credit should make writer ready"); + Pin::new(&mut bridge) + .start_send(data.clone()) + .expect("start_send should commit push"); + drop(bridge); + + wait_sent(&handle, vec![WriteCommand::Push { data }]).await; + } + + #[tokio::test] + async fn dropping_reader_with_flushed_pull_and_pending_stop_drains_stop_only() { + let code = VarInt::from_u32(208); + let (mut bridge, handle) = reader(VarInt::from_u32(209)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_next(cx)) + .now_or_never() + .is_none(), + "pull should be flushed and wait for inbound data" + ); + assert_eq!(handle.sent(), [ReadCommand::Pull]); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_stop(cx, code)) + .now_or_never() + .is_none(), + "stop should be committed behind the in-flight pull" + ); + handle.push_inbound(ReadEvent::Push { + data: Bytes::from_static(b"not drained"), + }); + drop(bridge); + + wait_sent(&handle, vec![ReadCommand::Pull, ReadCommand::Stop { code }]).await; + assert_eq!( + handle.inbound_len(), + 1, + "drain must not poll inbound pull results" + ); + } + + #[tokio::test] + async fn dropping_writer_awaiting_ack_does_not_flush_second_acked_command() { + let (mut bridge, handle) = writer(VarInt::from_u32(210)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none(), + "flush should be sent and wait for FlushAck" + ); + assert_eq!(handle.sent(), [WriteCommand::Flush]); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none(), + "eos should be committed behind the awaiting flush" + ); + drop(bridge); + + tokio::task::yield_now().await; + assert_eq!( + handle.sent(), + [WriteCommand::Flush], + "drain must not pipeline Eos behind an awaiting FlushAck" + ); + } +} diff --git a/src/rpc/stream/error.rs b/src/rpc/stream/error.rs new file mode 100644 index 0000000..f000497 --- /dev/null +++ b/src/rpc/stream/error.rs @@ -0,0 +1,92 @@ +#![allow(dead_code)] + +use std::{borrow::Cow, sync::Arc}; + +use futures::future::BoxFuture; +use snafu::Snafu; + +use crate::{quic, rpc::lifecycle::LifecycleExt, varint::VarInt}; + +const DRIVER_PROTOCOL_ERROR_KIND: VarInt = VarInt::from_u32(0x3f00); +const DRIVER_PROTOCOL_ERROR_FRAME_TYPE: VarInt = VarInt::from_u32(0x3f01); + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub(crate) enum DriverProtocolError { + #[snafu(display("stop ack code {actual} does not match committed stop code {expected}"))] + StopAckCodeMismatch { expected: VarInt, actual: VarInt }, + #[snafu(display("reset ack code {actual} does not match committed reset code {expected}"))] + ResetAckCodeMismatch { expected: VarInt, actual: VarInt }, + #[snafu(display("received frame that is invalid for the current reader operation"))] + UnexpectedReaderFrame, + #[snafu(display("received writer ack without a matching committed operation"))] + UnexpectedWriterAck, + #[snafu(display("received duplicate writer credit before previous credit was consumed"))] + DuplicateWriterCredit, + #[snafu(display("start_send called without writer credit"))] + StartSendWithoutCredit, + #[snafu(display("typed frame stream ended before operation completed"))] + FrameEof, +} + +pub(crate) enum DeferredStreamError { + Ready { + error: quic::StreamError, + }, + Pending { + future: BoxFuture<'static, quic::StreamError>, + }, +} + +pub(crate) fn protocol_connection_error(error: DriverProtocolError) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: DRIVER_PROTOCOL_ERROR_KIND, + frame_type: DRIVER_PROTOCOL_ERROR_FRAME_TYPE, + // Lossy: QUIC transport error reason is a protocol string field for + // local bridge-driver protocol failures. + reason: Cow::Owned(error.to_string()), + }, + } +} + +pub(crate) fn latch_protocol_error( + lifecycle: &L, + error: DriverProtocolError, +) -> quic::StreamError +where + L: LifecycleExt, +{ + let source = lifecycle + .latch() + .latch_with(|| protocol_connection_error(error)); + quic::StreamError::Connection { source } +} + +pub(crate) fn latch_frame_io_error(lifecycle: &L, error: E) -> quic::StreamError +where + L: LifecycleExt, + quic::ConnectionError: From, +{ + let source = lifecycle + .latch() + .latch_with(|| quic::ConnectionError::from(error)); + quic::StreamError::Connection { source } +} + +pub(crate) fn defer_err_conn(lifecycle: Arc) -> DeferredStreamError +where + L: LifecycleExt + 'static, +{ + match quic::Lifecycle::check(lifecycle.as_ref()) { + Ok(()) => DeferredStreamError::Pending { + future: Box::pin(async move { + let source = quic::Lifecycle::closed(lifecycle.as_ref()).await; + quic::StreamError::Connection { source } + }), + }, + Err(source) => DeferredStreamError::Ready { + error: quic::StreamError::Connection { source }, + }, + } +} diff --git a/src/rpc/stream/frame.rs b/src/rpc/stream/frame.rs new file mode 100644 index 0000000..51eedf9 --- /dev/null +++ b/src/rpc/stream/frame.rs @@ -0,0 +1,164 @@ +#![allow(dead_code)] + +use bytes::Bytes; + +use crate::varint::VarInt; + +pub(crate) const TAG_PULL: u64 = 0x00; +pub(crate) const TAG_PUSH: u64 = 0x01; +pub(crate) const TAG_FLUSH: u64 = 0x02; +pub(crate) const TAG_FLUSH_ACK: u64 = 0x03; +pub(crate) const TAG_EOS: u64 = 0x04; +pub(crate) const TAG_EOS_ACK: u64 = 0x05; +pub(crate) const TAG_STOP: u64 = 0x06; +pub(crate) const TAG_STOP_ACK: u64 = 0x07; +pub(crate) const TAG_RESET: u64 = 0x08; +pub(crate) const TAG_RESET_ACK: u64 = 0x09; +pub(crate) const TAG_ERR_RESET: u64 = 0x0a; +pub(crate) const TAG_ERR_CONN: u64 = 0x0b; + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ReadCommand { + Pull, + Stop { code: VarInt }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum ReadEvent { + Push { data: Bytes }, + Eos, + StopAck { code: VarInt }, + ErrReset { code: VarInt }, + ErrConn, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum WriteCommand { + Push { data: Bytes }, + Flush, + Eos, + Reset { code: VarInt }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum WriteEvent { + Pull, + FlushAck, + EosAck, + ResetAck { code: VarInt }, + ErrReset { code: VarInt }, + ErrConn, +} + +pub(crate) type WorkerReadOut = ReadCommand; +pub(crate) type WorkerReadIn = ReadEvent; +pub(crate) type WorkerWriteOut = WriteCommand; +pub(crate) type WorkerWriteIn = WriteEvent; +pub(crate) type HyperReadIn = ReadCommand; +pub(crate) type HyperReadOut = ReadEvent; +pub(crate) type HyperWriteIn = WriteCommand; +pub(crate) type HyperWriteOut = WriteEvent; + +#[cfg(test)] +mod tests { + use bytes::Bytes; + + use super::*; + use crate::varint::VarInt; + + #[test] + fn read_eos_and_write_eos_are_distinct_typed_variants() { + let read_eos = ReadEvent::Eos; + let write_eos = WriteCommand::Eos; + + assert_eq!(read_eos, ReadEvent::Eos); + assert_eq!(write_eos, WriteCommand::Eos); + } + + #[test] + fn tags_match_shared_wire_vocabulary() { + assert_eq!( + [ + TAG_PULL, + TAG_PUSH, + TAG_FLUSH, + TAG_FLUSH_ACK, + TAG_EOS, + TAG_EOS_ACK, + TAG_STOP, + TAG_STOP_ACK, + TAG_RESET, + TAG_RESET_ACK, + TAG_ERR_RESET, + TAG_ERR_CONN, + ], + [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b + ] + ); + } + + #[test] + fn direction_aliases_preserve_semantic_frame_types() { + let stop_code = VarInt::from_u32(51); + let reset_code = VarInt::from_u32(52); + let data = Bytes::from_static(b"alias payload"); + + let worker_read_out: WorkerReadOut = ReadCommand::Stop { code: stop_code }; + let worker_read_in: WorkerReadIn = ReadEvent::ErrReset { code: reset_code }; + let worker_write_out: WorkerWriteOut = WriteCommand::Reset { code: reset_code }; + let worker_write_in: WorkerWriteIn = WriteEvent::ErrConn; + let hyper_read_in: HyperReadIn = ReadCommand::Pull; + let hyper_read_out: HyperReadOut = ReadEvent::ErrConn; + let hyper_write_in: HyperWriteIn = WriteCommand::Flush; + let hyper_write_out: HyperWriteOut = WriteEvent::Pull; + + assert_eq!(worker_read_out, ReadCommand::Stop { code: stop_code }); + assert_eq!(worker_read_in, ReadEvent::ErrReset { code: reset_code }); + assert_eq!(worker_write_out, WriteCommand::Reset { code: reset_code }); + assert_eq!(worker_write_in, WriteEvent::ErrConn); + assert_eq!(hyper_read_in, ReadCommand::Pull); + assert_eq!(hyper_read_out, ReadEvent::ErrConn); + assert_eq!(hyper_write_in, WriteCommand::Flush); + assert_eq!(hyper_write_out, WriteEvent::Pull); + assert_eq!( + WriteCommand::Push { data: data.clone() }, + WriteCommand::Push { data } + ); + assert_eq!(WriteEvent::FlushAck, WriteEvent::FlushAck); + assert_eq!(WriteEvent::EosAck, WriteEvent::EosAck); + } + + #[test] + fn stop_ack_and_reset_ack_carry_codes() { + let stop_code = VarInt::from_u32(41); + let reset_code = VarInt::from_u32(42); + + assert_eq!( + ReadEvent::StopAck { code: stop_code }, + ReadEvent::StopAck { code: stop_code } + ); + assert_eq!( + WriteEvent::ResetAck { code: reset_code }, + WriteEvent::ResetAck { code: reset_code } + ); + } + + #[test] + fn push_payloads_remain_bytes() { + let data = Bytes::from_static(b"frame payload"); + + assert_eq!( + ReadEvent::Push { data: data.clone() }, + ReadEvent::Push { data: data.clone() } + ); + assert_eq!( + WriteCommand::Push { data: data.clone() }, + WriteCommand::Push { data } + ); + } +} diff --git a/src/rpc/stream/hypervisor.rs b/src/rpc/stream/hypervisor.rs new file mode 100644 index 0000000..88660db --- /dev/null +++ b/src/rpc/stream/hypervisor.rs @@ -0,0 +1,524 @@ +pub(crate) mod read; +pub(crate) mod write; + +#[cfg(test)] +mod tests { + use std::{ + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll, Waker}, + time::Duration, + }; + + use bytes::Bytes; + use futures::{Sink, SinkExt as _, Stream, StreamExt as _, channel::mpsc}; + + use super::{read::run_read_bridge, write::run_write_bridge}; + use crate::{ + quic::{self, GetStreamId, ResetStream, StopStream}, + rpc::stream::{ + frame::{ReadCommand, ReadEvent, WriteCommand, WriteEvent}, + test_io::{TestFrameIoError, worker_reader_pair, worker_writer_pair}, + }, + varint::VarInt, + }; + + const TEST_TIMEOUT: Duration = Duration::from_millis(100); + + struct RecordingReader { + stream_id: VarInt, + inbound: mpsc::UnboundedReceiver>, + stops: Arc>>, + } + + struct ReaderHandle { + inbound: mpsc::UnboundedSender>, + stops: Arc>>, + } + + fn recording_reader(stream_id: VarInt) -> (RecordingReader, ReaderHandle) { + let (inbound, rx) = mpsc::unbounded(); + let stops = Arc::new(Mutex::new(Vec::new())); + ( + RecordingReader { + stream_id, + inbound: rx, + stops: stops.clone(), + }, + ReaderHandle { inbound, stops }, + ) + } + + impl ReaderHandle { + fn push(&self, data: Bytes) { + self.inbound + .unbounded_send(Ok(data)) + .expect("reader data should queue"); + } + + fn stops(&self) -> Vec { + self.stops.lock().unwrap().clone() + } + } + + impl Unpin for RecordingReader {} + + impl GetStreamId for RecordingReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl StopStream for RecordingReader { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.stops.lock().unwrap().push(code); + Poll::Ready(Ok(())) + } + } + + impl Stream for RecordingReader { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().inbound.poll_next_unpin(cx) + } + } + + #[derive(Debug, Clone, PartialEq, Eq)] + enum WriterOp { + Push(Bytes), + Flush, + Eos, + Reset(VarInt), + } + + struct RecordingWriter { + stream_id: VarInt, + state: Arc>, + } + + #[derive(Clone)] + struct WriterHandle { + state: Arc>, + } + + struct WriterState { + hold_flush: bool, + waker: Option, + buffered: Vec, + ops: Vec, + } + + fn recording_writer(stream_id: VarInt) -> (RecordingWriter, WriterHandle) { + let state = Arc::new(Mutex::new(WriterState { + hold_flush: false, + waker: None, + buffered: Vec::new(), + ops: Vec::new(), + })); + ( + RecordingWriter { + stream_id, + state: state.clone(), + }, + WriterHandle { state }, + ) + } + + impl WriterHandle { + fn hold_flush(&self, hold_flush: bool) { + let waker = { + let mut state = self.state.lock().unwrap(); + state.hold_flush = hold_flush; + state.waker.take() + }; + if let Some(waker) = waker { + waker.wake(); + } + } + + fn ops(&self) -> Vec { + self.state.lock().unwrap().ops.clone() + } + } + + impl Unpin for RecordingWriter {} + + impl GetStreamId for RecordingWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl ResetStream for RecordingWriter { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.state.lock().unwrap().ops.push(WriterOp::Reset(code)); + Poll::Ready(Ok(())) + } + } + + impl Sink for RecordingWriter { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.state.lock().unwrap().buffered.push(item); + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut state = self.state.lock().unwrap(); + if state.hold_flush { + state.waker = Some(cx.waker().clone()); + return Poll::Pending; + } + let buffered = std::mem::take(&mut state.buffered); + if buffered.is_empty() { + state.ops.push(WriterOp::Flush); + } else { + state.ops.extend(buffered.into_iter().map(WriterOp::Push)); + } + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.state.lock().unwrap().ops.push(WriterOp::Eos); + Poll::Ready(Ok(())) + } + } + + async fn expect_read_event( + worker: &mut crate::rpc::stream::test_io::MemoryFrameIo< + ReadCommand, + ReadEvent, + TestFrameIoError, + >, + ) -> ReadEvent { + worker + .next_frame() + .await + .expect("read event should be sent") + .expect("read event should decode") + } + + async fn expect_write_event( + worker: &mut crate::rpc::stream::test_io::MemoryFrameIo< + WriteCommand, + WriteEvent, + TestFrameIoError, + >, + ) -> WriteEvent { + worker + .next_frame() + .await + .expect("write event should be sent") + .expect("write event should decode") + } + + #[tokio::test] + async fn read_bridge_executes_pull_then_stop_in_received_order() { + let stop_code = VarInt::from_u32(301); + let data = Bytes::from_static(b"read bridge data"); + let (reader, handle) = recording_reader(VarInt::from_u32(302)); + let (mut worker, hypervisor) = worker_reader_pair::(); + let task = tokio::spawn(run_read_bridge(reader, hypervisor)); + + worker + .send(ReadCommand::Pull) + .await + .expect("pull should send"); + worker + .send(ReadCommand::Stop { code: stop_code }) + .await + .expect("stop should send behind pull"); + handle.push(data.clone()); + + assert_eq!( + expect_read_event(&mut worker).await, + ReadEvent::Push { data } + ); + assert_eq!( + expect_read_event(&mut worker).await, + ReadEvent::StopAck { code: stop_code } + ); + assert_eq!(handle.stops(), [stop_code]); + drop(worker); + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("read bridge should finish after worker eof") + .expect("read bridge task should not panic"); + } + + #[tokio::test] + async fn read_bridge_worker_eof_does_not_generate_stop() { + let (reader, handle) = recording_reader(VarInt::from_u32(303)); + let (worker, hypervisor) = worker_reader_pair::(); + let task = tokio::spawn(run_read_bridge(reader, hypervisor)); + + drop(worker); + + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("read bridge should finish after worker eof") + .expect("read bridge task should not panic"); + assert!(handle.stops().is_empty()); + } + + #[tokio::test] + async fn read_bridge_ignores_worker_outbound_eof_while_replying() { + let (reader, handle) = recording_reader(VarInt::from_u32(304)); + let (mut worker, hypervisor) = worker_reader_pair::(); + let task = tokio::spawn(run_read_bridge(reader, hypervisor)); + + worker + .send(ReadCommand::Pull) + .await + .expect("pull should send"); + drop(worker); + handle.push(Bytes::from_static(b"abandoned")); + + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("read bridge should finish after abandoned response") + .expect("read bridge task should not panic"); + } + + #[tokio::test] + async fn write_bridge_sends_initial_pull_credit() { + let (writer, _handle) = recording_writer(VarInt::from_u32(305)); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + drop(worker); + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should finish after worker eof") + .expect("write bridge task should not panic"); + } + + #[tokio::test] + async fn write_bridge_treats_flush_and_eos_as_fifo_barriers() { + let data = Bytes::from_static(b"fifo data"); + let (writer, handle) = recording_writer(VarInt::from_u32(306)); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Push { data: data.clone() }) + .await + .expect("push should send"); + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Flush) + .await + .expect("flush should send"); + worker + .send(WriteCommand::Eos) + .await + .expect("eos should send"); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::FlushAck); + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::EosAck); + assert_eq!( + handle.ops(), + [WriterOp::Push(data), WriterOp::Flush, WriterOp::Eos] + ); + drop(worker); + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should finish after eos") + .expect("write bridge task should not panic"); + } + + #[tokio::test] + async fn write_bridge_does_not_duplicate_pull_after_flush_with_outstanding_credit() { + let first = Bytes::from_static(b"first"); + let second = Bytes::from_static(b"second"); + let (writer, handle) = recording_writer(VarInt::from_u32(313)); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Push { + data: first.clone(), + }) + .await + .expect("first push should send"); + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Flush) + .await + .expect("flush should send"); + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::FlushAck); + + match tokio::time::timeout(Duration::from_millis(10), worker.next_frame()).await { + Ok(Some(Ok(frame))) => panic!("flush must not duplicate pull credit: {frame:?}"), + Ok(Some(Err(error))) => panic!("unexpected frame error after flush: {error:?}"), + Ok(None) | Err(_) => {} + } + + worker + .send(WriteCommand::Push { + data: second.clone(), + }) + .await + .expect("second push should use outstanding credit"); + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + assert_eq!( + handle.ops(), + [ + WriterOp::Push(first), + WriterOp::Flush, + WriterOp::Push(second) + ] + ); + drop(worker); + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should finish after worker eof") + .expect("write bridge task should not panic"); + } + + #[tokio::test] + async fn write_bridge_reset_clears_queued_work_without_interrupting_current_push() { + let reset = VarInt::from_u32(307); + let first = Bytes::from_static(b"first"); + let (writer, handle) = recording_writer(VarInt::from_u32(308)); + handle.hold_flush(true); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Push { + data: first.clone(), + }) + .await + .expect("first push should send"); + worker + .send(WriteCommand::Flush) + .await + .expect("flush should queue"); + worker + .send(WriteCommand::Reset { code: reset }) + .await + .expect("reset should queue"); + + tokio::task::yield_now().await; + assert_eq!(handle.ops(), []); + handle.hold_flush(false); + + assert_eq!( + expect_write_event(&mut worker).await, + WriteEvent::ResetAck { code: reset } + ); + assert_eq!( + handle.ops(), + [WriterOp::Push(first), WriterOp::Reset(reset)] + ); + drop(worker); + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should finish after reset") + .expect("write bridge task should not panic"); + } + + #[tokio::test] + async fn write_bridge_does_not_pull_after_terminal_command() { + let reset = VarInt::from_u32(309); + let (writer, _handle) = recording_writer(VarInt::from_u32(310)); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Reset { code: reset }) + .await + .expect("reset should send"); + assert_eq!( + expect_write_event(&mut worker).await, + WriteEvent::ResetAck { code: reset } + ); + match tokio::time::timeout(Duration::from_millis(10), worker.next_frame()).await { + Ok(Some(Ok(frame))) => panic!("terminal reset must not send another frame: {frame:?}"), + Ok(Some(Err(error))) => panic!("terminal reset produced frame error: {error:?}"), + Ok(None) | Err(_) => {} + } + drop(worker); + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should finish after reset") + .expect("write bridge task should not panic"); + } + + #[tokio::test] + async fn write_bridge_worker_eof_drains_queued_work_without_fin_or_reset() { + let data = Bytes::from_static(b"queued before eof"); + let (writer, handle) = recording_writer(VarInt::from_u32(311)); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Push { data: data.clone() }) + .await + .expect("push should send"); + worker + .send(WriteCommand::Flush) + .await + .expect("flush should send"); + drop(worker); + + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should drain after worker eof") + .expect("write bridge task should not panic"); + assert_eq!(handle.ops(), [WriterOp::Push(data), WriterOp::Flush]); + } + + #[tokio::test] + async fn write_bridge_ignores_worker_outbound_eof_while_sending_ack() { + let (writer, handle) = recording_writer(VarInt::from_u32(312)); + let (mut worker, hypervisor) = worker_writer_pair::(); + let task = tokio::spawn(run_write_bridge(writer, hypervisor)); + + assert_eq!(expect_write_event(&mut worker).await, WriteEvent::Pull); + worker + .send(WriteCommand::Flush) + .await + .expect("flush should send"); + drop(worker); + + tokio::time::timeout(TEST_TIMEOUT, task) + .await + .expect("write bridge should ignore closed ack receiver") + .expect("write bridge task should not panic"); + assert_eq!(handle.ops(), [WriterOp::Flush]); + } +} diff --git a/src/rpc/stream/hypervisor/read.rs b/src/rpc/stream/hypervisor/read.rs new file mode 100644 index 0000000..9996abf --- /dev/null +++ b/src/rpc/stream/hypervisor/read.rs @@ -0,0 +1,105 @@ +use std::{collections::VecDeque, error::Error}; + +use futures::{SinkExt as _, StreamExt as _}; + +use crate::{ + quic, + rpc::stream::{ + frame::{ReadCommand, ReadEvent}, + io::FrameIo, + }, +}; + +pub(crate) async fn run_read_bridge(mut reader: R, mut bridge: Io) +where + R: quic::ReadStream + Unpin, + Io: FrameIo + Send + Unpin, + E: Error + Send + 'static, +{ + let mut queue = VecDeque::new(); + let mut inbound_closed = false; + + loop { + if let Some(command) = queue.pop_front() { + let event = { + let mut current = Box::pin(run_read_job(&mut reader, command)); + loop { + tokio::select! { + event = &mut current => break event, + inbound = bridge.next(), if !inbound_closed => { + match inbound { + Some(Ok(command)) => queue.push_back(command), + Some(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, "stream frame read bridge input failed"); + return; + } + None => { + inbound_closed = true; + } + } + } + } + } + }; + + let terminal = matches!( + event, + ReadEvent::Eos | ReadEvent::ErrReset { .. } | ReadEvent::ErrConn + ); + send_read_event(&mut bridge, event).await; + if terminal { + return; + } + continue; + } + + if inbound_closed { + return; + } + + match bridge.next().await { + Some(Ok(command)) => queue.push_back(command), + Some(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, "stream frame read bridge input failed"); + return; + } + None => inbound_closed = true, + } + } +} + +async fn run_read_job(reader: &mut R, command: ReadCommand) -> ReadEvent +where + R: quic::ReadStream + Unpin, +{ + match command { + ReadCommand::Pull => match reader.next().await { + Some(Ok(data)) => ReadEvent::Push { data }, + Some(Err(quic::StreamError::Reset { code })) => ReadEvent::ErrReset { code }, + Some(Err(quic::StreamError::Connection { .. })) => ReadEvent::ErrConn, + None => ReadEvent::Eos, + }, + ReadCommand::Stop { code } => match futures::future::poll_fn(|cx| { + quic::StopStream::poll_stop(std::pin::Pin::new(&mut *reader), cx, code) + }) + .await + { + Ok(()) => ReadEvent::StopAck { code }, + Err(quic::StreamError::Reset { code }) => ReadEvent::ErrReset { code }, + Err(quic::StreamError::Connection { .. }) => ReadEvent::ErrConn, + }, + } +} + +async fn send_read_event(bridge: &mut Io, event: ReadEvent) +where + Io: FrameIo + Unpin, + E: Error + 'static, +{ + if let Err(error) = bridge.send(event).await { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, "stream frame read bridge output failed"); + } +} diff --git a/src/rpc/stream/hypervisor/write.rs b/src/rpc/stream/hypervisor/write.rs new file mode 100644 index 0000000..5751a25 --- /dev/null +++ b/src/rpc/stream/hypervisor/write.rs @@ -0,0 +1,342 @@ +use std::{collections::VecDeque, error::Error, pin::Pin, task::Poll}; + +use bytes::Bytes; +use futures::{SinkExt as _, StreamExt as _}; +use snafu::Snafu; + +use crate::{ + quic, + rpc::stream::{ + frame::{WriteCommand, WriteEvent}, + io::FrameIo, + }, + varint::VarInt, +}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +enum WriteHypervisorProtocolError { + #[snafu(display("received reset code {actual} after committed reset code {expected}"))] + ConflictingReset { expected: VarInt, actual: VarInt }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum HyperWriteJob { + Push { data: Bytes }, + Flush, + Eos, + Reset { code: VarInt }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum HyperWriteDone { + Push, + Flush, + Eos, + Reset { code: VarInt }, +} + +pub(crate) async fn run_write_bridge(mut writer: W, mut bridge: Io) +where + W: quic::WriteStream + Unpin, + Io: FrameIo + Send + Unpin, + E: Error + Send + 'static, +{ + send_write_event(&mut bridge, WriteEvent::Pull).await; + + let mut queue = VecDeque::new(); + let mut reset_code = None; + let mut inbound_closed = false; + let mut credit_outstanding = true; + + loop { + if let Some(job) = queue.pop_front() { + let done = { + let mut current = Box::pin(run_write_job(&mut writer, job)); + loop { + tokio::select! { + result = &mut current => break result, + inbound = bridge.next(), if !inbound_closed => { + match inbound { + Some(Ok(command)) => { + if !record_command( + &mut queue, + &mut reset_code, + &mut credit_outstanding, + command, + ) { + return; + } + } + Some(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, "stream frame write bridge input failed"); + return; + } + None => { + inbound_closed = true; + } + } + } + } + } + }; + + match done { + Ok(HyperWriteDone::Push) => { + if should_send_credit(inbound_closed, reset_code, &queue, credit_outstanding) + && !send_write_credit( + &mut bridge, + &mut queue, + &mut reset_code, + &mut inbound_closed, + &mut credit_outstanding, + ) + .await + { + return; + } + } + Ok(HyperWriteDone::Flush) => { + send_write_event(&mut bridge, WriteEvent::FlushAck).await; + if should_send_credit(inbound_closed, reset_code, &queue, credit_outstanding) + && !send_write_credit( + &mut bridge, + &mut queue, + &mut reset_code, + &mut inbound_closed, + &mut credit_outstanding, + ) + .await + { + return; + } + } + Ok(HyperWriteDone::Eos) => { + send_write_event(&mut bridge, WriteEvent::EosAck).await; + return; + } + Ok(HyperWriteDone::Reset { code }) => { + send_write_event(&mut bridge, WriteEvent::ResetAck { code }).await; + return; + } + Err(quic::StreamError::Reset { code }) => { + send_write_event(&mut bridge, WriteEvent::ErrReset { code }).await; + return; + } + Err(quic::StreamError::Connection { .. }) => { + send_write_event(&mut bridge, WriteEvent::ErrConn).await; + return; + } + } + continue; + } + + if inbound_closed { + return; + } + + match bridge.next().await { + Some(Ok(command)) => { + if !record_command( + &mut queue, + &mut reset_code, + &mut credit_outstanding, + command, + ) { + return; + } + } + Some(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, "stream frame write bridge input failed"); + return; + } + None => inbound_closed = true, + } + } +} + +fn record_command( + queue: &mut VecDeque, + reset_code: &mut Option, + credit_outstanding: &mut bool, + command: WriteCommand, +) -> bool { + match command { + WriteCommand::Push { data } => { + *credit_outstanding = false; + if reset_code.is_none() && !has_terminal(queue) { + queue.push_back(HyperWriteJob::Push { data }); + } + true + } + WriteCommand::Flush => { + if reset_code.is_none() && !has_terminal(queue) { + queue.push_back(HyperWriteJob::Flush); + } + true + } + WriteCommand::Eos => { + if reset_code.is_none() && !has_terminal(queue) { + queue.push_back(HyperWriteJob::Eos); + } + true + } + WriteCommand::Reset { code } => match *reset_code { + Some(committed) if committed == code => true, + Some(expected) => { + let error = WriteHypervisorProtocolError::ConflictingReset { + expected, + actual: code, + }; + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, "stream frame write bridge input failed"); + false + } + None => { + *reset_code = Some(code); + queue.clear(); + queue.push_back(HyperWriteJob::Reset { code }); + true + } + }, + } +} + +fn has_terminal(queue: &VecDeque) -> bool { + queue + .iter() + .any(|job| matches!(job, HyperWriteJob::Eos | HyperWriteJob::Reset { .. })) +} + +fn should_send_credit( + inbound_closed: bool, + reset_code: Option, + queue: &VecDeque, + credit_outstanding: bool, +) -> bool { + !credit_outstanding && !inbound_closed && reset_code.is_none() && !has_terminal(queue) +} + +async fn run_write_job( + writer: &mut W, + job: HyperWriteJob, +) -> Result +where + W: quic::WriteStream + Unpin, +{ + match job { + HyperWriteJob::Push { data } => { + writer.send(data).await?; + Ok(HyperWriteDone::Push) + } + HyperWriteJob::Flush => { + writer.flush().await?; + Ok(HyperWriteDone::Flush) + } + HyperWriteJob::Eos => { + writer.close().await?; + Ok(HyperWriteDone::Eos) + } + HyperWriteJob::Reset { code } => { + quic::ResetStreamExt::reset(writer, code).await?; + Ok(HyperWriteDone::Reset { code }) + } + } +} + +async fn send_write_event(bridge: &mut Io, event: WriteEvent) +where + Io: FrameIo + Unpin, + E: Error + 'static, +{ + if let Err(error) = bridge.send(event).await { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, "stream frame write bridge output failed"); + } +} + +enum CreditSendState { + Ready, + Flush, + Done, +} + +async fn send_write_credit( + bridge: &mut Io, + queue: &mut VecDeque, + reset_code: &mut Option, + inbound_closed: &mut bool, + credit_outstanding: &mut bool, +) -> bool +where + Io: FrameIo + Unpin, + E: Error + 'static, +{ + let mut state = CreditSendState::Ready; + futures::future::poll_fn(|cx| { + if !*inbound_closed { + match Pin::new(&mut *bridge).poll_next(cx) { + Poll::Ready(Some(Ok(command))) => { + if record_command(queue, reset_code, credit_outstanding, command) { + return Poll::Ready(true); + } + return Poll::Ready(false); + } + Poll::Ready(Some(Err(error))) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, "stream frame write bridge input failed"); + return Poll::Ready(false); + } + Poll::Ready(None) => { + *inbound_closed = true; + return Poll::Ready(true); + } + Poll::Pending => {} + } + } + + loop { + match state { + CreditSendState::Ready => match Pin::new(&mut *bridge).poll_ready(cx) { + Poll::Ready(Ok(())) => { + if let Err(error) = Pin::new(&mut *bridge).start_send(WriteEvent::Pull) { + let report = snafu::Report::from_error(&error); + tracing::debug!( + error = %report, + "stream frame write bridge output failed" + ); + state = CreditSendState::Done; + return Poll::Ready(true); + } + *credit_outstanding = true; + state = CreditSendState::Flush; + } + Poll::Ready(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, "stream frame write bridge output failed"); + state = CreditSendState::Done; + return Poll::Ready(true); + } + Poll::Pending => return Poll::Pending, + }, + CreditSendState::Flush => match Pin::new(&mut *bridge).poll_flush(cx) { + Poll::Ready(Ok(())) => { + state = CreditSendState::Done; + return Poll::Ready(true); + } + Poll::Ready(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, "stream frame write bridge output failed"); + state = CreditSendState::Done; + return Poll::Ready(true); + } + Poll::Pending => return Poll::Pending, + }, + CreditSendState::Done => return Poll::Ready(true), + } + } + }) + .await +} diff --git a/src/rpc/stream/io.rs b/src/rpc/stream/io.rs new file mode 100644 index 0000000..99f4dea --- /dev/null +++ b/src/rpc/stream/io.rs @@ -0,0 +1,13 @@ +#![allow(dead_code)] + +use futures::{Sink, Stream}; + +pub(crate) trait FrameIo: + Sink + Stream> +{ +} + +impl FrameIo for T where + T: Sink + Stream> +{ +} diff --git a/src/rpc/stream/reader.rs b/src/rpc/stream/reader.rs new file mode 100644 index 0000000..def6f3d --- /dev/null +++ b/src/rpc/stream/reader.rs @@ -0,0 +1,1300 @@ +use std::{ + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll, ready}, +}; + +use bytes::Bytes; +use futures::{Stream, stream::FusedStream}; + +use super::{ + drain, + error::{ + DeferredStreamError, DriverProtocolError, defer_err_conn, latch_frame_io_error, + latch_protocol_error, + }, + frame::{ReadCommand, ReadEvent}, +}; +use crate::{quic, varint::VarInt}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ReadOperation { + Pull, + Stop { code: VarInt }, +} + +#[derive(Debug, Default)] +struct PendingRead { + pull: bool, + stop: Option, +} + +#[derive(Debug)] +enum SendState { + Idle { ready: bool }, + Flush { inflight: Op }, +} + +impl Default for SendState { + fn default() -> Self { + Self::Idle { ready: false } + } +} + +#[derive(Debug, Default)] +enum ReaderRecvState { + #[default] + Idle, + Receive { + sent: ReadOperation, + }, +} + +#[derive(Debug, Default)] +struct ReaderResults { + pull: Option, + stop: Option, +} + +#[derive(Debug)] +enum ReaderPullResult { + Push(Bytes), + Eos, +} + +#[derive(Debug)] +struct StopAckResult { + code: VarInt, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CommitResult { + Committed, + Duplicate, + Conflict, +} + +enum ReaderFault { + Stream(quic::StreamError), + Deferred(DeferredStreamError), +} + +pin_project_lite::pin_project! { + #[project = BridgeStreamReaderProj] + #[project_replace = BridgeStreamReaderReplace] + pub(crate) enum BridgeStreamReader + where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, + { + Active { + active: Pin>>, + }, + Eos { + stream_id: VarInt, + }, + Closed { + stream_id: VarInt, + error: DeferredStreamError, + }, + } + + impl PinnedDrop for BridgeStreamReader + where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, + { + fn drop(mut this: Pin<&mut Self>) { + let stream_id = match this.as_mut().project() { + BridgeStreamReaderProj::Active { active } => active.as_ref().get_ref().stream_id, + BridgeStreamReaderProj::Eos { .. } | BridgeStreamReaderProj::Closed { .. } => { + return; + } + }; + + match this.project_replace(BridgeStreamReader::Eos { stream_id }) { + BridgeStreamReaderReplace::Active { active } + if active.as_ref().get_ref().has_committed_outbound() => + { + drain::spawn_read_drain(active); + } + _ => {} + } + } + } +} + +pin_project_lite::pin_project! { + pub(crate) struct ActiveBridgeStreamReader { + stream_id: VarInt, + lifecycle: Arc, + #[pin] + bridge: Io, + pending: PendingRead, + send_state: SendState, + recv_state: ReaderRecvState, + results: ReaderResults, + _error: PhantomData, + } +} + +impl BridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + pub(crate) fn new(stream_id: VarInt, bridge: Io, lifecycle: Arc) -> Self { + Self::Active { + active: Box::pin(ActiveBridgeStreamReader { + stream_id, + lifecycle, + bridge, + pending: PendingRead::default(), + send_state: SendState::default(), + recv_state: ReaderRecvState::default(), + results: ReaderResults::default(), + _error: PhantomData, + }), + } + } +} + +impl ActiveBridgeStreamReader { + pub(super) fn has_committed_outbound(&self) -> bool { + self.pending.pull + || self.pending.stop.is_some() + || matches!(self.send_state, SendState::Flush { .. }) + } + + fn committed_stop_code(&self) -> Option { + if let Some(code) = self.pending.stop { + return Some(code); + } + if let SendState::Flush { + inflight: ReadOperation::Stop { code }, + } = self.send_state + { + return Some(code); + } + if let ReaderRecvState::Receive { + sent: ReadOperation::Stop { code }, + } = self.recv_state + { + return Some(code); + } + self.results.stop.as_ref().map(|result| result.code) + } + + fn stop_send_is_blocked_by_pull(&self) -> bool { + self.pending.pull + || matches!( + self.send_state, + SendState::Flush { + inflight: ReadOperation::Pull + } + ) + || matches!( + self.recv_state, + ReaderRecvState::Receive { + sent: ReadOperation::Pull + } + ) + } + + fn has_pull_eos(&self) -> bool { + matches!(self.results.pull, Some(ReaderPullResult::Eos)) + } + + fn commit_pull_pinned(mut self: Pin<&mut Self>) -> CommitResult { + let this = self.as_mut().project(); + if this.pending.pull + || matches!( + this.send_state, + SendState::Flush { + inflight: ReadOperation::Pull + } + ) + || matches!( + this.recv_state, + ReaderRecvState::Receive { + sent: ReadOperation::Pull + } + ) + || this.results.pull.is_some() + { + return CommitResult::Duplicate; + } + this.pending.pull = true; + CommitResult::Committed + } + + fn commit_stop_pinned(mut self: Pin<&mut Self>, code: VarInt) -> CommitResult { + match self.as_ref().get_ref().committed_stop_code() { + Some(committed) if committed == code => CommitResult::Duplicate, + Some(_committed) => CommitResult::Conflict, + None => { + let this = self.as_mut().project(); + this.pending.pull = false; + this.pending.stop = Some(code); + CommitResult::Committed + } + } + } + + fn take_pull_result_pinned(mut self: Pin<&mut Self>) -> Option { + self.as_mut().project().results.pull.take() + } +} + +impl ActiveBridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn protocol_fault_for(lifecycle: &Arc, error: DriverProtocolError) -> ReaderFault { + ReaderFault::Stream(latch_protocol_error(lifecycle.as_ref(), error)) + } + + fn frame_io_fault_for(lifecycle: &Arc, error: E) -> ReaderFault { + ReaderFault::Stream(latch_frame_io_error(lifecycle.as_ref(), error)) + } + + fn pair_response( + lifecycle: &Arc, + results: &mut ReaderResults, + sent: ReadOperation, + frame: ReadEvent, + ) -> Result<(), ReaderFault> { + match (sent, frame) { + (ReadOperation::Pull, ReadEvent::Push { data }) => { + results.pull = Some(ReaderPullResult::Push(data)); + Ok(()) + } + (ReadOperation::Pull, ReadEvent::Eos) => { + results.pull = Some(ReaderPullResult::Eos); + Ok(()) + } + (ReadOperation::Stop { code }, ReadEvent::StopAck { code: actual }) + if code == actual => + { + results.stop = Some(StopAckResult { code }); + Ok(()) + } + (ReadOperation::Stop { code }, ReadEvent::StopAck { code: actual }) => { + Err(Self::protocol_fault_for( + lifecycle, + DriverProtocolError::StopAckCodeMismatch { + expected: code, + actual, + }, + )) + } + (_sent, ReadEvent::ErrReset { code }) => { + Err(ReaderFault::Stream(quic::StreamError::Reset { code })) + } + (_sent, ReadEvent::ErrConn) => { + Err(ReaderFault::Deferred(defer_err_conn(lifecycle.clone()))) + } + (_sent, _frame) => Err(Self::protocol_fault_for( + lifecycle, + DriverProtocolError::UnexpectedReaderFrame, + )), + } + } + + fn clear_pending_stop_if_matches(mut self: Pin<&mut Self>, code: VarInt) { + let this = self.as_mut().project(); + if this.pending.stop == Some(code) { + this.pending.stop = None; + } + } + + fn pending_stop_matches(&self, code: VarInt) -> bool { + self.pending.stop == Some(code) + } + + fn stop_ready_to_send(&self, code: VarInt) -> bool { + self.pending_stop_matches(code) && !self.stop_send_is_blocked_by_pull() + } + + fn completed_stop_matches(&self, code: VarInt) -> bool { + matches!(&self.results.stop, Some(result) if result.code == code) + } + + fn completed_pull_result(&self) -> bool { + self.results.pull.is_some() + } + + fn poll_idle(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + let mut this = self.as_mut().project(); + match (&mut *this.send_state, &mut *this.recv_state) { + (SendState::Idle { .. }, ReaderRecvState::Idle) => { + return Poll::Ready(Ok(())); + } + (SendState::Flush { inflight }, ReaderRecvState::Idle) => { + match ready!(this.bridge.as_mut().poll_flush(cx)) { + Ok(()) => { + let sent = *inflight; + *this.send_state = SendState::Idle { ready: false }; + *this.recv_state = ReaderRecvState::Receive { sent }; + } + Err(error) => { + return Poll::Ready(Err(Self::frame_io_fault_for( + this.lifecycle, + error, + ))); + } + } + } + (SendState::Idle { .. }, ReaderRecvState::Receive { sent }) => { + let sent = *sent; + match ready!(this.bridge.as_mut().poll_next(cx)) { + Some(Ok(frame)) => { + *this.recv_state = ReaderRecvState::Idle; + Self::pair_response(this.lifecycle, this.results, sent, frame)?; + } + Some(Err(error)) => { + return Poll::Ready(Err(Self::frame_io_fault_for( + this.lifecycle, + error, + ))); + } + None => { + return Poll::Ready(Err(Self::protocol_fault_for( + this.lifecycle, + DriverProtocolError::FrameEof, + ))); + } + } + } + (SendState::Flush { .. }, ReaderRecvState::Receive { .. }) => { + unreachable!("reader bridge cannot flush and receive simultaneously"); + } + } + } + } + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.as_mut().poll_idle(cx))?; + + let mut this = self.as_mut().project(); + match &mut *this.send_state { + SendState::Idle { ready: true } => Poll::Ready(Ok(())), + SendState::Idle { ready } => match ready!(this.bridge.as_mut().poll_ready(cx)) { + Ok(()) => { + *ready = true; + Poll::Ready(Ok(())) + } + Err(error) => Poll::Ready(Err(Self::frame_io_fault_for(this.lifecycle, error))), + }, + SendState::Flush { .. } => unreachable!("poll_idle must drive flush before readiness"), + } + } + + fn poll_drain_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + match &mut *this.send_state { + SendState::Idle { ready: true } => Poll::Ready(Ok(())), + SendState::Idle { ready } => match ready!(this.bridge.as_mut().poll_ready(cx)) { + Ok(()) => { + *ready = true; + Poll::Ready(Ok(())) + } + Err(error) => { + Poll::Ready(Err(latch_frame_io_error(this.lifecycle.as_ref(), error))) + } + }, + SendState::Flush { .. } => unreachable!("reader drain readiness while flushing"), + } + } + + fn start_drain_operation( + mut self: Pin<&mut Self>, + operation: ReadOperation, + ) -> Result<(), quic::StreamError> { + let command = match operation { + ReadOperation::Pull => ReadCommand::Pull, + ReadOperation::Stop { code } => ReadCommand::Stop { code }, + }; + + let mut this = self.as_mut().project(); + match this.bridge.as_mut().start_send(command) { + Ok(()) => { + *this.send_state = SendState::Flush { + inflight: operation, + }; + Ok(()) + } + Err(error) => Err(latch_frame_io_error(this.lifecycle.as_ref(), error)), + } + } + + pub(super) fn poll_drain(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + loop { + if matches!(self.as_ref().get_ref().send_state, SendState::Flush { .. }) { + let mut this = self.as_mut().project(); + match ready!(this.bridge.as_mut().poll_flush(cx)) { + Ok(()) => { + *this.send_state = SendState::Idle { ready: false }; + continue; + } + Err(error) => { + let error = latch_frame_io_error(this.lifecycle.as_ref(), error); + drain::log_drain_error(&error, "reader"); + return Poll::Ready(()); + } + } + } + + if self.as_ref().get_ref().pending.pull { + if let Err(error) = ready!(self.as_mut().poll_drain_ready(cx)) { + drain::log_drain_error(&error, "reader"); + return Poll::Ready(()); + } + self.as_mut().project().pending.pull = false; + if let Err(error) = self.as_mut().start_drain_operation(ReadOperation::Pull) { + drain::log_drain_error(&error, "reader"); + return Poll::Ready(()); + } + continue; + } + + if let Some(code) = self.as_ref().get_ref().pending.stop { + if let Err(error) = ready!(self.as_mut().poll_drain_ready(cx)) { + drain::log_drain_error(&error, "reader"); + return Poll::Ready(()); + } + self.as_mut().project().pending.stop = None; + if let Err(error) = self + .as_mut() + .start_drain_operation(ReadOperation::Stop { code }) + { + drain::log_drain_error(&error, "reader"); + return Poll::Ready(()); + } + continue; + } + + return Poll::Ready(()); + } + } + + fn drive_pull(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + if self.as_ref().get_ref().completed_pull_result() { + return Poll::Ready(Ok(())); + } + + ready!(self.as_mut().poll_idle(cx))?; + + if self.as_ref().get_ref().completed_pull_result() { + return Poll::Ready(Ok(())); + } + + if self.as_ref().get_ref().pending.pull { + ready!(self.as_mut().poll_ready(cx))?; + + let mut this = self.as_mut().project(); + this.pending.pull = false; + match this.bridge.as_mut().start_send(ReadCommand::Pull) { + Ok(()) => { + *this.send_state = SendState::Flush { + inflight: ReadOperation::Pull, + }; + } + Err(error) => { + return Poll::Ready(Err(Self::frame_io_fault_for(this.lifecycle, error))); + } + } + continue; + } + + return Poll::Pending; + } + } + + fn drive_stop( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + committed_code: VarInt, + ) -> Poll> { + loop { + if self.as_ref().get_ref().has_pull_eos() { + self.as_mut().clear_pending_stop_if_matches(committed_code); + return Poll::Ready(Ok(())); + } + + if self + .as_ref() + .get_ref() + .completed_stop_matches(committed_code) + { + return Poll::Ready(Ok(())); + } + + ready!(self.as_mut().poll_idle(cx))?; + + if self.as_ref().get_ref().has_pull_eos() { + self.as_mut().clear_pending_stop_if_matches(committed_code); + return Poll::Ready(Ok(())); + } + + if self + .as_ref() + .get_ref() + .completed_stop_matches(committed_code) + { + return Poll::Ready(Ok(())); + } + + if self.as_ref().get_ref().stop_ready_to_send(committed_code) { + ready!(self.as_mut().poll_ready(cx))?; + + let mut this = self.as_mut().project(); + this.pending.stop = None; + match this.bridge.as_mut().start_send(ReadCommand::Stop { + code: committed_code, + }) { + Ok(()) => { + *this.send_state = SendState::Flush { + inflight: ReadOperation::Stop { + code: committed_code, + }, + }; + } + Err(error) => { + return Poll::Ready(Err(Self::frame_io_fault_for(this.lifecycle, error))); + } + } + continue; + } + + return Poll::Pending; + } + } +} + +fn poll_deferred_error( + error: &mut DeferredStreamError, + cx: &mut Context<'_>, +) -> Poll { + match error { + DeferredStreamError::Ready { error } => Poll::Ready(error.clone()), + DeferredStreamError::Pending { future } => { + let stream_error = ready!(future.as_mut().poll(cx)); + *error = DeferredStreamError::Ready { + error: stream_error.clone(), + }; + Poll::Ready(stream_error) + } + } +} + +fn ready_error(error: quic::StreamError) -> DeferredStreamError { + DeferredStreamError::Ready { error } +} + +impl BridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn stream_id(self: Pin<&mut Self>) -> VarInt { + match self.project() { + BridgeStreamReaderProj::Active { active } => active.as_ref().get_ref().stream_id, + BridgeStreamReaderProj::Eos { stream_id } => *stream_id, + BridgeStreamReaderProj::Closed { stream_id, .. } => *stream_id, + } + } + + fn close_with_stream_error( + mut self: Pin<&mut Self>, + error: quic::StreamError, + ) -> quic::StreamError { + let stream_id = self.as_mut().stream_id(); + self.as_mut().project_replace(Self::Closed { + stream_id, + error: ready_error(error.clone()), + }); + error + } + + fn close_with_fault( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + fault: ReaderFault, + ) -> Poll { + match fault { + ReaderFault::Stream(error) => Poll::Ready(self.close_with_stream_error(error)), + ReaderFault::Deferred(error) => { + let stream_id = self.as_mut().stream_id(); + self.as_mut() + .project_replace(Self::Closed { stream_id, error }); + match self.project() { + BridgeStreamReaderProj::Closed { error, .. } => poll_deferred_error(error, cx), + _ => unreachable!("reader should have transitioned to closed"), + } + } + } + } + + fn poll_closed_error(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + BridgeStreamReaderProj::Closed { error, .. } => poll_deferred_error(error, cx), + _ => unreachable!("poll_closed_error called outside closed state"), + } + } +} + +impl Stream for BridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().project() { + BridgeStreamReaderProj::Active { active } => { + let _commit = active.as_mut().commit_pull_pinned(); + match ready!(active.as_mut().drive_pull(cx)) { + Ok(()) => match active.as_mut().take_pull_result_pinned() { + Some(ReaderPullResult::Push(data)) => Poll::Ready(Some(Ok(data))), + Some(ReaderPullResult::Eos) => { + let stream_id = active.as_ref().get_ref().stream_id; + self.as_mut().project_replace(Self::Eos { stream_id }); + Poll::Ready(None) + } + None => Poll::Pending, + }, + Err(fault) => { + let error = ready!(self.as_mut().close_with_fault(cx, fault)); + Poll::Ready(Some(Err(error))) + } + } + } + BridgeStreamReaderProj::Eos { .. } => Poll::Ready(None), + BridgeStreamReaderProj::Closed { .. } => { + let error = ready!(self.as_mut().poll_closed_error(cx)); + Poll::Ready(Some(Err(error))) + } + } + } +} + +impl FusedStream for BridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn is_terminated(&self) -> bool { + matches!( + self, + Self::Eos { .. } + | Self::Closed { + error: DeferredStreamError::Ready { .. }, + .. + } + ) + } +} + +impl quic::GetStreamId for BridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + match self.project() { + BridgeStreamReaderProj::Active { active } => { + Poll::Ready(Ok(active.as_ref().get_ref().stream_id)) + } + BridgeStreamReaderProj::Eos { stream_id } => Poll::Ready(Ok(*stream_id)), + BridgeStreamReaderProj::Closed { stream_id, .. } => Poll::Ready(Ok(*stream_id)), + } + } +} + +impl quic::StopStream for BridgeStreamReader +where + L: drain::DrainLifecycle, + Io: drain::ReadDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn poll_stop( + mut self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + match self.as_mut().project() { + BridgeStreamReaderProj::Active { active } => { + let commit = active.as_mut().commit_stop_pinned(code); + let committed_code = match commit { + CommitResult::Committed | CommitResult::Duplicate => code, + CommitResult::Conflict => active + .as_ref() + .committed_stop_code() + .expect("conflicting stop must have committed code"), + }; + + match ready!(active.as_mut().drive_stop(cx, committed_code)) { + Ok(()) => { + if active.as_ref().has_pull_eos() { + let stream_id = active.as_ref().get_ref().stream_id; + self.as_mut().project_replace(Self::Eos { stream_id }); + } + Poll::Ready(Ok(())) + } + Err(fault) => { + let error = ready!(self.as_mut().close_with_fault(cx, fault)); + Poll::Ready(Err(error)) + } + } + } + BridgeStreamReaderProj::Eos { .. } => Poll::Ready(Ok(())), + BridgeStreamReaderProj::Closed { .. } => { + let error = ready!(self.as_mut().poll_closed_error(cx)); + Poll::Ready(Err(error)) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{borrow::Cow, future::pending, sync::Arc}; + + use bytes::Bytes; + use futures::{ + FutureExt as _, SinkExt as _, StreamExt as _, future::poll_fn, stream::FusedStream as _, + }; + + use super::BridgeStreamReader; + use crate::{ + quic::{self, GetStreamIdExt as _, StopStream, StopStreamExt as _}, + rpc::{ + lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, + stream::{ + frame::{ReadCommand, ReadEvent}, + test_io::{MemoryFrameIo, TestFrameIoError, TestLifecycle, worker_reader_pair}, + }, + }, + varint::VarInt, + }; + + type WorkerReaderIo = MemoryFrameIo; + type HypervisorReaderIo = MemoryFrameIo; + type TestBridgeStreamReader = + BridgeStreamReader; + + struct PendingClosedLifecycle { + latch: ConnectionErrorLatch, + } + + impl PendingClosedLifecycle { + fn new() -> Self { + Self { + latch: ConnectionErrorLatch::new(), + } + } + } + + impl HasLatch for PendingClosedLifecycle { + fn latch(&self) -> &ConnectionErrorLatch { + &self.latch + } + } + + impl quic::Lifecycle for PendingClosedLifecycle { + fn close(&self, _code: crate::error::Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + self.check_with_probe(|| None) + } + + async fn closed(&self) -> quic::ConnectionError { + self.resolve_closed(pending()).await + } + } + + fn reader( + stream_id: VarInt, + ) -> ( + TestBridgeStreamReader, + HypervisorReaderIo, + Arc, + ) { + let lifecycle = Arc::new(TestLifecycle::new()); + let (worker, hypervisor) = worker_reader_pair::(); + ( + BridgeStreamReader::new(stream_id, worker, lifecycle.clone()), + hypervisor, + lifecycle, + ) + } + + fn pending_closed_reader( + stream_id: VarInt, + ) -> ( + BridgeStreamReader, + HypervisorReaderIo, + ) { + let lifecycle = Arc::new(PendingClosedLifecycle::new()); + let (worker, hypervisor) = worker_reader_pair::(); + ( + BridgeStreamReader::new(stream_id, worker, lifecycle), + hypervisor, + ) + } + + fn transport(error: &quic::ConnectionError) -> &quic::TransportError { + let quic::ConnectionError::Transport { source } = error else { + panic!("connection error should be transport-scoped"); + }; + source + } + + fn stream_connection(error: quic::StreamError) -> quic::ConnectionError { + let quic::StreamError::Connection { source } = error else { + panic!("stream error should be connection-scoped"); + }; + source + } + + async fn expect_command(hypervisor: &mut HypervisorReaderIo) -> ReadCommand { + hypervisor + .next_frame() + .await + .expect("bridge should send a frame") + .expect("command frame should be readable") + } + + #[tokio::test] + async fn poll_next_sends_one_pull_and_returns_push_bytes() { + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(7)); + let data = Bytes::from_static(b"read payload"); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none(), + "first poll_next should commit a pull and wait for a response" + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + assert!( + hypervisor.next_frame().now_or_never().is_none(), + "poll_next must not send a duplicate pull while the first is in flight" + ); + + hypervisor + .send(ReadEvent::Push { data: data.clone() }) + .await + .expect("push event should send"); + + assert_eq!( + bridge.next().await.expect("reader item").expect("read ok"), + data + ); + } + + #[tokio::test] + async fn poll_next_returns_none_for_eos_and_stays_ended() { + let stream_id = VarInt::from_u32(8); + let (mut bridge, mut hypervisor, _lifecycle) = reader(stream_id); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::Eos) + .await + .expect("eos event should send"); + + assert!(bridge.next().await.is_none()); + assert_eq!(bridge.stream_id().await.expect("stream id"), stream_id); + assert!(matches!(bridge.next().now_or_never(), Some(None))); + } + + #[tokio::test] + async fn poll_stop_keeps_first_code_until_matching_stop_ack_arrives() { + let first = VarInt::from_u32(11); + let second = VarInt::from_u32(12); + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(9)); + + assert!( + poll_fn(|cx| std::pin::Pin::new(&mut bridge).poll_stop(cx, first)) + .now_or_never() + .is_none(), + "first stop poll should commit and wait for ack" + ); + assert_eq!( + expect_command(&mut hypervisor).await, + ReadCommand::Stop { code: first } + ); + + assert!( + poll_fn(|cx| std::pin::Pin::new(&mut bridge).poll_stop(cx, second)) + .now_or_never() + .is_none(), + "second stop poll should continue the first committed code" + ); + assert!( + hypervisor.next_frame().now_or_never().is_none(), + "different stop code must not replace the first committed stop" + ); + + hypervisor + .send(ReadEvent::StopAck { code: first }) + .await + .expect("stop ack should send"); + bridge.stop(second).await.expect("stop should complete"); + } + + #[tokio::test] + async fn poll_stop_after_completed_stop_ack_does_not_send_second_stop() { + let first = VarInt::from_u32(21); + let second = VarInt::from_u32(22); + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(10)); + + assert!( + poll_fn(|cx| std::pin::Pin::new(&mut bridge).poll_stop(cx, first)) + .now_or_never() + .is_none() + ); + assert_eq!( + expect_command(&mut hypervisor).await, + ReadCommand::Stop { code: first } + ); + hypervisor + .send(ReadEvent::StopAck { code: first }) + .await + .expect("stop ack should send"); + bridge + .stop(first) + .await + .expect("first stop should complete"); + + bridge + .stop(second) + .await + .expect("completed stop should make later stop a no-op"); + assert!( + hypervisor.next_frame().now_or_never().is_none(), + "completed stop result must prevent a second stop frame" + ); + } + + #[tokio::test] + async fn poll_stop_during_pull_preserves_push_for_later_poll_next() { + let stop_code = VarInt::from_u32(31); + let data = Bytes::from_static(b"received before stop"); + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(13)); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::Push { data: data.clone() }) + .await + .expect("push event should send"); + + let stop = + poll_fn(|cx| std::pin::Pin::new(&mut bridge).poll_stop(cx, stop_code)).now_or_never(); + assert!( + stop.is_none(), + "stop should wait for stop ack after caching push" + ); + assert_eq!( + expect_command(&mut hypervisor).await, + ReadCommand::Stop { code: stop_code } + ); + hypervisor + .send(ReadEvent::StopAck { code: stop_code }) + .await + .expect("stop ack should send"); + bridge + .stop(stop_code) + .await + .expect("stop should complete after ack"); + + assert_eq!( + bridge + .next() + .await + .expect("cached item should exist") + .expect("cached read should be ok"), + data + ); + assert!( + hypervisor.next_frame().now_or_never().is_none(), + "cached push must be returned without issuing another pull" + ); + } + + #[tokio::test] + async fn poll_stop_during_pull_eos_completes_without_stop_frame() { + let stop_code = VarInt::from_u32(32); + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(14)); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::Eos) + .await + .expect("eos event should send"); + + bridge + .stop(stop_code) + .await + .expect("eos should satisfy stop"); + match hypervisor.next_frame().now_or_never() { + None | Some(None) => {} + Some(Some(frame)) => panic!("eos stop should not send a frame, got {frame:?}"), + } + assert!(matches!(bridge.next().now_or_never(), Some(None))); + } + + #[tokio::test] + async fn wrong_stop_ack_latches_protocol_connection_error() { + let expected = VarInt::from_u32(41); + let actual = VarInt::from_u32(42); + let (mut bridge, mut hypervisor, lifecycle) = reader(VarInt::from_u32(15)); + + assert!( + poll_fn(|cx| std::pin::Pin::new(&mut bridge).poll_stop(cx, expected)) + .now_or_never() + .is_none() + ); + assert_eq!( + expect_command(&mut hypervisor).await, + ReadCommand::Stop { code: expected } + ); + hypervisor + .send(ReadEvent::StopAck { code: actual }) + .await + .expect("wrong stop ack should send"); + + let error = bridge.stop(expected).await.expect_err("stop should fail"); + let source = stream_connection(error); + assert_eq!( + transport(&source).reason.as_ref(), + "stop ack code 42 does not match committed stop code 41" + ); + assert_eq!( + transport(&source).reason.as_ref(), + quic::Lifecycle::closed(lifecycle.as_ref()) + .await + .transport() + .reason + .as_ref() + ); + } + + #[tokio::test] + async fn err_reset_closes_reader_with_reset_stream_error() { + let reset_code = VarInt::from_u32(51); + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(16)); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::ErrReset { code: reset_code }) + .await + .expect("reset event should send"); + + match bridge.next().await { + Some(Err(quic::StreamError::Reset { code })) if code == reset_code => {} + result => panic!("expected reset stream error, got {result:?}"), + } + } + + #[tokio::test] + async fn terminal_error_marks_reader_terminated() { + let reset_code = VarInt::from_u32(52); + let (mut bridge, mut hypervisor, _lifecycle) = reader(VarInt::from_u32(19)); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::ErrReset { code: reset_code }) + .await + .expect("reset event should send"); + + let Some(Err(quic::StreamError::Reset { code })) = bridge.next().await else { + panic!("err reset should close the reader"); + }; + assert_eq!(code, reset_code); + assert!(bridge.is_terminated()); + } + + #[tokio::test] + async fn pending_terminal_error_does_not_mark_reader_terminated() { + let (mut bridge, mut hypervisor) = pending_closed_reader(VarInt::from_u32(21)); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::ErrConn) + .await + .expect("connection error event should send"); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none(), + "deferred connection error should remain pending" + ); + assert!(!bridge.is_terminated()); + } + + #[tokio::test] + async fn stream_id_remains_available_after_terminal_error() { + let stream_id = VarInt::from_u32(20); + let reset_code = VarInt::from_u32(53); + let (mut bridge, mut hypervisor, _lifecycle) = reader(stream_id); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::ErrReset { code: reset_code }) + .await + .expect("reset event should send"); + + let Some(Err(quic::StreamError::Reset { code })) = bridge.next().await else { + panic!("err reset should close the reader"); + }; + assert_eq!(code, reset_code); + assert_eq!(bridge.stream_id().await.expect("stream id"), stream_id); + } + + #[tokio::test] + async fn err_conn_deferred_error_returns_lifecycle_canonical_error() { + let canonical = quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(61), + frame_type: VarInt::from_u32(62), + reason: "canonical closed error".into(), + }, + }; + let (mut bridge, mut hypervisor, lifecycle) = reader(VarInt::from_u32(17)); + lifecycle.set_closed_error(canonical.clone()); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + hypervisor + .send(ReadEvent::ErrConn) + .await + .expect("connection error event should send"); + + let Some(Err(error)) = bridge.next().await else { + panic!("err conn should return a stream error"); + }; + let source = stream_connection(error); + assert_eq!(transport(&source).kind, VarInt::from_u32(61)); + let closed = quic::Lifecycle::closed(lifecycle.as_ref()).await; + assert_eq!(transport(&source).kind, transport(&closed).kind); + assert_eq!(transport(&source).frame_type, transport(&closed).frame_type); + assert_eq!(transport(&source).reason, transport(&closed).reason); + assert_eq!(transport(&source).kind, transport(&canonical).kind); + assert_eq!( + transport(&source).frame_type, + transport(&canonical).frame_type + ); + assert_eq!(transport(&source).reason, transport(&canonical).reason); + } + + #[tokio::test] + async fn frame_io_eof_before_read_eos_latches_connection_frame_eof() { + let (mut bridge, mut hypervisor, lifecycle) = reader(VarInt::from_u32(18)); + + assert!( + poll_fn(|cx| bridge.poll_next_unpin(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, ReadCommand::Pull); + drop(hypervisor); + + let Some(Err(error)) = bridge.next().await else { + panic!("frame eof should return a stream error"); + }; + let source = stream_connection(error); + assert_eq!( + transport(&source).reason.as_ref(), + "typed frame stream ended before operation completed" + ); + assert_eq!( + transport(&source).reason.as_ref(), + quic::Lifecycle::closed(lifecycle.as_ref()) + .await + .transport() + .reason + .as_ref() + ); + } + + trait TransportExt { + fn transport(&self) -> &quic::TransportError; + } + + impl TransportExt for quic::ConnectionError { + fn transport(&self) -> &quic::TransportError { + transport(self) + } + } +} diff --git a/src/rpc/stream/remoc.rs b/src/rpc/stream/remoc.rs new file mode 100644 index 0000000..fd3e260 --- /dev/null +++ b/src/rpc/stream/remoc.rs @@ -0,0 +1,344 @@ +use std::{ + borrow::Cow, + collections::VecDeque, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures::{ + Sink, Stream, + future::{BoxFuture, FutureExt as _}, + ready, +}; +use remoc::rch::{Sending, SendingError, mpsc}; +use serde::{Deserialize, Serialize}; +use snafu::{ResultExt as _, Snafu}; + +use super::{ + frame::{ReadCommand, ReadEvent, WriteCommand, WriteEvent}, + reader::BridgeStreamReader, + writer::BridgeStreamWriter, +}; +use crate::{quic, rpc::lifecycle::LifecycleExt, varint::VarInt}; + +const CHANNEL_CAPACITY: usize = 8; +const RPC_FRAME_IO_ERROR_KIND: VarInt = VarInt::from_u32(0x0c); +const RPC_FRAME_IO_FRAME_TYPE: VarInt = VarInt::from_u32(0x00); + +pub type ReadOutSender = mpsc::Sender; +pub type ReadInReceiver = mpsc::Receiver; +pub type WriteOutSender = mpsc::Sender; +pub type WriteInReceiver = mpsc::Receiver; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum RpcFrameIoError { + #[snafu(display("failed to send rpc stream frame"))] + Send { source: mpsc::SendError<()> }, + #[snafu(display("failed to receive rpc stream frame"))] + Receive { source: mpsc::RecvError }, +} + +impl From for quic::ConnectionError { + fn from(error: RpcFrameIoError) -> Self { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: RPC_FRAME_IO_ERROR_KIND, + frame_type: RPC_FRAME_IO_FRAME_TYPE, + // Lossy: QUIC transport error reason is a protocol string field + // for local RPC frame-channel failures. + reason: Cow::Owned(error.to_string()), + }, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ReadFrameChannels { + stream_id: VarInt, + outbound: ReadOutSender, + inbound: ReadInReceiver, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct WriteFrameChannels { + stream_id: VarInt, + outbound: WriteOutSender, + inbound: WriteInReceiver, +} + +impl ReadFrameChannels { + pub(crate) fn pair(stream_id: VarInt) -> (Self, RpcFrameIo) { + let (command_tx, command_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (event_tx, event_rx) = mpsc::channel(CHANNEL_CAPACITY); + ( + Self { + stream_id, + outbound: command_tx, + inbound: event_rx, + }, + RpcFrameIo::new(event_tx, command_rx), + ) + } + + pub(crate) fn into_quic( + self, + lifecycle: Arc, + ) -> BridgeStreamReader, L, RpcFrameIoError> + where + L: LifecycleExt + 'static, + { + BridgeStreamReader::new( + self.stream_id, + RpcFrameIo::new(self.outbound, self.inbound), + lifecycle, + ) + } +} + +impl WriteFrameChannels { + pub(crate) fn pair(stream_id: VarInt) -> (Self, RpcFrameIo) { + let (command_tx, command_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (event_tx, event_rx) = mpsc::channel(CHANNEL_CAPACITY); + ( + Self { + stream_id, + outbound: command_tx, + inbound: event_rx, + }, + RpcFrameIo::new(event_tx, command_rx), + ) + } + + pub(crate) fn into_quic( + self, + lifecycle: Arc, + ) -> BridgeStreamWriter, L, RpcFrameIoError> + where + L: LifecycleExt + 'static, + { + BridgeStreamWriter::new( + self.stream_id, + RpcFrameIo::new(self.outbound, self.inbound), + lifecycle, + ) + } +} + +type ReserveFuture = BoxFuture<'static, Result, mpsc::SendError<()>>>; + +pub(crate) struct RpcFrameIo { + sender: Option>, + reserve: Option>, + permit: Option>, + sending: VecDeque>, + receiver: mpsc::Receiver, + _in: PhantomData In>, +} + +impl RpcFrameIo { + fn new(sender: mpsc::Sender, receiver: mpsc::Receiver) -> Self + where + Out: Send + 'static, + { + Self { + sender: Some(sender), + reserve: None, + permit: None, + sending: VecDeque::new(), + receiver, + _in: PhantomData, + } + } + + fn reserve(sender: mpsc::Sender) -> ReserveFuture + where + Out: Send + 'static, + { + async move { sender.reserve().await }.boxed() + } +} + +fn sending_error_to_send_error(error: SendingError) -> mpsc::SendError<()> { + match error { + SendingError::Send(source) => mpsc::SendError::RemoteSend(source.kind), + SendingError::Dropped => mpsc::SendError::Closed(()), + } +} + +impl Sink for RpcFrameIo +where + Out: Send + 'static, +{ + type Error = RpcFrameIoError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.permit.is_some() { + return Poll::Ready(Ok(())); + } + + if self.reserve.is_none() { + let Some(sender) = self.sender.clone() else { + return Poll::Ready(Err(RpcFrameIoError::Send { + source: mpsc::SendError::Closed(()), + })); + }; + self.reserve = Some(Self::reserve(sender)); + } + + let reserve = self + .reserve + .as_mut() + .expect("reserve future should be initialized"); + let result = ready!(reserve.as_mut().poll(cx)).context(rpc_frame_io_error::SendSnafu); + self.reserve = None; + match result { + Ok(permit) => { + self.permit = Some(permit); + Poll::Ready(Ok(())) + } + Err(error) => Poll::Ready(Err(error)), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + let permit = self + .permit + .take() + .expect("rpc frame io sender is not ready"); + self.sending.push_back(permit.send(item)); + Ok(()) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Some(sending) = self.sending.front_mut() { + let result = ready!(sending.poll_unpin(cx)); + self.sending.pop_front(); + if let Err(error) = result { + return Poll::Ready(Err(RpcFrameIoError::Send { + source: sending_error_to_send_error(error), + })); + } + } + Poll::Ready(Ok(())) + } + + fn poll_close( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.sender = None; + self.reserve = None; + self.permit = None; + self.sending.clear(); + Poll::Ready(Ok(())) + } +} + +impl Stream for RpcFrameIo { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match ready!(self.receiver.poll_recv(cx)).context(rpc_frame_io_error::ReceiveSnafu) { + Ok(Some(frame)) => Poll::Ready(Some(Ok(frame))), + Ok(None) => Poll::Ready(None), + Err(error) => Poll::Ready(Some(Err(error))), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use futures::{SinkExt as _, StreamExt as _, future::poll_fn}; + + use super::*; + use crate::rpc::stream::frame::{ReadCommand, ReadEvent}; + + #[tokio::test] + async fn read_frame_channel_pair_is_bidirectional() { + let stream_id = VarInt::from_u32(701); + let (channels, mut hypervisor) = ReadFrameChannels::pair(stream_id); + let mut worker = RpcFrameIo::new(channels.outbound, channels.inbound); + + let send = worker.send(ReadCommand::Pull); + let receive = hypervisor.next(); + let (send, received) = tokio::join!(send, receive); + send.unwrap(); + assert_eq!(received.unwrap().unwrap(), ReadCommand::Pull); + + let send = hypervisor.send(ReadEvent::Push { + data: Bytes::from_static(b"rpc frame"), + }); + let receive = worker.next(); + let (send, received) = tokio::join!(send, receive); + send.unwrap(); + assert_eq!( + received.unwrap().unwrap(), + ReadEvent::Push { + data: Bytes::from_static(b"rpc frame"), + } + ); + } + + #[tokio::test] + async fn flush_waits_until_remoc_send_is_observed() { + let (outbound, mut outbound_rx): (ReadOutSender, mpsc::Receiver) = + mpsc::channel(CHANNEL_CAPACITY); + let (_inbound_tx, inbound): (mpsc::Sender, ReadInReceiver) = + mpsc::channel(CHANNEL_CAPACITY); + let mut io = RpcFrameIo::new(outbound, inbound); + + let send = io.send(ReadCommand::Pull); + tokio::pin!(send); + assert!( + send.as_mut().now_or_never().is_none(), + "flush must wait for remoc to finish sending the queued frame" + ); + + assert_eq!(outbound_rx.recv().await.unwrap(), Some(ReadCommand::Pull)); + send.await.unwrap(); + } + + #[tokio::test] + async fn flush_waits_for_all_started_remoc_sends() { + let (outbound, mut outbound_rx): (ReadOutSender, mpsc::Receiver) = + mpsc::channel(CHANNEL_CAPACITY); + let (_inbound_tx, inbound): (mpsc::Sender, ReadInReceiver) = + mpsc::channel(CHANNEL_CAPACITY); + let mut io = RpcFrameIo::new(outbound, inbound); + let stop = VarInt::from_u32(9); + + poll_fn(|cx| Pin::new(&mut io).poll_ready(cx)) + .await + .unwrap(); + Pin::new(&mut io).start_send(ReadCommand::Pull).unwrap(); + poll_fn(|cx| Pin::new(&mut io).poll_ready(cx)) + .await + .unwrap(); + Pin::new(&mut io) + .start_send(ReadCommand::Stop { code: stop }) + .unwrap(); + + let flush = poll_fn(|cx| Pin::new(&mut io).poll_flush(cx)); + tokio::pin!(flush); + assert!( + flush.as_mut().now_or_never().is_none(), + "flush must wait for the first started send" + ); + + assert_eq!(outbound_rx.recv().await.unwrap(), Some(ReadCommand::Pull)); + assert!( + flush.as_mut().now_or_never().is_none(), + "flush must wait for later started sends too" + ); + + assert_eq!( + outbound_rx.recv().await.unwrap(), + Some(ReadCommand::Stop { code: stop }) + ); + flush.await.unwrap(); + } +} diff --git a/src/rpc/stream/test_io.rs b/src/rpc/stream/test_io.rs new file mode 100644 index 0000000..bd62adf --- /dev/null +++ b/src/rpc/stream/test_io.rs @@ -0,0 +1,378 @@ +use std::{ + borrow::Cow, + marker::PhantomData, + pin::Pin, + sync::Mutex, + task::{Context, Poll}, +}; + +use futures::{Sink, SinkExt as _, Stream, StreamExt as _, channel::mpsc}; +use snafu::Snafu; + +use super::frame::{ + HyperReadIn, HyperReadOut, HyperWriteIn, HyperWriteOut, WorkerReadIn, WorkerReadOut, + WorkerWriteIn, WorkerWriteOut, +}; +use crate::{ + error::Code, + quic, + rpc::lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, + varint::VarInt, +}; + +const MEMORY_FRAME_IO_CAPACITY: usize = 8; +const TEST_FRAME_IO_FRAME_TYPE: VarInt = VarInt::from_u32(0x3f02); + +pub(crate) struct TestLifecycle { + latch: ConnectionErrorLatch, + closed_error: Mutex>, +} + +impl TestLifecycle { + pub(crate) fn new() -> Self { + Self { + latch: ConnectionErrorLatch::new(), + closed_error: Mutex::new(None), + } + } + + pub(crate) fn set_closed_error(&self, error: quic::ConnectionError) { + *self.closed_error.lock().unwrap() = Some(error); + } +} + +impl HasLatch for TestLifecycle { + fn latch(&self) -> &ConnectionErrorLatch { + &self.latch + } +} + +impl quic::Lifecycle for TestLifecycle { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + self.check_with_probe(|| None) + } + + async fn closed(&self) -> quic::ConnectionError { + self.resolve_closed(async { + self.closed_error + .lock() + .unwrap() + .take() + .unwrap_or_else(|| quic::ConnectionError::from(TestFrameIoError::new(0xfe))) + }) + .await + } +} + +#[derive(Debug, Snafu, Clone, PartialEq, Eq)] +#[snafu(display("test frame io error {kind}"))] +pub(crate) struct TestFrameIoError { + kind: VarInt, +} + +impl TestFrameIoError { + pub(crate) const fn new(kind: u32) -> Self { + Self { + kind: VarInt::from_u32(kind), + } + } + + pub(crate) const fn kind(&self) -> VarInt { + self.kind + } +} + +impl From for quic::ConnectionError { + fn from(error: TestFrameIoError) -> Self { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: error.kind(), + frame_type: TEST_FRAME_IO_FRAME_TYPE, + reason: "test frame io error".into(), + }, + } + } +} + +#[derive(Debug, Snafu, Clone)] +#[snafu(display("memory frame io peer closed"))] +pub(crate) struct MemoryFrameIoClosed; + +impl From for TestFrameIoError { + fn from(_error: MemoryFrameIoClosed) -> Self { + Self::new(0xff) + } +} + +pub(crate) struct MemoryFrameIo { + outgoing: mpsc::Sender>, + incoming: mpsc::Receiver>, + _error: PhantomData, +} + +impl MemoryFrameIo { + pub(crate) async fn next_frame(&mut self) -> Option> { + self.incoming.next().await + } +} + +impl Unpin for MemoryFrameIo {} + +impl Sink for MemoryFrameIo +where + E: From, +{ + type Error = E; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.outgoing.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(_error)) => Poll::Ready(Err(MemoryFrameIoClosed.into())), + Poll::Pending => Poll::Pending, + } + } + + fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + let this = self.get_mut(); + match this.outgoing.start_send_unpin(Ok(item)) { + Ok(()) => Ok(()), + Err(_error) => Err(MemoryFrameIoClosed.into()), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.outgoing.poll_flush_unpin(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(_error)) => Poll::Ready(Err(MemoryFrameIoClosed.into())), + Poll::Pending => Poll::Pending, + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.outgoing.poll_close_unpin(cx) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(_error)) => Poll::Ready(Err(MemoryFrameIoClosed.into())), + Poll::Pending => Poll::Pending, + } + } +} + +impl Stream for MemoryFrameIo { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + this.incoming.poll_next_unpin(cx) + } +} + +fn memory_frame_io_pair() -> (MemoryFrameIo, MemoryFrameIo) +{ + let (a_outgoing, b_incoming) = mpsc::channel(MEMORY_FRAME_IO_CAPACITY); + let (b_outgoing, a_incoming) = mpsc::channel(MEMORY_FRAME_IO_CAPACITY); + + ( + MemoryFrameIo { + outgoing: a_outgoing, + incoming: a_incoming, + _error: PhantomData, + }, + MemoryFrameIo { + outgoing: b_outgoing, + incoming: b_incoming, + _error: PhantomData, + }, + ) +} + +pub(crate) fn worker_reader_pair() -> ( + MemoryFrameIo, + MemoryFrameIo, +) { + memory_frame_io_pair() +} + +pub(crate) fn worker_writer_pair() -> ( + MemoryFrameIo, + MemoryFrameIo, +) { + memory_frame_io_pair() +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::rpc::stream::{ + error::{ + DeferredStreamError, DriverProtocolError, defer_err_conn, latch_frame_io_error, + latch_protocol_error, + }, + io::FrameIo, + }; + + fn transport(error: &quic::ConnectionError) -> &quic::TransportError { + let quic::ConnectionError::Transport { source } = error else { + panic!("connection error should be transport-scoped"); + }; + source + } + + fn transport_kind(error: &quic::ConnectionError) -> VarInt { + transport(error).kind + } + + fn transport_reason(error: &quic::ConnectionError) -> &str { + transport(error).reason.as_ref() + } + + fn stream_connection_kind(error: quic::StreamError) -> VarInt { + let quic::StreamError::Connection { source } = error else { + panic!("stream error should be connection-scoped"); + }; + transport_kind(&source) + } + + fn assert_frame_io(_io: &IO) + where + IO: FrameIo, + { + } + + #[tokio::test] + async fn frame_io_error_is_latched_first_wins() { + let lifecycle = TestLifecycle::new(); + + let first = latch_frame_io_error(&lifecycle, TestFrameIoError::new(11)); + let second = latch_frame_io_error(&lifecycle, TestFrameIoError::new(12)); + + assert_eq!(stream_connection_kind(first), VarInt::from_u32(11)); + assert_eq!(stream_connection_kind(second), VarInt::from_u32(11)); + assert_eq!( + transport_kind(&lifecycle.latch().check().unwrap_err()), + VarInt::from_u32(11) + ); + } + + #[tokio::test] + async fn protocol_error_is_latched_first_wins() { + let lifecycle = TestLifecycle::new(); + + let first = latch_protocol_error(&lifecycle, DriverProtocolError::UnexpectedReaderFrame); + let second = latch_protocol_error(&lifecycle, DriverProtocolError::UnexpectedWriterAck); + + assert_eq!( + stream_connection_kind(first), + stream_connection_kind(second) + ); + let latched = lifecycle.latch().check().unwrap_err(); + assert_eq!( + transport_reason(&latched), + "received frame that is invalid for the current reader operation" + ); + assert_eq!( + transport_reason(&latched), + transport_reason(&quic::Lifecycle::closed(&lifecycle).await) + ); + } + + #[tokio::test] + async fn defer_err_conn_uses_immediate_check_when_already_closed() { + let lifecycle = Arc::new(TestLifecycle::new()); + lifecycle + .latch() + .latch_with(|| quic::ConnectionError::from(TestFrameIoError::new(21))); + + let DeferredStreamError::Ready { error } = defer_err_conn(lifecycle) else { + panic!("already closed lifecycle should produce immediate stream error"); + }; + + assert_eq!(stream_connection_kind(error), VarInt::from_u32(21)); + } + + #[tokio::test] + async fn defer_err_conn_can_wait_for_closed_when_not_latched() { + let lifecycle = Arc::new(TestLifecycle::new()); + lifecycle.set_closed_error(quic::ConnectionError::from(TestFrameIoError::new(31))); + + let DeferredStreamError::Pending { future } = defer_err_conn(lifecycle) else { + panic!("clean lifecycle should defer until closed"); + }; + + assert_eq!(stream_connection_kind(future.await), VarInt::from_u32(31)); + } + + #[tokio::test] + async fn worker_reader_io_eof_is_latched_as_connection_error_not_read_eos() { + let lifecycle = TestLifecycle::new(); + let (mut worker, hypervisor) = worker_reader_pair::(); + assert_frame_io::<_, WorkerReadOut, WorkerReadIn, TestFrameIoError>(&worker); + drop(hypervisor); + + let eof = match worker.next_frame().await { + None => latch_protocol_error(&lifecycle, DriverProtocolError::FrameEof), + Some(Ok(WorkerReadIn::Eos)) => { + panic!("worker bridge IO EOF must not synthesize read EOS") + } + Some(frame) => panic!("worker bridge IO EOF produced unexpected frame {frame:?}"), + }; + let quic::StreamError::Connection { source } = eof else { + panic!("worker frame EOF should be a connection error"); + }; + + assert_eq!( + transport_reason(&source), + "typed frame stream ended before operation completed" + ); + assert_eq!( + transport_reason(&source), + transport_reason(&quic::Lifecycle::closed(&lifecycle).await) + ); + } + + #[tokio::test] + async fn worker_writer_io_eof_is_latched_as_connection_error_not_write_close() { + let lifecycle = TestLifecycle::new(); + let (mut worker, hypervisor) = worker_writer_pair::(); + assert_frame_io::<_, WorkerWriteOut, WorkerWriteIn, TestFrameIoError>(&worker); + drop(hypervisor); + + let eof = match worker.next_frame().await { + None => latch_protocol_error(&lifecycle, DriverProtocolError::FrameEof), + Some(Ok(WorkerWriteIn::EosAck)) => { + panic!("worker bridge IO EOF must not synthesize write close ack") + } + Some(frame) => panic!("worker bridge IO EOF produced unexpected frame {frame:?}"), + }; + let quic::StreamError::Connection { source } = eof else { + panic!("worker frame EOF should be a connection error"); + }; + + assert_eq!( + transport_reason(&source), + "typed frame stream ended before operation completed" + ); + assert_eq!( + transport_reason(&source), + transport_reason(&quic::Lifecycle::closed(&lifecycle).await) + ); + } + + #[tokio::test] + async fn memory_worker_writer_pair_transfers_frames() { + let (mut worker, mut hypervisor) = worker_writer_pair::(); + + worker + .send(WorkerWriteOut::Flush) + .await + .expect("send should succeed"); + + assert_eq!(hypervisor.next_frame().await, Some(Ok(HyperWriteIn::Flush))); + } +} diff --git a/src/rpc/stream/writer.rs b/src/rpc/stream/writer.rs new file mode 100644 index 0000000..a66f53d --- /dev/null +++ b/src/rpc/stream/writer.rs @@ -0,0 +1,2118 @@ +use std::{ + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll, ready}, +}; + +use bytes::Bytes; +use futures::Sink; + +use super::{ + drain, + error::{ + DeferredStreamError, DriverProtocolError, defer_err_conn, latch_frame_io_error, + latch_protocol_error, + }, + frame::{WriteCommand, WriteEvent}, +}; +use crate::{quic, varint::VarInt}; + +const WRITE_AFTER_EOS_PANIC: &str = "h3x write data after shutdown"; + +#[derive(Debug, Clone, PartialEq, Eq)] +enum WriteOperation { + Push { data: Bytes }, + Flush { generation: u64 }, + Eos { covers_flush: bool }, + Reset { code: VarInt }, +} + +#[derive(Debug)] +enum SendState { + Idle { ready: bool }, + Flush { inflight: Op }, +} + +impl Default for SendState { + fn default() -> Self { + Self::Idle { ready: false } + } +} + +#[derive(Debug, Default)] +struct PendingWrite { + push: Option, + control: Option, + reset: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum WriteBarrier { + Flush { generation: u64 }, + Eos { covers_flush: bool }, +} + +#[derive(Debug, Default)] +enum WriterRecvState { + #[default] + Idle, + Await { + sent: WriterAwaited, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum WriterAwaited { + Flush { generation: u64 }, + Eos { covers_flush: bool }, + Reset { code: VarInt }, +} + +#[derive(Debug, Default)] +struct WriterResults { + credit: bool, + flush: Option, + eos: Option, + reset: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FlushCompletion { + FlushAck { generation: u64 }, + CoveredByEos, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct EosAckResult; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ResetAckResult { + code: VarInt, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CommitResult { + Committed, + Duplicate, + Conflict, +} + +enum WriterFault { + Stream(quic::StreamError), + Deferred(DeferredStreamError), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum DriveTerminal { + None, + Eos, + Reset { code: VarInt }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ResetDrive { + Acked, + CoveredByEos, +} + +pin_project_lite::pin_project! { + #[project = BridgeStreamWriterProj] + #[project_replace = BridgeStreamWriterReplace] + pub(crate) enum BridgeStreamWriter + where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, + { + Active { + active: Pin>>, + }, + Eos { + stream_id: VarInt, + }, + Reset { + stream_id: VarInt, + code: VarInt, + }, + Closed { + stream_id: VarInt, + error: DeferredStreamError, + }, + } + + impl PinnedDrop for BridgeStreamWriter + where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, + { + fn drop(mut this: Pin<&mut Self>) { + let stream_id = match this.as_mut().project() { + BridgeStreamWriterProj::Active { active } => active.as_ref().get_ref().stream_id, + BridgeStreamWriterProj::Eos { .. } + | BridgeStreamWriterProj::Reset { .. } + | BridgeStreamWriterProj::Closed { .. } => return, + }; + + match this.project_replace(BridgeStreamWriter::Eos { stream_id }) { + BridgeStreamWriterReplace::Active { active } + if active.as_ref().get_ref().has_committed_outbound() => + { + drain::spawn_write_drain(active); + } + _ => {} + } + } + } +} + +pin_project_lite::pin_project! { + pub(crate) struct ActiveBridgeStreamWriter { + stream_id: VarInt, + lifecycle: Arc, + #[pin] + bridge: Io, + pending: PendingWrite, + send_state: SendState, + recv_state: WriterRecvState, + results: WriterResults, + write_generation: u64, + _error: PhantomData, + } +} + +impl BridgeStreamWriter +where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + pub(crate) fn new(stream_id: VarInt, bridge: Io, lifecycle: Arc) -> Self { + Self::Active { + active: Box::pin(ActiveBridgeStreamWriter { + stream_id, + lifecycle, + bridge, + pending: PendingWrite::default(), + send_state: SendState::default(), + recv_state: WriterRecvState::default(), + results: WriterResults::default(), + write_generation: 0, + _error: PhantomData, + }), + } + } +} + +impl ActiveBridgeStreamWriter { + pub(super) fn has_committed_outbound(&self) -> bool { + self.pending.push.is_some() + || self.pending.control.is_some() + || self.pending.reset.is_some() + || matches!(self.send_state, SendState::Flush { .. }) + } + + fn committed_reset_code(&self) -> Option { + if let Some(code) = self.pending.reset { + return Some(code); + } + if let SendState::Flush { + inflight: WriteOperation::Reset { code }, + } = &self.send_state + { + return Some(*code); + } + if let WriterRecvState::Await { + sent: WriterAwaited::Reset { code }, + } = self.recv_state + { + return Some(code); + } + self.results.reset.as_ref().map(|result| result.code) + } + + fn eos_committed(&self) -> bool { + matches!(self.pending.control, Some(WriteBarrier::Eos { .. })) + || matches!( + self.send_state, + SendState::Flush { + inflight: WriteOperation::Eos { .. } + } + ) + || matches!( + self.recv_state, + WriterRecvState::Await { + sent: WriterAwaited::Eos { .. } + } + ) + || self.results.eos.is_some() + } + + fn eos_started_or_completed(&self) -> bool { + matches!( + self.send_state, + SendState::Flush { + inflight: WriteOperation::Eos { .. } + } + ) || matches!( + self.recv_state, + WriterRecvState::Await { + sent: WriterAwaited::Eos { .. } + } + ) || self.results.eos.is_some() + } + + fn has_unflushed_control_barrier(&self) -> bool { + self.pending.control.is_some() + || matches!( + self.send_state, + SendState::Flush { + inflight: WriteOperation::Flush { .. } + | WriteOperation::Eos { .. } + | WriteOperation::Reset { .. } + } + ) + } + + fn mark_eos_covers_flush_pinned(mut self: Pin<&mut Self>) { + let mut this = self.as_mut().project(); + if let Some(WriteBarrier::Eos { covers_flush }) = &mut this.pending.control { + *covers_flush = true; + } + if let SendState::Flush { + inflight: WriteOperation::Eos { covers_flush }, + } = &mut this.send_state + { + *covers_flush = true; + } + if let WriterRecvState::Await { + sent: WriterAwaited::Eos { covers_flush }, + } = &mut this.recv_state + { + *covers_flush = true; + } + if this.results.eos.is_some() { + this.results.flush = Some(FlushCompletion::CoveredByEos); + } + } + + fn flush_committed_for(&self, generation: u64) -> bool { + matches!(self.pending.control, Some(WriteBarrier::Flush { generation: committed }) if committed >= generation) + || matches!( + self.send_state, + SendState::Flush { + inflight: WriteOperation::Flush { + generation: committed + } + } + if committed >= generation + ) + || matches!( + self.recv_state, + WriterRecvState::Await { + sent: WriterAwaited::Flush { + generation: committed + } + } + if committed >= generation + ) + || self.flush_result_covers(generation) + } + + fn flush_result_covers(&self, generation: u64) -> bool { + match self.results.flush { + Some(FlushCompletion::FlushAck { + generation: committed, + }) => committed >= generation, + Some(FlushCompletion::CoveredByEos) => true, + None => false, + } + } + + fn reset_matches_result(&self, code: VarInt) -> bool { + matches!(self.results.reset, Some(ResetAckResult { code: actual }) if actual == code) + } + + fn take_flush_result_pinned( + mut self: Pin<&mut Self>, + generation: u64, + ) -> Option { + let this = self.as_mut().project(); + if match this.results.flush { + Some(FlushCompletion::FlushAck { + generation: committed, + }) => committed >= generation, + Some(FlushCompletion::CoveredByEos) => true, + None => false, + } { + this.results.flush.take() + } else { + None + } + } + + fn take_eos_result_pinned(mut self: Pin<&mut Self>) -> Option { + self.as_mut().project().results.eos.take() + } + + fn take_reset_result_pinned(mut self: Pin<&mut Self>, code: VarInt) -> Option { + let this = self.as_mut().project(); + match this.results.reset { + Some(result) if result.code == code => this.results.reset.take(), + _ => None, + } + } + + fn commit_flush_pinned(mut self: Pin<&mut Self>) -> (CommitResult, u64) { + let generation = self.as_ref().get_ref().write_generation; + if self.as_ref().get_ref().committed_reset_code().is_some() { + return (CommitResult::Duplicate, generation); + } + if self.as_ref().get_ref().eos_committed() { + self.as_mut().mark_eos_covers_flush_pinned(); + return (CommitResult::Duplicate, generation); + } + if self.as_ref().get_ref().flush_committed_for(generation) { + return (CommitResult::Duplicate, generation); + } + self.as_mut().project().pending.control = Some(WriteBarrier::Flush { generation }); + (CommitResult::Committed, generation) + } + + fn commit_eos_pinned(mut self: Pin<&mut Self>) -> CommitResult { + if self.as_ref().get_ref().committed_reset_code().is_some() { + return CommitResult::Duplicate; + } + if self.as_ref().get_ref().eos_committed() { + return CommitResult::Duplicate; + } + + let this = self.as_mut().project(); + this.results.credit = false; + match this.pending.control { + Some(WriteBarrier::Flush { .. }) => { + this.pending.control = Some(WriteBarrier::Eos { covers_flush: true }); + } + Some(WriteBarrier::Eos { .. }) => return CommitResult::Duplicate, + None => { + this.pending.control = Some(WriteBarrier::Eos { + covers_flush: false, + }); + } + } + CommitResult::Committed + } + + fn commit_reset_pinned(mut self: Pin<&mut Self>, code: VarInt) -> CommitResult { + if self.as_ref().get_ref().eos_started_or_completed() { + return CommitResult::Duplicate; + } + match self.as_ref().get_ref().committed_reset_code() { + Some(committed) if committed == code => CommitResult::Duplicate, + Some(_committed) => CommitResult::Conflict, + None => { + let this = self.as_mut().project(); + this.pending.push = None; + this.pending.control = None; + this.pending.reset = Some(code); + this.results.credit = false; + CommitResult::Committed + } + } + } +} + +impl ActiveBridgeStreamWriter +where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn protocol_fault_for(lifecycle: &Arc, error: DriverProtocolError) -> WriterFault { + WriterFault::Stream(latch_protocol_error(lifecycle.as_ref(), error)) + } + + fn frame_io_fault_for(lifecycle: &Arc, error: E) -> WriterFault { + WriterFault::Stream(latch_frame_io_error(lifecycle.as_ref(), error)) + } + + fn commit_push_pinned(mut self: Pin<&mut Self>, data: Bytes) -> Result<(), quic::StreamError> { + if self.as_ref().get_ref().eos_committed() { + panic!("{WRITE_AFTER_EOS_PANIC}"); + } + if let Some(code) = self.as_ref().get_ref().committed_reset_code() { + return Err(quic::StreamError::Reset { code }); + } + + let this = self.as_mut().project(); + if !this.results.credit || this.pending.push.is_some() { + return Err(latch_protocol_error( + this.lifecycle.as_ref(), + DriverProtocolError::StartSendWithoutCredit, + )); + } + this.results.credit = false; + *this.write_generation += 1; + this.pending.push = Some(data); + Ok(()) + } + + fn poll_bridge_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + match &mut *this.send_state { + SendState::Idle { ready: true } => Poll::Ready(Ok(())), + SendState::Idle { ready } => match ready!(this.bridge.as_mut().poll_ready(cx)) { + Ok(()) => { + *ready = true; + Poll::Ready(Ok(())) + } + Err(error) => Poll::Ready(Err(Self::frame_io_fault_for(this.lifecycle, error))), + }, + SendState::Flush { .. } => unreachable!("cannot ready writer sink while flushing"), + } + } + + fn start_operation_pinned( + mut self: Pin<&mut Self>, + operation: WriteOperation, + ) -> Result<(), WriterFault> { + let command = match &operation { + WriteOperation::Push { data } => WriteCommand::Push { data: data.clone() }, + WriteOperation::Flush { .. } => WriteCommand::Flush, + WriteOperation::Eos { .. } => WriteCommand::Eos, + WriteOperation::Reset { code } => WriteCommand::Reset { code: *code }, + }; + + let mut this = self.as_mut().project(); + match this.bridge.as_mut().start_send(command) { + Ok(()) => { + *this.send_state = SendState::Flush { + inflight: operation, + }; + Ok(()) + } + Err(error) => Err(Self::frame_io_fault_for(this.lifecycle, error)), + } + } + + fn poll_drain_ready( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut this = self.as_mut().project(); + match &mut *this.send_state { + SendState::Idle { ready: true } => Poll::Ready(Ok(())), + SendState::Idle { ready } => match ready!(this.bridge.as_mut().poll_ready(cx)) { + Ok(()) => { + *ready = true; + Poll::Ready(Ok(())) + } + Err(error) => { + Poll::Ready(Err(latch_frame_io_error(this.lifecycle.as_ref(), error))) + } + }, + SendState::Flush { .. } => unreachable!("writer drain readiness while flushing"), + } + } + + fn start_drain_operation( + mut self: Pin<&mut Self>, + operation: WriteOperation, + ) -> Result<(), quic::StreamError> { + let command = match &operation { + WriteOperation::Push { data } => WriteCommand::Push { data: data.clone() }, + WriteOperation::Flush { .. } => WriteCommand::Flush, + WriteOperation::Eos { .. } => WriteCommand::Eos, + WriteOperation::Reset { code } => WriteCommand::Reset { code: *code }, + }; + + let mut this = self.as_mut().project(); + match this.bridge.as_mut().start_send(command) { + Ok(()) => { + *this.send_state = SendState::Flush { + inflight: operation, + }; + Ok(()) + } + Err(error) => Err(latch_frame_io_error(this.lifecycle.as_ref(), error)), + } + } + + pub(super) fn poll_drain(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + loop { + if matches!(self.as_ref().get_ref().send_state, SendState::Flush { .. }) { + let mut this = self.as_mut().project(); + match ready!(this.bridge.as_mut().poll_flush(cx)) { + Ok(()) => { + let sent = match std::mem::replace( + this.send_state, + SendState::Idle { ready: false }, + ) { + SendState::Flush { inflight } => inflight, + SendState::Idle { .. } => unreachable!("flush state checked above"), + }; + match sent { + WriteOperation::Push { .. } => continue, + WriteOperation::Flush { .. } + | WriteOperation::Eos { .. } + | WriteOperation::Reset { .. } => return Poll::Ready(()), + } + } + Err(error) => { + let error = latch_frame_io_error(this.lifecycle.as_ref(), error); + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + } + } + + if matches!( + self.as_ref().get_ref().recv_state, + WriterRecvState::Await { .. } + ) { + return Poll::Ready(()); + } + + if self.as_ref().get_ref().pending.push.is_some() { + if let Err(error) = ready!(self.as_mut().poll_drain_ready(cx)) { + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + let data = self + .as_mut() + .project() + .pending + .push + .take() + .expect("pending push should exist"); + if let Err(error) = self + .as_mut() + .start_drain_operation(WriteOperation::Push { data }) + { + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + continue; + } + + if let Some(code) = self.as_ref().get_ref().pending.reset { + if let Err(error) = ready!(self.as_mut().poll_drain_ready(cx)) { + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + self.as_mut().project().pending.reset = None; + if let Err(error) = self + .as_mut() + .start_drain_operation(WriteOperation::Reset { code }) + { + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + continue; + } + + if let Some(control) = self.as_ref().get_ref().pending.control { + if let Err(error) = ready!(self.as_mut().poll_drain_ready(cx)) { + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + self.as_mut().project().pending.control = None; + let operation = match control { + WriteBarrier::Flush { generation } => WriteOperation::Flush { generation }, + WriteBarrier::Eos { covers_flush } => WriteOperation::Eos { covers_flush }, + }; + if let Err(error) = self.as_mut().start_drain_operation(operation) { + drain::log_drain_error(&error, "writer"); + return Poll::Ready(()); + } + continue; + } + + return Poll::Ready(()); + } + } + + fn poll_send_progress( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + if matches!(self.as_ref().get_ref().send_state, SendState::Flush { .. }) { + let mut this = self.as_mut().project(); + match ready!(this.bridge.as_mut().poll_flush(cx)) { + Ok(()) => { + let sent = match std::mem::replace( + this.send_state, + SendState::Idle { ready: false }, + ) { + SendState::Flush { inflight } => inflight, + SendState::Idle { .. } => unreachable!("flush state checked above"), + }; + match sent { + WriteOperation::Push { .. } => continue, + WriteOperation::Flush { generation } => { + *this.recv_state = WriterRecvState::Await { + sent: WriterAwaited::Flush { generation }, + }; + return Poll::Ready(Ok(())); + } + WriteOperation::Eos { covers_flush } => { + *this.recv_state = WriterRecvState::Await { + sent: WriterAwaited::Eos { covers_flush }, + }; + return Poll::Ready(Ok(())); + } + WriteOperation::Reset { code } => { + *this.recv_state = WriterRecvState::Await { + sent: WriterAwaited::Reset { code }, + }; + return Poll::Ready(Ok(())); + } + } + } + Err(error) => { + return Poll::Ready(Err(Self::frame_io_fault_for(this.lifecycle, error))); + } + } + } + + if self.as_ref().get_ref().pending.push.is_some() + && !matches!( + self.as_ref().get_ref().recv_state, + WriterRecvState::Await { + sent: WriterAwaited::Eos { .. } | WriterAwaited::Reset { .. } + } + ) + { + ready!(self.as_mut().poll_bridge_ready(cx))?; + let data = self + .as_mut() + .project() + .pending + .push + .take() + .expect("pending push should exist"); + self.as_mut() + .start_operation_pinned(WriteOperation::Push { data })?; + continue; + } + + if matches!( + self.as_ref().get_ref().recv_state, + WriterRecvState::Await { .. } + ) { + return Poll::Ready(Ok(())); + } + + if let Some(code) = self.as_ref().get_ref().pending.reset { + ready!(self.as_mut().poll_bridge_ready(cx))?; + self.as_mut().project().pending.reset = None; + self.as_mut() + .start_operation_pinned(WriteOperation::Reset { code })?; + continue; + } + + if let Some(control) = self.as_ref().get_ref().pending.control { + ready!(self.as_mut().poll_bridge_ready(cx))?; + self.as_mut().project().pending.control = None; + let operation = match control { + WriteBarrier::Flush { generation } => WriteOperation::Flush { generation }, + WriteBarrier::Eos { covers_flush } => WriteOperation::Eos { covers_flush }, + }; + self.as_mut().start_operation_pinned(operation)?; + continue; + } + + return Poll::Ready(Ok(())); + } + } + + fn pair_awaited_response( + lifecycle: &Arc, + results: &mut WriterResults, + awaited: WriterAwaited, + frame: WriteEvent, + ) -> Result<(), WriterFault> { + match (awaited, frame) { + (WriterAwaited::Flush { generation }, WriteEvent::FlushAck) => { + results.flush = Some(FlushCompletion::FlushAck { generation }); + Ok(()) + } + (WriterAwaited::Eos { covers_flush }, WriteEvent::EosAck) => { + results.eos = Some(EosAckResult); + if covers_flush { + results.flush = Some(FlushCompletion::CoveredByEos); + } + Ok(()) + } + (WriterAwaited::Reset { code }, WriteEvent::ResetAck { code: actual }) + if code == actual => + { + results.reset = Some(ResetAckResult { code }); + Ok(()) + } + (WriterAwaited::Reset { code }, WriteEvent::ResetAck { code: actual }) => { + Err(Self::protocol_fault_for( + lifecycle, + DriverProtocolError::ResetAckCodeMismatch { + expected: code, + actual, + }, + )) + } + (_awaited, WriteEvent::ErrReset { code }) => { + Err(WriterFault::Stream(quic::StreamError::Reset { code })) + } + (_awaited, WriteEvent::ErrConn) => { + Err(WriterFault::Deferred(defer_err_conn(lifecycle.clone()))) + } + (_awaited, _frame) => Err(Self::protocol_fault_for( + lifecycle, + DriverProtocolError::UnexpectedWriterAck, + )), + } + } + + fn poll_recv_progress( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let terminal_committed = self.as_ref().get_ref().eos_committed() + || self.as_ref().get_ref().committed_reset_code().is_some(); + let mut this = self.as_mut().project(); + match ready!(this.bridge.as_mut().poll_next(cx)) { + Some(Ok(WriteEvent::Pull)) if terminal_committed => Poll::Ready(Ok(())), + Some(Ok(WriteEvent::Pull)) if this.results.credit => { + Poll::Ready(Err(Self::protocol_fault_for( + this.lifecycle, + DriverProtocolError::DuplicateWriterCredit, + ))) + } + Some(Ok(WriteEvent::Pull)) => { + this.results.credit = true; + Poll::Ready(Ok(())) + } + Some(Ok( + frame @ (WriteEvent::FlushAck | WriteEvent::EosAck | WriteEvent::ResetAck { .. }), + )) => { + let awaited = match std::mem::replace(this.recv_state, WriterRecvState::Idle) { + WriterRecvState::Await { sent } => sent, + WriterRecvState::Idle => { + return Poll::Ready(Err(Self::protocol_fault_for( + this.lifecycle, + DriverProtocolError::UnexpectedWriterAck, + ))); + } + }; + Poll::Ready(Self::pair_awaited_response( + this.lifecycle, + this.results, + awaited, + frame, + )) + } + Some(Ok(WriteEvent::ErrReset { code })) => { + Poll::Ready(Err(WriterFault::Stream(quic::StreamError::Reset { code }))) + } + Some(Ok(WriteEvent::ErrConn)) => Poll::Ready(Err(WriterFault::Deferred( + defer_err_conn(this.lifecycle.clone()), + ))), + Some(Err(error)) => Poll::Ready(Err(Self::frame_io_fault_for(this.lifecycle, error))), + None => Poll::Ready(Err(Self::protocol_fault_for( + this.lifecycle, + DriverProtocolError::FrameEof, + ))), + } + } + + fn drive_credit( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + if self.as_ref().get_ref().eos_committed() { + panic!("{WRITE_AFTER_EOS_PANIC}"); + } + if let Some(code) = self.as_ref().get_ref().committed_reset_code() { + return match ready!(self.as_mut().drive_reset(cx, code))? { + ResetDrive::Acked => Poll::Ready(Ok(DriveTerminal::Reset { code })), + ResetDrive::CoveredByEos => Poll::Ready(Ok(DriveTerminal::Eos)), + }; + } + if self.as_ref().get_ref().has_unflushed_control_barrier() { + ready!(self.as_mut().poll_send_progress(cx))?; + continue; + } + if self.as_ref().get_ref().results.credit { + return Poll::Ready(Ok(DriveTerminal::None)); + } + ready!(self.as_mut().poll_send_progress(cx))?; + if self.as_ref().get_ref().results.credit { + return Poll::Ready(Ok(DriveTerminal::None)); + } + ready!(self.as_mut().poll_recv_progress(cx))?; + } + } + + fn drive_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + generation: u64, + ) -> Poll> { + loop { + if let Some(code) = self.as_ref().get_ref().committed_reset_code() { + return match ready!(self.as_mut().drive_reset(cx, code))? { + ResetDrive::Acked => Poll::Ready(Ok(DriveTerminal::Reset { code })), + ResetDrive::CoveredByEos => Poll::Ready(Ok(DriveTerminal::Eos)), + }; + } + if self.as_ref().get_ref().results.eos.is_some() { + return Poll::Ready(Ok(DriveTerminal::Eos)); + } + if self.as_ref().get_ref().flush_result_covers(generation) { + return Poll::Ready(Ok(DriveTerminal::None)); + } + ready!(self.as_mut().poll_send_progress(cx))?; + if self.as_ref().get_ref().results.eos.is_some() { + return Poll::Ready(Ok(DriveTerminal::Eos)); + } + if self.as_ref().get_ref().flush_result_covers(generation) { + return Poll::Ready(Ok(DriveTerminal::None)); + } + ready!(self.as_mut().poll_recv_progress(cx))?; + } + } + + fn drive_eos( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + if let Some(code) = self.as_ref().get_ref().committed_reset_code() { + return match ready!(self.as_mut().drive_reset(cx, code))? { + ResetDrive::Acked => Poll::Ready(Ok(DriveTerminal::Reset { code })), + ResetDrive::CoveredByEos => Poll::Ready(Ok(DriveTerminal::Eos)), + }; + } + if self.as_ref().get_ref().results.eos.is_some() { + return Poll::Ready(Ok(DriveTerminal::Eos)); + } + ready!(self.as_mut().poll_send_progress(cx))?; + if self.as_ref().get_ref().results.eos.is_some() { + return Poll::Ready(Ok(DriveTerminal::Eos)); + } + ready!(self.as_mut().poll_recv_progress(cx))?; + } + } + + fn drive_started_eos( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + if self.as_ref().get_ref().results.eos.is_some() { + return Poll::Ready(Ok(())); + } + ready!(self.as_mut().poll_send_progress(cx))?; + if self.as_ref().get_ref().results.eos.is_some() { + return Poll::Ready(Ok(())); + } + ready!(self.as_mut().poll_recv_progress(cx))?; + } + } + + fn drive_reset( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + committed_code: VarInt, + ) -> Poll> { + loop { + if self.as_ref().get_ref().eos_started_or_completed() + && !matches!( + self.as_ref().get_ref().send_state, + SendState::Flush { + inflight: WriteOperation::Reset { .. } + } + ) + && !matches!( + self.as_ref().get_ref().recv_state, + WriterRecvState::Await { + sent: WriterAwaited::Reset { .. } + } + ) + && self.as_ref().get_ref().results.reset.is_none() + { + ready!(self.as_mut().drive_started_eos(cx))?; + return Poll::Ready(Ok(ResetDrive::CoveredByEos)); + } + if self.as_ref().get_ref().reset_matches_result(committed_code) { + return Poll::Ready(Ok(ResetDrive::Acked)); + } + ready!(self.as_mut().poll_send_progress(cx))?; + if self.as_ref().get_ref().reset_matches_result(committed_code) { + return Poll::Ready(Ok(ResetDrive::Acked)); + } + ready!(self.as_mut().poll_recv_progress(cx))?; + } + } +} + +fn poll_deferred_error( + error: &mut DeferredStreamError, + cx: &mut Context<'_>, +) -> Poll { + match error { + DeferredStreamError::Ready { error } => Poll::Ready(error.clone()), + DeferredStreamError::Pending { future } => { + let stream_error = ready!(future.as_mut().poll(cx)); + *error = DeferredStreamError::Ready { + error: stream_error.clone(), + }; + Poll::Ready(stream_error) + } + } +} + +fn ready_error(error: quic::StreamError) -> DeferredStreamError { + DeferredStreamError::Ready { error } +} + +impl BridgeStreamWriter +where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn stream_id(self: Pin<&mut Self>) -> VarInt { + match self.project() { + BridgeStreamWriterProj::Active { active } => active.as_ref().get_ref().stream_id, + BridgeStreamWriterProj::Eos { stream_id } => *stream_id, + BridgeStreamWriterProj::Reset { stream_id, .. } => *stream_id, + BridgeStreamWriterProj::Closed { stream_id, .. } => *stream_id, + } + } + + fn close_with_stream_error( + mut self: Pin<&mut Self>, + error: quic::StreamError, + ) -> quic::StreamError { + let stream_id = self.as_mut().stream_id(); + self.as_mut().project_replace(Self::Closed { + stream_id, + error: ready_error(error.clone()), + }); + error + } + + fn close_with_fault( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + fault: WriterFault, + ) -> Poll { + match fault { + WriterFault::Stream(error) => Poll::Ready(self.close_with_stream_error(error)), + WriterFault::Deferred(error) => { + let stream_id = self.as_mut().stream_id(); + self.as_mut() + .project_replace(Self::Closed { stream_id, error }); + match self.project() { + BridgeStreamWriterProj::Closed { error, .. } => poll_deferred_error(error, cx), + _ => unreachable!("writer should have transitioned to closed"), + } + } + } + } + + fn poll_closed_error(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project() { + BridgeStreamWriterProj::Closed { error, .. } => poll_deferred_error(error, cx), + _ => unreachable!("poll_closed_error called outside closed state"), + } + } + + fn transition_eos(mut self: Pin<&mut Self>) { + let stream_id = self.as_mut().stream_id(); + self.as_mut().project_replace(Self::Eos { stream_id }); + } + + fn transition_reset(mut self: Pin<&mut Self>, code: VarInt) { + let stream_id = self.as_mut().stream_id(); + self.as_mut() + .project_replace(Self::Reset { stream_id, code }); + } +} + +impl Sink for BridgeStreamWriter +where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + type Error = quic::StreamError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().project() { + BridgeStreamWriterProj::Active { active } => { + match ready!(active.as_mut().drive_credit(cx)) { + Ok(DriveTerminal::None) => Poll::Ready(Ok(())), + Ok(DriveTerminal::Eos) => { + self.as_mut().transition_eos(); + panic!("{WRITE_AFTER_EOS_PANIC}"); + } + Ok(DriveTerminal::Reset { code }) => { + self.as_mut().transition_reset(code); + Poll::Ready(Err(quic::StreamError::Reset { code })) + } + Err(fault) => { + let error = ready!(self.as_mut().close_with_fault(cx, fault)); + Poll::Ready(Err(error)) + } + } + } + BridgeStreamWriterProj::Eos { .. } => panic!("{WRITE_AFTER_EOS_PANIC}"), + BridgeStreamWriterProj::Reset { code, .. } => { + Poll::Ready(Err(quic::StreamError::Reset { code: *code })) + } + BridgeStreamWriterProj::Closed { .. } => { + let error = ready!(self.as_mut().poll_closed_error(cx)); + Poll::Ready(Err(error)) + } + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + match self.as_mut().project() { + BridgeStreamWriterProj::Active { active } => { + match active.as_mut().commit_push_pinned(item) { + Ok(()) => Ok(()), + Err(error @ quic::StreamError::Connection { .. }) => { + self.as_mut().close_with_stream_error(error.clone()); + Err(error) + } + Err(error) => Err(error), + } + } + BridgeStreamWriterProj::Eos { .. } => panic!("{WRITE_AFTER_EOS_PANIC}"), + BridgeStreamWriterProj::Reset { code, .. } => { + Err(quic::StreamError::Reset { code: *code }) + } + BridgeStreamWriterProj::Closed { error, .. } => match error { + DeferredStreamError::Ready { error } => Err(error.clone()), + DeferredStreamError::Pending { .. } => { + panic!("start_send called while writer closed with pending error") + } + }, + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().project() { + BridgeStreamWriterProj::Active { active } => { + let (_commit, generation) = active.as_mut().commit_flush_pinned(); + match ready!(active.as_mut().drive_flush(cx, generation)) { + Ok(DriveTerminal::None) => { + active.as_mut().take_flush_result_pinned(generation); + Poll::Ready(Ok(())) + } + Ok(DriveTerminal::Eos) => { + self.as_mut().transition_eos(); + Poll::Ready(Ok(())) + } + Ok(DriveTerminal::Reset { code }) => { + self.as_mut().transition_reset(code); + Poll::Ready(Err(quic::StreamError::Reset { code })) + } + Err(fault) => { + let error = ready!(self.as_mut().close_with_fault(cx, fault)); + Poll::Ready(Err(error)) + } + } + } + BridgeStreamWriterProj::Eos { .. } => Poll::Ready(Ok(())), + BridgeStreamWriterProj::Reset { code, .. } => { + Poll::Ready(Err(quic::StreamError::Reset { code: *code })) + } + BridgeStreamWriterProj::Closed { .. } => { + let error = ready!(self.as_mut().poll_closed_error(cx)); + Poll::Ready(Err(error)) + } + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().project() { + BridgeStreamWriterProj::Active { active } => { + let _commit = active.as_mut().commit_eos_pinned(); + match ready!(active.as_mut().drive_eos(cx)) { + Ok(DriveTerminal::Eos | DriveTerminal::None) => { + active.as_mut().take_eos_result_pinned(); + self.as_mut().transition_eos(); + Poll::Ready(Ok(())) + } + Ok(DriveTerminal::Reset { code }) => { + self.as_mut().transition_reset(code); + Poll::Ready(Err(quic::StreamError::Reset { code })) + } + Err(fault) => { + let error = ready!(self.as_mut().close_with_fault(cx, fault)); + Poll::Ready(Err(error)) + } + } + } + BridgeStreamWriterProj::Eos { .. } => Poll::Ready(Ok(())), + BridgeStreamWriterProj::Reset { code, .. } => { + Poll::Ready(Err(quic::StreamError::Reset { code: *code })) + } + BridgeStreamWriterProj::Closed { .. } => { + let error = ready!(self.as_mut().poll_closed_error(cx)); + Poll::Ready(Err(error)) + } + } + } +} + +impl quic::GetStreamId for BridgeStreamWriter +where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn poll_stream_id( + self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + match self.project() { + BridgeStreamWriterProj::Active { active } => { + Poll::Ready(Ok(active.as_ref().get_ref().stream_id)) + } + BridgeStreamWriterProj::Eos { stream_id } => Poll::Ready(Ok(*stream_id)), + BridgeStreamWriterProj::Reset { code, .. } => { + Poll::Ready(Err(quic::StreamError::Reset { code: *code })) + } + BridgeStreamWriterProj::Closed { error, .. } => { + Poll::Ready(Err(ready!(poll_deferred_error(error, cx)))) + } + } + } +} + +impl quic::ResetStream for BridgeStreamWriter +where + L: drain::DrainLifecycle, + Io: drain::WriteDrainIo, + E: 'static, + quic::ConnectionError: From, +{ + fn poll_reset( + mut self: Pin<&mut Self>, + cx: &mut Context, + code: VarInt, + ) -> Poll> { + match self.as_mut().project() { + BridgeStreamWriterProj::Active { active } => { + let commit = active.as_mut().commit_reset_pinned(code); + let committed_code = match commit { + CommitResult::Committed | CommitResult::Duplicate => { + active.as_ref().committed_reset_code().unwrap_or(code) + } + CommitResult::Conflict => active + .as_ref() + .committed_reset_code() + .expect("conflicting reset must have committed code"), + }; + + match ready!(active.as_mut().drive_reset(cx, committed_code)) { + Ok(ResetDrive::Acked) => { + active.as_mut().take_reset_result_pinned(committed_code); + self.as_mut().transition_reset(committed_code); + Poll::Ready(Ok(())) + } + Ok(ResetDrive::CoveredByEos) => { + active.as_mut().take_eos_result_pinned(); + self.as_mut().transition_eos(); + Poll::Ready(Ok(())) + } + Err(fault) => { + let error = ready!(self.as_mut().close_with_fault(cx, fault)); + Poll::Ready(Err(error)) + } + } + } + BridgeStreamWriterProj::Eos { .. } => Poll::Ready(Ok(())), + BridgeStreamWriterProj::Reset { code, .. } => { + Poll::Ready(Err(quic::StreamError::Reset { code: *code })) + } + BridgeStreamWriterProj::Closed { .. } => { + let error = ready!(self.as_mut().poll_closed_error(cx)); + Poll::Ready(Err(error)) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{borrow::Cow, future::pending, marker::PhantomData, panic, pin::Pin, sync::Arc}; + + use bytes::Bytes; + use futures::{FutureExt as _, Sink, SinkExt as _, future::poll_fn}; + + use super::{ + ActiveBridgeStreamWriter, BridgeStreamWriter, PendingWrite, SendState, WriteBarrier, + WriterRecvState, WriterResults, + }; + use crate::{ + quic::{self, GetStreamIdExt as _, ResetStream, ResetStreamExt as _}, + rpc::{ + lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, + stream::{ + frame::{WriteCommand, WriteEvent}, + test_io::{MemoryFrameIo, TestFrameIoError, TestLifecycle, worker_writer_pair}, + }, + }, + varint::VarInt, + }; + + type WorkerWriterIo = MemoryFrameIo; + type HypervisorWriterIo = MemoryFrameIo; + type TestBridgeStreamWriter = + BridgeStreamWriter; + + struct PendingClosedLifecycle { + latch: ConnectionErrorLatch, + } + + impl PendingClosedLifecycle { + fn new() -> Self { + Self { + latch: ConnectionErrorLatch::new(), + } + } + } + + impl HasLatch for PendingClosedLifecycle { + fn latch(&self) -> &ConnectionErrorLatch { + &self.latch + } + } + + impl quic::Lifecycle for PendingClosedLifecycle { + fn close(&self, _code: crate::error::Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + self.check_with_probe(|| None) + } + + async fn closed(&self) -> quic::ConnectionError { + self.resolve_closed(pending()).await + } + } + + fn writer( + stream_id: VarInt, + ) -> ( + TestBridgeStreamWriter, + HypervisorWriterIo, + Arc, + ) { + let lifecycle = Arc::new(TestLifecycle::new()); + let (worker, hypervisor) = worker_writer_pair::(); + ( + BridgeStreamWriter::new(stream_id, worker, lifecycle.clone()), + hypervisor, + lifecycle, + ) + } + + fn pending_closed_writer( + stream_id: VarInt, + ) -> ( + BridgeStreamWriter, + HypervisorWriterIo, + ) { + let lifecycle = Arc::new(PendingClosedLifecycle::new()); + let (worker, hypervisor) = worker_writer_pair::(); + ( + BridgeStreamWriter::new(stream_id, worker, lifecycle), + hypervisor, + ) + } + + fn writer_with_pending_flush_and_credit( + stream_id: VarInt, + ) -> (TestBridgeStreamWriter, HypervisorWriterIo) { + let lifecycle = Arc::new(TestLifecycle::new()); + let (worker, hypervisor) = worker_writer_pair::(); + ( + BridgeStreamWriter::Active { + active: Box::pin(ActiveBridgeStreamWriter { + stream_id, + lifecycle, + bridge: worker, + pending: PendingWrite { + push: None, + control: Some(WriteBarrier::Flush { generation: 0 }), + reset: None, + }, + send_state: SendState::default(), + recv_state: WriterRecvState::default(), + results: WriterResults { + credit: true, + flush: None, + eos: None, + reset: None, + }, + write_generation: 0, + _error: PhantomData, + }), + }, + hypervisor, + ) + } + + fn transport(error: &quic::ConnectionError) -> &quic::TransportError { + let quic::ConnectionError::Transport { source } = error else { + panic!("connection error should be transport-scoped"); + }; + source + } + + fn stream_connection(error: quic::StreamError) -> quic::ConnectionError { + let quic::StreamError::Connection { source } = error else { + panic!("stream error should be connection-scoped"); + }; + source + } + + fn assert_reset(error: quic::StreamError, expected: VarInt) { + let quic::StreamError::Reset { code } = error else { + panic!("stream error should be reset-scoped"); + }; + assert_eq!(code, expected); + } + + async fn expect_command(hypervisor: &mut HypervisorWriterIo) -> WriteCommand { + hypervisor + .next_frame() + .await + .expect("bridge should send a frame") + .expect("command frame should be readable") + } + + fn expect_command_now(hypervisor: &mut HypervisorWriterIo) -> WriteCommand { + match hypervisor.next_frame().now_or_never() { + Some(Some(Ok(command))) => command, + Some(Some(Err(_error))) => panic!("command frame should be readable"), + Some(None) => panic!("bridge command stream should remain open"), + None => panic!("expected command frame to be ready"), + } + } + + async fn grant_credit( + writer: &mut TestBridgeStreamWriter, + hypervisor: &mut HypervisorWriterIo, + ) { + hypervisor + .send(WriteEvent::Pull) + .await + .expect("pull credit should send"); + poll_fn(|cx| Pin::new(&mut *writer).poll_ready(cx)) + .await + .expect("writer should become ready"); + } + + fn assert_no_command_now(hypervisor: &mut HypervisorWriterIo, context: &str) { + match hypervisor.next_frame().now_or_never() { + None | Some(None) => {} + Some(Some(frame)) => panic!("{context}: unexpected frame {frame:?}"), + } + } + + fn panic_text(payload: Box) -> String { + match payload.downcast::() { + Ok(message) => *message, + Err(payload) => match payload.downcast::<&'static str>() { + Ok(message) => (*message).to_owned(), + Err(_payload) => panic!("panic payload should be a string"), + }, + } + } + + #[tokio::test] + async fn poll_ready_waits_for_pull_without_consuming_start_send_credit() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(101)); + let data = Bytes::from_static(b"ready payload"); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .now_or_never() + .is_none(), + "poll_ready should wait for inbound pull credit" + ); + assert_no_command_now(&mut hypervisor, "poll_ready must not send writer commands"); + + hypervisor + .send(WriteEvent::Pull) + .await + .expect("pull credit should send"); + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect("pull should make the sink ready"); + + Pin::new(&mut bridge) + .start_send(data) + .expect("start_send should consume the cached credit"); + } + + #[tokio::test] + async fn start_send_consumes_credit_and_commits_push_without_polling_typed_io() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(102)); + let data = Bytes::from_static(b"push payload"); + + grant_credit(&mut bridge, &mut hypervisor).await; + Pin::new(&mut bridge) + .start_send(data) + .expect("start_send should commit push"); + + assert_no_command_now( + &mut hypervisor, + "start_send must not poll or write typed frame IO", + ); + } + + #[tokio::test] + async fn cached_credit_waits_for_unsent_committed_flush_before_new_push() { + let data = Bytes::from_static(b"after committed flush"); + let (mut bridge, mut hypervisor) = + writer_with_pending_flush_and_credit(VarInt::from_u32(132)); + + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect("cached credit should become ready after sending committed flush"); + assert_eq!(expect_command_now(&mut hypervisor), WriteCommand::Flush); + + Pin::new(&mut bridge) + .start_send(data.clone()) + .expect("cached credit should allow start_send after flush is sent"); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none(), + "flush should wait for the earlier FlushAck" + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Push { data } + ); + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("old flush ack should send"); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none(), + "old FlushAck must not cover data committed after that Flush" + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Flush); + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("new flush ack should send"); + bridge + .flush() + .await + .expect("new FlushAck should cover later push"); + } + + #[tokio::test] + async fn start_send_without_credit_closes_writer_with_protocol_error() { + let (mut bridge, mut hypervisor, lifecycle) = writer(VarInt::from_u32(133)); + + let error = Pin::new(&mut bridge) + .start_send(Bytes::from_static(b"without credit")) + .expect_err("start_send without credit should fail"); + let source = stream_connection(error); + assert_eq!( + transport(&source).reason.as_ref(), + "start_send called without writer credit" + ); + assert_eq!( + transport(&source).reason.as_ref(), + quic::Lifecycle::closed(lifecycle.as_ref()) + .await + .transport() + .reason + .as_ref() + ); + + let Some(Err(error)) = poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)).now_or_never() + else { + panic!("closed writer should return the latched protocol error immediately"); + }; + let source = stream_connection(error); + assert_eq!( + transport(&source).reason.as_ref(), + "start_send called without writer credit" + ); + assert_no_command_now( + &mut hypervisor, + "closed writer must not send frames after start_send protocol error", + ); + } + + #[tokio::test] + async fn poll_flush_writes_push_then_flush_and_completes_only_on_flush_ack() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(103)); + let data = Bytes::from_static(b"flush payload"); + + grant_credit(&mut bridge, &mut hypervisor).await; + Pin::new(&mut bridge) + .start_send(data.clone()) + .expect("start_send should commit push"); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none(), + "flush should wait for FlushAck" + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Push { data } + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Flush); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none(), + "flush must remain pending before FlushAck" + ); + + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("flush ack should send"); + bridge + .flush() + .await + .expect("flush ack should complete flush"); + } + + #[tokio::test] + async fn poll_close_writes_eos_and_completes_only_on_eos_ack() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(104)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none(), + "close should wait for EosAck" + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none(), + "close must remain pending before EosAck" + ); + + hypervisor + .send(WriteEvent::EosAck) + .await + .expect("eos ack should send"); + bridge.close().await.expect("eos ack should complete close"); + } + + #[tokio::test] + async fn poll_close_waits_behind_committed_flush_ack_before_sending_eos() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(105)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Flush); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none(), + "close should wait behind the in-flight flush ack" + ); + assert_no_command_now( + &mut hypervisor, + "single-awaiting writer must not send Eos before FlushAck", + ); + + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("flush ack should send"); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none(), + "close should send Eos after FlushAck and then wait for EosAck" + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + hypervisor + .send(WriteEvent::EosAck) + .await + .expect("eos ack should send"); + bridge.close().await.expect("close should complete"); + } + + #[tokio::test] + async fn poll_flush_after_awaiting_eos_sends_no_flush_and_completes_from_eos_ack() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(106)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none(), + "flush should wait for the already-awaiting EosAck" + ); + assert_no_command_now(&mut hypervisor, "eos-covered flush must not send Flush"); + + hypervisor + .send(WriteEvent::EosAck) + .await + .expect("eos ack should send"); + bridge + .flush() + .await + .expect("flush should complete from covering EosAck"); + } + + #[tokio::test] + async fn reset_keeps_first_code_and_second_reset_completes_on_first_ack() { + let first = VarInt::from_u32(107); + let second = VarInt::from_u32(108); + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(109)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, first)) + .now_or_never() + .is_none() + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Reset { code: first } + ); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, second)) + .now_or_never() + .is_none(), + "second reset should continue first committed reset" + ); + assert_no_command_now( + &mut hypervisor, + "different reset code must not replace first committed reset", + ); + + hypervisor + .send(WriteEvent::ResetAck { code: first }) + .await + .expect("reset ack should send"); + bridge + .reset(second) + .await + .expect("second reset future should complete from first ack"); + } + + #[tokio::test] + async fn wrong_reset_ack_latches_protocol_connection_error() { + let expected = VarInt::from_u32(111); + let actual = VarInt::from_u32(112); + let (mut bridge, mut hypervisor, lifecycle) = writer(VarInt::from_u32(113)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, expected)) + .now_or_never() + .is_none() + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Reset { code: expected } + ); + hypervisor + .send(WriteEvent::ResetAck { code: actual }) + .await + .expect("wrong reset ack should send"); + + let error = bridge.reset(expected).await.expect_err("reset should fail"); + let source = stream_connection(error); + assert_eq!( + transport(&source).reason.as_ref(), + "reset ack code 112 does not match committed reset code 111" + ); + assert_eq!( + transport(&source).reason.as_ref(), + quic::Lifecycle::closed(lifecycle.as_ref()) + .await + .transport() + .reason + .as_ref() + ); + } + + #[tokio::test] + async fn reset_clears_not_started_pending_work_but_preserves_awaiting_flush_order() { + let reset = VarInt::from_u32(114); + let data = Bytes::from_static(b"discard me"); + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(115)); + + grant_credit(&mut bridge, &mut hypervisor).await; + Pin::new(&mut bridge) + .start_send(data) + .expect("start_send should commit pending push"); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, reset)) + .now_or_never() + .is_none(), + "reset should wait for reset ack" + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Reset { code: reset } + ); + assert_no_command_now(&mut hypervisor, "pending push should be cleared by reset"); + hypervisor + .send(WriteEvent::ResetAck { code: reset }) + .await + .expect("reset ack should send"); + bridge.reset(reset).await.expect("reset should complete"); + + let reset = VarInt::from_u32(116); + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(117)); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Flush); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none(), + "close should be pending behind flush" + ); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, reset)) + .now_or_never() + .is_none(), + "reset must preserve the already awaiting flush before sending reset" + ); + assert_no_command_now( + &mut hypervisor, + "reset must not remove an already awaiting flush", + ); + hypervisor + .send(WriteEvent::FlushAck) + .await + .expect("flush ack should send"); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, reset)) + .now_or_never() + .is_none() + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Reset { code: reset } + ); + assert_no_command_now(&mut hypervisor, "pending eos should be cleared by reset"); + } + + #[tokio::test] + async fn eos_ack_transitions_to_eos_and_later_flush_close_are_immediate_no_frames() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(118)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + hypervisor + .send(WriteEvent::EosAck) + .await + .expect("eos ack should send"); + bridge.close().await.expect("close should complete"); + + bridge.flush().await.expect("flush after eos is a no-op"); + bridge.close().await.expect("close after eos is a no-op"); + assert_no_command_now(&mut hypervisor, "eos terminal operations send no frames"); + } + + #[tokio::test] + async fn poll_ready_and_start_send_panic_after_eos_is_committed() { + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(119)); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + + let ready_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + let _ = poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)).now_or_never(); + })) + .expect_err("poll_ready after eos commit should panic"); + assert_eq!(panic_text(ready_panic), "h3x write data after shutdown"); + + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(120)); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + let start_send_panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + Pin::new(&mut bridge) + .start_send(Bytes::from_static(b"illegal")) + .expect("start_send should panic before returning"); + })) + .expect_err("start_send after eos commit should panic"); + assert_eq!( + panic_text(start_send_panic), + "h3x write data after shutdown" + ); + } + + #[tokio::test] + async fn reset_ack_transitions_to_reset_and_later_paths_observe_reset_error() { + let stream_id = VarInt::from_u32(121); + let reset = VarInt::from_u32(122); + let (mut bridge, mut hypervisor, _lifecycle) = writer(stream_id); + + grant_credit(&mut bridge, &mut hypervisor).await; + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, reset)) + .now_or_never() + .is_none() + ); + assert_eq!( + expect_command(&mut hypervisor).await, + WriteCommand::Reset { code: reset } + ); + hypervisor + .send(WriteEvent::Pull) + .await + .expect("post-reset cached pull should be suppressed"); + hypervisor + .send(WriteEvent::ResetAck { code: reset }) + .await + .expect("reset ack should send"); + bridge.reset(reset).await.expect("reset should complete"); + + assert_reset( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect_err("poll_ready after reset should fail"), + reset, + ); + assert_reset( + Pin::new(&mut bridge) + .start_send(Bytes::from_static(b"after reset")) + .expect_err("start_send after reset should fail"), + reset, + ); + assert_reset( + bridge + .flush() + .await + .expect_err("flush after reset should fail"), + reset, + ); + assert_reset( + bridge + .close() + .await + .expect_err("close after reset should fail"), + reset, + ); + assert_reset( + bridge + .reset(reset) + .await + .expect_err("reset after reset should fail"), + reset, + ); + assert_reset( + bridge + .stream_id() + .await + .expect_err("stream id after reset should fail"), + reset, + ); + } + + #[tokio::test] + async fn reset_after_already_started_eos_waits_for_eos_ack_and_sends_no_reset() { + let reset = VarInt::from_u32(123); + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(124)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_reset(cx, reset)) + .now_or_never() + .is_none(), + "reset should wait for the already-started eos" + ); + assert_no_command_now(&mut hypervisor, "eos-covered reset must not send Reset"); + + hypervisor + .send(WriteEvent::EosAck) + .await + .expect("eos ack should send"); + bridge + .reset(reset) + .await + .expect("reset covered by eos should complete"); + assert_no_command_now(&mut hypervisor, "covered reset sends no frame"); + } + + #[tokio::test] + async fn err_reset_and_err_conn_close_writer_correctly() { + let reset = VarInt::from_u32(125); + let (mut bridge, mut hypervisor, _lifecycle) = writer(VarInt::from_u32(126)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .now_or_never() + .is_none() + ); + hypervisor + .send(WriteEvent::ErrReset { code: reset }) + .await + .expect("reset event should send"); + assert_reset( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect_err("err reset should fail readiness"), + reset, + ); + assert_reset( + bridge + .flush() + .await + .expect_err("closed reset should be cached"), + reset, + ); + + let canonical = quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(127), + frame_type: VarInt::from_u32(128), + reason: "canonical writer close".into(), + }, + }; + let (mut bridge, mut hypervisor, lifecycle) = writer(VarInt::from_u32(129)); + lifecycle.set_closed_error(canonical.clone()); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .now_or_never() + .is_none() + ); + hypervisor + .send(WriteEvent::ErrConn) + .await + .expect("connection error event should send"); + let error = poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .await + .expect_err("err conn should fail readiness"); + let source = stream_connection(error); + assert_eq!(transport(&source).kind, transport(&canonical).kind); + assert_eq!( + transport(&source).frame_type, + transport(&canonical).frame_type + ); + assert_eq!(transport(&source).reason, transport(&canonical).reason); + + let (mut bridge, mut hypervisor) = pending_closed_writer(VarInt::from_u32(130)); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .now_or_never() + .is_none() + ); + hypervisor + .send(WriteEvent::ErrConn) + .await + .expect("connection error event should send"); + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_ready(cx)) + .now_or_never() + .is_none(), + "deferred ErrConn should remain pending while lifecycle close is pending" + ); + } + + #[tokio::test] + async fn frame_io_eof_before_terminal_ack_latches_connection_frame_eof() { + let (mut bridge, mut hypervisor, lifecycle) = writer(VarInt::from_u32(131)); + + assert!( + poll_fn(|cx| Pin::new(&mut bridge).poll_close(cx)) + .now_or_never() + .is_none() + ); + assert_eq!(expect_command(&mut hypervisor).await, WriteCommand::Eos); + drop(hypervisor); + + let error = bridge + .close() + .await + .expect_err("frame eof should fail close"); + let source = stream_connection(error); + assert_eq!( + transport(&source).reason.as_ref(), + "typed frame stream ended before operation completed" + ); + assert_eq!( + transport(&source).reason.as_ref(), + quic::Lifecycle::closed(lifecycle.as_ref()) + .await + .transport() + .reason + .as_ref() + ); + } + + trait TransportExt { + fn transport(&self) -> &quic::TransportError; + } + + impl TransportExt for quic::ConnectionError { + fn transport(&self) -> &quic::TransportError { + transport(self) + } + } +} diff --git a/src/rpc/webtransport.rs b/src/rpc/webtransport.rs index d85d2f4..5452d4e 100644 --- a/src/rpc/webtransport.rs +++ b/src/rpc/webtransport.rs @@ -1,15 +1,142 @@ //! RPC forwarding for the WebTransport session layer. //! //! Re-exports the RTC trait, generated client/server types, and the -//! [`RemoteWtSession`] convenience wrapper. +//! [`RemoteWebTransportSession`] convenience wrapper. + +use std::future::Future; + +use snafu::ResultExt; mod session; pub use self::session::{ - RemoteWtSession, WtSessionClient, WtSessionReqReceiver, WtSessionServer, WtSessionServerRef, - WtSessionServerRefMut, WtSessionServerShared, WtSessionServerSharedMut, + RemoteWebTransportSession, WebTransportRpcSessionClient, WebTransportRpcSessionReqReceiver, + WebTransportRpcSessionServer, WebTransportRpcSessionServerRef, + WebTransportRpcSessionServerRefMut, WebTransportRpcSessionServerShared, + WebTransportRpcSessionServerSharedMut, }; -use crate::{quic, webtransport}; +use crate::{ + quic::{self, ConnectionError}, + rpc::lifecycle::LifecycleExt as ConnectionLifecycleExt, + webtransport::{ + self, AcceptStreamError, CloseReason, CloseSessionError, ControlCommandError, + DrainSessionError, OpenStreamError, SessionClosed, accept_stream_error, open_stream_error, + }, +}; + +/// WebTransport-flavoured lifecycle helpers for RPC-backed session handles. +/// +/// The trait is intentionally named `LifecycleExt`; the module namespace +/// (`rpc::webtransport::LifecycleExt`) carries the domain. It builds on the +/// connection-level RPC lifecycle latch and maps WebTransport operations into +/// [`OpenStreamError`] / [`AcceptStreamError`] without introducing a separate +/// latch mechanism. +/// +/// Implementers still must satisfy the latch-aware lifecycle invariant from +/// [`crate::rpc::lifecycle`]: their [`quic::Lifecycle`] implementation must +/// consult and resolve the same latch used by these guard helpers. +#[allow(async_fn_in_trait)] +pub trait LifecycleExt: ConnectionLifecycleExt { + /// Check liveness and surface any error as an [`OpenStreamError`]. + fn check_open(&self) -> Result<(), OpenStreamError> { + quic::Lifecycle::check(self).context(open_stream_error::OpenSnafu) + } + + /// Check liveness and surface any error as an [`AcceptStreamError`]. + fn check_accept(&self) -> Result<(), AcceptStreamError> { + quic::Lifecycle::check(self).context(accept_stream_error::ConnectionSnafu) + } + + /// Guard an async open operation whose error is already an + /// [`OpenStreamError`]. + async fn guard_open( + &self, + fut: impl Future>, + ) -> Result { + self.check_open()?; + match fut.await { + Ok(v) => Ok(v), + Err(OpenStreamError::Open { source }) => Err(OpenStreamError::Open { + source: self.latch().latch_with(|| source), + }), + Err(other) => Err(other), + } + } + + /// Guard an async open operation whose error must be lazily converted to + /// an [`OpenStreamError`]. + async fn guard_open_with( + &self, + fut: impl Future>, + convert_error: M, + ) -> Result + where + M: FnOnce(E) -> OpenStreamError, + { + self.check_open()?; + match fut.await { + Ok(v) => Ok(v), + Err(e) => { + if let Some(existing) = self.latch().peek() { + return Err(OpenStreamError::Open { source: existing }); + } + Err(match convert_error(e) { + OpenStreamError::Open { source } => OpenStreamError::Open { + source: self.latch().latch_with(|| source), + }, + other => other, + }) + } + } + } + + /// Guard an async accept operation whose error is already an + /// [`AcceptStreamError`]. + async fn guard_accept( + &self, + fut: impl Future>, + ) -> Result { + self.check_accept()?; + match fut.await { + Ok(v) => Ok(v), + Err(AcceptStreamError::Connection { source }) => Err(AcceptStreamError::Connection { + source: self.latch().latch_with(|| source), + }), + Err(other) => Err(other), + } + } + + /// Guard an async accept operation whose error carries richer information + /// than [`SessionClosed`]. + async fn guard_accept_err( + &self, + fut: impl Future>, + convert_error: M, + ) -> Result + where + M: FnOnce(E) -> Option, + { + self.check_accept()?; + match fut.await { + Ok(v) => Ok(v), + Err(e) => { + if let Some(existing) = self.latch().peek() { + return Err(AcceptStreamError::Connection { source: existing }); + } + if let Some(error) = convert_error(e) { + return Err(AcceptStreamError::Connection { + source: self.latch().latch_with(|| error), + }); + } + Err(AcceptStreamError::Closed { + source: SessionClosed, + }) + } + } + } +} + +impl LifecycleExt for T {} impl From for webtransport::OpenStreamError { fn from(error: remoc::rtc::CallError) -> Self { @@ -19,8 +146,344 @@ impl From for webtransport::OpenStreamError { } } -impl From for webtransport::Closed { +impl From for webtransport::AcceptStreamError { + fn from(error: remoc::rtc::CallError) -> Self { + webtransport::AcceptStreamError::Connection { + source: quic::ConnectionError::from(error), + } + } +} + +impl From for webtransport::DrainSessionError { + fn from(_error: remoc::rtc::CallError) -> Self { + DrainSessionError::Command { + source: ControlCommandError::Closed, + } + } +} + +impl From for webtransport::CloseSessionError { fn from(_error: remoc::rtc::CallError) -> Self { - webtransport::Closed + CloseSessionError::Command { + source: ControlCommandError::Closed, + } + } +} + +impl From for webtransport::CloseReason { + fn from(error: remoc::rtc::CallError) -> Self { + CloseReason::Connection(quic::ConnectionError::from(error)) + } +} + +#[cfg(test)] +mod tests { + use std::{borrow::Cow, future::pending}; + + use super::*; + use crate::{ + error::Code, + rpc::lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt as ConnectionLifecycleExt}, + varint::VarInt, + webtransport::{AcceptStreamError, OpenStreamError}, + }; + + #[test] + fn call_error_maps_to_open_stream_connection_error() { + let error = webtransport::OpenStreamError::from(remoc::rtc::CallError::Dropped); + let webtransport::OpenStreamError::Open { source } = error else { + panic!("call error should map to open stream connection error"); + }; + + assert!(source.is_transport()); + } + + #[test] + fn call_error_maps_to_accept_stream_connection_error() { + let error = webtransport::AcceptStreamError::from(remoc::rtc::CallError::Dropped); + let webtransport::AcceptStreamError::Connection { source } = error else { + panic!("call error should map to accept stream connection error"); + }; + + assert!(source.is_transport()); + } + + #[derive(Debug, Default)] + struct TestLifecycle { + latch: ConnectionErrorLatch, + } + + impl HasLatch for TestLifecycle { + fn latch(&self) -> &ConnectionErrorLatch { + &self.latch + } + } + + impl quic::Lifecycle for TestLifecycle { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + self.check_with_probe(|| None) + } + + async fn closed(&self) -> quic::ConnectionError { + self.resolve_closed(pending()).await + } + } + + fn connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + fn assert_reason(error: &quic::ConnectionError, expected: &str) { + let quic::ConnectionError::Transport { source } = error else { + panic!("expected transport error"); + }; + assert_eq!(source.reason.as_ref(), expected); + } + + #[tokio::test] + async fn lifecycle_checks_and_latches_open_errors() { + let lifecycle = TestLifecycle::default(); + + lifecycle.check_open().expect("open check should pass"); + lifecycle.check_accept().expect("accept check should pass"); + + let error = lifecycle + .guard_open(async { + Err::<(), _>(OpenStreamError::Open { + source: connection_error("first"), + }) + }) + .await + .expect_err("open error should be returned"); + let OpenStreamError::Open { source } = error else { + panic!("expected open error"); + }; + assert_reason(&source, "first"); + + let error = lifecycle + .guard_open(async { + Err::<(), _>(OpenStreamError::Open { + source: connection_error("second"), + }) + }) + .await + .expect_err("latched open error should be returned"); + let OpenStreamError::Open { source } = error else { + panic!("expected open error"); + }; + assert_reason(&source, "first"); + } + + #[tokio::test] + async fn lifecycle_guard_open_with_converts_lazily() { + let lifecycle = TestLifecycle::default(); + + let error = lifecycle + .guard_open_with(async { Err::<(), _>("boom") }, |_| OpenStreamError::Open { + source: connection_error("converted"), + }) + .await + .expect_err("converted open error should be returned"); + + let OpenStreamError::Open { source } = error else { + panic!("expected open error"); + }; + assert_reason(&source, "converted"); + } + + #[tokio::test] + async fn lifecycle_open_guards_return_success_and_passthrough_closed() { + let lifecycle = TestLifecycle::default(); + + let value = lifecycle + .guard_open(async { Ok::<_, OpenStreamError>(7) }) + .await + .expect("successful open operation should pass through"); + assert_eq!(value, 7); + + let error = lifecycle + .guard_open(async { + Err::<(), _>(OpenStreamError::Closed { + source: SessionClosed, + }) + }) + .await + .expect_err("closed session error should pass through"); + assert!(matches!(error, OpenStreamError::Closed { .. })); + lifecycle + .check_open() + .expect("non-connection open error must not latch"); + + let error = lifecycle + .guard_open_with(async { Err::<(), _>("closed") }, |_| { + OpenStreamError::Closed { + source: SessionClosed, + } + }) + .await + .expect_err("converted closed session error should pass through"); + assert!(matches!(error, OpenStreamError::Closed { .. })); + lifecycle + .check_open() + .expect("converted non-connection open error must not latch"); + } + + #[tokio::test] + async fn lifecycle_open_guard_with_uses_error_latched_during_operation() { + let lifecycle = TestLifecycle::default(); + let latch = lifecycle.latch.clone(); + + let error = lifecycle + .guard_open_with( + async move { + latch.latch_with(|| connection_error("latched during open")); + Err::<(), _>("unconverted") + }, + |_| { + panic!("conversion should be skipped once operation latched a connection error") + }, + ) + .await + .expect_err("latched open error should be returned"); + + let OpenStreamError::Open { source } = error else { + panic!("expected open error"); + }; + assert_reason(&source, "latched during open"); + } + + #[tokio::test] + async fn lifecycle_checks_surface_latched_open_error() { + let lifecycle = TestLifecycle::default(); + + let error = lifecycle + .guard_open(async { + Err::<(), _>(OpenStreamError::Open { + source: connection_error("check latch"), + }) + }) + .await + .expect_err("open error should be returned"); + let OpenStreamError::Open { source } = error else { + panic!("expected open error"); + }; + assert_reason(&source, "check latch"); + + let open_error = lifecycle + .check_open() + .expect_err("open check should surface latched error"); + let OpenStreamError::Open { source } = open_error else { + panic!("expected open error"); + }; + assert_reason(&source, "check latch"); + + let accept_error = lifecycle + .check_accept() + .expect_err("accept check should surface latched error"); + let AcceptStreamError::Connection { source } = accept_error else { + panic!("expected connection error"); + }; + assert_reason(&source, "check latch"); + } + + #[tokio::test] + async fn lifecycle_accept_guards_preserve_error_shape() { + let lifecycle = TestLifecycle::default(); + + let error = lifecycle + .guard_accept(async { + Err::<(), _>(AcceptStreamError::Connection { + source: connection_error("accept"), + }) + }) + .await + .expect_err("accept error should be returned"); + let AcceptStreamError::Connection { source } = error else { + panic!("expected connection error"); + }; + assert_reason(&source, "accept"); + + let lifecycle = TestLifecycle::default(); + let error = lifecycle + .guard_accept_err(async { Err::<(), _>("closed") }, |_| None) + .await + .expect_err("closed session should be returned"); + assert!(matches!(error, AcceptStreamError::Closed { .. })); + + let lifecycle = TestLifecycle::default(); + let error = lifecycle + .guard_accept_err(async { Err::<(), _>("connection") }, |_| { + Some(connection_error("converted accept")) + }) + .await + .expect_err("converted accept error should be returned"); + let AcceptStreamError::Connection { source } = error else { + panic!("expected connection error"); + }; + assert_reason(&source, "converted accept"); + } + + #[tokio::test] + async fn lifecycle_accept_guards_return_success_and_passthrough_closed() { + let lifecycle = TestLifecycle::default(); + + let value = lifecycle + .guard_accept(async { Ok::<_, AcceptStreamError>(11) }) + .await + .expect("successful accept operation should pass through"); + assert_eq!(value, 11); + + let error = lifecycle + .guard_accept(async { + Err::<(), _>(AcceptStreamError::Closed { + source: SessionClosed, + }) + }) + .await + .expect_err("closed session error should pass through"); + assert!(matches!(error, AcceptStreamError::Closed { .. })); + lifecycle + .check_accept() + .expect("non-connection accept error must not latch"); + + let value = lifecycle + .guard_accept_err(async { Ok::<_, &'static str>(13) }, |_| { + panic!("conversion should not run for successful accept operation") + }) + .await + .expect("successful converted accept operation should pass through"); + assert_eq!(value, 13); + } + + #[tokio::test] + async fn lifecycle_accept_guard_uses_error_latched_during_operation() { + let lifecycle = TestLifecycle::default(); + let latch = lifecycle.latch.clone(); + + let error = lifecycle + .guard_accept_err( + async move { + latch.latch_with(|| connection_error("latched during accept")); + Err::<(), _>("unconverted") + }, + |_| { + panic!("conversion should be skipped once operation latched a connection error") + }, + ) + .await + .expect_err("latched accept error should be returned"); + + let AcceptStreamError::Connection { source } = error else { + panic!("expected connection error"); + }; + assert_reason(&source, "latched during accept"); } } diff --git a/src/rpc/webtransport/session.rs b/src/rpc/webtransport/session.rs index 94ef71b..3fe70e7 100644 --- a/src/rpc/webtransport/session.rs +++ b/src/rpc/webtransport/session.rs @@ -2,23 +2,32 @@ //! //! Follows the same pattern as [`super::super::quic::connection`]: a //! `#[remoc::rtc::remote]` trait provides the wire protocol, a blanket-free -//! server impl delegates to [`WebTransportSession`], and [`RemoteWtSession`] +//! server impl delegates to [`WebTransportSession`], and [`RemoteWebTransportSession`] //! wraps the generated client to present a convenient async API. use std::sync::Arc; -use remoc::{prelude::Server, rtc::Client as RemocClient}; +use remoc::rtc::Client as RemocClient; use tracing::Instrument; -use super::super::{ - lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt}, - quic::{ReadStreamClient, ReadStreamServer, WriteStreamClient, WriteStreamServer}, +use super::{ + super::{ + lifecycle::{ConnectionErrorLatch, HasLatch, LifecycleExt as ConnectionLifecycleExt}, + quic::{ReadFrameChannels, WriteFrameChannels}, + }, + LifecycleExt, }; use crate::{ - message::stream::guard, - quic::{self, ConnectionError, DynLifecycle}, + dhttp::message::guard, + quic::{ + self, BoxQuicStreamReader, BoxQuicStreamWriter, ConnectionError, DynLifecycle, + GetStreamIdExt, + }, varint::VarInt, - webtransport::{self, Closed, OpenStreamError, WtLifecycleExt}, + webtransport::{ + self, AcceptStreamError, CloseReason, CloseSession, CloseSessionError, DrainSessionError, + OpenStreamError, SessionDrain, WebTransportSessionId, + }, }; // --------------------------------------------------------------------------- @@ -27,90 +36,77 @@ use crate::{ /// Remoc RPC counterpart of [`WebTransportSession`]. /// -/// Uses the native [`OpenStreamError`] and [`Closed`] error types directly — -/// both are serializable. The `session_id` is not included because it is -/// immutable and can be passed out-of-band at construction time. +/// Uses the native [`OpenStreamError`] and [`AcceptStreamError`] error types +/// directly — both are serializable. The `session_id` is not included because +/// it is immutable and can be passed out-of-band at construction time. #[remoc::rtc::remote] -pub trait WtSession: Send + Sync { - async fn open_bi(&self) -> Result<(ReadStreamClient, WriteStreamClient), OpenStreamError>; - async fn open_uni(&self) -> Result; - async fn accept_bi(&self) -> Result<(ReadStreamClient, WriteStreamClient), Closed>; - async fn accept_uni(&self) -> Result; +pub trait WebTransportRpcSession: Send + Sync { + async fn drain(&self) -> Result<(), DrainSessionError>; + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError>; + async fn drained(&self) -> Result; + async fn closed(&self) -> Result; + async fn open_bi(&self) -> Result<(ReadFrameChannels, WriteFrameChannels), OpenStreamError>; + async fn open_uni(&self) -> Result; + async fn accept_bi(&self) + -> Result<(ReadFrameChannels, WriteFrameChannels), AcceptStreamError>; + async fn accept_uni(&self) -> Result; } // --------------------------------------------------------------------------- -// Server: impl WtSession for WebTransportSession +// Server: impl WebTransportRpcSession for WebTransportSession // --------------------------------------------------------------------------- -impl WtSession for webtransport::WebTransportSession { - async fn open_bi(&self) -> Result<(ReadStreamClient, WriteStreamClient), OpenStreamError> { +impl WebTransportRpcSession for webtransport::WebTransportSession { + async fn drain(&self) -> Result<(), DrainSessionError> { + webtransport::WebTransportSession::drain(self).await + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + webtransport::WebTransportSession::close(self, close).await + } + + async fn drained(&self) -> Result { + Ok(webtransport::WebTransportSession::drained(self).await) + } + + async fn closed(&self) -> Result { + Ok(webtransport::WebTransportSession::closed(self).await) + } + + async fn open_bi(&self) -> Result<(ReadFrameChannels, WriteFrameChannels), OpenStreamError> { let (reader, writer) = webtransport::WebTransportSession::open_bi(self).await?; - let (rs, rc) = ReadStreamServer::new(reader, 1); - tokio::spawn( - (async move { - let _ = rs.serve().await; - }) - .in_current_span(), - ); - let (ws, wc) = WriteStreamServer::new(writer, 1); - tokio::spawn( - (async move { - let _ = ws.serve().await; - }) - .in_current_span(), - ); - Ok((rc, wc)) + Ok(( + read_channels_open(Box::pin(reader) as BoxQuicStreamReader).await?, + write_channels_open(Box::pin(writer) as BoxQuicStreamWriter).await?, + )) } - async fn open_uni(&self) -> Result { + async fn open_uni(&self) -> Result { let writer = webtransport::WebTransportSession::open_uni(self).await?; - let (ws, wc) = WriteStreamServer::new(writer, 1); - tokio::spawn( - (async move { - let _ = ws.serve().await; - }) - .in_current_span(), - ); - Ok(wc) + write_channels_open(Box::pin(writer) as BoxQuicStreamWriter).await } - async fn accept_bi(&self) -> Result<(ReadStreamClient, WriteStreamClient), Closed> { + async fn accept_bi( + &self, + ) -> Result<(ReadFrameChannels, WriteFrameChannels), AcceptStreamError> { let (reader, writer) = webtransport::WebTransportSession::accept_bi(self).await?; - let (rs, rc) = ReadStreamServer::new(reader, 1); - tokio::spawn( - (async move { - let _ = rs.serve().await; - }) - .in_current_span(), - ); - let (ws, wc) = WriteStreamServer::new(writer, 1); - tokio::spawn( - (async move { - let _ = ws.serve().await; - }) - .in_current_span(), - ); - Ok((rc, wc)) + Ok(( + read_channels_accept(Box::pin(reader) as BoxQuicStreamReader).await?, + write_channels_accept(Box::pin(writer) as BoxQuicStreamWriter).await?, + )) } - async fn accept_uni(&self) -> Result { + async fn accept_uni(&self) -> Result { let reader = webtransport::WebTransportSession::accept_uni(self).await?; - let (rs, rc) = ReadStreamServer::new(reader, 1); - tokio::spawn( - (async move { - let _ = rs.serve().await; - }) - .in_current_span(), - ); - Ok(rc) + read_channels_accept(Box::pin(reader) as BoxQuicStreamReader).await } } // --------------------------------------------------------------------------- -// Client: RemoteWtSession wraps WtSessionClient +// Client: RemoteWebTransportSession wraps WebTransportRpcSessionClient // --------------------------------------------------------------------------- -/// A wrapper around [`WtSessionClient`] that converts RPC stream clients back +/// A wrapper around [`WebTransportRpcSessionClient`] that converts RPC stream clients back /// into boxed async streams. /// /// This is the client-side handle for a remote WebTransport session. It mirrors @@ -125,17 +121,17 @@ impl WtSession for webtransport::WebTransportSession { /// own remoc channel — is latched and returned on every subsequent operation, /// satisfying the QUIC connection-error consistency requirement. #[derive(Clone)] -pub struct RemoteWtSession { - client: WtSessionClient, - session_id: VarInt, +pub struct RemoteWebTransportSession { + client: WebTransportRpcSessionClient, + session_id: WebTransportSessionId, parent: Arc, latch: ConnectionErrorLatch, } -impl RemoteWtSession { +impl RemoteWebTransportSession { pub fn new( - client: WtSessionClient, - session_id: VarInt, + client: WebTransportRpcSessionClient, + session_id: WebTransportSessionId, conn_lifecycle: Arc, ) -> Self { Self { @@ -146,7 +142,7 @@ impl RemoteWtSession { } } - pub fn into_inner(self) -> WtSessionClient { + pub fn into_inner(self) -> WebTransportRpcSessionClient { self.client } @@ -156,7 +152,7 @@ impl RemoteWtSession { source: quic::TransportError { kind: VarInt::from_u32(0x01), frame_type: VarInt::from_u32(0x00), - reason: "remoc wt session channel closed".into(), + reason: "remoc webtransport session channel closed".into(), }, } } @@ -173,13 +169,13 @@ impl RemoteWtSession { } } -impl HasLatch for RemoteWtSession { +impl HasLatch for RemoteWebTransportSession { fn latch(&self) -> &ConnectionErrorLatch { &self.latch } } -impl quic::Lifecycle for RemoteWtSession { +impl quic::Lifecycle for RemoteWebTransportSession { fn close(&self, code: crate::error::Code, reason: std::borrow::Cow<'static, str>) { DynLifecycle::close(self.parent.as_ref(), code, reason); } @@ -194,56 +190,951 @@ impl quic::Lifecycle for RemoteWtSession { } } -impl webtransport::Session for RemoteWtSession { - type StreamReader = guard::GuardedQuicReader; - type StreamWriter = guard::GuardedQuicWriter; +impl webtransport::Session for RemoteWebTransportSession { + type StreamReader = guard::GuardQuicReader; + type StreamWriter = guard::GuardQuicWriter; - fn session_id(&self) -> VarInt { + fn id(&self) -> WebTransportSessionId { self.session_id } + async fn drain(&self) -> Result<(), DrainSessionError> { + WebTransportRpcSession::drain(&self.client).await + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + WebTransportRpcSession::close(&self.client, close).await + } + + async fn drained(&self) -> SessionDrain { + match WebTransportRpcSession::drained(&self.client).await { + Ok(drain) => drain, + Err(reason) => SessionDrain::Closed(reason), + } + } + + async fn closed(&self) -> CloseReason { + match WebTransportRpcSession::closed(&self.client).await { + Ok(reason) | Err(reason) => reason, + } + } + async fn open_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), OpenStreamError> { - let (reader, writer) = self.guard_open(WtSession::open_bi(&self.client)).await?; - Ok((reader.into_boxed_quic(), writer.into_boxed_quic())) + let (reader, writer) = self + .guard_open(WebTransportRpcSession::open_bi(&self.client)) + .await?; + Ok(( + self.read_channels_into_quic(reader), + self.write_channels_into_quic(writer), + )) } async fn open_uni(&self) -> Result { - let writer = self.guard_open(WtSession::open_uni(&self.client)).await?; - Ok(writer.into_boxed_quic()) + let writer = self + .guard_open(WebTransportRpcSession::open_uni(&self.client)) + .await?; + Ok(self.write_channels_into_quic(writer)) } - async fn accept_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), Closed> { + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), AcceptStreamError> { let (reader, writer) = self - .guard_accept_err(WtSession::accept_bi(&self.client), |Closed| { - RemocClient::is_closed(&self.client).then(Self::remoc_channel_error) - }) + .guard_accept(WebTransportRpcSession::accept_bi(&self.client)) .await?; - Ok((reader.into_boxed_quic(), writer.into_boxed_quic())) + Ok(( + self.read_channels_into_quic(reader), + self.write_channels_into_quic(writer), + )) } - async fn accept_uni(&self) -> Result { + async fn accept_uni(&self) -> Result { let reader = self - .guard_accept_err(WtSession::accept_uni(&self.client), |Closed| { - RemocClient::is_closed(&self.client).then(Self::remoc_channel_error) - }) + .guard_accept(WebTransportRpcSession::accept_uni(&self.client)) .await?; - Ok(reader.into_boxed_quic()) + Ok(self.read_channels_into_quic(reader)) + } +} + +impl RemoteWebTransportSession { + fn read_channels_into_quic(&self, channels: ReadFrameChannels) -> guard::GuardQuicReader { + let lifecycle = Arc::new(self.clone()); + let raw = Box::pin(channels.into_quic(lifecycle)) as BoxQuicStreamReader; + guard::GuardQuicReader::new(raw) } + + fn write_channels_into_quic(&self, channels: WriteFrameChannels) -> guard::GuardQuicWriter { + let lifecycle = Arc::new(self.clone()); + let raw = Box::pin(channels.into_quic(lifecycle)) as BoxQuicStreamWriter; + guard::GuardQuicWriter::new(raw) + } +} + +async fn read_channels_open( + mut reader: BoxQuicStreamReader, +) -> Result { + let stream_id = match reader.stream_id().await { + Ok(stream_id) => stream_id, + Err(error) => return Err(open_stream_id_error(error)), + }; + let (channels, bridge) = ReadFrameChannels::pair(stream_id); + // Inherent termination: this task owns the real WebTransport stream and + // remoc frame IO. It exits when the real stream reaches a terminal state, + // the worker drops the frame channels, or frame IO reports failure. + tokio::spawn( + crate::rpc::stream::hypervisor::read::run_read_bridge(reader, bridge).in_current_span(), + ); + Ok(channels) +} + +async fn write_channels_open( + mut writer: BoxQuicStreamWriter, +) -> Result { + let stream_id = match writer.stream_id().await { + Ok(stream_id) => stream_id, + Err(error) => return Err(open_stream_id_error(error)), + }; + let (channels, bridge) = WriteFrameChannels::pair(stream_id); + // Inherent termination: this task owns the real WebTransport stream and + // remoc frame IO. It exits when the real stream reaches a terminal state, + // the worker drops the frame channels, or frame IO reports failure. + tokio::spawn( + crate::rpc::stream::hypervisor::write::run_write_bridge(writer, bridge).in_current_span(), + ); + Ok(channels) +} + +async fn read_channels_accept( + mut reader: BoxQuicStreamReader, +) -> Result { + let stream_id = match reader.stream_id().await { + Ok(stream_id) => stream_id, + Err(error) => return Err(accept_stream_id_error(error)), + }; + let (channels, bridge) = ReadFrameChannels::pair(stream_id); + // Inherent termination: this task owns the real WebTransport stream and + // remoc frame IO. It exits when the real stream reaches a terminal state, + // the worker drops the frame channels, or frame IO reports failure. + tokio::spawn( + crate::rpc::stream::hypervisor::read::run_read_bridge(reader, bridge).in_current_span(), + ); + Ok(channels) +} + +async fn write_channels_accept( + mut writer: BoxQuicStreamWriter, +) -> Result { + let stream_id = match writer.stream_id().await { + Ok(stream_id) => stream_id, + Err(error) => return Err(accept_stream_id_error(error)), + }; + let (channels, bridge) = WriteFrameChannels::pair(stream_id); + // Inherent termination: this task owns the real WebTransport stream and + // remoc frame IO. It exits when the real stream reaches a terminal state, + // the worker drops the frame channels, or frame IO reports failure. + tokio::spawn( + crate::rpc::stream::hypervisor::write::run_write_bridge(writer, bridge).in_current_span(), + ); + Ok(channels) +} + +fn open_stream_id_error(error: quic::StreamError) -> OpenStreamError { + OpenStreamError::StreamId { source: error } +} + +fn accept_stream_id_error(error: quic::StreamError) -> AcceptStreamError { + AcceptStreamError::StreamId { source: error } } -impl WtSessionClient { - /// Convert into a [`RemoteWtSession`]. - pub fn into_wt( +impl WebTransportRpcSessionClient { + /// Convert into a [`RemoteWebTransportSession`]. + pub fn into_webtransport_session( self, - session_id: VarInt, + session_id: WebTransportSessionId, conn_lifecycle: Arc, - ) -> RemoteWtSession { - RemoteWtSession::new(self, session_id, conn_lifecycle) + ) -> RemoteWebTransportSession { + RemoteWebTransportSession::new(self, session_id, conn_lifecycle) } } -impl From for WtSessionClient { - fn from(remote: RemoteWtSession) -> Self { +impl From for WebTransportRpcSessionClient { + fn from(remote: RemoteWebTransportSession) -> Self { remote.client } } + +#[cfg(test)] +mod tests { + use std::{ + borrow::Cow, + future::pending, + sync::{Arc, Mutex}, + }; + + use bytes::Bytes; + use futures::{Sink, SinkExt, Stream, StreamExt}; + use remoc::prelude::ServerShared; + use tokio::time::{Duration, timeout}; + use tokio_util::task::AbortOnDropHandle; + use tracing::Instrument; + + use super::*; + use crate::{ + connection::{ConnectionState, tests::MockConnection}, + dhttp::{ + message::test::{read_stream_for_test, write_stream_for_test}, + protocol::DHttpProtocol, + settings::Settings, + webtransport::settings::{ + EnableWebTransport, InitialMaxData, InitialMaxStreamsBidi, InitialMaxStreamsUni, + }, + }, + error::Code, + extended_connect::EstablishedConnect, + protocol::Protocols, + qpack::field::Protocol, + quic::{BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamIdExt, StopStreamExt}, + rpc::stream::test_io::TestLifecycle as StreamTestLifecycle, + stream_id::StreamId, + webtransport::{ + CloseReason, CloseSession, DrainReason, SessionCloseReason, SessionClosed, + SessionDrain, SessionDrainReason, WEBTRANSPORT_H3, WebTransportProtocol, + }, + }; + + struct TestRpcSession { + fail_open: bool, + close_accept: bool, + fail_accept: bool, + drained: Mutex, + closed: Mutex>, + tasks: Mutex>>, + } + + impl TestRpcSession { + fn new() -> Self { + Self { + fail_open: false, + close_accept: false, + fail_accept: false, + drained: Mutex::new(false), + closed: Mutex::new(None), + tasks: Mutex::new(Vec::new()), + } + } + + fn fail_open() -> Self { + Self { + fail_open: true, + ..Self::new() + } + } + + fn fail_accept() -> Self { + Self { + fail_accept: true, + ..Self::new() + } + } + + fn close_accept() -> Self { + Self { + close_accept: true, + ..Self::new() + } + } + + fn stream_pair(&self, stream_id: u32) -> (ReadFrameChannels, WriteFrameChannels) { + let stream_id = VarInt::from_u32(stream_id); + let (reader, writer) = quic::test::mock_stream_pair(stream_id); + let (reader_client, reader_bridge) = ReadFrameChannels::pair(stream_id); + let (writer_client, writer_bridge) = WriteFrameChannels::pair(stream_id); + self.push_task(tokio::spawn( + crate::rpc::stream::hypervisor::read::run_read_bridge( + Box::pin(reader) as BoxQuicStreamReader, + reader_bridge, + ) + .in_current_span(), + )); + self.push_task(tokio::spawn( + crate::rpc::stream::hypervisor::write::run_write_bridge( + Box::pin(writer) as BoxQuicStreamWriter, + writer_bridge, + ) + .in_current_span(), + )); + (reader_client, writer_client) + } + + fn read_stream(&self, stream_id: u32) -> ReadFrameChannels { + let stream_id = VarInt::from_u32(stream_id); + let (reader, _writer) = quic::test::mock_stream_pair(stream_id); + let (reader_client, reader_bridge) = ReadFrameChannels::pair(stream_id); + self.push_task(tokio::spawn( + crate::rpc::stream::hypervisor::read::run_read_bridge( + Box::pin(reader) as BoxQuicStreamReader, + reader_bridge, + ) + .in_current_span(), + )); + reader_client + } + + fn write_stream(&self, stream_id: u32) -> WriteFrameChannels { + let stream_id = VarInt::from_u32(stream_id); + let (_reader, writer) = quic::test::mock_stream_pair(stream_id); + let (writer_client, writer_bridge) = WriteFrameChannels::pair(stream_id); + self.push_task(tokio::spawn( + crate::rpc::stream::hypervisor::write::run_write_bridge( + Box::pin(writer) as BoxQuicStreamWriter, + writer_bridge, + ) + .in_current_span(), + )); + writer_client + } + + fn push_task(&self, task: tokio::task::JoinHandle<()>) { + self.tasks + .lock() + .expect("tasks mutex should not be poisoned") + .push(AbortOnDropHandle::new(task)); + } + } + + impl WebTransportRpcSession for TestRpcSession { + async fn drain(&self) -> Result<(), DrainSessionError> { + *self + .drained + .lock() + .expect("drained mutex should not poison") = true; + Ok(()) + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + *self.closed.lock().expect("closed mutex should not poison") = Some(close); + Ok(()) + } + + async fn drained(&self) -> Result { + if *self + .drained + .lock() + .expect("drained mutex should not poison") + { + Ok(SessionDrain::Requested(DrainReason::Session( + SessionDrainReason::Local, + ))) + } else { + Ok(SessionDrain::Closed(CloseReason::Session( + SessionCloseReason::ControlStreamError, + ))) + } + } + + async fn closed(&self) -> Result { + match self + .closed + .lock() + .expect("closed mutex should not poison") + .clone() + { + Some(close) => Ok(CloseReason::Session(SessionCloseReason::Local(close))), + None => Ok(CloseReason::Session(SessionCloseReason::ControlStreamError)), + } + } + + async fn open_bi( + &self, + ) -> Result<(ReadFrameChannels, WriteFrameChannels), OpenStreamError> { + if self.fail_open { + return Err(OpenStreamError::Open { + source: connection_error("rpc open failed"), + }); + } + Ok(self.stream_pair(1)) + } + + async fn open_uni(&self) -> Result { + if self.fail_open { + return Err(OpenStreamError::Open { + source: connection_error("rpc open failed"), + }); + } + Ok(self.write_stream(2)) + } + + async fn accept_bi( + &self, + ) -> Result<(ReadFrameChannels, WriteFrameChannels), AcceptStreamError> { + if self.fail_accept { + return Err(AcceptStreamError::Connection { + source: connection_error("rpc accept failed"), + }); + } + if self.close_accept { + return Err(AcceptStreamError::Closed { + source: SessionClosed, + }); + } + Ok(self.stream_pair(3)) + } + + async fn accept_uni(&self) -> Result { + if self.fail_accept { + return Err(AcceptStreamError::Connection { + source: connection_error("rpc accept failed"), + }); + } + if self.close_accept { + return Err(AcceptStreamError::Closed { + source: SessionClosed, + }); + } + Ok(self.read_stream(4)) + } + } + + #[derive(Debug, Default)] + struct TestLifecycle { + closes: Mutex)>>, + terminal: Mutex>, + } + + impl TestLifecycle { + fn set_terminal(&self, error: ConnectionError) { + *self + .terminal + .lock() + .expect("terminal mutex should not be poisoned") = Some(error); + } + + fn closes(&self) -> Vec<(Code, Cow<'static, str>)> { + self.closes + .lock() + .expect("closes mutex should not be poisoned") + .clone() + } + } + + impl quic::Lifecycle for TestLifecycle { + fn close(&self, code: Code, reason: Cow<'static, str>) { + self.closes + .lock() + .expect("closes mutex should not be poisoned") + .push((code, reason)); + } + + fn check(&self) -> Result<(), ConnectionError> { + self.terminal + .lock() + .expect("terminal mutex should not be poisoned") + .clone() + .map_or(Ok(()), Err) + } + + async fn closed(&self) -> ConnectionError { + let terminal = self + .terminal + .lock() + .expect("terminal mutex should not be poisoned") + .clone(); + match terminal { + Some(error) => error, + None => pending().await, + } + } + } + + fn spawn_rpc_session( + session: Arc, + ) -> (AbortOnDropHandle<()>, WebTransportRpcSessionClient) + where + S: WebTransportRpcSession + 'static, + { + let (server, client) = WebTransportRpcSessionServerShared::new(session, 1); + let task = AbortOnDropHandle::new(tokio::spawn( + async move { + let _ = server.serve(true).await; + } + .in_current_span(), + )); + (task, client) + } + + fn connection_error(reason: &'static str) -> ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + fn assert_reason(error: &ConnectionError, expected: &str) { + let quic::ConnectionError::Transport { source } = error else { + panic!("expected transport error"); + }; + assert_eq!(source.reason.as_ref(), expected); + } + + async fn assert_roundtrip( + reader: &mut (impl Stream> + Unpin), + writer: &mut (impl Sink + Unpin), + payload: &'static [u8], + ) { + let bytes = Bytes::from_static(payload); + let write = async { writer.send(bytes.clone()).await }; + let read = async { reader.next().await }; + let (write, received) = tokio::join!(write, read); + write.expect("write"); + let received = received + .expect("reader should produce one chunk") + .expect("read"); + assert_eq!(received, bytes); + } + + fn connection_with_webtransport_pair() -> ( + Arc, + Arc>, + ) { + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + let dhttp = DHttpProtocol::new_for_test(erased.clone()); + dhttp + .state + .peer_settings + .set(Arc::new(enabled_webtransport_settings())) + .expect("peer settings should be set once"); + protocols.insert(dhttp); + protocols.insert(WebTransportProtocol::new_for_test(erased)); + let connection = + Arc::new(ConnectionState::new_for_test(quic.clone(), Arc::new(protocols)).erase()); + (quic, connection) + } + + fn enabled_webtransport_settings() -> Settings { + let mut settings = Settings::default(); + settings.set(EnableWebTransport::setting(true)); + settings.set(InitialMaxStreamsBidi::setting(VarInt::from_u32(16))); + settings.set(InitialMaxStreamsUni::setting(VarInt::from_u32(16))); + settings.set(InitialMaxData::setting(VarInt::MAX)); + settings + } + + fn connection_with_enabled_webtransport_pair() -> ( + Arc, + Arc>, + ) { + let (quic, connection) = connection_with_webtransport_pair(); + quic.enable_stream_ops(); + (quic, connection) + } + + fn connection_with_webtransport() -> Arc> { + connection_with_webtransport_pair().1 + } + + fn wt_session_id(id: u32) -> WebTransportSessionId { + WebTransportSessionId::try_from(StreamId::from(VarInt::from_u32(id))) + .expect("test id must be a valid webtransport session id") + } + + fn webtransport_connect_on( + connection: Arc>, + ) -> EstablishedConnect { + let stream_id = StreamId::from(VarInt::from_u32(8)); + EstablishedConnect::ready( + stream_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection, + read_stream_for_test(stream_id.0), + write_stream_for_test(stream_id.0), + ) + } + + fn webtransport_connect() -> EstablishedConnect { + webtransport_connect_on(connection_with_webtransport()) + } + + async fn wait_for_remoc_client_to_close(client: &WebTransportRpcSessionClient) { + timeout(Duration::from_secs(1), async { + loop { + if RemocClient::is_closed(client) { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("remoc client should close"); + } + + #[tokio::test] + async fn remote_session_delegates_stream_operations_and_lifecycle() { + let session = Arc::new(TestRpcSession::new()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = + RemoteWebTransportSession::new(client.clone(), wt_session_id(40), parent.clone()); + + assert_eq!(webtransport::Session::id(&remote), wt_session_id(40),); + quic::Lifecycle::check(&remote).expect("remote session should be live"); + + let (mut reader, mut writer) = webtransport::Session::open_bi(&remote) + .await + .expect("open_bi"); + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(1) + ); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(1) + ); + assert_roundtrip(&mut reader, &mut writer, b"open-bidi").await; + + let mut writer = webtransport::Session::open_uni(&remote) + .await + .expect("open_uni"); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(2) + ); + + let (mut reader, mut writer) = webtransport::Session::accept_bi(&remote) + .await + .expect("accept_bi"); + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(3) + ); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(3) + ); + assert_roundtrip(&mut reader, &mut writer, b"accept-bidi").await; + + let mut reader = webtransport::Session::accept_uni(&remote) + .await + .expect("accept_uni"); + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(4) + ); + reader + .stop(VarInt::from_u32(100)) + .await + .expect("stop accept_uni reader"); + + quic::Lifecycle::close(&remote, Code::H3_NO_ERROR, "done".into()); + assert_eq!(parent.closes(), vec![(Code::H3_NO_ERROR, "done".into())]); + + let converted = client.into_webtransport_session(wt_session_id(44), parent); + let _inner: WebTransportRpcSessionClient = converted.into(); + } + + #[tokio::test] + async fn remote_session_into_inner_returns_rpc_client() { + let session = Arc::new(TestRpcSession::new()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent.clone()); + + let client = remote.into_inner(); + let converted = client.into_webtransport_session(wt_session_id(44), parent); + + assert_eq!(webtransport::Session::id(&converted), wt_session_id(44),); + webtransport::Session::open_uni(&converted) + .await + .expect("client returned by into_inner should remain usable"); + } + + #[tokio::test] + async fn remote_webtransport_session_delegates_control_methods() { + let session = Arc::new(TestRpcSession::new()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent); + let close = CloseSession::try_from((5_u32, "bye")).expect("valid close"); + + webtransport::Session::drain(&remote) + .await + .expect("remote drain should succeed"); + assert_eq!( + webtransport::Session::drained(&remote).await, + SessionDrain::Requested(DrainReason::Session(SessionDrainReason::Local)) + ); + + webtransport::Session::close(&remote, close.clone()) + .await + .expect("remote close should succeed"); + assert_eq!( + webtransport::Session::closed(&remote).await, + CloseReason::Session(SessionCloseReason::Local(close)) + ); + } + + #[tokio::test] + async fn remote_session_latches_open_errors() { + let session = Arc::new(TestRpcSession::fail_open()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent.clone()); + + let Err(error) = webtransport::Session::open_bi(&remote).await else { + panic!("open error should surface"); + }; + let OpenStreamError::Open { source } = error else { + panic!("expected open stream connection error"); + }; + assert_reason(&source, "rpc open failed"); + + let Err(error) = webtransport::Session::open_uni(&remote).await else { + panic!("latched open error should block uni opens"); + }; + let OpenStreamError::Open { source } = error else { + panic!("expected open stream connection error"); + }; + assert_reason(&source, "rpc open failed"); + + let error = quic::Lifecycle::check(&remote).expect_err("latched error should fail check"); + assert_reason(&error, "rpc open failed"); + + let closed = timeout(Duration::from_secs(1), quic::Lifecycle::closed(&remote)) + .await + .expect("latched closed should resolve immediately"); + assert_reason(&closed, "rpc open failed"); + } + + #[tokio::test] + async fn remote_session_maps_accept_session_closed() { + let session = Arc::new(TestRpcSession::close_accept()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent); + + let Err(error) = webtransport::Session::accept_bi(&remote).await else { + panic!("accept_bi should report session closure"); + }; + assert!(matches!(error, AcceptStreamError::Closed { .. })); + + let Err(error) = webtransport::Session::accept_uni(&remote).await else { + panic!("accept_uni should report session closure"); + }; + assert!(matches!(error, AcceptStreamError::Closed { .. })); + } + + #[tokio::test] + async fn remote_session_latches_accept_connection_errors() { + let session = Arc::new(TestRpcSession::fail_accept()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent); + + let Err(error) = webtransport::Session::accept_bi(&remote).await else { + panic!("accept_bi should surface connection failure"); + }; + let AcceptStreamError::Connection { source } = error else { + panic!("expected accept connection error"); + }; + assert_reason(&source, "rpc accept failed"); + + let Err(error) = webtransport::Session::accept_uni(&remote).await else { + panic!("latched accept error should block later accepts"); + }; + let AcceptStreamError::Connection { source } = error else { + panic!("expected latched accept connection error"); + }; + assert_reason(&source, "rpc accept failed"); + + let error = + quic::Lifecycle::check(&remote).expect_err("latched accept error should fail check"); + assert_reason(&error, "rpc accept failed"); + } + + #[tokio::test] + async fn remote_session_observes_parent_lifecycle_errors() { + let session = Arc::new(TestRpcSession::new()); + let (_server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent.clone()); + + parent.set_terminal(connection_error("parent closed")); + + let error = quic::Lifecycle::check(&remote).expect_err("parent error should fail check"); + assert_reason(&error, "parent closed"); + + let Err(error) = webtransport::Session::open_uni(&remote).await else { + panic!("latched parent error should block opens"); + }; + let OpenStreamError::Open { source } = error else { + panic!("expected open stream connection error"); + }; + assert_reason(&source, "parent closed"); + + let closed = quic::Lifecycle::closed(&remote).await; + assert_reason(&closed, "parent closed"); + } + + #[tokio::test] + async fn remote_session_maps_closed_remoc_channel_to_transport_errors() { + let session = Arc::new(TestRpcSession::new()); + let (server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent); + + drop(server_task); + wait_for_remoc_client_to_close(&remote.client).await; + + let Err(error) = webtransport::Session::open_bi(&remote).await else { + panic!("open_bi should surface remoc channel closure"); + }; + let OpenStreamError::Open { source } = error else { + panic!("expected open stream connection error"); + }; + assert_reason(&source, "remoc webtransport session channel closed"); + + let Err(error) = webtransport::Session::accept_uni(&remote).await else { + panic!("accept_uni should surface remoc channel closure"); + }; + let AcceptStreamError::Connection { source } = error else { + panic!("expected accept connection error"); + }; + assert_reason(&source, "remoc webtransport session channel closed"); + + let error = + quic::Lifecycle::check(&remote).expect_err("latched remoc error should fail checks"); + assert_reason(&error, "remoc webtransport session channel closed"); + + let closed = timeout(Duration::from_secs(1), quic::Lifecycle::closed(&remote)) + .await + .expect("latched closed should resolve immediately"); + assert_reason(&closed, "remoc webtransport session channel closed"); + } + + #[tokio::test] + async fn remote_session_prefers_parent_error_over_closed_remoc_channel() { + let session = Arc::new(TestRpcSession::new()); + let (server_task, client) = spawn_rpc_session(session); + let parent = Arc::new(TestLifecycle::default()); + let remote = RemoteWebTransportSession::new(client, wt_session_id(40), parent.clone()); + + drop(server_task); + wait_for_remoc_client_to_close(&remote.client).await; + parent.set_terminal(connection_error("parent closed")); + + let error = quic::Lifecycle::check(&remote).expect_err("parent error should win probe"); + assert_reason(&error, "parent closed"); + } + + #[tokio::test] + async fn concrete_webtransport_rpc_session_surfaces_open_errors() { + let session = webtransport::WebTransportSession::try_from(webtransport_connect()) + .expect("connect should create webtransport session"); + + let open_bi = WebTransportRpcSession::open_bi(&session) + .await + .expect_err("mock connection cannot open bidi streams"); + assert!(matches!( + open_bi, + OpenStreamError::Open { .. } | OpenStreamError::Closed { .. }, + )); + + let open_uni = WebTransportRpcSession::open_uni(&session) + .await + .expect_err("mock connection cannot open uni streams"); + assert!(matches!( + open_uni, + OpenStreamError::Open { .. } | OpenStreamError::Closed { .. }, + )); + } + + #[tokio::test] + async fn concrete_webtransport_rpc_session_wraps_successful_open_streams() { + let (_quic, connection) = connection_with_enabled_webtransport_pair(); + let session = + webtransport::WebTransportSession::try_from(webtransport_connect_on(connection)) + .expect("connect should create webtransport session"); + + let (reader, writer) = WebTransportRpcSession::open_bi(&session) + .await + .expect("open_bi should wrap successful streams"); + let lifecycle = Arc::new(StreamTestLifecycle::new()); + let mut reader = Box::pin(reader.into_quic(lifecycle.clone())) as BoxQuicStreamReader; + let mut writer = Box::pin(writer.into_quic(lifecycle)) as BoxQuicStreamWriter; + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(0) + ); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(0) + ); + let write = writer.send(Bytes::from_static(b"wrapped-bidi")); + let read = reader.next(); + let (write, received) = tokio::join!(write, read); + write.expect("wrapped bidi writer should remain usable"); + assert!(received.is_none()); + + let (_quic, connection) = connection_with_enabled_webtransport_pair(); + let session = + webtransport::WebTransportSession::try_from(webtransport_connect_on(connection)) + .expect("connect should create webtransport session"); + + let writer = WebTransportRpcSession::open_uni(&session) + .await + .expect("open_uni should wrap successful stream"); + let lifecycle = Arc::new(StreamTestLifecycle::new()); + let mut writer = Box::pin(writer.into_quic(lifecycle)) as BoxQuicStreamWriter; + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(0) + ); + writer + .send(Bytes::from_static(b"wrapped-uni")) + .await + .expect("wrapped uni writer should remain usable"); + } + + #[tokio::test] + async fn concrete_webtransport_rpc_session_preserves_connection_closed_accepts() { + let (quic, connection) = connection_with_webtransport_pair(); + let session = + webtransport::WebTransportSession::try_from(webtransport_connect_on(connection)) + .expect("connect should create webtransport session"); + quic.set_terminal_error(connection_error("accept_bi connection closed")); + + let accept_bi = timeout( + Duration::from_secs(1), + WebTransportRpcSession::accept_bi(&session), + ) + .await + .expect("accept_bi should resolve on closed connection") + .expect_err("accept_bi should preserve connection closure"); + let AcceptStreamError::Connection { source } = accept_bi else { + panic!("expected accept_bi connection error"); + }; + assert_reason(&source, "accept_bi connection closed"); + + let (quic, connection) = connection_with_webtransport_pair(); + let session = + webtransport::WebTransportSession::try_from(webtransport_connect_on(connection)) + .expect("connect should create webtransport session"); + quic.set_terminal_error(connection_error("accept_uni connection closed")); + + let accept_uni = timeout( + Duration::from_secs(1), + WebTransportRpcSession::accept_uni(&session), + ) + .await + .expect("accept_uni should resolve on closed connection") + .expect_err("accept_uni should preserve connection closure"); + let AcceptStreamError::Connection { source } = accept_uni else { + panic!("expected accept_uni connection error"); + }; + assert_reason(&source, "accept_uni connection closed"); + } +} diff --git a/src/server.rs b/src/server.rs deleted file mode 100644 index 525658e..0000000 --- a/src/server.rs +++ /dev/null @@ -1,232 +0,0 @@ -use std::{error::Error, sync::Arc}; - -use futures::future; -use snafu::Report; -use tokio::task::JoinSet; -use tracing::Instrument; - -pub use crate::message::{ - stream::{MessageStreamError, ReadStream, WriteStream}, - unify::ReadToStringError, -}; -use crate::{ - connection::ConnectionBuilder, - error::Code, - pool::Pool, - quic::{self, GetStreamIdExt}, - stream_id::StreamId, -}; - -mod message; -pub use message::{Request, Response, UnresolvedRequest}; -mod route; -pub use route::{MethodRouter, Router}; -mod servers_router; -pub use servers_router::{ServersRouter, ServersRouterDispatchError}; -mod service; -pub use service::{BoxService, BoxServiceFuture, IntoBoxService, Service, box_service}; - -#[derive(Debug)] -pub struct Servers { - pool: Pool, - listener: L, - builder: Arc>, - service: S, -} - -#[bon::bon] -impl Servers -where - L: quic::Listen, -{ - #[builder( - builder_type(vis = "pub"), - start_fn(name = from_quic_listener, vis = "pub") - )] - fn new( - #[builder(default = Pool::empty())] pool: Pool, - listener: L, - service: S, - #[builder(default = Arc::new(ConnectionBuilder::new(Arc::default())))] builder: Arc< - ConnectionBuilder, - >, - ) -> Self { - Self { - pool, - listener, - builder, - service, - } - } - - pub fn quic_listener(&self) -> &L { - &self.listener - } - - pub fn quic_listener_mut(&mut self) -> &mut L { - &mut self.listener - } - - pub fn service(&self) -> &S { - &self.service - } - - pub fn service_mut(&mut self) -> &mut S { - &mut self.service - } - - /// Decompose this `Servers` into its constituent parts. - /// - /// Useful for recovering the listener after cancellation so it can be - /// reused without tearing down the underlying QUIC bindings. - #[allow(clippy::type_complexity)] - pub fn into_parts( - self, - ) -> ( - Pool, - L, - S, - Arc>, - ) { - (self.pool, self.listener, self.service, self.builder) - } -} - -impl Servers -where - L: quic::Listen, - S: tower_service::Service + Clone + Send + Sync + 'static, - S::Future: Send, - S::Error: Into>, -{ - fn handle_incoming_connection( - &self, - connection: Arc, - ) -> impl futures::Future + Send + 'static { - let pool = self.pool.clone(); - let builder = self.builder.clone(); - let service = self.service.clone(); - let span = tracing::info_span!("handle_connection", server_name = tracing::field::Empty); - async move { - tracing::debug!("accepted new QUIC connection"); - let Ok(connection) = builder.build(connection).await else { - // failed to initialize H3 connection - return; - }; - - tracing::debug!("accepted new H3 connection"); - let Ok(local_agent) = connection.local_agent().await else { - // connection already closed - return; - }; - let Some(local_agent) = local_agent else { - tracing::debug!("close incoming connection due to missing SNI"); - // no SNI - connection.close( - Code::H3_INTERNAL_ERROR, - "missing server name (SNI) on incoming connection", - ); - return; - }; - - tracing::Span::current().record("server_name", local_agent.name()); - let connection = Arc::new(connection); - _ = pool.try_insert(connection.clone(), builder.clone()).await; - let mut connection_tasks = JoinSet::new(); - - loop { - let (mut read_stream, write_stream) = match connection.accept_message_stream().await - { - Ok(pair) => { - tracing::trace!("accepted incoming request stream"); - pair - } - Err(error) => { - tracing::debug!( - error = %Report::from_error(error), - "failed to accept incoming request" - ); - break; - } - }; - - let stream_id = match read_stream.stream_id().await { - Ok(stream_id) => stream_id, - Err(error) => { - tracing::debug!( - error = %Report::from_error(error), - "Failed to acquire incoming request stream ID" - ); - continue; - } - }; - - let mut service = service.clone(); - let connection = connection.clone(); - let unresolved_request = UnresolvedRequest { - stream_id: StreamId(stream_id), - read_stream, - write_stream, - // Erase the concrete `C` so the request-handling pipeline - // stays monomorphic in the service type only. - connection: Arc::new(connection.erase()), - }; - - let handle_request = async move { - if let Err(error) = future::poll_fn(|cx| service.poll_ready(cx)).await { - let error = error.into(); - tracing::debug!( - stream_id = %stream_id, - error = %Report::from_error(error.as_ref()), - "Service not ready to handle incoming request" - ); - return; - } - - if let Err(error) = service.call(unresolved_request).await { - let error = error.into(); - if error - .as_ref() - .downcast_ref::() - .is_some() - { - tracing::debug!( - stream_id = %stream_id, - error = %Report::from_error(error.as_ref()), - "close incoming connection due to missing service" - ); - connection.close(Code::H3_NO_ERROR, "no error"); - return; - } - tracing::debug!( - stream_id = %stream_id, - error = %Report::from_error(error.as_ref()), - "Failed to handle incoming request" - ); - } - }; - connection_tasks.spawn(handle_request.in_current_span()); - // Reap completed tasks to prevent unbounded growth - while connection_tasks.try_join_next().is_some() {} - } - } - .instrument(span) - } - - pub async fn run(&mut self) -> L::Error { - let mut tasks = JoinSet::default(); - - loop { - // Reap completed tasks to prevent unbounded growth - while tasks.try_join_next().is_some() {} - match self.listener.accept().await { - Ok(connection) => tasks.spawn(self.handle_incoming_connection(connection)), - Err(error) => break error, - }; - } - } - - pub async fn shutdown(&self) -> Result<(), L::Error> { - self.listener.shutdown().await - } -} diff --git a/src/server/message.rs b/src/server/message.rs deleted file mode 100644 index 9e2c87c..0000000 --- a/src/server/message.rs +++ /dev/null @@ -1,490 +0,0 @@ -use std::sync::Arc; - -use bytes::{Buf, Bytes}; -use futures::{Sink, Stream, StreamExt, future::BoxFuture}; -use http::{ - HeaderMap, HeaderValue, Method, Uri, - header::{AsHeaderName, IntoHeaderName}, - uri::{Authority, PathAndQuery, Scheme}, -}; -use snafu::Report; -use tracing::Instrument; - -use crate::{ - connection::ConnectionState, - error::Code, - message::{ - stream::{MessageStreamError, ReadStream, WriteStream}, - unify::{MalformedMessageError, Message, MessageStage, ReadToStringError}, - }, - protocol::Protocols, - qpack::field::{Protocol, PseudoHeaders}, - quic::{self, agent}, - stream_id::StreamId, -}; - -/// A request that has just been accepted on a QUIC stream but whose HTTP/3 -/// header frame has not yet been read. -/// -/// All per-connection context (local/remote agents, protocol registry) is -/// reachable through [`Self::connection`] — [`resolve`](Self::resolve) -/// awaits those accessors once and hands back an eagerly-populated -/// [`Request`]/[`Response`] pair. -pub struct UnresolvedRequest { - /// QUIC stream identifier for this request. - pub stream_id: StreamId, - /// Incoming request stream — read by [`resolve`](Self::resolve) to pull - /// the HTTP/3 header frame. - pub read_stream: ReadStream, - /// Outgoing response stream — handed to the [`Response`] on resolve. - pub write_stream: WriteStream, - /// Owning h3 connection. Type-erased so the request-handling pipeline - /// stays independent of the concrete QUIC implementation; all accessors - /// that used to live on [`UnresolvedRequest`] (agents, protocol - /// registry) are reachable through the underlying [`ConnectionState`]. - pub connection: Arc>, -} - -impl UnresolvedRequest { - pub async fn resolve(self) -> Result<(Request, Response), MessageStreamError> { - let UnresolvedRequest { - stream_id, - read_stream, - write_stream, - connection, - } = self; - // Agents are backed by a watch channel — fetching them per-request - // is effectively a clone once the handshake has completed. - let local_agent = connection - .local_agent() - .await? - .expect("server connection must have a local agent (SNI)"); - let remote_agent = connection.remote_agent().await?; - let protocols = connection.protocols().clone(); - - let mut request = Request { - message: Message::unresolved_request(), - stream: read_stream, - agent: remote_agent, - stream_id, - protocols: protocols.clone(), - }; - request - .stream - .read_message_header(&mut request.message) - .await?; - let response = Response { - message: Message::unresolved_response(), - stream: write_stream, - agent: local_agent, - stream_id, - protocols, - }; - Ok((request, response)) - } -} - -impl IntoFuture for UnresolvedRequest { - type Output = Result<(Request, Response), MessageStreamError>; - - type IntoFuture = BoxFuture<'static, Self::Output>; - - fn into_future(self) -> Self::IntoFuture { - Box::pin(self.resolve()) - } -} - -pub struct Request { - message: Message, - stream: ReadStream, - agent: Option>, - stream_id: StreamId, - protocols: Arc, -} - -impl Request { - pub fn method(&self) -> Method { - self.message.header().method() - } - - pub fn scheme(&self) -> Option { - self.message.header().scheme() - } - - pub fn authority(&self) -> Option { - self.message.header().authority() - } - - pub fn path(&self) -> Option { - self.message.header().path() - } - - pub fn protocol(&self) -> Option { - self.message.header().protocol() - } - - pub fn uri(&self) -> Uri { - self.message.header().uri() - } - - pub fn headers(&self) -> &http::HeaderMap { - &self.message.header().header_map - } - - pub fn header(&self, name: impl AsHeaderName) -> Option<&HeaderValue> { - self.headers().get(name) - } - - pub async fn read(&mut self) -> Option> { - self.stream.read_message(&mut self.message).await - } - - pub async fn read_all(&mut self) -> Result { - self.stream.read_message_full_body(&mut self.message).await - } - - pub async fn read_to_bytes(&mut self) -> Result { - self.stream - .read_message_body_to_bytes(&mut self.message) - .await - } - - pub async fn read_to_string(&mut self) -> Result { - self.stream - .read_message_body_to_string(&mut self.message) - .await - } - - pub async fn as_stream(&mut self) -> impl Stream> { - futures::stream::unfold(self, async |this| { - this.read().await.map(|item| (item, this)) - }) - .fuse() - } - - pub async fn into_stream(self) -> impl Stream> { - futures::stream::unfold(self, async |mut this| { - this.read().await.map(|item| (item, this)) - }) - .fuse() - } - - pub async fn trailers(&mut self) -> Result<&HeaderMap, MessageStreamError> { - self.stream.read_message_trailer(&mut self.message).await - } - - pub async fn stop(&mut self, code: Code) -> Result<(), MessageStreamError> { - self.stream.stop(code).await - } - - /// Low level access to the underlying read stream - pub fn read_stream(&mut self) -> &mut ReadStream { - &mut self.stream - } - - pub fn agent(&self) -> Option<&Arc> { - self.agent.as_ref() - } - - /// Returns the QUIC stream identifier for this request. - /// - /// The stream ID uniquely identifies the request stream within its QUIC connection. - /// Combined with [`protocols()`](Self::protocols), it serves as the per-stream key - /// for deriving protocol-specific session handles from connection-scoped protocol - /// state: - /// - /// ```ignore - /// let proto = request.protocols().get::().unwrap(); - /// let session = proto.create_session(request.stream_id()); - /// ``` - pub fn stream_id(&self) -> StreamId { - self.stream_id - } - - /// Returns the connection-scoped protocol registry. - /// - /// The returned `Arc` is shared across all request handlers on the same - /// QUIC connection. Use [`Protocols::get`] to look up a concrete protocol runtime - /// by type, then derive per-request handles using [`stream_id()`](Self::stream_id): - /// - /// ```ignore - /// let dhttp = request.protocols().get::().unwrap(); - /// let qpack = request.protocols().get::(); - /// ``` - pub fn protocols(&self) -> &Arc { - &self.protocols - } -} - -pub struct Response { - message: Message, - stream: WriteStream, - agent: Arc, - stream_id: StreamId, - protocols: Arc, -} - -impl Response { - fn check_message_operation( - &mut self, - operation: &str, - operate: impl FnOnce(&mut Self) -> Result<(), MalformedMessageError>, - ) { - if self.message.is_malformed() { - tracing::warn!( - target: "h3x::server", operation, - "Response is malformed, operation will not affect the response stream", - ); - } - if let Err(error) = operate(self) { - tracing::warn!( - target: "h3x::server", operation, error = %Report::from_error(error), - "Operation malformed the response message, response stream will be cancelled with H3_REQUEST_CANCELLED", - ); - self.message.set_malformed(); - } - } - - pub fn headers(&self) -> &http::HeaderMap { - &self.message.header().header_map - } - - pub fn headers_mut(&mut self) -> &mut http::HeaderMap { - self.check_message_operation("modify_headers", |this| { - if this.message.stage() > MessageStage::Header { - return Err(MalformedMessageError::HeaderAlreadySent); - } - Ok(()) - }); - &mut self.message.header_mut().header_map - } - - pub fn set_header(&mut self, name: impl IntoHeaderName, value: HeaderValue) -> &mut Self { - self.headers_mut().insert(name, value); - self - } - - pub fn status(&self) -> Option { - match self.message.header().pseudo_headers { - Some(PseudoHeaders::Response { ref status }) => *status, - _ => unreachable!(), - } - } - - pub fn set_status(&mut self, status: http::StatusCode) -> &mut Self { - self.check_message_operation("set_status", |this| { - if this.message.stage() > MessageStage::Header { - return Err(MalformedMessageError::HeaderAlreadySent); - } - Ok(()) - }); - self.message.header_mut().set_status(status); - self - } - - pub fn set_body(&mut self, content: impl Buf) -> &mut Self { - self.check_message_operation("write_chunked_body", |this| { - if this.message.is_interim_response() { - return Err(MalformedMessageError::BodyOrTrailerOnInterimResponse); - } - if this.message.stage() > MessageStage::Body { - return Err(MalformedMessageError::BodyAlreadySending); - } - if this.message.stage() == MessageStage::Body { - return Err(MalformedMessageError::BodyReplacementDuringSend); - } - this.message.chunked_body()?; - Ok(()) - }); - self.message.set_body(content); - self - } - - pub async fn write( - &mut self, - content: impl Buf + Send, - ) -> Result<&mut Self, MessageStreamError> { - self.check_message_operation("write_streaming_body", |this| { - if this.message.is_interim_response() { - return Err(MalformedMessageError::BodyOrTrailerOnInterimResponse); - } - this.message.streaming_body()?; - Ok(()) - }); - self.stream - .send_message_streaming_body(&mut self.message, content) - .await?; - Ok(self) - } - - pub async fn flush(&mut self) -> Result<&mut Self, MessageStreamError> { - self.check_message_operation("flush_response", |this| { - if !this.message.header().is_empty() { - this.message.header().check_pseudo()?; - } - Ok(()) - }); - self.stream.flush_message(&mut self.message).await?; - Ok(self) - } - - pub fn as_sink(&mut self) -> impl Sink { - crate::message::stream::unfold::write::unfold( - self, - async |request: &mut Self, buf: B| { - request.write(buf).await?; - Ok(request) - }, - async |request: &mut Self| { - request.flush().await?; - Ok(request) - }, - async |request: &mut Self| { - request.close().await?; - Ok(request) - }, - ) - } - - pub fn into_sink(self) -> impl Sink { - crate::message::stream::unfold::write::unfold( - self, - async |request: Self, buf: B| { - let mut request = request; - request.write(buf).await?; - Ok(request) - }, - async |request: Self| { - let mut request = request; - request.flush().await?; - Ok(request) - }, - async |request: Self| { - let mut request = request; - request.close().await?; - Ok(request) - }, - ) - } - - pub fn trailers(&self) -> &HeaderMap { - self.message.trailers() - } - - pub fn trailers_mut(&mut self) -> &mut HeaderMap { - self.check_message_operation("modify_trailers", |this| { - if this.message.is_interim_response() { - return Err(MalformedMessageError::BodyOrTrailerOnInterimResponse); - } - if this.message.stage() > MessageStage::Trailer { - return Err(MalformedMessageError::TrailerAlreadySent); - } - Ok(()) - }); - self.message.trailers_mut() - } - - pub fn set_trailer(&mut self, name: impl IntoHeaderName, value: HeaderValue) -> &mut Self { - self.trailers_mut().insert(name, value); - self - } - - pub fn set_trailers(&mut self, map: HeaderMap) -> &mut Self { - *self.trailers_mut() = map; - self - } - - pub async fn close(&mut self) -> Result<(), MessageStreamError> { - self.check_message_operation("close_response", |this| { - this.message.header().check_pseudo()?; - if this.message.is_interim_response() { - return Err(MalformedMessageError::FinalResponseRequired); - } - Ok(()) - }); - self.stream.close_message(&mut self.message).await - } - - pub async fn cancel(&mut self, code: Code) -> Result<(), MessageStreamError> { - self.stream.cancel(code).await - } - - /// Low level access to the underlying write stream - pub fn write_stream(&mut self) -> &mut WriteStream { - &mut self.stream - } - - pub fn agent(&self) -> &Arc { - &self.agent - } - - /// Returns the QUIC stream identifier for this response. - /// - /// Same stream ID as the corresponding [`Request::stream_id`]. Useful when the - /// response handler needs to interact with connection-scoped protocols: - /// - /// ```ignore - /// let proto = response.protocols().get::().unwrap(); - /// let session = proto.create_session(response.stream_id()); - /// ``` - pub fn stream_id(&self) -> StreamId { - self.stream_id - } - - /// Returns the connection-scoped protocol registry. - /// - /// Same `Arc` as [`Request::protocols`]. See [`Protocols::get`] for - /// typed protocol lookup. - pub fn protocols(&self) -> &Arc { - &self.protocols - } - - /// Mark this response as taken over by a service adapter (e.g. [`TowerService`]). - /// - /// After calling this, the response's async drop becomes a no-op: the - /// service adapter is responsible for closing the stream. - #[cfg(feature = "hyper")] - pub(crate) fn mark_taken_over(&mut self) { - self.message.set_dropped(); - } - - /// Async drop the response properly - pub(crate) fn drop(&mut self) -> Option + Send + use<>> { - if self.message.is_complete() || self.message.is_dropped() { - return None; - } - // It's ok to take: Response will not be used after drop - let mut stream = self.stream.take(); - let mut message = self.message.take(); - - if !message.is_malformed() { - let check = || { - message.header().check_pseudo()?; - if message.is_interim_response() { - return Err(MalformedMessageError::FinalResponseRequired); - } - Ok(()) - }; - if let Err(error) = check() { - message.set_malformed(); - tracing::warn!( - target: "h3x::server", error = %Report::from_error(error), - "Response stream cannot be closed properly as its malformed", - ); - } - } - - Some(async move { - _ = stream.close_message(&mut message).await; - }) - } -} - -impl Drop for Response { - fn drop(&mut self) { - if let Some(future) = self.drop() { - // Best-effort: send the end-of-stream marker before the response is dropped. - tokio::spawn(future.in_current_span()); - } - } -} diff --git a/src/server/route.rs b/src/server/route.rs deleted file mode 100644 index 9fa92c6..0000000 --- a/src/server/route.rs +++ /dev/null @@ -1,473 +0,0 @@ -use std::{ - collections::HashMap, - sync::{Arc, RwLock}, - task::{Context, Poll}, -}; - -use futures::future::BoxFuture; -use http::{Method, StatusCode}; - -use crate::server::{ - BoxService, BoxServiceFuture, IntoBoxService, MessageStreamError, Request, Response, Service, - UnresolvedRequest, box_service, -}; - -#[tracing::instrument(skip_all)] -pub async fn default_fallback(_request: &mut Request, response: &mut Response) { - tracing::debug!("call default fallback service (404 Not Found)"); - _ = response.set_status(StatusCode::NOT_FOUND) -} - -#[derive(Debug, Clone)] -struct Fallback(Arc>); - -impl Fallback { - pub fn new(service: BoxService) -> Self { - Self(Arc::new(RwLock::new(service))) - } - - pub fn set(&mut self, service: BoxService) { - *self.0.write().expect("lock is not poisoned") = service; - } -} - -impl Service for Fallback { - type Future<'s> = BoxServiceFuture<'s>; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> { - tracing::debug!("call fallback service"); - self.0 - .read() - .expect("lock is not poisoned") - .serve(request, response) - } -} - -#[derive(Debug, Clone)] -struct RouterInner { - router: matchit::Router, - fallback: Fallback, -} - -impl Default for RouterInner { - fn default() -> Self { - Self { - router: Default::default(), - fallback: Fallback::new(box_service(default_fallback)), - } - } -} - -impl RouterInner { - fn route(&mut self, path: &str, service: impl IntoBoxService) { - self.router - .insert(path, service.into_box_service()) - .expect("failed to register route"); - } - - pub fn on(&mut self, method: Method, path: &str, service: impl IntoBoxService) { - match self.router.at_mut(path) { - Ok(exist_service) => { - if let Some(router) = exist_service - .value - .downcast_mut::>() - { - router.set(method, service.into_box_service()); - } else { - let fallback = exist_service.value.clone(); - let mut router = MethodRouter::new(fallback); - router.set(method, service.into_box_service()); - *exist_service.value = router.into_box_service(); - } - } - Err(..) => { - let mut router = MethodRouter::new(self.fallback.clone().into_box_service()); - router.set(method, service.into_box_service()); - self.route(path, router) - } - } - } -} - -impl Service for RouterInner { - type Future<'s> = BoxServiceFuture<'s>; - - fn serve<'s>( - &self, - request: &'s mut Request, - response: &'s mut Response, - ) -> BoxServiceFuture<'s> { - let Some(path_and_query) = request.path() else { - tracing::debug!("missing path in request URI, call fallback service"); - return self.fallback.serve(request, response); - }; - let path = path_and_query.path(); - let Ok(endpoint) = self.router.at(path) else { - tracing::debug!(path, "path route: not found, call fallback service"); - return self.fallback.serve(request, response); - }; - - tracing::debug!(path, "path route found, call matched service"); - endpoint.value.serve(request, response) - } -} - -#[derive(Debug, Default, Clone)] -pub struct Router { - inner: Arc, -} - -impl Router { - pub fn new() -> Self { - Self::default() - } - - fn inner_ref(&self) -> &RouterInner { - &self.inner - } - - fn inner_mut(&mut self) -> &mut RouterInner { - Arc::make_mut(&mut self.inner) - } - - pub fn route(mut self, path: &str, service: impl IntoBoxService) -> Self { - self.inner_mut().route(path, service.into_box_service()); - self - } - - pub fn on(mut self, method: Method, path: &str, service: impl IntoBoxService) -> Self { - self.inner_mut() - .on(method, path, service.into_box_service()); - self - } - - pub fn fallback(mut self, service: impl IntoBoxService) -> Self { - self.inner_mut().fallback.set(service.into_box_service()); - self - } - - pub fn options(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::OPTIONS, path, service) - } - pub fn get(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::GET, path, service) - } - pub fn post(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::POST, path, service) - } - pub fn put(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::PUT, path, service) - } - pub fn delete(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::DELETE, path, service) - } - pub fn head(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::HEAD, path, service) - } - pub fn trace(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::TRACE, path, service) - } - pub fn connect(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::CONNECT, path, service) - } - pub fn patch(self, path: &str, service: impl IntoBoxService) -> Self { - self.on(Method::PATCH, path, service) - } - - pub fn serve<'s>( - &self, - request: &'s mut Request, - response: &'s mut Response, - ) -> BoxServiceFuture<'s> { - self.inner_ref().serve(request, response) - } - - #[tracing::instrument(skip(self, req), fields(method = tracing::field::Empty, uri = tracing::field::Empty))] - pub async fn handle(&self, req: UnresolvedRequest) -> Result<(), MessageStreamError> { - let (mut request, mut response) = req.resolve().await?; - - tracing::Span::current() - .record("method", request.method().as_str()) - .record("uri", request.uri().to_string()); - - self.serve(&mut request, &mut response).await; - - // Drop response in place to avoid spawning another tokio task - // FIXME: remove this when async drop is stabilized (https://github.com/rust-lang/rust/issues/126482) - if let Some(drop_future) = response.drop() { - drop_future.await; - } - - Ok(()) - } -} - -impl Service for Router { - type Future<'s> = BoxServiceFuture<'s>; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> { - Router::serve(self, request, response) - } -} - -impl tower_service::Service for Router { - type Response = (); - - type Error = MessageStreamError; - - type Future = BoxFuture<'static, Result<(), MessageStreamError>>; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - _ = cx; - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: UnresolvedRequest) -> Self::Future { - let router = self.clone(); - Box::pin(async move { router.handle(req).await }) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct MethodRouter { - // most used methods are stored separately for faster access - options: Option, - get: Option, - post: Option, - put: Option, - delete: Option, - head: Option, - trace: Option, - connect: Option, - patch: Option, - // other - extensions: HashMap, - // fallback service when no method match - fallback: S, -} - -impl MethodRouter { - pub fn new(fallback: S) -> Self { - Self { - options: None, - get: None, - post: None, - put: None, - delete: None, - head: None, - trace: None, - connect: None, - patch: None, - extensions: HashMap::new(), - fallback, - } - } - - pub fn service(&self, method: Method) -> Option<&S> { - match method { - Method::OPTIONS => self.options.as_ref(), - Method::GET => self.get.as_ref(), - Method::POST => self.post.as_ref(), - Method::PUT => self.put.as_ref(), - Method::DELETE => self.delete.as_ref(), - Method::HEAD => self.head.as_ref(), - Method::TRACE => self.trace.as_ref(), - Method::CONNECT => self.connect.as_ref(), - Method::PATCH => self.patch.as_ref(), - _ => self.extensions.get(&method), - } - } - - pub fn service_mut(&mut self, method: Method) -> Option<&mut S> { - match method { - Method::OPTIONS => self.options.as_mut(), - Method::GET => self.get.as_mut(), - Method::POST => self.post.as_mut(), - Method::PUT => self.put.as_mut(), - Method::DELETE => self.delete.as_mut(), - Method::HEAD => self.head.as_mut(), - Method::TRACE => self.trace.as_mut(), - Method::CONNECT => self.connect.as_mut(), - Method::PATCH => self.patch.as_mut(), - _ => self.extensions.get_mut(&method), - } - } - - pub fn set(&mut self, method: Method, service: S) { - match method { - Method::OPTIONS => self.options = Some(service), - Method::GET => self.get = Some(service), - Method::POST => self.post = Some(service), - Method::PUT => self.put = Some(service), - Method::DELETE => self.delete = Some(service), - Method::HEAD => self.head = Some(service), - Method::TRACE => self.trace = Some(service), - Method::CONNECT => self.connect = Some(service), - Method::PATCH => self.patch = Some(service), - _ => _ = self.extensions.insert(method, service), - } - } - - pub fn set_fallback(&mut self, service: S) { - self.fallback = service; - } -} - -impl Service for MethodRouter -where - S: Clone + for<'s> Service: Send> + Send + 'static, -{ - type Future<'s> = BoxServiceFuture<'s>; - - fn serve<'s>( - &self, - request: &'s mut super::Request, - response: &'s mut super::Response, - ) -> Self::Future<'s> { - let method = request.method(); - let service = match method { - Method::OPTIONS => self.options.as_ref().unwrap_or(&self.fallback), - Method::GET => self.get.as_ref().unwrap_or(&self.fallback), - Method::POST => self.post.as_ref().unwrap_or(&self.fallback), - Method::PUT => self.put.as_ref().unwrap_or(&self.fallback), - Method::DELETE => self.delete.as_ref().unwrap_or(&self.fallback), - Method::HEAD => self.head.as_ref().unwrap_or(&self.fallback), - Method::TRACE => self.trace.as_ref().unwrap_or(&self.fallback), - Method::CONNECT => self.connect.as_ref().unwrap_or(&self.fallback), - Method::PATCH => self.patch.as_ref().unwrap_or(&self.fallback), - _ => self.extensions.get(&method).unwrap_or(&self.fallback), - } - .clone(); - Box::pin(async move { service.serve(request, response).await }) - } -} - -#[cfg(test)] -mod tests { - use http::Method; - - use super::MethodRouter; - - fn make_router() -> MethodRouter<&'static str> { - let mut router = MethodRouter::new("fallback"); - router.set(Method::GET, "get_handler"); - router.set(Method::POST, "post_handler"); - router.set(Method::PUT, "put_handler"); - router.set(Method::DELETE, "delete_handler"); - router - } - - #[test] - fn method_router_service_lookup() { - let router = make_router(); - assert_eq!(router.service(Method::GET), Some(&"get_handler")); - assert_eq!(router.service(Method::POST), Some(&"post_handler")); - assert_eq!(router.service(Method::PUT), Some(&"put_handler")); - assert_eq!(router.service(Method::DELETE), Some(&"delete_handler")); - } - - #[test] - fn method_router_unset_returns_none() { - let router = make_router(); - assert_eq!(router.service(Method::PATCH), None); - assert_eq!(router.service(Method::HEAD), None); - assert_eq!(router.service(Method::OPTIONS), None); - assert_eq!(router.service(Method::TRACE), None); - assert_eq!(router.service(Method::CONNECT), None); - } - - #[test] - fn method_router_fallback() { - let router = make_router(); - assert_eq!(router.fallback, "fallback"); - } - - #[test] - fn method_router_set_all_standard_methods() { - let mut router = MethodRouter::new("fb"); - router.set(Method::OPTIONS, "opt"); - router.set(Method::GET, "get"); - router.set(Method::POST, "post"); - router.set(Method::PUT, "put"); - router.set(Method::DELETE, "del"); - router.set(Method::HEAD, "head"); - router.set(Method::TRACE, "trace"); - router.set(Method::CONNECT, "connect"); - router.set(Method::PATCH, "patch"); - - assert_eq!(router.service(Method::OPTIONS), Some(&"opt")); - assert_eq!(router.service(Method::GET), Some(&"get")); - assert_eq!(router.service(Method::POST), Some(&"post")); - assert_eq!(router.service(Method::PUT), Some(&"put")); - assert_eq!(router.service(Method::DELETE), Some(&"del")); - assert_eq!(router.service(Method::HEAD), Some(&"head")); - assert_eq!(router.service(Method::TRACE), Some(&"trace")); - assert_eq!(router.service(Method::CONNECT), Some(&"connect")); - assert_eq!(router.service(Method::PATCH), Some(&"patch")); - } - - #[test] - fn method_router_service_mut() { - let mut router = make_router(); - if let Some(handler) = router.service_mut(Method::GET) { - *handler = "updated_get"; - } - assert_eq!(router.service(Method::GET), Some(&"updated_get")); - } - - #[test] - fn method_router_set_fallback() { - let mut router = make_router(); - router.set_fallback("new_fallback"); - assert_eq!(router.fallback, "new_fallback"); - } - - #[test] - fn method_router_overwrite() { - let mut router = make_router(); - router.set(Method::GET, "overwritten"); - assert_eq!(router.service(Method::GET), Some(&"overwritten")); - } - - #[test] - fn router_builder_chain() { - use super::Router; - - async fn dummy(_req: &mut super::Request, _resp: &mut super::Response) {} - - // Just test that the builder pattern compiles and doesn't panic - let _router = Router::new() - .route("/exact", dummy) - .get("/api/users", dummy) - .post("/api/users", dummy) - .fallback(dummy); - } - - #[test] - fn matchit_path_matching() { - // Directly test the underlying matchit router to verify path matching logic - let mut router = matchit::Router::new(); - router.insert("/", "root").unwrap(); - router.insert("/users", "users").unwrap(); - router.insert("/users/{id}", "user_by_id").unwrap(); - router.insert("/files/{*path}", "files_catch_all").unwrap(); - - // Exact matches - assert_eq!(*router.at("/").unwrap().value, "root"); - assert_eq!(*router.at("/users").unwrap().value, "users"); - - // Parameterized match - let m = router.at("/users/42").unwrap(); - assert_eq!(*m.value, "user_by_id"); - assert_eq!(m.params.get("id"), Some("42")); - - // Catch-all match - let m = router.at("/files/docs/readme.md").unwrap(); - assert_eq!(*m.value, "files_catch_all"); - assert_eq!(m.params.get("path"), Some("docs/readme.md")); - - // No match - assert!(router.at("/nonexistent").is_err()); - assert!(router.at("").is_err()); - } -} diff --git a/src/server/servers_router.rs b/src/server/servers_router.rs deleted file mode 100644 index 9d8169c..0000000 --- a/src/server/servers_router.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - collections::HashMap, - error::Error, - task::{Context, Poll}, -}; - -use futures::future::{self, BoxFuture}; -use snafu::Snafu; - -use crate::server::UnresolvedRequest; - -#[derive(Debug, Clone, Snafu)] -#[snafu(visibility(pub))] -pub enum ServersRouterDispatchError { - #[snafu(display("service not found for server name: {server_name}"))] - MissingService { server_name: String }, -} - -#[derive(Debug, Clone)] -pub struct ServersRouter { - router: std::sync::Arc>, -} - -impl Default for ServersRouter { - fn default() -> Self { - Self { - router: Default::default(), - } - } -} - -impl ServersRouter { - pub fn new() -> Self { - Self::default() - } - - pub fn contains(&self, domain: &str) -> bool { - self.router.contains_key(domain) - } -} - -impl ServersRouter { - fn router_mut(&mut self) -> &mut HashMap { - std::sync::Arc::make_mut(&mut self.router) - } - - pub fn insert(&mut self, domain: impl Into, service: S) -> Option { - self.router_mut().insert(domain.into(), service) - } - - pub fn serve(&mut self, domain: impl Into, service: S) -> &mut Self { - _ = self.insert(domain, service); - self - } -} - -impl tower_service::Service for ServersRouter -where - S: tower_service::Service + Clone + Send + Sync + 'static, - S::Future: Send, - S::Error: Into>, -{ - type Response = (); - - type Error = Box; - - type Future = BoxFuture<'static, Result<(), Self::Error>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: UnresolvedRequest) -> Self::Future { - let router = self.router.clone(); - Box::pin(async move { - // Resolve the server name lazily from the connection instead of - // carrying it on `UnresolvedRequest`. The local agent watch is - // already populated by the time the first request arrives, so - // this is effectively a clone. - let local_agent = match req.connection.local_agent().await { - Ok(Some(agent)) => agent, - Ok(None) => { - return Err(ServersRouterDispatchError::MissingService { - server_name: String::new(), - } - .into()); - } - Err(error) => return Err(Box::new(error) as _), - }; - let server_name = local_agent.name().to_string(); - let Some(mut service) = router.get(server_name.as_str()).cloned() else { - return Err(ServersRouterDispatchError::MissingService { server_name }.into()); - }; - - future::poll_fn(|cx| service.poll_ready(cx)) - .await - .map_err(Into::into)?; - service.call(req).await.map_err(Into::into) - }) - } -} diff --git a/src/server/service.rs b/src/server/service.rs deleted file mode 100644 index 77d369d..0000000 --- a/src/server/service.rs +++ /dev/null @@ -1,115 +0,0 @@ -use std::any::Any; - -use futures::future::BoxFuture; - -use crate::server::{Request, Response}; - -pub trait Service { - type Future<'s>: Future; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s>; -} - -/// A helper trait to allow using async closures as services. -pub trait ServiceFn<'s> { - type Future: Future + 's; - - fn call(&self, req: &'s mut Request, res: &'s mut Response) -> Self::Future; -} - -impl<'s, F, Fut> ServiceFn<'s> for F -where - F: Fn(&'s mut Request, &'s mut Response) -> Fut, - Fut: Future + Send + 's, -{ - type Future = Fut; - - fn call(&self, req: &'s mut Request, res: &'s mut Response) -> Self::Future { - (self)(req, res) - } -} - -impl Service for S -where - S: for<'s> ServiceFn<'s>, -{ - type Future<'s> = >::Future; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> { - self.call(request, response) - } -} - -trait CloneableService: Any { - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> BoxFuture<'s, ()>; - - fn clone_box(&self) -> Box; -} - -impl Service: Send> + Any + Clone + Send + Sync> CloneableService for H { - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> BoxFuture<'s, ()> { - Box::pin(self.serve(request, response)) - } - - fn clone_box(&self) -> Box { - Box::new(self.clone()) - } -} - -pub trait IntoBoxService: - for<'s> Service: Send> + Clone + Send + Sync + 'static -{ - fn into_box_service(self) -> BoxService { - BoxService(Box::new(self)) - } -} - -impl Service: Send> + Clone + Send + Sync + 'static> IntoBoxService for S {} - -pub type BoxServiceFuture<'s> = BoxFuture<'s, ()>; -type DynService = dyn CloneableService + Send + Sync; - -pub struct BoxService(Box); - -impl std::fmt::Debug for BoxService { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("BoxService").finish() - } -} - -impl BoxService { - pub fn downcast_ref(&self) -> Option<&T> { - (self.0.as_ref() as &dyn Any).downcast_ref::() - } - - pub fn downcast_mut(&mut self) -> Option<&mut T> { - (self.0.as_mut() as &mut dyn Any).downcast_mut::() - } - - pub fn downcast(self) -> Result, BoxService> { - match (self.0.as_ref() as &dyn Any).is::() { - true => Ok((self.0 as Box) - .downcast::() - .expect("type checked by is::()")), - false => Err(self), - } - } -} - -impl Service for BoxService { - type Future<'s> = BoxServiceFuture<'s>; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> { - self.0.serve(request, response) - } -} - -impl Clone for BoxService { - fn clone(&self) -> Self { - Self(self.0.clone_box()) - } -} - -pub fn box_service(service: impl IntoBoxService) -> BoxService { - service.into_box_service() -} diff --git a/src/stream.rs b/src/stream.rs new file mode 100644 index 0000000..93f4759 --- /dev/null +++ b/src/stream.rs @@ -0,0 +1,289 @@ +use std::{ + future::Future, + ops::DerefMut, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::{Sink, Stream}; + +use crate::varint::VarInt; + +pub mod unfold; + +/// Read-only observation of a stream id. +/// +/// Polling for the id has no committed stream side effect and no ordering +/// relationship with data, stop, or reset operations. An implementation may +/// return the id immediately or wait until the id is observable. Dropping a +/// pending [`StreamId`] future commits nothing. +pub trait GetStreamId { + fn poll_stream_id(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +impl GetStreamId for Pin

+where + P: DerefMut, + P::Target: GetStreamId, +{ + fn poll_stream_id(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + >::poll_stream_id(self.as_deref_mut(), cx) + } +} + +impl GetStreamId for &mut S +where + S: GetStreamId + Unpin + ?Sized, +{ + fn poll_stream_id(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + S::poll_stream_id(Pin::new(self.get_mut()), cx) + } +} + +pin_project_lite::pin_project! { + pub struct StreamId { + _error: std::marker::PhantomData E>, + #[pin] + stream: S, + } +} + +impl StreamId { + pub(crate) fn new(stream: S) -> Self { + Self { + stream, + _error: std::marker::PhantomData, + } + } +} + +impl Future for StreamId +where + S: GetStreamId + ?Sized, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().stream.poll_stream_id(cx) + } +} + +pub trait GetStreamIdExt: GetStreamId { + fn stream_id(&mut self) -> StreamId<&mut Self, E> { + StreamId::new(self) + } +} + +impl GetStreamIdExt for T where T: GetStreamId + ?Sized {} + +/// Receive-side stop control for a stream. +/// +/// The first poll of [`poll_stop`](StopStream::poll_stop) commits the stop +/// request and its code. Dropping the caller future after that first poll does +/// not cancel the request; later polls for the same outstanding stop operation +/// continue it rather than creating a new one. +/// +/// Stop asks the peer to stop sending. It does not reset the local send side, +/// and it must not discard bytes that were already received locally but have not +/// yet been delivered to the caller. +pub trait StopStream { + fn poll_stop(self: Pin<&mut Self>, cx: &mut Context<'_>, code: VarInt) -> Poll>; +} + +impl StopStream for Pin

+where + P: DerefMut, + P::Target: StopStream, +{ + fn poll_stop(self: Pin<&mut Self>, cx: &mut Context<'_>, code: VarInt) -> Poll> { + >::poll_stop(self.as_deref_mut(), cx, code) + } +} + +impl StopStream for &mut S +where + S: StopStream + Unpin + ?Sized, +{ + fn poll_stop(self: Pin<&mut Self>, cx: &mut Context<'_>, code: VarInt) -> Poll> { + S::poll_stop(Pin::new(self.get_mut()), cx, code) + } +} + +pin_project_lite::pin_project! { + pub struct Stop { + code: VarInt, + _error: std::marker::PhantomData E>, + #[pin] + stream: S, + } +} + +impl Stop { + pub(crate) fn new(stream: S, code: VarInt) -> Self { + Self { + code, + stream, + _error: std::marker::PhantomData, + } + } +} + +impl Future for Stop +where + S: StopStream + ?Sized, +{ + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + project.stream.poll_stop(cx, *project.code) + } +} + +pub trait StopStreamExt: StopStream { + fn stop(&mut self, code: VarInt) -> Stop<&mut Self, E> { + Stop::new(self, code) + } +} + +impl StopStreamExt for T where T: StopStream + ?Sized {} + +/// Send-side reset control for a stream. +/// +/// This operation is stream reset rather than cancellation of a Rust future. +/// The first poll of [`poll_reset`](ResetStream::poll_reset) commits the reset +/// code. Once committed, reset may interrupt in-flight send-side work such as +/// data send, flush, or shutdown. Reset does not stop local receive-side byte +/// delivery. +pub trait ResetStream { + fn poll_reset(self: Pin<&mut Self>, cx: &mut Context<'_>, code: VarInt) -> Poll>; +} + +impl ResetStream for Pin

+where + P: DerefMut, + P::Target: ResetStream, +{ + fn poll_reset(self: Pin<&mut Self>, cx: &mut Context<'_>, code: VarInt) -> Poll> { + >::poll_reset(self.as_deref_mut(), cx, code) + } +} + +impl ResetStream for &mut S +where + S: ResetStream + Unpin + ?Sized, +{ + fn poll_reset(self: Pin<&mut Self>, cx: &mut Context<'_>, code: VarInt) -> Poll> { + S::poll_reset(Pin::new(self.get_mut()), cx, code) + } +} + +pin_project_lite::pin_project! { + pub struct Reset { + code: VarInt, + _error: std::marker::PhantomData E>, + #[pin] + stream: S, + } +} + +impl Reset { + pub(crate) fn new(stream: S, code: VarInt) -> Self { + Self { + code, + stream, + _error: std::marker::PhantomData, + } + } +} + +impl Future for Reset +where + S: ResetStream + ?Sized, +{ + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let project = self.project(); + project.stream.poll_reset(cx, *project.code) + } +} + +pub trait ResetStreamExt: ResetStream { + fn reset(&mut self, code: VarInt) -> Reset<&mut Self, E> { + Reset::new(self, code) + } +} + +impl ResetStreamExt for T where T: ResetStream + ?Sized {} + +pub trait ReadStream: + Stream> + StopStream + GetStreamId +{ +} + +impl ReadStream for T where + T: Stream> + StopStream + GetStreamId + ?Sized +{ +} + +pub trait WriteStream: + Sink + ResetStream + GetStreamId +{ +} + +impl WriteStream for T where + T: Sink + ResetStream + GetStreamId + ?Sized +{ +} + +pub type BoxStreamReader = + Pin + Send>>; + +pub type LocalBoxStreamReader = + Pin>>; + +pub type BoxStreamWriter = + Pin + Send>>; + +pub type LocalBoxStreamWriter = + Pin>>; + +pub trait ManageStream { + type Data; + + type ReadError; + type WriteError; + type StopError; + type ResetError; + type StreamIdError; + + type OpenBiError; + type OpenUniError; + type AcceptBiError; + type AcceptUniError; + + type StreamReader: ReadStream; + + type StreamWriter: WriteStream; + + fn open_bi( + &self, + ) -> impl Future> + + Send + + '_; + + fn open_uni( + &self, + ) -> impl Future> + Send + '_; + + fn accept_bi( + &self, + ) -> impl Future> + + Send + + '_; + + fn accept_uni( + &self, + ) -> impl Future> + Send + '_; +} diff --git a/src/stream/unfold.rs b/src/stream/unfold.rs new file mode 100644 index 0000000..a579609 --- /dev/null +++ b/src/stream/unfold.rs @@ -0,0 +1,5 @@ +//! Generic stream unfold adapters. +//! +//! This namespace is reserved for generic stream-preserving unfold adapters. +//! Message- and QUIC-specific adapters remain under their owning modules until +//! they are genericized. diff --git a/src/stream_id.rs b/src/stream_id.rs index 344a384..55645d1 100644 --- a/src/stream_id.rs +++ b/src/stream_id.rs @@ -11,16 +11,16 @@ use crate::{ /// /// A lightweight newtype around [`VarInt`] representing the QUIC stream ID of the /// current request/response pair. Injected as a field in -/// [`Request`](crate::server::Request) and [`Response`](crate::server::Response) -/// (native path) or as a request extension (hyper path). +/// [`UnresolvedRequest`](crate::endpoint::UnresolvedRequest) on the raw +/// endpoint path or as a request extension on the hyper path. /// /// `StreamId` serves as the per-stream key when deriving protocol-specific session /// handles from connection-scoped protocol state stored in [`Protocols`](crate::protocol::Protocols): /// /// ```ignore -/// // Native handler: -/// let proto = request.protocols().get::().unwrap(); -/// let session = proto.create_session(request.stream_id()); +/// // Raw handler: +/// let proto = request.connection.protocols().get::().unwrap(); +/// let session = proto.create_session(request.stream_id); /// /// // Hyper handler: /// let stream_id = request.extensions().get::().unwrap(); @@ -38,6 +38,29 @@ use crate::{ #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct StreamId(pub VarInt); +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum StreamInitiator { + Client, + Server, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum StreamDirection { + Bidirectional, + Unidirectional, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum StreamKind { + ClientBidirectional, + ServerBidirectional, + ClientUnidirectional, + ServerUnidirectional, +} + impl EncodeInto for StreamId { type Output = (); type Error = >::Error; @@ -71,6 +94,51 @@ impl StreamId { pub const fn into_inner(self) -> u64 { self.0.into_inner() } + + pub const fn initiator(self) -> StreamInitiator { + if self.into_inner() & 0x01 == 0 { + StreamInitiator::Client + } else { + StreamInitiator::Server + } + } + + pub const fn direction(self) -> StreamDirection { + if self.into_inner() & 0x02 == 0 { + StreamDirection::Bidirectional + } else { + StreamDirection::Unidirectional + } + } + + pub const fn kind(self) -> StreamKind { + match (self.is_client_initiated(), self.is_bidirectional()) { + (true, true) => StreamKind::ClientBidirectional, + (false, true) => StreamKind::ServerBidirectional, + (true, false) => StreamKind::ClientUnidirectional, + (false, false) => StreamKind::ServerUnidirectional, + } + } + + pub const fn is_client_initiated(self) -> bool { + matches!(self.initiator(), StreamInitiator::Client) + } + + pub const fn is_server_initiated(self) -> bool { + matches!(self.initiator(), StreamInitiator::Server) + } + + pub const fn is_bidirectional(self) -> bool { + matches!(self.direction(), StreamDirection::Bidirectional) + } + + pub const fn is_unidirectional(self) -> bool { + matches!(self.direction(), StreamDirection::Unidirectional) + } + + pub const fn is_client_initiated_bidirectional(self) -> bool { + self.is_client_initiated() && self.is_bidirectional() + } } impl TryFrom for StreamId { @@ -86,3 +154,70 @@ impl fmt::Display for StreamId { write!(f, "{}", self.0) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::codec::{DecodeFrom, EncodeInto}; + + #[test] + fn conversions_preserve_inner_varint() { + let raw = VarInt::from_u32(123); + let stream_id = StreamId::from(raw); + + assert_eq!(stream_id.into_inner(), 123); + assert_eq!(VarInt::from(stream_id), raw); + } + + #[test] + fn try_from_rejects_values_outside_varint_range() { + let error = StreamId::try_from(1_u64 << 62).expect_err("value exceeds QUIC varint range"); + + assert_eq!( + error.to_string(), + "value(4611686018427387904) too large for varint encoding" + ); + } + + #[test] + fn display_delegates_to_inner_varint() { + let stream_id = StreamId::from(VarInt::from_u32(7)); + + assert_eq!(stream_id.to_string(), "7"); + } + + #[test] + fn stream_id_exposes_quic_initiator_and_direction() { + let client_bi = StreamId::from(VarInt::from_u32(0)); + let server_bi = StreamId::from(VarInt::from_u32(1)); + let client_uni = StreamId::from(VarInt::from_u32(2)); + let server_uni = StreamId::from(VarInt::from_u32(3)); + + assert_eq!(client_bi.kind(), StreamKind::ClientBidirectional); + assert_eq!(server_bi.kind(), StreamKind::ServerBidirectional); + assert_eq!(client_uni.kind(), StreamKind::ClientUnidirectional); + assert_eq!(server_uni.kind(), StreamKind::ServerUnidirectional); + + assert_eq!(client_bi.initiator(), StreamInitiator::Client); + assert_eq!(server_bi.initiator(), StreamInitiator::Server); + assert_eq!(client_uni.direction(), StreamDirection::Unidirectional); + assert_eq!(server_bi.direction(), StreamDirection::Bidirectional); + + assert!(client_bi.is_client_initiated_bidirectional()); + assert!(!server_bi.is_client_initiated_bidirectional()); + assert!(!client_uni.is_client_initiated_bidirectional()); + } + + #[tokio::test] + async fn encode_decode_round_trips() { + let (mut writer, mut reader) = tokio::io::duplex(8); + let expected = StreamId::from(VarInt::from_u32(0x3fff)); + + let write = async move { expected.encode_into(&mut writer).await.expect("encode") }; + let read = async move { StreamId::decode_from(&mut reader).await.expect("decode") }; + + let ((), decoded) = tokio::join!(write, read); + + assert_eq!(decoded, expected); + } +} diff --git a/src/util/deferred.rs b/src/util/deferred.rs index c82237b..70622d3 100644 --- a/src/util/deferred.rs +++ b/src/util/deferred.rs @@ -10,6 +10,7 @@ //! it transitions to a [`Resolved`] state and delegates all further calls. use std::{ + collections::VecDeque, fmt::Debug, pin::Pin, task::{Context, Poll, ready}, @@ -138,18 +139,18 @@ where } } -impl quic::CancelStream for Resolved +impl quic::ResetStream for Resolved where - S: quic::CancelStream, + S: quic::ResetStream, E: Clone, quic::StreamError: From, { - fn poll_cancel( + fn poll_reset( self: Pin<&mut Self>, cx: &mut Context, code: VarInt, ) -> Poll> { - self.poll_inner()?.poll_cancel(cx, code) + self.poll_inner()?.poll_reset(cx, code) } } @@ -273,7 +274,7 @@ where // ============================================================================ pin_project_lite::pin_project! { - /// A lazily opened stream. + /// A lazy value adapter. /// /// Wraps a future that resolves to `Result`. While the future is /// pending, trait method calls drive it to completion; once resolved, @@ -337,6 +338,148 @@ impl Deferred { } } +pin_project_lite::pin_project! { + /// Stream-specific lazy reader wrapper that remembers committed receive-side controls. + /// + /// The generic [`Deferred`] wrapper remains a plain lazy value adapter. This + /// type adds read-stream semantics on top: `poll_stop` records a committed + /// STOP_SENDING request even while the opening future is still pending, then + /// applies it before later reads once the stream exists. + pub struct DeferredStreamReader { + #[pin] + inner: Deferred, + pending_stop: Option, + cached_id: Option, + } +} + +impl DeferredStreamReader +where + OpenFuture: Future>, +{ + /// Wrap an opening future without a known stream id. + pub fn new(open: OpenFuture) -> Self { + Self { + inner: Deferred::from(open), + pending_stop: None, + cached_id: None, + } + } +} + +impl From + for DeferredStreamReader +where + OpenFuture: Future>, +{ + fn from(open: OpenFuture) -> Self { + Self::new(open) + } +} + +pin_project_lite::pin_project! { + /// Stream-specific lazy writer wrapper that remembers committed send-side controls. + /// + /// The wrapper records reset, flush, and close operations that are first + /// polled before the opening future resolves. After resolution it drains + /// those operations before accepting new data. + pub struct DeferredStreamWriter { + #[pin] + inner: Deferred, + pending: PendingWriteQueue, + cached_id: Option, + } +} + +impl DeferredStreamWriter +where + OpenFuture: Future>, +{ + /// Wrap an opening future without a known stream id. + pub fn new(open: OpenFuture) -> Self { + Self { + inner: Deferred::from(open), + pending: PendingWriteQueue::default(), + cached_id: None, + } + } +} + +impl From + for DeferredStreamWriter +where + OpenFuture: Future>, +{ + fn from(open: OpenFuture) -> Self { + Self::new(open) + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +enum PendingWriteOp { + Reset(VarInt), + Flush, + Close, +} + +#[derive(Debug, Default)] +struct PendingWriteQueue { + ops: VecDeque, +} + +impl PendingWriteQueue { + fn contains_reset(&self) -> bool { + matches!(self.ops.front(), Some(PendingWriteOp::Reset(_))) + } + + fn contains_flush(&self) -> bool { + self.ops.contains(&PendingWriteOp::Flush) + } + + fn contains_close(&self) -> bool { + self.ops.contains(&PendingWriteOp::Close) + } + + fn contains_kind(&self, op: PendingWriteOp) -> bool { + match op { + PendingWriteOp::Reset(_) => self.contains_reset(), + PendingWriteOp::Flush => self.contains_flush(), + PendingWriteOp::Close => self.contains_close(), + } + } + + fn enqueue_reset(&mut self, code: VarInt) { + if !self.contains_reset() { + self.ops.clear(); + self.ops.push_back(PendingWriteOp::Reset(code)); + } + } + + fn enqueue_flush(&mut self) { + if !self.contains_reset() && !self.contains_flush() { + self.ops.push_back(PendingWriteOp::Flush); + } + } + + fn enqueue_close(&mut self) { + if !self.contains_reset() && !self.contains_close() { + self.ops.push_back(PendingWriteOp::Close); + } + } + + fn front(&self) -> Option { + self.ops.front().copied() + } + + fn pop_front(&mut self) { + self.ops.pop_front(); + } + + fn is_empty(&self) -> bool { + self.ops.is_empty() + } +} + impl Future for Deferred where F: Future>, @@ -350,50 +493,194 @@ where } } -impl quic::GetStreamId for Deferred +impl quic::GetStreamId + for DeferredStreamReader where - S: quic::GetStreamId, - E: Clone, - quic::StreamError: From, - F: Future>, + InnerStream: quic::GetStreamId, + Error: Clone, + quic::StreamError: From, + OpenFuture: Future>, { fn poll_stream_id( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context, ) -> Poll> { - ready!(self.poll(cx)?).poll_stream_id(cx) + if let Some(stream_id) = self.as_mut().project().cached_id.as_ref().copied() { + return Poll::Ready(Ok(stream_id)); + } + + let stream_id = { + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + ready!(stream.poll_stream_id(cx))? + }; + *self.as_mut().project().cached_id = Some(stream_id); + Poll::Ready(Ok(stream_id)) } } -impl quic::StopStream for Deferred +impl quic::StopStream + for DeferredStreamReader where - S: quic::StopStream, - E: Clone, - quic::StreamError: From, - F: Future>, + InnerStream: quic::StopStream, + Error: Clone, + quic::StreamError: From, + OpenFuture: Future>, { fn poll_stop( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context, code: VarInt, ) -> Poll> { - ready!(self.poll(cx)?).poll_stop(cx, code) + if self.as_mut().project().pending_stop.is_none() { + *self.as_mut().project().pending_stop = Some(code); + } + let code = self + .as_mut() + .project() + .pending_stop + .expect("pending stop code should be present"); + + let result = { + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + ready!(stream.poll_stop(cx, code)) + }; + *self.as_mut().project().pending_stop = None; + Poll::Ready(result) } } -impl quic::CancelStream for Deferred +impl futures::Stream + for DeferredStreamReader where - S: quic::CancelStream, - E: Clone, - quic::StreamError: From, - F: Future>, + InnerStream: futures::Stream> + quic::StopStream, + Error: Clone, + quic::StreamError: From, + StreamItemError: From + From, + OpenFuture: Future>, { - fn poll_cancel( - self: Pin<&mut Self>, + type Item = InnerStream::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Some(code) = self.as_mut().project().pending_stop.as_ref().copied() + && let Err(error) = ready!(quic::StopStream::poll_stop(self.as_mut(), cx, code)) + { + return Poll::Ready(Some(Err(StreamItemError::from(error)))); + } + + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Some(Err(StreamItemError::from(error)))), + }; + stream.poll_next(cx) + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } +} + +impl FusedStream + for DeferredStreamReader +where + InnerStream: FusedStream> + quic::StopStream, + Error: Clone, + quic::StreamError: From, + StreamItemError: From + From, + OpenFuture: Future>, +{ + fn is_terminated(&self) -> bool { + match &self.inner { + Deferred::Pending { .. } => false, + Deferred::Ready { + resolved: Resolved::Error { .. }, + } => true, + Deferred::Ready { .. } if self.pending_stop.is_some() => false, + Deferred::Ready { resolved } => resolved.is_terminated(), + } + } +} + +impl quic::GetStreamId + for DeferredStreamWriter +where + InnerStream: quic::GetStreamId, + Error: Clone, + quic::StreamError: From, + OpenFuture: Future>, +{ + fn poll_stream_id( + mut self: Pin<&mut Self>, + cx: &mut Context, + ) -> Poll> { + if let Some(stream_id) = self.as_mut().project().cached_id.as_ref().copied() { + return Poll::Ready(Ok(stream_id)); + } + + let stream_id = { + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + ready!(stream.poll_stream_id(cx))? + }; + *self.as_mut().project().cached_id = Some(stream_id); + Poll::Ready(Ok(stream_id)) + } +} + +impl DeferredStreamWriter +where + InnerStream: quic::ResetStream, + Error: Clone, + quic::StreamError: From, + OpenFuture: Future>, +{ + fn poll_reset_op( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let Some(PendingWriteOp::Reset(code)) = self.as_mut().project().pending.front() else { + return Poll::Ready(Ok(())); + }; + + let result = { + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + ready!(stream.poll_reset(cx, code)) + }; + self.as_mut().project().pending.pop_front(); + Poll::Ready(result) + } +} + +impl quic::ResetStream + for DeferredStreamWriter +where + InnerStream: quic::ResetStream, + Error: Clone, + quic::StreamError: From, + OpenFuture: Future>, +{ + fn poll_reset( + mut self: Pin<&mut Self>, cx: &mut Context, code: VarInt, ) -> Poll> { - ready!(self.poll(cx)?).poll_cancel(cx, code) + self.as_mut().project().pending.enqueue_reset(code); + self.poll_reset_op(cx) } } @@ -516,3 +803,835 @@ where ready!(self.poll(cx)?).poll_close(cx) } } + +impl DeferredStreamWriter +where + InnerStream: quic::ResetStream, + Error: Clone, + quic::StreamError: From, + OpenFuture: Future>, +{ + fn poll_pending_until( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + target: Option, + ) -> Poll> + where + InnerStream: Sink, + SinkError: From + From, + { + loop { + let op = { + let project = self.as_mut().project(); + if let Some(target) = target + && !project.pending.contains_kind(target) + { + return Poll::Ready(Ok(())); + } + match project.pending.front() { + Some(op) => op, + None => return Poll::Ready(Ok(())), + } + }; + + match op { + PendingWriteOp::Reset(_) => { + ready!(self.as_mut().poll_reset_op(cx)).map_err(SinkError::from)?; + } + PendingWriteOp::Flush => { + let result = { + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + ready!(stream.poll_flush(cx)) + }; + self.as_mut().project().pending.pop_front(); + result?; + if target == Some(PendingWriteOp::Flush) { + return Poll::Ready(Ok(())); + } + } + PendingWriteOp::Close => { + let result = { + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + ready!(stream.poll_close(cx)) + }; + self.as_mut().project().pending.pop_front(); + result?; + if target == Some(PendingWriteOp::Close) { + return Poll::Ready(Ok(())); + } + } + } + } + } +} + +impl Sink + for DeferredStreamWriter +where + InnerStream: Sink + quic::ResetStream, + Error: Clone, + InnerStream::Error: From + From, + quic::StreamError: From, + OpenFuture: Future>, +{ + type Error = InnerStream::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while !self.as_mut().project().pending.is_empty() { + ready!( + self.as_mut() + .poll_pending_until::(cx, None)? + ); + } + + let project = self.as_mut().project(); + let stream = match ready!(project.inner.poll(cx)) { + Ok(stream) => stream, + Err(error) => return Poll::Ready(Err(error.into())), + }; + stream.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { + let project = self.project(); + assert!( + project.pending.is_empty(), + "start_send called without poll_ready being called first" + ); + project + .inner + .try_peek_mut() + .expect("start_send before poll_ready completed")? + .start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().pending.enqueue_flush(); + self.poll_pending_until(cx, Some(PendingWriteOp::Flush)) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.as_mut().project().pending.enqueue_close(); + self.poll_pending_until(cx, Some(PendingWriteOp::Close)) + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + convert::Infallible, + future::{Ready, pending, ready}, + io::Cursor, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use futures::{ + FutureExt, Sink, SinkExt, Stream, StreamExt, future::poll_fn, stream::FusedStream, + task::noop_waker_ref, + }; + use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, ReadBuf}, + sync::oneshot, + }; + + use super::*; + use crate::{ + codec::{DecodeError, EncodeError}, + quic::{GetStreamIdExt, ResetStream, ResetStreamExt, StopStream, StopStreamExt}, + }; + + fn varint(value: u32) -> VarInt { + VarInt::from_u32(value) + } + + fn assert_reset(error: quic::StreamError, expected: VarInt) { + let quic::StreamError::Reset { code } = error else { + panic!("expected stream reset"); + }; + assert_eq!(code, expected); + } + + fn resolved_quic_error(value: T, code: u32) -> Resolved { + drop(value); + Resolved::err(quic::StreamError::Reset { code: varint(code) }) + } + + fn deferred_quic_error( + value: T, + code: u32, + ) -> Deferred>> { + drop(value); + Deferred::from(ready(Err(quic::StreamError::Reset { code: varint(code) }))) + } + + #[derive(Debug, Clone, PartialEq, Eq)] + enum DeferredStreamEvent { + Stop(VarInt), + Write(Bytes), + Flush, + Close, + Reset(VarInt), + } + + #[derive(Debug)] + struct RecordingReader { + stream_id: VarInt, + events: Arc>>, + chunks: VecDeque>, + terminated: bool, + } + + impl quic::GetStreamId for RecordingReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::StopStream for RecordingReader { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.events + .lock() + .expect("event log poisoned") + .push(DeferredStreamEvent::Stop(code)); + Poll::Ready(Ok(())) + } + } + + impl Stream for RecordingReader { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.chunks.pop_front() { + Some(item) => Poll::Ready(Some(item)), + None => { + this.terminated = true; + Poll::Ready(None) + } + } + } + } + + impl FusedStream for RecordingReader { + fn is_terminated(&self) -> bool { + self.terminated + } + } + + #[derive(Debug)] + struct RecordingWriter { + stream_id: VarInt, + events: Arc>>, + } + + impl quic::GetStreamId for RecordingWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::ResetStream for RecordingWriter { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context, + code: VarInt, + ) -> Poll> { + self.events + .lock() + .expect("event log poisoned") + .push(DeferredStreamEvent::Reset(code)); + Poll::Ready(Ok(())) + } + } + + impl Sink for RecordingWriter { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.events + .lock() + .expect("event log poisoned") + .push(DeferredStreamEvent::Write(item)); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.events + .lock() + .expect("event log poisoned") + .push(DeferredStreamEvent::Flush); + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + self.events + .lock() + .expect("event log poisoned") + .push(DeferredStreamEvent::Close); + Poll::Ready(Ok(())) + } + } + + #[test] + fn resolved_converts_from_result_and_back() { + assert_eq!(Resolved::<_, DecodeError>::from(Ok(7)).into_result(), Ok(7)); + assert_eq!( + Resolved::::from(Err(DecodeError::Incomplete)).into_result(), + Err(DecodeError::Incomplete) + ); + } + + #[cfg(feature = "serde")] + #[test] + fn resolved_serializes_and_deserializes_error_result() { + let resolved: Resolved = Resolved::err(7); + + let encoded = serde_json::to_value(&resolved).expect("error result serializes"); + assert_eq!(encoded, serde_json::json!({ "Err": 7 })); + + let decoded: Resolved = + serde_json::from_value(encoded).expect("error result deserializes"); + assert_eq!(decoded.into_result(), Err(7)); + } + + #[tokio::test] + async fn resolved_delegates_async_read_and_write() { + let (mut input, output) = tokio::io::duplex(16); + input.write_all(b"hello").await.expect("write input"); + input.shutdown().await.expect("shutdown input"); + + let mut reader: Resolved<_, DecodeError> = Resolved::ok(output); + let mut received = Vec::new(); + reader + .read_to_end(&mut received) + .await + .expect("read resolved value"); + assert_eq!(received, b"hello"); + + let (writer_side, mut output) = tokio::io::duplex(16); + let mut writer: Resolved<_, EncodeError> = Resolved::ok(writer_side); + writer + .write_all(b"world") + .await + .expect("write resolved value"); + writer.flush().await.expect("flush resolved value"); + writer.shutdown().await.expect("shutdown resolved value"); + + let mut received = Vec::new(); + output + .read_to_end(&mut received) + .await + .expect("read written bytes"); + assert_eq!(received, b"world"); + } + + #[tokio::test] + async fn resolved_error_maps_to_async_io_errors() { + let (_input, output) = tokio::io::duplex(1); + let mut reader: Resolved = + Resolved::err(DecodeError::Incomplete); + let error = reader.read(&mut [0]).await.expect_err("read should fail"); + assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof); + drop(output); + + let (writer_side, _output) = tokio::io::duplex(1); + let mut writer: Resolved = + Resolved::err(EncodeError::FramePayloadTooLarge); + let error = writer.write_all(b"x").await.expect_err("write should fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + drop(writer_side); + } + + #[tokio::test] + async fn resolved_error_maps_to_buffered_read_flush_and_shutdown_errors() { + let mut reader: Resolved>>, DecodeError> = + Resolved::err(DecodeError::Incomplete); + let error = reader.fill_buf().await.expect_err("fill_buf should fail"); + assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof); + + let mut writer: Resolved = + Resolved::err(EncodeError::FramePayloadTooLarge); + let error = writer.flush().await.expect_err("flush should fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + let error = writer.shutdown().await.expect_err("shutdown should fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn resolved_delegates_buffered_read() { + let cursor = Cursor::new(b"line\nrest".to_vec()); + let mut reader: Resolved<_, DecodeError> = Resolved::ok(BufReader::new(cursor)); + + let bytes = reader.fill_buf().await.expect("fill buffer"); + assert_eq!(bytes, b"line\nrest"); + reader.consume(5); + + let mut rest = String::new(); + reader + .read_to_string(&mut rest) + .await + .expect("read after consume"); + assert_eq!(rest, "rest"); + } + + #[tokio::test] + async fn resolved_delegates_quic_stream_traits_and_stream_sink() { + let (reader, writer) = quic::test::mock_stream_pair(varint(11)); + let mut reader: Resolved<_, quic::StreamError> = Resolved::ok(reader); + let mut writer: Resolved<_, quic::StreamError> = Resolved::ok(writer); + + assert_eq!(reader.stream_id().await.expect("reader id"), varint(11)); + assert_eq!(writer.stream_id().await.expect("writer id"), varint(11)); + + writer + .send(Bytes::from_static(b"payload")) + .await + .expect("send"); + assert_eq!( + reader.next().await.expect("item").expect("read"), + Bytes::from_static(b"payload") + ); + + reader.stop(varint(12)).await.expect("stop"); + writer.reset(varint(13)).await.expect("reset"); + + let (_reader, writer) = quic::test::mock_stream_pair(varint(14)); + let mut writer: Resolved<_, quic::StreamError> = Resolved::ok(writer); + writer.close().await.expect("close"); + } + + #[tokio::test] + async fn resolved_error_maps_to_quic_stream_error_for_control_traits() { + let (reader, _writer) = quic::test::mock_stream_pair(varint(51)); + let mut stream_id = resolved_quic_error(reader, 52); + assert_reset( + stream_id + .stream_id() + .await + .expect_err("stream id should fail"), + varint(52), + ); + + let (reader, _writer) = quic::test::mock_stream_pair(varint(53)); + let mut stop = resolved_quic_error(reader, 54); + assert_reset( + stop.stop(varint(55)).await.expect_err("stop should fail"), + varint(54), + ); + + let (_reader, writer) = quic::test::mock_stream_pair(varint(56)); + let mut reset = resolved_quic_error(writer, 57); + assert_reset( + reset + .reset(varint(58)) + .await + .expect_err("reset should fail"), + varint(57), + ); + } + + #[test] + fn resolved_error_maps_to_sink_operation_errors() { + let (_reader, writer) = quic::test::mock_stream_pair(varint(59)); + let mut sink = Box::pin(resolved_quic_error(writer, 60)); + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + + let Poll::Ready(Err(error)) = sink.as_mut().poll_ready(&mut cx) else { + panic!("poll_ready should return reset"); + }; + assert_reset(error, varint(60)); + + let error = sink + .as_mut() + .start_send(Bytes::from_static(b"ignored")) + .expect_err("start_send should fail"); + assert_reset(error, varint(60)); + + let Poll::Ready(Err(error)) = sink.as_mut().poll_flush(&mut cx) else { + panic!("poll_flush should return reset"); + }; + assert_reset(error, varint(60)); + + let Poll::Ready(Err(error)) = sink.as_mut().poll_close(&mut cx) else { + panic!("poll_close should return reset"); + }; + assert_reset(error, varint(60)); + } + + #[tokio::test] + async fn resolved_stream_error_yields_error_item_and_is_terminated() { + let mut stream: Resolved< + futures::stream::Empty>, + quic::StreamError, + > = Resolved::err(quic::StreamError::Reset { code: varint(31) }); + assert_eq!(stream.size_hint(), (0, None)); + let item = stream.next().await.expect("error item").expect_err("reset"); + assert_reset(item, varint(31)); + assert!(stream.is_terminated()); + + let value: Resolved<_, Infallible> = + Resolved::ok(futures::stream::iter([Ok::<_, Infallible>( + Bytes::from_static(b"one"), + )])); + assert_eq!(value.size_hint(), (1, Some(1))); + } + + #[tokio::test] + async fn deferred_future_resolves_success_and_error() { + let value = Deferred::from(ready(Ok::<_, DecodeError>(42))) + .await + .expect("deferred value"); + assert_eq!(value, 42); + + let error = Deferred::from(ready(Err::(DecodeError::IntegerOverflow))) + .await + .expect_err("deferred error"); + assert_eq!(error, DecodeError::IntegerOverflow); + } + + #[tokio::test] + async fn deferred_delegates_async_read_write_and_buffered_read() { + let (mut input, output) = tokio::io::duplex(16); + input.write_all(b"hello").await.expect("write input"); + input.shutdown().await.expect("shutdown input"); + + let mut reader = Deferred::from(ready(Ok::<_, DecodeError>(output))); + let mut received = Vec::new(); + reader + .read_to_end(&mut received) + .await + .expect("read deferred value"); + assert_eq!(received, b"hello"); + + let (writer_side, mut output) = tokio::io::duplex(16); + let mut writer = Deferred::from(ready(Ok::<_, EncodeError>(writer_side))); + writer + .write_all(b"world") + .await + .expect("write deferred value"); + writer.flush().await.expect("flush deferred value"); + writer.shutdown().await.expect("shutdown deferred value"); + let mut received = Vec::new(); + output + .read_to_end(&mut received) + .await + .expect("read written bytes"); + assert_eq!(received, b"world"); + + let cursor = Cursor::new(b"line\nrest".to_vec()); + let mut reader = Deferred::from(ready(Ok::<_, DecodeError>(BufReader::new(cursor)))); + assert_eq!(reader.fill_buf().await.expect("fill buffer"), b"line\nrest"); + reader.consume(5); + let mut rest = String::new(); + reader + .read_to_string(&mut rest) + .await + .expect("read after consume"); + assert_eq!(rest, "rest"); + } + + #[tokio::test] + async fn deferred_error_maps_to_async_io_errors() { + let (_input, output) = tokio::io::duplex(1); + let mut reader = Deferred::from(ready({ + drop(output); + Err::(DecodeError::Incomplete) + })); + let error = reader.read(&mut [0]).await.expect_err("read should fail"); + assert_eq!(error.kind(), io::ErrorKind::UnexpectedEof); + + let (writer_side, _output) = tokio::io::duplex(1); + let mut writer = Deferred::from(ready({ + drop(writer_side); + Err::(EncodeError::HuffmanEncoding) + })); + let error = writer.write_all(b"x").await.expect_err("write should fail"); + assert_eq!(error.kind(), io::ErrorKind::InvalidData); + } + + #[tokio::test] + async fn deferred_stream_wrappers_delegate_quic_stream_traits_and_stream_sink() { + let (reader, writer) = quic::test::mock_stream_pair(varint(21)); + let mut reader = DeferredStreamReader::from(ready(Ok::<_, quic::StreamError>(reader))); + let mut writer = DeferredStreamWriter::from(ready(Ok::<_, quic::StreamError>(writer))); + + assert_eq!(reader.stream_id().await.expect("reader id"), varint(21)); + assert_eq!(writer.stream_id().await.expect("writer id"), varint(21)); + + writer + .send(Bytes::from_static(b"deferred")) + .await + .expect("send"); + assert_eq!( + reader.next().await.expect("item").expect("read"), + Bytes::from_static(b"deferred") + ); + + assert_eq!(reader.size_hint(), (0, None)); + + reader.stop(varint(22)).await.expect("stop"); + writer.reset(varint(23)).await.expect("reset"); + + let (_reader, writer) = quic::test::mock_stream_pair(varint(24)); + let mut writer = Deferred::from(ready(Ok::<_, quic::StreamError>(writer))); + writer.close().await.expect("close"); + } + + #[tokio::test] + async fn deferred_stream_reader_applies_stop_committed_before_resolution() { + let events = Arc::new(Mutex::new(Vec::new())); + let (open_tx, open_rx) = oneshot::channel(); + let stop_code = varint(101); + let mut reader = Box::pin(DeferredStreamReader::from(async move { + open_rx + .await + .expect("stream opener should resolve before test ends") + })); + + assert!( + poll_fn(|cx| reader.as_mut().poll_stop(cx, stop_code)) + .now_or_never() + .is_none() + ); + assert!(events.lock().expect("event log poisoned").is_empty()); + + open_tx + .send(Ok::<_, quic::StreamError>(RecordingReader { + stream_id: varint(102), + events: events.clone(), + chunks: VecDeque::from([Ok(Bytes::from_static(b"after stop"))]), + terminated: false, + })) + .expect("send opener result"); + + assert_eq!( + reader + .as_mut() + .next() + .await + .expect("chunk after stop") + .expect("read succeeds"), + Bytes::from_static(b"after stop") + ); + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![DeferredStreamEvent::Stop(stop_code)] + ); + } + + #[tokio::test] + async fn deferred_stream_reader_is_terminated_after_pending_stop_open_error() { + let stop_code = varint(103); + let error_code = varint(104); + let mut reader = Box::pin(DeferredStreamReader::from(ready( + Err::(quic::StreamError::Reset { code: error_code }), + ))); + + let error = poll_fn(|cx| reader.as_mut().poll_stop(cx, stop_code)) + .await + .expect_err("stop should report opener error"); + + assert_reset(error, error_code); + assert!(reader.is_terminated()); + } + + #[tokio::test] + async fn deferred_stream_writer_reset_replaces_flush_committed_before_resolution() { + let events = Arc::new(Mutex::new(Vec::new())); + let (open_tx, open_rx) = oneshot::channel(); + let reset_code = varint(111); + let mut writer = Box::pin(DeferredStreamWriter::from(async move { + open_rx + .await + .expect("stream opener should resolve before test ends") + })); + + assert!( + poll_fn(|cx| writer.as_mut().poll_flush(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| writer.as_mut().poll_reset(cx, reset_code)) + .now_or_never() + .is_none() + ); + assert!(events.lock().expect("event log poisoned").is_empty()); + + open_tx + .send(Ok::<_, quic::StreamError>(RecordingWriter { + stream_id: varint(112), + events: events.clone(), + })) + .expect("send opener result"); + + poll_fn(|cx| writer.as_mut().poll_ready(cx)) + .await + .expect("writer becomes ready after reset drains"); + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![DeferredStreamEvent::Reset(reset_code)] + ); + } + + #[tokio::test] + async fn deferred_stream_writer_drains_flush_and_close_in_commit_order_after_resolution() { + let events = Arc::new(Mutex::new(Vec::new())); + let (open_tx, open_rx) = oneshot::channel(); + let mut writer = Box::pin(DeferredStreamWriter::from(async move { + open_rx + .await + .expect("stream opener should resolve before test ends") + })); + + assert!( + poll_fn(|cx| writer.as_mut().poll_close(cx)) + .now_or_never() + .is_none() + ); + assert!( + poll_fn(|cx| writer.as_mut().poll_flush(cx)) + .now_or_never() + .is_none() + ); + + open_tx + .send(Ok::<_, quic::StreamError>(RecordingWriter { + stream_id: varint(113), + events: events.clone(), + })) + .expect("send opener result"); + + poll_fn(|cx| writer.as_mut().poll_ready(cx)) + .await + .expect("writer becomes ready after controls drain"); + assert_eq!( + *events.lock().expect("event log poisoned"), + vec![DeferredStreamEvent::Close, DeferredStreamEvent::Flush] + ); + } + + #[tokio::test] + async fn deferred_fused_stream_tracks_pending_and_ready_state() { + let stream = + futures::stream::iter([Ok::<_, Infallible>(Bytes::from_static(b"one"))]).fuse(); + let mut stream = Deferred::from(ready(Ok::<_, Infallible>(stream))); + + assert!(!stream.is_terminated()); + assert_eq!( + stream.next().await.expect("item").expect("stream value"), + Bytes::from_static(b"one") + ); + assert!(stream.next().await.is_none()); + assert!(stream.is_terminated()); + } + + #[tokio::test] + async fn deferred_error_maps_to_stream_and_sink_errors() { + let (reader, _writer) = quic::test::mock_stream_pair(varint(39)); + let mut stream = deferred_quic_error(reader, 40); + let item = stream.next().await.expect("error item").expect_err("reset"); + assert_reset(item, varint(40)); + + let (_reader, writer) = quic::test::mock_stream_pair(varint(41)); + let mut sink = deferred_quic_error(writer, 42); + let error = sink + .send(Bytes::from_static(b"ignored")) + .await + .expect_err("send should fail"); + assert_reset(error, varint(42)); + } + + #[test] + fn deferred_ready_error_maps_to_start_send_error() { + let (_reader, writer) = quic::test::mock_stream_pair(varint(61)); + let mut sink = Box::pin(deferred_quic_error(writer, 62)); + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + + let Poll::Ready(Err(error)) = sink.as_mut().poll_ready(&mut cx) else { + panic!("poll_ready should resolve to reset"); + }; + assert_reset(error, varint(62)); + + let error = sink + .as_mut() + .start_send(Bytes::from_static(b"ignored")) + .expect_err("start_send should fail"); + assert_reset(error, varint(62)); + } + + #[test] + fn deferred_pending_read_stays_pending_without_transitioning() { + let mut reader = + Deferred::>, DecodeError, _>::from(pending::>()); + let mut output = [0; 8]; + let mut read_buf = ReadBuf::new(&mut output); + let waker = noop_waker_ref(); + let mut cx = Context::from_waker(waker); + + assert!(matches!( + Pin::new(&mut reader).poll_read(&mut cx, &mut read_buf), + Poll::Pending + )); + assert!(matches!(reader, Deferred::Pending { .. })); + } + + #[test] + #[should_panic(expected = "start_send before poll_ready completed")] + fn deferred_start_send_before_ready_panics() { + let (_reader, writer) = quic::test::mock_stream_pair(varint(43)); + let mut sink = Box::pin(Deferred::from(async move { + pending::<()>().await; + Ok::<_, quic::StreamError>(writer) + })); + + sink.as_mut() + .start_send(Bytes::from_static(b"not ready")) + .expect("panic before result"); + } + + #[test] + #[should_panic(expected = "consume before read")] + fn deferred_consume_before_read_panics() { + let reader = BufReader::new(Cursor::new(b"pending".to_vec())); + let mut reader = Deferred::from(ready(Ok::<_, DecodeError>(reader))); + Pin::new(&mut reader).consume(1); + } +} diff --git a/src/util/ring_channel.rs b/src/util/ring_channel.rs index e0198c1..b241fe2 100644 --- a/src/util/ring_channel.rs +++ b/src/util/ring_channel.rs @@ -96,3 +96,120 @@ impl Stream for Receiver { self.poll(cx).map(Some) } } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use futures::{FutureExt, StreamExt}; + use tokio::time::timeout; + use tracing::Instrument; + + use super::*; + + #[test] + fn ring_channel_capacity_is_retained() { + let channel: RingChannel = RingChannel::new(7); + assert_eq!(channel.capacity(), 7); + } + + #[tokio::test] + async fn ring_channel_send_and_receive_fifo_order() { + let channel: RingChannel = RingChannel::new(2); + channel.send(1); + channel.send(2); + + let first = timeout(Duration::from_millis(50), channel.receive()) + .await + .unwrap(); + let second = timeout(Duration::from_millis(50), channel.receive()) + .await + .unwrap(); + + assert_eq!(first, 1); + assert_eq!(second, 2); + } + + #[tokio::test] + async fn ring_channel_send_returns_overflow_item_when_full() { + let channel: RingChannel<&'static str> = RingChannel::new(1); + assert_eq!(channel.send("first"), None); + assert_eq!(channel.send("second"), Some("first")); + + let value = timeout(Duration::from_millis(50), channel.receive()) + .await + .unwrap(); + assert_eq!(value, "second"); + } + + #[tokio::test] + async fn ring_channel_clone_shares_ring_storage_with_send_receiver() { + let channel: RingChannel = RingChannel::new(4); + let sender = channel.clone(); + + sender.send(13); + + let value = timeout(Duration::from_millis(50), channel.receive()) + .await + .unwrap(); + assert_eq!(value, 13); + } + + #[tokio::test] + async fn ring_channel_receiver_waits_for_send_without_spin() { + let channel: RingChannel = RingChannel::new(4); + let receiver = channel.receive(); + + let receive = tokio::spawn( + async move { timeout(Duration::from_millis(100), receiver).await.unwrap() } + .in_current_span(), + ); + + tokio::time::sleep(Duration::from_millis(10)).await; + channel.send(99); + + let value = receive.await.unwrap(); + assert_eq!(value, 99); + } + + #[tokio::test] + async fn ring_channel_receive_times_out_without_send() { + let receiver: Receiver = RingChannel::new(2).receive(); + + let result = timeout(Duration::from_millis(20), receiver).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn ring_channel_overflow_discards_only_oldest_item() { + let channel: RingChannel = RingChannel::new(2); + + assert_eq!(channel.send(1), None); + assert_eq!(channel.send(2), None); + assert_eq!(channel.send(3), Some(1)); + + let receiver = channel.receive(); + let mut receiver = std::pin::pin!(receiver); + assert_eq!(receiver.as_mut().next().await, Some(2)); + assert_eq!(receiver.as_mut().next().await, Some(3)); + assert_eq!(receiver.as_mut().next().now_or_never(), None); + } + + #[tokio::test] + async fn pending_receiver_recovers_from_consumed_notification() { + let channel: RingChannel = RingChannel::new(1); + let pending_receiver = channel.receive(); + assert_eq!(pending_receiver.now_or_never(), None); + + let waiting_receiver = channel.receive(); + let mut waiting_receiver = std::pin::pin!(waiting_receiver); + assert_eq!(waiting_receiver.as_mut().next().now_or_never(), None); + + channel.send(1); + let consumed_by_other_receiver = channel.receive().await; + assert_eq!(consumed_by_other_receiver, 1); + + channel.send(2); + assert_eq!(waiting_receiver.as_mut().next().await, Some(2)); + } +} diff --git a/src/util/set_once.rs b/src/util/set_once.rs index 081f018..65e87f6 100644 --- a/src/util/set_once.rs +++ b/src/util/set_once.rs @@ -97,3 +97,79 @@ impl Future for Get { } } } + +#[cfg(test)] +mod tests { + use std::cell::Cell; + + use futures::FutureExt; + + use super::*; + + #[test] + fn default_starts_unset_and_set_stores_value() { + let value = SetOnce::default(); + + assert!(!value.is_set()); + assert_eq!(value.peek(), None); + + assert_eq!(value.set(7), Ok(())); + assert!(value.is_set()); + assert_eq!(value.peek(), Some(7)); + assert_eq!(value.get().now_or_never(), Some(Some(7))); + } + + #[test] + fn second_set_returns_rejected_value_without_replacing_existing() { + let value = SetOnce::new(); + + assert_eq!(value.set("first"), Ok(())); + assert_eq!(value.set("second"), Err("second")); + assert_eq!(value.peek(), Some("first")); + } + + #[test] + fn set_with_does_not_call_factory_after_value_is_set() { + let value = SetOnce::new(); + assert_eq!(value.set(1), Ok(())); + + let called = Cell::new(false); + let factory = || { + called.set(true); + 2 + }; + + let rejected = value + .set_with(factory) + .expect_err("factory should be returned when already set"); + assert!(!called.get()); + assert_eq!(rejected(), 2); + assert!(called.get()); + assert_eq!(value.peek(), Some(1)); + } + + #[tokio::test] + async fn pending_get_observes_value_set_from_clone() { + let value = SetOnce::new(); + assert_eq!(value.get().now_or_never(), None); + + let setter = value.clone(); + let get = value.get(); + + assert_eq!(setter.set(11), Ok(())); + assert_eq!(get.await, Some(11)); + } + + #[test] + fn get_survives_notification_before_value_is_set() { + let value = SetOnce::new(); + let mut get = Box::pin(value.get()); + + assert_eq!(get.as_mut().now_or_never(), None); + value.notify.notify_waiters(); + assert_eq!(get.as_mut().now_or_never(), None); + + assert_eq!(value.set(13), Ok(())); + assert_eq!(get.now_or_never(), Some(Some(13))); + } +} diff --git a/src/util/tls.rs b/src/util/tls.rs index 4f78ebb..b188f04 100644 --- a/src/util/tls.rs +++ b/src/util/tls.rs @@ -1,6 +1,3 @@ -use rustls::pki_types::CertificateDer; -use snafu::Snafu; - #[derive(Debug)] pub struct DangerousServerCertVerifier; @@ -53,47 +50,85 @@ impl rustls::client::danger::ServerCertVerifier for DangerousServerCertVerifier } } -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum InvalidIdentity { - #[snafu(transparent)] - Tls { source: rustls::Error }, - #[snafu(display("certificate for identity cannot be parsed"))] - InvalidCertificate { - source: x509_parser::nom::Err, - }, - #[snafu(display("SAN extensions in certificate are invalid"))] - InvalidSAN { - source: x509_parser::error::X509Error, - }, - #[snafu(display("certificate for identity is missing SAN extensions"))] - MissingSAN, - #[snafu(display("identity name not found in certificate SAN"))] - NameNotFound, -} - -pub fn verify_certificate_for_name( - certificate: &CertificateDer<'static>, - client_name: &str, -) -> Result<(), InvalidIdentity> { - use x509_parser::prelude::*; +#[cfg(test)] +mod tests { + use std::time::Duration; - let cert = match x509_parser::parse_x509_certificate(certificate) { - Ok((_remain, cert)) => cert, - Err(source) => return Err(InvalidIdentity::InvalidCertificate { source }), - }; - let san = match cert.subject_alternative_name() { - Ok(Some(san)) => san, - Ok(None) => return Err(InvalidIdentity::MissingSAN), - Err(source) => return Err(InvalidIdentity::InvalidSAN { source }), + use rustls::{ + DigitallySignedStruct, SignatureScheme, + client::danger::ServerCertVerifier, + internal::msgs::codec::{Codec, Reader}, + pki_types::{CertificateDer, ServerName, UnixTime}, }; - if san.value.general_names.iter().any(|name| match name { - GeneralName::DNSName(name) => *name == client_name, - _ => false, - }) { - return Ok(()); + use super::*; + + fn digitally_signed_struct_for_test(scheme: SignatureScheme) -> DigitallySignedStruct { + // rustls exposes `DigitallySignedStruct` publicly because verifier + // traits take it by reference, but its constructor is crate-private. + // Decode the same wire representation that rustls uses internally so + // these verifier tests can exercise the signature callbacks without + // depending on unsafe construction. + let mut signature = Vec::new(); + scheme.encode(&mut signature); + 0u16.encode(&mut signature); + DigitallySignedStruct::read(&mut Reader::init(&signature)) + .expect("digitally signed struct decodes") } - Err(InvalidIdentity::NameNotFound) + #[test] + fn dangerous_verifier_accepts_server_cert_without_validation() { + let verifier = DangerousServerCertVerifier; + assert_eq!(format!("{verifier:?}"), "DangerousServerCertVerifier"); + let end_entity = CertificateDer::from(Vec::new()); + let server_name = ServerName::try_from("example.com").expect("server name"); + + verifier + .verify_server_cert( + &end_entity, + &[], + &server_name, + &[], + UnixTime::since_unix_epoch(Duration::ZERO), + ) + .expect("dangerous verifier should accept any certificate"); + } + + #[test] + fn dangerous_verifier_advertises_all_supported_signature_schemes() { + let schemes = DangerousServerCertVerifier.supported_verify_schemes(); + + assert_eq!( + schemes, + vec![ + SignatureScheme::RSA_PKCS1_SHA1, + SignatureScheme::ECDSA_SHA1_Legacy, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, + SignatureScheme::ECDSA_NISTP521_SHA512, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::ED25519, + SignatureScheme::ED448, + ] + ); + } + + #[test] + fn dangerous_verifier_accepts_tls_signatures_without_validation() { + let verifier = DangerousServerCertVerifier; + let end_entity = CertificateDer::from(Vec::new()); + let signature = digitally_signed_struct_for_test(SignatureScheme::RSA_PSS_SHA256); + + verifier + .verify_tls12_signature(&[], &end_entity, &signature) + .expect("dangerous verifier should accept any TLS 1.2 signature"); + verifier + .verify_tls13_signature(&[], &end_entity, &signature) + .expect("dangerous verifier should accept any TLS 1.3 signature"); + } } diff --git a/src/util/watch.rs b/src/util/watch.rs index 0804eb6..62ee62a 100644 --- a/src/util/watch.rs +++ b/src/util/watch.rs @@ -144,9 +144,9 @@ impl Stream for Watcher { #[cfg(test)] mod tests { - use futures::StreamExt; + use futures::{FutureExt, StreamExt}; - use super::Watch; + use super::{Get, Watch}; #[tokio::test] async fn set_before_watch_observes_current_value_immediately() { @@ -209,4 +209,59 @@ mod tests { assert_eq!(value, 3); } + + #[test] + fn peek_and_set_return_previous_value() { + let watch = Watch::new(); + + assert_eq!(watch.peek(), None); + assert_eq!(watch.set("first"), None); + assert_eq!(watch.peek(), Some("first")); + assert_eq!(watch.set("second"), Some("first")); + assert_eq!(watch.peek(), Some("second")); + } + + #[test] + fn locked_value_get_set_and_replace_update_shared_state() { + let watch = Watch::new(); + + { + let mut value = watch.lock(); + assert_eq!(value.get(), None); + value.set(1); + assert_eq!(value.get(), Some(&1)); + assert_eq!(value.replace(2), Some(1)); + assert_eq!(value.get(), Some(&2)); + } + + assert_eq!(watch.peek(), Some(2)); + } + + #[tokio::test] + async fn watcher_observes_sequential_updates_when_polled_between_sets() { + let watch = Watch::new(); + let watcher = watch.watch(); + let mut watcher = std::pin::pin!(watcher); + + watch.set("first"); + assert_eq!(watcher.as_mut().next().await, Some("first")); + + watch.set("second"); + assert_eq!(watcher.as_mut().next().await, Some("second")); + assert_eq!(watcher.as_mut().next().now_or_never(), None); + } + + #[tokio::test] + async fn get_future_returns_latest_value_after_notification() { + let watch = Watch::new(); + let mut get = Box::pin(Get { + notified: watch.notify.clone().notified_owned(), + state: watch.state.clone(), + }); + + assert_eq!(get.as_mut().now_or_never(), None); + watch.set(5); + + assert_eq!(get.await, Some(5)); + } } diff --git a/src/varint.rs b/src/varint.rs index f62fb53..7aead6a 100644 --- a/src/varint.rs +++ b/src/varint.rs @@ -44,7 +44,7 @@ impl VarInt { /// Construct a `VarInt` from a [`u64`]. /// Succeeds if `x` < 2^62. pub const fn from_u64(value: u64) -> Result { - if value < VARINT_MAX { + if value <= VARINT_MAX { Ok(Self(value)) } else { Err(err::Overflow { value: value as _ }) @@ -63,7 +63,7 @@ impl VarInt { /// Construct a `VarInt` from a [`u128`]. /// Succeeds if `x` < 2^62. pub fn from_u128(value: u128) -> Result { - if value < VARINT_MAX as u128 { + if value <= VARINT_MAX as u128 { Ok(Self(value as _)) } else { Err(err::Overflow { value }) @@ -264,11 +264,11 @@ mod tests { assert!(VarInt::from_u64(63).is_ok()); assert!(VarInt::from_u64(16383).is_ok()); assert!(VarInt::from_u64(VARINT_MAX - 1).is_ok()); + assert!(VarInt::from_u64(VARINT_MAX).is_ok()); } #[test] fn from_u64_overflow() { - assert!(VarInt::from_u64(VARINT_MAX).is_err()); assert!(VarInt::from_u64(VARINT_MAX + 1).is_err()); assert!(VarInt::from_u64(u64::MAX).is_err()); } @@ -318,9 +318,65 @@ mod tests { let _ = VarInt::from(0u16); let _ = VarInt::from(0u32); assert!(VarInt::try_from(0u64).is_ok()); - assert!(VarInt::try_from(VARINT_MAX).is_err()); + assert!(VarInt::try_from(VARINT_MAX).is_ok()); + assert!(VarInt::try_from(VARINT_MAX + 1).is_err()); assert!(VarInt::try_from(0u128).is_ok()); - assert!(VarInt::try_from(VARINT_MAX as u128).is_err()); + assert!(VarInt::try_from(VARINT_MAX as u128).is_ok()); + assert!(VarInt::try_from(VARINT_MAX as u128 + 1).is_err()); + assert!(VarInt::try_from(0usize).is_ok()); + #[cfg(target_pointer_width = "64")] + assert!(VarInt::try_from((VARINT_MAX + 1) as usize).is_err()); + } + + #[test] + fn unchecked_constructor_and_constants_preserve_raw_value() { + // SAFETY: 123 is below the QUIC varint 2^62 bound. + let value = unsafe { VarInt::from_u64_unchecked(123) }; + + assert_eq!(value.into_inner(), 123); + assert_eq!(VarInt::MAX_SIZE, 8); + assert_eq!(VarInt::MAX.into_inner(), VARINT_MAX); + } + + #[test] + fn malformed_unchecked_value_panics_when_sized() { + // SAFETY: This intentionally violates the unsafe constructor contract to verify + // the internal invariant guard for malformed values. + let value = unsafe { VarInt::from_u64_unchecked(VARINT_MAX + 1) }; + + let panic = std::panic::catch_unwind(|| value.encoding_size()) + .expect_err("malformed unchecked value should panic"); + assert!( + panic + .downcast_ref::<&'static str>() + .is_some_and(|message| message.contains("malformed VarInt")) + ); + } + + #[test] + fn overflow_error_reports_original_value() { + let overflow = VARINT_MAX as u128 + 1; + let error = VarInt::from_u128(overflow).expect_err("value above max is rejected"); + + assert_eq!( + error.to_string(), + format!("value({overflow}) too large for varint encoding") + ); + assert!(format!("{error:?}").contains(&overflow.to_string())); + } + + #[cfg(feature = "serde")] + #[test] + fn overflow_error_serializes_as_original_value() { + let overflow = VARINT_MAX as u128 + 1; + let error = VarInt::from_u128(overflow).expect_err("value above max is rejected"); + + let encoded = serde_json::to_value(error).expect("overflow serializes as u128 value"); + assert_eq!(encoded, serde_json::json!(overflow)); + + let decoded: err::Overflow = + serde_json::from_value(encoded).expect("overflow deserializes from u128 value"); + assert_eq!(decoded, error); } async fn encode_decode_round_trip(value: u64) { @@ -359,6 +415,25 @@ mod tests { encode_decode_round_trip(1 << 30).await; encode_decode_round_trip(1 << 40).await; encode_decode_round_trip(VARINT_MAX - 1).await; + encode_decode_round_trip(VARINT_MAX).await; + } + + #[tokio::test] + async fn malformed_unchecked_value_panics_when_encoded() { + // SAFETY: This intentionally violates the unsafe constructor contract to verify + // the internal invariant guard for malformed values. + let value = unsafe { VarInt::from_u64_unchecked(VARINT_MAX + 1) }; + let join = tokio::spawn(async move { + value + .encode_into(Cursor::new(Vec::::new())) + .await + .expect("malformed unchecked value should panic before returning"); + }); + + let error = join + .await + .expect_err("malformed unchecked value should panic"); + assert!(error.is_panic()); } #[test] diff --git a/src/webtransport.rs b/src/webtransport.rs index ab9a4d7..f28602b 100644 --- a/src/webtransport.rs +++ b/src/webtransport.rs @@ -7,7 +7,7 @@ //! //! WebTransport (draft-ietf-webtrans-http3) enables multiplexed, bidirectional //! and unidirectional streams within an HTTP/3 connection. A session is -//! established via Extended CONNECT with `:protocol=webtransport`, after which +//! established via Extended CONNECT with `:protocol=webtransport-h3`, after which //! both peers can open streams associated with that session. //! //! # Stream identification @@ -23,35 +23,49 @@ //! # Usage //! //! ```ignore -//! // Server handler: accept a WebTransport CONNECT request -//! let wt = request.protocols().get::().unwrap(); -//! let session = wt.register(request.stream_id().into_inner())?; -//! -//! // Open / accept streams within the session +//! let (response, connect) = h3x::hyper::extended_connect::accept(request).await?; +//! let session = h3x::webtransport::WebTransportSession::try_from(connect)?; //! let (reader, writer) = session.open_bi().await?; //! let (reader, writer) = session.accept_bi().await?; //! let reader = session.accept_uni().await?; +//! # Ok::<_, Box>(()) //! ``` -use std::future::Future; - use futures::future::BoxFuture; -use crate::{ - codec::{BoxReadStream, BoxWriteStream}, - quic::{ReadStream, WriteStream}, - varint::VarInt, -}; +use crate::quic::{BoxQuicStreamReader, BoxQuicStreamWriter, ReadStream, WriteStream}; +mod close; mod error; mod protocol; +mod registry; mod session; +mod session_id; +mod stream_count; -pub use error::{Closed, DatagramError, OpenSnafu, OpenStreamError, RegisterError}; +pub use close::{ + CloseSession, CloseSessionMessage, CloseSessionMessageTooLong, DecodeCloseSessionError, + TryFromCloseSessionMessageBytesError, TryFromCloseSessionPartsError, +}; +pub use error::{ + AcceptStreamError, CloseReason, CloseSessionError, ControlCommandError, DatagramError, + DrainReason, DrainSessionError, OpenStreamError, RegisterSessionError, SessionCloseReason, + SessionClosed, SessionDrain, SessionDrainReason, +}; +#[cfg(feature = "rpc")] +pub(crate) use error::{accept_stream_error, open_stream_error}; pub use protocol::{ - WT_BIDI_SIGNAL, WT_UNI_SIGNAL, WebTransportProtocol, WebTransportProtocolFactory, + WEBTRANSPORT_BIDI_SIGNAL, WEBTRANSPORT_H3, WEBTRANSPORT_UNI_SIGNAL, WebTransportProtocol, + WebTransportProtocolFactory, +}; +pub use session::{ + WebTransportSession, + stream::{WebTransportStreamReader, WebTransportStreamWriter}, +}; +pub use session_id::{InvalidSessionId, WebTransportSessionId}; +pub use stream_count::{ + DecodeWebTransportStreamCountError, InvalidWebTransportStreamCount, WebTransportStreamCount, }; -pub use session::WebTransportSession; // ============================================================================ // Session trait (AFIT) @@ -65,7 +79,18 @@ pub trait Session: Send + Sync { type StreamReader: ReadStream + Unpin; type StreamWriter: WriteStream + Unpin; - fn session_id(&self) -> VarInt; + fn id(&self) -> WebTransportSessionId; + + fn drain(&self) -> impl Future> + Send + '_; + + fn close( + &self, + close: CloseSession, + ) -> impl Future> + Send + '_; + + fn drained(&self) -> impl Future + Send + '_; + + fn closed(&self) -> impl Future + Send + '_; fn open_bi( &self, @@ -79,9 +104,13 @@ pub trait Session: Send + Sync { fn accept_bi( &self, - ) -> impl Future> + Send + '_; + ) -> impl Future> + + Send + + '_; - fn accept_uni(&self) -> impl Future> + Send + '_; + fn accept_uni( + &self, + ) -> impl Future> + Send + '_; } // ============================================================================ @@ -90,52 +119,90 @@ pub trait Session: Send + Sync { /// Object-safe version of [`Session`] with type-erased streams. /// -/// Stream types are fixed to [`BoxReadStream`] / [`BoxWriteStream`]. +/// Stream types are fixed to [`BoxQuicStreamReader`] / [`BoxQuicStreamWriter`]. /// A blanket impl is provided for all `T: Session`. pub trait DynSession: Send + Sync { - fn session_id(&self) -> VarInt; + fn id(&self) -> WebTransportSessionId; + + fn drain(&self) -> BoxFuture<'_, Result<(), DrainSessionError>>; + + fn close(&self, close: CloseSession) -> BoxFuture<'_, Result<(), CloseSessionError>>; + + fn drained(&self) -> BoxFuture<'_, SessionDrain>; + + fn closed(&self) -> BoxFuture<'_, CloseReason>; #[allow(clippy::type_complexity)] - fn open_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), OpenStreamError>>; + fn open_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), OpenStreamError>>; - fn open_uni(&self) -> BoxFuture<'_, Result>; + fn open_uni(&self) -> BoxFuture<'_, Result>; #[allow(clippy::type_complexity)] - fn accept_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), Closed>>; + fn accept_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), AcceptStreamError>>; - fn accept_uni(&self) -> BoxFuture<'_, Result>; + fn accept_uni(&self) -> BoxFuture<'_, Result>; } impl DynSession for T { - fn session_id(&self) -> VarInt { - Session::session_id(self) + fn id(&self) -> WebTransportSessionId { + Session::id(self) + } + + fn drain(&self) -> BoxFuture<'_, Result<(), DrainSessionError>> { + Box::pin(Session::drain(self)) + } + + fn close(&self, close: CloseSession) -> BoxFuture<'_, Result<(), CloseSessionError>> { + Box::pin(Session::close(self, close)) + } + + fn drained(&self) -> BoxFuture<'_, SessionDrain> { + Box::pin(Session::drained(self)) + } + + fn closed(&self) -> BoxFuture<'_, CloseReason> { + Box::pin(Session::closed(self)) } - fn open_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), OpenStreamError>> { + fn open_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), OpenStreamError>> { Box::pin(async { let (r, w) = Session::open_bi(self).await?; - Ok((Box::pin(r) as BoxReadStream, Box::pin(w) as BoxWriteStream)) + Ok(( + Box::pin(r) as BoxQuicStreamReader, + Box::pin(w) as BoxQuicStreamWriter, + )) }) } - fn open_uni(&self) -> BoxFuture<'_, Result> { + fn open_uni(&self) -> BoxFuture<'_, Result> { Box::pin(async { let w = Session::open_uni(self).await?; - Ok(Box::pin(w) as BoxWriteStream) + Ok(Box::pin(w) as BoxQuicStreamWriter) }) } - fn accept_bi(&self) -> BoxFuture<'_, Result<(BoxReadStream, BoxWriteStream), Closed>> { + fn accept_bi( + &self, + ) -> BoxFuture<'_, Result<(BoxQuicStreamReader, BoxQuicStreamWriter), AcceptStreamError>> { Box::pin(async { let (r, w) = Session::accept_bi(self).await?; - Ok((Box::pin(r) as BoxReadStream, Box::pin(w) as BoxWriteStream)) + Ok(( + Box::pin(r) as BoxQuicStreamReader, + Box::pin(w) as BoxQuicStreamWriter, + )) }) } - fn accept_uni(&self) -> BoxFuture<'_, Result> { + fn accept_uni(&self) -> BoxFuture<'_, Result> { Box::pin(async { let r = Session::accept_uni(self).await?; - Ok(Box::pin(r) as BoxReadStream) + Ok(Box::pin(r) as BoxQuicStreamReader) }) } } @@ -145,161 +212,525 @@ impl DynSession for T { // ============================================================================ impl Session for WebTransportSession { - type StreamReader = BoxReadStream; - type StreamWriter = BoxWriteStream; + type StreamReader = WebTransportStreamReader; + type StreamWriter = WebTransportStreamWriter; - fn session_id(&self) -> VarInt { - WebTransportSession::session_id(self) + fn id(&self) -> WebTransportSessionId { + WebTransportSession::id(self) } - async fn open_bi(&self) -> Result<(BoxReadStream, BoxWriteStream), OpenStreamError> { + async fn drain(&self) -> Result<(), DrainSessionError> { + WebTransportSession::drain(self).await + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + WebTransportSession::close(self, close).await + } + + async fn drained(&self) -> SessionDrain { + WebTransportSession::drained(self).await + } + + async fn closed(&self) -> CloseReason { + WebTransportSession::closed(self).await + } + + async fn open_bi(&self) -> Result<(Self::StreamReader, Self::StreamWriter), OpenStreamError> { WebTransportSession::open_bi(self).await } - async fn open_uni(&self) -> Result { + async fn open_uni(&self) -> Result { WebTransportSession::open_uni(self).await } - async fn accept_bi(&self) -> Result<(BoxReadStream, BoxWriteStream), Closed> { + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), AcceptStreamError> { WebTransportSession::accept_bi(self).await } - async fn accept_uni(&self) -> Result { + async fn accept_uni(&self) -> Result { WebTransportSession::accept_uni(self).await } } -// ============================================================================ -// Error-latching extensions for RPC/IPC -// ============================================================================ +#[cfg(test)] +mod tests { + use std::{future, sync::Arc}; -#[cfg(feature = "rpc")] -mod lifecycle_ext { - use std::future::Future; + use bytes::Bytes; + use futures::{SinkExt, StreamExt}; - use snafu::ResultExt; - - use super::{Closed, OpenSnafu, OpenStreamError}; + use super::*; use crate::{ - quic::{self, ConnectionError}, - rpc::lifecycle::LifecycleExt, + connection::{ConnectionState, tests::MockConnection}, + dhttp::{ + message::{MessageWriter, test::read_stream_for_test}, + protocol::DHttpProtocol, + settings::Settings, + webtransport::settings::{ + EnableWebTransport, InitialMaxData, InitialMaxStreamsBidi, InitialMaxStreamsUni, + }, + }, + extended_connect::{EstablishedConnect, PendingWriteStreamError}, + protocol::Protocols, + qpack::field::Protocol, + quic::{self, GetStreamIdExt}, + stream_id::StreamId, + varint::VarInt, }; - /// WebTransport-flavoured extension of [`LifecycleExt`]. - /// - /// Adds check/guard helpers that surface [`OpenStreamError`] and - /// [`Closed`] instead of the raw [`ConnectionError`], while preserving - /// the lazy first-wins latching discipline. - /// - /// Like [`LifecycleExt`], this trait is sealed: it is automatically - /// implemented for any type that already satisfies [`LifecycleExt`]. - #[allow(async_fn_in_trait)] - pub trait WtLifecycleExt: LifecycleExt { - /// Check liveness and surface any error as an [`OpenStreamError`]. - fn check_open(&self) -> Result<(), OpenStreamError> { - quic::Lifecycle::check(self).context(OpenSnafu) + #[derive(Debug)] + struct TestSession { + id: WebTransportSessionId, + drained: std::sync::Mutex, + closed: std::sync::Mutex>, + } + + #[derive(Debug)] + struct FailingSession { + id: WebTransportSessionId, + } + + impl TestSession { + fn new(id: WebTransportSessionId) -> Self { + Self { + id, + drained: std::sync::Mutex::new(false), + closed: std::sync::Mutex::new(None), + } } + } + + fn wt_session_id(id: u32) -> WebTransportSessionId { + WebTransportSessionId::try_from(StreamId(VarInt::from_u32(id))) + .expect("test id must be a valid webtransport session id") + } + + fn boxed_stream_pair(stream_id: u32) -> (BoxQuicStreamReader, BoxQuicStreamWriter) { + let (reader, writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + ( + Box::pin(reader) as BoxQuicStreamReader, + Box::pin(writer) as BoxQuicStreamWriter, + ) + } + + fn connection_with_webtransport( + mock: Arc, + ) -> Arc> { + let erased: Arc = mock.clone(); + let mut protocols = Protocols::new(); + let dhttp = DHttpProtocol::new_for_test(erased.clone()); + dhttp + .state + .peer_settings + .set(Arc::new(enabled_webtransport_settings())) + .expect("peer settings should be set once"); + protocols.insert(dhttp); + protocols.insert(WebTransportProtocol::new_for_test(erased)); + Arc::new(ConnectionState::new_for_test(mock, Arc::new(protocols)).erase()) + } + + fn enabled_webtransport_settings() -> Settings { + let mut settings = Settings::default(); + settings.set(EnableWebTransport::setting(true)); + settings.set(InitialMaxStreamsBidi::setting(VarInt::from_u32(16))); + settings.set(InitialMaxStreamsUni::setting(VarInt::from_u32(16))); + settings.set(InitialMaxData::setting(VarInt::MAX)); + settings + } - /// Check liveness and flatten any error into a [`Closed`]. - fn check_accept(&self) -> Result<(), Closed> { - quic::Lifecycle::check(self).map_err(|_| Closed) + fn webtransport_session_for_test( + mock: Arc, + stream_id: StreamId, + ) -> WebTransportSession { + WebTransportSession::try_from(EstablishedConnect::pending( + stream_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection_with_webtransport(mock), + read_stream_for_test(stream_id.0), + future::pending::>(), + )) + .expect("webtransport session should be registered") + } + + fn connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, } + } - /// Guard an async open operation whose error is already an - /// [`OpenStreamError`]. - /// - /// If the error wraps a [`ConnectionError`] (the `Open` variant), it - /// is latched lazily so the first such error becomes canonical; other - /// variants pass through untouched. - async fn guard_open( - &self, - fut: impl Future>, - ) -> Result { - self.check_open()?; - match fut.await { - Ok(v) => Ok(v), - Err(OpenStreamError::Open { source }) => Err(OpenStreamError::Open { - source: self.latch().latch_with(|| source), - }), - Err(other) => Err(other), + fn assert_transport_reason(error: &quic::ConnectionError, expected_reason: &str) { + match error { + quic::ConnectionError::Transport { source } => { + assert_eq!(source.reason.as_ref(), expected_reason); } + other => panic!("expected transport error, got {other:?}"), } + } - /// Guard an async open operation whose error must be lazily converted - /// to an [`OpenStreamError`]. - /// - /// `map_err` is invoked only when the operation errored **and** no - /// error has been latched yet. If the resulting - /// [`OpenStreamError::Open`] carries a connection error, it is - /// substituted with the already-latched value (first wins). - async fn guard_open_with( - &self, - fut: impl Future>, - map_err: M, - ) -> Result - where - M: FnOnce(E) -> OpenStreamError, - { - self.check_open()?; - match fut.await { - Ok(v) => Ok(v), - Err(e) => { - if let Some(existing) = self.latch().peek() { - return Err(OpenStreamError::Open { source: existing }); - } - Err(match map_err(e) { - OpenStreamError::Open { source } => OpenStreamError::Open { - source: self.latch().latch_with(|| source), - }, - other => other, - }) - } + impl Session for TestSession { + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; + + fn id(&self) -> WebTransportSessionId { + self.id + } + + async fn drain(&self) -> Result<(), DrainSessionError> { + *self + .drained + .lock() + .expect("drained mutex should not poison") = true; + Ok(()) + } + + async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + *self.closed.lock().expect("closed mutex should not poison") = Some(close); + Ok(()) + } + + async fn drained(&self) -> SessionDrain { + if *self + .drained + .lock() + .expect("drained mutex should not poison") + { + SessionDrain::Requested(DrainReason::Session(SessionDrainReason::Local)) + } else { + SessionDrain::Closed(CloseReason::Session(SessionCloseReason::ControlStreamError)) } } - /// Guard an async accept operation whose error is already a - /// [`Closed`]. - async fn guard_accept( + async fn closed(&self) -> CloseReason { + match self + .closed + .lock() + .expect("closed mutex should not poison") + .clone() + { + Some(close) => CloseReason::Session(SessionCloseReason::Local(close)), + None => CloseReason::Session(SessionCloseReason::ControlStreamError), + } + } + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), OpenStreamError> { + Ok(boxed_stream_pair(1)) + } + + async fn open_uni(&self) -> Result { + let (_reader, writer) = boxed_stream_pair(2); + Ok(writer) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), AcceptStreamError> { + Ok(boxed_stream_pair(3)) + } + + async fn accept_uni(&self) -> Result { + let (reader, _writer) = boxed_stream_pair(4); + Ok(reader) + } + } + + impl Session for FailingSession { + type StreamReader = BoxQuicStreamReader; + type StreamWriter = BoxQuicStreamWriter; + + fn id(&self) -> WebTransportSessionId { + self.id + } + + async fn drain(&self) -> Result<(), DrainSessionError> { + Err(DrainSessionError::Closed { + source: SessionClosed, + }) + } + + async fn close(&self, _close: CloseSession) -> Result<(), CloseSessionError> { + Err(CloseSessionError::Closed { + source: SessionClosed, + }) + } + + async fn drained(&self) -> SessionDrain { + SessionDrain::Closed(CloseReason::Session(SessionCloseReason::ControlStreamError)) + } + + async fn closed(&self) -> CloseReason { + CloseReason::Session(SessionCloseReason::ControlStreamError) + } + + async fn open_bi( &self, - fut: impl Future>, - ) -> Result { - self.check_accept()?; - fut.await + ) -> Result<(Self::StreamReader, Self::StreamWriter), OpenStreamError> { + Err(OpenStreamError::Closed { + source: SessionClosed, + }) } - /// Guard an async accept operation whose error carries richer - /// information than [`Closed`]. - /// - /// `map_err` is invoked only when the operation errored **and** no - /// error has been latched yet. Returning `Some(error)` from it will - /// lazily install that error in the latch so later observers (on the - /// connection path) see a meaningful terminal cause; the caller - /// always sees a plain [`Closed`]. - async fn guard_accept_err( + async fn open_uni(&self) -> Result { + Err(OpenStreamError::Closed { + source: SessionClosed, + }) + } + + async fn accept_bi( &self, - fut: impl Future>, - map_err: M, - ) -> Result - where - M: FnOnce(E) -> Option, - { - self.check_accept()?; - match fut.await { - Ok(v) => Ok(v), - Err(e) => { - if self.latch().peek().is_none() - && let Some(error) = map_err(e) - { - let _ = self.latch().latch_with(|| error); - } - Err(Closed) - } + ) -> Result<(Self::StreamReader, Self::StreamWriter), AcceptStreamError> { + Err(AcceptStreamError::Closed { + source: SessionClosed, + }) + } + + async fn accept_uni(&self) -> Result { + Err(AcceptStreamError::Closed { + source: SessionClosed, + }) + } + } + + async fn assert_roundtrip( + reader: &mut BoxQuicStreamReader, + writer: &mut BoxQuicStreamWriter, + payload: &'static [u8], + ) { + let bytes = Bytes::from_static(payload); + writer + .send(bytes.clone()) + .await + .expect("stream write should succeed"); + let received = reader + .next() + .await + .expect("reader should receive one chunk") + .expect("stream read should succeed"); + + assert_eq!(received, bytes); + } + + #[tokio::test] + async fn dyn_session_delegates_all_stream_operations() { + let session = TestSession::new(wt_session_id(40)); + let dyn_session: &dyn DynSession = &session; + + assert_eq!(dyn_session.id(), wt_session_id(40)); + + let (mut reader, mut writer) = dyn_session.open_bi().await.expect("open_bi should succeed"); + assert_eq!(reader.stream_id().await.unwrap(), VarInt::from_u32(1)); + assert_eq!(writer.stream_id().await.unwrap(), VarInt::from_u32(1)); + assert_roundtrip(&mut reader, &mut writer, b"open-bidi").await; + + let mut writer = dyn_session + .open_uni() + .await + .expect("open_uni should succeed"); + assert_eq!(writer.stream_id().await.unwrap(), VarInt::from_u32(2)); + + let (mut reader, mut writer) = dyn_session + .accept_bi() + .await + .expect("accept_bi should succeed"); + assert_eq!(reader.stream_id().await.unwrap(), VarInt::from_u32(3)); + assert_eq!(writer.stream_id().await.unwrap(), VarInt::from_u32(3)); + assert_roundtrip(&mut reader, &mut writer, b"accept-bidi").await; + + let mut reader = dyn_session + .accept_uni() + .await + .expect("accept_uni should succeed"); + assert_eq!(reader.stream_id().await.unwrap(), VarInt::from_u32(4)); + } + + #[tokio::test] + async fn dyn_session_preserves_operation_errors() { + let session = FailingSession { + id: wt_session_id(4), + }; + let dyn_session: &dyn DynSession = &session; + + assert_eq!(dyn_session.id(), wt_session_id(4)); + assert!(matches!( + dyn_session.open_bi().await, + Err(OpenStreamError::Closed { .. }) + )); + assert!(matches!( + dyn_session.open_uni().await, + Err(OpenStreamError::Closed { .. }) + )); + assert!(matches!( + dyn_session.accept_bi().await, + Err(AcceptStreamError::Closed { .. }) + )); + assert!(matches!( + dyn_session.accept_uni().await, + Err(AcceptStreamError::Closed { .. }) + )); + } + + #[tokio::test] + async fn dyn_session_bridge_uses_webtransport_session_open_impls() { + let mock = Arc::new(MockConnection::new()); + mock.enable_stream_ops(); + let session_id = StreamId(VarInt::from_u32(44)); + let session = webtransport_session_for_test(Arc::clone(&mock), session_id); + let dyn_session: &dyn DynSession = &session; + + assert_eq!( + dyn_session.id(), + WebTransportSessionId::try_from(session_id).expect("valid session id") + ); + + let (mut reader, mut writer) = dyn_session.open_bi().await.expect("open_bi should work"); + assert_eq!(reader.stream_id().await.unwrap(), VarInt::from_u32(0)); + assert_eq!(writer.stream_id().await.unwrap(), VarInt::from_u32(0)); + + let mut writer = dyn_session.open_uni().await.expect("open_uni should work"); + assert_eq!(writer.stream_id().await.unwrap(), VarInt::from_u32(0)); + + assert_eq!(mock.stream_calls(), vec!["open_bi", "open_uni"]); + } + + #[tokio::test] + async fn dyn_session_delegates_drain_close_drained_and_closed() { + let session = TestSession::new(wt_session_id(44)); + let dyn_session: &dyn DynSession = &session; + let close = CloseSession::try_from((5_u32, "bye")).expect("valid close"); + + dyn_session.drain().await.expect("dyn drain succeeds"); + assert_eq!( + dyn_session.drained().await, + SessionDrain::Requested(DrainReason::Session(SessionDrainReason::Local)) + ); + + dyn_session + .close(close.clone()) + .await + .expect("dyn close succeeds"); + assert_eq!( + dyn_session.closed().await, + CloseReason::Session(SessionCloseReason::Local(close)) + ); + } + + #[tokio::test] + async fn dyn_session_bridge_uses_webtransport_session_accept_impls() { + let mock = Arc::new(MockConnection::new()); + mock.set_terminal_error(connection_error("bidi closed")); + let session = + webtransport_session_for_test(Arc::clone(&mock), StreamId(VarInt::from_u32(44))); + let dyn_session: &dyn DynSession = &session; + + match dyn_session.accept_bi().await { + Err(AcceptStreamError::Connection { source }) => { + assert_transport_reason(&source, "bidi closed"); + } + Err(_) => panic!("expected connection error"), + Ok(_) => panic!("accept_bi should fail"), + } + + let mock = Arc::new(MockConnection::new()); + mock.set_terminal_error(connection_error("uni closed")); + let session = + webtransport_session_for_test(Arc::clone(&mock), StreamId(VarInt::from_u32(48))); + let dyn_session: &dyn DynSession = &session; + + match dyn_session.accept_uni().await { + Err(AcceptStreamError::Connection { source }) => { + assert_transport_reason(&source, "uni closed"); } + Err(_) => panic!("expected connection error"), + Ok(_) => panic!("accept_uni should fail"), } } - impl WtLifecycleExt for T {} -} + #[cfg(feature = "rpc")] + mod lifecycle_ext_tests { + use std::{borrow::Cow, future::pending}; + + use super::*; + use crate::{ + error::Code, + rpc::{ + lifecycle::{ + ConnectionErrorLatch, HasLatch, LifecycleExt as ConnectionLifecycleExt, + }, + webtransport::LifecycleExt as _, + }, + }; + + #[derive(Debug, Default)] + struct TestLifecycle { + latch: ConnectionErrorLatch, + } -#[cfg(feature = "rpc")] -pub use lifecycle_ext::WtLifecycleExt; + impl HasLatch for TestLifecycle { + fn latch(&self) -> &ConnectionErrorLatch { + &self.latch + } + } + + impl quic::Lifecycle for TestLifecycle { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + self.check_with_probe(|| None) + } + + async fn closed(&self) -> quic::ConnectionError { + self.resolve_closed(pending()).await + } + } + + fn connection_error(reason: &'static str) -> quic::ConnectionError { + quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(0x01), + frame_type: VarInt::from_u32(0x00), + reason: reason.into(), + }, + } + } + + #[tokio::test] + async fn module_scoped_lifecycle_ext_path_exposes_helper_set() { + let lifecycle = TestLifecycle::default(); + + lifecycle.check_open().expect("open check should pass"); + lifecycle.check_accept().expect("accept check should pass"); + + lifecycle + .guard_open(async { Ok::<_, OpenStreamError>(()) }) + .await + .expect("guard_open should be available on the module path"); + + let lifecycle = TestLifecycle::default(); + lifecycle + .guard_accept(async { Ok::<_, AcceptStreamError>(()) }) + .await + .expect("guard_accept should be available on the module path"); + + let lifecycle = TestLifecycle::default(); + lifecycle + .guard_accept_err(async { Err::<(), _>("closed") }, |_| None) + .await + .expect_err("guard_accept_err should be available on the module path"); + + let lifecycle = TestLifecycle::default(); + lifecycle + .guard_open_with(async { Err::<(), _>("open") }, |_| OpenStreamError::Open { + source: connection_error("open"), + }) + .await + .expect_err("guard_open_with should be available on the old path"); + } + } +} diff --git a/src/webtransport/close.rs b/src/webtransport/close.rs new file mode 100644 index 0000000..0f3c6b9 --- /dev/null +++ b/src/webtransport/close.rs @@ -0,0 +1,249 @@ +use std::{convert::Infallible, error::Error as StdError, io, string::FromUtf8Error}; + +use bytes::Bytes; +use snafu::{ResultExt, Snafu}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::{ + buflist::BufList, + codec::{DecodeFrom, EncodeExt, EncodeInto}, +}; + +const CLOSE_SESSION_MESSAGE_MAX_LEN: usize = 1024; + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CloseSession { + application_error_code: u32, + message: CloseSessionMessage, +} + +impl CloseSession { + pub const fn new(application_error_code: u32, message: CloseSessionMessage) -> Self { + Self { + application_error_code, + message, + } + } + + pub const fn application_error_code(&self) -> u32 { + self.application_error_code + } + + pub const fn message(&self) -> &CloseSessionMessage { + &self.message + } + + pub fn try_from_parts( + application_error_code: C, + message: M, + ) -> Result> + where + C: TryInto, + C::Error: StdError + Send + Sync + 'static, + M: TryInto, + M::Error: StdError + Send + Sync + 'static, + { + let application_error_code = application_error_code + .try_into() + .context(try_from_close_session_parts_error::ApplicationErrorCodeSnafu)?; + let message = message + .try_into() + .context(try_from_close_session_parts_error::MessageSnafu)?; + Ok(Self { + application_error_code, + message, + }) + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CloseSessionMessage(String); + +impl CloseSessionMessage { + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl TryFrom for CloseSessionMessage { + type Error = CloseSessionMessageTooLong; + + fn try_from(message: String) -> Result { + let len = message.len(); + if len > CLOSE_SESSION_MESSAGE_MAX_LEN { + Err(CloseSessionMessageTooLong { len }) + } else { + Ok(Self(message)) + } + } +} + +impl TryFrom<&str> for CloseSessionMessage { + type Error = CloseSessionMessageTooLong; + + fn try_from(message: &str) -> Result { + Self::try_from(message.to_owned()) + } +} + +impl TryFrom for CloseSessionMessage { + type Error = TryFromCloseSessionMessageBytesError; + + fn try_from(message: Bytes) -> Result { + let message = String::from_utf8(message.to_vec()) + .context(try_from_close_session_message_bytes_error::Utf8Snafu)?; + Ok(Self::try_from(message)?) + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Snafu, Clone, Copy, PartialEq, Eq)] +#[snafu(display("webtransport close session message is {len} bytes, exceeding 1024 bytes"))] +pub struct CloseSessionMessageTooLong { + len: usize, +} + +impl CloseSessionMessageTooLong { + pub const fn len(&self) -> usize { + self.len + } + + pub const fn is_empty(&self) -> bool { + self.len == 0 + } +} + +#[derive(Debug, Snafu)] +#[snafu( + module(try_from_close_session_message_bytes_error), + visibility(pub(super)) +)] +pub enum TryFromCloseSessionMessageBytesError { + #[snafu(display("webtransport close session message is not utf-8"))] + Utf8 { source: FromUtf8Error }, + #[snafu(transparent)] + TooLong { source: CloseSessionMessageTooLong }, +} + +#[derive(Debug, Snafu)] +#[snafu(module(try_from_close_session_parts_error), visibility(pub(super)))] +pub enum TryFromCloseSessionPartsError +where + C: StdError + Send + Sync + 'static, + M: StdError + Send + Sync + 'static, +{ + #[snafu(display("invalid webtransport close session application error code"))] + ApplicationErrorCode { source: C }, + #[snafu(display("invalid webtransport close session message"))] + Message { source: M }, +} + +impl TryFrom<(u32, String)> for CloseSession { + type Error = TryFromCloseSessionPartsError; + + fn try_from((application_error_code, message): (u32, String)) -> Result { + Self::try_from_parts(application_error_code, message) + } +} + +impl<'m> TryFrom<(u32, &'m str)> for CloseSession { + type Error = TryFromCloseSessionPartsError; + + fn try_from((application_error_code, message): (u32, &'m str)) -> Result { + Self::try_from_parts(application_error_code, message) + } +} + +impl TryFrom<(u32, Bytes)> for CloseSession { + type Error = TryFromCloseSessionPartsError; + + fn try_from((application_error_code, message): (u32, Bytes)) -> Result { + Self::try_from_parts(application_error_code, message) + } +} + +#[derive(Debug, Snafu)] +#[snafu(module(decode_close_session_error), visibility(pub(super)))] +pub enum DecodeCloseSessionError { + #[snafu(display("failed to decode webtransport close session application error code"))] + ApplicationErrorCode { source: io::Error }, + #[snafu(display("failed to decode webtransport close session message"))] + Message { source: io::Error }, + #[snafu(display("invalid webtransport close session message"))] + InvalidMessage { + source: TryFromCloseSessionMessageBytesError, + }, +} + +impl DecodeFrom for CloseSession +where + S: AsyncRead + Unpin + Send, +{ + type Error = DecodeCloseSessionError; + + async fn decode_from(mut stream: S) -> Result { + let application_error_code = stream + .read_u32() + .await + .context(decode_close_session_error::ApplicationErrorCodeSnafu)?; + + let mut message = Vec::new(); + stream + .take((CLOSE_SESSION_MESSAGE_MAX_LEN + 1) as u64) + .read_to_end(&mut message) + .await + .context(decode_close_session_error::MessageSnafu)?; + let message = CloseSessionMessage::try_from(Bytes::from(message)) + .context(decode_close_session_error::InvalidMessageSnafu)?; + + Ok(Self::new(application_error_code, message)) + } +} + +impl<'s, S> EncodeInto<&'s mut S> for CloseSession +where + S: AsyncWrite + Unpin + Send, +{ + type Output = (); + type Error = io::Error; + + async fn encode_into(self, stream: &'s mut S) -> Result { + stream.write_u32(self.application_error_code()).await?; + stream.write_all(self.message().as_str().as_bytes()).await?; + Ok(()) + } +} + +impl EncodeInto for CloseSession { + type Output = BufList; + type Error = Infallible; + + async fn encode_into(self, mut stream: BufList) -> Result { + stream + .encode_one(self) + .await + .expect("encoding CloseSession into a buflist is infallible"); + Ok(stream) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn close_session_message_rejects_message_above_1024_bytes() { + let message = "x".repeat(1025); + let error = CloseSessionMessage::try_from(message).expect_err("too long"); + assert_eq!(error.len(), 1025); + } + + #[test] + fn close_session_try_from_tuple_preserves_parts() { + let close = CloseSession::try_from((7_u32, "done")).expect("valid close"); + assert_eq!(close.application_error_code(), 7); + assert_eq!(close.message().as_str(), "done"); + } +} diff --git a/src/webtransport/error.rs b/src/webtransport/error.rs index abc2939..51a6b77 100644 --- a/src/webtransport/error.rs +++ b/src/webtransport/error.rs @@ -1,59 +1,167 @@ -//! Error types for the WebTransport protocol layer. - use snafu::Snafu; +use super::{ + CloseSession, InvalidSessionId, InvalidWebTransportStreamCount, WebTransportSessionId, +}; use crate::{ + dhttp::message::MessageStreamError, + qpack::field::Protocol, quic::{ConnectionError, StreamError}, - stream_id::StreamId, }; -// ============================================================================ -// Registration errors -// ============================================================================ - -/// Errors from [`WebTransportProtocol::register`](super::WebTransportProtocol::register). #[derive(Debug, Snafu)] -#[snafu(visibility(pub(super)))] -pub enum RegisterError { +#[snafu(module, visibility(pub(in crate::webtransport)))] +pub enum RegisterSessionError { + #[snafu(display("extended connect is missing a protocol token"))] + MissingProtocol, + #[snafu(display("extended connect protocol {protocol:?} is not webtransport-h3"))] + UnexpectedProtocol { protocol: Protocol }, + #[snafu(display("webtransport protocol layer is not registered on the connection"))] + ProtocolLayerMissing, + #[snafu(display("peer HTTP/3 settings are not available"))] + PeerSettingsUnavailable, + #[snafu(display("webtransport is not enabled by peer settings"))] + WebTransportNotEnabled, + #[snafu(display("webtransport stream-count flow control is not enabled by peer settings"))] + FlowControlNotEnabled, + #[snafu(display("invalid peer webtransport initial stream count"))] + InitialStreamCount { + source: InvalidWebTransportStreamCount, + }, + #[snafu(display("invalid webtransport session id"))] + InvalidSessionId { source: InvalidSessionId }, #[snafu(display("session already registered for {session_id}"))] - AlreadyRegistered { session_id: StreamId }, - + AlreadyRegistered { session_id: WebTransportSessionId }, #[snafu(display("session registry lock poisoned"))] RegistryPoisoned, } -// ============================================================================ -// Stream-opening errors -// ============================================================================ +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Snafu)] +#[snafu(display("webtransport session closed"))] +pub struct SessionClosed; + +#[derive(Debug, Snafu, Clone, Copy, PartialEq, Eq)] +#[snafu(module, visibility(pub(in crate::webtransport)))] +pub enum SessionFlowControlError { + #[snafu(display("peer exceeded webtransport stream credit"))] + ExceededStreamCredit, + #[snafu(display("peer decreased webtransport max streams"))] + DecreasingMaxStreams, + #[snafu(display("webtransport stream queue capacity invariant failed"))] + QueueCapacityInvariant, + #[snafu(display("webtransport stream count overflow"))] + StreamCount { + source: InvalidWebTransportStreamCount, + }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq)] +pub enum SessionDrain { + Requested(DrainReason), + Closed(CloseReason), +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DrainReason { + Session(SessionDrainReason), + HttpGoaway(crate::connection::ConnectionGoaway), +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionDrainReason { + Local, + Remote, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone)] +pub enum CloseReason { + Session(SessionCloseReason), + Connection(crate::quic::ConnectionError), +} + +impl PartialEq for CloseReason { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Session(left), Self::Session(right)) => left == right, + _ => false, + } + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SessionCloseReason { + Local(CloseSession), + Remote(CloseSession), + Protocol { code: crate::error::Code }, + ControlStreamError, +} -/// Errors from [`WebTransportSession::open_bi`](super::WebTransportSession::open_bi) -/// and [`WebTransportSession::open_uni`](super::WebTransportSession::open_uni). #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Snafu)] -#[snafu(visibility(pub))] +#[snafu(module, visibility(pub))] +pub enum DrainSessionError { + #[snafu(display("webtransport session closed"))] + Closed { source: SessionClosed }, + #[snafu(display("failed to send webtransport drain command"))] + Command { source: ControlCommandError }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Snafu)] +#[snafu(module, visibility(pub))] +pub enum CloseSessionError { + #[snafu(display("webtransport session closed"))] + Closed { source: SessionClosed }, + #[snafu(display("failed to send webtransport close command"))] + Command { source: ControlCommandError }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Snafu)] +#[snafu(module, visibility(pub))] +pub enum ControlCommandError { + #[snafu(display("webtransport control task is closed"))] + Closed, + #[snafu(display("webtransport control task dropped response"))] + ResponseDropped, + #[snafu(display("failed to write webtransport control capsule"))] + Write { source: MessageStreamError }, +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Snafu)] +#[snafu(module, visibility(pub))] pub enum OpenStreamError { + #[snafu(display("webtransport session closed"))] + Closed { source: SessionClosed }, #[snafu(display("failed to open QUIC stream"))] Open { source: ConnectionError }, - + #[snafu(display("failed to observe opened QUIC stream id"))] + StreamId { source: StreamError }, #[snafu(display("failed to write stream routing header"))] WriteHeader { source: StreamError }, + #[snafu(display("failed to send webtransport stream credit command"))] + Control { source: ControlCommandError }, } -// ============================================================================ -// Session-closed error -// ============================================================================ - -/// The WebTransport session has been closed and no more streams will arrive. #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[derive(Debug, Clone, Snafu)] -#[snafu(display("webtransport session closed"))] -pub struct Closed; - -// ============================================================================ -// Datagram errors -// ============================================================================ +#[snafu(module, visibility(pub))] +pub enum AcceptStreamError { + #[snafu(display("webtransport session closed"))] + Closed { source: SessionClosed }, + #[snafu(display("webtransport connection closed"))] + Connection { source: ConnectionError }, + #[snafu(display("failed to observe accepted QUIC stream id"))] + StreamId { source: StreamError }, +} -/// Errors from datagram operations on [`WebTransportSession`](super::WebTransportSession). #[derive(Debug, Snafu)] #[snafu(visibility(pub(super)))] pub enum DatagramError { diff --git a/src/webtransport/protocol.rs b/src/webtransport/protocol.rs index 604dcab..7bc784f 100644 --- a/src/webtransport/protocol.rs +++ b/src/webtransport/protocol.rs @@ -10,27 +10,28 @@ //! The protocol layer consumes exactly two fields from each incoming stream: //! //! 1. The stream signal value ([`VarInt`]) -//! 2. The session ID ([`VarInt`]) +//! 2. The session ID ([`StreamId`](crate::stream_id::StreamId)) //! //! The session receives the stream positioned after these two fields. -use std::{collections::HashMap, fmt, sync::Arc}; +use std::{fmt, sync::Arc}; use futures::future::BoxFuture; -use snafu::ensure; -use tokio::sync::mpsc; use super::{ - error::{AlreadyRegisteredSnafu, RegisterError}, - session::{RoutedBiStream, RoutedUniStream, WebTransportSession}, + WebTransportSessionId, + error::RegisterSessionError, + registry::{RegisteredSession, Registry, RouteBiError, RouteUniError}, }; use crate::{ - codec::{ - BoxReadStream, BoxWriteStream, DecodeExt, ErasedPeekableBiStream, ErasedPeekableUniStream, - }, + codec::{BoxPeekableStreamReader, BoxStreamWriter, DecodeExt}, connection::StreamError, protocol::{ProductProtocol, Protocol, Protocols, StreamVerdict}, - quic::{self, ConnectionError}, + quic::{ + self, BoxQuicStreamReader, BoxQuicStreamWriter, ConnectionError, ResetStreamExt, + StopStreamExt, + }, + stream_id::StreamId, varint::VarInt, }; @@ -38,50 +39,16 @@ use crate::{ // Constants // ============================================================================ +/// Extended CONNECT protocol token for WebTransport over HTTP/3. +pub const WEBTRANSPORT_H3: &str = "webtransport-h3"; + /// Signal value for WebTransport bidirectional streams /// (draft-ietf-webtrans-http3, §2). -pub const WT_BIDI_SIGNAL: VarInt = VarInt::from_u32(0x41); +pub const WEBTRANSPORT_BIDI_SIGNAL: VarInt = VarInt::from_u32(0x41); /// Signal value for WebTransport unidirectional streams /// (draft-ietf-webtrans-http3, §2). -pub const WT_UNI_SIGNAL: VarInt = VarInt::from_u32(0x54); - -// ============================================================================ -// Session stream router -// ============================================================================ - -/// Session registry: maps session ID to stream routers. -pub(super) type Registry = Arc>>; - -/// Routes incoming streams to a single WebTransport session via bounded channels. -pub(super) struct SessionStreamRouter { - bidi_tx: mpsc::Sender, - uni_tx: mpsc::Sender, -} - -impl SessionStreamRouter { - fn new(bidi_tx: mpsc::Sender, uni_tx: mpsc::Sender) -> Self { - Self { bidi_tx, uni_tx } - } - - fn route_bi(&self, session_id: VarInt, stream: RoutedBiStream) { - if self.bidi_tx.try_send(stream).is_err() { - tracing::debug!( - ?session_id, - "session bidi channel full or closed, dropping stream" - ); - } - } - - fn route_uni(&self, session_id: VarInt, stream: RoutedUniStream) { - if self.uni_tx.try_send(stream).is_err() { - tracing::debug!( - ?session_id, - "session uni channel full or closed, dropping stream" - ); - } - } -} +pub const WEBTRANSPORT_UNI_SIGNAL: VarInt = VarInt::from_u32(0x54); // ============================================================================ // WebTransportProtocol @@ -96,50 +63,35 @@ impl SessionStreamRouter { /// shared (via `Arc`) across all concurrent streams. pub struct WebTransportProtocol { registry: Registry, - conn: Arc, + conn: Arc, } impl fmt::Debug for WebTransportProtocol { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WebTransportProtocol") - .field( - "sessions", - &self.registry.lock().map(|r| r.len()).unwrap_or(0), - ) + .field("sessions", &self.registry.len()) .finish() } } impl WebTransportProtocol { - /// Register a new session for the given session ID. - /// - /// Returns a [`WebTransportSession`] that receives routed streams and can - /// open new streams. The session unregisters itself when dropped. - pub fn register(&self, session_id: VarInt) -> Result { - let (bidi_tx, bidi_rx) = mpsc::channel(16); - let (uni_tx, uni_rx) = mpsc::channel(16); - - let mut registry = self - .registry - .lock() - .map_err(|_| RegisterError::RegistryPoisoned)?; - - ensure!( - !registry.contains_key(&session_id), - AlreadyRegisteredSnafu { - session_id: crate::stream_id::StreamId::from(session_id), - } - ); + pub(super) fn register( + &self, + session_id: WebTransportSessionId, + ) -> Result { + self.registry.register(session_id) + } - registry.insert(session_id, SessionStreamRouter::new(bidi_tx, uni_tx)); + pub(super) fn connection(&self) -> Arc { + Arc::clone(&self.conn) + } - Ok(WebTransportSession::new( - session_id, - bidi_rx, - uni_rx, - Arc::clone(&self.conn), - Arc::clone(&self.registry), - )) + #[cfg(test)] + pub(crate) fn new_for_test(conn: Arc) -> Self { + Self { + registry: Registry::default(), + conn, + } } } @@ -150,38 +102,56 @@ impl WebTransportProtocol { impl WebTransportProtocol { async fn accept_bi_inner( &self, - (mut reader, writer): ErasedPeekableBiStream, - ) -> Result, StreamError> { + (mut reader, writer): (BoxPeekableStreamReader, BoxStreamWriter), + ) -> Result, StreamError> { let Ok(signal_value) = reader.decode_one::().await else { return Ok(StreamVerdict::Passed((reader, writer))); }; - if signal_value != WT_BIDI_SIGNAL { + if signal_value != WEBTRANSPORT_BIDI_SIGNAL { return Ok(StreamVerdict::Passed((reader, writer))); } // WebTransport bidi stream confirmed. Decode session ID for routing. - let Ok(session_id) = reader.decode_one::().await else { - tracing::debug!("failed to decode session ID from webtransport bidi stream"); + let Ok(raw_session_id) = reader.decode_one::().await else { + tracing::debug!("failed to decode session id from webtransport bidi stream"); return Ok(StreamVerdict::Accepted); }; + let session_id = WebTransportSessionId::try_from(raw_session_id) + .map_err(crate::connection::ConnectionError::from)?; - tracing::debug!(?session_id, "routing webtransport bidi stream to session"); + tracing::debug!(session_id = %session_id, "routing webtransport bidi stream to session"); - let reader: BoxReadStream = Box::pin(reader.into_stream_reader()); - let writer: BoxWriteStream = writer.into_inner(); + let reader: BoxQuicStreamReader = Box::pin(reader.into_stream_reader()); + let writer: BoxQuicStreamWriter = writer.into_inner(); - let Ok(registry) = self.registry.lock() else { - tracing::debug!("webtransport session registry lock poisoned"); - return Ok(StreamVerdict::Accepted); - }; - if let Some(router) = registry.get(&session_id) { - router.route_bi(session_id, (reader, writer)); - } else { - tracing::debug!( - ?session_id, - "no registered session for webtransport bidi stream" - ); + match self.registry.route_bi(session_id, (reader, writer)) { + Ok(()) => {} + // draft-ietf-webtrans-http3 §4: "Session IDs that correspond to + // closed sessions are not considered invalid"; §9.5 defines + // WT_SESSION_GONE for streams whose associated session has closed. + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3#section-4 + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3#section-9.5 + Err(RouteBiError::Closed((mut reader, mut writer))) => { + let code = crate::error::Code::WT_SESSION_GONE.into_inner(); + _ = tokio::join!(reader.stop(code), writer.reset(code)); + } + Err(RouteBiError::Unknown((mut reader, mut writer))) => { + // draft-ietf-webtrans-http3 §4 allows buffering streams for a + // not-yet-established session and requires WT_BUFFERED_STREAM_REJECTED + // when that buffer is full. h3x uses a zero-capacity pre-session + // buffer, so every valid unknown-session stream is rejected here. + let code = crate::error::Code::WT_BUFFERED_STREAM_REJECTED.into_inner(); + _ = tokio::join!(reader.stop(code), writer.reset(code)); + } + Err(RouteBiError::FlowControl((mut reader, mut writer))) => { + let code = crate::error::Code::WT_FLOW_CONTROL_ERROR.into_inner(); + _ = tokio::join!(reader.stop(code), writer.reset(code)); + } + Err(RouteBiError::Rejected((mut reader, mut writer))) => { + let code = crate::error::Code::WT_FLOW_CONTROL_ERROR.into_inner(); + _ = tokio::join!(reader.stop(code), writer.reset(code)); + } } Ok(StreamVerdict::Accepted) @@ -189,37 +159,55 @@ impl WebTransportProtocol { async fn accept_uni_inner( &self, - mut stream: ErasedPeekableUniStream, - ) -> Result, StreamError> { + mut stream: BoxPeekableStreamReader, + ) -> Result, StreamError> { let Ok(signal_value) = stream.decode_one::().await else { return Ok(StreamVerdict::Passed(stream)); }; - if signal_value != WT_UNI_SIGNAL { + if signal_value != WEBTRANSPORT_UNI_SIGNAL { return Ok(StreamVerdict::Passed(stream)); } // WebTransport uni stream confirmed. Decode session ID for routing. - let Ok(session_id) = stream.decode_one::().await else { - tracing::debug!("failed to decode session ID from webtransport uni stream"); + let Ok(raw_session_id) = stream.decode_one::().await else { + tracing::debug!("failed to decode session id from webtransport uni stream"); return Ok(StreamVerdict::Accepted); }; + let session_id = WebTransportSessionId::try_from(raw_session_id) + .map_err(crate::connection::ConnectionError::from)?; - tracing::debug!(?session_id, "routing webtransport uni stream to session"); + tracing::debug!(session_id = %session_id, "routing webtransport uni stream to session"); - let reader: BoxReadStream = Box::pin(stream.into_stream_reader()); + let reader: BoxQuicStreamReader = Box::pin(stream.into_stream_reader()); - let Ok(registry) = self.registry.lock() else { - tracing::debug!("webtransport session registry lock poisoned"); - return Ok(StreamVerdict::Accepted); - }; - if let Some(router) = registry.get(&session_id) { - router.route_uni(session_id, reader); - } else { - tracing::debug!( - ?session_id, - "no registered session for webtransport uni stream" - ); + match self.registry.route_uni(session_id, reader) { + Ok(()) => {} + // draft-ietf-webtrans-http3 §4: "Session IDs that correspond to + // closed sessions are not considered invalid"; §9.5 defines + // WT_SESSION_GONE for streams whose associated session has closed. + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3#section-4 + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3#section-9.5 + Err(RouteUniError::Closed(mut reader)) => { + let code = crate::error::Code::WT_SESSION_GONE.into_inner(); + _ = reader.stop(code).await; + } + Err(RouteUniError::Unknown(mut reader)) => { + // draft-ietf-webtrans-http3 §4 allows buffering streams for a + // not-yet-established session and requires WT_BUFFERED_STREAM_REJECTED + // when that buffer is full. h3x uses a zero-capacity pre-session + // buffer, so every valid unknown-session stream is rejected here. + let code = crate::error::Code::WT_BUFFERED_STREAM_REJECTED.into_inner(); + _ = reader.stop(code).await; + } + Err(RouteUniError::FlowControl(mut reader)) => { + let code = crate::error::Code::WT_FLOW_CONTROL_ERROR.into_inner(); + _ = reader.stop(code).await; + } + Err(RouteUniError::Rejected(mut reader)) => { + let code = crate::error::Code::WT_FLOW_CONTROL_ERROR.into_inner(); + _ = reader.stop(code).await; + } } Ok(StreamVerdict::Accepted) @@ -229,15 +217,16 @@ impl WebTransportProtocol { impl Protocol for WebTransportProtocol { fn accept_uni<'a>( &'a self, - stream: ErasedPeekableUniStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: BoxPeekableStreamReader, + ) -> BoxFuture<'a, Result, StreamError>> { Box::pin(self.accept_uni_inner(stream)) } fn accept_bi<'a>( &'a self, - stream: ErasedPeekableBiStream, - ) -> BoxFuture<'a, Result, StreamError>> { + stream: (BoxPeekableStreamReader, BoxStreamWriter), + ) -> BoxFuture<'a, Result, StreamError>> + { Box::pin(self.accept_bi_inner(stream)) } } @@ -271,10 +260,10 @@ impl ProductProtocol for WebTransportProtocolFactory { conn: &'a Arc, _layers: &'a Protocols, ) -> BoxFuture<'a, Result> { - let conn: Arc = conn.clone(); + let conn: Arc = conn.clone(); Box::pin(async move { Ok(WebTransportProtocol { - registry: Arc::new(std::sync::Mutex::new(HashMap::new())), + registry: Registry::default(), conn, }) }) @@ -287,15 +276,1421 @@ impl ProductProtocol for WebTransportProtocolFactory { #[cfg(test)] mod tests { + use std::{ + borrow::Cow, + collections::VecDeque, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use dhttp_identity::identity as authority; + use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt, future::pending}; + use super::*; + use crate::{ + codec::{PeekableStreamReader, SinkWriter, StreamReader}, + error::Code, + protocol::InitProtocols, + quic, + webtransport::WebTransportSessionId, + }; + + #[test] + fn webtransport_protocol_is_send_sync() { + fn assert_send_sync() {} + + assert_send_sync::(); + } + + #[test] + fn webtransport_constants_are_draft15_values() { + assert_eq!(WEBTRANSPORT_H3, "webtransport-h3"); + assert_eq!(WEBTRANSPORT_BIDI_SIGNAL.into_inner(), 0x41); + assert_eq!(WEBTRANSPORT_UNI_SIGNAL.into_inner(), 0x54); + assert_eq!(Code::WT_SESSION_GONE.into_inner(), 0x170d7b68); + } + + #[tokio::test] + async fn unknown_bidi_session_is_rejected_with_wt_buffered_stream_rejected() { + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x41, 0x28])]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + state.stopped_codes(), + vec![Code::WT_BUFFERED_STREAM_REJECTED.into_inner()] + ); + assert_eq!( + state.reset_codes(), + vec![Code::WT_BUFFERED_STREAM_REJECTED.into_inner()] + ); + } + + #[tokio::test] + async fn unknown_uni_session_is_rejected_with_wt_buffered_stream_rejected() { + let state = Arc::new(StreamState::default()); + let stream = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x54, 0x28])]); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + state.stopped_codes(), + vec![Code::WT_BUFFERED_STREAM_REJECTED.into_inner()] + ); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn incoming_stream_with_invalid_session_id_closes_connection_with_h3_id_error() { + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x41, 0x03])]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let error = match protocol.accept_bi((reader, writer)).await { + Ok(_) => panic!("invalid WT session id must be connection scoped"), + Err(error) => error, + }; + + assert_stream_connection_code(error, Code::H3_ID_ERROR); + } + + #[test] + fn registry_register_accepts_only_webtransport_session_id() { + let protocol = test_protocol(); + let session_id = WebTransportSessionId::try_from(StreamId::from(VarInt::from_u32(4))) + .expect("client bidi stream id is valid WT session id"); + + let registered = protocol + .register(session_id) + .expect("registration succeeds"); + + assert_eq!(registered.state.id(), session_id); + } + + #[tokio::test] + async fn closed_bidi_session_id_uses_wt_session_gone() { + let protocol = test_protocol(); + let registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + registered.state.close(); + + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x41, 0x28])]); + let writer = test_writer(state.clone()); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + state.stopped_codes(), + vec![Code::WT_SESSION_GONE.into_inner()] + ); + assert_eq!( + state.reset_codes(), + vec![Code::WT_SESSION_GONE.into_inner()] + ); + } + + #[tokio::test] + async fn closed_uni_session_id_uses_wt_session_gone() { + let protocol = test_protocol(); + let registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + registered.state.close(); + + let state = Arc::new(StreamState::default()); + let stream = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x54, 0x28])]); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + state.stopped_codes(), + vec![Code::WT_SESSION_GONE.into_inner()] + ); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn registered_bidi_session_receives_payload_after_header() { + let state = Arc::new(StreamState::default()); + let reader = test_reader( + state.clone(), + vec![Bytes::from_static(&[0x40, 0x41, 0x28, 0xde, 0xad])], + ); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + let mut registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + let (mut routed_reader, _routed_writer) = registered + .bidi_rx + .recv() + .await + .expect("registered session should receive bidi stream"); + assert_eq!( + routed_reader + .next() + .await + .expect("reader should yield a payload chunk") + .expect("payload chunk should succeed"), + Bytes::from_static(&[0xde, 0xad]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn registered_uni_session_receives_payload_after_header() { + let state = Arc::new(StreamState::default()); + let stream = test_reader( + state.clone(), + vec![Bytes::from_static(&[0x40, 0x54, 0x28, 0xbe, 0xef])], + ); + let protocol = test_protocol(); + let mut registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + let mut routed_reader = registered + .uni_rx + .recv() + .await + .expect("registered session should receive uni stream"); + assert_eq!( + routed_reader + .next() + .await + .expect("reader should yield a payload chunk") + .expect("payload chunk should succeed"), + Bytes::from_static(&[0xbe, 0xef]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn split_bidi_routing_header_routes_registered_stream() { + let state = Arc::new(StreamState::default()); + let reader = test_reader( + state.clone(), + vec![ + Bytes::from_static(&[0x40]), + Bytes::from_static(&[0x41]), + Bytes::from_static(&[0x28, 0xca, 0xfe]), + ], + ); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + let mut registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + let (mut routed_reader, _routed_writer) = registered + .bidi_rx + .recv() + .await + .expect("registered session should receive bidi stream"); + assert_eq!( + routed_reader + .next() + .await + .expect("reader should yield a payload chunk") + .expect("payload chunk should succeed"), + Bytes::from_static(&[0xca, 0xfe]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn split_uni_routing_header_routes_registered_stream() { + let state = Arc::new(StreamState::default()); + let stream = test_reader( + state.clone(), + vec![ + Bytes::from_static(&[0x40]), + Bytes::from_static(&[0x54]), + Bytes::from_static(&[0x28, 0xca, 0xfe]), + ], + ); + let protocol = test_protocol(); + let mut registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + let mut routed_reader = registered + .uni_rx + .recv() + .await + .expect("registered session should receive uni stream"); + assert_eq!( + routed_reader + .next() + .await + .expect("reader should yield a payload chunk") + .expect("payload chunk should succeed"), + Bytes::from_static(&[0xca, 0xfe]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn wrong_bidi_signal_is_passed_through_without_aborting() { + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x05, 0x99])]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + let StreamVerdict::Passed((mut passed_reader, _passed_writer)) = verdict else { + panic!("unexpected verdict"); + }; + Pin::new(&mut passed_reader).reset(); + let mut passed_reader = passed_reader.into_stream_reader(); + assert_eq!( + passed_reader + .next() + .await + .expect("reader should yield original bytes") + .expect("reader chunk should succeed"), + Bytes::from_static(&[0x05, 0x99]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn wrong_uni_signal_is_passed_through_without_stopping() { + let state = Arc::new(StreamState::default()); + let stream = test_reader(state.clone(), vec![Bytes::from_static(&[0x06, 0x77])]); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + let StreamVerdict::Passed(mut passed_stream) = verdict else { + panic!("unexpected verdict"); + }; + Pin::new(&mut passed_stream).reset(); + let mut passed_stream = passed_stream.into_stream_reader(); + assert_eq!( + passed_stream + .next() + .await + .expect("reader should yield original bytes") + .expect("reader chunk should succeed"), + Bytes::from_static(&[0x06, 0x77]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn incomplete_bidi_signal_is_passed_through_without_aborting() { + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x40])]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + let StreamVerdict::Passed((mut passed_reader, _passed_writer)) = verdict else { + panic!("unexpected verdict"); + }; + Pin::new(&mut passed_reader).reset(); + let mut passed_reader = passed_reader.into_stream_reader(); + assert_eq!( + passed_reader + .next() + .await + .expect("reader should yield original bytes") + .expect("reader chunk should succeed"), + Bytes::from_static(&[0x40]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn incomplete_uni_signal_is_passed_through_without_stopping() { + let state = Arc::new(StreamState::default()); + let stream = test_reader(state.clone(), vec![Bytes::from_static(&[0x40])]); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + let StreamVerdict::Passed(mut passed_stream) = verdict else { + panic!("unexpected verdict"); + }; + Pin::new(&mut passed_stream).reset(); + let mut passed_stream = passed_stream.into_stream_reader(); + assert_eq!( + passed_stream + .next() + .await + .expect("reader should yield original bytes") + .expect("reader chunk should succeed"), + Bytes::from_static(&[0x40]) + ); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn truncated_bidi_session_id_is_accepted_without_abort() { + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x41])]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn truncated_uni_session_id_is_accepted_without_stop() { + let state = Arc::new(StreamState::default()); + let stream = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x54])]); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn reset_while_decoding_bidi_signal_is_passed_through_without_aborting() { + let state = Arc::new(StreamState::default()); + let reader = test_reader_results(state.clone(), vec![Err(reset_stream_error(7))]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Passed(_))); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn reset_while_decoding_uni_signal_is_passed_through_without_stopping() { + let state = Arc::new(StreamState::default()); + let stream = test_reader_results(state.clone(), vec![Err(reset_stream_error(7))]); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Passed(_))); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn reset_while_decoding_bidi_session_id_is_accepted_without_abort() { + let state = Arc::new(StreamState::default()); + let reader = test_reader_results( + state.clone(), + vec![ + Ok(Bytes::from_static(&[0x40, 0x41])), + Err(reset_stream_error(9)), + ], + ); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn reset_while_decoding_uni_session_id_is_accepted_without_stop() { + let state = Arc::new(StreamState::default()); + let stream = test_reader_results( + state.clone(), + vec![ + Ok(Bytes::from_static(&[0x40, 0x54])), + Err(reset_stream_error(9)), + ], + ); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn empty_bidi_stream_is_passed_through_without_side_effects() { + let state = Arc::new(StreamState::default()); + let reader = test_reader(state.clone(), Vec::new()); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + let StreamVerdict::Passed((mut passed_reader, _passed_writer)) = verdict else { + panic!("unexpected verdict"); + }; + Pin::new(&mut passed_reader).reset(); + let mut passed_reader = passed_reader.into_stream_reader(); + assert!(passed_reader.next().await.is_none()); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn empty_uni_stream_is_passed_through_without_side_effects() { + let state = Arc::new(StreamState::default()); + let stream = test_reader(state.clone(), Vec::new()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + let StreamVerdict::Passed(mut passed_stream) = verdict else { + panic!("unexpected verdict"); + }; + Pin::new(&mut passed_stream).reset(); + let mut passed_stream = passed_stream.into_stream_reader(); + assert!(passed_stream.next().await.is_none()); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn full_bidi_session_channel_rejects_and_aborts_extra_stream() { + let protocol = test_protocol(); + let mut registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + + for _ in 0..16 { + let state = Arc::new(StreamState::default()); + let reader = test_reader( + state.clone(), + vec![Bytes::from_static(&[0x40, 0x41, 0x28, 0xde, 0xad])], + ); + let writer = test_writer(state.clone()); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + let overflow_state = Arc::new(StreamState::default()); + let reader = test_reader( + overflow_state.clone(), + vec![Bytes::from_static(&[0x40, 0x41, 0x28, 0xfa, 0xce])], + ); + let writer = test_writer(overflow_state.clone()); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + overflow_state.stopped_codes(), + vec![Code::WT_FLOW_CONTROL_ERROR.into_inner()] + ); + assert_eq!( + overflow_state.reset_codes(), + vec![Code::WT_FLOW_CONTROL_ERROR.into_inner()] + ); + + let (mut routed_reader, _routed_writer) = registered + .bidi_rx + .recv() + .await + .expect("registered session should still retain queued streams"); + assert_eq!( + routed_reader + .next() + .await + .expect("reader should yield a payload chunk") + .expect("payload chunk should succeed"), + Bytes::from_static(&[0xde, 0xad]) + ); + } + + #[tokio::test] + async fn failed_bidi_rejection_for_unknown_session_is_ignored() { + let state = Arc::new(StreamState::with_stop_and_reset_errors()); + let reader = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x41, 0x28])]); + let writer = test_writer(state.clone()); + let protocol = test_protocol(); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + state.stopped_codes(), + vec![Code::WT_BUFFERED_STREAM_REJECTED.into_inner()] + ); + assert_eq!( + state.reset_codes(), + vec![Code::WT_BUFFERED_STREAM_REJECTED.into_inner()] + ); + } + + #[tokio::test] + async fn full_uni_session_channel_rejects_and_stops_extra_stream() { + let protocol = test_protocol(); + let mut registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + + for _ in 0..16 { + let state = Arc::new(StreamState::default()); + let stream = test_reader( + state.clone(), + vec![Bytes::from_static(&[0x40, 0x54, 0x28, 0xbe, 0xef])], + ); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert!(state.stopped_codes().is_empty()); + assert!(state.reset_codes().is_empty()); + } + + let overflow_state = Arc::new(StreamState::default()); + let stream = test_reader( + overflow_state.clone(), + vec![Bytes::from_static(&[0x40, 0x54, 0x28, 0xca, 0xfe])], + ); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + overflow_state.stopped_codes(), + vec![Code::WT_FLOW_CONTROL_ERROR.into_inner()] + ); + assert!(overflow_state.reset_codes().is_empty()); + + let mut routed_reader = registered + .uni_rx + .recv() + .await + .expect("registered session should still retain queued streams"); + assert_eq!( + routed_reader + .next() + .await + .expect("reader should yield a payload chunk") + .expect("payload chunk should succeed"), + Bytes::from_static(&[0xbe, 0xef]) + ); + } + + #[tokio::test] + async fn failed_uni_rejection_for_unknown_session_is_ignored() { + let state = Arc::new(StreamState::with_stop_error()); + let stream = test_reader(state.clone(), vec![Bytes::from_static(&[0x40, 0x54, 0x28])]); + let protocol = test_protocol(); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + state.stopped_codes(), + vec![Code::WT_BUFFERED_STREAM_REJECTED.into_inner()] + ); + assert!(state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn closed_bidi_session_receiver_rejects_and_aborts_stream() { + let protocol = test_protocol(); + let RegisteredSession { + state: _session_state, + bidi_rx, + uni_rx: _uni_rx, + } = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + drop(bidi_rx); + + let stream_state = Arc::new(StreamState::default()); + let reader = test_reader( + stream_state.clone(), + vec![Bytes::from_static(&[0x40, 0x41, 0x28, 0xfa, 0xce])], + ); + let writer = test_writer(stream_state.clone()); + + let verdict = protocol + .accept_bi((reader, writer)) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + stream_state.stopped_codes(), + vec![Code::WT_FLOW_CONTROL_ERROR.into_inner()] + ); + assert_eq!( + stream_state.reset_codes(), + vec![Code::WT_FLOW_CONTROL_ERROR.into_inner()] + ); + } + + #[tokio::test] + async fn closed_uni_session_receiver_rejects_and_stops_stream() { + let protocol = test_protocol(); + let RegisteredSession { + state: _session_state, + bidi_rx: _bidi_rx, + uni_rx, + } = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + drop(uni_rx); + + let stream_state = Arc::new(StreamState::default()); + let stream = test_reader( + stream_state.clone(), + vec![Bytes::from_static(&[0x40, 0x54, 0x28, 0xfa, 0xce])], + ); + + let verdict = protocol + .accept_uni(stream) + .await + .expect("routing should not fail"); + + assert!(matches!(verdict, StreamVerdict::Accepted)); + assert_eq!( + stream_state.stopped_codes(), + vec![Code::WT_FLOW_CONTROL_ERROR.into_inner()] + ); + assert!(stream_state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn factory_init_keeps_connection_and_starts_empty() { + let conn = Arc::new(TestConnection); + let expected: Arc = conn.clone(); + let protocol = WebTransportProtocolFactory + .init(&conn, &Protocols::new()) + .await + .expect("factory init should succeed"); + + assert!(Arc::ptr_eq(&protocol.connection(), &expected)); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 0 }" + ); + } + + #[tokio::test] + async fn factory_initializer_inserts_webtransport_protocol_once() { + let conn = Arc::new(TestConnection); + let mut layers = Protocols::new(); + + WebTransportProtocolFactory + .init_protocols(&conn, &mut layers) + .await + .expect("factory initializer should insert protocol"); + + let protocol = layers + .get::() + .expect("webtransport protocol should be registered"); + let _registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 1 }" + ); + + WebTransportProtocolFactory + .init_protocols(&conn, &mut layers) + .await + .expect("second initialization should be a no-op"); + + let protocol = layers + .get::() + .expect("webtransport protocol should remain registered"); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 1 }" + ); + } + + #[test] + fn new_for_test_connection_and_duplicate_registration_behave_consistently() { + let conn: Arc = Arc::new(TestConnection); + let protocol = WebTransportProtocol::new_for_test(conn.clone()); + let session_id = wt_session_id(40); + + assert!(Arc::ptr_eq(&protocol.connection(), &conn)); + + let registered = protocol + .register(session_id) + .expect("session registration should succeed"); + assert_eq!(registered.state.id(), session_id); + + match protocol.register(session_id) { + Err(RegisterSessionError::AlreadyRegistered { + session_id: duplicate, + }) => assert_eq!(duplicate, session_id), + other => panic!("unexpected duplicate registration result: {other:?}"), + } + + registered.state.close(); + + let error = protocol + .register(session_id) + .expect_err("closed session id must remain reserved"); + assert!(matches!( + error, + RegisterSessionError::AlreadyRegistered { + session_id: duplicate + } if duplicate == session_id + )); + } + + #[test] + fn explicit_session_close_unregisters_and_reports_closed() { + let protocol = test_protocol(); + let session_id = wt_session_id(40); + let registered = protocol + .register(session_id) + .expect("session registration should succeed"); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 1 }" + ); - const fn assert_send_sync() {} - // WebTransportProtocol is Send + Sync (required for Protocol trait). - const _: () = assert_send_sync::(); + registered.state.close(); + + assert!(registered.state.check_open().is_err()); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 0 }" + ); + let error = protocol + .register(session_id) + .expect_err("closed session id must remain reserved"); + assert!(matches!( + error, + RegisterSessionError::AlreadyRegistered { + session_id: duplicate + } if duplicate == session_id + )); + } #[test] - fn signal_values_are_correct() { - assert_eq!(WT_BIDI_SIGNAL.into_inner(), 0x41); - assert_eq!(WT_UNI_SIGNAL.into_inner(), 0x54); + fn protocol_debug_and_factory_display_expose_session_count_and_name() { + use std::collections::{BTreeSet, HashSet}; + + let mut hash_set = HashSet::new(); + hash_set.insert(WebTransportProtocolFactory); + hash_set.insert(WebTransportProtocolFactory); + assert_eq!(hash_set.len(), 1); + + let mut btree_set = BTreeSet::new(); + btree_set.insert(WebTransportProtocolFactory); + btree_set.insert(WebTransportProtocolFactory); + assert_eq!(btree_set.len(), 1); + + let protocol = test_protocol(); + assert_eq!(format!("{}", WebTransportProtocolFactory), "WebTransport"); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 0 }" + ); + + let _registered = protocol + .register(wt_session_id(40)) + .expect("session registration should succeed"); + assert_eq!( + format!("{protocol:?}"), + "WebTransportProtocol { sessions: 1 }" + ); + } + + #[tokio::test] + async fn test_stream_doubles_delegate_stream_id_stop_reset_and_sink_operations() { + use quic::{GetStreamIdExt as _, ResetStreamExt as _, StopStreamExt as _}; + + let state = Arc::new(StreamState::default()); + + let mut reader = TestReadStream { + state: state.clone(), + chunks: VecDeque::from([Ok(Bytes::from_static(b"payload"))]), + }; + assert_eq!( + reader + .stream_id() + .await + .expect("reader stream id should be available"), + VarInt::from_u32(0) + ); + assert_eq!( + reader + .next() + .await + .expect("reader should yield chunk") + .expect("chunk should succeed"), + Bytes::from_static(b"payload") + ); + reader + .stop(VarInt::from_u32(0x11)) + .await + .expect("reader stop should succeed"); + assert_eq!(state.stopped_codes(), vec![VarInt::from_u32(0x11)]); + + let mut writer = TestWriteStream { + state: state.clone(), + }; + assert_eq!( + writer + .stream_id() + .await + .expect("writer stream id should be available"), + VarInt::from_u32(0) + ); + writer + .send(Bytes::from_static(b"ignored")) + .await + .expect("test writer sink send should succeed"); + writer + .flush() + .await + .expect("test writer sink flush should succeed"); + writer + .close() + .await + .expect("test writer sink close should succeed"); + writer + .reset(VarInt::from_u32(0x12)) + .await + .expect("writer reset should succeed"); + assert_eq!(state.reset_codes(), vec![VarInt::from_u32(0x12)]); + + let failing_reader_state = Arc::new(StreamState::with_stop_error()); + let mut failing_reader = TestReadStream { + state: failing_reader_state.clone(), + chunks: VecDeque::new(), + }; + let error = failing_reader + .stop(VarInt::from_u32(0x13)) + .await + .expect_err("configured stop error should surface"); + assert!(error.is_reset()); + assert_eq!( + failing_reader_state.stopped_codes(), + vec![VarInt::from_u32(0x13)] + ); + + let failing_writer_state = Arc::new(StreamState::with_stop_and_reset_errors()); + let mut failing_writer = TestWriteStream { + state: failing_writer_state.clone(), + }; + let error = failing_writer + .reset(VarInt::from_u32(0x14)) + .await + .expect_err("configured reset error should surface"); + assert!(error.is_reset()); + assert_eq!( + failing_writer_state.reset_codes(), + vec![VarInt::from_u32(0x14)] + ); + } + + #[tokio::test] + async fn test_connection_double_delegates_agents_lifecycle_and_pending_streams() { + let conn = TestConnection; + + assert!( + quic::ManageStream::open_bi(&conn).now_or_never().is_none(), + "test open_bi should remain pending" + ); + assert!( + quic::ManageStream::open_uni(&conn).now_or_never().is_none(), + "test open_uni should remain pending" + ); + assert!( + quic::ManageStream::accept_bi(&conn) + .now_or_never() + .is_none(), + "test accept_bi should remain pending" + ); + assert!( + quic::ManageStream::accept_uni(&conn) + .now_or_never() + .is_none(), + "test accept_uni should remain pending" + ); + + assert!( + quic::WithLocalAuthority::local_authority(&conn) + .await + .expect("local authority lookup should succeed") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(&conn) + .await + .expect("remote authority lookup should succeed") + .is_none() + ); + quic::Lifecycle::close(&conn, Code::H3_NO_ERROR, Cow::Borrowed("test close")); + quic::Lifecycle::check(&conn).expect("test connection should remain open"); + assert!( + quic::Lifecycle::closed(&conn).now_or_never().is_none(), + "test closed future should remain pending" + ); + + let dyn_conn: Arc = Arc::new(TestConnection); + assert!( + dyn_conn.open_bi().now_or_never().is_none(), + "dyn open_bi should delegate to the pending test connection" + ); + assert!( + dyn_conn.open_uni().now_or_never().is_none(), + "dyn open_uni should delegate to the pending test connection" + ); + assert!( + dyn_conn.accept_bi().now_or_never().is_none(), + "dyn accept_bi should delegate to the pending test connection" + ); + assert!( + dyn_conn.accept_uni().now_or_never().is_none(), + "dyn accept_uni should delegate to the pending test connection" + ); + assert!( + dyn_conn + .local_authority() + .await + .expect("dyn local authority lookup should succeed") + .is_none() + ); + assert!( + dyn_conn + .remote_authority() + .await + .expect("dyn remote authority lookup should succeed") + .is_none() + ); + dyn_conn.close(Code::H3_NO_ERROR, Cow::Borrowed("dyn test close")); + dyn_conn + .check() + .expect("dyn test connection should remain open"); + assert!( + dyn_conn.closed().now_or_never().is_none(), + "dyn closed future should delegate to the pending test connection" + ); + } + + #[tokio::test] + async fn test_agent_doubles_expose_identity_metadata_and_signing() { + let local = TestLocalAuthority; + assert_eq!(authority::LocalAuthority::name(&local), "test-local"); + assert!(authority::LocalAuthority::cert_chain(&local).is_empty()); + assert_eq!( + authority::LocalAuthority::sign(&local, b"payload") + .await + .expect("test local authority signing should succeed"), + Vec::::new() + ); + + let remote = TestRemoteAuthority; + assert_eq!(authority::RemoteAuthority::name(&remote), "test-remote"); + assert!(authority::RemoteAuthority::cert_chain(&remote).is_empty()); + } + + #[derive(Debug, Default)] + struct StreamState { + stopped: Mutex>, + resets: Mutex>, + fail_stop: bool, + fail_reset: bool, + } + + impl StreamState { + fn with_stop_error() -> Self { + Self { + fail_stop: true, + ..Self::default() + } + } + + fn with_stop_and_reset_errors() -> Self { + Self { + fail_stop: true, + fail_reset: true, + ..Self::default() + } + } + + fn stopped_codes(&self) -> Vec { + self.stopped.lock().expect("stopped lock poisoned").clone() + } + + fn reset_codes(&self) -> Vec { + self.resets.lock().expect("resets lock poisoned").clone() + } + } + + #[derive(Debug)] + struct TestReadStream { + state: Arc, + chunks: VecDeque>, + } + + impl Stream for TestReadStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.chunks.pop_front()) + } + } + + impl quic::GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(VarInt::from_u32(0))) + } + } + + impl quic::StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.state + .stopped + .lock() + .expect("stopped lock poisoned") + .push(code); + if self.state.fail_stop { + Poll::Ready(Err(reset_stream_error(0x1f))) + } else { + Poll::Ready(Ok(())) + } + } + } + + #[derive(Debug)] + struct TestWriteStream { + state: Arc, + } + + impl Sink for TestWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl quic::GetStreamId for TestWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(VarInt::from_u32(0))) + } + } + + impl quic::ResetStream for TestWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.state + .resets + .lock() + .expect("resets lock poisoned") + .push(code); + if self.state.fail_reset { + Poll::Ready(Err(reset_stream_error(0x2f))) + } else { + Poll::Ready(Ok(())) + } + } + } + + #[derive(Debug)] + struct TestLocalAuthority; + + impl authority::LocalAuthority for TestLocalAuthority { + fn name(&self) -> &str { + "test-local" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + fn sign( + &self, + _data: &[u8], + ) -> futures::future::BoxFuture<'_, Result, authority::SignError>> { + Box::pin(async { Ok(Vec::new()) }) + } + } + + #[derive(Debug)] + struct TestRemoteAuthority; + + impl authority::RemoteAuthority for TestRemoteAuthority { + fn name(&self) -> &str { + "test-remote" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + } + + #[derive(Debug, Default)] + struct TestConnection; + + impl quic::ManageStream for TestConnection { + type StreamReader = TestReadStream; + type StreamWriter = TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn open_uni(&self) -> Result { + pending().await + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for TestConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for TestConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for TestConnection { + fn close(&self, _code: Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + pending().await + } + } + + fn test_reader( + state: Arc, + chunks: Vec, + ) -> PeekableStreamReader { + test_reader_results(state, chunks.into_iter().map(Ok).collect()) + } + + fn test_reader_results( + state: Arc, + chunks: Vec>, + ) -> PeekableStreamReader { + let stream = TestReadStream { + state, + chunks: chunks.into(), + }; + PeekableStreamReader::new(StreamReader::new(Box::pin(stream) as BoxQuicStreamReader)) + } + + fn test_writer(state: Arc) -> SinkWriter { + SinkWriter::new(Box::pin(TestWriteStream { state }) as BoxQuicStreamWriter) + } + + fn test_protocol() -> WebTransportProtocol { + let conn: Arc = Arc::new(TestConnection); + WebTransportProtocol { + registry: Registry::default(), + conn, + } + } + + fn wt_session_id(id: u32) -> WebTransportSessionId { + WebTransportSessionId::try_from(StreamId::from(VarInt::from_u32(id))) + .expect("test id must be a valid webtransport session id") + } + + fn reset_stream_error(code: u32) -> quic::StreamError { + quic::StreamError::Reset { + code: VarInt::from_u32(code), + } + } + + fn assert_stream_connection_code(error: StreamError, expected: Code) { + let StreamError::Connection { source } = error else { + panic!("expected connection-scoped stream error, got {error:?}"); + }; + let crate::connection::ConnectionError::H3 { source } = source else { + panic!("expected h3 connection error, got {source:?}"); + }; + assert_eq!(source.code(), expected); } } diff --git a/src/webtransport/registry.rs b/src/webtransport/registry.rs new file mode 100644 index 0000000..2b46489 --- /dev/null +++ b/src/webtransport/registry.rs @@ -0,0 +1,1560 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::{ + Arc, Mutex, Weak, + atomic::{AtomicBool, Ordering}, + }, +}; + +use futures::{StreamExt, future::BoxFuture, stream::FuturesUnordered}; +use snafu::ResultExt; +use tokio::sync::{mpsc, watch}; +use tracing::Instrument; + +use super::{ + WebTransportSessionId, WebTransportStreamCount, + error::{ + CloseReason, RegisterSessionError, SessionCloseReason, SessionClosed, SessionDrain, + SessionFlowControlError, session_flow_control_error, + }, + session::{ + RoutedBiStream, RoutedUniStream, + stream::{TrackedStreamReader, TrackedStreamWriter}, + }, +}; +use crate::{ + error::Code, + quic::{ResetStreamExt, StopStreamExt}, + stream_id::StreamId, + varint::VarInt, +}; + +const SESSION_STREAM_CHANNEL_SIZE: usize = 16; + +#[derive(Debug, Default)] +struct RegistryInner { + active: HashMap>, + closed: HashSet, +} + +#[derive(Debug, Default, Clone)] +pub(super) struct Registry { + inner: Arc>, +} + +pub(super) enum RouteBiError { + Unknown(RoutedBiStream), + Closed(RoutedBiStream), + FlowControl(RoutedBiStream), + Rejected(RoutedBiStream), +} + +pub(super) enum RouteUniError { + Unknown(RoutedUniStream), + Closed(RoutedUniStream), + FlowControl(RoutedUniStream), + Rejected(RoutedUniStream), +} + +impl Registry { + pub(super) fn register( + &self, + session_id: WebTransportSessionId, + ) -> Result { + let default_credit = default_initial_stream_credit(); + self.register_with_credit(session_id, default_credit, default_credit) + } + + pub(super) fn register_with_credit( + &self, + session_id: WebTransportSessionId, + bidi_credit: WebTransportStreamCount, + uni_credit: WebTransportStreamCount, + ) -> Result { + let bidi_queue_capacity = incoming_stream_queue_capacity(bidi_credit); + let uni_queue_capacity = incoming_stream_queue_capacity(uni_credit); + let (bidi_tx, bidi_rx) = mpsc::channel(bidi_queue_capacity); + let (uni_tx, uni_rx) = mpsc::channel(uni_queue_capacity); + + let Ok(mut inner) = self.inner.lock() else { + return Err(RegisterSessionError::RegistryPoisoned); + }; + + if inner.active.contains_key(&session_id) || inner.closed.contains(&session_id) { + return Err(RegisterSessionError::AlreadyRegistered { session_id }); + } + + let state = Arc::new(SessionState::new( + session_id, + self.clone(), + bidi_tx, + uni_tx, + bidi_credit, + uni_credit, + bidi_queue_capacity, + uni_queue_capacity, + )); + + inner.active.insert(session_id, Arc::downgrade(&state)); + + Ok(RegisteredSession { + state, + bidi_rx, + uni_rx, + }) + } + + pub(super) fn unregister(&self, session_id: WebTransportSessionId) { + self.close(session_id); + } + + fn close(&self, session_id: WebTransportSessionId) { + let Ok(mut inner) = self.inner.lock() else { + tracing::debug!(?session_id, "webtransport session registry lock poisoned"); + return; + }; + inner.active.remove(&session_id); + inner.closed.insert(session_id); + } + + pub(super) fn route_bi( + &self, + session_id: WebTransportSessionId, + stream: RoutedBiStream, + ) -> Result<(), RouteBiError> { + let state = { + let Ok(inner) = self.inner.lock() else { + tracing::debug!(session_id = %session_id, "webtransport session registry lock poisoned"); + return Err(RouteBiError::Rejected(stream)); + }; + let Some(state) = inner + .active + .get(&session_id) + .and_then(std::sync::Weak::upgrade) + else { + if inner.closed.contains(&session_id) { + tracing::debug!(session_id = %session_id, "webtransport bidi stream belongs to closed session"); + return Err(RouteBiError::Closed(stream)); + } + tracing::debug!(session_id = %session_id, "no registered session for webtransport bidi stream"); + return Err(RouteBiError::Unknown(stream)); + }; + state + }; + + state.route_incoming_bi(stream) + } + + pub(super) fn route_uni( + &self, + session_id: WebTransportSessionId, + stream: RoutedUniStream, + ) -> Result<(), RouteUniError> { + let state = { + let Ok(inner) = self.inner.lock() else { + tracing::debug!(session_id = %session_id, "webtransport session registry lock poisoned"); + return Err(RouteUniError::Rejected(stream)); + }; + let Some(state) = inner + .active + .get(&session_id) + .and_then(std::sync::Weak::upgrade) + else { + if inner.closed.contains(&session_id) { + tracing::debug!(session_id = %session_id, "webtransport uni stream belongs to closed session"); + return Err(RouteUniError::Closed(stream)); + } + tracing::debug!(session_id = %session_id, "no registered session for webtransport uni stream"); + return Err(RouteUniError::Unknown(stream)); + }; + state + }; + + state.route_incoming_uni(stream) + } +} + +impl Registry { + pub(super) fn len(&self) -> usize { + self.inner + .lock() + .map(|inner| inner.active.len()) + .unwrap_or(0) + } +} + +#[derive(Debug, Clone, Copy)] +struct IncomingStreamCredit { + advertised_max: WebTransportStreamCount, + received: WebTransportStreamCount, + queued: usize, +} + +impl IncomingStreamCredit { + const fn new(advertised_max: WebTransportStreamCount) -> Self { + Self { + advertised_max, + received: WebTransportStreamCount::ZERO, + queued: 0, + } + } + + fn reserve_incoming(&mut self, queue_capacity: usize) -> Result<(), SessionFlowControlError> { + if self.received >= self.advertised_max { + return Err(SessionFlowControlError::ExceededStreamCredit); + } + if self.queued >= queue_capacity { + return Err(SessionFlowControlError::QueueCapacityInvariant); + } + self.received = self + .received + .checked_increment() + .context(session_flow_control_error::StreamCountSnafu)?; + self.queued += 1; + Ok(()) + } + + fn accept_one(&mut self) -> Result { + self.queued = self.queued.saturating_sub(1); + self.advertised_max = self + .advertised_max + .checked_increment() + .context(session_flow_control_error::StreamCountSnafu)?; + Ok(self.advertised_max) + } +} + +#[derive(Debug)] +struct LocalOpenCredit { + peer_max: WebTransportStreamCount, + opened: WebTransportStreamCount, + last_blocked_sent: Option, + changed: watch::Sender, +} + +impl LocalOpenCredit { + fn new(peer_max: WebTransportStreamCount) -> Self { + let (changed, _rx) = watch::channel(peer_max); + Self { + peer_max, + opened: WebTransportStreamCount::ZERO, + last_blocked_sent: None, + changed, + } + } + + fn try_reserve(&mut self) -> Result<(), WebTransportStreamCount> { + if self.opened >= self.peer_max { + return Err(self.peer_max); + } + self.opened = self + .opened + .checked_increment() + .expect("opened stream count cannot overflow below peer maximum"); + Ok(()) + } + + fn block(&mut self) -> LocalStreamCreditBlock { + let maximum = self.peer_max; + let send_blocked = self.last_blocked_sent != Some(maximum); + if send_blocked { + self.last_blocked_sent = Some(maximum); + } + LocalStreamCreditBlock { + maximum, + send_blocked, + changed: self.changed.subscribe(), + } + } + + fn update_peer_max( + &mut self, + peer_max: WebTransportStreamCount, + ) -> Result<(), SessionFlowControlError> { + if peer_max < self.peer_max { + return Err(SessionFlowControlError::DecreasingMaxStreams); + } + if peer_max > self.peer_max { + self.peer_max = peer_max; + self.last_blocked_sent = None; + let _ = self.changed.send(peer_max); + } + Ok(()) + } +} + +pub(super) struct LocalStreamCreditBlock { + pub(super) maximum: WebTransportStreamCount, + pub(super) send_blocked: bool, + pub(super) changed: watch::Receiver, +} + +pub(super) enum LocalStreamCreditReservation { + Reserved, + Blocked(LocalStreamCreditBlock), +} + +#[derive(Debug)] +pub(super) struct RegisteredSession { + pub(super) state: Arc, + pub(super) bidi_rx: mpsc::Receiver, + pub(super) uni_rx: mpsc::Receiver, +} + +#[derive(Debug)] +pub(super) struct SessionState { + session_id: WebTransportSessionId, + registry: Registry, + closed: AtomicBool, + close_reason: watch::Sender>, + drain_status: watch::Sender>, + bidi_tx: mpsc::Sender, + uni_tx: mpsc::Sender, + bidi_credit: Mutex, + uni_credit: Mutex, + local_bidi_credit: Mutex, + local_uni_credit: Mutex, + bidi_queue_capacity: usize, + uni_queue_capacity: usize, + tracked_readers: Mutex>, + tracked_writers: Mutex>, +} + +impl SessionState { + #[allow(clippy::too_many_arguments)] + fn new( + session_id: WebTransportSessionId, + registry: Registry, + bidi_tx: mpsc::Sender, + uni_tx: mpsc::Sender, + bidi_credit: WebTransportStreamCount, + uni_credit: WebTransportStreamCount, + bidi_queue_capacity: usize, + uni_queue_capacity: usize, + ) -> Self { + let (close_reason, _close_rx) = watch::channel(None); + let (drain_status, _drain_rx) = watch::channel(None); + let default_local_credit = default_initial_stream_credit(); + Self { + session_id, + registry, + closed: AtomicBool::new(false), + close_reason, + drain_status, + bidi_tx, + uni_tx, + bidi_credit: Mutex::new(IncomingStreamCredit::new(bidi_credit)), + uni_credit: Mutex::new(IncomingStreamCredit::new(uni_credit)), + local_bidi_credit: Mutex::new(LocalOpenCredit::new(default_local_credit)), + local_uni_credit: Mutex::new(LocalOpenCredit::new(default_local_credit)), + bidi_queue_capacity, + uni_queue_capacity, + tracked_readers: Mutex::new(HashMap::new()), + tracked_writers: Mutex::new(HashMap::new()), + } + } + + pub(super) fn id(&self) -> WebTransportSessionId { + self.session_id + } + + pub(super) fn check_open(&self) -> Result<(), SessionClosed> { + if self.closed.load(Ordering::Acquire) { + Err(SessionClosed) + } else { + Ok(()) + } + } + + pub(super) fn set_local_stream_credit( + &self, + peer_bidi_credit: WebTransportStreamCount, + peer_uni_credit: WebTransportStreamCount, + ) { + if let Ok(mut credit) = self.local_bidi_credit.lock() { + *credit = LocalOpenCredit::new(peer_bidi_credit); + } else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session local bidi credit lock poisoned" + ); + self.close(); + } + if let Ok(mut credit) = self.local_uni_credit.lock() { + *credit = LocalOpenCredit::new(peer_uni_credit); + } else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session local uni credit lock poisoned" + ); + self.close(); + } + } + + pub(super) fn close(&self) { + self.close_with_reason(CloseReason::Session(SessionCloseReason::ControlStreamError)); + } + + pub(super) fn close_with_reason(&self, reason: CloseReason) { + if !self.closed.swap(true, Ordering::AcqRel) { + let _ = self.close_reason.send(Some(reason.clone())); + let _ = self.drain_status.send(Some(SessionDrain::Closed(reason))); + self.registry.unregister(self.session_id); + let (readers, writers) = self.take_tracked_streams(); + spawn_tracked_stream_cleanup(self.session_id, readers, writers); + } + } + + pub(super) fn drain_with_reason(&self, drain: SessionDrain) { + if self.drain_status.borrow().is_none() { + let _ = self.drain_status.send(Some(drain)); + } + } + + pub(super) async fn closed(&self) -> CloseReason { + let mut reason = self.close_reason.subscribe(); + loop { + if let Some(reason) = reason.borrow().clone() { + return reason; + } + if reason.changed().await.is_err() { + return CloseReason::Session(SessionCloseReason::ControlStreamError); + } + } + } + + pub(super) async fn drained(&self) -> SessionDrain { + let mut drain = self.drain_status.subscribe(); + loop { + if let Some(drain) = drain.borrow().clone() { + return drain; + } + if drain.changed().await.is_err() { + return SessionDrain::Closed(CloseReason::Session( + SessionCloseReason::ControlStreamError, + )); + } + } + } + + pub(super) fn insert_tracked_bi( + &self, + stream_id: StreamId, + reader: TrackedStreamReader, + writer: TrackedStreamWriter, + ) -> Result<(), SessionClosed> { + self.check_open()?; + + let Ok(mut readers) = self.tracked_readers.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session reader tracking lock poisoned" + ); + return Err(SessionClosed); + }; + let Ok(mut writers) = self.tracked_writers.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session writer tracking lock poisoned" + ); + return Err(SessionClosed); + }; + + if self.closed.load(Ordering::Acquire) { + return Err(SessionClosed); + } + + readers.insert(stream_id, reader); + writers.insert(stream_id, writer); + + if self.closed.load(Ordering::Acquire) { + readers.remove(&stream_id); + writers.remove(&stream_id); + Err(SessionClosed) + } else { + Ok(()) + } + } + + pub(super) fn insert_tracked_reader( + &self, + stream_id: StreamId, + reader: TrackedStreamReader, + ) -> Result<(), SessionClosed> { + self.check_open()?; + + let Ok(mut readers) = self.tracked_readers.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session reader tracking lock poisoned" + ); + return Err(SessionClosed); + }; + + if self.closed.load(Ordering::Acquire) { + return Err(SessionClosed); + } + + readers.insert(stream_id, reader); + + if self.closed.load(Ordering::Acquire) { + readers.remove(&stream_id); + Err(SessionClosed) + } else { + Ok(()) + } + } + + pub(super) fn insert_tracked_writer( + &self, + stream_id: StreamId, + writer: TrackedStreamWriter, + ) -> Result<(), SessionClosed> { + self.check_open()?; + + let Ok(mut writers) = self.tracked_writers.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session writer tracking lock poisoned" + ); + return Err(SessionClosed); + }; + + if self.closed.load(Ordering::Acquire) { + return Err(SessionClosed); + } + + writers.insert(stream_id, writer); + + if self.closed.load(Ordering::Acquire) { + writers.remove(&stream_id); + Err(SessionClosed) + } else { + Ok(()) + } + } + + pub(super) fn route_incoming_bi(&self, stream: RoutedBiStream) -> Result<(), RouteBiError> { + if self.check_open().is_err() { + return Err(RouteBiError::Closed(stream)); + } + + let Ok(mut credit) = self.bidi_credit.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session bidi credit lock poisoned" + ); + self.close(); + return Err(RouteBiError::Rejected(stream)); + }; + if let Err(error) = credit.reserve_incoming(self.bidi_queue_capacity) { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %self.session_id, + error = %report, + "webtransport session bidi stream credit exhausted" + ); + drop(credit); + self.close(); + return Err(RouteBiError::FlowControl(stream)); + } + drop(credit); + + match self.bidi_tx.try_send(stream) { + Ok(()) => Ok(()), + Err(error) => { + tracing::debug!( + session_id = %self.session_id, + "session bidi channel full or closed, rejecting stream" + ); + self.close(); + Err(RouteBiError::Rejected(error.into_inner())) + } + } + } + + pub(super) fn route_incoming_uni(&self, stream: RoutedUniStream) -> Result<(), RouteUniError> { + if self.check_open().is_err() { + return Err(RouteUniError::Closed(stream)); + } + + let Ok(mut credit) = self.uni_credit.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session uni credit lock poisoned" + ); + self.close(); + return Err(RouteUniError::Rejected(stream)); + }; + if let Err(error) = credit.reserve_incoming(self.uni_queue_capacity) { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %self.session_id, + error = %report, + "webtransport session uni stream credit exhausted" + ); + drop(credit); + self.close(); + return Err(RouteUniError::FlowControl(stream)); + } + drop(credit); + + match self.uni_tx.try_send(stream) { + Ok(()) => Ok(()), + Err(error) => { + tracing::debug!( + session_id = %self.session_id, + "session uni channel full or closed, rejecting stream" + ); + self.close(); + Err(RouteUniError::Rejected(error.into_inner())) + } + } + } + + pub(super) fn accept_incoming_bi( + &self, + ) -> Result { + let Ok(mut credit) = self.bidi_credit.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session bidi credit lock poisoned" + ); + return Err(SessionFlowControlError::QueueCapacityInvariant); + }; + credit.accept_one() + } + + pub(super) fn accept_incoming_uni( + &self, + ) -> Result { + let Ok(mut credit) = self.uni_credit.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session uni credit lock poisoned" + ); + return Err(SessionFlowControlError::QueueCapacityInvariant); + }; + credit.accept_one() + } + + pub(super) fn reserve_local_bidi(&self) -> Result { + self.reserve_local_credit(&self.local_bidi_credit, "bidi") + } + + pub(super) fn reserve_local_uni(&self) -> Result { + self.reserve_local_credit(&self.local_uni_credit, "uni") + } + + fn reserve_local_credit( + &self, + credit: &Mutex, + direction: &'static str, + ) -> Result { + self.check_open()?; + let Ok(mut credit) = credit.lock() else { + tracing::debug!( + session_id = %self.session_id, + direction, + "webtransport session local stream credit lock poisoned" + ); + self.close(); + return Err(SessionClosed); + }; + match credit.try_reserve() { + Ok(()) => Ok(LocalStreamCreditReservation::Reserved), + Err(_) => Ok(LocalStreamCreditReservation::Blocked(credit.block())), + } + } + + pub(super) fn update_peer_bidi_max( + &self, + peer_max: WebTransportStreamCount, + ) -> Result<(), SessionFlowControlError> { + self.update_peer_max(&self.local_bidi_credit, peer_max, "bidi") + } + + pub(super) fn update_peer_uni_max( + &self, + peer_max: WebTransportStreamCount, + ) -> Result<(), SessionFlowControlError> { + self.update_peer_max(&self.local_uni_credit, peer_max, "uni") + } + + fn update_peer_max( + &self, + credit: &Mutex, + peer_max: WebTransportStreamCount, + direction: &'static str, + ) -> Result<(), SessionFlowControlError> { + let Ok(mut credit) = credit.lock() else { + tracing::debug!( + session_id = %self.session_id, + direction, + "webtransport session local stream credit lock poisoned" + ); + self.close(); + return Err(SessionFlowControlError::QueueCapacityInvariant); + }; + credit.update_peer_max(peer_max) + } + + pub(super) fn remove_tracked_reader(&self, stream_id: StreamId) { + let Ok(mut readers) = self.tracked_readers.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session reader tracking lock poisoned" + ); + return; + }; + readers.remove(&stream_id); + } + + pub(super) fn remove_tracked_writer(&self, stream_id: StreamId) { + let Ok(mut writers) = self.tracked_writers.lock() else { + tracing::debug!( + session_id = %self.session_id, + "webtransport session writer tracking lock poisoned" + ); + return; + }; + writers.remove(&stream_id); + } + + fn take_tracked_streams(&self) -> (Vec, Vec) { + let readers = match self.tracked_readers.lock() { + Ok(mut readers) => readers.drain().map(|(_stream_id, reader)| reader).collect(), + Err(_) => { + tracing::debug!( + session_id = %self.session_id, + "webtransport session reader tracking lock poisoned" + ); + Vec::new() + } + }; + let writers = match self.tracked_writers.lock() { + Ok(mut writers) => writers.drain().map(|(_stream_id, writer)| writer).collect(), + Err(_) => { + tracing::debug!( + session_id = %self.session_id, + "webtransport session writer tracking lock poisoned" + ); + Vec::new() + } + }; + (readers, writers) + } +} + +fn spawn_tracked_stream_cleanup( + session_id: WebTransportSessionId, + readers: Vec, + writers: Vec, +) { + if readers.is_empty() && writers.is_empty() { + return; + } + + let cleanup = cleanup_tracked_streams(session_id, readers, writers); + match tokio::runtime::Handle::try_current() { + Ok(handle) => { + // Inherent termination: the task owns the taken tracked halves and + // exits after every STOP/RESET future resolves or returns an error. + let _cleanup_task = handle.spawn(cleanup.in_current_span()); + } + Err(error) => { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %session_id, + error = %report, + "failed to spawn webtransport session stream cleanup" + ); + } + } +} + +async fn cleanup_tracked_streams( + session_id: WebTransportSessionId, + readers: Vec, + writers: Vec, +) { + let mut cleanup = FuturesUnordered::>::new(); + + for mut reader in readers { + cleanup.push(Box::pin(async move { + if let Err(error) = reader.stop(Code::WT_SESSION_GONE.into_inner()).await { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %session_id, + error = %report, + "failed to stop webtransport session stream reader" + ); + } + })); + } + + for mut writer in writers { + cleanup.push(Box::pin(async move { + if let Err(error) = writer.reset(Code::WT_SESSION_GONE.into_inner()).await { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %session_id, + error = %report, + "failed to reset webtransport session stream writer" + ); + } + })); + } + + while cleanup.next().await.is_some() {} +} + +impl Drop for SessionState { + fn drop(&mut self) { + self.close(); + } +} + +fn default_initial_stream_credit() -> WebTransportStreamCount { + WebTransportStreamCount::try_from(VarInt::from_u32(SESSION_STREAM_CHANNEL_SIZE as u32)) + .expect("default webtransport stream credit is valid") +} + +fn incoming_stream_queue_capacity(credit: WebTransportStreamCount) -> usize { + usize::try_from(credit.into_varint().into_inner()) + .unwrap_or(SESSION_STREAM_CHANNEL_SIZE) + .clamp(1, SESSION_STREAM_CHANNEL_SIZE) +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + panic::{AssertUnwindSafe, catch_unwind}, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use futures::{Sink, SinkExt, Stream, StreamExt}; + + use super::*; + use crate::{ + quic::{ + self, BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamIdExt, ResetStreamExt, + StopStreamExt, + }, + varint::VarInt, + webtransport::WebTransportStreamCount, + }; + + #[derive(Debug, Default)] + struct StreamState { + written: Mutex>, + } + + #[derive(Debug)] + struct TestReadStream { + chunks: VecDeque, + stream_id: VarInt, + } + + impl Stream for TestReadStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.chunks.pop_front().map(Ok)) + } + } + + impl quic::GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct TestWriteStream { + state: Arc, + stream_id: VarInt, + } + + impl Sink for TestWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.state + .written + .lock() + .expect("written lock poisoned") + .extend_from_slice(&item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl quic::GetStreamId for TestWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::ResetStream for TestWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + fn test_read_stream(id: u32, bytes: Vec) -> BoxQuicStreamReader { + Box::pin(TestReadStream { + chunks: VecDeque::from([Bytes::from(bytes)]), + stream_id: VarInt::from_u32(id), + }) as BoxQuicStreamReader + } + + fn test_write_stream(id: u32, state: Arc) -> BoxQuicStreamWriter { + Box::pin(TestWriteStream { + state, + stream_id: VarInt::from_u32(id), + }) as BoxQuicStreamWriter + } + + fn bidi_stream(id: u32) -> RoutedBiStream { + let state = Arc::new(StreamState::default()); + ( + test_read_stream(id, vec![id as u8]), + test_write_stream(id, Arc::clone(&state)), + ) + } + + fn uni_stream(id: u32) -> RoutedUniStream { + test_read_stream(id, vec![id as u8]) + } + + fn wt_session_id(id: u32) -> WebTransportSessionId { + WebTransportSessionId::try_from(StreamId::from(VarInt::from_u32(id))) + .expect("test id must be a valid webtransport session id") + } + + fn poison_registry(registry: &Registry) { + let registry = registry.clone(); + let _ = catch_unwind(AssertUnwindSafe(move || { + let _guard = registry.inner.lock().expect("registry lock should succeed"); + panic!("poison registry mutex"); + })); + } + + fn route_bi_error_stream(error: RouteBiError) -> RoutedBiStream { + match error { + RouteBiError::Unknown(stream) + | RouteBiError::Closed(stream) + | RouteBiError::FlowControl(stream) + | RouteBiError::Rejected(stream) => stream, + } + } + + fn route_uni_error_stream(error: RouteUniError) -> RoutedUniStream { + match error { + RouteUniError::Unknown(stream) + | RouteUniError::Closed(stream) + | RouteUniError::FlowControl(stream) + | RouteUniError::Rejected(stream) => stream, + } + } + + #[tokio::test] + async fn test_read_stream_reads_chunks_and_stops() { + let mut stream = test_read_stream(40, b"chunk".to_vec()); + + assert_eq!( + stream + .next() + .await + .expect("read stream should yield one chunk") + .expect("read chunk should succeed"), + Bytes::from_static(b"chunk") + ); + assert!( + stream.next().await.is_none(), + "read stream should be exhausted" + ); + stream + .stop(VarInt::from_u32(41)) + .await + .expect("read stream stop should succeed"); + } + + #[tokio::test] + async fn test_write_stream_writes_flushes_closes_and_resets() { + let state = Arc::new(StreamState::default()); + let mut stream = test_write_stream(42, Arc::clone(&state)); + + assert_eq!( + stream + .stream_id() + .await + .expect("write stream id should be readable"), + VarInt::from_u32(42) + ); + stream + .send(Bytes::from_static(b"payload")) + .await + .expect("write stream send should succeed"); + stream + .close() + .await + .expect("write stream close should succeed"); + stream + .reset(VarInt::from_u32(43)) + .await + .expect("write stream reset should succeed"); + + assert_eq!( + *state + .written + .lock() + .expect("written lock should not poison"), + b"payload".to_vec() + ); + } + + #[test] + fn register_len_duplicate_and_close_unregister() { + let registry = Registry::default(); + let session_id = wt_session_id(4); + + let registered = registry + .register(session_id) + .expect("first registration should succeed"); + assert_eq!(registry.len(), 1); + assert_eq!(registered.state.id(), session_id); + assert!(registered.state.check_open().is_ok()); + + let error = registry + .register(session_id) + .expect_err("duplicate session should be rejected"); + assert!(matches!( + error, + RegisterSessionError::AlreadyRegistered { session_id: duplicate } if duplicate == session_id + )); + + registered.state.close(); + assert_eq!(registry.len(), 0); + assert!(registered.state.check_open().is_err()); + + registered.state.close(); + assert_eq!(registry.len(), 0); + } + + #[test] + fn dropping_registered_session_unregisters_it() { + let registry = Registry::default(); + let session_id = wt_session_id(8); + + let registered = registry + .register(session_id) + .expect("registration should succeed"); + assert_eq!(registry.len(), 1); + + drop(registered); + assert_eq!(registry.len(), 0); + } + + #[tokio::test] + async fn route_unknown_sessions_return_original_streams() { + let registry = Registry::default(); + let session_id = wt_session_id(4); + + let bidi = bidi_stream(1); + let error = registry + .route_bi(session_id, bidi) + .expect_err("unknown bidi session should reject stream"); + assert!(matches!(&error, RouteBiError::Unknown(_))); + let mut returned_bidi = route_bi_error_stream(error); + assert_eq!( + returned_bidi + .0 + .stream_id() + .await + .expect("returned bidi stream id should be readable"), + VarInt::from_u32(1) + ); + + let uni = uni_stream(2); + let error = registry + .route_uni(session_id, uni) + .expect_err("unknown uni session should reject stream"); + assert!(matches!(&error, RouteUniError::Unknown(_))); + let mut returned_uni = route_uni_error_stream(error); + assert_eq!( + returned_uni + .stream_id() + .await + .expect("returned uni stream id should be readable"), + VarInt::from_u32(2) + ); + } + + #[tokio::test] + async fn route_known_sessions_deliver_streams_to_receivers() { + let registry = Registry::default(); + let session_id = wt_session_id(4); + let mut registered = registry + .register(session_id) + .expect("registration should succeed"); + + assert!(registry.route_bi(session_id, bidi_stream(3)).is_ok()); + assert!(registry.route_uni(session_id, uni_stream(4)).is_ok()); + + let (mut bidi_reader, _bidi_writer) = registered + .bidi_rx + .recv() + .await + .expect("bidi receiver should get a stream"); + let mut uni_reader = registered + .uni_rx + .recv() + .await + .expect("uni receiver should get a stream"); + + assert_eq!( + bidi_reader + .stream_id() + .await + .expect("bidi stream id should be readable"), + VarInt::from_u32(3) + ); + assert_eq!( + uni_reader + .stream_id() + .await + .expect("uni stream id should be readable"), + VarInt::from_u32(4) + ); + } + + #[tokio::test] + async fn route_incoming_bidi_closes_session_when_peer_exceeds_advertised_credit() { + let registry = Registry::default(); + let session_id = wt_session_id(4); + let mut registered = registry + .register_with_credit( + session_id, + WebTransportStreamCount::try_from(VarInt::from_u32(1)).expect("bidi credit"), + WebTransportStreamCount::try_from(VarInt::from_u32(0)).expect("uni credit"), + ) + .expect("registration should succeed"); + + assert!( + registry.route_bi(session_id, bidi_stream(8)).is_ok(), + "first stream should be allowed" + ); + let error = registry + .route_bi(session_id, bidi_stream(12)) + .expect_err("second stream exceeds credit"); + + assert!(matches!(error, RouteBiError::FlowControl(_))); + assert!(registered.state.check_open().is_err()); + let (_reader, _writer) = registered + .bidi_rx + .recv() + .await + .expect("first stream remains queued"); + } + + #[tokio::test] + async fn route_incoming_uni_closes_session_when_peer_exceeds_advertised_credit() { + let registry = Registry::default(); + let session_id = wt_session_id(4); + let mut registered = registry + .register_with_credit( + session_id, + WebTransportStreamCount::try_from(VarInt::from_u32(0)).expect("bidi credit"), + WebTransportStreamCount::try_from(VarInt::from_u32(1)).expect("uni credit"), + ) + .expect("registration should succeed"); + + assert!( + registry.route_uni(session_id, uni_stream(10)).is_ok(), + "first stream should be allowed" + ); + let error = registry + .route_uni(session_id, uni_stream(14)) + .expect_err("second stream exceeds credit"); + + assert!(matches!(error, RouteUniError::FlowControl(_))); + assert!(registered.state.check_open().is_err()); + let _reader = registered + .uni_rx + .recv() + .await + .expect("first stream remains queued"); + } + + #[test] + fn route_rejects_closed_or_full_channels() { + let registry = Registry::default(); + let bidi_session_id = wt_session_id(4); + let bidi_registered = registry + .register(bidi_session_id) + .expect("registration should succeed"); + + drop(bidi_registered.bidi_rx); + assert!(matches!( + registry.route_bi(bidi_session_id, bidi_stream(5)), + Err(RouteBiError::Rejected(_)) + )); + + let uni_session_id = wt_session_id(8); + let uni_registered = registry + .register(uni_session_id) + .expect("registration should succeed"); + + drop(uni_registered.uni_rx); + assert!(matches!( + registry.route_uni(uni_session_id, uni_stream(6)), + Err(RouteUniError::Rejected(_)) + )); + } + + #[tokio::test] + async fn route_closed_channels_return_original_streams() { + let registry = Registry::default(); + let bidi_session_id = wt_session_id(16); + let bidi_registered = registry + .register(bidi_session_id) + .expect("registration should succeed"); + + drop(bidi_registered.bidi_rx); + + let error = registry + .route_bi(bidi_session_id, bidi_stream(17)) + .expect_err("closed bidi receiver should reject stream"); + assert!(matches!(&error, RouteBiError::Rejected(_))); + let mut returned_bidi = route_bi_error_stream(error); + assert_eq!( + returned_bidi + .0 + .stream_id() + .await + .expect("returned bidi stream id should be readable"), + VarInt::from_u32(17) + ); + + let uni_session_id = wt_session_id(20); + let uni_registered = registry + .register(uni_session_id) + .expect("registration should succeed"); + + drop(uni_registered.uni_rx); + + let error = registry + .route_uni(uni_session_id, uni_stream(18)) + .expect_err("closed uni receiver should reject stream"); + assert!(matches!(&error, RouteUniError::Rejected(_))); + let mut returned_uni = route_uni_error_stream(error); + assert_eq!( + returned_uni + .stream_id() + .await + .expect("returned uni stream id should be readable"), + VarInt::from_u32(18) + ); + } + + #[test] + fn route_flow_control_when_initial_credit_is_exhausted() { + let registry = Registry::default(); + let bidi_session_id = wt_session_id(24); + let _bidi_registered = registry + .register(bidi_session_id) + .expect("registration should succeed"); + + for id in 0..SESSION_STREAM_CHANNEL_SIZE { + assert!( + registry + .route_bi(bidi_session_id, bidi_stream(id as u32)) + .is_ok() + ); + } + + assert!(matches!( + registry.route_bi(bidi_session_id, bidi_stream(99)), + Err(RouteBiError::FlowControl(_)) + )); + + let uni_session_id = wt_session_id(28); + let _uni_registered = registry + .register(uni_session_id) + .expect("registration should succeed"); + + for id in 0..SESSION_STREAM_CHANNEL_SIZE { + assert!( + registry + .route_uni(uni_session_id, uni_stream(id as u32)) + .is_ok() + ); + } + + assert!(matches!( + registry.route_uni(uni_session_id, uni_stream(100)), + Err(RouteUniError::FlowControl(_)) + )); + } + + #[tokio::test] + async fn route_uses_exact_registered_session_id() { + let registry = Registry::default(); + let first_session_id = wt_session_id(32); + let second_session_id = wt_session_id(36); + let mut first = registry + .register(first_session_id) + .expect("first registration should succeed"); + let mut second = registry + .register(second_session_id) + .expect("second registration should succeed"); + + assert!( + registry + .route_bi(second_session_id, bidi_stream(37)) + .is_ok() + ); + assert!( + registry + .route_uni(second_session_id, uni_stream(38)) + .is_ok() + ); + + assert!(matches!( + first.bidi_rx.try_recv(), + Err(mpsc::error::TryRecvError::Empty) + )); + assert!(matches!( + first.uni_rx.try_recv(), + Err(mpsc::error::TryRecvError::Empty) + )); + + let (mut second_bidi_reader, _second_bidi_writer) = second + .bidi_rx + .recv() + .await + .expect("second session should receive routed bidi stream"); + assert_eq!( + second_bidi_reader + .stream_id() + .await + .expect("second bidi stream id should be readable"), + VarInt::from_u32(37) + ); + + let mut second_uni_reader = second + .uni_rx + .recv() + .await + .expect("second session should receive routed uni stream"); + assert_eq!( + second_uni_reader + .stream_id() + .await + .expect("second uni stream id should be readable"), + VarInt::from_u32(38) + ); + } + + #[tokio::test] + async fn route_full_channels_return_original_streams() { + let registry = Registry::default(); + let bidi_session_id = wt_session_id(40); + let _bidi_registered = registry + .register(bidi_session_id) + .expect("registration should succeed"); + + for id in 0..SESSION_STREAM_CHANNEL_SIZE { + assert!( + registry + .route_bi(bidi_session_id, bidi_stream(id as u32)) + .is_ok() + ); + } + + let error = registry + .route_bi(bidi_session_id, bidi_stream(21)) + .expect_err("exhausted bidi credit should reject stream"); + assert!(matches!(&error, RouteBiError::FlowControl(_))); + let mut returned_bidi = route_bi_error_stream(error); + assert_eq!( + returned_bidi + .0 + .stream_id() + .await + .expect("returned bidi stream id should be readable"), + VarInt::from_u32(21) + ); + + let uni_session_id = wt_session_id(44); + let _uni_registered = registry + .register(uni_session_id) + .expect("registration should succeed"); + + for id in 0..SESSION_STREAM_CHANNEL_SIZE { + assert!( + registry + .route_uni(uni_session_id, uni_stream(id as u32)) + .is_ok() + ); + } + + let error = registry + .route_uni(uni_session_id, uni_stream(22)) + .expect_err("exhausted uni credit should reject stream"); + assert!(matches!(&error, RouteUniError::FlowControl(_))); + let mut returned_uni = route_uni_error_stream(error); + assert_eq!( + returned_uni + .stream_id() + .await + .expect("returned uni stream id should be readable"), + VarInt::from_u32(22) + ); + } + + #[tokio::test] + async fn closed_session_tombstone_is_preserved_for_connection_lifetime() { + let registry = Registry::default(); + let session_id = wt_session_id(24); + let registered = registry + .register(session_id) + .expect("initial registration should succeed"); + + registered.state.close(); + assert_eq!(registry.len(), 0); + + let error = registry + .route_uni(session_id, uni_stream(25)) + .expect_err("closed session should no longer route streams"); + assert!(matches!(&error, RouteUniError::Closed(_))); + let mut returned = route_uni_error_stream(error); + assert_eq!( + returned + .stream_id() + .await + .expect("returned uni stream id should be readable"), + VarInt::from_u32(25) + ); + + let error = registry + .register(session_id) + .expect_err("closed session id must not be re-registered"); + assert!(matches!( + error, + RegisterSessionError::AlreadyRegistered { + session_id: duplicate + } if duplicate == session_id + )); + } + + #[test] + fn cloned_session_state_keeps_registration_until_last_state_drop() { + let registry = Registry::default(); + let session_id = wt_session_id(28); + let registered = registry + .register(session_id) + .expect("registration should succeed"); + let state = Arc::clone(®istered.state); + + drop(registered); + assert_eq!(registry.len(), 1); + assert!(matches!( + registry.route_uni(session_id, uni_stream(29)), + Err(RouteUniError::Rejected(_)) + )); + assert!(state.check_open().is_err()); + assert_eq!(registry.len(), 0); + + drop(state); + assert_eq!(registry.len(), 0); + } + + #[tokio::test] + async fn poisoned_registry_surfaces_errors_and_preserves_streams() { + let registry = Registry::default(); + let session_id = wt_session_id(12); + poison_registry(®istry); + + let error = registry + .register(session_id) + .expect_err("poisoned registry should reject new registrations"); + assert!(matches!(error, RegisterSessionError::RegistryPoisoned)); + assert_eq!(registry.len(), 0); + + registry.unregister(session_id); + + let error = registry + .route_bi(session_id, bidi_stream(13)) + .expect_err("poisoned registry should return routed bidi stream"); + assert!(matches!(&error, RouteBiError::Rejected(_))); + let mut returned_bidi = route_bi_error_stream(error); + assert_eq!( + returned_bidi + .0 + .stream_id() + .await + .expect("returned bidi stream id should be readable"), + VarInt::from_u32(13) + ); + + let error = registry + .route_uni(session_id, uni_stream(14)) + .expect_err("poisoned registry should return routed uni stream"); + assert!(matches!(&error, RouteUniError::Rejected(_))); + let mut returned_uni = route_uni_error_stream(error); + assert_eq!( + returned_uni + .stream_id() + .await + .expect("returned uni stream id should be readable"), + VarInt::from_u32(14) + ); + } +} diff --git a/src/webtransport/session.rs b/src/webtransport/session.rs index 84c41b7..f8d2763 100644 --- a/src/webtransport/session.rs +++ b/src/webtransport/session.rs @@ -1,7 +1,7 @@ //! WebTransport session handle. //! -//! A [`WebTransportSession`] is returned by -//! [`WebTransportProtocol::register`](super::WebTransportProtocol::register) and +//! A [`WebTransportSession`] is created from an +//! [`EstablishedConnect`](crate::extended_connect::EstablishedConnect) and //! provides the application-facing API for opening and accepting streams within a //! WebTransport session. //! @@ -9,32 +9,60 @@ use std::{fmt, sync::Arc}; -use snafu::ResultExt; -use tokio::{io::AsyncWriteExt, sync::mpsc}; +use bytes::Buf; +use snafu::{ResultExt, Snafu}; +use tokio::{ + io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt}, + sync::{mpsc, oneshot}, +}; +use tokio_util::task::AbortOnDropHandle; +use tracing::Instrument; use super::{ + CloseSession, DecodeCloseSessionError, DecodeWebTransportStreamCountError, + WebTransportSessionId, WebTransportStreamCount, error::{ - Closed, DatagramError, OpenSnafu, OpenStreamError, UnsupportedSnafu, WriteHeaderSnafu, + AcceptStreamError, CloseReason, CloseSessionError, ControlCommandError, DatagramError, + DrainReason, DrainSessionError, OpenStreamError, RegisterSessionError, SessionCloseReason, + SessionClosed, SessionDrain, SessionDrainReason, UnsupportedSnafu, accept_stream_error, + close_session_error, control_command_error, drain_session_error, open_stream_error, + register_session_error, }, - protocol::Registry, + protocol::{WEBTRANSPORT_BIDI_SIGNAL, WEBTRANSPORT_H3, WEBTRANSPORT_UNI_SIGNAL}, + registry::{LocalStreamCreditReservation, RegisteredSession, SessionState}, }; use crate::{ - codec::{BoxReadStream, BoxWriteStream, EncodeExt, SinkWriter}, - quic::{self}, + buflist::BufList, + codec::{DecodeExt, EncodeExt, EncodeInto, SinkWriter, StreamReader}, + dhttp::{ + message::{MessageStreamError, MessageWriter}, + webtransport::capsule::{Capsule, CapsuleType}, + }, + extended_connect::{EstablishedConnect, IntoStreamsError}, + quic::{self, BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamIdExt}, + stream_id::StreamId, varint::VarInt, }; +pub(in crate::webtransport) mod stream; + +use stream::{WebTransportStreamReader, WebTransportStreamWriter}; + +const CONTROL_CAPSULE_SKIP_CHUNK_SIZE: usize = 8 * 1024; +const CLOSE_SESSION_CAPSULE_MAX_PAYLOAD: u64 = 4 + 1024; +const STREAM_COUNT_CAPSULE_MAX_PAYLOAD: u64 = 8; + // ============================================================================ // Stream type aliases // ============================================================================ /// A routed bidirectional stream (reader + writer) after signal/session-ID /// consumption by the protocol layer. -pub(super) type RoutedBiStream = (BoxReadStream, BoxWriteStream); +pub(super) type RoutedBiStream = (BoxQuicStreamReader, BoxQuicStreamWriter); /// A routed unidirectional stream (reader only) after signal/session-ID /// consumption by the protocol layer. -pub(super) type RoutedUniStream = BoxReadStream; +pub(super) type RoutedUniStream = BoxQuicStreamReader; // ============================================================================ // WebTransportSession @@ -48,69 +76,333 @@ pub(super) type RoutedUniStream = BoxReadStream; /// Dropping the handle automatically unregisters the session from the protocol /// registry. pub struct WebTransportSession { - session_id: VarInt, + state: Arc, bidi_rx: tokio::sync::Mutex>, uni_rx: tokio::sync::Mutex>, - conn: Arc, - registry: Registry, + conn: Arc, + control: ControlHandle, + _control_task: AbortOnDropHandle<()>, +} + +enum ControlCommand { + Drain { + ack: oneshot::Sender>, + }, + Close { + close: CloseSession, + ack: oneshot::Sender>, + }, + StreamsBlockedBidi { + maximum: WebTransportStreamCount, + ack: oneshot::Sender>, + }, + StreamsBlockedUni { + maximum: WebTransportStreamCount, + ack: oneshot::Sender>, + }, +} + +enum RemoteControlCapsule { + Close(CloseSession), + Drain, + MaxStreamsBidi(WebTransportStreamCount), + MaxStreamsUni(WebTransportStreamCount), + Ignored, +} + +#[derive(Clone)] +struct ControlHandle { + tx: mpsc::Sender, +} + +impl ControlHandle { + async fn drain(&self) -> Result<(), ControlCommandError> { + let (ack, rx) = oneshot::channel(); + if self.tx.send(ControlCommand::Drain { ack }).await.is_err() { + return Err(ControlCommandError::Closed); + } + match rx.await { + Ok(result) => result, + Err(_) => Err(ControlCommandError::ResponseDropped), + } + } + + async fn close(&self, close: CloseSession) -> Result<(), ControlCommandError> { + let (ack, rx) = oneshot::channel(); + if self + .tx + .send(ControlCommand::Close { close, ack }) + .await + .is_err() + { + return Err(ControlCommandError::Closed); + } + match rx.await { + Ok(result) => result, + Err(_) => Err(ControlCommandError::ResponseDropped), + } + } + + async fn streams_blocked_bidi( + &self, + maximum: WebTransportStreamCount, + ) -> Result<(), ControlCommandError> { + let (ack, rx) = oneshot::channel(); + if self + .tx + .send(ControlCommand::StreamsBlockedBidi { maximum, ack }) + .await + .is_err() + { + return Err(ControlCommandError::Closed); + } + match rx.await { + Ok(result) => result, + Err(_) => Err(ControlCommandError::ResponseDropped), + } + } + + async fn streams_blocked_uni( + &self, + maximum: WebTransportStreamCount, + ) -> Result<(), ControlCommandError> { + let (ack, rx) = oneshot::channel(); + if self + .tx + .send(ControlCommand::StreamsBlockedUni { maximum, ack }) + .await + .is_err() + { + return Err(ControlCommandError::Closed); + } + match rx.await { + Ok(result) => result, + Err(_) => Err(ControlCommandError::ResponseDropped), + } + } +} + +#[derive(Debug, Snafu)] +#[snafu(module(control_task_error), visibility(pub(super)))] +enum ControlTaskError { + #[snafu(display("failed to take over webtransport connect stream"))] + Takeover { source: IntoStreamsError }, + #[snafu(display("failed to decode webtransport control capsule"))] + DecodeCapsule { source: std::io::Error }, + #[snafu(display("webtransport close session capsule payload is too large"))] + CloseSessionPayloadTooLarge, + #[snafu(display("webtransport drain session capsule has a payload"))] + DrainSessionPayload, + #[snafu(display("invalid webtransport close session capsule"))] + CloseSessionPayload { source: DecodeCloseSessionError }, + #[snafu(display("invalid webtransport stream count capsule"))] + StreamCountPayload { + source: DecodeWebTransportStreamCountError, + }, + #[snafu(display("webtransport stream count capsule payload is too large"))] + StreamCountPayloadTooLarge, + #[snafu(display("webtransport stream count capsule has trailing bytes"))] + StreamCountPayloadTrailing, } impl WebTransportSession { - pub(super) fn new( - session_id: VarInt, - bidi_rx: mpsc::Receiver, - uni_rx: mpsc::Receiver, - conn: Arc, - registry: Registry, + #[cfg(test)] + fn from_registered( + registered: RegisteredSession, + conn: Arc, + connect: EstablishedConnect, + ) -> Self { + let default_credit = default_peer_stream_credit(); + Self::from_registered_with_peer_credit( + registered, + conn, + connect, + default_credit, + default_credit, + ) + } + + fn from_registered_with_peer_credit( + registered: RegisteredSession, + conn: Arc, + connect: EstablishedConnect, + peer_bidi_credit: WebTransportStreamCount, + peer_uni_credit: WebTransportStreamCount, + ) -> Self { + registered + .state + .set_local_stream_credit(peer_bidi_credit, peer_uni_credit); + Self::from_registered_inner(registered, conn, connect) + } + + fn from_registered_inner( + registered: RegisteredSession, + conn: Arc, + connect: EstablishedConnect, ) -> Self { + let state = registered.state; + let task_state = Arc::clone(&state); + let (control_tx, control_rx) = mpsc::channel(8); + let control = ControlHandle { tx: control_tx }; + let control_task = run_control_task(task_state, connect, control_rx); + Self { - session_id, - bidi_rx: tokio::sync::Mutex::new(bidi_rx), - uni_rx: tokio::sync::Mutex::new(uni_rx), + state, + bidi_rx: tokio::sync::Mutex::new(registered.bidi_rx), + uni_rx: tokio::sync::Mutex::new(registered.uni_rx), conn, - registry, + control, + _control_task: AbortOnDropHandle::new(tokio::spawn(control_task.in_current_span())), } } /// The session ID (QUIC stream ID of the CONNECT stream that established /// this session). - pub fn session_id(&self) -> VarInt { - self.session_id + pub fn id(&self) -> WebTransportSessionId { + self.state.id() } /// Open a new bidirectional stream within this session. /// - /// Writes the WT bidi signal value (`0x41`) and the session ID as a - /// routing header, then returns the raw stream pair positioned after the + /// Writes the WebTransport bidi signal value (`0x41`) and the session ID as + /// a routing header, then returns the raw stream pair positioned after the /// header. - pub async fn open_bi(&self) -> Result<(BoxReadStream, BoxWriteStream), OpenStreamError> { - let (reader, writer) = self.conn.open_bi().await.context(OpenSnafu)?; - let writer = write_header(writer, super::WT_BIDI_SIGNAL, self.session_id).await?; + pub async fn open_bi( + &self, + ) -> Result<(WebTransportStreamReader, WebTransportStreamWriter), OpenStreamError> { + self.state + .check_open() + .context(open_stream_error::ClosedSnafu)?; + self.reserve_local_bidi().await?; + let (mut reader, writer) = self + .conn + .open_bi() + .await + .context(open_stream_error::OpenSnafu)?; + let writer = write_header(writer, WEBTRANSPORT_BIDI_SIGNAL, self.id().stream_id()).await?; + let stream_id = reader + .stream_id() + .await + .context(open_stream_error::StreamIdSnafu)?; + let stream_id = StreamId::from(stream_id); + let (reader, tracked_reader) = + WebTransportStreamReader::tracked(stream_id, reader, Arc::downgrade(&self.state)); + let (writer, tracked_writer) = + WebTransportStreamWriter::tracked(stream_id, writer, Arc::downgrade(&self.state)); + self.state + .insert_tracked_bi(stream_id, tracked_reader, tracked_writer) + .context(open_stream_error::ClosedSnafu)?; Ok((reader, writer)) } /// Open a new unidirectional stream within this session. /// - /// Writes the WT uni signal value (`0x54`) and the session ID as a - /// routing header, then returns the write half positioned after the header. - pub async fn open_uni(&self) -> Result { - let writer = self.conn.open_uni().await.context(OpenSnafu)?; - let writer = write_header(writer, super::WT_UNI_SIGNAL, self.session_id).await?; + /// Writes the WebTransport uni signal value (`0x54`) and the session ID as + /// a routing header, then returns the write half positioned after the + /// header. + pub async fn open_uni(&self) -> Result { + self.state + .check_open() + .context(open_stream_error::ClosedSnafu)?; + self.reserve_local_uni().await?; + let writer = self + .conn + .open_uni() + .await + .context(open_stream_error::OpenSnafu)?; + let mut writer = + write_header(writer, WEBTRANSPORT_UNI_SIGNAL, self.id().stream_id()).await?; + let stream_id = writer + .stream_id() + .await + .context(open_stream_error::StreamIdSnafu)?; + let stream_id = StreamId::from(stream_id); + let (writer, tracked_writer) = + WebTransportStreamWriter::tracked(stream_id, writer, Arc::downgrade(&self.state)); + self.state + .insert_tracked_writer(stream_id, tracked_writer) + .context(open_stream_error::ClosedSnafu)?; Ok(writer) } /// Accept a bidirectional stream routed to this session by the protocol layer. - /// - /// Returns [`Closed`] when the session is closed and no more streams will arrive. - pub async fn accept_bi(&self) -> Result<(BoxReadStream, BoxWriteStream), Closed> { - self.bidi_rx.lock().await.recv().await.ok_or(Closed) + pub async fn accept_bi( + &self, + ) -> Result<(WebTransportStreamReader, WebTransportStreamWriter), AcceptStreamError> { + self.state + .check_open() + .context(accept_stream_error::ClosedSnafu)?; + let mut rx = self.bidi_rx.lock().await; + let (mut reader, writer) = tokio::select! { + biased; + source = self.conn.closed() => { + self.state.close(); + return Err(AcceptStreamError::Connection { source }); + } + stream = rx.recv() => stream.ok_or(SessionClosed).context(accept_stream_error::ClosedSnafu)?, + }; + drop(rx); + if let Err(error) = self.state.accept_incoming_bi() { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %self.id(), + error = %report, + "failed to advance webtransport bidi stream credit" + ); + self.state.close(); + return Err(SessionClosed).context(accept_stream_error::ClosedSnafu); + } + let stream_id = reader + .stream_id() + .await + .context(accept_stream_error::StreamIdSnafu)?; + let stream_id = StreamId::from(stream_id); + let (reader, tracked_reader) = + WebTransportStreamReader::tracked(stream_id, reader, Arc::downgrade(&self.state)); + let (writer, tracked_writer) = + WebTransportStreamWriter::tracked(stream_id, writer, Arc::downgrade(&self.state)); + self.state + .insert_tracked_bi(stream_id, tracked_reader, tracked_writer) + .context(accept_stream_error::ClosedSnafu)?; + Ok((reader, writer)) } /// Accept a unidirectional stream routed to this session by the protocol layer. - /// - /// Returns [`Closed`] when the session is closed and no more streams will arrive. - pub async fn accept_uni(&self) -> Result { - self.uni_rx.lock().await.recv().await.ok_or(Closed) + pub async fn accept_uni(&self) -> Result { + self.state + .check_open() + .context(accept_stream_error::ClosedSnafu)?; + let mut rx = self.uni_rx.lock().await; + let mut reader = tokio::select! { + biased; + source = self.conn.closed() => { + self.state.close(); + return Err(AcceptStreamError::Connection { source }); + } + stream = rx.recv() => stream.ok_or(SessionClosed).context(accept_stream_error::ClosedSnafu)?, + }; + drop(rx); + if let Err(error) = self.state.accept_incoming_uni() { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %self.id(), + error = %report, + "failed to advance webtransport uni stream credit" + ); + self.state.close(); + return Err(SessionClosed).context(accept_stream_error::ClosedSnafu); + } + let stream_id = reader + .stream_id() + .await + .context(accept_stream_error::StreamIdSnafu)?; + let stream_id = StreamId::from(stream_id); + let (reader, tracked_reader) = + WebTransportStreamReader::tracked(stream_id, reader, Arc::downgrade(&self.state)); + self.state + .insert_tracked_reader(stream_id, tracked_reader) + .context(accept_stream_error::ClosedSnafu)?; + Ok(reader) } /// Send a datagram within this session. @@ -126,6 +418,475 @@ impl WebTransportSession { pub async fn recv_datagram(&self) -> Result, DatagramError> { UnsupportedSnafu.fail() } + + pub async fn drain(&self) -> Result<(), DrainSessionError> { + self.state + .check_open() + .context(drain_session_error::ClosedSnafu)?; + self.control + .drain() + .await + .context(drain_session_error::CommandSnafu) + } + + pub async fn close(&self, close: CloseSession) -> Result<(), CloseSessionError> { + self.state + .check_open() + .context(close_session_error::ClosedSnafu)?; + self.control + .close(close) + .await + .context(close_session_error::CommandSnafu) + } + + pub async fn closed(&self) -> CloseReason { + self.state.closed().await + } + + pub async fn drained(&self) -> SessionDrain { + self.state.drained().await + } + + async fn reserve_local_bidi(&self) -> Result<(), OpenStreamError> { + loop { + match self + .state + .reserve_local_bidi() + .context(open_stream_error::ClosedSnafu)? + { + LocalStreamCreditReservation::Reserved => return Ok(()), + LocalStreamCreditReservation::Blocked(mut block) => { + if block.send_blocked { + self.control + .streams_blocked_bidi(block.maximum) + .await + .context(open_stream_error::ControlSnafu)?; + } + tokio::select! { + biased; + _ = self.state.closed() => { + return Err(SessionClosed).context(open_stream_error::ClosedSnafu); + } + changed = block.changed.changed() => { + if changed.is_err() { + return Err(SessionClosed).context(open_stream_error::ClosedSnafu); + } + } + } + } + } + } + } + + async fn reserve_local_uni(&self) -> Result<(), OpenStreamError> { + loop { + match self + .state + .reserve_local_uni() + .context(open_stream_error::ClosedSnafu)? + { + LocalStreamCreditReservation::Reserved => return Ok(()), + LocalStreamCreditReservation::Blocked(mut block) => { + if block.send_blocked { + self.control + .streams_blocked_uni(block.maximum) + .await + .context(open_stream_error::ControlSnafu)?; + } + tokio::select! { + biased; + _ = self.state.closed() => { + return Err(SessionClosed).context(open_stream_error::ClosedSnafu); + } + changed = block.changed.changed() => { + if changed.is_err() { + return Err(SessionClosed).context(open_stream_error::ClosedSnafu); + } + } + } + } + } + } + } +} + +#[cfg(test)] +fn default_peer_stream_credit() -> WebTransportStreamCount { + WebTransportStreamCount::try_from(VarInt::from_u32(16)) + .expect("default webtransport peer stream credit is valid") +} + +async fn run_control_task( + state: Arc, + connect: EstablishedConnect, + commands: mpsc::Receiver, +) { + let result = drive_control_task(Arc::clone(&state), connect, commands).await; + if let Err(error) = result { + tracing::debug!( + error = %snafu::Report::from_error(&error), + "webtransport control task failed" + ); + } + state.close(); +} + +async fn drive_control_task( + state: Arc, + connect: EstablishedConnect, + mut commands: mpsc::Receiver, +) -> Result<(), ControlTaskError> { + let (reader, mut writer) = connect + .into_streams() + .await + .context(control_task_error::TakeoverSnafu)?; + let mut reader = StreamReader::new(reader.into_box_reader()); + + loop { + tokio::select! { + command = commands.recv() => { + let Some(command) = command else { + break; + }; + if handle_control_command(&state, &mut writer, command).await { + break; + } + } + capsule = read_remote_control_capsule(&mut reader) => match capsule? { + Some(RemoteControlCapsule::Close(close)) => { + state.close_with_reason(CloseReason::Session(SessionCloseReason::Remote(close))); + break; + } + Some(RemoteControlCapsule::Drain) => { + state.drain_with_reason(SessionDrain::Requested(DrainReason::Session( + SessionDrainReason::Remote, + ))); + } + Some(RemoteControlCapsule::MaxStreamsBidi(maximum)) => { + if let Err(error) = state.update_peer_bidi_max(maximum) { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %state.id(), + error = %report, + "invalid webtransport bidi max streams update" + ); + state.close_with_reason(CloseReason::Session(SessionCloseReason::Protocol { + code: crate::error::Code::WT_FLOW_CONTROL_ERROR, + })); + break; + } + } + Some(RemoteControlCapsule::MaxStreamsUni(maximum)) => { + if let Err(error) = state.update_peer_uni_max(maximum) { + let report = snafu::Report::from_error(&error); + tracing::debug!( + session_id = %state.id(), + error = %report, + "invalid webtransport uni max streams update" + ); + state.close_with_reason(CloseReason::Session(SessionCloseReason::Protocol { + code: crate::error::Code::WT_FLOW_CONTROL_ERROR, + })); + break; + } + } + Some(RemoteControlCapsule::Ignored) => {} + None => break, + } + } + } + + Ok(()) +} + +async fn read_remote_control_capsule( + reader: &mut S, +) -> Result, ControlTaskError> +where + S: AsyncRead + tokio::io::AsyncBufRead + Unpin + Send, +{ + if reader + .fill_buf() + .await + .context(control_task_error::DecodeCapsuleSnafu)? + .is_empty() + { + return Ok(None); + } + + let r#type = CapsuleType::from( + reader + .decode_one::() + .await + .context(control_task_error::DecodeCapsuleSnafu)?, + ); + let length = reader + .decode_one::() + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + + match r#type { + CapsuleType::WT_CLOSE_SESSION => { + if length.into_inner() > CLOSE_SESSION_CAPSULE_MAX_PAYLOAD { + skip_control_capsule_payload(reader, length) + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + return Err(ControlTaskError::CloseSessionPayloadTooLarge); + } + let payload = read_control_capsule_payload(reader, length) + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + let close = payload + .decode::() + .await + .context(control_task_error::CloseSessionPayloadSnafu)?; + Ok(Some(RemoteControlCapsule::Close(close))) + } + CapsuleType::WT_DRAIN_SESSION => { + if length != VarInt::from_u32(0) { + skip_control_capsule_payload(reader, length) + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + return Err(ControlTaskError::DrainSessionPayload); + } + Ok(Some(RemoteControlCapsule::Drain)) + } + CapsuleType::WT_MAX_STREAMS_BIDI => { + let maximum = read_stream_count_payload(reader, length).await?; + Ok(Some(RemoteControlCapsule::MaxStreamsBidi(maximum))) + } + CapsuleType::WT_MAX_STREAMS_UNI => { + let maximum = read_stream_count_payload(reader, length).await?; + Ok(Some(RemoteControlCapsule::MaxStreamsUni(maximum))) + } + _ => { + skip_control_capsule_payload(reader, length) + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + Ok(Some(RemoteControlCapsule::Ignored)) + } + } +} + +async fn read_stream_count_payload( + reader: &mut S, + length: VarInt, +) -> Result +where + S: AsyncRead + Unpin + Send, +{ + if length.into_inner() > STREAM_COUNT_CAPSULE_MAX_PAYLOAD { + skip_control_capsule_payload(reader, length) + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + return Err(ControlTaskError::StreamCountPayloadTooLarge); + } + let mut payload = read_control_capsule_payload(reader, length) + .await + .context(control_task_error::DecodeCapsuleSnafu)?; + let count = payload + .decode_one::() + .await + .context(control_task_error::StreamCountPayloadSnafu)?; + if payload.remaining() != 0 { + return Err(ControlTaskError::StreamCountPayloadTrailing); + } + Ok(count) +} + +async fn read_control_capsule_payload( + reader: &mut S, + length: VarInt, +) -> Result +where + S: AsyncRead + Unpin + Send, +{ + let mut remaining = length.into_inner(); + let mut payload = BufList::new(); + while remaining > 0 { + let len = remaining.min(CONTROL_CAPSULE_SKIP_CHUNK_SIZE as u64) as usize; + let mut bytes = vec![0; len]; + reader.read_exact(&mut bytes).await?; + payload.write(bytes.as_slice()); + remaining -= len as u64; + } + Ok(payload) +} + +async fn skip_control_capsule_payload( + reader: &mut S, + length: VarInt, +) -> Result<(), std::io::Error> +where + S: AsyncRead + Unpin + Send, +{ + let mut remaining = length.into_inner(); + let mut scratch = vec![0; CONTROL_CAPSULE_SKIP_CHUNK_SIZE]; + while remaining > 0 { + let len = remaining.min(scratch.len() as u64) as usize; + reader.read_exact(&mut scratch[..len]).await?; + remaining -= len as u64; + } + Ok(()) +} + +async fn handle_control_command( + state: &Arc, + writer: &mut MessageWriter, + command: ControlCommand, +) -> bool { + match command { + ControlCommand::Drain { ack } => { + let result = write_drain_capsule(writer) + .await + .context(control_command_error::WriteSnafu); + if result.is_ok() { + state.drain_with_reason(SessionDrain::Requested(DrainReason::Session( + SessionDrainReason::Local, + ))); + } + let should_close = result.is_err(); + _ = ack.send(result); + should_close + } + ControlCommand::Close { close, ack } => { + let result = async { + write_close_capsule(writer, close.clone()) + .await + .context(control_command_error::WriteSnafu)?; + writer + .close() + .await + .context(control_command_error::WriteSnafu)?; + Ok(()) + } + .await; + if result.is_ok() { + state.close_with_reason(CloseReason::Session(SessionCloseReason::Local(close))); + } + _ = ack.send(result); + true + } + ControlCommand::StreamsBlockedBidi { maximum, ack } => { + let result = + write_stream_count_capsule(writer, CapsuleType::WT_STREAMS_BLOCKED_BIDI, maximum) + .await + .context(control_command_error::WriteSnafu); + let should_close = result.is_err(); + _ = ack.send(result); + should_close + } + ControlCommand::StreamsBlockedUni { maximum, ack } => { + let result = + write_stream_count_capsule(writer, CapsuleType::WT_STREAMS_BLOCKED_UNI, maximum) + .await + .context(control_command_error::WriteSnafu); + let should_close = result.is_err(); + _ = ack.send(result); + should_close + } + } +} + +async fn write_drain_capsule(writer: &mut MessageWriter) -> Result<(), MessageStreamError> { + let capsule = Capsule::new(CapsuleType::WT_DRAIN_SESSION, BufList::new()) + .expect("empty webtransport drain capsule payload is a valid varint length"); + write_control_capsule(writer, capsule).await +} + +async fn write_close_capsule( + writer: &mut MessageWriter, + close: CloseSession, +) -> Result<(), MessageStreamError> { + let payload = BufList::new() + .encode(close) + .await + .expect("encoding webtransport close session payload into a buflist is infallible"); + let capsule = Capsule::new(CapsuleType::WT_CLOSE_SESSION, payload) + .expect("validated webtransport close session payload is a valid varint length"); + write_control_capsule(writer, capsule).await +} + +async fn write_stream_count_capsule( + writer: &mut MessageWriter, + r#type: CapsuleType, + count: WebTransportStreamCount, +) -> Result<(), MessageStreamError> { + let payload = BufList::new() + .encode(count) + .await + .expect("encoding webtransport stream count payload into a buflist is infallible"); + let capsule = Capsule::new(r#type, payload) + .expect("validated webtransport stream count payload is a valid varint length"); + write_control_capsule(writer, capsule).await +} + +async fn write_control_capsule( + writer: &mut MessageWriter, + capsule: Capsule, +) -> Result<(), MessageStreamError> { + let payload = BufList::new() + .encode(capsule) + .await + .expect("encoding webtransport capsule into a buflist is infallible"); + writer.write_data(payload).await?; + writer.flush().await?; + Ok(()) +} + +impl TryFrom for WebTransportSession { + type Error = RegisterSessionError; + + fn try_from(connect: EstablishedConnect) -> Result { + let protocol = match connect.protocol() { + Some(protocol) => protocol, + None => return Err(RegisterSessionError::MissingProtocol), + }; + if protocol.as_str() != WEBTRANSPORT_H3 { + let protocol = protocol.clone(); + return Err(RegisterSessionError::UnexpectedProtocol { protocol }); + } + + let (registered, conn, peer_bidi_credit, peer_uni_credit) = { + let dhttp = connect + .connection() + .protocol::() + .ok_or(RegisterSessionError::ProtocolLayerMissing)?; + let peer_settings = dhttp + .peer_settings_peek() + .ok_or(RegisterSessionError::PeerSettingsUnavailable)?; + if !peer_settings.enable_webtransport() { + return Err(RegisterSessionError::WebTransportNotEnabled); + } + if !peer_settings.webtransport_flow_control_enabled() { + return Err(RegisterSessionError::FlowControlNotEnabled); + } + let peer_bidi_credit = + WebTransportStreamCount::try_from(peer_settings.wt_initial_max_streams_bidi()) + .context(register_session_error::InitialStreamCountSnafu)?; + let peer_uni_credit = + WebTransportStreamCount::try_from(peer_settings.wt_initial_max_streams_uni()) + .context(register_session_error::InitialStreamCountSnafu)?; + + let protocol = connect + .connection() + .protocol::() + .ok_or(RegisterSessionError::ProtocolLayerMissing)?; + let session_id = WebTransportSessionId::try_from(connect.stream_id()) + .context(register_session_error::InvalidSessionIdSnafu)?; + let registered = protocol.register(session_id)?; + let conn = protocol.connection(); + (registered, conn, peer_bidi_credit, peer_uni_credit) + }; + + Ok(Self::from_registered_with_peer_credit( + registered, + conn, + connect, + peer_bidi_credit, + peer_uni_credit, + )) + } } // ============================================================================ @@ -134,28 +895,43 @@ impl WebTransportSession { /// Write the WebTransport stream routing header (signal + session_id) and flush. async fn write_header( - writer: BoxWriteStream, + writer: BoxQuicStreamWriter, signal: VarInt, - session_id: VarInt, -) -> Result { + session_id: StreamId, +) -> Result { let mut codec_writer = SinkWriter::new(writer); - codec_writer - .encode_one(signal) + encode_header_value(&mut codec_writer, signal) .await - .map_err(quic::StreamError::from) - .context(WriteHeaderSnafu)?; - codec_writer - .encode_one(session_id) + .context(open_stream_error::WriteHeaderSnafu)?; + encode_header_value(&mut codec_writer, session_id) .await - .map_err(quic::StreamError::from) - .context(WriteHeaderSnafu)?; - AsyncWriteExt::flush(&mut codec_writer) + .context(open_stream_error::WriteHeaderSnafu)?; + flush_header(&mut codec_writer) .await - .map_err(quic::StreamError::from) - .context(WriteHeaderSnafu)?; + .context(open_stream_error::WriteHeaderSnafu)?; Ok(codec_writer.into_inner()) } +async fn encode_header_value( + codec_writer: &mut SinkWriter, + value: T, +) -> Result<(), quic::StreamError> +where + T: Send, + for<'a> T: + EncodeInto<&'a mut SinkWriter, Error = std::io::Error, Output = ()>, +{ + codec_writer.encode_one(value).await?; + Ok(()) +} + +async fn flush_header( + codec_writer: &mut SinkWriter, +) -> Result<(), quic::StreamError> { + AsyncWriteExt::flush(codec_writer).await?; + Ok(()) +} + // ============================================================================ // Trait impls // ============================================================================ @@ -163,23 +939,1992 @@ async fn write_header( impl fmt::Debug for WebTransportSession { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("WebTransportSession") - .field("session_id", &self.session_id) + .field("id", &self.id()) .finish() } } impl Drop for WebTransportSession { fn drop(&mut self) { - if let Ok(mut registry) = self.registry.lock() { - registry.remove(&self.session_id); - } + self.state.close(); } } #[cfg(test)] mod tests { + use std::{ + borrow::Cow, + collections::VecDeque, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, + }; + + use bytes::Bytes; + use dhttp_identity::identity as authority; + use futures::{Sink, SinkExt, Stream, StreamExt, future, future::pending}; + use tokio::time::{Duration, timeout}; + use tokio_util::task::AbortOnDropHandle; + use tracing::{Instrument, Level}; + use super::*; + use crate::{ + buflist::BufList, + codec::{DecodeExt, SinkWriter, StreamReader}, + connection::{ConnectionState, tests::MockConnection}, + dhttp::{ + message::{ + MessageReader, MessageWriter, guard, + test::{read_stream_for_test, write_stream_for_test}, + }, + protocol::DHttpProtocol, + settings::Settings, + webtransport::{ + capsule::{Capsule, CapsuleType}, + settings::{ + EnableWebTransport, InitialMaxData, InitialMaxStreamsBidi, InitialMaxStreamsUni, + }, + }, + }, + error::Code, + extended_connect::EstablishedConnect, + protocol::Protocols, + qpack::{ + field::Protocol, + protocol::{QPackDecoder, QPackEncoder}, + }, + quic, + stream_id::StreamId, + varint::VarInt, + webtransport::{ + CloseSession, WEBTRANSPORT_H3, WebTransportProtocol, WebTransportSessionId, + WebTransportStreamCount, registry::Registry, + }, + }; const fn assert_send_sync() {} const _: () = assert_send_sync::(); + + #[derive(Debug, Default)] + struct StreamState { + written: Mutex>, + stopped: Mutex>, + resets: Mutex>, + flushes: Mutex, + } + + impl StreamState { + fn written(&self) -> Vec { + self.written.lock().expect("written lock poisoned").clone() + } + + fn reset_codes(&self) -> Vec { + self.resets.lock().expect("resets lock poisoned").clone() + } + + fn stopped_codes(&self) -> Vec { + self.stopped.lock().expect("stopped lock poisoned").clone() + } + + fn flushes(&self) -> usize { + *self.flushes.lock().expect("flushes lock poisoned") + } + } + + #[derive(Debug)] + struct TestReadStream { + state: Arc, + chunks: VecDeque, + stream_id: VarInt, + } + + impl Stream for TestReadStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.chunks.pop_front().map(Ok)) + } + } + + impl quic::GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.state + .stopped + .lock() + .expect("stopped lock poisoned") + .push(code); + Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct TestWriteStream { + state: Arc, + stream_id: VarInt, + } + + impl Sink for TestWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.state + .written + .lock() + .expect("written lock poisoned") + .extend_from_slice(&item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let mut flushes = self.state.flushes.lock().expect("flushes lock poisoned"); + *flushes += 1; + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + #[derive(Debug, Clone, Copy)] + enum HeaderFailureMode { + SendOnCall(usize), + Flush, + } + + #[derive(Debug)] + struct HeaderFailWriteStream { + mode: HeaderFailureMode, + send_calls: usize, + state: Arc, + stream_id: VarInt, + } + + impl Sink for HeaderFailWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.send_calls += 1; + if matches!(self.mode, HeaderFailureMode::SendOnCall(call) if call == self.send_calls) { + return Err(quic::StreamError::Reset { + code: VarInt::from_u32(0xe0 + self.send_calls as u32), + }); + } + + self.state + .written + .lock() + .expect("written lock poisoned") + .extend_from_slice(&item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if matches!(self.mode, HeaderFailureMode::Flush) { + return Poll::Ready(Err(quic::StreamError::Reset { + code: VarInt::from_u32(0xef), + })); + } + + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl quic::GetStreamId for HeaderFailWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::ResetStream for HeaderFailWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _code: VarInt, + ) -> Poll> { + Poll::Ready(Ok(())) + } + } + + impl quic::GetStreamId for TestWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl quic::ResetStream for TestWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.state + .resets + .lock() + .expect("resets lock poisoned") + .push(code); + Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct TestLocalAuthority; + + impl authority::LocalAuthority for TestLocalAuthority { + fn name(&self) -> &str { + "test-local" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + fn sign( + &self, + _data: &[u8], + ) -> futures::future::BoxFuture<'_, Result, authority::SignError>> { + Box::pin(async { Ok(Vec::new()) }) + } + } + + #[derive(Debug)] + struct TestRemoteAuthority; + + impl authority::RemoteAuthority for TestRemoteAuthority { + fn name(&self) -> &str { + "test-remote" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + } + + #[derive(Debug)] + struct TestConnection { + state: Arc, + } + + impl quic::ManageStream for TestConnection { + type StreamReader = TestReadStream; + type StreamWriter = TestWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + Ok(( + TestReadStream { + state: Arc::clone(&self.state), + chunks: VecDeque::new(), + stream_id: VarInt::from_u32(9), + }, + TestWriteStream { + state: Arc::clone(&self.state), + stream_id: VarInt::from_u32(9), + }, + )) + } + + async fn open_uni(&self) -> Result { + Ok(TestWriteStream { + state: Arc::clone(&self.state), + stream_id: VarInt::from_u32(10), + }) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for TestConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for TestConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for TestConnection { + fn close(&self, _code: crate::error::Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + pending().await + } + } + + #[derive(Debug)] + struct HeaderFailConnection { + mode: HeaderFailureMode, + state: Arc, + } + + impl quic::ManageStream for HeaderFailConnection { + type StreamReader = TestReadStream; + type StreamWriter = HeaderFailWriteStream; + + async fn open_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + Ok(( + TestReadStream { + state: Arc::clone(&self.state), + chunks: VecDeque::new(), + stream_id: VarInt::from_u32(21), + }, + HeaderFailWriteStream { + mode: self.mode, + send_calls: 0, + state: Arc::clone(&self.state), + stream_id: VarInt::from_u32(21), + }, + )) + } + + async fn open_uni(&self) -> Result { + Ok(HeaderFailWriteStream { + mode: self.mode, + send_calls: 0, + state: Arc::clone(&self.state), + stream_id: VarInt::from_u32(22), + }) + } + + async fn accept_bi( + &self, + ) -> Result<(Self::StreamReader, Self::StreamWriter), quic::ConnectionError> { + pending().await + } + + async fn accept_uni(&self) -> Result { + pending().await + } + } + + impl quic::WithLocalAuthority for HeaderFailConnection { + type LocalAuthority = TestLocalAuthority; + + async fn local_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::WithRemoteAuthority for HeaderFailConnection { + type RemoteAuthority = TestRemoteAuthority; + + async fn remote_authority( + &self, + ) -> Result, quic::ConnectionError> { + Ok(None) + } + } + + impl quic::Lifecycle for HeaderFailConnection { + fn close(&self, _code: crate::error::Code, _reason: Cow<'static, str>) {} + + fn check(&self) -> Result<(), quic::ConnectionError> { + Ok(()) + } + + async fn closed(&self) -> quic::ConnectionError { + pending().await + } + } + + fn enable_debug_tracing_for_test() -> tracing::dispatcher::DefaultGuard { + let subscriber = tracing_subscriber::fmt() + .with_max_level(Level::DEBUG) + .with_writer(std::io::sink) + .finish(); + let dispatch = tracing::Dispatch::new(subscriber); + tracing::dispatcher::set_default(&dispatch) + } + + #[tokio::test] + async fn test_stream_helpers_expose_ids_and_shutdown_codes() { + use crate::quic::{GetStreamIdExt, ResetStreamExt, StopStreamExt}; + + let state = Arc::new(StreamState::default()); + let mut reader = TestReadStream { + state: Arc::clone(&state), + chunks: VecDeque::from([Bytes::from_static(&[0x01])]), + stream_id: VarInt::from_u32(41), + }; + let mut writer = TestWriteStream { + state: Arc::clone(&state), + stream_id: VarInt::from_u32(42), + }; + + assert_eq!( + reader.stream_id().await.expect("reader id"), + VarInt::from_u32(41) + ); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(42) + ); + reader + .stop(VarInt::from_u32(0x31)) + .await + .expect("stop reader"); + writer + .reset(VarInt::from_u32(0x32)) + .await + .expect("reset writer"); + writer.close().await.expect("close writer"); + + assert_eq!(state.stopped_codes(), vec![VarInt::from_u32(0x31)]); + assert_eq!(state.reset_codes(), vec![VarInt::from_u32(0x32)]); + } + + #[tokio::test] + async fn test_agents_return_static_metadata_and_empty_signature() { + let local = TestLocalAuthority; + let remote = TestRemoteAuthority; + + assert_eq!(authority::LocalAuthority::name(&local), "test-local"); + assert!(authority::LocalAuthority::cert_chain(&local).is_empty()); + assert!( + authority::LocalAuthority::sign(&local, b"payload") + .await + .expect("test signer") + .is_empty() + ); + + assert_eq!(authority::RemoteAuthority::name(&remote), "test-remote"); + assert!(authority::RemoteAuthority::cert_chain(&remote).is_empty()); + } + + #[tokio::test] + async fn test_connection_trait_methods_are_pending_or_empty() { + let conn = TestConnection { + state: Arc::new(StreamState::default()), + }; + + assert!( + quic::WithLocalAuthority::local_authority(&conn) + .await + .expect("local authority") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(&conn) + .await + .expect("remote authority") + .is_none() + ); + quic::Lifecycle::check(&conn).expect("connection is open"); + quic::Lifecycle::close( + &conn, + crate::error::Code::from(VarInt::from_u32(0)), + Cow::Borrowed("test close"), + ); + + timeout( + Duration::from_millis(10), + quic::ManageStream::accept_bi(&conn), + ) + .await + .expect_err("accept_bi should stay pending"); + timeout( + Duration::from_millis(10), + quic::ManageStream::accept_uni(&conn), + ) + .await + .expect_err("accept_uni should stay pending"); + timeout(Duration::from_millis(10), quic::Lifecycle::closed(&conn)) + .await + .expect_err("closed should stay pending"); + } + + #[tokio::test] + async fn header_fail_connection_nonfailing_mode_exercises_remaining_traits() { + use crate::quic::{GetStreamIdExt, ResetStreamExt}; + + let state = Arc::new(StreamState::default()); + let conn = HeaderFailConnection { + mode: HeaderFailureMode::SendOnCall(99), + state: Arc::clone(&state), + }; + + assert!( + quic::WithLocalAuthority::local_authority(&conn) + .await + .expect("local authority") + .is_none() + ); + assert!( + quic::WithRemoteAuthority::remote_authority(&conn) + .await + .expect("remote authority") + .is_none() + ); + quic::Lifecycle::check(&conn).expect("connection is open"); + quic::Lifecycle::close( + &conn, + crate::error::Code::from(VarInt::from_u32(0)), + Cow::Borrowed("test close"), + ); + timeout( + Duration::from_millis(10), + quic::ManageStream::accept_bi(&conn), + ) + .await + .expect_err("accept_bi should stay pending"); + timeout( + Duration::from_millis(10), + quic::ManageStream::accept_uni(&conn), + ) + .await + .expect_err("accept_uni should stay pending"); + timeout(Duration::from_millis(10), quic::Lifecycle::closed(&conn)) + .await + .expect_err("closed should stay pending"); + + let (_reader, mut writer) = quic::ManageStream::open_bi(&conn) + .await + .expect("open bidi stream"); + assert_eq!( + writer.stream_id().await.expect("writer id"), + VarInt::from_u32(21) + ); + writer + .send(Bytes::from_static(&[0xaa])) + .await + .expect("write payload"); + writer.flush().await.expect("flush writer"); + writer + .reset(VarInt::from_u32(0x44)) + .await + .expect("reset writer"); + writer.close().await.expect("close writer"); + assert_eq!(state.written(), vec![0xaa]); + } + + #[tokio::test] + async fn header_fail_write_stream_can_fail_after_prior_write() { + let stream_state = Arc::new(StreamState::default()); + let mut writer = HeaderFailWriteStream { + mode: HeaderFailureMode::SendOnCall(2), + state: Arc::clone(&stream_state), + send_calls: 0, + stream_id: VarInt::from_u32(23), + }; + + writer + .send(Bytes::from_static(&[0x40, 0x41])) + .await + .expect("first write should succeed"); + let error = writer + .send(Bytes::from_static(&[0x04])) + .await + .expect_err("second write should fail"); + assert!(matches!(error, quic::StreamError::Reset { .. })); + assert_eq!(stream_state.written(), vec![0x40, 0x41]); + } + + fn connection_with_webtransport() -> Arc> { + connection_with_webtransport_and_peer_settings(enabled_webtransport_settings()) + } + + fn enabled_webtransport_settings() -> Settings { + let mut settings = Settings::default(); + settings.set(EnableWebTransport::setting(true)); + settings.set(InitialMaxStreamsBidi::setting(VarInt::from_u32(16))); + settings.set(InitialMaxStreamsUni::setting(VarInt::from_u32(16))); + settings.set(InitialMaxData::setting(VarInt::MAX)); + settings + } + + fn connection_with_webtransport_and_peer_settings( + settings: Settings, + ) -> Arc> { + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + let dhttp = DHttpProtocol::new_for_test(erased.clone()); + dhttp + .state + .peer_settings + .set(Arc::new(settings)) + .expect("peer settings should be set once"); + protocols.insert(dhttp); + protocols.insert(WebTransportProtocol::new_for_test(erased)); + Arc::new(ConnectionState::new_for_test(quic, Arc::new(protocols)).erase()) + } + + fn connect_with_protocol(protocol: Option) -> EstablishedConnect { + let stream_id = StreamId::from(VarInt::from_u32(4)); + EstablishedConnect::ready( + stream_id, + protocol, + connection_with_webtransport(), + read_stream_for_test(stream_id.0), + write_stream_for_test(stream_id.0), + ) + } + + fn connect_on_connection( + connection: Arc>, + stream_id: StreamId, + protocol: Option, + ) -> EstablishedConnect { + EstablishedConnect::ready( + stream_id, + protocol, + connection, + read_stream_for_test(stream_id.0), + write_stream_for_test(stream_id.0), + ) + } + + fn connection_without_webtransport() -> Arc> { + let quic = Arc::new(MockConnection::new()); + Arc::new(ConnectionState::new_for_test(quic.clone(), Arc::new(Protocols::new())).erase()) + } + + fn wt_session_id(session_id: StreamId) -> WebTransportSessionId { + WebTransportSessionId::try_from(session_id) + .expect("test id must be a valid webtransport session id") + } + + fn noop_task() -> AbortOnDropHandle<()> { + AbortOnDropHandle::new(tokio::spawn(async {}.in_current_span())) + } + + fn noop_control_handle() -> ControlHandle { + let (tx, _rx) = mpsc::channel(1); + ControlHandle { tx } + } + + fn session_with_connection( + session_id: StreamId, + conn: Arc, + ) -> (Registry, WebTransportSession) { + let registry = Registry::default(); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let session = WebTransportSession { + state: Arc::clone(®istered.state), + bidi_rx: tokio::sync::Mutex::new(registered.bidi_rx), + uni_rx: tokio::sync::Mutex::new(registered.uni_rx), + conn, + control: noop_control_handle(), + _control_task: noop_task(), + }; + (registry, session) + } + + fn open_session_with_closed_receivers( + session_id: StreamId, + conn: Arc, + ) -> WebTransportSession { + let registry = Registry::default(); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let (_bidi_tx, bidi_rx) = mpsc::channel(1); + let (_uni_tx, uni_rx) = mpsc::channel(1); + + WebTransportSession { + state: Arc::clone(®istered.state), + bidi_rx: tokio::sync::Mutex::new(bidi_rx), + uni_rx: tokio::sync::Mutex::new(uni_rx), + conn, + control: noop_control_handle(), + _control_task: noop_task(), + } + } + + fn qpack_decoder_sink() -> Pin< + Box< + dyn Sink< + crate::qpack::decoder::DecoderInstruction, + Error = crate::connection::StreamError, + > + Send, + >, + > { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + + fn qpack_decoder_stream() -> Pin< + Box< + dyn Stream< + Item = Result< + crate::qpack::encoder::EncoderInstruction, + crate::connection::StreamError, + >, + > + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn qpack_encoder_sink() -> Pin< + Box< + dyn Sink< + crate::qpack::encoder::EncoderInstruction, + Error = crate::connection::StreamError, + > + Send, + >, + > { + Box::pin( + futures::sink::drain::() + .sink_map_err(|never| match never {}), + ) + } + + fn qpack_encoder_stream() -> Pin< + Box< + dyn Stream< + Item = Result< + crate::qpack::decoder::DecoderInstruction, + crate::connection::StreamError, + >, + > + Send, + >, + > { + Box::pin(futures::stream::empty::< + Result, + >()) + } + + fn message_reader_from_quic_reader( + stream_id: VarInt, + reader: crate::quic::BoxQuicStreamReader, + ) -> MessageReader { + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(quic, Arc::new(protocols)).erase(); + + MessageReader::new( + stream_id, + StreamReader::new(guard::GuardQuicReader::new(reader)), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state, + ) + } + + fn message_writer_from_quic_writer(writer: crate::quic::BoxQuicStreamWriter) -> MessageWriter { + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(quic, Arc::new(protocols)).erase(); + + MessageWriter::new( + SinkWriter::new(guard::GuardQuicWriter::new(writer)), + Arc::new(QPackEncoder::new( + Arc::new(Settings::default()), + qpack_encoder_sink(), + qpack_encoder_stream(), + )), + state, + ) + } + + struct ConnectStreamState { + outgoing_reader: tokio::sync::Mutex, + incoming_writer: tokio::sync::Mutex, + } + + impl ConnectStreamState { + async fn inject_capsule(&self, capsule: Capsule) { + let mut writer = self.incoming_writer.lock().await; + write_control_capsule(&mut writer, capsule) + .await + .expect("injected control capsule should write"); + } + + async fn inject_close(&self, close: CloseSession) { + let payload = BufList::new() + .encode(close) + .await + .expect("close payload should encode"); + let capsule = Capsule::new(CapsuleType::WT_CLOSE_SESSION, payload) + .expect("close payload length is valid"); + self.inject_capsule(capsule).await; + } + + async fn inject_drain(&self) { + let capsule = Capsule::new(CapsuleType::WT_DRAIN_SESSION, BufList::new()) + .expect("empty drain payload length is valid"); + self.inject_capsule(capsule).await; + } + + async fn next_capsule(&self) -> Capsule { + let mut reader = self.outgoing_reader.lock().await; + let mut payload = BufList::new(); + loop { + let bytes = timeout(Duration::from_secs(1), reader.read_data_chunk()) + .await + .expect("control capsule should be written") + .expect("control data frame should decode") + .expect("control writer should not close before capsule"); + payload.write(bytes); + if let Ok(capsule) = payload.clone().decode::>().await { + return capsule; + } + } + } + + async fn next_close_capsule(&self) -> CloseSession { + let capsule = self.next_capsule().await; + assert_eq!(capsule.r#type(), CapsuleType::WT_CLOSE_SESSION); + capsule + .into_payload() + .decode::() + .await + .expect("close session capsule payload should decode") + } + + async fn next_drain_capsule(&self) { + let capsule = self.next_capsule().await; + assert_eq!(capsule.r#type(), CapsuleType::WT_DRAIN_SESSION); + assert_eq!(capsule.length(), VarInt::from_u32(0)); + } + + async fn next_streams_blocked_bidi(&self) -> WebTransportStreamCount { + let capsule = self.next_capsule().await; + assert_eq!(capsule.r#type(), CapsuleType::WT_STREAMS_BLOCKED_BIDI); + capsule + .into_payload() + .decode::() + .await + .expect("streams blocked bidi payload should decode") + } + + async fn writer_closed(&self) -> bool { + let mut reader = self.outgoing_reader.lock().await; + matches!( + timeout(Duration::from_millis(50), reader.read_data_chunk()).await, + Ok(Ok(None)) + ) + } + } + + fn webtransport_session_with_observed_connect_stream( + session_id: StreamId, + ) -> (WebTransportSession, ConnectStreamState) { + let registry = Registry::default(); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + + let (incoming_reader, incoming_writer) = quic::test::mock_stream_pair(session_id.0); + let (outgoing_reader, outgoing_writer) = quic::test::mock_stream_pair(session_id.0); + let connect = EstablishedConnect::ready( + session_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection_without_webtransport(), + message_reader_from_quic_reader(session_id.0, Box::pin(incoming_reader)), + message_writer_from_quic_writer(Box::pin(outgoing_writer)), + ); + let session = WebTransportSession::from_registered(registered, conn, connect); + let state = ConnectStreamState { + outgoing_reader: tokio::sync::Mutex::new(message_reader_from_quic_reader( + session_id.0, + Box::pin(outgoing_reader), + )), + incoming_writer: tokio::sync::Mutex::new(message_writer_from_quic_writer(Box::pin( + incoming_writer, + ))), + }; + + (session, state) + } + + fn stream_count(value: u32) -> WebTransportStreamCount { + WebTransportStreamCount::try_from(VarInt::from_u32(value)).expect("valid stream count") + } + + async fn max_streams_bidi_capsule(value: WebTransportStreamCount) -> Capsule { + let payload = BufList::new() + .encode(value) + .await + .expect("stream count payload should encode"); + Capsule::new(CapsuleType::WT_MAX_STREAMS_BIDI, payload) + .expect("stream count payload length is valid") + } + + fn webtransport_session_with_peer_credit( + session_id: StreamId, + peer_bidi_credit: WebTransportStreamCount, + peer_uni_credit: WebTransportStreamCount, + ) -> (WebTransportSession, ConnectStreamState) { + let registry = Registry::default(); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + + let (incoming_reader, incoming_writer) = quic::test::mock_stream_pair(session_id.0); + let (outgoing_reader, outgoing_writer) = quic::test::mock_stream_pair(session_id.0); + let connect = EstablishedConnect::ready( + session_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection_without_webtransport(), + message_reader_from_quic_reader(session_id.0, Box::pin(incoming_reader)), + message_writer_from_quic_writer(Box::pin(outgoing_writer)), + ); + let session = WebTransportSession::from_registered_with_peer_credit( + registered, + conn, + connect, + peer_bidi_credit, + peer_uni_credit, + ); + let state = ConnectStreamState { + outgoing_reader: tokio::sync::Mutex::new(message_reader_from_quic_reader( + session_id.0, + Box::pin(outgoing_reader), + )), + incoming_writer: tokio::sync::Mutex::new(message_writer_from_quic_writer(Box::pin( + incoming_writer, + ))), + }; + + (session, state) + } + + async fn read_stream_with_bytes(stream_id: u32, bytes: &[u8]) -> MessageReader { + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + writer + .send(Bytes::copy_from_slice(bytes)) + .await + .expect("write test stream bytes"); + writer.close().await.expect("close test stream writer"); + + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(quic, Arc::new(protocols)).erase(); + + MessageReader::new( + VarInt::from_u32(stream_id), + StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state, + ) + } + + async fn read_stream_with_reset(stream_id: u32, code: VarInt) -> MessageReader { + use crate::quic::ResetStreamExt; + + let (reader, mut writer) = quic::test::mock_stream_pair(VarInt::from_u32(stream_id)); + writer + .reset(code) + .await + .expect("send reset to test stream reader"); + drop(writer); + + let quic = Arc::new(MockConnection::new()); + let erased: Arc = quic.clone(); + let mut protocols = Protocols::new(); + protocols.insert(DHttpProtocol::new_for_test(erased.clone())); + let state = ConnectionState::new_for_test(quic, Arc::new(protocols)).erase(); + + MessageReader::new( + VarInt::from_u32(stream_id), + StreamReader::new(guard::GuardQuicReader::new( + Box::pin(reader) as crate::quic::BoxQuicStreamReader + )), + Arc::new(QPackDecoder::new( + Arc::new(Settings::default()), + qpack_decoder_sink(), + qpack_decoder_stream(), + )), + state, + ) + } + + async fn wait_until_session_closed(session: &WebTransportSession) { + timeout(Duration::from_secs(1), async { + loop { + if session.state.check_open().is_err() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .expect("session should close before timeout"); + } + + fn assert_transport_reason(error: &quic::ConnectionError, expected_reason: &str) { + match error { + quic::ConnectionError::Transport { source } => { + assert_eq!(source.reason.as_ref(), expected_reason); + } + other => panic!("expected transport error, got {other:?}"), + } + } + + #[tokio::test] + async fn try_from_established_connect_registers_session() { + let session = WebTransportSession::try_from(connect_with_protocol(Some(Protocol::new( + WEBTRANSPORT_H3, + )))) + .expect("valid webtransport connect registers session"); + + assert_eq!( + session.id(), + wt_session_id(StreamId::from(VarInt::from_u32(4))) + ); + } + + #[tokio::test] + async fn try_from_rejects_invalid_connect_stream_id() { + let error = WebTransportSession::try_from(connect_on_connection( + connection_with_webtransport(), + StreamId::from(VarInt::from_u32(3)), + Some(Protocol::new(WEBTRANSPORT_H3)), + )) + .expect_err("server unidirectional stream id cannot establish a WT session"); + + let RegisterSessionError::InvalidSessionId { source } = error else { + panic!("expected invalid session id error, got {error:?}"); + }; + assert_eq!(source.session_id(), StreamId::from(VarInt::from_u32(3))); + } + + #[tokio::test] + async fn concrete_session_id_returns_proof_type() { + let session = WebTransportSession::try_from(connect_with_protocol(Some(Protocol::new( + WEBTRANSPORT_H3, + )))) + .expect("valid webtransport connect registers session"); + + let id: WebTransportSessionId = session.id(); + + assert_eq!(id.stream_id(), StreamId::from(VarInt::from_u32(4))); + } + + #[tokio::test] + async fn try_from_rejects_when_peer_settings_do_not_enable_webtransport() { + let connect = connect_on_connection( + connection_with_webtransport_and_peer_settings(Settings::default()), + StreamId::from(VarInt::from_u32(4)), + Some(Protocol::new(WEBTRANSPORT_H3)), + ); + + let error = WebTransportSession::try_from(connect) + .expect_err("WT session must require peer WebTransport settings"); + + assert!(matches!( + error, + RegisterSessionError::WebTransportNotEnabled + )); + } + + #[tokio::test] + async fn try_from_accepts_when_peer_settings_enable_webtransport_and_stream_credit() { + let connect = connect_on_connection( + connection_with_webtransport_and_peer_settings(enabled_webtransport_settings()), + StreamId::from(VarInt::from_u32(4)), + Some(Protocol::new(WEBTRANSPORT_H3)), + ); + + let session = WebTransportSession::try_from(connect).expect("negotiated WT session"); + + assert_eq!( + session.id().stream_id(), + StreamId::from(VarInt::from_u32(4)) + ); + } + + #[tokio::test] + async fn try_from_rejects_missing_protocol() { + let error = WebTransportSession::try_from(connect_with_protocol(None)) + .expect_err("missing protocol is invalid"); + assert!(matches!(error, RegisterSessionError::MissingProtocol)); + } + + #[tokio::test] + async fn try_from_rejects_unexpected_protocol() { + let error = WebTransportSession::try_from(connect_with_protocol(Some(Protocol::new( + "other-protocol", + )))) + .expect_err("wrong protocol token is invalid"); + assert!(matches!( + error, + RegisterSessionError::UnexpectedProtocol { .. } + )); + } + + #[tokio::test] + async fn try_from_rejects_missing_protocol_layer() { + let error = WebTransportSession::try_from(connect_on_connection( + connection_without_webtransport(), + StreamId::from(VarInt::from_u32(4)), + Some(Protocol::new(WEBTRANSPORT_H3)), + )) + .expect_err("missing webtransport protocol layer is invalid"); + + assert!(matches!(error, RegisterSessionError::ProtocolLayerMissing)); + } + + #[tokio::test] + async fn try_from_rejects_duplicate_session_registration() { + let connection = connection_with_webtransport(); + let stream_id = StreamId::from(VarInt::from_u32(4)); + + let first = WebTransportSession::try_from(connect_on_connection( + Arc::clone(&connection), + stream_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + )) + .expect("first registration should succeed"); + + let error = WebTransportSession::try_from(connect_on_connection( + connection, + stream_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + )) + .expect_err("duplicate session id should be rejected"); + + assert!(matches!( + error, + RegisterSessionError::AlreadyRegistered { session_id } if session_id == wt_session_id(stream_id) + )); + + drop(first); + } + + #[tokio::test] + async fn debug_includes_session_id() { + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + let formatted = format!("{session:?}"); + assert!(formatted.contains("WebTransportSession")); + assert!(formatted.contains(&format!("{:?}", session.id()))); + } + + #[tokio::test] + async fn open_bi_writes_webtransport_header_and_keeps_writer_usable() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(TestConnection { + state: Arc::clone(&stream_state), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + let (_reader, mut writer) = session.open_bi().await.expect("open_bi should succeed"); + writer + .send(Bytes::from_static(&[0xaa, 0xbb])) + .await + .expect("writer should remain usable after header"); + + assert_eq!(stream_state.written(), vec![0x40, 0x41, 0x04, 0xaa, 0xbb]); + assert!(stream_state.flushes() >= 1); + assert!(stream_state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn open_bi_maps_header_write_failure_on_first_signal_write() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(HeaderFailConnection { + mode: HeaderFailureMode::SendOnCall(1), + state: Arc::clone(&stream_state), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + assert!(matches!( + session.open_bi().await, + Err(OpenStreamError::WriteHeader { .. }) + )); + assert!(stream_state.written().is_empty()); + } + + #[tokio::test] + async fn open_uni_writes_webtransport_header_and_keeps_writer_usable() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(TestConnection { + state: Arc::clone(&stream_state), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + let mut writer = session.open_uni().await.expect("open_uni should succeed"); + writer + .send(Bytes::from_static(&[0xcc])) + .await + .expect("writer should remain usable after header"); + + assert_eq!(stream_state.written(), vec![0x40, 0x54, 0x04, 0xcc]); + assert!(stream_state.flushes() >= 1); + assert!(stream_state.reset_codes().is_empty()); + } + + #[tokio::test] + async fn open_uni_maps_buffered_header_write_failure() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(HeaderFailConnection { + mode: HeaderFailureMode::SendOnCall(1), + state: Arc::clone(&stream_state), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + assert!(matches!( + session.open_uni().await, + Err(OpenStreamError::WriteHeader { .. }) + )); + assert!(stream_state.written().is_empty()); + } + + #[tokio::test] + async fn open_bi_maps_header_flush_failure() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(HeaderFailConnection { + mode: HeaderFailureMode::Flush, + state: Arc::clone(&stream_state), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + assert!(matches!( + session.open_bi().await, + Err(OpenStreamError::WriteHeader { .. }) + )); + assert_eq!(stream_state.written(), vec![0x40, 0x41, 0x04]); + } + + #[tokio::test] + async fn open_uni_maps_header_flush_failure() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(HeaderFailConnection { + mode: HeaderFailureMode::Flush, + state: Arc::clone(&stream_state), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + assert!(matches!( + session.open_uni().await, + Err(OpenStreamError::WriteHeader { .. }) + )); + assert_eq!(stream_state.written(), vec![0x40, 0x54, 0x04]); + } + + #[tokio::test] + async fn open_streams_map_connection_errors() { + let conn = Arc::new(MockConnection::new()); + let erased: Arc = conn.clone(); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), erased); + + let bi_error = match session.open_bi().await { + Ok(_) => panic!("open_bi should fail"), + Err(error) => error, + }; + let uni_error = match session.open_uni().await { + Ok(_) => panic!("open_uni should fail"), + Err(error) => error, + }; + + match bi_error { + OpenStreamError::Open { source } => { + assert_transport_reason(&source, "open_bi unavailable"); + } + other => panic!("expected open error, got {other:?}"), + } + + match uni_error { + OpenStreamError::Open { source } => { + assert_transport_reason(&source, "open_uni unavailable"); + } + other => panic!("expected open error, got {other:?}"), + } + } + + #[tokio::test] + async fn closed_session_maps_open_and_accept_to_closed_errors() { + let conn = Arc::new(MockConnection::new()); + let erased: Arc = conn.clone(); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), erased); + session.state.close(); + + assert!(matches!( + session.open_bi().await, + Err(OpenStreamError::Closed { .. }) + )); + assert!(matches!( + session.open_uni().await, + Err(OpenStreamError::Closed { .. }) + )); + assert!(matches!( + session.accept_bi().await, + Err(AcceptStreamError::Closed { .. }) + )); + assert!(matches!( + session.accept_uni().await, + Err(AcceptStreamError::Closed { .. }) + )); + assert!(conn.stream_calls().is_empty()); + } + + #[tokio::test] + async fn accept_streams_map_dropped_receivers_to_closed_errors() { + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let session = open_session_with_closed_receivers(StreamId::from(VarInt::from_u32(4)), conn); + + assert!(matches!( + session.accept_bi().await, + Err(AcceptStreamError::Closed { .. }) + )); + assert!(matches!( + session.accept_uni().await, + Err(AcceptStreamError::Closed { .. }) + )); + } + + #[tokio::test] + async fn drop_unregisters_session_and_rejects_late_routed_streams() { + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + assert_eq!(registry.len(), 1); + + let session_id = session.id(); + drop(session); + + let routed_state = Arc::new(StreamState::default()); + let rejected = registry.route_uni( + session_id, + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x55])]), + stream_id: VarInt::from_u32(13), + }) as BoxQuicStreamReader, + ); + + assert!(rejected.is_err()); + assert_eq!(registry.len(), 0); + assert!(routed_state.stopped_codes().is_empty()); + } + + #[tokio::test] + async fn accept_bi_returns_routed_stream() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(TestConnection { + state: Arc::clone(&stream_state), + }); + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + let routed_state = Arc::new(StreamState::default()); + + assert!( + registry + .route_bi( + session.id(), + ( + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x11, 0x22])]), + stream_id: VarInt::from_u32(11), + }) as BoxQuicStreamReader, + Box::pin(TestWriteStream { + state: Arc::clone(&routed_state), + stream_id: VarInt::from_u32(11), + }) as BoxQuicStreamWriter, + ), + ) + .is_ok() + ); + + let (mut reader, mut writer) = session.accept_bi().await.expect("accept_bi should succeed"); + writer + .send(Bytes::from_static(&[0x33])) + .await + .expect("routed writer should stay usable"); + + assert_eq!( + reader + .next() + .await + .expect("reader should yield a chunk") + .expect("reader chunk should succeed"), + Bytes::from_static(&[0x11, 0x22]) + ); + assert_eq!(routed_state.written(), vec![0x33]); + assert!(stream_state.written().is_empty()); + } + + #[tokio::test] + async fn session_close_stops_and_resets_tracked_bidi_stream() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(TestConnection { + state: Arc::clone(&stream_state), + }); + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + let routed_state = Arc::new(StreamState::default()); + + assert!( + registry + .route_bi( + session.id(), + ( + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x11])]), + stream_id: VarInt::from_u32(31), + }) as BoxQuicStreamReader, + Box::pin(TestWriteStream { + state: Arc::clone(&routed_state), + stream_id: VarInt::from_u32(31), + }) as BoxQuicStreamWriter, + ), + ) + .is_ok() + ); + + let (_reader, _writer) = session.accept_bi().await.expect("accept_bi should succeed"); + session.state.close(); + tokio::task::yield_now().await; + + assert_eq!( + routed_state.stopped_codes(), + vec![Code::WT_SESSION_GONE.into_inner()] + ); + assert_eq!( + routed_state.reset_codes(), + vec![Code::WT_SESSION_GONE.into_inner()] + ); + } + + #[tokio::test] + async fn reader_eof_removes_only_reader_tracking_before_session_close() { + let stream_state = Arc::new(StreamState::default()); + let conn: Arc = Arc::new(TestConnection { + state: Arc::clone(&stream_state), + }); + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + let routed_state = Arc::new(StreamState::default()); + + assert!( + registry + .route_bi( + session.id(), + ( + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x11])]), + stream_id: VarInt::from_u32(32), + }) as BoxQuicStreamReader, + Box::pin(TestWriteStream { + state: Arc::clone(&routed_state), + stream_id: VarInt::from_u32(32), + }) as BoxQuicStreamWriter, + ), + ) + .is_ok() + ); + + let (mut reader, _writer) = session.accept_bi().await.expect("accept_bi should succeed"); + assert_eq!( + reader + .next() + .await + .expect("chunk") + .expect("read should work"), + Bytes::from_static(&[0x11]) + ); + assert!(reader.next().await.is_none()); + + session.state.close(); + tokio::task::yield_now().await; + + assert!(routed_state.stopped_codes().is_empty()); + assert_eq!( + routed_state.reset_codes(), + vec![Code::WT_SESSION_GONE.into_inner()] + ); + } + + #[tokio::test] + async fn accept_uni_returns_routed_stream() { + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + let routed_state = Arc::new(StreamState::default()); + + assert!( + registry + .route_uni( + session.id(), + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x44])]), + stream_id: VarInt::from_u32(12), + }) as BoxQuicStreamReader, + ) + .is_ok() + ); + + let mut reader = session + .accept_uni() + .await + .expect("accept_uni should succeed"); + assert_eq!( + reader + .next() + .await + .expect("reader should yield a chunk") + .expect("reader chunk should succeed"), + Bytes::from_static(&[0x44]) + ); + assert!(routed_state.stopped_codes().is_empty()); + } + + #[tokio::test] + async fn accept_bi_maps_connection_closed() { + let conn = Arc::new(MockConnection::new()); + conn.set_terminal_error(quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: "connection closed".into(), + }, + }); + let erased: Arc = conn; + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), erased); + + match session.accept_bi().await { + Ok(_) => panic!("accept_bi should fail"), + Err(error) => match error { + AcceptStreamError::Connection { source } => { + assert_transport_reason(&source, "connection closed"); + } + other => panic!("expected connection error, got {other:?}"), + }, + } + } + + #[tokio::test] + async fn accept_bi_prefers_connection_closed_over_queued_stream() { + let conn = Arc::new(MockConnection::new()); + conn.set_terminal_error(quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: "connection closed with queued bidi".into(), + }, + }); + let erased: Arc = conn; + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), erased); + let routed_state = Arc::new(StreamState::default()); + assert!( + registry + .route_bi( + session.id(), + ( + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x66])]), + stream_id: VarInt::from_u32(14), + }) as BoxQuicStreamReader, + Box::pin(TestWriteStream { + state: Arc::clone(&routed_state), + stream_id: VarInt::from_u32(14), + }) as BoxQuicStreamWriter, + ), + ) + .is_ok() + ); + + match session.accept_bi().await { + Ok(_) => panic!("accept_bi should prefer closed connection"), + Err(AcceptStreamError::Connection { source }) => { + assert_transport_reason(&source, "connection closed with queued bidi"); + } + Err(other) => panic!("expected connection error, got {other:?}"), + } + assert_eq!(registry.len(), 0); + } + + #[tokio::test] + async fn accept_uni_prefers_connection_closed_over_queued_stream() { + let conn = Arc::new(MockConnection::new()); + conn.set_terminal_error(quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: "connection closed with queued uni".into(), + }, + }); + let erased: Arc = conn; + let (registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), erased); + let routed_state = Arc::new(StreamState::default()); + assert!( + registry + .route_uni( + session.id(), + Box::pin(TestReadStream { + state: Arc::clone(&routed_state), + chunks: VecDeque::from([Bytes::from_static(&[0x77])]), + stream_id: VarInt::from_u32(15), + }) as BoxQuicStreamReader, + ) + .is_ok() + ); + + match session.accept_uni().await { + Ok(_) => panic!("accept_uni should prefer closed connection"), + Err(AcceptStreamError::Connection { source }) => { + assert_transport_reason(&source, "connection closed with queued uni"); + } + Err(other) => panic!("expected connection error, got {other:?}"), + } + assert_eq!(registry.len(), 0); + } + + #[tokio::test] + async fn accept_uni_maps_connection_closed() { + let conn = Arc::new(MockConnection::new()); + conn.set_terminal_error(quic::ConnectionError::Transport { + source: quic::TransportError { + kind: VarInt::from_u32(1), + frame_type: VarInt::from_u32(0), + reason: "connection closed".into(), + }, + }); + let erased: Arc = conn; + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), erased); + + match session.accept_uni().await { + Ok(_) => panic!("accept_uni should fail"), + Err(error) => match error { + AcceptStreamError::Connection { source } => { + assert_transport_reason(&source, "connection closed"); + } + other => panic!("expected connection error, got {other:?}"), + }, + } + } + + #[tokio::test] + async fn datagram_methods_return_unsupported() { + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let (_registry, session) = + session_with_connection(StreamId::from(VarInt::from_u32(4)), conn); + + assert!(matches!( + session + .send_datagram(b"hello") + .await + .expect_err("send_datagram should be unsupported"), + DatagramError::Unsupported + )); + assert!(matches!( + session + .recv_datagram() + .await + .expect_err("recv_datagram should be unsupported"), + DatagramError::Unsupported + )); + } + + #[tokio::test] + async fn close_sends_close_capsule_finishes_connect_stream_and_closes_session() { + let (session, connect_state) = + webtransport_session_with_observed_connect_stream(StreamId::from(VarInt::from_u32(4))); + let close = CloseSession::try_from((7_u32, "done")).expect("valid close"); + + session.close(close.clone()).await.expect("close succeeds"); + + assert!(matches!(session.state.check_open(), Err(SessionClosed))); + assert_eq!(connect_state.next_close_capsule().await, close); + assert!(connect_state.writer_closed().await); + } + + #[tokio::test] + async fn drain_sends_drain_capsule_without_closing_session() { + let (session, connect_state) = + webtransport_session_with_observed_connect_stream(StreamId::from(VarInt::from_u32(4))); + + session.drain().await.expect("drain succeeds"); + + assert!(session.state.check_open().is_ok()); + connect_state.next_drain_capsule().await; + assert!(!connect_state.writer_closed().await); + } + + #[tokio::test] + async fn remote_close_capsule_closes_session_and_reports_remote_reason() { + let (session, connect_state) = + webtransport_session_with_observed_connect_stream(StreamId::from(VarInt::from_u32(4))); + let close = CloseSession::try_from((9_u32, "remote done")).expect("valid close"); + + connect_state.inject_close(close.clone()).await; + let reason = timeout(Duration::from_secs(1), session.closed()) + .await + .expect("remote close should close session"); + + assert_eq!( + reason, + CloseReason::Session(SessionCloseReason::Remote(close)) + ); + assert!(matches!( + session.open_uni().await, + Err(OpenStreamError::Closed { .. }) + )); + } + + #[tokio::test] + async fn remote_drain_capsule_wakes_drained_without_closing_session() { + let (session, connect_state) = + webtransport_session_with_observed_connect_stream(StreamId::from(VarInt::from_u32(4))); + + connect_state.inject_drain().await; + let drain = timeout(Duration::from_secs(1), session.drained()) + .await + .expect("remote drain should update drain state"); + + assert_eq!( + drain, + SessionDrain::Requested(DrainReason::Session(SessionDrainReason::Remote)) + ); + assert!(session.state.check_open().is_ok()); + } + + #[tokio::test] + async fn open_bi_sends_streams_blocked_and_waits_until_peer_max_streams_increases() { + let (session, connect_state) = webtransport_session_with_peer_credit( + StreamId::from(VarInt::from_u32(4)), + WebTransportStreamCount::ZERO, + WebTransportStreamCount::ZERO, + ); + + let open = session.open_bi(); + tokio::pin!(open); + timeout(Duration::from_millis(25), &mut open) + .await + .expect_err("open_bi should wait for peer stream credit"); + + assert_eq!( + connect_state.next_streams_blocked_bidi().await, + WebTransportStreamCount::ZERO + ); + + connect_state + .inject_capsule(max_streams_bidi_capsule(stream_count(1)).await) + .await; + let (_reader, _writer) = timeout(Duration::from_secs(1), open) + .await + .expect("open_bi should unblock after peer grants credit") + .expect("open_bi should succeed after peer grants credit"); + } + + #[tokio::test] + async fn decreasing_peer_max_streams_closes_session_with_flow_control_error() { + let (session, connect_state) = webtransport_session_with_peer_credit( + StreamId::from(VarInt::from_u32(4)), + stream_count(2), + WebTransportStreamCount::ZERO, + ); + + connect_state + .inject_capsule(max_streams_bidi_capsule(stream_count(1)).await) + .await; + + let reason = timeout(Duration::from_secs(1), session.closed()) + .await + .expect("decreasing max streams should close session"); + assert_eq!( + reason, + CloseReason::Session(SessionCloseReason::Protocol { + code: Code::WT_FLOW_CONTROL_ERROR + }) + ); + } + + #[tokio::test] + async fn control_task_closes_session_after_connect_stream_payload_and_eof() { + let registry = Registry::default(); + let session_id = StreamId::from(VarInt::from_u32(24)); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let connect = EstablishedConnect::ready( + session_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection_without_webtransport(), + read_stream_with_bytes(24, &[0x00, 0x03, b'a', b'b', b'c']).await, + write_stream_for_test(session_id.0), + ); + + let session = WebTransportSession::from_registered(registered, conn, connect); + wait_until_session_closed(&session).await; + assert_eq!(registry.len(), 0); + } + + #[tokio::test] + async fn control_task_closes_session_after_connect_stream_read_error() { + let _guard = enable_debug_tracing_for_test(); + let registry = Registry::default(); + let session_id = StreamId::from(VarInt::from_u32(28)); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let connect = EstablishedConnect::ready( + session_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection_without_webtransport(), + read_stream_with_reset(28, VarInt::from_u32(0xaa)).await, + write_stream_for_test(session_id.0), + ); + + let session = WebTransportSession::from_registered(registered, conn, connect); + wait_until_session_closed(&session).await; + assert_eq!(registry.len(), 0); + } + + #[tokio::test] + async fn control_task_closes_session_when_connect_takeover_fails() { + let _guard = enable_debug_tracing_for_test(); + let registry = Registry::default(); + let session_id = StreamId::from(VarInt::from_u32(32)); + let registered = registry + .register(wt_session_id(session_id)) + .expect("session registration should succeed"); + let conn: Arc = Arc::new(TestConnection { + state: Arc::new(StreamState::default()), + }); + let connect = EstablishedConnect::pending( + session_id, + Some(Protocol::new(WEBTRANSPORT_H3)), + connection_without_webtransport(), + read_stream_for_test(session_id.0), + future::ready(Err( + crate::extended_connect::PendingWriteStreamError::Aborted, + )), + ); + + let session = WebTransportSession::from_registered(registered, conn, connect); + wait_until_session_closed(&session).await; + assert_eq!(registry.len(), 0); + } } diff --git a/src/webtransport/session/stream.rs b/src/webtransport/session/stream.rs new file mode 100644 index 0000000..37924e3 --- /dev/null +++ b/src/webtransport/session/stream.rs @@ -0,0 +1,1137 @@ +use std::{ + fmt, + pin::Pin, + sync::{Arc, Mutex, Weak}, + task::{Context, Poll, Waker}, +}; + +use bytes::Bytes; +use futures::{Sink, Stream, task::AtomicWaker}; + +use crate::{ + quic::{self, BoxQuicStreamReader, BoxQuicStreamWriter, GetStreamId, ResetStream, StopStream}, + stream_id::StreamId, + varint::VarInt, + webtransport::registry::SessionState, +}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum TrackingState { + Tracked, + Untracked, +} + +struct InFlightWakers { + handle: AtomicWaker, + tracked: AtomicWaker, +} + +impl InFlightWakers { + fn new() -> Self { + Self { + handle: AtomicWaker::new(), + tracked: AtomicWaker::new(), + } + } + + fn wake_all(&self) { + self.handle.wake(); + self.tracked.wake(); + } +} + +struct InFlightControl { + code: VarInt, + wakers: Arc, +} + +struct ReaderCore { + stream_id: StreamId, + stream: S, + session: Weak, + tracking: TrackingState, + stop: Option, + stop_result: Option>, +} + +struct WriterCore { + stream_id: StreamId, + stream: S, + session: Weak, + tracking: TrackingState, + reset: Option, + reset_result: Option>, +} + +pub struct WebTransportStreamReader { + core: Arc>>, +} + +pub(in crate::webtransport) struct TrackedStreamReader { + core: Arc>>, +} + +pub struct WebTransportStreamWriter { + core: Arc>>, +} + +pub(in crate::webtransport) struct TrackedStreamWriter { + core: Arc>>, +} + +impl WebTransportStreamReader { + pub(super) fn tracked( + stream_id: StreamId, + stream: S, + session: Weak, + ) -> (Self, TrackedStreamReader) { + let core = Arc::new(Mutex::new(ReaderCore { + stream_id, + stream, + session, + tracking: TrackingState::Tracked, + stop: None, + stop_result: None, + })); + ( + Self { + core: Arc::clone(&core), + }, + TrackedStreamReader { core }, + ) + } +} + +impl WebTransportStreamWriter { + pub(super) fn tracked( + stream_id: StreamId, + stream: S, + session: Weak, + ) -> (Self, TrackedStreamWriter) { + let core = Arc::new(Mutex::new(WriterCore { + stream_id, + stream, + session, + tracking: TrackingState::Tracked, + reset: None, + reset_result: None, + })); + ( + Self { + core: Arc::clone(&core), + }, + TrackedStreamWriter { core }, + ) + } +} + +impl fmt::Debug for WebTransportStreamReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.core.lock() { + Ok(core) => f + .debug_struct("WebTransportStreamReader") + .field("stream_id", &core.stream_id) + .field("tracking", &core.tracking) + .finish(), + Err(_) => f + .debug_struct("WebTransportStreamReader") + .field("poisoned", &true) + .finish(), + } + } +} + +impl fmt::Debug for TrackedStreamReader { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.core.lock() { + Ok(core) => f + .debug_struct("TrackedStreamReader") + .field("stream_id", &core.stream_id) + .field("tracking", &core.tracking) + .finish(), + Err(_) => f + .debug_struct("TrackedStreamReader") + .field("poisoned", &true) + .finish(), + } + } +} + +impl fmt::Debug for WebTransportStreamWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.core.lock() { + Ok(core) => f + .debug_struct("WebTransportStreamWriter") + .field("stream_id", &core.stream_id) + .field("tracking", &core.tracking) + .finish(), + Err(_) => f + .debug_struct("WebTransportStreamWriter") + .field("poisoned", &true) + .finish(), + } + } +} + +impl fmt::Debug for TrackedStreamWriter { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.core.lock() { + Ok(core) => f + .debug_struct("TrackedStreamWriter") + .field("stream_id", &core.stream_id) + .field("tracking", &core.tracking) + .finish(), + Err(_) => f + .debug_struct("TrackedStreamWriter") + .field("poisoned", &true) + .finish(), + } + } +} + +fn mark_reader_untracked( + core: &Arc>>, +) -> Option<(Weak, StreamId)> { + let mut core = core.lock().expect("webtransport reader core lock poisoned"); + if core.tracking == TrackingState::Tracked { + core.tracking = TrackingState::Untracked; + Some((core.session.clone(), core.stream_id)) + } else { + None + } +} + +fn remove_reader_tracking(core: &Arc>>) { + let Some((session, stream_id)) = mark_reader_untracked(core) else { + return; + }; + if let Some(session) = session.upgrade() { + session.remove_tracked_reader(stream_id); + } +} + +fn remove_reader_tracking_if_untracked(core: &Arc>>) { + let (session, stream_id, untracked) = { + let core = core.lock().expect("webtransport reader core lock poisoned"); + ( + core.session.clone(), + core.stream_id, + core.tracking == TrackingState::Untracked, + ) + }; + if untracked && let Some(session) = session.upgrade() { + session.remove_tracked_reader(stream_id); + } +} + +fn mark_writer_untracked( + core: &Arc>>, +) -> Option<(Weak, StreamId)> { + let mut core = core.lock().expect("webtransport writer core lock poisoned"); + if core.tracking == TrackingState::Tracked { + core.tracking = TrackingState::Untracked; + Some((core.session.clone(), core.stream_id)) + } else { + None + } +} + +fn remove_writer_tracking(core: &Arc>>) { + let Some((session, stream_id)) = mark_writer_untracked(core) else { + return; + }; + if let Some(session) = session.upgrade() { + session.remove_tracked_writer(stream_id); + } +} + +fn remove_writer_tracking_if_untracked(core: &Arc>>) { + let (session, stream_id, untracked) = { + let core = core.lock().expect("webtransport writer core lock poisoned"); + ( + core.session.clone(), + core.stream_id, + core.tracking == TrackingState::Untracked, + ) + }; + if untracked && let Some(session) = session.upgrade() { + session.remove_tracked_writer(stream_id); + } +} + +impl Drop for WebTransportStreamReader { + fn drop(&mut self) { + remove_reader_tracking_if_untracked(&self.core); + } +} + +impl Drop for WebTransportStreamWriter { + fn drop(&mut self) { + remove_writer_tracking_if_untracked(&self.core); + } +} + +struct AggregateWake { + wakers: Arc, +} + +impl std::task::Wake for AggregateWake { + fn wake(self: Arc) { + self.wakers.wake_all(); + } + + fn wake_by_ref(self: &Arc) { + self.wakers.wake_all(); + } +} + +fn with_aggregate_context( + wakers: Arc, + f: impl FnOnce(&mut Context<'_>) -> T, +) -> T { + let waker = Waker::from(Arc::new(AggregateWake { wakers })); + let mut cx = Context::from_waker(&waker); + f(&mut cx) +} + +impl GetStreamId for WebTransportStreamReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let core = self + .core + .lock() + .expect("webtransport reader core lock poisoned"); + Poll::Ready(Ok(core.stream_id.into())) + } +} + +impl GetStreamId for TrackedStreamReader { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let core = self + .core + .lock() + .expect("webtransport reader core lock poisoned"); + Poll::Ready(Ok(core.stream_id.into())) + } +} + +impl GetStreamId for WebTransportStreamWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let core = self + .core + .lock() + .expect("webtransport writer core lock poisoned"); + Poll::Ready(Ok(core.stream_id.into())) + } +} + +impl GetStreamId for TrackedStreamWriter { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let core = self + .core + .lock() + .expect("webtransport writer core lock poisoned"); + Poll::Ready(Ok(core.stream_id.into())) + } +} + +impl Stream for WebTransportStreamReader +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = { + let mut core = self + .core + .lock() + .expect("webtransport reader core lock poisoned"); + Pin::new(&mut core.stream).poll_next(cx) + }; + + match &poll { + Poll::Ready(None | Some(Err(_))) => remove_reader_tracking(&self.core), + Poll::Ready(Some(Ok(_))) | Poll::Pending => {} + } + + poll + } +} + +impl Sink for WebTransportStreamWriter +where + S: Sink + Unpin, +{ + type Error = quic::StreamError; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = { + let mut core = self + .core + .lock() + .expect("webtransport writer core lock poisoned"); + Pin::new(&mut core.stream).poll_ready(cx) + }; + if matches!(poll, Poll::Ready(Err(_))) { + remove_writer_tracking(&self.core); + } + poll + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + let result = { + let mut core = self + .core + .lock() + .expect("webtransport writer core lock poisoned"); + Pin::new(&mut core.stream).start_send(item) + }; + if result.is_err() { + remove_writer_tracking(&self.core); + } + result + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = { + let mut core = self + .core + .lock() + .expect("webtransport writer core lock poisoned"); + Pin::new(&mut core.stream).poll_flush(cx) + }; + if matches!(poll, Poll::Ready(Err(_))) { + remove_writer_tracking(&self.core); + } + poll + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let poll = { + let mut core = self + .core + .lock() + .expect("webtransport writer core lock poisoned"); + Pin::new(&mut core.stream).poll_close(cx) + }; + match poll { + Poll::Ready(Ok(()) | Err(_)) => remove_writer_tracking(&self.core), + Poll::Pending => {} + } + poll + } +} + +fn poll_reader_stop( + core: &Arc>>, + cx: &mut Context<'_>, + code: VarInt, + tracked_side: bool, +) -> Poll> +where + S: StopStream + Unpin, +{ + let (committed_code, wakers) = { + let mut core = core.lock().expect("webtransport reader core lock poisoned"); + if let Some(result) = core.stop_result.clone() { + return Poll::Ready(result); + } + let control = core.stop.get_or_insert_with(|| InFlightControl { + code, + wakers: Arc::new(InFlightWakers::new()), + }); + if tracked_side { + control.wakers.tracked.register(cx.waker()); + } else { + control.wakers.handle.register(cx.waker()); + } + (control.code, Arc::clone(&control.wakers)) + }; + + let poll_wakers = Arc::clone(&wakers); + let poll = with_aggregate_context(poll_wakers, |cx| { + let mut core = core.lock().expect("webtransport reader core lock poisoned"); + Pin::new(&mut core.stream).poll_stop(cx, committed_code) + }); + + match poll { + Poll::Ready(result) => { + { + let mut core = core.lock().expect("webtransport reader core lock poisoned"); + core.stop = None; + core.stop_result = Some(result.clone()); + } + remove_reader_tracking(core); + wakers.wake_all(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } +} + +impl StopStream for WebTransportStreamReader +where + S: StopStream + Unpin, +{ + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + poll_reader_stop(&self.core, cx, code, false) + } +} + +impl StopStream for TrackedStreamReader +where + S: StopStream + Unpin, +{ + fn poll_stop( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + poll_reader_stop(&self.core, cx, code, true) + } +} + +fn poll_writer_reset( + core: &Arc>>, + cx: &mut Context<'_>, + code: VarInt, + tracked_side: bool, +) -> Poll> +where + S: ResetStream + Unpin, +{ + let (committed_code, wakers) = { + let mut core = core.lock().expect("webtransport writer core lock poisoned"); + if let Some(result) = core.reset_result.clone() { + return Poll::Ready(result); + } + let control = core.reset.get_or_insert_with(|| InFlightControl { + code, + wakers: Arc::new(InFlightWakers::new()), + }); + if tracked_side { + control.wakers.tracked.register(cx.waker()); + } else { + control.wakers.handle.register(cx.waker()); + } + (control.code, Arc::clone(&control.wakers)) + }; + + let poll_wakers = Arc::clone(&wakers); + let poll = with_aggregate_context(poll_wakers, |cx| { + let mut core = core.lock().expect("webtransport writer core lock poisoned"); + Pin::new(&mut core.stream).poll_reset(cx, committed_code) + }); + + match poll { + Poll::Ready(result) => { + { + let mut core = core.lock().expect("webtransport writer core lock poisoned"); + core.reset = None; + core.reset_result = Some(result.clone()); + } + remove_writer_tracking(core); + wakers.wake_all(); + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } +} + +impl ResetStream for WebTransportStreamWriter +where + S: ResetStream + Unpin, +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + poll_writer_reset(&self.core, cx, code, false) + } +} + +impl ResetStream for TrackedStreamWriter +where + S: ResetStream + Unpin, +{ + fn poll_reset( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + poll_writer_reset(&self.core, cx, code, true) + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::VecDeque, + pin::Pin, + sync::{ + Arc, Mutex, + atomic::{AtomicBool, AtomicUsize, Ordering}, + }, + task::{Context, Poll, Wake, Waker}, + }; + + use bytes::Bytes; + use futures::{Sink, SinkExt, Stream, StreamExt, task::noop_waker_ref}; + + use super::*; + use crate::{ + quic::{self, ResetStream, StopStream}, + stream_id::StreamId, + varint::VarInt, + }; + + #[derive(Debug)] + struct TestReadStream { + chunks: VecDeque>, + stream_id: VarInt, + stopped: Arc>>, + } + + impl TestReadStream { + fn with_chunks(stream_id: u32, chunks: [Bytes; N]) -> Self { + Self { + chunks: chunks.into_iter().map(Ok).collect(), + stream_id: VarInt::from_u32(stream_id), + stopped: Arc::new(Mutex::new(Vec::new())), + } + } + + fn stopped_codes(&self) -> Vec { + self.stopped.lock().expect("stopped lock poisoned").clone() + } + } + + impl Stream for TestReadStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(self.chunks.pop_front()) + } + } + + impl quic::GetStreamId for TestReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl StopStream for TestReadStream { + fn poll_stop( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.stopped + .lock() + .expect("stopped lock poisoned") + .push(code); + Poll::Ready(Ok(())) + } + } + + #[derive(Debug, Default)] + struct TestWriteState { + written: Mutex>, + flushes: Mutex, + closes: Mutex, + resets: Mutex>, + } + + impl TestWriteState { + fn written(&self) -> Vec { + self.written.lock().expect("written lock poisoned").clone() + } + + fn flushes(&self) -> usize { + *self.flushes.lock().expect("flushes lock poisoned") + } + + fn closes(&self) -> usize { + *self.closes.lock().expect("closes lock poisoned") + } + } + + #[derive(Debug)] + struct TestWriteStream { + state: Arc, + stream_id: VarInt, + } + + impl TestWriteStream { + fn new(stream_id: u32, state: Arc) -> Self { + Self { + state, + stream_id: VarInt::from_u32(stream_id), + } + } + } + + impl Sink for TestWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> { + self.state + .written + .lock() + .expect("written lock poisoned") + .extend_from_slice(&item); + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let mut flushes = self.state.flushes.lock().expect("flushes lock poisoned"); + *flushes += 1; + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let mut closes = self.state.closes.lock().expect("closes lock poisoned"); + *closes += 1; + Poll::Ready(Ok(())) + } + } + + impl quic::GetStreamId for TestWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl ResetStream for TestWriteStream { + fn poll_reset( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.state + .resets + .lock() + .expect("resets lock poisoned") + .push(code); + Poll::Ready(Ok(())) + } + } + + #[derive(Debug)] + struct PendingStopReadStream { + stream_id: VarInt, + stop_codes: Vec, + ready: Arc, + waker: Option, + } + + impl PendingStopReadStream { + fn new(stream_id: u32) -> Self { + Self { + stream_id: VarInt::from_u32(stream_id), + stop_codes: Vec::new(), + ready: Arc::new(AtomicBool::new(false)), + waker: None, + } + } + + fn with_ready(stream_id: u32, ready: Arc) -> Self { + Self { + stream_id: VarInt::from_u32(stream_id), + stop_codes: Vec::new(), + ready, + waker: None, + } + } + + fn stop_codes(&self) -> Vec { + self.stop_codes.clone() + } + } + + impl Stream for PendingStopReadStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + } + + impl quic::GetStreamId for PendingStopReadStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl StopStream for PendingStopReadStream { + fn poll_stop( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.stop_codes.push(code); + self.waker = Some(cx.waker().clone()); + if self.ready.load(Ordering::SeqCst) { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + } + + #[derive(Debug)] + struct PendingResetWriteStream { + stream_id: VarInt, + reset_codes: Vec, + ready: Arc, + waker: Option, + } + + impl PendingResetWriteStream { + fn new(stream_id: u32) -> Self { + Self { + stream_id: VarInt::from_u32(stream_id), + reset_codes: Vec::new(), + ready: Arc::new(AtomicBool::new(false)), + waker: None, + } + } + + fn with_ready(stream_id: u32, ready: Arc) -> Self { + Self { + stream_id: VarInt::from_u32(stream_id), + reset_codes: Vec::new(), + ready, + waker: None, + } + } + + fn reset_codes(&self) -> Vec { + self.reset_codes.clone() + } + } + + impl Sink for PendingResetWriteStream { + type Error = quic::StreamError; + + fn poll_ready( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } + + fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> { + Ok(()) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Pending + } + } + + impl quic::GetStreamId for PendingResetWriteStream { + fn poll_stream_id( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(self.stream_id)) + } + } + + impl ResetStream for PendingResetWriteStream { + fn poll_reset( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + code: VarInt, + ) -> Poll> { + self.reset_codes.push(code); + self.waker = Some(cx.waker().clone()); + if self.ready.load(Ordering::SeqCst) { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + } + + #[derive(Debug)] + struct CountWake { + wakes: Arc, + } + + impl Wake for CountWake { + fn wake(self: Arc) { + self.wakes.fetch_add(1, Ordering::SeqCst); + } + + fn wake_by_ref(self: &Arc) { + self.wakes.fetch_add(1, Ordering::SeqCst); + } + } + + fn count_waker() -> (Waker, Arc) { + let wakes = Arc::new(AtomicUsize::new(0)); + ( + Waker::from(Arc::new(CountWake { + wakes: Arc::clone(&wakes), + })), + wakes, + ) + } + + fn poll_once(f: impl FnOnce(&mut Context<'_>) -> Poll) -> Poll { + let mut cx = Context::from_waker(noop_waker_ref()); + f(&mut cx) + } + + #[tokio::test] + async fn ordinary_reader_io_delegates_to_inner_stream() { + let session = std::sync::Weak::new(); + let inner = TestReadStream::with_chunks(7, [Bytes::from_static(b"chunk")]); + let (mut reader, _tracked) = + WebTransportStreamReader::tracked(StreamId::from(VarInt::from_u32(7)), inner, session); + + assert_eq!( + reader.next().await.expect("chunk").expect("ok"), + Bytes::from_static(b"chunk") + ); + assert!(reader.next().await.is_none()); + } + + #[tokio::test] + async fn ordinary_writer_io_delegates_to_inner_sink() { + let session = std::sync::Weak::new(); + let state = Arc::new(TestWriteState::default()); + let inner = TestWriteStream::new(8, Arc::clone(&state)); + let (mut writer, _tracked) = + WebTransportStreamWriter::tracked(StreamId::from(VarInt::from_u32(8)), inner, session); + + writer + .feed(Bytes::from_static(b"payload")) + .await + .expect("feed"); + writer.flush().await.expect("flush"); + writer.close().await.expect("close"); + + assert_eq!(state.written(), b"payload".to_vec()); + assert_eq!(state.flushes(), 1); + assert_eq!(state.closes(), 1); + } + + #[tokio::test] + async fn reader_stop_does_not_imply_eos() { + let session = std::sync::Weak::new(); + let inner = TestReadStream::with_chunks(9, [Bytes::from_static(b"after-stop")]); + let (mut reader, _tracked) = + WebTransportStreamReader::tracked(StreamId::from(VarInt::from_u32(9)), inner, session); + + let mut cx = Context::from_waker(noop_waker_ref()); + assert!( + Pin::new(&mut reader) + .poll_stop(&mut cx, VarInt::from_u32(0x31)) + .is_ready() + ); + + { + let core = reader.core.lock().expect("reader core lock poisoned"); + assert_eq!(core.stream.stopped_codes(), vec![VarInt::from_u32(0x31)]); + } + assert_eq!( + reader.next().await.expect("chunk").expect("ok"), + Bytes::from_static(b"after-stop") + ); + assert!(reader.next().await.is_none()); + } + + #[test] + fn concurrent_reader_stop_uses_first_committed_code() { + let session = std::sync::Weak::new(); + let inner = PendingStopReadStream::new(10); + let (mut reader, mut tracked) = + WebTransportStreamReader::tracked(StreamId::from(VarInt::from_u32(10)), inner, session); + + let first = VarInt::from_u32(0x31); + let second = VarInt::from_u32(0x32); + + assert!(poll_once(|cx| Pin::new(&mut reader).poll_stop(cx, first)).is_pending()); + assert!(poll_once(|cx| Pin::new(&mut tracked).poll_stop(cx, second)).is_pending()); + + let core = reader.core.lock().expect("reader core lock poisoned"); + assert_eq!(core.stream.stop_codes(), vec![first, first]); + } + + #[test] + fn completed_reader_stop_is_shared_by_both_control_sides() { + let session = std::sync::Weak::new(); + let ready = Arc::new(AtomicBool::new(false)); + let inner = PendingStopReadStream::with_ready(14, Arc::clone(&ready)); + let (mut reader, mut tracked) = + WebTransportStreamReader::tracked(StreamId::from(VarInt::from_u32(14)), inner, session); + + let first = VarInt::from_u32(0x71); + let second = VarInt::from_u32(0x72); + + assert!(poll_once(|cx| Pin::new(&mut reader).poll_stop(cx, first)).is_pending()); + assert!(poll_once(|cx| Pin::new(&mut tracked).poll_stop(cx, second)).is_pending()); + ready.store(true, Ordering::SeqCst); + + assert!(poll_once(|cx| Pin::new(&mut reader).poll_stop(cx, first)).is_ready()); + assert!(poll_once(|cx| Pin::new(&mut tracked).poll_stop(cx, second)).is_ready()); + + let core = reader.core.lock().expect("reader core lock poisoned"); + assert_eq!(core.stream.stop_codes(), vec![first, first, first]); + } + + #[test] + fn concurrent_writer_reset_uses_first_committed_code() { + let session = std::sync::Weak::new(); + let inner = PendingResetWriteStream::new(11); + let (mut writer, mut tracked) = + WebTransportStreamWriter::tracked(StreamId::from(VarInt::from_u32(11)), inner, session); + + let first = VarInt::from_u32(0x41); + let second = VarInt::from_u32(0x42); + + assert!(poll_once(|cx| Pin::new(&mut writer).poll_reset(cx, first)).is_pending()); + assert!(poll_once(|cx| Pin::new(&mut tracked).poll_reset(cx, second)).is_pending()); + + let core = writer.core.lock().expect("writer core lock poisoned"); + assert_eq!(core.stream.reset_codes(), vec![first, first]); + } + + #[test] + fn completed_writer_reset_is_shared_by_both_control_sides() { + let session = std::sync::Weak::new(); + let ready = Arc::new(AtomicBool::new(false)); + let inner = PendingResetWriteStream::with_ready(15, Arc::clone(&ready)); + let (mut writer, mut tracked) = + WebTransportStreamWriter::tracked(StreamId::from(VarInt::from_u32(15)), inner, session); + + let first = VarInt::from_u32(0x81); + let second = VarInt::from_u32(0x82); + + assert!(poll_once(|cx| Pin::new(&mut writer).poll_reset(cx, first)).is_pending()); + assert!(poll_once(|cx| Pin::new(&mut tracked).poll_reset(cx, second)).is_pending()); + ready.store(true, Ordering::SeqCst); + + assert!(poll_once(|cx| Pin::new(&mut writer).poll_reset(cx, first)).is_ready()); + assert!(poll_once(|cx| Pin::new(&mut tracked).poll_reset(cx, second)).is_ready()); + + let core = writer.core.lock().expect("writer core lock poisoned"); + assert_eq!(core.stream.reset_codes(), vec![first, first, first]); + } + + #[test] + fn aggregate_wake_wakes_both_reader_control_sides() { + let session = std::sync::Weak::new(); + let inner = PendingStopReadStream::new(12); + let (mut reader, mut tracked) = + WebTransportStreamReader::tracked(StreamId::from(VarInt::from_u32(12)), inner, session); + let (handle_waker, handle_wakes) = count_waker(); + let (tracked_waker, tracked_wakes) = count_waker(); + let mut handle_cx = Context::from_waker(&handle_waker); + let mut tracked_cx = Context::from_waker(&tracked_waker); + + assert!( + Pin::new(&mut reader) + .poll_stop(&mut handle_cx, VarInt::from_u32(0x51)) + .is_pending() + ); + assert!( + Pin::new(&mut tracked) + .poll_stop(&mut tracked_cx, VarInt::from_u32(0x52)) + .is_pending() + ); + + let aggregate = { + let core = reader.core.lock().expect("reader core lock poisoned"); + core.stream.waker.as_ref().expect("aggregate waker").clone() + }; + aggregate.wake_by_ref(); + + assert_eq!(handle_wakes.load(Ordering::SeqCst), 1); + assert_eq!(tracked_wakes.load(Ordering::SeqCst), 1); + } + + #[test] + fn aggregate_wake_wakes_both_writer_control_sides() { + let session = std::sync::Weak::new(); + let inner = PendingResetWriteStream::new(13); + let (mut writer, mut tracked) = + WebTransportStreamWriter::tracked(StreamId::from(VarInt::from_u32(13)), inner, session); + let (handle_waker, handle_wakes) = count_waker(); + let (tracked_waker, tracked_wakes) = count_waker(); + let mut handle_cx = Context::from_waker(&handle_waker); + let mut tracked_cx = Context::from_waker(&tracked_waker); + + assert!( + Pin::new(&mut writer) + .poll_reset(&mut handle_cx, VarInt::from_u32(0x61)) + .is_pending() + ); + assert!( + Pin::new(&mut tracked) + .poll_reset(&mut tracked_cx, VarInt::from_u32(0x62)) + .is_pending() + ); + + let aggregate = { + let core = writer.core.lock().expect("writer core lock poisoned"); + core.stream.waker.as_ref().expect("aggregate waker").clone() + }; + aggregate.wake_by_ref(); + + assert_eq!(handle_wakes.load(Ordering::SeqCst), 1); + assert_eq!(tracked_wakes.load(Ordering::SeqCst), 1); + } +} diff --git a/src/webtransport/session_id.rs b/src/webtransport/session_id.rs new file mode 100644 index 0000000..92302ad --- /dev/null +++ b/src/webtransport/session_id.rs @@ -0,0 +1,91 @@ +use std::fmt; + +use snafu::Snafu; + +use crate::{ + error::{Code, H3ConnectionError}, + stream_id::StreamId, +}; + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct WebTransportSessionId(StreamId); + +impl WebTransportSessionId { + pub const fn stream_id(self) -> StreamId { + self.0 + } +} + +impl TryFrom for WebTransportSessionId { + type Error = InvalidSessionId; + + fn try_from(session_id: StreamId) -> Result { + if session_id.is_client_initiated_bidirectional() { + Ok(Self(session_id)) + } else { + Err(InvalidSessionId { session_id }) + } + } +} + +impl From for StreamId { + fn from(session_id: WebTransportSessionId) -> Self { + session_id.0 + } +} + +impl fmt::Display for WebTransportSessionId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Snafu, Clone, Copy, PartialEq, Eq)] +#[snafu(display( + "webtransport session id {session_id} is not a client-initiated bidirectional stream id" +))] +pub struct InvalidSessionId { + session_id: StreamId, +} + +impl InvalidSessionId { + pub const fn session_id(&self) -> StreamId { + self.session_id + } +} + +impl H3ConnectionError for InvalidSessionId { + fn code(&self) -> Code { + Code::H3_ID_ERROR + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + error::{Code, H3ConnectionError}, + stream_id::StreamId, + varint::VarInt, + }; + + #[test] + fn session_id_accepts_only_client_initiated_bidirectional_stream_id() { + let raw = StreamId::from(VarInt::from_u32(0)); + let session_id = WebTransportSessionId::try_from(raw).expect("valid session id"); + + assert_eq!(session_id.stream_id(), raw); + assert_eq!(StreamId::from(session_id), raw); + } + + #[test] + fn invalid_session_id_is_h3_id_error() { + let error = WebTransportSessionId::try_from(StreamId::from(VarInt::from_u32(3))) + .expect_err("server uni cannot be a session id"); + + assert_eq!(error.session_id(), StreamId::from(VarInt::from_u32(3))); + assert_eq!(H3ConnectionError::code(&error), Code::H3_ID_ERROR); + } +} diff --git a/src/webtransport/stream_count.rs b/src/webtransport/stream_count.rs new file mode 100644 index 0000000..dc09649 --- /dev/null +++ b/src/webtransport/stream_count.rs @@ -0,0 +1,140 @@ +use std::{convert::Infallible, io}; + +use snafu::{ResultExt, Snafu}; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::{ + buflist::BufList, + codec::{DecodeExt, DecodeFrom, EncodeExt, EncodeInto}, + varint::VarInt, +}; + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct WebTransportStreamCount(VarInt); + +impl WebTransportStreamCount { + pub const ZERO: Self = Self(VarInt::from_u32(0)); + pub const MAX_VALUE: VarInt = match VarInt::from_u64(0x0fff_ffff_ffff_ffff) { + Ok(value) => value, + Err(_) => panic!("2^60 - 1 is a valid QUIC varint"), + }; + + pub const fn into_varint(self) -> VarInt { + self.0 + } + + pub fn checked_increment(self) -> Result { + let next = VarInt::from_u64(self.0.into_inner() + 1) + .expect("a valid webtransport stream count increment is a valid QUIC varint"); + Self::try_from(next) + } +} + +impl TryFrom for WebTransportStreamCount { + type Error = InvalidWebTransportStreamCount; + + fn try_from(value: VarInt) -> Result { + if value <= Self::MAX_VALUE { + Ok(Self(value)) + } else { + Err(InvalidWebTransportStreamCount { value }) + } + } +} + +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[derive(Debug, Snafu, Clone, Copy, PartialEq, Eq)] +#[snafu(display("webtransport stream count {value} exceeds 2^60 - 1"))] +pub struct InvalidWebTransportStreamCount { + value: VarInt, +} + +impl InvalidWebTransportStreamCount { + pub const fn value(&self) -> VarInt { + self.value + } +} + +#[derive(Debug, Snafu)] +#[snafu(module(decode_webtransport_stream_count_error), visibility(pub(super)))] +pub enum DecodeWebTransportStreamCountError { + #[snafu(display("failed to decode webtransport stream count"))] + Decode { source: io::Error }, + #[snafu(display("invalid webtransport stream count"))] + Invalid { + source: InvalidWebTransportStreamCount, + }, +} + +impl DecodeFrom for WebTransportStreamCount +where + S: AsyncRead + Unpin + Send, +{ + type Error = DecodeWebTransportStreamCountError; + + async fn decode_from(mut stream: S) -> Result { + let value = stream + .decode_one::() + .await + .context(decode_webtransport_stream_count_error::DecodeSnafu)?; + Self::try_from(value).context(decode_webtransport_stream_count_error::InvalidSnafu) + } +} + +impl<'s, S> EncodeInto<&'s mut S> for WebTransportStreamCount +where + S: AsyncWrite + Unpin + Send, +{ + type Output = (); + type Error = io::Error; + + async fn encode_into(self, stream: &'s mut S) -> Result { + self.into_varint().encode_into(stream).await + } +} + +impl EncodeInto for WebTransportStreamCount { + type Output = BufList; + type Error = Infallible; + + async fn encode_into(self, mut stream: BufList) -> Result { + stream + .encode_one(self) + .await + .expect("encoding a webtransport stream count into a BufList is infallible"); + Ok(stream) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn stream_count_accepts_webtransport_boundaries() { + let zero = WebTransportStreamCount::try_from(VarInt::from_u32(0)).expect("zero"); + assert_eq!(zero.into_varint(), VarInt::from_u32(0)); + + let max = + WebTransportStreamCount::try_from(WebTransportStreamCount::MAX_VALUE).expect("max"); + assert_eq!(max.into_varint(), WebTransportStreamCount::MAX_VALUE); + } + + #[test] + fn stream_count_rejects_above_webtransport_limit() { + let value = VarInt::from_u64(1 << 60).expect("valid varint"); + let error = WebTransportStreamCount::try_from(value).expect_err("above stream-count limit"); + + assert_eq!(error.value(), value); + } + + #[test] + fn checked_increment_preserves_varint_domain() { + let count = WebTransportStreamCount::try_from(VarInt::from_u32(7)).expect("count"); + assert_eq!( + count.checked_increment().expect("increment").into_varint(), + VarInt::from_u32(8) + ); + } +} diff --git a/tests/axum.rs b/tests/axum.rs deleted file mode 100644 index 8f8db9d..0000000 --- a/tests/axum.rs +++ /dev/null @@ -1,284 +0,0 @@ -mod common; -use std::pin::pin; - -use axum::{ - Router, - body::Body, - routing::{any, get}, -}; -use common::*; -use h3x::{ - hyper::{server::TowerService, upgrade}, - qpack::field::Protocol, - server, -}; -use http::{Request, StatusCode}; -use http_body_util::{BodyExt, combinators::UnsyncBoxBody}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio_util::{ - io::{CopyToBytes, SinkWriter, StreamReader}, - task::AbortOnDropHandle, -}; - -#[test] -fn axum_hello_world() { - run("axum_hello_world", async move { - let router = Router::new() - .route("/hello_world", get(|| async { "Hello, World!" })) - .into_service(); - let mut server = test_server(TowerService(router)).await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - - let connection = client - .connect(host.clone()) - .await - .expect("failed to connect to server"); - - let response = connection - .execute_hyper_request( - Request::builder() - .method("GET") - .uri(format!("https://{host}/hello_world",)) - .body(Body::empty()) - .expect("failed to build request"), - ) - .await - .expect("failed to execute request") - .map(UnsyncBoxBody::new); - - let body = response - .collect() - .await - .expect("failed to read body") - .to_bytes(); - - assert_eq!(&body[..], b"Hello, World!"); - }) -} - -pub const INTERIM_RESPONSE_COUNT: usize = 3; - -async fn interim_response_service(_request: &mut server::Request, response: &mut server::Response) { - for _ in 0..INTERIM_RESPONSE_COUNT { - response - .set_status(http::StatusCode::CONTINUE) - .flush() - .await - .expect("failed to send interim response"); - } - response - .set_status(http::StatusCode::OK) - .set_body(b"37" as &[u8]) - .close() - .await - .expect("failed to send final response"); -} - -#[test] -fn interim_response() { - run("interim_response", async move { - let mut server = - test_server(server::Router::new().get("/ultimate_answer", interim_response_service)) - .await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - - let connection = client - .connect(host.clone()) - .await - .expect("failed to connect to server"); - - let response = connection - .execute_hyper_request( - Request::get(format!("https://{host}/ultimate_answer")) - .body(Body::empty()) - .expect("failed to build request"), - ) - .await - .expect("failed to execute request") - .map(UnsyncBoxBody::new); - - assert_eq!(response.status(), http::StatusCode::OK); - let body = response - .collect() - .await - .expect("failed to read body") - .to_bytes(); - - assert_eq!(&body[..], b"37"); - }) -} - -const CONNECTED_REQUEST: &str = "GET / HTTP/1.1\r\nHost: example.org\r\nConnection: close\r\n\r\n"; -const CONNECTED_RESPONSE: &str = "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nHello"; - -#[axum::debug_handler] -async fn mock_connect_service(request: axum::extract::Request) -> Result<(), StatusCode> { - let (Some(host), Some(port)) = (request.uri().host(), request.uri().port_u16()) else { - return Err(StatusCode::BAD_REQUEST); - }; - let host = host.to_string(); - - if host != "example.org" || port != 80 { - return Err(StatusCode::FORBIDDEN); - } - - tokio::spawn(async move { - let (mut read_stream, mut write_stream) = upgrade::on(request) - .await - .expect("failed to establish tunnel"); - tracing::info!("tunnel established to {host}:{port}"); - - let mut buf = Vec::with_capacity(CONNECTED_REQUEST.len()); - read_stream - .read_to_end(&mut buf) - .await - .expect("failed to read from tunnel"); - tracing::info!(request = %String::from_utf8_lossy(&buf), "Tunnel received request"); - - assert_eq!(&buf[..], CONNECTED_REQUEST.as_bytes()); - write_stream - .write_all(CONNECTED_RESPONSE.as_bytes()) - .await - .expect("failed to write to tunnel"); - write_stream - .shutdown() - .await - .expect("failed to close write stream"); - }); - - Ok(()) -} - -#[test] -fn axum_connect() { - run("axum_connect", async move { - let router = Router::new() - .route("/", any(mock_connect_service)) - .into_service(); - let mut server = test_server(TowerService(router)).await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - - let connection = client - .connect(host.clone()) - .await - .expect("failed to connect to server"); - - let response = connection - .execute_hyper_request( - // FIXME: correct way to build CONNECT request? - Request::connect("https://example.org:80") - .body(Body::empty()) - .expect("failed to build request"), - ) - .await - .expect("failed to execute request") - .map(UnsyncBoxBody::new); - - assert_eq!(response.status(), StatusCode::OK); - - let (mut read_stream, mut write_stream) = upgrade::on(response) - .await - .expect("failed to upgrade to tunnel"); - - write_stream - .write_all(CONNECTED_REQUEST.as_bytes()) - .await - .expect("failed to write to tunnel"); - write_stream - .shutdown() - .await - .expect("failed to close write stream"); - tracing::info!("Sent connected request"); - - let mut buf = Vec::new(); - read_stream - .read_to_end(&mut buf) - .await - .expect("failed to read from tunnel"); - - assert_eq!(&buf[..], CONNECTED_RESPONSE.as_bytes()); - }); -} - -async fn extend_connect_service(request: axum::extract::Request) -> Result<(), StatusCode> { - let protocol = request - .extensions() - .get::() - .map(|p| p.as_str()) - .unwrap_or(""); - if protocol != "h3x-test" { - return Err(StatusCode::BAD_REQUEST); - } - mock_connect_service(request).await -} - -#[test] -fn axum_extend_connect() { - run("axum_extend_connect", async move { - let router = Router::new() - .route("/connect", any(extend_connect_service)) - .into_service(); - let mut server = test_server(TowerService(router)).await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - - let connection = client - .connect(host.clone()) - .await - .expect("failed to connect to server"); - let (mut read_stream, mut write_stream) = connection - .initial_message_stream() - .await - .expect("failed to open request stream"); - - write_stream - .send_hyper_request( - Request::connect("https://example.org:80/connect") - .extension(hyper::ext::Protocol::from_static("h3x-test")) // use extension API to set :protocol header - .body(Body::empty()) - .expect("failed to build request"), - ) - .await - .expect("failed to send request"); - - let response = read_stream - .read_hyper_response_parts() - .await - .expect("failed to take response"); - assert_eq!(response.status, StatusCode::OK); - - let read_stream = pin!(read_stream.into_bytes_stream()); - let mut read_stream = StreamReader::new(read_stream); - let write_stream = pin!(write_stream.into_bytes_sink()); - let mut write_stream = SinkWriter::new(CopyToBytes::new(write_stream)); - - write_stream - .write_all(CONNECTED_REQUEST.as_bytes()) - .await - .expect("failed to write to tunnel"); - write_stream - .shutdown() - .await - .expect("failed to close write stream"); - tracing::info!("Sent connected request"); - - let mut buf = Vec::new(); - read_stream - .read_to_end(&mut buf) - .await - .expect("failed to read from tunnel"); - - assert_eq!(&buf[..], CONNECTED_RESPONSE.as_bytes()); - }); -} diff --git a/tests/common/mod.rs b/tests/common/mod.rs index e4a625b..32cebea 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -1,21 +1,15 @@ #![allow(unused)] use std::{ - error::Error, sync::{Arc, LazyLock}, time::Duration, }; -use dquic::{ - prelude::{ - BindUri, BoundAddr, IO, - handy::{ToCertificate, ToPrivateKey}, - }, - qinterface::component::route::QuicRouter, -}; -use h3x::{ - dquic::{H3Client, H3Servers}, - server::UnresolvedRequest, +use h3x::dquic::{ + Name, + cert::handy::{ToCertificate, ToPrivateKey}, + net::IO, + resolver::handy::SystemResolver, }; use http::uri::Authority; use tokio::time; @@ -32,7 +26,7 @@ pub fn run(test_name: &'static str, future: F) -> F::Output { tokio::runtime::Builder::new_multi_thread() .enable_all() .build() - .unwrap() + .expect("failed to build tokio runtime") }); static TRACING: LazyLock = LazyLock::new(|| { @@ -69,68 +63,39 @@ pub const SERVER_CERT: &[u8] = include_bytes!("../../tests/keychain/localhost/se pub const SERVER_KEY: &[u8] = include_bytes!("../../tests/keychain/localhost/server.key"); pub const TEST_DATA: &[u8] = include_bytes!("mod.rs"); -pub fn test_client() -> H3Client { - let mut roots = rustls::RootCertStore::empty(); - roots.add_parsable_certificates(CA_CERT.to_certificate()); - H3Client::builder() - .with_root_certificates(roots) - .without_identity() - .expect("failed to initialize client tls") - .with_router(Arc::new(QuicRouter::new())) - .build() +pub async fn test_client() -> h3x::dquic::H3Endpoint { + let quic = h3x::dquic::QuicEndpoint::builder().build().await; + h3x::endpoint::H3Endpoint::new(quic) } -pub async fn test_server(router: S) -> H3Servers -where - S: tower_service::Service + Clone + Send + Sync + 'static, - S::Future: Send, - S::Error: Into>, -{ - let mut servers = H3Servers::builder() - .without_client_cert_verifier() - .expect("failed to initialize server tls") - .with_router(Arc::new(QuicRouter::new())) - .listen() - .expect("failed to listen"); - servers - .add_server( - "localhost", - SERVER_CERT.to_certificate(), - SERVER_KEY.to_private_key(), - None, - [ - BindUri::from("inet://127.0.0.1:0").alloc_port(), - BindUri::from("inet://[::1]:0").alloc_port(), - ], - router, - ) - .await - .expect("failed to add server"); - servers -} - -pub fn get_server_addr(servers: &H3Servers) -> BoundAddr { - let localhost = servers - .quic_listener() - .get_server("localhost") - .expect("server localhost must be registered"); - let (_bind_uri, localhost_bind_interface) = localhost - .bind_interfaces() +pub async fn test_server() -> (h3x::dquic::H3Endpoint, Authority) { + let identity = Arc::new(h3x::dquic::Identity { + name: "localhost".parse().unwrap(), + certs: Arc::new(SERVER_CERT.to_certificate()), + key: Arc::new(SERVER_KEY.to_private_key()), + ocsp: Arc::new(None), + }); + let network = h3x::dquic::Network::builder().build(); + let quic = h3x::dquic::QuicEndpoint::builder() + .network(network.clone()) + .identity(identity) + .bind(Arc::new(vec![ + "127.0.0.1:0".parse().expect("valid pattern"), + ])) + .build() + .await; + let bind_iface = network + .quic() + .interfaces() .into_iter() .next() - .expect("server localhost must have at least one bind interface"); - localhost_bind_interface + .expect("no bound interface"); + let port = bind_iface .borrow() .bound_addr() - .expect("bind interface must have local addr") -} - -pub fn get_server_authority(servers: &H3Servers) -> Authority { - match get_server_addr(servers) { - BoundAddr::Internet(socket_addr) => { - Authority::from_maybe_shared(Vec::from(format!("localhost:{}", socket_addr.port()))) - .expect("failed to parse authority") - } - _ => unimplemented!("Only Internet addresses are supported now"), - } + .expect("no bound addr") + .port(); + let authority = Authority::from_maybe_shared(format!("localhost:{port}")) + .expect("failed to parse authority"); + (h3x::endpoint::H3Endpoint::new(quic), authority) } diff --git a/tests/endpoint.rs b/tests/endpoint.rs deleted file mode 100644 index 7e59bdb..0000000 --- a/tests/endpoint.rs +++ /dev/null @@ -1,453 +0,0 @@ -//! Integration tests exercising the new `h3x::endpoint` API. -//! -//! Each test pairs a server [`H3Endpoint`] + `serve(Router)` with a client -//! built on top of [`h3x::client::Client`] using [`QuicEndpoint`] as the -//! [`quic::Connect`](h3x::quic::Connect) implementation. - -#![cfg(feature = "endpoint")] - -mod common; - -use std::sync::Arc; - -use common::{CA_CERT, SERVER_CERT, SERVER_KEY, run}; -use dquic::{ - prelude::{ - BindUri, BoundAddr, IO, - handy::{ToCertificate, ToPrivateKey}, - }, - qbase::net::addr::{EndpointAddr, SocketEndpointAddr}, - qinterface::{component::route::QuicRouter, manager::InterfaceManager}, - qresolve::{Resolve, ResolveFuture, Source, SystemResolver}, -}; -use h3x::{ - client::Client, - connection::ConnectionBuilder, - endpoint::{ - ClientOnlyConfig, ClientQuicConfig, H3Endpoint, Identity, NamedIdentity, Network, - QuicEndpoint, ServerCertVerifierChoice, ServerQuicConfig, - }, - pool::Pool, - server::{self, Router}, -}; -use http::uri::Authority; -use rustls::{ - RootCertStore, - client::WebPkiServerVerifier, - pki_types::{CertificateDer, PrivateKeyDer}, -}; -use tokio_util::task::AbortOnDropHandle; - -async fn hello_service(_: &mut server::Request, response: &mut server::Response) { - response - .set_status(http::StatusCode::OK) - .set_body(&b"hello from endpoint"[..]); -} - -fn named_server_identity() -> Identity { - let certs: Vec> = SERVER_CERT.to_certificate(); - let key: PrivateKeyDer<'static> = SERVER_KEY.to_private_key(); - Identity::Named(Arc::new(NamedIdentity { - name: Arc::from("localhost"), - certs, - key: Arc::new(key), - })) -} - -fn client_webpki_verifier() -> Arc { - let mut roots = RootCertStore::empty(); - roots.add_parsable_certificates(CA_CERT.to_certificate()); - WebPkiServerVerifier::builder(Arc::new(roots)) - .build() - .expect("failed to build webpki verifier") -} - -fn test_network() -> Arc { - // Each test uses its own `QuicRouter` AND `InterfaceManager` so that - // parallel tests do not share the global dispatcher or cross-route - // packets destined for other tests. - Network::builder() - .quic_router(Arc::new(QuicRouter::new())) - .iface_manager(Arc::new(InterfaceManager::new())) - .build() -} - -#[test] -fn serve_and_connect_hello() { - run("serve_and_connect_hello", async move { - let network = test_network(); - - // --- Server --- - let server_quic = QuicEndpoint::new( - network.clone(), - named_server_identity(), - Arc::new(SystemResolver), - ClientQuicConfig::default(), - ServerQuicConfig::default(), - ); - let bind_iface = network.bind(BindUri::from("inet://127.0.0.1:0")).await; - let bound_addr = bind_iface - .borrow() - .bound_addr() - .expect("bind interface must have a local address"); - let port = match bound_addr { - BoundAddr::Internet(socket_addr) => socket_addr.port(), - _ => unreachable!("bound to inet://127.0.0.1"), - }; - - let server_endpoint = H3Endpoint::new( - server_quic, - Pool::empty(), - Arc::new(ConnectionBuilder::new(Arc::default())), - ); - let router = Router::new().get("/hello", hello_service); - let _serve = - AbortOnDropHandle::new(tokio::spawn( - async move { server_endpoint.serve(router).await }, - )); - - // --- Client --- - let client_own = ClientOnlyConfig { - verifier: ServerCertVerifierChoice::WebPki(client_webpki_verifier()), - ..Default::default() - }; - let client_quic_config = ClientQuicConfig { - common: Arc::default(), - own: Arc::new(client_own), - }; - let client_quic = QuicEndpoint::new( - network.clone(), - Identity::Anonymous, - Arc::new(SystemResolver), - client_quic_config, - ServerQuicConfig::default(), - ); - let client = Client::from_quic_client().client(client_quic).build(); - - let authority: Authority = format!("localhost:{port}") - .parse() - .expect("valid authority"); - let uri: http::Uri = format!("https://{authority}/hello").parse().unwrap(); - let (_request, mut response) = client - .new_request() - .with_authority(authority) - .get(uri) - .await - .expect("failed to send request"); - - assert_eq!(response.status(), http::StatusCode::OK); - let body = response - .read_to_string() - .await - .expect("failed to read response body"); - assert_eq!(body, "hello from endpoint"); - }); -} - -// --------------------------------------------------------------------------- -// bind_server semantics -// --------------------------------------------------------------------------- - -fn named_with(name: &str) -> Arc { - let certs: Vec> = SERVER_CERT.to_certificate(); - let key: PrivateKeyDer<'static> = SERVER_KEY.to_private_key(); - Arc::new(NamedIdentity { - name: Arc::from(name), - certs, - key: Arc::new(key), - }) -} - -#[test] -fn bind_server_sni_in_use() { - run("bind_server_sni_in_use", async move { - let network = test_network(); - let a = named_with("localhost"); - let b = named_with("localhost"); - let _first = network - .bind_server(a, ServerQuicConfig::default()) - .await - .expect("first bind succeeds"); - let err = network - .bind_server(b, ServerQuicConfig::default()) - .await - .expect_err("second bind with different identity must fail"); - assert!( - matches!(err, h3x::endpoint::BindServerError::SniInUse { .. }), - "unexpected error: {err:?}" - ); - }); -} - -#[test] -fn bind_server_reuses_identity() { - run("bind_server_reuses_identity", async move { - let network = test_network(); - let id = named_with("localhost"); - let first = network - .bind_server(id.clone(), ServerQuicConfig::default()) - .await - .expect("first bind succeeds"); - let second = network - .bind_server(id.clone(), ServerQuicConfig::default()) - .await - .expect("same identity must reuse binding"); - assert_eq!(first.name, second.name); - }); -} - -#[test] -fn bind_server_config_conflict() { - run("bind_server_config_conflict", async move { - let network = test_network(); - let a = named_with("alpha"); - let b = named_with("beta"); - - let cfg_a = ServerQuicConfig::default(); - let cfg_b = { - let own = h3x::endpoint::ServerOnlyConfig { - alpns: vec![b"altproto".to_vec()], - ..Default::default() - }; - ServerQuicConfig { - common: Arc::default(), - own: Arc::new(own), - } - }; - - let _held = network - .bind_server(a, cfg_a) - .await - .expect("first bind succeeds"); - let err = network - .bind_server(b, cfg_b) - .await - .expect_err("incompatible server config must fail"); - assert!( - matches!(err, h3x::endpoint::BindServerError::ServerConfigConflict), - "unexpected error: {err:?}" - ); - }); -} - -#[test] -fn bind_server_slot_auto_reset() { - run("bind_server_slot_auto_reset", async move { - let network = test_network(); - - let cfg_a = ServerQuicConfig::default(); - let cfg_b = { - let own = h3x::endpoint::ServerOnlyConfig { - alpns: vec![b"altproto".to_vec()], - ..Default::default() - }; - ServerQuicConfig { - common: Arc::default(), - own: Arc::new(own), - } - }; - - { - let _first = network - .bind_server(named_with("alpha"), cfg_a) - .await - .expect("first bind succeeds"); - } - // After the binding drops the slot should clear, allowing a new - // incompatible config to install. - let _second = network - .bind_server(named_with("beta"), cfg_b) - .await - .expect("slot should auto-reset after last binding dropped"); - }); -} - -// --------------------------------------------------------------------------- -// Multi-server end-to-end tests -// --------------------------------------------------------------------------- - -/// Test-only resolver that maps every name lookup to a fixed loopback -/// endpoint. Lets tests dial arbitrary SNIs (e.g. `alpha`, `beta`) without -/// requiring `/etc/hosts` entries. -#[derive(Debug)] -struct FixedResolver(std::net::SocketAddr); - -impl std::fmt::Display for FixedResolver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "FixedResolver({})", self.0) - } -} - -impl Resolve for FixedResolver { - fn lookup<'l>(&'l self, _name: &'l str) -> ResolveFuture<'l> { - use futures::{FutureExt, StreamExt, stream}; - let ep = EndpointAddr::Socket(SocketEndpointAddr::direct(self.0)); - let source = Source::System; - async move { Ok(stream::iter(std::iter::once((source, ep))).boxed()) }.boxed() - } -} - -async fn alpha_service(_: &mut server::Request, response: &mut server::Response) { - response - .set_status(http::StatusCode::OK) - .set_body(&b"alpha"[..]); -} - -async fn beta_service(_: &mut server::Request, response: &mut server::Response) { - response - .set_status(http::StatusCode::OK) - .set_body(&b"beta"[..]); -} - -/// Two named servers (`alpha`, `beta`) share one [`Network`] and one listen -/// port. Both register distinct SNIs via the Network's SNI registry. A -/// client dialing `alpha:PORT` must receive the `alpha` handler's response, -/// and `beta:PORT` must receive `beta`'s. This exercises: -/// -/// * shared [`Network`] infrastructure (one bind port, one QuicRouter, one -/// connectionless dispatcher); -/// * the Network's per-SNI mpmc fan-out (`sni_registry` + `bind_server`); -/// * rustls [`SniCertResolver`] picking the correct [`CertifiedKey`]; -/// * [`QuicEndpoint::accept`] returning only connections destined for its -/// own SNI. -#[test] -fn two_sni_share_network_and_port() { - run("two_sni_share_network_and_port", async move { - let network = test_network(); - let bind_iface = network.bind(BindUri::from("inet://127.0.0.1:0")).await; - let port = match bind_iface.borrow().bound_addr().unwrap() { - BoundAddr::Internet(s) => s.port(), - _ => unreachable!(), - }; - let resolver: Arc = - Arc::new(FixedResolver(format!("127.0.0.1:{port}").parse().unwrap())); - - let mut serve_handles = Vec::new(); - let mut bindings = Vec::new(); // keep SNI registrations alive before spawn - // Share one config Arc across both servers; otherwise - // `ServerQuicConfig::default()` yields fresh trait objects each - // call and `server_config_compatible` rejects the second bind. - let shared_server_config = ServerQuicConfig::default(); - let alpha_router = Router::new().get("/hello", alpha_service); - let beta_router = Router::new().get("/hello", beta_service); - for (name, router) in [("alpha", alpha_router), ("beta", beta_router)] { - let named = named_with(name); - // Eagerly register so rustls' SniCertResolver sees both SNIs - // before the first ClientHello arrives (avoids a startup race - // where a client connects before the server task has polled - // its first `accept()`). - let binding = network - .bind_server(named.clone(), shared_server_config.clone()) - .await - .expect("eager bind_server"); - bindings.push(binding); - let quic = QuicEndpoint::new( - network.clone(), - Identity::Named(named), - resolver.clone(), - ClientQuicConfig::default(), - shared_server_config.clone(), - ); - let h3 = H3Endpoint::new( - quic, - Pool::empty(), - Arc::new(ConnectionBuilder::new(Arc::default())), - ); - serve_handles.push(AbortOnDropHandle::new(tokio::spawn(async move { - h3.serve(router).await - }))); - } - - // Client with dangerous verifier (cert was issued for `localhost` so - // webpki would reject `alpha`/`beta` even though they share material). - let client_own = ClientOnlyConfig { - verifier: ServerCertVerifierChoice::Dangerous, - ..Default::default() - }; - let client_quic_config = ClientQuicConfig { - common: Arc::default(), - own: Arc::new(client_own), - }; - let client_quic = QuicEndpoint::new( - network.clone(), - Identity::Anonymous, - resolver.clone(), - client_quic_config, - ServerQuicConfig::default(), - ); - let client = Client::from_quic_client().client(client_quic).build(); - - for (sni, expected) in [("alpha", "alpha"), ("beta", "beta")] { - let authority: Authority = format!("{sni}:{port}").parse().unwrap(); - let uri: http::Uri = format!("https://{authority}/hello").parse().unwrap(); - let (_req, mut resp) = client - .new_request() - .with_authority(authority) - .get(uri) - .await - .unwrap_or_else(|e| panic!("request for {sni} failed: {e:?}")); - assert_eq!(resp.status(), http::StatusCode::OK); - let body = resp.read_to_string().await.expect("read body"); - assert_eq!(body, expected, "sni {sni} got wrong response"); - } - - drop(serve_handles); - drop(bindings); - }); -} - -// --------------------------------------------------------------------------- -// Introspection accessors (Phase 0) -// --------------------------------------------------------------------------- - -#[test] -fn get_iface_returns_bound_interface() { - run("get_iface_returns_bound_interface", async move { - let network = test_network(); - let uri = BindUri::from("inet://127.0.0.1:0"); - let bound = network.bind(uri.clone()).await; - let expected_addr = bound.borrow().bound_addr().expect("bound addr"); - - let fetched = network - .get_iface(&uri) - .expect("iface must be retrievable via get_iface"); - let fetched_addr = fetched.borrow().bound_addr().expect("bound addr"); - assert_eq!(expected_addr, fetched_addr); - - // Unknown URI yields None. - let unknown = BindUri::from("inet://127.0.0.1:1"); - assert!(network.get_iface(&unknown).is_none()); - }); -} - -#[test] -fn registered_sni_names_tracks_live_bindings() { - run("registered_sni_names_tracks_live_bindings", async move { - let network = test_network(); - assert!(network.registered_sni_names().is_empty()); - - let shared_config = ServerQuicConfig::default(); - let alpha = network - .bind_server(named_with("alpha"), shared_config.clone()) - .await - .expect("bind alpha"); - let beta = network - .bind_server(named_with("beta"), shared_config.clone()) - .await - .expect("bind beta"); - - let mut names = network.registered_sni_names(); - names.sort_by(|a, b| a.as_ref().cmp(b.as_ref())); - assert_eq!(names.len(), 2); - assert_eq!(names[0].as_ref(), "alpha"); - assert_eq!(names[1].as_ref(), "beta"); - - drop(alpha); - let remaining = network.registered_sni_names(); - assert_eq!(remaining.len(), 1); - assert_eq!(remaining[0].as_ref(), "beta"); - - drop(beta); - assert!(network.registered_sni_names().is_empty()); - }); -} diff --git a/tests/module_paths.rs b/tests/module_paths.rs new file mode 100644 index 0000000..79bf301 --- /dev/null +++ b/tests/module_paths.rs @@ -0,0 +1,36 @@ +#![cfg(feature = "hyper")] + +use bytes::Bytes; +use http_body_util::Empty; + +#[test] +fn owner_local_hyper_modules_are_public() { + let request = http::Request::builder() + .method(http::Method::GET) + .uri("https://example.test/") + .body(()) + .expect("request should build"); + let fields = h3x::qpack::field::hyper::validated_hyper_request_parts_to_field_lines( + request.into_parts().0, + ) + .expect("request pseudo headers should be valid"); + assert!(!fields.is_empty()); + + let _message_takeover = std::any::TypeId::of::< + h3x::dhttp::message::hyper::upgrade::TakeoverSlot, + >(); + let _endpoint_service = h3x::endpoint::hyper::TowerService(()); + let _endpoint_hyper_service = h3x::endpoint::hyper::HyperService(()); +} + +#[test] +fn top_level_hyper_is_facade_without_client_or_server_modules() { + let _request_error = std::any::TypeId::of::>(); + let _send_error = std::any::TypeId::of::>(); + let _handle_error = + std::any::TypeId::of::>(); + let _tower = h3x::hyper::TowerService(()); + let _hyper = h3x::hyper::HyperService(()); + let _upgrade_error = h3x::hyper::upgrade::MissingStream::Both; + let _body = Empty::::new(); +} diff --git a/tests/peer_settings.rs b/tests/peer_settings.rs new file mode 100644 index 0000000..a8942b3 --- /dev/null +++ b/tests/peer_settings.rs @@ -0,0 +1,34 @@ +mod common; + +use std::time::Duration; + +use tokio::time::timeout; + +#[test] +fn peer_settings_resolve_after_h3_connection_setup_without_request_streams() { + common::run( + "peer_settings_resolve_after_h3_connection_setup_without_request_streams", + async { + let mut server = common::test_server().await; + let client = common::test_client().await; + let authority = server.1.clone(); + + let (server_connection, client_connection) = timeout(Duration::from_secs(10), async { + tokio::join!(server.0.accept(), client.connect(authority)) + }) + .await + .expect("connections should be established"); + let server_connection = server_connection.expect("server h3 connection"); + let client_connection = client_connection.expect("client h3 connection"); + + timeout(Duration::from_secs(5), client_connection.peer_settings()) + .await + .expect("client peer settings should arrive") + .expect("client peer settings should be ok"); + timeout(Duration::from_secs(5), server_connection.peer_settings()) + .await + .expect("server peer settings should arrive") + .expect("server peer settings should be ok"); + }, + ); +} diff --git a/tests/simple.rs b/tests/simple.rs deleted file mode 100644 index a1b7093..0000000 --- a/tests/simple.rs +++ /dev/null @@ -1,223 +0,0 @@ -mod common; -use std::sync::Arc; - -use common::*; -use dquic::prelude::handy::{ToCertificate, ToPrivateKey}; -use h3x::{ - dquic::H3Servers, - error::Code, - quic, - server::{self, Router}, - varint::VarInt, -}; -use tokio_util::task::AbortOnDropHandle; - -async fn hello_world_service(_: &mut server::Request, response: &mut server::Response) { - response - .set_status(http::StatusCode::OK) - .set_body(&b"Hello, World!"[..]); -} - -#[test] -fn hello_world() { - run("hello_world", async move { - let mut server = test_server(Router::new().get("/hello_world", hello_world_service)).await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - let (_, mut response) = client - .new_request() - .get(format!("https://{host}/hello_world").parse().unwrap()) - .await - .expect("failed to send request"); - - assert_eq!(response.status(), http::StatusCode::OK); - let response = response - .read_to_string() - .await - .expect("failed to read response body"); - assert_eq!(response, "Hello, World!"); - }) -} - -async fn streaming_echo_service(request: &mut server::Request, response: &mut server::Response) { - response.set_status(http::StatusCode::OK); - response.flush().await.expect("failed to flush response"); - - while let Some(chunk) = request - .read() - .await - .transpose() - .expect("failed to read request body") - { - response - .write(chunk) - .await - .expect("failed to write response body"); - } -} - -#[test] -fn streaming_echo() { - run("streaming_echo", async move { - let mut server = test_server(Router::new().post("/echo", streaming_echo_service)).await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - let (mut request, mut response) = client - .new_request() - .post(format!("https://{host}/echo").parse().unwrap()) - .await - .expect("failed to send request"); - assert_eq!(response.status(), http::StatusCode::OK); - - request - .write(TEST_DATA) - .await - .expect("failed to write body") - .close() - .await - .expect("failed to close stream"); - let response = response - .read_to_bytes() - .await - .expect("failed to read response body"); - assert_eq!(response, TEST_DATA); - }) -} - -#[test] -fn fallback() { - run("fallback", async move { - let mut server = test_server( - Router::new() - .get("/hello_world", hello_world_service) - .post("/hello_world", hello_world_service), - ) - .await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - let (_, response) = client - .new_request() - .get(format!("https://{host}/non_exist").parse().unwrap()) - .await - .expect("failed to send request"); - - assert_eq!(response.status(), http::StatusCode::NOT_FOUND); - }) -} - -async fn echo_service(request: &mut server::Request, response: &mut server::Response) { - let body = request - .read_to_bytes() - .await - .expect("failed to read request body"); - response.set_status(http::StatusCode::OK).set_body(body); -} - -#[test] -fn auto_close() { - run("auto_close", async move { - let mut server = test_server(Router::new().post("/echo", echo_service)).await; - let host = get_server_authority(&server); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { server.run().await })); - - let client = test_client(); - let (_, mut response) = client - .new_request() - .with_body(TEST_DATA) - .post(format!("https://{host}/echo").parse().unwrap()) - .await - .expect("failed to send request"); - - assert_eq!(response.status(), http::StatusCode::OK); - let response = response - .read_to_bytes() - .await - .expect("failed to read response body"); - assert_eq!(response, TEST_DATA); - }) -} - -#[test] -fn missing_server_name_closes_connection_with_no_error() { - run( - "missing_server_name_closes_connection_with_no_error", - async move { - let mut servers: H3Servers = H3Servers::builder() - .without_client_cert_verifier() - .expect("failed to initialize server tls") - .with_router(Arc::new( - dquic::qinterface::component::route::QuicRouter::new(), - )) - .listen() - .expect("failed to listen"); - servers - .quic_listener() - .add_server( - "localhost", - SERVER_CERT.to_certificate(), - SERVER_KEY.to_private_key(), - [ - dquic::prelude::BindUri::from("inet://127.0.0.1:0").alloc_port(), - dquic::prelude::BindUri::from("inet://[::1]:0").alloc_port(), - ], - None, - ) - .await - .expect("failed to add server"); - - let host = get_server_authority(&servers); - let _serve = AbortOnDropHandle::new(tokio::spawn(async move { servers.run().await })); - - let client = test_client(); - let error = client - .new_request() - .get( - format!("https://{host}/hello_world") - .parse() - .expect("valid uri"), - ) - .await - .err() - .expect("request should fail with no matching server name"); - - match error { - h3x::client::RequestError::ResponseStream { - source: - quic::StreamError::Connection { - source: - quic::ConnectionError::Application { - source: quic::ApplicationError { code, .. }, - }, - }, - } => assert_eq!(code, Code::H3_NO_ERROR), - h3x::client::RequestError::ResponseStream { - source: - quic::StreamError::Connection { - source: - quic::ConnectionError::Transport { - source: quic::TransportError { kind, .. }, - }, - }, - } => { - assert_eq!(kind, VarInt::from_u32(0x0c)); - } - // Stream reset with H3_NO_ERROR: race between GuardedQuicWriter drop - // (RESET_STREAM) and connection close (CONNECTION_CLOSE). Both are valid. - h3x::client::RequestError::ResponseStream { - source: quic::StreamError::Reset { code }, - } => { - assert_eq!(code, Code::H3_NO_ERROR.into_inner()); - } - other => panic!( - "expected response stream close from missing-server-name connection close, got: {other:?}" - ), - } - }, - ) -}