From 65a6a1c10924de91ddca71cbf75fe26363ab0501 Mon Sep 17 00:00:00 2001 From: centdix Date: Sat, 10 Jan 2026 16:46:04 +0100 Subject: [PATCH 1/2] feat(auth): add StateStore trait for pluggable OAuth state storage --- crates/rmcp/src/transport.rs | 6 +- crates/rmcp/src/transport/auth.rs | 336 ++++++++++++++++++++++++++++-- 2 files changed, 319 insertions(+), 23 deletions(-) diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index aeb8c795..bf9e7464 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -99,7 +99,11 @@ pub use io::stdio; pub mod auth; #[cfg(feature = "auth")] #[cfg_attr(docsrs, doc(cfg(feature = "auth")))] -pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}; +pub use auth::{ + AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, CredentialStore, + InMemoryCredentialStore, InMemoryStateStore, StateStore, StoredAuthorizationState, + StoredCredentials, +}; // #[cfg(feature = "transport-ws")] // #[cfg_attr(docsrs, doc(cfg(feature = "transport-ws")))] diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 91b8cad0..6a5567f4 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -74,6 +74,93 @@ impl CredentialStore for InMemoryCredentialStore { } } +/// Stored authorization state for OAuth2 PKCE flow +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StoredAuthorizationState { + pub pkce_verifier: String, + pub csrf_token: String, + pub created_at: u64, +} + +impl StoredAuthorizationState { + pub fn new(pkce_verifier: &PkceCodeVerifier, csrf_token: &CsrfToken) -> Self { + Self { + pkce_verifier: pkce_verifier.secret().to_string(), + csrf_token: csrf_token.secret().to_string(), + created_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0), + } + } + + pub fn into_pkce_verifier(self) -> PkceCodeVerifier { + PkceCodeVerifier::new(self.pkce_verifier) + } +} + +/// Trait for storing and retrieving OAuth2 authorization state +/// +/// Implementations of this trait can provide custom storage backends +/// for OAuth2 PKCE flow state, such as Redis or database storage. +/// +/// Implementors are responsible for expiring stale states (e.g., abandoned +/// authorization flows). Use [`StoredAuthorizationState::created_at`] for +/// TTL-based expiration. +#[async_trait] +pub trait StateStore: Send + Sync { + async fn save( + &self, + csrf_token: &str, + state: StoredAuthorizationState, + ) -> Result<(), AuthError>; + + async fn load(&self, csrf_token: &str) -> Result, AuthError>; + + async fn delete(&self, csrf_token: &str) -> Result<(), AuthError>; +} + +/// In-memory state store (default implementation) +/// +/// This store keeps authorization state in memory only and does not persist +/// between application restarts or across multiple server instances. +#[derive(Debug, Default, Clone)] +pub struct InMemoryStateStore { + states: Arc>>, +} + +impl InMemoryStateStore { + pub fn new() -> Self { + Self { + states: Arc::new(RwLock::new(HashMap::new())), + } + } +} + +#[async_trait] +impl StateStore for InMemoryStateStore { + async fn save( + &self, + csrf_token: &str, + state: StoredAuthorizationState, + ) -> Result<(), AuthError> { + self.states + .write() + .await + .insert(csrf_token.to_string(), state); + Ok(()) + } + + async fn load(&self, csrf_token: &str) -> Result, AuthError> { + Ok(self.states.read().await.get(csrf_token).cloned()) + } + + async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> { + self.states.write().await.remove(csrf_token); + Ok(()) + } +} + /// HTTP client with OAuth 2.0 authorization #[derive(Clone)] pub struct AuthClient { @@ -210,7 +297,7 @@ pub struct AuthorizationManager { metadata: Option, oauth_client: Option, credential_store: Arc, - state: RwLock>, + state_store: Arc, base_url: Url, } @@ -234,12 +321,6 @@ pub struct ClientRegistrationResponse { pub additional_fields: HashMap, } -#[derive(Debug)] -struct AuthorizationState { - pkce_verifier: PkceCodeVerifier, - csrf_token: CsrfToken, -} - /// SEP-991: URL-based Client IDs /// Validate that the client_id is a valid URL with https scheme and non-root pathname fn is_https_url(value: &str) -> bool { @@ -290,7 +371,7 @@ impl AuthorizationManager { metadata: None, oauth_client: None, credential_store: Arc::new(InMemoryCredentialStore::new()), - state: RwLock::new(None), + state_store: Arc::new(InMemoryStateStore::new()), base_url, }; @@ -306,6 +387,21 @@ impl AuthorizationManager { self.credential_store = Arc::new(store); } + /// Set a custom state store for OAuth2 authorization flow state + /// + /// This should be called before initiating the authorization flow. + pub fn set_state_store(&mut self, store: S) { + self.state_store = Arc::new(store); + } + + /// Set OAuth2 authorization metadata + /// + /// This should be called after discovering metadata via `discover_metadata()` + /// and before creating an `AuthorizationSession`. + pub fn set_metadata(&mut self, metadata: AuthorizationMetadata) { + self.metadata = Some(metadata); + } + /// Initialize from stored credentials if available /// /// This will load credentials from the credential store and configure @@ -527,11 +623,11 @@ impl AuthorizationManager { let (auth_url, csrf_token) = auth_request.url(); - // store pkce verifier for later use - *self.state.write().await = Some(AuthorizationState { - pkce_verifier, - csrf_token, - }); + // store pkce verifier for later use via state store + let stored_state = StoredAuthorizationState::new(&pkce_verifier, &csrf_token); + self.state_store + .save(csrf_token.secret(), stored_state) + .await?; Ok(auth_url.to_string()) } @@ -548,17 +644,17 @@ impl AuthorizationManager { .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - let AuthorizationState { - pkce_verifier, - csrf_token: expected_csrf_token, - } = - self.state.write().await.take().ok_or_else(|| { + // Load state from state store using CSRF token as key + let stored_state = + self.state_store.load(csrf_token).await?.ok_or_else(|| { AuthError::InternalError("Authorization state not found".to_string()) })?; - if csrf_token != expected_csrf_token.secret() { - return Err(AuthError::InternalError("CSRF token mismatch".to_string())); - } + // Delete state after retrieval (one-time use) + self.state_store.delete(csrf_token).await?; + + // Reconstruct the PKCE verifier + let pkce_verifier = stored_state.into_pkce_verifier(); let http_client = reqwest::ClientBuilder::new() .redirect(reqwest::redirect::Policy::none()) @@ -1353,9 +1449,15 @@ impl OAuthState { #[cfg(test)] mod tests { + use std::sync::Arc; + + use oauth2::{CsrfToken, PkceCodeVerifier}; use url::Url; - use super::{AuthorizationManager, is_https_url}; + use super::{ + AuthError, AuthorizationManager, InMemoryStateStore, StateStore, StoredAuthorizationState, + is_https_url, + }; // SEP-991: URL-based Client IDs // Tests adapted from the TypeScript SDK's isHttpsUrl test suite @@ -1551,4 +1653,194 @@ mod tests { "https://auth.example.com/tenant1/subtenant/.well-known/openid-configuration" ); } + + // StateStore and StoredAuthorizationState tests + + #[tokio::test] + async fn test_in_memory_state_store_save_and_load() { + let store = InMemoryStateStore::new(); + let pkce = PkceCodeVerifier::new("test-verifier".to_string()); + let csrf = CsrfToken::new("test-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + // Save state + store.save("test-csrf", state).await.unwrap(); + + // Load state + let loaded = store.load("test-csrf").await.unwrap(); + assert!(loaded.is_some()); + let loaded = loaded.unwrap(); + assert_eq!(loaded.csrf_token, "test-csrf"); + assert_eq!(loaded.pkce_verifier, "test-verifier"); + } + + #[tokio::test] + async fn test_in_memory_state_store_load_nonexistent() { + let store = InMemoryStateStore::new(); + let result = store.load("nonexistent").await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_in_memory_state_store_delete() { + let store = InMemoryStateStore::new(); + let pkce = PkceCodeVerifier::new("verifier".to_string()); + let csrf = CsrfToken::new("csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save("csrf", state).await.unwrap(); + store.delete("csrf").await.unwrap(); + + let result = store.load("csrf").await.unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_stored_authorization_state_serialization() { + let pkce = PkceCodeVerifier::new("my-verifier".to_string()); + let csrf = CsrfToken::new("my-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + // Serialize to JSON + let json = serde_json::to_string(&state).unwrap(); + + // Deserialize back + let deserialized: StoredAuthorizationState = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.pkce_verifier, "my-verifier"); + assert_eq!(deserialized.csrf_token, "my-csrf"); + } + + #[test] + fn test_stored_authorization_state_into_pkce_verifier() { + let pkce = PkceCodeVerifier::new("original-verifier".to_string()); + let csrf = CsrfToken::new("csrf-token".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + let recovered = state.into_pkce_verifier(); + assert_eq!(recovered.secret(), "original-verifier"); + } + + #[test] + fn test_stored_authorization_state_created_at() { + let pkce = PkceCodeVerifier::new("verifier".to_string()); + let csrf = CsrfToken::new("csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + // created_at should be a reasonable timestamp (after year 2020) + assert!(state.created_at > 1577836800); // Jan 1, 2020 + } + + #[tokio::test] + async fn test_in_memory_state_store_overwrite() { + let store = InMemoryStateStore::new(); + let csrf_key = "same-csrf"; + + // Save first state + let pkce1 = PkceCodeVerifier::new("verifier-1".to_string()); + let csrf1 = CsrfToken::new(csrf_key.to_string()); + let state1 = StoredAuthorizationState::new(&pkce1, &csrf1); + store.save(csrf_key, state1).await.unwrap(); + + // Overwrite with second state + let pkce2 = PkceCodeVerifier::new("verifier-2".to_string()); + let csrf2 = CsrfToken::new(csrf_key.to_string()); + let state2 = StoredAuthorizationState::new(&pkce2, &csrf2); + store.save(csrf_key, state2).await.unwrap(); + + // Should get the second state + let loaded = store.load(csrf_key).await.unwrap().unwrap(); + assert_eq!(loaded.pkce_verifier, "verifier-2"); + } + + #[tokio::test] + async fn test_in_memory_state_store_concurrent_access() { + let store = Arc::new(InMemoryStateStore::new()); + let mut handles = vec![]; + + // Spawn 10 concurrent tasks that each save and load their own state + for i in 0..10 { + let store = Arc::clone(&store); + let handle = tokio::spawn(async move { + let csrf_key = format!("csrf-{}", i); + let verifier = format!("verifier-{}", i); + + let pkce = PkceCodeVerifier::new(verifier.clone()); + let csrf = CsrfToken::new(csrf_key.clone()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save(&csrf_key, state).await.unwrap(); + let loaded = store.load(&csrf_key).await.unwrap().unwrap(); + assert_eq!(loaded.pkce_verifier, verifier); + + store.delete(&csrf_key).await.unwrap(); + let deleted = store.load(&csrf_key).await.unwrap(); + assert!(deleted.is_none()); + }); + handles.push(handle); + } + + // Wait for all tasks to complete + for handle in handles { + handle.await.unwrap(); + } + } + + #[tokio::test] + async fn test_custom_state_store_with_authorization_manager() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + // Custom state store that tracks calls + #[derive(Debug, Default)] + struct TrackingStateStore { + inner: InMemoryStateStore, + save_count: AtomicUsize, + load_count: AtomicUsize, + delete_count: AtomicUsize, + } + + #[async_trait::async_trait] + impl StateStore for TrackingStateStore { + async fn save( + &self, + csrf_token: &str, + state: StoredAuthorizationState, + ) -> Result<(), AuthError> { + self.save_count.fetch_add(1, Ordering::SeqCst); + self.inner.save(csrf_token, state).await + } + + async fn load( + &self, + csrf_token: &str, + ) -> Result, AuthError> { + self.load_count.fetch_add(1, Ordering::SeqCst); + self.inner.load(csrf_token).await + } + + async fn delete(&self, csrf_token: &str) -> Result<(), AuthError> { + self.delete_count.fetch_add(1, Ordering::SeqCst); + self.inner.delete(csrf_token).await + } + } + + // Verify custom store works standalone + let store = TrackingStateStore::default(); + let pkce = PkceCodeVerifier::new("test-verifier".to_string()); + let csrf = CsrfToken::new("test-csrf".to_string()); + let state = StoredAuthorizationState::new(&pkce, &csrf); + + store.save("test-csrf", state).await.unwrap(); + assert_eq!(store.save_count.load(Ordering::SeqCst), 1); + + let _ = store.load("test-csrf").await.unwrap(); + assert_eq!(store.load_count.load(Ordering::SeqCst), 1); + + store.delete("test-csrf").await.unwrap(); + assert_eq!(store.delete_count.load(Ordering::SeqCst), 1); + + // Verify custom store can be set on AuthorizationManager + let mut manager = AuthorizationManager::new("http://localhost").await.unwrap(); + manager.set_state_store(TrackingStateStore::default()); + } } From 60f939918c4816c5c07d4275c6decb44c4d59d0b Mon Sep 17 00:00:00 2001 From: centdix Date: Sat, 10 Jan 2026 17:16:29 +0100 Subject: [PATCH 2/2] fix(examples): use CLI server_url for transport connection --- examples/clients/src/auth/oauth_client.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/clients/src/auth/oauth_client.rs b/examples/clients/src/auth/oauth_client.rs index 9b131652..ab7867cf 100644 --- a/examples/clients/src/auth/oauth_client.rs +++ b/examples/clients/src/auth/oauth_client.rs @@ -173,7 +173,7 @@ async fn main() -> Result<()> { let client = AuthClient::new(reqwest::Client::default(), am); let transport = StreamableHttpClientTransport::with_client( client, - StreamableHttpClientTransportConfig::with_uri(MCP_SERVER_URL), + StreamableHttpClientTransportConfig::with_uri(server_url.as_str()), ); // Create client and connect to MCP server