diff --git a/crates/config/src/lib.rs b/crates/config/src/lib.rs index 7135300d0..5529dfb55 100644 --- a/crates/config/src/lib.rs +++ b/crates/config/src/lib.rs @@ -332,6 +332,11 @@ pub struct ConfigToml { pub tools: Option, #[serde(default)] pub providers: ProvidersToml, + /// Provider fallback chain (#2574). When the active provider returns a + /// retryable error (429, 5xx, timeout), CodeWhale tries the next provider + /// in this list without user intervention. + #[serde(default)] + pub fallback_providers: Vec, /// Per-domain network policy (#135). When absent, network tools fall back /// to a permissive default that mirrors pre-v0.7.0 behavior. #[serde(default)] @@ -357,6 +362,63 @@ pub struct ConfigToml { pub extras: BTreeMap, } +// ── Provider Fallback Chain (#2574) ───────────────────────────────── + +/// Represents a position within the fallback chain during a session. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ProviderChain { + /// The full fallback chain: [active, fallback_1, fallback_2, ...]. + pub providers: Vec, + /// Current position in the chain (0 = active provider). + pub position: usize, +} + +impl ProviderChain { + /// Build a chain from the active provider and optional fallbacks. + /// The active provider is always at position 0. Duplicates are removed. + #[must_use] + pub fn new(active: ProviderKind, fallbacks: &[ProviderKind]) -> Self { + let mut providers = vec![active]; + for fb in fallbacks { + if *fb != active && !providers.contains(fb) { + providers.push(*fb); + } + } + Self { + providers, + position: 0, + } + } + + pub fn current(&self) -> ProviderKind { + self.providers + .get(self.position) + .copied() + .unwrap_or(self.providers[0]) + } + + pub fn has_next(&self) -> bool { + self.position + 1 < self.providers.len() + } + + pub fn advance(&mut self) -> Option { + if self.has_next() { + self.position += 1; + Some(self.current()) + } else { + None + } + } + + pub fn is_fallback_active(&self) -> bool { + self.position > 0 + } + + pub fn remaining(&self) -> usize { + self.providers.len().saturating_sub(self.position) + } +} + /// On-disk schema for the `[hook_sinks]` table. #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct HookSinksToml { @@ -5087,4 +5149,77 @@ model = "mimo-v2.5-pro" assert_eq!(resolved.api_key.as_deref(), Some("cli-key")); assert_eq!(resolved.api_key_source, Some(RuntimeApiKeySource::Cli)); } + + // ── ProviderChain tests (#2574) ───────────────────────────── + + #[test] + fn provider_chain_initial_current_is_active() { + let chain = ProviderChain::new( + ProviderKind::NvidiaNim, + &[ProviderKind::Deepseek, ProviderKind::Openrouter], + ); + assert_eq!(chain.current(), ProviderKind::NvidiaNim); + assert_eq!(chain.position, 0); + assert!(!chain.is_fallback_active()); + } + + #[test] + fn provider_chain_advance_switches_to_fallback() { + let mut chain = ProviderChain::new( + ProviderKind::NvidiaNim, + &[ProviderKind::Deepseek, ProviderKind::Openrouter], + ); + assert!(chain.has_next()); + let next = chain.advance(); + assert_eq!(next, Some(ProviderKind::Deepseek)); + assert_eq!(chain.current(), ProviderKind::Deepseek); + assert!(chain.is_fallback_active()); + } + + #[test] + fn provider_chain_exhausts_returns_none() { + let mut chain = ProviderChain::new(ProviderKind::Deepseek, &[ProviderKind::Openrouter]); + assert!(chain.advance().is_some()); // -> Openrouter + assert!(!chain.has_next()); + assert_eq!(chain.advance(), None); + } + + #[test] + fn provider_chain_skips_duplicates() { + let chain = ProviderChain::new( + ProviderKind::Deepseek, + &[ + ProviderKind::Deepseek, + ProviderKind::NvidiaNim, + ProviderKind::Deepseek, + ], + ); + assert_eq!(chain.providers.len(), 2); + assert_eq!( + chain.providers, + vec![ProviderKind::Deepseek, ProviderKind::NvidiaNim] + ); + } + + #[test] + fn provider_chain_remaining_counts_correctly() { + let chain = ProviderChain::new( + ProviderKind::Deepseek, + &[ProviderKind::NvidiaNim, ProviderKind::Openrouter], + ); + assert_eq!(chain.remaining(), 3); + } + + #[test] + fn config_toml_parses_fallback_providers() { + let toml_str = r#" +provider = "nvidia-nim" +fallback_providers = ["deepseek", "openrouter"] +"#; + let config: ConfigToml = toml::from_str(toml_str).unwrap(); + assert_eq!(config.provider, ProviderKind::NvidiaNim); + assert_eq!(config.fallback_providers.len(), 2); + assert_eq!(config.fallback_providers[0], ProviderKind::Deepseek); + assert_eq!(config.fallback_providers[1], ProviderKind::Openrouter); + } } diff --git a/crates/tui/src/commands/provider.rs b/crates/tui/src/commands/provider.rs index 911e6299b..5fba37907 100644 --- a/crates/tui/src/commands/provider.rs +++ b/crates/tui/src/commands/provider.rs @@ -28,6 +28,11 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult { let name = parts.next().unwrap_or(""); let model_arg = parts.next(); + // `/provider fallback` — show or reset the fallback chain. + if name == "fallback" { + return provider_fallback(app, model_arg); + } + let Some(target) = ApiProvider::parse(name) else { return CommandResult::error(format!( "Unknown provider '{name}'. Expected: {}.", @@ -70,6 +75,40 @@ pub fn provider(app: &mut App, args: Option<&str>) -> CommandResult { }) } +/// `/provider fallback` — shows the current fallback chain and status, +/// or resets it with `/provider fallback reset`. +fn provider_fallback(app: &mut App, sub: Option<&str>) -> CommandResult { + match sub { + Some("reset") => { + app.reset_fallback(); + CommandResult::message("Fallback chain reset to primary provider.") + } + _ => { + if app.fallback_providers.is_empty() { + return CommandResult::message( + "No fallback providers configured. Add `fallback_providers` to your config.", + ); + } + let active = app.api_provider.as_str().to_string(); + let current_fallback = app.fallback_depth; + let mut lines = vec![format!("Active: {active}")]; + for (i, name) in app.fallback_providers.iter().enumerate() { + let marker = if current_fallback == Some(i) { + " ◀ current" + } else { + "" + }; + lines.push(format!(" [{i}] {name}{marker}")); + } + if let Some(ref reason) = app.last_fallback_reason { + lines.push(format!("Last fallback: {reason}")); + } + lines.push("Use `/provider fallback reset` to return to the primary provider.".into()); + CommandResult::message(lines.join("\n")) + } + } +} + fn expand_model_alias_for_provider(provider: ApiProvider, name: &str) -> String { let lower = name.trim().to_ascii_lowercase(); if matches!(provider, ApiProvider::XiaomiMimo) { diff --git a/crates/tui/src/tui/app.rs b/crates/tui/src/tui/app.rs index 3eb494f16..1fe956442 100644 --- a/crates/tui/src/tui/app.rs +++ b/crates/tui/src/tui/app.rs @@ -1212,6 +1212,14 @@ pub struct App { /// Updated by `/provider` switches so the UI/commands can read the /// active backend without re-deriving it from the live config. pub api_provider: ApiProvider, + /// Provider fallback providers in route-name form (#2574). + /// e.g. `["deepseek", "openrouter"]`. Empty when no fallbacks configured. + pub fallback_providers: Vec, + /// Current position in the fallback chain. 0 = active provider, + /// 1+ = fallback provider. `None` when fallback is not active. + pub fallback_depth: Option, + /// Human-readable description of the last fallback event (for UI display). + pub last_fallback_reason: Option, /// True when the active provider/base URL accepts arbitrary model IDs /// verbatim rather than DeepSeek-only aliases. pub model_ids_passthrough: bool, @@ -2002,6 +2010,9 @@ impl App { auto_model, last_effective_model: None, api_provider: provider, + fallback_providers: Vec::new(), + fallback_depth: None, + last_fallback_reason: None, model_ids_passthrough, reasoning_effort, last_effective_reasoning_effort: None, @@ -4938,6 +4949,65 @@ pub enum AppAction { }, } +// ── Provider Fallback helpers (#2574) ──────────────────────────── + +impl App { + /// Advance to the next provider in the fallback chain. Call this when + /// a retryable error (429, 5xx, timeout) exhausts per-request retries. + /// Returns `true` if fallback executed. + #[allow(dead_code)] // Called by runtime integration (follow-up PR) + pub fn advance_fallback(&mut self, reason: impl Into) -> bool { + if self.fallback_providers.is_empty() { + return false; + } + // When fallback_depth is None, the primary provider is active. + // The first fallback goes to index 0, not index 1. + let next_depth = match self.fallback_depth { + None => 0, + Some(d) => d + 1, + }; + if next_depth >= self.fallback_providers.len() { + self.last_fallback_reason = Some(format!( + "Fallback chain exhausted after {}: {}", + self.fallback_providers.len(), + reason.into() + )); + return false; + } + let next_name = &self.fallback_providers[next_depth]; + if let Some(next_provider) = ApiProvider::parse(next_name) { + self.fallback_depth = Some(next_depth); + self.last_fallback_reason = + Some(format!("Fell back to {}: {}", next_name, reason.into())); + self.api_provider = next_provider; + true + } else { + self.last_fallback_reason = Some(format!("Unknown fallback provider: {next_name}")); + false + } + } + + /// Reset the fallback chain to the primary provider. + pub fn reset_fallback(&mut self) { + self.fallback_depth = None; + self.last_fallback_reason = None; + } + + /// Whether a fallback provider is currently active. + pub fn is_fallback_active(&self) -> bool { + self.fallback_depth.unwrap_or(0) > 0 + } + + /// Initialize fallback providers from the on-disk ConfigToml. + #[allow(dead_code)] // Called at startup (follow-up PR) + pub fn load_fallback_from_toml(&mut self, raw_providers: &[codewhale_config::ProviderKind]) { + self.fallback_providers = raw_providers + .iter() + .map(|k| k.as_str().to_string()) + .collect(); + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum ShellJobAction { List, diff --git a/crates/tui/src/tui/footer_ui.rs b/crates/tui/src/tui/footer_ui.rs index 02dc8ce55..2ae91682a 100644 --- a/crates/tui/src/tui/footer_ui.rs +++ b/crates/tui/src/tui/footer_ui.rs @@ -858,6 +858,9 @@ pub(crate) fn footer_status_line_spans(app: &App, max_width: usize) -> Vec (&'static str, ratatui::style::Color) { + if app.is_fallback_active() { + return ("fallback \u{2192}", app.ui_theme.status_warning); + } if app.is_compacting { return ("compacting \u{238B}", app.ui_theme.status_warning); } diff --git a/crates/tui/src/tui/ui.rs b/crates/tui/src/tui/ui.rs index b23f4fadf..6c44b6a74 100644 --- a/crates/tui/src/tui/ui.rs +++ b/crates/tui/src/tui/ui.rs @@ -4502,6 +4502,30 @@ pub(crate) fn apply_engine_error_to_app( ); return; } + // Provider fallback: when the error is recoverable (429, 5xx, timeout, + // network) and fallback providers are configured, advance the chain + // instead of going offline. The user can re-submit manually or via + // undo (Ctrl+Z). + if recoverable + && matches!( + envelope.category, + crate::error_taxonomy::ErrorCategory::Network + | crate::error_taxonomy::ErrorCategory::RateLimit + | crate::error_taxonomy::ErrorCategory::Timeout + ) + && !app.fallback_providers.is_empty() + { + let advanced = app.advance_fallback(&message); + if advanced { + app.status_message = Some(format!( + "Switched to {} (fallback {}/{}): {message}", + app.api_provider.as_str(), + app.fallback_depth.map_or(0, |d| d + 1), + app.fallback_providers.len() + )); + return; + } + } if !recoverable { app.offline_mode = true; }