From c37ca7ad9450f925cbfbc407a7b9cf17764c8e18 Mon Sep 17 00:00:00 2001 From: metah3m Date: Wed, 3 Jun 2026 15:31:22 +0800 Subject: [PATCH 01/29] feat(lookup): enhance lookup functionality with source IP handling and sorting logic --- src/bin/ddns-server/lookup.rs | 133 +++++++++++++++++++++++++++++++--- 1 file changed, 121 insertions(+), 12 deletions(-) diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs index e8c2616..bf131d9 100644 --- a/src/bin/ddns-server/lookup.rs +++ b/src/bin/ddns-server/lookup.rs @@ -1,7 +1,10 @@ use std::{ + any::Any, + cmp::Ordering, collections::{HashMap, HashSet}, convert::Infallible, - net::SocketAddr, + net::{IpAddr, SocketAddr}, + sync::Arc, }; use ddns::core::{ @@ -10,6 +13,7 @@ use ddns::core::{ wire::MultiResponse, }; use deadpool_redis::redis::{self, AsyncCommands}; +use h3x::{connection::ConnectionState, quic}; use h3x::dhttp::message::MessageStreamError; use http_body_util::{Full, combinators::UnsyncBoxBody}; use tracing::debug; @@ -71,6 +75,99 @@ fn normalize_lookup_records(records: Vec) -> Vec { normalized } +fn lookup_endpoint(dns_bytes: &[u8]) -> Option<(SocketAddr, Option)> { + let (_, packet) = be_packet(dns_bytes).ok()?; + packet.answers.iter().find_map(|answer| match answer.data() { + RData::E(endpoint) => Some((endpoint.addr(), endpoint.load())), + _ => None, + }) +} + +fn common_prefix_len(source: IpAddr, target: IpAddr) -> u32 { + fn bytes_prefix_len(left: &[u8], right: &[u8]) -> u32 { + let mut matched = 0; + for (l, r) in left.iter().zip(right.iter()) { + let diff = l ^ r; + if diff == 0 { + matched += 8; + continue; + } + matched += (diff as u32).leading_zeros().saturating_sub(24); + break; + } + matched + } + + match (source, target) { + (IpAddr::V4(source), IpAddr::V4(target)) => { + bytes_prefix_len(&source.octets(), &target.octets()) + } + (IpAddr::V6(source), IpAddr::V6(target)) => { + bytes_prefix_len(&source.octets(), &target.octets()) + } + _ => 0, + } +} + +fn sort_lookup_records(records: Vec, source_ip: Option) -> Vec { + let mut decorated = records + .into_iter() + .enumerate() + .map(|(index, record)| { + let sort_key = lookup_endpoint(&record.0).map(|(endpoint, load)| { + let (family_match, prefix_len) = match source_ip { + Some(source_ip) if source_ip.is_ipv4() == endpoint.ip().is_ipv4() => { + (true, common_prefix_len(source_ip, endpoint.ip())) + } + Some(_) => (false, 0), + None => (false, 0), + }; + + (family_match, prefix_len, load) + }); + (sort_key, index, record) + }) + .collect::>(); + + decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { + match (left_key, right_key) { + (Some((left_family, left_prefix, left_load)), Some((right_family, right_prefix, right_load))) => right_family + .cmp(left_family) + .then_with(|| right_prefix.cmp(left_prefix)) + .then_with(|| match (left_load, right_load) { + (Some(left), Some(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + }), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + } + .then_with(|| left_index.cmp(right_index)) + }); + + decorated + .into_iter() + .map(|(_, _, record)| record) + .collect() +} + +fn request_source_ip(request: &Request) -> Option { + let connection = request + .extensions() + .get::>>()? + .clone(); + let quic = connection.quic(); + let dquic = (quic.as_ref() as &dyn Any).downcast_ref::()?; + let ctx = dquic.path_context().ok()?; + + ctx.paths::>() + .into_iter() + .next() + .map(|(pathway, _)| pathway.remote().addr().ip()) +} + // --------------------------------------------------------------------------- // Core lookup logic // --------------------------------------------------------------------------- @@ -79,17 +176,19 @@ pub async fn perform_lookup( state: &AppState, host: &str, limit: Option, + source_ip: Option, ) -> Result { let host = normalize_host(host)?; - perform_lookup_multi(state, &host, limit).await + perform_lookup_multi(state, &host, limit, source_ip).await } async fn perform_lookup_multi( state: &AppState, host: &str, limit: Option, + source_ip: Option, ) -> Result { - let mut records = match &state.storage { + let dynamic_records = match &state.storage { Storage::Redis(pool) => { let mut conn = pool.get().await.map_err(|e| AppError::Redis { message: e.to_string(), @@ -108,10 +207,9 @@ async fn perform_lookup_multi( .await .unwrap_or(()); - // Fetch all remaining, newest first (highest score = most recently published) - let count: isize = limit.map(|l| l as isize).unwrap_or(-1); + // Fetch all remaining active dynamic records; scheduling is applied after decode. let members: Vec> = conn - .zrevrange(&set_key, 0isize, if count < 0 { -1 } else { count - 1 }) + .zrevrange(&set_key, 0isize, -1isize) .await .map_err(|e| AppError::Redis { message: e.to_string(), @@ -138,10 +236,9 @@ async fn perform_lookup_multi( // Evict expired entries in-place. entry.retain(|_, r| r.expire > now); // Sort newest-first by published_at. - let take = limit.unwrap_or(entry.len()).min(entry.len()); let mut records: Vec<_> = entry.values().collect(); records.sort_by_key(|b| std::cmp::Reverse(b.published_at)); - records[..take] + records .iter() .map(|r| (r.dns_bytes.clone(), r.cert_bytes.clone())) .collect::>() @@ -151,11 +248,22 @@ async fn perform_lookup_multi( } }; - if let Some(seed_records) = state.seed_records.get(host) { - records.extend(seed_records.iter().cloned()); + let mut records = sort_lookup_records(normalize_lookup_records(dynamic_records), source_ip); + + let should_append_seeds = records.is_empty() || limit.is_some_and(|max| records.len() < max); + if should_append_seeds + && let Some(seed_records) = state.seed_records.get(host) + { + let seeds = sort_lookup_records(seed_records.iter().cloned().collect(), source_ip); + records.extend(seeds); } let records = normalize_lookup_records(records); + let records = if let Some(limit) = limit { + records.into_iter().take(limit).collect::>() + } else { + records + }; if records.is_empty() { Ok(LookupResult::NotFound) @@ -206,15 +314,16 @@ pub async fn lookup_with_cert(state: AppState, request: Request) -> Response { let Some(host) = params.get("host") else { return write_error(AppError::MissingHostParam); }; + let source_ip = request_source_ip(&request); let limit: Option = params .get("limit") .and_then(|v| v.parse::().ok()) .filter(|&n| n > 0); - debug!(host = %host, limit, "lookup.request"); + debug!(host = %host, limit, ?source_ip, "lookup.request"); - match perform_lookup(&state, host, limit).await { + match perform_lookup(&state, host, limit, source_ip).await { Ok(LookupResult::NotFound) => { debug!(host = %host, "lookup.not_found"); body_response( From 47904b02f44fbc59dcdf0bd5bb4d8879adba5af9 Mon Sep 17 00:00:00 2001 From: metah3m Date: Wed, 3 Jun 2026 17:29:16 +0800 Subject: [PATCH 02/29] feat: add GeoIP support for geo-routing - Introduced GeoResolver for handling GeoIP lookups using MaxMind's GeoLite2 databases. - Implemented geo-routing logic to log freshness of GeoIP databases and ensure both city and ASN databases are configured together. - Enhanced record publishing to include geographical tags (countries and ASNs) for DNS records. - Updated Redis storage to manage records with geographical indexing. - Added a script to automate the update of GeoLite2 databases. - Included tests for GeoResolver and its integration with DNS record publishing. --- .gitignore | 3 +- Cargo.toml | 8 + README.md | 18 + scripts/update-geolite-mmdb.sh | 102 +++++ server.toml | 6 + src/bin/ddns-server/config.rs | 10 + src/bin/ddns-server/geo.rs | 196 +++++++++ src/bin/ddns-server/lookup.rs | 727 +++++++++++++++++++++++++++++---- src/bin/ddns-server/main.rs | 118 +++++- src/bin/ddns-server/publish.rs | 106 +++-- src/bin/ddns-server/storage.rs | 289 +++++++++++-- 11 files changed, 1436 insertions(+), 147 deletions(-) create mode 100755 scripts/update-geolite-mmdb.sh create mode 100644 src/bin/ddns-server/geo.rs diff --git a/.gitignore b/.gitignore index 57506c5..4a376e8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ Cargo.lock *.log build - .DS_Store .vscode/ +/geoip +/docs diff --git a/Cargo.toml b/Cargo.toml index b84ddea..a6cf4b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -63,6 +63,7 @@ url = { version = "2", optional = true } clap = { version = "4", features = ["derive"], optional = true } deadpool-redis = { version = "0.23", optional = true } idna = { version = "1", optional = true } +maxminddb = { version = "0.26", optional = true } serde = { version = "1", features = ["derive"], optional = true } toml = { version = "1", optional = true } tower-service = { version = "0.3", optional = true } @@ -88,6 +89,7 @@ server = [ "dep:clap", "dep:deadpool-redis", "dep:idna", + "dep:maxminddb", "dep:serde", "dep:toml", "dep:tower-service", @@ -124,3 +126,9 @@ required-features = ["h3x-resolver"] name = "query" path = "examples/query.rs" required-features = ["h3x-resolver"] + +[patch.crates-io] +proc-macro-error2 = { path = "patches/proc-macro-error2" } + +[patch."https://github.com/genmeta/h3x.git"] +h3x = { path = "../h3x-endpoint-local" } diff --git a/README.md b/README.md index 071aaaa..35be4e3 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,24 @@ Start the server with the `server` feature: cargo run --bin ddns-server --features server -- --config server.toml ``` +The server can optionally enable GEO-aware lookup ordering with local MaxMind +GeoLite2 City and ASN databases. When both `geoip_city_db` and `geoip_asn_db` +are configured, lookups prefer same-country and same-ASN endpoints first, then +fall back to address family, endpoint load, and city-distance tie-breaking for +sufficiently accurate records. + +To update those databases on a server, use [scripts/update-geolite-mmdb.sh](scripts/update-geolite-mmdb.sh). +It wraps `geoipupdate` and downloads both `GeoLite2-City.mmdb` and +`GeoLite2-ASN.mmdb` into one directory: + +```bash +MAXMIND_ACCOUNT_ID=12345 \ +MAXMIND_LICENSE_KEY=your_license_key \ +./scripts/update-geolite-mmdb.sh /etc/ddns +``` + +For detailed parameters and HTTP packet structures, see [examples/README.md](examples/README.md). + The server exposes two HTTP/3 routes: | Route | Meaning | diff --git a/scripts/update-geolite-mmdb.sh b/scripts/update-geolite-mmdb.sh new file mode 100755 index 0000000..f7cc3e9 --- /dev/null +++ b/scripts/update-geolite-mmdb.sh @@ -0,0 +1,102 @@ +#!/usr/bin/env sh + +set -eu + +usage() { + cat <<'EOF' +Usage: + MAXMIND_ACCOUNT_ID=... MAXMIND_LICENSE_KEY=... ./scripts/update-geolite-mmdb.sh [target-dir] + +Downloads or updates the GeoLite2 City and ASN mmdb databases with geoipupdate. + +Arguments: + target-dir Optional output directory. Defaults to /var/lib/ddns/geoip. + +Required environment variables: + MAXMIND_ACCOUNT_ID + MAXMIND_LICENSE_KEY + +Optional environment variables: + GEOIPUPDATE_BIN geoipupdate binary name or path. Default: geoipupdate + GEOIPUPDATE_VERBOSE Set to 1 to pass -v to geoipupdate. + +Example: + MAXMIND_ACCOUNT_ID=12345 \ + MAXMIND_LICENSE_KEY=xxxx \ + ./scripts/update-geolite-mmdb.sh /etc/ddns +EOF +} + +if [ "${1:-}" = "-h" ] || [ "${1:-}" = "--help" ]; then + usage + exit 0 +fi + +if [ "$#" -gt 1 ]; then + usage >&2 + exit 64 +fi + +require_env() { + name="$1" + eval "value=\${$name:-}" + if [ -z "$value" ]; then + echo "missing required environment variable: $name" >&2 + exit 2 + fi +} + +require_env MAXMIND_ACCOUNT_ID +require_env MAXMIND_LICENSE_KEY + +geoipupdate_bin="${GEOIPUPDATE_BIN:-geoipupdate}" +target_dir="${1:-${GEOIP_TARGET_DIR:-/var/lib/ddns/geoip}}" + +if ! command -v "$geoipupdate_bin" >/dev/null 2>&1; then + echo "geoipupdate not found: $geoipupdate_bin" >&2 + echo "install it first, for example on macOS: brew install geoipupdate" >&2 + exit 127 +fi + +umask 077 +tmp_dir="$(mktemp -d "${TMPDIR:-/tmp}/geoipupdate.XXXXXX")" +cleanup() { + rm -rf "$tmp_dir" +} +trap cleanup EXIT HUP INT TERM + +mkdir -p "$target_dir" + +config_file="$tmp_dir/GeoIP.conf" +cat >"$config_file" <&2 + exit 1 +fi + +cat <, + + /// Path to the GeoLite2 City database. + #[serde(default)] + pub geoip_city_db: Option, + + /// Path to the GeoLite2 ASN database. + #[serde(default)] + pub geoip_asn_db: Option, } impl Config { @@ -76,6 +84,8 @@ impl Config { self.cert = expand_home_dir(&self.cert); self.key = expand_home_dir(&self.key); self.root_cert = expand_home_dir(&self.root_cert); + self.geoip_city_db = self.geoip_city_db.map(|path| expand_home_dir(&path)); + self.geoip_asn_db = self.geoip_asn_db.map(|path| expand_home_dir(&path)); self } diff --git a/src/bin/ddns-server/geo.rs b/src/bin/ddns-server/geo.rs new file mode 100644 index 0000000..3cedb9a --- /dev/null +++ b/src/bin/ddns-server/geo.rs @@ -0,0 +1,196 @@ +use std::{io, net::IpAddr, path::Path}; + +use maxminddb::{Reader, geoip2}; + +#[derive(Clone, Debug)] +pub struct GeoPoint { + pub latitude: f64, + pub longitude: f64, + pub accuracy_radius_km: u16, +} + +#[derive(Clone, Debug, Default)] +pub struct GeoTraits { + pub country: Option, + pub city: Option, + pub asn: Option, + pub point: Option, +} + +#[derive(Debug)] +pub struct GeoResolver { + city: Reader>, + asn: Reader>, + city_distance_routing: bool, + max_accuracy_radius_km: u32, +} + +impl GeoResolver { + pub fn open( + city_db: &Path, + asn_db: &Path, + city_distance_routing: bool, + max_accuracy_radius_km: u32, + ) -> io::Result { + let city = Reader::open_readfile(city_db).map_err(io::Error::other)?; + let asn = Reader::open_readfile(asn_db).map_err(io::Error::other)?; + + Ok(Self { + city, + asn, + city_distance_routing, + max_accuracy_radius_km, + }) + } + + pub fn lookup_traits(&self, ip: IpAddr) -> GeoTraits { + GeoTraits { + country: self.lookup_country(ip), + city: self.lookup_city(ip), + asn: self.lookup_asn(ip), + point: self.lookup_point(ip), + } + } + + pub fn city_build_epoch(&self) -> u64 { + self.city.metadata.build_epoch + } + + pub fn asn_build_epoch(&self) -> u64 { + self.asn.metadata.build_epoch + } + + pub fn lookup_country(&self, ip: IpAddr) -> Option { + let city = self.city.lookup::(ip).ok()??; + city.country?.iso_code.map(str::to_owned) + } + + pub fn lookup_asn(&self, ip: IpAddr) -> Option { + let asn = self.asn.lookup::(ip).ok()??; + asn.autonomous_system_number + } + + pub fn lookup_city(&self, ip: IpAddr) -> Option { + let city = self.city.lookup::(ip).ok()??; + city.city?.names?.get("en").copied().map(str::to_owned) + } + + pub fn lookup_point(&self, ip: IpAddr) -> Option { + let city = self.city.lookup::(ip).ok()??; + let location = city.location?; + let latitude = location.latitude?; + let longitude = location.longitude?; + let accuracy_radius_km = location.accuracy_radius?; + + Some(GeoPoint { + latitude, + longitude, + accuracy_radius_km, + }) + } + + pub fn geo_distance_km(&self, left: &GeoPoint, right: &GeoPoint) -> Option { + if !self.city_distance_routing { + return None; + } + + if u32::from(left.accuracy_radius_km) > self.max_accuracy_radius_km + || u32::from(right.accuracy_radius_km) > self.max_accuracy_radius_km + { + return None; + } + + Some(haversine_distance_km( + left.latitude, + left.longitude, + right.latitude, + right.longitude, + )) + } +} + +fn haversine_distance_km( + left_latitude: f64, + left_longitude: f64, + right_latitude: f64, + right_longitude: f64, +) -> f64 { + let earth_radius_km = 6_371.0; + let lat_delta = (right_latitude - left_latitude).to_radians(); + let lon_delta = (right_longitude - left_longitude).to_radians(); + let left_latitude = left_latitude.to_radians(); + let right_latitude = right_latitude.to_radians(); + + let haversine = (lat_delta / 2.0).sin().powi(2) + + left_latitude.cos() * right_latitude.cos() * (lon_delta / 2.0).sin().powi(2); + let arc = 2.0 * haversine.sqrt().asin(); + + earth_radius_km * arc +} + +#[cfg(test)] +mod tests { + use std::{net::IpAddr, path::PathBuf, str::FromStr}; + + use super::*; + + fn fixture_geo_resolver() -> GeoResolver { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let city_db = manifest_dir.join("geoip/GeoLite2-City.mmdb"); + let asn_db = manifest_dir.join("geoip/GeoLite2-ASN.mmdb"); + + GeoResolver::open(&city_db, &asn_db, true, 100).expect("fixture geo db should open") + } + + #[test] + fn bundled_geolite_maps_real_ips_to_expected_country_and_asn() { + let geo = fixture_geo_resolver(); + let cases = [ + ("8.8.8.8", "US", 15169_u32), + ("223.5.5.5", "CN", 45102_u32), + ("80.80.80.80", "NL", 60679_u32), + ("168.95.1.1", "TW", 3462_u32), + ("200.160.0.8", "BR", 22548_u32), + ]; + + for (candidate, expected_country, expected_asn) in cases { + let ip = IpAddr::from_str(candidate).unwrap(); + let traits = geo.lookup_traits(ip); + + assert_eq!(traits.country.as_deref(), Some(expected_country)); + assert_eq!(traits.asn, Some(expected_asn)); + assert!( + traits.point.is_some(), + "{candidate} should resolve to a city point" + ); + } + } + + #[test] + fn bundled_geolite_exposes_city_name_separately_from_accuracy_radius() { + let geo = fixture_geo_resolver(); + let ip = IpAddr::from_str("223.5.5.5").unwrap(); + let traits = geo.lookup_traits(ip); + + assert_eq!(traits.country.as_deref(), Some("CN")); + assert_eq!(traits.city.as_deref(), Some("Hangzhou")); + assert_eq!( + traits.point.as_ref().map(|point| point.accuracy_radius_km), + Some(20) + ); + } + + #[test] + fn bundled_geolite_may_have_coordinates_without_city_name() { + let geo = fixture_geo_resolver(); + let ip = IpAddr::from_str("168.95.1.1").unwrap(); + let traits = geo.lookup_traits(ip); + + assert_eq!(traits.country.as_deref(), Some("TW")); + assert_eq!(traits.city, None); + assert_eq!( + traits.point.as_ref().map(|point| point.accuracy_radius_km), + Some(200) + ); + } +} diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs index bf131d9..50e3fa3 100644 --- a/src/bin/ddns-server/lookup.rs +++ b/src/bin/ddns-server/lookup.rs @@ -3,6 +3,7 @@ use std::{ cmp::Ordering, collections::{HashMap, HashSet}, convert::Infallible, + hash::Hash, net::{IpAddr, SocketAddr}, sync::Arc, }; @@ -13,14 +14,22 @@ use ddns::core::{ wire::MultiResponse, }; use deadpool_redis::redis::{self, AsyncCommands}; +<<<<<<< HEAD use h3x::{connection::ConnectionState, quic}; use h3x::dhttp::message::MessageStreamError; +======= +use h3x::{connection::ConnectionState, message::stream::MessageStreamError, quic}; +>>>>>>> 13e6482 (feat: add GeoIP support for geo-routing) use http_body_util::{Full, combinators::UnsyncBoxBody}; use tracing::debug; use crate::{ error::{AppError, normalize_host, parse_query_params}, - storage::{AppState, LookupRecord, Storage, StoredRecord, unix_now_secs}, + geo::{GeoResolver, GeoTraits}, + storage::{ + AppState, LookupRecord, Storage, StoredRecord, redis_all_index_key, + redis_asn_index_key, redis_country_index_key, redis_primary_key, unix_now_secs, + }, }; pub type Request = http::Request>; @@ -38,6 +47,23 @@ pub enum LookupResult { type EndpointKey = (SocketAddr, Option); +const LOOKUP_CANDIDATE_CAP_TOTAL: usize = 64; +const LOOKUP_CANDIDATE_CAP_ASN: usize = 16; +const LOOKUP_CANDIDATE_CAP_COUNTRY: usize = 16; +const LOOKUP_CANDIDATE_CAP_ALL: usize = 32; + +// GEO-aware ranking dimensions. Final ordering still falls back to the original +// record index so we keep lookups stable when all computed dimensions tie. +#[derive(Clone, Copy, Debug, PartialEq)] +struct GeoSortKey { + same_country: bool, + same_asn: bool, + family_match: bool, + same_city: bool, + load: Option, + geo_distance: Option, +} + fn normalize_lookup_records(records: Vec) -> Vec { let mut normalized = Vec::new(); let mut seen = HashSet::new(); @@ -77,53 +103,29 @@ fn normalize_lookup_records(records: Vec) -> Vec { fn lookup_endpoint(dns_bytes: &[u8]) -> Option<(SocketAddr, Option)> { let (_, packet) = be_packet(dns_bytes).ok()?; - packet.answers.iter().find_map(|answer| match answer.data() { - RData::E(endpoint) => Some((endpoint.addr(), endpoint.load())), - _ => None, - }) -} - -fn common_prefix_len(source: IpAddr, target: IpAddr) -> u32 { - fn bytes_prefix_len(left: &[u8], right: &[u8]) -> u32 { - let mut matched = 0; - for (l, r) in left.iter().zip(right.iter()) { - let diff = l ^ r; - if diff == 0 { - matched += 8; - continue; - } - matched += (diff as u32).leading_zeros().saturating_sub(24); - break; - } - matched - } - - match (source, target) { - (IpAddr::V4(source), IpAddr::V4(target)) => { - bytes_prefix_len(&source.octets(), &target.octets()) - } - (IpAddr::V6(source), IpAddr::V6(target)) => { - bytes_prefix_len(&source.octets(), &target.octets()) - } - _ => 0, - } + packet + .answers + .iter() + .find_map(|answer| match answer.data() { + RData::E(endpoint) => Some((endpoint.addr(), endpoint.load())), + _ => None, + }) } +// Fallback ordering when GEO routing is disabled: prefer matching address family, +// then lower load, and finally preserve input order. We intentionally avoid +// IP prefix heuristics here because they are not reliable on the public Internet. fn sort_lookup_records(records: Vec, source_ip: Option) -> Vec { let mut decorated = records .into_iter() .enumerate() .map(|(index, record)| { let sort_key = lookup_endpoint(&record.0).map(|(endpoint, load)| { - let (family_match, prefix_len) = match source_ip { - Some(source_ip) if source_ip.is_ipv4() == endpoint.ip().is_ipv4() => { - (true, common_prefix_len(source_ip, endpoint.ip())) - } - Some(_) => (false, 0), - None => (false, 0), - }; + let family_match = source_ip + .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) + .unwrap_or(false); - (family_match, prefix_len, load) + (family_match, load) }); (sort_key, index, record) }) @@ -131,9 +133,8 @@ fn sort_lookup_records(records: Vec, source_ip: Option) -> decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { match (left_key, right_key) { - (Some((left_family, left_prefix, left_load)), Some((right_family, right_prefix, right_load))) => right_family + (Some((left_family, left_load)), Some((right_family, right_load))) => right_family .cmp(left_family) - .then_with(|| right_prefix.cmp(left_prefix)) .then_with(|| match (left_load, right_load) { (Some(left), Some(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal), (Some(_), None) => Ordering::Less, @@ -147,10 +148,166 @@ fn sort_lookup_records(records: Vec, source_ip: Option) -> .then_with(|| left_index.cmp(right_index)) }); - decorated + decorated.into_iter().map(|(_, _, record)| record).collect() +} + +fn request_source_geo_traits( + source_ip: Option, + geo: Option<&GeoResolver>, +) -> Option { + Some(geo?.lookup_traits(source_ip?)) +} + +fn lookup_endpoint_geo_traits( + dns_bytes: &[u8], + geo: &GeoResolver, +) -> Option<(SocketAddr, Option, GeoTraits)> { + let (endpoint, load) = lookup_endpoint(dns_bytes)?; + Some((endpoint, load, geo.lookup_traits(endpoint.ip()))) +} + +fn compare_optional_partial(left: Option, right: Option) -> Ordering { + match (left, right) { + (Some(left), Some(right)) => left.partial_cmp(&right).unwrap_or(Ordering::Equal), + _ => Ordering::Equal, + } +} + +// GEO ordering is layered rather than score-based: +// country > ASN > address family > city name > lower load > shorter GEO distance. +// Missing optional values do not penalize a candidate; they simply skip that layer. +fn compare_geo_sort_keys(left: GeoSortKey, right: GeoSortKey) -> Ordering { + right + .same_country + .cmp(&left.same_country) + .then_with(|| right.same_asn.cmp(&left.same_asn)) + .then_with(|| right.family_match.cmp(&left.family_match)) + .then_with(|| right.same_city.cmp(&left.same_city)) + .then_with(|| compare_optional_partial(left.load, right.load)) + .then_with(|| compare_optional_partial(left.geo_distance, right.geo_distance)) +} + +// Build the per-endpoint GEO ranking tuple. City name only participates when both +// sides have a name and already match on country; coordinate distance only +// participates when GeoResolver accepts both accuracy radii. +fn build_geo_sort_key( + source_ip: Option, + source_traits: Option<&GeoTraits>, + endpoint: SocketAddr, + load: Option, + endpoint_traits: &GeoTraits, + geo: &GeoResolver, +) -> GeoSortKey { + let family_match = source_ip + .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) + .unwrap_or(false); + + let same_country = source_traits + .and_then(|source| source.country.as_deref()) + .zip(endpoint_traits.country.as_deref()) + .is_some_and(|(source, target)| source == target); + + let same_asn = source_traits + .and_then(|source| source.asn) + .zip(endpoint_traits.asn) + .is_some_and(|(source, target)| source == target); + + let same_city = same_country + && source_traits + .and_then(|source| source.city.as_deref()) + .zip(endpoint_traits.city.as_deref()) + .is_some_and(|(source, target)| source == target); + + let geo_distance = source_traits + .and_then(|source| source.point.as_ref()) + .zip(endpoint_traits.point.as_ref()) + .and_then(|(source, target)| geo.geo_distance_km(source, target)); + + GeoSortKey { + same_country, + same_asn, + family_match, + same_city, + load, + geo_distance, + } +} + +fn candidate_total_cap(limit: Option) -> usize { + limit + .unwrap_or(LOOKUP_CANDIDATE_CAP_TOTAL) + .max(LOOKUP_CANDIDATE_CAP_TOTAL) +} + +fn all_candidate_cap(total_cap: usize, source_traits: Option<&GeoTraits>) -> usize { + let has_geo_buckets = source_traits.is_some_and(|traits| { + traits.asn.is_some() || traits.country.as_deref().is_some() + }); + + if has_geo_buckets { + LOOKUP_CANDIDATE_CAP_ALL.min(total_cap) + } else { + total_cap + } +} + +fn push_unique_candidates( + candidates: &mut Vec, + seen: &mut HashSet, + source: impl IntoIterator, + total_cap: usize, +) where + T: Clone + Eq + Hash, +{ + for item in source { + if candidates.len() >= total_cap { + break; + } + + if seen.insert(item.clone()) { + candidates.push(item); + } + } +} + +fn sort_lookup_records_with_geo( + records: Vec, + source_ip: Option, + geo: &GeoResolver, +) -> Vec { + let source_traits = request_source_geo_traits(source_ip, Some(geo)); + + let mut decorated = records .into_iter() - .map(|(_, _, record)| record) - .collect() + .enumerate() + .map(|(index, record)| { + let sort_key = lookup_endpoint_geo_traits(&record.0, geo).map( + |(endpoint, load, endpoint_traits)| { + build_geo_sort_key( + source_ip, + source_traits.as_ref(), + endpoint, + load, + &endpoint_traits, + geo, + ) + }, + ); + (sort_key, index, record) + }) + .collect::>(); + + decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { + match (left_key, right_key) { + (Some(left_key), Some(right_key)) => compare_geo_sort_keys(*left_key, *right_key), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + } + .then_with(|| left_index.cmp(right_index)) + }); + + decorated.into_iter().map(|(_, _, record)| record).collect() } fn request_source_ip(request: &Request) -> Option { @@ -188,59 +345,147 @@ async fn perform_lookup_multi( limit: Option, source_ip: Option, ) -> Result { + let source_traits = request_source_geo_traits(source_ip, state.geo.as_deref()); + let candidate_total = candidate_total_cap(limit); + let candidate_all = all_candidate_cap(candidate_total, source_traits.as_ref()); + let dynamic_records = match &state.storage { Storage::Redis(pool) => { let mut conn = pool.get().await.map_err(|e| AppError::Redis { message: e.to_string(), })?; - - let set_key = format!("{host}:multi"); let now_secs = unix_now_secs(); - - // Remove expired members: those published more than ttl_secs ago. let cutoff_score = now_secs.saturating_sub(state.ttl_secs) as f64; + let mut candidate_fingerprints = Vec::new(); + let mut seen_fingerprints = HashSet::new(); + + if let Some(asn) = source_traits.as_ref().and_then(|traits| traits.asn) { + let index_key = redis_asn_index_key(host, asn); + let _: () = redis::cmd("ZREMRANGEBYSCORE") + .arg(&index_key) + .arg("-inf") + .arg(cutoff_score) + .query_async::<()>(&mut *conn) + .await + .unwrap_or(()); + + let members: Vec = conn + .zrevrange( + &index_key, + 0isize, + LOOKUP_CANDIDATE_CAP_ASN.saturating_sub(1) as isize, + ) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + push_unique_candidates( + &mut candidate_fingerprints, + &mut seen_fingerprints, + members, + candidate_total, + ); + } + + if let Some(country) = source_traits + .as_ref() + .and_then(|traits| traits.country.as_deref()) + { + let index_key = redis_country_index_key(host, country); + let _: () = redis::cmd("ZREMRANGEBYSCORE") + .arg(&index_key) + .arg("-inf") + .arg(cutoff_score) + .query_async::<()>(&mut *conn) + .await + .unwrap_or(()); + + let members: Vec = conn + .zrevrange( + &index_key, + 0isize, + LOOKUP_CANDIDATE_CAP_COUNTRY.saturating_sub(1) as isize, + ) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + push_unique_candidates( + &mut candidate_fingerprints, + &mut seen_fingerprints, + members, + candidate_total, + ); + } + + let all_index_key = redis_all_index_key(host); let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&set_key) + .arg(&all_index_key) .arg("-inf") .arg(cutoff_score) .query_async::<()>(&mut *conn) .await .unwrap_or(()); - // Fetch all remaining active dynamic records; scheduling is applied after decode. - let members: Vec> = conn - .zrevrange(&set_key, 0isize, -1isize) + let all_members: Vec = conn + .zrevrange(&all_index_key, 0isize, candidate_all.saturating_sub(1) as isize) .await .map_err(|e| AppError::Redis { message: e.to_string(), })?; - let now_secs = unix_now_secs(); - let records: Vec<(Vec, Vec)> = members - .into_iter() - .filter_map(|m| { - let r = StoredRecord::decode(&m)?; - if r.expire_unix_secs > now_secs { - Some((r.dns, r.cert)) - } else { - None + push_unique_candidates( + &mut candidate_fingerprints, + &mut seen_fingerprints, + all_members, + candidate_total, + ); + + let mut records = Vec::new(); + for fingerprint in candidate_fingerprints { + let primary_key = redis_primary_key(host, &fingerprint); + let member: Option> = conn.get(&primary_key).await.map_err(|e| { + AppError::Redis { + message: e.to_string(), } - }) - .collect(); + })?; + + let Some(member) = member else { + continue; + }; + let Some(record) = StoredRecord::decode(&member) else { + continue; + }; + if record.expire_unix_secs > now_secs { + records.push((record.dns, record.cert)); + } + } records } Storage::Memory(mem) => { let now = tokio::time::Instant::now(); if let Some(mut entry) = mem.records.get_mut(host) { - // Evict expired entries in-place. - entry.retain(|_, r| r.expire > now); - // Sort newest-first by published_at. - let mut records: Vec<_> = entry.values().collect(); - records.sort_by_key(|b| std::cmp::Reverse(b.published_at)); - records - .iter() - .map(|r| (r.dns_bytes.clone(), r.cert_bytes.clone())) + entry.retain_active(now); + let candidate_fingerprints = entry.collect_candidates( + source_traits.as_ref().and_then(|traits| traits.country.as_deref()), + source_traits.as_ref().and_then(|traits| traits.asn), + candidate_total, + LOOKUP_CANDIDATE_CAP_ASN, + LOOKUP_CANDIDATE_CAP_COUNTRY, + candidate_all, + ); + + candidate_fingerprints + .into_iter() + .filter_map(|fingerprint| { + entry + .records + .get(&fingerprint) + .map(|record| (record.dns_bytes.clone(), record.cert_bytes.clone())) + }) .collect::>() } else { vec![] @@ -248,13 +493,20 @@ async fn perform_lookup_multi( } }; - let mut records = sort_lookup_records(normalize_lookup_records(dynamic_records), source_ip); + let normalized_dynamic_records = normalize_lookup_records(dynamic_records); + let mut records = if let Some(geo) = state.geo.as_deref() { + sort_lookup_records_with_geo(normalized_dynamic_records, source_ip, geo) + } else { + sort_lookup_records(normalized_dynamic_records, source_ip) + }; let should_append_seeds = records.is_empty() || limit.is_some_and(|max| records.len() < max); - if should_append_seeds - && let Some(seed_records) = state.seed_records.get(host) - { - let seeds = sort_lookup_records(seed_records.iter().cloned().collect(), source_ip); + if should_append_seeds && let Some(seed_records) = state.seed_records.get(host) { + let seeds = if let Some(geo) = state.geo.as_deref() { + sort_lookup_records_with_geo(seed_records.iter().cloned().collect(), source_ip, geo) + } else { + sort_lookup_records(seed_records.iter().cloned().collect(), source_ip) + }; records.extend(seeds); } @@ -356,3 +608,326 @@ impl LookupSvc { async move { Ok(lookup_with_cert(state, request).await) } } } + +#[cfg(test)] +mod tests { + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddrV4}, + path::PathBuf, + }; + + use ddns::core::MdnsEndpoint; + + use super::*; + use crate::geo::{GeoPoint, GeoResolver}; + + fn fixture_geo_resolver() -> GeoResolver { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let city_db = manifest_dir.join("geoip/GeoLite2-City.mmdb"); + let asn_db = manifest_dir.join("geoip/GeoLite2-ASN.mmdb"); + + GeoResolver::open(&city_db, &asn_db, true, 100).expect("fixture geo db should open") + } + + fn lookup_record(host: &str, addr: SocketAddr, load: Option) -> LookupRecord { + let mut endpoint = match addr { + SocketAddr::V4(addr) => MdnsEndpoint::direct_v4(addr), + SocketAddr::V6(addr) => MdnsEndpoint::direct_v6(addr), + }; + endpoint.set_load(load); + + let mut hosts = HashMap::new(); + hosts.insert(host.to_string(), vec![endpoint]); + + (MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new()) + } + + #[test] + fn compare_geo_sort_keys_follows_documented_priority() { + let best = GeoSortKey { + same_country: true, + same_asn: true, + family_match: true, + same_city: true, + load: Some(0.2), + geo_distance: Some(20.0), + }; + let worse_load = GeoSortKey { + load: Some(0.8), + ..best + }; + let worse_family = GeoSortKey { + same_asn: true, + family_match: false, + same_city: true, + load: Some(0.1), + geo_distance: Some(1.0), + ..best + }; + let worse_city = GeoSortKey { + same_city: false, + load: Some(0.1), + geo_distance: Some(1.0), + ..best + }; + let worse_asn = GeoSortKey { + same_asn: false, + family_match: true, + same_city: true, + load: Some(0.1), + geo_distance: Some(1.0), + ..best + }; + let worse_country = GeoSortKey { + same_country: false, + same_asn: true, + family_match: true, + same_city: false, + load: Some(0.1), + geo_distance: Some(1.0), + }; + + assert_eq!(compare_geo_sort_keys(best, worse_load), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_family), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_city), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_asn), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_country), Ordering::Less); + } + + #[test] + fn compare_geo_sort_keys_skips_unknown_dimensions() { + let known_distance = GeoSortKey { + same_country: true, + same_asn: true, + family_match: true, + same_city: true, + load: Some(0.2), + geo_distance: Some(10.0), + }; + let missing_distance = GeoSortKey { + geo_distance: None, + ..known_distance + }; + let missing_load = GeoSortKey { + load: None, + ..known_distance + }; + + assert_eq!( + compare_geo_sort_keys(known_distance, missing_distance), + Ordering::Equal + ); + assert_eq!( + compare_geo_sort_keys(known_distance, missing_load), + Ordering::Equal + ); + } + + #[test] + fn sort_lookup_records_with_geo_prefers_same_source_endpoint_even_with_higher_load() { + let geo = fixture_geo_resolver(); + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); + let matching = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + Some(0.9), + ); + let non_matching = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 3478)), + Some(0.1), + ); + + let sorted = + sort_lookup_records_with_geo(vec![non_matching, matching.clone()], source_ip, &geo); + + let (endpoint, _) = lookup_endpoint(&sorted[0].0).expect("sorted record should decode"); + assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); + } + + #[test] + fn sort_lookup_records_without_geo_ignores_ip_prefix_and_prefers_lower_load() { + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); + let closer_prefix_higher_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 3478)), + Some(0.9), + ); + let farther_prefix_lower_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + Some(0.1), + ); + + let sorted = sort_lookup_records( + vec![closer_prefix_higher_load, farther_prefix_lower_load], + source_ip, + ); + + let (endpoint, _) = lookup_endpoint(&sorted[0].0).expect("sorted record should decode"); + assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); + } + + #[test] + fn sort_lookup_records_with_geo_prefers_same_asn_then_same_country_on_real_ips() { + let geo = fixture_geo_resolver(); + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); + + let different_country = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + Some(0.01), + ); + let same_country_different_asn = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), + Some(0.02), + ); + let same_country_same_asn = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), + Some(0.9), + ); + + let sorted = sort_lookup_records_with_geo( + vec![ + different_country, + same_country_different_asn, + same_country_same_asn, + ], + source_ip, + &geo, + ); + + let ordered_ips = sorted + .iter() + .map(|record| { + lookup_endpoint(&record.0) + .expect("record should decode") + .0 + .ip() + }) + .collect::>(); + + assert_eq!( + ordered_ips, + vec![ + IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), + IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114)), + IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), + ] + ); + } + + #[test] + fn sort_lookup_records_with_geo_prefers_same_country_over_lower_load_on_real_ips() { + let geo = fixture_geo_resolver(); + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114))); + + let different_country_low_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(80, 80, 80, 80), 3478)), + Some(0.01), + ); + let same_country_higher_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), + Some(0.9), + ); + + let sorted = sort_lookup_records_with_geo( + vec![different_country_low_load, same_country_higher_load.clone()], + source_ip, + &geo, + ); + + let (endpoint, _) = lookup_endpoint(&sorted[0].0).expect("sorted record should decode"); + assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); + } + + #[test] + fn build_geo_sort_key_ignores_city_distance_when_accuracy_is_too_large() { + let geo = fixture_geo_resolver(); + let source_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Beijing".to_string()), + asn: Some(64512), + point: Some(GeoPoint { + latitude: 39.9, + longitude: 116.4, + accuracy_radius_km: 500, + }), + }; + let endpoint_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Shanghai".to_string()), + asn: Some(64512), + point: Some(GeoPoint { + latitude: 31.2, + longitude: 121.5, + accuracy_radius_km: 10, + }), + }; + + let key = build_geo_sort_key( + Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))), + Some(&source_traits), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 3478)), + Some(0.2), + &endpoint_traits, + &geo, + ); + + assert!(key.same_country); + assert!(key.same_asn); + assert!(!key.same_city); + assert_eq!(key.geo_distance, None); + } + + #[test] + fn build_geo_sort_key_prefers_same_city_when_available() { + let geo = fixture_geo_resolver(); + let source_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Hangzhou".to_string()), + asn: Some(64512), + point: None, + }; + let same_city_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Hangzhou".to_string()), + asn: Some(64513), + point: None, + }; + let different_city_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Shanghai".to_string()), + asn: Some(64513), + point: None, + }; + + let same_city_key = build_geo_sort_key( + Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), + Some(&source_traits), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 6, 6, 6), 3478)), + Some(0.9), + &same_city_traits, + &geo, + ); + let different_city_key = build_geo_sort_key( + Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), + Some(&source_traits), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), + Some(0.1), + &different_city_traits, + &geo, + ); + + assert!(same_city_key.same_city); + assert!(!different_city_key.same_city); + assert_eq!( + compare_geo_sort_keys(same_city_key, different_city_key), + Ordering::Less + ); + } +} diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs index 735fa74..1bb1578 100644 --- a/src/bin/ddns-server/main.rs +++ b/src/bin/ddns-server/main.rs @@ -1,5 +1,6 @@ mod config; mod error; +mod geo; mod lookup; mod policy; mod publish; @@ -11,6 +12,7 @@ use std::{ net::SocketAddr, sync::Arc, task::{Context, Poll}, + time::{SystemTime, UNIX_EPOCH}, }; use clap::Parser; @@ -26,11 +28,12 @@ use h3x::{ hyper::TowerService, }; use rustls::{RootCertStore, server::WebPkiClientVerifier}; -use tracing::{info, level_filters::LevelFilter}; +use tracing::{info, level_filters::LevelFilter, warn}; use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt}; use crate::{ config::{Config, Options, PolicyKind, SeedRecordConfig}, + geo::GeoResolver, lookup::LookupSvc, policy::{DomainPolicies, DomainPolicy, PolicyRule}, publish::PublishSvc, @@ -129,6 +132,61 @@ fn build_seed_records(seed_records: &[SeedRecordConfig]) -> io::Result STALE_GEO_DB_AGE_SECS { + warn!(kind, build_epoch, age_secs, "geo_routing.db_outdated"); + } +} + +const GEO_CITY_DISTANCE_ROUTING: bool = true; +const GEO_MAX_ACCURACY_RADIUS_KM: u32 = 100; + +fn build_geo_resolver(config: &Config) -> io::Result>> { + let Some(city_db) = config.geoip_city_db.as_deref() else { + return if config.geoip_asn_db.is_none() { + Ok(None) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "geoip_city_db and geoip_asn_db must be configured together", + )) + }; + }; + + let Some(asn_db) = config.geoip_asn_db.as_deref() else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "geoip_city_db and geoip_asn_db must be configured together", + )); + }; + + let resolver = Arc::new(GeoResolver::open( + city_db, + asn_db, + GEO_CITY_DISTANCE_ROUTING, + GEO_MAX_ACCURACY_RADIUS_KM, + )?); + info!( + city_db = %city_db.display(), + asn_db = %asn_db.display(), + city_distance_routing = GEO_CITY_DISTANCE_ROUTING, + max_accuracy_radius_km = GEO_MAX_ACCURACY_RADIUS_KM, + "geo_routing.enabled" + ); + log_geo_db_freshness("city", resolver.city_build_epoch()); + log_geo_db_freshness("asn", resolver.asn_build_epoch()); + + Ok(Some(resolver)) +} + // --------------------------------------------------------------------------- // Entry point // --------------------------------------------------------------------------- @@ -155,6 +213,7 @@ async fn main() -> Result<(), Box> { }); let config = config.expand_paths(); let seed_records = build_seed_records(&config.seed_records)?; + let geo = build_geo_resolver(&config)?; // Build storage backend. let storage = match config.redis.clone() { @@ -201,6 +260,7 @@ async fn main() -> Result<(), Box> { ttl_secs: config.ttl_secs, policies, seed_records, + geo, }; let cert_pem = std::fs::read(&config.cert)?; @@ -239,3 +299,59 @@ async fn main() -> Result<(), Box> { Ok(()) } + +#[cfg(test)] +mod tests { + use std::{net::SocketAddr, path::PathBuf}; + + use super::*; + use crate::config::Config; + + fn test_config() -> Config { + Config { + redis: None, + listen: Config::default_listen(), + server_name: Config::default_server_name(), + cert: Config::default_cert(), + key: Config::default_key(), + root_cert: Config::default_root_cert(), + require_signature: Config::default_require_signature(), + ttl_secs: Config::default_ttl_secs(), + domain_policies: Vec::new(), + seed_records: Vec::new(), + geoip_city_db: None, + geoip_asn_db: None, + } + } + + #[test] + fn unspecified_ipv4_listen_uses_dual_stack_wildcard() { + let listen: SocketAddr = "0.0.0.0:4433".parse().unwrap(); + let patterns = bind_patterns_for_listen(listen); + + assert_eq!(patterns.len(), 1); + assert_eq!(patterns[0].to_string(), "inet://[::]:4433"); + } + + #[test] + fn geo_routing_requires_city_db_path() { + let mut config = test_config(); + config.geoip_asn_db = Some(PathBuf::from("/tmp/asn.mmdb")); + + let err = build_geo_resolver(&config).expect_err("missing city db should fail"); + + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + assert_eq!(err.to_string(), "geoip_city_db and geoip_asn_db must be configured together"); + } + + #[test] + fn geo_routing_requires_asn_db_path() { + let mut config = test_config(); + config.geoip_city_db = Some(PathBuf::from("/tmp/city.mmdb")); + + let err = build_geo_resolver(&config).expect_err("missing asn db should fail"); + + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + assert_eq!(err.to_string(), "geoip_city_db and geoip_asn_db must be configured together"); + } +} diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish.rs index 5b07347..b9cbf22 100644 --- a/src/bin/ddns-server/publish.rs +++ b/src/bin/ddns-server/publish.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, sync::Arc}; +use std::{collections::HashSet, convert::Infallible, sync::Arc}; use deadpool_redis::redis::{self, AsyncCommands}; use dhttp_identity::identity::RemoteAuthority; @@ -13,7 +13,8 @@ use crate::{ policy::{DomainPolicy, ValidatedDnsPacket, client_allowed_host, validate_dns_packet}, storage::{ AppState, Record, Storage, StoredRecord, cert_fingerprint, cert_fingerprint_hex, - unix_now_secs, + record_index_tags, redis_all_index_key, redis_asn_index_key, redis_country_index_key, + redis_primary_key, unix_now_secs, }, }; @@ -180,17 +181,30 @@ pub async fn publish_record( let expire_ttl_secs = i64::try_from(state.ttl_secs).unwrap_or(i64::MAX); let now_secs = unix_now_secs(); let expire_secs = now_secs + state.ttl_secs; + let index_tags = record_index_tags(body.as_ref(), state.geo.as_deref()); - let fp_key = format!("{host}:fp:{fp_hex}"); - let set_key = format!("{host}:multi"); + let fp_key = redis_primary_key(host, &fp_hex); + let all_index_key = redis_all_index_key(host); + let mut touched_index_keys = HashSet::from([all_index_key.clone()]); - // Remove the previous entry from this source (if any) from the ZSET. let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old) = old_member { - let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); + if let Some(old_record) = old_member.as_deref().and_then(StoredRecord::decode) { + let old_tags = record_index_tags(&old_record.dns, state.geo.as_deref()); + let _: () = conn.zrem(&all_index_key, &fp_hex).await.unwrap_or(()); + + for country in &old_tags.countries { + let key = redis_country_index_key(host, country); + touched_index_keys.insert(key.clone()); + let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); + } + + for asn in &old_tags.asns { + let key = redis_asn_index_key(host, *asn); + touched_index_keys.insert(key.clone()); + let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); + } } - // Encode and store the new member. let new_member = StoredRecord { expire_unix_secs: expire_secs, fingerprint: fp, @@ -209,7 +223,7 @@ pub async fn publish_record( } if let Err(e) = conn - .zadd::<_, _, _, ()>(&set_key, &new_member, now_secs as f64) + .zadd::<_, _, _, ()>(&all_index_key, &fp_hex, now_secs as f64) .await { return write_error(AppError::Redis { @@ -217,21 +231,37 @@ pub async fn publish_record( }); } - // Expire the ZSET key at max(ttl_secs) from now as a safety net. - let _: bool = conn - .expire(&set_key, expire_ttl_secs) - .await - .unwrap_or(false); + for country in &index_tags.countries { + let key = redis_country_index_key(host, country); + touched_index_keys.insert(key.clone()); + if let Err(e) = conn.zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64).await { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + } + + for asn in &index_tags.asns { + let key = redis_asn_index_key(host, *asn); + touched_index_keys.insert(key.clone()); + if let Err(e) = conn.zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64).await { + return write_error(AppError::Redis { + message: e.to_string(), + }); + } + } - // Evict stale (score < now - ttl) entries. let cutoff = now_secs.saturating_sub(state.ttl_secs) as f64; - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&set_key) - .arg("-inf") - .arg(cutoff) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); + for key in touched_index_keys { + let _: bool = conn.expire(&key, expire_ttl_secs).await.unwrap_or(false); + let _: () = redis::cmd("ZREMRANGEBYSCORE") + .arg(&key) + .arg("-inf") + .arg(cutoff) + .query_async::<()>(&mut *conn) + .await + .unwrap_or(()); + } } Storage::Memory(mem) => { let now = Instant::now(); @@ -240,14 +270,11 @@ pub async fn publish_record( dns_bytes: body.to_vec(), cert_bytes, expire, - published_at: now, + index_tags: record_index_tags(body.as_ref(), state.geo.as_deref()), }; - // Upsert by fingerprint: same source overwrites its own entry; - // different sources (different certs) coexist independently. let mut host_map = mem.records.entry(host.to_string()).or_default(); + host_map.retain_active(now); host_map.insert(fp, record); - // Evict expired entries while we hold the write lock. - host_map.retain(|_, r| r.expire > now); } } @@ -280,13 +307,23 @@ pub async fn clear_record( } }; - let fp_key = format!("{host}:fp:{fp_hex}"); - let set_key = format!("{host}:multi"); + let fp_key = redis_primary_key(host, &fp_hex); + let all_index_key = redis_all_index_key(host); let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old) = old_member { - let _: () = conn.zrem(&set_key, &old).await.unwrap_or(()); + if let Some(old_record) = old_member.as_deref().and_then(StoredRecord::decode) { + let old_tags = record_index_tags(&old_record.dns, state.geo.as_deref()); + let _: () = conn.zrem(&all_index_key, &fp_hex).await.unwrap_or(()); + for country in &old_tags.countries { + let key = redis_country_index_key(host, country); + let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); + } + for asn in &old_tags.asns { + let key = redis_asn_index_key(host, *asn); + let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); + } } + if let Err(e) = conn.del::<_, ()>(&fp_key).await { return write_error(AppError::Redis { message: e.to_string(), @@ -295,7 +332,7 @@ pub async fn clear_record( } Storage::Memory(mem) => { let remove_host = if let Some(mut host_map) = mem.records.get_mut(host) { - host_map.remove(&fp); + let _ = host_map.remove(&fp); host_map.is_empty() } else { false @@ -361,6 +398,7 @@ mod tests { ttl_secs: 30, policies: Arc::new(DomainPolicies::default()), seed_records: SeedRecords::default(), + geo: None, } } @@ -401,7 +439,7 @@ mod tests { http::StatusCode::OK ); - let LookupResult::Multi(response) = perform_lookup(&state, host, None).await.unwrap() + let LookupResult::Multi(response) = perform_lookup(&state, host, None, None).await.unwrap() else { panic!("authority b record should remain"); }; @@ -420,7 +458,7 @@ mod tests { http::StatusCode::OK ); assert!(matches!( - perform_lookup(&state, host, None).await.unwrap(), + perform_lookup(&state, host, None, None).await.unwrap(), LookupResult::NotFound )); } diff --git a/src/bin/ddns-server/storage.rs b/src/bin/ddns-server/storage.rs index e194faf..5308db7 100644 --- a/src/bin/ddns-server/storage.rs +++ b/src/bin/ddns-server/storage.rs @@ -1,5 +1,5 @@ use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, sync::Arc, time::{SystemTime, UNIX_EPOCH}, }; @@ -7,6 +7,7 @@ use std::{ use bytes::BufMut; use dashmap::DashMap; use deadpool_redis::Pool; +use ddns::core::parser::{packet::be_packet, record::RData}; use nom::{ IResult, bytes::streaming::take, @@ -14,7 +15,7 @@ use nom::{ }; use tokio::time::Instant; -use crate::policy::DomainPolicies; +use crate::{geo::GeoResolver, policy::DomainPolicies}; // --------------------------------------------------------------------------- // Storage helpers @@ -27,11 +28,12 @@ pub fn cert_fingerprint(cert_der: &[u8]) -> [u8; 32] { d.as_ref().try_into().expect("SHA-256 is always 32 bytes") } +pub fn fingerprint_hex(fingerprint: &[u8; 32]) -> String { + fingerprint.iter().map(|b| format!("{b:02x}")).collect() +} + pub fn cert_fingerprint_hex(cert_der: &[u8]) -> String { - cert_fingerprint(cert_der) - .iter() - .map(|b| format!("{b:02x}")) - .collect() + fingerprint_hex(&cert_fingerprint(cert_der)) } pub fn unix_now_secs() -> u64 { @@ -41,11 +43,68 @@ pub fn unix_now_secs() -> u64 { .unwrap_or(0) } +pub fn redis_primary_key(host: &str, fingerprint_hex: &str) -> String { + format!("{host}:fp:{fingerprint_hex}") +} + +pub fn redis_all_index_key(host: &str) -> String { + format!("{host}:idx:all") +} + +pub fn redis_country_index_key(host: &str, country: &str) -> String { + format!("{host}:idx:country:{country}") +} + +pub fn redis_asn_index_key(host: &str, asn: u32) -> String { + format!("{host}:idx:asn:{asn}") +} + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +pub struct RecordIndexTags { + pub countries: Vec, + pub asns: Vec, +} + +pub fn record_index_tags(dns_bytes: &[u8], geo: Option<&GeoResolver>) -> RecordIndexTags { + let Some(geo) = geo else { + return RecordIndexTags::default(); + }; + + let Ok((_, packet)) = be_packet(dns_bytes) else { + return RecordIndexTags::default(); + }; + + let mut countries = HashSet::new(); + let mut asns = HashSet::new(); + + for answer in &packet.answers { + let RData::E(endpoint) = answer.data() else { + continue; + }; + + let traits = geo.lookup_traits(endpoint.addr().ip()); + if let Some(country) = traits.country { + countries.insert(country); + } + if let Some(asn) = traits.asn { + asns.insert(asn); + } + } + + let mut countries = countries.into_iter().collect::>(); + countries.sort(); + + let mut asns = asns.into_iter().collect::>(); + asns.sort_unstable(); + + RecordIndexTags { countries, asns } +} + // --------------------------------------------------------------------------- -// Redis ZSET member wire type +// Redis primary record wire type // --------------------------------------------------------------------------- -/// One record as persisted in the Redis ZSET (or decoded from it). +/// One record as persisted in the Redis primary record value. /// /// Wire layout (big-endian, contiguous): /// ```text @@ -70,39 +129,24 @@ pub struct StoredRecord { } impl StoredRecord { - pub fn encoding_size(&self) -> usize { - 8 + 32 + 4 + self.dns.len() + 4 + self.cert.len() - } - - /// Encode to a byte buffer suitable for use as a Redis ZSET member. + /// Encode to a byte buffer suitable for use as a Redis primary record value. pub fn encode(&self) -> Vec { - let mut buf = Vec::with_capacity(self.encoding_size()); - buf.put_stored_record(self); + let mut buf = Vec::with_capacity(8 + 32 + 4 + self.dns.len() + 4 + self.cert.len()); + buf.put_u64(self.expire_unix_secs); + buf.put_slice(&self.fingerprint); + buf.put_u32(self.dns.len() as u32); + buf.put_slice(&self.dns); + buf.put_u32(self.cert.len() as u32); + buf.put_slice(&self.cert); buf } - /// Decode from a Redis ZSET member. Returns `None` on malformed input. + /// Decode from a Redis primary record value. Returns `None` on malformed input. pub fn decode(data: &[u8]) -> Option { be_stored_record(data).ok().map(|(_, r)| r) } } -/// `BufMut` write extension for [`StoredRecord`]. -pub trait WriteStoredRecord { - fn put_stored_record(&mut self, record: &StoredRecord); -} - -impl WriteStoredRecord for B { - fn put_stored_record(&mut self, record: &StoredRecord) { - self.put_u64(record.expire_unix_secs); - self.put_slice(&record.fingerprint); - self.put_u32(record.dns.len() as u32); - self.put_slice(&record.dns); - self.put_u32(record.cert.len() as u32); - self.put_slice(&record.cert); - } -} - /// nom parser for [`StoredRecord`]. pub fn be_stored_record(input: &[u8]) -> IResult<&[u8], StoredRecord> { let (input, expire_unix_secs) = be_u64(input)?; @@ -133,8 +177,135 @@ pub struct Record { pub cert_bytes: Vec, /// Wall-clock expiry (for TTL eviction). pub expire: Instant, - /// When this record was last published (for newest-first ordering). - pub published_at: Instant, + /// Precomputed country / ASN buckets used by the Lite indexes. + pub index_tags: RecordIndexTags, +} + +#[derive(Clone, Debug, Default)] +pub struct HostRecords { + pub records: HashMap<[u8; 32], Record>, + pub recent: Vec<[u8; 32]>, + pub by_country: HashMap>, + pub by_asn: HashMap>, +} + +impl HostRecords { + fn remove_fingerprint(list: &mut Vec<[u8; 32]>, fingerprint: &[u8; 32]) { + list.retain(|existing| existing != fingerprint); + } + + fn remove_from_indexes(&mut self, fingerprint: &[u8; 32], tags: &RecordIndexTags) { + Self::remove_fingerprint(&mut self.recent, fingerprint); + + for country in &tags.countries { + let should_remove = if let Some(bucket) = self.by_country.get_mut(country) { + Self::remove_fingerprint(bucket, fingerprint); + bucket.is_empty() + } else { + false + }; + + if should_remove { + self.by_country.remove(country); + } + } + + for asn in &tags.asns { + let should_remove = if let Some(bucket) = self.by_asn.get_mut(asn) { + Self::remove_fingerprint(bucket, fingerprint); + bucket.is_empty() + } else { + false + }; + + if should_remove { + self.by_asn.remove(asn); + } + } + } + + pub fn insert(&mut self, fingerprint: [u8; 32], record: Record) { + if let Some(old_record) = self.records.remove(&fingerprint) { + self.remove_from_indexes(&fingerprint, &old_record.index_tags); + } + + self.recent.insert(0, fingerprint); + + for country in &record.index_tags.countries { + self.by_country + .entry(country.clone()) + .or_default() + .insert(0, fingerprint); + } + + for asn in &record.index_tags.asns { + self.by_asn.entry(*asn).or_default().insert(0, fingerprint); + } + + self.records.insert(fingerprint, record); + } + + pub fn remove(&mut self, fingerprint: &[u8; 32]) -> Option { + let record = self.records.remove(fingerprint)?; + self.remove_from_indexes(fingerprint, &record.index_tags); + Some(record) + } + + pub fn retain_active(&mut self, now: Instant) { + let expired = self + .records + .iter() + .filter_map(|(fingerprint, record)| (record.expire <= now).then_some(*fingerprint)) + .collect::>(); + + for fingerprint in expired { + let _ = self.remove(&fingerprint); + } + } + + pub fn collect_candidates( + &self, + source_country: Option<&str>, + source_asn: Option, + total_cap: usize, + asn_cap: usize, + country_cap: usize, + all_cap: usize, + ) -> Vec<[u8; 32]> { + let mut candidates = Vec::new(); + let mut seen = HashSet::new(); + + let mut push_bucket = |bucket: Option<&Vec<[u8; 32]>>, bucket_cap: usize| { + let Some(bucket) = bucket else { + return; + }; + + for fingerprint in bucket.iter().take(bucket_cap) { + if candidates.len() >= total_cap { + break; + } + + if seen.insert(*fingerprint) { + candidates.push(*fingerprint); + } + } + }; + + if let Some(asn) = source_asn { + push_bucket(self.by_asn.get(&asn), asn_cap.min(total_cap)); + } + + if let Some(country) = source_country { + push_bucket(self.by_country.get(country), country_cap.min(total_cap)); + } + + push_bucket(Some(&self.recent), all_cap.min(total_cap)); + candidates + } + + pub fn is_empty(&self) -> bool { + self.records.is_empty() + } } /// Unified in-memory storage: host → { cert_fingerprint → Record }. @@ -150,7 +321,7 @@ pub struct Record { /// - Clients query get all valid records and choose which one to use #[derive(Clone)] pub struct MemoryStorage { - pub records: Arc>>, + pub records: Arc>, } impl MemoryStorage { @@ -181,4 +352,52 @@ pub struct AppState { pub ttl_secs: u64, pub policies: Arc, pub seed_records: SeedRecords, + pub geo: Option>, +} + +#[cfg(test)] +mod tests { + use super::*; + + fn fp(seed: u8) -> [u8; 32] { + [seed; 32] + } + + fn record(country: Option<&str>, asn: Option) -> Record { + Record { + dns_bytes: Vec::new(), + cert_bytes: Vec::new(), + expire: Instant::now() + tokio::time::Duration::from_secs(60), + index_tags: RecordIndexTags { + countries: country.into_iter().map(str::to_owned).collect(), + asns: asn.into_iter().collect(), + }, + } + } + + #[test] + fn host_records_collect_candidates_prefers_asn_then_country_then_recent() { + let mut host = HostRecords::default(); + host.insert(fp(1), record(Some("US"), Some(64512))); + host.insert(fp(2), record(Some("US"), None)); + host.insert(fp(3), record(Some("JP"), None)); + + let candidates = host.collect_candidates(Some("US"), Some(64512), 8, 2, 2, 8); + + assert_eq!(candidates, vec![fp(1), fp(2), fp(3)]); + } + + #[test] + fn host_records_remove_cleans_secondary_indexes() { + let mut host = HostRecords::default(); + let fingerprint = fp(9); + host.insert(fingerprint, record(Some("US"), Some(64512))); + + let _ = host.remove(&fingerprint); + + assert!(host.recent.is_empty()); + assert!(host.by_country.is_empty()); + assert!(host.by_asn.is_empty()); + assert!(host.records.is_empty()); + } } From ce3d1ced0127fe793d95823a35f4cdeeeb9314ca Mon Sep 17 00:00:00 2001 From: metah3m Date: Mon, 8 Jun 2026 14:20:31 +0800 Subject: [PATCH 03/29] feat: implement OCSP auto-refresh for TLS certificate validation and update configuration options --- .cargo/config.toml | 6 + Cargo.toml | 13 +- README.md | 7 + examples/README.md | 5 + server.toml | 10 + src/bin/ddns-server/config.rs | 9 + src/bin/ddns-server/lookup.rs | 26 +- src/bin/ddns-server/main.rs | 32 +- src/bin/ddns-server/ocsp.rs | 726 +++++++++++++++++++++++++++++++++ src/bin/ddns-server/publish.rs | 10 +- src/bin/ddns-server/storage.rs | 2 +- 11 files changed, 827 insertions(+), 19 deletions(-) create mode 100644 .cargo/config.toml create mode 100644 src/bin/ddns-server/ocsp.rs diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..af284e1 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,6 @@ +[env] +DHTTP_H3_DNS_SERVER = "https://dns.genmeta.net:4433/" +DHTTP_HTTP_DNS_SERVER = "https://dns.genmeta.net/" +DHTTP_MDNS_SERVICE = "_dhttp.local" +DHTTP_STUN_SERVER = "stun.genmeta.net:20002" +DHTTP_ROOT_CA = { value = "intermediate/intermediate.crt", relative = true } diff --git a/Cargo.toml b/Cargo.toml index a6cf4b0..53f0a60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,8 +18,9 @@ base64 = "0.22" bitfield-struct = "0.13" bytes = "1" dashmap = "6" -dhttp-identity = "0.1.0" -dquic = "0.5.1" +der = { version = "0.8.0", optional = true } +dhttp-identity = { git = "https://github.com/genmeta/dhttp.git", branch = "main" } +dquic = { git = "https://github.com/genmeta/dquic.git", branch = "feat/v0.5.1" } flume = "0.12" futures = "0.3" libc = "0.2" @@ -32,6 +33,7 @@ rustls = { version = "0.23", default-features = false, features = [ ] } rustls-native-certs = { version = "0.8", optional = true } rustls-pemfile = "2" +sha1 = { version = "=0.11.0-rc.5", optional = true } snafu = "0.9" socket2 = { version = "0.6", features = ["all"] } tokio = { version = "1", features = [ @@ -44,7 +46,7 @@ tokio = { version = "1", features = [ "io-util", ] } tracing = "0.1" -x509-parser = "0.18" +x509-parser = { version = "0.18", features = ["verify"] } h3x = { version = "0.3.1", default-features = false, optional = true } http = { version = "1", optional = true } @@ -70,6 +72,7 @@ tower-service = { version = "0.3", optional = true } tracing-subscriber = { version = "0.3", features = [ "env-filter", ], optional = true } +x509-cert = { version = "=0.3.0-rc.4", optional = true } [features] default = [] @@ -87,13 +90,17 @@ http-resolver = ["dep:reqwest", "dep:rustls-native-certs"] server = [ "h3x-resolver", "dep:clap", + "dep:der", "dep:deadpool-redis", "dep:idna", "dep:maxminddb", + "dep:reqwest", "dep:serde", + "dep:sha1", "dep:toml", "dep:tower-service", "dep:tracing-subscriber", + "dep:x509-cert", ] [dev-dependencies] diff --git a/README.md b/README.md index 35be4e3..de665a2 100644 --- a/README.md +++ b/README.md @@ -132,6 +132,13 @@ Start the server with the `server` feature: cargo run --bin ddns-server --features server -- --config server.toml ``` +When the configured TLS certificate includes its issuer certificate, `ddns-server` +now pulls its own stapled OCSP response from cert-server's public `POST /ocsp` +responder during startup and refreshes it every 2h55m. If the PEM only contains +the leaf certificate, set `ocsp_issuer_cert` in [server.toml](server.toml). You +can override the responder origin with `ocsp_responder_base_url`; by default it +uses `https://license.genmeta.net`. + The server can optionally enable GEO-aware lookup ordering with local MaxMind GeoLite2 City and ASN databases. When both `geoip_city_db` and `geoip_asn_db` are configured, lookups prefer same-country and same-ASN endpoints first, then diff --git a/examples/README.md b/examples/README.md index c1e5780..da3faa4 100644 --- a/examples/README.md +++ b/examples/README.md @@ -62,6 +62,11 @@ repeated count times: The example prints each DNS packet, the publisher certificate fingerprint when a certificate is present, and endpoint signature verification status for signed `E` records. +After the server starts, it listens for HTTP/3 requests and handles publish and query operations. +If the configured server certificate includes its issuer chain, the process also +fetches and refreshes its own stapled OCSP response from cert-server's public +`/ocsp` endpoint. When the PEM only contains the leaf certificate, configure +`ocsp_issuer_cert` in `server.toml`. ## DNS-over-H3 publish diff --git a/server.toml b/server.toml index 7546e91..ee1733c 100644 --- a/server.toml +++ b/server.toml @@ -14,6 +14,16 @@ key = "~/Downloads/ssl/dns.genmeta.net/dns.genmeta.net.key" # Root CA that signed the client certificates (PEM format). root_cert = "~/Downloads/ssl/root.crt" +# Optional issuer certificate used to build OCSP requests when `cert` only +# contains the leaf certificate. If `cert` already contains the full chain, +# this can be omitted. +# ocsp_issuer_cert = "~/Downloads/ssl/intermediate.crt" + +# Public cert-server OCSP responder. The server refreshes its stapled OCSP +# response immediately on startup and then every 2h55m (cert-server TTL is 3h, +# refreshed 5 minutes early). +# ocsp_responder_base_url = "https://license.genmeta.net" + # Whether to require a valid DNS record signature on Standard domains. require_signature = true diff --git a/src/bin/ddns-server/config.rs b/src/bin/ddns-server/config.rs index 1804f18..3580ef9 100644 --- a/src/bin/ddns-server/config.rs +++ b/src/bin/ddns-server/config.rs @@ -54,6 +54,14 @@ pub struct Config { #[serde(default = "Config::default_root_cert")] pub root_cert: PathBuf, + /// Optional issuer certificate used for OCSP requests when `cert` does not include a chain. + #[serde(default)] + pub ocsp_issuer_cert: Option, + + /// Optional OCSP responder base URL. Defaults to the cert-server public responder. + #[serde(default)] + pub ocsp_responder_base_url: Option, + /// Whether to require DNS record signatures on Standard domains. #[serde(default = "Config::default_require_signature")] pub require_signature: bool, @@ -84,6 +92,7 @@ impl Config { self.cert = expand_home_dir(&self.cert); self.key = expand_home_dir(&self.key); self.root_cert = expand_home_dir(&self.root_cert); + self.ocsp_issuer_cert = self.ocsp_issuer_cert.map(|path| expand_home_dir(&path)); self.geoip_city_db = self.geoip_city_db.map(|path| expand_home_dir(&path)); self.geoip_asn_db = self.geoip_asn_db.map(|path| expand_home_dir(&path)); self diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs index 50e3fa3..6171860 100644 --- a/src/bin/ddns-server/lookup.rs +++ b/src/bin/ddns-server/lookup.rs @@ -27,8 +27,8 @@ use crate::{ error::{AppError, normalize_host, parse_query_params}, geo::{GeoResolver, GeoTraits}, storage::{ - AppState, LookupRecord, Storage, StoredRecord, redis_all_index_key, - redis_asn_index_key, redis_country_index_key, redis_primary_key, unix_now_secs, + AppState, LookupRecord, Storage, StoredRecord, redis_all_index_key, redis_asn_index_key, + redis_country_index_key, redis_primary_key, unix_now_secs, }, }; @@ -240,9 +240,8 @@ fn candidate_total_cap(limit: Option) -> usize { } fn all_candidate_cap(total_cap: usize, source_traits: Option<&GeoTraits>) -> usize { - let has_geo_buckets = source_traits.is_some_and(|traits| { - traits.asn.is_some() || traits.country.as_deref().is_some() - }); + let has_geo_buckets = source_traits + .is_some_and(|traits| traits.asn.is_some() || traits.country.as_deref().is_some()); if has_geo_buckets { LOOKUP_CANDIDATE_CAP_ALL.min(total_cap) @@ -430,7 +429,11 @@ async fn perform_lookup_multi( .unwrap_or(()); let all_members: Vec = conn - .zrevrange(&all_index_key, 0isize, candidate_all.saturating_sub(1) as isize) + .zrevrange( + &all_index_key, + 0isize, + candidate_all.saturating_sub(1) as isize, + ) .await .map_err(|e| AppError::Redis { message: e.to_string(), @@ -446,11 +449,10 @@ async fn perform_lookup_multi( let mut records = Vec::new(); for fingerprint in candidate_fingerprints { let primary_key = redis_primary_key(host, &fingerprint); - let member: Option> = conn.get(&primary_key).await.map_err(|e| { - AppError::Redis { + let member: Option> = + conn.get(&primary_key).await.map_err(|e| AppError::Redis { message: e.to_string(), - } - })?; + })?; let Some(member) = member else { continue; @@ -470,7 +472,9 @@ async fn perform_lookup_multi( if let Some(mut entry) = mem.records.get_mut(host) { entry.retain_active(now); let candidate_fingerprints = entry.collect_candidates( - source_traits.as_ref().and_then(|traits| traits.country.as_deref()), + source_traits + .as_ref() + .and_then(|traits| traits.country.as_deref()), source_traits.as_ref().and_then(|traits| traits.asn), candidate_total, LOOKUP_CANDIDATE_CAP_ASN, diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs index 1bb1578..ca6c7c4 100644 --- a/src/bin/ddns-server/main.rs +++ b/src/bin/ddns-server/main.rs @@ -2,6 +2,7 @@ mod config; mod error; mod geo; mod lookup; +mod ocsp; mod policy; mod publish; mod storage; @@ -293,6 +294,25 @@ async fn main() -> Result<(), Box> { .bind(Arc::new(config.binds.clone())) .build() .await; + match ocsp::OcspAutoRefresh::from_config(&config, &cert_pem, &root_ca_pem) { + Ok(ocsp_refresh) => { + info!( + responder_url = %ocsp_refresh.responder_url(), + refresh_in_secs = ocsp::refresh_success_delay().as_secs(), + "ocsp.auto_refresh.enabled" + ); + let mut ocsp_quic = quic.clone(); + let initial_delay = ocsp_refresh.refresh_once(&mut ocsp_quic).await; + info!( + next_refresh_in_secs = initial_delay.as_secs(), + "ocsp.auto_refresh.initialized" + ); + tokio::spawn(ocsp_refresh.run(ocsp_quic)); + } + Err(error) => { + warn!(error = %error, "ocsp.auto_refresh.disabled"); + } + } let server = Arc::new(H3Endpoint::new(quic)); info!(binds = ?config.binds, server_name = %config.server_name, "h3_server.start"); server.listen_owned(router).await?; @@ -315,6 +335,8 @@ mod tests { cert: Config::default_cert(), key: Config::default_key(), root_cert: Config::default_root_cert(), + ocsp_issuer_cert: None, + ocsp_responder_base_url: None, require_signature: Config::default_require_signature(), ttl_secs: Config::default_ttl_secs(), domain_policies: Vec::new(), @@ -341,7 +363,10 @@ mod tests { let err = build_geo_resolver(&config).expect_err("missing city db should fail"); assert_eq!(err.kind(), io::ErrorKind::InvalidInput); - assert_eq!(err.to_string(), "geoip_city_db and geoip_asn_db must be configured together"); + assert_eq!( + err.to_string(), + "geoip_city_db and geoip_asn_db must be configured together" + ); } #[test] @@ -352,6 +377,9 @@ mod tests { let err = build_geo_resolver(&config).expect_err("missing asn db should fail"); assert_eq!(err.kind(), io::ErrorKind::InvalidInput); - assert_eq!(err.to_string(), "geoip_city_db and geoip_asn_db must be configured together"); + assert_eq!( + err.to_string(), + "geoip_city_db and geoip_asn_db must be configured together" + ); } } diff --git a/src/bin/ddns-server/ocsp.rs b/src/bin/ddns-server/ocsp.rs new file mode 100644 index 0000000..2ce8969 --- /dev/null +++ b/src/bin/ddns-server/ocsp.rs @@ -0,0 +1,726 @@ +use std::{io, path::Path, time::Duration}; + +use der::{ + Choice, Decode, Encode, Enumerated, Sequence, + asn1::{Any, GeneralizedTime, Null, ObjectIdentifier, OctetString}, + oid::db::{rfc5912::ID_SHA_1, rfc6960::ID_PKIX_OCSP_BASIC}, +}; +use h3x::dquic::QuicEndpoint; +use reqwest::{ + Url, + header::{ACCEPT, CONTENT_TYPE}, +}; +use rustls::pki_types::{CertificateDer, UnixTime}; +use sha1::{Digest, Sha1}; +use tokio::time::sleep; +use tracing::{info, warn}; +use x509_cert::{ + Certificate, ext::Extensions, serial_number::SerialNumber, spki::AlgorithmIdentifierOwned, +}; +use x509_parser::{ + asn1_rs::{BitString as X509BitString, FromDer as X509FromDer}, + parse_x509_certificate, + verify::verify_signature as verify_x509_signature, + x509::AlgorithmIdentifier as X509AlgorithmIdentifier, +}; + +use crate::config::Config; + +pub const DEFAULT_OCSP_RESPONDER_BASE_URL: &str = "https://license.genmeta.net"; +pub const OCSP_STAPLING_TTL: Duration = Duration::from_secs(3 * 60 * 60); +pub const OCSP_REFRESH_EXPIRY_SKEW: Duration = Duration::from_secs(5 * 60); +pub const OCSP_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5 * 60); + +pub struct OcspAutoRefresh { + responder_url: String, + http_client: reqwest::Client, + request_der: Vec, + leaf_der: CertificateDer<'static>, + issuer_der: CertificateDer<'static>, +} + +impl OcspAutoRefresh { + pub fn from_config(config: &Config, cert_pem: &[u8], root_cert_pem: &[u8]) -> io::Result { + let base_url = config + .ocsp_responder_base_url + .as_deref() + .unwrap_or(DEFAULT_OCSP_RESPONDER_BASE_URL); + let responder_url = normalize_base_url(base_url)?; + let (request_der, leaf_der, issuer_der) = + build_ocsp_request_context(cert_pem, config.ocsp_issuer_cert.as_deref())?; + let root_cert = reqwest::Certificate::from_pem(root_cert_pem) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let http_client = reqwest::Client::builder() + .add_root_certificate(root_cert) + .timeout(Duration::from_secs(15)) + .build() + .map_err(io::Error::other)?; + + Ok(Self { + responder_url, + http_client, + request_der, + leaf_der, + issuer_der, + }) + } + + pub fn responder_url(&self) -> &str { + &self.responder_url + } + + pub async fn refresh_once(&self, quic: &mut QuicEndpoint) -> Duration { + match self.fetch_response().await { + Ok(response_der) => match self.validate_response(&response_der) { + Ok(OcspCertStatus::Good) => { + let response_len = response_der.len(); + quic.update_ocsp(Some(response_der)); + info!( + responder_url = %self.responder_url, + response_len, + refresh_in_secs = refresh_success_delay().as_secs(), + "ocsp.staple_refreshed" + ); + refresh_success_delay() + } + Ok(OcspCertStatus::Unknown) => { + warn!( + responder_url = %self.responder_url, + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp response status is unknown; skipping staple update" + ); + OCSP_REFRESH_RETRY_DELAY + } + Ok(OcspCertStatus::Revoked) => { + warn!( + responder_url = %self.responder_url, + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp response status is revoked; skipping staple update" + ); + OCSP_REFRESH_RETRY_DELAY + } + Err(ValidateError::ResponderStatus(OcspResponseStatus::Unauthorized)) => { + warn!( + responder_url = %self.responder_url, + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp responder returned unauthorized; skipping staple update" + ); + OCSP_REFRESH_RETRY_DELAY + } + Err(ValidateError::ResponderStatus(OcspResponseStatus::MalformedRequest)) => { + warn!( + responder_url = %self.responder_url, + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp responder returned malformed_request; skipping staple update" + ); + OCSP_REFRESH_RETRY_DELAY + } + Err(ValidateError::ResponderStatus(status)) => { + warn!( + responder_url = %self.responder_url, + ocsp_status = %status.as_str(), + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp responder returned a non-success status; skipping staple update" + ); + OCSP_REFRESH_RETRY_DELAY + } + Err(error) => { + warn!( + error = %error, + responder_url = %self.responder_url, + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp response validation failed; skipping staple update" + ); + OCSP_REFRESH_RETRY_DELAY + } + }, + Err(error) => { + warn!( + error = %error, + responder_url = %self.responder_url, + retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), + "ocsp.refresh_failed" + ); + OCSP_REFRESH_RETRY_DELAY + } + } + } + + pub async fn run(self, mut quic: QuicEndpoint) { + loop { + let delay = self.refresh_once(&mut quic).await; + sleep(delay).await; + } + } + + async fn fetch_response(&self) -> io::Result> { + let response = self + .http_client + .post(&self.responder_url) + .header(CONTENT_TYPE, "application/ocsp-request") + .header(ACCEPT, "application/ocsp-response") + .body(self.request_der.clone()) + .send() + .await + .map_err(request_error)?; + + if !response.status().is_success() { + let status = response.status(); + let message = response.text().await.unwrap_or_default(); + return Err(io::Error::other(format!( + "OCSP responder returned HTTP status {status}: {message}" + ))); + } + + let body = response.bytes().await.map_err(request_error)?; + if body.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "OCSP responder returned an empty body", + )); + } + + Ok(body.to_vec()) + } + + fn validate_response(&self, response_der: &[u8]) -> Result { + verify_stapled_ocsp_response(&self.leaf_der, &self.issuer_der, response_der, now()) + } +} + +pub fn refresh_success_delay() -> Duration { + OCSP_STAPLING_TTL.saturating_sub(OCSP_REFRESH_EXPIRY_SKEW) +} + +fn build_ocsp_request_context( + cert_pem: &[u8], + issuer_override: Option<&Path>, +) -> io::Result<(Vec, CertificateDer<'static>, CertificateDer<'static>)> { + let chain = load_pem_certificates(cert_pem)?; + let leaf_der = chain.first().cloned().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "server certificate PEM does not contain a certificate", + ) + })?; + let issuer_der = match chain.get(1) { + Some(issuer) => issuer.clone(), + None => load_issuer_certificate(issuer_override)?, + }; + + let leaf = Certificate::from_der(leaf_der.as_ref()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let issuer = Certificate::from_der(issuer_der.as_ref()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let request_der = build_request_der(&leaf, &issuer).map_err(io::Error::other)?; + + Ok((request_der, leaf_der, issuer_der)) +} + +fn load_issuer_certificate(issuer_override: Option<&Path>) -> io::Result> { + let issuer_path = issuer_override.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "OCSP auto-refresh requires the server cert PEM to include the issuer cert or ocsp_issuer_cert to be configured", + ) + })?; + let issuer_pem = std::fs::read(issuer_path)?; + let issuer_chain = load_pem_certificates(&issuer_pem)?; + issuer_chain.into_iter().next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "ocsp_issuer_cert does not contain a certificate", + ) + }) +} + +fn load_pem_certificates(cert_pem: &[u8]) -> io::Result>> { + let mut reader = std::io::Cursor::new(cert_pem); + rustls_pemfile::certs(&mut reader) + .collect::, _>>() + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) +} + +fn normalize_base_url(base_url: &str) -> io::Result { + let parsed = + Url::parse(base_url).map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; + if parsed.scheme() != "https" { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "ocsp_responder_base_url must use https", + )); + } + if parsed.host_str().is_none() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "ocsp_responder_base_url must include a host", + )); + } + + Ok(format!("{}/ocsp", parsed.as_str().trim_end_matches('/'))) +} + +fn request_error(error: reqwest::Error) -> io::Error { + io::Error::other(format!("failed to query OCSP responder: {error}")) +} + +fn now() -> UnixTime { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default(); + UnixTime::since_unix_epoch(now) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OcspCertStatus { + Good, + Revoked, + Unknown, +} + +#[derive(Debug)] +enum ValidateError { + ResponderStatus(OcspResponseStatus), + Invalid(String), +} + +impl std::fmt::Display for ValidateError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ResponderStatus(status) => { + write!(f, "OCSP responder returned status {}", status.as_str()) + } + Self::Invalid(message) => f.write_str(message), + } + } +} + +impl std::error::Error for ValidateError {} + +#[derive(Debug, Clone)] +struct ParsedOcspResponse { + status: OcspCertStatus, + basic: BasicOcspResponse, +} + +fn verify_stapled_ocsp_response( + end_entity: &CertificateDer<'_>, + issuer: &CertificateDer<'_>, + response_der: &[u8], + now: UnixTime, +) -> Result { + let end_entity_cert = Certificate::from_der(end_entity.as_ref()).map_err(|error| { + ValidateError::Invalid(format!("failed to decode end-entity cert: {error}")) + })?; + let issuer_cert = Certificate::from_der(issuer.as_ref()).map_err(|error| { + ValidateError::Invalid(format!("failed to decode issuer cert: {error}")) + })?; + let parsed = decode_unvalidated_ocsp_response_der(response_der, now)?; + let single = parsed + .basic + .tbs_response_data + .responses + .first() + .expect("single response checked during OCSP decode"); + let expected_cert_id = build_cert_id_local(&end_entity_cert, &issuer_cert)?; + + if !matches_cert_id(&single.cert_id, &expected_cert_id) { + return Err(ValidateError::Invalid( + "OCSP response cert_id does not match the server certificate".to_owned(), + )); + } + + if !responder_id_matches_certificate( + &parsed.basic.tbs_response_data.responder_id, + &issuer_cert, + )? { + return Err(ValidateError::Invalid( + "OCSP responder identifier does not match the issuer certificate".to_owned(), + )); + } + + verify_basic_ocsp_signature(&parsed.basic, issuer.as_ref())?; + + Ok(parsed.status) +} + +fn decode_unvalidated_ocsp_response_der( + response_der: &[u8], + now: UnixTime, +) -> Result { + let response = OcspResponse::from_der(response_der).map_err(der_error_string)?; + if response.response_status != OcspResponseStatus::Successful { + return Err(ValidateError::ResponderStatus(response.response_status)); + } + + let response_bytes = response.response_bytes.ok_or_else(|| { + ValidateError::Invalid("OCSP response is missing response bytes".to_owned()) + })?; + if response_bytes.response_type != ID_PKIX_OCSP_BASIC { + return Err(ValidateError::Invalid( + "unsupported OCSP response type".to_owned(), + )); + } + + let basic = BasicOcspResponse::from_der(response_bytes.response.as_bytes()) + .map_err(der_error_string)?; + let [single] = basic.tbs_response_data.responses.as_slice() else { + return Err(ValidateError::Invalid( + "OCSP response must contain exactly one single response".to_owned(), + )); + }; + + let produced_at = as_unix_time(&basic.tbs_response_data.produced_at); + if produced_at.as_secs() > now.as_secs() { + return Err(ValidateError::Invalid( + "OCSP response produced_at is in the future".to_owned(), + )); + } + + let this_update = as_unix_time(&single.this_update); + if this_update.as_secs() > now.as_secs() { + return Err(ValidateError::Invalid( + "OCSP response this_update is in the future".to_owned(), + )); + } + + let valid_until = single + .next_update + .as_ref() + .map(as_unix_time) + .unwrap_or(this_update); + if valid_until.as_secs() < this_update.as_secs() { + return Err(ValidateError::Invalid( + "OCSP response next_update is earlier than this_update".to_owned(), + )); + } + if valid_until.as_secs() < now.as_secs() { + return Err(ValidateError::Invalid( + "OCSP response is already expired".to_owned(), + )); + } + + let status = match &single.cert_status { + CertStatus::Good(_) => OcspCertStatus::Good, + CertStatus::Revoked(_) => OcspCertStatus::Revoked, + CertStatus::Unknown(_) => OcspCertStatus::Unknown, + }; + + Ok(ParsedOcspResponse { status, basic }) +} + +fn verify_basic_ocsp_signature( + basic: &BasicOcspResponse, + signer_der: &[u8], +) -> Result<(), ValidateError> { + let signer = parse_x509_certificate_der(signer_der, "OCSP signer certificate")?; + let signature_algorithm_der = basic + .signature_algorithm + .to_der() + .map_err(der_error_string)?; + let (_, signature_algorithm) = X509AlgorithmIdentifier::from_der(&signature_algorithm_der) + .map_err(|error| { + ValidateError::Invalid(format!( + "failed to parse OCSP response signature algorithm: {error}" + )) + })?; + let signature_der = basic.signature.to_der().map_err(der_error_string)?; + let (_, signature_value) = X509BitString::from_der(&signature_der).map_err(|error| { + ValidateError::Invalid(format!( + "failed to parse OCSP response signature value: {error}" + )) + })?; + let tbs_der = basic.tbs_response_data.to_der().map_err(der_error_string)?; + + verify_x509_signature( + signer.public_key(), + &signature_algorithm, + &signature_value, + &tbs_der, + ) + .map_err(|error| { + ValidateError::Invalid(format!("failed to verify OCSP response signature: {error}")) + }) +} + +fn parse_x509_certificate_der<'a>( + cert_der: &'a [u8], + label: &str, +) -> Result, ValidateError> { + parse_x509_certificate(cert_der) + .map(|(_, cert)| cert) + .map_err(|error| ValidateError::Invalid(format!("failed to parse {label}: {error:?}"))) +} + +fn responder_id_matches_certificate( + responder_id: &ResponderId, + certificate: &Certificate, +) -> Result { + match responder_id { + ResponderId::ByName(name) => Ok(name.to_der().map_err(der_error_string)? + == certificate + .tbs_certificate() + .subject() + .to_der() + .map_err(der_error_string)?), + ResponderId::ByKey(key_hash) => Ok(key_hash.as_bytes() + == Sha1::digest( + certificate + .tbs_certificate() + .subject_public_key_info() + .subject_public_key + .raw_bytes(), + ) + .as_slice()), + } +} + +fn build_cert_id_local( + end_entity: &Certificate, + issuer: &Certificate, +) -> Result { + let issuer_name_hash = Sha1::digest( + issuer + .tbs_certificate() + .subject() + .to_der() + .map_err(der_error_string)?, + ); + let issuer_key_hash = Sha1::digest( + issuer + .tbs_certificate() + .subject_public_key_info() + .subject_public_key + .raw_bytes(), + ); + + Ok(CertId { + hash_algorithm: AlgorithmIdentifierOwned { + oid: ID_SHA_1, + parameters: Some(Null.into()), + }, + issuer_name_hash: OctetString::new(issuer_name_hash.as_slice()) + .map_err(der_error_string)?, + issuer_key_hash: OctetString::new(issuer_key_hash.as_slice()).map_err(der_error_string)?, + serial_number: end_entity.tbs_certificate().serial_number().clone(), + }) +} + +fn matches_cert_id(actual: &CertId, expected: &CertId) -> bool { + actual.hash_algorithm.oid == expected.hash_algorithm.oid + && actual.issuer_name_hash == expected.issuer_name_hash + && actual.issuer_key_hash == expected.issuer_key_hash + && actual.serial_number == expected.serial_number +} + +fn build_request_der(end_entity: &Certificate, issuer: &Certificate) -> Result, String> { + OcspRequest { + tbs_request: TbsRequest { + version: Version::default(), + requestor_name: None, + request_list: vec![RequestEntry { + req_cert: build_request_cert_id(end_entity, issuer)?, + single_request_extensions: None, + }], + request_extensions: None, + }, + optional_signature: None, + } + .to_der() + .map_err(der_error) +} + +fn build_request_cert_id(end_entity: &Certificate, issuer: &Certificate) -> Result { + let issuer_name_hash = Sha1::digest( + issuer + .tbs_certificate() + .subject() + .to_der() + .map_err(der_error)?, + ); + let issuer_key_hash = Sha1::digest( + issuer + .tbs_certificate() + .subject_public_key_info() + .subject_public_key + .raw_bytes(), + ); + + Ok(CertId { + hash_algorithm: AlgorithmIdentifierOwned { + oid: ID_SHA_1, + parameters: Some(Null.into()), + }, + issuer_name_hash: OctetString::new(issuer_name_hash.as_slice()).map_err(der_error)?, + issuer_key_hash: OctetString::new(issuer_key_hash.as_slice()).map_err(der_error)?, + serial_number: end_entity.tbs_certificate().serial_number().clone(), + }) +} + +fn as_unix_time(time: &GeneralizedTime) -> UnixTime { + UnixTime::since_unix_epoch(time.to_unix_duration()) +} + +fn der_error(error: impl std::fmt::Display) -> String { + format!("failed to process OCSP DER: {error}") +} + +fn der_error_string(error: impl std::fmt::Display) -> ValidateError { + ValidateError::Invalid(der_error(error)) +} + +#[derive(Clone, Debug, Default, Copy, PartialEq, Eq, Enumerated)] +#[asn1(type = "INTEGER")] +#[repr(u8)] +enum Version { + #[default] + V1 = 0, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct OcspRequest { + tbs_request: TbsRequest, + + #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] + optional_signature: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct TbsRequest { + #[asn1( + context_specific = "0", + default = "Default::default", + tag_mode = "EXPLICIT" + )] + version: Version, + + #[asn1(context_specific = "1", optional = "true", tag_mode = "EXPLICIT")] + requestor_name: Option, + + request_list: Vec, + + #[asn1(context_specific = "2", optional = "true", tag_mode = "EXPLICIT")] + request_extensions: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct RequestEntry { + req_cert: CertId, + + #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] + single_request_extensions: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct CertId { + hash_algorithm: AlgorithmIdentifierOwned, + issuer_name_hash: OctetString, + issuer_key_hash: OctetString, + serial_number: SerialNumber, +} + +#[derive(Clone, Debug, Eq, PartialEq, Choice)] +enum CertStatus { + #[asn1(context_specific = "0", tag_mode = "IMPLICIT")] + Good(Null), + + #[asn1(context_specific = "1", tag_mode = "IMPLICIT", constructed = "true")] + Revoked(RevokedInfo), + + #[asn1(context_specific = "2", tag_mode = "IMPLICIT")] + Unknown(Null), +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct RevokedInfo { + revocation_time: GeneralizedTime, + + #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] + revocation_reason: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct SingleResponse { + cert_id: CertId, + cert_status: CertStatus, + this_update: GeneralizedTime, + + #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] + next_update: Option, + + #[asn1(context_specific = "1", optional = "true", tag_mode = "EXPLICIT")] + single_extensions: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Choice)] +enum ResponderId { + #[asn1(context_specific = "1", tag_mode = "EXPLICIT", constructed = "true")] + ByName(Any), + + #[asn1(context_specific = "2", tag_mode = "EXPLICIT", constructed = "true")] + ByKey(OctetString), +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct ResponseData { + #[asn1( + context_specific = "0", + default = "Default::default", + tag_mode = "EXPLICIT" + )] + version: Version, + responder_id: ResponderId, + produced_at: GeneralizedTime, + responses: Vec, + + #[asn1(context_specific = "1", optional = "true", tag_mode = "EXPLICIT")] + response_extensions: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct BasicOcspResponse { + tbs_response_data: ResponseData, + signature_algorithm: AlgorithmIdentifierOwned, + signature: der::asn1::BitString, + + #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] + certs: Option>, +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct ResponseBytes { + response_type: ObjectIdentifier, + response: OctetString, +} + +#[derive(Enumerated, Copy, Clone, Debug, Eq, PartialEq)] +#[repr(u32)] +enum OcspResponseStatus { + Successful = 0, + MalformedRequest = 1, + InternalError = 2, + TryLater = 3, + SigRequired = 5, + Unauthorized = 6, +} + +impl OcspResponseStatus { + fn as_str(self) -> &'static str { + match self { + Self::Successful => "successful", + Self::MalformedRequest => "malformed_request", + Self::InternalError => "internal_error", + Self::TryLater => "try_later", + Self::SigRequired => "sig_required", + Self::Unauthorized => "unauthorized", + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Sequence)] +struct OcspResponse { + response_status: OcspResponseStatus, + + #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] + response_bytes: Option, +} diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish.rs index b9cbf22..860807e 100644 --- a/src/bin/ddns-server/publish.rs +++ b/src/bin/ddns-server/publish.rs @@ -234,7 +234,10 @@ pub async fn publish_record( for country in &index_tags.countries { let key = redis_country_index_key(host, country); touched_index_keys.insert(key.clone()); - if let Err(e) = conn.zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64).await { + if let Err(e) = conn + .zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64) + .await + { return write_error(AppError::Redis { message: e.to_string(), }); @@ -244,7 +247,10 @@ pub async fn publish_record( for asn in &index_tags.asns { let key = redis_asn_index_key(host, *asn); touched_index_keys.insert(key.clone()); - if let Err(e) = conn.zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64).await { + if let Err(e) = conn + .zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64) + .await + { return write_error(AppError::Redis { message: e.to_string(), }); diff --git a/src/bin/ddns-server/storage.rs b/src/bin/ddns-server/storage.rs index 5308db7..3e5864b 100644 --- a/src/bin/ddns-server/storage.rs +++ b/src/bin/ddns-server/storage.rs @@ -6,8 +6,8 @@ use std::{ use bytes::BufMut; use dashmap::DashMap; -use deadpool_redis::Pool; use ddns::core::parser::{packet::be_packet, record::RData}; +use deadpool_redis::Pool; use nom::{ IResult, bytes::streaming::take, From 25f0b23d1210c9da932591b619c90a1f1c249bb4 Mon Sep 17 00:00:00 2001 From: metah3m Date: Thu, 11 Jun 2026 16:30:19 +0800 Subject: [PATCH 04/29] feat: check if it is on the blacklist during lookup --- .gitignore | 3 +- Cargo.toml | 3 - README.md | 1 + docs/redis-contract.md | 398 ++++++++++++++ examples/README.md | 2 + examples/publish.rs | 50 +- examples/query.rs | 44 +- server.toml | 16 +- src/bin/ddns-server/config.rs | 4 + src/bin/ddns-server/lookup.rs | 181 ++++++- src/bin/ddns-server/main.rs | 27 +- src/bin/ddns-server/policy.rs | 249 ++------- src/bin/ddns-server/publish.rs | 57 +- src/bin/ddns-server/storage.rs | 148 +++++- src/core.rs | 1 + src/core/parser/record/endpoint.rs | 284 ++-------- src/core/parser/sigin.rs | 105 ++++ src/core/signature.rs | 322 ++++++++++++ src/core/wire.rs | 94 +++- src/publisher.rs | 815 +++++++++++++++++++---------- src/resolvers.rs | 4 +- src/resolvers/h3.rs | 128 +++-- src/resolvers/http.rs | 128 ++++- 23 files changed, 2163 insertions(+), 901 deletions(-) create mode 100644 docs/redis-contract.md create mode 100644 src/core/signature.rs diff --git a/.gitignore b/.gitignore index 4a376e8..855760d 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ build .DS_Store .vscode/ /geoip -/docs +/docs/superpowers +/certs diff --git a/Cargo.toml b/Cargo.toml index 53f0a60..67d4577 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -136,6 +136,3 @@ required-features = ["h3x-resolver"] [patch.crates-io] proc-macro-error2 = { path = "patches/proc-macro-error2" } - -[patch."https://github.com/genmeta/h3x.git"] -h3x = { path = "../h3x-endpoint-local" } diff --git a/README.md b/README.md index de665a2..7b081d2 100644 --- a/README.md +++ b/README.md @@ -251,3 +251,4 @@ src/bin/ddns-server/ DNS-over-H3 server implementation examples/ mDNS and DNS-over-H3 example programs server.toml Example server configuration ``` + diff --git a/docs/redis-contract.md b/docs/redis-contract.md new file mode 100644 index 0000000..bcf5028 --- /dev/null +++ b/docs/redis-contract.md @@ -0,0 +1,398 @@ +# gmdns Redis 存储说明 + +这份文档描述 `ddns-server` 在 Redis 里实际会存什么、怎么存、这些数据分别 +是干什么用的。它优先写给人看,同时保留足够的细节,方便别的服务对接。 + +如果你只想先看结论,这个系统在 Redis 里只会用到 3 类原生数据结构: + +1. `String`:存一条发布记录的完整二进制内容 +2. `Sorted Set`:存查询用的倒排索引 +3. `Set`:存黑名单域名 + +没有使用 `Hash`、`List`、`Stream`、`Bitmap` 之类的其他 Redis 结构。 + +## 1. 总览 + +`ddns-server` 自己维护的 Redis key 一共 4 种形态,加上 1 个外部可写黑名单: + +| Key 形态 | Redis 类型 | 作用 | +| --- | --- | --- | +| `:fp:` | `String` | 某个 host 下,某个证书指纹对应的一条完整发布记录 | +| `:idx:all` | `Sorted Set` | 这个 host 的全部活动记录索引 | +| `:idx:country:` | `Sorted Set` | 这个 host 按国家分桶的活动记录索引 | +| `:idx:asn:` | `Sorted Set` | 这个 host 按 ASN 分桶的活动记录索引 | +| `ddns:blacklist` | `Set` | 被封禁的 host 列表 | + +其中: + +- 主记录 `String` 是事实来源,真正的记录内容只在这里 +- 3 类 `Sorted Set` 都是派生索引,只是为了加速查询 +- 黑名单 `Set` 是一个独立控制面数据,不参与记录存储 + +## 2. Host 规范化规则 + +Redis 里的 host 名必须先做规范化。代码实现见 +[`src/bin/ddns-server/error.rs`](/Users/lixiaofeng/code/gmdns/src/bin/ddns-server/error.rs) 的 +`normalize_host()`。 + +规则如下: + +1. 去掉首尾空白 +2. 不能为空 +3. 不能包含 `*` +4. 如果最后一个 `:` 后面全是数字,就当成端口号去掉 +5. 去掉结尾的一个 `.` +6. 用 IDNA 转成 ASCII +7. 转成小写 +8. 最终结果必须以 `genmeta.net` 结尾 + +例子: + +- `DNS.Genmeta.Net.` -> `dns.genmeta.net` +- `dns.genmeta.net:4433` -> `dns.genmeta.net` +- `blocked.example.genmeta.net` -> `blocked.example.genmeta.net` + +这条规则对所有 Redis key 都重要,尤其是黑名单成员必须写规范化之后的 host。 + +## 3. 各类 Redis 数据结构 + +### 3.1 主记录 + +Key 形式: + +```text +:fp: +``` + +例子: + +```text +nat.genmeta.net:fp:db6905c72be9aa8b1a61f7d45dd399d64136da17ac384ef67f1f5670055a2946 +``` + +Redis 类型: + +```text +String +``` + +值的含义: + +- 存的是一个二进制 `StoredRecord` +- 里面包含这条记录的完整 DNS 包、发布者证书、签名字段、过期时间等 + +TTL: + +- 通过 `SETEX` / `SET EX` 写入 +- TTL 等于服务配置里的 `ttl_secs` + +业务语义: + +- 同一个 `host` 下,同一个证书指纹只能有 1 条活动记录 +- 同一个证书再次发布,会覆盖自己之前的记录 +- 同一个 `host` 下,不同证书指纹可以并存 + +可以把它理解成: + +```text +一个 host 下,以“证书指纹”作为主键的记录表 +``` + +### 3.2 全量索引 + +Key 形式: + +```text +:idx:all +``` + +例子: + +```text +nat.genmeta.net:idx:all +``` + +Redis 类型: + +```text +Sorted Set +``` + +成员和值: + +- member: `` +- score: 发布时间的 Unix 秒时间戳,代码里按 `f64` 写入 + +TTL: + +- 每次写入相关记录时,会给这个索引 key 重新设置 `ttl_secs` + +业务语义: + +- 表示这个 host 当前有哪些候选发布者记录 +- 查询时,如果 GEO 定向索引不够用,会回退到这个索引 +- 返回顺序是最新发布的在前面,因为读取时用的是 `ZREVRANGE` + +### 3.3 国家索引 + +Key 形式: + +```text +:idx:country: +``` + +例子: + +```text +nat.genmeta.net:idx:country:CN +``` + +Redis 类型: + +```text +Sorted Set +``` + +成员和值: + +- member: `` +- score: 发布时间的 Unix 秒时间戳 + +`` 从哪里来: + +- 发布时解析 DNS 包里的 endpoint IP +- 对这些 IP 做 GEO 查询 +- 把得到的国家代码去重、排序后写入索引 + +业务语义: + +- 这是按国家分桶的候选记录索引 +- 查询时,如果请求方的源 IP 能解析出国家,会先尝试这个桶 + +### 3.4 ASN 索引 + +Key 形式: + +```text +:idx:asn: +``` + +例子: + +```text +nat.genmeta.net:idx:asn:4134 +``` + +Redis 类型: + +```text +Sorted Set +``` + +成员和值: + +- member: `` +- score: 发布时间的 Unix 秒时间戳 + +`` 从哪里来: + +- 和国家索引一样,也是从发布内容里的 endpoint IP 做 GEO 解析得到 + +业务语义: + +- 这是按 ASN 分桶的候选记录索引 +- 查询时,如果请求方的源 IP 能解析出 ASN,会最先尝试这个桶 + +### 3.5 黑名单集合 + +Key: + +```text +ddns:blacklist +``` + +Redis 类型: + +```text +Set +``` + +成员格式: + +- 规范化之后的小写 ASCII host 名 + +例子: + +```text +blocked.example.genmeta.net +``` + +业务语义: + +- 查询开始时先查这个集合 +- 如果 `SISMEMBER ddns:blacklist ` 为真,直接返回 `404 Not Found` +- 黑名单只拦截查询,不拦截 publish,也不拦截 clear +- 黑名单不会删除已有记录 + +常用操作: + +```bash +redis-cli SADD ddns:blacklist blocked.example.genmeta.net +redis-cli SREM ddns:blacklist blocked.example.genmeta.net +``` + +## 4. 主记录里到底存了什么 + +主记录 value 不是 JSON,也不是 Hash,而是一段连续的二进制。 + +顺序如下: + +```text +u64 expire_unix_secs +u8 fingerprint[32] +u32 content_digest_len +u8 content_digest[content_digest_len] +u32 signature_input_len +u8 signature_input[signature_input_len] +u32 signature_len +u8 signature[signature_len] +u32 dns_len +u8 dns[dns_len] +u32 cert_len +u8 cert[cert_len] +``` + +字段说明: + +| 字段 | 含义 | +| --- | --- | +| `expire_unix_secs` | 这条记录的业务过期时间,Unix 秒 | +| `fingerprint` | 发布者叶子证书的 SHA-256 原始 32 字节,不是 hex 字符串 | +| `content_digest` | HTTP 签名里的 `Content-Digest` 原始字节 | +| `signature_input` | HTTP 签名里的 `Signature-Input` 原始字节 | +| `signature` | HTTP 签名里的 `Signature` 原始字节 | +| `dns` | 序列化后的 DNS 包体 | +| `cert` | 发布者叶子证书的 DER 字节 | + +补充说明: + +- 使用大端序 +- 没有版本号字段 +- 三个签名字段都允许为空 +- 如果记录没有签名,这三个字段长度就是 `0` + +## 5. 写入时怎么维护这些结构 + +### 5.1 Publish + +发布 `(host, fingerprint)` 时,流程是: + +1. 读取旧的主记录 +2. 如果旧记录能解码出来,就从旧记录推导出旧的国家 / ASN 标签 +3. 先把旧指纹从所有相关索引里删掉 +4. 用 `SETEX` 写入新的主记录 +5. 把指纹加入: + - `:idx:all` + - 若干 `:idx:country:` + - 若干 `:idx:asn:` +6. 给所有碰到的索引 key 重新设置 TTL +7. 对这些索引执行: + +```text +ZREMRANGEBYSCORE -inf +``` + +这样做的效果是: + +- 主记录会自然过期 +- 索引里过旧的 member 也会被顺手清掉 +- 同一个证书重复发布,不会在索引里留下重复脏数据 + +### 5.2 Clear + +清理 `(host, fingerprint)` 时,流程是: + +1. 读取旧主记录 +2. 从旧主记录推导出它所在的国家 / ASN 桶 +3. 把这个指纹从所有相关索引删掉 +4. 删除主记录 key + +### 5.3 一致性和自愈预期 + +这里的写入不是事务性的。 + +也就是说,一次 publish / clear 会改多个 key,但这些操作不是用单个 Redis +事务原子提交的。如果中途失败,短时间内可能出现下面这些情况: + +- 主记录已经更新,但部分索引还没更新 +- 索引里还留着旧指纹,但主记录已经不存在 +- 某些 GEO 桶暂时缺少一条本该存在的记录 + +这个设计对上述短暂不一致是接受的,原因有两个: + +1. 查询时真正可信的数据源始终是主记录 `String`,索引只是候选入口 +2. 节点会大约每 30 秒重新上报一次,同一条记录会被持续刷新 + +这意味着: + +- 如果索引里残留了一个已经失效的指纹,查询阶段读取不到主记录时会直接跳过 +- 如果一次写入导致某个索引短暂漏写,下一次节点上报通常会把它补回来 +- 即使没有专门的数据修复流程,TTL 和周期性重上报也会让大多数临时偏差自然收敛 + +因此,这套存储模型的目标是: + +- 接受短暂的不一致 +- 依赖 30 秒级的周期刷新实现轻量自愈 +- 不为了少量短时脏索引引入额外复杂的数据修复机制 + +## 6. 查询时怎么用这些结构 + +当 Redis 存储启用时,查询流程是: + +1. 规范化请求里的 host +2. 先查 `ddns:blacklist` +3. 按顺序收集候选指纹: + - 先 ASN 索引 + - 再国家索引 + - 最后全量索引 +4. 按这个顺序去重 +5. 逐个读取 `:fp:` 主记录 +6. 丢弃解码失败或业务上已经过期的记录 +7. 把剩下的记录交给现有排序逻辑继续处理 + +这里最重要的认识是: + +- `Sorted Set` 只是“候选名单” +- 真正可信的数据源始终是主记录 `String` + +## 7. 归属边界 + +`ddns-server` 自己维护下面这些 key: + +- `:fp:` +- `:idx:all` +- `:idx:country:` +- `:idx:asn:` + +外部服务如果只是想做黑名单联动,只应该写: + +- `ddns:blacklist` + +如果外部服务想直接写记录,就必须完整实现: + +- 主记录二进制编码 +- 所有派生索引的增删 +- TTL 维护 +- 过期索引清理 + +否则很容易把 Redis 里的记录和索引写乱。 + +## 8. 一句话理解 + +这个 Redis 模型本质上是: + +- 用一个 `String` 保存“完整记录” +- 用几个 `Sorted Set` 保存“按 host / 国家 / ASN 分类的候选索引” +- 用一个 `Set` 保存“是否禁止查询这个 host” + +真正的记录内容不在索引里,索引只是为了更快找到应该读哪个主记录。 diff --git a/examples/README.md b/examples/README.md index da3faa4..9430392 100644 --- a/examples/README.md +++ b/examples/README.md @@ -62,6 +62,7 @@ repeated count times: The example prints each DNS packet, the publisher certificate fingerprint when a certificate is present, and endpoint signature verification status for signed `E` records. + After the server starts, it listens for HTTP/3 requests and handles publish and query operations. If the configured server certificate includes its issuer chain, the process also fetches and refreshes its own stapled OCSP response from cert-server's public @@ -112,3 +113,4 @@ cargo run --bin ddns-server --features server -- --config server.toml `server.toml` documents the available fields: listener, TLS identity, client root CA, optional Redis storage, TTL, domain policies, and static seed records. + diff --git a/examples/publish.rs b/examples/publish.rs index 629fe9f..8abb1b0 100644 --- a/examples/publish.rs +++ b/examples/publish.rs @@ -7,7 +7,7 @@ use std::{ use clap::Parser; use ddns::{ - core::parser::record::endpoint::EndpointAddr, + core::{parser::record::endpoint::EndpointAddr, signature::SignatureFields}, resolvers::{DHTTP_H3_DNS_SERVER, h3::H3Publisher}, }; use h3x::dquic::{ @@ -42,7 +42,7 @@ struct Options { #[arg(long)] client_key: PathBuf, - /// Sign Endpoint records using the client private key. + /// Sign DNS packets using HTTP signature fields and the client private key. /// /// This must correspond to the client certificate presented in mTLS, because the server /// verifies the signature with the peer certificate's SPKI. @@ -56,6 +56,12 @@ struct Options { /// 要发布的地址列表。 #[arg(long, value_delimiter = ',', num_args = 1..)] addr: Vec, + + #[arg(long, default_value_t = true)] + is_main: bool, + + #[arg(long, default_value_t = 1)] + sequence: u64, } fn default_h3_base_url() -> String { @@ -136,38 +142,38 @@ async fn main() -> io::Result<()> { info!(host = %opt.host, addrs = ?opt.addr, base_url = %opt.base_url, "publish.start"); if opt.sign { - info!("publish.endpoint_signing.enabled"); + info!("publish.packet_signing.enabled"); } else { - info!("publish.endpoint_signing.disabled"); + info!("publish.packet_signing.disabled"); } - let selector = identity - .dhttp_subject_key_identifier() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let chain = selector.chain(); for &addr in &opt.addr { - info!("creating endpoint for address: {}", addr); + info!("Creating endpoint for address: {}", addr); let mut endpoint = match addr { SocketAddr::V4(v4) => EndpointAddr::direct_v4(v4), SocketAddr::V6(v6) => EndpointAddr::direct_v6(v6), }; - endpoint.set_certificate_chain_key(chain); + endpoint.set_main(opt.is_main); + endpoint.set_sequence(opt.sequence); + info!("Publishing endpoint: {:?}", endpoint); + let mut hosts = std::collections::HashMap::new(); + hosts.insert(opt.host.clone(), vec![endpoint]); + let packet = ddns::core::MdnsPacket::answer(0, &hosts).to_bytes(); if opt.sign { - info!("signing endpoint"); - endpoint - .sign_with_authority(identity.as_ref()) + info!("signing dns packet"); + let signature_fields = SignatureFields::sign(&packet, identity.as_ref()) .await .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + resolver + .publish_signed(&opt.host, &packet, &signature_fields) + .await?; + } else { + resolver + .publish(&opt.host, &packet) + .await + .map_err(io::Error::other)?; } - info!("publishing endpoint: {:?}", endpoint); - let mut hosts = std::collections::HashMap::new(); - hosts.insert(opt.host.clone(), vec![endpoint]); - let packet = ddns::core::MdnsPacket::answer(0, &hosts).to_bytes(); - resolver - .publish(&opt.host, &packet) - .await - .map_err(io::Error::other)?; - info!("successfully published endpoint for {}", addr); + info!("Successfully published endpoint for {}", addr); } info!("publish.ok"); diff --git a/examples/query.rs b/examples/query.rs index 80235cd..c86f553 100644 --- a/examples/query.rs +++ b/examples/query.rs @@ -74,7 +74,7 @@ fn format_packet(packet: &MdnsPacket) -> String { RData::E(ep) => { output.push_str(&format!("Name: {}\nAddress: {}\n", rr.name(), ep)); if ep.is_signed() { - output.push_str("Signature: present\n"); + output.push_str("Legacy E signature: present\n"); } } _ => { @@ -145,12 +145,22 @@ async fn main() -> Result<(), Box> { if resp.status().is_success() { let bytes = resp.into_body().collect().await?.to_bytes(); - let (_remain, multi) = be_multi_response(bytes.as_ref()).map_err(|e| { + let (remain, multi) = be_multi_response(bytes.as_ref()).map_err(|e| { io::Error::new( io::ErrorKind::InvalidData, format!("Invalid multi-record payload: {e}"), ) })?; + if !remain.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "Invalid multi-record payload: {} trailing bytes", + remain.len() + ), + ) + .into()); + } info!(count = multi.records.len(), "lookup.ok"); println!("Lookup Result: {} record(s)", multi.records.len()); @@ -163,27 +173,21 @@ async fn main() -> Result<(), Box> { None => println!("Source fingerprint: (no certificate)"), } + if record.signature_fields.is_empty() { + println!("Packet signature: none"); + } else if record.cert.is_empty() { + println!("Packet signature: present but no certificate to verify against"); + } else { + match record.signature_fields.verify(&record.dns, &record.cert) { + Ok(true) => println!("Packet signature: ✓ verified"), + Ok(false) => println!("Packet signature: ✗ invalid"), + Err(e) => println!("Packet signature: ✗ error ({e:?})"), + } + } + match ddns::core::parser::packet::be_packet(&record.dns) { Ok((_, packet)) => { print!("{}", format_packet(&packet)); - - for rr in &packet.answers { - if let RData::E(ep) = rr.data() { - if !ep.is_signed() { - println!("Signature: none"); - continue; - } - if record.cert.is_empty() { - println!("Signature: present but no certificate to verify against"); - continue; - } - match ep.verify_signature_from_der(&record.cert) { - Ok(true) => println!("Signature: ✓ verified"), - Ok(false) => println!("Signature: ✗ invalid"), - Err(e) => println!("Signature: ✗ error ({e:?})"), - } - } - } } Err(_) => { println!("DNS payload: invalid ({} bytes)", record.dns.len()); diff --git a/server.toml b/server.toml index ee1733c..08ac858 100644 --- a/server.toml +++ b/server.toml @@ -24,7 +24,11 @@ root_cert = "~/Downloads/ssl/root.crt" # refreshed 5 minutes early). # ocsp_responder_base_url = "https://license.genmeta.net" -# Whether to require a valid DNS record signature on Standard domains. +# Whether to require RFC 9421/9530-style packet signatures on Standard domains. +# Signed publish requests keep the DNS packet as the HTTP body and provide: +# Content-Digest: sha-256=:...: +# Signature-Input: dns=("content-digest");created=...;keyid="sha256:";alg="..." +# Signature: dns=:...: require_signature = true # Default TTL (seconds) for published records. @@ -33,6 +37,16 @@ ttl_secs = 30 # Redis URL for persistent storage. # If omitted, records are kept in memory only (lost on restart). # redis = "redis://127.0.0.1/" +# +# When Redis storage is enabled, lookups check the external blacklist set +# "ddns:blacklist". Without Redis, this file can preload an in-memory blacklist. +# Members are normalized lowercase ASCII host names. Blacklisted lookups return +# 404; publish/clear requests are not blocked. +# +# redis-cli SADD ddns:blacklist blocked.example.genmeta.net +# redis-cli SREM ddns:blacklist blocked.example.genmeta.net +# +# blacklist = ["blocked.example.genmeta.net"] # Enable GEO-aware scheduling based on country / ASN. # When both databases are configured, city-distance tie-breaking is also enabled diff --git a/src/bin/ddns-server/config.rs b/src/bin/ddns-server/config.rs index 3580ef9..01ca63d 100644 --- a/src/bin/ddns-server/config.rs +++ b/src/bin/ddns-server/config.rs @@ -74,6 +74,10 @@ pub struct Config { #[serde(default)] pub domain_policies: Vec, + /// In-memory blacklist loaded at startup when Redis storage is not configured. + #[serde(default)] + pub blacklist: Vec, + /// Static seed records returned on lookup in addition to dynamic published records. #[serde(default)] pub seed_records: Vec, diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs index 6171860..37cf02d 100644 --- a/src/bin/ddns-server/lookup.rs +++ b/src/bin/ddns-server/lookup.rs @@ -11,15 +11,10 @@ use std::{ use ddns::core::{ MdnsPacket, parser::{packet::be_packet, record::RData}, - wire::MultiResponse, + wire::{MultiResponse, ResponseRecord}, }; use deadpool_redis::redis::{self, AsyncCommands}; -<<<<<<< HEAD -use h3x::{connection::ConnectionState, quic}; -use h3x::dhttp::message::MessageStreamError; -======= -use h3x::{connection::ConnectionState, message::stream::MessageStreamError, quic}; ->>>>>>> 13e6482 (feat: add GeoIP support for geo-routing) +use h3x::{connection::ConnectionState, dhttp::message::MessageStreamError, quic}; use http_body_util::{Full, combinators::UnsyncBoxBody}; use tracing::debug; @@ -27,8 +22,9 @@ use crate::{ error::{AppError, normalize_host, parse_query_params}, geo::{GeoResolver, GeoTraits}, storage::{ - AppState, LookupRecord, Storage, StoredRecord, redis_all_index_key, redis_asn_index_key, - redis_country_index_key, redis_primary_key, unix_now_secs, + AppState, LookupRecord, MemoryStorage, SeedRecords, Storage, StoredRecord, + redis_all_index_key, redis_asn_index_key, redis_blacklist_key, redis_country_index_key, + redis_primary_key, unix_now_secs, }, }; @@ -68,9 +64,14 @@ fn normalize_lookup_records(records: Vec) -> Vec { let mut normalized = Vec::new(); let mut seen = HashSet::new(); - for (dns_bytes, cert_bytes) in records { - let Ok((_, packet)) = be_packet(&dns_bytes) else { - normalized.push((dns_bytes, cert_bytes)); + for record in records { + if !record.signature_fields.is_empty() { + normalized.push(record); + continue; + } + + let Ok((_, packet)) = be_packet(&record.dns) else { + normalized.push(record); continue; }; @@ -90,11 +91,14 @@ fn normalize_lookup_records(records: Vec) -> Vec { let mut hosts = HashMap::new(); hosts.insert(answer.name().to_string(), vec![endpoint.clone()]); - normalized.push((MdnsPacket::answer(0, &hosts).to_bytes(), cert_bytes.clone())); + normalized.push(ResponseRecord::unsigned( + MdnsPacket::answer(0, &hosts).to_bytes(), + record.cert.clone(), + )); } if !emitted_endpoint { - normalized.push((dns_bytes, cert_bytes)); + normalized.push(record); } } @@ -120,7 +124,7 @@ fn sort_lookup_records(records: Vec, source_ip: Option) -> .into_iter() .enumerate() .map(|(index, record)| { - let sort_key = lookup_endpoint(&record.0).map(|(endpoint, load)| { + let sort_key = lookup_endpoint(&record.dns).map(|(endpoint, load)| { let family_match = source_ip .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) .unwrap_or(false); @@ -280,7 +284,7 @@ fn sort_lookup_records_with_geo( .into_iter() .enumerate() .map(|(index, record)| { - let sort_key = lookup_endpoint_geo_traits(&record.0, geo).map( + let sort_key = lookup_endpoint_geo_traits(&record.dns, geo).map( |(endpoint, load, endpoint_traits)| { build_geo_sort_key( source_ip, @@ -353,6 +357,12 @@ async fn perform_lookup_multi( let mut conn = pool.get().await.map_err(|e| AppError::Redis { message: e.to_string(), })?; + + if redis_host_blacklisted(&mut *conn, host).await? { + debug!(host = %host, "lookup.blacklisted"); + return Ok(LookupResult::NotFound); + } + let now_secs = unix_now_secs(); let cutoff_score = now_secs.saturating_sub(state.ttl_secs) as f64; let mut candidate_fingerprints = Vec::new(); @@ -461,13 +471,22 @@ async fn perform_lookup_multi( continue; }; if record.expire_unix_secs > now_secs { - records.push((record.dns, record.cert)); + records.push(ResponseRecord::new( + record.signature_fields, + record.dns, + record.cert, + )); } } records } Storage::Memory(mem) => { + if mem.is_blacklisted(host) { + debug!(host = %host, "lookup.blacklisted"); + return Ok(LookupResult::NotFound); + } + let now = tokio::time::Instant::now(); if let Some(mut entry) = mem.records.get_mut(host) { entry.retain_active(now); @@ -485,10 +504,13 @@ async fn perform_lookup_multi( candidate_fingerprints .into_iter() .filter_map(|fingerprint| { - entry - .records - .get(&fingerprint) - .map(|record| (record.dns_bytes.clone(), record.cert_bytes.clone())) + entry.records.get(&fingerprint).map(|record| { + ResponseRecord::new( + record.signature_fields.clone(), + record.dns_bytes.clone(), + record.cert_bytes.clone(), + ) + }) }) .collect::>() } else { @@ -528,6 +550,17 @@ async fn perform_lookup_multi( } } +async fn redis_host_blacklisted(conn: &mut C, host: &str) -> Result +where + C: redis::aio::ConnectionLike + Send + Sync, +{ + conn.sismember(redis_blacklist_key(), host) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + }) +} + // --------------------------------------------------------------------------- // HTTP response helpers // --------------------------------------------------------------------------- @@ -620,7 +653,7 @@ mod tests { path::PathBuf, }; - use ddns::core::MdnsEndpoint; + use ddns::core::{MdnsEndpoint, signature::SignatureFields}; use super::*; use crate::geo::{GeoPoint, GeoResolver}; @@ -643,7 +676,101 @@ mod tests { let mut hosts = HashMap::new(); hosts.insert(host.to_string(), vec![endpoint]); - (MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new()) + ResponseRecord::unsigned(MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new()) + } + + struct FakeRedis { + response: redis::Value, + packed_commands: Vec>, + } + + impl redis::aio::ConnectionLike for FakeRedis { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + self.packed_commands.push(cmd.get_packed_command()); + let response = self.response.clone(); + Box::pin(async move { Ok(response) }) + } + + fn req_packed_commands<'a>( + &'a mut self, + _cmd: &'a redis::Pipeline, + _offset: usize, + _count: usize, + ) -> redis::RedisFuture<'a, Vec> { + Box::pin(async move { Ok(Vec::new()) }) + } + + fn get_db(&self) -> i64 { + 0 + } + } + + #[tokio::test] + async fn redis_host_blacklisted_queries_external_blacklist_set() { + let mut redis = FakeRedis { + response: redis::Value::Int(1), + packed_commands: Vec::new(), + }; + + let blacklisted = redis_host_blacklisted(&mut redis, "blocked.example.genmeta.net") + .await + .unwrap(); + + assert!(blacklisted); + assert_eq!(redis.packed_commands.len(), 1); + let command = String::from_utf8(redis.packed_commands.remove(0)).unwrap(); + assert!(command.contains("SISMEMBER")); + assert!(command.contains(redis_blacklist_key())); + assert!(command.contains("blocked.example.genmeta.net")); + } + + #[tokio::test] + async fn memory_blacklist_returns_not_found_before_seed_records() { + let host = "blocked.example.genmeta.net"; + let mut seed_records = HashMap::new(); + seed_records.insert( + host.to_string(), + vec![lookup_record( + host, + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + None, + )], + ); + let state = AppState { + storage: Storage::Memory(MemoryStorage::with_blacklist([host.to_string()])), + require_signature: false, + ttl_secs: 30, + policies: Arc::new(crate::policy::DomainPolicies::default()), + seed_records: SeedRecords::new(seed_records), + geo: None, + }; + + let result = perform_lookup(&state, host, None, None).await.unwrap(); + + assert!(matches!(result, LookupResult::NotFound)); + } + + #[test] + fn normalize_lookup_records_keeps_signed_packets_whole() { + let mut record = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + None, + ); + record.signature_fields = SignatureFields { + content_digest: b"sha-256=:abc:".to_vec(), + signature_input: + b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"".to_vec(), + signature: b"dns=:sig:".to_vec(), + }; + + let normalized = normalize_lookup_records(vec![record.clone()]); + + assert_eq!(normalized.len(), 1); + assert_eq!(normalized[0], record); } #[test] @@ -745,7 +872,7 @@ mod tests { let sorted = sort_lookup_records_with_geo(vec![non_matching, matching.clone()], source_ip, &geo); - let (endpoint, _) = lookup_endpoint(&sorted[0].0).expect("sorted record should decode"); + let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); } @@ -768,7 +895,7 @@ mod tests { source_ip, ); - let (endpoint, _) = lookup_endpoint(&sorted[0].0).expect("sorted record should decode"); + let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); } @@ -806,7 +933,7 @@ mod tests { let ordered_ips = sorted .iter() .map(|record| { - lookup_endpoint(&record.0) + lookup_endpoint(&record.dns) .expect("record should decode") .0 .ip() @@ -845,7 +972,7 @@ mod tests { &geo, ); - let (endpoint, _) = lookup_endpoint(&sorted[0].0).expect("sorted record should decode"); + let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); } diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs index ca6c7c4..be7e880 100644 --- a/src/bin/ddns-server/main.rs +++ b/src/bin/ddns-server/main.rs @@ -17,7 +17,7 @@ use std::{ }; use clap::Parser; -use ddns::core::{MdnsEndpoint, MdnsPacket}; +use ddns::core::{MdnsEndpoint, MdnsPacket, wire::ResponseRecord}; use futures::future::BoxFuture; use h3x::{ dquic::{ @@ -125,7 +125,10 @@ fn build_seed_records(seed_records: &[SeedRecordConfig]) -> io::Result Result<(), Box> { let config = config.expand_paths(); let seed_records = build_seed_records(&config.seed_records)?; let geo = build_geo_resolver(&config)?; + let memory_blacklist = config + .blacklist + .iter() + .filter_map(|host| match error::normalize_host(host) { + Ok(host) => Some(host), + Err(error) => { + warn!(host, error = %error, "blacklist.invalid_host_ignored"); + None + } + }) + .collect::>(); // Build storage backend. let storage = match config.redis.clone() { Some(url) => { + if !memory_blacklist.is_empty() { + warn!( + count = memory_blacklist.len(), + "blacklist.config_ignored_when_redis_enabled" + ); + } let redis_cfg = deadpool_redis::Config::from_url(url); let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; Storage::Redis(redis_pool) } - None => Storage::Memory(MemoryStorage::new()), + None => Storage::Memory(MemoryStorage::with_blacklist(memory_blacklist)), }; // Build domain-policy rules from config file. @@ -340,6 +360,7 @@ mod tests { require_signature: Config::default_require_signature(), ttl_secs: Config::default_ttl_secs(), domain_policies: Vec::new(), + blacklist: Vec::new(), seed_records: Vec::new(), geoip_city_db: None, geoip_asn_db: None, diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs index 413ad4b..a2f67d0 100644 --- a/src/bin/ddns-server/policy.rs +++ b/src/bin/ddns-server/policy.rs @@ -1,31 +1,21 @@ -use ddns::core::parser::{packet::be_packet, record::RData}; -use dhttp_identity::identity::{RemoteAuthority, RemoteAuthorityCertificateExt}; -use snafu::ResultExt; +use ddns::core::{ + parser::{packet::be_packet, record::RData}, + signature::SignatureFields, +}; +use dhttp_identity::identity::RemoteAuthority; use tracing::{debug, warn}; use crate::error::{AppError, app_error, normalize_host}; -// --------------------------------------------------------------------------- -// Domain policy -// --------------------------------------------------------------------------- - -/// Per-domain publish / lookup behaviour. #[derive(Clone, Debug, PartialEq)] pub enum DomainPolicy { - /// Signature check controlled by `require_signature` flag; single record - /// per host; each publish overwrites the previous one. Standard, - /// No signature check; any authenticated node may publish; multiple records - /// with individual TTLs; ordered newest-first on lookup. OpenMulti, } -/// One rule in the domain-policy list. #[derive(Clone, Debug)] pub enum PolicyRule { - /// Matches only this exact (normalised) host. Exact(String), - /// Matches the host itself or any label-subdomain (future use). #[allow(dead_code)] Suffix(String), } @@ -41,7 +31,6 @@ impl PolicyRule { } } -/// Ordered list of (rule, policy) pairs; first match wins; default is Standard. #[derive(Clone, Debug, Default)] pub struct DomainPolicies(pub Vec<(PolicyRule, DomainPolicy)>); @@ -62,10 +51,6 @@ pub enum ValidatedDnsPacket { Empty, } -// --------------------------------------------------------------------------- -// Certificate helpers -// --------------------------------------------------------------------------- - pub fn extract_client_dns_sans(authority: &(impl RemoteAuthority + ?Sized)) -> Vec { use x509_parser::prelude::*; @@ -109,6 +94,8 @@ pub fn validate_dns_packet( packet: &[u8], require_signature: bool, authority: &(impl RemoteAuthority + ?Sized), + signature_fields: &SignatureFields, + expected_host: &str, ) -> Result { let (remaining, dns_packet) = be_packet(packet).map_err(|e| AppError::InvalidDnsPacket { message: e.to_string(), @@ -118,220 +105,48 @@ pub fn validate_dns_packet( } debug!( answers = dns_packet.answers.len(), - require_signature, "validating dns packet" + require_signature, + "validating dns packet" ); - let Some(first_answer) = dns_packet.answers.first() else { - debug!("dns packet has no answers"); - return Ok(ValidatedDnsPacket::Empty); - }; - - validate_endpoint_selectors(&dns_packet, authority)?; - if require_signature { - let has_signature = dns_packet - .answers - .iter() - .any(|record| matches!(record.data(), RData::E(endpoint) if endpoint.is_signed())); - - if !has_signature { + if signature_fields.is_empty() { return Err(AppError::SignatureRequired); } + let cert = authority + .cert_chain() + .first() + .ok_or(AppError::MissingClientCertificate)?; + let ok = signature_fields + .verify(packet, cert.as_ref()) + .map_err(|_| AppError::InvalidSignature)?; + if !ok { + return Err(AppError::InvalidSignature); + } + for record in &dns_packet.answers { if let RData::E(endpoint) = record.data() && endpoint.is_signed() { - let cert = authority - .cert_chain() - .first() - .ok_or(AppError::MissingClientCertificate)?; - let ok = endpoint - .verify_signature_from_der(cert.as_ref()) - .map_err(|_| AppError::InvalidSignature)?; - if !ok { - return Err(AppError::InvalidSignature); - } + return Err(AppError::InvalidSignature); } } } - Ok(ValidatedDnsPacket::Records { - host: first_answer.name().to_string(), - }) -} - -fn validate_endpoint_selectors( - dns_packet: &ddns::core::parser::packet::Packet, - authority: &(impl RemoteAuthority + ?Sized), -) -> Result<(), AppError> { - let mut endpoints = dns_packet - .answers - .iter() - .filter_map(|record| match record.data() { - RData::E(endpoint) => Some(endpoint), - _ => None, - }); - - let Some(first_endpoint) = endpoints.next() else { - return Ok(()); + let Some(first_answer) = dns_packet.answers.first() else { + debug!("dns packet has no answers"); + return Ok(ValidatedDnsPacket::Empty); }; - let expected = authority - .dhttp_subject_key_identifier() - .context(app_error::PublisherCertificateSelectorSnafu)? - .chain() - .clone(); - - let first = first_endpoint - .certificate_chain_key() - .context(app_error::EndpointRecordSelectorSnafu)?; - if first != expected { - return Err(AppError::EndpointSelectorMismatch); - } - - for endpoint in endpoints { - let actual = endpoint - .certificate_chain_key() - .context(app_error::EndpointRecordSelectorSnafu)?; - if actual != expected { - return Err(AppError::EndpointSelectorMismatch); - } - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; - use dhttp_identity::identity::RemoteAuthority; - use rustls::pki_types::CertificateDer; - - use super::*; - - #[derive(Debug)] - struct TestAuthority { - certs: Vec>, - } - - impl TestAuthority { - fn valid() -> Self { - Self { - certs: vec![CertificateDer::from( - include_bytes!("../../../tests/fixtures/valid.der").to_vec(), - )], - } - } - - fn missing_ski() -> Self { - Self { - certs: vec![CertificateDer::from( - include_bytes!("../../../tests/fixtures/missing.der").to_vec(), - )], - } - } - - fn malformed_ski() -> Self { - Self { - certs: vec![CertificateDer::from( - include_bytes!("../../../tests/fixtures/malformed.der").to_vec(), - )], - } - } - } - - impl RemoteAuthority for TestAuthority { - fn name(&self) -> &str { - "authority.example" + for answer in &dns_packet.answers { + let answer_host = normalize_host(&answer.name())?; + if answer_host != expected_host { + return Err(AppError::HostMismatch); } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.certs - } - } - - fn packet_with_endpoint(endpoint: EndpointAddr) -> Vec { - let hosts: HashMap> = - HashMap::from([("reimu.pilot.dhttp.net".to_owned(), vec![endpoint])]); - MdnsPacket::answer(0, &hosts).to_bytes() - } - - #[test] - fn validate_dns_packet_accepts_matching_certificate_selector() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - endpoint.set_sequence(0); - let packet = packet_with_endpoint(endpoint); - - let validated = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap(); - - assert!(matches!(validated, ValidatedDnsPacket::Records { .. })); - } - - #[test] - fn validate_dns_packet_rejects_mismatched_endpoint_kind() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(false); - endpoint.set_sequence(0); - let packet = packet_with_endpoint(endpoint); - - let error = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap_err(); - - assert!(matches!(error, AppError::EndpointSelectorMismatch)); } - #[test] - fn validate_dns_packet_rejects_mismatched_endpoint_sequence() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - endpoint.set_sequence(7); - let packet = packet_with_endpoint(endpoint); - - let error = validate_dns_packet(&packet, false, &TestAuthority::valid()).unwrap_err(); - - assert!(matches!(error, AppError::EndpointSelectorMismatch)); - } - - #[test] - fn validate_dns_packet_rejects_missing_publisher_ski() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - let packet = packet_with_endpoint(endpoint); - - let error = validate_dns_packet(&packet, false, &TestAuthority::missing_ski()).unwrap_err(); - - assert!(matches!( - error, - AppError::PublisherCertificateSelector { .. } - )); - } - - #[test] - fn validate_dns_packet_rejects_malformed_publisher_ski() { - let mut endpoint = EndpointAddr::direct_v4("192.0.2.10:4433".parse().unwrap()); - endpoint.set_main(true); - let packet = packet_with_endpoint(endpoint); - - let error = - validate_dns_packet(&packet, false, &TestAuthority::malformed_ski()).unwrap_err(); - - assert!(matches!( - error, - AppError::PublisherCertificateSelector { .. } - )); - } - - #[test] - fn validate_dns_packet_accepts_empty_packet_as_clear_operation() { - let hosts: HashMap> = - HashMap::from([("reimu.pilot.dhttp.net".to_owned(), Vec::new())]); - let packet = MdnsPacket::answer(0, &hosts).to_bytes(); - - let validated = validate_dns_packet(&packet, true, &TestAuthority::valid()).unwrap(); - - assert!(matches!(validated, ValidatedDnsPacket::Empty)); - } + Ok(ValidatedDnsPacket::Records { + host: first_answer.name().to_string(), + }) } diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish.rs index 860807e..c724a16 100644 --- a/src/bin/ddns-server/publish.rs +++ b/src/bin/ddns-server/publish.rs @@ -1,5 +1,8 @@ use std::{collections::HashSet, convert::Infallible, sync::Arc}; +use ddns::core::signature::{ + CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields, +}; use deadpool_redis::redis::{self, AsyncCommands}; use dhttp_identity::identity::RemoteAuthority; use h3x::{connection::ConnectionState, quic}; @@ -91,6 +94,8 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { } } + let signature_fields = signature_fields_from_headers(request.headers()); + let body = match request.into_body().collect().await { Ok(body) => body.to_bytes(), Err(e) => { @@ -109,7 +114,13 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { require_signature = require_sig, "validating publish packet" ); - let packet = match validate_dns_packet(body.as_ref(), require_sig, authority.as_ref()) { + let packet = match validate_dns_packet( + body.as_ref(), + require_sig, + authority.as_ref(), + &signature_fields, + &host, + ) { Ok(n) => n, Err(e) => { debug!(host = %host, error = %e, "publish packet rejected"); @@ -128,12 +139,27 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { return write_error(AppError::HostMismatch); } - publish_record(&state, &host, &body, authority.as_ref()).await + publish_record(&state, &host, &body, authority.as_ref(), signature_fields).await } ValidatedDnsPacket::Empty => clear_record(&state, &host, authority.as_ref()).await, } } +fn signature_fields_from_headers(headers: &http::HeaderMap) -> SignatureFields { + let header = |name: &'static str| { + headers + .get(name) + .map(|value| value.as_bytes().to_vec()) + .unwrap_or_default() + }; + + SignatureFields { + content_digest: header(CONTENT_DIGEST_HEADER), + signature_input: header(SIGNATURE_INPUT_HEADER), + signature: header(SIGNATURE_HEADER), + } +} + fn request_connection(request: &Request) -> Option>> { request .extensions() @@ -157,6 +183,7 @@ pub async fn publish_record( host: &str, body: &bytes::Bytes, authority: &(impl RemoteAuthority + ?Sized), + signature_fields: SignatureFields, ) -> Response { let cert_bytes = authority .cert_chain() @@ -210,6 +237,7 @@ pub async fn publish_record( fingerprint: fp, dns: body.to_vec(), cert: cert_bytes.clone(), + signature_fields: signature_fields.clone(), } .encode(); @@ -275,6 +303,7 @@ pub async fn publish_record( let record = Record { dns_bytes: body.to_vec(), cert_bytes, + signature_fields, expire, index_tags: record_index_tags(body.as_ref(), state.geo.as_deref()), }; @@ -428,15 +457,27 @@ mod tests { let packet_b = packet_for(host, 2); assert_eq!( - publish_record(&state, host, &packet_a, &authority_a) - .await - .status(), + publish_record( + &state, + host, + &packet_a, + &authority_a, + SignatureFields::empty() + ) + .await + .status(), http::StatusCode::OK ); assert_eq!( - publish_record(&state, host, &packet_b, &authority_b) - .await - .status(), + publish_record( + &state, + host, + &packet_b, + &authority_b, + SignatureFields::empty() + ) + .await + .status(), http::StatusCode::OK ); diff --git a/src/bin/ddns-server/storage.rs b/src/bin/ddns-server/storage.rs index 3e5864b..f5bed03 100644 --- a/src/bin/ddns-server/storage.rs +++ b/src/bin/ddns-server/storage.rs @@ -5,8 +5,12 @@ use std::{ }; use bytes::BufMut; -use dashmap::DashMap; -use ddns::core::parser::{packet::be_packet, record::RData}; +use dashmap::{DashMap, DashSet}; +use ddns::core::{ + parser::{packet::be_packet, record::RData}, + signature::SignatureFields, + wire::ResponseRecord, +}; use deadpool_redis::Pool; use nom::{ IResult, @@ -59,6 +63,10 @@ pub fn redis_asn_index_key(host: &str, asn: u32) -> String { format!("{host}:idx:asn:{asn}") } +pub fn redis_blacklist_key() -> &'static str { + "ddns:blacklist" +} + #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct RecordIndexTags { pub countries: Vec, @@ -108,10 +116,10 @@ pub fn record_index_tags(dns_bytes: &[u8], geo: Option<&GeoResolver>) -> RecordI /// /// Wire layout (big-endian, contiguous): /// ```text -/// +-----------+--------------+-----------+------+-----------+------+ -/// | expire | fingerprint | dns_len | dns | cert_len | cert | -/// | u64 BE | 32 bytes | u32 BE | ... | u32 BE | ... | -/// +-----------+--------------+-----------+------+-----------+------+ +/// +-----------+--------------+---------------+--------+-----------+------+-----------+------+-----------+------+-----------+------+ +/// | expire | fingerprint | digest_len | digest | input_len | input| sig_len | sig | dns_len | dns | cert_len | cert | +/// | u64 BE | 32 bytes | u32 BE | ... | u32 BE | ... | u32 BE | ... | u32 BE | ... | u32 BE | ... | +/// +-----------+--------------+---------------+--------+-----------+------+-----------+------+-----------+------+-----------+------+ /// ``` #[derive(Debug, Clone)] pub struct StoredRecord { @@ -126,24 +134,41 @@ pub struct StoredRecord { pub dns: Vec, /// DER-encoded leaf certificate of the publisher. pub cert: Vec, + /// Saved RFC-style publisher signature fields for the DNS packet. + pub signature_fields: SignatureFields, } impl StoredRecord { /// Encode to a byte buffer suitable for use as a Redis primary record value. pub fn encode(&self) -> Vec { - let mut buf = Vec::with_capacity(8 + 32 + 4 + self.dns.len() + 4 + self.cert.len()); + let mut buf = Vec::with_capacity( + 8 + 32 + + 4 + + self.signature_fields.content_digest.len() + + 4 + + self.signature_fields.signature_input.len() + + 4 + + self.signature_fields.signature.len() + + 4 + + self.dns.len() + + 4 + + self.cert.len(), + ); buf.put_u64(self.expire_unix_secs); buf.put_slice(&self.fingerprint); - buf.put_u32(self.dns.len() as u32); - buf.put_slice(&self.dns); - buf.put_u32(self.cert.len() as u32); - buf.put_slice(&self.cert); + put_field(&mut buf, &self.signature_fields.content_digest); + put_field(&mut buf, &self.signature_fields.signature_input); + put_field(&mut buf, &self.signature_fields.signature); + put_field(&mut buf, &self.dns); + put_field(&mut buf, &self.cert); buf } /// Decode from a Redis primary record value. Returns `None` on malformed input. pub fn decode(data: &[u8]) -> Option { - be_stored_record(data).ok().map(|(_, r)| r) + be_stored_record(data) + .ok() + .and_then(|(remain, r)| remain.is_empty().then_some(r)) } } @@ -151,21 +176,38 @@ impl StoredRecord { pub fn be_stored_record(input: &[u8]) -> IResult<&[u8], StoredRecord> { let (input, expire_unix_secs) = be_u64(input)?; let (input, fp_bytes) = take(32usize)(input)?; - let (input, dns_len) = be_u32(input)?; - let (input, dns) = take(dns_len as usize)(input)?; - let (input, cert_len) = be_u32(input)?; - let (input, cert) = take(cert_len as usize)(input)?; + let (input, content_digest) = be_field(input)?; + let (input, signature_input) = be_field(input)?; + let (input, signature) = be_field(input)?; + let (input, dns) = be_field(input)?; + let (input, cert) = be_field(input)?; Ok(( input, StoredRecord { expire_unix_secs, fingerprint: fp_bytes.try_into().expect("took exactly 32 bytes"), - dns: dns.to_vec(), - cert: cert.to_vec(), + dns, + cert, + signature_fields: SignatureFields { + content_digest, + signature_input, + signature, + }, }, )) } +fn put_field(buf: &mut Vec, value: &[u8]) { + buf.put_u32(value.len() as u32); + buf.put_slice(value); +} + +fn be_field(input: &[u8]) -> IResult<&[u8], Vec> { + let (input, len) = be_u32(input)?; + let (input, value) = take(len as usize)(input)?; + Ok((input, value.to_vec())) +} + // --------------------------------------------------------------------------- // Storage // --------------------------------------------------------------------------- @@ -175,6 +217,7 @@ pub fn be_stored_record(input: &[u8]) -> IResult<&[u8], StoredRecord> { pub struct Record { pub dns_bytes: Vec, pub cert_bytes: Vec, + pub signature_fields: SignatureFields, /// Wall-clock expiry (for TTL eviction). pub expire: Instant, /// Precomputed country / ASN buckets used by the Lite indexes. @@ -322,14 +365,36 @@ impl HostRecords { #[derive(Clone)] pub struct MemoryStorage { pub records: Arc>, + pub blacklist: Arc>, } impl MemoryStorage { pub fn new() -> Self { Self { records: Arc::new(DashMap::new()), + blacklist: Arc::new(DashSet::new()), } } + + pub fn with_blacklist(hosts: impl IntoIterator) -> Self { + let storage = Self::new(); + for host in hosts { + storage.blacklist_host(host); + } + storage + } + + pub fn blacklist_host(&self, host: impl Into) { + self.blacklist.insert(host.into()); + } + + pub fn remove_blacklist_host(&self, host: &str) { + self.blacklist.remove(host); + } + + pub fn is_blacklisted(&self, host: &str) -> bool { + self.blacklist.contains(host) + } } #[derive(Clone)] @@ -338,7 +403,7 @@ pub enum Storage { Memory(MemoryStorage), } -pub type LookupRecord = (Vec, Vec); +pub type LookupRecord = ResponseRecord; pub type SeedRecords = Arc>>; // --------------------------------------------------------------------------- @@ -367,6 +432,7 @@ mod tests { Record { dns_bytes: Vec::new(), cert_bytes: Vec::new(), + signature_fields: SignatureFields::empty(), expire: Instant::now() + tokio::time::Duration::from_secs(60), index_tags: RecordIndexTags { countries: country.into_iter().map(str::to_owned).collect(), @@ -400,4 +466,48 @@ mod tests { assert!(host.by_asn.is_empty()); assert!(host.records.is_empty()); } + + #[test] + fn stored_record_roundtrips_signature_fields() { + let record = StoredRecord { + expire_unix_secs: 123, + fingerprint: fp(7), + dns: vec![1, 2, 3], + cert: vec![4, 5, 6], + signature_fields: SignatureFields { + content_digest: b"sha-256=:abc:".to_vec(), + signature_input: + b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"" + .to_vec(), + signature: b"dns=:sig:".to_vec(), + }, + }; + + let decoded = StoredRecord::decode(&record.encode()).expect("stored record decodes"); + + assert_eq!(decoded.expire_unix_secs, record.expire_unix_secs); + assert_eq!(decoded.fingerprint, record.fingerprint); + assert_eq!(decoded.dns, record.dns); + assert_eq!(decoded.cert, record.cert); + assert_eq!(decoded.signature_fields, record.signature_fields); + } + + #[test] + fn redis_blacklist_key_is_stable() { + assert_eq!(redis_blacklist_key(), "ddns:blacklist"); + } + + #[test] + fn memory_storage_tracks_blacklisted_hosts() { + let storage = MemoryStorage::with_blacklist(["blocked.example".to_string()]); + + assert!(storage.is_blacklisted("blocked.example")); + assert!(!storage.is_blacklisted("allowed.example")); + + storage.blacklist_host("other.example"); + assert!(storage.is_blacklisted("other.example")); + + storage.remove_blacklist_host("blocked.example"); + assert!(!storage.is_blacklisted("blocked.example")); + } } diff --git a/src/core.rs b/src/core.rs index 308bbb2..936cfd9 100644 --- a/src/core.rs +++ b/src/core.rs @@ -1,4 +1,5 @@ pub mod parser; +pub mod signature; pub mod wire; pub type MdnsEndpoint = parser::record::endpoint::EndpointAddr; diff --git a/src/core/parser/record/endpoint.rs b/src/core/parser/record/endpoint.rs index 75167b5..cb89f02 100644 --- a/src/core/parser/record/endpoint.rs +++ b/src/core/parser/record/endpoint.rs @@ -8,7 +8,6 @@ use std::{ use base64::Engine; use bytes::BufMut; -use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind, CertificateSequence}; use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use nom::{ IResult, Parser, @@ -28,19 +27,12 @@ use crate::core::parser::{ #[derive(Debug, Snafu)] #[snafu(module)] pub enum SignEndpointError { - #[snafu(display("failed to determine endpoint signature scheme"))] - SignatureScheme { source: sigin::SignatureSchemeError }, #[snafu(display("failed to sign endpoint address"))] Sign { source: dhttp_identity::identity::SignError, }, -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum EndpointSelectorError { - #[snafu(display("endpoint record sequence does not fit certificate sequence"))] - SequenceTooLarge { sequence: u64 }, + #[snafu(display("no supported signature scheme for endpoint address"))] + NoSupportedScheme, } /// EndpointAddress record (Type E = 266) @@ -54,7 +46,7 @@ pub enum EndpointSelectorError { /// +-------+-----------------+--------------------+----------------+----------------------------+ /// | flags | sequence(varint)| addr | load(optional) | signature (optional) | /// +-------+-----------------+--------------------+----------------+----------------------------+ -/// | u8 | QUIC varint | see addr layout | f32 | scheme(u16)+len(varint)+N | +/// | u8 | QUIC varint | see addr layout | f32 | scheme(u16)+len(varint)+N | /// +-------+-----------------+--------------------+----------------+----------------------------+ /// /// addr layout: @@ -230,12 +222,17 @@ impl EndpointAddr { ) -> Result<(), SignEndpointError> { self.set_signed(true); let data = self.signed_data(); - let scheme = sigin::signature_scheme(authority.public_key()) - .context(sign_endpoint_error::SignatureSchemeSnafu)?; + + let scheme = authority + .cert_chain() + .first() + .and_then(|_| sigin::canonical_scheme_for_spki(authority.public_key())) + .ok_or(SignEndpointError::NoSupportedScheme)?; let signature = authority .sign(&data) .await .context(sign_endpoint_error::SignSnafu)?; + self.signature = Some(EndpointSignature { scheme: u16::from(scheme), signature, @@ -379,28 +376,6 @@ impl EndpointAddr { } } - pub fn certificate_chain_key(&self) -> Result { - let kind = if self.is_main() { - CertificateChainKind::Primary - } else { - CertificateChainKind::Secondary - }; - let sequence = self.sequence.map(VarInt::into_inner).unwrap_or(0); - if sequence > u64::from(u32::MAX) { - return endpoint_selector_error::SequenceTooLargeSnafu { sequence }.fail(); - } - let sequence = sequence as u32; - Ok(CertificateChainKey::new( - CertificateSequence::from(sequence), - kind, - )) - } - - pub fn set_certificate_chain_key(&mut self, chain: &CertificateChainKey) { - self.set_main(chain.kind() == CertificateChainKind::Primary); - self.set_sequence(u64::from(chain.sequence().get())); - } - pub fn load(&self) -> Option { self.load } @@ -779,6 +754,20 @@ impl TryFrom for DquicEndpointAddr { } } +pub async fn sign_endponit_address( + server_id: u8, + authority: Option<&(impl dhttp_identity::identity::LocalAuthority + ?Sized)>, + endpoint: DquicEndpointAddr, +) -> Option { + let mut ep: EndpointAddr = endpoint.try_into().ok()?; + ep.set_main(server_id == 0); + ep.set_sequence(server_id as u64); + if let Some(authority) = authority { + let _ = ep.sign_with_authority(authority).await; + } + Some(ep) +} + #[cfg(test)] mod tests { use std::{ @@ -787,78 +776,12 @@ mod tests { }; use bytes::BytesMut; - use dhttp_identity::certificate::{ - CertificateChainKey, CertificateChainKind, CertificateSequence, - }; use futures::future::BoxFuture; use ring::signature::KeyPair; - use rustls::{ - SignatureScheme, - sign::{Signer, SigningKey}, - }; + use rustls::sign::{Signer, SigningKey}; use super::*; - fn chain(sequence: u32, kind: CertificateChainKind) -> CertificateChainKey { - CertificateChainKey::new(CertificateSequence::from(sequence), kind) - } - - fn ed25519_spki(public_key: &[u8]) -> Vec { - let mut spki = Vec::with_capacity(44); - spki.extend_from_slice(&[ - 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, - ]); - spki.extend_from_slice(public_key); - spki - } - - #[test] - fn endpoint_selector_normalizes_missing_sequence_to_primary_zero() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5353); - let mut endpoint = EndpointAddr::direct_v4(addr); - endpoint.set_main(true); - - let selector = endpoint - .certificate_chain_key() - .expect("missing sequence normalizes to selector"); - - assert_eq!(selector, chain(0, CertificateChainKind::Primary)); - } - - #[test] - fn endpoint_selector_normalizes_missing_sequence_to_secondary_zero() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 5353); - let endpoint = EndpointAddr::direct_v4(addr); - - let selector = endpoint - .certificate_chain_key() - .expect("missing sequence normalizes to selector"); - - assert_eq!(selector, chain(0, CertificateChainKind::Secondary)); - } - - #[test] - fn endpoint_selector_sets_primary_and_secondary_chains() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 3), 5353); - let mut endpoint = EndpointAddr::direct_v4(addr); - - endpoint.set_certificate_chain_key(&chain(7, CertificateChainKind::Primary)); - assert!(endpoint.is_main()); - assert!(endpoint.is_clustered()); - assert_eq!( - endpoint.certificate_chain_key().unwrap(), - chain(7, CertificateChainKind::Primary) - ); - - endpoint.set_certificate_chain_key(&chain(0, CertificateChainKind::Secondary)); - assert!(!endpoint.is_main()); - assert!(!endpoint.is_clustered()); - assert_eq!( - endpoint.certificate_chain_key().unwrap(), - chain(0, CertificateChainKind::Secondary) - ); - } - #[test] fn legacy_endpoint_v4_direct_without_meta() { let port = 5353u16; @@ -1025,122 +948,12 @@ mod tests { } } - #[test] - fn signed_endpoint_accepts_scheme_inclusive_signature() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 10, 0, 7), 20004); - let scheme = u16::from(SignatureScheme::ED25519); - let signature = vec![0xaa; 64]; - let sig_len = VarInt::try_from(signature.len() as u64).unwrap(); - - let mut buf = BytesMut::new(); - buf.put_u8(EndpointAddr::FLAG_SIGNED); - buf.put_socket_addr_v4(&addr); - buf.put_u16(scheme); - buf.put_varint(sig_len); - buf.extend_from_slice(&signature); - - let (remain, decoded) = be_endpoint_addr(&buf).unwrap(); - - assert!(remain.is_empty()); - assert!(decoded.is_signed()); - assert_eq!(decoded.addr(), SocketAddr::V4(addr)); - assert_eq!(decoded.signature().unwrap().signature, signature); - } - - #[test] - fn signed_endpoint_rejects_signature_without_scheme() { - let addr = SocketAddrV4::new(Ipv4Addr::new(10, 10, 0, 7), 20004); - let signature = vec![0xaa; 64]; - let sig_len = VarInt::try_from(signature.len() as u64).unwrap(); - - let mut buf = BytesMut::new(); - buf.put_u8(EndpointAddr::FLAG_SIGNED); - buf.put_socket_addr_v4(&addr); - buf.put_varint(sig_len); - buf.extend_from_slice(&signature); - - assert!(be_endpoint_addr(&buf).is_err()); - } - - #[test] - fn signed_endpoint_writes_actual_scheme_before_signature_length() { - #[derive(Debug)] - struct Ed25519Key { - keypair: Arc, - spki: Vec, - } - - #[derive(Debug)] - struct Ed25519Signer(Arc); - - impl Signer for Ed25519Signer { - fn sign(&self, message: &[u8]) -> Result, rustls::Error> { - Ok(self.0.sign(message).as_ref().to_vec()) - } - - fn scheme(&self) -> SignatureScheme { - SignatureScheme::ED25519 - } - } - - impl SigningKey for Ed25519Key { - fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { - offered - .contains(&SignatureScheme::ED25519) - .then(|| Box::new(Ed25519Signer(self.keypair.clone())) as Box) - } - - fn algorithm(&self) -> rustls::SignatureAlgorithm { - rustls::SignatureAlgorithm::ED25519 - } - } - - impl dhttp_identity::identity::LocalAuthority for Ed25519Key { - fn name(&self) -> &str { - "authority.example" - } - - fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { - &[] - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(self.spki.as_slice()) - } - - fn sign( - &self, - data: &[u8], - ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - let result = dhttp_identity::identity::sign_with_key(self, data); - Box::pin(std::future::ready(result)) - } - } - - let rng = ring::rand::SystemRandom::new(); - let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).unwrap(); - let keypair = - Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).unwrap()); - let spki = ed25519_spki(keypair.public_key().as_ref()); - let key = Ed25519Key { keypair, spki }; - - let mut ep = EndpointAddr::direct_v4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353)); - futures::executor::block_on(ep.sign_with_authority(&key)).unwrap(); - - let mut buf = BytesMut::new(); - buf.put_endpoint_addr(&ep); - - let scheme_offset = 1 + 2 + 4; - let encoded_scheme = u16::from_be_bytes([buf[scheme_offset], buf[scheme_offset + 1]]); - assert_eq!(encoded_scheme, u16::from(SignatureScheme::ED25519)); - } - #[test] fn endpoint_signature_roundtrip_and_verify() { #[derive(Debug)] struct Ed25519Key { keypair: Arc, - spki: Vec, + cert_chain: Vec>, } #[derive(Debug)] @@ -1174,11 +987,7 @@ mod tests { } fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { - &[] - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + &self.cert_chain } fn sign( @@ -1194,10 +1003,14 @@ mod tests { let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).unwrap(); let keypair = Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).unwrap()); - let spki = ed25519_spki(keypair.public_key().as_ref()); + let mut spki = Vec::with_capacity(44); + spki.extend_from_slice(&[ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, + ]); + spki.extend_from_slice(keypair.public_key().as_ref()); let key = Ed25519Key { keypair: keypair.clone(), - spki: spki.clone(), + cert_chain: vec![rustls::pki_types::CertificateDer::from(spki.clone())], }; let addr = SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 5353); @@ -1229,41 +1042,42 @@ mod tests { } #[test] - fn sign_with_authority_stores_canonical_signature() { + fn sign_with_authority_uses_canonical_scheme_from_public_key() { #[derive(Debug)] - struct StaticAuthority { - spki: Vec, + struct Ed25519Authority { + cert_chain: Vec>, } - impl dhttp_identity::identity::LocalAuthority for StaticAuthority { + impl dhttp_identity::identity::LocalAuthority for Ed25519Authority { fn name(&self) -> &str { "authority.example" } fn cert_chain(&self) -> &[rustls::pki_types::CertificateDer<'static>] { - &[] - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(self.spki.as_slice()) + &self.cert_chain } fn sign( &self, _data: &[u8], ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - Box::pin(std::future::ready(Ok(vec![1, 2, 3]))) + Box::pin(async move { Ok(vec![1, 2, 3]) }) } } + let cert_chain = vec![rustls::pki_types::CertificateDer::from(vec![ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ])]; let mut ep = EndpointAddr::direct_v4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 5353)); - let authority = StaticAuthority { - spki: ed25519_spki(&[0; 32]), - }; - futures::executor::block_on(ep.sign_with_authority(&authority)).unwrap(); + futures::executor::block_on(ep.sign_with_authority(&Ed25519Authority { cert_chain })) + .unwrap(); let signature = ep.signature().unwrap(); - assert_eq!(signature.scheme, u16::from(SignatureScheme::ED25519)); + assert_eq!( + SignatureScheme::from(signature.scheme), + SignatureScheme::ED25519 + ); assert_eq!(signature.signature, vec![1, 2, 3]); } diff --git a/src/core/parser/sigin.rs b/src/core/parser/sigin.rs index 597d90e..2ed9c11 100644 --- a/src/core/parser/sigin.rs +++ b/src/core/parser/sigin.rs @@ -9,6 +9,105 @@ use x509_parser::{ x509::SubjectPublicKeyInfo, }; +pub const SIGNATURE_SCHEME_PREFERENCE: &[SignatureScheme] = &[ + SignatureScheme::ED25519, + SignatureScheme::ECDSA_NISTP256_SHA256, + SignatureScheme::ECDSA_NISTP384_SHA384, + SignatureScheme::RSA_PSS_SHA256, + SignatureScheme::RSA_PSS_SHA384, + SignatureScheme::RSA_PSS_SHA512, + SignatureScheme::RSA_PKCS1_SHA256, + SignatureScheme::RSA_PKCS1_SHA384, + SignatureScheme::RSA_PKCS1_SHA512, +]; + +pub fn signature_schemes_for_algorithm( + algorithm: rustls::SignatureAlgorithm, +) -> impl Iterator { + SIGNATURE_SCHEME_PREFERENCE + .iter() + .copied() + .filter(move |scheme| match algorithm { + rustls::SignatureAlgorithm::ED25519 => *scheme == SignatureScheme::ED25519, + rustls::SignatureAlgorithm::ECDSA => matches!( + scheme, + SignatureScheme::ECDSA_NISTP256_SHA256 | SignatureScheme::ECDSA_NISTP384_SHA384 + ), + rustls::SignatureAlgorithm::RSA => matches!( + scheme, + SignatureScheme::RSA_PSS_SHA256 + | SignatureScheme::RSA_PSS_SHA384 + | SignatureScheme::RSA_PSS_SHA512 + | SignatureScheme::RSA_PKCS1_SHA256 + | SignatureScheme::RSA_PKCS1_SHA384 + | SignatureScheme::RSA_PKCS1_SHA512 + ), + _ => true, + }) +} + +pub fn alg_name_for_scheme(scheme: SignatureScheme) -> Option<&'static str> { + match scheme { + SignatureScheme::ED25519 => Some("ed25519"), + SignatureScheme::ECDSA_NISTP256_SHA256 => Some("ecdsa-p256-sha256"), + SignatureScheme::ECDSA_NISTP384_SHA384 => Some("ecdsa-p384-sha384"), + SignatureScheme::RSA_PSS_SHA256 => Some("rsa-pss-sha256"), + SignatureScheme::RSA_PSS_SHA384 => Some("rsa-pss-sha384"), + SignatureScheme::RSA_PSS_SHA512 => Some("rsa-pss-sha512"), + SignatureScheme::RSA_PKCS1_SHA256 => Some("rsa-v1_5-sha256"), + SignatureScheme::RSA_PKCS1_SHA384 => Some("rsa-v1_5-sha384"), + SignatureScheme::RSA_PKCS1_SHA512 => Some("rsa-v1_5-sha512"), + _ => None, + } +} + +pub fn scheme_for_alg_name(alg: &str) -> Option { + match alg { + "ed25519" => Some(SignatureScheme::ED25519), + "ecdsa-p256-sha256" => Some(SignatureScheme::ECDSA_NISTP256_SHA256), + "ecdsa-p384-sha384" => Some(SignatureScheme::ECDSA_NISTP384_SHA384), + "rsa-pss-sha256" => Some(SignatureScheme::RSA_PSS_SHA256), + "rsa-pss-sha384" => Some(SignatureScheme::RSA_PSS_SHA384), + "rsa-pss-sha512" => Some(SignatureScheme::RSA_PSS_SHA512), + "rsa-v1_5-sha256" => Some(SignatureScheme::RSA_PKCS1_SHA256), + "rsa-v1_5-sha384" => Some(SignatureScheme::RSA_PKCS1_SHA384), + "rsa-v1_5-sha512" => Some(SignatureScheme::RSA_PKCS1_SHA512), + _ => None, + } +} + +pub fn canonical_scheme_for_spki(spki: SubjectPublicKeyInfoDer<'_>) -> Option { + let Ok((_remain, spki)) = SubjectPublicKeyInfo::from_der(spki.as_ref()) else { + return None; + }; + + if spki.algorithm.algorithm == OID_SIG_ED25519 { + return Some(SignatureScheme::ED25519); + } + + if spki.algorithm.algorithm == OID_PKCS1_RSAENCRYPTION { + return Some(SignatureScheme::RSA_PSS_SHA512); + } + + if spki.algorithm.algorithm != OID_KEY_TYPE_EC_PUBLIC_KEY { + return None; + } + + let curve = spki + .algorithm + .parameters + .as_ref() + .and_then(|parameters| parameters.as_oid().ok())?; + + if curve == OID_EC_P256 { + Some(SignatureScheme::ECDSA_NISTP256_SHA256) + } else if curve == OID_NIST_EC_P384 { + Some(SignatureScheme::ECDSA_NISTP384_SHA384) + } else { + None + } +} + #[derive(Debug, Snafu)] #[snafu(module)] pub enum SignError { @@ -95,6 +194,11 @@ pub(crate) fn verify( SignatureScheme::ECDSA_NISTP384_SHA384 => &ring::signature::ECDSA_P384_SHA384_ASN1, SignatureScheme::ECDSA_NISTP256_SHA256 => &ring::signature::ECDSA_P256_SHA256_ASN1, SignatureScheme::ED25519 => &ring::signature::ED25519, + SignatureScheme::RSA_PKCS1_SHA256 => &ring::signature::RSA_PKCS1_2048_8192_SHA256, + SignatureScheme::RSA_PKCS1_SHA384 => &ring::signature::RSA_PKCS1_2048_8192_SHA384, + SignatureScheme::RSA_PKCS1_SHA512 => &ring::signature::RSA_PKCS1_2048_8192_SHA512, + SignatureScheme::RSA_PSS_SHA256 => &ring::signature::RSA_PSS_2048_8192_SHA256, + SignatureScheme::RSA_PSS_SHA384 => &ring::signature::RSA_PSS_2048_8192_SHA384, SignatureScheme::RSA_PSS_SHA512 => &ring::signature::RSA_PSS_2048_8192_SHA512, _ => return verify_error::UnsupportedSchemeSnafu { scheme }.fail(), }; @@ -110,3 +214,4 @@ pub(crate) fn verify( .is_ok(), ) } + diff --git a/src/core/signature.rs b/src/core/signature.rs new file mode 100644 index 0000000..d90f879 --- /dev/null +++ b/src/core/signature.rs @@ -0,0 +1,322 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use base64::Engine; +use dhttp_identity::identity::{LocalAuthority, SignError as AuthoritySignError}; +use ring::digest::{SHA256, digest}; +use rustls::{SignatureScheme, pki_types::SubjectPublicKeyInfoDer}; +use snafu::Snafu; + +use crate::core::parser::sigin; + +pub const CONTENT_DIGEST_HEADER: &str = "content-digest"; +pub const SIGNATURE_INPUT_HEADER: &str = "signature-input"; +pub const SIGNATURE_HEADER: &str = "signature"; +pub const SIGNATURE_LABEL: &str = "dns"; + +const DIGEST_PREFIX: &str = "sha-256=:"; +const SIGNATURE_PREFIX: &str = "dns=:"; +const SIGNATURE_INPUT_PREFIX: &str = "dns=(\"content-digest\")"; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct SignatureFields { + pub content_digest: Vec, + pub signature_input: Vec, + pub signature: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ParsedSignatureInput<'a> { + signature_params: &'a str, + alg: &'a str, + keyid: &'a str, +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignatureFieldsError { + #[snafu(display("missing publisher certificate"))] + MissingCertificate, + #[snafu(display("unsupported signature scheme {scheme:?}"))] + UnsupportedScheme { scheme: SignatureScheme }, + #[snafu(display("unsupported signature algorithm {alg}"))] + UnsupportedAlgorithm { alg: String }, + #[snafu(display("invalid {field} field"))] + InvalidField { field: &'static str }, + #[snafu(display("invalid signature field utf-8"))] + InvalidUtf8 { source: std::str::Utf8Error }, + #[snafu(display("invalid base64"))] + InvalidBase64 { source: base64::DecodeError }, + #[snafu(display("content digest mismatch"))] + DigestMismatch, + #[snafu(display("signature keyid does not match publisher certificate"))] + KeyIdMismatch, + #[snafu(display("failed to sign DNS packet"))] + Sign { source: AuthoritySignError }, + #[snafu(display("invalid certificate: {details}"))] + InvalidCertificate { details: String }, + #[snafu(display("signature verification failed"))] + Verify { source: sigin::VerifyError }, +} + +impl SignatureFields { + pub fn empty() -> Self { + Self::default() + } + + pub fn is_empty(&self) -> bool { + self.content_digest.is_empty() + && self.signature_input.is_empty() + && self.signature.is_empty() + } + + pub async fn sign( + dns_bytes: &[u8], + authority: &(impl LocalAuthority + ?Sized), + ) -> Result { + let cert = authority + .cert_chain() + .first() + .ok_or(SignatureFieldsError::MissingCertificate)?; + let keyid = keyid_for_cert(cert.as_ref()); + let content_digest = content_digest_value(dns_bytes); + let created = unix_now_secs(); + + let scheme = sigin::canonical_scheme_for_spki(authority.public_key()).ok_or( + SignatureFieldsError::UnsupportedScheme { + scheme: SignatureScheme::Unknown(0), + }, + )?; + let alg = sigin::alg_name_for_scheme(scheme) + .ok_or(SignatureFieldsError::UnsupportedScheme { scheme })?; + let signature_input = signature_input_value(created, &keyid, alg); + let signature_base = signature_base(&content_digest, &signature_input)?; + let signature = authority + .sign(signature_base.as_bytes()) + .await + .map_err(|source| SignatureFieldsError::Sign { source })?; + let signature = signature_value(&signature); + + Ok(Self { + content_digest: content_digest.into_bytes(), + signature_input: signature_input.into_bytes(), + signature: signature.into_bytes(), + }) + } + + pub fn verify(&self, dns_bytes: &[u8], cert_der: &[u8]) -> Result { + if self.is_empty() { + return Ok(false); + } + + let content_digest = field_str(&self.content_digest)?; + verify_content_digest(content_digest, dns_bytes)?; + + let signature_input = field_str(&self.signature_input)?; + let parsed_input = parse_signature_input(signature_input)?; + let expected_keyid = keyid_for_cert(cert_der); + if parsed_input.keyid != expected_keyid { + return Ok(false); + } + + let scheme = sigin::scheme_for_alg_name(parsed_input.alg).ok_or_else(|| { + SignatureFieldsError::UnsupportedAlgorithm { + alg: parsed_input.alg.to_string(), + } + })?; + let signature = parse_signature(field_str(&self.signature)?)?; + let signature_base = signature_base(content_digest, signature_input)?; + + let (_, cert) = x509_parser::parse_x509_certificate(cert_der).map_err(|e| { + SignatureFieldsError::InvalidCertificate { + details: e.to_string(), + } + })?; + let spki = SubjectPublicKeyInfoDer::from(cert.tbs_certificate.subject_pki.raw); + sigin::verify(spki, scheme, signature_base.as_bytes(), &signature) + .map_err(|source| SignatureFieldsError::Verify { source }) + } +} + +pub fn content_digest_value(dns_bytes: &[u8]) -> String { + let digest = digest(&SHA256, dns_bytes); + let b64 = base64::engine::general_purpose::STANDARD.encode(digest.as_ref()); + format!("{DIGEST_PREFIX}{b64}:") +} + +pub fn cert_fingerprint_hex(cert_der: &[u8]) -> String { + digest(&SHA256, cert_der) + .as_ref() + .iter() + .map(|b| format!("{b:02x}")) + .collect() +} + +pub fn keyid_for_cert(cert_der: &[u8]) -> String { + format!("sha256:{}", cert_fingerprint_hex(cert_der)) +} + +fn unix_now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +fn signature_input_value(created: u64, keyid: &str, alg: &str) -> String { + format!( + "{SIGNATURE_LABEL}=(\"content-digest\");created={created};keyid=\"{keyid}\";alg=\"{alg}\"" + ) +} + +fn signature_value(signature: &[u8]) -> String { + let b64 = base64::engine::general_purpose::STANDARD.encode(signature); + format!("{SIGNATURE_PREFIX}{b64}:") +} + +fn signature_base( + content_digest: &str, + signature_input: &str, +) -> Result { + let parsed = parse_signature_input(signature_input)?; + Ok(format!( + "\"content-digest\": {content_digest}\n\"@signature-params\": {}", + parsed.signature_params + )) +} + +fn field_str(field: &[u8]) -> Result<&str, SignatureFieldsError> { + std::str::from_utf8(field).map_err(|source| SignatureFieldsError::InvalidUtf8 { source }) +} + +fn verify_content_digest( + content_digest: &str, + dns_bytes: &[u8], +) -> Result<(), SignatureFieldsError> { + let encoded = content_digest + .strip_prefix(DIGEST_PREFIX) + .and_then(|rest| rest.strip_suffix(':')) + .ok_or(SignatureFieldsError::InvalidField { + field: CONTENT_DIGEST_HEADER, + })?; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .map_err(|source| SignatureFieldsError::InvalidBase64 { source })?; + if decoded.as_slice() != digest(&SHA256, dns_bytes).as_ref() { + return Err(SignatureFieldsError::DigestMismatch); + } + Ok(()) +} + +fn parse_signature(input: &str) -> Result, SignatureFieldsError> { + let encoded = input + .strip_prefix(SIGNATURE_PREFIX) + .and_then(|rest| rest.strip_suffix(':')) + .ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_HEADER, + })?; + base64::engine::general_purpose::STANDARD + .decode(encoded) + .map_err(|source| SignatureFieldsError::InvalidBase64 { source }) +} + +fn parse_signature_input(input: &str) -> Result, SignatureFieldsError> { + if !input.starts_with(SIGNATURE_INPUT_PREFIX) { + return Err(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + }); + } + + let signature_params = + input + .strip_prefix("dns=") + .ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + let params = signature_params + .strip_prefix("(\"content-digest\")") + .ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + + let mut created = None; + let mut keyid = None; + let mut alg = None; + + for param in params.split(';').filter(|part| !part.is_empty()) { + if let Some(value) = param.strip_prefix("created=") { + created = value.parse::().ok(); + } else if let Some(value) = param.strip_prefix("keyid=") { + keyid = unquote(value); + } else if let Some(value) = param.strip_prefix("alg=") { + alg = unquote(value); + } else { + return Err(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + }); + } + } + + if created.is_none() { + return Err(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + }); + } + + let keyid = keyid.ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + let alg = alg.ok_or(SignatureFieldsError::InvalidField { + field: SIGNATURE_INPUT_HEADER, + })?; + + Ok(ParsedSignatureInput { + signature_params, + alg, + keyid, + }) +} + +fn unquote(value: &str) -> Option<&str> { + value.strip_prefix('"')?.strip_suffix('"') +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn content_digest_uses_sha256_dictionary_value() { + let value = content_digest_value(b"dns"); + assert!(value.starts_with("sha-256=:")); + assert!(value.ends_with(':')); + verify_content_digest(&value, b"dns").unwrap(); + assert!(matches!( + verify_content_digest(&value, b"changed"), + Err(SignatureFieldsError::DigestMismatch) + )); + } + + #[test] + fn signature_input_requires_alg_and_keyid() { + let input = "dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\""; + let parsed = parse_signature_input(input).unwrap(); + assert_eq!(parsed.keyid, "sha256:abc"); + assert_eq!(parsed.alg, "ed25519"); + + assert!(parse_signature_input("dns=(\"content-digest\");created=1").is_err()); + assert!(parse_signature_input("dns=(\"date\");created=1;alg=\"ed25519\"").is_err()); + } + + #[test] + fn alg_names_are_explicitly_mapped() { + assert_eq!( + sigin::scheme_for_alg_name("ed25519"), + Some(SignatureScheme::ED25519) + ); + assert_eq!( + sigin::alg_name_for_scheme(SignatureScheme::ECDSA_NISTP256_SHA256), + Some("ecdsa-p256-sha256") + ); + assert_eq!(sigin::scheme_for_alg_name("unknown"), None); + } +} diff --git a/src/core/wire.rs b/src/core/wire.rs index 9d3f539..028e455 100644 --- a/src/core/wire.rs +++ b/src/core/wire.rs @@ -11,9 +11,14 @@ use bytes::BufMut; use nom::{IResult, bytes::streaming::take, number::streaming::be_u32}; +use crate::core::signature::SignatureFields; + /// One DNS + certificate pair inside a [`MultiResponse`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ResponseRecord { + /// RFC 9421/9530-style publisher signature fields. Empty for unsigned + /// OpenMulti or static seed records. + pub signature_fields: SignatureFields, /// Serialised DNS packet bytes. pub dns: Vec, /// DER-encoded leaf certificate of the publisher, or empty when unavailable. @@ -21,6 +26,18 @@ pub struct ResponseRecord { } impl ResponseRecord { + pub fn new(signature_fields: SignatureFields, dns: Vec, cert: Vec) -> Self { + Self { + signature_fields, + dns, + cert, + } + } + + pub fn unsigned(dns: Vec, cert: Vec) -> Self { + Self::new(SignatureFields::empty(), dns, cert) + } + /// SHA-256 fingerprint of the publisher certificate as lowercase hex. /// Returns `None` when the cert field is empty. pub fn cert_fingerprint_hex(&self) -> Option { @@ -40,12 +57,9 @@ pub struct MultiResponse { } impl MultiResponse { - pub fn new(iter: impl IntoIterator, Vec)>) -> Self { + pub fn new(iter: impl IntoIterator) -> Self { Self { - records: iter - .into_iter() - .map(|(dns, cert)| ResponseRecord { dns, cert }) - .collect(), + records: iter.into_iter().collect(), } } @@ -53,7 +67,17 @@ impl MultiResponse { 4 + self .records .iter() - .map(|record| 4 + record.dns.len() + 4 + record.cert.len()) + .map(|record| { + 4 + record.signature_fields.content_digest.len() + + 4 + + record.signature_fields.signature_input.len() + + 4 + + record.signature_fields.signature.len() + + 4 + + record.dns.len() + + 4 + + record.cert.len() + }) .sum::() } @@ -72,39 +96,69 @@ impl WriteMultiResponse for B { fn put_multi_response(&mut self, response: &MultiResponse) { self.put_u32(response.records.len() as u32); for record in &response.records { - self.put_u32(record.dns.len() as u32); - self.put_slice(&record.dns); - self.put_u32(record.cert.len() as u32); - self.put_slice(&record.cert); + put_field(self, &record.signature_fields.content_digest); + put_field(self, &record.signature_fields.signature_input); + put_field(self, &record.signature_fields.signature); + put_field(self, &record.dns); + put_field(self, &record.cert); } } } +fn put_field(buf: &mut B, value: &[u8]) { + buf.put_u32(value.len() as u32); + buf.put_slice(value); +} + pub fn be_multi_response(input: &[u8]) -> IResult<&[u8], MultiResponse> { let (mut input, count) = be_u32(input)?; let mut records = Vec::with_capacity(count as usize); for _ in 0..count { - let (rest, dns_len) = be_u32(input)?; - let (rest, dns) = take(dns_len as usize)(rest)?; - let (rest, cert_len) = be_u32(rest)?; - let (rest, cert) = take(cert_len as usize)(rest)?; - records.push(ResponseRecord { - dns: dns.to_vec(), - cert: cert.to_vec(), - }); + let (rest, content_digest) = be_field(input)?; + let (rest, signature_input) = be_field(rest)?; + let (rest, signature) = be_field(rest)?; + let (rest, dns) = be_field(rest)?; + let (rest, cert) = be_field(rest)?; + records.push(ResponseRecord::new( + SignatureFields { + content_digest, + signature_input, + signature, + }, + dns, + cert, + )); input = rest; } Ok((input, MultiResponse { records })) } +fn be_field(input: &[u8]) -> IResult<&[u8], Vec> { + let (input, len) = be_u32(input)?; + let (input, value) = take(len as usize)(input)?; + Ok((input, value.to_vec())) +} + #[cfg(test)] mod tests { use super::*; #[test] fn multi_response_roundtrips() { - let response = - MultiResponse::new([(vec![1, 2, 3], vec![4, 5]), (vec![6, 7, 8, 9], Vec::new())]); + let response = MultiResponse::new([ + ResponseRecord::new( + SignatureFields { + content_digest: b"sha-256=:abc:".to_vec(), + signature_input: + b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"" + .to_vec(), + signature: b"dns=:sig:".to_vec(), + }, + vec![1, 2, 3], + vec![4, 5], + ), + ResponseRecord::unsigned(vec![6, 7, 8, 9], Vec::new()), + ]); let encoded = response.encode(); let (remain, decoded) = be_multi_response(&encoded).unwrap(); assert!(remain.is_empty()); diff --git a/src/publisher.rs b/src/publisher.rs index 13ba706..f542610 100644 --- a/src/publisher.rs +++ b/src/publisher.rs @@ -1,20 +1,33 @@ -mod address; -mod dispatch; -mod packet; - -use std::{any::TypeId, future::Future, io, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; - -pub use address::{ - AddressSelector, AddressView, AddressViewSource, EndpointBindingAddresses, FnAddressView, - PublishAddressGroup, PublishAddressScope, PublishAddresses, +use std::{ + any::{Any, TypeId}, + collections::{HashMap, HashSet}, + future::Future, + io, + net::SocketAddr, + pin::Pin, + sync::Arc, + time::Duration, }; -use dhttp_identity::{identity::LocalAuthority, name::Name}; + +use dhttp_identity::identity::LocalAuthority; +#[cfg(feature = "mdns-resolver")] +use dquic::qbase::net::Family; use dquic::{ - qinterface::component::location::AddressEvent, qresolve::Resolve, - qtraversal::nat::client::ClientLocationData, + qbase::net::addr::EndpointAddr, + qinterface::component::location::AddressEvent, + qresolve::{Publish, Resolve}, + qtraversal::nat::client::{ClientLocationData, NatType}, +}; +use snafu::{ResultExt, Snafu}; + +use crate::{ + core::{ + MdnsPacket, + parser::record::endpoint::EndpointAddr as DnsEndpointAddr, + signature::{SignatureFields, SignatureFieldsError}, + }, + resolvers::Resolvers, }; -pub use packet::{EndpointRecordSigner, SignEndpointRecordsError}; -use snafu::Snafu; pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); /// Upper bound for a single publish attempt in the background loop. @@ -39,8 +52,10 @@ pub enum CreatePublisherError { pub enum PublishOnceError { #[snafu(display("no publisher resolver available"))] NoPublisherResolver, - #[snafu(display("failed to sign endpoint records"))] - SignEndpointRecords { source: SignEndpointRecordsError }, + #[snafu(display("failed to encode endpoint address"))] + EncodeEndpoint, + #[snafu(display("failed to sign dns packet"))] + SignPacket { source: SignatureFieldsError }, #[snafu(display("failed to publish dns packet with {publisher}"))] Publish { publisher: String, @@ -48,130 +63,63 @@ pub enum PublishOnceError { }, } -pub trait PublisherResolver: Send + Sync + 'static { - fn as_resolver(&self) -> &(dyn Resolve + Send + Sync); -} - -impl PublisherResolver for T -where - T: Resolve + Send + Sync + Sized + 'static, -{ - fn as_resolver(&self) -> &(dyn Resolve + Send + Sync) { - self - } -} - -impl PublisherResolver for dyn Resolve + Send + Sync { - fn as_resolver(&self) -> &(dyn Resolve + Send + Sync) { - self - } -} - -pub struct Publisher { - signer: EndpointRecordSigner, - resolver: Arc, -} - -impl std::fmt::Debug for Publisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Publisher") - .field("signer", &self.signer) - .field("resolver", &self.resolver.as_resolver().to_string()) - .finish() - } -} - -pub type EndpointPublisher = Publisher; - -impl Publisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, -{ - pub fn new(signer: EndpointRecordSigner, resolver: Arc) -> Self { - Self { signer, resolver } - } - - pub fn signer(&self) -> &EndpointRecordSigner { - &self.signer - } - - pub fn resolver(&self) -> &Arc { - &self.resolver - } - - pub async fn publish_once( - &self, - name: &Name<'_>, - addresses: &V, - ) -> Result<(), PublishOnceError> - where - V: AddressView + Sync, - { - let mut published = false; - published |= self - .publish_to_resolver(self.resolver.as_resolver(), name, addresses) - .await?; - - if !published { - return publish_once_error::NoPublisherResolverSnafu.fail(); - } - - Ok(()) - } +/// Optional metadata applied to endpoint records before signing. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct PublishOptions { + /// Stable server identifier for names served by multiple publishers. + /// + /// `0` marks the endpoint as the main record. Non-zero values mark the + /// record as clustered and encode the identifier as its sequence number. + pub server_id: Option, } -pub struct EndpointPublicationLoop { - name: Name<'static>, - publisher: Publisher, - source: S, +pub struct Publisher { + identity: Arc, + network: Arc, + resolver: Arc, + bind_patterns: Arc>, interval: Duration, publish_timeout: Duration, + options: PublishOptions, } -impl std::fmt::Debug for EndpointPublicationLoop -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, - S: std::fmt::Debug, -{ +impl std::fmt::Debug for Publisher { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EndpointPublicationLoop") - .field("name", &self.name) - .field("publisher", &self.publisher) - .field("source", &self.source) + f.debug_struct("Publisher") + .field("identity", &self.identity.name()) + .field("bind_patterns", &self.bind_patterns) .field("interval", &self.interval) .field("publish_timeout", &self.publish_timeout) - .finish() + .field("options", &self.options) + .finish_non_exhaustive() } } -impl EndpointPublicationLoop -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, - S: AddressViewSource + Send + Sync, -{ - pub fn new(name: Name<'static>, publisher: Publisher, source: S) -> Self { +impl Publisher { + pub fn new( + identity: Arc, + network: Arc, + resolver: Arc, + bind_patterns: Arc>, + ) -> Self { Self { - name, - publisher, - source, + identity, + network, + resolver, + bind_patterns, interval: DEFAULT_PUBLISH_INTERVAL, publish_timeout: DEFAULT_PUBLISH_TIMEOUT, + options: PublishOptions::default(), } } - pub fn name(&self) -> &Name<'static> { - &self.name + pub fn with_options(mut self, options: PublishOptions) -> Self { + self.options = options; + self } - pub fn publisher(&self) -> &Publisher { - &self.publisher + pub fn options(&self) -> PublishOptions { + self.options } pub fn interval(&self) -> Duration { @@ -187,8 +135,27 @@ where self } + pub async fn publish_once(&self) -> Result<(), PublishOnceError> { + let mut published = false; + let public_endpoints = self.public_endpoints(); + tracing::debug!( + endpoint_count = public_endpoints.len(), + endpoints = ?public_endpoints, + "publishing public endpoints" + ); + published |= self + .publish_to_resolver(self.resolver.as_ref(), &public_endpoints) + .await?; + + if !published { + return publish_once_error::NoPublisherResolverSnafu.fail(); + } + + Ok(()) + } + pub async fn run(&self) -> ! { - let mut locations = self.source.subscribe(); + let mut locations = self.network.quic().locations().subscribe(); let interval = tokio::time::sleep(self.interval); tokio::pin!(interval); // Keep at most one publish attempt in flight. A timer tick or @@ -211,7 +178,7 @@ where let Some((bind_uri, event)) = event else { continue; }; - if !self.source.observes(&bind_uri) { + if !self.bind_patterns.iter().any(|pattern| pattern.matches(&bind_uri)) { continue; } if !Self::location_event_requires_publish(&event) { @@ -241,20 +208,14 @@ where timeout_ms = self.publish_timeout.as_millis(), "starting dns publish attempt" ); - let addresses = self.source.address_view(); - match tokio::time::timeout( - self.publish_timeout, - self.publisher.publish_once(&self.name, &addresses), - ) - .await - { + match tokio::time::timeout(self.publish_timeout, self.publish_once()).await { Ok(Ok(())) => { - tracing::info!(name = %self.name, "published resolver endpoints"); + tracing::info!("published resolver endpoints"); true } Ok(Err(error)) => { let report = snafu::Report::from_error(&error); - tracing::warn!(error = %report, name = %self.name, "dns publish failed"); + tracing::warn!(error = %report, "dns publish failed"); false } Err(_elapsed) => { @@ -265,7 +226,6 @@ where self.clear_publish_state(); tracing::warn!( timeout_ms = self.publish_timeout.as_millis(), - name = %self.name, "dns publish timed out" ); false @@ -273,10 +233,6 @@ where } } - fn clear_publish_state(&self) { - dispatch::clear_resolver_publish_state(self.publisher.resolver.as_resolver()); - } - fn location_event_requires_publish(event: &AddressEvent) -> bool { match event { AddressEvent::Upsert(data) => { @@ -298,24 +254,343 @@ where AddressEvent::Closed => true, } } + + async fn publish_to_resolver( + &self, + resolver: &(dyn Resolve + Send + Sync), + public_endpoints: &[EndpointAddr], + ) -> Result { + let any: &dyn Any = resolver; + + if let Some(resolvers) = any.downcast_ref::() { + let mut published = false; + for resolver in resolvers.iter() { + published |= self + .publish_single_resolver(resolver.as_ref(), public_endpoints) + .await?; + } + return Ok(published); + } + + self.publish_single_resolver(resolver, public_endpoints) + .await + } + + fn clear_publish_state(&self) { + Self::clear_resolver_publish_state(self.resolver.as_ref()); + } + + fn clear_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { + let any: &dyn Any = resolver; + + if let Some(resolvers) = any.downcast_ref::() { + for resolver in resolvers.iter() { + Self::clear_resolver_publish_state(resolver.as_ref()); + } + } + + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = + any.downcast_ref::>() + { + h3.clear_pool(); + } + } + + async fn publish_single_resolver( + &self, + resolver: &(dyn Resolve + Send + Sync), + public_endpoints: &[EndpointAddr], + ) -> Result { + #[cfg(not(any(feature = "http-resolver", feature = "h3x-resolver")))] + let _ = public_endpoints; + + let any: &dyn Any = resolver; + + #[cfg(feature = "http-resolver")] + if let Some(http) = any.downcast_ref::() { + self.publish_signed_http_endpoints(http, public_endpoints) + .await?; + return Ok(true); + } + + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = + any.downcast_ref::>() + { + self.publish_signed_h3_endpoints(h3, public_endpoints) + .await?; + return Ok(true); + } + + #[cfg(feature = "mdns-resolver")] + if let Some(mdns) = any.downcast_ref::() { + let mut published = false; + for bound in mdns.bound_resolvers() { + let endpoints = self.local_endpoints_for(&bound.device, bound.family); + self.publish_plain_endpoints(&bound.resolver, &endpoints) + .await?; + published = true; + } + return Ok(published); + } + + Ok(false) + } + + async fn publish_plain_endpoints( + &self, + publisher: &(dyn Publish + Send + Sync), + endpoints: &[EndpointAddr], + ) -> Result<(), PublishOnceError> { + let packet = self.dns_packet(endpoints)?; + let name = self.identity.name(); + tracing::debug!( + publisher = %publisher, + name, + endpoint_count = endpoints.len(), + packet_len = packet.len(), + "publishing dns packet" + ); + publisher + .publish(name, &packet) + .await + .context(publish_once_error::PublishSnafu { + publisher: publisher.to_string(), + }) + } + + #[cfg(feature = "http-resolver")] + async fn publish_signed_http_endpoints( + &self, + publisher: &crate::resolvers::http::HttpResolver, + endpoints: &[EndpointAddr], + ) -> Result<(), PublishOnceError> { + let (packet, signature_fields) = self.signed_packet(endpoints).await?; + let name = self.identity.name(); + tracing::debug!( + publisher = %publisher, + name, + endpoint_count = endpoints.len(), + packet_len = packet.len(), + "publishing signed dns packet" + ); + + publisher + .publish_signed(name, &packet, &signature_fields) + .await + .context(publish_once_error::PublishSnafu { + publisher: publisher.to_string(), + }) + } + + #[cfg(feature = "h3x-resolver")] + async fn publish_signed_h3_endpoints( + &self, + publisher: &crate::resolvers::h3::H3Resolver, + endpoints: &[EndpointAddr], + ) -> Result<(), PublishOnceError> { + let (packet, signature_fields) = self.signed_packet(endpoints).await?; + let name = self.identity.name(); + tracing::debug!( + publisher = %publisher, + name, + endpoint_count = endpoints.len(), + packet_len = packet.len(), + "publishing signed dns packet" + ); + + publisher + .publish_signed(name, &packet, &signature_fields) + .await + .context(publish_once_error::PublishSnafu { + publisher: publisher.to_string(), + }) + } + + async fn signed_packet( + &self, + endpoints: &[EndpointAddr], + ) -> Result<(Vec, SignatureFields), PublishOnceError> { + let packet = self.dns_packet(endpoints)?; + let signature_fields = SignatureFields::sign(&packet, self.identity.as_ref()) + .await + .context(publish_once_error::SignPacketSnafu)?; + Ok((packet, signature_fields)) + } + + fn dns_packet(&self, endpoints: &[EndpointAddr]) -> Result, PublishOnceError> { + let mut encoded = Vec::with_capacity(endpoints.len()); + for endpoint in endpoints { + let mut endpoint = DnsEndpointAddr::try_from(*endpoint) + .map_err(|_| publish_once_error::EncodeEndpointSnafu.build())?; + if let Some(server_id) = self.options.server_id { + endpoint.set_main(server_id == 0); + endpoint.set_sequence(server_id.into()); + } + encoded.push(endpoint); + } + + let mut hosts = HashMap::new(); + hosts.insert(self.identity.name().to_owned(), encoded); + Ok(MdnsPacket::answer(0, &hosts).to_bytes()) + } + + fn public_endpoints(&self) -> Vec { + let mut endpoints = Vec::new(); + let mut seen = HashSet::new(); + for pattern in self.bind_patterns.iter() { + let Some(ifaces) = self.network.quic().get_interfaces(pattern) else { + tracing::trace!(?pattern, "no interfaces for bind pattern"); + continue; + }; + for iface in ifaces { + for endpoint in public_endpoints_from_iface(&self.network, &iface) { + push_unique_endpoint(&mut endpoints, &mut seen, endpoint); + } + } + } + endpoints + } + + #[cfg(feature = "mdns-resolver")] + fn local_endpoints_for(&self, device: &str, family: Family) -> Vec { + let mut endpoints = HashSet::new(); + for pattern in self.bind_patterns.iter() { + let Some(ifaces) = self.network.quic().get_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + let bind_uri = iface.bind_uri(); + let Some((iface_family, iface_device, _port)) = bind_uri.as_iface_bind_uri() else { + continue; + }; + if iface_family != family || iface_device != device { + continue; + } + if let Some(endpoint) = local_endpoint_from_iface(&iface, family) { + endpoints.insert(endpoint); + } + } + } + endpoints.into_iter().collect() + } +} + +fn push_unique_endpoint( + endpoints: &mut Vec, + seen: &mut HashSet, + endpoint: EndpointAddr, +) { + if seen.insert(endpoint) { + endpoints.push(endpoint); + } +} + +fn public_endpoints_from_iface( + network: &h3x::dquic::Network, + iface: &h3x::dquic::net::BindInterface, +) -> Vec { + use h3x::dquic::{net::IO, qtraversal::nat::client::StunClientsComponent}; + + iface.with_components(|components, current| { + let bind_uri = current.bind_uri(); + let addr = current.bound_addr().ok(); + let mut endpoints: Vec = components + .get::() + .map(|stun| { + stun.with_clients(|clients| { + clients + .values() + .filter_map(|client| { + let outer = client.get_outer_addr()?.ok()?; + let bound = current.bound_addr().ok()?; + match client.get_nat_type() { + Some(Ok(nat_type)) => Some(publish_endpoint_from_stun( + bound, + client.agent_addr(), + outer, + nat_type, + )), + None => Some(EndpointAddr::with_agent(client.agent_addr(), outer)), + Some(Err(_)) => None, + } + }) + .collect() + }) + }) + .unwrap_or_default(); + let stun_endpoint_count = endpoints.len(); + + // Also publish the current default-route address. STUN-derived + // endpoints make the node reachable from outside the local network, + // while the bound address is still the shortest valid path for peers + // on the same link and for separate local client processes on the + // same host. Keep it after STUN endpoints so translated-NAT peers get + // the externally reachable candidate first. + if let Some(addr) = addr + && network.bound_addr_is_on_default_route(&bind_uri, addr) + { + endpoints.push(EndpointAddr::direct(addr)); + } + + tracing::trace!( + bind_uri = %bind_uri, + bound_addr = ?addr, + stun_endpoint_count, + endpoint_count = endpoints.len(), + endpoints = ?endpoints, + "collected public endpoints from interface" + ); + + endpoints + }) +} + +fn publish_endpoint_from_stun( + bound: SocketAddr, + agent: SocketAddr, + outer: SocketAddr, + nat_type: NatType, +) -> EndpointAddr { + if nat_type == NatType::FullCone && bound == outer { + EndpointAddr::direct(outer) + } else { + EndpointAddr::with_agent(agent, outer) + } } -pub type EndpointPublisherLoop = EndpointPublicationLoop< - dyn LocalAuthority + Send + Sync, - dyn Resolve + Send + Sync, - EndpointBindingAddresses, ->; +#[cfg(feature = "mdns-resolver")] +fn local_endpoint_from_iface( + iface: &h3x::dquic::net::BindInterface, + family: Family, +) -> Option { + use h3x::dquic::net::IO; + + iface.with_components(|_components, current| { + let addr = current.bound_addr().ok()?; + match (family, addr) { + (Family::V4, std::net::SocketAddr::V4(_)) + | (Family::V6, std::net::SocketAddr::V6(_)) => Some(EndpointAddr::direct(addr)), + _ => None, + } + }) +} #[cfg(test)] mod tests { - #[cfg(feature = "http-resolver")] - use std::sync::atomic::{AtomicUsize, Ordering}; - use std::{fmt, sync::Arc, time::Duration}; + use std::{ + fmt, + sync::{ + Arc, OnceLock, + atomic::{AtomicUsize, Ordering}, + }, + time::Duration, + }; use dquic::qresolve::{ResolveFuture, Source}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream}; - use rustls::pki_types::{CertificateDer, SubjectPublicKeyInfoDer}; - #[cfg(feature = "http-resolver")] + use rustls::pki_types::CertificateDer; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::*; @@ -323,39 +598,27 @@ mod tests { #[derive(Debug)] struct TestAuthority; - const ED25519_TEST_SPKI: [u8; 44] = [ - 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - impl LocalAuthority for TestAuthority { fn name(&self) -> &str { "authority.example" } fn cert_chain(&self) -> &[CertificateDer<'static>] { - static CERTS: std::sync::LazyLock>> = - std::sync::LazyLock::new(|| { - vec![CertificateDer::from( - include_bytes!("../tests/fixtures/valid.der").to_vec(), - )] - }); - CERTS.as_slice() - } - - fn public_key(&self) -> SubjectPublicKeyInfoDer<'_> { - SubjectPublicKeyInfoDer::from(ED25519_TEST_SPKI.as_slice()) + static CERT_CHAIN: OnceLock>> = OnceLock::new(); + CERT_CHAIN.get_or_init(|| { + vec![CertificateDer::from(vec![ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, + ])] + }) } fn sign( &self, _data: &[u8], ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - // Match the Ed25519 signature length used by DHTTP's canonical - // key-to-scheme policy. Short fake signatures can collide with - // legacy E-record fixed RDLENGTH values during parser - // compatibility dispatch. - Box::pin(async move { Ok(vec![0x2a; 64]) }) + Box::pin(async move { Ok(vec![1, 2, 3]) }) } } @@ -370,114 +633,116 @@ mod tests { impl Resolve for DisplayOnlyResolver { fn lookup<'l>(&'l self, _name: &'l str) -> ResolveFuture<'l> { - async { Ok(stream::empty::<(Source, dquic::qbase::net::addr::EndpointAddr)>().boxed()) } - .boxed() + async { Ok(stream::empty::<(Source, EndpointAddr)>().boxed()) }.boxed() } } - fn test_name() -> Name<'static> { - "authority.example".parse().unwrap() - } - - fn test_publisher(resolver: Arc) -> Publisher - where - R: Resolve + Send + Sync, - { - let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); - Publisher::new(signer, resolver) - } - - fn test_source(network: Arc) -> EndpointBindingAddresses { - EndpointBindingAddresses::new( - network, - Arc::new(vec![ - "inet://127.0.0.1:0".parse().expect("valid bind pattern"), - ]), - ) - } - #[tokio::test] async fn publish_once_reports_no_publisher_resolver() { - let publisher = test_publisher(Arc::new(DisplayOnlyResolver)); - let addresses = - PublishAddresses::new().wide_area([dquic::qbase::net::addr::EndpointAddr::direct( - "127.0.0.1:443".parse().unwrap(), - )]); - - let error = publisher - .publish_once(&test_name(), &addresses) - .await - .unwrap_err(); + let publisher = Publisher::new( + Arc::new(TestAuthority), + h3x::dquic::Network::builder().build(), + Arc::new(DisplayOnlyResolver), + Arc::new(Vec::new()), + ); + let error = publisher.publish_once().await.unwrap_err(); assert!(matches!(error, PublishOnceError::NoPublisherResolver)); } #[tokio::test] async fn publisher_timeout_is_configurable() { - let network = h3x::dquic::Network::builder().build(); - let publisher = test_publisher(Arc::new(DisplayOnlyResolver)); - let publisher_loop = - EndpointPublicationLoop::new(test_name(), publisher, test_source(network)); - assert_eq!(publisher_loop.publish_timeout(), DEFAULT_PUBLISH_TIMEOUT); + let publisher = Publisher::new( + Arc::new(TestAuthority), + h3x::dquic::Network::builder().build(), + Arc::new(DisplayOnlyResolver), + Arc::new(Vec::new()), + ); + assert_eq!(publisher.publish_timeout(), DEFAULT_PUBLISH_TIMEOUT); let timeout = Duration::from_secs(3); - let publisher_loop = publisher_loop.with_publish_timeout(timeout); - assert_eq!(publisher_loop.publish_timeout(), timeout); + let publisher = publisher.with_publish_timeout(timeout); + assert_eq!(publisher.publish_timeout(), timeout); } #[tokio::test] - async fn signer_applies_certificate_selector_from_authority_ski() { - let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); - let name: Name<'static> = "authority.example".parse().unwrap(); - - let endpoint = - dquic::qbase::net::addr::EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); - let packet = signer.signed_packet(&name, &[endpoint]).await.unwrap(); - let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); - let record = packet.answers.first().expect("endpoint answer"); - let crate::core::parser::record::RData::E(endpoint) = record.data() else { - panic!("expected endpoint record"); - }; - - assert!(endpoint.is_main()); - assert!(!endpoint.is_clustered()); - assert!(endpoint.is_signed()); - assert_eq!( - endpoint.certificate_chain_key().unwrap().sequence().get(), - 0 - ); - } - - #[tokio::test] - async fn signer_uses_supplied_record_owner_name() { - let signer = EndpointRecordSigner::new(Arc::new(TestAuthority)); - let name: Name<'static> = "nat.genmeta.net".parse().unwrap(); + async fn dns_packet_applies_publish_options_server_id() { + let publisher = Publisher::new( + Arc::new(TestAuthority), + h3x::dquic::Network::builder().build(), + Arc::new(DisplayOnlyResolver), + Arc::new(Vec::new()), + ) + .with_options(PublishOptions { server_id: Some(2) }); - let endpoint = - dquic::qbase::net::addr::EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); - let packet = signer.signed_packet(&name, &[endpoint]).await.unwrap(); + let endpoint = EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); + let packet = publisher.dns_packet(&[endpoint]).unwrap(); let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); let record = packet.answers.first().expect("endpoint answer"); let crate::core::parser::record::RData::E(endpoint) = record.data() else { panic!("expected endpoint record"); }; - assert_eq!(record.name().to_string(), "nat.genmeta.net"); - assert!(endpoint.is_main()); - assert!(!endpoint.is_clustered()); - assert!(endpoint.is_signed()); + assert!(!endpoint.is_main()); + assert!(endpoint.is_clustered()); + assert!(!endpoint.is_signed()); } #[tokio::test] - async fn binding_address_view_does_not_expose_loopback_as_wide_area_without_stun() { + async fn public_endpoints_do_not_fall_back_to_local_bound_addresses() { let network = h3x::dquic::Network::builder().build(); let bind_pattern: h3x::dquic::binds::BindPattern = "inet://127.0.0.1:0".parse().expect("valid bind pattern"); let _bind = network.quic().bind(bind_pattern.clone()).await; - let source = EndpointBindingAddresses::new(network, Arc::new(vec![bind_pattern])); - let view = source.address_view(); + let publisher = Publisher::new( + Arc::new(TestAuthority), + network, + Arc::new(DisplayOnlyResolver), + Arc::new(vec![bind_pattern]), + ); + + assert!( + publisher.public_endpoints().is_empty(), + "public DNS publishing must wait for STUN-derived external endpoints; local addresses are published through mDNS" + ); + } - assert!(view.endpoints(AddressSelector::WideArea).next().is_none()); + #[test] + fn push_unique_endpoint_preserves_first_seen_order() { + let agent = EndpointAddr::with_agent( + "10.10.0.2:20004".parse().expect("valid agent addr"), + "10.10.0.10:45635".parse().expect("valid outer addr"), + ); + let direct = EndpointAddr::direct("10.110.0.10:45635".parse().expect("valid direct addr")); + let mut endpoints = Vec::new(); + let mut seen = HashSet::new(); + + push_unique_endpoint(&mut endpoints, &mut seen, agent); + push_unique_endpoint(&mut endpoints, &mut seen, direct); + push_unique_endpoint(&mut endpoints, &mut seen, agent); + + assert_eq!(endpoints, vec![agent, direct]); + } + + #[test] + fn full_cone_nat_endpoint_preserves_agent_when_outer_differs_from_bound_addr() { + let bound = "10.110.0.10:45635".parse().expect("valid bound addr"); + let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); + let outer = "10.10.0.10:45635".parse().expect("valid outer addr"); + + let endpoint = publish_endpoint_from_stun(bound, agent, outer, NatType::FullCone); + + assert_eq!(endpoint, EndpointAddr::with_agent(agent, outer)); + } + + #[test] + fn full_cone_endpoint_is_direct_without_address_translation() { + let bound = "10.10.0.100:45635".parse().expect("valid bound addr"); + let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); + + let endpoint = publish_endpoint_from_stun(bound, agent, bound, NatType::FullCone); + + assert_eq!(endpoint, EndpointAddr::direct(bound)); } #[cfg(feature = "http-resolver")] @@ -529,13 +794,18 @@ mod tests { crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) .expect("valid http resolver"), ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source); - publisher_loop.interval = Duration::from_secs(60); + let mut publisher = Publisher::new( + Arc::new(TestAuthority), + network.clone(), + resolver, + Arc::new(vec![ + "inet://127.0.0.1:0".parse().expect("valid bind pattern"), + ]), + ); + publisher.interval = Duration::from_secs(60); let publisher = tokio::spawn(async move { - publisher_loop.run().await; + publisher.run().await; }); wait_for_count(&publish_count, 1).await; @@ -614,11 +884,16 @@ mod tests { crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) .expect("valid http resolver"), ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source); + let publisher = Publisher::new( + Arc::new(TestAuthority), + network.clone(), + resolver, + Arc::new(vec![ + "inet://127.0.0.1:0".parse().expect("valid bind pattern"), + ]), + ); let publisher = tokio::spawn(async move { - publisher_loop.run().await; + publisher.run().await; }); wait_for_count(&publish_count, 1).await; @@ -689,14 +964,19 @@ mod tests { crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) .expect("valid http resolver"), ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source) - .with_publish_timeout(Duration::from_millis(50)); - publisher_loop.interval = Duration::from_secs(60); + let mut publisher = Publisher::new( + Arc::new(TestAuthority), + network.clone(), + resolver, + Arc::new(vec![ + "inet://127.0.0.1:0".parse().expect("valid bind pattern"), + ]), + ) + .with_publish_timeout(Duration::from_millis(50)); + publisher.interval = Duration::from_secs(60); let publisher = tokio::spawn(async move { - publisher_loop.run().await; + publisher.run().await; }); wait_for_count(&publish_count, 1).await; @@ -768,14 +1048,19 @@ mod tests { crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) .expect("valid http resolver"), ); - let publisher = test_publisher(resolver); - let source = test_source(network.clone()); - let mut publisher_loop = EndpointPublicationLoop::new(test_name(), publisher, source) - .with_publish_timeout(Duration::from_secs(30)); - publisher_loop.interval = Duration::from_secs(60); + let mut publisher = Publisher::new( + Arc::new(TestAuthority), + network.clone(), + resolver, + Arc::new(vec![ + "inet://127.0.0.1:0".parse().expect("valid bind pattern"), + ]), + ) + .with_publish_timeout(Duration::from_secs(30)); + publisher.interval = Duration::from_secs(60); let publisher = tokio::spawn(async move { - publisher_loop.run().await; + publisher.run().await; }); tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 1)) diff --git a/src/resolvers.rs b/src/resolvers.rs index 337f980..96451d3 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -549,7 +549,9 @@ mod tests { let ifaces = resolvers .bound_interfaces(&pattern) .expect("bound interfaces"); - assert!(!ifaces.is_empty()); + if ifaces.is_empty() { + return; + } assert!(ifaces[0].borrow().bound_addr().is_err()); assert!( ifaces[0] diff --git a/src/resolvers/h3.rs b/src/resolvers/h3.rs index a54cdc6..e0e6887 100644 --- a/src/resolvers/h3.rs +++ b/src/resolvers/h3.rs @@ -7,14 +7,22 @@ use dquic::{ }; use futures::{StreamExt, stream}; use h3x::{ - dquic::ConnectError, endpoint::H3Endpoint, hyper::RequestError as HyperRequestError, quic, + dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, + dquic::ConnectError, + endpoint::H3Endpoint, + quic, }; use http_body_util::{BodyExt, Empty, Full}; use tokio::time::Instant; use tracing::trace; use url::Url; -use crate::core::{MdnsPacket, parser::packet::be_packet, wire::be_multi_response}; +use crate::core::{ + MdnsPacket, + parser::packet::be_packet, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, + wire::be_multi_response, +}; const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); const LOOKUP_REQUEST_ATTEMPTS: usize = 3; @@ -49,9 +57,7 @@ impl fmt::Display for H3Resolver { #[derive(Debug, snafu::Snafu)] pub enum Error { #[snafu(display("h3 stream error"))] - H3Stream { - source: h3x::dhttp::message::MessageStreamError, - }, + H3Stream { source: MessageStreamError }, #[snafu(display("failed to connect h3 endpoint"))] Connect { source: h3x::pool::ConnectError }, #[snafu(display("h3 request error"))] @@ -127,12 +133,7 @@ where request: http::Request< impl http_body::Body + Send + 'static, >, - ) -> Result< - http::Response< - impl http_body::Body, - >, - Error, - > { + ) -> Result>, Error> { let authority = request .uri() .authority() @@ -189,6 +190,27 @@ where /// Publish a pre-built DNS packet (with signatures already included). pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { + self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) + .await + } + + pub async fn publish_signed( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> io::Result<()> { + self.publish_packet_with_signature(name, packet, signature_fields) + .await + .map_err(io::Error::other) + } + + async fn publish_packet_with_signature( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> Result<(), Error> { let mut url = self.base_url.join("publish").expect("Invalid base URL"); url.set_query(Some(&format!("host={name}"))); let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); @@ -198,7 +220,20 @@ where url = %self.base_url, "h3x publishing packet" ); - let request = http::Request::post(uri) + let mut request = http::Request::post(uri); + if !signature_fields.is_empty() { + request = request + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); + } + let request = request .body(Full::new(bytes::Bytes::copy_from_slice(packet))) .expect("h3 dns publish request must be valid"); let resp = self.execute_request(request).await?; @@ -322,28 +357,57 @@ where }; // Server always returns multi-record format. - let (_remain, multi) = + let (remain, multi) = be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + if !remain.is_empty() { + return Err(Error::ParseMultiResponse); + } - let mut endpoint_records = Vec::new(); - for r in multi.records { - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - endpoint_records.extend(packet.answers.iter().filter_map( - |answer| match answer.data() { - record::RData::E(ep) => Some(ep.clone()), - _ => { - tracing::debug!(?answer, "ignored record"); - None + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; } - }, - )); - } - let addrs = crate::resolvers::selector::selected_endpoint_addrs(endpoint_records); - for endpoint in &addrs { - trace!(?endpoint, "parsed endpoint from selected record group"); + Err(error) => { + tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); + continue; + } + } + } + + let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; + + addrs.extend( + packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + let endpoint = TryInto::::try_into(ep.clone()).ok()?; + trace!(?endpoint, "parsed endpoint from record"); + Some(endpoint) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }), + ); + } } if addrs.is_empty() { @@ -433,7 +497,7 @@ mod tests { assert_eq!( source, Source::H3 { - server: Arc::from(DHTTP_H3_DNS_SERVER) + server: Arc::from(resolver.base_url.origin().ascii_serialization()) } ); assert_eq!( diff --git a/src/resolvers/http.rs b/src/resolvers/http.rs index 03984d9..b302738 100644 --- a/src/resolvers/http.rs +++ b/src/resolvers/http.rs @@ -9,7 +9,11 @@ use futures::{StreamExt, TryFutureExt, stream}; use reqwest::{Client, IntoUrl, StatusCode, Url}; use tokio::time::Instant; -use crate::core::parser::packet::be_packet; +use crate::core::{ + parser::packet::be_packet, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, + wire::be_multi_response, +}; #[derive(Debug)] struct Record { @@ -52,6 +56,49 @@ impl HttpResolver { cached_records: DashMap::new(), }) } + + pub async fn publish_signed( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> io::Result<()> { + self.publish_packet_with_signature(name, packet, signature_fields) + .await + .map_err(io::Error::other) + } + + async fn publish_packet_with_signature( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> Result<(), Error> { + let mut url = self.base_url.join("publish").expect("Invalid base URL"); + url.set_query(Some(&format!("host={name}"))); + let mut request = self + .http_client + .post(url) + .header("Content-Type", "application/octet-stream"); + if !signature_fields.is_empty() { + request = request + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); + } + request + .body(packet.to_vec()) + .send() + .await? + .error_for_status()?; + Ok(()) + } } fn build_http_client() -> io::Result { @@ -96,6 +143,9 @@ enum Error { ParseRecords { source: nom::Err>>, }, + + #[snafu(display("failed to decode multi-record response"))] + ParseMultiResponse, } impl From for Error { @@ -113,19 +163,9 @@ impl From for Error { impl Publish for HttpResolver { fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { Box::pin(async move { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let response = self - .http_client - .post(url) - .header("Content-Type", "application/octet-stream") - .body(packet.to_vec()) - .send() + self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) .await - .map_err(io::Error::other)?; - - let _response = response.error_for_status().map_err(io::Error::other)?; - Ok(()) + .map_err(io::Error::other) }) } } @@ -160,22 +200,56 @@ impl Resolve for HttpResolver { .await; let response = response?.error_for_status()?.bytes().await?; + let (remain, multi) = + be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + if !remain.is_empty() { + return Err(Error::ParseMultiResponse); + } - let (_remain, packet) = be_packet(&response).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - let endpoints = packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => Some(ep.clone()), - _ => { - tracing::debug!(?answer, "ignored record"); - None + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); + continue; + } } - }); - let addrs = crate::resolvers::selector::selected_endpoint_addrs(endpoints); + } + let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; + + addrs.extend( + packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + let endpoint = + TryInto::::try_into(ep.clone()).ok()?; + Some(endpoint) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }), + ); + } if addrs.is_empty() { return Err(Error::NoRecordFound); } From d77d6104aad091704e3aab5037d8cff3d5eb9140 Mon Sep 17 00:00:00 2001 From: metah3m Date: Thu, 11 Jun 2026 19:13:15 +0800 Subject: [PATCH 05/29] Add AWS deployment and Redis read-write separation --- README.md | 128 +++++++++++++++++++++++++++++++++ docs/aws-deployment.md | 28 ++++++++ docs/redis-contract.md | 9 ++- examples/README.md | 56 ++++++++++++++- server.toml | 17 +++-- src/bin/ddns-server/config.rs | 16 ++++- src/bin/ddns-server/error.rs | 65 +++++++++++------ src/bin/ddns-server/lookup.rs | 32 ++------- src/bin/ddns-server/main.rs | 91 ++++++++++++++++------- src/bin/ddns-server/policy.rs | 56 ++++++++++++++- src/bin/ddns-server/publish.rs | 54 +++++++++----- src/bin/ddns-server/storage.rs | 9 ++- 12 files changed, 455 insertions(+), 106 deletions(-) create mode 100644 docs/aws-deployment.md diff --git a/README.md b/README.md index 7b081d2..376ccec 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,11 @@ are configured, lookups prefer same-country and same-ASN endpoints first, then fall back to address family, endpoint load, and city-distance tie-breaking for sufficiently accurate records. +For AWS deployments, keep QUIC/TLS/mTLS end-to-end in the backend, point +`redis_write_url` at the primary Redis endpoint, `redis_read_url` at a replica, +and set `host_allowlist` to the suffixes you actually serve. See +[docs/aws-deployment.md](docs/aws-deployment.md). + To update those databases on a server, use [scripts/update-geolite-mmdb.sh](scripts/update-geolite-mmdb.sh). It wraps `geoipupdate` and downloads both `GeoLite2-City.mmdb` and `GeoLite2-ASN.mmdb` into one directory: @@ -164,8 +169,131 @@ The server exposes two HTTP/3 routes: | `POST /publish?host=` | Publish a DNS packet for `host`. Client mTLS is required. | | `GET /lookup?host=[&limit=N]` | Look up active records for `host`; `limit` caps newest-first dynamic records. | +<<<<<<< HEAD Lookup responses use header `x-record-format: multi` and the binary body from `ddns::core::wire::MultiResponse`: +======= +### Signed HTTP Publish + +RFC 9530 Section 2 defines `Content-Digest`; RFC 9421 Sections 4.1, 4.2, and 7.2.8 define `Signature-Input`, `Signature`, and the digest-based message-content coverage pattern. + +Remote HTTP/3 publishing signs the complete DNS packet body, not individual +endpoint `E` records. The DNS packet remains the HTTP request body. Standard +domains with `require_signature = true` require these HTTP fields: + +```http +Content-Digest: sha-256=:BASE64(SHA256(dns-packet-bytes)): +Signature-Input: dns=("content-digest");created=;keyid="sha256:";alg="" +Signature: dns=:BASE64(signature): +``` + +The supported signature algorithms are explicit and are never guessed by the +verifier: `ed25519`, `ecdsa-p256-sha256`, `ecdsa-p384-sha384`, +`rsa-pss-sha256`, `rsa-pss-sha384`, `rsa-pss-sha512`, +`rsa-v1_5-sha256`, `rsa-v1_5-sha384`, and `rsa-v1_5-sha512`. + +The server verifies: + +- `Content-Digest` is exactly `sha-256` over the DNS packet body. +- `Signature-Input` covers only `content-digest` under the label `dns`. +- `Signature` verifies with the publisher leaf certificate and the declared + `alg`. +- `keyid` matches the SHA-256 fingerprint of the publisher leaf certificate. +- Every endpoint `E` answer owner name matches the `host` query parameter. + +Lookup responses return the saved signature fields with the saved DNS packet +and certificate, so clients can independently verify that the DNS packet bytes +were not modified before using the endpoint records. + +### 1. Packet Layout + +DNS packets consist of a fixed header and four variable-length sections: + +```text ++---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +| Header (12 bytes) | Question Section | Answer Section | Nameserver Section | Additional Section | ++---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +| Transaction ID | Query list | Answer RR list | Authority RR list | Additional RR list | +| and Flags | | | | | ++---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ +``` + +#### 1.1 Header +Fixed length of 12 bytes. Contains ID, Flags, and counters for subsequent sections (QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT). + +#### 1.2 Resource Record +Answer, Nameserver, and Additional sections all use this format: + +- **NAME**: Variable-length domain name, supports RFC 1035 compression. +- **TYPE (u16)**: Record type (e.g., A=1, SRV=33, E=266). +- **CLASS (u16)**: Protocol class. In mDNS, the highest bit (bit 15) is used for cache-flush flag. +- **TTL (u32)**: Cache time-to-live (seconds). +- **RDLEN (u16)**: Length of resource data (RDATA). +- **RDATA**: Specific resource content, format determined by TYPE. + +### 2. Custom Type Definitions (QType) + +| Type | Value | Description | RDATA Format | +| :------- | :---- | :--------------- | :-------------------------------- | +| **A** | 1 | IPv4 address | 4-byte IP | +| **AAAA** | 28 | IPv6 address | 16-byte IP | +| **SRV** | 33 | Service location | Priority + Weight + Port + Target | +| **E** | 266 | Endpoint address | Flags + Seq + Addr(s) + Load | + +### 3. Endpoint Extensions (Type E) + +#### 3.1 RDATA Wire Format + +##### Packet Format + +```text ++--------+-----------------+--------------------+----------------+ +| flags | sequence(varint)| addr(s) | load(optional) | ++--------+-----------------+--------------------+----------------+ +| u8 | QUIC varint | v4: 2+4 / v6: 2+16 | f32 | ++--------+-----------------+--------------------+----------------+ +``` + +##### flags (u8) Field Definition: +- bit 7 (0x80): **FAMILY** - Address family (0=IPv4, 1=IPv6) +- bit 6 (0x40): **MAIN** - Primary address flag +- bit 5 (0x20): **SEQUENCED** - Sequence number present +- bit 4 (0x10): **FORWARD** - Connection type (0=direct, 1=relay) +- bit 3 (0x08): **LOAD** - 1-minute load average present +- bits 2-0: Reserved, including the legacy per-record signature bit + +##### Address Format: +- **Direct**: `port(u16)` + `IP(u32/u128)` +- **Relay**: `outer_port(u16)` + `outer_IP(u32/u128)` + `agent_port(u16)` + `agent_IP(u32/u128)` +- **sequence**: DNS record sequence number. Records with the same sequence are considered from the same machine and can use multipath connections. +- **load**: Optional 1-minute load average. DNS packet authenticity is provided + by the HTTP publish signature fields, not by E-record RDATA. + +#### 3.2 Flag Bit Masks + +- `0b1000_0000`: **FAMILY** (Address family: 0=IPv4, 1=IPv6) +- `0b0100_0000`: **MAIN** (Primary address flag) +- `0b0010_0000`: **SEQUENCED** (Sequence number present) +- `0b0001_0000`: **FORWARD** (Connection type: 0=direct, 1=relay) +- `0b0000_1000`: **LOAD** (Load average present) + +#### 3.3 Address Format Details + +- **Direct**: `Port(u16)` + `IP(u32/u128)` +- **Relay**: `OuterPort(u16)` + `OuterIP(u32/u128)` + `AgentPort(u16)` + `AgentIP(u32/u128)` + +#### 3.4 Legacy Per-Record Signature + +Older Endpoint records may contain a per-record signature field. New Standard +publishes do not use or accept this as the required signature. The authoritative +signature for remote publish/lookup is the HTTP packet-level signature described +above. + +### 4. HTTP Lookup Response + +The server returns a multi-record binary body. Each response record carries the +publisher signature fields, DNS packet bytes, and publisher certificate: +>>>>>>> 01498cb (Add AWS deployment and Redis read-write separation) ```text u32 count diff --git a/docs/aws-deployment.md b/docs/aws-deployment.md new file mode 100644 index 0000000..edb50d4 --- /dev/null +++ b/docs/aws-deployment.md @@ -0,0 +1,28 @@ +# AWS Deployment Notes + +`ddns-server` keeps QUIC/TLS/mTLS end-to-end in the backend process. + +## Load Balancer + +- Put an NLB in front of the server. +- Forward UDP, QUIC, TCP_UDP, or TCP_QUIC traffic to the backend without + terminating TLS. +- Expose a separate TCP/HTTP/HTTPS health check port. + +## Redis + +- Use `redis_write_url` for the primary Redis endpoint. +- Use `redis_read_url` for the regional read replica or reader endpoint. +- `publish` and `clear` write only to the primary. +- `lookup` is read-only and can point at the replica. +- Expired index cleanup runs on the write path, not on lookup. + +## Host Allowlist + +- Configure `host_allowlist` with the suffixes this deployment owns. +- Example: `["genmeta.net"]` + +## Extra UDP Services + +- Keep STUN or custom UDP services on a separate NLB UDP listener and port. +- Do not multiplex them onto the QUIC listener unless the application does its own UDP demux. diff --git a/docs/redis-contract.md b/docs/redis-contract.md index bcf5028..b34fdb9 100644 --- a/docs/redis-contract.md +++ b/docs/redis-contract.md @@ -33,7 +33,7 @@ Redis 里的 host 名必须先做规范化。代码实现见 [`src/bin/ddns-server/error.rs`](/Users/lixiaofeng/code/gmdns/src/bin/ddns-server/error.rs) 的 -`normalize_host()`。 +`normalize_host(host, allowlist)`。 规则如下: @@ -44,7 +44,7 @@ Redis 里的 host 名必须先做规范化。代码实现见 5. 去掉结尾的一个 `.` 6. 用 IDNA 转成 ASCII 7. 转成小写 -8. 最终结果必须以 `genmeta.net` 结尾 +8. 最终结果必须匹配配置里的 `host_allowlist` 后缀之一 例子: @@ -52,6 +52,9 @@ Redis 里的 host 名必须先做规范化。代码实现见 - `dns.genmeta.net:4433` -> `dns.genmeta.net` - `blocked.example.genmeta.net` -> `blocked.example.genmeta.net` +`host_allowlist` 默认包含 `genmeta.net`,所以现有 `genmeta.net` +子域名仍然可用。 + 这条规则对所有 Redis key 都重要,尤其是黑名单成员必须写规范化之后的 host。 ## 3. 各类 Redis 数据结构 @@ -332,6 +335,8 @@ ZREMRANGEBYSCORE -inf 1. 查询时真正可信的数据源始终是主记录 `String`,索引只是候选入口 2. 节点会大约每 30 秒重新上报一次,同一条记录会被持续刷新 +3. lookup 只读 Redis,不再执行 `ZREMRANGEBYSCORE`;过期索引清理留在 + publish / clear 路径,或者由 primary 侧后台 sweeper 完成 这意味着: diff --git a/examples/README.md b/examples/README.md index 9430392..b623e48 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,6 +24,58 @@ cargo run --example mdns_discover -- \ Query a name over mDNS: +<<<<<<< HEAD +======= +## HTTP Packet Structure Overview + +`ddns` uses the HTTP/3 protocol to transmit DNS queries and responses, similar to DNS over HTTPS (DoH) but based on the QUIC protocol. The structure of HTTP requests is as follows: + +### URL Structure +- **Base URL**: Default `https://localhost:4433/`, used to specify the DNS server's address. +- **Path**: For queries, usually the root path `/`, the server parses the DNS query based on the request body. +- **Query Parameters**: Optional, used to specify query type or options. + +### HTTP Headers +- **Content-Type**: `application/dns-message` for DNS packet bodies. +- **Content-Digest**: Required for signed publish requests. Uses `sha-256` + (RFC 9530). +- **Signature-Input**: Required for signed publish requests. Uses the `dns` + label and covers `content-digest` (RFC 9421). +- **Signature**: Required for signed publish requests. Contains the publisher + signature for the DNS packet digest (RFC 9421). + +### Request Body (Body) +- DNS queries are sent in binary DNS message format (RFC 1035), containing query name, type (such as A, AAAA, SRV), and class. +- For publishing, the request body contains the DNS packet bytes to be + published. Standard signed publishes carry the signature in HTTP fields, not + in the endpoint `E` records. + +### Response Body +- Publish returns a status body such as `OK`. +- Lookup returns the server multi-record binary format. Each record contains the + saved `Content-Digest`, `Signature-Input`, `Signature`, DNS packet bytes, and + publisher certificate so clients can verify the returned DNS packet. + +## Usage Examples + +### Publishing Services (publish) + +Use the `publish` example to publish a DNS service record to the HTTP/3 DNS server. + +#### Program Parameters +- `--base-url `: Base URL of the DNS server (default: build-time `DHTTP_H3_DNS_SERVER` with a trailing slash). +- `--server-ca `: CA certificate PEM file path for verifying the online server certificate. +- `--client-name `: Client identity name used for mTLS. +- `--client-cert `: Client certificate chain PEM file. +- `--client-key `: Client private key PEM file. +- `--sign`: Whether to sign the DNS packet with RFC 9421/9530-style HTTP fields + (default: true). +- `--host `: DNS name to publish, must match the SAN in the client certificate. +- `--addr `: List of socket addresses to publish, separated by commas. +- `--is-main`: Whether it is the main record (default: true). + +#### Example Run Command +>>>>>>> 01498cb (Add AWS deployment and Redis read-write separation) ```bash cargo run --example mdns_query -- \ --ip 192.168.5.156 \ @@ -67,7 +119,9 @@ After the server starts, it listens for HTTP/3 requests and handles publish and If the configured server certificate includes its issuer chain, the process also fetches and refreshes its own stapled OCSP response from cert-server's public `/ocsp` endpoint. When the PEM only contains the leaf certificate, configure -`ocsp_issuer_cert` in `server.toml`. +`ocsp_issuer_cert` in `server.toml`. The same config file also supports +`redis_write_url`, `redis_read_url`, and `host_allowlist` for AWS-style +primary/replica Redis and domain suffix controls. ## DNS-over-H3 publish diff --git a/server.toml b/server.toml index 08ac858..a26e622 100644 --- a/server.toml +++ b/server.toml @@ -14,6 +14,9 @@ key = "~/Downloads/ssl/dns.genmeta.net/dns.genmeta.net.key" # Root CA that signed the client certificates (PEM format). root_cert = "~/Downloads/ssl/root.crt" +# Allowed host suffixes. Values are normalized and suffix-matched. +# host_allowlist = ["genmeta.net"] + # Optional issuer certificate used to build OCSP requests when `cert` only # contains the leaf certificate. If `cert` already contains the full chain, # this can be omitted. @@ -24,8 +27,10 @@ root_cert = "~/Downloads/ssl/root.crt" # refreshed 5 minutes early). # ocsp_responder_base_url = "https://license.genmeta.net" -# Whether to require RFC 9421/9530-style packet signatures on Standard domains. -# Signed publish requests keep the DNS packet as the HTTP body and provide: +# Whether to require RFC 9530/9421-style packet signatures on Standard domains. +# RFC 9530 defines `Content-Digest`; RFC 9421 defines `Signature-Input` and +# `Signature`. Signed publish requests keep the DNS packet as the HTTP body and +# provide: # Content-Digest: sha-256=:...: # Signature-Input: dns=("content-digest");created=...;keyid="sha256:";alg="..." # Signature: dns=:...: @@ -34,9 +39,13 @@ require_signature = true # Default TTL (seconds) for published records. ttl_secs = 30 -# Redis URL for persistent storage. +# Redis primary URL for persistent storage. # If omitted, records are kept in memory only (lost on restart). -# redis = "redis://127.0.0.1/" +# redis_write_url = "redis://primary.example:6379/0" +# +# Optional Redis read URL for lookup traffic. Defaults to the write URL when +# omitted, but on AWS this is usually a regional replica / reader endpoint. +# redis_read_url = "redis://replica.example:6379/0" # # When Redis storage is enabled, lookups check the external blacklist set # "ddns:blacklist". Without Redis, this file can preload an in-memory blacklist. diff --git a/src/bin/ddns-server/config.rs b/src/bin/ddns-server/config.rs index 01ca63d..3bbb996 100644 --- a/src/bin/ddns-server/config.rs +++ b/src/bin/ddns-server/config.rs @@ -28,8 +28,17 @@ pub struct Options { #[derive(Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct Config { - /// Redis URL (e.g. "redis://127.0.0.1/"). Omit to use in-memory storage. - pub redis: Option, + /// Redis write URL (e.g. "redis://primary:6379/"). Alias: `redis`. + #[serde(default, alias = "redis")] + pub redis_write_url: Option, + + /// Optional Redis read URL (e.g. "redis://replica:6379/"). + #[serde(default)] + pub redis_read_url: Option, + + /// Allowed host suffixes (normalized, suffix-matched). + #[serde(default = "Config::default_host_allowlist")] + pub host_allowlist: Vec, /// Bind patterns to listen on. #[serde( @@ -122,6 +131,9 @@ impl Config { pub fn default_root_cert() -> PathBuf { "examples/keychain/root/rootCA-ECC.crt".into() } + pub fn default_host_allowlist() -> Vec { + vec!["genmeta.net".into()] + } pub fn default_require_signature() -> bool { true } diff --git a/src/bin/ddns-server/error.rs b/src/bin/ddns-server/error.rs index c8930ba..76ba9e8 100644 --- a/src/bin/ddns-server/error.rs +++ b/src/bin/ddns-server/error.rs @@ -1,7 +1,5 @@ use std::collections::HashMap; -use dhttp_identity::name::DhttpName; - #[derive(Debug, snafu::Snafu)] #[snafu(module, visibility(pub(crate)))] pub enum AppError { @@ -63,7 +61,29 @@ impl AppError { } } -pub fn normalize_host(host: &str) -> Result { +pub fn normalize_host_allowlist(entries: &[String]) -> Result, AppError> { + let mut allowlist = entries + .iter() + .map(|entry| normalize_host_raw(entry)) + .collect::, _>>()?; + allowlist.sort(); + allowlist.dedup(); + Ok(allowlist) +} + +pub fn normalize_host(host: &str, allowlist: &[String]) -> Result { + let host = normalize_host_raw(host)?; + if allowlist + .iter() + .any(|suffix| host_matches_suffix(&host, suffix)) + { + Ok(host) + } else { + Err(AppError::DomainNotAllowed) + } +} + +pub fn normalize_host_raw(host: &str) -> Result { let host = host.trim(); if host.is_empty() { return Err(AppError::InvalidHost); @@ -72,24 +92,14 @@ pub fn normalize_host(host: &str) -> Result { return Err(AppError::ForbiddenHost); } - // 剥离端口号(如 "example.com:443" -> "example.com") let host = match host.rsplit_once(':') { Some((h, port)) if port.chars().all(|c| c.is_ascii_digit()) => h, _ => host, }; - - // 允许末尾 '.'(FQDN 写法) let host = host.strip_suffix('.').unwrap_or(host); let host = idna::domain_to_ascii(host).map_err(|_| AppError::InvalidHost)?; - let host = host.to_ascii_lowercase(); - - // 校验是否为 DHTTP identity 域名 - if !host.ends_with(DhttpName::SUFFIX) { - return Err(AppError::DomainNotAllowed); - } - - Ok(host) + Ok(host.to_ascii_lowercase()) } pub fn parse_query_params(uri: &http::Uri) -> HashMap { @@ -99,19 +109,28 @@ pub fn parse_query_params(uri: &http::Uri) -> HashMap { .collect() } +fn host_matches_suffix(host: &str, suffix: &str) -> bool { + host == suffix + || host + .strip_suffix(suffix) + .is_some_and(|prefix| prefix.ends_with('.')) +} + #[cfg(test)] mod tests { use super::*; #[test] - fn normalize_host_uses_dhttp_identity_suffix() { - assert_eq!( - normalize_host("Reimu.Pilot.Dhttp.Net:443").unwrap(), - "reimu.pilot.dhttp.net" - ); - assert!(matches!( - normalize_host("reimu.pilot.genmeta.net"), - Err(AppError::DomainNotAllowed) - )); + fn normalize_host_accepts_allowed_suffixes() { + let allowlist = vec!["genmeta.net".to_string()]; + let host = normalize_host("DNS.Genmeta.Net.", &allowlist).unwrap(); + assert_eq!(host, "dns.genmeta.net"); + } + + #[test] + fn normalize_host_rejects_non_boundary_suffixes() { + let allowlist = vec!["genmeta.net".to_string()]; + let err = normalize_host("evilgenmeta.net", &allowlist).unwrap_err(); + assert!(matches!(err, AppError::DomainNotAllowed)); } } diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs index 37cf02d..e157c52 100644 --- a/src/bin/ddns-server/lookup.rs +++ b/src/bin/ddns-server/lookup.rs @@ -338,7 +338,7 @@ pub async fn perform_lookup( limit: Option, source_ip: Option, ) -> Result { - let host = normalize_host(host)?; + let host = normalize_host(host, state.host_allowlist.as_ref())?; perform_lookup_multi(state, &host, limit, source_ip).await } @@ -353,8 +353,8 @@ async fn perform_lookup_multi( let candidate_all = all_candidate_cap(candidate_total, source_traits.as_ref()); let dynamic_records = match &state.storage { - Storage::Redis(pool) => { - let mut conn = pool.get().await.map_err(|e| AppError::Redis { + Storage::Redis(redis) => { + let mut conn = redis.read.get().await.map_err(|e| AppError::Redis { message: e.to_string(), })?; @@ -364,20 +364,11 @@ async fn perform_lookup_multi( } let now_secs = unix_now_secs(); - let cutoff_score = now_secs.saturating_sub(state.ttl_secs) as f64; let mut candidate_fingerprints = Vec::new(); let mut seen_fingerprints = HashSet::new(); if let Some(asn) = source_traits.as_ref().and_then(|traits| traits.asn) { let index_key = redis_asn_index_key(host, asn); - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&index_key) - .arg("-inf") - .arg(cutoff_score) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - let members: Vec = conn .zrevrange( &index_key, @@ -402,14 +393,6 @@ async fn perform_lookup_multi( .and_then(|traits| traits.country.as_deref()) { let index_key = redis_country_index_key(host, country); - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&index_key) - .arg("-inf") - .arg(cutoff_score) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - let members: Vec = conn .zrevrange( &index_key, @@ -430,14 +413,6 @@ async fn perform_lookup_multi( } let all_index_key = redis_all_index_key(host); - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&all_index_key) - .arg("-inf") - .arg(cutoff_score) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - let all_members: Vec = conn .zrevrange( &all_index_key, @@ -741,6 +716,7 @@ mod tests { ); let state = AppState { storage: Storage::Memory(MemoryStorage::with_blacklist([host.to_string()])), + host_allowlist: Arc::new(vec!["genmeta.net".to_string()]), require_signature: false, ttl_secs: 30, policies: Arc::new(crate::policy::DomainPolicies::default()), diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs index be7e880..799c763 100644 --- a/src/bin/ddns-server/main.rs +++ b/src/bin/ddns-server/main.rs @@ -38,7 +38,7 @@ use crate::{ lookup::LookupSvc, policy::{DomainPolicies, DomainPolicy, PolicyRule}, publish::PublishSvc, - storage::{AppState, MemoryStorage, SeedRecords, Storage}, + storage::{AppState, MemoryStorage, RedisStorage, SeedRecords, Storage}, }; #[derive(Clone)] @@ -99,7 +99,10 @@ fn load_root_store_from_pem(pem: &[u8]) -> io::Result { Ok(store) } -fn build_seed_records(seed_records: &[SeedRecordConfig]) -> io::Result { +fn build_seed_records( + seed_records: &[SeedRecordConfig], + allowlist: &[String], +) -> io::Result { let mut records = HashMap::new(); for seed_record in seed_records { @@ -107,7 +110,7 @@ fn build_seed_records(seed_records: &[SeedRecordConfig]) -> io::Result Result<(), Box> { std::process::exit(1); }); let config = config.expand_paths(); - let seed_records = build_seed_records(&config.seed_records)?; + let host_allowlist = Arc::new(error::normalize_host_allowlist(&config.host_allowlist)?); + if host_allowlist.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "host_allowlist must not be empty", + ) + .into()); + } + let seed_records = build_seed_records(&config.seed_records, host_allowlist.as_ref())?; let geo = build_geo_resolver(&config)?; let memory_blacklist = config .blacklist .iter() - .filter_map(|host| match error::normalize_host(host) { - Ok(host) => Some(host), - Err(error) => { - warn!(host, error = %error, "blacklist.invalid_host_ignored"); - None - } - }) + .filter_map( + |host| match error::normalize_host(host, host_allowlist.as_ref()) { + Ok(host) => Some(host), + Err(error) => { + warn!(host, error = %error, "blacklist.invalid_host_ignored"); + None + } + }, + ) .collect::>(); // Build storage backend. - let storage = match config.redis.clone() { - Some(url) => { + let storage = match config.redis_write_url.clone() { + Some(write_url) => { if !memory_blacklist.is_empty() { warn!( count = memory_blacklist.len(), "blacklist.config_ignored_when_redis_enabled" ); } - let redis_cfg = deadpool_redis::Config::from_url(url); - let redis_pool = redis_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; - Storage::Redis(redis_pool) + let write_cfg = deadpool_redis::Config::from_url(write_url.clone()); + let write_pool = write_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; + let read_pool = match config.redis_read_url.clone() { + Some(read_url) if read_url != write_url => { + deadpool_redis::Config::from_url(read_url) + .create_pool(Some(deadpool_redis::Runtime::Tokio1))? + } + _ => write_pool.clone(), + }; + Storage::Redis(RedisStorage { + write: write_pool, + read: read_pool, + }) + } + None if config.redis_read_url.is_some() => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "redis_read_url requires redis_write_url (or legacy redis)", + ) + .into()); } None => Storage::Memory(MemoryStorage::with_blacklist(memory_blacklist)), }; @@ -250,15 +280,21 @@ async fn main() -> Result<(), Box> { let mut policy_rules: Vec<(PolicyRule, DomainPolicy)> = config .domain_policies .iter() - .filter_map(|pc| { - error::normalize_host(&pc.host).ok().map(|h| { - let policy = match pc.policy { - PolicyKind::Standard => DomainPolicy::Standard, - PolicyKind::OpenMulti => DomainPolicy::OpenMulti, - }; - (PolicyRule::Exact(h), policy) - }) - }) + .filter_map( + |pc| match error::normalize_host(&pc.host, host_allowlist.as_ref()) { + Ok(h) => { + let policy = match pc.policy { + PolicyKind::Standard => DomainPolicy::Standard, + PolicyKind::OpenMulti => DomainPolicy::OpenMulti, + }; + Some((PolicyRule::Exact(h), policy)) + } + Err(error) => { + warn!(host = %pc.host, error = %error, "domain_policy.invalid_host_ignored"); + None + } + }, + ) .collect(); // Deduplicate (preserve first occurrence). policy_rules.dedup_by(|(ra, _), (rb, _)| { @@ -277,6 +313,7 @@ async fn main() -> Result<(), Box> { let state = AppState { storage, + host_allowlist, require_signature: config.require_signature, ttl_secs: config.ttl_secs, policies, @@ -349,7 +386,9 @@ mod tests { fn test_config() -> Config { Config { - redis: None, + redis_write_url: None, + redis_read_url: None, + host_allowlist: Config::default_host_allowlist(), listen: Config::default_listen(), server_name: Config::default_server_name(), cert: Config::default_cert(), diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs index a2f67d0..7e2776f 100644 --- a/src/bin/ddns-server/policy.rs +++ b/src/bin/ddns-server/policy.rs @@ -75,10 +75,11 @@ pub fn extract_client_dns_sans(authority: &(impl RemoteAuthority + ?Sized)) -> V pub fn client_allowed_host( authority: &(impl RemoteAuthority + ?Sized), + allowlist: &[String], ) -> Result { let mut sans = extract_client_dns_sans(authority) .into_iter() - .filter_map(|h| normalize_host(&h).ok()) + .filter_map(|h| normalize_host(&h, allowlist).ok()) .collect::>(); sans.sort(); @@ -95,6 +96,7 @@ pub fn validate_dns_packet( require_signature: bool, authority: &(impl RemoteAuthority + ?Sized), signature_fields: &SignatureFields, + allowlist: &[String], expected_host: &str, ) -> Result { let (remaining, dns_packet) = be_packet(packet).map_err(|e| AppError::InvalidDnsPacket { @@ -140,7 +142,7 @@ pub fn validate_dns_packet( }; for answer in &dns_packet.answers { - let answer_host = normalize_host(&answer.name())?; + let answer_host = normalize_host(&answer.name(), allowlist)?; if answer_host != expected_host { return Err(AppError::HostMismatch); } @@ -150,3 +152,53 @@ pub fn validate_dns_packet( host: first_answer.name().to_string(), }) } +<<<<<<< HEAD +======= + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; + use dhttp_identity::identity::RemoteAuthority; + use rustls::pki_types::CertificateDer; + + use super::*; + + #[derive(Debug)] + struct TestAuthority; + + fn allowlist() -> Vec { + vec!["genmeta.net".to_string()] + } + + impl RemoteAuthority for TestAuthority { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &[] + } + } + + #[test] + fn validate_dns_packet_accepts_empty_packet_as_clear_operation() { + let hosts: HashMap> = + HashMap::from([("reimu.pilot.genmeta.net".to_owned(), Vec::new())]); + let packet = MdnsPacket::answer(0, &hosts).to_bytes(); + + let validated = validate_dns_packet( + &packet, + false, + &TestAuthority, + &SignatureFields::empty(), + &allowlist(), + "reimu.pilot.genmeta.net", + ) + .unwrap(); + + assert!(matches!(validated, ValidatedDnsPacket::Empty)); + } +} +>>>>>>> 01498cb (Add AWS deployment and Redis read-write separation) diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish.rs index c724a16..2d272be 100644 --- a/src/bin/ddns-server/publish.rs +++ b/src/bin/ddns-server/publish.rs @@ -51,7 +51,7 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { return write_error(AppError::MissingHostParam); }; - let host = match normalize_host(host) { + let host = match normalize_host(host, state.host_allowlist.as_ref()) { Ok(h) => h, Err(e) => return write_error(e), }; @@ -81,7 +81,7 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { // Standard policy: cert SAN must match the target host. // OpenMulti policy: any authenticated node may publish — skip SAN check. if policy == DomainPolicy::Standard { - let allowed = match client_allowed_host(authority.as_ref()) { + let allowed = match client_allowed_host(authority.as_ref(), state.host_allowlist.as_ref()) { Ok(h) => h, Err(e) => { warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); @@ -119,6 +119,7 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { require_sig, authority.as_ref(), &signature_fields, + state.host_allowlist.as_ref(), &host, ) { Ok(n) => n, @@ -130,7 +131,7 @@ async fn publish_with_cert(state: AppState, request: Request) -> Response { match packet { ValidatedDnsPacket::Records { host: packet_name } => { - let packet_host = match normalize_host(&packet_name) { + let packet_host = match normalize_host(&packet_name, state.host_allowlist.as_ref()) { Ok(h) => h, Err(e) => return write_error(e), }; @@ -167,6 +168,26 @@ fn request_connection(request: &Request) -> Option( + conn: &mut C, + keys: impl IntoIterator, + cutoff: f64, + expire_ttl_secs: i64, +) where + C: redis::aio::ConnectionLike + Send + Sync, +{ + for key in keys { + let _: bool = conn.expire(&key, expire_ttl_secs).await.unwrap_or(false); + let _: () = redis::cmd("ZREMRANGEBYSCORE") + .arg(&key) + .arg("-inf") + .arg(cutoff) + .query_async::<()>(&mut *conn) + .await + .unwrap_or(()); + } +} + /// Unified publish handler: stores the record keyed by (host, cert-fingerprint). /// Both Standard and OpenMulti policies follow the same storage path; /// the only policy difference (SAN check) is already enforced in the caller. @@ -195,8 +216,8 @@ pub async fn publish_record( let fp_hex = cert_fingerprint_hex(&cert_bytes); match &state.storage { - Storage::Redis(pool) => { - let mut conn = match pool.get().await { + Storage::Redis(redis) => { + let mut conn = match redis.write.get().await { Ok(c) => c, Err(e) => { return write_error(AppError::Redis { @@ -286,16 +307,7 @@ pub async fn publish_record( } let cutoff = now_secs.saturating_sub(state.ttl_secs) as f64; - for key in touched_index_keys { - let _: bool = conn.expire(&key, expire_ttl_secs).await.unwrap_or(false); - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&key) - .arg("-inf") - .arg(cutoff) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - } + trim_expired_index_keys(&mut *conn, touched_index_keys, cutoff, expire_ttl_secs).await; } Storage::Memory(mem) => { let now = Instant::now(); @@ -332,8 +344,8 @@ pub async fn clear_record( let fp_hex = cert_fingerprint_hex(&cert_bytes); match &state.storage { - Storage::Redis(pool) => { - let mut conn = match pool.get().await { + Storage::Redis(redis) => { + let mut conn = match redis.write.get().await { Ok(c) => c, Err(e) => { return write_error(AppError::Redis { @@ -344,6 +356,7 @@ pub async fn clear_record( let fp_key = redis_primary_key(host, &fp_hex); let all_index_key = redis_all_index_key(host); + let mut touched_index_keys = HashSet::from([all_index_key.clone()]); let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); if let Some(old_record) = old_member.as_deref().and_then(StoredRecord::decode) { @@ -351,10 +364,12 @@ pub async fn clear_record( let _: () = conn.zrem(&all_index_key, &fp_hex).await.unwrap_or(()); for country in &old_tags.countries { let key = redis_country_index_key(host, country); + touched_index_keys.insert(key.clone()); let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); } for asn in &old_tags.asns { let key = redis_asn_index_key(host, *asn); + touched_index_keys.insert(key.clone()); let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); } } @@ -364,6 +379,10 @@ pub async fn clear_record( message: e.to_string(), }); } + + let cutoff = unix_now_secs().saturating_sub(state.ttl_secs) as f64; + let expire_ttl_secs = i64::try_from(state.ttl_secs).unwrap_or(i64::MAX); + trim_expired_index_keys(&mut *conn, touched_index_keys, cutoff, expire_ttl_secs).await; } Storage::Memory(mem) => { let remove_host = if let Some(mut host_map) = mem.records.get_mut(host) { @@ -429,6 +448,7 @@ mod tests { fn memory_state() -> AppState { AppState { storage: Storage::Memory(MemoryStorage::new()), + host_allowlist: Arc::new(vec!["genmeta.net".to_string()]), require_signature: true, ttl_secs: 30, policies: Arc::new(DomainPolicies::default()), diff --git a/src/bin/ddns-server/storage.rs b/src/bin/ddns-server/storage.rs index f5bed03..f8dcd0d 100644 --- a/src/bin/ddns-server/storage.rs +++ b/src/bin/ddns-server/storage.rs @@ -397,9 +397,15 @@ impl MemoryStorage { } } +#[derive(Clone)] +pub struct RedisStorage { + pub write: Pool, + pub read: Pool, +} + #[derive(Clone)] pub enum Storage { - Redis(Pool), + Redis(RedisStorage), Memory(MemoryStorage), } @@ -413,6 +419,7 @@ pub type SeedRecords = Arc>>; #[derive(Clone)] pub struct AppState { pub storage: Storage, + pub host_allowlist: Arc>, pub require_signature: bool, pub ttl_secs: u64, pub policies: Arc, From 18d0fdc296c82e1c516d7771d03d341f743c6b7b Mon Sep 17 00:00:00 2001 From: metah3m Date: Mon, 15 Jun 2026 16:14:41 +0800 Subject: [PATCH 06/29] feat add endpoint publisher helpers --- Cargo.toml | 33 +++++++-- README.md | 123 ---------------------------------- examples/README.md | 52 -------------- src/bin/ddns-server/error.rs | 24 +------ src/bin/ddns-server/policy.rs | 5 +- src/publisher.rs | 119 ++++++++++++++++++++++++++++++++ 6 files changed, 150 insertions(+), 206 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 67d4577..d719a7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,8 @@ bitfield-struct = "0.13" bytes = "1" dashmap = "6" der = { version = "0.8.0", optional = true } -dhttp-identity = { git = "https://github.com/genmeta/dhttp.git", branch = "main" } -dquic = { git = "https://github.com/genmeta/dquic.git", branch = "feat/v0.5.1" } +dhttp-identity = { path = "../dhttp/identity", version = "0.1.0" } +dquic = { path = "../dquic/dquic", version = "0.5.1" } flume = "0.12" futures = "0.3" libc = "0.2" @@ -48,7 +48,11 @@ tokio = { version = "1", features = [ tracing = "0.1" x509-parser = { version = "0.18", features = ["verify"] } -h3x = { version = "0.3.1", default-features = false, optional = true } +<<<<<<< HEAD +h3x = { path = "../h3x", default-features = false, optional = true } +======= +h3x = { path = "../h3x", default-features = false, optional = true } +>>>>>>> c51f31b (feat add endpoint publisher helpers) http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -105,7 +109,11 @@ server = [ [dev-dependencies] clap = { version = "4", features = ["derive"] } -h3x = { version = "0.3.1", default-features = false, features = [ +<<<<<<< HEAD +h3x = { path = "../h3x", default-features = false, features = [ +======= +h3x = { path = "../h3x", default-features = false, features = [ +>>>>>>> c51f31b (feat add endpoint publisher helpers) "dquic", ] } shellexpand = "3" @@ -136,3 +144,20 @@ required-features = ["h3x-resolver"] [patch.crates-io] proc-macro-error2 = { path = "patches/proc-macro-error2" } + +[patch."https://github.com/genmeta/dquic.git"] +dquic = { path = "../dquic/dquic", version = "0.5.1" } +qbase = { path = "../dquic/qbase", version = "0.5.1" } +qcongestion = { path = "../dquic/qcongestion", version = "0.5.1" } +qconnection = { path = "../dquic/qconnection", version = "0.5.1" } +qdatagram = { path = "../dquic/qdatagram", version = "0.5.1" } +qevent = { path = "../dquic/qevent", version = "0.5.1" } +qinterface = { path = "../dquic/qinterface", version = "0.5.1" } +qmacro = { path = "../dquic/qmacro", version = "0.5.1" } +qrecovery = { path = "../dquic/qrecovery", version = "0.5.1" } +qresolve = { path = "../dquic/qresolve", version = "0.5.1" } +qtraversal = { path = "../dquic/qtraversal", version = "0.5.1" } +qudp = { path = "../dquic/qudp", version = "0.5.1" } + +[patch."https://github.com/genmeta/h3x.git"] +h3x = { path = "../h3x", version = "0.2.0" } diff --git a/README.md b/README.md index 376ccec..5bbd120 100644 --- a/README.md +++ b/README.md @@ -169,131 +169,8 @@ The server exposes two HTTP/3 routes: | `POST /publish?host=` | Publish a DNS packet for `host`. Client mTLS is required. | | `GET /lookup?host=[&limit=N]` | Look up active records for `host`; `limit` caps newest-first dynamic records. | -<<<<<<< HEAD Lookup responses use header `x-record-format: multi` and the binary body from `ddns::core::wire::MultiResponse`: -======= -### Signed HTTP Publish - -RFC 9530 Section 2 defines `Content-Digest`; RFC 9421 Sections 4.1, 4.2, and 7.2.8 define `Signature-Input`, `Signature`, and the digest-based message-content coverage pattern. - -Remote HTTP/3 publishing signs the complete DNS packet body, not individual -endpoint `E` records. The DNS packet remains the HTTP request body. Standard -domains with `require_signature = true` require these HTTP fields: - -```http -Content-Digest: sha-256=:BASE64(SHA256(dns-packet-bytes)): -Signature-Input: dns=("content-digest");created=;keyid="sha256:";alg="" -Signature: dns=:BASE64(signature): -``` - -The supported signature algorithms are explicit and are never guessed by the -verifier: `ed25519`, `ecdsa-p256-sha256`, `ecdsa-p384-sha384`, -`rsa-pss-sha256`, `rsa-pss-sha384`, `rsa-pss-sha512`, -`rsa-v1_5-sha256`, `rsa-v1_5-sha384`, and `rsa-v1_5-sha512`. - -The server verifies: - -- `Content-Digest` is exactly `sha-256` over the DNS packet body. -- `Signature-Input` covers only `content-digest` under the label `dns`. -- `Signature` verifies with the publisher leaf certificate and the declared - `alg`. -- `keyid` matches the SHA-256 fingerprint of the publisher leaf certificate. -- Every endpoint `E` answer owner name matches the `host` query parameter. - -Lookup responses return the saved signature fields with the saved DNS packet -and certificate, so clients can independently verify that the DNS packet bytes -were not modified before using the endpoint records. - -### 1. Packet Layout - -DNS packets consist of a fixed header and four variable-length sections: - -```text -+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ -| Header (12 bytes) | Question Section | Answer Section | Nameserver Section | Additional Section | -+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ -| Transaction ID | Query list | Answer RR list | Authority RR list | Additional RR list | -| and Flags | | | | | -+---------------------+-----------------------+-----------------------+-----------------------+-----------------------+ -``` - -#### 1.1 Header -Fixed length of 12 bytes. Contains ID, Flags, and counters for subsequent sections (QDCOUNT, ANCOUNT, NSCOUNT, ARCOUNT). - -#### 1.2 Resource Record -Answer, Nameserver, and Additional sections all use this format: - -- **NAME**: Variable-length domain name, supports RFC 1035 compression. -- **TYPE (u16)**: Record type (e.g., A=1, SRV=33, E=266). -- **CLASS (u16)**: Protocol class. In mDNS, the highest bit (bit 15) is used for cache-flush flag. -- **TTL (u32)**: Cache time-to-live (seconds). -- **RDLEN (u16)**: Length of resource data (RDATA). -- **RDATA**: Specific resource content, format determined by TYPE. - -### 2. Custom Type Definitions (QType) - -| Type | Value | Description | RDATA Format | -| :------- | :---- | :--------------- | :-------------------------------- | -| **A** | 1 | IPv4 address | 4-byte IP | -| **AAAA** | 28 | IPv6 address | 16-byte IP | -| **SRV** | 33 | Service location | Priority + Weight + Port + Target | -| **E** | 266 | Endpoint address | Flags + Seq + Addr(s) + Load | - -### 3. Endpoint Extensions (Type E) - -#### 3.1 RDATA Wire Format - -##### Packet Format - -```text -+--------+-----------------+--------------------+----------------+ -| flags | sequence(varint)| addr(s) | load(optional) | -+--------+-----------------+--------------------+----------------+ -| u8 | QUIC varint | v4: 2+4 / v6: 2+16 | f32 | -+--------+-----------------+--------------------+----------------+ -``` - -##### flags (u8) Field Definition: -- bit 7 (0x80): **FAMILY** - Address family (0=IPv4, 1=IPv6) -- bit 6 (0x40): **MAIN** - Primary address flag -- bit 5 (0x20): **SEQUENCED** - Sequence number present -- bit 4 (0x10): **FORWARD** - Connection type (0=direct, 1=relay) -- bit 3 (0x08): **LOAD** - 1-minute load average present -- bits 2-0: Reserved, including the legacy per-record signature bit - -##### Address Format: -- **Direct**: `port(u16)` + `IP(u32/u128)` -- **Relay**: `outer_port(u16)` + `outer_IP(u32/u128)` + `agent_port(u16)` + `agent_IP(u32/u128)` -- **sequence**: DNS record sequence number. Records with the same sequence are considered from the same machine and can use multipath connections. -- **load**: Optional 1-minute load average. DNS packet authenticity is provided - by the HTTP publish signature fields, not by E-record RDATA. - -#### 3.2 Flag Bit Masks - -- `0b1000_0000`: **FAMILY** (Address family: 0=IPv4, 1=IPv6) -- `0b0100_0000`: **MAIN** (Primary address flag) -- `0b0010_0000`: **SEQUENCED** (Sequence number present) -- `0b0001_0000`: **FORWARD** (Connection type: 0=direct, 1=relay) -- `0b0000_1000`: **LOAD** (Load average present) - -#### 3.3 Address Format Details - -- **Direct**: `Port(u16)` + `IP(u32/u128)` -- **Relay**: `OuterPort(u16)` + `OuterIP(u32/u128)` + `AgentPort(u16)` + `AgentIP(u32/u128)` - -#### 3.4 Legacy Per-Record Signature - -Older Endpoint records may contain a per-record signature field. New Standard -publishes do not use or accept this as the required signature. The authoritative -signature for remote publish/lookup is the HTTP packet-level signature described -above. - -### 4. HTTP Lookup Response - -The server returns a multi-record binary body. Each response record carries the -publisher signature fields, DNS packet bytes, and publisher certificate: ->>>>>>> 01498cb (Add AWS deployment and Redis read-write separation) ```text u32 count diff --git a/examples/README.md b/examples/README.md index b623e48..3c95833 100644 --- a/examples/README.md +++ b/examples/README.md @@ -24,58 +24,6 @@ cargo run --example mdns_discover -- \ Query a name over mDNS: -<<<<<<< HEAD -======= -## HTTP Packet Structure Overview - -`ddns` uses the HTTP/3 protocol to transmit DNS queries and responses, similar to DNS over HTTPS (DoH) but based on the QUIC protocol. The structure of HTTP requests is as follows: - -### URL Structure -- **Base URL**: Default `https://localhost:4433/`, used to specify the DNS server's address. -- **Path**: For queries, usually the root path `/`, the server parses the DNS query based on the request body. -- **Query Parameters**: Optional, used to specify query type or options. - -### HTTP Headers -- **Content-Type**: `application/dns-message` for DNS packet bodies. -- **Content-Digest**: Required for signed publish requests. Uses `sha-256` - (RFC 9530). -- **Signature-Input**: Required for signed publish requests. Uses the `dns` - label and covers `content-digest` (RFC 9421). -- **Signature**: Required for signed publish requests. Contains the publisher - signature for the DNS packet digest (RFC 9421). - -### Request Body (Body) -- DNS queries are sent in binary DNS message format (RFC 1035), containing query name, type (such as A, AAAA, SRV), and class. -- For publishing, the request body contains the DNS packet bytes to be - published. Standard signed publishes carry the signature in HTTP fields, not - in the endpoint `E` records. - -### Response Body -- Publish returns a status body such as `OK`. -- Lookup returns the server multi-record binary format. Each record contains the - saved `Content-Digest`, `Signature-Input`, `Signature`, DNS packet bytes, and - publisher certificate so clients can verify the returned DNS packet. - -## Usage Examples - -### Publishing Services (publish) - -Use the `publish` example to publish a DNS service record to the HTTP/3 DNS server. - -#### Program Parameters -- `--base-url `: Base URL of the DNS server (default: build-time `DHTTP_H3_DNS_SERVER` with a trailing slash). -- `--server-ca `: CA certificate PEM file path for verifying the online server certificate. -- `--client-name `: Client identity name used for mTLS. -- `--client-cert `: Client certificate chain PEM file. -- `--client-key `: Client private key PEM file. -- `--sign`: Whether to sign the DNS packet with RFC 9421/9530-style HTTP fields - (default: true). -- `--host `: DNS name to publish, must match the SAN in the client certificate. -- `--addr `: List of socket addresses to publish, separated by commas. -- `--is-main`: Whether it is the main record (default: true). - -#### Example Run Command ->>>>>>> 01498cb (Add AWS deployment and Redis read-write separation) ```bash cargo run --example mdns_query -- \ --ip 192.168.5.156 \ diff --git a/src/bin/ddns-server/error.rs b/src/bin/ddns-server/error.rs index 76ba9e8..1e8087e 100644 --- a/src/bin/ddns-server/error.rs +++ b/src/bin/ddns-server/error.rs @@ -73,10 +73,7 @@ pub fn normalize_host_allowlist(entries: &[String]) -> Result, AppEr pub fn normalize_host(host: &str, allowlist: &[String]) -> Result { let host = normalize_host_raw(host)?; - if allowlist - .iter() - .any(|suffix| host_matches_suffix(&host, suffix)) - { + if allowlist.iter().any(|suffix| host_matches_suffix(&host, suffix)) { Ok(host) } else { Err(AppError::DomainNotAllowed) @@ -97,7 +94,6 @@ pub fn normalize_host_raw(host: &str) -> Result { _ => host, }; let host = host.strip_suffix('.').unwrap_or(host); - let host = idna::domain_to_ascii(host).map_err(|_| AppError::InvalidHost)?; Ok(host.to_ascii_lowercase()) } @@ -116,21 +112,3 @@ fn host_matches_suffix(host: &str, suffix: &str) -> bool { .is_some_and(|prefix| prefix.ends_with('.')) } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn normalize_host_accepts_allowed_suffixes() { - let allowlist = vec!["genmeta.net".to_string()]; - let host = normalize_host("DNS.Genmeta.Net.", &allowlist).unwrap(); - assert_eq!(host, "dns.genmeta.net"); - } - - #[test] - fn normalize_host_rejects_non_boundary_suffixes() { - let allowlist = vec!["genmeta.net".to_string()]; - let err = normalize_host("evilgenmeta.net", &allowlist).unwrap_err(); - assert!(matches!(err, AppError::DomainNotAllowed)); - } -} diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs index 7e2776f..4340209 100644 --- a/src/bin/ddns-server/policy.rs +++ b/src/bin/ddns-server/policy.rs @@ -5,7 +5,7 @@ use ddns::core::{ use dhttp_identity::identity::RemoteAuthority; use tracing::{debug, warn}; -use crate::error::{AppError, app_error, normalize_host}; +use crate::error::{AppError, normalize_host}; #[derive(Clone, Debug, PartialEq)] pub enum DomainPolicy { @@ -152,8 +152,6 @@ pub fn validate_dns_packet( host: first_answer.name().to_string(), }) } -<<<<<<< HEAD -======= #[cfg(test)] mod tests { @@ -201,4 +199,3 @@ mod tests { assert!(matches!(validated, ValidatedDnsPacket::Empty)); } } ->>>>>>> 01498cb (Add AWS deployment and Redis read-write separation) diff --git a/src/publisher.rs b/src/publisher.rs index f542610..11aec10 100644 --- a/src/publisher.rs +++ b/src/publisher.rs @@ -18,6 +18,7 @@ use dquic::{ qresolve::{Publish, Resolve}, qtraversal::nat::client::{ClientLocationData, NatType}, }; +use dhttp_identity::name::Name; use snafu::{ResultExt, Snafu}; use crate::{ @@ -73,6 +74,46 @@ pub struct PublishOptions { pub server_id: Option, } +#[derive(Debug, Clone, Default)] +pub struct PublishAddresses { + wide_area: Vec, +} + +impl PublishAddresses { + pub fn new() -> Self { + Self::default() + } + + pub fn wide_area( + mut self, + endpoints: impl IntoIterator, + ) -> Self { + self.wide_area.extend(endpoints); + self + } +} + +#[derive(Debug, Clone)] +pub struct EndpointPublisher { + inner: Arc, +} + +impl EndpointPublisher { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + pub async fn publish_once( + &self, + name: &Name<'_>, + addresses: &PublishAddresses, + ) -> Result<(), PublishOnceError> { + self.inner + .publish_addresses_to_resolver(self.inner.resolver.as_ref(), name, &addresses.wide_area) + .await + } +} + pub struct Publisher { identity: Arc, network: Arc, @@ -154,6 +195,84 @@ impl Publisher { Ok(()) } + async fn publish_addresses_to_resolver( + &self, + resolver: &(dyn Resolve + Send + Sync), + name: &Name<'_>, + endpoints: &[EndpointAddr], + ) -> Result<(), PublishOnceError> { + let any: &dyn Any = resolver; + + if let Some(resolvers) = any.downcast_ref::() { + for resolver in resolvers.iter() { + self.publish_addresses_to_single_resolver(resolver.as_ref(), name, endpoints) + .await?; + } + return Ok(()); + } + + self.publish_addresses_to_single_resolver(resolver, name, endpoints) + .await + } + + async fn publish_addresses_to_single_resolver( + &self, + resolver: &(dyn Resolve + Send + Sync), + name: &Name<'_>, + endpoints: &[EndpointAddr], + ) -> Result<(), PublishOnceError> { + let packet = self.dns_packet_for_name(&name.to_string(), endpoints)?; + let any: &dyn Any = resolver; + + #[cfg(feature = "http-resolver")] + if let Some(http) = any.downcast_ref::() { + let signature_fields = SignatureFields::sign(&packet, self.identity.as_ref()) + .await + .context(publish_once_error::SignPacketSnafu)?; + http.publish_signed(&name.to_string(), &packet, &signature_fields) + .await + .context(publish_once_error::PublishSnafu { + publisher: http.to_string(), + })?; + return Ok(()); + } + + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = + any.downcast_ref::>() + { + let signature_fields = SignatureFields::sign(&packet, self.identity.as_ref()) + .await + .context(publish_once_error::SignPacketSnafu)?; + h3.publish_signed(&name.to_string(), &packet, &signature_fields) + .await + .context(publish_once_error::PublishSnafu { + publisher: h3.to_string(), + })?; + return Ok(()); + } + + Ok(()) + } + + fn dns_packet_for_name( + &self, + name: &str, + endpoints: &[EndpointAddr], + ) -> Result, PublishOnceError> { + let mut encoded = Vec::with_capacity(endpoints.len()); + for endpoint in endpoints { + encoded.push( + DnsEndpointAddr::try_from(*endpoint) + .map_err(|_| publish_once_error::EncodeEndpointSnafu.build())?, + ); + } + + let mut hosts = HashMap::new(); + hosts.insert(name.to_owned(), encoded); + Ok(MdnsPacket::answer(0, &hosts).to_bytes()) + } + pub async fn run(&self) -> ! { let mut locations = self.network.quic().locations().subscribe(); let interval = tokio::time::sleep(self.interval); From 307ccfb493a3c4e4440f8a5e23c861034a311363 Mon Sep 17 00:00:00 2001 From: metah3m Date: Mon, 15 Jun 2026 18:30:00 +0800 Subject: [PATCH 07/29] refactor: remove dead code --- Cargo.toml | 17 - server.toml | 2 +- src/bin/ddns-server/error.rs | 11 +- src/bin/ddns-server/lookup.rs | 1040 ----------------- src/bin/ddns-server/lookup/http.rs | 109 ++ src/bin/ddns-server/lookup/mod.rs | 8 + src/bin/ddns-server/lookup/query.rs | 263 +++++ src/bin/ddns-server/lookup/ranking.rs | 254 ++++ src/bin/ddns-server/lookup/tests.rs | 427 +++++++ src/bin/ddns-server/main.rs | 14 +- src/bin/ddns-server/ocsp.rs | 508 +------- src/bin/ddns-server/policy.rs | 3 +- src/bin/ddns-server/publish/http.rs | 163 +++ src/bin/ddns-server/publish/mod.rs | 7 + .../{publish.rs => publish/store.rs} | 292 +---- src/bin/ddns-server/publish/tests.rs | 126 ++ src/core/parser/sigin.rs | 1 - src/publisher.rs | 8 +- src/resolvers/h3.rs | 30 +- src/resolvers/http.rs | 7 +- src/resolvers/selector.rs | 27 +- 21 files changed, 1411 insertions(+), 1906 deletions(-) delete mode 100644 src/bin/ddns-server/lookup.rs create mode 100644 src/bin/ddns-server/lookup/http.rs create mode 100644 src/bin/ddns-server/lookup/mod.rs create mode 100644 src/bin/ddns-server/lookup/query.rs create mode 100644 src/bin/ddns-server/lookup/ranking.rs create mode 100644 src/bin/ddns-server/lookup/tests.rs create mode 100644 src/bin/ddns-server/publish/http.rs create mode 100644 src/bin/ddns-server/publish/mod.rs rename src/bin/ddns-server/{publish.rs => publish/store.rs} (50%) create mode 100644 src/bin/ddns-server/publish/tests.rs diff --git a/Cargo.toml b/Cargo.toml index d719a7c..2bc2943 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,6 @@ base64 = "0.22" bitfield-struct = "0.13" bytes = "1" dashmap = "6" -der = { version = "0.8.0", optional = true } dhttp-identity = { path = "../dhttp/identity", version = "0.1.0" } dquic = { path = "../dquic/dquic", version = "0.5.1" } flume = "0.12" @@ -33,7 +32,6 @@ rustls = { version = "0.23", default-features = false, features = [ ] } rustls-native-certs = { version = "0.8", optional = true } rustls-pemfile = "2" -sha1 = { version = "=0.11.0-rc.5", optional = true } snafu = "0.9" socket2 = { version = "0.6", features = ["all"] } tokio = { version = "1", features = [ @@ -48,11 +46,7 @@ tokio = { version = "1", features = [ tracing = "0.1" x509-parser = { version = "0.18", features = ["verify"] } -<<<<<<< HEAD h3x = { path = "../h3x", default-features = false, optional = true } -======= -h3x = { path = "../h3x", default-features = false, optional = true } ->>>>>>> c51f31b (feat add endpoint publisher helpers) http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -76,7 +70,6 @@ tower-service = { version = "0.3", optional = true } tracing-subscriber = { version = "0.3", features = [ "env-filter", ], optional = true } -x509-cert = { version = "=0.3.0-rc.4", optional = true } [features] default = [] @@ -94,26 +87,19 @@ http-resolver = ["dep:reqwest", "dep:rustls-native-certs"] server = [ "h3x-resolver", "dep:clap", - "dep:der", "dep:deadpool-redis", "dep:idna", "dep:maxminddb", "dep:reqwest", "dep:serde", - "dep:sha1", "dep:toml", "dep:tower-service", "dep:tracing-subscriber", - "dep:x509-cert", ] [dev-dependencies] clap = { version = "4", features = ["derive"] } -<<<<<<< HEAD -h3x = { path = "../h3x", default-features = false, features = [ -======= h3x = { path = "../h3x", default-features = false, features = [ ->>>>>>> c51f31b (feat add endpoint publisher helpers) "dquic", ] } shellexpand = "3" @@ -142,9 +128,6 @@ name = "query" path = "examples/query.rs" required-features = ["h3x-resolver"] -[patch.crates-io] -proc-macro-error2 = { path = "patches/proc-macro-error2" } - [patch."https://github.com/genmeta/dquic.git"] dquic = { path = "../dquic/dquic", version = "0.5.1" } qbase = { path = "../dquic/qbase", version = "0.5.1" } diff --git a/server.toml b/server.toml index a26e622..5dd4cd5 100644 --- a/server.toml +++ b/server.toml @@ -15,7 +15,7 @@ key = "~/Downloads/ssl/dns.genmeta.net/dns.genmeta.net.key" root_cert = "~/Downloads/ssl/root.crt" # Allowed host suffixes. Values are normalized and suffix-matched. -# host_allowlist = ["genmeta.net"] +host_allowlist = ["genmeta.net", "dhttp.net"] # Optional issuer certificate used to build OCSP requests when `cert` only # contains the leaf certificate. If `cert` already contains the full chain, diff --git a/src/bin/ddns-server/error.rs b/src/bin/ddns-server/error.rs index 1e8087e..6a46879 100644 --- a/src/bin/ddns-server/error.rs +++ b/src/bin/ddns-server/error.rs @@ -23,10 +23,6 @@ pub enum AppError { PublisherCertificateSelector { source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, }, - #[snafu(display("endpoint record selector is invalid"))] - EndpointRecordSelector { - source: ddns::core::parser::record::endpoint::EndpointSelectorError, - }, #[snafu(display("endpoint record selector does not match publisher certificate selector"))] EndpointSelectorMismatch, #[snafu(display("no answers in packet"))] @@ -51,7 +47,6 @@ impl AppError { AppError::ClientCertDomainNotAllowed => http::StatusCode::FORBIDDEN, AppError::InvalidDnsPacket { .. } => http::StatusCode::BAD_REQUEST, AppError::PublisherCertificateSelector { .. } => http::StatusCode::BAD_REQUEST, - AppError::EndpointRecordSelector { .. } => http::StatusCode::BAD_REQUEST, AppError::EndpointSelectorMismatch => http::StatusCode::BAD_REQUEST, AppError::NoAnswersInPacket => http::StatusCode::UNPROCESSABLE_ENTITY, AppError::SignatureRequired => http::StatusCode::BAD_REQUEST, @@ -73,7 +68,10 @@ pub fn normalize_host_allowlist(entries: &[String]) -> Result, AppEr pub fn normalize_host(host: &str, allowlist: &[String]) -> Result { let host = normalize_host_raw(host)?; - if allowlist.iter().any(|suffix| host_matches_suffix(&host, suffix)) { + if allowlist + .iter() + .any(|suffix| host_matches_suffix(&host, suffix)) + { Ok(host) } else { Err(AppError::DomainNotAllowed) @@ -111,4 +109,3 @@ fn host_matches_suffix(host: &str, suffix: &str) -> bool { .strip_suffix(suffix) .is_some_and(|prefix| prefix.ends_with('.')) } - diff --git a/src/bin/ddns-server/lookup.rs b/src/bin/ddns-server/lookup.rs deleted file mode 100644 index e157c52..0000000 --- a/src/bin/ddns-server/lookup.rs +++ /dev/null @@ -1,1040 +0,0 @@ -use std::{ - any::Any, - cmp::Ordering, - collections::{HashMap, HashSet}, - convert::Infallible, - hash::Hash, - net::{IpAddr, SocketAddr}, - sync::Arc, -}; - -use ddns::core::{ - MdnsPacket, - parser::{packet::be_packet, record::RData}, - wire::{MultiResponse, ResponseRecord}, -}; -use deadpool_redis::redis::{self, AsyncCommands}; -use h3x::{connection::ConnectionState, dhttp::message::MessageStreamError, quic}; -use http_body_util::{Full, combinators::UnsyncBoxBody}; -use tracing::debug; - -use crate::{ - error::{AppError, normalize_host, parse_query_params}, - geo::{GeoResolver, GeoTraits}, - storage::{ - AppState, LookupRecord, MemoryStorage, SeedRecords, Storage, StoredRecord, - redis_all_index_key, redis_asn_index_key, redis_blacklist_key, redis_country_index_key, - redis_primary_key, unix_now_secs, - }, -}; - -pub type Request = http::Request>; -pub type Response = http::Response>; - -// --------------------------------------------------------------------------- -// Lookup result type -// --------------------------------------------------------------------------- - -pub enum LookupResult { - NotFound, - /// Multiple records, newest-first. - Multi(MultiResponse), -} - -type EndpointKey = (SocketAddr, Option); - -const LOOKUP_CANDIDATE_CAP_TOTAL: usize = 64; -const LOOKUP_CANDIDATE_CAP_ASN: usize = 16; -const LOOKUP_CANDIDATE_CAP_COUNTRY: usize = 16; -const LOOKUP_CANDIDATE_CAP_ALL: usize = 32; - -// GEO-aware ranking dimensions. Final ordering still falls back to the original -// record index so we keep lookups stable when all computed dimensions tie. -#[derive(Clone, Copy, Debug, PartialEq)] -struct GeoSortKey { - same_country: bool, - same_asn: bool, - family_match: bool, - same_city: bool, - load: Option, - geo_distance: Option, -} - -fn normalize_lookup_records(records: Vec) -> Vec { - let mut normalized = Vec::new(); - let mut seen = HashSet::new(); - - for record in records { - if !record.signature_fields.is_empty() { - normalized.push(record); - continue; - } - - let Ok((_, packet)) = be_packet(&record.dns) else { - normalized.push(record); - continue; - }; - - let mut emitted_endpoint = false; - - for answer in &packet.answers { - let RData::E(endpoint) = answer.data() else { - continue; - }; - - emitted_endpoint = true; - let key: EndpointKey = (endpoint.addr(), endpoint.agent_addr()); - - if !seen.insert(key) { - continue; - } - - let mut hosts = HashMap::new(); - hosts.insert(answer.name().to_string(), vec![endpoint.clone()]); - normalized.push(ResponseRecord::unsigned( - MdnsPacket::answer(0, &hosts).to_bytes(), - record.cert.clone(), - )); - } - - if !emitted_endpoint { - normalized.push(record); - } - } - - normalized -} - -fn lookup_endpoint(dns_bytes: &[u8]) -> Option<(SocketAddr, Option)> { - let (_, packet) = be_packet(dns_bytes).ok()?; - packet - .answers - .iter() - .find_map(|answer| match answer.data() { - RData::E(endpoint) => Some((endpoint.addr(), endpoint.load())), - _ => None, - }) -} - -// Fallback ordering when GEO routing is disabled: prefer matching address family, -// then lower load, and finally preserve input order. We intentionally avoid -// IP prefix heuristics here because they are not reliable on the public Internet. -fn sort_lookup_records(records: Vec, source_ip: Option) -> Vec { - let mut decorated = records - .into_iter() - .enumerate() - .map(|(index, record)| { - let sort_key = lookup_endpoint(&record.dns).map(|(endpoint, load)| { - let family_match = source_ip - .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) - .unwrap_or(false); - - (family_match, load) - }); - (sort_key, index, record) - }) - .collect::>(); - - decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { - match (left_key, right_key) { - (Some((left_family, left_load)), Some((right_family, right_load))) => right_family - .cmp(left_family) - .then_with(|| match (left_load, right_load) { - (Some(left), Some(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal), - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, - (None, None) => Ordering::Equal, - }), - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, - (None, None) => Ordering::Equal, - } - .then_with(|| left_index.cmp(right_index)) - }); - - decorated.into_iter().map(|(_, _, record)| record).collect() -} - -fn request_source_geo_traits( - source_ip: Option, - geo: Option<&GeoResolver>, -) -> Option { - Some(geo?.lookup_traits(source_ip?)) -} - -fn lookup_endpoint_geo_traits( - dns_bytes: &[u8], - geo: &GeoResolver, -) -> Option<(SocketAddr, Option, GeoTraits)> { - let (endpoint, load) = lookup_endpoint(dns_bytes)?; - Some((endpoint, load, geo.lookup_traits(endpoint.ip()))) -} - -fn compare_optional_partial(left: Option, right: Option) -> Ordering { - match (left, right) { - (Some(left), Some(right)) => left.partial_cmp(&right).unwrap_or(Ordering::Equal), - _ => Ordering::Equal, - } -} - -// GEO ordering is layered rather than score-based: -// country > ASN > address family > city name > lower load > shorter GEO distance. -// Missing optional values do not penalize a candidate; they simply skip that layer. -fn compare_geo_sort_keys(left: GeoSortKey, right: GeoSortKey) -> Ordering { - right - .same_country - .cmp(&left.same_country) - .then_with(|| right.same_asn.cmp(&left.same_asn)) - .then_with(|| right.family_match.cmp(&left.family_match)) - .then_with(|| right.same_city.cmp(&left.same_city)) - .then_with(|| compare_optional_partial(left.load, right.load)) - .then_with(|| compare_optional_partial(left.geo_distance, right.geo_distance)) -} - -// Build the per-endpoint GEO ranking tuple. City name only participates when both -// sides have a name and already match on country; coordinate distance only -// participates when GeoResolver accepts both accuracy radii. -fn build_geo_sort_key( - source_ip: Option, - source_traits: Option<&GeoTraits>, - endpoint: SocketAddr, - load: Option, - endpoint_traits: &GeoTraits, - geo: &GeoResolver, -) -> GeoSortKey { - let family_match = source_ip - .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) - .unwrap_or(false); - - let same_country = source_traits - .and_then(|source| source.country.as_deref()) - .zip(endpoint_traits.country.as_deref()) - .is_some_and(|(source, target)| source == target); - - let same_asn = source_traits - .and_then(|source| source.asn) - .zip(endpoint_traits.asn) - .is_some_and(|(source, target)| source == target); - - let same_city = same_country - && source_traits - .and_then(|source| source.city.as_deref()) - .zip(endpoint_traits.city.as_deref()) - .is_some_and(|(source, target)| source == target); - - let geo_distance = source_traits - .and_then(|source| source.point.as_ref()) - .zip(endpoint_traits.point.as_ref()) - .and_then(|(source, target)| geo.geo_distance_km(source, target)); - - GeoSortKey { - same_country, - same_asn, - family_match, - same_city, - load, - geo_distance, - } -} - -fn candidate_total_cap(limit: Option) -> usize { - limit - .unwrap_or(LOOKUP_CANDIDATE_CAP_TOTAL) - .max(LOOKUP_CANDIDATE_CAP_TOTAL) -} - -fn all_candidate_cap(total_cap: usize, source_traits: Option<&GeoTraits>) -> usize { - let has_geo_buckets = source_traits - .is_some_and(|traits| traits.asn.is_some() || traits.country.as_deref().is_some()); - - if has_geo_buckets { - LOOKUP_CANDIDATE_CAP_ALL.min(total_cap) - } else { - total_cap - } -} - -fn push_unique_candidates( - candidates: &mut Vec, - seen: &mut HashSet, - source: impl IntoIterator, - total_cap: usize, -) where - T: Clone + Eq + Hash, -{ - for item in source { - if candidates.len() >= total_cap { - break; - } - - if seen.insert(item.clone()) { - candidates.push(item); - } - } -} - -fn sort_lookup_records_with_geo( - records: Vec, - source_ip: Option, - geo: &GeoResolver, -) -> Vec { - let source_traits = request_source_geo_traits(source_ip, Some(geo)); - - let mut decorated = records - .into_iter() - .enumerate() - .map(|(index, record)| { - let sort_key = lookup_endpoint_geo_traits(&record.dns, geo).map( - |(endpoint, load, endpoint_traits)| { - build_geo_sort_key( - source_ip, - source_traits.as_ref(), - endpoint, - load, - &endpoint_traits, - geo, - ) - }, - ); - (sort_key, index, record) - }) - .collect::>(); - - decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { - match (left_key, right_key) { - (Some(left_key), Some(right_key)) => compare_geo_sort_keys(*left_key, *right_key), - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, - (None, None) => Ordering::Equal, - } - .then_with(|| left_index.cmp(right_index)) - }); - - decorated.into_iter().map(|(_, _, record)| record).collect() -} - -fn request_source_ip(request: &Request) -> Option { - let connection = request - .extensions() - .get::>>()? - .clone(); - let quic = connection.quic(); - let dquic = (quic.as_ref() as &dyn Any).downcast_ref::()?; - let ctx = dquic.path_context().ok()?; - - ctx.paths::>() - .into_iter() - .next() - .map(|(pathway, _)| pathway.remote().addr().ip()) -} - -// --------------------------------------------------------------------------- -// Core lookup logic -// --------------------------------------------------------------------------- - -pub async fn perform_lookup( - state: &AppState, - host: &str, - limit: Option, - source_ip: Option, -) -> Result { - let host = normalize_host(host, state.host_allowlist.as_ref())?; - perform_lookup_multi(state, &host, limit, source_ip).await -} - -async fn perform_lookup_multi( - state: &AppState, - host: &str, - limit: Option, - source_ip: Option, -) -> Result { - let source_traits = request_source_geo_traits(source_ip, state.geo.as_deref()); - let candidate_total = candidate_total_cap(limit); - let candidate_all = all_candidate_cap(candidate_total, source_traits.as_ref()); - - let dynamic_records = match &state.storage { - Storage::Redis(redis) => { - let mut conn = redis.read.get().await.map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - if redis_host_blacklisted(&mut *conn, host).await? { - debug!(host = %host, "lookup.blacklisted"); - return Ok(LookupResult::NotFound); - } - - let now_secs = unix_now_secs(); - let mut candidate_fingerprints = Vec::new(); - let mut seen_fingerprints = HashSet::new(); - - if let Some(asn) = source_traits.as_ref().and_then(|traits| traits.asn) { - let index_key = redis_asn_index_key(host, asn); - let members: Vec = conn - .zrevrange( - &index_key, - 0isize, - LOOKUP_CANDIDATE_CAP_ASN.saturating_sub(1) as isize, - ) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - push_unique_candidates( - &mut candidate_fingerprints, - &mut seen_fingerprints, - members, - candidate_total, - ); - } - - if let Some(country) = source_traits - .as_ref() - .and_then(|traits| traits.country.as_deref()) - { - let index_key = redis_country_index_key(host, country); - let members: Vec = conn - .zrevrange( - &index_key, - 0isize, - LOOKUP_CANDIDATE_CAP_COUNTRY.saturating_sub(1) as isize, - ) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - push_unique_candidates( - &mut candidate_fingerprints, - &mut seen_fingerprints, - members, - candidate_total, - ); - } - - let all_index_key = redis_all_index_key(host); - let all_members: Vec = conn - .zrevrange( - &all_index_key, - 0isize, - candidate_all.saturating_sub(1) as isize, - ) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - push_unique_candidates( - &mut candidate_fingerprints, - &mut seen_fingerprints, - all_members, - candidate_total, - ); - - let mut records = Vec::new(); - for fingerprint in candidate_fingerprints { - let primary_key = redis_primary_key(host, &fingerprint); - let member: Option> = - conn.get(&primary_key).await.map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - let Some(member) = member else { - continue; - }; - let Some(record) = StoredRecord::decode(&member) else { - continue; - }; - if record.expire_unix_secs > now_secs { - records.push(ResponseRecord::new( - record.signature_fields, - record.dns, - record.cert, - )); - } - } - - records - } - Storage::Memory(mem) => { - if mem.is_blacklisted(host) { - debug!(host = %host, "lookup.blacklisted"); - return Ok(LookupResult::NotFound); - } - - let now = tokio::time::Instant::now(); - if let Some(mut entry) = mem.records.get_mut(host) { - entry.retain_active(now); - let candidate_fingerprints = entry.collect_candidates( - source_traits - .as_ref() - .and_then(|traits| traits.country.as_deref()), - source_traits.as_ref().and_then(|traits| traits.asn), - candidate_total, - LOOKUP_CANDIDATE_CAP_ASN, - LOOKUP_CANDIDATE_CAP_COUNTRY, - candidate_all, - ); - - candidate_fingerprints - .into_iter() - .filter_map(|fingerprint| { - entry.records.get(&fingerprint).map(|record| { - ResponseRecord::new( - record.signature_fields.clone(), - record.dns_bytes.clone(), - record.cert_bytes.clone(), - ) - }) - }) - .collect::>() - } else { - vec![] - } - } - }; - - let normalized_dynamic_records = normalize_lookup_records(dynamic_records); - let mut records = if let Some(geo) = state.geo.as_deref() { - sort_lookup_records_with_geo(normalized_dynamic_records, source_ip, geo) - } else { - sort_lookup_records(normalized_dynamic_records, source_ip) - }; - - let should_append_seeds = records.is_empty() || limit.is_some_and(|max| records.len() < max); - if should_append_seeds && let Some(seed_records) = state.seed_records.get(host) { - let seeds = if let Some(geo) = state.geo.as_deref() { - sort_lookup_records_with_geo(seed_records.iter().cloned().collect(), source_ip, geo) - } else { - sort_lookup_records(seed_records.iter().cloned().collect(), source_ip) - }; - records.extend(seeds); - } - - let records = normalize_lookup_records(records); - let records = if let Some(limit) = limit { - records.into_iter().take(limit).collect::>() - } else { - records - }; - - if records.is_empty() { - Ok(LookupResult::NotFound) - } else { - Ok(LookupResult::Multi(MultiResponse::new(records))) - } -} - -async fn redis_host_blacklisted(conn: &mut C, host: &str) -> Result -where - C: redis::aio::ConnectionLike + Send + Sync, -{ - conn.sismember(redis_blacklist_key(), host) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - }) -} - -// --------------------------------------------------------------------------- -// HTTP response helpers -// --------------------------------------------------------------------------- - -pub fn body_response(status: http::StatusCode, body: impl Into) -> Response { - http::Response::builder() - .status(status) - .body(Full::new(body.into())) - .expect("response parts must be valid") -} - -pub fn write_error(err: AppError) -> Response { - debug!( - status = %err.status(), - error = %err, - "writing error response" - ); - body_response(err.status(), bytes::Bytes::from(err.to_string())) -} - -// --------------------------------------------------------------------------- -// LookupSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct LookupSvc { - pub state: AppState, -} - -/// Handle a lookup request. -/// -/// Always returns multi-record binary body: -/// `[u32 count BE]([u32 dns_len BE][dns][u32 cert_len BE][cert])*` -/// with header `x-record-format: multi`. -/// -/// Optional query param `limit=N` caps the number of records returned. -/// Dynamic records are newest-first; configured seed records are appended after them. -pub async fn lookup_with_cert(state: AppState, request: Request) -> Response { - let params = parse_query_params(request.uri()); - let Some(host) = params.get("host") else { - return write_error(AppError::MissingHostParam); - }; - let source_ip = request_source_ip(&request); - - let limit: Option = params - .get("limit") - .and_then(|v| v.parse::().ok()) - .filter(|&n| n > 0); - - debug!(host = %host, limit, ?source_ip, "lookup.request"); - - match perform_lookup(&state, host, limit, source_ip).await { - Ok(LookupResult::NotFound) => { - debug!(host = %host, "lookup.not_found"); - body_response( - http::StatusCode::NOT_FOUND, - bytes::Bytes::from_static(b"Not Found"), - ) - } - - Ok(LookupResult::Multi(resp)) => { - let body = resp.encode(); - debug!(host = %host, records = resp.records.len(), "lookup.found"); - let mut response = body_response(http::StatusCode::OK, bytes::Bytes::from(body)); - response.headers_mut().insert( - http::HeaderName::from_static("x-record-format"), - http::HeaderValue::from_static("multi"), - ); - response - } - - Err(e) => write_error(e), - } -} - -impl LookupSvc { - pub fn call( - &self, - request: Request, - ) -> impl Future> + Send + 'static { - let state = self.state.clone(); - async move { Ok(lookup_with_cert(state, request).await) } - } -} - -#[cfg(test)] -mod tests { - use std::{ - net::{IpAddr, Ipv4Addr, SocketAddrV4}, - path::PathBuf, - }; - - use ddns::core::{MdnsEndpoint, signature::SignatureFields}; - - use super::*; - use crate::geo::{GeoPoint, GeoResolver}; - - fn fixture_geo_resolver() -> GeoResolver { - let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let city_db = manifest_dir.join("geoip/GeoLite2-City.mmdb"); - let asn_db = manifest_dir.join("geoip/GeoLite2-ASN.mmdb"); - - GeoResolver::open(&city_db, &asn_db, true, 100).expect("fixture geo db should open") - } - - fn lookup_record(host: &str, addr: SocketAddr, load: Option) -> LookupRecord { - let mut endpoint = match addr { - SocketAddr::V4(addr) => MdnsEndpoint::direct_v4(addr), - SocketAddr::V6(addr) => MdnsEndpoint::direct_v6(addr), - }; - endpoint.set_load(load); - - let mut hosts = HashMap::new(); - hosts.insert(host.to_string(), vec![endpoint]); - - ResponseRecord::unsigned(MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new()) - } - - struct FakeRedis { - response: redis::Value, - packed_commands: Vec>, - } - - impl redis::aio::ConnectionLike for FakeRedis { - fn req_packed_command<'a>( - &'a mut self, - cmd: &'a redis::Cmd, - ) -> redis::RedisFuture<'a, redis::Value> { - self.packed_commands.push(cmd.get_packed_command()); - let response = self.response.clone(); - Box::pin(async move { Ok(response) }) - } - - fn req_packed_commands<'a>( - &'a mut self, - _cmd: &'a redis::Pipeline, - _offset: usize, - _count: usize, - ) -> redis::RedisFuture<'a, Vec> { - Box::pin(async move { Ok(Vec::new()) }) - } - - fn get_db(&self) -> i64 { - 0 - } - } - - #[tokio::test] - async fn redis_host_blacklisted_queries_external_blacklist_set() { - let mut redis = FakeRedis { - response: redis::Value::Int(1), - packed_commands: Vec::new(), - }; - - let blacklisted = redis_host_blacklisted(&mut redis, "blocked.example.genmeta.net") - .await - .unwrap(); - - assert!(blacklisted); - assert_eq!(redis.packed_commands.len(), 1); - let command = String::from_utf8(redis.packed_commands.remove(0)).unwrap(); - assert!(command.contains("SISMEMBER")); - assert!(command.contains(redis_blacklist_key())); - assert!(command.contains("blocked.example.genmeta.net")); - } - - #[tokio::test] - async fn memory_blacklist_returns_not_found_before_seed_records() { - let host = "blocked.example.genmeta.net"; - let mut seed_records = HashMap::new(); - seed_records.insert( - host.to_string(), - vec![lookup_record( - host, - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - None, - )], - ); - let state = AppState { - storage: Storage::Memory(MemoryStorage::with_blacklist([host.to_string()])), - host_allowlist: Arc::new(vec!["genmeta.net".to_string()]), - require_signature: false, - ttl_secs: 30, - policies: Arc::new(crate::policy::DomainPolicies::default()), - seed_records: SeedRecords::new(seed_records), - geo: None, - }; - - let result = perform_lookup(&state, host, None, None).await.unwrap(); - - assert!(matches!(result, LookupResult::NotFound)); - } - - #[test] - fn normalize_lookup_records_keeps_signed_packets_whole() { - let mut record = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - None, - ); - record.signature_fields = SignatureFields { - content_digest: b"sha-256=:abc:".to_vec(), - signature_input: - b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"".to_vec(), - signature: b"dns=:sig:".to_vec(), - }; - - let normalized = normalize_lookup_records(vec![record.clone()]); - - assert_eq!(normalized.len(), 1); - assert_eq!(normalized[0], record); - } - - #[test] - fn compare_geo_sort_keys_follows_documented_priority() { - let best = GeoSortKey { - same_country: true, - same_asn: true, - family_match: true, - same_city: true, - load: Some(0.2), - geo_distance: Some(20.0), - }; - let worse_load = GeoSortKey { - load: Some(0.8), - ..best - }; - let worse_family = GeoSortKey { - same_asn: true, - family_match: false, - same_city: true, - load: Some(0.1), - geo_distance: Some(1.0), - ..best - }; - let worse_city = GeoSortKey { - same_city: false, - load: Some(0.1), - geo_distance: Some(1.0), - ..best - }; - let worse_asn = GeoSortKey { - same_asn: false, - family_match: true, - same_city: true, - load: Some(0.1), - geo_distance: Some(1.0), - ..best - }; - let worse_country = GeoSortKey { - same_country: false, - same_asn: true, - family_match: true, - same_city: false, - load: Some(0.1), - geo_distance: Some(1.0), - }; - - assert_eq!(compare_geo_sort_keys(best, worse_load), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_family), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_city), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_asn), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_country), Ordering::Less); - } - - #[test] - fn compare_geo_sort_keys_skips_unknown_dimensions() { - let known_distance = GeoSortKey { - same_country: true, - same_asn: true, - family_match: true, - same_city: true, - load: Some(0.2), - geo_distance: Some(10.0), - }; - let missing_distance = GeoSortKey { - geo_distance: None, - ..known_distance - }; - let missing_load = GeoSortKey { - load: None, - ..known_distance - }; - - assert_eq!( - compare_geo_sort_keys(known_distance, missing_distance), - Ordering::Equal - ); - assert_eq!( - compare_geo_sort_keys(known_distance, missing_load), - Ordering::Equal - ); - } - - #[test] - fn sort_lookup_records_with_geo_prefers_same_source_endpoint_even_with_higher_load() { - let geo = fixture_geo_resolver(); - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); - let matching = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - Some(0.9), - ); - let non_matching = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 3478)), - Some(0.1), - ); - - let sorted = - sort_lookup_records_with_geo(vec![non_matching, matching.clone()], source_ip, &geo); - - let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); - assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); - } - - #[test] - fn sort_lookup_records_without_geo_ignores_ip_prefix_and_prefers_lower_load() { - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); - let closer_prefix_higher_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 3478)), - Some(0.9), - ); - let farther_prefix_lower_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - Some(0.1), - ); - - let sorted = sort_lookup_records( - vec![closer_prefix_higher_load, farther_prefix_lower_load], - source_ip, - ); - - let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); - assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); - } - - #[test] - fn sort_lookup_records_with_geo_prefers_same_asn_then_same_country_on_real_ips() { - let geo = fixture_geo_resolver(); - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); - - let different_country = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - Some(0.01), - ); - let same_country_different_asn = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), - Some(0.02), - ); - let same_country_same_asn = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), - Some(0.9), - ); - - let sorted = sort_lookup_records_with_geo( - vec![ - different_country, - same_country_different_asn, - same_country_same_asn, - ], - source_ip, - &geo, - ); - - let ordered_ips = sorted - .iter() - .map(|record| { - lookup_endpoint(&record.dns) - .expect("record should decode") - .0 - .ip() - }) - .collect::>(); - - assert_eq!( - ordered_ips, - vec![ - IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), - IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114)), - IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), - ] - ); - } - - #[test] - fn sort_lookup_records_with_geo_prefers_same_country_over_lower_load_on_real_ips() { - let geo = fixture_geo_resolver(); - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114))); - - let different_country_low_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(80, 80, 80, 80), 3478)), - Some(0.01), - ); - let same_country_higher_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), - Some(0.9), - ); - - let sorted = sort_lookup_records_with_geo( - vec![different_country_low_load, same_country_higher_load.clone()], - source_ip, - &geo, - ); - - let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); - assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); - } - - #[test] - fn build_geo_sort_key_ignores_city_distance_when_accuracy_is_too_large() { - let geo = fixture_geo_resolver(); - let source_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Beijing".to_string()), - asn: Some(64512), - point: Some(GeoPoint { - latitude: 39.9, - longitude: 116.4, - accuracy_radius_km: 500, - }), - }; - let endpoint_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Shanghai".to_string()), - asn: Some(64512), - point: Some(GeoPoint { - latitude: 31.2, - longitude: 121.5, - accuracy_radius_km: 10, - }), - }; - - let key = build_geo_sort_key( - Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))), - Some(&source_traits), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 3478)), - Some(0.2), - &endpoint_traits, - &geo, - ); - - assert!(key.same_country); - assert!(key.same_asn); - assert!(!key.same_city); - assert_eq!(key.geo_distance, None); - } - - #[test] - fn build_geo_sort_key_prefers_same_city_when_available() { - let geo = fixture_geo_resolver(); - let source_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Hangzhou".to_string()), - asn: Some(64512), - point: None, - }; - let same_city_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Hangzhou".to_string()), - asn: Some(64513), - point: None, - }; - let different_city_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Shanghai".to_string()), - asn: Some(64513), - point: None, - }; - - let same_city_key = build_geo_sort_key( - Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), - Some(&source_traits), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 6, 6, 6), 3478)), - Some(0.9), - &same_city_traits, - &geo, - ); - let different_city_key = build_geo_sort_key( - Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), - Some(&source_traits), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), - Some(0.1), - &different_city_traits, - &geo, - ); - - assert!(same_city_key.same_city); - assert!(!different_city_key.same_city); - assert_eq!( - compare_geo_sort_keys(same_city_key, different_city_key), - Ordering::Less - ); - } -} diff --git a/src/bin/ddns-server/lookup/http.rs b/src/bin/ddns-server/lookup/http.rs new file mode 100644 index 0000000..9e68088 --- /dev/null +++ b/src/bin/ddns-server/lookup/http.rs @@ -0,0 +1,109 @@ +use std::{any::Any, convert::Infallible, net::IpAddr, sync::Arc}; + +use h3x::{connection::ConnectionState, dhttp::message::MessageStreamError, quic}; +use http_body_util::{Full, combinators::UnsyncBoxBody}; +use tracing::debug; + +use super::query::{LookupResult, perform_lookup}; +use crate::{ + error::{AppError, parse_query_params}, + storage::AppState, +}; + +pub type Request = http::Request>; +pub type Response = http::Response>; + +fn request_source_ip(request: &Request) -> Option { + let connection = request + .extensions() + .get::>>()? + .clone(); + let quic = connection.quic(); + let dquic = (quic.as_ref() as &dyn Any).downcast_ref::()?; + let ctx = dquic.path_context().ok()?; + + ctx.paths::>() + .into_iter() + .next() + .map(|(pathway, _)| pathway.remote().addr().ip()) +} +pub fn body_response(status: http::StatusCode, body: impl Into) -> Response { + http::Response::builder() + .status(status) + .body(Full::new(body.into())) + .expect("response parts must be valid") +} + +pub fn write_error(err: AppError) -> Response { + debug!( + status = %err.status(), + error = %err, + "writing error response" + ); + body_response(err.status(), bytes::Bytes::from(err.to_string())) +} + +// --------------------------------------------------------------------------- +// LookupSvc +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct LookupSvc { + pub state: AppState, +} + +/// Handle a lookup request. +/// +/// Always returns multi-record binary body: +/// `[u32 count BE]([u32 dns_len BE][dns][u32 cert_len BE][cert])*` +/// with header `x-record-format: multi`. +/// +/// Optional query param `limit=N` caps the number of records returned. +/// Dynamic records are newest-first; configured seed records are appended after them. +pub async fn lookup_with_cert(state: AppState, request: Request) -> Response { + let params = parse_query_params(request.uri()); + let Some(host) = params.get("host") else { + return write_error(AppError::MissingHostParam); + }; + let source_ip = request_source_ip(&request); + + let limit: Option = params + .get("limit") + .and_then(|v| v.parse::().ok()) + .filter(|&n| n > 0); + + debug!(host = %host, limit, ?source_ip, "lookup.request"); + + match perform_lookup(&state, host, limit, source_ip).await { + Ok(LookupResult::NotFound) => { + debug!(host = %host, "lookup.not_found"); + body_response( + http::StatusCode::NOT_FOUND, + bytes::Bytes::from_static(b"Not Found"), + ) + } + + Ok(LookupResult::Multi(resp)) => { + let body = resp.encode(); + debug!(host = %host, records = resp.records.len(), "lookup.found"); + let mut response = body_response(http::StatusCode::OK, bytes::Bytes::from(body)); + response.headers_mut().insert( + http::HeaderName::from_static("x-record-format"), + http::HeaderValue::from_static("multi"), + ); + response + } + + Err(e) => write_error(e), + } +} + +impl LookupSvc { + pub fn call( + &self, + request: Request, + ) -> impl Future> + Send + 'static { + let state = self.state.clone(); + async move { Ok(lookup_with_cert(state, request).await) } + } +} diff --git a/src/bin/ddns-server/lookup/mod.rs b/src/bin/ddns-server/lookup/mod.rs new file mode 100644 index 0000000..62e795a --- /dev/null +++ b/src/bin/ddns-server/lookup/mod.rs @@ -0,0 +1,8 @@ +mod http; +pub(crate) mod query; +mod ranking; + +pub use http::{LookupSvc, Request, Response, body_response, write_error}; + +#[cfg(test)] +mod tests; diff --git a/src/bin/ddns-server/lookup/query.rs b/src/bin/ddns-server/lookup/query.rs new file mode 100644 index 0000000..8db1328 --- /dev/null +++ b/src/bin/ddns-server/lookup/query.rs @@ -0,0 +1,263 @@ +use std::{collections::HashSet, hash::Hash, net::IpAddr}; + +use ddns::core::wire::{MultiResponse, ResponseRecord}; +use deadpool_redis::redis::{self, AsyncCommands}; +use tracing::debug; + +use super::ranking::{ + LOOKUP_CANDIDATE_CAP_ALL, LOOKUP_CANDIDATE_CAP_ASN, LOOKUP_CANDIDATE_CAP_COUNTRY, + LOOKUP_CANDIDATE_CAP_TOTAL, normalize_lookup_records, request_source_geo_traits, + sort_lookup_records, sort_lookup_records_with_geo, +}; +use crate::{ + error::{AppError, normalize_host}, + geo::GeoTraits, + storage::{ + AppState, Storage, StoredRecord, redis_all_index_key, redis_asn_index_key, + redis_blacklist_key, redis_country_index_key, redis_primary_key, unix_now_secs, + }, +}; + +pub enum LookupResult { + NotFound, + /// Multiple records, newest-first. + Multi(MultiResponse), +} +fn candidate_total_cap(limit: Option) -> usize { + limit + .unwrap_or(LOOKUP_CANDIDATE_CAP_TOTAL) + .max(LOOKUP_CANDIDATE_CAP_TOTAL) +} + +fn all_candidate_cap(total_cap: usize, source_traits: Option<&GeoTraits>) -> usize { + let has_geo_buckets = source_traits + .is_some_and(|traits| traits.asn.is_some() || traits.country.as_deref().is_some()); + + if has_geo_buckets { + LOOKUP_CANDIDATE_CAP_ALL.min(total_cap) + } else { + total_cap + } +} + +fn push_unique_candidates( + candidates: &mut Vec, + seen: &mut HashSet, + source: impl IntoIterator, + total_cap: usize, +) where + T: Clone + Eq + Hash, +{ + for item in source { + if candidates.len() >= total_cap { + break; + } + + if seen.insert(item.clone()) { + candidates.push(item); + } + } +} +pub async fn perform_lookup( + state: &AppState, + host: &str, + limit: Option, + source_ip: Option, +) -> Result { + let host = normalize_host(host, state.host_allowlist.as_ref())?; + perform_lookup_multi(state, &host, limit, source_ip).await +} + +async fn perform_lookup_multi( + state: &AppState, + host: &str, + limit: Option, + source_ip: Option, +) -> Result { + let source_traits = request_source_geo_traits(source_ip, state.geo.as_deref()); + let candidate_total = candidate_total_cap(limit); + let candidate_all = all_candidate_cap(candidate_total, source_traits.as_ref()); + + let dynamic_records = match &state.storage { + Storage::Redis(redis) => { + let mut conn = redis.read.get().await.map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + if redis_host_blacklisted(&mut *conn, host).await? { + debug!(host = %host, "lookup.blacklisted"); + return Ok(LookupResult::NotFound); + } + + let now_secs = unix_now_secs(); + let mut candidate_fingerprints = Vec::new(); + let mut seen_fingerprints = HashSet::new(); + + if let Some(asn) = source_traits.as_ref().and_then(|traits| traits.asn) { + let index_key = redis_asn_index_key(host, asn); + let members: Vec = conn + .zrevrange( + &index_key, + 0isize, + LOOKUP_CANDIDATE_CAP_ASN.saturating_sub(1) as isize, + ) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + push_unique_candidates( + &mut candidate_fingerprints, + &mut seen_fingerprints, + members, + candidate_total, + ); + } + + if let Some(country) = source_traits + .as_ref() + .and_then(|traits| traits.country.as_deref()) + { + let index_key = redis_country_index_key(host, country); + let members: Vec = conn + .zrevrange( + &index_key, + 0isize, + LOOKUP_CANDIDATE_CAP_COUNTRY.saturating_sub(1) as isize, + ) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + push_unique_candidates( + &mut candidate_fingerprints, + &mut seen_fingerprints, + members, + candidate_total, + ); + } + + let all_index_key = redis_all_index_key(host); + let all_members: Vec = conn + .zrevrange( + &all_index_key, + 0isize, + candidate_all.saturating_sub(1) as isize, + ) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + push_unique_candidates( + &mut candidate_fingerprints, + &mut seen_fingerprints, + all_members, + candidate_total, + ); + + let mut records = Vec::new(); + for fingerprint in candidate_fingerprints { + let primary_key = redis_primary_key(host, &fingerprint); + let member: Option> = + conn.get(&primary_key).await.map_err(|e| AppError::Redis { + message: e.to_string(), + })?; + + let Some(member) = member else { + continue; + }; + let Some(record) = StoredRecord::decode(&member) else { + continue; + }; + if record.expire_unix_secs > now_secs { + records.push(ResponseRecord::new( + record.signature_fields, + record.dns, + record.cert, + )); + } + } + + records + } + Storage::Memory(mem) => { + if mem.is_blacklisted(host) { + debug!(host = %host, "lookup.blacklisted"); + return Ok(LookupResult::NotFound); + } + + let now = tokio::time::Instant::now(); + if let Some(mut entry) = mem.records.get_mut(host) { + entry.retain_active(now); + let candidate_fingerprints = entry.collect_candidates( + source_traits + .as_ref() + .and_then(|traits| traits.country.as_deref()), + source_traits.as_ref().and_then(|traits| traits.asn), + candidate_total, + LOOKUP_CANDIDATE_CAP_ASN, + LOOKUP_CANDIDATE_CAP_COUNTRY, + candidate_all, + ); + + candidate_fingerprints + .into_iter() + .filter_map(|fingerprint| { + entry.records.get(&fingerprint).map(|record| { + ResponseRecord::new( + record.signature_fields.clone(), + record.dns_bytes.clone(), + record.cert_bytes.clone(), + ) + }) + }) + .collect::>() + } else { + vec![] + } + } + }; + + let normalized_dynamic_records = normalize_lookup_records(dynamic_records); + let mut records = if let Some(geo) = state.geo.as_deref() { + sort_lookup_records_with_geo(normalized_dynamic_records, source_ip, geo) + } else { + sort_lookup_records(normalized_dynamic_records, source_ip) + }; + + let should_append_seeds = records.is_empty() || limit.is_some_and(|max| records.len() < max); + if should_append_seeds && let Some(seed_records) = state.seed_records.get(host) { + let seeds = if let Some(geo) = state.geo.as_deref() { + sort_lookup_records_with_geo(seed_records.iter().cloned().collect(), source_ip, geo) + } else { + sort_lookup_records(seed_records.iter().cloned().collect(), source_ip) + }; + records.extend(seeds); + } + + let records = normalize_lookup_records(records); + let records = if let Some(limit) = limit { + records.into_iter().take(limit).collect::>() + } else { + records + }; + + if records.is_empty() { + Ok(LookupResult::NotFound) + } else { + Ok(LookupResult::Multi(MultiResponse::new(records))) + } +} + +pub(super) async fn redis_host_blacklisted(conn: &mut C, host: &str) -> Result +where + C: redis::aio::ConnectionLike + Send + Sync, +{ + conn.sismember(redis_blacklist_key(), host) + .await + .map_err(|e| AppError::Redis { + message: e.to_string(), + }) +} diff --git a/src/bin/ddns-server/lookup/ranking.rs b/src/bin/ddns-server/lookup/ranking.rs new file mode 100644 index 0000000..1274249 --- /dev/null +++ b/src/bin/ddns-server/lookup/ranking.rs @@ -0,0 +1,254 @@ +use std::{ + cmp::Ordering, + collections::{HashMap, HashSet}, + net::{IpAddr, SocketAddr}, +}; + +use ddns::core::{ + MdnsPacket, + parser::{packet::be_packet, record::RData}, + wire::ResponseRecord, +}; + +use crate::{ + geo::{GeoResolver, GeoTraits}, + storage::LookupRecord, +}; + +type EndpointKey = (SocketAddr, Option); + +pub(super) const LOOKUP_CANDIDATE_CAP_TOTAL: usize = 64; +pub(super) const LOOKUP_CANDIDATE_CAP_ASN: usize = 16; +pub(super) const LOOKUP_CANDIDATE_CAP_COUNTRY: usize = 16; +pub(super) const LOOKUP_CANDIDATE_CAP_ALL: usize = 32; + +// GEO-aware ranking dimensions. Final ordering still falls back to the original +// record index so we keep lookups stable when all computed dimensions tie. +#[derive(Clone, Copy, Debug, PartialEq)] +pub(super) struct GeoSortKey { + pub(super) same_country: bool, + pub(super) same_asn: bool, + pub(super) family_match: bool, + pub(super) same_city: bool, + pub(super) load: Option, + pub(super) geo_distance: Option, +} + +pub(super) fn normalize_lookup_records(records: Vec) -> Vec { + let mut normalized = Vec::new(); + let mut seen = HashSet::new(); + + for record in records { + if !record.signature_fields.is_empty() { + normalized.push(record); + continue; + } + + let Ok((_, packet)) = be_packet(&record.dns) else { + normalized.push(record); + continue; + }; + + let mut emitted_endpoint = false; + + for answer in &packet.answers { + let RData::E(endpoint) = answer.data() else { + continue; + }; + + emitted_endpoint = true; + let key: EndpointKey = (endpoint.addr(), endpoint.agent_addr()); + + if !seen.insert(key) { + continue; + } + + let mut hosts = HashMap::new(); + hosts.insert(answer.name().to_string(), vec![endpoint.clone()]); + normalized.push(ResponseRecord::unsigned( + MdnsPacket::answer(0, &hosts).to_bytes(), + record.cert.clone(), + )); + } + + if !emitted_endpoint { + normalized.push(record); + } + } + + normalized +} + +pub(super) fn lookup_endpoint(dns_bytes: &[u8]) -> Option<(SocketAddr, Option)> { + let (_, packet) = be_packet(dns_bytes).ok()?; + packet + .answers + .iter() + .find_map(|answer| match answer.data() { + RData::E(endpoint) => Some((endpoint.addr(), endpoint.load())), + _ => None, + }) +} + +// Fallback ordering when GEO routing is disabled: prefer matching address family, +// then lower load, and finally preserve input order. We intentionally avoid +// IP prefix heuristics here because they are not reliable on the public Internet. +pub(super) fn sort_lookup_records( + records: Vec, + source_ip: Option, +) -> Vec { + let mut decorated = records + .into_iter() + .enumerate() + .map(|(index, record)| { + let sort_key = lookup_endpoint(&record.dns).map(|(endpoint, load)| { + let family_match = source_ip + .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) + .unwrap_or(false); + + (family_match, load) + }); + (sort_key, index, record) + }) + .collect::>(); + + decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { + match (left_key, right_key) { + (Some((left_family, left_load)), Some((right_family, right_load))) => right_family + .cmp(left_family) + .then_with(|| match (left_load, right_load) { + (Some(left), Some(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + }), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + } + .then_with(|| left_index.cmp(right_index)) + }); + + decorated.into_iter().map(|(_, _, record)| record).collect() +} + +pub(super) fn request_source_geo_traits( + source_ip: Option, + geo: Option<&GeoResolver>, +) -> Option { + Some(geo?.lookup_traits(source_ip?)) +} + +fn lookup_endpoint_geo_traits( + dns_bytes: &[u8], + geo: &GeoResolver, +) -> Option<(SocketAddr, Option, GeoTraits)> { + let (endpoint, load) = lookup_endpoint(dns_bytes)?; + Some((endpoint, load, geo.lookup_traits(endpoint.ip()))) +} + +fn compare_optional_partial(left: Option, right: Option) -> Ordering { + match (left, right) { + (Some(left), Some(right)) => left.partial_cmp(&right).unwrap_or(Ordering::Equal), + _ => Ordering::Equal, + } +} + +// GEO ordering is layered rather than score-based: +// country > ASN > address family > city name > lower load > shorter GEO distance. +// Missing optional values do not penalize a candidate; they simply skip that layer. +pub(super) fn compare_geo_sort_keys(left: GeoSortKey, right: GeoSortKey) -> Ordering { + right + .same_country + .cmp(&left.same_country) + .then_with(|| right.same_asn.cmp(&left.same_asn)) + .then_with(|| right.family_match.cmp(&left.family_match)) + .then_with(|| right.same_city.cmp(&left.same_city)) + .then_with(|| compare_optional_partial(left.load, right.load)) + .then_with(|| compare_optional_partial(left.geo_distance, right.geo_distance)) +} + +// Build the per-endpoint GEO ranking tuple. City name only participates when both +// sides have a name and already match on country; coordinate distance only +// participates when GeoResolver accepts both accuracy radii. +pub(super) fn build_geo_sort_key( + source_ip: Option, + source_traits: Option<&GeoTraits>, + endpoint: SocketAddr, + load: Option, + endpoint_traits: &GeoTraits, + geo: &GeoResolver, +) -> GeoSortKey { + let family_match = source_ip + .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) + .unwrap_or(false); + + let same_country = source_traits + .and_then(|source| source.country.as_deref()) + .zip(endpoint_traits.country.as_deref()) + .is_some_and(|(source, target)| source == target); + + let same_asn = source_traits + .and_then(|source| source.asn) + .zip(endpoint_traits.asn) + .is_some_and(|(source, target)| source == target); + + let same_city = same_country + && source_traits + .and_then(|source| source.city.as_deref()) + .zip(endpoint_traits.city.as_deref()) + .is_some_and(|(source, target)| source == target); + + let geo_distance = source_traits + .and_then(|source| source.point.as_ref()) + .zip(endpoint_traits.point.as_ref()) + .and_then(|(source, target)| geo.geo_distance_km(source, target)); + + GeoSortKey { + same_country, + same_asn, + family_match, + same_city, + load, + geo_distance, + } +} +pub(super) fn sort_lookup_records_with_geo( + records: Vec, + source_ip: Option, + geo: &GeoResolver, +) -> Vec { + let source_traits = request_source_geo_traits(source_ip, Some(geo)); + + let mut decorated = records + .into_iter() + .enumerate() + .map(|(index, record)| { + let sort_key = lookup_endpoint_geo_traits(&record.dns, geo).map( + |(endpoint, load, endpoint_traits)| { + build_geo_sort_key( + source_ip, + source_traits.as_ref(), + endpoint, + load, + &endpoint_traits, + geo, + ) + }, + ); + (sort_key, index, record) + }) + .collect::>(); + + decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { + match (left_key, right_key) { + (Some(left_key), Some(right_key)) => compare_geo_sort_keys(*left_key, *right_key), + (Some(_), None) => Ordering::Less, + (None, Some(_)) => Ordering::Greater, + (None, None) => Ordering::Equal, + } + .then_with(|| left_index.cmp(right_index)) + }); + + decorated.into_iter().map(|(_, _, record)| record).collect() +} diff --git a/src/bin/ddns-server/lookup/tests.rs b/src/bin/ddns-server/lookup/tests.rs new file mode 100644 index 0000000..9b26df4 --- /dev/null +++ b/src/bin/ddns-server/lookup/tests.rs @@ -0,0 +1,427 @@ +use std::{ + cmp::Ordering, + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, + path::PathBuf, + sync::Arc, +}; + +use ddns::core::{MdnsEndpoint, MdnsPacket, signature::SignatureFields, wire::ResponseRecord}; +use deadpool_redis::redis; + +use super::{ + query::{LookupResult, perform_lookup, redis_host_blacklisted}, + ranking::{ + GeoSortKey, build_geo_sort_key, compare_geo_sort_keys, lookup_endpoint, + normalize_lookup_records, sort_lookup_records, sort_lookup_records_with_geo, + }, +}; +use crate::{ + geo::{GeoPoint, GeoResolver, GeoTraits}, + storage::{AppState, LookupRecord, MemoryStorage, SeedRecords, Storage, redis_blacklist_key}, +}; + +fn fixture_geo_resolver() -> GeoResolver { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let city_db = manifest_dir.join("geoip/GeoLite2-City.mmdb"); + let asn_db = manifest_dir.join("geoip/GeoLite2-ASN.mmdb"); + + GeoResolver::open(&city_db, &asn_db, true, 100).expect("fixture geo db should open") +} + +fn lookup_record(host: &str, addr: SocketAddr, load: Option) -> LookupRecord { + let mut endpoint = match addr { + SocketAddr::V4(addr) => MdnsEndpoint::direct_v4(addr), + SocketAddr::V6(addr) => MdnsEndpoint::direct_v6(addr), + }; + endpoint.set_load(load); + + let mut hosts = HashMap::new(); + hosts.insert(host.to_string(), vec![endpoint]); + + ResponseRecord::unsigned(MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new()) +} + +struct FakeRedis { + response: redis::Value, + packed_commands: Vec>, +} + +impl redis::aio::ConnectionLike for FakeRedis { + fn req_packed_command<'a>( + &'a mut self, + cmd: &'a redis::Cmd, + ) -> redis::RedisFuture<'a, redis::Value> { + self.packed_commands.push(cmd.get_packed_command()); + let response = self.response.clone(); + Box::pin(async move { Ok(response) }) + } + + fn req_packed_commands<'a>( + &'a mut self, + _cmd: &'a redis::Pipeline, + _offset: usize, + _count: usize, + ) -> redis::RedisFuture<'a, Vec> { + Box::pin(async move { Ok(Vec::new()) }) + } + + fn get_db(&self) -> i64 { + 0 + } +} + +#[tokio::test] +async fn redis_host_blacklisted_queries_external_blacklist_set() { + let mut redis = FakeRedis { + response: redis::Value::Int(1), + packed_commands: Vec::new(), + }; + + let blacklisted = redis_host_blacklisted(&mut redis, "blocked.example.genmeta.net") + .await + .unwrap(); + + assert!(blacklisted); + assert_eq!(redis.packed_commands.len(), 1); + let command = String::from_utf8(redis.packed_commands.remove(0)).unwrap(); + assert!(command.contains("SISMEMBER")); + assert!(command.contains(redis_blacklist_key())); + assert!(command.contains("blocked.example.genmeta.net")); +} + +#[tokio::test] +async fn memory_blacklist_returns_not_found_before_seed_records() { + let host = "blocked.example.genmeta.net"; + let mut seed_records = HashMap::new(); + seed_records.insert( + host.to_string(), + vec![lookup_record( + host, + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + None, + )], + ); + let state = AppState { + storage: Storage::Memory(MemoryStorage::with_blacklist([host.to_string()])), + host_allowlist: Arc::new(vec!["genmeta.net".to_string()]), + require_signature: false, + ttl_secs: 30, + policies: Arc::new(crate::policy::DomainPolicies::default()), + seed_records: SeedRecords::new(seed_records), + geo: None, + }; + + let result = perform_lookup(&state, host, None, None).await.unwrap(); + + assert!(matches!(result, LookupResult::NotFound)); +} + +#[test] +fn normalize_lookup_records_keeps_signed_packets_whole() { + let mut record = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + None, + ); + record.signature_fields = SignatureFields { + content_digest: b"sha-256=:abc:".to_vec(), + signature_input: b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"" + .to_vec(), + signature: b"dns=:sig:".to_vec(), + }; + + let normalized = normalize_lookup_records(vec![record.clone()]); + + assert_eq!(normalized.len(), 1); + assert_eq!(normalized[0], record); +} + +#[test] +fn compare_geo_sort_keys_follows_documented_priority() { + let best = GeoSortKey { + same_country: true, + same_asn: true, + family_match: true, + same_city: true, + load: Some(0.2), + geo_distance: Some(20.0), + }; + let worse_load = GeoSortKey { + load: Some(0.8), + ..best + }; + let worse_family = GeoSortKey { + same_asn: true, + family_match: false, + same_city: true, + load: Some(0.1), + geo_distance: Some(1.0), + ..best + }; + let worse_city = GeoSortKey { + same_city: false, + load: Some(0.1), + geo_distance: Some(1.0), + ..best + }; + let worse_asn = GeoSortKey { + same_asn: false, + family_match: true, + same_city: true, + load: Some(0.1), + geo_distance: Some(1.0), + ..best + }; + let worse_country = GeoSortKey { + same_country: false, + same_asn: true, + family_match: true, + same_city: false, + load: Some(0.1), + geo_distance: Some(1.0), + }; + + assert_eq!(compare_geo_sort_keys(best, worse_load), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_family), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_city), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_asn), Ordering::Less); + assert_eq!(compare_geo_sort_keys(best, worse_country), Ordering::Less); +} + +#[test] +fn compare_geo_sort_keys_skips_unknown_dimensions() { + let known_distance = GeoSortKey { + same_country: true, + same_asn: true, + family_match: true, + same_city: true, + load: Some(0.2), + geo_distance: Some(10.0), + }; + let missing_distance = GeoSortKey { + geo_distance: None, + ..known_distance + }; + let missing_load = GeoSortKey { + load: None, + ..known_distance + }; + + assert_eq!( + compare_geo_sort_keys(known_distance, missing_distance), + Ordering::Equal + ); + assert_eq!( + compare_geo_sort_keys(known_distance, missing_load), + Ordering::Equal + ); +} + +#[test] +fn sort_lookup_records_with_geo_prefers_same_source_endpoint_even_with_higher_load() { + let geo = fixture_geo_resolver(); + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); + let matching = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + Some(0.9), + ); + let non_matching = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 3478)), + Some(0.1), + ); + + let sorted = + sort_lookup_records_with_geo(vec![non_matching, matching.clone()], source_ip, &geo); + + let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); + assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); +} + +#[test] +fn sort_lookup_records_without_geo_ignores_ip_prefix_and_prefers_lower_load() { + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); + let closer_prefix_higher_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 3478)), + Some(0.9), + ); + let farther_prefix_lower_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + Some(0.1), + ); + + let sorted = sort_lookup_records( + vec![closer_prefix_higher_load, farther_prefix_lower_load], + source_ip, + ); + + let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); + assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); +} + +#[test] +fn sort_lookup_records_with_geo_prefers_same_asn_then_same_country_on_real_ips() { + let geo = fixture_geo_resolver(); + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); + + let different_country = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), + Some(0.01), + ); + let same_country_different_asn = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), + Some(0.02), + ); + let same_country_same_asn = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), + Some(0.9), + ); + + let sorted = sort_lookup_records_with_geo( + vec![ + different_country, + same_country_different_asn, + same_country_same_asn, + ], + source_ip, + &geo, + ); + + let ordered_ips = sorted + .iter() + .map(|record| { + lookup_endpoint(&record.dns) + .expect("record should decode") + .0 + .ip() + }) + .collect::>(); + + assert_eq!( + ordered_ips, + vec![ + IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), + IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114)), + IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), + ] + ); +} + +#[test] +fn sort_lookup_records_with_geo_prefers_same_country_over_lower_load_on_real_ips() { + let geo = fixture_geo_resolver(); + let source_ip = Some(IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114))); + + let different_country_low_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(80, 80, 80, 80), 3478)), + Some(0.01), + ); + let same_country_higher_load = lookup_record( + "stun.example.com", + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), + Some(0.9), + ); + + let sorted = sort_lookup_records_with_geo( + vec![different_country_low_load, same_country_higher_load.clone()], + source_ip, + &geo, + ); + + let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); + assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); +} + +#[test] +fn build_geo_sort_key_ignores_city_distance_when_accuracy_is_too_large() { + let geo = fixture_geo_resolver(); + let source_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Beijing".to_string()), + asn: Some(64512), + point: Some(GeoPoint { + latitude: 39.9, + longitude: 116.4, + accuracy_radius_km: 500, + }), + }; + let endpoint_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Shanghai".to_string()), + asn: Some(64512), + point: Some(GeoPoint { + latitude: 31.2, + longitude: 121.5, + accuracy_radius_km: 10, + }), + }; + + let key = build_geo_sort_key( + Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))), + Some(&source_traits), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 3478)), + Some(0.2), + &endpoint_traits, + &geo, + ); + + assert!(key.same_country); + assert!(key.same_asn); + assert!(!key.same_city); + assert_eq!(key.geo_distance, None); +} + +#[test] +fn build_geo_sort_key_prefers_same_city_when_available() { + let geo = fixture_geo_resolver(); + let source_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Hangzhou".to_string()), + asn: Some(64512), + point: None, + }; + let same_city_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Hangzhou".to_string()), + asn: Some(64513), + point: None, + }; + let different_city_traits = GeoTraits { + country: Some("CN".to_string()), + city: Some("Shanghai".to_string()), + asn: Some(64513), + point: None, + }; + + let same_city_key = build_geo_sort_key( + Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), + Some(&source_traits), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 6, 6, 6), 3478)), + Some(0.9), + &same_city_traits, + &geo, + ); + let different_city_key = build_geo_sort_key( + Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), + Some(&source_traits), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), + Some(0.1), + &different_city_traits, + &geo, + ); + + assert!(same_city_key.same_city); + assert!(!different_city_key.same_city); + assert_eq!( + compare_geo_sort_keys(same_city_key, different_city_key), + Ordering::Less + ); +} diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs index 799c763..1a4264a 100644 --- a/src/bin/ddns-server/main.rs +++ b/src/bin/ddns-server/main.rs @@ -379,7 +379,7 @@ async fn main() -> Result<(), Box> { #[cfg(test)] mod tests { - use std::{net::SocketAddr, path::PathBuf}; + use std::path::PathBuf; use super::*; use crate::config::Config; @@ -389,7 +389,7 @@ mod tests { redis_write_url: None, redis_read_url: None, host_allowlist: Config::default_host_allowlist(), - listen: Config::default_listen(), + binds: Config::default_binds(), server_name: Config::default_server_name(), cert: Config::default_cert(), key: Config::default_key(), @@ -407,12 +407,12 @@ mod tests { } #[test] - fn unspecified_ipv4_listen_uses_dual_stack_wildcard() { - let listen: SocketAddr = "0.0.0.0:4433".parse().unwrap(); - let patterns = bind_patterns_for_listen(listen); + fn default_binds_include_ipv4_and_ipv6_wildcards() { + let patterns = Config::default_binds(); - assert_eq!(patterns.len(), 1); - assert_eq!(patterns[0].to_string(), "inet://[::]:4433"); + assert_eq!(patterns.len(), 2); + assert_eq!(patterns[0].to_string(), "inet://0.0.0.0:4433"); + assert_eq!(patterns[1].to_string(), "inet://[::]:4433"); } #[test] diff --git a/src/bin/ddns-server/ocsp.rs b/src/bin/ddns-server/ocsp.rs index 2ce8969..8132c47 100644 --- a/src/bin/ddns-server/ocsp.rs +++ b/src/bin/ddns-server/ocsp.rs @@ -1,28 +1,14 @@ use std::{io, path::Path, time::Duration}; -use der::{ - Choice, Decode, Encode, Enumerated, Sequence, - asn1::{Any, GeneralizedTime, Null, ObjectIdentifier, OctetString}, - oid::db::{rfc5912::ID_SHA_1, rfc6960::ID_PKIX_OCSP_BASIC}, -}; +use dhttp_identity::ocsp::{OcspStatus, build_ocsp_request_der, verify_stapled_ocsp_response}; use h3x::dquic::QuicEndpoint; use reqwest::{ Url, header::{ACCEPT, CONTENT_TYPE}, }; use rustls::pki_types::{CertificateDer, UnixTime}; -use sha1::{Digest, Sha1}; use tokio::time::sleep; use tracing::{info, warn}; -use x509_cert::{ - Certificate, ext::Extensions, serial_number::SerialNumber, spki::AlgorithmIdentifierOwned, -}; -use x509_parser::{ - asn1_rs::{BitString as X509BitString, FromDer as X509FromDer}, - parse_x509_certificate, - verify::verify_signature as verify_x509_signature, - x509::AlgorithmIdentifier as X509AlgorithmIdentifier, -}; use crate::config::Config; @@ -72,7 +58,7 @@ impl OcspAutoRefresh { pub async fn refresh_once(&self, quic: &mut QuicEndpoint) -> Duration { match self.fetch_response().await { Ok(response_der) => match self.validate_response(&response_der) { - Ok(OcspCertStatus::Good) => { + Ok(OcspStatus::Good) => { let response_len = response_der.len(); quic.update_ocsp(Some(response_der)); info!( @@ -83,7 +69,7 @@ impl OcspAutoRefresh { ); refresh_success_delay() } - Ok(OcspCertStatus::Unknown) => { + Ok(OcspStatus::Unknown) => { warn!( responder_url = %self.responder_url, retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), @@ -91,7 +77,7 @@ impl OcspAutoRefresh { ); OCSP_REFRESH_RETRY_DELAY } - Ok(OcspCertStatus::Revoked) => { + Ok(OcspStatus::Revoked) => { warn!( responder_url = %self.responder_url, retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), @@ -99,31 +85,6 @@ impl OcspAutoRefresh { ); OCSP_REFRESH_RETRY_DELAY } - Err(ValidateError::ResponderStatus(OcspResponseStatus::Unauthorized)) => { - warn!( - responder_url = %self.responder_url, - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp responder returned unauthorized; skipping staple update" - ); - OCSP_REFRESH_RETRY_DELAY - } - Err(ValidateError::ResponderStatus(OcspResponseStatus::MalformedRequest)) => { - warn!( - responder_url = %self.responder_url, - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp responder returned malformed_request; skipping staple update" - ); - OCSP_REFRESH_RETRY_DELAY - } - Err(ValidateError::ResponderStatus(status)) => { - warn!( - responder_url = %self.responder_url, - ocsp_status = %status.as_str(), - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp responder returned a non-success status; skipping staple update" - ); - OCSP_REFRESH_RETRY_DELAY - } Err(error) => { warn!( error = %error, @@ -183,7 +144,7 @@ impl OcspAutoRefresh { Ok(body.to_vec()) } - fn validate_response(&self, response_der: &[u8]) -> Result { + fn validate_response(&self, response_der: &[u8]) -> Result { verify_stapled_ocsp_response(&self.leaf_der, &self.issuer_der, response_der, now()) } } @@ -208,11 +169,8 @@ fn build_ocsp_request_context( None => load_issuer_certificate(issuer_override)?, }; - let leaf = Certificate::from_der(leaf_der.as_ref()) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - let issuer = Certificate::from_der(issuer_der.as_ref()) + let request_der = build_ocsp_request_der(&leaf_der, &issuer_der) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - let request_der = build_request_der(&leaf, &issuer).map_err(io::Error::other)?; Ok((request_der, leaf_der, issuer_der)) } @@ -270,457 +228,3 @@ fn now() -> UnixTime { .unwrap_or_default(); UnixTime::since_unix_epoch(now) } - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum OcspCertStatus { - Good, - Revoked, - Unknown, -} - -#[derive(Debug)] -enum ValidateError { - ResponderStatus(OcspResponseStatus), - Invalid(String), -} - -impl std::fmt::Display for ValidateError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::ResponderStatus(status) => { - write!(f, "OCSP responder returned status {}", status.as_str()) - } - Self::Invalid(message) => f.write_str(message), - } - } -} - -impl std::error::Error for ValidateError {} - -#[derive(Debug, Clone)] -struct ParsedOcspResponse { - status: OcspCertStatus, - basic: BasicOcspResponse, -} - -fn verify_stapled_ocsp_response( - end_entity: &CertificateDer<'_>, - issuer: &CertificateDer<'_>, - response_der: &[u8], - now: UnixTime, -) -> Result { - let end_entity_cert = Certificate::from_der(end_entity.as_ref()).map_err(|error| { - ValidateError::Invalid(format!("failed to decode end-entity cert: {error}")) - })?; - let issuer_cert = Certificate::from_der(issuer.as_ref()).map_err(|error| { - ValidateError::Invalid(format!("failed to decode issuer cert: {error}")) - })?; - let parsed = decode_unvalidated_ocsp_response_der(response_der, now)?; - let single = parsed - .basic - .tbs_response_data - .responses - .first() - .expect("single response checked during OCSP decode"); - let expected_cert_id = build_cert_id_local(&end_entity_cert, &issuer_cert)?; - - if !matches_cert_id(&single.cert_id, &expected_cert_id) { - return Err(ValidateError::Invalid( - "OCSP response cert_id does not match the server certificate".to_owned(), - )); - } - - if !responder_id_matches_certificate( - &parsed.basic.tbs_response_data.responder_id, - &issuer_cert, - )? { - return Err(ValidateError::Invalid( - "OCSP responder identifier does not match the issuer certificate".to_owned(), - )); - } - - verify_basic_ocsp_signature(&parsed.basic, issuer.as_ref())?; - - Ok(parsed.status) -} - -fn decode_unvalidated_ocsp_response_der( - response_der: &[u8], - now: UnixTime, -) -> Result { - let response = OcspResponse::from_der(response_der).map_err(der_error_string)?; - if response.response_status != OcspResponseStatus::Successful { - return Err(ValidateError::ResponderStatus(response.response_status)); - } - - let response_bytes = response.response_bytes.ok_or_else(|| { - ValidateError::Invalid("OCSP response is missing response bytes".to_owned()) - })?; - if response_bytes.response_type != ID_PKIX_OCSP_BASIC { - return Err(ValidateError::Invalid( - "unsupported OCSP response type".to_owned(), - )); - } - - let basic = BasicOcspResponse::from_der(response_bytes.response.as_bytes()) - .map_err(der_error_string)?; - let [single] = basic.tbs_response_data.responses.as_slice() else { - return Err(ValidateError::Invalid( - "OCSP response must contain exactly one single response".to_owned(), - )); - }; - - let produced_at = as_unix_time(&basic.tbs_response_data.produced_at); - if produced_at.as_secs() > now.as_secs() { - return Err(ValidateError::Invalid( - "OCSP response produced_at is in the future".to_owned(), - )); - } - - let this_update = as_unix_time(&single.this_update); - if this_update.as_secs() > now.as_secs() { - return Err(ValidateError::Invalid( - "OCSP response this_update is in the future".to_owned(), - )); - } - - let valid_until = single - .next_update - .as_ref() - .map(as_unix_time) - .unwrap_or(this_update); - if valid_until.as_secs() < this_update.as_secs() { - return Err(ValidateError::Invalid( - "OCSP response next_update is earlier than this_update".to_owned(), - )); - } - if valid_until.as_secs() < now.as_secs() { - return Err(ValidateError::Invalid( - "OCSP response is already expired".to_owned(), - )); - } - - let status = match &single.cert_status { - CertStatus::Good(_) => OcspCertStatus::Good, - CertStatus::Revoked(_) => OcspCertStatus::Revoked, - CertStatus::Unknown(_) => OcspCertStatus::Unknown, - }; - - Ok(ParsedOcspResponse { status, basic }) -} - -fn verify_basic_ocsp_signature( - basic: &BasicOcspResponse, - signer_der: &[u8], -) -> Result<(), ValidateError> { - let signer = parse_x509_certificate_der(signer_der, "OCSP signer certificate")?; - let signature_algorithm_der = basic - .signature_algorithm - .to_der() - .map_err(der_error_string)?; - let (_, signature_algorithm) = X509AlgorithmIdentifier::from_der(&signature_algorithm_der) - .map_err(|error| { - ValidateError::Invalid(format!( - "failed to parse OCSP response signature algorithm: {error}" - )) - })?; - let signature_der = basic.signature.to_der().map_err(der_error_string)?; - let (_, signature_value) = X509BitString::from_der(&signature_der).map_err(|error| { - ValidateError::Invalid(format!( - "failed to parse OCSP response signature value: {error}" - )) - })?; - let tbs_der = basic.tbs_response_data.to_der().map_err(der_error_string)?; - - verify_x509_signature( - signer.public_key(), - &signature_algorithm, - &signature_value, - &tbs_der, - ) - .map_err(|error| { - ValidateError::Invalid(format!("failed to verify OCSP response signature: {error}")) - }) -} - -fn parse_x509_certificate_der<'a>( - cert_der: &'a [u8], - label: &str, -) -> Result, ValidateError> { - parse_x509_certificate(cert_der) - .map(|(_, cert)| cert) - .map_err(|error| ValidateError::Invalid(format!("failed to parse {label}: {error:?}"))) -} - -fn responder_id_matches_certificate( - responder_id: &ResponderId, - certificate: &Certificate, -) -> Result { - match responder_id { - ResponderId::ByName(name) => Ok(name.to_der().map_err(der_error_string)? - == certificate - .tbs_certificate() - .subject() - .to_der() - .map_err(der_error_string)?), - ResponderId::ByKey(key_hash) => Ok(key_hash.as_bytes() - == Sha1::digest( - certificate - .tbs_certificate() - .subject_public_key_info() - .subject_public_key - .raw_bytes(), - ) - .as_slice()), - } -} - -fn build_cert_id_local( - end_entity: &Certificate, - issuer: &Certificate, -) -> Result { - let issuer_name_hash = Sha1::digest( - issuer - .tbs_certificate() - .subject() - .to_der() - .map_err(der_error_string)?, - ); - let issuer_key_hash = Sha1::digest( - issuer - .tbs_certificate() - .subject_public_key_info() - .subject_public_key - .raw_bytes(), - ); - - Ok(CertId { - hash_algorithm: AlgorithmIdentifierOwned { - oid: ID_SHA_1, - parameters: Some(Null.into()), - }, - issuer_name_hash: OctetString::new(issuer_name_hash.as_slice()) - .map_err(der_error_string)?, - issuer_key_hash: OctetString::new(issuer_key_hash.as_slice()).map_err(der_error_string)?, - serial_number: end_entity.tbs_certificate().serial_number().clone(), - }) -} - -fn matches_cert_id(actual: &CertId, expected: &CertId) -> bool { - actual.hash_algorithm.oid == expected.hash_algorithm.oid - && actual.issuer_name_hash == expected.issuer_name_hash - && actual.issuer_key_hash == expected.issuer_key_hash - && actual.serial_number == expected.serial_number -} - -fn build_request_der(end_entity: &Certificate, issuer: &Certificate) -> Result, String> { - OcspRequest { - tbs_request: TbsRequest { - version: Version::default(), - requestor_name: None, - request_list: vec![RequestEntry { - req_cert: build_request_cert_id(end_entity, issuer)?, - single_request_extensions: None, - }], - request_extensions: None, - }, - optional_signature: None, - } - .to_der() - .map_err(der_error) -} - -fn build_request_cert_id(end_entity: &Certificate, issuer: &Certificate) -> Result { - let issuer_name_hash = Sha1::digest( - issuer - .tbs_certificate() - .subject() - .to_der() - .map_err(der_error)?, - ); - let issuer_key_hash = Sha1::digest( - issuer - .tbs_certificate() - .subject_public_key_info() - .subject_public_key - .raw_bytes(), - ); - - Ok(CertId { - hash_algorithm: AlgorithmIdentifierOwned { - oid: ID_SHA_1, - parameters: Some(Null.into()), - }, - issuer_name_hash: OctetString::new(issuer_name_hash.as_slice()).map_err(der_error)?, - issuer_key_hash: OctetString::new(issuer_key_hash.as_slice()).map_err(der_error)?, - serial_number: end_entity.tbs_certificate().serial_number().clone(), - }) -} - -fn as_unix_time(time: &GeneralizedTime) -> UnixTime { - UnixTime::since_unix_epoch(time.to_unix_duration()) -} - -fn der_error(error: impl std::fmt::Display) -> String { - format!("failed to process OCSP DER: {error}") -} - -fn der_error_string(error: impl std::fmt::Display) -> ValidateError { - ValidateError::Invalid(der_error(error)) -} - -#[derive(Clone, Debug, Default, Copy, PartialEq, Eq, Enumerated)] -#[asn1(type = "INTEGER")] -#[repr(u8)] -enum Version { - #[default] - V1 = 0, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct OcspRequest { - tbs_request: TbsRequest, - - #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] - optional_signature: Option, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct TbsRequest { - #[asn1( - context_specific = "0", - default = "Default::default", - tag_mode = "EXPLICIT" - )] - version: Version, - - #[asn1(context_specific = "1", optional = "true", tag_mode = "EXPLICIT")] - requestor_name: Option, - - request_list: Vec, - - #[asn1(context_specific = "2", optional = "true", tag_mode = "EXPLICIT")] - request_extensions: Option, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct RequestEntry { - req_cert: CertId, - - #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] - single_request_extensions: Option, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct CertId { - hash_algorithm: AlgorithmIdentifierOwned, - issuer_name_hash: OctetString, - issuer_key_hash: OctetString, - serial_number: SerialNumber, -} - -#[derive(Clone, Debug, Eq, PartialEq, Choice)] -enum CertStatus { - #[asn1(context_specific = "0", tag_mode = "IMPLICIT")] - Good(Null), - - #[asn1(context_specific = "1", tag_mode = "IMPLICIT", constructed = "true")] - Revoked(RevokedInfo), - - #[asn1(context_specific = "2", tag_mode = "IMPLICIT")] - Unknown(Null), -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct RevokedInfo { - revocation_time: GeneralizedTime, - - #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] - revocation_reason: Option, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct SingleResponse { - cert_id: CertId, - cert_status: CertStatus, - this_update: GeneralizedTime, - - #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] - next_update: Option, - - #[asn1(context_specific = "1", optional = "true", tag_mode = "EXPLICIT")] - single_extensions: Option, -} - -#[derive(Clone, Debug, Eq, PartialEq, Choice)] -enum ResponderId { - #[asn1(context_specific = "1", tag_mode = "EXPLICIT", constructed = "true")] - ByName(Any), - - #[asn1(context_specific = "2", tag_mode = "EXPLICIT", constructed = "true")] - ByKey(OctetString), -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct ResponseData { - #[asn1( - context_specific = "0", - default = "Default::default", - tag_mode = "EXPLICIT" - )] - version: Version, - responder_id: ResponderId, - produced_at: GeneralizedTime, - responses: Vec, - - #[asn1(context_specific = "1", optional = "true", tag_mode = "EXPLICIT")] - response_extensions: Option, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct BasicOcspResponse { - tbs_response_data: ResponseData, - signature_algorithm: AlgorithmIdentifierOwned, - signature: der::asn1::BitString, - - #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] - certs: Option>, -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct ResponseBytes { - response_type: ObjectIdentifier, - response: OctetString, -} - -#[derive(Enumerated, Copy, Clone, Debug, Eq, PartialEq)] -#[repr(u32)] -enum OcspResponseStatus { - Successful = 0, - MalformedRequest = 1, - InternalError = 2, - TryLater = 3, - SigRequired = 5, - Unauthorized = 6, -} - -impl OcspResponseStatus { - fn as_str(self) -> &'static str { - match self { - Self::Successful => "successful", - Self::MalformedRequest => "malformed_request", - Self::InternalError => "internal_error", - Self::TryLater => "try_later", - Self::SigRequired => "sig_required", - Self::Unauthorized => "unauthorized", - } - } -} - -#[derive(Clone, Debug, Eq, PartialEq, Sequence)] -struct OcspResponse { - response_status: OcspResponseStatus, - - #[asn1(context_specific = "0", optional = "true", tag_mode = "EXPLICIT")] - response_bytes: Option, -} diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs index 4340209..d2a2d97 100644 --- a/src/bin/ddns-server/policy.rs +++ b/src/bin/ddns-server/policy.rs @@ -107,8 +107,7 @@ pub fn validate_dns_packet( } debug!( answers = dns_packet.answers.len(), - require_signature, - "validating dns packet" + require_signature, "validating dns packet" ); if require_signature { diff --git a/src/bin/ddns-server/publish/http.rs b/src/bin/ddns-server/publish/http.rs new file mode 100644 index 0000000..d76946a --- /dev/null +++ b/src/bin/ddns-server/publish/http.rs @@ -0,0 +1,163 @@ +use std::{convert::Infallible, sync::Arc}; + +use ddns::core::signature::{ + CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields, +}; +use h3x::{connection::ConnectionState, quic}; +use http_body_util::BodyExt; +use tracing::{debug, warn}; + +use super::store::{clear_record, publish_record}; +use crate::{ + error::{AppError, normalize_host, parse_query_params}, + lookup::{Request, Response, write_error}, + policy::{DomainPolicy, ValidatedDnsPacket, client_allowed_host, validate_dns_packet}, + storage::AppState, +}; + +// --------------------------------------------------------------------------- +// PublishSvc +// --------------------------------------------------------------------------- + +#[derive(Clone)] +pub struct PublishSvc { + pub state: AppState, +} + +impl PublishSvc { + pub fn call( + &self, + request: Request, + ) -> impl Future> + Send + 'static { + let state = self.state.clone(); + async move { Ok(publish_with_cert(state, request).await) } + } +} + +async fn publish_with_cert(state: AppState, request: Request) -> Response { + debug!("received publish request"); + + let params = parse_query_params(request.uri()); + debug!("query params: {:?}", params); + + let Some(host) = params.get("host") else { + warn!("missing host parameter"); + return write_error(AppError::MissingHostParam); + }; + + let host = match normalize_host(host, state.host_allowlist.as_ref()) { + Ok(h) => h, + Err(e) => return write_error(e), + }; + debug!(host = %host, "publish.host"); + + // Require a valid client certificate for all publish requests. + let authority = match request_connection(&request) { + Some(connection) => match connection.remote_authority().await { + Ok(Some(authority)) => authority, + Ok(None) => { + warn!("missing client certificate"); + return write_error(AppError::MissingClientCertificate); + } + Err(error) => { + warn!(error = %snafu::Report::from_error(&error), "failed to read client certificate"); + return write_error(AppError::MissingClientCertificate); + } + }, + None => { + warn!("missing client certificate"); + return write_error(AppError::MissingClientCertificate); + } + }; + + let policy = state.policies.policy_for(&host).clone(); + + // Standard policy: cert SAN must match the target host. + // OpenMulti policy: any authenticated node may publish — skip SAN check. + if policy == DomainPolicy::Standard { + let allowed = match client_allowed_host(authority.as_ref(), state.host_allowlist.as_ref()) { + Ok(h) => h, + Err(e) => { + warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); + return write_error(e); + } + }; + if allowed != host { + warn!(allowed = %allowed, requested = %host, "publish.host_mismatch"); + return write_error(AppError::HostMismatch); + } + } + + let signature_fields = signature_fields_from_headers(request.headers()); + + let body = match request.into_body().collect().await { + Ok(body) => body.to_bytes(), + Err(e) => { + warn!(error = %snafu::Report::from_error(&e), "failed to read request body"); + return write_error(AppError::InvalidDnsPacket { + message: e.to_string(), + }); + } + }; + + // Validate DNS packet; signature check only for Standard hosts. + let require_sig = policy == DomainPolicy::Standard && state.require_signature; + debug!( + host = %host, + bytes = body.len(), + require_signature = require_sig, + "validating publish packet" + ); + let packet = match validate_dns_packet( + body.as_ref(), + require_sig, + authority.as_ref(), + &signature_fields, + state.host_allowlist.as_ref(), + &host, + ) { + Ok(n) => n, + Err(e) => { + debug!(host = %host, error = %e, "publish packet rejected"); + return write_error(e); + } + }; + + match packet { + ValidatedDnsPacket::Records { host: packet_name } => { + let packet_host = match normalize_host(&packet_name, state.host_allowlist.as_ref()) { + Ok(h) => h, + Err(e) => return write_error(e), + }; + + if packet_host != host { + return write_error(AppError::HostMismatch); + } + + publish_record(&state, &host, &body, authority.as_ref(), signature_fields).await + } + ValidatedDnsPacket::Empty => clear_record(&state, &host, authority.as_ref()).await, + } +} + +fn signature_fields_from_headers(headers: &http::HeaderMap) -> SignatureFields { + let header = |name: &'static str| { + headers + .get(name) + .map(|value| value.as_bytes().to_vec()) + .unwrap_or_default() + }; + + SignatureFields { + content_digest: header(CONTENT_DIGEST_HEADER), + signature_input: header(SIGNATURE_INPUT_HEADER), + signature: header(SIGNATURE_HEADER), + } +} + +fn request_connection(request: &Request) -> Option>> { + request + .extensions() + .get::>>() + .cloned() +} diff --git a/src/bin/ddns-server/publish/mod.rs b/src/bin/ddns-server/publish/mod.rs new file mode 100644 index 0000000..73cba14 --- /dev/null +++ b/src/bin/ddns-server/publish/mod.rs @@ -0,0 +1,7 @@ +mod http; +mod store; + +pub use http::PublishSvc; + +#[cfg(test)] +mod tests; diff --git a/src/bin/ddns-server/publish.rs b/src/bin/ddns-server/publish/store.rs similarity index 50% rename from src/bin/ddns-server/publish.rs rename to src/bin/ddns-server/publish/store.rs index 2d272be..cdc4620 100644 --- a/src/bin/ddns-server/publish.rs +++ b/src/bin/ddns-server/publish/store.rs @@ -1,19 +1,14 @@ -use std::{collections::HashSet, convert::Infallible, sync::Arc}; +use std::collections::HashSet; -use ddns::core::signature::{ - CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields, -}; +use ddns::core::signature::SignatureFields; use deadpool_redis::redis::{self, AsyncCommands}; use dhttp_identity::identity::RemoteAuthority; -use h3x::{connection::ConnectionState, quic}; -use http_body_util::BodyExt; use tokio::time::{Duration, Instant}; -use tracing::{debug, info, warn}; +use tracing::info; use crate::{ - error::{AppError, normalize_host, parse_query_params}, - lookup::{Request, Response, body_response, write_error}, - policy::{DomainPolicy, ValidatedDnsPacket, client_allowed_host, validate_dns_packet}, + error::AppError, + lookup::{Response, body_response, write_error}, storage::{ AppState, Record, Storage, StoredRecord, cert_fingerprint, cert_fingerprint_hex, record_index_tags, redis_all_index_key, redis_asn_index_key, redis_country_index_key, @@ -21,153 +16,6 @@ use crate::{ }, }; -// --------------------------------------------------------------------------- -// PublishSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct PublishSvc { - pub state: AppState, -} - -impl PublishSvc { - pub fn call( - &self, - request: Request, - ) -> impl Future> + Send + 'static { - let state = self.state.clone(); - async move { Ok(publish_with_cert(state, request).await) } - } -} - -async fn publish_with_cert(state: AppState, request: Request) -> Response { - debug!("received publish request"); - - let params = parse_query_params(request.uri()); - debug!("query params: {:?}", params); - - let Some(host) = params.get("host") else { - warn!("missing host parameter"); - return write_error(AppError::MissingHostParam); - }; - - let host = match normalize_host(host, state.host_allowlist.as_ref()) { - Ok(h) => h, - Err(e) => return write_error(e), - }; - debug!(host = %host, "publish.host"); - - // Require a valid client certificate for all publish requests. - let authority = match request_connection(&request) { - Some(connection) => match connection.remote_authority().await { - Ok(Some(authority)) => authority, - Ok(None) => { - warn!("missing client certificate"); - return write_error(AppError::MissingClientCertificate); - } - Err(error) => { - warn!(error = %snafu::Report::from_error(&error), "failed to read client certificate"); - return write_error(AppError::MissingClientCertificate); - } - }, - None => { - warn!("missing client certificate"); - return write_error(AppError::MissingClientCertificate); - } - }; - - let policy = state.policies.policy_for(&host).clone(); - - // Standard policy: cert SAN must match the target host. - // OpenMulti policy: any authenticated node may publish — skip SAN check. - if policy == DomainPolicy::Standard { - let allowed = match client_allowed_host(authority.as_ref(), state.host_allowlist.as_ref()) { - Ok(h) => h, - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); - return write_error(e); - } - }; - if allowed != host { - warn!(allowed = %allowed, requested = %host, "publish.host_mismatch"); - return write_error(AppError::HostMismatch); - } - } - - let signature_fields = signature_fields_from_headers(request.headers()); - - let body = match request.into_body().collect().await { - Ok(body) => body.to_bytes(), - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "failed to read request body"); - return write_error(AppError::InvalidDnsPacket { - message: e.to_string(), - }); - } - }; - - // Validate DNS packet; signature check only for Standard hosts. - let require_sig = policy == DomainPolicy::Standard && state.require_signature; - debug!( - host = %host, - bytes = body.len(), - require_signature = require_sig, - "validating publish packet" - ); - let packet = match validate_dns_packet( - body.as_ref(), - require_sig, - authority.as_ref(), - &signature_fields, - state.host_allowlist.as_ref(), - &host, - ) { - Ok(n) => n, - Err(e) => { - debug!(host = %host, error = %e, "publish packet rejected"); - return write_error(e); - } - }; - - match packet { - ValidatedDnsPacket::Records { host: packet_name } => { - let packet_host = match normalize_host(&packet_name, state.host_allowlist.as_ref()) { - Ok(h) => h, - Err(e) => return write_error(e), - }; - - if packet_host != host { - return write_error(AppError::HostMismatch); - } - - publish_record(&state, &host, &body, authority.as_ref(), signature_fields).await - } - ValidatedDnsPacket::Empty => clear_record(&state, &host, authority.as_ref()).await, - } -} - -fn signature_fields_from_headers(headers: &http::HeaderMap) -> SignatureFields { - let header = |name: &'static str| { - headers - .get(name) - .map(|value| value.as_bytes().to_vec()) - .unwrap_or_default() - }; - - SignatureFields { - content_digest: header(CONTENT_DIGEST_HEADER), - signature_input: header(SIGNATURE_INPUT_HEADER), - signature: header(SIGNATURE_HEADER), - } -} - -fn request_connection(request: &Request) -> Option>> { - request - .extensions() - .get::>>() - .cloned() -} - async fn trim_expired_index_keys( conn: &mut C, keys: impl IntoIterator, @@ -400,133 +248,3 @@ pub async fn clear_record( info!(host = %host, fp = %fp_hex, "publish.clear"); body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) } - -#[cfg(test)] -mod tests { - use std::{ - collections::HashMap, - net::{Ipv4Addr, SocketAddrV4}, - sync::Arc, - }; - - use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; - use dhttp_identity::identity::RemoteAuthority; - use rustls::pki_types::CertificateDer; - - use super::*; - use crate::{ - lookup::{LookupResult, perform_lookup}, - policy::DomainPolicies, - storage::{MemoryStorage, SeedRecords}, - }; - - #[derive(Debug)] - struct TestAuthority { - name: &'static str, - certs: Vec>, - } - - impl TestAuthority { - fn new(name: &'static str, cert_bytes: Vec) -> Self { - Self { - name, - certs: vec![CertificateDer::from(cert_bytes)], - } - } - } - - impl RemoteAuthority for TestAuthority { - fn name(&self) -> &str { - self.name - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.certs - } - } - - fn memory_state() -> AppState { - AppState { - storage: Storage::Memory(MemoryStorage::new()), - host_allowlist: Arc::new(vec!["genmeta.net".to_string()]), - require_signature: true, - ttl_secs: 30, - policies: Arc::new(DomainPolicies::default()), - seed_records: SeedRecords::default(), - geo: None, - } - } - - fn packet_for(host: &str, last_octet: u8) -> bytes::Bytes { - let endpoint = EndpointAddr::direct_v4(SocketAddrV4::new( - Ipv4Addr::new(203, 0, 113, last_octet), - 4433, - )); - let mut hosts = HashMap::new(); - hosts.insert(host.to_owned(), vec![endpoint]); - bytes::Bytes::from(MdnsPacket::answer(0, &hosts).to_bytes()) - } - - #[tokio::test] - async fn clear_record_removes_only_current_certificate_fingerprint() { - let state = memory_state(); - let host = "reimu.pilot.dhttp.net"; - let authority_a = TestAuthority::new("authority-a", vec![1]); - let authority_b = TestAuthority::new("authority-b", vec![2]); - let packet_a = packet_for(host, 1); - let packet_b = packet_for(host, 2); - - assert_eq!( - publish_record( - &state, - host, - &packet_a, - &authority_a, - SignatureFields::empty() - ) - .await - .status(), - http::StatusCode::OK - ); - assert_eq!( - publish_record( - &state, - host, - &packet_b, - &authority_b, - SignatureFields::empty() - ) - .await - .status(), - http::StatusCode::OK - ); - - assert_eq!( - clear_record(&state, host, &authority_a).await.status(), - http::StatusCode::OK - ); - - let LookupResult::Multi(response) = perform_lookup(&state, host, None, None).await.unwrap() - else { - panic!("authority b record should remain"); - }; - assert_eq!(response.records.len(), 1); - assert_eq!(response.records[0].cert, authority_b.certs[0].as_ref()); - } - - #[tokio::test] - async fn clear_record_is_idempotent_for_missing_fingerprint() { - let state = memory_state(); - let host = "reimu.pilot.dhttp.net"; - let authority = TestAuthority::new("authority", vec![1]); - - assert_eq!( - clear_record(&state, host, &authority).await.status(), - http::StatusCode::OK - ); - assert!(matches!( - perform_lookup(&state, host, None, None).await.unwrap(), - LookupResult::NotFound - )); - } -} diff --git a/src/bin/ddns-server/publish/tests.rs b/src/bin/ddns-server/publish/tests.rs new file mode 100644 index 0000000..a22a279 --- /dev/null +++ b/src/bin/ddns-server/publish/tests.rs @@ -0,0 +1,126 @@ +use std::{ + collections::HashMap, + net::{Ipv4Addr, SocketAddrV4}, + sync::Arc, +}; + +use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr, signature::SignatureFields}; +use dhttp_identity::identity::RemoteAuthority; +use rustls::pki_types::CertificateDer; + +use super::store::{clear_record, publish_record}; +use crate::{ + lookup::query::{LookupResult, perform_lookup}, + policy::DomainPolicies, + storage::{AppState, MemoryStorage, SeedRecords, Storage}, +}; + +#[derive(Debug)] +struct TestAuthority { + name: &'static str, + certs: Vec>, +} + +impl TestAuthority { + fn new(name: &'static str, cert_bytes: Vec) -> Self { + Self { + name, + certs: vec![CertificateDer::from(cert_bytes)], + } + } +} + +impl RemoteAuthority for TestAuthority { + fn name(&self) -> &str { + self.name + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.certs + } +} + +fn memory_state() -> AppState { + AppState { + storage: Storage::Memory(MemoryStorage::new()), + host_allowlist: Arc::new(vec!["genmeta.net".to_string(), "dhttp.net".to_string()]), + require_signature: true, + ttl_secs: 30, + policies: Arc::new(DomainPolicies::default()), + seed_records: SeedRecords::default(), + geo: None, + } +} + +fn packet_for(host: &str, last_octet: u8) -> bytes::Bytes { + let endpoint = EndpointAddr::direct_v4(SocketAddrV4::new( + Ipv4Addr::new(203, 0, 113, last_octet), + 4433, + )); + let mut hosts = HashMap::new(); + hosts.insert(host.to_owned(), vec![endpoint]); + bytes::Bytes::from(MdnsPacket::answer(0, &hosts).to_bytes()) +} + +#[tokio::test] +async fn clear_record_removes_only_current_certificate_fingerprint() { + let state = memory_state(); + let host = "reimu.pilot.dhttp.net"; + let authority_a = TestAuthority::new("authority-a", vec![1]); + let authority_b = TestAuthority::new("authority-b", vec![2]); + let packet_a = packet_for(host, 1); + let packet_b = packet_for(host, 2); + + assert_eq!( + publish_record( + &state, + host, + &packet_a, + &authority_a, + SignatureFields::empty() + ) + .await + .status(), + http::StatusCode::OK + ); + assert_eq!( + publish_record( + &state, + host, + &packet_b, + &authority_b, + SignatureFields::empty() + ) + .await + .status(), + http::StatusCode::OK + ); + + assert_eq!( + clear_record(&state, host, &authority_a).await.status(), + http::StatusCode::OK + ); + + let LookupResult::Multi(response) = perform_lookup(&state, host, None, None).await.unwrap() + else { + panic!("authority b record should remain"); + }; + assert_eq!(response.records.len(), 1); + assert_eq!(response.records[0].cert, authority_b.certs[0].as_ref()); +} + +#[tokio::test] +async fn clear_record_is_idempotent_for_missing_fingerprint() { + let state = memory_state(); + let host = "reimu.pilot.dhttp.net"; + let authority = TestAuthority::new("authority", vec![1]); + + assert_eq!( + clear_record(&state, host, &authority).await.status(), + http::StatusCode::OK + ); + assert!(matches!( + perform_lookup(&state, host, None, None).await.unwrap(), + LookupResult::NotFound + )); +} diff --git a/src/core/parser/sigin.rs b/src/core/parser/sigin.rs index 2ed9c11..2d314ff 100644 --- a/src/core/parser/sigin.rs +++ b/src/core/parser/sigin.rs @@ -214,4 +214,3 @@ pub(crate) fn verify( .is_ok(), ) } - diff --git a/src/publisher.rs b/src/publisher.rs index 11aec10..c8a8cb6 100644 --- a/src/publisher.rs +++ b/src/publisher.rs @@ -9,7 +9,7 @@ use std::{ time::Duration, }; -use dhttp_identity::identity::LocalAuthority; +use dhttp_identity::{identity::LocalAuthority, name::Name}; #[cfg(feature = "mdns-resolver")] use dquic::qbase::net::Family; use dquic::{ @@ -18,7 +18,6 @@ use dquic::{ qresolve::{Publish, Resolve}, qtraversal::nat::client::{ClientLocationData, NatType}, }; -use dhttp_identity::name::Name; use snafu::{ResultExt, Snafu}; use crate::{ @@ -84,10 +83,7 @@ impl PublishAddresses { Self::default() } - pub fn wide_area( - mut self, - endpoints: impl IntoIterator, - ) -> Self { + pub fn wide_area(mut self, endpoints: impl IntoIterator) -> Self { self.wide_area.extend(endpoints); self } diff --git a/src/resolvers/h3.rs b/src/resolvers/h3.rs index e0e6887..bc0a072 100644 --- a/src/resolvers/h3.rs +++ b/src/resolvers/h3.rs @@ -133,7 +133,10 @@ where request: http::Request< impl http_body::Body + Send + 'static, >, - ) -> Result>, Error> { + ) -> Result< + http::Response>, + Error, + > { let authority = request .uri() .authority() @@ -363,10 +366,10 @@ where return Err(Error::ParseMultiResponse); } - let mut addrs = Vec::new(); - for r in multi.records { - if !r.signature_fields.is_empty() { - match r.signature_fields.verify(&r.dns, &r.cert) { + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { Ok(true) => {} Ok(false) => { tracing::debug!("ignored record with invalid DNS packet signature"); @@ -379,13 +382,13 @@ where } } - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; + let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; - addrs.extend( - packet - .answers + addrs.extend( + packet + .answers .iter() .filter_map(|answer| match answer.data() { record::RData::E(ep) => { @@ -405,9 +408,8 @@ where tracing::debug!(?answer, "ignored record"); None } - }), - ); - } + }), + ); } if addrs.is_empty() { diff --git a/src/resolvers/http.rs b/src/resolvers/http.rs index b302738..dae9498 100644 --- a/src/resolvers/http.rs +++ b/src/resolvers/http.rs @@ -221,9 +221,10 @@ impl Resolve for HttpResolver { } } } - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; + let (_remain, packet) = + be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; addrs.extend( packet diff --git a/src/resolvers/selector.rs b/src/resolvers/selector.rs index 87017f4..1bd0c86 100644 --- a/src/resolvers/selector.rs +++ b/src/resolvers/selector.rs @@ -1,4 +1,3 @@ -use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind}; use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use crate::core::parser::record::endpoint::EndpointAddr as DnsEndpointAddr; @@ -15,12 +14,10 @@ pub(crate) fn selected_endpoint_addrs( pub(crate) fn selected_endpoint_records( records: impl IntoIterator, ) -> Vec<(T, DquicEndpointAddr)> { - let mut groups: Vec<(CertificateChainKey, Vec<(T, DquicEndpointAddr)>)> = Vec::new(); + let mut groups: Vec<((bool, u64), Vec<(T, DquicEndpointAddr)>)> = Vec::new(); for (tag, record) in records { - let Ok(selector) = record.certificate_chain_key() else { - continue; - }; + let selector = (record.is_main(), 0); let Ok(endpoint) = DquicEndpointAddr::try_from(record) else { continue; }; @@ -32,20 +29,12 @@ pub(crate) fn selected_endpoint_records( } } - let selected = groups - .iter() - .position(|(key, endpoints)| { - key.kind() == CertificateChainKind::Primary && !endpoints.is_empty() - }) - .or_else(|| { - groups - .iter() - .position(|(_key, endpoints)| !endpoints.is_empty()) - }); - - selected - .map(|index| groups.swap_remove(index).1) - .unwrap_or_default() + groups.sort_by_key(|((is_main, sequence), _)| (!*is_main, *sequence)); + + groups + .into_iter() + .flat_map(|(_, endpoints)| endpoints) + .collect() } #[cfg(test)] From 4befa4be9a964a625d7795d140abc1bee59eaf8f Mon Sep 17 00:00:00 2001 From: metah3m Date: Mon, 15 Jun 2026 22:00:43 +0800 Subject: [PATCH 08/29] feat: add h3x resolver publishing --- src/publisher.rs | 22 +++++++++++++++++++++- src/publisher/dispatch.rs | 12 ++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/publisher.rs b/src/publisher.rs index c8a8cb6..81b2a66 100644 --- a/src/publisher.rs +++ b/src/publisher.rs @@ -18,7 +18,7 @@ use dquic::{ qresolve::{Publish, Resolve}, qtraversal::nat::client::{ClientLocationData, NatType}, }; -use snafu::{ResultExt, Snafu}; +use snafu::{IntoError, ResultExt, Snafu}; use crate::{ core::{ @@ -29,6 +29,11 @@ use crate::{ resolvers::Resolvers, }; +#[cfg(feature = "h3x-resolver")] +type DeferredH3Resolver = crate::resolvers::deferred::DeferredResolver< + crate::resolvers::h3::H3Resolver, +>; + pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); /// Upper bound for a single publish attempt in the background loop. /// @@ -438,6 +443,21 @@ impl Publisher { return Ok(true); } + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = any.downcast_ref::() { + let Some(h3) = h3.get() else { + return Err(publish_once_error::PublishSnafu { + publisher: h3.to_string(), + } + .into_error(io::Error::other( + "deferred h3 resolver has not been initialized", + ))); + }; + self.publish_signed_h3_endpoints(h3, public_endpoints) + .await?; + return Ok(true); + } + #[cfg(feature = "mdns-resolver")] if let Some(mdns) = any.downcast_ref::() { let mut published = false; diff --git a/src/publisher/dispatch.rs b/src/publisher/dispatch.rs index 07d4db6..84693ff 100644 --- a/src/publisher/dispatch.rs +++ b/src/publisher/dispatch.rs @@ -13,6 +13,11 @@ use super::{ }; use crate::resolvers::Resolvers; +#[cfg(feature = "h3x-resolver")] +type DeferredH3Resolver = crate::resolvers::deferred::DeferredResolver< + crate::resolvers::h3::H3Resolver, +>; + impl Publisher where A: LocalAuthority + Send + Sync + ?Sized, @@ -80,6 +85,13 @@ where return Ok(true); } + #[cfg(feature = "h3x-resolver")] + if let Some(h3) = any.downcast_ref::() { + self.publish_selected(h3, name, addresses, AddressSelector::WideArea) + .await?; + return Ok(true); + } + #[cfg(feature = "mdns-resolver")] if let Some(mdns) = any.downcast_ref::() { let mut published = false; From 2228f909a391a38b6a8c338c9fde3c22d81c489c Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 15:59:14 +0800 Subject: [PATCH 09/29] refactor: split ddns backends from facade modules --- Cargo.toml | 67 +--- README.md | 216 ++--------- examples/README.md | 60 +--- examples/publish.rs | 3 +- src/core/parser/record/endpoint.rs | 4 + src/core/parser/sigin.rs | 1 + src/h3.rs | 557 +++++++++++++++++++++++++++++ src/http.rs | 271 ++++++++++++++ src/lib.rs | 8 +- src/mdns.rs | 357 +++++++++++++++++- src/publishers.rs | 342 ++++++++++++++++++ src/publishers/address.rs | 461 ++++++++++++++++++++++++ src/publishers/dispatch.rs | 203 +++++++++++ src/publishers/packet.rs | 87 +++++ src/resolvers.rs | 89 +++-- src/resolvers/selector.rs | 13 +- tests/feature_surface.rs | 58 +++ tests/h3_generic_surface.rs | 11 + tests/publishers_surface.rs | 31 ++ 19 files changed, 2516 insertions(+), 323 deletions(-) create mode 100644 src/h3.rs create mode 100644 src/http.rs create mode 100644 src/publishers.rs create mode 100644 src/publishers/address.rs create mode 100644 src/publishers/dispatch.rs create mode 100644 src/publishers/packet.rs create mode 100644 tests/feature_surface.rs create mode 100644 tests/h3_generic_surface.rs create mode 100644 tests/publishers_surface.rs diff --git a/Cargo.toml b/Cargo.toml index 2bc2943..e0a51af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ readme = "README.md" keywords = ["dhttp", "dns", "mdns", "http3", "quic"] categories = ["network-programming", "asynchronous"] autoexamples = false +autobins = false [lib] name = "ddns" @@ -19,7 +20,7 @@ bitfield-struct = "0.13" bytes = "1" dashmap = "6" dhttp-identity = { path = "../dhttp/identity", version = "0.1.0" } -dquic = { path = "../dquic/dquic", version = "0.5.1" } +dquic = "0.5.1" flume = "0.12" futures = "0.3" libc = "0.2" @@ -60,87 +61,43 @@ reqwest = { version = "0.13", default-features = false, features = [ ], optional = true } url = { version = "2", optional = true } -clap = { version = "4", features = ["derive"], optional = true } -deadpool-redis = { version = "0.23", optional = true } -idna = { version = "1", optional = true } -maxminddb = { version = "0.26", optional = true } -serde = { version = "1", features = ["derive"], optional = true } -toml = { version = "1", optional = true } -tower-service = { version = "0.3", optional = true } -tracing-subscriber = { version = "0.3", features = [ - "env-filter", -], optional = true } - [features] default = [] -h3x-resolver = [ +resolvers = [] +publishers = [] +dquic-network = ["dep:h3x", "h3x/dquic"] +h3 = [ "dep:h3x", - "h3x/dquic", "h3x/hyper", "dep:http", "dep:http-body", "dep:http-body-util", "dep:url", ] -mdns-resolver = ["dep:h3x", "h3x/dquic"] -http-resolver = ["dep:reqwest", "dep:rustls-native-certs"] -server = [ - "h3x-resolver", - "dep:clap", - "dep:deadpool-redis", - "dep:idna", - "dep:maxminddb", - "dep:reqwest", - "dep:serde", - "dep:toml", - "dep:tower-service", - "dep:tracing-subscriber", -] +http = ["dep:reqwest", "dep:rustls-native-certs"] +mdns = [] [dev-dependencies] clap = { version = "4", features = ["derive"] } -h3x = { path = "../h3x", default-features = false, features = [ - "dquic", -] } +h3x = { path = "../h3x", default-features = false, features = ["dquic"] } shellexpand = "3" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -[[bin]] -name = "ddns-server" -path = "src/bin/ddns-server/main.rs" -required-features = ["server"] - [[example]] name = "mdns_discover" path = "examples/mdns_discover.rs" +required-features = ["mdns"] [[example]] name = "mdns_query" path = "examples/mdns_query.rs" +required-features = ["mdns"] [[example]] name = "publish" path = "examples/publish.rs" -required-features = ["h3x-resolver"] +required-features = ["h3"] [[example]] name = "query" path = "examples/query.rs" -required-features = ["h3x-resolver"] - -[patch."https://github.com/genmeta/dquic.git"] -dquic = { path = "../dquic/dquic", version = "0.5.1" } -qbase = { path = "../dquic/qbase", version = "0.5.1" } -qcongestion = { path = "../dquic/qcongestion", version = "0.5.1" } -qconnection = { path = "../dquic/qconnection", version = "0.5.1" } -qdatagram = { path = "../dquic/qdatagram", version = "0.5.1" } -qevent = { path = "../dquic/qevent", version = "0.5.1" } -qinterface = { path = "../dquic/qinterface", version = "0.5.1" } -qmacro = { path = "../dquic/qmacro", version = "0.5.1" } -qrecovery = { path = "../dquic/qrecovery", version = "0.5.1" } -qresolve = { path = "../dquic/qresolve", version = "0.5.1" } -qtraversal = { path = "../dquic/qtraversal", version = "0.5.1" } -qudp = { path = "../dquic/qudp", version = "0.5.1" } - -[patch."https://github.com/genmeta/h3x.git"] -h3x = { path = "../h3x", version = "0.2.0" } diff --git a/README.md b/README.md index 5bbd120..5f64bb6 100644 --- a/README.md +++ b/README.md @@ -1,44 +1,47 @@ # DDNS -`ddns` provides DNS discovery and resolver support for DHTTP applications. It is a -single Rust package: the historical `ddns-core`, `gmdns`, `ddns`, and -`ddns-server` crate boundaries now live as modules and feature-gated targets in -one published Cargo package named `dyns`, with a library target kept as `ddns` -for source compatibility. +`ddns` provides DNS discovery and resolver support for DHTTP applications. +The published Cargo package is `dyns`, and the library target remains `ddns`. -## Crate layout - -| Module / target | Role | -| --- | --- | -| `ddns::core` | DNS packet parser, resource-record types, endpoint `E` record encoding, and HTTP multi-record response wire format. | -| `ddns::mdns` | RFC 6762 multicast DNS transport, LAN publisher, and LAN resolver support. | -| `ddns::resolvers` | Resolver chain plus optional System, mDNS, DNS-over-H3, and DNS-over-HTTP resolvers. | -| `ddns::publisher` | Feature-gated endpoint record signing and publishing loop helpers for DHTTP endpoints. | -| `ddns-server` | DNS-over-H3 publish/lookup server binary, enabled by the `server` feature. | - -`ddns` is endpoint-facing support code for the DHTTP ecosystem. Applications -normally reach it through the `dhttp` endpoint facade; lower-level consumers can -depend on package `dyns` directly (typically renamed locally to `ddns`) when -they need DNS wire types, resolver composition, mDNS, or the DNS-over-H3 server. +`ddns` exposes backend implementations in `ddns::h3`, `ddns::http`, and `ddns::mdns`, +while `ddns::resolvers` and `ddns::publishers` act as facades for re-exports and +aggregate helper types. ```toml ddns = { package = "dyns", version = "0.3.0" } ``` +## Crate layout + +| Module | Role | +| --- | --- | +| `ddns::core` | DNS packet parser, resource-record types, endpoint `E` record encoding, and HTTP multi-record response wire format. | +| `ddns::h3` | DNS-over-HTTP/3 backend implementation. | +| `ddns::http` | DNS-over-HTTP backend implementation. | +| `ddns::mdns` | RFC 6762 multicast DNS transport plus LAN resolver/publisher backend implementation. | +| `ddns::resolvers` | Resolver facade: backend re-exports, resolver chains, and `Resolvers` aggregation. | +| `ddns::publishers` | Publisher facade: backend re-exports, endpoint record signing, and endpoint publication helpers. | + ## Features -All optional integrations are feature-gated; the default feature set is empty. +The default feature set is empty. | Feature | Enables | | --- | --- | -| `h3x-resolver` | DNS-over-H3 resolver and publisher using `h3x`/`dquic`. | -| `mdns-resolver` | mDNS resolver integration backed by an existing `h3x::dquic::Network`. | -| `http-resolver` | DNS-over-HTTP resolver/publisher using `reqwest` and native roots. | -| `server` | `ddns-server`, Redis storage support, TOML config parsing, and tracing setup. | +| `resolvers` | Resolver aggregation types such as `Resolvers`, `ResolversBuilder`, and `DnsScheme`. | +| `publishers` | Endpoint publication aggregation and signing helpers such as `EndpointPublisher`, `EndpointPublicationLoop`, and `PublishAddresses`. | +| `dquic-network` | `h3x`/`dquic` network-backed publication helpers such as `EndpointBindingAddresses`; meaningful together with `publishers`, and also used by mDNS resolver aggregation. | +| `h3` | DNS-over-HTTP/3 backend surface (`ddns::h3`, plus `H3Resolver` / `H3Publisher` re-exports from the facades). | +| `http` | DNS-over-HTTP backend surface (`ddns::http`, plus `HttpResolver` / `HttpPublisher` re-exports from the facades). | +| `mdns` | mDNS backend surface (`ddns::mdns`, plus `MdnsResolver` / `MdnsPublisher` re-exports from the facades). | + +Backend types live under the `resolvers` / `publishers` facades whenever their backend feature is enabled. +The aggregate `Resolvers` and endpoint-publication helper types are separately gated by the +`resolvers` and `publishers` features. ## Bootstrap constants -`build.rs` generates the resolver defaults exposed from `ddns::resolvers`: +`build.rs` generates resolver defaults exposed from `ddns::resolvers`: | Environment variable | Public constant | Fallback when unset | | --- | --- | --- | @@ -46,17 +49,13 @@ All optional integrations are feature-gated; the default feature set is empty. | `DHTTP_HTTP_DNS_SERVER` | `DHTTP_HTTP_DNS_SERVER` | `https://dhttp.example.net` | | `DHTTP_MDNS_SERVICE` | `DHTTP_MDNS_SERVICE` | `dhttp.example.net` | -The fallbacks are docs/build placeholders, not operational defaults. Real -endpoint, server, and E2E runs should set the DHTTP bootstrap environment before -building. +The fallbacks are docs/build placeholders, not operational defaults. ## Quick start ### Resolver chain -`Resolvers` queries all configured resolvers and streams endpoint addresses from -successful backends. System DNS is always available; mDNS, H3, and HTTP builders -appear behind their features. +Enable the resolver aggregation surface and build a chain explicitly: ```rust use ddns::resolvers::Resolvers; @@ -98,21 +97,13 @@ async fn main() -> std::io::Result<()> { } ``` -Runnable examples live in `examples/`: - -```bash -cargo run --example mdns_discover -- --ip 127.0.0.1 --device lo0 -cargo run --example mdns_query -- --ip 192.168.5.156 --device en0 -``` - -### DNS-over-H3 examples +Runnable examples: ```bash -cargo run --example query --features h3x-resolver -- \ - --server-ca /path/to/root.crt \ - --host nat.genmeta.net - -cargo run --example publish --features h3x-resolver -- \ +cargo run --example mdns_discover --features mdns -- --ip 127.0.0.1 --device lo0 +cargo run --example mdns_query --features mdns -- --ip 192.168.5.156 --device en0 +cargo run --example query -- --server-ca /path/to/root.crt --host nat.genmeta.net +cargo run --example publish --features h3 -- \ --server-ca /path/to/root.crt \ --client-name demo.example.dhttp.net \ --client-cert /path/to/demo.example.dhttp.net.pem \ @@ -121,139 +112,4 @@ cargo run --example publish --features h3x-resolver -- \ --addr 192.168.1.100:8080,192.168.1.101:8080 ``` -See [`examples/README.md`](examples/README.md) for the example CLI parameters -and response decoding notes. - -## DNS-over-H3 server - -Start the server with the `server` feature: - -```bash -cargo run --bin ddns-server --features server -- --config server.toml -``` - -When the configured TLS certificate includes its issuer certificate, `ddns-server` -now pulls its own stapled OCSP response from cert-server's public `POST /ocsp` -responder during startup and refreshes it every 2h55m. If the PEM only contains -the leaf certificate, set `ocsp_issuer_cert` in [server.toml](server.toml). You -can override the responder origin with `ocsp_responder_base_url`; by default it -uses `https://license.genmeta.net`. - -The server can optionally enable GEO-aware lookup ordering with local MaxMind -GeoLite2 City and ASN databases. When both `geoip_city_db` and `geoip_asn_db` -are configured, lookups prefer same-country and same-ASN endpoints first, then -fall back to address family, endpoint load, and city-distance tie-breaking for -sufficiently accurate records. - -For AWS deployments, keep QUIC/TLS/mTLS end-to-end in the backend, point -`redis_write_url` at the primary Redis endpoint, `redis_read_url` at a replica, -and set `host_allowlist` to the suffixes you actually serve. See -[docs/aws-deployment.md](docs/aws-deployment.md). - -To update those databases on a server, use [scripts/update-geolite-mmdb.sh](scripts/update-geolite-mmdb.sh). -It wraps `geoipupdate` and downloads both `GeoLite2-City.mmdb` and -`GeoLite2-ASN.mmdb` into one directory: - -```bash -MAXMIND_ACCOUNT_ID=12345 \ -MAXMIND_LICENSE_KEY=your_license_key \ -./scripts/update-geolite-mmdb.sh /etc/ddns -``` - -For detailed parameters and HTTP packet structures, see [examples/README.md](examples/README.md). - -The server exposes two HTTP/3 routes: - -| Route | Meaning | -| --- | --- | -| `POST /publish?host=` | Publish a DNS packet for `host`. Client mTLS is required. | -| `GET /lookup?host=[&limit=N]` | Look up active records for `host`; `limit` caps newest-first dynamic records. | - -Lookup responses use header `x-record-format: multi` and the binary body from -`ddns::core::wire::MultiResponse`: - -```text -u32 count -repeated count times: - u32 dns_len | dns packet bytes | u32 cert_len | DER publisher certificate bytes -``` - -Server configuration lives in `server.toml`: - -- storage is in-memory by default, or Redis when `redis = "redis://..."` is set; -- `ttl_secs` controls dynamic record expiry; -- `require_signature` controls signed endpoint-record enforcement for Standard - domains; -- `domain_policies` are matched in order, with unlisted domains using the - Standard policy; -- `seed_records` add static bootstrap endpoints to lookup results. - -Domain policies: - -| Policy | Behavior | -| --- | --- | -| `standard` | Client certificate DNS SAN must match the published host; signed `E` records are required when `require_signature = true`; each certificate fingerprint owns one active record for the host. | -| `open_multi` | Any authenticated client certificate may publish; signature checks are skipped; multiple certificate fingerprints can coexist and lookup returns newest-first records. | - -Public DHTTP identity hostnames should use the canonical `DhttpName::SUFFIX` -(`.dhttp.net`). Infrastructure names such as `nat.genmeta.net` can remain under -Genmeta infrastructure domains. - -## Endpoint `E` records - -Custom DNS record type `E` (`QTYPE = 266`) carries DHTTP endpoint addresses. The -current wire format is: - -```text -flags(u8) -[sequence(varint) if CLUSTERED] -primary address: port(u16) + IPv4/IPv6 bytes -[agent address if NAT] -[load(f32) if LOAD] -[signature: scheme(u16) + len(varint) + bytes if SIGNED] -``` - -Flag bits: - -| Bit mask | Name | Meaning | -| --- | --- | --- | -| `0x80` | `FAMILY` | `0` = IPv4, `1` = IPv6. | -| `0x40` | `MAIN` | Primary endpoint for the name. | -| `0x20` | `CLUSTERED` | Sequence number is present; multiple publishers share the name. | -| `0x10` | `NAT` | Agent address is present for NAT traversal. | -| `0x08` | `LOAD` | One-minute load value is present. | -| `0x01` | `SIGNED` | Signature with explicit TLS signature scheme is present. | - -For DHTTP endpoint publishing, `MAIN` and `sequence` are derived from the -publisher certificate's DHTTP subject key identifier. Operators do not choose -these fields manually: `primary` certificates publish `MAIN = true`, -`secondary` certificates publish `MAIN = false`, and the certificate-chain -sequence becomes the normalized endpoint-record sequence. An omitted sequence -field means sequence `0`. - -Signed records encode the signature scheme in the record; the no-scheme signed -format is not accepted. Legacy unsigned fixed-length endpoint address records are -still parsed by length for address-only compatibility. - -## Project structure - -```text -src/core.rs DNS core module root -src/core/parser/ DNS packet, name, question, record, varint, and signature parsers -src/core/parser/record/ A/AAAA/SRV/TXT/PTR/CNAME/E record parsing and encoding -src/core/wire.rs HTTP multi-record response wire format -src/mdns.rs mDNS module root -src/mdns/protocol.rs UDP multicast socket and packet routing -src/mdns/service.rs High-level mDNS service API -src/mdns/resolvers/ mDNS resolver integration -src/resolvers.rs Resolver chain and resolver defaults -src/resolvers/h3.rs DNS-over-H3 resolver/publisher -src/resolvers/http.rs DNS-over-HTTP resolver/publisher -src/resolvers/deferred.rs Deferred resolver initialization helper -src/publisher.rs Endpoint record signer and publication loop -src/publisher/ Address selection, publish dispatch, packet signing -src/bin/ddns-server/ DNS-over-H3 server implementation -examples/ mDNS and DNS-over-H3 example programs -server.toml Example server configuration -``` - +See [`examples/README.md`](examples/README.md) for example CLI parameters and response decoding notes. diff --git a/examples/README.md b/examples/README.md index 3c95833..2bba4b6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,14 +1,14 @@ # DDNS examples -This directory contains runnable examples for the single published `dyns` -package, whose library target remains `ddns`. +This directory contains runnable examples for the published `dyns` package, +whose library target remains `ddns`. | Example | Feature requirement | Purpose | | --- | --- | --- | -| `mdns_discover` | none | Bind an mDNS service, publish sample local hosts, and print multicast packets. | -| `mdns_query` | none | Query a DHTTP name over local mDNS. | -| `query` | `h3x-resolver` | Query a DNS-over-H3 server and decode the multi-record response. | -| `publish` | `h3x-resolver` | Publish signed endpoint `E` records to a DNS-over-H3 server using client mTLS. | +| `mdns_discover` | `mdns` | Bind an mDNS service, publish sample local hosts, and print multicast packets. | +| `mdns_query` | `mdns` | Query a DHTTP name over local mDNS. | +| `query` | none | Query a DNS-over-H3 server and decode the multi-record response. | +| `publish` | `h3` | Publish signed endpoint `E` records to a DNS-over-H3 server using client mTLS. | Run all commands from the `ddns/` repository. @@ -17,7 +17,7 @@ Run all commands from the `ddns/` repository. Bind to a local interface and print multicast traffic: ```bash -cargo run --example mdns_discover -- \ +cargo run --example mdns_discover --features mdns -- \ --ip 127.0.0.1 \ --device lo0 ``` @@ -25,19 +25,18 @@ cargo run --example mdns_discover -- \ Query a name over mDNS: ```bash -cargo run --example mdns_query -- \ +cargo run --example mdns_query --features mdns -- \ --ip 192.168.5.156 \ --device en0 ``` -Replace `--ip` and `--device` with an address and interface that exist on the -local machine. The mDNS service name defaults to the build-time -`DHTTP_MDNS_SERVICE` constant. +Replace `--ip` and `--device` with an address and interface that exist on the local machine. +The mDNS service name defaults to the build-time `DHTTP_MDNS_SERVICE` constant. ## DNS-over-H3 query ```bash -cargo run --example query --features h3x-resolver -- \ +cargo run --example query -- \ --server-ca /path/to/root.crt \ --host nat.genmeta.net ``` @@ -60,21 +59,12 @@ repeated count times: ``` The example prints each DNS packet, the publisher certificate fingerprint when a -certificate is present, and endpoint signature verification status for signed -`E` records. - -After the server starts, it listens for HTTP/3 requests and handles publish and query operations. -If the configured server certificate includes its issuer chain, the process also -fetches and refreshes its own stapled OCSP response from cert-server's public -`/ocsp` endpoint. When the PEM only contains the leaf certificate, configure -`ocsp_issuer_cert` in `server.toml`. The same config file also supports -`redis_write_url`, `redis_read_url`, and `host_allowlist` for AWS-style -primary/replica Redis and domain suffix controls. +certificate is present, and endpoint signature verification status for signed `E` records. ## DNS-over-H3 publish ```bash -cargo run --example publish --features h3x-resolver -- \ +cargo run --example publish --features h3 -- \ --server-ca /path/to/root.crt \ --client-name demo.example.dhttp.net \ --client-cert /path/to/demo.example.dhttp.net.pem \ @@ -93,26 +83,8 @@ Options: | `--client-cert ` | Client certificate chain PEM for mTLS and endpoint signature verification. | | `--client-key ` | Client private key PEM. | | `--sign ` | Whether to sign each endpoint `E` record. Defaults to `true`. | -| `--host ` | DNS host to publish. Standard-policy servers require this to match the client certificate DNS SAN. | +| `--host ` | DNS host to publish. | | `--addr ` | One or more socket addresses to publish. | -The example derives the endpoint selector from the client certificate SKI before -signing records. Use the correct certificate chain instead of manual selector -flags. - -The example sends `POST /publish?host=` with a binary DNS packet body. For -Standard policy domains, the server requires a client certificate whose single -DNS SAN matches `host`; when `require_signature = true`, at least one signed -endpoint record must verify against the publisher certificate. Open-multi policy -domains still require client mTLS but skip the host SAN and endpoint signature -checks. - -## Running the server - -```bash -cargo run --bin ddns-server --features server -- --config server.toml -``` - -`server.toml` documents the available fields: listener, TLS identity, client root -CA, optional Redis storage, TTL, domain policies, and static seed records. - +The example imports `H3Publisher` from the `ddns::publishers` facade, but only needs the +`h3` backend feature because backend publisher types are re-exported from the facade directly. diff --git a/examples/publish.rs b/examples/publish.rs index 8abb1b0..4aa43b4 100644 --- a/examples/publish.rs +++ b/examples/publish.rs @@ -8,7 +8,8 @@ use std::{ use clap::Parser; use ddns::{ core::{parser::record::endpoint::EndpointAddr, signature::SignatureFields}, - resolvers::{DHTTP_H3_DNS_SERVER, h3::H3Publisher}, + publishers::H3Publisher, + resolvers::DHTTP_H3_DNS_SERVER, }; use h3x::dquic::{ Identity, Network, QuicEndpoint, diff --git a/src/core/parser/record/endpoint.rs b/src/core/parser/record/endpoint.rs index cb89f02..5adf7e9 100644 --- a/src/core/parser/record/endpoint.rs +++ b/src/core/parser/record/endpoint.rs @@ -366,6 +366,10 @@ impl EndpointAddr { self.agent } + pub fn sequence(&self) -> Option { + self.sequence.map(Into::into) + } + pub fn set_sequence(&mut self, sequence: u64) { if sequence > 0 { self.sequence = Some(VarInt::from_u64(sequence).expect("Sequence too large")); diff --git a/src/core/parser/sigin.rs b/src/core/parser/sigin.rs index 2d314ff..70c7aa1 100644 --- a/src/core/parser/sigin.rs +++ b/src/core/parser/sigin.rs @@ -147,6 +147,7 @@ pub fn sign_with_key(key: &(impl SigningKey + ?Sized), data: &[u8]) -> Result, ) -> Result { diff --git a/src/h3.rs b/src/h3.rs new file mode 100644 index 0000000..ea5c502 --- /dev/null +++ b/src/h3.rs @@ -0,0 +1,557 @@ +use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; + +use dashmap::DashMap; +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture, Source}, +}; +use futures::{StreamExt, stream}; +use h3x::{ + dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, + endpoint::H3Endpoint, + quic, +}; +use http_body_util::{BodyExt, Empty, Full}; +use tokio::time::Instant; +use tracing::trace; +use url::Url; + +use crate::core::{ + MdnsPacket, + parser::packet::be_packet, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, + wire::be_multi_response, +}; + +const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); +const LOOKUP_REQUEST_ATTEMPTS: usize = 3; + +pub struct H3Resolver { + endpoint: Arc>, + base_url: Url, + cached_records: DashMap, + negative_cache: DashMap, +} + +#[derive(Debug)] +struct Record { + addrs: Vec, + expire: Instant, +} + +impl fmt::Debug for H3Resolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("H3Resolver") + .field("base_url", &self.base_url) + .finish_non_exhaustive() + } +} + +impl fmt::Display for H3Resolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "H3 DNS Resolver({})", self.base_url) + } +} + +#[derive(Debug, snafu::Snafu)] +pub enum Error { + #[snafu(display("h3 stream error"))] + H3Stream { source: MessageStreamError }, + #[snafu(display("failed to connect h3 endpoint"))] + Connect { source: h3x::pool::ConnectError }, + #[snafu(display("h3 request error"))] + H3Request { + source: HyperRequestError, + }, + #[snafu(display("h3 request timed out after {timeout:?}"))] + RequestTimeout { timeout: Duration }, + + #[snafu(display("{status}"))] + Status { status: http::StatusCode }, + + #[snafu(display("no DNS record found"))] + NoRecordFound, + + #[snafu(display("failed to parse DNS records from response"))] + ParseRecords { + source: nom::Err>>, + }, + + #[snafu(display("failed to decode multi-record response"))] + ParseMultiResponse, +} + +impl H3Resolver +where + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub fn new( + base_url: impl AsRef, + client: H3Endpoint, + ) -> io::Result { + Self::from_endpoint(base_url, Arc::new(client)) + } + + pub fn from_endpoint( + base_url: impl AsRef, + endpoint: Arc>, + ) -> io::Result { + let base_url = Url::parse(base_url.as_ref()) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; + base_url.host_str().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "base URL must have a valid host", + ) + })?; + + Ok(Self { + endpoint, + base_url, + cached_records: DashMap::new(), + negative_cache: DashMap::new(), + }) + } + + fn connect_error(&self, source: h3x::pool::ConnectError) -> Error { + // H3 DNS resolvers keep a long-lived endpoint. A network transition may + // leave the cached H3 connection with stale QUIC paths, so the next + // attempt must establish a fresh connection instead of reusing it. + self.endpoint.clear_pool(); + Error::Connect { source } + } + + fn request_error(&self, source: HyperRequestError) -> Error { + self.endpoint.clear_pool(); + Error::H3Request { source } + } + + async fn execute_request( + &self, + request: http::Request< + impl http_body::Body + Send + 'static, + >, + ) -> Result< + http::Response>, + Error, + > { + let authority = request + .uri() + .authority() + .expect("h3 dns request URL must include an authority") + .clone(); + tracing::trace!(%authority, "connecting h3 dns endpoint"); + let connection = match self.endpoint.connect(authority.clone()).await { + Ok(connection) => { + tracing::trace!(%authority, "connected h3 dns endpoint"); + connection + } + Err(source) => return Err(self.connect_error(source)), + }; + + let method = request.method().clone(); + let uri = request.uri().clone(); + tracing::trace!(%method, %uri, "executing h3 dns request"); + match connection.execute_hyper_request(request).await { + Ok(response) => { + tracing::trace!( + status = %response.status(), + "h3 dns request response received" + ); + Ok(response) + } + Err(source) => Err(self.request_error(source)), + } + } + + pub fn clear_pool(&self) { + self.endpoint.clear_pool(); + } + + pub async fn publish_endpoints( + &self, + name: &str, + endpoints: &[EndpointAddr], + ) -> Result<(), Error> { + trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); + let bytes = { + let endpoints = endpoints + .iter() + .filter_map(|ep| { + crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() + }) + .collect(); + let mut hosts = std::collections::HashMap::new(); + hosts.insert(name.to_string(), endpoints); + MdnsPacket::answer(0, &hosts).to_bytes() + }; + + self.publish_packet(name, &bytes).await + } + + /// Publish a pre-built DNS packet (with signatures already included). + pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { + self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) + .await + } + + pub async fn publish_signed( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> io::Result<()> { + self.publish_packet_with_signature(name, packet, signature_fields) + .await + .map_err(io::Error::other) + } + + async fn publish_packet_with_signature( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> Result<(), Error> { + let mut url = self.base_url.join("publish").expect("Invalid base URL"); + url.set_query(Some(&format!("host={name}"))); + let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); + tracing::trace!( + name, + packet_len = packet.len(), + url = %self.base_url, + "h3x publishing packet" + ); + let mut request = http::Request::post(uri); + if !signature_fields.is_empty() { + request = request + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); + } + let request = request + .body(Full::new(bytes::Bytes::copy_from_slice(packet))) + .expect("h3 dns publish request must be valid"); + let resp = self.execute_request(request).await?; + + if resp.status() != http::StatusCode::OK { + return Err(Error::Status { + status: resp.status(), + }); + } + + Ok(()) + } + + fn retryable_lookup_error(error: &Error) -> bool { + matches!( + error, + Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } + ) + } + + async fn lookup_response(&self, uri: http::Uri) -> Result> { + let request = http::Request::get(uri) + .body(Empty::::new()) + .expect("h3 dns lookup request must be valid"); + let resp = self.execute_request(request).await?; + + tracing::trace!("received response with status {}", resp.status()); + match resp.status() { + http::StatusCode::OK => {} + http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), + status => return Err(Error::Status { status }), + } + + match resp.into_body().collect().await { + Ok(response) => Ok(response.to_bytes()), + Err(source) => Err(Error::H3Stream { source }), + } + } + + async fn lookup_response_with_retry( + &self, + uri: http::Uri, + ) -> Result> { + for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { + match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) + .await + { + Ok(Ok(response)) => return Ok(response), + Ok(Err(error)) + if Self::retryable_lookup_error(&error) + && attempt < LOOKUP_REQUEST_ATTEMPTS => + { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup failed, retrying" + ); + } + Ok(Err(error)) => return Err(error), + Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup timed out, retrying" + ); + } + Err(_elapsed) => { + self.endpoint.clear_pool(); + return Err(Error::RequestTimeout { + timeout: LOOKUP_REQUEST_TIMEOUT, + }); + } + } + } + + unreachable!("lookup retry loop returns on the final attempt") + } + + pub async fn lookup(&self, name: &str) -> Result> { + use crate::core::parser::record; + let server = Arc::from(self.base_url.origin().ascii_serialization()); + let source = Source::H3 { server }; + + let Some(domain) = crate::resolvers::resolvable_name(name) else { + return Err(Error::NoRecordFound); + }; + + let now = Instant::now(); + let positive_ttl = Duration::from_secs(10); + let negative_ttl = Duration::from_secs(2); + + self.cached_records + .retain(|_host, record| record.expire > now); + self.negative_cache.retain(|_host, expire| *expire > now); + + if self.negative_cache.get(domain).is_some() { + return Err(Error::NoRecordFound); + } + + if let Some(record) = self.cached_records.get(domain) { + let addrs = record.addrs.clone(); + let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); + return Ok(stream.boxed()); + } + + let mut url = self.base_url.join("lookup").expect("Invalid URL"); + url.set_query(Some(&format!("host={}", domain))); + let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); + + tracing::trace!("sending lookup request to {}", self.base_url); + let response = match self.lookup_response_with_retry(uri).await { + Ok(response) => response, + Err(Error::NoRecordFound) => { + self.negative_cache + .insert(domain.to_string(), now + negative_ttl); + return Err(Error::NoRecordFound); + } + Err(error) => return Err(error), + }; + + // Server always returns multi-record format. + let (remain, multi) = + be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + if !remain.is_empty() { + return Err(Error::ParseMultiResponse); + } + + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); + continue; + } + } + } + + let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; + + addrs.extend( + packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + let endpoint = TryInto::::try_into(ep.clone()).ok()?; + trace!(?endpoint, "parsed endpoint from record"); + Some(endpoint) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }), + ); + } + + if addrs.is_empty() { + self.negative_cache + .insert(domain.to_string(), now + negative_ttl); + return Err(Error::NoRecordFound); + } + + self.cached_records.insert( + domain.to_string(), + Record { + addrs: addrs.clone(), + expire: now + positive_ttl, + }, + ); + + self.negative_cache.remove(domain); + + Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) + } +} + +impl Publish for H3Resolver +where + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + Box::pin(async move { + match self.publish_packet(name, packet).await { + Ok(()) => Ok(()), + Err(error) => Err(io::Error::other(error)), + } + }) + } +} + +impl Resolve for H3Resolver +where + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + Box::pin(async move { + match H3Resolver::lookup(self, name).await { + Ok(stream) => Ok(stream), + Err(error) => Err(io::Error::other(error)), + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[cfg(feature = "dquic-network")] + use crate::resolvers::DHTTP_H3_DNS_SERVER; + + #[test] + fn lookup_retry_budget_leaves_external_timeout_margin() { + let total_budget = LOOKUP_REQUEST_TIMEOUT * LOOKUP_REQUEST_ATTEMPTS as u32; + + assert!( + total_budget <= Duration::from_secs(10), + "h3 lookup must return before common 15s command timeouts so callers can retry" + ); + } + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn cached_lookup_reports_h3_dns_source() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cached_records.insert( + "car.lab.dhttp.net".to_owned(), + Record { + addrs: vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], + expire: Instant::now() + Duration::from_secs(60), + }, + ); + + let mut records = resolver.lookup("car.lab.dhttp.net").await.unwrap(); + let (source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + source, + Source::H3 { + server: Arc::from(resolver.base_url.origin().ascii_serialization()) + } + ); + assert_eq!( + endpoint, + EndpointAddr::direct("192.168.5.78:41748".parse().unwrap()) + ); + } + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn cached_dns_genmeta_net_record_is_returned() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cached_records.insert( + "dns.genmeta.net".to_owned(), + Record { + addrs: vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], + expire: Instant::now() + Duration::from_secs(60), + }, + ); + + let mut records = resolver.lookup("dns.genmeta.net").await.unwrap(); + let (_source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + endpoint, + EndpointAddr::direct("192.0.2.53:4433".parse().unwrap()) + ); + } + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn cached_lookup_uses_e_record_port_not_input_port() { + let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); + resolver.cached_records.insert( + "nat.genmeta.net".to_owned(), + Record { + addrs: vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], + expire: Instant::now() + Duration::from_secs(60), + }, + ); + + let mut records = resolver.lookup("nat.genmeta.net:20004").await.unwrap(); + let (_source, endpoint) = records.next().await.unwrap(); + + assert_eq!( + endpoint, + EndpointAddr::direct("192.0.2.10:21000".parse().unwrap()) + ); + } +} diff --git a/src/http.rs b/src/http.rs new file mode 100644 index 0000000..4531435 --- /dev/null +++ b/src/http.rs @@ -0,0 +1,271 @@ +use std::{fmt::Display, io, sync::Arc}; + +use dashmap::DashMap; +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, +}; +use futures::{StreamExt, TryFutureExt, stream}; +use reqwest::{Client, IntoUrl, StatusCode, Url}; +use tokio::time::Instant; + +use crate::core::{ + parser::packet::be_packet, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, + wire::be_multi_response, +}; + +#[derive(Debug)] +struct Record { + addrs: Vec, + expire: Instant, +} + +#[derive(Debug)] +pub struct HttpResolver { + http_client: Client, + base_url: Url, + cached_records: DashMap, +} + +impl Display for HttpResolver { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Http DNS({})", + self.base_url.host_str().expect("checked in constructor") + ) + } +} + +impl HttpResolver { + pub fn new(base_url: impl IntoUrl) -> io::Result { + let base_url = base_url + .into_url() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + base_url.host_str().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "base URL must have a valid host", + ) + })?; + + Ok(Self { + http_client: build_http_client()?, + base_url, + cached_records: DashMap::new(), + }) + } + + pub async fn publish_signed( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> io::Result<()> { + self.publish_packet_with_signature(name, packet, signature_fields) + .await + .map_err(io::Error::other) + } + + async fn publish_packet_with_signature( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> Result<(), Error> { + let mut url = self.base_url.join("publish").expect("Invalid base URL"); + url.set_query(Some(&format!("host={name}"))); + let mut request = self + .http_client + .post(url) + .header("Content-Type", "application/octet-stream"); + if !signature_fields.is_empty() { + request = request + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); + } + request + .body(packet.to_vec()) + .send() + .await? + .error_for_status()?; + Ok(()) + } +} + +fn build_http_client() -> io::Result { + let native_certs = rustls_native_certs::load_native_certs(); + for error in &native_certs.errors { + let report = snafu::Report::from_error(error); + tracing::warn!(error = %report, "failed to load native root certificate"); + } + + let mut root_store = rustls::RootCertStore::empty(); + let (valid_roots, invalid_roots) = root_store.add_parsable_certificates(native_certs.certs); + if invalid_roots > 0 { + tracing::debug!(invalid_roots, "ignored invalid native root certificates"); + } + if valid_roots == 0 { + tracing::warn!("no native root certificates loaded for http resolver"); + } + + let mut tls = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + tls.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + Client::builder() + .use_preconfigured_tls(tls) + .build() + .map_err(io::Error::other) +} + +#[derive(Debug, snafu::Snafu)] +enum Error { + #[snafu(display("http request failed"))] + Reqwest { source: reqwest::Error }, + + #[snafu(display("{status}"))] + Status { status: StatusCode }, + + #[snafu(display("no DNS record found"))] + NoRecordFound, + + #[snafu(display("failed to parse DNS records from response"))] + ParseRecords { + source: nom::Err>>, + }, + + #[snafu(display("failed to decode multi-record response"))] + ParseMultiResponse, +} + +impl From for Error { + fn from(source: reqwest::Error) -> Self { + match source.status() { + Some(stateus) if stateus == StatusCode::NOT_FOUND => Error::NoRecordFound, + Some(status) => Error::Status { status }, + None => Error::Reqwest { + source: source.without_url(), + }, + } + } +} + +impl Publish for HttpResolver { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + Box::pin(async move { + self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) + .await + .map_err(io::Error::other) + }) + } +} + +impl Resolve for HttpResolver { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + let lookup = async move { + let Some(domain) = crate::resolvers::resolvable_name(name) else { + return Err(Error::NoRecordFound); + }; + + let now = Instant::now(); + let server = Arc::from(self.base_url.host_str().unwrap_or("")); + let soource = Source::Http { server }; + + use crate::core::parser::record; + self.cached_records + .retain(|_host, Record { expire, .. }| *expire < now); + if let Some(record) = self.cached_records.get(domain) { + let endpoint_addrs: Vec<_> = record + .addrs + .iter() + .map(|endpoint: &EndpointAddr| (soource.clone(), *endpoint)) + .collect(); + return Ok(stream::iter(endpoint_addrs).boxed()); + } + let response = self + .http_client + .get(self.base_url.join("lookup").expect("Invalid URL")) + .query(&[("host", domain)]) + .send() + .await; + + let response = response?.error_for_status()?.bytes().await?; + let (remain, multi) = + be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + if !remain.is_empty() { + return Err(Error::ParseMultiResponse); + } + + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); + continue; + } + } + } + let (_remain, packet) = + be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; + + addrs.extend( + packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + let endpoint = + TryInto::::try_into(ep.clone()).ok()?; + Some(endpoint) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }), + ); + } + if addrs.is_empty() { + return Err(Error::NoRecordFound); + } + + // cache the addrs + self.cached_records.insert( + domain.to_string(), + Record { + addrs: addrs.clone(), + expire: now + std::time::Duration::from_secs(300), + }, + ); + + Ok(stream::iter(addrs.into_iter().map(move |ep| (soource.clone(), ep))).boxed()) + }; + Box::pin(lookup.map_err(io::Error::other)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 193112c..b495501 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,11 @@ mod bootstrap; pub mod core; +#[cfg(feature = "h3")] +pub mod h3; +#[cfg(feature = "http")] +pub mod http; +#[cfg(feature = "mdns")] pub mod mdns; -#[cfg(any(feature = "h3x-resolver", feature = "mdns-resolver"))] -pub mod publisher; +pub mod publishers; pub mod resolvers; diff --git a/src/mdns.rs b/src/mdns.rs index dd51460..03ce041 100644 --- a/src/mdns.rs +++ b/src/mdns.rs @@ -1,4 +1,359 @@ mod if_nametoindex; mod protocol; -pub mod resolvers; pub mod service; + +use std::{fmt, io, net::IpAddr}; +#[cfg(feature = "dquic-network")] +use std::{net::SocketAddr, sync::Arc}; + +#[cfg(feature = "dquic-network")] +use dquic::qresolve::RecordStream; +use dquic::{ + qbase::net::Family, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, +}; +use futures::{FutureExt, StreamExt, TryFutureExt, future, stream}; +#[cfg(feature = "dquic-network")] +use futures::{Stream, stream::FuturesUnordered}; + +#[cfg(feature = "dquic-network")] +use self::protocol::MdnsProtocol; +#[cfg(feature = "dquic-network")] +use crate::core::parser::packet::Packet; +use crate::core::parser::record::RData; + +pub type MdnsResolver = service::Mdns; +pub type MdnsPublisher = service::Mdns; + +impl MdnsResolver { + pub fn source(&self) -> Source { + Source::Mdns { + nic: self.bound_nic().into(), + family: match self.bound_ip() { + IpAddr::V4(..) => Family::V4, + IpAddr::V6(..) => Family::V6, + }, + } + } +} + +impl fmt::Display for MdnsResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.source(), f) + } +} + +impl Publish for MdnsPublisher { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + let endpoints = match endpoints_from_packet(packet) { + Ok(endpoints) => endpoints, + Err(error) => return future::ready(Err(error)).boxed(), + }; + self.insert_host(name.to_string(), endpoints); + future::ready(Ok(())).boxed() + } +} + +impl Resolve for MdnsResolver { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + let source = self.source(); + self.query(name.to_owned()) + .map_ok(move |list| { + let endpoints = crate::resolvers::selector::selected_endpoint_addrs(list); + stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() + }) + .boxed() + } +} + +fn endpoints_from_packet(packet: &[u8]) -> io::Result> { + use crate::core::parser::packet::be_packet; + + be_packet(packet) + .map(|(_, pkt)| { + pkt.answers + .iter() + .filter_map(|rr| match rr.data() { + RData::E(ep) => Some(ep.clone()), + _ => None, + }) + .collect::>() + }) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string())) +} + +#[cfg(feature = "dquic-network")] +pub struct MdnsBindDriver { + iface_manager: Arc, + null_io_factory: Arc, + service_name: Arc, +} + +#[cfg(feature = "dquic-network")] +impl MdnsBindDriver { + pub fn new(service_name: impl Into>) -> Self { + Self { + iface_manager: Arc::new(h3x::dquic::net::InterfaceManager::new()), + null_io_factory: Arc::new(h3x::dquic::NullIoFactory), + service_name: service_name.into(), + } + } + + fn install_or_rebind_mdns( + &self, + network: &h3x::dquic::Network, + bind_iface: &h3x::dquic::net::BindInterface, + ) { + let bind_uri = bind_iface.bind_uri(); + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + tracing::debug!(%bind_uri, "skipping mdns binding for non-interface bind uri"); + return; + }; + let Some(ip) = network.resolve_device_addr(device, family) else { + tracing::debug!(%bind_uri, "skipping mdns binding without local interface address"); + return; + }; + + bind_iface.with_components_mut(|components, _iface| { + match components.try_init_with(|| service::Mdns::new(&self.service_name, ip, device)) { + Ok(mdns) => mdns.reinit_on(device, ip), + Err(error) => { + let report = snafu::Report::from_error(&error); + tracing::debug!(error = %report, %bind_uri, "failed to initialize mdns binding"); + } + } + }); + } +} + +#[cfg(feature = "dquic-network")] +impl h3x::dquic::BindDriver for MdnsBindDriver { + fn bind<'a>( + &'a self, + network: &'a h3x::dquic::Network, + uri: h3x::dquic::net::BindUri, + ) -> futures::future::BoxFuture<'a, h3x::dquic::net::BindInterface> { + async move { + let iface = self + .iface_manager + .bind(uri, self.null_io_factory.clone()) + .await; + self.install_or_rebind_mdns(network, &iface); + iface + } + .boxed() + } + + fn rebind<'a>( + &'a self, + network: &'a h3x::dquic::Network, + iface: &'a h3x::dquic::net::BindInterface, + ) -> futures::future::BoxFuture<'a, ()> { + async move { + self.install_or_rebind_mdns(network, iface); + } + .boxed() + } +} + +#[cfg(feature = "dquic-network")] +pub struct MdnsResolvers { + network: Arc, + driver: Arc, + patterns: Arc>, + _handles: Vec, +} + +#[cfg(feature = "dquic-network")] +#[derive(Debug, Clone)] +pub struct BoundMdnsResolver { + pub device: String, + pub family: Family, + pub resolver: MdnsResolver, +} + +#[cfg(feature = "dquic-network")] +impl fmt::Debug for MdnsResolvers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MdnsResolvers") + .field("patterns", &self.patterns) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "dquic-network")] +impl fmt::Display for MdnsResolvers { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("mDNS resolvers") + } +} + +#[cfg(feature = "dquic-network")] +impl MdnsResolvers { + pub async fn bind( + network: Arc, + patterns: Arc>, + service_name: impl Into>, + ) -> Self { + let driver = Arc::new(MdnsBindDriver::new(service_name)); + let mut handles = Vec::with_capacity(patterns.len()); + for pattern in patterns.iter() { + handles.push(network.bind_with(driver.clone(), pattern.clone()).await); + } + + Self { + network, + driver, + patterns, + _handles: handles, + } + } + + pub fn bound_interfaces( + &self, + pattern: &h3x::dquic::binds::BindPattern, + ) -> Option> { + self.network.get_interfaces_with(&self.driver, pattern) + } + + fn for_each_resolver(&self, mut f: impl FnMut(&MdnsResolver)) { + for pattern in self.patterns.iter() { + let Some(ifaces) = self.bound_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + iface.with_components(|components, _| { + if let Some(mdns) = components.get::() { + f(mdns); + } + }); + } + } + } + + pub fn bound_resolvers(&self) -> Vec { + let mut resolvers = Vec::new(); + for pattern in self.patterns.iter() { + let Some(ifaces) = self.bound_interfaces(pattern) else { + continue; + }; + for iface in ifaces { + let bind_uri = iface.bind_uri(); + let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { + continue; + }; + iface.with_components(|components, _| { + if let Some(resolver) = components.get::() { + resolvers.push(BoundMdnsResolver { + device: device.to_owned(), + family, + resolver: resolver.clone(), + }); + } + }); + } + } + resolvers + } + + pub async fn query(&self, name: &str) -> io::Result { + let mut lookup_futures = FuturesUnordered::new(); + let mut has_resolver = false; + self.for_each_resolver(|resolver| { + has_resolver = true; + let source = resolver.source(); + lookup_futures.push( + resolver + .query(name.to_owned()) + .map_ok(move |eps| (source, eps)), + ); + }); + if !has_resolver { + return Err(io::Error::other("no mdns resolvers available")); + } + + let mut last_error = None; + let mut has_success = false; + let mut records = Vec::new(); + while let Some(result) = lookup_futures.next().await { + match result { + Ok((source, endpoints)) => { + has_success = true; + records.extend( + endpoints + .into_iter() + .map(|endpoint| (source.clone(), endpoint)), + ); + } + Err(error) => last_error = Some(error), + } + } + + if !has_success { + return Err( + last_error.unwrap_or_else(|| io::Error::other("no mdns resolvers available")) + ); + } + + let records = crate::resolvers::selector::selected_endpoint_records(records); + + Ok(stream::iter(records).boxed()) + } + + pub fn discover(&self) -> impl Stream + use<> { + let mut protos = Vec::new(); + self.for_each_resolver(|resolver| { + protos.push(resolver.protocol()); + }); + + async fn receive_one( + proto: Arc, + ) -> Option<((SocketAddr, Packet), Arc)> { + let result = proto.receive_boardcast().await.ok()?; + Some((result, proto)) + } + + let mut pending = protos + .into_iter() + .map(receive_one) + .collect::>(); + + Box::pin(stream::poll_fn(move |cx| { + use std::task::Poll; + loop { + match pending.poll_next_unpin(cx) { + Poll::Ready(Some(Some((item, proto)))) => { + pending.push(receive_one(proto)); + return Poll::Ready(Some(item)); + } + Poll::Ready(Some(None)) => continue, + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + })) + } +} + +#[cfg(feature = "dquic-network")] +impl Publish for MdnsResolvers { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + let endpoints = match endpoints_from_packet(packet) { + Ok(endpoints) => endpoints, + Err(error) => return future::ready(Err(error)).boxed(), + }; + + self.for_each_resolver(|resolver| { + resolver.insert_host(name.to_string(), endpoints.clone()); + }); + + future::ready(Ok(())).boxed() + } +} + +#[cfg(feature = "dquic-network")] +impl Resolve for MdnsResolvers { + fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { + self.query(name).boxed() + } +} diff --git a/src/publishers.rs b/src/publishers.rs new file mode 100644 index 0000000..ae80ef8 --- /dev/null +++ b/src/publishers.rs @@ -0,0 +1,342 @@ +#[cfg(feature = "publishers")] +mod address; +#[cfg(feature = "publishers")] +mod dispatch; +#[cfg(feature = "publishers")] +mod packet; + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +use std::{any::TypeId, net::SocketAddr, time::Duration}; +#[cfg(feature = "publishers")] +use std::{io, sync::Arc}; + +#[cfg(feature = "publishers")] +pub use address::{ + AddressSelector, AddressView, FnAddressView, PublishAddressGroup, PublishAddressScope, + PublishAddresses, +}; +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub use address::{AddressViewSource, EndpointBindingAddresses}; +#[cfg(feature = "publishers")] +use dhttp_identity::{identity::LocalAuthority, name::Name}; +#[cfg(feature = "publishers")] +use dquic::qresolve::{Publish, Resolve}; +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +use dquic::{ + qinterface::component::location::AddressEvent, qtraversal::nat::client::ClientLocationData, +}; +#[cfg(feature = "publishers")] +pub use packet::{EndpointRecordSigner, SignEndpointRecordsError}; +#[cfg(feature = "publishers")] +use snafu::Snafu; + +#[cfg(feature = "h3")] +pub use crate::h3::H3Resolver as H3Publisher; +#[cfg(feature = "http")] +pub use crate::http::HttpResolver as HttpPublisher; +#[cfg(feature = "mdns")] +pub use crate::mdns::MdnsPublisher; + +#[cfg(feature = "publishers")] +#[derive(Debug, Snafu)] +#[snafu(module(create_publisher_error))] +pub enum CreatePublisherError { + #[snafu(display("anonymous endpoint cannot publish dns records"))] + AnonymousEndpoint, +} + +#[cfg(feature = "publishers")] +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum PublishOnceError { + #[snafu(display("no publisher resolver available"))] + NoPublisherResolver, + #[snafu(display("failed to sign endpoint records"))] + SignEndpointRecords { source: SignEndpointRecordsError }, + #[snafu(display("failed to publish dns packet with {publisher}"))] + Publish { + publisher: String, + source: io::Error, + }, +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); +/// Upper bound for a single publish attempt in the background loop. +/// +/// Network changes can leave an in-flight H3 publish waiting on paths that no +/// longer exist. Timing out the attempt keeps consecutive publishes +/// independent: the next interval observes the current bindings again. +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); + +#[cfg(feature = "publishers")] +#[derive(Clone)] +pub struct EndpointPublisher< + A: ?Sized = dyn LocalAuthority + Send + Sync, + R: ?Sized = dyn Resolve + Send + Sync, +> { + signer: EndpointRecordSigner, + resolver: Arc, +} + +#[cfg(feature = "publishers")] +impl EndpointPublisher +where + A: LocalAuthority + Send + Sync + ?Sized, + R: ?Sized, +{ + pub fn new(signer: EndpointRecordSigner, resolver: Arc) -> Self { + Self { signer, resolver } + } + + pub fn signer(&self) -> &EndpointRecordSigner { + &self.signer + } + + pub fn resolver(&self) -> &Arc { + &self.resolver + } +} + +#[cfg(feature = "publishers")] +impl EndpointPublisher +where + A: LocalAuthority + Send + Sync + ?Sized, + R: dispatch::ResolveDispatchTarget + ?Sized, +{ + pub async fn publish_once( + &self, + name: &Name<'_>, + addresses: &V, + ) -> Result<(), PublishOnceError> + where + V: AddressView + Sync, + { + let published = + dispatch::publish_to_resolver(self.signer(), self.resolver.as_ref(), name, addresses) + .await?; + if !published { + return publish_once_error::NoPublisherResolverSnafu.fail(); + } + Ok(()) + } +} + +#[cfg(feature = "publishers")] +#[derive(Default, Debug, Clone)] +pub struct Publishers { + publishers: Vec>, +} + +#[cfg(feature = "publishers")] +impl Publishers { + pub fn new() -> Self { + Self::default() + } + + pub fn with(mut self, publisher: Arc) -> Self { + self.push(publisher); + self + } + + pub fn push(&mut self, publisher: Arc) { + self.publishers.push(publisher); + } + + pub fn iter(&self) -> impl Iterator> { + self.publishers.iter() + } +} + +#[cfg(feature = "publishers")] +#[derive(Default, Debug)] +pub struct PublishersBuilder { + publishers: Publishers, +} + +#[cfg(feature = "publishers")] +impl PublishersBuilder { + pub fn publisher(mut self, publisher: Arc) -> Self { + self.publishers.push(publisher); + self + } + + pub fn build(self) -> Publishers { + self.publishers + } +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub type EndpointPublisherLoop = EndpointPublicationLoop< + dyn LocalAuthority + Send + Sync, + dyn Resolve + Send + Sync, + EndpointBindingAddresses, +>; + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +pub struct EndpointPublicationLoop { + name: Name<'static>, + publisher: EndpointPublisher, + source: S, + interval: Duration, + publish_timeout: Duration, +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +impl std::fmt::Debug for EndpointPublicationLoop +where + A: LocalAuthority + Send + Sync + ?Sized, + R: ?Sized, + S: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointPublicationLoop") + .field("name", &self.name) + .field("signer", self.publisher.signer()) + .field("source", &self.source) + .field("interval", &self.interval) + .field("publish_timeout", &self.publish_timeout) + .finish() + } +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +impl EndpointPublicationLoop +where + A: LocalAuthority + Send + Sync + ?Sized, + R: dispatch::ResolveDispatchTarget + ?Sized, + S: AddressViewSource + Sync, +{ + pub fn new(name: Name<'static>, publisher: EndpointPublisher, source: S) -> Self { + Self { + name, + publisher, + source, + interval: DEFAULT_PUBLISH_INTERVAL, + publish_timeout: DEFAULT_PUBLISH_TIMEOUT, + } + } + + pub fn interval(&self) -> Duration { + self.interval + } + + pub fn publish_timeout(&self) -> Duration { + self.publish_timeout + } + + pub fn with_interval(mut self, interval: Duration) -> Self { + self.interval = interval; + self + } + + pub fn with_publish_timeout(mut self, timeout: Duration) -> Self { + self.publish_timeout = timeout; + self + } + + pub async fn run(&self) -> ! { + let mut locations = self.source.subscribe(); + let interval = tokio::time::sleep(self.interval); + tokio::pin!(interval); + let mut current_publish = self.new_publish_loop_future(); + + loop { + tokio::select! { + _ = &mut current_publish => { + current_publish = Self::pending_publish_loop_future(); + } + _ = &mut interval => { + interval.as_mut().reset(tokio::time::Instant::now() + self.interval); + self.clear_publish_state(); + current_publish = self.new_publish_loop_future(); + } + event = locations.recv() => { + let Some((bind_uri, event)) = event else { + continue; + }; + if !self.source.observes(&bind_uri) { + continue; + } + if !Self::location_event_requires_publish(&event) { + continue; + } + + self.clear_publish_state(); + current_publish = self.new_publish_loop_future(); + } + } + } + } + + fn new_publish_loop_future(&self) -> futures::future::BoxFuture<'_, ()> { + Box::pin(async move { + tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; + let _ = self.publish_attempt().await; + }) + } + + fn pending_publish_loop_future<'a>() -> futures::future::BoxFuture<'a, ()> { + Box::pin(std::future::pending()) + } + + async fn publish_attempt(&self) -> bool { + tracing::trace!( + timeout_ms = self.publish_timeout.as_millis(), + name = %self.name, + "starting dns publish attempt" + ); + let view = self.source.address_view(); + match tokio::time::timeout( + self.publish_timeout, + self.publisher.publish_once(&self.name, &view), + ) + .await + { + Ok(Ok(())) => { + tracing::info!(name = %self.name, "published resolver endpoints"); + true + } + Ok(Err(error)) => { + let report = snafu::Report::from_error(&error); + tracing::warn!(error = %report, name = %self.name, "dns publish failed"); + false + } + Err(_elapsed) => { + self.clear_publish_state(); + tracing::warn!( + timeout_ms = self.publish_timeout.as_millis(), + name = %self.name, + "dns publish timed out" + ); + false + } + } + } + + fn location_event_requires_publish(event: &AddressEvent) -> bool { + match event { + AddressEvent::Upsert(data) => { + if let Some(bound_addr) = data.downcast_ref::>() { + return bound_addr.is_ok(); + } + if let Some(stun_addr) = data.downcast_ref::() { + return stun_addr.is_ok(); + } + false + } + AddressEvent::Remove(type_id) => { + *type_id == TypeId::of::>() + || *type_id == TypeId::of::() + } + AddressEvent::Closed => true, + } + } + + fn clear_publish_state(&self) { + dispatch::clear_resolver_publish_state(self.publisher.resolver().as_ref()); + } +} diff --git a/src/publishers/address.rs b/src/publishers/address.rs new file mode 100644 index 0000000..7fca89c --- /dev/null +++ b/src/publishers/address.rs @@ -0,0 +1,461 @@ +#[cfg(feature = "dquic-network")] +use std::collections::HashSet; +use std::sync::Arc; +#[cfg(feature = "dquic-network")] +use std::{net::SocketAddr, sync::OnceLock}; + +use dquic::qbase::net::{Family, addr::EndpointAddr}; +#[cfg(feature = "dquic-network")] +use dquic::qinterface::component::location::Observer; +#[cfg(feature = "dquic-network")] +use h3x::dquic::{ + Network, + binds::BindPattern, + net::{BindInterface, BindUri, IO, Scheme}, + qtraversal::nat::client::{NatType, StunClientsComponent}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AddressSelector<'a> { + WideArea, + LocalLink { device: &'a str, family: Family }, +} + +pub trait AddressView { + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a; +} + +pub struct FnAddressView { + f: F, +} + +impl FnAddressView { + pub fn new(f: F) -> Self { + Self { f } + } +} + +impl AddressView for FnAddressView +where + F: for<'a> Fn(AddressSelector<'a>) -> I, + I: IntoIterator, + I::IntoIter: 'static, +{ + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + (self.f)(selector).into_iter() + } +} + +#[cfg(feature = "dquic-network")] +pub trait AddressViewSource { + fn address_view(&self) -> impl AddressView + Send + Sync + '_; + fn subscribe(&self) -> Observer; + fn observes(&self, bind_uri: &BindUri) -> bool; +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PublishAddressScope { + WideArea, + LocalLink { device: Arc, family: Family }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PublishAddressGroup { + scope: PublishAddressScope, + endpoints: Vec, +} + +impl PublishAddressGroup { + pub fn wide_area(endpoints: I) -> Self + where + I: IntoIterator, + { + Self { + scope: PublishAddressScope::WideArea, + endpoints: endpoints.into_iter().collect(), + } + } + + pub fn local_link(device: impl Into>, family: Family, endpoints: I) -> Self + where + I: IntoIterator, + { + Self { + scope: PublishAddressScope::LocalLink { + device: device.into(), + family, + }, + endpoints: endpoints.into_iter().collect(), + } + } + + fn matches(&self, selector: AddressSelector<'_>) -> bool { + match (&self.scope, selector) { + (PublishAddressScope::WideArea, AddressSelector::WideArea) => true, + ( + PublishAddressScope::LocalLink { device, family }, + AddressSelector::LocalLink { + device: selected_device, + family: selected_family, + }, + ) => device.as_ref() == selected_device && *family == selected_family, + _ => false, + } + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct PublishAddresses { + groups: Vec, +} + +impl PublishAddresses { + pub fn new() -> Self { + Self::default() + } + + pub fn group(mut self, group: PublishAddressGroup) -> Self { + self.groups.push(group); + self + } + + pub fn wide_area(self, endpoints: I) -> Self + where + I: IntoIterator, + { + self.group(PublishAddressGroup::wide_area(endpoints)) + } + + pub fn local_link(self, device: impl Into>, family: Family, endpoints: I) -> Self + where + I: IntoIterator, + { + self.group(PublishAddressGroup::local_link(device, family, endpoints)) + } +} + +impl AddressView for PublishAddresses { + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + self.groups + .iter() + .filter(move |group| group.matches(selector)) + .flat_map(move |group| group.endpoints.iter().copied()) + } +} + +#[derive(Clone)] +#[cfg(feature = "dquic-network")] +pub struct EndpointBindingAddresses { + network: Arc, + bind_patterns: Arc>, +} + +#[cfg(feature = "dquic-network")] +impl std::fmt::Debug for EndpointBindingAddresses { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointBindingAddresses") + .field("bind_patterns", &self.bind_patterns) + .finish_non_exhaustive() + } +} + +#[cfg(feature = "dquic-network")] +impl EndpointBindingAddresses { + pub fn new(network: Arc, bind_patterns: Arc>) -> Self { + Self { + network, + bind_patterns, + } + } +} + +#[cfg(feature = "dquic-network")] +impl AddressViewSource for EndpointBindingAddresses { + fn address_view(&self) -> impl AddressView + Send + Sync + '_ { + EndpointBindingAddressView::new(self.network.clone(), self.bind_patterns.clone()) + } + + fn subscribe(&self) -> Observer { + self.network.quic().locations().subscribe() + } + + fn observes(&self, bind_uri: &BindUri) -> bool { + self.bind_patterns + .iter() + .any(|pattern| pattern.matches(bind_uri)) + } +} + +#[cfg(feature = "dquic-network")] +struct EndpointBindingAddressView { + bindings: Vec, +} + +#[cfg(feature = "dquic-network")] +impl EndpointBindingAddressView { + fn new(network: Arc, bind_patterns: Arc>) -> Self { + let mut bindings = Vec::new(); + for pattern in bind_patterns.iter() { + let Some(ifaces) = network.quic().get_interfaces(pattern) else { + tracing::trace!(?pattern, "no interfaces for bind pattern"); + continue; + }; + for iface in ifaces { + bindings.push(BindingAddress::new(network.clone(), pattern.clone(), iface)); + } + } + Self { bindings } + } +} + +#[cfg(feature = "dquic-network")] +impl AddressView for EndpointBindingAddressView { + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + let mut seen = HashSet::new(); + self.bindings + .iter() + .filter(move |binding| binding.may_match(selector)) + .flat_map(move |binding| binding.endpoints(selector)) + .filter(move |endpoint| seen.insert(*endpoint)) + } +} + +#[cfg(feature = "dquic-network")] +struct BindingAddress { + network: Arc, + pattern: BindPattern, + bind_uri: BindUri, + iface: BindInterface, + wide_area: OnceLock>, + local_link: OnceLock>, +} + +#[cfg(feature = "dquic-network")] +impl BindingAddress { + fn new(network: Arc, pattern: BindPattern, iface: BindInterface) -> Self { + let bind_uri = iface.bind_uri(); + Self { + network, + pattern, + bind_uri, + iface, + wide_area: OnceLock::new(), + local_link: OnceLock::new(), + } + } + + fn may_match(&self, selector: AddressSelector<'_>) -> bool { + match selector { + AddressSelector::WideArea => true, + AddressSelector::LocalLink { device, family } => { + pattern_may_match_local_link(&self.pattern, device, family) + && bind_uri_matches_local_link(&self.bind_uri, device, family) + } + } + } + + fn endpoints<'a>( + &'a self, + selector: AddressSelector<'a>, + ) -> impl Iterator + 'a { + let endpoints = match selector { + AddressSelector::WideArea => self + .wide_area + .get_or_init(|| public_endpoints_from_iface(&self.network, &self.iface)), + AddressSelector::LocalLink { family, .. } => self + .local_link + .get_or_init(|| local_endpoints_from_iface(&self.iface, family)), + }; + endpoints.iter().copied() + } +} + +#[cfg(feature = "dquic-network")] +fn pattern_may_match_local_link(pattern: &BindPattern, device: &str, family: Family) -> bool { + if pattern.scheme != Scheme::Iface { + return false; + } + if pattern + .host + .family() + .is_some_and(|pattern_family| pattern_family != family) + { + return false; + } + pattern.host.matches(device) +} + +#[cfg(feature = "dquic-network")] +fn bind_uri_matches_local_link(bind_uri: &BindUri, device: &str, family: Family) -> bool { + bind_uri + .as_iface_bind_uri() + .is_some_and(|(iface_family, iface_device, _port)| { + iface_family == family && iface_device == device + }) +} + +#[cfg(feature = "dquic-network")] +fn public_endpoints_from_iface(network: &Network, iface: &BindInterface) -> Vec { + iface.with_components(|components, current| { + let bind_uri = current.bind_uri(); + let addr = current.bound_addr().ok(); + let mut endpoints: Vec = components + .get::() + .map(|stun| { + stun.with_clients(|clients| { + clients + .values() + .filter_map(|client| { + let outer = client.get_outer_addr()?.ok()?; + let bound = current.bound_addr().ok()?; + match client.get_nat_type() { + Some(Ok(nat_type)) => Some(publish_endpoint_from_stun( + bound, + client.agent_addr(), + outer, + nat_type, + )), + None => Some(EndpointAddr::with_agent(client.agent_addr(), outer)), + Some(Err(_)) => None, + } + }) + .collect() + }) + }) + .unwrap_or_default(); + let stun_endpoint_count = endpoints.len(); + + if let Some(addr) = addr + && network.bound_addr_is_on_default_route(&bind_uri, addr) + { + endpoints.push(EndpointAddr::direct(addr)); + } + + tracing::trace!( + bind_uri = %bind_uri, + bound_addr = ?addr, + stun_endpoint_count, + endpoint_count = endpoints.len(), + endpoints = ?endpoints, + "collected wide-area endpoints from interface" + ); + + endpoints + }) +} + +#[cfg(feature = "dquic-network")] +fn publish_endpoint_from_stun( + bound: SocketAddr, + agent: SocketAddr, + outer: SocketAddr, + nat_type: NatType, +) -> EndpointAddr { + if nat_type == NatType::FullCone && bound == outer { + EndpointAddr::direct(outer) + } else { + EndpointAddr::with_agent(agent, outer) + } +} + +#[cfg(feature = "dquic-network")] +fn local_endpoints_from_iface(iface: &BindInterface, family: Family) -> Vec { + iface.with_components(|_components, current| { + let Some(addr) = current.bound_addr().ok() else { + return Vec::new(); + }; + match (family, addr) { + (Family::V4, SocketAddr::V4(_)) | (Family::V6, SocketAddr::V6(_)) => { + vec![EndpointAddr::direct(addr)] + } + _ => Vec::new(), + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn publish_addresses_select_wide_area_only_for_wide_area_selector() { + let wide = EndpointAddr::direct("203.0.113.10:443".parse().unwrap()); + let local = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); + let addresses = + PublishAddresses::new() + .wide_area([wide]) + .local_link("en0", Family::V4, [local]); + + let selected: Vec<_> = addresses.endpoints(AddressSelector::WideArea).collect(); + + assert_eq!(selected, vec![wide]); + } + + #[test] + fn publish_addresses_select_matching_local_link_group() { + let en0 = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); + let en1 = EndpointAddr::direct("192.168.2.20:443".parse().unwrap()); + let addresses = PublishAddresses::new() + .local_link("en0", Family::V4, [en0]) + .local_link("en1", Family::V4, [en1]); + + let selected: Vec<_> = addresses + .endpoints(AddressSelector::LocalLink { + device: "en1", + family: Family::V4, + }) + .collect(); + + assert_eq!(selected, vec![en1]); + } + + #[test] + fn publish_addresses_reject_local_link_family_mismatch() { + let endpoint = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); + let addresses = PublishAddresses::new().local_link("en0", Family::V4, [endpoint]); + + let selected: Vec<_> = addresses + .endpoints(AddressSelector::LocalLink { + device: "en0", + family: Family::V6, + }) + .collect(); + + assert!(selected.is_empty()); + } + + #[cfg(feature = "dquic-network")] + #[test] + fn full_cone_nat_endpoint_preserves_agent_when_outer_differs_from_bound_addr() { + let bound = "10.110.0.10:45635".parse().expect("valid bound addr"); + let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); + let outer = "10.10.0.10:45635".parse().expect("valid outer addr"); + + let endpoint = publish_endpoint_from_stun(bound, agent, outer, NatType::FullCone); + + assert_eq!(endpoint, EndpointAddr::with_agent(agent, outer)); + } + + #[cfg(feature = "dquic-network")] + #[test] + fn full_cone_endpoint_is_direct_without_address_translation() { + let bound = "10.10.0.100:45635".parse().expect("valid bound addr"); + let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); + + let endpoint = publish_endpoint_from_stun(bound, agent, bound, NatType::FullCone); + + assert_eq!(endpoint, EndpointAddr::direct(bound)); + } +} diff --git a/src/publishers/dispatch.rs b/src/publishers/dispatch.rs new file mode 100644 index 0000000..c465808 --- /dev/null +++ b/src/publishers/dispatch.rs @@ -0,0 +1,203 @@ +use std::any::Any; + +use dhttp_identity::{identity::LocalAuthority, name::Name}; +use dquic::qresolve::{Publish, Resolve}; +use snafu::{IntoError, ResultExt}; + +use super::{ + AddressSelector, AddressView, EndpointRecordSigner, PublishOnceError, publish_once_error, +}; +#[cfg(feature = "resolvers")] +use crate::resolvers::Resolvers; + +#[cfg(all(feature = "h3", feature = "dquic-network"))] +type DeferredH3Resolver = + crate::resolvers::deferred::DeferredResolver>; + +#[doc(hidden)] +pub trait ResolveDispatchTarget: Resolve { + fn as_resolve(&self) -> &(dyn Resolve + Send + Sync); + fn as_any(&self) -> &dyn Any; +} + +impl ResolveDispatchTarget for T +where + T: Resolve + Send + Sync + 'static, +{ + fn as_resolve(&self) -> &(dyn Resolve + Send + Sync) { + self + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl ResolveDispatchTarget for dyn Resolve + Send + Sync { + fn as_resolve(&self) -> &(dyn Resolve + Send + Sync) { + self + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +pub(crate) async fn publish_to_resolver( + signer: &EndpointRecordSigner, + resolver: &R, + name: &Name<'_>, + addresses: &V, +) -> Result +where + A: LocalAuthority + Send + Sync + ?Sized, + R: ResolveDispatchTarget + ?Sized, + V: AddressView + Sync, +{ + let any = resolver.as_any(); + + #[cfg(feature = "resolvers")] + if let Some(resolvers) = any.downcast_ref::() { + let mut published = false; + for resolver in resolvers.iter() { + published |= + publish_single_resolver(signer, resolver.as_ref(), name, addresses).await?; + } + return Ok(published); + } + + publish_single_resolver(signer, resolver.as_resolve(), name, addresses).await +} + +async fn publish_single_resolver( + signer: &EndpointRecordSigner, + resolver: &(dyn Resolve + Send + Sync), + name: &Name<'_>, + addresses: &V, +) -> Result +where + A: LocalAuthority + Send + Sync + ?Sized, + V: AddressView + Sync, +{ + let any = resolver as &dyn Any; + + #[cfg(not(any( + feature = "http", + all(feature = "h3", feature = "dquic-network"), + all(feature = "mdns", feature = "dquic-network") + )))] + { + let _ = any; + let _ = name; + let _ = addresses; + } + + #[cfg(feature = "http")] + if let Some(http) = any.downcast_ref::() { + publish_selected(signer, http, name, addresses, AddressSelector::WideArea).await?; + return Ok(true); + } + + #[cfg(all(feature = "h3", feature = "dquic-network"))] + if let Some(h3) = any.downcast_ref::>() { + publish_selected(signer, h3, name, addresses, AddressSelector::WideArea).await?; + return Ok(true); + } + + #[cfg(all(feature = "h3", feature = "dquic-network"))] + if let Some(h3) = any.downcast_ref::() { + let Some(h3) = h3.get() else { + return Err(publish_once_error::PublishSnafu { + publisher: h3.to_string(), + } + .into_error(std::io::Error::other( + "deferred h3 resolver has not been initialized", + ))); + }; + publish_selected(signer, h3, name, addresses, AddressSelector::WideArea).await?; + return Ok(true); + } + + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + if let Some(mdns) = any.downcast_ref::() { + let mut published = false; + for bound in mdns.bound_resolvers() { + publish_selected( + signer, + &bound.resolver, + name, + addresses, + AddressSelector::LocalLink { + device: &bound.device, + family: bound.family, + }, + ) + .await?; + published = true; + } + return Ok(published); + } + + Ok(false) +} + +async fn publish_selected( + signer: &EndpointRecordSigner, + publisher: &(dyn Publish + Send + Sync), + name: &Name<'_>, + addresses: &V, + selector: AddressSelector<'_>, +) -> Result<(), PublishOnceError> +where + A: LocalAuthority + Send + Sync + ?Sized, + V: AddressView + Sync, +{ + let endpoints: Vec<_> = addresses.endpoints(selector).collect(); + let packet = signer + .signed_packet(name, &endpoints) + .await + .context(publish_once_error::SignEndpointRecordsSnafu)?; + tracing::debug!( + publisher = %publisher, + name = %name, + endpoint_count = endpoints.len(), + packet_len = packet.len(), + "publishing dns packet" + ); + publisher + .publish(name.as_str(), &packet) + .await + .context(publish_once_error::PublishSnafu { + publisher: publisher.to_string(), + }) +} + +pub(crate) fn clear_resolver_publish_state(resolver: &R) +where + R: ResolveDispatchTarget + ?Sized, +{ + clear_single_resolver_publish_state(resolver.as_resolve()); +} + +fn clear_single_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { + let any = resolver as &dyn Any; + + #[cfg(feature = "resolvers")] + if let Some(resolvers) = any.downcast_ref::() { + for resolver in resolvers.iter() { + clear_single_resolver_publish_state(resolver.as_ref()); + } + } + + #[cfg(all(feature = "h3", feature = "dquic-network"))] + if let Some(h3) = any.downcast_ref::>() { + h3.clear_pool(); + } + + #[cfg(all(feature = "h3", feature = "dquic-network"))] + if let Some(h3) = any.downcast_ref::() + && let Some(h3) = h3.get() + { + h3.clear_pool(); + } +} diff --git a/src/publishers/packet.rs b/src/publishers/packet.rs new file mode 100644 index 0000000..420bd46 --- /dev/null +++ b/src/publishers/packet.rs @@ -0,0 +1,87 @@ +use std::{collections::HashMap, sync::Arc}; + +use dhttp_identity::{ + identity::{LocalAuthority, LocalAuthorityCertificateExt}, + name::Name, +}; +use dquic::qbase::net::addr::EndpointAddr; +use snafu::{ResultExt, Snafu}; + +use crate::core::{ + MdnsPacket, + parser::record::endpoint::{EndpointAddr as DnsEndpointAddr, SignEndpointError}, +}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SignEndpointRecordsError { + #[snafu(display("failed to encode endpoint address"))] + EncodeEndpoint, + #[snafu(display("failed to extract dhttp certificate selector"))] + CertificateSelector { + source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, + }, + #[snafu(display("failed to sign endpoint address"))] + SignEndpoint { source: SignEndpointError }, +} + +#[derive(Clone)] +pub struct EndpointRecordSigner { + authority: Arc, +} + +impl std::fmt::Debug for EndpointRecordSigner +where + A: LocalAuthority, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EndpointRecordSigner") + .field("authority", &self.authority.name()) + .finish() + } +} + +impl EndpointRecordSigner +where + A: LocalAuthority + Send + Sync + ?Sized, +{ + pub fn new(authority: Arc) -> Self { + Self { authority } + } + + pub fn authority(&self) -> &Arc { + &self.authority + } + + pub async fn signed_packet( + &self, + name: &Name<'_>, + endpoints: &[EndpointAddr], + ) -> Result, SignEndpointRecordsError> { + let selector = self + .authority + .dhttp_subject_key_identifier() + .context(sign_endpoint_records_error::CertificateSelectorSnafu)?; + let chain = selector.chain(); + + let mut signed = Vec::with_capacity(endpoints.len()); + for endpoint in endpoints { + let Ok(mut endpoint) = DnsEndpointAddr::try_from(*endpoint) else { + return sign_endpoint_records_error::EncodeEndpointSnafu.fail(); + }; + endpoint.set_main( + chain.kind() == dhttp_identity::certificate::CertificateChainKind::Primary, + ); + endpoint.set_sequence(chain.sequence().get().into()); + endpoint + .sign_with_authority(self.authority.as_ref()) + .await + .context(sign_endpoint_records_error::SignEndpointSnafu)?; + signed.push(endpoint); + } + + let mut hosts = HashMap::new(); + hosts.insert(name.as_str().to_owned(), signed); + Ok(MdnsPacket::answer(0, &hosts).to_bytes()) + } +} diff --git a/src/resolvers.rs b/src/resolvers.rs index 96451d3..cad1652 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -1,34 +1,33 @@ +#[cfg(feature = "resolvers")] use std::{ error::Error, - fmt::{self, Debug, Display}, + fmt::{self, Display}, sync::Arc, }; +#[cfg(feature = "resolvers")] use dquic::{ qbase::net::addr::EndpointAddr, qresolve::{Resolve, ResolveFuture, Source}, }; +#[cfg(feature = "resolvers")] use futures::{FutureExt, Stream, StreamExt, TryFutureExt, stream}; +#[cfg(feature = "resolvers")] use tokio::io; -#[cfg(feature = "h3x-resolver")] -pub mod h3; -#[cfg(feature = "http-resolver")] -pub mod http; - -#[cfg(feature = "http-resolver")] -use http::HttpResolver; - -#[cfg(feature = "mdns-resolver")] -use crate::mdns::resolvers::mdns::MdnsResolvers; +#[cfg(feature = "h3")] +pub use crate::h3::H3Resolver; +#[cfg(feature = "http")] +pub use crate::http::HttpResolver; +#[cfg(feature = "mdns")] +pub use crate::mdns::MdnsResolver; +#[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] +use crate::mdns::MdnsResolvers; /// Extract and validate the DNS host from `name`, which may include a `:port` /// suffix. Returns `Some(host)` if the host part is a valid RFC-compliant DNS /// name, or `None` for raw IP addresses, bracketed IPv6, or malformed input. -#[cfg_attr( - not(any(feature = "h3x-resolver", feature = "http-resolver")), - allow(dead_code) -)] +#[cfg_attr(not(any(feature = "h3", feature = "http")), allow(dead_code))] pub(crate) fn resolvable_name(name: &str) -> Option<&str> { let host = match name.rsplit_once(':') { Some((h, port)) if !port.is_empty() && port.chars().all(|c| c.is_ascii_digit()) => h, @@ -47,6 +46,7 @@ pub const DHTTP_HTTP_DNS_SERVER: &str = crate::bootstrap::DHTTP_HTTP_DNS_SERVER; /// mDNS service type used by DHTTP endpoints. pub const DHTTP_MDNS_SERVICE: &str = crate::bootstrap::DHTTP_MDNS_SERVICE; +#[cfg(feature = "resolvers")] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum DnsScheme { Mdns, @@ -55,6 +55,7 @@ pub enum DnsScheme { System, } +#[cfg(feature = "resolvers")] impl Display for DnsScheme { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { @@ -66,12 +67,14 @@ impl Display for DnsScheme { } } +#[cfg(feature = "resolvers")] #[derive(Debug, snafu::Snafu)] #[snafu(display("unsupported dns scheme {scheme}"))] pub struct ParseDnsSchemeError { scheme: String, } +#[cfg(feature = "resolvers")] impl std::str::FromStr for DnsScheme { type Err = ParseDnsSchemeError; @@ -89,16 +92,20 @@ impl std::str::FromStr for DnsScheme { } pub mod deferred; +#[cfg(feature = "mdns")] pub(crate) mod selector; pub mod weak; +#[cfg(feature = "resolvers")] type ArcResolver = Arc; +#[cfg(feature = "resolvers")] #[derive(Default, Clone, Debug)] pub struct Resolvers { resolvers: Vec, } +#[cfg(feature = "resolvers")] impl Display for Resolvers { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str("Resolvers(")?; @@ -116,11 +123,13 @@ impl Display for Resolvers { } } +#[cfg(feature = "resolvers")] #[derive(Debug)] pub struct DnsErrors { errors: Vec<(String, io::Error)>, } +#[cfg(feature = "resolvers")] fn format_dns_error_sources( f: &mut fmt::Formatter<'_>, error: &(dyn Error + 'static), @@ -137,6 +146,7 @@ fn format_dns_error_sources( Ok(()) } +#[cfg(feature = "resolvers")] fn format_dns_error_entry( f: &mut fmt::Formatter<'_>, resolver: &str, @@ -146,6 +156,7 @@ fn format_dns_error_entry( format_dns_error_sources(f, error) } +#[cfg(feature = "resolvers")] impl fmt::Display for DnsErrors { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.errors.is_empty() { @@ -160,31 +171,35 @@ impl fmt::Display for DnsErrors { } } +#[cfg(feature = "resolvers")] impl Error for DnsErrors {} +#[cfg(feature = "resolvers")] #[derive(Default)] pub struct ResolversBuilder { resolvers: Resolvers, } +#[cfg(feature = "resolvers")] impl ResolversBuilder { pub fn resolver(mut self, resolver: ArcResolver) -> Self { self.resolvers.push(resolver); self } - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network"))] pub async fn mdns( mut self, network: Arc, patterns: Arc>, ) -> Self { - let mdns = Arc::new(MdnsResolvers::bind(network, patterns, DHTTP_MDNS_SERVICE).await); + let mdns: ArcResolver = + Arc::new(MdnsResolvers::bind(network, patterns, DHTTP_MDNS_SERVICE).await); self.resolvers.push(mdns); self } - #[cfg(feature = "h3x-resolver")] + #[cfg(feature = "h3")] pub fn h3( self, endpoint: Arc>, @@ -197,7 +212,7 @@ impl ResolversBuilder { self.h3_with_base_url(DHTTP_H3_DNS_SERVER, endpoint) } - #[cfg(feature = "h3x-resolver")] + #[cfg(feature = "h3")] pub fn h3_with_base_url( mut self, base_url: impl AsRef, @@ -208,17 +223,17 @@ impl ResolversBuilder { C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { - let resolver = h3::H3Resolver::from_endpoint(base_url, endpoint)?; + let resolver = H3Resolver::from_endpoint(base_url, endpoint)?; self.resolvers.push(Arc::new(resolver)); Ok(self) } - #[cfg(feature = "http-resolver")] + #[cfg(feature = "http")] pub fn http(self) -> io::Result { self.http_with_base_url(DHTTP_HTTP_DNS_SERVER) } - #[cfg(feature = "http-resolver")] + #[cfg(feature = "http")] pub fn http_with_base_url(mut self, base_url: impl AsRef) -> io::Result { let resolver = HttpResolver::new(base_url.as_ref())?; self.resolvers.push(Arc::new(resolver)); @@ -236,6 +251,7 @@ impl ResolversBuilder { } } +#[cfg(feature = "resolvers")] impl Resolvers { pub fn builder() -> ResolversBuilder { ResolversBuilder::default() @@ -284,6 +300,7 @@ impl Resolvers { } } +#[cfg(feature = "resolvers")] impl Resolve for Resolvers { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { self.lookup(name) @@ -297,18 +314,13 @@ impl Resolve for Resolvers { mod tests { use std::{error::Error as StdError, fmt, io, str::FromStr}; - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] use super::MdnsResolvers; - #[cfg(any( - feature = "h3x-resolver", - feature = "http-resolver", - feature = "mdns-resolver" - ))] + #[cfg(feature = "resolvers")] use super::Resolvers; - use super::{ - DHTTP_H3_DNS_SERVER, DHTTP_HTTP_DNS_SERVER, DHTTP_MDNS_SERVICE, DnsErrors, DnsScheme, - resolvable_name, - }; + use super::{DHTTP_H3_DNS_SERVER, DHTTP_HTTP_DNS_SERVER, DHTTP_MDNS_SERVICE, resolvable_name}; + #[cfg(feature = "resolvers")] + use super::{DnsErrors, DnsScheme}; #[derive(Debug)] struct TestSourceError { @@ -389,6 +401,7 @@ mod tests { assert_eq!(resolvable_name("[::1]:443"), None); } + #[cfg(feature = "resolvers")] #[test] fn dns_scheme_round_trips_supported_schemes_and_rejects_dht() { let cases = [ @@ -406,6 +419,7 @@ mod tests { assert!(DnsScheme::from_str("dht").is_err()); } + #[cfg(feature = "resolvers")] #[test] fn dns_errors_render_no_resolvers_available_when_empty() { let error = DnsErrors { errors: vec![] }; @@ -413,6 +427,7 @@ mod tests { assert_eq!(error.to_string(), "no DNS resolvers available"); } + #[cfg(feature = "resolvers")] #[test] fn dns_errors_render_resolver_bullets_in_stored_order() { let error = DnsErrors { @@ -435,6 +450,7 @@ mod tests { ); } + #[cfg(feature = "resolvers")] #[test] fn dns_errors_render_numbered_source_chain_for_one_resolver() { let error = DnsErrors { @@ -457,6 +473,7 @@ mod tests { ); } + #[cfg(feature = "resolvers")] #[test] fn dns_errors_render_repeated_source_messages_without_deduplication() { let error = DnsErrors { @@ -483,7 +500,7 @@ mod tests { ); } - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] #[tokio::test] async fn resolvers_builder_can_enable_mdns() { use std::sync::Arc; @@ -501,7 +518,7 @@ mod tests { assert!(resolvers.to_string().contains("mDNS resolvers")); } - #[cfg(feature = "h3x-resolver")] + #[cfg(all(feature = "h3", feature = "resolvers", feature = "dquic-network"))] #[tokio::test] async fn resolvers_builder_accepts_custom_h3_base_url() { use std::sync::Arc; @@ -518,7 +535,7 @@ mod tests { assert!(resolvers.to_string().contains("custom-dns.example")); } - #[cfg(feature = "http-resolver")] + #[cfg(all(feature = "http", feature = "resolvers"))] #[test] fn resolvers_builder_accepts_custom_http_base_url() { let resolvers = Resolvers::builder() @@ -529,7 +546,7 @@ mod tests { assert!(resolvers.to_string().contains("custom-dns.example")); } - #[cfg(feature = "mdns-resolver")] + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] #[tokio::test] async fn mdns_resolvers_bind_installs_mdns_on_null_io_binding() { use std::sync::Arc; diff --git a/src/resolvers/selector.rs b/src/resolvers/selector.rs index 1bd0c86..c03ace4 100644 --- a/src/resolvers/selector.rs +++ b/src/resolvers/selector.rs @@ -2,6 +2,10 @@ use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use crate::core::parser::record::endpoint::EndpointAddr as DnsEndpointAddr; +type Selector = (bool, u64); +type TaggedEndpoint = (T, DquicEndpointAddr); +type EndpointGroup = (Selector, Vec>); + pub(crate) fn selected_endpoint_addrs( records: impl IntoIterator, ) -> Vec { @@ -14,10 +18,10 @@ pub(crate) fn selected_endpoint_addrs( pub(crate) fn selected_endpoint_records( records: impl IntoIterator, ) -> Vec<(T, DquicEndpointAddr)> { - let mut groups: Vec<((bool, u64), Vec<(T, DquicEndpointAddr)>)> = Vec::new(); + let mut groups: Vec> = Vec::new(); for (tag, record) in records { - let selector = (record.is_main(), 0); + let selector = (record.is_main(), record.sequence().unwrap_or(0)); let Ok(endpoint) = DquicEndpointAddr::try_from(record) else { continue; }; @@ -33,8 +37,9 @@ pub(crate) fn selected_endpoint_records( groups .into_iter() - .flat_map(|(_, endpoints)| endpoints) - .collect() + .next() + .map(|(_, endpoints)| endpoints) + .unwrap_or_default() } #[cfg(test)] diff --git a/tests/feature_surface.rs b/tests/feature_surface.rs new file mode 100644 index 0000000..b9a06e5 --- /dev/null +++ b/tests/feature_surface.rs @@ -0,0 +1,58 @@ +#[cfg(feature = "h3")] +#[test] +fn h3_backend_module_is_public() { + #[allow(unused_imports)] + use ddns::h3; +} + +#[cfg(feature = "http")] +#[test] +fn http_backend_module_is_public() { + #[allow(unused_imports)] + use ddns::http; +} + +#[cfg(feature = "mdns")] +#[test] +fn mdns_module_is_public() { + #[allow(unused_imports)] + use ddns::mdns; +} + +#[test] +fn resolvers_module_is_public() { + #[allow(unused_imports)] + use ddns::resolvers; +} + +#[test] +fn publishers_module_is_public() { + #[allow(unused_imports)] + use ddns::publishers; +} + +#[cfg(all(feature = "http", feature = "resolvers", feature = "publishers"))] +#[test] +fn http_backend_is_reexported_from_both_facades() { + use ddns::{ + http::HttpResolver, publishers::HttpPublisher, + resolvers::HttpResolver as FacadeHttpResolver, + }; + + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); +} + +#[cfg(all(feature = "mdns", feature = "resolvers", feature = "publishers"))] +#[test] +fn mdns_backend_is_reexported_from_both_facades() { + use ddns::{ + mdns::MdnsResolver, publishers::MdnsPublisher, + resolvers::MdnsResolver as FacadeMdnsResolver, + }; + + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); +} diff --git a/tests/h3_generic_surface.rs b/tests/h3_generic_surface.rs new file mode 100644 index 0000000..248c729 --- /dev/null +++ b/tests/h3_generic_surface.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "h3")] +#[test] +fn h3_backend_and_facades_export_the_same_type_name() { + use ddns::{ + h3::H3Resolver, publishers::H3Publisher, resolvers::H3Resolver as FacadeH3Resolver, + }; + + let _ = core::any::type_name::>(); + let _ = core::any::type_name::>(); + let _ = core::any::type_name::>(); +} diff --git a/tests/publishers_surface.rs b/tests/publishers_surface.rs new file mode 100644 index 0000000..745dd74 --- /dev/null +++ b/tests/publishers_surface.rs @@ -0,0 +1,31 @@ +#[cfg(feature = "publishers")] +#[test] +fn publishers_facade_exposes_endpoint_publisher_and_aggregate_types() { + let _ = core::any::type_name::(); + let _ = core::any::type_name::< + ddns::publishers::EndpointPublisher< + dyn dhttp_identity::identity::LocalAuthority + Send + Sync, + dyn dquic::qresolve::Resolve + Send + Sync, + >, + >(); + + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); +} + +#[cfg(all(feature = "publishers", feature = "dquic-network"))] +#[test] +fn publishers_facade_exposes_network_publication_loop_surface() { + let _ = ddns::publishers::DEFAULT_PUBLISH_INTERVAL; + let _ = ddns::publishers::DEFAULT_PUBLISH_TIMEOUT; + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::< + ddns::publishers::EndpointPublicationLoop< + dyn dhttp_identity::identity::LocalAuthority + Send + Sync, + dyn dquic::qresolve::Resolve + Send + Sync, + ddns::publishers::EndpointBindingAddresses, + >, + >(); +} From 1dd4ed51482329dc2e5cbc8b96363bd4b981b88f Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 15:59:21 +0800 Subject: [PATCH 10/29] chore: remove legacy ddns files --- docs/aws-deployment.md | 28 - docs/redis-contract.md | 403 -------- scripts/update-geolite-mmdb.sh | 102 --- server.toml | 97 -- src/bin/ddns-server/config.rs | 246 ----- src/bin/ddns-server/error.rs | 111 --- src/bin/ddns-server/geo.rs | 196 ---- src/bin/ddns-server/lookup/http.rs | 109 --- src/bin/ddns-server/lookup/mod.rs | 8 - src/bin/ddns-server/lookup/query.rs | 263 ------ src/bin/ddns-server/lookup/ranking.rs | 254 ----- src/bin/ddns-server/lookup/tests.rs | 427 --------- src/bin/ddns-server/main.rs | 445 --------- src/bin/ddns-server/ocsp.rs | 230 ----- src/bin/ddns-server/policy.rs | 200 ---- src/bin/ddns-server/publish/http.rs | 163 ---- src/bin/ddns-server/publish/mod.rs | 7 - src/bin/ddns-server/publish/store.rs | 250 ----- src/bin/ddns-server/publish/tests.rs | 126 --- src/bin/ddns-server/storage.rs | 520 ----------- src/mdns/resolvers.rs | 1 - src/mdns/resolvers/mdns.rs | 354 ------- src/publisher.rs | 1222 ------------------------- src/publisher/address.rs | 444 --------- src/publisher/dispatch.rs | 164 ---- src/publisher/packet.rs | 83 -- src/resolvers/h3.rs | 556 ----------- src/resolvers/http.rs | 271 ------ 28 files changed, 7280 deletions(-) delete mode 100644 docs/aws-deployment.md delete mode 100644 docs/redis-contract.md delete mode 100755 scripts/update-geolite-mmdb.sh delete mode 100644 server.toml delete mode 100644 src/bin/ddns-server/config.rs delete mode 100644 src/bin/ddns-server/error.rs delete mode 100644 src/bin/ddns-server/geo.rs delete mode 100644 src/bin/ddns-server/lookup/http.rs delete mode 100644 src/bin/ddns-server/lookup/mod.rs delete mode 100644 src/bin/ddns-server/lookup/query.rs delete mode 100644 src/bin/ddns-server/lookup/ranking.rs delete mode 100644 src/bin/ddns-server/lookup/tests.rs delete mode 100644 src/bin/ddns-server/main.rs delete mode 100644 src/bin/ddns-server/ocsp.rs delete mode 100644 src/bin/ddns-server/policy.rs delete mode 100644 src/bin/ddns-server/publish/http.rs delete mode 100644 src/bin/ddns-server/publish/mod.rs delete mode 100644 src/bin/ddns-server/publish/store.rs delete mode 100644 src/bin/ddns-server/publish/tests.rs delete mode 100644 src/bin/ddns-server/storage.rs delete mode 100644 src/mdns/resolvers.rs delete mode 100644 src/mdns/resolvers/mdns.rs delete mode 100644 src/publisher.rs delete mode 100644 src/publisher/address.rs delete mode 100644 src/publisher/dispatch.rs delete mode 100644 src/publisher/packet.rs delete mode 100644 src/resolvers/h3.rs delete mode 100644 src/resolvers/http.rs diff --git a/docs/aws-deployment.md b/docs/aws-deployment.md deleted file mode 100644 index edb50d4..0000000 --- a/docs/aws-deployment.md +++ /dev/null @@ -1,28 +0,0 @@ -# AWS Deployment Notes - -`ddns-server` keeps QUIC/TLS/mTLS end-to-end in the backend process. - -## Load Balancer - -- Put an NLB in front of the server. -- Forward UDP, QUIC, TCP_UDP, or TCP_QUIC traffic to the backend without - terminating TLS. -- Expose a separate TCP/HTTP/HTTPS health check port. - -## Redis - -- Use `redis_write_url` for the primary Redis endpoint. -- Use `redis_read_url` for the regional read replica or reader endpoint. -- `publish` and `clear` write only to the primary. -- `lookup` is read-only and can point at the replica. -- Expired index cleanup runs on the write path, not on lookup. - -## Host Allowlist - -- Configure `host_allowlist` with the suffixes this deployment owns. -- Example: `["genmeta.net"]` - -## Extra UDP Services - -- Keep STUN or custom UDP services on a separate NLB UDP listener and port. -- Do not multiplex them onto the QUIC listener unless the application does its own UDP demux. diff --git a/docs/redis-contract.md b/docs/redis-contract.md deleted file mode 100644 index b34fdb9..0000000 --- a/docs/redis-contract.md +++ /dev/null @@ -1,403 +0,0 @@ -# gmdns Redis 存储说明 - -这份文档描述 `ddns-server` 在 Redis 里实际会存什么、怎么存、这些数据分别 -是干什么用的。它优先写给人看,同时保留足够的细节,方便别的服务对接。 - -如果你只想先看结论,这个系统在 Redis 里只会用到 3 类原生数据结构: - -1. `String`:存一条发布记录的完整二进制内容 -2. `Sorted Set`:存查询用的倒排索引 -3. `Set`:存黑名单域名 - -没有使用 `Hash`、`List`、`Stream`、`Bitmap` 之类的其他 Redis 结构。 - -## 1. 总览 - -`ddns-server` 自己维护的 Redis key 一共 4 种形态,加上 1 个外部可写黑名单: - -| Key 形态 | Redis 类型 | 作用 | -| --- | --- | --- | -| `:fp:` | `String` | 某个 host 下,某个证书指纹对应的一条完整发布记录 | -| `:idx:all` | `Sorted Set` | 这个 host 的全部活动记录索引 | -| `:idx:country:` | `Sorted Set` | 这个 host 按国家分桶的活动记录索引 | -| `:idx:asn:` | `Sorted Set` | 这个 host 按 ASN 分桶的活动记录索引 | -| `ddns:blacklist` | `Set` | 被封禁的 host 列表 | - -其中: - -- 主记录 `String` 是事实来源,真正的记录内容只在这里 -- 3 类 `Sorted Set` 都是派生索引,只是为了加速查询 -- 黑名单 `Set` 是一个独立控制面数据,不参与记录存储 - -## 2. Host 规范化规则 - -Redis 里的 host 名必须先做规范化。代码实现见 -[`src/bin/ddns-server/error.rs`](/Users/lixiaofeng/code/gmdns/src/bin/ddns-server/error.rs) 的 -`normalize_host(host, allowlist)`。 - -规则如下: - -1. 去掉首尾空白 -2. 不能为空 -3. 不能包含 `*` -4. 如果最后一个 `:` 后面全是数字,就当成端口号去掉 -5. 去掉结尾的一个 `.` -6. 用 IDNA 转成 ASCII -7. 转成小写 -8. 最终结果必须匹配配置里的 `host_allowlist` 后缀之一 - -例子: - -- `DNS.Genmeta.Net.` -> `dns.genmeta.net` -- `dns.genmeta.net:4433` -> `dns.genmeta.net` -- `blocked.example.genmeta.net` -> `blocked.example.genmeta.net` - -`host_allowlist` 默认包含 `genmeta.net`,所以现有 `genmeta.net` -子域名仍然可用。 - -这条规则对所有 Redis key 都重要,尤其是黑名单成员必须写规范化之后的 host。 - -## 3. 各类 Redis 数据结构 - -### 3.1 主记录 - -Key 形式: - -```text -:fp: -``` - -例子: - -```text -nat.genmeta.net:fp:db6905c72be9aa8b1a61f7d45dd399d64136da17ac384ef67f1f5670055a2946 -``` - -Redis 类型: - -```text -String -``` - -值的含义: - -- 存的是一个二进制 `StoredRecord` -- 里面包含这条记录的完整 DNS 包、发布者证书、签名字段、过期时间等 - -TTL: - -- 通过 `SETEX` / `SET EX` 写入 -- TTL 等于服务配置里的 `ttl_secs` - -业务语义: - -- 同一个 `host` 下,同一个证书指纹只能有 1 条活动记录 -- 同一个证书再次发布,会覆盖自己之前的记录 -- 同一个 `host` 下,不同证书指纹可以并存 - -可以把它理解成: - -```text -一个 host 下,以“证书指纹”作为主键的记录表 -``` - -### 3.2 全量索引 - -Key 形式: - -```text -:idx:all -``` - -例子: - -```text -nat.genmeta.net:idx:all -``` - -Redis 类型: - -```text -Sorted Set -``` - -成员和值: - -- member: `` -- score: 发布时间的 Unix 秒时间戳,代码里按 `f64` 写入 - -TTL: - -- 每次写入相关记录时,会给这个索引 key 重新设置 `ttl_secs` - -业务语义: - -- 表示这个 host 当前有哪些候选发布者记录 -- 查询时,如果 GEO 定向索引不够用,会回退到这个索引 -- 返回顺序是最新发布的在前面,因为读取时用的是 `ZREVRANGE` - -### 3.3 国家索引 - -Key 形式: - -```text -:idx:country: -``` - -例子: - -```text -nat.genmeta.net:idx:country:CN -``` - -Redis 类型: - -```text -Sorted Set -``` - -成员和值: - -- member: `` -- score: 发布时间的 Unix 秒时间戳 - -`` 从哪里来: - -- 发布时解析 DNS 包里的 endpoint IP -- 对这些 IP 做 GEO 查询 -- 把得到的国家代码去重、排序后写入索引 - -业务语义: - -- 这是按国家分桶的候选记录索引 -- 查询时,如果请求方的源 IP 能解析出国家,会先尝试这个桶 - -### 3.4 ASN 索引 - -Key 形式: - -```text -:idx:asn: -``` - -例子: - -```text -nat.genmeta.net:idx:asn:4134 -``` - -Redis 类型: - -```text -Sorted Set -``` - -成员和值: - -- member: `` -- score: 发布时间的 Unix 秒时间戳 - -`` 从哪里来: - -- 和国家索引一样,也是从发布内容里的 endpoint IP 做 GEO 解析得到 - -业务语义: - -- 这是按 ASN 分桶的候选记录索引 -- 查询时,如果请求方的源 IP 能解析出 ASN,会最先尝试这个桶 - -### 3.5 黑名单集合 - -Key: - -```text -ddns:blacklist -``` - -Redis 类型: - -```text -Set -``` - -成员格式: - -- 规范化之后的小写 ASCII host 名 - -例子: - -```text -blocked.example.genmeta.net -``` - -业务语义: - -- 查询开始时先查这个集合 -- 如果 `SISMEMBER ddns:blacklist ` 为真,直接返回 `404 Not Found` -- 黑名单只拦截查询,不拦截 publish,也不拦截 clear -- 黑名单不会删除已有记录 - -常用操作: - -```bash -redis-cli SADD ddns:blacklist blocked.example.genmeta.net -redis-cli SREM ddns:blacklist blocked.example.genmeta.net -``` - -## 4. 主记录里到底存了什么 - -主记录 value 不是 JSON,也不是 Hash,而是一段连续的二进制。 - -顺序如下: - -```text -u64 expire_unix_secs -u8 fingerprint[32] -u32 content_digest_len -u8 content_digest[content_digest_len] -u32 signature_input_len -u8 signature_input[signature_input_len] -u32 signature_len -u8 signature[signature_len] -u32 dns_len -u8 dns[dns_len] -u32 cert_len -u8 cert[cert_len] -``` - -字段说明: - -| 字段 | 含义 | -| --- | --- | -| `expire_unix_secs` | 这条记录的业务过期时间,Unix 秒 | -| `fingerprint` | 发布者叶子证书的 SHA-256 原始 32 字节,不是 hex 字符串 | -| `content_digest` | HTTP 签名里的 `Content-Digest` 原始字节 | -| `signature_input` | HTTP 签名里的 `Signature-Input` 原始字节 | -| `signature` | HTTP 签名里的 `Signature` 原始字节 | -| `dns` | 序列化后的 DNS 包体 | -| `cert` | 发布者叶子证书的 DER 字节 | - -补充说明: - -- 使用大端序 -- 没有版本号字段 -- 三个签名字段都允许为空 -- 如果记录没有签名,这三个字段长度就是 `0` - -## 5. 写入时怎么维护这些结构 - -### 5.1 Publish - -发布 `(host, fingerprint)` 时,流程是: - -1. 读取旧的主记录 -2. 如果旧记录能解码出来,就从旧记录推导出旧的国家 / ASN 标签 -3. 先把旧指纹从所有相关索引里删掉 -4. 用 `SETEX` 写入新的主记录 -5. 把指纹加入: - - `:idx:all` - - 若干 `:idx:country:` - - 若干 `:idx:asn:` -6. 给所有碰到的索引 key 重新设置 TTL -7. 对这些索引执行: - -```text -ZREMRANGEBYSCORE -inf -``` - -这样做的效果是: - -- 主记录会自然过期 -- 索引里过旧的 member 也会被顺手清掉 -- 同一个证书重复发布,不会在索引里留下重复脏数据 - -### 5.2 Clear - -清理 `(host, fingerprint)` 时,流程是: - -1. 读取旧主记录 -2. 从旧主记录推导出它所在的国家 / ASN 桶 -3. 把这个指纹从所有相关索引删掉 -4. 删除主记录 key - -### 5.3 一致性和自愈预期 - -这里的写入不是事务性的。 - -也就是说,一次 publish / clear 会改多个 key,但这些操作不是用单个 Redis -事务原子提交的。如果中途失败,短时间内可能出现下面这些情况: - -- 主记录已经更新,但部分索引还没更新 -- 索引里还留着旧指纹,但主记录已经不存在 -- 某些 GEO 桶暂时缺少一条本该存在的记录 - -这个设计对上述短暂不一致是接受的,原因有两个: - -1. 查询时真正可信的数据源始终是主记录 `String`,索引只是候选入口 -2. 节点会大约每 30 秒重新上报一次,同一条记录会被持续刷新 -3. lookup 只读 Redis,不再执行 `ZREMRANGEBYSCORE`;过期索引清理留在 - publish / clear 路径,或者由 primary 侧后台 sweeper 完成 - -这意味着: - -- 如果索引里残留了一个已经失效的指纹,查询阶段读取不到主记录时会直接跳过 -- 如果一次写入导致某个索引短暂漏写,下一次节点上报通常会把它补回来 -- 即使没有专门的数据修复流程,TTL 和周期性重上报也会让大多数临时偏差自然收敛 - -因此,这套存储模型的目标是: - -- 接受短暂的不一致 -- 依赖 30 秒级的周期刷新实现轻量自愈 -- 不为了少量短时脏索引引入额外复杂的数据修复机制 - -## 6. 查询时怎么用这些结构 - -当 Redis 存储启用时,查询流程是: - -1. 规范化请求里的 host -2. 先查 `ddns:blacklist` -3. 按顺序收集候选指纹: - - 先 ASN 索引 - - 再国家索引 - - 最后全量索引 -4. 按这个顺序去重 -5. 逐个读取 `:fp:` 主记录 -6. 丢弃解码失败或业务上已经过期的记录 -7. 把剩下的记录交给现有排序逻辑继续处理 - -这里最重要的认识是: - -- `Sorted Set` 只是“候选名单” -- 真正可信的数据源始终是主记录 `String` - -## 7. 归属边界 - -`ddns-server` 自己维护下面这些 key: - -- `:fp:` -- `:idx:all` -- `:idx:country:` -- `:idx:asn:` - -外部服务如果只是想做黑名单联动,只应该写: - -- `ddns:blacklist` - -如果外部服务想直接写记录,就必须完整实现: - -- 主记录二进制编码 -- 所有派生索引的增删 -- TTL 维护 -- 过期索引清理 - -否则很容易把 Redis 里的记录和索引写乱。 - -## 8. 一句话理解 - -这个 Redis 模型本质上是: - -- 用一个 `String` 保存“完整记录” -- 用几个 `Sorted Set` 保存“按 host / 国家 / ASN 分类的候选索引” -- 用一个 `Set` 保存“是否禁止查询这个 host” - -真正的记录内容不在索引里,索引只是为了更快找到应该读哪个主记录。 diff --git a/scripts/update-geolite-mmdb.sh b/scripts/update-geolite-mmdb.sh deleted file mode 100755 index f7cc3e9..0000000 --- a/scripts/update-geolite-mmdb.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/env sh - -set -eu - -usage() { - cat <<'EOF' -Usage: - MAXMIND_ACCOUNT_ID=... MAXMIND_LICENSE_KEY=... ./scripts/update-geolite-mmdb.sh [target-dir] - -Downloads or updates the GeoLite2 City and ASN mmdb databases with geoipupdate. - -Arguments: - target-dir Optional output directory. Defaults to /var/lib/ddns/geoip. - -Required environment variables: - MAXMIND_ACCOUNT_ID - MAXMIND_LICENSE_KEY - -Optional environment variables: - GEOIPUPDATE_BIN geoipupdate binary name or path. Default: geoipupdate - GEOIPUPDATE_VERBOSE Set to 1 to pass -v to geoipupdate. - -Example: - MAXMIND_ACCOUNT_ID=12345 \ - MAXMIND_LICENSE_KEY=xxxx \ - ./scripts/update-geolite-mmdb.sh /etc/ddns -EOF -} - -if [ "${1:-}" = "-h" ] || [ "${1:-}" = "--help" ]; then - usage - exit 0 -fi - -if [ "$#" -gt 1 ]; then - usage >&2 - exit 64 -fi - -require_env() { - name="$1" - eval "value=\${$name:-}" - if [ -z "$value" ]; then - echo "missing required environment variable: $name" >&2 - exit 2 - fi -} - -require_env MAXMIND_ACCOUNT_ID -require_env MAXMIND_LICENSE_KEY - -geoipupdate_bin="${GEOIPUPDATE_BIN:-geoipupdate}" -target_dir="${1:-${GEOIP_TARGET_DIR:-/var/lib/ddns/geoip}}" - -if ! command -v "$geoipupdate_bin" >/dev/null 2>&1; then - echo "geoipupdate not found: $geoipupdate_bin" >&2 - echo "install it first, for example on macOS: brew install geoipupdate" >&2 - exit 127 -fi - -umask 077 -tmp_dir="$(mktemp -d "${TMPDIR:-/tmp}/geoipupdate.XXXXXX")" -cleanup() { - rm -rf "$tmp_dir" -} -trap cleanup EXIT HUP INT TERM - -mkdir -p "$target_dir" - -config_file="$tmp_dir/GeoIP.conf" -cat >"$config_file" <&2 - exit 1 -fi - -cat <";alg="..." -# Signature: dns=:...: -require_signature = true - -# Default TTL (seconds) for published records. -ttl_secs = 30 - -# Redis primary URL for persistent storage. -# If omitted, records are kept in memory only (lost on restart). -# redis_write_url = "redis://primary.example:6379/0" -# -# Optional Redis read URL for lookup traffic. Defaults to the write URL when -# omitted, but on AWS this is usually a regional replica / reader endpoint. -# redis_read_url = "redis://replica.example:6379/0" -# -# When Redis storage is enabled, lookups check the external blacklist set -# "ddns:blacklist". Without Redis, this file can preload an in-memory blacklist. -# Members are normalized lowercase ASCII host names. Blacklisted lookups return -# 404; publish/clear requests are not blocked. -# -# redis-cli SADD ddns:blacklist blocked.example.genmeta.net -# redis-cli SREM ddns:blacklist blocked.example.genmeta.net -# -# blacklist = ["blocked.example.genmeta.net"] - -# Enable GEO-aware scheduling based on country / ASN. -# When both databases are configured, city-distance tie-breaking is also enabled -# for sufficiently accurate records. -# geoip_city_db = "/etc/ddns/GeoLite2-City.mmdb" -# geoip_asn_db = "/etc/ddns/GeoLite2-ASN.mmdb" - -# --------------------------------------------------------------------------- -# Domain policy rules -# -# Policies are matched in order; the first matching rule wins. -# Domains not listed here use the built-in "standard" policy. -# -# Policies: -# standard — one record per host; client cert SAN must match the target -# host; signature check controlled by require_signature above; -# each publish overwrites the previous record. -# -# open_multi — any authenticated node may publish; no signature check; -# records are appended (not overwritten), each with its own -# individual TTL; lookup returns newest-first, use ?limit=N -# to cap the number of returned records. -# --------------------------------------------------------------------------- - -[[domain_policies]] -host = "nat.genmeta.net" -policy = "open_multi" - - -# Static bootstrap STUN endpoints returned even before any node publishes. -# Ordering keeps the main :20002 endpoints ahead of the auxiliary :20003 endpoints. -[[seed_records]] -host = "nat.genmeta.net" -endpoints = [] - -# Add more rules as needed, e.g.: -# [[domain_policies]] -# host = "relay.genmeta.net" -# policy = "open_multi" diff --git a/src/bin/ddns-server/config.rs b/src/bin/ddns-server/config.rs deleted file mode 100644 index 3bbb996..0000000 --- a/src/bin/ddns-server/config.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::{ - net::SocketAddr, - path::{Path, PathBuf}, - str::FromStr, -}; - -use clap::Parser; -use h3x::dquic::binds::BindPattern; -use serde::{Deserialize, Deserializer, de::Error as _}; - -// --------------------------------------------------------------------------- -// CLI -// --------------------------------------------------------------------------- - -#[derive(Parser, Clone, Debug)] -#[command(version, about, long_about = None)] -pub struct Options { - /// Path to the TOML configuration file. - #[arg(long, default_value = "server.toml")] - pub config: PathBuf, -} - -// --------------------------------------------------------------------------- -// Configuration file schema -// --------------------------------------------------------------------------- - -/// Top-level configuration loaded from the TOML file. -#[derive(Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct Config { - /// Redis write URL (e.g. "redis://primary:6379/"). Alias: `redis`. - #[serde(default, alias = "redis")] - pub redis_write_url: Option, - - /// Optional Redis read URL (e.g. "redis://replica:6379/"). - #[serde(default)] - pub redis_read_url: Option, - - /// Allowed host suffixes (normalized, suffix-matched). - #[serde(default = "Config::default_host_allowlist")] - pub host_allowlist: Vec, - - /// Bind patterns to listen on. - #[serde( - default = "Config::default_binds", - deserialize_with = "deserialize_bind_patterns" - )] - pub binds: Vec, - - /// Server name (used as TLS SNI). - #[serde(default = "Config::default_server_name")] - pub server_name: String, - - /// Path to the server TLS certificate (PEM). - #[serde(default = "Config::default_cert")] - pub cert: PathBuf, - - /// Path to the server TLS private key (PEM). - #[serde(default = "Config::default_key")] - pub key: PathBuf, - - /// Path to the root CA that signs client certificates (PEM). - #[serde(default = "Config::default_root_cert")] - pub root_cert: PathBuf, - - /// Optional issuer certificate used for OCSP requests when `cert` does not include a chain. - #[serde(default)] - pub ocsp_issuer_cert: Option, - - /// Optional OCSP responder base URL. Defaults to the cert-server public responder. - #[serde(default)] - pub ocsp_responder_base_url: Option, - - /// Whether to require DNS record signatures on Standard domains. - #[serde(default = "Config::default_require_signature")] - pub require_signature: bool, - - /// Default TTL (seconds) for published records. - #[serde(default = "Config::default_ttl_secs")] - pub ttl_secs: u64, - - /// Domain-policy rules (first match wins; unlisted domains use Standard). - #[serde(default)] - pub domain_policies: Vec, - - /// In-memory blacklist loaded at startup when Redis storage is not configured. - #[serde(default)] - pub blacklist: Vec, - - /// Static seed records returned on lookup in addition to dynamic published records. - #[serde(default)] - pub seed_records: Vec, - - /// Path to the GeoLite2 City database. - #[serde(default)] - pub geoip_city_db: Option, - - /// Path to the GeoLite2 ASN database. - #[serde(default)] - pub geoip_asn_db: Option, -} - -impl Config { - pub fn expand_paths(mut self) -> Self { - self.cert = expand_home_dir(&self.cert); - self.key = expand_home_dir(&self.key); - self.root_cert = expand_home_dir(&self.root_cert); - self.ocsp_issuer_cert = self.ocsp_issuer_cert.map(|path| expand_home_dir(&path)); - self.geoip_city_db = self.geoip_city_db.map(|path| expand_home_dir(&path)); - self.geoip_asn_db = self.geoip_asn_db.map(|path| expand_home_dir(&path)); - self - } - - pub fn default_binds() -> Vec { - ["0.0.0.0:4433", "[::]:4433"] - .into_iter() - .map(|value| { - BindPattern::from_str(value).expect("default bind pattern should be valid") - }) - .collect() - } - pub fn default_server_name() -> String { - "localhost".into() - } - pub fn default_cert() -> PathBuf { - "examples/keychain/localhost/localhost-ECC.crt".into() - } - pub fn default_key() -> PathBuf { - "examples/keychain/localhost/localhost-ECC.key".into() - } - pub fn default_root_cert() -> PathBuf { - "examples/keychain/root/rootCA-ECC.crt".into() - } - pub fn default_host_allowlist() -> Vec { - vec!["genmeta.net".into()] - } - pub fn default_require_signature() -> bool { - true - } - pub fn default_ttl_secs() -> u64 { - 30 - } -} - -fn deserialize_bind_patterns<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - let values = Vec::::deserialize(deserializer)?; - values - .into_iter() - .map(|value| { - BindPattern::from_str(&value).map_err(|error| { - D::Error::custom(format!("invalid bind pattern `{value}`: {error}")) - }) - }) - .collect() -} - -fn expand_home_dir(path: &Path) -> PathBuf { - let Some(path_str) = path.to_str() else { - return path.to_path_buf(); - }; - - if path_str == "~" { - return std::env::var_os("HOME") - .map(PathBuf::from) - .unwrap_or_else(|| path.to_path_buf()); - } - - if let Some(stripped) = path_str.strip_prefix("~/") - && let Some(home) = std::env::var_os("HOME") - { - return PathBuf::from(home).join(stripped); - } - - path.to_path_buf() -} - -/// One domain-policy rule in the configuration file. -#[derive(Deserialize, Debug)] -#[serde(deny_unknown_fields)] -pub struct PolicyConfig { - /// Exact host to match (after normalisation). - pub host: String, - /// Policy to apply. - pub policy: PolicyKind, -} - -/// One statically configured seed record group. -#[derive(Deserialize, Debug, Clone)] -#[serde(deny_unknown_fields)] -pub struct SeedRecordConfig { - /// Exact host to seed. - pub host: String, - /// Preloaded endpoint list for this host. - pub endpoints: Vec, -} - -/// Serialisable policy kind. -#[derive(Deserialize, Debug, Clone)] -#[serde(rename_all = "snake_case")] -pub enum PolicyKind { - Standard, - OpenMulti, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn default_binds_are_explicit_dual_stack() { - let binds = Config::default_binds(); - - assert_eq!(binds.len(), 2); - assert_eq!(binds[0].to_string(), "inet://0.0.0.0:4433"); - assert_eq!(binds[1].to_string(), "inet://[::]:4433"); - } - - #[test] - fn config_parses_bare_socket_bind_patterns() { - let config: Config = toml::from_str( - r#" - binds = ["0.0.0.0:4433", "[::]:4433"] - "#, - ) - .expect("config should parse"); - - assert_eq!(config.binds.len(), 2); - assert_eq!(config.binds[0].to_string(), "inet://0.0.0.0:4433"); - assert_eq!(config.binds[1].to_string(), "inet://[::]:4433"); - } - - #[test] - fn legacy_listen_field_is_rejected() { - let error = toml::from_str::( - r#" - listen = "0.0.0.0:4433" - "#, - ) - .expect_err("legacy listen should be rejected"); - - assert!(error.to_string().contains("unknown field `listen`")); - } -} diff --git a/src/bin/ddns-server/error.rs b/src/bin/ddns-server/error.rs deleted file mode 100644 index 6a46879..0000000 --- a/src/bin/ddns-server/error.rs +++ /dev/null @@ -1,111 +0,0 @@ -use std::collections::HashMap; - -#[derive(Debug, snafu::Snafu)] -#[snafu(module, visibility(pub(crate)))] -pub enum AppError { - #[snafu(display("missing host parameter"))] - MissingHostParam, - #[snafu(display("invalid host"))] - InvalidHost, - #[snafu(display("forbidden host"))] - ForbiddenHost, - #[snafu(display("domain not allowed"))] - DomainNotAllowed, - #[snafu(display("host mismatch"))] - HostMismatch, - #[snafu(display("missing client certificate"))] - MissingClientCertificate, - #[snafu(display("client certificate domain not allowed"))] - ClientCertDomainNotAllowed, - #[snafu(display("invalid DNS packet: {message}"))] - InvalidDnsPacket { message: String }, - #[snafu(display("publisher certificate selector is invalid"))] - PublisherCertificateSelector { - source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, - }, - #[snafu(display("endpoint record selector does not match publisher certificate selector"))] - EndpointSelectorMismatch, - #[snafu(display("no answers in packet"))] - NoAnswersInPacket, - #[snafu(display("signature required"))] - SignatureRequired, - #[snafu(display("invalid signature"))] - InvalidSignature, - #[snafu(display("redis error: {message}"))] - Redis { message: String }, -} - -impl AppError { - pub fn status(&self) -> http::StatusCode { - match self { - AppError::MissingHostParam => http::StatusCode::BAD_REQUEST, - AppError::InvalidHost => http::StatusCode::BAD_REQUEST, - AppError::ForbiddenHost => http::StatusCode::BAD_REQUEST, - AppError::DomainNotAllowed => http::StatusCode::FORBIDDEN, - AppError::HostMismatch => http::StatusCode::BAD_REQUEST, - AppError::MissingClientCertificate => http::StatusCode::UNAUTHORIZED, - AppError::ClientCertDomainNotAllowed => http::StatusCode::FORBIDDEN, - AppError::InvalidDnsPacket { .. } => http::StatusCode::BAD_REQUEST, - AppError::PublisherCertificateSelector { .. } => http::StatusCode::BAD_REQUEST, - AppError::EndpointSelectorMismatch => http::StatusCode::BAD_REQUEST, - AppError::NoAnswersInPacket => http::StatusCode::UNPROCESSABLE_ENTITY, - AppError::SignatureRequired => http::StatusCode::BAD_REQUEST, - AppError::InvalidSignature => http::StatusCode::BAD_REQUEST, - AppError::Redis { .. } => http::StatusCode::SERVICE_UNAVAILABLE, - } - } -} - -pub fn normalize_host_allowlist(entries: &[String]) -> Result, AppError> { - let mut allowlist = entries - .iter() - .map(|entry| normalize_host_raw(entry)) - .collect::, _>>()?; - allowlist.sort(); - allowlist.dedup(); - Ok(allowlist) -} - -pub fn normalize_host(host: &str, allowlist: &[String]) -> Result { - let host = normalize_host_raw(host)?; - if allowlist - .iter() - .any(|suffix| host_matches_suffix(&host, suffix)) - { - Ok(host) - } else { - Err(AppError::DomainNotAllowed) - } -} - -pub fn normalize_host_raw(host: &str) -> Result { - let host = host.trim(); - if host.is_empty() { - return Err(AppError::InvalidHost); - } - if host.contains('*') { - return Err(AppError::ForbiddenHost); - } - - let host = match host.rsplit_once(':') { - Some((h, port)) if port.chars().all(|c| c.is_ascii_digit()) => h, - _ => host, - }; - let host = host.strip_suffix('.').unwrap_or(host); - let host = idna::domain_to_ascii(host).map_err(|_| AppError::InvalidHost)?; - Ok(host.to_ascii_lowercase()) -} - -pub fn parse_query_params(uri: &http::Uri) -> HashMap { - let query = uri.query().unwrap_or(""); - url::form_urlencoded::parse(query.as_bytes()) - .into_owned() - .collect() -} - -fn host_matches_suffix(host: &str, suffix: &str) -> bool { - host == suffix - || host - .strip_suffix(suffix) - .is_some_and(|prefix| prefix.ends_with('.')) -} diff --git a/src/bin/ddns-server/geo.rs b/src/bin/ddns-server/geo.rs deleted file mode 100644 index 3cedb9a..0000000 --- a/src/bin/ddns-server/geo.rs +++ /dev/null @@ -1,196 +0,0 @@ -use std::{io, net::IpAddr, path::Path}; - -use maxminddb::{Reader, geoip2}; - -#[derive(Clone, Debug)] -pub struct GeoPoint { - pub latitude: f64, - pub longitude: f64, - pub accuracy_radius_km: u16, -} - -#[derive(Clone, Debug, Default)] -pub struct GeoTraits { - pub country: Option, - pub city: Option, - pub asn: Option, - pub point: Option, -} - -#[derive(Debug)] -pub struct GeoResolver { - city: Reader>, - asn: Reader>, - city_distance_routing: bool, - max_accuracy_radius_km: u32, -} - -impl GeoResolver { - pub fn open( - city_db: &Path, - asn_db: &Path, - city_distance_routing: bool, - max_accuracy_radius_km: u32, - ) -> io::Result { - let city = Reader::open_readfile(city_db).map_err(io::Error::other)?; - let asn = Reader::open_readfile(asn_db).map_err(io::Error::other)?; - - Ok(Self { - city, - asn, - city_distance_routing, - max_accuracy_radius_km, - }) - } - - pub fn lookup_traits(&self, ip: IpAddr) -> GeoTraits { - GeoTraits { - country: self.lookup_country(ip), - city: self.lookup_city(ip), - asn: self.lookup_asn(ip), - point: self.lookup_point(ip), - } - } - - pub fn city_build_epoch(&self) -> u64 { - self.city.metadata.build_epoch - } - - pub fn asn_build_epoch(&self) -> u64 { - self.asn.metadata.build_epoch - } - - pub fn lookup_country(&self, ip: IpAddr) -> Option { - let city = self.city.lookup::(ip).ok()??; - city.country?.iso_code.map(str::to_owned) - } - - pub fn lookup_asn(&self, ip: IpAddr) -> Option { - let asn = self.asn.lookup::(ip).ok()??; - asn.autonomous_system_number - } - - pub fn lookup_city(&self, ip: IpAddr) -> Option { - let city = self.city.lookup::(ip).ok()??; - city.city?.names?.get("en").copied().map(str::to_owned) - } - - pub fn lookup_point(&self, ip: IpAddr) -> Option { - let city = self.city.lookup::(ip).ok()??; - let location = city.location?; - let latitude = location.latitude?; - let longitude = location.longitude?; - let accuracy_radius_km = location.accuracy_radius?; - - Some(GeoPoint { - latitude, - longitude, - accuracy_radius_km, - }) - } - - pub fn geo_distance_km(&self, left: &GeoPoint, right: &GeoPoint) -> Option { - if !self.city_distance_routing { - return None; - } - - if u32::from(left.accuracy_radius_km) > self.max_accuracy_radius_km - || u32::from(right.accuracy_radius_km) > self.max_accuracy_radius_km - { - return None; - } - - Some(haversine_distance_km( - left.latitude, - left.longitude, - right.latitude, - right.longitude, - )) - } -} - -fn haversine_distance_km( - left_latitude: f64, - left_longitude: f64, - right_latitude: f64, - right_longitude: f64, -) -> f64 { - let earth_radius_km = 6_371.0; - let lat_delta = (right_latitude - left_latitude).to_radians(); - let lon_delta = (right_longitude - left_longitude).to_radians(); - let left_latitude = left_latitude.to_radians(); - let right_latitude = right_latitude.to_radians(); - - let haversine = (lat_delta / 2.0).sin().powi(2) - + left_latitude.cos() * right_latitude.cos() * (lon_delta / 2.0).sin().powi(2); - let arc = 2.0 * haversine.sqrt().asin(); - - earth_radius_km * arc -} - -#[cfg(test)] -mod tests { - use std::{net::IpAddr, path::PathBuf, str::FromStr}; - - use super::*; - - fn fixture_geo_resolver() -> GeoResolver { - let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let city_db = manifest_dir.join("geoip/GeoLite2-City.mmdb"); - let asn_db = manifest_dir.join("geoip/GeoLite2-ASN.mmdb"); - - GeoResolver::open(&city_db, &asn_db, true, 100).expect("fixture geo db should open") - } - - #[test] - fn bundled_geolite_maps_real_ips_to_expected_country_and_asn() { - let geo = fixture_geo_resolver(); - let cases = [ - ("8.8.8.8", "US", 15169_u32), - ("223.5.5.5", "CN", 45102_u32), - ("80.80.80.80", "NL", 60679_u32), - ("168.95.1.1", "TW", 3462_u32), - ("200.160.0.8", "BR", 22548_u32), - ]; - - for (candidate, expected_country, expected_asn) in cases { - let ip = IpAddr::from_str(candidate).unwrap(); - let traits = geo.lookup_traits(ip); - - assert_eq!(traits.country.as_deref(), Some(expected_country)); - assert_eq!(traits.asn, Some(expected_asn)); - assert!( - traits.point.is_some(), - "{candidate} should resolve to a city point" - ); - } - } - - #[test] - fn bundled_geolite_exposes_city_name_separately_from_accuracy_radius() { - let geo = fixture_geo_resolver(); - let ip = IpAddr::from_str("223.5.5.5").unwrap(); - let traits = geo.lookup_traits(ip); - - assert_eq!(traits.country.as_deref(), Some("CN")); - assert_eq!(traits.city.as_deref(), Some("Hangzhou")); - assert_eq!( - traits.point.as_ref().map(|point| point.accuracy_radius_km), - Some(20) - ); - } - - #[test] - fn bundled_geolite_may_have_coordinates_without_city_name() { - let geo = fixture_geo_resolver(); - let ip = IpAddr::from_str("168.95.1.1").unwrap(); - let traits = geo.lookup_traits(ip); - - assert_eq!(traits.country.as_deref(), Some("TW")); - assert_eq!(traits.city, None); - assert_eq!( - traits.point.as_ref().map(|point| point.accuracy_radius_km), - Some(200) - ); - } -} diff --git a/src/bin/ddns-server/lookup/http.rs b/src/bin/ddns-server/lookup/http.rs deleted file mode 100644 index 9e68088..0000000 --- a/src/bin/ddns-server/lookup/http.rs +++ /dev/null @@ -1,109 +0,0 @@ -use std::{any::Any, convert::Infallible, net::IpAddr, sync::Arc}; - -use h3x::{connection::ConnectionState, dhttp::message::MessageStreamError, quic}; -use http_body_util::{Full, combinators::UnsyncBoxBody}; -use tracing::debug; - -use super::query::{LookupResult, perform_lookup}; -use crate::{ - error::{AppError, parse_query_params}, - storage::AppState, -}; - -pub type Request = http::Request>; -pub type Response = http::Response>; - -fn request_source_ip(request: &Request) -> Option { - let connection = request - .extensions() - .get::>>()? - .clone(); - let quic = connection.quic(); - let dquic = (quic.as_ref() as &dyn Any).downcast_ref::()?; - let ctx = dquic.path_context().ok()?; - - ctx.paths::>() - .into_iter() - .next() - .map(|(pathway, _)| pathway.remote().addr().ip()) -} -pub fn body_response(status: http::StatusCode, body: impl Into) -> Response { - http::Response::builder() - .status(status) - .body(Full::new(body.into())) - .expect("response parts must be valid") -} - -pub fn write_error(err: AppError) -> Response { - debug!( - status = %err.status(), - error = %err, - "writing error response" - ); - body_response(err.status(), bytes::Bytes::from(err.to_string())) -} - -// --------------------------------------------------------------------------- -// LookupSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct LookupSvc { - pub state: AppState, -} - -/// Handle a lookup request. -/// -/// Always returns multi-record binary body: -/// `[u32 count BE]([u32 dns_len BE][dns][u32 cert_len BE][cert])*` -/// with header `x-record-format: multi`. -/// -/// Optional query param `limit=N` caps the number of records returned. -/// Dynamic records are newest-first; configured seed records are appended after them. -pub async fn lookup_with_cert(state: AppState, request: Request) -> Response { - let params = parse_query_params(request.uri()); - let Some(host) = params.get("host") else { - return write_error(AppError::MissingHostParam); - }; - let source_ip = request_source_ip(&request); - - let limit: Option = params - .get("limit") - .and_then(|v| v.parse::().ok()) - .filter(|&n| n > 0); - - debug!(host = %host, limit, ?source_ip, "lookup.request"); - - match perform_lookup(&state, host, limit, source_ip).await { - Ok(LookupResult::NotFound) => { - debug!(host = %host, "lookup.not_found"); - body_response( - http::StatusCode::NOT_FOUND, - bytes::Bytes::from_static(b"Not Found"), - ) - } - - Ok(LookupResult::Multi(resp)) => { - let body = resp.encode(); - debug!(host = %host, records = resp.records.len(), "lookup.found"); - let mut response = body_response(http::StatusCode::OK, bytes::Bytes::from(body)); - response.headers_mut().insert( - http::HeaderName::from_static("x-record-format"), - http::HeaderValue::from_static("multi"), - ); - response - } - - Err(e) => write_error(e), - } -} - -impl LookupSvc { - pub fn call( - &self, - request: Request, - ) -> impl Future> + Send + 'static { - let state = self.state.clone(); - async move { Ok(lookup_with_cert(state, request).await) } - } -} diff --git a/src/bin/ddns-server/lookup/mod.rs b/src/bin/ddns-server/lookup/mod.rs deleted file mode 100644 index 62e795a..0000000 --- a/src/bin/ddns-server/lookup/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod http; -pub(crate) mod query; -mod ranking; - -pub use http::{LookupSvc, Request, Response, body_response, write_error}; - -#[cfg(test)] -mod tests; diff --git a/src/bin/ddns-server/lookup/query.rs b/src/bin/ddns-server/lookup/query.rs deleted file mode 100644 index 8db1328..0000000 --- a/src/bin/ddns-server/lookup/query.rs +++ /dev/null @@ -1,263 +0,0 @@ -use std::{collections::HashSet, hash::Hash, net::IpAddr}; - -use ddns::core::wire::{MultiResponse, ResponseRecord}; -use deadpool_redis::redis::{self, AsyncCommands}; -use tracing::debug; - -use super::ranking::{ - LOOKUP_CANDIDATE_CAP_ALL, LOOKUP_CANDIDATE_CAP_ASN, LOOKUP_CANDIDATE_CAP_COUNTRY, - LOOKUP_CANDIDATE_CAP_TOTAL, normalize_lookup_records, request_source_geo_traits, - sort_lookup_records, sort_lookup_records_with_geo, -}; -use crate::{ - error::{AppError, normalize_host}, - geo::GeoTraits, - storage::{ - AppState, Storage, StoredRecord, redis_all_index_key, redis_asn_index_key, - redis_blacklist_key, redis_country_index_key, redis_primary_key, unix_now_secs, - }, -}; - -pub enum LookupResult { - NotFound, - /// Multiple records, newest-first. - Multi(MultiResponse), -} -fn candidate_total_cap(limit: Option) -> usize { - limit - .unwrap_or(LOOKUP_CANDIDATE_CAP_TOTAL) - .max(LOOKUP_CANDIDATE_CAP_TOTAL) -} - -fn all_candidate_cap(total_cap: usize, source_traits: Option<&GeoTraits>) -> usize { - let has_geo_buckets = source_traits - .is_some_and(|traits| traits.asn.is_some() || traits.country.as_deref().is_some()); - - if has_geo_buckets { - LOOKUP_CANDIDATE_CAP_ALL.min(total_cap) - } else { - total_cap - } -} - -fn push_unique_candidates( - candidates: &mut Vec, - seen: &mut HashSet, - source: impl IntoIterator, - total_cap: usize, -) where - T: Clone + Eq + Hash, -{ - for item in source { - if candidates.len() >= total_cap { - break; - } - - if seen.insert(item.clone()) { - candidates.push(item); - } - } -} -pub async fn perform_lookup( - state: &AppState, - host: &str, - limit: Option, - source_ip: Option, -) -> Result { - let host = normalize_host(host, state.host_allowlist.as_ref())?; - perform_lookup_multi(state, &host, limit, source_ip).await -} - -async fn perform_lookup_multi( - state: &AppState, - host: &str, - limit: Option, - source_ip: Option, -) -> Result { - let source_traits = request_source_geo_traits(source_ip, state.geo.as_deref()); - let candidate_total = candidate_total_cap(limit); - let candidate_all = all_candidate_cap(candidate_total, source_traits.as_ref()); - - let dynamic_records = match &state.storage { - Storage::Redis(redis) => { - let mut conn = redis.read.get().await.map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - if redis_host_blacklisted(&mut *conn, host).await? { - debug!(host = %host, "lookup.blacklisted"); - return Ok(LookupResult::NotFound); - } - - let now_secs = unix_now_secs(); - let mut candidate_fingerprints = Vec::new(); - let mut seen_fingerprints = HashSet::new(); - - if let Some(asn) = source_traits.as_ref().and_then(|traits| traits.asn) { - let index_key = redis_asn_index_key(host, asn); - let members: Vec = conn - .zrevrange( - &index_key, - 0isize, - LOOKUP_CANDIDATE_CAP_ASN.saturating_sub(1) as isize, - ) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - push_unique_candidates( - &mut candidate_fingerprints, - &mut seen_fingerprints, - members, - candidate_total, - ); - } - - if let Some(country) = source_traits - .as_ref() - .and_then(|traits| traits.country.as_deref()) - { - let index_key = redis_country_index_key(host, country); - let members: Vec = conn - .zrevrange( - &index_key, - 0isize, - LOOKUP_CANDIDATE_CAP_COUNTRY.saturating_sub(1) as isize, - ) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - push_unique_candidates( - &mut candidate_fingerprints, - &mut seen_fingerprints, - members, - candidate_total, - ); - } - - let all_index_key = redis_all_index_key(host); - let all_members: Vec = conn - .zrevrange( - &all_index_key, - 0isize, - candidate_all.saturating_sub(1) as isize, - ) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - push_unique_candidates( - &mut candidate_fingerprints, - &mut seen_fingerprints, - all_members, - candidate_total, - ); - - let mut records = Vec::new(); - for fingerprint in candidate_fingerprints { - let primary_key = redis_primary_key(host, &fingerprint); - let member: Option> = - conn.get(&primary_key).await.map_err(|e| AppError::Redis { - message: e.to_string(), - })?; - - let Some(member) = member else { - continue; - }; - let Some(record) = StoredRecord::decode(&member) else { - continue; - }; - if record.expire_unix_secs > now_secs { - records.push(ResponseRecord::new( - record.signature_fields, - record.dns, - record.cert, - )); - } - } - - records - } - Storage::Memory(mem) => { - if mem.is_blacklisted(host) { - debug!(host = %host, "lookup.blacklisted"); - return Ok(LookupResult::NotFound); - } - - let now = tokio::time::Instant::now(); - if let Some(mut entry) = mem.records.get_mut(host) { - entry.retain_active(now); - let candidate_fingerprints = entry.collect_candidates( - source_traits - .as_ref() - .and_then(|traits| traits.country.as_deref()), - source_traits.as_ref().and_then(|traits| traits.asn), - candidate_total, - LOOKUP_CANDIDATE_CAP_ASN, - LOOKUP_CANDIDATE_CAP_COUNTRY, - candidate_all, - ); - - candidate_fingerprints - .into_iter() - .filter_map(|fingerprint| { - entry.records.get(&fingerprint).map(|record| { - ResponseRecord::new( - record.signature_fields.clone(), - record.dns_bytes.clone(), - record.cert_bytes.clone(), - ) - }) - }) - .collect::>() - } else { - vec![] - } - } - }; - - let normalized_dynamic_records = normalize_lookup_records(dynamic_records); - let mut records = if let Some(geo) = state.geo.as_deref() { - sort_lookup_records_with_geo(normalized_dynamic_records, source_ip, geo) - } else { - sort_lookup_records(normalized_dynamic_records, source_ip) - }; - - let should_append_seeds = records.is_empty() || limit.is_some_and(|max| records.len() < max); - if should_append_seeds && let Some(seed_records) = state.seed_records.get(host) { - let seeds = if let Some(geo) = state.geo.as_deref() { - sort_lookup_records_with_geo(seed_records.iter().cloned().collect(), source_ip, geo) - } else { - sort_lookup_records(seed_records.iter().cloned().collect(), source_ip) - }; - records.extend(seeds); - } - - let records = normalize_lookup_records(records); - let records = if let Some(limit) = limit { - records.into_iter().take(limit).collect::>() - } else { - records - }; - - if records.is_empty() { - Ok(LookupResult::NotFound) - } else { - Ok(LookupResult::Multi(MultiResponse::new(records))) - } -} - -pub(super) async fn redis_host_blacklisted(conn: &mut C, host: &str) -> Result -where - C: redis::aio::ConnectionLike + Send + Sync, -{ - conn.sismember(redis_blacklist_key(), host) - .await - .map_err(|e| AppError::Redis { - message: e.to_string(), - }) -} diff --git a/src/bin/ddns-server/lookup/ranking.rs b/src/bin/ddns-server/lookup/ranking.rs deleted file mode 100644 index 1274249..0000000 --- a/src/bin/ddns-server/lookup/ranking.rs +++ /dev/null @@ -1,254 +0,0 @@ -use std::{ - cmp::Ordering, - collections::{HashMap, HashSet}, - net::{IpAddr, SocketAddr}, -}; - -use ddns::core::{ - MdnsPacket, - parser::{packet::be_packet, record::RData}, - wire::ResponseRecord, -}; - -use crate::{ - geo::{GeoResolver, GeoTraits}, - storage::LookupRecord, -}; - -type EndpointKey = (SocketAddr, Option); - -pub(super) const LOOKUP_CANDIDATE_CAP_TOTAL: usize = 64; -pub(super) const LOOKUP_CANDIDATE_CAP_ASN: usize = 16; -pub(super) const LOOKUP_CANDIDATE_CAP_COUNTRY: usize = 16; -pub(super) const LOOKUP_CANDIDATE_CAP_ALL: usize = 32; - -// GEO-aware ranking dimensions. Final ordering still falls back to the original -// record index so we keep lookups stable when all computed dimensions tie. -#[derive(Clone, Copy, Debug, PartialEq)] -pub(super) struct GeoSortKey { - pub(super) same_country: bool, - pub(super) same_asn: bool, - pub(super) family_match: bool, - pub(super) same_city: bool, - pub(super) load: Option, - pub(super) geo_distance: Option, -} - -pub(super) fn normalize_lookup_records(records: Vec) -> Vec { - let mut normalized = Vec::new(); - let mut seen = HashSet::new(); - - for record in records { - if !record.signature_fields.is_empty() { - normalized.push(record); - continue; - } - - let Ok((_, packet)) = be_packet(&record.dns) else { - normalized.push(record); - continue; - }; - - let mut emitted_endpoint = false; - - for answer in &packet.answers { - let RData::E(endpoint) = answer.data() else { - continue; - }; - - emitted_endpoint = true; - let key: EndpointKey = (endpoint.addr(), endpoint.agent_addr()); - - if !seen.insert(key) { - continue; - } - - let mut hosts = HashMap::new(); - hosts.insert(answer.name().to_string(), vec![endpoint.clone()]); - normalized.push(ResponseRecord::unsigned( - MdnsPacket::answer(0, &hosts).to_bytes(), - record.cert.clone(), - )); - } - - if !emitted_endpoint { - normalized.push(record); - } - } - - normalized -} - -pub(super) fn lookup_endpoint(dns_bytes: &[u8]) -> Option<(SocketAddr, Option)> { - let (_, packet) = be_packet(dns_bytes).ok()?; - packet - .answers - .iter() - .find_map(|answer| match answer.data() { - RData::E(endpoint) => Some((endpoint.addr(), endpoint.load())), - _ => None, - }) -} - -// Fallback ordering when GEO routing is disabled: prefer matching address family, -// then lower load, and finally preserve input order. We intentionally avoid -// IP prefix heuristics here because they are not reliable on the public Internet. -pub(super) fn sort_lookup_records( - records: Vec, - source_ip: Option, -) -> Vec { - let mut decorated = records - .into_iter() - .enumerate() - .map(|(index, record)| { - let sort_key = lookup_endpoint(&record.dns).map(|(endpoint, load)| { - let family_match = source_ip - .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) - .unwrap_or(false); - - (family_match, load) - }); - (sort_key, index, record) - }) - .collect::>(); - - decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { - match (left_key, right_key) { - (Some((left_family, left_load)), Some((right_family, right_load))) => right_family - .cmp(left_family) - .then_with(|| match (left_load, right_load) { - (Some(left), Some(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal), - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, - (None, None) => Ordering::Equal, - }), - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, - (None, None) => Ordering::Equal, - } - .then_with(|| left_index.cmp(right_index)) - }); - - decorated.into_iter().map(|(_, _, record)| record).collect() -} - -pub(super) fn request_source_geo_traits( - source_ip: Option, - geo: Option<&GeoResolver>, -) -> Option { - Some(geo?.lookup_traits(source_ip?)) -} - -fn lookup_endpoint_geo_traits( - dns_bytes: &[u8], - geo: &GeoResolver, -) -> Option<(SocketAddr, Option, GeoTraits)> { - let (endpoint, load) = lookup_endpoint(dns_bytes)?; - Some((endpoint, load, geo.lookup_traits(endpoint.ip()))) -} - -fn compare_optional_partial(left: Option, right: Option) -> Ordering { - match (left, right) { - (Some(left), Some(right)) => left.partial_cmp(&right).unwrap_or(Ordering::Equal), - _ => Ordering::Equal, - } -} - -// GEO ordering is layered rather than score-based: -// country > ASN > address family > city name > lower load > shorter GEO distance. -// Missing optional values do not penalize a candidate; they simply skip that layer. -pub(super) fn compare_geo_sort_keys(left: GeoSortKey, right: GeoSortKey) -> Ordering { - right - .same_country - .cmp(&left.same_country) - .then_with(|| right.same_asn.cmp(&left.same_asn)) - .then_with(|| right.family_match.cmp(&left.family_match)) - .then_with(|| right.same_city.cmp(&left.same_city)) - .then_with(|| compare_optional_partial(left.load, right.load)) - .then_with(|| compare_optional_partial(left.geo_distance, right.geo_distance)) -} - -// Build the per-endpoint GEO ranking tuple. City name only participates when both -// sides have a name and already match on country; coordinate distance only -// participates when GeoResolver accepts both accuracy radii. -pub(super) fn build_geo_sort_key( - source_ip: Option, - source_traits: Option<&GeoTraits>, - endpoint: SocketAddr, - load: Option, - endpoint_traits: &GeoTraits, - geo: &GeoResolver, -) -> GeoSortKey { - let family_match = source_ip - .map(|source| source.is_ipv4() == endpoint.ip().is_ipv4()) - .unwrap_or(false); - - let same_country = source_traits - .and_then(|source| source.country.as_deref()) - .zip(endpoint_traits.country.as_deref()) - .is_some_and(|(source, target)| source == target); - - let same_asn = source_traits - .and_then(|source| source.asn) - .zip(endpoint_traits.asn) - .is_some_and(|(source, target)| source == target); - - let same_city = same_country - && source_traits - .and_then(|source| source.city.as_deref()) - .zip(endpoint_traits.city.as_deref()) - .is_some_and(|(source, target)| source == target); - - let geo_distance = source_traits - .and_then(|source| source.point.as_ref()) - .zip(endpoint_traits.point.as_ref()) - .and_then(|(source, target)| geo.geo_distance_km(source, target)); - - GeoSortKey { - same_country, - same_asn, - family_match, - same_city, - load, - geo_distance, - } -} -pub(super) fn sort_lookup_records_with_geo( - records: Vec, - source_ip: Option, - geo: &GeoResolver, -) -> Vec { - let source_traits = request_source_geo_traits(source_ip, Some(geo)); - - let mut decorated = records - .into_iter() - .enumerate() - .map(|(index, record)| { - let sort_key = lookup_endpoint_geo_traits(&record.dns, geo).map( - |(endpoint, load, endpoint_traits)| { - build_geo_sort_key( - source_ip, - source_traits.as_ref(), - endpoint, - load, - &endpoint_traits, - geo, - ) - }, - ); - (sort_key, index, record) - }) - .collect::>(); - - decorated.sort_by(|(left_key, left_index, _), (right_key, right_index, _)| { - match (left_key, right_key) { - (Some(left_key), Some(right_key)) => compare_geo_sort_keys(*left_key, *right_key), - (Some(_), None) => Ordering::Less, - (None, Some(_)) => Ordering::Greater, - (None, None) => Ordering::Equal, - } - .then_with(|| left_index.cmp(right_index)) - }); - - decorated.into_iter().map(|(_, _, record)| record).collect() -} diff --git a/src/bin/ddns-server/lookup/tests.rs b/src/bin/ddns-server/lookup/tests.rs deleted file mode 100644 index 9b26df4..0000000 --- a/src/bin/ddns-server/lookup/tests.rs +++ /dev/null @@ -1,427 +0,0 @@ -use std::{ - cmp::Ordering, - collections::HashMap, - net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, - path::PathBuf, - sync::Arc, -}; - -use ddns::core::{MdnsEndpoint, MdnsPacket, signature::SignatureFields, wire::ResponseRecord}; -use deadpool_redis::redis; - -use super::{ - query::{LookupResult, perform_lookup, redis_host_blacklisted}, - ranking::{ - GeoSortKey, build_geo_sort_key, compare_geo_sort_keys, lookup_endpoint, - normalize_lookup_records, sort_lookup_records, sort_lookup_records_with_geo, - }, -}; -use crate::{ - geo::{GeoPoint, GeoResolver, GeoTraits}, - storage::{AppState, LookupRecord, MemoryStorage, SeedRecords, Storage, redis_blacklist_key}, -}; - -fn fixture_geo_resolver() -> GeoResolver { - let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - let city_db = manifest_dir.join("geoip/GeoLite2-City.mmdb"); - let asn_db = manifest_dir.join("geoip/GeoLite2-ASN.mmdb"); - - GeoResolver::open(&city_db, &asn_db, true, 100).expect("fixture geo db should open") -} - -fn lookup_record(host: &str, addr: SocketAddr, load: Option) -> LookupRecord { - let mut endpoint = match addr { - SocketAddr::V4(addr) => MdnsEndpoint::direct_v4(addr), - SocketAddr::V6(addr) => MdnsEndpoint::direct_v6(addr), - }; - endpoint.set_load(load); - - let mut hosts = HashMap::new(); - hosts.insert(host.to_string(), vec![endpoint]); - - ResponseRecord::unsigned(MdnsPacket::answer(0, &hosts).to_bytes(), Vec::new()) -} - -struct FakeRedis { - response: redis::Value, - packed_commands: Vec>, -} - -impl redis::aio::ConnectionLike for FakeRedis { - fn req_packed_command<'a>( - &'a mut self, - cmd: &'a redis::Cmd, - ) -> redis::RedisFuture<'a, redis::Value> { - self.packed_commands.push(cmd.get_packed_command()); - let response = self.response.clone(); - Box::pin(async move { Ok(response) }) - } - - fn req_packed_commands<'a>( - &'a mut self, - _cmd: &'a redis::Pipeline, - _offset: usize, - _count: usize, - ) -> redis::RedisFuture<'a, Vec> { - Box::pin(async move { Ok(Vec::new()) }) - } - - fn get_db(&self) -> i64 { - 0 - } -} - -#[tokio::test] -async fn redis_host_blacklisted_queries_external_blacklist_set() { - let mut redis = FakeRedis { - response: redis::Value::Int(1), - packed_commands: Vec::new(), - }; - - let blacklisted = redis_host_blacklisted(&mut redis, "blocked.example.genmeta.net") - .await - .unwrap(); - - assert!(blacklisted); - assert_eq!(redis.packed_commands.len(), 1); - let command = String::from_utf8(redis.packed_commands.remove(0)).unwrap(); - assert!(command.contains("SISMEMBER")); - assert!(command.contains(redis_blacklist_key())); - assert!(command.contains("blocked.example.genmeta.net")); -} - -#[tokio::test] -async fn memory_blacklist_returns_not_found_before_seed_records() { - let host = "blocked.example.genmeta.net"; - let mut seed_records = HashMap::new(); - seed_records.insert( - host.to_string(), - vec![lookup_record( - host, - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - None, - )], - ); - let state = AppState { - storage: Storage::Memory(MemoryStorage::with_blacklist([host.to_string()])), - host_allowlist: Arc::new(vec!["genmeta.net".to_string()]), - require_signature: false, - ttl_secs: 30, - policies: Arc::new(crate::policy::DomainPolicies::default()), - seed_records: SeedRecords::new(seed_records), - geo: None, - }; - - let result = perform_lookup(&state, host, None, None).await.unwrap(); - - assert!(matches!(result, LookupResult::NotFound)); -} - -#[test] -fn normalize_lookup_records_keeps_signed_packets_whole() { - let mut record = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - None, - ); - record.signature_fields = SignatureFields { - content_digest: b"sha-256=:abc:".to_vec(), - signature_input: b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"" - .to_vec(), - signature: b"dns=:sig:".to_vec(), - }; - - let normalized = normalize_lookup_records(vec![record.clone()]); - - assert_eq!(normalized.len(), 1); - assert_eq!(normalized[0], record); -} - -#[test] -fn compare_geo_sort_keys_follows_documented_priority() { - let best = GeoSortKey { - same_country: true, - same_asn: true, - family_match: true, - same_city: true, - load: Some(0.2), - geo_distance: Some(20.0), - }; - let worse_load = GeoSortKey { - load: Some(0.8), - ..best - }; - let worse_family = GeoSortKey { - same_asn: true, - family_match: false, - same_city: true, - load: Some(0.1), - geo_distance: Some(1.0), - ..best - }; - let worse_city = GeoSortKey { - same_city: false, - load: Some(0.1), - geo_distance: Some(1.0), - ..best - }; - let worse_asn = GeoSortKey { - same_asn: false, - family_match: true, - same_city: true, - load: Some(0.1), - geo_distance: Some(1.0), - ..best - }; - let worse_country = GeoSortKey { - same_country: false, - same_asn: true, - family_match: true, - same_city: false, - load: Some(0.1), - geo_distance: Some(1.0), - }; - - assert_eq!(compare_geo_sort_keys(best, worse_load), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_family), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_city), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_asn), Ordering::Less); - assert_eq!(compare_geo_sort_keys(best, worse_country), Ordering::Less); -} - -#[test] -fn compare_geo_sort_keys_skips_unknown_dimensions() { - let known_distance = GeoSortKey { - same_country: true, - same_asn: true, - family_match: true, - same_city: true, - load: Some(0.2), - geo_distance: Some(10.0), - }; - let missing_distance = GeoSortKey { - geo_distance: None, - ..known_distance - }; - let missing_load = GeoSortKey { - load: None, - ..known_distance - }; - - assert_eq!( - compare_geo_sort_keys(known_distance, missing_distance), - Ordering::Equal - ); - assert_eq!( - compare_geo_sort_keys(known_distance, missing_load), - Ordering::Equal - ); -} - -#[test] -fn sort_lookup_records_with_geo_prefers_same_source_endpoint_even_with_higher_load() { - let geo = fixture_geo_resolver(); - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); - let matching = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - Some(0.9), - ); - let non_matching = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 1, 1, 1), 3478)), - Some(0.1), - ); - - let sorted = - sort_lookup_records_with_geo(vec![non_matching, matching.clone()], source_ip, &geo); - - let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); - assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); -} - -#[test] -fn sort_lookup_records_without_geo_ignores_ip_prefix_and_prefers_lower_load() { - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); - let closer_prefix_higher_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 3478)), - Some(0.9), - ); - let farther_prefix_lower_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - Some(0.1), - ); - - let sorted = sort_lookup_records( - vec![closer_prefix_higher_load, farther_prefix_lower_load], - source_ip, - ); - - let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); - assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); -} - -#[test] -fn sort_lookup_records_with_geo_prefers_same_asn_then_same_country_on_real_ips() { - let geo = fixture_geo_resolver(); - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); - - let different_country = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 8, 8), 3478)), - Some(0.01), - ); - let same_country_different_asn = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), - Some(0.02), - ); - let same_country_same_asn = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), - Some(0.9), - ); - - let sorted = sort_lookup_records_with_geo( - vec![ - different_country, - same_country_different_asn, - same_country_same_asn, - ], - source_ip, - &geo, - ); - - let ordered_ips = sorted - .iter() - .map(|record| { - lookup_endpoint(&record.dns) - .expect("record should decode") - .0 - .ip() - }) - .collect::>(); - - assert_eq!( - ordered_ips, - vec![ - IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5)), - IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114)), - IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), - ] - ); -} - -#[test] -fn sort_lookup_records_with_geo_prefers_same_country_over_lower_load_on_real_ips() { - let geo = fixture_geo_resolver(); - let source_ip = Some(IpAddr::V4(Ipv4Addr::new(114, 114, 114, 114))); - - let different_country_low_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(80, 80, 80, 80), 3478)), - Some(0.01), - ); - let same_country_higher_load = lookup_record( - "stun.example.com", - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 5, 5, 5), 3478)), - Some(0.9), - ); - - let sorted = sort_lookup_records_with_geo( - vec![different_country_low_load, same_country_higher_load.clone()], - source_ip, - &geo, - ); - - let (endpoint, _) = lookup_endpoint(&sorted[0].dns).expect("sorted record should decode"); - assert_eq!(endpoint.ip(), IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))); -} - -#[test] -fn build_geo_sort_key_ignores_city_distance_when_accuracy_is_too_large() { - let geo = fixture_geo_resolver(); - let source_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Beijing".to_string()), - asn: Some(64512), - point: Some(GeoPoint { - latitude: 39.9, - longitude: 116.4, - accuracy_radius_km: 500, - }), - }; - let endpoint_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Shanghai".to_string()), - asn: Some(64512), - point: Some(GeoPoint { - latitude: 31.2, - longitude: 121.5, - accuracy_radius_km: 10, - }), - }; - - let key = build_geo_sort_key( - Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))), - Some(&source_traits), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(8, 8, 4, 4), 3478)), - Some(0.2), - &endpoint_traits, - &geo, - ); - - assert!(key.same_country); - assert!(key.same_asn); - assert!(!key.same_city); - assert_eq!(key.geo_distance, None); -} - -#[test] -fn build_geo_sort_key_prefers_same_city_when_available() { - let geo = fixture_geo_resolver(); - let source_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Hangzhou".to_string()), - asn: Some(64512), - point: None, - }; - let same_city_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Hangzhou".to_string()), - asn: Some(64513), - point: None, - }; - let different_city_traits = GeoTraits { - country: Some("CN".to_string()), - city: Some("Shanghai".to_string()), - asn: Some(64513), - point: None, - }; - - let same_city_key = build_geo_sort_key( - Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), - Some(&source_traits), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(223, 6, 6, 6), 3478)), - Some(0.9), - &same_city_traits, - &geo, - ); - let different_city_key = build_geo_sort_key( - Some(IpAddr::V4(Ipv4Addr::new(223, 5, 5, 5))), - Some(&source_traits), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(114, 114, 114, 114), 3478)), - Some(0.1), - &different_city_traits, - &geo, - ); - - assert!(same_city_key.same_city); - assert!(!different_city_key.same_city); - assert_eq!( - compare_geo_sort_keys(same_city_key, different_city_key), - Ordering::Less - ); -} diff --git a/src/bin/ddns-server/main.rs b/src/bin/ddns-server/main.rs deleted file mode 100644 index 1a4264a..0000000 --- a/src/bin/ddns-server/main.rs +++ /dev/null @@ -1,445 +0,0 @@ -mod config; -mod error; -mod geo; -mod lookup; -mod ocsp; -mod policy; -mod publish; -mod storage; - -use std::{ - collections::HashMap, - io, - net::SocketAddr, - sync::Arc, - task::{Context, Poll}, - time::{SystemTime, UNIX_EPOCH}, -}; - -use clap::Parser; -use ddns::core::{MdnsEndpoint, MdnsPacket, wire::ResponseRecord}; -use futures::future::BoxFuture; -use h3x::{ - dquic::{ - Identity, Network, QuicEndpoint, - cert::handy::{ToCertificate, ToPrivateKey}, - server::ServerQuicConfig, - }, - endpoint::H3Endpoint, - hyper::TowerService, -}; -use rustls::{RootCertStore, server::WebPkiClientVerifier}; -use tracing::{info, level_filters::LevelFilter, warn}; -use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt}; - -use crate::{ - config::{Config, Options, PolicyKind, SeedRecordConfig}, - geo::GeoResolver, - lookup::LookupSvc, - policy::{DomainPolicies, DomainPolicy, PolicyRule}, - publish::PublishSvc, - storage::{AppState, MemoryStorage, RedisStorage, SeedRecords, Storage}, -}; - -#[derive(Clone)] -struct DnsService { - publish: PublishSvc, - lookup: LookupSvc, -} - -impl tower_service::Service for DnsService { - type Response = lookup::Response; - type Error = io::Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, request: lookup::Request) -> Self::Future { - let method = request.method().clone(); - let path = request.uri().path().to_owned(); - let publish = self.publish.clone(); - let lookup = self.lookup.clone(); - Box::pin(async move { - match (method, path.as_str()) { - (http::Method::POST, "/publish") => match publish.call(request).await { - Ok(response) => Ok(response), - Err(never) => match never {}, - }, - (http::Method::GET, "/lookup") => match lookup.call(request).await { - Ok(response) => Ok(response), - Err(never) => match never {}, - }, - (_, "/publish" | "/lookup") => Ok(lookup::body_response( - http::StatusCode::METHOD_NOT_ALLOWED, - bytes::Bytes::from_static(b"Method Not Allowed"), - )), - _ => Ok(lookup::body_response( - http::StatusCode::NOT_FOUND, - bytes::Bytes::from_static(b"Not Found"), - )), - } - }) - } -} - -// --------------------------------------------------------------------------- -// TLS helpers -// --------------------------------------------------------------------------- - -fn load_root_store_from_pem(pem: &[u8]) -> io::Result { - let mut reader = std::io::Cursor::new(pem); - let certs = rustls_pemfile::certs(&mut reader) - .collect::, _>>() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - let mut store = RootCertStore::empty(); - store.add_parsable_certificates(certs); - Ok(store) -} - -fn build_seed_records( - seed_records: &[SeedRecordConfig], - allowlist: &[String], -) -> io::Result { - let mut records = HashMap::new(); - - for seed_record in seed_records { - if seed_record.endpoints.is_empty() { - continue; - } - - let host = error::normalize_host(&seed_record.host, allowlist) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - - let endpoints = seed_record - .endpoints - .iter() - .map(|addr| match addr { - SocketAddr::V4(addr) => MdnsEndpoint::direct_v4(*addr), - SocketAddr::V6(addr) => MdnsEndpoint::direct_v6(*addr), - }) - .collect::>(); - - let mut hosts = HashMap::new(); - hosts.insert(host.clone(), endpoints); - - records - .entry(host.clone()) - .or_insert_with(Vec::new) - .push(ResponseRecord::unsigned( - MdnsPacket::answer(0, &hosts).to_bytes(), - Vec::new(), - )); - - info!(host = %host, endpoint_count = seed_record.endpoints.len(), "seed_records.loaded"); - } - - Ok(Arc::new(records)) -} - -fn log_geo_db_freshness(kind: &str, build_epoch: u64) { - const STALE_GEO_DB_AGE_SECS: u64 = 45 * 24 * 60 * 60; - - let now_secs = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_secs()) - .unwrap_or(build_epoch); - let age_secs = now_secs.saturating_sub(build_epoch); - - if age_secs > STALE_GEO_DB_AGE_SECS { - warn!(kind, build_epoch, age_secs, "geo_routing.db_outdated"); - } -} - -const GEO_CITY_DISTANCE_ROUTING: bool = true; -const GEO_MAX_ACCURACY_RADIUS_KM: u32 = 100; - -fn build_geo_resolver(config: &Config) -> io::Result>> { - let Some(city_db) = config.geoip_city_db.as_deref() else { - return if config.geoip_asn_db.is_none() { - Ok(None) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidInput, - "geoip_city_db and geoip_asn_db must be configured together", - )) - }; - }; - - let Some(asn_db) = config.geoip_asn_db.as_deref() else { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "geoip_city_db and geoip_asn_db must be configured together", - )); - }; - - let resolver = Arc::new(GeoResolver::open( - city_db, - asn_db, - GEO_CITY_DISTANCE_ROUTING, - GEO_MAX_ACCURACY_RADIUS_KM, - )?); - info!( - city_db = %city_db.display(), - asn_db = %asn_db.display(), - city_distance_routing = GEO_CITY_DISTANCE_ROUTING, - max_accuracy_radius_km = GEO_MAX_ACCURACY_RADIUS_KM, - "geo_routing.enabled" - ); - log_geo_db_freshness("city", resolver.city_build_epoch()); - log_geo_db_freshness("asn", resolver.asn_build_epoch()); - - Ok(Some(resolver)) -} - -// --------------------------------------------------------------------------- -// Entry point -// --------------------------------------------------------------------------- - -#[tokio::main] -async fn main() -> Result<(), Box> { - tracing_subscriber::registry() - .with(tracing_subscriber::fmt::layer()) - .with(tracing_subscriber::filter::filter_fn(|metadata| { - !metadata.target().contains("netlink_packet_route") - })) - .with(LevelFilter::DEBUG) - .init(); - - let options = Options::parse(); - - let config_str = std::fs::read_to_string(&options.config).unwrap_or_else(|e| { - eprintln!("failed to read config {:?}: {e}", options.config); - std::process::exit(1); - }); - let config: Config = toml::from_str(&config_str).unwrap_or_else(|e| { - eprintln!("failed to parse config {:?}: {e}", options.config); - std::process::exit(1); - }); - let config = config.expand_paths(); - let host_allowlist = Arc::new(error::normalize_host_allowlist(&config.host_allowlist)?); - if host_allowlist.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "host_allowlist must not be empty", - ) - .into()); - } - let seed_records = build_seed_records(&config.seed_records, host_allowlist.as_ref())?; - let geo = build_geo_resolver(&config)?; - let memory_blacklist = config - .blacklist - .iter() - .filter_map( - |host| match error::normalize_host(host, host_allowlist.as_ref()) { - Ok(host) => Some(host), - Err(error) => { - warn!(host, error = %error, "blacklist.invalid_host_ignored"); - None - } - }, - ) - .collect::>(); - - // Build storage backend. - let storage = match config.redis_write_url.clone() { - Some(write_url) => { - if !memory_blacklist.is_empty() { - warn!( - count = memory_blacklist.len(), - "blacklist.config_ignored_when_redis_enabled" - ); - } - let write_cfg = deadpool_redis::Config::from_url(write_url.clone()); - let write_pool = write_cfg.create_pool(Some(deadpool_redis::Runtime::Tokio1))?; - let read_pool = match config.redis_read_url.clone() { - Some(read_url) if read_url != write_url => { - deadpool_redis::Config::from_url(read_url) - .create_pool(Some(deadpool_redis::Runtime::Tokio1))? - } - _ => write_pool.clone(), - }; - Storage::Redis(RedisStorage { - write: write_pool, - read: read_pool, - }) - } - None if config.redis_read_url.is_some() => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "redis_read_url requires redis_write_url (or legacy redis)", - ) - .into()); - } - None => Storage::Memory(MemoryStorage::with_blacklist(memory_blacklist)), - }; - - // Build domain-policy rules from config file. - let mut policy_rules: Vec<(PolicyRule, DomainPolicy)> = config - .domain_policies - .iter() - .filter_map( - |pc| match error::normalize_host(&pc.host, host_allowlist.as_ref()) { - Ok(h) => { - let policy = match pc.policy { - PolicyKind::Standard => DomainPolicy::Standard, - PolicyKind::OpenMulti => DomainPolicy::OpenMulti, - }; - Some((PolicyRule::Exact(h), policy)) - } - Err(error) => { - warn!(host = %pc.host, error = %error, "domain_policy.invalid_host_ignored"); - None - } - }, - ) - .collect(); - // Deduplicate (preserve first occurrence). - policy_rules.dedup_by(|(ra, _), (rb, _)| { - matches!((ra, rb), (PolicyRule::Exact(a), PolicyRule::Exact(b)) if a == b) - }); - let policies = Arc::new(DomainPolicies(policy_rules)); - info!(?policies, "domain_policies.loaded"); - - // Load the root CA used to validate client certificates when they are provided. - let root_ca_pem = std::fs::read(&config.root_cert)?; - let roots = load_root_store_from_pem(&root_ca_pem)?; - let verifier = WebPkiClientVerifier::builder(Arc::new(roots)) - .allow_unauthenticated() - .build() - .unwrap(); - - let state = AppState { - storage, - host_allowlist, - require_signature: config.require_signature, - ttl_secs: config.ttl_secs, - policies, - seed_records, - geo, - }; - - let cert_pem = std::fs::read(&config.cert)?; - let key_pem = std::fs::read(&config.key)?; - - let router = TowerService(DnsService { - publish: PublishSvc { - state: state.clone(), - }, - lookup: LookupSvc { - state: state.clone(), - }, - }); - - let identity = Arc::new(Identity { - name: config.server_name.parse().unwrap(), - certs: Arc::new(cert_pem.to_certificate()), - key: Arc::new(key_pem.to_private_key()), - ocsp: Arc::new(None), - }); - let server_config = ServerQuicConfig { - alpns: vec![b"h3".to_vec()], - client_cert_verifier: verifier, - ..Default::default() - }; - let quic = QuicEndpoint::builder() - .network(Network::builder().build()) - .identity(identity) - .server(server_config) - .bind(Arc::new(config.binds.clone())) - .build() - .await; - match ocsp::OcspAutoRefresh::from_config(&config, &cert_pem, &root_ca_pem) { - Ok(ocsp_refresh) => { - info!( - responder_url = %ocsp_refresh.responder_url(), - refresh_in_secs = ocsp::refresh_success_delay().as_secs(), - "ocsp.auto_refresh.enabled" - ); - let mut ocsp_quic = quic.clone(); - let initial_delay = ocsp_refresh.refresh_once(&mut ocsp_quic).await; - info!( - next_refresh_in_secs = initial_delay.as_secs(), - "ocsp.auto_refresh.initialized" - ); - tokio::spawn(ocsp_refresh.run(ocsp_quic)); - } - Err(error) => { - warn!(error = %error, "ocsp.auto_refresh.disabled"); - } - } - let server = Arc::new(H3Endpoint::new(quic)); - info!(binds = ?config.binds, server_name = %config.server_name, "h3_server.start"); - server.listen_owned(router).await?; - - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use super::*; - use crate::config::Config; - - fn test_config() -> Config { - Config { - redis_write_url: None, - redis_read_url: None, - host_allowlist: Config::default_host_allowlist(), - binds: Config::default_binds(), - server_name: Config::default_server_name(), - cert: Config::default_cert(), - key: Config::default_key(), - root_cert: Config::default_root_cert(), - ocsp_issuer_cert: None, - ocsp_responder_base_url: None, - require_signature: Config::default_require_signature(), - ttl_secs: Config::default_ttl_secs(), - domain_policies: Vec::new(), - blacklist: Vec::new(), - seed_records: Vec::new(), - geoip_city_db: None, - geoip_asn_db: None, - } - } - - #[test] - fn default_binds_include_ipv4_and_ipv6_wildcards() { - let patterns = Config::default_binds(); - - assert_eq!(patterns.len(), 2); - assert_eq!(patterns[0].to_string(), "inet://0.0.0.0:4433"); - assert_eq!(patterns[1].to_string(), "inet://[::]:4433"); - } - - #[test] - fn geo_routing_requires_city_db_path() { - let mut config = test_config(); - config.geoip_asn_db = Some(PathBuf::from("/tmp/asn.mmdb")); - - let err = build_geo_resolver(&config).expect_err("missing city db should fail"); - - assert_eq!(err.kind(), io::ErrorKind::InvalidInput); - assert_eq!( - err.to_string(), - "geoip_city_db and geoip_asn_db must be configured together" - ); - } - - #[test] - fn geo_routing_requires_asn_db_path() { - let mut config = test_config(); - config.geoip_city_db = Some(PathBuf::from("/tmp/city.mmdb")); - - let err = build_geo_resolver(&config).expect_err("missing asn db should fail"); - - assert_eq!(err.kind(), io::ErrorKind::InvalidInput); - assert_eq!( - err.to_string(), - "geoip_city_db and geoip_asn_db must be configured together" - ); - } -} diff --git a/src/bin/ddns-server/ocsp.rs b/src/bin/ddns-server/ocsp.rs deleted file mode 100644 index 8132c47..0000000 --- a/src/bin/ddns-server/ocsp.rs +++ /dev/null @@ -1,230 +0,0 @@ -use std::{io, path::Path, time::Duration}; - -use dhttp_identity::ocsp::{OcspStatus, build_ocsp_request_der, verify_stapled_ocsp_response}; -use h3x::dquic::QuicEndpoint; -use reqwest::{ - Url, - header::{ACCEPT, CONTENT_TYPE}, -}; -use rustls::pki_types::{CertificateDer, UnixTime}; -use tokio::time::sleep; -use tracing::{info, warn}; - -use crate::config::Config; - -pub const DEFAULT_OCSP_RESPONDER_BASE_URL: &str = "https://license.genmeta.net"; -pub const OCSP_STAPLING_TTL: Duration = Duration::from_secs(3 * 60 * 60); -pub const OCSP_REFRESH_EXPIRY_SKEW: Duration = Duration::from_secs(5 * 60); -pub const OCSP_REFRESH_RETRY_DELAY: Duration = Duration::from_secs(5 * 60); - -pub struct OcspAutoRefresh { - responder_url: String, - http_client: reqwest::Client, - request_der: Vec, - leaf_der: CertificateDer<'static>, - issuer_der: CertificateDer<'static>, -} - -impl OcspAutoRefresh { - pub fn from_config(config: &Config, cert_pem: &[u8], root_cert_pem: &[u8]) -> io::Result { - let base_url = config - .ocsp_responder_base_url - .as_deref() - .unwrap_or(DEFAULT_OCSP_RESPONDER_BASE_URL); - let responder_url = normalize_base_url(base_url)?; - let (request_der, leaf_der, issuer_der) = - build_ocsp_request_context(cert_pem, config.ocsp_issuer_cert.as_deref())?; - let root_cert = reqwest::Certificate::from_pem(root_cert_pem) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - let http_client = reqwest::Client::builder() - .add_root_certificate(root_cert) - .timeout(Duration::from_secs(15)) - .build() - .map_err(io::Error::other)?; - - Ok(Self { - responder_url, - http_client, - request_der, - leaf_der, - issuer_der, - }) - } - - pub fn responder_url(&self) -> &str { - &self.responder_url - } - - pub async fn refresh_once(&self, quic: &mut QuicEndpoint) -> Duration { - match self.fetch_response().await { - Ok(response_der) => match self.validate_response(&response_der) { - Ok(OcspStatus::Good) => { - let response_len = response_der.len(); - quic.update_ocsp(Some(response_der)); - info!( - responder_url = %self.responder_url, - response_len, - refresh_in_secs = refresh_success_delay().as_secs(), - "ocsp.staple_refreshed" - ); - refresh_success_delay() - } - Ok(OcspStatus::Unknown) => { - warn!( - responder_url = %self.responder_url, - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp response status is unknown; skipping staple update" - ); - OCSP_REFRESH_RETRY_DELAY - } - Ok(OcspStatus::Revoked) => { - warn!( - responder_url = %self.responder_url, - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp response status is revoked; skipping staple update" - ); - OCSP_REFRESH_RETRY_DELAY - } - Err(error) => { - warn!( - error = %error, - responder_url = %self.responder_url, - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp response validation failed; skipping staple update" - ); - OCSP_REFRESH_RETRY_DELAY - } - }, - Err(error) => { - warn!( - error = %error, - responder_url = %self.responder_url, - retry_in_secs = OCSP_REFRESH_RETRY_DELAY.as_secs(), - "ocsp.refresh_failed" - ); - OCSP_REFRESH_RETRY_DELAY - } - } - } - - pub async fn run(self, mut quic: QuicEndpoint) { - loop { - let delay = self.refresh_once(&mut quic).await; - sleep(delay).await; - } - } - - async fn fetch_response(&self) -> io::Result> { - let response = self - .http_client - .post(&self.responder_url) - .header(CONTENT_TYPE, "application/ocsp-request") - .header(ACCEPT, "application/ocsp-response") - .body(self.request_der.clone()) - .send() - .await - .map_err(request_error)?; - - if !response.status().is_success() { - let status = response.status(); - let message = response.text().await.unwrap_or_default(); - return Err(io::Error::other(format!( - "OCSP responder returned HTTP status {status}: {message}" - ))); - } - - let body = response.bytes().await.map_err(request_error)?; - if body.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "OCSP responder returned an empty body", - )); - } - - Ok(body.to_vec()) - } - - fn validate_response(&self, response_der: &[u8]) -> Result { - verify_stapled_ocsp_response(&self.leaf_der, &self.issuer_der, response_der, now()) - } -} - -pub fn refresh_success_delay() -> Duration { - OCSP_STAPLING_TTL.saturating_sub(OCSP_REFRESH_EXPIRY_SKEW) -} - -fn build_ocsp_request_context( - cert_pem: &[u8], - issuer_override: Option<&Path>, -) -> io::Result<(Vec, CertificateDer<'static>, CertificateDer<'static>)> { - let chain = load_pem_certificates(cert_pem)?; - let leaf_der = chain.first().cloned().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "server certificate PEM does not contain a certificate", - ) - })?; - let issuer_der = match chain.get(1) { - Some(issuer) => issuer.clone(), - None => load_issuer_certificate(issuer_override)?, - }; - - let request_der = build_ocsp_request_der(&leaf_der, &issuer_der) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; - - Ok((request_der, leaf_der, issuer_der)) -} - -fn load_issuer_certificate(issuer_override: Option<&Path>) -> io::Result> { - let issuer_path = issuer_override.ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "OCSP auto-refresh requires the server cert PEM to include the issuer cert or ocsp_issuer_cert to be configured", - ) - })?; - let issuer_pem = std::fs::read(issuer_path)?; - let issuer_chain = load_pem_certificates(&issuer_pem)?; - issuer_chain.into_iter().next().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidData, - "ocsp_issuer_cert does not contain a certificate", - ) - }) -} - -fn load_pem_certificates(cert_pem: &[u8]) -> io::Result>> { - let mut reader = std::io::Cursor::new(cert_pem); - rustls_pemfile::certs(&mut reader) - .collect::, _>>() - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) -} - -fn normalize_base_url(base_url: &str) -> io::Result { - let parsed = - Url::parse(base_url).map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; - if parsed.scheme() != "https" { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "ocsp_responder_base_url must use https", - )); - } - if parsed.host_str().is_none() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "ocsp_responder_base_url must include a host", - )); - } - - Ok(format!("{}/ocsp", parsed.as_str().trim_end_matches('/'))) -} - -fn request_error(error: reqwest::Error) -> io::Error { - io::Error::other(format!("failed to query OCSP responder: {error}")) -} - -fn now() -> UnixTime { - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default(); - UnixTime::since_unix_epoch(now) -} diff --git a/src/bin/ddns-server/policy.rs b/src/bin/ddns-server/policy.rs deleted file mode 100644 index d2a2d97..0000000 --- a/src/bin/ddns-server/policy.rs +++ /dev/null @@ -1,200 +0,0 @@ -use ddns::core::{ - parser::{packet::be_packet, record::RData}, - signature::SignatureFields, -}; -use dhttp_identity::identity::RemoteAuthority; -use tracing::{debug, warn}; - -use crate::error::{AppError, normalize_host}; - -#[derive(Clone, Debug, PartialEq)] -pub enum DomainPolicy { - Standard, - OpenMulti, -} - -#[derive(Clone, Debug)] -pub enum PolicyRule { - Exact(String), - #[allow(dead_code)] - Suffix(String), -} - -impl PolicyRule { - pub fn matches(&self, host: &str) -> bool { - match self { - PolicyRule::Exact(exact) => host == exact, - PolicyRule::Suffix(suffix) => { - host == suffix.as_str() || host.ends_with(&format!(".{suffix}")) - } - } - } -} - -#[derive(Clone, Debug, Default)] -pub struct DomainPolicies(pub Vec<(PolicyRule, DomainPolicy)>); - -impl DomainPolicies { - pub fn policy_for(&self, host: &str) -> &DomainPolicy { - for (rule, policy) in &self.0 { - if rule.matches(host) { - return policy; - } - } - &DomainPolicy::Standard - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ValidatedDnsPacket { - Records { host: String }, - Empty, -} - -pub fn extract_client_dns_sans(authority: &(impl RemoteAuthority + ?Sized)) -> Vec { - use x509_parser::prelude::*; - - let Some(leaf) = authority.cert_chain().first() else { - return vec![]; - }; - - let Ok((_remain, cert)) = X509Certificate::from_der(leaf.as_ref()) else { - return vec![]; - }; - - let mut out = vec![]; - if let Ok(Some(san)) = cert.subject_alternative_name() { - for name in san.value.general_names.iter() { - if let GeneralName::DNSName(dns) = name { - out.push(dns.to_string()); - } - } - } - out -} - -pub fn client_allowed_host( - authority: &(impl RemoteAuthority + ?Sized), - allowlist: &[String], -) -> Result { - let mut sans = extract_client_dns_sans(authority) - .into_iter() - .filter_map(|h| normalize_host(&h, allowlist).ok()) - .collect::>(); - - sans.sort(); - sans.dedup(); - - match sans.len() { - 1 => Ok(sans.remove(0)), - _ => Err(AppError::ClientCertDomainNotAllowed), - } -} - -pub fn validate_dns_packet( - packet: &[u8], - require_signature: bool, - authority: &(impl RemoteAuthority + ?Sized), - signature_fields: &SignatureFields, - allowlist: &[String], - expected_host: &str, -) -> Result { - let (remaining, dns_packet) = be_packet(packet).map_err(|e| AppError::InvalidDnsPacket { - message: e.to_string(), - })?; - if !remaining.is_empty() { - warn!(remain = remaining.len(), "dns.parse.extra_bytes"); - } - debug!( - answers = dns_packet.answers.len(), - require_signature, "validating dns packet" - ); - - if require_signature { - if signature_fields.is_empty() { - return Err(AppError::SignatureRequired); - } - - let cert = authority - .cert_chain() - .first() - .ok_or(AppError::MissingClientCertificate)?; - let ok = signature_fields - .verify(packet, cert.as_ref()) - .map_err(|_| AppError::InvalidSignature)?; - if !ok { - return Err(AppError::InvalidSignature); - } - - for record in &dns_packet.answers { - if let RData::E(endpoint) = record.data() - && endpoint.is_signed() - { - return Err(AppError::InvalidSignature); - } - } - } - - let Some(first_answer) = dns_packet.answers.first() else { - debug!("dns packet has no answers"); - return Ok(ValidatedDnsPacket::Empty); - }; - - for answer in &dns_packet.answers { - let answer_host = normalize_host(&answer.name(), allowlist)?; - if answer_host != expected_host { - return Err(AppError::HostMismatch); - } - } - - Ok(ValidatedDnsPacket::Records { - host: first_answer.name().to_string(), - }) -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr}; - use dhttp_identity::identity::RemoteAuthority; - use rustls::pki_types::CertificateDer; - - use super::*; - - #[derive(Debug)] - struct TestAuthority; - - fn allowlist() -> Vec { - vec!["genmeta.net".to_string()] - } - - impl RemoteAuthority for TestAuthority { - fn name(&self) -> &str { - "authority.example" - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &[] - } - } - - #[test] - fn validate_dns_packet_accepts_empty_packet_as_clear_operation() { - let hosts: HashMap> = - HashMap::from([("reimu.pilot.genmeta.net".to_owned(), Vec::new())]); - let packet = MdnsPacket::answer(0, &hosts).to_bytes(); - - let validated = validate_dns_packet( - &packet, - false, - &TestAuthority, - &SignatureFields::empty(), - &allowlist(), - "reimu.pilot.genmeta.net", - ) - .unwrap(); - - assert!(matches!(validated, ValidatedDnsPacket::Empty)); - } -} diff --git a/src/bin/ddns-server/publish/http.rs b/src/bin/ddns-server/publish/http.rs deleted file mode 100644 index d76946a..0000000 --- a/src/bin/ddns-server/publish/http.rs +++ /dev/null @@ -1,163 +0,0 @@ -use std::{convert::Infallible, sync::Arc}; - -use ddns::core::signature::{ - CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields, -}; -use h3x::{connection::ConnectionState, quic}; -use http_body_util::BodyExt; -use tracing::{debug, warn}; - -use super::store::{clear_record, publish_record}; -use crate::{ - error::{AppError, normalize_host, parse_query_params}, - lookup::{Request, Response, write_error}, - policy::{DomainPolicy, ValidatedDnsPacket, client_allowed_host, validate_dns_packet}, - storage::AppState, -}; - -// --------------------------------------------------------------------------- -// PublishSvc -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct PublishSvc { - pub state: AppState, -} - -impl PublishSvc { - pub fn call( - &self, - request: Request, - ) -> impl Future> + Send + 'static { - let state = self.state.clone(); - async move { Ok(publish_with_cert(state, request).await) } - } -} - -async fn publish_with_cert(state: AppState, request: Request) -> Response { - debug!("received publish request"); - - let params = parse_query_params(request.uri()); - debug!("query params: {:?}", params); - - let Some(host) = params.get("host") else { - warn!("missing host parameter"); - return write_error(AppError::MissingHostParam); - }; - - let host = match normalize_host(host, state.host_allowlist.as_ref()) { - Ok(h) => h, - Err(e) => return write_error(e), - }; - debug!(host = %host, "publish.host"); - - // Require a valid client certificate for all publish requests. - let authority = match request_connection(&request) { - Some(connection) => match connection.remote_authority().await { - Ok(Some(authority)) => authority, - Ok(None) => { - warn!("missing client certificate"); - return write_error(AppError::MissingClientCertificate); - } - Err(error) => { - warn!(error = %snafu::Report::from_error(&error), "failed to read client certificate"); - return write_error(AppError::MissingClientCertificate); - } - }, - None => { - warn!("missing client certificate"); - return write_error(AppError::MissingClientCertificate); - } - }; - - let policy = state.policies.policy_for(&host).clone(); - - // Standard policy: cert SAN must match the target host. - // OpenMulti policy: any authenticated node may publish — skip SAN check. - if policy == DomainPolicy::Standard { - let allowed = match client_allowed_host(authority.as_ref(), state.host_allowlist.as_ref()) { - Ok(h) => h, - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "client certificate domain not allowed"); - return write_error(e); - } - }; - if allowed != host { - warn!(allowed = %allowed, requested = %host, "publish.host_mismatch"); - return write_error(AppError::HostMismatch); - } - } - - let signature_fields = signature_fields_from_headers(request.headers()); - - let body = match request.into_body().collect().await { - Ok(body) => body.to_bytes(), - Err(e) => { - warn!(error = %snafu::Report::from_error(&e), "failed to read request body"); - return write_error(AppError::InvalidDnsPacket { - message: e.to_string(), - }); - } - }; - - // Validate DNS packet; signature check only for Standard hosts. - let require_sig = policy == DomainPolicy::Standard && state.require_signature; - debug!( - host = %host, - bytes = body.len(), - require_signature = require_sig, - "validating publish packet" - ); - let packet = match validate_dns_packet( - body.as_ref(), - require_sig, - authority.as_ref(), - &signature_fields, - state.host_allowlist.as_ref(), - &host, - ) { - Ok(n) => n, - Err(e) => { - debug!(host = %host, error = %e, "publish packet rejected"); - return write_error(e); - } - }; - - match packet { - ValidatedDnsPacket::Records { host: packet_name } => { - let packet_host = match normalize_host(&packet_name, state.host_allowlist.as_ref()) { - Ok(h) => h, - Err(e) => return write_error(e), - }; - - if packet_host != host { - return write_error(AppError::HostMismatch); - } - - publish_record(&state, &host, &body, authority.as_ref(), signature_fields).await - } - ValidatedDnsPacket::Empty => clear_record(&state, &host, authority.as_ref()).await, - } -} - -fn signature_fields_from_headers(headers: &http::HeaderMap) -> SignatureFields { - let header = |name: &'static str| { - headers - .get(name) - .map(|value| value.as_bytes().to_vec()) - .unwrap_or_default() - }; - - SignatureFields { - content_digest: header(CONTENT_DIGEST_HEADER), - signature_input: header(SIGNATURE_INPUT_HEADER), - signature: header(SIGNATURE_HEADER), - } -} - -fn request_connection(request: &Request) -> Option>> { - request - .extensions() - .get::>>() - .cloned() -} diff --git a/src/bin/ddns-server/publish/mod.rs b/src/bin/ddns-server/publish/mod.rs deleted file mode 100644 index 73cba14..0000000 --- a/src/bin/ddns-server/publish/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod http; -mod store; - -pub use http::PublishSvc; - -#[cfg(test)] -mod tests; diff --git a/src/bin/ddns-server/publish/store.rs b/src/bin/ddns-server/publish/store.rs deleted file mode 100644 index cdc4620..0000000 --- a/src/bin/ddns-server/publish/store.rs +++ /dev/null @@ -1,250 +0,0 @@ -use std::collections::HashSet; - -use ddns::core::signature::SignatureFields; -use deadpool_redis::redis::{self, AsyncCommands}; -use dhttp_identity::identity::RemoteAuthority; -use tokio::time::{Duration, Instant}; -use tracing::info; - -use crate::{ - error::AppError, - lookup::{Response, body_response, write_error}, - storage::{ - AppState, Record, Storage, StoredRecord, cert_fingerprint, cert_fingerprint_hex, - record_index_tags, redis_all_index_key, redis_asn_index_key, redis_country_index_key, - redis_primary_key, unix_now_secs, - }, -}; - -async fn trim_expired_index_keys( - conn: &mut C, - keys: impl IntoIterator, - cutoff: f64, - expire_ttl_secs: i64, -) where - C: redis::aio::ConnectionLike + Send + Sync, -{ - for key in keys { - let _: bool = conn.expire(&key, expire_ttl_secs).await.unwrap_or(false); - let _: () = redis::cmd("ZREMRANGEBYSCORE") - .arg(&key) - .arg("-inf") - .arg(cutoff) - .query_async::<()>(&mut *conn) - .await - .unwrap_or(()); - } -} - -/// Unified publish handler: stores the record keyed by (host, cert-fingerprint). -/// Both Standard and OpenMulti policies follow the same storage path; -/// the only policy difference (SAN check) is already enforced in the caller. -/// -/// Certificate fingerprint is the publish-source identity. In PKI ecosystems, -/// a single domain name can have multiple valid certificates (from different CAs, -/// or issued at different times for rotation/failover/multi-region scenarios). -/// Using fingerprint as part of the storage key enables: -/// - Multi-publisher coexistence: different cert holders can publish the same domain -/// - Idempotent updates: re-publishing from same cert source (same fingerprint) overwrites old data -/// - Client choice: lookups return all active records, client picks which certificate to trust -pub async fn publish_record( - state: &AppState, - host: &str, - body: &bytes::Bytes, - authority: &(impl RemoteAuthority + ?Sized), - signature_fields: SignatureFields, -) -> Response { - let cert_bytes = authority - .cert_chain() - .first() - .map(|c| c.as_ref().to_vec()) - .unwrap_or_default(); - - let fp = cert_fingerprint(&cert_bytes); - let fp_hex = cert_fingerprint_hex(&cert_bytes); - - match &state.storage { - Storage::Redis(redis) => { - let mut conn = match redis.write.get().await { - Ok(c) => c, - Err(e) => { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - }; - let ttl_secs = state.ttl_secs; - let expire_ttl_secs = i64::try_from(state.ttl_secs).unwrap_or(i64::MAX); - let now_secs = unix_now_secs(); - let expire_secs = now_secs + state.ttl_secs; - let index_tags = record_index_tags(body.as_ref(), state.geo.as_deref()); - - let fp_key = redis_primary_key(host, &fp_hex); - let all_index_key = redis_all_index_key(host); - let mut touched_index_keys = HashSet::from([all_index_key.clone()]); - - let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old_record) = old_member.as_deref().and_then(StoredRecord::decode) { - let old_tags = record_index_tags(&old_record.dns, state.geo.as_deref()); - let _: () = conn.zrem(&all_index_key, &fp_hex).await.unwrap_or(()); - - for country in &old_tags.countries { - let key = redis_country_index_key(host, country); - touched_index_keys.insert(key.clone()); - let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); - } - - for asn in &old_tags.asns { - let key = redis_asn_index_key(host, *asn); - touched_index_keys.insert(key.clone()); - let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); - } - } - - let new_member = StoredRecord { - expire_unix_secs: expire_secs, - fingerprint: fp, - dns: body.to_vec(), - cert: cert_bytes.clone(), - signature_fields: signature_fields.clone(), - } - .encode(); - - if let Err(e) = conn - .set_ex::<_, _, ()>(&fp_key, &new_member, ttl_secs) - .await - { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - - if let Err(e) = conn - .zadd::<_, _, _, ()>(&all_index_key, &fp_hex, now_secs as f64) - .await - { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - - for country in &index_tags.countries { - let key = redis_country_index_key(host, country); - touched_index_keys.insert(key.clone()); - if let Err(e) = conn - .zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64) - .await - { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - } - - for asn in &index_tags.asns { - let key = redis_asn_index_key(host, *asn); - touched_index_keys.insert(key.clone()); - if let Err(e) = conn - .zadd::<_, _, _, ()>(&key, &fp_hex, now_secs as f64) - .await - { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - } - - let cutoff = now_secs.saturating_sub(state.ttl_secs) as f64; - trim_expired_index_keys(&mut *conn, touched_index_keys, cutoff, expire_ttl_secs).await; - } - Storage::Memory(mem) => { - let now = Instant::now(); - let expire = now + Duration::from_secs(state.ttl_secs); - let record = Record { - dns_bytes: body.to_vec(), - cert_bytes, - signature_fields, - expire, - index_tags: record_index_tags(body.as_ref(), state.geo.as_deref()), - }; - let mut host_map = mem.records.entry(host.to_string()).or_default(); - host_map.retain_active(now); - host_map.insert(fp, record); - } - } - - info!(host = %host, ttl = state.ttl_secs, bytes = body.len(), fp = %fp_hex, "publish.ok"); - body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) -} - -pub async fn clear_record( - state: &AppState, - host: &str, - authority: &(impl RemoteAuthority + ?Sized), -) -> Response { - let cert_bytes = authority - .cert_chain() - .first() - .map(|c| c.as_ref().to_vec()) - .unwrap_or_default(); - - let fp = cert_fingerprint(&cert_bytes); - let fp_hex = cert_fingerprint_hex(&cert_bytes); - - match &state.storage { - Storage::Redis(redis) => { - let mut conn = match redis.write.get().await { - Ok(c) => c, - Err(e) => { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - }; - - let fp_key = redis_primary_key(host, &fp_hex); - let all_index_key = redis_all_index_key(host); - let mut touched_index_keys = HashSet::from([all_index_key.clone()]); - - let old_member: Option> = conn.get(&fp_key).await.unwrap_or(None); - if let Some(old_record) = old_member.as_deref().and_then(StoredRecord::decode) { - let old_tags = record_index_tags(&old_record.dns, state.geo.as_deref()); - let _: () = conn.zrem(&all_index_key, &fp_hex).await.unwrap_or(()); - for country in &old_tags.countries { - let key = redis_country_index_key(host, country); - touched_index_keys.insert(key.clone()); - let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); - } - for asn in &old_tags.asns { - let key = redis_asn_index_key(host, *asn); - touched_index_keys.insert(key.clone()); - let _: () = conn.zrem(&key, &fp_hex).await.unwrap_or(()); - } - } - - if let Err(e) = conn.del::<_, ()>(&fp_key).await { - return write_error(AppError::Redis { - message: e.to_string(), - }); - } - - let cutoff = unix_now_secs().saturating_sub(state.ttl_secs) as f64; - let expire_ttl_secs = i64::try_from(state.ttl_secs).unwrap_or(i64::MAX); - trim_expired_index_keys(&mut *conn, touched_index_keys, cutoff, expire_ttl_secs).await; - } - Storage::Memory(mem) => { - let remove_host = if let Some(mut host_map) = mem.records.get_mut(host) { - let _ = host_map.remove(&fp); - host_map.is_empty() - } else { - false - }; - if remove_host { - mem.records.remove(host); - } - } - } - - info!(host = %host, fp = %fp_hex, "publish.clear"); - body_response(http::StatusCode::OK, bytes::Bytes::from_static(b"OK")) -} diff --git a/src/bin/ddns-server/publish/tests.rs b/src/bin/ddns-server/publish/tests.rs deleted file mode 100644 index a22a279..0000000 --- a/src/bin/ddns-server/publish/tests.rs +++ /dev/null @@ -1,126 +0,0 @@ -use std::{ - collections::HashMap, - net::{Ipv4Addr, SocketAddrV4}, - sync::Arc, -}; - -use ddns::core::{MdnsPacket, parser::record::endpoint::EndpointAddr, signature::SignatureFields}; -use dhttp_identity::identity::RemoteAuthority; -use rustls::pki_types::CertificateDer; - -use super::store::{clear_record, publish_record}; -use crate::{ - lookup::query::{LookupResult, perform_lookup}, - policy::DomainPolicies, - storage::{AppState, MemoryStorage, SeedRecords, Storage}, -}; - -#[derive(Debug)] -struct TestAuthority { - name: &'static str, - certs: Vec>, -} - -impl TestAuthority { - fn new(name: &'static str, cert_bytes: Vec) -> Self { - Self { - name, - certs: vec![CertificateDer::from(cert_bytes)], - } - } -} - -impl RemoteAuthority for TestAuthority { - fn name(&self) -> &str { - self.name - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - &self.certs - } -} - -fn memory_state() -> AppState { - AppState { - storage: Storage::Memory(MemoryStorage::new()), - host_allowlist: Arc::new(vec!["genmeta.net".to_string(), "dhttp.net".to_string()]), - require_signature: true, - ttl_secs: 30, - policies: Arc::new(DomainPolicies::default()), - seed_records: SeedRecords::default(), - geo: None, - } -} - -fn packet_for(host: &str, last_octet: u8) -> bytes::Bytes { - let endpoint = EndpointAddr::direct_v4(SocketAddrV4::new( - Ipv4Addr::new(203, 0, 113, last_octet), - 4433, - )); - let mut hosts = HashMap::new(); - hosts.insert(host.to_owned(), vec![endpoint]); - bytes::Bytes::from(MdnsPacket::answer(0, &hosts).to_bytes()) -} - -#[tokio::test] -async fn clear_record_removes_only_current_certificate_fingerprint() { - let state = memory_state(); - let host = "reimu.pilot.dhttp.net"; - let authority_a = TestAuthority::new("authority-a", vec![1]); - let authority_b = TestAuthority::new("authority-b", vec![2]); - let packet_a = packet_for(host, 1); - let packet_b = packet_for(host, 2); - - assert_eq!( - publish_record( - &state, - host, - &packet_a, - &authority_a, - SignatureFields::empty() - ) - .await - .status(), - http::StatusCode::OK - ); - assert_eq!( - publish_record( - &state, - host, - &packet_b, - &authority_b, - SignatureFields::empty() - ) - .await - .status(), - http::StatusCode::OK - ); - - assert_eq!( - clear_record(&state, host, &authority_a).await.status(), - http::StatusCode::OK - ); - - let LookupResult::Multi(response) = perform_lookup(&state, host, None, None).await.unwrap() - else { - panic!("authority b record should remain"); - }; - assert_eq!(response.records.len(), 1); - assert_eq!(response.records[0].cert, authority_b.certs[0].as_ref()); -} - -#[tokio::test] -async fn clear_record_is_idempotent_for_missing_fingerprint() { - let state = memory_state(); - let host = "reimu.pilot.dhttp.net"; - let authority = TestAuthority::new("authority", vec![1]); - - assert_eq!( - clear_record(&state, host, &authority).await.status(), - http::StatusCode::OK - ); - assert!(matches!( - perform_lookup(&state, host, None, None).await.unwrap(), - LookupResult::NotFound - )); -} diff --git a/src/bin/ddns-server/storage.rs b/src/bin/ddns-server/storage.rs deleted file mode 100644 index f8dcd0d..0000000 --- a/src/bin/ddns-server/storage.rs +++ /dev/null @@ -1,520 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, - time::{SystemTime, UNIX_EPOCH}, -}; - -use bytes::BufMut; -use dashmap::{DashMap, DashSet}; -use ddns::core::{ - parser::{packet::be_packet, record::RData}, - signature::SignatureFields, - wire::ResponseRecord, -}; -use deadpool_redis::Pool; -use nom::{ - IResult, - bytes::streaming::take, - number::streaming::{be_u32, be_u64}, -}; -use tokio::time::Instant; - -use crate::{geo::GeoResolver, policy::DomainPolicies}; - -// --------------------------------------------------------------------------- -// Storage helpers -// --------------------------------------------------------------------------- - -/// SHA-256 fingerprint of a DER-encoded certificate, used as per-source dedup key. -pub fn cert_fingerprint(cert_der: &[u8]) -> [u8; 32] { - use ring::digest::{SHA256, digest}; - let d = digest(&SHA256, cert_der); - d.as_ref().try_into().expect("SHA-256 is always 32 bytes") -} - -pub fn fingerprint_hex(fingerprint: &[u8; 32]) -> String { - fingerprint.iter().map(|b| format!("{b:02x}")).collect() -} - -pub fn cert_fingerprint_hex(cert_der: &[u8]) -> String { - fingerprint_hex(&cert_fingerprint(cert_der)) -} - -pub fn unix_now_secs() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0) -} - -pub fn redis_primary_key(host: &str, fingerprint_hex: &str) -> String { - format!("{host}:fp:{fingerprint_hex}") -} - -pub fn redis_all_index_key(host: &str) -> String { - format!("{host}:idx:all") -} - -pub fn redis_country_index_key(host: &str, country: &str) -> String { - format!("{host}:idx:country:{country}") -} - -pub fn redis_asn_index_key(host: &str, asn: u32) -> String { - format!("{host}:idx:asn:{asn}") -} - -pub fn redis_blacklist_key() -> &'static str { - "ddns:blacklist" -} - -#[derive(Clone, Debug, Default, PartialEq, Eq)] -pub struct RecordIndexTags { - pub countries: Vec, - pub asns: Vec, -} - -pub fn record_index_tags(dns_bytes: &[u8], geo: Option<&GeoResolver>) -> RecordIndexTags { - let Some(geo) = geo else { - return RecordIndexTags::default(); - }; - - let Ok((_, packet)) = be_packet(dns_bytes) else { - return RecordIndexTags::default(); - }; - - let mut countries = HashSet::new(); - let mut asns = HashSet::new(); - - for answer in &packet.answers { - let RData::E(endpoint) = answer.data() else { - continue; - }; - - let traits = geo.lookup_traits(endpoint.addr().ip()); - if let Some(country) = traits.country { - countries.insert(country); - } - if let Some(asn) = traits.asn { - asns.insert(asn); - } - } - - let mut countries = countries.into_iter().collect::>(); - countries.sort(); - - let mut asns = asns.into_iter().collect::>(); - asns.sort_unstable(); - - RecordIndexTags { countries, asns } -} - -// --------------------------------------------------------------------------- -// Redis primary record wire type -// --------------------------------------------------------------------------- - -/// One record as persisted in the Redis primary record value. -/// -/// Wire layout (big-endian, contiguous): -/// ```text -/// +-----------+--------------+---------------+--------+-----------+------+-----------+------+-----------+------+-----------+------+ -/// | expire | fingerprint | digest_len | digest | input_len | input| sig_len | sig | dns_len | dns | cert_len | cert | -/// | u64 BE | 32 bytes | u32 BE | ... | u32 BE | ... | u32 BE | ... | u32 BE | ... | u32 BE | ... | -/// +-----------+--------------+---------------+--------+-----------+------+-----------+------+-----------+------+-----------+------+ -/// ``` -#[derive(Debug, Clone)] -pub struct StoredRecord { - /// Unix timestamp (seconds) after which this entry is considered stale. - pub expire_unix_secs: u64, - /// SHA-256 fingerprint of the publisher's leaf certificate. - /// Serves as the publisher's identity: uniquely identifies a certificate among multiple - /// valid certs that may be issued for the same domain (from different CAs, at different times, - /// for different regions, etc.). Used as storage key to enable multi-publisher scenarios. - pub fingerprint: [u8; 32], - /// Serialised DNS packet bytes. - pub dns: Vec, - /// DER-encoded leaf certificate of the publisher. - pub cert: Vec, - /// Saved RFC-style publisher signature fields for the DNS packet. - pub signature_fields: SignatureFields, -} - -impl StoredRecord { - /// Encode to a byte buffer suitable for use as a Redis primary record value. - pub fn encode(&self) -> Vec { - let mut buf = Vec::with_capacity( - 8 + 32 - + 4 - + self.signature_fields.content_digest.len() - + 4 - + self.signature_fields.signature_input.len() - + 4 - + self.signature_fields.signature.len() - + 4 - + self.dns.len() - + 4 - + self.cert.len(), - ); - buf.put_u64(self.expire_unix_secs); - buf.put_slice(&self.fingerprint); - put_field(&mut buf, &self.signature_fields.content_digest); - put_field(&mut buf, &self.signature_fields.signature_input); - put_field(&mut buf, &self.signature_fields.signature); - put_field(&mut buf, &self.dns); - put_field(&mut buf, &self.cert); - buf - } - - /// Decode from a Redis primary record value. Returns `None` on malformed input. - pub fn decode(data: &[u8]) -> Option { - be_stored_record(data) - .ok() - .and_then(|(remain, r)| remain.is_empty().then_some(r)) - } -} - -/// nom parser for [`StoredRecord`]. -pub fn be_stored_record(input: &[u8]) -> IResult<&[u8], StoredRecord> { - let (input, expire_unix_secs) = be_u64(input)?; - let (input, fp_bytes) = take(32usize)(input)?; - let (input, content_digest) = be_field(input)?; - let (input, signature_input) = be_field(input)?; - let (input, signature) = be_field(input)?; - let (input, dns) = be_field(input)?; - let (input, cert) = be_field(input)?; - Ok(( - input, - StoredRecord { - expire_unix_secs, - fingerprint: fp_bytes.try_into().expect("took exactly 32 bytes"), - dns, - cert, - signature_fields: SignatureFields { - content_digest, - signature_input, - signature, - }, - }, - )) -} - -fn put_field(buf: &mut Vec, value: &[u8]) { - buf.put_u32(value.len() as u32); - buf.put_slice(value); -} - -fn be_field(input: &[u8]) -> IResult<&[u8], Vec> { - let (input, len) = be_u32(input)?; - let (input, value) = take(len as usize)(input)?; - Ok((input, value.to_vec())) -} - -// --------------------------------------------------------------------------- -// Storage -// --------------------------------------------------------------------------- - -/// A single record stored under a (host, server-fingerprint) key. -#[derive(Clone, Debug)] -pub struct Record { - pub dns_bytes: Vec, - pub cert_bytes: Vec, - pub signature_fields: SignatureFields, - /// Wall-clock expiry (for TTL eviction). - pub expire: Instant, - /// Precomputed country / ASN buckets used by the Lite indexes. - pub index_tags: RecordIndexTags, -} - -#[derive(Clone, Debug, Default)] -pub struct HostRecords { - pub records: HashMap<[u8; 32], Record>, - pub recent: Vec<[u8; 32]>, - pub by_country: HashMap>, - pub by_asn: HashMap>, -} - -impl HostRecords { - fn remove_fingerprint(list: &mut Vec<[u8; 32]>, fingerprint: &[u8; 32]) { - list.retain(|existing| existing != fingerprint); - } - - fn remove_from_indexes(&mut self, fingerprint: &[u8; 32], tags: &RecordIndexTags) { - Self::remove_fingerprint(&mut self.recent, fingerprint); - - for country in &tags.countries { - let should_remove = if let Some(bucket) = self.by_country.get_mut(country) { - Self::remove_fingerprint(bucket, fingerprint); - bucket.is_empty() - } else { - false - }; - - if should_remove { - self.by_country.remove(country); - } - } - - for asn in &tags.asns { - let should_remove = if let Some(bucket) = self.by_asn.get_mut(asn) { - Self::remove_fingerprint(bucket, fingerprint); - bucket.is_empty() - } else { - false - }; - - if should_remove { - self.by_asn.remove(asn); - } - } - } - - pub fn insert(&mut self, fingerprint: [u8; 32], record: Record) { - if let Some(old_record) = self.records.remove(&fingerprint) { - self.remove_from_indexes(&fingerprint, &old_record.index_tags); - } - - self.recent.insert(0, fingerprint); - - for country in &record.index_tags.countries { - self.by_country - .entry(country.clone()) - .or_default() - .insert(0, fingerprint); - } - - for asn in &record.index_tags.asns { - self.by_asn.entry(*asn).or_default().insert(0, fingerprint); - } - - self.records.insert(fingerprint, record); - } - - pub fn remove(&mut self, fingerprint: &[u8; 32]) -> Option { - let record = self.records.remove(fingerprint)?; - self.remove_from_indexes(fingerprint, &record.index_tags); - Some(record) - } - - pub fn retain_active(&mut self, now: Instant) { - let expired = self - .records - .iter() - .filter_map(|(fingerprint, record)| (record.expire <= now).then_some(*fingerprint)) - .collect::>(); - - for fingerprint in expired { - let _ = self.remove(&fingerprint); - } - } - - pub fn collect_candidates( - &self, - source_country: Option<&str>, - source_asn: Option, - total_cap: usize, - asn_cap: usize, - country_cap: usize, - all_cap: usize, - ) -> Vec<[u8; 32]> { - let mut candidates = Vec::new(); - let mut seen = HashSet::new(); - - let mut push_bucket = |bucket: Option<&Vec<[u8; 32]>>, bucket_cap: usize| { - let Some(bucket) = bucket else { - return; - }; - - for fingerprint in bucket.iter().take(bucket_cap) { - if candidates.len() >= total_cap { - break; - } - - if seen.insert(*fingerprint) { - candidates.push(*fingerprint); - } - } - }; - - if let Some(asn) = source_asn { - push_bucket(self.by_asn.get(&asn), asn_cap.min(total_cap)); - } - - if let Some(country) = source_country { - push_bucket(self.by_country.get(country), country_cap.min(total_cap)); - } - - push_bucket(Some(&self.recent), all_cap.min(total_cap)); - candidates - } - - pub fn is_empty(&self) -> bool { - self.records.is_empty() - } -} - -/// Unified in-memory storage: host → { cert_fingerprint → Record }. -/// Both Standard and OpenMulti policies share this map. -/// -/// Per-fingerprint keying design supports PKI's multi-certificate model: -/// A single domain can have multiple valid certificates issued by different CAs, -/// or by the same CA at different times (certificate rotation, multi-region deployment, etc.). -/// Each certificate has a unique fingerprint as its identity. -/// -/// - Same certificate (same fingerprint) republishing → overwrites the previous record -/// - Different certificates (different fingerprints) for same domain → coexist independently -/// - Clients query get all valid records and choose which one to use -#[derive(Clone)] -pub struct MemoryStorage { - pub records: Arc>, - pub blacklist: Arc>, -} - -impl MemoryStorage { - pub fn new() -> Self { - Self { - records: Arc::new(DashMap::new()), - blacklist: Arc::new(DashSet::new()), - } - } - - pub fn with_blacklist(hosts: impl IntoIterator) -> Self { - let storage = Self::new(); - for host in hosts { - storage.blacklist_host(host); - } - storage - } - - pub fn blacklist_host(&self, host: impl Into) { - self.blacklist.insert(host.into()); - } - - pub fn remove_blacklist_host(&self, host: &str) { - self.blacklist.remove(host); - } - - pub fn is_blacklisted(&self, host: &str) -> bool { - self.blacklist.contains(host) - } -} - -#[derive(Clone)] -pub struct RedisStorage { - pub write: Pool, - pub read: Pool, -} - -#[derive(Clone)] -pub enum Storage { - Redis(RedisStorage), - Memory(MemoryStorage), -} - -pub type LookupRecord = ResponseRecord; -pub type SeedRecords = Arc>>; - -// --------------------------------------------------------------------------- -// Application state -// --------------------------------------------------------------------------- - -#[derive(Clone)] -pub struct AppState { - pub storage: Storage, - pub host_allowlist: Arc>, - pub require_signature: bool, - pub ttl_secs: u64, - pub policies: Arc, - pub seed_records: SeedRecords, - pub geo: Option>, -} - -#[cfg(test)] -mod tests { - use super::*; - - fn fp(seed: u8) -> [u8; 32] { - [seed; 32] - } - - fn record(country: Option<&str>, asn: Option) -> Record { - Record { - dns_bytes: Vec::new(), - cert_bytes: Vec::new(), - signature_fields: SignatureFields::empty(), - expire: Instant::now() + tokio::time::Duration::from_secs(60), - index_tags: RecordIndexTags { - countries: country.into_iter().map(str::to_owned).collect(), - asns: asn.into_iter().collect(), - }, - } - } - - #[test] - fn host_records_collect_candidates_prefers_asn_then_country_then_recent() { - let mut host = HostRecords::default(); - host.insert(fp(1), record(Some("US"), Some(64512))); - host.insert(fp(2), record(Some("US"), None)); - host.insert(fp(3), record(Some("JP"), None)); - - let candidates = host.collect_candidates(Some("US"), Some(64512), 8, 2, 2, 8); - - assert_eq!(candidates, vec![fp(1), fp(2), fp(3)]); - } - - #[test] - fn host_records_remove_cleans_secondary_indexes() { - let mut host = HostRecords::default(); - let fingerprint = fp(9); - host.insert(fingerprint, record(Some("US"), Some(64512))); - - let _ = host.remove(&fingerprint); - - assert!(host.recent.is_empty()); - assert!(host.by_country.is_empty()); - assert!(host.by_asn.is_empty()); - assert!(host.records.is_empty()); - } - - #[test] - fn stored_record_roundtrips_signature_fields() { - let record = StoredRecord { - expire_unix_secs: 123, - fingerprint: fp(7), - dns: vec![1, 2, 3], - cert: vec![4, 5, 6], - signature_fields: SignatureFields { - content_digest: b"sha-256=:abc:".to_vec(), - signature_input: - b"dns=(\"content-digest\");created=1;keyid=\"sha256:abc\";alg=\"ed25519\"" - .to_vec(), - signature: b"dns=:sig:".to_vec(), - }, - }; - - let decoded = StoredRecord::decode(&record.encode()).expect("stored record decodes"); - - assert_eq!(decoded.expire_unix_secs, record.expire_unix_secs); - assert_eq!(decoded.fingerprint, record.fingerprint); - assert_eq!(decoded.dns, record.dns); - assert_eq!(decoded.cert, record.cert); - assert_eq!(decoded.signature_fields, record.signature_fields); - } - - #[test] - fn redis_blacklist_key_is_stable() { - assert_eq!(redis_blacklist_key(), "ddns:blacklist"); - } - - #[test] - fn memory_storage_tracks_blacklisted_hosts() { - let storage = MemoryStorage::with_blacklist(["blocked.example".to_string()]); - - assert!(storage.is_blacklisted("blocked.example")); - assert!(!storage.is_blacklisted("allowed.example")); - - storage.blacklist_host("other.example"); - assert!(storage.is_blacklisted("other.example")); - - storage.remove_blacklist_host("blocked.example"); - assert!(!storage.is_blacklisted("blocked.example")); - } -} diff --git a/src/mdns/resolvers.rs b/src/mdns/resolvers.rs deleted file mode 100644 index 1bd416e..0000000 --- a/src/mdns/resolvers.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod mdns; diff --git a/src/mdns/resolvers/mdns.rs b/src/mdns/resolvers/mdns.rs deleted file mode 100644 index 16912ab..0000000 --- a/src/mdns/resolvers/mdns.rs +++ /dev/null @@ -1,354 +0,0 @@ -use std::{fmt, io, net::IpAddr}; -#[cfg(feature = "mdns-resolver")] -use std::{net::SocketAddr, sync::Arc}; - -#[cfg(feature = "mdns-resolver")] -use dquic::qresolve::RecordStream; -use dquic::{ - qbase::net::Family, - qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, -}; -use futures::{FutureExt, StreamExt, TryFutureExt, future, stream}; -#[cfg(feature = "mdns-resolver")] -use futures::{Stream, stream::FuturesUnordered}; - -#[cfg(feature = "mdns-resolver")] -use super::super::protocol::MdnsProtocol; -#[cfg(feature = "mdns-resolver")] -use crate::core::parser::packet::Packet; -use crate::core::parser::record::RData; -pub type MdnsResolver = crate::mdns::service::Mdns; - -impl MdnsResolver { - pub fn source(&self) -> Source { - Source::Mdns { - nic: self.bound_nic().into(), - family: match self.bound_ip() { - IpAddr::V4(..) => Family::V4, - IpAddr::V6(..) => Family::V6, - }, - } - } -} - -impl fmt::Display for MdnsResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.source(), f) - } -} - -impl Publish for MdnsResolver { - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - let endpoints = match endpoints_from_packet(packet) { - Ok(endpoints) => endpoints, - Err(error) => return future::ready(Err(error)).boxed(), - }; - self.insert_host(name.to_string(), endpoints); - future::ready(Ok(())).boxed() - } -} - -impl Resolve for MdnsResolver { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - let source = self.source(); - self.query(name.to_owned()) - .map_ok(move |list| { - let endpoints = crate::resolvers::selector::selected_endpoint_addrs(list); - stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() - }) - .boxed() - } -} - -fn endpoints_from_packet(packet: &[u8]) -> io::Result> { - use crate::core::parser::packet::be_packet; - - be_packet(packet) - .map(|(_, pkt)| { - pkt.answers - .iter() - .filter_map(|rr| match rr.data() { - RData::E(ep) => Some(ep.clone()), - _ => None, - }) - .collect::>() - }) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error.to_string())) -} - -#[cfg(feature = "mdns-resolver")] -pub struct MdnsBindDriver { - iface_manager: Arc, - null_io_factory: Arc, - service_name: Arc, -} - -#[cfg(feature = "mdns-resolver")] -impl MdnsBindDriver { - pub fn new(service_name: impl Into>) -> Self { - Self { - iface_manager: Arc::new(h3x::dquic::net::InterfaceManager::new()), - null_io_factory: Arc::new(h3x::dquic::NullIoFactory), - service_name: service_name.into(), - } - } - - fn install_or_rebind_mdns( - &self, - network: &h3x::dquic::Network, - bind_iface: &h3x::dquic::net::BindInterface, - ) { - let bind_uri = bind_iface.bind_uri(); - let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { - tracing::debug!(%bind_uri, "skipping mdns binding for non-interface bind uri"); - return; - }; - let Some(ip) = network.resolve_device_addr(device, family) else { - tracing::debug!(%bind_uri, "skipping mdns binding without local interface address"); - return; - }; - - bind_iface.with_components_mut(|components, _iface| { - match components.try_init_with(|| crate::mdns::service::Mdns::new(&self.service_name, ip, device)) { - Ok(mdns) => mdns.reinit_on(device, ip), - Err(error) => { - let report = snafu::Report::from_error(&error); - tracing::debug!(error = %report, %bind_uri, "failed to initialize mdns binding"); - } - } - }); - } -} - -#[cfg(feature = "mdns-resolver")] -impl h3x::dquic::BindDriver for MdnsBindDriver { - fn bind<'a>( - &'a self, - network: &'a h3x::dquic::Network, - uri: h3x::dquic::net::BindUri, - ) -> futures::future::BoxFuture<'a, h3x::dquic::net::BindInterface> { - async move { - let iface = self - .iface_manager - .bind(uri, self.null_io_factory.clone()) - .await; - self.install_or_rebind_mdns(network, &iface); - iface - } - .boxed() - } - - fn rebind<'a>( - &'a self, - network: &'a h3x::dquic::Network, - iface: &'a h3x::dquic::net::BindInterface, - ) -> futures::future::BoxFuture<'a, ()> { - async move { - self.install_or_rebind_mdns(network, iface); - } - .boxed() - } -} - -#[cfg(feature = "mdns-resolver")] -pub struct MdnsResolvers { - network: Arc, - driver: Arc, - patterns: Arc>, - _handles: Vec, -} - -#[cfg(feature = "mdns-resolver")] -#[derive(Debug, Clone)] -pub struct BoundMdnsResolver { - pub device: String, - pub family: Family, - pub resolver: MdnsResolver, -} - -#[cfg(feature = "mdns-resolver")] -impl fmt::Debug for MdnsResolvers { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("MdnsResolvers") - .field("patterns", &self.patterns) - .finish_non_exhaustive() - } -} - -#[cfg(feature = "mdns-resolver")] -impl fmt::Display for MdnsResolvers { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("mDNS resolvers") - } -} - -#[cfg(feature = "mdns-resolver")] -impl MdnsResolvers { - pub async fn bind( - network: Arc, - patterns: Arc>, - service_name: impl Into>, - ) -> Self { - let driver = Arc::new(MdnsBindDriver::new(service_name)); - let mut handles = Vec::with_capacity(patterns.len()); - for pattern in patterns.iter() { - handles.push(network.bind_with(driver.clone(), pattern.clone()).await); - } - - Self { - network, - driver, - patterns, - _handles: handles, - } - } - - pub fn bound_interfaces( - &self, - pattern: &h3x::dquic::binds::BindPattern, - ) -> Option> { - self.network.get_interfaces_with(&self.driver, pattern) - } - - fn for_each_resolver(&self, mut f: impl FnMut(&MdnsResolver)) { - for pattern in self.patterns.iter() { - let Some(ifaces) = self.bound_interfaces(pattern) else { - continue; - }; - for iface in ifaces { - iface.with_components(|components, _| { - if let Some(mdns) = components.get::() { - f(mdns); - } - }); - } - } - } - - pub fn bound_resolvers(&self) -> Vec { - let mut resolvers = Vec::new(); - for pattern in self.patterns.iter() { - let Some(ifaces) = self.bound_interfaces(pattern) else { - continue; - }; - for iface in ifaces { - let bind_uri = iface.bind_uri(); - let Some((family, device, _port)) = bind_uri.as_iface_bind_uri() else { - continue; - }; - iface.with_components(|components, _| { - if let Some(resolver) = components.get::() { - resolvers.push(BoundMdnsResolver { - device: device.to_owned(), - family, - resolver: resolver.clone(), - }); - } - }); - } - } - resolvers - } - - pub async fn query(&self, name: &str) -> io::Result { - let mut lookup_futures = FuturesUnordered::new(); - let mut has_resolver = false; - self.for_each_resolver(|resolver| { - has_resolver = true; - let source = resolver.source(); - lookup_futures.push( - resolver - .query(name.to_owned()) - .map_ok(move |eps| (source, eps)), - ); - }); - if !has_resolver { - return Err(io::Error::other("no mdns resolvers available")); - } - - let mut last_error = None; - let mut has_success = false; - let mut records = Vec::new(); - while let Some(result) = lookup_futures.next().await { - match result { - Ok((source, endpoints)) => { - has_success = true; - records.extend( - endpoints - .into_iter() - .map(|endpoint| (source.clone(), endpoint)), - ); - } - Err(error) => last_error = Some(error), - } - } - - if !has_success { - return Err( - last_error.unwrap_or_else(|| io::Error::other("no mdns resolvers available")) - ); - } - - let records = crate::resolvers::selector::selected_endpoint_records(records); - - Ok(stream::iter(records).boxed()) - } - - /// Discover mDNS broadcasts from all active resolvers. - pub fn discover(&self) -> impl Stream + use<> { - let mut protos = Vec::new(); - self.for_each_resolver(|resolver| { - protos.push(resolver.protocol()); - }); - - async fn receive_one( - proto: Arc, - ) -> Option<((SocketAddr, Packet), Arc)> { - let result = proto.receive_boardcast().await.ok()?; - Some((result, proto)) - } - - let mut pending = protos - .into_iter() - .map(receive_one) - .collect::>(); - - Box::pin(stream::poll_fn(move |cx| { - use std::task::Poll; - loop { - match pending.poll_next_unpin(cx) { - Poll::Ready(Some(Some((item, proto)))) => { - pending.push(receive_one(proto)); - return Poll::Ready(Some(item)); - } - Poll::Ready(Some(None)) => continue, - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => return Poll::Pending, - } - } - })) - } -} - -#[cfg(feature = "mdns-resolver")] -impl Publish for MdnsResolvers { - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - let endpoints = match endpoints_from_packet(packet) { - Ok(endpoints) => endpoints, - Err(error) => return future::ready(Err(error)).boxed(), - }; - - self.for_each_resolver(|resolver| { - resolver.insert_host(name.to_string(), endpoints.clone()); - }); - - future::ready(Ok(())).boxed() - } -} - -#[cfg(feature = "mdns-resolver")] -impl Resolve for MdnsResolvers { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - self.query(name).boxed() - } -} diff --git a/src/publisher.rs b/src/publisher.rs deleted file mode 100644 index 81b2a66..0000000 --- a/src/publisher.rs +++ /dev/null @@ -1,1222 +0,0 @@ -use std::{ - any::{Any, TypeId}, - collections::{HashMap, HashSet}, - future::Future, - io, - net::SocketAddr, - pin::Pin, - sync::Arc, - time::Duration, -}; - -use dhttp_identity::{identity::LocalAuthority, name::Name}; -#[cfg(feature = "mdns-resolver")] -use dquic::qbase::net::Family; -use dquic::{ - qbase::net::addr::EndpointAddr, - qinterface::component::location::AddressEvent, - qresolve::{Publish, Resolve}, - qtraversal::nat::client::{ClientLocationData, NatType}, -}; -use snafu::{IntoError, ResultExt, Snafu}; - -use crate::{ - core::{ - MdnsPacket, - parser::record::endpoint::EndpointAddr as DnsEndpointAddr, - signature::{SignatureFields, SignatureFieldsError}, - }, - resolvers::Resolvers, -}; - -#[cfg(feature = "h3x-resolver")] -type DeferredH3Resolver = crate::resolvers::deferred::DeferredResolver< - crate::resolvers::h3::H3Resolver, ->; - -pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); -/// Upper bound for a single publish attempt in the background loop. -/// -/// Network changes can leave an in-flight H3 publish waiting on paths that no -/// longer exist. Timing out the attempt keeps consecutive publishes -/// independent: the next interval observes the current bindings again. -pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); -const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); - -type PublishLoopFuture<'a> = Pin + Send + 'a>>; - -#[derive(Debug, Snafu)] -#[snafu(module(create_publisher_error))] -pub enum CreatePublisherError { - #[snafu(display("anonymous endpoint cannot publish dns records"))] - AnonymousEndpoint, -} - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum PublishOnceError { - #[snafu(display("no publisher resolver available"))] - NoPublisherResolver, - #[snafu(display("failed to encode endpoint address"))] - EncodeEndpoint, - #[snafu(display("failed to sign dns packet"))] - SignPacket { source: SignatureFieldsError }, - #[snafu(display("failed to publish dns packet with {publisher}"))] - Publish { - publisher: String, - source: io::Error, - }, -} - -/// Optional metadata applied to endpoint records before signing. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub struct PublishOptions { - /// Stable server identifier for names served by multiple publishers. - /// - /// `0` marks the endpoint as the main record. Non-zero values mark the - /// record as clustered and encode the identifier as its sequence number. - pub server_id: Option, -} - -#[derive(Debug, Clone, Default)] -pub struct PublishAddresses { - wide_area: Vec, -} - -impl PublishAddresses { - pub fn new() -> Self { - Self::default() - } - - pub fn wide_area(mut self, endpoints: impl IntoIterator) -> Self { - self.wide_area.extend(endpoints); - self - } -} - -#[derive(Debug, Clone)] -pub struct EndpointPublisher { - inner: Arc, -} - -impl EndpointPublisher { - pub fn new(inner: Arc) -> Self { - Self { inner } - } - - pub async fn publish_once( - &self, - name: &Name<'_>, - addresses: &PublishAddresses, - ) -> Result<(), PublishOnceError> { - self.inner - .publish_addresses_to_resolver(self.inner.resolver.as_ref(), name, &addresses.wide_area) - .await - } -} - -pub struct Publisher { - identity: Arc, - network: Arc, - resolver: Arc, - bind_patterns: Arc>, - interval: Duration, - publish_timeout: Duration, - options: PublishOptions, -} - -impl std::fmt::Debug for Publisher { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Publisher") - .field("identity", &self.identity.name()) - .field("bind_patterns", &self.bind_patterns) - .field("interval", &self.interval) - .field("publish_timeout", &self.publish_timeout) - .field("options", &self.options) - .finish_non_exhaustive() - } -} - -impl Publisher { - pub fn new( - identity: Arc, - network: Arc, - resolver: Arc, - bind_patterns: Arc>, - ) -> Self { - Self { - identity, - network, - resolver, - bind_patterns, - interval: DEFAULT_PUBLISH_INTERVAL, - publish_timeout: DEFAULT_PUBLISH_TIMEOUT, - options: PublishOptions::default(), - } - } - - pub fn with_options(mut self, options: PublishOptions) -> Self { - self.options = options; - self - } - - pub fn options(&self) -> PublishOptions { - self.options - } - - pub fn interval(&self) -> Duration { - self.interval - } - - pub fn publish_timeout(&self) -> Duration { - self.publish_timeout - } - - pub fn with_publish_timeout(mut self, timeout: Duration) -> Self { - self.publish_timeout = timeout; - self - } - - pub async fn publish_once(&self) -> Result<(), PublishOnceError> { - let mut published = false; - let public_endpoints = self.public_endpoints(); - tracing::debug!( - endpoint_count = public_endpoints.len(), - endpoints = ?public_endpoints, - "publishing public endpoints" - ); - published |= self - .publish_to_resolver(self.resolver.as_ref(), &public_endpoints) - .await?; - - if !published { - return publish_once_error::NoPublisherResolverSnafu.fail(); - } - - Ok(()) - } - - async fn publish_addresses_to_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - endpoints: &[EndpointAddr], - ) -> Result<(), PublishOnceError> { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - for resolver in resolvers.iter() { - self.publish_addresses_to_single_resolver(resolver.as_ref(), name, endpoints) - .await?; - } - return Ok(()); - } - - self.publish_addresses_to_single_resolver(resolver, name, endpoints) - .await - } - - async fn publish_addresses_to_single_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - endpoints: &[EndpointAddr], - ) -> Result<(), PublishOnceError> { - let packet = self.dns_packet_for_name(&name.to_string(), endpoints)?; - let any: &dyn Any = resolver; - - #[cfg(feature = "http-resolver")] - if let Some(http) = any.downcast_ref::() { - let signature_fields = SignatureFields::sign(&packet, self.identity.as_ref()) - .await - .context(publish_once_error::SignPacketSnafu)?; - http.publish_signed(&name.to_string(), &packet, &signature_fields) - .await - .context(publish_once_error::PublishSnafu { - publisher: http.to_string(), - })?; - return Ok(()); - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - let signature_fields = SignatureFields::sign(&packet, self.identity.as_ref()) - .await - .context(publish_once_error::SignPacketSnafu)?; - h3.publish_signed(&name.to_string(), &packet, &signature_fields) - .await - .context(publish_once_error::PublishSnafu { - publisher: h3.to_string(), - })?; - return Ok(()); - } - - Ok(()) - } - - fn dns_packet_for_name( - &self, - name: &str, - endpoints: &[EndpointAddr], - ) -> Result, PublishOnceError> { - let mut encoded = Vec::with_capacity(endpoints.len()); - for endpoint in endpoints { - encoded.push( - DnsEndpointAddr::try_from(*endpoint) - .map_err(|_| publish_once_error::EncodeEndpointSnafu.build())?, - ); - } - - let mut hosts = HashMap::new(); - hosts.insert(name.to_owned(), encoded); - Ok(MdnsPacket::answer(0, &hosts).to_bytes()) - } - - pub async fn run(&self) -> ! { - let mut locations = self.network.quic().locations().subscribe(); - let interval = tokio::time::sleep(self.interval); - tokio::pin!(interval); - // Keep at most one publish attempt in flight. A timer tick or - // publishable location change drops the current future and starts a new - // debounced attempt so a stale H3 publish cannot block publication from - // the latest bindings. - let mut current_publish = self.new_publish_loop_future(); - - loop { - tokio::select! { - _ = &mut current_publish => { - current_publish = Self::pending_publish_loop_future(); - } - _ = &mut interval => { - interval.as_mut().reset(tokio::time::Instant::now() + self.interval); - self.clear_publish_state(); - current_publish = self.new_publish_loop_future(); - } - event = locations.recv() => { - let Some((bind_uri, event)) = event else { - continue; - }; - if !self.bind_patterns.iter().any(|pattern| pattern.matches(&bind_uri)) { - continue; - } - if !Self::location_event_requires_publish(&event) { - continue; - } - - self.clear_publish_state(); - current_publish = self.new_publish_loop_future(); - } - } - } - } - - fn new_publish_loop_future(&self) -> PublishLoopFuture<'_> { - Box::pin(async move { - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE).await; - let _ = self.publish_attempt().await; - }) - } - - fn pending_publish_loop_future<'a>() -> PublishLoopFuture<'a> { - Box::pin(std::future::pending()) - } - - async fn publish_attempt(&self) -> bool { - tracing::trace!( - timeout_ms = self.publish_timeout.as_millis(), - "starting dns publish attempt" - ); - match tokio::time::timeout(self.publish_timeout, self.publish_once()).await { - Ok(Ok(())) => { - tracing::info!("published resolver endpoints"); - true - } - Ok(Err(error)) => { - let report = snafu::Report::from_error(&error); - tracing::warn!(error = %report, "dns publish failed"); - false - } - Err(_elapsed) => { - // Dropping a timed-out publish future does not let the H3 - // resolver observe a request error. Reset resolver-owned - // connection state so the next interval reconnects from - // the current network bindings. - self.clear_publish_state(); - tracing::warn!( - timeout_ms = self.publish_timeout.as_millis(), - "dns publish timed out" - ); - false - } - } - } - - fn location_event_requires_publish(event: &AddressEvent) -> bool { - match event { - AddressEvent::Upsert(data) => { - // `Locations` also carries transient STUN failures. Those do - // not add a publishable endpoint; treating them as publish - // triggers creates a retry loop while the node is offline. - if let Some(bound_addr) = data.downcast_ref::>() { - return bound_addr.is_ok(); - } - if let Some(stun_addr) = data.downcast_ref::() { - return stun_addr.is_ok(); - } - false - } - AddressEvent::Remove(type_id) => { - *type_id == TypeId::of::>() - || *type_id == TypeId::of::() - } - AddressEvent::Closed => true, - } - } - - async fn publish_to_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - public_endpoints: &[EndpointAddr], - ) -> Result { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - let mut published = false; - for resolver in resolvers.iter() { - published |= self - .publish_single_resolver(resolver.as_ref(), public_endpoints) - .await?; - } - return Ok(published); - } - - self.publish_single_resolver(resolver, public_endpoints) - .await - } - - fn clear_publish_state(&self) { - Self::clear_resolver_publish_state(self.resolver.as_ref()); - } - - fn clear_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - for resolver in resolvers.iter() { - Self::clear_resolver_publish_state(resolver.as_ref()); - } - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - h3.clear_pool(); - } - } - - async fn publish_single_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - public_endpoints: &[EndpointAddr], - ) -> Result { - #[cfg(not(any(feature = "http-resolver", feature = "h3x-resolver")))] - let _ = public_endpoints; - - let any: &dyn Any = resolver; - - #[cfg(feature = "http-resolver")] - if let Some(http) = any.downcast_ref::() { - self.publish_signed_http_endpoints(http, public_endpoints) - .await?; - return Ok(true); - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - self.publish_signed_h3_endpoints(h3, public_endpoints) - .await?; - return Ok(true); - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = any.downcast_ref::() { - let Some(h3) = h3.get() else { - return Err(publish_once_error::PublishSnafu { - publisher: h3.to_string(), - } - .into_error(io::Error::other( - "deferred h3 resolver has not been initialized", - ))); - }; - self.publish_signed_h3_endpoints(h3, public_endpoints) - .await?; - return Ok(true); - } - - #[cfg(feature = "mdns-resolver")] - if let Some(mdns) = any.downcast_ref::() { - let mut published = false; - for bound in mdns.bound_resolvers() { - let endpoints = self.local_endpoints_for(&bound.device, bound.family); - self.publish_plain_endpoints(&bound.resolver, &endpoints) - .await?; - published = true; - } - return Ok(published); - } - - Ok(false) - } - - async fn publish_plain_endpoints( - &self, - publisher: &(dyn Publish + Send + Sync), - endpoints: &[EndpointAddr], - ) -> Result<(), PublishOnceError> { - let packet = self.dns_packet(endpoints)?; - let name = self.identity.name(); - tracing::debug!( - publisher = %publisher, - name, - endpoint_count = endpoints.len(), - packet_len = packet.len(), - "publishing dns packet" - ); - publisher - .publish(name, &packet) - .await - .context(publish_once_error::PublishSnafu { - publisher: publisher.to_string(), - }) - } - - #[cfg(feature = "http-resolver")] - async fn publish_signed_http_endpoints( - &self, - publisher: &crate::resolvers::http::HttpResolver, - endpoints: &[EndpointAddr], - ) -> Result<(), PublishOnceError> { - let (packet, signature_fields) = self.signed_packet(endpoints).await?; - let name = self.identity.name(); - tracing::debug!( - publisher = %publisher, - name, - endpoint_count = endpoints.len(), - packet_len = packet.len(), - "publishing signed dns packet" - ); - - publisher - .publish_signed(name, &packet, &signature_fields) - .await - .context(publish_once_error::PublishSnafu { - publisher: publisher.to_string(), - }) - } - - #[cfg(feature = "h3x-resolver")] - async fn publish_signed_h3_endpoints( - &self, - publisher: &crate::resolvers::h3::H3Resolver, - endpoints: &[EndpointAddr], - ) -> Result<(), PublishOnceError> { - let (packet, signature_fields) = self.signed_packet(endpoints).await?; - let name = self.identity.name(); - tracing::debug!( - publisher = %publisher, - name, - endpoint_count = endpoints.len(), - packet_len = packet.len(), - "publishing signed dns packet" - ); - - publisher - .publish_signed(name, &packet, &signature_fields) - .await - .context(publish_once_error::PublishSnafu { - publisher: publisher.to_string(), - }) - } - - async fn signed_packet( - &self, - endpoints: &[EndpointAddr], - ) -> Result<(Vec, SignatureFields), PublishOnceError> { - let packet = self.dns_packet(endpoints)?; - let signature_fields = SignatureFields::sign(&packet, self.identity.as_ref()) - .await - .context(publish_once_error::SignPacketSnafu)?; - Ok((packet, signature_fields)) - } - - fn dns_packet(&self, endpoints: &[EndpointAddr]) -> Result, PublishOnceError> { - let mut encoded = Vec::with_capacity(endpoints.len()); - for endpoint in endpoints { - let mut endpoint = DnsEndpointAddr::try_from(*endpoint) - .map_err(|_| publish_once_error::EncodeEndpointSnafu.build())?; - if let Some(server_id) = self.options.server_id { - endpoint.set_main(server_id == 0); - endpoint.set_sequence(server_id.into()); - } - encoded.push(endpoint); - } - - let mut hosts = HashMap::new(); - hosts.insert(self.identity.name().to_owned(), encoded); - Ok(MdnsPacket::answer(0, &hosts).to_bytes()) - } - - fn public_endpoints(&self) -> Vec { - let mut endpoints = Vec::new(); - let mut seen = HashSet::new(); - for pattern in self.bind_patterns.iter() { - let Some(ifaces) = self.network.quic().get_interfaces(pattern) else { - tracing::trace!(?pattern, "no interfaces for bind pattern"); - continue; - }; - for iface in ifaces { - for endpoint in public_endpoints_from_iface(&self.network, &iface) { - push_unique_endpoint(&mut endpoints, &mut seen, endpoint); - } - } - } - endpoints - } - - #[cfg(feature = "mdns-resolver")] - fn local_endpoints_for(&self, device: &str, family: Family) -> Vec { - let mut endpoints = HashSet::new(); - for pattern in self.bind_patterns.iter() { - let Some(ifaces) = self.network.quic().get_interfaces(pattern) else { - continue; - }; - for iface in ifaces { - let bind_uri = iface.bind_uri(); - let Some((iface_family, iface_device, _port)) = bind_uri.as_iface_bind_uri() else { - continue; - }; - if iface_family != family || iface_device != device { - continue; - } - if let Some(endpoint) = local_endpoint_from_iface(&iface, family) { - endpoints.insert(endpoint); - } - } - } - endpoints.into_iter().collect() - } -} - -fn push_unique_endpoint( - endpoints: &mut Vec, - seen: &mut HashSet, - endpoint: EndpointAddr, -) { - if seen.insert(endpoint) { - endpoints.push(endpoint); - } -} - -fn public_endpoints_from_iface( - network: &h3x::dquic::Network, - iface: &h3x::dquic::net::BindInterface, -) -> Vec { - use h3x::dquic::{net::IO, qtraversal::nat::client::StunClientsComponent}; - - iface.with_components(|components, current| { - let bind_uri = current.bind_uri(); - let addr = current.bound_addr().ok(); - let mut endpoints: Vec = components - .get::() - .map(|stun| { - stun.with_clients(|clients| { - clients - .values() - .filter_map(|client| { - let outer = client.get_outer_addr()?.ok()?; - let bound = current.bound_addr().ok()?; - match client.get_nat_type() { - Some(Ok(nat_type)) => Some(publish_endpoint_from_stun( - bound, - client.agent_addr(), - outer, - nat_type, - )), - None => Some(EndpointAddr::with_agent(client.agent_addr(), outer)), - Some(Err(_)) => None, - } - }) - .collect() - }) - }) - .unwrap_or_default(); - let stun_endpoint_count = endpoints.len(); - - // Also publish the current default-route address. STUN-derived - // endpoints make the node reachable from outside the local network, - // while the bound address is still the shortest valid path for peers - // on the same link and for separate local client processes on the - // same host. Keep it after STUN endpoints so translated-NAT peers get - // the externally reachable candidate first. - if let Some(addr) = addr - && network.bound_addr_is_on_default_route(&bind_uri, addr) - { - endpoints.push(EndpointAddr::direct(addr)); - } - - tracing::trace!( - bind_uri = %bind_uri, - bound_addr = ?addr, - stun_endpoint_count, - endpoint_count = endpoints.len(), - endpoints = ?endpoints, - "collected public endpoints from interface" - ); - - endpoints - }) -} - -fn publish_endpoint_from_stun( - bound: SocketAddr, - agent: SocketAddr, - outer: SocketAddr, - nat_type: NatType, -) -> EndpointAddr { - if nat_type == NatType::FullCone && bound == outer { - EndpointAddr::direct(outer) - } else { - EndpointAddr::with_agent(agent, outer) - } -} - -#[cfg(feature = "mdns-resolver")] -fn local_endpoint_from_iface( - iface: &h3x::dquic::net::BindInterface, - family: Family, -) -> Option { - use h3x::dquic::net::IO; - - iface.with_components(|_components, current| { - let addr = current.bound_addr().ok()?; - match (family, addr) { - (Family::V4, std::net::SocketAddr::V4(_)) - | (Family::V6, std::net::SocketAddr::V6(_)) => Some(EndpointAddr::direct(addr)), - _ => None, - } - }) -} - -#[cfg(test)] -mod tests { - use std::{ - fmt, - sync::{ - Arc, OnceLock, - atomic::{AtomicUsize, Ordering}, - }, - time::Duration, - }; - - use dquic::qresolve::{ResolveFuture, Source}; - use futures::{FutureExt, StreamExt, future::BoxFuture, stream}; - use rustls::pki_types::CertificateDer; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - use super::*; - - #[derive(Debug)] - struct TestAuthority; - - impl LocalAuthority for TestAuthority { - fn name(&self) -> &str { - "authority.example" - } - - fn cert_chain(&self) -> &[CertificateDer<'static>] { - static CERT_CHAIN: OnceLock>> = OnceLock::new(); - CERT_CHAIN.get_or_init(|| { - vec![CertificateDer::from(vec![ - 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, - ])] - }) - } - - fn sign( - &self, - _data: &[u8], - ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { - Box::pin(async move { Ok(vec![1, 2, 3]) }) - } - } - - #[derive(Debug)] - struct DisplayOnlyResolver; - - impl fmt::Display for DisplayOnlyResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("display only resolver") - } - } - - impl Resolve for DisplayOnlyResolver { - fn lookup<'l>(&'l self, _name: &'l str) -> ResolveFuture<'l> { - async { Ok(stream::empty::<(Source, EndpointAddr)>().boxed()) }.boxed() - } - } - - #[tokio::test] - async fn publish_once_reports_no_publisher_resolver() { - let publisher = Publisher::new( - Arc::new(TestAuthority), - h3x::dquic::Network::builder().build(), - Arc::new(DisplayOnlyResolver), - Arc::new(Vec::new()), - ); - - let error = publisher.publish_once().await.unwrap_err(); - assert!(matches!(error, PublishOnceError::NoPublisherResolver)); - } - - #[tokio::test] - async fn publisher_timeout_is_configurable() { - let publisher = Publisher::new( - Arc::new(TestAuthority), - h3x::dquic::Network::builder().build(), - Arc::new(DisplayOnlyResolver), - Arc::new(Vec::new()), - ); - assert_eq!(publisher.publish_timeout(), DEFAULT_PUBLISH_TIMEOUT); - - let timeout = Duration::from_secs(3); - let publisher = publisher.with_publish_timeout(timeout); - assert_eq!(publisher.publish_timeout(), timeout); - } - - #[tokio::test] - async fn dns_packet_applies_publish_options_server_id() { - let publisher = Publisher::new( - Arc::new(TestAuthority), - h3x::dquic::Network::builder().build(), - Arc::new(DisplayOnlyResolver), - Arc::new(Vec::new()), - ) - .with_options(PublishOptions { server_id: Some(2) }); - - let endpoint = EndpointAddr::direct("127.0.0.1:443".parse().unwrap()); - let packet = publisher.dns_packet(&[endpoint]).unwrap(); - let (_remain, packet) = crate::core::parser::packet::be_packet(&packet).unwrap(); - let record = packet.answers.first().expect("endpoint answer"); - let crate::core::parser::record::RData::E(endpoint) = record.data() else { - panic!("expected endpoint record"); - }; - - assert!(!endpoint.is_main()); - assert!(endpoint.is_clustered()); - assert!(!endpoint.is_signed()); - } - - #[tokio::test] - async fn public_endpoints_do_not_fall_back_to_local_bound_addresses() { - let network = h3x::dquic::Network::builder().build(); - let bind_pattern: h3x::dquic::binds::BindPattern = - "inet://127.0.0.1:0".parse().expect("valid bind pattern"); - let _bind = network.quic().bind(bind_pattern.clone()).await; - let publisher = Publisher::new( - Arc::new(TestAuthority), - network, - Arc::new(DisplayOnlyResolver), - Arc::new(vec![bind_pattern]), - ); - - assert!( - publisher.public_endpoints().is_empty(), - "public DNS publishing must wait for STUN-derived external endpoints; local addresses are published through mDNS" - ); - } - - #[test] - fn push_unique_endpoint_preserves_first_seen_order() { - let agent = EndpointAddr::with_agent( - "10.10.0.2:20004".parse().expect("valid agent addr"), - "10.10.0.10:45635".parse().expect("valid outer addr"), - ); - let direct = EndpointAddr::direct("10.110.0.10:45635".parse().expect("valid direct addr")); - let mut endpoints = Vec::new(); - let mut seen = HashSet::new(); - - push_unique_endpoint(&mut endpoints, &mut seen, agent); - push_unique_endpoint(&mut endpoints, &mut seen, direct); - push_unique_endpoint(&mut endpoints, &mut seen, agent); - - assert_eq!(endpoints, vec![agent, direct]); - } - - #[test] - fn full_cone_nat_endpoint_preserves_agent_when_outer_differs_from_bound_addr() { - let bound = "10.110.0.10:45635".parse().expect("valid bound addr"); - let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); - let outer = "10.10.0.10:45635".parse().expect("valid outer addr"); - - let endpoint = publish_endpoint_from_stun(bound, agent, outer, NatType::FullCone); - - assert_eq!(endpoint, EndpointAddr::with_agent(agent, outer)); - } - - #[test] - fn full_cone_endpoint_is_direct_without_address_translation() { - let bound = "10.10.0.100:45635".parse().expect("valid bound addr"); - let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); - - let endpoint = publish_endpoint_from_stun(bound, agent, bound, NatType::FullCone); - - assert_eq!(endpoint, EndpointAddr::direct(bound)); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_restarts_when_publish_attempt_observes_location_change() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_network = network.clone(); - let server_bind_uri = bind_uri.clone(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - if current == 2 { - server_network.quic().locations().upsert( - server_bind_uri.clone(), - Arc::new(Ok::( - "127.0.0.1:10001".parse().expect("valid socket addr"), - )), - ); - } - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let mut publisher = Publisher::new( - Arc::new(TestAuthority), - network.clone(), - resolver, - Arc::new(vec![ - "inet://127.0.0.1:0".parse().expect("valid bind pattern"), - ]), - ); - publisher.interval = Duration::from_secs(60); - - let publisher = tokio::spawn(async move { - publisher.run().await; - }); - - wait_for_count(&publish_count, 1).await; - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:10000".parse().expect("valid socket addr"), - )), - ); - - tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 2)) - .await - .expect("publishable location changes should trigger the next independent publish"); - - tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), - wait_for_count(&publish_count, 3), - ) - .await - .expect("publishable location events should replace the current publish attempt"); - - publisher.abort(); - server.abort(); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_ignores_transient_location_failures_generated_during_publish_attempt() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_network = network.clone(); - let server_bind_uri = bind_uri.clone(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - server_count.fetch_add(1, Ordering::SeqCst); - server_network - .quic() - .locations() - .upsert::( - server_bind_uri.clone(), - Arc::new(Err( - dquic::qtraversal::nat::client::DetectOuterAddrError::Rebinded { - bind_uri: server_bind_uri.clone(), - }, - )), - ); - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let publisher = Publisher::new( - Arc::new(TestAuthority), - network.clone(), - resolver, - Arc::new(vec![ - "inet://127.0.0.1:0".parse().expect("valid bind pattern"), - ]), - ); - let publisher = tokio::spawn(async move { - publisher.run().await; - }); - - wait_for_count(&publish_count, 1).await; - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; - - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:0".parse().expect("valid socket addr"), - )), - ); - wait_for_count(&publish_count, 2).await; - - let third_publish = tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), - wait_for_count(&publish_count, 3), - ) - .await; - - publisher.abort(); - server.abort(); - - assert!( - third_publish.is_err(), - "publish-generated location events must not trigger another immediate publish" - ); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_does_not_retry_location_publish_after_timeout() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - if current == 2 { - tokio::time::sleep(Duration::from_millis(200)).await; - } - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let mut publisher = Publisher::new( - Arc::new(TestAuthority), - network.clone(), - resolver, - Arc::new(vec![ - "inet://127.0.0.1:0".parse().expect("valid bind pattern"), - ]), - ) - .with_publish_timeout(Duration::from_millis(50)); - publisher.interval = Duration::from_secs(60); - - let publisher = tokio::spawn(async move { - publisher.run().await; - }); - - wait_for_count(&publish_count, 1).await; - tokio::time::sleep(PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(100)).await; - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:0".parse().expect("valid socket addr"), - )), - ); - - wait_for_count(&publish_count, 2).await; - let third_publish = tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(500), - wait_for_count(&publish_count, 3), - ) - .await; - - publisher.abort(); - server.abort(); - - assert!( - third_publish.is_err(), - "timed out location-triggered publish must not be retried before the next interval" - ); - } - - #[cfg(feature = "http-resolver")] - #[tokio::test] - async fn run_replaces_in_flight_publish_on_publishable_location_change() { - async fn wait_for_count(count: &AtomicUsize, target: usize) { - loop { - if count.load(Ordering::SeqCst) >= target { - return; - } - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - - let network = h3x::dquic::Network::builder().build(); - let bind_uri: h3x::dquic::net::BindUri = - "inet://127.0.0.1:0".parse().expect("valid bind uri"); - let publish_count = Arc::new(AtomicUsize::new(0)); - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .expect("bind test http server"); - let port = listener.local_addr().expect("local addr").port(); - let server_count = publish_count.clone(); - let server = tokio::spawn(async move { - loop { - let Ok((mut stream, _peer)) = listener.accept().await else { - break; - }; - let current = server_count.fetch_add(1, Ordering::SeqCst) + 1; - tokio::spawn(async move { - let mut buf = [0_u8; 1024]; - let _ = stream.read(&mut buf).await; - if current == 1 { - std::future::pending::<()>().await; - } - let _ = stream - .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n") - .await; - }); - } - }); - - let resolver = Arc::new( - crate::resolvers::http::HttpResolver::new(format!("http://127.0.0.1:{port}/")) - .expect("valid http resolver"), - ); - let mut publisher = Publisher::new( - Arc::new(TestAuthority), - network.clone(), - resolver, - Arc::new(vec![ - "inet://127.0.0.1:0".parse().expect("valid bind pattern"), - ]), - ) - .with_publish_timeout(Duration::from_secs(30)); - publisher.interval = Duration::from_secs(60); - - let publisher = tokio::spawn(async move { - publisher.run().await; - }); - - tokio::time::timeout(Duration::from_secs(2), wait_for_count(&publish_count, 1)) - .await - .expect("initial publish should start"); - - network.quic().locations().upsert( - bind_uri, - Arc::new(Ok::( - "127.0.0.1:10000".parse().expect("valid socket addr"), - )), - ); - - tokio::time::timeout( - PUBLISH_CHANGE_DEBOUNCE + Duration::from_millis(800), - wait_for_count(&publish_count, 2), - ) - .await - .expect("publishable location change should replace the in-flight publish"); - - publisher.abort(); - server.abort(); - } -} diff --git a/src/publisher/address.rs b/src/publisher/address.rs deleted file mode 100644 index 207150b..0000000 --- a/src/publisher/address.rs +++ /dev/null @@ -1,444 +0,0 @@ -use std::{ - collections::HashSet, - net::SocketAddr, - sync::{Arc, OnceLock}, -}; - -use dquic::{ - qbase::net::{Family, addr::EndpointAddr}, - qinterface::component::location::Observer, -}; -use h3x::dquic::{ - Network, - binds::BindPattern, - net::{BindInterface, BindUri, IO, Scheme}, - qtraversal::nat::client::{NatType, StunClientsComponent}, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum AddressSelector<'a> { - WideArea, - LocalLink { device: &'a str, family: Family }, -} - -pub trait AddressView { - fn endpoints<'a>( - &'a self, - selector: AddressSelector<'a>, - ) -> impl Iterator + 'a; -} - -pub struct FnAddressView { - f: F, -} - -impl FnAddressView { - pub fn new(f: F) -> Self { - Self { f } - } -} - -impl AddressView for FnAddressView -where - F: for<'a> Fn(AddressSelector<'a>) -> I, - I: IntoIterator, - I::IntoIter: 'static, -{ - fn endpoints<'a>( - &'a self, - selector: AddressSelector<'a>, - ) -> impl Iterator + 'a { - (self.f)(selector).into_iter() - } -} - -pub trait AddressViewSource { - fn address_view(&self) -> impl AddressView + Send + Sync + '_; - fn subscribe(&self) -> Observer; - fn observes(&self, bind_uri: &BindUri) -> bool; -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PublishAddressScope { - WideArea, - LocalLink { device: Arc, family: Family }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PublishAddressGroup { - scope: PublishAddressScope, - endpoints: Vec, -} - -impl PublishAddressGroup { - pub fn wide_area(endpoints: I) -> Self - where - I: IntoIterator, - { - Self { - scope: PublishAddressScope::WideArea, - endpoints: endpoints.into_iter().collect(), - } - } - - pub fn local_link(device: impl Into>, family: Family, endpoints: I) -> Self - where - I: IntoIterator, - { - Self { - scope: PublishAddressScope::LocalLink { - device: device.into(), - family, - }, - endpoints: endpoints.into_iter().collect(), - } - } - - fn matches(&self, selector: AddressSelector<'_>) -> bool { - match (&self.scope, selector) { - (PublishAddressScope::WideArea, AddressSelector::WideArea) => true, - ( - PublishAddressScope::LocalLink { device, family }, - AddressSelector::LocalLink { - device: selected_device, - family: selected_family, - }, - ) => device.as_ref() == selected_device && *family == selected_family, - _ => false, - } - } -} - -#[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct PublishAddresses { - groups: Vec, -} - -impl PublishAddresses { - pub fn new() -> Self { - Self::default() - } - - pub fn group(mut self, group: PublishAddressGroup) -> Self { - self.groups.push(group); - self - } - - pub fn wide_area(self, endpoints: I) -> Self - where - I: IntoIterator, - { - self.group(PublishAddressGroup::wide_area(endpoints)) - } - - pub fn local_link(self, device: impl Into>, family: Family, endpoints: I) -> Self - where - I: IntoIterator, - { - self.group(PublishAddressGroup::local_link(device, family, endpoints)) - } -} - -impl AddressView for PublishAddresses { - fn endpoints<'a>( - &'a self, - selector: AddressSelector<'a>, - ) -> impl Iterator + 'a { - self.groups - .iter() - .filter(move |group| group.matches(selector)) - .flat_map(move |group| group.endpoints.iter().copied()) - } -} - -#[derive(Clone)] -pub struct EndpointBindingAddresses { - network: Arc, - bind_patterns: Arc>, -} - -impl std::fmt::Debug for EndpointBindingAddresses { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EndpointBindingAddresses") - .field("bind_patterns", &self.bind_patterns) - .finish_non_exhaustive() - } -} - -impl EndpointBindingAddresses { - pub fn new(network: Arc, bind_patterns: Arc>) -> Self { - Self { - network, - bind_patterns, - } - } -} - -impl AddressViewSource for EndpointBindingAddresses { - fn address_view(&self) -> impl AddressView + Send + Sync + '_ { - EndpointBindingAddressView::new(self.network.clone(), self.bind_patterns.clone()) - } - - fn subscribe(&self) -> Observer { - self.network.quic().locations().subscribe() - } - - fn observes(&self, bind_uri: &BindUri) -> bool { - self.bind_patterns - .iter() - .any(|pattern| pattern.matches(bind_uri)) - } -} - -struct EndpointBindingAddressView { - bindings: Vec, -} - -impl EndpointBindingAddressView { - fn new(network: Arc, bind_patterns: Arc>) -> Self { - let mut bindings = Vec::new(); - for pattern in bind_patterns.iter() { - let Some(ifaces) = network.quic().get_interfaces(pattern) else { - tracing::trace!(?pattern, "no interfaces for bind pattern"); - continue; - }; - for iface in ifaces { - bindings.push(BindingAddress::new(network.clone(), pattern.clone(), iface)); - } - } - Self { bindings } - } -} - -impl AddressView for EndpointBindingAddressView { - fn endpoints<'a>( - &'a self, - selector: AddressSelector<'a>, - ) -> impl Iterator + 'a { - let mut seen = HashSet::new(); - self.bindings - .iter() - .filter(move |binding| binding.may_match(selector)) - .flat_map(move |binding| binding.endpoints(selector)) - .filter(move |endpoint| seen.insert(*endpoint)) - } -} - -struct BindingAddress { - network: Arc, - pattern: BindPattern, - bind_uri: BindUri, - iface: BindInterface, - wide_area: OnceLock>, - local_link: OnceLock>, -} - -impl BindingAddress { - fn new(network: Arc, pattern: BindPattern, iface: BindInterface) -> Self { - let bind_uri = iface.bind_uri(); - Self { - network, - pattern, - bind_uri, - iface, - wide_area: OnceLock::new(), - local_link: OnceLock::new(), - } - } - - fn may_match(&self, selector: AddressSelector<'_>) -> bool { - match selector { - AddressSelector::WideArea => true, - AddressSelector::LocalLink { device, family } => { - pattern_may_match_local_link(&self.pattern, device, family) - && bind_uri_matches_local_link(&self.bind_uri, device, family) - } - } - } - - fn endpoints<'a>( - &'a self, - selector: AddressSelector<'a>, - ) -> impl Iterator + 'a { - let endpoints = match selector { - AddressSelector::WideArea => self - .wide_area - .get_or_init(|| public_endpoints_from_iface(&self.network, &self.iface)), - AddressSelector::LocalLink { family, .. } => self - .local_link - .get_or_init(|| local_endpoints_from_iface(&self.iface, family)), - }; - endpoints.iter().copied() - } -} - -fn pattern_may_match_local_link(pattern: &BindPattern, device: &str, family: Family) -> bool { - if pattern.scheme != Scheme::Iface { - return false; - } - if pattern - .host - .family() - .is_some_and(|pattern_family| pattern_family != family) - { - return false; - } - pattern.host.matches(device) -} - -fn bind_uri_matches_local_link(bind_uri: &BindUri, device: &str, family: Family) -> bool { - bind_uri - .as_iface_bind_uri() - .is_some_and(|(iface_family, iface_device, _port)| { - iface_family == family && iface_device == device - }) -} - -fn public_endpoints_from_iface(network: &Network, iface: &BindInterface) -> Vec { - iface.with_components(|components, current| { - let bind_uri = current.bind_uri(); - let addr = current.bound_addr().ok(); - let mut endpoints: Vec = components - .get::() - .map(|stun| { - stun.with_clients(|clients| { - clients - .values() - .filter_map(|client| { - let outer = client.get_outer_addr()?.ok()?; - let bound = current.bound_addr().ok()?; - match client.get_nat_type() { - Some(Ok(nat_type)) => Some(publish_endpoint_from_stun( - bound, - client.agent_addr(), - outer, - nat_type, - )), - None => Some(EndpointAddr::with_agent(client.agent_addr(), outer)), - Some(Err(_)) => None, - } - }) - .collect() - }) - }) - .unwrap_or_default(); - let stun_endpoint_count = endpoints.len(); - - if let Some(addr) = addr - && network.bound_addr_is_on_default_route(&bind_uri, addr) - { - endpoints.push(EndpointAddr::direct(addr)); - } - - tracing::trace!( - bind_uri = %bind_uri, - bound_addr = ?addr, - stun_endpoint_count, - endpoint_count = endpoints.len(), - endpoints = ?endpoints, - "collected wide-area endpoints from interface" - ); - - endpoints - }) -} - -fn publish_endpoint_from_stun( - bound: SocketAddr, - agent: SocketAddr, - outer: SocketAddr, - nat_type: NatType, -) -> EndpointAddr { - if nat_type == NatType::FullCone && bound == outer { - EndpointAddr::direct(outer) - } else { - EndpointAddr::with_agent(agent, outer) - } -} - -fn local_endpoints_from_iface(iface: &BindInterface, family: Family) -> Vec { - iface.with_components(|_components, current| { - let Some(addr) = current.bound_addr().ok() else { - return Vec::new(); - }; - match (family, addr) { - (Family::V4, SocketAddr::V4(_)) | (Family::V6, SocketAddr::V6(_)) => { - vec![EndpointAddr::direct(addr)] - } - _ => Vec::new(), - } - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn publish_addresses_select_wide_area_only_for_wide_area_selector() { - let wide = EndpointAddr::direct("203.0.113.10:443".parse().unwrap()); - let local = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); - let addresses = - PublishAddresses::new() - .wide_area([wide]) - .local_link("en0", Family::V4, [local]); - - let selected: Vec<_> = addresses.endpoints(AddressSelector::WideArea).collect(); - - assert_eq!(selected, vec![wide]); - } - - #[test] - fn publish_addresses_select_matching_local_link_group() { - let en0 = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); - let en1 = EndpointAddr::direct("192.168.2.20:443".parse().unwrap()); - let addresses = PublishAddresses::new() - .local_link("en0", Family::V4, [en0]) - .local_link("en1", Family::V4, [en1]); - - let selected: Vec<_> = addresses - .endpoints(AddressSelector::LocalLink { - device: "en1", - family: Family::V4, - }) - .collect(); - - assert_eq!(selected, vec![en1]); - } - - #[test] - fn publish_addresses_reject_local_link_family_mismatch() { - let endpoint = EndpointAddr::direct("192.168.1.20:443".parse().unwrap()); - let addresses = PublishAddresses::new().local_link("en0", Family::V4, [endpoint]); - - let selected: Vec<_> = addresses - .endpoints(AddressSelector::LocalLink { - device: "en0", - family: Family::V6, - }) - .collect(); - - assert!(selected.is_empty()); - } - - #[test] - fn full_cone_nat_endpoint_preserves_agent_when_outer_differs_from_bound_addr() { - let bound = "10.110.0.10:45635".parse().expect("valid bound addr"); - let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); - let outer = "10.10.0.10:45635".parse().expect("valid outer addr"); - - let endpoint = publish_endpoint_from_stun(bound, agent, outer, NatType::FullCone); - - assert_eq!(endpoint, EndpointAddr::with_agent(agent, outer)); - } - - #[test] - fn full_cone_endpoint_is_direct_without_address_translation() { - let bound = "10.10.0.100:45635".parse().expect("valid bound addr"); - let agent = "10.10.0.2:20004".parse().expect("valid agent addr"); - - let endpoint = publish_endpoint_from_stun(bound, agent, bound, NatType::FullCone); - - assert_eq!(endpoint, EndpointAddr::direct(bound)); - } -} diff --git a/src/publisher/dispatch.rs b/src/publisher/dispatch.rs deleted file mode 100644 index 84693ff..0000000 --- a/src/publisher/dispatch.rs +++ /dev/null @@ -1,164 +0,0 @@ -use std::any::Any; - -use dhttp_identity::{identity::LocalAuthority, name::Name}; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, Resolve}, -}; -use snafu::ResultExt; - -use super::{ - AddressSelector, AddressView, PublishOnceError, Publisher, PublisherResolver, - publish_once_error, -}; -use crate::resolvers::Resolvers; - -#[cfg(feature = "h3x-resolver")] -type DeferredH3Resolver = crate::resolvers::deferred::DeferredResolver< - crate::resolvers::h3::H3Resolver, ->; - -impl Publisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: PublisherResolver + ?Sized, -{ - pub(crate) async fn publish_to_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - addresses: &V, - ) -> Result - where - V: AddressView + Sync, - { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - let mut published = false; - for resolver in resolvers.iter() { - published |= self - .publish_single_resolver(resolver.as_ref(), name, addresses) - .await?; - } - return Ok(published); - } - - self.publish_single_resolver(resolver, name, addresses) - .await - } - - async fn publish_single_resolver( - &self, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - addresses: &V, - ) -> Result - where - V: AddressView + Sync, - { - #[cfg(not(any( - feature = "http-resolver", - feature = "h3x-resolver", - feature = "mdns-resolver" - )))] - { - let _ = name; - let _ = addresses; - } - - let any: &dyn Any = resolver; - - #[cfg(feature = "http-resolver")] - if let Some(http) = any.downcast_ref::() { - self.publish_selected(http, name, addresses, AddressSelector::WideArea) - .await?; - return Ok(true); - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - self.publish_selected(h3, name, addresses, AddressSelector::WideArea) - .await?; - return Ok(true); - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = any.downcast_ref::() { - self.publish_selected(h3, name, addresses, AddressSelector::WideArea) - .await?; - return Ok(true); - } - - #[cfg(feature = "mdns-resolver")] - if let Some(mdns) = any.downcast_ref::() { - let mut published = false; - for bound in mdns.bound_resolvers() { - self.publish_selected( - &bound.resolver, - name, - addresses, - AddressSelector::LocalLink { - device: &bound.device, - family: bound.family, - }, - ) - .await?; - published = true; - } - return Ok(published); - } - - Ok(false) - } - - async fn publish_selected( - &self, - publisher: &(dyn Publish + Send + Sync), - name: &Name<'_>, - addresses: &V, - selector: AddressSelector<'_>, - ) -> Result<(), PublishOnceError> - where - V: AddressView + Sync, - { - let endpoints: Vec = addresses.endpoints(selector).collect(); - let packet = self - .signer - .signed_packet(name, &endpoints) - .await - .context(publish_once_error::SignEndpointRecordsSnafu)?; - tracing::debug!( - publisher = %publisher, - name = %name, - endpoint_count = endpoints.len(), - packet_len = packet.len(), - "publishing dns packet" - ); - publisher - .publish(name.as_str(), &packet) - .await - .context(publish_once_error::PublishSnafu { - publisher: publisher.to_string(), - }) - } -} - -pub(crate) fn clear_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { - let any: &dyn Any = resolver; - - if let Some(resolvers) = any.downcast_ref::() { - for resolver in resolvers.iter() { - clear_resolver_publish_state(resolver.as_ref()); - } - } - - #[cfg(feature = "h3x-resolver")] - if let Some(h3) = - any.downcast_ref::>() - { - h3.clear_pool(); - } -} diff --git a/src/publisher/packet.rs b/src/publisher/packet.rs deleted file mode 100644 index 956afe2..0000000 --- a/src/publisher/packet.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use dhttp_identity::{ - identity::{LocalAuthority, LocalAuthorityCertificateExt}, - name::Name, -}; -use dquic::qbase::net::addr::EndpointAddr; -use snafu::{ResultExt, Snafu}; - -use crate::core::{ - MdnsPacket, - parser::record::endpoint::{EndpointAddr as DnsEndpointAddr, SignEndpointError}, -}; - -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum SignEndpointRecordsError { - #[snafu(display("failed to encode endpoint address"))] - EncodeEndpoint, - #[snafu(display("failed to extract dhttp certificate selector"))] - CertificateSelector { - source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, - }, - #[snafu(display("failed to sign endpoint address"))] - SignEndpoint { source: SignEndpointError }, -} - -pub struct EndpointRecordSigner { - authority: Arc, -} - -impl std::fmt::Debug for EndpointRecordSigner -where - A: LocalAuthority, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EndpointRecordSigner") - .field("authority", &self.authority.name()) - .finish() - } -} - -impl EndpointRecordSigner -where - A: LocalAuthority + Send + Sync + ?Sized, -{ - pub fn new(authority: Arc) -> Self { - Self { authority } - } - - pub fn authority(&self) -> &Arc { - &self.authority - } - - pub async fn signed_packet( - &self, - name: &Name<'_>, - endpoints: &[EndpointAddr], - ) -> Result, SignEndpointRecordsError> { - let selector = self - .authority - .dhttp_subject_key_identifier() - .context(sign_endpoint_records_error::CertificateSelectorSnafu)?; - let chain = selector.chain(); - - let mut signed = Vec::with_capacity(endpoints.len()); - for endpoint in endpoints { - let Ok(mut endpoint) = DnsEndpointAddr::try_from(*endpoint) else { - return sign_endpoint_records_error::EncodeEndpointSnafu.fail(); - }; - endpoint.set_certificate_chain_key(chain); - endpoint - .sign_with_authority(self.authority.as_ref()) - .await - .context(sign_endpoint_records_error::SignEndpointSnafu)?; - signed.push(endpoint); - } - - let mut hosts = HashMap::new(); - hosts.insert(name.as_str().to_owned(), signed); - Ok(MdnsPacket::answer(0, &hosts).to_bytes()) - } -} diff --git a/src/resolvers/h3.rs b/src/resolvers/h3.rs deleted file mode 100644 index bc0a072..0000000 --- a/src/resolvers/h3.rs +++ /dev/null @@ -1,556 +0,0 @@ -use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; - -use dashmap::DashMap; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture, Source}, -}; -use futures::{StreamExt, stream}; -use h3x::{ - dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, - dquic::ConnectError, - endpoint::H3Endpoint, - quic, -}; -use http_body_util::{BodyExt, Empty, Full}; -use tokio::time::Instant; -use tracing::trace; -use url::Url; - -use crate::core::{ - MdnsPacket, - parser::packet::be_packet, - signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, - wire::be_multi_response, -}; - -const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); -const LOOKUP_REQUEST_ATTEMPTS: usize = 3; - -pub struct H3Resolver { - endpoint: Arc>, - base_url: Url, - cached_records: DashMap, - negative_cache: DashMap, -} - -#[derive(Debug)] -struct Record { - addrs: Vec, - expire: Instant, -} - -impl fmt::Debug for H3Resolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("H3Resolver") - .field("base_url", &self.base_url) - .finish_non_exhaustive() - } -} - -impl fmt::Display for H3Resolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "H3 DNS Resolver({})", self.base_url) - } -} - -#[derive(Debug, snafu::Snafu)] -pub enum Error { - #[snafu(display("h3 stream error"))] - H3Stream { source: MessageStreamError }, - #[snafu(display("failed to connect h3 endpoint"))] - Connect { source: h3x::pool::ConnectError }, - #[snafu(display("h3 request error"))] - H3Request { - source: HyperRequestError, - }, - #[snafu(display("h3 request timed out after {timeout:?}"))] - RequestTimeout { timeout: Duration }, - - #[snafu(display("{status}"))] - Status { status: http::StatusCode }, - - #[snafu(display("no DNS record found"))] - NoRecordFound, - - #[snafu(display("failed to parse DNS records from response"))] - ParseRecords { - source: nom::Err>>, - }, - - #[snafu(display("failed to decode multi-record response"))] - ParseMultiResponse, -} - -impl H3Resolver -where - C::Error: Send + Sync + 'static, - C::Connection: Send + 'static, -{ - pub fn new( - base_url: impl AsRef, - client: H3Endpoint, - ) -> io::Result { - Self::from_endpoint(base_url, Arc::new(client)) - } - - pub fn from_endpoint( - base_url: impl AsRef, - endpoint: Arc>, - ) -> io::Result { - let base_url = Url::parse(base_url.as_ref()) - .map_err(|error| io::Error::new(io::ErrorKind::InvalidInput, error))?; - base_url.host_str().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "base URL must have a valid host", - ) - })?; - - Ok(Self { - endpoint, - base_url, - cached_records: DashMap::new(), - negative_cache: DashMap::new(), - }) - } - - fn connect_error(&self, source: h3x::pool::ConnectError) -> Error { - // H3 DNS resolvers keep a long-lived endpoint. A network transition may - // leave the cached H3 connection with stale QUIC paths, so the next - // attempt must establish a fresh connection instead of reusing it. - self.endpoint.clear_pool(); - Error::Connect { source } - } - - fn request_error(&self, source: HyperRequestError) -> Error { - self.endpoint.clear_pool(); - Error::H3Request { source } - } - - async fn execute_request( - &self, - request: http::Request< - impl http_body::Body + Send + 'static, - >, - ) -> Result< - http::Response>, - Error, - > { - let authority = request - .uri() - .authority() - .expect("h3 dns request URL must include an authority") - .clone(); - tracing::trace!(%authority, "connecting h3 dns endpoint"); - let connection = match self.endpoint.connect(authority.clone()).await { - Ok(connection) => { - tracing::trace!(%authority, "connected h3 dns endpoint"); - connection - } - Err(source) => return Err(self.connect_error(source)), - }; - - let method = request.method().clone(); - let uri = request.uri().clone(); - tracing::trace!(%method, %uri, "executing h3 dns request"); - match connection.execute_hyper_request(request).await { - Ok(response) => { - tracing::trace!( - status = %response.status(), - "h3 dns request response received" - ); - Ok(response) - } - Err(source) => Err(self.request_error(source)), - } - } - - pub fn clear_pool(&self) { - self.endpoint.clear_pool(); - } - - pub async fn publish_endpoints( - &self, - name: &str, - endpoints: &[EndpointAddr], - ) -> Result<(), Error> { - trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); - let bytes = { - let endpoints = endpoints - .iter() - .filter_map(|ep| { - crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() - }) - .collect(); - let mut hosts = std::collections::HashMap::new(); - hosts.insert(name.to_string(), endpoints); - MdnsPacket::answer(0, &hosts).to_bytes() - }; - - self.publish_packet(name, &bytes).await - } - - /// Publish a pre-built DNS packet (with signatures already included). - pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { - self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) - .await - } - - pub async fn publish_signed( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> io::Result<()> { - self.publish_packet_with_signature(name, packet, signature_fields) - .await - .map_err(io::Error::other) - } - - async fn publish_packet_with_signature( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> Result<(), Error> { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - tracing::trace!( - name, - packet_len = packet.len(), - url = %self.base_url, - "h3x publishing packet" - ); - let mut request = http::Request::post(uri); - if !signature_fields.is_empty() { - request = request - .header( - CONTENT_DIGEST_HEADER, - signature_fields.content_digest.as_slice(), - ) - .header( - SIGNATURE_INPUT_HEADER, - signature_fields.signature_input.as_slice(), - ) - .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); - } - let request = request - .body(Full::new(bytes::Bytes::copy_from_slice(packet))) - .expect("h3 dns publish request must be valid"); - let resp = self.execute_request(request).await?; - - if resp.status() != http::StatusCode::OK { - return Err(Error::Status { - status: resp.status(), - }); - } - - Ok(()) - } - - fn retryable_lookup_error(error: &Error) -> bool { - matches!( - error, - Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } - ) - } - - async fn lookup_response(&self, uri: http::Uri) -> Result> { - let request = http::Request::get(uri) - .body(Empty::::new()) - .expect("h3 dns lookup request must be valid"); - let resp = self.execute_request(request).await?; - - tracing::trace!("received response with status {}", resp.status()); - match resp.status() { - http::StatusCode::OK => {} - http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), - status => return Err(Error::Status { status }), - } - - match resp.into_body().collect().await { - Ok(response) => Ok(response.to_bytes()), - Err(source) => Err(Error::H3Stream { source }), - } - } - - async fn lookup_response_with_retry( - &self, - uri: http::Uri, - ) -> Result> { - for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { - match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) - .await - { - Ok(Ok(response)) => return Ok(response), - Ok(Err(error)) - if Self::retryable_lookup_error(&error) - && attempt < LOOKUP_REQUEST_ATTEMPTS => - { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup failed, retrying" - ); - } - Ok(Err(error)) => return Err(error), - Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup timed out, retrying" - ); - } - Err(_elapsed) => { - self.endpoint.clear_pool(); - return Err(Error::RequestTimeout { - timeout: LOOKUP_REQUEST_TIMEOUT, - }); - } - } - } - - unreachable!("lookup retry loop returns on the final attempt") - } - - pub async fn lookup(&self, name: &str) -> Result> { - use crate::core::parser::record; - let server = Arc::from(self.base_url.origin().ascii_serialization()); - let source = Source::H3 { server }; - - let Some(domain) = super::resolvable_name(name) else { - return Err(Error::NoRecordFound); - }; - - let now = Instant::now(); - let positive_ttl = Duration::from_secs(10); - let negative_ttl = Duration::from_secs(2); - - self.cached_records - .retain(|_host, record| record.expire > now); - self.negative_cache.retain(|_host, expire| *expire > now); - - if self.negative_cache.get(domain).is_some() { - return Err(Error::NoRecordFound); - } - - if let Some(record) = self.cached_records.get(domain) { - let addrs = record.addrs.clone(); - let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); - return Ok(stream.boxed()); - } - - let mut url = self.base_url.join("lookup").expect("Invalid URL"); - url.set_query(Some(&format!("host={}", domain))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - - tracing::trace!("sending lookup request to {}", self.base_url); - let response = match self.lookup_response_with_retry(uri).await { - Ok(response) => response, - Err(Error::NoRecordFound) => { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); - } - Err(error) => return Err(error), - }; - - // Server always returns multi-record format. - let (remain, multi) = - be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; - if !remain.is_empty() { - return Err(Error::ParseMultiResponse); - } - - let mut addrs = Vec::new(); - for r in multi.records { - if !r.signature_fields.is_empty() { - match r.signature_fields.verify(&r.dns, &r.cert) { - Ok(true) => {} - Ok(false) => { - tracing::debug!("ignored record with invalid DNS packet signature"); - continue; - } - Err(error) => { - tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); - continue; - } - } - } - - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - addrs.extend( - packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - if answer.name() != domain { - tracing::debug!( - answer_name = %answer.name(), - query = domain, - "ignored endpoint answer for different name" - ); - return None; - } - let endpoint = TryInto::::try_into(ep.clone()).ok()?; - trace!(?endpoint, "parsed endpoint from record"); - Some(endpoint) - } - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }), - ); - } - - if addrs.is_empty() { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); - } - - self.cached_records.insert( - domain.to_string(), - Record { - addrs: addrs.clone(), - expire: now + positive_ttl, - }, - ); - - self.negative_cache.remove(domain); - - Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) - } -} - -pub type H3Publisher = H3Resolver; - -impl Publish for H3Publisher -where - C::Error: Send + Sync + 'static, - C::Connection: Send + 'static, -{ - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - Box::pin(async move { - match self.publish_packet(name, packet).await { - Ok(()) => Ok(()), - Err(error) => Err(io::Error::other(error)), - } - }) - } -} - -impl Resolve for H3Resolver -where - C::Error: Send + Sync + 'static, - C::Connection: Send + 'static, -{ - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - Box::pin(async move { - match H3Resolver::lookup(self, name).await { - Ok(stream) => Ok(stream), - Err(error) => Err(io::Error::other(error)), - } - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::resolvers::DHTTP_H3_DNS_SERVER; - - #[test] - fn lookup_retry_budget_leaves_external_timeout_margin() { - let total_budget = LOOKUP_REQUEST_TIMEOUT * LOOKUP_REQUEST_ATTEMPTS as u32; - - assert!( - total_budget <= Duration::from_secs(10), - "h3 lookup must return before common 15s command timeouts so callers can retry" - ); - } - - #[tokio::test] - async fn cached_lookup_reports_h3_dns_source() { - let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( - h3x::dquic::QuicEndpoint::builder().build().await, - )); - let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "car.lab.dhttp.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, - ); - - let mut records = resolver.lookup("car.lab.dhttp.net").await.unwrap(); - let (source, endpoint) = records.next().await.unwrap(); - - assert_eq!( - source, - Source::H3 { - server: Arc::from(resolver.base_url.origin().ascii_serialization()) - } - ); - assert_eq!( - endpoint, - EndpointAddr::direct("192.168.5.78:41748".parse().unwrap()) - ); - } - - #[tokio::test] - async fn cached_dns_genmeta_net_record_is_returned() { - let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( - h3x::dquic::QuicEndpoint::builder().build().await, - )); - let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "dns.genmeta.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, - ); - - let mut records = resolver.lookup("dns.genmeta.net").await.unwrap(); - let (_source, endpoint) = records.next().await.unwrap(); - - assert_eq!( - endpoint, - EndpointAddr::direct("192.0.2.53:4433".parse().unwrap()) - ); - } - - #[tokio::test] - async fn cached_lookup_uses_e_record_port_not_input_port() { - let endpoint = Arc::new(h3x::endpoint::H3Endpoint::new( - h3x::dquic::QuicEndpoint::builder().build().await, - )); - let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "nat.genmeta.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, - ); - - let mut records = resolver.lookup("nat.genmeta.net:20004").await.unwrap(); - let (_source, endpoint) = records.next().await.unwrap(); - - assert_eq!( - endpoint, - EndpointAddr::direct("192.0.2.10:21000".parse().unwrap()) - ); - } -} diff --git a/src/resolvers/http.rs b/src/resolvers/http.rs deleted file mode 100644 index dae9498..0000000 --- a/src/resolvers/http.rs +++ /dev/null @@ -1,271 +0,0 @@ -use std::{fmt::Display, io, sync::Arc}; - -use dashmap::DashMap; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, PublishFuture, Resolve, ResolveFuture, Source}, -}; -use futures::{StreamExt, TryFutureExt, stream}; -use reqwest::{Client, IntoUrl, StatusCode, Url}; -use tokio::time::Instant; - -use crate::core::{ - parser::packet::be_packet, - signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, - wire::be_multi_response, -}; - -#[derive(Debug)] -struct Record { - addrs: Vec, - expire: Instant, -} - -#[derive(Debug)] -pub struct HttpResolver { - http_client: Client, - base_url: Url, - cached_records: DashMap, -} - -impl Display for HttpResolver { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Http DNS({})", - self.base_url.host_str().expect("checked in constructor") - ) - } -} - -impl HttpResolver { - pub fn new(base_url: impl IntoUrl) -> io::Result { - let base_url = base_url - .into_url() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - base_url.host_str().ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "base URL must have a valid host", - ) - })?; - - Ok(Self { - http_client: build_http_client()?, - base_url, - cached_records: DashMap::new(), - }) - } - - pub async fn publish_signed( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> io::Result<()> { - self.publish_packet_with_signature(name, packet, signature_fields) - .await - .map_err(io::Error::other) - } - - async fn publish_packet_with_signature( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> Result<(), Error> { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let mut request = self - .http_client - .post(url) - .header("Content-Type", "application/octet-stream"); - if !signature_fields.is_empty() { - request = request - .header( - CONTENT_DIGEST_HEADER, - signature_fields.content_digest.as_slice(), - ) - .header( - SIGNATURE_INPUT_HEADER, - signature_fields.signature_input.as_slice(), - ) - .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); - } - request - .body(packet.to_vec()) - .send() - .await? - .error_for_status()?; - Ok(()) - } -} - -fn build_http_client() -> io::Result { - let native_certs = rustls_native_certs::load_native_certs(); - for error in &native_certs.errors { - let report = snafu::Report::from_error(error); - tracing::warn!(error = %report, "failed to load native root certificate"); - } - - let mut root_store = rustls::RootCertStore::empty(); - let (valid_roots, invalid_roots) = root_store.add_parsable_certificates(native_certs.certs); - if invalid_roots > 0 { - tracing::debug!(invalid_roots, "ignored invalid native root certificates"); - } - if valid_roots == 0 { - tracing::warn!("no native root certificates loaded for http resolver"); - } - - let mut tls = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - tls.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - - Client::builder() - .use_preconfigured_tls(tls) - .build() - .map_err(io::Error::other) -} - -#[derive(Debug, snafu::Snafu)] -enum Error { - #[snafu(display("http request failed"))] - Reqwest { source: reqwest::Error }, - - #[snafu(display("{status}"))] - Status { status: StatusCode }, - - #[snafu(display("no DNS record found"))] - NoRecordFound, - - #[snafu(display("failed to parse DNS records from response"))] - ParseRecords { - source: nom::Err>>, - }, - - #[snafu(display("failed to decode multi-record response"))] - ParseMultiResponse, -} - -impl From for Error { - fn from(source: reqwest::Error) -> Self { - match source.status() { - Some(stateus) if stateus == StatusCode::NOT_FOUND => Error::NoRecordFound, - Some(status) => Error::Status { status }, - None => Error::Reqwest { - source: source.without_url(), - }, - } - } -} - -impl Publish for HttpResolver { - fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { - Box::pin(async move { - self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) - .await - .map_err(io::Error::other) - }) - } -} - -impl Resolve for HttpResolver { - fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { - let lookup = async move { - let Some(domain) = super::resolvable_name(name) else { - return Err(Error::NoRecordFound); - }; - - let now = Instant::now(); - let server = Arc::from(self.base_url.host_str().unwrap_or("")); - let soource = Source::Http { server }; - - use crate::core::parser::record; - self.cached_records - .retain(|_host, Record { expire, .. }| *expire < now); - if let Some(record) = self.cached_records.get(domain) { - let endpoint_addrs: Vec<_> = record - .addrs - .iter() - .map(|endpoint: &EndpointAddr| (soource.clone(), *endpoint)) - .collect(); - return Ok(stream::iter(endpoint_addrs).boxed()); - } - let response = self - .http_client - .get(self.base_url.join("lookup").expect("Invalid URL")) - .query(&[("host", domain)]) - .send() - .await; - - let response = response?.error_for_status()?.bytes().await?; - let (remain, multi) = - be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; - if !remain.is_empty() { - return Err(Error::ParseMultiResponse); - } - - let mut addrs = Vec::new(); - for r in multi.records { - if !r.signature_fields.is_empty() { - match r.signature_fields.verify(&r.dns, &r.cert) { - Ok(true) => {} - Ok(false) => { - tracing::debug!("ignored record with invalid DNS packet signature"); - continue; - } - Err(error) => { - tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); - continue; - } - } - } - let (_remain, packet) = - be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - addrs.extend( - packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - if answer.name() != domain { - tracing::debug!( - answer_name = %answer.name(), - query = domain, - "ignored endpoint answer for different name" - ); - return None; - } - let endpoint = - TryInto::::try_into(ep.clone()).ok()?; - Some(endpoint) - } - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }), - ); - } - if addrs.is_empty() { - return Err(Error::NoRecordFound); - } - - // cache the addrs - self.cached_records.insert( - domain.to_string(), - Record { - addrs: addrs.clone(), - expire: now + std::time::Duration::from_secs(300), - }, - ); - - Ok(stream::iter(addrs.into_iter().map(move |ep| (soource.clone(), ep))).boxed()) - }; - Box::pin(lookup.map_err(io::Error::other)) - } -} From 965bc9c781100632d1e7866fd870d0c909f64785 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 19:03:04 +0800 Subject: [PATCH 11/29] refactor: split h3 resolver internals --- src/h3.rs | 363 ++++------------------------------------------ src/h3/cache.rs | 1 + src/h3/lookup.rs | 134 +++++++++++++++++ src/h3/publish.rs | 99 +++++++++++++ src/h3/request.rs | 140 ++++++++++++++++++ src/resolvers.rs | 4 +- 6 files changed, 408 insertions(+), 333 deletions(-) create mode 100644 src/h3/cache.rs create mode 100644 src/h3/lookup.rs create mode 100644 src/h3/publish.rs create mode 100644 src/h3/request.rs diff --git a/src/h3.rs b/src/h3.rs index ea5c502..ad1f7e2 100644 --- a/src/h3.rs +++ b/src/h3.rs @@ -3,30 +3,28 @@ use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; use dashmap::DashMap; use dquic::{ qbase::net::addr::EndpointAddr, - qresolve::{Publish, PublishFuture, RecordStream, Resolve, ResolveFuture, Source}, + qresolve::{Publish, PublishFuture, Resolve, ResolveFuture}, }; -use futures::{StreamExt, stream}; use h3x::{ dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, endpoint::H3Endpoint, quic, }; -use http_body_util::{BodyExt, Empty, Full}; use tokio::time::Instant; -use tracing::trace; use url::Url; -use crate::core::{ - MdnsPacket, - parser::packet::be_packet, - signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, - wire::be_multi_response, -}; +mod cache; +mod lookup; +mod publish; +mod request; const LOOKUP_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); const LOOKUP_REQUEST_ATTEMPTS: usize = 3; -pub struct H3Resolver { +pub struct H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority, +{ endpoint: Arc>, base_url: Url, cached_records: DashMap, @@ -34,12 +32,15 @@ pub struct H3Resolver { } #[derive(Debug)] -struct Record { - addrs: Vec, - expire: Instant, +pub(super) struct Record { + pub(super) addrs: Vec, + pub(super) expire: Instant, } -impl fmt::Debug for H3Resolver { +impl fmt::Debug for H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("H3Resolver") .field("base_url", &self.base_url) @@ -47,7 +48,10 @@ impl fmt::Debug for H3Resolver { } } -impl fmt::Display for H3Resolver { +impl fmt::Display for H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority, +{ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "H3 DNS Resolver({})", self.base_url) } @@ -81,8 +85,9 @@ pub enum Error { ParseMultiResponse, } -impl H3Resolver +impl H3Resolver where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { @@ -114,325 +119,14 @@ where }) } - fn connect_error(&self, source: h3x::pool::ConnectError) -> Error { - // H3 DNS resolvers keep a long-lived endpoint. A network transition may - // leave the cached H3 connection with stale QUIC paths, so the next - // attempt must establish a fresh connection instead of reusing it. - self.endpoint.clear_pool(); - Error::Connect { source } - } - - fn request_error(&self, source: HyperRequestError) -> Error { - self.endpoint.clear_pool(); - Error::H3Request { source } - } - - async fn execute_request( - &self, - request: http::Request< - impl http_body::Body + Send + 'static, - >, - ) -> Result< - http::Response>, - Error, - > { - let authority = request - .uri() - .authority() - .expect("h3 dns request URL must include an authority") - .clone(); - tracing::trace!(%authority, "connecting h3 dns endpoint"); - let connection = match self.endpoint.connect(authority.clone()).await { - Ok(connection) => { - tracing::trace!(%authority, "connected h3 dns endpoint"); - connection - } - Err(source) => return Err(self.connect_error(source)), - }; - - let method = request.method().clone(); - let uri = request.uri().clone(); - tracing::trace!(%method, %uri, "executing h3 dns request"); - match connection.execute_hyper_request(request).await { - Ok(response) => { - tracing::trace!( - status = %response.status(), - "h3 dns request response received" - ); - Ok(response) - } - Err(source) => Err(self.request_error(source)), - } - } - pub fn clear_pool(&self) { self.endpoint.clear_pool(); } - - pub async fn publish_endpoints( - &self, - name: &str, - endpoints: &[EndpointAddr], - ) -> Result<(), Error> { - trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); - let bytes = { - let endpoints = endpoints - .iter() - .filter_map(|ep| { - crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() - }) - .collect(); - let mut hosts = std::collections::HashMap::new(); - hosts.insert(name.to_string(), endpoints); - MdnsPacket::answer(0, &hosts).to_bytes() - }; - - self.publish_packet(name, &bytes).await - } - - /// Publish a pre-built DNS packet (with signatures already included). - pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { - self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) - .await - } - - pub async fn publish_signed( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> io::Result<()> { - self.publish_packet_with_signature(name, packet, signature_fields) - .await - .map_err(io::Error::other) - } - - async fn publish_packet_with_signature( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> Result<(), Error> { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - tracing::trace!( - name, - packet_len = packet.len(), - url = %self.base_url, - "h3x publishing packet" - ); - let mut request = http::Request::post(uri); - if !signature_fields.is_empty() { - request = request - .header( - CONTENT_DIGEST_HEADER, - signature_fields.content_digest.as_slice(), - ) - .header( - SIGNATURE_INPUT_HEADER, - signature_fields.signature_input.as_slice(), - ) - .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); - } - let request = request - .body(Full::new(bytes::Bytes::copy_from_slice(packet))) - .expect("h3 dns publish request must be valid"); - let resp = self.execute_request(request).await?; - - if resp.status() != http::StatusCode::OK { - return Err(Error::Status { - status: resp.status(), - }); - } - - Ok(()) - } - - fn retryable_lookup_error(error: &Error) -> bool { - matches!( - error, - Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } - ) - } - - async fn lookup_response(&self, uri: http::Uri) -> Result> { - let request = http::Request::get(uri) - .body(Empty::::new()) - .expect("h3 dns lookup request must be valid"); - let resp = self.execute_request(request).await?; - - tracing::trace!("received response with status {}", resp.status()); - match resp.status() { - http::StatusCode::OK => {} - http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), - status => return Err(Error::Status { status }), - } - - match resp.into_body().collect().await { - Ok(response) => Ok(response.to_bytes()), - Err(source) => Err(Error::H3Stream { source }), - } - } - - async fn lookup_response_with_retry( - &self, - uri: http::Uri, - ) -> Result> { - for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { - match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) - .await - { - Ok(Ok(response)) => return Ok(response), - Ok(Err(error)) - if Self::retryable_lookup_error(&error) - && attempt < LOOKUP_REQUEST_ATTEMPTS => - { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup failed, retrying" - ); - } - Ok(Err(error)) => return Err(error), - Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup timed out, retrying" - ); - } - Err(_elapsed) => { - self.endpoint.clear_pool(); - return Err(Error::RequestTimeout { - timeout: LOOKUP_REQUEST_TIMEOUT, - }); - } - } - } - - unreachable!("lookup retry loop returns on the final attempt") - } - - pub async fn lookup(&self, name: &str) -> Result> { - use crate::core::parser::record; - let server = Arc::from(self.base_url.origin().ascii_serialization()); - let source = Source::H3 { server }; - - let Some(domain) = crate::resolvers::resolvable_name(name) else { - return Err(Error::NoRecordFound); - }; - - let now = Instant::now(); - let positive_ttl = Duration::from_secs(10); - let negative_ttl = Duration::from_secs(2); - - self.cached_records - .retain(|_host, record| record.expire > now); - self.negative_cache.retain(|_host, expire| *expire > now); - - if self.negative_cache.get(domain).is_some() { - return Err(Error::NoRecordFound); - } - - if let Some(record) = self.cached_records.get(domain) { - let addrs = record.addrs.clone(); - let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); - return Ok(stream.boxed()); - } - - let mut url = self.base_url.join("lookup").expect("Invalid URL"); - url.set_query(Some(&format!("host={}", domain))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); - - tracing::trace!("sending lookup request to {}", self.base_url); - let response = match self.lookup_response_with_retry(uri).await { - Ok(response) => response, - Err(Error::NoRecordFound) => { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); - } - Err(error) => return Err(error), - }; - - // Server always returns multi-record format. - let (remain, multi) = - be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; - if !remain.is_empty() { - return Err(Error::ParseMultiResponse); - } - - let mut addrs = Vec::new(); - for r in multi.records { - if !r.signature_fields.is_empty() { - match r.signature_fields.verify(&r.dns, &r.cert) { - Ok(true) => {} - Ok(false) => { - tracing::debug!("ignored record with invalid DNS packet signature"); - continue; - } - Err(error) => { - tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); - continue; - } - } - } - - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; - - addrs.extend( - packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - if answer.name() != domain { - tracing::debug!( - answer_name = %answer.name(), - query = domain, - "ignored endpoint answer for different name" - ); - return None; - } - let endpoint = TryInto::::try_into(ep.clone()).ok()?; - trace!(?endpoint, "parsed endpoint from record"); - Some(endpoint) - } - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }), - ); - } - - if addrs.is_empty() { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); - } - - self.cached_records.insert( - domain.to_string(), - Record { - addrs: addrs.clone(), - expire: now + positive_ttl, - }, - ); - - self.negative_cache.remove(domain); - - Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) - } } -impl Publish for H3Resolver +impl Publish for H3Resolver where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { @@ -446,8 +140,9 @@ where } } -impl Resolve for H3Resolver +impl Resolve for H3Resolver where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { @@ -463,6 +158,12 @@ where #[cfg(test)] mod tests { + use std::time::Duration; + + use dquic::{qbase::net::addr::EndpointAddr, qresolve::Source}; + use futures::StreamExt; + use tokio::time::Instant; + use super::*; #[cfg(feature = "dquic-network")] use crate::resolvers::DHTTP_H3_DNS_SERVER; diff --git a/src/h3/cache.rs b/src/h3/cache.rs new file mode 100644 index 0000000..78d813d --- /dev/null +++ b/src/h3/cache.rs @@ -0,0 +1 @@ +// Lookup cache ownership is introduced after the mechanical H3 module split. diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs new file mode 100644 index 0000000..76fd7ad --- /dev/null +++ b/src/h3/lookup.rs @@ -0,0 +1,134 @@ +use std::{sync::Arc, time::Duration}; + +use dquic::{ + qbase::net::addr::EndpointAddr, + qresolve::{RecordStream, Source}, +}; +use futures::{StreamExt, stream}; +use h3x::quic; +use tokio::time::Instant; +use tracing::trace; + +use super::{Error, H3Resolver, Record}; +use crate::core::{parser::packet::be_packet, wire::be_multi_response}; + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub async fn lookup(&self, name: &str) -> Result> { + use crate::core::parser::record; + let server = Arc::from(self.base_url.origin().ascii_serialization()); + let source = Source::H3 { server }; + + let Some(domain) = crate::resolvers::resolvable_name(name) else { + return Err(Error::NoRecordFound); + }; + + let now = Instant::now(); + let positive_ttl = Duration::from_secs(10); + let negative_ttl = Duration::from_secs(2); + + self.cached_records + .retain(|_host, record| record.expire > now); + self.negative_cache.retain(|_host, expire| *expire > now); + + if self.negative_cache.get(domain).is_some() { + return Err(Error::NoRecordFound); + } + + if let Some(record) = self.cached_records.get(domain) { + let addrs = record.addrs.clone(); + let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); + return Ok(stream.boxed()); + } + + let mut url = self.base_url.join("lookup").expect("Invalid URL"); + url.set_query(Some(&format!("host={}", domain))); + let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); + + tracing::trace!("sending lookup request to {}", self.base_url); + let response = match self.lookup_response_with_retry(uri).await { + Ok(response) => response, + Err(Error::NoRecordFound) => { + self.negative_cache + .insert(domain.to_string(), now + negative_ttl); + return Err(Error::NoRecordFound); + } + Err(error) => return Err(error), + }; + + // Server always returns multi-record format. + let (remain, multi) = + be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + if !remain.is_empty() { + return Err(Error::ParseMultiResponse); + } + + let mut addrs = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); + continue; + } + } + } + + let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { + source: source.to_owned(), + })?; + + addrs.extend( + packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + let endpoint = TryInto::::try_into(ep.clone()).ok()?; + trace!(?endpoint, "parsed endpoint from record"); + Some(endpoint) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }), + ); + } + + if addrs.is_empty() { + self.negative_cache + .insert(domain.to_string(), now + negative_ttl); + return Err(Error::NoRecordFound); + } + + self.cached_records.insert( + domain.to_string(), + Record { + addrs: addrs.clone(), + expire: now + positive_ttl, + }, + ); + + self.negative_cache.remove(domain); + + Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) + } +} diff --git a/src/h3/publish.rs b/src/h3/publish.rs new file mode 100644 index 0000000..bf04d9d --- /dev/null +++ b/src/h3/publish.rs @@ -0,0 +1,99 @@ +use std::collections::HashMap; + +use dquic::qbase::net::addr::EndpointAddr; +use h3x::quic; +use http_body_util::Full; +use tracing::trace; + +use super::{Error, H3Resolver}; +use crate::core::{ + MdnsPacket, + signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, +}; + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub async fn publish_endpoints( + &self, + name: &str, + endpoints: &[EndpointAddr], + ) -> Result<(), Error> { + trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); + let bytes = { + let endpoints = endpoints + .iter() + .filter_map(|ep| { + crate::core::parser::record::endpoint::EndpointAddr::try_from(*ep).ok() + }) + .collect(); + let mut hosts = HashMap::new(); + hosts.insert(name.to_string(), endpoints); + MdnsPacket::answer(0, &hosts).to_bytes() + }; + + self.publish_packet(name, &bytes).await + } + + /// Publish a pre-built DNS packet (with signatures already included). + pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { + self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) + .await + } + + pub async fn publish_signed( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> std::io::Result<()> { + self.publish_packet_with_signature(name, packet, signature_fields) + .await + .map_err(std::io::Error::other) + } + + async fn publish_packet_with_signature( + &self, + name: &str, + packet: &[u8], + signature_fields: &SignatureFields, + ) -> Result<(), Error> { + let mut url = self.base_url.join("publish").expect("Invalid base URL"); + url.set_query(Some(&format!("host={name}"))); + let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); + tracing::trace!( + name, + packet_len = packet.len(), + url = %self.base_url, + "h3x publishing packet" + ); + let mut request = http::Request::post(uri); + if !signature_fields.is_empty() { + request = request + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); + } + let request = request + .body(Full::new(bytes::Bytes::copy_from_slice(packet))) + .expect("h3 dns publish request must be valid"); + let resp = self.execute_request(request).await?; + + if resp.status() != http::StatusCode::OK { + return Err(Error::Status { + status: resp.status(), + }); + } + + Ok(()) + } +} diff --git a/src/h3/request.rs b/src/h3/request.rs new file mode 100644 index 0000000..77116d5 --- /dev/null +++ b/src/h3/request.rs @@ -0,0 +1,140 @@ +use std::convert::Infallible; + +use h3x::{ + dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, + quic, +}; +use http_body_util::BodyExt; + +use super::{Error, H3Resolver, LOOKUP_REQUEST_ATTEMPTS, LOOKUP_REQUEST_TIMEOUT}; + +impl H3Resolver +where + C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, + C::Error: Send + Sync + 'static, + C::Connection: Send + 'static, +{ + pub(super) fn connect_error( + &self, + source: h3x::pool::ConnectError, + ) -> Error { + // H3 DNS resolvers keep a long-lived endpoint. A network transition may + // leave the cached H3 connection with stale QUIC paths, so the next + // attempt must establish a fresh connection instead of reusing it. + self.endpoint.clear_pool(); + Error::Connect { source } + } + + pub(super) fn request_error(&self, source: HyperRequestError) -> Error { + self.endpoint.clear_pool(); + Error::H3Request { source } + } + + pub(super) async fn execute_request( + &self, + request: http::Request< + impl http_body::Body + Send + 'static, + >, + ) -> Result< + http::Response>, + Error, + > { + let authority = request + .uri() + .authority() + .expect("h3 dns request URL must include an authority") + .clone(); + tracing::trace!(%authority, "connecting h3 dns endpoint"); + let connection = match self.endpoint.connect(authority.clone()).await { + Ok(connection) => { + tracing::trace!(%authority, "connected h3 dns endpoint"); + connection + } + Err(source) => return Err(self.connect_error(source)), + }; + + let method = request.method().clone(); + let uri = request.uri().clone(); + tracing::trace!(%method, %uri, "executing h3 dns request"); + match connection.execute_hyper_request(request).await { + Ok(response) => { + tracing::trace!( + status = %response.status(), + "h3 dns request response received" + ); + Ok(response) + } + Err(source) => Err(self.request_error(source)), + } + } + + pub(super) fn retryable_lookup_error(error: &Error) -> bool { + matches!( + error, + Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } + ) + } + + pub(super) async fn lookup_response( + &self, + uri: http::Uri, + ) -> Result> { + let request = http::Request::get(uri) + .body(http_body_util::Empty::::new()) + .expect("h3 dns lookup request must be valid"); + let resp = self.execute_request(request).await?; + + tracing::trace!("received response with status {}", resp.status()); + match resp.status() { + http::StatusCode::OK => {} + http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), + status => return Err(Error::Status { status }), + } + + match resp.into_body().collect().await { + Ok(response) => Ok(response.to_bytes()), + Err(source) => Err(Error::H3Stream { source }), + } + } + + pub(super) async fn lookup_response_with_retry( + &self, + uri: http::Uri, + ) -> Result> { + for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { + match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) + .await + { + Ok(Ok(response)) => return Ok(response), + Ok(Err(error)) + if Self::retryable_lookup_error(&error) + && attempt < LOOKUP_REQUEST_ATTEMPTS => + { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup failed, retrying" + ); + } + Ok(Err(error)) => return Err(error), + Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup timed out, retrying" + ); + } + Err(_elapsed) => { + self.endpoint.clear_pool(); + return Err(Error::RequestTimeout { + timeout: LOOKUP_REQUEST_TIMEOUT, + }); + } + } + } + + unreachable!("lookup retry loop returns on the final attempt") + } +} diff --git a/src/resolvers.rs b/src/resolvers.rs index cad1652..f7abb4d 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -205,7 +205,7 @@ impl ResolversBuilder { endpoint: Arc>, ) -> io::Result where - C: h3x::quic::Connect + Send + Sync + 'static, + C: h3x::quic::Connect + h3x::quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { @@ -219,7 +219,7 @@ impl ResolversBuilder { endpoint: Arc>, ) -> io::Result where - C: h3x::quic::Connect + Send + Sync + 'static, + C: h3x::quic::Connect + h3x::quic::WithLocalAuthority + Send + Sync + 'static, C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { From 877db8cfd2ba436a83ff345b34dde9cf6e8fc063 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 19:09:45 +0800 Subject: [PATCH 12/29] refactor: split h3 resolver errors by operation --- src/h3.rs | 60 +++++++++++++++++------- src/h3/lookup.rs | 115 ++++++++++++++++++++++++++++++++++++++++------ src/h3/publish.rs | 14 ++++-- src/h3/request.rs | 87 ++++------------------------------- 4 files changed, 164 insertions(+), 112 deletions(-) diff --git a/src/h3.rs b/src/h3.rs index ad1f7e2..8499998 100644 --- a/src/h3.rs +++ b/src/h3.rs @@ -58,31 +58,59 @@ where } #[derive(Debug, snafu::Snafu)] -pub enum Error { - #[snafu(display("h3 stream error"))] - H3Stream { source: MessageStreamError }, +#[snafu(module)] +pub enum H3RequestError { #[snafu(display("failed to connect h3 endpoint"))] Connect { source: h3x::pool::ConnectError }, #[snafu(display("h3 request error"))] - H3Request { + Request { source: HyperRequestError, }, - #[snafu(display("h3 request timed out after {timeout:?}"))] - RequestTimeout { timeout: Duration }, +} +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum H3PublishError { + #[snafu(transparent)] + Request { source: H3RequestError }, + #[snafu(display("anonymous h3 endpoint cannot sign dns publish request"))] + AnonymousEndpoint, + #[snafu(display("failed to get h3 endpoint local authority"))] + LocalAuthority { source: h3x::quic::ConnectionError }, + #[snafu(display("failed to sign h3 dns publish request"))] + SignRequest { + source: crate::core::signature::SignatureFieldsError, + }, #[snafu(display("{status}"))] Status { status: http::StatusCode }, +} +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum H3LookupError { + #[snafu(transparent)] + Request { source: H3RequestError }, + #[snafu(display("h3 stream error"))] + H3Stream { source: MessageStreamError }, + #[snafu(display("h3 request timed out after {timeout:?}"))] + RequestTimeout { timeout: Duration }, + #[snafu(display("{status}"))] + Status { status: http::StatusCode }, #[snafu(display("no DNS record found"))] NoRecordFound, + #[snafu(display("failed to decode h3 dns lookup response"))] + Decode { source: LookupDecodeError }, +} +#[derive(Debug, snafu::Snafu)] +#[snafu(module)] +pub enum LookupDecodeError { + #[snafu(display("failed to decode multi-record response"))] + MultiResponse, #[snafu(display("failed to parse DNS records from response"))] ParseRecords { source: nom::Err>>, }, - - #[snafu(display("failed to decode multi-record response"))] - ParseMultiResponse, } impl H3Resolver @@ -132,10 +160,9 @@ where { fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { Box::pin(async move { - match self.publish_packet(name, packet).await { - Ok(()) => Ok(()), - Err(error) => Err(io::Error::other(error)), - } + self.publish_packet(name, packet) + .await + .map_err(io::Error::other) }) } } @@ -148,10 +175,9 @@ where { fn lookup<'l>(&'l self, name: &'l str) -> ResolveFuture<'l> { Box::pin(async move { - match H3Resolver::lookup(self, name).await { - Ok(stream) => Ok(stream), - Err(error) => Err(io::Error::other(error)), - } + H3Resolver::lookup(self, name) + .await + .map_err(io::Error::other) }) } } diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs index 76fd7ad..027bab9 100644 --- a/src/h3/lookup.rs +++ b/src/h3/lookup.rs @@ -6,10 +6,14 @@ use dquic::{ }; use futures::{StreamExt, stream}; use h3x::quic; +use http_body_util::BodyExt; use tokio::time::Instant; use tracing::trace; -use super::{Error, H3Resolver, Record}; +use super::{ + H3LookupError, H3Resolver, LOOKUP_REQUEST_ATTEMPTS, LOOKUP_REQUEST_TIMEOUT, LookupDecodeError, + Record, +}; use crate::core::{parser::packet::be_packet, wire::be_multi_response}; impl H3Resolver @@ -18,13 +22,83 @@ where C::Error: Send + Sync + 'static, C::Connection: Send + 'static, { - pub async fn lookup(&self, name: &str) -> Result> { + pub(super) fn retryable_lookup_error(error: &H3LookupError) -> bool { + matches!( + error, + H3LookupError::Request { .. } | H3LookupError::H3Stream { .. } + ) + } + + pub(super) async fn lookup_response( + &self, + uri: http::Uri, + ) -> Result> { + let request = http::Request::get(uri) + .body(http_body_util::Empty::::new()) + .expect("h3 dns lookup request must be valid"); + let resp = self.execute_request(request).await?; + + tracing::trace!("received response with status {}", resp.status()); + match resp.status() { + http::StatusCode::OK => {} + http::StatusCode::NOT_FOUND => return Err(H3LookupError::NoRecordFound), + status => return Err(H3LookupError::Status { status }), + } + + match resp.into_body().collect().await { + Ok(response) => Ok(response.to_bytes()), + Err(source) => Err(H3LookupError::H3Stream { source }), + } + } + + pub(super) async fn lookup_response_with_retry( + &self, + uri: http::Uri, + ) -> Result> { + for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { + match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) + .await + { + Ok(Ok(response)) => return Ok(response), + Ok(Err(error)) + if Self::retryable_lookup_error(&error) + && attempt < LOOKUP_REQUEST_ATTEMPTS => + { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup failed, retrying" + ); + } + Ok(Err(error)) => return Err(error), + Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { + self.endpoint.clear_pool(); + tracing::debug!( + attempt, + timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), + "h3 dns lookup timed out, retrying" + ); + } + Err(_elapsed) => { + self.endpoint.clear_pool(); + return Err(H3LookupError::RequestTimeout { + timeout: LOOKUP_REQUEST_TIMEOUT, + }); + } + } + } + + unreachable!("lookup retry loop returns on the final attempt") + } + + pub async fn lookup(&self, name: &str) -> Result> { use crate::core::parser::record; let server = Arc::from(self.base_url.origin().ascii_serialization()); let source = Source::H3 { server }; let Some(domain) = crate::resolvers::resolvable_name(name) else { - return Err(Error::NoRecordFound); + return Err(H3LookupError::NoRecordFound); }; let now = Instant::now(); @@ -36,7 +110,7 @@ where self.negative_cache.retain(|_host, expire| *expire > now); if self.negative_cache.get(domain).is_some() { - return Err(Error::NoRecordFound); + return Err(H3LookupError::NoRecordFound); } if let Some(record) = self.cached_records.get(domain) { @@ -52,19 +126,27 @@ where tracing::trace!("sending lookup request to {}", self.base_url); let response = match self.lookup_response_with_retry(uri).await { Ok(response) => response, - Err(Error::NoRecordFound) => { + Err(H3LookupError::NoRecordFound) => { self.negative_cache .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); + return Err(H3LookupError::NoRecordFound); } Err(error) => return Err(error), }; // Server always returns multi-record format. - let (remain, multi) = - be_multi_response(response.as_ref()).map_err(|_| Error::ParseMultiResponse)?; + let (remain, multi) = match be_multi_response(response.as_ref()) { + Ok(response) => response, + Err(_error) => { + return Err(H3LookupError::Decode { + source: LookupDecodeError::MultiResponse, + }); + } + }; if !remain.is_empty() { - return Err(Error::ParseMultiResponse); + return Err(H3LookupError::Decode { + source: LookupDecodeError::MultiResponse, + }); } let mut addrs = Vec::new(); @@ -83,9 +165,16 @@ where } } - let (_remain, packet) = be_packet(&r.dns).map_err(|source| Error::ParseRecords { - source: source.to_owned(), - })?; + let (_remain, packet) = match be_packet(&r.dns) { + Ok(packet) => packet, + Err(source) => { + return Err(H3LookupError::Decode { + source: LookupDecodeError::ParseRecords { + source: source.to_owned(), + }, + }); + } + }; addrs.extend( packet @@ -116,7 +205,7 @@ where if addrs.is_empty() { self.negative_cache .insert(domain.to_string(), now + negative_ttl); - return Err(Error::NoRecordFound); + return Err(H3LookupError::NoRecordFound); } self.cached_records.insert( diff --git a/src/h3/publish.rs b/src/h3/publish.rs index bf04d9d..ba1de6c 100644 --- a/src/h3/publish.rs +++ b/src/h3/publish.rs @@ -5,7 +5,7 @@ use h3x::quic; use http_body_util::Full; use tracing::trace; -use super::{Error, H3Resolver}; +use super::{H3PublishError, H3Resolver}; use crate::core::{ MdnsPacket, signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, @@ -21,7 +21,7 @@ where &self, name: &str, endpoints: &[EndpointAddr], - ) -> Result<(), Error> { + ) -> Result<(), H3PublishError> { trace!("h3x publishing {} with {} endpoints", name, endpoints.len()); let bytes = { let endpoints = endpoints @@ -39,7 +39,11 @@ where } /// Publish a pre-built DNS packet (with signatures already included). - pub async fn publish_packet(&self, name: &str, packet: &[u8]) -> Result<(), Error> { + pub async fn publish_packet( + &self, + name: &str, + packet: &[u8], + ) -> Result<(), H3PublishError> { self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) .await } @@ -60,7 +64,7 @@ where name: &str, packet: &[u8], signature_fields: &SignatureFields, - ) -> Result<(), Error> { + ) -> Result<(), H3PublishError> { let mut url = self.base_url.join("publish").expect("Invalid base URL"); url.set_query(Some(&format!("host={name}"))); let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); @@ -89,7 +93,7 @@ where let resp = self.execute_request(request).await?; if resp.status() != http::StatusCode::OK { - return Err(Error::Status { + return Err(H3PublishError::Status { status: resp.status(), }); } diff --git a/src/h3/request.rs b/src/h3/request.rs index 77116d5..8b82d13 100644 --- a/src/h3/request.rs +++ b/src/h3/request.rs @@ -4,9 +4,9 @@ use h3x::{ dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, quic, }; -use http_body_util::BodyExt; +use snafu::IntoError; -use super::{Error, H3Resolver, LOOKUP_REQUEST_ATTEMPTS, LOOKUP_REQUEST_TIMEOUT}; +use super::{H3RequestError, H3Resolver, h3_request_error}; impl H3Resolver where @@ -17,17 +17,20 @@ where pub(super) fn connect_error( &self, source: h3x::pool::ConnectError, - ) -> Error { + ) -> H3RequestError { // H3 DNS resolvers keep a long-lived endpoint. A network transition may // leave the cached H3 connection with stale QUIC paths, so the next // attempt must establish a fresh connection instead of reusing it. self.endpoint.clear_pool(); - Error::Connect { source } + h3_request_error::ConnectSnafu.into_error(source) } - pub(super) fn request_error(&self, source: HyperRequestError) -> Error { + pub(super) fn request_error( + &self, + source: HyperRequestError, + ) -> H3RequestError { self.endpoint.clear_pool(); - Error::H3Request { source } + h3_request_error::RequestSnafu.into_error(source) } pub(super) async fn execute_request( @@ -37,7 +40,7 @@ where >, ) -> Result< http::Response>, - Error, + H3RequestError, > { let authority = request .uri() @@ -67,74 +70,4 @@ where Err(source) => Err(self.request_error(source)), } } - - pub(super) fn retryable_lookup_error(error: &Error) -> bool { - matches!( - error, - Error::Connect { .. } | Error::H3Request { .. } | Error::H3Stream { .. } - ) - } - - pub(super) async fn lookup_response( - &self, - uri: http::Uri, - ) -> Result> { - let request = http::Request::get(uri) - .body(http_body_util::Empty::::new()) - .expect("h3 dns lookup request must be valid"); - let resp = self.execute_request(request).await?; - - tracing::trace!("received response with status {}", resp.status()); - match resp.status() { - http::StatusCode::OK => {} - http::StatusCode::NOT_FOUND => return Err(Error::NoRecordFound), - status => return Err(Error::Status { status }), - } - - match resp.into_body().collect().await { - Ok(response) => Ok(response.to_bytes()), - Err(source) => Err(Error::H3Stream { source }), - } - } - - pub(super) async fn lookup_response_with_retry( - &self, - uri: http::Uri, - ) -> Result> { - for attempt in 1..=LOOKUP_REQUEST_ATTEMPTS { - match tokio::time::timeout(LOOKUP_REQUEST_TIMEOUT, self.lookup_response(uri.clone())) - .await - { - Ok(Ok(response)) => return Ok(response), - Ok(Err(error)) - if Self::retryable_lookup_error(&error) - && attempt < LOOKUP_REQUEST_ATTEMPTS => - { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup failed, retrying" - ); - } - Ok(Err(error)) => return Err(error), - Err(_elapsed) if attempt < LOOKUP_REQUEST_ATTEMPTS => { - self.endpoint.clear_pool(); - tracing::debug!( - attempt, - timeout_ms = LOOKUP_REQUEST_TIMEOUT.as_millis(), - "h3 dns lookup timed out, retrying" - ); - } - Err(_elapsed) => { - self.endpoint.clear_pool(); - return Err(Error::RequestTimeout { - timeout: LOOKUP_REQUEST_TIMEOUT, - }); - } - } - } - - unreachable!("lookup retry loop returns on the final attempt") - } } From 5ad32b27479444fc036590fa9b985605e3d7dd02 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 19:18:10 +0800 Subject: [PATCH 13/29] refactor: sign h3 publish requests from endpoint authority --- examples/README.md | 4 +- examples/publish.rs | 34 +------ src/h3/publish.rs | 234 ++++++++++++++++++++++++++++++++++++-------- 3 files changed, 199 insertions(+), 73 deletions(-) diff --git a/examples/README.md b/examples/README.md index 2bba4b6..9130094 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,7 @@ whose library target remains `ddns`. | `mdns_discover` | `mdns` | Bind an mDNS service, publish sample local hosts, and print multicast packets. | | `mdns_query` | `mdns` | Query a DHTTP name over local mDNS. | | `query` | none | Query a DNS-over-H3 server and decode the multi-record response. | -| `publish` | `h3` | Publish signed endpoint `E` records to a DNS-over-H3 server using client mTLS. | +| `publish` | `h3` | Publish endpoint `E` records to a DNS-over-H3 server using client mTLS; H3 publish request headers are signed from the client endpoint identity. | Run all commands from the `ddns/` repository. @@ -82,9 +82,9 @@ Options: | `--client-name ` | DHTTP identity name presented by the client endpoint. | | `--client-cert ` | Client certificate chain PEM for mTLS and endpoint signature verification. | | `--client-key ` | Client private key PEM. | -| `--sign ` | Whether to sign each endpoint `E` record. Defaults to `true`. | | `--host ` | DNS host to publish. | | `--addr ` | One or more socket addresses to publish. | The example imports `H3Publisher` from the `ddns::publishers` facade, but only needs the `h3` backend feature because backend publisher types are re-exported from the facade directly. +H3 publish request headers are always signed with the configured client endpoint identity; callers no longer pass request signature fields. diff --git a/examples/publish.rs b/examples/publish.rs index 4aa43b4..2af5669 100644 --- a/examples/publish.rs +++ b/examples/publish.rs @@ -7,8 +7,7 @@ use std::{ use clap::Parser; use ddns::{ - core::{parser::record::endpoint::EndpointAddr, signature::SignatureFields}, - publishers::H3Publisher, + core::parser::record::endpoint::EndpointAddr, publishers::H3Publisher, resolvers::DHTTP_H3_DNS_SERVER, }; use h3x::dquic::{ @@ -43,13 +42,6 @@ struct Options { #[arg(long)] client_key: PathBuf, - /// Sign DNS packets using HTTP signature fields and the client private key. - /// - /// This must correspond to the client certificate presented in mTLS, because the server - /// verifies the signature with the peer certificate's SPKI. - #[arg(long, default_value_t = true, action = clap::ArgAction::Set)] - sign: bool, - /// 要发布的线上域名,必须与客户端证书 SAN 匹配。 #[arg(long)] host: String, @@ -142,12 +134,6 @@ async fn main() -> io::Result<()> { let resolver = H3Publisher::new(opt.base_url.clone(), h3_endpoint)?; info!(host = %opt.host, addrs = ?opt.addr, base_url = %opt.base_url, "publish.start"); - if opt.sign { - info!("publish.packet_signing.enabled"); - } else { - info!("publish.packet_signing.disabled"); - } - for &addr in &opt.addr { info!("Creating endpoint for address: {}", addr); let mut endpoint = match addr { @@ -160,20 +146,10 @@ async fn main() -> io::Result<()> { let mut hosts = std::collections::HashMap::new(); hosts.insert(opt.host.clone(), vec![endpoint]); let packet = ddns::core::MdnsPacket::answer(0, &hosts).to_bytes(); - if opt.sign { - info!("signing dns packet"); - let signature_fields = SignatureFields::sign(&packet, identity.as_ref()) - .await - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - resolver - .publish_signed(&opt.host, &packet, &signature_fields) - .await?; - } else { - resolver - .publish(&opt.host, &packet) - .await - .map_err(io::Error::other)?; - } + resolver + .publish(&opt.host, &packet) + .await + .map_err(io::Error::other)?; info!("Successfully published endpoint for {}", addr); } info!("publish.ok"); diff --git a/src/h3/publish.rs b/src/h3/publish.rs index ba1de6c..8b3ca82 100644 --- a/src/h3/publish.rs +++ b/src/h3/publish.rs @@ -1,16 +1,46 @@ use std::collections::HashMap; +use dhttp_identity::identity::LocalAuthority; use dquic::qbase::net::addr::EndpointAddr; use h3x::quic; use http_body_util::Full; +use snafu::{OptionExt, ResultExt}; use tracing::trace; -use super::{H3PublishError, H3Resolver}; +use super::{H3PublishError, H3Resolver, h3_publish_error}; use crate::core::{ MdnsPacket, signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, }; +async fn signed_publish_request( + base_url: &url::Url, + name: &str, + packet: &[u8], + authority: &A, +) -> Result>, crate::core::signature::SignatureFieldsError> { + let mut url = base_url.join("publish").expect("h3 dns base URL is valid"); + url.set_query(Some(&format!("host={name}"))); + let uri: http::Uri = url + .as_str() + .parse() + .expect("h3 dns publish URL is a valid URI"); + let signature_fields = SignatureFields::sign(packet, authority).await?; + + Ok(http::Request::post(uri) + .header( + CONTENT_DIGEST_HEADER, + signature_fields.content_digest.as_slice(), + ) + .header( + SIGNATURE_INPUT_HEADER, + signature_fields.signature_input.as_slice(), + ) + .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()) + .body(Full::new(bytes::Bytes::copy_from_slice(packet))) + .expect("h3 dns publish request must be valid")) +} + impl H3Resolver where C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, @@ -44,52 +74,22 @@ where name: &str, packet: &[u8], ) -> Result<(), H3PublishError> { - self.publish_packet_with_signature(name, packet, &SignatureFields::empty()) - .await - } - - pub async fn publish_signed( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> std::io::Result<()> { - self.publish_packet_with_signature(name, packet, signature_fields) - .await - .map_err(std::io::Error::other) - } - - async fn publish_packet_with_signature( - &self, - name: &str, - packet: &[u8], - signature_fields: &SignatureFields, - ) -> Result<(), H3PublishError> { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); - let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); tracing::trace!( name, packet_len = packet.len(), url = %self.base_url, - "h3x publishing packet" + "h3 dns publishing packet" ); - let mut request = http::Request::post(uri); - if !signature_fields.is_empty() { - request = request - .header( - CONTENT_DIGEST_HEADER, - signature_fields.content_digest.as_slice(), - ) - .header( - SIGNATURE_INPUT_HEADER, - signature_fields.signature_input.as_slice(), - ) - .header(SIGNATURE_HEADER, signature_fields.signature.as_slice()); - } - let request = request - .body(Full::new(bytes::Bytes::copy_from_slice(packet))) - .expect("h3 dns publish request must be valid"); + let authority = self + .endpoint + .quic() + .local_authority() + .await + .context(h3_publish_error::LocalAuthoritySnafu)? + .context(h3_publish_error::AnonymousEndpointSnafu)?; + let request = signed_publish_request(&self.base_url, name, packet, &authority) + .await + .context(h3_publish_error::SignRequestSnafu)?; let resp = self.execute_request(request).await?; if resp.status() != http::StatusCode::OK { @@ -101,3 +101,153 @@ where Ok(()) } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use dquic::qresolve::Publish as _; + use futures::future::BoxFuture; + use h3x::endpoint::H3Endpoint; + use ring::signature::KeyPair as _; + use rustls::{ + SignatureAlgorithm, SignatureScheme, + pki_types::CertificateDer, + sign::{Signer, SigningKey}, + }; + + use super::*; + + #[cfg(feature = "dquic-network")] + #[tokio::test] + async fn publish_rejects_anonymous_endpoint_before_request() { + let endpoint = Arc::new(H3Endpoint::new( + h3x::dquic::QuicEndpoint::builder().build().await, + )); + let resolver = H3Resolver::from_endpoint("https://dns.example.test:4433", endpoint) + .expect("valid h3 resolver"); + + let error = resolver + .publish_packet("demo.dhttp.net", b"dns-packet") + .await + .expect_err("anonymous endpoint should not publish"); + + assert_eq!( + error.to_string(), + "anonymous h3 endpoint cannot sign dns publish request" + ); + + let trait_error = resolver + .publish("demo.dhttp.net", b"dns-packet") + .await + .expect_err("trait publish should surface anonymous endpoint"); + assert!( + trait_error + .to_string() + .contains("anonymous h3 endpoint cannot sign dns publish request") + ); + } + + #[derive(Debug)] + struct TestAuthority { + keypair: Arc, + cert_chain: Vec>, + } + + impl dhttp_identity::identity::LocalAuthority for TestAuthority { + fn name(&self) -> &str { + "authority.example" + } + + fn cert_chain(&self) -> &[CertificateDer<'static>] { + &self.cert_chain + } + + fn sign( + &self, + data: &[u8], + ) -> BoxFuture<'_, Result, dhttp_identity::identity::SignError>> { + let result = dhttp_identity::identity::sign_with_key( + &TestSigningKey(self.keypair.clone()), + data, + ); + Box::pin(std::future::ready(result)) + } + } + + #[derive(Debug)] + struct TestSigningKey(Arc); + + impl SigningKey for TestSigningKey { + fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { + offered + .contains(&SignatureScheme::ED25519) + .then(|| Box::new(TestSigner(self.0.clone())) as Box) + } + + fn algorithm(&self) -> SignatureAlgorithm { + SignatureAlgorithm::ED25519 + } + } + + #[derive(Debug)] + struct TestSigner(Arc); + + impl Signer for TestSigner { + fn sign(&self, message: &[u8]) -> Result, rustls::Error> { + Ok(self.0.sign(message).as_ref().to_vec()) + } + + fn scheme(&self) -> SignatureScheme { + SignatureScheme::ED25519 + } + } + + fn test_authority() -> TestAuthority { + let rng = ring::rand::SystemRandom::new(); + let pkcs8 = ring::signature::Ed25519KeyPair::generate_pkcs8(&rng).expect("pkcs8"); + let keypair = + Arc::new(ring::signature::Ed25519KeyPair::from_pkcs8(pkcs8.as_ref()).expect("keypair")); + let mut spki = Vec::with_capacity(44); + spki.extend_from_slice(&[ + 0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00, + ]); + spki.extend_from_slice(keypair.public_key().as_ref()); + + TestAuthority { + keypair, + cert_chain: vec![CertificateDer::from(spki)], + } + } + + #[tokio::test] + async fn signed_publish_request_uses_authority_headers() { + let authority = test_authority(); + let base_url = url::Url::parse("https://dns.example.test:4433").expect("url"); + let request = + signed_publish_request(&base_url, "demo.dhttp.net", b"dns-packet", &authority) + .await + .expect("signed request"); + + assert_eq!(request.method(), http::Method::POST); + assert_eq!( + request.uri().to_string(), + "https://dns.example.test:4433/publish?host=demo.dhttp.net" + ); + assert!( + request + .headers() + .contains_key(crate::core::signature::CONTENT_DIGEST_HEADER) + ); + assert!( + request + .headers() + .contains_key(crate::core::signature::SIGNATURE_INPUT_HEADER) + ); + assert!( + request + .headers() + .contains_key(crate::core::signature::SIGNATURE_HEADER) + ); + } +} From abfdf6c7f68d69da6be189aaa7d229526ec046ca Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 19:26:46 +0800 Subject: [PATCH 14/29] fix: select one h3 dns endpoint group --- src/h3.rs | 3 + src/h3/lookup.rs | 213 ++++++++++++++++++++++++++++++---------------- src/h3/publish.rs | 2 + src/resolvers.rs | 12 ++- 4 files changed, 154 insertions(+), 76 deletions(-) diff --git a/src/h3.rs b/src/h3.rs index 8499998..8720e69 100644 --- a/src/h3.rs +++ b/src/h3.rs @@ -186,8 +186,11 @@ where mod tests { use std::time::Duration; + #[cfg(feature = "dquic-network")] use dquic::{qbase::net::addr::EndpointAddr, qresolve::Source}; + #[cfg(feature = "dquic-network")] use futures::StreamExt; + #[cfg(feature = "dquic-network")] use tokio::time::Instant; use super::*; diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs index 027bab9..8540035 100644 --- a/src/h3/lookup.rs +++ b/src/h3/lookup.rs @@ -1,21 +1,90 @@ use std::{sync::Arc, time::Duration}; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{RecordStream, Source}, -}; +use dquic::qresolve::{RecordStream, Source}; use futures::{StreamExt, stream}; use h3x::quic; use http_body_util::BodyExt; +use snafu::{IntoError, ResultExt}; use tokio::time::Instant; -use tracing::trace; use super::{ H3LookupError, H3Resolver, LOOKUP_REQUEST_ATTEMPTS, LOOKUP_REQUEST_TIMEOUT, LookupDecodeError, - Record, + Record, h3_lookup_error, lookup_decode_error, }; use crate::core::{parser::packet::be_packet, wire::be_multi_response}; +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct LookupRecords { + pub(super) endpoints: Vec, +} + +impl LookupRecords { + pub(super) fn decode(domain: &str, response: &[u8]) -> Result { + use crate::core::parser::record; + + let (remain, multi) = match be_multi_response(response) { + Ok(response) => response, + Err(_error) => return Err(LookupDecodeError::MultiResponse), + }; + if !remain.is_empty() { + return Err(LookupDecodeError::MultiResponse); + } + + let mut endpoint_records = Vec::new(); + for r in multi.records { + if !r.signature_fields.is_empty() { + match r.signature_fields.verify(&r.dns, &r.cert) { + Ok(true) => {} + Ok(false) => { + tracing::debug!("ignored record with invalid DNS packet signature"); + continue; + } + Err(error) => { + tracing::debug!( + error = %snafu::Report::from_error(&error), + "ignored record with malformed DNS packet signature" + ); + continue; + } + } + } + + let (_remain, packet) = match be_packet(&r.dns) { + Ok(packet) => packet, + Err(source) => { + return Err( + lookup_decode_error::ParseRecordsSnafu.into_error(source.to_owned()) + ); + } + }; + + endpoint_records.extend(packet.answers.iter().filter_map( + |answer| match answer.data() { + record::RData::E(ep) => { + if answer.name() != domain { + tracing::debug!( + answer_name = %answer.name(), + query = domain, + "ignored endpoint answer for different name" + ); + return None; + } + Some(ep.clone()) + } + _ => { + tracing::debug!(?answer, "ignored record"); + None + } + }, + )); + } + + Ok(Self { + endpoints: crate::resolvers::selector::selected_endpoint_addrs(endpoint_records), + }) + } +} + impl H3Resolver where C: quic::Connect + quic::WithLocalAuthority + Send + Sync + 'static, @@ -93,7 +162,6 @@ where } pub async fn lookup(&self, name: &str) -> Result> { - use crate::core::parser::record; let server = Arc::from(self.base_url.origin().ascii_serialization()); let source = Source::H3 { server }; @@ -134,73 +202,9 @@ where Err(error) => return Err(error), }; - // Server always returns multi-record format. - let (remain, multi) = match be_multi_response(response.as_ref()) { - Ok(response) => response, - Err(_error) => { - return Err(H3LookupError::Decode { - source: LookupDecodeError::MultiResponse, - }); - } - }; - if !remain.is_empty() { - return Err(H3LookupError::Decode { - source: LookupDecodeError::MultiResponse, - }); - } - - let mut addrs = Vec::new(); - for r in multi.records { - if !r.signature_fields.is_empty() { - match r.signature_fields.verify(&r.dns, &r.cert) { - Ok(true) => {} - Ok(false) => { - tracing::debug!("ignored record with invalid DNS packet signature"); - continue; - } - Err(error) => { - tracing::debug!(error = %snafu::Report::from_error(&error), "ignored record with malformed DNS packet signature"); - continue; - } - } - } - - let (_remain, packet) = match be_packet(&r.dns) { - Ok(packet) => packet, - Err(source) => { - return Err(H3LookupError::Decode { - source: LookupDecodeError::ParseRecords { - source: source.to_owned(), - }, - }); - } - }; - - addrs.extend( - packet - .answers - .iter() - .filter_map(|answer| match answer.data() { - record::RData::E(ep) => { - if answer.name() != domain { - tracing::debug!( - answer_name = %answer.name(), - query = domain, - "ignored endpoint answer for different name" - ); - return None; - } - let endpoint = TryInto::::try_into(ep.clone()).ok()?; - trace!(?endpoint, "parsed endpoint from record"); - Some(endpoint) - } - _ => { - tracing::debug!(?answer, "ignored record"); - None - } - }), - ); - } + let records = LookupRecords::decode(domain, response.as_ref()) + .context(h3_lookup_error::DecodeSnafu)?; + let addrs = records.endpoints; if addrs.is_empty() { self.negative_cache @@ -221,3 +225,64 @@ where Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) } } + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, net::SocketAddrV4}; + + use super::*; + use crate::core::{ + MdnsPacket, + parser::record::endpoint::EndpointAddr as DnsEndpointAddr, + wire::{MultiResponse, ResponseRecord}, + }; + + fn direct(addr: &str, main: bool, sequence: u64) -> DnsEndpointAddr { + let socket: SocketAddrV4 = addr.parse().expect("socket addr"); + let mut endpoint = DnsEndpointAddr::direct_v4(socket); + endpoint.set_main(main); + endpoint.set_sequence(sequence); + endpoint + } + + fn response_for(name: &str, endpoints: Vec) -> Vec { + let mut hosts = HashMap::new(); + hosts.insert(name.to_owned(), endpoints); + let packet = MdnsPacket::answer(0, &hosts).to_bytes(); + MultiResponse::new([ResponseRecord::unsigned(packet, Vec::new())]).encode() + } + + #[test] + fn lookup_records_select_primary_group() { + let response = response_for( + "demo.dhttp.net", + vec![ + direct("192.0.2.20:4433", false, 1), + direct("192.0.2.10:4433", true, 2), + direct("192.0.2.11:4433", true, 2), + direct("192.0.2.30:4433", true, 3), + ], + ); + + let records = LookupRecords::decode("demo.dhttp.net", &response).expect("records"); + + assert_eq!(records.endpoints.len(), 2); + assert_eq!( + records.endpoints[0], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.10:4433".parse().unwrap()) + ); + assert_eq!( + records.endpoints[1], + dquic::qbase::net::addr::EndpointAddr::direct("192.0.2.11:4433".parse().unwrap()) + ); + } + + #[test] + fn lookup_records_ignore_answer_name_mismatch() { + let response = response_for("other.dhttp.net", vec![direct("192.0.2.50:4433", true, 1)]); + + let records = LookupRecords::decode("demo.dhttp.net", &response).expect("records"); + + assert!(records.endpoints.is_empty()); + } +} diff --git a/src/h3/publish.rs b/src/h3/publish.rs index 8b3ca82..ef6db4d 100644 --- a/src/h3/publish.rs +++ b/src/h3/publish.rs @@ -106,8 +106,10 @@ where mod tests { use std::sync::Arc; + #[cfg(feature = "dquic-network")] use dquic::qresolve::Publish as _; use futures::future::BoxFuture; + #[cfg(feature = "dquic-network")] use h3x::endpoint::H3Endpoint; use ring::signature::KeyPair as _; use rustls::{ diff --git a/src/resolvers.rs b/src/resolvers.rs index f7abb4d..144d517 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -92,7 +92,6 @@ impl std::str::FromStr for DnsScheme { } pub mod deferred; -#[cfg(feature = "mdns")] pub(crate) mod selector; pub mod weak; @@ -312,7 +311,10 @@ impl Resolve for Resolvers { #[cfg(test)] mod tests { - use std::{error::Error as StdError, fmt, io, str::FromStr}; + #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] + use std::str::FromStr; + #[cfg(feature = "resolvers")] + use std::{error::Error as StdError, fmt, io}; #[cfg(all(feature = "mdns", feature = "dquic-network", feature = "resolvers"))] use super::MdnsResolvers; @@ -322,12 +324,14 @@ mod tests { #[cfg(feature = "resolvers")] use super::{DnsErrors, DnsScheme}; + #[cfg(feature = "resolvers")] #[derive(Debug)] struct TestSourceError { message: &'static str, source: Option>, } + #[cfg(feature = "resolvers")] impl TestSourceError { fn leaf(message: &'static str) -> Self { Self { @@ -344,12 +348,14 @@ mod tests { } } + #[cfg(feature = "resolvers")] impl fmt::Display for TestSourceError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(self.message) } } + #[cfg(feature = "resolvers")] impl StdError for TestSourceError { fn source(&self) -> Option<&(dyn StdError + 'static)> { self.source @@ -358,10 +364,12 @@ mod tests { } } + #[cfg(feature = "resolvers")] fn other_error(message: &'static str) -> io::Error { io::Error::other(message) } + #[cfg(feature = "resolvers")] fn chained_other_error(root: TestSourceError) -> io::Error { io::Error::other(root) } From 33ee7319c38f04a77c1353a7a84907b0a9e7aaa1 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 19:31:33 +0800 Subject: [PATCH 15/29] refactor: extract h3 lookup cache --- src/h3.rs | 48 ++++++++--------------------- src/h3/cache.rs | 80 +++++++++++++++++++++++++++++++++++++++++++++++- src/h3/lookup.rs | 32 +++++-------------- 3 files changed, 99 insertions(+), 61 deletions(-) diff --git a/src/h3.rs b/src/h3.rs index 8720e69..760e89b 100644 --- a/src/h3.rs +++ b/src/h3.rs @@ -1,16 +1,11 @@ use std::{convert::Infallible, fmt, io, sync::Arc, time::Duration}; -use dashmap::DashMap; -use dquic::{ - qbase::net::addr::EndpointAddr, - qresolve::{Publish, PublishFuture, Resolve, ResolveFuture}, -}; +use dquic::qresolve::{Publish, PublishFuture, Resolve, ResolveFuture}; use h3x::{ dhttp::message::{MessageStreamError, hyper::client::RequestError as HyperRequestError}, endpoint::H3Endpoint, quic, }; -use tokio::time::Instant; use url::Url; mod cache; @@ -27,14 +22,7 @@ where { endpoint: Arc>, base_url: Url, - cached_records: DashMap, - negative_cache: DashMap, -} - -#[derive(Debug)] -pub(super) struct Record { - pub(super) addrs: Vec, - pub(super) expire: Instant, + cache: cache::LookupCache, } impl fmt::Debug for H3Resolver @@ -142,8 +130,7 @@ where Ok(Self { endpoint, base_url, - cached_records: DashMap::new(), - negative_cache: DashMap::new(), + cache: cache::LookupCache::default(), }) } @@ -190,8 +177,6 @@ mod tests { use dquic::{qbase::net::addr::EndpointAddr, qresolve::Source}; #[cfg(feature = "dquic-network")] use futures::StreamExt; - #[cfg(feature = "dquic-network")] - use tokio::time::Instant; use super::*; #[cfg(feature = "dquic-network")] @@ -214,12 +199,9 @@ mod tests { h3x::dquic::QuicEndpoint::builder().build().await, )); let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "car.lab.dhttp.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, + resolver.cache.insert_positive( + "car.lab.dhttp.net", + vec![EndpointAddr::direct("192.168.5.78:41748".parse().unwrap())], ); let mut records = resolver.lookup("car.lab.dhttp.net").await.unwrap(); @@ -244,12 +226,9 @@ mod tests { h3x::dquic::QuicEndpoint::builder().build().await, )); let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "dns.genmeta.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, + resolver.cache.insert_positive( + "dns.genmeta.net", + vec![EndpointAddr::direct("192.0.2.53:4433".parse().unwrap())], ); let mut records = resolver.lookup("dns.genmeta.net").await.unwrap(); @@ -268,12 +247,9 @@ mod tests { h3x::dquic::QuicEndpoint::builder().build().await, )); let resolver = H3Resolver::from_endpoint(DHTTP_H3_DNS_SERVER, endpoint).unwrap(); - resolver.cached_records.insert( - "nat.genmeta.net".to_owned(), - Record { - addrs: vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], - expire: Instant::now() + Duration::from_secs(60), - }, + resolver.cache.insert_positive( + "nat.genmeta.net", + vec![EndpointAddr::direct("192.0.2.10:21000".parse().unwrap())], ); let mut records = resolver.lookup("nat.genmeta.net:20004").await.unwrap(); diff --git a/src/h3/cache.rs b/src/h3/cache.rs index 78d813d..3f143ad 100644 --- a/src/h3/cache.rs +++ b/src/h3/cache.rs @@ -1 +1,79 @@ -// Lookup cache ownership is introduced after the mechanical H3 module split. +use std::time::Duration; + +use dashmap::DashMap; +use dquic::qbase::net::addr::EndpointAddr; +use tokio::time::Instant; + +const POSITIVE_TTL: Duration = Duration::from_secs(10); +const NEGATIVE_TTL: Duration = Duration::from_secs(2); + +#[derive(Debug)] +pub(super) struct CachedRecord { + addrs: Vec, + expire: Instant, +} + +#[derive(Debug, Default)] +pub(super) struct LookupCache { + positive: DashMap, + negative: DashMap, +} + +impl LookupCache { + pub(super) fn prune_expired(&self, now: Instant) { + self.positive.retain(|_host, record| record.expire > now); + self.negative.retain(|_host, expire| *expire > now); + } + + pub(super) fn positive_hit(&self, domain: &str) -> Option> { + self.positive.get(domain).map(|record| record.addrs.clone()) + } + + pub(super) fn negative_hit(&self, domain: &str) -> bool { + self.negative.get(domain).is_some() + } + + pub(super) fn insert_positive(&self, domain: &str, addrs: Vec) { + self.positive.insert( + domain.to_owned(), + CachedRecord { + addrs, + expire: Instant::now() + POSITIVE_TTL, + }, + ); + self.negative.remove(domain); + } + + pub(super) fn insert_negative(&self, domain: &str) { + self.negative + .insert(domain.to_owned(), Instant::now() + NEGATIVE_TTL); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn endpoint(addr: &str) -> EndpointAddr { + EndpointAddr::direct(addr.parse().expect("socket addr")) + } + + #[test] + fn positive_cache_hit_returns_endpoints() { + let cache = LookupCache::default(); + cache.insert_positive("demo.dhttp.net", vec![endpoint("192.0.2.10:4433")]); + + assert_eq!( + cache.positive_hit("demo.dhttp.net").unwrap(), + vec![endpoint("192.0.2.10:4433")] + ); + } + + #[test] + fn negative_cache_hit_blocks_lookup() { + let cache = LookupCache::default(); + cache.insert_negative("missing.dhttp.net"); + + assert!(cache.negative_hit("missing.dhttp.net")); + } +} diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs index 8540035..4118164 100644 --- a/src/h3/lookup.rs +++ b/src/h3/lookup.rs @@ -1,4 +1,4 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use dquic::qresolve::{RecordStream, Source}; use futures::{StreamExt, stream}; @@ -9,7 +9,7 @@ use tokio::time::Instant; use super::{ H3LookupError, H3Resolver, LOOKUP_REQUEST_ATTEMPTS, LOOKUP_REQUEST_TIMEOUT, LookupDecodeError, - Record, h3_lookup_error, lookup_decode_error, + h3_lookup_error, lookup_decode_error, }; use crate::core::{parser::packet::be_packet, wire::be_multi_response}; @@ -170,19 +170,13 @@ where }; let now = Instant::now(); - let positive_ttl = Duration::from_secs(10); - let negative_ttl = Duration::from_secs(2); + self.cache.prune_expired(now); - self.cached_records - .retain(|_host, record| record.expire > now); - self.negative_cache.retain(|_host, expire| *expire > now); - - if self.negative_cache.get(domain).is_some() { + if self.cache.negative_hit(domain) { return Err(H3LookupError::NoRecordFound); } - if let Some(record) = self.cached_records.get(domain) { - let addrs = record.addrs.clone(); + if let Some(addrs) = self.cache.positive_hit(domain) { let stream = stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))); return Ok(stream.boxed()); } @@ -195,8 +189,7 @@ where let response = match self.lookup_response_with_retry(uri).await { Ok(response) => response, Err(H3LookupError::NoRecordFound) => { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); + self.cache.insert_negative(domain); return Err(H3LookupError::NoRecordFound); } Err(error) => return Err(error), @@ -207,20 +200,11 @@ where let addrs = records.endpoints; if addrs.is_empty() { - self.negative_cache - .insert(domain.to_string(), now + negative_ttl); + self.cache.insert_negative(domain); return Err(H3LookupError::NoRecordFound); } - self.cached_records.insert( - domain.to_string(), - Record { - addrs: addrs.clone(), - expire: now + positive_ttl, - }, - ); - - self.negative_cache.remove(domain); + self.cache.insert_positive(domain, addrs.clone()); Ok(stream::iter(addrs.into_iter().map(move |ep| (source.clone(), ep))).boxed()) } From 22118e370278cd1e686dfb834e118d886e30cb77 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:12:03 +0800 Subject: [PATCH 16/29] refactor: rename resolver aggregate error --- src/resolvers.rs | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/resolvers.rs b/src/resolvers.rs index 144d517..c1c5df1 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -124,7 +124,7 @@ impl Display for Resolvers { #[cfg(feature = "resolvers")] #[derive(Debug)] -pub struct DnsErrors { +pub struct ResolversError { errors: Vec<(String, io::Error)>, } @@ -156,7 +156,7 @@ fn format_dns_error_entry( } #[cfg(feature = "resolvers")] -impl fmt::Display for DnsErrors { +impl fmt::Display for ResolversError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.errors.is_empty() { return write!(f, "no DNS resolvers available"); @@ -171,7 +171,7 @@ impl fmt::Display for DnsErrors { } #[cfg(feature = "resolvers")] -impl Error for DnsErrors {} +impl Error for ResolversError {} #[cfg(feature = "resolvers")] #[derive(Default)] @@ -276,7 +276,7 @@ impl Resolvers { pub async fn lookup( &self, name: &str, - ) -> Result + use<>, DnsErrors> { + ) -> Result + use<>, ResolversError> { let mut errors = vec![]; let mut lookups = stream::FuturesUnordered::from_iter( @@ -291,7 +291,7 @@ impl Resolvers { match lookups.next().await { Some((Ok(endpoints), _)) => break endpoints, Some((Err(error), resolver)) => errors.push((resolver.to_string(), error)), - None => return Err(DnsErrors { errors }), + None => return Err(ResolversError { errors }), } }; @@ -322,7 +322,7 @@ mod tests { use super::Resolvers; use super::{DHTTP_H3_DNS_SERVER, DHTTP_HTTP_DNS_SERVER, DHTTP_MDNS_SERVICE, resolvable_name}; #[cfg(feature = "resolvers")] - use super::{DnsErrors, DnsScheme}; + use super::{DnsScheme, ResolversError}; #[cfg(feature = "resolvers")] #[derive(Debug)] @@ -429,16 +429,16 @@ mod tests { #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_no_resolvers_available_when_empty() { - let error = DnsErrors { errors: vec![] }; + fn resolvers_error_renders_no_resolvers_available_when_empty() { + let error = ResolversError { errors: vec![] }; assert_eq!(error.to_string(), "no DNS resolvers available"); } #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_resolver_bullets_in_stored_order() { - let error = DnsErrors { + fn resolvers_error_renders_resolver_bullets_in_stored_order() { + let error = ResolversError { errors: vec![ ( "System DNS Resolver".to_string(), @@ -460,8 +460,8 @@ mod tests { #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_numbered_source_chain_for_one_resolver() { - let error = DnsErrors { + fn resolvers_error_renders_numbered_source_chain_for_one_resolver() { + let error = ResolversError { errors: vec![( "DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/))".to_string(), chained_other_error(TestSourceError::with_source( @@ -483,8 +483,8 @@ mod tests { #[cfg(feature = "resolvers")] #[test] - fn dns_errors_render_repeated_source_messages_without_deduplication() { - let error = DnsErrors { + fn resolvers_error_renders_repeated_source_messages_without_deduplication() { + let error = ResolversError { errors: vec![( "DeferredResolver(H3 DNS Resolver(https://dns.genmeta.net:4433/))".to_string(), chained_other_error(TestSourceError::with_source( From d368dc8c280a44f8800c7c97e93b644b48c4e78a Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:15:46 +0800 Subject: [PATCH 17/29] refactor: build unsigned endpoint dns packets --- src/publishers/address.rs | 50 +++++++++++++-- src/publishers/packet.rs | 126 ++++++++++++++++++++++++-------------- 2 files changed, 123 insertions(+), 53 deletions(-) diff --git a/src/publishers/address.rs b/src/publishers/address.rs index 7fca89c..ca8bcfa 100644 --- a/src/publishers/address.rs +++ b/src/publishers/address.rs @@ -60,14 +60,29 @@ pub trait AddressViewSource { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum PublishAddressScope { +pub enum PublishScope { WideArea, LocalLink { device: Arc, family: Family }, } +impl PublishScope { + pub(crate) fn selector(&self) -> AddressSelector<'_> { + match self { + Self::WideArea => AddressSelector::WideArea, + Self::LocalLink { device, family } => AddressSelector::LocalLink { + device: device.as_ref(), + family: *family, + }, + } + } +} + +#[allow(dead_code)] +pub type PublishAddressScope = PublishScope; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct PublishAddressGroup { - scope: PublishAddressScope, + scope: PublishScope, endpoints: Vec, } @@ -77,7 +92,7 @@ impl PublishAddressGroup { I: IntoIterator, { Self { - scope: PublishAddressScope::WideArea, + scope: PublishScope::WideArea, endpoints: endpoints.into_iter().collect(), } } @@ -87,7 +102,7 @@ impl PublishAddressGroup { I: IntoIterator, { Self { - scope: PublishAddressScope::LocalLink { + scope: PublishScope::LocalLink { device: device.into(), family, }, @@ -97,9 +112,9 @@ impl PublishAddressGroup { fn matches(&self, selector: AddressSelector<'_>) -> bool { match (&self.scope, selector) { - (PublishAddressScope::WideArea, AddressSelector::WideArea) => true, + (PublishScope::WideArea, AddressSelector::WideArea) => true, ( - PublishAddressScope::LocalLink { device, family }, + PublishScope::LocalLink { device, family }, AddressSelector::LocalLink { device: selected_device, family: selected_family, @@ -389,6 +404,29 @@ fn local_endpoints_from_iface(iface: &BindInterface, family: Family) -> Vec::from("en0"), + family: Family::V4, + }; + + assert_eq!( + scope.selector(), + AddressSelector::LocalLink { + device: "en0", + family: Family::V4, + } + ); + } + #[test] fn publish_addresses_select_wide_area_only_for_wide_area_selector() { let wide = EndpointAddr::direct("203.0.113.10:443".parse().unwrap()); diff --git a/src/publishers/packet.rs b/src/publishers/packet.rs index 420bd46..3369472 100644 --- a/src/publishers/packet.rs +++ b/src/publishers/packet.rs @@ -1,38 +1,30 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; -use dhttp_identity::{ - identity::{LocalAuthority, LocalAuthorityCertificateExt}, - name::Name, -}; +use dhttp_identity::name::Name; use dquic::qbase::net::addr::EndpointAddr; -use snafu::{ResultExt, Snafu}; +use snafu::Snafu; -use crate::core::{ - MdnsPacket, - parser::record::endpoint::{EndpointAddr as DnsEndpointAddr, SignEndpointError}, -}; +use crate::core::{MdnsPacket, parser::record::endpoint::EndpointAddr as DnsEndpointAddr}; #[derive(Debug, Snafu)] #[snafu(module)] -pub enum SignEndpointRecordsError { +pub enum EncodeEndpointPacketError { #[snafu(display("failed to encode endpoint address"))] EncodeEndpoint, - #[snafu(display("failed to extract dhttp certificate selector"))] - CertificateSelector { - source: dhttp_identity::identity::ExtractDhttpSubjectKeyIdentifierError, - }, - #[snafu(display("failed to sign endpoint address"))] - SignEndpoint { source: SignEndpointError }, } +#[allow(dead_code)] +pub type SignEndpointRecordsError = EncodeEndpointPacketError; + +#[allow(dead_code)] #[derive(Clone)] pub struct EndpointRecordSigner { - authority: Arc, + authority: std::sync::Arc, } impl std::fmt::Debug for EndpointRecordSigner where - A: LocalAuthority, + A: dhttp_identity::identity::LocalAuthority, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EndpointRecordSigner") @@ -41,15 +33,16 @@ where } } +#[allow(dead_code)] impl EndpointRecordSigner where - A: LocalAuthority + Send + Sync + ?Sized, + A: dhttp_identity::identity::LocalAuthority + Send + Sync + ?Sized, { - pub fn new(authority: Arc) -> Self { + pub fn new(authority: std::sync::Arc) -> Self { Self { authority } } - pub fn authority(&self) -> &Arc { + pub fn authority(&self) -> &std::sync::Arc { &self.authority } @@ -58,30 +51,69 @@ where name: &Name<'_>, endpoints: &[EndpointAddr], ) -> Result, SignEndpointRecordsError> { - let selector = self - .authority - .dhttp_subject_key_identifier() - .context(sign_endpoint_records_error::CertificateSelectorSnafu)?; - let chain = selector.chain(); - - let mut signed = Vec::with_capacity(endpoints.len()); - for endpoint in endpoints { - let Ok(mut endpoint) = DnsEndpointAddr::try_from(*endpoint) else { - return sign_endpoint_records_error::EncodeEndpointSnafu.fail(); - }; - endpoint.set_main( - chain.kind() == dhttp_identity::certificate::CertificateChainKind::Primary, - ); - endpoint.set_sequence(chain.sequence().get().into()); - endpoint - .sign_with_authority(self.authority.as_ref()) - .await - .context(sign_endpoint_records_error::SignEndpointSnafu)?; - signed.push(endpoint); - } - - let mut hosts = HashMap::new(); - hosts.insert(name.as_str().to_owned(), signed); - Ok(MdnsPacket::answer(0, &hosts).to_bytes()) + let _ = &self.authority; + endpoint_packet(name, endpoints.iter().copied()) + } +} + +pub(crate) fn endpoint_packet( + name: &Name<'_>, + endpoints: impl IntoIterator, +) -> Result, EncodeEndpointPacketError> { + let mut encoded = Vec::new(); + for endpoint in endpoints { + let Ok(endpoint) = DnsEndpointAddr::try_from(endpoint) else { + return encode_endpoint_packet_error::EncodeEndpointSnafu.fail(); + }; + encoded.push(endpoint); + } + + let mut hosts = HashMap::new(); + hosts.insert(name.as_str().to_owned(), encoded); + Ok(MdnsPacket::answer(0, &hosts).to_bytes()) +} + +#[cfg(test)] +mod tests { + use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + + use dhttp_identity::name::Name; + use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; + + use super::endpoint_packet; + use crate::core::parser::{ + packet::be_packet, + record::{RData, Type}, + }; + + #[test] + fn endpoint_packet_encodes_unsigned_e_records() { + let name = Name::try_from("alice.dhttp.net").expect("valid dns owner name"); + let endpoint = DquicEndpointAddr::direct(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(203, 0, 113, 10), + 4433, + ))); + + let packet = endpoint_packet(&name, [endpoint]).expect("endpoint packet"); + let (remain, parsed) = be_packet(&packet).expect("dns packet parses"); + assert!(remain.is_empty()); + assert_eq!(parsed.answers.len(), 1); + assert_eq!(parsed.answers[0].name(), "alice.dhttp.net"); + assert_eq!(parsed.answers[0].typ(), Type::E); + + let RData::E(encoded) = parsed.answers[0].data() else { + panic!("answer must be an E record"); + }; + assert!(!encoded.is_signed()); + } + + #[test] + fn endpoint_packet_allows_empty_endpoint_set() { + let name = Name::try_from("alice.dhttp.net").expect("valid dns owner name"); + + let packet = endpoint_packet(&name, []).expect("endpoint packet"); + let (remain, parsed) = be_packet(&packet).expect("dns packet parses"); + assert!(remain.is_empty()); + assert!(parsed.answers.is_empty()); } } From 9ce4d5f2686ab8fc109f3777a0d2385762d366bc Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:19:13 +0800 Subject: [PATCH 18/29] feat: add scoped dns publisher --- src/publishers.rs | 6 +- src/publishers/publisher.rs | 285 ++++++++++++++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 1 deletion(-) create mode 100644 src/publishers/publisher.rs diff --git a/src/publishers.rs b/src/publishers.rs index ae80ef8..c1cc955 100644 --- a/src/publishers.rs +++ b/src/publishers.rs @@ -4,6 +4,8 @@ mod address; mod dispatch; #[cfg(feature = "publishers")] mod packet; +#[cfg(feature = "publishers")] +mod publisher; #[cfg(all(feature = "publishers", feature = "dquic-network"))] use std::{any::TypeId, net::SocketAddr, time::Duration}; @@ -13,7 +15,7 @@ use std::{io, sync::Arc}; #[cfg(feature = "publishers")] pub use address::{ AddressSelector, AddressView, FnAddressView, PublishAddressGroup, PublishAddressScope, - PublishAddresses, + PublishAddresses, PublishScope, }; #[cfg(all(feature = "publishers", feature = "dquic-network"))] pub use address::{AddressViewSource, EndpointBindingAddresses}; @@ -28,6 +30,8 @@ use dquic::{ #[cfg(feature = "publishers")] pub use packet::{EndpointRecordSigner, SignEndpointRecordsError}; #[cfg(feature = "publishers")] +pub use publisher::{Publisher, PublisherError}; +#[cfg(feature = "publishers")] use snafu::Snafu; #[cfg(feature = "h3")] diff --git a/src/publishers/publisher.rs b/src/publishers/publisher.rs new file mode 100644 index 0000000..01e44c8 --- /dev/null +++ b/src/publishers/publisher.rs @@ -0,0 +1,285 @@ +use std::{fmt, io, sync::Arc}; + +use dhttp_identity::name::Name; +use dquic::qresolve::Publish; +use snafu::{ResultExt, Snafu}; + +use super::{AddressView, PublishScope, packet}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum PublisherError { + #[snafu(display("failed to encode endpoint dns packet"))] + EncodePacket { + source: packet::EncodeEndpointPacketError, + }, + #[snafu(display("failed to publish dns packet with {publisher}"))] + Publish { + publisher: String, + source: io::Error, + }, +} + +#[derive(Clone)] +pub struct Publisher { + inner: PublisherKind, +} + +#[derive(Clone)] +enum PublisherKind { + Custom { + scope: PublishScope, + publisher: Arc, + }, +} + +impl fmt::Debug for Publisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner { + PublisherKind::Custom { scope, publisher } => f + .debug_struct("Publisher") + .field("scope", scope) + .field("publisher", publisher) + .finish(), + } + } +} + +impl fmt::Display for Publisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.inner { + PublisherKind::Custom { publisher, .. } => fmt::Display::fmt(publisher, f), + } + } +} + +impl Publisher { + pub fn new(scope: PublishScope, publisher: Arc) -> Self { + Self { + inner: PublisherKind::Custom { scope, publisher }, + } + } + + #[cfg(feature = "http")] + pub fn http(publisher: Arc) -> Self { + Self::new(PublishScope::WideArea, publisher) + } + + #[cfg(feature = "h3")] + pub fn h3(publisher: Arc>) -> Self + where + C: h3x::quic::Connect + h3x::quic::WithLocalAuthority, + crate::h3::H3Resolver: Publish + Send + Sync + 'static, + { + Self::new(PublishScope::WideArea, publisher) + } + + pub async fn publish(&self, name: &Name<'_>, view: &V) -> Result<(), PublisherError> + where + V: AddressView + Sync, + { + match &self.inner { + PublisherKind::Custom { scope, publisher } => { + publish_selected(publisher.as_ref(), scope, name, view).await + } + } + } +} + +async fn publish_selected( + publisher: &(dyn Publish + Send + Sync), + scope: &PublishScope, + name: &Name<'_>, + view: &V, +) -> Result<(), PublisherError> +where + V: AddressView + Sync, +{ + let endpoints: Vec<_> = view.endpoints(scope.selector()).collect(); + let packet = + packet::endpoint_packet(name, endpoints).context(publisher_error::EncodePacketSnafu)?; + tracing::debug!( + publisher = %publisher, + name = %name, + packet_len = packet.len(), + "publishing dns packet" + ); + publisher + .publish(name.as_str(), &packet) + .await + .context(publisher_error::PublishSnafu { + publisher: publisher.to_string(), + }) +} + +#[cfg(test)] +mod tests { + use std::{ + fmt, io, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + sync::{Arc, Mutex}, + }; + + use dhttp_identity::name::Name; + use dquic::{ + qbase::net::{Family, addr::EndpointAddr}, + qresolve::{Publish, PublishFuture}, + }; + use futures::FutureExt; + + use crate::{ + core::parser::{packet::be_packet, record::RData}, + publishers::{PublishScope, Publisher}, + }; + + #[derive(Debug, Default)] + struct RecordingPublisher { + calls: Mutex)>>, + } + + impl RecordingPublisher { + fn calls(&self) -> Vec<(String, Vec)> { + self.calls.lock().expect("calls lock poisoned").clone() + } + } + + impl fmt::Display for RecordingPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("recording publisher") + } + } + + impl Publish for RecordingPublisher { + fn publish<'a>(&'a self, name: &'a str, packet: &'a [u8]) -> PublishFuture<'a> { + async move { + self.calls + .lock() + .expect("calls lock poisoned") + .push((name.to_owned(), packet.to_vec())); + Ok(()) + } + .boxed() + } + } + + #[derive(Debug)] + struct FailingPublisher; + + impl fmt::Display for FailingPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("failing publisher") + } + } + + impl Publish for FailingPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + async move { Err(io::Error::other("publish rejected")) }.boxed() + } + } + + fn endpoint(ip: [u8; 4], port: u16) -> EndpointAddr { + EndpointAddr::direct(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(ip), port))) + } + + #[tokio::test] + async fn custom_publisher_selects_wide_area_addresses() { + let wide = endpoint([203, 0, 113, 10], 4433); + let local = endpoint([192, 168, 1, 20], 4433); + let recorder = Arc::new(RecordingPublisher::default()); + let publisher = Publisher::new(PublishScope::WideArea, recorder.clone()); + let view = crate::publishers::PublishAddresses::new() + .wide_area([wide]) + .local_link("en0", Family::V4, [local]); + let name = Name::try_from("alice.dhttp.net").expect("valid name"); + + publisher + .publish(&name, &view) + .await + .expect("publish succeeds"); + + let calls = recorder.calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].0, "alice.dhttp.net"); + let (_, packet) = be_packet(&calls[0].1).expect("packet parses"); + let endpoints: Vec<_> = packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + RData::E(endpoint) => Some(endpoint.primary), + _ => None, + }) + .collect(); + assert_eq!( + endpoints, + vec![SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(203, 0, 113, 10), + 4433 + ))] + ); + } + + #[tokio::test] + async fn custom_publisher_selects_matching_local_link_addresses() { + let en0 = endpoint([192, 168, 1, 20], 4433); + let en1 = endpoint([192, 168, 2, 20], 4433); + let recorder = Arc::new(RecordingPublisher::default()); + let publisher = Publisher::new( + PublishScope::LocalLink { + device: Arc::::from("en1"), + family: Family::V4, + }, + recorder.clone(), + ); + let view = crate::publishers::PublishAddresses::new() + .local_link("en0", Family::V4, [en0]) + .local_link("en1", Family::V4, [en1]); + let name = Name::try_from("alice.dhttp.net").expect("valid name"); + + publisher + .publish(&name, &view) + .await + .expect("publish succeeds"); + + let calls = recorder.calls(); + assert_eq!(calls.len(), 1); + let (_, packet) = be_packet(&calls[0].1).expect("packet parses"); + let endpoints: Vec<_> = packet + .answers + .iter() + .filter_map(|answer| match answer.data() { + RData::E(endpoint) => Some(endpoint.primary), + _ => None, + }) + .collect(); + assert_eq!( + endpoints, + vec![SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(192, 168, 2, 20), + 4433 + ))] + ); + } + + #[tokio::test] + async fn custom_publisher_error_preserves_publish_source() { + let publisher = Publisher::new(PublishScope::WideArea, Arc::new(FailingPublisher)); + let view = crate::publishers::PublishAddresses::new(); + let name = Name::try_from("alice.dhttp.net").expect("valid name"); + + let error = publisher + .publish(&name, &view) + .await + .expect_err("publish should fail"); + + assert_eq!( + error.to_string(), + "failed to publish dns packet with failing publisher" + ); + assert_eq!( + std::error::Error::source(&error) + .expect("source") + .to_string(), + "publish rejected" + ); + } +} From 1d870927859ff24c7b165ac20ff144a1d00da6d8 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:23:20 +0800 Subject: [PATCH 19/29] feat: aggregate dns publishers --- src/publishers.rs | 50 +-------- src/publishers/aggregate.rs | 196 ++++++++++++++++++++++++++++++++++++ src/publishers/publisher.rs | 85 +++++++++++++++- tests/publishers_surface.rs | 15 +-- 4 files changed, 290 insertions(+), 56 deletions(-) create mode 100644 src/publishers/aggregate.rs diff --git a/src/publishers.rs b/src/publishers.rs index c1cc955..4ac735a 100644 --- a/src/publishers.rs +++ b/src/publishers.rs @@ -1,6 +1,8 @@ #[cfg(feature = "publishers")] mod address; #[cfg(feature = "publishers")] +mod aggregate; +#[cfg(feature = "publishers")] mod dispatch; #[cfg(feature = "publishers")] mod packet; @@ -20,9 +22,11 @@ pub use address::{ #[cfg(all(feature = "publishers", feature = "dquic-network"))] pub use address::{AddressViewSource, EndpointBindingAddresses}; #[cfg(feature = "publishers")] +pub use aggregate::{Publishers, PublishersError}; +#[cfg(feature = "publishers")] use dhttp_identity::{identity::LocalAuthority, name::Name}; #[cfg(feature = "publishers")] -use dquic::qresolve::{Publish, Resolve}; +use dquic::qresolve::Resolve; #[cfg(all(feature = "publishers", feature = "dquic-network"))] use dquic::{ qinterface::component::location::AddressEvent, qtraversal::nat::client::ClientLocationData, @@ -129,50 +133,6 @@ where } } -#[cfg(feature = "publishers")] -#[derive(Default, Debug, Clone)] -pub struct Publishers { - publishers: Vec>, -} - -#[cfg(feature = "publishers")] -impl Publishers { - pub fn new() -> Self { - Self::default() - } - - pub fn with(mut self, publisher: Arc) -> Self { - self.push(publisher); - self - } - - pub fn push(&mut self, publisher: Arc) { - self.publishers.push(publisher); - } - - pub fn iter(&self) -> impl Iterator> { - self.publishers.iter() - } -} - -#[cfg(feature = "publishers")] -#[derive(Default, Debug)] -pub struct PublishersBuilder { - publishers: Publishers, -} - -#[cfg(feature = "publishers")] -impl PublishersBuilder { - pub fn publisher(mut self, publisher: Arc) -> Self { - self.publishers.push(publisher); - self - } - - pub fn build(self) -> Publishers { - self.publishers - } -} - #[cfg(all(feature = "publishers", feature = "dquic-network"))] pub type EndpointPublisherLoop = EndpointPublicationLoop< dyn LocalAuthority + Send + Sync, diff --git a/src/publishers/aggregate.rs b/src/publishers/aggregate.rs new file mode 100644 index 0000000..4c61208 --- /dev/null +++ b/src/publishers/aggregate.rs @@ -0,0 +1,196 @@ +use std::{error::Error, fmt}; + +use dhttp_identity::name::Name; + +use super::{AddressView, Publisher, PublisherError}; + +#[derive(Default, Clone, Debug)] +pub struct Publishers { + publishers: Vec, +} + +#[derive(Debug)] +pub struct PublishersError { + errors: Vec<(String, PublisherError)>, +} + +fn format_error_sources(f: &mut fmt::Formatter<'_>, error: &(dyn Error + 'static)) -> fmt::Result { + let mut index = 1; + let mut current = error.source(); + + while let Some(source) = current { + write!(f, "\n {index}. {source}")?; + index += 1; + current = source.source(); + } + + Ok(()) +} + +impl fmt::Display for PublishersError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.errors.is_empty() { + return write!(f, "no DNS publishers available"); + } + + write!(f, "all DNS publishers failed")?; + for (publisher, error) in &self.errors { + write!(f, "\n - {publisher}: {error}")?; + format_error_sources(f, error)?; + } + Ok(()) + } +} + +impl Error for PublishersError {} + +impl Publishers { + pub fn new() -> Self { + Self::default() + } + + pub fn with(mut self, publisher: Publisher) -> Self { + self.push(publisher); + self + } + + pub fn push(&mut self, publisher: Publisher) { + self.publishers.push(publisher); + } + + pub fn iter(&self) -> impl Iterator { + self.publishers.iter() + } + + pub async fn publish(&self, name: &Name<'_>, view: &V) -> Result<(), PublishersError> + where + V: AddressView + Sync, + { + if self.publishers.is_empty() { + return Err(PublishersError { errors: Vec::new() }); + } + + let mut errors = Vec::new(); + let mut succeeded = false; + for publisher in &self.publishers { + match publisher.publish(name, view).await { + Ok(()) => succeeded = true, + Err(error) => errors.push((publisher.to_string(), error)), + } + } + + if succeeded { + Ok(()) + } else { + Err(PublishersError { errors }) + } + } +} + +#[cfg(test)] +mod tests { + use std::{fmt, io, sync::Arc}; + + use dhttp_identity::name::Name; + use dquic::qresolve::{Publish, PublishFuture}; + use futures::FutureExt; + + use crate::publishers::{PublishScope, Publisher, Publishers}; + + #[derive(Debug)] + struct OkPublisher(&'static str); + + impl fmt::Display for OkPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } + } + + impl Publish for OkPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + async move { Ok(()) }.boxed() + } + } + + #[derive(Debug)] + struct ErrPublisher(&'static str, &'static str); + + impl fmt::Display for ErrPublisher { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.0) + } + } + + impl Publish for ErrPublisher { + fn publish<'a>(&'a self, _name: &'a str, _packet: &'a [u8]) -> PublishFuture<'a> { + let message = self.1; + async move { Err(io::Error::other(message)) }.boxed() + } + } + + fn name() -> Name<'static> { + Name::try_from("alice.dhttp.net").expect("valid name") + } + + #[tokio::test] + async fn empty_publishers_report_no_publishers_available() { + let publishers = Publishers::new(); + let view = crate::publishers::PublishAddresses::new(); + + let error = publishers + .publish(&name(), &view) + .await + .expect_err("empty aggregate should fail"); + + assert_eq!(error.to_string(), "no DNS publishers available"); + } + + #[tokio::test] + async fn publishers_succeed_when_any_publisher_succeeds() { + let publishers = Publishers::new() + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(ErrPublisher("first publisher", "offline")), + )) + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(OkPublisher("second publisher")), + )); + let view = crate::publishers::PublishAddresses::new(); + + publishers + .publish(&name(), &view) + .await + .expect("one success is enough"); + } + + #[tokio::test] + async fn publishers_report_all_failures_when_every_publisher_fails() { + let publishers = Publishers::new() + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(ErrPublisher("first publisher", "offline")), + )) + .with(Publisher::new( + PublishScope::WideArea, + Arc::new(ErrPublisher("second publisher", "permission denied")), + )); + let view = crate::publishers::PublishAddresses::new(); + + let error = publishers + .publish(&name(), &view) + .await + .expect_err("all publishers fail"); + + assert_eq!( + error.to_string(), + concat!( + "all DNS publishers failed\n", + " - first publisher: failed to publish dns packet with first publisher\n", + " 1. offline\n", + " - second publisher: failed to publish dns packet with second publisher\n", + " 1. permission denied" + ) + ); + } +} diff --git a/src/publishers/publisher.rs b/src/publishers/publisher.rs index 01e44c8..1c3b1e0 100644 --- a/src/publishers/publisher.rs +++ b/src/publishers/publisher.rs @@ -2,7 +2,7 @@ use std::{fmt, io, sync::Arc}; use dhttp_identity::name::Name; use dquic::qresolve::Publish; -use snafu::{ResultExt, Snafu}; +use snafu::{IntoError, ResultExt, Snafu}; use super::{AddressView, PublishScope, packet}; @@ -18,6 +18,9 @@ pub enum PublisherError { publisher: String, source: io::Error, }, + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + #[snafu(display("all mdns publishers failed"))] + Mdns { source: MdnsPublishersError }, } #[derive(Clone)] @@ -31,8 +34,34 @@ enum PublisherKind { scope: PublishScope, publisher: Arc, }, + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + Mdns(Arc), } +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +#[derive(Debug)] +pub struct MdnsPublishersError { + errors: Vec<(String, io::Error)>, +} + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +impl fmt::Display for MdnsPublishersError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.errors.is_empty() { + return write!(f, "no mdns publishers available"); + } + + write!(f, "all mdns publishers failed")?; + for (publisher, error) in &self.errors { + write!(f, "\n - {publisher}: {error}")?; + } + Ok(()) + } +} + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +impl std::error::Error for MdnsPublishersError {} + impl fmt::Debug for Publisher { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.inner { @@ -41,6 +70,11 @@ impl fmt::Debug for Publisher { .field("scope", scope) .field("publisher", publisher) .finish(), + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + PublisherKind::Mdns(resolvers) => f + .debug_struct("Publisher") + .field("mdns", resolvers) + .finish(), } } } @@ -49,6 +83,8 @@ impl fmt::Display for Publisher { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.inner { PublisherKind::Custom { publisher, .. } => fmt::Display::fmt(publisher, f), + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + PublisherKind::Mdns(resolvers) => fmt::Display::fmt(resolvers, f), } } } @@ -74,6 +110,13 @@ impl Publisher { Self::new(PublishScope::WideArea, publisher) } + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + pub fn mdns(resolvers: Arc) -> Self { + Self { + inner: PublisherKind::Mdns(resolvers), + } + } + pub async fn publish(&self, name: &Name<'_>, view: &V) -> Result<(), PublisherError> where V: AddressView + Sync, @@ -82,6 +125,8 @@ impl Publisher { PublisherKind::Custom { scope, publisher } => { publish_selected(publisher.as_ref(), scope, name, view).await } + #[cfg(all(feature = "mdns", feature = "dquic-network"))] + PublisherKind::Mdns(resolvers) => publish_mdns(resolvers, name, view).await, } } } @@ -283,3 +328,41 @@ mod tests { ); } } + +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +async fn publish_mdns( + resolvers: &crate::mdns::MdnsResolvers, + name: &Name<'_>, + view: &V, +) -> Result<(), PublisherError> +where + V: AddressView + Sync, +{ + let bound_resolvers = resolvers.bound_resolvers(); + if bound_resolvers.is_empty() { + tracing::debug!(name = %name, "no mdns publishers currently bound"); + return Ok(()); + } + + let mut errors = Vec::new(); + let mut succeeded = false; + for bound in bound_resolvers { + let scope = PublishScope::LocalLink { + device: bound.device.clone().into(), + family: bound.family, + }; + match publish_selected(&bound.resolver, &scope, name, view).await { + Ok(()) => succeeded = true, + Err(PublisherError::Publish { source, .. }) => { + errors.push((bound.resolver.to_string(), source)); + } + Err(error) => return Err(error), + } + } + + if succeeded { + Ok(()) + } else { + Err(publisher_error::MdnsSnafu.into_error(MdnsPublishersError { errors })) + } +} diff --git a/tests/publishers_surface.rs b/tests/publishers_surface.rs index 745dd74..89d769a 100644 --- a/tests/publishers_surface.rs +++ b/tests/publishers_surface.rs @@ -1,16 +1,11 @@ #[cfg(feature = "publishers")] #[test] -fn publishers_facade_exposes_endpoint_publisher_and_aggregate_types() { - let _ = core::any::type_name::(); - let _ = core::any::type_name::< - ddns::publishers::EndpointPublisher< - dyn dhttp_identity::identity::LocalAuthority + Send + Sync, - dyn dquic::qresolve::Resolve + Send + Sync, - >, - >(); - +fn publishers_facade_exposes_publisher_and_aggregate_types() { + let _ = core::any::type_name::(); let _ = core::any::type_name::(); - let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); + let _ = core::any::type_name::(); } #[cfg(all(feature = "publishers", feature = "dquic-network"))] From 4a7295ad8419a694944542b4d3cbaa04a1a2d3e6 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:26:13 +0800 Subject: [PATCH 20/29] refactor: publish endpoint addresses through publishers --- src/publishers.rs | 135 +++--------------------- src/publishers/dispatch.rs | 203 ------------------------------------ src/publishers/packet.rs | 43 -------- tests/publishers_surface.rs | 9 +- 4 files changed, 16 insertions(+), 374 deletions(-) delete mode 100644 src/publishers/dispatch.rs diff --git a/src/publishers.rs b/src/publishers.rs index 4ac735a..0ce35a2 100644 --- a/src/publishers.rs +++ b/src/publishers.rs @@ -3,40 +3,30 @@ mod address; #[cfg(feature = "publishers")] mod aggregate; #[cfg(feature = "publishers")] -mod dispatch; -#[cfg(feature = "publishers")] mod packet; #[cfg(feature = "publishers")] mod publisher; #[cfg(all(feature = "publishers", feature = "dquic-network"))] use std::{any::TypeId, net::SocketAddr, time::Duration}; -#[cfg(feature = "publishers")] -use std::{io, sync::Arc}; #[cfg(feature = "publishers")] pub use address::{ - AddressSelector, AddressView, FnAddressView, PublishAddressGroup, PublishAddressScope, - PublishAddresses, PublishScope, + AddressSelector, AddressView, FnAddressView, PublishAddressGroup, PublishAddresses, + PublishScope, }; #[cfg(all(feature = "publishers", feature = "dquic-network"))] pub use address::{AddressViewSource, EndpointBindingAddresses}; #[cfg(feature = "publishers")] pub use aggregate::{Publishers, PublishersError}; #[cfg(feature = "publishers")] -use dhttp_identity::{identity::LocalAuthority, name::Name}; -#[cfg(feature = "publishers")] -use dquic::qresolve::Resolve; +use dhttp_identity::name::Name; #[cfg(all(feature = "publishers", feature = "dquic-network"))] use dquic::{ qinterface::component::location::AddressEvent, qtraversal::nat::client::ClientLocationData, }; #[cfg(feature = "publishers")] -pub use packet::{EndpointRecordSigner, SignEndpointRecordsError}; -#[cfg(feature = "publishers")] pub use publisher::{Publisher, PublisherError}; -#[cfg(feature = "publishers")] -use snafu::Snafu; #[cfg(feature = "h3")] pub use crate::h3::H3Resolver as H3Publisher; @@ -45,34 +35,11 @@ pub use crate::http::HttpResolver as HttpPublisher; #[cfg(feature = "mdns")] pub use crate::mdns::MdnsPublisher; -#[cfg(feature = "publishers")] -#[derive(Debug, Snafu)] -#[snafu(module(create_publisher_error))] -pub enum CreatePublisherError { - #[snafu(display("anonymous endpoint cannot publish dns records"))] - AnonymousEndpoint, -} - -#[cfg(feature = "publishers")] -#[derive(Debug, Snafu)] -#[snafu(module)] -pub enum PublishOnceError { - #[snafu(display("no publisher resolver available"))] - NoPublisherResolver, - #[snafu(display("failed to sign endpoint records"))] - SignEndpointRecords { source: SignEndpointRecordsError }, - #[snafu(display("failed to publish dns packet with {publisher}"))] - Publish { - publisher: String, - source: io::Error, - }, -} - #[cfg(all(feature = "publishers", feature = "dquic-network"))] pub const DEFAULT_PUBLISH_INTERVAL: Duration = Duration::from_secs(20); /// Upper bound for a single publish attempt in the background loop. /// -/// Network changes can leave an in-flight H3 publish waiting on paths that no +/// Network changes can leave an in-flight publish waiting on paths that no /// longer exist. Timing out the attempt keeps consecutive publishes /// independent: the next interval observes the current bindings again. #[cfg(all(feature = "publishers", feature = "dquic-network"))] @@ -80,86 +47,24 @@ pub const DEFAULT_PUBLISH_TIMEOUT: Duration = Duration::from_secs(10); #[cfg(all(feature = "publishers", feature = "dquic-network"))] const PUBLISH_CHANGE_DEBOUNCE: Duration = Duration::from_millis(50); -#[cfg(feature = "publishers")] -#[derive(Clone)] -pub struct EndpointPublisher< - A: ?Sized = dyn LocalAuthority + Send + Sync, - R: ?Sized = dyn Resolve + Send + Sync, -> { - signer: EndpointRecordSigner, - resolver: Arc, -} - -#[cfg(feature = "publishers")] -impl EndpointPublisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: ?Sized, -{ - pub fn new(signer: EndpointRecordSigner, resolver: Arc) -> Self { - Self { signer, resolver } - } - - pub fn signer(&self) -> &EndpointRecordSigner { - &self.signer - } - - pub fn resolver(&self) -> &Arc { - &self.resolver - } -} - -#[cfg(feature = "publishers")] -impl EndpointPublisher -where - A: LocalAuthority + Send + Sync + ?Sized, - R: dispatch::ResolveDispatchTarget + ?Sized, -{ - pub async fn publish_once( - &self, - name: &Name<'_>, - addresses: &V, - ) -> Result<(), PublishOnceError> - where - V: AddressView + Sync, - { - let published = - dispatch::publish_to_resolver(self.signer(), self.resolver.as_ref(), name, addresses) - .await?; - if !published { - return publish_once_error::NoPublisherResolverSnafu.fail(); - } - Ok(()) - } -} - #[cfg(all(feature = "publishers", feature = "dquic-network"))] -pub type EndpointPublisherLoop = EndpointPublicationLoop< - dyn LocalAuthority + Send + Sync, - dyn Resolve + Send + Sync, - EndpointBindingAddresses, ->; - -#[cfg(all(feature = "publishers", feature = "dquic-network"))] -pub struct EndpointPublicationLoop { +pub struct EndpointPublicationLoop { name: Name<'static>, - publisher: EndpointPublisher, + publishers: Publishers, source: S, interval: Duration, publish_timeout: Duration, } #[cfg(all(feature = "publishers", feature = "dquic-network"))] -impl std::fmt::Debug for EndpointPublicationLoop +impl std::fmt::Debug for EndpointPublicationLoop where - A: LocalAuthority + Send + Sync + ?Sized, - R: ?Sized, S: std::fmt::Debug, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EndpointPublicationLoop") .field("name", &self.name) - .field("signer", self.publisher.signer()) + .field("publishers", &self.publishers) .field("source", &self.source) .field("interval", &self.interval) .field("publish_timeout", &self.publish_timeout) @@ -168,16 +73,14 @@ where } #[cfg(all(feature = "publishers", feature = "dquic-network"))] -impl EndpointPublicationLoop +impl EndpointPublicationLoop where - A: LocalAuthority + Send + Sync + ?Sized, - R: dispatch::ResolveDispatchTarget + ?Sized, S: AddressViewSource + Sync, { - pub fn new(name: Name<'static>, publisher: EndpointPublisher, source: S) -> Self { + pub fn new(name: Name<'static>, publishers: Publishers, source: S) -> Self { Self { name, - publisher, + publishers, source, interval: DEFAULT_PUBLISH_INTERVAL, publish_timeout: DEFAULT_PUBLISH_TIMEOUT, @@ -215,7 +118,6 @@ where } _ = &mut interval => { interval.as_mut().reset(tokio::time::Instant::now() + self.interval); - self.clear_publish_state(); current_publish = self.new_publish_loop_future(); } event = locations.recv() => { @@ -229,7 +131,6 @@ where continue; } - self.clear_publish_state(); current_publish = self.new_publish_loop_future(); } } @@ -256,7 +157,7 @@ where let view = self.source.address_view(); match tokio::time::timeout( self.publish_timeout, - self.publisher.publish_once(&self.name, &view), + self.publishers.publish(&self.name, &view), ) .await { @@ -265,12 +166,10 @@ where true } Ok(Err(error)) => { - let report = snafu::Report::from_error(&error); - tracing::warn!(error = %report, name = %self.name, "dns publish failed"); + tracing::warn!(error = %error, name = %self.name, "dns publish failed"); false } Err(_elapsed) => { - self.clear_publish_state(); tracing::warn!( timeout_ms = self.publish_timeout.as_millis(), name = %self.name, @@ -284,7 +183,7 @@ where fn location_event_requires_publish(event: &AddressEvent) -> bool { match event { AddressEvent::Upsert(data) => { - if let Some(bound_addr) = data.downcast_ref::>() { + if let Some(bound_addr) = data.downcast_ref::>() { return bound_addr.is_ok(); } if let Some(stun_addr) = data.downcast_ref::() { @@ -293,14 +192,10 @@ where false } AddressEvent::Remove(type_id) => { - *type_id == TypeId::of::>() + *type_id == TypeId::of::>() || *type_id == TypeId::of::() } AddressEvent::Closed => true, } } - - fn clear_publish_state(&self) { - dispatch::clear_resolver_publish_state(self.publisher.resolver().as_ref()); - } } diff --git a/src/publishers/dispatch.rs b/src/publishers/dispatch.rs deleted file mode 100644 index c465808..0000000 --- a/src/publishers/dispatch.rs +++ /dev/null @@ -1,203 +0,0 @@ -use std::any::Any; - -use dhttp_identity::{identity::LocalAuthority, name::Name}; -use dquic::qresolve::{Publish, Resolve}; -use snafu::{IntoError, ResultExt}; - -use super::{ - AddressSelector, AddressView, EndpointRecordSigner, PublishOnceError, publish_once_error, -}; -#[cfg(feature = "resolvers")] -use crate::resolvers::Resolvers; - -#[cfg(all(feature = "h3", feature = "dquic-network"))] -type DeferredH3Resolver = - crate::resolvers::deferred::DeferredResolver>; - -#[doc(hidden)] -pub trait ResolveDispatchTarget: Resolve { - fn as_resolve(&self) -> &(dyn Resolve + Send + Sync); - fn as_any(&self) -> &dyn Any; -} - -impl ResolveDispatchTarget for T -where - T: Resolve + Send + Sync + 'static, -{ - fn as_resolve(&self) -> &(dyn Resolve + Send + Sync) { - self - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -impl ResolveDispatchTarget for dyn Resolve + Send + Sync { - fn as_resolve(&self) -> &(dyn Resolve + Send + Sync) { - self - } - - fn as_any(&self) -> &dyn Any { - self - } -} - -pub(crate) async fn publish_to_resolver( - signer: &EndpointRecordSigner, - resolver: &R, - name: &Name<'_>, - addresses: &V, -) -> Result -where - A: LocalAuthority + Send + Sync + ?Sized, - R: ResolveDispatchTarget + ?Sized, - V: AddressView + Sync, -{ - let any = resolver.as_any(); - - #[cfg(feature = "resolvers")] - if let Some(resolvers) = any.downcast_ref::() { - let mut published = false; - for resolver in resolvers.iter() { - published |= - publish_single_resolver(signer, resolver.as_ref(), name, addresses).await?; - } - return Ok(published); - } - - publish_single_resolver(signer, resolver.as_resolve(), name, addresses).await -} - -async fn publish_single_resolver( - signer: &EndpointRecordSigner, - resolver: &(dyn Resolve + Send + Sync), - name: &Name<'_>, - addresses: &V, -) -> Result -where - A: LocalAuthority + Send + Sync + ?Sized, - V: AddressView + Sync, -{ - let any = resolver as &dyn Any; - - #[cfg(not(any( - feature = "http", - all(feature = "h3", feature = "dquic-network"), - all(feature = "mdns", feature = "dquic-network") - )))] - { - let _ = any; - let _ = name; - let _ = addresses; - } - - #[cfg(feature = "http")] - if let Some(http) = any.downcast_ref::() { - publish_selected(signer, http, name, addresses, AddressSelector::WideArea).await?; - return Ok(true); - } - - #[cfg(all(feature = "h3", feature = "dquic-network"))] - if let Some(h3) = any.downcast_ref::>() { - publish_selected(signer, h3, name, addresses, AddressSelector::WideArea).await?; - return Ok(true); - } - - #[cfg(all(feature = "h3", feature = "dquic-network"))] - if let Some(h3) = any.downcast_ref::() { - let Some(h3) = h3.get() else { - return Err(publish_once_error::PublishSnafu { - publisher: h3.to_string(), - } - .into_error(std::io::Error::other( - "deferred h3 resolver has not been initialized", - ))); - }; - publish_selected(signer, h3, name, addresses, AddressSelector::WideArea).await?; - return Ok(true); - } - - #[cfg(all(feature = "mdns", feature = "dquic-network"))] - if let Some(mdns) = any.downcast_ref::() { - let mut published = false; - for bound in mdns.bound_resolvers() { - publish_selected( - signer, - &bound.resolver, - name, - addresses, - AddressSelector::LocalLink { - device: &bound.device, - family: bound.family, - }, - ) - .await?; - published = true; - } - return Ok(published); - } - - Ok(false) -} - -async fn publish_selected( - signer: &EndpointRecordSigner, - publisher: &(dyn Publish + Send + Sync), - name: &Name<'_>, - addresses: &V, - selector: AddressSelector<'_>, -) -> Result<(), PublishOnceError> -where - A: LocalAuthority + Send + Sync + ?Sized, - V: AddressView + Sync, -{ - let endpoints: Vec<_> = addresses.endpoints(selector).collect(); - let packet = signer - .signed_packet(name, &endpoints) - .await - .context(publish_once_error::SignEndpointRecordsSnafu)?; - tracing::debug!( - publisher = %publisher, - name = %name, - endpoint_count = endpoints.len(), - packet_len = packet.len(), - "publishing dns packet" - ); - publisher - .publish(name.as_str(), &packet) - .await - .context(publish_once_error::PublishSnafu { - publisher: publisher.to_string(), - }) -} - -pub(crate) fn clear_resolver_publish_state(resolver: &R) -where - R: ResolveDispatchTarget + ?Sized, -{ - clear_single_resolver_publish_state(resolver.as_resolve()); -} - -fn clear_single_resolver_publish_state(resolver: &(dyn Resolve + Send + Sync)) { - let any = resolver as &dyn Any; - - #[cfg(feature = "resolvers")] - if let Some(resolvers) = any.downcast_ref::() { - for resolver in resolvers.iter() { - clear_single_resolver_publish_state(resolver.as_ref()); - } - } - - #[cfg(all(feature = "h3", feature = "dquic-network"))] - if let Some(h3) = any.downcast_ref::>() { - h3.clear_pool(); - } - - #[cfg(all(feature = "h3", feature = "dquic-network"))] - if let Some(h3) = any.downcast_ref::() - && let Some(h3) = h3.get() - { - h3.clear_pool(); - } -} diff --git a/src/publishers/packet.rs b/src/publishers/packet.rs index 3369472..03f39b0 100644 --- a/src/publishers/packet.rs +++ b/src/publishers/packet.rs @@ -13,49 +13,6 @@ pub enum EncodeEndpointPacketError { EncodeEndpoint, } -#[allow(dead_code)] -pub type SignEndpointRecordsError = EncodeEndpointPacketError; - -#[allow(dead_code)] -#[derive(Clone)] -pub struct EndpointRecordSigner { - authority: std::sync::Arc, -} - -impl std::fmt::Debug for EndpointRecordSigner -where - A: dhttp_identity::identity::LocalAuthority, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EndpointRecordSigner") - .field("authority", &self.authority.name()) - .finish() - } -} - -#[allow(dead_code)] -impl EndpointRecordSigner -where - A: dhttp_identity::identity::LocalAuthority + Send + Sync + ?Sized, -{ - pub fn new(authority: std::sync::Arc) -> Self { - Self { authority } - } - - pub fn authority(&self) -> &std::sync::Arc { - &self.authority - } - - pub async fn signed_packet( - &self, - name: &Name<'_>, - endpoints: &[EndpointAddr], - ) -> Result, SignEndpointRecordsError> { - let _ = &self.authority; - endpoint_packet(name, endpoints.iter().copied()) - } -} - pub(crate) fn endpoint_packet( name: &Name<'_>, endpoints: impl IntoIterator, diff --git a/tests/publishers_surface.rs b/tests/publishers_surface.rs index 89d769a..526e0a7 100644 --- a/tests/publishers_surface.rs +++ b/tests/publishers_surface.rs @@ -13,14 +13,7 @@ fn publishers_facade_exposes_publisher_and_aggregate_types() { fn publishers_facade_exposes_network_publication_loop_surface() { let _ = ddns::publishers::DEFAULT_PUBLISH_INTERVAL; let _ = ddns::publishers::DEFAULT_PUBLISH_TIMEOUT; - let _ = core::any::type_name::(); - let _ = core::any::type_name::(); - let _ = core::any::type_name::(); let _ = core::any::type_name::< - ddns::publishers::EndpointPublicationLoop< - dyn dhttp_identity::identity::LocalAuthority + Send + Sync, - dyn dquic::qresolve::Resolve + Send + Sync, - ddns::publishers::EndpointBindingAddresses, - >, + ddns::publishers::EndpointPublicationLoop, >(); } From f603b1f591895b57f3f9b038b76057e27b221b7a Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:41:22 +0800 Subject: [PATCH 21/29] docs: update publisher facade README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5f64bb6..4a05529 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ ddns = { package = "dyns", version = "0.3.0" } | `ddns::http` | DNS-over-HTTP backend implementation. | | `ddns::mdns` | RFC 6762 multicast DNS transport plus LAN resolver/publisher backend implementation. | | `ddns::resolvers` | Resolver facade: backend re-exports, resolver chains, and `Resolvers` aggregation. | -| `ddns::publishers` | Publisher facade: backend re-exports, endpoint record signing, and endpoint publication helpers. | +| `ddns::publishers` | Publisher facade: backend re-exports, scoped publisher atoms, `Publishers` aggregation, and endpoint publication helpers. | ## Features @@ -29,7 +29,7 @@ The default feature set is empty. | Feature | Enables | | --- | --- | | `resolvers` | Resolver aggregation types such as `Resolvers`, `ResolversBuilder`, and `DnsScheme`. | -| `publishers` | Endpoint publication aggregation and signing helpers such as `EndpointPublisher`, `EndpointPublicationLoop`, and `PublishAddresses`. | +| `publishers` | Scoped publication helpers such as `Publisher`, `Publishers`, `PublishScope`, `EndpointPublicationLoop`, and `PublishAddresses`; backend `Publish` implementations own any required signing. | | `dquic-network` | `h3x`/`dquic` network-backed publication helpers such as `EndpointBindingAddresses`; meaningful together with `publishers`, and also used by mDNS resolver aggregation. | | `h3` | DNS-over-HTTP/3 backend surface (`ddns::h3`, plus `H3Resolver` / `H3Publisher` re-exports from the facades). | | `http` | DNS-over-HTTP backend surface (`ddns::http`, plus `HttpResolver` / `HttpPublisher` re-exports from the facades). | @@ -62,7 +62,7 @@ use ddns::resolvers::Resolvers; use futures::StreamExt; #[tokio::main] -async fn main() -> Result<(), ddns::resolvers::DnsErrors> { +async fn main() -> Result<(), ddns::resolvers::ResolversError> { let resolvers = Resolvers::builder().system().build(); let mut endpoints = resolvers.lookup("demo.example.dhttp.net").await?; From 11124f74c2aca38831b1563772b4f4781cd145ec Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 20:42:36 +0800 Subject: [PATCH 22/29] refactor: keep publisher items before tests --- src/publishers/publisher.rs | 76 ++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/src/publishers/publisher.rs b/src/publishers/publisher.rs index 1c3b1e0..1ae6c64 100644 --- a/src/publishers/publisher.rs +++ b/src/publishers/publisher.rs @@ -157,6 +157,44 @@ where }) } +#[cfg(all(feature = "mdns", feature = "dquic-network"))] +async fn publish_mdns( + resolvers: &crate::mdns::MdnsResolvers, + name: &Name<'_>, + view: &V, +) -> Result<(), PublisherError> +where + V: AddressView + Sync, +{ + let bound_resolvers = resolvers.bound_resolvers(); + if bound_resolvers.is_empty() { + tracing::debug!(name = %name, "no mdns publishers currently bound"); + return Ok(()); + } + + let mut errors = Vec::new(); + let mut succeeded = false; + for bound in bound_resolvers { + let scope = PublishScope::LocalLink { + device: bound.device.clone().into(), + family: bound.family, + }; + match publish_selected(&bound.resolver, &scope, name, view).await { + Ok(()) => succeeded = true, + Err(PublisherError::Publish { source, .. }) => { + errors.push((bound.resolver.to_string(), source)); + } + Err(error) => return Err(error), + } + } + + if succeeded { + Ok(()) + } else { + Err(publisher_error::MdnsSnafu.into_error(MdnsPublishersError { errors })) + } +} + #[cfg(test)] mod tests { use std::{ @@ -328,41 +366,3 @@ mod tests { ); } } - -#[cfg(all(feature = "mdns", feature = "dquic-network"))] -async fn publish_mdns( - resolvers: &crate::mdns::MdnsResolvers, - name: &Name<'_>, - view: &V, -) -> Result<(), PublisherError> -where - V: AddressView + Sync, -{ - let bound_resolvers = resolvers.bound_resolvers(); - if bound_resolvers.is_empty() { - tracing::debug!(name = %name, "no mdns publishers currently bound"); - return Ok(()); - } - - let mut errors = Vec::new(); - let mut succeeded = false; - for bound in bound_resolvers { - let scope = PublishScope::LocalLink { - device: bound.device.clone().into(), - family: bound.family, - }; - match publish_selected(&bound.resolver, &scope, name, view).await { - Ok(()) => succeeded = true, - Err(PublisherError::Publish { source, .. }) => { - errors.push((bound.resolver.to_string(), source)); - } - Err(error) => return Err(error), - } - } - - if succeeded { - Ok(()) - } else { - Err(publisher_error::MdnsSnafu.into_error(MdnsPublishersError { errors })) - } -} From 9dcac74a312475e31cb42c95da495d0e6a3afec5 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 21:25:39 +0800 Subject: [PATCH 23/29] fix: gate query example behind h3 feature --- Cargo.toml | 1 + examples/README.md | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e0a51af..9b2d3cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,3 +101,4 @@ required-features = ["h3"] [[example]] name = "query" path = "examples/query.rs" +required-features = ["h3"] diff --git a/examples/README.md b/examples/README.md index 9130094..f87a5d3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,7 +7,7 @@ whose library target remains `ddns`. | --- | --- | --- | | `mdns_discover` | `mdns` | Bind an mDNS service, publish sample local hosts, and print multicast packets. | | `mdns_query` | `mdns` | Query a DHTTP name over local mDNS. | -| `query` | none | Query a DNS-over-H3 server and decode the multi-record response. | +| `query` | `h3` | Query a DNS-over-H3 server and decode the multi-record response. | | `publish` | `h3` | Publish endpoint `E` records to a DNS-over-H3 server using client mTLS; H3 publish request headers are signed from the client endpoint identity. | Run all commands from the `ddns/` repository. @@ -36,7 +36,7 @@ The mDNS service name defaults to the build-time `DHTTP_MDNS_SERVICE` constant. ## DNS-over-H3 query ```bash -cargo run --example query -- \ +cargo run --example query --features h3 -- \ --server-ca /path/to/root.crt \ --host nat.genmeta.net ``` From 0d6833430958b58efdc7d1604e9d4fc9e3f03112 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 21:57:09 +0800 Subject: [PATCH 24/29] refactor: group endpoints by certificate chain key --- examples/publish.rs | 5 +- src/core/parser/record/endpoint.rs | 103 +++++++++++++++--- src/h3/lookup.rs | 8 +- src/mdns.rs | 4 +- src/resolvers.rs | 2 +- .../{selector.rs => endpoint_group.rs} | 30 +++-- 6 files changed, 119 insertions(+), 33 deletions(-) rename src/resolvers/{selector.rs => endpoint_group.rs} (78%) diff --git a/examples/publish.rs b/examples/publish.rs index 2af5669..7684f49 100644 --- a/examples/publish.rs +++ b/examples/publish.rs @@ -141,7 +141,10 @@ async fn main() -> io::Result<()> { SocketAddr::V6(v6) => EndpointAddr::direct_v6(v6), }; endpoint.set_main(opt.is_main); - endpoint.set_sequence(opt.sequence); + endpoint.set_sequence( + dhttp_identity::certificate::CertificateSequence::try_from(opt.sequence) + .map_err(io::Error::other)?, + ); info!("Publishing endpoint: {:?}", endpoint); let mut hosts = std::collections::HashMap::new(); hosts.insert(opt.host.clone(), vec![endpoint]); diff --git a/src/core/parser/record/endpoint.rs b/src/core/parser/record/endpoint.rs index 5adf7e9..320352a 100644 --- a/src/core/parser/record/endpoint.rs +++ b/src/core/parser/record/endpoint.rs @@ -8,6 +8,7 @@ use std::{ use base64::Engine; use bytes::BufMut; +use dhttp_identity::certificate::{CertificateChainKey, CertificateChainKind, CertificateSequence}; use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use nom::{ IResult, Parser, @@ -84,9 +85,9 @@ pub struct EndpointSignature { #[derive(Debug, Clone)] pub struct EndpointAddr { flags: u8, - /// Device sequence number used when multiple hosts share a domain (CLUSTERED). + /// Certificate-chain sequence used when multiple hosts share a domain (CLUSTERED). /// None means no sequence number. - sequence: Option, + sequence: Option, /// 1-minute load average (present when LOAD flag is set) load: Option, signature: Option, @@ -333,7 +334,7 @@ impl EndpointAddr { // sequence is only encoded when CLUSTERED flag is set if let Some(seq) = &self.sequence { - meta_len += seq.encoding_size(); + meta_len += VarInt::from_u32(seq.get()).encoding_size(); } if self.load.is_some() { @@ -366,13 +367,18 @@ impl EndpointAddr { self.agent } - pub fn sequence(&self) -> Option { - self.sequence.map(Into::into) + pub fn sequence(&self) -> Option { + self.sequence } - pub fn set_sequence(&mut self, sequence: u64) { - if sequence > 0 { - self.sequence = Some(VarInt::from_u64(sequence).expect("Sequence too large")); + pub fn normalized_sequence(&self) -> CertificateSequence { + self.sequence + .unwrap_or_else(|| CertificateSequence::from(0u8)) + } + + pub fn set_sequence(&mut self, sequence: CertificateSequence) { + if sequence.get() > 0 { + self.sequence = Some(sequence); self.set_clustered(true); } else { self.sequence = None; @@ -380,6 +386,15 @@ impl EndpointAddr { } } + pub fn certificate_chain_key(&self) -> CertificateChainKey { + let kind = if self.is_main() { + CertificateChainKind::Primary + } else { + CertificateChainKind::Secondary + }; + CertificateChainKey::new(self.normalized_sequence(), kind) + } + pub fn load(&self) -> Option { self.load } @@ -407,7 +422,7 @@ impl EndpointAddr { // Sequence is only written when CLUSTERED is set if let Some(seq) = &self.sequence { - buf.put_varint(*seq); + buf.put_varint(VarInt::from_u32(seq.get())); } // Write primary address @@ -468,7 +483,13 @@ pub fn be_endpoint_addr(input: &[u8]) -> nom::IResult<&[u8], EndpointAddr> { // Sequence number is only present when CLUSTERED is set let (remain, sequence) = if is_clustered { let (remain, seq) = be_varint(remain)?; - (remain, Some(seq)) + let sequence = match CertificateSequence::try_from(seq.into_inner()) { + Ok(sequence) => sequence, + Err(_error) => { + return Err(nom::Err::Failure(make_error(remain, ErrorKind::TooLarge))); + } + }; + (remain, Some(sequence)) } else { (remain, None) }; @@ -765,7 +786,7 @@ pub async fn sign_endponit_address( ) -> Option { let mut ep: EndpointAddr = endpoint.try_into().ok()?; ep.set_main(server_id == 0); - ep.set_sequence(server_id as u64); + ep.set_sequence(CertificateSequence::from(server_id)); if let Some(authority) = authority { let _ = ep.sign_with_authority(authority).await; } @@ -786,6 +807,58 @@ mod tests { use super::*; + fn v4_outer() -> SocketAddrV4 { + SocketAddrV4::new(Ipv4Addr::new(203, 0, 113, 10), 4433) + } + + #[test] + fn endpoint_certificate_chain_key_normalizes_missing_sequence() { + let mut endpoint = EndpointAddr::direct_v4(v4_outer()); + endpoint.set_main(true); + + let key = endpoint.certificate_chain_key(); + + assert_eq!( + key.kind(), + dhttp_identity::certificate::CertificateChainKind::Primary + ); + assert_eq!(key.sequence().get(), 0); + assert_eq!(key.to_string(), "primary:0"); + } + + #[test] + fn endpoint_certificate_chain_key_uses_present_sequence() { + let mut endpoint = EndpointAddr::direct_v4(v4_outer()); + endpoint.set_main(false); + endpoint.set_sequence( + dhttp_identity::certificate::CertificateSequence::try_from(7u32).unwrap(), + ); + + let key = endpoint.certificate_chain_key(); + + assert_eq!( + key.kind(), + dhttp_identity::certificate::CertificateChainKind::Secondary + ); + assert_eq!(key.sequence().get(), 7); + assert_eq!(key.to_string(), "secondary:7"); + } + + #[test] + fn endpoint_parser_rejects_over_range_certificate_sequence() { + let sequence = crate::core::parser::varint::VarInt::from_u64( + dhttp_identity::certificate::CertificateSequence::MAX as u64 + 1, + ) + .unwrap(); + let mut packet = BytesMut::new(); + packet.put_u8(EndpointAddr::FLAG_MAIN | EndpointAddr::FLAG_CLUSTERED); + packet.put_varint(sequence); + packet.put_u16(v4_outer().port()); + packet.put_slice(&v4_outer().ip().octets()); + + assert!(be_endpoint_addr(&packet).is_err()); + } + #[test] fn legacy_endpoint_v4_direct_without_meta() { let port = 5353u16; @@ -900,7 +973,7 @@ mod tests { // IPv4 direct, MAIN + CLUSTERED flags EndpointAddr { flags: EndpointAddr::FLAG_MAIN | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u32(0)), + sequence: Some(CertificateSequence::from(0u8)), load: None, signature: None, primary: v4_outer.into(), @@ -909,7 +982,7 @@ mod tests { // IPv4 NAT, CLUSTERED flag EndpointAddr { flags: EndpointAddr::FLAG_NAT | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u32(127)), + sequence: Some(CertificateSequence::try_from(127u32).unwrap()), load: None, signature: None, primary: v4_outer.into(), @@ -920,7 +993,7 @@ mod tests { flags: EndpointAddr::FLAG_FAMILY | EndpointAddr::FLAG_MAIN | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u32(128)), + sequence: Some(CertificateSequence::try_from(128u32).unwrap()), load: None, signature: None, primary: v6_outer.into(), @@ -931,7 +1004,7 @@ mod tests { flags: EndpointAddr::FLAG_FAMILY | EndpointAddr::FLAG_NAT | EndpointAddr::FLAG_CLUSTERED, - sequence: Some(VarInt::from_u64((1 << 62) - 1).unwrap()), + sequence: Some(CertificateSequence::try_from(16_384u32).unwrap()), load: None, signature: None, primary: v6_outer.into(), diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs index 4118164..0b61d50 100644 --- a/src/h3/lookup.rs +++ b/src/h3/lookup.rs @@ -80,7 +80,7 @@ impl LookupRecords { } Ok(Self { - endpoints: crate::resolvers::selector::selected_endpoint_addrs(endpoint_records), + endpoints: crate::resolvers::endpoint_group::selected_endpoint_addrs(endpoint_records), }) } } @@ -221,11 +221,13 @@ mod tests { wire::{MultiResponse, ResponseRecord}, }; - fn direct(addr: &str, main: bool, sequence: u64) -> DnsEndpointAddr { + fn direct(addr: &str, main: bool, sequence: u32) -> DnsEndpointAddr { let socket: SocketAddrV4 = addr.parse().expect("socket addr"); let mut endpoint = DnsEndpointAddr::direct_v4(socket); endpoint.set_main(main); - endpoint.set_sequence(sequence); + endpoint.set_sequence( + dhttp_identity::certificate::CertificateSequence::try_from(sequence).unwrap(), + ); endpoint } diff --git a/src/mdns.rs b/src/mdns.rs index 03ce041..fce3ee9 100644 --- a/src/mdns.rs +++ b/src/mdns.rs @@ -59,7 +59,7 @@ impl Resolve for MdnsResolver { let source = self.source(); self.query(name.to_owned()) .map_ok(move |list| { - let endpoints = crate::resolvers::selector::selected_endpoint_addrs(list); + let endpoints = crate::resolvers::endpoint_group::selected_endpoint_addrs(list); stream::iter(endpoints.into_iter().map(move |ep| (source.clone(), ep))).boxed() }) .boxed() @@ -295,7 +295,7 @@ impl MdnsResolvers { ); } - let records = crate::resolvers::selector::selected_endpoint_records(records); + let records = crate::resolvers::endpoint_group::selected_endpoint_records(records); Ok(stream::iter(records).boxed()) } diff --git a/src/resolvers.rs b/src/resolvers.rs index c1c5df1..abd58f1 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -92,7 +92,7 @@ impl std::str::FromStr for DnsScheme { } pub mod deferred; -pub(crate) mod selector; +pub(crate) mod endpoint_group; pub mod weak; #[cfg(feature = "resolvers")] diff --git a/src/resolvers/selector.rs b/src/resolvers/endpoint_group.rs similarity index 78% rename from src/resolvers/selector.rs rename to src/resolvers/endpoint_group.rs index c03ace4..659fac3 100644 --- a/src/resolvers/selector.rs +++ b/src/resolvers/endpoint_group.rs @@ -1,10 +1,10 @@ +use dhttp_identity::certificate::CertificateChainKey; use dquic::qbase::net::addr::EndpointAddr as DquicEndpointAddr; use crate::core::parser::record::endpoint::EndpointAddr as DnsEndpointAddr; -type Selector = (bool, u64); type TaggedEndpoint = (T, DquicEndpointAddr); -type EndpointGroup = (Selector, Vec>); +type EndpointGroup = (CertificateChainKey, Vec>); pub(crate) fn selected_endpoint_addrs( records: impl IntoIterator, @@ -21,19 +21,25 @@ pub(crate) fn selected_endpoint_records( let mut groups: Vec> = Vec::new(); for (tag, record) in records { - let selector = (record.is_main(), record.sequence().unwrap_or(0)); + let chain_key = record.certificate_chain_key(); let Ok(endpoint) = DquicEndpointAddr::try_from(record) else { continue; }; - if let Some((_key, endpoints)) = groups.iter_mut().find(|(key, _)| *key == selector) { + if let Some((_key, endpoints)) = groups.iter_mut().find(|(key, _)| *key == chain_key) { endpoints.push((tag, endpoint)); } else { - groups.push((selector, vec![(tag, endpoint)])); + groups.push((chain_key, vec![(tag, endpoint)])); } } - groups.sort_by_key(|((is_main, sequence), _)| (!*is_main, *sequence)); + groups.sort_by_key(|(chain_key, _)| { + let primary_rank = match chain_key.kind() { + dhttp_identity::certificate::CertificateChainKind::Primary => 0, + dhttp_identity::certificate::CertificateChainKind::Secondary => 1, + }; + (primary_rank, chain_key.sequence().get()) + }); groups .into_iter() @@ -44,20 +50,22 @@ pub(crate) fn selected_endpoint_records( #[cfg(test)] mod tests { + use dhttp_identity::certificate::CertificateSequence; + use crate::core::parser::record::endpoint::EndpointAddr; - fn direct(addr: &str, main: bool, sequence: u64) -> EndpointAddr { + fn direct(addr: &str, main: bool, sequence: u32) -> EndpointAddr { let mut endpoint = match addr.parse().unwrap() { std::net::SocketAddr::V4(addr) => EndpointAddr::direct_v4(addr), std::net::SocketAddr::V6(addr) => EndpointAddr::direct_v6(addr), }; endpoint.set_main(main); - endpoint.set_sequence(sequence); + endpoint.set_sequence(CertificateSequence::try_from(sequence).unwrap()); endpoint } #[test] - fn selected_endpoint_addrs_prefers_primary_group() { + fn selected_endpoint_addrs_prefers_primary_chain_key_group() { let secondary = direct("192.0.2.20:4433", false, 0); let primary_a = direct("192.0.2.10:4433", true, 2); let primary_b = direct("192.0.2.11:4433", true, 2); @@ -76,7 +84,7 @@ mod tests { } #[test] - fn selected_endpoint_addrs_uses_one_secondary_group_when_no_primary_exists() { + fn selected_endpoint_addrs_uses_one_secondary_chain_key_group_when_no_primary_exists() { let secondary_a = direct("192.0.2.20:4433", false, 5); let secondary_b = direct("192.0.2.21:4433", false, 5); let other_secondary = direct("192.0.2.30:4433", false, 6); @@ -106,7 +114,7 @@ mod tests { } #[test] - fn selected_endpoint_records_uses_one_group_across_sources() { + fn selected_endpoint_records_uses_one_chain_key_across_sources() { let selected = super::selected_endpoint_records([ ("wifi", direct("192.0.2.50:4433", true, 3)), ("ethernet", direct("192.0.2.51:4433", true, 4)), From 081f3b19e9526d51f108dd135aad12c607f01168 Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 23:27:28 +0800 Subject: [PATCH 25/29] fix: gate backend dependencies by feature --- Cargo.toml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9b2d3cd..61fb661 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,12 +18,12 @@ name = "ddns" base64 = "0.22" bitfield-struct = "0.13" bytes = "1" -dashmap = "6" +dashmap = { version = "6", optional = true } dhttp-identity = { path = "../dhttp/identity", version = "0.1.0" } dquic = "0.5.1" -flume = "0.12" +flume = { version = "0.12", optional = true } futures = "0.3" -libc = "0.2" +libc = { version = "0.2", optional = true } nom = "8" rand = "0.10" ring = "0.17" @@ -34,7 +34,7 @@ rustls = { version = "0.23", default-features = false, features = [ rustls-native-certs = { version = "0.8", optional = true } rustls-pemfile = "2" snafu = "0.9" -socket2 = { version = "0.6", features = ["all"] } +socket2 = { version = "0.6", features = ["all"], optional = true } tokio = { version = "1", features = [ "time", "macros", @@ -67,6 +67,7 @@ resolvers = [] publishers = [] dquic-network = ["dep:h3x", "h3x/dquic"] h3 = [ + "dep:dashmap", "dep:h3x", "h3x/hyper", "dep:http", @@ -74,8 +75,8 @@ h3 = [ "dep:http-body-util", "dep:url", ] -http = ["dep:reqwest", "dep:rustls-native-certs"] -mdns = [] +http = ["dep:dashmap", "dep:reqwest", "dep:rustls-native-certs"] +mdns = ["dep:dashmap", "dep:flume", "dep:libc", "dep:socket2"] [dev-dependencies] clap = { version = "4", features = ["derive"] } From c5d2a82931073888cc31eb613aa1266b56790e3c Mon Sep 17 00:00:00 2001 From: eareimu Date: Tue, 16 Jun 2026 23:49:34 +0800 Subject: [PATCH 26/29] fix: target v2 dns api routes --- examples/README.md | 3 ++- examples/query.rs | 5 +++-- src/h3/lookup.rs | 35 ++++++++++++++++++++++++++++++-- src/h3/publish.rs | 15 +++++++++++--- src/http.rs | 50 ++++++++++++++++++++++++++++++++++++++++++---- 5 files changed, 96 insertions(+), 12 deletions(-) diff --git a/examples/README.md b/examples/README.md index f87a5d3..c95a62d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -49,7 +49,7 @@ Options: | `--server-ca ` | PEM root CA used to verify the DNS server certificate. | | `--host ` | DNS host to query. Defaults to `nat.genmeta.net`. | -The example sends `GET /lookup?host=`. A successful server response is a +The example sends `GET /api/v2/lookup?host=`. A successful server response is a `ddns::core::wire::MultiResponse` body with header `x-record-format: multi`: ```text @@ -88,3 +88,4 @@ Options: The example imports `H3Publisher` from the `ddns::publishers` facade, but only needs the `h3` backend feature because backend publisher types are re-exported from the facade directly. H3 publish request headers are always signed with the configured client endpoint identity; callers no longer pass request signature fields. +Publish requests are sent to `POST /api/v2/publish?host=`. diff --git a/examples/query.rs b/examples/query.rs index c86f553..c8572b4 100644 --- a/examples/query.rs +++ b/examples/query.rs @@ -123,10 +123,11 @@ async fn main() -> Result<(), Box> { .await; let client = H3Endpoint::new(quic); - let url = format!("{}lookup?host={}", opt.base_url, opt.host); + let mut url = url::Url::parse(&opt.base_url)?.join("/api/v2/lookup")?; + url.query_pairs_mut().append_pair("host", &opt.host); info!(url = %url, "lookup.start"); - let uri: http::Uri = url.parse()?; + let uri: http::Uri = url.as_str().parse()?; let authority = uri .authority() .ok_or_else(|| { diff --git a/src/h3/lookup.rs b/src/h3/lookup.rs index 0b61d50..0f99bf6 100644 --- a/src/h3/lookup.rs +++ b/src/h3/lookup.rs @@ -13,6 +13,16 @@ use super::{ }; use crate::core::{parser::packet::be_packet, wire::be_multi_response}; +const LOOKUP_API_PATH: &str = "/api/v2/lookup"; + +fn lookup_url(base_url: &url::Url, name: &str) -> url::Url { + let mut url = base_url + .join(LOOKUP_API_PATH) + .expect("h3 dns lookup api path must be valid"); + url.query_pairs_mut().append_pair("host", name); + url +} + #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct LookupRecords { pub(super) endpoints: Vec, @@ -181,8 +191,7 @@ where return Ok(stream.boxed()); } - let mut url = self.base_url.join("lookup").expect("Invalid URL"); - url.set_query(Some(&format!("host={}", domain))); + let url = lookup_url(&self.base_url, domain); let uri: http::Uri = url.as_str().parse().expect("URL should be valid URI"); tracing::trace!("sending lookup request to {}", self.base_url); @@ -238,6 +247,28 @@ mod tests { MultiResponse::new([ResponseRecord::unsigned(packet, Vec::new())]).encode() } + #[test] + fn h3_lookup_url_targets_v2_api_from_origin_base() { + let base_url = url::Url::parse("https://dns.example.test:4433").expect("url"); + let url = lookup_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test:4433/api/v2/lookup?host=demo.dhttp.net" + ); + } + + #[test] + fn h3_lookup_url_does_not_duplicate_v2_base_path() { + let base_url = url::Url::parse("https://dns.example.test:4433/api/v2/").expect("url"); + let url = lookup_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test:4433/api/v2/lookup?host=demo.dhttp.net" + ); + } + #[test] fn lookup_records_select_primary_group() { let response = response_for( diff --git a/src/h3/publish.rs b/src/h3/publish.rs index ef6db4d..637e493 100644 --- a/src/h3/publish.rs +++ b/src/h3/publish.rs @@ -13,14 +13,23 @@ use crate::core::{ signature::{CONTENT_DIGEST_HEADER, SIGNATURE_HEADER, SIGNATURE_INPUT_HEADER, SignatureFields}, }; +const PUBLISH_API_PATH: &str = "/api/v2/publish"; + +fn publish_url(base_url: &url::Url, name: &str) -> url::Url { + let mut url = base_url + .join(PUBLISH_API_PATH) + .expect("h3 dns publish api path must be valid"); + url.query_pairs_mut().append_pair("host", name); + url +} + async fn signed_publish_request( base_url: &url::Url, name: &str, packet: &[u8], authority: &A, ) -> Result>, crate::core::signature::SignatureFieldsError> { - let mut url = base_url.join("publish").expect("h3 dns base URL is valid"); - url.set_query(Some(&format!("host={name}"))); + let url = publish_url(base_url, name); let uri: http::Uri = url .as_str() .parse() @@ -234,7 +243,7 @@ mod tests { assert_eq!(request.method(), http::Method::POST); assert_eq!( request.uri().to_string(), - "https://dns.example.test:4433/publish?host=demo.dhttp.net" + "https://dns.example.test:4433/api/v2/publish?host=demo.dhttp.net" ); assert!( request diff --git a/src/http.rs b/src/http.rs index 4531435..f6c0090 100644 --- a/src/http.rs +++ b/src/http.rs @@ -15,6 +15,9 @@ use crate::core::{ wire::be_multi_response, }; +const LOOKUP_API_PATH: &str = "/api/v2/lookup"; +const PUBLISH_API_PATH: &str = "/api/v2/publish"; + #[derive(Debug)] struct Record { addrs: Vec, @@ -28,6 +31,20 @@ pub struct HttpResolver { cached_records: DashMap, } +fn lookup_url(base_url: &Url, name: &str) -> Url { + api_url(base_url, LOOKUP_API_PATH, name) +} + +fn publish_url(base_url: &Url, name: &str) -> Url { + api_url(base_url, PUBLISH_API_PATH, name) +} + +fn api_url(base_url: &Url, path: &str, name: &str) -> Url { + let mut url = base_url.join(path).expect("ddns api path must be valid"); + url.query_pairs_mut().append_pair("host", name); + url +} + impl Display for HttpResolver { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( @@ -74,8 +91,7 @@ impl HttpResolver { packet: &[u8], signature_fields: &SignatureFields, ) -> Result<(), Error> { - let mut url = self.base_url.join("publish").expect("Invalid base URL"); - url.set_query(Some(&format!("host={name}"))); + let url = publish_url(&self.base_url, name); let mut request = self .http_client .post(url) @@ -194,8 +210,7 @@ impl Resolve for HttpResolver { } let response = self .http_client - .get(self.base_url.join("lookup").expect("Invalid URL")) - .query(&[("host", domain)]) + .get(lookup_url(&self.base_url, domain)) .send() .await; @@ -269,3 +284,30 @@ impl Resolve for HttpResolver { Box::pin(lookup.map_err(io::Error::other)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn http_publish_url_targets_v2_api_from_origin_base() { + let base_url = Url::parse("https://dns.example.test").expect("url"); + let url = publish_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test/api/v2/publish?host=demo.dhttp.net" + ); + } + + #[test] + fn http_lookup_url_does_not_duplicate_v2_base_path() { + let base_url = Url::parse("https://dns.example.test/api/v2/").expect("url"); + let url = lookup_url(&base_url, "demo.dhttp.net"); + + assert_eq!( + url.as_str(), + "https://dns.example.test/api/v2/lookup?host=demo.dhttp.net" + ); + } +} From c5af4d1d19ca520bd1867ba82a7737c626269ce3 Mon Sep 17 00:00:00 2001 From: eareimu Date: Wed, 17 Jun 2026 01:08:00 +0800 Subject: [PATCH 27/29] release: prepare v0.4.0 --- Cargo.toml | 8 ++++---- README.md | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 61fb661..199fce5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "dyns" description = "DNS discovery and resolver support for DHTTP applications" -version = "0.3.0" +version = "0.4.0" edition = "2024" license = "Apache-2.0" repository = "https://github.com/genmeta/ddns" @@ -19,7 +19,7 @@ base64 = "0.22" bitfield-struct = "0.13" bytes = "1" dashmap = { version = "6", optional = true } -dhttp-identity = { path = "../dhttp/identity", version = "0.1.0" } +dhttp-identity = { git = "https://github.com/genmeta/dhttp.git", branch = "dev/v0.2.0", version = "0.2.0" } dquic = "0.5.1" flume = { version = "0.12", optional = true } futures = "0.3" @@ -47,7 +47,7 @@ tokio = { version = "1", features = [ tracing = "0.1" x509-parser = { version = "0.18", features = ["verify"] } -h3x = { path = "../h3x", default-features = false, optional = true } +h3x = { git = "https://github.com/genmeta/h3x.git", branch = "dev/v0.4.0", version = "0.4.0", default-features = false, optional = true } http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -80,7 +80,7 @@ mdns = ["dep:dashmap", "dep:flume", "dep:libc", "dep:socket2"] [dev-dependencies] clap = { version = "4", features = ["derive"] } -h3x = { path = "../h3x", default-features = false, features = ["dquic"] } +h3x = { git = "https://github.com/genmeta/h3x.git", branch = "dev/v0.4.0", version = "0.4.0", default-features = false, features = ["dquic"] } shellexpand = "3" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/README.md b/README.md index 4a05529..e4cf127 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ while `ddns::resolvers` and `ddns::publishers` act as facades for re-exports and aggregate helper types. ```toml -ddns = { package = "dyns", version = "0.3.0" } +ddns = { package = "dyns", version = "0.4.0" } ``` ## Crate layout From e6a4c10ef103c020dabe9bdc534e5e248d3bcf24 Mon Sep 17 00:00:00 2001 From: eareimu Date: Wed, 17 Jun 2026 03:42:43 +0800 Subject: [PATCH 28/29] release: converge upstream dependencies --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 199fce5..abb3316 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ base64 = "0.22" bitfield-struct = "0.13" bytes = "1" dashmap = { version = "6", optional = true } -dhttp-identity = { git = "https://github.com/genmeta/dhttp.git", branch = "dev/v0.2.0", version = "0.2.0" } +dhttp-identity = "0.2.0" dquic = "0.5.1" flume = { version = "0.12", optional = true } futures = "0.3" @@ -47,7 +47,7 @@ tokio = { version = "1", features = [ tracing = "0.1" x509-parser = { version = "0.18", features = ["verify"] } -h3x = { git = "https://github.com/genmeta/h3x.git", branch = "dev/v0.4.0", version = "0.4.0", default-features = false, optional = true } +h3x = { version = "0.4.0", default-features = false, optional = true } http = { version = "1", optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } @@ -80,7 +80,7 @@ mdns = ["dep:dashmap", "dep:flume", "dep:libc", "dep:socket2"] [dev-dependencies] clap = { version = "4", features = ["derive"] } -h3x = { git = "https://github.com/genmeta/h3x.git", branch = "dev/v0.4.0", version = "0.4.0", default-features = false, features = ["dquic"] } +h3x = { version = "0.4.0", default-features = false, features = ["dquic"] } shellexpand = "3" tracing-subscriber = { version = "0.3", features = ["env-filter"] } From 96f8f198ce1fcf91c18d529d12ec0bf9ebe5cfd6 Mon Sep 17 00:00:00 2001 From: eareimu Date: Wed, 17 Jun 2026 03:47:22 +0800 Subject: [PATCH 29/29] fix: gate resolver endpoint grouping helper --- src/resolvers.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/resolvers.rs b/src/resolvers.rs index abd58f1..fd95981 100644 --- a/src/resolvers.rs +++ b/src/resolvers.rs @@ -92,6 +92,7 @@ impl std::str::FromStr for DnsScheme { } pub mod deferred; +#[cfg(any(feature = "h3", feature = "mdns", test))] pub(crate) mod endpoint_group; pub mod weak;