diff --git a/.changeset/assistants-mcp-oauth.md b/.changeset/assistants-mcp-oauth.md new file mode 100644 index 0000000000..72816bfb81 --- /dev/null +++ b/.changeset/assistants-mcp-oauth.md @@ -0,0 +1,5 @@ +--- +"server": minor +--- + +Assistants can now authenticate with OAuth-protected MCP servers. When a configured MCP server requires user authentication, the assistant relays the authorization link through an available output tool; once the user completes authentication, the assistant reconnects and continues its task. diff --git a/agents/runner/src/gram_client.rs b/agents/runner/src/gram_client.rs index 583350aac6..5d6c2190b9 100644 --- a/agents/runner/src/gram_client.rs +++ b/agents/runner/src/gram_client.rs @@ -1,13 +1,14 @@ use std::time::Duration; use reqwest_middleware::ClientWithMiddleware; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::http_layer::TokenRegistry; use crate::wire::ThreadBootstrap; const BOOTSTRAP_PATH: &str = "/rpc/assistants.getThreadBootstrap"; +const CREATE_MCP_AUTH_FLOW_PATH: &str = "/rpc/assistantMcpAuth.create"; const BOOTSTRAP_TIMEOUT: Duration = Duration::from_secs(15); /// Lightweight client used by the runner to pull a per-thread bootstrap @@ -44,6 +45,20 @@ struct BootstrapRequest<'a> { thread_id: &'a str, } +#[derive(Serialize)] +struct CreateMcpAuthFlowRequest<'a> { + thread_id: &'a str, + server_id: &'a str, + url: &'a str, +} + +#[derive(Debug, Deserialize)] +pub struct CreateMcpAuthFlowResponse { + pub server_id: String, + pub mcp_slug: String, + pub auth_url: String, +} + impl GramBootstrapClient { pub fn new(base_url: String, http: ClientWithMiddleware) -> Self { Self { base_url, http } @@ -80,4 +95,43 @@ impl GramBootstrapClient { let bootstrap: ThreadBootstrap = serde_json::from_str(&body)?; Ok(bootstrap) } + + pub async fn create_mcp_auth_flow( + &self, + thread_id: &str, + server_id: &str, + url: &str, + tokens: &TokenRegistry, + ) -> Result { + let endpoint = format!( + "{}{}", + self.base_url.trim_end_matches('/'), + CREATE_MCP_AUTH_FLOW_PATH + ); + let bearer = tokens.current().map_err(|_| GramClientError::Token)?; + + let resp = self + .http + .post(&endpoint) + .timeout(BOOTSTRAP_TIMEOUT) + .bearer_auth(&bearer) + .json(&CreateMcpAuthFlowRequest { + thread_id, + server_id, + url, + }) + .send() + .await?; + + let status = resp.status(); + let body = resp.text().await?; + if !status.is_success() { + return Err(GramClientError::Status { + status: status.as_u16(), + body, + }); + } + let flow: CreateMcpAuthFlowResponse = serde_json::from_str(&body)?; + Ok(flow) + } } diff --git a/agents/runner/src/runtime.rs b/agents/runner/src/runtime.rs index 356ceeb50d..85730cefdf 100644 --- a/agents/runner/src/runtime.rs +++ b/agents/runner/src/runtime.rs @@ -10,7 +10,7 @@ use agentkit_loop::{ PromptCacheRetention, SessionConfig, }; use agentkit_mcp::{ - McpServerConfig, McpServerId, McpServerManager, McpTransportBinding, + McpError, McpServerConfig, McpServerId, McpServerManager, McpTransportBinding, StreamableHttpTransportConfig, }; use agentkit_provider_openrouter::{OpenRouterConfig, OpenRouterProvider}; @@ -278,7 +278,8 @@ async fn spawn_thread( bootstrap: ThreadBootstrap, tokens: TokenRegistry, ) -> Result, RunnerError> { - let (mcp_cmd_tx, mcp_catalog) = build_thread_mcp(host, &bootstrap.mcp_servers, &tokens).await?; + let (mcp_cmd_tx, mcp_catalog, mcp_auth_notices) = + build_thread_mcp(host, &thread_id, &bootstrap.mcp_servers, &tokens).await?; let chat_id = bootstrap.chat_id.clone(); @@ -345,6 +346,9 @@ async fn spawn_thread( transcript.push(Item::text(ItemKind::System, &bootstrap.instructions)); } transcript.extend(normalize_history(&bootstrap.history)?); + for notice in mcp_auth_notices { + transcript.push(Item::text(ItemKind::User, ¬ice)); + } let permissions = CompositePermissionChecker::new(PermissionDecision::Allow).with_policy( PathPolicy::new() @@ -358,10 +362,7 @@ async fn spawn_thread( let mcp_server_ids: Vec = bootstrap.mcp_servers.iter().map(|s| s.id.clone()).collect(); let native_tools = ToolRegistry::new().with(tools::bun_run::bun_run).with( - tools::mcp_force_reconnect::McpForceReconnectTool::new( - Arc::clone(host), - mcp_server_ids, - ), + tools::mcp_force_reconnect::McpForceReconnectTool::new(Arc::clone(host), mcp_server_ids), ); let mcp_source = ClippedToolSource::new(mcp_catalog, host.spill_root.clone()); @@ -436,29 +437,56 @@ async fn spawn_thread( async fn build_thread_mcp( host: &Arc, + thread_id: &str, servers: &[McpServer], tokens: &TokenRegistry, -) -> Result<(mpsc::Sender, CatalogReader), RunnerError> { +) -> Result<(mpsc::Sender, CatalogReader, Vec), RunnerError> { let mut manager = McpServerManager::new(); let catalog = manager.source(); + let mut auth_notices = Vec::new(); for server in servers { let config = build_mcp_server_config(server, &host.http_client, tokens)?; let server_id = McpServerId::new(server.id.clone()); manager.register_server(config); - let _ = connect_and_log(&mut manager, &server_id, "register").await; + if let Err(err) = connect_and_log(&mut manager, &server_id, "register").await + && err.auth_required + { + match host + .gram_client + .create_mcp_auth_flow(thread_id, &server.id, &server.url, tokens) + .await + { + Ok(flow) => auth_notices.push(format!( + "\nEventType: assistant_mcp_auth_required\nMCPServerID: {server_id}\nMCPSlug: {mcp_slug}\nAuthURL: {auth_url}\n", + server_id = flow.server_id, + mcp_slug = flow.mcp_slug, + auth_url = flow.auth_url, + )), + Err(flow_err) => tracing::warn!( + server_id = %server_id, + error = %flow_err, + "failed to create assistant mcp auth flow" + ), + } + } } let (cmd_tx, cmd_rx) = mpsc::channel(MCP_CMD_CAPACITY); tokio::spawn(run_mcp_actor(manager, cmd_rx)); - Ok((cmd_tx, catalog)) + Ok((cmd_tx, catalog, auth_notices)) +} + +struct McpConnectFailure { + message: String, + auth_required: bool, } async fn connect_and_log( manager: &mut McpServerManager, server_id: &McpServerId, action: &'static str, -) -> Result<(), String> { +) -> Result<(), McpConnectFailure> { match manager.connect_server(server_id).await { Ok(handle) => { tracing::info!( @@ -470,8 +498,12 @@ async fn connect_and_log( Ok(()) } Err(e) => { + let auth_required = matches!(e, McpError::AuthRequired(_)); tracing::warn!(server_id = %server_id, error = %e, action, "mcp connect failed"); - Err(e.to_string()) + Err(McpConnectFailure { + message: e.to_string(), + auth_required, + }) } } } @@ -517,7 +549,9 @@ async fn run_mcp_actor(mut manager: McpServerManager, mut cmd_rx: mpsc::Receiver if let Err(e) = manager.disconnect_server(&server_id).await { tracing::debug!(server_id = %server_id, error = %e, "disconnect during force reconnect"); } - let result = connect_and_log(&mut manager, &server_id, "force_reconnect").await; + let result = connect_and_log(&mut manager, &server_id, "force_reconnect") + .await + .map_err(|err| err.message); let _ = reply.send(result); } } @@ -737,10 +771,14 @@ mod tests { async fn evict_thread_clears_seen_keys_with_prefix() { let host = empty_host(); insert_thread(&host, "T", Some(Instant::now())); - host.seen - .insert("T:evt-1".to_string(), Arc::new(tokio::sync::Mutex::new(true))); - host.seen - .insert("T:evt-2".to_string(), Arc::new(tokio::sync::Mutex::new(true))); + host.seen.insert( + "T:evt-1".to_string(), + Arc::new(tokio::sync::Mutex::new(true)), + ); + host.seen.insert( + "T:evt-2".to_string(), + Arc::new(tokio::sync::Mutex::new(true)), + ); host.seen.insert( "other:evt-1".to_string(), Arc::new(tokio::sync::Mutex::new(true)), diff --git a/server/cmd/gram/start.go b/server/cmd/gram/start.go index b01e60a8db..a1c9b85d13 100644 --- a/server/cmd/gram/start.go +++ b/server/cmd/gram/start.go @@ -827,7 +827,7 @@ func newStartCommand() *cli.Command { ) contextWindowResolver := openrouter.NewContextWindowResolver(logger, guardianPolicy, cache.NewRedisCacheAdapter(redisClient)) chatService := chat.NewService(logger, tracerProvider, db, sessionManager, chatSessionsManager, openRouter, chatClient, contextWindowResolver, posthogClient, telemSvc, assetStorage, authzEngine, assistantTokenManager, billingRepo) - assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemLogger, contextWindowResolver) + assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, guardianPolicy, encryptionClient, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemLogger, contextWindowResolver) assistantsCore.SetWakeCanceller(triggerApp) assistantsCore.SetChatMessageWriter(chatWriter) assistantsSvc := assistants.NewService(logger, tracerProvider, db, sessionManager, authzEngine, assistantsCore, &background.AssistantWorkflowSignaler{TemporalEnv: temporalEnv}) diff --git a/server/cmd/gram/worker.go b/server/cmd/gram/worker.go index 81617ef8b2..4ae48f1478 100644 --- a/server/cmd/gram/worker.go +++ b/server/cmd/gram/worker.go @@ -676,7 +676,7 @@ func newWorkerCommand() *cli.Command { return err } contextWindowResolver := openrouter.NewContextWindowResolver(logger, guardianPolicy, cache.NewRedisCacheAdapter(redisClient)) - assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemetryLogger, contextWindowResolver) + assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, guardianPolicy, encryptionClient, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemetryLogger, contextWindowResolver) assistantsCore.SetWakeCanceller(triggerApp) assistantsCore.SetChatMessageWriter(chatWriter) assistantsSvc := assistants.NewService(logger, tracerProvider, db, sessionManager, authzEngine, assistantsCore, &background.AssistantWorkflowSignaler{TemporalEnv: temporalEnv}) diff --git a/server/internal/assistants/impl.go b/server/internal/assistants/impl.go index cabb038a3d..7d74eec0ec 100644 --- a/server/internal/assistants/impl.go +++ b/server/internal/assistants/impl.go @@ -72,6 +72,8 @@ func Attach(mux goahttp.Muxer, service *Service) { srv.New(endpoints, mux, goahttp.RequestDecoder, goahttp.ResponseEncoder, nil, nil), ) o11y.AttachHandler(mux, "POST", "/rpc/assistants.getThreadBootstrap", oops.ErrHandle(service.logger, service.handleGetThreadBootstrap).ServeHTTP) + o11y.AttachHandler(mux, "POST", "/rpc/assistantMcpAuth.create", oops.ErrHandle(service.logger, service.handleCreateMCPAuthFlow).ServeHTTP) + o11y.AttachHandler(mux, "GET", "/rpc/assistantMcpAuth/{id}/oauth/callback", oops.ErrHandle(service.logger, service.handleMCPAuthCallback).ServeHTTP) } func (s *Service) APIKeyAuth(ctx context.Context, key string, schema *security.APIKeyScheme) (context.Context, error) { diff --git a/server/internal/assistants/impl_test.go b/server/internal/assistants/impl_test.go index 3e0111b9a6..50eb7e4c61 100644 --- a/server/internal/assistants/impl_test.go +++ b/server/internal/assistants/impl_test.go @@ -260,7 +260,7 @@ func newRBACServiceWithConn(t *testing.T, dbName string) (*Service, context.Cont logger: logger, auth: nil, authz: authzEngine, - core: NewServiceCore(logger, testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(logger), nil), + core: NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(logger), nil), signaler: nil, } diff --git a/server/internal/assistants/mcp_auth_handler.go b/server/internal/assistants/mcp_auth_handler.go new file mode 100644 index 0000000000..a067114eb4 --- /dev/null +++ b/server/internal/assistants/mcp_auth_handler.go @@ -0,0 +1,493 @@ +package assistants + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "mime" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + + assistantrepo "github.com/speakeasy-api/gram/server/internal/assistants/repo" + "github.com/speakeasy-api/gram/server/internal/attr" + "github.com/speakeasy-api/gram/server/internal/auth/assistanttokens" + "github.com/speakeasy-api/gram/server/internal/contextvalues" + "github.com/speakeasy-api/gram/server/internal/externalmcp" + "github.com/speakeasy-api/gram/server/internal/guardian" + "github.com/speakeasy-api/gram/server/internal/o11y" + "github.com/speakeasy-api/gram/server/internal/oops" +) + +const ( + mcpAuthFlowMaxBodyBytes = 16 * 1024 + mcpAuthFlowTTL = 15 * time.Minute + mcpAuthEventKind = "assistant_mcp_auth" + + mcpAuthStatusSuccess = "success" + mcpAuthStatusFailed = "failed" +) + +type createMCPAuthFlowRequest struct { + ThreadID string `json:"thread_id"` + ServerID string `json:"server_id"` + URL string `json:"url"` +} + +type createMCPAuthFlowResponse struct { + ServerID string `json:"server_id"` + McpSlug string `json:"mcp_slug"` + AuthURL string `json:"auth_url"` +} + +type mcpAuthClientRegistrationRequest struct { + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` +} + +type mcpAuthClientRegistrationResponse struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +type mcpAuthEventPayload struct { + GramEventKind string `json:"gram_event_kind"` + Status string `json:"status"` + ServerID string `json:"mcp_server_id"` + McpSlug string `json:"mcp_slug"` + Error string `json:"error,omitempty"` + ErrorDescription string `json:"error_description,omitempty"` +} + +func (s *Service) handleCreateMCPAuthFlow(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + token := r.Header.Get("Authorization") + if token == "" { + return oops.C(oops.CodeUnauthorized) + } + + authedCtx, claims, err := s.core.assistantTokens.Authorize(ctx, token) + if err != nil { + return fmt.Errorf("authorize assistant runtime token: %w", err) + } + ctx = authedCtx + + principal, ok := contextvalues.GetAssistantPrincipal(ctx) + if !ok { + return oops.C(oops.CodeUnauthorized) + } + projectID, err := uuid.Parse(claims.ProjectID) + if err != nil { + return oops.E(oops.CodeUnauthorized, err, "invalid token project") + } + + if ct := r.Header.Get("Content-Type"); ct != "" { + mediaType, _, err := mime.ParseMediaType(ct) + if err != nil || mediaType != "application/json" { + return oops.E(oops.CodeBadRequest, err, "Content-Type must be application/json") + } + } + r.Body = http.MaxBytesReader(w, r.Body, mcpAuthFlowMaxBodyBytes) + var req createMCPAuthFlowRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return oops.E(oops.CodeBadRequest, err, "request body too large") + } + return oops.E(oops.CodeBadRequest, err, "decode mcp auth flow request") + } + threadID, err := uuid.Parse(req.ThreadID) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "invalid thread_id") + } + if principal.ThreadID != uuid.Nil && principal.ThreadID != threadID { + return oops.E(oops.CodeForbidden, nil, "token thread does not match requested thread") + } + + mcpURL, err := url.Parse(req.URL) + if err != nil || mcpURL.Scheme == "" || mcpURL.Host == "" { + return oops.E(oops.CodeBadRequest, err, "invalid mcp url") + } + mcpSlug, err := mcpSlugFromURL(mcpURL) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "mcp auth flow only supports hosted MCP URLs") + } + + flowID := uuid.NewString() + if s.core.serverURL == nil { + return oops.E(oops.CodeUnexpected, nil, "assistant mcp auth callback base url not configured").Log(ctx, s.logger) + } + redirectURI := s.core.serverURL.JoinPath("rpc", "assistantMcpAuth", flowID, "oauth", "callback").String() + codeVerifier, codeChallenge, err := newPKCEPair() + if err != nil { + return oops.E(oops.CodeUnexpected, err, "generate PKCE verifier").Log(ctx, s.logger) + } + encryptedVerifier, err := s.core.encryptionClient.Encrypt([]byte(codeVerifier)) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "encrypt pkce verifier").Log(ctx, s.logger) + } + + metadata, err := externalmcp.DiscoverOAuthMetadata(ctx, s.logger, s.core.guardianPolicy, "", mcpURL.String()) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "discover mcp authorization server metadata").Log(ctx, s.logger) + } + if metadata.AuthorizationEndpoint == "" || metadata.TokenEndpoint == "" || metadata.RegistrationEndpoint == "" { + return oops.E(oops.CodeUnexpected, nil, "mcp authorization server does not advertise RFC 8414 endpoints").Log(ctx, s.logger) + } + + registration, err := s.registerMCPAuthClient(ctx, metadata.RegistrationEndpoint, redirectURI) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "register assistant mcp oauth client").Log(ctx, s.logger) + } + encryptedSecret, err := s.core.encryptionClient.Encrypt([]byte(registration.ClientSecret)) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "encrypt mcp client secret").Log(ctx, s.logger) + } + + state, err := s.core.assistantTokens.GenerateMCPAuthFlow(assistanttokens.MCPAuthFlowInput{ + OrgID: claims.OrgID, + ProjectID: projectID, + UserID: claims.UserID, + AssistantID: principal.AssistantID, + ThreadID: threadID, + FlowID: flowID, + ServerID: req.ServerID, + McpURL: mcpURL.String(), + ClientID: registration.ClientID, + ClientSecret: encryptedSecret, + RedirectURI: redirectURI, + CodeVerifier: encryptedVerifier, + TokenEndpoint: metadata.TokenEndpoint, + TTL: mcpAuthFlowTTL, + }) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "sign mcp auth flow state").Log(ctx, s.logger) + } + + authURL, err := buildMCPAuthURL(metadata.AuthorizationEndpoint, registration.ClientID, redirectURI, state, codeChallenge) + if err != nil { + return oops.E(oops.CodeUnexpected, err, "build mcp auth url").Log(ctx, s.logger) + } + + s.logger.InfoContext(ctx, "assistant mcp auth flow created", + attr.SlogAssistantID(principal.AssistantID.String()), + attr.SlogAssistantThreadID(threadID.String()), + attr.SlogProjectID(projectID.String()), + attr.SlogToolsetMCPSlug(mcpSlug), + ) + + return writeJSON(w, http.StatusOK, createMCPAuthFlowResponse{ + ServerID: req.ServerID, + McpSlug: mcpSlug, + AuthURL: authURL, + }) +} + +func (s *Service) handleMCPAuthCallback(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + flowID := chi.URLParam(r, "id") + state := r.URL.Query().Get("state") + claims, err := s.core.assistantTokens.ValidateMCPAuthFlow(state) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "invalid mcp auth callback state").Log(ctx, s.logger) + } + if claims.FlowID != flowID { + return oops.E(oops.CodeBadRequest, nil, "mcp auth callback flow mismatch").Log(ctx, s.logger) + } + + projectID, err := uuid.Parse(claims.ProjectID) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "invalid callback project id").Log(ctx, s.logger) + } + assistantID, err := uuid.Parse(claims.AssistantID) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "invalid callback assistant id").Log(ctx, s.logger) + } + threadID, err := uuid.Parse(claims.ThreadID) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "invalid callback thread id").Log(ctx, s.logger) + } + mcpURL, err := url.Parse(claims.McpURL) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "invalid callback mcp url").Log(ctx, s.logger) + } + mcpSlug, err := mcpSlugFromURL(mcpURL) + if err != nil { + return oops.E(oops.CodeBadRequest, err, "callback mcp url missing slug").Log(ctx, s.logger) + } + + payload := mcpAuthEventPayload{ + GramEventKind: mcpAuthEventKind, + Status: mcpAuthStatusSuccess, + ServerID: claims.ServerID, + McpSlug: mcpSlug, + Error: "", + ErrorDescription: "", + } + oauthErr := r.URL.Query().Get("error") + code := r.URL.Query().Get("code") + switch { + case oauthErr != "": + payload.Status = mcpAuthStatusFailed + payload.Error = oauthErr + payload.ErrorDescription = r.URL.Query().Get("error_description") + case code == "": + payload.Status = mcpAuthStatusFailed + payload.Error = "invalid_request" + payload.ErrorDescription = "authorization code missing from callback" + default: + if err := s.consumeMCPAuthGrant(ctx, claims, code); err != nil { + payload.Status = mcpAuthStatusFailed + payload.Error = "invalid_grant" + payload.ErrorDescription = "failed to consume authorization grant" + s.logger.ErrorContext(ctx, "assistant mcp auth grant consumption failed", + attr.SlogAssistantID(assistantID.String()), + attr.SlogAssistantThreadID(threadID.String()), + attr.SlogProjectID(projectID.String()), + attr.SlogError(err), + ) + } + } + + eventCreated, err := s.enqueueMCPAuthEvent(ctx, projectID, assistantID, threadID, flowID, payload) + if err != nil { + return err + } + if eventCreated { + if err := s.signaler.SignalCoordinator(ctx, assistantID); err != nil { + return fmt.Errorf("signal assistant coordinator: %w", err) + } + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, "Authentication complete

Authentication complete. You can close this window.

") + return nil +} + +func (s *Service) enqueueMCPAuthEvent(ctx context.Context, projectID, assistantID, threadID uuid.UUID, flowID string, payload mcpAuthEventPayload) (bool, error) { + repo := assistantrepo.New(s.core.db) + thread, err := repo.ResolveThreadCorrelation(ctx, assistantrepo.ResolveThreadCorrelationParams{ + ThreadID: threadID, + ProjectID: projectID, + }) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return false, oops.E(oops.CodeNotFound, err, "assistant thread not found").Log(ctx, s.logger) + } + return false, oops.E(oops.CodeUnexpected, err, "load assistant thread for mcp auth event").Log(ctx, s.logger) + } + if thread.AssistantID != assistantID { + return false, oops.E(oops.CodeForbidden, nil, "assistant thread assistant mismatch").Log(ctx, s.logger) + } + + body, err := json.Marshal(payload) + if err != nil { + return false, oops.E(oops.CodeUnexpected, err, "marshal mcp auth event").Log(ctx, s.logger) + } + _, err = repo.InsertAssistantThreadEvent(ctx, assistantrepo.InsertAssistantThreadEventParams{ + AssistantThreadID: threadID, + AssistantID: assistantID, + ProjectID: projectID, + TriggerInstanceID: uuid.NullUUID{UUID: uuid.Nil, Valid: false}, + EventID: mcpAuthEventKind + ":" + flowID, + CorrelationID: thread.CorrelationID, + Status: eventStatusPending, + NormalizedPayloadJson: body, + SourcePayloadJson: body, + }) + switch { + case errors.Is(err, pgx.ErrNoRows): + return false, nil + case err != nil: + return false, oops.E(oops.CodeUnexpected, err, "insert mcp auth assistant event").Log(ctx, s.logger) + default: + return true, nil + } +} + +func decodeMCPAuthTurn(ctx context.Context, logger *slog.Logger, event assistantThreadEventRecord) (string, bool) { + var payload mcpAuthEventPayload + if err := json.Unmarshal(event.NormalizedPayloadJSON, &payload); err != nil { + logger.WarnContext(ctx, "skip mcp auth event with undecodable payload", + attr.SlogAssistantEventID(event.EventID), + attr.SlogError(err), + ) + return "", false + } + if payload.GramEventKind != mcpAuthEventKind { + return "", false + } + var b strings.Builder + b.WriteString("\n") + fmt.Fprintf(&b, "EventID: %s\n", event.EventID) + fmt.Fprintf(&b, "EventType: %s\n", mcpAuthEventKind) + if payload.ServerID != "" { + fmt.Fprintf(&b, "MCPServerID: %s\n", payload.ServerID) + } + if payload.McpSlug != "" { + fmt.Fprintf(&b, "MCPSlug: %s\n", payload.McpSlug) + } + if payload.Status != "" { + fmt.Fprintf(&b, "Status: %s\n", payload.Status) + } + if payload.Error != "" { + fmt.Fprintf(&b, "Error: %s\n", payload.Error) + } + if payload.ErrorDescription != "" { + fmt.Fprintf(&b, "ErrorDescription: %s\n", payload.ErrorDescription) + } + b.WriteString("") + return b.String(), true +} + +func (s *Service) registerMCPAuthClient(ctx context.Context, endpoint, redirectURI string) (mcpAuthClientRegistrationResponse, error) { + payload := mcpAuthClientRegistrationRequest{ + ClientName: "Gram Assistant MCP Auth", + RedirectURIs: []string{redirectURI}, + GrantTypes: []string{"authorization_code"}, + ResponseTypes: []string{"code"}, + TokenEndpointAuthMethod: "client_secret_basic", + } + body, err := json.Marshal(payload) + if err != nil { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("marshal registration request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("build registration request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + resp, err := s.core.guardianPolicy.Client(guardian.WithDefaultRetryConfig()).Do(req) + if err != nil { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("send registration request: %w", err) + } + defer o11y.NoLogDefer(func() error { + _, _ = io.Copy(io.Discard, resp.Body) + return resp.Body.Close() + }) + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if err != nil { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("read registration response: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("registration failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + var out mcpAuthClientRegistrationResponse + if err := json.Unmarshal(respBody, &out); err != nil { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("decode registration response: %w", err) + } + if out.ClientID == "" { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("registration response missing client_id") + } + if out.ClientSecret == "" { + return mcpAuthClientRegistrationResponse{}, fmt.Errorf("registration response missing client_secret") + } + return out, nil +} + +func (s *Service) consumeMCPAuthGrant(ctx context.Context, claims *assistanttokens.MCPAuthFlowClaims, code string) error { + verifier, err := s.core.encryptionClient.Decrypt(claims.CodeVerifier) + if err != nil { + return fmt.Errorf("decrypt pkce verifier: %w", err) + } + clientSecret, err := s.core.encryptionClient.Decrypt(claims.ClientSecret) + if err != nil { + return fmt.Errorf("decrypt mcp client secret: %w", err) + } + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("redirect_uri", claims.RedirectURI) + form.Set("code_verifier", verifier) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, claims.TokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return fmt.Errorf("build token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(claims.ClientID, clientSecret) + resp, err := s.core.guardianPolicy.Client(guardian.WithDefaultRetryConfig()).Do(req) + if err != nil { + return fmt.Errorf("send token request: %w", err) + } + defer o11y.NoLogDefer(func() error { + _, _ = io.Copy(io.Discard, resp.Body) + return resp.Body.Close() + }) + body, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if err != nil { + return fmt.Errorf("read token response: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("token request failed: status=%d body=%s", resp.StatusCode, strings.TrimSpace(string(body))) + } + return nil +} + +func mcpSlugFromURL(u *url.URL) (string, error) { + parts := strings.Split(strings.Trim(u.EscapedPath(), "/"), "/") + if len(parts) != 2 || parts[0] != "mcp" { + return "", fmt.Errorf("expected /mcp/{slug}") + } + slug, err := url.PathUnescape(parts[1]) + if err != nil || slug == "" { + return "", fmt.Errorf("invalid mcp slug") + } + return slug, nil +} + +func buildMCPAuthURL(endpoint, clientID, redirectURI, state, codeChallenge string) (string, error) { + u, err := url.Parse(endpoint) + if err != nil { + return "", fmt.Errorf("parse authorize endpoint: %w", err) + } + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", clientID) + q.Set("redirect_uri", redirectURI) + q.Set("state", state) + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + u.RawQuery = q.Encode() + return u.String(), nil +} + +func newPKCEPair() (string, string, error) { + raw := make([]byte, 32) + if _, err := rand.Read(raw); err != nil { + return "", "", fmt.Errorf("read random verifier bytes: %w", err) + } + verifier := base64.RawURLEncoding.EncodeToString(raw) + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sum[:]) + return verifier, challenge, nil +} + +func writeJSON(w http.ResponseWriter, status int, payload any) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal response: %w", err) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if _, err := w.Write(body); err != nil { + return fmt.Errorf("write response: %w", err) + } + return nil +} diff --git a/server/internal/assistants/service.go b/server/internal/assistants/service.go index 1be8116df2..d4f41254d2 100644 --- a/server/internal/assistants/service.go +++ b/server/internal/assistants/service.go @@ -25,6 +25,8 @@ import ( "github.com/speakeasy-api/gram/server/internal/chat" chatrepo "github.com/speakeasy-api/gram/server/internal/chat/repo" "github.com/speakeasy-api/gram/server/internal/conv" + "github.com/speakeasy-api/gram/server/internal/encryption" + "github.com/speakeasy-api/gram/server/internal/guardian" "github.com/speakeasy-api/gram/server/internal/oops" "github.com/speakeasy-api/gram/server/internal/platformtools" "github.com/speakeasy-api/gram/server/internal/telemetry" @@ -285,23 +287,27 @@ type WakeCanceller interface { } type ServiceCore struct { - logger *slog.Logger - tracer trace.Tracer - db *pgxpool.Pool - runtime RuntimeBackend - slackClient *slackclient.SlackClient - assistantTokens *assistanttokens.Manager - serverURL *url.URL - telemetryLogger *telemetry.Logger - contextWindow *openrouter.ContextWindowResolver - wakeCanceller WakeCanceller - chatWriter *chat.ChatMessageWriter + logger *slog.Logger + tracer trace.Tracer + db *pgxpool.Pool + guardianPolicy *guardian.Policy + encryptionClient *encryption.Client + runtime RuntimeBackend + slackClient *slackclient.SlackClient + assistantTokens *assistanttokens.Manager + serverURL *url.URL + telemetryLogger *telemetry.Logger + contextWindow *openrouter.ContextWindowResolver + wakeCanceller WakeCanceller + chatWriter *chat.ChatMessageWriter } func NewServiceCore( logger *slog.Logger, tracerProvider trace.TracerProvider, db *pgxpool.Pool, + guardianPolicy *guardian.Policy, + encryptionClient *encryption.Client, runtime RuntimeBackend, slackClient *slackclient.SlackClient, assistantTokens *assistanttokens.Manager, @@ -310,17 +316,19 @@ func NewServiceCore( contextWindow *openrouter.ContextWindowResolver, ) *ServiceCore { return &ServiceCore{ - logger: logger, - tracer: tracerProvider.Tracer("github.com/speakeasy-api/gram/server/internal/assistants"), - db: db, - runtime: newTelemetryRuntimeBackend(runtime, telemetryLogger), - slackClient: slackClient, - assistantTokens: assistantTokens, - serverURL: serverURL, - telemetryLogger: telemetryLogger, - contextWindow: contextWindow, - wakeCanceller: nil, - chatWriter: nil, + logger: logger, + tracer: tracerProvider.Tracer("github.com/speakeasy-api/gram/server/internal/assistants"), + db: db, + guardianPolicy: guardianPolicy, + encryptionClient: encryptionClient, + runtime: newTelemetryRuntimeBackend(runtime, telemetryLogger), + slackClient: slackClient, + assistantTokens: assistantTokens, + serverURL: serverURL, + telemetryLogger: telemetryLogger, + contextWindow: contextWindow, + wakeCanceller: nil, + chatWriter: nil, } } @@ -1617,6 +1625,17 @@ func (s *ServiceCore) processEventTurn( runtime assistantRuntimeRecord, event assistantThreadEventRecord, ) error { + if prompt, ok := decodeMCPAuthTurn(ctx, s.logger, event); ok { + turnToken, err := s.MintThreadScopedRuntimeToken(assistant, thread.ID) + if err != nil { + return err + } + if err := s.runtime.RunTurn(ctx, runtime, thread.ID, event.ID.String(), turnToken, prompt); err != nil { + return fmt.Errorf("run assistant turn: %w", err) + } + return nil + } + adapter, err := getSourceAdapter(thread.SourceKind) if err != nil { return err @@ -1804,7 +1823,15 @@ const assistantRuntimeTokenTTL = 60 * time.Minute const outputChannelAddendum = `## Output channel -Your text responses are not delivered to the user. To communicate, call a tool (e.g. post a Slack message, send an email). If no suitable tool is available, the user will not see your reply.` +Your text responses are not delivered to the user. To communicate, call a tool (e.g. post a Slack message, send an email). If no suitable tool is available, the user will not see your reply. + +## MCP authentication + +Two MCP authentication events may appear in this thread, each delivered as a block with EventType and field lines. + +- EventType "assistant_mcp_auth_required" carries an AuthURL. Relay AuthURL to the user verbatim through an output tool (do not shorten, summarize, or rewrite it). Reference the MCP server using its MCPSlug rather than MCPServerID. + +- EventType "assistant_mcp_auth" reports the result. When Status is "success" and you still need that server, call mcp_force_reconnect with server_id set to the MCPServerID value, then continue your task. When Status is "failed", inform the user via an output tool and include the ErrorDescription if present.` func composeInstructions(base string, thread assistantThreadRecord) (string, error) { adapter, err := getSourceAdapter(thread.SourceKind) diff --git a/server/internal/assistants/service_self_heal_test.go b/server/internal/assistants/service_self_heal_test.go index be69a5ae35..ae01b8e8fd 100644 --- a/server/internal/assistants/service_self_heal_test.go +++ b/server/internal/assistants/service_self_heal_test.go @@ -70,7 +70,7 @@ func TestServiceCoreSelfHealsHistoryCorruptionOnFirstAttempt(t *testing.T) { runTurnErr: corruption, stopCalls: &stopCalls, } - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, backend, nil, tokens, mustParseURLForServiceTest(t, "https://gram.example.com"), telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, tokens, mustParseURLForServiceTest(t, "https://gram.example.com"), telemetry.NewStub(logger), nil) chatWriter, chatWriterShutdown := chat.NewChatMessageWriter(logger, conn, assetstest.NewTestBlobStore(t)) t.Cleanup(func() { _ = chatWriterShutdown(ctx) }) core.SetChatMessageWriter(chatWriter) @@ -168,7 +168,7 @@ func TestServiceCoreSkipsSelfHealAfterFirstRetry(t *testing.T) { runTurnErr: corruption, stopCalls: &stopCalls, } - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, backend, nil, tokens, mustParseURLForServiceTest(t, "https://gram.example.com"), telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, tokens, mustParseURLForServiceTest(t, "https://gram.example.com"), telemetry.NewStub(logger), nil) chatWriter, chatWriterShutdown := chat.NewChatMessageWriter(logger, conn, assetstest.NewTestBlobStore(t)) t.Cleanup(func() { _ = chatWriterShutdown(ctx) }) core.SetChatMessageWriter(chatWriter) diff --git a/server/internal/assistants/service_test.go b/server/internal/assistants/service_test.go index da4c9fd9e8..73b93b68bb 100644 --- a/server/internal/assistants/service_test.go +++ b/server/internal/assistants/service_test.go @@ -50,7 +50,7 @@ func TestServiceCoreAdmitPendingThreadsUsesFlyBackend(t *testing.T) { projectID, assistantID, _, threadID := insertAssistantFixture(t, conn) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) admitted, err := core.AdmitPendingThreads(t.Context(), assistantID) require.NoError(t, err) @@ -75,7 +75,7 @@ func TestServiceCoreAdmitPendingThreadsCapsFanOut(t *testing.T) { assistantID, pending := seedAssistantWithPendingThreads(t, conn, "assistants-cap", 2, 3) preActivateV2Runtime(t, conn, assistantID, pending[0]) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) admitted, err := core.AdmitPendingThreads(ctx, assistantID) require.NoError(t, err) @@ -93,7 +93,7 @@ func TestServiceCoreAdmitPendingThreadsBlocksWhenActiveAtCap(t *testing.T) { require.NotEmpty(t, pending) preActivateV2Runtime(t, conn, assistantID, active[0]) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) admitted, err := core.AdmitPendingThreads(ctx, assistantID) require.NoError(t, err) @@ -111,7 +111,7 @@ func TestServiceCoreAdmitPendingThreadsReleasesPartialHeadroom(t *testing.T) { assistantID, active, _ := seedAssistantWithActiveAndPending(t, conn, "assistants-partial", 2, 1, 2) preActivateV2Runtime(t, conn, assistantID, active[0]) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) admitted, err := core.AdmitPendingThreads(ctx, assistantID) require.NoError(t, err) @@ -131,7 +131,7 @@ func TestServiceCoreAdmitPendingThreadsBypassesCapForReservedStarter(t *testing. assistantID, _, pending := seedAssistantWithActiveAndPending(t, conn, "assistants-cold-bypass", 1, 1, 1) require.NotEmpty(t, pending) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) admitted, err := core.AdmitPendingThreads(ctx, assistantID) require.NoError(t, err) @@ -295,7 +295,7 @@ func TestServiceCoreExpireThreadRuntimeRevertsWhenTurnInFlight(t *testing.T) { statusResult: RuntimeBackendStatus{Configured: true, IdleSeconds: &busyIdle}, stopCalls: &stopCalls, } - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(logger), nil) result, err := core.ExpireThreadRuntime(t.Context(), projectID, threadID, DefaultWarmTTLSeconds) require.NoError(t, err) @@ -354,7 +354,7 @@ func TestServiceCoreExpireThreadRuntimeRetryAfterStopFailureIsIdempotent(t *test stopErr: errors.New("fly delete app blew up"), stopCalls: &stopCalls, } - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, failingBackend, nil, nil, nil, telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, failingBackend, nil, nil, nil, telemetry.NewStub(logger), nil) _, err = core.ExpireThreadRuntime(t.Context(), projectID, threadID, DefaultWarmTTLSeconds) require.Error(t, err, "first attempt with failing Stop must surface the error so Temporal retries") @@ -372,7 +372,7 @@ func TestServiceCoreExpireThreadRuntimeRetryAfterStopFailureIsIdempotent(t *test statusResult: RuntimeBackendStatus{Configured: true, IdleSeconds: new(uint64(DefaultWarmTTLSeconds + 60))}, stopCalls: &stopCalls, } - core = NewServiceCore(logger, testenv.NewTracerProvider(t), conn, healingBackend, nil, nil, nil, telemetry.NewStub(logger), nil) + core = NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, healingBackend, nil, nil, nil, telemetry.NewStub(logger), nil) result, err := core.ExpireThreadRuntime(t.Context(), projectID, threadID, DefaultWarmTTLSeconds) require.NoError(t, err, "retry must drive the existing expiring row to a terminal state") @@ -420,7 +420,7 @@ func TestServiceCoreReapStuckRuntimesCleansUpStuckExpiring(t *testing.T) { }) require.NoError(t, err) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapStuckRuntimes(t.Context()) require.NoError(t, err) @@ -465,7 +465,7 @@ func TestServiceCoreReapStuckRuntimesLeavesFreshExpiring(t *testing.T) { }) require.NoError(t, err) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapStuckRuntimes(t.Context()) require.NoError(t, err) @@ -517,7 +517,7 @@ func TestServiceCoreReapStuckRuntimesSkipsLiveProcessingLease(t *testing.T) { }) require.NoError(t, err) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapStuckRuntimes(ctx) require.NoError(t, err) @@ -574,7 +574,7 @@ func TestServiceCoreReapStuckRuntimesReclaimsStaleProcessingLease(t *testing.T) }) require.NoError(t, err) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapStuckRuntimes(ctx) require.NoError(t, err) @@ -700,7 +700,7 @@ func TestServiceCoreLoadChatHistoryReplaysToolTurns(t *testing.T) { require.NoError(t, err) } - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) history, err := core.loadChatHistory(t.Context(), chatID, projectID) require.NoError(t, err) @@ -782,7 +782,7 @@ func TestServiceCoreLoadChatHistoryReturnsOnlyLatestGeneration(t *testing.T) { require.NoError(t, err) } - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) history, err := core.loadChatHistory(t.Context(), chatID, projectID) require.NoError(t, err) @@ -827,7 +827,7 @@ func TestServiceCoreLoadChatHistoryFailsWhenToolRowMissingCallID(t *testing.T) { }) require.NoError(t, err) - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) _, err = core.loadChatHistory(ctx, chatID, projectID) require.ErrorContains(t, err, "tool chat row missing tool_call_id") @@ -843,7 +843,7 @@ func TestServiceCoreProcessThreadEventsCompletesEvent(t *testing.T) { logger := testenv.NewLogger(t) tokens := assistanttokens.New("test-jwt-secret", conn, nil) - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, tokens, nil, telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, tokens, nil, telemetry.NewStub(logger), nil) admitted, err := core.AdmitPendingThreads(t.Context(), assistantID) require.NoError(t, err) @@ -876,7 +876,7 @@ func TestServiceCoreProcessThreadEventsRequeuesOnTurnFailure(t *testing.T) { logger := testenv.NewLogger(t) tokens := assistanttokens.New("test-jwt-secret", conn, nil) backend := testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: errors.New("runtime RunTurn blew up")} - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, backend, nil, tokens, nil, telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, tokens, nil, telemetry.NewStub(logger), nil) admitted, err := core.AdmitPendingThreads(t.Context(), assistantID) require.NoError(t, err) @@ -912,7 +912,7 @@ func TestServiceCoreProcessThreadEventsMarksRuntimeFailedOnUnhealthyTurn(t *test runTurnErr: ErrRuntimeUnhealthy, stopCalls: &stopCalls, } - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, backend, nil, tokens, mustParseURLForServiceTest(t, "https://gram.example.com"), telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, tokens, mustParseURLForServiceTest(t, "https://gram.example.com"), telemetry.NewStub(logger), nil) admitted, err := core.AdmitPendingThreads(t.Context(), assistantID) require.NoError(t, err) @@ -1039,7 +1039,7 @@ func TestServiceCoreDeleteAssistantReapsRuntimes(t *testing.T) { reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) require.NoError(t, core.DeleteAssistant(t.Context(), projectID, assistantID)) require.EqualValues(t, 1, reapCalls.Load()) @@ -1075,7 +1075,7 @@ func TestServiceCoreDeleteAssistantSucceedsEvenWhenReapErrors(t *testing.T) { reapCalls: reapCalls, reapErr: errors.New("fly api 503"), } - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) require.NoError(t, core.DeleteAssistant(t.Context(), projectID, assistantID)) require.EqualValues(t, 1, reapCalls.Load()) @@ -1101,7 +1101,7 @@ func TestServiceCoreReapAssistantRuntimesCallsBackendAndClearsMetadata(t *testin reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapAssistantRuntimes(t.Context(), projectID, assistantID) require.NoError(t, err) @@ -1144,7 +1144,7 @@ func TestServiceCoreReapAssistantRuntimesSkipsRowsWithoutMetadata(t *testing.T) reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapAssistantRuntimes(t.Context(), projectID, assistantID) require.NoError(t, err) @@ -1166,7 +1166,7 @@ func TestServiceCoreReapInactiveAssistantRuntimesCollectsOnlyInactive(t *testing reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapInactiveAssistantRuntimes(t.Context(), ReapInactiveAssistantRuntimesParams{ InactivityThreshold: 7 * 24 * time.Hour, @@ -1200,7 +1200,7 @@ func TestServiceCoreReapInactiveAssistantRuntimesSkipsAssistantWithRecentActivit reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapInactiveAssistantRuntimes(t.Context(), ReapInactiveAssistantRuntimesParams{ InactivityThreshold: 7 * 24 * time.Hour, @@ -1250,7 +1250,7 @@ func TestServiceCoreReapInactiveAssistantRuntimesReapsSiblingsAcrossSweeps(t *te reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) first, err := core.ReapInactiveAssistantRuntimes(t.Context(), ReapInactiveAssistantRuntimesParams{ InactivityThreshold: 7 * 24 * time.Hour, @@ -1283,7 +1283,7 @@ func TestServiceCoreReapInactiveAssistantRuntimesReapsStaleRowsRegardlessOfState reapCalls := &atomic.Int64{} backend := testRuntimeBackend{backend: runtimeBackendFlyIO, reapCalls: reapCalls} - core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) + core := NewServiceCore(testenv.NewLogger(t), testenv.NewTracerProvider(t), conn, nil, nil, backend, nil, nil, nil, telemetry.NewStub(testenv.NewLogger(t)), nil) result, err := core.ReapInactiveAssistantRuntimes(t.Context(), ReapInactiveAssistantRuntimesParams{ InactivityThreshold: 7 * 24 * time.Hour, @@ -1366,7 +1366,7 @@ func TestServiceCoreEnqueueTriggerTaskSkipsMissingAssistant(t *testing.T) { logger := testenv.NewLogger(t) tokens := assistanttokens.New("test-jwt-secret", conn, nil) - core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, tokens, nil, telemetry.NewStub(logger), nil) + core := NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO}, nil, tokens, nil, telemetry.NewStub(logger), nil) missing := uuid.New() result, err := core.EnqueueTriggerTask(t.Context(), bgtriggers.Task{ diff --git a/server/internal/auth/assistanttokens/manager.go b/server/internal/auth/assistanttokens/manager.go index 36a4abef78..08e6539d86 100644 --- a/server/internal/auth/assistanttokens/manager.go +++ b/server/internal/auth/assistanttokens/manager.go @@ -46,6 +46,42 @@ type Claims struct { jwt.RegisteredClaims } +const mcpAuthFlowIssuer = "gram-assistants-mcp-auth-flow" + +type MCPAuthFlowInput struct { + OrgID string + ProjectID uuid.UUID + UserID string + AssistantID uuid.UUID + ThreadID uuid.UUID + FlowID string + ServerID string + McpURL string + ClientID string + ClientSecret string + RedirectURI string + CodeVerifier string + TokenEndpoint string + TTL time.Duration +} + +type MCPAuthFlowClaims struct { + OrgID string `json:"org_id"` + ProjectID string `json:"project_id"` + UserID string `json:"user_id"` + AssistantID string `json:"assistant_id"` + ThreadID string `json:"thread_id"` + FlowID string `json:"flow_id"` + ServerID string `json:"server_id"` + McpURL string `json:"mcp_url"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` + RedirectURI string `json:"redirect_uri"` + CodeVerifier string `json:"code_verifier"` + TokenEndpoint string `json:"token_endpoint"` + jwt.RegisteredClaims +} + type GenerateInput struct { OrgID string ProjectID uuid.UUID @@ -114,6 +150,94 @@ func (m *Manager) Generate(input GenerateInput) (string, error) { return signed, nil } +func (m *Manager) GenerateMCPAuthFlow(input MCPAuthFlowInput) (string, error) { + now := time.Now() + ttl := input.TTL + if ttl <= 0 { + ttl = 15 * time.Minute + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, MCPAuthFlowClaims{ + OrgID: input.OrgID, + ProjectID: input.ProjectID.String(), + UserID: input.UserID, + AssistantID: input.AssistantID.String(), + ThreadID: input.ThreadID.String(), + FlowID: input.FlowID, + ServerID: input.ServerID, + McpURL: input.McpURL, + ClientID: input.ClientID, + ClientSecret: input.ClientSecret, + RedirectURI: input.RedirectURI, + CodeVerifier: input.CodeVerifier, + TokenEndpoint: input.TokenEndpoint, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: mcpAuthFlowIssuer, + Subject: input.AssistantID.String(), + Audience: nil, + ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), + NotBefore: nil, + IssuedAt: jwt.NewNumericDate(now), + ID: input.FlowID, + }, + }) + + signed, err := token.SignedString([]byte(m.jwtSecret)) + if err != nil { + return "", fmt.Errorf("sign mcp auth flow token: %w", err) + } + return signed, nil +} + +func (m *Manager) ValidateMCPAuthFlow(tokenString string) (*MCPAuthFlowClaims, error) { + tokenString = strings.TrimSpace(tokenString) + if tokenString == "" { + return nil, oops.C(oops.CodeUnauthorized) + } + + token, err := jwt.ParseWithClaims(tokenString, &MCPAuthFlowClaims{ + OrgID: "", + ProjectID: "", + UserID: "", + AssistantID: "", + ThreadID: "", + FlowID: "", + ServerID: "", + McpURL: "", + ClientID: "", + ClientSecret: "", + RedirectURI: "", + CodeVerifier: "", + TokenEndpoint: "", + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: "", + Subject: "", + Audience: nil, + ExpiresAt: nil, + NotBefore: nil, + IssuedAt: nil, + ID: "", + }, + }, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(m.jwtSecret), nil + }) + if err != nil { + return nil, oops.E(oops.CodeUnauthorized, err, "invalid mcp auth flow token") + } + + claims, ok := token.Claims.(*MCPAuthFlowClaims) + if !ok || !token.Valid { + return nil, oops.E(oops.CodeUnauthorized, nil, "invalid mcp auth flow token") + } + if claims.Issuer != mcpAuthFlowIssuer { + return nil, oops.E(oops.CodeUnauthorized, nil, "invalid mcp auth flow token issuer") + } + return claims, nil +} + func (m *Manager) Validate(tokenString string) (*Claims, error) { tokenString = strings.TrimSpace(tokenString) if strings.HasPrefix(strings.ToLower(tokenString), "bearer ") {