Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/cli/config/security.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ pub struct SecurityConfig {
/// Persist scores across restarts (default false)
#[serde(default)]
pub scoring_persist: bool,
/// When `true`, requests without an `X-Tenant-ID` header *and* without a
/// JWT `tenant` claim are rejected with HTTP 400. Use in regulated
/// multi-tenant deployments (HDS, SecNumCloud) where audit logs must be
/// keyed on a non-anonymous tenant id.
#[serde(default)]
pub strict_tenant: bool,
}

impl Default for SecurityConfig {
Expand All @@ -83,6 +89,7 @@ impl Default for SecurityConfig {
scoring_window_size: default_scoring_window_size(),
scoring_decay_rate: default_scoring_decay_rate(),
scoring_persist: false,
strict_tenant: false,
}
}
}
Expand Down
117 changes: 113 additions & 4 deletions src/features/token_pricing/spend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ impl SpendTracker {
.join("spend.json")
}

/// Record spend for a request
/// Record spend for a request without tenant context.
///
/// Internally bucketed under [`crate::storage::DEFAULT_TENANT`] so the
/// per-tenant budget machinery treats legacy callers uniformly.
pub fn record(&mut self, provider: &str, model: &str, cost: f64) {
if let Some(ref store) = self.store {
store.record_spend(None, cost, provider, model);
Expand All @@ -134,13 +137,27 @@ impl SpendTracker {
}
}

/// Record spend for a specific tenant
/// Record spend for a specific tenant.
///
/// The event lands in the per-tenant journal and the per-tenant
/// in-memory cache. It is **not** also accumulated into the global
/// counter — that previous behaviour caused a per-tenant overspend to
/// trip the global budget for every other tenant. The global journal
/// still receives a tagged copy of the event so existing exports keep
/// working unchanged.
pub fn record_tenant(&mut self, tenant: &str, provider: &str, model: &str, cost: f64) {
if let Some(ref store) = self.store {
store.record_spend(Some(tenant), cost, provider, model);
// Refresh the local view of the global cache so accessors that
// surface "request count seen" stay consistent. Global $ totals
// intentionally do not include tenant-tagged spend.
self.data = store.load_spend(None);
} else {
// No store available (test/CLI mode): track the tenant amount
// locally so future per-tenant budget checks behave correctly.
// Global counters are intentionally untouched.
self.reset_if_new_month();
}
// Also record to global
self.record(provider, model, cost);
}

/// Get total spend for current month
Expand Down Expand Up @@ -203,6 +220,79 @@ impl SpendTracker {
}
}

/// Check if a request should be allowed for the given tenant.
///
/// Per-tenant budgets are enforced against the tenant's isolated spend
/// cache, so a single tenant exceeding its quota cannot block another
/// tenant. When `tenant` is `None`, [`crate::storage::DEFAULT_TENANT`]
/// is used so legacy callers still go through the per-tenant path.
///
/// # Errors
///
/// Returns [`BudgetError`] when the tenant has reached the supplied
/// `tenant_limit`. Provider and model sub-limits are evaluated against
/// the same tenant-local spend so one tenant cannot starve another by
/// hitting a shared provider cap.
pub fn check_tenant_budget(
&self,
tenant: Option<&str>,
provider: &str,
model: &str,
tenant_limit: f64,
provider_limit: Option<f64>,
model_limit: Option<f64>,
) -> Result<(), BudgetError> {
let Some(ref store) = self.store else {
// Without a store we cannot persist per-tenant state; fall back
// to the global check to preserve legacy CLI/test behaviour.
return self.check_budget(provider, model, tenant_limit, provider_limit, model_limit);
};
let tenant_key = tenant.unwrap_or(crate::storage::DEFAULT_TENANT);
let data = store.load_spend(Some(tenant_key));

if let Some(limit) = model_limit {
let spend = data.by_model.get(model).copied().unwrap_or(0.0);
if spend >= limit {
return Err(BudgetError {
message: format!(
"Monthly budget for tenant '{tenant_key}' model '{model}' reached: \
${spend:.2}/${limit:.2}"
),
limit_usd: limit,
actual_usd: spend,
});
}
}

if let Some(limit) = provider_limit {
let spend = data.by_provider.get(provider).copied().unwrap_or(0.0);
if spend >= limit {
return Err(BudgetError {
message: format!(
"Monthly budget for tenant '{tenant_key}' provider '{provider}' \
reached: ${spend:.2}/${limit:.2}"
),
limit_usd: limit,
actual_usd: spend,
});
}
}

if tenant_limit > 0.0 && data.total >= tenant_limit {
return Err(BudgetError {
message: format!(
"Monthly budget for tenant '{tenant_key}' reached: \
${:.2}/${:.2}",
data.total, tenant_limit
),
limit_usd: tenant_limit,
actual_usd: data.total,
});
}

Ok(())
}

/// Check if a request should be allowed given budget limits.
pub fn check_budget(
&self,
Expand Down Expand Up @@ -331,6 +421,25 @@ impl crate::traits::SpendTracking for SpendTracker {
self.check_budget(provider, model, global_limit, provider_limit, model_limit)
}

fn check_tenant_budget(
&self,
tenant: Option<&str>,
provider: &str,
model: &str,
tenant_limit: f64,
provider_limit: Option<f64>,
model_limit: Option<f64>,
) -> Result<(), BudgetError> {
self.check_tenant_budget(
tenant,
provider,
model,
tenant_limit,
provider_limit,
model_limit,
)
}

fn total(&self) -> f64 {
self.total()
}
Expand Down
2 changes: 1 addition & 1 deletion src/providers/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ mod tests {
calls: AtomicUsize,
}
impl SecretBackend for CountingBackend {
fn get(&self, name: &str) -> Option<SecretString> {
fn get(&self, _tenant: &str, name: &str) -> Option<SecretString> {
self.calls.fetch_add(1, Ordering::SeqCst);
if name == "openrouter" {
Some(SecretString::new("sk-resolved-real-key".into()))
Expand Down
30 changes: 27 additions & 3 deletions src/server/budget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ pub(crate) fn record_request_metrics(m: &RequestMetrics<'_>) {
}
}

/// Check budget before a request. Returns `Err(RequestError::BudgetExceeded)` if any limit is hit.
pub(crate) async fn check_budget(
/// Check budget before a request, scoped to a specific tenant.
///
/// Per-tenant overspend is enforced against a tenant-isolated counter so a
/// single tenant exceeding its quota cannot block other tenants. The global
/// counter is still consulted for un-tagged callers and provides the
/// rate-limiting baseline for non-tenant-aware deployments.
pub(crate) async fn check_budget_for_tenant(
state: &Arc<AppState>,
inner: &Arc<ReloadableState>,
provider_name: &str,
model_name: &str,
tenant_id: Option<&str>,
) -> Result<(), RequestError> {
let budget_config = &inner.config.budget;
let global_limit = budget_config.monthly_limit_usd.value();
Expand All @@ -85,7 +91,25 @@ pub(crate) async fn check_budget(

let tracker = state.observability.spend_tracker.lock().await;

if let Err(e) = tracker.check_budget(
// Per-tenant limits use the same numeric caps as the global config; in
// a future revision they will key on a `[budget.tenants]` map. Tenants
// overspending their slice cannot trip the global counter for other
// tenants because `check_tenant_budget` reads the per-tenant cache.
if let Some(tenant) = tenant_id {
if let Err(e) = tracker.check_tenant_budget(
Some(tenant),
provider_name,
model_name,
global_limit,
provider_limit,
model_limit,
) {
return Err(RequestError::BudgetExceeded {
limit_usd: e.limit_usd,
actual_usd: e.actual_usd,
});
}
} else if let Err(e) = tracker.check_budget(
provider_name,
model_name,
global_limit,
Expand Down
5 changes: 3 additions & 2 deletions src/server/dispatch/provider_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! is written before returning `RequestError::ProviderUpstream`.

use super::super::{
check_budget, format_route_type, inject_continuation_text, is_provider_subscription,
check_budget_for_tenant, format_route_type, inject_continuation_text, is_provider_subscription,
should_inject_continuation, RequestError,
};
use super::resolver::{resolve_provider, try_direct_provider_lookup};
Expand Down Expand Up @@ -55,11 +55,12 @@ pub(super) async fn dispatch_provider_loop(
continue;
};

check_budget(
check_budget_for_tenant(
ctx.state,
ctx.inner,
&mapping.provider,
&decision.model_name,
ctx.tenant_id.as_deref(),
)
.await?;

Expand Down
27 changes: 21 additions & 6 deletions src/server/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,30 @@ use super::{
responses_compat, should_apply_transparency, AppState, RequestError, RequestId,
};

/// Extracts tenant_id from VirtualKeyContext (preferred) or GrobClaims.
/// Extracts tenant_id with this priority:
/// 1. VirtualKeyContext (operator-provisioned binding)
/// 2. JWT `tenant` claim
/// 3. `X-Tenant-ID` request header
///
/// JWT and VirtualKey paths cannot be overridden by the client header so a
/// caller cannot impersonate another tenant in authenticated mode. The
/// header path is only consulted when no authenticated tenant exists.
fn extract_tenant_id(
vk_ctx: &Option<axum::Extension<crate::auth::virtual_keys::VirtualKeyContext>>,
claims: &Option<axum::Extension<crate::auth::GrobClaims>>,
headers: &HeaderMap,
) -> Option<String> {
vk_ctx
.as_ref()
.map(|vk| vk.tenant_id.clone())
.or_else(|| claims.as_ref().map(|c| c.tenant_id().to_string()))
if let Some(vk) = vk_ctx.as_ref() {
return Some(vk.tenant_id.clone());
}
if let Some(c) = claims.as_ref() {
return Some(c.tenant_id().to_string());
}
headers
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}

/// Drop guard that decrements the active request counter
Expand Down Expand Up @@ -108,7 +123,7 @@ fn prepare_dispatch(
vk_ctx: &Option<axum::Extension<crate::auth::virtual_keys::VirtualKeyContext>>,
headers: &HeaderMap,
) -> DispatchPrelude {
let tenant_id = extract_tenant_id(vk_ctx, claims);
let tenant_id = extract_tenant_id(vk_ctx, claims, headers);
let peer_ip = extract_client_ip(headers);
let inner = state.snapshot();
let session_key = tenant_id
Expand Down
60 changes: 60 additions & 0 deletions src/server/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,66 @@ pub(crate) async fn audit_log_layer(
response
}

/// Returns the tenant id derived from authentication context or headers.
///
/// Mirrors `handlers::extract_tenant_id` but uses request extensions /
/// headers directly so middleware can short-circuit before the handler.
fn middleware_tenant_id(request: &Request<Body>) -> Option<String> {
if let Some(vk) = request
.extensions()
.get::<crate::auth::virtual_keys::VirtualKeyContext>()
{
return Some(vk.tenant_id.clone());
}
if let Some(claims) = request.extensions().get::<crate::auth::GrobClaims>() {
return Some(claims.tenant_id().to_string());
}
request
.headers()
.get("x-tenant-id")
.and_then(|v| v.to_str().ok())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}

/// Enforces `[security] strict_tenant`.
///
/// When the flag is enabled, requests that fail to resolve a tenant id
/// (no virtual-key binding, no JWT `tenant` claim, no `X-Tenant-ID`
/// header) are rejected with HTTP 400 and a structured JSON body. Health
/// and OAuth endpoints are exempt because they are dispatched before any
/// tenant context exists.
pub(crate) async fn tenant_required_middleware(
State(state): State<Arc<AppState>>,
request: Request<Body>,
next: Next,
) -> Response {
let path = request.uri().path();
if matches!(
path,
"/health" | "/live" | "/ready" | "/metrics" | "/auth/callback" | "/api/oauth/callback"
) {
return next.run(request).await;
}

let inner = state.snapshot();
if !inner.config.security.strict_tenant {
return next.run(request).await;
}

if middleware_tenant_id(&request).is_some() {
return next.run(request).await;
}

let body = Json(serde_json::json!({
"error": {
"type": "missing_tenant",
"message": "X-Tenant-ID header or JWT tenant claim required when [security] strict_tenant=true"
}
}));
(StatusCode::BAD_REQUEST, body).into_response()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
14 changes: 11 additions & 3 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ mod watch_sse;
pub use audit::AuditEntryBuilder;
pub(crate) use audit::{log_audit, AuditCompliance, AuditParams};
pub(crate) use budget::{
calculate_cost, check_budget, is_auth_revoked_error, is_provider_subscription, is_retryable,
record_request_metrics, record_spend, retry_delay, RequestMetrics, MAX_RETRIES,
calculate_cost, check_budget_for_tenant, is_auth_revoked_error, is_provider_subscription,
is_retryable, record_request_metrics, record_spend, retry_delay, RequestMetrics, MAX_RETRIES,
};
pub use error::{ErrorVariantTag, RequestError};
pub(crate) use helpers::{
Expand All @@ -51,7 +51,7 @@ pub(crate) use init::{
pub(crate) use middleware::{
apply_transparency_headers, audit_log_layer, auth_middleware, extract_api_credential,
extract_client_ip, rate_limit_check_middleware, request_id_middleware,
security_headers_response_middleware, should_apply_transparency,
security_headers_response_middleware, should_apply_transparency, tenant_required_middleware,
};
pub use middleware::{
capture_audit_input, emit_request_processed, AuditMiddlewareCapture, AuditedAlready, RequestId,
Expand Down Expand Up @@ -382,6 +382,14 @@ fn build_app_router(config: &AppConfig, state: Arc<AppState>) -> axum::Router {
app
};

// tenant_required runs *after* auth so the GrobClaims / VirtualKeyContext
// are already populated; in axum the from_fn applied first becomes the
// innermost layer, so it must be added before auth_middleware below.
let app = app.layer(axum::middleware::from_fn_with_state(
state.clone(),
tenant_required_middleware,
));

let app = app.layer(axum::middleware::from_fn_with_state(
state.clone(),
auth_middleware,
Expand Down
Loading
Loading