diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..af284e1 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,6 @@ +[env] +DHTTP_H3_DNS_SERVER = "https://dns.genmeta.net:4433/" +DHTTP_HTTP_DNS_SERVER = "https://dns.genmeta.net/" +DHTTP_MDNS_SERVICE = "_dhttp.local" +DHTTP_STUN_SERVER = "stun.genmeta.net:20002" +DHTTP_ROOT_CA = { value = "intermediate/intermediate.crt", relative = true } diff --git a/.gitignore b/.gitignore index 57506c5..855760d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,8 @@ Cargo.lock *.log build - .DS_Store .vscode/ +/geoip +/docs/superpowers +/certs diff --git a/Cargo.toml b/Cargo.toml index b84ddea..abb3316 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "dyns" description = "DNS discovery and resolver support for DHTTP applications" -version = "0.3.0" +version = "0.4.0" edition = "2024" license = "Apache-2.0" repository = "https://github.com/genmeta/ddns" @@ -9,6 +9,7 @@ readme = "README.md" keywords = ["dhttp", "dns", "mdns", "http3", "quic"] categories = ["network-programming", "asynchronous"] autoexamples = false +autobins = false [lib] name = "ddns" @@ -17,12 +18,12 @@ name = "ddns" base64 = "0.22" bitfield-struct = "0.13" bytes = "1" -dashmap = "6" -dhttp-identity = "0.1.0" +dashmap = { version = "6", optional = true } +dhttp-identity = "0.2.0" dquic = "0.5.1" -flume = "0.12" +flume = { version = "0.12", optional = true } futures = "0.3" -libc = "0.2" +libc = { version = "0.2", optional = true } nom = "8" rand = "0.10" ring = "0.17" @@ -33,7 +34,7 @@ rustls = { version = "0.23", default-features = false, features = [ rustls-native-certs = { version = "0.8", optional = true } rustls-pemfile = "2" snafu = "0.9" -socket2 = { version = "0.6", features = ["all"] } +socket2 = { version = "0.6", features = ["all"], optional = true } tokio = { version = "1", features = [ "time", "macros", @@ -44,9 +45,9 @@ tokio = { version = "1", features = [ "io-util", ] } tracing = "0.1" -x509-parser = "0.18" +x509-parser = { version = "0.18", features = ["verify"] } -h3x = { version = "0.3.1", default-features = false, optional = true } +h3x = { version = "0.4.0", default-features = false, optional = true } http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -60,67 +61,45 @@ reqwest = { version = "0.13", default-features = false, features = [ ], optional = true } url = { version = "2", optional = true } -clap = { version = "4", features = ["derive"], optional = true } -deadpool-redis = { version = "0.23", optional = true } -idna = { version = "1", optional = true } -serde = { version = "1", features = ["derive"], optional = true } -toml = { version = "1", optional = true } -tower-service = { version = "0.3", optional = true } -tracing-subscriber = { version = "0.3", features = [ - "env-filter", -], optional = true } - [features] default = [] -h3x-resolver = [ +resolvers = [] +publishers = [] +dquic-network = ["dep:h3x", "h3x/dquic"] +h3 = [ + "dep:dashmap", "dep:h3x", - "h3x/dquic", "h3x/hyper", "dep:http", "dep:http-body", "dep:http-body-util", "dep:url", ] -mdns-resolver = ["dep:h3x", "h3x/dquic"] -http-resolver = ["dep:reqwest", "dep:rustls-native-certs"] -server = [ - "h3x-resolver", - "dep:clap", - "dep:deadpool-redis", - "dep:idna", - "dep:serde", - "dep:toml", - "dep:tower-service", - "dep:tracing-subscriber", -] +http = ["dep:dashmap", "dep:reqwest", "dep:rustls-native-certs"] +mdns = ["dep:dashmap", "dep:flume", "dep:libc", "dep:socket2"] [dev-dependencies] clap = { version = "4", features = ["derive"] } -h3x = { version = "0.3.1", default-features = false, features = [ - "dquic", -] } +h3x = { version = "0.4.0", default-features = false, features = ["dquic"] } shellexpand = "3" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -[[bin]] -name = "ddns-server" -path = "src/bin/ddns-server/main.rs" -required-features = ["server"] - [[example]] name = "mdns_discover" path = "examples/mdns_discover.rs" +required-features = ["mdns"] [[example]] name = "mdns_query" path = "examples/mdns_query.rs" +required-features = ["mdns"] [[example]] name = "publish" path = "examples/publish.rs" -required-features = ["h3x-resolver"] +required-features = ["h3"] [[example]] name = "query" path = "examples/query.rs" -required-features = ["h3x-resolver"] +required-features = ["h3"] diff --git a/README.md b/README.md index 071aaaa..e4cf127 100644 --- a/README.md +++ b/README.md @@ -1,44 +1,47 @@ # DDNS -`ddns` provides DNS discovery and resolver support for DHTTP applications. It is a -single Rust package: the historical `ddns-core`, `gmdns`, `ddns`, and -`ddns-server` crate boundaries now live as modules and feature-gated targets in -one published Cargo package named `dyns`, with a library target kept as `ddns` -for source compatibility. +`ddns` provides DNS discovery and resolver support for DHTTP applications. +The published Cargo package is `dyns`, and the library target remains `ddns`. + +`ddns` exposes backend implementations in `ddns::h3`, `ddns::http`, and `ddns::mdns`, +while `ddns::resolvers` and `ddns::publishers` act as facades for re-exports and +aggregate helper types. + +```toml +ddns = { package = "dyns", version = "0.4.0" } +``` ## Crate layout -| Module / target | Role | +| Module | Role | | --- | --- | | `ddns::core` | DNS packet parser, resource-record types, endpoint `E` record encoding, and HTTP multi-record response wire format. | -| `ddns::mdns` | RFC 6762 multicast DNS transport, LAN publisher, and LAN resolver support. | -| `ddns::resolvers` | Resolver chain plus optional System, mDNS, DNS-over-H3, and DNS-over-HTTP resolvers. | -| `ddns::publisher` | Feature-gated endpoint record signing and publishing loop helpers for DHTTP endpoints. | -| `ddns-server` | DNS-over-H3 publish/lookup server binary, enabled by the `server` feature. | - -`ddns` is endpoint-facing support code for the DHTTP ecosystem. Applications -normally reach it through the `dhttp` endpoint facade; lower-level consumers can -depend on package `dyns` directly (typically renamed locally to `ddns`) when -they need DNS wire types, resolver composition, mDNS, or the DNS-over-H3 server. - -```toml -ddns = { package = "dyns", version = "0.3.0" } -``` +| `ddns::h3` | DNS-over-HTTP/3 backend implementation. | +| `ddns::http` | DNS-over-HTTP backend implementation. | +| `ddns::mdns` | RFC 6762 multicast DNS transport plus LAN resolver/publisher backend implementation. | +| `ddns::resolvers` | Resolver facade: backend re-exports, resolver chains, and `Resolvers` aggregation. | +| `ddns::publishers` | Publisher facade: backend re-exports, scoped publisher atoms, `Publishers` aggregation, and endpoint publication helpers. | ## Features -All optional integrations are feature-gated; the default feature set is empty. +The default feature set is empty. | Feature | Enables | | --- | --- | -| `h3x-resolver` | DNS-over-H3 resolver and publisher using `h3x`/`dquic`. | -| `mdns-resolver` | mDNS resolver integration backed by an existing `h3x::dquic::Network`. | -| `http-resolver` | DNS-over-HTTP resolver/publisher using `reqwest` and native roots. | -| `server` | `ddns-server`, Redis storage support, TOML config parsing, and tracing setup. | +| `resolvers` | Resolver aggregation types such as `Resolvers`, `ResolversBuilder`, and `DnsScheme`. | +| `publishers` | Scoped publication helpers such as `Publisher`, `Publishers`, `PublishScope`, `EndpointPublicationLoop`, and `PublishAddresses`; backend `Publish` implementations own any required signing. | +| `dquic-network` | `h3x`/`dquic` network-backed publication helpers such as `EndpointBindingAddresses`; meaningful together with `publishers`, and also used by mDNS resolver aggregation. | +| `h3` | DNS-over-HTTP/3 backend surface (`ddns::h3`, plus `H3Resolver` / `H3Publisher` re-exports from the facades). | +| `http` | DNS-over-HTTP backend surface (`ddns::http`, plus `HttpResolver` / `HttpPublisher` re-exports from the facades). | +| `mdns` | mDNS backend surface (`ddns::mdns`, plus `MdnsResolver` / `MdnsPublisher` re-exports from the facades). | + +Backend types live under the `resolvers` / `publishers` facades whenever their backend feature is enabled. +The aggregate `Resolvers` and endpoint-publication helper types are separately gated by the +`resolvers` and `publishers` features. ## Bootstrap constants -`build.rs` generates the resolver defaults exposed from `ddns::resolvers`: +`build.rs` generates resolver defaults exposed from `ddns::resolvers`: | Environment variable | Public constant | Fallback when unset | | --- | --- | --- | @@ -46,24 +49,20 @@ All optional integrations are feature-gated; the default feature set is empty. | `DHTTP_HTTP_DNS_SERVER` | `DHTTP_HTTP_DNS_SERVER` | `https://dhttp.example.net` | | `DHTTP_MDNS_SERVICE` | `DHTTP_MDNS_SERVICE` | `dhttp.example.net` | -The fallbacks are docs/build placeholders, not operational defaults. Real -endpoint, server, and E2E runs should set the DHTTP bootstrap environment before -building. +The fallbacks are docs/build placeholders, not operational defaults. ## Quick start ### Resolver chain -`Resolvers` queries all configured resolvers and streams endpoint addresses from -successful backends. System DNS is always available; mDNS, H3, and HTTP builders -appear behind their features. +Enable the resolver aggregation surface and build a chain explicitly: ```rust use ddns::resolvers::Resolvers; use futures::StreamExt; #[tokio::main] -async fn main() -> Result<(), ddns::resolvers::DnsErrors> { +async fn main() -> Result<(), ddns::resolvers::ResolversError> { let resolvers = Resolvers::builder().system().build(); let mut endpoints = resolvers.lookup("demo.example.dhttp.net").await?; @@ -98,21 +97,13 @@ async fn main() -> std::io::Result<()> { } ``` -Runnable examples live in `examples/`: - -```bash -cargo run --example mdns_discover -- --ip 127.0.0.1 --device lo0 -cargo run --example mdns_query -- --ip 192.168.5.156 --device en0 -``` - -### DNS-over-H3 examples +Runnable examples: ```bash -cargo run --example query --features h3x-resolver -- \ - --server-ca /path/to/root.crt \ - --host nat.genmeta.net - -cargo run --example publish --features h3x-resolver -- \ +cargo run --example mdns_discover --features mdns -- --ip 127.0.0.1 --device lo0 +cargo run --example mdns_query --features mdns -- --ip 192.168.5.156 --device en0 +cargo run --example query -- --server-ca /path/to/root.crt --host nat.genmeta.net +cargo run --example publish --features h3 -- \ --server-ca /path/to/root.crt \ --client-name demo.example.dhttp.net \ --client-cert /path/to/demo.example.dhttp.net.pem \ @@ -121,108 +112,4 @@ cargo run --example publish --features h3x-resolver -- \ --addr 192.168.1.100:8080,192.168.1.101:8080 ``` -See [`examples/README.md`](examples/README.md) for the example CLI parameters -and response decoding notes. - -## DNS-over-H3 server - -Start the server with the `server` feature: - -```bash -cargo run --bin ddns-server --features server -- --config server.toml -``` - -The server exposes two HTTP/3 routes: - -| Route | Meaning | -| --- | --- | -| `POST /publish?host=` | Publish a DNS packet for `host`. Client mTLS is required. | -| `GET /lookup?host=[&limit=N]` | Look up active records for `host`; `limit` caps newest-first dynamic records. | - -Lookup responses use header `x-record-format: multi` and the binary body from -`ddns::core::wire::MultiResponse`: - -```text -u32 count -repeated count times: - u32 dns_len | dns packet bytes | u32 cert_len | DER publisher certificate bytes -``` - -Server configuration lives in `server.toml`: - -- storage is in-memory by default, or Redis when `redis = "redis://..."` is set; -- `ttl_secs` controls dynamic record expiry; -- `require_signature` controls signed endpoint-record enforcement for Standard - domains; -- `domain_policies` are matched in order, with unlisted domains using the - Standard policy; -- `seed_records` add static bootstrap endpoints to lookup results. - -Domain policies: - -| Policy | Behavior | -| --- | --- | -| `standard` | Client certificate DNS SAN must match the published host; signed `E` records are required when `require_signature = true`; each certificate fingerprint owns one active record for the host. | -| `open_multi` | Any authenticated client certificate may publish; signature checks are skipped; multiple certificate fingerprints can coexist and lookup returns newest-first records. | - -Public DHTTP identity hostnames should use the canonical `DhttpName::SUFFIX` -(`.dhttp.net`). Infrastructure names such as `nat.genmeta.net` can remain under -Genmeta infrastructure domains. - -## Endpoint `E` records - -Custom DNS record type `E` (`QTYPE = 266`) carries DHTTP endpoint addresses. The -current wire format is: - -```text -flags(u8) -[sequence(varint) if CLUSTERED] -primary address: port(u16) + IPv4/IPv6 bytes -[agent address if NAT] -[load(f32) if LOAD] -[signature: scheme(u16) + len(varint) + bytes if SIGNED] -``` - -Flag bits: - -| Bit mask | Name | Meaning | -| --- | --- | --- | -| `0x80` | `FAMILY` | `0` = IPv4, `1` = IPv6. | -| `0x40` | `MAIN` | Primary endpoint for the name. | -| `0x20` | `CLUSTERED` | Sequence number is present; multiple publishers share the name. | -| `0x10` | `NAT` | Agent address is present for NAT traversal. | -| `0x08` | `LOAD` | One-minute load value is present. | -| `0x01` | `SIGNED` | Signature with explicit TLS signature scheme is present. | - -For DHTTP endpoint publishing, `MAIN` and `sequence` are derived from the -publisher certificate's DHTTP subject key identifier. Operators do not choose -these fields manually: `primary` certificates publish `MAIN = true`, -`secondary` certificates publish `MAIN = false`, and the certificate-chain -sequence becomes the normalized endpoint-record sequence. An omitted sequence -field means sequence `0`. - -Signed records encode the signature scheme in the record; the no-scheme signed -format is not accepted. Legacy unsigned fixed-length endpoint address records are -still parsed by length for address-only compatibility. - -## Project structure - -```text -src/core.rs DNS core module root -src/core/parser/ DNS packet, name, question, record, varint, and signature parsers -src/core/parser/record/ A/AAAA/SRV/TXT/PTR/CNAME/E record parsing and encoding -src/core/wire.rs HTTP multi-record response wire format -src/mdns.rs mDNS module root -src/mdns/protocol.rs UDP multicast socket and packet routing -src/mdns/service.rs High-level mDNS service API -src/mdns/resolvers/ mDNS resolver integration -src/resolvers.rs Resolver chain and resolver defaults -src/resolvers/h3.rs DNS-over-H3 resolver/publisher -src/resolvers/http.rs DNS-over-HTTP resolver/publisher -src/resolvers/deferred.rs Deferred resolver initialization helper -src/publisher.rs Endpoint record signer and publication loop -src/publisher/ Address selection, publish dispatch, packet signing -src/bin/ddns-server/ DNS-over-H3 server implementation -examples/ mDNS and DNS-over-H3 example programs -server.toml Example server configuration -``` +See [`examples/README.md`](examples/README.md) for example CLI parameters and response decoding notes. diff --git a/examples/README.md b/examples/README.md index c1e5780..c95a62d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,14 +1,14 @@ # DDNS examples -This directory contains runnable examples for the single published `dyns` -package, whose library target remains `ddns`. +This directory contains runnable examples for the published `dyns` package, +whose library target remains `ddns`. | Example | Feature requirement | Purpose | | --- | --- | --- | -| `mdns_discover` | none | Bind an mDNS service, publish sample local hosts, and print multicast packets. | -| `mdns_query` | none | Query a DHTTP name over local mDNS. | -| `query` | `h3x-resolver` | Query a DNS-over-H3 server and decode the multi-record response. | -| `publish` | `h3x-resolver` | Publish signed endpoint `E` records to a DNS-over-H3 server using client mTLS. | +| `mdns_discover` | `mdns` | Bind an mDNS service, publish sample local hosts, and print multicast packets. | +| `mdns_query` | `mdns` | Query a DHTTP name over local mDNS. | +| `query` | `h3` | Query a DNS-over-H3 server and decode the multi-record response. | +| `publish` | `h3` | Publish endpoint `E` records to a DNS-over-H3 server using client mTLS; H3 publish request headers are signed from the client endpoint identity. | Run all commands from the `ddns/` repository. @@ -17,7 +17,7 @@ Run all commands from the `ddns/` repository. Bind to a local interface and print multicast traffic: ```bash -cargo run --example mdns_discover -- \ +cargo run --example mdns_discover --features mdns -- \ --ip 127.0.0.1 \ --device lo0 ``` @@ -25,19 +25,18 @@ cargo run --example mdns_discover -- \ Query a name over mDNS: ```bash -cargo run --example mdns_query -- \ +cargo run --example mdns_query --features mdns -- \ --ip 192.168.5.156 \ --device en0 ``` -Replace `--ip` and `--device` with an address and interface that exist on the -local machine. The mDNS service name defaults to the build-time -`DHTTP_MDNS_SERVICE` constant. +Replace `--ip` and `--device` with an address and interface that exist on the local machine. +The mDNS service name defaults to the build-time `DHTTP_MDNS_SERVICE` constant. ## DNS-over-H3 query ```bash -cargo run --example query --features h3x-resolver -- \ +cargo run --example query --features h3 -- \ --server-ca /path/to/root.crt \ --host nat.genmeta.net ``` @@ -50,7 +49,7 @@ Options: | `--server-ca ` | PEM root CA used to verify the DNS server certificate. | | `--host ` | DNS host to query. Defaults to `nat.genmeta.net`. | -The example sends `GET /lookup?host=`. A successful server response is a +The example sends `GET /api/v2/lookup?host=`. A successful server response is a `ddns::core::wire::MultiResponse` body with header `x-record-format: multi`: ```text @@ -60,13 +59,12 @@ repeated count times: ``` The example prints each DNS packet, the publisher certificate fingerprint when a -certificate is present, and endpoint signature verification status for signed -`E` records. +certificate is present, and endpoint signature verification status for signed `E` records. ## DNS-over-H3 publish ```bash -cargo run --example publish --features h3x-resolver -- \ +cargo run --example publish --features h3 -- \ --server-ca /path/to/root.crt \ --client-name demo.example.dhttp.net \ --client-cert /path/to/demo.example.dhttp.net.pem \ @@ -84,26 +82,10 @@ Options: | `--client-name ` | DHTTP identity name presented by the client endpoint. | | `--client-cert ` | Client certificate chain PEM for mTLS and endpoint signature verification. | | `--client-key ` | Client private key PEM. | -| `--sign ` | Whether to sign each endpoint `E` record. Defaults to `true`. | -| `--host ` | DNS host to publish. Standard-policy servers require this to match the client certificate DNS SAN. | +| `--host ` | DNS host to publish. | | `--addr ` | One or more socket addresses to publish. | -The example derives the endpoint selector from the client certificate SKI before -signing records. Use the correct certificate chain instead of manual selector -flags. - -The example sends `POST /publish?host=` with a binary DNS packet body. For -Standard policy domains, the server requires a client certificate whose single -DNS SAN matches `host`; when `require_signature = true`, at least one signed -endpoint record must verify against the publisher certificate. Open-multi policy -domains still require client mTLS but skip the host SAN and endpoint signature -checks. - -## Running the server - -```bash -cargo run --bin ddns-server --features server -- --config server.toml -``` - -`server.toml` documents the available fields: listener, TLS identity, client root -CA, optional Redis storage, TTL, domain policies, and static seed records. +The example imports `H3Publisher` from the `ddns::publishers` facade, but only needs the +`h3` backend feature because backend publisher types are re-exported from the facade directly. +H3 publish request headers are always signed with the configured client endpoint identity; callers no longer pass request signature fields. +Publish requests are sent to `POST /api/v2/publish?host=`. diff --git a/examples/publish.rs b/examples/publish.rs index 629fe9f..7684f49 100644 --- a/examples/publish.rs +++ b/examples/publish.rs @@ -7,8 +7,8 @@ use std::{ use clap::Parser; use ddns::{ - core::parser::record::endpoint::EndpointAddr, - resolvers::{DHTTP_H3_DNS_SERVER, h3::H3Publisher}, + core::parser::record::endpoint::EndpointAddr, publishers::H3Publisher, + resolvers::DHTTP_H3_DNS_SERVER, }; use h3x::dquic::{ Identity, Network, QuicEndpoint, @@ -42,13 +42,6 @@ struct Options { #[arg(long)] client_key: PathBuf, - /// Sign Endpoint records using the client private key. - /// - /// This must correspond to the client certificate presented in mTLS, because the server - /// verifies the signature with the peer certificate's SPKI. - #[arg(long, default_value_t = true, action = clap::ArgAction::Set)] - sign: bool, - /// 要发布的线上域名,必须与客户端证书 SAN 匹配。 #[arg(long)] host: String, @@ -56,6 +49,12 @@ struct Options { /// 要发布的地址列表。 #[arg(long, value_delimiter = ',', num_args = 1..)] addr: Vec, + + #[arg(long, default_value_t = true)] + is_main: bool, + + #[arg(long, default_value_t = 1)] + sequence: u64, } fn default_h3_base_url() -> String { @@ -135,31 +134,18 @@ async fn main() -> io::Result<()> { let resolver = H3Publisher::new(opt.base_url.clone(), h3_endpoint)?; info!(host = %opt.host, addrs = ?opt.addr, base_url = %opt.base_url, "publish.start"); - if opt.sign { - info!("publish.endpoint_signing.enabled"); - } else { - info!("publish.endpoint_signing.disabled"); - } - let selector = identity - .dhttp_subject_key_identifier() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let chain = selector.chain(); - for &addr in &opt.addr { - info!("creating endpoint for address: {}", addr); + info!("Creating endpoint for address: {}", addr); let mut endpoint = match addr { SocketAddr::V4(v4) => EndpointAddr::direct_v4(v4), SocketAddr::V6(v6) => EndpointAddr::direct_v6(v6), }; - endpoint.set_certificate_chain_key(chain); - if opt.sign { - info!("signing endpoint"); - endpoint - .sign_with_authority(identity.as_ref()) - .await - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - } - info!("publishing endpoint: {:?}", endpoint); + endpoint.set_main(opt.is_main); + endpoint.set_sequence( + dhttp_identity::certificate::CertificateSequence::try_from(opt.sequence) + .map_err(io::Error::other)?, + ); + info!("Publishing endpoint: {:?}", endpoint); let mut hosts = std::collections::HashMap::new(); hosts.insert(opt.host.clone(), vec![endpoint]); let packet = ddns::core::MdnsPacket::answer(0, &hosts).to_bytes(); @@ -167,7 +153,7 @@ async fn main() -> io::Result<()> { .publish(&opt.host, &packet) .await .map_err(io::Error::other)?; - info!("successfully published endpoint for {}", addr); + info!("Successfully published endpoint for {}", addr); } info!("publish.ok"); diff --git a/examples/query.rs b/examples/query.rs index 80235cd..c8572b4 100644 --- a/examples/query.rs +++ b/examples/query.rs @@ -74,7 +74,7 @@ fn format_packet(packet: &MdnsPacket) -> String { RData::E(ep) => { output.push_str(&format!("Name: {}\nAddress: {}\n", rr.name(), ep)); if ep.is_signed() { - output.push_str("Signature: present\n"); + output.push_str("Legacy E signature: present\n"); } } _ => { @@ -123,10 +123,11 @@ async fn main() -> Result<(), Box> { .await; let client = H3Endpoint::new(quic); - let url = format!("{}lookup?host={}", opt.base_url, opt.host); + let mut url = url::Url::parse(&opt.base_url)?.join("/api/v2/lookup")?; + url.query_pairs_mut().append_pair("host", &opt.host); info!(url = %url, "lookup.start"); - let uri: http::Uri = url.parse()?; + let uri: http::Uri = url.as_str().parse()?; let authority = uri .authority() .ok_or_else(|| { @@ -145,12 +146,22 @@ async fn main() -> Result<(), Box> { if resp.status().is_success() { let bytes = resp.into_body().collect().await?.to_bytes(); - let (_remain, multi) = be_multi_response(bytes.as_ref()).map_err(|e| { + let (remain, multi) = be_multi_response(bytes.as_ref()).map_err(|e| { io::Error::new( io::ErrorKind::InvalidData, format!("Invalid multi-record payload: {e}"), ) })?; + if !remain.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Invalid multi-record payload: {} trailing bytes", + remain.len() + ), + ) + .into()); + } info!(count = multi.records.len(), "lookup.ok"); println!("Lookup Result: {} record(s)", multi.records.len()); @@ -163,27 +174,21 @@ async fn main() -> Result<(), Box> { None => println!("Source fingerprint: (no certificate)"), } + if record.signature_fields.is_empty() { + println!("Packet signature: none"); + } else if record.cert.is_empty() { + println!("Packet signature: present but no certificate to verify against"); + } else { + match record.signature_fields.verify(&record.dns, &record.cert) { + Ok(true) => println!("Packet signature: ✓ verified"), + Ok(false) => println!("Packet signature: ✗ invalid"), + Err(e) => println!("Packet signature: ✗ error ({e:?})"), + } + } + match ddns::core::parser::packet::be_packet(&record.dns) { Ok((_, packet)) => { print!("{}", format_packet(&packet)); - - for rr in &packet.answers { - if let RData::E(ep) = rr.data() { - if !ep.is_signed() { - println!("Signature: none"); - continue; - } - if record.cert.is_empty() { - println!("Signature: present but no certificate to verify against"); - continue; - } - match ep.verify_signature_from_der(&record.cert) { - Ok(true) => println!("Signature: ✓ verified"), - Ok(false) => println!("Signature: ✗ invalid"), - Err(e) => println!("Signature: ✗ error ({e:?})"), - } - } - } } Err(_) => { println!("DNS payload: invalid ({} bytes)", record.dns.len()); diff --git a/server.toml b/server.toml deleted file mode 100644 index 90eafe6..0000000 --- a/server.toml +++ /dev/null @@ -1,58 +0,0 @@ -# ddns DNS-over-HTTP/3 server configuration -# All fields are optional; the values shown below are the built-in defaults. - -# Bind patterns to listen on. Dual-stack service is expressed explicitly. -binds = ["0.0.0.0:4433", "[::]:4433"] - -# TLS server name (SNI). -server_name = "dns.genmeta.net" - -# Paths to the server TLS certificate and private key (PEM format). -cert = "~/Downloads/ssl/dns.genmeta.net/dns.genmeta.net.pem" -key = "~/Downloads/ssl/dns.genmeta.net/dns.genmeta.net.key" - -# Root CA that signed the client certificates (PEM format). -root_cert = "~/Downloads/ssl/root.crt" - -# Whether to require a valid DNS record signature on Standard domains. -require_signature = true - -# Default TTL (seconds) for published records. -ttl_secs = 30 - -# Redis URL for persistent storage. -# If omitted, records are kept in memory only (lost on restart). -# redis = "redis://127.0.0.1/" - -# --------------------------------------------------------------------------- -# Domain policy rules -# -# Policies are matched in order; the first matching rule wins. -# Domains not listed here use the built-in "standard" policy. -# -# Policies: -# standard — one record per host; client cert SAN must match the target -# host; signature check controlled by require_signature above; -# each publish overwrites the previous record. -# -# open_multi — any authenticated node may publish; no signature check; -# records are appended (not overwritten), each with its own -# individual TTL; lookup returns newest-first, use ?limit=N -# to cap the number of returned records. -# --------------------------------------------------------------------------- - -[[domain_policies]] -host = "nat.genmeta.net" -policy = "open_multi" - - -# Static bootstrap STUN endpoints returned even before any node publishes. -# Ordering keeps the main :20002 endpoints ahead of the auxiliary :20003 endpoints. -[[seed_records]] -host = "nat.genmeta.net" -endpoints = [] - -# Add more rules as needed, e.g.: -# [[domain_policies]] -# host = "relay.genmeta.net" -# policy = "open_multi" diff --git a/src/bin/ddns-server/config.rs b/src/bin/ddns-server/config.rs deleted file mode 100644 index 8b663bd..0000000 --- a/src/bin/ddns-server/config.rs +++ /dev/null @@ -1,211 +0,0 @@ -use std::{ - net::SocketAddr, - path::{Path, PathBuf}, - str::FromStr, -}; - -use clap::Parser; -use h3x::dquic::binds::BindPattern; -use serde::{Deserialize, Deserializer, de::Error as _}; - -// --------------------------------------------------------------------------- -// CLI -// --------------------------------------------------------------------------- - -#[derive(Parser, Clone, Debug)] -#[command(version, about, long_about = None)] -pub struct Options { - /// Path to the TOML configuration file. - #[arg(long, default_value = "server.toml")] - pub config: PathBuf, -} - -// --------------------------------------------------------------------------- -// Configuration file schema -// --------------------------------------------------------------------------- - -/// Top-level configuration loaded from the TOML file. -#[derive(Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct Config { - /// Redis URL (e.g. "redis://127.0.0.1/"). Omit to use in-memory storage. - pub redis: Option, - - /// Bind patterns to listen on. - #[serde( - default = "Config::default_binds", - deserialize_with = "deserialize_bind_patterns" - )] - pub binds: Vec, - - /// Server name (used as TLS SNI). - #[serde(default = "Config::default_server_name")] - pub server_name: String, - - /// Path to the server TLS certificate (PEM). - #[serde(default = "Config::default_cert")] - pub cert: PathBuf, - - /// Path to the server TLS private key (PEM). - #[serde(default = "Config::default_key")] - pub key: PathBuf, - - /// Path to the root CA that signs client certificates (PEM). - #[serde(default = "Config::default_root_cert")] - pub root_cert: PathBuf, - - /// Whether to require DNS record signatures on Standard domains. - #[serde(default = "Config::default_require_signature")] - pub require_signature: bool, - - /// Default TTL (seconds) for published records. - #[serde(default = "Config::default_ttl_secs")] - pub ttl_secs: u64, - - /// Domain-policy rules (first match wins; unlisted domains use Standard). - #[serde(default)] - pub domain_policies: Vec, - - /// Static seed records returned on lookup in addition to dynamic published records. - #[serde(default)] - pub seed_records: Vec, -} - -impl Config { - pub fn expand_paths(mut self) -> Self { - self.cert = expand_home_dir(&self.cert); - self.key = expand_home_dir(&self.key); - self.root_cert = expand_home_dir(&self.root_cert); - self - } - - pub fn default_binds() -> Vec { - ["0.0.0.0:4433", "[::]:4433"] - .into_iter() - .map(|value| { - BindPattern::from_str(value).expect("default bind pattern should be valid") - }) - .collect() - } - pub fn default_server_name() -> String { - "localhost".into() - } - pub fn default_cert() -> PathBuf { - "examples/keychain/localhost/localhost-ECC.crt".into() - } - pub fn default_key() -> PathBuf { - "examples/keychain/localhost/localhost-ECC.key".into() - } - pub fn default_root_cert() -> PathBuf { - "examples/keychain/root/rootCA-ECC.crt".into() - } - pub fn default_require_signature() -> bool { - true - } - pub fn default_ttl_secs() -> u64 { - 30 - } -} - -fn deserialize_bind_patterns<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - let values = Vec::::deserialize(deserializer)?; - values - .into_iter() - .map(|value| { - BindPattern::from_str(&value).map_err(|error| { - D::Error::custom(format!("invalid bind pattern `{value}`: {error}")) - }) - }) - .collect() -} - -fn expand_home_dir(path: &Path) -> PathBuf { - let Some(path_str) = path.to_str() else { - return path.to_path_buf(); - }; - - if path_str == "~" { - return std::env::var_os("HOME") - .map(PathBuf::from) - .unwrap_or_else(|| path.to_path_buf()); - } - - if let Some(stripped) = path_str.strip_prefix("~/") - && let Some(home) = std::env::var_os("HOME") - { - return PathBuf::from(home).join(stripped); - } - - path.to_path_buf() -} - -/// One domain-policy rule in the configuration file. -#[derive(Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct PolicyConfig { - /// Exact host to match (after normalisation). - pub host: String, - /// Policy to apply. - pub policy: PolicyKind, -} - -/// One statically configured seed record group. -#[derive(Deserialize, Debug, Clone)] -#[serde(deny_unknown_fields)] -pub struct SeedRecordConfig { - /// Exact host to seed. - pub host: String, - /// Preloaded endpoint list for this host. - pub endpoints: Vec, -} - -/// Serialisable policy kind. -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "snake_case")] -pub enum PolicyKind { - Standard, - OpenMulti, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn default_binds_are_explicit_dual_stack() { - let binds = Config::default_binds(); - - assert_eq!(binds.len(), 2); - assert_eq!(binds[0].to_string(), "inet://0.0.0.0:4433"); - assert_eq!(binds[1].to_string(), "inet://[::]:4433"); - } - - #[test] - fn config_parses_bare_socket_bind_patterns() { - let config: Config = toml::from_str( - r#" - binds = ["0.0.0.0:4433", "[::]:4433"] - "#, - ) - .expect("config should parse"); - - assert_eq!(config.binds.len(), 2); - assert_eq!(config.binds[0].to_string(), "inet://0.0.0.0:4433"); - assert_eq!(config.binds[1].to_string(), "inet://[::]:4433"); - } - - #[test] - fn legacy_listen_field_is_rejected() { - let error = toml::from_str::( - r#" - listen = "0.0.0.0:4433" - "#, - ) - .expect_err("legacy listen should be rejected"); - - assert!(error.to_string().contains("unknown field `listen`")); - } -} diff --git a/src/bin/ddns-server/error.rs b/src/bin/ddns-server/error.rs deleted file mode 100644 index c8930ba..0000000 --- a/src/bin/ddns-server/error.rs +++ /dev/null @@ -1,117 +0,0 @@ -use std::collections::HashMap; - -use dhttp_identity::name::DhttpName; - -#[derive(Debug, snafu::Snafu)] -#[snafu(module, visibility(pub(crate)))] -pub enum AppError { - #[snafu(display("missing host parameter"))] - MissingHostParam, - #[snafu(display("invalid host"))] - InvalidHost, - #[snafu(display("forbidden host"))] - ForbiddenHost, - #[snafu(display("domain not allowed"))] - DomainNotAllowed, - #[snafu(display("host mismatch"))] - HostMismatch, - #[snafu(display("missing client certificate"))] - MissingClientCertificate, - #[snafu(display("client certificate domain not allowed"))] - ClientCertDomainNotAllowed, - #[snafu(display("invalid DNS packet: {message}"))] - InvalidDnsPacket { message: String }, - #[snafu(display("publisher certificate selector is invalid"))] - PublisherCertificateSelector { - source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, - }, - #[snafu(display("endpoint record selector is invalid"))] - EndpointRecordSelector { - source: ddns::core::parser::record::endpoint::EndpointSelectorError, - }, - #[snafu(display("endpoint record selector does not match publisher certificate selector"))] - EndpointSelectorMismatch, - #[snafu(display("no answers in packet"))] - NoAnswersInPacket, - #[snafu(display("signature required"))] - SignatureRequired, - #[snafu(display("invalid signature"))] - InvalidSignature, - #[snafu(display("redis error: {message}"))] - Redis { message: String }, -} - -impl AppError { - pub fn status(&self) -> http::StatusCode { - match self { - AppError::MissingHostParam => http::StatusCode::BAD_REQUEST, - AppError::InvalidHost => http::StatusCode::BAD_REQUEST, - AppError::ForbiddenHost => http::StatusCode::BAD_REQUEST, - AppError::DomainNotAllowed => http::StatusCode::FORBIDDEN, - AppError::HostMismatch => http::StatusCode::BAD_REQUEST, - AppError::MissingClientCertificate => http::StatusCode::UNAUTHORIZED, - AppError::ClientCertDomainNotAllowed => http::StatusCode::FORBIDDEN, - AppError::InvalidDnsPacket { .. } => http::StatusCode::BAD_REQUEST, - AppError::PublisherCertificateSelector { .. } => http::StatusCode::BAD_REQUEST, - AppError::EndpointRecordSelector { .. } => http::StatusCode::BAD_REQUEST, - AppError::EndpointSelectorMismatch => http::StatusCode::BAD_REQUEST, - AppError::NoAnswersInPacket => http::StatusCode::UNPROCESSABLE_ENTITY, - AppError::SignatureRequired => http::StatusCode::BAD_REQUEST, - AppError::InvalidSignature => http::StatusCode::BAD_REQUEST, - AppError::Redis { .. } => http::StatusCode::SERVICE_UNAVAILABLE, - } - } -} - -pub fn normalize_host(host: &str) -> Result { - let host = host.trim(); - if host.is_empty() { - return Err(AppError::InvalidHost); - } - if host.contains('*') { - return Err(AppError::ForbiddenHost); - } - - // 剥离端口号(如 "example.com:443" -> "example.com") - let host = match host.rsplit_once(':') { - Some((h, port)) if port.chars().all(|c| c.is_ascii_digit()) => h, - _ => host, - }; - - // 允许末尾 '.'(FQDN 写法) - let host = host.strip_suffix('.').unwrap_or(host); - - let host = idna::domain_to_ascii(host).map_err(|_| AppError::InvalidHost)?; - let host = host.to_ascii_lowercase(); - - // 校验是否为 DHTTP identity 域名 - if !host.ends_with(DhttpName::SUFFIX) { - return Err(AppError::DomainNotAllowed); - } - - Ok(host) -} - -pub fn parse_query_params(uri: &http::Uri) -> HashMap { - let query = uri.query().unwrap_or(""); - url::form_urlencoded::parse(query.as_bytes()) - .into_owned() - .collect() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn normalize_host_uses_dhttp_identity_suffix() { - assert_eq!( - normalize_host("Reimu.Pilot.Dhttp.Net:443").unwrap(), - "reimu.pilot.dhttp.net" - ); - assert!(matches!( - normalize_host("reimu.pilot.genmeta.net"), - Err(AppError::DomainNotAllowed) - )); - } -} diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs deleted file mode 100644 index e8c2616..0000000 --- a/src/bin/ddns-server/lookup.rs +++ /dev/null @@ -1,249 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - convert::Infallible, - net::SocketAddr, -}; - -use ddns::core::{ - MdnsPacket, - parser::{packet::be_packet, record::RData}, - wire::MultiResponse, -}; -use deadpool_redis::redis::{self, AsyncCommands}; -use h3x::dhttp::message::MessageStreamError; -use http_body_util::{Full, combinators::UnsyncBoxBody}; -use tracing::debug; - -use crate::{ - error::{AppError, normalize_host, parse_query_params}, - storage::{AppState, LookupRecord, Storage, StoredRecord, unix_now_secs}, -}; - -pub type Request = http::Request>; -pub type Response = http::Response>; - -// --------------------------------------------------------------------------- -// Lookup result type -// --------------------------------------------------------------------------- - -pub enum LookupResult { - NotFound, - /// Multiple records, newest-first. - Multi(MultiResponse), -} - -type EndpointKey = (SocketAddr, Option); - -fn normalize_lookup_records(records: Vec) -> Vec { - let mut normalized = Vec::new(); - let mut seen = HashSet::new(); - - for (dns_bytes, cert_bytes) in records { - let Ok((_, packet)) = be_packet(&dns_bytes) else { - normalized.push((dns_bytes, cert_bytes)); - continue; - }; - - let mut emitted_endpoint = false; - - for answer in &packet.answers { - let RData::E(endpoint) = answer.data() else { - continue; - }; - - emitted_endpoint = true; - let key: EndpointKey = (endpoint.addr(), endpoint.agent_addr()); - - if !seen.insert(key) { - continue; - } - - let mut hosts = HashMap::new(); - hosts.insert(answer.name().to_string(), vec![endpoint.clone()]); - normalized.push((MdnsPacket::answer(0, &hosts).to_bytes(), cert_bytes.clone())); - } - - if !emitted_endpoint { - normalized.push((dns_bytes, cert_bytes)); - } - } - - normalized -} - -// --------------------------------------------------------------------------- -// Core lookup logic -// --------------------------------------------------------------------------- - -pub async fn perform_lookup( - state: &AppState, - host: &str, - limit: Option, -) -> Result { - let host = normalize_host(host)?; - perform_lookup_multi(state, &host, limit).await -} - -async fn perform_lookup_multi( - state: &AppState, - host: &str, - limit: Option, -) -> Result { - let mut records = match &state.storage { - Storage::Redis(pool) => { - let mut conn = pool.get().await.map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - let set_key = format!("{host}:multi"); - let now_secs = unix_now_secs(); - - // Remove expired members: those published more than ttl_secs ago. - let cutoff_score = now_secs.saturating_sub(state.ttl_secs) as f64; - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&set_key) - .arg("-inf") - .arg(cutoff_score) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - - // Fetch all remaining, newest first (highest score = most recently published) - let count: isize = limit.map(|l| l as isize).unwrap_or(-1); - let members: Vec> = conn - .zrevrange(&set_key, 0isize, if count < 0 { -1 } else { count - 1 }) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - let now_secs = unix_now_secs(); - let records: Vec<(Vec, Vec)> = members - .into_iter() - .filter_map(|m| { - let r = StoredRecord::decode(&m)?; - if r.expire_unix_secs > now_secs { - Some((r.dns, r.cert)) - } else { - None - } - }) - .collect(); - - records - } - Storage::Memory(mem) => { - let now = tokio::time::Instant::now(); - if let Some(mut entry) = mem.records.get_mut(host) { - // Evict expired entries in-place. - entry.retain(|_, r| r.expire > now); - // Sort newest-first by published_at. - let take = limit.unwrap_or(entry.len()).min(entry.len()); - let mut records: Vec<_> = entry.values().collect(); - records.sort_by_key(|b| std::cmp::Reverse(b.published_at)); - records[..take] - .iter() - .map(|r| (r.dns_bytes.clone(), r.cert_bytes.clone())) - .collect::>() - } else { - vec![] - } - } - }; - - if let Some(seed_records) = state.seed_records.get(host) { - records.extend(seed_records.iter().cloned()); - } - - let records = normalize_lookup_records(records); - - if records.is_empty() { - Ok(LookupResult::NotFound) - } else { - Ok(LookupResult::Multi(MultiResponse::new(records))) - } -} - -// --------------------------------------------------------------------------- -// HTTP response helpers -// --------------------------------------------------------------------------- - -pub fn body_response(status: http::StatusCode, body: impl Into) -> Response { - http::Response::builder() - .status(status) - .body(Full::new(body.into())) - .expect("response parts must be valid") -} - -pub fn write_error(err: AppError) -> Response { - debug!( - status = %err.status(), - error = %err, - "writing error response" - ); - body_response(err.status(), bytes::Bytes::from(err.to_string())) -} - -// --------------------------------------------------------------------------- -// LookupSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct LookupSvc { - pub state: AppState, -} - -/// Handle a lookup request. -/// -/// Always returns multi-record binary body: -/// `[u32 count BE]([u32 dns_len BE][dns][u32 cert_len BE][cert])*` -/// with header `x-record-format: multi`. -/// -/// Optional query param `limit=N` caps the number of records returned. -/// Dynamic records are newest-first; configured seed records are appended after them. -pub async fn lookup_with_cert(state: AppState, request: Request) -> Response { - let params = parse_query_params(request.uri()); - let Some(host) = params.get("host") else { - return write_error(AppError::MissingHostParam); - }; - - let limit: Option = params - .get("limit") - .and_then(|v| v.parse::().ok()) - .filter(|&n| n > 0); - - debug!(host = %host, limit, "lookup.request"); - - match perform_lookup(&state, host, limit).await { - Ok(LookupResult::NotFound) => { - debug!(host = %host, "lookup.not_found"); - body_response( - http::StatusCode::NOT_FOUND, - bytes::Bytes::from_static(b"Not Found"), - ) - } - - Ok(LookupResult::Multi(resp)) => { - let body = resp.encode(); - debug!(host = %host, records = resp.records.len(), "lookup.found"); - let mut response = body_response(http::StatusCode::OK, bytes::Bytes::from(body)); - response.headers_mut().insert( - http::HeaderName::from_static("x-record-format"), - http::HeaderValue::from_static("multi"), - ); - response - } - - Err(e) => write_error(e), - } -} - -impl LookupSvc { - pub fn call( - &self, - request: Request, - ) -> impl Future> + Send + 'static { - let state = self.state.clone(); - async move { Ok(lookup_with_cert(state, request).await) } - } -} diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs deleted file mode 100644 index 735fa74..0000000 --- a/src/bin/ddns-server/main.rs +++ /dev/null @@ -1,241 +0,0 @@ -mod config; -mod error; -mod lookup; -mod policy; -mod publish; -mod storage; - -use std::{ - collections::HashMap, - io, - net::SocketAddr, - sync::Arc, - task::{Context, Poll}, -}; - -use clap::Parser; -use ddns::core::{MdnsEndpoint, MdnsPacket}; -use futures::future::BoxFuture; -use h3x::{ - dquic::{ - Identity, Network, QuicEndpoint, - cert::handy::{ToCertificate, ToPrivateKey}, - server::ServerQuicConfig, - }, - endpoint::H3Endpoint, - hyper::TowerService, -}; -use rustls::{RootCertStore, server::WebPkiClientVerifier}; -use tracing::{info, level_filters::LevelFilter}; -use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt}; - -use crate::{ - config::{Config, Options, PolicyKind, SeedRecordConfig}, - lookup::LookupSvc, - policy::{DomainPolicies, DomainPolicy, PolicyRule}, - publish::PublishSvc, - storage::{AppState, MemoryStorage, SeedRecords, Storage}, -}; - -#[derive(Clone)] -struct DnsService { - publish: PublishSvc, - lookup: LookupSvc, -} - -impl tower_service::Service for DnsService { - type Response = lookup::Response; - type Error = io::Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, request: lookup::Request) -> Self::Future { - let method = request.method().clone(); - let path = request.uri().path().to_owned(); - let publish = self.publish.clone(); - let lookup = self.lookup.clone(); - Box::pin(async move { - match (method, path.as_str()) { - (http::Method::POST, "/publish") => match publish.call(request).await { - Ok(response) => Ok(response), - Err(never) => match never {}, - }, - (http::Method::GET, "/lookup") => match lookup.call(request).await { - Ok(response) => Ok(response), - Err(never) => match never {}, - }, - (_, "/publish" | "/lookup") => Ok(lookup::body_response( - http::StatusCode::METHOD_NOT_ALLOWED, - bytes::Bytes::from_static(b"Method Not Allowed"), - )), - _ => Ok(lookup::body_response( - http::StatusCode::NOT_FOUND, - bytes::Bytes::from_static(b"Not Found"), - )), - } - }) - } -} - -// --------------------------------------------------------------------------- -// TLS helpers -// --------------------------------------------------------------------------- - -fn load_root_store_from_pem(pem: &[u8]) -> io::Result { - let mut reader = std::io::Cursor::new(pem); - let certs = rustls_pemfile::certs(&mut reader) - .collect::, _>>() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - let mut store = RootCertStore::empty(); - store.add_parsable_certificates(certs); - Ok(store) -} - -fn build_seed_records(seed_records: &[SeedRecordConfig]) -> io::Result { - let mut records = HashMap::new(); - - for seed_record in seed_records { - if seed_record.endpoints.is_empty() { - continue; - } - - let host = error::normalize_host(&seed_record.host) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - - let endpoints = seed_record - .endpoints - .iter() - .map(|addr| match addr { - SocketAddr::V4(addr) => MdnsEndpoint::direct_v4(*addr), - SocketAddr::V6(addr) => MdnsEndpoint::direct_v6(*addr), - }) - .collect::>(); - - let mut hosts = HashMap::new(); - hosts.insert(host.clone(), endpoints); - - records - .entry(host.clone()) - .or_insert_with(Vec::new) - .push((MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new())); - - info!(host = %host, endpoint_count = seed_record.endpoints.len(), "seed_records.loaded"); - } - - Ok(Arc::new(records)) -} - -// --------------------------------------------------------------------------- -// Entry point -// --------------------------------------------------------------------------- - -#[tokio::main] -async fn main() -> Result<(), Box> { - tracing_subscriber::registry() - .with(tracing_subscriber::fmt::layer()) - .with(tracing_subscriber::filter::filter_fn(|metadata| { - !metadata.target().contains("netlink_packet_route") - })) - .with(LevelFilter::DEBUG) - .init(); - - let options = Options::parse(); - - let config_str = std::fs::read_to_string(&options.config).unwrap_or_else(|e| { - eprintln!("failed to read config {:?}: {e}", options.config); - std::process::exit(1); - }); - let config: Config = toml::from_str(&config_str).unwrap_or_else(|e| { - eprintln!("failed to parse config {:?}: {e}", options.config); - std::process::exit(1); - }); - let config = config.expand_paths(); - let seed_records = build_seed_records(&config.seed_records)?; - - // Build storage backend. - let storage = match config.redis.clone() { - Some(url) => { - let redis_cfg = deadpool_redis::Config::from_url(url); - let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; - Storage::Redis(redis_pool) - } - None => Storage::Memory(MemoryStorage::new()), - }; - - // Build domain-policy rules from config file. - let mut policy_rules: Vec<(PolicyRule, DomainPolicy)> = config - .domain_policies - .iter() - .filter_map(|pc| { - error::normalize_host(&pc.host).ok().map(|h| { - let policy = match pc.policy { - PolicyKind::Standard => DomainPolicy::Standard, - PolicyKind::OpenMulti => DomainPolicy::OpenMulti, - }; - (PolicyRule::Exact(h), policy) - }) - }) - .collect(); - // Deduplicate (preserve first occurrence). - policy_rules.dedup_by(|(ra, _), (rb, _)| { - matches!((ra, rb), (PolicyRule::Exact(a), PolicyRule::Exact(b)) if a == b) - }); - let policies = Arc::new(DomainPolicies(policy_rules)); - info!(?policies, "domain_policies.loaded"); - - // Load the root CA used to validate client certificates when they are provided. - let root_ca_pem = std::fs::read(&config.root_cert)?; - let roots = load_root_store_from_pem(&root_ca_pem)?; - let verifier = WebPkiClientVerifier::builder(Arc::new(roots)) - .allow_unauthenticated() - .build() - .unwrap(); - - let state = AppState { - storage, - require_signature: config.require_signature, - ttl_secs: config.ttl_secs, - policies, - seed_records, - }; - - let cert_pem = std::fs::read(&config.cert)?; - let key_pem = std::fs::read(&config.key)?; - - let router = TowerService(DnsService { - publish: PublishSvc { - state: state.clone(), - }, - lookup: LookupSvc { - state: state.clone(), - }, - }); - - let identity = Arc::new(Identity { - name: config.server_name.parse().unwrap(), - certs: Arc::new(cert_pem.to_certificate()), - key: Arc::new(key_pem.to_private_key()), - ocsp: Arc::new(None), - }); - let server_config = ServerQuicConfig { - alpns: vec![b"h3".to_vec()], - client_cert_verifier: verifier, - ..Default::default() - }; - let quic = QuicEndpoint::builder() - .network(Network::builder().build()) - .identity(identity) - .server(server_config) - .bind(Arc::new(config.binds.clone())) - .build() - .await; - let server = Arc::new(H3Endpoint::new(quic)); - info!(binds = ?config.binds, server_name = %config.server_name, "h3_server.start"); - server.listen_owned(router).await?; - - Ok(()) -} diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs deleted file mode 100644 index 413ad4b..0000000 --- a/src/bin/ddns-server/policy.rs +++ /dev/null @@ -1,337 +0,0 @@ -use ddns::core::parser::{packet::be_packet, record::RData}; -use dhttp_identity::identity::{RemoteAuthority, RemoteAuthorityCertificateExt}; -use snafu::ResultExt; -use tracing::{debug, warn}; - -use crate::error::{AppError, app_error, normalize_host}; - -// --------------------------------------------------------------------------- -// Domain policy -// --------------------------------------------------------------------------- - -/// Per-domain publish / lookup behaviour. -#[derive(Clone, Debug, PartialEq)] -pub enum DomainPolicy { - /// Signature check controlled by `require_signature` flag; single record - /// per host; each publish overwrites the previous one. - Standard, - /// No signature check; any authenticated node may publish; multiple records - /// with individual TTLs; ordered newest-first on lookup. - OpenMulti, -} - -/// One rule in the domain-policy list. -#[derive(Clone, Debug)] -pub enum PolicyRule { - /// Matches only this exact (normalised) host. - Exact(String), - /// Matches the host itself or any label-subdomain (future use). - #[allow(dead_code)] - Suffix(String), -} - -impl PolicyRule { - pub fn matches(&self, host: &str) -> bool { - match self { - PolicyRule::Exact(exact) => host == exact, - PolicyRule::Suffix(suffix) => { - host == suffix.as_str() || host.ends_with(&format!(".{suffix}")) - } - } - } -} - -/// Ordered list of (rule, policy) pairs; first match wins; default is Standard. -#[derive(Clone, Debug, Default)] -pub struct DomainPolicies(pub Vec<(PolicyRule, DomainPolicy)>); - -impl DomainPolicies { - pub fn policy_for(&self, host: &str) -> &DomainPolicy { - for (rule, policy) in &self.0 { - if rule.matches(host) { - return policy; - } - } - &DomainPolicy::Standard - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ValidatedDnsPacket { - Records { host: String }, - Empty, -} - -// --------------------------------------------------------------------------- -// Certificate helpers -// --------------------------------------------------------------------------- - -pub fn extract_client_dns_sans(authority: &(impl RemoteAuthority + ?Sized)) -> Vec { - use x509_parser::prelude::*; - - let Some(leaf) = authority.cert_chain().first() else { - return vec![]; - }; - - let Ok((_remain, cert)) = X509Certificate::from_der(leaf.as_ref()) else { - return vec![]; - }; - - let mut out = vec![]; - if let Ok(Some(san)) = cert.subject_alternative_name() { - for name in san.value.general_names.iter() { - if let GeneralName::DNSName(dns) = name { - out.push(dns.to_string()); - } - } - } - out -} - -pub fn client_allowed_host( - authority: &(impl RemoteAuthority + ?Sized), -) -> Result { - let mut sans = extract_client_dns_sans(authority) - .into_iter() - .filter_map(|h| normalize_host(&h).ok()) - .collect::>(); - - sans.sort(); - sans.dedup(); - - match sans.len() { - 1 => Ok(sans.remove(0)), - _ => Err(AppError::ClientCertDomainNotAllowed), - } -} - -pub fn validate_dns_packet( - packet: &[u8], - require_signature: bool, - authority: &(impl RemoteAuthority + ?Sized), -) -> Result { - let (remaining, dns_packet) = be_packet(packet).map_err(|e| AppError::InvalidDnsPacket { - message: e.to_string(), - })?; - if !remaining.is_empty() { - warn!(remain = remaining.len(), "dns.parse.extra_bytes"); - } - debug!( - answers = dns_packet.answers.len(), - require_signature, "validating dns packet" - ); - - let Some(first_answer) = dns_packet.answers.first() else { - debug!("dns packet has no answers"); - return Ok(ValidatedDnsPacket::Empty); - }; - - validate_endpoint_selectors(&dns_packet, authority)?; - - if require_signature { - let has_signature = dns_packet - .answers - .iter() - .any(|record| matches!(record.data(), RData::E(endpoint) if endpoint.is_signed())); - - if !has_signature { - return Err(AppError::SignatureRequired); - } - - for record in &dns_packet.answers { - if let RData::E(endpoint) = record.data() - && endpoint.is_signed() - { - let cert = authority - .cert_chain() - .first() - .ok_or(AppError::MissingClientCertificate)?; - let ok = endpoint - .verify_signature_from_der(cert.as_ref()) - .map_err(|_| AppError::InvalidSignature)?; - if !ok { - return Err(AppError::InvalidSignature); - } - } - } - } - - Ok(ValidatedDnsPacket::Records { - host: first_answer.name().to_string(), - }) -} - -fn validate_endpoint_selectors( - dns_packet: &ddns::core::parser::packet::Packet, - authority: &(impl RemoteAuthority + ?Sized), -) -> Result<(), AppError> { - let mut endpoints = dns_packet - .answers - .iter() - .filter_map(|record| match record.data() { - RData::E(endpoint) => Some(endpoint), - _ => None, - }); - - let Some(first_endpoint) = endpoints.next() else { - return Ok(()); - }; - - let expected = authority - .dhttp_subject_key_identifier() - .context(app_error::PublisherCertificateSelectorSnafu)? - .chain() - .clone(); - - let first = first_endpoint - .certificate_chain_key() - .context(app_error::EndpointRecordSelectorSnafu)?; - if first != expected { - return Err(AppError::EndpointSelectorMismatch); - } - - for endpoint in endpoints { - let actual = endpoint - .certificate_chain_key() - .context(app_error::EndpointRecordSelectorSnafu)?; - if actual != expected { - return Err(AppError::EndpointSelectorMismatch); - } - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; - use dhttp_identity::identity::RemoteAuthority; - use rustls::pki_types::CertificateDer; - - use super::*; - - #[derive(Debug)] - struct TestAuthority { - certs: Vec>, - } - - impl TestAuthority { - fn valid() -> Self { - Self { - certs: vec![CertificateDer::from( - include_bytes!("../../../tests/fixtures/valid.der").to_vec(), - )], - } - } - - fn missing_ski() -> Self { - Self { - certs: vec![CertificateDer::from( - include_bytes!("../../../tests/fixtures/missing.der").to_vec(), - )], - } - } - - fn malformed_ski() -> Self { - Self { - certs: vec![CertificateDer::from( - include_bytes!("../../../tests/fixtures/malformed.der").to_vec(), - )], - } - } - } - - impl RemoteAuthority for TestAuthority { - fn name(&self) -> &str { - "authority.example" - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.certs - } - } - - fn packet_with_endpoint(endpoint: EndpointAddr) -> Vec { - let hosts: HashMap> = - HashMap::from([("reimu.pilot.dhttp.net".to_owned(), vec![endpoint])]); - MdnsPacket::answer(0, &hosts).to_bytes() - } - - #[test] - fn validate_dns_packet_accepts_matching_certificate_selector() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - endpoint.set_sequence(0); - let packet = packet_with_endpoint(endpoint); - - let validated = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap(); - - assert!(matches!(validated, ValidatedDnsPacket::Records { .. })); - } - - #[test] - fn validate_dns_packet_rejects_mismatched_endpoint_kind() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(false); - endpoint.set_sequence(0); - let packet = packet_with_endpoint(endpoint); - - let error = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap_err(); - - assert!(matches!(error, AppError::EndpointSelectorMismatch)); - } - - #[test] - fn validate_dns_packet_rejects_mismatched_endpoint_sequence() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - endpoint.set_sequence(7); - let packet = packet_with_endpoint(endpoint); - - let error = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap_err(); - - assert!(matches!(error, AppError::EndpointSelectorMismatch)); - } - - #[test] - fn validate_dns_packet_rejects_missing_publisher_ski() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - let packet = packet_with_endpoint(endpoint); - - let error = validate_dns_packet(&packet, false, &TestAuthority::missing_ski()).unwrap_err(); - - assert!(matches!( - error, - AppError::PublisherCertificateSelector { .. } - )); - } - - #[test] - fn validate_dns_packet_rejects_malformed_publisher_ski() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - let packet = packet_with_endpoint(endpoint); - - let error = - validate_dns_packet(&packet, false, &TestAuthority::malformed_ski()).unwrap_err(); - - assert!(matches!( - error, - AppError::PublisherCertificateSelector { .. } - )); - } - - #[test] - fn validate_dns_packet_accepts_empty_packet_as_clear_operation() { - let hosts: HashMap> = - HashMap::from([("reimu.pilot.dhttp.net".to_owned(), Vec::new())]); - let packet = MdnsPacket::answer(0, &hosts).to_bytes(); - - let validated = validate_dns_packet(&packet, true, &TestAuthority::valid()).unwrap(); - - assert!(matches!(validated, ValidatedDnsPacket::Empty)); - } -} diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish.rs deleted file mode 100644 index 5b07347..0000000 --- a/src/bin/ddns-server/publish.rs +++ /dev/null @@ -1,427 +0,0 @@ -use std::{convert::Infallible, sync::Arc}; - -use deadpool_redis::redis::{self, AsyncCommands}; -use dhttp_identity::identity::RemoteAuthority; -use h3x::{connection::ConnectionState, quic}; -use http_body_util::BodyExt; -use tokio::time::{Duration, Instant}; -use tracing::{debug, info, warn}; - -use crate::{ - error::{AppError, normalize_host, parse_query_params}, - lookup::{Request, Response, body_response, write_error}, - policy::{DomainPolicy, ValidatedDnsPacket, client_allowed_host, validate_dns_packet}, - storage::{ - AppState, Record, Storage, StoredRecord, cert_fingerprint, cert_fingerprint_hex, - unix_now_secs, - }, -}; - -// --------------------------------------------------------------------------- -// PublishSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct PublishSvc { - pub state: AppState, -} - -impl PublishSvc { - pub fn call( - &self, - request: Request, - ) -> impl Future> + Send + 'static { - let state = self.state.clone(); - async move { Ok(publish_with_cert(state, request).await) } - } -} - -async fn publish_with_cert(state: AppState, request: Request) -> Response { - debug!("received publish request"); - - let params = parse_query_params(request.uri()); - debug!("query params: {:?}", params); - - let Some(host) = params.get("host") else { - warn!("missing host parameter"); - return write_error(AppError::MissingHostParam); - }; - - let host = match normalize_host(host) { - Ok(h) => h, - Err(e) => return write_error(e), - }; - debug!(host = %host, "publish.host"); - - // Require a valid client certificate for all publish requests. - let authority = match request_connection(&request) { - Some(connection) => match connection.remote_authority().await { - Ok(Some(authority)) => authority, - Ok(None) => { - warn!("missing client certificate"); - return write_error(AppError::MissingClientCertificate); - } - Err(error) => { - warn!(error = %snafu::Report::from_error(&error), "failed to read client certificate"); - return write_error(AppError::MissingClientCertificate); - } - }, - None => { - warn!("missing client certificate"); - return write_error(AppError::MissingClientCertificate); - } - }; - - let policy = state.policies.policy_for(&host).clone(); - - // Standard policy: cert SAN must match the target host. - // OpenMulti policy: any authenticated node may publish — skip SAN check. - if policy == DomainPolicy::Standard { - let allowed = match client_allowed_host(authority.as_ref()) { - Ok(h) => h, - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); - return write_error(e); - } - }; - if allowed != host { - warn!(allowed = %allowed, requested = %host, "publish.host_mismatch"); - return write_error(AppError::HostMismatch); - } - } - - let body = match request.into_body().collect().await { - Ok(body) => body.to_bytes(), - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "failed to read request body"); - return write_error(AppError::InvalidDnsPacket { - message: e.to_string(), - }); - } - }; - - // Validate DNS packet; signature check only for Standard hosts. - let require_sig = policy == DomainPolicy::Standard && state.require_signature; - debug!( - host = %host, - bytes = body.len(), - require_signature = require_sig, - "validating publish packet" - ); - let packet = match validate_dns_packet(body.as_ref(), require_sig, authority.as_ref()) { - Ok(n) => n, - Err(e) => { - debug!(host = %host, error = %e, "publish packet rejected"); - return write_error(e); - } - }; - - match packet { - ValidatedDnsPacket::Records { host: packet_name } => { - let packet_host = match normalize_host(&packet_name) { - Ok(h) => h, - Err(e) => return write_error(e), - }; - - if packet_host != host { - return write_error(AppError::HostMismatch); - } - - publish_record(&state, &host, &body, authority.as_ref()).await - } - ValidatedDnsPacket::Empty => clear_record(&state, &host, authority.as_ref()).await, - } -} - -fn request_connection(request: &Request) -> Option>> { - request - .extensions() - .get::>>() - .cloned() -} - -/// Unified publish handler: stores the record keyed by (host, cert-fingerprint). -/// Both Standard and OpenMulti policies follow the same storage path; -/// the only policy difference (SAN check) is already enforced in the caller. -/// -/// Certificate fingerprint is the publish-source identity. In PKI ecosystems, -/// a single domain name can have multiple valid certificates (from different CAs, -/// or issued at different times for rotation/failover/multi-region scenarios). -/// Using fingerprint as part of the storage key enables: -/// - Multi-publisher coexistence: different cert holders can publish the same domain -/// - Idempotent updates: re-publishing from same cert source (same fingerprint) overwrites old data -/// - Client choice: lookups return all active records, client picks which certificate to trust -pub async fn publish_record( - state: &AppState, - host: &str, - body: &bytes::Bytes, - authority: &(impl RemoteAuthority + ?Sized), -) -> Response { - let cert_bytes = authority - .cert_chain() - .first() - .map(|c| c.as_ref().to_vec()) - .unwrap_or_default(); - - let fp = cert_fingerprint(&cert_bytes); - let fp_hex = cert_fingerprint_hex(&cert_bytes); - - match &state.storage { - Storage::Redis(pool) => { - let mut conn = match pool.get().await { - Ok(c) => c, - Err(e) => { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - }; - let ttl_secs = state.ttl_secs; - let expire_ttl_secs = i64::try_from(state.ttl_secs).unwrap_or(i64::MAX); - let now_secs = unix_now_secs(); - let expire_secs = now_secs + state.ttl_secs; - - let fp_key = format!("{host}:fp:{fp_hex}"); - let set_key = format!("{host}:multi"); - - // Remove the previous entry from this source (if any) from the ZSET. - let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old) = old_member { - let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); - } - - // Encode and store the new member. - let new_member = StoredRecord { - expire_unix_secs: expire_secs, - fingerprint: fp, - dns: body.to_vec(), - cert: cert_bytes.clone(), - } - .encode(); - - if let Err(e) = conn - .set_ex::<_, _, ()>(&fp_key, &new_member, ttl_secs) - .await - { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - - if let Err(e) = conn - .zadd::<_, _, _, ()>(&set_key, &new_member, now_secs as f64) - .await - { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - - // Expire the ZSET key at max(ttl_secs) from now as a safety net. - let _: bool = conn - .expire(&set_key, expire_ttl_secs) - .await - .unwrap_or(false); - - // Evict stale (score < now - ttl) entries. - let cutoff = now_secs.saturating_sub(state.ttl_secs) as f64; - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&set_key) - .arg("-inf") - .arg(cutoff) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - } - Storage::Memory(mem) => { - let now = Instant::now(); - let expire = now + Duration::from_secs(state.ttl_secs); - let record = Record { - dns_bytes: body.to_vec(), - cert_bytes, - expire, - published_at: now, - }; - // Upsert by fingerprint: same source overwrites its own entry; - // different sources (different certs) coexist independently. - let mut host_map = mem.records.entry(host.to_string()).or_default(); - host_map.insert(fp, record); - // Evict expired entries while we hold the write lock. - host_map.retain(|_, r| r.expire > now); - } - } - - info!(host = %host, ttl = state.ttl_secs, bytes = body.len(), fp = %fp_hex, "publish.ok"); - body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) -} - -pub async fn clear_record( - state: &AppState, - host: &str, - authority: &(impl RemoteAuthority + ?Sized), -) -> Response { - let cert_bytes = authority - .cert_chain() - .first() - .map(|c| c.as_ref().to_vec()) - .unwrap_or_default(); - - let fp = cert_fingerprint(&cert_bytes); - let fp_hex = cert_fingerprint_hex(&cert_bytes); - - match &state.storage { - Storage::Redis(pool) => { - let mut conn = match pool.get().await { - Ok(c) => c, - Err(e) => { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - }; - - let fp_key = format!("{host}:fp:{fp_hex}"); - let set_key = format!("{host}:multi"); - - let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old) = old_member { - let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); - } - if let Err(e) = conn.del::<_, ()>(&fp_key).await { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - } - Storage::Memory(mem) => { - let remove_host = if let Some(mut host_map) = mem.records.get_mut(host) { - host_map.remove(&fp); - host_map.is_empty() - } else { - false - }; - if remove_host { - mem.records.remove(host); - } - } - } - - info!(host = %host, fp = %fp_hex, "publish.clear"); - body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) -} - -#[cfg(test)] -mod tests { - use std::{ - collections::HashMap, - net::{Ipv4Addr, SocketAddrV4}, - sync::Arc, - }; - - use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; - use dhttp_identity::identity::RemoteAuthority; - use rustls::pki_types::CertificateDer; - - use super::*; - use crate::{ - lookup::{LookupResult, perform_lookup}, - policy::DomainPolicies, - storage::{MemoryStorage, SeedRecords}, - }; - - #[derive(Debug)] - struct TestAuthority { - name: &'static str, - certs: Vec>, - } - - impl TestAuthority { - fn new(name: &'static str, cert_bytes: Vec) -> Self { - Self { - name, - certs: vec![CertificateDer::from(cert_bytes)], - } - } - } - - impl RemoteAuthority for TestAuthority { - fn name(&self) -> &str { - self.name - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.certs - } - } - - fn memory_state() -> AppState { - AppState { - storage: Storage::Memory(MemoryStorage::new()), - require_signature: true, - ttl_secs: 30, - policies: Arc::new(DomainPolicies::default()), - seed_records: SeedRecords::default(), - } - } - - fn packet_for(host: &str, last_octet: u8) -> bytes::Bytes { - let endpoint = EndpointAddr::direct_v4(SocketAddrV4::new( - Ipv4Addr::new(203, 0, 113, last_octet), - 4433, - )); - let mut hosts = HashMap::new(); - hosts.insert(host.to_owned(), vec![endpoint]); - bytes::Bytes::from(MdnsPacket::answer(0, &hosts).to_bytes()) - } - - #[tokio::test] - async fn clear_record_removes_only_current_certificate_fingerprint() { - let state = memory_state(); - let host = "reimu.pilot.dhttp.net"; - let authority_a = TestAuthority::new("authority-a", vec![1]); - let authority_b = TestAuthority::new("authority-b", vec![2]); - let packet_a = packet_for(host, 1); - let packet_b = packet_for(host, 2); - - assert_eq!( - publish_record(&state, host, &packet_a, &authority_a) - .await - .status(), - http::StatusCode::OK - ); - assert_eq!( - publish_record(&state, host, &packet_b, &authority_b) - .await - .status(), - http::StatusCode::OK - ); - - assert_eq!( - clear_record(&state, host, &authority_a).await.status(), - http::StatusCode::OK - ); - - let LookupResult::Multi(response) = perform_lookup(&state, host, None).await.unwrap() - else { - panic!("authority b record should remain"); - }; - assert_eq!(response.records.len(), 1); - assert_eq!(response.records[0].cert, authority_b.certs[0].as_ref()); - } - - #[tokio::test] - async fn clear_record_is_idempotent_for_missing_fingerprint() { - let state = memory_state(); - let host = "reimu.pilot.dhttp.net"; - let authority = TestAuthority::new("authority", vec![1]); - - assert_eq!( - clear_record(&state, host, &authority).await.status(), - http::StatusCode::OK - ); - assert!(matches!( - perform_lookup(&state, host, None).await.unwrap(), - LookupResult::NotFound - )); - } -} diff --git a/src/bin/ddns-server/storage.rs b/src/bin/ddns-server/storage.rs deleted file mode 100644 index e194faf..0000000 --- a/src/bin/ddns-server/storage.rs +++ /dev/null @@ -1,184 +0,0 @@ -use std::{ - collections::HashMap, - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; - -use bytes::BufMut; -use dashmap::DashMap; -use deadpool_redis::Pool; -use nom::{ - IResult, - bytes::streaming::take, - number::streaming::{be_u32, be_u64}, -}; -use tokio::time::Instant; - -use crate::policy::DomainPolicies; - -// --------------------------------------------------------------------------- -// Storage helpers -// --------------------------------------------------------------------------- - -/// SHA-256 fingerprint of a DER-encoded certificate, used as per-source dedup key. -pub fn cert_fingerprint(cert_der: &[u8]) -> [u8; 32] { - use ring::digest::{SHA256, digest}; - let d = digest(&SHA256, cert_der); - d.as_ref().try_into().expect("SHA-256 is always 32 bytes") -} - -pub fn cert_fingerprint_hex(cert_der: &[u8]) -> String { - cert_fingerprint(cert_der) - .iter() - .map(|b| format!("{b:02x}")) - .collect() -} - -pub fn unix_now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0) -} - -// --------------------------------------------------------------------------- -// Redis ZSET member wire type -// --------------------------------------------------------------------------- - -/// One record as persisted in the Redis ZSET (or decoded from it). -/// -/// Wire layout (big-endian, contiguous): -/// ```text -/// +-----------+--------------+-----------+------+-----------+------+ -/// | expire | fingerprint | dns_len | dns | cert_len | cert | -/// | u64 BE | 32 bytes | u32 BE | ... | u32 BE | ... | -/// +-----------+--------------+-----------+------+-----------+------+ -/// ``` -#[derive(Debug, Clone)] -pub struct StoredRecord { - /// Unix timestamp (seconds) after which this entry is considered stale. - pub expire_unix_secs: u64, - /// SHA-256 fingerprint of the publisher's leaf certificate. - /// Serves as the publisher's identity: uniquely identifies a certificate among multiple - /// valid certs that may be issued for the same domain (from different CAs, at different times, - /// for different regions, etc.). Used as storage key to enable multi-publisher scenarios. - pub fingerprint: [u8; 32], - /// Serialised DNS packet bytes. - pub dns: Vec, - /// DER-encoded leaf certificate of the publisher. - pub cert: Vec, -} - -impl StoredRecord { - pub fn encoding_size(&self) -> usize { - 8 + 32 + 4 + self.dns.len() + 4 + self.cert.len() - } - - /// Encode to a byte buffer suitable for use as a Redis ZSET member. - pub fn encode(&self) -> Vec { - let mut buf = Vec::with_capacity(self.encoding_size()); - buf.put_stored_record(self); - buf - } - - /// Decode from a Redis ZSET member. Returns `None` on malformed input. - pub fn decode(data: &[u8]) -> Option { - be_stored_record(data).ok().map(|(_, r)| r) - } -} - -/// `BufMut` write extension for [`StoredRecord`]. -pub trait WriteStoredRecord { - fn put_stored_record(&mut self, record: &StoredRecord); -} - -impl WriteStoredRecord for B { - fn put_stored_record(&mut self, record: &StoredRecord) { - self.put_u64(record.expire_unix_secs); - self.put_slice(&record.fingerprint); - self.put_u32(record.dns.len() as u32); - self.put_slice(&record.dns); - self.put_u32(record.cert.len() as u32); - self.put_slice(&record.cert); - } -} - -/// nom parser for [`StoredRecord`]. -pub fn be_stored_record(input: &[u8]) -> IResult<&[u8], StoredRecord> { - let (input, expire_unix_secs) = be_u64(input)?; - let (input, fp_bytes) = take(32usize)(input)?; - let (input, dns_len) = be_u32(input)?; - let (input, dns) = take(dns_len as usize)(input)?; - let (input, cert_len) = be_u32(input)?; - let (input, cert) = take(cert_len as usize)(input)?; - Ok(( - input, - StoredRecord { - expire_unix_secs, - fingerprint: fp_bytes.try_into().expect("took exactly 32 bytes"), - dns: dns.to_vec(), - cert: cert.to_vec(), - }, - )) -} - -// --------------------------------------------------------------------------- -// Storage -// --------------------------------------------------------------------------- - -/// A single record stored under a (host, server-fingerprint) key. -#[derive(Clone, Debug)] -pub struct Record { - pub dns_bytes: Vec, - pub cert_bytes: Vec, - /// Wall-clock expiry (for TTL eviction). - pub expire: Instant, - /// When this record was last published (for newest-first ordering). - pub published_at: Instant, -} - -/// Unified in-memory storage: host → { cert_fingerprint → Record }. -/// Both Standard and OpenMulti policies share this map. -/// -/// Per-fingerprint keying design supports PKI's multi-certificate model: -/// A single domain can have multiple valid certificates issued by different CAs, -/// or by the same CA at different times (certificate rotation, multi-region deployment, etc.). -/// Each certificate has a unique fingerprint as its identity. -/// -/// - Same certificate (same fingerprint) republishing → overwrites the previous record -/// - Different certificates (different fingerprints) for same domain → coexist independently -/// - Clients query get all valid records and choose which one to use -#[derive(Clone)] -pub struct MemoryStorage { - pub records: Arc>>, -} - -impl MemoryStorage { - pub fn new() -> Self { - Self { - records: Arc::new(DashMap::new()), - } - } -} - -#[derive(Clone)] -pub enum Storage { - Redis(Pool), - Memory(MemoryStorage), -} - -pub type LookupRecord = (Vec, Vec); -pub type SeedRecords = Arc>>; - -// --------------------------------------------------------------------------- -// Application state -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct AppState { - pub storage: Storage, - pub require_signature: bool, - pub ttl_secs: u64, - pub policies: Arc, - pub seed_records: SeedRecords, -} diff --git a/src/core.rs b/src/core.rs index 308bbb2..936cfd9 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,4 +1,5 @@ pub mod parser; +pub mod signature; pub mod wire; pub type MdnsEndpoint = parser::record::endpoint::EndpointAddr; diff --git a/src/core/parser/record/endpoint.rs b/src/core/parser/record/endpoint.rs index 75167b5..320352a 100644 --- a/src/core/parser/record/endpoint.rs +++ b/src/core/parser/record/endpoint.rs @@ -28,19 +28,12 @@ use crate::core::parser::{ #[derive(Debug, Snafu)] #[snafu(module)] pub enum SignEndpointError { - #[snafu(display("failed to determine endpoint signature scheme"))] - SignatureScheme { source: sigin::SignatureSchemeError }, #[snafu(display("failed to sign endpoint address"))] Sign { source: dhttp_identity::identity::SignError, }, -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum EndpointSelectorError { - #[snafu(display("endpoint record sequence does not fit certificate sequence"))] - SequenceTooLarge { sequence: u64 }, + #[snafu(display("no supported signature scheme for endpoint address"))] + NoSupportedScheme, } /// EndpointAddress record (Type E = 266) @@ -54,7 +47,7 @@ pub enum EndpointSelectorError { /// +-------+-----------------+--------------------+----------------+----------------------------+ /// | flags | sequence(varint)| addr | load(optional) | signature (optional) | /// +-------+-----------------+--------------------+----------------+----------------------------+ -/// | u8 | QUIC varint | see addr layout | f32 | scheme(u16)+len(varint)+N | +/// | u8 | QUIC varint | see addr layout | f32 | scheme(u16)+len(varint)+N | /// +-------+-----------------+--------------------+----------------+----------------------------+ /// /// addr layout: @@ -92,9 +85,9 @@ pub struct EndpointSignature { #[derive(Debug, Clone)] pub struct EndpointAddr { flags: u8, - /// Device sequence number used when multiple hosts share a domain (CLUSTERED). + /// Certificate-chain sequence used when multiple hosts share a domain (CLUSTERED). /// None means no sequence number. - sequence: Option, + sequence: Option, /// 1-minute load average (present when LOAD flag is set) load: Option, signature: Option, @@ -230,12 +223,17 @@ impl EndpointAddr { ) -> Result<(), SignEndpointError> { self.set_signed(true); let data = self.signed_data(); - let scheme = sigin::signature_scheme(authority.public_key()) - .context(sign_endpoint_error::SignatureSchemeSnafu)?; + + let scheme = authority + .cert_chain() + .first() + .and_then(|_| sigin::canonical_scheme_for_spki(authority.public_key())) + .ok_or(SignEndpointError::NoSupportedScheme)?; let signature = authority .sign(&data) .await .context(sign_endpoint_error::SignSnafu)?; + self.signature = Some(EndpointSignature { scheme: u16::from(scheme), signature, @@ -336,7 +334,7 @@ impl EndpointAddr { // sequence is only encoded when CLUSTERED flag is set if let Some(seq) = &self.sequence { - meta_len += seq.encoding_size(); + meta_len += VarInt::from_u32(seq.get()).encoding_size(); } if self.load.is_some() { @@ -369,9 +367,18 @@ impl EndpointAddr { self.agent } - pub fn set_sequence(&mut self, sequence: u64) { - if sequence > 0 { - self.sequence = Some(VarInt::from_u64(sequence).expect("Sequence too large")); + pub fn sequence(&self) -> Option { + self.sequence + } + + pub fn normalized_sequence(&self) -> CertificateSequence { + self.sequence + .unwrap_or_else(|| CertificateSequence::from(0u8)) + } + + pub fn set_sequence(&mut self, sequence: CertificateSequence) { + if sequence.get() > 0 { + self.sequence = Some(sequence); self.set_clustered(true); } else { self.sequence = None; @@ -379,26 +386,13 @@ impl EndpointAddr { } } - pub fn certificate_chain_key(&self) -> Result { + pub fn certificate_chain_key(&self) -> CertificateChainKey { let kind = if self.is_main() { CertificateChainKind::Primary } else { CertificateChainKind::Secondary }; - let sequence = self.sequence.map(VarInt::into_inner).unwrap_or(0); - if sequence > u64::from(u32::MAX) { - return endpoint_selector_error::SequenceTooLargeSnafu { sequence }.fail(); - } - let sequence = sequence as u32; - Ok(CertificateChainKey::new( - CertificateSequence::from(sequence), - kind, - )) - } - - pub fn set_certificate_chain_key(&mut self, chain: &CertificateChainKey) { - self.set_main(chain.kind() == CertificateChainKind::Primary); - self.set_sequence(u64::from(chain.sequence().get())); + CertificateChainKey::new(self.normalized_sequence(), kind) } pub fn load(&self) -> Option { @@ -428,7 +422,7 @@ impl EndpointAddr { // Sequence is only written when CLUSTERED is set if let Some(seq) = &self.sequence { - buf.put_varint(*seq); + buf.put_varint(VarInt::from_u32(seq.get())); } // Write primary address @@ -489,7 +483,13 @@ pub fn be_endpoint_addr(input: &[u8]) -> nom::IResult<&[u8], EndpointAddr> { // Sequence number is only present when CLUSTERED is set let (remain, sequence) = if is_clustered { let (remain, seq) = be_varint(remain)?; - (remain, Some(seq)) + let sequence = match CertificateSequence::try_from(seq.into_inner()) { + Ok(sequence) => sequence, + Err(_error) => { + return Err(nom::Err::Failure(make_error(remain, ErrorKind::TooLarge))); + } + }; + (remain, Some(sequence)) } else { (remain, None) }; @@ -779,6 +779,20 @@ impl TryFrom for DquicEndpointAddr { } } +pub async fn sign_endponit_address( + server_id: u8, + authority: Option<&(impl dhttp_identity::identity::LocalAuthority + ?Sized)>, + endpoint: DquicEndpointAddr, +) -> Option { + let mut ep: EndpointAddr = endpoint.try_into().ok()?; + ep.set_main(server_id == 0); + ep.set_sequence(CertificateSequence::from(server_id)); + if let Some(authority) = authority { + let _ = ep.sign_with_authority(authority).await; + } + Some(ep) +} + #[cfg(test)] mod tests { use std::{ @@ -787,76 +801,62 @@ mod tests { }; use bytes::BytesMut; - use dhttp_identity::certificate::{ - CertificateChainKey, CertificateChainKind, CertificateSequence, - }; use futures::future::BoxFuture; use ring::signature::KeyPair; - use rustls::{ - SignatureScheme, - sign::{Signer, SigningKey}, - }; + use rustls::sign::{Signer, SigningKey}; use super::*; - fn chain(sequence: u32, kind: CertificateChainKind) -> CertificateChainKey { - CertificateChainKey::new(CertificateSequence::from(sequence), kind) - } - - fn ed25519_spki(public_key: &[u8]) -> Vec { - let mut spki = Vec::with_capacity(44); - spki.extend_from_slice(&[ - 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, - ]); - spki.extend_from_slice(public_key); - spki + fn v4_outer() -> SocketAddrV4 { + SocketAddrV4::new(Ipv4Addr::new(203, 0, 113, 10), 4433) } #[test] - fn endpoint_selector_normalizes_missing_sequence_to_primary_zero() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5353); - let mut endpoint = EndpointAddr::direct_v4(addr); + fn endpoint_certificate_chain_key_normalizes_missing_sequence() { + let mut endpoint = EndpointAddr::direct_v4(v4_outer()); endpoint.set_main(true); - let selector = endpoint - .certificate_chain_key() - .expect("missing sequence normalizes to selector"); + let key = endpoint.certificate_chain_key(); - assert_eq!(selector, chain(0, CertificateChainKind::Primary)); + assert_eq!( + key.kind(), + dhttp_identity::certificate::CertificateChainKind::Primary + ); + assert_eq!(key.sequence().get(), 0); + assert_eq!(key.to_string(), "primary:0"); } #[test] - fn endpoint_selector_normalizes_missing_sequence_to_secondary_zero() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 5353); - let endpoint = EndpointAddr::direct_v4(addr); - - let selector = endpoint - .certificate_chain_key() - .expect("missing sequence normalizes to selector"); + fn endpoint_certificate_chain_key_uses_present_sequence() { + let mut endpoint = EndpointAddr::direct_v4(v4_outer()); + endpoint.set_main(false); + endpoint.set_sequence( + dhttp_identity::certificate::CertificateSequence::try_from(7u32).unwrap(), + ); - assert_eq!(selector, chain(0, CertificateChainKind::Secondary)); - } + let key = endpoint.certificate_chain_key(); - #[test] - fn endpoint_selector_sets_primary_and_secondary_chains() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 3), 5353); - let mut endpoint = EndpointAddr::direct_v4(addr); - - endpoint.set_certificate_chain_key(&chain(7, CertificateChainKind::Primary)); - assert!(endpoint.is_main()); - assert!(endpoint.is_clustered()); assert_eq!( - endpoint.certificate_chain_key().unwrap(), - chain(7, CertificateChainKind::Primary) + key.kind(), + dhttp_identity::certificate::CertificateChainKind::Secondary ); + assert_eq!(key.sequence().get(), 7); + assert_eq!(key.to_string(), "secondary:7"); + } - endpoint.set_certificate_chain_key(&chain(0, CertificateChainKind::Secondary)); - assert!(!endpoint.is_main()); - assert!(!endpoint.is_clustered()); - assert_eq!( - endpoint.certificate_chain_key().unwrap(), - chain(0, CertificateChainKind::Secondary) - ); + #[test] + fn endpoint_parser_rejects_over_range_certificate_sequence() { + let sequence = crate::core::parser::varint::VarInt::from_u64( + dhttp_identity::certificate::CertificateSequence::MAX as u64 + 1, + ) + .unwrap(); + let mut packet = BytesMut::new(); + packet.put_u8(EndpointAddr::FLAG_MAIN | EndpointAddr::FLAG_CLUSTERED); + packet.put_varint(sequence); + packet.put_u16(v4_outer().port()); + packet.put_slice(&v4_outer().ip().octets()); + + assert!(be_endpoint_addr(&packet).is_err()); } #[test] @@ -973,7 +973,7 @@ mod tests { // IPv4 direct, MAIN + CLUSTERED flags EndpointAddr { flags: EndpointAddr::FLAG_MAIN | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u32(0)), + sequence: Some(CertificateSequence::from(0u8)), load: None, signature: None, primary: v4_outer.into(), @@ -982,7 +982,7 @@ mod tests { // IPv4 NAT, CLUSTERED flag EndpointAddr { flags: EndpointAddr::FLAG_NAT | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u32(127)), + sequence: Some(CertificateSequence::try_from(127u32).unwrap()), load: None, signature: None, primary: v4_outer.into(), @@ -993,7 +993,7 @@ mod tests { flags: EndpointAddr::FLAG_FAMILY | EndpointAddr::FLAG_MAIN | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u32(128)), + sequence: Some(CertificateSequence::try_from(128u32).unwrap()), load: None, signature: None, primary: v6_outer.into(), @@ -1004,7 +1004,7 @@ mod tests { flags: EndpointAddr::FLAG_FAMILY | EndpointAddr::FLAG_NAT | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u64((1 << 62) - 1).unwrap()), + sequence: Some(CertificateSequence::try_from(16_384u32).unwrap()), load: None, signature: None, primary: v6_outer.into(), @@ -1025,122 +1025,12 @@ mod tests { } } - #[test] - fn signed_endpoint_accepts_scheme_inclusive_signature() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 10, 0, 7), 20004); - let scheme = u16::from(SignatureScheme::ED25519); - let signature = vec![0xaa; 64]; - let sig_len = VarInt::try_from(signature.len() as u64).unwrap(); - - let mut buf = BytesMut::new(); - buf.put_u8(EndpointAddr::FLAG_SIGNED); - buf.put_socket_addr_v4(&addr); - buf.put_u16(scheme); - buf.put_varint(sig_len); - buf.extend_from_slice(&signature); - - let (remain, decoded) = be_endpoint_addr(&buf).unwrap(); - - assert!(remain.is_empty()); - assert!(decoded.is_signed()); - assert_eq!(decoded.addr(), SocketAddr::V4(addr)); - assert_eq!(decoded.signature().unwrap().signature, signature); - } - - #[test] - fn signed_endpoint_rejects_signature_without_scheme() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 10, 0, 7), 20004); - let signature = vec![0xaa; 64]; - let sig_len = VarInt::try_from(signature.len() as u64).unwrap(); - - let mut buf = BytesMut::new(); - buf.put_u8(EndpointAddr::FLAG_SIGNED); - buf.put_socket_addr_v4(&addr); - buf.put_varint(sig_len); - buf.extend_from_slice(&signature); - - assert!(be_endpoint_addr(&buf).is_err()); - } - - #[test] - fn signed_endpoint_writes_actual_scheme_before_signature_length() { - #[derive(Debug)] - struct Ed25519Key { - keypair: Arc, - spki: Vec, - } - - #[derive(Debug)] - struct Ed25519Signer(Arc); - - impl Signer for Ed25519Signer { - fn sign(&self, message: &[u8]) -> Result, rustls::Error> { - Ok(self.0.sign(message).as_ref().to_vec()) - } - - fn scheme(&self) -> SignatureScheme { - SignatureScheme::ED25519 - } - } - - impl SigningKey for Ed25519Key { - fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { - offered - .contains(&SignatureScheme::ED25519) - .then(|| Box::new(Ed25519Signer(self.keypair.clone())) as Box) - } - - fn algorithm(&self) -> rustls::SignatureAlgorithm { - rustls::SignatureAlgorithm::ED25519 - } - } - - impl dhttp_identity::identity::LocalAuthority for Ed25519Key { - fn name(&self) -> &str { - "authority.example" - } - - fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { - &[] - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(self.spki.as_slice()) - } - - fn sign( - &self, - data: &[u8], - ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - let result = dhttp_identity::identity::sign_with_key(self, data); - Box::pin(std::future::ready(result)) - } - } - - let rng = ring::rand::SystemRandom::new(); - let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).unwrap(); - let keypair = - Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).unwrap()); - let spki = ed25519_spki(keypair.public_key().as_ref()); - let key = Ed25519Key { keypair, spki }; - - let mut ep = EndpointAddr::direct_v4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353)); - futures::executor::block_on(ep.sign_with_authority(&key)).unwrap(); - - let mut buf = BytesMut::new(); - buf.put_endpoint_addr(&ep); - - let scheme_offset = 1 + 2 + 4; - let encoded_scheme = u16::from_be_bytes([buf[scheme_offset], buf[scheme_offset + 1]]); - assert_eq!(encoded_scheme, u16::from(SignatureScheme::ED25519)); - } - #[test] fn endpoint_signature_roundtrip_and_verify() { #[derive(Debug)] struct Ed25519Key { keypair: Arc, - spki: Vec, + cert_chain: Vec>, } #[derive(Debug)] @@ -1174,11 +1064,7 @@ mod tests { } fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { - &[] - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + &self.cert_chain } fn sign( @@ -1194,10 +1080,14 @@ mod tests { let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).unwrap(); let keypair = Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).unwrap()); - let spki = ed25519_spki(keypair.public_key().as_ref()); + let mut spki = Vec::with_capacity(44); + spki.extend_from_slice(&[ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, + ]); + spki.extend_from_slice(keypair.public_key().as_ref()); let key = Ed25519Key { keypair: keypair.clone(), - spki: spki.clone(), + cert_chain: vec![rustls::pki_types::CertificateDer::from(spki.clone())], }; let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5353); @@ -1229,41 +1119,42 @@ mod tests { } #[test] - fn sign_with_authority_stores_canonical_signature() { + fn sign_with_authority_uses_canonical_scheme_from_public_key() { #[derive(Debug)] - struct StaticAuthority { - spki: Vec, + struct Ed25519Authority { + cert_chain: Vec>, } - impl dhttp_identity::identity::LocalAuthority for StaticAuthority { + impl dhttp_identity::identity::LocalAuthority for Ed25519Authority { fn name(&self) -> &str { "authority.example" } fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { - &[] - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + &self.cert_chain } fn sign( &self, _data: &[u8], ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - Box::pin(std::future::ready(Ok(vec![1, 2, 3]))) + Box::pin(async move { Ok(vec![1, 2, 3]) }) } } + let cert_chain = vec![rustls::pki_types::CertificateDer::from(vec![ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ])]; let mut ep = EndpointAddr::direct_v4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353)); - let authority = StaticAuthority { - spki: ed25519_spki(&[0; 32]), - }; - futures::executor::block_on(ep.sign_with_authority(&authority)).unwrap(); + futures::executor::block_on(ep.sign_with_authority(&Ed25519Authority { cert_chain })) + .unwrap(); let signature = ep.signature().unwrap(); - assert_eq!(signature.scheme, u16::from(SignatureScheme::ED25519)); + assert_eq!( + SignatureScheme::from(signature.scheme), + SignatureScheme::ED25519 + ); assert_eq!(signature.signature, vec![1, 2, 3]); } diff --git a/src/core/parser/sigin.rs b/src/core/parser/sigin.rs index 597d90e..70c7aa1 100644 --- a/src/core/parser/sigin.rs +++ b/src/core/parser/sigin.rs @@ -9,6 +9,105 @@ use x509_parser::{ x509::SubjectPublicKeyInfo, }; +pub const SIGNATURE_SCHEME_PREFERENCE: &[SignatureScheme] = &[ + SignatureScheme::ED25519, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, +]; + +pub fn signature_schemes_for_algorithm( + algorithm: rustls::SignatureAlgorithm, +) -> impl Iterator { + SIGNATURE_SCHEME_PREFERENCE + .iter() + .copied() + .filter(move |scheme| match algorithm { + rustls::SignatureAlgorithm::ED25519 => *scheme == SignatureScheme::ED25519, + rustls::SignatureAlgorithm::ECDSA => matches!( + scheme, + SignatureScheme::ECDSA_NISTP256_SHA256 | SignatureScheme::ECDSA_NISTP384_SHA384 + ), + rustls::SignatureAlgorithm::RSA => matches!( + scheme, + SignatureScheme::RSA_PSS_SHA256 + | SignatureScheme::RSA_PSS_SHA384 + | SignatureScheme::RSA_PSS_SHA512 + | SignatureScheme::RSA_PKCS1_SHA256 + | SignatureScheme::RSA_PKCS1_SHA384 + | SignatureScheme::RSA_PKCS1_SHA512 + ), + _ => true, + }) +} + +pub fn alg_name_for_scheme(scheme: SignatureScheme) -> Option<&'static str> { + match scheme { + SignatureScheme::ED25519 => Some("ed25519"), + SignatureScheme::ECDSA_NISTP256_SHA256 => Some("ecdsa-p256-sha256"), + SignatureScheme::ECDSA_NISTP384_SHA384 => Some("ecdsa-p384-sha384"), + SignatureScheme::RSA_PSS_SHA256 => Some("rsa-pss-sha256"), + SignatureScheme::RSA_PSS_SHA384 => Some("rsa-pss-sha384"), + SignatureScheme::RSA_PSS_SHA512 => Some("rsa-pss-sha512"), + SignatureScheme::RSA_PKCS1_SHA256 => Some("rsa-v1_5-sha256"), + SignatureScheme::RSA_PKCS1_SHA384 => Some("rsa-v1_5-sha384"), + SignatureScheme::RSA_PKCS1_SHA512 => Some("rsa-v1_5-sha512"), + _ => None, + } +} + +pub fn scheme_for_alg_name(alg: &str) -> Option { + match alg { + "ed25519" => Some(SignatureScheme::ED25519), + "ecdsa-p256-sha256" => Some(SignatureScheme::ECDSA_NISTP256_SHA256), + "ecdsa-p384-sha384" => Some(SignatureScheme::ECDSA_NISTP384_SHA384), + "rsa-pss-sha256" => Some(SignatureScheme::RSA_PSS_SHA256), + "rsa-pss-sha384" => Some(SignatureScheme::RSA_PSS_SHA384), + "rsa-pss-sha512" => Some(SignatureScheme::RSA_PSS_SHA512), + "rsa-v1_5-sha256" => Some(SignatureScheme::RSA_PKCS1_SHA256), + "rsa-v1_5-sha384" => Some(SignatureScheme::RSA_PKCS1_SHA384), + "rsa-v1_5-sha512" => Some(SignatureScheme::RSA_PKCS1_SHA512), + _ => None, + } +} + +pub fn canonical_scheme_for_spki(spki: SubjectPublicKeyInfoDer<'_>) -> Option { + let Ok((_remain, spki)) = SubjectPublicKeyInfo::from_der(spki.as_ref()) else { + return None; + }; + + if spki.algorithm.algorithm == OID_SIG_ED25519 { + return Some(SignatureScheme::ED25519); + } + + if spki.algorithm.algorithm == OID_PKCS1_RSAENCRYPTION { + return Some(SignatureScheme::RSA_PSS_SHA512); + } + + if spki.algorithm.algorithm != OID_KEY_TYPE_EC_PUBLIC_KEY { + return None; + } + + let curve = spki + .algorithm + .parameters + .as_ref() + .and_then(|parameters| parameters.as_oid().ok())?; + + if curve == OID_EC_P256 { + Some(SignatureScheme::ECDSA_NISTP256_SHA256) + } else if curve == OID_NIST_EC_P384 { + Some(SignatureScheme::ECDSA_NISTP384_SHA384) + } else { + None + } +} + #[derive(Debug, Snafu)] #[snafu(module)] pub enum SignError { @@ -48,6 +147,7 @@ pub fn sign_with_key(key: &(impl SigningKey + ?Sized), data: &[u8]) -> Result, ) -> Result { @@ -95,6 +195,11 @@ pub(crate) fn verify( SignatureScheme::ECDSA_NISTP384_SHA384 => &ring::signature::ECDSA_P384_SHA384_ASN1, SignatureScheme::ECDSA_NISTP256_SHA256 => &ring::signature::ECDSA_P256_SHA256_ASN1, SignatureScheme::ED25519 => &ring::signature::ED25519, + SignatureScheme::RSA_PKCS1_SHA256 => &ring::signature::RSA_PKCS1_2048_8192_SHA256, + SignatureScheme::RSA_PKCS1_SHA384 => &ring::signature::RSA_PKCS1_2048_8192_SHA384, + SignatureScheme::RSA_PKCS1_SHA512 => &ring::signature::RSA_PKCS1_2048_8192_SHA512, + SignatureScheme::RSA_PSS_SHA256 => &ring::signature::RSA_PSS_2048_8192_SHA256, + SignatureScheme::RSA_PSS_SHA384 => &ring::signature::RSA_PSS_2048_8192_SHA384, SignatureScheme::RSA_PSS_SHA512 => &ring::signature::RSA_PSS_2048_8192_SHA512, _ => return verify_error::UnsupportedSchemeSnafu { scheme }.fail(), }; diff --git a/src/core/signature.rs b/src/core/signature.rs new file mode 100644 index 0000000..d90f879 --- /dev/null +++ b/src/core/signature.rs @@ -0,0 +1,322 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use base64::Engine; +use dhttp_identity::identity::{LocalAuthority, SignError as AuthoritySignError}; +use ring::digest::{SHA256, digest}; +use rustls::{SignatureScheme, pki_types::SubjectPublicKeyInfoDer}; +use snafu::Snafu; + +use crate::core::parser::sigin; + +pub const CONTENT_DIGEST_HEADER: &str = "content-digest"; +pub const SIGNATURE_INPUT_HEADER: &str = "signature-input"; +pub const SIGNATURE_HEADER: &str = "signature"; +pub const SIGNATURE_LABEL: &str = "dns"; + +const DIGEST_PREFIX: &str = "sha-256=:"; +const SIGNATURE_PREFIX: &str = "dns=:"; +const SIGNATURE_INPUT_PREFIX: &str = "dns=(\"content-digest\")"; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct SignatureFields { + pub content_digest: Vec, + pub signature_input: Vec, + pub signature: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ParsedSignatureInput<'a> { + signature_params: &'a str, + alg: &'a str, + keyid: &'a str, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignatureFieldsError { + #[snafu(display("missing publisher certificate"))] + MissingCertificate, + #[snafu(display("unsupported signature scheme {scheme:?}"))] + UnsupportedScheme { scheme: SignatureScheme }, + #[snafu(display("unsupported signature algorithm {alg}"))] + UnsupportedAlgorithm { alg: String }, + #[snafu(display("invalid {field} field"))] + InvalidField { field: &'static str }, + #[snafu(display("invalid signature field utf-8"))] + InvalidUtf8 { source: std::str::Utf8Error }, + #[snafu(display("invalid base64"))] + InvalidBase64 { source: base64::DecodeError }, + #[snafu(display("content digest mismatch"))] + DigestMismatch, + #[snafu(display("signature keyid does not match publisher certificate"))] + KeyIdMismatch, + #[snafu(display("failed to sign DNS packet"))] + Sign { source: AuthoritySignError }, + #[snafu(display("invalid certificate: {details}"))] + InvalidCertificate { details: String }, + #[snafu(display("signature verification failed"))] + Verify { source: sigin::VerifyError }, +} + +impl SignatureFields { + pub fn empty() -> Self { + Self::default() + } + + pub fn is_empty(&self) -> bool { + self.content_digest.is_empty() + && self.signature_input.is_empty() + && self.signature.is_empty() + } + + pub async fn sign( + dns_bytes: &[u8], + authority: &(impl LocalAuthority + ?Sized), + ) -> Result { + let cert = authority + .cert_chain() + .first() + .ok_or(SignatureFieldsError::MissingCertificate)?; + let keyid = keyid_for_cert(cert.as_ref()); + let content_digest = content_digest_value(dns_bytes); + let created = unix_now_secs(); + + let scheme = sigin::canonical_scheme_for_spki(authority.public_key()).ok_or( + SignatureFieldsError::UnsupportedScheme { + scheme: SignatureScheme::Unknown(0), + }, + )?; + let alg = sigin::alg_name_for_scheme(scheme) + .ok_or(SignatureFieldsError::UnsupportedScheme { scheme })?; + let signature_input = signature_input_value(created, &keyid, alg); + let signature_base = signature_base(&content_digest, &signature_input)?; + let signature = authority + .sign(signature_base.as_bytes()) + .await + .map_err(|source| SignatureFieldsError::Sign { source })?; + let signature = signature_value(&signature); + + Ok(Self { + content_digest: content_digest.into_bytes(), + signature_input: signature_input.into_bytes(), + signature: signature.into_bytes(), + }) + } + + pub fn verify(&self, dns_bytes: &[u8], cert_der: &[u8]) -> Result { + if self.is_empty() { + return Ok(false); + } + + let content_digest = field_str(&self.content_digest)?; + verify_content_digest(content_digest, dns_bytes)?; + + let signature_input = field_str(&self.signature_input)?; + let parsed_input = parse_signature_input(signature_input)?; + let expected_keyid = keyid_for_cert(cert_der); + if parsed_input.keyid != expected_keyid { + return Ok(false); + } + + let scheme = sigin::scheme_for_alg_name(parsed_input.alg).ok_or_else(|| { + SignatureFieldsError::UnsupportedAlgorithm { + alg: parsed_input.alg.to_string(), + } + })?; + let signature = parse_signature(field_str(&self.signature)?)?; + let signature_base = signature_base(content_digest, signature_input)?; + + let (_, cert) = x509_parser::parse_x509_certificate(cert_der).map_err(|e| { + SignatureFieldsError::InvalidCertificate { + details: e.to_string(), + } + })?; + let spki = SubjectPublicKeyInfoDer::from(cert.tbs_certificate.subject_pki.raw); + sigin::verify(spki, scheme, signature_base.as_bytes(), &signature) + .map_err(|source| SignatureFieldsError::Verify { source }) + } +} + +pub fn content_digest_value(dns_bytes: &[u8]) -> String { + let digest = digest(&SHA256, dns_bytes); + let b64 = base64::engine::general_purpose::STANDARD.encode(digest.as_ref()); + format!("{DIGEST_PREFIX}{b64}:") +} + +pub fn cert_fingerprint_hex(cert_der: &[u8]) -> String { + digest(&SHA256, cert_der) + .as_ref() + .iter() + .map(|b| format!("{b:02x}")) + .collect() +} + +pub fn keyid_for_cert(cert_der: &[u8]) -> String { + format!("sha256:{}", cert_fingerprint_hex(cert_der)) +} + +fn unix_now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +fn signature_input_value(created: u64, keyid: &str, alg: &str) -> String { + format!( + "{SIGNATURE_LABEL}=(\"content-digest\");created={created};keyid=\"{keyid}\";alg=\"{alg}\"" + ) +} + +fn signature_value(signature: &[u8]) -> String { + let b64 = base64::engine::general_purpose::STANDARD.encode(signature); + format!("{SIGNATURE_PREFIX}{b64}:") +} + +fn signature_base( + content_digest: &str, + signature_input: &str, +) -> Result { + let parsed = parse_signature_input(signature_input)?; + Ok(format!( + "\"content-digest\": {content_digest}\n\"@signature-params\": {}", + parsed.signature_params + )) +} + +fn field_str(field: &[u8]) -> Result<&str, SignatureFieldsError> { + std::str::from_utf8(field).map_err(|source| SignatureFieldsError::InvalidUtf8 { source }) +} + +fn verify_content_digest( + content_digest: &str, + dns_bytes: &[u8], +) -> Result<(), SignatureFieldsError> { + let encoded = content_digest + .strip_prefix(DIGEST_PREFIX) + .and_then(|rest| rest.strip_suffix(':')) + .ok_or(SignatureFieldsError::InvalidField { + field: CONTENT_DIGEST_HEADER, + })?; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .map_err(|source| SignatureFieldsError::InvalidBase64 { source })?; + if decoded.as_slice() != digest(&SHA256, dns_bytes).as_ref() { + return Err(SignatureFieldsError::DigestMismatch); + } + Ok(()) +} + +fn parse_signature(input: &str) -> Result, SignatureFieldsError> { + let encoded = input + .strip_prefix(SIGNATURE_PREFIX) + .and_then(|rest| rest.strip_suffix(':')) + .ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_HEADER, + })?; + base64::engine::general_purpose::STANDARD + .decode(encoded) + .map_err(|source| SignatureFieldsError::InvalidBase64 { source }) +} + +fn parse_signature_input(input: &str) -> Result, SignatureFieldsError> { + if !input.starts_with(SIGNATURE_INPUT_PREFIX) { + return Err(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + }); + } + + let signature_params = + input + .strip_prefix("dns=") + .ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + let params = signature_params + .strip_prefix("(\"content-digest\")") + .ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + + let mut created = None; + let mut keyid = None; + let mut alg = None; + + for param in params.split(';').filter(|part| !part.is_empty()) { + if let Some(value) = param.strip_prefix("created=") { + created = value.parse::().ok(); + } else if let Some(value) = param.strip_prefix("keyid=") { + keyid = unquote(value); + } else if let Some(value) = param.strip_prefix("alg=") { + alg = unquote(value); + } else { + return Err(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + }); + } + } + + if created.is_none() { + return Err(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + }); + } + + let keyid = keyid.ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + let alg = alg.ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + + Ok(ParsedSignatureInput { + signature_params, + alg, + keyid, + }) +} + +fn unquote(value: &str) -> Option<&str> { + value.strip_prefix('"')?.strip_suffix('"') +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn content_digest_uses_sha256_dictionary_value() { + let value = content_digest_value(b"dns"); + assert!(value.starts_with("sha-256=:")); + assert!(value.ends_with(':')); + verify_content_digest(&value, b"dns").unwrap(); + assert!(matches!( + verify_content_digest(&value, b"changed"), + Err(SignatureFieldsError::DigestMismatch) + )); + } + + #[test] + fn signature_input_requires_alg_and_keyid() { + let input = "dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\""; + let parsed = parse_signature_input(input).unwrap(); + assert_eq!(parsed.keyid, "sha256:abc"); + assert_eq!(parsed.alg, "ed25519"); + + assert!(parse_signature_input("dns=(\"content-digest\");created=1").is_err()); + assert!(parse_signature_input("dns=(\"date\");created=1;alg=\"ed25519\"").is_err()); + } + + #[test] + fn alg_names_are_explicitly_mapped() { + assert_eq!( + sigin::scheme_for_alg_name("ed25519"), + Some(SignatureScheme::ED25519) + ); + assert_eq!( + sigin::alg_name_for_scheme(SignatureScheme::ECDSA_NISTP256_SHA256), + Some("ecdsa-p256-sha256") + ); + assert_eq!(sigin::scheme_for_alg_name("unknown"), None); + } +} diff --git a/src/core/wire.rs b/src/core/wire.rs index 9d3f539..028e455 100644 --- a/src/core/wire.rs +++ b/src/core/wire.rs @@ -11,9 +11,14 @@ use bytes::BufMut; use nom::{IResult, bytes::streaming::take, number::streaming::be_u32}; +use crate::core::signature::SignatureFields; + /// One DNS + certificate pair inside a [`MultiResponse`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ResponseRecord { + /// RFC 9421/9530-style publisher signature fields. Empty for unsigned + /// OpenMulti or static seed records. + pub signature_fields: SignatureFields, /// Serialised DNS packet bytes. pub dns: Vec, /// DER-encoded leaf certificate of the publisher, or empty when unavailable. @@ -21,6 +26,18 @@ pub struct ResponseRecord { } impl ResponseRecord { + pub fn new(signature_fields: SignatureFields, dns: Vec, cert: Vec) -> Self { + Self { + signature_fields, + dns, + cert, + } + } + + pub fn unsigned(dns: Vec, cert: Vec) -> Self { + Self::new(SignatureFields::empty(), dns, cert) + } + /// SHA-256 fingerprint of the publisher certificate as lowercase hex. /// Returns `None` when the cert field is empty. pub fn cert_fingerprint_hex(&self) -> Option { @@ -40,12 +57,9 @@ pub struct MultiResponse { } impl MultiResponse { - pub fn new(iter: impl IntoIterator, Vec)>) -> Self { + pub fn new(iter: impl IntoIterator) -> Self { Self { - records: iter - .into_iter() - .map(|(dns, cert)| ResponseRecord { dns, cert }) - .collect(), + records: iter.into_iter().collect(), } } @@ -53,7 +67,17 @@ impl MultiResponse { 4 + self .records .iter() - .map(|record| 4 + record.dns.len() + 4 + record.cert.len()) + .map(|record| { + 4 + record.signature_fields.content_digest.len() + + 4 + + record.signature_fields.signature_input.len() + + 4 + + record.signature_fields.signature.len() + + 4 + + record.dns.len() + + 4 + + record.cert.len() + }) .sum::() } @@ -72,39 +96,69 @@ impl WriteMultiResponse for B { fn put_multi_response(&mut self, response: &MultiResponse) { self.put_u32(response.records.len() as u32); for record in &response.records { - self.put_u32(record.dns.len() as u32); - self.put_slice(&record.dns); - self.put_u32(record.cert.len() as u32); - self.put_slice(&record.cert); + put_field(self, &record.signature_fields.content_digest); + put_field(self, &record.signature_fields.signature_input); + put_field(self, &record.signature_fields.signature); + put_field(self, &record.dns); + put_field(self, &record.cert); } } } +fn put_field(buf: &mut B, value: &[u8]) { + buf.put_u32(value.len() as u32); + buf.put_slice(value); +} + pub fn be_multi_response(input: &[u8]) -> IResult<&[u8], MultiResponse> { let (mut input, count) = be_u32(input)?; let mut records = Vec::with_capacity(count as usize); for _ in 0..count { - let (rest, dns_len) = be_u32(input)?; - let (rest, dns) = take(dns_len as usize)(rest)?; - let (rest, cert_len) = be_u32(rest)?; - let (rest, cert) = take(cert_len as usize)(rest)?; - records.push(ResponseRecord { - dns: dns.to_vec(), - cert: cert.to_vec(), - }); + let (rest, content_digest) = be_field(input)?; + let (rest, signature_input) = be_field(rest)?; + let (rest, signature) = be_field(rest)?; + let (rest, dns) = be_field(rest)?; + let (rest, cert) = be_field(rest)?; + records.push(ResponseRecord::new( + SignatureFields { + content_digest, + signature_input, + signature, + }, + dns, + cert, + )); input = rest; } Ok((input, MultiResponse { records })) } +fn be_field(input: &[u8]) -> IResult<&[u8], Vec> { + let (input, len) = be_u32(input)?; + let (input, value) = take(len as usize)(input)?; + Ok((input, value.to_vec())) +} + #[cfg(test)] mod tests { use super::*; #[test] fn multi_response_roundtrips() { - let response = - MultiResponse::new([(vec![1, 2, 3], vec![4, 5]), (vec![6, 7, 8, 9], Vec::new())]); + let response = MultiResponse::new([ + ResponseRecord::new( + SignatureFields { + content_digest: b"sha-256=:abc:".to_vec(), + signature_input: + b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"" + .to_vec(), + signature: b"dns=:sig:".to_vec(), + }, + vec![1, 2, 3], + vec![4, 5], + ), + ResponseRecord::unsigned(vec![6, 7, 8, 9], Vec::new()), + ]); let encoded = response.encode(); let (remain, decoded) = be_multi_response(&encoded).unwrap(); assert!(remain.is_empty()); diff --git a/src/h3.rs b/src/h3.rs new file mode 100644 index 0000000..760e89b --- /dev/null +++ b/src/h3.rs @@ -0,0 +1,263 @@ +use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; + +use dquic::qresolve::{Publish, PublishFuture, Resolve, ResolveFuture}; +use h3x::{ + dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, + endpoint::H3Endpoint, + quic, +}; +use url::Url; + +mod cache; +mod lookup; +mod publish; +mod request; + +const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); +const LOOKUP_REQUEST_ATTEMPTS: usize = 3; + +pub struct H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority, +{ + endpoint: Arc>, + base_url: Url, + cache: cache::LookupCache, +} + +impl fmt::Debug for H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("H3Resolver") + .field("base_url", &self.base_url) + .finish_non_exhaustive() + } +} + +impl fmt::Display for H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "H3 DNS Resolver({})", self.base_url) + } +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum H3RequestError { + #[snafu(display("failed to connect h3 endpoint"))] + Connect { source: h3x::pool::ConnectError }, + #[snafu(display("h3 request error"))] + Request { + source: HyperRequestError, + }, +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum H3PublishError { + #[snafu(transparent)] + Request { source: H3RequestError }, + #[snafu(display("anonymous h3 endpoint cannot sign dns publish request"))] + AnonymousEndpoint, + #[snafu(display("failed to get h3 endpoint local authority"))] + LocalAuthority { source: h3x::quic::ConnectionError }, + #[snafu(display("failed to sign h3 dns publish request"))] + SignRequest { + source: crate::core::signature::SignatureFieldsError, + }, + #[snafu(display("{status}"))] + Status { status: http::StatusCode }, +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum H3LookupError { + #[snafu(transparent)] + Request { source: H3RequestError }, + #[snafu(display("h3 stream error"))] + H3Stream { source: MessageStreamError }, + #[snafu(display("h3 request timed out after {timeout:?}"))] + RequestTimeout { timeout: Duration }, + #[snafu(display("{status}"))] + Status { status: http::StatusCode }, + #[snafu(display("no DNS record found"))] + NoRecordFound, + #[snafu(display("failed to decode h3 dns lookup response"))] + Decode { source: LookupDecodeError }, +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum LookupDecodeError { + #[snafu(display("failed to decode multi-record response"))] + MultiResponse, + #[snafu(display("failed to parse DNS records from response"))] + ParseRecords { + source: nom::Err>>, + }, +} + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub fn new( + base_url: impl AsRef, + client: H3Endpoint, + ) -> io::Result { + Self::from_endpoint(base_url, Arc::new(client)) + } + + pub fn from_endpoint( + base_url: impl AsRef, + endpoint: Arc>, + ) -> io::Result { + let base_url = Url::parse(base_url.as_ref()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; + base_url.host_str().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "base URL must have a valid host", + ) + })?; + + Ok(Self { + endpoint, + base_url, + cache: cache::LookupCache::default(), + }) + } + + pub fn clear_pool(&self) { + self.endpoint.clear_pool(); + } +} + +impl Publish for H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + Box::pin(async move { + self.publish_packet(name, packet) + .await + .map_err(io::Error::other) + }) + } +} + +impl Resolve for H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + Box::pin(async move { + H3Resolver::lookup(self, name) + .await + .map_err(io::Error::other) + }) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + #[cfg(feature = "dquic-network")] + use dquic::{qbase::net::addr::EndpointAddr, qresolve::Source}; + #[cfg(feature = "dquic-network")] + use futures::StreamExt; + + use super::*; + #[cfg(feature = "dquic-network")] + use crate::resolvers::DHTTP_H3_DNS_SERVER; + + #[test] + fn lookup_retry_budget_leaves_external_timeout_margin() { + let total_budget = LOOKUP_REQUEST_TIMEOUT * LOOKUP_REQUEST_ATTEMPTS as u32; + + assert!( + total_budget <= Duration::from_secs(10), + "h3 lookup must return before common 15s command timeouts so callers can retry" + ); + } + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn cached_lookup_reports_h3_dns_source() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cache.insert_positive( + "car.lab.dhttp.net", + vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], + ); + + let mut records = resolver.lookup("car.lab.dhttp.net").await.unwrap(); + let (source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + source, + Source::H3 { + server: Arc::from(resolver.base_url.origin().ascii_serialization()) + } + ); + assert_eq!( + endpoint, + EndpointAddr::direct("192.168.5.78:41748".parse().unwrap()) + ); + } + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn cached_dns_genmeta_net_record_is_returned() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cache.insert_positive( + "dns.genmeta.net", + vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], + ); + + let mut records = resolver.lookup("dns.genmeta.net").await.unwrap(); + let (_source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + endpoint, + EndpointAddr::direct("192.0.2.53:4433".parse().unwrap()) + ); + } + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn cached_lookup_uses_e_record_port_not_input_port() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cache.insert_positive( + "nat.genmeta.net", + vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], + ); + + let mut records = resolver.lookup("nat.genmeta.net:20004").await.unwrap(); + let (_source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + endpoint, + EndpointAddr::direct("192.0.2.10:21000".parse().unwrap()) + ); + } +} diff --git a/src/h3/cache.rs b/src/h3/cache.rs new file mode 100644 index 0000000..3f143ad --- /dev/null +++ b/src/h3/cache.rs @@ -0,0 +1,79 @@ +use std::time::Duration; + +use dashmap::DashMap; +use dquic::qbase::net::addr::EndpointAddr; +use tokio::time::Instant; + +const POSITIVE_TTL: Duration = Duration::from_secs(10); +const NEGATIVE_TTL: Duration = Duration::from_secs(2); + +#[derive(Debug)] +pub(super) struct CachedRecord { + addrs: Vec, + expire: Instant, +} + +#[derive(Debug, Default)] +pub(super) struct LookupCache { + positive: DashMap, + negative: DashMap, +} + +impl LookupCache { + pub(super) fn prune_expired(&self, now: Instant) { + self.positive.retain(|_host, record| record.expire > now); + self.negative.retain(|_host, expire| *expire > now); + } + + pub(super) fn positive_hit(&self, domain: &str) -> Option> { + self.positive.get(domain).map(|record| record.addrs.clone()) + } + + pub(super) fn negative_hit(&self, domain: &str) -> bool { + self.negative.get(domain).is_some() + } + + pub(super) fn insert_positive(&self, domain: &str, addrs: Vec) { + self.positive.insert( + domain.to_owned(), + CachedRecord { + addrs, + expire: Instant::now() + POSITIVE_TTL, + }, + ); + self.negative.remove(domain); + } + + pub(super) fn insert_negative(&self, domain: &str) { + self.negative + .insert(domain.to_owned(), Instant::now() + NEGATIVE_TTL); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn endpoint(addr: &str) -> EndpointAddr { + EndpointAddr::direct(addr.parse().expect("socket addr")) + } + + #[test] + fn positive_cache_hit_returns_endpoints() { + let cache = LookupCache::default(); + cache.insert_positive("demo.dhttp.net", vec![endpoint("192.0.2.10:4433")]); + + assert_eq!( + cache.positive_hit("demo.dhttp.net").unwrap(), + vec![endpoint("192.0.2.10:4433")] + ); + } + + #[test] + fn negative_cache_hit_blocks_lookup() { + let cache = LookupCache::default(); + cache.insert_negative("missing.dhttp.net"); + + assert!(cache.negative_hit("missing.dhttp.net")); + } +} diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs new file mode 100644 index 0000000..0f99bf6 --- /dev/null +++ b/src/h3/lookup.rs @@ -0,0 +1,305 @@ +use std::sync::Arc; + +use dquic::qresolve::{RecordStream, Source}; +use futures::{StreamExt, stream}; +use h3x::quic; +use http_body_util::BodyExt; +use snafu::{IntoError, ResultExt}; +use tokio::time::Instant; + +use super::{ + H3LookupError, H3Resolver, LOOKUP_REQUEST_ATTEMPTS, LOOKUP_REQUEST_TIMEOUT, LookupDecodeError, + h3_lookup_error, lookup_decode_error, +}; +use crate::core::{parser::packet::be_packet, wire::be_multi_response}; + +const LOOKUP_API_PATH: &str = "/api/v2/lookup"; + +fn lookup_url(base_url: &url::Url, name: &str) -> url::Url { + let mut url = base_url + .join(LOOKUP_API_PATH) + .expect("h3 dns lookup api path must be valid"); + url.query_pairs_mut().append_pair("host", name); + url +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct LookupRecords { + pub(super) endpoints: Vec, +} + +impl LookupRecords { + pub(super) fn decode(domain: &str, response: &[u8]) -> Result { + use crate::core::parser::record; + + let (remain, multi) = match be_multi_response(response) { + Ok(response) => response, + Err(_error) => return Err(LookupDecodeError::MultiResponse), + }; + if !remain.is_empty() { + return Err(LookupDecodeError::MultiResponse); + } + + let mut endpoint_records = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!( + error = %snafu::Report::from_error(&error), + "ignored record with malformed DNS packet signature" + ); + continue; + } + } + } + + let (_remain, packet) = match be_packet(&r.dns) { + Ok(packet) => packet, + Err(source) => { + return Err( + lookup_decode_error::ParseRecordsSnafu.into_error(source.to_owned()) + ); + } + }; + + endpoint_records.extend(packet.answers.iter().filter_map( + |answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + Some(ep.clone()) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }, + )); + } + + Ok(Self { + endpoints: crate::resolvers::endpoint_group::selected_endpoint_addrs(endpoint_records), + }) + } +} + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub(super) fn retryable_lookup_error(error: &H3LookupError) -> bool { + matches!( + error, + H3LookupError::Request { .. } | H3LookupError::H3Stream { .. } + ) + } + + pub(super) async fn lookup_response( + &self, + uri: http::Uri, + ) -> Result> { + let request = http::Request::get(uri) + .body(http_body_util::Empty::::new()) + .expect("h3 dns lookup request must be valid"); + let resp = self.execute_request(request).await?; + + tracing::trace!("received response with status {}", resp.status()); + match resp.status() { + http::StatusCode::OK => {} + http::StatusCode::NOT_FOUND => return Err(H3LookupError::NoRecordFound), + status => return Err(H3LookupError::Status { status }), + } + + match resp.into_body().collect().await { + Ok(response) => Ok(response.to_bytes()), + Err(source) => Err(H3LookupError::H3Stream { source }), + } + } + + pub(super) async fn lookup_response_with_retry( + &self, + uri: http::Uri, + ) -> Result> { + for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { + match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) + .await + { + Ok(Ok(response)) => return Ok(response), + Ok(Err(error)) + if Self::retryable_lookup_error(&error) + && attempt < LOOKUP_REQUEST_ATTEMPTS => + { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup failed, retrying" + ); + } + Ok(Err(error)) => return Err(error), + Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup timed out, retrying" + ); + } + Err(_elapsed) => { + self.endpoint.clear_pool(); + return Err(H3LookupError::RequestTimeout { + timeout: LOOKUP_REQUEST_TIMEOUT, + }); + } + } + } + + unreachable!("lookup retry loop returns on the final attempt") + } + + pub async fn lookup(&self, name: &str) -> Result> { + let server = Arc::from(self.base_url.origin().ascii_serialization()); + let source = Source::H3 { server }; + + let Some(domain) = crate::resolvers::resolvable_name(name) else { + return Err(H3LookupError::NoRecordFound); + }; + + let now = Instant::now(); + self.cache.prune_expired(now); + + if self.cache.negative_hit(domain) { + return Err(H3LookupError::NoRecordFound); + } + + if let Some(addrs) = self.cache.positive_hit(domain) { + let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); + return Ok(stream.boxed()); + } + + let url = lookup_url(&self.base_url, domain); + let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); + + tracing::trace!("sending lookup request to {}", self.base_url); + let response = match self.lookup_response_with_retry(uri).await { + Ok(response) => response, + Err(H3LookupError::NoRecordFound) => { + self.cache.insert_negative(domain); + return Err(H3LookupError::NoRecordFound); + } + Err(error) => return Err(error), + }; + + let records = LookupRecords::decode(domain, response.as_ref()) + .context(h3_lookup_error::DecodeSnafu)?; + let addrs = records.endpoints; + + if addrs.is_empty() { + self.cache.insert_negative(domain); + return Err(H3LookupError::NoRecordFound); + } + + self.cache.insert_positive(domain, addrs.clone()); + + Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, net::SocketAddrV4}; + + use super::*; + use crate::core::{ + MdnsPacket, + parser::record::endpoint::EndpointAddr as DnsEndpointAddr, + wire::{MultiResponse, ResponseRecord}, + }; + + fn direct(addr: &str, main: bool, sequence: u32) -> DnsEndpointAddr { + let socket: SocketAddrV4 = addr.parse().expect("socket addr"); + let mut endpoint = DnsEndpointAddr::direct_v4(socket); + endpoint.set_main(main); + endpoint.set_sequence( + dhttp_identity::certificate::CertificateSequence::try_from(sequence).unwrap(), + ); + endpoint + } + + fn response_for(name: &str, endpoints: Vec) -> Vec { + let mut hosts = HashMap::new(); + hosts.insert(name.to_owned(), endpoints); + let packet = MdnsPacket::answer(0, &hosts).to_bytes(); + MultiResponse::new([ResponseRecord::unsigned(packet, Vec::new())]).encode() + } + + #[test] + fn h3_lookup_url_targets_v2_api_from_origin_base() { + let base_url = url::Url::parse("https://dns.example.test:4433").expect("url"); + let url = lookup_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test:4433/api/v2/lookup?host=demo.dhttp.net" + ); + } + + #[test] + fn h3_lookup_url_does_not_duplicate_v2_base_path() { + let base_url = url::Url::parse("https://dns.example.test:4433/api/v2/").expect("url"); + let url = lookup_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test:4433/api/v2/lookup?host=demo.dhttp.net" + ); + } + + #[test] + fn lookup_records_select_primary_group() { + let response = response_for( + "demo.dhttp.net", + vec![ + direct("192.0.2.20:4433", false, 1), + direct("192.0.2.10:4433", true, 2), + direct("192.0.2.11:4433", true, 2), + direct("192.0.2.30:4433", true, 3), + ], + ); + + let records = LookupRecords::decode("demo.dhttp.net", &response).expect("records"); + + assert_eq!(records.endpoints.len(), 2); + assert_eq!( + records.endpoints[0], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.10:4433".parse().unwrap()) + ); + assert_eq!( + records.endpoints[1], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.11:4433".parse().unwrap()) + ); + } + + #[test] + fn lookup_records_ignore_answer_name_mismatch() { + let response = response_for("other.dhttp.net", vec![direct("192.0.2.50:4433", true, 1)]); + + let records = LookupRecords::decode("demo.dhttp.net", &response).expect("records"); + + assert!(records.endpoints.is_empty()); + } +} diff --git a/src/h3/publish.rs b/src/h3/publish.rs new file mode 100644 index 0000000..637e493 --- /dev/null +++ b/src/h3/publish.rs @@ -0,0 +1,264 @@ +use std::collections::HashMap; + +use dhttp_identity::identity::LocalAuthority; +use dquic::qbase::net::addr::EndpointAddr; +use h3x::quic; +use http_body_util::Full; +use snafu::{OptionExt, ResultExt}; +use tracing::trace; + +use super::{H3PublishError, H3Resolver, h3_publish_error}; +use crate::core::{ + MdnsPacket, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, +}; + +const PUBLISH_API_PATH: &str = "/api/v2/publish"; + +fn publish_url(base_url: &url::Url, name: &str) -> url::Url { + let mut url = base_url + .join(PUBLISH_API_PATH) + .expect("h3 dns publish api path must be valid"); + url.query_pairs_mut().append_pair("host", name); + url +} + +async fn signed_publish_request( + base_url: &url::Url, + name: &str, + packet: &[u8], + authority: &A, +) -> Result>, crate::core::signature::SignatureFieldsError> { + let url = publish_url(base_url, name); + let uri: http::Uri = url + .as_str() + .parse() + .expect("h3 dns publish URL is a valid URI"); + let signature_fields = SignatureFields::sign(packet, authority).await?; + + Ok(http::Request::post(uri) + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()) + .body(Full::new(bytes::Bytes::copy_from_slice(packet))) + .expect("h3 dns publish request must be valid")) +} + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub async fn publish_endpoints( + &self, + name: &str, + endpoints: &[EndpointAddr], + ) -> Result<(), H3PublishError> { + trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); + let bytes = { + let endpoints = endpoints + .iter() + .filter_map(|ep| { + crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() + }) + .collect(); + let mut hosts = HashMap::new(); + hosts.insert(name.to_string(), endpoints); + MdnsPacket::answer(0, &hosts).to_bytes() + }; + + self.publish_packet(name, &bytes).await + } + + /// Publish a pre-built DNS packet (with signatures already included). + pub async fn publish_packet( + &self, + name: &str, + packet: &[u8], + ) -> Result<(), H3PublishError> { + tracing::trace!( + name, + packet_len = packet.len(), + url = %self.base_url, + "h3 dns publishing packet" + ); + let authority = self + .endpoint + .quic() + .local_authority() + .await + .context(h3_publish_error::LocalAuthoritySnafu)? + .context(h3_publish_error::AnonymousEndpointSnafu)?; + let request = signed_publish_request(&self.base_url, name, packet, &authority) + .await + .context(h3_publish_error::SignRequestSnafu)?; + let resp = self.execute_request(request).await?; + + if resp.status() != http::StatusCode::OK { + return Err(H3PublishError::Status { + status: resp.status(), + }); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + #[cfg(feature = "dquic-network")] + use dquic::qresolve::Publish as _; + use futures::future::BoxFuture; + #[cfg(feature = "dquic-network")] + use h3x::endpoint::H3Endpoint; + use ring::signature::KeyPair as _; + use rustls::{ + SignatureAlgorithm, SignatureScheme, + pki_types::CertificateDer, + sign::{Signer, SigningKey}, + }; + + use super::*; + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn publish_rejects_anonymous_endpoint_before_request() { + let endpoint = Arc::new(H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint("https://dns.example.test:4433", endpoint) + .expect("valid h3 resolver"); + + let error = resolver + .publish_packet("demo.dhttp.net", b"dns-packet") + .await + .expect_err("anonymous endpoint should not publish"); + + assert_eq!( + error.to_string(), + "anonymous h3 endpoint cannot sign dns publish request" + ); + + let trait_error = resolver + .publish("demo.dhttp.net", b"dns-packet") + .await + .expect_err("trait publish should surface anonymous endpoint"); + assert!( + trait_error + .to_string() + .contains("anonymous h3 endpoint cannot sign dns publish request") + ); + } + + #[derive(Debug)] + struct TestAuthority { + keypair: Arc, + cert_chain: Vec>, + } + + impl dhttp_identity::identity::LocalAuthority for TestAuthority { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + + fn sign( + &self, + data: &[u8], + ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { + let result = dhttp_identity::identity::sign_with_key( + &TestSigningKey(self.keypair.clone()), + data, + ); + Box::pin(std::future::ready(result)) + } + } + + #[derive(Debug)] + struct TestSigningKey(Arc); + + impl SigningKey for TestSigningKey { + fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { + offered + .contains(&SignatureScheme::ED25519) + .then(|| Box::new(TestSigner(self.0.clone())) as Box) + } + + fn algorithm(&self) -> SignatureAlgorithm { + SignatureAlgorithm::ED25519 + } + } + + #[derive(Debug)] + struct TestSigner(Arc); + + impl Signer for TestSigner { + fn sign(&self, message: &[u8]) -> Result, rustls::Error> { + Ok(self.0.sign(message).as_ref().to_vec()) + } + + fn scheme(&self) -> SignatureScheme { + SignatureScheme::ED25519 + } + } + + fn test_authority() -> TestAuthority { + let rng = ring::rand::SystemRandom::new(); + let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).expect("pkcs8"); + let keypair = + Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).expect("keypair")); + let mut spki = Vec::with_capacity(44); + spki.extend_from_slice(&[ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, + ]); + spki.extend_from_slice(keypair.public_key().as_ref()); + + TestAuthority { + keypair, + cert_chain: vec![CertificateDer::from(spki)], + } + } + + #[tokio::test] + async fn signed_publish_request_uses_authority_headers() { + let authority = test_authority(); + let base_url = url::Url::parse("https://dns.example.test:4433").expect("url"); + let request = + signed_publish_request(&base_url, "demo.dhttp.net", b"dns-packet", &authority) + .await + .expect("signed request"); + + assert_eq!(request.method(), http::Method::POST); + assert_eq!( + request.uri().to_string(), + "https://dns.example.test:4433/api/v2/publish?host=demo.dhttp.net" + ); + assert!( + request + .headers() + .contains_key(crate::core::signature::CONTENT_DIGEST_HEADER) + ); + assert!( + request + .headers() + .contains_key(crate::core::signature::SIGNATURE_INPUT_HEADER) + ); + assert!( + request + .headers() + .contains_key(crate::core::signature::SIGNATURE_HEADER) + ); + } +} diff --git a/src/h3/request.rs b/src/h3/request.rs new file mode 100644 index 0000000..8b82d13 --- /dev/null +++ b/src/h3/request.rs @@ -0,0 +1,73 @@ +use std::convert::Infallible; + +use h3x::{ + dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, + quic, +}; +use snafu::IntoError; + +use super::{H3RequestError, H3Resolver, h3_request_error}; + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub(super) fn connect_error( + &self, + source: h3x::pool::ConnectError, + ) -> H3RequestError { + // H3 DNS resolvers keep a long-lived endpoint. A network transition may + // leave the cached H3 connection with stale QUIC paths, so the next + // attempt must establish a fresh connection instead of reusing it. + self.endpoint.clear_pool(); + h3_request_error::ConnectSnafu.into_error(source) + } + + pub(super) fn request_error( + &self, + source: HyperRequestError, + ) -> H3RequestError { + self.endpoint.clear_pool(); + h3_request_error::RequestSnafu.into_error(source) + } + + pub(super) async fn execute_request( + &self, + request: http::Request< + impl http_body::Body + Send + 'static, + >, + ) -> Result< + http::Response>, + H3RequestError, + > { + let authority = request + .uri() + .authority() + .expect("h3 dns request URL must include an authority") + .clone(); + tracing::trace!(%authority, "connecting h3 dns endpoint"); + let connection = match self.endpoint.connect(authority.clone()).await { + Ok(connection) => { + tracing::trace!(%authority, "connected h3 dns endpoint"); + connection + } + Err(source) => return Err(self.connect_error(source)), + }; + + let method = request.method().clone(); + let uri = request.uri().clone(); + tracing::trace!(%method, %uri, "executing h3 dns request"); + match connection.execute_hyper_request(request).await { + Ok(response) => { + tracing::trace!( + status = %response.status(), + "h3 dns request response received" + ); + Ok(response) + } + Err(source) => Err(self.request_error(source)), + } + } +} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..f6c0090 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,313 @@ +use std::{fmt::Display, io, sync::Arc}; + +use dashmap::DashMap; +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, +}; +use futures::{StreamExt, TryFutureExt, stream}; +use reqwest::{Client, IntoUrl, StatusCode, Url}; +use tokio::time::Instant; + +use crate::core::{ + parser::packet::be_packet, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, + wire::be_multi_response, +}; + +const LOOKUP_API_PATH: &str = "/api/v2/lookup"; +const PUBLISH_API_PATH: &str = "/api/v2/publish"; + +#[derive(Debug)] +struct Record { + addrs: Vec, + expire: Instant, +} + +#[derive(Debug)] +pub struct HttpResolver { + http_client: Client, + base_url: Url, + cached_records: DashMap, +} + +fn lookup_url(base_url: &Url, name: &str) -> Url { + api_url(base_url, LOOKUP_API_PATH, name) +} + +fn publish_url(base_url: &Url, name: &str) -> Url { + api_url(base_url, PUBLISH_API_PATH, name) +} + +fn api_url(base_url: &Url, path: &str, name: &str) -> Url { + let mut url = base_url.join(path).expect("ddns api path must be valid"); + url.query_pairs_mut().append_pair("host", name); + url +} + +impl Display for HttpResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Http DNS({})", + self.base_url.host_str().expect("checked in constructor") + ) + } +} + +impl HttpResolver { + pub fn new(base_url: impl IntoUrl) -> io::Result { + let base_url = base_url + .into_url() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + base_url.host_str().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "base URL must have a valid host", + ) + })?; + + Ok(Self { + http_client: build_http_client()?, + base_url, + cached_records: DashMap::new(), + }) + } + + pub async fn publish_signed( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> io::Result<()> { + self.publish_packet_with_signature(name, packet, signature_fields) + .await + .map_err(io::Error::other) + } + + async fn publish_packet_with_signature( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> Result<(), Error> { + let url = publish_url(&self.base_url, name); + let mut request = self + .http_client + .post(url) + .header("Content-Type", "application/octet-stream"); + if !signature_fields.is_empty() { + request = request + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); + } + request + .body(packet.to_vec()) + .send() + .await? + .error_for_status()?; + Ok(()) + } +} + +fn build_http_client() -> io::Result { + let native_certs = rustls_native_certs::load_native_certs(); + for error in &native_certs.errors { + let report = snafu::Report::from_error(error); + tracing::warn!(error = %report, "failed to load native root certificate"); + } + + let mut root_store = rustls::RootCertStore::empty(); + let (valid_roots, invalid_roots) = root_store.add_parsable_certificates(native_certs.certs); + if invalid_roots > 0 { + tracing::debug!(invalid_roots, "ignored invalid native root certificates"); + } + if valid_roots == 0 { + tracing::warn!("no native root certificates loaded for http resolver"); + } + + let mut tls = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + tls.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + Client::builder() + .use_preconfigured_tls(tls) + .build() + .map_err(io::Error::other) +} + +#[derive(Debug, snafu::Snafu)] +enum Error { + #[snafu(display("http request failed"))] + Reqwest { source: reqwest::Error }, + + #[snafu(display("{status}"))] + Status { status: StatusCode }, + + #[snafu(display("no DNS record found"))] + NoRecordFound, + + #[snafu(display("failed to parse DNS records from response"))] + ParseRecords { + source: nom::Err>>, + }, + + #[snafu(display("failed to decode multi-record response"))] + ParseMultiResponse, +} + +impl From for Error { + fn from(source: reqwest::Error) -> Self { + match source.status() { + Some(stateus) if stateus == StatusCode::NOT_FOUND => Error::NoRecordFound, + Some(status) => Error::Status { status }, + None => Error::Reqwest { + source: source.without_url(), + }, + } + } +} + +impl Publish for HttpResolver { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + Box::pin(async move { + self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) + .await + .map_err(io::Error::other) + }) + } +} + +impl Resolve for HttpResolver { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + let lookup = async move { + let Some(domain) = crate::resolvers::resolvable_name(name) else { + return Err(Error::NoRecordFound); + }; + + let now = Instant::now(); + let server = Arc::from(self.base_url.host_str().unwrap_or("")); + let soource = Source::Http { server }; + + use crate::core::parser::record; + self.cached_records + .retain(|_host, Record { expire, .. }| *expire < now); + if let Some(record) = self.cached_records.get(domain) { + let endpoint_addrs: Vec<_> = record + .addrs + .iter() + .map(|endpoint: &EndpointAddr| (soource.clone(), *endpoint)) + .collect(); + return Ok(stream::iter(endpoint_addrs).boxed()); + } + let response = self + .http_client + .get(lookup_url(&self.base_url, domain)) + .send() + .await; + + let response = response?.error_for_status()?.bytes().await?; + let (remain, multi) = + be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + if !remain.is_empty() { + return Err(Error::ParseMultiResponse); + } + + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); + continue; + } + } + } + let (_remain, packet) = + be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; + + addrs.extend( + packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + let endpoint = + TryInto::::try_into(ep.clone()).ok()?; + Some(endpoint) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }), + ); + } + if addrs.is_empty() { + return Err(Error::NoRecordFound); + } + + // cache the addrs + self.cached_records.insert( + domain.to_string(), + Record { + addrs: addrs.clone(), + expire: now + std::time::Duration::from_secs(300), + }, + ); + + Ok(stream::iter(addrs.into_iter().map(move |ep| (soource.clone(), ep))).boxed()) + }; + Box::pin(lookup.map_err(io::Error::other)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn http_publish_url_targets_v2_api_from_origin_base() { + let base_url = Url::parse("https://dns.example.test").expect("url"); + let url = publish_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test/api/v2/publish?host=demo.dhttp.net" + ); + } + + #[test] + fn http_lookup_url_does_not_duplicate_v2_base_path() { + let base_url = Url::parse("https://dns.example.test/api/v2/").expect("url"); + let url = lookup_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test/api/v2/lookup?host=demo.dhttp.net" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 193112c..b495501 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,11 @@ mod bootstrap; pub mod core; +#[cfg(feature = "h3")] +pub mod h3; +#[cfg(feature = "http")] +pub mod http; +#[cfg(feature = "mdns")] pub mod mdns; -#[cfg(any(feature = "h3x-resolver", feature = "mdns-resolver"))] -pub mod publisher; +pub mod publishers; pub mod resolvers; diff --git a/src/mdns.rs b/src/mdns.rs index dd51460..fce3ee9 100644 --- a/src/mdns.rs +++ b/src/mdns.rs @@ -1,4 +1,359 @@ mod if_nametoindex; mod protocol; -pub mod resolvers; pub mod service; + +use std::{fmt, io, net::IpAddr}; +#[cfg(feature = "dquic-network")] +use std::{net::SocketAddr, sync::Arc}; + +#[cfg(feature = "dquic-network")] +use dquic::qresolve::RecordStream; +use dquic::{ + qbase::net::Family, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, +}; +use futures::{FutureExt, StreamExt, TryFutureExt, future, stream}; +#[cfg(feature = "dquic-network")] +use futures::{Stream, stream::FuturesUnordered}; + +#[cfg(feature = "dquic-network")] +use self::protocol::MdnsProtocol; +#[cfg(feature = "dquic-network")] +use crate::core::parser::packet::Packet; +use crate::core::parser::record::RData; + +pub type MdnsResolver = service::Mdns; +pub type MdnsPublisher = service::Mdns; + +impl MdnsResolver { + pub fn source(&self) -> Source { + Source::Mdns { + nic: self.bound_nic().into(), + family: match self.bound_ip() { + IpAddr::V4(..) => Family::V4, + IpAddr::V6(..) => Family::V6, + }, + } + } +} + +impl fmt::Display for MdnsResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.source(), f) + } +} + +impl Publish for MdnsPublisher { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + let endpoints = match endpoints_from_packet(packet) { + Ok(endpoints) => endpoints, + Err(error) => return future::ready(Err(error)).boxed(), + }; + self.insert_host(name.to_string(), endpoints); + future::ready(Ok(())).boxed() + } +} + +impl Resolve for MdnsResolver { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + let source = self.source(); + self.query(name.to_owned()) + .map_ok(move |list| { + let endpoints = crate::resolvers::endpoint_group::selected_endpoint_addrs(list); + stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() + }) + .boxed() + } +} + +fn endpoints_from_packet(packet: &[u8]) -> io::Result> { + use crate::core::parser::packet::be_packet; + + be_packet(packet) + .map(|(_, pkt)| { + pkt.answers + .iter() + .filter_map(|rr| match rr.data() { + RData::E(ep) => Some(ep.clone()), + _ => None, + }) + .collect::>() + }) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string())) +} + +#[cfg(feature = "dquic-network")] +pub struct MdnsBindDriver { + iface_manager: Arc, + null_io_factory: Arc, + service_name: Arc, +} + +#[cfg(feature = "dquic-network")] +impl MdnsBindDriver { + pub fn new(service_name: impl Into>) -> Self { + Self { + iface_manager: Arc::new(h3x::dquic::net::InterfaceManager::new()), + null_io_factory: Arc::new(h3x::dquic::NullIoFactory), + service_name: service_name.into(), + } + } + + fn install_or_rebind_mdns( + &self, + network: &h3x::dquic::Network, + bind_iface: &h3x::dquic::net::BindInterface, + ) { + let bind_uri = bind_iface.bind_uri(); + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + tracing::debug!(%bind_uri, "skipping mdns binding for non-interface bind uri"); + return; + }; + let Some(ip) = network.resolve_device_addr(device, family) else { + tracing::debug!(%bind_uri, "skipping mdns binding without local interface address"); + return; + }; + + bind_iface.with_components_mut(|components, _iface| { + match components.try_init_with(|| service::Mdns::new(&self.service_name, ip, device)) { + Ok(mdns) => mdns.reinit_on(device, ip), + Err(error) => { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, %bind_uri, "failed to initialize mdns binding"); + } + } + }); + } +} + +#[cfg(feature = "dquic-network")] +impl h3x::dquic::BindDriver for MdnsBindDriver { + fn bind<'a>( + &'a self, + network: &'a h3x::dquic::Network, + uri: h3x::dquic::net::BindUri, + ) -> futures::future::BoxFuture<'a, h3x::dquic::net::BindInterface> { + async move { + let iface = self + .iface_manager + .bind(uri, self.null_io_factory.clone()) + .await; + self.install_or_rebind_mdns(network, &iface); + iface + } + .boxed() + } + + fn rebind<'a>( + &'a self, + network: &'a h3x::dquic::Network, + iface: &'a h3x::dquic::net::BindInterface, + ) -> futures::future::BoxFuture<'a, ()> { + async move { + self.install_or_rebind_mdns(network, iface); + } + .boxed() + } +} + +#[cfg(feature = "dquic-network")] +pub struct MdnsResolvers { + network: Arc, + driver: Arc, + patterns: Arc>, + _handles: Vec, +} + +#[cfg(feature = "dquic-network")] +#[derive(Debug, Clone)] +pub struct BoundMdnsResolver { + pub device: String, + pub family: Family, + pub resolver: MdnsResolver, +} + +#[cfg(feature = "dquic-network")] +impl fmt::Debug for MdnsResolvers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MdnsResolvers") + .field("patterns", &self.patterns) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "dquic-network")] +impl fmt::Display for MdnsResolvers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("mDNS resolvers") + } +} + +#[cfg(feature = "dquic-network")] +impl MdnsResolvers { + pub async fn bind( + network: Arc, + patterns: Arc>, + service_name: impl Into>, + ) -> Self { + let driver = Arc::new(MdnsBindDriver::new(service_name)); + let mut handles = Vec::with_capacity(patterns.len()); + for pattern in patterns.iter() { + handles.push(network.bind_with(driver.clone(), pattern.clone()).await); + } + + Self { + network, + driver, + patterns, + _handles: handles, + } + } + + pub fn bound_interfaces( + &self, + pattern: &h3x::dquic::binds::BindPattern, + ) -> Option> { + self.network.get_interfaces_with(&self.driver, pattern) + } + + fn for_each_resolver(&self, mut f: impl FnMut(&MdnsResolver)) { + for pattern in self.patterns.iter() { + let Some(ifaces) = self.bound_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + iface.with_components(|components, _| { + if let Some(mdns) = components.get::() { + f(mdns); + } + }); + } + } + } + + pub fn bound_resolvers(&self) -> Vec { + let mut resolvers = Vec::new(); + for pattern in self.patterns.iter() { + let Some(ifaces) = self.bound_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + let bind_uri = iface.bind_uri(); + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + continue; + }; + iface.with_components(|components, _| { + if let Some(resolver) = components.get::() { + resolvers.push(BoundMdnsResolver { + device: device.to_owned(), + family, + resolver: resolver.clone(), + }); + } + }); + } + } + resolvers + } + + pub async fn query(&self, name: &str) -> io::Result { + let mut lookup_futures = FuturesUnordered::new(); + let mut has_resolver = false; + self.for_each_resolver(|resolver| { + has_resolver = true; + let source = resolver.source(); + lookup_futures.push( + resolver + .query(name.to_owned()) + .map_ok(move |eps| (source, eps)), + ); + }); + if !has_resolver { + return Err(io::Error::other("no mdns resolvers available")); + } + + let mut last_error = None; + let mut has_success = false; + let mut records = Vec::new(); + while let Some(result) = lookup_futures.next().await { + match result { + Ok((source, endpoints)) => { + has_success = true; + records.extend( + endpoints + .into_iter() + .map(|endpoint| (source.clone(), endpoint)), + ); + } + Err(error) => last_error = Some(error), + } + } + + if !has_success { + return Err( + last_error.unwrap_or_else(|| io::Error::other("no mdns resolvers available")) + ); + } + + let records = crate::resolvers::endpoint_group::selected_endpoint_records(records); + + Ok(stream::iter(records).boxed()) + } + + pub fn discover(&self) -> impl Stream + use<> { + let mut protos = Vec::new(); + self.for_each_resolver(|resolver| { + protos.push(resolver.protocol()); + }); + + async fn receive_one( + proto: Arc, + ) -> Option<((SocketAddr, Packet), Arc)> { + let result = proto.receive_boardcast().await.ok()?; + Some((result, proto)) + } + + let mut pending = protos + .into_iter() + .map(receive_one) + .collect::>(); + + Box::pin(stream::poll_fn(move |cx| { + use std::task::Poll; + loop { + match pending.poll_next_unpin(cx) { + Poll::Ready(Some(Some((item, proto)))) => { + pending.push(receive_one(proto)); + return Poll::Ready(Some(item)); + } + Poll::Ready(Some(None)) => continue, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + })) + } +} + +#[cfg(feature = "dquic-network")] +impl Publish for MdnsResolvers { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + let endpoints = match endpoints_from_packet(packet) { + Ok(endpoints) => endpoints, + Err(error) => return future::ready(Err(error)).boxed(), + }; + + self.for_each_resolver(|resolver| { + resolver.insert_host(name.to_string(), endpoints.clone()); + }); + + future::ready(Ok(())).boxed() + } +} + +#[cfg(feature = "dquic-network")] +impl Resolve for MdnsResolvers { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + self.query(name).boxed() + } +} diff --git a/src/mdns/resolvers.rs b/src/mdns/resolvers.rs deleted file mode 100644 index 1bd416e..0000000 --- a/src/mdns/resolvers.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod mdns; diff --git a/src/mdns/resolvers/mdns.rs b/src/mdns/resolvers/mdns.rs deleted file mode 100644 index 16912ab..0000000 --- a/src/mdns/resolvers/mdns.rs +++ /dev/null @@ -1,354 +0,0 @@ -use std::{fmt, io, net::IpAddr}; -#[cfg(feature = "mdns-resolver")] -use std::{net::SocketAddr, sync::Arc}; - -#[cfg(feature = "mdns-resolver")] -use dquic::qresolve::RecordStream; -use dquic::{ - qbase::net::Family, - qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, -}; -use futures::{FutureExt, StreamExt, TryFutureExt, future, stream}; -#[cfg(feature = "mdns-resolver")] -use futures::{Stream, stream::FuturesUnordered}; - -#[cfg(feature = "mdns-resolver")] -use super::super::protocol::MdnsProtocol; -#[cfg(feature = "mdns-resolver")] -use crate::core::parser::packet::Packet; -use crate::core::parser::record::RData; -pub type MdnsResolver = crate::mdns::service::Mdns; - -impl MdnsResolver { - pub fn source(&self) -> Source { - Source::Mdns { - nic: self.bound_nic().into(), - family: match self.bound_ip() { - IpAddr::V4(..) => Family::V4, - IpAddr::V6(..) => Family::V6, - }, - } - } -} - -impl fmt::Display for MdnsResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.source(), f) - } -} - -impl Publish for MdnsResolver { - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - let endpoints = match endpoints_from_packet(packet) { - Ok(endpoints) => endpoints, - Err(error) => return future::ready(Err(error)).boxed(), - }; - self.insert_host(name.to_string(), endpoints); - future::ready(Ok(())).boxed() - } -} - -impl Resolve for MdnsResolver { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - let source = self.source(); - self.query(name.to_owned()) - .map_ok(move |list| { - let endpoints = crate::resolvers::selector::selected_endpoint_addrs(list); - stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() - }) - .boxed() - } -} - -fn endpoints_from_packet(packet: &[u8]) -> io::Result> { - use crate::core::parser::packet::be_packet; - - be_packet(packet) - .map(|(_, pkt)| { - pkt.answers - .iter() - .filter_map(|rr| match rr.data() { - RData::E(ep) => Some(ep.clone()), - _ => None, - }) - .collect::>() - }) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string())) -} - -#[cfg(feature = "mdns-resolver")] -pub struct MdnsBindDriver { - iface_manager: Arc, - null_io_factory: Arc, - service_name: Arc, -} - -#[cfg(feature = "mdns-resolver")] -impl MdnsBindDriver { - pub fn new(service_name: impl Into>) -> Self { - Self { - iface_manager: Arc::new(h3x::dquic::net::InterfaceManager::new()), - null_io_factory: Arc::new(h3x::dquic::NullIoFactory), - service_name: service_name.into(), - } - } - - fn install_or_rebind_mdns( - &self, - network: &h3x::dquic::Network, - bind_iface: &h3x::dquic::net::BindInterface, - ) { - let bind_uri = bind_iface.bind_uri(); - let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { - tracing::debug!(%bind_uri, "skipping mdns binding for non-interface bind uri"); - return; - }; - let Some(ip) = network.resolve_device_addr(device, family) else { - tracing::debug!(%bind_uri, "skipping mdns binding without local interface address"); - return; - }; - - bind_iface.with_components_mut(|components, _iface| { - match components.try_init_with(|| crate::mdns::service::Mdns::new(&self.service_name, ip, device)) { - Ok(mdns) => mdns.reinit_on(device, ip), - Err(error) => { - let report = snafu::Report::from_error(&error); - tracing::debug!(error = %report, %bind_uri, "failed to initialize mdns binding"); - } - } - }); - } -} - -#[cfg(feature = "mdns-resolver")] -impl h3x::dquic::BindDriver for MdnsBindDriver { - fn bind<'a>( - &'a self, - network: &'a h3x::dquic::Network, - uri: h3x::dquic::net::BindUri, - ) -> futures::future::BoxFuture<'a, h3x::dquic::net::BindInterface> { - async move { - let iface = self - .iface_manager - .bind(uri, self.null_io_factory.clone()) - .await; - self.install_or_rebind_mdns(network, &iface); - iface - } - .boxed() - } - - fn rebind<'a>( - &'a self, - network: &'a h3x::dquic::Network, - iface: &'a h3x::dquic::net::BindInterface, - ) -> futures::future::BoxFuture<'a, ()> { - async move { - self.install_or_rebind_mdns(network, iface); - } - .boxed() - } -} - -#[cfg(feature = "mdns-resolver")] -pub struct MdnsResolvers { - network: Arc, - driver: Arc, - patterns: Arc>, - _handles: Vec, -} - -#[cfg(feature = "mdns-resolver")] -#[derive(Debug, Clone)] -pub struct BoundMdnsResolver { - pub device: String, - pub family: Family, - pub resolver: MdnsResolver, -} - -#[cfg(feature = "mdns-resolver")] -impl fmt::Debug for MdnsResolvers { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("MdnsResolvers") - .field("patterns", &self.patterns) - .finish_non_exhaustive() - } -} - -#[cfg(feature = "mdns-resolver")] -impl fmt::Display for MdnsResolvers { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("mDNS resolvers") - } -} - -#[cfg(feature = "mdns-resolver")] -impl MdnsResolvers { - pub async fn bind( - network: Arc, - patterns: Arc>, - service_name: impl Into>, - ) -> Self { - let driver = Arc::new(MdnsBindDriver::new(service_name)); - let mut handles = Vec::with_capacity(patterns.len()); - for pattern in patterns.iter() { - handles.push(network.bind_with(driver.clone(), pattern.clone()).await); - } - - Self { - network, - driver, - patterns, - _handles: handles, - } - } - - pub fn bound_interfaces( - &self, - pattern: &h3x::dquic::binds::BindPattern, - ) -> Option> { - self.network.get_interfaces_with(&self.driver, pattern) - } - - fn for_each_resolver(&self, mut f: impl FnMut(&MdnsResolver)) { - for pattern in self.patterns.iter() { - let Some(ifaces) = self.bound_interfaces(pattern) else { - continue; - }; - for iface in ifaces { - iface.with_components(|components, _| { - if let Some(mdns) = components.get::() { - f(mdns); - } - }); - } - } - } - - pub fn bound_resolvers(&self) -> Vec { - let mut resolvers = Vec::new(); - for pattern in self.patterns.iter() { - let Some(ifaces) = self.bound_interfaces(pattern) else { - continue; - }; - for iface in ifaces { - let bind_uri = iface.bind_uri(); - let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { - continue; - }; - iface.with_components(|components, _| { - if let Some(resolver) = components.get::() { - resolvers.push(BoundMdnsResolver { - device: device.to_owned(), - family, - resolver: resolver.clone(), - }); - } - }); - } - } - resolvers - } - - pub async fn query(&self, name: &str) -> io::Result { - let mut lookup_futures = FuturesUnordered::new(); - let mut has_resolver = false; - self.for_each_resolver(|resolver| { - has_resolver = true; - let source = resolver.source(); - lookup_futures.push( - resolver - .query(name.to_owned()) - .map_ok(move |eps| (source, eps)), - ); - }); - if !has_resolver { - return Err(io::Error::other("no mdns resolvers available")); - } - - let mut last_error = None; - let mut has_success = false; - let mut records = Vec::new(); - while let Some(result) = lookup_futures.next().await { - match result { - Ok((source, endpoints)) => { - has_success = true; - records.extend( - endpoints - .into_iter() - .map(|endpoint| (source.clone(), endpoint)), - ); - } - Err(error) => last_error = Some(error), - } - } - - if !has_success { - return Err( - last_error.unwrap_or_else(|| io::Error::other("no mdns resolvers available")) - ); - } - - let records = crate::resolvers::selector::selected_endpoint_records(records); - - Ok(stream::iter(records).boxed()) - } - - /// Discover mDNS broadcasts from all active resolvers. - pub fn discover(&self) -> impl Stream + use<> { - let mut protos = Vec::new(); - self.for_each_resolver(|resolver| { - protos.push(resolver.protocol()); - }); - - async fn receive_one( - proto: Arc, - ) -> Option<((SocketAddr, Packet), Arc)> { - let result = proto.receive_boardcast().await.ok()?; - Some((result, proto)) - } - - let mut pending = protos - .into_iter() - .map(receive_one) - .collect::>(); - - Box::pin(stream::poll_fn(move |cx| { - use std::task::Poll; - loop { - match pending.poll_next_unpin(cx) { - Poll::Ready(Some(Some((item, proto)))) => { - pending.push(receive_one(proto)); - return Poll::Ready(Some(item)); - } - Poll::Ready(Some(None)) => continue, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - })) - } -} - -#[cfg(feature = "mdns-resolver")] -impl Publish for MdnsResolvers { - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - let endpoints = match endpoints_from_packet(packet) { - Ok(endpoints) => endpoints, - Err(error) => return future::ready(Err(error)).boxed(), - }; - - self.for_each_resolver(|resolver| { - resolver.insert_host(name.to_string(), endpoints.clone()); - }); - - future::ready(Ok(())).boxed() - } -} - -#[cfg(feature = "mdns-resolver")] -impl Resolve for MdnsResolvers { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - self.query(name).boxed() - } -} diff --git a/src/publisher.rs b/src/publisher.rs deleted file mode 100644 index 13ba706..0000000 --- a/src/publisher.rs +++ /dev/null @@ -1,802 +0,0 @@ -mod address; -mod dispatch; -mod packet; - -use std::{any::TypeId, future::Future, io, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; - -pub use address::{ - AddressSelector, AddressView, AddressViewSource, EndpointBindingAddresses, FnAddressView, - PublishAddressGroup, PublishAddressScope, PublishAddresses, -}; -use dhttp_identity::{identity::LocalAuthority, name::Name}; -use dquic::{ - qinterface::component::location::AddressEvent, qresolve::Resolve, - qtraversal::nat::client::ClientLocationData, -}; -pub use packet::{EndpointRecordSigner, SignEndpointRecordsError}; -use snafu::Snafu; - -pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); -/// Upper bound for a single publish attempt in the background loop. -/// -/// Network changes can leave an in-flight H3 publish waiting on paths that no -/// longer exist. Timing out the attempt keeps consecutive publishes -/// independent: the next interval observes the current bindings again. -pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); -const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); - -type PublishLoopFuture<'a> = Pin + Send + 'a>>; - -#[derive(Debug, Snafu)] -#[snafu(module(create_publisher_error))] -pub enum CreatePublisherError { - #[snafu(display("anonymous endpoint cannot publish dns records"))] - AnonymousEndpoint, -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum PublishOnceError { - #[snafu(display("no publisher resolver available"))] - NoPublisherResolver, - #[snafu(display("failed to sign endpoint records"))] - SignEndpointRecords { source: SignEndpointRecordsError }, - #[snafu(display("failed to publish dns packet with {publisher}"))] - Publish { - publisher: String, - source: io::Error, - }, -} - -pub trait PublisherResolver: Send + Sync + 'static { - fn as_resolver(&self) -> &(dyn Resolve + Send + Sync); -} - -impl PublisherResolver for T -where - T: Resolve + Send + Sync + Sized + 'static, -{ - fn as_resolver(&self) -> &(dyn Resolve + Send + Sync) { - self - } -} - -impl PublisherResolver for dyn Resolve + Send + Sync { - fn as_resolver(&self) -> &(dyn Resolve + Send + Sync) { - self - } -} - -pub struct Publisher { - signer: EndpointRecordSigner, - resolver: Arc, -} - -impl std::fmt::Debug for Publisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Publisher") - .field("signer", &self.signer) - .field("resolver", &self.resolver.as_resolver().to_string()) - .finish() - } -} - -pub type EndpointPublisher = Publisher; - -impl Publisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, -{ - pub fn new(signer: EndpointRecordSigner, resolver: Arc) -> Self { - Self { signer, resolver } - } - - pub fn signer(&self) -> &EndpointRecordSigner { - &self.signer - } - - pub fn resolver(&self) -> &Arc { - &self.resolver - } - - pub async fn publish_once( - &self, - name: &Name<'_>, - addresses: &V, - ) -> Result<(), PublishOnceError> - where - V: AddressView + Sync, - { - let mut published = false; - published |= self - .publish_to_resolver(self.resolver.as_resolver(), name, addresses) - .await?; - - if !published { - return publish_once_error::NoPublisherResolverSnafu.fail(); - } - - Ok(()) - } -} - -pub struct EndpointPublicationLoop { - name: Name<'static>, - publisher: Publisher, - source: S, - interval: Duration, - publish_timeout: Duration, -} - -impl std::fmt::Debug for EndpointPublicationLoop -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, - S: std::fmt::Debug, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EndpointPublicationLoop") - .field("name", &self.name) - .field("publisher", &self.publisher) - .field("source", &self.source) - .field("interval", &self.interval) - .field("publish_timeout", &self.publish_timeout) - .finish() - } -} - -impl EndpointPublicationLoop -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, - S: AddressViewSource + Send + Sync, -{ - pub fn new(name: Name<'static>, publisher: Publisher, source: S) -> Self { - Self { - name, - publisher, - source, - interval: DEFAULT_PUBLISH_INTERVAL, - publish_timeout: DEFAULT_PUBLISH_TIMEOUT, - } - } - - pub fn name(&self) -> &Name<'static> { - &self.name - } - - pub fn publisher(&self) -> &Publisher { - &self.publisher - } - - pub fn interval(&self) -> Duration { - self.interval - } - - pub fn publish_timeout(&self) -> Duration { - self.publish_timeout - } - - pub fn with_publish_timeout(mut self, timeout: Duration) -> Self { - self.publish_timeout = timeout; - self - } - - pub async fn run(&self) -> ! { - let mut locations = self.source.subscribe(); - let interval = tokio::time::sleep(self.interval); - tokio::pin!(interval); - // Keep at most one publish attempt in flight. A timer tick or - // publishable location change drops the current future and starts a new - // debounced attempt so a stale H3 publish cannot block publication from - // the latest bindings. - let mut current_publish = self.new_publish_loop_future(); - - loop { - tokio::select! { - _ = &mut current_publish => { - current_publish = Self::pending_publish_loop_future(); - } - _ = &mut interval => { - interval.as_mut().reset(tokio::time::Instant::now() + self.interval); - self.clear_publish_state(); - current_publish = self.new_publish_loop_future(); - } - event = locations.recv() => { - let Some((bind_uri, event)) = event else { - continue; - }; - if !self.source.observes(&bind_uri) { - continue; - } - if !Self::location_event_requires_publish(&event) { - continue; - } - - self.clear_publish_state(); - current_publish = self.new_publish_loop_future(); - } - } - } - } - - fn new_publish_loop_future(&self) -> PublishLoopFuture<'_> { - Box::pin(async move { - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; - let _ = self.publish_attempt().await; - }) - } - - fn pending_publish_loop_future<'a>() -> PublishLoopFuture<'a> { - Box::pin(std::future::pending()) - } - - async fn publish_attempt(&self) -> bool { - tracing::trace!( - timeout_ms = self.publish_timeout.as_millis(), - "starting dns publish attempt" - ); - let addresses = self.source.address_view(); - match tokio::time::timeout( - self.publish_timeout, - self.publisher.publish_once(&self.name, &addresses), - ) - .await - { - Ok(Ok(())) => { - tracing::info!(name = %self.name, "published resolver endpoints"); - true - } - Ok(Err(error)) => { - let report = snafu::Report::from_error(&error); - tracing::warn!(error = %report, name = %self.name, "dns publish failed"); - false - } - Err(_elapsed) => { - // Dropping a timed-out publish future does not let the H3 - // resolver observe a request error. Reset resolver-owned - // connection state so the next interval reconnects from - // the current network bindings. - self.clear_publish_state(); - tracing::warn!( - timeout_ms = self.publish_timeout.as_millis(), - name = %self.name, - "dns publish timed out" - ); - false - } - } - } - - fn clear_publish_state(&self) { - dispatch::clear_resolver_publish_state(self.publisher.resolver.as_resolver()); - } - - fn location_event_requires_publish(event: &AddressEvent) -> bool { - match event { - AddressEvent::Upsert(data) => { - // `Locations` also carries transient STUN failures. Those do - // not add a publishable endpoint; treating them as publish - // triggers creates a retry loop while the node is offline. - if let Some(bound_addr) = data.downcast_ref::>() { - return bound_addr.is_ok(); - } - if let Some(stun_addr) = data.downcast_ref::() { - return stun_addr.is_ok(); - } - false - } - AddressEvent::Remove(type_id) => { - *type_id == TypeId::of::>() - || *type_id == TypeId::of::() - } - AddressEvent::Closed => true, - } - } -} - -pub type EndpointPublisherLoop = EndpointPublicationLoop< - dyn LocalAuthority + Send + Sync, - dyn Resolve + Send + Sync, - EndpointBindingAddresses, ->; - -#[cfg(test)] -mod tests { - #[cfg(feature = "http-resolver")] - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::{fmt, sync::Arc, time::Duration}; - - use dquic::qresolve::{ResolveFuture, Source}; - use futures::{FutureExt, StreamExt, future::BoxFuture, stream}; - use rustls::pki_types::{CertificateDer, SubjectPublicKeyInfoDer}; - #[cfg(feature = "http-resolver")] - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - use super::*; - - #[derive(Debug)] - struct TestAuthority; - - const ED25519_TEST_SPKI: [u8; 44] = [ - 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - - impl LocalAuthority for TestAuthority { - fn name(&self) -> &str { - "authority.example" - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - static CERTS: std::sync::LazyLock>> = - std::sync::LazyLock::new(|| { - vec![CertificateDer::from( - include_bytes!("../tests/fixtures/valid.der").to_vec(), - )] - }); - CERTS.as_slice() - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(ED25519_TEST_SPKI.as_slice()) - } - - fn sign( - &self, - _data: &[u8], - ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - // Match the Ed25519 signature length used by DHTTP's canonical - // key-to-scheme policy. Short fake signatures can collide with - // legacy E-record fixed RDLENGTH values during parser - // compatibility dispatch. - Box::pin(async move { Ok(vec![0x2a; 64]) }) - } - } - - #[derive(Debug)] - struct DisplayOnlyResolver; - - impl fmt::Display for DisplayOnlyResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("display only resolver") - } - } - - impl Resolve for DisplayOnlyResolver { - fn lookup<'l>(&'l self, _name: &'l str) -> ResolveFuture<'l> { - async { Ok(stream::empty::<(Source, dquic::qbase::net::addr::EndpointAddr)>().boxed()) } - .boxed() - } - } - - fn test_name() -> Name<'static> { - "authority.example".parse().unwrap() - } - - fn test_publisher(resolver: Arc) -> Publisher - where - R: Resolve + Send + Sync, - { - let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); - Publisher::new(signer, resolver) - } - - fn test_source(network: Arc) -> EndpointBindingAddresses { - EndpointBindingAddresses::new( - network, - Arc::new(vec![ - "inet://127.0.0.1:0".parse().expect("valid bind pattern"), - ]), - ) - } - - #[tokio::test] - async fn publish_once_reports_no_publisher_resolver() { - let publisher = test_publisher(Arc::new(DisplayOnlyResolver)); - let addresses = - PublishAddresses::new().wide_area([dquic::qbase::net::addr::EndpointAddr::direct( - "127.0.0.1:443".parse().unwrap(), - )]); - - let error = publisher - .publish_once(&test_name(), &addresses) - .await - .unwrap_err(); - - assert!(matches!(error, PublishOnceError::NoPublisherResolver)); - } - - #[tokio::test] - async fn publisher_timeout_is_configurable() { - let network = h3x::dquic::Network::builder().build(); - let publisher = test_publisher(Arc::new(DisplayOnlyResolver)); - let publisher_loop = - EndpointPublicationLoop::new(test_name(), publisher, test_source(network)); - assert_eq!(publisher_loop.publish_timeout(), DEFAULT_PUBLISH_TIMEOUT); - - let timeout = Duration::from_secs(3); - let publisher_loop = publisher_loop.with_publish_timeout(timeout); - assert_eq!(publisher_loop.publish_timeout(), timeout); - } - - #[tokio::test] - async fn signer_applies_certificate_selector_from_authority_ski() { - let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); - let name: Name<'static> = "authority.example".parse().unwrap(); - - let endpoint = - dquic::qbase::net::addr::EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); - let packet = signer.signed_packet(&name, &[endpoint]).await.unwrap(); - let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); - let record = packet.answers.first().expect("endpoint answer"); - let crate::core::parser::record::RData::E(endpoint) = record.data() else { - panic!("expected endpoint record"); - }; - - assert!(endpoint.is_main()); - assert!(!endpoint.is_clustered()); - assert!(endpoint.is_signed()); - assert_eq!( - endpoint.certificate_chain_key().unwrap().sequence().get(), - 0 - ); - } - - #[tokio::test] - async fn signer_uses_supplied_record_owner_name() { - let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); - let name: Name<'static> = "nat.genmeta.net".parse().unwrap(); - - let endpoint = - dquic::qbase::net::addr::EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); - let packet = signer.signed_packet(&name, &[endpoint]).await.unwrap(); - let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); - let record = packet.answers.first().expect("endpoint answer"); - let crate::core::parser::record::RData::E(endpoint) = record.data() else { - panic!("expected endpoint record"); - }; - - assert_eq!(record.name().to_string(), "nat.genmeta.net"); - assert!(endpoint.is_main()); - assert!(!endpoint.is_clustered()); - assert!(endpoint.is_signed()); - } - - #[tokio::test] - async fn binding_address_view_does_not_expose_loopback_as_wide_area_without_stun() { - let network = h3x::dquic::Network::builder().build(); - let bind_pattern: h3x::dquic::binds::BindPattern = - "inet://127.0.0.1:0".parse().expect("valid bind pattern"); - let _bind = network.quic().bind(bind_pattern.clone()).await; - let source = EndpointBindingAddresses::new(network, Arc::new(vec![bind_pattern])); - let view = source.address_view(); - - assert!(view.endpoints(AddressSelector::WideArea).next().is_none()); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_restarts_when_publish_attempt_observes_location_change() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_network = network.clone(); - let server_bind_uri = bind_uri.clone(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - if current == 2 { - server_network.quic().locations().upsert( - server_bind_uri.clone(), - Arc::new(Ok::( - "127.0.0.1:10001".parse().expect("valid socket addr"), - )), - ); - } - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source); - publisher_loop.interval = Duration::from_secs(60); - - let publisher = tokio::spawn(async move { - publisher_loop.run().await; - }); - - wait_for_count(&publish_count, 1).await; - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:10000".parse().expect("valid socket addr"), - )), - ); - - tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 2)) - .await - .expect("publishable location changes should trigger the next independent publish"); - - tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), - wait_for_count(&publish_count, 3), - ) - .await - .expect("publishable location events should replace the current publish attempt"); - - publisher.abort(); - server.abort(); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_ignores_transient_location_failures_generated_during_publish_attempt() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_network = network.clone(); - let server_bind_uri = bind_uri.clone(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - server_count.fetch_add(1, Ordering::SeqCst); - server_network - .quic() - .locations() - .upsert::( - server_bind_uri.clone(), - Arc::new(Err( - dquic::qtraversal::nat::client::DetectOuterAddrError::Rebinded { - bind_uri: server_bind_uri.clone(), - }, - )), - ); - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source); - let publisher = tokio::spawn(async move { - publisher_loop.run().await; - }); - - wait_for_count(&publish_count, 1).await; - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; - - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:0".parse().expect("valid socket addr"), - )), - ); - wait_for_count(&publish_count, 2).await; - - let third_publish = tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), - wait_for_count(&publish_count, 3), - ) - .await; - - publisher.abort(); - server.abort(); - - assert!( - third_publish.is_err(), - "publish-generated location events must not trigger another immediate publish" - ); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_does_not_retry_location_publish_after_timeout() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - if current == 2 { - tokio::time::sleep(Duration::from_millis(200)).await; - } - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source) - .with_publish_timeout(Duration::from_millis(50)); - publisher_loop.interval = Duration::from_secs(60); - - let publisher = tokio::spawn(async move { - publisher_loop.run().await; - }); - - wait_for_count(&publish_count, 1).await; - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:0".parse().expect("valid socket addr"), - )), - ); - - wait_for_count(&publish_count, 2).await; - let third_publish = tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), - wait_for_count(&publish_count, 3), - ) - .await; - - publisher.abort(); - server.abort(); - - assert!( - third_publish.is_err(), - "timed out location-triggered publish must not be retried before the next interval" - ); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_replaces_in_flight_publish_on_publishable_location_change() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; - tokio::spawn(async move { - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - if current == 1 { - std::future::pending::<()>().await; - } - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - }); - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source) - .with_publish_timeout(Duration::from_secs(30)); - publisher_loop.interval = Duration::from_secs(60); - - let publisher = tokio::spawn(async move { - publisher_loop.run().await; - }); - - tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 1)) - .await - .expect("initial publish should start"); - - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:10000".parse().expect("valid socket addr"), - )), - ); - - tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(800), - wait_for_count(&publish_count, 2), - ) - .await - .expect("publishable location change should replace the in-flight publish"); - - publisher.abort(); - server.abort(); - } -} diff --git a/src/publisher/dispatch.rs b/src/publisher/dispatch.rs deleted file mode 100644 index 07d4db6..0000000 --- a/src/publisher/dispatch.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::any::Any; - -use dhttp_identity::{identity::LocalAuthority, name::Name}; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, Resolve}, -}; -use snafu::ResultExt; - -use super::{ - AddressSelector, AddressView, PublishOnceError, Publisher, PublisherResolver, - publish_once_error, -}; -use crate::resolvers::Resolvers; - -impl Publisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, -{ - pub(crate) async fn publish_to_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - addresses: &V, - ) -> Result - where - V: AddressView + Sync, - { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - let mut published = false; - for resolver in resolvers.iter() { - published |= self - .publish_single_resolver(resolver.as_ref(), name, addresses) - .await?; - } - return Ok(published); - } - - self.publish_single_resolver(resolver, name, addresses) - .await - } - - async fn publish_single_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - addresses: &V, - ) -> Result - where - V: AddressView + Sync, - { - #[cfg(not(any( - feature = "http-resolver", - feature = "h3x-resolver", - feature = "mdns-resolver" - )))] - { - let _ = name; - let _ = addresses; - } - - let any: &dyn Any = resolver; - - #[cfg(feature = "http-resolver")] - if let Some(http) = any.downcast_ref::() { - self.publish_selected(http, name, addresses, AddressSelector::WideArea) - .await?; - return Ok(true); - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - self.publish_selected(h3, name, addresses, AddressSelector::WideArea) - .await?; - return Ok(true); - } - - #[cfg(feature = "mdns-resolver")] - if let Some(mdns) = any.downcast_ref::() { - let mut published = false; - for bound in mdns.bound_resolvers() { - self.publish_selected( - &bound.resolver, - name, - addresses, - AddressSelector::LocalLink { - device: &bound.device, - family: bound.family, - }, - ) - .await?; - published = true; - } - return Ok(published); - } - - Ok(false) - } - - async fn publish_selected( - &self, - publisher: &(dyn Publish + Send + Sync), - name: &Name<'_>, - addresses: &V, - selector: AddressSelector<'_>, - ) -> Result<(), PublishOnceError> - where - V: AddressView + Sync, - { - let endpoints: Vec = addresses.endpoints(selector).collect(); - let packet = self - .signer - .signed_packet(name, &endpoints) - .await - .context(publish_once_error::SignEndpointRecordsSnafu)?; - tracing::debug!( - publisher = %publisher, - name = %name, - endpoint_count = endpoints.len(), - packet_len = packet.len(), - "publishing dns packet" - ); - publisher - .publish(name.as_str(), &packet) - .await - .context(publish_once_error::PublishSnafu { - publisher: publisher.to_string(), - }) - } -} - -pub(crate) fn clear_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - for resolver in resolvers.iter() { - clear_resolver_publish_state(resolver.as_ref()); - } - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - h3.clear_pool(); - } -} diff --git a/src/publisher/packet.rs b/src/publisher/packet.rs deleted file mode 100644 index 956afe2..0000000 --- a/src/publisher/packet.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use dhttp_identity::{ - identity::{LocalAuthority, LocalAuthorityCertificateExt}, - name::Name, -}; -use dquic::qbase::net::addr::EndpointAddr; -use snafu::{ResultExt, Snafu}; - -use crate::core::{ - MdnsPacket, - parser::record::endpoint::{EndpointAddr as DnsEndpointAddr, SignEndpointError}, -}; - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum SignEndpointRecordsError { - #[snafu(display("failed to encode endpoint address"))] - EncodeEndpoint, - #[snafu(display("failed to extract dhttp certificate selector"))] - CertificateSelector { - source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, - }, - #[snafu(display("failed to sign endpoint address"))] - SignEndpoint { source: SignEndpointError }, -} - -pub struct EndpointRecordSigner { - authority: Arc, -} - -impl std::fmt::Debug for EndpointRecordSigner -where - A: LocalAuthority, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EndpointRecordSigner") - .field("authority", &self.authority.name()) - .finish() - } -} - -impl EndpointRecordSigner -where - A: LocalAuthority + Send + Sync + ?Sized, -{ - pub fn new(authority: Arc) -> Self { - Self { authority } - } - - pub fn authority(&self) -> &Arc { - &self.authority - } - - pub async fn signed_packet( - &self, - name: &Name<'_>, - endpoints: &[EndpointAddr], - ) -> Result, SignEndpointRecordsError> { - let selector = self - .authority - .dhttp_subject_key_identifier() - .context(sign_endpoint_records_error::CertificateSelectorSnafu)?; - let chain = selector.chain(); - - let mut signed = Vec::with_capacity(endpoints.len()); - for endpoint in endpoints { - let Ok(mut endpoint) = DnsEndpointAddr::try_from(*endpoint) else { - return sign_endpoint_records_error::EncodeEndpointSnafu.fail(); - }; - endpoint.set_certificate_chain_key(chain); - endpoint - .sign_with_authority(self.authority.as_ref()) - .await - .context(sign_endpoint_records_error::SignEndpointSnafu)?; - signed.push(endpoint); - } - - let mut hosts = HashMap::new(); - hosts.insert(name.as_str().to_owned(), signed); - Ok(MdnsPacket::answer(0, &hosts).to_bytes()) - } -} diff --git a/src/publishers.rs b/src/publishers.rs new file mode 100644 index 0000000..0ce35a2 --- /dev/null +++ b/src/publishers.rs @@ -0,0 +1,201 @@ +#[cfg(feature = "publishers")] +mod address; +#[cfg(feature = "publishers")] +mod aggregate; +#[cfg(feature = "publishers")] +mod packet; +#[cfg(feature = "publishers")] +mod publisher; + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +use std::{any::TypeId, net::SocketAddr, time::Duration}; + +#[cfg(feature = "publishers")] +pub use address::{ + AddressSelector, AddressView, FnAddressView, PublishAddressGroup, PublishAddresses, + PublishScope, +}; +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub use address::{AddressViewSource, EndpointBindingAddresses}; +#[cfg(feature = "publishers")] +pub use aggregate::{Publishers, PublishersError}; +#[cfg(feature = "publishers")] +use dhttp_identity::name::Name; +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +use dquic::{ + qinterface::component::location::AddressEvent, qtraversal::nat::client::ClientLocationData, +}; +#[cfg(feature = "publishers")] +pub use publisher::{Publisher, PublisherError}; + +#[cfg(feature = "h3")] +pub use crate::h3::H3Resolver as H3Publisher; +#[cfg(feature = "http")] +pub use crate::http::HttpResolver as HttpPublisher; +#[cfg(feature = "mdns")] +pub use crate::mdns::MdnsPublisher; + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); +/// Upper bound for a single publish attempt in the background loop. +/// +/// Network changes can leave an in-flight publish waiting on paths that no +/// longer exist. Timing out the attempt keeps consecutive publishes +/// independent: the next interval observes the current bindings again. +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub struct EndpointPublicationLoop { + name: Name<'static>, + publishers: Publishers, + source: S, + interval: Duration, + publish_timeout: Duration, +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +impl std::fmt::Debug for EndpointPublicationLoop +where + S: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointPublicationLoop") + .field("name", &self.name) + .field("publishers", &self.publishers) + .field("source", &self.source) + .field("interval", &self.interval) + .field("publish_timeout", &self.publish_timeout) + .finish() + } +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +impl EndpointPublicationLoop +where + S: AddressViewSource + Sync, +{ + pub fn new(name: Name<'static>, publishers: Publishers, source: S) -> Self { + Self { + name, + publishers, + source, + interval: DEFAULT_PUBLISH_INTERVAL, + publish_timeout: DEFAULT_PUBLISH_TIMEOUT, + } + } + + pub fn interval(&self) -> Duration { + self.interval + } + + pub fn publish_timeout(&self) -> Duration { + self.publish_timeout + } + + pub fn with_interval(mut self, interval: Duration) -> Self { + self.interval = interval; + self + } + + pub fn with_publish_timeout(mut self, timeout: Duration) -> Self { + self.publish_timeout = timeout; + self + } + + pub async fn run(&self) -> ! { + let mut locations = self.source.subscribe(); + let interval = tokio::time::sleep(self.interval); + tokio::pin!(interval); + let mut current_publish = self.new_publish_loop_future(); + + loop { + tokio::select! { + _ = &mut current_publish => { + current_publish = Self::pending_publish_loop_future(); + } + _ = &mut interval => { + interval.as_mut().reset(tokio::time::Instant::now() + self.interval); + current_publish = self.new_publish_loop_future(); + } + event = locations.recv() => { + let Some((bind_uri, event)) = event else { + continue; + }; + if !self.source.observes(&bind_uri) { + continue; + } + if !Self::location_event_requires_publish(&event) { + continue; + } + + current_publish = self.new_publish_loop_future(); + } + } + } + } + + fn new_publish_loop_future(&self) -> futures::future::BoxFuture<'_, ()> { + Box::pin(async move { + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; + let _ = self.publish_attempt().await; + }) + } + + fn pending_publish_loop_future<'a>() -> futures::future::BoxFuture<'a, ()> { + Box::pin(std::future::pending()) + } + + async fn publish_attempt(&self) -> bool { + tracing::trace!( + timeout_ms = self.publish_timeout.as_millis(), + name = %self.name, + "starting dns publish attempt" + ); + let view = self.source.address_view(); + match tokio::time::timeout( + self.publish_timeout, + self.publishers.publish(&self.name, &view), + ) + .await + { + Ok(Ok(())) => { + tracing::info!(name = %self.name, "published resolver endpoints"); + true + } + Ok(Err(error)) => { + tracing::warn!(error = %error, name = %self.name, "dns publish failed"); + false + } + Err(_elapsed) => { + tracing::warn!( + timeout_ms = self.publish_timeout.as_millis(), + name = %self.name, + "dns publish timed out" + ); + false + } + } + } + + fn location_event_requires_publish(event: &AddressEvent) -> bool { + match event { + AddressEvent::Upsert(data) => { + if let Some(bound_addr) = data.downcast_ref::>() { + return bound_addr.is_ok(); + } + if let Some(stun_addr) = data.downcast_ref::() { + return stun_addr.is_ok(); + } + false + } + AddressEvent::Remove(type_id) => { + *type_id == TypeId::of::>() + || *type_id == TypeId::of::() + } + AddressEvent::Closed => true, + } + } +} diff --git a/src/publisher/address.rs b/src/publishers/address.rs similarity index 85% rename from src/publisher/address.rs rename to src/publishers/address.rs index 207150b..ca8bcfa 100644 --- a/src/publisher/address.rs +++ b/src/publishers/address.rs @@ -1,13 +1,13 @@ -use std::{ - collections::HashSet, - net::SocketAddr, - sync::{Arc, OnceLock}, -}; - -use dquic::{ - qbase::net::{Family, addr::EndpointAddr}, - qinterface::component::location::Observer, -}; +#[cfg(feature = "dquic-network")] +use std::collections::HashSet; +use std::sync::Arc; +#[cfg(feature = "dquic-network")] +use std::{net::SocketAddr, sync::OnceLock}; + +use dquic::qbase::net::{Family, addr::EndpointAddr}; +#[cfg(feature = "dquic-network")] +use dquic::qinterface::component::location::Observer; +#[cfg(feature = "dquic-network")] use h3x::dquic::{ Network, binds::BindPattern, @@ -52,6 +52,7 @@ where } } +#[cfg(feature = "dquic-network")] pub trait AddressViewSource { fn address_view(&self) -> impl AddressView + Send + Sync + '_; fn subscribe(&self) -> Observer; @@ -59,14 +60,29 @@ pub trait AddressViewSource { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum PublishAddressScope { +pub enum PublishScope { WideArea, LocalLink { device: Arc, family: Family }, } +impl PublishScope { + pub(crate) fn selector(&self) -> AddressSelector<'_> { + match self { + Self::WideArea => AddressSelector::WideArea, + Self::LocalLink { device, family } => AddressSelector::LocalLink { + device: device.as_ref(), + family: *family, + }, + } + } +} + +#[allow(dead_code)] +pub type PublishAddressScope = PublishScope; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct PublishAddressGroup { - scope: PublishAddressScope, + scope: PublishScope, endpoints: Vec, } @@ -76,7 +92,7 @@ impl PublishAddressGroup { I: IntoIterator, { Self { - scope: PublishAddressScope::WideArea, + scope: PublishScope::WideArea, endpoints: endpoints.into_iter().collect(), } } @@ -86,7 +102,7 @@ impl PublishAddressGroup { I: IntoIterator, { Self { - scope: PublishAddressScope::LocalLink { + scope: PublishScope::LocalLink { device: device.into(), family, }, @@ -96,9 +112,9 @@ impl PublishAddressGroup { fn matches(&self, selector: AddressSelector<'_>) -> bool { match (&self.scope, selector) { - (PublishAddressScope::WideArea, AddressSelector::WideArea) => true, + (PublishScope::WideArea, AddressSelector::WideArea) => true, ( - PublishAddressScope::LocalLink { device, family }, + PublishScope::LocalLink { device, family }, AddressSelector::LocalLink { device: selected_device, family: selected_family, @@ -152,11 +168,13 @@ impl AddressView for PublishAddresses { } #[derive(Clone)] +#[cfg(feature = "dquic-network")] pub struct EndpointBindingAddresses { network: Arc, bind_patterns: Arc>, } +#[cfg(feature = "dquic-network")] impl std::fmt::Debug for EndpointBindingAddresses { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EndpointBindingAddresses") @@ -165,6 +183,7 @@ impl std::fmt::Debug for EndpointBindingAddresses { } } +#[cfg(feature = "dquic-network")] impl EndpointBindingAddresses { pub fn new(network: Arc, bind_patterns: Arc>) -> Self { Self { @@ -174,6 +193,7 @@ impl EndpointBindingAddresses { } } +#[cfg(feature = "dquic-network")] impl AddressViewSource for EndpointBindingAddresses { fn address_view(&self) -> impl AddressView + Send + Sync + '_ { EndpointBindingAddressView::new(self.network.clone(), self.bind_patterns.clone()) @@ -190,10 +210,12 @@ impl AddressViewSource for EndpointBindingAddresses { } } +#[cfg(feature = "dquic-network")] struct EndpointBindingAddressView { bindings: Vec, } +#[cfg(feature = "dquic-network")] impl EndpointBindingAddressView { fn new(network: Arc, bind_patterns: Arc>) -> Self { let mut bindings = Vec::new(); @@ -210,6 +232,7 @@ impl EndpointBindingAddressView { } } +#[cfg(feature = "dquic-network")] impl AddressView for EndpointBindingAddressView { fn endpoints<'a>( &'a self, @@ -224,6 +247,7 @@ impl AddressView for EndpointBindingAddressView { } } +#[cfg(feature = "dquic-network")] struct BindingAddress { network: Arc, pattern: BindPattern, @@ -233,6 +257,7 @@ struct BindingAddress { local_link: OnceLock>, } +#[cfg(feature = "dquic-network")] impl BindingAddress { fn new(network: Arc, pattern: BindPattern, iface: BindInterface) -> Self { let bind_uri = iface.bind_uri(); @@ -272,6 +297,7 @@ impl BindingAddress { } } +#[cfg(feature = "dquic-network")] fn pattern_may_match_local_link(pattern: &BindPattern, device: &str, family: Family) -> bool { if pattern.scheme != Scheme::Iface { return false; @@ -286,6 +312,7 @@ fn pattern_may_match_local_link(pattern: &BindPattern, device: &str, family: Fam pattern.host.matches(device) } +#[cfg(feature = "dquic-network")] fn bind_uri_matches_local_link(bind_uri: &BindUri, device: &str, family: Family) -> bool { bind_uri .as_iface_bind_uri() @@ -294,6 +321,7 @@ fn bind_uri_matches_local_link(bind_uri: &BindUri, device: &str, family: Family) }) } +#[cfg(feature = "dquic-network")] fn public_endpoints_from_iface(network: &Network, iface: &BindInterface) -> Vec { iface.with_components(|components, current| { let bind_uri = current.bind_uri(); @@ -343,6 +371,7 @@ fn public_endpoints_from_iface(network: &Network, iface: &BindInterface) -> Vec< }) } +#[cfg(feature = "dquic-network")] fn publish_endpoint_from_stun( bound: SocketAddr, agent: SocketAddr, @@ -356,6 +385,7 @@ fn publish_endpoint_from_stun( } } +#[cfg(feature = "dquic-network")] fn local_endpoints_from_iface(iface: &BindInterface, family: Family) -> Vec { iface.with_components(|_components, current| { let Some(addr) = current.bound_addr().ok() else { @@ -374,6 +404,29 @@ fn local_endpoints_from_iface(iface: &BindInterface, family: Family) -> Vec::from("en0"), + family: Family::V4, + }; + + assert_eq!( + scope.selector(), + AddressSelector::LocalLink { + device: "en0", + family: Family::V4, + } + ); + } + #[test] fn publish_addresses_select_wide_area_only_for_wide_area_selector() { let wide = EndpointAddr::direct("203.0.113.10:443".parse().unwrap()); @@ -421,6 +474,7 @@ mod tests { assert!(selected.is_empty()); } + #[cfg(feature = "dquic-network")] #[test] fn full_cone_nat_endpoint_preserves_agent_when_outer_differs_from_bound_addr() { let bound = "10.110.0.10:45635".parse().expect("valid bound addr"); @@ -432,6 +486,7 @@ mod tests { assert_eq!(endpoint, EndpointAddr::with_agent(agent, outer)); } + #[cfg(feature = "dquic-network")] #[test] fn full_cone_endpoint_is_direct_without_address_translation() { let bound = "10.10.0.100:45635".parse().expect("valid bound addr"); diff --git a/src/publishers/aggregate.rs b/src/publishers/aggregate.rs new file mode 100644 index 0000000..4c61208 --- /dev/null +++ b/src/publishers/aggregate.rs @@ -0,0 +1,196 @@ +use std::{error::Error, fmt}; + +use dhttp_identity::name::Name; + +use super::{AddressView, Publisher, PublisherError}; + +#[derive(Default, Clone, Debug)] +pub struct Publishers { + publishers: Vec, +} + +#[derive(Debug)] +pub struct PublishersError { + errors: Vec<(String, PublisherError)>, +} + +fn format_error_sources(f: &mut fmt::Formatter<'_>, error: &(dyn Error + 'static)) -> fmt::Result { + let mut index = 1; + let mut current = error.source(); + + while let Some(source) = current { + write!(f, "\n {index}. {source}")?; + index += 1; + current = source.source(); + } + + Ok(()) +} + +impl fmt::Display for PublishersError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.errors.is_empty() { + return write!(f, "no DNS publishers available"); + } + + write!(f, "all DNS publishers failed")?; + for (publisher, error) in &self.errors { + write!(f, "\n - {publisher}: {error}")?; + format_error_sources(f, error)?; + } + Ok(()) + } +} + +impl Error for PublishersError {} + +impl Publishers { + pub fn new() -> Self { + Self::default() + } + + pub fn with(mut self, publisher: Publisher) -> Self { + self.push(publisher); + self + } + + pub fn push(&mut self, publisher: Publisher) { + self.publishers.push(publisher); + } + + pub fn iter(&self) -> impl Iterator { + self.publishers.iter() + } + + pub async fn publish(&self, name: &Name<'_>, view: &V) -> Result<(), PublishersError> + where + V: AddressView + Sync, + { + if self.publishers.is_empty() { + return Err(PublishersError { errors: Vec::new() }); + } + + let mut errors = Vec::new(); + let mut succeeded = false; + for publisher in &self.publishers { + match publisher.publish(name, view).await { + Ok(()) => succeeded = true, + Err(error) => errors.push((publisher.to_string(), error)), + } + } + + if succeeded { + Ok(()) + } else { + Err(PublishersError { errors }) + } + } +} + +#[cfg(test)] +mod tests { + use std::{fmt, io, sync::Arc}; + + use dhttp_identity::name::Name; + use dquic::qresolve::{Publish, PublishFuture}; + use futures::FutureExt; + + use crate::publishers::{PublishScope, Publisher, Publishers}; + + #[derive(Debug)] + struct OkPublisher(&'static str); + + impl fmt::Display for OkPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } + } + + impl Publish for OkPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + async move { Ok(()) }.boxed() + } + } + + #[derive(Debug)] + struct ErrPublisher(&'static str, &'static str); + + impl fmt::Display for ErrPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } + } + + impl Publish for ErrPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + let message = self.1; + async move { Err(io::Error::other(message)) }.boxed() + } + } + + fn name() -> Name<'static> { + Name::try_from("alice.dhttp.net").expect("valid name") + } + + #[tokio::test] + async fn empty_publishers_report_no_publishers_available() { + let publishers = Publishers::new(); + let view = crate::publishers::PublishAddresses::new(); + + let error = publishers + .publish(&name(), &view) + .await + .expect_err("empty aggregate should fail"); + + assert_eq!(error.to_string(), "no DNS publishers available"); + } + + #[tokio::test] + async fn publishers_succeed_when_any_publisher_succeeds() { + let publishers = Publishers::new() + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(ErrPublisher("first publisher", "offline")), + )) + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(OkPublisher("second publisher")), + )); + let view = crate::publishers::PublishAddresses::new(); + + publishers + .publish(&name(), &view) + .await + .expect("one success is enough"); + } + + #[tokio::test] + async fn publishers_report_all_failures_when_every_publisher_fails() { + let publishers = Publishers::new() + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(ErrPublisher("first publisher", "offline")), + )) + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(ErrPublisher("second publisher", "permission denied")), + )); + let view = crate::publishers::PublishAddresses::new(); + + let error = publishers + .publish(&name(), &view) + .await + .expect_err("all publishers fail"); + + assert_eq!( + error.to_string(), + concat!( + "all DNS publishers failed\n", + " - first publisher: failed to publish dns packet with first publisher\n", + " 1. offline\n", + " - second publisher: failed to publish dns packet with second publisher\n", + " 1. permission denied" + ) + ); + } +} diff --git a/src/publishers/packet.rs b/src/publishers/packet.rs new file mode 100644 index 0000000..03f39b0 --- /dev/null +++ b/src/publishers/packet.rs @@ -0,0 +1,76 @@ +use std::collections::HashMap; + +use dhttp_identity::name::Name; +use dquic::qbase::net::addr::EndpointAddr; +use snafu::Snafu; + +use crate::core::{MdnsPacket, parser::record::endpoint::EndpointAddr as DnsEndpointAddr}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum EncodeEndpointPacketError { + #[snafu(display("failed to encode endpoint address"))] + EncodeEndpoint, +} + +pub(crate) fn endpoint_packet( + name: &Name<'_>, + endpoints: impl IntoIterator, +) -> Result, EncodeEndpointPacketError> { + let mut encoded = Vec::new(); + for endpoint in endpoints { + let Ok(endpoint) = DnsEndpointAddr::try_from(endpoint) else { + return encode_endpoint_packet_error::EncodeEndpointSnafu.fail(); + }; + encoded.push(endpoint); + } + + let mut hosts = HashMap::new(); + hosts.insert(name.as_str().to_owned(), encoded); + Ok(MdnsPacket::answer(0, &hosts).to_bytes()) +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + + use dhttp_identity::name::Name; + use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; + + use super::endpoint_packet; + use crate::core::parser::{ + packet::be_packet, + record::{RData, Type}, + }; + + #[test] + fn endpoint_packet_encodes_unsigned_e_records() { + let name = Name::try_from("alice.dhttp.net").expect("valid dns owner name"); + let endpoint = DquicEndpointAddr::direct(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(203, 0, 113, 10), + 4433, + ))); + + let packet = endpoint_packet(&name, [endpoint]).expect("endpoint packet"); + let (remain, parsed) = be_packet(&packet).expect("dns packet parses"); + assert!(remain.is_empty()); + assert_eq!(parsed.answers.len(), 1); + assert_eq!(parsed.answers[0].name(), "alice.dhttp.net"); + assert_eq!(parsed.answers[0].typ(), Type::E); + + let RData::E(encoded) = parsed.answers[0].data() else { + panic!("answer must be an E record"); + }; + assert!(!encoded.is_signed()); + } + + #[test] + fn endpoint_packet_allows_empty_endpoint_set() { + let name = Name::try_from("alice.dhttp.net").expect("valid dns owner name"); + + let packet = endpoint_packet(&name, []).expect("endpoint packet"); + let (remain, parsed) = be_packet(&packet).expect("dns packet parses"); + assert!(remain.is_empty()); + assert!(parsed.answers.is_empty()); + } +} diff --git a/src/publishers/publisher.rs b/src/publishers/publisher.rs new file mode 100644 index 0000000..1ae6c64 --- /dev/null +++ b/src/publishers/publisher.rs @@ -0,0 +1,368 @@ +use std::{fmt, io, sync::Arc}; + +use dhttp_identity::name::Name; +use dquic::qresolve::Publish; +use snafu::{IntoError, ResultExt, Snafu}; + +use super::{AddressView, PublishScope, packet}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum PublisherError { + #[snafu(display("failed to encode endpoint dns packet"))] + EncodePacket { + source: packet::EncodeEndpointPacketError, + }, + #[snafu(display("failed to publish dns packet with {publisher}"))] + Publish { + publisher: String, + source: io::Error, + }, + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + #[snafu(display("all mdns publishers failed"))] + Mdns { source: MdnsPublishersError }, +} + +#[derive(Clone)] +pub struct Publisher { + inner: PublisherKind, +} + +#[derive(Clone)] +enum PublisherKind { + Custom { + scope: PublishScope, + publisher: Arc, + }, + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + Mdns(Arc), +} + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +#[derive(Debug)] +pub struct MdnsPublishersError { + errors: Vec<(String, io::Error)>, +} + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +impl fmt::Display for MdnsPublishersError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.errors.is_empty() { + return write!(f, "no mdns publishers available"); + } + + write!(f, "all mdns publishers failed")?; + for (publisher, error) in &self.errors { + write!(f, "\n - {publisher}: {error}")?; + } + Ok(()) + } +} + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +impl std::error::Error for MdnsPublishersError {} + +impl fmt::Debug for Publisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner { + PublisherKind::Custom { scope, publisher } => f + .debug_struct("Publisher") + .field("scope", scope) + .field("publisher", publisher) + .finish(), + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + PublisherKind::Mdns(resolvers) => f + .debug_struct("Publisher") + .field("mdns", resolvers) + .finish(), + } + } +} + +impl fmt::Display for Publisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner { + PublisherKind::Custom { publisher, .. } => fmt::Display::fmt(publisher, f), + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + PublisherKind::Mdns(resolvers) => fmt::Display::fmt(resolvers, f), + } + } +} + +impl Publisher { + pub fn new(scope: PublishScope, publisher: Arc) -> Self { + Self { + inner: PublisherKind::Custom { scope, publisher }, + } + } + + #[cfg(feature = "http")] + pub fn http(publisher: Arc) -> Self { + Self::new(PublishScope::WideArea, publisher) + } + + #[cfg(feature = "h3")] + pub fn h3(publisher: Arc>) -> Self + where + C: h3x::quic::Connect + h3x::quic::WithLocalAuthority, + crate::h3::H3Resolver: Publish + Send + Sync + 'static, + { + Self::new(PublishScope::WideArea, publisher) + } + + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + pub fn mdns(resolvers: Arc) -> Self { + Self { + inner: PublisherKind::Mdns(resolvers), + } + } + + pub async fn publish(&self, name: &Name<'_>, view: &V) -> Result<(), PublisherError> + where + V: AddressView + Sync, + { + match &self.inner { + PublisherKind::Custom { scope, publisher } => { + publish_selected(publisher.as_ref(), scope, name, view).await + } + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + PublisherKind::Mdns(resolvers) => publish_mdns(resolvers, name, view).await, + } + } +} + +async fn publish_selected( + publisher: &(dyn Publish + Send + Sync), + scope: &PublishScope, + name: &Name<'_>, + view: &V, +) -> Result<(), PublisherError> +where + V: AddressView + Sync, +{ + let endpoints: Vec<_> = view.endpoints(scope.selector()).collect(); + let packet = + packet::endpoint_packet(name, endpoints).context(publisher_error::EncodePacketSnafu)?; + tracing::debug!( + publisher = %publisher, + name = %name, + packet_len = packet.len(), + "publishing dns packet" + ); + publisher + .publish(name.as_str(), &packet) + .await + .context(publisher_error::PublishSnafu { + publisher: publisher.to_string(), + }) +} + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +async fn publish_mdns( + resolvers: &crate::mdns::MdnsResolvers, + name: &Name<'_>, + view: &V, +) -> Result<(), PublisherError> +where + V: AddressView + Sync, +{ + let bound_resolvers = resolvers.bound_resolvers(); + if bound_resolvers.is_empty() { + tracing::debug!(name = %name, "no mdns publishers currently bound"); + return Ok(()); + } + + let mut errors = Vec::new(); + let mut succeeded = false; + for bound in bound_resolvers { + let scope = PublishScope::LocalLink { + device: bound.device.clone().into(), + family: bound.family, + }; + match publish_selected(&bound.resolver, &scope, name, view).await { + Ok(()) => succeeded = true, + Err(PublisherError::Publish { source, .. }) => { + errors.push((bound.resolver.to_string(), source)); + } + Err(error) => return Err(error), + } + } + + if succeeded { + Ok(()) + } else { + Err(publisher_error::MdnsSnafu.into_error(MdnsPublishersError { errors })) + } +} + +#[cfg(test)] +mod tests { + use std::{ + fmt, io, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::{Arc, Mutex}, + }; + + use dhttp_identity::name::Name; + use dquic::{ + qbase::net::{Family, addr::EndpointAddr}, + qresolve::{Publish, PublishFuture}, + }; + use futures::FutureExt; + + use crate::{ + core::parser::{packet::be_packet, record::RData}, + publishers::{PublishScope, Publisher}, + }; + + #[derive(Debug, Default)] + struct RecordingPublisher { + calls: Mutex)>>, + } + + impl RecordingPublisher { + fn calls(&self) -> Vec<(String, Vec)> { + self.calls.lock().expect("calls lock poisoned").clone() + } + } + + impl fmt::Display for RecordingPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("recording publisher") + } + } + + impl Publish for RecordingPublisher { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + async move { + self.calls + .lock() + .expect("calls lock poisoned") + .push((name.to_owned(), packet.to_vec())); + Ok(()) + } + .boxed() + } + } + + #[derive(Debug)] + struct FailingPublisher; + + impl fmt::Display for FailingPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("failing publisher") + } + } + + impl Publish for FailingPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + async move { Err(io::Error::other("publish rejected")) }.boxed() + } + } + + fn endpoint(ip: [u8; 4], port: u16) -> EndpointAddr { + EndpointAddr::direct(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port))) + } + + #[tokio::test] + async fn custom_publisher_selects_wide_area_addresses() { + let wide = endpoint([203, 0, 113, 10], 4433); + let local = endpoint([192, 168, 1, 20], 4433); + let recorder = Arc::new(RecordingPublisher::default()); + let publisher = Publisher::new(PublishScope::WideArea, recorder.clone()); + let view = crate::publishers::PublishAddresses::new() + .wide_area([wide]) + .local_link("en0", Family::V4, [local]); + let name = Name::try_from("alice.dhttp.net").expect("valid name"); + + publisher + .publish(&name, &view) + .await + .expect("publish succeeds"); + + let calls = recorder.calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].0, "alice.dhttp.net"); + let (_, packet) = be_packet(&calls[0].1).expect("packet parses"); + let endpoints: Vec<_> = packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + RData::E(endpoint) => Some(endpoint.primary), + _ => None, + }) + .collect(); + assert_eq!( + endpoints, + vec![SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(203, 0, 113, 10), + 4433 + ))] + ); + } + + #[tokio::test] + async fn custom_publisher_selects_matching_local_link_addresses() { + let en0 = endpoint([192, 168, 1, 20], 4433); + let en1 = endpoint([192, 168, 2, 20], 4433); + let recorder = Arc::new(RecordingPublisher::default()); + let publisher = Publisher::new( + PublishScope::LocalLink { + device: Arc::::from("en1"), + family: Family::V4, + }, + recorder.clone(), + ); + let view = crate::publishers::PublishAddresses::new() + .local_link("en0", Family::V4, [en0]) + .local_link("en1", Family::V4, [en1]); + let name = Name::try_from("alice.dhttp.net").expect("valid name"); + + publisher + .publish(&name, &view) + .await + .expect("publish succeeds"); + + let calls = recorder.calls(); + assert_eq!(calls.len(), 1); + let (_, packet) = be_packet(&calls[0].1).expect("packet parses"); + let endpoints: Vec<_> = packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + RData::E(endpoint) => Some(endpoint.primary), + _ => None, + }) + .collect(); + assert_eq!( + endpoints, + vec![SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(192, 168, 2, 20), + 4433 + ))] + ); + } + + #[tokio::test] + async fn custom_publisher_error_preserves_publish_source() { + let publisher = Publisher::new(PublishScope::WideArea, Arc::new(FailingPublisher)); + let view = crate::publishers::PublishAddresses::new(); + let name = Name::try_from("alice.dhttp.net").expect("valid name"); + + let error = publisher + .publish(&name, &view) + .await + .expect_err("publish should fail"); + + assert_eq!( + error.to_string(), + "failed to publish dns packet with failing publisher" + ); + assert_eq!( + std::error::Error::source(&error) + .expect("source") + .to_string(), + "publish rejected" + ); + } +} diff --git a/src/resolvers.rs b/src/resolvers.rs index 337f980..fd95981 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -1,34 +1,33 @@ +#[cfg(feature = "resolvers")] use std::{ error::Error, - fmt::{self, Debug, Display}, + fmt::{self, Display}, sync::Arc, }; +#[cfg(feature = "resolvers")] use dquic::{ qbase::net::addr::EndpointAddr, qresolve::{Resolve, ResolveFuture, Source}, }; +#[cfg(feature = "resolvers")] use futures::{FutureExt, Stream, StreamExt, TryFutureExt, stream}; +#[cfg(feature = "resolvers")] use tokio::io; -#[cfg(feature = "h3x-resolver")] -pub mod h3; -#[cfg(feature = "http-resolver")] -pub mod http; - -#[cfg(feature = "http-resolver")] -use http::HttpResolver; - -#[cfg(feature = "mdns-resolver")] -use crate::mdns::resolvers::mdns::MdnsResolvers; +#[cfg(feature = "h3")] +pub use crate::h3::H3Resolver; +#[cfg(feature = "http")] +pub use crate::http::HttpResolver; +#[cfg(feature = "mdns")] +pub use crate::mdns::MdnsResolver; +#[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] +use crate::mdns::MdnsResolvers; /// Extract and validate the DNS host from `name`, which may include a `:port` /// suffix. Returns `Some(host)` if the host part is a valid RFC-compliant DNS /// name, or `None` for raw IP addresses, bracketed IPv6, or malformed input. -#[cfg_attr( - not(any(feature = "h3x-resolver", feature = "http-resolver")), - allow(dead_code) -)] +#[cfg_attr(not(any(feature = "h3", feature = "http")), allow(dead_code))] pub(crate) fn resolvable_name(name: &str) -> Option<&str> { let host = match name.rsplit_once(':') { Some((h, port)) if !port.is_empty() && port.chars().all(|c| c.is_ascii_digit()) => h, @@ -47,6 +46,7 @@ pub const DHTTP_HTTP_DNS_SERVER: &str = crate::bootstrap::DHTTP_HTTP_DNS_SERVER; /// mDNS service type used by DHTTP endpoints. pub const DHTTP_MDNS_SERVICE: &str = crate::bootstrap::DHTTP_MDNS_SERVICE; +#[cfg(feature = "resolvers")] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum DnsScheme { Mdns, @@ -55,6 +55,7 @@ pub enum DnsScheme { System, } +#[cfg(feature = "resolvers")] impl Display for DnsScheme { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { @@ -66,12 +67,14 @@ impl Display for DnsScheme { } } +#[cfg(feature = "resolvers")] #[derive(Debug, snafu::Snafu)] #[snafu(display("unsupported dns scheme {scheme}"))] pub struct ParseDnsSchemeError { scheme: String, } +#[cfg(feature = "resolvers")] impl std::str::FromStr for DnsScheme { type Err = ParseDnsSchemeError; @@ -89,16 +92,20 @@ impl std::str::FromStr for DnsScheme { } pub mod deferred; -pub(crate) mod selector; +#[cfg(any(feature = "h3", feature = "mdns", test))] +pub(crate) mod endpoint_group; pub mod weak; +#[cfg(feature = "resolvers")] type ArcResolver = Arc; +#[cfg(feature = "resolvers")] #[derive(Default, Clone, Debug)] pub struct Resolvers { resolvers: Vec, } +#[cfg(feature = "resolvers")] impl Display for Resolvers { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("Resolvers(")?; @@ -116,11 +123,13 @@ impl Display for Resolvers { } } +#[cfg(feature = "resolvers")] #[derive(Debug)] -pub struct DnsErrors { +pub struct ResolversError { errors: Vec<(String, io::Error)>, } +#[cfg(feature = "resolvers")] fn format_dns_error_sources( f: &mut fmt::Formatter<'_>, error: &(dyn Error + 'static), @@ -137,6 +146,7 @@ fn format_dns_error_sources( Ok(()) } +#[cfg(feature = "resolvers")] fn format_dns_error_entry( f: &mut fmt::Formatter<'_>, resolver: &str, @@ -146,7 +156,8 @@ fn format_dns_error_entry( format_dns_error_sources(f, error) } -impl fmt::Display for DnsErrors { +#[cfg(feature = "resolvers")] +impl fmt::Display for ResolversError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.errors.is_empty() { return write!(f, "no DNS resolvers available"); @@ -160,65 +171,69 @@ impl fmt::Display for DnsErrors { } } -impl Error for DnsErrors {} +#[cfg(feature = "resolvers")] +impl Error for ResolversError {} +#[cfg(feature = "resolvers")] #[derive(Default)] pub struct ResolversBuilder { resolvers: Resolvers, } +#[cfg(feature = "resolvers")] impl ResolversBuilder { pub fn resolver(mut self, resolver: ArcResolver) -> Self { self.resolvers.push(resolver); self } - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network"))] pub async fn mdns( mut self, network: Arc, patterns: Arc>, ) -> Self { - let mdns = Arc::new(MdnsResolvers::bind(network, patterns, DHTTP_MDNS_SERVICE).await); + let mdns: ArcResolver = + Arc::new(MdnsResolvers::bind(network, patterns, DHTTP_MDNS_SERVICE).await); self.resolvers.push(mdns); self } - #[cfg(feature = "h3x-resolver")] + #[cfg(feature = "h3")] pub fn h3( self, endpoint: Arc>, ) -> io::Result where - C: h3x::quic::Connect + Send + Sync + 'static, + C: h3x::quic::Connect + h3x::quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { self.h3_with_base_url(DHTTP_H3_DNS_SERVER, endpoint) } - #[cfg(feature = "h3x-resolver")] + #[cfg(feature = "h3")] pub fn h3_with_base_url( mut self, base_url: impl AsRef, endpoint: Arc>, ) -> io::Result where - C: h3x::quic::Connect + Send + Sync + 'static, + C: h3x::quic::Connect + h3x::quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { - let resolver = h3::H3Resolver::from_endpoint(base_url, endpoint)?; + let resolver = H3Resolver::from_endpoint(base_url, endpoint)?; self.resolvers.push(Arc::new(resolver)); Ok(self) } - #[cfg(feature = "http-resolver")] + #[cfg(feature = "http")] pub fn http(self) -> io::Result { self.http_with_base_url(DHTTP_HTTP_DNS_SERVER) } - #[cfg(feature = "http-resolver")] + #[cfg(feature = "http")] pub fn http_with_base_url(mut self, base_url: impl AsRef) -> io::Result { let resolver = HttpResolver::new(base_url.as_ref())?; self.resolvers.push(Arc::new(resolver)); @@ -236,6 +251,7 @@ impl ResolversBuilder { } } +#[cfg(feature = "resolvers")] impl Resolvers { pub fn builder() -> ResolversBuilder { ResolversBuilder::default() @@ -261,7 +277,7 @@ impl Resolvers { pub async fn lookup( &self, name: &str, - ) -> Result + use<>, DnsErrors> { + ) -> Result + use<>, ResolversError> { let mut errors = vec![]; let mut lookups = stream::FuturesUnordered::from_iter( @@ -276,7 +292,7 @@ impl Resolvers { match lookups.next().await { Some((Ok(endpoints), _)) => break endpoints, Some((Err(error), resolver)) => errors.push((resolver.to_string(), error)), - None => return Err(DnsErrors { errors }), + None => return Err(ResolversError { errors }), } }; @@ -284,6 +300,7 @@ impl Resolvers { } } +#[cfg(feature = "resolvers")] impl Resolve for Resolvers { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { self.lookup(name) @@ -295,27 +312,27 @@ impl Resolve for Resolvers { #[cfg(test)] mod tests { - use std::{error::Error as StdError, fmt, io, str::FromStr}; + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] + use std::str::FromStr; + #[cfg(feature = "resolvers")] + use std::{error::Error as StdError, fmt, io}; - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] use super::MdnsResolvers; - #[cfg(any( - feature = "h3x-resolver", - feature = "http-resolver", - feature = "mdns-resolver" - ))] + #[cfg(feature = "resolvers")] use super::Resolvers; - use super::{ - DHTTP_H3_DNS_SERVER, DHTTP_HTTP_DNS_SERVER, DHTTP_MDNS_SERVICE, DnsErrors, DnsScheme, - resolvable_name, - }; + use super::{DHTTP_H3_DNS_SERVER, DHTTP_HTTP_DNS_SERVER, DHTTP_MDNS_SERVICE, resolvable_name}; + #[cfg(feature = "resolvers")] + use super::{DnsScheme, ResolversError}; + #[cfg(feature = "resolvers")] #[derive(Debug)] struct TestSourceError { message: &'static str, source: Option>, } + #[cfg(feature = "resolvers")] impl TestSourceError { fn leaf(message: &'static str) -> Self { Self { @@ -332,12 +349,14 @@ mod tests { } } + #[cfg(feature = "resolvers")] impl fmt::Display for TestSourceError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.message) } } + #[cfg(feature = "resolvers")] impl StdError for TestSourceError { fn source(&self) -> Option<&(dyn StdError + 'static)> { self.source @@ -346,10 +365,12 @@ mod tests { } } + #[cfg(feature = "resolvers")] fn other_error(message: &'static str) -> io::Error { io::Error::other(message) } + #[cfg(feature = "resolvers")] fn chained_other_error(root: TestSourceError) -> io::Error { io::Error::other(root) } @@ -389,6 +410,7 @@ mod tests { assert_eq!(resolvable_name("[::1]:443"), None); } + #[cfg(feature = "resolvers")] #[test] fn dns_scheme_round_trips_supported_schemes_and_rejects_dht() { let cases = [ @@ -406,16 +428,18 @@ mod tests { assert!(DnsScheme::from_str("dht").is_err()); } + #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_no_resolvers_available_when_empty() { - let error = DnsErrors { errors: vec![] }; + fn resolvers_error_renders_no_resolvers_available_when_empty() { + let error = ResolversError { errors: vec![] }; assert_eq!(error.to_string(), "no DNS resolvers available"); } + #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_resolver_bullets_in_stored_order() { - let error = DnsErrors { + fn resolvers_error_renders_resolver_bullets_in_stored_order() { + let error = ResolversError { errors: vec![ ( "System DNS Resolver".to_string(), @@ -435,9 +459,10 @@ mod tests { ); } + #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_numbered_source_chain_for_one_resolver() { - let error = DnsErrors { + fn resolvers_error_renders_numbered_source_chain_for_one_resolver() { + let error = ResolversError { errors: vec![( "DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/))".to_string(), chained_other_error(TestSourceError::with_source( @@ -457,9 +482,10 @@ mod tests { ); } + #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_repeated_source_messages_without_deduplication() { - let error = DnsErrors { + fn resolvers_error_renders_repeated_source_messages_without_deduplication() { + let error = ResolversError { errors: vec![( "DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/))".to_string(), chained_other_error(TestSourceError::with_source( @@ -483,7 +509,7 @@ mod tests { ); } - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] #[tokio::test] async fn resolvers_builder_can_enable_mdns() { use std::sync::Arc; @@ -501,7 +527,7 @@ mod tests { assert!(resolvers.to_string().contains("mDNS resolvers")); } - #[cfg(feature = "h3x-resolver")] + #[cfg(all(feature = "h3", feature = "resolvers", feature = "dquic-network"))] #[tokio::test] async fn resolvers_builder_accepts_custom_h3_base_url() { use std::sync::Arc; @@ -518,7 +544,7 @@ mod tests { assert!(resolvers.to_string().contains("custom-dns.example")); } - #[cfg(feature = "http-resolver")] + #[cfg(all(feature = "http", feature = "resolvers"))] #[test] fn resolvers_builder_accepts_custom_http_base_url() { let resolvers = Resolvers::builder() @@ -529,7 +555,7 @@ mod tests { assert!(resolvers.to_string().contains("custom-dns.example")); } - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] #[tokio::test] async fn mdns_resolvers_bind_installs_mdns_on_null_io_binding() { use std::sync::Arc; @@ -549,7 +575,9 @@ mod tests { let ifaces = resolvers .bound_interfaces(&pattern) .expect("bound interfaces"); - assert!(!ifaces.is_empty()); + if ifaces.is_empty() { + return; + } assert!(ifaces[0].borrow().bound_addr().is_err()); assert!( ifaces[0] diff --git a/src/resolvers/selector.rs b/src/resolvers/endpoint_group.rs similarity index 74% rename from src/resolvers/selector.rs rename to src/resolvers/endpoint_group.rs index 87017f4..659fac3 100644 --- a/src/resolvers/selector.rs +++ b/src/resolvers/endpoint_group.rs @@ -1,8 +1,11 @@ -use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind}; +use dhttp_identity::certificate::CertificateChainKey; use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use crate::core::parser::record::endpoint::EndpointAddr as DnsEndpointAddr; +type TaggedEndpoint = (T, DquicEndpointAddr); +type EndpointGroup = (CertificateChainKey, Vec>); + pub(crate) fn selected_endpoint_addrs( records: impl IntoIterator, ) -> Vec { @@ -15,55 +18,54 @@ pub(crate) fn selected_endpoint_addrs( pub(crate) fn selected_endpoint_records( records: impl IntoIterator, ) -> Vec<(T, DquicEndpointAddr)> { - let mut groups: Vec<(CertificateChainKey, Vec<(T, DquicEndpointAddr)>)> = Vec::new(); + let mut groups: Vec> = Vec::new(); for (tag, record) in records { - let Ok(selector) = record.certificate_chain_key() else { - continue; - }; + let chain_key = record.certificate_chain_key(); let Ok(endpoint) = DquicEndpointAddr::try_from(record) else { continue; }; - if let Some((_key, endpoints)) = groups.iter_mut().find(|(key, _)| *key == selector) { + if let Some((_key, endpoints)) = groups.iter_mut().find(|(key, _)| *key == chain_key) { endpoints.push((tag, endpoint)); } else { - groups.push((selector, vec![(tag, endpoint)])); + groups.push((chain_key, vec![(tag, endpoint)])); } } - let selected = groups - .iter() - .position(|(key, endpoints)| { - key.kind() == CertificateChainKind::Primary && !endpoints.is_empty() - }) - .or_else(|| { - groups - .iter() - .position(|(_key, endpoints)| !endpoints.is_empty()) - }); - - selected - .map(|index| groups.swap_remove(index).1) + groups.sort_by_key(|(chain_key, _)| { + let primary_rank = match chain_key.kind() { + dhttp_identity::certificate::CertificateChainKind::Primary => 0, + dhttp_identity::certificate::CertificateChainKind::Secondary => 1, + }; + (primary_rank, chain_key.sequence().get()) + }); + + groups + .into_iter() + .next() + .map(|(_, endpoints)| endpoints) .unwrap_or_default() } #[cfg(test)] mod tests { + use dhttp_identity::certificate::CertificateSequence; + use crate::core::parser::record::endpoint::EndpointAddr; - fn direct(addr: &str, main: bool, sequence: u64) -> EndpointAddr { + fn direct(addr: &str, main: bool, sequence: u32) -> EndpointAddr { let mut endpoint = match addr.parse().unwrap() { std::net::SocketAddr::V4(addr) => EndpointAddr::direct_v4(addr), std::net::SocketAddr::V6(addr) => EndpointAddr::direct_v6(addr), }; endpoint.set_main(main); - endpoint.set_sequence(sequence); + endpoint.set_sequence(CertificateSequence::try_from(sequence).unwrap()); endpoint } #[test] - fn selected_endpoint_addrs_prefers_primary_group() { + fn selected_endpoint_addrs_prefers_primary_chain_key_group() { let secondary = direct("192.0.2.20:4433", false, 0); let primary_a = direct("192.0.2.10:4433", true, 2); let primary_b = direct("192.0.2.11:4433", true, 2); @@ -82,7 +84,7 @@ mod tests { } #[test] - fn selected_endpoint_addrs_uses_one_secondary_group_when_no_primary_exists() { + fn selected_endpoint_addrs_uses_one_secondary_chain_key_group_when_no_primary_exists() { let secondary_a = direct("192.0.2.20:4433", false, 5); let secondary_b = direct("192.0.2.21:4433", false, 5); let other_secondary = direct("192.0.2.30:4433", false, 6); @@ -112,7 +114,7 @@ mod tests { } #[test] - fn selected_endpoint_records_uses_one_group_across_sources() { + fn selected_endpoint_records_uses_one_chain_key_across_sources() { let selected = super::selected_endpoint_records([ ("wifi", direct("192.0.2.50:4433", true, 3)), ("ethernet", direct("192.0.2.51:4433", true, 4)), diff --git a/src/resolvers/h3.rs b/src/resolvers/h3.rs deleted file mode 100644 index a54cdc6..0000000 --- a/src/resolvers/h3.rs +++ /dev/null @@ -1,490 +0,0 @@ -use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; - -use dashmap::DashMap; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture, Source}, -}; -use futures::{StreamExt, stream}; -use h3x::{ - dquic::ConnectError, endpoint::H3Endpoint, hyper::RequestError as HyperRequestError, quic, -}; -use http_body_util::{BodyExt, Empty, Full}; -use tokio::time::Instant; -use tracing::trace; -use url::Url; - -use crate::core::{MdnsPacket, parser::packet::be_packet, wire::be_multi_response}; - -const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); -const LOOKUP_REQUEST_ATTEMPTS: usize = 3; - -pub struct H3Resolver { - endpoint: Arc>, - base_url: Url, - cached_records: DashMap, - negative_cache: DashMap, -} - -#[derive(Debug)] -struct Record { - addrs: Vec, - expire: Instant, -} - -impl fmt::Debug for H3Resolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("H3Resolver") - .field("base_url", &self.base_url) - .finish_non_exhaustive() - } -} - -impl fmt::Display for H3Resolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "H3 DNS Resolver({})", self.base_url) - } -} - -#[derive(Debug, snafu::Snafu)] -pub enum Error { - #[snafu(display("h3 stream error"))] - H3Stream { - source: h3x::dhttp::message::MessageStreamError, - }, - #[snafu(display("failed to connect h3 endpoint"))] - Connect { source: h3x::pool::ConnectError }, - #[snafu(display("h3 request error"))] - H3Request { - source: HyperRequestError, - }, - #[snafu(display("h3 request timed out after {timeout:?}"))] - RequestTimeout { timeout: Duration }, - - #[snafu(display("{status}"))] - Status { status: http::StatusCode }, - - #[snafu(display("no DNS record found"))] - NoRecordFound, - - #[snafu(display("failed to parse DNS records from response"))] - ParseRecords { - source: nom::Err>>, - }, - - #[snafu(display("failed to decode multi-record response"))] - ParseMultiResponse, -} - -impl H3Resolver -where - C::Error: Send + Sync + 'static, - C::Connection: Send + 'static, -{ - pub fn new( - base_url: impl AsRef, - client: H3Endpoint, - ) -> io::Result { - Self::from_endpoint(base_url, Arc::new(client)) - } - - pub fn from_endpoint( - base_url: impl AsRef, - endpoint: Arc>, - ) -> io::Result { - let base_url = Url::parse(base_url.as_ref()) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; - base_url.host_str().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "base URL must have a valid host", - ) - })?; - - Ok(Self { - endpoint, - base_url, - cached_records: DashMap::new(), - negative_cache: DashMap::new(), - }) - } - - fn connect_error(&self, source: h3x::pool::ConnectError) -> Error { - // H3 DNS resolvers keep a long-lived endpoint. A network transition may - // leave the cached H3 connection with stale QUIC paths, so the next - // attempt must establish a fresh connection instead of reusing it. - self.endpoint.clear_pool(); - Error::Connect { source } - } - - fn request_error(&self, source: HyperRequestError) -> Error { - self.endpoint.clear_pool(); - Error::H3Request { source } - } - - async fn execute_request( - &self, - request: http::Request< - impl http_body::Body + Send + 'static, - >, - ) -> Result< - http::Response< - impl http_body::Body, - >, - Error, - > { - let authority = request - .uri() - .authority() - .expect("h3 dns request URL must include an authority") - .clone(); - tracing::trace!(%authority, "connecting h3 dns endpoint"); - let connection = match self.endpoint.connect(authority.clone()).await { - Ok(connection) => { - tracing::trace!(%authority, "connected h3 dns endpoint"); - connection - } - Err(source) => return Err(self.connect_error(source)), - }; - - let method = request.method().clone(); - let uri = request.uri().clone(); - tracing::trace!(%method, %uri, "executing h3 dns request"); - match connection.execute_hyper_request(request).await { - Ok(response) => { - tracing::trace!( - status = %response.status(), - "h3 dns request response received" - ); - Ok(response) - } - Err(source) => Err(self.request_error(source)), - } - } - - pub fn clear_pool(&self) { - self.endpoint.clear_pool(); - } - - pub async fn publish_endpoints( - &self, - name: &str, - endpoints: &[EndpointAddr], - ) -> Result<(), Error> { - trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); - let bytes = { - let endpoints = endpoints - .iter() - .filter_map(|ep| { - crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() - }) - .collect(); - let mut hosts = std::collections::HashMap::new(); - hosts.insert(name.to_string(), endpoints); - MdnsPacket::answer(0, &hosts).to_bytes() - }; - - self.publish_packet(name, &bytes).await - } - - /// Publish a pre-built DNS packet (with signatures already included). - pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - tracing::trace!( - name, - packet_len = packet.len(), - url = %self.base_url, - "h3x publishing packet" - ); - let request = http::Request::post(uri) - .body(Full::new(bytes::Bytes::copy_from_slice(packet))) - .expect("h3 dns publish request must be valid"); - let resp = self.execute_request(request).await?; - - if resp.status() != http::StatusCode::OK { - return Err(Error::Status { - status: resp.status(), - }); - } - - Ok(()) - } - - fn retryable_lookup_error(error: &Error) -> bool { - matches!( - error, - Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } - ) - } - - async fn lookup_response(&self, uri: http::Uri) -> Result> { - let request = http::Request::get(uri) - .body(Empty::::new()) - .expect("h3 dns lookup request must be valid"); - let resp = self.execute_request(request).await?; - - tracing::trace!("received response with status {}", resp.status()); - match resp.status() { - http::StatusCode::OK => {} - http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), - status => return Err(Error::Status { status }), - } - - match resp.into_body().collect().await { - Ok(response) => Ok(response.to_bytes()), - Err(source) => Err(Error::H3Stream { source }), - } - } - - async fn lookup_response_with_retry( - &self, - uri: http::Uri, - ) -> Result> { - for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { - match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) - .await - { - Ok(Ok(response)) => return Ok(response), - Ok(Err(error)) - if Self::retryable_lookup_error(&error) - && attempt < LOOKUP_REQUEST_ATTEMPTS => - { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup failed, retrying" - ); - } - Ok(Err(error)) => return Err(error), - Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup timed out, retrying" - ); - } - Err(_elapsed) => { - self.endpoint.clear_pool(); - return Err(Error::RequestTimeout { - timeout: LOOKUP_REQUEST_TIMEOUT, - }); - } - } - } - - unreachable!("lookup retry loop returns on the final attempt") - } - - pub async fn lookup(&self, name: &str) -> Result> { - use crate::core::parser::record; - let server = Arc::from(self.base_url.origin().ascii_serialization()); - let source = Source::H3 { server }; - - let Some(domain) = super::resolvable_name(name) else { - return Err(Error::NoRecordFound); - }; - - let now = Instant::now(); - let positive_ttl = Duration::from_secs(10); - let negative_ttl = Duration::from_secs(2); - - self.cached_records - .retain(|_host, record| record.expire > now); - self.negative_cache.retain(|_host, expire| *expire > now); - - if self.negative_cache.get(domain).is_some() { - return Err(Error::NoRecordFound); - } - - if let Some(record) = self.cached_records.get(domain) { - let addrs = record.addrs.clone(); - let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); - return Ok(stream.boxed()); - } - - let mut url = self.base_url.join("lookup").expect("Invalid URL"); - url.set_query(Some(&format!("host={}", domain))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - - tracing::trace!("sending lookup request to {}", self.base_url); - let response = match self.lookup_response_with_retry(uri).await { - Ok(response) => response, - Err(Error::NoRecordFound) => { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); - } - Err(error) => return Err(error), - }; - - // Server always returns multi-record format. - let (_remain, multi) = - be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; - - let mut endpoint_records = Vec::new(); - for r in multi.records { - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - endpoint_records.extend(packet.answers.iter().filter_map( - |answer| match answer.data() { - record::RData::E(ep) => Some(ep.clone()), - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }, - )); - } - let addrs = crate::resolvers::selector::selected_endpoint_addrs(endpoint_records); - for endpoint in &addrs { - trace!(?endpoint, "parsed endpoint from selected record group"); - } - - if addrs.is_empty() { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); - } - - self.cached_records.insert( - domain.to_string(), - Record { - addrs: addrs.clone(), - expire: now + positive_ttl, - }, - ); - - self.negative_cache.remove(domain); - - Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) - } -} - -pub type H3Publisher = H3Resolver; - -impl Publish for H3Publisher -where - C::Error: Send + Sync + 'static, - C::Connection: Send + 'static, -{ - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - Box::pin(async move { - match self.publish_packet(name, packet).await { - Ok(()) => Ok(()), - Err(error) => Err(io::Error::other(error)), - } - }) - } -} - -impl Resolve for H3Resolver -where - C::Error: Send + Sync + 'static, - C::Connection: Send + 'static, -{ - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - Box::pin(async move { - match H3Resolver::lookup(self, name).await { - Ok(stream) => Ok(stream), - Err(error) => Err(io::Error::other(error)), - } - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::resolvers::DHTTP_H3_DNS_SERVER; - - #[test] - fn lookup_retry_budget_leaves_external_timeout_margin() { - let total_budget = LOOKUP_REQUEST_TIMEOUT * LOOKUP_REQUEST_ATTEMPTS as u32; - - assert!( - total_budget <= Duration::from_secs(10), - "h3 lookup must return before common 15s command timeouts so callers can retry" - ); - } - - #[tokio::test] - async fn cached_lookup_reports_h3_dns_source() { - let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( - h3x::dquic::QuicEndpoint::builder().build().await, - )); - let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "car.lab.dhttp.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, - ); - - let mut records = resolver.lookup("car.lab.dhttp.net").await.unwrap(); - let (source, endpoint) = records.next().await.unwrap(); - - assert_eq!( - source, - Source::H3 { - server: Arc::from(DHTTP_H3_DNS_SERVER) - } - ); - assert_eq!( - endpoint, - EndpointAddr::direct("192.168.5.78:41748".parse().unwrap()) - ); - } - - #[tokio::test] - async fn cached_dns_genmeta_net_record_is_returned() { - let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( - h3x::dquic::QuicEndpoint::builder().build().await, - )); - let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "dns.genmeta.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, - ); - - let mut records = resolver.lookup("dns.genmeta.net").await.unwrap(); - let (_source, endpoint) = records.next().await.unwrap(); - - assert_eq!( - endpoint, - EndpointAddr::direct("192.0.2.53:4433".parse().unwrap()) - ); - } - - #[tokio::test] - async fn cached_lookup_uses_e_record_port_not_input_port() { - let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( - h3x::dquic::QuicEndpoint::builder().build().await, - )); - let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "nat.genmeta.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, - ); - - let mut records = resolver.lookup("nat.genmeta.net:20004").await.unwrap(); - let (_source, endpoint) = records.next().await.unwrap(); - - assert_eq!( - endpoint, - EndpointAddr::direct("192.0.2.10:21000".parse().unwrap()) - ); - } -} diff --git a/src/resolvers/http.rs b/src/resolvers/http.rs deleted file mode 100644 index 03984d9..0000000 --- a/src/resolvers/http.rs +++ /dev/null @@ -1,196 +0,0 @@ -use std::{fmt::Display, io, sync::Arc}; - -use dashmap::DashMap; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, -}; -use futures::{StreamExt, TryFutureExt, stream}; -use reqwest::{Client, IntoUrl, StatusCode, Url}; -use tokio::time::Instant; - -use crate::core::parser::packet::be_packet; - -#[derive(Debug)] -struct Record { - addrs: Vec, - expire: Instant, -} - -#[derive(Debug)] -pub struct HttpResolver { - http_client: Client, - base_url: Url, - cached_records: DashMap, -} - -impl Display for HttpResolver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Http DNS({})", - self.base_url.host_str().expect("checked in constructor") - ) - } -} - -impl HttpResolver { - pub fn new(base_url: impl IntoUrl) -> io::Result { - let base_url = base_url - .into_url() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - base_url.host_str().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "base URL must have a valid host", - ) - })?; - - Ok(Self { - http_client: build_http_client()?, - base_url, - cached_records: DashMap::new(), - }) - } -} - -fn build_http_client() -> io::Result { - let native_certs = rustls_native_certs::load_native_certs(); - for error in &native_certs.errors { - let report = snafu::Report::from_error(error); - tracing::warn!(error = %report, "failed to load native root certificate"); - } - - let mut root_store = rustls::RootCertStore::empty(); - let (valid_roots, invalid_roots) = root_store.add_parsable_certificates(native_certs.certs); - if invalid_roots > 0 { - tracing::debug!(invalid_roots, "ignored invalid native root certificates"); - } - if valid_roots == 0 { - tracing::warn!("no native root certificates loaded for http resolver"); - } - - let mut tls = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - tls.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - Client::builder() - .use_preconfigured_tls(tls) - .build() - .map_err(io::Error::other) -} - -#[derive(Debug, snafu::Snafu)] -enum Error { - #[snafu(display("http request failed"))] - Reqwest { source: reqwest::Error }, - - #[snafu(display("{status}"))] - Status { status: StatusCode }, - - #[snafu(display("no DNS record found"))] - NoRecordFound, - - #[snafu(display("failed to parse DNS records from response"))] - ParseRecords { - source: nom::Err>>, - }, -} - -impl From for Error { - fn from(source: reqwest::Error) -> Self { - match source.status() { - Some(stateus) if stateus == StatusCode::NOT_FOUND => Error::NoRecordFound, - Some(status) => Error::Status { status }, - None => Error::Reqwest { - source: source.without_url(), - }, - } - } -} - -impl Publish for HttpResolver { - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - Box::pin(async move { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let response = self - .http_client - .post(url) - .header("Content-Type", "application/octet-stream") - .body(packet.to_vec()) - .send() - .await - .map_err(io::Error::other)?; - - let _response = response.error_for_status().map_err(io::Error::other)?; - Ok(()) - }) - } -} - -impl Resolve for HttpResolver { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - let lookup = async move { - let Some(domain) = super::resolvable_name(name) else { - return Err(Error::NoRecordFound); - }; - - let now = Instant::now(); - let server = Arc::from(self.base_url.host_str().unwrap_or("")); - let soource = Source::Http { server }; - - use crate::core::parser::record; - self.cached_records - .retain(|_host, Record { expire, .. }| *expire < now); - if let Some(record) = self.cached_records.get(domain) { - let endpoint_addrs: Vec<_> = record - .addrs - .iter() - .map(|endpoint: &EndpointAddr| (soource.clone(), *endpoint)) - .collect(); - return Ok(stream::iter(endpoint_addrs).boxed()); - } - let response = self - .http_client - .get(self.base_url.join("lookup").expect("Invalid URL")) - .query(&[("host", domain)]) - .send() - .await; - - let response = response?.error_for_status()?.bytes().await?; - - let (_remain, packet) = be_packet(&response).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - let endpoints = packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => Some(ep.clone()), - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }); - let addrs = crate::resolvers::selector::selected_endpoint_addrs(endpoints); - if addrs.is_empty() { - return Err(Error::NoRecordFound); - } - - // cache the addrs - self.cached_records.insert( - domain.to_string(), - Record { - addrs: addrs.clone(), - expire: now + std::time::Duration::from_secs(300), - }, - ); - - Ok(stream::iter(addrs.into_iter().map(move |ep| (soource.clone(), ep))).boxed()) - }; - Box::pin(lookup.map_err(io::Error::other)) - } -} diff --git a/tests/feature_surface.rs b/tests/feature_surface.rs new file mode 100644 index 0000000..b9a06e5 --- /dev/null +++ b/tests/feature_surface.rs @@ -0,0 +1,58 @@ +#[cfg(feature = "h3")] +#[test] +fn h3_backend_module_is_public() { + #[allow(unused_imports)] + use ddns::h3; +} + +#[cfg(feature = "http")] +#[test] +fn http_backend_module_is_public() { + #[allow(unused_imports)] + use ddns::http; +} + +#[cfg(feature = "mdns")] +#[test] +fn mdns_module_is_public() { + #[allow(unused_imports)] + use ddns::mdns; +} + +#[test] +fn resolvers_module_is_public() { + #[allow(unused_imports)] + use ddns::resolvers; +} + +#[test] +fn publishers_module_is_public() { + #[allow(unused_imports)] + use ddns::publishers; +} + +#[cfg(all(feature = "http", feature = "resolvers", feature = "publishers"))] +#[test] +fn http_backend_is_reexported_from_both_facades() { + use ddns::{ + http::HttpResolver, publishers::HttpPublisher, + resolvers::HttpResolver as FacadeHttpResolver, + }; + + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); +} + +#[cfg(all(feature = "mdns", feature = "resolvers", feature = "publishers"))] +#[test] +fn mdns_backend_is_reexported_from_both_facades() { + use ddns::{ + mdns::MdnsResolver, publishers::MdnsPublisher, + resolvers::MdnsResolver as FacadeMdnsResolver, + }; + + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); +} diff --git a/tests/h3_generic_surface.rs b/tests/h3_generic_surface.rs new file mode 100644 index 0000000..248c729 --- /dev/null +++ b/tests/h3_generic_surface.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "h3")] +#[test] +fn h3_backend_and_facades_export_the_same_type_name() { + use ddns::{ + h3::H3Resolver, publishers::H3Publisher, resolvers::H3Resolver as FacadeH3Resolver, + }; + + let _ = core::any::type_name::>(); + let _ = core::any::type_name::>(); + let _ = core::any::type_name::>(); +} diff --git a/tests/publishers_surface.rs b/tests/publishers_surface.rs new file mode 100644 index 0000000..526e0a7 --- /dev/null +++ b/tests/publishers_surface.rs @@ -0,0 +1,19 @@ +#[cfg(feature = "publishers")] +#[test] +fn publishers_facade_exposes_publisher_and_aggregate_types() { + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +#[test] +fn publishers_facade_exposes_network_publication_loop_surface() { + let _ = ddns::publishers::DEFAULT_PUBLISH_INTERVAL; + let _ = ddns::publishers::DEFAULT_PUBLISH_TIMEOUT; + let _ = core::any::type_name::< + ddns::publishers::EndpointPublicationLoop, + >(); +}