From cfad5bd925991b21eba27f0b291a91d536769b10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20LIARD?= Date: Tue, 28 Apr 2026 22:30:30 +0200 Subject: [PATCH 1/5] chore(prek): drop redundant cargo-test and cargo-doctest from pre-push CI re-runs the full test suite (incl. doctests) on every PR via the .github/workflows/ci.yml tests job, so local pre-push duplication adds ~20 min per push without catching anything new. Pre-push hooks should be fast-fail; expensive checks belong on the CI server. Closes audit finding: silent productivity tax (pre-push duplication). --- prek.toml | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/prek.toml b/prek.toml index 82c09caa..99598f46 100644 --- a/prek.toml +++ b/prek.toml @@ -54,25 +54,10 @@ args = ["-c", "lychee --offline --no-progress --include-fragments docs/**/*.md R types = ["markdown"] pass_filenames = false -# Tests — expensive, runs only on pre-push -[[repos.hooks]] -id = "cargo-test" -name = "cargo test" -language = "system" -entry = "cargo nextest run --profile ci" -types = ["rust"] -pass_filenames = false -stages = ["pre-push"] - -# Doc tests — not covered by nextest -[[repos.hooks]] -id = "cargo-doctest" -name = "cargo doc tests" -language = "system" -entry = "cargo test --doc" -types = ["rust"] -pass_filenames = false -stages = ["pre-push"] +# NOTE: cargo-test and cargo-doctest were removed from pre-push to keep +# the gate fast-fail. CI re-runs the full suite on every PR (see +# `.github/workflows/ci.yml` `tests` job), so duplicating ~20 min of work +# locally only delays the developer feedback loop without adding signal. # Cargo deny — license & advisory & dependency checks # Only runs when Cargo.toml/lock or Rust files change. From fafba5e58d3e00deb9d7bfdfb7505093e5357ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20LIARD?= Date: Tue, 28 Apr 2026 22:31:23 +0200 Subject: [PATCH 2/5] docs(config): clarify is_enabled semantics and typo-safety contract Documents the three-state intent (true/false/absent) of ProviderConfig.is_enabled and the dependency on deny_unknown_fields (added in the next commit) to reject typos like enbaled = false at parse time. Behaviour is unchanged; this is purely contractual clarity to support the silent-typo-killer audit. Closes audit finding: silent typo killer on provider config. --- src/cli/config/providers.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/cli/config/providers.rs b/src/cli/config/providers.rs index 87c9ab77..586f49ae 100644 --- a/src/cli/config/providers.rs +++ b/src/cli/config/providers.rs @@ -119,7 +119,16 @@ pub struct ProviderConfig { } impl ProviderConfig { - /// Returns `true` if the provider is enabled (defaults to `true`). + /// Returns `true` if the provider is enabled. + /// + /// Semantics: + /// - `enabled = true` → enabled. + /// - `enabled = false` → disabled. + /// - `enabled` absent → enabled (sensible default for newly added blocks). + /// + /// Typo safety: `#[serde(deny_unknown_fields)]` on [`ProviderConfig`] + /// rejects misspelled keys (e.g. `enbaled`) at parse time, so an absent + /// `enabled` field genuinely means "not specified" rather than "typo'd". pub fn is_enabled(&self) -> bool { self.enabled.unwrap_or(true) } From 93401b4a3777ee11e730c9673038aeaa4f63be44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20LIARD?= Date: Tue, 28 Apr 2026 22:35:00 +0200 Subject: [PATCH 3/5] fix(config): reject unknown TOML fields in core config structs Adds #[serde(deny_unknown_fields)] to AppConfig and the major sub-structs (ProviderConfig, ModelConfig, TierConfig, RouterConfig, ScoringConfig, CacheConfig, BudgetConfig, DlpConfig, SecurityConfig). Without this guard, a typo like enbaled = false in a [[providers]] block silently parses (the unknown key is dropped) and the provider remains enabled with the wrong intent. With the guard, parsing fails loudly and the operator gets an actionable error pointing at the offending key. Tested with the full nextest suite (1268 tests) plus all doctests: no fixture, preset or example carries a stale field, so this is a pure tightening with no migration cost. Closes audit finding: silent typo killer on TOML config. --- src/cli/config/budget.rs | 1 + src/cli/config/cache.rs | 1 + src/cli/config/providers.rs | 1 + src/cli/config/routing.rs | 3 +++ src/cli/config/security.rs | 1 + src/features/dlp/config.rs | 1 + src/models/config.rs | 1 + src/routing/classify/classify.rs | 1 + 8 files changed, 10 insertions(+) diff --git a/src/cli/config/budget.rs b/src/cli/config/budget.rs index 414a2f26..4eaa6e74 100644 --- a/src/cli/config/budget.rs +++ b/src/cli/config/budget.rs @@ -6,6 +6,7 @@ use crate::cli::BudgetUsd; /// Budget configuration #[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[serde(deny_unknown_fields)] pub struct BudgetConfig { /// Global monthly hard cap in USD (0 = unlimited) #[serde(default)] diff --git a/src/cli/config/cache.rs b/src/cli/config/cache.rs index b2e42554..201ddf66 100644 --- a/src/cli/config/cache.rs +++ b/src/cli/config/cache.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; /// LLM response cache configuration #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct CacheConfig { /// Enable response caching (only for temperature=0 requests) #[serde(default)] diff --git a/src/cli/config/providers.rs b/src/cli/config/providers.rs index 586f49ae..88c240f8 100644 --- a/src/cli/config/providers.rs +++ b/src/cli/config/providers.rs @@ -28,6 +28,7 @@ pub enum AuthType { /// Provider configuration from TOML. #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ProviderConfig { /// Unique provider name used in routing and logging. pub name: String, diff --git a/src/cli/config/routing.rs b/src/cli/config/routing.rs index d0992450..0538177f 100644 --- a/src/cli/config/routing.rs +++ b/src/cli/config/routing.rs @@ -9,6 +9,7 @@ use super::user::PresetConfig; /// Router configuration #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct RouterConfig { /// Default model for unclassified requests pub default: String, @@ -102,6 +103,7 @@ pub struct FanOutConfig { /// Model configuration with 1:N provider mappings #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct ModelConfig { /// External model name (used in API requests) pub name: String, @@ -203,6 +205,7 @@ pub struct TierMatchCondition { /// When the scoring heuristic classifies a request, the dispatch pipeline /// resolves providers from the matching tier instead of the default model mappings. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct TierConfig { /// Tier name — must match a `ComplexityTier` variant (case-insensitive). pub name: String, diff --git a/src/cli/config/security.rs b/src/cli/config/security.rs index 27583a9c..c403952e 100644 --- a/src/cli/config/security.rs +++ b/src/cli/config/security.rs @@ -8,6 +8,7 @@ use super::default_true; /// Security configuration (wired into middleware stack) #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SecurityConfig { /// Master switch for security middleware #[serde(default = "default_true")] diff --git a/src/features/dlp/config.rs b/src/features/dlp/config.rs index f9329651..dc81c20e 100644 --- a/src/features/dlp/config.rs +++ b/src/features/dlp/config.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; /// Top-level DLP configuration, mapped from `[dlp]` in TOML. #[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[serde(deny_unknown_fields)] pub struct DlpConfig { /// Enables the DLP pipeline globally. #[serde(default)] diff --git a/src/models/config.rs b/src/models/config.rs index 63fe5196..1c87c1b5 100644 --- a/src/models/config.rs +++ b/src/models/config.rs @@ -25,6 +25,7 @@ use crate::features::tap::TapConfig; /// Application configuration #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct AppConfig { /// Config schema version (for forward compatibility) #[serde(default, skip_serializing_if = "Option::is_none")] diff --git a/src/routing/classify/classify.rs b/src/routing/classify/classify.rs index ebe280eb..38b70b2a 100644 --- a/src/routing/classify/classify.rs +++ b/src/routing/classify/classify.rs @@ -84,6 +84,7 @@ impl Default for ScoringThresholds { /// Scoring configuration combining weights and thresholds. #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] +#[serde(deny_unknown_fields)] pub struct ScoringConfig { /// Per-signal weights. pub weights: ScoringWeights, From 56675a67fb73de87d5ffd658b1a3e5f80988eb79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20LIARD?= Date: Tue, 28 Apr 2026 22:36:52 +0200 Subject: [PATCH 4/5] docs(config-guard): document deny-list rationale and log denied reloads MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Each entry in DENIED_SECTIONS / DENIED_KEYS now carries a short justification table covering why it can not be hot-reloaded — either because the data is sensitive (credentials, DLP rules) or because the consumer is constructed once at process start (TLS listener, secret backend, TEE attestation, FIPS gate). Adds tee, fips, server.tls and secrets.backend to the deny-list so the documented "static-init" rationale matches actual behaviour. Also emits an INFO log on every denied attempt telling the operator to restart instead of expecting the silent reload to apply. Adds two unit tests covering the new deny entries (tee/fips sections and server.tls / secrets.backend keys) and asserts that sibling keys in the same sections remain editable. Closes audit finding: hot-reload UX (silent ignore of denied edits). --- src/server/config_guard.rs | 80 +++++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/src/server/config_guard.rs b/src/server/config_guard.rs index b0a8132f..58993720 100644 --- a/src/server/config_guard.rs +++ b/src/server/config_guard.rs @@ -12,29 +12,76 @@ use std::sync::Arc; use tracing::info; /// Top-level TOML sections that are never writable via any config API. -const DENIED_SECTIONS: &[&str] = &["providers", "dlp"]; +/// +/// Each entry is denied because hot-reloading it cannot be done safely +/// at runtime — either the data is sensitive (and must travel through a +/// dedicated secret API), or the code path that consumes it is set up +/// once at process start and not re-initialised on `/api/config/reload`: +/// +/// | Section | Reason | +/// |-------------|-----------------------------------------------------------------------------------------| +/// | `providers` | Contains API keys; mutate via `grob connect` / secret backend, not the config API. | +/// | `dlp` | Security policy must not be weakened by an authenticated control-plane caller. | +/// | `tee` | TEE attestation runs at startup; flipping the mode mid-flight bypasses the gate. | +/// | `fips` | FIPS mode is checked once on init; toggling at runtime gives a false sense of compliance. | +/// +/// To change any of these the operator must edit `~/.grob/config.toml` +/// and restart the daemon. +const DENIED_SECTIONS: &[&str] = &["providers", "dlp", "tee", "fips"]; /// Per-section keys that are never writable via any config API. +/// +/// These are individual fields whose host section is otherwise editable, +/// but the field itself is either credential material or wired into a +/// non-reloadable subsystem: +/// +/// | Section.Key | Reason | +/// |--------------------|---------------------------------------------------------------------------------| +/// | `router.api_key` | Credential material — never round-trip through the config API. | +/// | `budget.api_key` | Same. | +/// | `cache.api_key` | Same. | +/// | `server.tls` | TLS listener is bound at startup; rebuilding it requires a daemon restart. | +/// | `secrets.backend` | The secret backend is constructed once and shared via `Arc`; swapping it at | +/// | | runtime would orphan in-flight readers and change credential resolution semantics. | const DENIED_KEYS: &[(&str, &str)] = &[ ("router", "api_key"), ("budget", "api_key"), ("cache", "api_key"), + ("server", "tls"), + ("secrets", "backend"), ]; /// Checks whether a (section, key) pair is blocked by the deny-list. /// -/// Returns `true` when the write must be rejected: -/// - The entire `providers` section (contains API keys). -/// - The entire `dlp` section (security must not be weakened). -/// - Any `api_key` field in any section. +/// Returns `true` when the write must be rejected. See [`DENIED_SECTIONS`] +/// and [`DENIED_KEYS`] for the rationale behind every entry. A denied +/// attempt is logged at INFO so the operator sees actionable guidance +/// (restart instead of expecting a silent reload to take effect). pub fn is_section_or_key_denied(section: &str, key: &str) -> bool { if DENIED_SECTIONS.contains(§ion) { + info!( + section = %section, + "config hot-reload: section is on the deny-list; restart the daemon to apply changes" + ); return true; } if key == "api_key" { + info!( + section = %section, + key = %key, + "config hot-reload: api_key fields cannot be set via the config API; use `grob connect` or the secret backend" + ); + return true; + } + if DENIED_KEYS.iter().any(|(s, k)| *s == section && *k == key) { + info!( + section = %section, + key = %key, + "config hot-reload: key is on the deny-list; restart the daemon to apply changes" + ); return true; } - DENIED_KEYS.iter().any(|(s, k)| *s == section && *k == key) + false } /// Validates a key update against the deny-list using [`ConfigSection`]. @@ -177,6 +224,27 @@ mod tests { assert!(!is_section_or_key_denied("cache", "ttl_secs")); } + #[test] + fn deny_static_init_sections() { + // tee and fips are checked once at startup; toggling them at runtime + // would bypass the gate without the operator realising. + assert!(is_section_or_key_denied("tee", "mode")); + assert!(is_section_or_key_denied("tee", "sealed_keys")); + assert!(is_section_or_key_denied("fips", "mode")); + assert!(is_section_or_key_denied("fips", "anything")); + } + + #[test] + fn deny_static_init_keys() { + // The TLS listener and secret backend are constructed once on + // process start; both require a daemon restart to swap. + assert!(is_section_or_key_denied("server", "tls")); + assert!(is_section_or_key_denied("secrets", "backend")); + // Sibling keys in the same sections must remain editable. + assert!(!is_section_or_key_denied("server", "host")); + assert!(!is_section_or_key_denied("server", "port")); + } + #[cfg(feature = "mcp")] mod mcp_compat { use super::*; From 8d78543d5cc3aee09b68346d48c1c7c8feacf8ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20LIARD?= Date: Tue, 28 Apr 2026 22:51:42 +0200 Subject: [PATCH 5/5] refactor: unify request errors + add audit log middleware MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the `AppError` + `ProviderError` split with a single `RequestError` enum that carries the upstream HTTP status verbatim (502/503/504) instead of flattening every provider failure to a generic 502 body. Introduces the canonical `RequestError::is_retryable` classifier as the single source of truth for retry/backoff decisions and removes three duplicate 429 detectors from `dispatch/retry.rs`. Also adds an Axum audit-log middleware (`audit_log_layer`) that emits an `AuditEvent::RequestProcessed` entry for every request lifecycle — including OAuth, config, and error paths that previously bypassed audit entirely. Uses an `AuditedAlready` response-extension marker so the dispatch path (which writes a richer DLP/risk/token-aware entry) is not double-logged. Closes the audit gap that allowed silent model enumeration via DLP probes (EU AI Act Article 6 / PCI DSS 3.4). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/features/token_pricing/spend.rs | 10 + src/security/audit_log.rs | 4 + src/server/budget.rs | 11 +- src/server/config_api.rs | 22 +- src/server/config_guard.rs | 23 +- src/server/dispatch/mod.rs | 28 +- src/server/dispatch/provider_loop.rs | 22 +- src/server/dispatch/resolver.rs | 6 +- src/server/dispatch/retry.rs | 43 +- src/server/error.rs | 720 ++++++++++++++++-- src/server/handlers.rs | 126 ++- src/server/helpers.rs | 10 +- src/server/middleware.rs | 189 ++++- src/server/mod.rs | 22 +- tests/enterprise/snapshot_error_test.rs | 66 +- ...ror_test__snapshot_auth_revoked_error.snap | 5 + ..._test__snapshot_budget_exceeded_error.snap | 2 +- ..._error_test__snapshot_forbidden_error.snap | 5 + ...t_error_test__snapshot_provider_error.snap | 2 +- ...ror_test__snapshot_rate_limited_error.snap | 5 + ...ror_test__snapshot_unauthorized_error.snap | 5 + tests/integration/audit_middleware_test.rs | 244 ++++++ tests/integration/mod.rs | 1 + 23 files changed, 1385 insertions(+), 186 deletions(-) create mode 100644 tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap create mode 100644 tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap create mode 100644 tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap create mode 100644 tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap create mode 100644 tests/integration/audit_middleware_test.rs diff --git a/src/features/token_pricing/spend.rs b/src/features/token_pricing/spend.rs index e00ab0b1..eabc6eab 100644 --- a/src/features/token_pricing/spend.rs +++ b/src/features/token_pricing/spend.rs @@ -25,6 +25,10 @@ pub struct BudgetLimits { pub struct BudgetError { /// Human-readable budget exceeded message. pub message: String, + /// Configured budget limit in USD (whichever scope tripped first). + pub limit_usd: f64, + /// Recorded spend in USD at check time. + pub actual_usd: f64, } impl std::fmt::Display for BudgetError { @@ -216,6 +220,8 @@ impl SpendTracker { "Monthly budget for model '{}' reached: ${:.2}/${:.2}", model, spend, limit ), + limit_usd: limit, + actual_usd: spend, }); } } @@ -228,6 +234,8 @@ impl SpendTracker { "Monthly budget for provider '{}' reached: ${:.2}/${:.2}", provider, spend, limit ), + limit_usd: limit, + actual_usd: spend, }); } } @@ -240,6 +248,8 @@ impl SpendTracker { "Monthly global budget reached: ${:.2}/${:.2}", total, global_limit ), + limit_usd: global_limit, + actual_usd: total, }); } } diff --git a/src/security/audit_log.rs b/src/security/audit_log.rs index 04005928..bc1bd154 100644 --- a/src/security/audit_log.rs +++ b/src/security/audit_log.rs @@ -86,6 +86,10 @@ pub enum AuditEvent { HitApproval, /// TEE attestation report generated at startup. TeeAttestation, + /// HTTP request fully processed (emitted by the audit middleware once a + /// response has been produced — covers the entire request lifecycle from + /// authentication through dispatch and error handling). + RequestProcessed, } /// Immutable audit log entry. diff --git a/src/server/budget.rs b/src/server/budget.rs index ef919045..a40eee75 100644 --- a/src/server/budget.rs +++ b/src/server/budget.rs @@ -4,7 +4,7 @@ use crate::providers::AuthType; use std::sync::Arc; use tracing::warn; -use super::{AppError, AppState, ReloadableState}; +use super::{AppState, ReloadableState, RequestError}; /// Maximum retries per provider before falling back to the next mapping. /// NOTE: 2 retries (3 total attempts) balances latency vs resilience — most @@ -62,13 +62,13 @@ pub(crate) fn record_request_metrics(m: &RequestMetrics<'_>) { } } -/// Check budget before a request. Returns Err(AppError::BudgetExceeded) if any limit is hit. +/// Check budget before a request. Returns `Err(RequestError::BudgetExceeded)` if any limit is hit. pub(crate) async fn check_budget( state: &Arc, inner: &Arc, provider_name: &str, model_name: &str, -) -> Result<(), AppError> { +) -> Result<(), RequestError> { let budget_config = &inner.config.budget; let global_limit = budget_config.monthly_limit_usd.value(); @@ -92,7 +92,10 @@ pub(crate) async fn check_budget( provider_limit, model_limit, ) { - return Err(AppError::BudgetExceeded(e.message)); + return Err(RequestError::BudgetExceeded { + limit_usd: e.limit_usd, + actual_usd: e.actual_usd, + }); } if let Some(warning) = tracker.check_warnings( diff --git a/src/server/config_api.rs b/src/server/config_api.rs index eb1cd5cd..03bb7059 100644 --- a/src/server/config_api.rs +++ b/src/server/config_api.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use tracing::{error, info, warn}; use super::config_guard::is_section_or_key_denied; -use super::{AppError, AppState, ReloadableState}; +use super::{AppState, ReloadableState, RequestError}; /// Redact an API key for safe display (show first 4 + last 4 chars) pub(crate) fn redact_api_key(key: &str) -> String { @@ -89,7 +89,7 @@ pub(crate) async fn get_config_json(State(state): State>) -> impl pub(crate) async fn update_config_json( State(state): State>, Json(mut new_config): Json, -) -> Result, AppError> { +) -> Result, RequestError> { // Remove null values (TOML doesn't support null) remove_null_values(&mut new_config); @@ -99,7 +99,7 @@ pub(crate) async fn update_config_json( // Whole-section deny check (providers, dlp). if is_section_or_key_denied(section, "") { warn!(section = %section, "config API: denied write to protected section"); - return Err(AppError::ParseError(format!( + return Err(RequestError::Forbidden(format!( "denied: section '{}' cannot be modified via the config API", section ))); @@ -109,7 +109,7 @@ pub(crate) async fn update_config_json( for key in inner.keys() { if is_section_or_key_denied(section, key) { warn!(section = %section, key = %key, "config API: denied write to protected key"); - return Err(AppError::ParseError(format!( + return Err(RequestError::Forbidden(format!( "denied: {}.{} cannot be modified via the config API", section, key ))); @@ -123,7 +123,7 @@ pub(crate) async fn update_config_json( let config_path = match &state.config_source { crate::cli::ConfigSource::File(p) => p, crate::cli::ConfigSource::Url(_) => { - return Err(AppError::ParseError( + return Err(RequestError::BadRequest( "Cannot save config: loaded from remote URL (read-only)".to_string(), )); } @@ -132,15 +132,15 @@ pub(crate) async fn update_config_json( // Read current config and merge the incoming JSON updates into it. let config_str = tokio::fs::read_to_string(config_path) .await - .map_err(|e| AppError::ParseError(format!("Failed to read config: {e}")))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("Failed to read config: {e}")))?; let mut config: toml::Value = toml::from_str(&config_str) - .map_err(|e| AppError::ParseError(format!("Failed to parse config: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Failed to parse config: {e}")))?; // Update providers section if let Some(providers) = new_config.get("providers") { let providers_toml: toml::Value = serde_json::from_str(&providers.to_string()) - .map_err(|e| AppError::ParseError(format!("Failed to convert providers: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Failed to convert providers: {e}")))?; if let Some(table) = config.as_table_mut() { table.insert("providers".to_string(), providers_toml); @@ -150,7 +150,7 @@ pub(crate) async fn update_config_json( // Update models section if let Some(models) = new_config.get("models") { let models_toml: toml::Value = serde_json::from_str(&models.to_string()) - .map_err(|e| AppError::ParseError(format!("Failed to convert models: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Failed to convert models: {e}")))?; if let Some(table) = config.as_table_mut() { table.insert("models".to_string(), models_toml); @@ -192,9 +192,9 @@ pub(crate) async fn update_config_json( // Deserialise the merged TOML into AppConfig so we can validate and reload. let merged_toml_str = toml::to_string_pretty(&config) - .map_err(|e| AppError::ParseError(format!("Failed to serialize config: {e}")))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("Failed to serialize config: {e}")))?; let merged_config: crate::models::config::AppConfig = toml::from_str(&merged_toml_str) - .map_err(|e| AppError::ParseError(format!("Invalid config after merge: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Invalid config after merge: {e}")))?; // Backup, write, and hot-reload via the shared pipeline. super::config_guard::persist_and_reload(&state, &merged_config).await?; diff --git a/src/server/config_guard.rs b/src/server/config_guard.rs index 58993720..667d02fc 100644 --- a/src/server/config_guard.rs +++ b/src/server/config_guard.rs @@ -114,11 +114,11 @@ pub fn is_key_denied(section: &ConfigSection, key: &str) -> bool { pub async fn persist_and_reload( state: &Arc, config: &crate::models::config::AppConfig, -) -> Result<(), super::AppError> { +) -> Result<(), super::RequestError> { let config_path = match &state.config_source { crate::cli::ConfigSource::File(p) => p, crate::cli::ConfigSource::Url(_) => { - return Err(super::AppError::ParseError( + return Err(super::RequestError::BadRequest( "Cannot save config: loaded from remote URL (read-only)".to_string(), )); } @@ -128,15 +128,18 @@ pub async fn persist_and_reload( let backup_path = config_path.with_extension("toml.backup"); tokio::fs::copy(config_path, &backup_path) .await - .map_err(|e| super::AppError::ParseError(format!("Failed to create backup: {e}")))?; + .map_err(|e| { + super::RequestError::Internal(anyhow::anyhow!("Failed to create backup: {e}")) + })?; // 2. Serialise and write - let toml_str = toml::to_string_pretty(config) - .map_err(|e| super::AppError::ParseError(format!("Failed to serialize config: {e}")))?; + let toml_str = toml::to_string_pretty(config).map_err(|e| { + super::RequestError::Internal(anyhow::anyhow!("Failed to serialize config: {e}")) + })?; - tokio::fs::write(config_path, toml_str) - .await - .map_err(|e| super::AppError::ParseError(format!("Failed to write config: {e}")))?; + tokio::fs::write(config_path, toml_str).await.map_err(|e| { + super::RequestError::Internal(anyhow::anyhow!("Failed to write config: {e}")) + })?; // 3. Hot-reload: rebuild router + provider registry from the new config reload_state(state, config.clone(), config_path)?; @@ -156,7 +159,7 @@ fn reload_state( state: &Arc, config: crate::models::config::AppConfig, _config_path: &Path, -) -> Result<(), super::AppError> { +) -> Result<(), super::RequestError> { let new_router = crate::routing::classify::Router::new(config.clone()); let secret_backend = @@ -170,7 +173,7 @@ fn reload_state( &config.server.timeouts, ) .map_err(|e| { - super::AppError::ProviderError(format!("Failed to rebuild provider registry: {e}")) + super::RequestError::Internal(anyhow::anyhow!("Failed to rebuild provider registry: {e}")) })?; let new_inner = Arc::new(super::ReloadableState::new( diff --git a/src/server/dispatch/mod.rs b/src/server/dispatch/mod.rs index ffb8c64e..940cac10 100644 --- a/src/server/dispatch/mod.rs +++ b/src/server/dispatch/mod.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use super::{ calculate_cost, is_provider_subscription, log_audit, record_request_metrics, - resolve_provider_mappings, sanitize_provider_response_reported, AppError, AppState, - AuditCompliance, AuditParams, ReloadableState, RequestMetrics, + resolve_provider_mappings, sanitize_provider_response_reported, AppState, AuditCompliance, + AuditParams, ReloadableState, RequestError, RequestMetrics, }; use crate::features::watch::events::{DlpDirection, WatchEvent}; @@ -46,6 +46,9 @@ pub(crate) struct DispatchContext<'a> { pub headers: &'a HeaderMap, /// Message tracer context. None for OpenAI compat endpoint. pub trace_id: Option, + /// Audit-emitted flag — flipped by `log_audit_if_enabled` so the + /// outer audit middleware can skip writing a duplicate entry. + pub audited: std::sync::Arc, /// Resolved policy for this request (when policies feature is enabled). #[cfg(feature = "policies")] #[allow(dead_code)] @@ -169,6 +172,9 @@ impl DispatchContext<'_> { dlp_had_pii: entry.dlp_had_pii, dlp_had_redact_or_warn: entry.dlp_had_redact_or_warn, }); + // Flag so the outer audit middleware skips a duplicate entry. + self.audited + .store(true, std::sync::atomic::Ordering::Release); } } } @@ -244,7 +250,7 @@ pub(crate) fn resolve_grob_hint( pub(crate) async fn dispatch( ctx: &DispatchContext<'_>, request: &mut CanonicalRequest, -) -> Result { +) -> Result { // ── Step 0: Resolve complexity hint ── // Resolved up-front (borrows `request` immutably) but applied post-routing // so the client-declared tier overrides the algorithmic scorer. @@ -291,7 +297,7 @@ pub(crate) async fn dispatch( .inner .router .route(request) - .map_err(|e| AppError::RoutingError(e.to_string()))?; + .map_err(|e| RequestError::RoutingError(e.to_string()))?; // ── Step 3.5: Apply client-declared complexity hint ── // The hint (header / body metadata / MCP one-shot) overrides whatever tier @@ -398,7 +404,7 @@ async fn check_cache( fn scan_dlp_input( ctx: &DispatchContext<'_>, request: &mut CanonicalRequest, -) -> Result<(), AppError> { +) -> Result<(), RequestError> { let Some(ref dlp_engine) = ctx.dlp else { return Ok(()); }; @@ -476,7 +482,7 @@ fn scan_dlp_input( dlp_had_pii: false, dlp_had_redact_or_warn: false, }); - Err(AppError::DlpBlocked(format!("{}", block_err))) + Err(RequestError::DlpBlocked(format!("{}", block_err))) } } } @@ -488,7 +494,7 @@ async fn dispatch_fan_out( sorted_mappings: &[crate::cli::ModelMapping], fan_out_config: &crate::cli::FanOutConfig, decision: &crate::models::RouteDecision, -) -> Result { +) -> Result { let mut fan_request = request.clone(); ctx.sanitize_input(&mut fan_request); @@ -503,7 +509,11 @@ async fn dispatch_fan_out( Ok((response, provider_info)) => { handle_fan_out_success(ctx, response, &provider_info, decision).await } - Err(e) => Err(AppError::ProviderError(format!("Fan-out failed: {}", e))), + Err(e) => Err(RequestError::ProviderUpstream { + provider: "fan_out".to_string(), + status: 502, + body: Some(format!("Fan-out failed: {}", e)), + }), } } @@ -513,7 +523,7 @@ async fn handle_fan_out_success( mut response: ProviderResponse, provider_info: &[(String, String)], decision: &crate::models::RouteDecision, -) -> Result { +) -> Result { ctx.sanitize_output(&mut response); let latency_ms = ctx.start_time.elapsed().as_millis() as u64; diff --git a/src/server/dispatch/provider_loop.rs b/src/server/dispatch/provider_loop.rs index 9b441d19..773be8fc 100644 --- a/src/server/dispatch/provider_loop.rs +++ b/src/server/dispatch/provider_loop.rs @@ -17,11 +17,11 @@ //! //! After the loop exhausts the list, [`resolver::try_direct_provider_lookup`] //! offers a backward-compat path for unmapped models. A final audit entry -//! is written before returning `AppError::ProviderError`. +//! is written before returning `RequestError::ProviderUpstream`. use super::super::{ check_budget, format_route_type, inject_continuation_text, is_provider_subscription, - should_inject_continuation, AppError, + should_inject_continuation, RequestError, }; use super::resolver::{resolve_provider, try_direct_provider_lookup}; use super::retry::{ @@ -39,7 +39,7 @@ pub(super) async fn dispatch_provider_loop( sorted_mappings: &[crate::cli::ModelMapping], decision: &crate::models::RouteDecision, cache_key: &Option, -) -> Result { +) -> Result { // Re-sort mappings by adaptive score when scorer is enabled let rescored; let effective_mappings: &[crate::cli::ModelMapping] = @@ -156,7 +156,7 @@ pub(super) async fn dispatch_provider_loop( ); // Abort the fallback cascade: this is a user-actionable error, // not a transient provider failure. - return Err(AppError::AuthenticationError(format!( + return Err(RequestError::AuthRevoked(format!( "OAuth token for provider '{}' revoked. Run: grob connect --force-reauth. Details: {}", mapping.provider, msg ))); @@ -189,11 +189,15 @@ pub(super) async fn dispatch_provider_loop( "All provider mappings failed for model: {}", decision.model_name ); - Err(AppError::ProviderError(format!( - "All {} provider mappings failed for model: {}", - effective_mappings.len(), - decision.model_name - ))) + Err(RequestError::ProviderUpstream { + provider: "all".to_string(), + status: 502, + body: Some(format!( + "All {} provider mappings failed for model: {}", + effective_mappings.len(), + decision.model_name + )), + }) } /// Log the dispatch attempt info line (route type, stream mode, model -> provider). diff --git a/src/server/dispatch/resolver.rs b/src/server/dispatch/resolver.rs index 95e0cf54..6a5cfc6e 100644 --- a/src/server/dispatch/resolver.rs +++ b/src/server/dispatch/resolver.rs @@ -8,7 +8,7 @@ use std::sync::Arc; -use super::super::AppError; +use super::super::RequestError; use super::{DispatchContext, DispatchResult}; use tracing::info; @@ -101,7 +101,7 @@ pub(super) async fn try_direct_provider_lookup( ctx: &DispatchContext<'_>, request: &crate::models::CanonicalRequest, model_name: &str, -) -> Result, AppError> { +) -> Result, RequestError> { let Ok(provider) = ctx.inner.provider_registry.provider_for_model(model_name) else { return Ok(None); }; @@ -116,7 +116,7 @@ pub(super) async fn try_direct_provider_lookup( let mut response = provider .send_message(fallback_request) .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; + .map_err(RequestError::from)?; response.model = original_model; Ok(Some(DispatchResult::Complete { diff --git a/src/server/dispatch/retry.rs b/src/server/dispatch/retry.rs index aa7fd690..8bf3df03 100644 --- a/src/server/dispatch/retry.rs +++ b/src/server/dispatch/retry.rs @@ -43,16 +43,31 @@ pub(super) struct ProviderAttempt<'a> { pub is_subscription: bool, } +/// Returns `true` when a provider error reports a 429 rate-limit upstream. +/// +/// Defers to the `RequestError::RateLimited` mapping rules so the +/// classification logic lives in one place: a 429 status code OR a 401 with a +/// `rate_limit_error` payload (Anthropic-style). Callers that need to know +/// specifically whether they hit a 429 (e.g. to rotate a key pool) consult +/// this helper rather than re-implement the matcher. +pub(super) fn is_upstream_rate_limit(e: &crate::providers::error::ProviderError) -> bool { + use crate::providers::error::ProviderError; + match e { + ProviderError::ApiError { status: 429, .. } => true, + ProviderError::ApiError { + status: 401, + message, + } => super::super::budget::is_rate_limit_payload(message), + _ => false, + } +} + /// Emit shared provider-error metrics (rate-limit counter + error counter). fn emit_provider_error_metrics( mapping: &crate::cli::ModelMapping, e: &crate::providers::error::ProviderError, ) { - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit { + if is_upstream_rate_limit(e) { warn!("Provider {} rate limited", mapping.provider); metrics::counter!( "grob_ratelimit_hits_total", @@ -177,11 +192,7 @@ pub(super) async fn dispatch_streaming( if is_auth_revoked_error(&e) { return Err(ProviderLoopAction::AuthRevoked(e.to_string())); } - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit { + if is_upstream_rate_limit(&e) { Err(ProviderLoopAction::RateLimited) } else { Err(ProviderLoopAction::Continue) @@ -297,11 +308,7 @@ pub(super) async fn dispatch_non_streaming( if classify_and_handle_error(ctx, attempt.mapping, &e, retry) { // On 429, try rotating to next pooled key before retrying. - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit && provider.rotate_key_pool() { + if is_upstream_rate_limit(&e) && provider.rotate_key_pool() { info!( "Provider {} rate-limited, rotated to next pooled key", attempt.mapping.provider @@ -315,11 +322,7 @@ pub(super) async fn dispatch_non_streaming( } // Before giving up on this provider, try key rotation for 429. - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit && provider.rotate_key_pool() { + if is_upstream_rate_limit(&e) && provider.rotate_key_pool() { info!( "Provider {} exhausted retries but rotated to next pooled key", attempt.mapping.provider diff --git a/src/server/error.rs b/src/server/error.rs index 2617c99e..d03cf5e3 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -1,86 +1,371 @@ +//! Unified request-level error taxonomy. +//! +//! Single source of truth for HTTP error responses. Replaces the old +//! `AppError` + `ProviderError` split that masked upstream HTTP status codes +//! behind a generic `502 Bad Gateway` body. +//! +//! Every variant maps to a precise HTTP status; the ECMA RFC-9457 inspired +//! body shape is `{ "error": { "type": ..., "message": ..., ...extras } }`. +//! +//! `is_retryable()` is the authoritative classifier for retry/backoff +//! logic — `dispatch/retry.rs` and the provider loop must consult this +//! single method rather than re-implement status-code matching. + use axum::{ http::StatusCode, response::{IntoResponse, Response}, Json, }; -/// Application error types — all variants carry a user-facing message string. +/// Unified error type for the request pipeline. +/// +/// Carries the upstream HTTP status when applicable so the client sees the +/// exact failure mode (`Bad Gateway`, `Service Unavailable`, …) rather than +/// an opaque "Provider error" string. #[derive(Debug)] -pub enum AppError { - /// Indicates no matching route or model for the request. - RoutingError(String), - /// Indicates a malformed or invalid request payload. +pub enum RequestError { + /// Indicates a malformed or invalid request payload (HTTP 400). + BadRequest(String), + /// Indicates the caller failed authentication (HTTP 401). + Unauthorized, + /// Indicates the caller is authenticated but the action is forbidden (HTTP 403). + Forbidden(String), + /// Indicates the requested resource does not exist (HTTP 404). + NotFound, + /// Indicates a JSON parse or schema validation failure (HTTP 400). ParseError(String), - /// Indicates an upstream provider returned an error. - ProviderError(String), - /// Indicates the monthly spend budget has been exceeded. - BudgetExceeded(String), - /// Indicates the DLP pipeline blocked the request. + /// Indicates routing could not resolve a model or provider (HTTP 400). + RoutingError(String), + /// Indicates the upstream provider rate-limited the request (HTTP 429). + RateLimited { + /// Provider that emitted the 429 (or the resolved alias). + provider: String, + /// Server-side hint for when to retry, when known. + retry_after_ms: Option, + }, + /// Indicates the upstream provider returned a non-success status. + /// + /// The original status is forwarded verbatim so the client sees the + /// actual failure mode (502/503/504/etc.) instead of a flattened error. + ProviderUpstream { + /// Provider name (e.g. `"anthropic"`). + provider: String, + /// Verbatim upstream HTTP status code. + status: u16, + /// Optional upstream body excerpt for diagnostics. + body: Option, + }, + /// Indicates a budget cap (global, provider, or model) was exceeded (HTTP 402). + BudgetExceeded { + /// Configured monthly limit in USD. + limit_usd: f64, + /// Actual recorded spend in USD at the time of the check. + actual_usd: f64, + }, + /// Indicates the DLP pipeline blocked the request (HTTP 400). DlpBlocked(String), - /// Indicates an upstream OAuth token is revoked or invalid (401 authentication_error). + /// Indicates an upstream OAuth credential was revoked (HTTP 401). /// - /// Surfaced to the client as a terminal 401 without fallback to sibling providers. - AuthenticationError(String), + /// Surfaces a terminal authentication error — the user must run + /// `grob connect --force-reauth`. Distinct from `Unauthorized`, which + /// covers the inbound caller's credential rather than an upstream's. + AuthRevoked(String), + /// Indicates an internal server failure (HTTP 500). + Internal(anyhow::Error), } -impl IntoResponse for AppError { - fn into_response(self) -> Response { - let (status, error_type, message) = match self { - AppError::RoutingError(msg) => (StatusCode::BAD_REQUEST, "error", msg), - AppError::ParseError(msg) => (StatusCode::BAD_REQUEST, "invalid_request_error", msg), - AppError::ProviderError(msg) => (StatusCode::BAD_GATEWAY, "error", msg), - AppError::BudgetExceeded(msg) => (StatusCode::PAYMENT_REQUIRED, "budget_exceeded", msg), - AppError::DlpBlocked(msg) => (StatusCode::BAD_REQUEST, "dlp_block", msg), - AppError::AuthenticationError(msg) => { - (StatusCode::UNAUTHORIZED, "authentication_error", msg) +impl RequestError { + /// Returns `true` when the error is transient and the dispatch loop should + /// retry (with exponential backoff) before falling back to the next provider. + /// + /// This is the SINGLE source of truth for retry classification — both the + /// retry loop and the rate-limit detector must call this method rather than + /// duplicate the status-code matching logic. + /// + /// Notably **excludes** `AuthRevoked` (a permanent 401 requires operator + /// action) but **includes** `RateLimited` and 5xx upstream failures. + pub fn is_retryable(&self) -> bool { + match self { + RequestError::RateLimited { .. } => true, + RequestError::ProviderUpstream { status, .. } => { + matches!(*status, 429 | 500 | 502 | 503 | 504) } - }; + // Network/transport failures bubble up here too; treat as retryable + // when the underlying error chain wraps a `reqwest::Error`. + RequestError::Internal(err) => err.downcast_ref::().is_some(), + _ => false, + } + } + + /// Returns the HTTP status code, error type tag, and message for this variant. + fn parts(&self) -> (StatusCode, &'static str, String) { + match self { + RequestError::BadRequest(msg) => ( + StatusCode::BAD_REQUEST, + "invalid_request_error", + msg.clone(), + ), + RequestError::Unauthorized => ( + StatusCode::UNAUTHORIZED, + "authentication_error", + "Missing or invalid credentials".to_string(), + ), + RequestError::Forbidden(msg) => { + (StatusCode::FORBIDDEN, "permission_error", msg.clone()) + } + RequestError::NotFound => ( + StatusCode::NOT_FOUND, + "not_found_error", + "Resource not found".to_string(), + ), + RequestError::ParseError(msg) => ( + StatusCode::BAD_REQUEST, + "invalid_request_error", + msg.clone(), + ), + RequestError::RoutingError(msg) => (StatusCode::BAD_REQUEST, "error", msg.clone()), + RequestError::RateLimited { provider, .. } => ( + StatusCode::TOO_MANY_REQUESTS, + "rate_limit_error", + format!("Provider '{}' rate-limited the request", provider), + ), + RequestError::ProviderUpstream { + provider, + status, + body, + } => { + // Forward the upstream status verbatim so the client sees the + // exact failure mode. Default to 502 for unmapped non-success codes. + let status_code = StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY); + let msg = body + .clone() + .unwrap_or_else(|| format!("Provider '{}' returned HTTP {}", provider, status)); + (status_code, "error", msg) + } + RequestError::BudgetExceeded { + limit_usd, + actual_usd, + } => ( + StatusCode::PAYMENT_REQUIRED, + "budget_exceeded", + format!( + "Budget exceeded: ${:.4} spent of ${:.4} limit", + actual_usd, limit_usd + ), + ), + RequestError::DlpBlocked(msg) => (StatusCode::BAD_REQUEST, "dlp_block", msg.clone()), + RequestError::AuthRevoked(msg) => ( + StatusCode::UNAUTHORIZED, + "authentication_error", + msg.clone(), + ), + RequestError::Internal(err) => { + (StatusCode::INTERNAL_SERVER_ERROR, "error", err.to_string()) + } + } + } + + /// Returns a stable string tag for the error variant — used in audit logs + /// and metrics labels (low-cardinality alternative to the message). + pub fn variant_tag(&self) -> &'static str { + match self { + RequestError::BadRequest(_) => "bad_request", + RequestError::Unauthorized => "unauthorized", + RequestError::Forbidden(_) => "forbidden", + RequestError::NotFound => "not_found", + RequestError::ParseError(_) => "parse_error", + RequestError::RoutingError(_) => "routing_error", + RequestError::RateLimited { .. } => "rate_limited", + RequestError::ProviderUpstream { .. } => "provider_upstream", + RequestError::BudgetExceeded { .. } => "budget_exceeded", + RequestError::DlpBlocked(_) => "dlp_blocked", + RequestError::AuthRevoked(_) => "auth_revoked", + RequestError::Internal(_) => "internal", + } + } +} - let body = Json(serde_json::json!({ +impl IntoResponse for RequestError { + fn into_response(self) -> Response { + let (status, error_type, message) = self.parts(); + + let mut body_obj = serde_json::json!({ "error": { "type": error_type, - "message": message + "message": message, + } + }); + + // Attach variant-specific extras (retry-after for rate limits, + // budget figures for budget overruns, upstream provider for 5xx). + match &self { + RequestError::RateLimited { + provider, + retry_after_ms, + } => { + if let Some(error) = body_obj.get_mut("error").and_then(|v| v.as_object_mut()) { + error.insert( + "provider".to_string(), + serde_json::Value::String(provider.clone()), + ); + if let Some(ms) = retry_after_ms { + error.insert("retry_after_ms".to_string(), serde_json::Value::from(*ms)); + } + } + } + RequestError::ProviderUpstream { + provider, status, .. + } => { + if let Some(error) = body_obj.get_mut("error").and_then(|v| v.as_object_mut()) { + error.insert( + "provider".to_string(), + serde_json::Value::String(provider.clone()), + ); + error.insert( + "upstream_status".to_string(), + serde_json::Value::from(*status), + ); + } + } + RequestError::BudgetExceeded { + limit_usd, + actual_usd, + } => { + if let Some(error) = body_obj.get_mut("error").and_then(|v| v.as_object_mut()) { + error.insert("limit_usd".to_string(), serde_json::Value::from(*limit_usd)); + error.insert( + "actual_usd".to_string(), + serde_json::Value::from(*actual_usd), + ); + } } - })); + _ => {} + } + + let mut response = (status, Json(body_obj)).into_response(); + // Mark response so the audit middleware can pick up the variant + // without re-parsing the body. + response + .extensions_mut() + .insert(ErrorVariantTag(self.variant_tag().to_string())); + + // Forward Retry-After header on 429 when known. + if let RequestError::RateLimited { + retry_after_ms: Some(ms), + .. + } = &self + { + // RFC 7231 Retry-After is in seconds; round up so we never advise + // a delay shorter than the upstream actually requested. + let secs = ms.div_ceil(1000).max(1); + if let Ok(value) = axum::http::HeaderValue::from_str(&secs.to_string()) { + response.headers_mut().insert("retry-after", value); + } + } - (status, body).into_response() + response } } -impl std::fmt::Display for AppError { +/// Marker stored in response extensions so middleware can read the error +/// variant tag without parsing the body. Value is the lowercase variant tag. +#[derive(Clone, Debug)] +pub struct ErrorVariantTag(pub String); + +impl std::fmt::Display for RequestError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (_, _, message) = self.parts(); + write!(f, "{}: {}", self.variant_tag(), message) + } +} + +impl std::error::Error for RequestError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - AppError::RoutingError(msg) => write!(f, "Routing error: {}", msg), - AppError::ParseError(msg) => write!(f, "Parse error: {}", msg), - AppError::ProviderError(msg) => write!(f, "Provider error: {}", msg), - AppError::BudgetExceeded(msg) => write!(f, "Budget exceeded: {}", msg), - AppError::DlpBlocked(msg) => write!(f, "DLP blocked: {}", msg), - AppError::AuthenticationError(msg) => write!(f, "Authentication error: {}", msg), + RequestError::Internal(err) => Some(err.as_ref()), + _ => None, } } } -impl std::error::Error for AppError {} +impl From for RequestError { + fn from(err: anyhow::Error) -> Self { + RequestError::Internal(err) + } +} + +impl From for RequestError { + fn from(err: crate::providers::error::ProviderError) -> Self { + use crate::providers::error::ProviderError; + match err { + ProviderError::ApiError { status, message } => match status { + 429 => RequestError::RateLimited { + provider: "upstream".to_string(), + retry_after_ms: None, + }, + 401 => { + if super::budget::is_rate_limit_payload(&message) { + // Anthropic emits `rate_limit_error` with HTTP 401 — treat + // as a transient rate-limit, not a revoked credential. + RequestError::RateLimited { + provider: "upstream".to_string(), + retry_after_ms: None, + } + } else { + RequestError::AuthRevoked(message) + } + } + _ => RequestError::ProviderUpstream { + provider: "upstream".to_string(), + status, + body: Some(message), + }, + }, + ProviderError::HttpError(e) => { + RequestError::Internal(anyhow::Error::new(e).context("HTTP request failed")) + } + ProviderError::SerializationError(e) => { + RequestError::ParseError(format!("serialization failed: {}", e)) + } + ProviderError::ModelNotSupported(model) => RequestError::RoutingError(format!( + "Model '{}' is not configured. Add a [[models]] entry or set pass_through = true on a provider.", + model + )), + ProviderError::ConfigError(msg) => { + RequestError::Internal(anyhow::anyhow!("Provider config error: {}", msg)) + } + ProviderError::AuthError(msg) => RequestError::AuthRevoked(msg), + ProviderError::NoProviderAvailable => { + RequestError::RoutingError("No provider available for this request".to_string()) + } + ProviderError::AllProvidersFailed(msg) => RequestError::ProviderUpstream { + provider: "all".to_string(), + status: 502, + body: Some(msg), + }, + } + } +} #[cfg(test)] mod tests { use super::*; - /// Extracts status code and parsed JSON body from an AppError response. - async fn error_response_parts(error: AppError) -> (StatusCode, serde_json::Value) { + /// Extracts status code and parsed JSON body from a `RequestError` response. + async fn error_response_parts(error: RequestError) -> (StatusCode, serde_json::Value) { let response = error.into_response(); let status = response.status(); let body_bytes = axum::body::to_bytes(response.into_body(), 1024 * 1024) .await .expect("invariant: in-memory body collection cannot fail"); let json: serde_json::Value = serde_json::from_slice(&body_bytes) - .expect("invariant: AppError always produces valid JSON"); + .expect("invariant: RequestError always produces valid JSON"); (status, json) } #[tokio::test] async fn parse_error_returns_400_with_invalid_request_type() { - let err = AppError::ParseError("invalid JSON at line 1".to_string()); + let err = RequestError::ParseError("invalid JSON at line 1".to_string()); let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_REQUEST); @@ -88,9 +373,40 @@ mod tests { assert_eq!(json["error"]["message"], "invalid JSON at line 1"); } + #[tokio::test] + async fn bad_request_returns_400() { + let err = RequestError::BadRequest("missing required field 'model'".to_string()); + let (status, json) = error_response_parts(err).await; + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(json["error"]["type"], "invalid_request_error"); + } + + #[tokio::test] + async fn unauthorized_returns_401() { + let err = RequestError::Unauthorized; + let (status, json) = error_response_parts(err).await; + assert_eq!(status, StatusCode::UNAUTHORIZED); + assert_eq!(json["error"]["type"], "authentication_error"); + } + + #[tokio::test] + async fn forbidden_returns_403() { + let err = RequestError::Forbidden("policy denies model access".to_string()); + let (status, json) = error_response_parts(err).await; + assert_eq!(status, StatusCode::FORBIDDEN); + assert_eq!(json["error"]["type"], "permission_error"); + } + + #[tokio::test] + async fn not_found_returns_404() { + let err = RequestError::NotFound; + let (status, _) = error_response_parts(err).await; + assert_eq!(status, StatusCode::NOT_FOUND); + } + #[tokio::test] async fn routing_error_returns_400_with_error_type() { - let err = AppError::RoutingError("no matching model: gpt-unknown".to_string()); + let err = RequestError::RoutingError("no matching model: gpt-unknown".to_string()); let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_REQUEST); @@ -99,28 +415,72 @@ mod tests { } #[tokio::test] - async fn provider_error_returns_502() { - let err = AppError::ProviderError("upstream timeout".to_string()); - let (status, json) = error_response_parts(err).await; + async fn rate_limited_returns_429_with_retry_after() { + let err = RequestError::RateLimited { + provider: "anthropic".to_string(), + retry_after_ms: Some(2500), + }; + let response = err.into_response(); + let status = response.status(); + let retry_after = response + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let body_bytes = axum::body::to_bytes(response.into_body(), 1024 * 1024) + .await + .unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(json["error"]["type"], "rate_limit_error"); + assert_eq!(json["error"]["provider"], "anthropic"); + assert_eq!(json["error"]["retry_after_ms"], 2500); + // 2500ms rounds up to 3 seconds. + assert_eq!(retry_after.as_deref(), Some("3")); + } + #[tokio::test] + async fn provider_upstream_502_forwards_status() { + let err = RequestError::ProviderUpstream { + provider: "openai".to_string(), + status: 502, + body: Some("upstream gateway timeout".to_string()), + }; + let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_GATEWAY); - assert_eq!(json["error"]["type"], "error"); - assert_eq!(json["error"]["message"], "upstream timeout"); + assert_eq!(json["error"]["upstream_status"], 502); + assert_eq!(json["error"]["provider"], "openai"); + } + + #[tokio::test] + async fn provider_upstream_503_forwards_status() { + let err = RequestError::ProviderUpstream { + provider: "openai".to_string(), + status: 503, + body: None, + }; + let (status, _) = error_response_parts(err).await; + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); } #[tokio::test] - async fn budget_exceeded_returns_402() { - let err = AppError::BudgetExceeded("monthly limit reached".to_string()); + async fn budget_exceeded_returns_402_with_figures() { + let err = RequestError::BudgetExceeded { + limit_usd: 100.0, + actual_usd: 105.5, + }; let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::PAYMENT_REQUIRED); assert_eq!(json["error"]["type"], "budget_exceeded"); - assert_eq!(json["error"]["message"], "monthly limit reached"); + assert_eq!(json["error"]["limit_usd"], 100.0); + assert_eq!(json["error"]["actual_usd"], 105.5); } #[tokio::test] async fn dlp_blocked_returns_400_with_dlp_block_type() { - let err = AppError::DlpBlocked("secret detected in prompt".to_string()); + let err = RequestError::DlpBlocked("secret detected in prompt".to_string()); let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_REQUEST); @@ -129,13 +489,12 @@ mod tests { } #[tokio::test] - async fn authentication_error_returns_401_with_authentication_error_type() { - let err = AppError::AuthenticationError( + async fn auth_revoked_returns_401() { + let err = RequestError::AuthRevoked( "OAuth token for provider 'anthropic' revoked. Run: grob connect --force-reauth" .to_string(), ); let (status, json) = error_response_parts(err).await; - assert_eq!(status, StatusCode::UNAUTHORIZED); assert_eq!(json["error"]["type"], "authentication_error"); assert!(json["error"]["message"] @@ -144,12 +503,257 @@ mod tests { .contains("grob connect --force-reauth")); } + #[tokio::test] + async fn internal_returns_500() { + let err = RequestError::Internal(anyhow::anyhow!("disk full")); + let (status, _) = error_response_parts(err).await; + assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); + } + + // ── is_retryable() table-driven tests ── + #[test] - fn display_impl_includes_variant_prefix() { - let err = AppError::ParseError("bad input".to_string()); - assert_eq!(err.to_string(), "Parse error: bad input"); + fn rate_limited_is_retryable() { + let err = RequestError::RateLimited { + provider: "x".to_string(), + retry_after_ms: None, + }; + assert!(err.is_retryable()); + } - let err = AppError::RoutingError("no route".to_string()); - assert_eq!(err.to_string(), "Routing error: no route"); + #[test] + fn upstream_500_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 500, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_502_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 502, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_503_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 503, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_504_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 504, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_429_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 429, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_400_is_not_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 400, + body: None, + }; + assert!(!err.is_retryable()); + } + + #[test] + fn upstream_401_is_not_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 401, + body: None, + }; + assert!(!err.is_retryable()); + } + + #[test] + fn auth_revoked_is_not_retryable() { + let err = RequestError::AuthRevoked("revoked".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn unauthorized_is_not_retryable() { + let err = RequestError::Unauthorized; + assert!(!err.is_retryable()); + } + + #[test] + fn forbidden_is_not_retryable() { + let err = RequestError::Forbidden("nope".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn parse_error_is_not_retryable() { + let err = RequestError::ParseError("bad".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn routing_error_is_not_retryable() { + let err = RequestError::RoutingError("no model".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn budget_exceeded_is_not_retryable() { + let err = RequestError::BudgetExceeded { + limit_usd: 100.0, + actual_usd: 200.0, + }; + assert!(!err.is_retryable()); + } + + #[test] + fn dlp_blocked_is_not_retryable() { + let err = RequestError::DlpBlocked("secret".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn internal_without_reqwest_is_not_retryable() { + let err = RequestError::Internal(anyhow::anyhow!("logic bug")); + assert!(!err.is_retryable()); + } + + #[test] + fn variant_tags_are_stable() { + assert_eq!( + RequestError::BadRequest("x".to_string()).variant_tag(), + "bad_request" + ); + assert_eq!(RequestError::Unauthorized.variant_tag(), "unauthorized"); + assert_eq!( + RequestError::Forbidden("x".to_string()).variant_tag(), + "forbidden" + ); + assert_eq!(RequestError::NotFound.variant_tag(), "not_found"); + assert_eq!( + RequestError::ParseError("x".to_string()).variant_tag(), + "parse_error" + ); + assert_eq!( + RequestError::RoutingError("x".to_string()).variant_tag(), + "routing_error" + ); + assert_eq!( + RequestError::RateLimited { + provider: "x".to_string(), + retry_after_ms: None + } + .variant_tag(), + "rate_limited" + ); + assert_eq!( + RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 502, + body: None + } + .variant_tag(), + "provider_upstream" + ); + assert_eq!( + RequestError::BudgetExceeded { + limit_usd: 1.0, + actual_usd: 2.0, + } + .variant_tag(), + "budget_exceeded" + ); + assert_eq!( + RequestError::DlpBlocked("x".to_string()).variant_tag(), + "dlp_blocked" + ); + assert_eq!( + RequestError::AuthRevoked("x".to_string()).variant_tag(), + "auth_revoked" + ); + assert_eq!( + RequestError::Internal(anyhow::anyhow!("x")).variant_tag(), + "internal" + ); + } + + #[test] + fn provider_error_429_converts_to_rate_limited() { + let err = crate::providers::error::ProviderError::ApiError { + status: 429, + message: "slow down".to_string(), + }; + let req_err: RequestError = err.into(); + assert!(matches!(req_err, RequestError::RateLimited { .. })); + assert!(req_err.is_retryable()); + } + + #[test] + fn provider_error_500_converts_to_upstream() { + let err = crate::providers::error::ProviderError::ApiError { + status: 500, + message: "boom".to_string(), + }; + let req_err: RequestError = err.into(); + match &req_err { + RequestError::ProviderUpstream { status, .. } => assert_eq!(*status, 500), + other => panic!("unexpected variant: {:?}", other), + } + assert!(req_err.is_retryable()); + } + + #[test] + fn provider_error_401_with_rate_limit_payload_converts_to_rate_limited() { + let err = crate::providers::error::ProviderError::ApiError { + status: 401, + message: r#"{"type":"error","error":{"type":"rate_limit_error","message":"slow"}}"# + .to_string(), + }; + let req_err: RequestError = err.into(); + assert!(matches!(req_err, RequestError::RateLimited { .. })); + } + + #[test] + fn provider_error_401_authentication_converts_to_auth_revoked() { + let err = crate::providers::error::ProviderError::ApiError { + status: 401, + message: r#"{"type":"error","error":{"type":"authentication_error","message":"bad"}}"# + .to_string(), + }; + let req_err: RequestError = err.into(); + assert!(matches!(req_err, RequestError::AuthRevoked(_))); + assert!(!req_err.is_retryable()); + } + + #[test] + fn display_includes_variant_tag() { + let err = RequestError::ParseError("bad input".to_string()); + let s = err.to_string(); + assert!(s.contains("parse_error")); + assert!(s.contains("bad input")); } } diff --git a/src/server/handlers.rs b/src/server/handlers.rs index 679a0a4e..41849abb 100644 --- a/src/server/handlers.rs +++ b/src/server/handlers.rs @@ -10,9 +10,10 @@ use futures::stream::TryStreamExt; use std::sync::Arc; use tracing::{debug, error}; +use super::middleware::AuditedAlready; use super::{ apply_transparency_headers, dispatch, extract_api_credential, extract_client_ip, openai_compat, - responses_compat, should_apply_transparency, AppError, AppState, RequestId, + responses_compat, should_apply_transparency, AppState, RequestError, RequestId, }; /// Extracts tenant_id from VirtualKeyContext (preferred) or GrobClaims. @@ -50,12 +51,12 @@ impl Drop for ActiveRequestGuard { fn build_json_response( body: Vec, transparency: Option<(&str, &str, &str)>, -) -> Result { +) -> Result { let mut resp = Response::builder() .status(200) .header("content-type", "application/json") .body(Body::from(body)) - .map_err(|e| AppError::ProviderError(format!("response builder: {}", e)))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("response builder: {}", e)))?; if let Some((provider, actual_model, req_id)) = transparency { apply_transparency_headers(resp.headers_mut(), provider, actual_model, req_id); } @@ -71,11 +72,11 @@ fn build_sse_response() -> axum::http::response::Builder { .header("Connection", "keep-alive") } -/// Serializes a value to JSON bytes, returning an `AppError` on failure. -fn serialize_response(value: &T) -> Result, AppError> { +/// Serializes a value to JSON bytes, returning a `RequestError` on failure. +fn serialize_response(value: &T) -> Result, RequestError> { serde_json::to_vec(value).map_err(|e| { error!("Failed to serialize response: {}", e); - AppError::ProviderError(format!("response serialization failed: {}", e)) + RequestError::Internal(anyhow::anyhow!("response serialization failed: {}", e)) }) } @@ -86,6 +87,18 @@ struct DispatchPrelude { tenant_id: Option, peer_ip: String, transparency_enabled: bool, + /// Set by the dispatch path when it has emitted an audit entry — used + /// by the audit middleware to skip duplicate logging. + audited: Arc, +} + +/// Marks the response with the [`AuditedAlready`] extension when dispatch +/// has emitted its own audit entry, so the outer audit middleware skips a +/// duplicate write. +fn mark_audited_if_set(audited: &Arc, response: &mut Response) { + if audited.load(std::sync::atomic::Ordering::Acquire) { + response.extensions_mut().insert(AuditedAlready); + } } /// Builds the shared pre-dispatch state common to all three handlers. @@ -113,6 +126,7 @@ fn prepare_dispatch( tenant_id, peer_ip, transparency_enabled, + audited: Arc::new(std::sync::atomic::AtomicBool::new(false)), } } @@ -137,7 +151,7 @@ fn finish_dispatch( on_streaming: S, on_complete: C, on_fan_out: F, -) -> Result +) -> Result where S: FnOnce( std::pin::Pin< @@ -148,7 +162,7 @@ where >, >, ) -> Body, - C: FnOnce(crate::providers::ProviderResponse) -> Result, AppError>, + C: FnOnce(crate::providers::ProviderResponse) -> Result, RequestError>, F: FnOnce(crate::providers::ProviderResponse) -> Response, { match result { @@ -179,7 +193,7 @@ where let response = response_builder .body(body) - .map_err(|e| AppError::ProviderError(format!("response builder: {}", e)))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("response builder: {}", e)))?; Ok(response) } @@ -227,7 +241,7 @@ pub(crate) async fn handle_openai_chat_completions( axum::Extension(request_id): axum::Extension, headers: HeaderMap, Json(openai_request): Json, -) -> Result { +) -> Result { let _guard = ActiveRequestGuard::new(&state); let model = openai_request.model.clone(); let is_streaming = openai_request.stream == Some(true); @@ -235,14 +249,17 @@ pub(crate) async fn handle_openai_chat_completions( let prelude = prepare_dispatch(&state, &claims, &vk_ctx, &headers); // Transform OpenAI → Anthropic format - let mut request = openai_compat::transform_openai_to_canonical(openai_request) - .map_err(|e| AppError::ParseError(format!("Failed to transform OpenAI request: {}", e)))?; + let mut request = + openai_compat::transform_openai_to_canonical(openai_request).map_err(|e| { + RequestError::ParseError(format!("Failed to transform OpenAI request: {}", e)) + })?; forward_beta_header(&mut request, &headers); let start_time = std::time::Instant::now(); #[cfg(feature = "policies")] let resolved_policy = evaluate_policy_if_configured(&state, prelude.tenant_id.as_deref(), &model, &headers); + let audited_flag = prelude.audited.clone(); let ctx = dispatch::DispatchContext { state: &state, inner: &prelude.inner, @@ -255,15 +272,23 @@ pub(crate) async fn handle_openai_chat_completions( start_time, headers: &headers, trace_id: None, + audited: audited_flag.clone(), #[cfg(feature = "policies")] resolved_policy, }; - let result = dispatch::dispatch(&ctx, &mut request).await?; + let result = match dispatch::dispatch(&ctx, &mut request).await { + Ok(r) => r, + Err(e) => { + let mut response = e.into_response(); + mark_audited_if_set(&audited_flag, &mut response); + return Ok(response); + } + }; let model_for_stream = model.clone(); let model_for_fanout = model.clone(); - finish_dispatch( + let mut response = finish_dispatch( result, prelude.transparency_enabled, &request_id.0, @@ -284,7 +309,9 @@ pub(crate) async fn handle_openai_chat_completions( openai_compat::transform_canonical_to_openai(resp, model_for_fanout); Json(openai_response).into_response() }, - ) + )?; + mark_audited_if_set(&audited_flag, &mut response); + Ok(response) } /// Handle /v1/responses requests (OpenAI Responses API — used by Codex CLI) @@ -296,7 +323,7 @@ pub(crate) async fn handle_responses( axum::Extension(request_id): axum::Extension, headers: HeaderMap, Json(responses_request): Json, -) -> Result { +) -> Result { let _guard = ActiveRequestGuard::new(&state); let model = responses_request.model.clone(); let is_streaming = responses_request.stream == Some(true); @@ -306,7 +333,7 @@ pub(crate) async fn handle_responses( // Transform Responses → canonical format let mut request = responses_compat::transform_responses_to_canonical(responses_request) .map_err(|e| { - AppError::ParseError(format!("Failed to transform Responses request: {}", e)) + RequestError::ParseError(format!("Failed to transform Responses request: {}", e)) })?; forward_beta_header(&mut request, &headers); @@ -314,6 +341,7 @@ pub(crate) async fn handle_responses( #[cfg(feature = "policies")] let resolved_policy = evaluate_policy_if_configured(&state, prelude.tenant_id.as_deref(), &model, &headers); + let audited_flag = prelude.audited.clone(); let ctx = dispatch::DispatchContext { state: &state, inner: &prelude.inner, @@ -326,15 +354,23 @@ pub(crate) async fn handle_responses( start_time, headers: &headers, trace_id: None, + audited: audited_flag.clone(), #[cfg(feature = "policies")] resolved_policy, }; - let result = dispatch::dispatch(&ctx, &mut request).await?; + let result = match dispatch::dispatch(&ctx, &mut request).await { + Ok(r) => r, + Err(e) => { + let mut response = e.into_response(); + mark_audited_if_set(&audited_flag, &mut response); + return Ok(response); + } + }; let model_for_stream = model.clone(); let model_for_fanout = model.clone(); - finish_dispatch( + let mut response = finish_dispatch( result, prelude.transparency_enabled, &request_id.0, @@ -357,7 +393,9 @@ pub(crate) async fn handle_responses( responses_compat::transform_canonical_to_responses(resp, model_for_fanout); Json(responses_response).into_response() }, - ) + )?; + mark_audited_if_set(&audited_flag, &mut response); + Ok(response) } /// Handle /v1/models endpoint (OpenAI-compatible) @@ -391,7 +429,7 @@ pub(crate) async fn handle_messages( axum::Extension(request_id): axum::Extension, headers: HeaderMap, Json(request_json): Json, -) -> Result { +) -> Result { let _guard = ActiveRequestGuard::new(&state); let req_id = &request_id.0; let model: String = request_json @@ -412,7 +450,7 @@ pub(crate) async fn handle_messages( let mut request: CanonicalRequest = serde_json::from_value(request_json).map_err(|e| { tracing::error!("❌ Failed to parse request: {}", e); - AppError::ParseError(format!("Invalid request format: {}", e)) + RequestError::ParseError(format!("Invalid request format: {}", e)) })?; forward_beta_header(&mut request, &headers); @@ -422,6 +460,7 @@ pub(crate) async fn handle_messages( #[cfg(feature = "policies")] let resolved_policy = evaluate_policy_if_configured(&state, prelude.tenant_id.as_deref(), &model, &headers); + let audited_flag = prelude.audited.clone(); let ctx = dispatch::DispatchContext { state: &state, inner: &prelude.inner, @@ -434,13 +473,21 @@ pub(crate) async fn handle_messages( start_time, headers: &headers, trace_id: Some(trace_id), + audited: audited_flag.clone(), #[cfg(feature = "policies")] resolved_policy, }; - let result = dispatch::dispatch(&ctx, &mut request).await?; + let result = match dispatch::dispatch(&ctx, &mut request).await { + Ok(r) => r, + Err(e) => { + let mut response = e.into_response(); + mark_audited_if_set(&audited_flag, &mut response); + return Ok(response); + } + }; - finish_dispatch( + let mut response = finish_dispatch( result, prelude.transparency_enabled, req_id, @@ -454,14 +501,16 @@ pub(crate) async fn handle_messages( }, |resp| serialize_response(&resp), |resp| Json(resp).into_response(), - ) + )?; + mark_audited_if_set(&audited_flag, &mut response); + Ok(response) } /// Handle /v1/messages/count_tokens requests pub(crate) async fn handle_count_tokens( State(state): State>, Json(request_json): Json, -) -> Result { +) -> Result { let model = request_json .get("model") .and_then(|m| m.as_str()) @@ -491,7 +540,7 @@ pub(crate) async fn handle_count_tokens( let decision = inner .router .route(&mut routing_request) - .map_err(|e| AppError::RoutingError(e.to_string()))?; + .map_err(|e| RequestError::RoutingError(e.to_string()))?; debug!( "🧮 Routed count_tokens: {} → {} ({})", @@ -500,8 +549,9 @@ pub(crate) async fn handle_count_tokens( // Deserialize the full count_tokens request (consumes the JSON value — no clone). use crate::models::CountTokensRequest; - let count_request: CountTokensRequest = serde_json::from_value(request_json) - .map_err(|e| AppError::ParseError(format!("Invalid count_tokens request format: {}", e)))?; + let count_request: CountTokensRequest = serde_json::from_value(request_json).map_err(|e| { + RequestError::ParseError(format!("Invalid count_tokens request format: {}", e)) + })?; // Try model mappings with fallback (1:N mapping) if let Some(model_config) = inner.find_model(&decision.model_name) { @@ -525,11 +575,15 @@ pub(crate) async fn handle_count_tokens( } } - Err(AppError::ProviderError(format!( - "All {} provider mappings failed for token counting: {}", - sorted_mappings.len(), - decision.model_name - ))) + Err(RequestError::ProviderUpstream { + provider: "all".to_string(), + status: 502, + body: Some(format!( + "All {} provider mappings failed for token counting: {}", + sorted_mappings.len(), + decision.model_name + )), + }) } else if let Ok(provider) = inner .provider_registry .provider_for_model(&decision.model_name) @@ -539,10 +593,10 @@ pub(crate) async fn handle_count_tokens( let response = provider .count_tokens(req) .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; + .map_err(RequestError::from)?; Ok(Json(response).into_response()) } else { - Err(AppError::ProviderError(format!( + Err(RequestError::RoutingError(format!( "No model mapping or provider found for token counting: {}", decision.model_name ))) diff --git a/src/server/helpers.rs b/src/server/helpers.rs index 55efc317..f0a2f825 100644 --- a/src/server/helpers.rs +++ b/src/server/helpers.rs @@ -4,14 +4,14 @@ use std::borrow::Cow; use std::sync::Arc; use tracing::info; -use super::{AppError, ReloadableState}; +use super::{ReloadableState, RequestError}; /// Resolve and sort provider mappings for a routing decision. pub(crate) fn resolve_provider_mappings( inner: &Arc, headers: &HeaderMap, decision: &crate::models::RouteDecision, -) -> Result, AppError> { +) -> Result, RequestError> { // Tier-based provider selection (opt-in via [[tiers]] config) if let Some(ref tier) = decision.complexity_tier { let tier_name = tier.to_string(); @@ -112,7 +112,7 @@ pub(crate) fn resolve_provider_mappings( if let Some(ref provider_name) = forced_provider { sorted.retain(|m| m.provider == *provider_name); if sorted.is_empty() { - return Err(AppError::RoutingError(format!( + return Err(RequestError::RoutingError(format!( "Provider '{}' not found in mappings for model '{}'", provider_name, decision.model_name ))); @@ -137,7 +137,7 @@ pub(crate) fn resolve_provider_mappings( provider_region == region_filter || provider_region == "global" }); if sorted.is_empty() { - return Err(AppError::RoutingError(format!( + return Err(RequestError::RoutingError(format!( "No providers match region '{}' for model '{}' (GDPR filtering enabled)", region_filter, decision.model_name ))); @@ -201,7 +201,7 @@ pub(crate) fn resolve_provider_mappings( .collect(); if pass_through_mappings.is_empty() { - Err(AppError::RoutingError(format!( + Err(RequestError::RoutingError(format!( "Model '{}' is not configured. Add a [[models]] entry in config.toml or set pass_through = true on a provider.", decision.model_name ))) diff --git a/src/server/middleware.rs b/src/server/middleware.rs index 1a73084e..2a06e92f 100644 --- a/src/server/middleware.rs +++ b/src/server/middleware.rs @@ -78,7 +78,7 @@ pub(crate) fn auth_error_response(message: &str) -> Response { /// Stored in request extensions for correlation #[derive(Clone, Debug)] -pub(crate) struct RequestId(pub String); +pub struct RequestId(pub String); /// Auth middleware: supports three modes: /// - "none" (default): all requests pass @@ -321,6 +321,193 @@ pub(crate) async fn security_headers_response_middleware( apply_security_headers(response, &config) } +/// Marker inserted into response extensions by handlers that already wrote +/// an audit entry. The audit middleware skips logging when present so that +/// the dispatch pipeline (which audits with rich DLP and token-count context) +/// is the source of truth for request-lifecycle entries on the hot path. +/// +/// Endpoints that bypass dispatch entirely (oauth handlers, config API, +/// errors raised in middleware before dispatch) leave this marker absent +/// and are audited centrally by the middleware. +#[derive(Clone, Debug)] +pub struct AuditedAlready; + +/// Inputs captured by the audit middleware before the handler runs. +/// +/// Stored on the request side so post-handler audit emission can rebuild +/// the entry without re-reading consumed request state. +pub struct AuditMiddlewareCapture { + /// HTTP method of the request. + pub method: axum::http::Method, + /// Path component of the request URI. + pub path: String, + /// Correlation ID resolved from the `RequestId` extension. + pub request_id: String, + /// Tenant identifier from JWT / virtual key, or empty. + pub tenant_id: String, + /// Client IP from `X-Forwarded-For` or `"unknown"`. + pub client_ip: String, + /// Wall-clock instant the middleware observed the request. + pub started_at: std::time::Instant, +} + +/// Pulls the captured request context that `audit_log_layer` snapshots +/// before the handler runs. +pub fn capture_audit_input(request: &Request) -> AuditMiddlewareCapture { + let request_id = request + .extensions() + .get::() + .map(|r| r.0.clone()) + .unwrap_or_default(); + + let tenant_id = if let Some(vk) = request + .extensions() + .get::() + { + vk.tenant_id.clone() + } else if let Some(claims) = request.extensions().get::() { + claims.tenant_id().to_string() + } else { + String::new() + }; + + AuditMiddlewareCapture { + method: request.method().clone(), + path: request.uri().path().to_string(), + request_id, + tenant_id, + client_ip: extract_client_ip(request.headers()), + started_at: std::time::Instant::now(), + } +} + +/// Emits an `AuditEvent::RequestProcessed` entry from the captured request +/// context plus the post-handler response. Returns `true` when an entry +/// was written, `false` when the response carried [`AuditedAlready`] (in +/// which case the dispatch pipeline already wrote a richer entry). +/// +/// Extracted from [`audit_log_layer`] so it can be unit-tested without +/// constructing a full `AppState`. +pub fn emit_request_processed( + audit_log: &crate::security::AuditLog, + capture: &AuditMiddlewareCapture, + response: &Response, +) -> bool { + if response.extensions().get::().is_some() { + return false; + } + + let status = response.status(); + let duration_ms = capture.started_at.elapsed().as_millis() as u64; + + let provider = response + .headers() + .get("x-ai-provider") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + let model = response + .headers() + .get("x-ai-model") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + + let error_variant = response + .extensions() + .get::() + .map(|tag| tag.0.clone()); + + let backend = if !provider.is_empty() { + provider + } else if let Some(ref tag) = error_variant { + format!("ERROR:{}:{}", tag, status.as_u16()) + } else if status.is_success() { + format!("{} {}", capture.method, capture.path) + } else { + format!("STATUS:{}", status.as_u16()) + }; + + let tenant_for_entry = if capture.tenant_id.is_empty() { + capture.client_ip.as_str() + } else { + capture.tenant_id.as_str() + }; + + let mut builder = super::AuditEntryBuilder::new( + tenant_for_entry, + crate::security::audit_log::AuditEvent::RequestProcessed, + &backend, + &capture.client_ip, + duration_ms, + ); + + if !model.is_empty() { + builder = builder.model(model); + } + + // Risk level: low for 2xx, medium for 4xx, high for 5xx — matches + // the EU AI Act Article 14 escalation threshold defaults. + let risk = if status.is_server_error() { + crate::security::audit_log::RiskLevel::High + } else if status.is_client_error() { + crate::security::audit_log::RiskLevel::Medium + } else { + crate::security::audit_log::RiskLevel::Low + }; + builder = builder.risk(risk); + + if let Some(tag) = error_variant { + builder = builder.dlp_rules(vec![format!( + "request_error:{}:status={}", + tag, + status.as_u16() + )]); + } + + if let Err(e) = audit_log.write(builder.build()) { + tracing::error!( + error = %e, + request_id = %capture.request_id, + "audit middleware: write failed" + ); + } + true +} + +/// Audit-log middleware: emits `AuditEvent::RequestProcessed` for every HTTP +/// request that flows through the server. +/// +/// Wraps every endpoint, including the OAuth, config, and health surfaces +/// that previously bypassed audit entirely. Captures request method, path, +/// status, latency, error variant tag (when 4xx/5xx), tenant identifier +/// (from JWT claims or virtual key context), client IP, and the upstream +/// provider name when set on the response by the dispatch pipeline. +/// +/// Skips logging when the dispatch pipeline has already written a richer +/// audit entry (signalled by the [`AuditedAlready`] marker in response +/// extensions). Health and metrics endpoints are excluded to avoid +/// flooding the journal with unauthenticated probe traffic. +pub(crate) async fn audit_log_layer( + State(state): State>, + request: Request, + next: Next, +) -> Response { + let path = request.uri().path(); + if matches!(path, "/health" | "/live" | "/ready" | "/metrics") { + return next.run(request).await; + } + + let capture = capture_audit_input(&request); + let response = next.run(request).await; + + if let Some(ref audit_log) = state.security.audit_log { + emit_request_processed(audit_log, &capture, &response); + } + + response +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/server/mod.rs b/src/server/mod.rs index e815fa3d..792702d4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -37,7 +37,7 @@ pub(crate) use budget::{ calculate_cost, check_budget, is_auth_revoked_error, is_provider_subscription, is_retryable, record_request_metrics, record_spend, retry_delay, RequestMetrics, MAX_RETRIES, }; -pub use error::AppError; +pub use error::{ErrorVariantTag, RequestError}; pub(crate) use helpers::{ format_route_type, inject_continuation_text, resolve_provider_mappings, sanitize_provider_response_reported, should_inject_continuation, @@ -49,9 +49,12 @@ pub(crate) use init::{ init_provider_scorer, init_security, maybe_preset_sync, spawn_background_tasks, }; pub(crate) use middleware::{ - apply_transparency_headers, auth_middleware, extract_api_credential, extract_client_ip, - rate_limit_check_middleware, request_id_middleware, security_headers_response_middleware, - should_apply_transparency, RequestId, + apply_transparency_headers, audit_log_layer, auth_middleware, extract_api_credential, + extract_client_ip, rate_limit_check_middleware, request_id_middleware, + security_headers_response_middleware, should_apply_transparency, +}; +pub use middleware::{ + capture_audit_input, emit_request_processed, AuditMiddlewareCapture, AuditedAlready, RequestId, }; use crate::auth::TokenStore; @@ -400,6 +403,17 @@ fn build_app_router(config: &AppConfig, state: Arc) -> axum::Router { let app = app.layer(RequestBodyLimitLayer::new( config.security.max_body_size.value(), )); + + // Audit middleware: captures every request lifecycle, including those + // rejected by rate-limit / auth before reaching a handler. Layered + // INSIDE `request_id_middleware` (which is added afterwards and so wraps + // this one) so `RequestId` is set in extensions before the audit logic + // reads it. + let app = app.layer(axum::middleware::from_fn_with_state( + state.clone(), + audit_log_layer, + )); + let app = app.layer(axum::middleware::from_fn(request_id_middleware)); // Tape recorder layer: outermost to capture raw HTTP before any transformation. diff --git a/tests/enterprise/snapshot_error_test.rs b/tests/enterprise/snapshot_error_test.rs index cd7cccd1..482cc0bf 100644 --- a/tests/enterprise/snapshot_error_test.rs +++ b/tests/enterprise/snapshot_error_test.rs @@ -1,23 +1,23 @@ //! Snapshot tests for error response formats. //! -//! Each [`AppError`] variant produces a JSON response with a specific HTTP +//! Each [`RequestError`] variant produces a JSON response with a specific HTTP //! status code and error structure. These snapshots detect unintended changes //! to the error API contract. use axum::response::IntoResponse; -use grob::server::AppError; +use grob::server::RequestError; // ── Helpers ───────────────────────────────────────────────────── -/// Extracts status code and parsed JSON body from an AppError. -async fn error_snapshot(error: AppError) -> String { +/// Extracts status code and parsed JSON body from a `RequestError`. +async fn error_snapshot(error: RequestError) -> String { let response = error.into_response(); let status = response.status(); let body_bytes = axum::body::to_bytes(response.into_body(), 1024 * 1024) .await .expect("invariant: in-memory body collection cannot fail"); let json: serde_json::Value = serde_json::from_slice(&body_bytes) - .expect("invariant: AppError always produces valid JSON"); + .expect("invariant: RequestError always produces valid JSON"); format!("status={} body={}", status.as_u16(), json) } @@ -25,16 +25,17 @@ async fn error_snapshot(error: AppError) -> String { #[tokio::test] async fn snapshot_budget_exceeded_error() { - let snap = error_snapshot(AppError::BudgetExceeded( - "Monthly global budget reached: $50.00/$50.00".to_string(), - )) + let snap = error_snapshot(RequestError::BudgetExceeded { + limit_usd: 50.0, + actual_usd: 50.0, + }) .await; insta::assert_snapshot!(snap); } #[tokio::test] async fn snapshot_routing_error() { - let snap = error_snapshot(AppError::RoutingError( + let snap = error_snapshot(RequestError::RoutingError( "no matching model: gpt-unknown".to_string(), )) .await; @@ -43,16 +44,18 @@ async fn snapshot_routing_error() { #[tokio::test] async fn snapshot_provider_error() { - let snap = error_snapshot(AppError::ProviderError( - "upstream timeout after 30s".to_string(), - )) + let snap = error_snapshot(RequestError::ProviderUpstream { + provider: "openai".to_string(), + status: 502, + body: Some("upstream timeout after 30s".to_string()), + }) .await; insta::assert_snapshot!(snap); } #[tokio::test] async fn snapshot_parse_error() { - let snap = error_snapshot(AppError::ParseError( + let snap = error_snapshot(RequestError::ParseError( "invalid JSON at line 1, column 42".to_string(), )) .await; @@ -61,9 +64,44 @@ async fn snapshot_parse_error() { #[tokio::test] async fn snapshot_dlp_blocked_error() { - let snap = error_snapshot(AppError::DlpBlocked( + let snap = error_snapshot(RequestError::DlpBlocked( "secret detected in prompt: sk-***".to_string(), )) .await; insta::assert_snapshot!(snap); } + +#[tokio::test] +async fn snapshot_rate_limited_error() { + let snap = error_snapshot(RequestError::RateLimited { + provider: "anthropic".to_string(), + retry_after_ms: Some(2500), + }) + .await; + insta::assert_snapshot!(snap); +} + +#[tokio::test] +async fn snapshot_unauthorized_error() { + let snap = error_snapshot(RequestError::Unauthorized).await; + insta::assert_snapshot!(snap); +} + +#[tokio::test] +async fn snapshot_forbidden_error() { + let snap = error_snapshot(RequestError::Forbidden( + "policy denies model 'claude-opus-4-7' for tenant 'public'".to_string(), + )) + .await; + insta::assert_snapshot!(snap); +} + +#[tokio::test] +async fn snapshot_auth_revoked_error() { + let snap = error_snapshot(RequestError::AuthRevoked( + "OAuth token for provider 'anthropic' revoked. Run: grob connect --force-reauth" + .to_string(), + )) + .await; + insta::assert_snapshot!(snap); +} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap new file mode 100644 index 00000000..342ae8c6 --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=401 body={"error":{"type":"authentication_error","message":"OAuth token for provider 'anthropic' revoked. Run: grob connect --force-reauth"}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap index f23ad3cc..b3570bb4 100644 --- a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap @@ -2,4 +2,4 @@ source: tests/enterprise/snapshot_error_test.rs expression: snap --- -status=402 body={"error":{"type":"budget_exceeded","message":"Monthly global budget reached: $50.00/$50.00"}} +status=402 body={"error":{"type":"budget_exceeded","message":"Budget exceeded: $50.0000 spent of $50.0000 limit","limit_usd":50.0,"actual_usd":50.0}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap new file mode 100644 index 00000000..06f2fc34 --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=403 body={"error":{"type":"permission_error","message":"policy denies model 'claude-opus-4-7' for tenant 'public'"}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap index 9435a30c..a4cbed28 100644 --- a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap @@ -2,4 +2,4 @@ source: tests/enterprise/snapshot_error_test.rs expression: snap --- -status=502 body={"error":{"type":"error","message":"upstream timeout after 30s"}} +status=502 body={"error":{"type":"error","message":"upstream timeout after 30s","provider":"openai","upstream_status":502}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap new file mode 100644 index 00000000..b0d62d2d --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=429 body={"error":{"type":"rate_limit_error","message":"Provider 'anthropic' rate-limited the request","provider":"anthropic","retry_after_ms":2500}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap new file mode 100644 index 00000000..003b1d21 --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=401 body={"error":{"type":"authentication_error","message":"Missing or invalid credentials"}} diff --git a/tests/integration/audit_middleware_test.rs b/tests/integration/audit_middleware_test.rs new file mode 100644 index 00000000..09f80da1 --- /dev/null +++ b/tests/integration/audit_middleware_test.rs @@ -0,0 +1,244 @@ +//! Integration tests for the audit-log middleware. +//! +//! Exercises [`emit_request_processed`] over a representative cross-section +//! of HTTP statuses (2xx/4xx/5xx) and verifies: +//! +//! * One audit entry is written per request when the dispatch pipeline did +//! not already log. +//! * No entry is written when the response carries the [`AuditedAlready`] +//! marker (de-duplication invariant). +//! * Provider/model headers and error variant tags are preserved in the +//! audit entry. + +use axum::http::{HeaderValue, Method, Request, Response, StatusCode}; +use grob::security::audit_log::{AuditConfig, AuditEntry, AuditEvent, AuditLog, SigningAlgorithm}; +use grob::server::{ + capture_audit_input, emit_request_processed, AuditMiddlewareCapture, AuditedAlready, +}; +use tempfile::TempDir; + +fn build_audit_log() -> (TempDir, AuditLog) { + let dir = TempDir::new().expect("tempdir"); + let log = AuditLog::new(AuditConfig { + log_dir: dir.path().to_path_buf(), + sign_key_path: None, + signing_algorithm: SigningAlgorithm::default(), + hmac_key_path: None, + batch_size: 1, + flush_interval_ms: 5000, + include_merkle_proof: false, + }) + .expect("audit log"); + (dir, log) +} + +fn read_entries(dir: &TempDir) -> Vec { + let path = dir.path().join("current.jsonl"); + if !path.exists() { + return vec![]; + } + std::fs::read_to_string(&path) + .expect("read jsonl") + .lines() + .filter(|l| !l.trim().is_empty()) + .map(|l| serde_json::from_str(l).expect("parse audit entry")) + .collect() +} + +fn make_capture() -> AuditMiddlewareCapture { + AuditMiddlewareCapture { + method: Method::POST, + path: "/v1/messages".to_string(), + request_id: "req-test-001".to_string(), + tenant_id: "tenant-alpha".to_string(), + client_ip: "10.0.0.1".to_string(), + started_at: std::time::Instant::now(), + } +} + +fn build_response(status: StatusCode) -> Response { + Response::builder() + .status(status) + .body(axum::body::Body::empty()) + .expect("build response") +} + +#[test] +fn audit_middleware_emits_one_entry_per_request_lifecycle() { + let (dir, audit) = build_audit_log(); + + // Simulate five requests with different outcomes. + let cases = [ + StatusCode::OK, // 200 + StatusCode::BAD_REQUEST, // 400 + StatusCode::UNAUTHORIZED, // 401 + StatusCode::TOO_MANY_REQUESTS, // 429 + StatusCode::INTERNAL_SERVER_ERROR, // 500 + ]; + + for status in cases { + let capture = make_capture(); + let response = build_response(status); + let written = emit_request_processed(&audit, &capture, &response); + assert!( + written, + "audit middleware must emit an entry for {}", + status + ); + } + + let entries = read_entries(&dir); + assert_eq!( + entries.len(), + cases.len(), + "expected one audit entry per request" + ); + for entry in &entries { + assert!(matches!(entry.action, AuditEvent::RequestProcessed)); + assert_eq!(entry.tenant_id, "tenant-alpha"); + assert_eq!(entry.ip_source, "10.0.0.1"); + } +} + +#[test] +fn audit_middleware_skips_when_handler_already_audited() { + let (dir, audit) = build_audit_log(); + + let capture = make_capture(); + let mut response = build_response(StatusCode::OK); + response.extensions_mut().insert(AuditedAlready); + + let written = emit_request_processed(&audit, &capture, &response); + assert!( + !written, + "audit middleware must not double-log when handler already logged" + ); + + let entries = read_entries(&dir); + assert!( + entries.is_empty(), + "expected zero audit entries from middleware when AuditedAlready is set, got {}", + entries.len() + ); +} + +#[test] +fn audit_middleware_captures_provider_from_response_header() { + let (dir, audit) = build_audit_log(); + + let capture = make_capture(); + let mut response = build_response(StatusCode::OK); + response + .headers_mut() + .insert("x-ai-provider", HeaderValue::from_static("anthropic")); + response + .headers_mut() + .insert("x-ai-model", HeaderValue::from_static("claude-opus-4-7")); + + let written = emit_request_processed(&audit, &capture, &response); + assert!(written); + + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].backend_routed, "anthropic"); + assert_eq!( + entries[0].model_name.as_deref(), + Some("claude-opus-4-7"), + "model name should be captured from x-ai-model header" + ); +} + +#[test] +fn audit_middleware_includes_error_variant_tag_in_dlp_rules_field() { + let (dir, audit) = build_audit_log(); + + let capture = make_capture(); + let mut response = build_response(StatusCode::PAYMENT_REQUIRED); + response + .extensions_mut() + .insert(grob::server::ErrorVariantTag("budget_exceeded".to_string())); + + let written = emit_request_processed(&audit, &capture, &response); + assert!(written); + + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert!(entries[0] + .dlp_rules_triggered + .iter() + .any(|r| r.contains("budget_exceeded") && r.contains("status=402"))); +} + +#[test] +fn audit_middleware_emits_low_risk_for_2xx() { + let (dir, audit) = build_audit_log(); + let capture = make_capture(); + let response = build_response(StatusCode::OK); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].risk_level, + Some(grob::security::audit_log::RiskLevel::Low) + ); +} + +#[test] +fn audit_middleware_emits_medium_risk_for_4xx() { + let (dir, audit) = build_audit_log(); + let capture = make_capture(); + let response = build_response(StatusCode::BAD_REQUEST); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].risk_level, + Some(grob::security::audit_log::RiskLevel::Medium) + ); +} + +#[test] +fn audit_middleware_emits_high_risk_for_5xx() { + let (dir, audit) = build_audit_log(); + let capture = make_capture(); + let response = build_response(StatusCode::INTERNAL_SERVER_ERROR); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].risk_level, + Some(grob::security::audit_log::RiskLevel::High) + ); +} + +#[test] +fn audit_middleware_falls_back_to_client_ip_when_tenant_missing() { + let (dir, audit) = build_audit_log(); + let mut capture = make_capture(); + capture.tenant_id = String::new(); + let response = build_response(StatusCode::OK); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].tenant_id, "10.0.0.1", + "anonymous request should be tagged by client IP" + ); +} + +#[test] +fn capture_audit_input_picks_up_request_id_from_extensions() { + let mut request: Request = Request::builder() + .method(Method::GET) + .uri("/v1/models") + .body(axum::body::Body::empty()) + .unwrap(); + request + .extensions_mut() + .insert(grob::server::RequestId("rid-abc-123".to_string())); + + let capture = capture_audit_input(&request); + assert_eq!(capture.request_id, "rid-abc-123"); + assert_eq!(capture.method, Method::GET); + assert_eq!(capture.path, "/v1/models"); +} diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 398f996e..8498d8f4 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -1,4 +1,5 @@ // Integration tests module +mod audit_middleware_test; mod cache_test; mod compliance_test; mod dlp_test;