diff --git a/Cargo.lock b/Cargo.lock index 936a96f5e..8f96b16b9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3319,6 +3319,7 @@ dependencies = [ "base_mem_check", "base_rt", "bytes", + "dashmap", "deno_core", "deno_fetch", "deno_fs", @@ -3335,6 +3336,7 @@ dependencies = [ "log", "memmem", "once_cell", + "regex", "scopeguard", "serde", "thiserror 2.0.8", diff --git a/cli/src/flags.rs b/cli/src/flags.rs index c7e13abcc..136b3e0a9 100644 --- a/cli/src/flags.rs +++ b/cli/src/flags.rs @@ -232,6 +232,15 @@ fn get_start_command() -> Command { )) .value_parser(value_parser!(u64)), ) + .arg( + arg!(--"rate-limit-table-cleanup-interval" ) + .help(concat!( + "Interval in seconds between sweeps of the outbound rate-limit ", + "table to remove expired entries (default: 60)" + )) + .default_value("60") + .value_parser(value_parser!(u64)), + ) .arg( arg!(--"inspect"[HOST_AND_PORT]) .help("Activate inspector on host:port") diff --git a/cli/src/main.rs b/cli/src/main.rs index 2071d7207..f12d5f889 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -229,7 +229,13 @@ fn main() -> Result { .get_one::("request-buffer-size") .copied() .unwrap(); - + let rate_limit_cleanup_interval_sec = sub_matches + .get_one::("rate-limit-table-cleanup-interval") + .copied() + .unwrap_or(60); + if rate_limit_cleanup_interval_sec == 0 { + bail!("--rate-limit-table-cleanup-interval must be >= 1 second"); + } let flags = ServerFlags { otel: if !enable_otel.is_empty() { if enable_otel.len() > 1 { @@ -264,6 +270,8 @@ fn main() -> Result { beforeunload_wall_clock_pct: maybe_beforeunload_wall_clock_pct, beforeunload_cpu_pct: maybe_beforeunload_cpu_pct, beforeunload_memory_pct: maybe_beforeunload_memory_pct, + + rate_limit_cleanup_interval_sec, }; let mut builder = Builder::new(addr, &main_service_path); diff --git a/crates/base/src/server.rs b/crates/base/src/server.rs index 9ab56aa57..4a649afd1 100644 --- a/crates/base/src/server.rs +++ b/crates/base/src/server.rs @@ -332,6 +332,8 @@ pub struct ServerFlags { pub beforeunload_wall_clock_pct: Option, pub beforeunload_cpu_pct: Option, pub beforeunload_memory_pct: Option, + + pub rate_limit_cleanup_interval_sec: u64, } #[derive(Debug)] diff --git a/crates/base/src/worker/pool.rs b/crates/base/src/worker/pool.rs index 6cface98e..ee9142b28 100644 --- a/crates/base/src/worker/pool.rs +++ b/crates/base/src/worker/pool.rs @@ -18,6 +18,8 @@ use either::Either::Left; use enum_as_inner::EnumAsInner; use ext_event_worker::events::WorkerEventWithMetadata; use ext_runtime::SharedMetricSource; +use ext_runtime::SharedRateLimitTable; +use ext_runtime::TraceRateLimiterConfig; use ext_workers::context::CreateUserWorkerResult; use ext_workers::context::SendRequestResult; use ext_workers::context::Timing; @@ -235,6 +237,7 @@ pub struct WorkerPool { pub flags: Arc, pub policy: WorkerPoolPolicy, pub metric_src: SharedMetricSource, + pub shared_rate_limit_table: SharedRateLimitTable, pub user_workers: HashMap, pub active_workers: HashMap, pub worker_pool_msgs_tx: mpsc::UnboundedSender, @@ -253,11 +256,19 @@ impl WorkerPool { worker_event_sender: Option>, worker_pool_msgs_tx: mpsc::UnboundedSender, inspector: Option, + cancel: CancellationToken, ) -> Self { + let shared_rate_limit_table = SharedRateLimitTable::default(); + shared_rate_limit_table.spawn_cleanup_task( + Duration::from_secs(flags.rate_limit_cleanup_interval_sec), + cancel, + ); + Self { flags, policy, metric_src, + shared_rate_limit_table, worker_event_sender, user_workers: HashMap::new(), active_workers: HashMap::new(), @@ -384,6 +395,7 @@ impl WorkerPool { let worker_pool_msgs_tx = self.worker_pool_msgs_tx.clone(); let events_msg_tx = self.worker_event_sender.clone(); let supervisor_policy = self.policy.supervisor_policy; + let shared_rate_limit_table = self.shared_rate_limit_table.clone(); drop(tokio::spawn(async move { let (permit, tx) = match wait_fence_fut.await { @@ -462,6 +474,17 @@ impl WorkerPool { user_worker_rt_opts.events_msg_tx = events_msg_tx; user_worker_rt_opts.cancel = Some(cancel.clone()); + if let ext_runtime::RateLimiterOpts::Rules { rules, global_key } = + std::mem::take(&mut user_worker_rt_opts.rate_limiter) + { + user_worker_rt_opts.rate_limiter = + ext_runtime::RateLimiterOpts::Configured(TraceRateLimiterConfig { + table: shared_rate_limit_table, + rules, + global_key: Some(global_key), + }); + } + worker_options.timing = Some(Timing { early_drop_rx, status: status.clone(), @@ -792,6 +815,7 @@ pub async fn create_user_worker_pool( async move { let token = termination_token.as_ref(); let mut termination_requested = false; + let cleanup_cancel = token.map(|t| t.inbound.clone()).unwrap_or_default(); let mut worker_pool = WorkerPool::new( flags, policy, @@ -799,6 +823,7 @@ pub async fn create_user_worker_pool( worker_event_sender, user_worker_msgs_tx_clone, inspector, + cleanup_cancel, ); // Note: Keep this loop non-blocking. Spawn a task to run blocking calls. diff --git a/crates/base/src/worker/worker_inner.rs b/crates/base/src/worker/worker_inner.rs index 28d7ee1a1..372b8be80 100644 --- a/crates/base/src/worker/worker_inner.rs +++ b/crates/base/src/worker/worker_inner.rs @@ -12,7 +12,9 @@ use ext_event_worker::events::UncaughtExceptionEvent; use ext_event_worker::events::WorkerEventWithMetadata; use ext_event_worker::events::WorkerEvents; use ext_runtime::MetricSource; +use ext_runtime::RateLimiterOpts; use ext_runtime::RuntimeMetricSource; +use ext_runtime::TraceRateLimiter; use ext_runtime::WorkerMetricSource; use ext_workers::context::UserWorkerMsgs; use ext_workers::context::WorkerContextInitOpts; @@ -270,6 +272,20 @@ impl Worker { state_mut.put(metric_src.clone()); MetricSource::Runtime(metric_src) } else { + if let Some(opts) = new_runtime.conf.as_user_worker().cloned() { + if let RateLimiterOpts::Configured(config) = opts.rate_limiter { + match TraceRateLimiter::new(config) { + Ok(limiter) => { + let state = new_runtime.js_runtime.op_state(); + let mut state_mut = state.borrow_mut(); + state_mut.put(limiter); + } + Err(err) => { + error!("failed to compile rate limit rules: {err}"); + } + } + } + } MetricSource::Worker(metric_src) } }; diff --git a/crates/base/test_cases/rate-limit-a/index.ts b/crates/base/test_cases/rate-limit-a/index.ts new file mode 100644 index 000000000..b91d2ebf0 --- /dev/null +++ b/crates/base/test_cases/rate-limit-a/index.ts @@ -0,0 +1,81 @@ +// Worker A: forwards to worker B. Supports two outbound HTTP modes selected +// via the x-http-mode header: "fetch" (default) or "node". +import * as http from "node:http"; + +function requestViaNode( + url: string, + headers: Record, +): Promise<{ status: number; body: string }> { + return new Promise((resolve, reject) => { + const parsed = new URL(url); + const req = http.request( + { + hostname: parsed.hostname, + port: parsed.port, + path: parsed.pathname, + method: "GET", + headers, + }, + (res) => { + let body = ""; + res.on("data", (chunk) => { + body += chunk; + }); + res.on("end", () => resolve({ status: res.statusCode ?? 500, body })); + }, + ); + req.on("error", reject); + req.end(); + }); +} + +Deno.serve(async (req: Request) => { + if (!req.headers.has("traceparent")) { + return new Response( + JSON.stringify({ msg: "missing traceparent header" }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ); + } + + const serverUrl = req.headers.get("x-test-server-url"); + if (!serverUrl) { + return new Response( + JSON.stringify({ msg: "missing x-test-server-url header" }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ); + } + + const mode = req.headers.get("x-http-mode") ?? "fetch"; + const forwardHeaders: Record = { + "x-test-server-url": serverUrl, + "x-http-mode": mode, + }; + + try { + let status: number; + let body: string; + + if (mode === "node") { + ({ status, body } = await requestViaNode( + `${serverUrl}/rate-limit-b`, + forwardHeaders, + )); + } else { + const resp = await fetch(`${serverUrl}/rate-limit-b`, { + headers: forwardHeaders, + }); + status = resp.status; + body = await resp.text(); + } + + return new Response(body, { + status, + headers: { "Content-Type": "application/json" }, + }); + } catch (e) { + return new Response( + JSON.stringify({ msg: e.toString() }), + { status: 500, headers: { "Content-Type": "application/json" } }, + ); + } +}); diff --git a/crates/base/test_cases/rate-limit-b/index.ts b/crates/base/test_cases/rate-limit-b/index.ts new file mode 100644 index 000000000..b2d787245 --- /dev/null +++ b/crates/base/test_cases/rate-limit-b/index.ts @@ -0,0 +1,81 @@ +// Worker B: forwards back to worker A. Supports two outbound HTTP modes +// selected via the x-http-mode header: "fetch" (default) or "node". +import * as http from "node:http"; + +function requestViaNode( + url: string, + headers: Record, +): Promise<{ status: number; body: string }> { + return new Promise((resolve, reject) => { + const parsed = new URL(url); + const req = http.request( + { + hostname: parsed.hostname, + port: parsed.port, + path: parsed.pathname, + method: "GET", + headers, + }, + (res) => { + let body = ""; + res.on("data", (chunk) => { + body += chunk; + }); + res.on("end", () => resolve({ status: res.statusCode ?? 500, body })); + }, + ); + req.on("error", reject); + req.end(); + }); +} + +Deno.serve(async (req: Request) => { + if (!req.headers.has("traceparent")) { + return new Response( + JSON.stringify({ msg: "missing traceparent header" }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ); + } + + const serverUrl = req.headers.get("x-test-server-url"); + if (!serverUrl) { + return new Response( + JSON.stringify({ msg: "missing x-test-server-url header" }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ); + } + + const mode = req.headers.get("x-http-mode") ?? "fetch"; + const forwardHeaders: Record = { + "x-test-server-url": serverUrl, + "x-http-mode": mode, + }; + + try { + let status: number; + let body: string; + + if (mode === "node") { + ({ status, body } = await requestViaNode( + `${serverUrl}/rate-limit-a`, + forwardHeaders, + )); + } else { + const resp = await fetch(`${serverUrl}/rate-limit-a`, { + headers: forwardHeaders, + }); + status = resp.status; + body = await resp.text(); + } + + return new Response(body, { + status, + headers: { "Content-Type": "application/json" }, + }); + } catch (e) { + return new Response( + JSON.stringify({ msg: e.toString() }), + { status: 500, headers: { "Content-Type": "application/json" } }, + ); + } +}); diff --git a/crates/base/test_cases/rate-limit-echo/index.ts b/crates/base/test_cases/rate-limit-echo/index.ts new file mode 100644 index 000000000..cb74525fb --- /dev/null +++ b/crates/base/test_cases/rate-limit-echo/index.ts @@ -0,0 +1,15 @@ +// Worker that echoes back the trace ID from the AsyncVariable context. +// Exposed via globalThis.getRequestTraceId when exposeRequestTraceId context +// flag is set. Used to verify AsyncVariable isolation across concurrent +// requests. +Deno.serve(async (_req: Request) => { + // Small delay so concurrent requests actually overlap inside the event loop. + await new Promise((resolve) => setTimeout(resolve, 30)); + + const traceId = (globalThis as any).getRequestTraceId?.() ?? null; + + return new Response( + JSON.stringify({ traceId }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ); +}); diff --git a/crates/base/test_cases/rate-limit-main/index.ts b/crates/base/test_cases/rate-limit-main/index.ts new file mode 100644 index 000000000..f4a34f7e8 --- /dev/null +++ b/crates/base/test_cases/rate-limit-main/index.ts @@ -0,0 +1,79 @@ +console.log("rate-limit-main started"); + +Deno.serve(async (req: Request) => { + const url = new URL(req.url); + const { pathname } = url; + + console.log( + `Received request for ${pathname}, transparent header: ${ + req.headers.get("traceparent") + }`, + ); + + const path_parts = pathname.split("/"); + const service_name = path_parts[1]; + + if (!service_name || service_name === "") { + return new Response( + JSON.stringify({ msg: "missing function name in request" }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ); + } + + const servicePath = `./test_cases/${service_name}`; + + const createWorker = async () => { + const memoryLimitMb = 150; + const workerTimeoutMs = 10 * 60 * 1000; + const cpuTimeSoftLimitMs = 10 * 60 * 1000; + const cpuTimeHardLimitMs = 10 * 60 * 1000; + const noModuleCache = false; + const envVarsObj = Deno.env.toObject(); + const envVars = Object.keys(envVarsObj).map((k) => [k, envVarsObj[k]]); + + const tracingOpts = service_name.endsWith("-untraced") ? {} : { + otelConfig: { + tracing_enabled: true, + propagators: ["TraceContext"], + }, + }; + + const debugOpts = service_name.endsWith("-echo") + ? { context: { exposeRequestTraceId: true } } + : {}; + + return await EdgeRuntime.userWorkers.create({ + servicePath, + memoryLimitMb, + workerTimeoutMs, + cpuTimeSoftLimitMs, + cpuTimeHardLimitMs, + noModuleCache, + envVars, + ...tracingOpts, + ...debugOpts, + traceRateLimitOptions: { + key: servicePath, + rules: [ + { + matches: ".*", + ttl: 60, + budget: { local: 10, global: 10 }, + }, + ], + }, + }); + }; + + try { + const worker = await createWorker(); + return await worker.fetch(req); + } catch (e) { + console.error(e); + const error = { msg: e.toString() }; + return new Response( + JSON.stringify(error), + { status: 500, headers: { "Content-Type": "application/json" } }, + ); + } +}); diff --git a/crates/base/test_cases/rate-limit-untraced/index.ts b/crates/base/test_cases/rate-limit-untraced/index.ts new file mode 100644 index 000000000..880eeda09 --- /dev/null +++ b/crates/base/test_cases/rate-limit-untraced/index.ts @@ -0,0 +1,74 @@ +// Worker C: calls itself repeatedly to exhaust the global (untraced) budget. +// Supports two outbound HTTP modes via x-http-mode: "fetch" (default) or "node". +import * as http from "node:http"; + +function requestViaNode( + url: string, + headers: Record, +): Promise<{ status: number; body: string }> { + return new Promise((resolve, reject) => { + const parsed = new URL(url); + const req = http.request( + { + hostname: parsed.hostname, + port: parsed.port, + path: parsed.pathname, + method: "GET", + headers, + }, + (res) => { + let body = ""; + res.on("data", (chunk) => { + body += chunk; + }); + res.on("end", () => resolve({ status: res.statusCode ?? 500, body })); + }, + ); + req.on("error", reject); + req.end(); + }); +} + +Deno.serve(async (req: Request) => { + const serverUrl = req.headers.get("x-test-server-url"); + if (!serverUrl) { + return new Response( + JSON.stringify({ msg: "missing x-test-server-url header" }), + { status: 400, headers: { "Content-Type": "application/json" } }, + ); + } + + const mode = req.headers.get("x-http-mode") ?? "fetch"; + const forwardHeaders: Record = { + "x-test-server-url": serverUrl, + "x-http-mode": mode, + }; + + try { + let status: number; + let body: string; + + if (mode === "node") { + ({ status, body } = await requestViaNode( + `${serverUrl}/rate-limit-untraced`, + forwardHeaders, + )); + } else { + const resp = await fetch(`${serverUrl}/rate-limit-untraced`, { + headers: forwardHeaders, + }); + status = resp.status; + body = await resp.text(); + } + + return new Response(body, { + status, + headers: { "Content-Type": "application/json" }, + }); + } catch (e) { + return new Response( + JSON.stringify({ msg: e.toString() }), + { status: 500, headers: { "Content-Type": "application/json" } }, + ); + } +}); diff --git a/crates/base/tests/integration_tests.rs b/crates/base/tests/integration_tests.rs index 04b5e72c5..7ad6cb56a 100644 --- a/crates/base/tests/integration_tests.rs +++ b/crates/base/tests/integration_tests.rs @@ -37,6 +37,7 @@ use base::utils::test_utils::TestBedBuilder; use base::worker; use base::worker::TerminationToken; use base::WorkerKind; +use deno::deno_telemetry::OtelConfig; use deno::DenoOptionsBuilder; use deno_core::error::AnyError; use deno_core::serde_json::json; @@ -106,6 +107,18 @@ const NON_SECURE_PORT: u16 = 8498; const SECURE_PORT: u16 = 4433; const TESTBED_DEADLINE_SEC: u64 = 20; +static OTEL_INIT: std::sync::OnceLock<()> = std::sync::OnceLock::new(); + +fn init_otel() { + OTEL_INIT.get_or_init(|| { + deno::deno_telemetry::init( + deno::versions::otel_runtime_config(), + OtelConfig::default(), + ) + .unwrap(); + }); +} + const TLS_LOCALHOST_ROOT_CA: &[u8] = include_bytes!("./fixture/tls/root-ca.pem"); const TLS_LOCALHOST_CERT: &[u8] = include_bytes!("./fixture/tls/localhost.pem"); @@ -4158,6 +4171,196 @@ async fn test_brotli_async() { ); } +async fn assert_rate_limit_error( + resp: Result, +) { + let res = resp.unwrap(); + assert_eq!( + res.status().as_u16(), + StatusCode::INTERNAL_SERVER_ERROR, + "expected the chain to be rate-limited" + ); + let body = res.text().await.unwrap(); + assert!( + body.contains("RateLimitError"), + "expected RateLimitError in body, got: {body}" + ); +} + +/// Verifies that the local (traced) budget cuts off A→B→A circular chains. +async fn test_outbound_rate_limit_circular(http_mode: &'static str) { + const TRACEPARENT: &str = + "00-12345678901234567890123456789012-1234567890123456-01"; + + init_otel(); + + integration_test_with_server_flag!( + ServerFlags { + rate_limit_cleanup_interval_sec: 60, + request_wait_timeout_ms: Some(30_000), + ..Default::default() + }, + "./test_cases/rate-limit-main", + NON_SECURE_PORT, + "rate-limit-a", + None, + Some( + reqwest::Client::new() + .get(format!("http://localhost:{}/rate-limit-a", NON_SECURE_PORT)) + .header("traceparent", TRACEPARENT) + .header("x-http-mode", http_mode) + .header( + "x-test-server-url", + format!("http://localhost:{}", NON_SECURE_PORT), + ) + ), + None::, + (|resp| async { assert_rate_limit_error(resp).await }), + TerminationToken::new() + ); +} + +#[tokio::test] +#[serial] +async fn test_outbound_rate_limit_circular_fetch() { + test_outbound_rate_limit_circular("fetch").await; +} + +#[tokio::test] +#[serial] +async fn test_outbound_rate_limit_circular_node_http() { + test_outbound_rate_limit_circular("node").await; +} + +/// Verifies that the global (untraced) budget cuts off a self-calling worker. +async fn test_outbound_rate_limit_global(http_mode: &'static str) { + init_otel(); + + integration_test_with_server_flag!( + ServerFlags { + rate_limit_cleanup_interval_sec: 60, + request_wait_timeout_ms: Some(30_000), + ..Default::default() + }, + "./test_cases/rate-limit-main", + NON_SECURE_PORT, + "rate-limit-untraced", + None, + Some( + reqwest::Client::new() + .get(format!( + "http://localhost:{}/rate-limit-untraced", + NON_SECURE_PORT + )) + .header("x-http-mode", http_mode) + .header( + "x-test-server-url", + format!("http://localhost:{}", NON_SECURE_PORT), + ) + ), + None::, + (|resp| async { assert_rate_limit_error(resp).await }), + TerminationToken::new() + ); +} + +#[tokio::test] +#[serial] +async fn test_outbound_rate_limit_global_budget_fetch() { + test_outbound_rate_limit_global("fetch").await; +} + +#[tokio::test] +#[serial] +async fn test_outbound_rate_limit_global_budget_node_http() { + test_outbound_rate_limit_global("node").await; +} + +/// Verifies that `AsyncVariable` isolates trace IDs per request. +/// +/// Sends N concurrent requests to the `rate-limit-echo` worker, each carrying +/// a distinct `traceparent`. The worker reads the `AsyncVariable` set by +/// `http.js` and echoes the trace ID back. If any response contains a trace +/// ID that does not match the one sent in that request, context has leaked. +#[tokio::test] +#[serial] +async fn test_request_trace_id_isolation() { + const N: usize = 20; + + init_otel(); + + integration_test_with_server_flag!( + ServerFlags { + rate_limit_cleanup_interval_sec: 60, + request_wait_timeout_ms: Some(30_000), + ..Default::default() + }, + "./test_cases/rate-limit-main", + NON_SECURE_PORT, + "rate-limit-echo", + None, + None::, + None::, + ( + |(port, _url, _req_builder, _event_rx, _metric_src)| async move { + let client = std::sync::Arc::new(reqwest::Client::new()); + let base = std::sync::Arc::new(format!( + "http://localhost:{}/rate-limit-echo", + port + )); + + // Build N futures, each with a unique trace ID. + let futs: Vec<_> = (0..N) + .map(|i| { + let client = client.clone(); + let base = base.clone(); + // Each trace ID is a 32-hex-char string unique to this request. + let trace_id = format!("{:032x}", i); + let traceparent = format!("00-{}-1234567890123456-01", trace_id); + async move { + let resp = client + .get(base.as_str()) + .header("traceparent", &traceparent) + .send() + .await + .unwrap(); + assert!( + resp.status().is_success(), + "request {i} failed with status {}", + resp.status() + ); + let body: serde_json::Value = resp.json().await.unwrap(); + let returned = body["traceId"].as_str().unwrap_or("").to_string(); + (trace_id, returned) + } + }) + .collect(); + + // Run all requests concurrently. + let results = futures_util::future::join_all(futs).await; + + for (expected, got) in &results { + assert_eq!( + expected, got, + "AsyncVariable context leaked: expected trace_id={expected} but worker saw {got}" + ); + } + + // Return the last response to satisfy the macro's type requirement. + Some(Ok( + reqwest::Client::new() + .get(base.as_str()) + .send() + .await + .unwrap(), + )) + }, + |_resp| async {} + ), + TerminationToken::new() + ); +} + #[derive(Deserialize)] struct ErrorResponsePayload { msg: String, diff --git a/ext/node/polyfills/http.ts b/ext/node/polyfills/http.ts index b389f4bfb..b35c336ae 100644 --- a/ext/node/polyfills/http.ts +++ b/ext/node/polyfills/http.ts @@ -5,6 +5,7 @@ import { core, internals, primordials } from "ext:core/mod.js"; import { + op_check_outbound_rate_limit, op_http_upgrade_raw2, op_node_http_fetch_response_upgrade, op_node_http_fetch_send, @@ -493,6 +494,21 @@ class ClientRequest extends OutgoingMessage { span.setAttribute("url.query", parsedUrl.search.slice(1)); } + const traceId = internals.getRequestTraceId?.(); + const isTraced = traceId !== null && traceId !== undefined; + const rlKey = isTraced ? traceId : ""; + const allowed = op_check_outbound_rate_limit( + parsedUrl.href, + rlKey, + isTraced, + ); + if (!allowed) { + const msg = isTraced + ? `Rate limit exceeded for trace ${rlKey}` + : `Rate limit exceeded for function`; + throw new Deno.errors.RateLimitError(msg); + } + this._req = op_node_http_request( this.method, url, @@ -598,7 +614,7 @@ class ClientRequest extends OutgoingMessage { updateSpanFromError(span, err); } - if (this._req.cancelHandleRid !== null) { + if (this._req !== undefined && this._req.cancelHandleRid !== null) { core.tryClose(this._req.cancelHandleRid); } diff --git a/ext/runtime/Cargo.toml b/ext/runtime/Cargo.toml index 12dbbe326..771ba4af9 100644 --- a/ext/runtime/Cargo.toml +++ b/ext/runtime/Cargo.toml @@ -23,6 +23,7 @@ base_rt.workspace = true anyhow.workspace = true bytes.workspace = true +dashmap.workspace = true enum-as-inner.workspace = true futures.workspace = true http.workspace = true @@ -31,6 +32,7 @@ hyper.workspace = true hyper_v014.workspace = true log.workspace = true once_cell.workspace = true +regex.workspace = true scopeguard.workspace = true serde.workspace = true thiserror.workspace = true diff --git a/ext/runtime/js/bootstrap.js b/ext/runtime/js/bootstrap.js index 9035efc3b..492e9a8fd 100644 --- a/ext/runtime/js/bootstrap.js +++ b/ext/runtime/js/bootstrap.js @@ -1,4 +1,5 @@ import { core, internals, primordials } from "ext:core/mod.js"; +import "ext:runtime/request_context.js"; import * as abortSignal from "ext:deno_web/03_abort_signal.js"; import * as base64 from "ext:deno_web/05_base64.js"; @@ -602,6 +603,16 @@ globalThis.bootstrapSBEdge = (opts, ctx) => { }); } + if (ctx?.exposeRequestTraceId) { + ObjectDefineProperty(globalThis, "getRequestTraceId", { + get() { + return internals.getRequestTraceId; + }, + configurable: true, + enumerable: false, + }); + } + bootstrapOtel(otel); ObjectDefineProperty(globalThis, "Deno", readOnly(denoOverrides)); diff --git a/ext/runtime/js/errors.js b/ext/runtime/js/errors.js index 2e6f790e5..844c39250 100644 --- a/ext/runtime/js/errors.js +++ b/ext/runtime/js/errors.js @@ -56,6 +56,7 @@ const DOMExceptionInvalidCharacterError = buildDomErrorClass( "InvalidCharacterError", ); const DOMExceptionDataError = buildDomErrorClass("DOMExceptionDataError"); +const RateLimitError = buildErrorClass("RateLimitError"); function registerErrors() { core.registerErrorClass("InvalidWorkerResponse", InvalidWorkerResponse); @@ -107,6 +108,7 @@ function registerErrors() { "DOMExceptionDataError", DOMExceptionDataError, ); + core.registerErrorClass("RateLimitError", RateLimitError); } const errors = knownErrors; diff --git a/ext/runtime/js/http.js b/ext/runtime/js/http.js index 1a3608e65..5c19aa2f2 100644 --- a/ext/runtime/js/http.js +++ b/ext/runtime/js/http.js @@ -1,6 +1,7 @@ import "ext:deno_http/01_http.js"; import { core, internals, primordials } from "ext:core/mod.js"; +import { enterRequestContext } from "ext:runtime/request_context.js"; import { RequestPrototype } from "ext:deno_fetch/23_request.js"; import { fromInnerResponse, @@ -28,6 +29,7 @@ const ops = core.ops; const { BadResourcePrototype, internalRidSymbol, + setAsyncContext, } = core; const { ArrayPrototypeFind, @@ -223,66 +225,82 @@ async function respond(requestEvent, httpConn, options, snapshot) { const mapped = async function (requestEvent, httpConn, options, span) { /** @type {Response} */ let response; - try { - if (span) { - updateSpanFromRequest(span, requestEvent.request); - } - response = await options["handler"](requestEvent.request, { - remoteAddr: { - port: options.port, - hostname: options.hostname, - transport: options.transport, - }, - }); - } catch (error) { - if (options["onError"] !== void 0) { - /** @throwable */ - response = await options["onError"](error); - } else { - console.error(error); - response = internalServerError(); + const traceParent = requestEvent.request.headers.get("traceparent"); + let traceId = null; + if (traceParent) { + // traceparent format: 00-{trace-id}-{parent-id}-{flags} + const parts = traceParent.split("-"); + if (parts.length >= 4 && parts[1].length === 32) { + traceId = parts[1]; } } + const prevCtx = enterRequestContext(traceId); - if (ObjectPrototypeIsPrototypeOf(ResponsePrototype, response) && span) { - updateSpanFromResponse(span, response); - } + try { + try { + if (span) { + updateSpanFromRequest(span, requestEvent.request); + } - if (response === internals.RAW_UPGRADE_RESPONSE_SENTINEL) { - const { fenceRid } = getSupabaseTag(requestEvent.request); + response = await options["handler"](requestEvent.request, { + remoteAddr: { + port: options.port, + hostname: options.hostname, + transport: options.transport, + }, + }); + } catch (error) { + if (options["onError"] !== void 0) { + /** @throwable */ + response = await options["onError"](error); + } else { + console.error(error); + response = internalServerError(); + } + } - if (fenceRid === void 0) { - throw TypeError("Cannot find a fence for upgrading response"); + if (ObjectPrototypeIsPrototypeOf(ResponsePrototype, response) && span) { + updateSpanFromResponse(span, response); } - setTimeout(async () => { - const { - status, - headers, - } = await ops.op_http_upgrade_raw2_fence(fenceRid); + if (response === internals.RAW_UPGRADE_RESPONSE_SENTINEL) { + const { fenceRid } = getSupabaseTag(requestEvent.request); + + if (fenceRid === void 0) { + throw TypeError("Cannot find a fence for upgrading response"); + } + setTimeout(async () => { + const { + status, + headers, + } = await ops.op_http_upgrade_raw2_fence(fenceRid); + + try { + await requestEvent.respondWith( + new Response(null, { + headers, + status, + }), + ); + } catch (error) { + closeHttpConn(httpConn); + } + }); + } else { try { - await requestEvent.respondWith( - new Response(null, { - headers, - status, - }), - ); - } catch (error) { - closeHttpConn(httpConn); + // send the response + await requestEvent.respondWith(response); + } catch { + // respondWith() fails when the connection has already been closed, + // or there is some other error with responding on this connection + // that prompts us to close it and open a new connection. + return closeHttpConn(httpConn); } - }); - } else { - try { - // send the response - await requestEvent.respondWith(response); - } catch { - // respondWith() fails when the connection has already been closed, - // or there is some other error with responding on this connection - // that prompts us to close it and open a new connection. - return closeHttpConn(httpConn); } + } finally { + setAsyncContext(prevCtx); } }; diff --git a/ext/runtime/js/request_context.js b/ext/runtime/js/request_context.js new file mode 100644 index 000000000..4fdda5902 --- /dev/null +++ b/ext/runtime/js/request_context.js @@ -0,0 +1,18 @@ +import { core, internals } from "ext:core/mod.js"; + +const { AsyncVariable } = core; + +const requestTraceIdVar = new AsyncVariable(); + +function enterRequestContext(traceId) { + return requestTraceIdVar.enter(traceId); +} + +function getRequestTraceId() { + return requestTraceIdVar.get(); +} + +internals.enterRequestContext = enterRequestContext; +internals.getRequestTraceId = getRequestTraceId; + +export { enterRequestContext, getRequestTraceId }; diff --git a/ext/runtime/lib.rs b/ext/runtime/lib.rs index dcfa3d589..3a2c74423 100644 --- a/ext/runtime/lib.rs +++ b/ext/runtime/lib.rs @@ -29,6 +29,13 @@ pub mod cert; pub mod conn_sync; pub mod external_memory; pub mod ops; +pub mod rate_limit; + +pub use rate_limit::RateLimiterOpts; +pub use rate_limit::SharedRateLimitTable; +pub use rate_limit::TraceRateLimitRule; +pub use rate_limit::TraceRateLimiter; +pub use rate_limit::TraceRateLimiterConfig; pub use ops::bootstrap::runtime_bootstrap; pub use ops::http::runtime_http; @@ -442,6 +449,19 @@ pub fn op_bootstrap_unstable_args(_state: &mut OpState) -> Vec { vec![] } +#[op2(fast)] +pub fn op_check_outbound_rate_limit( + state: &mut OpState, + #[string] url: &str, + #[string] key: &str, + is_traced: bool, +) -> bool { + let Some(limiter) = state.try_borrow::() else { + return true; + }; + limiter.check_and_increment(url, key, is_traced) +} + deno_core::extension!( runtime, ops = [ @@ -459,6 +479,7 @@ deno_core::extension!( op_raise_segfault, op_tap_promise_metrics, op_cancel_drop_token, + op_check_outbound_rate_limit, ], esm_entry_point = "ext:runtime/bootstrap.js", esm = [ @@ -472,6 +493,7 @@ deno_core::extension!( "errors.js", "fieldUtils.js", "http.js", + "request_context.js", "namespaces.js", "navigator.js", "permissions.js", diff --git a/ext/runtime/rate_limit.rs b/ext/runtime/rate_limit.rs new file mode 100644 index 000000000..f135abf49 --- /dev/null +++ b/ext/runtime/rate_limit.rs @@ -0,0 +1,197 @@ +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use dashmap::DashMap; +use regex::Regex; +use serde::Deserialize; +use serde::Serialize; +use tokio_util::sync::CancellationToken; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TraceRateLimitBudget { + pub local: u32, + pub global: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TraceRateLimitRule { + pub matches: String, + pub ttl: u64, + pub budget: TraceRateLimitBudget, +} + +struct CompiledRule { + matches: Regex, + ttl: Duration, + budget: TraceRateLimitBudget, +} + +#[derive(Debug)] +struct Entry { + count: u32, + expires_at: Instant, +} + +#[derive(Debug, Clone, Default)] +pub struct SharedRateLimitTable { + table: Arc>, +} + +impl SharedRateLimitTable { + /// Spawns a background task that removes expired entries every `interval`. + /// The task shuts down when `cancel` is cancelled. + pub fn spawn_cleanup_task( + &self, + interval: Duration, + cancel: CancellationToken, + ) { + let table = self.table.clone(); + tokio::spawn(async move { + loop { + tokio::select! { + _ = cancel.cancelled() => break, + _ = tokio::time::sleep(interval) => { + let now = Instant::now(); + table.retain(|_, entry| entry.expires_at > now); + } + } + } + }); + } + + pub fn check_and_increment( + &self, + key: &str, + budget: u32, + ttl: Duration, + ) -> bool { + let now = Instant::now(); + + let mut entry = + self.table.entry(key.to_string()).or_insert_with(|| Entry { + count: 0, + expires_at: now + ttl, + }); + + if now >= entry.expires_at { + tracing::debug!( + key, + count = entry.count, + budget, + "rate limit entry expired, resetting" + ); + entry.count = 0; + entry.expires_at = now + ttl; + } + + let allowed = entry.count < budget; + tracing::trace!( + key, + count = entry.count, + budget, + ?ttl, + allowed, + "rate limit check" + ); + + if !allowed { + return false; + } + + entry.count += 1; + true + } +} + +/// Bundles the shared table and per-worker rules together. +/// `Some` only when both are present; passed through the worker creation chain. +#[derive(Debug, Clone)] +pub struct TraceRateLimiterConfig { + pub table: SharedRateLimitTable, + pub rules: Vec, + /// Caller-supplied stable key shared across all instances of the same + /// function. Used as the rate-limit key for untraced requests so the global + /// budget accumulates correctly regardless of how many worker instances exist. + pub global_key: Option, +} + +/// Rate-limit configuration as it travels through the worker creation pipeline. +/// +/// - `Rules`: rules provided by JS; the pool hasn't attached a shared table yet. +/// - `Configured`: the pool has assembled the full config; ready for compilation. +#[derive(Debug, Clone, Default)] +pub enum RateLimiterOpts { + #[default] + Disabled, + Rules { + rules: Vec, + global_key: String, + }, + Configured(TraceRateLimiterConfig), +} + +#[derive(Clone)] +pub struct TraceRateLimiter { + table: SharedRateLimitTable, + rules: Arc>, + global_key: Option, +} + +impl TraceRateLimiter { + pub fn new( + TraceRateLimiterConfig { + table, + rules, + global_key, + }: TraceRateLimiterConfig, + ) -> Result { + let compiled = rules + .into_iter() + .map(|r| { + Ok(CompiledRule { + matches: Regex::new(&r.matches)?, + ttl: Duration::from_secs(r.ttl), + budget: r.budget, + }) + }) + .collect::, regex::Error>>()?; + + Ok(Self { + table, + rules: Arc::new(compiled), + global_key, + }) + } + + pub fn check_and_increment( + &self, + url: &str, + key: &str, + is_traced: bool, + ) -> bool { + let rule = self.rules.iter().find(|r| r.matches.is_match(url)); + + let Some(rule) = rule else { + return true; + }; + + if is_traced { + self + .table + .check_and_increment(key, rule.budget.local, rule.ttl) + } else { + // For untraced requests a stable global_key is required so the global + // budget accumulates correctly across worker instances. Deny the request + // if the caller did not supply one. + let Some(fid) = self.global_key.as_deref() else { + return false; + }; + self + .table + .check_and_increment(fid, rule.budget.global, rule.ttl) + } + } +} diff --git a/ext/workers/context.rs b/ext/workers/context.rs index d875ff500..60beeccfd 100644 --- a/ext/workers/context.rs +++ b/ext/workers/context.rs @@ -17,6 +17,7 @@ use enum_as_inner::EnumAsInner; use ext_event_worker::events::UncaughtExceptionEvent; use ext_event_worker::events::WorkerEventWithMetadata; use ext_runtime::MetricSource; +use ext_runtime::RateLimiterOpts; use ext_runtime::SharedMetricSource; use fs::s3_fs::S3FsConfig; use fs::tmp_fs::TmpFsConfig; @@ -92,6 +93,7 @@ pub struct UserWorkerRuntimeOpts { pub permissions: Option, pub context: Option, + pub rate_limiter: RateLimiterOpts, } impl Default for UserWorkerRuntimeOpts { @@ -135,6 +137,7 @@ impl Default for UserWorkerRuntimeOpts { permissions: None, context: None, + rate_limiter: RateLimiterOpts::Disabled, } } } diff --git a/ext/workers/lib.rs b/ext/workers/lib.rs index c6959f09b..0a0380d15 100644 --- a/ext/workers/lib.rs +++ b/ext/workers/lib.rs @@ -38,6 +38,7 @@ use deno_telemetry::OtelConsoleConfig; use deno_telemetry::OtelPropagators; use errors::WorkerError; use ext_runtime::conn_sync::ConnWatcher; +use ext_runtime::TraceRateLimitRule; use fs::s3_fs::S3FsConfig; use fs::tmp_fs::TmpFsConfig; use http_utils::utils::get_upgrade_type; @@ -134,6 +135,15 @@ pub struct UserWorkerCreateOptions { context: Option, #[serde(default)] static_patterns: Vec, + trace_rate_limit_options: Option, +} + +#[derive(Deserialize, Serialize, Default, Debug)] +#[serde(rename_all = "camelCase")] +pub struct JsTraceRateLimitOptions { + pub key: String, + #[serde(default)] + pub rules: Vec, } /// It is identical to [`PermissionsOptions`], except for `prompt`. @@ -220,6 +230,7 @@ pub async fn op_user_worker_create( context, static_patterns, + trace_rate_limit_options, } = opts; let maybe_otel_config = maybe_otel_config.map(|it| OtelConfig { @@ -259,6 +270,16 @@ pub async fn op_user_worker_create( .map(JsPermissionsOptions::into_permissions_options), context, + rate_limiter: match trace_rate_limit_options { + None => ext_runtime::RateLimiterOpts::Disabled, + Some(opts) if opts.rules.is_empty() => { + ext_runtime::RateLimiterOpts::Disabled + } + Some(opts) => ext_runtime::RateLimiterOpts::Rules { + rules: opts.rules, + global_key: opts.key, + }, + }, ..Default::default() } diff --git a/types/global.d.ts b/types/global.d.ts index fce11cb53..11d0a2f98 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -64,6 +64,8 @@ interface UserWorkerCreateContext { otel?: { [attribute: string]: string; }; + + exposeRequestTraceId?: boolean | null; } interface UserWorkerCreateOptions { @@ -93,6 +95,42 @@ interface UserWorkerCreateOptions { otelConfig?: OtelConfig | null; context?: UserWorkerCreateContext | null; + traceRateLimitOptions?: TraceRateLimitOptions | null; +} + +/** Per-URL budget split between traced (local) and untraced (global) requests. */ +interface TraceRateLimitBudget { + /** Max outbound requests allowed per trace ID within the TTL window. */ + local: number; + /** Max outbound requests allowed across all untraced requests within the TTL window. */ + global: number; +} + +/** A single rate-limit rule applied to outbound URLs matching `matches`. */ +interface TraceRateLimitRule { + /** Regular expression matched against the outbound request URL. */ + matches: string; + /** Window duration in seconds. The counter resets after this period. */ + ttl: number; + budget: TraceRateLimitBudget; +} + +/** + * Rate-limit configuration for outbound HTTP requests made by a user worker. + * + * Rules are evaluated in order; the first matching rule applies. + * Traced requests (those carrying a `traceparent` header) share a budget + * identified by their trace ID (`local` budget). Untraced requests share a + * single global budget identified by `key` (`global` budget). + */ +interface TraceRateLimitOptions { + /** + * Stable identifier shared across all instances of the same function. + * Used as the rate-limit key for untraced requests so the global budget + * accumulates correctly regardless of how many worker instances exist. + */ + key: string; + rules: TraceRateLimitRule[]; } interface HeapStatistics { @@ -214,5 +252,8 @@ declare namespace Deno { export namespace errors { class WorkerRequestCancelled extends Error {} class WorkerAlreadyRetired extends Error {} + + /** Thrown when an outbound HTTP request is blocked by the rate limiter. */ + class RateLimitError extends Error {} } } diff --git a/vendor/deno_fetch/26_fetch.js b/vendor/deno_fetch/26_fetch.js index b3e23fe75..7ddd5b457 100644 --- a/vendor/deno_fetch/26_fetch.js +++ b/vendor/deno_fetch/26_fetch.js @@ -10,8 +10,9 @@ /// /// -import { core, primordials } from "ext:core/mod.js"; +import { core, internals, primordials } from "ext:core/mod.js"; import { + op_check_outbound_rate_limit, op_fetch, op_fetch_promise_is_settled, op_fetch_send, @@ -387,11 +388,30 @@ function fetch(input, init = { __proto__: null }) { // 3. const request = toInnerRequest(requestObject); + // 4. if (requestObject.signal.aborted) { reject(abortFetch(request, null, requestObject.signal.reason)); return; } + + // Rate limit check. + const traceId = internals.getRequestTraceId?.(); + const isTraced = traceId !== null && traceId !== undefined; + const rlKey = isTraced ? traceId : ""; + const allowed = op_check_outbound_rate_limit( + requestObject.url, + rlKey, + isTraced, + ); + if (!allowed) { + const msg = isTraced + ? `Rate limit exceeded for trace ${rlKey}` + : `Rate limit exceeded for function`; + reject(new Deno.errors.RateLimitError(msg)); + return; + } + // 7. let responseObject = null; // 9.