diff --git a/.env.example b/.env.example index 95db3ebb..3bd9906d 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,17 @@ # Grok API Configuration GROK_API_KEY=your_grok_api_key_here +# Optional: Use Google Cloud Vertex AI instead of a native xAI API key. +# Requires Application Default Credentials: +# gcloud auth application-default login +# GROK_USE_VERTEX=1 +# GROK_VERTEX_PROJECT_ID=your-gcp-project-id +# GROK_VERTEX_LOCATION=us-central1 +# Advanced/custom environments only; default is the global Vertex host below. +# GROK_VERTEX_BASE_URL=https://aiplatform.googleapis.com +# Emergency fallback only: disable Vertex tool/function declarations. +# GROK_VERTEX_DISABLE_TOOLS=1 + # Optional: Custom API base URL (default: https://api.x.ai/v1) # GROK_BASE_URL=https://api.x.ai/v1 diff --git a/.npmignore b/.npmignore index 11ba487f..1fcddf39 100644 --- a/.npmignore +++ b/.npmignore @@ -1,6 +1,9 @@ # Source files (only include built dist/) src/ tsconfig.json +vitest.config.ts +biome.json +test-vertex-integration.ts # Lock files from other package managers yarn.lock @@ -50,6 +53,9 @@ tests/ *.test.ts *.spec.js *.spec.ts +dist/**/*.test.* +dist/**/*.spec.* +dist/grok-standalone* coverage/ .nyc_output/ @@ -62,6 +68,15 @@ docs/ .git/ .gitignore +# Local development, agent, and editor runtime state +.cursor/ +.omx/ +.codex/ +.agents/ +.claude/ +.grok/ +.husky/ + # CI/CD files .github/ .gitlab-ci.yml @@ -77,4 +92,4 @@ appveyor.yml # Include only what's needed for the package # The dist/ folder should be included (not ignored) -# package.json and README.md should be included \ No newline at end of file +# package.json and README.md should be included diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d10530c..ea8d72e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Native Google Cloud Vertex AI support for Grok chat completions via `GROK_USE_VERTEX=1`, Application Default Credentials, and a Vertex payload/stream adapter. +- Vertex mode now forwards local CLI tools through sanitized Vertex function declarations, preserving directory/file/bash access while avoiding raw AI SDK schema fields that Vertex rejects. +- Vertex requests map native xAI model IDs such as `grok-4-1-fast-reasoning` to the corresponding Vertex publisher ID such as `grok-4.1-fast-reasoning`. +- Vertex ADC refresh failures such as Google `invalid_rapt` reauthentication errors now produce actionable `gcloud auth application-default` recovery guidance instead of raw auth JSON. +- Interactive Vertex AI authentication setup in the TUI, saving Grok-specific `GROK_VERTEX_PROJECT_ID`, `GROK_VERTEX_LOCATION`, and `GROK_VERTEX_BASE_URL` settings without relying on broad global GCP env names. - Dedicated grep tool powered by npm ripgrep WASM (#263) - `/btw` command for side questions (#264) @@ -151,4 +156,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [1.0.0-rc1] - 2026-03-20 -Initial release. \ No newline at end of file +Initial release. diff --git a/README.md b/README.md index 8450be5a..0311ae06 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ grok uninstall --dry-run grok uninstall --keep-config ``` -**Prerequisites:** a **Grok API key** from [x.ai](https://x.ai) and a modern terminal emulator for the interactive OpenTUI experience. Headless `--prompt` mode does not depend on terminal UI support. If you want host desktop automation via the built-in computer sub-agent, also enable **Accessibility** permission for your terminal app on macOS. +**Prerequisites:** either a **Grok API key** from [x.ai](https://x.ai) or **Google Cloud Vertex AI** access with Application Default Credentials. You also need a modern terminal emulator for the interactive OpenTUI experience. Headless `--prompt` mode does not depend on terminal UI support. If you want host desktop automation via the built-in computer sub-agent, also enable **Accessibility** permission for your terminal app on macOS. --- @@ -77,7 +77,8 @@ grok --verify `--batch-api` uses xAI's Batch API for lower-cost unattended runs. It is a good fit for scripts, CI, schedules, and other non-interactive workflows where a -delayed result is fine. +delayed result is fine. Batch mode is a native xAI endpoint and is not available +when `GROK_USE_VERTEX=1`. **Continue a saved session:** @@ -195,7 +196,11 @@ You keep using a text model for the session, and Grok saves generated media unde --- -## API key (pick one) +## Authentication + +Pick one model authentication path. + +### Native xAI API key **Environment (good for CI):** @@ -239,6 +244,36 @@ Names cannot be `general`, `explore`, `vision`, `verify`, or `computer` because Optional: `**GROK_BASE_URL**` (default `https://api.x.ai/v1`), `**GROK_MODEL**`, `**GROK_MAX_TOKENS**`. +### Google Cloud Vertex AI + +Vertex mode uses Google Application Default Credentials instead of an xAI API key: + +```bash +gcloud auth application-default login +export GROK_USE_VERTEX=1 +export GROK_VERTEX_PROJECT_ID=your-gcp-project-id +export GROK_VERTEX_LOCATION=us-central1 +grok --prompt "hello from Vertex" +``` + +`GROK_USE_VERTEX=1` bypasses `GROK_API_KEY` validation for chat completions and fetches a short-lived OAuth access token with the `cloud-platform` scope. + +In the interactive TUI, choose the **Vertex AI** tab in the authentication modal to save `GROK_VERTEX_PROJECT_ID`, `GROK_VERTEX_LOCATION`, and `GROK_VERTEX_BASE_URL` as Grok-specific user settings in `~/.grok/user-settings.json`. ADC still comes from Google Cloud's normal `gcloud auth application-default login` flow. + +Vertex Grok uses the global API host but a normal location path. By default requests go to: + +```text +https://aiplatform.googleapis.com/v1/projects/$GROK_VERTEX_PROJECT_ID/locations/us-central1/publishers/xai/models/$GROK_MODEL:generateContent +``` + +The adapter maps native xAI model IDs to Vertex publisher model IDs where they differ, for example `grok-4-1-fast-reasoning` becomes `grok-4.1-fast-reasoning` on the Vertex request path. + +Set `GROK_VERTEX_LOCATION` to choose the location path, for example `us-central1` or `europe-west1`. The host defaults to `https://aiplatform.googleapis.com`; `GROK_VERTEX_BASE_URL` is available only for advanced/custom environments. The broader `GCP_PROJECT_ID`, `GCP_REGION`, and `GCP_VERTEX_*` variables remain compatibility fallbacks, but the `GROK_VERTEX_*` names take precedence to avoid clashing with other Google Cloud tools. + +Vertex mode forwards local CLI tools such as `bash`, `read_file`, `grep`, `write_file`, and `edit_file` through Vertex function declarations. The adapter sanitizes the AI SDK JSON schemas into the OpenAPI-style schema subset accepted by Vertex AI. If a specific Google-side model rollout rejects tool declarations, set `GROK_VERTEX_DISABLE_TOOLS=1` as an emergency fallback; that disables local CLI tool access for Vertex mode. + +Native xAI-only endpoints remain native xAI-only in Vertex mode: `--batch-api`, live X/web search tools, image/video generation, and Telegram audio transcription require `GROK_USE_VERTEX` unset plus a configured `GROK_API_KEY`. + --- ## Telegram (remote control) — short version @@ -254,7 +289,7 @@ Send a voice note or audio attachment in Telegram and Grok will transcribe it wi #### Prerequisites -- A valid `GROK_API_KEY` (the same key used for the agent). Transcription reuses the CLI's `apiKey` / `baseURL` resolution, so if the agent can reach xAI, transcription will too. +- A valid `GROK_API_KEY` (the same key used for the agent). Transcription reuses the CLI's `apiKey` / `baseURL` resolution, so if the agent can reach xAI, transcription will too. Vertex ADC alone does not support the native xAI STT endpoint. #### Configure in `~/.grok/user-settings.json` @@ -401,6 +436,31 @@ grok -k your_key_here Get your API key from [x.ai](https://x.ai). +**Vertex project or ADC error** + +For Vertex mode, verify all three are true: + +```bash +export GROK_USE_VERTEX=1 +export GROK_VERTEX_PROJECT_ID=your-gcp-project-id +gcloud auth application-default login +``` + +Use `GROK_VERTEX_LOCATION` for the resource path location. Keep the default host `https://aiplatform.googleapis.com` unless you are testing a custom endpoint. + +If Google returns an ADC reauthentication error such as `invalid_rapt`, refresh the local ADC token: + +```bash +gcloud auth application-default login +``` + +If the token cache is stuck, reset it: + +```bash +gcloud auth application-default revoke +gcloud auth application-default login +``` + ### Terminal UI issues **UI doesn't render correctly** @@ -426,6 +486,7 @@ Ensure your terminal supports true color and Unicode. Update your terminal emula **Voice messages not transcribing** - Verify `GROK_API_KEY` is set (transcription uses the same key) +- If `GROK_USE_VERTEX=1`, unset it for Telegram audio transcription or disable `telegram.audioInput.enabled` - Check `~/.grok/user-settings.json` has `telegram.audioInput.enabled: true` ### Sandbox mode diff --git a/bun.lock b/bun.lock index a77cbfb4..423c2bf8 100644 --- a/bun.lock +++ b/bun.lock @@ -18,6 +18,7 @@ "commander": "^12.1.0", "diff": "^8.0.3", "dotenv": "^16.6.1", + "google-auth-library": "^10.6.2", "grammy": "^1.41.1", "react": "^19.2.4", "ripgrep": "^0.3.1", @@ -795,6 +796,8 @@ "buffer": ["buffer@6.0.3", "", { "dependencies": { "base64-js": "^1.3.1", "ieee754": "^1.2.1" } }, "sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA=="], + "buffer-equal-constant-time": ["buffer-equal-constant-time@1.0.1", "", {}, "sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA=="], + "buffer-reverse": ["buffer-reverse@1.0.1", "", {}, "sha512-M87YIUBsZ6N924W57vDwT/aOu8hw7ZgdByz6ijksLjmHJELBASmYTTlNHRgjE+pTsT9oJXGaDSgqqwfdHotDUg=="], "bufferutil": ["bufferutil@4.1.0", "", { "dependencies": { "node-gyp-build": "^4.3.0" } }, "sha512-ZMANVnAixE6AWWnPzlW2KpUrxhm9woycYvPOo67jWHyFowASTEd9s+QN1EIMsSDtwhIxN4sWE1jotpuDUIgyIw=="], @@ -905,6 +908,8 @@ "csstype": ["csstype@3.2.3", "", {}, "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ=="], + "data-uri-to-buffer": ["data-uri-to-buffer@4.0.1", "", {}, "sha512-0R9ikRb668HB7QDxT1vkpuUBtqc53YyAwMwGeUFKRojY/NWKvdZ+9UYtRfGmhqNbRkTSVpMbmyhXipFFv2cb/A=="], + "date-fns": ["date-fns@3.3.1", "", {}, "sha512-y8e109LYGgoQDveiEBD3DYXKba1jWf5BA8YU1FL5Tvm0BTdEfy54WLCwnuYWZNnzzvALy/QQ4Hov+Q9RVRv+Zw=="], "dayjs": ["dayjs@1.11.13", "", {}, "sha512-oaMBel6gjolK862uaPQOVTA7q3TZhuSvuMQAAglQDOWYO9A91IrAOUJEyKVlqJlHE0vq5p5UXxzdPfMH/x6xNg=="], @@ -947,6 +952,8 @@ "duplexify": ["duplexify@4.1.3", "", { "dependencies": { "end-of-stream": "^1.4.1", "inherits": "^2.0.3", "readable-stream": "^3.1.1", "stream-shift": "^1.0.2" } }, "sha512-M3BmBhwJRZsSx38lZyhE53Csddgzl5R7xGJNk7CVddZD6CcmwMCH8J+7AprIrQKH7TonKxaCjcv27Qmf+sQ+oA=="], + "ecdsa-sig-formatter": ["ecdsa-sig-formatter@1.0.11", "", { "dependencies": { "safe-buffer": "^5.0.1" } }, "sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ=="], + "eciesjs": ["eciesjs@0.4.18", "", { "dependencies": { "@ecies/ciphers": "^0.2.5", "@noble/ciphers": "^1.3.0", "@noble/curves": "^1.9.7", "@noble/hashes": "^1.8.0" } }, "sha512-wG99Zcfcys9fZux7Cft8BAX/YrOJLJSZ3jyYPfhZHqN2E+Ffx+QXBDsv3gubEgPtV6dTzJMSQUwk1H98/t/0wQ=="], "ed2curve": ["ed2curve@0.3.0", "", { "dependencies": { "tweetnacl": "1.x.x" } }, "sha512-8w2fmmq3hv9rCrcI7g9hms2pMunQr1JINfcjwR9tAyZqhtyaMN991lF/ZfHfr5tzZQ8c7y7aBgZbjfbd0fjFwQ=="], @@ -1033,6 +1040,8 @@ "express-rate-limit": ["express-rate-limit@8.3.1", "", { "dependencies": { "ip-address": "10.1.0" }, "peerDependencies": { "express": ">= 4.11" } }, "sha512-D1dKN+cmyPWuvB+G2SREQDzPY1agpBIcTa9sJxOPMCNeH3gwzhqJRDWCXW3gg0y//+LQ/8j52JbMROWyrKdMdw=="], + "extend": ["extend@3.0.2", "", {}, "sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g=="], + "extension-port-stream": ["extension-port-stream@3.0.0", "", { "dependencies": { "readable-stream": "^3.6.2 || ^4.4.2", "webextension-polyfill": ">=0.10.0 <1.0" } }, "sha512-an2S5quJMiy5bnZKEf6AkfH/7r8CzHvhchU40gxN+OM6HPhe7Z9T1FUychcf2M9PpPOO0Hf7BAEfJkw2TDIBDw=="], "eyes": ["eyes@0.1.8", "", {}, "sha512-GipyPsXO1anza0AOZdy69Im7hGFCNB7Y/NGjDlZGJ3GJJLtwNSb2vrzYrTYJRrRloVx7pl+bhUaTB8yiccPvFQ=="], @@ -1053,6 +1062,8 @@ "fdir": ["fdir@6.5.0", "", { "peerDependencies": { "picomatch": "^3 || ^4" }, "optionalPeers": ["picomatch"] }, "sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg=="], + "fetch-blob": ["fetch-blob@3.2.0", "", { "dependencies": { "node-domexception": "^1.0.0", "web-streams-polyfill": "^3.0.3" } }, "sha512-7yAQpD2UMJzLi1Dqv7qFYnPbaPx7ZfFK6PiIxQ4PfkGPyNyl2Ugx+a/umUonmKqjhM4DnfbMvdX6otXq83soQQ=="], + "figures": ["figures@3.2.0", "", { "dependencies": { "escape-string-regexp": "^1.0.5" } }, "sha512-yaduQFRKLXYOGgEn6AZau90j3ggSOyiqXU0F9JZfeXYhNa+Jk4X+s45A2zg5jns87GAFa34BBm2kXw4XpNcbdg=="], "file-type": ["file-type@16.5.4", "", { "dependencies": { "readable-web-to-node-stream": "^3.0.0", "strtok3": "^6.2.4", "token-types": "^4.1.1" } }, "sha512-/yFHK0aGjFEgDJjEKP0pWCplsPFPhwyfwevf/pVxiN0tmE4L9LmwWxWukdJSHdoCli4VgQLehjJtwQBnqmsKcw=="], @@ -1071,6 +1082,8 @@ "form-data": ["form-data@4.0.5", "", { "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", "hasown": "^2.0.2", "mime-types": "^2.1.12" } }, "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w=="], + "formdata-polyfill": ["formdata-polyfill@4.0.10", "", { "dependencies": { "fetch-blob": "^3.1.2" } }, "sha512-buewHzMvYL29jdeQTVILecSaZKnt/RJWjoZCF5OW60Z67/GmSLBkOFM7qh1PI3zFNtJbaZL5eQu1vLfazOwj4g=="], + "forwarded": ["forwarded@0.2.0", "", {}, "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow=="], "fresh": ["fresh@2.0.0", "", {}, "sha512-Rx/WycZ60HOaqLKAi6cHRKKI7zxWbJ31MhntmtwMoaTeF7XFH9hhBp8vITaMidfljRQ6eYWCKkaTK+ykVJHP2A=="], @@ -1081,6 +1094,10 @@ "function-bind": ["function-bind@1.1.2", "", {}, "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA=="], + "gaxios": ["gaxios@7.1.4", "", { "dependencies": { "extend": "^3.0.2", "https-proxy-agent": "^7.0.1", "node-fetch": "^3.3.2" } }, "sha512-bTIgTsM2bWn3XklZISBTQX7ZSddGW+IO3bMdGaemHZ3tbqExMENHLx6kKZ/KlejgrMtj8q7wBItt51yegqalrA=="], + + "gcp-metadata": ["gcp-metadata@8.1.2", "", { "dependencies": { "gaxios": "^7.0.0", "google-logging-utils": "^1.0.0", "json-bigint": "^1.0.0" } }, "sha512-zV/5HKTfCeKWnxG0Dmrw51hEWFGfcF2xiXqcA3+J90WDuP0SvoiSO5ORvcBsifmx/FoIjgQN3oNOGaQ5PhLFkg=="], + "generator-function": ["generator-function@2.0.1", "", {}, "sha512-SFdFmIJi+ybC0vjlHN0ZGVGHc3lgE0DxPAT0djjVg+kjOnSqclqmj0KQ7ykTOLP6YxoqOvuAODGdcHJn+43q3g=="], "get-caller-file": ["get-caller-file@2.0.5", "", {}, "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg=="], @@ -1095,6 +1112,10 @@ "glob": ["glob@13.0.6", "", { "dependencies": { "minimatch": "^10.2.2", "minipass": "^7.1.3", "path-scurry": "^2.0.2" } }, "sha512-Wjlyrolmm8uDpm/ogGyXZXb1Z+Ca2B8NbJwqBVg0axK9GbBeoS7yGV6vjXnYdGm6X53iehEuxxbyiKp8QmN4Vw=="], + "google-auth-library": ["google-auth-library@10.6.2", "", { "dependencies": { "base64-js": "^1.3.0", "ecdsa-sig-formatter": "^1.0.11", "gaxios": "^7.1.4", "gcp-metadata": "8.1.2", "google-logging-utils": "1.1.3", "jws": "^4.0.0" } }, "sha512-e27Z6EThmVNNvtYASwQxose/G57rkRuaRbQyxM2bvYLLX/GqWZ5chWq2EBoUchJbCc57eC9ArzO5wMsEmWftCw=="], + + "google-logging-utils": ["google-logging-utils@1.1.3", "", {}, "sha512-eAmLkjDjAFCVXg7A1unxHsLf961m6y17QFqXqAXGj/gVkKFrEICfStRfwUlGNfeCEjNRa32JEWOUTlYXPyyKvA=="], + "gopd": ["gopd@1.2.0", "", {}, "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg=="], "graceful-fs": ["graceful-fs@4.2.11", "", {}, "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ=="], @@ -1205,6 +1226,8 @@ "js-sha3": ["js-sha3@0.8.0", "", {}, "sha512-gF1cRrHhIzNfToc802P800N8PpXS+evLLXfsVpowqmAFR9uwbi89WvXg2QspOmXL8QL86J4T1EpFu+yUkwJY3Q=="], + "json-bigint": ["json-bigint@1.0.0", "", { "dependencies": { "bignumber.js": "^9.0.0" } }, "sha512-SiPv/8VpZuWbvLSMtTDU8hEfrZWg/mH/nV/b4o0CYbSxu1UIQPLdwKOCIyLQX+VIPO5vrLX3i8qtqFyhdPSUSQ=="], + "json-parse-even-better-errors": ["json-parse-even-better-errors@5.0.0", "", {}, "sha512-ZF1nxZ28VhQouRWhUcVlUIN3qwSgPuswK05s/HIaoetAoE/9tngVmCHjSxmSQPav1nd+lPtTL0YZ/2AFdR/iYQ=="], "json-rpc-engine": ["json-rpc-engine@6.1.0", "", { "dependencies": { "@metamask/safe-event-emitter": "^2.0.0", "eth-rpc-errors": "^4.0.2" } }, "sha512-NEdLrtrq1jUZyfjkr9OCz9EzCNhnRyWtt1PAnvnhwy6e8XETS0Dtc+ZNCO2gvuAoKsIn2+vCSowXTYE4CkgnAQ=="], @@ -1227,6 +1250,10 @@ "just-diff-apply": ["just-diff-apply@5.5.0", "", {}, "sha512-OYTthRfSh55WOItVqwpefPtNt2VdKsq5AnAK6apdtR6yCH8pr0CmSr710J0Mf+WdQy7K/OzMy7K2MgAfdQURDw=="], + "jwa": ["jwa@2.0.1", "", { "dependencies": { "buffer-equal-constant-time": "^1.0.1", "ecdsa-sig-formatter": "1.0.11", "safe-buffer": "^5.0.1" } }, "sha512-hRF04fqJIP8Abbkq5NKGN0Bbr3JxlQ+qhZufXVr0DvujKy93ZCbXZMHDL4EOtodSbCWxOqR8MS1tXA5hwqCXDg=="], + + "jws": ["jws@4.0.1", "", { "dependencies": { "jwa": "^2.0.1", "safe-buffer": "^5.0.1" } }, "sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA=="], + "keccak": ["keccak@3.0.4", "", { "dependencies": { "node-addon-api": "^2.0.0", "node-gyp-build": "^4.2.0", "readable-stream": "^3.6.0" } }, "sha512-3vKuW0jV8J3XNTzvfyicFR5qvxrSAGl7KIhvgOu5cmWwM7tZRj3fMbj/pfIf4be7aznbc+prBWGjywox/g2Y6Q=="], "keyvaluestorage-interface": ["keyvaluestorage-interface@1.0.0", "", {}, "sha512-8t6Q3TclQ4uZynJY9IGr2+SsIGwK9JHcO6ootkHCGA0CrQCRy+VkouYNO2xicET6b9al7QKzpebNow+gkpCL8g=="], @@ -1341,6 +1368,8 @@ "node-addon-api": ["node-addon-api@5.1.0", "", {}, "sha512-eh0GgfEkpnoWDq+VY8OyvYhFEzBk6jIYbRKdIlyTiAXIVJ8PyBaKb0rp7oDtoddbdoHWhq8wwr+XZ81F1rpNdA=="], + "node-domexception": ["node-domexception@1.0.0", "", {}, "sha512-/jKZoMpw0F8GRwl4/eLROPA3cfcXtLApP0QzLmUT/HuPCZWyB7IY9ZrMeKw2O/nFIqPQB3PVM9aYm0F312AXDQ=="], + "node-fetch": ["node-fetch@2.7.0", "", { "dependencies": { "whatwg-url": "^5.0.0" }, "peerDependencies": { "encoding": "^0.1.0" }, "optionalPeers": ["encoding"] }, "sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A=="], "node-fetch-native": ["node-fetch-native@1.6.7", "", {}, "sha512-g9yhqoedzIUm0nTnTqAQvueMPVOuIY16bqgAJJC8XOOubYFNwz6IER9qs0Gq2Xd0+CecCKFjtdDTMA4u4xG06Q=="], @@ -1773,6 +1802,8 @@ "wcwidth": ["wcwidth@1.0.1", "", { "dependencies": { "defaults": "^1.0.3" } }, "sha512-XHPEwS0q6TaxcvG85+8EYkbiCux2XtWG2mkc47Ng2A77BQu9+DqIOJldST4HgPkuea7dvKSj5VgX3P1d4rW8Tg=="], + "web-streams-polyfill": ["web-streams-polyfill@3.3.3", "", {}, "sha512-d2JWLCivmZYTSIoge9MsgFCZrt571BikcWGYkjC1khllbTeDlGqZ2D8vD8E/lJa8WGWbb7Plm8/XJYV7IJHZZw=="], + "web-tree-sitter": ["web-tree-sitter@0.25.10", "", { "peerDependencies": { "@types/emscripten": "^1.40.0" }, "optionalPeers": ["@types/emscripten"] }, "sha512-Y09sF44/13XvgVKgO2cNDw5rGk6s26MgoZPXLESvMXeefBf7i6/73eFurre0IsTW6E14Y0ArIzhUMmjoc7xyzA=="], "web3-utils": ["web3-utils@1.10.4", "", { "dependencies": { "@ethereumjs/util": "^8.1.0", "bn.js": "^5.2.1", "ethereum-bloom-filters": "^1.0.6", "ethereum-cryptography": "^2.1.2", "ethjs-unit": "0.1.6", "number-to-bn": "1.7.0", "randombytes": "^2.1.0", "utf8": "3.0.0" } }, "sha512-tsu8FiKJLk2PzhDl9fXbGUWTkkVXYhtTA+SmEFkKft+9BgwLxfCRpU96sWv7ICC8zixBNd3JURVoiR3dUXgP8A=="], @@ -2167,6 +2198,8 @@ "form-data/mime-types": ["mime-types@2.1.35", "", { "dependencies": { "mime-db": "1.52.0" } }, "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw=="], + "gaxios/node-fetch": ["node-fetch@3.3.2", "", { "dependencies": { "data-uri-to-buffer": "^4.0.0", "fetch-blob": "^3.1.4", "formdata-polyfill": "^4.0.10" } }, "sha512-dRB78srN/l6gqWulah9SrxeYnxeddIG30+GOqK/9OlLVyLg3HPnr6SqOWTWOXKRwC2eGYCkZ59NNuSgvSrpgOA=="], + "hash-base/readable-stream": ["readable-stream@2.3.8", "", { "dependencies": { "core-util-is": "~1.0.0", "inherits": "~2.0.3", "isarray": "~1.0.0", "process-nextick-args": "~2.0.0", "safe-buffer": "~5.1.1", "string_decoder": "~1.1.1", "util-deprecate": "~1.0.1" } }, "sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA=="], "image-q/@types/node": ["@types/node@16.9.1", "", {}, "sha512-QpLcX9ZSsq3YYUUnD3nFDY8H7wctAhQj/TFKL8Ya8v5fMm3CFXxo8zStsLAl780ltoYoo1WvKUVGBQK+1ifr7g=="], diff --git a/package.json b/package.json index d4226879..4a505aa8 100644 --- a/package.json +++ b/package.json @@ -15,11 +15,14 @@ }, "scripts": { "dev": "bun run src/index.ts", - "build": "tsc", + "clean:dist": "bun -e \"import { rmSync } from 'node:fs'; rmSync('dist', { recursive: true, force: true })\"", + "build": "bun run clean:dist && tsc && bun -e \"import { chmodSync } from 'node:fs'; chmodSync('dist/index.js', 0o755)\"", "build:binary": "bun build --compile --outfile dist/grok-standalone ./src/index.ts", + "prepack": "bun run build", "start": "bun run dist/index.js", "typecheck": "tsc --noEmit", "test": "bunx vitest run", + "test:vertex": "bun run test-vertex-integration.ts", "test:watch": "bunx vitest", "lint": "biome check src/", "format": "biome format src/", @@ -56,6 +59,7 @@ "commander": "^12.1.0", "diff": "^8.0.3", "dotenv": "^16.6.1", + "google-auth-library": "^10.6.2", "grammy": "^1.41.1", "react": "^19.2.4", "ripgrep": "^0.3.1", diff --git a/src/agent/agent.ts b/src/agent/agent.ts index 2c76420a..3d26dba1 100644 --- a/src/agent/agent.ts +++ b/src/agent/agent.ts @@ -67,11 +67,15 @@ import { loadCustomInstructions } from "../utils/instructions"; import { type CustomSubagentConfig, getCurrentModel, + getModelAuthStatus, getModeSpecificModel, + hasModelAuthConfigured, + isVertexModeEnabled, loadMcpServers, loadValidSubAgents, type SandboxMode, type SandboxSettings, + VERTEX_API_KEY_PLACEHOLDER, } from "../utils/settings"; import { runSideQuestion, type SideQuestionResult } from "../utils/side-question"; import { discoverSkills, formatSkillsForPrompt } from "../utils/skills"; @@ -556,8 +560,9 @@ export class Agent { options: AgentOptions = {}, ) { this.baseURL = baseURL || null; - if (apiKey) { - this.setApiKey(apiKey, baseURL); + const authStatus = getModelAuthStatus(); + if (apiKey || (authStatus.activeMode === "vertex" && authStatus.configured)) { + this.setApiKey(apiKey || VERTEX_API_KEY_PLACEHOLDER, baseURL); } this.bash = new BashTool(process.cwd(), { sandboxMode: options.sandboxMode ?? "off", @@ -644,7 +649,7 @@ export class Agent { } hasApiKey(): boolean { - return !!this.apiKey; + return hasModelAuthConfigured() && !!this.provider; } setApiKey(apiKey: string, baseURL = this.baseURL ?? undefined): void { @@ -885,6 +890,12 @@ export class Agent { } private getBatchClientOptions(signal?: AbortSignal): BatchClientOptions { + if (isVertexModeEnabled()) { + throw new Error( + "xAI Batch API is not available when GROK_USE_VERTEX=1. Use normal streaming/headless mode with Vertex AI, or unset GROK_USE_VERTEX and configure GROK_API_KEY for native xAI Batch API.", + ); + } + if (!this.apiKey) { throw new Error("API key required. Add an API key to continue."); } @@ -2136,6 +2147,12 @@ export class Agent { private requireProvider(): XaiProvider { if (!this.provider) { + const authStatus = getModelAuthStatus(); + if (authStatus.activeMode === "vertex") { + throw new Error( + "Vertex AI authentication is not configured. Set GROK_VERTEX_PROJECT_ID, then run `gcloud auth application-default login`.", + ); + } throw new Error("API key required. Add an API key to continue."); } diff --git a/src/agent/auth.test.ts b/src/agent/auth.test.ts new file mode 100644 index 00000000..763bd6b0 --- /dev/null +++ b/src/agent/auth.test.ts @@ -0,0 +1,95 @@ +import { rmSync } from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, describe, expect, it, vi } from "vitest"; + +const originalEnv = { + GROK_API_KEY: process.env.GROK_API_KEY, + GROK_USE_VERTEX: process.env.GROK_USE_VERTEX, + GROK_USER_SETTINGS_PATH: process.env.GROK_USER_SETTINGS_PATH, + GROK_VERTEX_PROJECT_ID: process.env.GROK_VERTEX_PROJECT_ID, + GCP_PROJECT_ID: process.env.GCP_PROJECT_ID, + GOOGLE_CLOUD_PROJECT: process.env.GOOGLE_CLOUD_PROJECT, + GCLOUD_PROJECT: process.env.GCLOUD_PROJECT, +}; + +const testUserSettingsPath = path.join(os.tmpdir(), `grok-agent-auth-${process.pid}.json`); + +function restoreEnv(): void { + rmSync(testUserSettingsPath, { force: true }); + for (const [key, value] of Object.entries(originalEnv)) { + if (value === undefined) { + delete process.env[key]; + } else { + process.env[key] = value; + } + } +} + +async function importAgentWithStorageMock() { + vi.resetModules(); + vi.doMock("../storage/index", () => ({ + appendCompaction: vi.fn(), + appendMessages: vi.fn(() => []), + appendSystemMessage: vi.fn(() => 0), + buildChatEntries: vi.fn(() => []), + getNextMessageSequence: vi.fn(() => 0), + getSessionTotalTokens: vi.fn(() => 0), + loadTranscript: vi.fn(() => []), + loadTranscriptState: vi.fn(() => ({ messages: [], seqs: [] })), + recordUsageEvent: vi.fn(), + SessionStore: class { + getWorkspace() { + return null; + } + openSession() { + return null; + } + createSession() { + return null; + } + setModel() {} + getRequiredSession() { + return null; + } + setMode() {} + touchSession() {} + }, + })); + + return import("./agent"); +} + +describe("Agent auth state", () => { + afterEach(() => { + restoreEnv(); + vi.restoreAllMocks(); + vi.resetModules(); + vi.doUnmock("../storage/index"); + }); + + it("does not treat incomplete Vertex environment as configured auth", async () => { + const { Agent } = await importAgentWithStorageMock(); + process.env.GROK_USER_SETTINGS_PATH = testUserSettingsPath; + process.env.GROK_USE_VERTEX = "1"; + delete process.env.GROK_VERTEX_PROJECT_ID; + delete process.env.GCP_PROJECT_ID; + delete process.env.GOOGLE_CLOUD_PROJECT; + delete process.env.GCLOUD_PROJECT; + + const agent = new Agent(undefined, undefined, undefined, undefined, { persistSession: false }); + + expect(agent.hasApiKey()).toBe(false); + }); + + it("treats complete Vertex environment as configured auth without an xAI key", async () => { + const { Agent } = await importAgentWithStorageMock(); + process.env.GROK_USER_SETTINGS_PATH = testUserSettingsPath; + process.env.GROK_USE_VERTEX = "1"; + process.env.GROK_VERTEX_PROJECT_ID = "project-1"; + + const agent = new Agent(undefined, undefined, undefined, undefined, { persistSession: false }); + + expect(agent.hasApiKey()).toBe(true); + }); +}); diff --git a/src/audio/stt/engine.ts b/src/audio/stt/engine.ts index 26ba9f81..107bfd5a 100644 --- a/src/audio/stt/engine.ts +++ b/src/audio/stt/engine.ts @@ -1,5 +1,5 @@ import type { TelegramSettings } from "../../utils/settings"; -import { getApiKey, getBaseURL, resolveTelegramAudioInputSettings } from "../../utils/settings"; +import { getApiKey, getBaseURL, isVertexModeEnabled, resolveTelegramAudioInputSettings } from "../../utils/settings"; import { GrokSttEngine, type GrokSttTranscriptionResult } from "./grok-stt"; export interface AudioTranscriptionInput { @@ -20,6 +20,12 @@ export function createTelegramAudioInputEngine( const resolved = resolveTelegramAudioInputSettings(telegramSettings); const apiKey = getApiKey(); if (!apiKey) { + if (isVertexModeEnabled()) { + throw new Error( + "Telegram audio transcription uses the native xAI STT endpoint and is not available with Vertex ADC alone. Disable Telegram audio input, or unset GROK_USE_VERTEX and configure GROK_API_KEY.", + ); + } + throw new Error( "Grok STT requires an API key. Set GROK_API_KEY or configure apiKey in ~/.grok/user-settings.json.", ); diff --git a/src/grok/client.test.ts b/src/grok/client.test.ts index c2d656e1..0ba70231 100644 --- a/src/grok/client.test.ts +++ b/src/grok/client.test.ts @@ -46,6 +46,24 @@ describe("client", () => { expect(runtime.modelId).toBe("grok-3"); expect(runtime.providerOptions).toBeUndefined(); }); + + it("keeps chat model tool support in Vertex mode", () => { + const original = process.env.GROK_USE_VERTEX; + process.env.GROK_USE_VERTEX = "1"; + + try { + const runtime = resolveModelRuntime(mockProvider, "grok-4-1-fast-reasoning"); + + expect(runtime.modelId).toBe("grok-4-1-fast-reasoning"); + expect(runtime.modelInfo?.supportsClientTools).not.toBe(false); + } finally { + if (original === undefined) { + delete process.env.GROK_USE_VERTEX; + } else { + process.env.GROK_USE_VERTEX = original; + } + } + }); }); describe("with configured reasoning effort", () => { diff --git a/src/grok/client.ts b/src/grok/client.ts index abbbcfe5..ab8d0f82 100644 --- a/src/grok/client.ts +++ b/src/grok/client.ts @@ -1,8 +1,9 @@ import { createXai } from "@ai-sdk/xai"; import { generateText } from "ai"; import type { ModelInfo, ReasoningEffort } from "../types/index"; -import { getReasoningEffortForModel } from "../utils/settings"; +import { getReasoningEffortForModel, isVertexModeEnabled, VERTEX_API_KEY_PLACEHOLDER } from "../utils/settings"; import { getEffectiveReasoningEffort, getModelInfo, normalizeModelId } from "./models"; +import { createVertexFetch } from "./vertex-adapter"; export type XaiProvider = ReturnType; export type XaiChatModel = ReturnType; @@ -33,6 +34,14 @@ export interface ResolvedModelRuntime { } export function createProvider(apiKey: string, baseURL?: string): XaiProvider { + if (isVertexModeEnabled()) { + return createXai({ + apiKey: apiKey || VERTEX_API_KEY_PLACEHOLDER, + baseURL: "https://api.x.ai/v1", + fetch: createVertexFetch(), + }); + } + return createXai({ apiKey, baseURL: baseURL || process.env.GROK_BASE_URL || "https://api.x.ai/v1", @@ -45,7 +54,7 @@ export function resolveModelRuntime(provider: XaiProvider, requestedModelId: str const reasoningEffort = getEffectiveReasoningEffort(modelId, getReasoningEffortForModel(modelId)); return { - model: modelInfo?.responsesOnly ? provider.responses(modelId) : provider(modelId), + model: !isVertexModeEnabled() && modelInfo?.responsesOnly ? provider.responses(modelId) : provider(modelId), modelId, modelInfo, providerOptions: reasoningEffort diff --git a/src/grok/tools.ts b/src/grok/tools.ts index bb014bb2..d0948254 100644 --- a/src/grok/tools.ts +++ b/src/grok/tools.ts @@ -22,7 +22,12 @@ import { editFile, readFile, writeFile } from "../tools/file"; import { executeGrep } from "../tools/grep"; import type { ScheduleDaemonStatus, ScheduleManager, StoredSchedule } from "../tools/schedule"; import type { AgentMode, TaskRequest, ToolResult } from "../types/index"; -import { type CustomSubagentConfig, loadPaymentSettings, loadValidSubAgents } from "../utils/settings"; +import { + type CustomSubagentConfig, + isVertexModeEnabled, + loadPaymentSettings, + loadValidSubAgents, +} from "../utils/settings"; import type { XaiProvider } from "./client"; import { type GenerateImageToolInput, @@ -61,6 +66,10 @@ export function createTools( toolName: "web_search" | "x_search", abortSignal?: AbortSignal, ): Promise<{ success: boolean; output: string }> => { + if (isVertexModeEnabled()) { + return vertexUnsupportedTool(toolName === "web_search" ? "Web search" : "X search"); + } + try { const { text } = await generateText({ model: provider.responses(RESPONSES_SEARCH_MODEL), @@ -242,6 +251,9 @@ export function createTools( .describe("Optional file path for the generated image. For multiple images, numbered suffixes are added."), }), execute: async (input: GenerateImageToolInput, { abortSignal }) => { + if (isVertexModeEnabled()) { + return vertexUnsupportedTool("Image generation"); + } return generateImageTool(provider, input, cwd(), abortSignal); }, }), @@ -279,6 +291,9 @@ export function createTools( .describe("Optional timeout in milliseconds while waiting for video generation"), }), execute: async (input: GenerateVideoToolInput, { abortSignal }) => { + if (isVertexModeEnabled()) { + return vertexUnsupportedTool("Video generation"); + } return generateVideoTool(provider, input, cwd(), abortSignal); }, }), @@ -972,6 +987,13 @@ export function createTools( return tools; } +function vertexUnsupportedTool(label: string): { success: false; output: string } { + return { + success: false, + output: `${label} is not available through Vertex AI Grok chat completions. Unset GROK_USE_VERTEX and configure GROK_API_KEY to use native xAI-only endpoints.`, + }; +} + function formatScheduleList(schedules: StoredSchedule[], daemonStatus: ScheduleDaemonStatus): string { const lines = [ `Daemon: ${daemonStatus.running ? `running${daemonStatus.pid ? ` (pid ${daemonStatus.pid})` : ""}` : "not running"}`, diff --git a/src/grok/vertex-adapter.test.ts b/src/grok/vertex-adapter.test.ts new file mode 100644 index 00000000..be711746 --- /dev/null +++ b/src/grok/vertex-adapter.test.ts @@ -0,0 +1,528 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + buildVertexModelUrl, + convertMessagesToVertexContents, + convertVertexGenerateResponseToOpenAI, + convertVertexStreamResponseToOpenAIChunks, + convertXaiChatRequestToVertex, + createVertexFetch, + createVertexSseStream, + getVertexModelId, + sanitizeVertexSchema, +} from "./vertex-adapter"; +import { getVertexAccessToken } from "./vertex-auth"; + +const getVertexAccessTokenMock = vi.mocked(getVertexAccessToken); + +vi.mock("./vertex-auth", () => ({ + getVertexAccessToken: vi.fn(async () => "adc-token"), +})); + +const originalEnv = { + GROK_VERTEX_PROJECT_ID: process.env.GROK_VERTEX_PROJECT_ID, + GROK_VERTEX_LOCATION: process.env.GROK_VERTEX_LOCATION, + GROK_VERTEX_BASE_URL: process.env.GROK_VERTEX_BASE_URL, + GROK_VERTEX_DISABLE_TOOLS: process.env.GROK_VERTEX_DISABLE_TOOLS, + GCP_PROJECT_ID: process.env.GCP_PROJECT_ID, + GCP_REGION: process.env.GCP_REGION, + GCP_VERTEX_LOCATION: process.env.GCP_VERTEX_LOCATION, + GCP_VERTEX_BASE_URL: process.env.GCP_VERTEX_BASE_URL, +}; + +function restoreVertexEnv(): void { + for (const [key, value] of Object.entries(originalEnv)) { + if (value === undefined) { + delete process.env[key]; + } else { + process.env[key] = value; + } + } +} + +describe("Vertex Grok adapter", () => { + afterEach(() => { + restoreVertexEnv(); + vi.clearAllMocks(); + getVertexAccessTokenMock.mockResolvedValue("adc-token"); + }); + + it("uses the global Vertex host with a configurable location path", () => { + expect( + buildVertexModelUrl( + { + projectId: "project-1", + location: "europe-west1", + baseURL: "https://aiplatform.googleapis.com", + }, + "grok-4-1-fast-reasoning", + false, + ), + ).toBe( + "https://aiplatform.googleapis.com/v1/projects/project-1/locations/europe-west1/publishers/xai/models/grok-4.1-fast-reasoning:generateContent", + ); + }); + + it("requests SSE output from Vertex streaming endpoints", () => { + expect( + buildVertexModelUrl( + { + projectId: "project-1", + location: "europe-west1", + baseURL: "https://aiplatform.googleapis.com/", + }, + "grok-4-1-fast-reasoning", + true, + ), + ).toBe( + "https://aiplatform.googleapis.com/v1/projects/project-1/locations/europe-west1/publishers/xai/models/grok-4.1-fast-reasoning:streamGenerateContent?alt=sse", + ); + }); + + it("maps native xAI model IDs to Vertex xAI publisher IDs", () => { + expect(getVertexModelId("grok-4-1-fast-reasoning")).toBe("grok-4.1-fast-reasoning"); + expect(getVertexModelId("grok-4-1-fast-non-reasoning")).toBe("grok-4.1-fast-non-reasoning"); + expect(getVertexModelId("grok-4.20-0309-reasoning")).toBe("grok-4.20-reasoning"); + expect(getVertexModelId("custom-model")).toBe("custom-model"); + }); + + it("maps OpenAI-style chat messages to Vertex contents", () => { + const contents = convertMessagesToVertexContents([ + { role: "system", content: "Follow policy." }, + { role: "user", content: "Hello" }, + { + role: "assistant", + content: "I can help.", + tool_calls: [ + { + id: "call-1", + type: "function", + function: { name: "lookup", arguments: '{"query":"docs"}' }, + }, + ], + }, + { role: "tool", tool_call_id: "call-1", content: '{"ok":true}' }, + ]); + + expect(contents).toEqual([ + { role: "user", parts: [{ text: "Hello" }] }, + { + role: "model", + parts: [{ text: "I can help." }, { functionCall: { name: "lookup", args: { query: "docs" } } }], + }, + { role: "user", parts: [{ functionResponse: { name: "lookup", response: { ok: true } } }] }, + ]); + }); + + it("sends system prompts through Vertex systemInstruction", () => { + const request = convertXaiChatRequestToVertex({ + model: "grok-4-1-fast-reasoning", + messages: [ + { role: "system", content: "Follow policy." }, + { role: "user", content: "Hello" }, + ], + }); + + expect(request).toMatchObject({ + contents: [{ role: "user", parts: [{ text: "Hello" }] }], + systemInstruction: { parts: [{ text: "Follow policy." }] }, + }); + }); + + it("maps and sanitizes function declarations by default", () => { + const request = convertXaiChatRequestToVertex({ + model: "grok-4-1-fast-reasoning", + messages: [{ role: "user", content: "Search docs" }], + tools: [ + { + type: "function", + function: { + name: "search", + description: "Search docs", + parameters: { + type: "object", + additionalProperties: false, + properties: { + query: { type: "string", minLength: 1 }, + limit: { anyOf: [{ type: "integer" }, { type: "null" }], description: "Result count" }, + }, + required: ["query"], + }, + }, + }, + ], + tool_choice: { type: "function", function: { name: "search" } }, + }); + + expect(request.tools).toEqual([ + { + functionDeclarations: [ + { + name: "search", + description: "Search docs", + parameters: { + type: "OBJECT", + properties: { + query: { type: "STRING" }, + limit: { type: "INTEGER", description: "Result count", nullable: true }, + }, + required: ["query"], + }, + }, + ], + }, + ]); + expect(request.toolConfig).toEqual({ + functionCallingConfig: { mode: "ANY", allowedFunctionNames: ["search"] }, + }); + }); + + it("can omit function declarations with the emergency Vertex tool disable flag", () => { + process.env.GROK_VERTEX_DISABLE_TOOLS = "1"; + const request = convertXaiChatRequestToVertex({ + model: "grok-4-1-fast-reasoning", + messages: [{ role: "user", content: "Search docs" }], + max_completion_tokens: 512, + temperature: 0.2, + top_p: 0.9, + tools: [ + { + type: "function", + function: { + name: "search", + description: "Search docs", + parameters: { type: "object", properties: { query: { type: "string" } } }, + }, + }, + ], + tool_choice: { type: "function", function: { name: "search" } }, + }); + + expect(request.generationConfig).toEqual({ maxOutputTokens: 512, temperature: 0.2, topP: 0.9 }); + expect(request.tools).toBeUndefined(); + expect(request.toolConfig).toBeUndefined(); + }); + + it("drops invalid function names instead of sending declarations Vertex will reject", () => { + const request = convertXaiChatRequestToVertex({ + model: "grok-4-1-fast-reasoning", + messages: [{ role: "user", content: "Search docs" }], + tools: [ + { + type: "function", + function: { + name: "not allowed", + description: "Invalid Vertex function name", + parameters: { type: "object" }, + }, + }, + ], + tool_choice: { type: "function", function: { name: "not allowed" } }, + }); + + expect(request.tools).toBeUndefined(); + expect(request.toolConfig).toBeUndefined(); + }); + + it("converts JSON schema to Vertex's function schema subset", () => { + expect( + sanitizeVertexSchema({ + type: "object", + $schema: "https://json-schema.org/draft/2020-12/schema", + additionalProperties: false, + properties: { + path: { type: "string", title: "Path" }, + count: { type: ["integer", "null"], default: 10, nullable: true }, + mode: { enum: ["read", "write"] }, + tags: { type: "array", items: { type: "string", minLength: 1 } }, + }, + required: ["path"], + }), + ).toEqual({ + type: "OBJECT", + properties: { + path: { type: "STRING" }, + count: { type: "INTEGER", nullable: true }, + mode: { type: "STRING", enum: ["read", "write"] }, + tags: { type: "ARRAY", items: { type: "STRING" } }, + }, + required: ["path"], + }); + }); + + it("preserves nullable unions when sanitizing function schemas", () => { + expect( + sanitizeVertexSchema({ + type: "object", + properties: { + maybeText: { oneOf: [{ type: "null" }, { type: "string", description: "Optional text" }] }, + maybeCount: { anyOf: [{ type: "integer" }, { type: "null" }] }, + }, + }), + ).toEqual({ + type: "OBJECT", + properties: { + maybeText: { type: "STRING", description: "Optional text", nullable: true }, + maybeCount: { type: "INTEGER", nullable: true }, + }, + }); + }); + + it("maps Vertex generateContent responses back to OpenAI chat completions", () => { + const converted = convertVertexGenerateResponseToOpenAI( + { + candidates: [ + { + index: 0, + finishReason: "STOP", + content: { + parts: [{ text: "done" }, { functionCall: { name: "save", args: { path: "file.txt" } } }], + }, + }, + ], + usageMetadata: { promptTokenCount: 3, candidatesTokenCount: 4, totalTokenCount: 7 }, + }, + { id: "chatcmpl-test", model: "grok-4-1-fast-reasoning", created: 123 }, + ); + + expect(converted).toMatchObject({ + id: "chatcmpl-test", + object: "chat.completion", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: "done", + tool_calls: [ + { + type: "function", + function: { name: "save", arguments: '{"path":"file.txt"}' }, + }, + ], + }, + finish_reason: "tool_calls", + }, + ], + usage: { prompt_tokens: 3, completion_tokens: 4, total_tokens: 7 }, + }); + }); + + it("wraps Vertex JSON stream chunks as OpenAI SSE events", async () => { + const encoder = new TextEncoder(); + const vertexBody = new ReadableStream({ + start(controller) { + controller.enqueue( + encoder.encode( + '[{"candidates":[{"content":{"parts":[{"text":"hel"}]}}]},{"candidates":[{"content":{"parts":[{"text":"lo"}]},"finishReason":"STOP"}]}]', + ), + ); + controller.close(); + }, + }); + + const text = await new Response( + createVertexSseStream(vertexBody, { id: "chatcmpl-stream", model: "grok-4-1-fast-reasoning", created: 123 }), + ).text(); + + expect(text).toContain('"object":"chat.completion.chunk"'); + expect(text).toContain('"content":"hel"'); + expect(text).toContain('"content":"lo"'); + expect(text).toContain('"finish_reason":"stop"'); + expect(text.trim().endsWith("data: [DONE]")).toBe(true); + }); + + it("forwards Vertex streaming error chunks as OpenAI SSE error events", async () => { + const encoder = new TextEncoder(); + const vertexBody = new ReadableStream({ + start(controller) { + controller.enqueue( + encoder.encode('{"error":{"message":"quota exceeded","status":"RESOURCE_EXHAUSTED","code":8}}'), + ); + controller.close(); + }, + }); + + const text = await new Response( + createVertexSseStream(vertexBody, { id: "chatcmpl-stream", model: "grok-4-1-fast-reasoning", created: 123 }), + ).text(); + + expect(text).toContain('"error":{"message":"quota exceeded","status":"RESOURCE_EXHAUSTED","code":8}'); + expect(text.trim().endsWith("data: [DONE]")).toBe(true); + }); + + it("includes OpenAI stream indexes for Vertex function-call chunks", () => { + const chunks = convertVertexStreamResponseToOpenAIChunks( + { + candidates: [ + { + index: 0, + content: { + parts: [{ functionCall: { name: "read_file", args: { path: "README.md" } } }], + }, + }, + ], + }, + { id: "chatcmpl-tool", model: "grok-4-1-fast-reasoning", created: 123 }, + ); + + expect(chunks).toMatchObject([ + { + choices: [ + { + delta: { + tool_calls: [ + { + index: 0, + id: "call_chatcmpl-tool_0_0", + type: "function", + function: { name: "read_file", arguments: '{"path":"README.md"}' }, + }, + ], + }, + }, + ], + }, + { + choices: [{ finish_reason: "tool_calls" }], + }, + ]); + }); + + it("numbers stream function calls by function-call order, not part position", () => { + const chunks = convertVertexStreamResponseToOpenAIChunks( + { + candidates: [ + { + index: 0, + content: { + parts: [ + { text: "I will use tools." }, + { functionCall: { name: "read_file", args: { path: "README.md" } } }, + { text: " Then another." }, + { functionCall: { name: "grep", args: { pattern: "Vertex" } } }, + ], + }, + }, + ], + }, + { id: "chatcmpl-tool-mixed", model: "grok-4-1-fast-reasoning", created: 123 }, + ); + + const toolChunk = chunks.find((chunk) => { + const record = chunk as { choices?: Array<{ delta?: { tool_calls?: unknown[] } }> }; + return Boolean(record.choices?.[0]?.delta?.tool_calls); + }) as { choices: Array<{ delta: { tool_calls: Array<{ index: number; id: string }> } }> }; + + expect(toolChunk.choices[0].delta.tool_calls).toMatchObject([ + { index: 0, id: "call_chatcmpl-tool-mixed_0_0" }, + { index: 1, id: "call_chatcmpl-tool-mixed_0_1" }, + ]); + }); + + it("fetches Vertex with ADC bearer auth and returns translated chat JSON", async () => { + process.env.GROK_VERTEX_PROJECT_ID = "project-1"; + process.env.GROK_VERTEX_LOCATION = "europe-west1"; + + const baseFetch = vi.fn(async (url, init) => { + expect(String(url)).toBe( + "https://aiplatform.googleapis.com/v1/projects/project-1/locations/europe-west1/publishers/xai/models/grok-4.1-fast-reasoning:generateContent", + ); + expect((init?.headers as Record).Authorization).toBe("Bearer adc-token"); + expect(JSON.parse(String(init?.body))).toMatchObject({ + contents: [{ role: "user", parts: [{ text: "Hi" }] }], + }); + + return new Response( + JSON.stringify({ + candidates: [{ content: { parts: [{ text: "Hello" }] }, finishReason: "STOP" }], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); + }); + + const response = await createVertexFetch(baseFetch)("https://api.x.ai/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "grok-4-1-fast-reasoning", + messages: [{ role: "user", content: "Hi" }], + }), + }); + + expect(response.status).toBe(200); + await expect(response.json()).resolves.toMatchObject({ + object: "chat.completion", + choices: [{ message: { role: "assistant", content: "Hello" }, finish_reason: "stop" }], + }); + }); + + it("returns an actionable Vertex auth response when ADC token refresh fails", async () => { + process.env.GROK_VERTEX_PROJECT_ID = "project-1"; + getVertexAccessTokenMock.mockRejectedValueOnce( + new Error( + "Google Application Default Credentials need reauthentication.\n\nRun `gcloud auth application-default login`.", + ), + ); + const baseFetch = vi.fn(); + + const response = await createVertexFetch(baseFetch)("https://api.x.ai/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "grok-4-1-fast-reasoning", + messages: [{ role: "user", content: "Hi" }], + }), + }); + + expect(response.status).toBe(401); + expect(baseFetch).not.toHaveBeenCalled(); + await expect(response.json()).resolves.toMatchObject({ + error: { + message: expect.stringContaining("Google Application Default Credentials need reauthentication."), + code: "vertex_auth_failed", + }, + }); + }); + + it("returns structured errors for unsupported xAI request shapes", async () => { + process.env.GROK_VERTEX_PROJECT_ID = "project-1"; + const baseFetch = vi.fn(); + + const response = await createVertexFetch(baseFetch)("https://api.x.ai/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "grok-4-1-fast-reasoning", + messages: [{ role: "user", content: [{ type: "image_url", image_url: { url: "https://example.com/a.png" } }] }], + }), + }); + + expect(response.status).toBe(400); + expect(baseFetch).not.toHaveBeenCalled(); + await expect(response.json()).resolves.toMatchObject({ + error: { + message: expect.stringContaining("image_url message parts are not supported"), + code: "vertex_request_invalid", + }, + }); + }); + + it("returns structured errors for empty Vertex conversations", async () => { + process.env.GROK_VERTEX_PROJECT_ID = "project-1"; + const baseFetch = vi.fn(); + + const response = await createVertexFetch(baseFetch)("https://api.x.ai/v1/chat/completions", { + method: "POST", + body: JSON.stringify({ + model: "grok-4-1-fast-reasoning", + messages: [], + }), + }); + + expect(response.status).toBe(400); + expect(baseFetch).not.toHaveBeenCalled(); + await expect(response.json()).resolves.toMatchObject({ + error: { + message: "Cannot send an empty conversation to Vertex AI.", + code: "vertex_request_invalid", + }, + }); + }); +}); diff --git a/src/grok/vertex-adapter.ts b/src/grok/vertex-adapter.ts new file mode 100644 index 00000000..10cfd340 --- /dev/null +++ b/src/grok/vertex-adapter.ts @@ -0,0 +1,967 @@ +import { isTruthyEnv, requireVertexSettings, type VertexSettings } from "../utils/settings"; +import { getVertexAccessToken } from "./vertex-auth"; + +type JsonRecord = Record; + +interface XaiToolCall { + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; +} + +interface XaiMessage { + role: "system" | "user" | "assistant" | "tool"; + content?: unknown; + tool_calls?: XaiToolCall[]; + tool_call_id?: string; +} + +interface XaiChatRequest { + model: string; + messages?: XaiMessage[]; + stream?: boolean; + temperature?: number; + top_p?: number; + seed?: number; + max_completion_tokens?: number; + tools?: Array<{ + type: "function"; + function: { + name: string; + description?: string; + parameters?: unknown; + }; + }>; + tool_choice?: + | "auto" + | "none" + | "required" + | { + type: "function"; + function: { + name: string; + }; + }; +} + +interface VertexPart { + text?: string; + functionCall?: { + name: string; + args?: JsonRecord; + }; + functionResponse?: { + name: string; + response: JsonRecord; + }; +} + +interface VertexContent { + role: "user" | "model"; + parts: VertexPart[]; +} + +interface VertexRequest { + contents: VertexContent[]; + systemInstruction?: { + parts: VertexPart[]; + }; + generationConfig?: JsonRecord; + tools?: Array<{ + functionDeclarations: Array<{ + name: string; + description?: string; + parameters?: unknown; + }>; + }>; + toolConfig?: JsonRecord; +} + +interface VertexCandidate { + index?: number; + content?: { + role?: string; + parts?: VertexPart[]; + }; + finishReason?: string; +} + +interface VertexResponse { + candidates?: VertexCandidate[]; + usageMetadata?: { + promptTokenCount?: number; + candidatesTokenCount?: number; + totalTokenCount?: number; + }; + error?: { + message?: string; + status?: string; + code?: number; + }; +} + +interface OpenAIContext { + id: string; + model: string; + created: number; +} + +const VERTEX_MODEL_IDS: Record = { + "grok-4-1-fast-reasoning": "grok-4.1-fast-reasoning", + "grok-4-1-fast-non-reasoning": "grok-4.1-fast-non-reasoning", + "grok-4.20-0309-reasoning": "grok-4.20-reasoning", + "grok-4.20-0309-non-reasoning": "grok-4.20-non-reasoning", +}; + +const VERTEX_FUNCTION_NAME_PATTERN = /^[A-Za-z_][A-Za-z0-9_.-]{0,63}$/; +const VERTEX_SCHEMA_TYPES = new Set(["STRING", "INTEGER", "BOOLEAN", "NUMBER", "ARRAY", "OBJECT"]); +const UNSUPPORTED_SCHEMA_KEYS = new Set([ + "$defs", + "$id", + "$schema", + "additionalItems", + "additionalProperties", + "allOf", + "anyOf", + "const", + "contains", + "default", + "definitions", + "dependencies", + "dependentRequired", + "dependentSchemas", + "else", + "examples", + "exclusiveMaximum", + "exclusiveMinimum", + "format", + "if", + "maxItems", + "maxLength", + "maximum", + "minItems", + "minLength", + "minimum", + "multipleOf", + "not", + "oneOf", + "pattern", + "patternProperties", + "prefixItems", + "propertyNames", + "readOnly", + "strict", + "then", + "title", + "unevaluatedProperties", + "uniqueItems", + "writeOnly", +]); + +export function createVertexFetch(baseFetch: typeof fetch = globalThis.fetch): typeof fetch { + return async (input, init) => { + const url = getRequestUrl(input); + + if (!url.pathname.endsWith("/chat/completions")) { + return unsupportedVertexEndpointResponse(url); + } + + let xaiRequest: XaiChatRequest; + let vertexSettings: VertexSettings; + let vertexRequest: VertexRequest; + try { + xaiRequest = (await readJsonRequest(input, init)) as XaiChatRequest; + vertexSettings = requireVertexSettings(); + vertexRequest = convertXaiChatRequestToVertex(xaiRequest); + } catch (err: unknown) { + const message = err instanceof Error ? err.message : String(err); + return vertexErrorResponse(message, 400, "vertex_request_invalid"); + } + + const isStreaming = xaiRequest.stream === true; + let accessToken: string; + try { + accessToken = await getVertexAccessToken(); + } catch (err: unknown) { + const message = err instanceof Error ? err.message : String(err); + return vertexErrorResponse(message, 401, "vertex_auth_failed"); + } + const vertexUrl = buildVertexModelUrl(vertexSettings, xaiRequest.model, isStreaming); + const response = await baseFetch(vertexUrl, { + method: "POST", + headers: { + Authorization: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(vertexRequest), + signal: init?.signal ?? (input instanceof Request ? input.signal : undefined), + }); + + if (!response.ok) { + return translateVertexError(response, vertexUrl); + } + + const context = createOpenAIContext(xaiRequest.model); + if (isStreaming) { + if (!response.body) { + return vertexErrorResponse("Vertex AI returned a streaming response without a readable body.", 502); + } + return new Response(createVertexSseStream(response.body, context), { + status: 200, + headers: { + "Content-Type": "text/event-stream; charset=utf-8", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }, + }); + } + + const payload = await response.json(); + return new Response(JSON.stringify(convertVertexGenerateResponseToOpenAI(payload, context)), { + status: 200, + headers: { + "Content-Type": "application/json", + }, + }); + }; +} + +export function buildVertexModelUrl(settings: VertexSettings, modelId: string, isStreaming: boolean): string { + const method = isStreaming ? "streamGenerateContent" : "generateContent"; + const baseURL = settings.baseURL.replace(/\/+$/, ""); + const vertexModelId = getVertexModelId(modelId); + const url = `${baseURL}/v1/projects/${encodeURIComponent(settings.projectId)}/locations/${encodeURIComponent( + settings.location, + )}/publishers/xai/models/${encodeURIComponent(vertexModelId)}:${method}`; + return isStreaming ? `${url}?alt=sse` : url; +} + +export function getVertexModelId(modelId: string): string { + return VERTEX_MODEL_IDS[modelId] ?? modelId; +} + +export function convertXaiChatRequestToVertex(request: XaiChatRequest): VertexRequest { + const conversation = convertMessagesToVertexConversation(request.messages ?? []); + const generationConfig = removeUndefined({ + maxOutputTokens: request.max_completion_tokens, + temperature: request.temperature, + topP: request.top_p, + seed: request.seed, + }); + const functionDeclarations = shouldForwardVertexTools() + ? (request.tools ?? []) + .filter((tool) => tool.type === "function" && isValidVertexFunctionName(tool.function.name)) + .map((tool) => + removeUndefined({ + name: tool.function.name, + description: tool.function.description, + parameters: sanitizeVertexSchema(tool.function.parameters), + }), + ) + : []; + + return removeUndefined({ + contents: conversation.contents, + systemInstruction: conversation.systemInstruction, + generationConfig: Object.keys(generationConfig).length > 0 ? generationConfig : undefined, + tools: functionDeclarations.length > 0 ? [{ functionDeclarations }] : undefined, + toolConfig: + functionDeclarations.length > 0 + ? convertToolChoiceToVertexToolConfig( + request.tool_choice, + new Set(functionDeclarations.map((declaration) => declaration.name)), + ) + : undefined, + }) as VertexRequest; +} + +function shouldForwardVertexTools(): boolean { + return !isTruthyEnv(process.env.GROK_VERTEX_DISABLE_TOOLS); +} + +function isValidVertexFunctionName(name: string): boolean { + return VERTEX_FUNCTION_NAME_PATTERN.test(name); +} + +export function sanitizeVertexSchema(schema: unknown): unknown { + const normalized = sanitizeVertexSchemaValue(schema); + if (!isRecord(normalized)) { + return { type: "OBJECT", properties: {} }; + } + if (!normalized.type) { + if (isRecord(normalized.properties)) { + normalized.type = "OBJECT"; + } else if (normalized.items !== undefined) { + normalized.type = "ARRAY"; + } else { + normalized.type = "OBJECT"; + } + } + if (normalized.type === "OBJECT" && !isRecord(normalized.properties)) { + normalized.properties = {}; + } + return normalized; +} + +function sanitizeVertexSchemaValue(schema: unknown): unknown { + if (Array.isArray(schema)) { + return schema.map((item) => sanitizeVertexSchemaValue(item)); + } + + if (!isRecord(schema)) { + return undefined; + } + + const unionSchema = pickUnionSchema(schema); + if (unionSchema && unionSchema !== schema) { + const unionResult = sanitizeVertexSchemaValue(unionSchema); + if (isRecord(unionResult)) { + return copySchemaMetadata(schema, unionResult); + } + } + + const result: JsonRecord = {}; + const type = normalizeVertexSchemaType(schema.type); + if (type) { + result.type = type; + } + if (schemaAllowsNull(schema.type)) { + result.nullable = true; + } + + for (const [key, value] of Object.entries(schema)) { + if (key === "type" || UNSUPPORTED_SCHEMA_KEYS.has(key)) continue; + + switch (key) { + case "description": + if (typeof value === "string" && value.trim()) { + result.description = value; + } + break; + case "nullable": + if (typeof value === "boolean") { + result.nullable = value; + } + break; + case "enum": { + const enumValues = Array.isArray(value) + ? value.filter((entry): entry is string => typeof entry === "string") + : []; + if (enumValues.length > 0) { + result.enum = enumValues; + } + break; + } + case "required": { + const required = Array.isArray(value) + ? value.filter((entry): entry is string => typeof entry === "string") + : []; + if (required.length > 0) { + result.required = required; + } + break; + } + case "properties": { + if (!isRecord(value)) break; + const properties = Object.fromEntries( + Object.entries(value) + .map(([propertyName, propertySchema]) => [propertyName, sanitizeVertexSchemaValue(propertySchema)]) + .filter((entry): entry is [string, unknown] => entry[1] !== undefined), + ); + result.properties = properties; + if (!result.type) { + result.type = "OBJECT"; + } + break; + } + case "items": { + if (Array.isArray(value)) { + const firstItem = value.find((item) => item !== undefined); + const itemSchema = sanitizeVertexSchemaValue(firstItem); + if (itemSchema !== undefined) { + result.items = itemSchema; + } + } else { + const itemSchema = sanitizeVertexSchemaValue(value); + if (itemSchema !== undefined) { + result.items = itemSchema; + } + } + if (!result.type) { + result.type = "ARRAY"; + } + break; + } + } + } + + if (!result.type && result.enum) { + result.type = "STRING"; + } + if (result.type === "OBJECT" && !isRecord(result.properties)) { + result.properties = {}; + } + if (Array.isArray(result.required) && isRecord(result.properties)) { + const propertyNames = new Set(Object.keys(result.properties)); + const required = result.required.filter( + (entry): entry is string => typeof entry === "string" && propertyNames.has(entry), + ); + if (required.length > 0) { + result.required = required; + } else { + delete result.required; + } + } + if (result.type === "ARRAY" && result.items === undefined) { + result.items = { type: "STRING" }; + } + + return Object.keys(result).length > 0 ? result : undefined; +} + +function pickUnionSchema(schema: JsonRecord): unknown { + for (const key of ["anyOf", "oneOf", "allOf"] as const) { + const variants = schema[key]; + if (!Array.isArray(variants)) continue; + const nonNull = variants.find((variant) => !isNullSchema(variant)); + if (nonNull !== undefined) { + const picked = isRecord(nonNull) ? { ...nonNull } : nonNull; + if (isRecord(picked) && variants.some((variant) => isNullSchema(variant))) { + picked.nullable = true; + } + return picked; + } + } + return undefined; +} + +function copySchemaMetadata(source: JsonRecord, target: JsonRecord): JsonRecord { + if (typeof source.description === "string" && !target.description) { + target.description = source.description; + } + if (source.nullable === true && target.nullable === undefined) { + target.nullable = true; + } + return target; +} + +function normalizeVertexSchemaType(value: unknown): string | undefined { + if (typeof value === "string") { + const normalized = value.toUpperCase(); + return VERTEX_SCHEMA_TYPES.has(normalized) ? normalized : undefined; + } + if (Array.isArray(value)) { + const first = value.find((entry) => typeof entry === "string" && entry.toUpperCase() !== "NULL"); + return normalizeVertexSchemaType(first); + } + return undefined; +} + +function schemaAllowsNull(value: unknown): boolean { + return Array.isArray(value) && value.some((entry) => typeof entry === "string" && entry.toUpperCase() === "NULL"); +} + +function isNullSchema(value: unknown): boolean { + return isRecord(value) && typeof value.type === "string" && value.type.toUpperCase() === "NULL"; +} + +export function convertMessagesToVertexContents(messages: XaiMessage[]): VertexContent[] { + return convertMessagesToVertexConversation(messages).contents; +} + +function convertMessagesToVertexConversation(messages: XaiMessage[]): { + contents: VertexContent[]; + systemInstruction?: { parts: VertexPart[] }; +} { + const contents: VertexContent[] = []; + const toolNamesById = new Map(); + const systemParts = messages + .filter((message) => message.role === "system") + .flatMap((message) => textPartsFromContent(message.content)); + + const append = (role: VertexContent["role"], parts: VertexPart[]) => { + const cleanParts = parts.filter((part) => hasVertexPartValue(part)); + if (cleanParts.length === 0) return; + + const last = contents[contents.length - 1]; + if (last?.role === role) { + last.parts.push(...cleanParts); + return; + } + + contents.push({ role, parts: cleanParts }); + }; + + for (const message of messages) { + switch (message.role) { + case "system": + break; + case "user": + append("user", textPartsFromContent(message.content)); + break; + case "assistant": { + const parts = textPartsFromContent(message.content); + for (const toolCall of message.tool_calls ?? []) { + toolNamesById.set(toolCall.id, toolCall.function.name); + parts.push({ + functionCall: { + name: toolCall.function.name, + args: parseJsonObject(toolCall.function.arguments), + }, + }); + } + append("model", parts); + break; + } + case "tool": { + const toolName = (message.tool_call_id ? toolNamesById.get(message.tool_call_id) : undefined) ?? "tool_result"; + append("user", [ + { + functionResponse: { + name: toolName, + response: responseObjectFromToolContent(message.content), + }, + }, + ]); + break; + } + } + } + + if (contents.length === 0) { + if (systemParts.length === 0) { + throw new Error("Cannot send an empty conversation to Vertex AI."); + } + contents.push({ role: "user", parts: [{ text: "Continue." }] }); + } + + if (contents[0]?.role === "model") { + contents.unshift({ role: "user", parts: [{ text: "Continue." }] }); + } + + return removeUndefined({ + contents, + systemInstruction: systemParts.length > 0 ? { parts: systemParts } : undefined, + }) as { + contents: VertexContent[]; + systemInstruction?: { parts: VertexPart[] }; + }; +} + +export function convertVertexGenerateResponseToOpenAI(payload: unknown, context: OpenAIContext) { + const response = normalizeVertexResponse(payload); + if (response.error?.message) { + return { + error: response.error, + }; + } + + const candidates = response.candidates?.length ? response.candidates : [{}]; + return { + id: context.id, + object: "chat.completion", + created: context.created, + model: context.model, + choices: candidates.map((candidate, index) => { + const toolCalls = extractFunctionCalls(candidate, context.id, index, false); + const content = extractTextFromVertexCandidate(candidate); + return { + index: candidate.index ?? index, + message: { + role: "assistant", + content: content || null, + ...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}), + }, + finish_reason: toolCalls.length > 0 ? "tool_calls" : mapVertexFinishReason(candidate.finishReason), + }; + }), + usage: convertVertexUsage(response.usageMetadata), + }; +} + +export function createVertexSseStream( + vertexBody: ReadableStream, + context: OpenAIContext, +): ReadableStream { + const encoder = new TextEncoder(); + const decoder = new TextDecoder(); + let currentObject = ""; + let depth = 0; + let inString = false; + let escaped = false; + + const enqueueSse = (controller: TransformStreamDefaultController, value: unknown) => { + controller.enqueue(encoder.encode(`data: ${JSON.stringify(value)}\n\n`)); + }; + + const processText = (text: string, controller: TransformStreamDefaultController) => { + for (const char of text) { + if (depth === 0) { + if (char === "{") { + currentObject = char; + depth = 1; + inString = false; + escaped = false; + } + continue; + } + + currentObject += char; + + if (escaped) { + escaped = false; + continue; + } + + if (inString && char === "\\") { + escaped = true; + continue; + } + + if (char === '"') { + inString = !inString; + continue; + } + + if (inString) continue; + + if (char === "{") { + depth += 1; + continue; + } + + if (char === "}") { + depth -= 1; + if (depth === 0) { + const parsed = JSON.parse(currentObject) as VertexResponse; + currentObject = ""; + for (const chunk of convertVertexStreamResponseToOpenAIChunks(parsed, context)) { + enqueueSse(controller, chunk); + } + } + } + } + }; + + return vertexBody.pipeThrough( + new TransformStream({ + transform(chunk, controller) { + processText(decoder.decode(chunk, { stream: true }), controller); + }, + flush(controller) { + const tail = decoder.decode(); + if (tail) { + processText(tail, controller); + } + if (depth !== 0) { + throw new Error("Vertex AI returned an incomplete JSON stream."); + } + controller.enqueue(encoder.encode("data: [DONE]\n\n")); + }, + }), + ); +} + +export function convertVertexStreamResponseToOpenAIChunks(payload: unknown, context: OpenAIContext): JsonRecord[] { + const response = normalizeVertexResponse(payload); + if (response.error?.message || response.error?.status || response.error?.code !== undefined) { + return [{ error: response.error }]; + } + + const usage = convertVertexUsage(response.usageMetadata); + const chunks: JsonRecord[] = []; + + for (const candidate of response.candidates ?? []) { + const index = candidate.index ?? 0; + const text = extractTextFromVertexCandidate(candidate); + const toolCalls = extractFunctionCalls(candidate, context.id, index, true); + const finishReason = toolCalls.length > 0 ? "tool_calls" : mapVertexFinishReason(candidate.finishReason); + + if (text) { + chunks.push({ + id: context.id, + object: "chat.completion.chunk", + created: context.created, + model: context.model, + choices: [{ index, delta: { content: text }, finish_reason: null }], + }); + } + + if (toolCalls.length > 0) { + chunks.push({ + id: context.id, + object: "chat.completion.chunk", + created: context.created, + model: context.model, + choices: [{ index, delta: { tool_calls: toolCalls }, finish_reason: null }], + }); + } + + if (finishReason) { + chunks.push({ + id: context.id, + object: "chat.completion.chunk", + created: context.created, + model: context.model, + choices: [{ index, delta: {}, finish_reason: finishReason }], + ...(usage ? { usage } : {}), + }); + } + } + + if (chunks.length === 0 && usage) { + chunks.push({ + id: context.id, + object: "chat.completion.chunk", + created: context.created, + model: context.model, + choices: [{ index: 0, delta: {}, finish_reason: null }], + usage, + }); + } + + return chunks; +} + +function createOpenAIContext(model: string): OpenAIContext { + return { + id: `chatcmpl-vertex-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`, + model, + created: Math.floor(Date.now() / 1000), + }; +} + +function getRequestUrl(input: Request | URL | string): URL { + if (input instanceof Request) return new URL(input.url); + return new URL(String(input)); +} + +async function readJsonRequest(input: Request | URL | string, init: RequestInit | undefined): Promise { + if (init?.body !== undefined && init.body !== null) { + return JSON.parse(await readBodyAsText(init.body)); + } + + if (input instanceof Request) { + return input.json(); + } + + throw new Error("Vertex adapter received a chat request without a JSON body."); +} + +async function readBodyAsText(body: NonNullable): Promise { + if (typeof body === "string") return body; + if (body instanceof URLSearchParams) return body.toString(); + if (body instanceof Blob) return body.text(); + if (body instanceof ArrayBuffer) return new TextDecoder().decode(body); + if (ArrayBuffer.isView(body)) return new TextDecoder().decode(body); + if (body instanceof ReadableStream) { + const response = new Response(body); + return response.text(); + } + + throw new Error("Vertex adapter expected a JSON request body."); +} + +function unsupportedVertexEndpointResponse(url: URL): Response { + return vertexErrorResponse( + `GROK_USE_VERTEX=1 supports Vertex AI chat completions only. The xAI endpoint "${url.pathname}" is not available through Vertex AI; native xAI-only features such as Responses API search, image/video generation, STT, and Batch API require GROK_API_KEY with GROK_USE_VERTEX unset.`, + 400, + "vertex_unsupported_endpoint", + ); +} + +async function translateVertexError(response: Response, vertexUrl: string): Promise { + const body = await response.text(); + const detail = extractVertexErrorMessage(body) || response.statusText || `HTTP ${response.status}`; + return vertexErrorResponse( + `Vertex AI request failed (${response.status}) for ${vertexUrl}: ${detail}`, + response.status, + ); +} + +function vertexErrorResponse(message: string, status: number, code = "vertex_request_failed"): Response { + return new Response( + JSON.stringify({ + error: { + message, + type: "vertex_ai_error", + code, + }, + }), + { + status, + headers: { + "Content-Type": "application/json", + }, + }, + ); +} + +function extractVertexErrorMessage(body: string): string | undefined { + try { + const parsed = JSON.parse(body) as { error?: { message?: string; status?: string } }; + return parsed.error?.message || parsed.error?.status; + } catch { + return body.trim() || undefined; + } +} + +function normalizeVertexResponse(payload: unknown): VertexResponse { + if (Array.isArray(payload)) { + return (payload[payload.length - 1] ?? {}) as VertexResponse; + } + return (payload ?? {}) as VertexResponse; +} + +function convertToolChoiceToVertexToolConfig( + toolChoice: XaiChatRequest["tool_choice"], + availableFunctionNames: Set, +): JsonRecord | undefined { + if (!toolChoice || toolChoice === "auto") { + return { functionCallingConfig: { mode: "AUTO" } }; + } + if (toolChoice === "none") { + return { functionCallingConfig: { mode: "NONE" } }; + } + if (toolChoice === "required") { + return { functionCallingConfig: { mode: "ANY" } }; + } + if (toolChoice.type === "function" && availableFunctionNames.has(toolChoice.function.name)) { + return { + functionCallingConfig: { + mode: "ANY", + allowedFunctionNames: [toolChoice.function.name], + }, + }; + } + return { functionCallingConfig: { mode: "AUTO" } }; +} + +function textPartsFromContent(content: unknown): VertexPart[] { + const text = extractTextContent(content); + return text ? [{ text }] : []; +} + +function extractTextContent(content: unknown): string { + if (content == null) return ""; + if (typeof content === "string") return content; + if (!Array.isArray(content)) return String(content); + + const parts: string[] = []; + for (const part of content) { + if (!part || typeof part !== "object") continue; + const record = part as JsonRecord; + if (record.type === "text" && typeof record.text === "string") { + parts.push(record.text); + continue; + } + if (record.type === "image_url") { + throw new Error( + "Vertex Grok adapter supports text chat and tool payloads; image_url message parts are not supported.", + ); + } + } + return parts.join("\n"); +} + +function hasVertexPartValue(part: VertexPart): boolean { + return Boolean(part.text || part.functionCall || part.functionResponse); +} + +function parseJsonObject(value: string): JsonRecord { + if (!value.trim()) return {}; + try { + const parsed = JSON.parse(value) as unknown; + return isRecord(parsed) ? parsed : { value: parsed }; + } catch { + return { value }; + } +} + +function responseObjectFromToolContent(content: unknown): JsonRecord { + const text = extractTextContent(content); + if (!text.trim()) return {}; + try { + const parsed = JSON.parse(text) as unknown; + return isRecord(parsed) ? parsed : { result: parsed }; + } catch { + return { result: text }; + } +} + +function extractTextFromVertexCandidate(candidate: VertexCandidate): string { + return (candidate.content?.parts ?? []) + .map((part) => (typeof part.text === "string" ? part.text : "")) + .filter(Boolean) + .join(""); +} + +function extractFunctionCalls( + candidate: VertexCandidate, + responseId: string, + candidateIndex: number, + includeDeltaIndex: boolean, +) { + const toolCalls = []; + let toolCallIndex = 0; + + for (const part of candidate.content?.parts ?? []) { + if (!part.functionCall) continue; + const index = toolCallIndex++; + toolCalls.push({ + ...(includeDeltaIndex ? { index } : {}), + id: `call_${responseId}_${candidateIndex}_${index}`, + type: "function", + function: { + name: part.functionCall.name, + arguments: JSON.stringify(part.functionCall.args ?? {}), + }, + }); + } + + return toolCalls; +} + +function mapVertexFinishReason(reason: string | undefined): string | null { + switch (reason) { + case undefined: + case "": + return null; + case "STOP": + return "stop"; + case "MAX_TOKENS": + return "length"; + case "MALFORMED_FUNCTION_CALL": + return "tool_calls"; + case "SAFETY": + case "RECITATION": + case "BLOCKLIST": + case "PROHIBITED_CONTENT": + case "SPII": + return "content_filter"; + default: + return "stop"; + } +} + +function convertVertexUsage(usage: VertexResponse["usageMetadata"]) { + if (!usage) return undefined; + const promptTokens = usage.promptTokenCount ?? 0; + const completionTokens = usage.candidatesTokenCount ?? 0; + return { + prompt_tokens: promptTokens, + completion_tokens: completionTokens, + total_tokens: usage.totalTokenCount ?? promptTokens + completionTokens, + }; +} + +function removeUndefined(value: T): T { + return Object.fromEntries(Object.entries(value).filter((entry) => entry[1] !== undefined)) as T; +} + +function isRecord(value: unknown): value is JsonRecord { + return typeof value === "object" && value !== null && !Array.isArray(value); +} diff --git a/src/grok/vertex-auth.test.ts b/src/grok/vertex-auth.test.ts new file mode 100644 index 00000000..ba83da67 --- /dev/null +++ b/src/grok/vertex-auth.test.ts @@ -0,0 +1,75 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; + +const googleAuthConstructor = vi.hoisted(() => vi.fn()); +const getAccessTokenMock = vi.hoisted(() => vi.fn(async () => "adc-token")); + +vi.mock("google-auth-library", () => ({ + GoogleAuth: class { + constructor(options: unknown) { + googleAuthConstructor(options); + } + + async getAccessToken() { + return getAccessTokenMock(); + } + }, +})); + +describe("Vertex auth", () => { + beforeEach(() => { + vi.resetModules(); + googleAuthConstructor.mockClear(); + getAccessTokenMock.mockReset(); + getAccessTokenMock.mockResolvedValue("adc-token"); + }); + + it("passes an explicit fetch implementation to google-auth-library", async () => { + const { getVertexAccessToken } = await import("./vertex-auth"); + + await expect(getVertexAccessToken()).resolves.toBe("adc-token"); + + expect(googleAuthConstructor).toHaveBeenCalledWith({ + scopes: ["https://www.googleapis.com/auth/cloud-platform"], + clientOptions: { + transporterOptions: { + fetchImplementation: expect.any(Function), + }, + }, + }); + }); + + it("turns invalid_rapt refresh failures into an actionable ADC reauth message", async () => { + getAccessTokenMock.mockRejectedValueOnce({ + response: { + data: { + error: "invalid_grant", + error_description: "reauth related error (invalid_rapt)", + error_subtype: "invalid_rapt", + }, + }, + }); + const { getVertexAccessToken } = await import("./vertex-auth"); + + let message = ""; + try { + await getVertexAccessToken(); + } catch (caught: unknown) { + message = caught instanceof Error ? caught.message : String(caught); + } + + expect(message).toContain("Google Application Default Credentials need reauthentication."); + expect(message).toContain("gcloud auth application-default login"); + expect(message).toContain("gcloud auth application-default revoke"); + expect(message).not.toContain('{"error"'); + }); + + it("reuses the GoogleAuth client so token caching can work", async () => { + const { getVertexAccessToken } = await import("./vertex-auth"); + + await expect(getVertexAccessToken()).resolves.toBe("adc-token"); + await expect(getVertexAccessToken()).resolves.toBe("adc-token"); + + expect(googleAuthConstructor).toHaveBeenCalledTimes(1); + expect(getAccessTokenMock).toHaveBeenCalledTimes(2); + }); +}); diff --git a/src/grok/vertex-auth.ts b/src/grok/vertex-auth.ts new file mode 100644 index 00000000..ab51c15f --- /dev/null +++ b/src/grok/vertex-auth.ts @@ -0,0 +1,120 @@ +import { GoogleAuth } from "google-auth-library"; + +const VERTEX_AUTH_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]; +const fetchImplementation: typeof fetch = (input, init) => globalThis.fetch(input, init); +let vertexAuth: GoogleAuth | undefined; + +export async function getVertexAccessToken(): Promise { + if (!vertexAuth) { + vertexAuth = new GoogleAuth({ + scopes: VERTEX_AUTH_SCOPES, + clientOptions: { + transporterOptions: { + fetchImplementation, + }, + }, + }); + } + + let token: string | null | undefined; + try { + token = await vertexAuth.getAccessToken(); + } catch (err: unknown) { + throw new Error(formatVertexAuthErrorMessage(err)); + } + + if (!token) { + throw new Error( + "Could not obtain a Google Cloud access token from Application Default Credentials. Run `gcloud auth application-default login` or configure ADC for this environment.", + ); + } + + return token; +} + +export function formatVertexAuthErrorMessage(err: unknown): string { + const detail = extractGoogleAuthDetail(err); + if (isReauthError(detail)) { + return [ + "Google Application Default Credentials need reauthentication.", + "", + "Run `gcloud auth application-default login` in a terminal, then restart `grok`.", + "If that still fails, run `gcloud auth application-default revoke` and then `gcloud auth application-default login` again.", + "For SSH/headless environments, use `gcloud auth application-default login --no-launch-browser`.", + ].join("\n"); + } + + return [ + "Could not obtain a Google Cloud access token from Application Default Credentials.", + detail ? `Google auth error: ${detail}` : "Google auth returned an unknown error.", + "", + "Run `gcloud auth application-default login` or configure ADC for this environment.", + ].join("\n"); +} + +function isReauthError(detail: string): boolean { + return /invalid_rapt|invalid_grant|reauth|application-default login|cannot prompt/i.test(detail); +} + +function extractGoogleAuthDetail(err: unknown): string { + const fields = extractGoogleAuthFields(err); + const joined = [fields.error, fields.errorDescription, fields.errorSubtype, fields.message] + .filter(Boolean) + .join(": "); + if (joined) return joined; + if (err instanceof Error) return err.message; + return typeof err === "string" ? err : ""; +} + +function extractGoogleAuthFields(value: unknown): { + error?: string; + errorDescription?: string; + errorSubtype?: string; + message?: string; +} { + if (typeof value === "string") { + return parseGoogleAuthJson(value) ?? { message: value }; + } + if (!value || typeof value !== "object") return {}; + + const record = value as Record; + const data = getNestedRecord(record, "response", "data") ?? getNestedRecord(record, "data"); + const parsedMessage = typeof record.message === "string" ? parseGoogleAuthJson(record.message) : undefined; + const source = data ?? parsedMessage ?? record; + + return { + error: stringField(source, "error"), + errorDescription: stringField(source, "error_description") ?? stringField(source, "errorDescription"), + errorSubtype: stringField(source, "error_subtype") ?? stringField(source, "errorSubtype"), + message: typeof record.message === "string" && !parsedMessage ? record.message : stringField(source, "message"), + }; +} + +function parseGoogleAuthJson(value: string): Record | undefined { + const trimmed = value.trim(); + if (!trimmed.startsWith("{")) return undefined; + try { + const parsed = JSON.parse(trimmed) as unknown; + return parsed && typeof parsed === "object" && !Array.isArray(parsed) + ? (parsed as Record) + : undefined; + } catch { + return undefined; + } +} + +function getNestedRecord(record: Record, ...path: string[]): Record | undefined { + let current: unknown = record; + for (const part of path) { + if (!current || typeof current !== "object" || Array.isArray(current)) return undefined; + current = (current as Record)[part]; + } + return current && typeof current === "object" && !Array.isArray(current) + ? (current as Record) + : undefined; +} + +function stringField(record: Record, key: string): string | undefined { + const value = record[key]; + return typeof value === "string" && value.trim() ? value.trim() : undefined; +} diff --git a/src/index.ts b/src/index.ts index d56630d1..fbbe6cea 100755 --- a/src/index.ts +++ b/src/index.ts @@ -21,8 +21,10 @@ import { getBaseURL, getCurrentSandboxMode, getCurrentSandboxSettings, + isVertexModeEnabled, loadPaymentSettings, mergeSandboxSettings, + requireVertexSettings, type SandboxMode, type SandboxSettings, savePaymentSettings, @@ -102,7 +104,7 @@ async function startInteractive( async function runHeadless( prompt: string, - apiKey: string, + apiKey: string | undefined, baseURL: string, model: string | undefined, maxToolRounds: number, @@ -185,8 +187,12 @@ async function runBackgroundDelegation(jobPath: string, options: CliOptions) { try { const delegation = await loadDelegation(jobPath); const apiKey = stringOption(options.apiKey) || getApiKey(); - if (!apiKey) { - throw new Error("API key required. Set GROK_API_KEY, use --api-key, or save it to ~/.grok/user-settings.json."); + requireModelAuth(apiKey, { exitOnError: false }); + const useBatchApi = Boolean(delegation.batchApi ?? options.batchApi === true); + if (isVertexModeEnabled() && useBatchApi) { + throw new Error( + "xAI Batch API is not available when GROK_USE_VERTEX=1. Use normal Vertex AI streaming/headless mode, or unset GROK_USE_VERTEX and configure GROK_API_KEY.", + ); } const baseURL = stringOption(options.baseUrl) || getBaseURL(); @@ -200,7 +206,7 @@ async function runBackgroundDelegation(jobPath: string, options: CliOptions) { persistSession: false, sandboxMode, sandboxSettings, - batchApi: Boolean(delegation.batchApi ?? options.batchApi === true), + batchApi: useBatchApi, }); const result = await agent.runTaskRequest({ agent: delegation.agent, @@ -250,17 +256,33 @@ function resolveConfig(options: CliOptions) { } const sandboxSettings = mergeSandboxSettings(getCurrentSandboxSettings(), cliOverrides); - if (typeof options.apiKey === "string") saveUserSettings({ apiKey: options.apiKey }); + if (typeof options.apiKey === "string" && !isVertexModeEnabled()) saveUserSettings({ apiKey: options.apiKey }); if (typeof options.model === "string") saveUserSettings({ defaultModel: normalizeModelId(options.model) }); return { apiKey, baseURL, model, maxToolRounds, sandboxMode, sandboxSettings }; } -function requireApiKey(apiKey: string | undefined): string { +function requireModelAuth(apiKey: string | undefined, options: { exitOnError?: boolean } = {}): string | undefined { + if (isVertexModeEnabled()) { + try { + requireVertexSettings(); + } catch (err: unknown) { + const message = err instanceof Error ? err.message : String(err); + if (options.exitOnError === false) { + throw new Error(message); + } + console.error(`Error: ${message}`); + process.exit(1); + } + return apiKey; + } + if (!apiKey) { - console.error( - "Error: API key required. Set GROK_API_KEY env var, use --api-key, or save to ~/.grok/user-settings.json", - ); + const message = "API key required. Set GROK_API_KEY env var, use --api-key, or save to ~/.grok/user-settings.json"; + if (options.exitOnError === false) { + throw new Error(message); + } + console.error(`Error: ${message}`); process.exit(1); } @@ -313,6 +335,17 @@ program } const config = resolveConfig(options); + if (isVertexModeEnabled() && options.batchApi === true) { + console.error( + "Error: xAI Batch API is not available when GROK_USE_VERTEX=1. Use normal Vertex AI mode, or unset GROK_USE_VERTEX and configure GROK_API_KEY.", + ); + process.exit(1); + } + + const isInteractiveRun = !options.verify && !options.prompt; + if (isVertexModeEnabled() && !isInteractiveRun) { + requireModelAuth(config.apiKey); + } if (options.verify) { const verifyError = getVerifyCliError({ hasPrompt: Boolean(options.prompt), hasMessageArgs: message.length > 0 }); @@ -323,7 +356,7 @@ program await runHeadless( buildVerifyPrompt(process.cwd()), - requireApiKey(config.apiKey), + requireModelAuth(config.apiKey), config.baseURL, config.model, config.maxToolRounds, @@ -339,7 +372,7 @@ program if (options.prompt) { await runHeadless( options.prompt, - requireApiKey(config.apiKey), + requireModelAuth(config.apiKey), config.baseURL, config.model, config.maxToolRounds, @@ -385,7 +418,7 @@ program process.off("SIGTERM", exitCleanlyOnSigterm); try { await runTelegramHeadlessBridge({ - apiKey: requireApiKey(config.apiKey), + apiKey: requireModelAuth(config.apiKey), baseURL: config.baseURL, model: config.model, maxToolRounds: config.maxToolRounds, diff --git a/src/telegram/headless-bridge.test.ts b/src/telegram/headless-bridge.test.ts index bd52a1a7..b68410bc 100644 --- a/src/telegram/headless-bridge.test.ts +++ b/src/telegram/headless-bridge.test.ts @@ -1,7 +1,22 @@ import * as path from "node:path"; -import { describe, expect, it } from "vitest"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import { hasTelegramModelAuth } from "./headless-bridge"; import { resolveTelegramHeadlessBridgePaths } from "./headless-bridge-paths"; +vi.mock("../agent/agent", () => ({ + Agent: class {}, +})); + +const originalApiKey = process.env.GROK_API_KEY; + +afterEach(() => { + if (originalApiKey === undefined) { + delete process.env.GROK_API_KEY; + } else { + process.env.GROK_API_KEY = originalApiKey; + } +}); + describe("resolveTelegramHeadlessBridgePaths", () => { it("uses default files in the provided cwd", () => { const cwd = path.resolve("fixture-workspace"); @@ -26,3 +41,11 @@ describe("resolveTelegramHeadlessBridgePaths", () => { }); }); }); + +describe("hasTelegramModelAuth", () => { + it("accepts an explicit CLI api key even when saved auth is absent", () => { + delete process.env.GROK_API_KEY; + + expect(hasTelegramModelAuth("cli-key")).toBe(true); + }); +}); diff --git a/src/telegram/headless-bridge.ts b/src/telegram/headless-bridge.ts index 756cf21d..c964bf19 100644 --- a/src/telegram/headless-bridge.ts +++ b/src/telegram/headless-bridge.ts @@ -9,6 +9,7 @@ import { getCurrentSandboxMode, getCurrentSandboxSettings, getTelegramBotToken, + hasModelAuthConfigured, loadUserSettings, type SandboxMode, type SandboxSettings, @@ -36,7 +37,7 @@ export interface TelegramHeadlessBridgeOptions { } interface TelegramHeadlessStartupConfig { - apiKey: string; + apiKey?: string; baseURL: string; model: string; sandboxMode: SandboxMode; @@ -121,8 +122,10 @@ export async function runTelegramHeadlessBridge(options: TelegramHeadlessBridgeO } const apiKey = options.apiKey ?? getApiKey(); - if (!apiKey) { - throw new Error("Missing Grok API key."); + if (!hasTelegramModelAuth(apiKey)) { + throw new Error( + "Missing model authentication. Set GROK_API_KEY, or configure Vertex with GROK_VERTEX_PROJECT_ID and Application Default Credentials.", + ); } const startupConfig: TelegramHeadlessStartupConfig = { @@ -270,3 +273,7 @@ export async function runTelegramHeadlessBridge(options: TelegramHeadlessBridgeO await shutdownComplete; } } + +export function hasTelegramModelAuth(apiKey: string | undefined): boolean { + return Boolean(apiKey) || hasModelAuthConfigured(); +} diff --git a/src/test/setup.ts b/src/test/setup.ts new file mode 100644 index 00000000..53bdd635 --- /dev/null +++ b/src/test/setup.ts @@ -0,0 +1,11 @@ +import { rmSync } from "node:fs"; +import os from "node:os"; +import path from "node:path"; +import { beforeEach } from "vitest"; + +const testUserSettingsPath = path.join(os.tmpdir(), `grok-cli-vitest-user-settings-${process.pid}.json`); + +beforeEach(() => { + process.env.GROK_USER_SETTINGS_PATH = testUserSettingsPath; + rmSync(testUserSettingsPath, { force: true }); +}); diff --git a/src/ui/app-auth-modal.test.ts b/src/ui/app-auth-modal.test.ts new file mode 100644 index 00000000..7d4989b3 --- /dev/null +++ b/src/ui/app-auth-modal.test.ts @@ -0,0 +1,14 @@ +import { describe, expect, it } from "vitest"; +import { getNextAuthModalError } from "./auth-modal-state"; + +describe("getNextAuthModalError", () => { + it("clears stale auth errors by default", () => { + expect(getNextAuthModalError("old error")).toBeNull(); + }); + + it("preserves the current auth error for error-driven modal opens", () => { + expect(getNextAuthModalError("Vertex authentication failed", { preserveError: true })).toBe( + "Vertex authentication failed", + ); + }); +}); diff --git a/src/ui/app.tsx b/src/ui/app.tsx index 3996bc30..97181d55 100644 --- a/src/ui/app.tsx +++ b/src/ui/app.tsx @@ -38,9 +38,16 @@ import { FileIndex } from "../utils/file-index.js"; import { copyTextToHostClipboard } from "../utils/host-clipboard"; import { type CustomSubagentConfig, + DEFAULT_VERTEX_BASE_URL, + DEFAULT_VERTEX_LOCATION, getApiKey, + getModelAuthStatus, getTelegramBotToken, + getVertexSettings, + hasModelAuthConfigured, isReservedSubagentName, + isTruthyEnv, + isVertexModeEnabled, loadMcpServers, loadPaymentSettings, loadUserSettings, @@ -56,6 +63,7 @@ import { savePaymentSettings, saveProjectSettings, saveUserSettings, + VERTEX_API_KEY_PLACEHOLDER, } from "../utils/settings"; import { discoverSkills, formatSkillsForChat } from "../utils/skills"; import { formatSubagentName } from "../utils/subagent-display"; @@ -68,6 +76,7 @@ import { SubagentEditorModal, SubagentsBrowserModal, } from "./agents-modal"; +import { type AuthModalOpenOptions, getNextAuthModalError } from "./auth-modal-state"; import { BtwOverlay, type BtwState } from "./components/btw-overlay.js"; import { SuggestionOverlay } from "./components/SuggestionOverlay.js"; import { type TypeaheadState, useTypeahead } from "./hooks/useTypeahead.js"; @@ -610,6 +619,53 @@ interface ActiveTurnState { flushedAssistantChars: number; } +type AuthModalTab = "xai" | "vertex"; +const VERTEX_AUTH_FIELDS = ["projectId", "location", "baseURL"] as const; +type VertexAuthField = (typeof VERTEX_AUTH_FIELDS)[number]; + +interface VertexAuthDraft { + projectId: string; + location: string; + baseURL: string; +} + +function getDefaultAuthModalTab(): AuthModalTab { + return getModelAuthStatus().activeMode === "vertex" ? "vertex" : "xai"; +} + +function getVertexAuthDraft(): VertexAuthDraft { + const settings = getVertexSettings(); + return { + projectId: settings.projectId, + location: settings.location || DEFAULT_VERTEX_LOCATION, + baseURL: settings.baseURL || DEFAULT_VERTEX_BASE_URL, + }; +} + +function normalizeVertexAuthBaseUrl(value: string): string { + return (value.trim() || DEFAULT_VERTEX_BASE_URL).replace(/\/+$/, ""); +} + +function isInvalidVertexLocation(value: string): boolean { + return value.trim().toLowerCase() === "global"; +} + +function isHttpBaseUrl(value: string): boolean { + try { + const parsed = new URL(value); + return parsed.protocol === "https:" || parsed.protocol === "http:"; + } catch { + return false; + } +} + +function syncTextareaRef(ref: React.RefObject, value: string): void { + ref.current?.clear(); + if (value) { + ref.current?.insertText(value); + } +} + export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) { const t = dark; const renderer = useRenderer(); @@ -655,6 +711,10 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) const [sessionTitle, setSessionTitle] = useState(() => agent.getSessionTitle()); const [sessionId, setSessionId] = useState(() => agent.getSessionId()); const [showApiKeyModal, setShowApiKeyModal] = useState(() => !initialHasApiKey); + const [authModalTab, setAuthModalTab] = useState(() => getDefaultAuthModalTab()); + const [vertexAuthField, setVertexAuthField] = useState("projectId"); + const [vertexAuthDraft, setVertexAuthDraft] = useState(() => getVertexAuthDraft()); + const [vertexAuthSyncKey, setVertexAuthSyncKey] = useState(0); const [apiKeyError, setApiKeyError] = useState(null); const [showSlashMenu, setShowSlashMenu] = useState(false); const [slashMenuIndex, setSlashMenuIndex] = useState(0); @@ -680,6 +740,9 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) const pasteCounterRef = useRef(0); const pasteBlocksRef = useRef([]); const apiKeyInputRef = useRef(null); + const vertexProjectInputRef = useRef(null); + const vertexLocationInputRef = useRef(null); + const vertexBaseUrlInputRef = useRef(null); const inputRef = useRef(null); const scrollRef = useRef(null); const { width, height } = useTerminalDimensions(); @@ -689,6 +752,8 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) const isProcessingRef = useRef(false); const hasApiKeyRef = useRef(initialHasApiKey); const showApiKeyModalRef = useRef(!initialHasApiKey); + const authModalTabRef = useRef(getDefaultAuthModalTab()); + const vertexAuthFieldRef = useRef("projectId"); const queuedMessagesRef = useRef([]); const processMessageRef = useRef<(text: string, displayText?: string) => Promise | void>(() => {}); const [queuedMessages, setQueuedMessages] = useState([]); @@ -1598,8 +1663,10 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) } const apiKey = getApiKey(); - if (!apiKey) { - throw new Error("Grok API key required. Add it in the CLI or set GROK_API_KEY."); + if (!hasModelAuthConfigured()) { + throw new Error( + "Model authentication required. Set GROK_API_KEY, or configure Vertex with GROK_VERTEX_PROJECT_ID and Application Default Credentials.", + ); } const u = loadUserSettings(); @@ -1706,7 +1773,7 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) const startTelegramBridge = useCallback(() => { const token = getTelegramBotToken(); - if (!token || !getApiKey()) return; + if (!token || !hasModelAuthConfigured()) return; if (bridgeRef.current) return; const bridge = createTelegramBridge({ @@ -1772,12 +1839,53 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) return () => clearTimeout(id); }, [copyFlashId]); - const openApiKeyModal = useCallback(() => { - showApiKeyModalRef.current = true; - setApiKeyError(null); - setShowApiKeyModal(true); + const refreshVertexAuthDraft = useCallback(() => { + setVertexAuthDraft(getVertexAuthDraft()); + setVertexAuthSyncKey((n) => n + 1); }, []); + const selectVertexAuthField = useCallback((field: VertexAuthField) => { + vertexAuthFieldRef.current = field; + setVertexAuthField(field); + }, []); + + const moveVertexAuthField = useCallback( + (delta: number) => { + const index = VERTEX_AUTH_FIELDS.indexOf(vertexAuthFieldRef.current); + const nextIndex = (index + delta + VERTEX_AUTH_FIELDS.length) % VERTEX_AUTH_FIELDS.length; + selectVertexAuthField(VERTEX_AUTH_FIELDS[nextIndex]); + }, + [selectVertexAuthField], + ); + + const selectAuthModalTab = useCallback( + (tab: AuthModalTab, options: AuthModalOpenOptions = {}) => { + authModalTabRef.current = tab; + setAuthModalTab(tab); + setApiKeyError((currentError) => getNextAuthModalError(currentError, options)); + if (tab === "vertex") { + selectVertexAuthField("projectId"); + refreshVertexAuthDraft(); + } + }, + [refreshVertexAuthDraft, selectVertexAuthField], + ); + + const openApiKeyModal = useCallback( + (tab: AuthModalTab = getDefaultAuthModalTab(), options: AuthModalOpenOptions = {}) => { + authModalTabRef.current = tab; + setAuthModalTab(tab); + showApiKeyModalRef.current = true; + setApiKeyError((currentError) => getNextAuthModalError(currentError, options)); + if (tab === "vertex") { + selectVertexAuthField("projectId"); + refreshVertexAuthDraft(); + } + setShowApiKeyModal(true); + }, + [refreshVertexAuthDraft, selectVertexAuthField], + ); + const closeApiKeyModal = useCallback(() => { showApiKeyModalRef.current = false; setApiKeyError(null); @@ -1785,6 +1893,15 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) }, []); const submitApiKey = useCallback(() => { + const shouldDisableSavedVertex = isVertexModeEnabled(); + if (shouldDisableSavedVertex) { + if (isTruthyEnv(process.env.GROK_USE_VERTEX)) { + setApiKeyError("GROK_USE_VERTEX is set in this shell. Unset it before using a native xAI API key."); + selectAuthModalTab("vertex", { preserveError: true }); + return; + } + } + const apiKey = (apiKeyInputRef.current?.plainText || "").trim(); if (!apiKey) { setApiKeyError("Enter an API key to continue."); @@ -1795,6 +1912,11 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) return; } + if (shouldDisableSavedVertex) { + const current = loadUserSettings(); + saveUserSettings({ vertex: { ...current.vertex, enabled: false } }); + } + saveUserSettings({ apiKey }); agent.setApiKey(apiKey); hasApiKeyRef.current = true; @@ -1806,7 +1928,53 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) if (getTelegramBotToken()) { startTelegramBridge(); } - }, [agent, startTelegramBridge]); + }, [agent, selectAuthModalTab, startTelegramBridge]); + + const submitVertexSettings = useCallback(() => { + const projectId = (vertexProjectInputRef.current?.plainText || "").trim(); + const location = (vertexLocationInputRef.current?.plainText || "").trim() || DEFAULT_VERTEX_LOCATION; + const baseURL = normalizeVertexAuthBaseUrl(vertexBaseUrlInputRef.current?.plainText || ""); + + if (!projectId) { + setApiKeyError("Enter GROK_VERTEX_PROJECT_ID for the Google Cloud project with the xAI Vertex model enabled."); + selectVertexAuthField("projectId"); + return; + } + if (!location) { + setApiKeyError("Enter GROK_VERTEX_LOCATION, for example us-central1 or europe-west1."); + selectVertexAuthField("location"); + return; + } + if (isInvalidVertexLocation(location)) { + setApiKeyError("Use us-central1 or europe-west1 for GROK_VERTEX_LOCATION. The host is global, not the location."); + selectVertexAuthField("location"); + return; + } + if (!isHttpBaseUrl(baseURL)) { + setApiKeyError("Enter a valid GROK_VERTEX_BASE_URL, or leave it as https://aiplatform.googleapis.com."); + selectVertexAuthField("baseURL"); + return; + } + + saveUserSettings({ + vertex: { + enabled: true, + projectId, + location, + baseURL, + }, + }); + agent.setApiKey(VERTEX_API_KEY_PLACEHOLDER); + hasApiKeyRef.current = true; + showApiKeyModalRef.current = false; + setHasApiKey(true); + setApiKeyError(null); + setShowApiKeyModal(false); + refreshVertexAuthDraft(); + if (getTelegramBotToken()) { + startTelegramBridge(); + } + }, [agent, refreshVertexAuthDraft, selectVertexAuthField, startTelegramBridge]); useEffect(() => { hasApiKeyRef.current = hasApiKey; @@ -1869,8 +2037,8 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) setTelegramTokenError("Paste your bot token from @BotFather."); return; } - if (!getApiKey()) { - setTelegramTokenError("Add a Grok API key first."); + if (!hasModelAuthConfigured()) { + setTelegramTokenError("Configure GROK_API_KEY, or configure Vertex with GROK_VERTEX_PROJECT_ID first."); return; } const u = loadUserSettings(); @@ -1924,9 +2092,16 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) const beginTelegramFromConnect = useCallback(() => { setShowConnectModal(false); - if (!getApiKey()) { - setMessages((p) => [...p, { type: "assistant", content: "Add a Grok API key first.", timestamp: new Date() }]); - openApiKeyModal(); + if (!hasModelAuthConfigured()) { + setMessages((p) => [ + ...p, + { + type: "assistant", + content: "Configure GROK_API_KEY, or configure Vertex with GROK_VERTEX_PROJECT_ID first.", + timestamp: new Date(), + }, + ]); + openApiKeyModal(isVertexModeEnabled() ? "vertex" : "xai"); return; } if (!getTelegramBotToken()) { @@ -2135,9 +2310,13 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) } if (turnHadAuthError) { - setApiKeyError("Your API key is invalid or expired. Please enter a new key."); - setShowApiKeyModal(true); - showApiKeyModalRef.current = true; + if (isVertexModeEnabled()) { + setApiKeyError("Vertex authentication failed. Check ADC, GROK_VERTEX_PROJECT_ID, and Vertex model access."); + openApiKeyModal("vertex", { preserveError: true }); + } else { + setApiKeyError("Your API key is invalid or expired. Please enter a new key."); + openApiKeyModal("xai", { preserveError: true }); + } } if (!isStale()) { @@ -2154,6 +2333,7 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) applyLocalAssistantDelta, beginLiveTurn, finalizeActiveTurn, + openApiKeyModal, scrollToBottom, sessionTitle, showLiveToolCalls, @@ -2901,8 +3081,24 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) closeApiKeyModal(); return; } + if (key.name === "tab") { + selectAuthModalTab(authModalTabRef.current === "xai" ? "vertex" : "xai"); + return; + } + if (authModalTabRef.current === "vertex" && key.name === "down") { + moveVertexAuthField(1); + return; + } + if (authModalTabRef.current === "vertex" && key.name === "up") { + moveVertexAuthField(-1); + return; + } if (key.name === "return") { - submitApiKey(); + if (authModalTabRef.current === "xai") { + submitApiKey(); + } else { + submitVertexSettings(); + } } return; } @@ -3257,6 +3453,7 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) mcpEditorField, mcpEditorFields, mcpModalIndex, + moveVertexAuthField, mcpRows, modelPickerIndex, openApiKeyModal, @@ -3267,6 +3464,7 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) removeSchedule, scheduleModalIndex, scheduleRows, + selectAuthModalTab, showScheduleDetails, submitTelegramPair, submitTelegramToken, @@ -3296,6 +3494,7 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) showSlashMenu, slashMenuIndex, submitApiKey, + submitVertexSettings, submitPlanAnswers, copyTuiSelectionToHost, toggleSavedMcp, @@ -3550,13 +3749,22 @@ export function App({ agent, startupConfig, initialMessage, onExit }: AppProps) )} {showApiKeyModal && ( - )} {showUpdateModal && updateInfo && ( @@ -4056,25 +4264,54 @@ function CopyFlashBanner({ t, width }: { t: Theme; width: number }) { ); } -function ApiKeyModal({ +function AuthModal({ t, width, height, inputRef, + selectedTab, + vertexProjectRef, + vertexLocationRef, + vertexBaseUrlRef, + vertexDraft, + vertexSyncKey, + activeVertexField, + authStatus, error, onSubmit, + onSubmitVertex, }: { t: Theme; width: number; height: number; inputRef: React.RefObject; + selectedTab: AuthModalTab; + vertexProjectRef: React.RefObject; + vertexLocationRef: React.RefObject; + vertexBaseUrlRef: React.RefObject; + vertexDraft: VertexAuthDraft; + vertexSyncKey: number; + activeVertexField: VertexAuthField; + authStatus: ReturnType; error: string | null; onSubmit: () => void; + onSubmitVertex: () => void; }) { + useEffect(() => { + void vertexSyncKey; + syncTextareaRef(vertexProjectRef, vertexDraft.projectId); + syncTextareaRef(vertexLocationRef, vertexDraft.location); + syncTextareaRef(vertexBaseUrlRef, vertexDraft.baseURL); + }, [vertexBaseUrlRef, vertexDraft, vertexLocationRef, vertexProjectRef, vertexSyncKey]); + const overlayBg = "#000000cc" as string; - const panelWidth = Math.min(68, width - 6); - const panelHeight = 13; + const panelWidth = Math.min(82, width - 6); + const panelHeight = selectedTab === "xai" ? 16 : 25; const top = bottomAlignedModalTop(height, panelHeight); + const xaiSelected = selectedTab === "xai"; + const vertexSelected = selectedTab === "vertex"; + const vertex = authStatus.vertex; + const missingVertex = vertex.missing.join(", "); return ( - {"Add API key"} + {"Choose authentication"} {"esc"} - - {"Paste your xAI API key to unlock chat. You can hide this prompt with esc."} - - - -