diff --git a/prek.toml b/prek.toml index 82c09caa..99598f46 100644 --- a/prek.toml +++ b/prek.toml @@ -54,25 +54,10 @@ args = ["-c", "lychee --offline --no-progress --include-fragments docs/**/*.md R types = ["markdown"] pass_filenames = false -# Tests — expensive, runs only on pre-push -[[repos.hooks]] -id = "cargo-test" -name = "cargo test" -language = "system" -entry = "cargo nextest run --profile ci" -types = ["rust"] -pass_filenames = false -stages = ["pre-push"] - -# Doc tests — not covered by nextest -[[repos.hooks]] -id = "cargo-doctest" -name = "cargo doc tests" -language = "system" -entry = "cargo test --doc" -types = ["rust"] -pass_filenames = false -stages = ["pre-push"] +# NOTE: cargo-test and cargo-doctest were removed from pre-push to keep +# the gate fast-fail. CI re-runs the full suite on every PR (see +# `.github/workflows/ci.yml` `tests` job), so duplicating ~20 min of work +# locally only delays the developer feedback loop without adding signal. # Cargo deny — license & advisory & dependency checks # Only runs when Cargo.toml/lock or Rust files change. diff --git a/src/cli/config/budget.rs b/src/cli/config/budget.rs index 414a2f26..4eaa6e74 100644 --- a/src/cli/config/budget.rs +++ b/src/cli/config/budget.rs @@ -6,6 +6,7 @@ use crate::cli::BudgetUsd; /// Budget configuration #[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[serde(deny_unknown_fields)] pub struct BudgetConfig { /// Global monthly hard cap in USD (0 = unlimited) #[serde(default)] diff --git a/src/cli/config/cache.rs b/src/cli/config/cache.rs index b2e42554..201ddf66 100644 --- a/src/cli/config/cache.rs +++ b/src/cli/config/cache.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; /// LLM response cache configuration #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct CacheConfig { /// Enable response caching (only for temperature=0 requests) #[serde(default)] diff --git a/src/cli/config/providers.rs b/src/cli/config/providers.rs index 87c9ab77..88c240f8 100644 --- a/src/cli/config/providers.rs +++ b/src/cli/config/providers.rs @@ -28,6 +28,7 @@ pub enum AuthType { /// Provider configuration from TOML. #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] pub struct ProviderConfig { /// Unique provider name used in routing and logging. pub name: String, @@ -119,7 +120,16 @@ pub struct ProviderConfig { } impl ProviderConfig { - /// Returns `true` if the provider is enabled (defaults to `true`). + /// Returns `true` if the provider is enabled. + /// + /// Semantics: + /// - `enabled = true` → enabled. + /// - `enabled = false` → disabled. + /// - `enabled` absent → enabled (sensible default for newly added blocks). + /// + /// Typo safety: `#[serde(deny_unknown_fields)]` on [`ProviderConfig`] + /// rejects misspelled keys (e.g. `enbaled`) at parse time, so an absent + /// `enabled` field genuinely means "not specified" rather than "typo'd". pub fn is_enabled(&self) -> bool { self.enabled.unwrap_or(true) } diff --git a/src/cli/config/routing.rs b/src/cli/config/routing.rs index d0992450..0538177f 100644 --- a/src/cli/config/routing.rs +++ b/src/cli/config/routing.rs @@ -9,6 +9,7 @@ use super::user::PresetConfig; /// Router configuration #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct RouterConfig { /// Default model for unclassified requests pub default: String, @@ -102,6 +103,7 @@ pub struct FanOutConfig { /// Model configuration with 1:N provider mappings #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct ModelConfig { /// External model name (used in API requests) pub name: String, @@ -203,6 +205,7 @@ pub struct TierMatchCondition { /// When the scoring heuristic classifies a request, the dispatch pipeline /// resolves providers from the matching tier instead of the default model mappings. #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct TierConfig { /// Tier name — must match a `ComplexityTier` variant (case-insensitive). pub name: String, diff --git a/src/cli/config/security.rs b/src/cli/config/security.rs index 27583a9c..c403952e 100644 --- a/src/cli/config/security.rs +++ b/src/cli/config/security.rs @@ -8,6 +8,7 @@ use super::default_true; /// Security configuration (wired into middleware stack) #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct SecurityConfig { /// Master switch for security middleware #[serde(default = "default_true")] diff --git a/src/features/dlp/config.rs b/src/features/dlp/config.rs index f9329651..dc81c20e 100644 --- a/src/features/dlp/config.rs +++ b/src/features/dlp/config.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; /// Top-level DLP configuration, mapped from `[dlp]` in TOML. #[derive(Debug, Clone, Deserialize, Serialize, Default)] +#[serde(deny_unknown_fields)] pub struct DlpConfig { /// Enables the DLP pipeline globally. #[serde(default)] diff --git a/src/features/token_pricing/spend.rs b/src/features/token_pricing/spend.rs index e00ab0b1..eabc6eab 100644 --- a/src/features/token_pricing/spend.rs +++ b/src/features/token_pricing/spend.rs @@ -25,6 +25,10 @@ pub struct BudgetLimits { pub struct BudgetError { /// Human-readable budget exceeded message. pub message: String, + /// Configured budget limit in USD (whichever scope tripped first). + pub limit_usd: f64, + /// Recorded spend in USD at check time. + pub actual_usd: f64, } impl std::fmt::Display for BudgetError { @@ -216,6 +220,8 @@ impl SpendTracker { "Monthly budget for model '{}' reached: ${:.2}/${:.2}", model, spend, limit ), + limit_usd: limit, + actual_usd: spend, }); } } @@ -228,6 +234,8 @@ impl SpendTracker { "Monthly budget for provider '{}' reached: ${:.2}/${:.2}", provider, spend, limit ), + limit_usd: limit, + actual_usd: spend, }); } } @@ -240,6 +248,8 @@ impl SpendTracker { "Monthly global budget reached: ${:.2}/${:.2}", total, global_limit ), + limit_usd: global_limit, + actual_usd: total, }); } } diff --git a/src/models/config.rs b/src/models/config.rs index 63fe5196..1c87c1b5 100644 --- a/src/models/config.rs +++ b/src/models/config.rs @@ -25,6 +25,7 @@ use crate::features::tap::TapConfig; /// Application configuration #[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] pub struct AppConfig { /// Config schema version (for forward compatibility) #[serde(default, skip_serializing_if = "Option::is_none")] diff --git a/src/routing/classify/classify.rs b/src/routing/classify/classify.rs index ebe280eb..38b70b2a 100644 --- a/src/routing/classify/classify.rs +++ b/src/routing/classify/classify.rs @@ -84,6 +84,7 @@ impl Default for ScoringThresholds { /// Scoring configuration combining weights and thresholds. #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] +#[serde(deny_unknown_fields)] pub struct ScoringConfig { /// Per-signal weights. pub weights: ScoringWeights, diff --git a/src/security/audit_log.rs b/src/security/audit_log.rs index 04005928..bc1bd154 100644 --- a/src/security/audit_log.rs +++ b/src/security/audit_log.rs @@ -86,6 +86,10 @@ pub enum AuditEvent { HitApproval, /// TEE attestation report generated at startup. TeeAttestation, + /// HTTP request fully processed (emitted by the audit middleware once a + /// response has been produced — covers the entire request lifecycle from + /// authentication through dispatch and error handling). + RequestProcessed, } /// Immutable audit log entry. diff --git a/src/server/budget.rs b/src/server/budget.rs index ef919045..a40eee75 100644 --- a/src/server/budget.rs +++ b/src/server/budget.rs @@ -4,7 +4,7 @@ use crate::providers::AuthType; use std::sync::Arc; use tracing::warn; -use super::{AppError, AppState, ReloadableState}; +use super::{AppState, ReloadableState, RequestError}; /// Maximum retries per provider before falling back to the next mapping. /// NOTE: 2 retries (3 total attempts) balances latency vs resilience — most @@ -62,13 +62,13 @@ pub(crate) fn record_request_metrics(m: &RequestMetrics<'_>) { } } -/// Check budget before a request. Returns Err(AppError::BudgetExceeded) if any limit is hit. +/// Check budget before a request. Returns `Err(RequestError::BudgetExceeded)` if any limit is hit. pub(crate) async fn check_budget( state: &Arc, inner: &Arc, provider_name: &str, model_name: &str, -) -> Result<(), AppError> { +) -> Result<(), RequestError> { let budget_config = &inner.config.budget; let global_limit = budget_config.monthly_limit_usd.value(); @@ -92,7 +92,10 @@ pub(crate) async fn check_budget( provider_limit, model_limit, ) { - return Err(AppError::BudgetExceeded(e.message)); + return Err(RequestError::BudgetExceeded { + limit_usd: e.limit_usd, + actual_usd: e.actual_usd, + }); } if let Some(warning) = tracker.check_warnings( diff --git a/src/server/config_api.rs b/src/server/config_api.rs index eb1cd5cd..03bb7059 100644 --- a/src/server/config_api.rs +++ b/src/server/config_api.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use tracing::{error, info, warn}; use super::config_guard::is_section_or_key_denied; -use super::{AppError, AppState, ReloadableState}; +use super::{AppState, ReloadableState, RequestError}; /// Redact an API key for safe display (show first 4 + last 4 chars) pub(crate) fn redact_api_key(key: &str) -> String { @@ -89,7 +89,7 @@ pub(crate) async fn get_config_json(State(state): State>) -> impl pub(crate) async fn update_config_json( State(state): State>, Json(mut new_config): Json, -) -> Result, AppError> { +) -> Result, RequestError> { // Remove null values (TOML doesn't support null) remove_null_values(&mut new_config); @@ -99,7 +99,7 @@ pub(crate) async fn update_config_json( // Whole-section deny check (providers, dlp). if is_section_or_key_denied(section, "") { warn!(section = %section, "config API: denied write to protected section"); - return Err(AppError::ParseError(format!( + return Err(RequestError::Forbidden(format!( "denied: section '{}' cannot be modified via the config API", section ))); @@ -109,7 +109,7 @@ pub(crate) async fn update_config_json( for key in inner.keys() { if is_section_or_key_denied(section, key) { warn!(section = %section, key = %key, "config API: denied write to protected key"); - return Err(AppError::ParseError(format!( + return Err(RequestError::Forbidden(format!( "denied: {}.{} cannot be modified via the config API", section, key ))); @@ -123,7 +123,7 @@ pub(crate) async fn update_config_json( let config_path = match &state.config_source { crate::cli::ConfigSource::File(p) => p, crate::cli::ConfigSource::Url(_) => { - return Err(AppError::ParseError( + return Err(RequestError::BadRequest( "Cannot save config: loaded from remote URL (read-only)".to_string(), )); } @@ -132,15 +132,15 @@ pub(crate) async fn update_config_json( // Read current config and merge the incoming JSON updates into it. let config_str = tokio::fs::read_to_string(config_path) .await - .map_err(|e| AppError::ParseError(format!("Failed to read config: {e}")))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("Failed to read config: {e}")))?; let mut config: toml::Value = toml::from_str(&config_str) - .map_err(|e| AppError::ParseError(format!("Failed to parse config: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Failed to parse config: {e}")))?; // Update providers section if let Some(providers) = new_config.get("providers") { let providers_toml: toml::Value = serde_json::from_str(&providers.to_string()) - .map_err(|e| AppError::ParseError(format!("Failed to convert providers: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Failed to convert providers: {e}")))?; if let Some(table) = config.as_table_mut() { table.insert("providers".to_string(), providers_toml); @@ -150,7 +150,7 @@ pub(crate) async fn update_config_json( // Update models section if let Some(models) = new_config.get("models") { let models_toml: toml::Value = serde_json::from_str(&models.to_string()) - .map_err(|e| AppError::ParseError(format!("Failed to convert models: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Failed to convert models: {e}")))?; if let Some(table) = config.as_table_mut() { table.insert("models".to_string(), models_toml); @@ -192,9 +192,9 @@ pub(crate) async fn update_config_json( // Deserialise the merged TOML into AppConfig so we can validate and reload. let merged_toml_str = toml::to_string_pretty(&config) - .map_err(|e| AppError::ParseError(format!("Failed to serialize config: {e}")))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("Failed to serialize config: {e}")))?; let merged_config: crate::models::config::AppConfig = toml::from_str(&merged_toml_str) - .map_err(|e| AppError::ParseError(format!("Invalid config after merge: {e}")))?; + .map_err(|e| RequestError::ParseError(format!("Invalid config after merge: {e}")))?; // Backup, write, and hot-reload via the shared pipeline. super::config_guard::persist_and_reload(&state, &merged_config).await?; diff --git a/src/server/config_guard.rs b/src/server/config_guard.rs index b0a8132f..667d02fc 100644 --- a/src/server/config_guard.rs +++ b/src/server/config_guard.rs @@ -12,29 +12,76 @@ use std::sync::Arc; use tracing::info; /// Top-level TOML sections that are never writable via any config API. -const DENIED_SECTIONS: &[&str] = &["providers", "dlp"]; +/// +/// Each entry is denied because hot-reloading it cannot be done safely +/// at runtime — either the data is sensitive (and must travel through a +/// dedicated secret API), or the code path that consumes it is set up +/// once at process start and not re-initialised on `/api/config/reload`: +/// +/// | Section | Reason | +/// |-------------|-----------------------------------------------------------------------------------------| +/// | `providers` | Contains API keys; mutate via `grob connect` / secret backend, not the config API. | +/// | `dlp` | Security policy must not be weakened by an authenticated control-plane caller. | +/// | `tee` | TEE attestation runs at startup; flipping the mode mid-flight bypasses the gate. | +/// | `fips` | FIPS mode is checked once on init; toggling at runtime gives a false sense of compliance. | +/// +/// To change any of these the operator must edit `~/.grob/config.toml` +/// and restart the daemon. +const DENIED_SECTIONS: &[&str] = &["providers", "dlp", "tee", "fips"]; /// Per-section keys that are never writable via any config API. +/// +/// These are individual fields whose host section is otherwise editable, +/// but the field itself is either credential material or wired into a +/// non-reloadable subsystem: +/// +/// | Section.Key | Reason | +/// |--------------------|---------------------------------------------------------------------------------| +/// | `router.api_key` | Credential material — never round-trip through the config API. | +/// | `budget.api_key` | Same. | +/// | `cache.api_key` | Same. | +/// | `server.tls` | TLS listener is bound at startup; rebuilding it requires a daemon restart. | +/// | `secrets.backend` | The secret backend is constructed once and shared via `Arc`; swapping it at | +/// | | runtime would orphan in-flight readers and change credential resolution semantics. | const DENIED_KEYS: &[(&str, &str)] = &[ ("router", "api_key"), ("budget", "api_key"), ("cache", "api_key"), + ("server", "tls"), + ("secrets", "backend"), ]; /// Checks whether a (section, key) pair is blocked by the deny-list. /// -/// Returns `true` when the write must be rejected: -/// - The entire `providers` section (contains API keys). -/// - The entire `dlp` section (security must not be weakened). -/// - Any `api_key` field in any section. +/// Returns `true` when the write must be rejected. See [`DENIED_SECTIONS`] +/// and [`DENIED_KEYS`] for the rationale behind every entry. A denied +/// attempt is logged at INFO so the operator sees actionable guidance +/// (restart instead of expecting a silent reload to take effect). pub fn is_section_or_key_denied(section: &str, key: &str) -> bool { if DENIED_SECTIONS.contains(§ion) { + info!( + section = %section, + "config hot-reload: section is on the deny-list; restart the daemon to apply changes" + ); return true; } if key == "api_key" { + info!( + section = %section, + key = %key, + "config hot-reload: api_key fields cannot be set via the config API; use `grob connect` or the secret backend" + ); + return true; + } + if DENIED_KEYS.iter().any(|(s, k)| *s == section && *k == key) { + info!( + section = %section, + key = %key, + "config hot-reload: key is on the deny-list; restart the daemon to apply changes" + ); return true; } - DENIED_KEYS.iter().any(|(s, k)| *s == section && *k == key) + false } /// Validates a key update against the deny-list using [`ConfigSection`]. @@ -67,11 +114,11 @@ pub fn is_key_denied(section: &ConfigSection, key: &str) -> bool { pub async fn persist_and_reload( state: &Arc, config: &crate::models::config::AppConfig, -) -> Result<(), super::AppError> { +) -> Result<(), super::RequestError> { let config_path = match &state.config_source { crate::cli::ConfigSource::File(p) => p, crate::cli::ConfigSource::Url(_) => { - return Err(super::AppError::ParseError( + return Err(super::RequestError::BadRequest( "Cannot save config: loaded from remote URL (read-only)".to_string(), )); } @@ -81,15 +128,18 @@ pub async fn persist_and_reload( let backup_path = config_path.with_extension("toml.backup"); tokio::fs::copy(config_path, &backup_path) .await - .map_err(|e| super::AppError::ParseError(format!("Failed to create backup: {e}")))?; + .map_err(|e| { + super::RequestError::Internal(anyhow::anyhow!("Failed to create backup: {e}")) + })?; // 2. Serialise and write - let toml_str = toml::to_string_pretty(config) - .map_err(|e| super::AppError::ParseError(format!("Failed to serialize config: {e}")))?; + let toml_str = toml::to_string_pretty(config).map_err(|e| { + super::RequestError::Internal(anyhow::anyhow!("Failed to serialize config: {e}")) + })?; - tokio::fs::write(config_path, toml_str) - .await - .map_err(|e| super::AppError::ParseError(format!("Failed to write config: {e}")))?; + tokio::fs::write(config_path, toml_str).await.map_err(|e| { + super::RequestError::Internal(anyhow::anyhow!("Failed to write config: {e}")) + })?; // 3. Hot-reload: rebuild router + provider registry from the new config reload_state(state, config.clone(), config_path)?; @@ -109,7 +159,7 @@ fn reload_state( state: &Arc, config: crate::models::config::AppConfig, _config_path: &Path, -) -> Result<(), super::AppError> { +) -> Result<(), super::RequestError> { let new_router = crate::routing::classify::Router::new(config.clone()); let secret_backend = @@ -123,7 +173,7 @@ fn reload_state( &config.server.timeouts, ) .map_err(|e| { - super::AppError::ProviderError(format!("Failed to rebuild provider registry: {e}")) + super::RequestError::Internal(anyhow::anyhow!("Failed to rebuild provider registry: {e}")) })?; let new_inner = Arc::new(super::ReloadableState::new( @@ -177,6 +227,27 @@ mod tests { assert!(!is_section_or_key_denied("cache", "ttl_secs")); } + #[test] + fn deny_static_init_sections() { + // tee and fips are checked once at startup; toggling them at runtime + // would bypass the gate without the operator realising. + assert!(is_section_or_key_denied("tee", "mode")); + assert!(is_section_or_key_denied("tee", "sealed_keys")); + assert!(is_section_or_key_denied("fips", "mode")); + assert!(is_section_or_key_denied("fips", "anything")); + } + + #[test] + fn deny_static_init_keys() { + // The TLS listener and secret backend are constructed once on + // process start; both require a daemon restart to swap. + assert!(is_section_or_key_denied("server", "tls")); + assert!(is_section_or_key_denied("secrets", "backend")); + // Sibling keys in the same sections must remain editable. + assert!(!is_section_or_key_denied("server", "host")); + assert!(!is_section_or_key_denied("server", "port")); + } + #[cfg(feature = "mcp")] mod mcp_compat { use super::*; diff --git a/src/server/dispatch/mod.rs b/src/server/dispatch/mod.rs index ffb8c64e..940cac10 100644 --- a/src/server/dispatch/mod.rs +++ b/src/server/dispatch/mod.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use super::{ calculate_cost, is_provider_subscription, log_audit, record_request_metrics, - resolve_provider_mappings, sanitize_provider_response_reported, AppError, AppState, - AuditCompliance, AuditParams, ReloadableState, RequestMetrics, + resolve_provider_mappings, sanitize_provider_response_reported, AppState, AuditCompliance, + AuditParams, ReloadableState, RequestError, RequestMetrics, }; use crate::features::watch::events::{DlpDirection, WatchEvent}; @@ -46,6 +46,9 @@ pub(crate) struct DispatchContext<'a> { pub headers: &'a HeaderMap, /// Message tracer context. None for OpenAI compat endpoint. pub trace_id: Option, + /// Audit-emitted flag — flipped by `log_audit_if_enabled` so the + /// outer audit middleware can skip writing a duplicate entry. + pub audited: std::sync::Arc, /// Resolved policy for this request (when policies feature is enabled). #[cfg(feature = "policies")] #[allow(dead_code)] @@ -169,6 +172,9 @@ impl DispatchContext<'_> { dlp_had_pii: entry.dlp_had_pii, dlp_had_redact_or_warn: entry.dlp_had_redact_or_warn, }); + // Flag so the outer audit middleware skips a duplicate entry. + self.audited + .store(true, std::sync::atomic::Ordering::Release); } } } @@ -244,7 +250,7 @@ pub(crate) fn resolve_grob_hint( pub(crate) async fn dispatch( ctx: &DispatchContext<'_>, request: &mut CanonicalRequest, -) -> Result { +) -> Result { // ── Step 0: Resolve complexity hint ── // Resolved up-front (borrows `request` immutably) but applied post-routing // so the client-declared tier overrides the algorithmic scorer. @@ -291,7 +297,7 @@ pub(crate) async fn dispatch( .inner .router .route(request) - .map_err(|e| AppError::RoutingError(e.to_string()))?; + .map_err(|e| RequestError::RoutingError(e.to_string()))?; // ── Step 3.5: Apply client-declared complexity hint ── // The hint (header / body metadata / MCP one-shot) overrides whatever tier @@ -398,7 +404,7 @@ async fn check_cache( fn scan_dlp_input( ctx: &DispatchContext<'_>, request: &mut CanonicalRequest, -) -> Result<(), AppError> { +) -> Result<(), RequestError> { let Some(ref dlp_engine) = ctx.dlp else { return Ok(()); }; @@ -476,7 +482,7 @@ fn scan_dlp_input( dlp_had_pii: false, dlp_had_redact_or_warn: false, }); - Err(AppError::DlpBlocked(format!("{}", block_err))) + Err(RequestError::DlpBlocked(format!("{}", block_err))) } } } @@ -488,7 +494,7 @@ async fn dispatch_fan_out( sorted_mappings: &[crate::cli::ModelMapping], fan_out_config: &crate::cli::FanOutConfig, decision: &crate::models::RouteDecision, -) -> Result { +) -> Result { let mut fan_request = request.clone(); ctx.sanitize_input(&mut fan_request); @@ -503,7 +509,11 @@ async fn dispatch_fan_out( Ok((response, provider_info)) => { handle_fan_out_success(ctx, response, &provider_info, decision).await } - Err(e) => Err(AppError::ProviderError(format!("Fan-out failed: {}", e))), + Err(e) => Err(RequestError::ProviderUpstream { + provider: "fan_out".to_string(), + status: 502, + body: Some(format!("Fan-out failed: {}", e)), + }), } } @@ -513,7 +523,7 @@ async fn handle_fan_out_success( mut response: ProviderResponse, provider_info: &[(String, String)], decision: &crate::models::RouteDecision, -) -> Result { +) -> Result { ctx.sanitize_output(&mut response); let latency_ms = ctx.start_time.elapsed().as_millis() as u64; diff --git a/src/server/dispatch/provider_loop.rs b/src/server/dispatch/provider_loop.rs index 9b441d19..773be8fc 100644 --- a/src/server/dispatch/provider_loop.rs +++ b/src/server/dispatch/provider_loop.rs @@ -17,11 +17,11 @@ //! //! After the loop exhausts the list, [`resolver::try_direct_provider_lookup`] //! offers a backward-compat path for unmapped models. A final audit entry -//! is written before returning `AppError::ProviderError`. +//! is written before returning `RequestError::ProviderUpstream`. use super::super::{ check_budget, format_route_type, inject_continuation_text, is_provider_subscription, - should_inject_continuation, AppError, + should_inject_continuation, RequestError, }; use super::resolver::{resolve_provider, try_direct_provider_lookup}; use super::retry::{ @@ -39,7 +39,7 @@ pub(super) async fn dispatch_provider_loop( sorted_mappings: &[crate::cli::ModelMapping], decision: &crate::models::RouteDecision, cache_key: &Option, -) -> Result { +) -> Result { // Re-sort mappings by adaptive score when scorer is enabled let rescored; let effective_mappings: &[crate::cli::ModelMapping] = @@ -156,7 +156,7 @@ pub(super) async fn dispatch_provider_loop( ); // Abort the fallback cascade: this is a user-actionable error, // not a transient provider failure. - return Err(AppError::AuthenticationError(format!( + return Err(RequestError::AuthRevoked(format!( "OAuth token for provider '{}' revoked. Run: grob connect --force-reauth. Details: {}", mapping.provider, msg ))); @@ -189,11 +189,15 @@ pub(super) async fn dispatch_provider_loop( "All provider mappings failed for model: {}", decision.model_name ); - Err(AppError::ProviderError(format!( - "All {} provider mappings failed for model: {}", - effective_mappings.len(), - decision.model_name - ))) + Err(RequestError::ProviderUpstream { + provider: "all".to_string(), + status: 502, + body: Some(format!( + "All {} provider mappings failed for model: {}", + effective_mappings.len(), + decision.model_name + )), + }) } /// Log the dispatch attempt info line (route type, stream mode, model -> provider). diff --git a/src/server/dispatch/resolver.rs b/src/server/dispatch/resolver.rs index 95e0cf54..6a5cfc6e 100644 --- a/src/server/dispatch/resolver.rs +++ b/src/server/dispatch/resolver.rs @@ -8,7 +8,7 @@ use std::sync::Arc; -use super::super::AppError; +use super::super::RequestError; use super::{DispatchContext, DispatchResult}; use tracing::info; @@ -101,7 +101,7 @@ pub(super) async fn try_direct_provider_lookup( ctx: &DispatchContext<'_>, request: &crate::models::CanonicalRequest, model_name: &str, -) -> Result, AppError> { +) -> Result, RequestError> { let Ok(provider) = ctx.inner.provider_registry.provider_for_model(model_name) else { return Ok(None); }; @@ -116,7 +116,7 @@ pub(super) async fn try_direct_provider_lookup( let mut response = provider .send_message(fallback_request) .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; + .map_err(RequestError::from)?; response.model = original_model; Ok(Some(DispatchResult::Complete { diff --git a/src/server/dispatch/retry.rs b/src/server/dispatch/retry.rs index aa7fd690..8bf3df03 100644 --- a/src/server/dispatch/retry.rs +++ b/src/server/dispatch/retry.rs @@ -43,16 +43,31 @@ pub(super) struct ProviderAttempt<'a> { pub is_subscription: bool, } +/// Returns `true` when a provider error reports a 429 rate-limit upstream. +/// +/// Defers to the `RequestError::RateLimited` mapping rules so the +/// classification logic lives in one place: a 429 status code OR a 401 with a +/// `rate_limit_error` payload (Anthropic-style). Callers that need to know +/// specifically whether they hit a 429 (e.g. to rotate a key pool) consult +/// this helper rather than re-implement the matcher. +pub(super) fn is_upstream_rate_limit(e: &crate::providers::error::ProviderError) -> bool { + use crate::providers::error::ProviderError; + match e { + ProviderError::ApiError { status: 429, .. } => true, + ProviderError::ApiError { + status: 401, + message, + } => super::super::budget::is_rate_limit_payload(message), + _ => false, + } +} + /// Emit shared provider-error metrics (rate-limit counter + error counter). fn emit_provider_error_metrics( mapping: &crate::cli::ModelMapping, e: &crate::providers::error::ProviderError, ) { - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit { + if is_upstream_rate_limit(e) { warn!("Provider {} rate limited", mapping.provider); metrics::counter!( "grob_ratelimit_hits_total", @@ -177,11 +192,7 @@ pub(super) async fn dispatch_streaming( if is_auth_revoked_error(&e) { return Err(ProviderLoopAction::AuthRevoked(e.to_string())); } - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit { + if is_upstream_rate_limit(&e) { Err(ProviderLoopAction::RateLimited) } else { Err(ProviderLoopAction::Continue) @@ -297,11 +308,7 @@ pub(super) async fn dispatch_non_streaming( if classify_and_handle_error(ctx, attempt.mapping, &e, retry) { // On 429, try rotating to next pooled key before retrying. - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit && provider.rotate_key_pool() { + if is_upstream_rate_limit(&e) && provider.rotate_key_pool() { info!( "Provider {} rate-limited, rotated to next pooled key", attempt.mapping.provider @@ -315,11 +322,7 @@ pub(super) async fn dispatch_non_streaming( } // Before giving up on this provider, try key rotation for 429. - let is_rate_limit = matches!( - e, - crate::providers::error::ProviderError::ApiError { status: 429, .. } - ); - if is_rate_limit && provider.rotate_key_pool() { + if is_upstream_rate_limit(&e) && provider.rotate_key_pool() { info!( "Provider {} exhausted retries but rotated to next pooled key", attempt.mapping.provider diff --git a/src/server/error.rs b/src/server/error.rs index 2617c99e..d03cf5e3 100644 --- a/src/server/error.rs +++ b/src/server/error.rs @@ -1,86 +1,371 @@ +//! Unified request-level error taxonomy. +//! +//! Single source of truth for HTTP error responses. Replaces the old +//! `AppError` + `ProviderError` split that masked upstream HTTP status codes +//! behind a generic `502 Bad Gateway` body. +//! +//! Every variant maps to a precise HTTP status; the ECMA RFC-9457 inspired +//! body shape is `{ "error": { "type": ..., "message": ..., ...extras } }`. +//! +//! `is_retryable()` is the authoritative classifier for retry/backoff +//! logic — `dispatch/retry.rs` and the provider loop must consult this +//! single method rather than re-implement status-code matching. + use axum::{ http::StatusCode, response::{IntoResponse, Response}, Json, }; -/// Application error types — all variants carry a user-facing message string. +/// Unified error type for the request pipeline. +/// +/// Carries the upstream HTTP status when applicable so the client sees the +/// exact failure mode (`Bad Gateway`, `Service Unavailable`, …) rather than +/// an opaque "Provider error" string. #[derive(Debug)] -pub enum AppError { - /// Indicates no matching route or model for the request. - RoutingError(String), - /// Indicates a malformed or invalid request payload. +pub enum RequestError { + /// Indicates a malformed or invalid request payload (HTTP 400). + BadRequest(String), + /// Indicates the caller failed authentication (HTTP 401). + Unauthorized, + /// Indicates the caller is authenticated but the action is forbidden (HTTP 403). + Forbidden(String), + /// Indicates the requested resource does not exist (HTTP 404). + NotFound, + /// Indicates a JSON parse or schema validation failure (HTTP 400). ParseError(String), - /// Indicates an upstream provider returned an error. - ProviderError(String), - /// Indicates the monthly spend budget has been exceeded. - BudgetExceeded(String), - /// Indicates the DLP pipeline blocked the request. + /// Indicates routing could not resolve a model or provider (HTTP 400). + RoutingError(String), + /// Indicates the upstream provider rate-limited the request (HTTP 429). + RateLimited { + /// Provider that emitted the 429 (or the resolved alias). + provider: String, + /// Server-side hint for when to retry, when known. + retry_after_ms: Option, + }, + /// Indicates the upstream provider returned a non-success status. + /// + /// The original status is forwarded verbatim so the client sees the + /// actual failure mode (502/503/504/etc.) instead of a flattened error. + ProviderUpstream { + /// Provider name (e.g. `"anthropic"`). + provider: String, + /// Verbatim upstream HTTP status code. + status: u16, + /// Optional upstream body excerpt for diagnostics. + body: Option, + }, + /// Indicates a budget cap (global, provider, or model) was exceeded (HTTP 402). + BudgetExceeded { + /// Configured monthly limit in USD. + limit_usd: f64, + /// Actual recorded spend in USD at the time of the check. + actual_usd: f64, + }, + /// Indicates the DLP pipeline blocked the request (HTTP 400). DlpBlocked(String), - /// Indicates an upstream OAuth token is revoked or invalid (401 authentication_error). + /// Indicates an upstream OAuth credential was revoked (HTTP 401). /// - /// Surfaced to the client as a terminal 401 without fallback to sibling providers. - AuthenticationError(String), + /// Surfaces a terminal authentication error — the user must run + /// `grob connect --force-reauth`. Distinct from `Unauthorized`, which + /// covers the inbound caller's credential rather than an upstream's. + AuthRevoked(String), + /// Indicates an internal server failure (HTTP 500). + Internal(anyhow::Error), } -impl IntoResponse for AppError { - fn into_response(self) -> Response { - let (status, error_type, message) = match self { - AppError::RoutingError(msg) => (StatusCode::BAD_REQUEST, "error", msg), - AppError::ParseError(msg) => (StatusCode::BAD_REQUEST, "invalid_request_error", msg), - AppError::ProviderError(msg) => (StatusCode::BAD_GATEWAY, "error", msg), - AppError::BudgetExceeded(msg) => (StatusCode::PAYMENT_REQUIRED, "budget_exceeded", msg), - AppError::DlpBlocked(msg) => (StatusCode::BAD_REQUEST, "dlp_block", msg), - AppError::AuthenticationError(msg) => { - (StatusCode::UNAUTHORIZED, "authentication_error", msg) +impl RequestError { + /// Returns `true` when the error is transient and the dispatch loop should + /// retry (with exponential backoff) before falling back to the next provider. + /// + /// This is the SINGLE source of truth for retry classification — both the + /// retry loop and the rate-limit detector must call this method rather than + /// duplicate the status-code matching logic. + /// + /// Notably **excludes** `AuthRevoked` (a permanent 401 requires operator + /// action) but **includes** `RateLimited` and 5xx upstream failures. + pub fn is_retryable(&self) -> bool { + match self { + RequestError::RateLimited { .. } => true, + RequestError::ProviderUpstream { status, .. } => { + matches!(*status, 429 | 500 | 502 | 503 | 504) } - }; + // Network/transport failures bubble up here too; treat as retryable + // when the underlying error chain wraps a `reqwest::Error`. + RequestError::Internal(err) => err.downcast_ref::().is_some(), + _ => false, + } + } + + /// Returns the HTTP status code, error type tag, and message for this variant. + fn parts(&self) -> (StatusCode, &'static str, String) { + match self { + RequestError::BadRequest(msg) => ( + StatusCode::BAD_REQUEST, + "invalid_request_error", + msg.clone(), + ), + RequestError::Unauthorized => ( + StatusCode::UNAUTHORIZED, + "authentication_error", + "Missing or invalid credentials".to_string(), + ), + RequestError::Forbidden(msg) => { + (StatusCode::FORBIDDEN, "permission_error", msg.clone()) + } + RequestError::NotFound => ( + StatusCode::NOT_FOUND, + "not_found_error", + "Resource not found".to_string(), + ), + RequestError::ParseError(msg) => ( + StatusCode::BAD_REQUEST, + "invalid_request_error", + msg.clone(), + ), + RequestError::RoutingError(msg) => (StatusCode::BAD_REQUEST, "error", msg.clone()), + RequestError::RateLimited { provider, .. } => ( + StatusCode::TOO_MANY_REQUESTS, + "rate_limit_error", + format!("Provider '{}' rate-limited the request", provider), + ), + RequestError::ProviderUpstream { + provider, + status, + body, + } => { + // Forward the upstream status verbatim so the client sees the + // exact failure mode. Default to 502 for unmapped non-success codes. + let status_code = StatusCode::from_u16(*status).unwrap_or(StatusCode::BAD_GATEWAY); + let msg = body + .clone() + .unwrap_or_else(|| format!("Provider '{}' returned HTTP {}", provider, status)); + (status_code, "error", msg) + } + RequestError::BudgetExceeded { + limit_usd, + actual_usd, + } => ( + StatusCode::PAYMENT_REQUIRED, + "budget_exceeded", + format!( + "Budget exceeded: ${:.4} spent of ${:.4} limit", + actual_usd, limit_usd + ), + ), + RequestError::DlpBlocked(msg) => (StatusCode::BAD_REQUEST, "dlp_block", msg.clone()), + RequestError::AuthRevoked(msg) => ( + StatusCode::UNAUTHORIZED, + "authentication_error", + msg.clone(), + ), + RequestError::Internal(err) => { + (StatusCode::INTERNAL_SERVER_ERROR, "error", err.to_string()) + } + } + } + + /// Returns a stable string tag for the error variant — used in audit logs + /// and metrics labels (low-cardinality alternative to the message). + pub fn variant_tag(&self) -> &'static str { + match self { + RequestError::BadRequest(_) => "bad_request", + RequestError::Unauthorized => "unauthorized", + RequestError::Forbidden(_) => "forbidden", + RequestError::NotFound => "not_found", + RequestError::ParseError(_) => "parse_error", + RequestError::RoutingError(_) => "routing_error", + RequestError::RateLimited { .. } => "rate_limited", + RequestError::ProviderUpstream { .. } => "provider_upstream", + RequestError::BudgetExceeded { .. } => "budget_exceeded", + RequestError::DlpBlocked(_) => "dlp_blocked", + RequestError::AuthRevoked(_) => "auth_revoked", + RequestError::Internal(_) => "internal", + } + } +} - let body = Json(serde_json::json!({ +impl IntoResponse for RequestError { + fn into_response(self) -> Response { + let (status, error_type, message) = self.parts(); + + let mut body_obj = serde_json::json!({ "error": { "type": error_type, - "message": message + "message": message, + } + }); + + // Attach variant-specific extras (retry-after for rate limits, + // budget figures for budget overruns, upstream provider for 5xx). + match &self { + RequestError::RateLimited { + provider, + retry_after_ms, + } => { + if let Some(error) = body_obj.get_mut("error").and_then(|v| v.as_object_mut()) { + error.insert( + "provider".to_string(), + serde_json::Value::String(provider.clone()), + ); + if let Some(ms) = retry_after_ms { + error.insert("retry_after_ms".to_string(), serde_json::Value::from(*ms)); + } + } + } + RequestError::ProviderUpstream { + provider, status, .. + } => { + if let Some(error) = body_obj.get_mut("error").and_then(|v| v.as_object_mut()) { + error.insert( + "provider".to_string(), + serde_json::Value::String(provider.clone()), + ); + error.insert( + "upstream_status".to_string(), + serde_json::Value::from(*status), + ); + } + } + RequestError::BudgetExceeded { + limit_usd, + actual_usd, + } => { + if let Some(error) = body_obj.get_mut("error").and_then(|v| v.as_object_mut()) { + error.insert("limit_usd".to_string(), serde_json::Value::from(*limit_usd)); + error.insert( + "actual_usd".to_string(), + serde_json::Value::from(*actual_usd), + ); + } } - })); + _ => {} + } + + let mut response = (status, Json(body_obj)).into_response(); + // Mark response so the audit middleware can pick up the variant + // without re-parsing the body. + response + .extensions_mut() + .insert(ErrorVariantTag(self.variant_tag().to_string())); + + // Forward Retry-After header on 429 when known. + if let RequestError::RateLimited { + retry_after_ms: Some(ms), + .. + } = &self + { + // RFC 7231 Retry-After is in seconds; round up so we never advise + // a delay shorter than the upstream actually requested. + let secs = ms.div_ceil(1000).max(1); + if let Ok(value) = axum::http::HeaderValue::from_str(&secs.to_string()) { + response.headers_mut().insert("retry-after", value); + } + } - (status, body).into_response() + response } } -impl std::fmt::Display for AppError { +/// Marker stored in response extensions so middleware can read the error +/// variant tag without parsing the body. Value is the lowercase variant tag. +#[derive(Clone, Debug)] +pub struct ErrorVariantTag(pub String); + +impl std::fmt::Display for RequestError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (_, _, message) = self.parts(); + write!(f, "{}: {}", self.variant_tag(), message) + } +} + +impl std::error::Error for RequestError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { match self { - AppError::RoutingError(msg) => write!(f, "Routing error: {}", msg), - AppError::ParseError(msg) => write!(f, "Parse error: {}", msg), - AppError::ProviderError(msg) => write!(f, "Provider error: {}", msg), - AppError::BudgetExceeded(msg) => write!(f, "Budget exceeded: {}", msg), - AppError::DlpBlocked(msg) => write!(f, "DLP blocked: {}", msg), - AppError::AuthenticationError(msg) => write!(f, "Authentication error: {}", msg), + RequestError::Internal(err) => Some(err.as_ref()), + _ => None, } } } -impl std::error::Error for AppError {} +impl From for RequestError { + fn from(err: anyhow::Error) -> Self { + RequestError::Internal(err) + } +} + +impl From for RequestError { + fn from(err: crate::providers::error::ProviderError) -> Self { + use crate::providers::error::ProviderError; + match err { + ProviderError::ApiError { status, message } => match status { + 429 => RequestError::RateLimited { + provider: "upstream".to_string(), + retry_after_ms: None, + }, + 401 => { + if super::budget::is_rate_limit_payload(&message) { + // Anthropic emits `rate_limit_error` with HTTP 401 — treat + // as a transient rate-limit, not a revoked credential. + RequestError::RateLimited { + provider: "upstream".to_string(), + retry_after_ms: None, + } + } else { + RequestError::AuthRevoked(message) + } + } + _ => RequestError::ProviderUpstream { + provider: "upstream".to_string(), + status, + body: Some(message), + }, + }, + ProviderError::HttpError(e) => { + RequestError::Internal(anyhow::Error::new(e).context("HTTP request failed")) + } + ProviderError::SerializationError(e) => { + RequestError::ParseError(format!("serialization failed: {}", e)) + } + ProviderError::ModelNotSupported(model) => RequestError::RoutingError(format!( + "Model '{}' is not configured. Add a [[models]] entry or set pass_through = true on a provider.", + model + )), + ProviderError::ConfigError(msg) => { + RequestError::Internal(anyhow::anyhow!("Provider config error: {}", msg)) + } + ProviderError::AuthError(msg) => RequestError::AuthRevoked(msg), + ProviderError::NoProviderAvailable => { + RequestError::RoutingError("No provider available for this request".to_string()) + } + ProviderError::AllProvidersFailed(msg) => RequestError::ProviderUpstream { + provider: "all".to_string(), + status: 502, + body: Some(msg), + }, + } + } +} #[cfg(test)] mod tests { use super::*; - /// Extracts status code and parsed JSON body from an AppError response. - async fn error_response_parts(error: AppError) -> (StatusCode, serde_json::Value) { + /// Extracts status code and parsed JSON body from a `RequestError` response. + async fn error_response_parts(error: RequestError) -> (StatusCode, serde_json::Value) { let response = error.into_response(); let status = response.status(); let body_bytes = axum::body::to_bytes(response.into_body(), 1024 * 1024) .await .expect("invariant: in-memory body collection cannot fail"); let json: serde_json::Value = serde_json::from_slice(&body_bytes) - .expect("invariant: AppError always produces valid JSON"); + .expect("invariant: RequestError always produces valid JSON"); (status, json) } #[tokio::test] async fn parse_error_returns_400_with_invalid_request_type() { - let err = AppError::ParseError("invalid JSON at line 1".to_string()); + let err = RequestError::ParseError("invalid JSON at line 1".to_string()); let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_REQUEST); @@ -88,9 +373,40 @@ mod tests { assert_eq!(json["error"]["message"], "invalid JSON at line 1"); } + #[tokio::test] + async fn bad_request_returns_400() { + let err = RequestError::BadRequest("missing required field 'model'".to_string()); + let (status, json) = error_response_parts(err).await; + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(json["error"]["type"], "invalid_request_error"); + } + + #[tokio::test] + async fn unauthorized_returns_401() { + let err = RequestError::Unauthorized; + let (status, json) = error_response_parts(err).await; + assert_eq!(status, StatusCode::UNAUTHORIZED); + assert_eq!(json["error"]["type"], "authentication_error"); + } + + #[tokio::test] + async fn forbidden_returns_403() { + let err = RequestError::Forbidden("policy denies model access".to_string()); + let (status, json) = error_response_parts(err).await; + assert_eq!(status, StatusCode::FORBIDDEN); + assert_eq!(json["error"]["type"], "permission_error"); + } + + #[tokio::test] + async fn not_found_returns_404() { + let err = RequestError::NotFound; + let (status, _) = error_response_parts(err).await; + assert_eq!(status, StatusCode::NOT_FOUND); + } + #[tokio::test] async fn routing_error_returns_400_with_error_type() { - let err = AppError::RoutingError("no matching model: gpt-unknown".to_string()); + let err = RequestError::RoutingError("no matching model: gpt-unknown".to_string()); let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_REQUEST); @@ -99,28 +415,72 @@ mod tests { } #[tokio::test] - async fn provider_error_returns_502() { - let err = AppError::ProviderError("upstream timeout".to_string()); - let (status, json) = error_response_parts(err).await; + async fn rate_limited_returns_429_with_retry_after() { + let err = RequestError::RateLimited { + provider: "anthropic".to_string(), + retry_after_ms: Some(2500), + }; + let response = err.into_response(); + let status = response.status(); + let retry_after = response + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let body_bytes = axum::body::to_bytes(response.into_body(), 1024 * 1024) + .await + .unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); + + assert_eq!(status, StatusCode::TOO_MANY_REQUESTS); + assert_eq!(json["error"]["type"], "rate_limit_error"); + assert_eq!(json["error"]["provider"], "anthropic"); + assert_eq!(json["error"]["retry_after_ms"], 2500); + // 2500ms rounds up to 3 seconds. + assert_eq!(retry_after.as_deref(), Some("3")); + } + #[tokio::test] + async fn provider_upstream_502_forwards_status() { + let err = RequestError::ProviderUpstream { + provider: "openai".to_string(), + status: 502, + body: Some("upstream gateway timeout".to_string()), + }; + let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_GATEWAY); - assert_eq!(json["error"]["type"], "error"); - assert_eq!(json["error"]["message"], "upstream timeout"); + assert_eq!(json["error"]["upstream_status"], 502); + assert_eq!(json["error"]["provider"], "openai"); + } + + #[tokio::test] + async fn provider_upstream_503_forwards_status() { + let err = RequestError::ProviderUpstream { + provider: "openai".to_string(), + status: 503, + body: None, + }; + let (status, _) = error_response_parts(err).await; + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); } #[tokio::test] - async fn budget_exceeded_returns_402() { - let err = AppError::BudgetExceeded("monthly limit reached".to_string()); + async fn budget_exceeded_returns_402_with_figures() { + let err = RequestError::BudgetExceeded { + limit_usd: 100.0, + actual_usd: 105.5, + }; let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::PAYMENT_REQUIRED); assert_eq!(json["error"]["type"], "budget_exceeded"); - assert_eq!(json["error"]["message"], "monthly limit reached"); + assert_eq!(json["error"]["limit_usd"], 100.0); + assert_eq!(json["error"]["actual_usd"], 105.5); } #[tokio::test] async fn dlp_blocked_returns_400_with_dlp_block_type() { - let err = AppError::DlpBlocked("secret detected in prompt".to_string()); + let err = RequestError::DlpBlocked("secret detected in prompt".to_string()); let (status, json) = error_response_parts(err).await; assert_eq!(status, StatusCode::BAD_REQUEST); @@ -129,13 +489,12 @@ mod tests { } #[tokio::test] - async fn authentication_error_returns_401_with_authentication_error_type() { - let err = AppError::AuthenticationError( + async fn auth_revoked_returns_401() { + let err = RequestError::AuthRevoked( "OAuth token for provider 'anthropic' revoked. Run: grob connect --force-reauth" .to_string(), ); let (status, json) = error_response_parts(err).await; - assert_eq!(status, StatusCode::UNAUTHORIZED); assert_eq!(json["error"]["type"], "authentication_error"); assert!(json["error"]["message"] @@ -144,12 +503,257 @@ mod tests { .contains("grob connect --force-reauth")); } + #[tokio::test] + async fn internal_returns_500() { + let err = RequestError::Internal(anyhow::anyhow!("disk full")); + let (status, _) = error_response_parts(err).await; + assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); + } + + // ── is_retryable() table-driven tests ── + #[test] - fn display_impl_includes_variant_prefix() { - let err = AppError::ParseError("bad input".to_string()); - assert_eq!(err.to_string(), "Parse error: bad input"); + fn rate_limited_is_retryable() { + let err = RequestError::RateLimited { + provider: "x".to_string(), + retry_after_ms: None, + }; + assert!(err.is_retryable()); + } - let err = AppError::RoutingError("no route".to_string()); - assert_eq!(err.to_string(), "Routing error: no route"); + #[test] + fn upstream_500_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 500, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_502_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 502, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_503_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 503, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_504_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 504, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_429_is_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 429, + body: None, + }; + assert!(err.is_retryable()); + } + + #[test] + fn upstream_400_is_not_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 400, + body: None, + }; + assert!(!err.is_retryable()); + } + + #[test] + fn upstream_401_is_not_retryable() { + let err = RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 401, + body: None, + }; + assert!(!err.is_retryable()); + } + + #[test] + fn auth_revoked_is_not_retryable() { + let err = RequestError::AuthRevoked("revoked".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn unauthorized_is_not_retryable() { + let err = RequestError::Unauthorized; + assert!(!err.is_retryable()); + } + + #[test] + fn forbidden_is_not_retryable() { + let err = RequestError::Forbidden("nope".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn parse_error_is_not_retryable() { + let err = RequestError::ParseError("bad".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn routing_error_is_not_retryable() { + let err = RequestError::RoutingError("no model".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn budget_exceeded_is_not_retryable() { + let err = RequestError::BudgetExceeded { + limit_usd: 100.0, + actual_usd: 200.0, + }; + assert!(!err.is_retryable()); + } + + #[test] + fn dlp_blocked_is_not_retryable() { + let err = RequestError::DlpBlocked("secret".to_string()); + assert!(!err.is_retryable()); + } + + #[test] + fn internal_without_reqwest_is_not_retryable() { + let err = RequestError::Internal(anyhow::anyhow!("logic bug")); + assert!(!err.is_retryable()); + } + + #[test] + fn variant_tags_are_stable() { + assert_eq!( + RequestError::BadRequest("x".to_string()).variant_tag(), + "bad_request" + ); + assert_eq!(RequestError::Unauthorized.variant_tag(), "unauthorized"); + assert_eq!( + RequestError::Forbidden("x".to_string()).variant_tag(), + "forbidden" + ); + assert_eq!(RequestError::NotFound.variant_tag(), "not_found"); + assert_eq!( + RequestError::ParseError("x".to_string()).variant_tag(), + "parse_error" + ); + assert_eq!( + RequestError::RoutingError("x".to_string()).variant_tag(), + "routing_error" + ); + assert_eq!( + RequestError::RateLimited { + provider: "x".to_string(), + retry_after_ms: None + } + .variant_tag(), + "rate_limited" + ); + assert_eq!( + RequestError::ProviderUpstream { + provider: "x".to_string(), + status: 502, + body: None + } + .variant_tag(), + "provider_upstream" + ); + assert_eq!( + RequestError::BudgetExceeded { + limit_usd: 1.0, + actual_usd: 2.0, + } + .variant_tag(), + "budget_exceeded" + ); + assert_eq!( + RequestError::DlpBlocked("x".to_string()).variant_tag(), + "dlp_blocked" + ); + assert_eq!( + RequestError::AuthRevoked("x".to_string()).variant_tag(), + "auth_revoked" + ); + assert_eq!( + RequestError::Internal(anyhow::anyhow!("x")).variant_tag(), + "internal" + ); + } + + #[test] + fn provider_error_429_converts_to_rate_limited() { + let err = crate::providers::error::ProviderError::ApiError { + status: 429, + message: "slow down".to_string(), + }; + let req_err: RequestError = err.into(); + assert!(matches!(req_err, RequestError::RateLimited { .. })); + assert!(req_err.is_retryable()); + } + + #[test] + fn provider_error_500_converts_to_upstream() { + let err = crate::providers::error::ProviderError::ApiError { + status: 500, + message: "boom".to_string(), + }; + let req_err: RequestError = err.into(); + match &req_err { + RequestError::ProviderUpstream { status, .. } => assert_eq!(*status, 500), + other => panic!("unexpected variant: {:?}", other), + } + assert!(req_err.is_retryable()); + } + + #[test] + fn provider_error_401_with_rate_limit_payload_converts_to_rate_limited() { + let err = crate::providers::error::ProviderError::ApiError { + status: 401, + message: r#"{"type":"error","error":{"type":"rate_limit_error","message":"slow"}}"# + .to_string(), + }; + let req_err: RequestError = err.into(); + assert!(matches!(req_err, RequestError::RateLimited { .. })); + } + + #[test] + fn provider_error_401_authentication_converts_to_auth_revoked() { + let err = crate::providers::error::ProviderError::ApiError { + status: 401, + message: r#"{"type":"error","error":{"type":"authentication_error","message":"bad"}}"# + .to_string(), + }; + let req_err: RequestError = err.into(); + assert!(matches!(req_err, RequestError::AuthRevoked(_))); + assert!(!req_err.is_retryable()); + } + + #[test] + fn display_includes_variant_tag() { + let err = RequestError::ParseError("bad input".to_string()); + let s = err.to_string(); + assert!(s.contains("parse_error")); + assert!(s.contains("bad input")); } } diff --git a/src/server/handlers.rs b/src/server/handlers.rs index 679a0a4e..41849abb 100644 --- a/src/server/handlers.rs +++ b/src/server/handlers.rs @@ -10,9 +10,10 @@ use futures::stream::TryStreamExt; use std::sync::Arc; use tracing::{debug, error}; +use super::middleware::AuditedAlready; use super::{ apply_transparency_headers, dispatch, extract_api_credential, extract_client_ip, openai_compat, - responses_compat, should_apply_transparency, AppError, AppState, RequestId, + responses_compat, should_apply_transparency, AppState, RequestError, RequestId, }; /// Extracts tenant_id from VirtualKeyContext (preferred) or GrobClaims. @@ -50,12 +51,12 @@ impl Drop for ActiveRequestGuard { fn build_json_response( body: Vec, transparency: Option<(&str, &str, &str)>, -) -> Result { +) -> Result { let mut resp = Response::builder() .status(200) .header("content-type", "application/json") .body(Body::from(body)) - .map_err(|e| AppError::ProviderError(format!("response builder: {}", e)))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("response builder: {}", e)))?; if let Some((provider, actual_model, req_id)) = transparency { apply_transparency_headers(resp.headers_mut(), provider, actual_model, req_id); } @@ -71,11 +72,11 @@ fn build_sse_response() -> axum::http::response::Builder { .header("Connection", "keep-alive") } -/// Serializes a value to JSON bytes, returning an `AppError` on failure. -fn serialize_response(value: &T) -> Result, AppError> { +/// Serializes a value to JSON bytes, returning a `RequestError` on failure. +fn serialize_response(value: &T) -> Result, RequestError> { serde_json::to_vec(value).map_err(|e| { error!("Failed to serialize response: {}", e); - AppError::ProviderError(format!("response serialization failed: {}", e)) + RequestError::Internal(anyhow::anyhow!("response serialization failed: {}", e)) }) } @@ -86,6 +87,18 @@ struct DispatchPrelude { tenant_id: Option, peer_ip: String, transparency_enabled: bool, + /// Set by the dispatch path when it has emitted an audit entry — used + /// by the audit middleware to skip duplicate logging. + audited: Arc, +} + +/// Marks the response with the [`AuditedAlready`] extension when dispatch +/// has emitted its own audit entry, so the outer audit middleware skips a +/// duplicate write. +fn mark_audited_if_set(audited: &Arc, response: &mut Response) { + if audited.load(std::sync::atomic::Ordering::Acquire) { + response.extensions_mut().insert(AuditedAlready); + } } /// Builds the shared pre-dispatch state common to all three handlers. @@ -113,6 +126,7 @@ fn prepare_dispatch( tenant_id, peer_ip, transparency_enabled, + audited: Arc::new(std::sync::atomic::AtomicBool::new(false)), } } @@ -137,7 +151,7 @@ fn finish_dispatch( on_streaming: S, on_complete: C, on_fan_out: F, -) -> Result +) -> Result where S: FnOnce( std::pin::Pin< @@ -148,7 +162,7 @@ where >, >, ) -> Body, - C: FnOnce(crate::providers::ProviderResponse) -> Result, AppError>, + C: FnOnce(crate::providers::ProviderResponse) -> Result, RequestError>, F: FnOnce(crate::providers::ProviderResponse) -> Response, { match result { @@ -179,7 +193,7 @@ where let response = response_builder .body(body) - .map_err(|e| AppError::ProviderError(format!("response builder: {}", e)))?; + .map_err(|e| RequestError::Internal(anyhow::anyhow!("response builder: {}", e)))?; Ok(response) } @@ -227,7 +241,7 @@ pub(crate) async fn handle_openai_chat_completions( axum::Extension(request_id): axum::Extension, headers: HeaderMap, Json(openai_request): Json, -) -> Result { +) -> Result { let _guard = ActiveRequestGuard::new(&state); let model = openai_request.model.clone(); let is_streaming = openai_request.stream == Some(true); @@ -235,14 +249,17 @@ pub(crate) async fn handle_openai_chat_completions( let prelude = prepare_dispatch(&state, &claims, &vk_ctx, &headers); // Transform OpenAI → Anthropic format - let mut request = openai_compat::transform_openai_to_canonical(openai_request) - .map_err(|e| AppError::ParseError(format!("Failed to transform OpenAI request: {}", e)))?; + let mut request = + openai_compat::transform_openai_to_canonical(openai_request).map_err(|e| { + RequestError::ParseError(format!("Failed to transform OpenAI request: {}", e)) + })?; forward_beta_header(&mut request, &headers); let start_time = std::time::Instant::now(); #[cfg(feature = "policies")] let resolved_policy = evaluate_policy_if_configured(&state, prelude.tenant_id.as_deref(), &model, &headers); + let audited_flag = prelude.audited.clone(); let ctx = dispatch::DispatchContext { state: &state, inner: &prelude.inner, @@ -255,15 +272,23 @@ pub(crate) async fn handle_openai_chat_completions( start_time, headers: &headers, trace_id: None, + audited: audited_flag.clone(), #[cfg(feature = "policies")] resolved_policy, }; - let result = dispatch::dispatch(&ctx, &mut request).await?; + let result = match dispatch::dispatch(&ctx, &mut request).await { + Ok(r) => r, + Err(e) => { + let mut response = e.into_response(); + mark_audited_if_set(&audited_flag, &mut response); + return Ok(response); + } + }; let model_for_stream = model.clone(); let model_for_fanout = model.clone(); - finish_dispatch( + let mut response = finish_dispatch( result, prelude.transparency_enabled, &request_id.0, @@ -284,7 +309,9 @@ pub(crate) async fn handle_openai_chat_completions( openai_compat::transform_canonical_to_openai(resp, model_for_fanout); Json(openai_response).into_response() }, - ) + )?; + mark_audited_if_set(&audited_flag, &mut response); + Ok(response) } /// Handle /v1/responses requests (OpenAI Responses API — used by Codex CLI) @@ -296,7 +323,7 @@ pub(crate) async fn handle_responses( axum::Extension(request_id): axum::Extension, headers: HeaderMap, Json(responses_request): Json, -) -> Result { +) -> Result { let _guard = ActiveRequestGuard::new(&state); let model = responses_request.model.clone(); let is_streaming = responses_request.stream == Some(true); @@ -306,7 +333,7 @@ pub(crate) async fn handle_responses( // Transform Responses → canonical format let mut request = responses_compat::transform_responses_to_canonical(responses_request) .map_err(|e| { - AppError::ParseError(format!("Failed to transform Responses request: {}", e)) + RequestError::ParseError(format!("Failed to transform Responses request: {}", e)) })?; forward_beta_header(&mut request, &headers); @@ -314,6 +341,7 @@ pub(crate) async fn handle_responses( #[cfg(feature = "policies")] let resolved_policy = evaluate_policy_if_configured(&state, prelude.tenant_id.as_deref(), &model, &headers); + let audited_flag = prelude.audited.clone(); let ctx = dispatch::DispatchContext { state: &state, inner: &prelude.inner, @@ -326,15 +354,23 @@ pub(crate) async fn handle_responses( start_time, headers: &headers, trace_id: None, + audited: audited_flag.clone(), #[cfg(feature = "policies")] resolved_policy, }; - let result = dispatch::dispatch(&ctx, &mut request).await?; + let result = match dispatch::dispatch(&ctx, &mut request).await { + Ok(r) => r, + Err(e) => { + let mut response = e.into_response(); + mark_audited_if_set(&audited_flag, &mut response); + return Ok(response); + } + }; let model_for_stream = model.clone(); let model_for_fanout = model.clone(); - finish_dispatch( + let mut response = finish_dispatch( result, prelude.transparency_enabled, &request_id.0, @@ -357,7 +393,9 @@ pub(crate) async fn handle_responses( responses_compat::transform_canonical_to_responses(resp, model_for_fanout); Json(responses_response).into_response() }, - ) + )?; + mark_audited_if_set(&audited_flag, &mut response); + Ok(response) } /// Handle /v1/models endpoint (OpenAI-compatible) @@ -391,7 +429,7 @@ pub(crate) async fn handle_messages( axum::Extension(request_id): axum::Extension, headers: HeaderMap, Json(request_json): Json, -) -> Result { +) -> Result { let _guard = ActiveRequestGuard::new(&state); let req_id = &request_id.0; let model: String = request_json @@ -412,7 +450,7 @@ pub(crate) async fn handle_messages( let mut request: CanonicalRequest = serde_json::from_value(request_json).map_err(|e| { tracing::error!("❌ Failed to parse request: {}", e); - AppError::ParseError(format!("Invalid request format: {}", e)) + RequestError::ParseError(format!("Invalid request format: {}", e)) })?; forward_beta_header(&mut request, &headers); @@ -422,6 +460,7 @@ pub(crate) async fn handle_messages( #[cfg(feature = "policies")] let resolved_policy = evaluate_policy_if_configured(&state, prelude.tenant_id.as_deref(), &model, &headers); + let audited_flag = prelude.audited.clone(); let ctx = dispatch::DispatchContext { state: &state, inner: &prelude.inner, @@ -434,13 +473,21 @@ pub(crate) async fn handle_messages( start_time, headers: &headers, trace_id: Some(trace_id), + audited: audited_flag.clone(), #[cfg(feature = "policies")] resolved_policy, }; - let result = dispatch::dispatch(&ctx, &mut request).await?; + let result = match dispatch::dispatch(&ctx, &mut request).await { + Ok(r) => r, + Err(e) => { + let mut response = e.into_response(); + mark_audited_if_set(&audited_flag, &mut response); + return Ok(response); + } + }; - finish_dispatch( + let mut response = finish_dispatch( result, prelude.transparency_enabled, req_id, @@ -454,14 +501,16 @@ pub(crate) async fn handle_messages( }, |resp| serialize_response(&resp), |resp| Json(resp).into_response(), - ) + )?; + mark_audited_if_set(&audited_flag, &mut response); + Ok(response) } /// Handle /v1/messages/count_tokens requests pub(crate) async fn handle_count_tokens( State(state): State>, Json(request_json): Json, -) -> Result { +) -> Result { let model = request_json .get("model") .and_then(|m| m.as_str()) @@ -491,7 +540,7 @@ pub(crate) async fn handle_count_tokens( let decision = inner .router .route(&mut routing_request) - .map_err(|e| AppError::RoutingError(e.to_string()))?; + .map_err(|e| RequestError::RoutingError(e.to_string()))?; debug!( "🧮 Routed count_tokens: {} → {} ({})", @@ -500,8 +549,9 @@ pub(crate) async fn handle_count_tokens( // Deserialize the full count_tokens request (consumes the JSON value — no clone). use crate::models::CountTokensRequest; - let count_request: CountTokensRequest = serde_json::from_value(request_json) - .map_err(|e| AppError::ParseError(format!("Invalid count_tokens request format: {}", e)))?; + let count_request: CountTokensRequest = serde_json::from_value(request_json).map_err(|e| { + RequestError::ParseError(format!("Invalid count_tokens request format: {}", e)) + })?; // Try model mappings with fallback (1:N mapping) if let Some(model_config) = inner.find_model(&decision.model_name) { @@ -525,11 +575,15 @@ pub(crate) async fn handle_count_tokens( } } - Err(AppError::ProviderError(format!( - "All {} provider mappings failed for token counting: {}", - sorted_mappings.len(), - decision.model_name - ))) + Err(RequestError::ProviderUpstream { + provider: "all".to_string(), + status: 502, + body: Some(format!( + "All {} provider mappings failed for token counting: {}", + sorted_mappings.len(), + decision.model_name + )), + }) } else if let Ok(provider) = inner .provider_registry .provider_for_model(&decision.model_name) @@ -539,10 +593,10 @@ pub(crate) async fn handle_count_tokens( let response = provider .count_tokens(req) .await - .map_err(|e| AppError::ProviderError(e.to_string()))?; + .map_err(RequestError::from)?; Ok(Json(response).into_response()) } else { - Err(AppError::ProviderError(format!( + Err(RequestError::RoutingError(format!( "No model mapping or provider found for token counting: {}", decision.model_name ))) diff --git a/src/server/helpers.rs b/src/server/helpers.rs index 55efc317..f0a2f825 100644 --- a/src/server/helpers.rs +++ b/src/server/helpers.rs @@ -4,14 +4,14 @@ use std::borrow::Cow; use std::sync::Arc; use tracing::info; -use super::{AppError, ReloadableState}; +use super::{ReloadableState, RequestError}; /// Resolve and sort provider mappings for a routing decision. pub(crate) fn resolve_provider_mappings( inner: &Arc, headers: &HeaderMap, decision: &crate::models::RouteDecision, -) -> Result, AppError> { +) -> Result, RequestError> { // Tier-based provider selection (opt-in via [[tiers]] config) if let Some(ref tier) = decision.complexity_tier { let tier_name = tier.to_string(); @@ -112,7 +112,7 @@ pub(crate) fn resolve_provider_mappings( if let Some(ref provider_name) = forced_provider { sorted.retain(|m| m.provider == *provider_name); if sorted.is_empty() { - return Err(AppError::RoutingError(format!( + return Err(RequestError::RoutingError(format!( "Provider '{}' not found in mappings for model '{}'", provider_name, decision.model_name ))); @@ -137,7 +137,7 @@ pub(crate) fn resolve_provider_mappings( provider_region == region_filter || provider_region == "global" }); if sorted.is_empty() { - return Err(AppError::RoutingError(format!( + return Err(RequestError::RoutingError(format!( "No providers match region '{}' for model '{}' (GDPR filtering enabled)", region_filter, decision.model_name ))); @@ -201,7 +201,7 @@ pub(crate) fn resolve_provider_mappings( .collect(); if pass_through_mappings.is_empty() { - Err(AppError::RoutingError(format!( + Err(RequestError::RoutingError(format!( "Model '{}' is not configured. Add a [[models]] entry in config.toml or set pass_through = true on a provider.", decision.model_name ))) diff --git a/src/server/middleware.rs b/src/server/middleware.rs index 1a73084e..2a06e92f 100644 --- a/src/server/middleware.rs +++ b/src/server/middleware.rs @@ -78,7 +78,7 @@ pub(crate) fn auth_error_response(message: &str) -> Response { /// Stored in request extensions for correlation #[derive(Clone, Debug)] -pub(crate) struct RequestId(pub String); +pub struct RequestId(pub String); /// Auth middleware: supports three modes: /// - "none" (default): all requests pass @@ -321,6 +321,193 @@ pub(crate) async fn security_headers_response_middleware( apply_security_headers(response, &config) } +/// Marker inserted into response extensions by handlers that already wrote +/// an audit entry. The audit middleware skips logging when present so that +/// the dispatch pipeline (which audits with rich DLP and token-count context) +/// is the source of truth for request-lifecycle entries on the hot path. +/// +/// Endpoints that bypass dispatch entirely (oauth handlers, config API, +/// errors raised in middleware before dispatch) leave this marker absent +/// and are audited centrally by the middleware. +#[derive(Clone, Debug)] +pub struct AuditedAlready; + +/// Inputs captured by the audit middleware before the handler runs. +/// +/// Stored on the request side so post-handler audit emission can rebuild +/// the entry without re-reading consumed request state. +pub struct AuditMiddlewareCapture { + /// HTTP method of the request. + pub method: axum::http::Method, + /// Path component of the request URI. + pub path: String, + /// Correlation ID resolved from the `RequestId` extension. + pub request_id: String, + /// Tenant identifier from JWT / virtual key, or empty. + pub tenant_id: String, + /// Client IP from `X-Forwarded-For` or `"unknown"`. + pub client_ip: String, + /// Wall-clock instant the middleware observed the request. + pub started_at: std::time::Instant, +} + +/// Pulls the captured request context that `audit_log_layer` snapshots +/// before the handler runs. +pub fn capture_audit_input(request: &Request) -> AuditMiddlewareCapture { + let request_id = request + .extensions() + .get::() + .map(|r| r.0.clone()) + .unwrap_or_default(); + + let tenant_id = if let Some(vk) = request + .extensions() + .get::() + { + vk.tenant_id.clone() + } else if let Some(claims) = request.extensions().get::() { + claims.tenant_id().to_string() + } else { + String::new() + }; + + AuditMiddlewareCapture { + method: request.method().clone(), + path: request.uri().path().to_string(), + request_id, + tenant_id, + client_ip: extract_client_ip(request.headers()), + started_at: std::time::Instant::now(), + } +} + +/// Emits an `AuditEvent::RequestProcessed` entry from the captured request +/// context plus the post-handler response. Returns `true` when an entry +/// was written, `false` when the response carried [`AuditedAlready`] (in +/// which case the dispatch pipeline already wrote a richer entry). +/// +/// Extracted from [`audit_log_layer`] so it can be unit-tested without +/// constructing a full `AppState`. +pub fn emit_request_processed( + audit_log: &crate::security::AuditLog, + capture: &AuditMiddlewareCapture, + response: &Response, +) -> bool { + if response.extensions().get::().is_some() { + return false; + } + + let status = response.status(); + let duration_ms = capture.started_at.elapsed().as_millis() as u64; + + let provider = response + .headers() + .get("x-ai-provider") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + let model = response + .headers() + .get("x-ai-model") + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_string(); + + let error_variant = response + .extensions() + .get::() + .map(|tag| tag.0.clone()); + + let backend = if !provider.is_empty() { + provider + } else if let Some(ref tag) = error_variant { + format!("ERROR:{}:{}", tag, status.as_u16()) + } else if status.is_success() { + format!("{} {}", capture.method, capture.path) + } else { + format!("STATUS:{}", status.as_u16()) + }; + + let tenant_for_entry = if capture.tenant_id.is_empty() { + capture.client_ip.as_str() + } else { + capture.tenant_id.as_str() + }; + + let mut builder = super::AuditEntryBuilder::new( + tenant_for_entry, + crate::security::audit_log::AuditEvent::RequestProcessed, + &backend, + &capture.client_ip, + duration_ms, + ); + + if !model.is_empty() { + builder = builder.model(model); + } + + // Risk level: low for 2xx, medium for 4xx, high for 5xx — matches + // the EU AI Act Article 14 escalation threshold defaults. + let risk = if status.is_server_error() { + crate::security::audit_log::RiskLevel::High + } else if status.is_client_error() { + crate::security::audit_log::RiskLevel::Medium + } else { + crate::security::audit_log::RiskLevel::Low + }; + builder = builder.risk(risk); + + if let Some(tag) = error_variant { + builder = builder.dlp_rules(vec![format!( + "request_error:{}:status={}", + tag, + status.as_u16() + )]); + } + + if let Err(e) = audit_log.write(builder.build()) { + tracing::error!( + error = %e, + request_id = %capture.request_id, + "audit middleware: write failed" + ); + } + true +} + +/// Audit-log middleware: emits `AuditEvent::RequestProcessed` for every HTTP +/// request that flows through the server. +/// +/// Wraps every endpoint, including the OAuth, config, and health surfaces +/// that previously bypassed audit entirely. Captures request method, path, +/// status, latency, error variant tag (when 4xx/5xx), tenant identifier +/// (from JWT claims or virtual key context), client IP, and the upstream +/// provider name when set on the response by the dispatch pipeline. +/// +/// Skips logging when the dispatch pipeline has already written a richer +/// audit entry (signalled by the [`AuditedAlready`] marker in response +/// extensions). Health and metrics endpoints are excluded to avoid +/// flooding the journal with unauthenticated probe traffic. +pub(crate) async fn audit_log_layer( + State(state): State>, + request: Request, + next: Next, +) -> Response { + let path = request.uri().path(); + if matches!(path, "/health" | "/live" | "/ready" | "/metrics") { + return next.run(request).await; + } + + let capture = capture_audit_input(&request); + let response = next.run(request).await; + + if let Some(ref audit_log) = state.security.audit_log { + emit_request_processed(audit_log, &capture, &response); + } + + response +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/server/mod.rs b/src/server/mod.rs index e815fa3d..792702d4 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -37,7 +37,7 @@ pub(crate) use budget::{ calculate_cost, check_budget, is_auth_revoked_error, is_provider_subscription, is_retryable, record_request_metrics, record_spend, retry_delay, RequestMetrics, MAX_RETRIES, }; -pub use error::AppError; +pub use error::{ErrorVariantTag, RequestError}; pub(crate) use helpers::{ format_route_type, inject_continuation_text, resolve_provider_mappings, sanitize_provider_response_reported, should_inject_continuation, @@ -49,9 +49,12 @@ pub(crate) use init::{ init_provider_scorer, init_security, maybe_preset_sync, spawn_background_tasks, }; pub(crate) use middleware::{ - apply_transparency_headers, auth_middleware, extract_api_credential, extract_client_ip, - rate_limit_check_middleware, request_id_middleware, security_headers_response_middleware, - should_apply_transparency, RequestId, + apply_transparency_headers, audit_log_layer, auth_middleware, extract_api_credential, + extract_client_ip, rate_limit_check_middleware, request_id_middleware, + security_headers_response_middleware, should_apply_transparency, +}; +pub use middleware::{ + capture_audit_input, emit_request_processed, AuditMiddlewareCapture, AuditedAlready, RequestId, }; use crate::auth::TokenStore; @@ -400,6 +403,17 @@ fn build_app_router(config: &AppConfig, state: Arc) -> axum::Router { let app = app.layer(RequestBodyLimitLayer::new( config.security.max_body_size.value(), )); + + // Audit middleware: captures every request lifecycle, including those + // rejected by rate-limit / auth before reaching a handler. Layered + // INSIDE `request_id_middleware` (which is added afterwards and so wraps + // this one) so `RequestId` is set in extensions before the audit logic + // reads it. + let app = app.layer(axum::middleware::from_fn_with_state( + state.clone(), + audit_log_layer, + )); + let app = app.layer(axum::middleware::from_fn(request_id_middleware)); // Tape recorder layer: outermost to capture raw HTTP before any transformation. diff --git a/tests/enterprise/snapshot_error_test.rs b/tests/enterprise/snapshot_error_test.rs index cd7cccd1..482cc0bf 100644 --- a/tests/enterprise/snapshot_error_test.rs +++ b/tests/enterprise/snapshot_error_test.rs @@ -1,23 +1,23 @@ //! Snapshot tests for error response formats. //! -//! Each [`AppError`] variant produces a JSON response with a specific HTTP +//! Each [`RequestError`] variant produces a JSON response with a specific HTTP //! status code and error structure. These snapshots detect unintended changes //! to the error API contract. use axum::response::IntoResponse; -use grob::server::AppError; +use grob::server::RequestError; // ── Helpers ───────────────────────────────────────────────────── -/// Extracts status code and parsed JSON body from an AppError. -async fn error_snapshot(error: AppError) -> String { +/// Extracts status code and parsed JSON body from a `RequestError`. +async fn error_snapshot(error: RequestError) -> String { let response = error.into_response(); let status = response.status(); let body_bytes = axum::body::to_bytes(response.into_body(), 1024 * 1024) .await .expect("invariant: in-memory body collection cannot fail"); let json: serde_json::Value = serde_json::from_slice(&body_bytes) - .expect("invariant: AppError always produces valid JSON"); + .expect("invariant: RequestError always produces valid JSON"); format!("status={} body={}", status.as_u16(), json) } @@ -25,16 +25,17 @@ async fn error_snapshot(error: AppError) -> String { #[tokio::test] async fn snapshot_budget_exceeded_error() { - let snap = error_snapshot(AppError::BudgetExceeded( - "Monthly global budget reached: $50.00/$50.00".to_string(), - )) + let snap = error_snapshot(RequestError::BudgetExceeded { + limit_usd: 50.0, + actual_usd: 50.0, + }) .await; insta::assert_snapshot!(snap); } #[tokio::test] async fn snapshot_routing_error() { - let snap = error_snapshot(AppError::RoutingError( + let snap = error_snapshot(RequestError::RoutingError( "no matching model: gpt-unknown".to_string(), )) .await; @@ -43,16 +44,18 @@ async fn snapshot_routing_error() { #[tokio::test] async fn snapshot_provider_error() { - let snap = error_snapshot(AppError::ProviderError( - "upstream timeout after 30s".to_string(), - )) + let snap = error_snapshot(RequestError::ProviderUpstream { + provider: "openai".to_string(), + status: 502, + body: Some("upstream timeout after 30s".to_string()), + }) .await; insta::assert_snapshot!(snap); } #[tokio::test] async fn snapshot_parse_error() { - let snap = error_snapshot(AppError::ParseError( + let snap = error_snapshot(RequestError::ParseError( "invalid JSON at line 1, column 42".to_string(), )) .await; @@ -61,9 +64,44 @@ async fn snapshot_parse_error() { #[tokio::test] async fn snapshot_dlp_blocked_error() { - let snap = error_snapshot(AppError::DlpBlocked( + let snap = error_snapshot(RequestError::DlpBlocked( "secret detected in prompt: sk-***".to_string(), )) .await; insta::assert_snapshot!(snap); } + +#[tokio::test] +async fn snapshot_rate_limited_error() { + let snap = error_snapshot(RequestError::RateLimited { + provider: "anthropic".to_string(), + retry_after_ms: Some(2500), + }) + .await; + insta::assert_snapshot!(snap); +} + +#[tokio::test] +async fn snapshot_unauthorized_error() { + let snap = error_snapshot(RequestError::Unauthorized).await; + insta::assert_snapshot!(snap); +} + +#[tokio::test] +async fn snapshot_forbidden_error() { + let snap = error_snapshot(RequestError::Forbidden( + "policy denies model 'claude-opus-4-7' for tenant 'public'".to_string(), + )) + .await; + insta::assert_snapshot!(snap); +} + +#[tokio::test] +async fn snapshot_auth_revoked_error() { + let snap = error_snapshot(RequestError::AuthRevoked( + "OAuth token for provider 'anthropic' revoked. Run: grob connect --force-reauth" + .to_string(), + )) + .await; + insta::assert_snapshot!(snap); +} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap new file mode 100644 index 00000000..342ae8c6 --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_auth_revoked_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=401 body={"error":{"type":"authentication_error","message":"OAuth token for provider 'anthropic' revoked. Run: grob connect --force-reauth"}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap index f23ad3cc..b3570bb4 100644 --- a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_budget_exceeded_error.snap @@ -2,4 +2,4 @@ source: tests/enterprise/snapshot_error_test.rs expression: snap --- -status=402 body={"error":{"type":"budget_exceeded","message":"Monthly global budget reached: $50.00/$50.00"}} +status=402 body={"error":{"type":"budget_exceeded","message":"Budget exceeded: $50.0000 spent of $50.0000 limit","limit_usd":50.0,"actual_usd":50.0}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap new file mode 100644 index 00000000..06f2fc34 --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_forbidden_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=403 body={"error":{"type":"permission_error","message":"policy denies model 'claude-opus-4-7' for tenant 'public'"}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap index 9435a30c..a4cbed28 100644 --- a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_provider_error.snap @@ -2,4 +2,4 @@ source: tests/enterprise/snapshot_error_test.rs expression: snap --- -status=502 body={"error":{"type":"error","message":"upstream timeout after 30s"}} +status=502 body={"error":{"type":"error","message":"upstream timeout after 30s","provider":"openai","upstream_status":502}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap new file mode 100644 index 00000000..b0d62d2d --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_rate_limited_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=429 body={"error":{"type":"rate_limit_error","message":"Provider 'anthropic' rate-limited the request","provider":"anthropic","retry_after_ms":2500}} diff --git a/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap new file mode 100644 index 00000000..003b1d21 --- /dev/null +++ b/tests/enterprise/snapshots/lib__enterprise__snapshot_error_test__snapshot_unauthorized_error.snap @@ -0,0 +1,5 @@ +--- +source: tests/enterprise/snapshot_error_test.rs +expression: snap +--- +status=401 body={"error":{"type":"authentication_error","message":"Missing or invalid credentials"}} diff --git a/tests/integration/audit_middleware_test.rs b/tests/integration/audit_middleware_test.rs new file mode 100644 index 00000000..09f80da1 --- /dev/null +++ b/tests/integration/audit_middleware_test.rs @@ -0,0 +1,244 @@ +//! Integration tests for the audit-log middleware. +//! +//! Exercises [`emit_request_processed`] over a representative cross-section +//! of HTTP statuses (2xx/4xx/5xx) and verifies: +//! +//! * One audit entry is written per request when the dispatch pipeline did +//! not already log. +//! * No entry is written when the response carries the [`AuditedAlready`] +//! marker (de-duplication invariant). +//! * Provider/model headers and error variant tags are preserved in the +//! audit entry. + +use axum::http::{HeaderValue, Method, Request, Response, StatusCode}; +use grob::security::audit_log::{AuditConfig, AuditEntry, AuditEvent, AuditLog, SigningAlgorithm}; +use grob::server::{ + capture_audit_input, emit_request_processed, AuditMiddlewareCapture, AuditedAlready, +}; +use tempfile::TempDir; + +fn build_audit_log() -> (TempDir, AuditLog) { + let dir = TempDir::new().expect("tempdir"); + let log = AuditLog::new(AuditConfig { + log_dir: dir.path().to_path_buf(), + sign_key_path: None, + signing_algorithm: SigningAlgorithm::default(), + hmac_key_path: None, + batch_size: 1, + flush_interval_ms: 5000, + include_merkle_proof: false, + }) + .expect("audit log"); + (dir, log) +} + +fn read_entries(dir: &TempDir) -> Vec { + let path = dir.path().join("current.jsonl"); + if !path.exists() { + return vec![]; + } + std::fs::read_to_string(&path) + .expect("read jsonl") + .lines() + .filter(|l| !l.trim().is_empty()) + .map(|l| serde_json::from_str(l).expect("parse audit entry")) + .collect() +} + +fn make_capture() -> AuditMiddlewareCapture { + AuditMiddlewareCapture { + method: Method::POST, + path: "/v1/messages".to_string(), + request_id: "req-test-001".to_string(), + tenant_id: "tenant-alpha".to_string(), + client_ip: "10.0.0.1".to_string(), + started_at: std::time::Instant::now(), + } +} + +fn build_response(status: StatusCode) -> Response { + Response::builder() + .status(status) + .body(axum::body::Body::empty()) + .expect("build response") +} + +#[test] +fn audit_middleware_emits_one_entry_per_request_lifecycle() { + let (dir, audit) = build_audit_log(); + + // Simulate five requests with different outcomes. + let cases = [ + StatusCode::OK, // 200 + StatusCode::BAD_REQUEST, // 400 + StatusCode::UNAUTHORIZED, // 401 + StatusCode::TOO_MANY_REQUESTS, // 429 + StatusCode::INTERNAL_SERVER_ERROR, // 500 + ]; + + for status in cases { + let capture = make_capture(); + let response = build_response(status); + let written = emit_request_processed(&audit, &capture, &response); + assert!( + written, + "audit middleware must emit an entry for {}", + status + ); + } + + let entries = read_entries(&dir); + assert_eq!( + entries.len(), + cases.len(), + "expected one audit entry per request" + ); + for entry in &entries { + assert!(matches!(entry.action, AuditEvent::RequestProcessed)); + assert_eq!(entry.tenant_id, "tenant-alpha"); + assert_eq!(entry.ip_source, "10.0.0.1"); + } +} + +#[test] +fn audit_middleware_skips_when_handler_already_audited() { + let (dir, audit) = build_audit_log(); + + let capture = make_capture(); + let mut response = build_response(StatusCode::OK); + response.extensions_mut().insert(AuditedAlready); + + let written = emit_request_processed(&audit, &capture, &response); + assert!( + !written, + "audit middleware must not double-log when handler already logged" + ); + + let entries = read_entries(&dir); + assert!( + entries.is_empty(), + "expected zero audit entries from middleware when AuditedAlready is set, got {}", + entries.len() + ); +} + +#[test] +fn audit_middleware_captures_provider_from_response_header() { + let (dir, audit) = build_audit_log(); + + let capture = make_capture(); + let mut response = build_response(StatusCode::OK); + response + .headers_mut() + .insert("x-ai-provider", HeaderValue::from_static("anthropic")); + response + .headers_mut() + .insert("x-ai-model", HeaderValue::from_static("claude-opus-4-7")); + + let written = emit_request_processed(&audit, &capture, &response); + assert!(written); + + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].backend_routed, "anthropic"); + assert_eq!( + entries[0].model_name.as_deref(), + Some("claude-opus-4-7"), + "model name should be captured from x-ai-model header" + ); +} + +#[test] +fn audit_middleware_includes_error_variant_tag_in_dlp_rules_field() { + let (dir, audit) = build_audit_log(); + + let capture = make_capture(); + let mut response = build_response(StatusCode::PAYMENT_REQUIRED); + response + .extensions_mut() + .insert(grob::server::ErrorVariantTag("budget_exceeded".to_string())); + + let written = emit_request_processed(&audit, &capture, &response); + assert!(written); + + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert!(entries[0] + .dlp_rules_triggered + .iter() + .any(|r| r.contains("budget_exceeded") && r.contains("status=402"))); +} + +#[test] +fn audit_middleware_emits_low_risk_for_2xx() { + let (dir, audit) = build_audit_log(); + let capture = make_capture(); + let response = build_response(StatusCode::OK); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].risk_level, + Some(grob::security::audit_log::RiskLevel::Low) + ); +} + +#[test] +fn audit_middleware_emits_medium_risk_for_4xx() { + let (dir, audit) = build_audit_log(); + let capture = make_capture(); + let response = build_response(StatusCode::BAD_REQUEST); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].risk_level, + Some(grob::security::audit_log::RiskLevel::Medium) + ); +} + +#[test] +fn audit_middleware_emits_high_risk_for_5xx() { + let (dir, audit) = build_audit_log(); + let capture = make_capture(); + let response = build_response(StatusCode::INTERNAL_SERVER_ERROR); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].risk_level, + Some(grob::security::audit_log::RiskLevel::High) + ); +} + +#[test] +fn audit_middleware_falls_back_to_client_ip_when_tenant_missing() { + let (dir, audit) = build_audit_log(); + let mut capture = make_capture(); + capture.tenant_id = String::new(); + let response = build_response(StatusCode::OK); + emit_request_processed(&audit, &capture, &response); + let entries = read_entries(&dir); + assert_eq!(entries.len(), 1); + assert_eq!( + entries[0].tenant_id, "10.0.0.1", + "anonymous request should be tagged by client IP" + ); +} + +#[test] +fn capture_audit_input_picks_up_request_id_from_extensions() { + let mut request: Request = Request::builder() + .method(Method::GET) + .uri("/v1/models") + .body(axum::body::Body::empty()) + .unwrap(); + request + .extensions_mut() + .insert(grob::server::RequestId("rid-abc-123".to_string())); + + let capture = capture_audit_input(&request); + assert_eq!(capture.request_id, "rid-abc-123"); + assert_eq!(capture.method, Method::GET); + assert_eq!(capture.path, "/v1/models"); +} diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 398f996e..8498d8f4 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -1,4 +1,5 @@ // Integration tests module +mod audit_middleware_test; mod cache_test; mod compliance_test; mod dlp_test;