diff --git a/crates/agentgateway/src/app.rs b/crates/agentgateway/src/app.rs index 655d33fd3..876fd3947 100644 --- a/crates/agentgateway/src/app.rs +++ b/crates/agentgateway/src/app.rs @@ -109,7 +109,10 @@ pub async fn run(config: Arc) -> anyhow::Result { .await .context("admin server starts")?; #[cfg(feature = "ui")] - admin_server.set_admin_handler(Arc::new(crate::ui::UiHandler::new(config.clone()))); + admin_server.set_admin_handler(Arc::new(crate::ui::UiHandler::new( + config.clone(), + stores.guard_registry.clone(), + ))); #[cfg(feature = "ui")] info!("serving UI at http://{}/ui", config.admin_addr); diff --git a/crates/agentgateway/src/mcp/handler.rs b/crates/agentgateway/src/mcp/handler.rs index af5c89630..c5088568e 100644 --- a/crates/agentgateway/src/mcp/handler.rs +++ b/crates/agentgateway/src/mcp/handler.rs @@ -171,6 +171,39 @@ impl Relay { tracing::info!("Establishing security guard baselines for all upstreams"); for (server_name, upstream) in self.upstreams.iter_named() { + // Evaluate connection phase guards (whitelist, typosquat detection) + let context = crate::mcp::security::GuardContext { + server_name: server_name.to_string(), + identity: None, + metadata: serde_json::Value::Null, + }; + match self + .security_guards + .evaluate_connection(&server_name, None, &context) + { + Ok(crate::mcp::security::GuardDecision::Allow) => { + tracing::info!(server = %server_name, "Connection guard: allowed"); + }, + Ok(crate::mcp::security::GuardDecision::Deny(reason)) => { + tracing::warn!( + server = %server_name, + code = %reason.code, + message = %reason.message, + "Connection guard: BLOCKED server" + ); + continue; // Skip this upstream entirely + }, + Ok(_) => {}, + Err(e) => { + tracing::error!( + server = %server_name, + error = %e, + "Connection guard: error" + ); + continue; // Skip on error (fail closed) + }, + } + // Create a tools/list request let request = JsonRpcRequest { jsonrpc: Default::default(), @@ -263,6 +296,12 @@ impl Relay { // Process each server's tools individually for security guard evaluation for (server_name, s) in streams.into_iter() { + let context = crate::mcp::security::GuardContext { + server_name: server_name.to_string(), + identity: None, + metadata: serde_json::Value::Null, + }; + let tools = match s { ServerResult::ListToolsResult(ltr) => ltr.tools, _ => vec![], @@ -270,11 +309,6 @@ impl Relay { // Execute security guards on this server's tools list BEFORE merging // This ensures baselines are stored per-server, not under "merged" - let context = crate::mcp::security::GuardContext { - server_name: server_name.to_string(), - identity: None, - metadata: serde_json::Value::Null, - }; match security_guards.evaluate_tools_list(&tools, &context) { Ok(crate::mcp::security::GuardDecision::Allow) => { diff --git a/crates/agentgateway/src/mcp/security/mod.rs b/crates/agentgateway/src/mcp/security/mod.rs index 27f1b2b89..f5d84e144 100644 --- a/crates/agentgateway/src/mcp/security/mod.rs +++ b/crates/agentgateway/src/mcp/security/mod.rs @@ -98,6 +98,10 @@ pub enum McpGuardKind { #[serde(rename_all = "snake_case")] #[derive(Default)] pub enum GuardPhase { + /// Before establishing connection to MCP server + /// Used for server whitelisting, typosquat detection, TLS validation + Connection, + /// Before forwarding client request to MCP server #[default] Request, @@ -297,6 +301,30 @@ impl GuardExecutorRegistry { let executors = self.executors.read().expect("registry lock poisoned"); executors.keys().cloned().collect() } + + /// Collect schemas from all WASM guards across all backends. + /// Returns a map of guard_id -> (settings_schema_json, default_config_json). + pub fn collect_wasm_schemas(&self) -> HashMap { + let executors = self.executors.read().expect("registry lock poisoned"); + let mut schemas = HashMap::new(); + + for (_backend_name, executor) in executors.iter() { + for entry in executor.collect_guard_schemas() { + schemas.insert(entry.0, entry.1); + } + } + + schemas + } +} + +/// Schema information returned by a WASM guard +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WasmGuardSchema { + /// JSON Schema describing guard's configurable parameters + pub settings_schema: serde_json::Value, + /// Default configuration values + pub default_config: serde_json::Value, } /// Guard executor that manages and executes security guards in priority order @@ -390,6 +418,64 @@ impl GuardExecutor { Ok(()) } + /// Execute guards before establishing connection to an MCP server + /// Used for server whitelisting, typosquat detection, TLS validation + pub fn evaluate_connection( + &self, + server_name: &str, + server_url: Option<&str>, + context: &GuardContext, + ) -> GuardResult { + let guards = self.guards.read().expect("guards lock poisoned"); + tracing::info!( + guard_count = guards.len(), + server = %server_name, + server_url = ?server_url, + "GuardExecutor::evaluate_connection called" + ); + for guard_entry in guards.iter() { + // Only run guards configured for Connection phase + if !guard_entry.config.runs_on.contains(&GuardPhase::Connection) { + continue; + } + + // Execute guard with timeout + let result = self.execute_with_timeout( + || { + guard_entry + .guard + .evaluate_connection(server_name, server_url, context) + }, + Duration::from_millis(guard_entry.config.timeout_ms), + &guard_entry.config, + ); + + // Handle result based on failure mode + match result { + Ok(GuardDecision::Allow) => continue, + Ok(decision) => return Ok(decision), + Err(e) => match guard_entry.config.failure_mode { + FailureMode::FailClosed => { + return Err(GuardError::ExecutionError(format!( + "Guard {} failed: {}", + guard_entry.config.id, e + ))); + }, + FailureMode::FailOpen => { + tracing::warn!( + "Guard {} failed but continuing due to fail_open: {}", + guard_entry.config.id, + e + ); + continue; + }, + }, + } + } + + Ok(GuardDecision::Allow) + } + /// Execute guards on a tools/list response pub fn evaluate_tools_list( &self, @@ -575,6 +661,36 @@ impl GuardExecutor { f() } + /// Collect schemas from guards that support dynamic schema export (WASM guards). + /// Returns a list of (guard_id, WasmGuardSchema) pairs. + pub fn collect_guard_schemas(&self) -> Vec<(String, WasmGuardSchema)> { + let guards = self.guards.read().expect("guards lock poisoned"); + let mut schemas = Vec::new(); + + for guard_entry in guards.iter() { + if let Some(schema_json) = guard_entry.guard.get_settings_schema() { + let settings_schema: serde_json::Value = + serde_json::from_str(&schema_json).unwrap_or(serde_json::Value::Null); + + let default_config: serde_json::Value = guard_entry + .guard + .get_default_config() + .and_then(|s| serde_json::from_str(&s).ok()) + .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); + + schemas.push(( + guard_entry.config.id.clone(), + WasmGuardSchema { + settings_schema, + default_config, + }, + )); + } + } + + schemas + } + /// Reset state for a server (called on session re-initialization) /// This clears any per-server state like baselines in guards. pub fn reset_server(&self, server_name: &str) { diff --git a/crates/agentgateway/src/mcp/security/native/mod.rs b/crates/agentgateway/src/mcp/security/native/mod.rs index c44d57d8a..687c565c2 100644 --- a/crates/agentgateway/src/mcp/security/native/mod.rs +++ b/crates/agentgateway/src/mcp/security/native/mod.rs @@ -21,6 +21,19 @@ use super::{GuardContext, GuardDecision, GuardResult}; /// Common trait for all native guards pub trait NativeGuard: Send + Sync { + /// Evaluate before establishing connection to an MCP server + /// Used for server whitelisting, typosquat detection, TLS validation + fn evaluate_connection( + &self, + server_name: &str, + server_url: Option<&str>, + context: &GuardContext, + ) -> GuardResult { + // Default: allow + let _ = (server_name, server_url, context); + Ok(GuardDecision::Allow) + } + /// Evaluate a tools/list response fn evaluate_tools_list(&self, tools: &[rmcp::model::Tool], context: &GuardContext) -> GuardResult; @@ -70,6 +83,20 @@ pub trait NativeGuard: Send + Sync { // Default: no-op (most guards are stateless) let _ = server_name; } + + /// Get JSON Schema describing this guard's configurable parameters. + /// Returns None for native guards (schemas are embedded in the UI). + /// WASM guards override this to call the guest module's get-settings-schema. + fn get_settings_schema(&self) -> Option { + None + } + + /// Get default configuration as JSON. + /// Returns None for native guards. + /// WASM guards override this to call the guest module's get-default-config. + fn get_default_config(&self) -> Option { + None + } } /// Helper: Build regex set from patterns diff --git a/crates/agentgateway/src/mcp/security/wasm.rs b/crates/agentgateway/src/mcp/security/wasm.rs index 16244a12c..8e5717767 100644 --- a/crates/agentgateway/src/mcp/security/wasm.rs +++ b/crates/agentgateway/src/mcp/security/wasm.rs @@ -32,6 +32,13 @@ pub struct WasmGuardConfig { #[serde(default = "default_max_memory")] pub max_memory: usize, + /// Maximum WebAssembly stack size (bytes). + /// Python WASM components require significantly more stack space (2-4 MB) + /// due to the embedded Python interpreter. + /// Default: 2 MB (sufficient for most Python guards) + #[serde(default = "default_max_wasm_stack")] + pub max_wasm_stack: usize, + /// Timeout for guard execution (milliseconds) #[serde(default = "default_timeout_ms")] pub timeout_ms: u64, @@ -45,10 +52,36 @@ fn default_max_memory() -> usize { 10 * 1024 * 1024 // 10 MB } +fn default_max_wasm_stack() -> usize { + 2 * 1024 * 1024 // 2 MB - sufficient for Python WASM guards +} + fn default_timeout_ms() -> u64 { 100 } +/// Run a closure on a thread with a large stack. +/// Python WASM components require significant native stack space that exceeds +/// the default thread stack size, especially on Windows where the main thread +/// stack cannot be grown dynamically. +/// Uses scoped threads to avoid 'static lifetime requirements. +#[cfg(feature = "wasm-guards")] +fn run_with_large_stack(stack_size: usize, f: F) -> T +where + F: FnOnce() -> T + Send, + T: Send, +{ + std::thread::scope(|scope| { + scope + .spawn(|| { + // Grow the stack on this thread before executing + stacker::grow(stack_size, f) + }) + .join() + .expect("WASM thread panicked") + }) +} + /// State stored in the wasmtime Store for host functions #[cfg(feature = "wasm-guards")] struct WasmState { @@ -121,13 +154,23 @@ impl WasmGuard { // Configure wasmtime engine let mut engine_config = Config::new(); engine_config.wasm_component_model(true); + // Set maximum WASM stack size - Python WASM components require larger stacks + // due to the embedded interpreter + engine_config.max_wasm_stack(config.max_wasm_stack); let engine = Engine::new(&engine_config) .map_err(|e| GuardError::WasmError(format!("Failed to create wasmtime engine: {}", e)))?; // Load and compile the WASM component - let component = Component::from_file(&engine, expanded_path.as_ref()) - .map_err(|e| GuardError::WasmError(format!("Failed to load WASM component: {}", e)))?; + // Python WASM components require significant native stack space during compilation + // due to the embedded interpreter. On Windows, the main thread stack cannot be grown, + // so we spawn a dedicated thread with a large stack (8MB) for compilation. + let path_for_thread = expanded_path.to_string(); + let engine_clone = engine.clone(); + let component = run_with_large_stack(8 * 1024 * 1024, move || { + Component::from_file(&engine_clone, &path_for_thread) + }) + .map_err(|e| GuardError::WasmError(format!("Failed to load WASM component: {}", e)))?; tracing::info!( guard_id = %guard_id, @@ -274,6 +317,20 @@ impl WasmGuard { ))) } }, + "warn" => { + // Warn means allow but log the warnings + if let Some(Val::List(warnings)) = payload.as_deref() { + for warning in warnings { + if let Val::String(msg) = warning { + tracing::warn!( + warning = %msg, + "WASM guard returned warning" + ); + } + } + } + Ok(GuardDecision::Allow) + }, _ => Err(GuardError::WasmError(format!( "Unknown decision variant: {}", name @@ -330,7 +387,7 @@ impl WasmGuard { } } - /// Execute the guard with timeout protection + /// Execute the guard with timeout protection and sufficient stack space fn execute_with_timeout(&self, f: F) -> GuardResult where F: FnOnce() -> GuardResult, @@ -338,7 +395,10 @@ impl WasmGuard { // For synchronous execution, we use a simple approach // In production, this could be enhanced with proper async timeout let start = std::time::Instant::now(); - let result = f(); + // Python WASM components require significant native stack space due to the + // embedded interpreter. Use stacker to grow the native stack when needed. + // Use stacker::grow to force allocation of a large stack segment (8MB). + let result = stacker::grow(8 * 1024 * 1024, f); let elapsed = start.elapsed(); if elapsed.as_millis() as u64 > self.config.timeout_ms { @@ -352,6 +412,69 @@ impl WasmGuard { result } + + /// Call a no-argument WASM function that returns a string. + /// Used for get-settings-schema and get-default-config. + fn call_string_func(&self, func_name: &str) -> Result { + stacker::grow(8 * 1024 * 1024, || { + let linker = self.create_linker()?; + let state = WasmState::new(self.config.config.clone()); + let mut store = Store::new(&self.engine, state); + + let instance = linker + .instantiate(&mut store, &self.component) + .map_err(|e| GuardError::WasmError(format!("Failed to instantiate component: {}", e)))?; + + let guard_export_idx = instance + .get_export(&mut store, None, "mcp:security-guard/guard@0.1.0") + .ok_or_else(|| { + GuardError::WasmError("Guard interface not found in component exports".to_string()) + })?; + + let func_export_idx = instance + .get_export(&mut store, Some(&guard_export_idx), func_name) + .ok_or_else(|| { + GuardError::WasmError(format!( + "Function {} not found in guard interface", + func_name + )) + })?; + + let func = instance + .get_func(&mut store, &func_export_idx) + .ok_or_else(|| { + GuardError::WasmError("Could not get function from export index".to_string()) + })?; + + let mut results = vec![Val::Bool(false)]; // Placeholder + func + .call(&mut store, &[], &mut results) + .map_err(|e| GuardError::WasmError(format!("WASM function call failed: {}", e)))?; + + func + .post_return(&mut store) + .map_err(|e| GuardError::WasmError(format!("WASM post-return failed: {}", e)))?; + + match &results[0] { + Val::String(s) => Ok(s.to_string()), + other => Err(GuardError::WasmError(format!( + "Expected string from {}, got: {:?}", + func_name, other + ))), + } + }) + } + + /// Get the JSON Schema describing this guard's configurable parameters. + /// Returns JSON-serialized JSON Schema (Draft 2020-12). + pub fn get_settings_schema(&self) -> Result { + self.call_string_func("get-settings-schema") + } + + /// Get the default configuration as JSON. + pub fn get_default_config(&self) -> Result { + self.call_string_func("get-default-config") + } } #[cfg(feature = "wasm-guards")] @@ -440,6 +563,7 @@ impl NativeGuard for WasmGuard { "server-name".into(), Val::String(context.server_name.clone().into()), ), + ("server-url".into(), Val::Option(None)), // Not applicable for tools_list evaluation ( "identity".into(), match &context.identity { @@ -478,27 +602,267 @@ impl NativeGuard for WasmGuard { arguments: &serde_json::Value, context: &GuardContext, ) -> GuardResult { - // Default implementation - WASM guards primarily target tools_list evaluation - // This can be extended if the WIT interface is updated to support tool invocation - tracing::debug!( - guard_id = %self.guard_id, - tool_name = %tool_name, - server = %context.server_name, - "WASM guard evaluate_tool_invoke called (default allow)" - ); - let _ = (tool_name, arguments, context); - Ok(GuardDecision::Allow) + self.execute_with_timeout(|| { + tracing::debug!( + guard_id = %self.guard_id, + tool_name = %tool_name, + server = %context.server_name, + "Evaluating tool invocation with WASM guard" + ); + + let linker = self.create_linker()?; + let state = WasmState::new(self.config.config.clone()); + let mut store = Store::new(&self.engine, state); + + let instance = linker + .instantiate(&mut store, &self.component) + .map_err(|e| GuardError::WasmError(format!("Failed to instantiate component: {}", e)))?; + + let guard_export_idx = instance + .get_export(&mut store, None, "mcp:security-guard/guard@0.1.0") + .ok_or_else(|| { + GuardError::WasmError("Guard interface not found in component exports".to_string()) + })?; + + // Try to get evaluate-tool-invoke — if not exported, allow (backward compat) + let func_export_idx = + match instance.get_export(&mut store, Some(&guard_export_idx), "evaluate-tool-invoke") { + Some(idx) => idx, + None => { + tracing::debug!( + guard_id = %self.guard_id, + "WASM guard does not export evaluate-tool-invoke, allowing" + ); + return Ok(GuardDecision::Allow); + }, + }; + + let func = instance + .get_func(&mut store, &func_export_idx) + .ok_or_else(|| { + GuardError::WasmError("Could not get function from export index".to_string()) + })?; + + let tool_name_val = Val::String(tool_name.to_string().into()); + let arguments_val = Val::String( + serde_json::to_string(arguments) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + ); + let context_record = Val::Record(vec![ + ( + "server-name".into(), + Val::String(context.server_name.clone().into()), + ), + ("server-url".into(), Val::Option(None)), + ( + "identity".into(), + match &context.identity { + Some(id) => Val::Option(Some(Box::new(Val::String(id.clone().into())))), + None => Val::Option(None), + }, + ), + ( + "metadata".into(), + Val::String( + serde_json::to_string(&context.metadata) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + ), + ), + ]); + + let mut results = vec![Val::Bool(false)]; + func + .call( + &mut store, + &[tool_name_val, arguments_val, context_record], + &mut results, + ) + .map_err(|e| GuardError::WasmError(format!("WASM function call failed: {}", e)))?; + + func + .post_return(&mut store) + .map_err(|e| GuardError::WasmError(format!("WASM post-return failed: {}", e)))?; + + Self::parse_decision(&results) + }) } fn evaluate_response(&self, response: &serde_json::Value, context: &GuardContext) -> GuardResult { - // Default implementation - can be extended if WIT interface supports response evaluation - tracing::debug!( - guard_id = %self.guard_id, - server = %context.server_name, - "WASM guard evaluate_response called (default allow)" - ); - let _ = (response, context); - Ok(GuardDecision::Allow) + self.execute_with_timeout(|| { + tracing::debug!( + guard_id = %self.guard_id, + server = %context.server_name, + "Evaluating response with WASM guard" + ); + + let linker = self.create_linker()?; + let state = WasmState::new(self.config.config.clone()); + let mut store = Store::new(&self.engine, state); + + let instance = linker + .instantiate(&mut store, &self.component) + .map_err(|e| GuardError::WasmError(format!("Failed to instantiate component: {}", e)))?; + + let guard_export_idx = instance + .get_export(&mut store, None, "mcp:security-guard/guard@0.1.0") + .ok_or_else(|| { + GuardError::WasmError("Guard interface not found in component exports".to_string()) + })?; + + // Try to get evaluate-response — if not exported, allow (backward compat) + let func_export_idx = + match instance.get_export(&mut store, Some(&guard_export_idx), "evaluate-response") { + Some(idx) => idx, + None => { + tracing::debug!( + guard_id = %self.guard_id, + "WASM guard does not export evaluate-response, allowing" + ); + return Ok(GuardDecision::Allow); + }, + }; + + let func = instance + .get_func(&mut store, &func_export_idx) + .ok_or_else(|| { + GuardError::WasmError("Could not get function from export index".to_string()) + })?; + + let response_val = Val::String( + serde_json::to_string(response) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + ); + let context_record = Val::Record(vec![ + ( + "server-name".into(), + Val::String(context.server_name.clone().into()), + ), + ("server-url".into(), Val::Option(None)), + ( + "identity".into(), + match &context.identity { + Some(id) => Val::Option(Some(Box::new(Val::String(id.clone().into())))), + None => Val::Option(None), + }, + ), + ( + "metadata".into(), + Val::String( + serde_json::to_string(&context.metadata) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + ), + ), + ]); + + let mut results = vec![Val::Bool(false)]; + func + .call(&mut store, &[response_val, context_record], &mut results) + .map_err(|e| GuardError::WasmError(format!("WASM function call failed: {}", e)))?; + + func + .post_return(&mut store) + .map_err(|e| GuardError::WasmError(format!("WASM post-return failed: {}", e)))?; + + Self::parse_decision(&results) + }) + } + + fn evaluate_connection( + &self, + server_name: &str, + server_url: Option<&str>, + context: &GuardContext, + ) -> GuardResult { + self.execute_with_timeout(|| { + tracing::debug!( + guard_id = %self.guard_id, + server = %server_name, + server_url = ?server_url, + "Evaluating connection with WASM guard" + ); + + let linker = self.create_linker()?; + let state = WasmState::new(self.config.config.clone()); + let mut store = Store::new(&self.engine, state); + + // Instantiate the component + let instance = linker + .instantiate(&mut store, &self.component) + .map_err(|e| GuardError::WasmError(format!("Failed to instantiate component: {}", e)))?; + + // Get the exported function from the guard interface + let guard_export_idx = instance + .get_export(&mut store, None, "mcp:security-guard/guard@0.1.0") + .ok_or_else(|| { + GuardError::WasmError("Guard interface not found in component exports".to_string()) + })?; + + // Get the evaluate-server-connection function + let func_export_idx = instance + .get_export( + &mut store, + Some(&guard_export_idx), + "evaluate-server-connection", + ) + .ok_or_else(|| { + GuardError::WasmError( + "Function evaluate-server-connection not found in guard interface".to_string(), + ) + })?; + + let func = instance + .get_func(&mut store, &func_export_idx) + .ok_or_else(|| { + GuardError::WasmError("Could not get function from export index".to_string()) + })?; + + // Build context as WIT record with server_url + let context_record = Val::Record(vec![ + ( + "server-name".into(), + Val::String(context.server_name.clone().into()), + ), + ( + "server-url".into(), + match server_url { + Some(url) => Val::Option(Some(Box::new(Val::String(url.to_string().into())))), + None => Val::Option(None), + }, + ), + ( + "identity".into(), + match &context.identity { + Some(id) => Val::Option(Some(Box::new(Val::String(id.clone().into())))), + None => Val::Option(None), + }, + ), + ( + "metadata".into(), + Val::String( + serde_json::to_string(&context.metadata) + .unwrap_or_else(|_| "{}".to_string()) + .into(), + ), + ), + ]); + + // Call the function + let mut results = vec![Val::Bool(false)]; // Placeholder for result + func + .call(&mut store, &[context_record], &mut results) + .map_err(|e| GuardError::WasmError(format!("WASM function call failed: {}", e)))?; + + // Post-call cleanup + func + .post_return(&mut store) + .map_err(|e| GuardError::WasmError(format!("WASM post-return failed: {}", e)))?; + + Self::parse_decision(&results) + }) } fn reset_server(&self, server_name: &str) { @@ -509,6 +873,34 @@ impl NativeGuard for WasmGuard { "WASM guard reset_server called (no-op)" ); } + + fn get_settings_schema(&self) -> Option { + match self.call_string_func("get-settings-schema") { + Ok(schema) => Some(schema), + Err(e) => { + tracing::warn!( + guard_id = %self.guard_id, + error = %e, + "Failed to get settings schema from WASM guard" + ); + None + }, + } + } + + fn get_default_config(&self) -> Option { + match self.call_string_func("get-default-config") { + Ok(config) => Some(config), + Err(e) => { + tracing::warn!( + guard_id = %self.guard_id, + error = %e, + "Failed to get default config from WASM guard" + ); + None + }, + } + } } // Non-wasm-guards feature: provide stub implementation @@ -533,6 +925,7 @@ mod tests { let invalid_config = WasmGuardConfig { module_path: String::new(), max_memory: 1024 * 1024, + max_wasm_stack: default_max_wasm_stack(), timeout_ms: 100, config: HashMap::new(), }; @@ -546,6 +939,7 @@ mod tests { let valid_config = WasmGuardConfig { module_path: "/path/to/probe.wasm".to_string(), max_memory: 10 * 1024 * 1024, + max_wasm_stack: default_max_wasm_stack(), timeout_ms: 100, config: HashMap::new(), }; @@ -567,6 +961,7 @@ mod tests { #[test] fn test_default_config_values() { assert_eq!(default_max_memory(), 10 * 1024 * 1024); + assert_eq!(default_max_wasm_stack(), 2 * 1024 * 1024); assert_eq!(default_timeout_ms(), 100); } @@ -655,6 +1050,7 @@ module_path: ./guards/test.wasm let config = WasmGuardConfig { module_path: wasm_path.to_str().unwrap().to_string(), max_memory: 10 * 1024 * 1024, + max_wasm_stack: default_max_wasm_stack(), timeout_ms: 1000, config: HashMap::new(), // Use default patterns }; diff --git a/crates/agentgateway/src/mcp/security/wit/guard.wit b/crates/agentgateway/src/mcp/security/wit/guard.wit index 5b06c5d46..24504ec3e 100644 --- a/crates/agentgateway/src/mcp/security/wit/guard.wit +++ b/crates/agentgateway/src/mcp/security/wit/guard.wit @@ -17,6 +17,7 @@ interface guard { /// Context provided by the host record guard-context { server-name: string, + server-url: option, identity: option, metadata: string, /// JSON-serialized metadata } @@ -25,7 +26,8 @@ interface guard { variant decision { allow, deny(deny-reason), - modify(string), /// JSON-serialized modification + modify(string), /// JSON-serialized modification (for PII masking, etc.) + warn(list), /// Allow with warnings } /// Reason for denying a request @@ -35,9 +37,29 @@ interface guard { details: option, /// JSON-serialized details } - /// Main guard evaluation function - /// Called by the host for every tools/list response + /// Evaluate server connection before establishing + evaluate-server-connection: func(context: guard-context) -> result; + + /// Evaluate tools/list response evaluate-tools-list: func(tools: list, context: guard-context) -> result; + + /// Evaluate a tool invocation (tools/call) before execution + /// tool-name: name of the tool being invoked + /// arguments: JSON-serialized tool arguments + evaluate-tool-invoke: func(tool-name: string, arguments: string, context: guard-context) -> result; + + /// Evaluate a response from the MCP server + /// response: JSON-serialized response content + evaluate-response: func(response: string, context: guard-context) -> result; + + /// Get JSON Schema describing guard's configurable parameters + /// Returns JSON-serialized JSON Schema (Draft 2020-12) + /// Schema describes guard-specific settings only (not common settings like priority, timeout) + get-settings-schema: func() -> string; + + /// Get default configuration as JSON + /// Returns JSON-serialized default config values + get-default-config: func() -> string; } /// Host interface - functions provided by AgentGateway to the WASM module diff --git a/crates/agentgateway/src/state_manager.rs b/crates/agentgateway/src/state_manager.rs index 8f186c5d5..92eb77d0b 100644 --- a/crates/agentgateway/src/state_manager.rs +++ b/crates/agentgateway/src/state_manager.rs @@ -195,12 +195,42 @@ impl LocalClient { .await?; info!("loaded config from {:?}", self.cfg); + // Extract MCP backend guard configs before sync_local consumes backends + let mcp_guard_configs: Vec<_> = config + .backends + .iter() + .filter_map(|bwp| { + if let crate::types::agent::Backend::MCP(_, mcp) = &bwp.backend { + let backend_name = bwp.backend.name().to_string(); + Some((backend_name, mcp.security_guards.clone())) + } else { + None + } + }) + .collect(); + // Sync the state let next_binds = self .stores .binds .sync_local(config.binds, config.policies, config.backends, prev.binds); + + // Hot-reload security guards for MCP backends + for (backend_name, guards) in mcp_guard_configs { + if let Err(e) = self + .stores + .guard_registry + .update_backend(&backend_name, guards) + { + warn!( + backend = %backend_name, + error = %e, + "Failed to hot-reload security guards" + ); + } + } + let next_discovery = self .stores diff --git a/crates/agentgateway/src/ui.rs b/crates/agentgateway/src/ui.rs index 614186350..a5e3f2e22 100644 --- a/crates/agentgateway/src/ui.rs +++ b/crates/agentgateway/src/ui.rs @@ -17,7 +17,9 @@ use tower_http::cors::CorsLayer; use tower_serve_static::ServeDir; use crate::management::admin::{AdminFallback, AdminResponse}; +use crate::mcp::security::{GuardExecutorRegistry, McpGuardKind, McpSecurityGuard}; use crate::{Config, ConfigSource, client, yamlviajson}; + pub struct UiHandler { router: Router, } @@ -26,6 +28,7 @@ pub struct UiHandler { struct App { state: Arc, client: client::Client, + guard_registry: GuardExecutorRegistry, } impl App { @@ -44,17 +47,19 @@ lazy_static::lazy_static! { } impl UiHandler { - pub fn new(cfg: Arc) -> Self { + pub fn new(cfg: Arc, guard_registry: GuardExecutorRegistry) -> Self { let ui_service = ServeDir::new(&ASSETS_DIR); let router = Router::new() // Redirect to the UI .route("/config", get(get_config).post(write_config)) + .route("/api/v1/guards/schemas", get(get_guard_schemas)) .nest_service("/ui", ui_service) .route("/", get(|| async { Redirect::permanent("/ui") })) .layer(add_cors_layer()) .with_state(App { state: cfg.clone(), client: client::Client::new(&cfg.dns, None, Default::default(), None), + guard_registry, }); Self { router } } @@ -128,6 +133,142 @@ async fn write_config( )) } +/// GET /api/v1/guards/schemas +/// Returns JSON Schemas for all guards. +/// Primary source: already-loaded guards from the registry (works regardless of file paths). +/// Fallback: re-instantiate WASM modules from config (for when guards haven't connected yet). +async fn get_guard_schemas(State(app): State) -> Result, ErrorResponse> { + let mut schemas = serde_json::Map::new(); + + // Primary: query already-loaded guards from the registry + let registry_schemas = app.guard_registry.collect_wasm_schemas(); + for (guard_id, wasm_schema) in ®istry_schemas { + // Use x-guard-meta.guardType as key, fall back to guard id + let schema_key = wasm_schema + .settings_schema + .get("x-guard-meta") + .and_then(|m| m.get("guardType")) + .and_then(|v| v.as_str()) + .unwrap_or(guard_id) + .to_string(); + schemas.insert(schema_key, wasm_schema.settings_schema.clone()); + } + + // Fallback: if registry had no schemas, try loading from config + // (handles case where no MCP clients have connected yet) + if schemas.is_empty() { + if let Ok(cfg_source) = app.cfg() { + if let Ok(yaml_str) = cfg_source.read_to_string().await { + if let Ok(config_val) = yamlviajson::from_str::(&yaml_str) { + collect_wasm_schemas_from_config(&config_val, &mut schemas); + } + } + } + } + + Ok(Json(serde_json::json!({ + "schemas": schemas, + }))) +} + +/// Walk the config JSON to find WASM guard entries and extract their schemas. +/// Returns schemas keyed by x-guard-meta.guardType (or guard id as fallback), +/// matching the GuardSchemasResponse format expected by the frontend. +#[allow(unused_variables)] +fn collect_wasm_schemas_from_config(config: &Value, schemas: &mut serde_json::Map) { + // Navigate: binds[] -> listeners[] -> routes[] -> backends[] -> mcp -> securityGuards[] + let Some(binds) = config.get("binds").and_then(|v| v.as_array()) else { + return; + }; + + for bind in binds { + let Some(listeners) = bind.get("listeners").and_then(|v| v.as_array()) else { + continue; + }; + for listener in listeners { + let Some(routes) = listener.get("routes").and_then(|v| v.as_array()) else { + continue; + }; + for route in routes { + let Some(backends) = route.get("backends").and_then(|v| v.as_array()) else { + continue; + }; + for backend in backends { + let Some(mcp) = backend.get("mcp") else { + continue; + }; + let Some(guards) = mcp.get("securityGuards").and_then(|v| v.as_array()) else { + continue; + }; + collect_wasm_schemas_from_guards(guards, schemas); + } + } + } + } +} + +/// Extract schemas from a list of guard config values. +#[allow(unused_variables)] +fn collect_wasm_schemas_from_guards( + guards: &[Value], + schemas: &mut serde_json::Map, +) { + for guard_val in guards { + let Some(guard_type) = guard_val.get("type").and_then(|v| v.as_str()) else { + continue; + }; + if guard_type != "wasm" { + continue; + } + + let Some(guard_id) = guard_val.get("id").and_then(|v| v.as_str()) else { + continue; + }; + + // Deserialize as full McpSecurityGuard (which flattens McpGuardKind), + // then extract .kind to get WasmGuardConfig + if let Ok(guard) = serde_json::from_value::(guard_val.clone()) { + let kind = guard.kind; + #[cfg(feature = "wasm-guards")] + if let McpGuardKind::Wasm(wasm_cfg) = kind { + match crate::mcp::security::wasm::WasmGuard::new(guard_id.to_string(), wasm_cfg) { + Ok(wasm_guard) => { + if let Ok(schema_str) = wasm_guard.get_settings_schema() { + if let Ok(schema_val) = serde_json::from_str::(&schema_str) { + // Use x-guard-meta.guardType as key, fall back to guard id + let schema_key = schema_val + .get("x-guard-meta") + .and_then(|m| m.get("guardType")) + .and_then(|v| v.as_str()) + .unwrap_or(guard_id) + .to_string(); + + schemas.insert(schema_key, schema_val); + } + } + }, + Err(e) => { + tracing::warn!( + guard_id = guard_id, + error = %e, + "Failed to load WASM guard for schema extraction" + ); + }, + } + } + + #[cfg(not(feature = "wasm-guards"))] + { + let _ = kind; + tracing::debug!( + guard_id = guard_id, + "WASM guards feature not enabled, skipping schema extraction" + ); + } + } + } +} + pub fn add_cors_layer() -> CorsLayer { CorsLayer::new() .allow_origin( diff --git a/examples/wasm-guards/simple-pattern-guard/src/lib.rs b/examples/wasm-guards/simple-pattern-guard/src/lib.rs index 67f13464b..7b3d56190 100644 --- a/examples/wasm-guards/simple-pattern-guard/src/lib.rs +++ b/examples/wasm-guards/simple-pattern-guard/src/lib.rs @@ -88,6 +88,84 @@ impl Guest for SimplePatternGuard { log_info("All tools passed pattern check"); Ok(Decision::Allow) } + + fn evaluate_server_connection(context: GuardContext) -> Result { + // This guard focuses on tool patterns, so we allow all connections + log_info(&format!( + "Allowing connection to server '{}'", + context.server_name + )); + Ok(Decision::Allow) + } + + fn evaluate_tool_invoke( + _tool_name: String, + _arguments: String, + _context: GuardContext, + ) -> Result { + Ok(Decision::Allow) + } + + fn evaluate_response( + _response: String, + _context: GuardContext, + ) -> Result { + Ok(Decision::Allow) + } + + fn get_settings_schema() -> String { + serde_json::json!({ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "agentgateway://guards/simple-pattern/v1", + "title": "Simple Pattern Guard", + "description": "Blocks tools matching configurable patterns in name or description", + "type": "object", + "properties": { + "blocked_patterns": { + "type": "array", + "title": "Blocked Patterns", + "description": "List of substrings to block (case-insensitive)", + "items": { "type": "string" }, + "default": ["delete", "rm -rf", "drop table", "eval", "exec"], + "x-ui": { + "component": "tags", + "placeholder": "Enter pattern and press Enter", + "order": 1 + } + }, + "scan_descriptions": { + "type": "boolean", + "title": "Scan Descriptions", + "description": "Also check tool descriptions for blocked patterns", + "default": true, + "x-ui": { "order": 2 } + }, + "max_tool_count": { + "type": "integer", + "title": "Max Tool Count", + "description": "Maximum allowed tools per server (0 = unlimited)", + "default": 0, + "minimum": 0, + "x-ui": { "order": 3, "advanced": true } + } + }, + "x-guard-meta": { + "guardType": "simple_pattern", + "version": "1.0.0", + "category": "detection", + "defaultRunsOn": ["tools_list"], + "icon": "filter" + } + }).to_string() + } + + fn get_default_config() -> String { + serde_json::json!({ + "blocked_patterns": ["delete", "rm -rf", "drop table", "eval", "exec"], + "scan_descriptions": true, + "max_tool_count": 0 + }).to_string() + } } // Helper: Get blocked patterns from config or use defaults diff --git a/examples/wasm-guards/simple-pattern-guard/wit/guard.wit b/examples/wasm-guards/simple-pattern-guard/wit/guard.wit index 2fc28bc13..cd879fd3e 100644 --- a/examples/wasm-guards/simple-pattern-guard/wit/guard.wit +++ b/examples/wasm-guards/simple-pattern-guard/wit/guard.wit @@ -20,6 +20,7 @@ interface guard { /// Context provided by the host record guard-context { server-name: string, + server-url: option, /// URL of the MCP server (for connection validation) identity: option, metadata: string, /// JSON-serialized metadata } @@ -28,7 +29,8 @@ interface guard { variant decision { allow, deny(deny-reason), - modify(string), /// JSON-serialized modification + modify(string), /// JSON-serialized modification (for PII masking, etc.) + warn(list), /// Allow with warnings } /// Reason for denying a request @@ -41,6 +43,25 @@ interface guard { /// Main guard evaluation function /// Called by the host for every tools/list response evaluate-tools-list: func(tools: list, context: guard-context) -> result; + + /// Evaluate server connection before establishing + /// Called by the host before connecting to an MCP server + evaluate-server-connection: func(context: guard-context) -> result; + + /// Evaluate a tool invocation (tools/call) before execution + evaluate-tool-invoke: func(tool-name: string, arguments: string, context: guard-context) -> result; + + /// Evaluate a response from the MCP server + evaluate-response: func(response: string, context: guard-context) -> result; + + /// Get JSON Schema describing guard's configurable parameters + /// Returns JSON-serialized JSON Schema (Draft 2020-12) + /// Schema describes guard-specific settings only (not common settings like priority, timeout) + get-settings-schema: func() -> string; + + /// Get default configuration as JSON + /// Returns JSON-serialized default config values + get-default-config: func() -> string; } /// Host interface - functions provided by AgentGateway to the WASM module diff --git a/schema/config.json b/schema/config.json index 04b60d010..bd6932cf2 100644 --- a/schema/config.json +++ b/schema/config.json @@ -7219,6 +7219,11 @@ "items": { "description": "Execution phase for guards", "oneOf": [ + { + "description": "Before establishing connection to MCP server\nUsed for server whitelisting, typosquat detection, TLS validation", + "type": "string", + "const": "connection" + }, { "description": "Before forwarding client request to MCP server", "type": "string", diff --git a/ui/src/components/backend/backend-components.tsx b/ui/src/components/backend/backend-components.tsx index d853fc9fa..c690326ca 100644 --- a/ui/src/components/backend/backend-components.tsx +++ b/ui/src/components/backend/backend-components.tsx @@ -67,6 +67,8 @@ import { PII_ACTIONS, SCAN_FIELDS, } from "@/lib/backend-constants"; +import { SchemaForm } from "@/components/schema-form"; +import { useGuardSchemas } from "@/hooks/useGuardSchemas"; import type { SecurityGuard, SecurityGuardType } from "@/lib/types"; import { getBackendType, @@ -718,6 +720,7 @@ const SecurityGuardsSection: React.FC = ({ }) => { const [isExpanded, setIsExpanded] = React.useState(guards.length > 0); const [expandedGuards, setExpandedGuards] = React.useState>(new Set()); + const { schemas } = useGuardSchemas(); const toggleGuardExpanded = (index: number) => { setExpandedGuards((prev) => { @@ -1227,11 +1230,11 @@ const SecurityGuardsSection: React.FC = ({ type="number" min={1} max={1024} - value={Math.round(guard.max_memory_bytes / (1024 * 1024))} + value={Math.round(guard.max_memory / (1024 * 1024))} onChange={(e) => updateSecurityGuardField( index, - "max_memory_bytes", + "max_memory", (parseInt(e.target.value) || 10) * 1024 * 1024 ) } @@ -1239,53 +1242,88 @@ const SecurityGuardsSection: React.FC = ({ />
- + updateSecurityGuardField( index, - "max_fuel", - parseInt(e.target.value) || 1000000 + "max_wasm_stack", + (parseInt(e.target.value) || 2) * 1024 * 1024 ) } className="h-8 text-sm" />
-
- -