diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index b4ef695..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/.github/workflows/publish-crates.yml b/.github/workflows/publish-crates.yml new file mode 100644 index 0000000..bfdffde --- /dev/null +++ b/.github/workflows/publish-crates.yml @@ -0,0 +1,105 @@ +name: Publish crates.io + +on: + pull_request: + workflow_dispatch: + push: + branches: + - main + tags: + - "v*" + +env: + CARGO_TERM_COLOR: always + +jobs: + release: + runs-on: ubuntu-latest + permissions: + contents: read + id-token: write + + steps: + - uses: actions/checkout@v4 + + - name: Install Rust stable toolchain + uses: actions-rust-lang/setup-rust-toolchain@v1 + + - name: Test crate + run: cargo test --all-features --all-targets + + - name: Authenticate to crates.io + if: github.ref_type == 'tag' && startsWith(github.ref_name, 'v') + uses: rust-lang/crates-io-auth-action@v1 + id: auth + + - name: Release dyns crate + shell: bash + env: + CARGO_REGISTRY_TOKEN: ${{ steps.auth.outputs.token }} + run: | + set -euo pipefail + + if [[ "${GITHUB_REF_TYPE}" == "tag" && "${GITHUB_REF_NAME}" == v* ]]; then + mode=publish + else + mode=dry-run + fi + + package_name=dyns + package_version="$(cargo metadata --no-deps --format-version 1 | python3 -c 'import json, sys; print(json.load(sys.stdin)["packages"][0]["version"])')" + + crate_state="$( + python3 - <<'PY' "$package_name" "$package_version" + import sys + import urllib.error + import urllib.request + + name, version = sys.argv[1], sys.argv[2] + headers = {"User-Agent": "genmeta ddns publish workflow"} + version_url = f"https://crates.io/api/v1/crates/{name}/{version}" + version_request = urllib.request.Request(version_url, headers=headers) + try: + with urllib.request.urlopen(version_request, timeout=20) as response: + if response.status == 200: + print("published_version") + else: + raise SystemExit(f"unexpected crates.io status for {name} {version}: {response.status}") + except urllib.error.HTTPError as error: + if error.code == 404: + crate_url = f"https://crates.io/api/v1/crates/{name}" + crate_request = urllib.request.Request(crate_url, headers=headers) + try: + with urllib.request.urlopen(crate_request, timeout=20) as response: + if response.status == 200: + print("missing_version") + else: + raise SystemExit(f"unexpected crates.io crate status for {name}: {response.status}") + except urllib.error.HTTPError as crate_error: + if crate_error.code == 404: + print("missing_crate") + else: + raise + else: + raise + PY + )" + + if [[ "$crate_state" == "published_version" ]]; then + echo "skip $package_name $package_version (already on crates.io)" + exit 0 + fi + + if [[ "$crate_state" == "missing_crate" ]]; then + echo "skip $package_name $package_version (crate not yet initialized on crates.io)" + exit 0 + fi + + if [[ "$mode" == "dry-run" ]]; then + echo "dry-run $package_name $package_version" + cargo publish --dry-run --locked + exit 0 + fi + + echo "publish $package_name $package_version" + cargo publish --locked diff --git a/.gitignore b/.gitignore index 1daef7f..57506c5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ Cargo.lock *.log build + +.DS_Store +.vscode/ diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 2afd034..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "lldb.showDisassembly": "auto", - "lldb.dereferencePointers": true, - "lldb.consoleMode": "commands" -} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 89332b4..b84ddea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,40 +1,39 @@ -[workspace] -members = ["gmdns-server"] -resolver = "2" - [package] -name = "gmdns" -version = "0.2.0" +name = "dyns" +description = "DNS discovery and resolver support for DHTTP applications" +version = "0.3.0" edition = "2024" +license = "Apache-2.0" +repository = "https://github.com/genmeta/ddns" +readme = "README.md" +keywords = ["dhttp", "dns", "mdns", "http3", "quic"] +categories = ["network-programming", "asynchronous"] autoexamples = false +[lib] +name = "ddns" + [dependencies] base64 = "0.22" -bitfield-struct = "0.10" +bitfield-struct = "0.13" bytes = "1" dashmap = "6" +dhttp-identity = "0.1.0" +dquic = "0.5.1" flume = "0.12" futures = "0.3" libc = "0.2" nom = "8" -rand = "0.9" -reqwest = { version = "0.12", default-features = false, features = [ - "charset", - "rustls-tls", - "http2", - "macos-system-configuration", - "json", -] } +rand = "0.10" ring = "0.17" rustls = { version = "0.23", default-features = false, features = [ "logging", "ring", ] } +rustls-native-certs = { version = "0.8", optional = true } rustls-pemfile = "2" -serde = "1" -shellexpand = "3" -snafu = "0.8" -socket2 = { version = "0.5.8", features = ["all"] } +snafu = "0.9" +socket2 = { version = "0.6", features = ["all"] } tokio = { version = "1", features = [ "time", "macros", @@ -42,31 +41,71 @@ tokio = { version = "1", features = [ "sync", "rt", "rt-multi-thread", + "io-util", ] } -tokio-util = { version = "0.7", features = ["rt"] } tracing = "0.1" -url = "2" x509-parser = "0.18" -# Optional HTTP/3 publisher/resolver via h3x -h3x = { git = "https://github.com/genmeta/h3x.git", branch = "main", default-features = false, features = [ - "dquic", -], optional = true } +h3x = { version = "0.3.1", default-features = false, optional = true } http = { version = "1", optional = true } +http-body = { version = "1", optional = true } +http-body-util = { version = "0.1", optional = true } +reqwest = { version = "0.13", default-features = false, features = [ + "charset", + "http2", + "json", + "query", + "rustls-no-provider", + "system-proxy", +], optional = true } +url = { version = "2", optional = true } + +clap = { version = "4", features = ["derive"], optional = true } +deadpool-redis = { version = "0.23", optional = true } +idna = { version = "1", optional = true } +serde = { version = "1", features = ["derive"], optional = true } +toml = { version = "1", optional = true } +tower-service = { version = "0.3", optional = true } +tracing-subscriber = { version = "0.3", features = [ + "env-filter", +], optional = true } [features] -default = ["h3x-resolver"] -h3x-resolver = ["dep:h3x", "dep:http"] +default = [] +h3x-resolver = [ + "dep:h3x", + "h3x/dquic", + "h3x/hyper", + "dep:http", + "dep:http-body", + "dep:http-body-util", + "dep:url", +] +mdns-resolver = ["dep:h3x", "h3x/dquic"] +http-resolver = ["dep:reqwest", "dep:rustls-native-certs"] +server = [ + "h3x-resolver", + "dep:clap", + "dep:deadpool-redis", + "dep:idna", + "dep:serde", + "dep:toml", + "dep:tower-service", + "dep:tracing-subscriber", +] [dev-dependencies] -criterion = "0.5" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -rustls-pki-types = "1" -tracing-appender = "0.2" -# examples: publish / query clap = { version = "4", features = ["derive"] } -idna = "1" -serde = { version = "1", features = ["derive"] } +h3x = { version = "0.3.1", default-features = false, features = [ + "dquic", +] } +shellexpand = "3" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[[bin]] +name = "ddns-server" +path = "src/bin/ddns-server/main.rs" +required-features = ["server"] [[example]] name = "mdns_discover" diff --git a/README.md b/README.md index 1f50207..071aaaa 100644 --- a/README.md +++ b/README.md @@ -1,208 +1,228 @@ -# GMDNS +# DDNS -GMDNS is a high-performance mDNS (Multicast DNS) protocol library built with Rust, specifically designed for P2P network discovery and NAT traversal scenarios. It supports the standard RFC 6762 protocol while extending endpoint discovery capabilities through custom resource records, enabling publication and verification of both direct and relay addresses. Additionally, it integrates HTTP/3 support for secure DNS over HTTP/3 (DoH3) interactions with remote DNS servers. +`ddns` provides DNS discovery and resolver support for DHTTP applications. It is a +single Rust package: the historical `ddns-core`, `gmdns`, `ddns`, and +`ddns-server` crate boundaries now live as modules and feature-gated targets in +one published Cargo package named `dyns`, with a library target kept as `ddns` +for source compatibility. -## 🌟 Key Features +## Crate layout -- **Standards Compliant**: Supports standard DNS packet format and mDNS multicast discovery. -- **P2P Enhanced**: Custom `E` record type supporting IPv4/IPv6 direct and relay addresses. -- **Security Verification**: Built-in signature schemes (Ed25519, etc.) ensuring endpoint data authenticity and integrity. -- **High Performance Parsing**: Zero-copy parsing framework based on `nom` for blazing-fast packet processing. -- **Async-Driven**: Fully compatible with `tokio` async runtime for high-concurrency network environments. -- **HTTP/3 Integration**: Supports DNS over HTTP/3 (DoH3) for secure remote DNS queries and publishing. +| Module / target | Role | +| --- | --- | +| `ddns::core` | DNS packet parser, resource-record types, endpoint `E` record encoding, and HTTP multi-record response wire format. | +| `ddns::mdns` | RFC 6762 multicast DNS transport, LAN publisher, and LAN resolver support. | +| `ddns::resolvers` | Resolver chain plus optional System, mDNS, DNS-over-H3, and DNS-over-HTTP resolvers. | +| `ddns::publisher` | Feature-gated endpoint record signing and publishing loop helpers for DHTTP endpoints. | +| `ddns-server` | DNS-over-H3 publish/lookup server binary, enabled by the `server` feature. | -## 🚀 Quick Start - -Add to your `Cargo.toml`: +`ddns` is endpoint-facing support code for the DHTTP ecosystem. Applications +normally reach it through the `dhttp` endpoint facade; lower-level consumers can +depend on package `dyns` directly (typically renamed locally to `ddns`) when +they need DNS wire types, resolver composition, mDNS, or the DNS-over-H3 server. ```toml -[dependencies] -gmdns = { path = "../gmdns" } +ddns = { package = "dyns", version = "0.3.0" } ``` -For HTTP/3 features, enable the `h3x-resolver` feature: +## Features -```toml -[dependencies] -gmdns = { path = "../gmdns", features = ["h3x-resolver"] } -``` +All optional integrations are feature-gated; the default feature set is empty. + +| Feature | Enables | +| --- | --- | +| `h3x-resolver` | DNS-over-H3 resolver and publisher using `h3x`/`dquic`. | +| `mdns-resolver` | mDNS resolver integration backed by an existing `h3x::dquic::Network`. | +| `http-resolver` | DNS-over-HTTP resolver/publisher using `reqwest` and native roots. | +| `server` | `ddns-server`, Redis storage support, TOML config parsing, and tracing setup. | + +## Bootstrap constants + +`build.rs` generates the resolver defaults exposed from `ddns::resolvers`: + +| Environment variable | Public constant | Fallback when unset | +| --- | --- | --- | +| `DHTTP_H3_DNS_SERVER` | `DHTTP_H3_DNS_SERVER` | `https://dhttp.example.net` | +| `DHTTP_HTTP_DNS_SERVER` | `DHTTP_HTTP_DNS_SERVER` | `https://dhttp.example.net` | +| `DHTTP_MDNS_SERVICE` | `DHTTP_MDNS_SERVICE` | `dhttp.example.net` | + +The fallbacks are docs/build placeholders, not operational defaults. Real +endpoint, server, and E2E runs should set the DHTTP bootstrap environment before +building. + +## Quick start -### Simple mDNS Discovery Example +### Resolver chain + +`Resolvers` queries all configured resolvers and streams endpoint addresses from +successful backends. System DNS is always available; mDNS, H3, and HTTP builders +appear behind their features. ```rust -use gmdns::mdns::Mdns; +use ddns::resolvers::Resolvers; use futures::StreamExt; #[tokio::main] -async fn main() -> Result<(), std::io::Error> { - // Create mDNS instance - let mdns = Mdns::new("_genmeta.local", "127.0.0.1".parse().unwrap(), "lo0")?; - - // Listen to discovery stream - let mut stream = mdns.discover(); - while let Some((addr, packet)) = stream.next().await { - println!("Discovered packet from {}: {:?}", addr, packet); +async fn main() -> Result<(), ddns::resolvers::DnsErrors> { + let resolvers = Resolvers::builder().system().build(); + let mut endpoints = resolvers.lookup("demo.example.dhttp.net").await?; + + while let Some((source, endpoint)) = endpoints.next().await { + println!("{source:?}: {endpoint}"); } + Ok(()) } ``` -### HTTP/3 DNS Publishing Example +### mDNS discovery ```rust -use gmdns::{resolver::h3_resolver::H3Resolver, parser::record::endpoint::EndpointAddr}; -use std::path::Path; +use ddns::{mdns::service::Mdns, resolvers::DHTTP_MDNS_SERVICE}; +use futures::StreamExt; -#[tokio::main] -async fn main() -> Result<(), Box> { - let resolver = H3Resolver::new( - "https://localhost:4433/", - Path::new("examples/keychain/localhost/ca.cert"), - "client", - Path::new("examples/keychain/localhost/client.cert"), - Path::new("examples/keychain/localhost/client.key"), +#[tokio::main(flavor = "current_thread")] +async fn main() -> std::io::Result<()> { + let mdns = Mdns::new( + DHTTP_MDNS_SERVICE, + std::net::Ipv4Addr::LOCALHOST.into(), + "lo0", )?; + let mut discoveries = mdns.discover(); + + while let Some((source, packet)) = discoveries.next().await { + println!("received packet from {source}: {packet}"); + } - // Publish a DNS record - let endpoint = EndpointAddr::direct_v4("127.0.0.1:5555".parse()?); - resolver.publish("client.genmeta.net", &[endpoint]).await?; Ok(()) } ``` ---- - -## 🌐 HTTP/3 DNS Server - -GMDNS includes support for DNS over HTTP/3 (DoH3), allowing secure publication and querying of DNS records via HTTP/3 protocol. This is useful for remote networks where multicast mDNS is not feasible. - -### Publishing Services - -Publish DNS service records to an HTTP/3 DNS server: +Runnable examples live in `examples/`: ```bash -cargo run --example publish --features="h3x-resolver" \ - --server-ca /path/to/root.crt \ - --client-name demo.example.genmeta.net \ - --client-cert /path/to/demo.example.genmeta.net.pem \ - --client-key /path/to/demo.example.genmeta.net.key \ - --host demo.example.genmeta.net \ - --addr 192.168.1.100:8080 +cargo run --example mdns_discover -- --ip 127.0.0.1 --device lo0 +cargo run --example mdns_query -- --ip 192.168.5.156 --device en0 ``` -### Querying Services - -Query DNS service records from an HTTP/3 DNS server: +### DNS-over-H3 examples ```bash -cargo run --example query --features="h3x-resolver" \ +cargo run --example query --features h3x-resolver -- \ --server-ca /path/to/root.crt \ - --host stun.genmeta.net + --host nat.genmeta.net + +cargo run --example publish --features h3x-resolver -- \ + --server-ca /path/to/root.crt \ + --client-name demo.example.dhttp.net \ + --client-cert /path/to/demo.example.dhttp.net.pem \ + --client-key /path/to/demo.example.dhttp.net.key \ + --host demo.example.dhttp.net \ + --addr 192.168.1.100:8080,192.168.1.101:8080 ``` -### Running the DNS Server +See [`examples/README.md`](examples/README.md) for the example CLI parameters +and response decoding notes. + +## DNS-over-H3 server -Start an HTTP/3 DNS server: +Start the server with the `server` feature: ```bash -cargo run --example server --features="h3x-resolver" \ - --listen 127.0.0.1:4433 \ - --cert examples/keychain/localhost/server.cert \ - --key examples/keychain/localhost/server.key +cargo run --bin ddns-server --features server -- --config server.toml ``` -For detailed parameters and HTTP packet structures, see [examples/README.md](examples/README.md). - ---- - -## 📖 Protocol Specification +The server exposes two HTTP/3 routes: -### 1. Packet Layout +| Route | Meaning | +| --- | --- | +| `POST /publish?host=` | Publish a DNS packet for `host`. Client mTLS is required. | +| `GET /lookup?host=[&limit=N]` | Look up active records for `host`; `limit` caps newest-first dynamic records. | -DNS packets consist of a fixed header and four variable-length sections: +Lookup responses use header `x-record-format: multi` and the binary body from +`ddns::core::wire::MultiResponse`: ```text -+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ -| Header (12 bytes) | Question Section | Answer Section | Nameserver Section | Additional Section | -+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ -| Transaction ID | Query list | Answer RR list | Authority RR list | Additional RR list | -| and Flags | | | | | -+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +u32 count +repeated count times: + u32 dns_len | dns packet bytes | u32 cert_len | DER publisher certificate bytes ``` -#### 1.1 Header -Fixed length of 12 bytes. Contains ID, Flags, and counters for subsequent sections (QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT). +Server configuration lives in `server.toml`: -#### 1.2 Resource Record -Answer, Nameserver, and Additional sections all use this format: +- storage is in-memory by default, or Redis when `redis = "redis://..."` is set; +- `ttl_secs` controls dynamic record expiry; +- `require_signature` controls signed endpoint-record enforcement for Standard + domains; +- `domain_policies` are matched in order, with unlisted domains using the + Standard policy; +- `seed_records` add static bootstrap endpoints to lookup results. -- **NAME**: Variable-length domain name, supports RFC 1035 compression. -- **TYPE (u16)**: Record type (e.g., A=1, SRV=33, E=266). -- **CLASS (u16)**: Protocol class. In mDNS, the highest bit (bit 15) is used for cache-flush flag. -- **TTL (u32)**: Cache time-to-live (seconds). -- **RDLEN (u16)**: Length of resource data (RDATA). -- **RDATA**: Specific resource content, format determined by TYPE. +Domain policies: -### 2. Custom Type Definitions (QType) +| Policy | Behavior | +| --- | --- | +| `standard` | Client certificate DNS SAN must match the published host; signed `E` records are required when `require_signature = true`; each certificate fingerprint owns one active record for the host. | +| `open_multi` | Any authenticated client certificate may publish; signature checks are skipped; multiple certificate fingerprints can coexist and lookup returns newest-first records. | -| Type | Value | Description | RDATA Format | -| :------- | :---- | :--------------- | :-------------------------------- | -| **A** | 1 | IPv4 address | 4-byte IP | -| **AAAA** | 28 | IPv6 address | 16-byte IP | -| **SRV** | 33 | Service location | Priority + Weight + Port + Target | -| **E** | 266 | Endpoint address | Flags + Seq + Addr(s) + [Sig] | +Public DHTTP identity hostnames should use the canonical `DhttpName::SUFFIX` +(`.dhttp.net`). Infrastructure names such as `nat.genmeta.net` can remain under +Genmeta infrastructure domains. -### 3. Endpoint Extensions (Type E) +## Endpoint `E` records -#### 3.1 RDATA Wire Format - -##### Packet Format +Custom DNS record type `E` (`QTYPE = 266`) carries DHTTP endpoint addresses. The +current wire format is: ```text -+--------+-----------------+--------------------+----------------------------+ -| flags | sequence(varint)| addr(s) | signature (optional) | -+--------+-----------------+--------------------+----------------------------+ -| u8 | QUIC varint | v4: 2+4 / v6: 2+16 | scheme(u16)+len(varint)+N | -+--------+-----------------+--------------------+----------------------------+ +flags(u8) +[sequence(varint) if CLUSTERED] +primary address: port(u16) + IPv4/IPv6 bytes +[agent address if NAT] +[load(f32) if LOAD] +[signature: scheme(u16) + len(varint) + bytes if SIGNED] ``` -##### flags (u8) Field Definition: -- bit 7 (0x80): **FAMILY** - Address family (0=IPv4, 1=IPv6) -- bit 6 (0x40): **MAIN** - Primary address flag -- bit 5 (0x20): **SEQUENCED** - Sequence number present -- bit 4 (0x10): **FORWARD** - Connection type (0=direct, 1=relay) -- bit 3 (0x08): **SIGNED** - Signature present -- bits 2-0: Reserved - -##### Address Format: -- **Direct**: `port(u16)` + `IP(u32/u128)` -- **Relay**: `outer_port(u16)` + `outer_IP(u32/u128)` + `agent_port(u16)` + `agent_IP(u32/u128)` -- **sequence**: DNS record sequence number. Records with the same sequence are considered from the same machine and can use multipath connections. -- **signature**: When `SIGNED` flag is set, signature field is appended. - -#### 3.2 Flag Bit Masks - -- `0b1000_0000`: **FAMILY** (Address family: 0=IPv4, 1=IPv6) -- `0b0100_0000`: **MAIN** (Primary address flag) -- `0b0010_0000`: **SEQUENCED** (Sequence number present) -- `0b0001_0000`: **FORWARD** (Connection type: 0=direct, 1=relay) -- `0b0000_1000`: **SIGNED** (Signature present) +Flag bits: -#### 3.3 Address Format Details +| Bit mask | Name | Meaning | +| --- | --- | --- | +| `0x80` | `FAMILY` | `0` = IPv4, `1` = IPv6. | +| `0x40` | `MAIN` | Primary endpoint for the name. | +| `0x20` | `CLUSTERED` | Sequence number is present; multiple publishers share the name. | +| `0x10` | `NAT` | Agent address is present for NAT traversal. | +| `0x08` | `LOAD` | One-minute load value is present. | +| `0x01` | `SIGNED` | Signature with explicit TLS signature scheme is present. | -- **Direct**: `Port(u16)` + `IP(u32/u128)` -- **Relay**: `OuterPort(u16)` + `OuterIP(u32/u128)` + `AgentPort(u16)` + `AgentIP(u32/u128)` +For DHTTP endpoint publishing, `MAIN` and `sequence` are derived from the +publisher certificate's DHTTP subject key identifier. Operators do not choose +these fields manually: `primary` certificates publish `MAIN = true`, +`secondary` certificates publish `MAIN = false`, and the certificate-chain +sequence becomes the normalized endpoint-record sequence. An omitted sequence +field means sequence `0`. -#### 3.4 Signature Format +Signed records encode the signature scheme in the record; the no-scheme signed +format is not accepted. Legacy unsigned fixed-length endpoint address records are +still parsed by length for address-only compatibility. -When signature is present: `Scheme (u16)` + `Length (VarInt)` + `Data (N bytes)`. +## Project structure ---- - -## 🛠 Project Structure - -- `src/parser/`: Core protocol parsing implementation (Nom parsers). -- `src/protocol.rs`: UDP multicast and packet routing logic. -- `src/mdns.rs`: High-level mDNS discovery and response API. -- `src/resolver/`: HTTP/3 resolver implementation for DoH3 support. -- `examples/`: Sample code including mDNS discovery/query, and HTTP/3 publishing/querying/server examples. +```text +src/core.rs DNS core module root +src/core/parser/ DNS packet, name, question, record, varint, and signature parsers +src/core/parser/record/ A/AAAA/SRV/TXT/PTR/CNAME/E record parsing and encoding +src/core/wire.rs HTTP multi-record response wire format +src/mdns.rs mDNS module root +src/mdns/protocol.rs UDP multicast socket and packet routing +src/mdns/service.rs High-level mDNS service API +src/mdns/resolvers/ mDNS resolver integration +src/resolvers.rs Resolver chain and resolver defaults +src/resolvers/h3.rs DNS-over-H3 resolver/publisher +src/resolvers/http.rs DNS-over-HTTP resolver/publisher +src/resolvers/deferred.rs Deferred resolver initialization helper +src/publisher.rs Endpoint record signer and publication loop +src/publisher/ Address selection, publish dispatch, packet signing +src/bin/ddns-server/ DNS-over-H3 server implementation +examples/ mDNS and DNS-over-H3 example programs +server.toml Example server configuration +``` diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..8794f25 --- /dev/null +++ b/build.rs @@ -0,0 +1,71 @@ +use std::{env, fs, path::PathBuf}; + +const H3_DNS_SERVER_ENV: &str = "DHTTP_H3_DNS_SERVER"; +const HTTP_DNS_SERVER_ENV: &str = "DHTTP_HTTP_DNS_SERVER"; +const MDNS_SERVICE_ENV: &str = "DHTTP_MDNS_SERVICE"; + +const DEFAULT_H3_DNS_SERVER: &str = "https://dhttp.example.net"; +const DEFAULT_HTTP_DNS_SERVER: &str = "https://dhttp.example.net"; +const DEFAULT_MDNS_SERVICE: &str = "dhttp.example.net"; + +fn main() { + let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR is set by cargo")); + + let h3_dns_server = env_or_default(H3_DNS_SERVER_ENV, DEFAULT_H3_DNS_SERVER); + let http_dns_server = env_or_default(HTTP_DNS_SERVER_ENV, DEFAULT_HTTP_DNS_SERVER); + let mdns_service = env_or_default(MDNS_SERVICE_ENV, DEFAULT_MDNS_SERVICE); + + let bootstrap = format!( + "// @generated by build.rs; do not edit.\n\ + pub const DHTTP_H3_DNS_SERVER: &str = {h3_dns_server:?};\n\ + pub const DHTTP_HTTP_DNS_SERVER: &str = {http_dns_server:?};\n\ + pub const DHTTP_MDNS_SERVICE: &str = {mdns_service:?};\n" + ); + fs::write(out_dir.join("bootstrap.rs"), bootstrap) + .expect("failed to write generated DHTTP DNS bootstrap constants"); + + println!("cargo::rerun-if-env-changed={H3_DNS_SERVER_ENV}"); + println!("cargo::rerun-if-env-changed={HTTP_DNS_SERVER_ENV}"); + println!("cargo::rerun-if-env-changed={MDNS_SERVICE_ENV}"); +} + +fn env_or_default(name: &str, default: &str) -> String { + env::var(name).unwrap_or_else(|_| default.to_owned()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn missing_env_uses_dhttp_example_net_placeholder() { + let name = format!("__DDNS_MISSING_BOOTSTRAP_{}", std::process::id()); + + assert_eq!( + env_or_default(&name, DEFAULT_H3_DNS_SERVER), + "https://dhttp.example.net" + ); + assert_eq!( + env_or_default(&name, DEFAULT_HTTP_DNS_SERVER), + "https://dhttp.example.net" + ); + assert_eq!( + env_or_default(&name, DEFAULT_MDNS_SERVICE), + "dhttp.example.net" + ); + } + + #[test] + fn placeholder_urls_do_not_include_explicit_ports() { + for url in [DEFAULT_H3_DNS_SERVER, DEFAULT_HTTP_DNS_SERVER] { + let authority = url + .strip_prefix("https://") + .expect("placeholder URL uses https") + .split('/') + .next() + .expect("placeholder URL has an authority"); + + assert_eq!(authority, "dhttp.example.net"); + } + } +} diff --git a/examples/README.md b/examples/README.md index da10ec1..c1e5780 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,120 +1,109 @@ -# DNS Server Documentation +# DDNS examples -## Introduction +This directory contains runnable examples for the single published `dyns` +package, whose library target remains `ddns`. -`gmdns` is a Rust-implemented DNS library that supports the mDNS (Multicast DNS) protocol and interacts with DNS servers via the HTTP/3 (H3) protocol for service discovery and publishing in local and remote networks. This document introduces how to use the example programs of `gmdns` to publish and query DNS services, including detailed program parameters and HTTP packet structures. +| Example | Feature requirement | Purpose | +| --- | --- | --- | +| `mdns_discover` | none | Bind an mDNS service, publish sample local hosts, and print multicast packets. | +| `mdns_query` | none | Query a DHTTP name over local mDNS. | +| `query` | `h3x-resolver` | Query a DNS-over-H3 server and decode the multi-record response. | +| `publish` | `h3x-resolver` | Publish signed endpoint `E` records to a DNS-over-H3 server using client mTLS. | -## Building the Project +Run all commands from the `ddns/` repository. -First, ensure you have a Rust environment. Clone or enter the project directory, then build: +## mDNS examples + +Bind to a local interface and print multicast traffic: ```bash -cargo build --features="h3x-resolver" +cargo run --example mdns_discover -- \ + --ip 127.0.0.1 \ + --device lo0 ``` -Note: The example programs require the `h3x-resolver` feature to enable HTTP/3 support. +Query a name over mDNS: -## HTTP Packet Structure Overview +```bash +cargo run --example mdns_query -- \ + --ip 192.168.5.156 \ + --device en0 +``` -`gmdns` uses the HTTP/3 protocol to transmit DNS queries and responses, similar to DNS over HTTPS (DoH) but based on the QUIC protocol. The structure of HTTP requests is as follows: +Replace `--ip` and `--device` with an address and interface that exist on the +local machine. The mDNS service name defaults to the build-time +`DHTTP_MDNS_SERVICE` constant. -### URL Structure -- **Base URL**: Default `https://localhost:4433/`, used to specify the DNS server's address. -- **Path**: For queries, usually the root path `/`, the server parses the DNS query based on the request body. -- **Query Parameters**: Optional, used to specify query type or options. +## DNS-over-H3 query -### HTTP Headers -- **Content-Type**: `application/dns-message` (for DNS message body) or `application/json` (if using JSON format). -- **Accept**: `application/dns-message` or `application/json`. -- **User-Agent**: Client identifier. -- **Authorization**: If authentication is needed, use Bearer token or other mechanisms. +```bash +cargo run --example query --features h3x-resolver -- \ + --server-ca /path/to/root.crt \ + --host nat.genmeta.net +``` -### Request Body (Body) -- DNS queries are sent in binary DNS message format (RFC 1035), containing query name, type (such as A, AAAA, SRV), and class. -- For publishing, the request body contains the DNS record data to be published. +Options: -### Response Body -- The server returns a DNS response message containing query results or confirmation of publishing. +| Option | Meaning | +| --- | --- | +| `--base-url ` | DNS-over-H3 server base URL. Defaults to build-time `DHTTP_H3_DNS_SERVER` with a trailing slash. | +| `--server-ca ` | PEM root CA used to verify the DNS server certificate. | +| `--host ` | DNS host to query. Defaults to `nat.genmeta.net`. | -## Usage Examples +The example sends `GET /lookup?host=`. A successful server response is a +`ddns::core::wire::MultiResponse` body with header `x-record-format: multi`: -### Publishing Services (publish) +```text +u32 count +repeated count times: + u32 dns_len | dns packet bytes | u32 cert_len | DER publisher certificate bytes +``` -Use the `publish` example to publish a DNS service record to the HTTP/3 DNS server. +The example prints each DNS packet, the publisher certificate fingerprint when a +certificate is present, and endpoint signature verification status for signed +`E` records. -#### Program Parameters -- `--base-url `: Base URL of the DNS server (default: `https://dns.genmeta.net:4433/`). -- `--server-ca `: CA certificate PEM file path for verifying the online server certificate. -- `--client-name `: Client identity name used for mTLS. -- `--client-cert `: Client certificate chain PEM file. -- `--client-key `: Client private key PEM file. -- `--sign`: Whether to sign the Endpoint record with the client private key (default: true). -- `--host `: DNS name to publish, must match the SAN in the client certificate. -- `--addr `: List of socket addresses to publish, separated by commas. -- `--is-main`: Whether it is the main record (default: true). +## DNS-over-H3 publish -#### Example Run Command ```bash -cargo run --example publish --features="h3x-resolver" \ +cargo run --example publish --features h3x-resolver -- \ --server-ca /path/to/root.crt \ - --client-name demo.example.genmeta.net \ - --client-cert /path/to/demo.example.genmeta.net.pem \ - --client-key /path/to/demo.example.genmeta.net.key \ - --host demo.example.genmeta.net \ + --client-name demo.example.dhttp.net \ + --client-cert /path/to/demo.example.dhttp.net.pem \ + --client-key /path/to/demo.example.dhttp.net.key \ + --host demo.example.dhttp.net \ --addr 192.168.1.100:8080,192.168.1.101:8080 ``` -This command establishes an HTTP/3 connection to the server, sends a POST request containing DNS records, the server verifies the signature and stores the records. - -### Querying Services (query) - -Use the `query` example to query DNS service records from the HTTP/3 DNS server. +Options: -#### Program Parameters -- `--base-url `: Base URL of the DNS server (default: `https://dns.genmeta.net:4433/`). -- `--server-ca `: CA certificate PEM file path for verifying the online server certificate. -- `--host `: DNS name to query (default: `stun.genmeta.net`). +| Option | Meaning | +| --- | --- | +| `--base-url ` | DNS-over-H3 server base URL. Defaults to build-time `DHTTP_H3_DNS_SERVER` with a trailing slash. | +| `--server-ca ` | PEM root CA used to verify the DNS server certificate. | +| `--client-name ` | DHTTP identity name presented by the client endpoint. | +| `--client-cert ` | Client certificate chain PEM for mTLS and endpoint signature verification. | +| `--client-key ` | Client private key PEM. | +| `--sign ` | Whether to sign each endpoint `E` record. Defaults to `true`. | +| `--host ` | DNS host to publish. Standard-policy servers require this to match the client certificate DNS SAN. | +| `--addr ` | One or more socket addresses to publish. | -#### Example Run Command -```bash -cargo run --example query --features="h3x-resolver" \ - --server-ca /path/to/root.crt \ - --host stun.genmeta.net -``` - -This command sends a GET or POST request to the server, the request body contains the DNS query message, the server returns matching records. - -### Running the DNS Server (server) +The example derives the endpoint selector from the client certificate SKI before +signing records. Use the correct certificate chain instead of manual selector +flags. -Use the `server` example to start an HTTP/3 DNS server. +The example sends `POST /publish?host=` with a binary DNS packet body. For +Standard policy domains, the server requires a client certificate whose single +DNS SAN matches `host`; when `require_signature = true`, at least one signed +endpoint record must verify against the publisher certificate. Open-multi policy +domains still require client mTLS but skip the host SAN and endpoint signature +checks. -#### Program Parameters -- `--redis `: Optional Redis connection URL for persistent storage (default: none, uses in-memory storage). -- `--listen `: Server listen address (default: `127.0.0.1:4433`). -- `--server-name `: Server name (default: `localhost`). -- `--cert `: Server certificate PEM file (default: `examples/keychain/localhost/server.cert`). -- `--key `: Server private key PEM file (default: `examples/keychain/localhost/server.key`). -- `--root-cert `: Root CA certificate PEM file (default: `examples/keychain/localhost/ca.cert`). -- `--require-signature`: Whether to require client-signed records (default: true). -- `--ttl-secs `: TTL time for records in seconds (default: 30). +## Running the server -#### Example Run Command ```bash -cargo run --example server --features="h3x-resolver" \ - --listen 127.0.0.1:4433 \ - --cert examples/keychain/localhost/server.cert \ - --key examples/keychain/localhost/server.key +cargo run --bin ddns-server --features server -- --config server.toml ``` -After the server starts, it listens for HTTP/3 requests and handles publish and query operations. - -## Other Examples - -The project also includes other example programs such as `mdns_discover.rs` and `mdns_query.rs` for pure mDNS discovery and query operations, not involving HTTP/3. Please refer to the source code for more details. - -## Notes - -- Ensure that the local network supports QUIC and HTTP/3. -- Certificate and key files must be configured correctly, otherwise TLS handshake will fail. -- For production environments, use valid certificates and secure key management. -- For more configuration options, please refer to the project's main README.md file. +`server.toml` documents the available fields: listener, TLS identity, client root +CA, optional Redis storage, TTL, domain policies, and static seed records. diff --git a/examples/mdns_discover.rs b/examples/mdns_discover.rs index 8593c73..ce6eb12 100644 --- a/examples/mdns_discover.rs +++ b/examples/mdns_discover.rs @@ -4,10 +4,9 @@ use std::{ }; use clap::Parser; +use ddns::{core::MdnsEndpoint, mdns::service::Mdns, resolvers::DHTTP_MDNS_SERVICE}; use futures::StreamExt; -const SERVICE_NAME: &str = "_genmeta.local"; - #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { @@ -21,14 +20,14 @@ struct Args { async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); let args = Args::parse(); - let mdns = gmdns::mdns::Mdns::new(SERVICE_NAME, args.ip, &args.device)?; + let mdns = Mdns::new(DHTTP_MDNS_SERVICE, args.ip, &args.device)?; mdns.insert_host( - "test.genmeta.net".to_string(), + "test.dhttp.net".to_string(), vec![ { let addr: SocketAddr = "192.168.1.7:7000".parse().unwrap(); if let SocketAddr::V4(v4) = addr { - gmdns::parser::record::endpoint::EndpointAddr::direct_v4(v4) + MdnsEndpoint::direct_v4(v4) } else { panic!("Expected IPv4 address"); } @@ -36,7 +35,7 @@ async fn main() -> Result<(), Error> { { let addr: SocketAddr = "192.168.1.13:7000".parse().unwrap(); if let SocketAddr::V4(v4) = addr { - gmdns::parser::record::endpoint::EndpointAddr::direct_v4(v4) + MdnsEndpoint::direct_v4(v4) } else { panic!("Expected IPv4 address"); } @@ -45,12 +44,12 @@ async fn main() -> Result<(), Error> { ); mdns.insert_host( - "mdns.test.genmeta.net".to_string(), + "mdns.test.dhttp.net".to_string(), vec![ { let addr: SocketAddr = "192.168.1.7:7001".parse().unwrap(); if let SocketAddr::V4(v4) = addr { - gmdns::parser::record::endpoint::EndpointAddr::direct_v4(v4) + MdnsEndpoint::direct_v4(v4) } else { panic!("Expected IPv4 address"); } @@ -58,7 +57,7 @@ async fn main() -> Result<(), Error> { { let addr: SocketAddr = "192.168.1.7:7001".parse().unwrap(); if let SocketAddr::V4(v4) = addr { - gmdns::parser::record::endpoint::EndpointAddr::direct_v4(v4) + MdnsEndpoint::direct_v4(v4) } else { panic!("Expected IPv4 address"); } @@ -66,7 +65,7 @@ async fn main() -> Result<(), Error> { { let addr: SocketAddr = "192.168.1.7:7001".parse().unwrap(); if let SocketAddr::V4(v4) = addr { - gmdns::parser::record::endpoint::EndpointAddr::direct_v4(v4) + MdnsEndpoint::direct_v4(v4) } else { panic!("Expected IPv4 address"); } diff --git a/examples/mdns_query.rs b/examples/mdns_query.rs index def0f93..16fd3a6 100644 --- a/examples/mdns_query.rs +++ b/examples/mdns_query.rs @@ -1,8 +1,7 @@ use std::{io::Error, net::IpAddr}; use clap::Parser; - -const SERVICE_NAME: &str = "_genmeta.local"; +use ddns::{mdns::service::Mdns, resolvers::DHTTP_MDNS_SERVICE}; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -17,9 +16,9 @@ struct Args { async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); let args = Args::parse(); - let mdns = gmdns::mdns::Mdns::new(SERVICE_NAME, args.ip, &args.device)?; + let mdns = Mdns::new(DHTTP_MDNS_SERVICE, args.ip, &args.device)?; - let ret = mdns.query("publish.test.genmeta.net".to_string()).await?; + let ret = mdns.query("publish.test.dhttp.net".to_string()).await?; println!("{ret:?}\n"); Ok(()) } diff --git a/examples/publish.rs b/examples/publish.rs index 0b24d76..629fe9f 100644 --- a/examples/publish.rs +++ b/examples/publish.rs @@ -6,16 +6,24 @@ use std::{ }; use clap::Parser; -use gmdns::{parser::record::endpoint::EndpointAddr, resolvers::H3Publisher}; -use h3x::dquic::{H3Client, qresolve::Publish}; -use rustls::{RootCertStore, SignatureScheme, pki_types::PrivateKeyDer, sign::SigningKey}; +use ddns::{ + core::parser::record::endpoint::EndpointAddr, + resolvers::{DHTTP_H3_DNS_SERVER, h3::H3Publisher}, +}; +use h3x::dquic::{ + Identity, Network, QuicEndpoint, + cert::handy::{ToCertificate, ToPrivateKey}, + client::{ClientQuicConfig, ServerCertVerifierChoice}, + resolver::{Publish, handy::SystemResolver}, +}; +use rustls::{RootCertStore, client::WebPkiServerVerifier}; use tracing::{Level, info}; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Options { /// Base URL of the线上 H3 DNS server. - #[arg(long, default_value = "https://dns.genmeta.net:4433/")] + #[arg(long, default_value_t = default_h3_base_url())] base_url: String, /// 用于校验线上服务端证书的 CA PEM 文件。 @@ -48,12 +56,10 @@ struct Options { /// 要发布的地址列表。 #[arg(long, value_delimiter = ',', num_args = 1..)] addr: Vec, +} - #[arg(long, default_value_t = true)] - is_main: bool, - - #[arg(long, default_value_t = 1)] - sequence: u64, +fn default_h3_base_url() -> String { + format!("{}/", DHTTP_H3_DNS_SERVER.trim_end_matches('/')) } fn load_root_store_from_pem(path: &Path) -> io::Result { @@ -83,47 +89,8 @@ fn expand_tilde(path: &Path) -> io::Result { Ok(PathBuf::from(shellexpand::tilde(path).into_owned())) } -fn load_private_key_from_pem(pem: &[u8]) -> io::Result> { - let mut reader = std::io::Cursor::new(pem); - let key = rustls_pemfile::private_key(&mut reader) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found in PEM"))?; - Ok(key) -} - -fn build_signing_key_from_pem(pem: &[u8]) -> io::Result> { - let key = load_private_key_from_pem(pem)?; - rustls::crypto::ring::sign::any_supported_type(&key) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) -} - -fn pick_signature_scheme(key: &dyn SigningKey) -> io::Result { - // Order is preference; choose_scheme picks the first it supports. - let offered = [ - SignatureScheme::ED25519, - SignatureScheme::ECDSA_NISTP256_SHA256, - SignatureScheme::ECDSA_NISTP384_SHA384, - SignatureScheme::RSA_PSS_SHA256, - SignatureScheme::RSA_PSS_SHA384, - SignatureScheme::RSA_PSS_SHA512, - SignatureScheme::RSA_PKCS1_SHA256, - SignatureScheme::RSA_PKCS1_SHA384, - SignatureScheme::RSA_PKCS1_SHA512, - ]; - - let signer = key - .choose_scheme(&offered) - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Unsupported key type/scheme"))?; - Ok(signer.scheme()) -} - #[tokio::main] async fn main() -> io::Result<()> { - // Install ring crypto provider - rustls::crypto::ring::default_provider() - .install_default() - .expect("Failed to install ring crypto provider"); - tracing_subscriber::fmt() .with_max_level(Level::DEBUG) .init(); @@ -137,53 +104,70 @@ async fn main() -> io::Result<()> { let cert_chain_pem = std::fs::read(&client_cert)?; let private_key_pem = std::fs::read(&client_key)?; - let signer = opt - .sign - .then(|| build_signing_key_from_pem(&private_key_pem)) - .transpose()?; - let signer_scheme = signer.as_deref().map(pick_signature_scheme).transpose()?; - - let client = H3Client::builder() - .with_root_certificates(Arc::new(root_store)) - .with_identity( - opt.client_name, - cert_chain_pem.as_slice(), - private_key_pem.as_slice(), - ) - .map_err(io::Error::other)? - .build(); + // Build WebPki server cert verifier from CA root store + let verifier = WebPkiServerVerifier::builder(Arc::new(root_store)) + .build() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + // Build TLS identity from cert chain and private key PEM + let identity = Arc::new(Identity { + name: opt.client_name.parse().unwrap(), + certs: Arc::new(cert_chain_pem.to_certificate()), + key: Arc::new(private_key_pem.to_private_key()), + ocsp: Arc::new(None), + }); + + // Build network and QuicEndpoint with client mTLS config + let network = Network::builder().build(); + let quic = QuicEndpoint::builder() + .network(network) + .identity(identity.clone()) + .resolver(Arc::new(SystemResolver)) + .client(ClientQuicConfig { + verifier: ServerCertVerifierChoice::WebPki(verifier), + ..Default::default() + }) + .build() + .await; + let h3_endpoint = h3x::dquic::H3Endpoint::new(quic); // Uses H3Resolver which uses dquic internally aka HTTP/3 - let resolver = H3Publisher::new(opt.base_url.clone(), client)?; + let resolver = H3Publisher::new(opt.base_url.clone(), h3_endpoint)?; info!(host = %opt.host, addrs = ?opt.addr, base_url = %opt.base_url, "publish.start"); - if let Some(scheme) = signer_scheme { - info!(?scheme, "publish.endpoint_signing.enabled"); + if opt.sign { + info!("publish.endpoint_signing.enabled"); } else { info!("publish.endpoint_signing.disabled"); } + let selector = identity + .dhttp_subject_key_identifier() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let chain = selector.chain(); for &addr in &opt.addr { - info!("Creating endpoint for address: {}", addr); + info!("creating endpoint for address: {}", addr); let mut endpoint = match addr { SocketAddr::V4(v4) => EndpointAddr::direct_v4(v4), SocketAddr::V6(v6) => EndpointAddr::direct_v6(v6), }; - endpoint.set_main(opt.is_main); - endpoint.set_sequence(opt.sequence); - if let Some((key, scheme)) = signer.as_deref().zip(signer_scheme) { - info!("Signing endpoint with scheme: {:?}", scheme); - endpoint.sign_with(key, scheme).map_err(io::Error::other)?; + endpoint.set_certificate_chain_key(chain); + if opt.sign { + info!("signing endpoint"); + endpoint + .sign_with_authority(identity.as_ref()) + .await + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; } - info!("Publishing endpoint: {:?}", endpoint); + info!("publishing endpoint: {:?}", endpoint); let mut hosts = std::collections::HashMap::new(); hosts.insert(opt.host.clone(), vec![endpoint]); - let packet = gmdns::MdnsPacket::answer(0, &hosts).to_bytes(); + let packet = ddns::core::MdnsPacket::answer(0, &hosts).to_bytes(); resolver .publish(&opt.host, &packet) .await .map_err(io::Error::other)?; - info!("Successfully published endpoint for {}", addr); + info!("successfully published endpoint for {}", addr); } info!("publish.ok"); diff --git a/examples/query.rs b/examples/query.rs index 5a2408c..80235cd 100644 --- a/examples/query.rs +++ b/examples/query.rs @@ -5,16 +5,27 @@ use std::{ }; use clap::Parser; -use gmdns::{MdnsPacket, parser::record::RData, wire::be_multi_response}; -use h3x::dquic::H3Client; -use rustls::RootCertStore; +use ddns::{ + core::{MdnsPacket, parser::record::RData, wire::be_multi_response}, + resolvers::DHTTP_H3_DNS_SERVER, +}; +use h3x::{ + dquic::{ + Network, QuicEndpoint, + client::{ClientQuicConfig, ServerCertVerifierChoice}, + resolver::handy::SystemResolver, + }, + endpoint::H3Endpoint, +}; +use http_body_util::{BodyExt, Empty}; +use rustls::{RootCertStore, client::WebPkiServerVerifier}; use tracing::{Level, info}; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Options { /// Base URL of the线上 HTTP/3 DNS server. - #[arg(long, default_value = "https://dns.genmeta.net:4433/")] + #[arg(long, default_value_t = default_h3_base_url())] base_url: String, /// 用于校验线上服务端证书的 CA PEM 文件。 @@ -22,10 +33,14 @@ struct Options { server_ca: PathBuf, /// 要查询的线上域名。 - #[arg(long, default_value = "stun.genmeta.net")] + #[arg(long, default_value = "nat.genmeta.net")] host: String, } +fn default_h3_base_url() -> String { + format!("{}/", DHTTP_H3_DNS_SERVER.trim_end_matches('/')) +} + fn load_root_store_from_pem(path: &Path) -> io::Result { let pem = std::fs::read(path)?; let mut store = RootCertStore::empty(); @@ -85,9 +100,6 @@ fn expand_tilde(path: &Path) -> io::Result { #[tokio::main] async fn main() -> Result<(), Box> { - rustls::crypto::ring::default_provider() - .install_default() - .expect("Failed to install ring crypto provider"); tracing_subscriber::fmt() .with_max_level(Level::DEBUG) .init(); @@ -95,20 +107,43 @@ async fn main() -> Result<(), Box> { let opt = Options::parse(); let server_ca = expand_tilde(&opt.server_ca)?; let root_store = load_root_store_from_pem(&server_ca)?; - let client = H3Client::builder() - .with_root_certificates(Arc::new(root_store)) - .without_identity() - .map_err(|e| io::Error::other(e.to_string()))? - .build(); + let verifier = WebPkiServerVerifier::builder(Arc::new(root_store)) + .build() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let client_config = ClientQuicConfig { + verifier: ServerCertVerifierChoice::WebPki(verifier), + ..Default::default() + }; + let network = Network::builder().build(); + let quic = QuicEndpoint::builder() + .network(network) + .resolver(Arc::new(SystemResolver)) + .client(client_config) + .build() + .await; + let client = H3Endpoint::new(quic); let url = format!("{}lookup?host={}", opt.base_url, opt.host); info!(url = %url, "lookup.start"); let uri: http::Uri = url.parse()?; - let (_req, mut resp) = client.new_request().get(uri).await?; + let authority = uri + .authority() + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "query URL must include authority", + ) + })? + .clone(); + let connection = Arc::new(client).connect(authority).await?; + let request = http::Request::get(uri) + .body(Empty::::new()) + .expect("query request must be valid"); + let resp = connection.execute_hyper_request(request).await?; if resp.status().is_success() { - let bytes = resp.read_to_bytes().await?; + let bytes = resp.into_body().collect().await?.to_bytes(); let (_remain, multi) = be_multi_response(bytes.as_ref()).map_err(|e| { io::Error::new( @@ -128,7 +163,7 @@ async fn main() -> Result<(), Box> { None => println!("Source fingerprint: (no certificate)"), } - match gmdns::parser::packet::be_packet(&record.dns) { + match ddns::core::parser::packet::be_packet(&record.dns) { Ok((_, packet)) => { print!("{}", format_packet(&packet)); diff --git a/gmdns-server/Cargo.toml b/gmdns-server/Cargo.toml deleted file mode 100644 index d18a63a..0000000 --- a/gmdns-server/Cargo.toml +++ /dev/null @@ -1,40 +0,0 @@ -[package] -name = "gmdns-server" -version = "0.2.0" -edition = "2024" - -[[bin]] -name = "gmdns-server" -path = "src/main.rs" - -[dependencies] -gmdns = { path = "..", features = ["h3x-resolver"] } -h3x = { git = "https://github.com/genmeta/h3x.git", branch = "main", features = [ - "dquic", -] } - -# server-specific deps(不再污染核心库) -deadpool-redis = "0.12" -redis = { version = "0.23", features = ["tokio-comp", "aio"] } -serde = { version = "1", features = ["derive"] } -toml = "0.8" -dashmap = "6" -bytes = "1" -base64 = "0.22" -clap = { version = "4", features = ["derive"] } -http = "1" -nom = "8" -rustls = { version = "0.23", default-features = false, features = [ - "logging", - "ring", -] } -rustls-pemfile = "2" -idna = "1" -x509-parser = "0.18" -snafu = "0.8" -tokio = { version = "1", features = ["full"] } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -url = "2" -futures = "0.3" -ring = "0.17" diff --git a/gmdns-server/src/policy.rs b/gmdns-server/src/policy.rs deleted file mode 100644 index 96094f8..0000000 --- a/gmdns-server/src/policy.rs +++ /dev/null @@ -1,144 +0,0 @@ -use gmdns::parser::{packet::be_packet, record::RData}; -use h3x::quic::agent::RemoteAgent; -use tracing::warn; - -use crate::error::{AppError, normalize_host}; - -// --------------------------------------------------------------------------- -// Domain policy -// --------------------------------------------------------------------------- - -/// Per-domain publish / lookup behaviour. -#[derive(Clone, Debug, PartialEq)] -pub enum DomainPolicy { - /// Signature check controlled by `require_signature` flag; single record - /// per host; each publish overwrites the previous one. - Standard, - /// No signature check; any authenticated node may publish; multiple records - /// with individual TTLs; ordered newest-first on lookup. - OpenMulti, -} - -/// One rule in the domain-policy list. -#[derive(Clone, Debug)] -pub enum PolicyRule { - /// Matches only this exact (normalised) host. - Exact(String), - /// Matches the host itself or any label-subdomain (future use). - #[allow(dead_code)] - Suffix(String), -} - -impl PolicyRule { - pub fn matches(&self, host: &str) -> bool { - match self { - PolicyRule::Exact(exact) => host == exact, - PolicyRule::Suffix(suffix) => { - host == suffix.as_str() || host.ends_with(&format!(".{suffix}")) - } - } - } -} - -/// Ordered list of (rule, policy) pairs; first match wins; default is Standard. -#[derive(Clone, Debug, Default)] -pub struct DomainPolicies(pub Vec<(PolicyRule, DomainPolicy)>); - -impl DomainPolicies { - pub fn policy_for(&self, host: &str) -> &DomainPolicy { - for (rule, policy) in &self.0 { - if rule.matches(host) { - return policy; - } - } - &DomainPolicy::Standard - } -} - -// --------------------------------------------------------------------------- -// Certificate helpers -// --------------------------------------------------------------------------- - -pub fn extract_client_dns_sans(agent: &(impl RemoteAgent + ?Sized)) -> Vec { - use x509_parser::prelude::*; - - let Some(leaf) = agent.cert_chain().first() else { - return vec![]; - }; - - let Ok((_remain, cert)) = X509Certificate::from_der(leaf.as_ref()) else { - return vec![]; - }; - - let mut out = vec![]; - if let Ok(Some(san)) = cert.subject_alternative_name() { - for name in san.value.general_names.iter() { - if let GeneralName::DNSName(dns) = name { - out.push(dns.to_string()); - } - } - } - out -} - -pub fn client_allowed_host(agent: &(impl RemoteAgent + ?Sized)) -> Result { - let mut sans = extract_client_dns_sans(agent) - .into_iter() - .filter_map(|h| normalize_host(&h).ok()) - .collect::>(); - - sans.sort(); - sans.dedup(); - - match sans.len() { - 1 => Ok(sans.remove(0)), - _ => Err(AppError::ClientCertDomainNotAllowed), - } -} - -pub fn validate_dns_packet( - packet: &[u8], - require_signature: bool, - agent: &(impl RemoteAgent + ?Sized), -) -> Result { - let (remaining, dns_packet) = be_packet(packet).map_err(|e| AppError::InvalidDnsPacket { - message: e.to_string(), - })?; - if !remaining.is_empty() { - warn!(remain = remaining.len(), "dns.parse.extra_bytes"); - } - - if require_signature { - let has_signature = dns_packet - .answers - .iter() - .any(|record| matches!(record.data(), RData::E(endpoint) if endpoint.is_signed())); - - if !has_signature { - return Err(AppError::SignatureRequired); - } - - for record in &dns_packet.answers { - if let RData::E(endpoint) = record.data() - && endpoint.is_signed() - { - let cert = agent - .cert_chain() - .first() - .ok_or(AppError::MissingClientCertificate)?; - let ok = endpoint - .verify_signature_from_der(cert.as_ref()) - .map_err(|_| AppError::InvalidSignature)?; - if !ok { - return Err(AppError::InvalidSignature); - } - } - } - } - - dns_packet - .answers - .first() - .map(|record| record.name().to_string()) - .ok_or(AppError::NoAnswersInPacket) -} diff --git a/gmdns-server/src/publish.rs b/gmdns-server/src/publish.rs deleted file mode 100644 index a7eed21..0000000 --- a/gmdns-server/src/publish.rs +++ /dev/null @@ -1,254 +0,0 @@ -use futures::future::BoxFuture; -use h3x::{ - quic::agent::RemoteAgent, - server::{Request, Response, Service}, -}; -use redis::AsyncCommands; -use tokio::time::{Duration, Instant}; -use tracing::{debug, info, warn}; - -use crate::{ - error::{AppError, normalize_host, parse_query_params}, - lookup::write_error, - policy::{DomainPolicy, client_allowed_host, validate_dns_packet}, - storage::{ - AppState, Record, Storage, StoredRecord, cert_fingerprint, cert_fingerprint_hex, - unix_now_secs, - }, -}; - -// --------------------------------------------------------------------------- -// PublishSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct PublishSvc { - pub state: AppState, -} - -impl Service for PublishSvc { - type Future<'s> = BoxFuture<'s, ()>; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> { - let state = self.state.clone(); - Box::pin(async move { - debug!("received publish request"); - - let params = parse_query_params(&request.uri()); - debug!("query params: {:?}", params); - - let Some(host) = params.get("host") else { - warn!("missing host parameter"); - write_error(response, AppError::MissingHostParam).await; - return; - }; - - let host = match normalize_host(host) { - Ok(h) => h, - Err(e) => { - write_error(response, e).await; - return; - } - }; - debug!(host = %host, "publish.host"); - - // Require a valid client certificate for all publish requests. - let Some(agent) = request.agent().cloned() else { - warn!("missing client certificate"); - write_error(response, AppError::MissingClientCertificate).await; - return; - }; - - let policy = state.policies.policy_for(&host).clone(); - - // Standard policy: cert SAN must match the target host. - // OpenMulti policy: any authenticated node may publish — skip SAN check. - if policy == DomainPolicy::Standard { - let allowed = match client_allowed_host(agent.as_ref()) { - Ok(h) => h, - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); - write_error(response, e).await; - return; - } - }; - if allowed != host { - warn!(allowed = %allowed, requested = %host, "publish.host_mismatch"); - write_error(response, AppError::HostMismatch).await; - return; - } - } - - let body = match request.read_to_bytes().await { - Ok(b) => b, - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "failed to read request body"); - write_error( - response, - AppError::InvalidDnsPacket { - message: e.to_string(), - }, - ) - .await; - return; - } - }; - - // Validate DNS packet; signature check only for Standard hosts. - let require_sig = policy == DomainPolicy::Standard && state.require_signature; - let packet_name = match validate_dns_packet(body.as_ref(), require_sig, agent.as_ref()) - { - Ok(n) => n, - Err(e) => { - write_error(response, e).await; - return; - } - }; - - let packet_host = match normalize_host(&packet_name) { - Ok(h) => h, - Err(e) => { - write_error(response, e).await; - return; - } - }; - - if packet_host != host { - write_error(response, AppError::HostMismatch).await; - return; - } - - publish_record(&state, &host, &body, agent.as_ref(), response).await - }) - } -} - -/// Unified publish handler: stores the record keyed by (host, cert-fingerprint). -/// Both Standard and OpenMulti policies follow the same storage path; -/// the only policy difference (SAN check) is already enforced in the caller. -/// -/// Certificate fingerprint is the publish-source identity. In PKI ecosystems, -/// a single domain name can have multiple valid certificates (from different CAs, -/// or issued at different times for rotation/failover/multi-region scenarios). -/// Using fingerprint as part of the storage key enables: -/// - Multi-publisher coexistence: different cert holders can publish the same domain -/// - Idempotent updates: re-publishing from same cert source (same fingerprint) overwrites old data -/// - Client choice: lookups return all active records, client picks which certificate to trust -pub async fn publish_record( - state: &AppState, - host: &str, - body: &bytes::Bytes, - agent: &(impl RemoteAgent + ?Sized), - response: &mut Response, -) { - let cert_bytes = agent - .cert_chain() - .first() - .map(|c| c.as_ref().to_vec()) - .unwrap_or_default(); - - let fp = cert_fingerprint(&cert_bytes); - let fp_hex = cert_fingerprint_hex(&cert_bytes); - - match &state.storage { - Storage::Redis(pool) => { - let mut conn = match pool.get().await { - Ok(c) => c, - Err(e) => { - write_error( - response, - AppError::Redis { - message: e.to_string(), - }, - ) - .await; - return; - } - }; - let ttl_secs: usize = state.ttl_secs.try_into().unwrap_or(usize::MAX); - let now_secs = unix_now_secs(); - let expire_secs = now_secs + state.ttl_secs; - - let fp_key = format!("{host}:fp:{fp_hex}"); - let set_key = format!("{host}:multi"); - - // Remove the previous entry from this source (if any) from the ZSET. - let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old) = old_member { - let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); - } - - // Encode and store the new member. - let new_member = StoredRecord { - expire_unix_secs: expire_secs, - fingerprint: fp, - dns: body.to_vec(), - cert: cert_bytes.clone(), - } - .encode(); - - if let Err(e) = conn - .set_ex::<_, _, ()>(&fp_key, &new_member, ttl_secs) - .await - { - write_error( - response, - AppError::Redis { - message: e.to_string(), - }, - ) - .await; - return; - } - - if let Err(e) = conn - .zadd::<_, _, _, ()>(&set_key, &new_member, now_secs as f64) - .await - { - write_error( - response, - AppError::Redis { - message: e.to_string(), - }, - ) - .await; - return; - } - - // Expire the ZSET key at max(ttl_secs) from now as a safety net. - let _: () = conn.expire(&set_key, ttl_secs).await.unwrap_or(()); - - // Evict stale (score < now - ttl) entries. - let cutoff = now_secs.saturating_sub(state.ttl_secs) as f64; - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&set_key) - .arg("-inf") - .arg(cutoff) - .query_async::<_, ()>(&mut *conn) - .await - .unwrap_or(()); - } - Storage::Memory(mem) => { - let now = Instant::now(); - let expire = now + Duration::from_secs(state.ttl_secs); - let record = Record { - dns_bytes: body.to_vec(), - cert_bytes, - expire, - published_at: now, - }; - // Upsert by fingerprint: same source overwrites its own entry; - // different sources (different certs) coexist independently. - let mut host_map = mem.records.entry(host.to_string()).or_default(); - host_map.insert(fp, record); - // Evict expired entries while we hold the write lock. - host_map.retain(|_, r| r.expire > now); - } - } - - info!(host = %host, ttl = state.ttl_secs, bytes = body.len(), fp = %fp_hex, "publish.ok"); - response - .set_status(http::StatusCode::OK) - .set_body(bytes::Bytes::from_static(b"OK")); - let _ = response.flush().await; -} diff --git a/gmdns-server/server.toml b/server.toml similarity index 90% rename from gmdns-server/server.toml rename to server.toml index 79d6b8c..90eafe6 100644 --- a/gmdns-server/server.toml +++ b/server.toml @@ -1,8 +1,8 @@ -# gmdns DNS-over-HTTP/3 Server configuration +# ddns DNS-over-HTTP/3 server configuration # All fields are optional; the values shown below are the built-in defaults. -# Socket address to listen on. -listen = "0.0.0.0:4433" +# Bind patterns to listen on. Dual-stack service is expressed explicitly. +binds = ["0.0.0.0:4433", "[::]:4433"] # TLS server name (SNI). server_name = "dns.genmeta.net" @@ -42,14 +42,14 @@ ttl_secs = 30 # --------------------------------------------------------------------------- [[domain_policies]] -host = "stun.genmeta.net" +host = "nat.genmeta.net" policy = "open_multi" # Static bootstrap STUN endpoints returned even before any node publishes. # Ordering keeps the main :20002 endpoints ahead of the auxiliary :20003 endpoints. [[seed_records]] -host = "stun.genmeta.net" +host = "nat.genmeta.net" endpoints = [] # Add more rules as needed, e.g.: diff --git a/gmdns-server/src/config.rs b/src/bin/ddns-server/config.rs similarity index 65% rename from gmdns-server/src/config.rs rename to src/bin/ddns-server/config.rs index 245c112..8b663bd 100644 --- a/gmdns-server/src/config.rs +++ b/src/bin/ddns-server/config.rs @@ -1,10 +1,12 @@ use std::{ net::SocketAddr, path::{Path, PathBuf}, + str::FromStr, }; use clap::Parser; -use serde::Deserialize; +use h3x::dquic::binds::BindPattern; +use serde::{Deserialize, Deserializer, de::Error as _}; // --------------------------------------------------------------------------- // CLI @@ -29,9 +31,12 @@ pub struct Config { /// Redis URL (e.g. "redis://127.0.0.1/"). Omit to use in-memory storage. pub redis: Option, - /// Socket to listen on. - #[serde(default = "Config::default_listen")] - pub listen: SocketAddr, + /// Bind patterns to listen on. + #[serde( + default = "Config::default_binds", + deserialize_with = "deserialize_bind_patterns" + )] + pub binds: Vec, /// Server name (used as TLS SNI). #[serde(default = "Config::default_server_name")] @@ -74,8 +79,13 @@ impl Config { self } - pub fn default_listen() -> SocketAddr { - "0.0.0.0:4433".parse().unwrap() + pub fn default_binds() -> Vec { + ["0.0.0.0:4433", "[::]:4433"] + .into_iter() + .map(|value| { + BindPattern::from_str(value).expect("default bind pattern should be valid") + }) + .collect() } pub fn default_server_name() -> String { "localhost".into() @@ -97,6 +107,21 @@ impl Config { } } +fn deserialize_bind_patterns<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let values = Vec::::deserialize(deserializer)?; + values + .into_iter() + .map(|value| { + BindPattern::from_str(&value).map_err(|error| { + D::Error::custom(format!("invalid bind pattern `{value}`: {error}")) + }) + }) + .collect() +} + fn expand_home_dir(path: &Path) -> PathBuf { let Some(path_str) = path.to_str() else { return path.to_path_buf(); @@ -144,3 +169,43 @@ pub enum PolicyKind { Standard, OpenMulti, } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_binds_are_explicit_dual_stack() { + let binds = Config::default_binds(); + + assert_eq!(binds.len(), 2); + assert_eq!(binds[0].to_string(), "inet://0.0.0.0:4433"); + assert_eq!(binds[1].to_string(), "inet://[::]:4433"); + } + + #[test] + fn config_parses_bare_socket_bind_patterns() { + let config: Config = toml::from_str( + r#" + binds = ["0.0.0.0:4433", "[::]:4433"] + "#, + ) + .expect("config should parse"); + + assert_eq!(config.binds.len(), 2); + assert_eq!(config.binds[0].to_string(), "inet://0.0.0.0:4433"); + assert_eq!(config.binds[1].to_string(), "inet://[::]:4433"); + } + + #[test] + fn legacy_listen_field_is_rejected() { + let error = toml::from_str::( + r#" + listen = "0.0.0.0:4433" + "#, + ) + .expect_err("legacy listen should be rejected"); + + assert!(error.to_string().contains("unknown field `listen`")); + } +} diff --git a/gmdns-server/src/error.rs b/src/bin/ddns-server/error.rs similarity index 69% rename from gmdns-server/src/error.rs rename to src/bin/ddns-server/error.rs index e242522..c8930ba 100644 --- a/gmdns-server/src/error.rs +++ b/src/bin/ddns-server/error.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; +use dhttp_identity::name::DhttpName; + #[derive(Debug, snafu::Snafu)] +#[snafu(module, visibility(pub(crate)))] pub enum AppError { #[snafu(display("missing host parameter"))] MissingHostParam, @@ -18,6 +21,16 @@ pub enum AppError { ClientCertDomainNotAllowed, #[snafu(display("invalid DNS packet: {message}"))] InvalidDnsPacket { message: String }, + #[snafu(display("publisher certificate selector is invalid"))] + PublisherCertificateSelector { + source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, + }, + #[snafu(display("endpoint record selector is invalid"))] + EndpointRecordSelector { + source: ddns::core::parser::record::endpoint::EndpointSelectorError, + }, + #[snafu(display("endpoint record selector does not match publisher certificate selector"))] + EndpointSelectorMismatch, #[snafu(display("no answers in packet"))] NoAnswersInPacket, #[snafu(display("signature required"))] @@ -39,6 +52,9 @@ impl AppError { AppError::MissingClientCertificate => http::StatusCode::UNAUTHORIZED, AppError::ClientCertDomainNotAllowed => http::StatusCode::FORBIDDEN, AppError::InvalidDnsPacket { .. } => http::StatusCode::BAD_REQUEST, + AppError::PublisherCertificateSelector { .. } => http::StatusCode::BAD_REQUEST, + AppError::EndpointRecordSelector { .. } => http::StatusCode::BAD_REQUEST, + AppError::EndpointSelectorMismatch => http::StatusCode::BAD_REQUEST, AppError::NoAnswersInPacket => http::StatusCode::UNPROCESSABLE_ENTITY, AppError::SignatureRequired => http::StatusCode::BAD_REQUEST, AppError::InvalidSignature => http::StatusCode::BAD_REQUEST, @@ -68,8 +84,8 @@ pub fn normalize_host(host: &str) -> Result { let host = idna::domain_to_ascii(host).map_err(|_| AppError::InvalidHost)?; let host = host.to_ascii_lowercase(); - // 校验是否为 genmeta.net 域名 - if !host.ends_with("genmeta.net") { + // 校验是否为 DHTTP identity 域名 + if !host.ends_with(DhttpName::SUFFIX) { return Err(AppError::DomainNotAllowed); } @@ -82,3 +98,20 @@ pub fn parse_query_params(uri: &http::Uri) -> HashMap { .into_owned() .collect() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normalize_host_uses_dhttp_identity_suffix() { + assert_eq!( + normalize_host("Reimu.Pilot.Dhttp.Net:443").unwrap(), + "reimu.pilot.dhttp.net" + ); + assert!(matches!( + normalize_host("reimu.pilot.genmeta.net"), + Err(AppError::DomainNotAllowed) + )); + } +} diff --git a/gmdns-server/src/lookup.rs b/src/bin/ddns-server/lookup.rs similarity index 80% rename from gmdns-server/src/lookup.rs rename to src/bin/ddns-server/lookup.rs index 34fd1ba..e8c2616 100644 --- a/gmdns-server/src/lookup.rs +++ b/src/bin/ddns-server/lookup.rs @@ -1,22 +1,27 @@ use std::{ collections::{HashMap, HashSet}, + convert::Infallible, net::SocketAddr, }; -use futures::future::BoxFuture; -use gmdns::{ +use ddns::core::{ MdnsPacket, parser::{packet::be_packet, record::RData}, + wire::MultiResponse, }; -use h3x::server::{Request, Response, Service}; -use redis::AsyncCommands; +use deadpool_redis::redis::{self, AsyncCommands}; +use h3x::dhttp::message::MessageStreamError; +use http_body_util::{Full, combinators::UnsyncBoxBody}; use tracing::debug; use crate::{ error::{AppError, normalize_host, parse_query_params}, - storage::{AppState, LookupRecord, MultiResponse, Storage, StoredRecord, unix_now_secs}, + storage::{AppState, LookupRecord, Storage, StoredRecord, unix_now_secs}, }; +pub type Request = http::Request>; +pub type Response = http::Response>; + // --------------------------------------------------------------------------- // Lookup result type // --------------------------------------------------------------------------- @@ -99,7 +104,7 @@ async fn perform_lookup_multi( .arg(&set_key) .arg("-inf") .arg(cutoff_score) - .query_async::<_, ()>(&mut *conn) + .query_async::<()>(&mut *conn) .await .unwrap_or(()); @@ -163,10 +168,20 @@ async fn perform_lookup_multi( // HTTP response helpers // --------------------------------------------------------------------------- -pub async fn write_error(resp: &mut Response, err: AppError) { - resp.set_status(err.status()) - .set_body(bytes::Bytes::from(format!("{}", err))); - let _ = resp.flush().await; +pub fn body_response(status: http::StatusCode, body: impl Into) -> Response { + http::Response::builder() + .status(status) + .body(Full::new(body.into())) + .expect("response parts must be valid") +} + +pub fn write_error(err: AppError) -> Response { + debug!( + status = %err.status(), + error = %err, + "writing error response" + ); + body_response(err.status(), bytes::Bytes::from(err.to_string())) } // --------------------------------------------------------------------------- @@ -186,11 +201,10 @@ pub struct LookupSvc { /// /// Optional query param `limit=N` caps the number of records returned. /// Dynamic records are newest-first; configured seed records are appended after them. -pub async fn lookup_with_cert(state: AppState, request: &mut Request, response: &mut Response) { - let params = parse_query_params(&request.uri()); +pub async fn lookup_with_cert(state: AppState, request: Request) -> Response { + let params = parse_query_params(request.uri()); let Some(host) = params.get("host") else { - write_error(response, AppError::MissingHostParam).await; - return; + return write_error(AppError::MissingHostParam); }; let limit: Option = params @@ -203,38 +217,33 @@ pub async fn lookup_with_cert(state: AppState, request: &mut Request, response: match perform_lookup(&state, host, limit).await { Ok(LookupResult::NotFound) => { debug!(host = %host, "lookup.not_found"); - response - .set_status(http::StatusCode::NOT_FOUND) - .set_body(bytes::Bytes::from_static(b"Not Found")); - let _ = response.flush().await; + body_response( + http::StatusCode::NOT_FOUND, + bytes::Bytes::from_static(b"Not Found"), + ) } Ok(LookupResult::Multi(resp)) => { let body = resp.encode(); debug!(host = %host, records = resp.records.len(), "lookup.found"); - response - .set_status(http::StatusCode::OK) - .set_body(bytes::Bytes::from(body)); + let mut response = body_response(http::StatusCode::OK, bytes::Bytes::from(body)); response.headers_mut().insert( http::HeaderName::from_static("x-record-format"), http::HeaderValue::from_static("multi"), ); - let _ = response.flush().await; + response } - Err(e) => { - write_error(response, e).await; - } + Err(e) => write_error(e), } } -impl Service for LookupSvc { - type Future<'s> = BoxFuture<'s, ()>; - - fn serve<'s>(&self, request: &'s mut Request, response: &'s mut Response) -> Self::Future<'s> { +impl LookupSvc { + pub fn call( + &self, + request: Request, + ) -> impl Future> + Send + 'static { let state = self.state.clone(); - Box::pin(async move { - lookup_with_cert(state, request, response).await; - }) + async move { Ok(lookup_with_cert(state, request).await) } } } diff --git a/gmdns-server/src/main.rs b/src/bin/ddns-server/main.rs similarity index 63% rename from gmdns-server/src/main.rs rename to src/bin/ddns-server/main.rs index 57cde39..735fa74 100644 --- a/gmdns-server/src/main.rs +++ b/src/bin/ddns-server/main.rs @@ -5,16 +5,25 @@ mod policy; mod publish; mod storage; -use std::{collections::HashMap, io, net::SocketAddr, sync::Arc}; +use std::{ + collections::HashMap, + io, + net::SocketAddr, + sync::Arc, + task::{Context, Poll}, +}; use clap::Parser; -use gmdns::{MdnsEndpoint, MdnsPacket}; +use ddns::core::{MdnsEndpoint, MdnsPacket}; +use futures::future::BoxFuture; use h3x::{ - dquic::prelude::{ - BindUri, - handy::{ToCertificate, ToPrivateKey}, + dquic::{ + Identity, Network, QuicEndpoint, + cert::handy::{ToCertificate, ToPrivateKey}, + server::ServerQuicConfig, }, - server::{Router, Servers}, + endpoint::H3Endpoint, + hyper::TowerService, }; use rustls::{RootCertStore, server::WebPkiClientVerifier}; use tracing::{info, level_filters::LevelFilter}; @@ -28,6 +37,49 @@ use crate::{ storage::{AppState, MemoryStorage, SeedRecords, Storage}, }; +#[derive(Clone)] +struct DnsService { + publish: PublishSvc, + lookup: LookupSvc, +} + +impl tower_service::Service for DnsService { + type Response = lookup::Response; + type Error = io::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, request: lookup::Request) -> Self::Future { + let method = request.method().clone(); + let path = request.uri().path().to_owned(); + let publish = self.publish.clone(); + let lookup = self.lookup.clone(); + Box::pin(async move { + match (method, path.as_str()) { + (http::Method::POST, "/publish") => match publish.call(request).await { + Ok(response) => Ok(response), + Err(never) => match never {}, + }, + (http::Method::GET, "/lookup") => match lookup.call(request).await { + Ok(response) => Ok(response), + Err(never) => match never {}, + }, + (_, "/publish" | "/lookup") => Ok(lookup::body_response( + http::StatusCode::METHOD_NOT_ALLOWED, + bytes::Bytes::from_static(b"Method Not Allowed"), + )), + _ => Ok(lookup::body_response( + http::StatusCode::NOT_FOUND, + bytes::Bytes::from_static(b"Not Found"), + )), + } + }) + } +} + // --------------------------------------------------------------------------- // TLS helpers // --------------------------------------------------------------------------- @@ -52,7 +104,7 @@ fn build_seed_records(seed_records: &[SeedRecordConfig]) -> io::Result io::Result Result<(), Box> { - // Install ring crypto provider - rustls::crypto::ring::default_provider() - .install_default() - .expect("Failed to install ring crypto provider"); - tracing_subscriber::registry() .with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::filter::filter_fn(|metadata| { @@ -159,46 +206,36 @@ async fn main() -> Result<(), Box> { let cert_pem = std::fs::read(&config.cert)?; let key_pem = std::fs::read(&config.key)?; - let router = Router::new() - .post( - "/publish", - PublishSvc { - state: state.clone(), - }, - ) - .get( - "/lookup", - LookupSvc { - state: state.clone(), - }, - ); - - let bind = { - let base = BindUri::from(format!("inet://{}", config.listen)); - if config.listen.port() == 0 { - base.alloc_port() - } else { - base - } - }; + let router = TowerService(DnsService { + publish: PublishSvc { + state: state.clone(), + }, + lookup: LookupSvc { + state: state.clone(), + }, + }); - let mut servers = Servers::builder() - .with_client_cert_verifier(verifier)? - .listen()?; - - servers - .add_server( - config.server_name.clone(), - cert_pem.to_certificate(), - key_pem.to_private_key(), - None, - [bind], - router, - ) - .await?; - - info!(listen = %config.listen, server_name = %config.server_name, "h3_server.start"); - _ = servers.run().await; + let identity = Arc::new(Identity { + name: config.server_name.parse().unwrap(), + certs: Arc::new(cert_pem.to_certificate()), + key: Arc::new(key_pem.to_private_key()), + ocsp: Arc::new(None), + }); + let server_config = ServerQuicConfig { + alpns: vec![b"h3".to_vec()], + client_cert_verifier: verifier, + ..Default::default() + }; + let quic = QuicEndpoint::builder() + .network(Network::builder().build()) + .identity(identity) + .server(server_config) + .bind(Arc::new(config.binds.clone())) + .build() + .await; + let server = Arc::new(H3Endpoint::new(quic)); + info!(binds = ?config.binds, server_name = %config.server_name, "h3_server.start"); + server.listen_owned(router).await?; Ok(()) } diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs new file mode 100644 index 0000000..413ad4b --- /dev/null +++ b/src/bin/ddns-server/policy.rs @@ -0,0 +1,337 @@ +use ddns::core::parser::{packet::be_packet, record::RData}; +use dhttp_identity::identity::{RemoteAuthority, RemoteAuthorityCertificateExt}; +use snafu::ResultExt; +use tracing::{debug, warn}; + +use crate::error::{AppError, app_error, normalize_host}; + +// --------------------------------------------------------------------------- +// Domain policy +// --------------------------------------------------------------------------- + +/// Per-domain publish / lookup behaviour. +#[derive(Clone, Debug, PartialEq)] +pub enum DomainPolicy { + /// Signature check controlled by `require_signature` flag; single record + /// per host; each publish overwrites the previous one. + Standard, + /// No signature check; any authenticated node may publish; multiple records + /// with individual TTLs; ordered newest-first on lookup. + OpenMulti, +} + +/// One rule in the domain-policy list. +#[derive(Clone, Debug)] +pub enum PolicyRule { + /// Matches only this exact (normalised) host. + Exact(String), + /// Matches the host itself or any label-subdomain (future use). + #[allow(dead_code)] + Suffix(String), +} + +impl PolicyRule { + pub fn matches(&self, host: &str) -> bool { + match self { + PolicyRule::Exact(exact) => host == exact, + PolicyRule::Suffix(suffix) => { + host == suffix.as_str() || host.ends_with(&format!(".{suffix}")) + } + } + } +} + +/// Ordered list of (rule, policy) pairs; first match wins; default is Standard. +#[derive(Clone, Debug, Default)] +pub struct DomainPolicies(pub Vec<(PolicyRule, DomainPolicy)>); + +impl DomainPolicies { + pub fn policy_for(&self, host: &str) -> &DomainPolicy { + for (rule, policy) in &self.0 { + if rule.matches(host) { + return policy; + } + } + &DomainPolicy::Standard + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ValidatedDnsPacket { + Records { host: String }, + Empty, +} + +// --------------------------------------------------------------------------- +// Certificate helpers +// --------------------------------------------------------------------------- + +pub fn extract_client_dns_sans(authority: &(impl RemoteAuthority + ?Sized)) -> Vec { + use x509_parser::prelude::*; + + let Some(leaf) = authority.cert_chain().first() else { + return vec![]; + }; + + let Ok((_remain, cert)) = X509Certificate::from_der(leaf.as_ref()) else { + return vec![]; + }; + + let mut out = vec![]; + if let Ok(Some(san)) = cert.subject_alternative_name() { + for name in san.value.general_names.iter() { + if let GeneralName::DNSName(dns) = name { + out.push(dns.to_string()); + } + } + } + out +} + +pub fn client_allowed_host( + authority: &(impl RemoteAuthority + ?Sized), +) -> Result { + let mut sans = extract_client_dns_sans(authority) + .into_iter() + .filter_map(|h| normalize_host(&h).ok()) + .collect::>(); + + sans.sort(); + sans.dedup(); + + match sans.len() { + 1 => Ok(sans.remove(0)), + _ => Err(AppError::ClientCertDomainNotAllowed), + } +} + +pub fn validate_dns_packet( + packet: &[u8], + require_signature: bool, + authority: &(impl RemoteAuthority + ?Sized), +) -> Result { + let (remaining, dns_packet) = be_packet(packet).map_err(|e| AppError::InvalidDnsPacket { + message: e.to_string(), + })?; + if !remaining.is_empty() { + warn!(remain = remaining.len(), "dns.parse.extra_bytes"); + } + debug!( + answers = dns_packet.answers.len(), + require_signature, "validating dns packet" + ); + + let Some(first_answer) = dns_packet.answers.first() else { + debug!("dns packet has no answers"); + return Ok(ValidatedDnsPacket::Empty); + }; + + validate_endpoint_selectors(&dns_packet, authority)?; + + if require_signature { + let has_signature = dns_packet + .answers + .iter() + .any(|record| matches!(record.data(), RData::E(endpoint) if endpoint.is_signed())); + + if !has_signature { + return Err(AppError::SignatureRequired); + } + + for record in &dns_packet.answers { + if let RData::E(endpoint) = record.data() + && endpoint.is_signed() + { + let cert = authority + .cert_chain() + .first() + .ok_or(AppError::MissingClientCertificate)?; + let ok = endpoint + .verify_signature_from_der(cert.as_ref()) + .map_err(|_| AppError::InvalidSignature)?; + if !ok { + return Err(AppError::InvalidSignature); + } + } + } + } + + Ok(ValidatedDnsPacket::Records { + host: first_answer.name().to_string(), + }) +} + +fn validate_endpoint_selectors( + dns_packet: &ddns::core::parser::packet::Packet, + authority: &(impl RemoteAuthority + ?Sized), +) -> Result<(), AppError> { + let mut endpoints = dns_packet + .answers + .iter() + .filter_map(|record| match record.data() { + RData::E(endpoint) => Some(endpoint), + _ => None, + }); + + let Some(first_endpoint) = endpoints.next() else { + return Ok(()); + }; + + let expected = authority + .dhttp_subject_key_identifier() + .context(app_error::PublisherCertificateSelectorSnafu)? + .chain() + .clone(); + + let first = first_endpoint + .certificate_chain_key() + .context(app_error::EndpointRecordSelectorSnafu)?; + if first != expected { + return Err(AppError::EndpointSelectorMismatch); + } + + for endpoint in endpoints { + let actual = endpoint + .certificate_chain_key() + .context(app_error::EndpointRecordSelectorSnafu)?; + if actual != expected { + return Err(AppError::EndpointSelectorMismatch); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; + use dhttp_identity::identity::RemoteAuthority; + use rustls::pki_types::CertificateDer; + + use super::*; + + #[derive(Debug)] + struct TestAuthority { + certs: Vec>, + } + + impl TestAuthority { + fn valid() -> Self { + Self { + certs: vec![CertificateDer::from( + include_bytes!("../../../tests/fixtures/valid.der").to_vec(), + )], + } + } + + fn missing_ski() -> Self { + Self { + certs: vec![CertificateDer::from( + include_bytes!("../../../tests/fixtures/missing.der").to_vec(), + )], + } + } + + fn malformed_ski() -> Self { + Self { + certs: vec![CertificateDer::from( + include_bytes!("../../../tests/fixtures/malformed.der").to_vec(), + )], + } + } + } + + impl RemoteAuthority for TestAuthority { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.certs + } + } + + fn packet_with_endpoint(endpoint: EndpointAddr) -> Vec { + let hosts: HashMap> = + HashMap::from([("reimu.pilot.dhttp.net".to_owned(), vec![endpoint])]); + MdnsPacket::answer(0, &hosts).to_bytes() + } + + #[test] + fn validate_dns_packet_accepts_matching_certificate_selector() { + let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); + endpoint.set_main(true); + endpoint.set_sequence(0); + let packet = packet_with_endpoint(endpoint); + + let validated = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap(); + + assert!(matches!(validated, ValidatedDnsPacket::Records { .. })); + } + + #[test] + fn validate_dns_packet_rejects_mismatched_endpoint_kind() { + let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); + endpoint.set_main(false); + endpoint.set_sequence(0); + let packet = packet_with_endpoint(endpoint); + + let error = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap_err(); + + assert!(matches!(error, AppError::EndpointSelectorMismatch)); + } + + #[test] + fn validate_dns_packet_rejects_mismatched_endpoint_sequence() { + let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); + endpoint.set_main(true); + endpoint.set_sequence(7); + let packet = packet_with_endpoint(endpoint); + + let error = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap_err(); + + assert!(matches!(error, AppError::EndpointSelectorMismatch)); + } + + #[test] + fn validate_dns_packet_rejects_missing_publisher_ski() { + let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); + endpoint.set_main(true); + let packet = packet_with_endpoint(endpoint); + + let error = validate_dns_packet(&packet, false, &TestAuthority::missing_ski()).unwrap_err(); + + assert!(matches!( + error, + AppError::PublisherCertificateSelector { .. } + )); + } + + #[test] + fn validate_dns_packet_rejects_malformed_publisher_ski() { + let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); + endpoint.set_main(true); + let packet = packet_with_endpoint(endpoint); + + let error = + validate_dns_packet(&packet, false, &TestAuthority::malformed_ski()).unwrap_err(); + + assert!(matches!( + error, + AppError::PublisherCertificateSelector { .. } + )); + } + + #[test] + fn validate_dns_packet_accepts_empty_packet_as_clear_operation() { + let hosts: HashMap> = + HashMap::from([("reimu.pilot.dhttp.net".to_owned(), Vec::new())]); + let packet = MdnsPacket::answer(0, &hosts).to_bytes(); + + let validated = validate_dns_packet(&packet, true, &TestAuthority::valid()).unwrap(); + + assert!(matches!(validated, ValidatedDnsPacket::Empty)); + } +} diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish.rs new file mode 100644 index 0000000..5b07347 --- /dev/null +++ b/src/bin/ddns-server/publish.rs @@ -0,0 +1,427 @@ +use std::{convert::Infallible, sync::Arc}; + +use deadpool_redis::redis::{self, AsyncCommands}; +use dhttp_identity::identity::RemoteAuthority; +use h3x::{connection::ConnectionState, quic}; +use http_body_util::BodyExt; +use tokio::time::{Duration, Instant}; +use tracing::{debug, info, warn}; + +use crate::{ + error::{AppError, normalize_host, parse_query_params}, + lookup::{Request, Response, body_response, write_error}, + policy::{DomainPolicy, ValidatedDnsPacket, client_allowed_host, validate_dns_packet}, + storage::{ + AppState, Record, Storage, StoredRecord, cert_fingerprint, cert_fingerprint_hex, + unix_now_secs, + }, +}; + +// --------------------------------------------------------------------------- +// PublishSvc +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct PublishSvc { + pub state: AppState, +} + +impl PublishSvc { + pub fn call( + &self, + request: Request, + ) -> impl Future> + Send + 'static { + let state = self.state.clone(); + async move { Ok(publish_with_cert(state, request).await) } + } +} + +async fn publish_with_cert(state: AppState, request: Request) -> Response { + debug!("received publish request"); + + let params = parse_query_params(request.uri()); + debug!("query params: {:?}", params); + + let Some(host) = params.get("host") else { + warn!("missing host parameter"); + return write_error(AppError::MissingHostParam); + }; + + let host = match normalize_host(host) { + Ok(h) => h, + Err(e) => return write_error(e), + }; + debug!(host = %host, "publish.host"); + + // Require a valid client certificate for all publish requests. + let authority = match request_connection(&request) { + Some(connection) => match connection.remote_authority().await { + Ok(Some(authority)) => authority, + Ok(None) => { + warn!("missing client certificate"); + return write_error(AppError::MissingClientCertificate); + } + Err(error) => { + warn!(error = %snafu::Report::from_error(&error), "failed to read client certificate"); + return write_error(AppError::MissingClientCertificate); + } + }, + None => { + warn!("missing client certificate"); + return write_error(AppError::MissingClientCertificate); + } + }; + + let policy = state.policies.policy_for(&host).clone(); + + // Standard policy: cert SAN must match the target host. + // OpenMulti policy: any authenticated node may publish — skip SAN check. + if policy == DomainPolicy::Standard { + let allowed = match client_allowed_host(authority.as_ref()) { + Ok(h) => h, + Err(e) => { + warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); + return write_error(e); + } + }; + if allowed != host { + warn!(allowed = %allowed, requested = %host, "publish.host_mismatch"); + return write_error(AppError::HostMismatch); + } + } + + let body = match request.into_body().collect().await { + Ok(body) => body.to_bytes(), + Err(e) => { + warn!(error = %snafu::Report::from_error(&e), "failed to read request body"); + return write_error(AppError::InvalidDnsPacket { + message: e.to_string(), + }); + } + }; + + // Validate DNS packet; signature check only for Standard hosts. + let require_sig = policy == DomainPolicy::Standard && state.require_signature; + debug!( + host = %host, + bytes = body.len(), + require_signature = require_sig, + "validating publish packet" + ); + let packet = match validate_dns_packet(body.as_ref(), require_sig, authority.as_ref()) { + Ok(n) => n, + Err(e) => { + debug!(host = %host, error = %e, "publish packet rejected"); + return write_error(e); + } + }; + + match packet { + ValidatedDnsPacket::Records { host: packet_name } => { + let packet_host = match normalize_host(&packet_name) { + Ok(h) => h, + Err(e) => return write_error(e), + }; + + if packet_host != host { + return write_error(AppError::HostMismatch); + } + + publish_record(&state, &host, &body, authority.as_ref()).await + } + ValidatedDnsPacket::Empty => clear_record(&state, &host, authority.as_ref()).await, + } +} + +fn request_connection(request: &Request) -> Option>> { + request + .extensions() + .get::>>() + .cloned() +} + +/// Unified publish handler: stores the record keyed by (host, cert-fingerprint). +/// Both Standard and OpenMulti policies follow the same storage path; +/// the only policy difference (SAN check) is already enforced in the caller. +/// +/// Certificate fingerprint is the publish-source identity. In PKI ecosystems, +/// a single domain name can have multiple valid certificates (from different CAs, +/// or issued at different times for rotation/failover/multi-region scenarios). +/// Using fingerprint as part of the storage key enables: +/// - Multi-publisher coexistence: different cert holders can publish the same domain +/// - Idempotent updates: re-publishing from same cert source (same fingerprint) overwrites old data +/// - Client choice: lookups return all active records, client picks which certificate to trust +pub async fn publish_record( + state: &AppState, + host: &str, + body: &bytes::Bytes, + authority: &(impl RemoteAuthority + ?Sized), +) -> Response { + let cert_bytes = authority + .cert_chain() + .first() + .map(|c| c.as_ref().to_vec()) + .unwrap_or_default(); + + let fp = cert_fingerprint(&cert_bytes); + let fp_hex = cert_fingerprint_hex(&cert_bytes); + + match &state.storage { + Storage::Redis(pool) => { + let mut conn = match pool.get().await { + Ok(c) => c, + Err(e) => { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + }; + let ttl_secs = state.ttl_secs; + let expire_ttl_secs = i64::try_from(state.ttl_secs).unwrap_or(i64::MAX); + let now_secs = unix_now_secs(); + let expire_secs = now_secs + state.ttl_secs; + + let fp_key = format!("{host}:fp:{fp_hex}"); + let set_key = format!("{host}:multi"); + + // Remove the previous entry from this source (if any) from the ZSET. + let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); + if let Some(old) = old_member { + let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); + } + + // Encode and store the new member. + let new_member = StoredRecord { + expire_unix_secs: expire_secs, + fingerprint: fp, + dns: body.to_vec(), + cert: cert_bytes.clone(), + } + .encode(); + + if let Err(e) = conn + .set_ex::<_, _, ()>(&fp_key, &new_member, ttl_secs) + .await + { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + + if let Err(e) = conn + .zadd::<_, _, _, ()>(&set_key, &new_member, now_secs as f64) + .await + { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + + // Expire the ZSET key at max(ttl_secs) from now as a safety net. + let _: bool = conn + .expire(&set_key, expire_ttl_secs) + .await + .unwrap_or(false); + + // Evict stale (score < now - ttl) entries. + let cutoff = now_secs.saturating_sub(state.ttl_secs) as f64; + let _: () = redis::cmd("ZREMRANGEBYSCORE") + .arg(&set_key) + .arg("-inf") + .arg(cutoff) + .query_async::<()>(&mut *conn) + .await + .unwrap_or(()); + } + Storage::Memory(mem) => { + let now = Instant::now(); + let expire = now + Duration::from_secs(state.ttl_secs); + let record = Record { + dns_bytes: body.to_vec(), + cert_bytes, + expire, + published_at: now, + }; + // Upsert by fingerprint: same source overwrites its own entry; + // different sources (different certs) coexist independently. + let mut host_map = mem.records.entry(host.to_string()).or_default(); + host_map.insert(fp, record); + // Evict expired entries while we hold the write lock. + host_map.retain(|_, r| r.expire > now); + } + } + + info!(host = %host, ttl = state.ttl_secs, bytes = body.len(), fp = %fp_hex, "publish.ok"); + body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) +} + +pub async fn clear_record( + state: &AppState, + host: &str, + authority: &(impl RemoteAuthority + ?Sized), +) -> Response { + let cert_bytes = authority + .cert_chain() + .first() + .map(|c| c.as_ref().to_vec()) + .unwrap_or_default(); + + let fp = cert_fingerprint(&cert_bytes); + let fp_hex = cert_fingerprint_hex(&cert_bytes); + + match &state.storage { + Storage::Redis(pool) => { + let mut conn = match pool.get().await { + Ok(c) => c, + Err(e) => { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + }; + + let fp_key = format!("{host}:fp:{fp_hex}"); + let set_key = format!("{host}:multi"); + + let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); + if let Some(old) = old_member { + let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); + } + if let Err(e) = conn.del::<_, ()>(&fp_key).await { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + } + Storage::Memory(mem) => { + let remove_host = if let Some(mut host_map) = mem.records.get_mut(host) { + host_map.remove(&fp); + host_map.is_empty() + } else { + false + }; + if remove_host { + mem.records.remove(host); + } + } + } + + info!(host = %host, fp = %fp_hex, "publish.clear"); + body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) +} + +#[cfg(test)] +mod tests { + use std::{ + collections::HashMap, + net::{Ipv4Addr, SocketAddrV4}, + sync::Arc, + }; + + use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; + use dhttp_identity::identity::RemoteAuthority; + use rustls::pki_types::CertificateDer; + + use super::*; + use crate::{ + lookup::{LookupResult, perform_lookup}, + policy::DomainPolicies, + storage::{MemoryStorage, SeedRecords}, + }; + + #[derive(Debug)] + struct TestAuthority { + name: &'static str, + certs: Vec>, + } + + impl TestAuthority { + fn new(name: &'static str, cert_bytes: Vec) -> Self { + Self { + name, + certs: vec![CertificateDer::from(cert_bytes)], + } + } + } + + impl RemoteAuthority for TestAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.certs + } + } + + fn memory_state() -> AppState { + AppState { + storage: Storage::Memory(MemoryStorage::new()), + require_signature: true, + ttl_secs: 30, + policies: Arc::new(DomainPolicies::default()), + seed_records: SeedRecords::default(), + } + } + + fn packet_for(host: &str, last_octet: u8) -> bytes::Bytes { + let endpoint = EndpointAddr::direct_v4(SocketAddrV4::new( + Ipv4Addr::new(203, 0, 113, last_octet), + 4433, + )); + let mut hosts = HashMap::new(); + hosts.insert(host.to_owned(), vec![endpoint]); + bytes::Bytes::from(MdnsPacket::answer(0, &hosts).to_bytes()) + } + + #[tokio::test] + async fn clear_record_removes_only_current_certificate_fingerprint() { + let state = memory_state(); + let host = "reimu.pilot.dhttp.net"; + let authority_a = TestAuthority::new("authority-a", vec![1]); + let authority_b = TestAuthority::new("authority-b", vec![2]); + let packet_a = packet_for(host, 1); + let packet_b = packet_for(host, 2); + + assert_eq!( + publish_record(&state, host, &packet_a, &authority_a) + .await + .status(), + http::StatusCode::OK + ); + assert_eq!( + publish_record(&state, host, &packet_b, &authority_b) + .await + .status(), + http::StatusCode::OK + ); + + assert_eq!( + clear_record(&state, host, &authority_a).await.status(), + http::StatusCode::OK + ); + + let LookupResult::Multi(response) = perform_lookup(&state, host, None).await.unwrap() + else { + panic!("authority b record should remain"); + }; + assert_eq!(response.records.len(), 1); + assert_eq!(response.records[0].cert, authority_b.certs[0].as_ref()); + } + + #[tokio::test] + async fn clear_record_is_idempotent_for_missing_fingerprint() { + let state = memory_state(); + let host = "reimu.pilot.dhttp.net"; + let authority = TestAuthority::new("authority", vec![1]); + + assert_eq!( + clear_record(&state, host, &authority).await.status(), + http::StatusCode::OK + ); + assert!(matches!( + perform_lookup(&state, host, None).await.unwrap(), + LookupResult::NotFound + )); + } +} diff --git a/gmdns-server/src/storage.rs b/src/bin/ddns-server/storage.rs similarity index 67% rename from gmdns-server/src/storage.rs rename to src/bin/ddns-server/storage.rs index b489387..e194faf 100644 --- a/gmdns-server/src/storage.rs +++ b/src/bin/ddns-server/storage.rs @@ -122,97 +122,6 @@ pub fn be_stored_record(input: &[u8]) -> IResult<&[u8], StoredRecord> { )) } -// --------------------------------------------------------------------------- -// HTTP multi-record response wire type -// --------------------------------------------------------------------------- - -/// One DNS + certificate pair inside a [`MultiResponse`]. -#[derive(Debug, Clone)] -pub struct ResponseRecord { - /// Serialised DNS packet bytes. - pub dns: Vec, - /// DER-encoded leaf certificate of the publisher (may be empty). - pub cert: Vec, -} - -/// HTTP response body carrying zero or more DNS records. -/// -/// Wire layout (big-endian, contiguous): -/// ```text -/// +-----------+ (repeated `count` times) -/// | count | +-----------+------+-----------+------+ -/// | u32 BE | | dns_len | dns | cert_len | cert | -/// +-----------+ | u32 BE | ... | u32 BE | ... | -/// +-----------+------+-----------+------+ -/// ``` -#[derive(Debug, Clone)] -pub struct MultiResponse { - pub records: Vec, -} - -impl MultiResponse { - pub fn new(iter: impl IntoIterator, Vec)>) -> Self { - Self { - records: iter - .into_iter() - .map(|(dns, cert)| ResponseRecord { dns, cert }) - .collect(), - } - } - - pub fn encoding_size(&self) -> usize { - 4 + self - .records - .iter() - .map(|r| 4 + r.dns.len() + 4 + r.cert.len()) - .sum::() - } - - /// Encode to a byte buffer sent as the HTTP response body. - pub fn encode(&self) -> Vec { - let mut buf = Vec::with_capacity(self.encoding_size()); - buf.put_multi_response(self); - buf - } -} - -/// `BufMut` write extension for [`MultiResponse`]. -pub trait WriteMultiResponse { - fn put_multi_response(&mut self, resp: &MultiResponse); -} - -impl WriteMultiResponse for B { - fn put_multi_response(&mut self, resp: &MultiResponse) { - self.put_u32(resp.records.len() as u32); - for r in &resp.records { - self.put_u32(r.dns.len() as u32); - self.put_slice(&r.dns); - self.put_u32(r.cert.len() as u32); - self.put_slice(&r.cert); - } - } -} - -/// nom parser for [`MultiResponse`]. -/// Used by the client-side decoder; provided here to keep the wire format symmetric and testable. -#[allow(dead_code)] -pub fn be_multi_response(input: &[u8]) -> IResult<&[u8], MultiResponse> { - let (mut input, count) = be_u32(input)?; - let mut records = Vec::with_capacity(count as usize); - for _ in 0..count { - let (rest, dns_len) = be_u32(input)?; - let (rest, dns) = take(dns_len as usize)(rest)?; - let (rest, cert_len) = be_u32(rest)?; - let (rest, cert) = take(cert_len as usize)(rest)?; - records.push(ResponseRecord { - dns: dns.to_vec(), - cert: cert.to_vec(), - }); - input = rest; - } - Ok((input, MultiResponse { records })) -} - // --------------------------------------------------------------------------- // Storage // --------------------------------------------------------------------------- diff --git a/src/bootstrap.rs b/src/bootstrap.rs new file mode 100644 index 0000000..aa8c097 --- /dev/null +++ b/src/bootstrap.rs @@ -0,0 +1 @@ +include!(concat!(env!("OUT_DIR"), "/bootstrap.rs")); diff --git a/src/core.rs b/src/core.rs new file mode 100644 index 0000000..308bbb2 --- /dev/null +++ b/src/core.rs @@ -0,0 +1,5 @@ +pub mod parser; +pub mod wire; + +pub type MdnsEndpoint = parser::record::endpoint::EndpointAddr; +pub type MdnsPacket = parser::packet::Packet; diff --git a/src/core/parser.rs b/src/core/parser.rs new file mode 100644 index 0000000..847dca7 --- /dev/null +++ b/src/core/parser.rs @@ -0,0 +1,7 @@ +pub mod header; +pub mod name; +pub mod packet; +pub mod question; +pub mod record; +pub mod sigin; +pub mod varint; diff --git a/src/parser/header.rs b/src/core/parser/header.rs similarity index 100% rename from src/parser/header.rs rename to src/core/parser/header.rs diff --git a/src/parser/name.rs b/src/core/parser/name.rs similarity index 100% rename from src/parser/name.rs rename to src/core/parser/name.rs diff --git a/src/parser/packet.rs b/src/core/parser/packet.rs similarity index 97% rename from src/parser/packet.rs rename to src/core/parser/packet.rs index 345ee8f..fb7bc9b 100644 --- a/src/parser/packet.rs +++ b/src/core/parser/packet.rs @@ -4,14 +4,15 @@ use bytes::BufMut; use super::{ header::{Header, be_header}, - question::{QueryClass, QueryType, Question, be_question}, - record::{Class, RData, ResourceRecord, endpoint::EndpointAddr}, -}; -use crate::parser::{ - header::WriteHeader, name::{NameCompression, put_name}, - record::{Type, be_record, endpoint::WriteEndpointAddr, srv::Srv}, + question::{QueryClass, QueryType, Question, be_question}, + record::{ + Class, RData, ResourceRecord, Type, be_record, + endpoint::{EndpointAddr, WriteEndpointAddr}, + srv::Srv, + }, }; +use crate::core::parser::header::WriteHeader; /// Parsed DNS packet #[derive(Default, Clone)] @@ -72,6 +73,14 @@ impl fmt::Display for Packet { } impl Packet { + pub fn id(&self) -> u16 { + self.header.id + } + + pub fn is_query(&self) -> bool { + self.header.flags.query() + } + pub fn query_with_id(service_name: String) -> Self { let mut packet = Packet::default(); let id: u16 = rand::random(); @@ -275,7 +284,7 @@ mod test { use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use super::*; - use crate::parser::{ + use crate::core::parser::{ self, question::{QueryClass, QueryType}, record::{Class, RData, Type, srv::Srv}, diff --git a/src/parser/question.rs b/src/core/parser/question.rs similarity index 93% rename from src/parser/question.rs rename to src/core/parser/question.rs index e498823..a9126f8 100644 --- a/src/parser/question.rs +++ b/src/core/parser/question.rs @@ -1,5 +1,6 @@ +use std::io; + use nom::number::streaming::be_u16; -use tokio::io; use super::name::{Name, be_name}; @@ -24,6 +25,24 @@ pub struct Question { pub(crate) qclass: QueryClass, } +impl Question { + pub fn name(&self) -> &Name { + &self.name + } + + pub fn prefer_unicast(&self) -> bool { + self.prefer_unicast + } + + pub fn qtype(&self) -> QueryType { + self.qtype + } + + pub fn qclass(&self) -> QueryClass { + self.qclass + } +} + #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum QueryType { /// a host addresss diff --git a/src/parser/record.rs b/src/core/parser/record.rs similarity index 99% rename from src/parser/record.rs rename to src/core/parser/record.rs index cff5bbf..951f055 100644 --- a/src/parser/record.rs +++ b/src/core/parser/record.rs @@ -1,5 +1,6 @@ use std::{ fmt::Display, + io, net::{Ipv4Addr, Ipv6Addr}, }; @@ -13,7 +14,6 @@ use nom::{ }; use ptr::{Ptr, be_ptr}; use srv::{Srv, be_srv}; -use tokio::io; use txt::Txt; use super::name::{Name, be_name}; diff --git a/src/parser/record/endpoint.rs b/src/core/parser/record/endpoint.rs similarity index 72% rename from src/parser/record/endpoint.rs rename to src/core/parser/record/endpoint.rs index ea682aa..75167b5 100644 --- a/src/parser/record/endpoint.rs +++ b/src/core/parser/record/endpoint.rs @@ -8,7 +8,8 @@ use std::{ use base64::Engine; use bytes::BufMut; -use h3x::dquic::qresolve::SocketEndpointAddr; +use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind, CertificateSequence}; +use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use nom::{ IResult, Parser, bytes::streaming::take, @@ -16,13 +17,32 @@ use nom::{ error::{ErrorKind, make_error}, number::streaming::{be_u8, be_u16, be_u32, be_u128}, }; -use rustls::{SignatureScheme, pki_types::SubjectPublicKeyInfoDer, sign::SigningKey}; +use rustls::{SignatureScheme, pki_types::SubjectPublicKeyInfoDer}; +use snafu::{ResultExt, Snafu}; -use crate::parser::{ +use crate::core::parser::{ sigin, varint::{VarInt, WriteVarInt, be_varint}, }; +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignEndpointError { + #[snafu(display("failed to determine endpoint signature scheme"))] + SignatureScheme { source: sigin::SignatureSchemeError }, + #[snafu(display("failed to sign endpoint address"))] + Sign { + source: dhttp_identity::identity::SignError, + }, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum EndpointSelectorError { + #[snafu(display("endpoint record sequence does not fit certificate sequence"))] + SequenceTooLarge { sequence: u64 }, +} + /// EndpointAddress record (Type E = 266) /// /// Unified endpoint format that encodes address family, routing, clustering and NAT information @@ -34,7 +54,7 @@ use crate::parser::{ /// +-------+-----------------+--------------------+----------------+----------------------------+ /// | flags | sequence(varint)| addr | load(optional) | signature (optional) | /// +-------+-----------------+--------------------+----------------+----------------------------+ -/// | u8 | QUIC varint | see addr layout | f32 | scheme(u16)+len(varint)+N | +/// | u8 | QUIC varint | see addr layout | f32 | scheme(u16)+len(varint)+N | /// +-------+-----------------+--------------------+----------------+----------------------------+ /// /// addr layout: @@ -204,14 +224,18 @@ impl EndpointAddr { } } - pub fn sign_with( + pub async fn sign_with_authority( &mut self, - key: &(impl SigningKey + ?Sized), - scheme: SignatureScheme, - ) -> Result<(), sigin::SignError> { + authority: &(impl dhttp_identity::identity::LocalAuthority + ?Sized), + ) -> Result<(), SignEndpointError> { self.set_signed(true); let data = self.signed_data(); - let signature = sigin::sign(key, scheme, &data)?; + let scheme = sigin::signature_scheme(authority.public_key()) + .context(sign_endpoint_error::SignatureSchemeSnafu)?; + let signature = authority + .sign(&data) + .await + .context(sign_endpoint_error::SignSnafu)?; self.signature = Some(EndpointSignature { scheme: u16::from(scheme), signature, @@ -355,6 +379,28 @@ impl EndpointAddr { } } + pub fn certificate_chain_key(&self) -> Result { + let kind = if self.is_main() { + CertificateChainKind::Primary + } else { + CertificateChainKind::Secondary + }; + let sequence = self.sequence.map(VarInt::into_inner).unwrap_or(0); + if sequence > u64::from(u32::MAX) { + return endpoint_selector_error::SequenceTooLargeSnafu { sequence }.fail(); + } + let sequence = sequence as u32; + Ok(CertificateChainKey::new( + CertificateSequence::from(sequence), + kind, + )) + } + + pub fn set_certificate_chain_key(&mut self, chain: &CertificateChainKey) { + self.set_main(chain.kind() == CertificateChainKind::Primary); + self.set_sequence(u64::from(chain.sequence().get())); + } + pub fn load(&self) -> Option { self.load } @@ -680,22 +726,22 @@ impl Display for EndpointAddr { } } -impl TryFrom for EndpointAddr { +impl TryFrom for EndpointAddr { type Error = (); - fn try_from(value: SocketEndpointAddr) -> Result { + fn try_from(value: DquicEndpointAddr) -> Result { match value { - SocketEndpointAddr::Direct { + DquicEndpointAddr::Direct { addr: SocketAddr::V4(addr), } => Ok(Self::direct_v4(addr)), - SocketEndpointAddr::Direct { + DquicEndpointAddr::Direct { addr: SocketAddr::V6(addr), } => Ok(Self::direct_v6(addr)), - SocketEndpointAddr::Agent { + DquicEndpointAddr::Agent { agent: SocketAddr::V4(agent), outer: SocketAddr::V4(outer), } => Ok(Self::nat_v4(outer, agent)), - SocketEndpointAddr::Agent { + DquicEndpointAddr::Agent { agent: SocketAddr::V6(agent), outer: SocketAddr::V6(outer), } => Ok(Self::nat_v6(outer, agent)), @@ -704,49 +750,35 @@ impl TryFrom for EndpointAddr { } } -impl TryFrom for SocketEndpointAddr { +impl TryFrom for DquicEndpointAddr { type Error = (); fn try_from(value: EndpointAddr) -> Result { if let Some(agent_addr) = value.agent { match (value.primary, agent_addr) { - (SocketAddr::V4(outer), SocketAddr::V4(agent)) => Ok(SocketEndpointAddr::Agent { - outer: outer.into(), - agent: agent.into(), + (SocketAddr::V4(outer), SocketAddr::V4(agent)) => Ok(DquicEndpointAddr::Agent { + outer: SocketAddr::V4(outer), + agent: SocketAddr::V4(agent), }), - (SocketAddr::V6(outer), SocketAddr::V6(agent)) => Ok(SocketEndpointAddr::Agent { - outer: outer.into(), - agent: agent.into(), + (SocketAddr::V6(outer), SocketAddr::V6(agent)) => Ok(DquicEndpointAddr::Agent { + outer: SocketAddr::V6(outer), + agent: SocketAddr::V6(agent), }), _ => Err(()), } } else { match value.primary { - SocketAddr::V4(addr) => Ok(SocketEndpointAddr::Direct { addr: addr.into() }), - SocketAddr::V6(addr) => Ok(SocketEndpointAddr::Direct { addr: addr.into() }), + SocketAddr::V4(addr) => Ok(DquicEndpointAddr::Direct { + addr: SocketAddr::V4(addr), + }), + SocketAddr::V6(addr) => Ok(DquicEndpointAddr::Direct { + addr: SocketAddr::V6(addr), + }), } } } } -#[cfg(feature = "h3x-resolver")] -pub fn sign_endponit_address( - server_id: u8, - key: Option<(&(impl SigningKey + ?Sized), SignatureScheme)>, - endpoint: SocketEndpointAddr, -) -> Option { - let mut ep: EndpointAddr = endpoint.try_into().ok()?; - - ep.set_main(server_id == 0); - ep.set_sequence(server_id as u64); - - if let Some((key, scheme)) = key { - let _ = ep.sign_with(key, scheme); - } - - Some(ep) -} - #[cfg(test)] mod tests { use std::{ @@ -755,11 +787,78 @@ mod tests { }; use bytes::BytesMut; + use dhttp_identity::certificate::{ + CertificateChainKey, CertificateChainKind, CertificateSequence, + }; + use futures::future::BoxFuture; use ring::signature::KeyPair; - use rustls::sign::Signer; + use rustls::{ + SignatureScheme, + sign::{Signer, SigningKey}, + }; use super::*; + fn chain(sequence: u32, kind: CertificateChainKind) -> CertificateChainKey { + CertificateChainKey::new(CertificateSequence::from(sequence), kind) + } + + fn ed25519_spki(public_key: &[u8]) -> Vec { + let mut spki = Vec::with_capacity(44); + spki.extend_from_slice(&[ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, + ]); + spki.extend_from_slice(public_key); + spki + } + + #[test] + fn endpoint_selector_normalizes_missing_sequence_to_primary_zero() { + let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5353); + let mut endpoint = EndpointAddr::direct_v4(addr); + endpoint.set_main(true); + + let selector = endpoint + .certificate_chain_key() + .expect("missing sequence normalizes to selector"); + + assert_eq!(selector, chain(0, CertificateChainKind::Primary)); + } + + #[test] + fn endpoint_selector_normalizes_missing_sequence_to_secondary_zero() { + let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 5353); + let endpoint = EndpointAddr::direct_v4(addr); + + let selector = endpoint + .certificate_chain_key() + .expect("missing sequence normalizes to selector"); + + assert_eq!(selector, chain(0, CertificateChainKind::Secondary)); + } + + #[test] + fn endpoint_selector_sets_primary_and_secondary_chains() { + let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 3), 5353); + let mut endpoint = EndpointAddr::direct_v4(addr); + + endpoint.set_certificate_chain_key(&chain(7, CertificateChainKind::Primary)); + assert!(endpoint.is_main()); + assert!(endpoint.is_clustered()); + assert_eq!( + endpoint.certificate_chain_key().unwrap(), + chain(7, CertificateChainKind::Primary) + ); + + endpoint.set_certificate_chain_key(&chain(0, CertificateChainKind::Secondary)); + assert!(!endpoint.is_main()); + assert!(!endpoint.is_clustered()); + assert_eq!( + endpoint.certificate_chain_key().unwrap(), + chain(0, CertificateChainKind::Secondary) + ); + } + #[test] fn legacy_endpoint_v4_direct_without_meta() { let port = 5353u16; @@ -927,9 +1026,49 @@ mod tests { } #[test] - fn endpoint_signature_roundtrip_and_verify() { + fn signed_endpoint_accepts_scheme_inclusive_signature() { + let addr = SocketAddrV4::new(Ipv4Addr::new(10, 10, 0, 7), 20004); + let scheme = u16::from(SignatureScheme::ED25519); + let signature = vec![0xaa; 64]; + let sig_len = VarInt::try_from(signature.len() as u64).unwrap(); + + let mut buf = BytesMut::new(); + buf.put_u8(EndpointAddr::FLAG_SIGNED); + buf.put_socket_addr_v4(&addr); + buf.put_u16(scheme); + buf.put_varint(sig_len); + buf.extend_from_slice(&signature); + + let (remain, decoded) = be_endpoint_addr(&buf).unwrap(); + + assert!(remain.is_empty()); + assert!(decoded.is_signed()); + assert_eq!(decoded.addr(), SocketAddr::V4(addr)); + assert_eq!(decoded.signature().unwrap().signature, signature); + } + + #[test] + fn signed_endpoint_rejects_signature_without_scheme() { + let addr = SocketAddrV4::new(Ipv4Addr::new(10, 10, 0, 7), 20004); + let signature = vec![0xaa; 64]; + let sig_len = VarInt::try_from(signature.len() as u64).unwrap(); + + let mut buf = BytesMut::new(); + buf.put_u8(EndpointAddr::FLAG_SIGNED); + buf.put_socket_addr_v4(&addr); + buf.put_varint(sig_len); + buf.extend_from_slice(&signature); + + assert!(be_endpoint_addr(&buf).is_err()); + } + + #[test] + fn signed_endpoint_writes_actual_scheme_before_signature_length() { #[derive(Debug)] - struct Ed25519Key(Arc); + struct Ed25519Key { + keypair: Arc, + spki: Vec, + } #[derive(Debug)] struct Ed25519Signer(Arc); @@ -948,7 +1087,7 @@ mod tests { fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { offered .contains(&SignatureScheme::ED25519) - .then(|| Box::new(Ed25519Signer(self.0.clone())) as Box) + .then(|| Box::new(Ed25519Signer(self.keypair.clone())) as Box) } fn algorithm(&self) -> rustls::SignatureAlgorithm { @@ -956,22 +1095,115 @@ mod tests { } } + impl dhttp_identity::identity::LocalAuthority for Ed25519Key { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + + fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { + SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + } + + fn sign( + &self, + data: &[u8], + ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { + let result = dhttp_identity::identity::sign_with_key(self, data); + Box::pin(std::future::ready(result)) + } + } + let rng = ring::rand::SystemRandom::new(); let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).unwrap(); let keypair = Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).unwrap()); - let key = Ed25519Key(keypair.clone()); + let spki = ed25519_spki(keypair.public_key().as_ref()); + let key = Ed25519Key { keypair, spki }; - let mut spki = Vec::with_capacity(44); - spki.extend_from_slice(&[ - 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, - ]); - spki.extend_from_slice(keypair.public_key().as_ref()); + let mut ep = EndpointAddr::direct_v4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353)); + futures::executor::block_on(ep.sign_with_authority(&key)).unwrap(); + + let mut buf = BytesMut::new(); + buf.put_endpoint_addr(&ep); + + let scheme_offset = 1 + 2 + 4; + let encoded_scheme = u16::from_be_bytes([buf[scheme_offset], buf[scheme_offset + 1]]); + assert_eq!(encoded_scheme, u16::from(SignatureScheme::ED25519)); + } + + #[test] + fn endpoint_signature_roundtrip_and_verify() { + #[derive(Debug)] + struct Ed25519Key { + keypair: Arc, + spki: Vec, + } + + #[derive(Debug)] + struct Ed25519Signer(Arc); + + impl Signer for Ed25519Signer { + fn sign(&self, message: &[u8]) -> Result, rustls::Error> { + Ok(self.0.sign(message).as_ref().to_vec()) + } + + fn scheme(&self) -> SignatureScheme { + SignatureScheme::ED25519 + } + } + + impl SigningKey for Ed25519Key { + fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { + offered + .contains(&SignatureScheme::ED25519) + .then(|| Box::new(Ed25519Signer(self.keypair.clone())) as Box) + } + + fn algorithm(&self) -> rustls::SignatureAlgorithm { + rustls::SignatureAlgorithm::ED25519 + } + } + + impl dhttp_identity::identity::LocalAuthority for Ed25519Key { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + + fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { + SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + } + + fn sign( + &self, + data: &[u8], + ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { + let result = dhttp_identity::identity::sign_with_key(self, data); + Box::pin(std::future::ready(result)) + } + } + + let rng = ring::rand::SystemRandom::new(); + let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).unwrap(); + let keypair = + Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).unwrap()); + let spki = ed25519_spki(keypair.public_key().as_ref()); + let key = Ed25519Key { + keypair: keypair.clone(), + spki: spki.clone(), + }; let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5353); let mut ep = EndpointAddr::direct_v4(addr); ep.set_main(true); - ep.sign_with(&key, SignatureScheme::ED25519).unwrap(); + futures::executor::block_on(ep.sign_with_authority(&key)).unwrap(); let mut buf = BytesMut::new(); buf.put_endpoint_addr(&ep); @@ -996,6 +1228,45 @@ mod tests { ); } + #[test] + fn sign_with_authority_stores_canonical_signature() { + #[derive(Debug)] + struct StaticAuthority { + spki: Vec, + } + + impl dhttp_identity::identity::LocalAuthority for StaticAuthority { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { + &[] + } + + fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { + SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + } + + fn sign( + &self, + _data: &[u8], + ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { + Box::pin(std::future::ready(Ok(vec![1, 2, 3]))) + } + } + + let mut ep = EndpointAddr::direct_v4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353)); + let authority = StaticAuthority { + spki: ed25519_spki(&[0; 32]), + }; + futures::executor::block_on(ep.sign_with_authority(&authority)).unwrap(); + + let signature = ep.signature().unwrap(); + assert_eq!(signature.scheme, u16::from(SignatureScheme::ED25519)); + assert_eq!(signature.signature, vec![1, 2, 3]); + } + #[test] fn optional_fields_flags_follow_values() { let addr = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 5353); diff --git a/src/parser/record/ptr.rs b/src/core/parser/record/ptr.rs similarity index 86% rename from src/parser/record/ptr.rs rename to src/core/parser/record/ptr.rs index ac30188..7a24bbb 100644 --- a/src/parser/record/ptr.rs +++ b/src/core/parser/record/ptr.rs @@ -1,4 +1,4 @@ -use crate::parser::name::{Name, be_name}; +use crate::core::parser::name::{Name, be_name}; #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct Ptr(Name); diff --git a/src/parser/record/srv.rs b/src/core/parser/record/srv.rs similarity index 97% rename from src/parser/record/srv.rs rename to src/core/parser/record/srv.rs index 98e2d78..60d0608 100644 --- a/src/parser/record/srv.rs +++ b/src/core/parser/record/srv.rs @@ -1,6 +1,6 @@ use nom::number::streaming::be_u16; -use crate::parser::name::{Name, be_name}; +use crate::core::parser::name::{Name, be_name}; #[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct Srv { @@ -56,7 +56,7 @@ pub fn be_srv<'a>(input: &'a [u8], origin: &'a [u8]) -> nom::IResult<&'a [u8], S #[cfg(test)] mod test { use super::*; - use crate::parser::{ + use crate::core::parser::{ packet::be_packet, question::{QueryClass, QueryType}, record::{Class, RData}, diff --git a/src/parser/record/txt.rs b/src/core/parser/record/txt.rs similarity index 100% rename from src/parser/record/txt.rs rename to src/core/parser/record/txt.rs diff --git a/src/core/parser/sigin.rs b/src/core/parser/sigin.rs new file mode 100644 index 0000000..597d90e --- /dev/null +++ b/src/core/parser/sigin.rs @@ -0,0 +1,112 @@ +use rustls::{SignatureScheme, pki_types::SubjectPublicKeyInfoDer, sign::SigningKey}; +use snafu::{ResultExt, Snafu}; +use x509_parser::{ + oid_registry::{ + OID_EC_P256, OID_KEY_TYPE_EC_PUBLIC_KEY, OID_NIST_EC_P384, OID_PKCS1_RSAENCRYPTION, + OID_SIG_ED25519, + }, + prelude::FromDer, + x509::SubjectPublicKeyInfo, +}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignError { + #[snafu(display("failed to sign DHTTP identity data"))] + Identity { + source: dhttp_identity::identity::SignError, + }, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum VerifyError { + #[snafu(display("failed to verify DHTTP identity signature"))] + Identity { + source: dhttp_identity::identity::VerifyError, + }, + #[snafu(display("unsupported signature scheme {scheme:?}"))] + UnsupportedScheme { scheme: SignatureScheme }, + #[snafu(display("invalid certificate: {details}"))] + InvalidCertificate { details: String }, + #[snafu(display("invalid PEM"))] + InvalidPem { source: std::io::Error }, + #[snafu(display("invalid base64"))] + InvalidBase64 { source: base64::DecodeError }, + #[snafu(display("io error"))] + Io { source: std::io::Error }, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignatureSchemeError { + #[snafu(display("unsupported public key type"))] + UnsupportedKey, +} + +pub fn sign_with_key(key: &(impl SigningKey + ?Sized), data: &[u8]) -> Result, SignError> { + dhttp_identity::identity::sign_with_key(key, data).context(sign_error::IdentitySnafu) +} + +pub(crate) fn signature_scheme( + spki: SubjectPublicKeyInfoDer<'_>, +) -> Result { + let Ok((_remain, spki)) = SubjectPublicKeyInfo::from_der(spki.as_ref()) else { + return signature_scheme_error::UnsupportedKeySnafu.fail(); + }; + + if spki.algorithm.algorithm == OID_SIG_ED25519 { + return Ok(SignatureScheme::ED25519); + } + + if spki.algorithm.algorithm == OID_PKCS1_RSAENCRYPTION { + return Ok(SignatureScheme::RSA_PSS_SHA512); + } + + if spki.algorithm.algorithm != OID_KEY_TYPE_EC_PUBLIC_KEY { + return signature_scheme_error::UnsupportedKeySnafu.fail(); + } + + let Some(curve) = spki + .algorithm + .parameters + .as_ref() + .and_then(|parameters| parameters.as_oid().ok()) + else { + return signature_scheme_error::UnsupportedKeySnafu.fail(); + }; + + if curve == OID_EC_P256 { + Ok(SignatureScheme::ECDSA_NISTP256_SHA256) + } else if curve == OID_NIST_EC_P384 { + Ok(SignatureScheme::ECDSA_NISTP384_SHA384) + } else { + signature_scheme_error::UnsupportedKeySnafu.fail() + } +} + +pub(crate) fn verify( + spki: SubjectPublicKeyInfoDer, + scheme: SignatureScheme, + data: &[u8], + signature: &[u8], +) -> Result { + let algorithm: &'static dyn ring::signature::VerificationAlgorithm = match scheme { + SignatureScheme::ECDSA_NISTP384_SHA384 => &ring::signature::ECDSA_P384_SHA384_ASN1, + SignatureScheme::ECDSA_NISTP256_SHA256 => &ring::signature::ECDSA_P256_SHA256_ASN1, + SignatureScheme::ED25519 => &ring::signature::ED25519, + SignatureScheme::RSA_PSS_SHA512 => &ring::signature::RSA_PSS_2048_8192_SHA512, + _ => return verify_error::UnsupportedSchemeSnafu { scheme }.fail(), + }; + + let public_key = match SubjectPublicKeyInfo::from_der(spki.as_ref()) { + Ok((_remain, spki)) => spki.subject_public_key, + Err(_) => return Ok(false), + }; + + Ok( + ring::signature::UnparsedPublicKey::new(algorithm, public_key) + .verify(data, signature) + .is_ok(), + ) +} diff --git a/src/parser/varint.rs b/src/core/parser/varint.rs similarity index 100% rename from src/parser/varint.rs rename to src/core/parser/varint.rs diff --git a/src/core/wire.rs b/src/core/wire.rs new file mode 100644 index 0000000..9d3f539 --- /dev/null +++ b/src/core/wire.rs @@ -0,0 +1,113 @@ +/// HTTP multi-record response wire format shared between server and all clients. +/// +/// Wire layout (big-endian, contiguous): +/// ```text +/// +-----------+ (repeated `count` times) +/// | count | +-----------+------+-----------+------+ +/// | u32 BE | | dns_len | dns | cert_len | cert | +/// +-----------+ | u32 BE | ... | u32 BE | ... | +/// +-----------+------+-----------+------+ +/// ``` +use bytes::BufMut; +use nom::{IResult, bytes::streaming::take, number::streaming::be_u32}; + +/// One DNS + certificate pair inside a [`MultiResponse`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ResponseRecord { + /// Serialised DNS packet bytes. + pub dns: Vec, + /// DER-encoded leaf certificate of the publisher, or empty when unavailable. + pub cert: Vec, +} + +impl ResponseRecord { + /// SHA-256 fingerprint of the publisher certificate as lowercase hex. + /// Returns `None` when the cert field is empty. + pub fn cert_fingerprint_hex(&self) -> Option { + if self.cert.is_empty() { + return None; + } + use ring::digest::{SHA256, digest}; + let digest = digest(&SHA256, &self.cert); + Some(digest.as_ref().iter().map(|b| format!("{b:02x}")).collect()) + } +} + +/// HTTP response body carrying zero or more DNS records. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MultiResponse { + pub records: Vec, +} + +impl MultiResponse { + pub fn new(iter: impl IntoIterator, Vec)>) -> Self { + Self { + records: iter + .into_iter() + .map(|(dns, cert)| ResponseRecord { dns, cert }) + .collect(), + } + } + + pub fn encoding_size(&self) -> usize { + 4 + self + .records + .iter() + .map(|record| 4 + record.dns.len() + 4 + record.cert.len()) + .sum::() + } + + pub fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(self.encoding_size()); + buf.put_multi_response(self); + buf + } +} + +pub trait WriteMultiResponse { + fn put_multi_response(&mut self, response: &MultiResponse); +} + +impl WriteMultiResponse for B { + fn put_multi_response(&mut self, response: &MultiResponse) { + self.put_u32(response.records.len() as u32); + for record in &response.records { + self.put_u32(record.dns.len() as u32); + self.put_slice(&record.dns); + self.put_u32(record.cert.len() as u32); + self.put_slice(&record.cert); + } + } +} + +pub fn be_multi_response(input: &[u8]) -> IResult<&[u8], MultiResponse> { + let (mut input, count) = be_u32(input)?; + let mut records = Vec::with_capacity(count as usize); + for _ in 0..count { + let (rest, dns_len) = be_u32(input)?; + let (rest, dns) = take(dns_len as usize)(rest)?; + let (rest, cert_len) = be_u32(rest)?; + let (rest, cert) = take(cert_len as usize)(rest)?; + records.push(ResponseRecord { + dns: dns.to_vec(), + cert: cert.to_vec(), + }); + input = rest; + } + Ok((input, MultiResponse { records })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn multi_response_roundtrips() { + let response = + MultiResponse::new([(vec![1, 2, 3], vec![4, 5]), (vec![6, 7, 8, 9], Vec::new())]); + let encoded = response.encode(); + let (remain, decoded) = be_multi_response(&encoded).unwrap(); + assert!(remain.is_empty()); + assert_eq!(decoded, response); + } +} diff --git a/src/lib.rs b/src/lib.rs index 1f0af97..193112c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,7 @@ -mod if_nametoindex; +mod bootstrap; + +pub mod core; pub mod mdns; -pub mod parser; -mod protocol; +#[cfg(any(feature = "h3x-resolver", feature = "mdns-resolver"))] +pub mod publisher; pub mod resolvers; -pub mod wire; - -pub type MdnsEndpoint = crate::parser::record::endpoint::EndpointAddr; -pub type MdnsPacket = crate::parser::packet::Packet; - -#[cfg(feature = "h3x-resolver")] -pub use parser::record::endpoint::sign_endponit_address; diff --git a/src/mdns.rs b/src/mdns.rs index b3fa2e3..dd51460 100644 --- a/src/mdns.rs +++ b/src/mdns.rs @@ -1,285 +1,4 @@ -use std::{ - collections::{HashMap, HashSet}, - fmt, io, - net::{IpAddr, SocketAddr}, - sync::{Arc, Mutex}, - task::{Context, Poll, ready}, - time::Duration, -}; - -use futures::{Stream, stream}; -#[cfg(feature = "h3x-resolver")] -use h3x::dquic::qbase::net::addr::BoundAddr; -use h3x::dquic::qinterface::{Interface, component::Component, io::IO}; -use tokio::{task::JoinSet, time}; -use tracing::Instrument; - -use crate::{ - parser::{packet::Packet, record::endpoint::EndpointAddr}, - protocol::MdnsProtocol, -}; - -#[derive(Clone)] -pub struct Mdns { - service_name: String, - hosts: Arc>>>, - inner: Arc>, -} - -struct MdnsInner { - proto: Arc, - tasks: JoinSet<()>, -} - -impl fmt::Debug for Mdns { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let (local_device, ip) = { - let guard = self.inner.lock().expect("Mdns inner lock poisoned"); - (guard.proto.bound_nic().to_string(), guard.proto.bound_ip()) - }; - f.debug_struct("Mdns") - .field("service_name", &self.service_name) - .field("local_device", &local_device) - .field("ip", &ip) - .finish() - } -} - -impl Mdns { - pub fn new(service_name: &str, ip: IpAddr, device: &str) -> io::Result { - let service_name = service_name.to_string(); - let hosts = Arc::new(Mutex::new(HashMap::>::new())); - let (proto, route) = MdnsProtocol::new(device, ip)?; - let proto = Arc::new(proto); - let mut tasks = JoinSet::new(); - tasks.spawn(route); - Self::spawn_tasks( - &mut tasks, - proto.clone(), - hosts.clone(), - service_name.clone(), - ); - - Ok(Self { - service_name, - hosts, - inner: Arc::new(Mutex::new(MdnsInner { proto, tasks })), - }) - } - - pub fn from_iface(service_name: &str, iface: &(impl IO + ?Sized)) -> io::Result { - let binding = iface.bind_uri(); - let Some((_family, device, _port)) = binding.as_iface_bind_uri() else { - return Err(io::Error::new( - io::ErrorKind::Unsupported, - "interface is not bound to internet address", - )); - }; - let BoundAddr::Internet(bound_addr) = iface.bound_addr()? else { - return Err(io::Error::new( - io::ErrorKind::Unsupported, - "interface is not bound to internet address", - )); - }; - - Self::new(service_name, bound_addr.ip(), device) - } - - pub fn reinit(&self, iface: &(impl IO + ?Sized)) { - // Extract interface info - - let binding = iface.bind_uri(); - let Some((_family, device, _port)) = binding.as_iface_bind_uri() else { - return; - }; - let Ok(BoundAddr::Internet(bound_addr)) = iface.bound_addr() else { - return; - }; - let ip = bound_addr.ip(); - - let mut inner = self.inner.lock().expect("Mdns inner lock poisoned"); - - // Skip if already using same device/IP with active protocol - - if inner.proto.bound_nic() == device && inner.proto.bound_ip() == ip { - return; - } - - let Ok((proto, route)) = MdnsProtocol::new(device, ip) else { - tracing::debug!(target: "mdns", device, %ip, "failed to reinit mdns protocol"); - return; - }; - inner.proto = Arc::new(proto); - - inner.tasks.abort_all(); - while inner.tasks.try_join_next().is_some() {} - - inner.tasks.spawn(route); - let proto = inner.proto.clone(); - Self::spawn_tasks( - &mut inner.tasks, - proto, - self.hosts.clone(), - self.service_name.clone(), - ); - // Update state with new protocol and tasks - } - - fn spawn_tasks( - tasks: &mut JoinSet<()>, - proto: Arc, - hosts: Arc>>>, - service_name: String, - ) { - let span = tracing::info_span!(target: "mdns", "mdns_tasks", service_name, nic = proto.bound_nic(), ip = %proto.bound_ip()); - - // (1) periodic broadcaster - tasks.spawn( - { - let proto = proto.clone(); - let service_name = service_name.clone(); - async move { - let mut interval = time::interval(Duration::from_secs(10)); - interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay); - loop { - interval.tick().await; - let packet = Packet::query(service_name.clone()); - if let Err(e) = proto.broadcast_packet(packet).await { - tracing::debug!(target: "mdns", error = %snafu::Report::from_error(&e), "broadcast packet error"); - } - } - } - } - .instrument(span.clone()), - ); - - // (2) responder - tasks.spawn( - { - let proto = proto.clone(); - let hosts = hosts.clone(); - let service_name = service_name.clone(); - async move { - loop { - let res = proto.receive_query().await; - let Ok((_src, query)) = res else { - break; - }; - - let packet = { - let guard = hosts.lock().unwrap(); - let host_name = guard - .keys() - .cloned() - .map(|h| Self::local_name(service_name.clone(), h)) - .collect::>(); - - query - .questions - .iter() - .any(|q| host_name.iter().any(|h| h.contains(q.name.as_str()))) - .then(|| Packet::answer(query.header.id, &guard)) - }; - - if let Some(packet) = packet - && let Err(e) = proto.broadcast_packet(packet).await - { - tracing::debug!(target: "mdns", error = %snafu::Report::from_error(&e), "send response error"); - } - } - } - } - .instrument(span.clone()), - ); - } - - fn poll_close(&self, cx: &mut Context<'_>) -> Poll<()> { - let mut inner = self.inner.lock().expect("Mdns inner lock poisoned"); - - inner.tasks.abort_all(); - while ready!(inner.tasks.poll_join_next(cx)).is_some() {} - - Poll::Ready(()) - } - - #[inline] - pub fn service_name(&self) -> &str { - &self.service_name - } - - pub fn bound_nic(&self) -> String { - let inner = self.inner.lock().expect("Mdns inner lock poisoned"); - inner.proto.bound_nic().to_string() - } - - pub fn bound_ip(&self) -> IpAddr { - let inner = self.inner.lock().expect("Mdns inner lock poisoned"); - inner.proto.bound_ip() - } - - #[inline] - pub fn insert_host(&self, host_name: String, eps: Vec) { - let local_name = Self::local_name(self.service_name.clone(), host_name.clone()); - let mut guard = self.hosts.lock().unwrap(); - tracing::trace!( - target: "mdns", - %local_name, ?eps, - "adding host with addresses", - ); - guard.insert(local_name, eps); - } - - #[inline] - pub(crate) fn protocol(&self) -> Arc { - self.inner - .lock() - .expect("Mdns inner lock poisoned") - .proto - .clone() - } - - #[inline] - pub fn query( - &self, - domain: String, - ) -> impl Future>> + use<> { - let proto = self.protocol(); - let local_name = Self::local_name(self.service_name.clone(), domain); - async move { - let (src, mut endpoints) = proto.query(local_name).await?; - if let Some(pos) = endpoints.iter().position(|ep| ep.addr().ip() == src.ip()) { - endpoints.swap(0, pos); - } - if endpoints.is_empty() { - return Err(io::Error::other("empty dns result")); - } - Ok(endpoints) - } - } - - #[inline] - pub fn discover(&self) -> impl Stream + use<> { - let proto = self.protocol(); - - Box::pin(stream::unfold(proto, async move |proto| { - Some((proto.receive_boardcast().await.ok()?, proto)) - })) - } - - #[inline] - fn local_name(service_name: String, name: String) -> String { - name.split_once("genmeta.net") - .map(|(prefix, _)| format!("{prefix}{service_name}")) - .unwrap_or_else(|| name) - } -} - -impl Component for Mdns { - fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { - self.poll_close(cx) - } - - fn reinit(&self, iface: &Interface) { - self.reinit(iface); - } -} +mod if_nametoindex; +mod protocol; +pub mod resolvers; +pub mod service; diff --git a/src/if_nametoindex.rs b/src/mdns/if_nametoindex.rs similarity index 100% rename from src/if_nametoindex.rs rename to src/mdns/if_nametoindex.rs diff --git a/src/protocol.rs b/src/mdns/protocol.rs similarity index 85% rename from src/protocol.rs rename to src/mdns/protocol.rs index 31580ab..3b15fc2 100644 --- a/src/protocol.rs +++ b/src/mdns/protocol.rs @@ -1,6 +1,6 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, - num::NonZero, + num::{NonZero, NonZeroU32}, pin::Pin, sync::{Arc, Weak}, task::{Context, Poll}, @@ -13,12 +13,10 @@ use snafu::Snafu; use socket2::{Domain, Socket, Type}; use tokio::{io, net::UdpSocket, task::JoinSet, time}; -use crate::{ - if_nametoindex::if_nametoindex, - parser::{ - packet::{Packet, be_packet}, - record::endpoint::EndpointAddr, - }, +use super::if_nametoindex::if_nametoindex; +use crate::core::parser::{ + packet::{Packet, be_packet}, + record::endpoint::EndpointAddr, }; #[derive(Debug)] @@ -52,6 +50,22 @@ impl MdnsSocket { socket.bind(&bind.into())?; #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] socket.bind_device(Some(device.as_bytes()))?; + #[cfg(any( + target_os = "ios", + target_os = "visionos", + target_os = "macos", + target_os = "tvos", + target_os = "watchos", + ))] + { + let ifindex = NonZeroU32::new(if_nametoindex(device)?).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "interface index must be non-zero", + ) + })?; + socket.bind_device_by_index_v4(Some(ifindex))?; + } // Always enable multicast loopback so that mDNS services on the // same host (but in different processes) can communicate. socket.set_multicast_loop_v4(true)?; @@ -77,9 +91,22 @@ impl MdnsSocket { // same host (but in different processes) can communicate. socket.set_multicast_loop_v6(true)?; // TODO: 外面传进来 - let ifindex = if_nametoindex(device)?; - socket.join_multicast_v6(&MULTICAST_ADDR_V6, ifindex)?; - socket.set_multicast_if_v6(ifindex)?; + let ifindex = NonZeroU32::new(if_nametoindex(device)?).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "interface index must be non-zero", + ) + })?; + #[cfg(any( + target_os = "ios", + target_os = "visionos", + target_os = "macos", + target_os = "tvos", + target_os = "watchos", + ))] + socket.bind_device_by_index_v6(Some(ifindex))?; + socket.join_multicast_v6(&MULTICAST_ADDR_V6, ifindex.get())?; + socket.set_multicast_if_v6(ifindex.get())?; socket } @@ -192,7 +219,7 @@ impl PacketRouter { } pub fn deliver(&self, source: SocketAddr, packet: Packet) { - match (packet.header.flags.query(), packet.header.id) { + match (packet.is_query(), packet.id()) { (true, 0) => { if self.responses.0.try_send((source, packet.clone())).is_err() { // Queue is full, remove oldest message (FIFO) @@ -288,7 +315,7 @@ impl MdnsProtocol { let router = self.router.upgrade().ok_or(Disconnected)?; let packet = Packet::query_with_id(local_name.clone()); - let query_id = NonZero::new(packet.header.id).ok_or_else(|| { + let query_id = NonZero::new(packet.id()).ok_or_else(|| { io::Error::new(io::ErrorKind::InvalidInput, "Query id should not be 0") })?; @@ -305,7 +332,7 @@ impl MdnsProtocol { if let Ok(Some((source, packet))) = time::timeout(Duration::from_millis(300), packets.next()).await { - use crate::parser::record::RData::*; + use crate::core::parser::record::RData::*; let endpoints = packet .answers .iter() @@ -313,17 +340,17 @@ impl MdnsProtocol { tracing::debug!(target: "mdns", ?answer, "recv response"); }) .filter(|answer| { - if answer.name != local_name { + if answer.name() != local_name { tracing::debug!( target: "mdns", - answer_name = answer.name, + answer_name = answer.name(), local_name, "ignored answer for different service name", ); } - answer.name == local_name + answer.name() == local_name }) - .filter_map(|answer| match &answer.data { + .filter_map(|answer| match answer.data() { E(e) => Some(e.clone()), _ => { tracing::debug!(target: "mdns", ?answer, "ignored record"); diff --git a/src/mdns/resolvers.rs b/src/mdns/resolvers.rs new file mode 100644 index 0000000..1bd416e --- /dev/null +++ b/src/mdns/resolvers.rs @@ -0,0 +1 @@ +pub mod mdns; diff --git a/src/mdns/resolvers/mdns.rs b/src/mdns/resolvers/mdns.rs new file mode 100644 index 0000000..16912ab --- /dev/null +++ b/src/mdns/resolvers/mdns.rs @@ -0,0 +1,354 @@ +use std::{fmt, io, net::IpAddr}; +#[cfg(feature = "mdns-resolver")] +use std::{net::SocketAddr, sync::Arc}; + +#[cfg(feature = "mdns-resolver")] +use dquic::qresolve::RecordStream; +use dquic::{ + qbase::net::Family, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, +}; +use futures::{FutureExt, StreamExt, TryFutureExt, future, stream}; +#[cfg(feature = "mdns-resolver")] +use futures::{Stream, stream::FuturesUnordered}; + +#[cfg(feature = "mdns-resolver")] +use super::super::protocol::MdnsProtocol; +#[cfg(feature = "mdns-resolver")] +use crate::core::parser::packet::Packet; +use crate::core::parser::record::RData; +pub type MdnsResolver = crate::mdns::service::Mdns; + +impl MdnsResolver { + pub fn source(&self) -> Source { + Source::Mdns { + nic: self.bound_nic().into(), + family: match self.bound_ip() { + IpAddr::V4(..) => Family::V4, + IpAddr::V6(..) => Family::V6, + }, + } + } +} + +impl fmt::Display for MdnsResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.source(), f) + } +} + +impl Publish for MdnsResolver { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + let endpoints = match endpoints_from_packet(packet) { + Ok(endpoints) => endpoints, + Err(error) => return future::ready(Err(error)).boxed(), + }; + self.insert_host(name.to_string(), endpoints); + future::ready(Ok(())).boxed() + } +} + +impl Resolve for MdnsResolver { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + let source = self.source(); + self.query(name.to_owned()) + .map_ok(move |list| { + let endpoints = crate::resolvers::selector::selected_endpoint_addrs(list); + stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() + }) + .boxed() + } +} + +fn endpoints_from_packet(packet: &[u8]) -> io::Result> { + use crate::core::parser::packet::be_packet; + + be_packet(packet) + .map(|(_, pkt)| { + pkt.answers + .iter() + .filter_map(|rr| match rr.data() { + RData::E(ep) => Some(ep.clone()), + _ => None, + }) + .collect::>() + }) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string())) +} + +#[cfg(feature = "mdns-resolver")] +pub struct MdnsBindDriver { + iface_manager: Arc, + null_io_factory: Arc, + service_name: Arc, +} + +#[cfg(feature = "mdns-resolver")] +impl MdnsBindDriver { + pub fn new(service_name: impl Into>) -> Self { + Self { + iface_manager: Arc::new(h3x::dquic::net::InterfaceManager::new()), + null_io_factory: Arc::new(h3x::dquic::NullIoFactory), + service_name: service_name.into(), + } + } + + fn install_or_rebind_mdns( + &self, + network: &h3x::dquic::Network, + bind_iface: &h3x::dquic::net::BindInterface, + ) { + let bind_uri = bind_iface.bind_uri(); + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + tracing::debug!(%bind_uri, "skipping mdns binding for non-interface bind uri"); + return; + }; + let Some(ip) = network.resolve_device_addr(device, family) else { + tracing::debug!(%bind_uri, "skipping mdns binding without local interface address"); + return; + }; + + bind_iface.with_components_mut(|components, _iface| { + match components.try_init_with(|| crate::mdns::service::Mdns::new(&self.service_name, ip, device)) { + Ok(mdns) => mdns.reinit_on(device, ip), + Err(error) => { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, %bind_uri, "failed to initialize mdns binding"); + } + } + }); + } +} + +#[cfg(feature = "mdns-resolver")] +impl h3x::dquic::BindDriver for MdnsBindDriver { + fn bind<'a>( + &'a self, + network: &'a h3x::dquic::Network, + uri: h3x::dquic::net::BindUri, + ) -> futures::future::BoxFuture<'a, h3x::dquic::net::BindInterface> { + async move { + let iface = self + .iface_manager + .bind(uri, self.null_io_factory.clone()) + .await; + self.install_or_rebind_mdns(network, &iface); + iface + } + .boxed() + } + + fn rebind<'a>( + &'a self, + network: &'a h3x::dquic::Network, + iface: &'a h3x::dquic::net::BindInterface, + ) -> futures::future::BoxFuture<'a, ()> { + async move { + self.install_or_rebind_mdns(network, iface); + } + .boxed() + } +} + +#[cfg(feature = "mdns-resolver")] +pub struct MdnsResolvers { + network: Arc, + driver: Arc, + patterns: Arc>, + _handles: Vec, +} + +#[cfg(feature = "mdns-resolver")] +#[derive(Debug, Clone)] +pub struct BoundMdnsResolver { + pub device: String, + pub family: Family, + pub resolver: MdnsResolver, +} + +#[cfg(feature = "mdns-resolver")] +impl fmt::Debug for MdnsResolvers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MdnsResolvers") + .field("patterns", &self.patterns) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "mdns-resolver")] +impl fmt::Display for MdnsResolvers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("mDNS resolvers") + } +} + +#[cfg(feature = "mdns-resolver")] +impl MdnsResolvers { + pub async fn bind( + network: Arc, + patterns: Arc>, + service_name: impl Into>, + ) -> Self { + let driver = Arc::new(MdnsBindDriver::new(service_name)); + let mut handles = Vec::with_capacity(patterns.len()); + for pattern in patterns.iter() { + handles.push(network.bind_with(driver.clone(), pattern.clone()).await); + } + + Self { + network, + driver, + patterns, + _handles: handles, + } + } + + pub fn bound_interfaces( + &self, + pattern: &h3x::dquic::binds::BindPattern, + ) -> Option> { + self.network.get_interfaces_with(&self.driver, pattern) + } + + fn for_each_resolver(&self, mut f: impl FnMut(&MdnsResolver)) { + for pattern in self.patterns.iter() { + let Some(ifaces) = self.bound_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + iface.with_components(|components, _| { + if let Some(mdns) = components.get::() { + f(mdns); + } + }); + } + } + } + + pub fn bound_resolvers(&self) -> Vec { + let mut resolvers = Vec::new(); + for pattern in self.patterns.iter() { + let Some(ifaces) = self.bound_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + let bind_uri = iface.bind_uri(); + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + continue; + }; + iface.with_components(|components, _| { + if let Some(resolver) = components.get::() { + resolvers.push(BoundMdnsResolver { + device: device.to_owned(), + family, + resolver: resolver.clone(), + }); + } + }); + } + } + resolvers + } + + pub async fn query(&self, name: &str) -> io::Result { + let mut lookup_futures = FuturesUnordered::new(); + let mut has_resolver = false; + self.for_each_resolver(|resolver| { + has_resolver = true; + let source = resolver.source(); + lookup_futures.push( + resolver + .query(name.to_owned()) + .map_ok(move |eps| (source, eps)), + ); + }); + if !has_resolver { + return Err(io::Error::other("no mdns resolvers available")); + } + + let mut last_error = None; + let mut has_success = false; + let mut records = Vec::new(); + while let Some(result) = lookup_futures.next().await { + match result { + Ok((source, endpoints)) => { + has_success = true; + records.extend( + endpoints + .into_iter() + .map(|endpoint| (source.clone(), endpoint)), + ); + } + Err(error) => last_error = Some(error), + } + } + + if !has_success { + return Err( + last_error.unwrap_or_else(|| io::Error::other("no mdns resolvers available")) + ); + } + + let records = crate::resolvers::selector::selected_endpoint_records(records); + + Ok(stream::iter(records).boxed()) + } + + /// Discover mDNS broadcasts from all active resolvers. + pub fn discover(&self) -> impl Stream + use<> { + let mut protos = Vec::new(); + self.for_each_resolver(|resolver| { + protos.push(resolver.protocol()); + }); + + async fn receive_one( + proto: Arc, + ) -> Option<((SocketAddr, Packet), Arc)> { + let result = proto.receive_boardcast().await.ok()?; + Some((result, proto)) + } + + let mut pending = protos + .into_iter() + .map(receive_one) + .collect::>(); + + Box::pin(stream::poll_fn(move |cx| { + use std::task::Poll; + loop { + match pending.poll_next_unpin(cx) { + Poll::Ready(Some(Some((item, proto)))) => { + pending.push(receive_one(proto)); + return Poll::Ready(Some(item)); + } + Poll::Ready(Some(None)) => continue, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + })) + } +} + +#[cfg(feature = "mdns-resolver")] +impl Publish for MdnsResolvers { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + let endpoints = match endpoints_from_packet(packet) { + Ok(endpoints) => endpoints, + Err(error) => return future::ready(Err(error)).boxed(), + }; + + self.for_each_resolver(|resolver| { + resolver.insert_host(name.to_string(), endpoints.clone()); + }); + + future::ready(Ok(())).boxed() + } +} + +#[cfg(feature = "mdns-resolver")] +impl Resolve for MdnsResolvers { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + self.query(name).boxed() + } +} diff --git a/src/mdns/service.rs b/src/mdns/service.rs new file mode 100644 index 0000000..d6a5d69 --- /dev/null +++ b/src/mdns/service.rs @@ -0,0 +1,291 @@ +use std::{ + collections::{HashMap, HashSet}, + fmt, io, + net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex}, + task::{Context, Poll, ready}, + time::Duration, +}; + +use dhttp_identity::name::DhttpName; +use dquic::qinterface::{Interface, component::Component, io::IO}; +use futures::{Stream, stream}; +use tokio::{task::JoinSet, time}; +use tracing::Instrument; + +use super::protocol::MdnsProtocol; +use crate::core::parser::{packet::Packet, record::endpoint::EndpointAddr}; + +#[derive(Clone)] +pub struct Mdns { + service_name: String, + hosts: Arc>>>, + inner: Arc>, +} + +struct MdnsInner { + proto: Arc, + tasks: JoinSet<()>, +} + +impl fmt::Debug for Mdns { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let (local_device, ip) = { + let guard = self.inner.lock().expect("Mdns inner lock poisoned"); + (guard.proto.bound_nic().to_string(), guard.proto.bound_ip()) + }; + f.debug_struct("Mdns") + .field("service_name", &self.service_name) + .field("local_device", &local_device) + .field("ip", &ip) + .finish() + } +} + +impl Mdns { + pub fn new(service_name: &str, ip: IpAddr, device: &str) -> io::Result { + let service_name = service_name.to_string(); + let hosts = Arc::new(Mutex::new(HashMap::>::new())); + let (proto, route) = MdnsProtocol::new(device, ip)?; + let proto = Arc::new(proto); + let mut tasks = JoinSet::new(); + tasks.spawn(route); + Self::spawn_tasks( + &mut tasks, + proto.clone(), + hosts.clone(), + service_name.clone(), + ); + + Ok(Self { + service_name, + hosts, + inner: Arc::new(Mutex::new(MdnsInner { proto, tasks })), + }) + } + + pub fn from_iface(service_name: &str, iface: &(impl IO + ?Sized)) -> io::Result { + let binding = iface.bind_uri(); + let Some((_family, device, _port)) = binding.as_iface_bind_uri() else { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "interface is not bound to internet address", + )); + }; + let bound_addr = iface.bound_addr()?; + + Self::new(service_name, bound_addr.ip(), device) + } + + pub fn reinit(&self, iface: &(impl IO + ?Sized)) { + let binding = iface.bind_uri(); + let Some((_family, device, _port)) = binding.as_iface_bind_uri() else { + return; + }; + let Ok(bound_addr) = iface.bound_addr() else { + return; + }; + + self.reinit_on(device, bound_addr.ip()); + } + + pub fn reinit_on(&self, device: &str, ip: IpAddr) { + let mut inner = self.inner.lock().expect("Mdns inner lock poisoned"); + + if inner.proto.bound_nic() == device && inner.proto.bound_ip() == ip { + return; + } + + let Ok((proto, route)) = MdnsProtocol::new(device, ip) else { + tracing::debug!(target: "mdns", device, %ip, "failed to reinit mdns protocol"); + return; + }; + inner.proto = Arc::new(proto); + + inner.tasks.abort_all(); + while inner.tasks.try_join_next().is_some() {} + + inner.tasks.spawn(route); + let proto = inner.proto.clone(); + Self::spawn_tasks( + &mut inner.tasks, + proto, + self.hosts.clone(), + self.service_name.clone(), + ); + } + + fn spawn_tasks( + tasks: &mut JoinSet<()>, + proto: Arc, + hosts: Arc>>>, + service_name: String, + ) { + let span = tracing::info_span!(target: "mdns", "mdns_tasks", service_name, nic = proto.bound_nic(), ip = %proto.bound_ip()); + + // (1) periodic broadcaster + tasks.spawn( + { + let proto = proto.clone(); + let service_name = service_name.clone(); + async move { + let mut interval = time::interval(Duration::from_secs(10)); + interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay); + loop { + interval.tick().await; + let packet = Packet::query(service_name.clone()); + if let Err(e) = proto.broadcast_packet(packet).await { + tracing::debug!(target: "mdns", error = %snafu::Report::from_error(&e), "broadcast packet error"); + } + } + } + } + .instrument(span.clone()), + ); + + // (2) responder + tasks.spawn( + { + let proto = proto.clone(); + let hosts = hosts.clone(); + let service_name = service_name.clone(); + async move { + loop { + let res = proto.receive_query().await; + let Ok((_src, query)) = res else { + break; + }; + + let packet = { + let guard = hosts.lock().unwrap(); + let host_name = guard + .keys() + .cloned() + .map(|h| Self::local_name(service_name.clone(), h)) + .collect::>(); + + query + .questions + .iter() + .any(|q| host_name.iter().any(|h| h.contains(q.name().as_str()))) + .then(|| Packet::answer(query.id(), &guard)) + }; + + if let Some(packet) = packet + && let Err(e) = proto.broadcast_packet(packet).await + { + tracing::debug!(target: "mdns", error = %snafu::Report::from_error(&e), "send response error"); + } + } + } + } + .instrument(span.clone()), + ); + } + + fn poll_close(&self, cx: &mut Context<'_>) -> Poll<()> { + let mut inner = self.inner.lock().expect("Mdns inner lock poisoned"); + + inner.tasks.abort_all(); + while ready!(inner.tasks.poll_join_next(cx)).is_some() {} + + Poll::Ready(()) + } + + #[inline] + pub fn service_name(&self) -> &str { + &self.service_name + } + + pub fn bound_nic(&self) -> String { + let inner = self.inner.lock().expect("Mdns inner lock poisoned"); + inner.proto.bound_nic().to_string() + } + + pub fn bound_ip(&self) -> IpAddr { + let inner = self.inner.lock().expect("Mdns inner lock poisoned"); + inner.proto.bound_ip() + } + + #[inline] + pub fn insert_host(&self, host_name: String, eps: Vec) { + let local_name = Self::local_name(self.service_name.clone(), host_name.clone()); + let mut guard = self.hosts.lock().unwrap(); + tracing::trace!( + target: "mdns", + %local_name, ?eps, + "adding host with addresses", + ); + guard.insert(local_name, eps); + } + + #[inline] + pub(crate) fn protocol(&self) -> Arc { + self.inner + .lock() + .expect("Mdns inner lock poisoned") + .proto + .clone() + } + + #[inline] + pub fn query( + &self, + domain: String, + ) -> impl Future>> + use<> { + let proto = self.protocol(); + let local_name = Self::local_name(self.service_name.clone(), domain); + async move { + let (src, mut endpoints) = proto.query(local_name).await?; + if let Some(pos) = endpoints.iter().position(|ep| ep.addr().ip() == src.ip()) { + endpoints.swap(0, pos); + } + if endpoints.is_empty() { + return Err(io::Error::other("empty dns result")); + } + Ok(endpoints) + } + } + + #[inline] + pub fn discover(&self) -> impl Stream + use<> { + let proto = self.protocol(); + + Box::pin(stream::unfold(proto, async move |proto| { + Some((proto.receive_boardcast().await.ok()?, proto)) + })) + } + + #[inline] + fn local_name(service_name: String, name: String) -> String { + name.strip_suffix(DhttpName::SUFFIX) + .map(|prefix| format!("{prefix}.{service_name}")) + .unwrap_or_else(|| name) + } +} + +impl Component for Mdns { + fn poll_shutdown(&self, cx: &mut Context<'_>) -> Poll<()> { + self.poll_close(cx) + } + + fn reinit(&self, iface: &Interface) { + self.reinit(iface); + } +} + +#[cfg(test)] +mod tests { + use super::Mdns; + + #[test] + fn local_name_uses_dhttp_identity_suffix() { + assert_eq!( + Mdns::local_name( + "_gensokyo.local".to_string(), + "reimu.pilot.dhttp.net".to_string() + ), + "reimu.pilot._gensokyo.local" + ); + } +} diff --git a/src/parser.rs b/src/parser.rs deleted file mode 100644 index 69b002e..0000000 --- a/src/parser.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub(crate) mod header; -pub(crate) mod name; -pub mod packet; -pub(crate) mod question; -pub mod record; -pub mod sigin; -pub mod varint; - -pub use name::{NameCompression, put_name}; diff --git a/src/parser/sigin.rs b/src/parser/sigin.rs deleted file mode 100644 index fa366c1..0000000 --- a/src/parser/sigin.rs +++ /dev/null @@ -1,79 +0,0 @@ -use rustls::{SignatureScheme, pki_types::SubjectPublicKeyInfoDer, sign::SigningKey}; -use snafu::Snafu; -use x509_parser::prelude::FromDer; - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum SignError { - #[snafu(display("unsupported signature scheme {scheme:?}"))] - UnsupportedScheme { scheme: SignatureScheme }, - #[snafu(display("crypto error"))] - Crypto { - #[snafu(source(false))] - source: rustls::Error, - }, -} - -impl From for SignError { - fn from(source: rustls::Error) -> Self { - Self::Crypto { source } - } -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum VerifyError { - #[snafu(display("unsupported signature scheme {scheme:?}"))] - UnsupportedScheme { scheme: SignatureScheme }, - #[snafu(display("invalid certificate: {details}"))] - InvalidCertificate { details: String }, - #[snafu(display("invalid PEM"))] - InvalidPem { source: std::io::Error }, - #[snafu(display("invalid base64"))] - InvalidBase64 { source: base64::DecodeError }, - #[snafu(display("IO error"))] - Io { source: std::io::Error }, -} - -pub(crate) fn sign( - key: &(impl SigningKey + ?Sized), - scheme: SignatureScheme, - data: &[u8], -) -> Result, SignError> { - // FIXME: same as load spki then sign with ring? - let signer = key - .choose_scheme(&[scheme]) - .ok_or(SignError::UnsupportedScheme { scheme })?; - Ok(signer.sign(data)?) -} - -pub(crate) fn verify( - spki: SubjectPublicKeyInfoDer, - scheme: SignatureScheme, - data: &[u8], - signature: &[u8], -) -> Result { - let algorithm: &'static dyn ring::signature::VerificationAlgorithm = match scheme { - SignatureScheme::ECDSA_NISTP384_SHA384 => &ring::signature::ECDSA_P384_SHA384_ASN1, - SignatureScheme::ECDSA_NISTP256_SHA256 => &ring::signature::ECDSA_P256_SHA256_ASN1, - SignatureScheme::ED25519 => &ring::signature::ED25519, - SignatureScheme::RSA_PKCS1_SHA256 => &ring::signature::RSA_PKCS1_2048_8192_SHA256, - SignatureScheme::RSA_PKCS1_SHA384 => &ring::signature::RSA_PKCS1_2048_8192_SHA384, - SignatureScheme::RSA_PKCS1_SHA512 => &ring::signature::RSA_PKCS1_2048_8192_SHA512, - SignatureScheme::RSA_PSS_SHA256 => &ring::signature::RSA_PSS_2048_8192_SHA512, - SignatureScheme::RSA_PSS_SHA384 => &ring::signature::RSA_PSS_2048_8192_SHA384, - SignatureScheme::RSA_PSS_SHA512 => &ring::signature::RSA_PSS_2048_8192_SHA512, - _ => return Err(VerifyError::UnsupportedScheme { scheme }), - }; - - let public_key = match x509_parser::x509::SubjectPublicKeyInfo::from_der(&spki) { - Ok((_remain, spki)) => spki.subject_public_key, - Err(_error) => unreachable!("rustls returned an invalid peer_certificates."), - }; - - Ok( - ring::signature::UnparsedPublicKey::new(algorithm, public_key) - .verify(data, signature) - .is_ok(), - ) -} diff --git a/src/publisher.rs b/src/publisher.rs new file mode 100644 index 0000000..13ba706 --- /dev/null +++ b/src/publisher.rs @@ -0,0 +1,802 @@ +mod address; +mod dispatch; +mod packet; + +use std::{any::TypeId, future::Future, io, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; + +pub use address::{ + AddressSelector, AddressView, AddressViewSource, EndpointBindingAddresses, FnAddressView, + PublishAddressGroup, PublishAddressScope, PublishAddresses, +}; +use dhttp_identity::{identity::LocalAuthority, name::Name}; +use dquic::{ + qinterface::component::location::AddressEvent, qresolve::Resolve, + qtraversal::nat::client::ClientLocationData, +}; +pub use packet::{EndpointRecordSigner, SignEndpointRecordsError}; +use snafu::Snafu; + +pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); +/// Upper bound for a single publish attempt in the background loop. +/// +/// Network changes can leave an in-flight H3 publish waiting on paths that no +/// longer exist. Timing out the attempt keeps consecutive publishes +/// independent: the next interval observes the current bindings again. +pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); +const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); + +type PublishLoopFuture<'a> = Pin + Send + 'a>>; + +#[derive(Debug, Snafu)] +#[snafu(module(create_publisher_error))] +pub enum CreatePublisherError { + #[snafu(display("anonymous endpoint cannot publish dns records"))] + AnonymousEndpoint, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum PublishOnceError { + #[snafu(display("no publisher resolver available"))] + NoPublisherResolver, + #[snafu(display("failed to sign endpoint records"))] + SignEndpointRecords { source: SignEndpointRecordsError }, + #[snafu(display("failed to publish dns packet with {publisher}"))] + Publish { + publisher: String, + source: io::Error, + }, +} + +pub trait PublisherResolver: Send + Sync + 'static { + fn as_resolver(&self) -> &(dyn Resolve + Send + Sync); +} + +impl PublisherResolver for T +where + T: Resolve + Send + Sync + Sized + 'static, +{ + fn as_resolver(&self) -> &(dyn Resolve + Send + Sync) { + self + } +} + +impl PublisherResolver for dyn Resolve + Send + Sync { + fn as_resolver(&self) -> &(dyn Resolve + Send + Sync) { + self + } +} + +pub struct Publisher { + signer: EndpointRecordSigner, + resolver: Arc, +} + +impl std::fmt::Debug for Publisher +where + A: LocalAuthority + Send + Sync + ?Sized, + R: PublisherResolver + ?Sized, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Publisher") + .field("signer", &self.signer) + .field("resolver", &self.resolver.as_resolver().to_string()) + .finish() + } +} + +pub type EndpointPublisher = Publisher; + +impl Publisher +where + A: LocalAuthority + Send + Sync + ?Sized, + R: PublisherResolver + ?Sized, +{ + pub fn new(signer: EndpointRecordSigner, resolver: Arc) -> Self { + Self { signer, resolver } + } + + pub fn signer(&self) -> &EndpointRecordSigner { + &self.signer + } + + pub fn resolver(&self) -> &Arc { + &self.resolver + } + + pub async fn publish_once( + &self, + name: &Name<'_>, + addresses: &V, + ) -> Result<(), PublishOnceError> + where + V: AddressView + Sync, + { + let mut published = false; + published |= self + .publish_to_resolver(self.resolver.as_resolver(), name, addresses) + .await?; + + if !published { + return publish_once_error::NoPublisherResolverSnafu.fail(); + } + + Ok(()) + } +} + +pub struct EndpointPublicationLoop { + name: Name<'static>, + publisher: Publisher, + source: S, + interval: Duration, + publish_timeout: Duration, +} + +impl std::fmt::Debug for EndpointPublicationLoop +where + A: LocalAuthority + Send + Sync + ?Sized, + R: PublisherResolver + ?Sized, + S: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointPublicationLoop") + .field("name", &self.name) + .field("publisher", &self.publisher) + .field("source", &self.source) + .field("interval", &self.interval) + .field("publish_timeout", &self.publish_timeout) + .finish() + } +} + +impl EndpointPublicationLoop +where + A: LocalAuthority + Send + Sync + ?Sized, + R: PublisherResolver + ?Sized, + S: AddressViewSource + Send + Sync, +{ + pub fn new(name: Name<'static>, publisher: Publisher, source: S) -> Self { + Self { + name, + publisher, + source, + interval: DEFAULT_PUBLISH_INTERVAL, + publish_timeout: DEFAULT_PUBLISH_TIMEOUT, + } + } + + pub fn name(&self) -> &Name<'static> { + &self.name + } + + pub fn publisher(&self) -> &Publisher { + &self.publisher + } + + pub fn interval(&self) -> Duration { + self.interval + } + + pub fn publish_timeout(&self) -> Duration { + self.publish_timeout + } + + pub fn with_publish_timeout(mut self, timeout: Duration) -> Self { + self.publish_timeout = timeout; + self + } + + pub async fn run(&self) -> ! { + let mut locations = self.source.subscribe(); + let interval = tokio::time::sleep(self.interval); + tokio::pin!(interval); + // Keep at most one publish attempt in flight. A timer tick or + // publishable location change drops the current future and starts a new + // debounced attempt so a stale H3 publish cannot block publication from + // the latest bindings. + let mut current_publish = self.new_publish_loop_future(); + + loop { + tokio::select! { + _ = &mut current_publish => { + current_publish = Self::pending_publish_loop_future(); + } + _ = &mut interval => { + interval.as_mut().reset(tokio::time::Instant::now() + self.interval); + self.clear_publish_state(); + current_publish = self.new_publish_loop_future(); + } + event = locations.recv() => { + let Some((bind_uri, event)) = event else { + continue; + }; + if !self.source.observes(&bind_uri) { + continue; + } + if !Self::location_event_requires_publish(&event) { + continue; + } + + self.clear_publish_state(); + current_publish = self.new_publish_loop_future(); + } + } + } + } + + fn new_publish_loop_future(&self) -> PublishLoopFuture<'_> { + Box::pin(async move { + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; + let _ = self.publish_attempt().await; + }) + } + + fn pending_publish_loop_future<'a>() -> PublishLoopFuture<'a> { + Box::pin(std::future::pending()) + } + + async fn publish_attempt(&self) -> bool { + tracing::trace!( + timeout_ms = self.publish_timeout.as_millis(), + "starting dns publish attempt" + ); + let addresses = self.source.address_view(); + match tokio::time::timeout( + self.publish_timeout, + self.publisher.publish_once(&self.name, &addresses), + ) + .await + { + Ok(Ok(())) => { + tracing::info!(name = %self.name, "published resolver endpoints"); + true + } + Ok(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, name = %self.name, "dns publish failed"); + false + } + Err(_elapsed) => { + // Dropping a timed-out publish future does not let the H3 + // resolver observe a request error. Reset resolver-owned + // connection state so the next interval reconnects from + // the current network bindings. + self.clear_publish_state(); + tracing::warn!( + timeout_ms = self.publish_timeout.as_millis(), + name = %self.name, + "dns publish timed out" + ); + false + } + } + } + + fn clear_publish_state(&self) { + dispatch::clear_resolver_publish_state(self.publisher.resolver.as_resolver()); + } + + fn location_event_requires_publish(event: &AddressEvent) -> bool { + match event { + AddressEvent::Upsert(data) => { + // `Locations` also carries transient STUN failures. Those do + // not add a publishable endpoint; treating them as publish + // triggers creates a retry loop while the node is offline. + if let Some(bound_addr) = data.downcast_ref::>() { + return bound_addr.is_ok(); + } + if let Some(stun_addr) = data.downcast_ref::() { + return stun_addr.is_ok(); + } + false + } + AddressEvent::Remove(type_id) => { + *type_id == TypeId::of::>() + || *type_id == TypeId::of::() + } + AddressEvent::Closed => true, + } + } +} + +pub type EndpointPublisherLoop = EndpointPublicationLoop< + dyn LocalAuthority + Send + Sync, + dyn Resolve + Send + Sync, + EndpointBindingAddresses, +>; + +#[cfg(test)] +mod tests { + #[cfg(feature = "http-resolver")] + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::{fmt, sync::Arc, time::Duration}; + + use dquic::qresolve::{ResolveFuture, Source}; + use futures::{FutureExt, StreamExt, future::BoxFuture, stream}; + use rustls::pki_types::{CertificateDer, SubjectPublicKeyInfoDer}; + #[cfg(feature = "http-resolver")] + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use super::*; + + #[derive(Debug)] + struct TestAuthority; + + const ED25519_TEST_SPKI: [u8; 44] = [ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]; + + impl LocalAuthority for TestAuthority { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + static CERTS: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| { + vec![CertificateDer::from( + include_bytes!("../tests/fixtures/valid.der").to_vec(), + )] + }); + CERTS.as_slice() + } + + fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { + SubjectPublicKeyInfoDer::from(ED25519_TEST_SPKI.as_slice()) + } + + fn sign( + &self, + _data: &[u8], + ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { + // Match the Ed25519 signature length used by DHTTP's canonical + // key-to-scheme policy. Short fake signatures can collide with + // legacy E-record fixed RDLENGTH values during parser + // compatibility dispatch. + Box::pin(async move { Ok(vec![0x2a; 64]) }) + } + } + + #[derive(Debug)] + struct DisplayOnlyResolver; + + impl fmt::Display for DisplayOnlyResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("display only resolver") + } + } + + impl Resolve for DisplayOnlyResolver { + fn lookup<'l>(&'l self, _name: &'l str) -> ResolveFuture<'l> { + async { Ok(stream::empty::<(Source, dquic::qbase::net::addr::EndpointAddr)>().boxed()) } + .boxed() + } + } + + fn test_name() -> Name<'static> { + "authority.example".parse().unwrap() + } + + fn test_publisher(resolver: Arc) -> Publisher + where + R: Resolve + Send + Sync, + { + let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); + Publisher::new(signer, resolver) + } + + fn test_source(network: Arc) -> EndpointBindingAddresses { + EndpointBindingAddresses::new( + network, + Arc::new(vec![ + "inet://127.0.0.1:0".parse().expect("valid bind pattern"), + ]), + ) + } + + #[tokio::test] + async fn publish_once_reports_no_publisher_resolver() { + let publisher = test_publisher(Arc::new(DisplayOnlyResolver)); + let addresses = + PublishAddresses::new().wide_area([dquic::qbase::net::addr::EndpointAddr::direct( + "127.0.0.1:443".parse().unwrap(), + )]); + + let error = publisher + .publish_once(&test_name(), &addresses) + .await + .unwrap_err(); + + assert!(matches!(error, PublishOnceError::NoPublisherResolver)); + } + + #[tokio::test] + async fn publisher_timeout_is_configurable() { + let network = h3x::dquic::Network::builder().build(); + let publisher = test_publisher(Arc::new(DisplayOnlyResolver)); + let publisher_loop = + EndpointPublicationLoop::new(test_name(), publisher, test_source(network)); + assert_eq!(publisher_loop.publish_timeout(), DEFAULT_PUBLISH_TIMEOUT); + + let timeout = Duration::from_secs(3); + let publisher_loop = publisher_loop.with_publish_timeout(timeout); + assert_eq!(publisher_loop.publish_timeout(), timeout); + } + + #[tokio::test] + async fn signer_applies_certificate_selector_from_authority_ski() { + let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); + let name: Name<'static> = "authority.example".parse().unwrap(); + + let endpoint = + dquic::qbase::net::addr::EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); + let packet = signer.signed_packet(&name, &[endpoint]).await.unwrap(); + let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); + let record = packet.answers.first().expect("endpoint answer"); + let crate::core::parser::record::RData::E(endpoint) = record.data() else { + panic!("expected endpoint record"); + }; + + assert!(endpoint.is_main()); + assert!(!endpoint.is_clustered()); + assert!(endpoint.is_signed()); + assert_eq!( + endpoint.certificate_chain_key().unwrap().sequence().get(), + 0 + ); + } + + #[tokio::test] + async fn signer_uses_supplied_record_owner_name() { + let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); + let name: Name<'static> = "nat.genmeta.net".parse().unwrap(); + + let endpoint = + dquic::qbase::net::addr::EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); + let packet = signer.signed_packet(&name, &[endpoint]).await.unwrap(); + let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); + let record = packet.answers.first().expect("endpoint answer"); + let crate::core::parser::record::RData::E(endpoint) = record.data() else { + panic!("expected endpoint record"); + }; + + assert_eq!(record.name().to_string(), "nat.genmeta.net"); + assert!(endpoint.is_main()); + assert!(!endpoint.is_clustered()); + assert!(endpoint.is_signed()); + } + + #[tokio::test] + async fn binding_address_view_does_not_expose_loopback_as_wide_area_without_stun() { + let network = h3x::dquic::Network::builder().build(); + let bind_pattern: h3x::dquic::binds::BindPattern = + "inet://127.0.0.1:0".parse().expect("valid bind pattern"); + let _bind = network.quic().bind(bind_pattern.clone()).await; + let source = EndpointBindingAddresses::new(network, Arc::new(vec![bind_pattern])); + let view = source.address_view(); + + assert!(view.endpoints(AddressSelector::WideArea).next().is_none()); + } + + #[cfg(feature = "http-resolver")] + #[tokio::test] + async fn run_restarts_when_publish_attempt_observes_location_change() { + async fn wait_for_count(count: &AtomicUsize, target: usize) { + loop { + if count.load(Ordering::SeqCst) >= target { + return; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + } + + let network = h3x::dquic::Network::builder().build(); + let bind_uri: h3x::dquic::net::BindUri = + "inet://127.0.0.1:0".parse().expect("valid bind uri"); + let publish_count = Arc::new(AtomicUsize::new(0)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test http server"); + let port = listener.local_addr().expect("local addr").port(); + let server_network = network.clone(); + let server_bind_uri = bind_uri.clone(); + let server_count = publish_count.clone(); + let server = tokio::spawn(async move { + loop { + let Ok((mut stream, _peer)) = listener.accept().await else { + break; + }; + let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; + let mut buf = [0_u8; 1024]; + let _ = stream.read(&mut buf).await; + if current == 2 { + server_network.quic().locations().upsert( + server_bind_uri.clone(), + Arc::new(Ok::( + "127.0.0.1:10001".parse().expect("valid socket addr"), + )), + ); + } + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .await; + } + }); + + let resolver = Arc::new( + crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) + .expect("valid http resolver"), + ); + let publisher = test_publisher(resolver); + let source = test_source(network.clone()); + let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source); + publisher_loop.interval = Duration::from_secs(60); + + let publisher = tokio::spawn(async move { + publisher_loop.run().await; + }); + + wait_for_count(&publish_count, 1).await; + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; + network.quic().locations().upsert( + bind_uri, + Arc::new(Ok::( + "127.0.0.1:10000".parse().expect("valid socket addr"), + )), + ); + + tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 2)) + .await + .expect("publishable location changes should trigger the next independent publish"); + + tokio::time::timeout( + PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), + wait_for_count(&publish_count, 3), + ) + .await + .expect("publishable location events should replace the current publish attempt"); + + publisher.abort(); + server.abort(); + } + + #[cfg(feature = "http-resolver")] + #[tokio::test] + async fn run_ignores_transient_location_failures_generated_during_publish_attempt() { + async fn wait_for_count(count: &AtomicUsize, target: usize) { + loop { + if count.load(Ordering::SeqCst) >= target { + return; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + } + + let network = h3x::dquic::Network::builder().build(); + let bind_uri: h3x::dquic::net::BindUri = + "inet://127.0.0.1:0".parse().expect("valid bind uri"); + let publish_count = Arc::new(AtomicUsize::new(0)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test http server"); + let port = listener.local_addr().expect("local addr").port(); + let server_network = network.clone(); + let server_bind_uri = bind_uri.clone(); + let server_count = publish_count.clone(); + let server = tokio::spawn(async move { + loop { + let Ok((mut stream, _peer)) = listener.accept().await else { + break; + }; + let mut buf = [0_u8; 1024]; + let _ = stream.read(&mut buf).await; + server_count.fetch_add(1, Ordering::SeqCst); + server_network + .quic() + .locations() + .upsert::( + server_bind_uri.clone(), + Arc::new(Err( + dquic::qtraversal::nat::client::DetectOuterAddrError::Rebinded { + bind_uri: server_bind_uri.clone(), + }, + )), + ); + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .await; + } + }); + + let resolver = Arc::new( + crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) + .expect("valid http resolver"), + ); + let publisher = test_publisher(resolver); + let source = test_source(network.clone()); + let publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source); + let publisher = tokio::spawn(async move { + publisher_loop.run().await; + }); + + wait_for_count(&publish_count, 1).await; + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; + + network.quic().locations().upsert( + bind_uri, + Arc::new(Ok::( + "127.0.0.1:0".parse().expect("valid socket addr"), + )), + ); + wait_for_count(&publish_count, 2).await; + + let third_publish = tokio::time::timeout( + PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), + wait_for_count(&publish_count, 3), + ) + .await; + + publisher.abort(); + server.abort(); + + assert!( + third_publish.is_err(), + "publish-generated location events must not trigger another immediate publish" + ); + } + + #[cfg(feature = "http-resolver")] + #[tokio::test] + async fn run_does_not_retry_location_publish_after_timeout() { + async fn wait_for_count(count: &AtomicUsize, target: usize) { + loop { + if count.load(Ordering::SeqCst) >= target { + return; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + } + + let network = h3x::dquic::Network::builder().build(); + let bind_uri: h3x::dquic::net::BindUri = + "inet://127.0.0.1:0".parse().expect("valid bind uri"); + let publish_count = Arc::new(AtomicUsize::new(0)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test http server"); + let port = listener.local_addr().expect("local addr").port(); + let server_count = publish_count.clone(); + let server = tokio::spawn(async move { + loop { + let Ok((mut stream, _peer)) = listener.accept().await else { + break; + }; + let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; + let mut buf = [0_u8; 1024]; + let _ = stream.read(&mut buf).await; + if current == 2 { + tokio::time::sleep(Duration::from_millis(200)).await; + } + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .await; + } + }); + + let resolver = Arc::new( + crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) + .expect("valid http resolver"), + ); + let publisher = test_publisher(resolver); + let source = test_source(network.clone()); + let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source) + .with_publish_timeout(Duration::from_millis(50)); + publisher_loop.interval = Duration::from_secs(60); + + let publisher = tokio::spawn(async move { + publisher_loop.run().await; + }); + + wait_for_count(&publish_count, 1).await; + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; + network.quic().locations().upsert( + bind_uri, + Arc::new(Ok::( + "127.0.0.1:0".parse().expect("valid socket addr"), + )), + ); + + wait_for_count(&publish_count, 2).await; + let third_publish = tokio::time::timeout( + PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), + wait_for_count(&publish_count, 3), + ) + .await; + + publisher.abort(); + server.abort(); + + assert!( + third_publish.is_err(), + "timed out location-triggered publish must not be retried before the next interval" + ); + } + + #[cfg(feature = "http-resolver")] + #[tokio::test] + async fn run_replaces_in_flight_publish_on_publishable_location_change() { + async fn wait_for_count(count: &AtomicUsize, target: usize) { + loop { + if count.load(Ordering::SeqCst) >= target { + return; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + } + + let network = h3x::dquic::Network::builder().build(); + let bind_uri: h3x::dquic::net::BindUri = + "inet://127.0.0.1:0".parse().expect("valid bind uri"); + let publish_count = Arc::new(AtomicUsize::new(0)); + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test http server"); + let port = listener.local_addr().expect("local addr").port(); + let server_count = publish_count.clone(); + let server = tokio::spawn(async move { + loop { + let Ok((mut stream, _peer)) = listener.accept().await else { + break; + }; + let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; + tokio::spawn(async move { + let mut buf = [0_u8; 1024]; + let _ = stream.read(&mut buf).await; + if current == 1 { + std::future::pending::<()>().await; + } + let _ = stream + .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") + .await; + }); + } + }); + + let resolver = Arc::new( + crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) + .expect("valid http resolver"), + ); + let publisher = test_publisher(resolver); + let source = test_source(network.clone()); + let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source) + .with_publish_timeout(Duration::from_secs(30)); + publisher_loop.interval = Duration::from_secs(60); + + let publisher = tokio::spawn(async move { + publisher_loop.run().await; + }); + + tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 1)) + .await + .expect("initial publish should start"); + + network.quic().locations().upsert( + bind_uri, + Arc::new(Ok::( + "127.0.0.1:10000".parse().expect("valid socket addr"), + )), + ); + + tokio::time::timeout( + PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(800), + wait_for_count(&publish_count, 2), + ) + .await + .expect("publishable location change should replace the in-flight publish"); + + publisher.abort(); + server.abort(); + } +} diff --git a/src/publisher/address.rs b/src/publisher/address.rs new file mode 100644 index 0000000..207150b --- /dev/null +++ b/src/publisher/address.rs @@ -0,0 +1,444 @@ +use std::{ + collections::HashSet, + net::SocketAddr, + sync::{Arc, OnceLock}, +}; + +use dquic::{ + qbase::net::{Family, addr::EndpointAddr}, + qinterface::component::location::Observer, +}; +use h3x::dquic::{ + Network, + binds::BindPattern, + net::{BindInterface, BindUri, IO, Scheme}, + qtraversal::nat::client::{NatType, StunClientsComponent}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AddressSelector<'a> { + WideArea, + LocalLink { device: &'a str, family: Family }, +} + +pub trait AddressView { + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a; +} + +pub struct FnAddressView { + f: F, +} + +impl FnAddressView { + pub fn new(f: F) -> Self { + Self { f } + } +} + +impl AddressView for FnAddressView +where + F: for<'a> Fn(AddressSelector<'a>) -> I, + I: IntoIterator, + I::IntoIter: 'static, +{ + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + (self.f)(selector).into_iter() + } +} + +pub trait AddressViewSource { + fn address_view(&self) -> impl AddressView + Send + Sync + '_; + fn subscribe(&self) -> Observer; + fn observes(&self, bind_uri: &BindUri) -> bool; +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PublishAddressScope { + WideArea, + LocalLink { device: Arc, family: Family }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PublishAddressGroup { + scope: PublishAddressScope, + endpoints: Vec, +} + +impl PublishAddressGroup { + pub fn wide_area(endpoints: I) -> Self + where + I: IntoIterator, + { + Self { + scope: PublishAddressScope::WideArea, + endpoints: endpoints.into_iter().collect(), + } + } + + pub fn local_link(device: impl Into>, family: Family, endpoints: I) -> Self + where + I: IntoIterator, + { + Self { + scope: PublishAddressScope::LocalLink { + device: device.into(), + family, + }, + endpoints: endpoints.into_iter().collect(), + } + } + + fn matches(&self, selector: AddressSelector<'_>) -> bool { + match (&self.scope, selector) { + (PublishAddressScope::WideArea, AddressSelector::WideArea) => true, + ( + PublishAddressScope::LocalLink { device, family }, + AddressSelector::LocalLink { + device: selected_device, + family: selected_family, + }, + ) => device.as_ref() == selected_device && *family == selected_family, + _ => false, + } + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct PublishAddresses { + groups: Vec, +} + +impl PublishAddresses { + pub fn new() -> Self { + Self::default() + } + + pub fn group(mut self, group: PublishAddressGroup) -> Self { + self.groups.push(group); + self + } + + pub fn wide_area(self, endpoints: I) -> Self + where + I: IntoIterator, + { + self.group(PublishAddressGroup::wide_area(endpoints)) + } + + pub fn local_link(self, device: impl Into>, family: Family, endpoints: I) -> Self + where + I: IntoIterator, + { + self.group(PublishAddressGroup::local_link(device, family, endpoints)) + } +} + +impl AddressView for PublishAddresses { + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + self.groups + .iter() + .filter(move |group| group.matches(selector)) + .flat_map(move |group| group.endpoints.iter().copied()) + } +} + +#[derive(Clone)] +pub struct EndpointBindingAddresses { + network: Arc, + bind_patterns: Arc>, +} + +impl std::fmt::Debug for EndpointBindingAddresses { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointBindingAddresses") + .field("bind_patterns", &self.bind_patterns) + .finish_non_exhaustive() + } +} + +impl EndpointBindingAddresses { + pub fn new(network: Arc, bind_patterns: Arc>) -> Self { + Self { + network, + bind_patterns, + } + } +} + +impl AddressViewSource for EndpointBindingAddresses { + fn address_view(&self) -> impl AddressView + Send + Sync + '_ { + EndpointBindingAddressView::new(self.network.clone(), self.bind_patterns.clone()) + } + + fn subscribe(&self) -> Observer { + self.network.quic().locations().subscribe() + } + + fn observes(&self, bind_uri: &BindUri) -> bool { + self.bind_patterns + .iter() + .any(|pattern| pattern.matches(bind_uri)) + } +} + +struct EndpointBindingAddressView { + bindings: Vec, +} + +impl EndpointBindingAddressView { + fn new(network: Arc, bind_patterns: Arc>) -> Self { + let mut bindings = Vec::new(); + for pattern in bind_patterns.iter() { + let Some(ifaces) = network.quic().get_interfaces(pattern) else { + tracing::trace!(?pattern, "no interfaces for bind pattern"); + continue; + }; + for iface in ifaces { + bindings.push(BindingAddress::new(network.clone(), pattern.clone(), iface)); + } + } + Self { bindings } + } +} + +impl AddressView for EndpointBindingAddressView { + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + let mut seen = HashSet::new(); + self.bindings + .iter() + .filter(move |binding| binding.may_match(selector)) + .flat_map(move |binding| binding.endpoints(selector)) + .filter(move |endpoint| seen.insert(*endpoint)) + } +} + +struct BindingAddress { + network: Arc, + pattern: BindPattern, + bind_uri: BindUri, + iface: BindInterface, + wide_area: OnceLock>, + local_link: OnceLock>, +} + +impl BindingAddress { + fn new(network: Arc, pattern: BindPattern, iface: BindInterface) -> Self { + let bind_uri = iface.bind_uri(); + Self { + network, + pattern, + bind_uri, + iface, + wide_area: OnceLock::new(), + local_link: OnceLock::new(), + } + } + + fn may_match(&self, selector: AddressSelector<'_>) -> bool { + match selector { + AddressSelector::WideArea => true, + AddressSelector::LocalLink { device, family } => { + pattern_may_match_local_link(&self.pattern, device, family) + && bind_uri_matches_local_link(&self.bind_uri, device, family) + } + } + } + + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + let endpoints = match selector { + AddressSelector::WideArea => self + .wide_area + .get_or_init(|| public_endpoints_from_iface(&self.network, &self.iface)), + AddressSelector::LocalLink { family, .. } => self + .local_link + .get_or_init(|| local_endpoints_from_iface(&self.iface, family)), + }; + endpoints.iter().copied() + } +} + +fn pattern_may_match_local_link(pattern: &BindPattern, device: &str, family: Family) -> bool { + if pattern.scheme != Scheme::Iface { + return false; + } + if pattern + .host + .family() + .is_some_and(|pattern_family| pattern_family != family) + { + return false; + } + pattern.host.matches(device) +} + +fn bind_uri_matches_local_link(bind_uri: &BindUri, device: &str, family: Family) -> bool { + bind_uri + .as_iface_bind_uri() + .is_some_and(|(iface_family, iface_device, _port)| { + iface_family == family && iface_device == device + }) +} + +fn public_endpoints_from_iface(network: &Network, iface: &BindInterface) -> Vec { + iface.with_components(|components, current| { + let bind_uri = current.bind_uri(); + let addr = current.bound_addr().ok(); + let mut endpoints: Vec = components + .get::() + .map(|stun| { + stun.with_clients(|clients| { + clients + .values() + .filter_map(|client| { + let outer = client.get_outer_addr()?.ok()?; + let bound = current.bound_addr().ok()?; + match client.get_nat_type() { + Some(Ok(nat_type)) => Some(publish_endpoint_from_stun( + bound, + client.agent_addr(), + outer, + nat_type, + )), + None => Some(EndpointAddr::with_agent(client.agent_addr(), outer)), + Some(Err(_)) => None, + } + }) + .collect() + }) + }) + .unwrap_or_default(); + let stun_endpoint_count = endpoints.len(); + + if let Some(addr) = addr + && network.bound_addr_is_on_default_route(&bind_uri, addr) + { + endpoints.push(EndpointAddr::direct(addr)); + } + + tracing::trace!( + bind_uri = %bind_uri, + bound_addr = ?addr, + stun_endpoint_count, + endpoint_count = endpoints.len(), + endpoints = ?endpoints, + "collected wide-area endpoints from interface" + ); + + endpoints + }) +} + +fn publish_endpoint_from_stun( + bound: SocketAddr, + agent: SocketAddr, + outer: SocketAddr, + nat_type: NatType, +) -> EndpointAddr { + if nat_type == NatType::FullCone && bound == outer { + EndpointAddr::direct(outer) + } else { + EndpointAddr::with_agent(agent, outer) + } +} + +fn local_endpoints_from_iface(iface: &BindInterface, family: Family) -> Vec { + iface.with_components(|_components, current| { + let Some(addr) = current.bound_addr().ok() else { + return Vec::new(); + }; + match (family, addr) { + (Family::V4, SocketAddr::V4(_)) | (Family::V6, SocketAddr::V6(_)) => { + vec![EndpointAddr::direct(addr)] + } + _ => Vec::new(), + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn publish_addresses_select_wide_area_only_for_wide_area_selector() { + let wide = EndpointAddr::direct("203.0.113.10:443".parse().unwrap()); + let local = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); + let addresses = + PublishAddresses::new() + .wide_area([wide]) + .local_link("en0", Family::V4, [local]); + + let selected: Vec<_> = addresses.endpoints(AddressSelector::WideArea).collect(); + + assert_eq!(selected, vec![wide]); + } + + #[test] + fn publish_addresses_select_matching_local_link_group() { + let en0 = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); + let en1 = EndpointAddr::direct("192.168.2.20:443".parse().unwrap()); + let addresses = PublishAddresses::new() + .local_link("en0", Family::V4, [en0]) + .local_link("en1", Family::V4, [en1]); + + let selected: Vec<_> = addresses + .endpoints(AddressSelector::LocalLink { + device: "en1", + family: Family::V4, + }) + .collect(); + + assert_eq!(selected, vec![en1]); + } + + #[test] + fn publish_addresses_reject_local_link_family_mismatch() { + let endpoint = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); + let addresses = PublishAddresses::new().local_link("en0", Family::V4, [endpoint]); + + let selected: Vec<_> = addresses + .endpoints(AddressSelector::LocalLink { + device: "en0", + family: Family::V6, + }) + .collect(); + + assert!(selected.is_empty()); + } + + #[test] + fn full_cone_nat_endpoint_preserves_agent_when_outer_differs_from_bound_addr() { + let bound = "10.110.0.10:45635".parse().expect("valid bound addr"); + let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); + let outer = "10.10.0.10:45635".parse().expect("valid outer addr"); + + let endpoint = publish_endpoint_from_stun(bound, agent, outer, NatType::FullCone); + + assert_eq!(endpoint, EndpointAddr::with_agent(agent, outer)); + } + + #[test] + fn full_cone_endpoint_is_direct_without_address_translation() { + let bound = "10.10.0.100:45635".parse().expect("valid bound addr"); + let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); + + let endpoint = publish_endpoint_from_stun(bound, agent, bound, NatType::FullCone); + + assert_eq!(endpoint, EndpointAddr::direct(bound)); + } +} diff --git a/src/publisher/dispatch.rs b/src/publisher/dispatch.rs new file mode 100644 index 0000000..07d4db6 --- /dev/null +++ b/src/publisher/dispatch.rs @@ -0,0 +1,152 @@ +use std::any::Any; + +use dhttp_identity::{identity::LocalAuthority, name::Name}; +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, Resolve}, +}; +use snafu::ResultExt; + +use super::{ + AddressSelector, AddressView, PublishOnceError, Publisher, PublisherResolver, + publish_once_error, +}; +use crate::resolvers::Resolvers; + +impl Publisher +where + A: LocalAuthority + Send + Sync + ?Sized, + R: PublisherResolver + ?Sized, +{ + pub(crate) async fn publish_to_resolver( + &self, + resolver: &(dyn Resolve + Send + Sync), + name: &Name<'_>, + addresses: &V, + ) -> Result + where + V: AddressView + Sync, + { + let any: &dyn Any = resolver; + + if let Some(resolvers) = any.downcast_ref::() { + let mut published = false; + for resolver in resolvers.iter() { + published |= self + .publish_single_resolver(resolver.as_ref(), name, addresses) + .await?; + } + return Ok(published); + } + + self.publish_single_resolver(resolver, name, addresses) + .await + } + + async fn publish_single_resolver( + &self, + resolver: &(dyn Resolve + Send + Sync), + name: &Name<'_>, + addresses: &V, + ) -> Result + where + V: AddressView + Sync, + { + #[cfg(not(any( + feature = "http-resolver", + feature = "h3x-resolver", + feature = "mdns-resolver" + )))] + { + let _ = name; + let _ = addresses; + } + + let any: &dyn Any = resolver; + + #[cfg(feature = "http-resolver")] + if let Some(http) = any.downcast_ref::() { + self.publish_selected(http, name, addresses, AddressSelector::WideArea) + .await?; + return Ok(true); + } + + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = + any.downcast_ref::>() + { + self.publish_selected(h3, name, addresses, AddressSelector::WideArea) + .await?; + return Ok(true); + } + + #[cfg(feature = "mdns-resolver")] + if let Some(mdns) = any.downcast_ref::() { + let mut published = false; + for bound in mdns.bound_resolvers() { + self.publish_selected( + &bound.resolver, + name, + addresses, + AddressSelector::LocalLink { + device: &bound.device, + family: bound.family, + }, + ) + .await?; + published = true; + } + return Ok(published); + } + + Ok(false) + } + + async fn publish_selected( + &self, + publisher: &(dyn Publish + Send + Sync), + name: &Name<'_>, + addresses: &V, + selector: AddressSelector<'_>, + ) -> Result<(), PublishOnceError> + where + V: AddressView + Sync, + { + let endpoints: Vec = addresses.endpoints(selector).collect(); + let packet = self + .signer + .signed_packet(name, &endpoints) + .await + .context(publish_once_error::SignEndpointRecordsSnafu)?; + tracing::debug!( + publisher = %publisher, + name = %name, + endpoint_count = endpoints.len(), + packet_len = packet.len(), + "publishing dns packet" + ); + publisher + .publish(name.as_str(), &packet) + .await + .context(publish_once_error::PublishSnafu { + publisher: publisher.to_string(), + }) + } +} + +pub(crate) fn clear_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { + let any: &dyn Any = resolver; + + if let Some(resolvers) = any.downcast_ref::() { + for resolver in resolvers.iter() { + clear_resolver_publish_state(resolver.as_ref()); + } + } + + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = + any.downcast_ref::>() + { + h3.clear_pool(); + } +} diff --git a/src/publisher/packet.rs b/src/publisher/packet.rs new file mode 100644 index 0000000..956afe2 --- /dev/null +++ b/src/publisher/packet.rs @@ -0,0 +1,83 @@ +use std::{collections::HashMap, sync::Arc}; + +use dhttp_identity::{ + identity::{LocalAuthority, LocalAuthorityCertificateExt}, + name::Name, +}; +use dquic::qbase::net::addr::EndpointAddr; +use snafu::{ResultExt, Snafu}; + +use crate::core::{ + MdnsPacket, + parser::record::endpoint::{EndpointAddr as DnsEndpointAddr, SignEndpointError}, +}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignEndpointRecordsError { + #[snafu(display("failed to encode endpoint address"))] + EncodeEndpoint, + #[snafu(display("failed to extract dhttp certificate selector"))] + CertificateSelector { + source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, + }, + #[snafu(display("failed to sign endpoint address"))] + SignEndpoint { source: SignEndpointError }, +} + +pub struct EndpointRecordSigner { + authority: Arc, +} + +impl std::fmt::Debug for EndpointRecordSigner +where + A: LocalAuthority, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointRecordSigner") + .field("authority", &self.authority.name()) + .finish() + } +} + +impl EndpointRecordSigner +where + A: LocalAuthority + Send + Sync + ?Sized, +{ + pub fn new(authority: Arc) -> Self { + Self { authority } + } + + pub fn authority(&self) -> &Arc { + &self.authority + } + + pub async fn signed_packet( + &self, + name: &Name<'_>, + endpoints: &[EndpointAddr], + ) -> Result, SignEndpointRecordsError> { + let selector = self + .authority + .dhttp_subject_key_identifier() + .context(sign_endpoint_records_error::CertificateSelectorSnafu)?; + let chain = selector.chain(); + + let mut signed = Vec::with_capacity(endpoints.len()); + for endpoint in endpoints { + let Ok(mut endpoint) = DnsEndpointAddr::try_from(*endpoint) else { + return sign_endpoint_records_error::EncodeEndpointSnafu.fail(); + }; + endpoint.set_certificate_chain_key(chain); + endpoint + .sign_with_authority(self.authority.as_ref()) + .await + .context(sign_endpoint_records_error::SignEndpointSnafu)?; + signed.push(endpoint); + } + + let mut hosts = HashMap::new(); + hosts.insert(name.as_str().to_owned(), signed); + Ok(MdnsPacket::answer(0, &hosts).to_bytes()) + } +} diff --git a/src/resolvers.rs b/src/resolvers.rs index fad5462..337f980 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -4,22 +4,31 @@ use std::{ sync::Arc, }; -use futures::{FutureExt, Stream, StreamExt, TryFutureExt, stream}; -use h3x::dquic::{ - qinterface::device::Devices, - qresolve::{EndpointAddr, Family, Publish, Resolve, ResolveFuture, Source}, +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Resolve, ResolveFuture, Source}, }; -use snafu::Report; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, stream}; use tokio::io; #[cfg(feature = "h3x-resolver")] -mod h3; -mod http; -mod mdns; +pub mod h3; +#[cfg(feature = "http-resolver")] +pub mod http; + +#[cfg(feature = "http-resolver")] +use http::HttpResolver; + +#[cfg(feature = "mdns-resolver")] +use crate::mdns::resolvers::mdns::MdnsResolvers; /// Extract and validate the DNS host from `name`, which may include a `:port` /// suffix. Returns `Some(host)` if the host part is a valid RFC-compliant DNS /// name, or `None` for raw IP addresses, bracketed IPv6, or malformed input. +#[cfg_attr( + not(any(feature = "h3x-resolver", feature = "http-resolver")), + allow(dead_code) +)] pub(crate) fn resolvable_name(name: &str) -> Option<&str> { let host = match name.rsplit_once(':') { Some((h, port)) if !port.is_empty() && port.chars().all(|c| c.is_ascii_digit()) => h, @@ -29,10 +38,59 @@ pub(crate) fn resolvable_name(name: &str) -> Option<&str> { Some(host) } -#[cfg(feature = "h3x-resolver")] -pub use h3::{H3Publisher, H3Resolver}; -pub use http::HttpResolver; -pub use mdns::{MdnsResolver, MdnsResolvers}; +/// Default DNS-over-H3 server for DHTTP endpoints. +pub const DHTTP_H3_DNS_SERVER: &str = crate::bootstrap::DHTTP_H3_DNS_SERVER; + +/// Default DNS-over-HTTP server for DHTTP endpoints. +pub const DHTTP_HTTP_DNS_SERVER: &str = crate::bootstrap::DHTTP_HTTP_DNS_SERVER; + +/// mDNS service type used by DHTTP endpoints. +pub const DHTTP_MDNS_SERVICE: &str = crate::bootstrap::DHTTP_MDNS_SERVICE; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum DnsScheme { + Mdns, + Http, + H3, + System, +} + +impl Display for DnsScheme { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::Mdns => "mdns", + Self::Http => "http", + Self::H3 => "h3", + Self::System => "system", + }) + } +} + +#[derive(Debug, snafu::Snafu)] +#[snafu(display("unsupported dns scheme {scheme}"))] +pub struct ParseDnsSchemeError { + scheme: String, +} + +impl std::str::FromStr for DnsScheme { + type Err = ParseDnsSchemeError; + + fn from_str(s: &str) -> Result { + match s { + "mdns" => Ok(Self::Mdns), + "http" => Ok(Self::Http), + "h3" => Ok(Self::H3), + "system" => Ok(Self::System), + scheme => Err(ParseDnsSchemeError { + scheme: scheme.to_owned(), + }), + } + } +} + +pub mod deferred; +pub(crate) mod selector; +pub mod weak; type ArcResolver = Arc; @@ -63,14 +121,40 @@ pub struct DnsErrors { errors: Vec<(String, io::Error)>, } +fn format_dns_error_sources( + f: &mut fmt::Formatter<'_>, + error: &(dyn Error + 'static), +) -> fmt::Result { + let mut index = 1; + let mut current = error.source(); + + while let Some(source) = current { + write!(f, "\n {index}. {source}")?; + index += 1; + current = source.source(); + } + + Ok(()) +} + +fn format_dns_error_entry( + f: &mut fmt::Formatter<'_>, + resolver: &str, + error: &io::Error, +) -> fmt::Result { + write!(f, "\n - {resolver}: {error}")?; + format_dns_error_sources(f, error) +} + impl fmt::Display for DnsErrors { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.errors.is_empty() { return write!(f, "no DNS resolvers available"); } - writeln!(f, "all DNS resolvers failed")?; - for (resolver, error) in self.errors.iter() { - write!(f, "`{resolver}` failed: {}", Report::from_error(error))?; + + write!(f, "all DNS resolvers failed")?; + for (resolver, error) in &self.errors { + format_dns_error_entry(f, resolver, error)?; } Ok(()) } @@ -78,38 +162,100 @@ impl fmt::Display for DnsErrors { impl Error for DnsErrors {} +#[derive(Default)] +pub struct ResolversBuilder { + resolvers: Resolvers, +} + +impl ResolversBuilder { + pub fn resolver(mut self, resolver: ArcResolver) -> Self { + self.resolvers.push(resolver); + self + } + + #[cfg(feature = "mdns-resolver")] + pub async fn mdns( + mut self, + network: Arc, + patterns: Arc>, + ) -> Self { + let mdns = Arc::new(MdnsResolvers::bind(network, patterns, DHTTP_MDNS_SERVICE).await); + self.resolvers.push(mdns); + self + } + + #[cfg(feature = "h3x-resolver")] + pub fn h3( + self, + endpoint: Arc>, + ) -> io::Result + where + C: h3x::quic::Connect + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, + { + self.h3_with_base_url(DHTTP_H3_DNS_SERVER, endpoint) + } + + #[cfg(feature = "h3x-resolver")] + pub fn h3_with_base_url( + mut self, + base_url: impl AsRef, + endpoint: Arc>, + ) -> io::Result + where + C: h3x::quic::Connect + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, + { + let resolver = h3::H3Resolver::from_endpoint(base_url, endpoint)?; + self.resolvers.push(Arc::new(resolver)); + Ok(self) + } + + #[cfg(feature = "http-resolver")] + pub fn http(self) -> io::Result { + self.http_with_base_url(DHTTP_HTTP_DNS_SERVER) + } + + #[cfg(feature = "http-resolver")] + pub fn http_with_base_url(mut self, base_url: impl AsRef) -> io::Result { + let resolver = HttpResolver::new(base_url.as_ref())?; + self.resolvers.push(Arc::new(resolver)); + Ok(self) + } + + pub fn system(mut self) -> Self { + self.resolvers + .push(Arc::new(dquic::qresolve::SystemResolver)); + self + } + + pub fn build(self) -> Resolvers { + self.resolvers + } +} + impl Resolvers { + pub fn builder() -> ResolversBuilder { + ResolversBuilder::default() + } + pub fn new() -> Self { Self::default() } pub fn with(mut self, resolver: ArcResolver) -> Self { - self.resolvers.push(resolver); + self.push(resolver); self } - pub fn with_mdns_resolvers( - mut self, - service_name: &str, - mut filter: impl FnMut(&str, Family) -> bool, - ) -> Self { - let devices = Devices::global(); - self.resolvers.extend( - devices - .interfaces() - .iter() - .flat_map(|(device, iface)| { - Option::into_iter( - (!iface.ipv4.is_empty()).then_some((device.as_str(), Family::V4)), - ) - .chain((!iface.ipv6.is_empty()).then_some((device.as_str(), Family::V6))) - }) - .filter(|(device, family)| filter(device, *family)) - .filter_map(|(device, ip)| Some((device, devices.resolve(device, ip)?))) - .filter_map(|(device, ip)| MdnsResolver::new(service_name, ip, device).ok()) - .map(|resolver| Arc::new(resolver) as ArcResolver), - ); - self + pub fn push(&mut self, resolver: ArcResolver) { + self.resolvers.push(resolver); + } + + pub fn iter(&self) -> impl Iterator { + self.resolvers.iter() } pub async fn lookup( @@ -146,3 +292,268 @@ impl Resolve for Resolvers { .boxed() } } + +#[cfg(test)] +mod tests { + use std::{error::Error as StdError, fmt, io, str::FromStr}; + + #[cfg(feature = "mdns-resolver")] + use super::MdnsResolvers; + #[cfg(any( + feature = "h3x-resolver", + feature = "http-resolver", + feature = "mdns-resolver" + ))] + use super::Resolvers; + use super::{ + DHTTP_H3_DNS_SERVER, DHTTP_HTTP_DNS_SERVER, DHTTP_MDNS_SERVICE, DnsErrors, DnsScheme, + resolvable_name, + }; + + #[derive(Debug)] + struct TestSourceError { + message: &'static str, + source: Option>, + } + + impl TestSourceError { + fn leaf(message: &'static str) -> Self { + Self { + message, + source: None, + } + } + + fn with_source(message: &'static str, source: TestSourceError) -> Self { + Self { + message, + source: Some(Box::new(source)), + } + } + } + + impl fmt::Display for TestSourceError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.message) + } + } + + impl StdError for TestSourceError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.source + .as_deref() + .map(|source| source as &(dyn StdError + 'static)) + } + } + + fn other_error(message: &'static str) -> io::Error { + io::Error::other(message) + } + + fn chained_other_error(root: TestSourceError) -> io::Error { + io::Error::other(root) + } + + #[test] + fn resolver_defaults_come_from_compile_time_environment() { + if let Some(expected) = option_env!("DHTTP_H3_DNS_SERVER") { + assert_eq!(DHTTP_H3_DNS_SERVER, expected); + } + if let Some(expected) = option_env!("DHTTP_HTTP_DNS_SERVER") { + assert_eq!(DHTTP_HTTP_DNS_SERVER, expected); + } + if let Some(expected) = option_env!("DHTTP_MDNS_SERVICE") { + assert_eq!(DHTTP_MDNS_SERVICE, expected); + } + } + + #[test] + fn resolvable_name_accepts_dns_name_with_numeric_port() { + assert_eq!( + resolvable_name("example.dhttp.net:443"), + Some("example.dhttp.net") + ); + } + + #[test] + fn resolvable_name_accepts_stun_authority_with_numeric_port() { + assert_eq!( + resolvable_name("nat.genmeta.net:20004"), + Some("nat.genmeta.net") + ); + } + + #[test] + fn resolvable_name_rejects_ip_literals() { + assert_eq!(resolvable_name("127.0.0.1:443"), None); + assert_eq!(resolvable_name("[::1]:443"), None); + } + + #[test] + fn dns_scheme_round_trips_supported_schemes_and_rejects_dht() { + let cases = [ + ("mdns", DnsScheme::Mdns), + ("http", DnsScheme::Http), + ("h3", DnsScheme::H3), + ("system", DnsScheme::System), + ]; + + for (text, scheme) in cases { + assert_eq!(DnsScheme::from_str(text).expect("supported scheme"), scheme); + assert_eq!(scheme.to_string(), text); + } + + assert!(DnsScheme::from_str("dht").is_err()); + } + + #[test] + fn dns_errors_render_no_resolvers_available_when_empty() { + let error = DnsErrors { errors: vec![] }; + + assert_eq!(error.to_string(), "no DNS resolvers available"); + } + + #[test] + fn dns_errors_render_resolver_bullets_in_stored_order() { + let error = DnsErrors { + errors: vec![ + ( + "System DNS Resolver".to_string(), + other_error("invalid socket address"), + ), + ("mDNS resolvers".to_string(), other_error("timed out")), + ], + }; + + assert_eq!( + error.to_string(), + concat!( + "all DNS resolvers failed\n", + " - System DNS Resolver: invalid socket address\n", + " - mDNS resolvers: timed out" + ) + ); + } + + #[test] + fn dns_errors_render_numbered_source_chain_for_one_resolver() { + let error = DnsErrors { + errors: vec![( + "DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/))".to_string(), + chained_other_error(TestSourceError::with_source( + "deferred resolver lookup failed", + TestSourceError::leaf("no DNS record found"), + )), + )], + }; + + assert_eq!( + error.to_string(), + concat!( + "all DNS resolvers failed\n", + " - DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/)): deferred resolver lookup failed\n", + " 1. no DNS record found" + ) + ); + } + + #[test] + fn dns_errors_render_repeated_source_messages_without_deduplication() { + let error = DnsErrors { + errors: vec![( + "DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/))".to_string(), + chained_other_error(TestSourceError::with_source( + "deferred resolver lookup failed", + TestSourceError::with_source( + "deferred resolver lookup failed", + TestSourceError::leaf("no DNS record found"), + ), + )), + )], + }; + + assert_eq!( + error.to_string(), + concat!( + "all DNS resolvers failed\n", + " - DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/)): deferred resolver lookup failed\n", + " 1. deferred resolver lookup failed\n", + " 2. no DNS record found" + ) + ); + } + + #[cfg(feature = "mdns-resolver")] + #[tokio::test] + async fn resolvers_builder_can_enable_mdns() { + use std::sync::Arc; + + use h3x::dquic::{Network, binds::BindPattern}; + + let network = Network::builder().build(); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + + let resolvers = Resolvers::builder() + .mdns(network, Arc::new(vec![pattern])) + .await + .build(); + + assert!(resolvers.to_string().contains("mDNS resolvers")); + } + + #[cfg(feature = "h3x-resolver")] + #[tokio::test] + async fn resolvers_builder_accepts_custom_h3_base_url() { + use std::sync::Arc; + + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + + let resolvers = Resolvers::builder() + .h3_with_base_url("https://custom-dns.example:4433", endpoint) + .expect("valid h3 dns url") + .build(); + + assert!(resolvers.to_string().contains("custom-dns.example")); + } + + #[cfg(feature = "http-resolver")] + #[test] + fn resolvers_builder_accepts_custom_http_base_url() { + let resolvers = Resolvers::builder() + .http_with_base_url("https://custom-dns.example") + .expect("valid http dns url") + .build(); + + assert!(resolvers.to_string().contains("custom-dns.example")); + } + + #[cfg(feature = "mdns-resolver")] + #[tokio::test] + async fn mdns_resolvers_bind_installs_mdns_on_null_io_binding() { + use std::sync::Arc; + + use dquic::qinterface::io::IO; + use h3x::dquic::{Network, binds::BindPattern}; + + let network = Network::builder().build(); + let pattern = BindPattern::from_str("iface://v4.lo:0").expect("valid pattern"); + let resolvers = MdnsResolvers::bind( + network.clone(), + Arc::new(vec![pattern.clone()]), + DHTTP_MDNS_SERVICE, + ) + .await; + + let ifaces = resolvers + .bound_interfaces(&pattern) + .expect("bound interfaces"); + assert!(!ifaces.is_empty()); + assert!(ifaces[0].borrow().bound_addr().is_err()); + assert!( + ifaces[0] + .with_components(|components, _| components.exist::()) + ); + } +} diff --git a/src/resolvers/deferred.rs b/src/resolvers/deferred.rs new file mode 100644 index 0000000..a961ac8 --- /dev/null +++ b/src/resolvers/deferred.rs @@ -0,0 +1,218 @@ +use std::{fmt, io}; + +use dquic::qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture}; +use futures::FutureExt; +use snafu::{ResultExt, Snafu}; +use tokio::sync::OnceCell; + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum DeferredLookupError { + #[snafu(display("deferred resolver has not been initialized"))] + Uninitialized, + #[snafu(display("deferred resolver lookup failed"))] + Lookup { source: io::Error }, +} + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum DeferredPublishError { + #[snafu(display("deferred resolver has not been initialized"))] + Uninitialized, + #[snafu(display("deferred resolver publish failed"))] + Publish { source: io::Error }, +} + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum SetDeferredResolverError { + #[snafu(display("deferred resolver has already been initialized"))] + AlreadyInitialized, +} + +pub struct DeferredResolver { + inner: OnceCell, +} + +impl fmt::Debug for DeferredResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DeferredResolver") + .field("initialized", &self.inner.get().is_some()) + .finish() + } +} + +impl Default for DeferredResolver { + fn default() -> Self { + Self::new() + } +} + +impl DeferredResolver { + #[must_use] + pub fn new() -> Self { + Self { + inner: OnceCell::new(), + } + } + + pub fn set(&self, resolver: R) -> Result<(), SetDeferredResolverError> { + if self.inner.set(resolver).is_err() { + return set_deferred_resolver_error::AlreadyInitializedSnafu.fail(); + } + Ok(()) + } + + #[must_use] + pub fn get(&self) -> Option<&R> { + self.inner.get() + } +} + +impl fmt::Display for DeferredResolver +where + R: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner.get() { + Some(resolver) => write!(f, "DeferredResolver({resolver})"), + None => f.write_str("DeferredResolver(uninitialized)"), + } + } +} + +impl DeferredResolver +where + R: Resolve + 'static, +{ + pub async fn lookup_typed(&self, name: &str) -> Result { + let Some(resolver) = self.get() else { + return deferred_lookup_error::UninitializedSnafu.fail(); + }; + resolver + .lookup(name) + .await + .context(deferred_lookup_error::LookupSnafu) + } +} + +impl Resolve for DeferredResolver +where + R: Resolve + 'static, +{ + fn lookup<'a>(&'a self, name: &'a str) -> ResolveFuture<'a> { + async move { self.lookup_typed(name).await.map_err(io::Error::other) }.boxed() + } +} + +impl DeferredResolver +where + R: Publish + 'static, +{ + pub async fn publish_typed( + &self, + name: &str, + packet: &[u8], + ) -> Result<(), DeferredPublishError> { + let Some(resolver) = self.get() else { + return deferred_publish_error::UninitializedSnafu.fail(); + }; + resolver + .publish(name, packet) + .await + .context(deferred_publish_error::PublishSnafu) + } +} + +impl Publish for DeferredResolver +where + R: Publish + 'static, +{ + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + async move { + self.publish_typed(name, packet) + .await + .map_err(io::Error::other) + } + .boxed() + } +} + +#[cfg(test)] +mod tests { + use std::fmt; + + use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, Resolve, Source}, + }; + use futures::{FutureExt, StreamExt}; + + use super::*; + + #[derive(Debug)] + struct TestResolver; + + impl fmt::Display for TestResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("test resolver") + } + } + + impl Resolve for TestResolver { + fn lookup<'a>(&'a self, _name: &'a str) -> dquic::qresolve::ResolveFuture<'a> { + async move { + let endpoint = EndpointAddr::direct("127.0.0.1:4433".parse().unwrap()); + Ok(futures::stream::iter([(Source::System, endpoint)]).boxed()) + } + .boxed() + } + } + + impl Publish for TestResolver { + fn publish<'a>( + &'a self, + _name: &'a str, + _packet: &'a [u8], + ) -> dquic::qresolve::PublishFuture<'a> { + async move { Ok(()) }.boxed() + } + } + + #[tokio::test] + async fn lookup_before_set_returns_typed_uninitialized_error() { + let resolver: DeferredResolver = DeferredResolver::new(); + + let error = match resolver.lookup_typed("example.test").await { + Ok(_) => panic!("uninitialized resolver must not resolve"), + Err(error) => error, + }; + + assert!(matches!(error, DeferredLookupError::Uninitialized)); + } + + #[tokio::test] + async fn lookup_after_set_forwards_to_inner_resolver() { + let resolver = DeferredResolver::new(); + resolver.set(TestResolver).expect("first set succeeds"); + + let mut stream = resolver.lookup_typed("example.test").await.unwrap(); + let (_source, endpoint) = stream.next().await.expect("forwarded endpoint"); + + assert_eq!( + endpoint, + EndpointAddr::direct("127.0.0.1:4433".parse().unwrap()) + ); + } + + #[tokio::test] + async fn publish_after_set_forwards_to_inner_resolver() { + let resolver = DeferredResolver::new(); + resolver.set(TestResolver).expect("first set succeeds"); + + resolver + .publish_typed("example.test", b"packet") + .await + .unwrap(); + } +} diff --git a/src/resolvers/h3.rs b/src/resolvers/h3.rs index 9b6a329..a54cdc6 100644 --- a/src/resolvers/h3.rs +++ b/src/resolvers/h3.rs @@ -1,27 +1,26 @@ -use std::{fmt, io, sync::Arc, time::Duration}; +use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; use dashmap::DashMap; -use futures::{FutureExt, StreamExt, TryFutureExt, stream}; +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture, Source}, +}; +use futures::{StreamExt, stream}; use h3x::{ - client::Client, - dquic::{ - prelude::ConnectServerError, - qresolve::{ - EndpointAddr, Publish, PublishFuture, RecordStream, Resolve, ResolveFuture, Source, - }, - }, - quic, + dquic::ConnectError, endpoint::H3Endpoint, hyper::RequestError as HyperRequestError, quic, }; -use reqwest::IntoUrl; +use http_body_util::{BodyExt, Empty, Full}; use tokio::time::Instant; use tracing::trace; use url::Url; -use crate::{MdnsPacket, parser::packet::be_packet, wire::be_multi_response}; +use crate::core::{MdnsPacket, parser::packet::be_packet, wire::be_multi_response}; + +const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); +const LOOKUP_REQUEST_ATTEMPTS: usize = 3; -// Inner struct that holds the actual H3 client and runs on a dedicated thread pub struct H3Resolver { - client: Client, + endpoint: Arc>, base_url: Url, cached_records: DashMap, negative_cache: DashMap, @@ -29,7 +28,7 @@ pub struct H3Resolver { #[derive(Debug)] struct Record { - addrs: Vec, + addrs: Vec, expire: Instant, } @@ -43,24 +42,24 @@ impl fmt::Debug for H3Resolver { impl fmt::Display for H3Resolver { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "H3 DNS Resolver({})", - self.base_url.host_str().unwrap_or("") - ) + write!(f, "H3 DNS Resolver({})", self.base_url) } } #[derive(Debug, snafu::Snafu)] -pub enum Error { +pub enum Error { #[snafu(display("h3 stream error"))] H3Stream { - source: h3x::client::MessageStreamError, + source: h3x::dhttp::message::MessageStreamError, }, + #[snafu(display("failed to connect h3 endpoint"))] + Connect { source: h3x::pool::ConnectError }, #[snafu(display("h3 request error"))] H3Request { - source: h3x::client::RequestError, + source: HyperRequestError, }, + #[snafu(display("h3 request timed out after {timeout:?}"))] + RequestTimeout { timeout: Duration }, #[snafu(display("{status}"))] Status { status: http::StatusCode }, @@ -77,29 +76,96 @@ pub enum Error H3Resolver +impl H3Resolver where C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, { - pub fn new(base_url: impl IntoUrl, client: Client) -> io::Result { - let base_url = base_url - .into_url() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + pub fn new( + base_url: impl AsRef, + client: H3Endpoint, + ) -> io::Result { + Self::from_endpoint(base_url, Arc::new(client)) + } + + pub fn from_endpoint( + base_url: impl AsRef, + endpoint: Arc>, + ) -> io::Result { + let base_url = Url::parse(base_url.as_ref()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; base_url.host_str().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, - "Base URL must have a valid host", + "base URL must have a valid host", ) })?; Ok(Self { - client, + endpoint, base_url, cached_records: DashMap::new(), negative_cache: DashMap::new(), }) } + fn connect_error(&self, source: h3x::pool::ConnectError) -> Error { + // H3 DNS resolvers keep a long-lived endpoint. A network transition may + // leave the cached H3 connection with stale QUIC paths, so the next + // attempt must establish a fresh connection instead of reusing it. + self.endpoint.clear_pool(); + Error::Connect { source } + } + + fn request_error(&self, source: HyperRequestError) -> Error { + self.endpoint.clear_pool(); + Error::H3Request { source } + } + + async fn execute_request( + &self, + request: http::Request< + impl http_body::Body + Send + 'static, + >, + ) -> Result< + http::Response< + impl http_body::Body, + >, + Error, + > { + let authority = request + .uri() + .authority() + .expect("h3 dns request URL must include an authority") + .clone(); + tracing::trace!(%authority, "connecting h3 dns endpoint"); + let connection = match self.endpoint.connect(authority.clone()).await { + Ok(connection) => { + tracing::trace!(%authority, "connected h3 dns endpoint"); + connection + } + Err(source) => return Err(self.connect_error(source)), + }; + + let method = request.method().clone(); + let uri = request.uri().clone(); + tracing::trace!(%method, %uri, "executing h3 dns request"); + match connection.execute_hyper_request(request).await { + Ok(response) => { + tracing::trace!( + status = %response.status(), + "h3 dns request response received" + ); + Ok(response) + } + Err(source) => Err(self.request_error(source)), + } + } + + pub fn clear_pool(&self) { + self.endpoint.clear_pool(); + } + pub async fn publish_endpoints( &self, name: &str, @@ -109,9 +175,8 @@ where let bytes = { let endpoints = endpoints .iter() - .filter_map(|ep| match *ep { - h3x::dquic::qresolve::EndpointAddr::Socket(ep) => ep.try_into().ok(), - h3x::dquic::qresolve::EndpointAddr::Ble(..) => None, + .filter_map(|ep| { + crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() }) .collect(); let mut hosts = std::collections::HashMap::new(); @@ -127,14 +192,16 @@ where let mut url = self.base_url.join("publish").expect("Invalid base URL"); url.set_query(Some(&format!("host={name}"))); let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - tracing::trace!("h3x publishing packet for {} to {}", name, self.base_url); - let (_, resp) = self - .client - .new_request() - .with_body(bytes::Bytes::copy_from_slice(packet)) - .post(uri) - .await - .map_err(|source| Error::H3Request { source })?; + tracing::trace!( + name, + packet_len = packet.len(), + url = %self.base_url, + "h3x publishing packet" + ); + let request = http::Request::post(uri) + .body(Full::new(bytes::Bytes::copy_from_slice(packet))) + .expect("h3 dns publish request must be valid"); + let resp = self.execute_request(request).await?; if resp.status() != http::StatusCode::OK { return Err(Error::Status { @@ -145,22 +212,82 @@ where Ok(()) } - pub const EXCLUDED_DOMAINS: [&str; 2] = ["dns.genmeta.net", "download.genmeta.net"]; + fn retryable_lookup_error(error: &Error) -> bool { + matches!( + error, + Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } + ) + } + + async fn lookup_response(&self, uri: http::Uri) -> Result> { + let request = http::Request::get(uri) + .body(Empty::::new()) + .expect("h3 dns lookup request must be valid"); + let resp = self.execute_request(request).await?; + + tracing::trace!("received response with status {}", resp.status()); + match resp.status() { + http::StatusCode::OK => {} + http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), + status => return Err(Error::Status { status }), + } + + match resp.into_body().collect().await { + Ok(response) => Ok(response.to_bytes()), + Err(source) => Err(Error::H3Stream { source }), + } + } + + async fn lookup_response_with_retry( + &self, + uri: http::Uri, + ) -> Result> { + for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { + match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) + .await + { + Ok(Ok(response)) => return Ok(response), + Ok(Err(error)) + if Self::retryable_lookup_error(&error) + && attempt < LOOKUP_REQUEST_ATTEMPTS => + { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup failed, retrying" + ); + } + Ok(Err(error)) => return Err(error), + Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup timed out, retrying" + ); + } + Err(_elapsed) => { + self.endpoint.clear_pool(); + return Err(Error::RequestTimeout { + timeout: LOOKUP_REQUEST_TIMEOUT, + }); + } + } + } + + unreachable!("lookup retry loop returns on the final attempt") + } pub async fn lookup(&self, name: &str) -> Result> { - use crate::parser::record; - let server = Arc::from(self.base_url.host_str().unwrap_or("")); - let source = Source::Http { server }; + use crate::core::parser::record; + let server = Arc::from(self.base_url.origin().ascii_serialization()); + let source = Source::H3 { server }; let Some(domain) = super::resolvable_name(name) else { return Err(Error::NoRecordFound); }; - // 1. Exclude certain domains from lookup - if Self::EXCLUDED_DOMAINS.contains(&domain) { - return Err(Error::NoRecordFound); - } - let now = Instant::now(); let positive_ttl = Duration::from_secs(10); let negative_ttl = Duration::from_secs(2); @@ -184,55 +311,39 @@ where let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); tracing::trace!("sending lookup request to {}", self.base_url); - let (_req, mut resp) = self - .client - .new_request() - .get(uri) - .await - .map_err(|source| Error::H3Request { source })?; - - tracing::trace!("received response with status {}", resp.status()); - match resp.status() { - http::StatusCode::OK => {} - http::StatusCode::NOT_FOUND => { + let response = match self.lookup_response_with_retry(uri).await { + Ok(response) => response, + Err(Error::NoRecordFound) => { self.negative_cache .insert(domain.to_string(), now + negative_ttl); return Err(Error::NoRecordFound); } - status => return Err(Error::Status { status }), - } - - let response = resp - .read_to_bytes() - .await - .map_err(|source| Error::H3Stream { source })?; + Err(error) => return Err(error), + }; // Server always returns multi-record format. let (_remain, multi) = be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; - let mut addrs = Vec::new(); + let mut endpoint_records = Vec::new(); for r in multi.records { let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { source: source.to_owned(), })?; - addrs.extend( - packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - let socket_ep = ep.clone().try_into().ok()?; - trace!(?socket_ep, "parsed endpoint from record"); - Some(h3x::dquic::qresolve::EndpointAddr::Socket(socket_ep)) - } - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }), - ); + endpoint_records.extend(packet.answers.iter().filter_map( + |answer| match answer.data() { + record::RData::E(ep) => Some(ep.clone()), + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }, + )); + } + let addrs = crate::resolvers::selector::selected_endpoint_addrs(endpoint_records); + for endpoint in &addrs { + trace!(?endpoint, "parsed endpoint from selected record group"); } if addrs.is_empty() { @@ -257,22 +368,123 @@ where pub type H3Publisher = H3Resolver; -impl Publish for H3Publisher +impl Publish for H3Publisher where C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, { fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - self.publish_packet(name, packet) - .map_err(io::Error::other) - .boxed() + Box::pin(async move { + match self.publish_packet(name, packet).await { + Ok(()) => Ok(()), + Err(error) => Err(io::Error::other(error)), + } + }) } } -impl Resolve for H3Resolver +impl Resolve for H3Resolver where C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - self.lookup(name).map_err(io::Error::other).boxed() + Box::pin(async move { + match H3Resolver::lookup(self, name).await { + Ok(stream) => Ok(stream), + Err(error) => Err(io::Error::other(error)), + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::resolvers::DHTTP_H3_DNS_SERVER; + + #[test] + fn lookup_retry_budget_leaves_external_timeout_margin() { + let total_budget = LOOKUP_REQUEST_TIMEOUT * LOOKUP_REQUEST_ATTEMPTS as u32; + + assert!( + total_budget <= Duration::from_secs(10), + "h3 lookup must return before common 15s command timeouts so callers can retry" + ); + } + + #[tokio::test] + async fn cached_lookup_reports_h3_dns_source() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cached_records.insert( + "car.lab.dhttp.net".to_owned(), + Record { + addrs: vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], + expire: Instant::now() + Duration::from_secs(60), + }, + ); + + let mut records = resolver.lookup("car.lab.dhttp.net").await.unwrap(); + let (source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + source, + Source::H3 { + server: Arc::from(DHTTP_H3_DNS_SERVER) + } + ); + assert_eq!( + endpoint, + EndpointAddr::direct("192.168.5.78:41748".parse().unwrap()) + ); + } + + #[tokio::test] + async fn cached_dns_genmeta_net_record_is_returned() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cached_records.insert( + "dns.genmeta.net".to_owned(), + Record { + addrs: vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], + expire: Instant::now() + Duration::from_secs(60), + }, + ); + + let mut records = resolver.lookup("dns.genmeta.net").await.unwrap(); + let (_source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + endpoint, + EndpointAddr::direct("192.0.2.53:4433".parse().unwrap()) + ); + } + + #[tokio::test] + async fn cached_lookup_uses_e_record_port_not_input_port() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cached_records.insert( + "nat.genmeta.net".to_owned(), + Record { + addrs: vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], + expire: Instant::now() + Duration::from_secs(60), + }, + ); + + let mut records = resolver.lookup("nat.genmeta.net:20004").await.unwrap(); + let (_source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + endpoint, + EndpointAddr::direct("192.0.2.10:21000".parse().unwrap()) + ); } } diff --git a/src/resolvers/http.rs b/src/resolvers/http.rs index f2ab92b..03984d9 100644 --- a/src/resolvers/http.rs +++ b/src/resolvers/http.rs @@ -1,20 +1,19 @@ -use std::{ - fmt::Display, - io, - sync::{Arc, LazyLock}, -}; +use std::{fmt::Display, io, sync::Arc}; use dashmap::DashMap; +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, +}; use futures::{StreamExt, TryFutureExt, stream}; -use h3x::dquic::qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}; use reqwest::{Client, IntoUrl, StatusCode, Url}; use tokio::time::Instant; -use crate::parser::packet::be_packet; +use crate::core::parser::packet::be_packet; #[derive(Debug)] struct Record { - addrs: Vec, + addrs: Vec, expire: Instant, } @@ -30,7 +29,7 @@ impl Display for HttpResolver { write!( f, "Http DNS({})", - self.base_url.host_str().expect("Cheked in constructor") + self.base_url.host_str().expect("checked in constructor") ) } } @@ -43,28 +42,48 @@ impl HttpResolver { base_url.host_str().ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, - "Base URL must have a valid host", + "base URL must have a valid host", ) })?; - static HTTP_CLIENT: LazyLock = LazyLock::new(|| { - Client::builder() - .build() - // with certs? - .expect("Failed to build HTTP client for HttpResolver") - }); - Ok(Self { - http_client: HTTP_CLIENT.clone(), + http_client: build_http_client()?, base_url, cached_records: DashMap::new(), }) } } +fn build_http_client() -> io::Result { + let native_certs = rustls_native_certs::load_native_certs(); + for error in &native_certs.errors { + let report = snafu::Report::from_error(error); + tracing::warn!(error = %report, "failed to load native root certificate"); + } + + let mut root_store = rustls::RootCertStore::empty(); + let (valid_roots, invalid_roots) = root_store.add_parsable_certificates(native_certs.certs); + if invalid_roots > 0 { + tracing::debug!(invalid_roots, "ignored invalid native root certificates"); + } + if valid_roots == 0 { + tracing::warn!("no native root certificates loaded for http resolver"); + } + + let mut tls = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + tls.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + Client::builder() + .use_preconfigured_tls(tls) + .build() + .map_err(io::Error::other) +} + #[derive(Debug, snafu::Snafu)] enum Error { - #[snafu(display("HTTP request failed"))] + #[snafu(display("http request failed"))] Reqwest { source: reqwest::Error }, #[snafu(display("{status}"))] @@ -96,18 +115,16 @@ impl Publish for HttpResolver { Box::pin(async move { let mut url = self.base_url.join("publish").expect("Invalid base URL"); url.set_query(Some(&format!("host={name}"))); - let client = reqwest::Client::new(); - let response = client + let response = self + .http_client .post(url) .header("Content-Type", "application/octet-stream") .body(packet.to_vec()) .send() .await - .map_err(|e| io::Error::other(e.to_string()))?; + .map_err(io::Error::other)?; - let _response = response - .error_for_status() - .map_err(|e| io::Error::other(e.to_string()))?; + let _response = response.error_for_status().map_err(io::Error::other)?; Ok(()) }) } @@ -124,14 +141,14 @@ impl Resolve for HttpResolver { let server = Arc::from(self.base_url.host_str().unwrap_or("")); let soource = Source::Http { server }; - use crate::parser::record; + use crate::core::parser::record; self.cached_records .retain(|_host, Record { expire, .. }| *expire < now); if let Some(record) = self.cached_records.get(domain) { let endpoint_addrs: Vec<_> = record .addrs .iter() - .map(|e: &h3x::dquic::qresolve::EndpointAddr| (soource.clone(), *e)) + .map(|endpoint: &EndpointAddr| (soource.clone(), *endpoint)) .collect(); return Ok(stream::iter(endpoint_addrs).boxed()); } @@ -148,20 +165,17 @@ impl Resolve for HttpResolver { source: source.to_owned(), })?; - let addrs = packet + let endpoints = packet .answers .iter() .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - let socket_ep = ep.clone().try_into().ok()?; - Some(h3x::dquic::qresolve::EndpointAddr::Socket(socket_ep)) - } + record::RData::E(ep) => Some(ep.clone()), _ => { tracing::debug!(?answer, "ignored record"); None } - }) - .collect::>(); + }); + let addrs = crate::resolvers::selector::selected_endpoint_addrs(endpoints); if addrs.is_empty() { return Err(Error::NoRecordFound); } diff --git a/src/resolvers/mdns.rs b/src/resolvers/mdns.rs deleted file mode 100644 index 19404f2..0000000 --- a/src/resolvers/mdns.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::{ - fmt, io, - net::{IpAddr, SocketAddr}, - sync::Arc, -}; - -use dashmap::DashMap; -use futures::{ - FutureExt, Stream, StreamExt, TryFutureExt, future, - stream::{self, FuturesUnordered}, -}; -use h3x::dquic::{ - qinterface::{BindInterface, WeakInterface, bind_uri::BindUri, io::IO}, - qresolve::{EndpointAddr, Family, RecordStream, ResolveFuture, SocketEndpointAddr, Source}, -}; - -use super::{Publish, Resolve}; -pub use crate::mdns::Mdns as MdnsResolver; -use crate::{parser::packet::Packet, protocol::MdnsProtocol}; - -impl MdnsResolver { - pub fn source(&self) -> Source { - Source::Mdns { - nic: self.bound_nic().into(), - family: match self.bound_ip() { - IpAddr::V4(..) => Family::V4, - IpAddr::V6(..) => Family::V6, - }, - } - } -} - -impl fmt::Display for MdnsResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.source(), f) - } -} - -impl Publish for MdnsResolver { - fn publish<'a>( - &'a self, - name: &'a str, - packet: &'a [u8], - ) -> h3x::dquic::qresolve::PublishFuture<'a> { - use crate::parser::{packet::be_packet, record::RData}; - let endpoints = be_packet(packet) - .map(|(_, pkt)| { - pkt.answers - .iter() - .filter_map(|rr| match rr.data() { - RData::E(ep) => Some(ep.clone()), - _ => None, - }) - .collect::>() - }) - .unwrap_or_default(); - self.insert_host(name.to_string(), endpoints); - Box::pin(future::ready(Ok(()))) - } -} - -impl Resolve for MdnsResolver { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - let source = self.source(); - self.query(name.to_owned()) - .map_ok(move |list| { - stream::iter(list.into_iter().filter_map(move |ep| { - let ep = EndpointAddr::Socket(SocketEndpointAddr::try_from(ep).ok()?); - Some((source.clone(), ep)) - })) - .boxed() - }) - .boxed() - } -} - -#[derive(Default, Clone, Debug)] -pub struct MdnsResolvers { - ifaces: DashMap, -} - -impl fmt::Display for MdnsResolvers { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "MDNS Resolvers") - } -} - -impl MdnsResolvers { - pub fn new() -> Self { - Self::default() - } - - pub fn insert_iface(&self, iface: BindInterface) { - let Some(iface) = iface.with_components(|component, iface| { - component.exist::().then(|| iface.downgrade()) - }) else { - return; - }; - self.ifaces.insert(iface.bind_uri(), iface); - } - - fn for_each_resolver(&self, mut f: impl FnMut(&MdnsResolver)) { - self.ifaces.retain(|_, iface| { - iface - .upgrade() - .ok() - .and_then(|iface| { - iface.bind_interface().with_components(|components, _| { - components.get::().map(&mut f) - }) - }) - .is_some() - }); - } - - pub async fn query(&self, name: &str) -> io::Result { - let mut lookup_futures = FuturesUnordered::new(); - self.for_each_resolver(|resolver| { - let source = resolver.source(); - lookup_futures.push(resolver.query(name.to_owned()).map_ok(move |eps| { - stream::iter(eps.into_iter().filter_map(move |ep| { - let ep = EndpointAddr::Socket(SocketEndpointAddr::try_from(ep).ok()?); - Some((source.clone(), ep)) - })) - })); - }); - - let mut last_error = None; - let no_resolver = || io::Error::other("no mdns resolvers available"); - let stream = loop { - match lookup_futures.next().await { - Some(Ok(stream)) => break stream, - Some(Err(error)) => last_error = Some(error), - None => return Err(last_error.unwrap_or_else(no_resolver)), - } - }; - - Ok(stream - .chain(lookup_futures.flat_map(stream::iter).flatten()) - .boxed()) - } - - pub fn merge(&self, other: &Self) { - other.ifaces.iter().for_each(|entry| { - self.ifaces - .entry(entry.key().clone()) - .or_insert_with(|| entry.value().clone()); - }); - } - - /// Discover mDNS broadcasts from all active resolvers. - /// - /// Returns a stream of `(SocketAddr, Packet)` pairs by polling all - /// underlying protocols concurrently. Unlike per-resolver `discover()`, - /// this uses a single `Box::pin` allocation for the combined stream. - pub fn discover(&self) -> impl Stream + use<> { - let mut protos = Vec::new(); - self.for_each_resolver(|resolver| { - protos.push(resolver.protocol()); - }); - - async fn receive_one( - proto: Arc, - ) -> Option<((SocketAddr, Packet), Arc)> { - let result = proto.receive_boardcast().await.ok()?; - Some((result, proto)) - } - - let mut pending = protos - .into_iter() - .map(receive_one) - .collect::>(); - - Box::pin(stream::poll_fn(move |cx| { - use std::task::Poll; - loop { - match pending.poll_next_unpin(cx) { - Poll::Ready(Some(Some((item, proto)))) => { - pending.push(receive_one(proto)); - return Poll::Ready(Some(item)); - } - Poll::Ready(Some(None)) => { - // This resolver's protocol disconnected, skip it - continue; - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - })) - } -} - -impl Resolve for MdnsResolvers { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - self.query(name).boxed() - } -} diff --git a/src/resolvers/selector.rs b/src/resolvers/selector.rs new file mode 100644 index 0000000..87017f4 --- /dev/null +++ b/src/resolvers/selector.rs @@ -0,0 +1,134 @@ +use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind}; +use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; + +use crate::core::parser::record::endpoint::EndpointAddr as DnsEndpointAddr; + +pub(crate) fn selected_endpoint_addrs( + records: impl IntoIterator, +) -> Vec { + selected_endpoint_records(records.into_iter().map(|record| ((), record))) + .into_iter() + .map(|((), endpoint)| endpoint) + .collect() +} + +pub(crate) fn selected_endpoint_records( + records: impl IntoIterator, +) -> Vec<(T, DquicEndpointAddr)> { + let mut groups: Vec<(CertificateChainKey, Vec<(T, DquicEndpointAddr)>)> = Vec::new(); + + for (tag, record) in records { + let Ok(selector) = record.certificate_chain_key() else { + continue; + }; + let Ok(endpoint) = DquicEndpointAddr::try_from(record) else { + continue; + }; + + if let Some((_key, endpoints)) = groups.iter_mut().find(|(key, _)| *key == selector) { + endpoints.push((tag, endpoint)); + } else { + groups.push((selector, vec![(tag, endpoint)])); + } + } + + let selected = groups + .iter() + .position(|(key, endpoints)| { + key.kind() == CertificateChainKind::Primary && !endpoints.is_empty() + }) + .or_else(|| { + groups + .iter() + .position(|(_key, endpoints)| !endpoints.is_empty()) + }); + + selected + .map(|index| groups.swap_remove(index).1) + .unwrap_or_default() +} + +#[cfg(test)] +mod tests { + use crate::core::parser::record::endpoint::EndpointAddr; + + fn direct(addr: &str, main: bool, sequence: u64) -> EndpointAddr { + let mut endpoint = match addr.parse().unwrap() { + std::net::SocketAddr::V4(addr) => EndpointAddr::direct_v4(addr), + std::net::SocketAddr::V6(addr) => EndpointAddr::direct_v6(addr), + }; + endpoint.set_main(main); + endpoint.set_sequence(sequence); + endpoint + } + + #[test] + fn selected_endpoint_addrs_prefers_primary_group() { + let secondary = direct("192.0.2.20:4433", false, 0); + let primary_a = direct("192.0.2.10:4433", true, 2); + let primary_b = direct("192.0.2.11:4433", true, 2); + + let selected = super::selected_endpoint_addrs([secondary, primary_a, primary_b]); + + assert_eq!(selected.len(), 2); + assert_eq!( + selected[0], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.10:4433".parse().unwrap()) + ); + assert_eq!( + selected[1], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.11:4433".parse().unwrap()) + ); + } + + #[test] + fn selected_endpoint_addrs_uses_one_secondary_group_when_no_primary_exists() { + let secondary_a = direct("192.0.2.20:4433", false, 5); + let secondary_b = direct("192.0.2.21:4433", false, 5); + let other_secondary = direct("192.0.2.30:4433", false, 6); + + let selected = super::selected_endpoint_addrs([secondary_a, secondary_b, other_secondary]); + + assert_eq!(selected.len(), 2); + assert_eq!( + selected[0], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.20:4433".parse().unwrap()) + ); + assert_eq!( + selected[1], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.21:4433".parse().unwrap()) + ); + } + + #[test] + fn selected_endpoint_addrs_treats_missing_sequence_as_zero() { + let mut first = direct("192.0.2.40:4433", true, 0); + first.set_clustered(false); + let second = direct("192.0.2.41:4433", true, 0); + + let selected = super::selected_endpoint_addrs([first, second]); + + assert_eq!(selected.len(), 2); + } + + #[test] + fn selected_endpoint_records_uses_one_group_across_sources() { + let selected = super::selected_endpoint_records([ + ("wifi", direct("192.0.2.50:4433", true, 3)), + ("ethernet", direct("192.0.2.51:4433", true, 4)), + ("wifi", direct("192.0.2.52:4433", true, 3)), + ]); + + assert_eq!(selected.len(), 2); + assert_eq!(selected[0].0, "wifi"); + assert_eq!( + selected[0].1, + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.50:4433".parse().unwrap()) + ); + assert_eq!(selected[1].0, "wifi"); + assert_eq!( + selected[1].1, + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.52:4433".parse().unwrap()) + ); + } +} diff --git a/src/resolvers/weak.rs b/src/resolvers/weak.rs new file mode 100644 index 0000000..965df02 --- /dev/null +++ b/src/resolvers/weak.rs @@ -0,0 +1,201 @@ +use std::{ + fmt, io, + sync::{Arc, Weak}, +}; + +use dquic::qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture}; +use futures::FutureExt; +use snafu::{ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum WeakLookupError { + #[snafu(display("weak resolver target has been dropped"))] + Dropped, + #[snafu(display("weak resolver lookup failed"))] + Lookup { source: io::Error }, +} + +#[derive(Debug, Snafu)] +#[snafu(module, visibility(pub))] +pub enum WeakPublishError { + #[snafu(display("weak resolver target has been dropped"))] + Dropped, + #[snafu(display("weak resolver publish failed"))] + Publish { source: io::Error }, +} + +pub struct WeakResolver { + inner: Weak, +} + +impl fmt::Debug for WeakResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WeakResolver") + .field("alive", &self.inner.strong_count().gt(&0)) + .finish() + } +} + +impl Clone for WeakResolver { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl WeakResolver { + #[must_use] + pub fn new(inner: Weak) -> Self { + Self { inner } + } + + pub fn upgrade(&self) -> Result, WeakLookupError> { + self.inner.upgrade().ok_or(WeakLookupError::Dropped) + } +} + +impl fmt::Display for WeakResolver +where + R: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.inner.upgrade() { + Some(resolver) => write!(f, "WeakResolver({resolver})"), + None => f.write_str("WeakResolver(dropped)"), + } + } +} + +impl WeakResolver +where + R: Resolve + 'static, +{ + pub async fn lookup_typed(&self, name: &str) -> Result { + let resolver = self.upgrade()?; + resolver + .lookup(name) + .await + .context(weak_lookup_error::LookupSnafu) + } +} + +impl Resolve for WeakResolver +where + R: Resolve + 'static, +{ + fn lookup<'a>(&'a self, name: &'a str) -> ResolveFuture<'a> { + async move { self.lookup_typed(name).await.map_err(io::Error::other) }.boxed() + } +} + +impl WeakResolver +where + R: Publish + 'static, +{ + pub async fn publish_typed(&self, name: &str, packet: &[u8]) -> Result<(), WeakPublishError> { + let Some(resolver) = self.inner.upgrade() else { + return weak_publish_error::DroppedSnafu.fail(); + }; + resolver + .publish(name, packet) + .await + .context(weak_publish_error::PublishSnafu) + } +} + +impl Publish for WeakResolver +where + R: Publish + 'static, +{ + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + async move { + self.publish_typed(name, packet) + .await + .map_err(io::Error::other) + } + .boxed() + } +} + +#[cfg(test)] +mod tests { + use std::{fmt, sync::Arc}; + + use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, Resolve, Source}, + }; + use futures::{FutureExt, StreamExt}; + + use super::*; + + #[derive(Debug)] + struct TestResolver; + + impl fmt::Display for TestResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("test resolver") + } + } + + impl Resolve for TestResolver { + fn lookup<'a>(&'a self, _name: &'a str) -> dquic::qresolve::ResolveFuture<'a> { + async move { + let endpoint = EndpointAddr::direct("127.0.0.1:4433".parse().unwrap()); + Ok(futures::stream::iter([(Source::System, endpoint)]).boxed()) + } + .boxed() + } + } + + impl Publish for TestResolver { + fn publish<'a>( + &'a self, + _name: &'a str, + _packet: &'a [u8], + ) -> dquic::qresolve::PublishFuture<'a> { + async move { Ok(()) }.boxed() + } + } + + #[tokio::test] + async fn lookup_after_target_drop_returns_typed_error() { + let strong = Arc::new(TestResolver); + let resolver = WeakResolver::new(Arc::downgrade(&strong)); + drop(strong); + + let error = match resolver.lookup_typed("example.test").await { + Ok(_) => panic!("dropped weak resolver must not resolve"), + Err(error) => error, + }; + + assert!(matches!(error, WeakLookupError::Dropped)); + } + + #[tokio::test] + async fn lookup_forwards_while_target_is_alive() { + let strong = Arc::new(TestResolver); + let resolver = WeakResolver::new(Arc::downgrade(&strong)); + + let mut stream = resolver.lookup_typed("example.test").await.unwrap(); + let (_source, endpoint) = stream.next().await.expect("forwarded endpoint"); + + assert_eq!( + endpoint, + EndpointAddr::direct("127.0.0.1:4433".parse().unwrap()) + ); + } + + #[tokio::test] + async fn publish_forwards_while_target_is_alive() { + let strong = Arc::new(TestResolver); + let resolver = WeakResolver::new(Arc::downgrade(&strong)); + + resolver + .publish_typed("example.test", b"packet") + .await + .unwrap(); + } +} diff --git a/src/wire.rs b/src/wire.rs deleted file mode 100644 index 25c9719..0000000 --- a/src/wire.rs +++ /dev/null @@ -1,57 +0,0 @@ -/// HTTP multi-record response wire format shared between server and all clients. -/// -/// Wire layout (big-endian, contiguous): -/// ```text -/// +-----------+ (repeated `count` times) -/// | count | +-----------+------+-----------+------+ -/// | u32 BE | | dns_len | dns | cert_len | cert | -/// +-----------+ | u32 BE | ... | u32 BE | ... | -/// +-----------+------+-----------+------+ -/// ``` -use nom::{IResult, bytes::streaming::take, number::streaming::be_u32}; - -/// One DNS + certificate pair inside a [`MultiResponse`]. -#[derive(Debug, Clone)] -pub struct ResponseRecord { - /// Serialised DNS packet bytes. - pub dns: Vec, - /// DER-encoded leaf certificate of the publisher (may be empty). - pub cert: Vec, -} - -impl ResponseRecord { - /// SHA-256 fingerprint of the publisher certificate, as a lowercase hex string. - /// Returns `None` when the cert field is empty. - pub fn cert_fingerprint_hex(&self) -> Option { - if self.cert.is_empty() { - return None; - } - use ring::digest::{SHA256, digest}; - let d = digest(&SHA256, &self.cert); - Some(d.as_ref().iter().map(|b| format!("{b:02x}")).collect()) - } -} - -/// Decoded HTTP response body carrying one or more DNS records. -#[derive(Debug, Clone)] -pub struct MultiResponse { - pub records: Vec, -} - -/// nom parser for [`MultiResponse`]. -pub fn be_multi_response(input: &[u8]) -> IResult<&[u8], MultiResponse> { - let (mut input, count) = be_u32(input)?; - let mut records = Vec::with_capacity(count as usize); - for _ in 0..count { - let (rest, dns_len) = be_u32(input)?; - let (rest, dns) = take(dns_len as usize)(rest)?; - let (rest, cert_len) = be_u32(rest)?; - let (rest, cert) = take(cert_len as usize)(rest)?; - records.push(ResponseRecord { - dns: dns.to_vec(), - cert: cert.to_vec(), - }); - input = rest; - } - Ok((input, MultiResponse { records })) -} diff --git a/tests/fixtures/malformed.der b/tests/fixtures/malformed.der new file mode 100644 index 0000000..d3f98e0 Binary files /dev/null and b/tests/fixtures/malformed.der differ diff --git a/tests/fixtures/missing.der b/tests/fixtures/missing.der new file mode 100644 index 0000000..18af2d1 Binary files /dev/null and b/tests/fixtures/missing.der differ diff --git a/tests/fixtures/valid.der b/tests/fixtures/valid.der new file mode 100644 index 0000000..d5566dd Binary files /dev/null and b/tests/fixtures/valid.der differ