diff --git a/rig-integrations/rig-bedrock/src/types/assistant_content.rs b/rig-integrations/rig-bedrock/src/types/assistant_content.rs index 964d6524d..9cc473cab 100644 --- a/rig-integrations/rig-bedrock/src/types/assistant_content.rs +++ b/rig-integrations/rig-bedrock/src/types/assistant_content.rs @@ -1,7 +1,6 @@ use aws_sdk_bedrockruntime::types as aws_bedrock; use rig::{ - OneOrMany, completion::CompletionError, message::{AssistantContent, Text, ToolCall, ToolFunction}, }; @@ -114,21 +113,25 @@ impl TryFrom for completion::CompletionResponse None, }) { return Ok(completion::CompletionResponse { - choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new( - tool_use.id, - ToolFunction::new(tool_use.function.name, tool_use.function.arguments), - ))), + choice: completion::AssistantChoice::one(AssistantContent::ToolCall( + ToolCall::new( + tool_use.id, + ToolFunction::new(tool_use.function.name, tool_use.function.arguments), + ), + )), usage, raw_response: value, message_id: None, + stop_reason: None, }); } Ok(completion::CompletionResponse { - choice, + choice: choice.into(), usage, raw_response: value, message_id: None, + stop_reason: None, }) } } diff --git a/rig-integrations/rig-gemini-grpc/src/completion.rs b/rig-integrations/rig-gemini-grpc/src/completion.rs index 7a5b08563..ad4448114 100644 --- a/rig-integrations/rig-gemini-grpc/src/completion.rs +++ b/rig-integrations/rig-gemini-grpc/src/completion.rs @@ -450,11 +450,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for CompletionResponse for CompletionResponse { + Some(AssistantContent::ToolCall(ToolCall { id, function, .. })) => { assert_eq!(id, "add"); assert_eq!(function.name, "add"); assert_eq!(function.arguments, args); diff --git a/rig/rig-core/examples/manual_tool_calls.rs b/rig/rig-core/examples/manual_tool_calls.rs index 7dc09fbe3..a25b70c94 100644 --- a/rig/rig-core/examples/manual_tool_calls.rs +++ b/rig/rig-core/examples/manual_tool_calls.rs @@ -10,9 +10,8 @@ //! 5. repeats until the model returns a final text answer. use anyhow::{Result, bail}; -use rig::OneOrMany; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::{Completion, ToolDefinition}; +use rig::completion::{AssistantChoice, Completion, ToolDefinition}; use rig::message::{AssistantContent, Message, ToolCall, ToolChoice}; use rig::providers::openai; use rig::tool::{Tool, ToolSet}; @@ -87,7 +86,7 @@ impl Tool for Subtract { } } -fn collect_tool_calls(choice: &OneOrMany) -> Vec { +fn collect_tool_calls(choice: &AssistantChoice) -> Vec { choice .iter() .filter_map(|content| match content { @@ -97,7 +96,7 @@ fn collect_tool_calls(choice: &OneOrMany) -> Vec { .collect() } -fn extract_text(choice: &OneOrMany) -> String { +fn extract_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { @@ -151,10 +150,12 @@ async fn main() -> Result<()> { let tool_calls = collect_tool_calls(&response.choice); history.push(current_prompt.clone()); - history.push(Message::Assistant { - id: response.message_id.clone(), - content: response.choice.clone(), - }); + if let Ok(content) = response.choice.to_one_or_many() { + history.push(Message::Assistant { + id: response.message_id.clone(), + content, + }); + } if tool_calls.is_empty() { let final_text = extract_text(&response.choice); diff --git a/rig/rig-core/src/agent/prompt_request/mod.rs b/rig/rig-core/src/agent/prompt_request/mod.rs index 4951d5b59..b5e08d198 100644 --- a/rig/rig-core/src/agent/prompt_request/mod.rs +++ b/rig/rig-core/src/agent/prompt_request/mod.rs @@ -1,5 +1,6 @@ pub mod hooks; pub mod streaming; +mod turns; use super::{ Agent, @@ -26,6 +27,7 @@ use std::{ }; use tracing::info_span; use tracing::{Instrument, span::Id}; +use turns::{AssistantTextAccumulator, AssistantTurnSummary}; pub trait PromptType {} pub struct Standard; @@ -339,6 +341,7 @@ where let mut current_max_turns = 0; let mut usage = Usage::new(); + let mut assistant_text_accumulator = AssistantTextAccumulator::default(); let current_span_id: AtomicU64 = AtomicU64::new(0); // We need to do at least 2 loops for 1 roundtrip (user expects normal message) @@ -447,34 +450,30 @@ where )); } - let (tool_calls, texts): (Vec<_>, Vec<_>) = resp + let turn_summary = AssistantTurnSummary::from_response(&resp); + let turn_text = turn_summary.visible_text("\n"); + assistant_text_accumulator.observe(&turn_text); + + let (tool_calls, _texts): (Vec<_>, Vec<_>) = resp .choice .iter() .partition(|choice| matches!(choice, AssistantContent::ToolCall(_))); - new_messages.push(Message::Assistant { - id: resp.message_id.clone(), - content: resp.choice.clone(), - }); + if let Ok(content) = resp.choice.to_one_or_many() { + new_messages.push(Message::Assistant { + id: resp.message_id.clone(), + content, + }); + } if tool_calls.is_empty() { - let merged_texts = texts - .into_iter() - .filter_map(|content| { - if let AssistantContent::Text(text) = content { - Some(text.text.clone()) - } else { - None - } - }) - .collect::>() - .join("\n"); + let final_output = assistant_text_accumulator.final_output(&turn_text); if self.max_turns > 1 { tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns); } - agent_span.record("gen_ai.completion", &merged_texts); + agent_span.record("gen_ai.completion", &final_output); agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens); agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens); agent_span.record( @@ -486,7 +485,7 @@ where usage.cache_creation_input_tokens, ); - return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages)); + return Ok(PromptResponse::new(final_output, usage).with_messages(new_messages)); } let hook = self.hook.clone(); diff --git a/rig/rig-core/src/agent/prompt_request/streaming.rs b/rig/rig-core/src/agent/prompt_request/streaming.rs index adcf766a9..641a20787 100644 --- a/rig/rig-core/src/agent/prompt_request/streaming.rs +++ b/rig/rig-core/src/agent/prompt_request/streaming.rs @@ -16,6 +16,7 @@ use tracing::info_span; use tracing_futures::Instrument; use super::ToolCallHookAction; +use super::turns::{AssistantTextAccumulator, AssistantTurnSummary}; use crate::{ agent::Agent, completion::{CompletionError, CompletionModel, PromptError}, @@ -170,16 +171,6 @@ fn tool_result_to_user_message( } } -fn assistant_text_from_choice(choice: &OneOrMany) -> String { - choice - .iter() - .filter_map(|content| match content { - AssistantContent::Text(text) => Some(text.text.as_str()), - _ => None, - }) - .collect() -} - #[derive(Debug, thiserror::Error)] pub enum StreamingError { #[error("CompletionError: {0}")] @@ -398,6 +389,7 @@ where let output_schema = self.output_schema; let mut aggregated_usage = crate::completion::Usage::new(); + let mut assistant_text_accumulator = AssistantTextAccumulator::default(); // NOTE: We use .instrument(agent_span) instead of span.enter() to avoid // span context leaking to other concurrent tasks. Using span.enter() inside @@ -659,7 +651,9 @@ where accumulated_reasoning.push(assembled); } - let turn_text_response = assistant_text_from_choice(&stream.choice); + let turn_summary = AssistantTurnSummary::from_stream_response(&stream); + let turn_text_response = turn_summary.visible_text(""); + assistant_text_accumulator.observe(&turn_text_response); tracing::Span::current().record("gen_ai.completion", &turn_text_response); // Add text, reasoning, and tool calls to chat history. @@ -692,9 +686,18 @@ where } if !saw_tool_call_this_turn { + let final_response_text = + assistant_text_accumulator.final_output(&turn_text_response); + // Add user message and assistant response to history before finishing if !turn_text_response.is_empty() { new_messages.push(Message::assistant(&turn_text_response)); + } else if !final_response_text.is_empty() { + tracing::info!( + agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME), + message_id = ?stream.message_id, + "Streaming turn completed without assistant text; using prior turn text for final response" + ); } else { tracing::warn!( agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME), @@ -715,7 +718,7 @@ where None }; yield Ok(MultiTurnStreamItem::final_response_with_history( - &turn_text_response, + &final_response_text, aggregated_usage, final_messages, )); @@ -1224,6 +1227,86 @@ mod tests { assert_eq!(final_response_text.as_deref(), Some("")); } + #[derive(Clone, Default)] + struct StreamingEmptyTerminalTurnFallbackModel { + turn_counter: Arc, + } + + #[allow(refining_impl_trait)] + impl CompletionModel for StreamingEmptyTerminalTurnFallbackModel { + type Response = (); + type StreamingResponse = MockStreamingResponse; + type Client = (); + + fn make(_: &Self::Client, _: impl Into) -> Self { + Self::default() + } + + async fn completion( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + Err(CompletionError::ProviderError( + "completion is unused in this streaming test".to_string(), + )) + } + + async fn stream( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst); + let stream = async_stream::stream! { + if turn == 0 { + yield Ok(RawStreamingChoice::Message("The answer is 5".to_string())); + yield Ok(RawStreamingChoice::ToolCall( + RawStreamingToolCall::new( + "tool_call_2".to_string(), + "calculator".to_string(), + serde_json::json!({"op": "add", "a": 2, "b": 3}), + ), + )); + yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(2))); + } else { + yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(1))); + } + }; + + let pinned_stream: crate::streaming::StreamingResult = + Box::pin(stream); + Ok(StreamingCompletionResponse::stream(pinned_stream)) + } + } + + #[tokio::test] + async fn final_response_falls_back_to_prior_assistant_text_when_terminal_turn_is_empty() { + let model = StreamingEmptyTerminalTurnFallbackModel::default(); + let turn_counter = model.turn_counter.clone(); + let agent = AgentBuilder::new(model).build(); + + let mut stream = agent.stream_prompt("What is 2 + 3?").multi_turn(3).await; + let mut streamed_text = String::new(); + let mut final_response_text = None; + + while let Some(item) = stream.next().await { + match item { + Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text( + text, + ))) => streamed_text.push_str(&text.text), + Ok(MultiTurnStreamItem::FinalResponse(res)) => { + final_response_text = Some(res.response().to_owned()); + break; + } + Ok(_) => {} + Err(err) => panic!("unexpected streaming error: {err:?}"), + } + } + + assert_eq!(streamed_text, "The answer is 5"); + assert_eq!(final_response_text.as_deref(), Some("The answer is 5")); + assert_eq!(turn_counter.load(Ordering::SeqCst), 2); + } + /// Background task that logs periodically to detect span leakage. /// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`. async fn background_logger(stop: Arc, leak_count: Arc) { diff --git a/rig/rig-core/src/agent/prompt_request/turns.rs b/rig/rig-core/src/agent/prompt_request/turns.rs new file mode 100644 index 000000000..65c5d1fb8 --- /dev/null +++ b/rig/rig-core/src/agent/prompt_request/turns.rs @@ -0,0 +1,108 @@ +//! Shared helpers for deriving user-visible assistant text across agent turns. + +use crate::completion::normalized::NormalizedTurn; +use crate::completion::{AssistantChoice, CompletionResponse, normalized::NormalizedItem}; +use crate::streaming::StreamingCompletionResponse; + +/// Summary of the visible assistant text in a normalized turn. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct AssistantTurnSummary { + visible_text_blocks: Vec, +} + +impl AssistantTurnSummary { + fn from_turn(turn: NormalizedTurn) -> Self { + let visible_text_blocks = turn + .items + .into_iter() + .filter_map(|item| match item { + NormalizedItem::Text(text) => Some(text), + _ => None, + }) + .collect(); + + Self { + visible_text_blocks, + } + } + + /// Extract non-empty text blocks from a normalized assistant choice. + #[cfg_attr(not(test), allow(dead_code))] + pub(crate) fn from_choice(choice: &AssistantChoice) -> Self { + Self::from_turn(NormalizedTurn::from_choice(choice, None, None)) + } + + /// Extract non-empty text blocks from a normalized completion response. + pub(crate) fn from_response(response: &CompletionResponse) -> Self { + Self::from_turn(NormalizedTurn::from_completion_response(response)) + } + + /// Extract non-empty text blocks from an aggregated streaming response. + pub(crate) fn from_stream_response(response: &StreamingCompletionResponse) -> Self + where + R: Clone + Unpin + crate::completion::GetTokenUsage, + { + Self::from_turn(NormalizedTurn::from_choice( + &response.choice, + response.message_id.clone(), + response.stop_reason.clone(), + )) + } + + /// Render the visible text blocks using the caller's preferred separator. + pub(crate) fn visible_text(&self, separator: &str) -> String { + self.visible_text_blocks.join(separator) + } +} + +/// Tracks non-empty assistant text from earlier turns so a textless final turn +/// can still return the last user-visible answer the model produced. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub(crate) struct AssistantTextAccumulator { + prior_turn_texts: Vec, +} + +impl AssistantTextAccumulator { + /// Record a turn's visible text when it is not empty. + pub(crate) fn observe(&mut self, turn_text: &str) { + if !turn_text.is_empty() { + self.prior_turn_texts.push(turn_text.to_owned()); + } + } + + /// Prefer the current turn's visible text and otherwise fall back to the + /// accumulated text from earlier turns in the same request. + pub(crate) fn final_output(&self, current_turn_text: &str) -> String { + if current_turn_text.is_empty() { + self.prior_turn_texts.join("\n") + } else { + current_turn_text.to_owned() + } + } +} + +#[cfg(test)] +mod tests { + use super::{AssistantTextAccumulator, AssistantTurnSummary}; + use crate::{completion::AssistantChoice, message::AssistantContent}; + + #[test] + fn summary_ignores_empty_text_blocks() { + let choice = AssistantChoice::many(vec![ + AssistantContent::text("visible"), + AssistantContent::text(""), + ]); + + let summary = AssistantTurnSummary::from_choice(&choice); + + assert_eq!(summary.visible_text("\n"), "visible"); + } + + #[test] + fn accumulator_falls_back_when_terminal_turn_is_empty() { + let mut accumulator = AssistantTextAccumulator::default(); + accumulator.observe("first turn"); + + assert_eq!(accumulator.final_output(""), "first turn"); + } +} diff --git a/rig/rig-core/src/completion/mod.rs b/rig/rig-core/src/completion/mod.rs index 7dcef0ac6..7010a7ced 100644 --- a/rig/rig-core/src/completion/mod.rs +++ b/rig/rig-core/src/completion/mod.rs @@ -1,5 +1,7 @@ pub mod message; +pub(crate) mod normalized; pub mod request; pub use message::{AssistantContent, Message, MessageError}; +pub use normalized::StopReason; pub use request::*; diff --git a/rig/rig-core/src/completion/normalized.rs b/rig/rig-core/src/completion/normalized.rs new file mode 100644 index 000000000..6bb136558 --- /dev/null +++ b/rig/rig-core/src/completion/normalized.rs @@ -0,0 +1,152 @@ +//! Internal normalization helpers for provider responses. +//! +//! Providers often expose different stop-reason enums and content block shapes. +//! This module defines a small shared semantic representation that can be used by +//! higher-level code, test harnesses, and future provider adapters without +//! depending on provider-specific wire types. + +use serde_json::Value; + +use super::{AssistantChoice, CompletionResponse}; +use crate::message::AssistantContent; + +/// Provider-agnostic assistant content used by internal normalization flows. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum NormalizedItem { + Text(String), + ToolCall { + id: String, + name: String, + arguments: Value, + }, + Reasoning(String), +} + +/// Provider-agnostic reasons why a provider stopped generating a turn. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StopReason { + /// The model reached a natural stop point. + Stop, + /// The model stopped in order to request tool execution. + ToolCalls, + /// The provider explicitly ended the assistant turn. + EndTurn, + /// The model hit the configured output-token limit. + MaxTokens, + /// The provider filtered the output content. + ContentFilter, + /// The provider stopped for a safety-related reason. + Safety, + /// The provider reported a stop reason Rig does not normalize yet. + Other(String), +} + +/// Provider-agnostic normalized assistant turn. +#[derive(Debug, Clone, PartialEq, Default)] +pub(crate) struct NormalizedTurn { + pub(crate) items: Vec, + pub(crate) message_id: Option, + pub(crate) stop_reason: Option, +} + +impl NormalizedTurn { + /// Build a normalized turn from a normalized assistant choice plus metadata. + pub(crate) fn from_choice( + choice: &AssistantChoice, + message_id: Option, + stop_reason: Option, + ) -> Self { + let mut items = Vec::new(); + + for item in choice.iter() { + let normalized = match item { + AssistantContent::Text(text) if !text.text.is_empty() => { + Some(NormalizedItem::Text(text.text.clone())) + } + AssistantContent::ToolCall(tool_call) => Some(NormalizedItem::ToolCall { + id: tool_call.id.clone(), + name: tool_call.function.name.clone(), + arguments: tool_call.function.arguments.clone(), + }), + AssistantContent::Reasoning(reasoning) => { + let text = reasoning.display_text(); + if text.is_empty() { + None + } else { + Some(NormalizedItem::Reasoning(text)) + } + } + _ => None, + }; + + if let Some(normalized) = normalized { + let duplicate_reasoning = matches!( + (&normalized, items.last()), + (NormalizedItem::Reasoning(current), Some(NormalizedItem::Reasoning(previous))) + if current == previous + ); + + if !duplicate_reasoning { + items.push(normalized); + } + } + } + + Self { + items, + message_id, + stop_reason, + } + } + + /// Build a normalized turn from a completion response. + pub(crate) fn from_completion_response(response: &CompletionResponse) -> Self { + Self::from_choice( + &response.choice, + response.message_id.clone(), + response.stop_reason.clone(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::{NormalizedItem, NormalizedTurn, StopReason}; + use crate::completion::AssistantChoice; + use crate::message::{AssistantContent, Reasoning}; + + #[test] + fn normalized_turn_ignores_empty_text_blocks() { + let choice = AssistantChoice::many(vec![ + AssistantContent::text("visible"), + AssistantContent::text(""), + ]); + + let turn = NormalizedTurn::from_choice(&choice, Some("msg_1".to_string()), None); + + assert_eq!( + turn, + NormalizedTurn { + items: vec![NormalizedItem::Text("visible".to_string())], + message_id: Some("msg_1".to_string()), + stop_reason: None, + } + ); + } + + #[test] + fn normalized_turn_deduplicates_adjacent_reasoning_blocks() { + let choice = AssistantChoice::many(vec![ + AssistantContent::Reasoning(Reasoning::new("step one")), + AssistantContent::Reasoning(Reasoning::new("step one")), + ]); + + let turn = NormalizedTurn::from_choice(&choice, None, Some(StopReason::EndTurn)); + + assert_eq!( + turn.items, + vec![NormalizedItem::Reasoning("step one".to_string())] + ); + assert_eq!(turn.stop_reason, Some(StopReason::EndTurn)); + } +} diff --git a/rig/rig-core/src/completion/request.rs b/rig/rig-core/src/completion/request.rs index 1dfd07690..071d2bf2b 100644 --- a/rig/rig-core/src/completion/request.rs +++ b/rig/rig-core/src/completion/request.rs @@ -63,12 +63,13 @@ //! For more information on how to use the completion functionality, refer to the documentation of //! the individual traits, structs, and enums defined in this module. +use super::StopReason; use super::message::{AssistantContent, DocumentMediaType}; use crate::message::ToolChoice; use crate::streaming::StreamingCompletionResponse; use crate::tool::server::ToolServerError; use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; -use crate::{OneOrMany, http_client}; +use crate::{EmptyListError, OneOrMany, http_client}; use crate::{ json_utils, message::{Message, UserContent}, @@ -359,12 +360,13 @@ pub trait Completion { } /// General completion response struct that contains the high-level completion choice -/// and the raw response. The completion choice contains one or more assistant content. +/// and the raw response. Providers may validly return an empty assistant choice +/// when a turn ends without emitting text, tool calls, reasoning, or images. #[derive(Debug)] pub struct CompletionResponse { - /// The completion choice (represented by one or more assistant message content) + /// The completion choice (represented by zero or more assistant message content) /// returned by the completion model provider - pub choice: OneOrMany, + pub choice: AssistantChoice, /// Tokens used during prompting and responding pub usage: Usage, /// The raw response returned by the completion model provider @@ -372,6 +374,200 @@ pub struct CompletionResponse { /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID). /// Used to pair reasoning input items with their output items in multi-turn. pub message_id: Option, + /// Provider-agnostic reason why the model stopped generating this turn. + pub stop_reason: Option, +} + +/// Zero or more assistant content items returned by a provider. +/// +/// Unlike [`OneOrMany`], this type preserves legitimate empty assistant turns +/// without synthesizing placeholder text. +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +#[serde(transparent)] +pub struct AssistantChoice { + items: Vec, +} + +impl AssistantChoice { + /// Create an empty assistant choice. + pub fn new() -> Self { + Self::default() + } + + /// Create an assistant choice containing a single item. + pub fn one(item: AssistantContent) -> Self { + Self { items: vec![item] } + } + + /// Create an assistant choice from zero or more items. + pub fn many(items: I) -> Self + where + I: IntoIterator, + { + Self { + items: items.into_iter().collect(), + } + } + + /// Returns the first assistant item, if present. + pub fn first(&self) -> Option { + self.items.first().cloned() + } + + /// Returns a reference to the first assistant item, if present. + pub fn first_ref(&self) -> Option<&AssistantContent> { + self.items.first() + } + + /// Returns the first assistant item mutably, if present. + pub fn first_mut(&mut self) -> Option<&mut AssistantContent> { + self.items.first_mut() + } + + /// Returns the last assistant item, if present. + pub fn last(&self) -> Option { + self.items.last().cloned() + } + + /// Returns a reference to the last assistant item, if present. + pub fn last_ref(&self) -> Option<&AssistantContent> { + self.items.last() + } + + /// Returns the last assistant item mutably, if present. + pub fn last_mut(&mut self) -> Option<&mut AssistantContent> { + self.items.last_mut() + } + + /// Returns the number of assistant items. + pub fn len(&self) -> usize { + self.items.len() + } + + /// Returns `true` when the provider emitted no assistant items. + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } + + /// Append an assistant item. + pub fn push(&mut self, item: AssistantContent) { + self.items.push(item); + } + + /// Insert an assistant item at a given index. + pub fn insert(&mut self, index: usize, item: AssistantContent) { + self.items.insert(index, item); + } + + /// Iterate over assistant items. + pub fn iter(&self) -> std::slice::Iter<'_, AssistantContent> { + self.items.iter() + } + + /// Iterate over assistant items mutably. + pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, AssistantContent> { + self.items.iter_mut() + } + + /// Convert into a plain vector. + pub fn into_vec(self) -> Vec { + self.items + } + + /// Convert into [`OneOrMany`] when the choice is non-empty. + pub fn into_one_or_many(self) -> Result, EmptyListError> { + OneOrMany::many(self.items) + } + + /// Clone into [`OneOrMany`] when the choice is non-empty. + pub fn to_one_or_many(&self) -> Result, EmptyListError> { + OneOrMany::many(self.items.clone()) + } +} + +impl From for AssistantChoice { + fn from(item: AssistantContent) -> Self { + Self::one(item) + } +} + +impl From> for AssistantChoice { + fn from(items: Vec) -> Self { + Self { items } + } +} + +impl From> for AssistantChoice { + fn from(items: OneOrMany) -> Self { + Self::many(items) + } +} + +impl From for Vec { + fn from(choice: AssistantChoice) -> Self { + choice.items + } +} + +impl std::iter::FromIterator for AssistantChoice { + fn from_iter>(iter: T) -> Self { + Self::many(iter) + } +} + +impl IntoIterator for AssistantChoice { + type Item = AssistantContent; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.items.into_iter() + } +} + +impl<'a> IntoIterator for &'a AssistantChoice { + type Item = &'a AssistantContent; + type IntoIter = std::slice::Iter<'a, AssistantContent>; + + fn into_iter(self) -> Self::IntoIter { + self.items.iter() + } +} + +impl<'a> IntoIterator for &'a mut AssistantChoice { + type Item = &'a mut AssistantContent; + type IntoIter = std::slice::IterMut<'a, AssistantContent>; + + fn into_iter(self) -> Self::IntoIter { + self.items.iter_mut() + } +} + +impl TryFrom for OneOrMany { + type Error = EmptyListError; + + fn try_from(choice: AssistantChoice) -> Result { + choice.into_one_or_many() + } +} + +impl TryFrom<&AssistantChoice> for OneOrMany { + type Error = EmptyListError; + + fn try_from(choice: &AssistantChoice) -> Result { + choice.to_one_or_many() + } +} + +impl PartialEq> for AssistantChoice { + fn eq(&self, other: &OneOrMany) -> bool { + self.iter().eq(other.iter()) + } +} + +impl PartialEq for OneOrMany { + fn eq(&self, other: &AssistantChoice) -> bool { + self.iter().eq(other.iter()) + } } /// A trait for grabbing the token usage of a completion response. @@ -988,6 +1184,30 @@ mod tests { assert_eq!(format!("{doc}"), expected); } + #[test] + fn assistant_choice_can_be_empty_without_placeholder_text() { + let choice = AssistantChoice::new(); + + assert!(choice.is_empty()); + assert!(choice.first().is_none()); + assert!(choice.to_one_or_many().is_err()); + } + + #[test] + fn assistant_choice_roundtrips_non_empty_content() { + let choice = AssistantChoice::many(vec![ + AssistantContent::text("hello"), + AssistantContent::tool_call("call_1", "echo", serde_json::json!({"value": 1})), + ]); + + let normalized = choice + .clone() + .into_one_or_many() + .expect("non-empty assistant choice should convert"); + + assert_eq!(choice, normalized); + } + #[test] fn test_normalize_documents_with_documents() { let doc1 = Document { diff --git a/rig/rig-core/src/providers/anthropic/completion.rs b/rig/rig-core/src/providers/anthropic/completion.rs index 442f38b22..e7cc0097c 100644 --- a/rig/rig-core/src/providers/anthropic/completion.rs +++ b/rig/rig-core/src/providers/anthropic/completion.rs @@ -60,6 +60,15 @@ pub struct CompletionResponse { pub usage: Usage, } +pub(crate) fn map_stop_reason(reason: &str) -> completion::StopReason { + match reason { + "end_turn" => completion::StopReason::EndTurn, + "tool_use" => completion::StopReason::ToolCalls, + "max_tokens" => completion::StopReason::MaxTokens, + other => completion::StopReason::Other(other.to_string()), + } +} + impl ProviderResponseExt for CompletionResponse { type OutputMessage = Content; type Usage = Usage; @@ -214,11 +223,7 @@ impl TryFrom for completion::CompletionResponse, _>>()?; - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::AssistantChoice::from(content); let usage = completion::Usage { input_tokens: response.usage.input_tokens, @@ -234,6 +239,7 @@ impl TryFrom for completion::CompletionResponse { Error(ApiErrorResponse), } +#[cfg(test)] +#[path = "conformance_tests.rs"] +mod conformance_tests; + #[cfg(test)] mod tests { use super::*; diff --git a/rig/rig-core/src/providers/anthropic/conformance_tests.rs b/rig/rig-core/src/providers/anthropic/conformance_tests.rs new file mode 100644 index 000000000..2db4053c2 --- /dev/null +++ b/rig/rig-core/src/providers/anthropic/conformance_tests.rs @@ -0,0 +1,245 @@ +use bytes::Bytes; +use serde_json::json; + +use super::*; +use crate::{ + OneOrMany, + completion::{self, CompletionError}, + http_client::mock::MockStreamingClient, + providers::conformance::{ + BoxFuture, Fixture, Harness, NormalizedItem, Outcome, StopReason, Turn, drain_stream, + normalize_completion_response, provider_conformance_tests, + }, +}; + +struct AnthropicHarness; + +impl Harness for AnthropicHarness { + fn family_name() -> &'static str { + "anthropic-messages" + } + + fn expected(case: Fixture) -> Outcome { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::EndTurn), + }), + Fixture::ToolOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::ToolCall { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::TextAndToolCallTurn => Outcome::Supported(Turn { + items: vec![ + NormalizedItem::Text("Need weather data first.".to_string()), + NormalizedItem::ToolCall { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }, + ], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::EmptyTextBlocks => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::EndTurn), + }), + Fixture::ReasoningOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::Reasoning( + "Need to reason about the tool result.".to_string(), + )], + message_id: None, + stop_reason: Some(StopReason::EndTurn), + }), + Fixture::MessageIdPreservation => { + Outcome::Unsupported("Anthropic Messages responses do not expose message IDs") + } + Fixture::StopReasonMapping => Outcome::Supported(Turn { + items: vec![NormalizedItem::Text("Truncated response".to_string())], + message_id: None, + stop_reason: Some(StopReason::MaxTokens), + }), + } + } + + fn non_stream(case: Fixture) -> Result, CompletionError> { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let raw = non_stream_response(case); + let response: completion::CompletionResponse = + raw.try_into()?; + Ok(Outcome::Supported(normalize_completion_response(&response))) + } + } + } + + fn stream(case: Fixture) -> BoxFuture, CompletionError>> { + Box::pin(async move { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let client = crate::providers::anthropic::Client::builder() + .http_client(MockStreamingClient { + sse_bytes: Bytes::from(streaming_sse(case)), + }) + .api_key("test-key") + .build() + .expect("client should build"); + let model = CompletionModel::new(client, "claude-test"); + let stream = model.stream(stream_request()).await?; + let response = drain_stream(stream).await?; + Ok(Outcome::Supported(normalize_completion_response(&response))) + } + } + }) + } +} + +fn non_stream_response(case: Fixture) -> CompletionResponse { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => response_with_content(vec![], "end_turn"), + Fixture::ToolOnlyTurn => response_with_content( + vec![Content::ToolUse { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + input: json!({"city": "Paris"}), + }], + "tool_use", + ), + Fixture::TextAndToolCallTurn => response_with_content( + vec![ + Content::Text { + text: "Need weather data first.".to_string(), + cache_control: None, + }, + Content::ToolUse { + id: "toolu_lookup".to_string(), + name: "lookup_weather".to_string(), + input: json!({"city": "Paris"}), + }, + ], + "tool_use", + ), + Fixture::EmptyTextBlocks => response_with_content( + vec![Content::Text { + text: String::new(), + cache_control: None, + }], + "end_turn", + ), + Fixture::ReasoningOnlyTurn => response_with_content( + vec![Content::Thinking { + thinking: "Need to reason about the tool result.".to_string(), + signature: Some("sig_1".to_string()), + }], + "end_turn", + ), + Fixture::StopReasonMapping => response_with_content( + vec![Content::Text { + text: "Truncated response".to_string(), + cache_control: None, + }], + "max_tokens", + ), + Fixture::MessageIdPreservation => unreachable!(), + } +} + +fn response_with_content(content: Vec, stop_reason: &str) -> CompletionResponse { + CompletionResponse { + content, + id: "msg_123".to_string(), + model: "claude-test".to_string(), + role: "assistant".to_string(), + stop_reason: Some(stop_reason.to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_read_input_tokens: None, + cache_creation_input_tokens: None, + output_tokens: 5, + }, + } +} + +fn stream_request() -> completion::CompletionRequest { + completion::CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(completion::Message::user("hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: Some(32), + tool_choice: None, + additional_params: None, + output_schema: None, + } +} + +fn streaming_sse(case: Fixture) -> String { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":0}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::ToolOnlyTurn => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_lookup\",\"name\":\"lookup_weather\",\"input\":{\"city\":\"Paris\"}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::TextAndToolCallTurn => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Need weather data first.\"}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_lookup\",\"name\":\"lookup_weather\",\"input\":{\"city\":\"Paris\"}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":1,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":1}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":7}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::EmptyTextBlocks => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":0}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::ReasoningOnlyTurn => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\",\"signature\":null}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"Need to reason about the tool result.\"}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_1\"}}\n\n", + "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":5}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::StopReasonMapping => concat!( + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-test\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":10,\"cache_read_input_tokens\":null,\"cache_creation_input_tokens\":null,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Truncated response\"}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"max_tokens\",\"stop_sequence\":null},\"usage\":{\"output_tokens\":2}}\n\n", + "data: {\"type\":\"message_stop\"}\n\n", + ) + .to_string(), + Fixture::MessageIdPreservation => unreachable!(), + } +} + +provider_conformance_tests!(AnthropicHarness); diff --git a/rig/rig-core/src/providers/anthropic/streaming.rs b/rig/rig-core/src/providers/anthropic/streaming.rs index a338e876c..b6206aa80 100644 --- a/rig/rig-core/src/providers/anthropic/streaming.rs +++ b/rig/rig-core/src/providers/anthropic/streaming.rs @@ -7,7 +7,7 @@ use tracing_futures::Instrument; use super::completion::{ AnthropicCompatibleProvider, CacheControl, Content, GenericCompletionModel, Message, - SystemContent, ToolChoice, ToolDefinition, Usage, apply_cache_control, + SystemContent, ToolChoice, ToolDefinition, Usage, apply_cache_control, map_stop_reason, split_system_messages_from_history, }; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; @@ -319,6 +319,12 @@ where }, StreamingEvent::MessageDelta { delta, usage } => { if delta.stop_reason.is_some() { + if let Some(stop_reason) = delta.stop_reason.as_deref() { + yield Ok(RawStreamingChoice::StopReason( + map_stop_reason(stop_reason), + )); + } + // cache_creation_input_tokens and cache_read_input_tokens // are cumulative totals on message_delta.usage per the // Anthropic streaming API spec — use them directly. diff --git a/rig/rig-core/src/providers/azure.rs b/rig/rig-core/src/providers/azure.rs index 9573d6f1a..ce7e64410 100644 --- a/rig/rig-core/src/providers/azure.rs +++ b/rig/rig-core/src/providers/azure.rs @@ -579,76 +579,34 @@ impl TryFrom<(&str, CompletionRequest)> for AzureOpenAICompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - //FIXME: Must fix! - if req.tool_choice.is_some() { - tracing::warn!( - "Tool choice is currently not supported in Azure OpenAI. This should be fixed by Rig 0.25." - ); - } - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![openai::Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - let tool_choice = req - .tool_choice - .clone() + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Azure OpenAI") + .native_response_format(), + openai::Message::system, + None, + |message| matches!(message, openai::Message::ToolResult { .. }), + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let additional_params = if let Some(schema) = req.output_schema { - let name = schema - .as_object() - .and_then(|o| o.get("title")) - .and_then(|v| v.as_str()) - .unwrap_or("response_schema") - .to_string(); - let mut schema_value = schema.to_value(); - openai::sanitize_schema(&mut schema_value); - let response_format = serde_json::json!({ - "response_format": { - "type": "json_schema", - "json_schema": { - "name": name, - "strict": true, - "schema": schema_value - } - } - }); - Some(match req.additional_params { - Some(existing) => json_utils::merge(existing, response_format), - None => response_format, - }) - } else { - req.additional_params - }; - Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(openai::ToolDefinition::from) .collect::>(), @@ -1072,6 +1030,97 @@ mod azure_tests { use crate::embeddings::EmbeddingModel; use crate::prelude::TypedPrompt; use crate::providers::openai::GPT_5_MINI; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct AzureRequestHarness; + + impl request_conformance::Harness for AzureRequestHarness { + fn family_name() -> &'static str { + "azure-openai" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + AzureOpenAICompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Azure OpenAI").native_response_format(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(AzureRequestHarness); + + #[test] + fn azure_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("azure-override".to_string()), + preamble: Some("system".to_string()), + chat_history: OneOrMany::one("hello".into()), + documents: vec![], + max_tokens: None, + temperature: None, + tools: vec![], + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let converted = AzureOpenAICompletionRequest::try_from(("azure-default", request)) + .expect("request should convert"); + + assert_eq!(converted.model, "azure-override"); + } + + #[test] + fn azure_request_merges_structured_output_into_additional_params() { + let request = CompletionRequest { + model: None, + preamble: Some("system".to_string()), + chat_history: OneOrMany::one("hello".into()), + documents: vec![], + max_tokens: None, + temperature: None, + tools: vec![], + tool_choice: None, + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "WeatherResponse", + "type": "object", + "properties": { + "city": { "type": "string" } + } + })) + .expect("schema should deserialize"), + ), + }; + + let converted = AzureOpenAICompletionRequest::try_from(("azure-default", request)) + .expect("request should convert"); + let params = converted + .additional_params + .expect("additional params should be present"); + + assert_eq!(params["foo"], serde_json::json!("bar")); + assert_eq!( + params["response_format"]["json_schema"]["name"], + serde_json::json!("WeatherResponse") + ); + } #[tokio::test] #[ignore] diff --git a/rig/rig-core/src/providers/cohere/completion.rs b/rig/rig-core/src/providers/cohere/completion.rs index a54270dd6..a86a093a0 100644 --- a/rig/rig-core/src/providers/cohere/completion.rs +++ b/rig/rig-core/src/providers/cohere/completion.rs @@ -91,6 +91,15 @@ pub enum FinishReason { ToolCall, } +fn map_finish_reason(reason: &FinishReason) -> completion::StopReason { + match reason { + FinishReason::MaxTokens => completion::StopReason::MaxTokens, + FinishReason::StopSequence | FinishReason::Complete => completion::StopReason::Stop, + FinishReason::ToolCall => completion::StopReason::ToolCalls, + FinishReason::Error => completion::StopReason::Other("ERROR".to_string()), + } +} + #[derive(Debug, Deserialize, Clone, Serialize)] pub struct Usage { #[serde(default)] @@ -137,10 +146,11 @@ impl TryFrom for completion::CompletionResponse Result { + let stop_reason = Some(map_finish_reason(&response.finish_reason)); let (content, _, tool_calls) = response.message(); let model_response = if !tool_calls.is_empty() { - OneOrMany::many( + completion::AssistantChoice::many( tool_calls .into_iter() .filter_map(|tool_call| { @@ -151,19 +161,13 @@ impl TryFrom for completion::CompletionResponse>(), ) - .expect("We have atleast 1 tool call in this if block") } else { - OneOrMany::many(content.into_iter().map(|content| match content { + completion::AssistantChoice::many(content.into_iter().map(|content| match content { AssistantContent::Text { text } => completion::AssistantContent::text(text), AssistantContent::Thinking { thinking } => { completion::AssistantContent::Reasoning(Reasoning::new(&thinking)) } })) - .map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })? }; let usage = response @@ -185,10 +189,11 @@ impl TryFrom for completion::CompletionResponse { + Supported(T), + Unsupported(&'static str), +} + +pub(crate) type BoxFuture = crate::wasm_compat::WasmBoxedFuture<'static, T>; + +pub(crate) trait Harness { + fn family_name() -> &'static str; + + fn expected(case: Fixture) -> Outcome; + + fn non_stream(case: Fixture) -> Result, CompletionError>; + + fn stream(case: Fixture) -> BoxFuture, CompletionError>>; +} + +pub(crate) fn normalize_completion_response( + response: &completion::CompletionResponse, +) -> Turn { + Turn::from_completion_response(response) +} + +pub(crate) async fn drain_stream( + mut stream: crate::streaming::StreamingCompletionResponse, +) -> Result>, CompletionError> +where + R: Clone + Unpin + GetTokenUsage, +{ + while let Some(item) = stream.next().await { + item?; + } + + Ok(stream.into()) +} + +pub(crate) fn assert_non_stream_case(case: Fixture) { + let expected = H::expected(case); + let actual = H::non_stream(case) + .unwrap_or_else(|err| panic!("{} non-stream {:?} failed: {err}", H::family_name(), case)); + + assert_eq!( + actual, + expected, + "{} non-stream {:?} mismatch", + H::family_name(), + case + ); +} + +pub(crate) async fn assert_stream_matches_non_stream(case: Fixture) { + let non_stream = H::non_stream(case) + .unwrap_or_else(|err| panic!("{} non-stream {:?} failed: {err}", H::family_name(), case)); + let stream = H::stream(case) + .await + .unwrap_or_else(|err| panic!("{} stream {:?} failed: {err}", H::family_name(), case)); + + match (non_stream, stream) { + (Outcome::Supported(expected), Outcome::Supported(actual)) => { + assert_eq!( + actual.items, + expected.items, + "{} stream {:?} items diverged", + H::family_name(), + case + ); + assert_eq!( + actual.message_id, + expected.message_id, + "{} stream {:?} message_id diverged", + H::family_name(), + case + ); + assert_eq!( + actual.stop_reason, + expected.stop_reason, + "{} stream {:?} stop_reason diverged", + H::family_name(), + case + ); + } + (Outcome::Unsupported(_), Outcome::Unsupported(_)) => {} + (expected, actual) => panic!( + "{} stream/non-stream support mismatch for {:?}: non-stream={expected:?}, stream={actual:?}", + H::family_name(), + case + ), + } +} + +macro_rules! provider_conformance_tests { + ($harness:ty) => { + #[test] + fn conformance_empty_assistant_turn_after_tool_result_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::EmptyAssistantTurnAfterToolResult, + ); + } + + #[test] + fn conformance_tool_only_turn_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::ToolOnlyTurn, + ); + } + + #[test] + fn conformance_text_and_tool_call_turn_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::TextAndToolCallTurn, + ); + } + + #[test] + fn conformance_empty_text_blocks_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::EmptyTextBlocks, + ); + } + + #[test] + fn conformance_reasoning_only_turn_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::ReasoningOnlyTurn, + ); + } + + #[test] + fn conformance_message_id_preservation_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::MessageIdPreservation, + ); + } + + #[test] + fn conformance_stop_reason_mapping_non_stream() { + crate::providers::conformance::assert_non_stream_case::<$harness>( + crate::providers::conformance::Fixture::StopReasonMapping, + ); + } + + #[tokio::test] + async fn conformance_empty_assistant_turn_after_tool_result_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::EmptyAssistantTurnAfterToolResult, + ) + .await; + } + + #[tokio::test] + async fn conformance_tool_only_turn_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::ToolOnlyTurn, + ) + .await; + } + + #[tokio::test] + async fn conformance_text_and_tool_call_turn_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::TextAndToolCallTurn, + ) + .await; + } + + #[tokio::test] + async fn conformance_empty_text_blocks_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::EmptyTextBlocks, + ) + .await; + } + + #[tokio::test] + async fn conformance_reasoning_only_turn_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::ReasoningOnlyTurn, + ) + .await; + } + + #[tokio::test] + async fn conformance_message_id_preservation_stream_equivalence() { + crate::providers::conformance::assert_stream_matches_non_stream::<$harness>( + crate::providers::conformance::Fixture::MessageIdPreservation, + ) + .await; + } + }; +} + +pub(crate) use provider_conformance_tests; diff --git a/rig/rig-core/src/providers/copilot/mod.rs b/rig/rig-core/src/providers/copilot/mod.rs index ddb1e975d..ffb565c30 100644 --- a/rig/rig-core/src/providers/copilot/mod.rs +++ b/rig/rig-core/src/providers/copilot/mod.rs @@ -26,6 +26,10 @@ use crate::completion::{self, CompletionError, GetTokenUsage}; use crate::embeddings::{self, EmbeddingError}; use crate::http_client::{self, HttpClientExt}; use crate::providers::openai; +use crate::providers::openai::completion::{ + CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + take_finalized_tool_calls, take_tool_calls, +}; use crate::providers::openai::responses_api::{self, CompletionRequest as ResponsesRequest}; use crate::streaming::{self, RawStreamingChoice, StreamingCompletionResponse}; use crate::wasm_compat::{WasmCompatSend, WasmCompatSync}; @@ -480,6 +484,10 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for completion::CompletionResponse Err(CompletionError::ProviderError( @@ -784,6 +790,7 @@ where usage: core.usage, raw_response: CopilotCompletionResponse::Responses(Box::new(response)), message_id: core.message_id, + stop_reason: core.stop_reason, }) } else { let body = http_client::text(response).await?; @@ -1259,6 +1266,16 @@ enum ChatFinishReason { Other(String), } +fn map_chat_finish_reason(reason: &ChatFinishReason) -> completion::StopReason { + match reason { + ChatFinishReason::ToolCalls => completion::StopReason::ToolCalls, + ChatFinishReason::Stop => completion::StopReason::Stop, + ChatFinishReason::ContentFilter => completion::StopReason::ContentFilter, + ChatFinishReason::Length => completion::StopReason::MaxTokens, + ChatFinishReason::Other(other) => completion::StopReason::Other(other.clone()), + } +} + #[derive(Deserialize, Debug)] struct ChatStreamingChoice { delta: ChatStreamingDelta, @@ -1310,77 +1327,17 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let index = tool_call.index; - - if let Some(new_id) = &tool_call.id - && !new_id.is_empty() - && let Some(new_name) = &tool_call.function.name - && !new_name.is_empty() - && let Some(existing) = tool_calls.get(&index) - && !existing.id.is_empty() - && existing.id != *new_id - && !existing.name.is_empty() - && existing.name != *new_name - { - let evicted = tool_calls.remove(&index).expect("checked above"); - yield Ok(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); - } - - let existing_tool_call = tool_calls - .entry(index) - .or_insert_with(streaming::RawStreamingToolCall::empty); - - if let Some(id) = &tool_call.id - && !id.is_empty() - { - existing_tool_call.id = id.clone(); - } - - if let Some(name) = &tool_call.function.name - && !name.is_empty() - { - existing_tool_call.name = name.clone(); - yield Ok(RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Name(name.clone()), - }); - } - - if let Some(chunk) = &tool_call.function.arguments - && !chunk.is_empty() - { - let current_args = match &existing_tool_call.arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(s) => s.clone(), - value => value.to_string(), - }; - let combined = format!("{current_args}{chunk}"); - - if combined.trim_start().starts_with('{') - && combined.trim_end().ends_with('}') - { - match serde_json::from_str(&combined) { - Ok(parsed) => existing_tool_call.arguments = parsed, - Err(_) => { - existing_tool_call.arguments = - serde_json::Value::String(combined) - } - } - } else { - existing_tool_call.arguments = - serde_json::Value::String(combined); - } - - yield Ok(RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Delta(chunk.clone()), - }); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::EvictDistinctIdAndName, + ) { + yield Ok(event); } } @@ -1402,12 +1359,15 @@ where if let Some(finish_reason) = &choice.finish_reason && *finish_reason == ChatFinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } - tool_calls = HashMap::new(); + } + + if let Some(finish_reason) = &choice.finish_reason { + yield Ok(RawStreamingChoice::StopReason(map_chat_finish_reason( + finish_reason, + ))); } } Err(crate::http_client::Error::StreamEnded) => break, @@ -1425,8 +1385,8 @@ where return; } - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(RawStreamingChoice::ToolCall(tool_call)); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } let final_usage = final_usage.unwrap_or_default(); @@ -1457,16 +1417,6 @@ where Ok(StreamingCompletionResponse::stream(Box::pin(stream))) } -fn finalize_completed_streaming_tool_call( - mut tool_call: streaming::RawStreamingToolCall, -) -> streaming::RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); - } - - tool_call -} - fn default_token_dir() -> Option { config_dir().map(|dir| dir.join("github_copilot")) } @@ -1490,7 +1440,7 @@ mod tests { use super::{ ChatApiErrorResponse, ChatCompletionResponse, Client, CompletionRoute, TEXT_EMBEDDING_3_SMALL, env_api_key, env_base_url, env_github_access_token, - route_for_model, + route_for_model, send_copilot_chat_streaming_request, }; use crate::client::CompletionClient; use crate::completion::CompletionModel; @@ -2015,6 +1965,35 @@ mod tests { ); } + #[tokio::test] + async fn chat_stream_captures_stop_reason() { + let sse_bytes = Bytes::from(concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n\n", + "data: [DONE]\n\n", + )); + + let http_client = MockStreamingClient { sse_bytes }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + + let mut stream = send_copilot_chat_streaming_request(http_client, req) + .await + .expect("stream should start"); + + while let Some(chunk) = stream.next().await { + chunk.expect("stream chunk should deserialize"); + } + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::Stop) + ); + } + #[test] fn env_api_key_prefers_github_prefixed_vars() { let env = env_map(&[ diff --git a/rig/rig-core/src/providers/deepseek.rs b/rig/rig-core/src/providers/deepseek.rs index 156b4854c..8f3827309 100644 --- a/rig/rig-core/src/providers/deepseek.rs +++ b/rig/rig-core/src/providers/deepseek.rs @@ -9,7 +9,6 @@ //! let deepseek_chat = client.completion_model(deepseek::DEEPSEEK_CHAT); //! ``` -use crate::json_utils::empty_or_none; use async_stream::stream; use bytes::Bytes; use futures::StreamExt; @@ -25,8 +24,12 @@ use crate::completion::GetTokenUsage; use crate::http_client::sse::{Event, GenericEventSource}; use crate::http_client::{self, HttpClientExt}; use crate::message::{Document, DocumentSourceKind}; +use crate::providers::openai::completion::{ + CompatibleChatProfile, CompatibleStreamingToolCall, ToolCallConflictPolicy, + apply_compatible_tool_call_deltas, build_compatible_request_core, map_finish_reason, + take_finalized_tool_calls, take_tool_calls, +}; use crate::{ - OneOrMany, completion::{self, CompletionError, CompletionRequest}, json_utils, message, }; @@ -381,9 +384,10 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + &choice.finish_reason, + )); let content = match &choice.message { Message::Assistant { content, @@ -421,12 +425,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse, #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, + tool_choice: Option, #[serde(flatten, skip_serializing_if = "Option::is_none")] pub additional_params: Option, } @@ -468,50 +469,37 @@ impl TryFrom<(&str, CompletionRequest)> for DeepseekCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for DeepSeek"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - let tool_choice = req - .tool_choice - .clone() - .map(crate::providers::openrouter::ToolChoice::try_from) - .transpose()?; + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = build_compatible_request_core( + model, + req, + CompatibleChatProfile::new("DeepSeek"), + Message::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = + tool_choice.map(crate::providers::openai::completion::CompatibleToolChoice::from); Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -697,6 +685,7 @@ pub struct StreamingDelta { #[derive(Deserialize, Debug)] struct StreamingChoice { delta: StreamingDelta, + finish_reason: Option, } #[derive(Deserialize, Debug)] @@ -742,8 +731,7 @@ where let stream = stream! { let mut final_usage = Usage::new(); - let mut text_response = String::new(); - let mut calls: HashMap = HashMap::new(); + let mut tool_calls: HashMap = HashMap::new(); while let Some(event_result) = event_source.next().await { match event_result { @@ -767,59 +755,47 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let function = &tool_call.function; - - // Start of tool call - if function.name.as_ref().map(|s| !s.is_empty()).unwrap_or(false) - && empty_or_none(&function.arguments) - { - let id = tool_call.id.clone().unwrap_or_default(); - let name = function.name.clone().unwrap(); - calls.insert(tool_call.index, (id, name, String::new())); - } - // Continuation of tool call - else if function.name.as_ref().map(|s| s.is_empty()).unwrap_or(true) - && let Some(arguments) = &function.arguments - && !arguments.is_empty() - { - if let Some((id, name, existing_args)) = calls.get(&tool_call.index) { - let combined = format!("{}{}", existing_args, arguments); - calls.insert(tool_call.index, (id.clone(), name.clone(), combined)); - } else { - tracing::debug!("Partial tool call received but tool call was never started."); - } - } - // Complete tool call - else { - let id = tool_call.id.clone().unwrap_or_default(); - let name = function.name.clone().unwrap_or_default(); - let arguments_str = function.arguments.clone().unwrap_or_default(); - - let Ok(arguments_json) = json_utils::parse_tool_arguments(&arguments_str) else { - tracing::debug!("Couldn't parse tool call args '{}'", arguments_str); - continue; - }; - - yield Ok(crate::streaming::RawStreamingChoice::ToolCall( - crate::streaming::RawStreamingToolCall::new(id, name, arguments_json) - )); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::KeepIndex, + ) { + yield Ok(event); } } // DeepSeek-specific reasoning stream - if let Some(content) = &delta.reasoning_content { + if let Some(content) = &delta.reasoning_content + && !content.is_empty() + { yield Ok(crate::streaming::RawStreamingChoice::ReasoningDelta { id: None, reasoning: content.to_string() }); } - if let Some(content) = &delta.content { - text_response += content; + if let Some(content) = &delta.content + && !content.is_empty() + { yield Ok(crate::streaming::RawStreamingChoice::Message(content.clone())); } + + if let Some(finish_reason) = &choice.finish_reason { + if finish_reason == "tool_calls" { + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); + } + } + + yield Ok(crate::streaming::RawStreamingChoice::StopReason( + map_finish_reason(finish_reason), + )); + } } if let Some(usage) = data.usage { @@ -839,25 +815,8 @@ where event_source.close(); - let mut tool_calls = Vec::new(); - // Flush accumulated tool calls - for (index, (id, name, arguments)) in calls { - let Ok(arguments_json) = json_utils::parse_tool_arguments(&arguments) else { - continue; - }; - - tool_calls.push(ToolCall { - id: id.clone(), - index, - r#type: ToolType::Function, - function: Function { - name: name.clone(), - arguments: arguments_json.clone() - } - }); - yield Ok(crate::streaming::RawStreamingChoice::ToolCall( - crate::streaming::RawStreamingToolCall::new(id, name, arguments_json) - )); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } yield Ok(crate::streaming::RawStreamingChoice::FinalResponse( @@ -880,6 +839,44 @@ pub const DEEPSEEK_REASONER: &str = "deepseek-reasoner"; #[cfg(test)] mod tests { use super::*; + use crate::http_client::mock::MockStreamingClient; + use crate::providers::openai::completion::request_conformance; + use crate::streaming::StreamedAssistantContent; + use bytes::Bytes; + use futures::StreamExt; + + struct DeepSeekRequestHarness; + + impl request_conformance::Harness for DeepSeekRequestHarness { + fn family_name() -> &'static str { + "deepseek" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + DeepseekCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new(CompatibleChatProfile::new( + "DeepSeek", + )) + .preserves_reasoning_assistant_history(), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(DeepSeekRequestHarness); #[test] fn test_deserialize_vec_choice() { @@ -1032,7 +1029,7 @@ mod tests { use crate::completion::message::{Message as RigMessage, UserContent}; let rig_msg = RigMessage::User { - content: OneOrMany::many(vec![ + content: crate::OneOrMany::many(vec![ UserContent::text("first part"), UserContent::text("second part"), ]) @@ -1065,7 +1062,7 @@ mod tests { let rig_msg = RigMessage::Assistant { id: None, - content: OneOrMany::many(vec![ + content: crate::OneOrMany::many(vec![ AssistantContent::reasoning("thinking about the problem"), AssistantContent::text("I'll call the tool"), AssistantContent::tool_call( @@ -1105,7 +1102,7 @@ mod tests { let rig_msg = RigMessage::Assistant { id: None, - content: OneOrMany::many(vec![ + content: crate::OneOrMany::many(vec![ AssistantContent::text("calling tool"), AssistantContent::tool_call("call_1", "add", serde_json::json!({"a": 1, "b": 2})), ]) @@ -1137,4 +1134,75 @@ mod tests { .build() .expect("Client::builder() failed"); } + + #[tokio::test] + async fn test_stream_captures_stop_reason() { + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n\n", + "data: [DONE]\n\n", + ); + let client = MockStreamingClient { + sse_bytes: Bytes::from(sse), + }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + + let mut stream = send_compatible_streaming_request(client, req) + .await + .expect("stream should start"); + + while let Some(chunk) = stream.next().await { + chunk.expect("stream chunk should deserialize"); + } + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::Stop) + ); + } + + #[tokio::test] + async fn test_stream_accumulates_tool_call_arguments_until_finish_reason() { + let sse = concat!( + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_123\",\"function\":{\"name\":\"subtract\",\"arguments\":\"\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"x\\\":2,\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"y\\\":5}\"}}]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n", + "data: [DONE]\n\n", + ); + let client = MockStreamingClient { + sse_bytes: Bytes::from(sse), + }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + + let mut stream = send_compatible_streaming_request(client, req) + .await + .expect("stream should start"); + let mut collected_tool_calls = Vec::new(); + + while let Some(chunk) = stream.next().await { + if let StreamedAssistantContent::ToolCall { + tool_call, + internal_call_id: _, + } = chunk.expect("stream chunk should deserialize") + { + collected_tool_calls.push(tool_call); + } + } + + assert_eq!(collected_tool_calls.len(), 1); + assert_eq!(collected_tool_calls[0].id, "call_123"); + assert_eq!(collected_tool_calls[0].function.name, "subtract"); + assert_eq!( + collected_tool_calls[0].function.arguments, + serde_json::json!({"x": 2, "y": 5}) + ); + } } diff --git a/rig/rig-core/src/providers/galadriel.rs b/rig/rig-core/src/providers/galadriel.rs index ea4e6e643..b68784d23 100644 --- a/rig/rig-core/src/providers/galadriel.rs +++ b/rig/rig-core/src/providers/galadriel.rs @@ -235,15 +235,21 @@ impl TryFrom for completion::CompletionResponse Result { - let Choice { message, .. } = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let Choice { + message, + finish_reason, + .. + } = openai::completion::first_choice(&response.choices)?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + finish_reason, + )); let mut content = message .content .as_ref() - .map(|c| vec![completion::AssistantContent::text(c)]) - .unwrap_or_default(); + .and_then(openai::completion::non_empty_text) + .into_iter() + .collect::>(); content.extend(message.tool_calls.iter().map(|call| { completion::AssistantContent::tool_call( @@ -253,11 +259,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse, } -impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { - type Error = CompletionError; +fn galadriel_request_messages(message: message::Message) -> Result, CompletionError> { + match message { + message::Message::System { content } => Ok(vec![Message { + role: "system".to_string(), + content: Some(content), + tool_calls: vec![], + }]), + message::Message::User { content } => { + let text = content + .into_iter() + .filter_map(|item| match item { + message::UserContent::Text(text) => Some(text.text), + message::UserContent::Document(document) => match document.data { + crate::message::DocumentSourceKind::Base64(content) + | crate::message::DocumentSourceKind::String(content) => Some(content), + _ => None, + }, + _ => None, + }) + .collect::>() + .join("\n"); - fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Galadriel"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); + if text.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: "user".to_string(), + content: Some(text), + tool_calls: vec![], + }]) + } } - partial_history.extend(req.chat_history); + message::Message::Assistant { content, .. } => { + let mut text_content: Option = None; + let mut tool_calls = vec![]; + + for item in content { + match item { + message::AssistantContent::Text(text) => { + let text = text.text; + text_content = Some( + text_content + .map(|mut existing| { + existing.push('\n'); + existing.push_str(&text); + existing + }) + .unwrap_or(text), + ); + } + message::AssistantContent::ToolCall(tool_call) => { + tool_calls.push(tool_call.into()); + } + message::AssistantContent::Reasoning(_) => {} + message::AssistantContent::Image(_) => { + return Err(CompletionError::RequestError( + MessageError::ConversionError( + "Galadriel currently doesn't support images.".into(), + ) + .into(), + )); + } + } + } - // Add preamble to chat history (if available) - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; + if text_content.is_none() && tool_calls.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: "assistant".to_string(), + content: text_content, + tool_calls, + }]) + } + } + } +} - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::, _>>()?, - ); +impl TryFrom<(&str, CompletionRequest)> for GaladrielCompletionRequest { + type Error = CompletionError; - let tool_choice = req - .tool_choice - .clone() + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Galadriel"), + Message::system, + None, + |_| false, + galadriel_request_messages, + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -668,6 +737,41 @@ where } #[cfg(test)] mod tests { + use super::GaladrielCompletionRequest; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct GaladrielRequestHarness; + + impl request_conformance::Harness for GaladrielRequestHarness { + fn family_name() -> &'static str { + "galadriel" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + GaladrielCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new(CompatibleChatProfile::new( + "Galadriel", + )), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(GaladrielRequestHarness); + #[test] fn test_client_initialization() { let _client = diff --git a/rig/rig-core/src/providers/gemini/completion.rs b/rig/rig-core/src/providers/gemini/completion.rs index 80d908848..7e94805ce 100644 --- a/rig/rig-core/src/providers/gemini/completion.rs +++ b/rig/rig-core/src/providers/gemini/completion.rs @@ -27,15 +27,12 @@ use self::gemini_api_types::Schema; use crate::http_client::HttpClientExt; use crate::message::{self, MimeType, Reasoning}; +use crate::completion::{self, CompletionError, CompletionRequest}; use crate::providers::gemini::completion::gemini_api_types::{ AdditionalParameters, FunctionCallingMode, ToolConfig, }; use crate::providers::gemini::streaming::StreamingCompletionResponse; use crate::telemetry::SpanCombinator; -use crate::{ - OneOrMany, - completion::{self, CompletionError, CompletionRequest}, -}; use gemini_api_types::{ Content, FunctionDeclaration, GenerateContentRequest, GenerateContentResponse, GenerationConfig, Part, PartKind, Role, Tool, @@ -51,6 +48,15 @@ use super::Client; // Rig Implementation Types // ================================================================= +pub(crate) fn map_finish_reason(reason: &gemini_api_types::FinishReason) -> completion::StopReason { + match reason { + gemini_api_types::FinishReason::Stop => completion::StopReason::Stop, + gemini_api_types::FinishReason::MaxTokens => completion::StopReason::MaxTokens, + gemini_api_types::FinishReason::Safety => completion::StopReason::Safety, + other => completion::StopReason::Other(format!("{other:?}")), + } +} + #[derive(Clone, Debug)] pub struct CompletionModel { pub(crate) client: Client, @@ -493,11 +499,7 @@ impl TryFrom for completion::CompletionResponse, _>>()?; - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage_metadata @@ -514,6 +516,11 @@ impl TryFrom for completion::CompletionResponse &'static str { + "gemini-generate-content" + } + + fn expected(case: Fixture) -> Outcome { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ToolOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::ToolCall { + id: "lookup_weather".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::TextAndToolCallTurn => Outcome::Supported(Turn { + items: vec![ + NormalizedItem::Text("Need weather data first.".to_string()), + NormalizedItem::ToolCall { + id: "lookup_weather".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }, + ], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::EmptyTextBlocks => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ReasoningOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::Reasoning( + "Need to reason about the tool result.".to_string(), + )], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::MessageIdPreservation => { + Outcome::Unsupported("Gemini GenerateContent responses do not expose message IDs") + } + Fixture::StopReasonMapping => Outcome::Supported(Turn { + items: vec![NormalizedItem::Text("Truncated response".to_string())], + message_id: None, + stop_reason: Some(StopReason::MaxTokens), + }), + } + } + + fn non_stream(case: Fixture) -> Result, CompletionError> { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let raw = non_stream_response(case); + let response: completion::CompletionResponse = + raw.try_into()?; + Ok(Outcome::Supported(normalize_completion_response(&response))) + } + } + } + + fn stream(case: Fixture) -> BoxFuture, CompletionError>> { + Box::pin(async move { + match case { + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let client = crate::providers::gemini::Client::builder() + .http_client(MockStreamingClient { + sse_bytes: Bytes::from(streaming_sse(case)), + }) + .api_key("test-key") + .build() + .expect("client should build"); + let model = CompletionModel::new(client, "gemini-test"); + let stream = model.stream(stream_request()).await?; + let response = drain_stream(stream).await?; + Ok(Outcome::Supported(normalize_completion_response(&response))) + } + } + }) + } +} + +fn non_stream_response(case: Fixture) -> GenerateContentResponse { + let candidate = match case { + Fixture::EmptyAssistantTurnAfterToolResult => candidate(vec![], FinishReason::Stop), + Fixture::ToolOnlyTurn => candidate( + vec![Part { + thought: None, + thought_signature: None, + part: PartKind::FunctionCall(gemini_api_types::FunctionCall { + name: "lookup_weather".to_string(), + args: json!({"city": "Paris"}), + }), + additional_params: None, + }], + FinishReason::Stop, + ), + Fixture::TextAndToolCallTurn => candidate( + vec![ + Part { + thought: Some(false), + thought_signature: None, + part: PartKind::Text("Need weather data first.".to_string()), + additional_params: None, + }, + Part { + thought: None, + thought_signature: None, + part: PartKind::FunctionCall(gemini_api_types::FunctionCall { + name: "lookup_weather".to_string(), + args: json!({"city": "Paris"}), + }), + additional_params: None, + }, + ], + FinishReason::Stop, + ), + Fixture::EmptyTextBlocks => candidate( + vec![Part { + thought: Some(false), + thought_signature: None, + part: PartKind::Text(String::new()), + additional_params: None, + }], + FinishReason::Stop, + ), + Fixture::ReasoningOnlyTurn => candidate( + vec![Part { + thought: Some(true), + thought_signature: Some("sig_1".to_string()), + part: PartKind::Text("Need to reason about the tool result.".to_string()), + additional_params: None, + }], + FinishReason::Stop, + ), + Fixture::StopReasonMapping => candidate( + vec![Part { + thought: Some(false), + thought_signature: None, + part: PartKind::Text("Truncated response".to_string()), + additional_params: None, + }], + FinishReason::MaxTokens, + ), + Fixture::MessageIdPreservation => unreachable!(), + }; + + GenerateContentResponse { + response_id: "resp_123".to_string(), + candidates: vec![candidate], + prompt_feedback: None, + usage_metadata: None, + model_version: None, + } +} + +fn candidate(parts: Vec, finish_reason: FinishReason) -> ContentCandidate { + ContentCandidate { + content: Some(Content { + parts, + role: Some(Role::Model), + }), + finish_reason: Some(finish_reason), + safety_ratings: None, + citation_metadata: None, + token_count: None, + avg_logprobs: None, + logprobs_result: None, + index: Some(0), + finish_message: None, + } +} + +fn stream_request() -> completion::CompletionRequest { + completion::CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(completion::Message::user("hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + } +} + +fn streaming_sse(case: Fixture) -> String { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => "data: {\"candidates\":[{\"content\":{\"parts\":[],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n".to_string(), + Fixture::ToolOnlyTurn => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"totalTokenCount\":15}}\n\n".to_string(), + Fixture::TextAndToolCallTurn => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need weather data first.\"},{\"functionCall\":{\"name\":\"lookup_weather\",\"args\":{\"city\":\"Paris\"}}}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":7,\"totalTokenCount\":17}}\n\n".to_string(), + Fixture::EmptyTextBlocks => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":0,\"totalTokenCount\":10}}\n\n".to_string(), + Fixture::ReasoningOnlyTurn => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Need to reason about the tool result.\",\"thought\":true,\"thoughtSignature\":\"sig_1\"}],\"role\":\"model\"},\"finishReason\":\"STOP\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"thoughtsTokenCount\":3,\"totalTokenCount\":18}}\n\n".to_string(), + Fixture::StopReasonMapping => "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Truncated response\"}],\"role\":\"model\"},\"finishReason\":\"MAX_TOKENS\",\"index\":0}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":2,\"totalTokenCount\":12}}\n\n".to_string(), + Fixture::MessageIdPreservation => unreachable!(), + } +} + +provider_conformance_tests!(GeminiHarness); diff --git a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs index cba503db2..65a77b135 100644 --- a/rig/rig-core/src/providers/gemini/interactions_api/mod.rs +++ b/rig/rig-core/src/providers/gemini/interactions_api/mod.rs @@ -1,7 +1,6 @@ //! Google Gemini Interactions API integration. //! From -use crate::OneOrMany; use crate::completion::{self, CompletionError, CompletionRequest, GetTokenUsage}; use crate::http_client::HttpClientExt; use crate::message::{self, MimeType, Reasoning}; @@ -490,11 +489,7 @@ impl TryFrom for completion::CompletionResponse { }) .collect::, _>>()?; - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage @@ -507,6 +502,7 @@ impl TryFrom for completion::CompletionResponse { usage, raw_response: response, message_id: None, + stop_reason: None, }) } } @@ -2502,7 +2498,7 @@ mod tests { let choice = response.choice.first(); match choice { - completion::AssistantContent::ToolCall(tool_call) => { + Some(completion::AssistantContent::ToolCall(tool_call)) => { assert_eq!(tool_call.function.name, "get_weather"); assert_eq!(tool_call.call_id.as_deref(), Some("call-123")); } diff --git a/rig/rig-core/src/providers/gemini/streaming.rs b/rig/rig-core/src/providers/gemini/streaming.rs index 1bb2f7b1e..7da7b2fc5 100644 --- a/rig/rig-core/src/providers/gemini/streaming.rs +++ b/rig/rig-core/src/providers/gemini/streaming.rs @@ -6,7 +6,8 @@ use tracing_futures::Instrument; use super::completion::gemini_api_types::{ContentCandidate, Part, PartKind}; use super::completion::{ - CompletionModel, create_request_body, resolve_request_model, streaming_endpoint, + CompletionModel, create_request_body, map_finish_reason, resolve_request_model, + streaming_endpoint, }; use crate::completion::message::ReasoningContent; use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; @@ -211,6 +212,11 @@ where // Check if this is the final response if choice.finish_reason.is_some() { + if let Some(stop_reason) = + choice.finish_reason.as_ref().map(map_finish_reason) + { + yield Ok(streaming::RawStreamingChoice::StopReason(stop_reason)); + } let span = tracing::Span::current(); span.record_token_usage(&data.usage_metadata); final_usage = data.usage_metadata; diff --git a/rig/rig-core/src/providers/groq.rs b/rig/rig-core/src/providers/groq.rs index 88ba5cc4e..80e0006d9 100644 --- a/rig/rig-core/src/providers/groq.rs +++ b/rig/rig-core/src/providers/groq.rs @@ -34,7 +34,6 @@ use futures::StreamExt; use crate::{ completion::{self, CompletionError, CompletionRequest}, json_utils, - message::{self}, providers::openai::ToolDefinition, transcription::{self, TranscriptionError}, }; @@ -184,42 +183,30 @@ pub(super) struct StreamOptions { impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { type Error = CompletionError; - fn try_from((model, mut req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Groq"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - // Add preamble to chat history (if available) - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![OpenAIMessage::system(preamble)], - None => vec![], - }; - - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); - - let tool_choice = req - .tool_choice - .clone() + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Groq"), + OpenAIMessage::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::ToolChoice::try_from) .transpose()?; - let mut additional_params_payload = req.additional_params.take().unwrap_or(Value::Null); + let mut additional_params_payload = additional_params.unwrap_or(Value::Null); let native_tools = extract_native_tools_from_additional_params(&mut additional_params_payload)?; @@ -232,12 +219,10 @@ impl TryFrom<(&str, CompletionRequest)> for GroqCompletionRequest { apply_native_tools_to_additional_params(&mut additional_params, native_tools); Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), @@ -822,12 +807,46 @@ where mod tests { use crate::{ OneOrMany, + completion::CompletionRequest, providers::{ groq::{GroqAdditionalParameters, GroqCompletionRequest}, + openai::completion::{CompatibleChatProfile, request_conformance}, openai::{Message, UserContent}, }, }; + struct GroqRequestHarness; + + impl request_conformance::Harness for GroqRequestHarness { + fn family_name() -> &'static str { + "groq" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + GroqCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new(CompatibleChatProfile::new( + "Groq", + )), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(GroqRequestHarness); + #[test] fn serialize_groq_request() { let additional_params = GroqAdditionalParameters { @@ -870,6 +889,28 @@ mod tests { }) ) } + + #[test] + fn groq_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("groq-override".to_string()), + preamble: Some("system".to_string()), + chat_history: OneOrMany::one("hello".into()), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let converted = GroqCompletionRequest::try_from(("groq-default", request)) + .expect("request should convert"); + + assert_eq!(converted.model, "groq-override"); + } + #[test] fn test_client_initialization() { let _client = diff --git a/rig/rig-core/src/providers/huggingface/completion.rs b/rig/rig-core/src/providers/huggingface/completion.rs index fe521313d..0d9b7eaef 100644 --- a/rig/rig-core/src/providers/huggingface/completion.rs +++ b/rig/rig-core/src/providers/huggingface/completion.rs @@ -549,9 +549,10 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + &choice.finish_reason, + )); let content = match &choice.message { Message::Assistant { @@ -561,8 +562,10 @@ impl TryFrom for completion::CompletionResponse { let mut content = content .iter() - .map(|c| match c { - AssistantContent::Text { text } => message::AssistantContent::text(text), + .filter_map(|c| match c { + AssistantContent::Text { text } => { + crate::providers::openai::completion::non_empty_text(text) + } }) .collect::>(); @@ -585,12 +588,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for HuggingfaceCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Huggingface"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "HuggingFace request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - let tool_choice = req - .tool_choice - .clone() + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("HuggingFace") + .require_messages(), + Message::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -818,8 +798,42 @@ where #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::request_conformance; use serde_path_to_error::deserialize; + struct HuggingFaceRequestHarness; + + impl request_conformance::Harness for HuggingFaceRequestHarness { + fn family_name() -> &'static str { + "huggingface" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + HuggingfaceCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("HuggingFace") + .require_messages(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(HuggingFaceRequestHarness); + #[test] fn test_huggingface_request_uses_request_model_override() { let request = CompletionRequest { diff --git a/rig/rig-core/src/providers/hyperbolic.rs b/rig/rig-core/src/providers/hyperbolic.rs index 38345249c..cf95814cc 100644 --- a/rig/rig-core/src/providers/hyperbolic.rs +++ b/rig/rig-core/src/providers/hyperbolic.rs @@ -17,7 +17,6 @@ use crate::streaming::StreamingCompletionResponse; use crate::providers::openai; use crate::{ - OneOrMany, completion::{self, CompletionError, CompletionRequest}, json_utils, providers::openai::Message, @@ -179,9 +178,10 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = openai::completion::first_choice(&response.choices)?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + &choice.finish_reason, + )); let content = match &choice.message { Message::Assistant { @@ -191,10 +191,10 @@ impl TryFrom for completion::CompletionResponse { let mut content = content .iter() - .map(|c| match c { - AssistantContent::Text { text } => completion::AssistantContent::text(text), + .filter_map(|c| match c { + AssistantContent::Text { text } => openai::completion::non_empty_text(text), AssistantContent::Refusal { refusal } => { - completion::AssistantContent::text(refusal) + openai::completion::non_empty_text(refusal) } }) .collect::>(); @@ -218,12 +218,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for HyperbolicCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Hyperbolic"); - } - - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - if req.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Hyperbolic"); - } - - if !req.tools.is_empty() { - tracing::warn!("WARNING: `tools` not supported on Hyperbolic"); - } - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools: _, + tool_choice: _, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Hyperbolic") + .unsupported_tools() + .unsupported_tool_choice(), + Message::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - additional_params: req.additional_params, + model, + messages, + temperature, + additional_params, }) } } @@ -708,6 +688,43 @@ mod audio_generation { #[cfg(test)] mod tests { + use super::HyperbolicCompletionRequest; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct HyperbolicRequestHarness; + + impl request_conformance::Harness for HyperbolicRequestHarness { + fn family_name() -> &'static str { + "hyperbolic" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + HyperbolicCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Hyperbolic") + .unsupported_tools() + .unsupported_tool_choice(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(HyperbolicRequestHarness); + #[test] fn test_client_initialization() { let _client = diff --git a/rig/rig-core/src/providers/llamafile.rs b/rig/rig-core/src/providers/llamafile.rs index 8fabb721f..db97497a5 100644 --- a/rig/rig-core/src/providers/llamafile.rs +++ b/rig/rig-core/src/providers/llamafile.rs @@ -164,45 +164,35 @@ impl TryFrom<(&str, CompletionRequest)> for LlamafileCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs may not be supported by llamafile"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - - // Build message history: preamble -> documents -> chat history - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![openai::Message::system(preamble)], - None => vec![], - }; - - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|msg| msg.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens, + tools, + tool_choice: _, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("llamafile") + .unsupported_tool_choice(), + openai::Message::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; Ok(Self { model, - messages: full_history, - temperature: req.temperature, - max_tokens: req.max_tokens, - tools: req - .tools + messages, + temperature, + max_tokens, + tools: tools .into_iter() .map(openai::ToolDefinition::from) .collect(), - additional_params: req.additional_params, + additional_params, }) } } @@ -651,6 +641,40 @@ where mod tests { use super::*; use crate::client::Nothing; + use crate::providers::openai::completion::request_conformance; + + struct LlamafileRequestHarness; + + impl request_conformance::Harness for LlamafileRequestHarness { + fn family_name() -> &'static str { + "llamafile" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + LlamafileCompletionRequest::try_from((LLAMA_CPP, request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("llamafile") + .unsupported_tool_choice(), + ), + LLAMA_CPP, + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(LlamafileRequestHarness); #[test] fn test_client_initialization() { @@ -698,4 +722,25 @@ mod tests { assert_eq!(request.temperature, Some(0.7)); assert_eq!(request.max_tokens, Some(256)); } + + #[test] + fn test_llamafile_request_uses_request_model_override() { + let completion_request = CompletionRequest { + model: Some("custom-llamafile".to_string()), + preamble: None, + chat_history: crate::OneOrMany::one(crate::completion::Message::user("Hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let request = LlamafileCompletionRequest::try_from((LLAMA_CPP, completion_request)) + .expect("Failed to create request"); + + assert_eq!(request.model, "custom-llamafile"); + } } diff --git a/rig/rig-core/src/providers/mira.rs b/rig/rig-core/src/providers/mira.rs index 9b9448007..8fd9052dd 100644 --- a/rig/rig-core/src/providers/mira.rs +++ b/rig/rig-core/src/providers/mira.rs @@ -221,82 +221,90 @@ pub(super) struct MiraCompletionRequest { pub stream: bool, } -impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { - type Error = CompletionError; - - fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Mira"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut messages = Vec::new(); - - if let Some(content) = &req.preamble { - messages.push(RawMessage { - role: "user".to_string(), - content: content.to_string(), - }); - } - - if let Some(Message::User { content }) = req.normalized_documents() { +fn mira_request_messages(message: Message) -> Result, CompletionError> { + match message { + Message::System { content } => Ok(vec![RawMessage { + role: "system".to_string(), + content, + }]), + Message::User { content } => { let text = content .into_iter() - .filter_map(|doc| match doc { + .filter_map(|item| match item { + UserContent::Text(text) => Some(text.text), UserContent::Document(Document { data: DocumentSourceKind::Base64(data) | DocumentSourceKind::String(data), .. }) => Some(data), - UserContent::Text(text) => Some(text.text), - - // This should always be `Document` _ => None, }) .collect::>() .join("\n"); - messages.push(RawMessage { - role: "user".to_string(), - content: text, - }); + if text.is_empty() { + Ok(vec![]) + } else { + Ok(vec![RawMessage { + role: "user".to_string(), + content: text, + }]) + } } + Message::Assistant { content, .. } => { + let text = content + .into_iter() + .filter_map(|item| match item { + AssistantContent::Text(text) => Some(text.text), + _ => None, + }) + .collect::>() + .join("\n"); - for msg in req.chat_history { - let (role, content) = match msg { - Message::System { content } => ("system", content), - Message::User { content } => { - let text = content - .iter() - .map(|c| match c { - UserContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join("\n"); - ("user", text) - } - Message::Assistant { content, .. } => { - let text = content - .iter() - .map(|c| match c { - AssistantContent::Text(text) => &text.text, - _ => "", - }) - .collect::>() - .join("\n"); - ("assistant", text) - } - }; - messages.push(RawMessage { - role: role.to_string(), - content, - }); + if text.is_empty() { + Ok(vec![]) + } else { + Ok(vec![RawMessage { + role: "assistant".to_string(), + content: text, + }]) + } } + } +} + +impl TryFrom<(&str, CompletionRequest)> for MiraCompletionRequest { + type Error = CompletionError; + + fn try_from((model, req): (&str, CompletionRequest)) -> Result { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens, + tools: _, + tool_choice: _, + additional_params: _, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Mira AI") + .unsupported_tools() + .unsupported_tool_choice() + .unsupported_additional_params(), + |content| RawMessage { + role: "user".to_string(), + content: content.to_string(), + }, + None, + |_| false, + mira_request_messages, + )?; Ok(Self { - model: model.to_string(), + model, messages, - temperature: req.temperature, - max_tokens: req.max_tokens, + temperature, + max_tokens, stream: false, }) } @@ -355,21 +363,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if !completion_request.tools.is_empty() { - tracing::warn!(target: "rig::completions", - "Tool calls are not supported by Mira AI. {len} tools will be ignored.", - len = completion_request.tools.len() - ); - } - - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); - } - - if completion_request.additional_params.is_some() { - tracing::warn!("WARNING: Additional parameters not supported on Mira AI"); - } - let request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?; if tracing::enabled!(tracing::Level::TRACE) { @@ -460,20 +453,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if !completion_request.tools.is_empty() { - tracing::warn!(target: "rig::completions", - "Tool calls are not supported by Mira AI. {len} tools will be ignored.", - len = completion_request.tools.len() - ); - } - - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Mira AI"); - } - - if completion_request.additional_params.is_some() { - tracing::warn!("WARNING: Additional parameters not supported on Mira AI"); - } let mut request = MiraCompletionRequest::try_from((self.model.as_ref(), completion_request))?; request.stream = true; @@ -577,17 +556,21 @@ impl TryFrom for completion::CompletionResponse choices + .first() + .and_then(|choice| choice.finish_reason.as_deref()) + .map(crate::providers::openai::completion::map_finish_reason), + CompletionResponse::Simple(_) => None, + }; Ok(completion::CompletionResponse { choice, usage, raw_response: response, message_id: None, + stop_reason, }) } } @@ -701,8 +684,45 @@ impl TryFrom for Message { mod tests { use super::*; use crate::message::UserContent; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; use serde_json::json; + struct MiraRequestHarness; + + impl request_conformance::Harness for MiraRequestHarness { + fn family_name() -> &'static str { + "mira" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + MiraCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Mira AI") + .unsupported_tools() + .unsupported_tool_choice() + .unsupported_additional_params(), + ) + .preamble_as_user(), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(MiraRequestHarness); + #[test] fn test_deserialize_message() { // Test string content format @@ -810,7 +830,7 @@ mod tests { assert_eq!( completion_response.choice.first(), - completion::AssistantContent::text("Test response") + Some(completion::AssistantContent::text("Test response")) ); } #[test] diff --git a/rig/rig-core/src/providers/mistral/completion.rs b/rig/rig-core/src/providers/mistral/completion.rs index 101d644ce..9b06753d5 100644 --- a/rig/rig-core/src/providers/mistral/completion.rs +++ b/rig/rig-core/src/providers/mistral/completion.rs @@ -8,7 +8,6 @@ use crate::completion::GetTokenUsage; use crate::http_client::{self, HttpClientExt}; use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse}; use crate::{ - OneOrMany, completion::{self, CompletionError, CompletionRequest}, json_utils, message, providers::mistral::client::ApiResponse, @@ -347,59 +346,40 @@ impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Mistral"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble.clone())], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Mistral request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - let tool_choice = req - .tool_choice + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Mistral") + .require_messages(), + |preamble| Message::system(preamble.to_owned()), + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice .clone() .map(crate::providers::openai::completion::ToolChoice::try_from) .transpose()?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -489,20 +469,19 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + &choice.finish_reason, + )); let content = match &choice.message { Message::Assistant { content, tool_calls, .. } => { - let mut content = if content.is_empty() { - vec![] - } else { - vec![completion::AssistantContent::text(content.clone())] - }; + let mut content = crate::providers::openai::completion::non_empty_text(content) + .into_iter() + .collect::>(); content.extend( tool_calls @@ -523,12 +502,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse &'static str { + "mistral" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + MistralCompletionRequest::try_from((MISTRAL_SMALL, request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("Mistral") + .require_messages(), + ) + .omits_document_messages(), + MISTRAL_SMALL, + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(MistralRequestHarness); #[test] fn test_response_deserialization() { @@ -737,7 +748,7 @@ mod tests { fn test_assistant_reasoning_is_skipped_in_message_conversion() { let assistant = message::Message::Assistant { id: None, - content: OneOrMany::one(message::AssistantContent::reasoning("hidden")), + content: crate::OneOrMany::one(message::AssistantContent::reasoning("hidden")), }; let converted: Vec = assistant.try_into().expect("conversion should work"); @@ -748,7 +759,7 @@ mod tests { fn test_assistant_text_and_tool_call_are_preserved_when_reasoning_present() { let assistant = message::Message::Assistant { id: None, - content: OneOrMany::many(vec![ + content: crate::OneOrMany::many(vec![ message::AssistantContent::reasoning("hidden"), message::AssistantContent::text("visible"), message::AssistantContent::tool_call( @@ -818,9 +829,9 @@ mod tests { fn test_request_conversion_errors_when_all_messages_are_filtered() { let request = CompletionRequest { preamble: None, - chat_history: OneOrMany::one(message::Message::Assistant { + chat_history: crate::OneOrMany::one(message::Message::Assistant { id: None, - content: OneOrMany::one(message::AssistantContent::reasoning("hidden")), + content: crate::OneOrMany::one(message::AssistantContent::reasoning("hidden")), }), documents: vec![], tools: vec![], @@ -835,4 +846,25 @@ mod tests { let result = MistralCompletionRequest::try_from((MISTRAL_SMALL, request)); assert!(matches!(result, Err(CompletionError::RequestError(_)))); } + + #[test] + fn test_mistral_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("mistral-custom".to_string()), + preamble: Some("System".to_string()), + chat_history: crate::OneOrMany::one(message::Message::user("Hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let request = + MistralCompletionRequest::try_from((MISTRAL_SMALL, request)).expect("request"); + + assert_eq!(request.model, "mistral-custom"); + } } diff --git a/rig/rig-core/src/providers/mod.rs b/rig/rig-core/src/providers/mod.rs index 24e9bece3..f1f825e99 100644 --- a/rig/rig-core/src/providers/mod.rs +++ b/rig/rig-core/src/providers/mod.rs @@ -49,6 +49,8 @@ pub mod anthropic; pub mod azure; pub mod chatgpt; pub mod cohere; +#[cfg(test)] +pub(crate) mod conformance; pub mod copilot; pub mod deepseek; pub mod galadriel; diff --git a/rig/rig-core/src/providers/moonshot.rs b/rig/rig-core/src/providers/moonshot.rs index 13a034345..ca54224e7 100644 --- a/rig/rig-core/src/providers/moonshot.rs +++ b/rig/rig-core/src/providers/moonshot.rs @@ -39,7 +39,7 @@ use crate::{ }; use crate::{http_client, message}; use serde::{Deserialize, Serialize}; -use serde_json::{Value, json}; +use serde_json::Value; use tracing::{Instrument, info_span}; // ================================================================ @@ -319,61 +319,52 @@ impl TryFrom<(&str, CompletionRequest)> for MoonshotCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Moonshot"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - // Build up the order of messages (context, chat_history, prompt) - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![serde_json::to_value(openai::Message::system(preamble))?], - None => vec![], - }; - - full_history.extend(moonshot_history_values(partial_history)?); - - let mut tool_choice = None; - let mut tool_choice_required = false; - if let Some(choice) = req.tool_choice.clone() { - match choice { - message::ToolChoice::Required => { - tool_choice_required = true; - tool_choice = Some(crate::providers::openai::completion::ToolChoice::Auto); - } - other => { - tool_choice = Some(crate::providers::openai::ToolChoice::try_from(other)?); - } - } - } - - if tool_choice_required { - tracing::warn!( - "Moonshot does not support tool_choice=required; coercing to auto with an additional steering message" - ); - full_history.push(json!({ - "role": "user", - "content": "Please select a tool to handle the current issue." - })); - } + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages: full_history, + temperature, + max_tokens, + tools, + tool_choice: request_tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Moonshot") + .coerce_required_tool_choice_to_auto( + "Please select a tool to handle the current issue.", + ), + |preamble| { + serde_json::json!({ + "role": "system", + "content": preamble, + }) + }, + Some(|content| { + serde_json::json!({ + "role": "user", + "content": content, + }) + }), + |_| false, + |message| moonshot_history_values(vec![message]), + )?; + + let tool_choice = request_tool_choice + .map(crate::providers::openai::ToolChoice::try_from) + .transpose()?; Ok(Self { - model: model.to_string(), + model, messages: full_history, - temperature: req.temperature, - max_tokens: req.max_tokens, - tools: req - .tools - .clone() + temperature, + max_tokens, + tools: tools .into_iter() .map(openai::ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -655,6 +646,42 @@ mod tests { use crate::message::{ AssistantContent, Message, Reasoning, ToolCall, ToolChoice, ToolFunction, }; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct MoonshotRequestHarness; + + impl request_conformance::Harness for MoonshotRequestHarness { + fn family_name() -> &'static str { + "moonshot" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + MoonshotCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Moonshot").coerce_required_tool_choice_to_auto( + "Please select a tool to handle the current issue.", + ), + ) + .preserves_reasoning_assistant_history(), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(MoonshotRequestHarness); #[test] fn test_client_initialization() { diff --git a/rig/rig-core/src/providers/ollama.rs b/rig/rig-core/src/providers/ollama.rs index 660bef677..60f74c810 100644 --- a/rig/rig-core/src/providers/ollama.rs +++ b/rig/rig-core/src/providers/ollama.rs @@ -351,6 +351,10 @@ pub struct CompletionResponse { impl TryFrom for completion::CompletionResponse { type Error = CompletionError; fn try_from(resp: CompletionResponse) -> Result { + let stop_reason = resp + .done_reason + .as_deref() + .map(crate::providers::openai::completion::map_finish_reason); match resp.message { // Process only if an assistant message is present. Message::Assistant { @@ -373,9 +377,7 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse Err(CompletionError::ResponseError( diff --git a/rig/rig-core/src/providers/openai/completion/conformance_tests.rs b/rig/rig-core/src/providers/openai/completion/conformance_tests.rs new file mode 100644 index 000000000..f4c8b97fa --- /dev/null +++ b/rig/rig-core/src/providers/openai/completion/conformance_tests.rs @@ -0,0 +1,248 @@ +use bytes::Bytes; +use serde_json::json; + +use super::*; +use crate::{ + completion::{self, CompletionError}, + http_client::mock::MockStreamingClient, + providers::conformance::{ + BoxFuture, Fixture, Harness, NormalizedItem, Outcome, StopReason, Turn, drain_stream, + normalize_completion_response, provider_conformance_tests, + }, +}; + +struct OpenAiChatHarness; + +impl Harness for OpenAiChatHarness { + fn family_name() -> &'static str { + "openai-chat" + } + + fn expected(case: Fixture) -> Outcome { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ToolOnlyTurn => Outcome::Supported(Turn { + items: vec![NormalizedItem::ToolCall { + id: "call_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::TextAndToolCallTurn => Outcome::Supported(Turn { + items: vec![ + NormalizedItem::Text("Need weather data first.".to_string()), + NormalizedItem::ToolCall { + id: "call_lookup".to_string(), + name: "lookup_weather".to_string(), + arguments: json!({"city": "Paris"}), + }, + ], + message_id: None, + stop_reason: Some(StopReason::ToolCalls), + }), + Fixture::EmptyTextBlocks => Outcome::Supported(Turn { + items: vec![], + message_id: None, + stop_reason: Some(StopReason::Stop), + }), + Fixture::ReasoningOnlyTurn => Outcome::Unsupported( + "OpenAI chat completions do not expose normalized reasoning blocks", + ), + Fixture::MessageIdPreservation => { + Outcome::Unsupported("OpenAI chat completions do not expose message IDs") + } + Fixture::StopReasonMapping => Outcome::Supported(Turn { + items: vec![NormalizedItem::Text("Truncated response".to_string())], + message_id: None, + stop_reason: Some(StopReason::MaxTokens), + }), + } + } + + fn non_stream(case: Fixture) -> Result, CompletionError> { + match case { + Fixture::ReasoningOnlyTurn => Ok(Self::expected(case)), + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let raw = non_stream_response(case); + let response: completion::CompletionResponse = + raw.try_into()?; + Ok(Outcome::Supported(normalize_completion_response(&response))) + } + } + } + + fn stream(case: Fixture) -> BoxFuture, CompletionError>> { + Box::pin(async move { + match case { + Fixture::ReasoningOnlyTurn => Ok(Self::expected(case)), + Fixture::MessageIdPreservation => Ok(Self::expected(case)), + _ => { + let client = MockStreamingClient { + sse_bytes: Bytes::from(streaming_sse(case)), + }; + let request = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .expect("request should build"); + let stream = + streaming::send_compatible_streaming_request(client, request).await?; + let response = drain_stream(stream).await?; + Ok(Outcome::Supported(normalize_completion_response(&response))) + } + } + }) + } +} + +fn non_stream_response(case: Fixture) -> CompletionResponse { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => response_with_message( + Message::Assistant { + content: vec![], + refusal: None, + audio: None, + name: None, + tool_calls: vec![], + }, + "stop", + ), + Fixture::ToolOnlyTurn => response_with_message( + Message::Assistant { + content: vec![], + refusal: None, + audio: None, + name: None, + tool_calls: vec![tool_call( + "call_lookup", + "lookup_weather", + json!({"city": "Paris"}), + )], + }, + "tool_calls", + ), + Fixture::TextAndToolCallTurn => response_with_message( + Message::Assistant { + content: vec![AssistantContent::Text { + text: "Need weather data first.".to_string(), + }], + refusal: None, + audio: None, + name: None, + tool_calls: vec![tool_call( + "call_lookup", + "lookup_weather", + json!({"city": "Paris"}), + )], + }, + "tool_calls", + ), + Fixture::EmptyTextBlocks => response_with_message( + Message::Assistant { + content: vec![AssistantContent::Text { + text: String::new(), + }], + refusal: None, + audio: None, + name: None, + tool_calls: vec![], + }, + "stop", + ), + Fixture::StopReasonMapping => response_with_message( + Message::Assistant { + content: vec![AssistantContent::Text { + text: "Truncated response".to_string(), + }], + refusal: None, + audio: None, + name: None, + tool_calls: vec![], + }, + "length", + ), + Fixture::ReasoningOnlyTurn | Fixture::MessageIdPreservation => { + unreachable!("unsupported cases are handled before construction") + } + } +} + +fn response_with_message(message: Message, finish_reason: &str) -> CompletionResponse { + CompletionResponse { + id: "chatcmpl_123".to_string(), + object: "chat.completion".to_string(), + created: 0, + model: "gpt-4o-mini".to_string(), + system_fingerprint: None, + choices: vec![Choice { + index: 0, + message, + logprobs: None, + finish_reason: finish_reason.to_string(), + }], + usage: Some(Usage { + prompt_tokens: 10, + total_tokens: 15, + prompt_tokens_details: None, + }), + } +} + +fn tool_call(id: &str, name: &str, arguments: serde_json::Value) -> ToolCall { + ToolCall { + id: id.to_string(), + r#type: ToolType::Function, + function: Function { + name: name.to_string(), + arguments, + }, + } +} + +fn streaming_sse(case: Fixture) -> String { + match case { + Fixture::EmptyAssistantTurnAfterToolResult => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":0,\"total_tokens\":10}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::ToolOnlyTurn => concat!( + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_lookup\",\"function\":{\"name\":\"lookup_weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":5,\"total_tokens\":15}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::TextAndToolCallTurn => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Need weather data first.\",\"tool_calls\":[]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_lookup\",\"function\":{\"name\":\"lookup_weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}]},\"finish_reason\":null}],\"usage\":null}\n\n", + "data: {\"choices\":[{\"delta\":{\"tool_calls\":[]},\"finish_reason\":\"tool_calls\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":7,\"total_tokens\":17}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::EmptyTextBlocks => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":0,\"total_tokens\":10}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::StopReasonMapping => concat!( + "data: {\"choices\":[{\"delta\":{\"content\":\"Truncated response\",\"tool_calls\":[]},\"finish_reason\":\"length\"}],\"usage\":null}\n\n", + "data: {\"choices\":[],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":2,\"total_tokens\":12}}\n\n", + "data: [DONE]\n\n", + ) + .to_string(), + Fixture::ReasoningOnlyTurn | Fixture::MessageIdPreservation => unreachable!(), + } +} + +provider_conformance_tests!(OpenAiChatHarness); diff --git a/rig/rig-core/src/providers/openai/completion/family.rs b/rig/rig-core/src/providers/openai/completion/family.rs new file mode 100644 index 000000000..676a3561e --- /dev/null +++ b/rig/rig-core/src/providers/openai/completion/family.rs @@ -0,0 +1,1513 @@ +use std::collections::HashMap; + +use crate::completion::{self, CompletionError, CompletionRequest as CoreCompletionRequest}; +use crate::json_utils; +use crate::message; +use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent}; +use serde::{Deserialize, Serialize}; + +pub(crate) fn first_choice(choices: &[T]) -> Result<&T, CompletionError> { + choices + .first() + .ok_or_else(|| CompletionError::ResponseError("Response contained no choices".to_owned())) +} + +pub(crate) fn map_finish_reason(reason: &str) -> completion::StopReason { + match reason { + "stop" => completion::StopReason::Stop, + "tool_calls" => completion::StopReason::ToolCalls, + "content_filter" => completion::StopReason::ContentFilter, + "length" => completion::StopReason::MaxTokens, + other => completion::StopReason::Other(other.to_string()), + } +} + +pub(crate) fn non_empty_text(text: impl AsRef) -> Option { + let text = text.as_ref(); + if text.is_empty() { + None + } else { + Some(completion::AssistantContent::text(text)) + } +} + +pub(crate) fn build_completion_response( + raw_response: R, + usage: completion::Usage, + message_id: Option, + stop_reason: Option, + choice: C, +) -> completion::CompletionResponse +where + C: Into, +{ + completion::CompletionResponse { + choice: choice.into(), + usage, + raw_response, + message_id, + stop_reason, + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum CompatibleToolChoiceKeyword { + None, + Auto, + Required, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", content = "function")] +pub enum CompatibleToolChoiceFunctionKind { + Function { name: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(untagged)] +pub enum CompatibleToolChoice { + Keyword(CompatibleToolChoiceKeyword), + Function(Vec), +} + +impl From for CompatibleToolChoice { + fn from(value: crate::message::ToolChoice) -> Self { + match value { + crate::message::ToolChoice::None => Self::Keyword(CompatibleToolChoiceKeyword::None), + crate::message::ToolChoice::Auto => Self::Keyword(CompatibleToolChoiceKeyword::Auto), + crate::message::ToolChoice::Required => { + Self::Keyword(CompatibleToolChoiceKeyword::Required) + } + crate::message::ToolChoice::Specific { function_names } => Self::Function( + function_names + .into_iter() + .map(|name| CompatibleToolChoiceFunctionKind::Function { name }) + .collect(), + ), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ToolsPolicy { + Supported, + Unsupported, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ToolChoicePolicy { + PassThrough, + Unsupported, + RejectRequired, + CoerceRequiredToAuto { steering_message: &'static str }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum OutputSchemaPolicy { + Unsupported, + NativeResponseFormat, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum AdditionalParamsPolicy { + PassThrough, + Unsupported, +} + +/// Shared unsupported-feature policy for request builders. +#[derive(Debug, Clone, Copy)] +pub(crate) struct CompatibleFeaturePolicy { + tools_policy: ToolsPolicy, + tool_choice_policy: ToolChoicePolicy, + output_schema_policy: OutputSchemaPolicy, + additional_params_policy: AdditionalParamsPolicy, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub(crate) struct CompatibleRequestAdjustments { + steering_user_message: Option<&'static str>, +} + +impl Default for CompatibleFeaturePolicy { + fn default() -> Self { + Self { + tools_policy: ToolsPolicy::Supported, + tool_choice_policy: ToolChoicePolicy::PassThrough, + output_schema_policy: OutputSchemaPolicy::Unsupported, + additional_params_policy: AdditionalParamsPolicy::PassThrough, + } + } +} + +impl CompatibleFeaturePolicy { + pub(crate) const fn with_tools_policy(mut self, policy: ToolsPolicy) -> Self { + self.tools_policy = policy; + self + } + + pub(crate) const fn with_tool_choice_policy(mut self, policy: ToolChoicePolicy) -> Self { + self.tool_choice_policy = policy; + self + } + + pub(crate) const fn with_output_schema_policy(mut self, policy: OutputSchemaPolicy) -> Self { + self.output_schema_policy = policy; + self + } + + pub(crate) const fn with_additional_params_policy( + mut self, + policy: AdditionalParamsPolicy, + ) -> Self { + self.additional_params_policy = policy; + self + } + + pub(crate) fn apply( + self, + provider_name: &'static str, + req: &mut CoreCompletionRequest, + ) -> Result { + if req.output_schema.is_some() + && matches!(self.output_schema_policy, OutputSchemaPolicy::Unsupported) + { + tracing::warn!( + "Structured outputs currently not supported for {}", + provider_name + ); + req.output_schema = None; + } + + if !req.tools.is_empty() && matches!(self.tools_policy, ToolsPolicy::Unsupported) { + tracing::warn!("WARNING: `tools` not supported on {}", provider_name); + req.tools.clear(); + } + + let mut adjustments = CompatibleRequestAdjustments::default(); + if let Some(choice) = req.tool_choice.clone() { + match self.tool_choice_policy { + ToolChoicePolicy::PassThrough => {} + ToolChoicePolicy::Unsupported => { + tracing::warn!("WARNING: `tool_choice` not supported on {}", provider_name); + req.tool_choice = None; + } + ToolChoicePolicy::RejectRequired => { + if matches!(choice, crate::message::ToolChoice::Required) { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("{provider_name} does not support tool_choice=required"), + ) + .into(), + )); + } + } + ToolChoicePolicy::CoerceRequiredToAuto { steering_message } => { + if matches!(choice, crate::message::ToolChoice::Required) { + tracing::warn!( + "{} does not support tool_choice=required; coercing to auto with an additional steering message", + provider_name + ); + req.tool_choice = Some(crate::message::ToolChoice::Auto); + adjustments.steering_user_message = Some(steering_message); + } + } + } + } + + if req.additional_params.is_some() + && matches!( + self.additional_params_policy, + AdditionalParamsPolicy::Unsupported + ) + { + tracing::warn!( + "WARNING: `additional_params` not supported on {}", + provider_name + ); + req.additional_params = None; + } + + Ok(adjustments) + } +} + +/// Shared request-shaping profile for OpenAI-compatible chat providers. +#[derive(Debug, Clone, Copy)] +pub struct CompatibleChatProfile { + provider_name: &'static str, + require_messages: bool, + feature_policy: CompatibleFeaturePolicy, +} + +impl CompatibleChatProfile { + pub(crate) const fn new(provider_name: &'static str) -> Self { + Self { + provider_name, + require_messages: false, + feature_policy: CompatibleFeaturePolicy { + tools_policy: ToolsPolicy::Supported, + tool_choice_policy: ToolChoicePolicy::PassThrough, + output_schema_policy: OutputSchemaPolicy::Unsupported, + additional_params_policy: AdditionalParamsPolicy::PassThrough, + }, + } + } + + pub(crate) const fn openai_chat_completions(provider_name: &'static str) -> Self { + Self::new(provider_name) + .require_messages() + .native_response_format() + } + + pub(crate) const fn require_messages(mut self) -> Self { + self.require_messages = true; + self + } + + pub(crate) const fn native_response_format(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_output_schema_policy(OutputSchemaPolicy::NativeResponseFormat); + self + } + + pub(crate) const fn unsupported_tools(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_tools_policy(ToolsPolicy::Unsupported); + self + } + + pub(crate) const fn unsupported_tool_choice(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_tool_choice_policy(ToolChoicePolicy::Unsupported); + self + } + + pub(crate) const fn reject_required_tool_choice(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_tool_choice_policy(ToolChoicePolicy::RejectRequired); + self + } + + pub(crate) const fn coerce_required_tool_choice_to_auto( + mut self, + steering_message: &'static str, + ) -> Self { + self.feature_policy = self + .feature_policy + .with_tool_choice_policy(ToolChoicePolicy::CoerceRequiredToAuto { steering_message }); + self + } + + pub(crate) const fn unsupported_additional_params(mut self) -> Self { + self.feature_policy = self + .feature_policy + .with_additional_params_policy(AdditionalParamsPolicy::Unsupported); + self + } +} + +/// Provider-agnostic core request fields shared by OpenAI-compatible chat families. +#[derive(Debug)] +pub(crate) struct CompatibleRequestCore { + pub model: String, + pub messages: Vec, + pub temperature: Option, + pub max_tokens: Option, + pub tools: Vec, + pub tool_choice: Option, + pub additional_params: Option, +} + +fn merge_native_response_format( + additional_params: Option, + schema: schemars::Schema, +) -> Option { + let name = schema + .as_object() + .and_then(|o| o.get("title")) + .and_then(|v| v.as_str()) + .unwrap_or("response_schema") + .to_string(); + let mut schema_value = schema.to_value(); + super::super::sanitize_schema(&mut schema_value); + let response_format = serde_json::json!({ + "response_format": { + "type": "json_schema", + "json_schema": { + "name": name, + "strict": true, + "schema": schema_value + } + } + }); + + Some(match additional_params { + Some(existing) => json_utils::merge(existing, response_format), + None => response_format, + }) +} + +pub(crate) fn build_compatible_request_core( + default_model: &str, + mut req: CoreCompletionRequest, + profile: CompatibleChatProfile, + system_message: fn(&str) -> M, + user_message: Option M>, + message_has_tool_result: fn(&M) -> bool, + mut convert_message: F, +) -> Result, CompletionError> +where + F: FnMut(message::Message) -> Result, CompletionError>, +{ + let adjustments = profile + .feature_policy + .apply(profile.provider_name, &mut req)?; + + let normalized_documents = req.normalized_documents(); + let CoreCompletionRequest { + model: request_model, + preamble, + chat_history, + tools, + temperature, + max_tokens, + tool_choice, + additional_params, + output_schema, + .. + } = req; + + let mut partial_history = Vec::new(); + if let Some(docs) = normalized_documents { + partial_history.push(docs); + } + partial_history.extend(chat_history); + + let mut messages = preamble + .as_deref() + .map_or_else(Vec::new, |preamble| vec![system_message(preamble)]); + + messages.extend( + partial_history + .into_iter() + .map(&mut convert_message) + .collect::>, _>>()? + .into_iter() + .flatten(), + ); + + if let Some(steering_message) = adjustments.steering_user_message { + let Some(user_message) = user_message else { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "{} requires a user-message adapter when coercing tool_choice", + profile.provider_name + ), + ) + .into(), + )); + }; + messages.push(user_message(steering_message)); + } + + if profile.require_messages && messages.is_empty() { + return Err(CompletionError::RequestError( + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!( + "{} request has no provider-compatible messages after conversion", + profile.provider_name + ), + ) + .into(), + )); + } + + let mut additional_params = additional_params; + if let Some(schema) = output_schema { + match profile.feature_policy.output_schema_policy { + OutputSchemaPolicy::NativeResponseFormat => { + let should_apply_response_format = + tools.is_empty() || messages.iter().any(message_has_tool_result); + if should_apply_response_format { + additional_params = merge_native_response_format(additional_params, schema); + } + } + OutputSchemaPolicy::Unsupported => {} + } + } + + Ok(CompatibleRequestCore { + model: request_model.unwrap_or_else(|| default_model.to_owned()), + messages, + temperature, + max_tokens, + tools, + tool_choice, + additional_params, + }) +} + +/// Internal provider profile used by the OpenAI-compatible chat family. +#[doc(hidden)] +pub trait OpenAiChatProviderProfile { + fn provider_name() -> &'static str; + + fn telemetry_provider_name() -> &'static str { + Self::provider_name() + } + + fn completions_path() -> &'static str { + "/chat/completions" + } + + fn request_profile() -> CompatibleChatProfile { + CompatibleChatProfile::openai_chat_completions(Self::provider_name()) + } + + fn stream_tool_call_conflict_policy() -> ToolCallConflictPolicy { + ToolCallConflictPolicy::EvictDistinctIdAndName + } +} + +impl OpenAiChatProviderProfile for super::super::client::OpenAICompletionsExt { + fn provider_name() -> &'static str { + "OpenAI Chat Completions" + } + + fn telemetry_provider_name() -> &'static str { + "openai" + } +} + +impl OpenAiChatProviderProfile for crate::providers::minimax::MiniMaxExt { + fn provider_name() -> &'static str { + "MiniMax" + } + + fn telemetry_provider_name() -> &'static str { + "minimax" + } +} + +impl OpenAiChatProviderProfile for crate::providers::zai::ZAiExt { + fn provider_name() -> &'static str { + "Z.AI" + } + + fn telemetry_provider_name() -> &'static str { + "z.ai" + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolCallConflictPolicy { + KeepIndex, + EvictDistinctIdAndName, +} + +#[derive(Debug, Clone, Copy)] +pub(crate) struct CompatibleStreamingToolCall<'a> { + pub index: usize, + pub id: Option<&'a str>, + pub name: Option<&'a str>, + pub arguments: Option<&'a str>, +} + +pub(crate) fn apply_compatible_tool_call_deltas<'a, R>( + tool_calls: &mut HashMap, + incoming: impl IntoIterator>, + conflict_policy: ToolCallConflictPolicy, +) -> Vec> +where + R: Clone, +{ + let mut events = Vec::new(); + + for tool_call in incoming { + let index = tool_call.index; + + if conflict_policy == ToolCallConflictPolicy::EvictDistinctIdAndName + && should_evict_existing_tool_call(tool_calls.get(&index), tool_call.id, tool_call.name) + && let Some(evicted) = tool_calls.remove(&index) + { + events.push(RawStreamingChoice::ToolCall( + finalize_completed_streaming_tool_call(evicted), + )); + } + + let existing_tool_call = tool_calls + .entry(index) + .or_insert_with(RawStreamingToolCall::empty); + + if let Some(id) = tool_call.id + && !id.is_empty() + { + existing_tool_call.id = id.to_owned(); + } + + if let Some(name) = tool_call.name + && !name.is_empty() + { + existing_tool_call.name = name.to_owned(); + events.push(RawStreamingChoice::ToolCallDelta { + id: existing_tool_call.id.clone(), + internal_call_id: existing_tool_call.internal_call_id.clone(), + content: ToolCallDeltaContent::Name(name.to_owned()), + }); + } + + if let Some(chunk) = tool_call.arguments + && !chunk.is_empty() + { + append_tool_call_arguments(&mut existing_tool_call.arguments, chunk); + events.push(RawStreamingChoice::ToolCallDelta { + id: existing_tool_call.id.clone(), + internal_call_id: existing_tool_call.internal_call_id.clone(), + content: ToolCallDeltaContent::Delta(chunk.to_owned()), + }); + } + } + + events +} + +pub(crate) fn take_finalized_tool_calls( + tool_calls: &mut HashMap, +) -> Vec> +where + R: Clone, +{ + std::mem::take(tool_calls) + .into_values() + .map(|tool_call| { + RawStreamingChoice::ToolCall(finalize_completed_streaming_tool_call(tool_call)) + }) + .collect() +} + +pub(crate) fn take_tool_calls( + tool_calls: &mut HashMap, +) -> Vec> +where + R: Clone, +{ + std::mem::take(tool_calls) + .into_values() + .map(RawStreamingChoice::ToolCall) + .collect() +} + +pub(crate) fn finalize_completed_streaming_tool_call( + mut tool_call: RawStreamingToolCall, +) -> RawStreamingToolCall { + if tool_call.arguments.is_null() { + tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); + } + + tool_call +} + +fn should_evict_existing_tool_call( + existing: Option<&RawStreamingToolCall>, + new_id: Option<&str>, + new_name: Option<&str>, +) -> bool { + let Some(existing) = existing else { + return false; + }; + + let Some(new_id) = new_id.filter(|id| !id.is_empty()) else { + return false; + }; + let Some(new_name) = new_name.filter(|name| !name.is_empty()) else { + return false; + }; + + !existing.id.is_empty() + && existing.id != new_id + && !existing.name.is_empty() + && existing.name != new_name +} + +fn append_tool_call_arguments(arguments: &mut serde_json::Value, chunk: &str) { + let current_arguments = match arguments { + serde_json::Value::Null => String::new(), + serde_json::Value::String(value) => value.clone(), + ref value => value.to_string(), + }; + let combined = format!("{current_arguments}{chunk}"); + + if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { + match serde_json::from_str(&combined) { + Ok(parsed) => *arguments = parsed, + Err(_) => *arguments = serde_json::Value::String(combined), + } + } else { + *arguments = serde_json::Value::String(combined); + } +} + +#[cfg(test)] +pub(crate) mod request_conformance { + use serde_json::{Value, json}; + + use super::{ + AdditionalParamsPolicy, CompatibleChatProfile, OutputSchemaPolicy, ToolChoicePolicy, + ToolsPolicy, + }; + use crate::completion::{CompletionError, CompletionRequest, Document, ToolDefinition}; + use crate::message::{self, Message, ToolChoice, UserContent}; + use crate::{OneOrMany, completion}; + + #[derive(Debug, Clone, Copy)] + pub(crate) enum Fixture { + ModelOverridePrecedence, + PreambleDocumentHistoryOrdering, + EmptyMessageRejection, + UnsupportedFieldStripping, + ToolChoiceCoercion, + OutputSchemaHandling, + AdditionalParamsPassthrough, + AdditionalParamsDrop, + } + + #[derive(Debug, Clone, PartialEq)] + pub(crate) enum Outcome { + Supported(T), + RequestError, + } + + pub(crate) trait Harness { + #[allow(dead_code)] + fn family_name() -> &'static str; + + fn run(case: Fixture) -> Outcome; + + fn assert(case: Fixture, actual: Outcome); + } + + pub(crate) fn assert_case(case: Fixture) { + let actual = H::run(case); + H::assert(case, actual); + } + + pub(crate) fn serialize_request(request: &T) -> Value { + serde_json::to_value(request).expect("request should serialize") + } + + pub(crate) fn serialize_case(case: Fixture, convert: F) -> Outcome + where + T: serde::Serialize, + F: FnOnce(CompletionRequest) -> Result, + { + match convert(fixture_request(case)) { + Ok(request) => Outcome::Supported(serialize_request(&request)), + Err(CompletionError::RequestError(_)) => Outcome::RequestError, + Err(err) => panic!("unexpected request conversion error for {case:?}: {err}"), + } + } + + pub(crate) fn fixture_request(case: Fixture) -> CompletionRequest { + match case { + Fixture::ModelOverridePrecedence => CompletionRequest { + model: Some("override-model".to_owned()), + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }, + Fixture::PreambleDocumentHistoryOrdering => CompletionRequest { + model: None, + preamble: Some("system".to_owned()), + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: vec![Document { + id: "doc-1".to_owned(), + text: "Document body".to_owned(), + additional_props: Default::default(), + }], + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }, + Fixture::EmptyMessageRejection => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::Assistant { + id: None, + content: OneOrMany::one(message::AssistantContent::reasoning("hidden")), + }), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }, + Fixture::UnsupportedFieldStripping => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup weather".to_owned(), + parameters: json!({ + "type": "object", + "properties": { + "city": { "type": "string" } + } + }), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Auto), + additional_params: Some(json!({"vendor_flag": true})), + output_schema: Some(sample_schema()), + }, + Fixture::ToolChoiceCoercion => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup weather".to_owned(), + parameters: json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Required), + additional_params: None, + output_schema: None, + }, + Fixture::OutputSchemaHandling => CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: Some(json!({"vendor_flag": true})), + output_schema: Some(sample_schema()), + }, + Fixture::AdditionalParamsPassthrough | Fixture::AdditionalParamsDrop => { + CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: Some(json!({"vendor_flag": true})), + output_schema: None, + } + } + } + } + + pub(crate) fn model(value: &Value) -> Option<&str> { + value.get("model").and_then(Value::as_str) + } + + pub(crate) fn tools_len(value: &Value) -> usize { + value + .get("tools") + .and_then(Value::as_array) + .map_or(0, Vec::len) + } + + pub(crate) fn tool_choice(value: &Value) -> Option<&Value> { + value.get("tool_choice") + } + + pub(crate) fn top_level<'a>(value: &'a Value, key: &str) -> Option<&'a Value> { + value.get(key) + } + + pub(crate) fn tool_choice_text(value: &Value) -> Option<&str> { + value.get("tool_choice").and_then(Value::as_str) + } + + pub(crate) fn message_summaries(value: &Value) -> Vec { + value + .get("messages") + .and_then(Value::as_array) + .into_iter() + .flatten() + .map(|message| { + let role = message + .get("role") + .and_then(Value::as_str) + .unwrap_or("unknown"); + format!("{role}:{}", extract_text(message)) + }) + .collect() + } + + fn extract_text(message: &Value) -> String { + let content = extract_text_from_value(message.get("content").unwrap_or(&Value::Null)); + if !content.is_empty() { + return content; + } + + if let Some(reasoning) = message.get("reasoning_content").and_then(Value::as_str) { + return reasoning.to_owned(); + } + + extract_text_from_value(message.get("reasoning_details").unwrap_or(&Value::Null)) + } + + fn extract_text_from_value(value: &Value) -> String { + match value { + Value::String(text) => text.clone(), + Value::Array(items) => items + .iter() + .filter_map(extract_text_part) + .collect::>() + .join("\n"), + Value::Object(object) => { + if let Some(text) = object.get("text").and_then(Value::as_str) { + return text.to_owned(); + } + + if let Some(text) = object.get("content").and_then(Value::as_str) { + return text.to_owned(); + } + + if let Some(content) = object.get("content") { + return extract_text_from_value(content); + } + + String::new() + } + _ => String::new(), + } + } + + fn extract_text_part(value: &Value) -> Option { + match value { + Value::String(text) => Some(text.clone()), + Value::Object(object) => object + .get("text") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + .or_else(|| { + object + .get("content") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) + .or_else(|| object.get("content").map(extract_text_from_value)) + .filter(|text| !text.is_empty()), + _ => None, + } + } + + pub(crate) fn expected_document_text() -> String { + completion::Document { + id: "doc-1".to_owned(), + text: "Document body".to_owned(), + additional_props: Default::default(), + } + .to_string() + } + + fn sample_schema() -> schemars::Schema { + serde_json::from_value(json!({ + "title": "fixture_schema", + "type": "object", + "properties": { + "answer": { "type": "string" } + } + })) + .expect("schema should deserialize") + } + + #[derive(Debug, Clone, Copy)] + enum EmptyMessageBehavior { + DropEmpty, + PreserveReasoningAssistantHistory, + RequestError, + } + + #[derive(Debug, Clone, Copy)] + pub(crate) struct CompatibleChatExpectation { + profile: CompatibleChatProfile, + empty_message_behavior: EmptyMessageBehavior, + preamble_role: &'static str, + document_role: Option<&'static str>, + } + + impl CompatibleChatExpectation { + pub(crate) const fn new(profile: CompatibleChatProfile) -> Self { + Self { + profile, + empty_message_behavior: if profile.require_messages { + EmptyMessageBehavior::RequestError + } else { + EmptyMessageBehavior::DropEmpty + }, + preamble_role: "system", + document_role: Some("user"), + } + } + + pub(crate) const fn preserves_reasoning_assistant_history(mut self) -> Self { + self.empty_message_behavior = EmptyMessageBehavior::PreserveReasoningAssistantHistory; + self + } + + pub(crate) const fn preamble_as_user(mut self) -> Self { + self.preamble_role = "user"; + self + } + + pub(crate) const fn omits_document_messages(mut self) -> Self { + self.document_role = None; + self + } + } + + pub(crate) fn assert_compatible_chat_case( + expectation: CompatibleChatExpectation, + default_model: &str, + case: Fixture, + actual: Outcome, + ) { + let profile = expectation.profile; + match case { + Fixture::ModelOverridePrecedence => { + let request = expect_supported(case, actual); + assert_eq!(model(&request), Some("override-model")); + } + Fixture::PreambleDocumentHistoryOrdering => { + let request = expect_supported(case, actual); + let mut expected = vec![format!("{}:system", expectation.preamble_role)]; + if let Some(document_role) = expectation.document_role { + expected.push(format!("{document_role}:{}", expected_document_text())); + } + expected.push("user:hello".to_owned()); + assert_eq!(message_summaries(&request), expected); + } + Fixture::EmptyMessageRejection => match expectation.empty_message_behavior { + EmptyMessageBehavior::RequestError => { + assert!( + matches!(actual, Outcome::RequestError), + "expected request rejection for {case:?}, got {actual:?}" + ); + } + EmptyMessageBehavior::PreserveReasoningAssistantHistory => { + let request = expect_supported(case, actual); + let summaries = message_summaries(&request); + assert_eq!(summaries, vec!["assistant:hidden".to_owned()]); + } + EmptyMessageBehavior::DropEmpty => { + let request = expect_supported(case, actual); + assert!( + message_summaries(&request).is_empty(), + "expected no provider-compatible messages: {request:?}" + ); + } + }, + Fixture::UnsupportedFieldStripping => { + let request = expect_supported(case, actual); + assert_eq!(model(&request), Some(default_model)); + assert_eq!( + tools_len(&request), + if matches!(profile.feature_policy.tools_policy, ToolsPolicy::Supported) { + 1 + } else { + 0 + } + ); + assert_tool_choice(&request, profile.feature_policy.tool_choice_policy, "auto"); + assert_vendor_flag( + &request, + matches!( + profile.feature_policy.additional_params_policy, + AdditionalParamsPolicy::PassThrough + ), + ); + assert!( + top_level(&request, "response_format").is_none(), + "tool-bearing turns should omit response_format: {request:?}" + ); + } + Fixture::ToolChoiceCoercion => { + if matches!( + profile.feature_policy.tool_choice_policy, + ToolChoicePolicy::RejectRequired + ) { + assert!( + matches!(actual, Outcome::RequestError), + "expected request rejection for {case:?}, got {actual:?}" + ); + return; + } + + let request = expect_supported(case, actual); + match profile.feature_policy.tool_choice_policy { + ToolChoicePolicy::PassThrough => { + assert_eq!(tool_choice_text(&request), Some("required")); + } + ToolChoicePolicy::Unsupported => { + assert!( + tool_choice(&request).is_none(), + "expected tool_choice to be stripped: {request:?}" + ); + } + ToolChoicePolicy::RejectRequired => unreachable!(), + ToolChoicePolicy::CoerceRequiredToAuto { steering_message } => { + assert_eq!(tool_choice_text(&request), Some("auto")); + let expected_message = format!("user:{steering_message}"); + assert_eq!( + message_summaries(&request).last().map(String::as_str), + Some(expected_message.as_str()), + ); + } + } + } + Fixture::OutputSchemaHandling => { + let request = expect_supported(case, actual); + match profile.feature_policy.output_schema_policy { + OutputSchemaPolicy::NativeResponseFormat => { + assert_eq!( + top_level(&request, "response_format") + .and_then(|value| value.get("json_schema")) + .and_then(|value| value.get("name")) + .and_then(Value::as_str), + Some("fixture_schema"), + ); + } + OutputSchemaPolicy::Unsupported => { + assert!( + top_level(&request, "response_format").is_none(), + "unexpected response_format: {request:?}" + ); + } + } + assert_vendor_flag( + &request, + matches!( + profile.feature_policy.additional_params_policy, + AdditionalParamsPolicy::PassThrough + ), + ); + } + Fixture::AdditionalParamsPassthrough | Fixture::AdditionalParamsDrop => { + let request = expect_supported(case, actual); + assert_vendor_flag( + &request, + matches!( + profile.feature_policy.additional_params_policy, + AdditionalParamsPolicy::PassThrough + ), + ); + } + } + } + + fn expect_supported(case: Fixture, actual: Outcome) -> Value { + match actual { + Outcome::Supported(request) => request, + Outcome::RequestError => { + panic!("expected supported request for {case:?}, got request rejection") + } + } + } + + fn assert_vendor_flag(request: &Value, expected: bool) { + if expected { + assert_eq!(top_level(request, "vendor_flag"), Some(&json!(true))); + } else { + assert!( + top_level(request, "vendor_flag").is_none(), + "expected vendor_flag to be stripped: {request:?}" + ); + } + } + + fn assert_tool_choice(request: &Value, policy: ToolChoicePolicy, expected_pass_through: &str) { + match policy { + ToolChoicePolicy::PassThrough => { + assert_eq!(tool_choice_text(request), Some(expected_pass_through)); + } + ToolChoicePolicy::Unsupported => { + assert!( + tool_choice(request).is_none(), + "expected tool_choice to be stripped: {request:?}" + ); + } + ToolChoicePolicy::RejectRequired => { + assert_eq!(tool_choice_text(request), Some(expected_pass_through)); + } + ToolChoicePolicy::CoerceRequiredToAuto { .. } => { + assert_eq!(tool_choice_text(request), Some(expected_pass_through)); + } + } + } + + macro_rules! provider_request_conformance_tests { + ($harness:ty) => { + #[test] + fn request_conformance_model_override_precedence() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::ModelOverridePrecedence, + ); + } + + #[test] + fn request_conformance_preamble_document_history_ordering() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::PreambleDocumentHistoryOrdering, + ); + } + + #[test] + fn request_conformance_empty_message_rejection() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::EmptyMessageRejection, + ); + } + + #[test] + fn request_conformance_unsupported_field_stripping() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::UnsupportedFieldStripping, + ); + } + + #[test] + fn request_conformance_tool_choice_coercion() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::ToolChoiceCoercion, + ); + } + + #[test] + fn request_conformance_output_schema_handling() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::OutputSchemaHandling, + ); + } + + #[test] + fn request_conformance_additional_params_passthrough() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::AdditionalParamsPassthrough, + ); + } + + #[test] + fn request_conformance_additional_params_drop() { + crate::providers::openai::completion::request_conformance::assert_case::<$harness>( + crate::providers::openai::completion::request_conformance::Fixture::AdditionalParamsDrop, + ); + } + }; + } + + pub(crate) use provider_request_conformance_tests; +} + +#[cfg(test)] +mod tests { + use super::{ + AdditionalParamsPolicy, CompatibleChatProfile, CompatibleFeaturePolicy, + CompatibleStreamingToolCall, OutputSchemaPolicy, ToolCallConflictPolicy, ToolChoicePolicy, + ToolsPolicy, apply_compatible_tool_call_deltas, build_compatible_request_core, + finalize_completed_streaming_tool_call, take_finalized_tool_calls, + }; + use crate::OneOrMany; + use crate::completion::{CompletionRequest, ToolDefinition}; + use crate::message::{Message, ToolChoice, UserContent}; + use crate::streaming::{RawStreamingChoice, RawStreamingToolCall}; + use std::collections::HashMap; + + #[test] + fn requires_non_empty_messages_when_profile_demands_it() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let err = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider").require_messages(), + |preamble| preamble.to_owned(), + None, + |_| false, + |_message| Ok(Vec::::new()), + ) + .expect_err("empty converted messages should fail"); + + assert!(err.to_string().contains( + "Example Provider request has no provider-compatible messages after conversion" + )); + } + + #[test] + fn preserves_model_override_and_additional_params() { + let req = CompletionRequest { + model: Some("override-model".to_owned()), + preamble: Some("system".to_owned()), + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: Some(0.5), + max_tokens: Some(42), + tool_choice: None, + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: None, + }; + + let result = build_compatible_request_core( + "default-model", + req, + CompatibleChatProfile::new("Example Provider"), + |preamble| preamble.to_owned(), + None, + |_| false, + |message| Ok(vec![format!("{message:?}")]), + ) + .expect("request conversion should succeed"); + + assert_eq!(result.model, "override-model"); + assert_eq!(result.temperature, Some(0.5)); + assert_eq!(result.max_tokens, Some(42)); + assert_eq!( + result.additional_params, + Some(serde_json::json!({"foo":"bar"})) + ); + assert_eq!(result.messages.len(), 2); + } + + #[test] + fn feature_policy_strips_unsupported_fields() { + let mut req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup the weather".to_owned(), + parameters: serde_json::json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Auto), + additional_params: Some(serde_json::json!({"foo":"bar"})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "example", + "type": "object" + })) + .expect("schema should deserialize"), + ), + }; + + CompatibleFeaturePolicy::default() + .with_tools_policy(ToolsPolicy::Unsupported) + .with_tool_choice_policy(ToolChoicePolicy::Unsupported) + .with_output_schema_policy(OutputSchemaPolicy::Unsupported) + .with_additional_params_policy(AdditionalParamsPolicy::Unsupported) + .apply("Example Provider", &mut req) + .expect("policy application should succeed"); + + assert!(req.tools.is_empty()); + assert!(req.tool_choice.is_none()); + assert!(req.additional_params.is_none()); + assert!(req.output_schema.is_none()); + } + + #[test] + fn tool_choice_policy_can_coerce_required_to_auto_with_steering_message() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: vec![ToolDefinition { + name: "lookup_weather".to_owned(), + description: "Lookup the weather".to_owned(), + parameters: serde_json::json!({"type":"object"}), + }], + temperature: None, + max_tokens: None, + tool_choice: Some(ToolChoice::Required), + additional_params: None, + output_schema: None, + }; + + let result = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider") + .coerce_required_tool_choice_to_auto("Use a tool."), + |preamble| preamble.to_owned(), + Some(|text| text.to_owned()), + |_| false, + |message| Ok(vec![format!("{message:?}")]), + ) + .expect("request conversion should succeed"); + + assert_eq!(result.tool_choice, Some(ToolChoice::Auto)); + assert_eq!(result.messages.last(), Some(&"Use a tool.".to_owned())); + } + + #[test] + fn native_response_format_merges_schema_into_additional_params() { + let req = CompletionRequest { + model: None, + preamble: None, + chat_history: OneOrMany::one(Message::User { + content: OneOrMany::one(UserContent::text("hello")), + }), + documents: Vec::new(), + tools: Vec::new(), + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: Some(serde_json::json!({"vendor_flag": true})), + output_schema: Some( + serde_json::from_value(serde_json::json!({ + "title": "example", + "type": "object", + "properties": { + "answer": { "type": "string" } + } + })) + .expect("schema should deserialize"), + ), + }; + + let result = build_compatible_request_core( + "test-model", + req, + CompatibleChatProfile::new("Example Provider").native_response_format(), + |preamble| preamble.to_owned(), + None, + |_| false, + |message| Ok(vec![format!("{message:?}")]), + ) + .expect("request conversion should succeed"); + + let additional_params = result + .additional_params + .expect("response_format should be merged"); + assert_eq!( + additional_params.get("vendor_flag"), + Some(&serde_json::json!(true)) + ); + assert!(additional_params.get("response_format").is_some()); + } + + #[test] + fn evicts_distinct_tool_calls_that_reuse_the_same_index() { + let mut tool_calls = HashMap::from([( + 0, + RawStreamingToolCall { + id: "call_1".to_owned(), + internal_call_id: "internal_1".to_owned(), + call_id: None, + name: "weather".to_owned(), + arguments: serde_json::json!({"city":"Paris"}), + signature: None, + additional_params: None, + }, + )]); + + let events = apply_compatible_tool_call_deltas::<()>( + &mut tool_calls, + [CompatibleStreamingToolCall { + index: 0, + id: Some("call_2"), + name: Some("time"), + arguments: Some("{"), + }], + ToolCallConflictPolicy::EvictDistinctIdAndName, + ); + + assert!( + matches!(events.first(), Some(RawStreamingChoice::ToolCall(tool_call)) if tool_call.id == "call_1") + ); + assert_eq!( + tool_calls.get(&0).map(|tool_call| tool_call.id.as_str()), + Some("call_2") + ); + assert_eq!( + tool_calls.get(&0).map(|tool_call| tool_call.name.as_str()), + Some("time") + ); + } + + #[test] + fn finalizes_null_arguments_into_empty_objects() { + let finalized = finalize_completed_streaming_tool_call(RawStreamingToolCall::empty()); + assert_eq!(finalized.arguments, serde_json::json!({})); + } + + #[test] + fn drains_finalized_tool_calls() { + let mut tool_calls = HashMap::from([(0, RawStreamingToolCall::empty())]); + + let events = take_finalized_tool_calls::<()>(&mut tool_calls); + + assert!(tool_calls.is_empty()); + assert!( + matches!(events.as_slice(), [RawStreamingChoice::ToolCall(tool_call)] if tool_call.arguments == serde_json::json!({})) + ); + } +} diff --git a/rig/rig-core/src/providers/openai/completion/mod.rs b/rig/rig-core/src/providers/openai/completion/mod.rs index a7763d7c3..8f444c4cf 100644 --- a/rig/rig-core/src/providers/openai/completion/mod.rs +++ b/rig/rig-core/src/providers/openai/completion/mod.rs @@ -22,8 +22,21 @@ use tracing::{Instrument, Level, enabled, info_span}; use std::str::FromStr; +mod family; pub mod streaming; +#[cfg(test)] +pub(crate) use family::request_conformance; +pub(crate) use family::{ + CompatibleChatProfile, CompatibleRequestCore, CompatibleStreamingToolCall, + OpenAiChatProviderProfile, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + build_compatible_request_core, build_completion_response, first_choice, map_finish_reason, + non_empty_text, take_finalized_tool_calls, take_tool_calls, +}; +pub use family::{ + CompatibleToolChoice, CompatibleToolChoiceFunctionKind, CompatibleToolChoiceKeyword, +}; + /// Serializes user content as a plain string when there's a single text item, /// otherwise as an array of content parts. fn serialize_user_content( @@ -185,10 +198,11 @@ impl Message { } } -fn history_contains_tool_result(messages: &[Message]) -> bool { - messages - .iter() - .any(|message| matches!(message, Message::ToolResult { .. })) +fn openai_user_message(content: &str) -> Message { + Message::User { + content: OneOrMany::one(UserContent::from(content.to_owned())), + name: None, + } } #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] @@ -792,9 +806,8 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = first_choice(&response.choices)?; + let stop_reason = Some(map_finish_reason(&choice.finish_reason)); let content = match &choice.message { Message::Assistant { @@ -809,11 +822,7 @@ impl TryFrom for completion::CompletionResponse text, AssistantContent::Refusal { refusal } => refusal, }; - if s.is_empty() { - None - } else { - Some(completion::AssistantContent::text(s)) - } + non_empty_text(s) }) .collect::>(); @@ -836,12 +845,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for CompletionRequest { - type Error = CompletionError; - - fn try_from(params: OpenAIRequestParams) -> Result { - let OpenAIRequestParams { - model, - request: req, - strict_tools, - tool_result_array_content, - } = params; - - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - let CoreCompletionRequest { - model: request_model, - preamble, - chat_history, - tools, - temperature, - max_tokens, - additional_params, - tool_choice, - output_schema, - .. - } = req; - - partial_history.extend(chat_history); - - let mut full_history: Vec = - preamble.map_or_else(Vec::new, |preamble| vec![Message::system(&preamble)]); - - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect::>(), - ); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "OpenAI Chat Completions request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - if tool_result_array_content { - for msg in &mut full_history { - if let Message::ToolResult { content, .. } = msg { - *content = content.to_array(); - } +fn build_openai_completion_request( + params: OpenAIRequestParams, + profile: CompatibleChatProfile, +) -> Result { + let OpenAIRequestParams { + model, + request: req, + strict_tools, + tool_result_array_content, + } = params; + let CompatibleRequestCore { + model, + messages: mut full_history, + temperature, + max_tokens, + tools, + tool_choice, + additional_params, + } = build_compatible_request_core( + &model, + req, + profile, + Message::system, + Some(openai_user_message), + |message| matches!(message, Message::ToolResult { .. }), + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + if tool_result_array_content { + for msg in &mut full_history { + if let Message::ToolResult { content, .. } = msg { + *content = content.to_array(); } } + } - let history_has_tool_result = history_contains_tool_result(&full_history); + let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; - let tool_choice = tool_choice.map(ToolChoice::try_from).transpose()?; + let tools: Vec = tools + .into_iter() + .map(|tool| { + let def = ToolDefinition::from(tool); + if strict_tools { def.with_strict() } else { def } + }) + .collect(); - let tools: Vec = tools - .into_iter() - .map(|tool| { - let def = ToolDefinition::from(tool); - if strict_tools { def.with_strict() } else { def } - }) - .collect(); - - // Some OpenAI-compatible backends such as llama.cpp will skip tool execution - // if `response_format` is sent on the first turn alongside tools. Delay the - // schema until after the conversation contains a tool result. - let should_apply_response_format = - output_schema.is_some() && (tools.is_empty() || history_has_tool_result); - - // Map output_schema to OpenAI's response_format and merge into additional_params - let additional_params = if let Some(schema) = output_schema - && should_apply_response_format - { - let name = schema - .as_object() - .and_then(|o| o.get("title")) - .and_then(|v| v.as_str()) - .unwrap_or("response_schema") - .to_string(); - let mut schema_value = schema.to_value(); - super::sanitize_schema(&mut schema_value); - let response_format = serde_json::json!({ - "response_format": { - "type": "json_schema", - "json_schema": { - "name": name, - "strict": true, - "schema": schema_value - } - } - }); - Some(match additional_params { - Some(existing) => json_utils::merge(existing, response_format), - None => response_format, - }) - } else { - additional_params - }; + Ok(CompletionRequest { + model, + messages: full_history, + tools, + tool_choice, + temperature, + max_tokens, + additional_params, + }) +} - let res = Self { - model: request_model.unwrap_or(model), - messages: full_history, - tools, - tool_choice, - temperature, - max_tokens, - additional_params, - }; +impl TryFrom for CompletionRequest { + type Error = CompletionError; - Ok(res) + fn try_from(params: OpenAIRequestParams) -> Result { + build_openai_completion_request( + params, + CompatibleChatProfile::openai_chat_completions("OpenAI Chat Completions"), + ) } } @@ -1235,6 +1184,7 @@ where crate::client::Client: HttpClientExt + Clone + WasmCompatSend + WasmCompatSync + 'static, Ext: crate::client::Provider + + OpenAiChatProviderProfile + crate::client::DebugExt + Clone + WasmCompatSend @@ -1260,7 +1210,7 @@ where target: "rig::completions", "chat", gen_ai.operation.name = "chat", - gen_ai.provider.name = "openai", + gen_ai.provider.name = Ext::telemetry_provider_name(), gen_ai.request.model = self.model, gen_ai.system_instructions = &completion_request.preamble, gen_ai.response.id = tracing::field::Empty, @@ -1273,17 +1223,21 @@ where tracing::Span::current() }; - let request = CompletionRequest::try_from(OpenAIRequestParams { - model: self.model.to_owned(), - request: completion_request, - strict_tools: self.strict_tools, - tool_result_array_content: self.tool_result_array_content, - })?; + let request = build_openai_completion_request( + OpenAIRequestParams { + model: self.model.to_owned(), + request: completion_request, + strict_tools: self.strict_tools, + tool_result_array_content: self.tool_result_array_content, + }, + Ext::request_profile(), + )?; if enabled!(Level::TRACE) { tracing::trace!( target: "rig::completions", - "OpenAI Chat Completions completion request: {}", + "{} completion request: {}", + Ext::provider_name(), serde_json::to_string_pretty(&request)? ); } @@ -1292,7 +1246,7 @@ where let req = self .client - .post("/chat/completions")? + .post(Ext::completions_path())? .body(body) .map_err(|e| CompletionError::HttpError(e.into()))?; @@ -1311,7 +1265,8 @@ where if enabled!(Level::TRACE) { tracing::trace!( target: "rig::completions", - "OpenAI Chat Completions completion response: {}", + "{} completion response: {}", + Ext::provider_name(), serde_json::to_string_pretty(&response)? ); } @@ -1354,10 +1309,50 @@ where } } +#[cfg(test)] +mod conformance_tests; + #[cfg(test)] mod tests { use super::*; + struct OpenAiRequestHarness; + + impl request_conformance::Harness for OpenAiRequestHarness { + fn family_name() -> &'static str { + "openai-chat" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + CompletionRequest::try_from(OpenAIRequestParams { + model: "default-model".to_owned(), + request, + strict_tools: false, + tool_result_array_content: false, + }) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::openai_chat_completions("OpenAI Chat Completions"), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(OpenAiRequestHarness); + #[test] fn test_openai_request_uses_request_model_override() { let request = crate::completion::CompletionRequest { diff --git a/rig/rig-core/src/providers/openai/completion/streaming.rs b/rig/rig-core/src/providers/openai/completion/streaming.rs index add63cf4b..a727300ca 100644 --- a/rig/rig-core/src/providers/openai/completion/streaming.rs +++ b/rig/rig-core/src/providers/openai/completion/streaming.rs @@ -12,7 +12,11 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; use crate::http_client::HttpClientExt; use crate::http_client::sse::{Event, GenericEventSource}; use crate::json_utils::{self, merge}; -use crate::providers::openai::completion::{GenericCompletionModel, OpenAIRequestParams, Usage}; +use crate::providers::openai::completion::{ + CompatibleStreamingToolCall, GenericCompletionModel, OpenAIRequestParams, + OpenAiChatProviderProfile, ToolCallConflictPolicy, Usage, apply_compatible_tool_call_deltas, + take_finalized_tool_calls, take_tool_calls, +}; use crate::streaming::{self, RawStreamingChoice}; // ================================================================ @@ -84,22 +88,35 @@ impl GetTokenUsage for StreamingCompletionResponse { } } +fn map_finish_reason(reason: &FinishReason) -> crate::completion::StopReason { + match reason { + FinishReason::ToolCalls => crate::completion::StopReason::ToolCalls, + FinishReason::Stop => crate::completion::StopReason::Stop, + FinishReason::ContentFilter => crate::completion::StopReason::ContentFilter, + FinishReason::Length => crate::completion::StopReason::MaxTokens, + FinishReason::Other(other) => crate::completion::StopReason::Other(other.clone()), + } +} + impl GenericCompletionModel where crate::client::Client: HttpClientExt + Clone + 'static, - Ext: crate::client::Provider + Clone + 'static, + Ext: crate::client::Provider + OpenAiChatProviderProfile + Clone + 'static, { pub(crate) async fn stream( &self, completion_request: CompletionRequest, ) -> Result, CompletionError> { - let request = super::CompletionRequest::try_from(OpenAIRequestParams { - model: self.model.clone(), - request: completion_request, - strict_tools: self.strict_tools, - tool_result_array_content: self.tool_result_array_content, - })?; + let request = super::build_openai_completion_request( + OpenAIRequestParams { + model: self.model.clone(), + request: completion_request, + strict_tools: self.strict_tools, + tool_result_array_content: self.tool_result_array_content, + }, + Ext::request_profile(), + )?; let request_messages = serde_json::to_string(&request.messages) .expect("Converting to JSON from a Rust struct shouldn't fail"); let mut request_as_json = serde_json::to_value(request).expect("this should never fail"); @@ -112,7 +129,8 @@ where if enabled!(Level::TRACE) { tracing::trace!( target: "rig::completions", - "OpenAI Chat Completions streaming completion request: {}", + "{} streaming completion request: {}", + Ext::provider_name(), serde_json::to_string_pretty(&request_as_json)? ); } @@ -121,7 +139,7 @@ where let req = self .client - .post("/chat/completions")? + .post(Ext::completions_path())? .body(req_body) .map_err(|e| CompletionError::HttpError(e.into()))?; @@ -130,7 +148,7 @@ where target: "rig::completions", "chat", gen_ai.operation.name = "chat", - gen_ai.provider.name = "openai", + gen_ai.provider.name = Ext::telemetry_provider_name(), gen_ai.request.model = self.model, gen_ai.response.id = tracing::field::Empty, gen_ai.response.model = self.model, @@ -146,7 +164,15 @@ where let client = self.client.clone(); - tracing::Instrument::instrument(send_compatible_streaming_request(client, req), span).await + tracing::Instrument::instrument( + send_compatible_streaming_request_with_policy( + client, + req, + Ext::stream_tool_call_conflict_policy(), + ), + span, + ) + .await } } @@ -154,6 +180,22 @@ pub async fn send_compatible_streaming_request( http_client: T, req: Request>, ) -> Result, CompletionError> +where + T: HttpClientExt + Clone + 'static, +{ + send_compatible_streaming_request_with_policy( + http_client, + req, + ToolCallConflictPolicy::EvictDistinctIdAndName, + ) + .await +} + +pub(crate) async fn send_compatible_streaming_request_with_policy( + http_client: T, + req: Request>, + conflict_policy: ToolCallConflictPolicy, +) -> Result, CompletionError> where T: HttpClientExt + Clone + 'static, { @@ -202,75 +244,17 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let index = tool_call.index; - - // Some API gateways (e.g. LiteLLM, OneAPI) emit multiple - // distinct tool calls all sharing index 0. Detect this by - // comparing both the `id` and `name`: only evict when a new - // chunk carries a different non-empty id AND a different - // non-empty name. Checking the name prevents false evictions - // from providers (e.g. GLM-4) that send a unique id on every - // SSE chunk for the same logical tool call. - if let Some(new_id) = &tool_call.id - && !new_id.is_empty() - && let Some(new_name) = &tool_call.function.name - && !new_name.is_empty() - && let Some(existing) = tool_calls.get(&index) - && !existing.id.is_empty() - && existing.id != *new_id - && !existing.name.is_empty() - && existing.name != *new_name - { - let evicted = tool_calls.remove(&index).expect("checked above"); - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(evicted), - )); - } - - let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty); - - if let Some(id) = &tool_call.id && !id.is_empty() { - existing_tool_call.id = id.clone(); - } - - if let Some(name) = &tool_call.function.name && !name.is_empty() { - existing_tool_call.name = name.clone(); - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Name(name.clone()), - }); - } - - // Convert current arguments to string if needed - if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() { - let current_args = match &existing_tool_call.arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(s) => s.clone(), - v => v.to_string(), - }; - - // Concatenate the new chunk - let combined = format!("{current_args}{chunk}"); - - // Try to parse as JSON if it looks complete - if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { - match serde_json::from_str(&combined) { - Ok(parsed) => existing_tool_call.arguments = parsed, - Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined), - } - } else { - existing_tool_call.arguments = serde_json::Value::String(combined); - } - - // Emit the delta so UI can show progress - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Delta(chunk.clone()), - }); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + conflict_policy, + ) { + yield Ok(event); } } @@ -289,13 +273,16 @@ where } // Finish reason - if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + if let Some(finish_reason) = &choice.finish_reason { + if *finish_reason == FinishReason::ToolCalls { + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); + } } - tool_calls = HashMap::new(); + + yield Ok(streaming::RawStreamingChoice::StopReason(map_finish_reason( + finish_reason, + ))); } } Err(crate::http_client::Error::StreamEnded) => { @@ -314,8 +301,8 @@ where event_source.close(); // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier) - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call)); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } let final_usage = final_usage.unwrap_or_default(); @@ -342,16 +329,6 @@ where ))) } -fn finalize_completed_streaming_tool_call( - mut tool_call: streaming::RawStreamingToolCall, -) -> streaming::RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = serde_json::Value::Object(serde_json::Map::new()); - } - - tool_call -} - #[cfg(test)] mod tests { use super::*; diff --git a/rig/rig-core/src/providers/openai/responses_api/mod.rs b/rig/rig-core/src/providers/openai/responses_api/mod.rs index b17493ef7..64efa29ce 100644 --- a/rig/rig-core/src/providers/openai/responses_api/mod.rs +++ b/rig/rig-core/src/providers/openai/responses_api/mod.rs @@ -1427,12 +1427,6 @@ impl TryFrom for completion::CompletionResponse Result { - if response.output.is_empty() { - return Err(CompletionError::ResponseError( - "Response contained no parts".to_owned(), - )); - } - // Extract the msg_ ID from the first Output::Message item let message_id = response.output.iter().find_map(|item| match item { Output::Message(msg) => Some(msg.id.clone()), @@ -1446,11 +1440,7 @@ impl TryFrom for completion::CompletionResponse>::from) .collect(); - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError( - "Response contained no message or tool call (empty)".to_owned(), - ) - })?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage @@ -1473,6 +1463,7 @@ impl TryFrom for completion::CompletionResponse) -> bool { - choice.iter().all(|content| match content { - completion::AssistantContent::Text(text) => text.text.trim().is_empty(), - completion::AssistantContent::Reasoning(reasoning) => reasoning.content.is_empty(), - completion::AssistantContent::Image(_) => false, - completion::AssistantContent::ToolCall(_) => false, + stop_reason, }) } diff --git a/rig/rig-core/src/providers/openrouter/completion.rs b/rig/rig-core/src/providers/openrouter/completion.rs index 3b2f841ec..0802ece60 100644 --- a/rig/rig-core/src/providers/openrouter/completion.rs +++ b/rig/rig-core/src/providers/openrouter/completion.rs @@ -592,9 +592,11 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; + let stop_reason = choice + .finish_reason + .as_deref() + .map(crate::providers::openai::completion::map_finish_reason); let content = match &choice.message { Message::Assistant { @@ -704,12 +706,6 @@ impl TryFrom for completion::CompletionResponse for completion::CompletionResponse for Vec { } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged, rename_all = "snake_case")] -pub enum ToolChoice { - None, - Auto, - Required, - Function(Vec), -} - -impl TryFrom for ToolChoice { - type Error = CompletionError; - - fn try_from(value: crate::message::ToolChoice) -> Result { - let res = match value { - crate::message::ToolChoice::None => Self::None, - crate::message::ToolChoice::Auto => Self::Auto, - crate::message::ToolChoice::Required => Self::Required, - crate::message::ToolChoice::Specific { function_names } => { - let vec: Vec = function_names - .into_iter() - .map(|name| ToolChoiceFunctionKind::Function { name }) - .collect(); - - Self::Function(vec) - } - }; - - Ok(res) - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type", content = "function")] -pub enum ToolChoiceFunctionKind { - Function { name: String }, -} +pub type ToolChoice = crate::providers::openai::completion::CompatibleToolChoice; +pub type ToolChoiceFunctionKind = + crate::providers::openai::completion::CompatibleToolChoiceFunctionKind; +pub type ToolChoiceKeyword = crate::providers::openai::completion::CompatibleToolChoiceKeyword; #[derive(Debug, Serialize, Deserialize)] pub(super) struct OpenrouterCompletionRequest { @@ -1561,7 +1528,7 @@ pub(super) struct OpenrouterCompletionRequest { #[serde(skip_serializing_if = "Vec::is_empty")] tools: Vec, #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, + tool_choice: Option, #[serde(flatten, skip_serializing_if = "Option::is_none")] pub additional_params: Option, } @@ -1582,42 +1549,27 @@ impl TryFrom> for OpenrouterCompletionRequest { request: req, strict_tools, } = params; - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for OpenRouter"); - } - - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![Message::system(preamble)], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .clone() - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("OpenRouter"), + Message::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; - let tool_choice = req - .tool_choice - .clone() - .map(crate::providers::openai::completion::ToolChoice::try_from) - .transpose()?; + let tool_choice = tool_choice.map(ToolChoice::from); - let tools: Vec = req - .tools - .clone() + let tools: Vec = tools .into_iter() .map(|tool| { let def = crate::providers::openai::completion::ToolDefinition::from(tool); @@ -1627,11 +1579,11 @@ impl TryFrom> for OpenrouterCompletionRequest { Ok(Self { model, - messages: full_history, - temperature: req.temperature, + messages, + temperature, tools, tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -1796,6 +1748,47 @@ mod tests { use super::*; use serde_json::json; + struct OpenRouterRequestHarness; + + impl crate::providers::openai::completion::request_conformance::Harness + for OpenRouterRequestHarness + { + fn family_name() -> &'static str { + "openrouter" + } + + fn run( + case: crate::providers::openai::completion::request_conformance::Fixture, + ) -> crate::providers::openai::completion::request_conformance::Outcome + { + crate::providers::openai::completion::request_conformance::serialize_case( + case, + |request| OpenrouterCompletionRequest::try_from(("default-model", request)), + ) + } + + fn assert( + case: crate::providers::openai::completion::request_conformance::Fixture, + actual: crate::providers::openai::completion::request_conformance::Outcome< + serde_json::Value, + >, + ) { + crate::providers::openai::completion::request_conformance::assert_compatible_chat_case( + crate::providers::openai::completion::request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("OpenRouter"), + ) + .preserves_reasoning_assistant_history(), + "default-model", + case, + actual, + ); + } + } + + crate::providers::openai::completion::request_conformance::provider_request_conformance_tests!( + OpenRouterRequestHarness + ); + #[test] fn test_openrouter_request_uses_request_model_override() { let request = CompletionRequest { diff --git a/rig/rig-core/src/providers/openrouter/streaming.rs b/rig/rig-core/src/providers/openrouter/streaming.rs index b057f7d8c..3a389727d 100644 --- a/rig/rig-core/src/providers/openrouter/streaming.rs +++ b/rig/rig-core/src/providers/openrouter/streaming.rs @@ -12,6 +12,10 @@ use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage}; use crate::http_client::HttpClientExt; use crate::http_client::sse::{Event, GenericEventSource}; use crate::json_utils; +use crate::providers::openai::completion::{ + CompatibleStreamingToolCall, ToolCallConflictPolicy, apply_compatible_tool_call_deltas, + take_finalized_tool_calls, take_tool_calls, +}; use crate::providers::openrouter::{ OpenRouterRequestParams, OpenrouterCompletionRequest, ReasoningDetails, }; @@ -34,6 +38,17 @@ impl GetTokenUsage for StreamingCompletionResponse { } } +fn map_finish_reason(reason: &FinishReason) -> crate::completion::StopReason { + match reason { + FinishReason::ToolCalls => crate::completion::StopReason::ToolCalls, + FinishReason::Stop => crate::completion::StopReason::Stop, + FinishReason::Error => crate::completion::StopReason::Other("error".to_owned()), + FinishReason::ContentFilter => crate::completion::StopReason::ContentFilter, + FinishReason::Length => crate::completion::StopReason::MaxTokens, + FinishReason::Other(other) => crate::completion::StopReason::Other(other.clone()), + } +} + #[derive(Deserialize, Debug, PartialEq)] #[serde(rename_all = "snake_case")] pub enum FinishReason { @@ -212,54 +227,17 @@ where let delta = &choice.delta; if !delta.tool_calls.is_empty() { - for tool_call in &delta.tool_calls { - let index = tool_call.index; - - // Get or create tool call entry - let existing_tool_call = tool_calls.entry(index).or_insert_with(streaming::RawStreamingToolCall::empty); - - // Update fields if present - if let Some(id) = &tool_call.id && !id.is_empty() { - existing_tool_call.id = id.clone(); - } - - if let Some(name) = &tool_call.function.name && !name.is_empty() { - existing_tool_call.name = name.clone(); - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Name(name.clone()), - }); - } - - // Convert current arguments to string if needed - if let Some(chunk) = &tool_call.function.arguments && !chunk.is_empty() { - let current_args = match &existing_tool_call.arguments { - serde_json::Value::Null => String::new(), - serde_json::Value::String(s) => s.clone(), - v => v.to_string(), - }; - - // Concatenate the new chunk - let combined = format!("{current_args}{chunk}"); - - // Try to parse as JSON if it looks complete - if combined.trim_start().starts_with('{') && combined.trim_end().ends_with('}') { - match serde_json::from_str(&combined) { - Ok(parsed) => existing_tool_call.arguments = parsed, - Err(_) => existing_tool_call.arguments = serde_json::Value::String(combined), - } - } else { - existing_tool_call.arguments = serde_json::Value::String(combined); - } - - // Emit the delta so UI can show progress - yield Ok(streaming::RawStreamingChoice::ToolCallDelta { - id: existing_tool_call.id.clone(), - internal_call_id: existing_tool_call.internal_call_id.clone(), - content: streaming::ToolCallDeltaContent::Delta(chunk.clone()), - }); - } + for event in apply_compatible_tool_call_deltas( + &mut tool_calls, + delta.tool_calls.iter().map(|tool_call| CompatibleStreamingToolCall { + index: tool_call.index, + id: tool_call.id.as_deref(), + name: tool_call.function.name.as_deref(), + arguments: tool_call.function.arguments.as_deref(), + }), + ToolCallConflictPolicy::KeepIndex, + ) { + yield Ok(event); } // Update the signature and the additional params of the tool call if present @@ -294,12 +272,15 @@ where // Finish reason if let Some(finish_reason) = &choice.finish_reason && *finish_reason == FinishReason::ToolCalls { - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall( - finalize_completed_streaming_tool_call(tool_call), - )); + for tool_call in take_finalized_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } - tool_calls = HashMap::new(); + } + + if let Some(finish_reason) = &choice.finish_reason { + yield Ok(streaming::RawStreamingChoice::StopReason(map_finish_reason( + finish_reason, + ))); } } Err(crate::http_client::Error::StreamEnded) => { @@ -317,8 +298,8 @@ where event_source.close(); // Flush any accumulated tool calls (that weren't emitted as ToolCall earlier) - for (_idx, tool_call) in tool_calls.into_iter() { - yield Ok(streaming::RawStreamingChoice::ToolCall(tool_call)); + for tool_call in take_tool_calls(&mut tool_calls) { + yield Ok(tool_call); } // Final response with usage @@ -332,16 +313,6 @@ where ))) } -fn finalize_completed_streaming_tool_call( - mut tool_call: streaming::RawStreamingToolCall, -) -> streaming::RawStreamingToolCall { - if tool_call.arguments.is_null() { - tool_call.arguments = Value::Object(serde_json::Map::new()); - } - - tool_call -} - #[cfg(test)] mod tests { use super::*; @@ -548,4 +519,39 @@ mod tests { assert_eq!(error.code, 500); assert_eq!(error.message, "Provider disconnected"); } + + #[tokio::test] + async fn test_stream_captures_stop_reason() { + use crate::http_client::mock::MockStreamingClient; + use bytes::Bytes; + use futures::StreamExt; + + let sse = concat!( + "data: {\"id\":\"gen-1\",\"model\":\"gpt-4\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"},\"finish_reason\":\"stop\"}],\"usage\":null}\n\n", + "data: {\"id\":\"gen-2\",\"model\":\"gpt-4\",\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n\n", + "data: [DONE]\n\n", + ); + + let client = MockStreamingClient { + sse_bytes: Bytes::from(sse), + }; + let req = http::Request::builder() + .method("POST") + .uri("http://localhost/v1/chat/completions") + .body(Vec::new()) + .unwrap(); + + let mut stream = send_compatible_streaming_request(client, req) + .await + .expect("stream should start"); + + while let Some(chunk) = stream.next().await { + chunk.expect("stream chunk should deserialize"); + } + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::Stop) + ); + } } diff --git a/rig/rig-core/src/providers/perplexity.rs b/rig/rig-core/src/providers/perplexity.rs index 50a7f64b7..5cb942a67 100644 --- a/rig/rig-core/src/providers/perplexity.rs +++ b/rig/rig-core/src/providers/perplexity.rs @@ -14,7 +14,6 @@ use crate::providers::openai; use crate::providers::openai::send_compatible_streaming_request; use crate::streaming::StreamingCompletionResponse; use crate::{ - OneOrMany, client::{ self, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder, ProviderClient, }, @@ -176,26 +175,38 @@ impl TryFrom for completion::CompletionResponse Result { - let choice = response.choices.first().ok_or_else(|| { - CompletionError::ResponseError("Response contained no choices".to_owned()) - })?; + let choice = crate::providers::openai::completion::first_choice(&response.choices)?; + let stop_reason = Some(crate::providers::openai::completion::map_finish_reason( + &choice.finish_reason, + )); match &choice.message { Message { role: Role::Assistant, content, - } => Ok(completion::CompletionResponse { - choice: OneOrMany::one(content.clone().into()), - usage: completion::Usage { + } => { + let normalized_content = + crate::providers::openai::completion::non_empty_text(content) + .into_iter() + .collect::>(); + let usage = completion::Usage { input_tokens: response.usage.prompt_tokens as u64, output_tokens: response.usage.completion_tokens as u64, total_tokens: response.usage.total_tokens as u64, cached_input_tokens: 0, cache_creation_input_tokens: 0, - }, - raw_response: response, - message_id: None, - }), + }; + + Ok( + crate::providers::openai::completion::build_completion_response( + response, + usage, + None, + stop_reason, + normalized_content, + ), + ) + } _ => Err(CompletionError::ResponseError( "Response contained no assistant message".to_owned(), )), @@ -216,42 +227,94 @@ pub(super) struct PerplexityCompletionRequest { pub stream: bool, } +fn perplexity_request_messages(message: message::Message) -> Result, MessageError> { + match message { + message::Message::System { content } => Ok(vec![Message { + role: Role::System, + content, + }]), + message::Message::User { content } => { + let collapsed_content = content + .into_iter() + .filter_map(|content| match content { + message::UserContent::Text(message::Text { text }) => Some(text), + message::UserContent::Document(document) => match document.data { + crate::message::DocumentSourceKind::Base64(content) + | crate::message::DocumentSourceKind::String(content) => Some(content), + _ => None, + }, + _ => None, + }) + .collect::>() + .join("\n"); + + if collapsed_content.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: Role::User, + content: collapsed_content, + }]) + } + } + message::Message::Assistant { content, .. } => { + let collapsed_content = content + .into_iter() + .filter_map(|content| match content { + message::AssistantContent::Text(message::Text { text }) => Some(text), + _ => None, + }) + .collect::>() + .join("\n"); + + if collapsed_content.is_empty() { + Ok(vec![]) + } else { + Ok(vec![Message { + role: Role::Assistant, + content: collapsed_content, + }]) + } + } + } +} + impl TryFrom<(&str, CompletionRequest)> for PerplexityCompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for Perplexity"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut partial_history = vec![]; - if let Some(docs) = req.normalized_documents() { - partial_history.push(docs); - } - partial_history.extend(req.chat_history); - - // Initialize full history with preamble (or empty if non-existent) - let mut full_history: Vec = req.preamble.map_or_else(Vec::new, |preamble| { - vec![Message { + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens, + tools: _, + tool_choice: _, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("Perplexity") + .unsupported_tools() + .unsupported_tool_choice(), + |preamble| Message { role: Role::System, - content: preamble, - }] - }); - - // Convert and extend the rest of the history - full_history.extend( - partial_history - .into_iter() - .map(message::Message::try_into) - .collect::, _>>()?, - ); + content: preamble.to_owned(), + }, + None, + |_| false, + |message| { + perplexity_request_messages(message) + .map_err(|err| CompletionError::RequestError(err.into())) + }, + )?; Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - max_tokens: req.max_tokens, - additional_params: req.additional_params, + model, + messages, + temperature, + max_tokens, + additional_params, stream: false, }) } @@ -373,13 +436,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); - } - - if !completion_request.tools.is_empty() { - tracing::warn!("WARNING: `tools` not supported on Perplexity"); - } let request = PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?; @@ -459,14 +515,6 @@ where span.record("gen_ai.system_instructions", &completion_request.preamble); - if completion_request.tool_choice.is_some() { - tracing::warn!("WARNING: `tool_choice` not supported on Perplexity"); - } - - if !completion_request.tools.is_empty() { - tracing::warn!("WARNING: `tools` not supported on Perplexity"); - } - let mut request = PerplexityCompletionRequest::try_from((self.model.as_ref(), completion_request))?; request.stream = true; @@ -495,6 +543,41 @@ where #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::{CompatibleChatProfile, request_conformance}; + + struct PerplexityRequestHarness; + + impl request_conformance::Harness for PerplexityRequestHarness { + fn family_name() -> &'static str { + "perplexity" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + PerplexityCompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + CompatibleChatProfile::new("Perplexity") + .unsupported_tools() + .unsupported_tool_choice(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(PerplexityRequestHarness); #[test] fn test_deserialize_message() { diff --git a/rig/rig-core/src/providers/together/completion.rs b/rig/rig-core/src/providers/together/completion.rs index e5b24ee05..9f688edbd 100644 --- a/rig/rig-core/src/providers/together/completion.rs +++ b/rig/rig-core/src/providers/together/completion.rs @@ -147,58 +147,38 @@ impl TryFrom<(&str, CompletionRequest)> for TogetherAICompletionRequest { type Error = CompletionError; fn try_from((model, req): (&str, CompletionRequest)) -> Result { - if req.output_schema.is_some() { - tracing::warn!("Structured outputs currently not supported for TogetherAI"); - } - let model = req.model.clone().unwrap_or_else(|| model.to_string()); - let mut full_history: Vec = match &req.preamble { - Some(preamble) => vec![openai::Message::system(preamble)], - None => vec![], - }; - if let Some(docs) = req.normalized_documents() { - let docs: Vec = docs.try_into()?; - full_history.extend(docs); - } - - let chat_history: Vec = req - .chat_history - .into_iter() - .map(|message| message.try_into()) - .collect::>, _>>()? - .into_iter() - .flatten() - .collect(); - - full_history.extend(chat_history); - - if full_history.is_empty() { - return Err(CompletionError::RequestError( - std::io::Error::new( - std::io::ErrorKind::InvalidInput, - "Together request has no provider-compatible messages after conversion", - ) - .into(), - )); - } - - let tool_choice = req - .tool_choice - .clone() - .map(ToolChoice::try_from) - .transpose()?; + let crate::providers::openai::completion::CompatibleRequestCore { + model, + messages, + temperature, + max_tokens: _, + tools, + tool_choice, + additional_params, + } = crate::providers::openai::completion::build_compatible_request_core( + model, + req, + crate::providers::openai::completion::CompatibleChatProfile::new("TogetherAI") + .require_messages() + .reject_required_tool_choice(), + openai::Message::system, + None, + |_| false, + |message| Vec::::try_from(message).map_err(CompletionError::from), + )?; + + let tool_choice = tool_choice.map(ToolChoice::from); Ok(Self { - model: model.to_string(), - messages: full_history, - temperature: req.temperature, - tools: req - .tools - .clone() + model, + messages, + temperature, + tools: tools .into_iter() .map(crate::providers::openai::completion::ToolDefinition::from) .collect::>(), tool_choice, - additional_params: req.additional_params, + additional_params, }) } } @@ -324,51 +304,51 @@ where } } -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged, rename_all = "snake_case")] -pub enum ToolChoice { - None, - Auto, - Function(Vec), -} - -impl TryFrom for ToolChoice { - type Error = CompletionError; - - fn try_from(value: crate::message::ToolChoice) -> Result { - let res = match value { - crate::message::ToolChoice::None => Self::None, - crate::message::ToolChoice::Auto => Self::Auto, - crate::message::ToolChoice::Specific { function_names } => { - let vec: Vec = function_names - .into_iter() - .map(|name| ToolChoiceFunctionKind::Function { name }) - .collect(); - - Self::Function(vec) - } - choice => { - return Err(CompletionError::ProviderError(format!( - "Unsupported tool choice type: {choice:?}" - ))); - } - }; - - Ok(res) - } -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(tag = "type", content = "function")] -pub enum ToolChoiceFunctionKind { - Function { name: String }, -} +pub type ToolChoice = crate::providers::openai::completion::CompatibleToolChoice; +pub type ToolChoiceFunctionKind = + crate::providers::openai::completion::CompatibleToolChoiceFunctionKind; +pub type ToolChoiceKeyword = crate::providers::openai::completion::CompatibleToolChoiceKeyword; #[cfg(test)] mod tests { use super::*; + use crate::providers::openai::completion::request_conformance; use crate::{OneOrMany, message}; + struct TogetherRequestHarness; + + impl request_conformance::Harness for TogetherRequestHarness { + fn family_name() -> &'static str { + "together-ai" + } + + fn run( + case: request_conformance::Fixture, + ) -> request_conformance::Outcome { + request_conformance::serialize_case(case, |request| { + TogetherAICompletionRequest::try_from(("default-model", request)) + }) + } + + fn assert( + case: request_conformance::Fixture, + actual: request_conformance::Outcome, + ) { + request_conformance::assert_compatible_chat_case( + request_conformance::CompatibleChatExpectation::new( + crate::providers::openai::completion::CompatibleChatProfile::new("TogetherAI") + .require_messages() + .reject_required_tool_choice(), + ), + "default-model", + case, + actual, + ); + } + } + + request_conformance::provider_request_conformance_tests!(TogetherRequestHarness); + #[test] fn together_request_conversion_errors_when_all_messages_are_filtered() { let request = CompletionRequest { @@ -390,4 +370,25 @@ mod tests { let result = TogetherAICompletionRequest::try_from(("meta-llama/test-model", request)); assert!(matches!(result, Err(CompletionError::RequestError(_)))); } + + #[test] + fn together_request_uses_request_model_override() { + let request = CompletionRequest { + model: Some("together-override".to_string()), + preamble: Some("system".to_string()), + chat_history: OneOrMany::one(message::Message::user("hello")), + documents: vec![], + tools: vec![], + temperature: None, + max_tokens: None, + tool_choice: None, + additional_params: None, + output_schema: None, + }; + + let converted = TogetherAICompletionRequest::try_from(("together-default", request)) + .expect("request should convert"); + + assert_eq!(converted.model, "together-override"); + } } diff --git a/rig/rig-core/src/providers/xai/completion.rs b/rig/rig-core/src/providers/xai/completion.rs index 1891379c1..80f5fb3b8 100644 --- a/rig/rig-core/src/providers/xai/completion.rs +++ b/rig/rig-core/src/providers/xai/completion.rs @@ -9,7 +9,6 @@ use tracing::{Instrument, Level, enabled, info_span}; use super::api::{ApiResponse, Message, ToolDefinition}; use super::client::Client; -use crate::OneOrMany; use crate::completion::{self, CompletionError, CompletionRequest}; use crate::http_client::HttpClientExt; use crate::providers::openai::completion::ToolChoice; @@ -132,6 +131,10 @@ impl TryFrom for completion::CompletionResponse Result { + let stop_reason = response + .status + .as_deref() + .map(|status| completion::StopReason::Other(status.to_string())); let content: Vec = response .output .iter() @@ -139,9 +142,7 @@ impl TryFrom for completion::CompletionResponse>::from) .collect(); - let choice = OneOrMany::many(content).map_err(|_| { - CompletionError::ResponseError("Response contained no output".to_owned()) - })?; + let choice = completion::AssistantChoice::from(content); let usage = response .usage @@ -164,6 +165,7 @@ impl TryFrom for completion::CompletionResponse, /// The final aggregated message from the stream /// contains all text and tool calls generated - pub choice: OneOrMany, + pub choice: AssistantChoice, /// The final response from the stream, may be `None` /// if the provider didn't yield it during the stream pub response: Option, pub final_response_yielded: AtomicBool, /// Provider-assigned message ID (e.g. OpenAI Responses API `msg_` ID). pub message_id: Option, + /// Provider-agnostic reason why the model stopped generating this turn. + pub stop_reason: Option, } impl StreamingCompletionResponse @@ -231,10 +236,11 @@ where assistant_items: vec![], text_item_index: None, reasoning_item_index: None, - choice: OneOrMany::one(AssistantContent::text("")), + choice: AssistantChoice::new(), response: None, final_response_yielded: AtomicBool::new(false), message_id: None, + stop_reason: None, } } @@ -304,6 +310,7 @@ where usage: Usage::new(), // Usage is not tracked in streaming responses raw_response: value.response, message_id: value.message_id, + stop_reason: value.stop_reason, } } } @@ -327,12 +334,7 @@ where Poll::Ready(None) => { // This is run at the end of the inner stream to collect all tokens into // a single unified `Message`. - if stream.assistant_items.is_empty() { - stream.assistant_items.push(AssistantContent::text("")); - } - - stream.choice = OneOrMany::many(std::mem::take(&mut stream.assistant_items)) - .expect("There should be at least one assistant message"); + stream.choice = AssistantChoice::from(std::mem::take(&mut stream.assistant_items)); Poll::Ready(None) } @@ -412,6 +414,10 @@ where stream.message_id = Some(id); stream.poll_next_unpin(cx) } + RawStreamingChoice::StopReason(reason) => { + stream.stop_reason = Some(reason); + stream.poll_next_unpin(cx) + } }, } } @@ -685,6 +691,23 @@ mod tests { StreamingCompletionResponse::stream(to_stream_result(stream)) } + fn create_empty_stream() -> StreamingCompletionResponse { + let stream = stream! { + yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 0 })); + }; + + StreamingCompletionResponse::stream(to_stream_result(stream)) + } + + fn create_stop_reason_stream() -> StreamingCompletionResponse { + let stream = stream! { + yield Ok(RawStreamingChoice::StopReason(crate::completion::StopReason::MaxTokens)); + yield Ok(RawStreamingChoice::FinalResponse(MockResponse { token_count: 0 })); + }; + + StreamingCompletionResponse::stream(to_stream_result(stream)) + } + fn create_interleaved_stream() -> StreamingCompletionResponse { let stream = stream! { yield Ok(RawStreamingChoice::Reasoning { @@ -837,6 +860,25 @@ mod tests { )); } + #[tokio::test] + async fn test_empty_stream_keeps_empty_choice() { + let mut stream = create_empty_stream(); + while stream.next().await.is_some() {} + + assert!(stream.choice.is_empty()); + } + + #[tokio::test] + async fn test_stream_captures_stop_reason() { + let mut stream = create_stop_reason_stream(); + while stream.next().await.is_some() {} + + assert_eq!( + stream.stop_reason, + Some(crate::completion::StopReason::MaxTokens) + ); + } + #[tokio::test] async fn test_stream_aggregates_assistant_items_in_arrival_order() { let mut stream = create_interleaved_stream(); diff --git a/rig/rig-core/tests/chatgpt/completion.rs b/rig/rig-core/tests/chatgpt/completion.rs index f897dfaf2..b1cea1ae6 100644 --- a/rig/rig-core/tests/chatgpt/completion.rs +++ b/rig/rig-core/tests/chatgpt/completion.rs @@ -2,7 +2,7 @@ use futures::StreamExt; use rig::client::CompletionClient; -use rig::completion::CompletionModel; +use rig::completion::{AssistantChoice, CompletionModel}; use rig::message::AssistantContent; use rig::message::Message; use rig::streaming::{StreamedAssistantContent, StreamingPrompt}; @@ -12,7 +12,7 @@ use crate::support::{ assert_contains_any_case_insensitive, assert_nonempty_response, collect_stream_final_response, }; -fn aggregated_text(choice: &rig::OneOrMany) -> String { +fn aggregated_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { diff --git a/rig/rig-core/tests/common/reasoning.rs b/rig/rig-core/tests/common/reasoning.rs index 8674536d9..181c8e7c0 100644 --- a/rig/rig-core/tests/common/reasoning.rs +++ b/rig/rig-core/tests/common/reasoning.rs @@ -206,7 +206,10 @@ where let turn1_assistant = Message::Assistant { id: response.message_id, - content: response.choice, + content: response + .choice + .into_one_or_many() + .expect("reasoning roundtrip response should not be empty"), }; let turn2_prompt = Message::User { diff --git a/rig/rig-core/tests/core/prompt_response_messages.rs b/rig/rig-core/tests/core/prompt_response_messages.rs index cdf94a0df..e69ca09d3 100644 --- a/rig/rig-core/tests/core/prompt_response_messages.rs +++ b/rig/rig-core/tests/core/prompt_response_messages.rs @@ -1,10 +1,10 @@ //! Integration tests for `PromptResponse.messages` using mock models. //! Exercises the real agent loop code path with mocked LLM responses. -use rig::OneOrMany; use rig::agent::AgentBuilder; use rig::completion::{ - CompletionError, CompletionModel, CompletionRequest, CompletionResponse, Message, Prompt, Usage, + AssistantChoice, CompletionError, CompletionModel, CompletionRequest, CompletionResponse, + Message, Prompt, Usage, }; use rig::message::{AssistantContent, Text, ToolCall, ToolFunction, UserContent}; use rig::streaming::{StreamingCompletionResponse, StreamingResult}; @@ -34,7 +34,7 @@ impl CompletionModel for SimpleTextModel { _request: CompletionRequest, ) -> Result, CompletionError> { Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::Text(Text { + choice: AssistantChoice::one(AssistantContent::Text(Text { text: "hello from mock".to_string(), })), usage: Usage { @@ -46,6 +46,7 @@ impl CompletionModel for SimpleTextModel { }, raw_response: (), message_id: Some("msg_mock_1".to_string()), + stop_reason: None, }) } @@ -92,7 +93,7 @@ impl CompletionModel for ToolThenTextModel { if turn == 0 { // First turn: return a tool call Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new( + choice: AssistantChoice::one(AssistantContent::ToolCall(ToolCall::new( "tc_1".to_string(), ToolFunction::new( "calculator".to_string(), @@ -108,11 +109,12 @@ impl CompletionModel for ToolThenTextModel { }, raw_response: (), message_id: Some("msg_tool".to_string()), + stop_reason: None, }) } else { // Second turn: return a text response Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::Text(Text { + choice: AssistantChoice::one(AssistantContent::Text(Text { text: "The answer is 5".to_string(), })), usage: Usage { @@ -124,6 +126,89 @@ impl CompletionModel for ToolThenTextModel { }, raw_response: (), message_id: Some("msg_text".to_string()), + stop_reason: None, + }) + } + } + + async fn stream( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + let stream: StreamingResult<()> = Box::pin(futures::stream::empty()); + Ok(StreamingCompletionResponse::stream(stream)) + } +} + +/// A mock model that emits text and a tool call, then ends with an empty +/// assistant turn. This exercises fallback to prior visible assistant text. +#[derive(Clone)] +struct ToolThenEmptyTerminalTurnModel { + turn: Arc, +} + +impl ToolThenEmptyTerminalTurnModel { + fn new() -> Self { + Self { + turn: Arc::new(AtomicUsize::new(0)), + } + } +} + +#[allow(refining_impl_trait)] +impl CompletionModel for ToolThenEmptyTerminalTurnModel { + type Response = (); + type StreamingResponse = (); + type Client = (); + + fn make(_: &Self::Client, _: impl Into) -> Self { + Self::new() + } + + async fn completion( + &self, + _request: CompletionRequest, + ) -> Result, CompletionError> { + let turn = self.turn.fetch_add(1, Ordering::SeqCst); + + if turn == 0 { + Ok(CompletionResponse { + choice: AssistantChoice::many(vec![ + AssistantContent::Text(Text { + text: "The answer is 5".to_string(), + }), + AssistantContent::ToolCall(ToolCall::new( + "tc_2".to_string(), + ToolFunction::new( + "calculator".to_string(), + serde_json::json!({"op": "add", "a": 2, "b": 3}), + ), + )), + ]), + usage: Usage { + input_tokens: 12, + output_tokens: 7, + total_tokens: 19, + cached_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + raw_response: (), + message_id: Some("msg_text_tool".to_string()), + stop_reason: None, + }) + } else { + Ok(CompletionResponse { + choice: AssistantChoice::new(), + usage: Usage { + input_tokens: 8, + output_tokens: 1, + total_tokens: 9, + cached_input_tokens: 0, + cache_creation_input_tokens: 0, + }, + raw_response: (), + message_id: Some("msg_empty".to_string()), + stop_reason: None, }) } } @@ -157,13 +242,14 @@ impl CompletionModel for AlwaysToolCallModel { _request: CompletionRequest, ) -> Result, CompletionError> { Ok(CompletionResponse { - choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new( + choice: AssistantChoice::one(AssistantContent::ToolCall(ToolCall::new( "tc_loop".to_string(), ToolFunction::new("infinite_tool".to_string(), serde_json::json!({"x": 1})), ))), usage: Usage::new(), raw_response: (), message_id: None, + stop_reason: None, }) } @@ -348,6 +434,21 @@ async fn multi_turn_messages_include_tool_calls() { assert_eq!(resp.usage.output_tokens, 12); // 8 + 4 } +/// Test 6: If the terminal assistant turn is empty after a prior visible text +/// turn, the agent should fall back to the earlier assistant text. +#[tokio::test] +async fn multi_turn_prompt_falls_back_to_prior_assistant_text_when_terminal_turn_is_empty() { + let agent = AgentBuilder::new(ToolThenEmptyTerminalTurnModel::new()).build(); + + let result = agent + .prompt("What is 2 + 3?") + .max_turns(5) + .await + .expect("prompt should succeed"); + + assert_eq!(result, "The answer is 5"); +} + /// Test 6: `PromptResponse::new()` backward compatibility — 2-argument constructor /// should still work, and `messages` should be `None`. #[tokio::test] diff --git a/rig/rig-core/tests/gemini/interactions_api.rs b/rig/rig-core/tests/gemini/interactions_api.rs index 6d45e4ae3..a67579f12 100644 --- a/rig/rig-core/tests/gemini/interactions_api.rs +++ b/rig/rig-core/tests/gemini/interactions_api.rs @@ -1,9 +1,8 @@ //! Migrated from `examples/gemini_interactions_api.rs`. use futures::StreamExt; -use rig::OneOrMany; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::{CompletionModel, GetTokenUsage}; +use rig::completion::{AssistantChoice, CompletionModel, GetTokenUsage}; use rig::message::{AssistantContent, Message, ToolCall, ToolChoice}; use rig::providers::gemini; use rig::providers::gemini::interactions_api::{AdditionalParameters, Tool}; @@ -11,7 +10,7 @@ use rig::streaming::StreamedAssistantContent; use crate::support::assert_nonempty_response; -fn extract_text(choice: &OneOrMany) -> String { +fn extract_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { @@ -22,7 +21,7 @@ fn extract_text(choice: &OneOrMany) -> String { .join("") } -fn first_tool_call(choice: &OneOrMany) -> Option { +fn first_tool_call(choice: &AssistantChoice) -> Option { choice.iter().find_map(|content| match content { AssistantContent::ToolCall(tool_call) => Some(tool_call.clone()), _ => None, diff --git a/rig/rig-core/tests/moonshot/reasoning_history.rs b/rig/rig-core/tests/moonshot/reasoning_history.rs index 81cae1baf..6c651b1e5 100644 --- a/rig/rig-core/tests/moonshot/reasoning_history.rs +++ b/rig/rig-core/tests/moonshot/reasoning_history.rs @@ -2,13 +2,13 @@ use rig::OneOrMany; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::CompletionModel; +use rig::completion::{AssistantChoice, CompletionModel}; use rig::message::{AssistantContent, Message, Reasoning}; use rig::providers::moonshot; use crate::support::{assert_contains_any_case_insensitive, assert_nonempty_response}; -fn response_text(choice: &rig::OneOrMany) -> String { +fn response_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content { diff --git a/rig/rig-core/tests/openai/websocket.rs b/rig/rig-core/tests/openai/websocket.rs index 70f3b0654..ccef4ed5d 100644 --- a/rig/rig-core/tests/openai/websocket.rs +++ b/rig/rig-core/tests/openai/websocket.rs @@ -2,7 +2,7 @@ use anyhow::Result; use rig::client::{CompletionClient, ProviderClient}; -use rig::completion::CompletionModel; +use rig::completion::{AssistantChoice, CompletionModel}; use rig::message::AssistantContent; use rig::providers::openai; use rig::providers::openai::responses_api::streaming::{ItemChunkKind, ResponseChunkKind}; @@ -10,7 +10,7 @@ use rig::providers::openai::responses_api::websocket::ResponsesWebSocketEvent; use crate::support::assert_nonempty_response; -fn extract_text(choice: &rig::OneOrMany) -> String { +fn extract_text(choice: &AssistantChoice) -> String { choice .iter() .filter_map(|content| match content {