Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/assistants-mcp-oauth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"server": minor
---

Assistants can now authenticate with OAuth-protected MCP servers. When a configured MCP server requires user authentication, the assistant relays the authorization link through an available output tool; once the user completes authentication, the assistant reconnects and continues its task.
56 changes: 55 additions & 1 deletion agents/runner/src/gram_client.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::time::Duration;

use reqwest_middleware::ClientWithMiddleware;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::http_layer::TokenRegistry;
use crate::wire::ThreadBootstrap;

const BOOTSTRAP_PATH: &str = "/rpc/assistants.getThreadBootstrap";
const CREATE_MCP_AUTH_FLOW_PATH: &str = "/rpc/assistantMcpAuth.create";
const BOOTSTRAP_TIMEOUT: Duration = Duration::from_secs(15);

/// Lightweight client used by the runner to pull a per-thread bootstrap
Expand Down Expand Up @@ -44,6 +45,20 @@ struct BootstrapRequest<'a> {
thread_id: &'a str,
}

#[derive(Serialize)]
struct CreateMcpAuthFlowRequest<'a> {
thread_id: &'a str,
server_id: &'a str,
url: &'a str,
}

#[derive(Debug, Deserialize)]
pub struct CreateMcpAuthFlowResponse {
pub server_id: String,
pub mcp_slug: String,
pub auth_url: String,
}

impl GramBootstrapClient {
pub fn new(base_url: String, http: ClientWithMiddleware) -> Self {
Self { base_url, http }
Expand Down Expand Up @@ -80,4 +95,43 @@ impl GramBootstrapClient {
let bootstrap: ThreadBootstrap = serde_json::from_str(&body)?;
Ok(bootstrap)
}

pub async fn create_mcp_auth_flow(
&self,
thread_id: &str,
server_id: &str,
url: &str,
tokens: &TokenRegistry,
) -> Result<CreateMcpAuthFlowResponse, GramClientError> {
let endpoint = format!(
"{}{}",
self.base_url.trim_end_matches('/'),
CREATE_MCP_AUTH_FLOW_PATH
);
let bearer = tokens.current().map_err(|_| GramClientError::Token)?;

let resp = self
.http
.post(&endpoint)
.timeout(BOOTSTRAP_TIMEOUT)
.bearer_auth(&bearer)
.json(&CreateMcpAuthFlowRequest {
thread_id,
server_id,
url,
})
.send()
.await?;

let status = resp.status();
let body = resp.text().await?;
if !status.is_success() {
return Err(GramClientError::Status {
status: status.as_u16(),
body,
});
}
let flow: CreateMcpAuthFlowResponse = serde_json::from_str(&body)?;
Ok(flow)
}
}
70 changes: 54 additions & 16 deletions agents/runner/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use agentkit_loop::{
PromptCacheRetention, SessionConfig,
};
use agentkit_mcp::{
McpServerConfig, McpServerId, McpServerManager, McpTransportBinding,
McpError, McpServerConfig, McpServerId, McpServerManager, McpTransportBinding,
StreamableHttpTransportConfig,
};
use agentkit_provider_openrouter::{OpenRouterConfig, OpenRouterProvider};
Expand Down Expand Up @@ -278,7 +278,8 @@ async fn spawn_thread(
bootstrap: ThreadBootstrap,
tokens: TokenRegistry,
) -> Result<Arc<ConfiguredThread>, RunnerError> {
let (mcp_cmd_tx, mcp_catalog) = build_thread_mcp(host, &bootstrap.mcp_servers, &tokens).await?;
let (mcp_cmd_tx, mcp_catalog, mcp_auth_notices) =
build_thread_mcp(host, &thread_id, &bootstrap.mcp_servers, &tokens).await?;

let chat_id = bootstrap.chat_id.clone();

Expand Down Expand Up @@ -345,6 +346,9 @@ async fn spawn_thread(
transcript.push(Item::text(ItemKind::System, &bootstrap.instructions));
}
transcript.extend(normalize_history(&bootstrap.history)?);
for notice in mcp_auth_notices {
transcript.push(Item::text(ItemKind::User, &notice));
}

let permissions = CompositePermissionChecker::new(PermissionDecision::Allow).with_policy(
PathPolicy::new()
Expand All @@ -358,10 +362,7 @@ async fn spawn_thread(

let mcp_server_ids: Vec<String> = bootstrap.mcp_servers.iter().map(|s| s.id.clone()).collect();
let native_tools = ToolRegistry::new().with(tools::bun_run::bun_run).with(
tools::mcp_force_reconnect::McpForceReconnectTool::new(
Arc::clone(host),
mcp_server_ids,
),
tools::mcp_force_reconnect::McpForceReconnectTool::new(Arc::clone(host), mcp_server_ids),
);

let mcp_source = ClippedToolSource::new(mcp_catalog, host.spill_root.clone());
Expand Down Expand Up @@ -436,29 +437,56 @@ async fn spawn_thread(

async fn build_thread_mcp(
host: &Arc<RuntimeHost>,
thread_id: &str,
servers: &[McpServer],
tokens: &TokenRegistry,
) -> Result<(mpsc::Sender<McpCmd>, CatalogReader), RunnerError> {
) -> Result<(mpsc::Sender<McpCmd>, CatalogReader, Vec<String>), RunnerError> {
let mut manager = McpServerManager::new();
let catalog = manager.source();
let mut auth_notices = Vec::new();

for server in servers {
let config = build_mcp_server_config(server, &host.http_client, tokens)?;
let server_id = McpServerId::new(server.id.clone());
manager.register_server(config);
let _ = connect_and_log(&mut manager, &server_id, "register").await;
if let Err(err) = connect_and_log(&mut manager, &server_id, "register").await
&& err.auth_required
{
match host
.gram_client
.create_mcp_auth_flow(thread_id, &server.id, &server.url, tokens)
.await
{
Ok(flow) => auth_notices.push(format!(
"<message-context>\nEventType: assistant_mcp_auth_required\nMCPServerID: {server_id}\nMCPSlug: {mcp_slug}\nAuthURL: {auth_url}\n</message-context>",
server_id = flow.server_id,
mcp_slug = flow.mcp_slug,
auth_url = flow.auth_url,
)),
Err(flow_err) => tracing::warn!(
server_id = %server_id,
error = %flow_err,
"failed to create assistant mcp auth flow"
),
}
}
}

let (cmd_tx, cmd_rx) = mpsc::channel(MCP_CMD_CAPACITY);
tokio::spawn(run_mcp_actor(manager, cmd_rx));
Ok((cmd_tx, catalog))
Ok((cmd_tx, catalog, auth_notices))
}

struct McpConnectFailure {
message: String,
auth_required: bool,
}

async fn connect_and_log(
manager: &mut McpServerManager,
server_id: &McpServerId,
action: &'static str,
) -> Result<(), String> {
) -> Result<(), McpConnectFailure> {
match manager.connect_server(server_id).await {
Ok(handle) => {
tracing::info!(
Expand All @@ -470,8 +498,12 @@ async fn connect_and_log(
Ok(())
}
Err(e) => {
let auth_required = matches!(e, McpError::AuthRequired(_));
tracing::warn!(server_id = %server_id, error = %e, action, "mcp connect failed");
Err(e.to_string())
Err(McpConnectFailure {
message: e.to_string(),
auth_required,
})
}
}
}
Expand Down Expand Up @@ -517,7 +549,9 @@ async fn run_mcp_actor(mut manager: McpServerManager, mut cmd_rx: mpsc::Receiver
if let Err(e) = manager.disconnect_server(&server_id).await {
tracing::debug!(server_id = %server_id, error = %e, "disconnect during force reconnect");
}
let result = connect_and_log(&mut manager, &server_id, "force_reconnect").await;
let result = connect_and_log(&mut manager, &server_id, "force_reconnect")
.await
.map_err(|err| err.message);
let _ = reply.send(result);
}
}
Expand Down Expand Up @@ -737,10 +771,14 @@ mod tests {
async fn evict_thread_clears_seen_keys_with_prefix() {
let host = empty_host();
insert_thread(&host, "T", Some(Instant::now()));
host.seen
.insert("T:evt-1".to_string(), Arc::new(tokio::sync::Mutex::new(true)));
host.seen
.insert("T:evt-2".to_string(), Arc::new(tokio::sync::Mutex::new(true)));
host.seen.insert(
"T:evt-1".to_string(),
Arc::new(tokio::sync::Mutex::new(true)),
);
host.seen.insert(
"T:evt-2".to_string(),
Arc::new(tokio::sync::Mutex::new(true)),
);
host.seen.insert(
"other:evt-1".to_string(),
Arc::new(tokio::sync::Mutex::new(true)),
Expand Down
2 changes: 1 addition & 1 deletion server/cmd/gram/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ func newStartCommand() *cli.Command {
)
contextWindowResolver := openrouter.NewContextWindowResolver(logger, guardianPolicy, cache.NewRedisCacheAdapter(redisClient))
chatService := chat.NewService(logger, tracerProvider, db, sessionManager, chatSessionsManager, openRouter, chatClient, contextWindowResolver, posthogClient, telemSvc, assetStorage, authzEngine, assistantTokenManager, billingRepo)
assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemLogger, contextWindowResolver)
assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, guardianPolicy, encryptionClient, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemLogger, contextWindowResolver)
assistantsCore.SetWakeCanceller(triggerApp)
assistantsCore.SetChatMessageWriter(chatWriter)
assistantsSvc := assistants.NewService(logger, tracerProvider, db, sessionManager, authzEngine, assistantsCore, &background.AssistantWorkflowSignaler{TemporalEnv: temporalEnv})
Expand Down
2 changes: 1 addition & 1 deletion server/cmd/gram/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ func newWorkerCommand() *cli.Command {
return err
}
contextWindowResolver := openrouter.NewContextWindowResolver(logger, guardianPolicy, cache.NewRedisCacheAdapter(redisClient))
assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemetryLogger, contextWindowResolver)
assistantsCore := assistants.NewServiceCore(logger, tracerProvider, db, guardianPolicy, encryptionClient, assistantRuntime, slackClient, assistantTokenManager, serverURL, telemetryLogger, contextWindowResolver)
assistantsCore.SetWakeCanceller(triggerApp)
assistantsCore.SetChatMessageWriter(chatWriter)
assistantsSvc := assistants.NewService(logger, tracerProvider, db, sessionManager, authzEngine, assistantsCore, &background.AssistantWorkflowSignaler{TemporalEnv: temporalEnv})
Expand Down
2 changes: 2 additions & 0 deletions server/internal/assistants/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func Attach(mux goahttp.Muxer, service *Service) {
srv.New(endpoints, mux, goahttp.RequestDecoder, goahttp.ResponseEncoder, nil, nil),
)
o11y.AttachHandler(mux, "POST", "/rpc/assistants.getThreadBootstrap", oops.ErrHandle(service.logger, service.handleGetThreadBootstrap).ServeHTTP)
o11y.AttachHandler(mux, "POST", "/rpc/assistantMcpAuth.create", oops.ErrHandle(service.logger, service.handleCreateMCPAuthFlow).ServeHTTP)
o11y.AttachHandler(mux, "GET", "/rpc/assistantMcpAuth/{id}/oauth/callback", oops.ErrHandle(service.logger, service.handleMCPAuthCallback).ServeHTTP)
}

func (s *Service) APIKeyAuth(ctx context.Context, key string, schema *security.APIKeyScheme) (context.Context, error) {
Expand Down
2 changes: 1 addition & 1 deletion server/internal/assistants/impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ func newRBACServiceWithConn(t *testing.T, dbName string) (*Service, context.Cont
logger: logger,
auth: nil,
authz: authzEngine,
core: NewServiceCore(logger, testenv.NewTracerProvider(t), conn, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(logger), nil),
core: NewServiceCore(logger, testenv.NewTracerProvider(t), conn, nil, nil, testRuntimeBackend{backend: runtimeBackendFlyIO, runTurnErr: nil}, nil, nil, nil, telemetry.NewStub(logger), nil),
signaler: nil,
}

Expand Down
Loading
Loading