From 05ac0e01154f1778446d6955bc72765f1d14bccc Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 16:56:14 -0700 Subject: [PATCH 1/9] tests and assistant choice fix --- rig/rig-core/src/completion/mod.rs | 1 + rig/rig-core/src/completion/request.rs | 12 + .../src/providers/anthropic/completion.rs | 10 +- .../providers/anthropic/conformance_tests.rs | 260 ++++++++++++++++ rig/rig-core/src/providers/conformance.rs | 286 ++++++++++++++++++ .../src/providers/gemini/completion.rs | 19 +- .../src/providers/gemini/conformance_tests.rs | 264 ++++++++++++++++ rig/rig-core/src/providers/mod.rs | 2 + .../openai/completion/conformance_tests.rs | 267 ++++++++++++++++ .../src/providers/openai/completion/mod.rs | 9 +- .../src/providers/openai/responses_api/mod.rs | 6 +- 11 files changed, 1110 insertions(+), 26 deletions(-) create mode 100644 rig/rig-core/src/providers/anthropic/conformance_tests.rs create mode 100644 rig/rig-core/src/providers/conformance.rs create mode 100644 rig/rig-core/src/providers/gemini/conformance_tests.rs create mode 100644 rig/rig-core/src/providers/openai/completion/conformance_tests.rs diff --git a/rig/rig-core/src/completion/mod.rs b/rig/rig-core/src/completion/mod.rs index 7dcef0ac6..83c7d1508 100644 --- a/rig/rig-core/src/completion/mod.rs +++ b/rig/rig-core/src/completion/mod.rs @@ -2,4 +2,5 @@ pub mod message; pub mod request; pub use message::{AssistantContent, Message, MessageError}; +pub(crate) use request::assistant_choice_from_vec; pub use request::*; diff --git a/rig/rig-core/src/completion/request.rs b/rig/rig-core/src/completion/request.rs index 1dfd07690..250d751d3 100644 --- a/rig/rig-core/src/completion/request.rs +++ b/rig/rig-core/src/completion/request.rs @@ -374,6 +374,18 @@ pub struct CompletionResponse { pub message_id: Option, } +pub(crate) fn assistant_choice_from_vec( + content: Vec, +) -> Result, CompletionError> { + if content.is_empty() { + return Ok(OneOrMany::one(AssistantContent::text(""))); + } + + OneOrMany::many(content).map_err(|_| { + CompletionError::ResponseError("Response contained no message or tool call (empty)".into()) + }) +} + /// A trait for grabbing the token usage of a completion response. /// /// Primarily designed for streamed completion responses in streamed multi-turn, as otherwise it would be impossible to do. diff --git a/rig/rig-core/src/providers/anthropic/completion.rs b/rig/rig-core/src/providers/anthropic/completion.rs index 442f38b22..832052a22 100644 --- a/rig/rig-core/src/providers/anthropic/completion.rs +++ b/rig/rig-core/src/providers/anthropic/completion.rs @@ -214,11 +214,7 @@ impl TryFrom for completion::CompletionResponse, _>>()?; - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::assistant_choice_from_vec(content)?; let usage = completion::Usage { input_tokens: response.usage.input_tokens, @@ -1490,6 +1486,10 @@ enum ApiResponse { Error(ApiErrorResponse), } +#[cfg(test)] +#[path = "conformance_tests.rs"] +mod conformance_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/rig/rig-core/src/providers/anthropic/conformance_tests.rs b/rig/rig-core/src/providers/anthropic/conformance_tests.rs new file mode 100644 index 000000000..d8b6b7d38 --- /dev/null +++ b/rig/rig-core/src/providers/anthropic/conformance_tests.rs @@ -0,0 +1,260 @@ +use bytes::Bytes; +use serde_json::json; + +use super::*; +use crate::{ + OneOrMany, + completion::{self, CompletionError}, + http_client::mock::MockStreamingClient, + providers::conformance::{ + BoxFuture, Fixture, Harness, NormalizedItem, Outcome, StopReason, Turn, drain_stream, + normalize_completion_response, provider_conformance_tests, + }, +}; + +struct AnthropicHarness; + +impl Harness for AnthropicHarness { + fn family_name() -> &'static str { + "anthropic-messages" + } + + fn expected(case: Fixture) -> Outcome { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::EndTurn), + }), + Fixture::ToolOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::ToolCall { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::TextAndToolCallTurn => Outcome::Supported(Turn { + items: vec![ + NormalizedItem::Text("Need weather data first.".to_string()), + NormalizedItem::ToolCall { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }, + ], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::EmptyTextBlocks => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::EndTurn), + }), + Fixture::ReasoningOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::Reasoning( + "Need to reason about the tool result.".to_string(), + )], + message_id: None, + stop_reason: Some(StopReason::EndTurn), + }), + Fixture::MessageIdPreservation => { + Outcome::Unsupported("Anthropic Messages responses do not expose message IDs") + } + Fixture::StopReasonMapping => Outcome::Supported(Turn { + items: vec![NormalizedItem::Text("Truncated response".to_string())], + message_id: None, + stop_reason: Some(StopReason::MaxTokens), + }), + } + } + + fn non_stream(case: Fixture) -> Result, CompletionError> { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let raw = non_stream_response(case); + let stop_reason = raw.stop_reason.as_deref().map(map_stop_reason); + let response: completion::CompletionResponse = + raw.try_into()?; + Ok(Outcome::Supported(normalize_completion_response( + &response, + stop_reason, + ))) + } + } + } + + fn stream(case: Fixture) -> BoxFuture, CompletionError>> { + Box::pin(async move { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let client = crate::providers::anthropic::Client::builder() + .http_client(MockStreamingClient { + sse_bytes: Bytes::from(streaming_sse(case)), + }) + .api_key("test-key") + .build() + .expect("client should build"); + let model = CompletionModel::new(client, "claude-test"); + let stream = model.stream(stream_request()).await?; + let response = drain_stream(stream).await?; + Ok(Outcome::Supported(normalize_completion_response( + &response, None, + ))) + } + } + }) + } +} + +fn non_stream_response(case: Fixture) -> CompletionResponse { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => response_with_content(vec![], "end_turn"), + Fixture::ToolOnlyTurn => response_with_content( + vec![Content::ToolUse { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + input: json!({"city": "Paris"}), + }], + "tool_use", + ), + Fixture::TextAndToolCallTurn => response_with_content( + vec![ + Content::Text { + text: "Need weather data first.".to_string(), + cache_control: None, + }, + Content::ToolUse { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + input: json!({"city": "Paris"}), + }, + ], + "tool_use", + ), + Fixture::EmptyTextBlocks => response_with_content( + vec![Content::Text { + text: String::new(), + cache_control: None, + }], + "end_turn", + ), + Fixture::ReasoningOnlyTurn => response_with_content( + vec![Content::Thinking { + thinking: "Need to reason about the tool result.".to_string(), + signature: Some("sig_1".to_string()), + }], + "end_turn", + ), + Fixture::StopReasonMapping => response_with_content( + vec![Content::Text { + text: "Truncated response".to_string(), + cache_control: None, + }], + "max_tokens", + ), + Fixture::MessageIdPreservation => unreachable!(), + } +} + +fn response_with_content(content: Vec, stop_reason: &str) -> CompletionResponse { + CompletionResponse { + content, + id: "msg_123".to_string(), + model: "claude-test".to_string(), + role: "assistant".to_string(), + stop_reason: Some(stop_reason.to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_read_input_tokens: None, + cache_creation_input_tokens: None, + output_tokens: 5, + }, + } +} + +fn map_stop_reason(reason: &str) -> StopReason { + match reason { + "end_turn" => StopReason::EndTurn, + "tool_use" => StopReason::ToolCalls, + "max_tokens" => StopReason::MaxTokens, + other => StopReason::Other(other.to_string()), + } +} + +fn stream_request() -> completion::CompletionRequest { + completion::CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(completion::Message::user("hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: Some(32), + tool_choice: None, + additional_params: None, + output_schema: None, + } +} + +fn streaming_sse(case: Fixture) -> String { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":0}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::ToolOnlyTurn => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_lookup\",\"name\":\"lookup_weather\",\"input\":{\"city\":\"Paris\"}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::TextAndToolCallTurn => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Need weather data first.\"}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_lookup\",\"name\":\"lookup_weather\",\"input\":{\"city\":\"Paris\"}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":1}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":7}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::EmptyTextBlocks => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":0}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::ReasoningOnlyTurn => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\",\"signature\":null}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"Need to reason about the tool result.\"}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_1\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::StopReasonMapping => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Truncated response\"}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"max_tokens\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":2}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::MessageIdPreservation => unreachable!(), + } +} + +provider_conformance_tests!(AnthropicHarness); diff --git a/rig/rig-core/src/providers/conformance.rs b/rig/rig-core/src/providers/conformance.rs new file mode 100644 index 000000000..5267809e1 --- /dev/null +++ b/rig/rig-core/src/providers/conformance.rs @@ -0,0 +1,286 @@ +use std::{future::Future, pin::Pin}; + +use futures::StreamExt; +use serde_json::Value; + +use crate::{ + OneOrMany, + completion::{self, CompletionError, GetTokenUsage}, + message::AssistantContent, +}; + +#[derive(Debug, Clone, Copy)] +pub(crate) enum Fixture { + EmptyAssistantTurnAfterToolResult, + ToolOnlyTurn, + TextAndToolCallTurn, + EmptyTextBlocks, + ReasoningOnlyTurn, + MessageIdPreservation, + StopReasonMapping, +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum NormalizedItem { + Text(String), + ToolCall { + id: String, + name: String, + arguments: Value, + }, + Reasoning(String), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum StopReason { + Stop, + ToolCalls, + EndTurn, + MaxTokens, + ContentFilter, + Safety, + Other(String), +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct Turn { + pub(crate) items: Vec, + pub(crate) message_id: Option, + pub(crate) stop_reason: Option, +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum Outcome { + Supported(T), + Unsupported(&'static str), +} + +pub(crate) type BoxFuture = Pin + Send>>; + +pub(crate) trait Harness { + fn family_name() -> &'static str; + + fn expected(case: Fixture) -> Outcome; + + fn non_stream(case: Fixture) -> Result, CompletionError>; + + fn stream(case: Fixture) -> BoxFuture, CompletionError>>; +} + +pub(crate) fn normalize_turn( + choice: &OneOrMany, + message_id: Option, + stop_reason: Option, +) -> Turn { + let mut items = Vec::new(); + + for item in choice.iter() { + let normalized = match item { + AssistantContent::Text(text) if !text.text.is_empty() => { + Some(NormalizedItem::Text(text.text.clone())) + } + AssistantContent::ToolCall(tool_call) => Some(NormalizedItem::ToolCall { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }), + AssistantContent::Reasoning(reasoning) => { + let text = reasoning.display_text(); + if text.is_empty() { + None + } else { + Some(NormalizedItem::Reasoning(text)) + } + } + _ => None, + }; + + if let Some(normalized) = normalized { + let duplicate_reasoning = matches!( + (&normalized, items.last()), + (NormalizedItem::Reasoning(current), Some(NormalizedItem::Reasoning(previous))) + if current == previous + ); + + if !duplicate_reasoning { + items.push(normalized); + } + } + } + + Turn { + items, + message_id, + stop_reason, + } +} + +pub(crate) fn normalize_completion_response( + response: &completion::CompletionResponse, + stop_reason: Option, +) -> Turn { + normalize_turn(&response.choice, response.message_id.clone(), stop_reason) +} + +pub(crate) async fn drain_stream( + mut stream: crate::streaming::StreamingCompletionResponse, +) -> Result>, CompletionError> +where + R: Clone + Unpin + GetTokenUsage, +{ + while let Some(item) = stream.next().await { + item?; + } + + Ok(stream.into()) +} + +pub(crate) fn assert_non_stream_case(case: Fixture) { + let expected = H::expected(case); + let actual = H::non_stream(case) + .unwrap_or_else(|err| panic!("{} non-stream {:?} failed: {err}", H::family_name(), case)); + + assert_eq!( + actual, + expected, + "{} non-stream {:?} mismatch", + H::family_name(), + case + ); +} + +pub(crate) async fn assert_stream_matches_non_stream(case: Fixture) { + let non_stream = H::non_stream(case) + .unwrap_or_else(|err| panic!("{} non-stream {:?} failed: {err}", H::family_name(), case)); + let stream = H::stream(case) + .await + .unwrap_or_else(|err| panic!("{} stream {:?} failed: {err}", H::family_name(), case)); + + match (non_stream, stream) { + (Outcome::Supported(expected), Outcome::Supported(actual)) => { + assert_eq!( + actual.items, + expected.items, + "{} stream {:?} items diverged", + H::family_name(), + case + ); + assert_eq!( + actual.message_id, + expected.message_id, + "{} stream {:?} message_id diverged", + H::family_name(), + case + ); + } + (Outcome::Unsupported(_), Outcome::Unsupported(_)) => {} + (expected, actual) => panic!( + "{} stream/non-stream support mismatch for {:?}: non-stream={expected:?}, stream={actual:?}", + H::family_name(), + case + ), + } +} + +macro_rules! provider_conformance_tests { + ($harness:ty) => { + #[test] + fn conformance_empty_assistant_turn_after_tool_result_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::EmptyAssistantTurnAfterToolResult, + ); + } + + #[test] + fn conformance_tool_only_turn_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::ToolOnlyTurn, + ); + } + + #[test] + fn conformance_text_and_tool_call_turn_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::TextAndToolCallTurn, + ); + } + + #[test] + fn conformance_empty_text_blocks_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::EmptyTextBlocks, + ); + } + + #[test] + fn conformance_reasoning_only_turn_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::ReasoningOnlyTurn, + ); + } + + #[test] + fn conformance_message_id_preservation_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::MessageIdPreservation, + ); + } + + #[test] + fn conformance_stop_reason_mapping_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::StopReasonMapping, + ); + } + + #[tokio::test] + async fn conformance_empty_assistant_turn_after_tool_result_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::EmptyAssistantTurnAfterToolResult, + ) + .await; + } + + #[tokio::test] + async fn conformance_tool_only_turn_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::ToolOnlyTurn, + ) + .await; + } + + #[tokio::test] + async fn conformance_text_and_tool_call_turn_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::TextAndToolCallTurn, + ) + .await; + } + + #[tokio::test] + async fn conformance_empty_text_blocks_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::EmptyTextBlocks, + ) + .await; + } + + #[tokio::test] + async fn conformance_reasoning_only_turn_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::ReasoningOnlyTurn, + ) + .await; + } + + #[tokio::test] + async fn conformance_message_id_preservation_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::MessageIdPreservation, + ) + .await; + } + }; +} + +pub(crate) use provider_conformance_tests; diff --git a/rig/rig-core/src/providers/gemini/completion.rs b/rig/rig-core/src/providers/gemini/completion.rs index 80d908848..52ec7a9f8 100644 --- a/rig/rig-core/src/providers/gemini/completion.rs +++ b/rig/rig-core/src/providers/gemini/completion.rs @@ -27,15 +27,12 @@ use self::gemini_api_types::Schema; use crate::http_client::HttpClientExt; use crate::message::{self, MimeType, Reasoning}; +use crate::completion::{self, CompletionError, CompletionRequest}; use crate::providers::gemini::completion::gemini_api_types::{ AdditionalParameters, FunctionCallingMode, ToolConfig, }; use crate::providers::gemini::streaming::StreamingCompletionResponse; use crate::telemetry::SpanCombinator; -use crate::{ - OneOrMany, - completion::{self, CompletionError, CompletionRequest}, -}; use gemini_api_types::{ Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Part, PartKind, Role, Tool, @@ -493,11 +490,7 @@ impl TryFrom for completion::CompletionResponse, _>>()?; - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::assistant_choice_from_vec(content)?; let usage = response .usage_metadata @@ -2022,6 +2015,10 @@ pub mod gemini_api_types { } } +#[cfg(test)] +#[path = "conformance_tests.rs"] +mod conformance_tests; + #[cfg(test)] mod tests { use crate::{ @@ -2301,7 +2298,7 @@ mod tests { fn test_reasoning_signature_is_emitted_in_gemini_part() { let msg = message::Message::Assistant { id: None, - content: OneOrMany::one(message::AssistantContent::Reasoning( + content: crate::OneOrMany::one(message::AssistantContent::Reasoning( message::Reasoning::new_with_signature( "structured thought", Some("reuse_sig_456".to_string()), @@ -2334,7 +2331,7 @@ mod tests { let msg = message::Message::Assistant { id: None, - content: OneOrMany::one(message::AssistantContent::ToolCall(tool_call)), + content: crate::OneOrMany::one(message::AssistantContent::ToolCall(tool_call)), }; let content: Content = msg.try_into().unwrap(); diff --git a/rig/rig-core/src/providers/gemini/conformance_tests.rs b/rig/rig-core/src/providers/gemini/conformance_tests.rs new file mode 100644 index 000000000..fe73bbeb9 --- /dev/null +++ b/rig/rig-core/src/providers/gemini/conformance_tests.rs @@ -0,0 +1,264 @@ +use bytes::Bytes; +use serde_json::json; + +use super::*; +use crate::{ + OneOrMany, + completion::{self, CompletionError}, + http_client::mock::MockStreamingClient, + providers::conformance::{ + BoxFuture, Fixture, Harness, NormalizedItem, Outcome, StopReason, Turn, drain_stream, + normalize_completion_response, provider_conformance_tests, + }, +}; + +use super::gemini_api_types::{Content, ContentCandidate, FinishReason, Part, PartKind, Role}; + +struct GeminiHarness; + +impl Harness for GeminiHarness { + fn family_name() -> &'static str { + "gemini-generate-content" + } + + fn expected(case: Fixture) -> Outcome { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ToolOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::ToolCall { + id: "lookup_weather".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::TextAndToolCallTurn => Outcome::Supported(Turn { + items: vec![ + NormalizedItem::Text("Need weather data first.".to_string()), + NormalizedItem::ToolCall { + id: "lookup_weather".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }, + ], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::EmptyTextBlocks => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ReasoningOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::Reasoning( + "Need to reason about the tool result.".to_string(), + )], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::MessageIdPreservation => { + Outcome::Unsupported("Gemini GenerateContent responses do not expose message IDs") + } + Fixture::StopReasonMapping => Outcome::Supported(Turn { + items: vec![NormalizedItem::Text("Truncated response".to_string())], + message_id: None, + stop_reason: Some(StopReason::MaxTokens), + }), + } + } + + fn non_stream(case: Fixture) -> Result, CompletionError> { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let raw = non_stream_response(case); + let stop_reason = raw + .candidates + .first() + .and_then(|candidate| candidate.finish_reason.clone()) + .map(map_finish_reason); + let response: completion::CompletionResponse = + raw.try_into()?; + Ok(Outcome::Supported(normalize_completion_response( + &response, + stop_reason, + ))) + } + } + } + + fn stream(case: Fixture) -> BoxFuture, CompletionError>> { + Box::pin(async move { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let client = crate::providers::gemini::Client::builder() + .http_client(MockStreamingClient { + sse_bytes: Bytes::from(streaming_sse(case)), + }) + .api_key("test-key") + .build() + .expect("client should build"); + let model = CompletionModel::new(client, "gemini-test"); + let stream = model.stream(stream_request()).await?; + let response = drain_stream(stream).await?; + Ok(Outcome::Supported(normalize_completion_response( + &response, None, + ))) + } + } + }) + } +} + +fn non_stream_response(case: Fixture) -> GenerateContentResponse { + let candidate = match case { + Fixture::EmptyAssistantTurnAfterToolResult => candidate(vec![], FinishReason::Stop), + Fixture::ToolOnlyTurn => candidate( + vec![Part { + thought: None, + thought_signature: None, + part: PartKind::FunctionCall(gemini_api_types::FunctionCall { + name: "lookup_weather".to_string(), + args: json!({"city": "Paris"}), + }), + additional_params: None, + }], + FinishReason::Stop, + ), + Fixture::TextAndToolCallTurn => candidate( + vec![ + Part { + thought: Some(false), + thought_signature: None, + part: PartKind::Text("Need weather data first.".to_string()), + additional_params: None, + }, + Part { + thought: None, + thought_signature: None, + part: PartKind::FunctionCall(gemini_api_types::FunctionCall { + name: "lookup_weather".to_string(), + args: json!({"city": "Paris"}), + }), + additional_params: None, + }, + ], + FinishReason::Stop, + ), + Fixture::EmptyTextBlocks => candidate( + vec![Part { + thought: Some(false), + thought_signature: None, + part: PartKind::Text(String::new()), + additional_params: None, + }], + FinishReason::Stop, + ), + Fixture::ReasoningOnlyTurn => candidate( + vec![Part { + thought: Some(true), + thought_signature: Some("sig_1".to_string()), + part: PartKind::Text("Need to reason about the tool result.".to_string()), + additional_params: None, + }], + FinishReason::Stop, + ), + Fixture::StopReasonMapping => candidate( + vec![Part { + thought: Some(false), + thought_signature: None, + part: PartKind::Text("Truncated response".to_string()), + additional_params: None, + }], + FinishReason::MaxTokens, + ), + Fixture::MessageIdPreservation => unreachable!(), + }; + + GenerateContentResponse { + response_id: "resp_123".to_string(), + candidates: vec![candidate], + prompt_feedback: None, + usage_metadata: None, + model_version: None, + } +} + +fn candidate(parts: Vec, finish_reason: FinishReason) -> ContentCandidate { + ContentCandidate { + content: Some(Content { + parts, + role: Some(Role::Model), + }), + finish_reason: Some(finish_reason), + safety_ratings: None, + citation_metadata: None, + token_count: None, + avg_logprobs: None, + logprobs_result: None, + index: Some(0), + finish_message: None, + } +} + +fn map_finish_reason(reason: FinishReason) -> StopReason { + match reason { + FinishReason::Stop => StopReason::Stop, + FinishReason::MaxTokens => StopReason::MaxTokens, + FinishReason::Safety => StopReason::Safety, + other => StopReason::Other(format!("{other:?}")), + } +} + +fn stream_request() -> completion::CompletionRequest { + completion::CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(completion::Message::user("hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + } +} + +fn streaming_sse(case: Fixture) -> String { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => concat!( + "data: {\"candidates\":[{\"content\":{\"parts\":[],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n", + ) + .to_string(), + Fixture::ToolOnlyTurn => concat!( + "data: {\"candidates\":[{\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"totalTokenCount\":15}}\n\n", + ) + .to_string(), + Fixture::TextAndToolCallTurn => concat!( + "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need weather data first.\"},{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":7,\"totalTokenCount\":17}}\n\n", + ) + .to_string(), + Fixture::EmptyTextBlocks => concat!( + "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n", + ) + .to_string(), + Fixture::ReasoningOnlyTurn => concat!( + "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need to reason about the tool result.\",\"thought\":true,\"thoughtSignature\":\"sig_1\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"thoughtsTokenCount\":3,\"totalTokenCount\":18}}\n\n", + ) + .to_string(), + Fixture::StopReasonMapping => concat!( + "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Truncated response\"}],\"role\":\"model\"},\"finishReason\":\"MAX_TOKENS\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":2,\"totalTokenCount\":12}}\n\n", + ) + .to_string(), + Fixture::MessageIdPreservation => unreachable!(), + } +} + +provider_conformance_tests!(GeminiHarness); diff --git a/rig/rig-core/src/providers/mod.rs b/rig/rig-core/src/providers/mod.rs index 24e9bece3..f1f825e99 100644 --- a/rig/rig-core/src/providers/mod.rs +++ b/rig/rig-core/src/providers/mod.rs @@ -49,6 +49,8 @@ pub mod anthropic; pub mod azure; pub mod chatgpt; pub mod cohere; +#[cfg(test)] +pub(crate) mod conformance; pub mod copilot; pub mod deepseek; pub mod galadriel; diff --git a/rig/rig-core/src/providers/openai/completion/conformance_tests.rs b/rig/rig-core/src/providers/openai/completion/conformance_tests.rs new file mode 100644 index 000000000..a5795591e --- /dev/null +++ b/rig/rig-core/src/providers/openai/completion/conformance_tests.rs @@ -0,0 +1,267 @@ +use bytes::Bytes; +use serde_json::json; + +use super::*; +use crate::{ + completion::{self, CompletionError}, + http_client::mock::MockStreamingClient, + providers::conformance::{ + BoxFuture, Fixture, Harness, NormalizedItem, Outcome, StopReason, Turn, drain_stream, + normalize_completion_response, provider_conformance_tests, + }, +}; + +struct OpenAiChatHarness; + +impl Harness for OpenAiChatHarness { + fn family_name() -> &'static str { + "openai-chat" + } + + fn expected(case: Fixture) -> Outcome { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ToolOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::ToolCall { + id: "call_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::TextAndToolCallTurn => Outcome::Supported(Turn { + items: vec![ + NormalizedItem::Text("Need weather data first.".to_string()), + NormalizedItem::ToolCall { + id: "call_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }, + ], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::EmptyTextBlocks => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ReasoningOnlyTurn => Outcome::Unsupported( + "OpenAI chat completions do not expose normalized reasoning blocks", + ), + Fixture::MessageIdPreservation => { + Outcome::Unsupported("OpenAI chat completions do not expose message IDs") + } + Fixture::StopReasonMapping => Outcome::Supported(Turn { + items: vec![NormalizedItem::Text("Truncated response".to_string())], + message_id: None, + stop_reason: Some(StopReason::MaxTokens), + }), + } + } + + fn non_stream(case: Fixture) -> Result, CompletionError> { + match case { + Fixture::ReasoningOnlyTurn => Ok(Self::expected(case)), + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let raw = non_stream_response(case); + let stop_reason = raw + .choices + .first() + .map(|choice| map_finish_reason(&choice.finish_reason)); + let response: completion::CompletionResponse = + raw.try_into()?; + Ok(Outcome::Supported(normalize_completion_response( + &response, + stop_reason, + ))) + } + } + } + + fn stream(case: Fixture) -> BoxFuture, CompletionError>> { + Box::pin(async move { + match case { + Fixture::ReasoningOnlyTurn => Ok(Self::expected(case)), + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let client = MockStreamingClient { + sse_bytes: Bytes::from(streaming_sse(case)), + }; + let request = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + let stream = + streaming::send_compatible_streaming_request(client, request).await?; + let response = drain_stream(stream).await?; + Ok(Outcome::Supported(normalize_completion_response( + &response, None, + ))) + } + } + }) + } +} + +fn non_stream_response(case: Fixture) -> CompletionResponse { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => response_with_message( + Message::Assistant { + content: vec![], + refusal: None, + audio: None, + name: None, + tool_calls: vec![], + }, + "stop", + ), + Fixture::ToolOnlyTurn => response_with_message( + Message::Assistant { + content: vec![], + refusal: None, + audio: None, + name: None, + tool_calls: vec![tool_call( + "call_lookup", + "lookup_weather", + json!({"city": "Paris"}), + )], + }, + "tool_calls", + ), + Fixture::TextAndToolCallTurn => response_with_message( + Message::Assistant { + content: vec![AssistantContent::Text { + text: "Need weather data first.".to_string(), + }], + refusal: None, + audio: None, + name: None, + tool_calls: vec![tool_call( + "call_lookup", + "lookup_weather", + json!({"city": "Paris"}), + )], + }, + "tool_calls", + ), + Fixture::EmptyTextBlocks => response_with_message( + Message::Assistant { + content: vec![AssistantContent::Text { + text: String::new(), + }], + refusal: None, + audio: None, + name: None, + tool_calls: vec![], + }, + "stop", + ), + Fixture::StopReasonMapping => response_with_message( + Message::Assistant { + content: vec![AssistantContent::Text { + text: "Truncated response".to_string(), + }], + refusal: None, + audio: None, + name: None, + tool_calls: vec![], + }, + "length", + ), + Fixture::ReasoningOnlyTurn | Fixture::MessageIdPreservation => { + unreachable!("unsupported cases are handled before construction") + } + } +} + +fn response_with_message(message: Message, finish_reason: &str) -> CompletionResponse { + CompletionResponse { + id: "chatcmpl_123".to_string(), + object: "chat.completion".to_string(), + created: 0, + model: "gpt-4o-mini".to_string(), + system_fingerprint: None, + choices: vec![Choice { + index: 0, + message, + logprobs: None, + finish_reason: finish_reason.to_string(), + }], + usage: Some(Usage { + prompt_tokens: 10, + total_tokens: 15, + prompt_tokens_details: None, + }), + } +} + +fn tool_call(id: &str, name: &str, arguments: serde_json::Value) -> ToolCall { + ToolCall { + id: id.to_string(), + r#type: ToolType::Function, + function: Function { + name: name.to_string(), + arguments, + }, + } +} + +fn map_finish_reason(reason: &str) -> StopReason { + match reason { + "stop" => StopReason::Stop, + "tool_calls" => StopReason::ToolCalls, + "content_filter" => StopReason::ContentFilter, + "length" => StopReason::MaxTokens, + other => StopReason::Other(other.to_string()), + } +} + +fn streaming_sse(case: Fixture) -> String { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":0,\"total_tokens\":10}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::ToolOnlyTurn => concat!( + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_lookup\",\"function\":{\"name\":\"lookup_weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::TextAndToolCallTurn => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Need weather data first.\",\"tool_calls\":[]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_lookup\",\"function\":{\"name\":\"lookup_weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":7,\"total_tokens\":17}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::EmptyTextBlocks => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":0,\"total_tokens\":10}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::StopReasonMapping => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Truncated response\",\"tool_calls\":[]},\"finish_reason\":\"length\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":2,\"total_tokens\":12}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::ReasoningOnlyTurn | Fixture::MessageIdPreservation => unreachable!(), + } +} + +provider_conformance_tests!(OpenAiChatHarness); diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index a7763d7c3..43b9cf64d 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -836,11 +836,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse>::from) .collect(); - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::assistant_choice_from_vec(content)?; let usage = response .usage From b36e9dabafe4a196475dde795bd83fdf68f59e70 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 17:08:59 -0700 Subject: [PATCH 2/9] helpers added --- rig/rig-core/src/agent/prompt_request/mod.rs | 25 ++-- .../src/agent/prompt_request/streaming.rs | 107 ++++++++++++++++-- .../src/agent/prompt_request/turns.rs | 84 ++++++++++++++ .../tests/core/prompt_response_messages.rs | 96 ++++++++++++++++ 4 files changed, 286 insertions(+), 26 deletions(-) create mode 100644 rig/rig-core/src/agent/prompt_request/turns.rs diff --git a/rig/rig-core/src/agent/prompt_request/mod.rs b/rig/rig-core/src/agent/prompt_request/mod.rs index 4951d5b59..3dbae4e34 100644 --- a/rig/rig-core/src/agent/prompt_request/mod.rs +++ b/rig/rig-core/src/agent/prompt_request/mod.rs @@ -1,5 +1,6 @@ pub mod hooks; pub mod streaming; +mod turns; use super::{ Agent, @@ -26,6 +27,7 @@ use std::{ }; use tracing::info_span; use tracing::{Instrument, span::Id}; +use turns::{AssistantTextAccumulator, AssistantTurnSummary}; pub trait PromptType {} pub struct Standard; @@ -339,6 +341,7 @@ where let mut current_max_turns = 0; let mut usage = Usage::new(); + let mut assistant_text_accumulator = AssistantTextAccumulator::default(); let current_span_id: AtomicU64 = AtomicU64::new(0); // We need to do at least 2 loops for 1 roundtrip (user expects normal message) @@ -447,7 +450,11 @@ where )); } - let (tool_calls, texts): (Vec<_>, Vec<_>) = resp + let turn_summary = AssistantTurnSummary::from_choice(&resp.choice); + let turn_text = turn_summary.visible_text("\n"); + assistant_text_accumulator.observe(&turn_text); + + let (tool_calls, _texts): (Vec<_>, Vec<_>) = resp .choice .iter() .partition(|choice| matches!(choice, AssistantContent::ToolCall(_))); @@ -458,23 +465,13 @@ where }); if tool_calls.is_empty() { - let merged_texts = texts - .into_iter() - .filter_map(|content| { - if let AssistantContent::Text(text) = content { - Some(text.text.clone()) - } else { - None - } - }) - .collect::>() - .join("\n"); + let final_output = assistant_text_accumulator.final_output(&turn_text); if self.max_turns > 1 { tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns); } - agent_span.record("gen_ai.completion", &merged_texts); + agent_span.record("gen_ai.completion", &final_output); agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens); agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens); agent_span.record( @@ -486,7 +483,7 @@ where usage.cache_creation_input_tokens, ); - return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages)); + return Ok(PromptResponse::new(final_output, usage).with_messages(new_messages)); } let hook = self.hook.clone(); diff --git a/rig/rig-core/src/agent/prompt_request/streaming.rs b/rig/rig-core/src/agent/prompt_request/streaming.rs index adcf766a9..6691ea849 100644 --- a/rig/rig-core/src/agent/prompt_request/streaming.rs +++ b/rig/rig-core/src/agent/prompt_request/streaming.rs @@ -16,6 +16,7 @@ use tracing::info_span; use tracing_futures::Instrument; use super::ToolCallHookAction; +use super::turns::{AssistantTextAccumulator, AssistantTurnSummary}; use crate::{ agent::Agent, completion::{CompletionError, CompletionModel, PromptError}, @@ -170,16 +171,6 @@ fn tool_result_to_user_message( } } -fn assistant_text_from_choice(choice: &OneOrMany) -> String { - choice - .iter() - .filter_map(|content| match content { - AssistantContent::Text(text) => Some(text.text.as_str()), - _ => None, - }) - .collect() -} - #[derive(Debug, thiserror::Error)] pub enum StreamingError { #[error("CompletionError: {0}")] @@ -398,6 +389,7 @@ where let output_schema = self.output_schema; let mut aggregated_usage = crate::completion::Usage::new(); + let mut assistant_text_accumulator = AssistantTextAccumulator::default(); // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid // span context leaking to other concurrent tasks. Using span.enter() inside @@ -659,7 +651,9 @@ where accumulated_reasoning.push(assembled); } - let turn_text_response = assistant_text_from_choice(&stream.choice); + let turn_summary = AssistantTurnSummary::from_choice(&stream.choice); + let turn_text_response = turn_summary.visible_text(""); + assistant_text_accumulator.observe(&turn_text_response); tracing::Span::current().record("gen_ai.completion", &turn_text_response); // Add text, reasoning, and tool calls to chat history. @@ -692,9 +686,18 @@ where } if !saw_tool_call_this_turn { + let final_response_text = + assistant_text_accumulator.final_output(&turn_text_response); + // Add user message and assistant response to history before finishing if !turn_text_response.is_empty() { new_messages.push(Message::assistant(&turn_text_response)); + } else if !final_response_text.is_empty() { + tracing::info!( + agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME), + message_id = ?stream.message_id, + "Streaming turn completed without assistant text; using prior turn text for final response" + ); } else { tracing::warn!( agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME), @@ -715,7 +718,7 @@ where None }; yield Ok(MultiTurnStreamItem::final_response_with_history( - &turn_text_response, + &final_response_text, aggregated_usage, final_messages, )); @@ -1224,6 +1227,86 @@ mod tests { assert_eq!(final_response_text.as_deref(), Some("")); } + #[derive(Clone, Default)] + struct StreamingEmptyTerminalTurnFallbackModel { + turn_counter: Arc, + } + + #[allow(refining_impl_trait)] + impl CompletionModel for StreamingEmptyTerminalTurnFallbackModel { + type Response = (); + type StreamingResponse = MockStreamingResponse; + type Client = (); + + fn make(_: &Self::Client, _: impl Into) -> Self { + Self::default() + } + + async fn completion( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + Err(CompletionError::ProviderError( + "completion is unused in this streaming test".to_string(), + )) + } + + async fn stream( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst); + let stream = async_stream::stream! { + if turn == 0 { + yield Ok(RawStreamingChoice::Message("The answer is 5".to_string())); + yield Ok(RawStreamingChoice::ToolCall( + RawStreamingToolCall::new( + "tool_call_2".to_string(), + "calculator".to_string(), + serde_json::json!({"op": "add", "a": 2, "b": 3}), + ), + )); + yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(2))); + } else { + yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(1))); + } + }; + + let pinned_stream: crate::streaming::StreamingResult = + Box::pin(stream); + Ok(StreamingCompletionResponse::stream(pinned_stream)) + } + } + + #[tokio::test] + async fn final_response_falls_back_to_prior_assistant_text_when_terminal_turn_is_empty() { + let model = StreamingEmptyTerminalTurnFallbackModel::default(); + let turn_counter = model.turn_counter.clone(); + let agent = AgentBuilder::new(model).build(); + + let mut stream = agent.stream_prompt("What is 2 + 3?").multi_turn(3).await; + let mut streamed_text = String::new(); + let mut final_response_text = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text( + text, + ))) => streamed_text.push_str(&text.text), + Ok(MultiTurnStreamItem::FinalResponse(res)) => { + final_response_text = Some(res.response().to_owned()); + break; + } + Ok(_) => {} + Err(err) => panic!("unexpected streaming error: {err:?}"), + } + } + + assert_eq!(streamed_text, "The answer is 5"); + assert_eq!(final_response_text.as_deref(), Some("The answer is 5")); + assert_eq!(turn_counter.load(Ordering::SeqCst), 2); + } + /// Background task that logs periodically to detect span leakage. /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`. async fn background_logger(stop: Arc, leak_count: Arc) { diff --git a/rig/rig-core/src/agent/prompt_request/turns.rs b/rig/rig-core/src/agent/prompt_request/turns.rs new file mode 100644 index 000000000..79c86c90e --- /dev/null +++ b/rig/rig-core/src/agent/prompt_request/turns.rs @@ -0,0 +1,84 @@ +//! Shared helpers for deriving user-visible assistant text across agent turns. + +use crate::{OneOrMany, message::AssistantContent}; + +/// Summary of the visible assistant text in a normalized turn. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct AssistantTurnSummary { + visible_text_blocks: Vec, +} + +impl AssistantTurnSummary { + /// Extract non-empty text blocks from a normalized assistant choice. + pub(crate) fn from_choice(choice: &OneOrMany) -> Self { + let visible_text_blocks = choice + .iter() + .filter_map(|content| match content { + AssistantContent::Text(text) if !text.text.is_empty() => Some(text.text.clone()), + _ => None, + }) + .collect(); + + Self { + visible_text_blocks, + } + } + + /// Render the visible text blocks using the caller's preferred separator. + pub(crate) fn visible_text(&self, separator: &str) -> String { + self.visible_text_blocks.join(separator) + } +} + +/// Tracks non-empty assistant text from earlier turns so a textless final turn +/// can still return the last user-visible answer the model produced. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct AssistantTextAccumulator { + prior_turn_texts: Vec, +} + +impl AssistantTextAccumulator { + /// Record a turn's visible text when it is not empty. + pub(crate) fn observe(&mut self, turn_text: &str) { + if !turn_text.is_empty() { + self.prior_turn_texts.push(turn_text.to_owned()); + } + } + + /// Prefer the current turn's visible text and otherwise fall back to the + /// accumulated text from earlier turns in the same request. + pub(crate) fn final_output(&self, current_turn_text: &str) -> String { + if current_turn_text.is_empty() { + self.prior_turn_texts.join("\n") + } else { + current_turn_text.to_owned() + } + } +} + +#[cfg(test)] +mod tests { + use super::{AssistantTextAccumulator, AssistantTurnSummary}; + use crate::{OneOrMany, message::AssistantContent}; + + #[test] + fn summary_ignores_empty_text_blocks() { + let choice = OneOrMany::many(vec![ + AssistantContent::text("visible"), + AssistantContent::text(""), + ]) + .expect("non-empty assistant choice"); + + let summary = AssistantTurnSummary::from_choice(&choice); + + assert_eq!(summary.visible_text("\n"), "visible"); + } + + #[test] + fn accumulator_falls_back_when_terminal_turn_is_empty() { + let mut accumulator = AssistantTextAccumulator::default(); + accumulator.observe("first turn"); + + assert_eq!(accumulator.final_output(""), "first turn"); + } +} diff --git a/rig/rig-core/tests/core/prompt_response_messages.rs b/rig/rig-core/tests/core/prompt_response_messages.rs index cdf94a0df..5302a088d 100644 --- a/rig/rig-core/tests/core/prompt_response_messages.rs +++ b/rig/rig-core/tests/core/prompt_response_messages.rs @@ -137,6 +137,87 @@ impl CompletionModel for ToolThenTextModel { } } +/// A mock model that emits text and a tool call, then ends with an empty +/// assistant turn. This exercises fallback to prior visible assistant text. +#[derive(Clone)] +struct ToolThenEmptyTerminalTurnModel { + turn: Arc, +} + +impl ToolThenEmptyTerminalTurnModel { + fn new() -> Self { + Self { + turn: Arc::new(AtomicUsize::new(0)), + } + } +} + +#[allow(refining_impl_trait)] +impl CompletionModel for ToolThenEmptyTerminalTurnModel { + type Response = (); + type StreamingResponse = (); + type Client = (); + + fn make(_: &Self::Client, _: impl Into) -> Self { + Self::new() + } + + async fn completion( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + let turn = self.turn.fetch_add(1, Ordering::SeqCst); + + if turn == 0 { + Ok(CompletionResponse { + choice: OneOrMany::many(vec![ + AssistantContent::Text(Text { + text: "The answer is 5".to_string(), + }), + AssistantContent::ToolCall(ToolCall::new( + "tc_2".to_string(), + ToolFunction::new( + "calculator".to_string(), + serde_json::json!({"op": "add", "a": 2, "b": 3}), + ), + )), + ]) + .expect("assistant turn has text and a tool call"), + usage: Usage { + input_tokens: 12, + output_tokens: 7, + total_tokens: 19, + cached_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + raw_response: (), + message_id: Some("msg_text_tool".to_string()), + }) + } else { + Ok(CompletionResponse { + choice: OneOrMany::one(AssistantContent::text("")), + usage: Usage { + input_tokens: 8, + output_tokens: 1, + total_tokens: 9, + cached_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + raw_response: (), + message_id: Some("msg_empty".to_string()), + }) + } + } + + async fn stream( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + let stream: StreamingResult<()> = Box::pin(futures::stream::empty()); + Ok(StreamingCompletionResponse::stream(stream)) + } +} + /// A mock model that always returns tool calls, never text. /// Used to test the MaxTurnsError path. #[derive(Clone)] @@ -348,6 +429,21 @@ async fn multi_turn_messages_include_tool_calls() { assert_eq!(resp.usage.output_tokens, 12); // 8 + 4 } +/// Test 6: If the terminal assistant turn is empty after a prior visible text +/// turn, the agent should fall back to the earlier assistant text. +#[tokio::test] +async fn multi_turn_prompt_falls_back_to_prior_assistant_text_when_terminal_turn_is_empty() { + let agent = AgentBuilder::new(ToolThenEmptyTerminalTurnModel::new()).build(); + + let result = agent + .prompt("What is 2 + 3?") + .max_turns(5) + .await + .expect("prompt should succeed"); + + assert_eq!(result, "The answer is 5"); +} + /// Test 6: `PromptResponse::new()` backward compatibility — 2-argument constructor /// should still work, and `messages` should be `None`. #[tokio::test] From 2cad9e9506e0ce7a262d535b2f4fc9ce4ac810f8 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 17:30:14 -0700 Subject: [PATCH 3/9] vec --- .../src/types/assistant_content.rs | 13 +- .../rig-gemini-grpc/src/completion.rs | 6 +- .../src/types/completion_response.rs | 15 +- rig/rig-core/examples/manual_tool_calls.rs | 17 +- rig/rig-core/src/agent/prompt_request/mod.rs | 10 +- .../src/agent/prompt_request/turns.rs | 11 +- rig/rig-core/src/completion/mod.rs | 1 - rig/rig-core/src/completion/request.rs | 229 +++++++++++++++++- .../src/providers/anthropic/completion.rs | 2 +- .../src/providers/cohere/completion.rs | 12 +- rig/rig-core/src/providers/conformance.rs | 5 +- rig/rig-core/src/providers/copilot/mod.rs | 6 +- rig/rig-core/src/providers/deepseek.rs | 13 +- rig/rig-core/src/providers/galadriel.rs | 6 +- .../src/providers/gemini/completion.rs | 4 +- .../providers/gemini/interactions_api/mod.rs | 9 +- .../src/providers/huggingface/completion.rs | 6 +- rig/rig-core/src/providers/hyperbolic.rs | 7 +- rig/rig-core/src/providers/mira.rs | 8 +- .../src/providers/mistral/completion.rs | 15 +- rig/rig-core/src/providers/ollama.rs | 4 +- .../src/providers/openai/completion/mod.rs | 2 +- .../src/providers/openai/responses_api/mod.rs | 8 +- .../openai/responses_api/streaming.rs | 15 -- .../src/providers/openrouter/completion.rs | 6 +- rig/rig-core/src/providers/perplexity.rs | 3 +- rig/rig-core/src/providers/xai/completion.rs | 5 +- rig/rig-core/src/streaming.rs | 32 ++- rig/rig-core/tests/chatgpt/completion.rs | 4 +- rig/rig-core/tests/common/reasoning.rs | 5 +- .../tests/core/prompt_response_messages.rs | 19 +- rig/rig-core/tests/gemini/interactions_api.rs | 7 +- .../tests/moonshot/reasoning_history.rs | 4 +- rig/rig-core/tests/openai/websocket.rs | 4 +- 34 files changed, 322 insertions(+), 191 deletions(-) diff --git a/rig-integrations/rig-bedrock/src/types/assistant_content.rs b/rig-integrations/rig-bedrock/src/types/assistant_content.rs index 964d6524d..7564a9478 100644 --- a/rig-integrations/rig-bedrock/src/types/assistant_content.rs +++ b/rig-integrations/rig-bedrock/src/types/assistant_content.rs @@ -1,7 +1,6 @@ use aws_sdk_bedrockruntime::types as aws_bedrock; use rig::{ - OneOrMany, completion::CompletionError, message::{AssistantContent, Text, ToolCall, ToolFunction}, }; @@ -114,10 +113,12 @@ impl TryFrom for completion::CompletionResponse None, }) { return Ok(completion::CompletionResponse { - choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new( - tool_use.id, - ToolFunction::new(tool_use.function.name, tool_use.function.arguments), - ))), + choice: completion::AssistantChoice::one(AssistantContent::ToolCall( + ToolCall::new( + tool_use.id, + ToolFunction::new(tool_use.function.name, tool_use.function.arguments), + ), + )), usage, raw_response: value, message_id: None, @@ -125,7 +126,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for CompletionResponse { + Some(AssistantContent::ToolCall(ToolCall { id, function, .. })) => { assert_eq!(id, "add"); assert_eq!(function.name, "add"); assert_eq!(function.arguments, args); diff --git a/rig/rig-core/examples/manual_tool_calls.rs b/rig/rig-core/examples/manual_tool_calls.rs index 7dc09fbe3..a25b70c94 100644 --- a/rig/rig-core/examples/manual_tool_calls.rs +++ b/rig/rig-core/examples/manual_tool_calls.rs @@ -10,9 +10,8 @@ //! 5. repeats until the model returns a final text answer. use anyhow::{Result, bail}; -use rig::OneOrMany; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::{Completion, ToolDefinition}; +use rig::completion::{AssistantChoice, Completion, ToolDefinition}; use rig::message::{AssistantContent, Message, ToolCall, ToolChoice}; use rig::providers::openai; use rig::tool::{Tool, ToolSet}; @@ -87,7 +86,7 @@ impl Tool for Subtract { } } -fn collect_tool_calls(choice: &OneOrMany) -> Vec { +fn collect_tool_calls(choice: &AssistantChoice) -> Vec { choice .iter() .filter_map(|content| match content { @@ -97,7 +96,7 @@ fn collect_tool_calls(choice: &OneOrMany) -> Vec { .collect() } -fn extract_text(choice: &OneOrMany) -> String { +fn extract_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { @@ -151,10 +150,12 @@ async fn main() -> Result<()> { let tool_calls = collect_tool_calls(&response.choice); history.push(current_prompt.clone()); - history.push(Message::Assistant { - id: response.message_id.clone(), - content: response.choice.clone(), - }); + if let Ok(content) = response.choice.to_one_or_many() { + history.push(Message::Assistant { + id: response.message_id.clone(), + content, + }); + } if tool_calls.is_empty() { let final_text = extract_text(&response.choice); diff --git a/rig/rig-core/src/agent/prompt_request/mod.rs b/rig/rig-core/src/agent/prompt_request/mod.rs index 3dbae4e34..2d9aaeab1 100644 --- a/rig/rig-core/src/agent/prompt_request/mod.rs +++ b/rig/rig-core/src/agent/prompt_request/mod.rs @@ -459,10 +459,12 @@ where .iter() .partition(|choice| matches!(choice, AssistantContent::ToolCall(_))); - new_messages.push(Message::Assistant { - id: resp.message_id.clone(), - content: resp.choice.clone(), - }); + if let Ok(content) = resp.choice.to_one_or_many() { + new_messages.push(Message::Assistant { + id: resp.message_id.clone(), + content, + }); + } if tool_calls.is_empty() { let final_output = assistant_text_accumulator.final_output(&turn_text); diff --git a/rig/rig-core/src/agent/prompt_request/turns.rs b/rig/rig-core/src/agent/prompt_request/turns.rs index 79c86c90e..cc72cf414 100644 --- a/rig/rig-core/src/agent/prompt_request/turns.rs +++ b/rig/rig-core/src/agent/prompt_request/turns.rs @@ -1,6 +1,6 @@ //! Shared helpers for deriving user-visible assistant text across agent turns. -use crate::{OneOrMany, message::AssistantContent}; +use crate::{completion::AssistantChoice, message::AssistantContent}; /// Summary of the visible assistant text in a normalized turn. #[derive(Debug, Clone, Default, PartialEq, Eq)] @@ -10,7 +10,7 @@ pub(crate) struct AssistantTurnSummary { impl AssistantTurnSummary { /// Extract non-empty text blocks from a normalized assistant choice. - pub(crate) fn from_choice(choice: &OneOrMany) -> Self { + pub(crate) fn from_choice(choice: &AssistantChoice) -> Self { let visible_text_blocks = choice .iter() .filter_map(|content| match content { @@ -59,15 +59,14 @@ impl AssistantTextAccumulator { #[cfg(test)] mod tests { use super::{AssistantTextAccumulator, AssistantTurnSummary}; - use crate::{OneOrMany, message::AssistantContent}; + use crate::{completion::AssistantChoice, message::AssistantContent}; #[test] fn summary_ignores_empty_text_blocks() { - let choice = OneOrMany::many(vec![ + let choice = AssistantChoice::many(vec![ AssistantContent::text("visible"), AssistantContent::text(""), - ]) - .expect("non-empty assistant choice"); + ]); let summary = AssistantTurnSummary::from_choice(&choice); diff --git a/rig/rig-core/src/completion/mod.rs b/rig/rig-core/src/completion/mod.rs index 83c7d1508..7dcef0ac6 100644 --- a/rig/rig-core/src/completion/mod.rs +++ b/rig/rig-core/src/completion/mod.rs @@ -2,5 +2,4 @@ pub mod message; pub mod request; pub use message::{AssistantContent, Message, MessageError}; -pub(crate) use request::assistant_choice_from_vec; pub use request::*; diff --git a/rig/rig-core/src/completion/request.rs b/rig/rig-core/src/completion/request.rs index 250d751d3..118403de2 100644 --- a/rig/rig-core/src/completion/request.rs +++ b/rig/rig-core/src/completion/request.rs @@ -68,7 +68,7 @@ use crate::message::ToolChoice; use crate::streaming::StreamingCompletionResponse; use crate::tool::server::ToolServerError; use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; -use crate::{OneOrMany, http_client}; +use crate::{EmptyListError, OneOrMany, http_client}; use crate::{ json_utils, message::{Message, UserContent}, @@ -359,12 +359,13 @@ pub trait Completion { } /// General completion response struct that contains the high-level completion choice -/// and the raw response. The completion choice contains one or more assistant content. +/// and the raw response. Providers may validly return an empty assistant choice +/// when a turn ends without emitting text, tool calls, reasoning, or images. #[derive(Debug)] pub struct CompletionResponse { - /// The completion choice (represented by one or more assistant message content) + /// The completion choice (represented by zero or more assistant message content) /// returned by the completion model provider - pub choice: OneOrMany, + pub choice: AssistantChoice, /// Tokens used during prompting and responding pub usage: Usage, /// The raw response returned by the completion model provider @@ -374,16 +375,196 @@ pub struct CompletionResponse { pub message_id: Option, } -pub(crate) fn assistant_choice_from_vec( - content: Vec, -) -> Result, CompletionError> { - if content.is_empty() { - return Ok(OneOrMany::one(AssistantContent::text(""))); +/// Zero or more assistant content items returned by a provider. +/// +/// Unlike [`OneOrMany`], this type preserves legitimate empty assistant turns +/// without synthesizing placeholder text. +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +#[serde(transparent)] +pub struct AssistantChoice { + items: Vec, +} + +impl AssistantChoice { + /// Create an empty assistant choice. + pub fn new() -> Self { + Self::default() + } + + /// Create an assistant choice containing a single item. + pub fn one(item: AssistantContent) -> Self { + Self { items: vec![item] } + } + + /// Create an assistant choice from zero or more items. + pub fn many(items: I) -> Self + where + I: IntoIterator, + { + Self { + items: items.into_iter().collect(), + } } - OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError("Response contained no message or tool call (empty)".into()) - }) + /// Returns the first assistant item, if present. + pub fn first(&self) -> Option { + self.items.first().cloned() + } + + /// Returns a reference to the first assistant item, if present. + pub fn first_ref(&self) -> Option<&AssistantContent> { + self.items.first() + } + + /// Returns the first assistant item mutably, if present. + pub fn first_mut(&mut self) -> Option<&mut AssistantContent> { + self.items.first_mut() + } + + /// Returns the last assistant item, if present. + pub fn last(&self) -> Option { + self.items.last().cloned() + } + + /// Returns a reference to the last assistant item, if present. + pub fn last_ref(&self) -> Option<&AssistantContent> { + self.items.last() + } + + /// Returns the last assistant item mutably, if present. + pub fn last_mut(&mut self) -> Option<&mut AssistantContent> { + self.items.last_mut() + } + + /// Returns the number of assistant items. + pub fn len(&self) -> usize { + self.items.len() + } + + /// Returns `true` when the provider emitted no assistant items. + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Append an assistant item. + pub fn push(&mut self, item: AssistantContent) { + self.items.push(item); + } + + /// Insert an assistant item at a given index. + pub fn insert(&mut self, index: usize, item: AssistantContent) { + self.items.insert(index, item); + } + + /// Iterate over assistant items. + pub fn iter(&self) -> std::slice::Iter<'_, AssistantContent> { + self.items.iter() + } + + /// Iterate over assistant items mutably. + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, AssistantContent> { + self.items.iter_mut() + } + + /// Convert into a plain vector. + pub fn into_vec(self) -> Vec { + self.items + } + + /// Convert into [`OneOrMany`] when the choice is non-empty. + pub fn into_one_or_many(self) -> Result, EmptyListError> { + OneOrMany::many(self.items) + } + + /// Clone into [`OneOrMany`] when the choice is non-empty. + pub fn to_one_or_many(&self) -> Result, EmptyListError> { + OneOrMany::many(self.items.clone()) + } +} + +impl From for AssistantChoice { + fn from(item: AssistantContent) -> Self { + Self::one(item) + } +} + +impl From> for AssistantChoice { + fn from(items: Vec) -> Self { + Self { items } + } +} + +impl From> for AssistantChoice { + fn from(items: OneOrMany) -> Self { + Self::many(items) + } +} + +impl From for Vec { + fn from(choice: AssistantChoice) -> Self { + choice.items + } +} + +impl std::iter::FromIterator for AssistantChoice { + fn from_iter>(iter: T) -> Self { + Self::many(iter) + } +} + +impl IntoIterator for AssistantChoice { + type Item = AssistantContent; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +impl<'a> IntoIterator for &'a AssistantChoice { + type Item = &'a AssistantContent; + type IntoIter = std::slice::Iter<'a, AssistantContent>; + + fn into_iter(self) -> Self::IntoIter { + self.items.iter() + } +} + +impl<'a> IntoIterator for &'a mut AssistantChoice { + type Item = &'a mut AssistantContent; + type IntoIter = std::slice::IterMut<'a, AssistantContent>; + + fn into_iter(self) -> Self::IntoIter { + self.items.iter_mut() + } +} + +impl TryFrom for OneOrMany { + type Error = EmptyListError; + + fn try_from(choice: AssistantChoice) -> Result { + choice.into_one_or_many() + } +} + +impl TryFrom<&AssistantChoice> for OneOrMany { + type Error = EmptyListError; + + fn try_from(choice: &AssistantChoice) -> Result { + choice.to_one_or_many() + } +} + +impl PartialEq> for AssistantChoice { + fn eq(&self, other: &OneOrMany) -> bool { + self.iter().eq(other.iter()) + } +} + +impl PartialEq for OneOrMany { + fn eq(&self, other: &AssistantChoice) -> bool { + self.iter().eq(other.iter()) + } } /// A trait for grabbing the token usage of a completion response. @@ -1000,6 +1181,30 @@ mod tests { assert_eq!(format!("{doc}"), expected); } + #[test] + fn assistant_choice_can_be_empty_without_placeholder_text() { + let choice = AssistantChoice::new(); + + assert!(choice.is_empty()); + assert!(choice.first().is_none()); + assert!(choice.to_one_or_many().is_err()); + } + + #[test] + fn assistant_choice_roundtrips_non_empty_content() { + let choice = AssistantChoice::many(vec![ + AssistantContent::text("hello"), + AssistantContent::tool_call("call_1", "echo", serde_json::json!({"value": 1})), + ]); + + let normalized = choice + .clone() + .into_one_or_many() + .expect("non-empty assistant choice should convert"); + + assert_eq!(choice, normalized); + } + #[test] fn test_normalize_documents_with_documents() { let doc1 = Document { diff --git a/rig/rig-core/src/providers/anthropic/completion.rs b/rig/rig-core/src/providers/anthropic/completion.rs index 832052a22..2e817a600 100644 --- a/rig/rig-core/src/providers/anthropic/completion.rs +++ b/rig/rig-core/src/providers/anthropic/completion.rs @@ -214,7 +214,7 @@ impl TryFrom for completion::CompletionResponse, _>>()?; - let choice = completion::assistant_choice_from_vec(content)?; + let choice = completion::AssistantChoice::from(content); let usage = completion::Usage { input_tokens: response.usage.input_tokens, diff --git a/rig/rig-core/src/providers/cohere/completion.rs b/rig/rig-core/src/providers/cohere/completion.rs index a54270dd6..5aa21c2ed 100644 --- a/rig/rig-core/src/providers/cohere/completion.rs +++ b/rig/rig-core/src/providers/cohere/completion.rs @@ -140,7 +140,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse>(), ) - .expect("We have atleast 1 tool call in this if block") } else { - OneOrMany::many(content.into_iter().map(|content| match content { + completion::AssistantChoice::many(content.into_iter().map(|content| match content { AssistantContent::Text { text } => completion::AssistantContent::text(text), AssistantContent::Thinking { thinking } => { completion::AssistantContent::Reasoning(Reasoning::new(&thinking)) } })) - .map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })? }; let usage = response @@ -185,7 +179,7 @@ impl TryFrom for completion::CompletionResponse, + choice: &AssistantChoice, message_id: Option, stop_reason: Option, ) -> Turn { diff --git a/rig/rig-core/src/providers/copilot/mod.rs b/rig/rig-core/src/providers/copilot/mod.rs index ddb1e975d..f9e46080b 100644 --- a/rig/rig-core/src/providers/copilot/mod.rs +++ b/rig/rig-core/src/providers/copilot/mod.rs @@ -521,11 +521,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse, _>>()?; - let choice = completion::assistant_choice_from_vec(content)?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage_metadata @@ -2283,7 +2283,7 @@ mod tests { let first = converted.choice.first(); assert!(matches!( first, - message::AssistantContent::Reasoning(message::Reasoning { content, .. }) + Some(message::AssistantContent::Reasoning(message::Reasoning { content, .. })) if matches!( content.first(), Some(message::ReasoningContent::Text { diff --git a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs index cba503db2..60346dd38 100644 --- a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs +++ b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs @@ -1,7 +1,6 @@ //! Google Gemini Interactions API integration. //! From -use crate::OneOrMany; use crate::completion::{self, CompletionError, CompletionRequest, GetTokenUsage}; use crate::http_client::HttpClientExt; use crate::message::{self, MimeType, Reasoning}; @@ -490,11 +489,7 @@ impl TryFrom for completion::CompletionResponse { }) .collect::, _>>()?; - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage @@ -2502,7 +2497,7 @@ mod tests { let choice = response.choice.first(); match choice { - completion::AssistantContent::ToolCall(tool_call) => { + Some(completion::AssistantContent::ToolCall(tool_call)) => { assert_eq!(tool_call.function.name, "get_weather"); assert_eq!(tool_call.call_id.as_deref(), Some("call-123")); } diff --git a/rig/rig-core/src/providers/huggingface/completion.rs b/rig/rig-core/src/providers/huggingface/completion.rs index fe521313d..468b76220 100644 --- a/rig/rig-core/src/providers/huggingface/completion.rs +++ b/rig/rig-core/src/providers/huggingface/completion.rs @@ -585,11 +585,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse = assistant.try_into().expect("conversion should work"); @@ -748,7 +743,7 @@ mod tests { fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() { let assistant = message::Message::Assistant { id: None, - content: OneOrMany::many(vec![ + content: crate::OneOrMany::many(vec![ message::AssistantContent::reasoning("hidden"), message::AssistantContent::text("visible"), message::AssistantContent::tool_call( @@ -818,9 +813,9 @@ mod tests { fn test_request_conversion_errors_when_all_messages_are_filtered() { let request = CompletionRequest { preamble: None, - chat_history: OneOrMany::one(message::Message::Assistant { + chat_history: crate::OneOrMany::one(message::Message::Assistant { id: None, - content: OneOrMany::one(message::AssistantContent::reasoning("hidden")), + content: crate::OneOrMany::one(message::AssistantContent::reasoning("hidden")), }), documents: vec![], tools: vec![], diff --git a/rig/rig-core/src/providers/ollama.rs b/rig/rig-core/src/providers/ollama.rs index 660bef677..a0b37597d 100644 --- a/rig/rig-core/src/providers/ollama.rs +++ b/rig/rig-core/src/providers/ollama.rs @@ -373,9 +373,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Result { - if response.output.is_empty() { - return Err(CompletionError::ResponseError( - "Response contained no parts".to_owned(), - )); - } - // Extract the msg_ ID from the first Output::Message item let message_id = response.output.iter().find_map(|item| match item { Output::Message(msg) => Some(msg.id.clone()), @@ -1446,7 +1440,7 @@ impl TryFrom for completion::CompletionResponse>::from) .collect(); - let choice = completion::assistant_choice_from_vec(content)?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage diff --git a/rig/rig-core/src/providers/openai/responses_api/streaming.rs b/rig/rig-core/src/providers/openai/responses_api/streaming.rs index d273f8376..9c9750c88 100644 --- a/rig/rig-core/src/providers/openai/responses_api/streaming.rs +++ b/rig/rig-core/src/providers/openai/responses_api/streaming.rs @@ -521,12 +521,6 @@ pub(crate) async fn completion_response_from_sse_body( item?; } - if choice_is_empty(&stream.choice) { - return Err(CompletionError::ResponseError( - "Response contained no parts".to_owned(), - )); - } - Ok(completion::CompletionResponse { usage: stream .response @@ -542,15 +536,6 @@ pub(crate) async fn completion_response_from_sse_body( }) } -fn choice_is_empty(choice: &crate::OneOrMany) -> bool { - choice.iter().all(|content| match content { - completion::AssistantContent::Text(text) => text.text.trim().is_empty(), - completion::AssistantContent::Reasoning(reasoning) => reasoning.content.is_empty(), - completion::AssistantContent::Image(_) => false, - completion::AssistantContent::ToolCall(_) => false, - }) -} - fn message_id_from_response(response: &CompletionResponse) -> Option { response.output.iter().find_map(|item| match item { Output::Message(message) => Some(message.id.clone()), diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index 3b2f841ec..b5a1014b1 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -704,11 +704,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse Ok(completion::CompletionResponse { - choice: OneOrMany::one(content.clone().into()), + choice: completion::AssistantChoice::one(content.clone().into()), usage: completion::Usage { input_tokens: response.usage.prompt_tokens as u64, output_tokens: response.usage.completion_tokens as u64, diff --git a/rig/rig-core/src/providers/xai/completion.rs b/rig/rig-core/src/providers/xai/completion.rs index 1891379c1..b698c771a 100644 --- a/rig/rig-core/src/providers/xai/completion.rs +++ b/rig/rig-core/src/providers/xai/completion.rs @@ -9,7 +9,6 @@ use tracing::{Instrument, Level, enabled, info_span}; use super::api::{ApiResponse, Message, ToolDefinition}; use super::client::Client; -use crate::OneOrMany; use crate::completion::{self, CompletionError, CompletionRequest}; use crate::http_client::HttpClientExt; use crate::providers::openai::completion::ToolChoice; @@ -139,9 +138,7 @@ impl TryFrom for completion::CompletionResponse>::from) .collect(); - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError("Response contained no output".to_owned()) - })?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage diff --git a/rig/rig-core/src/streaming.rs b/rig/rig-core/src/streaming.rs index 07b47a03f..78d4882de 100644 --- a/rig/rig-core/src/streaming.rs +++ b/rig/rig-core/src/streaming.rs @@ -8,13 +8,12 @@ //! - [StreamingCompletion]: Defines a low-level streaming LLM completion interface //! -use crate::OneOrMany; use crate::agent::Agent; use crate::agent::prompt_request::hooks::PromptHook; use crate::agent::prompt_request::streaming::StreamingPromptRequest; use crate::completion::{ - CompletionError, CompletionModel, CompletionRequestBuilder, CompletionResponse, GetTokenUsage, - Message, Usage, + AssistantChoice, CompletionError, CompletionModel, CompletionRequestBuilder, + CompletionResponse, GetTokenUsage, Message, Usage, }; use crate::message::{ AssistantContent, Reasoning, ReasoningContent, Text, ToolCall, ToolFunction, ToolResult, @@ -207,7 +206,7 @@ where reasoning_item_index: Option, /// The final aggregated message from the stream /// contains all text and tool calls generated - pub choice: OneOrMany, + pub choice: AssistantChoice, /// The final response from the stream, may be `None` /// if the provider didn't yield it during the stream pub response: Option, @@ -231,7 +230,7 @@ where assistant_items: vec![], text_item_index: None, reasoning_item_index: None, - choice: OneOrMany::one(AssistantContent::text("")), + choice: AssistantChoice::new(), response: None, final_response_yielded: AtomicBool::new(false), message_id: None, @@ -327,12 +326,7 @@ where Poll::Ready(None) => { // This is run at the end of the inner stream to collect all tokens into // a single unified `Message`. - if stream.assistant_items.is_empty() { - stream.assistant_items.push(AssistantContent::text("")); - } - - stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items)) - .expect("There should be at least one assistant message"); + stream.choice = AssistantChoice::from(std::mem::take(&mut stream.assistant_items)); Poll::Ready(None) } @@ -685,6 +679,14 @@ mod tests { StreamingCompletionResponse::stream(to_stream_result(stream)) } + fn create_empty_stream() -> StreamingCompletionResponse { + let stream = stream! { + yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 0 })); + }; + + StreamingCompletionResponse::stream(to_stream_result(stream)) + } + fn create_interleaved_stream() -> StreamingCompletionResponse { let stream = stream! { yield Ok(RawStreamingChoice::Reasoning { @@ -837,6 +839,14 @@ mod tests { )); } + #[tokio::test] + async fn test_empty_stream_keeps_empty_choice() { + let mut stream = create_empty_stream(); + while stream.next().await.is_some() {} + + assert!(stream.choice.is_empty()); + } + #[tokio::test] async fn test_stream_aggregates_assistant_items_in_arrival_order() { let mut stream = create_interleaved_stream(); diff --git a/rig/rig-core/tests/chatgpt/completion.rs b/rig/rig-core/tests/chatgpt/completion.rs index f897dfaf2..b1cea1ae6 100644 --- a/rig/rig-core/tests/chatgpt/completion.rs +++ b/rig/rig-core/tests/chatgpt/completion.rs @@ -2,7 +2,7 @@ use futures::StreamExt; use rig::client::CompletionClient; -use rig::completion::CompletionModel; +use rig::completion::{AssistantChoice, CompletionModel}; use rig::message::AssistantContent; use rig::message::Message; use rig::streaming::{StreamedAssistantContent, StreamingPrompt}; @@ -12,7 +12,7 @@ use crate::support::{ assert_contains_any_case_insensitive, assert_nonempty_response, collect_stream_final_response, }; -fn aggregated_text(choice: &rig::OneOrMany) -> String { +fn aggregated_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { diff --git a/rig/rig-core/tests/common/reasoning.rs b/rig/rig-core/tests/common/reasoning.rs index 8674536d9..181c8e7c0 100644 --- a/rig/rig-core/tests/common/reasoning.rs +++ b/rig/rig-core/tests/common/reasoning.rs @@ -206,7 +206,10 @@ where let turn1_assistant = Message::Assistant { id: response.message_id, - content: response.choice, + content: response + .choice + .into_one_or_many() + .expect("reasoning roundtrip response should not be empty"), }; let turn2_prompt = Message::User { diff --git a/rig/rig-core/tests/core/prompt_response_messages.rs b/rig/rig-core/tests/core/prompt_response_messages.rs index 5302a088d..3377eee0a 100644 --- a/rig/rig-core/tests/core/prompt_response_messages.rs +++ b/rig/rig-core/tests/core/prompt_response_messages.rs @@ -1,10 +1,10 @@ //! Integration tests for `PromptResponse.messages` using mock models. //! Exercises the real agent loop code path with mocked LLM responses. -use rig::OneOrMany; use rig::agent::AgentBuilder; use rig::completion::{ - CompletionError, CompletionModel, CompletionRequest, CompletionResponse, Message, Prompt, Usage, + AssistantChoice, CompletionError, CompletionModel, CompletionRequest, CompletionResponse, + Message, Prompt, Usage, }; use rig::message::{AssistantContent, Text, ToolCall, ToolFunction, UserContent}; use rig::streaming::{StreamingCompletionResponse, StreamingResult}; @@ -34,7 +34,7 @@ impl CompletionModel for SimpleTextModel { _request: CompletionRequest, ) -> Result, CompletionError> { Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::Text(Text { + choice: AssistantChoice::one(AssistantContent::Text(Text { text: "hello from mock".to_string(), })), usage: Usage { @@ -92,7 +92,7 @@ impl CompletionModel for ToolThenTextModel { if turn == 0 { // First turn: return a tool call Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new( + choice: AssistantChoice::one(AssistantContent::ToolCall(ToolCall::new( "tc_1".to_string(), ToolFunction::new( "calculator".to_string(), @@ -112,7 +112,7 @@ impl CompletionModel for ToolThenTextModel { } else { // Second turn: return a text response Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::Text(Text { + choice: AssistantChoice::one(AssistantContent::Text(Text { text: "The answer is 5".to_string(), })), usage: Usage { @@ -170,7 +170,7 @@ impl CompletionModel for ToolThenEmptyTerminalTurnModel { if turn == 0 { Ok(CompletionResponse { - choice: OneOrMany::many(vec![ + choice: AssistantChoice::many(vec![ AssistantContent::Text(Text { text: "The answer is 5".to_string(), }), @@ -181,8 +181,7 @@ impl CompletionModel for ToolThenEmptyTerminalTurnModel { serde_json::json!({"op": "add", "a": 2, "b": 3}), ), )), - ]) - .expect("assistant turn has text and a tool call"), + ]), usage: Usage { input_tokens: 12, output_tokens: 7, @@ -195,7 +194,7 @@ impl CompletionModel for ToolThenEmptyTerminalTurnModel { }) } else { Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::text("")), + choice: AssistantChoice::new(), usage: Usage { input_tokens: 8, output_tokens: 1, @@ -238,7 +237,7 @@ impl CompletionModel for AlwaysToolCallModel { _request: CompletionRequest, ) -> Result, CompletionError> { Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new( + choice: AssistantChoice::one(AssistantContent::ToolCall(ToolCall::new( "tc_loop".to_string(), ToolFunction::new("infinite_tool".to_string(), serde_json::json!({"x": 1})), ))), diff --git a/rig/rig-core/tests/gemini/interactions_api.rs b/rig/rig-core/tests/gemini/interactions_api.rs index 6d45e4ae3..a67579f12 100644 --- a/rig/rig-core/tests/gemini/interactions_api.rs +++ b/rig/rig-core/tests/gemini/interactions_api.rs @@ -1,9 +1,8 @@ //! Migrated from `examples/gemini_interactions_api.rs`. use futures::StreamExt; -use rig::OneOrMany; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::{CompletionModel, GetTokenUsage}; +use rig::completion::{AssistantChoice, CompletionModel, GetTokenUsage}; use rig::message::{AssistantContent, Message, ToolCall, ToolChoice}; use rig::providers::gemini; use rig::providers::gemini::interactions_api::{AdditionalParameters, Tool}; @@ -11,7 +10,7 @@ use rig::streaming::StreamedAssistantContent; use crate::support::assert_nonempty_response; -fn extract_text(choice: &OneOrMany) -> String { +fn extract_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { @@ -22,7 +21,7 @@ fn extract_text(choice: &OneOrMany) -> String { .join("") } -fn first_tool_call(choice: &OneOrMany) -> Option { +fn first_tool_call(choice: &AssistantChoice) -> Option { choice.iter().find_map(|content| match content { AssistantContent::ToolCall(tool_call) => Some(tool_call.clone()), _ => None, diff --git a/rig/rig-core/tests/moonshot/reasoning_history.rs b/rig/rig-core/tests/moonshot/reasoning_history.rs index 81cae1baf..6c651b1e5 100644 --- a/rig/rig-core/tests/moonshot/reasoning_history.rs +++ b/rig/rig-core/tests/moonshot/reasoning_history.rs @@ -2,13 +2,13 @@ use rig::OneOrMany; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::CompletionModel; +use rig::completion::{AssistantChoice, CompletionModel}; use rig::message::{AssistantContent, Message, Reasoning}; use rig::providers::moonshot; use crate::support::{assert_contains_any_case_insensitive, assert_nonempty_response}; -fn response_text(choice: &rig::OneOrMany) -> String { +fn response_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { diff --git a/rig/rig-core/tests/openai/websocket.rs b/rig/rig-core/tests/openai/websocket.rs index 70f3b0654..ccef4ed5d 100644 --- a/rig/rig-core/tests/openai/websocket.rs +++ b/rig/rig-core/tests/openai/websocket.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::CompletionModel; +use rig::completion::{AssistantChoice, CompletionModel}; use rig::message::AssistantContent; use rig::providers::openai; use rig::providers::openai::responses_api::streaming::{ItemChunkKind, ResponseChunkKind}; @@ -10,7 +10,7 @@ use rig::providers::openai::responses_api::websocket::ResponsesWebSocketEvent; use crate::support::assert_nonempty_response; -fn extract_text(choice: &rig::OneOrMany) -> String { +fn extract_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { From c42e1e02ce120b8ea34605a7bfbd7ca3bcebc9a2 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 18:04:50 -0700 Subject: [PATCH 4/9] stop_reason --- .../src/types/assistant_content.rs | 2 + .../rig-gemini-grpc/src/completion.rs | 1 + .../src/types/completion_response.rs | 1 + rig/rig-core/src/agent/prompt_request/mod.rs | 2 +- .../src/agent/prompt_request/streaming.rs | 2 +- .../src/agent/prompt_request/turns.rs | 39 ++++- rig/rig-core/src/completion/mod.rs | 2 + rig/rig-core/src/completion/normalized.rs | 152 ++++++++++++++++++ rig/rig-core/src/completion/request.rs | 3 + .../src/providers/anthropic/completion.rs | 10 ++ .../providers/anthropic/conformance_tests.rs | 19 +-- .../src/providers/anthropic/streaming.rs | 8 +- .../src/providers/cohere/completion.rs | 11 ++ rig/rig-core/src/providers/conformance.rs | 99 ++---------- rig/rig-core/src/providers/copilot/mod.rs | 7 + rig/rig-core/src/providers/deepseek.rs | 4 + rig/rig-core/src/providers/galadriel.rs | 10 +- .../src/providers/gemini/completion.rs | 14 ++ .../src/providers/gemini/conformance_tests.rs | 23 +-- .../providers/gemini/interactions_api/mod.rs | 1 + .../src/providers/gemini/streaming.rs | 8 +- .../src/providers/huggingface/completion.rs | 4 + rig/rig-core/src/providers/hyperbolic.rs | 4 + rig/rig-core/src/providers/mira.rs | 8 + .../src/providers/mistral/completion.rs | 4 + rig/rig-core/src/providers/ollama.rs | 5 + .../openai/completion/conformance_tests.rs | 23 +-- .../src/providers/openai/completion/mod.rs | 12 ++ .../providers/openai/completion/streaming.rs | 28 +++- .../src/providers/openai/responses_api/mod.rs | 1 + .../openai/responses_api/streaming.rs | 26 +-- .../src/providers/openrouter/completion.rs | 5 + rig/rig-core/src/providers/perplexity.rs | 4 + rig/rig-core/src/providers/xai/completion.rs | 5 + rig/rig-core/src/streaming.rs | 32 ++++ .../tests/core/prompt_response_messages.rs | 6 + 36 files changed, 412 insertions(+), 173 deletions(-) create mode 100644 rig/rig-core/src/completion/normalized.rs diff --git a/rig-integrations/rig-bedrock/src/types/assistant_content.rs b/rig-integrations/rig-bedrock/src/types/assistant_content.rs index 7564a9478..9cc473cab 100644 --- a/rig-integrations/rig-bedrock/src/types/assistant_content.rs +++ b/rig-integrations/rig-bedrock/src/types/assistant_content.rs @@ -122,6 +122,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for CompletionResponse Self { - let visible_text_blocks = choice - .iter() - .filter_map(|content| match content { - AssistantContent::Text(text) if !text.text.is_empty() => Some(text.text.clone()), + fn from_turn(turn: NormalizedTurn) -> Self { + let visible_text_blocks = turn + .items + .into_iter() + .filter_map(|item| match item { + NormalizedItem::Text(text) => Some(text), _ => None, }) .collect(); @@ -24,6 +26,29 @@ impl AssistantTurnSummary { } } + /// Extract non-empty text blocks from a normalized assistant choice. + #[cfg_attr(not(test), allow(dead_code))] + pub(crate) fn from_choice(choice: &AssistantChoice) -> Self { + Self::from_turn(NormalizedTurn::from_choice(choice, None, None)) + } + + /// Extract non-empty text blocks from a normalized completion response. + pub(crate) fn from_response(response: &CompletionResponse) -> Self { + Self::from_turn(NormalizedTurn::from_completion_response(response)) + } + + /// Extract non-empty text blocks from an aggregated streaming response. + pub(crate) fn from_stream_response(response: &StreamingCompletionResponse) -> Self + where + R: Clone + Unpin + crate::completion::GetTokenUsage, + { + Self::from_turn(NormalizedTurn::from_choice( + &response.choice, + response.message_id.clone(), + response.stop_reason.clone(), + )) + } + /// Render the visible text blocks using the caller's preferred separator. pub(crate) fn visible_text(&self, separator: &str) -> String { self.visible_text_blocks.join(separator) diff --git a/rig/rig-core/src/completion/mod.rs b/rig/rig-core/src/completion/mod.rs index 7dcef0ac6..7010a7ced 100644 --- a/rig/rig-core/src/completion/mod.rs +++ b/rig/rig-core/src/completion/mod.rs @@ -1,5 +1,7 @@ pub mod message; +pub(crate) mod normalized; pub mod request; pub use message::{AssistantContent, Message, MessageError}; +pub use normalized::StopReason; pub use request::*; diff --git a/rig/rig-core/src/completion/normalized.rs b/rig/rig-core/src/completion/normalized.rs new file mode 100644 index 000000000..6bb136558 --- /dev/null +++ b/rig/rig-core/src/completion/normalized.rs @@ -0,0 +1,152 @@ +//! Internal normalization helpers for provider responses. +//! +//! Providers often expose different stop-reason enums and content block shapes. +//! This module defines a small shared semantic representation that can be used by +//! higher-level code, test harnesses, and future provider adapters without +//! depending on provider-specific wire types. + +use serde_json::Value; + +use super::{AssistantChoice, CompletionResponse}; +use crate::message::AssistantContent; + +/// Provider-agnostic assistant content used by internal normalization flows. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum NormalizedItem { + Text(String), + ToolCall { + id: String, + name: String, + arguments: Value, + }, + Reasoning(String), +} + +/// Provider-agnostic reasons why a provider stopped generating a turn. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StopReason { + /// The model reached a natural stop point. + Stop, + /// The model stopped in order to request tool execution. + ToolCalls, + /// The provider explicitly ended the assistant turn. + EndTurn, + /// The model hit the configured output-token limit. + MaxTokens, + /// The provider filtered the output content. + ContentFilter, + /// The provider stopped for a safety-related reason. + Safety, + /// The provider reported a stop reason Rig does not normalize yet. + Other(String), +} + +/// Provider-agnostic normalized assistant turn. +#[derive(Debug, Clone, PartialEq, Default)] +pub(crate) struct NormalizedTurn { + pub(crate) items: Vec, + pub(crate) message_id: Option, + pub(crate) stop_reason: Option, +} + +impl NormalizedTurn { + /// Build a normalized turn from a normalized assistant choice plus metadata. + pub(crate) fn from_choice( + choice: &AssistantChoice, + message_id: Option, + stop_reason: Option, + ) -> Self { + let mut items = Vec::new(); + + for item in choice.iter() { + let normalized = match item { + AssistantContent::Text(text) if !text.text.is_empty() => { + Some(NormalizedItem::Text(text.text.clone())) + } + AssistantContent::ToolCall(tool_call) => Some(NormalizedItem::ToolCall { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }), + AssistantContent::Reasoning(reasoning) => { + let text = reasoning.display_text(); + if text.is_empty() { + None + } else { + Some(NormalizedItem::Reasoning(text)) + } + } + _ => None, + }; + + if let Some(normalized) = normalized { + let duplicate_reasoning = matches!( + (&normalized, items.last()), + (NormalizedItem::Reasoning(current), Some(NormalizedItem::Reasoning(previous))) + if current == previous + ); + + if !duplicate_reasoning { + items.push(normalized); + } + } + } + + Self { + items, + message_id, + stop_reason, + } + } + + /// Build a normalized turn from a completion response. + pub(crate) fn from_completion_response(response: &CompletionResponse) -> Self { + Self::from_choice( + &response.choice, + response.message_id.clone(), + response.stop_reason.clone(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::{NormalizedItem, NormalizedTurn, StopReason}; + use crate::completion::AssistantChoice; + use crate::message::{AssistantContent, Reasoning}; + + #[test] + fn normalized_turn_ignores_empty_text_blocks() { + let choice = AssistantChoice::many(vec![ + AssistantContent::text("visible"), + AssistantContent::text(""), + ]); + + let turn = NormalizedTurn::from_choice(&choice, Some("msg_1".to_string()), None); + + assert_eq!( + turn, + NormalizedTurn { + items: vec![NormalizedItem::Text("visible".to_string())], + message_id: Some("msg_1".to_string()), + stop_reason: None, + } + ); + } + + #[test] + fn normalized_turn_deduplicates_adjacent_reasoning_blocks() { + let choice = AssistantChoice::many(vec![ + AssistantContent::Reasoning(Reasoning::new("step one")), + AssistantContent::Reasoning(Reasoning::new("step one")), + ]); + + let turn = NormalizedTurn::from_choice(&choice, None, Some(StopReason::EndTurn)); + + assert_eq!( + turn.items, + vec![NormalizedItem::Reasoning("step one".to_string())] + ); + assert_eq!(turn.stop_reason, Some(StopReason::EndTurn)); + } +} diff --git a/rig/rig-core/src/completion/request.rs b/rig/rig-core/src/completion/request.rs index 118403de2..071d2bf2b 100644 --- a/rig/rig-core/src/completion/request.rs +++ b/rig/rig-core/src/completion/request.rs @@ -63,6 +63,7 @@ //! For more information on how to use the completion functionality, refer to the documentation of //! the individual traits, structs, and enums defined in this module. +use super::StopReason; use super::message::{AssistantContent, DocumentMediaType}; use crate::message::ToolChoice; use crate::streaming::StreamingCompletionResponse; @@ -373,6 +374,8 @@ pub struct CompletionResponse { /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID). /// Used to pair reasoning input items with their output items in multi-turn. pub message_id: Option, + /// Provider-agnostic reason why the model stopped generating this turn. + pub stop_reason: Option, } /// Zero or more assistant content items returned by a provider. diff --git a/rig/rig-core/src/providers/anthropic/completion.rs b/rig/rig-core/src/providers/anthropic/completion.rs index 2e817a600..e7cc0097c 100644 --- a/rig/rig-core/src/providers/anthropic/completion.rs +++ b/rig/rig-core/src/providers/anthropic/completion.rs @@ -60,6 +60,15 @@ pub struct CompletionResponse { pub usage: Usage, } +pub(crate) fn map_stop_reason(reason: &str) -> completion::StopReason { + match reason { + "end_turn" => completion::StopReason::EndTurn, + "tool_use" => completion::StopReason::ToolCalls, + "max_tokens" => completion::StopReason::MaxTokens, + other => completion::StopReason::Other(other.to_string()), + } +} + impl ProviderResponseExt for CompletionResponse { type OutputMessage = Content; type Usage = Usage; @@ -230,6 +239,7 @@ impl TryFrom for completion::CompletionResponse Ok(Self::expected(case)), _ => { let raw = non_stream_response(case); - let stop_reason = raw.stop_reason.as_deref().map(map_stop_reason); let response: completion::CompletionResponse = raw.try_into()?; - Ok(Outcome::Supported(normalize_completion_response( - &response, - stop_reason, - ))) + Ok(Outcome::Supported(normalize_completion_response(&response))) } } } @@ -101,9 +97,7 @@ impl Harness for AnthropicHarness { let model = CompletionModel::new(client, "claude-test"); let stream = model.stream(stream_request()).await?; let response = drain_stream(stream).await?; - Ok(Outcome::Supported(normalize_completion_response( - &response, None, - ))) + Ok(Outcome::Supported(normalize_completion_response(&response))) } } }) @@ -177,15 +171,6 @@ fn response_with_content(content: Vec, stop_reason: &str) -> Completion } } -fn map_stop_reason(reason: &str) -> StopReason { - match reason { - "end_turn" => StopReason::EndTurn, - "tool_use" => StopReason::ToolCalls, - "max_tokens" => StopReason::MaxTokens, - other => StopReason::Other(other.to_string()), - } -} - fn stream_request() -> completion::CompletionRequest { completion::CompletionRequest { model: None, diff --git a/rig/rig-core/src/providers/anthropic/streaming.rs b/rig/rig-core/src/providers/anthropic/streaming.rs index a338e876c..b6206aa80 100644 --- a/rig/rig-core/src/providers/anthropic/streaming.rs +++ b/rig/rig-core/src/providers/anthropic/streaming.rs @@ -7,7 +7,7 @@ use tracing_futures::Instrument; use super::completion::{ AnthropicCompatibleProvider, CacheControl, Content, GenericCompletionModel, Message, - SystemContent, ToolChoice, ToolDefinition, Usage, apply_cache_control, + SystemContent, ToolChoice, ToolDefinition, Usage, apply_cache_control, map_stop_reason, split_system_messages_from_history, }; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; @@ -319,6 +319,12 @@ where }, StreamingEvent::MessageDelta { delta, usage } => { if delta.stop_reason.is_some() { + if let Some(stop_reason) = delta.stop_reason.as_deref() { + yield Ok(RawStreamingChoice::StopReason( + map_stop_reason(stop_reason), + )); + } + // cache_creation_input_tokens and cache_read_input_tokens // are cumulative totals on message_delta.usage per the // Anthropic streaming API spec — use them directly. diff --git a/rig/rig-core/src/providers/cohere/completion.rs b/rig/rig-core/src/providers/cohere/completion.rs index 5aa21c2ed..a86a093a0 100644 --- a/rig/rig-core/src/providers/cohere/completion.rs +++ b/rig/rig-core/src/providers/cohere/completion.rs @@ -91,6 +91,15 @@ pub enum FinishReason { ToolCall, } +fn map_finish_reason(reason: &FinishReason) -> completion::StopReason { + match reason { + FinishReason::MaxTokens => completion::StopReason::MaxTokens, + FinishReason::StopSequence | FinishReason::Complete => completion::StopReason::Stop, + FinishReason::ToolCall => completion::StopReason::ToolCalls, + FinishReason::Error => completion::StopReason::Other("ERROR".to_string()), + } +} + #[derive(Debug, Deserialize, Clone, Serialize)] pub struct Usage { #[serde(default)] @@ -137,6 +146,7 @@ impl TryFrom for completion::CompletionResponse Result { + let stop_reason = Some(map_finish_reason(&response.finish_reason)); let (content, _, tool_calls) = response.message(); let model_response = if !tool_calls.is_empty() { @@ -183,6 +193,7 @@ impl TryFrom for completion::CompletionResponse, - pub(crate) message_id: Option, - pub(crate) stop_reason: Option, -} - #[derive(Debug, Clone, PartialEq)] pub(crate) enum Outcome { Supported(T), Unsupported(&'static str), } -pub(crate) type BoxFuture = Pin + Send>>; +pub(crate) type BoxFuture = crate::wasm_compat::WasmBoxedFuture<'static, T>; pub(crate) trait Harness { fn family_name() -> &'static str; @@ -66,59 +35,10 @@ pub(crate) trait Harness { fn stream(case: Fixture) -> BoxFuture, CompletionError>>; } -pub(crate) fn normalize_turn( - choice: &AssistantChoice, - message_id: Option, - stop_reason: Option, -) -> Turn { - let mut items = Vec::new(); - - for item in choice.iter() { - let normalized = match item { - AssistantContent::Text(text) if !text.text.is_empty() => { - Some(NormalizedItem::Text(text.text.clone())) - } - AssistantContent::ToolCall(tool_call) => Some(NormalizedItem::ToolCall { - id: tool_call.id.clone(), - name: tool_call.function.name.clone(), - arguments: tool_call.function.arguments.clone(), - }), - AssistantContent::Reasoning(reasoning) => { - let text = reasoning.display_text(); - if text.is_empty() { - None - } else { - Some(NormalizedItem::Reasoning(text)) - } - } - _ => None, - }; - - if let Some(normalized) = normalized { - let duplicate_reasoning = matches!( - (&normalized, items.last()), - (NormalizedItem::Reasoning(current), Some(NormalizedItem::Reasoning(previous))) - if current == previous - ); - - if !duplicate_reasoning { - items.push(normalized); - } - } - } - - Turn { - items, - message_id, - stop_reason, - } -} - pub(crate) fn normalize_completion_response( response: &completion::CompletionResponse, - stop_reason: Option, ) -> Turn { - normalize_turn(&response.choice, response.message_id.clone(), stop_reason) + Turn::from_completion_response(response) } pub(crate) async fn drain_stream( @@ -171,6 +91,13 @@ pub(crate) async fn assert_stream_matches_non_stream(case: Fixture) H::family_name(), case ); + assert_eq!( + actual.stop_reason, + expected.stop_reason, + "{} stream {:?} stop_reason diverged", + H::family_name(), + case + ); } (Outcome::Unsupported(_), Outcome::Unsupported(_)) => {} (expected, actual) => panic!( diff --git a/rig/rig-core/src/providers/copilot/mod.rs b/rig/rig-core/src/providers/copilot/mod.rs index f9e46080b..99f98d1bd 100644 --- a/rig/rig-core/src/providers/copilot/mod.rs +++ b/rig/rig-core/src/providers/copilot/mod.rs @@ -480,6 +480,10 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse Err(CompletionError::ProviderError( @@ -780,6 +786,7 @@ where usage: core.usage, raw_response: CopilotCompletionResponse::Responses(Box::new(response)), message_id: core.message_id, + stop_reason: core.stop_reason, }) } else { let body = http_client::text(response).await?; diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index c89dafc4c..98fff1530 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -383,6 +383,9 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Result { - let Choice { message, .. } = response.choices.first().ok_or_else(|| { + let Choice { + message, + finish_reason, + .. + } = response.choices.first().ok_or_else(|| { CompletionError::ResponseError("Response contained no choices".to_owned()) })?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + finish_reason, + )); let mut content = message .content @@ -271,6 +278,7 @@ impl TryFrom for completion::CompletionResponse completion::StopReason { + match reason { + gemini_api_types::FinishReason::Stop => completion::StopReason::Stop, + gemini_api_types::FinishReason::MaxTokens => completion::StopReason::MaxTokens, + gemini_api_types::FinishReason::Safety => completion::StopReason::Safety, + other => completion::StopReason::Other(format!("{other:?}")), + } +} + #[derive(Clone, Debug)] pub struct CompletionModel { pub(crate) client: Client, @@ -507,6 +516,11 @@ impl TryFrom for completion::CompletionResponse Ok(Self::expected(case)), _ => { let raw = non_stream_response(case); - let stop_reason = raw - .candidates - .first() - .and_then(|candidate| candidate.finish_reason.clone()) - .map(map_finish_reason); let response: completion::CompletionResponse = raw.try_into()?; - Ok(Outcome::Supported(normalize_completion_response( - &response, - stop_reason, - ))) + Ok(Outcome::Supported(normalize_completion_response(&response))) } } } @@ -107,9 +99,7 @@ impl Harness for GeminiHarness { let model = CompletionModel::new(client, "gemini-test"); let stream = model.stream(stream_request()).await?; let response = drain_stream(stream).await?; - Ok(Outcome::Supported(normalize_completion_response( - &response, None, - ))) + Ok(Outcome::Supported(normalize_completion_response(&response))) } } }) @@ -207,15 +197,6 @@ fn candidate(parts: Vec, finish_reason: FinishReason) -> ContentCandidate } } -fn map_finish_reason(reason: FinishReason) -> StopReason { - match reason { - FinishReason::Stop => StopReason::Stop, - FinishReason::MaxTokens => StopReason::MaxTokens, - FinishReason::Safety => StopReason::Safety, - other => StopReason::Other(format!("{other:?}")), - } -} - fn stream_request() -> completion::CompletionRequest { completion::CompletionRequest { model: None, diff --git a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs index 60346dd38..65a77b135 100644 --- a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs +++ b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs @@ -502,6 +502,7 @@ impl TryFrom for completion::CompletionResponse { usage, raw_response: response, message_id: None, + stop_reason: None, }) } } diff --git a/rig/rig-core/src/providers/gemini/streaming.rs b/rig/rig-core/src/providers/gemini/streaming.rs index 1bb2f7b1e..7da7b2fc5 100644 --- a/rig/rig-core/src/providers/gemini/streaming.rs +++ b/rig/rig-core/src/providers/gemini/streaming.rs @@ -6,7 +6,8 @@ use tracing_futures::Instrument; use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind}; use super::completion::{ - CompletionModel, create_request_body, resolve_request_model, streaming_endpoint, + CompletionModel, create_request_body, map_finish_reason, resolve_request_model, + streaming_endpoint, }; use crate::completion::message::ReasoningContent; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; @@ -211,6 +212,11 @@ where // Check if this is the final response if choice.finish_reason.is_some() { + if let Some(stop_reason) = + choice.finish_reason.as_ref().map(map_finish_reason) + { + yield Ok(streaming::RawStreamingChoice::StopReason(stop_reason)); + } let span = tracing::Span::current(); span.record_token_usage(&data.usage_metadata); final_usage = data.usage_metadata; diff --git a/rig/rig-core/src/providers/huggingface/completion.rs b/rig/rig-core/src/providers/huggingface/completion.rs index 468b76220..8100b7ae4 100644 --- a/rig/rig-core/src/providers/huggingface/completion.rs +++ b/rig/rig-core/src/providers/huggingface/completion.rs @@ -552,6 +552,9 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse choices + .first() + .and_then(|choice| choice.finish_reason.as_deref()) + .map(crate::providers::openai::completion::map_finish_reason), + CompletionResponse::Simple(_) => None, + }; Ok(completion::CompletionResponse { choice, usage, raw_response: response, message_id: None, + stop_reason, }) } } diff --git a/rig/rig-core/src/providers/mistral/completion.rs b/rig/rig-core/src/providers/mistral/completion.rs index 1f0b74fda..beec3344f 100644 --- a/rig/rig-core/src/providers/mistral/completion.rs +++ b/rig/rig-core/src/providers/mistral/completion.rs @@ -491,6 +491,9 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse { type Error = CompletionError; fn try_from(resp: CompletionResponse) -> Result { + let stop_reason = resp + .done_reason + .as_deref() + .map(crate::providers::openai::completion::map_finish_reason); match resp.message { // Process only if an assistant message is present. Message::Assistant { @@ -408,6 +412,7 @@ impl TryFrom for completion::CompletionResponse Err(CompletionError::ResponseError( diff --git a/rig/rig-core/src/providers/openai/completion/conformance_tests.rs b/rig/rig-core/src/providers/openai/completion/conformance_tests.rs index a5795591e..f4c8b97fa 100644 --- a/rig/rig-core/src/providers/openai/completion/conformance_tests.rs +++ b/rig/rig-core/src/providers/openai/completion/conformance_tests.rs @@ -71,16 +71,9 @@ impl Harness for OpenAiChatHarness { Fixture::MessageIdPreservation => Ok(Self::expected(case)), _ => { let raw = non_stream_response(case); - let stop_reason = raw - .choices - .first() - .map(|choice| map_finish_reason(&choice.finish_reason)); let response: completion::CompletionResponse = raw.try_into()?; - Ok(Outcome::Supported(normalize_completion_response( - &response, - stop_reason, - ))) + Ok(Outcome::Supported(normalize_completion_response(&response))) } } } @@ -102,9 +95,7 @@ impl Harness for OpenAiChatHarness { let stream = streaming::send_compatible_streaming_request(client, request).await?; let response = drain_stream(stream).await?; - Ok(Outcome::Supported(normalize_completion_response( - &response, None, - ))) + Ok(Outcome::Supported(normalize_completion_response(&response))) } } }) @@ -215,16 +206,6 @@ fn tool_call(id: &str, name: &str, arguments: serde_json::Value) -> ToolCall { } } -fn map_finish_reason(reason: &str) -> StopReason { - match reason { - "stop" => StopReason::Stop, - "tool_calls" => StopReason::ToolCalls, - "content_filter" => StopReason::ContentFilter, - "length" => StopReason::MaxTokens, - other => StopReason::Other(other.to_string()), - } -} - fn streaming_sse(case: Fixture) -> String { match case { Fixture::EmptyAssistantTurnAfterToolResult => concat!( diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index 1c9190c87..fb85e455b 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -24,6 +24,16 @@ use std::str::FromStr; pub mod streaming; +pub(crate) fn map_finish_reason(reason: &str) -> completion::StopReason { + match reason { + "stop" => completion::StopReason::Stop, + "tool_calls" => completion::StopReason::ToolCalls, + "content_filter" => completion::StopReason::ContentFilter, + "length" => completion::StopReason::MaxTokens, + other => completion::StopReason::Other(other.to_string()), + } +} + /// Serializes user content as a plain string when there's a single text item, /// otherwise as an array of content parts. fn serialize_user_content( @@ -795,6 +805,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse crate::completion::StopReason { + match reason { + FinishReason::ToolCalls => crate::completion::StopReason::ToolCalls, + FinishReason::Stop => crate::completion::StopReason::Stop, + FinishReason::ContentFilter => crate::completion::StopReason::ContentFilter, + FinishReason::Length => crate::completion::StopReason::MaxTokens, + FinishReason::Other(other) => crate::completion::StopReason::Other(other.clone()), + } +} + impl GenericCompletionModel where crate::client::Client: HttpClientExt + Clone + 'static, @@ -289,13 +299,19 @@ where } // Finish reason - if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + if let Some(finish_reason) = &choice.finish_reason { + if *finish_reason == FinishReason::ToolCalls { + for (_idx, tool_call) in tool_calls.into_iter() { + yield Ok(streaming::RawStreamingChoice::ToolCall( + finalize_completed_streaming_tool_call(tool_call), + )); + } + tool_calls = HashMap::new(); } - tool_calls = HashMap::new(); + + yield Ok(streaming::RawStreamingChoice::StopReason(map_finish_reason( + finish_reason, + ))); } } Err(crate::http_client::Error::StreamEnded) => { diff --git a/rig/rig-core/src/providers/openai/responses_api/mod.rs b/rig/rig-core/src/providers/openai/responses_api/mod.rs index 9e510c861..64efa29ce 100644 --- a/rig/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig/rig-core/src/providers/openai/responses_api/mod.rs @@ -1463,6 +1463,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Err(CompletionError::ResponseError( "Response contained no assistant message".to_owned(), diff --git a/rig/rig-core/src/providers/xai/completion.rs b/rig/rig-core/src/providers/xai/completion.rs index b698c771a..80f5fb3b8 100644 --- a/rig/rig-core/src/providers/xai/completion.rs +++ b/rig/rig-core/src/providers/xai/completion.rs @@ -131,6 +131,10 @@ impl TryFrom for completion::CompletionResponse Result { + let stop_reason = response + .status + .as_deref() + .map(|status| completion::StopReason::Other(status.to_string())); let content: Vec = response .output .iter() @@ -161,6 +165,7 @@ impl TryFrom for completion::CompletionResponse, + /// Provider-agnostic reason why the model stopped generating this turn. + pub stop_reason: Option, } impl StreamingCompletionResponse @@ -234,6 +240,7 @@ where response: None, final_response_yielded: AtomicBool::new(false), message_id: None, + stop_reason: None, } } @@ -303,6 +310,7 @@ where usage: Usage::new(), // Usage is not tracked in streaming responses raw_response: value.response, message_id: value.message_id, + stop_reason: value.stop_reason, } } } @@ -406,6 +414,10 @@ where stream.message_id = Some(id); stream.poll_next_unpin(cx) } + RawStreamingChoice::StopReason(reason) => { + stream.stop_reason = Some(reason); + stream.poll_next_unpin(cx) + } }, } } @@ -687,6 +699,15 @@ mod tests { StreamingCompletionResponse::stream(to_stream_result(stream)) } + fn create_stop_reason_stream() -> StreamingCompletionResponse { + let stream = stream! { + yield Ok(RawStreamingChoice::StopReason(crate::completion::StopReason::MaxTokens)); + yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 0 })); + }; + + StreamingCompletionResponse::stream(to_stream_result(stream)) + } + fn create_interleaved_stream() -> StreamingCompletionResponse { let stream = stream! { yield Ok(RawStreamingChoice::Reasoning { @@ -847,6 +868,17 @@ mod tests { assert!(stream.choice.is_empty()); } + #[tokio::test] + async fn test_stream_captures_stop_reason() { + let mut stream = create_stop_reason_stream(); + while stream.next().await.is_some() {} + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::MaxTokens) + ); + } + #[tokio::test] async fn test_stream_aggregates_assistant_items_in_arrival_order() { let mut stream = create_interleaved_stream(); diff --git a/rig/rig-core/tests/core/prompt_response_messages.rs b/rig/rig-core/tests/core/prompt_response_messages.rs index 3377eee0a..e69ca09d3 100644 --- a/rig/rig-core/tests/core/prompt_response_messages.rs +++ b/rig/rig-core/tests/core/prompt_response_messages.rs @@ -46,6 +46,7 @@ impl CompletionModel for SimpleTextModel { }, raw_response: (), message_id: Some("msg_mock_1".to_string()), + stop_reason: None, }) } @@ -108,6 +109,7 @@ impl CompletionModel for ToolThenTextModel { }, raw_response: (), message_id: Some("msg_tool".to_string()), + stop_reason: None, }) } else { // Second turn: return a text response @@ -124,6 +126,7 @@ impl CompletionModel for ToolThenTextModel { }, raw_response: (), message_id: Some("msg_text".to_string()), + stop_reason: None, }) } } @@ -191,6 +194,7 @@ impl CompletionModel for ToolThenEmptyTerminalTurnModel { }, raw_response: (), message_id: Some("msg_text_tool".to_string()), + stop_reason: None, }) } else { Ok(CompletionResponse { @@ -204,6 +208,7 @@ impl CompletionModel for ToolThenEmptyTerminalTurnModel { }, raw_response: (), message_id: Some("msg_empty".to_string()), + stop_reason: None, }) } } @@ -244,6 +249,7 @@ impl CompletionModel for AlwaysToolCallModel { usage: Usage::new(), raw_response: (), message_id: None, + stop_reason: None, }) } From 67e1068e531912fe57354d7c640b25a34b27ba29 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 18:38:33 -0700 Subject: [PATCH 5/9] migration --- rig/rig-core/src/providers/copilot/mod.rs | 154 ++++++------- rig/rig-core/src/providers/deepseek.rs | 197 ++++++++++------ rig/rig-core/src/providers/galadriel.rs | 20 +- .../src/providers/huggingface/completion.rs | 27 +-- rig/rig-core/src/providers/hyperbolic.rs | 21 +- .../src/providers/mistral/completion.rs | 30 ++- .../src/providers/openai/completion/compat.rs | 45 ++++ .../src/providers/openai/completion/mod.rs | 40 ++-- .../openai/completion/stream_compat.rs | 218 ++++++++++++++++++ .../providers/openai/completion/streaming.rs | 107 ++------- .../src/providers/openrouter/completion.rs | 22 +- .../src/providers/openrouter/streaming.rs | 136 +++++------ rig/rig-core/src/providers/perplexity.rs | 30 ++- 13 files changed, 631 insertions(+), 416 deletions(-) create mode 100644 rig/rig-core/src/providers/openai/completion/compat.rs create mode 100644 rig/rig-core/src/providers/openai/completion/stream_compat.rs diff --git a/rig/rig-core/src/providers/copilot/mod.rs b/rig/rig-core/src/providers/copilot/mod.rs index 99f98d1bd..ffb565c30 100644 --- a/rig/rig-core/src/providers/copilot/mod.rs +++ b/rig/rig-core/src/providers/copilot/mod.rs @@ -26,6 +26,10 @@ use crate::completion::{self, CompletionError, GetTokenUsage}; use crate::embeddings::{self, EmbeddingError}; use crate::http_client::{self, HttpClientExt}; use crate::providers::openai; +use crate::providers::openai::completion::{ + CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + take_finalized_tool_calls, take_tool_calls, +}; use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest}; use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse}; use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; @@ -1262,6 +1266,16 @@ enum ChatFinishReason { Other(String), } +fn map_chat_finish_reason(reason: &ChatFinishReason) -> completion::StopReason { + match reason { + ChatFinishReason::ToolCalls => completion::StopReason::ToolCalls, + ChatFinishReason::Stop => completion::StopReason::Stop, + ChatFinishReason::ContentFilter => completion::StopReason::ContentFilter, + ChatFinishReason::Length => completion::StopReason::MaxTokens, + ChatFinishReason::Other(other) => completion::StopReason::Other(other.clone()), + } +} + #[derive(Deserialize, Debug)] struct ChatStreamingChoice { delta: ChatStreamingDelta, @@ -1313,77 +1327,17 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let index = tool_call.index; - - if let Some(new_id) = &tool_call.id - && !new_id.is_empty() - && let Some(new_name) = &tool_call.function.name - && !new_name.is_empty() - && let Some(existing) = tool_calls.get(&index) - && !existing.id.is_empty() - && existing.id != *new_id - && !existing.name.is_empty() - && existing.name != *new_name - { - let evicted = tool_calls.remove(&index).expect("checked above"); - yield Ok(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); - } - - let existing_tool_call = tool_calls - .entry(index) - .or_insert_with(streaming::RawStreamingToolCall::empty); - - if let Some(id) = &tool_call.id - && !id.is_empty() - { - existing_tool_call.id = id.clone(); - } - - if let Some(name) = &tool_call.function.name - && !name.is_empty() - { - existing_tool_call.name = name.clone(); - yield Ok(RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Name(name.clone()), - }); - } - - if let Some(chunk) = &tool_call.function.arguments - && !chunk.is_empty() - { - let current_args = match &existing_tool_call.arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(s) => s.clone(), - value => value.to_string(), - }; - let combined = format!("{current_args}{chunk}"); - - if combined.trim_start().starts_with('{') - && combined.trim_end().ends_with('}') - { - match serde_json::from_str(&combined) { - Ok(parsed) => existing_tool_call.arguments = parsed, - Err(_) => { - existing_tool_call.arguments = - serde_json::Value::String(combined) - } - } - } else { - existing_tool_call.arguments = - serde_json::Value::String(combined); - } - - yield Ok(RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Delta(chunk.clone()), - }); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::EvictDistinctIdAndName, + ) { + yield Ok(event); } } @@ -1405,12 +1359,15 @@ where if let Some(finish_reason) = &choice.finish_reason && *finish_reason == ChatFinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } - tool_calls = HashMap::new(); + } + + if let Some(finish_reason) = &choice.finish_reason { + yield Ok(RawStreamingChoice::StopReason(map_chat_finish_reason( + finish_reason, + ))); } } Err(crate::http_client::Error::StreamEnded) => break, @@ -1428,8 +1385,8 @@ where return; } - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(RawStreamingChoice::ToolCall(tool_call)); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } let final_usage = final_usage.unwrap_or_default(); @@ -1460,16 +1417,6 @@ where Ok(StreamingCompletionResponse::stream(Box::pin(stream))) } -fn finalize_completed_streaming_tool_call( - mut tool_call: streaming::RawStreamingToolCall, -) -> streaming::RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); - } - - tool_call -} - fn default_token_dir() -> Option { config_dir().map(|dir| dir.join("github_copilot")) } @@ -1493,7 +1440,7 @@ mod tests { use super::{ ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute, TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token, - route_for_model, + route_for_model, send_copilot_chat_streaming_request, }; use crate::client::CompletionClient; use crate::completion::CompletionModel; @@ -2018,6 +1965,35 @@ mod tests { ); } + #[tokio::test] + async fn chat_stream_captures_stop_reason() { + let sse_bytes = Bytes::from(concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n\n", + "data: [DONE]\n\n", + )); + + let http_client = MockStreamingClient { sse_bytes }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + + let mut stream = send_copilot_chat_streaming_request(http_client, req) + .await + .expect("stream should start"); + + while let Some(chunk) = stream.next().await { + chunk.expect("stream chunk should deserialize"); + } + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::Stop) + ); + } + #[test] fn env_api_key_prefers_github_prefixed_vars() { let env = env_map(&[ diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index 98fff1530..fa36d891f 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -9,7 +9,6 @@ //! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT); //! ``` -use crate::json_utils::empty_or_none; use async_stream::stream; use bytes::Bytes; use futures::StreamExt; @@ -25,6 +24,10 @@ use crate::completion::GetTokenUsage; use crate::http_client::sse::{Event, GenericEventSource}; use crate::http_client::{self, HttpClientExt}; use crate::message::{Document, DocumentSourceKind}; +use crate::providers::openai::completion::{ + CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + map_finish_reason, take_finalized_tool_calls, take_tool_calls, +}; use crate::{ completion::{self, CompletionError, CompletionRequest}, json_utils, message, @@ -380,9 +383,7 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( &choice.finish_reason, )); @@ -423,8 +424,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse, } #[derive(Deserialize, Debug)] @@ -741,8 +743,7 @@ where let stream = stream! { let mut final_usage = Usage::new(); - let mut text_response = String::new(); - let mut calls: HashMap = HashMap::new(); + let mut tool_calls: HashMap = HashMap::new(); while let Some(event_result) = event_source.next().await { match event_result { @@ -766,59 +767,47 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let function = &tool_call.function; - - // Start of tool call - if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false) - && empty_or_none(&function.arguments) - { - let id = tool_call.id.clone().unwrap_or_default(); - let name = function.name.clone().unwrap(); - calls.insert(tool_call.index, (id, name, String::new())); - } - // Continuation of tool call - else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true) - && let Some(arguments) = &function.arguments - && !arguments.is_empty() - { - if let Some((id, name, existing_args)) = calls.get(&tool_call.index) { - let combined = format!("{}{}", existing_args, arguments); - calls.insert(tool_call.index, (id.clone(), name.clone(), combined)); - } else { - tracing::debug!("Partial tool call received but tool call was never started."); - } - } - // Complete tool call - else { - let id = tool_call.id.clone().unwrap_or_default(); - let name = function.name.clone().unwrap_or_default(); - let arguments_str = function.arguments.clone().unwrap_or_default(); - - let Ok(arguments_json) = json_utils::parse_tool_arguments(&arguments_str) else { - tracing::debug!("Couldn't parse tool call args '{}'", arguments_str); - continue; - }; - - yield Ok(crate::streaming::RawStreamingChoice::ToolCall( - crate::streaming::RawStreamingToolCall::new(id, name, arguments_json) - )); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::KeepIndex, + ) { + yield Ok(event); } } // DeepSeek-specific reasoning stream - if let Some(content) = &delta.reasoning_content { + if let Some(content) = &delta.reasoning_content + && !content.is_empty() + { yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: content.to_string() }); } - if let Some(content) = &delta.content { - text_response += content; + if let Some(content) = &delta.content + && !content.is_empty() + { yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone())); } + + if let Some(finish_reason) = &choice.finish_reason { + if finish_reason == "tool_calls" { + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); + } + } + + yield Ok(crate::streaming::RawStreamingChoice::StopReason( + map_finish_reason(finish_reason), + )); + } } if let Some(usage) = data.usage { @@ -838,25 +827,8 @@ where event_source.close(); - let mut tool_calls = Vec::new(); - // Flush accumulated tool calls - for (index, (id, name, arguments)) in calls { - let Ok(arguments_json) = json_utils::parse_tool_arguments(&arguments) else { - continue; - }; - - tool_calls.push(ToolCall { - id: id.clone(), - index, - r#type: ToolType::Function, - function: Function { - name: name.clone(), - arguments: arguments_json.clone() - } - }); - yield Ok(crate::streaming::RawStreamingChoice::ToolCall( - crate::streaming::RawStreamingToolCall::new(id, name, arguments_json) - )); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } yield Ok(crate::streaming::RawStreamingChoice::FinalResponse( @@ -879,6 +851,10 @@ pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner"; #[cfg(test)] mod tests { use super::*; + use crate::http_client::mock::MockStreamingClient; + use crate::streaming::StreamedAssistantContent; + use bytes::Bytes; + use futures::StreamExt; #[test] fn test_deserialize_vec_choice() { @@ -1136,4 +1112,75 @@ mod tests { .build() .expect("Client::builder() failed"); } + + #[tokio::test] + async fn test_stream_captures_stop_reason() { + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n\n", + "data: [DONE]\n\n", + ); + let client = MockStreamingClient { + sse_bytes: Bytes::from(sse), + }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + + let mut stream = send_compatible_streaming_request(client, req) + .await + .expect("stream should start"); + + while let Some(chunk) = stream.next().await { + chunk.expect("stream chunk should deserialize"); + } + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::Stop) + ); + } + + #[tokio::test] + async fn test_stream_accumulates_tool_call_arguments_until_finish_reason() { + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"subtract\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"x\\\":2,\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"y\\\":5}\"}}]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n", + "data: [DONE]\n\n", + ); + let client = MockStreamingClient { + sse_bytes: Bytes::from(sse), + }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + + let mut stream = send_compatible_streaming_request(client, req) + .await + .expect("stream should start"); + let mut collected_tool_calls = Vec::new(); + + while let Some(chunk) = stream.next().await { + if let StreamedAssistantContent::ToolCall { + tool_call, + internal_call_id: _, + } = chunk.expect("stream chunk should deserialize") + { + collected_tool_calls.push(tool_call); + } + } + + assert_eq!(collected_tool_calls.len(), 1); + assert_eq!(collected_tool_calls[0].id, "call_123"); + assert_eq!(collected_tool_calls[0].function.name, "subtract"); + assert_eq!( + collected_tool_calls[0].function.arguments, + serde_json::json!({"x": 2, "y": 5}) + ); + } } diff --git a/rig/rig-core/src/providers/galadriel.rs b/rig/rig-core/src/providers/galadriel.rs index 1f1e1dca0..0f849d821 100644 --- a/rig/rig-core/src/providers/galadriel.rs +++ b/rig/rig-core/src/providers/galadriel.rs @@ -239,9 +239,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse>(); content.extend(message.tool_calls.iter().map(|call| { completion::AssistantContent::tool_call( @@ -260,7 +259,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( &choice.finish_reason, )); @@ -565,8 +563,11 @@ impl TryFrom for completion::CompletionResponse message::AssistantContent::text(text), + AssistantContent::Text { text } => { + crate::providers::openai::completion::non_empty_text(text) + } }) + .flatten() .collect::>(); content.extend( @@ -588,8 +589,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = openai::completion::first_choice(&response.choices)?; let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( &choice.finish_reason, )); @@ -194,11 +192,12 @@ impl TryFrom for completion::CompletionResponse completion::AssistantContent::text(text), + AssistantContent::Text { text } => openai::completion::non_empty_text(text), AssistantContent::Refusal { refusal } => { - completion::AssistantContent::text(refusal) + openai::completion::non_empty_text(refusal) } }) + .flatten() .collect::>(); content.extend( @@ -220,8 +219,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( &choice.finish_reason, )); @@ -500,11 +498,9 @@ impl TryFrom for completion::CompletionResponse { - let mut content = if content.is_empty() { - vec![] - } else { - vec![completion::AssistantContent::text(content.clone())] - }; + let mut content = crate::providers::openai::completion::non_empty_text(content) + .into_iter() + .collect::>(); content.extend( tool_calls @@ -525,8 +521,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse(choices: &[T]) -> Result<&T, CompletionError> { + choices + .first() + .ok_or_else(|| CompletionError::ResponseError("Response contained no choices".to_owned())) +} + +pub(crate) fn map_finish_reason(reason: &str) -> completion::StopReason { + match reason { + "stop" => completion::StopReason::Stop, + "tool_calls" => completion::StopReason::ToolCalls, + "content_filter" => completion::StopReason::ContentFilter, + "length" => completion::StopReason::MaxTokens, + other => completion::StopReason::Other(other.to_string()), + } +} + +pub(crate) fn non_empty_text(text: impl AsRef) -> Option { + let text = text.as_ref(); + if text.is_empty() { + None + } else { + Some(completion::AssistantContent::text(text)) + } +} + +pub(crate) fn build_completion_response( + raw_response: R, + usage: completion::Usage, + message_id: Option, + stop_reason: Option, + choice: C, +) -> completion::CompletionResponse +where + C: Into, +{ + completion::CompletionResponse { + choice: choice.into(), + usage, + raw_response, + message_id, + stop_reason, + } +} diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index fb85e455b..acf8926b6 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -22,17 +22,17 @@ use tracing::{Instrument, Level, enabled, info_span}; use std::str::FromStr; +mod compat; +mod stream_compat; pub mod streaming; -pub(crate) fn map_finish_reason(reason: &str) -> completion::StopReason { - match reason { - "stop" => completion::StopReason::Stop, - "tool_calls" => completion::StopReason::ToolCalls, - "content_filter" => completion::StopReason::ContentFilter, - "length" => completion::StopReason::MaxTokens, - other => completion::StopReason::Other(other.to_string()), - } -} +pub(crate) use compat::{ + build_completion_response, first_choice, map_finish_reason, non_empty_text, +}; +pub(crate) use stream_compat::{ + CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + take_finalized_tool_calls, take_tool_calls, +}; /// Serializes user content as a plain string when there's a single text item, /// otherwise as an array of content parts. @@ -802,9 +802,7 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = first_choice(&response.choices)?; let stop_reason = Some(map_finish_reason(&choice.finish_reason)); let content = match &choice.message { @@ -820,11 +818,7 @@ impl TryFrom for completion::CompletionResponse text, AssistantContent::Refusal { refusal } => refusal, }; - if s.is_empty() { - None - } else { - Some(completion::AssistantContent::text(s)) - } + non_empty_text(s) }) .collect::>(); @@ -847,8 +841,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse { + pub index: usize, + pub id: Option<&'a str>, + pub name: Option<&'a str>, + pub arguments: Option<&'a str>, +} + +pub(crate) fn apply_compatible_tool_call_deltas<'a, R>( + tool_calls: &mut HashMap, + incoming: impl IntoIterator>, + conflict_policy: ToolCallConflictPolicy, +) -> Vec> +where + R: Clone, +{ + let mut events = Vec::new(); + + for tool_call in incoming { + let index = tool_call.index; + + if conflict_policy == ToolCallConflictPolicy::EvictDistinctIdAndName + && should_evict_existing_tool_call(tool_calls.get(&index), tool_call.id, tool_call.name) + && let Some(evicted) = tool_calls.remove(&index) + { + events.push(RawStreamingChoice::ToolCall( + finalize_completed_streaming_tool_call(evicted), + )); + } + + let existing_tool_call = tool_calls + .entry(index) + .or_insert_with(RawStreamingToolCall::empty); + + if let Some(id) = tool_call.id + && !id.is_empty() + { + existing_tool_call.id = id.to_owned(); + } + + if let Some(name) = tool_call.name + && !name.is_empty() + { + existing_tool_call.name = name.to_owned(); + events.push(RawStreamingChoice::ToolCallDelta { + id: existing_tool_call.id.clone(), + internal_call_id: existing_tool_call.internal_call_id.clone(), + content: ToolCallDeltaContent::Name(name.to_owned()), + }); + } + + if let Some(chunk) = tool_call.arguments + && !chunk.is_empty() + { + append_tool_call_arguments(&mut existing_tool_call.arguments, chunk); + events.push(RawStreamingChoice::ToolCallDelta { + id: existing_tool_call.id.clone(), + internal_call_id: existing_tool_call.internal_call_id.clone(), + content: ToolCallDeltaContent::Delta(chunk.to_owned()), + }); + } + } + + events +} + +pub(crate) fn take_finalized_tool_calls( + tool_calls: &mut HashMap, +) -> Vec> +where + R: Clone, +{ + std::mem::take(tool_calls) + .into_values() + .map(|tool_call| { + RawStreamingChoice::ToolCall(finalize_completed_streaming_tool_call(tool_call)) + }) + .collect() +} + +pub(crate) fn take_tool_calls( + tool_calls: &mut HashMap, +) -> Vec> +where + R: Clone, +{ + std::mem::take(tool_calls) + .into_values() + .map(RawStreamingChoice::ToolCall) + .collect() +} + +pub(crate) fn finalize_completed_streaming_tool_call( + mut tool_call: RawStreamingToolCall, +) -> RawStreamingToolCall { + if tool_call.arguments.is_null() { + tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); + } + + tool_call +} + +fn should_evict_existing_tool_call( + existing: Option<&RawStreamingToolCall>, + new_id: Option<&str>, + new_name: Option<&str>, +) -> bool { + let Some(existing) = existing else { + return false; + }; + + let Some(new_id) = new_id.filter(|id| !id.is_empty()) else { + return false; + }; + let Some(new_name) = new_name.filter(|name| !name.is_empty()) else { + return false; + }; + + !existing.id.is_empty() + && existing.id != new_id + && !existing.name.is_empty() + && existing.name != new_name +} + +fn append_tool_call_arguments(arguments: &mut serde_json::Value, chunk: &str) { + let current_arguments = match arguments { + serde_json::Value::Null => String::new(), + serde_json::Value::String(value) => value.clone(), + ref value => value.to_string(), + }; + let combined = format!("{current_arguments}{chunk}"); + + if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { + match serde_json::from_str(&combined) { + Ok(parsed) => *arguments = parsed, + Err(_) => *arguments = serde_json::Value::String(combined), + } + } else { + *arguments = serde_json::Value::String(combined); + } +} + +#[cfg(test)] +mod tests { + use super::{ + CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + finalize_completed_streaming_tool_call, take_finalized_tool_calls, + }; + use crate::streaming::{RawStreamingChoice, RawStreamingToolCall}; + use std::collections::HashMap; + + #[test] + fn evicts_distinct_tool_calls_that_reuse_the_same_index() { + let mut tool_calls = HashMap::from([( + 0, + RawStreamingToolCall { + id: "call_1".to_owned(), + internal_call_id: "internal_1".to_owned(), + call_id: None, + name: "weather".to_owned(), + arguments: serde_json::json!({"city":"Paris"}), + signature: None, + additional_params: None, + }, + )]); + + let events = apply_compatible_tool_call_deltas::<()>( + &mut tool_calls, + [CompatibleStreamingToolCall { + index: 0, + id: Some("call_2"), + name: Some("time"), + arguments: Some("{"), + }], + ToolCallConflictPolicy::EvictDistinctIdAndName, + ); + + assert!( + matches!(events.first(), Some(RawStreamingChoice::ToolCall(tool_call)) if tool_call.id == "call_1") + ); + assert_eq!( + tool_calls.get(&0).map(|tool_call| tool_call.id.as_str()), + Some("call_2") + ); + assert_eq!( + tool_calls.get(&0).map(|tool_call| tool_call.name.as_str()), + Some("time") + ); + } + + #[test] + fn finalizes_null_arguments_into_empty_objects() { + let finalized = finalize_completed_streaming_tool_call(RawStreamingToolCall::empty()); + assert_eq!(finalized.arguments, serde_json::json!({})); + } + + #[test] + fn drains_finalized_tool_calls() { + let mut tool_calls = HashMap::from([(0, RawStreamingToolCall::empty())]); + + let events = take_finalized_tool_calls::<()>(&mut tool_calls); + + assert!(tool_calls.is_empty()); + assert!( + matches!(events.as_slice(), [RawStreamingChoice::ToolCall(tool_call)] if tool_call.arguments == serde_json::json!({})) + ); + } +} diff --git a/rig/rig-core/src/providers/openai/completion/streaming.rs b/rig/rig-core/src/providers/openai/completion/streaming.rs index 8b459d6f7..b2fe808da 100644 --- a/rig/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig/rig-core/src/providers/openai/completion/streaming.rs @@ -12,7 +12,11 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; use crate::http_client::HttpClientExt; use crate::http_client::sse::{Event, GenericEventSource}; use crate::json_utils::{self, merge}; -use crate::providers::openai::completion::{GenericCompletionModel, OpenAIRequestParams, Usage}; +use crate::providers::openai::completion::{ + CompatibleStreamingToolCall, GenericCompletionModel, OpenAIRequestParams, + ToolCallConflictPolicy, Usage, apply_compatible_tool_call_deltas, take_finalized_tool_calls, + take_tool_calls, +}; use crate::streaming::{self, RawStreamingChoice}; // ================================================================ @@ -212,75 +216,17 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let index = tool_call.index; - - // Some API gateways (e.g. LiteLLM, OneAPI) emit multiple - // distinct tool calls all sharing index 0. Detect this by - // comparing both the `id` and `name`: only evict when a new - // chunk carries a different non-empty id AND a different - // non-empty name. Checking the name prevents false evictions - // from providers (e.g. GLM-4) that send a unique id on every - // SSE chunk for the same logical tool call. - if let Some(new_id) = &tool_call.id - && !new_id.is_empty() - && let Some(new_name) = &tool_call.function.name - && !new_name.is_empty() - && let Some(existing) = tool_calls.get(&index) - && !existing.id.is_empty() - && existing.id != *new_id - && !existing.name.is_empty() - && existing.name != *new_name - { - let evicted = tool_calls.remove(&index).expect("checked above"); - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); - } - - let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty); - - if let Some(id) = &tool_call.id && !id.is_empty() { - existing_tool_call.id = id.clone(); - } - - if let Some(name) = &tool_call.function.name && !name.is_empty() { - existing_tool_call.name = name.clone(); - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Name(name.clone()), - }); - } - - // Convert current arguments to string if needed - if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() { - let current_args = match &existing_tool_call.arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(s) => s.clone(), - v => v.to_string(), - }; - - // Concatenate the new chunk - let combined = format!("{current_args}{chunk}"); - - // Try to parse as JSON if it looks complete - if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { - match serde_json::from_str(&combined) { - Ok(parsed) => existing_tool_call.arguments = parsed, - Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined), - } - } else { - existing_tool_call.arguments = serde_json::Value::String(combined); - } - - // Emit the delta so UI can show progress - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Delta(chunk.clone()), - }); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::EvictDistinctIdAndName, + ) { + yield Ok(event); } } @@ -301,12 +247,9 @@ where // Finish reason if let Some(finish_reason) = &choice.finish_reason { if *finish_reason == FinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } - tool_calls = HashMap::new(); } yield Ok(streaming::RawStreamingChoice::StopReason(map_finish_reason( @@ -330,8 +273,8 @@ where event_source.close(); // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier) - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call)); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } let final_usage = final_usage.unwrap_or_default(); @@ -358,16 +301,6 @@ where ))) } -fn finalize_completed_streaming_tool_call( - mut tool_call: streaming::RawStreamingToolCall, -) -> streaming::RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); - } - - tool_call -} - #[cfg(test)] mod tests { use super::*; diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index 8b0ac391b..fcf658168 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -592,9 +592,7 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; let stop_reason = choice .finish_reason .as_deref() @@ -708,8 +706,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse crate::completion::StopReason { + match reason { + FinishReason::ToolCalls => crate::completion::StopReason::ToolCalls, + FinishReason::Stop => crate::completion::StopReason::Stop, + FinishReason::Error => crate::completion::StopReason::Other("error".to_owned()), + FinishReason::ContentFilter => crate::completion::StopReason::ContentFilter, + FinishReason::Length => crate::completion::StopReason::MaxTokens, + FinishReason::Other(other) => crate::completion::StopReason::Other(other.clone()), + } +} + #[derive(Deserialize, Debug, PartialEq)] #[serde(rename_all = "snake_case")] pub enum FinishReason { @@ -212,54 +227,17 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let index = tool_call.index; - - // Get or create tool call entry - let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty); - - // Update fields if present - if let Some(id) = &tool_call.id && !id.is_empty() { - existing_tool_call.id = id.clone(); - } - - if let Some(name) = &tool_call.function.name && !name.is_empty() { - existing_tool_call.name = name.clone(); - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Name(name.clone()), - }); - } - - // Convert current arguments to string if needed - if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() { - let current_args = match &existing_tool_call.arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(s) => s.clone(), - v => v.to_string(), - }; - - // Concatenate the new chunk - let combined = format!("{current_args}{chunk}"); - - // Try to parse as JSON if it looks complete - if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { - match serde_json::from_str(&combined) { - Ok(parsed) => existing_tool_call.arguments = parsed, - Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined), - } - } else { - existing_tool_call.arguments = serde_json::Value::String(combined); - } - - // Emit the delta so UI can show progress - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Delta(chunk.clone()), - }); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::KeepIndex, + ) { + yield Ok(event); } // Update the signature and the additional params of the tool call if present @@ -294,12 +272,15 @@ where // Finish reason if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } - tool_calls = HashMap::new(); + } + + if let Some(finish_reason) = &choice.finish_reason { + yield Ok(streaming::RawStreamingChoice::StopReason(map_finish_reason( + finish_reason, + ))); } } Err(crate::http_client::Error::StreamEnded) => { @@ -317,8 +298,8 @@ where event_source.close(); // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier) - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call)); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } // Final response with usage @@ -332,16 +313,6 @@ where ))) } -fn finalize_completed_streaming_tool_call( - mut tool_call: streaming::RawStreamingToolCall, -) -> streaming::RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = Value::Object(serde_json::Map::new()); - } - - tool_call -} - #[cfg(test)] mod tests { use super::*; @@ -548,4 +519,39 @@ mod tests { assert_eq!(error.code, 500); assert_eq!(error.message, "Provider disconnected"); } + + #[tokio::test] + async fn test_stream_captures_stop_reason() { + use crate::http_client::mock::MockStreamingClient; + use bytes::Bytes; + use futures::StreamExt; + + let sse = concat!( + "data: {\"id\":\"gen-1\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"id\":\"gen-2\",\"model\":\"gpt-4\",\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n\n", + "data: [DONE]\n\n", + ); + + let client = MockStreamingClient { + sse_bytes: Bytes::from(sse), + }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .unwrap(); + + let mut stream = send_compatible_streaming_request(client, req) + .await + .expect("stream should start"); + + while let Some(chunk) = stream.next().await { + chunk.expect("stream chunk should deserialize"); + } + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::Stop) + ); + } } diff --git a/rig/rig-core/src/providers/perplexity.rs b/rig/rig-core/src/providers/perplexity.rs index 04b3548d0..5991f97f8 100644 --- a/rig/rig-core/src/providers/perplexity.rs +++ b/rig/rig-core/src/providers/perplexity.rs @@ -175,9 +175,7 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( &choice.finish_reason, )); @@ -186,19 +184,29 @@ impl TryFrom for completion::CompletionResponse Ok(completion::CompletionResponse { - choice: completion::AssistantChoice::one(content.clone().into()), - usage: completion::Usage { + } => { + let normalized_content = + crate::providers::openai::completion::non_empty_text(content) + .into_iter() + .collect::>(); + let usage = completion::Usage { input_tokens: response.usage.prompt_tokens as u64, output_tokens: response.usage.completion_tokens as u64, total_tokens: response.usage.total_tokens as u64, cached_input_tokens: 0, cache_creation_input_tokens: 0, - }, - raw_response: response, - message_id: None, - stop_reason, - }), + }; + + Ok( + crate::providers::openai::completion::build_completion_response( + response, + usage, + None, + stop_reason, + normalized_content, + ), + ) + } _ => Err(CompletionError::ResponseError( "Response contained no assistant message".to_owned(), )), From d4f83d7713efc8336be0d0529fea4718ba060c26 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 19:17:16 -0700 Subject: [PATCH 6/9] migration --- rig/rig-core/src/providers/azure.rs | 126 ++++--- rig/rig-core/src/providers/deepseek.rs | 64 ++-- rig/rig-core/src/providers/galadriel.rs | 58 ++-- .../src/providers/gemini/conformance_tests.rs | 30 +- rig/rig-core/src/providers/groq.rs | 87 ++--- .../src/providers/huggingface/completion.rs | 72 ++-- rig/rig-core/src/providers/hyperbolic.rs | 61 ++-- rig/rig-core/src/providers/mira.rs | 40 +-- rig/rig-core/src/providers/moonshot.rs | 56 ++-- .../src/providers/openai/completion/mod.rs | 59 ++-- .../openai/completion/request_compat.rs | 317 ++++++++++++++++++ .../src/providers/openrouter/completion.rs | 56 ++-- rig/rig-core/src/providers/perplexity.rs | 69 ++-- .../src/providers/together/completion.rs | 91 +++-- 14 files changed, 701 insertions(+), 485 deletions(-) create mode 100644 rig/rig-core/src/providers/openai/completion/request_compat.rs diff --git a/rig/rig-core/src/providers/azure.rs b/rig/rig-core/src/providers/azure.rs index 9573d6f1a..829db7be5 100644 --- a/rig/rig-core/src/providers/azure.rs +++ b/rig/rig-core/src/providers/azure.rs @@ -579,43 +579,29 @@ impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - //FIXME: Must fix! - if req.tool_choice.is_some() { - tracing::warn!( - "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25." - ); - } - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![openai::Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - let tool_choice = req - .tool_choice - .clone() + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Azure OpenAI") + .supports_output_schema(), + openai::Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let additional_params = if let Some(schema) = req.output_schema { + let additional_params = if let Some(schema) = output_schema { let name = schema .as_object() .and_then(|o| o.get("title")) @@ -634,21 +620,19 @@ impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest { } } }); - Some(match req.additional_params { + Some(match additional_params { Some(existing) => json_utils::merge(existing, response_format), None => response_format, }) } else { - req.additional_params + additional_params }; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(openai::ToolDefinition::from) .collect::>(), @@ -1073,6 +1057,64 @@ mod azure_tests { use crate::prelude::TypedPrompt; use crate::providers::openai::GPT_5_MINI; + #[test] + fn azure_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("azure-override".to_string()), + preamble: Some("system".to_string()), + chat_history: OneOrMany::one("hello".into()), + documents: vec![], + max_tokens: None, + temperature: None, + tools: vec![], + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let converted = AzureOpenAICompletionRequest::try_from(("azure-default", request)) + .expect("request should convert"); + + assert_eq!(converted.model, "azure-override"); + } + + #[test] + fn azure_request_merges_structured_output_into_additional_params() { + let request = CompletionRequest { + model: None, + preamble: Some("system".to_string()), + chat_history: OneOrMany::one("hello".into()), + documents: vec![], + max_tokens: None, + temperature: None, + tools: vec![], + tool_choice: None, + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "WeatherResponse", + "type": "object", + "properties": { + "city": { "type": "string" } + } + })) + .expect("schema should deserialize"), + ), + }; + + let converted = AzureOpenAICompletionRequest::try_from(("azure-default", request)) + .expect("request should convert"); + let params = converted + .additional_params + .expect("additional params should be present"); + + assert_eq!(params["foo"], serde_json::json!("bar")); + assert_eq!( + params["response_format"]["json_schema"]["name"], + serde_json::json!("WeatherResponse") + ); + } + #[tokio::test] #[ignore] async fn test_azure_embedding() { diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index fa36d891f..3daefaf50 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -25,8 +25,9 @@ use crate::http_client::sse::{Event, GenericEventSource}; use crate::http_client::{self, HttpClientExt}; use crate::message::{Document, DocumentSourceKind}; use crate::providers::openai::completion::{ - CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, - map_finish_reason, take_finalized_tool_calls, take_tool_calls, + CompatibleChatProfile, CompatibleStreamingToolCall, ToolCallConflictPolicy, + apply_compatible_tool_call_deltas, build_compatible_request_core, map_finish_reason, + take_finalized_tool_calls, take_tool_calls, }; use crate::{ completion::{self, CompletionError, CompletionRequest}, @@ -468,50 +469,37 @@ impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for DeepSeek"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - let tool_choice = req - .tool_choice - .clone() + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = build_compatible_request_core( + model, + req, + CompatibleChatProfile::new("DeepSeek"), + Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openrouter::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } diff --git a/rig/rig-core/src/providers/galadriel.rs b/rig/rig-core/src/providers/galadriel.rs index 0f849d821..680ceb1b5 100644 --- a/rig/rig-core/src/providers/galadriel.rs +++ b/rig/rig-core/src/providers/galadriel.rs @@ -453,49 +453,37 @@ impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Galadriel"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - // Add preamble to chat history (if available) - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::, _>>()?, - ); - - let tool_choice = req - .tool_choice - .clone() + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Galadriel"), + Message::system, + |message| Ok(vec![message.try_into()?]), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } diff --git a/rig/rig-core/src/providers/gemini/conformance_tests.rs b/rig/rig-core/src/providers/gemini/conformance_tests.rs index e816b398a..0ae339147 100644 --- a/rig/rig-core/src/providers/gemini/conformance_tests.rs +++ b/rig/rig-core/src/providers/gemini/conformance_tests.rs @@ -214,30 +214,12 @@ fn stream_request() -> completion::CompletionRequest { fn streaming_sse(case: Fixture) -> String { match case { - Fixture::EmptyAssistantTurnAfterToolResult => concat!( - "data: {\"candidates\":[{\"content\":{\"parts\":[],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n", - ) - .to_string(), - Fixture::ToolOnlyTurn => concat!( - "data: {\"candidates\":[{\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"totalTokenCount\":15}}\n\n", - ) - .to_string(), - Fixture::TextAndToolCallTurn => concat!( - "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need weather data first.\"},{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":7,\"totalTokenCount\":17}}\n\n", - ) - .to_string(), - Fixture::EmptyTextBlocks => concat!( - "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n", - ) - .to_string(), - Fixture::ReasoningOnlyTurn => concat!( - "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need to reason about the tool result.\",\"thought\":true,\"thoughtSignature\":\"sig_1\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"thoughtsTokenCount\":3,\"totalTokenCount\":18}}\n\n", - ) - .to_string(), - Fixture::StopReasonMapping => concat!( - "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Truncated response\"}],\"role\":\"model\"},\"finishReason\":\"MAX_TOKENS\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":2,\"totalTokenCount\":12}}\n\n", - ) - .to_string(), + Fixture::EmptyAssistantTurnAfterToolResult => "data: {\"candidates\":[{\"content\":{\"parts\":[],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n".to_string(), + Fixture::ToolOnlyTurn => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"totalTokenCount\":15}}\n\n".to_string(), + Fixture::TextAndToolCallTurn => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need weather data first.\"},{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":7,\"totalTokenCount\":17}}\n\n".to_string(), + Fixture::EmptyTextBlocks => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n".to_string(), + Fixture::ReasoningOnlyTurn => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need to reason about the tool result.\",\"thought\":true,\"thoughtSignature\":\"sig_1\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"thoughtsTokenCount\":3,\"totalTokenCount\":18}}\n\n".to_string(), + Fixture::StopReasonMapping => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Truncated response\"}],\"role\":\"model\"},\"finishReason\":\"MAX_TOKENS\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":2,\"totalTokenCount\":12}}\n\n".to_string(), Fixture::MessageIdPreservation => unreachable!(), } } diff --git a/rig/rig-core/src/providers/groq.rs b/rig/rig-core/src/providers/groq.rs index 88ba5cc4e..97bca110e 100644 --- a/rig/rig-core/src/providers/groq.rs +++ b/rig/rig-core/src/providers/groq.rs @@ -34,7 +34,6 @@ use futures::StreamExt; use crate::{ completion::{self, CompletionError, CompletionRequest}, json_utils, - message::{self}, providers::openai::ToolDefinition, transcription::{self, TranscriptionError}, }; @@ -184,42 +183,29 @@ pub(super) struct StreamOptions { impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { type Error = CompletionError; - fn try_from((model, mut req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Groq"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - // Add preamble to chat history (if available) - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![OpenAIMessage::system(preamble)], - None => vec![], - }; - - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); - - let tool_choice = req - .tool_choice - .clone() + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Groq"), + OpenAIMessage::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null); + let mut additional_params_payload = additional_params.unwrap_or(Value::Null); let native_tools = extract_native_tools_from_additional_params(&mut additional_params_payload)?; @@ -232,12 +218,10 @@ impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { apply_native_tools_to_additional_params(&mut additional_params, native_tools); Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), @@ -822,6 +806,7 @@ where mod tests { use crate::{ OneOrMany, + completion::CompletionRequest, providers::{ groq::{GroqAdditionalParameters, GroqCompletionRequest}, openai::{Message, UserContent}, @@ -870,6 +855,28 @@ mod tests { }) ) } + + #[test] + fn groq_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("groq-override".to_string()), + preamble: Some("system".to_string()), + chat_history: OneOrMany::one("hello".into()), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let converted = GroqCompletionRequest::try_from(("groq-default", request)) + .expect("request should convert"); + + assert_eq!(converted.model, "groq-override"); + } + #[test] fn test_client_initialization() { let _client = diff --git a/rig/rig-core/src/providers/huggingface/completion.rs b/rig/rig-core/src/providers/huggingface/completion.rs index 47182606f..9caa58a38 100644 --- a/rig/rig-core/src/providers/huggingface/completion.rs +++ b/rig/rig-core/src/providers/huggingface/completion.rs @@ -562,12 +562,11 @@ impl TryFrom for completion::CompletionResponse { let mut content = content .iter() - .map(|c| match c { + .filter_map(|c| match c { AssistantContent::Text { text } => { crate::providers::openai::completion::non_empty_text(text) } }) - .flatten() .collect::>(); content.extend( @@ -627,59 +626,38 @@ impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Huggingface"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "HuggingFace request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - let tool_choice = req - .tool_choice - .clone() + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("HuggingFace") + .require_messages(), + Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } diff --git a/rig/rig-core/src/providers/hyperbolic.rs b/rig/rig-core/src/providers/hyperbolic.rs index 7df38bb2e..4e0f2507d 100644 --- a/rig/rig-core/src/providers/hyperbolic.rs +++ b/rig/rig-core/src/providers/hyperbolic.rs @@ -191,13 +191,12 @@ impl TryFrom for completion::CompletionResponse { let mut content = content .iter() - .map(|c| match c { + .filter_map(|c| match c { AssistantContent::Text { text } => openai::completion::non_empty_text(text), AssistantContent::Refusal { refusal } => { openai::completion::non_empty_text(refusal) } }) - .flatten() .collect::>(); content.extend( @@ -262,46 +261,38 @@ impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Hyperbolic"); - } - - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - if req.tool_choice.is_some() { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Hyperbolic") + .without_tools() + .without_tool_choice(), + Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + if tool_choice.is_some() { tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic"); } - if !req.tools.is_empty() { + if !tools.is_empty() { tracing::warn!("WARNING: `tools` not supported on Hyperbolic"); } - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - additional_params: req.additional_params, + model, + messages, + temperature, + additional_params, }) } } diff --git a/rig/rig-core/src/providers/mira.rs b/rig/rig-core/src/providers/mira.rs index 6e390f764..c9988980e 100644 --- a/rig/rig-core/src/providers/mira.rs +++ b/rig/rig-core/src/providers/mira.rs @@ -224,10 +224,13 @@ pub(super) struct MiraCompletionRequest { impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { type Error = CompletionError; - fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Mira"); - } + fn try_from((model, mut req): (&str, CompletionRequest)) -> Result { + crate::providers::openai::completion::CompatibleFeaturePolicy::default() + .without_tools() + .without_tool_choice() + .without_additional_params() + .apply("Mira AI", &mut req); + let model = req.model.clone().unwrap_or_else(|| model.to_string()); let mut messages = Vec::new(); @@ -355,21 +358,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if !completion_request.tools.is_empty() { - tracing::warn!(target: "rig::completions", - "Tool calls are not supported by Mira AI. {len} tools will be ignored.", - len = completion_request.tools.len() - ); - } - - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); - } - - if completion_request.additional_params.is_some() { - tracing::warn!("WARNING: Additional parameters not supported on Mira AI"); - } - let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?; if tracing::enabled!(tracing::Level::TRACE) { @@ -460,20 +448,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if !completion_request.tools.is_empty() { - tracing::warn!(target: "rig::completions", - "Tool calls are not supported by Mira AI. {len} tools will be ignored.", - len = completion_request.tools.len() - ); - } - - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); - } - - if completion_request.additional_params.is_some() { - tracing::warn!("WARNING: Additional parameters not supported on Mira AI"); - } let mut request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?; request.stream = true; diff --git a/rig/rig-core/src/providers/moonshot.rs b/rig/rig-core/src/providers/moonshot.rs index 13a034345..81c9817ef 100644 --- a/rig/rig-core/src/providers/moonshot.rs +++ b/rig/rig-core/src/providers/moonshot.rs @@ -319,27 +319,31 @@ impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Moonshot"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![serde_json::to_value(openai::Message::system(preamble))?], - None => vec![], - }; - - full_history.extend(moonshot_history_values(partial_history)?); - - let mut tool_choice = None; + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages: mut full_history, + temperature, + max_tokens, + tools, + tool_choice: request_tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Moonshot"), + |preamble| { + serde_json::json!({ + "role": "system", + "content": preamble, + }) + }, + |message| moonshot_history_values(vec![message]), + )?; + + let mut tool_choice: Option = None; let mut tool_choice_required = false; - if let Some(choice) = req.tool_choice.clone() { + if let Some(choice) = request_tool_choice { match choice { message::ToolChoice::Required => { tool_choice_required = true; @@ -362,18 +366,16 @@ impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest { } Ok(Self { - model: model.to_string(), + model, messages: full_history, - temperature: req.temperature, - max_tokens: req.max_tokens, - tools: req - .tools - .clone() + temperature, + max_tokens, + tools: tools .into_iter() .map(openai::ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index acf8926b6..e783464b3 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -23,12 +23,17 @@ use tracing::{Instrument, Level, enabled, info_span}; use std::str::FromStr; mod compat; +mod request_compat; mod stream_compat; pub mod streaming; pub(crate) use compat::{ build_completion_response, first_choice, map_finish_reason, non_empty_text, }; +pub(crate) use request_compat::{ + CompatibleChatProfile, CompatibleFeaturePolicy, CompatibleRequestCore, + build_compatible_request_core, +}; pub(crate) use stream_compat::{ CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, take_finalized_tool_calls, take_tool_calls, @@ -1060,48 +1065,24 @@ impl TryFrom for CompletionRequest { strict_tools, tool_result_array_content, } = params; - - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - let CoreCompletionRequest { - model: request_model, - preamble, - chat_history, - tools, + let CompatibleRequestCore { + model, + messages: mut full_history, temperature, max_tokens, - additional_params, + tools, tool_choice, + additional_params, output_schema, - .. - } = req; - - partial_history.extend(chat_history); - - let mut full_history: Vec = - preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); - - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "OpenAI Chat Completions request has no provider-compatible messages after conversion", - ) - .into(), - )); - } + } = build_compatible_request_core( + &model, + req, + CompatibleChatProfile::new("OpenAI Chat Completions") + .require_messages() + .supports_output_schema(), + Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; if tool_result_array_content { for msg in &mut full_history { @@ -1160,7 +1141,7 @@ impl TryFrom for CompletionRequest { }; let res = Self { - model: request_model.unwrap_or(model), + model, messages: full_history, tools, tool_choice, diff --git a/rig/rig-core/src/providers/openai/completion/request_compat.rs b/rig/rig-core/src/providers/openai/completion/request_compat.rs new file mode 100644 index 000000000..37155313e --- /dev/null +++ b/rig/rig-core/src/providers/openai/completion/request_compat.rs @@ -0,0 +1,317 @@ +use crate::completion::{CompletionError, CompletionRequest as CoreCompletionRequest}; +use crate::message; + +/// Shared unsupported-feature policy for request builders. +#[derive(Debug, Clone, Copy)] +pub(crate) struct CompatibleFeaturePolicy { + supports_output_schema: bool, + supports_tools: bool, + supports_tool_choice: bool, + supports_additional_params: bool, +} + +impl Default for CompatibleFeaturePolicy { + fn default() -> Self { + Self { + supports_output_schema: false, + supports_tools: true, + supports_tool_choice: true, + supports_additional_params: true, + } + } +} + +impl CompatibleFeaturePolicy { + pub(crate) const fn supports_output_schema(mut self) -> Self { + self.supports_output_schema = true; + self + } + + pub(crate) const fn without_tools(mut self) -> Self { + self.supports_tools = false; + self + } + + pub(crate) const fn without_tool_choice(mut self) -> Self { + self.supports_tool_choice = false; + self + } + + pub(crate) const fn without_additional_params(mut self) -> Self { + self.supports_additional_params = false; + self + } + + pub(crate) fn apply(self, provider_name: &'static str, req: &mut CoreCompletionRequest) { + if req.output_schema.is_some() && !self.supports_output_schema { + tracing::warn!( + "Structured outputs currently not supported for {}", + provider_name + ); + req.output_schema = None; + } + + if !req.tools.is_empty() && !self.supports_tools { + tracing::warn!("WARNING: `tools` not supported on {}", provider_name); + req.tools.clear(); + } + + if req.tool_choice.is_some() && !self.supports_tool_choice { + tracing::warn!("WARNING: `tool_choice` not supported on {}", provider_name); + req.tool_choice = None; + } + + if req.additional_params.is_some() && !self.supports_additional_params { + tracing::warn!( + "WARNING: `additional_params` not supported on {}", + provider_name + ); + req.additional_params = None; + } + } +} + +/// Shared request-shaping profile for OpenAI-compatible chat providers. +#[derive(Debug, Clone, Copy)] +pub(crate) struct CompatibleChatProfile { + provider_name: &'static str, + require_messages: bool, + feature_policy: CompatibleFeaturePolicy, +} + +impl CompatibleChatProfile { + pub(crate) const fn new(provider_name: &'static str) -> Self { + Self { + provider_name, + require_messages: false, + feature_policy: CompatibleFeaturePolicy { + supports_output_schema: false, + supports_tools: true, + supports_tool_choice: true, + supports_additional_params: true, + }, + } + } + + pub(crate) const fn require_messages(mut self) -> Self { + self.require_messages = true; + self + } + + pub(crate) const fn supports_output_schema(mut self) -> Self { + self.feature_policy = self.feature_policy.supports_output_schema(); + self + } + + pub(crate) const fn without_tools(mut self) -> Self { + self.feature_policy = self.feature_policy.without_tools(); + self + } + + pub(crate) const fn without_tool_choice(mut self) -> Self { + self.feature_policy = self.feature_policy.without_tool_choice(); + self + } +} + +/// Provider-agnostic core request fields shared by OpenAI-compatible chat families. +#[derive(Debug)] +pub(crate) struct CompatibleRequestCore { + pub model: String, + pub messages: Vec, + pub temperature: Option, + pub max_tokens: Option, + pub tools: Vec, + pub tool_choice: Option, + pub additional_params: Option, + pub output_schema: Option, +} + +pub(crate) fn build_compatible_request_core( + default_model: &str, + mut req: CoreCompletionRequest, + profile: CompatibleChatProfile, + system_message: impl Fn(&str) -> M, + mut convert_message: F, +) -> Result, CompletionError> +where + F: FnMut(message::Message) -> Result, CompletionError>, +{ + profile + .feature_policy + .apply(profile.provider_name, &mut req); + + let normalized_documents = req.normalized_documents(); + let CoreCompletionRequest { + model: request_model, + preamble, + chat_history, + tools, + temperature, + max_tokens, + tool_choice, + additional_params, + output_schema, + .. + } = req; + + let mut partial_history = Vec::new(); + if let Some(docs) = normalized_documents { + partial_history.push(docs); + } + partial_history.extend(chat_history); + + let mut messages = preamble + .as_deref() + .map_or_else(Vec::new, |preamble| vec![system_message(preamble)]); + + messages.extend( + partial_history + .into_iter() + .map(&mut convert_message) + .collect::>, _>>()? + .into_iter() + .flatten(), + ); + + if profile.require_messages && messages.is_empty() { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "{} request has no provider-compatible messages after conversion", + profile.provider_name + ), + ) + .into(), + )); + } + + Ok(CompatibleRequestCore { + model: request_model.unwrap_or_else(|| default_model.to_owned()), + messages, + temperature, + max_tokens, + tools, + tool_choice, + additional_params, + output_schema, + }) +} + +#[cfg(test)] +mod tests { + use super::{CompatibleChatProfile, CompatibleFeaturePolicy, build_compatible_request_core}; + use crate::OneOrMany; + use crate::completion::{CompletionRequest, ToolDefinition}; + use crate::message::{Message, ToolChoice, UserContent}; + + #[test] + fn requires_non_empty_messages_when_profile_demands_it() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let err = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider").require_messages(), + |preamble| preamble.to_owned(), + |_message| Ok(Vec::::new()), + ) + .expect_err("empty converted messages should fail"); + + assert!(err.to_string().contains( + "Example Provider request has no provider-compatible messages after conversion" + )); + } + + #[test] + fn preserves_model_override_and_additional_params() { + let req = CompletionRequest { + model: Some("override-model".to_owned()), + preamble: Some("system".to_owned()), + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: Some(0.5), + max_tokens: Some(42), + tool_choice: None, + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: None, + }; + + let result = build_compatible_request_core( + "default-model", + req, + CompatibleChatProfile::new("Example Provider"), + |preamble| format!("system:{preamble}"), + |_message| Ok(vec!["history".to_owned()]), + ) + .expect("request conversion should succeed"); + + assert_eq!(result.model, "override-model"); + assert_eq!(result.temperature, Some(0.5)); + assert_eq!(result.max_tokens, Some(42)); + assert_eq!( + result.messages, + vec!["system:system".to_owned(), "history".to_owned()] + ); + assert_eq!( + result.additional_params, + Some(serde_json::json!({"foo":"bar"})) + ); + } + + #[test] + fn feature_policy_strips_unsupported_fields() { + let mut req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "ping".to_owned(), + description: "Ping tool".to_owned(), + parameters: serde_json::json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Required), + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "Example", + "type": "object" + })) + .expect("schema should deserialize"), + ), + }; + + CompatibleFeaturePolicy::default() + .without_tools() + .without_tool_choice() + .without_additional_params() + .apply("Example Provider", &mut req); + + assert!(req.tools.is_empty()); + assert!(req.tool_choice.is_none()); + assert!(req.additional_params.is_none()); + assert!(req.output_schema.is_none()); + } +} diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index fcf658168..c58e46fbc 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -1581,42 +1581,28 @@ impl TryFrom> for OpenrouterCompletionRequest { request: req, strict_tools, } = params; - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for OpenRouter"); - } - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("OpenRouter"), + Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; - let tool_choice = req - .tool_choice - .clone() + let tool_choice = tool_choice .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; - let tools: Vec = req - .tools - .clone() + let tools: Vec = tools .into_iter() .map(|tool| { let def = crate::providers::openai::completion::ToolDefinition::from(tool); @@ -1626,11 +1612,11 @@ impl TryFrom> for OpenrouterCompletionRequest { Ok(Self { model, - messages: full_history, - temperature: req.temperature, + messages, + temperature, tools, tool_choice, - additional_params: req.additional_params, + additional_params, }) } } diff --git a/rig/rig-core/src/providers/perplexity.rs b/rig/rig-core/src/providers/perplexity.rs index 5991f97f8..a34d4dc23 100644 --- a/rig/rig-core/src/providers/perplexity.rs +++ b/rig/rig-core/src/providers/perplexity.rs @@ -231,38 +231,34 @@ impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Perplexity"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = req.preamble.map_or_else(Vec::new, |preamble| { - vec![Message { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens, + tools: _, + tool_choice: _, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Perplexity") + .without_tools() + .without_tool_choice(), + |preamble| Message { role: Role::System, - content: preamble, - }] - }); - - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::, _>>()?, - ); + content: preamble.to_owned(), + }, + |message| Ok(vec![message.try_into()?]), + )?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - max_tokens: req.max_tokens, - additional_params: req.additional_params, + model, + messages, + temperature, + max_tokens, + additional_params, stream: false, }) } @@ -384,13 +380,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); - } - - if !completion_request.tools.is_empty() { - tracing::warn!("WARNING: `tools` not supported on Perplexity"); - } let request = PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?; @@ -470,14 +459,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); - } - - if !completion_request.tools.is_empty() { - tracing::warn!("WARNING: `tools` not supported on Perplexity"); - } - let mut request = PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?; request.stream = true; diff --git a/rig/rig-core/src/providers/together/completion.rs b/rig/rig-core/src/providers/together/completion.rs index e5b24ee05..bbe30363c 100644 --- a/rig/rig-core/src/providers/together/completion.rs +++ b/rig/rig-core/src/providers/together/completion.rs @@ -147,58 +147,36 @@ impl TryFrom<(&str, CompletionRequest)> for TogetherAICompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for TogetherAI"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![openai::Message::system(preamble)], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Together request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - let tool_choice = req - .tool_choice - .clone() - .map(ToolChoice::try_from) - .transpose()?; + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("TogetherAI") + .require_messages(), + openai::Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(crate::providers::openai::completion::ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -390,4 +368,25 @@ mod tests { let result = TogetherAICompletionRequest::try_from(("meta-llama/test-model", request)); assert!(matches!(result, Err(CompletionError::RequestError(_)))); } + + #[test] + fn together_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("together-override".to_string()), + preamble: Some("system".to_string()), + chat_history: OneOrMany::one(message::Message::user("hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let converted = TogetherAICompletionRequest::try_from(("together-default", request)) + .expect("request should convert"); + + assert_eq!(converted.model, "together-override"); + } } From 3c49a3afd3bc777e944c46c5b75006e43c36b215 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 19:30:28 -0700 Subject: [PATCH 7/9] further migration --- rig/rig-core/src/providers/hyperbolic.rs | 12 +-- rig/rig-core/src/providers/llamafile.rs | 76 +++++++++------- .../src/providers/mistral/completion.rs | 89 ++++++++++--------- 3 files changed, 90 insertions(+), 87 deletions(-) diff --git a/rig/rig-core/src/providers/hyperbolic.rs b/rig/rig-core/src/providers/hyperbolic.rs index 4e0f2507d..0ffeff68a 100644 --- a/rig/rig-core/src/providers/hyperbolic.rs +++ b/rig/rig-core/src/providers/hyperbolic.rs @@ -266,8 +266,8 @@ impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest { messages, temperature, max_tokens: _, - tools, - tool_choice, + tools: _, + tool_choice: _, additional_params, output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( @@ -280,14 +280,6 @@ impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest { |message| Vec::::try_from(message).map_err(CompletionError::from), )?; - if tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic"); - } - - if !tools.is_empty() { - tracing::warn!("WARNING: `tools` not supported on Hyperbolic"); - } - Ok(Self { model, messages, diff --git a/rig/rig-core/src/providers/llamafile.rs b/rig/rig-core/src/providers/llamafile.rs index 8fabb721f..d032329f6 100644 --- a/rig/rig-core/src/providers/llamafile.rs +++ b/rig/rig-core/src/providers/llamafile.rs @@ -164,45 +164,34 @@ impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs may not be supported by llamafile"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - - // Build message history: preamble -> documents -> chat history - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![openai::Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|msg| msg.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens, + tools, + tool_choice: _, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("llamafile") + .without_tool_choice(), + openai::Message::system, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; Ok(Self { model, - messages: full_history, - temperature: req.temperature, - max_tokens: req.max_tokens, - tools: req - .tools + messages, + temperature, + max_tokens, + tools: tools .into_iter() .map(openai::ToolDefinition::from) .collect(), - additional_params: req.additional_params, + additional_params, }) } } @@ -698,4 +687,25 @@ mod tests { assert_eq!(request.temperature, Some(0.7)); assert_eq!(request.max_tokens, Some(256)); } + + #[test] + fn test_llamafile_request_uses_request_model_override() { + let completion_request = CompletionRequest { + model: Some("custom-llamafile".to_string()), + preamble: None, + chat_history: crate::OneOrMany::one(crate::completion::Message::user("Hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request)) + .expect("Failed to create request"); + + assert_eq!(request.model, "custom-llamafile"); + } } diff --git a/rig/rig-core/src/providers/mistral/completion.rs b/rig/rig-core/src/providers/mistral/completion.rs index eae8b580a..7e5fd77fd 100644 --- a/rig/rig-core/src/providers/mistral/completion.rs +++ b/rig/rig-core/src/providers/mistral/completion.rs @@ -346,59 +346,39 @@ impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Mistral"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble.clone())], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Mistral request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - let tool_choice = req - .tool_choice + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + output_schema: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Mistral") + .require_messages(), + |preamble| Message::system(preamble.to_owned()), + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .clone() .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -830,4 +810,25 @@ mod tests { let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request)); assert!(matches!(result, Err(CompletionError::RequestError(_)))); } + + #[test] + fn test_mistral_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("mistral-custom".to_string()), + preamble: Some("System".to_string()), + chat_history: crate::OneOrMany::one(message::Message::user("Hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let request = + MistralCompletionRequest::try_from((MISTRAL_SMALL, request)).expect("request"); + + assert_eq!(request.model, "mistral-custom"); + } } From 08b0ca363cf0de1eb56f8fd131cf99e21ed7d117 Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 20:11:45 -0700 Subject: [PATCH 8/9] family --- rig/rig-core/src/providers/azure.rs | 65 +- rig/rig-core/src/providers/deepseek.rs | 37 +- rig/rig-core/src/providers/galadriel.rs | 3 +- rig/rig-core/src/providers/groq.rs | 36 +- .../src/providers/huggingface/completion.rs | 37 +- rig/rig-core/src/providers/hyperbolic.rs | 44 +- rig/rig-core/src/providers/llamafile.rs | 39 +- rig/rig-core/src/providers/mira.rs | 12 +- .../src/providers/mistral/completion.rs | 38 +- rig/rig-core/src/providers/moonshot.rs | 79 +- .../src/providers/openai/completion/compat.rs | 45 - .../src/providers/openai/completion/family.rs | 1447 +++++++++++++++++ .../src/providers/openai/completion/mod.rs | 241 +-- .../openai/completion/request_compat.rs | 317 ---- .../openai/completion/stream_compat.rs | 218 --- .../providers/openai/completion/streaming.rs | 56 +- .../src/providers/openrouter/completion.rs | 60 +- rig/rig-core/src/providers/perplexity.rs | 7 +- .../src/providers/together/completion.rs | 55 +- 19 files changed, 2041 insertions(+), 795 deletions(-) delete mode 100644 rig/rig-core/src/providers/openai/completion/compat.rs create mode 100644 rig/rig-core/src/providers/openai/completion/family.rs delete mode 100644 rig/rig-core/src/providers/openai/completion/request_compat.rs delete mode 100644 rig/rig-core/src/providers/openai/completion/stream_compat.rs diff --git a/rig/rig-core/src/providers/azure.rs b/rig/rig-core/src/providers/azure.rs index 829db7be5..ce7e64410 100644 --- a/rig/rig-core/src/providers/azure.rs +++ b/rig/rig-core/src/providers/azure.rs @@ -587,13 +587,14 @@ impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest { tools, tool_choice, additional_params, - output_schema, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("Azure OpenAI") - .supports_output_schema(), + .native_response_format(), openai::Message::system, + None, + |message| matches!(message, openai::Message::ToolResult { .. }), |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -601,33 +602,6 @@ impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest { .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let additional_params = if let Some(schema) = output_schema { - let name = schema - .as_object() - .and_then(|o| o.get("title")) - .and_then(|v| v.as_str()) - .unwrap_or("response_schema") - .to_string(); - let mut schema_value = schema.to_value(); - openai::sanitize_schema(&mut schema_value); - let response_format = serde_json::json!({ - "response_format": { - "type": "json_schema", - "json_schema": { - "name": name, - "strict": true, - "schema": schema_value - } - } - }); - Some(match additional_params { - Some(existing) => json_utils::merge(existing, response_format), - None => response_format, - }) - } else { - additional_params - }; - Ok(Self { model, messages, @@ -1056,6 +1030,39 @@ mod azure_tests { use crate::embeddings::EmbeddingModel; use crate::prelude::TypedPrompt; use crate::providers::openai::GPT_5_MINI; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct AzureRequestHarness; + + impl request_conformance::Harness for AzureRequestHarness { + fn family_name() -> &'static str { + "azure-openai" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + AzureOpenAICompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Azure OpenAI").native_response_format(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(AzureRequestHarness); #[test] fn azure_request_uses_request_model_override() { diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index 3daefaf50..04ed03621 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -477,12 +477,13 @@ impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = build_compatible_request_core( model, req, CompatibleChatProfile::new("DeepSeek"), Message::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -840,10 +841,44 @@ pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner"; mod tests { use super::*; use crate::http_client::mock::MockStreamingClient; + use crate::providers::openai::completion::request_conformance; use crate::streaming::StreamedAssistantContent; use bytes::Bytes; use futures::StreamExt; + struct DeepSeekRequestHarness; + + impl request_conformance::Harness for DeepSeekRequestHarness { + fn family_name() -> &'static str { + "deepseek" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + DeepseekCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new(CompatibleChatProfile::new( + "DeepSeek", + )) + .preserves_reasoning_assistant_history(), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(DeepSeekRequestHarness); + #[test] fn test_deserialize_vec_choice() { let data = r#"[{ diff --git a/rig/rig-core/src/providers/galadriel.rs b/rig/rig-core/src/providers/galadriel.rs index 680ceb1b5..25f594c45 100644 --- a/rig/rig-core/src/providers/galadriel.rs +++ b/rig/rig-core/src/providers/galadriel.rs @@ -461,12 +461,13 @@ impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("Galadriel"), Message::system, + None, + |_| false, |message| Ok(vec![message.try_into()?]), )?; diff --git a/rig/rig-core/src/providers/groq.rs b/rig/rig-core/src/providers/groq.rs index 97bca110e..80e0006d9 100644 --- a/rig/rig-core/src/providers/groq.rs +++ b/rig/rig-core/src/providers/groq.rs @@ -192,12 +192,13 @@ impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("Groq"), OpenAIMessage::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -809,10 +810,43 @@ mod tests { completion::CompletionRequest, providers::{ groq::{GroqAdditionalParameters, GroqCompletionRequest}, + openai::completion::{CompatibleChatProfile, request_conformance}, openai::{Message, UserContent}, }, }; + struct GroqRequestHarness; + + impl request_conformance::Harness for GroqRequestHarness { + fn family_name() -> &'static str { + "groq" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + GroqCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new(CompatibleChatProfile::new( + "Groq", + )), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(GroqRequestHarness); + #[test] fn serialize_groq_request() { let additional_params = GroqAdditionalParameters { diff --git a/rig/rig-core/src/providers/huggingface/completion.rs b/rig/rig-core/src/providers/huggingface/completion.rs index 9caa58a38..0d9b7eaef 100644 --- a/rig/rig-core/src/providers/huggingface/completion.rs +++ b/rig/rig-core/src/providers/huggingface/completion.rs @@ -634,13 +634,14 @@ impl TryFrom<(&str, CompletionRequest)> for HuggingfaceCompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("HuggingFace") .require_messages(), Message::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -797,8 +798,42 @@ where #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::request_conformance; use serde_path_to_error::deserialize; + struct HuggingFaceRequestHarness; + + impl request_conformance::Harness for HuggingFaceRequestHarness { + fn family_name() -> &'static str { + "huggingface" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + HuggingfaceCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("HuggingFace") + .require_messages(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(HuggingFaceRequestHarness); + #[test] fn test_huggingface_request_uses_request_model_override() { let request = CompletionRequest { diff --git a/rig/rig-core/src/providers/hyperbolic.rs b/rig/rig-core/src/providers/hyperbolic.rs index 0ffeff68a..cf95814cc 100644 --- a/rig/rig-core/src/providers/hyperbolic.rs +++ b/rig/rig-core/src/providers/hyperbolic.rs @@ -269,14 +269,15 @@ impl TryFrom<(&str, CompletionRequest)> for HyperbolicCompletionRequest { tools: _, tool_choice: _, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("Hyperbolic") - .without_tools() - .without_tool_choice(), + .unsupported_tools() + .unsupported_tool_choice(), Message::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -687,6 +688,43 @@ mod audio_generation { #[cfg(test)] mod tests { + use super::HyperbolicCompletionRequest; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct HyperbolicRequestHarness; + + impl request_conformance::Harness for HyperbolicRequestHarness { + fn family_name() -> &'static str { + "hyperbolic" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + HyperbolicCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Hyperbolic") + .unsupported_tools() + .unsupported_tool_choice(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(HyperbolicRequestHarness); + #[test] fn test_client_initialization() { let _client = diff --git a/rig/rig-core/src/providers/llamafile.rs b/rig/rig-core/src/providers/llamafile.rs index d032329f6..db97497a5 100644 --- a/rig/rig-core/src/providers/llamafile.rs +++ b/rig/rig-core/src/providers/llamafile.rs @@ -172,13 +172,14 @@ impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest { tools, tool_choice: _, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("llamafile") - .without_tool_choice(), + .unsupported_tool_choice(), openai::Message::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -640,6 +641,40 @@ where mod tests { use super::*; use crate::client::Nothing; + use crate::providers::openai::completion::request_conformance; + + struct LlamafileRequestHarness; + + impl request_conformance::Harness for LlamafileRequestHarness { + fn family_name() -> &'static str { + "llamafile" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + LlamafileCompletionRequest::try_from((LLAMA_CPP, request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("llamafile") + .unsupported_tool_choice(), + ), + LLAMA_CPP, + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(LlamafileRequestHarness); #[test] fn test_client_initialization() { diff --git a/rig/rig-core/src/providers/mira.rs b/rig/rig-core/src/providers/mira.rs index c9988980e..a82c3c054 100644 --- a/rig/rig-core/src/providers/mira.rs +++ b/rig/rig-core/src/providers/mira.rs @@ -226,10 +226,14 @@ impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { fn try_from((model, mut req): (&str, CompletionRequest)) -> Result { crate::providers::openai::completion::CompatibleFeaturePolicy::default() - .without_tools() - .without_tool_choice() - .without_additional_params() - .apply("Mira AI", &mut req); + .with_tools_policy(crate::providers::openai::completion::ToolsPolicy::Unsupported) + .with_tool_choice_policy( + crate::providers::openai::completion::ToolChoicePolicy::Unsupported, + ) + .with_additional_params_policy( + crate::providers::openai::completion::AdditionalParamsPolicy::Unsupported, + ) + .apply("Mira AI", &mut req)?; let model = req.model.clone().unwrap_or_else(|| model.to_string()); let mut messages = Vec::new(); diff --git a/rig/rig-core/src/providers/mistral/completion.rs b/rig/rig-core/src/providers/mistral/completion.rs index 7e5fd77fd..9b06753d5 100644 --- a/rig/rig-core/src/providers/mistral/completion.rs +++ b/rig/rig-core/src/providers/mistral/completion.rs @@ -354,13 +354,14 @@ impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("Mistral") .require_messages(), |preamble| Message::system(preamble.to_owned()), + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -641,6 +642,41 @@ where #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::request_conformance; + + struct MistralRequestHarness; + + impl request_conformance::Harness for MistralRequestHarness { + fn family_name() -> &'static str { + "mistral" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + MistralCompletionRequest::try_from((MISTRAL_SMALL, request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("Mistral") + .require_messages(), + ) + .omits_document_messages(), + MISTRAL_SMALL, + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(MistralRequestHarness); #[test] fn test_response_deserialization() { diff --git a/rig/rig-core/src/providers/moonshot.rs b/rig/rig-core/src/providers/moonshot.rs index 81c9817ef..ca54224e7 100644 --- a/rig/rig-core/src/providers/moonshot.rs +++ b/rig/rig-core/src/providers/moonshot.rs @@ -39,7 +39,7 @@ use crate::{ }; use crate::{http_client, message}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::Value; use tracing::{Instrument, info_span}; // ================================================================ @@ -321,49 +321,38 @@ impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest { fn try_from((model, req): (&str, CompletionRequest)) -> Result { let crate::providers::openai::completion::CompatibleRequestCore { model, - messages: mut full_history, + messages: full_history, temperature, max_tokens, tools, tool_choice: request_tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, - crate::providers::openai::completion::CompatibleChatProfile::new("Moonshot"), + crate::providers::openai::completion::CompatibleChatProfile::new("Moonshot") + .coerce_required_tool_choice_to_auto( + "Please select a tool to handle the current issue.", + ), |preamble| { serde_json::json!({ "role": "system", "content": preamble, }) }, + Some(|content| { + serde_json::json!({ + "role": "user", + "content": content, + }) + }), + |_| false, |message| moonshot_history_values(vec![message]), )?; - let mut tool_choice: Option = None; - let mut tool_choice_required = false; - if let Some(choice) = request_tool_choice { - match choice { - message::ToolChoice::Required => { - tool_choice_required = true; - tool_choice = Some(crate::providers::openai::completion::ToolChoice::Auto); - } - other => { - tool_choice = Some(crate::providers::openai::ToolChoice::try_from(other)?); - } - } - } - - if tool_choice_required { - tracing::warn!( - "Moonshot does not support tool_choice=required; coercing to auto with an additional steering message" - ); - full_history.push(json!({ - "role": "user", - "content": "Please select a tool to handle the current issue." - })); - } + let tool_choice = request_tool_choice + .map(crate::providers::openai::ToolChoice::try_from) + .transpose()?; Ok(Self { model, @@ -657,6 +646,42 @@ mod tests { use crate::message::{ AssistantContent, Message, Reasoning, ToolCall, ToolChoice, ToolFunction, }; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct MoonshotRequestHarness; + + impl request_conformance::Harness for MoonshotRequestHarness { + fn family_name() -> &'static str { + "moonshot" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + MoonshotCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Moonshot").coerce_required_tool_choice_to_auto( + "Please select a tool to handle the current issue.", + ), + ) + .preserves_reasoning_assistant_history(), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(MoonshotRequestHarness); #[test] fn test_client_initialization() { diff --git a/rig/rig-core/src/providers/openai/completion/compat.rs b/rig/rig-core/src/providers/openai/completion/compat.rs deleted file mode 100644 index 12a19fe53..000000000 --- a/rig/rig-core/src/providers/openai/completion/compat.rs +++ /dev/null @@ -1,45 +0,0 @@ -use crate::completion::{self, CompletionError}; - -pub(crate) fn first_choice(choices: &[T]) -> Result<&T, CompletionError> { - choices - .first() - .ok_or_else(|| CompletionError::ResponseError("Response contained no choices".to_owned())) -} - -pub(crate) fn map_finish_reason(reason: &str) -> completion::StopReason { - match reason { - "stop" => completion::StopReason::Stop, - "tool_calls" => completion::StopReason::ToolCalls, - "content_filter" => completion::StopReason::ContentFilter, - "length" => completion::StopReason::MaxTokens, - other => completion::StopReason::Other(other.to_string()), - } -} - -pub(crate) fn non_empty_text(text: impl AsRef) -> Option { - let text = text.as_ref(); - if text.is_empty() { - None - } else { - Some(completion::AssistantContent::text(text)) - } -} - -pub(crate) fn build_completion_response( - raw_response: R, - usage: completion::Usage, - message_id: Option, - stop_reason: Option, - choice: C, -) -> completion::CompletionResponse -where - C: Into, -{ - completion::CompletionResponse { - choice: choice.into(), - usage, - raw_response, - message_id, - stop_reason, - } -} diff --git a/rig/rig-core/src/providers/openai/completion/family.rs b/rig/rig-core/src/providers/openai/completion/family.rs new file mode 100644 index 000000000..f76d0488c --- /dev/null +++ b/rig/rig-core/src/providers/openai/completion/family.rs @@ -0,0 +1,1447 @@ +use std::collections::HashMap; + +use crate::completion::{self, CompletionError, CompletionRequest as CoreCompletionRequest}; +use crate::json_utils; +use crate::message; +use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent}; + +pub(crate) fn first_choice(choices: &[T]) -> Result<&T, CompletionError> { + choices + .first() + .ok_or_else(|| CompletionError::ResponseError("Response contained no choices".to_owned())) +} + +pub(crate) fn map_finish_reason(reason: &str) -> completion::StopReason { + match reason { + "stop" => completion::StopReason::Stop, + "tool_calls" => completion::StopReason::ToolCalls, + "content_filter" => completion::StopReason::ContentFilter, + "length" => completion::StopReason::MaxTokens, + other => completion::StopReason::Other(other.to_string()), + } +} + +pub(crate) fn non_empty_text(text: impl AsRef) -> Option { + let text = text.as_ref(); + if text.is_empty() { + None + } else { + Some(completion::AssistantContent::text(text)) + } +} + +pub(crate) fn build_completion_response( + raw_response: R, + usage: completion::Usage, + message_id: Option, + stop_reason: Option, + choice: C, +) -> completion::CompletionResponse +where + C: Into, +{ + completion::CompletionResponse { + choice: choice.into(), + usage, + raw_response, + message_id, + stop_reason, + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ToolsPolicy { + Supported, + Unsupported, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ToolChoicePolicy { + PassThrough, + Unsupported, + RejectRequired, + CoerceRequiredToAuto { steering_message: &'static str }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum OutputSchemaPolicy { + Unsupported, + NativeResponseFormat, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum AdditionalParamsPolicy { + PassThrough, + Unsupported, +} + +/// Shared unsupported-feature policy for request builders. +#[derive(Debug, Clone, Copy)] +pub(crate) struct CompatibleFeaturePolicy { + tools_policy: ToolsPolicy, + tool_choice_policy: ToolChoicePolicy, + output_schema_policy: OutputSchemaPolicy, + additional_params_policy: AdditionalParamsPolicy, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub(crate) struct CompatibleRequestAdjustments { + steering_user_message: Option<&'static str>, +} + +impl Default for CompatibleFeaturePolicy { + fn default() -> Self { + Self { + tools_policy: ToolsPolicy::Supported, + tool_choice_policy: ToolChoicePolicy::PassThrough, + output_schema_policy: OutputSchemaPolicy::Unsupported, + additional_params_policy: AdditionalParamsPolicy::PassThrough, + } + } +} + +impl CompatibleFeaturePolicy { + pub(crate) const fn with_tools_policy(mut self, policy: ToolsPolicy) -> Self { + self.tools_policy = policy; + self + } + + pub(crate) const fn with_tool_choice_policy(mut self, policy: ToolChoicePolicy) -> Self { + self.tool_choice_policy = policy; + self + } + + pub(crate) const fn with_output_schema_policy(mut self, policy: OutputSchemaPolicy) -> Self { + self.output_schema_policy = policy; + self + } + + pub(crate) const fn with_additional_params_policy( + mut self, + policy: AdditionalParamsPolicy, + ) -> Self { + self.additional_params_policy = policy; + self + } + + pub(crate) fn apply( + self, + provider_name: &'static str, + req: &mut CoreCompletionRequest, + ) -> Result { + if req.output_schema.is_some() + && matches!(self.output_schema_policy, OutputSchemaPolicy::Unsupported) + { + tracing::warn!( + "Structured outputs currently not supported for {}", + provider_name + ); + req.output_schema = None; + } + + if !req.tools.is_empty() && matches!(self.tools_policy, ToolsPolicy::Unsupported) { + tracing::warn!("WARNING: `tools` not supported on {}", provider_name); + req.tools.clear(); + } + + let mut adjustments = CompatibleRequestAdjustments::default(); + if let Some(choice) = req.tool_choice.clone() { + match self.tool_choice_policy { + ToolChoicePolicy::PassThrough => {} + ToolChoicePolicy::Unsupported => { + tracing::warn!("WARNING: `tool_choice` not supported on {}", provider_name); + req.tool_choice = None; + } + ToolChoicePolicy::RejectRequired => { + if matches!(choice, crate::message::ToolChoice::Required) { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("{provider_name} does not support tool_choice=required"), + ) + .into(), + )); + } + } + ToolChoicePolicy::CoerceRequiredToAuto { steering_message } => { + if matches!(choice, crate::message::ToolChoice::Required) { + tracing::warn!( + "{} does not support tool_choice=required; coercing to auto with an additional steering message", + provider_name + ); + req.tool_choice = Some(crate::message::ToolChoice::Auto); + adjustments.steering_user_message = Some(steering_message); + } + } + } + } + + if req.additional_params.is_some() + && matches!( + self.additional_params_policy, + AdditionalParamsPolicy::Unsupported + ) + { + tracing::warn!( + "WARNING: `additional_params` not supported on {}", + provider_name + ); + req.additional_params = None; + } + + Ok(adjustments) + } +} + +/// Shared request-shaping profile for OpenAI-compatible chat providers. +#[derive(Debug, Clone, Copy)] +pub struct CompatibleChatProfile { + provider_name: &'static str, + require_messages: bool, + feature_policy: CompatibleFeaturePolicy, +} + +impl CompatibleChatProfile { + pub(crate) const fn new(provider_name: &'static str) -> Self { + Self { + provider_name, + require_messages: false, + feature_policy: CompatibleFeaturePolicy { + tools_policy: ToolsPolicy::Supported, + tool_choice_policy: ToolChoicePolicy::PassThrough, + output_schema_policy: OutputSchemaPolicy::Unsupported, + additional_params_policy: AdditionalParamsPolicy::PassThrough, + }, + } + } + + pub(crate) const fn openai_chat_completions(provider_name: &'static str) -> Self { + Self::new(provider_name) + .require_messages() + .native_response_format() + } + + pub(crate) const fn require_messages(mut self) -> Self { + self.require_messages = true; + self + } + + pub(crate) const fn native_response_format(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_output_schema_policy(OutputSchemaPolicy::NativeResponseFormat); + self + } + + pub(crate) const fn unsupported_tools(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_tools_policy(ToolsPolicy::Unsupported); + self + } + + pub(crate) const fn unsupported_tool_choice(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_tool_choice_policy(ToolChoicePolicy::Unsupported); + self + } + + pub(crate) const fn reject_required_tool_choice(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_tool_choice_policy(ToolChoicePolicy::RejectRequired); + self + } + + pub(crate) const fn coerce_required_tool_choice_to_auto( + mut self, + steering_message: &'static str, + ) -> Self { + self.feature_policy = self + .feature_policy + .with_tool_choice_policy(ToolChoicePolicy::CoerceRequiredToAuto { steering_message }); + self + } +} + +/// Provider-agnostic core request fields shared by OpenAI-compatible chat families. +#[derive(Debug)] +pub(crate) struct CompatibleRequestCore { + pub model: String, + pub messages: Vec, + pub temperature: Option, + pub max_tokens: Option, + pub tools: Vec, + pub tool_choice: Option, + pub additional_params: Option, +} + +fn merge_native_response_format( + additional_params: Option, + schema: schemars::Schema, +) -> Option { + let name = schema + .as_object() + .and_then(|o| o.get("title")) + .and_then(|v| v.as_str()) + .unwrap_or("response_schema") + .to_string(); + let mut schema_value = schema.to_value(); + super::super::sanitize_schema(&mut schema_value); + let response_format = serde_json::json!({ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": name, + "strict": true, + "schema": schema_value + } + } + }); + + Some(match additional_params { + Some(existing) => json_utils::merge(existing, response_format), + None => response_format, + }) +} + +pub(crate) fn build_compatible_request_core( + default_model: &str, + mut req: CoreCompletionRequest, + profile: CompatibleChatProfile, + system_message: fn(&str) -> M, + user_message: Option M>, + message_has_tool_result: fn(&M) -> bool, + mut convert_message: F, +) -> Result, CompletionError> +where + F: FnMut(message::Message) -> Result, CompletionError>, +{ + let adjustments = profile + .feature_policy + .apply(profile.provider_name, &mut req)?; + + let normalized_documents = req.normalized_documents(); + let CoreCompletionRequest { + model: request_model, + preamble, + chat_history, + tools, + temperature, + max_tokens, + tool_choice, + additional_params, + output_schema, + .. + } = req; + + let mut partial_history = Vec::new(); + if let Some(docs) = normalized_documents { + partial_history.push(docs); + } + partial_history.extend(chat_history); + + let mut messages = preamble + .as_deref() + .map_or_else(Vec::new, |preamble| vec![system_message(preamble)]); + + messages.extend( + partial_history + .into_iter() + .map(&mut convert_message) + .collect::>, _>>()? + .into_iter() + .flatten(), + ); + + if let Some(steering_message) = adjustments.steering_user_message { + let Some(user_message) = user_message else { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "{} requires a user-message adapter when coercing tool_choice", + profile.provider_name + ), + ) + .into(), + )); + }; + messages.push(user_message(steering_message)); + } + + if profile.require_messages && messages.is_empty() { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "{} request has no provider-compatible messages after conversion", + profile.provider_name + ), + ) + .into(), + )); + } + + let mut additional_params = additional_params; + if let Some(schema) = output_schema { + match profile.feature_policy.output_schema_policy { + OutputSchemaPolicy::NativeResponseFormat => { + let should_apply_response_format = + tools.is_empty() || messages.iter().any(message_has_tool_result); + if should_apply_response_format { + additional_params = merge_native_response_format(additional_params, schema); + } + } + OutputSchemaPolicy::Unsupported => {} + } + } + + Ok(CompatibleRequestCore { + model: request_model.unwrap_or_else(|| default_model.to_owned()), + messages, + temperature, + max_tokens, + tools, + tool_choice, + additional_params, + }) +} + +/// Internal provider profile used by the OpenAI-compatible chat family. +#[doc(hidden)] +pub trait OpenAiChatProviderProfile { + fn provider_name() -> &'static str; + + fn telemetry_provider_name() -> &'static str { + Self::provider_name() + } + + fn completions_path() -> &'static str { + "/chat/completions" + } + + fn request_profile() -> CompatibleChatProfile { + CompatibleChatProfile::openai_chat_completions(Self::provider_name()) + } + + fn stream_tool_call_conflict_policy() -> ToolCallConflictPolicy { + ToolCallConflictPolicy::EvictDistinctIdAndName + } +} + +impl OpenAiChatProviderProfile for super::super::client::OpenAICompletionsExt { + fn provider_name() -> &'static str { + "OpenAI Chat Completions" + } + + fn telemetry_provider_name() -> &'static str { + "openai" + } +} + +impl OpenAiChatProviderProfile for crate::providers::minimax::MiniMaxExt { + fn provider_name() -> &'static str { + "MiniMax" + } + + fn telemetry_provider_name() -> &'static str { + "minimax" + } +} + +impl OpenAiChatProviderProfile for crate::providers::zai::ZAiExt { + fn provider_name() -> &'static str { + "Z.AI" + } + + fn telemetry_provider_name() -> &'static str { + "z.ai" + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolCallConflictPolicy { + KeepIndex, + EvictDistinctIdAndName, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct CompatibleStreamingToolCall<'a> { + pub index: usize, + pub id: Option<&'a str>, + pub name: Option<&'a str>, + pub arguments: Option<&'a str>, +} + +pub(crate) fn apply_compatible_tool_call_deltas<'a, R>( + tool_calls: &mut HashMap, + incoming: impl IntoIterator>, + conflict_policy: ToolCallConflictPolicy, +) -> Vec> +where + R: Clone, +{ + let mut events = Vec::new(); + + for tool_call in incoming { + let index = tool_call.index; + + if conflict_policy == ToolCallConflictPolicy::EvictDistinctIdAndName + && should_evict_existing_tool_call(tool_calls.get(&index), tool_call.id, tool_call.name) + && let Some(evicted) = tool_calls.remove(&index) + { + events.push(RawStreamingChoice::ToolCall( + finalize_completed_streaming_tool_call(evicted), + )); + } + + let existing_tool_call = tool_calls + .entry(index) + .or_insert_with(RawStreamingToolCall::empty); + + if let Some(id) = tool_call.id + && !id.is_empty() + { + existing_tool_call.id = id.to_owned(); + } + + if let Some(name) = tool_call.name + && !name.is_empty() + { + existing_tool_call.name = name.to_owned(); + events.push(RawStreamingChoice::ToolCallDelta { + id: existing_tool_call.id.clone(), + internal_call_id: existing_tool_call.internal_call_id.clone(), + content: ToolCallDeltaContent::Name(name.to_owned()), + }); + } + + if let Some(chunk) = tool_call.arguments + && !chunk.is_empty() + { + append_tool_call_arguments(&mut existing_tool_call.arguments, chunk); + events.push(RawStreamingChoice::ToolCallDelta { + id: existing_tool_call.id.clone(), + internal_call_id: existing_tool_call.internal_call_id.clone(), + content: ToolCallDeltaContent::Delta(chunk.to_owned()), + }); + } + } + + events +} + +pub(crate) fn take_finalized_tool_calls( + tool_calls: &mut HashMap, +) -> Vec> +where + R: Clone, +{ + std::mem::take(tool_calls) + .into_values() + .map(|tool_call| { + RawStreamingChoice::ToolCall(finalize_completed_streaming_tool_call(tool_call)) + }) + .collect() +} + +pub(crate) fn take_tool_calls( + tool_calls: &mut HashMap, +) -> Vec> +where + R: Clone, +{ + std::mem::take(tool_calls) + .into_values() + .map(RawStreamingChoice::ToolCall) + .collect() +} + +pub(crate) fn finalize_completed_streaming_tool_call( + mut tool_call: RawStreamingToolCall, +) -> RawStreamingToolCall { + if tool_call.arguments.is_null() { + tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); + } + + tool_call +} + +fn should_evict_existing_tool_call( + existing: Option<&RawStreamingToolCall>, + new_id: Option<&str>, + new_name: Option<&str>, +) -> bool { + let Some(existing) = existing else { + return false; + }; + + let Some(new_id) = new_id.filter(|id| !id.is_empty()) else { + return false; + }; + let Some(new_name) = new_name.filter(|name| !name.is_empty()) else { + return false; + }; + + !existing.id.is_empty() + && existing.id != new_id + && !existing.name.is_empty() + && existing.name != new_name +} + +fn append_tool_call_arguments(arguments: &mut serde_json::Value, chunk: &str) { + let current_arguments = match arguments { + serde_json::Value::Null => String::new(), + serde_json::Value::String(value) => value.clone(), + ref value => value.to_string(), + }; + let combined = format!("{current_arguments}{chunk}"); + + if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { + match serde_json::from_str(&combined) { + Ok(parsed) => *arguments = parsed, + Err(_) => *arguments = serde_json::Value::String(combined), + } + } else { + *arguments = serde_json::Value::String(combined); + } +} + +#[cfg(test)] +pub(crate) mod request_conformance { + use serde_json::{Value, json}; + + use super::{ + AdditionalParamsPolicy, CompatibleChatProfile, OutputSchemaPolicy, ToolChoicePolicy, + ToolsPolicy, + }; + use crate::completion::{CompletionError, CompletionRequest, Document, ToolDefinition}; + use crate::message::{self, Message, ToolChoice, UserContent}; + use crate::{OneOrMany, completion}; + + #[derive(Debug, Clone, Copy)] + pub(crate) enum Fixture { + ModelOverridePrecedence, + PreambleDocumentHistoryOrdering, + EmptyMessageRejection, + UnsupportedFieldStripping, + ToolChoiceCoercion, + OutputSchemaHandling, + AdditionalParamsPassthrough, + AdditionalParamsDrop, + } + + #[derive(Debug, Clone, PartialEq)] + pub(crate) enum Outcome { + Supported(T), + RequestError, + } + + pub(crate) trait Harness { + #[allow(dead_code)] + fn family_name() -> &'static str; + + fn run(case: Fixture) -> Outcome; + + fn assert(case: Fixture, actual: Outcome); + } + + pub(crate) fn assert_case(case: Fixture) { + let actual = H::run(case); + H::assert(case, actual); + } + + pub(crate) fn serialize_request(request: &T) -> Value { + serde_json::to_value(request).expect("request should serialize") + } + + pub(crate) fn serialize_case(case: Fixture, convert: F) -> Outcome + where + T: serde::Serialize, + F: FnOnce(CompletionRequest) -> Result, + { + match convert(fixture_request(case)) { + Ok(request) => Outcome::Supported(serialize_request(&request)), + Err(CompletionError::RequestError(_)) => Outcome::RequestError, + Err(err) => panic!("unexpected request conversion error for {case:?}: {err}"), + } + } + + pub(crate) fn fixture_request(case: Fixture) -> CompletionRequest { + match case { + Fixture::ModelOverridePrecedence => CompletionRequest { + model: Some("override-model".to_owned()), + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }, + Fixture::PreambleDocumentHistoryOrdering => CompletionRequest { + model: None, + preamble: Some("system".to_owned()), + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: vec![Document { + id: "doc-1".to_owned(), + text: "Document body".to_owned(), + additional_props: Default::default(), + }], + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }, + Fixture::EmptyMessageRejection => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::Assistant { + id: None, + content: OneOrMany::one(message::AssistantContent::reasoning("hidden")), + }), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }, + Fixture::UnsupportedFieldStripping => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup weather".to_owned(), + parameters: json!({ + "type": "object", + "properties": { + "city": { "type": "string" } + } + }), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Auto), + additional_params: Some(json!({"vendor_flag": true})), + output_schema: Some(sample_schema()), + }, + Fixture::ToolChoiceCoercion => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup weather".to_owned(), + parameters: json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Required), + additional_params: None, + output_schema: None, + }, + Fixture::OutputSchemaHandling => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: Some(json!({"vendor_flag": true})), + output_schema: Some(sample_schema()), + }, + Fixture::AdditionalParamsPassthrough | Fixture::AdditionalParamsDrop => { + CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: Some(json!({"vendor_flag": true})), + output_schema: None, + } + } + } + } + + pub(crate) fn model(value: &Value) -> Option<&str> { + value.get("model").and_then(Value::as_str) + } + + pub(crate) fn tools_len(value: &Value) -> usize { + value + .get("tools") + .and_then(Value::as_array) + .map_or(0, Vec::len) + } + + pub(crate) fn tool_choice(value: &Value) -> Option<&Value> { + value.get("tool_choice") + } + + pub(crate) fn top_level<'a>(value: &'a Value, key: &str) -> Option<&'a Value> { + value.get(key) + } + + pub(crate) fn tool_choice_text(value: &Value) -> Option<&str> { + value.get("tool_choice").and_then(Value::as_str) + } + + pub(crate) fn message_summaries(value: &Value) -> Vec { + value + .get("messages") + .and_then(Value::as_array) + .into_iter() + .flatten() + .map(|message| { + let role = message + .get("role") + .and_then(Value::as_str) + .unwrap_or("unknown"); + format!("{role}:{}", extract_text(message)) + }) + .collect() + } + + fn extract_text(message: &Value) -> String { + let content = extract_text_from_value(message.get("content").unwrap_or(&Value::Null)); + if !content.is_empty() { + return content; + } + + if let Some(reasoning) = message.get("reasoning_content").and_then(Value::as_str) { + return reasoning.to_owned(); + } + + extract_text_from_value(message.get("reasoning_details").unwrap_or(&Value::Null)) + } + + fn extract_text_from_value(value: &Value) -> String { + match value { + Value::String(text) => text.clone(), + Value::Array(items) => items + .iter() + .filter_map(extract_text_part) + .collect::>() + .join("\n"), + Value::Object(object) => { + if let Some(text) = object.get("text").and_then(Value::as_str) { + return text.to_owned(); + } + + if let Some(text) = object.get("content").and_then(Value::as_str) { + return text.to_owned(); + } + + if let Some(content) = object.get("content") { + return extract_text_from_value(content); + } + + String::new() + } + _ => String::new(), + } + } + + fn extract_text_part(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.clone()), + Value::Object(object) => object + .get("text") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + .or_else(|| { + object + .get("content") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) + .or_else(|| object.get("content").map(extract_text_from_value)) + .filter(|text| !text.is_empty()), + _ => None, + } + } + + pub(crate) fn expected_document_text() -> String { + completion::Document { + id: "doc-1".to_owned(), + text: "Document body".to_owned(), + additional_props: Default::default(), + } + .to_string() + } + + fn sample_schema() -> schemars::Schema { + serde_json::from_value(json!({ + "title": "fixture_schema", + "type": "object", + "properties": { + "answer": { "type": "string" } + } + })) + .expect("schema should deserialize") + } + + #[derive(Debug, Clone, Copy)] + pub(crate) struct CompatibleChatExpectation { + profile: CompatibleChatProfile, + preserves_reasoning_assistant_history: bool, + includes_document_messages: bool, + } + + impl CompatibleChatExpectation { + pub(crate) const fn new(profile: CompatibleChatProfile) -> Self { + Self { + profile, + preserves_reasoning_assistant_history: false, + includes_document_messages: true, + } + } + + pub(crate) const fn preserves_reasoning_assistant_history(mut self) -> Self { + self.preserves_reasoning_assistant_history = true; + self + } + + pub(crate) const fn omits_document_messages(mut self) -> Self { + self.includes_document_messages = false; + self + } + } + + pub(crate) fn assert_compatible_chat_case( + expectation: CompatibleChatExpectation, + default_model: &str, + case: Fixture, + actual: Outcome, + ) { + let profile = expectation.profile; + match case { + Fixture::ModelOverridePrecedence => { + let request = expect_supported(case, actual); + assert_eq!(model(&request), Some("override-model")); + } + Fixture::PreambleDocumentHistoryOrdering => { + let request = expect_supported(case, actual); + let mut expected = vec!["system:system".to_owned()]; + if expectation.includes_document_messages { + expected.push(format!("user:{}", expected_document_text())); + } + expected.push("user:hello".to_owned()); + assert_eq!(message_summaries(&request), expected,); + } + Fixture::EmptyMessageRejection => { + if profile.require_messages { + assert!( + matches!(actual, Outcome::RequestError), + "expected request rejection for {case:?}, got {actual:?}" + ); + } else { + let request = expect_supported(case, actual); + let summaries = message_summaries(&request); + if expectation.preserves_reasoning_assistant_history { + assert_eq!(summaries, vec!["assistant:hidden".to_owned()]); + } else { + assert!( + summaries.is_empty(), + "expected no provider-compatible messages: {request:?}" + ); + } + } + } + Fixture::UnsupportedFieldStripping => { + let request = expect_supported(case, actual); + assert_eq!(model(&request), Some(default_model)); + assert_eq!( + tools_len(&request), + if matches!(profile.feature_policy.tools_policy, ToolsPolicy::Supported) { + 1 + } else { + 0 + } + ); + assert_tool_choice(&request, profile.feature_policy.tool_choice_policy, "auto"); + assert_vendor_flag( + &request, + matches!( + profile.feature_policy.additional_params_policy, + AdditionalParamsPolicy::PassThrough + ), + ); + assert!( + top_level(&request, "response_format").is_none(), + "tool-bearing turns should omit response_format: {request:?}" + ); + } + Fixture::ToolChoiceCoercion => { + if matches!( + profile.feature_policy.tool_choice_policy, + ToolChoicePolicy::RejectRequired + ) { + assert!( + matches!(actual, Outcome::RequestError), + "expected request rejection for {case:?}, got {actual:?}" + ); + return; + } + + let request = expect_supported(case, actual); + match profile.feature_policy.tool_choice_policy { + ToolChoicePolicy::PassThrough => { + assert_eq!(tool_choice_text(&request), Some("required")); + } + ToolChoicePolicy::Unsupported => { + assert!( + tool_choice(&request).is_none(), + "expected tool_choice to be stripped: {request:?}" + ); + } + ToolChoicePolicy::RejectRequired => unreachable!(), + ToolChoicePolicy::CoerceRequiredToAuto { steering_message } => { + assert_eq!(tool_choice_text(&request), Some("auto")); + let expected_message = format!("user:{steering_message}"); + assert_eq!( + message_summaries(&request).last().map(String::as_str), + Some(expected_message.as_str()), + ); + } + } + } + Fixture::OutputSchemaHandling => { + let request = expect_supported(case, actual); + match profile.feature_policy.output_schema_policy { + OutputSchemaPolicy::NativeResponseFormat => { + assert_eq!( + top_level(&request, "response_format") + .and_then(|value| value.get("json_schema")) + .and_then(|value| value.get("name")) + .and_then(Value::as_str), + Some("fixture_schema"), + ); + } + OutputSchemaPolicy::Unsupported => { + assert!( + top_level(&request, "response_format").is_none(), + "unexpected response_format: {request:?}" + ); + } + } + assert_vendor_flag( + &request, + matches!( + profile.feature_policy.additional_params_policy, + AdditionalParamsPolicy::PassThrough + ), + ); + } + Fixture::AdditionalParamsPassthrough | Fixture::AdditionalParamsDrop => { + let request = expect_supported(case, actual); + assert_vendor_flag( + &request, + matches!( + profile.feature_policy.additional_params_policy, + AdditionalParamsPolicy::PassThrough + ), + ); + } + } + } + + fn expect_supported(case: Fixture, actual: Outcome) -> Value { + match actual { + Outcome::Supported(request) => request, + Outcome::RequestError => { + panic!("expected supported request for {case:?}, got request rejection") + } + } + } + + fn assert_vendor_flag(request: &Value, expected: bool) { + if expected { + assert_eq!(top_level(request, "vendor_flag"), Some(&json!(true))); + } else { + assert!( + top_level(request, "vendor_flag").is_none(), + "expected vendor_flag to be stripped: {request:?}" + ); + } + } + + fn assert_tool_choice(request: &Value, policy: ToolChoicePolicy, expected_pass_through: &str) { + match policy { + ToolChoicePolicy::PassThrough => { + assert_eq!(tool_choice_text(request), Some(expected_pass_through)); + } + ToolChoicePolicy::Unsupported => { + assert!( + tool_choice(request).is_none(), + "expected tool_choice to be stripped: {request:?}" + ); + } + ToolChoicePolicy::RejectRequired => { + assert_eq!(tool_choice_text(request), Some(expected_pass_through)); + } + ToolChoicePolicy::CoerceRequiredToAuto { .. } => { + assert_eq!(tool_choice_text(request), Some(expected_pass_through)); + } + } + } + + macro_rules! provider_request_conformance_tests { + ($harness:ty) => { + #[test] + fn request_conformance_model_override_precedence() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::ModelOverridePrecedence, + ); + } + + #[test] + fn request_conformance_preamble_document_history_ordering() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::PreambleDocumentHistoryOrdering, + ); + } + + #[test] + fn request_conformance_empty_message_rejection() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::EmptyMessageRejection, + ); + } + + #[test] + fn request_conformance_unsupported_field_stripping() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::UnsupportedFieldStripping, + ); + } + + #[test] + fn request_conformance_tool_choice_coercion() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::ToolChoiceCoercion, + ); + } + + #[test] + fn request_conformance_output_schema_handling() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::OutputSchemaHandling, + ); + } + + #[test] + fn request_conformance_additional_params_passthrough() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::AdditionalParamsPassthrough, + ); + } + + #[test] + fn request_conformance_additional_params_drop() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::AdditionalParamsDrop, + ); + } + }; + } + + pub(crate) use provider_request_conformance_tests; +} + +#[cfg(test)] +mod tests { + use super::{ + AdditionalParamsPolicy, CompatibleChatProfile, CompatibleFeaturePolicy, + CompatibleStreamingToolCall, OutputSchemaPolicy, ToolCallConflictPolicy, ToolChoicePolicy, + ToolsPolicy, apply_compatible_tool_call_deltas, build_compatible_request_core, + finalize_completed_streaming_tool_call, take_finalized_tool_calls, + }; + use crate::OneOrMany; + use crate::completion::{CompletionRequest, ToolDefinition}; + use crate::message::{Message, ToolChoice, UserContent}; + use crate::streaming::{RawStreamingChoice, RawStreamingToolCall}; + use std::collections::HashMap; + + #[test] + fn requires_non_empty_messages_when_profile_demands_it() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let err = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider").require_messages(), + |preamble| preamble.to_owned(), + None, + |_| false, + |_message| Ok(Vec::::new()), + ) + .expect_err("empty converted messages should fail"); + + assert!(err.to_string().contains( + "Example Provider request has no provider-compatible messages after conversion" + )); + } + + #[test] + fn preserves_model_override_and_additional_params() { + let req = CompletionRequest { + model: Some("override-model".to_owned()), + preamble: Some("system".to_owned()), + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: Some(0.5), + max_tokens: Some(42), + tool_choice: None, + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: None, + }; + + let result = build_compatible_request_core( + "default-model", + req, + CompatibleChatProfile::new("Example Provider"), + |preamble| preamble.to_owned(), + None, + |_| false, + |message| Ok(vec![format!("{message:?}")]), + ) + .expect("request conversion should succeed"); + + assert_eq!(result.model, "override-model"); + assert_eq!(result.temperature, Some(0.5)); + assert_eq!(result.max_tokens, Some(42)); + assert_eq!( + result.additional_params, + Some(serde_json::json!({"foo":"bar"})) + ); + assert_eq!(result.messages.len(), 2); + } + + #[test] + fn feature_policy_strips_unsupported_fields() { + let mut req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup the weather".to_owned(), + parameters: serde_json::json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Auto), + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "example", + "type": "object" + })) + .expect("schema should deserialize"), + ), + }; + + CompatibleFeaturePolicy::default() + .with_tools_policy(ToolsPolicy::Unsupported) + .with_tool_choice_policy(ToolChoicePolicy::Unsupported) + .with_output_schema_policy(OutputSchemaPolicy::Unsupported) + .with_additional_params_policy(AdditionalParamsPolicy::Unsupported) + .apply("Example Provider", &mut req) + .expect("policy application should succeed"); + + assert!(req.tools.is_empty()); + assert!(req.tool_choice.is_none()); + assert!(req.additional_params.is_none()); + assert!(req.output_schema.is_none()); + } + + #[test] + fn tool_choice_policy_can_coerce_required_to_auto_with_steering_message() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup the weather".to_owned(), + parameters: serde_json::json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Required), + additional_params: None, + output_schema: None, + }; + + let result = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider") + .coerce_required_tool_choice_to_auto("Use a tool."), + |preamble| preamble.to_owned(), + Some(|text| text.to_owned()), + |_| false, + |message| Ok(vec![format!("{message:?}")]), + ) + .expect("request conversion should succeed"); + + assert_eq!(result.tool_choice, Some(ToolChoice::Auto)); + assert_eq!(result.messages.last(), Some(&"Use a tool.".to_owned())); + } + + #[test] + fn native_response_format_merges_schema_into_additional_params() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: Some(serde_json::json!({"vendor_flag": true})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "example", + "type": "object", + "properties": { + "answer": { "type": "string" } + } + })) + .expect("schema should deserialize"), + ), + }; + + let result = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider").native_response_format(), + |preamble| preamble.to_owned(), + None, + |_| false, + |message| Ok(vec![format!("{message:?}")]), + ) + .expect("request conversion should succeed"); + + let additional_params = result + .additional_params + .expect("response_format should be merged"); + assert_eq!( + additional_params.get("vendor_flag"), + Some(&serde_json::json!(true)) + ); + assert!(additional_params.get("response_format").is_some()); + } + + #[test] + fn evicts_distinct_tool_calls_that_reuse_the_same_index() { + let mut tool_calls = HashMap::from([( + 0, + RawStreamingToolCall { + id: "call_1".to_owned(), + internal_call_id: "internal_1".to_owned(), + call_id: None, + name: "weather".to_owned(), + arguments: serde_json::json!({"city":"Paris"}), + signature: None, + additional_params: None, + }, + )]); + + let events = apply_compatible_tool_call_deltas::<()>( + &mut tool_calls, + [CompatibleStreamingToolCall { + index: 0, + id: Some("call_2"), + name: Some("time"), + arguments: Some("{"), + }], + ToolCallConflictPolicy::EvictDistinctIdAndName, + ); + + assert!( + matches!(events.first(), Some(RawStreamingChoice::ToolCall(tool_call)) if tool_call.id == "call_1") + ); + assert_eq!( + tool_calls.get(&0).map(|tool_call| tool_call.id.as_str()), + Some("call_2") + ); + assert_eq!( + tool_calls.get(&0).map(|tool_call| tool_call.name.as_str()), + Some("time") + ); + } + + #[test] + fn finalizes_null_arguments_into_empty_objects() { + let finalized = finalize_completed_streaming_tool_call(RawStreamingToolCall::empty()); + assert_eq!(finalized.arguments, serde_json::json!({})); + } + + #[test] + fn drains_finalized_tool_calls() { + let mut tool_calls = HashMap::from([(0, RawStreamingToolCall::empty())]); + + let events = take_finalized_tool_calls::<()>(&mut tool_calls); + + assert!(tool_calls.is_empty()); + assert!( + matches!(events.as_slice(), [RawStreamingChoice::ToolCall(tool_call)] if tool_call.arguments == serde_json::json!({})) + ); + } +} diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index e783464b3..ee10a93bd 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -22,21 +22,17 @@ use tracing::{Instrument, Level, enabled, info_span}; use std::str::FromStr; -mod compat; -mod request_compat; -mod stream_compat; +mod family; pub mod streaming; -pub(crate) use compat::{ - build_completion_response, first_choice, map_finish_reason, non_empty_text, -}; -pub(crate) use request_compat::{ - CompatibleChatProfile, CompatibleFeaturePolicy, CompatibleRequestCore, - build_compatible_request_core, -}; -pub(crate) use stream_compat::{ - CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, - take_finalized_tool_calls, take_tool_calls, +#[cfg(test)] +pub(crate) use family::request_conformance; +pub(crate) use family::{ + AdditionalParamsPolicy, CompatibleChatProfile, CompatibleFeaturePolicy, CompatibleRequestCore, + CompatibleStreamingToolCall, OpenAiChatProviderProfile, ToolCallConflictPolicy, + ToolChoicePolicy, ToolsPolicy, apply_compatible_tool_call_deltas, + build_compatible_request_core, build_completion_response, first_choice, map_finish_reason, + non_empty_text, take_finalized_tool_calls, take_tool_calls, }; /// Serializes user content as a plain string when there's a single text item, @@ -200,10 +196,11 @@ impl Message { } } -fn history_contains_tool_result(messages: &[Message]) -> bool { - messages - .iter() - .any(|message| matches!(message, Message::ToolResult { .. })) +fn openai_user_message(content: &str) -> Message { + Message::User { + content: OneOrMany::one(UserContent::from(content.to_owned())), + name: None, + } } #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] @@ -1055,102 +1052,71 @@ pub struct OpenAIRequestParams { pub tool_result_array_content: bool, } -impl TryFrom for CompletionRequest { - type Error = CompletionError; - - fn try_from(params: OpenAIRequestParams) -> Result { - let OpenAIRequestParams { - model, - request: req, - strict_tools, - tool_result_array_content, - } = params; - let CompatibleRequestCore { - model, - messages: mut full_history, - temperature, - max_tokens, - tools, - tool_choice, - additional_params, - output_schema, - } = build_compatible_request_core( - &model, - req, - CompatibleChatProfile::new("OpenAI Chat Completions") - .require_messages() - .supports_output_schema(), - Message::system, - |message| Vec::::try_from(message).map_err(CompletionError::from), - )?; - - if tool_result_array_content { - for msg in &mut full_history { - if let Message::ToolResult { content, .. } = msg { - *content = content.to_array(); - } +fn build_openai_completion_request( + params: OpenAIRequestParams, + profile: CompatibleChatProfile, +) -> Result { + let OpenAIRequestParams { + model, + request: req, + strict_tools, + tool_result_array_content, + } = params; + let CompatibleRequestCore { + model, + messages: mut full_history, + temperature, + max_tokens, + tools, + tool_choice, + additional_params, + } = build_compatible_request_core( + &model, + req, + profile, + Message::system, + Some(openai_user_message), + |message| matches!(message, Message::ToolResult { .. }), + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + if tool_result_array_content { + for msg in &mut full_history { + if let Message::ToolResult { content, .. } = msg { + *content = content.to_array(); } } + } - let history_has_tool_result = history_contains_tool_result(&full_history); + let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; - let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; + let tools: Vec = tools + .into_iter() + .map(|tool| { + let def = ToolDefinition::from(tool); + if strict_tools { def.with_strict() } else { def } + }) + .collect(); - let tools: Vec = tools - .into_iter() - .map(|tool| { - let def = ToolDefinition::from(tool); - if strict_tools { def.with_strict() } else { def } - }) - .collect(); - - // Some OpenAI-compatible backends such as llama.cpp will skip tool execution - // if `response_format` is sent on the first turn alongside tools. Delay the - // schema until after the conversation contains a tool result. - let should_apply_response_format = - output_schema.is_some() && (tools.is_empty() || history_has_tool_result); - - // Map output_schema to OpenAI's response_format and merge into additional_params - let additional_params = if let Some(schema) = output_schema - && should_apply_response_format - { - let name = schema - .as_object() - .and_then(|o| o.get("title")) - .and_then(|v| v.as_str()) - .unwrap_or("response_schema") - .to_string(); - let mut schema_value = schema.to_value(); - super::sanitize_schema(&mut schema_value); - let response_format = serde_json::json!({ - "response_format": { - "type": "json_schema", - "json_schema": { - "name": name, - "strict": true, - "schema": schema_value - } - } - }); - Some(match additional_params { - Some(existing) => json_utils::merge(existing, response_format), - None => response_format, - }) - } else { - additional_params - }; + Ok(CompletionRequest { + model, + messages: full_history, + tools, + tool_choice, + temperature, + max_tokens, + additional_params, + }) +} - let res = Self { - model, - messages: full_history, - tools, - tool_choice, - temperature, - max_tokens, - additional_params, - }; +impl TryFrom for CompletionRequest { + type Error = CompletionError; - Ok(res) + fn try_from(params: OpenAIRequestParams) -> Result { + build_openai_completion_request( + params, + CompatibleChatProfile::openai_chat_completions("OpenAI Chat Completions"), + ) } } @@ -1216,6 +1182,7 @@ where crate::client::Client: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static, Ext: crate::client::Provider + + OpenAiChatProviderProfile + crate::client::DebugExt + Clone + WasmCompatSend @@ -1241,7 +1208,7 @@ where target: "rig::completions", "chat", gen_ai.operation.name = "chat", - gen_ai.provider.name = "openai", + gen_ai.provider.name = Ext::telemetry_provider_name(), gen_ai.request.model = self.model, gen_ai.system_instructions = &completion_request.preamble, gen_ai.response.id = tracing::field::Empty, @@ -1254,17 +1221,21 @@ where tracing::Span::current() }; - let request = CompletionRequest::try_from(OpenAIRequestParams { - model: self.model.to_owned(), - request: completion_request, - strict_tools: self.strict_tools, - tool_result_array_content: self.tool_result_array_content, - })?; + let request = build_openai_completion_request( + OpenAIRequestParams { + model: self.model.to_owned(), + request: completion_request, + strict_tools: self.strict_tools, + tool_result_array_content: self.tool_result_array_content, + }, + Ext::request_profile(), + )?; if enabled!(Level::TRACE) { tracing::trace!( target: "rig::completions", - "OpenAI Chat Completions completion request: {}", + "{} completion request: {}", + Ext::provider_name(), serde_json::to_string_pretty(&request)? ); } @@ -1273,7 +1244,7 @@ where let req = self .client - .post("/chat/completions")? + .post(Ext::completions_path())? .body(body) .map_err(|e| CompletionError::HttpError(e.into()))?; @@ -1292,7 +1263,8 @@ where if enabled!(Level::TRACE) { tracing::trace!( target: "rig::completions", - "OpenAI Chat Completions completion response: {}", + "{} completion response: {}", + Ext::provider_name(), serde_json::to_string_pretty(&response)? ); } @@ -1342,6 +1314,43 @@ mod conformance_tests; mod tests { use super::*; + struct OpenAiRequestHarness; + + impl request_conformance::Harness for OpenAiRequestHarness { + fn family_name() -> &'static str { + "openai-chat" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + CompletionRequest::try_from(OpenAIRequestParams { + model: "default-model".to_owned(), + request, + strict_tools: false, + tool_result_array_content: false, + }) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::openai_chat_completions("OpenAI Chat Completions"), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(OpenAiRequestHarness); + #[test] fn test_openai_request_uses_request_model_override() { let request = crate::completion::CompletionRequest { diff --git a/rig/rig-core/src/providers/openai/completion/request_compat.rs b/rig/rig-core/src/providers/openai/completion/request_compat.rs deleted file mode 100644 index 37155313e..000000000 --- a/rig/rig-core/src/providers/openai/completion/request_compat.rs +++ /dev/null @@ -1,317 +0,0 @@ -use crate::completion::{CompletionError, CompletionRequest as CoreCompletionRequest}; -use crate::message; - -/// Shared unsupported-feature policy for request builders. -#[derive(Debug, Clone, Copy)] -pub(crate) struct CompatibleFeaturePolicy { - supports_output_schema: bool, - supports_tools: bool, - supports_tool_choice: bool, - supports_additional_params: bool, -} - -impl Default for CompatibleFeaturePolicy { - fn default() -> Self { - Self { - supports_output_schema: false, - supports_tools: true, - supports_tool_choice: true, - supports_additional_params: true, - } - } -} - -impl CompatibleFeaturePolicy { - pub(crate) const fn supports_output_schema(mut self) -> Self { - self.supports_output_schema = true; - self - } - - pub(crate) const fn without_tools(mut self) -> Self { - self.supports_tools = false; - self - } - - pub(crate) const fn without_tool_choice(mut self) -> Self { - self.supports_tool_choice = false; - self - } - - pub(crate) const fn without_additional_params(mut self) -> Self { - self.supports_additional_params = false; - self - } - - pub(crate) fn apply(self, provider_name: &'static str, req: &mut CoreCompletionRequest) { - if req.output_schema.is_some() && !self.supports_output_schema { - tracing::warn!( - "Structured outputs currently not supported for {}", - provider_name - ); - req.output_schema = None; - } - - if !req.tools.is_empty() && !self.supports_tools { - tracing::warn!("WARNING: `tools` not supported on {}", provider_name); - req.tools.clear(); - } - - if req.tool_choice.is_some() && !self.supports_tool_choice { - tracing::warn!("WARNING: `tool_choice` not supported on {}", provider_name); - req.tool_choice = None; - } - - if req.additional_params.is_some() && !self.supports_additional_params { - tracing::warn!( - "WARNING: `additional_params` not supported on {}", - provider_name - ); - req.additional_params = None; - } - } -} - -/// Shared request-shaping profile for OpenAI-compatible chat providers. -#[derive(Debug, Clone, Copy)] -pub(crate) struct CompatibleChatProfile { - provider_name: &'static str, - require_messages: bool, - feature_policy: CompatibleFeaturePolicy, -} - -impl CompatibleChatProfile { - pub(crate) const fn new(provider_name: &'static str) -> Self { - Self { - provider_name, - require_messages: false, - feature_policy: CompatibleFeaturePolicy { - supports_output_schema: false, - supports_tools: true, - supports_tool_choice: true, - supports_additional_params: true, - }, - } - } - - pub(crate) const fn require_messages(mut self) -> Self { - self.require_messages = true; - self - } - - pub(crate) const fn supports_output_schema(mut self) -> Self { - self.feature_policy = self.feature_policy.supports_output_schema(); - self - } - - pub(crate) const fn without_tools(mut self) -> Self { - self.feature_policy = self.feature_policy.without_tools(); - self - } - - pub(crate) const fn without_tool_choice(mut self) -> Self { - self.feature_policy = self.feature_policy.without_tool_choice(); - self - } -} - -/// Provider-agnostic core request fields shared by OpenAI-compatible chat families. -#[derive(Debug)] -pub(crate) struct CompatibleRequestCore { - pub model: String, - pub messages: Vec, - pub temperature: Option, - pub max_tokens: Option, - pub tools: Vec, - pub tool_choice: Option, - pub additional_params: Option, - pub output_schema: Option, -} - -pub(crate) fn build_compatible_request_core( - default_model: &str, - mut req: CoreCompletionRequest, - profile: CompatibleChatProfile, - system_message: impl Fn(&str) -> M, - mut convert_message: F, -) -> Result, CompletionError> -where - F: FnMut(message::Message) -> Result, CompletionError>, -{ - profile - .feature_policy - .apply(profile.provider_name, &mut req); - - let normalized_documents = req.normalized_documents(); - let CoreCompletionRequest { - model: request_model, - preamble, - chat_history, - tools, - temperature, - max_tokens, - tool_choice, - additional_params, - output_schema, - .. - } = req; - - let mut partial_history = Vec::new(); - if let Some(docs) = normalized_documents { - partial_history.push(docs); - } - partial_history.extend(chat_history); - - let mut messages = preamble - .as_deref() - .map_or_else(Vec::new, |preamble| vec![system_message(preamble)]); - - messages.extend( - partial_history - .into_iter() - .map(&mut convert_message) - .collect::>, _>>()? - .into_iter() - .flatten(), - ); - - if profile.require_messages && messages.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - format!( - "{} request has no provider-compatible messages after conversion", - profile.provider_name - ), - ) - .into(), - )); - } - - Ok(CompatibleRequestCore { - model: request_model.unwrap_or_else(|| default_model.to_owned()), - messages, - temperature, - max_tokens, - tools, - tool_choice, - additional_params, - output_schema, - }) -} - -#[cfg(test)] -mod tests { - use super::{CompatibleChatProfile, CompatibleFeaturePolicy, build_compatible_request_core}; - use crate::OneOrMany; - use crate::completion::{CompletionRequest, ToolDefinition}; - use crate::message::{Message, ToolChoice, UserContent}; - - #[test] - fn requires_non_empty_messages_when_profile_demands_it() { - let req = CompletionRequest { - model: None, - preamble: None, - chat_history: OneOrMany::one(Message::User { - content: OneOrMany::one(UserContent::text("hello")), - }), - documents: Vec::new(), - tools: Vec::new(), - temperature: None, - max_tokens: None, - tool_choice: None, - additional_params: None, - output_schema: None, - }; - - let err = build_compatible_request_core( - "test-model", - req, - CompatibleChatProfile::new("Example Provider").require_messages(), - |preamble| preamble.to_owned(), - |_message| Ok(Vec::::new()), - ) - .expect_err("empty converted messages should fail"); - - assert!(err.to_string().contains( - "Example Provider request has no provider-compatible messages after conversion" - )); - } - - #[test] - fn preserves_model_override_and_additional_params() { - let req = CompletionRequest { - model: Some("override-model".to_owned()), - preamble: Some("system".to_owned()), - chat_history: OneOrMany::one(Message::User { - content: OneOrMany::one(UserContent::text("hello")), - }), - documents: Vec::new(), - tools: Vec::new(), - temperature: Some(0.5), - max_tokens: Some(42), - tool_choice: None, - additional_params: Some(serde_json::json!({"foo":"bar"})), - output_schema: None, - }; - - let result = build_compatible_request_core( - "default-model", - req, - CompatibleChatProfile::new("Example Provider"), - |preamble| format!("system:{preamble}"), - |_message| Ok(vec!["history".to_owned()]), - ) - .expect("request conversion should succeed"); - - assert_eq!(result.model, "override-model"); - assert_eq!(result.temperature, Some(0.5)); - assert_eq!(result.max_tokens, Some(42)); - assert_eq!( - result.messages, - vec!["system:system".to_owned(), "history".to_owned()] - ); - assert_eq!( - result.additional_params, - Some(serde_json::json!({"foo":"bar"})) - ); - } - - #[test] - fn feature_policy_strips_unsupported_fields() { - let mut req = CompletionRequest { - model: None, - preamble: None, - chat_history: OneOrMany::one(Message::User { - content: OneOrMany::one(UserContent::text("hello")), - }), - documents: Vec::new(), - tools: vec![ToolDefinition { - name: "ping".to_owned(), - description: "Ping tool".to_owned(), - parameters: serde_json::json!({"type":"object"}), - }], - temperature: None, - max_tokens: None, - tool_choice: Some(ToolChoice::Required), - additional_params: Some(serde_json::json!({"foo":"bar"})), - output_schema: Some( - serde_json::from_value(serde_json::json!({ - "title": "Example", - "type": "object" - })) - .expect("schema should deserialize"), - ), - }; - - CompatibleFeaturePolicy::default() - .without_tools() - .without_tool_choice() - .without_additional_params() - .apply("Example Provider", &mut req); - - assert!(req.tools.is_empty()); - assert!(req.tool_choice.is_none()); - assert!(req.additional_params.is_none()); - assert!(req.output_schema.is_none()); - } -} diff --git a/rig/rig-core/src/providers/openai/completion/stream_compat.rs b/rig/rig-core/src/providers/openai/completion/stream_compat.rs deleted file mode 100644 index ab8cccb22..000000000 --- a/rig/rig-core/src/providers/openai/completion/stream_compat.rs +++ /dev/null @@ -1,218 +0,0 @@ -use std::collections::HashMap; - -use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) enum ToolCallConflictPolicy { - KeepIndex, - EvictDistinctIdAndName, -} - -#[derive(Debug, Clone, Copy)] -pub(crate) struct CompatibleStreamingToolCall<'a> { - pub index: usize, - pub id: Option<&'a str>, - pub name: Option<&'a str>, - pub arguments: Option<&'a str>, -} - -pub(crate) fn apply_compatible_tool_call_deltas<'a, R>( - tool_calls: &mut HashMap, - incoming: impl IntoIterator>, - conflict_policy: ToolCallConflictPolicy, -) -> Vec> -where - R: Clone, -{ - let mut events = Vec::new(); - - for tool_call in incoming { - let index = tool_call.index; - - if conflict_policy == ToolCallConflictPolicy::EvictDistinctIdAndName - && should_evict_existing_tool_call(tool_calls.get(&index), tool_call.id, tool_call.name) - && let Some(evicted) = tool_calls.remove(&index) - { - events.push(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); - } - - let existing_tool_call = tool_calls - .entry(index) - .or_insert_with(RawStreamingToolCall::empty); - - if let Some(id) = tool_call.id - && !id.is_empty() - { - existing_tool_call.id = id.to_owned(); - } - - if let Some(name) = tool_call.name - && !name.is_empty() - { - existing_tool_call.name = name.to_owned(); - events.push(RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: ToolCallDeltaContent::Name(name.to_owned()), - }); - } - - if let Some(chunk) = tool_call.arguments - && !chunk.is_empty() - { - append_tool_call_arguments(&mut existing_tool_call.arguments, chunk); - events.push(RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: ToolCallDeltaContent::Delta(chunk.to_owned()), - }); - } - } - - events -} - -pub(crate) fn take_finalized_tool_calls( - tool_calls: &mut HashMap, -) -> Vec> -where - R: Clone, -{ - std::mem::take(tool_calls) - .into_values() - .map(|tool_call| { - RawStreamingChoice::ToolCall(finalize_completed_streaming_tool_call(tool_call)) - }) - .collect() -} - -pub(crate) fn take_tool_calls( - tool_calls: &mut HashMap, -) -> Vec> -where - R: Clone, -{ - std::mem::take(tool_calls) - .into_values() - .map(RawStreamingChoice::ToolCall) - .collect() -} - -pub(crate) fn finalize_completed_streaming_tool_call( - mut tool_call: RawStreamingToolCall, -) -> RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); - } - - tool_call -} - -fn should_evict_existing_tool_call( - existing: Option<&RawStreamingToolCall>, - new_id: Option<&str>, - new_name: Option<&str>, -) -> bool { - let Some(existing) = existing else { - return false; - }; - - let Some(new_id) = new_id.filter(|id| !id.is_empty()) else { - return false; - }; - let Some(new_name) = new_name.filter(|name| !name.is_empty()) else { - return false; - }; - - !existing.id.is_empty() - && existing.id != new_id - && !existing.name.is_empty() - && existing.name != new_name -} - -fn append_tool_call_arguments(arguments: &mut serde_json::Value, chunk: &str) { - let current_arguments = match arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(value) => value.clone(), - ref value => value.to_string(), - }; - let combined = format!("{current_arguments}{chunk}"); - - if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { - match serde_json::from_str(&combined) { - Ok(parsed) => *arguments = parsed, - Err(_) => *arguments = serde_json::Value::String(combined), - } - } else { - *arguments = serde_json::Value::String(combined); - } -} - -#[cfg(test)] -mod tests { - use super::{ - CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, - finalize_completed_streaming_tool_call, take_finalized_tool_calls, - }; - use crate::streaming::{RawStreamingChoice, RawStreamingToolCall}; - use std::collections::HashMap; - - #[test] - fn evicts_distinct_tool_calls_that_reuse_the_same_index() { - let mut tool_calls = HashMap::from([( - 0, - RawStreamingToolCall { - id: "call_1".to_owned(), - internal_call_id: "internal_1".to_owned(), - call_id: None, - name: "weather".to_owned(), - arguments: serde_json::json!({"city":"Paris"}), - signature: None, - additional_params: None, - }, - )]); - - let events = apply_compatible_tool_call_deltas::<()>( - &mut tool_calls, - [CompatibleStreamingToolCall { - index: 0, - id: Some("call_2"), - name: Some("time"), - arguments: Some("{"), - }], - ToolCallConflictPolicy::EvictDistinctIdAndName, - ); - - assert!( - matches!(events.first(), Some(RawStreamingChoice::ToolCall(tool_call)) if tool_call.id == "call_1") - ); - assert_eq!( - tool_calls.get(&0).map(|tool_call| tool_call.id.as_str()), - Some("call_2") - ); - assert_eq!( - tool_calls.get(&0).map(|tool_call| tool_call.name.as_str()), - Some("time") - ); - } - - #[test] - fn finalizes_null_arguments_into_empty_objects() { - let finalized = finalize_completed_streaming_tool_call(RawStreamingToolCall::empty()); - assert_eq!(finalized.arguments, serde_json::json!({})); - } - - #[test] - fn drains_finalized_tool_calls() { - let mut tool_calls = HashMap::from([(0, RawStreamingToolCall::empty())]); - - let events = take_finalized_tool_calls::<()>(&mut tool_calls); - - assert!(tool_calls.is_empty()); - assert!( - matches!(events.as_slice(), [RawStreamingChoice::ToolCall(tool_call)] if tool_call.arguments == serde_json::json!({})) - ); - } -} diff --git a/rig/rig-core/src/providers/openai/completion/streaming.rs b/rig/rig-core/src/providers/openai/completion/streaming.rs index b2fe808da..a727300ca 100644 --- a/rig/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig/rig-core/src/providers/openai/completion/streaming.rs @@ -14,8 +14,8 @@ use crate::http_client::sse::{Event, GenericEventSource}; use crate::json_utils::{self, merge}; use crate::providers::openai::completion::{ CompatibleStreamingToolCall, GenericCompletionModel, OpenAIRequestParams, - ToolCallConflictPolicy, Usage, apply_compatible_tool_call_deltas, take_finalized_tool_calls, - take_tool_calls, + OpenAiChatProviderProfile, ToolCallConflictPolicy, Usage, apply_compatible_tool_call_deltas, + take_finalized_tool_calls, take_tool_calls, }; use crate::streaming::{self, RawStreamingChoice}; @@ -101,19 +101,22 @@ fn map_finish_reason(reason: &FinishReason) -> crate::completion::StopReason { impl GenericCompletionModel where crate::client::Client: HttpClientExt + Clone + 'static, - Ext: crate::client::Provider + Clone + 'static, + Ext: crate::client::Provider + OpenAiChatProviderProfile + Clone + 'static, { pub(crate) async fn stream( &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request = super::CompletionRequest::try_from(OpenAIRequestParams { - model: self.model.clone(), - request: completion_request, - strict_tools: self.strict_tools, - tool_result_array_content: self.tool_result_array_content, - })?; + let request = super::build_openai_completion_request( + OpenAIRequestParams { + model: self.model.clone(), + request: completion_request, + strict_tools: self.strict_tools, + tool_result_array_content: self.tool_result_array_content, + }, + Ext::request_profile(), + )?; let request_messages = serde_json::to_string(&request.messages) .expect("Converting to JSON from a Rust struct shouldn't fail"); let mut request_as_json = serde_json::to_value(request).expect("this should never fail"); @@ -126,7 +129,8 @@ where if enabled!(Level::TRACE) { tracing::trace!( target: "rig::completions", - "OpenAI Chat Completions streaming completion request: {}", + "{} streaming completion request: {}", + Ext::provider_name(), serde_json::to_string_pretty(&request_as_json)? ); } @@ -135,7 +139,7 @@ where let req = self .client - .post("/chat/completions")? + .post(Ext::completions_path())? .body(req_body) .map_err(|e| CompletionError::HttpError(e.into()))?; @@ -144,7 +148,7 @@ where target: "rig::completions", "chat", gen_ai.operation.name = "chat", - gen_ai.provider.name = "openai", + gen_ai.provider.name = Ext::telemetry_provider_name(), gen_ai.request.model = self.model, gen_ai.response.id = tracing::field::Empty, gen_ai.response.model = self.model, @@ -160,7 +164,15 @@ where let client = self.client.clone(); - tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await + tracing::Instrument::instrument( + send_compatible_streaming_request_with_policy( + client, + req, + Ext::stream_tool_call_conflict_policy(), + ), + span, + ) + .await } } @@ -168,6 +180,22 @@ pub async fn send_compatible_streaming_request( http_client: T, req: Request>, ) -> Result, CompletionError> +where + T: HttpClientExt + Clone + 'static, +{ + send_compatible_streaming_request_with_policy( + http_client, + req, + ToolCallConflictPolicy::EvictDistinctIdAndName, + ) + .await +} + +pub(crate) async fn send_compatible_streaming_request_with_policy( + http_client: T, + req: Request>, + conflict_policy: ToolCallConflictPolicy, +) -> Result, CompletionError> where T: HttpClientExt + Clone + 'static, { @@ -224,7 +252,7 @@ where name: tool_call.function.name.as_deref(), arguments: tool_call.function.arguments.as_deref(), }), - ToolCallConflictPolicy::EvictDistinctIdAndName, + conflict_policy, ) { yield Ok(event); } diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index c58e46fbc..f2570ffd2 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -1515,11 +1515,17 @@ impl TryFrom for Vec { } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged, rename_all = "snake_case")] -pub enum ToolChoice { +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceKeyword { None, Auto, Required, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Keyword(ToolChoiceKeyword), Function(Vec), } @@ -1528,9 +1534,9 @@ impl TryFrom for ToolChoice { fn try_from(value: crate::message::ToolChoice) -> Result { let res = match value { - crate::message::ToolChoice::None => Self::None, - crate::message::ToolChoice::Auto => Self::Auto, - crate::message::ToolChoice::Required => Self::Required, + crate::message::ToolChoice::None => Self::Keyword(ToolChoiceKeyword::None), + crate::message::ToolChoice::Auto => Self::Keyword(ToolChoiceKeyword::Auto), + crate::message::ToolChoice::Required => Self::Keyword(ToolChoiceKeyword::Required), crate::message::ToolChoice::Specific { function_names } => { let vec: Vec = function_names .into_iter() @@ -1589,12 +1595,13 @@ impl TryFrom> for OpenrouterCompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("OpenRouter"), Message::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -1781,6 +1788,47 @@ mod tests { use super::*; use serde_json::json; + struct OpenRouterRequestHarness; + + impl crate::providers::openai::completion::request_conformance::Harness + for OpenRouterRequestHarness + { + fn family_name() -> &'static str { + "openrouter" + } + + fn run( + case: crate::providers::openai::completion::request_conformance::Fixture, + ) -> crate::providers::openai::completion::request_conformance::Outcome + { + crate::providers::openai::completion::request_conformance::serialize_case( + case, + |request| OpenrouterCompletionRequest::try_from(("default-model", request)), + ) + } + + fn assert( + case: crate::providers::openai::completion::request_conformance::Fixture, + actual: crate::providers::openai::completion::request_conformance::Outcome< + serde_json::Value, + >, + ) { + crate::providers::openai::completion::request_conformance::assert_compatible_chat_case( + crate::providers::openai::completion::request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("OpenRouter"), + ) + .preserves_reasoning_assistant_history(), + "default-model", + case, + actual, + ); + } + } + + crate::providers::openai::completion::request_conformance::provider_request_conformance_tests!( + OpenRouterRequestHarness + ); + #[test] fn test_openrouter_request_uses_request_model_override() { let request = CompletionRequest { diff --git a/rig/rig-core/src/providers/perplexity.rs b/rig/rig-core/src/providers/perplexity.rs index a34d4dc23..baa81169f 100644 --- a/rig/rig-core/src/providers/perplexity.rs +++ b/rig/rig-core/src/providers/perplexity.rs @@ -239,17 +239,18 @@ impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest { tools: _, tool_choice: _, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("Perplexity") - .without_tools() - .without_tool_choice(), + .unsupported_tools() + .unsupported_tool_choice(), |preamble| Message { role: Role::System, content: preamble.to_owned(), }, + None, + |_| false, |message| Ok(vec![message.try_into()?]), )?; diff --git a/rig/rig-core/src/providers/together/completion.rs b/rig/rig-core/src/providers/together/completion.rs index bbe30363c..de3a899c4 100644 --- a/rig/rig-core/src/providers/together/completion.rs +++ b/rig/rig-core/src/providers/together/completion.rs @@ -155,13 +155,15 @@ impl TryFrom<(&str, CompletionRequest)> for TogetherAICompletionRequest { tools, tool_choice, additional_params, - output_schema: _, } = crate::providers::openai::completion::build_compatible_request_core( model, req, crate::providers::openai::completion::CompatibleChatProfile::new("TogetherAI") - .require_messages(), + .require_messages() + .reject_required_tool_choice(), openai::Message::system, + None, + |_| false, |message| Vec::::try_from(message).map_err(CompletionError::from), )?; @@ -303,10 +305,16 @@ where } #[derive(Debug, Serialize, Deserialize)] -#[serde(untagged, rename_all = "snake_case")] -pub enum ToolChoice { +#[serde(rename_all = "snake_case")] +pub enum ToolChoiceKeyword { None, Auto, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ToolChoice { + Keyword(ToolChoiceKeyword), Function(Vec), } @@ -315,8 +323,8 @@ impl TryFrom for ToolChoice { fn try_from(value: crate::message::ToolChoice) -> Result { let res = match value { - crate::message::ToolChoice::None => Self::None, - crate::message::ToolChoice::Auto => Self::Auto, + crate::message::ToolChoice::None => Self::Keyword(ToolChoiceKeyword::None), + crate::message::ToolChoice::Auto => Self::Keyword(ToolChoiceKeyword::Auto), crate::message::ToolChoice::Specific { function_names } => { let vec: Vec = function_names .into_iter() @@ -345,8 +353,43 @@ pub enum ToolChoiceFunctionKind { #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::request_conformance; use crate::{OneOrMany, message}; + struct TogetherRequestHarness; + + impl request_conformance::Harness for TogetherRequestHarness { + fn family_name() -> &'static str { + "together-ai" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + TogetherAICompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("TogetherAI") + .require_messages() + .reject_required_tool_choice(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(TogetherRequestHarness); + #[test] fn together_request_conversion_errors_when_all_messages_are_filtered() { let request = CompletionRequest { From 196a215eaa59475a54a8fe1fa9c65daa2488747e Mon Sep 17 00:00:00 2001 From: stephen Date: Fri, 17 Apr 2026 21:30:28 -0700 Subject: [PATCH 9/9] s --- rig/rig-core/src/providers/deepseek.rs | 7 +- rig/rig-core/src/providers/galadriel.rs | 115 +++++++++++- rig/rig-core/src/providers/mira.rs | 170 +++++++++++------- .../src/providers/openai/completion/family.rs | 110 +++++++++--- .../src/providers/openai/completion/mod.rs | 8 +- .../src/providers/openrouter/completion.rs | 52 +----- rig/rig-core/src/providers/perplexity.rs | 92 +++++++++- .../src/providers/together/completion.rs | 51 +----- 8 files changed, 416 insertions(+), 189 deletions(-) diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index 04ed03621..8f3827309 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -460,7 +460,7 @@ pub(super) struct DeepseekCompletionRequest { #[serde(skip_serializing_if = "Vec::is_empty")] tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, + tool_choice: Option, #[serde(flatten, skip_serializing_if = "Option::is_none")] pub additional_params: Option, } @@ -487,9 +487,8 @@ impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest { |message| Vec::::try_from(message).map_err(CompletionError::from), )?; - let tool_choice = tool_choice - .map(crate::providers::openrouter::ToolChoice::try_from) - .transpose()?; + let tool_choice = + tool_choice.map(crate::providers::openai::completion::CompatibleToolChoice::from); Ok(Self { model, diff --git a/rig/rig-core/src/providers/galadriel.rs b/rig/rig-core/src/providers/galadriel.rs index 25f594c45..b68784d23 100644 --- a/rig/rig-core/src/providers/galadriel.rs +++ b/rig/rig-core/src/providers/galadriel.rs @@ -449,6 +449,84 @@ pub(super) struct GaladrielCompletionRequest { pub additional_params: Option, } +fn galadriel_request_messages(message: message::Message) -> Result, CompletionError> { + match message { + message::Message::System { content } => Ok(vec![Message { + role: "system".to_string(), + content: Some(content), + tool_calls: vec![], + }]), + message::Message::User { content } => { + let text = content + .into_iter() + .filter_map(|item| match item { + message::UserContent::Text(text) => Some(text.text), + message::UserContent::Document(document) => match document.data { + crate::message::DocumentSourceKind::Base64(content) + | crate::message::DocumentSourceKind::String(content) => Some(content), + _ => None, + }, + _ => None, + }) + .collect::>() + .join("\n"); + + if text.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: "user".to_string(), + content: Some(text), + tool_calls: vec![], + }]) + } + } + message::Message::Assistant { content, .. } => { + let mut text_content: Option = None; + let mut tool_calls = vec![]; + + for item in content { + match item { + message::AssistantContent::Text(text) => { + let text = text.text; + text_content = Some( + text_content + .map(|mut existing| { + existing.push('\n'); + existing.push_str(&text); + existing + }) + .unwrap_or(text), + ); + } + message::AssistantContent::ToolCall(tool_call) => { + tool_calls.push(tool_call.into()); + } + message::AssistantContent::Reasoning(_) => {} + message::AssistantContent::Image(_) => { + return Err(CompletionError::RequestError( + MessageError::ConversionError( + "Galadriel currently doesn't support images.".into(), + ) + .into(), + )); + } + } + } + + if text_content.is_none() && tool_calls.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: "assistant".to_string(), + content: text_content, + tool_calls, + }]) + } + } + } +} + impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { type Error = CompletionError; @@ -468,7 +546,7 @@ impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { Message::system, None, |_| false, - |message| Ok(vec![message.try_into()?]), + galadriel_request_messages, )?; let tool_choice = tool_choice @@ -659,6 +737,41 @@ where } #[cfg(test)] mod tests { + use super::GaladrielCompletionRequest; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct GaladrielRequestHarness; + + impl request_conformance::Harness for GaladrielRequestHarness { + fn family_name() -> &'static str { + "galadriel" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + GaladrielCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new(CompatibleChatProfile::new( + "Galadriel", + )), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(GaladrielRequestHarness); + #[test] fn test_client_initialization() { let _client = diff --git a/rig/rig-core/src/providers/mira.rs b/rig/rig-core/src/providers/mira.rs index a82c3c054..8fd9052dd 100644 --- a/rig/rig-core/src/providers/mira.rs +++ b/rig/rig-core/src/providers/mira.rs @@ -221,89 +221,90 @@ pub(super) struct MiraCompletionRequest { pub stream: bool, } -impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { - type Error = CompletionError; - - fn try_from((model, mut req): (&str, CompletionRequest)) -> Result { - crate::providers::openai::completion::CompatibleFeaturePolicy::default() - .with_tools_policy(crate::providers::openai::completion::ToolsPolicy::Unsupported) - .with_tool_choice_policy( - crate::providers::openai::completion::ToolChoicePolicy::Unsupported, - ) - .with_additional_params_policy( - crate::providers::openai::completion::AdditionalParamsPolicy::Unsupported, - ) - .apply("Mira AI", &mut req)?; - - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut messages = Vec::new(); - - if let Some(content) = &req.preamble { - messages.push(RawMessage { - role: "user".to_string(), - content: content.to_string(), - }); - } - - if let Some(Message::User { content }) = req.normalized_documents() { +fn mira_request_messages(message: Message) -> Result, CompletionError> { + match message { + Message::System { content } => Ok(vec![RawMessage { + role: "system".to_string(), + content, + }]), + Message::User { content } => { let text = content .into_iter() - .filter_map(|doc| match doc { + .filter_map(|item| match item { + UserContent::Text(text) => Some(text.text), UserContent::Document(Document { data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data), .. }) => Some(data), - UserContent::Text(text) => Some(text.text), - - // This should always be `Document` _ => None, }) .collect::>() .join("\n"); - messages.push(RawMessage { - role: "user".to_string(), - content: text, - }); + if text.is_empty() { + Ok(vec![]) + } else { + Ok(vec![RawMessage { + role: "user".to_string(), + content: text, + }]) + } } + Message::Assistant { content, .. } => { + let text = content + .into_iter() + .filter_map(|item| match item { + AssistantContent::Text(text) => Some(text.text), + _ => None, + }) + .collect::>() + .join("\n"); - for msg in req.chat_history { - let (role, content) = match msg { - Message::System { content } => ("system", content), - Message::User { content } => { - let text = content - .iter() - .map(|c| match c { - UserContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join("\n"); - ("user", text) - } - Message::Assistant { content, .. } => { - let text = content - .iter() - .map(|c| match c { - AssistantContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join("\n"); - ("assistant", text) - } - }; - messages.push(RawMessage { - role: role.to_string(), - content, - }); + if text.is_empty() { + Ok(vec![]) + } else { + Ok(vec![RawMessage { + role: "assistant".to_string(), + content: text, + }]) + } } + } +} + +impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { + type Error = CompletionError; + + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens, + tools: _, + tool_choice: _, + additional_params: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Mira AI") + .unsupported_tools() + .unsupported_tool_choice() + .unsupported_additional_params(), + |content| RawMessage { + role: "user".to_string(), + content: content.to_string(), + }, + None, + |_| false, + mira_request_messages, + )?; Ok(Self { - model: model.to_string(), + model, messages, - temperature: req.temperature, - max_tokens: req.max_tokens, + temperature, + max_tokens, stream: false, }) } @@ -683,8 +684,45 @@ impl TryFrom for Message { mod tests { use super::*; use crate::message::UserContent; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; use serde_json::json; + struct MiraRequestHarness; + + impl request_conformance::Harness for MiraRequestHarness { + fn family_name() -> &'static str { + "mira" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + MiraCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Mira AI") + .unsupported_tools() + .unsupported_tool_choice() + .unsupported_additional_params(), + ) + .preamble_as_user(), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(MiraRequestHarness); + #[test] fn test_deserialize_message() { // Test string content format diff --git a/rig/rig-core/src/providers/openai/completion/family.rs b/rig/rig-core/src/providers/openai/completion/family.rs index f76d0488c..676a3561e 100644 --- a/rig/rig-core/src/providers/openai/completion/family.rs +++ b/rig/rig-core/src/providers/openai/completion/family.rs @@ -4,6 +4,7 @@ use crate::completion::{self, CompletionError, CompletionRequest as CoreCompleti use crate::json_utils; use crate::message; use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent}; +use serde::{Deserialize, Serialize}; pub(crate) fn first_choice(choices: &[T]) -> Result<&T, CompletionError> { choices @@ -49,6 +50,45 @@ where } } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CompatibleToolChoiceKeyword { + None, + Auto, + Required, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", content = "function")] +pub enum CompatibleToolChoiceFunctionKind { + Function { name: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(untagged)] +pub enum CompatibleToolChoice { + Keyword(CompatibleToolChoiceKeyword), + Function(Vec), +} + +impl From for CompatibleToolChoice { + fn from(value: crate::message::ToolChoice) -> Self { + match value { + crate::message::ToolChoice::None => Self::Keyword(CompatibleToolChoiceKeyword::None), + crate::message::ToolChoice::Auto => Self::Keyword(CompatibleToolChoiceKeyword::Auto), + crate::message::ToolChoice::Required => { + Self::Keyword(CompatibleToolChoiceKeyword::Required) + } + crate::message::ToolChoice::Specific { function_names } => Self::Function( + function_names + .into_iter() + .map(|name| CompatibleToolChoiceFunctionKind::Function { name }) + .collect(), + ), + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum ToolsPolicy { Supported, @@ -263,6 +303,13 @@ impl CompatibleChatProfile { .with_tool_choice_policy(ToolChoicePolicy::CoerceRequiredToAuto { steering_message }); self } + + pub(crate) const fn unsupported_additional_params(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_additional_params_policy(AdditionalParamsPolicy::Unsupported); + self + } } /// Provider-agnostic core request fields shared by OpenAI-compatible chat families. @@ -910,29 +957,47 @@ pub(crate) mod request_conformance { .expect("schema should deserialize") } + #[derive(Debug, Clone, Copy)] + enum EmptyMessageBehavior { + DropEmpty, + PreserveReasoningAssistantHistory, + RequestError, + } + #[derive(Debug, Clone, Copy)] pub(crate) struct CompatibleChatExpectation { profile: CompatibleChatProfile, - preserves_reasoning_assistant_history: bool, - includes_document_messages: bool, + empty_message_behavior: EmptyMessageBehavior, + preamble_role: &'static str, + document_role: Option<&'static str>, } impl CompatibleChatExpectation { pub(crate) const fn new(profile: CompatibleChatProfile) -> Self { Self { profile, - preserves_reasoning_assistant_history: false, - includes_document_messages: true, + empty_message_behavior: if profile.require_messages { + EmptyMessageBehavior::RequestError + } else { + EmptyMessageBehavior::DropEmpty + }, + preamble_role: "system", + document_role: Some("user"), } } pub(crate) const fn preserves_reasoning_assistant_history(mut self) -> Self { - self.preserves_reasoning_assistant_history = true; + self.empty_message_behavior = EmptyMessageBehavior::PreserveReasoningAssistantHistory; + self + } + + pub(crate) const fn preamble_as_user(mut self) -> Self { + self.preamble_role = "user"; self } pub(crate) const fn omits_document_messages(mut self) -> Self { - self.includes_document_messages = false; + self.document_role = None; self } } @@ -951,32 +1016,33 @@ pub(crate) mod request_conformance { } Fixture::PreambleDocumentHistoryOrdering => { let request = expect_supported(case, actual); - let mut expected = vec!["system:system".to_owned()]; - if expectation.includes_document_messages { - expected.push(format!("user:{}", expected_document_text())); + let mut expected = vec![format!("{}:system", expectation.preamble_role)]; + if let Some(document_role) = expectation.document_role { + expected.push(format!("{document_role}:{}", expected_document_text())); } expected.push("user:hello".to_owned()); - assert_eq!(message_summaries(&request), expected,); + assert_eq!(message_summaries(&request), expected); } - Fixture::EmptyMessageRejection => { - if profile.require_messages { + Fixture::EmptyMessageRejection => match expectation.empty_message_behavior { + EmptyMessageBehavior::RequestError => { assert!( matches!(actual, Outcome::RequestError), "expected request rejection for {case:?}, got {actual:?}" ); - } else { + } + EmptyMessageBehavior::PreserveReasoningAssistantHistory => { let request = expect_supported(case, actual); let summaries = message_summaries(&request); - if expectation.preserves_reasoning_assistant_history { - assert_eq!(summaries, vec!["assistant:hidden".to_owned()]); - } else { - assert!( - summaries.is_empty(), - "expected no provider-compatible messages: {request:?}" - ); - } + assert_eq!(summaries, vec!["assistant:hidden".to_owned()]); } - } + EmptyMessageBehavior::DropEmpty => { + let request = expect_supported(case, actual); + assert!( + message_summaries(&request).is_empty(), + "expected no provider-compatible messages: {request:?}" + ); + } + }, Fixture::UnsupportedFieldStripping => { let request = expect_supported(case, actual); assert_eq!(model(&request), Some(default_model)); diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index ee10a93bd..8f444c4cf 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -28,12 +28,14 @@ pub mod streaming; #[cfg(test)] pub(crate) use family::request_conformance; pub(crate) use family::{ - AdditionalParamsPolicy, CompatibleChatProfile, CompatibleFeaturePolicy, CompatibleRequestCore, - CompatibleStreamingToolCall, OpenAiChatProviderProfile, ToolCallConflictPolicy, - ToolChoicePolicy, ToolsPolicy, apply_compatible_tool_call_deltas, + CompatibleChatProfile, CompatibleRequestCore, CompatibleStreamingToolCall, + OpenAiChatProviderProfile, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, build_compatible_request_core, build_completion_response, first_choice, map_finish_reason, non_empty_text, take_finalized_tool_calls, take_tool_calls, }; +pub use family::{ + CompatibleToolChoice, CompatibleToolChoiceFunctionKind, CompatibleToolChoiceKeyword, +}; /// Serializes user content as a plain string when there's a single text item, /// otherwise as an array of content parts. diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index f2570ffd2..0802ece60 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -1514,48 +1514,10 @@ impl TryFrom for Vec { } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ToolChoiceKeyword { - None, - Auto, - Required, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToolChoice { - Keyword(ToolChoiceKeyword), - Function(Vec), -} - -impl TryFrom for ToolChoice { - type Error = CompletionError; - - fn try_from(value: crate::message::ToolChoice) -> Result { - let res = match value { - crate::message::ToolChoice::None => Self::Keyword(ToolChoiceKeyword::None), - crate::message::ToolChoice::Auto => Self::Keyword(ToolChoiceKeyword::Auto), - crate::message::ToolChoice::Required => Self::Keyword(ToolChoiceKeyword::Required), - crate::message::ToolChoice::Specific { function_names } => { - let vec: Vec = function_names - .into_iter() - .map(|name| ToolChoiceFunctionKind::Function { name }) - .collect(); - - Self::Function(vec) - } - }; - - Ok(res) - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type", content = "function")] -pub enum ToolChoiceFunctionKind { - Function { name: String }, -} +pub type ToolChoice = crate::providers::openai::completion::CompatibleToolChoice; +pub type ToolChoiceFunctionKind = + crate::providers::openai::completion::CompatibleToolChoiceFunctionKind; +pub type ToolChoiceKeyword = crate::providers::openai::completion::CompatibleToolChoiceKeyword; #[derive(Debug, Serialize, Deserialize)] pub(super) struct OpenrouterCompletionRequest { @@ -1566,7 +1528,7 @@ pub(super) struct OpenrouterCompletionRequest { #[serde(skip_serializing_if = "Vec::is_empty")] tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, + tool_choice: Option, #[serde(flatten, skip_serializing_if = "Option::is_none")] pub additional_params: Option, } @@ -1605,9 +1567,7 @@ impl TryFrom> for OpenrouterCompletionRequest { |message| Vec::::try_from(message).map_err(CompletionError::from), )?; - let tool_choice = tool_choice - .map(crate::providers::openai::completion::ToolChoice::try_from) - .transpose()?; + let tool_choice = tool_choice.map(ToolChoice::from); let tools: Vec = tools .into_iter() diff --git a/rig/rig-core/src/providers/perplexity.rs b/rig/rig-core/src/providers/perplexity.rs index baa81169f..5cb942a67 100644 --- a/rig/rig-core/src/providers/perplexity.rs +++ b/rig/rig-core/src/providers/perplexity.rs @@ -227,6 +227,58 @@ pub(super) struct PerplexityCompletionRequest { pub stream: bool, } +fn perplexity_request_messages(message: message::Message) -> Result, MessageError> { + match message { + message::Message::System { content } => Ok(vec![Message { + role: Role::System, + content, + }]), + message::Message::User { content } => { + let collapsed_content = content + .into_iter() + .filter_map(|content| match content { + message::UserContent::Text(message::Text { text }) => Some(text), + message::UserContent::Document(document) => match document.data { + crate::message::DocumentSourceKind::Base64(content) + | crate::message::DocumentSourceKind::String(content) => Some(content), + _ => None, + }, + _ => None, + }) + .collect::>() + .join("\n"); + + if collapsed_content.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: Role::User, + content: collapsed_content, + }]) + } + } + message::Message::Assistant { content, .. } => { + let collapsed_content = content + .into_iter() + .filter_map(|content| match content { + message::AssistantContent::Text(message::Text { text }) => Some(text), + _ => None, + }) + .collect::>() + .join("\n"); + + if collapsed_content.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: Role::Assistant, + content: collapsed_content, + }]) + } + } + } +} + impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest { type Error = CompletionError; @@ -251,7 +303,10 @@ impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest { }, None, |_| false, - |message| Ok(vec![message.try_into()?]), + |message| { + perplexity_request_messages(message) + .map_err(|err| CompletionError::RequestError(err.into())) + }, )?; Ok(Self { @@ -488,6 +543,41 @@ where #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct PerplexityRequestHarness; + + impl request_conformance::Harness for PerplexityRequestHarness { + fn family_name() -> &'static str { + "perplexity" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + PerplexityCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Perplexity") + .unsupported_tools() + .unsupported_tool_choice(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(PerplexityRequestHarness); #[test] fn test_deserialize_message() { diff --git a/rig/rig-core/src/providers/together/completion.rs b/rig/rig-core/src/providers/together/completion.rs index de3a899c4..9f688edbd 100644 --- a/rig/rig-core/src/providers/together/completion.rs +++ b/rig/rig-core/src/providers/together/completion.rs @@ -167,7 +167,7 @@ impl TryFrom<(&str, CompletionRequest)> for TogetherAICompletionRequest { |message| Vec::::try_from(message).map_err(CompletionError::from), )?; - let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; + let tool_choice = tool_choice.map(ToolChoice::from); Ok(Self { model, @@ -304,51 +304,10 @@ where } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ToolChoiceKeyword { - None, - Auto, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum ToolChoice { - Keyword(ToolChoiceKeyword), - Function(Vec), -} - -impl TryFrom for ToolChoice { - type Error = CompletionError; - - fn try_from(value: crate::message::ToolChoice) -> Result { - let res = match value { - crate::message::ToolChoice::None => Self::Keyword(ToolChoiceKeyword::None), - crate::message::ToolChoice::Auto => Self::Keyword(ToolChoiceKeyword::Auto), - crate::message::ToolChoice::Specific { function_names } => { - let vec: Vec = function_names - .into_iter() - .map(|name| ToolChoiceFunctionKind::Function { name }) - .collect(); - - Self::Function(vec) - } - choice => { - return Err(CompletionError::ProviderError(format!( - "Unsupported tool choice type: {choice:?}" - ))); - } - }; - - Ok(res) - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type", content = "function")] -pub enum ToolChoiceFunctionKind { - Function { name: String }, -} +pub type ToolChoice = crate::providers::openai::completion::CompatibleToolChoice; +pub type ToolChoiceFunctionKind = + crate::providers::openai::completion::CompatibleToolChoiceFunctionKind; +pub type ToolChoiceKeyword = crate::providers::openai::completion::CompatibleToolChoiceKeyword; #[cfg(test)] mod tests {