diff --git a/Cargo.lock b/Cargo.lock index a5cdb26..69f900f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1073,6 +1073,12 @@ dependencies = [ "windows-link", ] +[[package]] +name = "chunked_transfer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" + [[package]] name = "clang-sys" version = "1.8.1" @@ -3257,19 +3263,68 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.0", "log", "thiserror 1.0.69", "walkdir", "windows-sys 0.45.0", ] +[[package]] +name = "jni" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efd9a482cf3a427f00d6b35f14332adc7902ce91efb778580e180ff90fa3498" +dependencies = [ + "cfg-if", + "combine", + "jni-macros", + "jni-sys 0.4.1", + "log", + "simd_cesu8", + "thiserror 2.0.17", + "walkdir", + "windows-link", +] + +[[package]] +name = "jni-macros" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00109accc170f0bdb141fed3e393c565b6f5e072365c3bd58f5b062591560a3" +dependencies = [ + "proc-macro2", + "quote", + "rustc_version", + "simd_cesu8", + "syn 2.0.111", +] + [[package]] name = "jni-sys" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn 2.0.111", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -3624,6 +3679,12 @@ dependencies = [ "pxfm", ] +[[package]] +name = "ndk-context" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b02d87554356db9e9a873add8782d4ea6e3e58ea071a9adb9a2e8ddb884a8b" + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -4176,6 +4237,7 @@ dependencies = [ "tempfile", "terminal_size", "tikv-jemallocator", + "tiny_http", "tokio", "tower", "tower-http", @@ -4183,6 +4245,7 @@ dependencies = [ "tracing-opentelemetry", "tracing-subscriber", "url", + "webbrowser", ] [[package]] @@ -4255,6 +4318,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "camino", + "getrandom 0.2.16", "http", "indexmap 2.12.1", "keyring", @@ -4264,6 +4328,7 @@ dependencies = [ "rmcp", "serde", "serde_json", + "sha2", "shlex", "thiserror 2.0.17", "tokio", @@ -5375,7 +5440,7 @@ checksum = "1d99feebc72bae7ab76ba994bb5e121b8d83d910ca40b36e0921f53becc41784" dependencies = [ "core-foundation 0.10.1", "core-foundation-sys", - "jni", + "jni 0.21.1", "log", "once_cell", "rustls", @@ -5781,6 +5846,22 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "simd_cesu8" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94f90157bb87cddf702797c5dadfa0be7d266cdf49e22da2fcaa32eff75b2c33" +dependencies = [ + "rustc_version", + "simdutf8", +] + +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "similar" version = "2.7.0" @@ -6787,6 +6868,18 @@ dependencies = [ "zoneinfo64", ] +[[package]] +name = "tiny_http" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82" +dependencies = [ + "ascii", + "chunked_transfer", + "httpdate", + "log", +] + [[package]] name = "tinystr" version = "0.8.2" @@ -7611,6 +7704,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webbrowser" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe985f41e291eecef5e5c0770a18d28390addb03331c043964d9e916453d6f16" +dependencies = [ + "core-foundation 0.10.1", + "jni 0.22.4", + "log", + "ndk-context", + "objc2", + "objc2-foundation", + "url", + "web-sys", +] + [[package]] name = "webpki-root-certs" version = "1.0.6" diff --git a/crates/pctx/Cargo.toml b/crates/pctx/Cargo.toml index 2411c7d..461e74c 100644 --- a/crates/pctx/Cargo.toml +++ b/crates/pctx/Cargo.toml @@ -49,6 +49,8 @@ chrono = { workspace = true, features = ["serde"] } notify = "8" arboard = "3" shlex = { workspace = true } +tiny_http = "0.12" +webbrowser = "1" # Logging and Telemetry tracing = { workspace = true } diff --git a/crates/pctx/src/commands/mcp/add.rs b/crates/pctx/src/commands/mcp/add.rs index a88e4b5..4d6a27d 100644 --- a/crates/pctx/src/commands/mcp/add.rs +++ b/crates/pctx/src/commands/mcp/add.rs @@ -7,14 +7,15 @@ use tracing::info; use crate::{ commands::USER_CANCELLED, utils::{ - prompts, + oauth_flow, prompts, spinner::Spinner, - styles::{fmt_bold, fmt_cyan_bold, fmt_good_check}, + styles::{fmt_bold, fmt_cyan_bold, fmt_dimmed, fmt_good_check}, }, }; use pctx_config::{ Config, auth::{AuthConfig, SecretString}, + oauth2, server::{McpConnectionError, ServerConfig}, }; @@ -54,12 +55,34 @@ pub struct AddCmd { #[arg(long, short = 'H', conflicts_with = "command")] pub header: Option>, + /// Authenticate using OAuth 2.1 (Authorization Code + PKCE). + /// + /// Forces the OAuth browser flow even if the server doesn't advertise + /// OAuth metadata at a well-known endpoint. By default `pctx mcp add` + /// auto-detects OAuth-protected servers via RFC 9728 / RFC 8414 + /// discovery, so you only need this flag to override detection. + #[arg(long, conflicts_with_all = ["bearer", "header", "command"])] + pub oauth: bool, + /// Overrides any existing server under the same name & /// skips testing connection to the MCP server #[arg(long, short)] pub force: bool, } +fn prompt_manual_auth(server_name: &str) -> Result> { + let add_auth = inquire::Confirm::new("Do you want to add authentication interactively?") + .with_default(false) + .with_help_message( + "you can also manually update the auth configuration later in the config", + ); + if add_auth.prompt()? { + Ok(Some(prompts::prompt_auth(server_name)?)) + } else { + Ok(None) + } +} + fn parse_env_var(s: &str) -> Result<(String, String), String> { let (key, value) = s .split_once('=') @@ -100,8 +123,9 @@ impl AddCmd { } } - // apply authentication for HTTP servers only (clap ensures bearer & header are mutually exclusive) - if server.http().is_some() { + // apply authentication for HTTP servers only (clap ensures bearer/header/oauth are mutually exclusive) + if let Some(http_cfg) = server.http() { + let server_url = http_cfg.url.clone(); let auth = if let Some(bearer) = &self.bearer { Some(AuthConfig::Bearer { token: bearer.clone(), @@ -113,18 +137,35 @@ impl AddCmd { .map(|h| (h.name.clone(), h.value.clone())) .collect(), }) + } else if self.oauth { + // Explicit opt-in: run the OAuth flow without discovery gating. + Some(oauth_flow::run_interactive_flow(&server.name, &server_url).await?) + } else if self.force { + None } else { - let add_auth = inquire::Confirm::new( - "Do you want to add authentication interactively?", - ) - .with_default(false) - .with_help_message( - "you can also manually update the auth configuration later in the config", - ); - if !self.force && add_auth.prompt()? { - Some(prompts::prompt_auth(&server.name)?) + // Auto-detect OAuth via RFC 9728 / RFC 8414 / OIDC discovery + // before falling back to the manual bearer/headers prompt. + let mut sp = Spinner::new("Checking for OAuth metadata..."); + let discovered = oauth2::discover(&server_url).await.ok().flatten(); + if discovered.is_some() { + sp.stop_success("Detected OAuth-protected MCP server"); + let use_oauth = inquire::Confirm::new( + "Authorize pctx with this server using OAuth 2.1 (browser flow)?", + ) + .with_default(true) + .with_help_message( + "Tokens will be stored in your system keychain; nothing secret is written to pctx.json", + ) + .prompt()?; + if use_oauth { + Some(oauth_flow::run_interactive_flow(&server.name, &server_url).await?) + } else { + prompt_manual_auth(&server.name)? + } } else { - None + sp.stop_and_persist("·", "No OAuth metadata advertised"); + info!("{}", fmt_dimmed("Falling back to manual auth setup.")); + prompt_manual_auth(&server.name)? } }; server.set_auth(auth); @@ -244,6 +285,7 @@ mod tests { env: vec![], bearer: None, header: None, + oauth: false, force: true, }; @@ -269,6 +311,7 @@ mod tests { env: vec![("NODE_ENV".to_string(), "test".to_string())], bearer: None, header: None, + oauth: false, force: true, }; @@ -295,6 +338,7 @@ mod tests { env: vec![], bearer: None, header: None, + oauth: false, force: true, }; diff --git a/crates/pctx/src/commands/mcp/init.rs b/crates/pctx/src/commands/mcp/init.rs index 9a79862..d591c70 100644 --- a/crates/pctx/src/commands/mcp/init.rs +++ b/crates/pctx/src/commands/mcp/init.rs @@ -95,6 +95,7 @@ impl InitCmd { force: false, bearer: None, header: None, + oauth: false, } } else { // stdio @@ -146,6 +147,7 @@ impl InitCmd { force: false, bearer: None, header: None, + oauth: false, } }; diff --git a/crates/pctx/src/utils/mod.rs b/crates/pctx/src/utils/mod.rs index 23ada70..0ebe57f 100644 --- a/crates/pctx/src/utils/mod.rs +++ b/crates/pctx/src/utils/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod logger; pub(crate) mod metrics; +pub(crate) mod oauth_flow; pub(crate) mod prompts; pub(crate) mod spinner; pub(crate) mod styles; diff --git a/crates/pctx/src/utils/oauth_flow.rs b/crates/pctx/src/utils/oauth_flow.rs new file mode 100644 index 0000000..83da2e0 --- /dev/null +++ b/crates/pctx/src/utils/oauth_flow.rs @@ -0,0 +1,294 @@ +//! Interactive OAuth 2.1 authorization-code flow for `pctx mcp add`. +//! +//! This module owns the user-facing parts of the OAuth flow: spinning up a +//! one-shot localhost callback listener, opening the user's browser, waiting +//! for the redirect, and exchanging the authorization code for tokens. The +//! lower-level protocol bits (discovery, PKCE, token exchange, refresh) live +//! in [`pctx_config::oauth2`] so they can be reused outside the CLI. + +use std::{ + net::{Ipv4Addr, SocketAddrV4, TcpListener as StdTcpListener}, + time::Duration, +}; + +use anyhow::{Context, Result}; +use pctx_config::{ + auth::AuthConfig, + oauth2::{self, OAuthMetadata, Pkce}, +}; +use tracing::{debug, info, warn}; + +use crate::utils::styles::{fmt_cyan_bold, fmt_dimmed, fmt_good_check}; + +/// Default scopes pctx requests when the auth-server metadata advertises +/// `scopes_supported` — we just ask for whatever the server lists. If the +/// server doesn't advertise any, we send no `scope` parameter and let the +/// server use its defaults. +fn pick_scopes(metadata: &OAuthMetadata) -> Vec { + metadata.scopes_supported.clone() +} + +/// Run the full interactive OAuth 2.1 flow for an MCP server, returning a +/// ready-to-persist [`AuthConfig::OAuth`] (with the token bundle already +/// stored in the system keychain). +/// +/// `server_name` is used to derive the keychain `token_ref` and to label the +/// dynamically registered client. +/// +/// # Errors +/// Returns an error if discovery fails, the user cancels, the browser flow +/// times out, or the token endpoint returns a non-success response. +pub(crate) async fn run_interactive_flow( + server_name: &str, + server_url: &url::Url, +) -> Result { + info!( + "{}", + fmt_dimmed(&format!("Discovering OAuth metadata for {server_url}...")) + ); + let metadata = oauth2::discover(server_url) + .await + .context("OAuth discovery request failed")? + .ok_or_else(|| { + anyhow::anyhow!("Server does not advertise OAuth metadata at any well-known endpoint") + })?; + debug!( + "OAuth metadata: authorize={} token={} registration={:?}", + metadata.authorization_endpoint, metadata.token_endpoint, metadata.registration_endpoint + ); + + // Bind a one-shot localhost callback listener. We bind first so that the + // chosen port is part of the redirect_uri we register / authorize with. + let listener = bind_callback_listener()?; + let port = listener.local_addr()?.port(); + let redirect_uri: url::Url = format!("http://127.0.0.1:{port}/callback") + .parse() + .expect("constructed redirect URI is valid"); + + // Try RFC 7591 dynamic client registration if the server supports it; + // fall back to prompting the user for a pre-registered client_id. + let (client_id, client_secret) = + obtain_client_credentials(&metadata, server_name, &redirect_uri).await?; + + let scopes = pick_scopes(&metadata); + let pkce = Pkce::generate()?; + let state = oauth2::random_url_safe(16)?; + let authorize_url = oauth2::build_authorize_url( + &metadata, + &client_id, + &redirect_uri, + &scopes, + &state, + &pkce, + Some(server_url), + ); + + info!( + "{}", + fmt_cyan_bold(&format!( + "Opening browser to authorize pctx with {server_name}..." + )) + ); + info!( + "{}", + fmt_dimmed(&format!( + "If the browser does not open automatically, visit:\n {authorize_url}" + )) + ); + if let Err(e) = webbrowser::open(authorize_url.as_str()) { + warn!("Failed to launch browser automatically: {e}"); + } + + // Block until the redirect arrives (or we time out). + let callback = wait_for_callback(listener, &state)?; + info!("{}", fmt_good_check("Authorization code received")); + + let bundle = oauth2::exchange_code( + &metadata, + &client_id, + client_secret.as_deref(), + &callback.code, + &pkce.verifier, + &redirect_uri, + Some(server_url), + ) + .await + .context("Failed to exchange authorization code for tokens")?; + + let token_ref = AuthConfig::default_oauth_token_ref(server_name); + let token_ref_resolved = token_ref + .resolve() + .await + .context("Failed to resolve OAuth token_ref")?; + bundle + .save(&token_ref_resolved) + .context("Failed to persist OAuth token bundle to keychain")?; + info!( + "{}", + fmt_good_check(&format!( + "OAuth tokens stored in system keychain ({token_ref_resolved})" + )) + ); + + Ok(AuthConfig::OAuth { token_ref, scopes }) +} + +/// Bind a localhost TCP listener on an OS-assigned port. We use the standard +/// blocking listener (consumed once on the calling thread) because the OAuth +/// callback happens exactly once per flow and we want zero async runtime +/// dependencies for `tiny_http`. +fn bind_callback_listener() -> Result { + // Bind to 127.0.0.1:0 — the OS picks a free ephemeral port. The localhost + // requirement matches what most OAuth servers will allow without explicit + // pre-registration when DCR isn't available. + let listener = StdTcpListener::bind(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)) + .context("Failed to bind localhost callback listener")?; + listener + .set_nonblocking(false) + .context("Failed to set listener to blocking mode")?; + Ok(listener) +} + +struct CallbackResult { + code: String, +} + +/// Wait for the OAuth provider to redirect the user back to our localhost +/// listener. Validates the `state` parameter and returns the authorization +/// `code`. Times out after 5 minutes. +fn wait_for_callback(listener: StdTcpListener, expected_state: &str) -> Result { + let server = tiny_http::Server::from_listener(listener, None) + .map_err(|e| anyhow::anyhow!("Failed to start callback HTTP server: {e}"))?; + + let timeout = Duration::from_secs(300); + let request = server + .recv_timeout(timeout) + .map_err(|e| anyhow::anyhow!("Callback listener error: {e}"))? + .ok_or_else(|| anyhow::anyhow!("Timed out waiting for OAuth redirect after 5 minutes"))?; + + // Parse query string from request URL like "/callback?code=...&state=..." + let url_str = format!("http://127.0.0.1{}", request.url()); + let parsed = url::Url::parse(&url_str) + .with_context(|| format!("Invalid callback URL from browser: {}", request.url()))?; + let mut code: Option = None; + let mut state: Option = None; + let mut error: Option = None; + for (k, v) in parsed.query_pairs() { + match k.as_ref() { + "code" => code = Some(v.into_owned()), + "state" => state = Some(v.into_owned()), + "error" => error = Some(v.into_owned()), + _ => {} + } + } + + let html_body = |success: bool, message: &str| -> String { + let title = if success { + "pctx: success" + } else { + "pctx: error" + }; + format!( + "{title}\ + \ +

{title}

{message}

\ +

You can close this tab and return to your terminal.

", + color = if success { "#0a7" } else { "#c33" } + ) + }; + + if let Some(err) = error { + let body = html_body(false, &format!("Authorization failed: {err}")); + let _ = request.respond(html_response(&body)); + anyhow::bail!("OAuth provider returned error: {err}"); + } + + let state = state.context("OAuth callback missing 'state' parameter")?; + if state != expected_state { + let body = html_body(false, "State mismatch — possible CSRF attempt."); + let _ = request.respond(html_response(&body)); + anyhow::bail!("OAuth state mismatch: expected '{expected_state}', got '{state}'"); + } + + let code = code.context("OAuth callback missing 'code' parameter")?; + + let body = html_body( + true, + "Authorization successful. pctx has stored your tokens securely.", + ); + let _ = request.respond(html_response(&body)); + + Ok(CallbackResult { code }) +} + +fn html_response(body: &str) -> tiny_http::Response>> { + let bytes = body.as_bytes().to_vec(); + let len = bytes.len(); + tiny_http::Response::new( + tiny_http::StatusCode(200), + vec![ + tiny_http::Header::from_bytes(&b"Content-Type"[..], &b"text/html; charset=utf-8"[..]) + .expect("static header is valid"), + ], + std::io::Cursor::new(bytes), + Some(len), + None, + ) +} + +/// Either dynamically register a new OAuth client (RFC 7591) or fall back to +/// prompting the user for a pre-registered `client_id` / `client_secret`. +async fn obtain_client_credentials( + metadata: &OAuthMetadata, + server_name: &str, + redirect_uri: &url::Url, +) -> Result<(String, Option)> { + if let Some(reg) = &metadata.registration_endpoint { + debug!("Attempting RFC 7591 dynamic client registration at {reg}"); + let scopes = pick_scopes(metadata); + match oauth2::dynamic_register( + reg, + &[redirect_uri.to_string()], + &format!("pctx ({server_name})"), + &scopes, + ) + .await + { + Ok((id, secret)) => { + info!( + "{}", + fmt_good_check("Dynamically registered OAuth client with upstream server") + ); + return Ok((id, secret)); + } + Err(e) => { + warn!("Dynamic client registration failed: {e}; falling back to manual entry"); + } + } + } + + // Manual fallback — prompt the user. + info!( + "{}", + fmt_dimmed( + "This server does not support dynamic client registration. \ + Enter the OAuth client_id you have pre-registered with the server." + ) + ); + info!( + "{}", + fmt_dimmed(&format!("Use redirect URI: {redirect_uri}")) + ); + let client_id = inquire::Text::new("OAuth client_id:") + .with_validator(inquire::min_length!( + 1, + "client_id must be at least 1 character" + )) + .prompt()?; + let client_secret = inquire::Text::new("OAuth client_secret (leave blank for public client):") + .prompt() + .ok() + .filter(|s| !s.is_empty()); + Ok((client_id, client_secret)) +} diff --git a/crates/pctx_config/Cargo.toml b/crates/pctx_config/Cargo.toml index a349a9b..6cdce07 100644 --- a/crates/pctx_config/Cargo.toml +++ b/crates/pctx_config/Cargo.toml @@ -51,6 +51,8 @@ opentelemetry-otlp = { workspace = true, features = [ "trace", ] } base64 = "0.22" +sha2 = "0.10" +getrandom = "0.2" tonic = "0.14" opentelemetry_sdk = { workspace = true } shlex = { workspace = true } diff --git a/crates/pctx_config/src/auth.rs b/crates/pctx_config/src/auth.rs index d0bf6f4..b2ac07f 100644 --- a/crates/pctx_config/src/auth.rs +++ b/crates/pctx_config/src/auth.rs @@ -16,16 +16,32 @@ pub enum AuthConfig { Headers { headers: IndexMap, }, - // TODO: support OAuth client credentials flow? - // /// OAuth 2.1 Client Credentials Flow (machine-to-machine) - // #[serde(rename = "oauth_client_credentials")] - // OAuthClientCredentials { - // client_id: SecretString, - // client_secret: SecretString, - // token_url: url::Url, - // #[serde(skip_serializing_if = "Option::is_none")] - // scope: Option, - // }, + /// OAuth 2.1 Authorization Code + PKCE flow. + /// + /// All actual credentials (access token, refresh token, `client_id` / + /// `client_secret`, token endpoint) live in the system keychain under + /// `token_ref`; only non-secret metadata is persisted in `pctx.json`. + /// `token_ref` is opaque — pctx generates one when the user runs + /// `pctx mcp add` against an OAuth-protected server. + #[serde(rename = "oauth")] + OAuth { + /// Keychain key holding the JSON-serialized [`crate::oauth2::TokenBundle`]. + token_ref: SecretString, + /// Scopes that were granted at authorization time. Persisted so that + /// `pctx mcp add` can re-authorize without prompting again, and so + /// that observers reading `pctx.json` can see what access pctx has. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + scopes: Vec, + }, +} + +impl AuthConfig { + /// Default keychain ref pctx uses for a server's OAuth token bundle. + /// Stable so that re-running `pctx mcp add` for the same name overwrites + /// the same entry instead of leaking old ones. + pub fn default_oauth_token_ref(server_name: &str) -> SecretString { + SecretString::new_plain(&format!("oauth:{server_name}")) + } } /// A string that may contain 0 or more embedded secrets diff --git a/crates/pctx_config/src/lib.rs b/crates/pctx_config/src/lib.rs index 2a5b672..f8b70e6 100644 --- a/crates/pctx_config/src/lib.rs +++ b/crates/pctx_config/src/lib.rs @@ -10,6 +10,7 @@ use crate::{logger::LoggerConfig, server::ServerConfig, telemetry::TelemetryConf pub mod auth; pub(crate) mod defaults; pub mod logger; +pub mod oauth2; pub mod server; pub mod telemetry; diff --git a/crates/pctx_config/src/oauth2.rs b/crates/pctx_config/src/oauth2.rs new file mode 100644 index 0000000..d20814d --- /dev/null +++ b/crates/pctx_config/src/oauth2.rs @@ -0,0 +1,609 @@ +//! OAuth 2.1 / RFC 8414 / RFC 9728 / RFC 7591 support for upstream MCP servers. +//! +//! This module provides the protocol building blocks (discovery, PKCE, token +//! exchange, refresh, and keychain-backed token storage) used by the CLI to +//! drive an interactive browser-based authorization flow and by +//! [`crate::server::ServerConfig::connect`] to apply tokens at request time. +//! +//! The interactive browser/callback orchestration lives in the CLI crate +//! (`pctx::utils::oauth_flow`) — this module deliberately stays +//! transport-agnostic so it can be reused outside the CLI. + +use std::time::{SystemTime, UNIX_EPOCH}; + +use anyhow::{Context, Result}; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use tracing::debug; + +/// Authorization-server metadata as defined by RFC 8414 / `OpenID` Connect +/// Discovery. Only the fields pctx actually uses are kept. +#[derive(Debug, Clone, Deserialize)] +pub struct OAuthMetadata { + pub issuer: Option, + pub authorization_endpoint: url::Url, + pub token_endpoint: url::Url, + #[serde(default)] + pub registration_endpoint: Option, + #[serde(default)] + pub scopes_supported: Vec, + #[serde(default)] + pub code_challenge_methods_supported: Vec, +} + +/// RFC 9728 protected-resource metadata. The MCP server publishes this at +/// `/.well-known/oauth-protected-resource` to point clients at the +/// authorization server(s) it trusts. +#[derive(Debug, Clone, Deserialize)] +struct ProtectedResourceMetadata { + #[serde(default)] + authorization_servers: Vec, +} + +/// Persistent OAuth credentials for a single upstream server. Stored as JSON +/// in the system keychain under a single key — never written to `pctx.json`. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenBundle { + pub access_token: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, + #[serde(default = "default_token_type")] + pub token_type: String, + /// Unix-epoch seconds when the access token expires. `0` means unknown + /// (assume valid forever / refresh on 401). + #[serde(default)] + pub expires_at: u64, + pub token_endpoint: url::Url, + pub client_id: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client_secret: Option, +} + +fn default_token_type() -> String { + "Bearer".into() +} + +impl TokenBundle { + /// Refresh-skew: refresh proactively if token expires within this many + /// seconds. + pub const REFRESH_SKEW_SECS: u64 = 60; + + pub fn is_expired(&self) -> bool { + if self.expires_at == 0 { + return false; + } + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + now + Self::REFRESH_SKEW_SECS >= self.expires_at + } + + /// Load a token bundle from the system keychain. + /// + /// # Errors + /// Returns an error if no entry exists for `token_ref`, the keychain is + /// inaccessible, or the stored value cannot be deserialized. + pub fn load(token_ref: &str) -> Result { + let entry = keyring::Entry::new("pctx", token_ref) + .context("Failed to create keychain entry for OAuth token")?; + let json = entry + .get_password() + .with_context(|| format!("No OAuth token in keychain for '{token_ref}'"))?; + serde_json::from_str(&json).context("Failed to parse OAuth token bundle from keychain") + } + + /// Save (or overwrite) the token bundle in the system keychain. + /// + /// # Errors + /// Returns an error if the bundle cannot be serialized or the keychain + /// rejects the write. + pub fn save(&self, token_ref: &str) -> Result<()> { + let entry = keyring::Entry::new("pctx", token_ref) + .context("Failed to create keychain entry for OAuth token")?; + let json = serde_json::to_string(self).context("Failed to serialize OAuth token bundle")?; + entry + .set_password(&json) + .context("Failed to store OAuth token bundle in keychain")?; + debug!("OAuth token bundle saved to keychain ref={token_ref}"); + Ok(()) + } + + /// Delete the token bundle from the keychain. No-op if not present. + /// + /// # Errors + /// Returns an error only on unexpected keychain failures (a missing + /// entry is treated as success). + pub fn delete(token_ref: &str) -> Result<()> { + let entry = keyring::Entry::new("pctx", token_ref) + .context("Failed to create keychain entry for OAuth token")?; + match entry.delete_credential() { + Ok(()) | Err(keyring::Error::NoEntry) => Ok(()), + Err(e) => Err(anyhow::anyhow!(e)), + } + } +} + +/// Try to discover OAuth metadata for an MCP server URL. +/// +/// Tries (in order): +/// 1. RFC 9728 protected-resource metadata at +/// `/.well-known/oauth-protected-resource` — if present, follow +/// its `authorization_servers[0]` and recurse. +/// 2. RFC 8414 authorization-server metadata at +/// `/.well-known/oauth-authorization-server`. +/// 3. `OpenID` Connect discovery at `/.well-known/openid-configuration`. +/// +/// Returns `Ok(None)` if none of those return parseable JSON — i.e. the +/// server is not OAuth-protected. +/// +/// # Errors +/// Returns an error only on unexpected transport failures (DNS, TLS). A +/// 404 / non-JSON response is treated as "not OAuth" and returns `Ok(None)`. +pub async fn discover(server_url: &url::Url) -> Result> { + let client = reqwest::Client::builder() + .build() + .context("Failed to build HTTP client for OAuth discovery")?; + + // 1. RFC 9728 — protected resource metadata + let pr_url = well_known(server_url, "oauth-protected-resource"); + if let Some(meta) = fetch_json::(&client, &pr_url).await? + && let Some(auth_server) = meta.authorization_servers.first() + { + debug!("Discovered protected-resource metadata pointing at {auth_server}"); + if let Some(m) = discover_auth_server(&client, auth_server).await? { + return Ok(Some(m)); + } + } + + // 2/3. Try the resource origin itself as the auth server. + discover_auth_server(&client, server_url).await +} + +async fn discover_auth_server( + client: &reqwest::Client, + base: &url::Url, +) -> Result> { + let as_url = well_known(base, "oauth-authorization-server"); + if let Some(meta) = fetch_json::(client, &as_url).await? { + return Ok(Some(meta)); + } + let oidc_url = well_known(base, "openid-configuration"); + fetch_json::(client, &oidc_url).await +} + +/// Build a `:///.well-known/` URL preserving the origin +/// of `base` and discarding any path/query. +fn well_known(base: &url::Url, name: &str) -> url::Url { + let mut u = base.clone(); + u.set_path(&format!("/.well-known/{name}")); + u.set_query(None); + u.set_fragment(None); + u +} + +async fn fetch_json( + client: &reqwest::Client, + url: &url::Url, +) -> Result> { + debug!("OAuth discovery: GET {url}"); + let resp = match client.get(url.clone()).send().await { + Ok(r) => r, + Err(e) => { + debug!("Discovery fetch failed for {url}: {e}"); + return Ok(None); + } + }; + if !resp.status().is_success() { + return Ok(None); + } + match resp.json::().await { + Ok(t) => Ok(Some(t)), + Err(e) => { + debug!("Discovery body parse failed for {url}: {e}"); + Ok(None) + } + } +} + +// === RFC 7591 Dynamic Client Registration =========================== + +#[derive(Debug, Serialize)] +struct DcrRequest<'a> { + redirect_uris: &'a [String], + client_name: &'a str, + grant_types: [&'a str; 2], + response_types: [&'a str; 1], + token_endpoint_auth_method: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + scope: Option, +} + +#[derive(Debug, Deserialize)] +struct DcrResponse { + client_id: String, + #[serde(default)] + client_secret: Option, +} + +/// Best-effort RFC 7591 dynamic client registration. Returns `(client_id, +/// client_secret)` on success. +/// +/// # Errors +/// Returns an error only on transport / non-success responses. Callers should +/// be prepared to fall back to prompting the user for a pre-registered +/// `client_id` when this fails. +pub async fn dynamic_register( + registration_endpoint: &url::Url, + redirect_uris: &[String], + client_name: &str, + scopes: &[String], +) -> Result<(String, Option)> { + let body = DcrRequest { + redirect_uris, + client_name, + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + // Public client (no client secret); we'll switch to client_secret_basic + // automatically if the server returns one. + token_endpoint_auth_method: "none", + scope: if scopes.is_empty() { + None + } else { + Some(scopes.join(" ")) + }, + }; + + let client = reqwest::Client::new(); + let resp = client + .post(registration_endpoint.clone()) + .json(&body) + .send() + .await + .context("Dynamic client registration request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Dynamic client registration failed: {status}: {text}"); + } + + let parsed: DcrResponse = resp + .json() + .await + .context("Dynamic client registration returned invalid JSON")?; + Ok((parsed.client_id, parsed.client_secret)) +} + +// === PKCE helpers ===================================================== + +/// A freshly generated PKCE code verifier + challenge pair. +#[derive(Debug, Clone)] +pub struct Pkce { + pub verifier: String, + pub challenge: String, +} + +impl Pkce { + /// Generate a fresh PKCE pair using SHA-256 (S256). + /// + /// # Errors + /// Returns an error if the OS RNG is unavailable. + pub fn generate() -> Result { + let verifier = random_url_safe(32)?; + let digest = Sha256::digest(verifier.as_bytes()); + let challenge = URL_SAFE_NO_PAD.encode(digest); + Ok(Self { + verifier, + challenge, + }) + } +} + +/// Generate `n` random bytes and return them base64url-encoded (no padding). +/// Suitable for OAuth `state` and PKCE `code_verifier`. +/// +/// # Errors +/// Returns an error if the OS RNG is unavailable. +pub fn random_url_safe(n: usize) -> Result { + let mut buf = vec![0u8; n]; + getrandom::getrandom(&mut buf).map_err(|e| anyhow::anyhow!("OS RNG unavailable: {e}"))?; + Ok(URL_SAFE_NO_PAD.encode(&buf)) +} + +/// URL-encode a list of form pairs (`application/x-www-form-urlencoded`). +/// We do this by hand because the `reqwest` `RequestBuilder::form` helper +/// isn't available with the feature set `pctx_config` enables. +fn url_encoded_body(pairs: &[(&str, &str)]) -> String { + let mut s = url::form_urlencoded::Serializer::new(String::new()); + for (k, v) in pairs { + s.append_pair(k, v); + } + s.finish() +} + +const FORM_CONTENT_TYPE: &str = "application/x-www-form-urlencoded"; + +/// Build the authorization-endpoint URL the user's browser should be sent to. +pub fn build_authorize_url( + metadata: &OAuthMetadata, + client_id: &str, + redirect_uri: &url::Url, + scopes: &[String], + state: &str, + pkce: &Pkce, + resource: Option<&url::Url>, +) -> url::Url { + let mut u = metadata.authorization_endpoint.clone(); + { + let mut q = u.query_pairs_mut(); + q.append_pair("response_type", "code"); + q.append_pair("client_id", client_id); + q.append_pair("redirect_uri", redirect_uri.as_str()); + q.append_pair("state", state); + q.append_pair("code_challenge", &pkce.challenge); + q.append_pair("code_challenge_method", "S256"); + if !scopes.is_empty() { + q.append_pair("scope", &scopes.join(" ")); + } + if let Some(res) = resource { + // RFC 8707 resource indicator + q.append_pair("resource", res.as_str()); + } + } + u +} + +// === Token endpoint exchanges ========================================= + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default = "default_token_type")] + token_type: String, + #[serde(default)] + expires_in: Option, +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0) +} + +fn token_response_into_bundle( + resp: TokenResponse, + token_endpoint: url::Url, + client_id: String, + client_secret: Option, + fallback_refresh: Option, +) -> TokenBundle { + let expires_at = resp.expires_in.map_or(0, |s| now_secs() + s); + TokenBundle { + access_token: resp.access_token, + refresh_token: resp.refresh_token.or(fallback_refresh), + token_type: resp.token_type, + expires_at, + token_endpoint, + client_id, + client_secret, + } +} + +/// Exchange an authorization code (returned via the redirect URI) for tokens. +/// +/// # Errors +/// Returns an error if the token endpoint is unreachable, returns a +/// non-success status, or returns a body that does not parse as a token +/// response. +pub async fn exchange_code( + metadata: &OAuthMetadata, + client_id: &str, + client_secret: Option<&str>, + code: &str, + code_verifier: &str, + redirect_uri: &url::Url, + resource: Option<&url::Url>, +) -> Result { + let redirect_uri_str = redirect_uri.to_string(); + let resource_str = resource.map(url::Url::to_string); + let mut pairs: Vec<(&str, &str)> = vec![ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect_uri_str.as_str()), + ("client_id", client_id), + ("code_verifier", code_verifier), + ]; + if let Some(res) = resource_str.as_deref() { + pairs.push(("resource", res)); + } + let body = url_encoded_body(&pairs); + + let client = reqwest::Client::new(); + let mut req = client + .post(metadata.token_endpoint.clone()) + .header(http::header::CONTENT_TYPE, FORM_CONTENT_TYPE) + .body(body); + if let Some(secret) = client_secret { + req = req.basic_auth(client_id, Some(secret)); + } + let resp = req + .send() + .await + .context("OAuth token exchange request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("OAuth token exchange failed: {status}: {text}"); + } + + let parsed: TokenResponse = resp + .json() + .await + .context("OAuth token endpoint returned invalid JSON")?; + Ok(token_response_into_bundle( + parsed, + metadata.token_endpoint.clone(), + client_id.into(), + client_secret.map(str::to_string), + None, + )) +} + +/// Use a refresh token to obtain a fresh access (and possibly refresh) token. +/// +/// # Errors +/// Returns an error if the bundle has no refresh token, the token endpoint +/// is unreachable, or it returns a non-success response. +pub async fn refresh(bundle: &TokenBundle) -> Result { + let refresh_token = bundle + .refresh_token + .as_deref() + .context("OAuth bundle has no refresh token; re-run `pctx mcp add` to re-authorize")?; + + let pairs = [ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ("client_id", bundle.client_id.as_str()), + ]; + let body = url_encoded_body(&pairs); + + let client = reqwest::Client::new(); + let mut req = client + .post(bundle.token_endpoint.clone()) + .header(http::header::CONTENT_TYPE, FORM_CONTENT_TYPE) + .body(body); + if let Some(secret) = bundle.client_secret.as_deref() { + req = req.basic_auth(&bundle.client_id, Some(secret)); + } + let resp = req.send().await.context("OAuth refresh request failed")?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("OAuth refresh failed: {status}: {text}"); + } + + let parsed: TokenResponse = resp + .json() + .await + .context("OAuth refresh returned invalid JSON")?; + Ok(token_response_into_bundle( + parsed, + bundle.token_endpoint.clone(), + bundle.client_id.clone(), + bundle.client_secret.clone(), + // Some servers don't return a new refresh token — keep the old one. + bundle.refresh_token.clone(), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pkce_pair_is_valid_s256() { + let p = Pkce::generate().unwrap(); + // verifier base64url-encoded 32 bytes => 43 chars + assert_eq!(p.verifier.len(), 43); + // challenge is sha256(verifier) base64url + let expected = URL_SAFE_NO_PAD.encode(Sha256::digest(p.verifier.as_bytes())); + assert_eq!(p.challenge, expected); + } + + #[test] + fn well_known_strips_path_and_query() { + let base: url::Url = "https://mcp.example.com/sse?token=foo".parse().unwrap(); + let w = well_known(&base, "oauth-authorization-server"); + assert_eq!( + w.as_str(), + "https://mcp.example.com/.well-known/oauth-authorization-server" + ); + } + + #[test] + fn token_bundle_round_trip() { + let bundle = TokenBundle { + access_token: "at".into(), + refresh_token: Some("rt".into()), + token_type: "Bearer".into(), + expires_at: 1_700_000_000, + token_endpoint: "https://issuer.example.com/token".parse().unwrap(), + client_id: "client123".into(), + client_secret: None, + }; + let json = serde_json::to_string(&bundle).unwrap(); + let parsed: TokenBundle = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.access_token, "at"); + assert_eq!(parsed.client_id, "client123"); + assert_eq!(parsed.expires_at, 1_700_000_000); + } + + #[test] + fn token_bundle_expiry() { + let bundle = TokenBundle { + access_token: "at".into(), + refresh_token: None, + token_type: "Bearer".into(), + expires_at: 1, // ancient + token_endpoint: "https://x/token".parse().unwrap(), + client_id: "c".into(), + client_secret: None, + }; + assert!(bundle.is_expired()); + + let unknown = TokenBundle { + expires_at: 0, + ..bundle.clone() + }; + assert!(!unknown.is_expired()); + + let future = TokenBundle { + expires_at: now_secs() + 3600, + ..bundle + }; + assert!(!future.is_expired()); + } + + #[test] + fn build_authorize_url_includes_pkce_and_state() { + let metadata = OAuthMetadata { + issuer: None, + authorization_endpoint: "https://issuer/authorize".parse().unwrap(), + token_endpoint: "https://issuer/token".parse().unwrap(), + registration_endpoint: None, + scopes_supported: vec![], + code_challenge_methods_supported: vec!["S256".into()], + }; + let pkce = Pkce::generate().unwrap(); + let url = build_authorize_url( + &metadata, + "myclient", + &"http://127.0.0.1:8765/callback".parse().unwrap(), + &["read".into(), "write".into()], + "abc-state", + &pkce, + Some(&"https://mcp.example.com".parse().unwrap()), + ); + let q: std::collections::HashMap<_, _> = url.query_pairs().into_owned().collect(); + assert_eq!(q.get("response_type").map(String::as_str), Some("code")); + assert_eq!(q.get("client_id").map(String::as_str), Some("myclient")); + assert_eq!( + q.get("code_challenge_method").map(String::as_str), + Some("S256") + ); + assert_eq!(q.get("code_challenge"), Some(&pkce.challenge)); + assert_eq!(q.get("scope").map(String::as_str), Some("read write")); + assert_eq!(q.get("state").map(String::as_str), Some("abc-state")); + assert_eq!( + q.get("resource").map(String::as_str), + Some("https://mcp.example.com/") + ); + } +} diff --git a/crates/pctx_config/src/server.rs b/crates/pctx_config/src/server.rs index 68c376b..127d221 100644 --- a/crates/pctx_config/src/server.rs +++ b/crates/pctx_config/src/server.rs @@ -16,7 +16,47 @@ use tokio::process::Command; pub use rmcp::ServiceError; -use super::auth::AuthConfig; +use super::auth::{AuthConfig, SecretString}; +use crate::oauth2::{self, TokenBundle}; + +/// Load the OAuth token bundle from the keychain, refreshing it (and writing +/// the new bundle back) if it's expired or about to expire. +async fn resolve_oauth_access_token(token_ref: &str) -> Result { + let bundle = + TokenBundle::load(token_ref).map_err(|e| McpConnectionError::Failed(e.to_string()))?; + + if !bundle.is_expired() { + return Ok(bundle.access_token); + } + + tracing::debug!("OAuth access token expired, refreshing (token_ref={token_ref})"); + let refreshed = oauth2::refresh(&bundle) + .await + .map_err(|e| McpConnectionError::Failed(format!("OAuth refresh failed: {e}")))?; + refreshed + .save(token_ref) + .map_err(|e| McpConnectionError::Failed(e.to_string()))?; + Ok(refreshed.access_token) +} + +/// Force a refresh of the OAuth token bundle for `token_ref`, regardless of +/// expiry. Used by the connection pool when a cached connection returns an +/// auth error mid-session. +/// +/// # Errors +/// Returns an error if the keychain entry is missing, the refresh request +/// fails, or the new bundle cannot be persisted back to the keychain. +pub async fn force_refresh_oauth_token(token_ref: &str) -> Result<(), McpConnectionError> { + let bundle = + TokenBundle::load(token_ref).map_err(|e| McpConnectionError::Failed(e.to_string()))?; + let refreshed = oauth2::refresh(&bundle) + .await + .map_err(|e| McpConnectionError::Failed(format!("OAuth refresh failed: {e}")))?; + refreshed + .save(token_ref) + .map_err(|e| McpConnectionError::Failed(e.to_string()))?; + Ok(()) +} #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ServerConfig { @@ -97,6 +137,20 @@ impl ServerConfig { } } + /// If this is an HTTP server configured with OAuth, returns the + /// `token_ref` [`SecretString`] that resolves to the keychain key holding + /// its [`crate::oauth2::TokenBundle`]. Returns `None` for stdio servers, + /// unauthenticated HTTP servers, or HTTP servers using bearer / header auth. + pub fn oauth_token_ref(&self) -> Option<&SecretString> { + match &self.transport { + ServerTransport::Http(http_cfg) => match http_cfg.auth.as_ref()? { + AuthConfig::OAuth { token_ref, .. } => Some(token_ref), + _ => None, + }, + ServerTransport::Stdio(_) => None, + } + } + pub fn display_target(&self) -> String { match &self.transport { ServerTransport::Http(cfg) => cfg.url.to_string(), @@ -159,6 +213,18 @@ impl ServerConfig { ); } } + AuthConfig::OAuth { token_ref, .. } => { + let resolved_ref = token_ref + .resolve() + .await + .map_err(|e| McpConnectionError::Failed(e.to_string()))?; + let access = resolve_oauth_access_token(&resolved_ref).await?; + default_headers.append( + http::header::AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {access}")) + .map_err(|e| McpConnectionError::Failed(e.to_string()))?, + ); + } } } diff --git a/crates/pctx_registry/src/connection_pool.rs b/crates/pctx_registry/src/connection_pool.rs index d501684..2e20bfd 100644 --- a/crates/pctx_registry/src/connection_pool.rs +++ b/crates/pctx_registry/src/connection_pool.rs @@ -2,10 +2,10 @@ use std::{collections::HashMap, sync::Arc}; use rmcp::{RoleClient, model::InitializeRequestParams, service::RunningService}; use tokio::sync::RwLock; -use tracing::debug; +use tracing::{debug, warn}; use crate::error::RegistryError; -use pctx_config::server::ServerConfig; +use pctx_config::server::{self, ServerConfig}; type PooledClient = Arc>; @@ -72,6 +72,54 @@ impl McpConnectionPool { Ok((new_client, false)) } + /// Force an OAuth token refresh for `cfg`, evict any cached connection + /// for it, and re-establish the connection with the new access token. + /// + /// This is the recovery path for OAuth-authed upstream MCPs whose access + /// token has expired mid-session: callers (executor / tool dispatchers) + /// should invoke this once after they observe an auth-shaped failure + /// from a cached connection, then retry the original request with the + /// returned client. For non-OAuth servers this is a no-op that returns + /// `Ok(None)`, signalling that no refresh recovery is possible. + /// + /// Refresh is intentionally targeted (only OAuth servers) so that bearer + /// / header configurations keep their existing failure semantics — we + /// don't want to mask a real misconfiguration as a transient blip. + /// + /// # Errors + /// Returns an error if the OAuth refresh request fails or if the + /// follow-up reconnect fails. + pub async fn refresh_oauth_and_reconnect( + &self, + cfg: &ServerConfig, + ) -> Result, RegistryError> { + let Some(token_ref_secret) = cfg.oauth_token_ref() else { + return Ok(None); + }; + let token_ref = token_ref_secret + .resolve() + .await + .map_err(|e| RegistryError::Connection(e.to_string()))?; + + debug!(server = %cfg.name, "Forcing OAuth token refresh and reconnect"); + if let Err(e) = server::force_refresh_oauth_token(&token_ref).await { + warn!(server = %cfg.name, error = %e, "OAuth refresh failed"); + return Err(RegistryError::from(e)); + } + + // Evict any cached connection so the next get_or_connect rebuilds + // the transport with the freshly-issued access token. + { + let mut connections = self.connections.write().await; + if let Some(prev) = connections.remove(&cfg.name) { + prev.cancellation_token().cancel(); + } + } + + let (client, _cached) = self.get_or_connect(cfg).await?; + Ok(Some(client)) + } + /// Cancels and removes all active upstream connections. /// /// Ongoing in-flight requests will complete or fail as the underlying diff --git a/docs/config.md b/docs/config.md index 35347dc..c76c106 100644 --- a/docs/config.md +++ b/docs/config.md @@ -121,7 +121,7 @@ await slack.sendMessage({ channel: "#general", text: "hi" }); ## Authentication -The `auth` field supports two types of authentication `BearerToken | Custom`: +The `auth` field supports three types of authentication: `bearer`, `headers`, and `oauth`. ### Bearer Token Authentication @@ -162,6 +162,69 @@ This adds an `Authorization: Bearer ` header to all requests. Use this for API key authentication or any custom header requirements. +### OAuth 2.1 Authentication + +`pctx` supports OAuth 2.1 Authorization Code + PKCE for upstream MCP servers +that advertise OAuth metadata via [RFC 9728](https://datatracker.ietf.org/doc/html/rfc9728) +(`/.well-known/oauth-protected-resource`), [RFC 8414](https://datatracker.ietf.org/doc/html/rfc8414) +(`/.well-known/oauth-authorization-server`), or OpenID Connect Discovery +(`/.well-known/openid-configuration`). + +The recommended way to set up OAuth is to let `pctx mcp add` drive the flow: + +```bash +pctx mcp add my-server https://mcp.example.com/sse +``` + +When the URL is reachable, `pctx` automatically tries OAuth discovery. If the +server advertises OAuth metadata, `pctx`: + +1. Performs RFC 7591 dynamic client registration if the auth server supports + it (otherwise prompts you for a pre-registered `client_id`). +2. Spins up a one-shot localhost callback listener and opens your browser to + the authorization endpoint. +3. Exchanges the resulting code (with PKCE) for access + refresh tokens. +4. Stores the token bundle in your **system keychain** under a stable + `token_ref`. Nothing secret is ever written to `pctx.json`. + +You can force the OAuth flow with `--oauth` even if you've already configured +another auth type, and it overrides any auto-detection. + +| Field | Type | Required | Description | +| ----------- | --------------- | -------- | -------------------------------------------------------------------------------------------------------------------------- | +| `type` | `"oauth"` | Yes | Constant designating this object as an OAuth config | +| `token_ref` | `string` | Yes | Opaque keychain entry name where pctx stores the access / refresh token bundle for this server | +| `scopes` | `array[string]` | No | Scopes that were granted at authorization time (informational; pctx writes this so the granted access is visible in-file) | + +**Example `pctx.json` excerpt:** + +```json +{ + "name": "my-server", + "url": "https://mcp.example.com/sse", + "auth": { + "type": "oauth", + "token_ref": "oauth:my-server", + "scopes": ["read", "write"] + } +} +``` + +At connection time, `pctx`: + +- Loads the bundle from the keychain entry named in `token_ref`. +- Refreshes the access token automatically if it has expired (or is within + 60 seconds of expiry), persisting the new bundle back to the keychain. +- Sends the access token as `Authorization: Bearer ` on every request + to the upstream MCP server. + +If a long-lived session sees an auth-shaped failure mid-stream, `pctx` will +force-refresh the token once and retry the connection before surfacing the +error. + +To re-authorize (e.g. after revocation or scope changes), re-run +`pctx mcp add my-server --oauth --force`. + ## Logger Configuration The optional `logger` field controls logging behavior for the pctx server MPC server. This configuration applies