Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/agentgateway/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ pub async fn run(config: Arc<Config>) -> anyhow::Result<Bound> {
.await
.context("admin server starts")?;
#[cfg(feature = "ui")]
admin_server.set_admin_handler(Arc::new(crate::ui::UiHandler::new(config.clone())));
admin_server.set_admin_handler(Arc::new(crate::ui::UiHandler::new(
config.clone(),
stores.guard_registry.clone(),
)));
#[cfg(feature = "ui")]
info!("serving UI at http://{}/ui", config.admin_addr);

Expand Down
44 changes: 39 additions & 5 deletions crates/agentgateway/src/mcp/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,39 @@ impl Relay {
tracing::info!("Establishing security guard baselines for all upstreams");

for (server_name, upstream) in self.upstreams.iter_named() {
// Evaluate connection phase guards (whitelist, typosquat detection)
let context = crate::mcp::security::GuardContext {
server_name: server_name.to_string(),
identity: None,
metadata: serde_json::Value::Null,
};
match self
.security_guards
.evaluate_connection(&server_name, None, &context)
{
Ok(crate::mcp::security::GuardDecision::Allow) => {
tracing::info!(server = %server_name, "Connection guard: allowed");
},
Ok(crate::mcp::security::GuardDecision::Deny(reason)) => {
tracing::warn!(
server = %server_name,
code = %reason.code,
message = %reason.message,
"Connection guard: BLOCKED server"
);
continue; // Skip this upstream entirely
},
Ok(_) => {},
Err(e) => {
tracing::error!(
server = %server_name,
error = %e,
"Connection guard: error"
);
continue; // Skip on error (fail closed)
},
}

// Create a tools/list request
let request = JsonRpcRequest {
jsonrpc: Default::default(),
Expand Down Expand Up @@ -263,18 +296,19 @@ impl Relay {

// Process each server's tools individually for security guard evaluation
for (server_name, s) in streams.into_iter() {
let context = crate::mcp::security::GuardContext {
server_name: server_name.to_string(),
identity: None,
metadata: serde_json::Value::Null,
};

let tools = match s {
ServerResult::ListToolsResult(ltr) => ltr.tools,
_ => vec![],
};

// Execute security guards on this server's tools list BEFORE merging
// This ensures baselines are stored per-server, not under "merged"
let context = crate::mcp::security::GuardContext {
server_name: server_name.to_string(),
identity: None,
metadata: serde_json::Value::Null,
};

match security_guards.evaluate_tools_list(&tools, &context) {
Ok(crate::mcp::security::GuardDecision::Allow) => {
Expand Down
116 changes: 116 additions & 0 deletions crates/agentgateway/src/mcp/security/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ pub enum McpGuardKind {
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum GuardPhase {
/// Before establishing connection to MCP server
/// Used for server whitelisting, typosquat detection, TLS validation
Connection,

/// Before forwarding client request to MCP server
#[default]
Request,
Expand Down Expand Up @@ -297,6 +301,30 @@ impl GuardExecutorRegistry {
let executors = self.executors.read().expect("registry lock poisoned");
executors.keys().cloned().collect()
}

/// Collect schemas from all WASM guards across all backends.
/// Returns a map of guard_id -> (settings_schema_json, default_config_json).
pub fn collect_wasm_schemas(&self) -> HashMap<String, WasmGuardSchema> {
let executors = self.executors.read().expect("registry lock poisoned");
let mut schemas = HashMap::new();

for (_backend_name, executor) in executors.iter() {
for entry in executor.collect_guard_schemas() {
schemas.insert(entry.0, entry.1);
}
}

schemas
}
}

/// Schema information returned by a WASM guard
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmGuardSchema {
/// JSON Schema describing guard's configurable parameters
pub settings_schema: serde_json::Value,
/// Default configuration values
pub default_config: serde_json::Value,
}

/// Guard executor that manages and executes security guards in priority order
Expand Down Expand Up @@ -390,6 +418,64 @@ impl GuardExecutor {
Ok(())
}

/// Execute guards before establishing connection to an MCP server
/// Used for server whitelisting, typosquat detection, TLS validation
pub fn evaluate_connection(
&self,
server_name: &str,
server_url: Option<&str>,
context: &GuardContext,
) -> GuardResult {
let guards = self.guards.read().expect("guards lock poisoned");
tracing::info!(
guard_count = guards.len(),
server = %server_name,
server_url = ?server_url,
"GuardExecutor::evaluate_connection called"
);
for guard_entry in guards.iter() {
// Only run guards configured for Connection phase
if !guard_entry.config.runs_on.contains(&GuardPhase::Connection) {
continue;
}

// Execute guard with timeout
let result = self.execute_with_timeout(
|| {
guard_entry
.guard
.evaluate_connection(server_name, server_url, context)
},
Duration::from_millis(guard_entry.config.timeout_ms),
&guard_entry.config,
);

// Handle result based on failure mode
match result {
Ok(GuardDecision::Allow) => continue,
Ok(decision) => return Ok(decision),
Err(e) => match guard_entry.config.failure_mode {
FailureMode::FailClosed => {
return Err(GuardError::ExecutionError(format!(
"Guard {} failed: {}",
guard_entry.config.id, e
)));
},
FailureMode::FailOpen => {
tracing::warn!(
"Guard {} failed but continuing due to fail_open: {}",
guard_entry.config.id,
e
);
continue;
},
},
}
}

Ok(GuardDecision::Allow)
}

/// Execute guards on a tools/list response
pub fn evaluate_tools_list(
&self,
Expand Down Expand Up @@ -575,6 +661,36 @@ impl GuardExecutor {
f()
}

/// Collect schemas from guards that support dynamic schema export (WASM guards).
/// Returns a list of (guard_id, WasmGuardSchema) pairs.
pub fn collect_guard_schemas(&self) -> Vec<(String, WasmGuardSchema)> {
let guards = self.guards.read().expect("guards lock poisoned");
let mut schemas = Vec::new();

for guard_entry in guards.iter() {
if let Some(schema_json) = guard_entry.guard.get_settings_schema() {
let settings_schema: serde_json::Value =
serde_json::from_str(&schema_json).unwrap_or(serde_json::Value::Null);

let default_config: serde_json::Value = guard_entry
.guard
.get_default_config()
.and_then(|s| serde_json::from_str(&s).ok())
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));

schemas.push((
guard_entry.config.id.clone(),
WasmGuardSchema {
settings_schema,
default_config,
},
));
}
}

schemas
}

/// Reset state for a server (called on session re-initialization)
/// This clears any per-server state like baselines in guards.
pub fn reset_server(&self, server_name: &str) {
Expand Down
27 changes: 27 additions & 0 deletions crates/agentgateway/src/mcp/security/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ use super::{GuardContext, GuardDecision, GuardResult};

/// Common trait for all native guards
pub trait NativeGuard: Send + Sync {
/// Evaluate before establishing connection to an MCP server
/// Used for server whitelisting, typosquat detection, TLS validation
fn evaluate_connection(
&self,
server_name: &str,
server_url: Option<&str>,
context: &GuardContext,
) -> GuardResult {
// Default: allow
let _ = (server_name, server_url, context);
Ok(GuardDecision::Allow)
}

/// Evaluate a tools/list response
fn evaluate_tools_list(&self, tools: &[rmcp::model::Tool], context: &GuardContext)
-> GuardResult;
Expand Down Expand Up @@ -70,6 +83,20 @@ pub trait NativeGuard: Send + Sync {
// Default: no-op (most guards are stateless)
let _ = server_name;
}

/// Get JSON Schema describing this guard's configurable parameters.
/// Returns None for native guards (schemas are embedded in the UI).
/// WASM guards override this to call the guest module's get-settings-schema.
fn get_settings_schema(&self) -> Option<String> {
None
}

/// Get default configuration as JSON.
/// Returns None for native guards.
/// WASM guards override this to call the guest module's get-default-config.
fn get_default_config(&self) -> Option<String> {
None
}
}

/// Helper: Build regex set from patterns
Expand Down
Loading