diff --git a/approval-gate/iii.worker.yaml b/approval-gate/iii.worker.yaml index 48148c1f..c9983931 100644 --- a/approval-gate/iii.worker.yaml +++ b/approval-gate/iii.worker.yaml @@ -4,7 +4,12 @@ language: rust deploy: binary manifest: Cargo.toml bin: approval-gate -description: Hook subscriber on agent::before_function_call that pauses function calls listed in approval_required until the UI resolves them via approval::resolve. +description: | + Hook subscriber on agent::before_function_call. Decides every LLM-initiated + function call via a layered rules engine. Allow → pass through. Deny → + structured Denial::Policy. Ask → write a Pending record and wait for + approval::resolve. The classifier surface and __from_approval marker are gone; + policy lives entirely in the rules layer. runtime: kind: rust @@ -17,12 +22,34 @@ config: topic: agent::before_function_call approval_state_scope: approvals default_timeout_ms: 300000 - interceptors: - - function_id: shell::exec - classifier: shell::classify_argv - classifier_timeout_ms: 2000 - inject_approval_marker: true - - function_id: shell::exec_bg - classifier: shell::classify_argv - classifier_timeout_ms: 2000 - inject_approval_marker: true + + # Curated default ruleset. `before_function_call` fires for every tool + # call; with no rules and no-match defaulting to Ask, an empty ruleset + # would prompt for every read-only function. The defaults below + # auto-allow safe reads and ask for everything that writes/executes/ + # mutates. Operators stack their own rules on top — last-match wins. + rules: + # Read-only filesystem / introspection + - { permission: "fs::read", pattern: "*", action: allow } + - { permission: "fs::list", pattern: "*", action: allow } + - { permission: "fs::stat", pattern: "*", action: allow } + - { permission: "fs::glob", pattern: "*", action: allow } + - { permission: "fs::grep", pattern: "*", action: allow } + + # Read-only git + - { permission: "shell::exec", pattern: "git status*", action: allow } + - { permission: "shell::exec", pattern: "git log*", action: allow } + - { permission: "shell::exec", pattern: "git diff*", action: allow } + - { permission: "shell::exec", pattern: "git show*", action: allow } + - { permission: "shell::exec", pattern: "git branch*", action: allow } + - { permission: "shell::exec", pattern: "git remote*", action: allow } + + # Approval API — the gate must not gate itself + - { permission: "approval::*", pattern: "*", action: allow } + + # All remaining shell exec calls → ask + - { permission: "shell::exec", pattern: "*", action: ask } + - { permission: "shell::exec_bg", pattern: "*", action: ask } + + # Catch-all: anything else → ask. (Operator overrides go above.) + - { permission: "*", pattern: "*", action: ask } diff --git a/approval-gate/skills/sweep_session.md b/approval-gate/skills/sweep_session.md index 9dc67c66..e791db36 100644 --- a/approval-gate/skills/sweep_session.md +++ b/approval-gate/skills/sweep_session.md @@ -1,6 +1,6 @@ # approval::sweep_session -Sweep all pending approval records for a session to `timed_out` with reason `session_deleted`. +Sweep all pending approval records for a session to `timed_out`. **Payload:** - `session_id` (string, required) @@ -12,4 +12,5 @@ Sweep all pending approval records for a session to `timed_out` with reason `ses **Behavior:** - Only records with `status: "pending"` are flipped. - Non-pending records (already resolved, executed, denied, etc.) are left untouched. -- Intended to be called by the session worker or turn-orchestrator when a session is being deleted, so that pending approvals don't dangle forever. +- The flipped records carry no `Denial` — `status: "timed_out"` is self-describing per the Denial refactor. Callers that need to distinguish session-delete from run-stop sweeps should log that context in their own worker. +- Intended to be called by the session worker or turn-orchestrator when a session is being deleted or a run is stopped, so pending approvals don't dangle forever. diff --git a/approval-gate/src/config.rs b/approval-gate/src/config.rs index e7cb6942..3c406f13 100644 --- a/approval-gate/src/config.rs +++ b/approval-gate/src/config.rs @@ -1,4 +1,14 @@ //! YAML-backed runtime settings for [`WorkerConfig`]. +//! +//! Post-refactor surface (T12): +//! - `topic` — hook bus topic the gate subscribes to. +//! - `approval_state_scope` — iii-state scope for approval records. +//! - `default_timeout_ms` — Pending-row TTL. +//! - `rules` — the layered ruleset (default + operator-shipped), +//! evaluated in order with last-match winning. +//! +//! Deleted in T12: `interceptors`, `sweeper_interval_ms`, +//! `InterceptorRule` (the classifier surface is gone). use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; @@ -15,21 +25,16 @@ fn default_default_timeout_ms() -> u64 { 300_000 } -fn default_classifier_timeout_ms() -> u64 { - 2000 -} - -/// Per-function iii intercept rule: optional classifier trigger before pending + -/// optional `__from_approval` injection on post-resolve `iii.trigger`. +/// Temporary alias retained while register.rs's classifier-alias warning +/// loop still references the symbol. The struct is structurally unused +/// (no fields populated from config) and will be deleted alongside the +/// warning loop when there are no more callers. Provided here so the +/// crate builds. #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct InterceptorRule { pub function_id: String, #[serde(default)] pub classifier: Option, - #[serde(default = "default_classifier_timeout_ms")] - pub classifier_timeout_ms: u64, - #[serde(default)] - pub inject_approval_marker: bool, } #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] @@ -40,8 +45,11 @@ pub struct WorkerConfig { pub approval_state_scope: String, #[serde(default = "default_default_timeout_ms")] pub default_timeout_ms: u64, + /// Layered permission ruleset. Allow / Deny / Ask actions. Evaluated + /// last-match-wins; the YAML's curated defaults ship at the bottom, + /// operator overrides stack on top. See [`crate::rules`]. #[serde(default)] - pub interceptors: Vec, + pub rules: crate::rules::Ruleset, } impl Default for WorkerConfig { @@ -50,7 +58,7 @@ impl Default for WorkerConfig { topic: default_topic(), approval_state_scope: default_approval_state_scope(), default_timeout_ms: default_default_timeout_ms(), - interceptors: Vec::new(), + rules: Vec::new(), } } } @@ -69,6 +77,7 @@ pub fn load_config(path: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use crate::rules::{Action, Rule}; #[test] fn defaults_from_empty_yaml_mapping() { @@ -76,50 +85,27 @@ mod tests { assert_eq!(cfg.topic, default_topic()); assert_eq!(cfg.approval_state_scope, "approvals"); assert_eq!(cfg.default_timeout_ms, 300_000); - assert!(cfg.interceptors.is_empty()); - } - - #[test] - fn interceptors_default_empty() { - assert!(WorkerConfig::default().interceptors.is_empty()); - } - - #[test] - fn interceptors_parse_from_nested_config_block() { - let yaml = r#" -interceptors: - - function_id: shell::exec - classifier: shell::classify_argv - classifier_timeout_ms: 1500 - inject_approval_marker: true - - function_id: other::fn - classifier: null -"#; - let cfg: WorkerConfig = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(cfg.interceptors.len(), 2); - assert_eq!(cfg.interceptors[0].function_id, "shell::exec"); - assert_eq!( - cfg.interceptors[0].classifier.as_deref(), - Some("shell::classify_argv") - ); - assert_eq!(cfg.interceptors[0].classifier_timeout_ms, 1500); - assert!(cfg.interceptors[0].inject_approval_marker); - assert_eq!(cfg.interceptors[1].function_id, "other::fn"); - assert!(cfg.interceptors[1].classifier.is_none()); - assert!(!cfg.interceptors[1].inject_approval_marker); + assert!(cfg.rules.is_empty()); } #[test] - fn interceptor_rule_marker_defaults_false() { + fn rules_parse_from_yaml() { let yaml = r#" -interceptors: - - function_id: x::y - classifier: c::f +rules: + - { permission: "shell::exec", pattern: "git status*", action: allow } + - { permission: "shell::exec", pattern: "*", action: ask } "#; let cfg: WorkerConfig = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(cfg.interceptors.len(), 1); - assert!(!cfg.interceptors[0].inject_approval_marker); - assert_eq!(cfg.interceptors[0].classifier_timeout_ms, 2000); + assert_eq!(cfg.rules.len(), 2); + assert_eq!(cfg.rules[0].permission, "shell::exec"); + assert_eq!(cfg.rules[0].pattern, "git status*"); + assert_eq!(cfg.rules[0].action, Action::Allow); + assert_eq!(cfg.rules[1].action, Action::Ask); + let _ = Rule { // smoke check on the imported type + permission: "x".into(), + pattern: "*".into(), + action: Action::Deny, + }; } #[test] diff --git a/approval-gate/src/delivery.rs b/approval-gate/src/delivery.rs new file mode 100644 index 00000000..e2f73281 --- /dev/null +++ b/approval-gate/src/delivery.rs @@ -0,0 +1,361 @@ +//! Delivery-tracking handlers. +//! +//! Three RPCs make up the gate's read/drain surface: +//! +//! - [`handle_list_pending`] — UI-facing list of in-flight prompts. +//! Applies lazy timeout flip on read: a Pending row past `expires_at` +//! flips to `Done(TimedOut)` and disappears from the list. +//! - [`handle_consume`] — atomic drain: returns Done rows and deletes +//! them in the same call. Defensive `session_id` filter; cap + +//! `omitted` counter; sort by `resolved_at` for deterministic LLM +//! replay across multi-row consumes (cascade case). +//! - [`handle_sweep_session`] — force-cancellation for `run::stop`: +//! flips every Pending and InFlight row to `Done(TimedOut)`. + +use serde_json::{json, Value}; + +use crate::record::{Outcome, Record, Status}; +use crate::state::StateBus; +use crate::wire::pending_key; + +/// Default per-call cap on `handle_consume`. Bounds the response size — +/// `Outcome::Executed.result` can carry MB-sized stdout/stderr payloads, +/// and we don't want one consume to blow the trigger wire or the next +/// LLM turn. +pub const CONSUME_DEFAULT_LIMIT: usize = 50; + +/// List Pending rows for a session. Applies lazy timeout flip on read — +/// expired Pending rows are persisted as `Done(TimedOut)` and dropped +/// from the response. +pub async fn handle_list_pending(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { + let session_id = payload.get("session_id").and_then(Value::as_str).unwrap_or(""); + if session_id.is_empty() { + return json!({ "pending": [] }); + } + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + let prefix = format!("{session_id}/"); + let rows = bus.list_prefix(state_scope, &prefix).await; + + let mut pending = Vec::new(); + for raw in rows { + let Some(record) = Record::from_value(raw) else { continue }; + if record.session_id != session_id { continue; } // defensive + // Lazy flip + persist; expired rows leave the Pending list. + if let Some(flipped) = record.flipped_to_timed_out_if_expired(now_ms) { + let key = pending_key(session_id, &flipped.function_call_id); + let _ = bus.set(state_scope, &key, flipped.to_value()).await; + continue; + } + if record.status == Status::Pending { + pending.push(record.to_value()); + } + } + json!({ "pending": pending }) +} + +/// Atomic drain: returns Done rows for a session and deletes them in the +/// same call. Pending and InFlight rows stay in state. Pending rows past +/// `expires_at` are lazy-flipped to `Done(TimedOut)` and returned. +/// +/// Three phases: +/// 1. gather Done candidates (no state mutation); +/// 2. sort by `resolved_at`, apply cap, report `omitted` count; +/// 3. delete-and-return — only rows whose delete succeeded are returned, +/// so a partial failure leaves the row to be retried next consume. +/// +/// Sort order matters when cascade auto-resolves multiple rows that all +/// surface to the same consume — `resolved_at` produces deterministic +/// LLM message order. +pub async fn handle_consume( + bus: &dyn StateBus, + state_scope: &str, + payload: Value, + now_ms: u64, +) -> Value { + let session_id = payload.get("session_id").and_then(Value::as_str).unwrap_or(""); + if session_id.is_empty() { + return json!({ "ok": false, "error": "missing_session_id" }); + } + let limit = payload + .get("limit") + .and_then(Value::as_u64) + .map(|n| n as usize) + .unwrap_or(CONSUME_DEFAULT_LIMIT); + + let prefix = format!("{session_id}/"); + let rows = bus.list_prefix(state_scope, &prefix).await; + + // Phase 1: gather Done candidates without mutating state. + let mut candidates: Vec = Vec::new(); + for raw in rows { + let Some(record) = Record::from_value(raw) else { continue }; + // Defensive session_id filter: some state-bus backends ignore the + // prefix arg and return every row in the scope. Drop anything not + // stamped with the session_id we're consuming for — otherwise a + // faulty backend could cross-session delete. + if record.session_id != session_id { continue; } + // Lazy flip (Pending → Done(TimedOut)). No persist needed — we're + // about to delete this row. + let record = record.flipped_to_timed_out_if_expired(now_ms).unwrap_or(record); + // Only drain Done. Pending (awaiting operator) and InFlight + // (invoke in progress) stay in state. + if record.status != Status::Done { continue; } + candidates.push(record); + } + + // Phase 2: sort + cap. + candidates.sort_by_key(|r| r.resolved_at.unwrap_or(u64::MAX)); + let total = candidates.len(); + let omitted = total.saturating_sub(limit) as u64; + candidates.truncate(limit); + + // Phase 3: delete-and-return. + let mut entries: Vec = Vec::with_capacity(candidates.len()); + for record in candidates { + let key = pending_key(session_id, &record.function_call_id); + if bus.delete(state_scope, &key).await.is_ok() { + entries.push(record.to_value()); + } + } + json!({ "ok": true, "entries": entries, "omitted": omitted }) +} + +/// Force-cancel every non-terminal row in a session by flipping it to +/// `Done(TimedOut)`. Called from `run::stop` so a stale UI modal cannot +/// still execute its function after the operator clicks Stop. Lazy +/// timeout is not a substitute — default `expires_at` is 5 min and we +/// cannot leave a 5-min stale-modal window after Stop. +pub async fn handle_sweep_session( + bus: &dyn StateBus, + state_scope: &str, + payload: Value, +) -> Value { + let session_id = payload.get("session_id").and_then(Value::as_str).unwrap_or(""); + if session_id.is_empty() { + return json!({ "ok": false, "error": "missing_session_id", "swept": 0 }); + } + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + let prefix = format!("{session_id}/"); + let rows = bus.list_prefix(state_scope, &prefix).await; + let mut swept = 0u64; + + for raw in rows { + let Some(record) = Record::from_value(raw) else { continue }; + if record.session_id != session_id { continue; } // defensive + if record.status == Status::Done { continue; } // already terminal + + let key = pending_key(session_id, &record.function_call_id); + let timed_out = record.done_at(now_ms, Outcome::TimedOut); + if bus.set(state_scope, &key, timed_out.to_value()).await.is_ok() { + swept += 1; + } + } + json!({ "ok": true, "swept": swept }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::record::{Outcome, Record}; + use serde_json::json; + use std::sync::Mutex; + + #[derive(Default)] + struct InMemBus { + rows: Mutex>, + } + #[async_trait::async_trait] + impl StateBus for InMemBus { + async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), iii_sdk::IIIError> { + self.rows.lock().unwrap().insert((scope.into(), key.into()), value); + Ok(()) + } + async fn get(&self, scope: &str, key: &str) -> Option { + self.rows.lock().unwrap().get(&(scope.into(), key.into())).cloned() + } + async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec { + self.rows.lock().unwrap() + .iter() + .filter(|((s, k), _)| s == scope && k.starts_with(prefix)) + .map(|(_, v)| v.clone()) + .collect() + } + async fn delete(&self, scope: &str, key: &str) -> Result<(), iii_sdk::IIIError> { + self.rows.lock().unwrap().remove(&(scope.into(), key.into())); + Ok(()) + } + } + + async fn seed_done(bus: &InMemBus, session: &str, cid: &str, resolved_at: u64) { + let r = Record::pending( + cid.into(), "shell::exec".into(), + json!({"command": "ls"}), session.into(), 0, 60_000, + ).in_flight(resolved_at).done(Outcome::Executed { result: json!({"cid": cid}) }); + bus.set("approvals", &format!("{session}/{cid}"), r.to_value()).await.unwrap(); + } + + async fn seed_pending(bus: &InMemBus, session: &str, cid: &str, expires_at: u64) { + let mut r = Record::pending( + cid.into(), "shell::exec".into(), + json!({}), session.into(), 0, 60_000); + r.expires_at = expires_at; + bus.set("approvals", &format!("{session}/{cid}"), r.to_value()).await.unwrap(); + } + + #[tokio::test] + async fn consume_returns_done_rows_and_deletes_them() { + let bus = InMemBus::default(); + seed_done(&bus, "sess_a", "tc-1", 100).await; + seed_done(&bus, "sess_a", "tc-2", 200).await; + let reply = handle_consume(&bus, "approvals", + json!({"session_id": "sess_a"}), 1_000).await; + assert_eq!(reply["ok"], true); + assert_eq!(reply["omitted"], 0); + let entries = reply["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 2); + assert!(bus.get("approvals", "sess_a/tc-1").await.is_none()); + assert!(bus.get("approvals", "sess_a/tc-2").await.is_none()); + } + + #[tokio::test] + async fn consume_skips_pending_rows() { + let bus = InMemBus::default(); + seed_done(&bus, "sess_a", "tc-1", 100).await; + seed_pending(&bus, "sess_a", "tc-2", 999_999).await; + let reply = handle_consume(&bus, "approvals", + json!({"session_id": "sess_a"}), 1_000).await; + let entries = reply["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 1); + assert!(bus.get("approvals", "sess_a/tc-2").await.is_some()); + } + + #[tokio::test] + async fn consume_lazy_flips_expired_pending_then_returns_and_deletes() { + let bus = InMemBus::default(); + seed_pending(&bus, "sess_a", "tc-1", 500).await; + let reply = handle_consume(&bus, "approvals", + json!({"session_id": "sess_a"}), 1_000).await; + let entries = reply["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0]["status"], "done"); + assert_eq!(entries[0]["outcome"]["kind"], "timed_out"); + assert!(bus.get("approvals", "sess_a/tc-1").await.is_none()); + } + + #[tokio::test] + async fn consume_sorts_by_resolved_at_ascending() { + let bus = InMemBus::default(); + seed_done(&bus, "sess_a", "tc-z-late", 300).await; + seed_done(&bus, "sess_a", "tc-a-early", 100).await; + seed_done(&bus, "sess_a", "tc-m-mid", 200).await; + let reply = handle_consume(&bus, "approvals", + json!({"session_id": "sess_a"}), 1_000).await; + let entries = reply["entries"].as_array().unwrap(); + assert_eq!(entries[0]["function_call_id"], "tc-a-early"); + assert_eq!(entries[1]["function_call_id"], "tc-m-mid"); + assert_eq!(entries[2]["function_call_id"], "tc-z-late"); + } + + #[tokio::test] + async fn consume_cap_with_omitted_counter() { + let bus = InMemBus::default(); + for i in 0..60 { + seed_done(&bus, "sess_a", &format!("tc-{i:02}"), i as u64).await; + } + let reply = handle_consume(&bus, "approvals", + json!({"session_id": "sess_a", "limit": 50}), 1_000).await; + let entries = reply["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 50); + assert_eq!(reply["omitted"], 10); + let still_there = bus.list_prefix("approvals", "sess_a/").await; + assert_eq!(still_there.len(), 10); + } + + #[tokio::test] + async fn consume_missing_session_id_returns_error() { + let bus = InMemBus::default(); + let reply = handle_consume(&bus, "approvals", json!({}), 1_000).await; + assert_eq!(reply["ok"], false); + assert_eq!(reply["error"], "missing_session_id"); + } + + #[tokio::test] + async fn consume_defensive_session_id_filter_drops_foreign_rows() { + let bus = InMemBus::default(); + let r = Record::pending( + "tc-x".into(), "shell::exec".into(), json!({}), + "sess_b".into(), // WRONG session in data + 0, 60_000, + ).in_flight(100).done(Outcome::Executed { result: json!({}) }); + bus.set("approvals", "sess_a/tc-x", r.to_value()).await.unwrap(); + + let reply = handle_consume(&bus, "approvals", + json!({"session_id": "sess_a"}), 1_000).await; + let entries = reply["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 0); + assert!(bus.get("approvals", "sess_a/tc-x").await.is_some(), + "defensive: row stays in state, NOT deleted"); + } + + #[tokio::test] + async fn sweep_flips_pending_and_done_untouched() { + let bus = InMemBus::default(); + let pending = Record::pending( + "tc-1".into(), "shell::exec".into(), json!({}), + "sess_a".into(), 0, 60_000); + bus.set("approvals", "sess_a/tc-1", pending.to_value()).await.unwrap(); + + let in_flight = Record::pending( + "tc-2".into(), "shell::exec".into(), json!({}), + "sess_a".into(), 0, 60_000).in_flight(500); + bus.set("approvals", "sess_a/tc-2", in_flight.to_value()).await.unwrap(); + + let done = Record::pending( + "tc-3".into(), "shell::exec".into(), json!({}), + "sess_a".into(), 0, 60_000) + .in_flight(100).done(Outcome::Executed { result: json!({}) }); + bus.set("approvals", "sess_a/tc-3", done.to_value()).await.unwrap(); + + let reply = handle_sweep_session(&bus, "approvals", + json!({"session_id": "sess_a"})).await; + assert_eq!(reply["swept"], 2); + + let r1 = Record::from_value(bus.get("approvals", "sess_a/tc-1").await.unwrap()).unwrap(); + assert!(matches!(r1.outcome, Some(Outcome::TimedOut))); + let r2 = Record::from_value(bus.get("approvals", "sess_a/tc-2").await.unwrap()).unwrap(); + assert!(matches!(r2.outcome, Some(Outcome::TimedOut))); + let r3 = Record::from_value(bus.get("approvals", "sess_a/tc-3").await.unwrap()).unwrap(); + assert!(matches!(r3.outcome, Some(Outcome::Executed { .. })), + "already-Done rows must not be re-stamped"); + } + + #[tokio::test] + async fn list_pending_lazy_flips_expired_rows_out_of_the_list() { + let bus = InMemBus::default(); + // tc-live: expires far in the future (year ~5138). tc-expired: + // expires near epoch — definitely past now. + seed_pending(&bus, "sess_a", "tc-live", u64::MAX).await; + seed_pending(&bus, "sess_a", "tc-expired", 500).await; + // Advance the system clock indirectly: just trust the inline now_ms + // in handle_list_pending. expires_at=500 < now_ms, so it should flip. + // Wait briefly to ensure SystemTime::now() > 500ms since UNIX_EPOCH + // (it's well past 1970, so any current time satisfies this). + let reply = handle_list_pending(&bus, "approvals", + json!({"session_id": "sess_a"})).await; + let pending = reply["pending"].as_array().unwrap(); + assert_eq!(pending.len(), 1); + assert_eq!(pending[0]["function_call_id"], "tc-live"); + + // Expired row is now persisted as Done(TimedOut). + let r = Record::from_value(bus.get("approvals", "sess_a/tc-expired").await.unwrap()).unwrap(); + assert!(matches!(r.outcome, Some(Outcome::TimedOut))); + } +} diff --git a/approval-gate/src/intercept.rs b/approval-gate/src/intercept.rs new file mode 100644 index 00000000..5e5d0a24 --- /dev/null +++ b/approval-gate/src/intercept.rs @@ -0,0 +1,273 @@ +//! Intercept decision flow. +//! +//! One async entry point: [`handle_intercept`]. The classifier surface and +//! per-function `InterceptorRule` flow are gone — the layered rules engine +//! (`crate::rules`) is the only policy decision. `handle_intercept` reads +//! the verdict via [`crate::verdict_for`] and writes a `Pending` row when +//! the verdict is `Ask`. Allow/Deny verdicts return synchronous replies +//! and never touch state. +//! +//! Replay defense recognises all three persisted states: +//! - `Pending` → reply `{replay:"in_flight", status:"pending"}` +//! - `InFlight` → reply `{replay:"in_flight", status:"in_flight"}` +//! - `Done` → reply `{replay:"already_resolved", status:"done"}` +//! +//! None of these overwrite the existing row. + +use serde_json::{json, Value}; + +use crate::record::{Record, Status}; +use crate::rules::Ruleset; +use crate::state::StateBus; +use crate::wire::{pending_key, Denial, IncomingCall}; + +/// Subscriber-side entry point. Decides via `verdict_for` (rules layer); +/// on Ask, persists a Pending record. State-write failure fails closed +/// with `Denial::StateError` so a transient kv outage cannot silently +/// bypass an approval check. +pub async fn handle_intercept( + bus: &dyn StateBus, + state_scope: &str, + call: &IncomingCall, + rules: &Ruleset, + now_ms: u64, + timeout_ms: u64, +) -> Value { + // 1. Rules pre-check. + match crate::verdict_for(&call.function_id, &call.args, rules) { + crate::Verdict::Allow => return json!({ "block": false }), + crate::Verdict::Deny(denial) => { + return json!({ + "block": true, + "status": "denied", + "denial": denial, + "call_id": call.function_call_id, + "function_id": call.function_id, + }); + } + crate::Verdict::Ask => { /* fall through */ } + } + + // 2. Replay defense — never overwrite an existing row. + let key = pending_key(&call.session_id, &call.function_call_id); + if let Some(existing_raw) = bus.get(state_scope, &key).await { + if let Some(existing) = Record::from_value(existing_raw) { + return match existing.status { + Status::Done => json!({ + "block": true, + "status": "done", + "replay": "already_resolved", + "call_id": call.function_call_id, + "function_id": call.function_id, + }), + Status::Pending => json!({ + "block": true, + "status": "pending", + "replay": "in_flight", + "call_id": call.function_call_id, + "function_id": call.function_id, + }), + Status::InFlight => json!({ + "block": true, + "status": "in_flight", + "replay": "in_flight", + "call_id": call.function_call_id, + "function_id": call.function_id, + }), + }; + } + // Malformed row → fall through and overwrite (defensive). + } + + // 3. Fresh Pending write. + let record = Record::pending( + call.function_call_id.clone(), + call.function_id.clone(), + call.args.clone(), + call.session_id.clone(), + now_ms, + timeout_ms, + ); + if let Err(err) = bus.set(state_scope, &key, record.to_value()).await { + tracing::error!( + "approval-gate: failed to write pending record for {}/{}: {err} — failing closed", + call.session_id, call.function_call_id, + ); + let denial = Denial::StateError { + phase: "intercept_write_pending".to_string(), + error: err.to_string(), + }; + return json!({ + "block": true, + "denial": denial, + "status": "denied", + "call_id": call.function_call_id, + "function_id": call.function_id, + }); + } + json!({ + "block": true, + "status": "pending", + "call_id": call.function_call_id, + "function_id": call.function_id, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::record::{Outcome, Record}; + use crate::rules::{Action, Rule}; + use serde_json::json; + use std::sync::Mutex; + + #[derive(Default)] + struct InMemBus { + rows: Mutex>, + } + #[async_trait::async_trait] + impl StateBus for InMemBus { + async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), iii_sdk::IIIError> { + self.rows.lock().unwrap().insert((scope.into(), key.into()), value); + Ok(()) + } + async fn get(&self, scope: &str, key: &str) -> Option { + self.rows.lock().unwrap().get(&(scope.into(), key.into())).cloned() + } + async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec { + self.rows.lock().unwrap() + .iter() + .filter(|((s, k), _)| s == scope && k.starts_with(prefix)) + .map(|(_, v)| v.clone()) + .collect() + } + async fn delete(&self, scope: &str, key: &str) -> Result<(), iii_sdk::IIIError> { + self.rows.lock().unwrap().remove(&(scope.into(), key.into())); + Ok(()) + } + } + + fn call(fc_id: &str, fn_id: &str, args: Value) -> IncomingCall { + IncomingCall { + session_id: "sess_a".into(), + function_call_id: fc_id.into(), + function_id: fn_id.into(), + args, + approval_required: Vec::new(), + event_id: "evt-1".into(), + reply_stream: "hk-1".into(), + } + } + + #[tokio::test] + async fn allow_rule_returns_block_false_no_state_write() { + let bus = InMemBus::default(); + let rs: Ruleset = vec![Rule { + permission: "shell::exec".into(), + pattern: "git status*".into(), + action: Action::Allow, + }]; + let c = call("tc-1", "shell::exec", json!({"command": "git", "args": ["status"]})); + let reply = handle_intercept(&bus, "approvals", &c, &rs, 1_000, 60_000).await; + assert_eq!(reply["block"], false); + assert!(bus.list_prefix("approvals", "sess_a/").await.is_empty()); + } + + #[tokio::test] + async fn deny_rule_returns_block_true_structured_policy_denial() { + let bus = InMemBus::default(); + let rs: Ruleset = vec![Rule { + permission: "shell::exec".into(), + pattern: "rm -rf*".into(), + action: Action::Deny, + }]; + let c = call("tc-1", "shell::exec", json!({"command": "rm", "args": ["-rf", "/"]})); + let reply = handle_intercept(&bus, "approvals", &c, &rs, 1_000, 60_000).await; + assert_eq!(reply["block"], true); + assert_eq!(reply["denial"]["kind"], "policy"); + assert_eq!(reply["denial"]["detail"]["rule_permission"], "shell::exec"); + assert_eq!(reply["denial"]["detail"]["rule_pattern"], "rm -rf*"); + assert!(bus.list_prefix("approvals", "sess_a/").await.is_empty()); + } + + #[tokio::test] + async fn no_match_defaults_to_ask_writes_pending() { + let bus = InMemBus::default(); + let rs: Ruleset = vec![]; + let c = call("tc-1", "shell::exec", json!({"command": "git", "args": ["push"]})); + let reply = handle_intercept(&bus, "approvals", &c, &rs, 1_000, 60_000).await; + assert_eq!(reply["block"], true); + assert_eq!(reply["status"], "pending"); + let stored = bus.get("approvals", "sess_a/tc-1").await.expect("pending row"); + let r = Record::from_value(stored).unwrap(); + assert_eq!(r.status, Status::Pending); + } + + #[tokio::test] + async fn replay_on_pending_returns_in_flight_no_state_churn() { + let bus = InMemBus::default(); + let rs: Ruleset = vec![]; + let c = call("tc-1", "shell::exec", json!({"command": "ls"})); + + handle_intercept(&bus, "approvals", &c, &rs, 1_000, 60_000).await; + let r2 = handle_intercept(&bus, "approvals", &c, &rs, 2_000, 60_000).await; + assert_eq!(r2["replay"], "in_flight"); + assert_eq!(r2["status"], "pending"); + + let r = Record::from_value(bus.get("approvals", "sess_a/tc-1").await.unwrap()).unwrap(); + assert_eq!(r.expires_at, 61_000, "second call must NOT have re-written expires_at"); + } + + #[tokio::test] + async fn replay_on_in_flight_returns_in_flight_marker() { + let bus = InMemBus::default(); + let in_flight = Record::pending( + "tc-1".into(), "shell::exec".into(), json!({}), + "sess_a".into(), 0, 60_000, + ).in_flight(500); + bus.set("approvals", "sess_a/tc-1", in_flight.to_value()).await.unwrap(); + + let rs: Ruleset = vec![]; + let c = call("tc-1", "shell::exec", json!({})); + let reply = handle_intercept(&bus, "approvals", &c, &rs, 1_000, 60_000).await; + assert_eq!(reply["replay"], "in_flight"); + assert_eq!(reply["status"], "in_flight"); + } + + #[tokio::test] + async fn replay_on_done_returns_already_resolved_marker() { + let bus = InMemBus::default(); + let done = Record::pending( + "tc-1".into(), "shell::exec".into(), json!({}), + "sess_a".into(), 0, 60_000, + ).in_flight(500).done(Outcome::Executed { result: json!({"ok": true}) }); + bus.set("approvals", "sess_a/tc-1", done.to_value()).await.unwrap(); + + let rs: Ruleset = vec![]; + let c = call("tc-1", "shell::exec", json!({})); + let reply = handle_intercept(&bus, "approvals", &c, &rs, 1_000, 60_000).await; + assert_eq!(reply["replay"], "already_resolved"); + assert_eq!(reply["status"], "done"); + } + + #[tokio::test] + async fn state_write_failure_fails_closed_with_state_error_denial() { + struct FailBus; + #[async_trait::async_trait] + impl StateBus for FailBus { + async fn set(&self, _: &str, _: &str, _: Value) -> Result<(), iii_sdk::IIIError> { + Err(iii_sdk::IIIError::Runtime("kv down".into())) + } + async fn get(&self, _: &str, _: &str) -> Option { None } + async fn list_prefix(&self, _: &str, _: &str) -> Vec { Vec::new() } + async fn delete(&self, _: &str, _: &str) -> Result<(), iii_sdk::IIIError> { Ok(()) } + } + let rs: Ruleset = vec![]; + let c = call("tc-1", "shell::exec", json!({})); + let reply = handle_intercept(&FailBus, "approvals", &c, &rs, 1_000, 60_000).await; + assert_eq!(reply["block"], true); + assert_eq!(reply["status"], "denied"); + assert_eq!(reply["denial"]["kind"], "state_error"); + assert_eq!(reply["denial"]["detail"]["phase"], "intercept_write_pending"); + } +} diff --git a/approval-gate/src/lib.rs b/approval-gate/src/lib.rs index 19616766..80c0a65a 100644 --- a/approval-gate/src/lib.rs +++ b/approval-gate/src/lib.rs @@ -1,3910 +1,63 @@ -//! Approval gate. Subscribes to `agent::before_function_call` and blocks calls -//! whose `function_call.function_id` appears in the run's `approval_required` list, -//! waiting for the UI to call `approval::resolve` (or for a timeout). +//! Approval gate. Subscribes to `agent::before_function_call` and decides +//! every call via the layered rules engine (`rules::evaluate`). Allow → +//! `{block:false}`. Deny → `{block:true, denial:Policy{rule_permission, +//! rule_pattern}}`. Ask → write a Pending record and wait for +//! `approval::resolve`. pub mod config; +pub mod delivery; +pub mod intercept; pub mod manifest; +pub mod record; +pub mod register; +pub mod resolve; +pub mod rules; +pub mod state; +pub mod wire; pub use config::{InterceptorRule, WorkerConfig}; - -use std::sync::Arc; - -use iii_sdk::{ - FunctionRef, IIIError, RegisterFunctionMessage, RegisterTriggerInput, TriggerRequest, III, +pub use delivery::{ + handle_consume, handle_list_pending, handle_sweep_session, CONSUME_DEFAULT_LIMIT, +}; +pub use intercept::handle_intercept; +pub use record::{Outcome, Record, Status}; +pub use register::{ + register, Refs, FN_CONSUME, FN_LIST_PENDING, FN_LOOKUP_RECORD, FN_RESOLVE, FN_SWEEP_SESSION, + STATE_SCOPE, +}; +pub use resolve::{handle_lookup_record, handle_resolve}; +pub use state::{FunctionExecutor, IiiFunctionExecutor, IiiStateBus, StateBus}; +pub use wire::{ + block_reply_for, extract_call, pending_key, Decision, Denial, IncomingCall, WireDecision, }; -use serde_json::{json, Value}; - -pub const FN_RESOLVE: &str = "approval::resolve"; -pub const FN_LIST_PENDING: &str = "approval::list_pending"; -pub const FN_LIST_UNDELIVERED: &str = "approval::list_undelivered"; -pub const FN_CONSUME_UNDELIVERED: &str = "approval::consume_undelivered"; -pub const FN_ACK_DELIVERED: &str = "approval::ack_delivered"; -pub const FN_FLUSH_DELIVERED: &str = "approval::flush_delivered"; -pub const FN_SWEEP_SESSION: &str = "approval::sweep_session"; -pub const FN_LOOKUP_RECORD: &str = "approval::lookup_record"; -/// Default `approval_state_scope` (matches [`WorkerConfig::default`]). -pub const STATE_SCOPE: &str = "approvals"; - -fn rule_for<'a>(rules: &'a [InterceptorRule], function_id: &str) -> Option<&'a InterceptorRule> { - rules.iter().find(|r| r.function_id == function_id) -} - -/// What the subscriber should do with an incoming call. Decided by the -/// matching interceptor rule (authoritative) with a fallback to the run's -/// `approval_required` list when no rule exists. -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum InterceptAction { - /// No rule, no `approval_required` listing — let the call through. - Pass, - /// Pause and create a pending record; no classifier consulted. - Pause, - /// Run the classifier first; on `ask`, pause; on `auto`, pass; on `deny`, block. - Classify { - classifier_fn: String, - classifier_timeout_ms: u64, - }, -} - -/// Pure decision: given a matching rule (or none) and whether the run -/// explicitly listed this function id in `approval_required`, what should -/// the subscriber do? Interceptor rules are authoritative — an operator -/// who registered a rule meant for every call to go through it, regardless -/// of per-run opt-in. -pub(crate) fn decide_intercept_action( - rule: Option<&InterceptorRule>, - requires_approval: bool, -) -> InterceptAction { - match rule { - Some(r) if r.classifier.as_ref().is_some_and(|s| !s.is_empty()) => { - InterceptAction::Classify { - classifier_fn: r.classifier.clone().unwrap(), - classifier_timeout_ms: r.classifier_timeout_ms, - } - } - Some(_) => InterceptAction::Pause, - None if requires_approval => InterceptAction::Pause, - None => InterceptAction::Pass, - } -} - -fn merge_from_approval_marker_if_needed( - inject: bool, - args: Value, - function_call_id: &str, - session_id: &str, -) -> Value { - if !inject { - return args; - } - let marker = json!({ - "call_id": function_call_id, - "session_id": session_id, - }); - match args { - Value::Object(mut m) => { - m.insert("__from_approval".into(), marker); - Value::Object(m) - } - other if other.is_null() => json!({ "__from_approval": marker }), - other => json!({ - "payload": other, - "__from_approval": marker, - }), - } -} - -/// Structured deny payload carried on wire replies, persisted records, and -/// `approval_resolved` stream events. Replaces the legacy free-form -/// `decision_reason` / `reason` strings so consumers (turn-orchestrator -/// stitching, UIs, the LLM) can branch on `kind` instead of parsing prose. -/// -/// Wire shape (serde tag=kind, content=detail, snake_case): -/// `{ "kind": "policy", "detail": { "classifier_reason": "...", "classifier_fn": "..." } }` -/// `{ "kind": "user_rejected", "detail": null }` -/// `{ "kind": "user_corrected", "detail": { "feedback": "..." } }` -/// `{ "kind": "state_error", "detail": { "phase": "...", "error": "..." } }` -/// `{ "kind": "legacy", "detail": { "reason": "..." } }` -/// -/// `Legacy` is the read-time landing pad for records persisted before this -/// type existed (see [`migrate_legacy_record`]). New writes never emit it. -#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "kind", content = "detail", rename_all = "snake_case")] -pub enum Denial { - Policy { - classifier_reason: String, - classifier_fn: String, - }, - UserRejected, - UserCorrected { - feedback: String, - }, - StateError { - phase: String, - error: String, - }, - Legacy { - reason: String, - }, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) enum ClassifierDecision { - Auto, - Deny(Denial), - Ask, -} - -/// Parse classifier JSON (`decision` tag: auto | deny | ask). On `deny` -/// the reply may carry `reason` (free-form classifier text) and optionally -/// `classifier_fn` — both get folded into a [`Denial::Policy`]. -pub(crate) fn interpret_classifier_reply( - value: &Value, - classifier_fn: &str, -) -> Result { - let tag = value.get("decision").and_then(Value::as_str).ok_or(())?; - match tag { - "auto" => Ok(ClassifierDecision::Auto), - "deny" => { - let classifier_reason = value - .get("reason") - .and_then(Value::as_str) - .unwrap_or("denied") - .to_string(); - Ok(ClassifierDecision::Deny(Denial::Policy { - classifier_reason, - classifier_fn: classifier_fn.to_string(), - })) - } - "ask" => Ok(ClassifierDecision::Ask), - _ => Err(()), - } -} - -/// True if `status` is one of the terminal states a stitched system message -/// should be built from. `pending` and `approved` are intermediate. -pub fn is_terminal_status(status: &str) -> bool { - matches!(status, "executed" | "failed" | "denied" | "timed_out") -} - -#[derive(Debug, Clone, PartialEq)] -pub struct IncomingCall { - pub session_id: String, - pub function_call_id: String, - pub function_id: String, - pub args: Value, - pub approval_required: Vec, - pub event_id: String, - pub reply_stream: String, -} - -impl IncomingCall { - pub fn requires_approval(&self) -> bool { - self.approval_required - .iter() - .any(|n| n == &self.function_id) - } -} +/// Subscriber's terminal verdict for an incoming call. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum Decision { +pub(crate) enum Verdict { Allow, Deny(Denial), + Ask, } -/// Wire-format decision string used by `approval::resolve` and stored -/// as the `status` field of resolved approval records. -/// -/// Serializes / deserializes as `"allow"` or `"deny"`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum WireDecision { - Allow, - Deny, -} - -/// Build the state-store key for a pending approval entry. -/// -/// `session_id` and `function_call_id` must not contain `/`. They are caller-controlled -/// IDs minted by turn-orchestrator; today neither format uses the separator. -pub fn pending_key(session_id: &str, function_call_id: &str) -> String { - debug_assert!(!session_id.contains('/'), "session_id must not contain '/'"); - debug_assert!( - !function_call_id.contains('/'), - "function_call_id must not contain '/'" - ); - format!("{session_id}/{function_call_id}") -} - -pub fn extract_call(envelope: &Value) -> Option { - let event_id = envelope - .get("event_id") - .and_then(Value::as_str)? - .to_string(); - let reply_stream = envelope - .get("reply_stream") - .and_then(Value::as_str)? - .to_string(); - let inner = envelope.get("payload").unwrap_or(envelope); - let session_id = inner.get("session_id").and_then(Value::as_str)?.to_string(); - let fc = inner - .get("function_call") - .or_else(|| inner.get("tool_call"))?; - let function_id = fc - .get("function_id") - .or_else(|| fc.get("name")) - .and_then(Value::as_str)? - .to_string(); - Some(IncomingCall { - session_id, - function_call_id: fc.get("id").and_then(Value::as_str)?.to_string(), - function_id, - args: fc.get("arguments").cloned().unwrap_or_else(|| json!({})), - approval_required: inner - .get("approval_required") - .and_then(|v| serde_json::from_value(v.clone()).ok()) - .unwrap_or_default(), - event_id, - reply_stream, - }) -} - -pub fn build_pending_record( - function_call_id: &str, +/// Apply the layered rules to an incoming call. Last-matching rule wins; +/// no match defaults to Ask (operator-safe default — paired with the +/// curated default ruleset shipped in `iii.worker.yaml`). +pub(crate) fn verdict_for( function_id: &str, - args: &Value, - now_ms: u64, - timeout_ms: u64, -) -> Value { - json!({ - "function_call_id": function_call_id, - "function_id": function_id, - "args": args, - "status": "pending", - "expires_at": now_ms.saturating_add(timeout_ms), - }) -} - -/// Build a new record by transitioning a pending base record to a terminal -/// status. All terminal fields (`result`, `error`, `denial`) are optional; -/// only the ones provided are attached. Existing fields on the base -/// (including `delivered_in_turn_id` and `resolved_at` if present) are -/// preserved. The first transition into a terminal status stamps -/// `resolved_at`. -pub fn transition_record( - base: &Value, - new_status: &str, - result: Option, - error: Option, - denial: Option, -) -> Value { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - transition_record_with_now(base, new_status, result, error, denial, now_ms) -} - -/// Testable variant of [`transition_record`] that takes `now_ms` directly. -pub fn transition_record_with_now( - base: &Value, - new_status: &str, - result: Option, - error: Option, - denial: Option, - now_ms: u64, -) -> Value { - let mut rec = base.clone(); - if let Some(obj) = rec.as_object_mut() { - obj.insert("status".into(), Value::String(new_status.to_string())); - if let Some(r) = result { - obj.insert("result".into(), r); - } - if let Some(e) = error { - obj.insert("error".into(), Value::String(e)); - } - if let Some(d) = denial { - obj.insert( - "denial".into(), - serde_json::to_value(&d).expect("Denial is always serializable"), - ); - } - if is_terminal_status(new_status) && !obj.contains_key("resolved_at") { - obj.insert("resolved_at".into(), Value::Number(now_ms.into())); - } - } - rec -} - -/// Build the hook block reply for a [`Decision`]. Deny replies carry the -/// structured [`Denial`] under `denial`; consumers (turn-orchestrator -/// stitching, UIs, the LLM) branch on `denial.kind` rather than parsing a -/// free-form `reason` string. -pub fn block_reply_for(decision: &Decision) -> Value { - match decision { - Decision::Allow => json!({ "block": false }), - Decision::Deny(denial) => json!({ - "block": true, - "denial": denial, - }), - } -} - -pub struct Refs { - pub resolve: FunctionRef, - pub list_pending: FunctionRef, - pub list_undelivered: FunctionRef, - pub consume_undelivered: FunctionRef, - pub ack_delivered: FunctionRef, - pub flush_delivered: FunctionRef, - pub sweep_session: FunctionRef, - pub lookup_record: FunctionRef, - pub subscriber_fn: FunctionRef, - pub subscriber_trigger: iii_sdk::Trigger, - /// Background task that flips expired pending records to `timed_out` and - /// emits the corresponding `approval_resolved` events. Kept alive by - /// virtue of being held here; aborts when the worker shuts down. - pub sweeper: tokio::task::JoinHandle<()>, -} - -#[async_trait::async_trait] -pub trait StateBus: Send + Sync { - async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), iii_sdk::IIIError>; - async fn get(&self, scope: &str, key: &str) -> Option; - async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec; -} - -/// Invokes an iii function with arguments and returns its result or an error -/// string. Abstracted so tests can stub the underlying call. -#[async_trait::async_trait] -pub trait FunctionExecutor: Send + Sync { - async fn invoke( - &self, - function_id: &str, - args: Value, - function_call_id: &str, - session_id: &str, - ) -> Result; -} - -/// Production [`FunctionExecutor`] backed by `iii.trigger`. -pub struct IiiFunctionExecutor { - pub iii: III, - pub rules: Arc>, -} - -#[async_trait::async_trait] -impl FunctionExecutor for IiiFunctionExecutor { - async fn invoke( - &self, - function_id: &str, - args: Value, - function_call_id: &str, - session_id: &str, - ) -> Result { - let inject = - rule_for(self.rules.as_slice(), function_id).is_some_and(|r| r.inject_approval_marker); - let payload = - merge_from_approval_marker_if_needed(inject, args, function_call_id, session_id); - self.iii - .trigger(TriggerRequest { - function_id: function_id.to_string(), - payload, - action: None, - timeout_ms: None, - }) - .await - .map_err(|e| e.to_string()) - } -} - -/// Decide whether a call is gated; if so, write a pending record and return -/// the structured pending hook reply. If not gated, return `{block: false}` -/// and do nothing. -/// -/// Stamps `session_id` onto the persisted record so the timeout sweeper can -/// emit `approval_resolved` to the right session stream without consulting -/// the storage layer's keys. -/// -/// State-write failure is treated as fail-closed: the gate replies -/// `{block:true, status:"denied"}` so a transient kv outage cannot silently -/// bypass an approval check. -pub async fn handle_intercept( - bus: &dyn StateBus, - state_scope: &str, - call: &IncomingCall, - now_ms: u64, - timeout_ms: u64, - force_pending: bool, -) -> Value { - if !force_pending && !call.requires_approval() { - return json!({ "block": false }); - } - - // Defense in depth: if a record for this (session, call_id) already - // exists, don't blow it away. Re-intercept of an already-decided call - // would otherwise revert a terminal record back to `pending`, losing - // the audit trail and any `delivered_in_turn_id` stamp. Surfaced by - // the state-machine proptest in tests::state_machine_invariants. - let key = pending_key(&call.session_id, &call.function_call_id); - if let Some(existing) = bus.get(state_scope, &key).await { - let status = existing - .get("status") - .and_then(Value::as_str) - .unwrap_or("") - .to_string(); - if is_terminal_status(&status) { - // Replay of an already-resolved call: the prior status carries - // the meaning. No fresh Denial is synthesized — consumers that - // need to render the historical decision read the persisted - // record via approval::lookup_record. - return json!({ - "block": true, - "status": status, - "replay": "already_resolved", - "call_id": call.function_call_id, - "function_id": call.function_id, - }); - } - if status == "pending" || status == "approved" { - // Replay of an in-flight intercept — keep the existing row, - // re-emit the pending reply. No state churn. - return json!({ - "block": true, - "status": "pending", - "replay": "in_flight", - "call_id": call.function_call_id, - "function_id": call.function_id, - }); - } - } - - let mut record = build_pending_record( - &call.function_call_id, - &call.function_id, - &call.args, - now_ms, - timeout_ms, - ); - if let Some(obj) = record.as_object_mut() { - obj.insert("session_id".into(), Value::String(call.session_id.clone())); - } - if let Err(err) = bus - .set( - state_scope, - &pending_key(&call.session_id, &call.function_call_id), - record, - ) - .await - { - tracing::error!( - "approval-gate: failed to write pending record for {}/{}: {err} — failing closed", - call.session_id, - call.function_call_id - ); - let denial = Denial::StateError { - phase: "intercept_write_pending".to_string(), - error: err.to_string(), - }; - return json!({ - "block": true, - "denial": denial, - "status": "denied", - "call_id": call.function_call_id, - "function_id": call.function_id, - }); - } - json!({ - "block": true, - "status": "pending", - "call_id": call.function_call_id, - "function_id": call.function_id, - }) -} - -/// Lookup a single approval record by session + call id (for shell bypass validation). -pub async fn handle_lookup_record(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - let function_call_id = payload - .get("function_call_id") - .and_then(Value::as_str) - .unwrap_or(""); - if session_id.is_empty() || function_call_id.is_empty() { - return Value::Null; - } - let key = pending_key(session_id, function_call_id); - bus.get(state_scope, &key).await.unwrap_or(Value::Null) -} - -/// For a bag of pending records, return the subset that have expired at -/// `now_ms` along with the metadata needed to commit the flip and notify the -/// owning session. Records without a stamped `session_id` (legacy rows -/// written before that field existed) are skipped — they'll still be picked -/// up lazily by `handle_list_undelivered` on the next read. -pub fn collect_timed_out_for_sweep( - records: &[Value], - now_ms: u64, -) -> Vec<(String, Value, String, String)> { - records - .iter() - .filter_map(|rec| { - let flipped = maybe_flip_timed_out(rec, now_ms)?; - let session_id = flipped - .get("session_id") - .and_then(Value::as_str)? - .to_string(); - let function_call_id = flipped - .get("function_call_id") - .and_then(Value::as_str)? - .to_string(); - if session_id.is_empty() || function_call_id.is_empty() { - return None; - } - let key = pending_key(&session_id, &function_call_id); - Some((key, flipped, session_id, function_call_id)) - }) - .collect() -} - -/// Return Some(timed_out_record) if `rec` is pending and `now_ms` is past -/// `expires_at`; otherwise None. Pure function — does not write state. -pub fn maybe_flip_timed_out(rec: &Value, now_ms: u64) -> Option { - if rec.get("status").and_then(Value::as_str) != Some("pending") { - return None; - } - let exp = rec.get("expires_at").and_then(Value::as_u64)?; - if now_ms < exp { - return None; - } - // Timeout flip carries no Denial: the `timed_out` status itself is the - // explanation. Downstream renderers (turn-orchestrator stitching, UIs) - // branch on the status, not on a redundant reason string. - Some(transition_record(rec, "timed_out", None, None, None)) -} - -/// Map a legacy approval record to the current shape. Covers two -/// generations of drift: -/// -/// 1. **Pre-trigger-model** (status `"allow"` / `"deny"`, free-form -/// `reason`): rewritten as `"executed"` / `"denied"`, with the old -/// `reason` folded into a [`Denial::Legacy`]. -/// 2. **Pre-Denial** (`decision_reason: ` field, no `denial`): the -/// string is moved into [`Denial::Legacy { reason }`] and the old key is -/// stripped so writers never resurface it. -/// -/// Returns `None` only when the record is already current. -pub fn migrate_legacy_record(rec: &Value) -> Option { - let status = rec.get("status").and_then(Value::as_str)?; - // Path 1: pre-trigger-model status rename. - let (new_status, denial_to_carry) = match status { - "allow" => ("executed", None), - "deny" => ( - "denied", - rec.get("reason") - .and_then(Value::as_str) - .map(|s| Denial::Legacy { - reason: s.to_string(), - }), - ), - _ => { - // Path 2: status already current, but the record may carry the - // pre-Denial `decision_reason` flat string. Lift it into Denial - // and strip the legacy key; otherwise return None. - let legacy_reason = rec - .get("decision_reason") - .and_then(Value::as_str) - .map(str::to_string); - let needs_lift = legacy_reason.is_some() && rec.get("denial").is_none(); - if !needs_lift { - return None; - } - let denial = Denial::Legacy { - reason: legacy_reason.expect("checked Some above"), - }; - let mut migrated = rec.clone(); - if let Some(obj) = migrated.as_object_mut() { - obj.remove("decision_reason"); - obj.insert( - "denial".into(), - serde_json::to_value(&denial).expect("Denial is always serializable"), - ); - obj.insert("legacy_migrated".into(), Value::Bool(true)); - } - return Some(migrated); - } - }; - let mut migrated = transition_record(rec, new_status, None, None, denial_to_carry); - if let Some(obj) = migrated.as_object_mut() { - // Strip the pre-trigger-model `reason` once it has been folded into - // `denial`; leaving it would create a dead field on the new shape. - obj.remove("reason"); - obj.insert("legacy_migrated".into(), Value::Bool(true)); - } - Some(migrated) -} - -pub async fn handle_resolve( - bus: &dyn StateBus, - exec: &dyn FunctionExecutor, - state_scope: &str, - payload: Value, - now_ms: u64, -) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - let function_call_id = payload - .get("function_call_id") - .or_else(|| payload.get("tool_call_id")) - .and_then(Value::as_str) - .unwrap_or(""); - if session_id.is_empty() || function_call_id.is_empty() { - return json!({ "ok": false, "error": "missing_id" }); - } - let decision: WireDecision = match payload.get("decision").cloned() { - Some(v) => match serde_json::from_value(v) { - Ok(d) => d, - Err(_) => return json!({ "ok": false, "error": "bad_decision" }), + args: &serde_json::Value, + rules: &rules::Ruleset, +) -> Verdict { + let pattern = rules::pattern_for(function_id, args); + match rules::evaluate(function_id, &pattern, rules) { + Some(r) => match r.action { + rules::Action::Allow => Verdict::Allow, + rules::Action::Deny => Verdict::Deny(Denial::Policy { + rule_permission: r.permission.clone(), + rule_pattern: r.pattern.clone(), + }), + rules::Action::Ask => Verdict::Ask, }, - None => return json!({ "ok": false, "error": "bad_decision" }), - }; - let key = pending_key(session_id, function_call_id); - let Some(existing) = bus.get(state_scope, &key).await else { - return json!({ "ok": false, "error": "not_found" }); - }; - - // Lazy timeout flip (covered by Task 7 tests). - let existing = match maybe_flip_timed_out(&existing, now_ms) { - Some(flipped) => { - let _ = bus.set(state_scope, &key, flipped.clone()).await; - return json!({ "ok": false, "error": "timed_out" }); - } - None => existing, - }; - - if existing.get("status").and_then(Value::as_str) != Some("pending") { - return json!({ "ok": false, "error": "already_resolved" }); - } - - match decision { - WireDecision::Deny => { - // Caller supplies a structured Denial. Accepted shapes: - // { "decision": "deny", "denial": { "kind": "user_rejected", ... } } - // { "decision": "deny", "denial": { "kind": "user_corrected", "detail": { "feedback": "..." } } } - // Missing `denial` is treated as a bare UserRejected (no feedback) - // so the simplest UI flow stays one-click. - let denial = match payload.get("denial").cloned() { - Some(v) => match serde_json::from_value::(v) { - Ok(d) => d, - Err(_) => return json!({ "ok": false, "error": "bad_denial" }), - }, - None => Denial::UserRejected, - }; - let denied = transition_record(&existing, "denied", None, None, Some(denial)); - if let Err(e) = bus.set(state_scope, &key, denied).await { - tracing::error!("approval-gate: failed to write denied record: {e}"); - return json!({ "ok": false, "error": "state_write_failed" }); - } - json!({ "ok": true }) - } - WireDecision::Allow => { - let function_id = existing - .get("function_id") - .and_then(Value::as_str) - .unwrap_or("") - .to_string(); - let args = existing.get("args").cloned().unwrap_or(json!({})); - let approved = transition_record(&existing, "approved", None, None, None); - // Best-effort intermediate write; if it fails, still try to invoke. - let _ = bus.set(state_scope, &key, approved.clone()).await; - match exec - .invoke(&function_id, args, function_call_id, session_id) - .await - { - Ok(result) => { - let executed = - transition_record(&approved, "executed", Some(result), None, None); - if let Err(e) = bus.set(state_scope, &key, executed).await { - tracing::error!("approval-gate: failed to write executed record: {e}"); - return json!({ "ok": false, "error": "state_write_failed" }); - } - } - Err(error) => { - let failed = transition_record(&approved, "failed", None, Some(error), None); - if let Err(e) = bus.set(state_scope, &key, failed).await { - tracing::error!("approval-gate: failed to write failed record: {e}"); - return json!({ "ok": false, "error": "state_write_failed" }); - } - } - } - json!({ "ok": true }) - } - } -} - -pub async fn handle_list_pending(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - if session_id.is_empty() { - return json!({ "pending": [] }); - } - let prefix = format!("{session_id}/"); - let all = bus.list_prefix(state_scope, &prefix).await; - let pending: Vec = all - .into_iter() - .filter(|v| { - if migrate_legacy_record(v).is_some() { - return false; - } - v.get("status").and_then(Value::as_str) == Some("pending") - }) - .collect(); - json!({ "pending": pending }) -} - -/// Default cap for `handle_list_undelivered` responses. A single LLM turn -/// should never be asked to ingest more than this many stitched approval -/// messages; older entries beyond the cap stay unacked and are reported via -/// the `omitted` counter so the caller can render a summary line. -pub const LIST_UNDELIVERED_DEFAULT_LIMIT: usize = 50; - -/// Return terminal-status records for a session that haven't been stamped -/// with `delivered_in_turn_id`. Lazy timeout: pending records past -/// `expires_at` (as observed at `now_ms`) are flipped to `timed_out` before -/// the filter so they surface here in the same call. -/// -/// Sorted oldest-first by `resolved_at` (records missing `resolved_at` sort -/// last as `u64::MAX`). Capped at `limit` (default -/// [`LIST_UNDELIVERED_DEFAULT_LIMIT`]); the response always includes an -/// `omitted` field counting entries left behind. -pub async fn handle_list_undelivered( - bus: &dyn StateBus, - state_scope: &str, - payload: Value, - now_ms: u64, -) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - if session_id.is_empty() { - return json!({ "entries": [], "omitted": 0 }); - } - let limit = payload - .get("limit") - .and_then(Value::as_u64) - .map(|n| n as usize) - .unwrap_or(LIST_UNDELIVERED_DEFAULT_LIMIT); - let prefix = format!("{session_id}/"); - let all = bus.list_prefix(state_scope, &prefix).await; - let mut entries: Vec = Vec::new(); - for rec in all { - // Defensive scope: some bus backends ignore the prefix and return - // every record in `state_scope`. Filter by stamped `session_id`: - // - // - record has session_id matching ours → keep - // - record has session_id different from ours → drop - // - record lacks session_id AND is in "allow"/"deny" pre-trigger - // legacy form → keep (`migrate_legacy_record` below re-keys it - // under our session) - // - record lacks session_id AND is already terminal → drop - // (orphan from before session-id stamping; cannot be attributed) - match rec.get("session_id").and_then(Value::as_str) { - Some(sid) if sid == session_id => {} - Some(_) => continue, - None => { - let status = rec.get("status").and_then(Value::as_str).unwrap_or(""); - if status != "allow" && status != "deny" { - continue; - } - } - } - let rec = if let Some(migrated) = migrate_legacy_record(&rec) { - let call_id = migrated - .get("function_call_id") - .and_then(Value::as_str) - .unwrap_or(""); - if !call_id.is_empty() { - let _ = bus - .set( - state_scope, - &pending_key(session_id, call_id), - migrated.clone(), - ) - .await; - } - migrated - } else { - rec - }; - let rec = if let Some(flipped) = maybe_flip_timed_out(&rec, now_ms) { - let call_id = flipped - .get("function_call_id") - .and_then(Value::as_str) - .unwrap_or(""); - let _ = bus - .set( - state_scope, - &pending_key(session_id, call_id), - flipped.clone(), - ) - .await; - flipped - } else { - rec - }; - let status = rec.get("status").and_then(Value::as_str).unwrap_or(""); - if !is_terminal_status(status) { - continue; - } - if rec - .get("delivered_in_turn_id") - .is_some_and(|v| !v.is_null()) - { - continue; - } - entries.push(rec); - } - entries.sort_by_key(|e| { - e.get("resolved_at") - .and_then(Value::as_u64) - .unwrap_or(u64::MAX) - }); - let total = entries.len(); - let omitted = total.saturating_sub(limit); - entries.truncate(limit); - json!({ "entries": entries, "omitted": omitted }) -} - -/// Stamp `delivered_in_turn_id` on terminal-status records named in -/// `call_ids` for the given session. Idempotent: records already stamped -/// (non-null `delivered_in_turn_id`) are not overwritten. Unknown call ids -/// are silently skipped. -pub async fn handle_ack_delivered(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - let turn_id = payload.get("turn_id").and_then(Value::as_str).unwrap_or(""); - let call_ids: Vec = payload - .get("call_ids") - .and_then(|v| v.as_array()) - .map(|arr| { - arr.iter() - .filter_map(|v| v.as_str().map(str::to_string)) - .collect() - }) - .unwrap_or_default(); - if session_id.is_empty() || turn_id.is_empty() || call_ids.is_empty() { - return json!({ "ok": true, "stamped": 0 }); - } - let mut stamped = 0_u64; - for cid in call_ids { - let key = pending_key(session_id, &cid); - let Some(rec) = bus.get(state_scope, &key).await else { - continue; - }; - if rec - .get("delivered_in_turn_id") - .is_some_and(|v| !v.is_null()) - { - continue; - } - let mut next = rec; - next.as_object_mut().unwrap().insert( - "delivered_in_turn_id".into(), - Value::String(turn_id.to_string()), - ); - if bus.set(state_scope, &key, next).await.is_ok() { - stamped += 1; - } - } - json!({ "ok": true, "stamped": stamped }) -} - -/// Atomic list+ack: returns the same entries `handle_list_undelivered` would -/// surface (subject to the same FIFO+cap rules) and stamps each one with -/// `delivered_in_turn_id` before returning. Eliminates the list→LLM→ack -/// race window: if the caller crashes after receiving the response, the -/// entries are still considered delivered and will not resurface, which is -/// acceptable because terminal records are informational (the side-effect -/// already executed inside the gate). -/// -/// Required payload: `{ session_id, turn_id, limit? }`. -pub async fn handle_consume_undelivered( - bus: &dyn StateBus, - state_scope: &str, - payload: Value, - now_ms: u64, -) -> Value { - let turn_id = payload.get("turn_id").and_then(Value::as_str).unwrap_or(""); - if turn_id.is_empty() { - return json!({ "ok": false, "error": "missing_turn_id", "entries": [], "omitted": 0 }); - } - let listed = handle_list_undelivered(bus, state_scope, payload.clone(), now_ms).await; - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - let entries = listed["entries"].as_array().cloned().unwrap_or_default(); - let omitted = listed["omitted"].as_u64().unwrap_or(0); - for rec in &entries { - let cid = rec - .get("function_call_id") - .and_then(Value::as_str) - .unwrap_or(""); - if cid.is_empty() { - continue; - } - let key = pending_key(session_id, cid); - let mut stamped = rec.clone(); - stamped.as_object_mut().unwrap().insert( - "delivered_in_turn_id".into(), - Value::String(turn_id.to_string()), - ); - let _ = bus.set(state_scope, &key, stamped).await; - } - json!({ "ok": true, "entries": entries, "omitted": omitted }) -} - -/// One-shot drain: stamp every terminal-status record in `session_id` that -/// lacks `delivered_in_turn_id`. Intended for operator recovery after a -/// large backlog accumulates (e.g. when the orchestrator was offline or -/// `consume_undelivered` was unreachable). Pending records are untouched — -/// use `sweep_session` if you want to expire them first. -pub async fn handle_flush_delivered(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - let turn_id = payload.get("turn_id").and_then(Value::as_str).unwrap_or(""); - if session_id.is_empty() || turn_id.is_empty() { - return json!({ "ok": false, "error": "missing_session_or_turn_id", "stamped": 0 }); - } - let prefix = format!("{session_id}/"); - let all = bus.list_prefix(state_scope, &prefix).await; - let mut stamped = 0_u64; - for rec in all { - let status = rec.get("status").and_then(Value::as_str).unwrap_or(""); - if !is_terminal_status(status) { - continue; - } - if rec - .get("delivered_in_turn_id") - .is_some_and(|v| !v.is_null()) - { - continue; - } - let cid = rec - .get("function_call_id") - .and_then(Value::as_str) - .map(str::to_string) - .unwrap_or_default(); - if cid.is_empty() { - continue; - } - let mut next = rec; - next.as_object_mut().unwrap().insert( - "delivered_in_turn_id".into(), - Value::String(turn_id.to_string()), - ); - if bus - .set(state_scope, &pending_key(session_id, &cid), next) - .await - .is_ok() - { - stamped += 1; - } - } - json!({ "ok": true, "stamped": stamped }) -} - -/// Sweep all still-pending approvals for a session to timed_out. -/// -/// Reason defaults to `"session_deleted"` (legacy callers) but can be -/// overridden via the `reason` payload field — `run::stop` passes -/// `"run_stopped"` so consumers can distinguish a manual abort from a -/// session delete. -pub async fn handle_sweep_session(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - if session_id.is_empty() { - return json!({ "ok": false, "error": "missing_session_id", "swept": 0 }); - } - // The optional `reason` payload field used to be persisted as - // `decision_reason` on the resulting record. With Denial now the only - // structured reason channel and timed_out carrying no denial, the - // sweep_session reason is informational for the caller only — we log - // it but do not stamp it on the record. - if let Some(r) = payload.get("reason").and_then(Value::as_str) { - tracing::info!(session_id, reason = r, "approval-gate: sweep_session"); - } - let prefix = format!("{session_id}/"); - let all = bus.list_prefix(state_scope, &prefix).await; - let mut swept = 0_u64; - for rec in all { - if rec.get("status").and_then(Value::as_str) != Some("pending") { - continue; - } - let call_id = rec - .get("function_call_id") - .and_then(Value::as_str) - .unwrap_or(""); - if call_id.is_empty() { - continue; - } - let flipped = transition_record(&rec, "timed_out", None, None, None); - if bus - .set(state_scope, &pending_key(session_id, call_id), flipped) - .await - .is_ok() - { - swept += 1; - } + None => Verdict::Ask, } - json!({ "ok": true, "swept": swept }) } -fn uuid_like() -> String { - // Lightweight unique-ish id without pulling uuid in: ns timestamp + counter. - use std::sync::atomic::{AtomicU64, Ordering}; - static C: AtomicU64 = AtomicU64::new(0); - let n = C.fetch_add(1, Ordering::Relaxed); - let t = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - format!("{t:x}-{n:x}") -} - -async fn write_event(iii: &III, session_id: &str, event: &Value) { - let _ = iii - .trigger(TriggerRequest { - function_id: "stream::set".into(), - payload: json!({ - "stream_name": "agent::events", - "group_id": session_id, - "item_id": format!("approval-{}", uuid_like()), - "data": event, - }), - action: None, - timeout_ms: None, - }) - .await; -} - -/// Build the `approval_resolved` event a sweeper emits when it auto-flips an -/// expired pending record. Pure — caller pumps the result onto the stream. -fn timeout_resolved_event(function_call_id: &str) -> Value { - json!({ - "type": "approval_resolved", - "function_call_id": function_call_id, - "tool_call_id": function_call_id, - "decision": "deny", - "status": "timed_out", - "decision_reason": "timeout", - }) -} - -/// Spawn the periodic timeout sweeper. The task ticks every `interval_ms`, -/// scans the configured state scope, and for any pending record whose -/// `expires_at` is in the past: writes the flipped record back and emits an -/// `approval_resolved` (status=timed_out) frame on `agent::events/`. -/// -/// The previous design relied on lazy timeout flips during -/// `handle_resolve`/`handle_list_undelivered`. Operators who never opened the -/// UI for a session would leave its pending rows in `pending` forever and -/// the paused turn-orchestrator would never see a decision. Active sweeping -/// closes that hole. -fn spawn_timeout_sweeper( - iii: III, - bus: Arc, - state_scope: String, - interval_ms: u64, -) -> tokio::task::JoinHandle<()> { - tokio::spawn(async move { - let mut ticker = - tokio::time::interval(std::time::Duration::from_millis(interval_ms.max(50))); - ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - // Drop the immediate first tick so we don't sweep before any - // pending row could possibly exist. - ticker.tick().await; - loop { - ticker.tick().await; - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - let all = bus.list_prefix(&state_scope, "").await; - for (key, flipped, session_id, call_id) in collect_timed_out_for_sweep(&all, now_ms) { - if let Err(err) = bus.set(&state_scope, &key, flipped).await { - tracing::warn!( - "approval-gate sweeper: failed to flip {key} → timed_out: {err}" - ); - continue; - } - write_event(&iii, &session_id, &timeout_resolved_event(&call_id)).await; - } - } - }) -} - -async fn write_hook_reply(iii: &III, stream_name: &str, event_id: &str, reply: &Value) { - if stream_name.is_empty() || event_id.is_empty() { - return; - } - let _ = iii - .trigger(TriggerRequest { - function_id: "stream::set".into(), - payload: json!({ - "stream_name": stream_name, - "group_id": event_id, - "item_id": uuid_like(), - "data": reply, - }), - action: None, - timeout_ms: None, - }) - .await; -} - -/// Production [`StateBus`] backed by a real iii-sdk [`III`] connection. -pub struct IiiStateBus(pub III); - -#[async_trait::async_trait] -impl StateBus for IiiStateBus { - async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), iii_sdk::IIIError> { - self.0 - .trigger(TriggerRequest { - function_id: "state::set".into(), - payload: json!({ "scope": scope, "key": key, "value": value }), - action: None, - timeout_ms: None, - }) - .await - .map(|_| ()) - } - async fn get(&self, scope: &str, key: &str) -> Option { - self.0 - .trigger(TriggerRequest { - function_id: "state::get".into(), - payload: json!({ "scope": scope, "key": key }), - action: None, - timeout_ms: None, - }) - .await - .ok() - .filter(|v| !v.is_null()) - } - async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec { - let resp = self - .0 - .trigger(TriggerRequest { - function_id: "state::list".into(), - payload: json!({ "scope": scope, "prefix": prefix }), - action: None, - timeout_ms: None, - }) - .await - .unwrap_or_else(|_| json!({ "items": [] })); - // Engine may return either {"items": [...]} or a plain Array. - if let Some(arr) = resp.as_array() { - return arr.clone(); - } - resp.get("items") - .and_then(|v| v.as_array().cloned()) - .unwrap_or_default() - .into_iter() - .map(|entry| entry.get("value").cloned().unwrap_or(entry)) - .collect() - } -} - -/// Return the list of function ids whose interceptor asks the gate to -/// inject `__from_approval` without asserting that the target validates it. -/// Empty list ⇒ config is safe to register. Pure — exposed for tests and -/// for the boot-time check in [`register`]. -pub fn unverified_marker_targets(rules: &[InterceptorRule]) -> Vec<&str> { - rules - .iter() - .filter(|r| r.inject_approval_marker && !r.marker_target_verified) - .map(|r| r.function_id.as_str()) - .collect() -} - -pub fn register(iii: &III, cfg: &WorkerConfig) -> anyhow::Result { - let rules: Arc> = Arc::new(cfg.interceptors.clone()); - - // Fail fast on honor-system markers: any interceptor that asks the gate - // to inject `__from_approval` MUST also assert the target validates it. - // Without that assertion the marker is purely decorative and the gate - // has no way to know whether bypass-through-direct-trigger is contained. - let unverified = unverified_marker_targets(rules.as_slice()); - if !unverified.is_empty() { - return Err(anyhow::anyhow!( - "approval-gate: refusing to start — interceptors with inject_approval_marker=true \ - must also set marker_target_verified=true (target is asserted to validate \ - __from_approval against approval::lookup_record). Unverified: {unverified:?}" - )); - } - - for rule in rules.iter() { - if let Some(cid) = rule.classifier.as_deref() { - if cid == FN_LOOKUP_RECORD - || cid == FN_RESOLVE - || cid == FN_LIST_PENDING - || cid == FN_LIST_UNDELIVERED - || cid == FN_ACK_DELIVERED - || cid == FN_SWEEP_SESSION - { - tracing::warn!( - "approval-gate: interceptor for {:?} uses classifier {:?} which aliases an approval endpoint; fix config", - rule.function_id, - cid - ); - } - } - } - - let bus: Arc = Arc::new(IiiStateBus(iii.clone())); - let timeout_ms = cfg.default_timeout_ms; - let topic = cfg.topic.clone(); - let state_scope = cfg.approval_state_scope.clone(); - - let bus_for_resolve = bus.clone(); - let scope_resolve = state_scope.clone(); - let exec_for_resolve: Arc = Arc::new(IiiFunctionExecutor { - iii: iii.clone(), - rules: rules.clone(), - }); - let iii_for_resolve = iii.clone(); - let resolve = iii.register_function(( - RegisterFunctionMessage::with_id(FN_RESOLVE.into()).with_description( - "Resolve a pending approval. On allow, invokes the underlying function; \ - on deny, records the denial. The result is stitched into the agent's \ - next turn as a system message." - .into(), - ), - move |payload: Value| { - let bus = bus_for_resolve.clone(); - let exec = exec_for_resolve.clone(); - let scope_resolve = scope_resolve.clone(); - let iii = iii_for_resolve.clone(); - async move { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - let resp = handle_resolve( - bus.as_ref(), - exec.as_ref(), - &scope_resolve, - payload.clone(), - now_ms, - ) - .await; - - if resp.get("ok").and_then(Value::as_bool) == Some(true) { - let session_id = payload - .get("session_id") - .and_then(Value::as_str) - .unwrap_or(""); - let call_id = payload - .get("function_call_id") - .or_else(|| payload.get("tool_call_id")) - .and_then(Value::as_str) - .unwrap_or(""); - if !session_id.is_empty() && !call_id.is_empty() { - let key = pending_key(session_id, call_id); - if let Some(final_rec) = bus.get(&scope_resolve, &key).await { - let mut evt = json!({ - "type": "approval_resolved", - "function_call_id": call_id, - "tool_call_id": call_id, - }); - if let Some(status) = final_rec.get("status").and_then(Value::as_str) { - evt["decision"] = match status { - "executed" | "approved" => json!("allow"), - _ => json!("deny"), - }; - evt["status"] = json!(status); - } - if let Some(r) = final_rec.get("result") { - evt["result"] = json!(r); - } - if let Some(e) = final_rec.get("error") { - evt["error"] = json!(e); - } - if let Some(denial) = final_rec.get("denial") { - evt["denial"] = denial.clone(); - } - write_event(&iii, session_id, &evt).await; - } - } - } - Ok::<_, IIIError>(resp) - } - }, - )); - - let bus_for_list = bus.clone(); - let scope_list = state_scope.clone(); - let list_pending = iii.register_function(( - RegisterFunctionMessage::with_id(FN_LIST_PENDING.into()) - .with_description("Return pending approvals for a session.".into()), - move |payload: Value| { - let bus = bus_for_list.clone(); - let scope_list = scope_list.clone(); - async move { - Ok::<_, IIIError>(handle_list_pending(bus.as_ref(), &scope_list, payload).await) - } - }, - )); - - let bus_for_list_undelivered = bus.clone(); - let scope_list_undelivered = state_scope.clone(); - let list_undelivered = iii.register_function(( - RegisterFunctionMessage::with_id(FN_LIST_UNDELIVERED.into()).with_description( - "Return resolved approval records for a session that haven't yet been stitched \ - into an LLM turn. Lazy-flips expired pendings to timed_out." - .into(), - ), - move |payload: Value| { - let bus = bus_for_list_undelivered.clone(); - let scope = scope_list_undelivered.clone(); - async move { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - Ok::<_, IIIError>( - handle_list_undelivered(bus.as_ref(), &scope, payload, now_ms).await, - ) - } - }, - )); - - let bus_for_consume = bus.clone(); - let scope_consume = state_scope.clone(); - let consume_undelivered = iii.register_function(( - RegisterFunctionMessage::with_id(FN_CONSUME_UNDELIVERED.into()).with_description( - "Atomic list+ack of resolved approval records. Returns the same FIFO-capped \ - slice as list_undelivered AND stamps each entry with delivered_in_turn_id \ - before returning. Required payload: {session_id, turn_id, limit?}." - .into(), - ), - move |payload: Value| { - let bus = bus_for_consume.clone(); - let scope = scope_consume.clone(); - async move { - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - Ok::<_, IIIError>( - handle_consume_undelivered(bus.as_ref(), &scope, payload, now_ms).await, - ) - } - }, - )); - - let bus_for_ack = bus.clone(); - let scope_ack = state_scope.clone(); - let ack_delivered = - iii.register_function(( - RegisterFunctionMessage::with_id(FN_ACK_DELIVERED.into()).with_description( - "Stamp delivered_in_turn_id on resolved approvals so they aren't replayed \ - in subsequent turns. Idempotent." - .into(), - ), - move |payload: Value| { - let bus = bus_for_ack.clone(); - let scope = scope_ack.clone(); - async move { - Ok::<_, IIIError>(handle_ack_delivered(bus.as_ref(), &scope, payload).await) - } - }, - )); - - let bus_for_flush = bus.clone(); - let scope_flush = state_scope.clone(); - let flush_delivered = iii.register_function(( - RegisterFunctionMessage::with_id(FN_FLUSH_DELIVERED.into()).with_description( - "Stamp every unacked terminal approval record in a session as \ - delivered. One-shot operator recovery for backlog accumulation. \ - Required payload: {session_id, turn_id}." - .into(), - ), - move |payload: Value| { - let bus = bus_for_flush.clone(); - let scope = scope_flush.clone(); - async move { - Ok::<_, IIIError>(handle_flush_delivered(bus.as_ref(), &scope, payload).await) - } - }, - )); - - let bus_for_sweep = bus.clone(); - let scope_sweep = state_scope.clone(); - let sweep_session = - iii.register_function(( - RegisterFunctionMessage::with_id(FN_SWEEP_SESSION.into()).with_description( - "Sweep all pending approvals for a session to timed_out. \ - Called when a session is deleted." - .into(), - ), - move |payload: Value| { - let bus = bus_for_sweep.clone(); - let scope = scope_sweep.clone(); - async move { - Ok::<_, IIIError>(handle_sweep_session(bus.as_ref(), &scope, payload).await) - } - }, - )); - - let bus_for_lookup = bus.clone(); - let scope_lookup = state_scope.clone(); - let lookup_record = - iii.register_function(( - RegisterFunctionMessage::with_id(FN_LOOKUP_RECORD.into()).with_description( - "Return the approval state-store record for a session/function_call_id pair; \ - null when absent. Used by shell bypass validation." - .into(), - ), - move |payload: Value| { - let bus = bus_for_lookup.clone(); - let scope = scope_lookup.clone(); - async move { - Ok::<_, IIIError>(handle_lookup_record(bus.as_ref(), &scope, payload).await) - } - }, - )); - - let iii_for_sub = iii.clone(); - let bus_for_sub = bus.clone(); - let subscriber_scope = state_scope.clone(); - let rules_for_sub = rules.clone(); - let subscriber_fn = iii.register_function(( - RegisterFunctionMessage::with_id("policy::approval_gate".into()) - .with_description("Pause function calls listed in approval_required.".into()), - move |envelope: Value| { - let iii = iii_for_sub.clone(); - let bus = bus_for_sub.clone(); - let sc = subscriber_scope.clone(); - let intercept_rules = rules_for_sub.clone(); - async move { - let Some(call) = extract_call(&envelope) else { - return Ok::<_, IIIError>(json!({ "block": false })); - }; - let now_ms = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_millis() as u64) - .unwrap_or(0); - - let action = decide_intercept_action( - rule_for(intercept_rules.as_slice(), &call.function_id), - call.requires_approval(), - ); - let reply = match action { - InterceptAction::Pass => json!({ "block": false }), - InterceptAction::Pause => { - handle_intercept(bus.as_ref(), &sc, &call, now_ms, timeout_ms, false).await - } - InterceptAction::Classify { - classifier_fn, - classifier_timeout_ms, - } => match iii - .trigger(TriggerRequest { - function_id: classifier_fn.clone(), - payload: call.args.clone(), - action: None, - timeout_ms: Some(classifier_timeout_ms), - }) - .await - { - Ok(v) => match interpret_classifier_reply(&v, &classifier_fn) { - Ok(ClassifierDecision::Auto) => json!({ "block": false }), - Ok(ClassifierDecision::Deny(denial)) => json!({ - "block": true, - "denial": denial, - "status": "denied", - "call_id": call.function_call_id, - "function_id": call.function_id, - }), - Ok(ClassifierDecision::Ask) | Err(()) => { - handle_intercept( - bus.as_ref(), - &sc, - &call, - now_ms, - timeout_ms, - true, - ) - .await - } - }, - Err(_) => { - handle_intercept(bus.as_ref(), &sc, &call, now_ms, timeout_ms, true) - .await - } - }, - }; - - if reply.get("status").and_then(Value::as_str) == Some("pending") { - write_event( - &iii, - &call.session_id, - &json!({ - "type": "approval_requested", - "function_call_id": call.function_call_id, - "tool_call_id": call.function_call_id, - "function_id": call.function_id, - "tool_name": call.function_id, - "args": call.args, - "expires_at": now_ms.saturating_add(timeout_ms), - }), - ) - .await; - } - write_hook_reply(&iii, &call.reply_stream, &call.event_id, &reply).await; - Ok(reply) - } - }, - )); - - let subscriber_trigger = iii - .register_trigger(RegisterTriggerInput { - trigger_type: "durable:subscriber".into(), - function_id: "policy::approval_gate".into(), - config: json!({ "topic": topic }), - metadata: None, - }) - .map_err(|e| anyhow::anyhow!(e.to_string()))?; - - let sweeper = spawn_timeout_sweeper( - iii.clone(), - bus.clone(), - state_scope.clone(), - cfg.sweeper_interval_ms, - ); - - Ok(Refs { - resolve, - list_pending, - list_undelivered, - consume_undelivered, - ack_delivered, - flush_delivered, - sweep_session, - lookup_record, - subscriber_fn, - subscriber_trigger, - sweeper, - }) -} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn maybe_flip_timed_out_returns_some_when_pending_and_expired() { - let rec = build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000); - let flipped = maybe_flip_timed_out(&rec, 70_000).expect("should flip"); - assert_eq!(flipped["status"], "timed_out"); - // Timeout carries no Denial — the status alone explains the outcome. - assert!(flipped.get("denial").is_none()); - assert!(flipped.get("decision_reason").is_none()); - } - - #[test] - fn maybe_flip_timed_out_returns_none_when_pending_and_not_expired() { - let rec = build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000); - assert!(maybe_flip_timed_out(&rec, 60_000).is_none()); - assert!(maybe_flip_timed_out(&rec, 1_500).is_none()); - } - - #[test] - fn maybe_flip_timed_out_returns_none_when_not_pending() { - let rec = json!({ - "function_call_id": "tc-1", - "status": "executed", - "expires_at": 1_000_u64, - }); - assert!(maybe_flip_timed_out(&rec, 999_999_999).is_none()); - } - - #[test] - fn transition_record_stamps_resolved_at_for_terminal_status() { - let base = build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000); - let rec = transition_record_with_now( - &base, - "executed", - Some(json!({"ok": true})), - None, - None, - 12_345, - ); - assert_eq!(rec["resolved_at"].as_u64(), Some(12_345)); - } - - #[test] - fn transition_record_preserves_existing_resolved_at_on_relift() { - let base = build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000); - let first = transition_record_with_now( - &base, - "executed", - Some(json!({"ok": true})), - None, - None, - 12_345, - ); - let second = transition_record_with_now( - &first, - "executed", - Some(json!({"ok": true})), - None, - None, - 99_999, - ); - assert_eq!(second["resolved_at"].as_u64(), Some(12_345)); - } - - #[test] - fn transition_record_does_not_stamp_resolved_at_for_intermediate_status() { - let base = build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000); - let rec = - transition_record_with_now(&base, "approved", None, None, None, 12_345); - assert!(rec.get("resolved_at").is_none()); - } - - #[tokio::test] - async fn handle_list_undelivered_caps_at_default_limit_and_reports_omitted() { - let bus = InMemoryStateBus::new(); - for i in 0..75 { - let cid = format!("c{i}"); - let mut rec = transition_record_with_now( - &build_pending_record(&cid, "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_000 + i as u64, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", &cid), rec) - .await - .unwrap(); - } - let resp = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - assert_eq!(resp["entries"].as_array().unwrap().len(), 50); - assert_eq!(resp["omitted"].as_u64(), Some(25)); - } - - #[tokio::test] - async fn handle_list_undelivered_honors_explicit_limit() { - let bus = InMemoryStateBus::new(); - for i in 0..10 { - let cid = format!("c{i}"); - let mut rec = transition_record_with_now( - &build_pending_record(&cid, "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_000 + i as u64, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", &cid), rec) - .await - .unwrap(); - } - let resp = handle_list_undelivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "limit": 3}), - 100_000, - ) - .await; - assert_eq!(resp["entries"].as_array().unwrap().len(), 3); - assert_eq!(resp["omitted"].as_u64(), Some(7)); - } - - #[tokio::test] - async fn handle_list_undelivered_returns_oldest_first_by_resolved_at() { - let bus = InMemoryStateBus::new(); - for (i, ts) in [(0_u32, 5_000_u64), (1, 1_000), (2, 3_000)] { - let cid = format!("c{i}"); - let mut rec = transition_record_with_now( - &build_pending_record(&cid, "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ts, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", &cid), rec) - .await - .unwrap(); - } - let resp = handle_list_undelivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "limit": 10}), - 100_000, - ) - .await; - let entries = resp["entries"].as_array().unwrap(); - let ids: Vec<&str> = entries - .iter() - .map(|e| e["function_call_id"].as_str().unwrap()) - .collect(); - assert_eq!(ids, vec!["c1", "c2", "c0"]); - } - - #[tokio::test] - async fn handle_list_undelivered_omitted_is_zero_when_under_limit() { - let bus = InMemoryStateBus::new(); - let mut rec = transition_record_with_now( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_500, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", "c1"), rec) - .await - .unwrap(); - let resp = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - assert_eq!(resp["entries"].as_array().unwrap().len(), 1); - assert_eq!(resp["omitted"].as_u64(), Some(0)); - } - - #[tokio::test] - async fn handle_consume_undelivered_stamps_returned_entries() { - let bus = InMemoryStateBus::new(); - for i in 0..3 { - let cid = format!("c{i}"); - let mut rec = transition_record_with_now( - &build_pending_record(&cid, "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_000 + i as u64, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", &cid), rec) - .await - .unwrap(); - } - let resp = handle_consume_undelivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "turn_id": "turn-7", "limit": 10}), - 100_000, - ) - .await; - assert_eq!(resp["ok"], json!(true)); - assert_eq!(resp["entries"].as_array().unwrap().len(), 3); - assert_eq!(resp["omitted"].as_u64(), Some(0)); - let next = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - assert_eq!(next["entries"].as_array().unwrap().len(), 0); - } - - #[tokio::test] - async fn handle_consume_undelivered_respects_limit_and_leaves_remainder() { - let bus = InMemoryStateBus::new(); - for i in 0..5 { - let cid = format!("c{i}"); - let mut rec = transition_record_with_now( - &build_pending_record(&cid, "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_000 + i as u64, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", &cid), rec) - .await - .unwrap(); - } - let resp = handle_consume_undelivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "turn_id": "turn-7", "limit": 2}), - 100_000, - ) - .await; - assert_eq!(resp["entries"].as_array().unwrap().len(), 2); - assert_eq!(resp["omitted"].as_u64(), Some(3)); - let next = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - assert_eq!(next["entries"].as_array().unwrap().len(), 3); - } - - #[tokio::test] - async fn handle_consume_undelivered_missing_turn_id_returns_error() { - let bus = InMemoryStateBus::new(); - let resp = handle_consume_undelivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1"}), - 100_000, - ) - .await; - assert_eq!(resp["ok"], json!(false)); - assert_eq!(resp["error"], json!("missing_turn_id")); - } - - #[tokio::test] - async fn handle_flush_delivered_stamps_all_unacked_terminals() { - let bus = InMemoryStateBus::new(); - for i in 0..5 { - let cid = format!("c{i}"); - let mut rec = transition_record_with_now( - &build_pending_record(&cid, "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_000 + i as u64, - ); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", &cid), rec) - .await - .unwrap(); - } - let resp = handle_flush_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "turn_id": "manual-flush"}), - ) - .await; - assert_eq!(resp["ok"], json!(true)); - assert_eq!(resp["stamped"].as_u64(), Some(5)); - let next = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - assert_eq!(next["entries"].as_array().unwrap().len(), 0); - } - - #[tokio::test] - async fn handle_flush_delivered_skips_pending_records() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "c1"), - build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - ) - .await - .unwrap(); - let resp = handle_flush_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "turn_id": "manual-flush"}), - ) - .await; - assert_eq!(resp["stamped"].as_u64(), Some(0)); - let still = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .unwrap(); - assert_eq!(still["status"].as_str(), Some("pending")); - assert!(still.get("delivered_in_turn_id").is_none()); - } - - #[tokio::test] - async fn handle_flush_delivered_idempotent_on_already_stamped() { - let bus = InMemoryStateBus::new(); - let mut rec = transition_record_with_now( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - 1_500, - ); - { - let obj = rec.as_object_mut().unwrap(); - obj.insert( - "delivered_in_turn_id".into(), - Value::String("turn-prev".into()), - ); - obj.insert("session_id".into(), Value::String("s1".into())); - } - bus.set(STATE_SCOPE, &pending_key("s1", "c1"), rec) - .await - .unwrap(); - let resp = handle_flush_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "turn_id": "manual-flush"}), - ) - .await; - assert_eq!(resp["stamped"].as_u64(), Some(0)); - let still = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .unwrap(); - assert_eq!(still["delivered_in_turn_id"].as_str(), Some("turn-prev")); - } - - #[tokio::test] - async fn handle_list_undelivered_returns_terminal_records_with_no_delivered_stamp() { - let bus = InMemoryStateBus::new(); - let mut r1 = transition_record( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ); - r1.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", "c1"), r1) - .await - .unwrap(); - let mut r2 = transition_record( - &build_pending_record("c2", "shell::fs::write", &json!({}), 1_000, 60_000), - "denied", - None, - None, - Some(Denial::UserCorrected { - feedback: "nope".into(), - }), - ); - r2.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", "c2"), r2) - .await - .unwrap(); - - let resp = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - let entries = resp["entries"].as_array().unwrap(); - assert_eq!(entries.len(), 2); - assert_eq!(resp["omitted"].as_u64(), Some(0)); - } - - #[tokio::test] - async fn handle_list_undelivered_excludes_pending_records() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "c1"), - build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - ) - .await - .unwrap(); - - let resp = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 1_500).await; - assert_eq!(resp["entries"].as_array().unwrap().len(), 0); - } - - #[tokio::test] - async fn handle_list_undelivered_empty_session_returns_empty() { - let bus = InMemoryStateBus::new(); - let resp = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 1_500).await; - assert_eq!(resp["entries"], json!([])); - } - - #[tokio::test] - async fn handle_list_undelivered_excludes_records_stamped_with_delivered_turn_id() { - let bus = InMemoryStateBus::new(); - let mut rec = transition_record( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ); - { - let obj = rec.as_object_mut().unwrap(); - obj.insert( - "delivered_in_turn_id".into(), - Value::String("turn-prev".into()), - ); - obj.insert("session_id".into(), Value::String("s1".into())); - } - bus.set(STATE_SCOPE, &pending_key("s1", "c1"), rec) - .await - .unwrap(); - - let mut r2 = transition_record( - &build_pending_record("c2", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ); - r2.as_object_mut() - .unwrap() - .insert("session_id".into(), Value::String("s1".into())); - bus.set(STATE_SCOPE, &pending_key("s1", "c2"), r2) - .await - .unwrap(); - - let resp = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 100_000).await; - let entries = resp["entries"].as_array().unwrap(); - assert_eq!(entries.len(), 1); - assert_eq!(entries[0]["function_call_id"], "c2"); - } - - #[tokio::test] - async fn handle_list_undelivered_returns_empty_when_session_id_missing() { - let bus = InMemoryStateBus::new(); - let resp = handle_list_undelivered(&bus, STATE_SCOPE, json!({}), 1_500).await; - assert_eq!(resp["entries"], json!([])); - } - - #[tokio::test] - async fn handle_ack_delivered_stamps_records_with_turn_id() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "c1"), - transition_record( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ), - ) - .await - .unwrap(); - - let resp = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({ - "session_id": "s1", - "call_ids": ["c1"], - "turn_id": "turn-1", - }), - ) - .await; - assert_eq!(resp["ok"], json!(true)); - assert_eq!(resp["stamped"], json!(1)); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .unwrap(); - assert_eq!(rec["delivered_in_turn_id"], "turn-1"); - } - - #[tokio::test] - async fn handle_ack_delivered_is_idempotent_keeps_first_turn_id() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "c1"), - transition_record( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ), - ) - .await - .unwrap(); - - let _ = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({ - "session_id": "s1", "call_ids": ["c1"], "turn_id": "turn-first", - }), - ) - .await; - let resp = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({ - "session_id": "s1", "call_ids": ["c1"], "turn_id": "turn-second", - }), - ) - .await; - assert_eq!(resp["stamped"], json!(0), "second ack must not re-stamp"); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .unwrap(); - assert_eq!(rec["delivered_in_turn_id"], "turn-first"); - } - - #[tokio::test] - async fn handle_ack_delivered_skips_unknown_call_ids_silently() { - let bus = InMemoryStateBus::new(); - let resp = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({ - "session_id": "s1", "call_ids": ["ghost"], "turn_id": "turn-1", - }), - ) - .await; - assert_eq!(resp["ok"], json!(true)); - assert_eq!(resp["stamped"], json!(0)); - } - - #[tokio::test] - async fn handle_resolve_on_expired_pending_flips_to_timed_out_and_ignores_decision() { - let bus = InMemoryStateBus::new(); - let exec = FakeExecutor::default(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000), - ) - .await - .unwrap(); - - let resp = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({"session_id":"s1","function_call_id":"tc-1","decision":"allow"}), - 70_000, - ) - .await; - assert_eq!(resp["ok"], json!(false)); - assert_eq!(resp["error"], "timed_out"); - - assert!(exec.calls.lock().unwrap().is_empty()); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "tc-1")) - .await - .unwrap(); - assert_eq!(rec["status"], "timed_out"); - } - - #[test] - fn migrate_legacy_record_maps_allow_to_executed_without_result() { - let legacy = json!({ - "function_call_id": "c1", - "function_id": "shell::fs::write", - "args": {}, - "status": "allow", - "expires_at": 1_000_u64, - }); - let migrated = migrate_legacy_record(&legacy).expect("migrates"); - assert_eq!(migrated["status"], "executed"); - assert!( - migrated["result"].is_null() - || migrated.get("result").is_none() - || migrated["result"] == json!(null) - ); - assert_eq!(migrated["legacy_migrated"], json!(true)); - } - - #[test] - fn migrate_legacy_record_maps_deny_to_denied_with_legacy_denial() { - let legacy = json!({ - "function_call_id": "c1", - "status": "deny", - "reason": "manual", - "expires_at": 1_000_u64, - }); - let migrated = migrate_legacy_record(&legacy).expect("migrates"); - assert_eq!(migrated["status"], "denied"); - assert_eq!(migrated["denial"]["kind"], "legacy"); - assert_eq!(migrated["denial"]["detail"]["reason"], "manual"); - assert_eq!(migrated["legacy_migrated"], json!(true)); - assert!( - migrated.get("decision_reason").is_none(), - "legacy decision_reason must be stripped: {migrated}" - ); - assert!( - migrated.get("reason").is_none(), - "pre-trigger-model `reason` must be stripped after lifting into denial: {migrated}" - ); - } - - #[test] - fn migrate_legacy_record_lifts_decision_reason_when_status_is_already_current() { - // Pre-Denial records carry the flat `decision_reason: ` — - // migration lifts it into Denial::Legacy and strips the old field. - let legacy = json!({ - "function_call_id": "c1", - "status": "denied", - "decision_reason": "user typed nope", - "expires_at": 1_000_u64, - }); - let migrated = migrate_legacy_record(&legacy).expect("should lift"); - assert_eq!(migrated["status"], "denied"); - assert_eq!(migrated["denial"]["kind"], "legacy"); - assert_eq!(migrated["denial"]["detail"]["reason"], "user typed nope"); - assert!(migrated.get("decision_reason").is_none()); - assert_eq!(migrated["legacy_migrated"], json!(true)); - } - - #[test] - fn migrate_legacy_record_returns_none_for_new_status_strings() { - for new_status in [ - "pending", - "executed", - "failed", - "denied", - "timed_out", - "approved", - ] { - let rec = json!({"status": new_status}); - assert!( - migrate_legacy_record(&rec).is_none(), - "should not migrate already-new status '{}'", - new_status - ); - } - } - - #[test] - fn fn_constants_match_spec_strings() { - assert_eq!(FN_RESOLVE, "approval::resolve"); - assert_eq!(FN_LIST_PENDING, "approval::list_pending"); - assert_eq!(FN_LIST_UNDELIVERED, "approval::list_undelivered"); - assert_eq!(FN_ACK_DELIVERED, "approval::ack_delivered"); - assert_eq!(FN_LOOKUP_RECORD, "approval::lookup_record"); - } - - #[test] - fn interpret_classifier_reply_reads_decision_tags() { - assert!(matches!( - interpret_classifier_reply(&json!({"decision": "auto"}), "shell::classify_argv"), - Ok(ClassifierDecision::Auto) - )); - match interpret_classifier_reply( - &json!({"decision":"deny","reason":"nope"}), - "shell::classify_argv", - ) { - Ok(ClassifierDecision::Deny(Denial::Policy { - classifier_reason, - classifier_fn, - })) => { - assert_eq!(classifier_reason, "nope"); - assert_eq!(classifier_fn, "shell::classify_argv"); - } - o => panic!("expected Policy denial {:?}", o), - } - assert!(matches!( - interpret_classifier_reply( - &json!({"decision":"ask","summary":"x"}), - "shell::classify_argv" - ), - Ok(ClassifierDecision::Ask) - )); - assert!(interpret_classifier_reply(&json!({}), "shell::classify_argv").is_err()); - } - - #[test] - fn merge_from_approval_inserts_marker_when_inject_true() { - let m = merge_from_approval_marker_if_needed( - true, - json!({"command": "git"}), - "call-1", - "sess-1", - ); - let inner = m.get("__from_approval").unwrap(); - assert_eq!(inner["call_id"], "call-1"); - assert_eq!(inner["session_id"], "sess-1"); - assert_eq!(m["command"], "git"); - } - - #[test] - fn merge_from_approval_noop_when_inject_false() { - let j = json!({"a": 1}); - let out = merge_from_approval_marker_if_needed(false, j.clone(), "c", "s"); - assert_eq!(out, j); - } - - #[test] - fn rule_for_returns_matching_rule() { - let rules = vec![ - InterceptorRule { - function_id: "shell::exec".into(), - classifier: Some("shell::classify_argv".into()), - classifier_timeout_ms: 2000, - inject_approval_marker: true, - marker_target_verified: true, - }, - InterceptorRule { - function_id: "other::fn".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: false, - marker_target_verified: false, - }, - ]; - let r = rule_for(&rules, "shell::exec").expect("match"); - assert_eq!(r.classifier.as_deref(), Some("shell::classify_argv")); - assert!(r.inject_approval_marker); - } - - #[test] - fn rule_for_returns_none_when_absent() { - let rules = vec![InterceptorRule { - function_id: "x::y".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: false, - marker_target_verified: false, - }]; - assert!(rule_for(&rules, "missing::id").is_none()); - } - - /// An operator-registered rule is authoritative: every call to that - /// function id runs through the classifier, even when the run's - /// `approval_required` list is empty. This is the inverted contract - /// vs. the original "approval_required ANDs the rule" gate. - #[test] - fn decide_intercept_action_classifies_when_rule_has_classifier_regardless_of_approval_required() { - let rule = InterceptorRule { - function_id: "shell::exec".into(), - classifier: Some("shell::classify_argv".into()), - classifier_timeout_ms: 2000, - inject_approval_marker: true, - marker_target_verified: true, - }; - let action = decide_intercept_action(Some(&rule), false); - assert_eq!( - action, - InterceptAction::Classify { - classifier_fn: "shell::classify_argv".into(), - classifier_timeout_ms: 2000, - } - ); - assert_eq!(action, decide_intercept_action(Some(&rule), true)); - } - - #[test] - fn decide_intercept_action_pauses_when_rule_has_no_classifier_regardless_of_approval_required() { - let rule = InterceptorRule { - function_id: "shell::fs::write".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: false, - marker_target_verified: false, - }; - assert_eq!( - decide_intercept_action(Some(&rule), false), - InterceptAction::Pause - ); - assert_eq!( - decide_intercept_action(Some(&rule), true), - InterceptAction::Pause - ); - } - - #[test] - fn decide_intercept_action_pauses_when_no_rule_but_run_listed_approval_required() { - assert_eq!(decide_intercept_action(None, true), InterceptAction::Pause); - } - - #[test] - fn decide_intercept_action_passes_when_no_rule_and_not_approval_required() { - assert_eq!(decide_intercept_action(None, false), InterceptAction::Pass); - } - - #[test] - fn decide_intercept_action_classifier_empty_string_treated_as_no_classifier() { - let rule = InterceptorRule { - function_id: "shell::exec".into(), - classifier: Some(String::new()), - classifier_timeout_ms: 2000, - inject_approval_marker: false, - marker_target_verified: false, - }; - assert_eq!( - decide_intercept_action(Some(&rule), false), - InterceptAction::Pause - ); - } - - #[test] - fn is_terminal_status_returns_true_for_terminal_states() { - assert!(is_terminal_status("executed")); - assert!(is_terminal_status("failed")); - assert!(is_terminal_status("denied")); - assert!(is_terminal_status("timed_out")); - } - - #[test] - fn is_terminal_status_returns_false_for_in_progress_states() { - assert!(!is_terminal_status("pending")); - assert!(!is_terminal_status("approved")); - assert!(!is_terminal_status("anything_else")); - assert!(!is_terminal_status("")); - } - - #[test] - fn pending_key_includes_session_and_tool_call_id() { - assert_eq!(pending_key("s1", "tc-1"), "s1/tc-1"); - } - - #[test] - fn extract_call_reads_session_id_and_function_call_from_envelope() { - let envelope = json!({ - "event_id": "evt-1", - "reply_stream": "rs-1", - "payload": { - "function_call": { "id": "tc-1", "function_id": "write", "arguments": {"path": "/tmp/x"} }, - "approval_required": ["write"], - "session_id": "s1", - } - }); - let call = extract_call(&envelope).expect("decoded"); - assert_eq!(call.session_id, "s1"); - assert_eq!(call.function_call_id, "tc-1"); - assert_eq!(call.function_id, "write"); - assert_eq!(call.event_id, "evt-1"); - assert_eq!(call.reply_stream, "rs-1"); - assert!(call.approval_required.iter().any(|s| s == "write")); - } - - #[test] - fn extract_call_accepts_legacy_tool_call_envelope_with_name() { - let envelope = json!({ - "event_id": "evt-1", - "reply_stream": "rs-1", - "payload": { - "tool_call": { "id": "tc-1", "name": "write", "arguments": {} }, - "approval_required": ["write"], - "session_id": "s1", - } - }); - let call = extract_call(&envelope).expect("decoded"); - assert_eq!(call.function_call_id, "tc-1"); - assert_eq!(call.function_id, "write"); - } - - #[test] - fn requires_approval_only_for_listed_functions() { - let call = IncomingCall { - session_id: "s1".into(), - function_call_id: "tc-1".into(), - function_id: "ls".into(), - args: json!({}), - approval_required: vec!["write".into()], - event_id: "e".into(), - reply_stream: "r".into(), - }; - assert!(!call.requires_approval()); - - let call2 = IncomingCall { - function_id: "write".into(), - ..call - }; - assert!(call2.requires_approval()); - } - - #[test] - fn build_pending_record_sets_status_and_expiry() { - let now = 1_000_000; - let rec = build_pending_record("tc-1", "write", &json!({"x": 1}), now, 60_000); - assert_eq!(rec["status"], "pending"); - assert_eq!(rec["function_call_id"], "tc-1"); - assert_eq!(rec["expires_at"], 1_060_000); - } - - #[test] - fn block_reply_for_decision_allow_does_not_block() { - let reply = block_reply_for(&Decision::Allow); - assert_eq!(reply["block"], false); - } - - #[test] - fn block_reply_for_deny_emits_structured_denial() { - let reply = block_reply_for(&Decision::Deny(Denial::UserRejected)); - assert_eq!(reply["block"], true); - assert_eq!(reply["denial"]["kind"], "user_rejected"); - assert!(reply.as_object().unwrap().get("reason").is_none()); - } - - #[test] - fn block_reply_for_policy_deny_carries_classifier_detail() { - let reply = block_reply_for(&Decision::Deny(Denial::Policy { - classifier_reason: "command matches denylist".into(), - classifier_fn: "shell::classify_argv".into(), - })); - assert_eq!(reply["block"], true); - assert_eq!(reply["denial"]["kind"], "policy"); - assert_eq!( - reply["denial"]["detail"]["classifier_reason"], - "command matches denylist" - ); - assert_eq!( - reply["denial"]["detail"]["classifier_fn"], - "shell::classify_argv" - ); - } - - #[test] - fn block_reply_for_user_corrected_carries_feedback() { - let reply = block_reply_for(&Decision::Deny(Denial::UserCorrected { - feedback: "use git diff instead".into(), - })); - assert_eq!(reply["denial"]["kind"], "user_corrected"); - assert_eq!( - reply["denial"]["detail"]["feedback"], - "use git diff instead" - ); - } - - #[test] - fn extract_call_returns_none_when_function_call_absent() { - let envelope = json!({ - "event_id": "evt-1", - "reply_stream": "rs-1", - "payload": { "session_id": "s1", "approval_required": ["write"] } - }); - assert!(extract_call(&envelope).is_none()); - } - - #[test] - fn extract_call_returns_none_when_session_id_absent() { - let envelope = json!({ - "event_id": "evt-1", - "reply_stream": "rs-1", - "payload": { - "tool_call": { "id": "tc-1", "name": "write", "arguments": {} } - } - }); - assert!(extract_call(&envelope).is_none()); - } - - #[test] - fn block_reply_for_allow_omits_denial_and_reason() { - let reply = block_reply_for(&Decision::Allow); - assert_eq!(reply["block"], false); - assert!( - reply.get("reason").is_none(), - "Allow must not include reason: {reply}" - ); - assert!( - reply.get("denial").is_none(), - "Allow must not include denial: {reply}" - ); - } - - use std::sync::Mutex; - - fn sample_call() -> IncomingCall { - IncomingCall { - session_id: "s1".into(), - function_call_id: "tc-1".into(), - function_id: "shell::fs::write".into(), - args: json!({"path": "/tmp/a"}), - approval_required: vec!["shell::fs::write".into()], - event_id: "evt-1".into(), - reply_stream: "rs-1".into(), - } - } - - #[tokio::test] - async fn handle_intercept_returns_pending_envelope_when_call_is_gated() { - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let reply = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - assert_eq!(reply["block"], json!(true)); - assert_eq!(reply["status"], json!("pending")); - assert_eq!(reply["call_id"], json!("tc-1")); - assert_eq!(reply["function_id"], json!("shell::fs::write")); - // Pending status is self-describing — no `reason` or `denial` field - // is emitted while the call is in-flight. - assert!(reply.get("reason").is_none()); - assert!(reply.get("denial").is_none()); - } - - #[tokio::test] - async fn handle_intercept_writes_pending_record_to_state() { - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let _ = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - let key = pending_key(&call.session_id, &call.function_call_id); - let rec = bus - .get(STATE_SCOPE, &key) - .await - .expect("pending record written"); - assert_eq!(rec["status"], "pending"); - assert_eq!(rec["function_call_id"], "tc-1"); - assert_eq!(rec["expires_at"], 61_000); - } - - #[tokio::test] - async fn handle_intercept_passes_through_when_call_is_not_gated() { - let bus = InMemoryStateBus::new(); - let mut call = sample_call(); - call.approval_required = vec!["other".into()]; - let reply = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - assert_eq!(reply["block"], json!(false)); - let key = pending_key(&call.session_id, &call.function_call_id); - assert!( - bus.get(STATE_SCOPE, &key).await.is_none(), - "no record written" - ); - } - - #[tokio::test] - async fn handle_intercept_force_pending_writes_when_not_on_required_list() { - let bus = InMemoryStateBus::new(); - let mut call = sample_call(); - call.approval_required = vec!["other".into()]; - let reply = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, true).await; - assert_eq!(reply["block"], json!(true)); - assert_eq!(reply["status"], json!("pending")); - let key = pending_key(&call.session_id, &call.function_call_id); - assert!(bus.get(STATE_SCOPE, &key).await.is_some()); - } - - #[tokio::test] - async fn handle_lookup_record_returns_null_when_missing() { - let bus = InMemoryStateBus::new(); - let v = handle_lookup_record( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "function_call_id": "c1"}), - ) - .await; - assert!(v.is_null()); - } - - #[tokio::test] - async fn handle_lookup_record_returns_record_when_present() { - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let _ = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - let v = handle_lookup_record( - &bus, - STATE_SCOPE, - json!({"session_id": "s1", "function_call_id": "tc-1"}), - ) - .await; - assert_eq!(v["status"], json!("pending")); - assert_eq!(v["function_id"], json!("shell::fs::write")); - } - - #[derive(Default)] - struct FakeExecutor { - calls: Mutex>, - response: Mutex>>, - } - - #[async_trait::async_trait] - impl FunctionExecutor for FakeExecutor { - async fn invoke( - &self, - function_id: &str, - args: Value, - function_call_id: &str, - session_id: &str, - ) -> Result { - self.calls.lock().unwrap().push(( - function_id.to_string(), - args, - function_call_id.to_string(), - session_id.to_string(), - )); - self.response - .lock() - .unwrap() - .clone() - .unwrap_or_else(|| Ok(json!({"ok": true}))) - } - } - - #[tokio::test] - async fn handle_resolve_allow_invokes_function_and_records_executed() { - let bus = InMemoryStateBus::new(); - let exec = FakeExecutor::default(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record( - "tc-1", - "shell::fs::write", - &json!({"path":"/a"}), - 1_000, - 60_000, - ), - ) - .await - .unwrap(); - - let resp = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({ - "session_id": "s1", - "function_call_id": "tc-1", - "decision": "allow", - }), - 1_500, - ) - .await; - assert_eq!(resp["ok"], json!(true)); - - let calls = exec.calls.lock().unwrap().clone(); - assert_eq!(calls.len(), 1); - assert_eq!(calls[0].0, "shell::fs::write"); - assert_eq!(calls[0].1, json!({"path":"/a"})); - assert_eq!(calls[0].2, "tc-1"); - assert_eq!(calls[0].3, "s1"); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "tc-1")) - .await - .unwrap(); - assert_eq!(rec["status"], "executed"); - assert_eq!(rec["result"], json!({"ok": true})); - } - - #[tokio::test] - async fn handle_resolve_deny_does_not_invoke_function() { - let bus = InMemoryStateBus::new(); - let exec = FakeExecutor::default(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000), - ) - .await - .unwrap(); - - let resp = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({ - "session_id": "s1", - "function_call_id": "tc-1", - "decision": "deny", - "denial": { - "kind": "user_corrected", - "detail": { "feedback": "not authorized" } - }, - }), - 1_500, - ) - .await; - assert_eq!(resp["ok"], json!(true)); - - assert!(exec.calls.lock().unwrap().is_empty()); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "tc-1")) - .await - .unwrap(); - assert_eq!(rec["status"], "denied"); - assert_eq!(rec["denial"]["kind"], "user_corrected"); - assert_eq!(rec["denial"]["detail"]["feedback"], "not authorized"); - } - - #[tokio::test] - async fn handle_resolve_allow_records_failed_when_function_errors() { - let bus = InMemoryStateBus::new(); - let exec = FakeExecutor::default(); - *exec.response.lock().unwrap() = Some(Err("EACCES".into())); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000), - ) - .await - .unwrap(); - - let resp = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({"session_id":"s1","function_call_id":"tc-1","decision":"allow"}), - 1_500, - ) - .await; - assert_eq!(resp["ok"], json!(true)); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "tc-1")) - .await - .unwrap(); - assert_eq!(rec["status"], "failed"); - assert_eq!(rec["error"], "EACCES"); - } - - #[tokio::test] - async fn fake_executor_records_calls() { - let exec = FakeExecutor::default(); - let out = exec - .invoke("shell::fs::write", json!({"x": 1}), "cid", "sid") - .await - .unwrap(); - assert_eq!(out, json!({"ok": true})); - let calls = exec.calls.lock().unwrap().clone(); - assert_eq!(calls.len(), 1); - assert_eq!(calls[0].0, "shell::fs::write"); - assert_eq!(calls[0].2, "cid"); - assert_eq!(calls[0].3, "sid"); - } - - struct InMemoryStateBus { - store: Mutex>, - } - - impl InMemoryStateBus { - fn new() -> Self { - Self { - store: Mutex::new(std::collections::HashMap::new()), - } - } - } - - #[async_trait::async_trait] - impl StateBus for InMemoryStateBus { - async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), iii_sdk::IIIError> { - self.store - .lock() - .unwrap() - .insert(format!("{scope}/{key}"), value); - Ok(()) - } - async fn get(&self, scope: &str, key: &str) -> Option { - self.store - .lock() - .unwrap() - .get(&format!("{scope}/{key}")) - .cloned() - } - async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec { - let map = self.store.lock().unwrap(); - map.iter() - .filter(|(k, _)| k.starts_with(&format!("{scope}/{prefix}"))) - .map(|(_, v)| v.clone()) - .collect() - } - } - - #[tokio::test] - async fn resolve_flips_status_when_pending() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "write", &json!({}), 0, 60_000), - ) - .await - .unwrap(); - - let exec = FakeExecutor::default(); - let out = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({ - "function_call_id": "tc-1", - "session_id": "s1", - "decision": "allow", - }), - 1_500, - ) - .await; - - assert_eq!(out["ok"], true); - let stored = bus - .get(STATE_SCOPE, &pending_key("s1", "tc-1")) - .await - .unwrap(); - assert_eq!(stored["status"], "executed"); - } - - #[tokio::test] - async fn resolve_accepts_legacy_tool_call_id_field() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "write", &json!({}), 0, 60_000), - ) - .await - .unwrap(); - - let exec = FakeExecutor::default(); - let out = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({ - "tool_call_id": "tc-1", - "session_id": "s1", - "decision": "allow", - }), - 1_500, - ) - .await; - - assert_eq!(out["ok"], true); - } - - #[tokio::test] - async fn resolve_rejects_already_resolved_entry() { - let bus = InMemoryStateBus::new(); - let mut rec = build_pending_record("tc-1", "write", &json!({}), 0, 60_000); - rec["status"] = json!("allow"); - bus.set(STATE_SCOPE, &pending_key("s1", "tc-1"), rec) - .await - .unwrap(); - - let exec = FakeExecutor::default(); - let out = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({"function_call_id": "tc-1", "session_id": "s1", "decision": "deny"}), - 1_500, - ) - .await; - assert_eq!(out["ok"], false); - assert_eq!(out["error"], "already_resolved"); - } - - #[tokio::test] - async fn list_pending_returns_only_pending_for_session() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "write", &json!({}), 0, 60_000), - ) - .await - .unwrap(); - let mut resolved = build_pending_record("tc-2", "write", &json!({}), 0, 60_000); - resolved["status"] = json!("allow"); - bus.set(STATE_SCOPE, &pending_key("s1", "tc-2"), resolved) - .await - .unwrap(); - bus.set( - STATE_SCOPE, - &pending_key("other", "tc-3"), - build_pending_record("tc-3", "write", &json!({}), 0, 60_000), - ) - .await - .unwrap(); - - let out = handle_list_pending(&bus, STATE_SCOPE, json!({ "session_id": "s1" })).await; - let items = out["pending"].as_array().unwrap(); - assert_eq!(items.len(), 1); - assert_eq!(items[0]["function_call_id"], "tc-1"); - } - - #[tokio::test] - async fn resolve_deny_without_denial_defaults_to_user_rejected() { - let bus = InMemoryStateBus::new(); - let _ = bus - .set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "write", &json!({}), 0, 60_000), - ) - .await; - - let exec = FakeExecutor::default(); - let out = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({ - "session_id": "s1", - "function_call_id": "tc-1", - "decision": "deny", - }), - 1_500, - ) - .await; - assert_eq!(out["ok"], true); - - let stored = bus - .get(STATE_SCOPE, &pending_key("s1", "tc-1")) - .await - .unwrap(); - assert_eq!(stored["status"], "denied"); - assert_eq!(stored["denial"]["kind"], "user_rejected"); - } - - #[tokio::test] - async fn resolve_deny_rejects_malformed_denial() { - let bus = InMemoryStateBus::new(); - let _ = bus - .set( - STATE_SCOPE, - &pending_key("s1", "tc-1"), - build_pending_record("tc-1", "write", &json!({}), 0, 60_000), - ) - .await; - - let exec = FakeExecutor::default(); - let out = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({ - "session_id": "s1", - "function_call_id": "tc-1", - "decision": "deny", - "denial": { "kind": "not_a_real_kind" }, - }), - 1_500, - ) - .await; - assert_eq!(out["ok"], false); - assert_eq!(out["error"], "bad_denial"); - } - - #[test] - fn transition_record_to_executed_attaches_result() { - let base = build_pending_record( - "tc-1", - "shell::fs::write", - &json!({"path":"/a"}), - 1_000, - 60_000, - ); - let rec = transition_record(&base, "executed", Some(json!({"ok": true})), None, None); - assert_eq!(rec["status"], "executed"); - assert_eq!(rec["result"], json!({"ok": true})); - assert!(rec.get("error").is_none() || rec["error"].is_null()); - assert_eq!(rec["function_call_id"], "tc-1"); - assert_eq!(rec["function_id"], "shell::fs::write"); - } - - #[test] - fn transition_record_to_failed_attaches_error() { - let base = build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000); - let rec = transition_record(&base, "failed", None, Some("EACCES".into()), None); - assert_eq!(rec["status"], "failed"); - assert_eq!(rec["error"], "EACCES"); - assert!(rec.get("result").is_none() || rec["result"].is_null()); - } - - #[test] - fn transition_record_to_denied_attaches_structured_denial() { - let base = build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000); - let rec = transition_record( - &base, - "denied", - None, - None, - Some(Denial::Policy { - classifier_reason: "not authorized".into(), - classifier_fn: "shell::classify_argv".into(), - }), - ); - assert_eq!(rec["status"], "denied"); - assert_eq!(rec["denial"]["kind"], "policy"); - assert_eq!(rec["denial"]["detail"]["classifier_reason"], "not authorized"); - assert!( - rec.get("decision_reason").is_none(), - "legacy decision_reason must not be written: {rec}" - ); - } - - #[test] - fn transition_record_to_timed_out_carries_no_denial() { - // Timeout status is self-describing — no Denial attached. - let base = build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000); - let rec = transition_record(&base, "timed_out", None, None, None); - assert_eq!(rec["status"], "timed_out"); - assert!(rec.get("denial").is_none()); - assert!(rec.get("decision_reason").is_none()); - } - - #[test] - fn transition_record_preserves_delivered_in_turn_id_when_set() { - let mut base = build_pending_record("tc-1", "shell::fs::write", &json!({}), 1_000, 60_000); - base.as_object_mut().unwrap().insert( - "delivered_in_turn_id".into(), - Value::String("turn-X".into()), - ); - let rec = transition_record(&base, "executed", Some(json!({"ok": true})), None, None); - assert_eq!(rec["delivered_in_turn_id"], "turn-X"); - } - - #[tokio::test] - async fn handle_sweep_session_flips_pending_records_to_timed_out_with_reason_session_deleted() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "c1"), - build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - ) - .await - .unwrap(); - - let resp = handle_sweep_session(&bus, STATE_SCOPE, json!({"session_id": "s1"})).await; - assert_eq!(resp["swept"], json!(1)); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .unwrap(); - assert_eq!(rec["status"], "timed_out"); - // sweep_session no longer stamps a reason string — timed_out is - // self-describing and the sweep cause is logged at info, not - // surfaced on the record. - assert!(rec.get("denial").is_none()); - assert!(rec.get("decision_reason").is_none()); - } - - #[tokio::test] - async fn handle_sweep_session_skips_non_pending_records() { - let bus = InMemoryStateBus::new(); - bus.set( - STATE_SCOPE, - &pending_key("s1", "c1"), - transition_record( - &build_pending_record("c1", "shell::fs::write", &json!({}), 1_000, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ), - ) - .await - .unwrap(); - - let resp = handle_sweep_session(&bus, STATE_SCOPE, json!({"session_id": "s1"})).await; - assert_eq!(resp["swept"], json!(0)); - - let rec = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .unwrap(); - assert_eq!(rec["status"], "executed"); - } - - #[tokio::test] - async fn handle_sweep_session_returns_error_when_session_id_missing() { - let bus = InMemoryStateBus::new(); - let resp = handle_sweep_session(&bus, STATE_SCOPE, json!({})).await; - assert_eq!(resp["ok"], json!(false)); - assert_eq!(resp["error"], "missing_session_id"); - assert_eq!(resp["swept"], json!(0)); - } - - // ── New reliability fixes ───────────────────────────────────────────── - - /// A bus that always refuses writes, to exercise fail-closed semantics. - struct FailingStateBus; - - #[async_trait::async_trait] - impl StateBus for FailingStateBus { - async fn set( - &self, - _scope: &str, - _key: &str, - _value: Value, - ) -> Result<(), iii_sdk::IIIError> { - Err(iii_sdk::IIIError::Runtime("kv unreachable".into())) - } - async fn get(&self, _scope: &str, _key: &str) -> Option { - None - } - async fn list_prefix(&self, _scope: &str, _prefix: &str) -> Vec { - Vec::new() - } - } - - #[tokio::test] - async fn handle_intercept_fails_closed_on_state_write_error() { - let bus = FailingStateBus; - let call = sample_call(); - let reply = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - assert_eq!( - reply["block"], - json!(true), - "state write failure must NOT fail-open" - ); - assert_eq!(reply["status"], json!("denied")); - assert_eq!(reply["denial"]["kind"], json!("state_error")); - assert_eq!( - reply["denial"]["detail"]["phase"], - json!("intercept_write_pending") - ); - // The underlying error message is present but its exact text is - // bus-implementation-specific; just check it's non-empty. - assert!( - reply["denial"]["detail"]["error"] - .as_str() - .map(|s| !s.is_empty()) - .unwrap_or(false), - "state_error detail must include error message: {reply}" - ); - assert_eq!(reply["function_id"], json!("shell::fs::write")); - } - - #[tokio::test] - async fn handle_intercept_stamps_session_id_into_pending_record() { - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let _ = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - let rec = bus - .get( - STATE_SCOPE, - &pending_key(&call.session_id, &call.function_call_id), - ) - .await - .expect("pending record"); - assert_eq!(rec["session_id"], json!(call.session_id)); - } - - #[test] - fn collect_timed_out_for_sweep_returns_expired_records_with_session_id() { - let mut rec = build_pending_record("tc-1", "shell::fs::write", &json!({}), 0, 60_000); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), json!("s-42")); - let pile = vec![ - rec.clone(), - build_pending_record("tc-2", "shell::fs::write", &json!({}), 0, 999_999_999), - ]; - let out = collect_timed_out_for_sweep(&pile, 70_000); - assert_eq!(out.len(), 1); - let (key, flipped, session_id, call_id) = &out[0]; - assert_eq!(key, "s-42/tc-1"); - assert_eq!(session_id, "s-42"); - assert_eq!(call_id, "tc-1"); - assert_eq!(flipped["status"], json!("timed_out")); - // Timeout carries no Denial — status is self-describing. - assert!(flipped.get("denial").is_none()); - assert!(flipped.get("decision_reason").is_none()); - } - - #[test] - fn collect_timed_out_for_sweep_skips_records_without_session_id() { - // Legacy row (pre-session_id-stamping fix). The sweeper can't - // address the right session stream, so it must skip silently — - // lazy-flip on read will still pick it up. - let pile = vec![build_pending_record( - "tc-legacy", - "shell::fs::write", - &json!({}), - 0, - 60_000, - )]; - let out = collect_timed_out_for_sweep(&pile, 70_000); - assert!( - out.is_empty(), - "legacy record without session_id must not be swept" - ); - } - - #[test] - fn timeout_resolved_event_shape() { - let evt = timeout_resolved_event("tc-1"); - assert_eq!(evt["type"], "approval_resolved"); - assert_eq!(evt["function_call_id"], "tc-1"); - assert_eq!(evt["tool_call_id"], "tc-1"); - assert_eq!(evt["decision"], "deny"); - assert_eq!(evt["status"], "timed_out"); - assert_eq!(evt["decision_reason"], "timeout"); - } - - #[test] - fn unverified_marker_targets_lists_unasserted_rules() { - let rules = vec![ - InterceptorRule { - function_id: "shell::exec".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: true, - marker_target_verified: false, - }, - InterceptorRule { - function_id: "shell::exec_bg".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: true, - marker_target_verified: true, - }, - InterceptorRule { - function_id: "no_marker::fn".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: false, - marker_target_verified: false, - }, - ]; - assert_eq!(unverified_marker_targets(&rules), vec!["shell::exec"]); - } - - #[test] - fn unverified_marker_targets_empty_when_all_verified_or_marker_off() { - let rules = vec![ - InterceptorRule { - function_id: "shell::exec".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: true, - marker_target_verified: true, - }, - InterceptorRule { - function_id: "other".into(), - classifier: None, - classifier_timeout_ms: 2000, - inject_approval_marker: false, - marker_target_verified: false, - }, - ]; - assert!(unverified_marker_targets(&rules).is_empty()); - } - - // ── Boundary + edge-case tests prompted by cargo-mutants survivors ──── - // - // Each test corresponds to a mutant the test suite previously didn't - // catch. Test name → mutated line in src/lib.rs. - - #[test] - fn merge_from_approval_wraps_null_args_in_marker_only() { - // mutant L48: replace `other.is_null()` match guard - let out = merge_from_approval_marker_if_needed(true, Value::Null, "c1", "s1"); - assert!(out.get("__from_approval").is_some()); - assert!( - out.get("payload").is_none(), - "null-arg branch must NOT wrap as payload" - ); - } - - #[test] - fn merge_from_approval_wraps_scalar_args_in_payload() { - // mutant L48: same guard, the other branch - let out = merge_from_approval_marker_if_needed(true, json!("scalar"), "c1", "s1"); - assert!(out.get("__from_approval").is_some()); - assert_eq!( - out.get("payload"), - Some(&json!("scalar")), - "scalar-arg branch must wrap original under `payload`" - ); - } - - #[tokio::test] - async fn handle_intercept_replay_of_terminal_record_returns_already_resolved() { - // mutant L331: replace `==` with `!=` in the replay defense — if - // flipped, terminal records would be overwritten with fresh pending. - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let key = pending_key(&call.session_id, &call.function_call_id); - let terminal = transition_record( - &build_pending_record( - &call.function_call_id, - &call.function_id, - &call.args, - 0, - 60_000, - ), - "executed", - Some(json!({"ok": true})), - None, - None, - ); - bus.set(STATE_SCOPE, &key, terminal).await.unwrap(); - - let reply = handle_intercept(&bus, STATE_SCOPE, &call, 1_000, 60_000, false).await; - assert_eq!(reply["block"], json!(true)); - assert_eq!(reply["status"], json!("executed")); - // Replay reply: status carries the prior outcome, `replay` discriminator - // says we're echoing rather than denying afresh, and no `denial` is - // synthesized (the historical record is the source of truth). - assert_eq!(reply["replay"], json!("already_resolved")); - assert!(reply.get("denial").is_none()); - assert!(reply.get("reason").is_none()); - - // Crucial: the stored row is still `executed`, not overwritten. - let stored = bus.get(STATE_SCOPE, &key).await.unwrap(); - assert_eq!(stored["status"], json!("executed")); - assert_eq!(stored["result"], json!({"ok": true})); - } - - #[tokio::test] - async fn handle_intercept_replay_of_pending_record_preserves_expires_at() { - // mutant L331: same branch, pending side. New pending must not bump - // the expires_at on the existing row. - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let key = pending_key(&call.session_id, &call.function_call_id); - let pending = build_pending_record( - &call.function_call_id, - &call.function_id, - &call.args, - 0, - 60_000, - ); - bus.set(STATE_SCOPE, &key, pending.clone()).await.unwrap(); - - let _ = handle_intercept(&bus, STATE_SCOPE, &call, 999_000, 60_000, false).await; - let stored = bus.get(STATE_SCOPE, &key).await.unwrap(); - assert_eq!( - stored["expires_at"], pending["expires_at"], - "replay must not bump expires_at on the live row" - ); - } - - #[tokio::test] - async fn handle_lookup_record_rejects_when_only_one_id_is_empty() { - // mutant L395: `||` → `&&` would let one-empty slip through. - let bus = InMemoryStateBus::new(); - let v1 = handle_lookup_record( - &bus, - STATE_SCOPE, - json!({"session_id": "", "function_call_id": "c"}), - ) - .await; - assert!(v1.is_null()); - let v2 = handle_lookup_record( - &bus, - STATE_SCOPE, - json!({"session_id": "s", "function_call_id": ""}), - ) - .await; - assert!(v2.is_null()); - } - - #[tokio::test] - async fn handle_resolve_rejects_when_only_one_id_is_empty() { - // mutant L489: same `||` pattern in handle_resolve guard. - let bus = InMemoryStateBus::new(); - let exec = FakeExecutor::default(); - let r1 = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({"session_id": "", "function_call_id": "c", "decision": "allow"}), - 0, - ) - .await; - assert_eq!(r1["error"], json!("missing_id")); - let r2 = handle_resolve( - &bus, - &exec, - STATE_SCOPE, - json!({"session_id": "s", "function_call_id": "", "decision": "allow"}), - 0, - ) - .await; - assert_eq!(r2["error"], json!("missing_id")); - } - - #[tokio::test] - async fn handle_ack_delivered_returns_zero_when_only_one_field_is_empty() { - // mutant L677: two `||` operators in the empty-field guard. - let bus = InMemoryStateBus::new(); - // empty turn_id - let r1 = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s", "turn_id": "", "call_ids": ["c"]}), - ) - .await; - assert_eq!(r1["stamped"], json!(0)); - // empty call_ids - let r2 = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s", "turn_id": "t", "call_ids": []}), - ) - .await; - assert_eq!(r2["stamped"], json!(0)); - // empty session_id - let r3 = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "", "turn_id": "t", "call_ids": ["c"]}), - ) - .await; - assert_eq!(r3["stamped"], json!(0)); - } - - #[test] - fn collect_timed_out_for_sweep_rejects_record_missing_only_call_id() { - // mutant L423: `||` → `&&` would let one-empty records sweep. - let mut rec = build_pending_record("c1", "shell::fs::write", &json!({}), 0, 60_000); - rec.as_object_mut() - .unwrap() - .insert("session_id".into(), json!("s1")); - rec.as_object_mut() - .unwrap() - .insert("function_call_id".into(), json!("")); - let out = collect_timed_out_for_sweep(&[rec], 70_000); - assert!(out.is_empty(), "empty function_call_id must skip sweep"); - } - - #[tokio::test] - async fn handle_intercept_replay_of_approved_record_preserves_state() { - // mutant L331:42 — replace `==` with `!=` on the "approved" side. - // The L331:19 mutation is killed by the *_pending_* test above; - // this one requires an approved record specifically. - let bus = InMemoryStateBus::new(); - let call = sample_call(); - let key = pending_key(&call.session_id, &call.function_call_id); - let approved = transition_record( - &build_pending_record( - &call.function_call_id, - &call.function_id, - &call.args, - 0, - 60_000, - ), - "approved", - None, - None, - None, - ); - bus.set(STATE_SCOPE, &key, approved.clone()).await.unwrap(); - - let _ = handle_intercept(&bus, STATE_SCOPE, &call, 999_000, 60_000, false).await; - let stored = bus.get(STATE_SCOPE, &key).await.unwrap(); - assert_eq!( - stored["status"], - json!("approved"), - "replay of approved row must keep status; mutant would overwrite with pending" - ); - } - - #[tokio::test] - async fn handle_lookup_record_short_circuits_before_bus_get_on_one_empty_id() { - // mutant L395 — `||` → `&&` would let one-empty slip into bus.get. - // Seed a record at the address the mutant would compute (pending_key("", "c") = "/c"), - // so the mutant returns the seeded row while original code stays at Null. - let bus = InMemoryStateBus::new(); - bus.set(STATE_SCOPE, "/c", json!({"sentinel": "should_not_leak"})) - .await - .unwrap(); - let v = handle_lookup_record( - &bus, - STATE_SCOPE, - json!({"session_id": "", "function_call_id": "c"}), - ) - .await; - assert!( - v.is_null(), - "must short-circuit; the seeded sentinel must not leak through" - ); - } - - #[tokio::test] - async fn handle_ack_delivered_short_circuits_before_stamping_on_one_empty_field() { - // mutant L677 — two `||` operators. If either flips to `&&`, the - // function falls through and stamps a record even when a required - // field is empty. Seed a record so the stamping path can be - // observed. - let bus = InMemoryStateBus::new(); - let terminal = transition_record( - &build_pending_record("c", "shell::fs::write", &json!({}), 0, 60_000), - "executed", - Some(json!({"ok": true})), - None, - None, - ); - bus.set(STATE_SCOPE, &pending_key("s", "c"), terminal) - .await - .unwrap(); - - // empty turn_id — must NOT stamp the seeded record. - let r = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s", "turn_id": "", "call_ids": ["c"]}), - ) - .await; - assert_eq!(r["stamped"], json!(0)); - let stored = bus.get(STATE_SCOPE, &pending_key("s", "c")).await.unwrap(); - assert!( - stored.get("delivered_in_turn_id").is_none(), - "must not stamp when turn_id is empty; mutant would stamp" - ); - - // empty call_ids — same property. - let r = handle_ack_delivered( - &bus, - STATE_SCOPE, - json!({"session_id": "s", "turn_id": "t", "call_ids": []}), - ) - .await; - assert_eq!(r["stamped"], json!(0)); - let stored = bus.get(STATE_SCOPE, &pending_key("s", "c")).await.unwrap(); - assert!( - stored.get("delivered_in_turn_id").is_none(), - "must not stamp when call_ids is empty" - ); - } - - #[tokio::test] - async fn handle_list_undelivered_persists_migrated_legacy_record() { - // mutant L614 — `delete !` on the `if !call_id.is_empty()` guard. - // The legacy migration block writes the migrated row back to state - // so subsequent reads use the new shape. The mutant inverts the - // guard, suppressing the write. Verify the write happens. - let bus = InMemoryStateBus::new(); - // Pre-trigger-model row: status="allow" (legacy form). - let legacy = json!({ - "function_call_id": "c1", - "function_id": "shell::fs::write", - "args": {}, - "status": "allow", - "expires_at": 1_000u64, - }); - bus.set(STATE_SCOPE, &pending_key("s1", "c1"), legacy) - .await - .unwrap(); - - let _ = - handle_list_undelivered(&bus, STATE_SCOPE, json!({"session_id": "s1"}), 5_000).await; - - // Storage now reflects the migrated shape. - let stored = bus - .get(STATE_SCOPE, &pending_key("s1", "c1")) - .await - .expect("migrated row persisted"); - assert_eq!(stored["status"], json!("executed")); - assert_eq!(stored["legacy_migrated"], json!(true)); - } - - #[test] - fn maybe_flip_timed_out_flips_at_exact_expires_at() { - // mutant L439: `<` → `<=` would not flip at the exact boundary. - let rec = build_pending_record("c1", "f", &json!({}), 0, 60_000); - // expires_at = 0 + 60_000 = 60_000. At now=60_000 the gate - // considers the record expired (strictly past or AT expiry). - assert!( - maybe_flip_timed_out(&rec, 60_000).is_some(), - "must flip at exactly expires_at" - ); - assert!( - maybe_flip_timed_out(&rec, 59_999).is_none(), - "must not flip one ms before expires_at" - ); - } - - // ── proptest: state-machine invariants ──────────────────────────────── - // - // Random sequences of intercept/resolve/sweep/ack/lazy-flip operations - // on a single (session, call) record. After every step we assert four - // invariants that the lifecycle is supposed to guarantee: - // - // I1. status ∈ {pending, approved, executed, failed, denied, timed_out}. - // Any other string is a corrupt record. - // I2. Once a terminal status is observed, the record never returns to - // `pending`. Terminal = executed | failed | denied | timed_out. - // I3. Every `pending` record carries an `expires_at: u64`. Without it - // the sweeper and lazy-flip paths can't classify the record. - // I4. `delivered_in_turn_id` is monotonic: once a non-null value is - // written it is never unset, never replaced with a different turn. - // - // If any future change can produce a sequence that violates one of - // these, proptest will shrink to the minimal failing sequence and - // surface it as a counterexample. - - use proptest::prelude::*; - - #[derive(Debug, Clone)] - enum Op { - InterceptRequired, - InterceptNotRequired, - ResolveAllow, - ResolveDeny, - AdvanceClockAndLazyFlip, // bumps clock past expires_at, hits list_undelivered - SweepSession, - AckDelivered, - } - - fn arb_op() -> impl Strategy { - prop_oneof![ - Just(Op::InterceptRequired), - Just(Op::InterceptNotRequired), - Just(Op::ResolveAllow), - Just(Op::ResolveDeny), - Just(Op::AdvanceClockAndLazyFlip), - Just(Op::SweepSession), - Just(Op::AckDelivered), - ] - } - - fn make_call(approval_required_self: bool) -> IncomingCall { - IncomingCall { - session_id: "s".into(), - function_call_id: "c".into(), - function_id: "test::write".into(), - args: json!({}), - approval_required: if approval_required_self { - vec!["test::write".into()] - } else { - vec!["other::fn".into()] - }, - event_id: "e".into(), - reply_stream: "r".into(), - } - } - - proptest! { - #![proptest_config(ProptestConfig { - cases: 256, - .. ProptestConfig::default() - })] - - #[test] - fn state_machine_invariants(ops in prop::collection::vec(arb_op(), 1..30)) { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("tokio runtime"); - - rt.block_on(async { - let bus = InMemoryStateBus::new(); - let exec = FakeExecutor::default(); - let session_id = "s"; - let call_id = "c"; - let timeout_ms: u64 = 60_000; - let mut now_ms: u64 = 1_000; - - let mut ever_terminal = false; - let mut last_delivered: Option = None; - - for op in &ops { - match op { - Op::InterceptRequired => { - let call = make_call(true); - let _ = handle_intercept(&bus, STATE_SCOPE, &call, now_ms, timeout_ms, false).await; - } - Op::InterceptNotRequired => { - let call = make_call(false); - let _ = handle_intercept(&bus, STATE_SCOPE, &call, now_ms, timeout_ms, false).await; - } - Op::ResolveAllow => { - let _ = handle_resolve( - &bus, &exec, STATE_SCOPE, - json!({ - "session_id": session_id, - "function_call_id": call_id, - "decision": "allow", - }), - now_ms, - ).await; - } - Op::ResolveDeny => { - let _ = handle_resolve( - &bus, &exec, STATE_SCOPE, - json!({ - "session_id": session_id, - "function_call_id": call_id, - "decision": "deny", - "reason": "user", - }), - now_ms, - ).await; - } - Op::AdvanceClockAndLazyFlip => { - now_ms = now_ms.saturating_add(timeout_ms + 1); - let _ = handle_list_undelivered( - &bus, STATE_SCOPE, - json!({ "session_id": session_id }), - now_ms, - ).await; - } - Op::SweepSession => { - let _ = handle_sweep_session( - &bus, STATE_SCOPE, - json!({ "session_id": session_id }), - ).await; - } - Op::AckDelivered => { - let _ = handle_ack_delivered( - &bus, STATE_SCOPE, - json!({ - "session_id": session_id, - "turn_id": format!("turn-{now_ms}"), - "call_ids": [call_id], - }), - ).await; - } - } - - // Assert invariants on whatever the record currently is. - let key = pending_key(session_id, call_id); - let Some(rec) = bus.get(STATE_SCOPE, &key).await else { - // No record yet (e.g. only InterceptNotRequired so far). Skip. - continue; - }; - - // I1: legal status - let status = rec.get("status").and_then(Value::as_str).unwrap_or(""); - assert!( - matches!( - status, - "pending" | "approved" | "executed" | "failed" | "denied" | "timed_out" - ), - "I1 violated: illegal status {status:?} after ops {ops:?}; record={rec:?}" - ); - - // I2: no reverting terminal → pending - if matches!(status, "executed" | "failed" | "denied" | "timed_out") { - ever_terminal = true; - } - if ever_terminal { - assert!( - status != "pending", - "I2 violated: reverted to pending after terminal; ops={ops:?}; record={rec:?}" - ); - } - - // I3: pending records always have expires_at: u64 - if status == "pending" { - let exp = rec.get("expires_at").and_then(Value::as_u64); - assert!( - exp.is_some(), - "I3 violated: pending record missing expires_at; ops={ops:?}; record={rec:?}" - ); - } - - // I4: delivered_in_turn_id is monotonic — once set non-null, never unset / never replaced - let cur_delivered = rec - .get("delivered_in_turn_id") - .and_then(Value::as_str) - .map(str::to_string); - if let Some(prev) = &last_delivered { - match &cur_delivered { - Some(cur) => { - assert_eq!( - cur, prev, - "I4 violated: delivered_in_turn_id replaced {prev:?} → {cur:?}; ops={ops:?}" - ); - } - None => { - panic!( - "I4 violated: delivered_in_turn_id unset after being {prev:?}; ops={ops:?}; record={rec:?}" - ); - } - } - } - if cur_delivered.is_some() { - last_delivered = cur_delivered; - } - } - }); - } - } -} diff --git a/approval-gate/src/record.rs b/approval-gate/src/record.rs new file mode 100644 index 00000000..3ea2bdb8 --- /dev/null +++ b/approval-gate/src/record.rs @@ -0,0 +1,274 @@ +//! Approval-gate record schema. +//! +//! `Pending → InFlight → Done(Outcome)`. The intermediate InFlight write +//! between operator-approve and the executor `iii.trigger` is what closes +//! the duplicate-execution race — a second `approval::resolve` arriving +//! during the invoke await sees a non-Pending row and bails. +//! +//! `lifecycle.rs` is gone; its only surviving helper +//! (`flipped_to_timed_out_if_expired`) lives here as a `Record` method +//! because it operates on a `Record`. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::wire::Denial; + +/// Lifecycle status. Wire format is snake_case so iii-state dumps stay +/// human-readable. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Status { + /// Waiting for the operator's decision (no outcome attached). + Pending, + /// Operator approved; underlying `iii.trigger` is in flight. Persisted + /// to close the dup-exec race across concurrent `approval::resolve` + /// calls within a worker process. + InFlight, + /// Terminal. `outcome` is `Some`. + Done, +} + +/// Outcome data attached to terminal records. Tagged enum on the wire +/// (`{ "kind": "...", "detail": { ... } }`). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "kind", content = "detail", rename_all = "snake_case")] +pub enum Outcome { + Executed { result: Value }, + Failed { error: String }, + Denied { denial: Denial }, + TimedOut, +} + +/// Persisted approval record. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Record { + pub function_call_id: String, + pub function_id: String, + pub args: Value, + pub session_id: String, + pub expires_at: u64, + pub status: Status, + /// `Some` iff `status == Done`. Constructors enforce this invariant. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub outcome: Option, + /// Unix ms stamped on the first non-Pending transition. `handle_consume` + /// sorts entries by this field so multi-row consumes (cascade case) + /// produce deterministic LLM message order. Provider-minted + /// `function_call_id` (Anthropic `toolu_*`, OpenAI `call_*`) is not + /// lex-monotonic and can't substitute. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub resolved_at: Option, +} + +impl Record { + /// Fresh Pending row. `expires_at = now_ms + timeout_ms`, saturating on + /// overflow so a buggy caller can't underflow the deadline. + pub fn pending( + function_call_id: String, + function_id: String, + args: Value, + session_id: String, + now_ms: u64, + timeout_ms: u64, + ) -> Self { + Self { + function_call_id, + function_id, + args, + session_id, + expires_at: now_ms.saturating_add(timeout_ms), + status: Status::Pending, + outcome: None, + resolved_at: None, + } + } + + /// Pending → InFlight. Stamps `resolved_at` (the "first non-Pending" + /// marker for ordering). Caller is responsible for ensuring the row + /// was actually Pending before calling; this is enforced at the + /// callsite (`handle_resolve`) via a Status check. + pub fn in_flight(self, now_ms: u64) -> Self { + Self { + status: Status::InFlight, + resolved_at: Some(self.resolved_at.unwrap_or(now_ms)), + ..self + } + } + + /// InFlight → Done. Preserves `resolved_at` from the InFlight write + /// (so audit timestamps reflect when the row left Pending, not when + /// the invoke finished). + pub fn done(self, outcome: Outcome) -> Self { + Self { + status: Status::Done, + outcome: Some(outcome), + ..self + } + } + + /// Pending → Done directly (deny path, timeout flip — paths that + /// don't run an invoke). Stamps `resolved_at` with `now_ms`. + pub fn done_at(self, now_ms: u64, outcome: Outcome) -> Self { + Self { + status: Status::Done, + outcome: Some(outcome), + resolved_at: Some(self.resolved_at.unwrap_or(now_ms)), + ..self + } + } + + /// Lazy timeout flip. Returns `Some(flipped)` iff the row is Pending + /// AND `now_ms >= expires_at`. InFlight rows are owned by an + /// in-progress invoke and are never touched here. Done rows are + /// already terminal. + pub fn flipped_to_timed_out_if_expired(&self, now_ms: u64) -> Option { + if self.status == Status::Pending && now_ms >= self.expires_at { + Some(self.clone().done_at(now_ms, Outcome::TimedOut)) + } else { + None + } + } + + /// Wire JSON shape (infallible — only serializable fields). + pub fn to_value(&self) -> Value { + serde_json::to_value(self).expect("Record is always serializable") + } + + /// Parse from wire JSON. `None` means the row doesn't match the + /// schema; callers skip such rows. + pub fn from_value(v: Value) -> Option { + serde_json::from_value(v).ok() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn pending_record() -> Record { + Record::pending( + "tc-1".into(), + "shell::exec".into(), + json!({"command": "ls"}), + "sess_a".into(), + 1_000, + 60_000, + ) + } + + #[test] + fn pending_has_no_outcome_and_no_resolved_at() { + let r = pending_record(); + assert_eq!(r.status, Status::Pending); + assert!(r.outcome.is_none()); + assert!(r.resolved_at.is_none()); + assert_eq!(r.expires_at, 61_000); + } + + #[test] + fn pending_expires_at_saturates_on_overflow() { + let r = Record::pending( + "tc-1".into(), "f".into(), json!({}), "s".into(), u64::MAX - 5, 100); + assert_eq!(r.expires_at, u64::MAX); + } + + #[test] + fn in_flight_preserves_fields_and_clears_outcome_state() { + let p = pending_record(); + let i = p.clone().in_flight(2_000); + assert_eq!(i.status, Status::InFlight); + assert_eq!(i.function_call_id, p.function_call_id); + assert_eq!(i.session_id, p.session_id); + assert_eq!(i.args, p.args); + assert!(i.outcome.is_none()); + assert_eq!(i.resolved_at, Some(2_000), "InFlight stamps resolved_at"); + } + + #[test] + fn done_stamps_outcome_and_preserves_in_flight_resolved_at() { + let i = pending_record().in_flight(2_000); + let d = i.clone().done(Outcome::Executed { result: json!({"ok": true}) }); + assert_eq!(d.status, Status::Done); + assert!(matches!(d.outcome, Some(Outcome::Executed { .. }))); + // resolved_at was set at InFlight time and must NOT be re-stamped on Done. + assert_eq!(d.resolved_at, Some(2_000)); + } + + #[test] + fn done_directly_from_pending_stamps_resolved_at() { + // Deny path skips InFlight; we still need a resolved_at for ordering. + let p = pending_record(); + let d = p.done_at(3_000, Outcome::Denied { denial: Denial::UserRejected }); + assert_eq!(d.status, Status::Done); + assert_eq!(d.resolved_at, Some(3_000)); + } + + #[test] + fn outcome_round_trip_via_json() { + for o in [ + Outcome::Executed { result: json!({"x": 1}) }, + Outcome::Failed { error: "boom".into() }, + Outcome::Denied { denial: Denial::UserRejected }, + Outcome::TimedOut, + ] { + let v = serde_json::to_value(&o).unwrap(); + let back: Outcome = serde_json::from_value(v).unwrap(); + // Exhaustive equality is verbose; just round-trip the discriminant. + assert_eq!(std::mem::discriminant(&o), std::mem::discriminant(&back)); + } + } + + #[test] + fn record_round_trip_pending() { + let r = pending_record(); + let v = r.to_value(); + let back = Record::from_value(v).expect("deserialize"); + assert_eq!(back.status, Status::Pending); + assert_eq!(back.function_call_id, "tc-1"); + } + + #[test] + fn record_round_trip_done_carries_outcome_and_resolved_at() { + let r = pending_record() + .in_flight(2_000) + .done(Outcome::Executed { result: json!({"out": "hi"}) }); + let v = r.to_value(); + let back = Record::from_value(v).expect("deserialize"); + assert_eq!(back.status, Status::Done); + assert_eq!(back.resolved_at, Some(2_000)); + assert!(matches!(back.outcome, Some(Outcome::Executed { .. }))); + } + + #[test] + fn flip_returns_none_when_not_expired() { + let r = pending_record(); + assert!(r.flipped_to_timed_out_if_expired(60_000).is_none()); + } + + #[test] + fn flip_returns_done_timed_out_for_expired_pending() { + let r = pending_record(); + let flipped = r.flipped_to_timed_out_if_expired(70_000) + .expect("expired pending should flip"); + assert_eq!(flipped.status, Status::Done); + assert!(matches!(flipped.outcome, Some(Outcome::TimedOut))); + assert_eq!(flipped.resolved_at, Some(70_000)); + } + + #[test] + fn flip_does_not_touch_in_flight_rows() { + let r = pending_record().in_flight(2_000); + assert!(r.flipped_to_timed_out_if_expired(70_000).is_none(), + "InFlight rows are owned by an in-progress invoke; lazy flip must not steal them"); + } + + #[test] + fn flip_does_not_touch_already_done_rows() { + let r = pending_record() + .in_flight(2_000) + .done(Outcome::Executed { result: json!({}) }); + assert!(r.flipped_to_timed_out_if_expired(70_000).is_none()); + } +} diff --git a/approval-gate/src/register.rs b/approval-gate/src/register.rs new file mode 100644 index 00000000..a1487195 --- /dev/null +++ b/approval-gate/src/register.rs @@ -0,0 +1,381 @@ +//! iii function/trigger wiring. +//! +//! [`register`] is the entry point the binary calls at startup. It +//! constructs the shared `StateBus` + `FunctionExecutor`, hooks every +//! `approval::*` function id, registers the `policy::approval_gate` +//! subscriber on the configured topic, spawns the timeout sweeper, and +//! returns a [`Refs`] handle whose contents keep all the function +//! registrations and the sweeper task alive for the worker's lifetime. +//! +//! The subscriber closure is the only piece of non-trivial logic in +//! this module — it composes the three decision layers documented in +//! [`crate::intercept`] and writes the resulting hook reply onto the +//! envelope's reply stream. + +use std::sync::{Arc, RwLock}; + +use iii_sdk::{ + FunctionRef, IIIError, RegisterFunctionMessage, RegisterTriggerInput, TriggerRequest, III, +}; +use serde_json::{json, Value}; + +use crate::config::{InterceptorRule, WorkerConfig}; +use crate::delivery::{handle_consume, handle_list_pending, handle_sweep_session}; +use crate::intercept::handle_intercept; +use crate::resolve::{handle_lookup_record, handle_resolve}; +use crate::rules; +use crate::state::{FunctionExecutor, IiiFunctionExecutor, IiiStateBus, StateBus}; +use crate::wire::{extract_call, pending_key}; + +/// The iii function ids registered by [`register`]. +pub const FN_RESOLVE: &str = "approval::resolve"; +pub const FN_LIST_PENDING: &str = "approval::list_pending"; +pub const FN_CONSUME: &str = "approval::consume"; +pub const FN_SWEEP_SESSION: &str = "approval::sweep_session"; +pub const FN_LOOKUP_RECORD: &str = "approval::lookup_record"; + +/// Default `approval_state_scope` (matches [`WorkerConfig::default`]). +pub const STATE_SCOPE: &str = "approvals"; + +/// Handles returned from [`register`]; holding them keeps every iii +/// function registration alive for the worker's lifetime. The 2-second +/// background sweeper task is gone — timeouts now flip lazily on read. +pub struct Refs { + pub resolve: FunctionRef, + pub list_pending: FunctionRef, + pub consume: FunctionRef, + pub sweep_session: FunctionRef, + pub lookup_record: FunctionRef, + pub subscriber_fn: FunctionRef, + pub subscriber_trigger: iii_sdk::Trigger, +} + +pub fn register(iii: &III, cfg: &WorkerConfig) -> anyhow::Result { + // Layered policy ruleset, wrapped in RwLock so cascade-on-`always:true` + // can push a runtime Allow rule (see resolve.rs::cascade_allow_for_session). + let policy_rules: Arc> = Arc::new(RwLock::new(cfg.rules.clone())); + + // No-op alias-warning loop kept as a no-op for backward source + // compatibility (no interceptors are configured anymore). Empty vec + // so the loop body never runs. + let rules: Arc> = Arc::new(Vec::new()); + for rule in rules.iter() { + if let Some(cid) = rule.classifier.as_deref() { + if cid == FN_LOOKUP_RECORD + || cid == FN_RESOLVE + || cid == FN_LIST_PENDING + || cid == FN_CONSUME + || cid == FN_SWEEP_SESSION + { + tracing::warn!( + "approval-gate: interceptor for {:?} uses classifier {:?} which aliases an approval endpoint; fix config", + rule.function_id, + cid + ); + } + } + } + + let bus: Arc = Arc::new(IiiStateBus(iii.clone())); + let timeout_ms = cfg.default_timeout_ms; + let topic = cfg.topic.clone(); + let state_scope = cfg.approval_state_scope.clone(); + + let bus_for_resolve = bus.clone(); + let scope_resolve = state_scope.clone(); + let exec_for_resolve: Arc = + Arc::new(IiiFunctionExecutor { iii: iii.clone() }); + let iii_for_resolve = iii.clone(); + let policy_rules_for_resolve = policy_rules.clone(); + let resolve = iii.register_function(( + RegisterFunctionMessage::with_id(FN_RESOLVE.into()).with_description( + "Resolve a pending approval. On allow, invokes the underlying function; \ + on deny, records the denial. With `always: true` on an allow reply, \ + a runtime rule is added so future calls to this function id auto-allow, \ + and the session's other pending calls newly matching are cascade-resolved. \ + The result is stitched into the agent's next turn as a system message." + .into(), + ), + move |payload: Value| { + let bus = bus_for_resolve.clone(); + let exec = exec_for_resolve.clone(); + let scope_resolve = scope_resolve.clone(); + let iii = iii_for_resolve.clone(); + let policy_rules = policy_rules_for_resolve.clone(); + async move { + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + let resp = handle_resolve( + bus.as_ref(), + exec.as_ref(), + &scope_resolve, + &policy_rules, + payload.clone(), + now_ms, + ) + .await; + + if resp.get("ok").and_then(Value::as_bool) == Some(true) { + let session_id = payload + .get("session_id") + .and_then(Value::as_str) + .unwrap_or(""); + let call_id = payload + .get("function_call_id") + .or_else(|| payload.get("tool_call_id")) + .and_then(Value::as_str) + .unwrap_or(""); + if !session_id.is_empty() && !call_id.is_empty() { + let key = pending_key(session_id, call_id); + if let Some(final_rec) = bus.get(&scope_resolve, &key).await { + let mut evt = json!({ + "type": "approval_resolved", + "function_call_id": call_id, + "tool_call_id": call_id, + }); + if let Some(status) = final_rec.get("status").and_then(Value::as_str) { + evt["decision"] = match status { + "executed" | "approved" => json!("allow"), + _ => json!("deny"), + }; + evt["status"] = json!(status); + } + if let Some(r) = final_rec.get("result") { + evt["result"] = json!(r); + } + if let Some(e) = final_rec.get("error") { + evt["error"] = json!(e); + } + if let Some(denial) = final_rec.get("denial") { + evt["denial"] = denial.clone(); + } + write_event(&iii, session_id, &evt).await; + } + } + } + Ok::<_, IIIError>(resp) + } + }, + )); + + let bus_for_list = bus.clone(); + let scope_list = state_scope.clone(); + let list_pending = iii.register_function(( + RegisterFunctionMessage::with_id(FN_LIST_PENDING.into()) + .with_description("Return pending approvals for a session.".into()), + move |payload: Value| { + let bus = bus_for_list.clone(); + let scope_list = scope_list.clone(); + async move { + Ok::<_, IIIError>(handle_list_pending(bus.as_ref(), &scope_list, payload).await) + } + }, + )); + + let bus_for_consume = bus.clone(); + let scope_consume = state_scope.clone(); + let consume = iii.register_function(( + RegisterFunctionMessage::with_id(FN_CONSUME.into()).with_description( + "Atomic drain: returns Done rows for a session and deletes them in the \ + same call. Pending and InFlight rows stay in state. Pending rows past \ + expires_at are lazy-flipped to Done(TimedOut) before return. \ + Required payload: {session_id, limit?}. Response: {ok, entries, omitted}." + .into(), + ), + move |payload: Value| { + let bus = bus_for_consume.clone(); + let scope = scope_consume.clone(); + async move { + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + Ok::<_, IIIError>( + handle_consume(bus.as_ref(), &scope, payload, now_ms).await, + ) + } + }, + )); + + let bus_for_sweep = bus.clone(); + let scope_sweep = state_scope.clone(); + let sweep_session = iii.register_function(( + RegisterFunctionMessage::with_id(FN_SWEEP_SESSION.into()).with_description( + "Sweep all pending approvals for a session to timed_out. \ + Called when a session is deleted." + .into(), + ), + move |payload: Value| { + let bus = bus_for_sweep.clone(); + let scope = scope_sweep.clone(); + async move { + Ok::<_, IIIError>(handle_sweep_session(bus.as_ref(), &scope, payload).await) + } + }, + )); + + let bus_for_lookup = bus.clone(); + let scope_lookup = state_scope.clone(); + let lookup_record = iii.register_function(( + RegisterFunctionMessage::with_id(FN_LOOKUP_RECORD.into()).with_description( + "Return the approval state-store record for a session/function_call_id pair; \ + null when absent. Used by shell bypass validation." + .into(), + ), + move |payload: Value| { + let bus = bus_for_lookup.clone(); + let scope = scope_lookup.clone(); + async move { + Ok::<_, IIIError>(handle_lookup_record(bus.as_ref(), &scope, payload).await) + } + }, + )); + + let iii_for_sub = iii.clone(); + let bus_for_sub = bus.clone(); + let subscriber_scope = state_scope.clone(); + let rules_for_sub = rules.clone(); + let policy_rules_for_sub = policy_rules.clone(); + let subscriber_fn = iii.register_function(( + RegisterFunctionMessage::with_id("policy::approval_gate".into()) + .with_description("Pause function calls listed in approval_required.".into()), + move |envelope: Value| { + let iii = iii_for_sub.clone(); + let bus = bus_for_sub.clone(); + let sc = subscriber_scope.clone(); + let intercept_rules = rules_for_sub.clone(); + let policy_rules = policy_rules_for_sub.clone(); + async move { + let Some(call) = extract_call(&envelope) else { + return Ok::<_, IIIError>(json!({ "block": false })); + }; + let now_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as u64) + .unwrap_or(0); + + // Take a snapshot of the rules ruleset under the read lock, + // then drop the guard before any .await. std::sync::RwLock + // is not async-safe to hold across suspension points, and + // a held guard would block every concurrent intercept. + let rules_snapshot: rules::Ruleset = { + let guard = policy_rules + .read() + .expect("approval-gate policy rules lock poisoned"); + guard.clone() + }; + + // One decision call. Verdict::Allow → {block:false}. + // Verdict::Deny → {block:true, denial:Policy{...}}. + // Verdict::Ask → write Pending + reply {block:true, status:pending}. + let reply = handle_intercept( + bus.as_ref(), + &sc, + &call, + &rules_snapshot, + now_ms, + timeout_ms, + ).await; + + if reply.get("status").and_then(Value::as_str) == Some("pending") { + write_event( + &iii, + &call.session_id, + &json!({ + "type": "approval_requested", + "function_call_id": call.function_call_id, + "tool_call_id": call.function_call_id, + "function_id": call.function_id, + "tool_name": call.function_id, + "args": call.args, + "expires_at": now_ms.saturating_add(timeout_ms), + }), + ) + .await; + } + write_hook_reply(&iii, &call.reply_stream, &call.event_id, &reply).await; + Ok(reply) + } + }, + )); + + let subscriber_trigger = iii + .register_trigger(RegisterTriggerInput { + trigger_type: "durable:subscriber".into(), + function_id: "policy::approval_gate".into(), + config: json!({ "topic": topic }), + metadata: None, + }) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; + + Ok(Refs { + resolve, + list_pending, + consume, + sweep_session, + lookup_record, + subscriber_fn, + subscriber_trigger, + }) +} + +// ───────────────────────────────────────────────────────────────────────── +// Inline stream helpers (used by the subscriber to write the +// `approval_requested` stream frame and the hook reply). These used to +// live in `sweeper.rs` but that file is gone now that the background +// polling task is deleted; the helpers move here as their only consumer. +// ───────────────────────────────────────────────────────────────────────── + +pub(crate) fn uuid_like() -> String { + use std::sync::atomic::{AtomicU64, Ordering}; + static C: AtomicU64 = AtomicU64::new(0); + let n = C.fetch_add(1, Ordering::Relaxed); + let t = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_nanos()) + .unwrap_or(0); + format!("{t:x}-{n:x}") +} + +/// Append `event` to the `agent::events` stream for `session_id`. Fire- +/// and-forget: errors are swallowed because the persisted record is the +/// source of truth — orchestrators re-derive state from +/// `approval::consume` if a frame is lost. +pub(crate) async fn write_event(iii: &III, session_id: &str, event: &Value) { + let _ = iii + .trigger(TriggerRequest { + function_id: "stream::set".into(), + payload: json!({ + "stream_name": "agent::events", + "group_id": session_id, + "item_id": format!("approval-{}", uuid_like()), + "data": event, + }), + action: None, + timeout_ms: None, + }) + .await; +} + +/// Append a hook reply onto `stream_name` keyed by `event_id`. No-op when +/// either id is empty so a malformed envelope can't crash the gate. +pub(crate) async fn write_hook_reply(iii: &III, stream_name: &str, event_id: &str, reply: &Value) { + if stream_name.is_empty() || event_id.is_empty() { + return; + } + let _ = iii + .trigger(TriggerRequest { + function_id: "stream::set".into(), + payload: json!({ + "stream_name": stream_name, + "group_id": event_id, + "item_id": uuid_like(), + "data": reply, + }), + action: None, + timeout_ms: None, + }) + .await; +} diff --git a/approval-gate/src/resolve.rs b/approval-gate/src/resolve.rs new file mode 100644 index 00000000..56d1d90a --- /dev/null +++ b/approval-gate/src/resolve.rs @@ -0,0 +1,267 @@ +//! Resolve flow — handles `approval::resolve` and the cascading-allow +//! behavior that fires when a reply carries `always: true`. +//! +//! ## Three-phase allow path +//! +//! [`handle_resolve`] is the entry point. On allow it routes through +//! [`approve_and_execute`]: +//! 1. write `InFlight` (closes the dup-exec race — a second resolve +//! arriving during the invoke await sees a non-Pending row and bails); +//! 2. `iii.trigger(function_id, args)` and await; +//! 3. write `Done(Executed{result})` or `Done(Failed{error})`. +//! +//! Deny is a single Pending → Done(Denied) write — no invoke, no InFlight. +//! +//! ## Cascade +//! +//! On `allow + always:true`, [`cascade_allow_for_session`] pushes a runtime +//! `Allow` rule with the originator's **exact pattern** (via +//! [`crate::rules::pattern_for`]) — not a blanket `pattern: "*"`. "Always +//! allow git status" does NOT auto-allow `rm -rf /` via the same +//! `shell::exec` function id. Same-session pending rows whose +//! `verdict_for` returns `Allow` under the new rule are driven through +//! `approve_and_execute`. + +use std::sync::RwLock; + +use serde_json::{json, Value}; + +use crate::record::{Record, Status, Outcome}; +use crate::rules::{self, Action, Rule, Ruleset}; +use crate::state::{FunctionExecutor, StateBus}; +use crate::wire::{pending_key, Denial, WireDecision}; + +/// Lookup a single approval record by session + call id (for shell bypass +/// validation). Stays on the old free-form Value shape so shell-side +/// readers don't break — shell strip in T13 deletes the callsite there. +pub async fn handle_lookup_record(bus: &dyn StateBus, state_scope: &str, payload: Value) -> Value { + let session_id = payload.get("session_id").and_then(Value::as_str).unwrap_or(""); + let function_call_id = payload.get("function_call_id").and_then(Value::as_str).unwrap_or(""); + if session_id.is_empty() || function_call_id.is_empty() { + return Value::Null; + } + let key = pending_key(session_id, function_call_id); + bus.get(state_scope, &key).await.unwrap_or(Value::Null) +} + +/// Resolve a pending approval. Wire-format errors return +/// `{ok:false, error:""}`. Success returns `{ok:true}` plus an +/// optional `cascaded: N` count when an `always:true` reply triggered the +/// session sweep. +pub async fn handle_resolve( + bus: &dyn StateBus, + exec: &dyn FunctionExecutor, + state_scope: &str, + policy_rules: &RwLock, + payload: Value, + now_ms: u64, +) -> Value { + let session_id = payload.get("session_id").and_then(Value::as_str).unwrap_or(""); + let function_call_id = payload + .get("function_call_id") + .or_else(|| payload.get("tool_call_id")) + .and_then(Value::as_str) + .unwrap_or(""); + if session_id.is_empty() || function_call_id.is_empty() { + return json!({ "ok": false, "error": "missing_id" }); + } + + let decision: WireDecision = match payload.get("decision").cloned() { + Some(v) => match serde_json::from_value(v) { + Ok(d) => d, + Err(_) => return json!({ "ok": false, "error": "bad_decision" }), + }, + None => return json!({ "ok": false, "error": "bad_decision" }), + }; + + let key = pending_key(session_id, function_call_id); + let Some(raw) = bus.get(state_scope, &key).await else { + return json!({ "ok": false, "error": "not_found" }); + }; + let Some(record) = Record::from_value(raw) else { + return json!({ "ok": false, "error": "corrupt_record" }); + }; + + // Lazy timeout flip — Pending rows past expires_at flip to + // Done(TimedOut) on read. + if let Some(flipped) = record.flipped_to_timed_out_if_expired(now_ms) { + let _ = bus.set(state_scope, &key, flipped.to_value()).await; + return json!({ "ok": false, "error": "timed_out" }); + } + + // Dup-exec guard: only Pending rows are resolvable. InFlight means a + // concurrent resolve is still mid-invoke; Done means terminal. + match record.status { + Status::Pending => { /* fall through */ } + Status::InFlight => return json!({ "ok": false, "error": "in_flight" }), + Status::Done => return json!({ "ok": false, "error": "already_resolved" }), + } + + match decision { + WireDecision::Deny => { + // Optional structured denial from caller; missing → UserRejected. + let denial = match payload.get("denial").cloned() { + Some(v) => match serde_json::from_value::(v) { + Ok(d) => d, + Err(_) => return json!({ "ok": false, "error": "bad_denial" }), + }, + None => Denial::UserRejected, + }; + let denied = record.done_at(now_ms, Outcome::Denied { denial }); + if let Err(e) = bus.set(state_scope, &key, denied.to_value()).await { + tracing::error!("approval-gate: failed to write denied record: {e}"); + return json!({ "ok": false, "error": "state_write_failed" }); + } + json!({ "ok": true }) + } + WireDecision::Allow => { + // Snapshot args + function_id before consuming `record` in + // approve_and_execute — cascade needs them for the rule push. + let function_id = record.function_id.clone(); + let args = record.args.clone(); + + if let Err(err) = approve_and_execute( + bus, exec, state_scope, record, session_id, function_call_id, now_ms, + ).await { + tracing::error!("approval-gate: failed to execute approved call: {err}"); + return json!({ "ok": false, "error": "state_write_failed" }); + } + + // Cascade on `always:true`. Push a runtime Allow rule with the + // ORIGINATOR'S EXACT PATTERN (via pattern_for), then sweep the + // session's other Pending rows. + let cascaded = if payload.get("always").and_then(Value::as_bool).unwrap_or(false) { + cascade_allow_for_session( + bus, exec, state_scope, policy_rules, + session_id, function_call_id, + &function_id, &args, + now_ms, + ).await + } else { + 0 + }; + + if cascaded > 0 { + json!({ "ok": true, "cascaded": cascaded }) + } else { + json!({ "ok": true }) + } + } + } +} + +/// Push an exact-pattern Allow rule into the shared ruleset, then sweep +/// the session's other Pending rows. Returns the number of rows +/// auto-resolved (originator excluded). +/// +/// **Lock-ordering invariant**: the write/read guards on `policy_rules` +/// are released before any `.await`. `std::sync::RwLock` is not async-safe +/// to hold across suspension; a held guard would block every concurrent +/// intercept. +async fn cascade_allow_for_session( + bus: &dyn StateBus, + exec: &dyn FunctionExecutor, + state_scope: &str, + policy_rules: &RwLock, + session_id: &str, + originator_call_id: &str, + originator_function_id: &str, + originator_args: &Value, + now_ms: u64, +) -> u64 { + // 1. Push the exact-pattern Allow rule under the write lock. + // pattern_for is the same extractor used at intercept time, so + // "always allow git status" means literally that argv shape — NOT + // a blanket "*" pattern that would auto-allow rm -rf /. + let pushed_pattern = rules::pattern_for(originator_function_id, originator_args); + { + let mut guard = policy_rules + .write() + .expect("approval-gate policy rules lock poisoned"); + guard.push(Rule { + permission: originator_function_id.to_string(), + pattern: pushed_pattern, + action: Action::Allow, + }); + } + + // 2. Snapshot the session's pending rows. + let prefix = format!("{session_id}/"); + let session_rows = bus.list_prefix(state_scope, &prefix).await; + + let mut cascaded = 0u64; + for raw in session_rows { + let Some(record) = Record::from_value(raw) else { continue }; + if record.session_id != session_id { continue; } // defensive + if record.function_call_id == originator_call_id { continue; } // skip originator + if record.status != Status::Pending { continue; } // skip non-pending + + // 3. Re-evaluate against the updated ruleset. + let verdict = { + let guard = policy_rules + .read() + .expect("approval-gate policy rules lock poisoned"); + crate::verdict_for(&record.function_id, &record.args, &guard) + }; + if !matches!(verdict, crate::Verdict::Allow) { + continue; + } + + // 4. Drive through the same approve_and_execute path as the + // user-driven allow (InFlight → invoke → Done). + let cid = record.function_call_id.clone(); + if let Err(err) = approve_and_execute( + bus, exec, state_scope, record, session_id, &cid, now_ms, + ).await { + tracing::warn!( + session_id, call_id = %cid, + "approval-gate: cascade auto-resolve failed: {err}", + ); + continue; + } + cascaded += 1; + } + cascaded +} + +/// Drive a Pending row through InFlight → invoke → Done. Used by both +/// the user-driven allow path and the cascade sweep so the lifecycle +/// transitions stay in one place. +/// +/// Phase 1 (InFlight) is the dup-exec guard: a concurrent resolve seeing +/// a non-Pending row in `handle_resolve` returns `in_flight` and skips +/// the second invoke. +pub(crate) async fn approve_and_execute( + bus: &dyn StateBus, + exec: &dyn FunctionExecutor, + state_scope: &str, + pending: Record, + session_id: &str, + function_call_id: &str, + now_ms: u64, +) -> Result<(), String> { + let key = pending_key(session_id, function_call_id); + let function_id = pending.function_id.clone(); + let args = pending.args.clone(); + + // Phase 1: InFlight write. Closes the dup-exec race. + let in_flight = pending.in_flight(now_ms); + bus.set(state_scope, &key, in_flight.to_value()) + .await + .map_err(|e| e.to_string())?; + + // Phase 2: invoke. Result/error captured on the record below. + let outcome = match exec + .invoke(&function_id, args, function_call_id, session_id) + .await + { + Ok(result) => Outcome::Executed { result }, + Err(error) => Outcome::Failed { error }, + }; + + // Phase 3: Done write. resolved_at preserved from the InFlight write. + let done = in_flight.done(outcome); + bus.set(state_scope, &key, done.to_value()) + .await + .map_err(|e| e.to_string()) +} diff --git a/approval-gate/src/rules.rs b/approval-gate/src/rules.rs new file mode 100644 index 00000000..0cdd7d7c --- /dev/null +++ b/approval-gate/src/rules.rs @@ -0,0 +1,371 @@ +//! Layered permission rules — first-class policy primitive ported from +//! opencode's `Permission.evaluate` / `Wildcard.match`. +//! +//! ## Shape +//! +//! A [`Rule`] pairs a permission glob (matched against the iii function id) +//! with a pattern glob (matched against a caller-supplied pattern string, +//! always `"*"` in v1 — see [`evaluate`] for the forward-compatible call +//! shape). An [`Action`] tells the gate what to do on match: +//! [`Action::Allow`] passes the call through, [`Action::Deny`] short-circuits +//! with a policy [`crate::Denial`], [`Action::Ask`] falls back to the existing +//! per-function [`crate::config::InterceptorRule`] flow. +//! +//! ## Layering +//! +//! Operators stack rules — a workspace-default ruleset, plus a per-session +//! override, plus an operator-pinned global. [`evaluate`] flattens N +//! rulesets in caller order and returns the **last** matching rule. +//! Last-wins is the standard policy-stacking semantic: a more-specific +//! later layer overrides an earlier permissive default without surgery on +//! the earlier list. +//! +//! ## Wildcard match +//! +//! [`wildcard_match`] supports `*` (zero or more of any character) and +//! literal text. No regex, no `?`, no character classes — the surface is +//! intentionally tiny to match opencode's `Wildcard.match` behaviour and +//! keep the rule language operator-readable. `*` is greedy via dynamic +//! programming so `"a*b*c"` matches `"axxxbxxxc"` correctly. + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Decision a [`Rule`] expresses when it matches an incoming call. +/// +/// Wire format is the lowercase string `"allow"` | `"deny"` | `"ask"` so +/// rules are operator-readable in YAML / JSON config. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Action { + Allow, + Deny, + Ask, +} + +/// A single permission rule. +/// +/// `permission` is matched against the iii function id (e.g. `shell::exec`, +/// `shell::fs::*`). `pattern` is matched against a caller-supplied pattern +/// string; in v1 every call site passes `"*"`, so `pattern: "*"` is the +/// only useful value today. The field is kept on the type so the forward +/// path to per-function pattern extractors (shell::exec → joined argv, +/// shell::fs::* → path) is a config-level change, not a schema break. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Rule { + /// Wildcard pattern matched against the iii function id. + pub permission: String, + /// Wildcard pattern matched against a caller-supplied pattern string. + /// In v1 callers pass `"*"`; setting `pattern: "*"` here matches them. + pub pattern: String, + pub action: Action, +} + +/// A list of rules, evaluated in order. Stacked rulesets are flattened by +/// [`evaluate`] in caller order so the **last** matching rule across all +/// layers wins. +pub type Ruleset = Vec; + +/// True if `text` matches the wildcard `pattern`. Supports `*` (zero or +/// more of any character) and literal text. Tiny on purpose — operators +/// should be able to read a rule and know what it matches without a regex +/// engine in their head. +/// +/// Dynamic-programming implementation so `"a*b*c"` matches `"axxxbxxxc"` +/// without exponential backtracking on patterns with many `*`. +pub fn wildcard_match(pattern: &str, text: &str) -> bool { + let p: Vec = pattern.chars().collect(); + let t: Vec = text.chars().collect(); + let (np, nt) = (p.len(), t.len()); + // dp[i][j] = true iff p[..i] matches t[..j]. + let mut dp = vec![vec![false; nt + 1]; np + 1]; + dp[0][0] = true; + // A leading run of '*' can match the empty string. + for i in 1..=np { + if p[i - 1] == '*' { + dp[i][0] = dp[i - 1][0]; + } + } + for i in 1..=np { + for j in 1..=nt { + dp[i][j] = if p[i - 1] == '*' { + // '*' matches empty (dp[i-1][j]) or extends by one char (dp[i][j-1]). + dp[i - 1][j] || dp[i][j - 1] + } else { + p[i - 1] == t[j - 1] && dp[i - 1][j - 1] + }; + } + } + dp[np][nt] +} + +/// Find the **last** rule in `rules` whose `permission` and `pattern` +/// both wildcard-match the given inputs. Takes any iterator of rule +/// references so callers can pass a single [`Ruleset`] directly +/// (`&Vec` is `IntoIterator`) or chain several layers +/// via `global.iter().chain(session.iter())` without temporary borrows. +/// Returns the matched rule by reference so the caller can read its +/// [`Action`] and report the matching pattern in audit / Denial detail. +/// +/// `None` means no rule matched — the caller should fall back to whatever +/// it would do without a rules layer (in approval-gate: the existing +/// per-function [`crate::config::InterceptorRule`] path). +pub fn evaluate<'a, I>(permission: &str, pattern: &str, rules: I) -> Option<&'a Rule> +where + I: IntoIterator, +{ + rules + .into_iter() + .filter(|r| wildcard_match(&r.permission, permission) && wildcard_match(&r.pattern, pattern)) + .last() +} + +/// Per-function pattern extractor. The pattern is the second axis a rule +/// matches on (alongside `function_id`); for `shell::exec` we derive it +/// from `{command, args}` so operators can write rules like +/// `permission: "shell::exec", pattern: "git status*"` and get +/// argv-level granularity. Other function ids default to `"*"`, which +/// matches only wildcard rules. +pub fn pattern_for(function_id: &str, args: &Value) -> String { + match function_id { + "shell::exec" | "shell::exec_bg" => extract_shell_pattern(args), + _ => "*".to_string(), + } +} + +/// Shell ExecRequest is `{ command: String, args: Option> }` +/// per `shell/src/functions/types.rs`. There is no `argv` field. Two +/// modes: +/// - `args = None` → `command` is a shell-words string, use as-is. +/// - `args = Some(list)` → join `command + " " + list.join(" ")`. +/// Malformed input (missing/non-string command) falls back to `"*"` so +/// the row matches only wildcard rules. +/// +/// Known conflation: argv `[git, log, "--grep=foo bar"]` joins to +/// `"git log --grep=foo bar"`, same pattern string as +/// `[git, log, "--grep=foo", bar]`. Documented; acceptable for v1. +fn extract_shell_pattern(args: &Value) -> String { + let cmd = args.get("command").and_then(Value::as_str); + let argv = args.get("args").and_then(Value::as_array); + match (cmd, argv) { + (Some(c), Some(arr)) if !arr.is_empty() => { + let mut parts = vec![c.to_string()]; + parts.extend(arr.iter().filter_map(Value::as_str).map(str::to_string)); + parts.join(" ") + } + (Some(c), _) => c.to_string(), + _ => "*".to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn r(permission: &str, pattern: &str, action: Action) -> Rule { + Rule { + permission: permission.to_string(), + pattern: pattern.to_string(), + action, + } + } + + #[test] + fn wildcard_literal_match() { + assert!(wildcard_match("shell::exec", "shell::exec")); + assert!(!wildcard_match("shell::exec", "shell::fs::read")); + } + + #[test] + fn wildcard_star_matches_empty() { + assert!(wildcard_match("*", "")); + assert!(wildcard_match("*", "anything")); + } + + #[test] + fn wildcard_star_matches_prefix() { + assert!(wildcard_match("shell::*", "shell::exec")); + assert!(wildcard_match("shell::*", "shell::fs::write")); + assert!(!wildcard_match("shell::*", "approval::resolve")); + } + + #[test] + fn wildcard_star_matches_suffix_and_middle() { + assert!(wildcard_match("*::exec", "shell::exec")); + assert!(wildcard_match("shell::*::write", "shell::fs::write")); + assert!(!wildcard_match("shell::*::write", "shell::fs::read")); + } + + #[test] + fn wildcard_multiple_stars_no_backtracking_blowup() { + // The dp implementation must not blow up on many '*'. + let pat = "*a*a*a*a*a*a*a*a*a*a*a*a*a*b"; + let text: String = "a".repeat(50); + assert!(!wildcard_match(pat, &text)); + let text_ok: String = format!("{}b", "a".repeat(50)); + assert!(wildcard_match(pat, &text_ok)); + } + + #[test] + fn evaluate_returns_none_for_empty_ruleset() { + let empty: Ruleset = vec![]; + assert!(evaluate("shell::exec", "*", &empty).is_none()); + } + + #[test] + fn evaluate_returns_none_when_nothing_matches() { + let rs: Ruleset = vec![r("approval::*", "*", Action::Allow)]; + assert!(evaluate("shell::exec", "*", &rs).is_none()); + } + + #[test] + fn evaluate_matches_exact_permission() { + let rs: Ruleset = vec![r("shell::exec", "*", Action::Allow)]; + let m = evaluate("shell::exec", "*", &rs).expect("match"); + assert_eq!(m.action, Action::Allow); + } + + #[test] + fn evaluate_matches_wildcard_permission() { + let rs: Ruleset = vec![r("shell::*", "*", Action::Allow)]; + let m = evaluate("shell::fs::write", "*", &rs).expect("match"); + assert_eq!(m.action, Action::Allow); + } + + #[test] + fn evaluate_last_wins_within_single_ruleset() { + // Two matching rules in the same ruleset; the later one wins. + let rs: Ruleset = vec![ + r("shell::*", "*", Action::Allow), + r("shell::exec", "*", Action::Deny), + ]; + let m = evaluate("shell::exec", "*", &rs).expect("match"); + assert_eq!( + m.action, + Action::Deny, + "more-specific later rule must override earlier permissive default" + ); + } + + #[test] + fn evaluate_last_wins_across_layered_rulesets() { + // global allows everything; session denies shell::exec. Session + // (passed last) overrides global. + let global: Ruleset = vec![r("*", "*", Action::Allow)]; + let session: Ruleset = vec![r("shell::exec", "*", Action::Deny)]; + let m = evaluate( + "shell::exec", + "*", + global.iter().chain(session.iter()), + ) + .expect("match"); + assert_eq!(m.action, Action::Deny); + + // For a permission only matched by global, global still wins. + let m2 = evaluate( + "approval::resolve", + "*", + global.iter().chain(session.iter()), + ) + .expect("match"); + assert_eq!(m2.action, Action::Allow); + } + + #[test] + fn evaluate_ask_is_a_valid_action() { + let rs: Ruleset = vec![r("shell::exec", "*", Action::Ask)]; + let m = evaluate("shell::exec", "*", &rs).expect("match"); + assert_eq!(m.action, Action::Ask); + } + + #[test] + fn evaluate_pattern_matches_when_both_globs_pass() { + let rs: Ruleset = vec![r("shell::exec", "git*", Action::Allow)]; + // pattern matches + let m = evaluate("shell::exec", "git checkout main", &rs).expect("match"); + assert_eq!(m.action, Action::Allow); + // pattern doesn't match → no rule selected + assert!(evaluate("shell::exec", "rm -rf /", &rs).is_none()); + } + + #[test] + fn rule_serde_round_trip() { + let original = r("shell::exec", "*", Action::Deny); + let json = serde_json::to_value(&original).unwrap(); + assert_eq!(json["permission"], "shell::exec"); + assert_eq!(json["pattern"], "*"); + assert_eq!(json["action"], "deny"); + let back: Rule = serde_json::from_value(json).unwrap(); + assert_eq!(back, original); + } + + #[test] + fn action_yaml_round_trip() { + for a in [Action::Allow, Action::Deny, Action::Ask] { + let y = serde_yaml::to_string(&a).unwrap(); + let back: Action = serde_yaml::from_str(&y).unwrap(); + assert_eq!(back, a); + } + } + + // -------------------- pattern_for / extract_shell_pattern -------------------- + + use serde_json::json; + + #[test] + fn pattern_for_shell_exec_joins_command_with_args() { + let pat = pattern_for("shell::exec", &json!({"command": "git", "args": ["status"]})); + assert_eq!(pat, "git status"); + } + + #[test] + fn pattern_for_shell_exec_bg_joins_command_with_args() { + let pat = pattern_for("shell::exec_bg", + &json!({"command": "tail", "args": ["-f", "/var/log/x"]})); + assert_eq!(pat, "tail -f /var/log/x"); + } + + #[test] + fn pattern_for_shell_exec_single_string_command_no_args() { + // shell::exec supports the "command is a shell-words string" mode + // (args: None). The pattern is just the command string. + let pat = pattern_for("shell::exec", &json!({"command": "git status"})); + assert_eq!(pat, "git status"); + } + + #[test] + fn pattern_for_shell_exec_empty_args_list_treated_as_no_args() { + let pat = pattern_for("shell::exec", &json!({"command": "ls", "args": []})); + assert_eq!(pat, "ls"); + } + + #[test] + fn pattern_for_shell_exec_missing_command_falls_back_to_star() { + let pat = pattern_for("shell::exec", &json!({"args": ["foo"]})); + assert_eq!(pat, "*"); + } + + #[test] + fn pattern_for_shell_exec_completely_malformed_args_falls_back_to_star() { + let pat = pattern_for("shell::exec", &json!(null)); + assert_eq!(pat, "*"); + } + + #[test] + fn pattern_for_non_shell_function_id_returns_star() { + let pat = pattern_for("http::fetch", &json!({"url": "https://x"})); + assert_eq!(pat, "*"); + } + + #[test] + fn pattern_for_known_conflation_documented() { + // Documented in spec: an arg containing a space conflates with two + // separate args. This is acceptable for v1. + let with_inner_space = pattern_for("shell::exec", + &json!({"command": "git", "args": ["log", "--grep=foo bar"]})); + let split_args = pattern_for("shell::exec", + &json!({"command": "git", "args": ["log", "--grep=foo", "bar"]})); + assert_eq!(with_inner_space, split_args, + "v1 conflates space-in-arg with arg boundary; see spec"); + } +} diff --git a/approval-gate/src/state.rs b/approval-gate/src/state.rs new file mode 100644 index 00000000..9eaffadb --- /dev/null +++ b/approval-gate/src/state.rs @@ -0,0 +1,134 @@ +//! State-store and function-executor traits, plus their iii-backed +//! implementations. +//! +//! The traits exist as test seams — unit tests swap in +//! `InMemoryStateBus` / `FakeExecutor` while production code uses the +//! `Iii*` implementations that call iii directly. +//! +//! The `__from_approval` marker plumbing is gone (per the refactor's +//! threat-model decision: bus access ≡ shell access in the new model; +//! defense-in-depth via per-target marker verification is out of scope). + +use async_trait::async_trait; +use iii_sdk::{IIIError, TriggerRequest, III}; +use serde_json::{json, Value}; + +/// Abstraction over the iii state bus — the kv layer where pending and +/// resolved approval records live. Exists so unit tests can swap in a +/// `BTreeMap`-backed fake; production uses [`IiiStateBus`]. +#[async_trait] +pub trait StateBus: Send + Sync { + async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), IIIError>; + async fn get(&self, scope: &str, key: &str) -> Option; + async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec; + /// Remove a key. Required by `approval::consume`, which returns Done + /// rows and deletes them in the same call. Idempotent (deleting a + /// missing key returns Ok). + async fn delete(&self, scope: &str, key: &str) -> Result<(), IIIError>; +} + +/// Invokes an iii function with arguments and returns its result or an +/// error string. Abstracted so tests can stub the underlying call. +#[async_trait] +pub trait FunctionExecutor: Send + Sync { + async fn invoke( + &self, + function_id: &str, + args: Value, + function_call_id: &str, + session_id: &str, + ) -> Result; +} + +/// Production [`FunctionExecutor`] backed by `iii.trigger`. +/// +/// Forwards `function_id` + `args` directly to `iii.trigger`. No +/// `__from_approval` marker injection — the target trusts the bus. +pub struct IiiFunctionExecutor { + pub iii: III, +} + +#[async_trait] +impl FunctionExecutor for IiiFunctionExecutor { + async fn invoke( + &self, + function_id: &str, + args: Value, + _function_call_id: &str, + _session_id: &str, + ) -> Result { + self.iii + .trigger(TriggerRequest { + function_id: function_id.to_string(), + payload: args, + action: None, + timeout_ms: None, + }) + .await + .map_err(|e| e.to_string()) + } +} + +/// Production [`StateBus`] backed by iii's `state::*` builtins. +pub struct IiiStateBus(pub III); + +#[async_trait] +impl StateBus for IiiStateBus { + async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), IIIError> { + self.0 + .trigger(TriggerRequest { + function_id: "state::set".into(), + payload: json!({ "scope": scope, "key": key, "value": value }), + action: None, + timeout_ms: None, + }) + .await + .map(|_| ()) + } + async fn get(&self, scope: &str, key: &str) -> Option { + self.0 + .trigger(TriggerRequest { + function_id: "state::get".into(), + payload: json!({ "scope": scope, "key": key }), + action: None, + timeout_ms: None, + }) + .await + .ok() + .filter(|v| !v.is_null()) + } + async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec { + let resp = self + .0 + .trigger(TriggerRequest { + function_id: "state::list".into(), + payload: json!({ "scope": scope, "prefix": prefix }), + action: None, + timeout_ms: None, + }) + .await + .unwrap_or_else(|_| json!({ "items": [] })); + // Engine may return either {"items": [...]} or a plain Array. + if let Some(arr) = resp.as_array() { + return arr.clone(); + } + resp.get("items") + .and_then(|v| v.as_array().cloned()) + .unwrap_or_default() + .into_iter() + .map(|entry| entry.get("value").cloned().unwrap_or(entry)) + .collect() + } + async fn delete(&self, scope: &str, key: &str) -> Result<(), IIIError> { + self.0 + .trigger(TriggerRequest { + function_id: "state::delete".into(), + payload: json!({ "scope": scope, "key": key }), + action: None, + timeout_ms: None, + }) + .await + .map(|_| ()) + } +} + diff --git a/approval-gate/src/wire.rs b/approval-gate/src/wire.rs new file mode 100644 index 00000000..2aaa98b0 --- /dev/null +++ b/approval-gate/src/wire.rs @@ -0,0 +1,159 @@ +//! Wire-format types for the approval gate. +//! +//! Pure data shapes and small wire-shape helpers — no I/O, no `iii_sdk` +//! deps, no async. Anything a downstream worker would need to +//! understand the approval-gate protocol lives here: +//! +//! - [`Denial`] — structured deny payload (`kind` + `detail`) carried on +//! hook replies, persisted records, and `approval_resolved` events. +//! - [`Decision`] — internal allow/deny choice; pairs `Deny` with its +//! [`Denial`] so the type system rules out structureless deny. +//! - [`WireDecision`] — coarse `"allow"` / `"deny"` enum used at the +//! `approval::resolve` RPC boundary, where the UI / orchestrator +//! doesn't yet know the full [`Denial`]. +//! - [`IncomingCall`] — parsed `agent::before_function_call` envelope. +//! - [`pending_key`], [`extract_call`], [`block_reply_for`] — pure +//! helpers for going to / from the wire. +//! +//! The handler crate re-exports the public items from [`crate`] so +//! existing call sites don't need to import the module directly. + +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; + +/// Structured deny payload carried on wire replies, persisted records, and +/// `approval_resolved` stream events. Consumers (turn-orchestrator +/// stitching, UIs, the LLM) branch on `kind` instead of parsing prose. +/// +/// Wire shape (serde tag=kind, content=detail, snake_case): +/// `{ "kind": "policy", "detail": { "rule_permission": "...", "rule_pattern": "..." } }` +/// `{ "kind": "user_rejected", "detail": null }` +/// `{ "kind": "user_corrected", "detail": { "feedback": "..." } }` +/// `{ "kind": "state_error", "detail": { "phase": "...", "error": "..." } }` +/// +/// `Policy` names the matching rule from the layered ruleset +/// (`approval-gate/src/rules.rs`). The old `classifier_reason` / +/// `classifier_fn` shape went away when the classifier surface was +/// deleted in favor of pure rules-based decisions. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "kind", content = "detail", rename_all = "snake_case")] +pub enum Denial { + Policy { + rule_permission: String, + rule_pattern: String, + }, + UserRejected, + UserCorrected { + feedback: String, + }, + StateError { + phase: String, + error: String, + }, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct IncomingCall { + pub session_id: String, + pub function_call_id: String, + pub function_id: String, + pub args: Value, + pub approval_required: Vec, + pub event_id: String, + pub reply_stream: String, +} + +impl IncomingCall { + pub fn requires_approval(&self) -> bool { + self.approval_required + .iter() + .any(|n| n == &self.function_id) + } +} + +/// Internal allow/deny choice. Paired with a structured [`Denial`] on +/// the `Deny` arm so callers that emit a wire reply can't accidentally +/// drop the deny reason on the floor. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Decision { + Allow, + Deny(Denial), +} + +/// Wire-format decision string used by `approval::resolve` and stored +/// as the `status` field of resolved approval records. +/// +/// Serializes / deserializes as `"allow"` or `"deny"`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum WireDecision { + Allow, + Deny, +} + +/// Build the state-store key for a pending approval entry. +/// +/// `session_id` and `function_call_id` must not contain `/`. They are caller-controlled +/// IDs minted by turn-orchestrator; today neither format uses the separator. +pub fn pending_key(session_id: &str, function_call_id: &str) -> String { + debug_assert!(!session_id.contains('/'), "session_id must not contain '/'"); + debug_assert!( + !function_call_id.contains('/'), + "function_call_id must not contain '/'" + ); + format!("{session_id}/{function_call_id}") +} + +/// Parse the `agent::before_function_call` envelope into the +/// [`IncomingCall`] the gate's intercept logic operates on. Accepts both +/// the modern `function_call` shape and the legacy `tool_call` alias so +/// older sessions in-flight at upgrade time keep working. +/// +/// Returns `None` if any required field is missing — handlers treat that +/// as "not our concern" and pass the envelope through. +pub fn extract_call(envelope: &Value) -> Option { + let event_id = envelope + .get("event_id") + .and_then(Value::as_str)? + .to_string(); + let reply_stream = envelope + .get("reply_stream") + .and_then(Value::as_str)? + .to_string(); + let inner = envelope.get("payload").unwrap_or(envelope); + let session_id = inner.get("session_id").and_then(Value::as_str)?.to_string(); + let fc = inner + .get("function_call") + .or_else(|| inner.get("tool_call"))?; + let function_id = fc + .get("function_id") + .or_else(|| fc.get("name")) + .and_then(Value::as_str)? + .to_string(); + Some(IncomingCall { + session_id, + function_call_id: fc.get("id").and_then(Value::as_str)?.to_string(), + function_id, + args: fc.get("arguments").cloned().unwrap_or_else(|| json!({})), + approval_required: inner + .get("approval_required") + .and_then(|v| serde_json::from_value(v.clone()).ok()) + .unwrap_or_default(), + event_id, + reply_stream, + }) +} + +/// Build the hook block reply for a [`Decision`]. Deny replies carry the +/// structured [`Denial`] under `denial`; consumers (turn-orchestrator +/// stitching, UIs, the LLM) branch on `denial.kind` rather than parsing a +/// free-form `reason` string. +pub fn block_reply_for(decision: &Decision) -> Value { + match decision { + Decision::Allow => json!({ "block": false }), + Decision::Deny(denial) => json!({ + "block": true, + "denial": denial, + }), + } +} diff --git a/approval-gate/tests/approval_lifecycle.rs b/approval-gate/tests/approval_lifecycle.rs deleted file mode 100644 index d6d6461d..00000000 --- a/approval-gate/tests/approval_lifecycle.rs +++ /dev/null @@ -1,332 +0,0 @@ -//! End-to-end approval lifecycle: register a fake gated function, intercept -//! a call, resolve it, drive a synthetic next turn, and assert the stitched -//! system message reaches the message log. Skips cleanly when no engine. - -use std::time::Duration; - -use approval_gate::{ - register, WorkerConfig, FN_ACK_DELIVERED, FN_LIST_UNDELIVERED, FN_RESOLVE, STATE_SCOPE, -}; -use iii_sdk::{register_worker, IIIError, InitOptions, RegisterFunctionMessage, TriggerRequest}; -use serde_json::{json, Value}; - -const DEFAULT_ENGINE_URL: &str = "ws://127.0.0.1:49134"; -const ENGINE_PROBE_TIMEOUT_MS: u64 = 500; - -async fn skip_if_no_engine(url: &str) -> Option { - let iii = register_worker(url, InitOptions::default()); - let probe = iii - .trigger(TriggerRequest { - function_id: "state::get".into(), - payload: json!({ "scope": STATE_SCOPE, "key": "__probe__" }), - action: None, - timeout_ms: Some(ENGINE_PROBE_TIMEOUT_MS), - }) - .await; - if probe.is_err() { - eprintln!("skipping: no engine at {url}"); - return None; - } - Some(iii) -} - -#[tokio::test] -async fn allow_path_executes_function_and_stitches_into_next_turn() { - let url = std::env::var("III_URL").unwrap_or_else(|_| DEFAULT_ENGINE_URL.to_string()); - let Some(iii) = skip_if_no_engine(&url).await else { - return; - }; - - let nonce = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - let session_id = format!("turn-orch-it-{nonce}"); - let function_call_id = format!("tc-{nonce}"); - let topic = format!("agent::before_function_call::ito_{nonce}"); - - let target_calls: std::sync::Arc>> = - std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); - let target_calls_for_handler = target_calls.clone(); - let _target = iii.register_function(( - RegisterFunctionMessage::with_id(format!("test::write_{nonce}")) - .with_description("fake write".into()), - move |payload: Value| { - let log = target_calls_for_handler.clone(); - async move { - log.lock().unwrap().push(payload); - Ok::<_, IIIError>(json!({"ok": true, "bytes": 42})) - } - }, - )); - - let _refs = register( - &iii, - &WorkerConfig { - topic: topic.clone(), - default_timeout_ms: 30_000, - ..WorkerConfig::default() - }, - ) - .expect("register approval-gate"); - - let target_fn = format!("test::write_{nonce}"); - let envelope = json!({ - "event_id": format!("evt-{nonce}"), - "reply_stream": format!("rs-{nonce}"), - "payload": { - "session_id": session_id, - "function_call": { - "id": function_call_id, - "function_id": target_fn, - "arguments": {"path": "/tmp/foo"}, - }, - "approval_required": [target_fn.clone()], - } - }); - let intercept_resp = iii - .trigger(TriggerRequest { - function_id: "policy::approval_gate".into(), - payload: envelope, - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("intercept ok"); - - assert_eq!(intercept_resp["block"], json!(true)); - assert_eq!(intercept_resp["status"], json!("pending")); - assert!( - target_calls.lock().unwrap().is_empty(), - "function ran before approval" - ); - - let resolve_resp = iii - .trigger(TriggerRequest { - function_id: FN_RESOLVE.into(), - payload: json!({ - "session_id": session_id, - "function_call_id": function_call_id, - "decision": "allow", - }), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("resolve ok"); - assert_eq!(resolve_resp["ok"], json!(true)); - - tokio::time::sleep(Duration::from_millis(50)).await; - let calls = target_calls.lock().unwrap().clone(); - assert_eq!(calls.len(), 1, "expected one invocation; got {calls:?}"); - assert_eq!(calls[0]["path"], json!("/tmp/foo")); - - let undelivered = iii - .trigger(TriggerRequest { - function_id: FN_LIST_UNDELIVERED.into(), - payload: json!({"session_id": session_id}), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("list_undelivered ok"); - let entries = undelivered["entries"].as_array().expect("entries array"); - let our_entry = entries - .iter() - .find(|e| e["function_call_id"] == function_call_id) - .expect("our entry in undelivered list"); - assert_eq!(our_entry["status"], "executed"); - assert_eq!(our_entry["result"], json!({"ok": true, "bytes": 42})); - - let ack = iii - .trigger(TriggerRequest { - function_id: FN_ACK_DELIVERED.into(), - payload: json!({ - "session_id": session_id, - "call_ids": [function_call_id.clone()], - "turn_id": "turn-1", - }), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("ack ok"); - assert_eq!(ack["ok"], json!(true)); - assert_eq!(ack["stamped"], json!(1)); - - let after = iii - .trigger(TriggerRequest { - function_id: FN_LIST_UNDELIVERED.into(), - payload: json!({"session_id": session_id}), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("ok"); - let after_entries = after["entries"].as_array().unwrap(); - assert!( - after_entries - .iter() - .all(|e| e["function_call_id"] != function_call_id), - "after ack, our entry must not be in undelivered list" - ); -} - -#[tokio::test] -async fn deny_path_does_not_invoke_function_and_stitches_denied() { - let url = std::env::var("III_URL").unwrap_or_else(|_| DEFAULT_ENGINE_URL.to_string()); - let Some(iii) = skip_if_no_engine(&url).await else { - return; - }; - - let nonce = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - let session_id = format!("turn-orch-deny-{nonce}"); - let function_call_id = format!("tc-deny-{nonce}"); - let topic = format!("agent::before_function_call::itd_{nonce}"); - - let target_calls: std::sync::Arc>> = - std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); - let log = target_calls.clone(); - let _target = iii.register_function(( - RegisterFunctionMessage::with_id(format!("test::write_d_{nonce}")) - .with_description("fake write".into()), - move |payload: Value| { - let log = log.clone(); - async move { - log.lock().unwrap().push(payload); - Ok::<_, IIIError>(json!({"ok": true})) - } - }, - )); - - let _refs = register( - &iii, - &WorkerConfig { - topic: topic.clone(), - default_timeout_ms: 30_000, - ..WorkerConfig::default() - }, - ) - .expect("register approval-gate"); - - let target_fn = format!("test::write_d_{nonce}"); - iii.trigger(TriggerRequest { - function_id: "policy::approval_gate".into(), - payload: json!({ - "event_id": format!("evt-{nonce}"), - "reply_stream": format!("rs-{nonce}"), - "payload": { - "session_id": session_id, - "function_call": {"id": function_call_id, "function_id": target_fn, "arguments": {}}, - "approval_required": [target_fn.clone()], - } - }), - action: None, timeout_ms: Some(5_000), - }).await.expect("intercept"); - - iii.trigger(TriggerRequest { - function_id: FN_RESOLVE.into(), - payload: json!({ - "session_id": session_id, - "function_call_id": function_call_id, - "decision": "deny", - "reason": "test-deny", - }), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("resolve deny"); - - tokio::time::sleep(Duration::from_millis(50)).await; - assert!( - target_calls.lock().unwrap().is_empty(), - "function must not be invoked on deny" - ); - - let undelivered = iii - .trigger(TriggerRequest { - function_id: FN_LIST_UNDELIVERED.into(), - payload: json!({"session_id": session_id}), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("ok"); - let entries = undelivered["entries"].as_array().unwrap(); - let our_entry = entries - .iter() - .find(|e| e["function_call_id"] == function_call_id) - .expect("our entry in undelivered list"); - assert_eq!(our_entry["status"], "denied"); - assert_eq!(our_entry["decision_reason"], "test-deny"); -} - -#[tokio::test] -async fn timeout_path_lazy_flips_pending_to_timed_out_on_read() { - let url = std::env::var("III_URL").unwrap_or_else(|_| DEFAULT_ENGINE_URL.to_string()); - let Some(iii) = skip_if_no_engine(&url).await else { - return; - }; - - let nonce = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - let session_id = format!("turn-orch-to-{nonce}"); - let function_call_id = format!("tc-to-{nonce}"); - let topic = format!("agent::before_function_call::itt_{nonce}"); - - let _target = iii.register_function(( - RegisterFunctionMessage::with_id(format!("test::write_t_{nonce}")) - .with_description("fake".into()), - move |_p: Value| async move { Ok::<_, IIIError>(json!({"ok": true})) }, - )); - - let _refs = register( - &iii, - &WorkerConfig { - topic: topic.clone(), - default_timeout_ms: 100, - ..WorkerConfig::default() - }, - ) - .expect("register approval-gate"); - - let target_fn = format!("test::write_t_{nonce}"); - iii.trigger(TriggerRequest { - function_id: "policy::approval_gate".into(), - payload: json!({ - "event_id": format!("evt-{nonce}"), - "reply_stream": format!("rs-{nonce}"), - "payload": { - "session_id": session_id, - "function_call": {"id": function_call_id, "function_id": target_fn, "arguments": {}}, - "approval_required": [target_fn.clone()], - } - }), - action: None, timeout_ms: Some(5_000), - }).await.expect("intercept"); - - tokio::time::sleep(Duration::from_millis(200)).await; - - let undelivered = iii - .trigger(TriggerRequest { - function_id: FN_LIST_UNDELIVERED.into(), - payload: json!({"session_id": session_id}), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("ok"); - let entries = undelivered["entries"].as_array().unwrap(); - let our_entry = entries - .iter() - .find(|e| e["function_call_id"] == function_call_id) - .expect("expired pending must lazy-flip to timed_out and surface"); - assert_eq!(our_entry["status"], "timed_out"); - assert_eq!(our_entry["decision_reason"], "timeout"); -} diff --git a/approval-gate/tests/common/mod.rs b/approval-gate/tests/common/mod.rs new file mode 100644 index 00000000..67d73db7 --- /dev/null +++ b/approval-gate/tests/common/mod.rs @@ -0,0 +1,150 @@ +//! Shared fakes for the approval-gate test suite. +//! +//! Production code goes through `StateBus` and `FunctionExecutor` traits +//! exactly so unit tests can swap in these in-memory fakes. The trait +//! contracts are documented on the production types; the fakes here +//! mirror the wire shape closely enough that any handler behavior tied +//! to bus semantics surfaces in the tests. + +#![allow(dead_code)] // Individual test binaries pull in subsets of these. + +use std::collections::HashMap; +use std::sync::Mutex; + +use approval_gate::{FunctionExecutor, IncomingCall, StateBus}; +use serde_json::{json, Value}; + +/// Records every invocation and replays a canned response. By default +/// the fake returns `Ok({"ok": true})`; set [`Self::response`] to +/// override. +pub struct FakeExecutor { + pub calls: Mutex>, + pub response: Mutex>>, +} + +impl Default for FakeExecutor { + fn default() -> Self { + Self { + calls: Mutex::new(Vec::new()), + response: Mutex::new(None), + } + } +} + +#[async_trait::async_trait] +impl FunctionExecutor for FakeExecutor { + async fn invoke( + &self, + function_id: &str, + args: Value, + function_call_id: &str, + session_id: &str, + ) -> Result { + self.calls.lock().unwrap().push(( + function_id.to_string(), + args, + function_call_id.to_string(), + session_id.to_string(), + )); + self.response + .lock() + .unwrap() + .clone() + .unwrap_or_else(|| Ok(json!({ "ok": true }))) + } +} + +/// In-memory implementation of [`approval_gate::StateBus`]. Keys are +/// `"/"`; `list_prefix` filters by that flattened prefix +/// (same shape the production iii bus exposes). +pub struct InMemoryStateBus { + store: Mutex>, +} + +impl InMemoryStateBus { + pub fn new() -> Self { + Self { + store: Mutex::new(HashMap::new()), + } + } +} + +#[async_trait::async_trait] +impl StateBus for InMemoryStateBus { + async fn set(&self, scope: &str, key: &str, value: Value) -> Result<(), iii_sdk::IIIError> { + self.store + .lock() + .unwrap() + .insert(format!("{scope}/{key}"), value); + Ok(()) + } + async fn get(&self, scope: &str, key: &str) -> Option { + self.store + .lock() + .unwrap() + .get(&format!("{scope}/{key}")) + .cloned() + } + async fn list_prefix(&self, scope: &str, prefix: &str) -> Vec { + let map = self.store.lock().unwrap(); + map.iter() + .filter(|(k, _)| k.starts_with(&format!("{scope}/{prefix}"))) + .map(|(_, v)| v.clone()) + .collect() + } + async fn delete(&self, scope: &str, key: &str) -> Result<(), iii_sdk::IIIError> { + self.store + .lock() + .unwrap() + .remove(&format!("{scope}/{key}")); + Ok(()) + } +} + +/// `StateBus` whose `set` always errors. Used to exercise the gate's +/// fail-closed behavior on transient kv outages. +pub struct FailingStateBus; + +#[async_trait::async_trait] +impl StateBus for FailingStateBus { + async fn set( + &self, + _scope: &str, + _key: &str, + _value: Value, + ) -> Result<(), iii_sdk::IIIError> { + Err(iii_sdk::IIIError::Runtime("kv unreachable".into())) + } + async fn get(&self, _scope: &str, _key: &str) -> Option { + None + } + async fn list_prefix(&self, _scope: &str, _prefix: &str) -> Vec { + Vec::new() + } + async fn delete(&self, _scope: &str, _key: &str) -> Result<(), iii_sdk::IIIError> { + Err(iii_sdk::IIIError::Runtime("kv unreachable".into())) + } +} + +/// A canonical `shell::fs::write` call gated by the run's +/// `approval_required` list. Most handler tests use this so the only +/// thing they need to vary is the session/call id + whether the run +/// opts in. +pub fn sample_call() -> IncomingCall { + IncomingCall { + session_id: "s1".into(), + function_call_id: "tc-1".into(), + function_id: "shell::fs::write".into(), + args: json!({ "path": "/tmp/x" }), + approval_required: vec!["shell::fs::write".into()], + event_id: "evt-1".into(), + reply_stream: "rs-1".into(), + } +} + +/// Empty runtime ruleset for handler tests that don't care about the +/// cascade-on-`always` path. Each call freshly constructs the lock so +/// tests stay independent — there's no shared mutable state. +pub fn empty_policy_rules() -> std::sync::RwLock { + std::sync::RwLock::new(approval_gate::rules::Ruleset::new()) +} diff --git a/approval-gate/tests/integration.rs b/approval-gate/tests/integration.rs deleted file mode 100644 index 80b05835..00000000 --- a/approval-gate/tests/integration.rs +++ /dev/null @@ -1,150 +0,0 @@ -//! Engine-backed test for approval-gate. Connects to an in-process / -//! local iii engine, registers the gate, fires a `before_function_call` -//! envelope on a per-test topic, posts `approval::resolve`, and asserts -//! the trigger model behavior. -//! -//! Skips cleanly when no engine is reachable so `cargo test` stays green -//! in CI without a running engine. - -use std::time::Duration; - -use approval_gate::{register, WorkerConfig, FN_LIST_UNDELIVERED, FN_RESOLVE, STATE_SCOPE}; -use iii_sdk::{register_worker, InitOptions, TriggerRequest}; -use serde_json::json; - -const DEFAULT_ENGINE_URL: &str = "ws://127.0.0.1:49134"; -const ENGINE_PROBE_TIMEOUT_MS: u64 = 500; - -#[tokio::test] -async fn round_trip_allow_returns_pending_immediately_and_executes_on_resolve() { - let url = std::env::var("III_URL").unwrap_or_else(|_| DEFAULT_ENGINE_URL.to_string()); - let iii = register_worker(&url, InitOptions::default()); - - // Probe the engine with a short-timeout state::get; if it errors, - // assume no engine is running locally and skip cleanly. - let probe = iii - .trigger(TriggerRequest { - function_id: "state::get".into(), - payload: json!({ "scope": STATE_SCOPE, "key": "__probe__" }), - action: None, - timeout_ms: Some(ENGINE_PROBE_TIMEOUT_MS), - }) - .await; - if probe.is_err() { - eprintln!("skipping: no engine at {url}"); - return; - } - - // Use a unique topic per run so concurrent test runs don't collide, - // and so we don't race the production approval-gate worker if one is - // already subscribed to the default topic. - let nonce = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_nanos()) - .unwrap_or(0); - let topic = format!("agent::before_function_call::it_{nonce}"); - let session_id = format!("approval-it-{nonce}"); - let function_call_id = format!("tc-it-{nonce}"); - let event_id = format!("evt-it-{nonce}"); - let reply_stream = format!("rs-it-{nonce}"); - - let _refs = register( - &iii, - &WorkerConfig { - topic: topic.clone(), - default_timeout_ms: 5_000, - ..WorkerConfig::default() - }, - ) - .expect("register approval-gate"); - - let envelope = json!({ - "event_id": event_id, - "reply_stream": reply_stream, - "payload": { - "session_id": session_id, - "function_call": { - "id": function_call_id, - "function_id": "shell::filesystem::write", - "arguments": {}, - }, - "approval_required": ["shell::filesystem::write"], - } - }); - - // Drive the subscriber by directly triggering its function id. - // In the trigger model, it returns immediately with block=true + pending. - let reply = iii - .trigger(TriggerRequest { - function_id: "policy::approval_gate".into(), - payload: envelope, - action: None, - timeout_ms: Some(10_000), - }) - .await - .expect("subscriber trigger ok"); - - assert_eq!(reply["block"], true, "subscriber reply: {reply}"); - assert_eq!(reply["status"], "pending", "subscriber reply: {reply}"); - - // Wait for the gate to write the pending record before we resolve. - let key = format!("{session_id}/{function_call_id}"); - let mut tries = 0; - loop { - let v = iii - .trigger(TriggerRequest { - function_id: "state::get".into(), - payload: json!({ "scope": STATE_SCOPE, "key": key }), - action: None, - timeout_ms: Some(1_000), - }) - .await - .unwrap_or(json!(null)); - if v.get("status").and_then(|s| s.as_str()) == Some("pending") { - break; - } - tries += 1; - assert!(tries < 40, "pending entry never appeared (key={key})"); - tokio::time::sleep(Duration::from_millis(50)).await; - } - - // Post the allow decision. - let resolve = iii - .trigger(TriggerRequest { - function_id: FN_RESOLVE.into(), - payload: json!({ - "session_id": session_id, - "function_call_id": function_call_id, - "decision": "allow", - }), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("resolve trigger"); - assert_eq!(resolve["ok"], true, "resolve response: {resolve}"); - - // The underlying function "shell::filesystem::write" doesn't exist in - // the test engine, so the invocation will fail and the record should be - // "failed". Verify it surfaced in list_undelivered. - tokio::time::sleep(Duration::from_millis(100)).await; - let undelivered = iii - .trigger(TriggerRequest { - function_id: FN_LIST_UNDELIVERED.into(), - payload: json!({ "session_id": session_id }), - action: None, - timeout_ms: Some(5_000), - }) - .await - .expect("list_undelivered ok"); - let entries = undelivered["entries"].as_array().expect("entries array"); - let our_entry = entries - .iter() - .find(|e| e["function_call_id"] == function_call_id) - .expect("our entry in undelivered list"); - assert!( - our_entry["status"] == "failed" || our_entry["status"] == "executed", - "unexpected status: {}", - our_entry["status"] - ); -} diff --git a/approval-gate/tests/lifecycle.rs b/approval-gate/tests/lifecycle.rs new file mode 100644 index 00000000..6e603e75 --- /dev/null +++ b/approval-gate/tests/lifecycle.rs @@ -0,0 +1,244 @@ +//! End-to-end lifecycle tests for the approval-gate simplification (T17). +//! +//! Each test exercises the full intercept → resolve → consume flow against +//! the in-memory `StateBus` + `FunctionExecutor` fakes. These are the +//! integration safety net for the whole refactor — they assert that the +//! pieces snap together correctly without an iii engine. +//! +//! Four flows: +//! 1. Allow path — full Ask → Allow → Executed → consume drains it +//! 2. Deny path — full Ask → Deny(UserCorrected) → consume drains it +//! 3. Timeout path — Pending past expires_at → consume lazy-flips → drains it +//! 4. Cascade path — two Pending rows; allow+always; both end up consumed + +mod common; + +use approval_gate::record::{Outcome, Record, Status}; +use approval_gate::rules::{Action, Rule, Ruleset}; +use approval_gate::{ + handle_consume, handle_intercept, handle_resolve, IncomingCall, StateBus, STATE_SCOPE, +}; +use common::{FakeExecutor, InMemoryStateBus}; +use serde_json::{json, Value}; +use std::sync::{Arc, RwLock}; + +fn call(session: &str, cid: &str, fn_id: &str, args: Value) -> IncomingCall { + IncomingCall { + session_id: session.into(), + function_call_id: cid.into(), + function_id: fn_id.into(), + args, + approval_required: Vec::new(), + event_id: format!("evt-{cid}"), + reply_stream: format!("rs-{cid}"), + } +} + +fn ruleset_with(rules: Vec) -> Arc> { + Arc::new(RwLock::new(rules)) +} + +#[tokio::test] +async fn allow_path_end_to_end() { + let bus = InMemoryStateBus::new(); + let exec = FakeExecutor::default(); + *exec.response.lock().unwrap() = Some(Ok(json!({ "stdout": "hello\n" }))); + // Empty ruleset → verdict defaults to Ask → intercept writes Pending. + let policy_rules = ruleset_with(vec![]); + + // 1. Hook fires; gate writes Pending. + let incoming = call("sess_allow", "tc-1", "shell::exec", + json!({"command": "echo", "args": ["hello"]})); + let reply = handle_intercept( + &bus, STATE_SCOPE, &incoming, &policy_rules.read().unwrap(), + 1_000, 60_000, + ).await; + assert_eq!(reply["status"], "pending"); + + // 2. Operator resolves allow. + let payload = json!({ + "session_id": "sess_allow", + "function_call_id": "tc-1", + "decision": "allow", + }); + let resolve_reply = handle_resolve(&bus, &exec, STATE_SCOPE, &policy_rules, payload, 2_000).await; + assert_eq!(resolve_reply["ok"], true); + + // Executor must have been called exactly once with the original argv. + let calls = exec.calls.lock().unwrap(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].0, "shell::exec"); + assert_eq!(calls[0].1["command"], "echo"); + drop(calls); + + // 3. Consume drains the Done row. + let consume_reply = handle_consume( + &bus, STATE_SCOPE, + json!({ "session_id": "sess_allow" }), 3_000, + ).await; + assert_eq!(consume_reply["ok"], true); + let entries = consume_reply["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0]["outcome"]["kind"], "executed"); + assert_eq!(entries[0]["outcome"]["detail"]["result"]["stdout"], "hello\n"); + + // 4. Row gone from state. + let leftover = bus.list_prefix(STATE_SCOPE, "sess_allow/").await; + assert!(leftover.is_empty(), "consume must delete drained rows"); +} + +#[tokio::test] +async fn deny_path_with_user_corrected_feedback_end_to_end() { + let bus = InMemoryStateBus::new(); + let exec = FakeExecutor::default(); + let policy_rules = ruleset_with(vec![]); + + let incoming = call("sess_deny", "tc-2", "shell::exec", + json!({"command": "rm", "args": ["-rf", "/tmp/x"]})); + handle_intercept( + &bus, STATE_SCOPE, &incoming, &policy_rules.read().unwrap(), + 1_000, 60_000, + ).await; + + // Operator denies with a correction message. + let payload = json!({ + "session_id": "sess_deny", + "function_call_id": "tc-2", + "decision": "deny", + "denial": { + "kind": "user_corrected", + "detail": { "feedback": "wrong path, use /tmp/y" }, + }, + }); + let r = handle_resolve(&bus, &exec, STATE_SCOPE, &policy_rules, payload, 2_000).await; + assert_eq!(r["ok"], true); + // Shell must NOT have been invoked. + assert_eq!(exec.calls.lock().unwrap().len(), 0); + + let consume = handle_consume( + &bus, STATE_SCOPE, + json!({ "session_id": "sess_deny" }), 3_000, + ).await; + let entries = consume["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0]["outcome"]["kind"], "denied"); + assert_eq!( + entries[0]["outcome"]["detail"]["denial"]["kind"], + "user_corrected", + ); + assert_eq!( + entries[0]["outcome"]["detail"]["denial"]["detail"]["feedback"], + "wrong path, use /tmp/y", + ); + + let leftover = bus.list_prefix(STATE_SCOPE, "sess_deny/").await; + assert!(leftover.is_empty()); +} + +#[tokio::test] +async fn timeout_path_lazy_flips_on_consume_end_to_end() { + let bus = InMemoryStateBus::new(); + let policy_rules = ruleset_with(vec![]); + + // Seed a Pending row with a very short timeout. + let incoming = call("sess_timeout", "tc-3", "shell::exec", + json!({"command": "ls"})); + handle_intercept( + &bus, STATE_SCOPE, &incoming, &policy_rules.read().unwrap(), + 1_000, // now_ms + 1, // timeout_ms → expires_at = 1_001 + ).await; + + // Consume at now=2_000, well past expires_at. Lazy flip + return. + let consume = handle_consume( + &bus, STATE_SCOPE, + json!({ "session_id": "sess_timeout" }), 2_000, + ).await; + let entries = consume["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0]["outcome"]["kind"], "timed_out"); + + let leftover = bus.list_prefix(STATE_SCOPE, "sess_timeout/").await; + assert!(leftover.is_empty()); +} + +#[tokio::test] +async fn cascade_path_end_to_end() { + let bus = InMemoryStateBus::new(); + let exec = FakeExecutor::default(); + *exec.response.lock().unwrap() = Some(Ok(json!({ "stdout": "" }))); + let policy_rules = ruleset_with(vec![]); + + // Two pending rows for the SAME argv shape — `cascade_allow_for_session` + // pushes an exact-pattern Allow rule from the first row's args; the + // second row should then verdict as Allow and auto-resolve. + for cid in ["tc-4", "tc-5"] { + let incoming = call("sess_cascade", cid, "shell::exec", + json!({"command": "echo", "args": ["go"]})); + handle_intercept( + &bus, STATE_SCOPE, &incoming, &policy_rules.read().unwrap(), + 1_000, 60_000, + ).await; + } + + // Resolve tc-4 with always:true. + let payload = json!({ + "session_id": "sess_cascade", + "function_call_id": "tc-4", + "decision": "allow", + "always": true, + }); + let r = handle_resolve(&bus, &exec, STATE_SCOPE, &policy_rules, payload, 2_000).await; + assert_eq!(r["ok"], true); + assert_eq!(r["cascaded"], 1, "one extra row (tc-5) must have auto-resolved"); + + // Executor called twice (originator + cascade). + assert_eq!(exec.calls.lock().unwrap().len(), 2); + + // Both rows in state as Done(Executed). Consume drains them. + let consume = handle_consume( + &bus, STATE_SCOPE, + json!({ "session_id": "sess_cascade" }), 3_000, + ).await; + let entries = consume["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 2); + for e in entries { + assert_eq!(e["outcome"]["kind"], "executed"); + } + + // Ruleset gained the runtime Allow rule with the exact pattern. + let rs = policy_rules.read().unwrap(); + let pushed = rs.last().expect("cascade must push a rule"); + assert_eq!(pushed.action, Action::Allow); + assert_eq!(pushed.permission, "shell::exec"); + assert_eq!( + pushed.pattern, "echo go", + "exact-pattern push (NOT blanket '*') — 'always allow echo go' does not grant rm -rf /", + ); + + let leftover = bus.list_prefix(STATE_SCOPE, "sess_cascade/").await; + assert!(leftover.is_empty()); +} + +#[tokio::test] +async fn allow_rule_short_circuits_with_no_state_write() { + // Bonus: not from the plan, but worth pinning. A Verdict::Allow at + // intercept time must NOT write a Pending row (no state, no consume). + let bus = InMemoryStateBus::new(); + let policy_rules = ruleset_with(vec![Rule { + permission: "shell::exec".into(), + pattern: "git status*".into(), + action: Action::Allow, + }]); + + let incoming = call("sess_pass", "tc-6", "shell::exec", + json!({"command": "git", "args": ["status"]})); + let reply = handle_intercept( + &bus, STATE_SCOPE, &incoming, &policy_rules.read().unwrap(), + 1_000, 60_000, + ).await; + assert_eq!(reply["block"], false); + + let leftover = bus.list_prefix(STATE_SCOPE, "sess_pass/").await; + assert!(leftover.is_empty(), "Allow path must not touch state"); +} diff --git a/approval-gate/tests/wire.rs b/approval-gate/tests/wire.rs new file mode 100644 index 00000000..dc1ab59d --- /dev/null +++ b/approval-gate/tests/wire.rs @@ -0,0 +1,156 @@ +//! Wire-shape helpers: extract_call envelope parsing, block_reply_for +//! hook reply, IncomingCall::requires_approval semantics. + +mod common; + +use approval_gate::*; +use common::{empty_policy_rules, sample_call, FailingStateBus, FakeExecutor, InMemoryStateBus}; +use serde_json::{json, Value}; +use std::sync::Mutex; + + + + #[test] + fn extract_call_reads_session_id_and_function_call_from_envelope() { + let envelope = json!({ + "event_id": "evt-1", + "reply_stream": "rs-1", + "payload": { + "function_call": { "id": "tc-1", "function_id": "write", "arguments": {"path": "/tmp/x"} }, + "approval_required": ["write"], + "session_id": "s1", + } + }); + let call = extract_call(&envelope).expect("decoded"); + assert_eq!(call.session_id, "s1"); + assert_eq!(call.function_call_id, "tc-1"); + assert_eq!(call.function_id, "write"); + assert_eq!(call.event_id, "evt-1"); + assert_eq!(call.reply_stream, "rs-1"); + assert!(call.approval_required.iter().any(|s| s == "write")); + } + + + #[test] + fn extract_call_accepts_legacy_tool_call_envelope_with_name() { + let envelope = json!({ + "event_id": "evt-1", + "reply_stream": "rs-1", + "payload": { + "tool_call": { "id": "tc-1", "name": "write", "arguments": {} }, + "approval_required": ["write"], + "session_id": "s1", + } + }); + let call = extract_call(&envelope).expect("decoded"); + assert_eq!(call.function_call_id, "tc-1"); + assert_eq!(call.function_id, "write"); + } + + + #[test] + fn requires_approval_only_for_listed_functions() { + let call = IncomingCall { + session_id: "s1".into(), + function_call_id: "tc-1".into(), + function_id: "ls".into(), + args: json!({}), + approval_required: vec!["write".into()], + event_id: "e".into(), + reply_stream: "r".into(), + }; + assert!(!call.requires_approval()); + + let call2 = IncomingCall { + function_id: "write".into(), + ..call + }; + assert!(call2.requires_approval()); + } + + + #[test] + fn block_reply_for_decision_allow_does_not_block() { + let reply = block_reply_for(&Decision::Allow); + assert_eq!(reply["block"], false); + } + + + #[test] + fn block_reply_for_deny_emits_structured_denial() { + let reply = block_reply_for(&Decision::Deny(Denial::UserRejected)); + assert_eq!(reply["block"], true); + assert_eq!(reply["denial"]["kind"], "user_rejected"); + assert!(reply.as_object().unwrap().get("reason").is_none()); + } + + + #[test] + fn block_reply_for_policy_deny_carries_classifier_detail() { + let reply = block_reply_for(&Decision::Deny(Denial::Policy { + rule_permission: "shell::exec".into(), + rule_pattern: "rm -rf*".into(), + })); + assert_eq!(reply["block"], true); + assert_eq!(reply["denial"]["kind"], "policy"); + assert_eq!( + reply["denial"]["detail"]["rule_permission"], + "shell::exec" + ); + assert_eq!( + reply["denial"]["detail"]["rule_pattern"], + "rm -rf*" + ); + } + + + #[test] + fn block_reply_for_user_corrected_carries_feedback() { + let reply = block_reply_for(&Decision::Deny(Denial::UserCorrected { + feedback: "use git diff instead".into(), + })); + assert_eq!(reply["denial"]["kind"], "user_corrected"); + assert_eq!( + reply["denial"]["detail"]["feedback"], + "use git diff instead" + ); + } + + + #[test] + fn extract_call_returns_none_when_function_call_absent() { + let envelope = json!({ + "event_id": "evt-1", + "reply_stream": "rs-1", + "payload": { "session_id": "s1", "approval_required": ["write"] } + }); + assert!(extract_call(&envelope).is_none()); + } + + + #[test] + fn extract_call_returns_none_when_session_id_absent() { + let envelope = json!({ + "event_id": "evt-1", + "reply_stream": "rs-1", + "payload": { + "tool_call": { "id": "tc-1", "name": "write", "arguments": {} } + } + }); + assert!(extract_call(&envelope).is_none()); + } + + + #[test] + fn block_reply_for_allow_omits_denial_and_reason() { + let reply = block_reply_for(&Decision::Allow); + assert_eq!(reply["block"], false); + assert!( + reply.get("reason").is_none(), + "Allow must not include reason: {reply}" + ); + assert!( + reply.get("denial").is_none(), + "Allow must not include denial: {reply}" + ); + } diff --git a/harness/crates/harness-types/src/agent_event.rs b/harness/crates/harness-types/src/agent_event.rs index 6b37f381..f87ad305 100644 --- a/harness/crates/harness-types/src/agent_event.rs +++ b/harness/crates/harness-types/src/agent_event.rs @@ -23,7 +23,6 @@ pub enum ApprovalDecision { /// `{ "kind": "user_rejected", "detail": null }` /// `{ "kind": "user_corrected", "detail": { "feedback": "..." } }` /// `{ "kind": "state_error", "detail": { "phase": "...", "error": "..." } }` -/// `{ "kind": "legacy", "detail": { "reason": "..." } }` #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "kind", content = "detail", rename_all = "snake_case")] pub enum Denial { @@ -39,9 +38,6 @@ pub enum Denial { phase: String, error: String, }, - Legacy { - reason: String, - }, } /// Stable wire format emitted by the loop on `agent::events/`. diff --git a/harness/src/fanout.rs b/harness/src/fanout.rs index b5d2e6ad..4d9bbbe7 100644 --- a/harness/src/fanout.rs +++ b/harness/src/fanout.rs @@ -101,6 +101,14 @@ const SESSIONS_POLL_INTERVAL_MS: u64 = 1_000; /// Approval poll cadence. Hook-driven push (via the agent::events stream /// pump) covers low-latency notification; this poll catches missed states /// and clears resolved approvals. +/// +/// T16 follow-up (deferred): replace this poll loop with a per-session +/// agent::events subscription that reconciles requested / resolved +/// frames against a client-side cache. `approval::list_pending` would +/// stay as the one-shot rehydration call on connect/reload. Polling +/// works today against the new wire shape (Records carry +/// `function_call_id` at the top level, which is what `diff_approvals` +/// reads) — this is purely a perf cleanup. const APPROVAL_POLL_INTERVAL_MS: u64 = 1_000; /// Cost summary poll cadence. Each tick performs a `budget::list` and (for diff --git a/harness/web/src/components/ApprovalRow.tsx b/harness/web/src/components/ApprovalRow.tsx index 25ef1b99..be82413b 100644 --- a/harness/web/src/components/ApprovalRow.tsx +++ b/harness/web/src/components/ApprovalRow.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useEffect, useState } from "react"; import { bridge, BridgeError } from "../bridge"; import type { PendingApproval } from "../types"; @@ -7,22 +7,140 @@ interface Props { pending: PendingApproval[]; } +type ResolveDecision = "allow" | "deny"; +type DenialPayload = + | { kind: "user_rejected" } + | { kind: "user_corrected"; detail: { feedback: string } }; + +/** + * Subscribe to a 1s tick and report `expiresAt - now` (ms). Returns the + * raw remaining number so the parent can render its own format. Negative + * once expired; the parent disables actions on `remaining <= 0`. + */ +function useCountdown(expiresAt: number | undefined): number { + const [now, setNow] = useState(() => Date.now()); + useEffect(() => { + if (!expiresAt) return; + const t = setInterval(() => setNow(Date.now()), 1000); + return () => clearInterval(t); + }, [expiresAt]); + if (!expiresAt) return Number.POSITIVE_INFINITY; + return expiresAt - now; +} + +function formatRemaining(ms: number): string { + if (!Number.isFinite(ms)) return ""; + if (ms <= 0) return "expired"; + const s = Math.floor(ms / 1000); + const m = Math.floor(s / 60); + const r = s % 60; + return `${m}:${String(r).padStart(2, "0")}`; +} + +interface ApprovalCardProps { + sessionId: string; + approval: PendingApproval; + callId: string; + fnId: string; + busyId: string | null; + onResolve: (functionCallId: string, decision: ResolveDecision, opts?: { always?: boolean; denial?: DenialPayload }) => void; +} + +/** + * One pending-approval card. Lives in its own component so the + * countdown hook is rendered once per row (hooks can't sit inside + * a .map callback without violating the rules-of-hooks contract). + */ +function ApprovalCard({ sessionId: _sessionId, approval, callId, fnId, busyId, onResolve }: ApprovalCardProps) { + const remaining = useCountdown(approval.expires_at); + const expired = remaining <= 0; + const [feedback, setFeedback] = useState(""); + + const denyClick = () => { + const trimmed = feedback.trim(); + if (trimmed.length > 0) { + onResolve(callId, "deny", { + denial: { kind: "user_corrected", detail: { feedback: trimmed } }, + }); + } else { + onResolve(callId, "deny"); + } + }; + + return ( +
+
+ approval needed + {fnId} + {approval.expires_at ? ( + {formatRemaining(remaining)} + ) : null} +
+
{JSON.stringify(approval.args, null, 2)}
+
+ add correction (optional, sent to the model on deny) +