diff --git a/Cargo.lock b/Cargo.lock index 1eec573865..11a1692aec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -542,7 +542,7 @@ dependencies = [ "pin-project-lite", "serde_core", "sync_wrapper", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", ] @@ -791,6 +791,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" +[[package]] +name = "cidr" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579504560394e388085d0c080ea587dfa5c15f7e251b4d5247d1e1a61d1d6928" + [[package]] name = "clap" version = "4.6.1" @@ -1878,14 +1884,20 @@ dependencies = [ ] [[package]] -name = "eventsource-stream" -version = "0.2.3" +name = "eventsource-client" +version = "0.17.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab" +checksum = "4f2808c25d229d2f854182ba2b8098bfb8592f439f199b408d16aaf186f7b5e8" dependencies = [ - "futures-core", - "nom 7.1.3", - "pin-project-lite", + "base64 0.22.1", + "bytes", + "futures", + "http 1.4.0", + "launchdarkly-sdk-transport", + "log", + "pin-project", + "rand 0.10.1", + "tokio", ] [[package]] @@ -2052,6 +2064,7 @@ dependencies = [ "dashmap 7.0.0-rc2", "derive_more", "derive_setters", + "eventsource-client", "fake", "forge_config", "forge_display", @@ -2072,7 +2085,6 @@ dependencies = [ "pretty_assertions", "regex", "reqwest 0.12.28", - "reqwest-eventsource", "schemars 1.2.1", "serde", "serde_json", @@ -2165,7 +2177,7 @@ dependencies = [ "is_ci", "lazy_static", "merge", - "nom 8.0.0", + "nom", "pretty_assertions", "regex", "schemars 1.2.1", @@ -2221,6 +2233,7 @@ dependencies = [ "diesel_migrations", "dirs", "dotenvy", + "eventsource-client", "fake", "forge_app", "forge_config", @@ -2234,12 +2247,12 @@ dependencies = [ "glob", "google-cloud-auth", "http 1.4.0", + "launchdarkly-sdk-transport", "libsqlite3-sys", "oauth2", "open", "pretty_assertions", "reqwest 0.12.28", - "reqwest-eventsource", "rmcp", "schemars 1.2.1", "serde", @@ -2369,7 +2382,7 @@ dependencies = [ "diesel", "diesel_migrations", "dirs", - "eventsource-stream", + "eventsource-client", "fake", "forge_app", "forge_config", @@ -2393,7 +2406,6 @@ dependencies = [ "prost-types", "regex", "reqwest 0.12.28", - "reqwest-eventsource", "schemars 1.2.1", "serde", "serde_json", @@ -2438,6 +2450,7 @@ dependencies = [ "dashmap 7.0.0-rc2", "derive_more", "derive_setters", + "eventsource-client", "fake", "forge_app", "forge_config", @@ -2463,7 +2476,6 @@ dependencies = [ "pretty_assertions", "regex", "reqwest 0.12.28", - "reqwest-eventsource", "serde", "serde_json", "serde_urlencoded", @@ -2681,12 +2693,6 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" -[[package]] -name = "futures-timer" -version = "3.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" - [[package]] name = "futures-util" version = "0.3.32" @@ -3999,6 +4005,30 @@ dependencies = [ "hashbrown 0.15.5", ] +[[package]] +name = "headers" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" +dependencies = [ + "base64 0.22.1", + "bytes", + "headers-core", + "http 1.4.0", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.4.0", +] + [[package]] name = "heapless" version = "0.8.0" @@ -4274,6 +4304,23 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-http-proxy" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ad4b0a1e37510028bc4ba81d0e38d239c39671b0f0ce9e02dfa93a8133f7c08" +dependencies = [ + "bytes", + "futures-util", + "headers", + "http 1.4.0", + "hyper 1.9.0", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -4856,6 +4903,25 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "launchdarkly-sdk-transport" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fe83622d04dfcaaeac0b5e3aaa1cc156eb1e70c8b68dfcaffaee4365faa00d3" +dependencies = [ + "bytes", + "futures", + "http 1.4.0", + "http-body-util", + "hyper 1.9.0", + "hyper-http-proxy", + "hyper-timeout", + "hyper-util", + "log", + "no-proxy", + "tower 0.4.13", +] + [[package]] name = "lazy-regex" version = "3.6.0" @@ -5165,12 +5231,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - [[package]] name = "miniz_oxide" version = "0.8.9" @@ -5291,13 +5351,12 @@ dependencies = [ ] [[package]] -name = "nom" -version = "7.1.3" +name = "no-proxy" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +checksum = "9f79c902b31ceac6856e262af5dbaffef75390cf4647c9fef7b55da69a4b912e" dependencies = [ - "memchr", - "minimal-lexical", + "cidr", ] [[package]] @@ -6369,7 +6428,7 @@ dependencies = [ "tokio", "tokio-rustls 0.26.4", "tokio-util", - "tower", + "tower 0.5.3", "tower-http", "tower-service", "url", @@ -6411,7 +6470,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls 0.26.4", - "tower", + "tower 0.5.3", "tower-http", "tower-service", "url", @@ -6420,22 +6479,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "reqwest-eventsource" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632c55746dbb44275691640e7b40c907c16a2dc1a5842aa98aaec90da6ec6bde" -dependencies = [ - "eventsource-stream", - "futures-core", - "futures-timer", - "mime", - "nom 7.1.3", - "pin-project-lite", - "reqwest 0.12.28", - "thiserror 1.0.69", -] - [[package]] name = "resolv-conf" version = "0.7.6" @@ -7954,7 +7997,7 @@ dependencies = [ "tokio", "tokio-rustls 0.26.4", "tokio-stream", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", "tracing", @@ -8000,6 +8043,17 @@ dependencies = [ "tonic-build", ] +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower" version = "0.5.3" @@ -8037,7 +8091,7 @@ dependencies = [ "pin-project-lite", "tokio", "tokio-util", - "tower", + "tower 0.5.3", "tower-layer", "tower-service", ] @@ -8060,6 +8114,7 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index de4d71c631..c706119510 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,8 +73,7 @@ reqwest = { version = "0.12.23", features = [ "http2", ], default-features = false } rustls = { version = "0.23", features = ["ring"], default-features = false } -reqwest-eventsource = "0.6.0" -eventsource-stream = "0.2.3" +eventsource-client = "0.17.2" include_dir = "0.7.4" schemars = "1.2" serde = { version = "1.0.217", features = ["derive"] } diff --git a/crates/forge_app/Cargo.toml b/crates/forge_app/Cargo.toml index 8f5f1873b5..b89be07010 100644 --- a/crates/forge_app/Cargo.toml +++ b/crates/forge_app/Cargo.toml @@ -41,7 +41,7 @@ dashmap.workspace = true url.workspace = true reqwest.workspace = true bytes.workspace = true -reqwest-eventsource.workspace = true +eventsource-client.workspace = true schemars.workspace = true glob.workspace = true lazy_static.workspace = true diff --git a/crates/forge_app/src/infra.rs b/crates/forge_app/src/infra.rs index e4b49bfb66..8aeae02fa1 100644 --- a/crates/forge_app/src/infra.rs +++ b/crates/forge_app/src/infra.rs @@ -10,11 +10,10 @@ use forge_domain::{ }; use reqwest::Response; use reqwest::header::HeaderMap; -use reqwest_eventsource::EventSource; use serde::de::DeserializeOwned; use url::Url; -use crate::{WalkedFile, Walker}; +use crate::{EventSource, WalkedFile, Walker}; /// Infrastructure trait for accessing environment configuration, system /// variables, and persisted application configuration. diff --git a/crates/forge_app/src/lib.rs b/crates/forge_app/src/lib.rs index 66de3e618d..a8a7d96d3f 100644 --- a/crates/forge_app/src/lib.rs +++ b/crates/forge_app/src/lib.rs @@ -1,3 +1,11 @@ +use std::pin::Pin; + +use eventsource_client::SSE; +use futures::Stream; + +/// Type alias for a server-sent events stream +pub type EventSource = Pin> + Send + Sync>>; + mod agent; mod agent_executor; mod agent_provider_resolver; diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 59f88f3be7..b326e3228a 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -12,11 +12,10 @@ use forge_domain::{ }; use reqwest::Response; use reqwest::header::HeaderMap; -use reqwest_eventsource::EventSource; use url::Url; use crate::user::{User, UserUsage}; -use crate::{EnvironmentInfra, Walker}; +use crate::{EnvironmentInfra, EventSource, Walker}; #[derive(Debug, Clone)] pub struct ShellOutput { diff --git a/crates/forge_infra/Cargo.toml b/crates/forge_infra/Cargo.toml index 232926dbdb..7edf89571b 100644 --- a/crates/forge_infra/Cargo.toml +++ b/crates/forge_infra/Cargo.toml @@ -31,7 +31,8 @@ forge_app.workspace = true forge_walker.workspace = true -reqwest-eventsource.workspace = true +eventsource-client.workspace = true +launchdarkly-sdk-transport = "0.1.1" glob.workspace = true futures.workspace = true diesel = { version= "2.3.7", features = ["sqlite", "r2d2", "chrono"] } diff --git a/crates/forge_infra/src/forge_infra.rs b/crates/forge_infra/src/forge_infra.rs index 3a3e602d17..78102f0025 100644 --- a/crates/forge_infra/src/forge_infra.rs +++ b/crates/forge_infra/src/forge_infra.rs @@ -1,9 +1,11 @@ use std::collections::BTreeMap; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::process::ExitStatus; use std::sync::Arc; use bytes::Bytes; +use eventsource_client::SSE; use forge_app::{ CommandInfra, DirectoryReaderInfra, EnvironmentInfra, FileDirectoryInfra, FileInfoInfra, FileReaderInfra, FileRemoverInfra, FileWriterInfra, GrpcInfra, HttpInfra, McpServerInfra, @@ -12,9 +14,9 @@ use forge_app::{ use forge_domain::{ AuthMethod, CommandOutput, FileInfo as FileInfoData, McpServerConfig, ProviderId, URLParamSpec, }; +use futures::Stream; use reqwest::header::HeaderMap; use reqwest::{Response, Url}; -use reqwest_eventsource::EventSource; use crate::auth::{AnyAuthStrategy, ForgeAuthStrategyFactory}; use crate::console::StdConsoleWriter; @@ -320,7 +322,7 @@ impl HttpInfra for ForgeInfra { url: &Url, headers: Option, body: Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result> + Send + Sync>>> { self.http_service.http_eventsource(url, headers, body).await } } diff --git a/crates/forge_infra/src/http.rs b/crates/forge_infra/src/http.rs index 60f0743760..b7e56ccb04 100644 --- a/crates/forge_infra/src/http.rs +++ b/crates/forge_infra/src/http.rs @@ -1,18 +1,22 @@ use std::fs; use std::path::PathBuf; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use anyhow::Context; use bytes::Bytes; +use eventsource_client::{Client as EsClient, ClientBuilder, ReconnectOptions, SSE}; use forge_app::HttpInfra; use forge_config::{ForgeConfig, TlsBackend, TlsVersion}; +use futures::{Stream, StreamExt}; use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; use reqwest::redirect::Policy; use reqwest::{Certificate, Client, Response, StatusCode, Url}; -use reqwest_eventsource::{EventSource, RequestBuilderExt}; use tracing::{debug, warn}; +use crate::transport::ReqwestTransport; + const VERSION: &str = match option_env!("APP_VERSION") { None => env!("CARGO_PKG_VERSION"), Some(v) => v, @@ -24,6 +28,28 @@ pub struct ForgeHttpInfra { file: Arc, } +#[derive(Debug, thiserror::Error)] +enum EventSourceRequestError { + #[error("Invalid SSE URL")] + InvalidUrl { + #[source] + source: eventsource_client::Error, + }, + #[error("Invalid SSE header '{name}'")] + InvalidHeader { + name: String, + #[source] + source: eventsource_client::Error, + }, + #[error("UnexpectedResponse(status: {status}, body: [omitted to avoid blocking stream])")] + UnexpectedResponse { status: String }, + #[error("EventSource error")] + EventSource { + #[source] + source: eventsource_client::Error, + }, +} + fn to_reqwest_tls(tls: TlsVersion) -> reqwest::tls::Version { use reqwest::tls::Version; match tls { @@ -250,21 +276,74 @@ impl ForgeHttpInfra { url: &Url, headers: Option, body: Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result> + Send + Sync>>> { let mut request_headers = self.headers(headers); request_headers.insert("Content-Type", HeaderValue::from_static("application/json")); self.write_debug_request(&body); - self.client - .post(url.clone()) - .headers(request_headers) - .body(body) - .eventsource() - .with_context(|| format_http_context(None, "POST (EventSource)", url)) + // Build the URL string + let url_str = url.to_string(); + + // Create the client builder + let mut builder = ClientBuilder::for_url(&url_str) + .map_err(|source| EventSourceRequestError::InvalidUrl { source })? + .method("POST".to_string()) + .body(String::from_utf8_lossy(&body).to_string()) + // Disable auto-reconnect: LLM streaming endpoints end the stream + // after each response; reconnecting would send duplicate POSTs. + .reconnect(ReconnectOptions::reconnect(false).build()); + + // Add headers to the builder + for (name, value) in request_headers.iter() { + if let Ok(value_str) = value.to_str() { + builder = builder.header(name.as_str(), value_str).map_err(|source| { + EventSourceRequestError::InvalidHeader { name: name.to_string(), source } + })?; + } + } + + // Create the transport using our reqwest client + let transport = ReqwestTransport::new(self.client.clone()); + + // Build the client with our transport + let client = builder.build_with_transport(transport); + + // Return the stream with error mapping + let stream = client.stream().take_while(|result| { + let should_continue = match result { + Ok(_) => true, + Err(error) => !is_terminal_eventsource_error(error), + }; + futures::future::ready(should_continue) + }); + + Ok(Box::pin(stream.then(|result| async move { + match result { + Ok(event) => Ok(event), + Err(eventsource_client::Error::UnexpectedResponse(response, _body)) => { + let status = response.status(); + let status_display = StatusCode::from_u16(status) + .map(|status| status.to_string()) + .unwrap_or_else(|_| status.to_string()); + Err( + EventSourceRequestError::UnexpectedResponse { status: status_display } + .into(), + ) + } + Err(error) => Err(EventSourceRequestError::EventSource { source: error }.into()), + } + }))) } } +fn is_terminal_eventsource_error(error: &eventsource_client::Error) -> bool { + matches!( + error, + eventsource_client::Error::Eof | eventsource_client::Error::UnexpectedEof + ) +} + /// Helper function to format HTTP request/response context for logging and /// error reporting fn format_http_context>(status: Option, method: &str, url: U) -> String { @@ -299,7 +378,7 @@ impl HttpInfra for ForgeHttpInfra { url: &Url, headers: Option, body: Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result> + Send + Sync>>> { self.eventsource(url, headers, body).await } } @@ -507,6 +586,36 @@ mod tests { assert_eq!(writes[0].1, Bytes::from(expected)); } + #[test] + fn test_is_terminal_eventsource_error_with_eof() { + let fixture = eventsource_client::Error::Eof; + + let actual = is_terminal_eventsource_error(&fixture); + + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_is_terminal_eventsource_error_with_unexpected_eof() { + let fixture = eventsource_client::Error::UnexpectedEof; + + let actual = is_terminal_eventsource_error(&fixture); + + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_is_terminal_eventsource_error_with_non_terminal_error() { + let fixture = eventsource_client::Error::TimedOut; + + let actual = is_terminal_eventsource_error(&fixture); + + let expected = false; + assert_eq!(actual, expected); + } + #[test] fn test_sanitize_headers_redacts_sensitive_values() { use reqwest::header::HeaderValue; diff --git a/crates/forge_infra/src/lib.rs b/crates/forge_infra/src/lib.rs index a6a726d477..096d0f998e 100644 --- a/crates/forge_infra/src/lib.rs +++ b/crates/forge_infra/src/lib.rs @@ -16,6 +16,7 @@ mod inquire; mod kv_storage; mod mcp_client; mod mcp_server; +mod transport; mod walker; pub use console::StdConsoleWriter; diff --git a/crates/forge_infra/src/transport.rs b/crates/forge_infra/src/transport.rs new file mode 100644 index 0000000000..26d9b607a2 --- /dev/null +++ b/crates/forge_infra/src/transport.rs @@ -0,0 +1,130 @@ +use bytes::Bytes; +use futures::TryStreamExt; +use launchdarkly_sdk_transport::{ByteStream, HttpTransport, ResponseFuture, TransportError}; +use reqwest::Client; + +fn to_http_headers( + headers: &reqwest::header::HeaderMap, +) -> Result { + headers + .iter() + .try_fold(http::HeaderMap::new(), |mut mapped, (name, value)| { + let header_name = http::header::HeaderName::from_bytes(name.as_str().as_bytes()) + .map_err(|error| { + TransportError::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid response header name '{}': {}", name, error), + )) + })?; + let header_value = + http::header::HeaderValue::from_bytes(value.as_bytes()).map_err(|error| { + TransportError::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid response header '{}': {}", name, error), + )) + })?; + mapped.insert(header_name, header_value); + Ok(mapped) + }) +} + +/// Reqwest-based HTTP transport for eventsource-client. +#[derive(Clone)] +pub struct ReqwestTransport { + client: Client, +} + +impl ReqwestTransport { + /// Create a new ReqwestTransport from a reqwest Client. + pub fn new(client: Client) -> Self { + Self { client } + } +} + +impl HttpTransport for ReqwestTransport { + fn request(&self, request: http::Request>) -> ResponseFuture { + let client = self.client.clone(); + + Box::pin(async move { + // Convert http::Request to reqwest::Request + let method = reqwest::Method::from_bytes(request.method().as_str().as_bytes()) + .map_err(|e| { + TransportError::new(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Invalid HTTP method: {}", e), + )) + })?; + + let url = request.uri().to_string(); + let mut reqwest_request = client.request(method, &url); + + // Add headers + for (name, value) in request.headers() { + if let Ok(value_str) = value.to_str() { + reqwest_request = reqwest_request.header(name.as_str(), value_str); + } + } + + // Add body if present + if let Some(body) = request.body() { + reqwest_request = reqwest_request.body(body.clone()); + } + + // Execute the request + let response = reqwest_request.send().await.map_err(|e| { + TransportError::new(std::io::Error::other(format!("Request failed: {}", e))) + })?; + + // Convert reqwest::Response to http::Response + let status = response.status(); + let response_headers = to_http_headers(response.headers())?; + + // Create a byte stream from the response body + let byte_stream: ByteStream = Box::pin(response.bytes_stream().map_err(|e| { + TransportError::new(std::io::Error::other(format!("Stream error: {}", e))) + })); + + let mut http_response = http::Response::builder() + .status(status.as_u16()) + .body(byte_stream) + .map_err(|e| { + TransportError::new(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to build response: {}", e), + )) + })?; + + *http_response.headers_mut() = response_headers; + + Ok(http_response) + }) + } +} + +#[cfg(test)] +mod tests { + use reqwest::header::{HeaderMap, HeaderValue}; + + use super::to_http_headers; + + #[test] + fn test_to_http_headers_preserves_content_type_for_sse() { + let mut fixture = HeaderMap::new(); + fixture.insert( + "content-type", + HeaderValue::from_static("text/event-stream"), + ); + fixture.insert("cache-control", HeaderValue::from_static("no-cache")); + + let actual = to_http_headers(&fixture).unwrap(); + + assert_eq!( + actual.get("content-type").unwrap(), + &http::HeaderValue::from_static("text/event-stream") + ); + assert_eq!( + actual.get("cache-control").unwrap(), + &http::HeaderValue::from_static("no-cache") + ); + } +} diff --git a/crates/forge_repo/Cargo.toml b/crates/forge_repo/Cargo.toml index 0d115c0b50..727fc05afa 100644 --- a/crates/forge_repo/Cargo.toml +++ b/crates/forge_repo/Cargo.toml @@ -26,8 +26,7 @@ reqwest.workspace = true url.workspace = true bytes.workspace = true strum.workspace = true -reqwest-eventsource.workspace = true -eventsource-stream.workspace = true +eventsource-client.workspace = true handlebars.workspace = true merge.workspace = true aws-sdk-bedrockruntime.workspace = true diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index 34d1bb8498..be82fde0ee 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -1,8 +1,10 @@ use std::collections::BTreeMap; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::sync::Arc; use bytes::Bytes; +use eventsource_client::SSE; use forge_app::{ AgentRepository, CommandInfra, DirectoryReaderInfra, EnvironmentInfra, FileDirectoryInfra, FileInfoInfra, FileReaderInfra, FileRemoverInfra, FileWriterInfra, GrpcInfra, HttpInfra, @@ -20,7 +22,7 @@ use forge_domain::{ pub use forge_infra::CacacheStorage; use reqwest::Response; use reqwest::header::HeaderMap; -use reqwest_eventsource::EventSource; +use tokio_stream::Stream; use url::Url; use crate::agent::ForgeAgentRepository; @@ -285,7 +287,7 @@ impl HttpInfra for ForgeRepo { url: &Url, headers: Option, body: Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result> + Send + Sync>>> { self.infra.http_eventsource(url, headers, body).await } } diff --git a/crates/forge_repo/src/provider/anthropic.rs b/crates/forge_repo/src/provider/anthropic.rs index 3292f5ab9f..a152812ab2 100644 --- a/crates/forge_repo/src/provider/anthropic.rs +++ b/crates/forge_repo/src/provider/anthropic.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use anyhow::Context as _; -use eventsource_stream::Eventsource; use forge_app::domain::{ ChatCompletionMessage, Context, Model, ModelId, ResultStream, Transformer, }; @@ -19,6 +18,7 @@ use tracing::debug; use crate::provider::event::into_chat_completion_message; use crate::provider::retry::into_retry; +use crate::provider::sse_parser::parse_sse_stream; use crate::provider::utils::{create_headers, format_http_context}; #[derive(Clone)] @@ -188,37 +188,43 @@ impl Anthropic { } let request_url = parsed_url.clone(); - let stream = response - .bytes_stream() - .eventsource() - .filter_map(move |event_result| { + let stream = parse_sse_stream(response.bytes_stream()).filter_map( + move |event_result: anyhow::Result| { let request_url = request_url.clone(); async move { match event_result { - Ok(event) if ["[DONE]", ""].contains(&event.data.as_str()) => None, - Ok(event) => Some( - serde_json::from_str::(&event.data) - .with_context(|| { - format!("Failed to parse provider response: {}", event.data) - }) - .and_then(|response| { - ChatCompletionMessage::try_from(response).with_context(|| { - format!( - "Failed to create completion message: {}", - event.data + Ok(event) => { + let data = event.data(); + if ["[DONE]", ""].contains(&data) { + return None; + } + Some( + serde_json::from_str::(data) + .with_context(|| { + format!("Failed to parse provider response: {}", data) + }) + .and_then(|response| { + ChatCompletionMessage::try_from(response).with_context( + || { + format!( + "Failed to create completion message: {}", + data + ) + }, ) }) - }) - .with_context(|| { - format_http_context(None, "POST", request_url.clone()) - }), - ), - Err(error) => Some(Err(into_sse_parse_error(error)).with_context(|| { + .with_context(|| { + format_http_context(None, "POST", request_url.clone()) + }), + ) + } + Err(error) => Some(Err(error).with_context(|| { format_http_context(None, "POST", request_url.clone()) })), } } - }); + }, + ); Ok(Box::pin(stream)) } @@ -269,20 +275,6 @@ impl Anthropic { } } -fn into_sse_parse_error(error: eventsource_stream::EventStreamError) -> anyhow::Error -where - E: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static, -{ - let is_retryable = matches!(&error, eventsource_stream::EventStreamError::Transport(_)); - let error = anyhow::anyhow!("SSE parse error: {}", error); - - if is_retryable { - forge_domain::Error::Retryable(error).into() - } else { - error - } -} - /// Repository for Anthropic provider responses pub struct AnthropicResponseRepository { infra: Arc, @@ -371,7 +363,6 @@ mod tests { ToolResult, }; use reqwest::header::HeaderMap; - use reqwest_eventsource::EventSource; use super::*; use crate::provider::mock_server::{MockServer, normalize_ports}; @@ -420,7 +411,15 @@ mod tests { _url: &Url, _headers: Option, _body: Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result< + std::pin::Pin< + Box< + dyn futures::Stream> + + Send + + Sync, + >, + >, + > { // For now, return an error since eventsource is not used in the failing tests Err(anyhow::anyhow!("EventSource not implemented in mock")) } diff --git a/crates/forge_repo/src/provider/bedrock.rs b/crates/forge_repo/src/provider/bedrock.rs index c5e9653167..18c0c54f51 100644 --- a/crates/forge_repo/src/provider/bedrock.rs +++ b/crates/forge_repo/src/provider/bedrock.rs @@ -143,19 +143,29 @@ impl BedrockProvider { } } + /// Returns true when Bedrock indicates a connection-capacity limit. + fn is_connection_limit_message(message: Option<&str>) -> bool { + message + .map(str::to_ascii_lowercase) + .is_some_and(|message| message.contains("too many connections")) + } + /// Checks if a ConverseStreamError service error is retryable fn is_retryable_converse_error( err: &aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError, ) -> bool { use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError; - matches!( - err, + + match err { ConverseStreamError::ThrottlingException(_) - | ConverseStreamError::ServiceUnavailableException(_) - | ConverseStreamError::InternalServerException(_) - | ConverseStreamError::ModelStreamErrorException(_) - | ConverseStreamError::ModelNotReadyException(_) - ) + | ConverseStreamError::InternalServerException(_) + | ConverseStreamError::ModelStreamErrorException(_) + | ConverseStreamError::ModelNotReadyException(_) => true, + ConverseStreamError::ServiceUnavailableException(error) => { + !Self::is_connection_limit_message(error.message()) + } + _ => false, + } } /// Checks if a ConverseStreamOutputError service error is retryable @@ -163,13 +173,16 @@ impl BedrockProvider { err: &aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError, ) -> bool { use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError; - matches!( - err, + + match err { ConverseStreamOutputError::ThrottlingException(_) - | ConverseStreamOutputError::ServiceUnavailableException(_) - | ConverseStreamOutputError::InternalServerException(_) - | ConverseStreamOutputError::ModelStreamErrorException(_) - ) + | ConverseStreamOutputError::InternalServerException(_) + | ConverseStreamOutputError::ModelStreamErrorException(_) => true, + ConverseStreamOutputError::ServiceUnavailableException(error) => { + !Self::is_connection_limit_message(error.message()) + } + _ => false, + } } /// Checks if an SDK error is retryable based on error type (network/timeout @@ -228,26 +241,26 @@ impl BedrockProvider { _ => Self::is_retryable_sdk_error(&sdk_error), }; - // Extract the source error for better error messages - // SAFETY: into_source() always returns Ok for all SdkError variants - // (see aws-smithy-runtime-api/src/client/result.rs:448-459) - let source = sdk_error.into_source().unwrap(); + let source_error = match sdk_error.into_source() { + Ok(source) => anyhow::Error::from_boxed(source), + Err(error) => anyhow::Error::new(error), + }; if is_retryable { - forge_domain::Error::Retryable(anyhow::anyhow!("{}", source)).into() + forge_domain::Error::Retryable(source_error).into() } else { - anyhow::anyhow!("{}", source) + source_error } })?; // Convert the Bedrock event stream to ChatCompletionMessage stream - let stream = futures::stream::unfold(output.stream, |mut event_stream| async move { + let stream = futures::stream::try_unfold(output.stream, |mut event_stream| async move { match event_stream.recv().await { Ok(Some(event)) => { let message = event.into_domain(); - Some((Ok(message), event_stream)) + Ok(Some((message, event_stream))) } - Ok(None) => None, // End of stream + Ok(None) => Ok(None), // End of stream Err(stream_error) => { use aws_sdk_bedrockruntime::error::SdkError; @@ -259,16 +272,18 @@ impl BedrockProvider { _ => Self::is_retryable_sdk_error(&stream_error), }; + let source_error = match stream_error.into_source() { + Ok(source) => anyhow::Error::from_boxed(source), + Err(error) => anyhow::Error::new(error), + }; + let error = if is_retryable { - forge_domain::Error::Retryable(anyhow::anyhow!( - "Bedrock stream error: {:?}", - stream_error - )) - .into() + forge_domain::Error::Retryable(source_error).into() } else { - anyhow::anyhow!("Bedrock stream error: {:?}", stream_error) + source_error }; - Some((Err(error), event_stream)) + + Err(error) } } }); @@ -1148,6 +1163,66 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_retryable_converse_service_unavailable_non_capacity_message() { + use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError; + + let fixture = ConverseStreamError::ServiceUnavailableException( + aws_sdk_bedrockruntime::types::error::ServiceUnavailableException::builder() + .message("temporarily unavailable") + .build(), + ); + + let actual = BedrockProvider::is_retryable_converse_error(&fixture); + let expected = true; + + assert_eq!(actual, expected); + } + + #[test] + fn test_non_retryable_converse_service_unavailable_too_many_connections() { + use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError; + + let fixture = ConverseStreamError::ServiceUnavailableException( + aws_sdk_bedrockruntime::types::error::ServiceUnavailableException::builder() + .message("Too many connections, please wait before trying again.") + .build(), + ); + + let actual = BedrockProvider::is_retryable_converse_error(&fixture); + let expected = false; + + assert_eq!(actual, expected); + } + + #[test] + fn test_retryable_stream_output_service_unavailable_non_capacity_message() { + let fixture = aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError::ServiceUnavailableException( + aws_sdk_bedrockruntime::types::error::ServiceUnavailableException::builder() + .message("temporarily unavailable") + .build(), + ); + + let actual = BedrockProvider::is_retryable_stream_output_error(&fixture); + let expected = true; + + assert_eq!(actual, expected); + } + + #[test] + fn test_non_retryable_stream_output_service_unavailable_too_many_connections() { + let fixture = aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError::ServiceUnavailableException( + aws_sdk_bedrockruntime::types::error::ServiceUnavailableException::builder() + .message("Too many connections, please wait before trying again.") + .build(), + ); + + let actual = BedrockProvider::is_retryable_stream_output_error(&fixture); + let expected = false; + + assert_eq!(actual, expected); + } + #[test] fn test_json_value_to_document_null() { let fixture = serde_json::Value::Null; diff --git a/crates/forge_repo/src/provider/event.rs b/crates/forge_repo/src/provider/event.rs index 7ed7433267..6281677c56 100644 --- a/crates/forge_repo/src/provider/event.rs +++ b/crates/forge_repo/src/provider/event.rs @@ -1,8 +1,9 @@ use anyhow::Context; +use eventsource_client::{Event, SSE}; +use forge_app::EventSource; use forge_app::domain::ChatCompletionMessage; use forge_app::dto::openai::Error; -use reqwest::Url; -use reqwest_eventsource::{Event, EventSource}; +use reqwest::{StatusCode, Url}; use serde::de::DeserializeOwned; use tokio_stream::{Stream, StreamExt}; use tracing::debug; @@ -18,66 +19,183 @@ where ChatCompletionMessage: TryFrom, { source - .take_while(|message| !matches!(message, Err(reqwest_eventsource::Error::StreamEnded))) - .then(|event| async { + .then(move |event| { + let url = url.clone(); + async move { match event { - Ok(event) => match event { - Event::Open => None, - Event::Message(event) if ["[DONE]", ""].contains(&event.data.as_str()) => { - - debug!("Received completion from Upstream"); - None - } - Event::Message(message) => Some( - serde_json::from_str::(&message.data) - .with_context(|| { - format!( - "Failed to parse provider response: {}", - message.data - ) - }) - .and_then(|response| { - ChatCompletionMessage::try_from(response).with_context( - || { - format!( - "Failed to create completion message: {}", - message.data - ) - }, - ) - }) - ), - }, - Err(error) => match error { - reqwest_eventsource::Error::StreamEnded => None, - reqwest_eventsource::Error::InvalidStatusCode(_, response) => { - let status = response.status(); - let body = response.text().await.ok(); - Some(Err(Error::InvalidStatusCode(status.as_u16())).with_context( - || match body { - Some(body) => { - format!("{status} Reason: {body}") - } - None => { - format!("{status} Reason: [Unknown]") - } - }, - )) - } - reqwest_eventsource::Error::InvalidContentType(_, ref response) => { - let status_code = response.status(); - debug!(response = ?response, "Invalid content type"); - Some(Err(error).with_context(|| format!("Http Status: {status_code}"))) - } - error => { - tracing::error!(error = ?error, "Failed to receive chat completion event"); - Some(Err(error.into())) - } - }, + Ok(SSE::Connected(_)) => None, + Ok(SSE::Comment(_)) => None, + Ok(SSE::Event(event)) => handle_event::(event, url).await, + Err(error) => handle_error(error, url).await, } - }) - .filter_map(move |response| { - response - .map(|result| result.with_context(|| format_http_context(None, "POST", url.clone()))) - }) + } + }) + .filter_map(|response| response) +} + +async fn handle_event( + event: Event, + url: Url, +) -> Option> +where + Response: DeserializeOwned, + ChatCompletionMessage: TryFrom, +{ + // Check for completion markers + if ["[DONE]", ""].contains(&event.data.as_str()) { + debug!("Received completion from Upstream"); + return None; + } + + // Parse the JSON response + let result = serde_json::from_str::(&event.data) + .with_context(|| format!("Failed to parse provider response: {}", event.data)) + .and_then(|response| { + ChatCompletionMessage::try_from(response) + .with_context(|| format!("Failed to create completion message: {}", event.data)) + }) + .with_context(|| format_http_context(None, "POST", url)); + + Some(result) +} + +async fn handle_error( + error: anyhow::Error, + url: Url, +) -> Option> { + let error_msg = error.to_string(); + + // Check for specific error patterns from eventsource-client + // The error types are different from reqwest-eventsource + if error_msg.to_lowercase().contains("eof") || error_msg.contains("stream ended") { + return None; + } + + // Check for HTTP status errors that we might extract from the error message + // eventsource-client wraps HTTP errors differently + if error_msg.contains("UnexpectedResponse") { + let status_code = extract_status_code(&error_msg); + if let Some(status) = status_code { + let status_display = StatusCode::from_u16(status) + .map(|status| status.to_string()) + .unwrap_or_else(|_| status.to_string()); + let reason = + extract_unexpected_response_reason(&error_msg).unwrap_or_else(|| error_msg.clone()); + return Some( + Err(Error::InvalidStatusCode(status)) + .with_context(|| format!("{} Reason: {}", status_display, reason)) + .with_context(|| format_http_context(None, "POST", &url)), + ); + } + } + + tracing::error!(error = ?error, "Failed to receive chat completion event"); + Some(Err(error).with_context(|| format_http_context(None, "POST", url))) +} + +/// Extract a status code from an error message string +fn extract_status_code(error_msg: &str) -> Option { + // Look for patterns like "401 Unauthorized" or "status: 401" + use regex::Regex; + + // Try to find a 3-digit status code in the error message + let re = Regex::new(r"\b(\d{3})\b").ok()?; + if let Some(captures) = re.captures(error_msg) + && let Some(code) = captures.get(1) + { + return code.as_str().parse().ok(); + } + None +} + +fn extract_unexpected_response_reason(error_msg: &str) -> Option { + let body_marker = "body: "; + let body_start = error_msg.find(body_marker)? + body_marker.len(); + let body = error_msg.get(body_start..)?; + Some(body.trim_end_matches(')').to_string()) +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + use serde::{Deserialize, Serialize}; + use tokio_stream::StreamExt; + + use super::*; + + #[derive(Debug, Serialize)] + struct FixtureApiErrorBody { + r#type: String, + message: String, + } + + #[derive(Debug, thiserror::Error)] + enum FixtureEventSourceError { + #[error("UnexpectedResponse(status: {status}, body: {body})")] + UnexpectedResponse { + status: StatusCode, + body: serde_json::Value, + }, + #[error("eof")] + Eof, + } + + #[derive(Debug, Deserialize)] + struct FixtureResponse; + + impl TryFrom for ChatCompletionMessage { + type Error = anyhow::Error; + + fn try_from(_value: FixtureResponse) -> Result { + Ok(ChatCompletionMessage::default()) + } + } + + #[tokio::test] + async fn test_into_chat_completion_message_preserves_unexpected_response_error() { + let url = Url::parse("https://example.com/v1/chat/completions").unwrap(); + let fixture = FixtureApiErrorBody { + r#type: "error".to_string(), + message: "Subscription quota exceeded".to_string(), + }; + let reason = serde_json::to_value(fixture).unwrap(); + let source: EventSource = Box::pin(tokio_stream::iter(vec![Err( + FixtureEventSourceError::UnexpectedResponse { + status: StatusCode::TOO_MANY_REQUESTS, + body: reason.clone(), + } + .into(), + )])); + + let mut actual = Box::pin(into_chat_completion_message::(url, source)); + let error = actual + .next() + .await + .expect("stream should yield an error") + .expect_err("stream item should be an error"); + let expected = vec![ + "POST https://example.com/v1/chat/completions".to_string(), + format!("429 Too Many Requests Reason: {}", reason), + "Invalid Status Code: 429".to_string(), + ]; + let actual = error + .chain() + .map(|error| error.to_string()) + .collect::>(); + + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_into_chat_completion_message_ignores_eof_error() { + let url = Url::parse("https://example.com/v1/chat/completions").unwrap(); + let source: EventSource = Box::pin(tokio_stream::iter(vec![Err( + FixtureEventSourceError::Eof.into(), + )])); + + let mut actual = Box::pin(into_chat_completion_message::(url, source)); + let actual = actual.next().await; + + assert!(actual.is_none()); + } } diff --git a/crates/forge_repo/src/provider/google.rs b/crates/forge_repo/src/provider/google.rs index 390f1cd7f6..b62f5911cc 100644 --- a/crates/forge_repo/src/provider/google.rs +++ b/crates/forge_repo/src/provider/google.rs @@ -230,12 +230,21 @@ mod tests { Context, ContextMessage, ToolCallFull, ToolCallId, ToolChoice, ToolName, ToolOutput, ToolResult, }; + use futures::StreamExt; use reqwest::header::HeaderMap; - use reqwest_eventsource::EventSource; use super::*; use crate::provider::mock_server::{MockServer, normalize_ports}; + #[derive(Debug, thiserror::Error)] + enum MockGoogleEventSourceError { + #[error("Mock Google EventSource stream error")] + Stream { + #[source] + source: reqwest::Error, + }, + } + // Mock implementation of HttpInfra for testing #[derive(Clone)] struct MockHttpClient { @@ -280,14 +289,58 @@ mod tests { url: &Url, headers: Option, body: Bytes, - ) -> anyhow::Result { - let mut request = self.client.post(url.clone()); + ) -> anyhow::Result< + std::pin::Pin< + Box< + dyn futures::Stream> + + Send + + Sync, + >, + >, + > { + // For tests, make an actual HTTP request and parse SSE from the response + let mut request = self.client.post(url.clone()).body(body); if let Some(headers) = headers { request = request.headers(headers); } - request = request.body(body); - let request_builder = request; - Ok(EventSource::new(request_builder).map_err(|e| anyhow::anyhow!(e))?) + let response = request.send().await?; + + // Create a stream that yields SSE events from the response + let stream = response.bytes_stream().flat_map(|result| { + match result { + Ok(bytes) => { + // Simple parsing: treat each chunk as SSE event data + if let Ok(text) = String::from_utf8(bytes.to_vec()) { + let events: Vec<_> = text + .lines() + .filter_map(|line| { + line.strip_prefix("data: ").map(|data| { + Ok(eventsource_client::SSE::Event( + eventsource_client::Event { + data: data.to_string(), + event_type: String::new(), + id: None, + retry: None, + }, + )) + }) + }) + .collect(); + futures::stream::iter(events) + } else { + futures::stream::iter(vec![]) + } + } + Err(e) => { + futures::stream::iter(vec![Err(MockGoogleEventSourceError::Stream { + source: e, + } + .into())]) + } + } + }); + + Ok(Box::pin(stream)) } } diff --git a/crates/forge_repo/src/provider/mod.rs b/crates/forge_repo/src/provider/mod.rs index cd24f07887..bd760acf24 100644 --- a/crates/forge_repo/src/provider/mod.rs +++ b/crates/forge_repo/src/provider/mod.rs @@ -12,6 +12,7 @@ mod openai_responses; mod opencode; mod provider_repo; mod retry; +mod sse_parser; mod utils; pub use chat::*; diff --git a/crates/forge_repo/src/provider/openai.rs b/crates/forge_repo/src/provider/openai.rs index 31eccd8592..a0902bd130 100644 --- a/crates/forge_repo/src/provider/openai.rs +++ b/crates/forge_repo/src/provider/openai.rs @@ -373,7 +373,6 @@ mod tests { use forge_app::domain::{Provider, ProviderId, ProviderResponse}; use forge_app::dto::openai::{ContentPart, ImageUrl, Message, MessageContent, Role}; use reqwest::header::HeaderMap; - use reqwest_eventsource::EventSource; use url::Url; use super::*; @@ -498,7 +497,15 @@ mod tests { _url: &Url, _headers: Option, _body: Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result< + std::pin::Pin< + Box< + dyn futures::Stream> + + Send + + Sync, + >, + >, + > { unimplemented!() } } diff --git a/crates/forge_repo/src/provider/openai_responses/repository.rs b/crates/forge_repo/src/provider/openai_responses/repository.rs index e311d13c20..b0dbfcf4bb 100644 --- a/crates/forge_repo/src/provider/openai_responses/repository.rs +++ b/crates/forge_repo/src/provider/openai_responses/repository.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use anyhow::Context as _; use async_openai::types::responses as oai; -use eventsource_stream::Eventsource; +use eventsource_client::SSE; use forge_app::domain::{ ChatCompletionMessage, Context as ChatContext, Model, ModelId, ResultStream, }; @@ -16,6 +16,7 @@ use url::Url; use crate::provider::FromDomain; use crate::provider::retry::into_retry; +use crate::provider::sse_parser::parse_sse_stream; use crate::provider::utils::{create_headers, format_http_context}; #[derive(Clone)] @@ -187,47 +188,11 @@ impl OpenAIResponsesProvider { .await .with_context(|| format_http_context(None, "POST", &self.responses_url))?; - // Parse SSE stream into domain messages and convert to domain type - use reqwest_eventsource::Event; - let event_stream = source - .take_while(|message| { - let should_continue = - !matches!(message, Err(reqwest_eventsource::Error::StreamEnded)); - async move { should_continue } - }) - .filter_map(|event_result| async move { - match event_result { - Ok(Event::Open) => None, - Ok(Event::Message(msg)) if ["[DONE]", ""].contains(&msg.data.as_str()) => None, - Ok(Event::Message(msg)) => { - let result = serde_json::from_str::( - &msg.data, - ) - .with_context(|| format!("Failed to parse SSE event: {}", msg.data)); - - match result { - Ok(super::response::ResponsesStreamEvent::Keepalive { .. }) => None, - Ok(super::response::ResponsesStreamEvent::Ping { cost }) => { - let usage = - forge_domain::Usage { cost: Some(cost), ..Default::default() }; - Some(Ok(super::response::StreamItem::Message(Box::new( - ChatCompletionMessage::assistant(forge_domain::Content::part( - "", - )) - .usage(usage), - )))) - } - Ok(super::response::ResponsesStreamEvent::Unknown(_)) => None, - Ok(super::response::ResponsesStreamEvent::Response(inner)) => { - Some(Ok(super::response::StreamItem::Event(inner))) - } - Err(e) => Some(Err(e)), - } - } - Err(reqwest_eventsource::Error::StreamEnded) => None, - Err(e) => Some(Err(anyhow::Error::from(e))), - } - }); + let responses_url = self.responses_url.clone(); + let event_stream = source.filter_map(move |event_result: anyhow::Result| { + let responses_url = responses_url.clone(); + async move { into_response_stream_item(event_result, &responses_url) } + }); // Convert to domain messages using the existing conversion logic use crate::provider::IntoDomain; @@ -237,7 +202,7 @@ impl OpenAIResponsesProvider { /// Streams a Codex chat response by making a direct HTTP POST and /// parsing SSE from the raw byte stream, bypassing Content-Type - /// validation that `reqwest-eventsource` enforces. + /// validation that reqwest-eventsource enforced. async fn chat_codex_stream( &self, headers: reqwest::header::HeaderMap, @@ -259,21 +224,22 @@ impl OpenAIResponsesProvider { .with_context(|| format_http_context(Some(status), "POST", &self.responses_url)); } - // Parse the raw byte stream as SSE events using eventsource-stream. + // Parse the raw byte stream as SSE events using our sse_parser module. // This mirrors the AI SDK approach: TextDecoderStream -> // EventSourceParserStream -> JSON parse, without any Content-Type // requirement. let byte_stream = response.bytes_stream(); - let event_stream = byte_stream - .eventsource() - .filter_map(|event_result| async move { + let event_stream = parse_sse_stream(byte_stream).filter_map( + |event_result: anyhow::Result| async move { match event_result { - Ok(event) if ["[DONE]", ""].contains(&event.data.as_str()) => None, Ok(event) => { - let result = serde_json::from_str::( - &event.data, - ) - .with_context(|| format!("Failed to parse SSE event: {}", event.data)); + let data = event.data(); + if ["[DONE]", ""].contains(&data) { + return None; + } + let result = + serde_json::from_str::(data) + .with_context(|| format!("Failed to parse SSE event: {}", data)); match result { Ok(super::response::ResponsesStreamEvent::Keepalive { .. }) => None, Ok(super::response::ResponsesStreamEvent::Ping { cost }) => { @@ -295,7 +261,8 @@ impl OpenAIResponsesProvider { } Err(e) => Some(Err(into_sse_parse_error(e))), } - }); + }, + ); use crate::provider::IntoDomain; let stream: BoxStream = Box::pin(event_stream); @@ -303,11 +270,62 @@ impl OpenAIResponsesProvider { } } -fn into_sse_parse_error(error: eventsource_stream::EventStreamError) -> anyhow::Error +fn into_response_stream_item( + event_result: anyhow::Result, + responses_url: &Url, +) -> Option> { + match event_result { + Ok(SSE::Event(event)) => parse_response_stream_event(event), + Ok(_) => None, + Err(error) => into_eventsource_error(error, responses_url).map(Err), + } +} + +fn parse_response_stream_event( + event: eventsource_client::Event, +) -> Option> { + let data = &event.data; + if ["[DONE]", ""].contains(&data.as_str()) { + return None; + } + + let result = serde_json::from_str::(data) + .with_context(|| format!("Failed to parse SSE event: {}", data)); + + match result { + Ok(super::response::ResponsesStreamEvent::Keepalive { .. }) => None, + Ok(super::response::ResponsesStreamEvent::Ping { cost }) => { + let usage = forge_domain::Usage { cost: Some(cost), ..Default::default() }; + Some(Ok(super::response::StreamItem::Message(Box::new( + ChatCompletionMessage::assistant(forge_domain::Content::part("")).usage(usage), + )))) + } + Ok(super::response::ResponsesStreamEvent::Unknown(_)) => None, + Ok(super::response::ResponsesStreamEvent::Response(inner)) => { + Some(Ok(super::response::StreamItem::Event(inner))) + } + Err(error) => Some(Err(error)), + } +} + +fn into_eventsource_error(error: anyhow::Error, responses_url: &Url) -> Option { + let error_message = error.to_string(); + let error_message_lowercase = error_message.to_lowercase(); + + if error_message_lowercase.contains("eof") || error_message_lowercase.contains("stream ended") { + return None; + } + + Some(error.context(format_http_context(None, "POST", responses_url))) +} + +fn into_sse_parse_error(error: E) -> anyhow::Error where E: std::fmt::Debug + std::fmt::Display + Send + Sync + 'static, { - let is_retryable = matches!(&error, eventsource_stream::EventStreamError::Transport(_)); + // Check if the error is retryable based on the display string + let error_str = format!("{}", error); + let is_retryable = error_str.contains("transport") || error_str.contains("network"); let error = anyhow::anyhow!("SSE parse error: {}", error); if is_retryable { @@ -439,6 +457,23 @@ mod tests { use super::*; use crate::provider::mock_server::MockServer; + #[derive(Debug, thiserror::Error)] + enum MockResponsesEventSourceError { + #[error("Mock Responses EventSource stream error")] + Stream { + #[source] + source: reqwest::Error, + }, + } + + #[derive(Debug, thiserror::Error)] + enum FixtureStreamError { + #[error("transport connection failed")] + TransportConnectionFailed, + #[error("EventSource error: eof")] + EventSourceEof, + } + fn is_retryable(error: &anyhow::Error) -> bool { error .downcast_ref::() @@ -511,12 +546,80 @@ mod tests { url: &reqwest::Url, headers: Option, body: bytes::Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result< + std::pin::Pin< + Box< + dyn futures::Stream> + + Send + + Sync, + >, + >, + > { + use futures::StreamExt; + + // For tests, make an actual HTTP request and parse SSE from the response let mut request = self.client.post(url.clone()).body(body); if let Some(headers) = headers { request = request.headers(headers); } - Ok(reqwest_eventsource::EventSource::new(request)?) + let response = request.send().await?; + + // Create a stream that yields SSE events from the response + let stream = + response.bytes_stream().flat_map(|result| { + match result { + Ok(bytes) => { + // Simple parsing: treat each chunk as SSE event data + if let Ok(text) = String::from_utf8(bytes.to_vec()) { + let mut events = Vec::new(); + let mut current_event = String::new(); + let mut current_event_type = String::new(); + + for line in text.lines() { + if line.starts_with("event: ") { + current_event_type = + line.strip_prefix("event: ").unwrap().to_string(); + } else if let Some(data) = line.strip_prefix("data: ") { + current_event = data.to_string(); + } else if line.is_empty() && !current_event.is_empty() { + // End of event, yield it + events.push(Ok(eventsource_client::SSE::Event( + eventsource_client::Event { + data: current_event.clone(), + event_type: current_event_type.clone(), + id: None, + retry: None, + }, + ))); + current_event.clear(); + current_event_type.clear(); + } + } + + // Handle last event if no trailing newline + if !current_event.is_empty() { + events.push(Ok(eventsource_client::SSE::Event( + eventsource_client::Event { + data: current_event, + event_type: current_event_type, + id: None, + retry: None, + }, + ))); + } + + futures::stream::iter(events) + } else { + futures::stream::iter(vec![]) + } + } + Err(e) => futures::stream::iter(vec![Err( + MockResponsesEventSourceError::Stream { source: e }.into(), + )]), + } + }); + + Ok(Box::pin(stream)) } } @@ -931,29 +1034,44 @@ mod tests { #[test] fn test_into_sse_parse_error_marks_transport_errors_retryable() { - let error = into_sse_parse_error(eventsource_stream::EventStreamError::Transport( - anyhow::anyhow!("error decoding response body"), - )); + // Create a custom error with "transport" in the message + let custom_error = FixtureStreamError::TransportConnectionFailed; + let error = into_sse_parse_error(custom_error); assert!(is_retryable(&error)); - assert_eq!( - error.to_string(), - "SSE parse error: Transport error: error decoding response body" - ); + assert!(error.to_string().contains("transport")); } #[test] - fn test_into_sse_parse_error_keeps_utf8_errors_non_retryable() { - let error = - into_sse_parse_error(eventsource_stream::EventStreamError::::Utf8( - String::from_utf8(vec![0xFF]).unwrap_err(), - )); + fn test_into_sse_parse_error_keeps_eof_errors_non_retryable() { + // Eof error should be non-retryable + let error = into_sse_parse_error(eventsource_client::Error::Eof); assert!(!is_retryable(&error)); - assert_eq!( - error.to_string(), - "SSE parse error: UTF8 error: invalid utf-8 sequence of 1 bytes from index 0" + // The error message contains "eof" from the Display impl + assert!(error.to_string().contains("eof")); + } + + #[test] + fn test_into_eventsource_error_ignores_eof_errors() { + let fixture = FixtureStreamError::EventSourceEof.into(); + let actual = into_eventsource_error( + fixture, + &Url::parse("https://example.com/v1/responses").unwrap(), + ); + + assert!(actual.is_none()); + } + + #[test] + fn test_into_response_stream_item_ignores_eof_events() { + let fixture = Err(FixtureStreamError::EventSourceEof.into()); + let actual = into_response_stream_item( + fixture, + &Url::parse("https://example.com/v1/responses").unwrap(), ); + + assert!(actual.is_none()); } #[test] diff --git a/crates/forge_repo/src/provider/provider_repo.rs b/crates/forge_repo/src/provider/provider_repo.rs index d17c1c25c5..1d51b2043c 100644 --- a/crates/forge_repo/src/provider/provider_repo.rs +++ b/crates/forge_repo/src/provider/provider_repo.rs @@ -923,7 +923,15 @@ mod env_tests { _url: &reqwest::Url, _headers: Option, _body: bytes::Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result< + std::pin::Pin< + Box< + dyn futures::Stream> + + Send + + Sync, + >, + >, + > { Err(anyhow::anyhow!("HTTP not implemented in mock")) } } @@ -1409,7 +1417,15 @@ mod env_tests { _url: &reqwest::Url, _headers: Option, _body: bytes::Bytes, - ) -> anyhow::Result { + ) -> anyhow::Result< + std::pin::Pin< + Box< + dyn futures::Stream> + + Send + + Sync, + >, + >, + > { Err(anyhow::anyhow!("HTTP not implemented in mock")) } } diff --git a/crates/forge_repo/src/provider/retry.rs b/crates/forge_repo/src/provider/retry.rs index b90d9b0c08..dbec092956 100644 --- a/crates/forge_repo/src/provider/retry.rs +++ b/crates/forge_repo/src/provider/retry.rs @@ -7,7 +7,7 @@ const TRANSPORT_ERROR_CODES: [&str; 3] = ["ERR_STREAM_PREMATURE_CLOSE", "ECONNRE pub fn into_retry(error: anyhow::Error, retry_config: &RetryConfig) -> anyhow::Error { if let Some(code) = get_req_status_code(&error) - .or(get_event_req_status_code(&error)) + .or(get_sse_status_code(&error)) .or(get_api_status_code(&error)) && retry_config.status_codes.contains(&code) { @@ -16,7 +16,7 @@ pub fn into_retry(error: anyhow::Error, retry_config: &RetryConfig) -> anyhow::E if is_api_transport_error(&error) || is_req_transport_error(&error) - || is_event_transport_error(&error) + || is_sse_transport_error(&error) || is_empty_error(&error) || is_anthropic_overloaded_error(&error) { @@ -51,18 +51,47 @@ fn get_req_status_code(error: &anyhow::Error) -> Option { .map(|status| status.as_u16()) } -fn get_event_req_status_code(error: &anyhow::Error) -> Option { - error - .downcast_ref::() - .and_then(|error| match error { - reqwest_eventsource::Error::InvalidStatusCode(_, response) => { - Some(response.status().as_u16()) - } - reqwest_eventsource::Error::InvalidContentType(_, response) => { - Some(response.status().as_u16()) +/// Extract status code from eventsource-client errors +/// Handles UnexpectedResponse and other error types from eventsource-client +fn get_sse_status_code(error: &anyhow::Error) -> Option { + // Check if this is an eventsource-client error + if let Some(error_str) = error.downcast_ref::() { + // Try to extract status code from error message + // Format is often: "UnexpectedResponse(status: XXX, ...)" + if error_str.contains("UnexpectedResponse") { + return extract_status_from_message(error_str); + } + } + + // Check in the error chain for HTTP-related errors + let error_msg = error.to_string(); + if error_msg.contains("UnexpectedResponse") || error_msg.contains("status") { + return extract_status_from_message(&error_msg); + } + + None +} + +/// Extract status code from error message text +fn extract_status_from_message(msg: &str) -> Option { + // Look for patterns like "status: 401" or "401" in error messages + let patterns = ["status: ", "status code ", "HTTP ", "("]; + for pattern in patterns { + if let Some(pos) = msg.find(pattern) { + let after_pattern = msg.get(pos + pattern.len()..).unwrap_or(""); + // Try to parse a number after the pattern + let num_str: String = after_pattern + .chars() + .take_while(|c| c.is_ascii_digit()) + .collect(); + if let Ok(status) = num_str.parse::() + && (100..=599).contains(&status) + { + return Some(status); } - _ => None, - }) + } + } + None } fn has_transport_error_code(error: &ErrorResponse) -> bool { @@ -109,10 +138,28 @@ fn is_req_transport_error(error: &anyhow::Error) -> bool { .is_some_and(|e| e.is_timeout() || e.is_connect() || e.is_request()) } -fn is_event_transport_error(error: &anyhow::Error) -> bool { - error - .downcast_ref::() - .is_some_and(|e| matches!(e, reqwest_eventsource::Error::Transport(_))) +/// Check if error is an SSE transport error from eventsource-client +/// Checks for network/connection related errors in the error message +fn is_sse_transport_error(error: &anyhow::Error) -> bool { + let error_msg = error.to_string(); + + // Check for transport-related keywords in error message + let transport_keywords = [ + "transport", + "network", + "connection", + "EOF", + "stream ended", + "UnexpectedResponse", + "io error", + "broken pipe", + "connection reset", + "timeout", + ]; + + transport_keywords + .iter() + .any(|kw| error_msg.to_lowercase().contains(kw)) } #[cfg(test)] @@ -299,10 +346,8 @@ mod tests { let error = anyhow!("Generic error"); assert!(!is_api_transport_error(&error)); assert!(!is_req_transport_error(&error)); - assert!(!is_event_transport_error(&error)); assert!(get_api_status_code(&error).is_none()); assert!(get_req_status_code(&error).is_none()); - assert!(get_event_req_status_code(&error).is_none()); } #[test] diff --git a/crates/forge_repo/src/provider/sse_parser.rs b/crates/forge_repo/src/provider/sse_parser.rs new file mode 100644 index 0000000000..8bea7011b1 --- /dev/null +++ b/crates/forge_repo/src/provider/sse_parser.rs @@ -0,0 +1,182 @@ +use std::pin::Pin; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use futures::Stream; + +/// A simple SSE event parsed from a byte stream. +#[derive(Debug, Clone)] +pub struct SSEEvent { + #[allow(dead_code)] + pub event_type: Option, + pub data: String, + #[allow(dead_code)] + pub id: Option, +} + +impl SSEEvent { + /// Get the event data. + pub fn data(&self) -> &str { + &self.data + } +} + +#[derive(Debug, thiserror::Error)] +enum SseParserError +where + E: std::error::Error + Send + Sync + 'static, +{ + #[error("Invalid UTF-8 in SSE stream")] + InvalidUtf8 { + #[source] + source: std::string::FromUtf8Error, + }, + #[error("SSE stream read error")] + Stream { + #[source] + source: E, + }, +} + +/// A stream adapter that parses Server-Sent Events from a bytes stream. +/// +/// This is a simple SSE parser for cases where we need to parse raw byte +/// streams (e.g., for providers that don't return proper Content-Type headers). +pub struct BytesToSSE { + inner: S, + buffer: String, + _phantom: std::marker::PhantomData, +} + +impl BytesToSSE +where + S: Stream>, +{ + /// Create a new BytesToSSE stream from a bytes stream. + pub fn new(inner: S) -> Self { + Self { + inner, + buffer: String::new(), + _phantom: std::marker::PhantomData, + } + } +} + +// Implement Unpin when the inner stream is Unpin +impl Unpin for BytesToSSE {} + +impl Stream for BytesToSSE +where + S: Stream> + Unpin, + E: std::error::Error + Send + Sync + 'static, +{ + type Item = anyhow::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + // Try to find a complete event in the buffer (double newline marks end of + // event) + if let Some(pos) = self.buffer.find("\n\n") { + let event_text = self.buffer.get(..pos).unwrap_or("").to_string(); + self.buffer = self.buffer.get(pos + 2..).unwrap_or("").to_string(); + + // Parse the event + let mut event_type = None; + let mut data = String::new(); + let mut id = None; + + for line in event_text.lines() { + if let Some(colon_pos) = line.find(':') { + let field = line.get(..colon_pos).unwrap_or(""); + let value = line.get(colon_pos + 1..).unwrap_or("").trim_start(); + + match field { + "event" => event_type = Some(value.to_string()), + "data" => { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(value); + } + "id" => id = Some(value.to_string()), + _ => {} // Ignore unknown fields + } + } else if !line.is_empty() { + // Line without colon is treated as data (some servers send this) + if !data.is_empty() { + data.push('\n'); + } + data.push_str(line); + } + } + + return Poll::Ready(Some(Ok(SSEEvent { event_type, data, id }))); + } + + // Need more data - poll the inner stream + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => match String::from_utf8(bytes.to_vec()) { + Ok(text) => { + self.buffer.push_str(&text); + } + Err(e) => { + return Poll::Ready(Some(Err(SseParserError::::InvalidUtf8 { + source: e, + } + .into()))); + } + }, + Poll::Ready(Some(Err(e))) => { + return Poll::Ready(Some( + Err(SseParserError::::Stream { source: e }.into()), + )); + } + Poll::Ready(None) => { + // Stream ended - if there's remaining data, return it as an event + if !self.buffer.trim().is_empty() { + let remaining = std::mem::take(&mut self.buffer); + // Try to parse the remaining buffer + let mut event_type = None; + let mut data = String::new(); + let mut id = None; + + for line in remaining.lines() { + if let Some(colon_pos) = line.find(':') { + let field = line.get(..colon_pos).unwrap_or(""); + let value = line.get(colon_pos + 1..).unwrap_or("").trim_start(); + + match field { + "event" => event_type = Some(value.to_string()), + "data" => { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(value); + } + "id" => id = Some(value.to_string()), + _ => {} + } + } + } + + if !data.is_empty() || event_type.is_some() { + return Poll::Ready(Some(Ok(SSEEvent { event_type, data, id }))); + } + } + return Poll::Ready(None); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } +} + +/// Parse a byte stream as Server-Sent Events. +pub fn parse_sse_stream(stream: S) -> BytesToSSE +where + S: Stream> + Unpin, +{ + BytesToSSE::new(stream) +} diff --git a/crates/forge_services/Cargo.toml b/crates/forge_services/Cargo.toml index 6784647880..11e5e640bf 100644 --- a/crates/forge_services/Cargo.toml +++ b/crates/forge_services/Cargo.toml @@ -43,7 +43,7 @@ merge.workspace = true strip-ansi-escapes.workspace = true forge_app.workspace = true url.workspace = true -reqwest-eventsource.workspace = true +eventsource-client.workspace = true lazy_static = "1.5.0" forge_domain.workspace = true forge_config.workspace = true