Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions rig-integrations/rig-bedrock/src/types/assistant_content.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use aws_sdk_bedrockruntime::types as aws_bedrock;

use rig::{
OneOrMany,
completion::CompletionError,
message::{AssistantContent, Text, ToolCall, ToolFunction},
};
Expand Down Expand Up @@ -114,21 +113,25 @@ impl TryFrom<AwsConverseOutput> for completion::CompletionResponse<AwsConverseOu
_ => None,
}) {
return Ok(completion::CompletionResponse {
choice: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new(
tool_use.id,
ToolFunction::new(tool_use.function.name, tool_use.function.arguments),
))),
choice: completion::AssistantChoice::one(AssistantContent::ToolCall(
ToolCall::new(
tool_use.id,
ToolFunction::new(tool_use.function.name, tool_use.function.arguments),
),
)),
usage,
raw_response: value,
message_id: None,
stop_reason: None,
});
}

Ok(completion::CompletionResponse {
choice,
choice: choice.into(),
usage,
raw_response: value,
message_id: None,
stop_reason: None,
})
}
}
Expand Down
7 changes: 2 additions & 5 deletions rig-integrations/rig-gemini-grpc/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,7 @@ impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<Generat
assistant_contents.push(assistant_content);
}

let choice = OneOrMany::many(assistant_contents).map_err(|_| {
CompletionError::ResponseError(
"Response contained no message or tool call (empty)".to_owned(),
)
})?;
let choice = completion::AssistantChoice::from(assistant_contents);

let usage = response
.usage_metadata
Expand All @@ -473,6 +469,7 @@ impl TryFrom<GenerateContentResponse> for completion::CompletionResponse<Generat
usage,
raw_response: response,
message_id: None,
stop_reason: None,
})
}
}
Expand Down
16 changes: 4 additions & 12 deletions rig-integrations/rig-vertexai/src/types/completion_response.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use google_cloud_aiplatform_v1 as vertexai;
use rig::OneOrMany;
use rig::completion::{CompletionError, CompletionResponse, Usage};
use rig::completion::{AssistantChoice, CompletionError, CompletionResponse, Usage};
use rig::message::{AssistantContent, Text, ToolCall, ToolFunction};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -43,15 +42,7 @@ impl TryFrom<VertexGenerateContentOutput> for CompletionResponse<VertexGenerateC
}
}

if assistant_contents.is_empty() {
return Err(CompletionError::ProviderError(
"No text or tool call content found in response".to_string(),
));
}

let choice = OneOrMany::many(assistant_contents).map_err(|e| {
CompletionError::ProviderError(format!("Failed to create OneOrMany: {e}"))
})?;
let choice = AssistantChoice::from(assistant_contents);

let usage = response
.usage_metadata
Expand All @@ -70,6 +61,7 @@ impl TryFrom<VertexGenerateContentOutput> for CompletionResponse<VertexGenerateC
usage,
raw_response: value,
message_id: None,
stop_reason: None,
})
}
}
Expand Down Expand Up @@ -145,7 +137,7 @@ mod tests {
let response = completion_response.unwrap();

match response.choice.first() {
AssistantContent::ToolCall(ToolCall { id, function, .. }) => {
Some(AssistantContent::ToolCall(ToolCall { id, function, .. })) => {
assert_eq!(id, "add");
assert_eq!(function.name, "add");
assert_eq!(function.arguments, args);
Expand Down
17 changes: 9 additions & 8 deletions rig/rig-core/examples/manual_tool_calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
//! 5. repeats until the model returns a final text answer.

use anyhow::{Result, bail};
use rig::OneOrMany;
use rig::client::{CompletionClient, ProviderClient};
use rig::completion::{Completion, ToolDefinition};
use rig::completion::{AssistantChoice, Completion, ToolDefinition};
use rig::message::{AssistantContent, Message, ToolCall, ToolChoice};
use rig::providers::openai;
use rig::tool::{Tool, ToolSet};
Expand Down Expand Up @@ -87,7 +86,7 @@ impl Tool for Subtract {
}
}

fn collect_tool_calls(choice: &OneOrMany<AssistantContent>) -> Vec<ToolCall> {
fn collect_tool_calls(choice: &AssistantChoice) -> Vec<ToolCall> {
choice
.iter()
.filter_map(|content| match content {
Expand All @@ -97,7 +96,7 @@ fn collect_tool_calls(choice: &OneOrMany<AssistantContent>) -> Vec<ToolCall> {
.collect()
}

fn extract_text(choice: &OneOrMany<AssistantContent>) -> String {
fn extract_text(choice: &AssistantChoice) -> String {
choice
.iter()
.filter_map(|content| match content {
Expand Down Expand Up @@ -151,10 +150,12 @@ async fn main() -> Result<()> {
let tool_calls = collect_tool_calls(&response.choice);

history.push(current_prompt.clone());
history.push(Message::Assistant {
id: response.message_id.clone(),
content: response.choice.clone(),
});
if let Ok(content) = response.choice.to_one_or_many() {
history.push(Message::Assistant {
id: response.message_id.clone(),
content,
});
}

if tool_calls.is_empty() {
let final_text = extract_text(&response.choice);
Expand Down
35 changes: 17 additions & 18 deletions rig/rig-core/src/agent/prompt_request/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod hooks;
pub mod streaming;
mod turns;

use super::{
Agent,
Expand All @@ -26,6 +27,7 @@ use std::{
};
use tracing::info_span;
use tracing::{Instrument, span::Id};
use turns::{AssistantTextAccumulator, AssistantTurnSummary};

pub trait PromptType {}
pub struct Standard;
Expand Down Expand Up @@ -339,6 +341,7 @@ where

let mut current_max_turns = 0;
let mut usage = Usage::new();
let mut assistant_text_accumulator = AssistantTextAccumulator::default();
let current_span_id: AtomicU64 = AtomicU64::new(0);

// We need to do at least 2 loops for 1 roundtrip (user expects normal message)
Expand Down Expand Up @@ -447,34 +450,30 @@ where
));
}

let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
let turn_summary = AssistantTurnSummary::from_response(&resp);
let turn_text = turn_summary.visible_text("\n");
assistant_text_accumulator.observe(&turn_text);

let (tool_calls, _texts): (Vec<_>, Vec<_>) = resp
.choice
.iter()
.partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));

new_messages.push(Message::Assistant {
id: resp.message_id.clone(),
content: resp.choice.clone(),
});
if let Ok(content) = resp.choice.to_one_or_many() {
new_messages.push(Message::Assistant {
id: resp.message_id.clone(),
content,
});
}

if tool_calls.is_empty() {
let merged_texts = texts
.into_iter()
.filter_map(|content| {
if let AssistantContent::Text(text) = content {
Some(text.text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
let final_output = assistant_text_accumulator.final_output(&turn_text);

if self.max_turns > 1 {
tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
}

agent_span.record("gen_ai.completion", &merged_texts);
agent_span.record("gen_ai.completion", &final_output);
agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
agent_span.record(
Expand All @@ -486,7 +485,7 @@ where
usage.cache_creation_input_tokens,
);

return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
return Ok(PromptResponse::new(final_output, usage).with_messages(new_messages));
}

let hook = self.hook.clone();
Expand Down
107 changes: 95 additions & 12 deletions rig/rig-core/src/agent/prompt_request/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tracing::info_span;
use tracing_futures::Instrument;

use super::ToolCallHookAction;
use super::turns::{AssistantTextAccumulator, AssistantTurnSummary};
use crate::{
agent::Agent,
completion::{CompletionError, CompletionModel, PromptError},
Expand Down Expand Up @@ -170,16 +171,6 @@ fn tool_result_to_user_message(
}
}

fn assistant_text_from_choice(choice: &OneOrMany<AssistantContent>) -> String {
choice
.iter()
.filter_map(|content| match content {
AssistantContent::Text(text) => Some(text.text.as_str()),
_ => None,
})
.collect()
}

#[derive(Debug, thiserror::Error)]
pub enum StreamingError {
#[error("CompletionError: {0}")]
Expand Down Expand Up @@ -398,6 +389,7 @@ where
let output_schema = self.output_schema;

let mut aggregated_usage = crate::completion::Usage::new();
let mut assistant_text_accumulator = AssistantTextAccumulator::default();

// NOTE: We use .instrument(agent_span) instead of span.enter() to avoid
// span context leaking to other concurrent tasks. Using span.enter() inside
Expand Down Expand Up @@ -659,7 +651,9 @@ where
accumulated_reasoning.push(assembled);
}

let turn_text_response = assistant_text_from_choice(&stream.choice);
let turn_summary = AssistantTurnSummary::from_stream_response(&stream);
let turn_text_response = turn_summary.visible_text("");
assistant_text_accumulator.observe(&turn_text_response);
tracing::Span::current().record("gen_ai.completion", &turn_text_response);

// Add text, reasoning, and tool calls to chat history.
Expand Down Expand Up @@ -692,9 +686,18 @@ where
}

if !saw_tool_call_this_turn {
let final_response_text =
assistant_text_accumulator.final_output(&turn_text_response);

// Add user message and assistant response to history before finishing
if !turn_text_response.is_empty() {
new_messages.push(Message::assistant(&turn_text_response));
} else if !final_response_text.is_empty() {
tracing::info!(
agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
message_id = ?stream.message_id,
"Streaming turn completed without assistant text; using prior turn text for final response"
);
} else {
tracing::warn!(
agent_name = agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
Expand All @@ -715,7 +718,7 @@ where
None
};
yield Ok(MultiTurnStreamItem::final_response_with_history(
&turn_text_response,
&final_response_text,
aggregated_usage,
final_messages,
));
Expand Down Expand Up @@ -1224,6 +1227,86 @@ mod tests {
assert_eq!(final_response_text.as_deref(), Some(""));
}

#[derive(Clone, Default)]
struct StreamingEmptyTerminalTurnFallbackModel {
turn_counter: Arc<AtomicUsize>,
}

#[allow(refining_impl_trait)]
impl CompletionModel for StreamingEmptyTerminalTurnFallbackModel {
type Response = ();
type StreamingResponse = MockStreamingResponse;
type Client = ();

fn make(_: &Self::Client, _: impl Into<String>) -> Self {
Self::default()
}

async fn completion(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse<Self::Response>, CompletionError> {
Err(CompletionError::ProviderError(
"completion is unused in this streaming test".to_string(),
))
}

async fn stream(
&self,
_request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
let stream = async_stream::stream! {
if turn == 0 {
yield Ok(RawStreamingChoice::Message("The answer is 5".to_string()));
yield Ok(RawStreamingChoice::ToolCall(
RawStreamingToolCall::new(
"tool_call_2".to_string(),
"calculator".to_string(),
serde_json::json!({"op": "add", "a": 2, "b": 3}),
),
));
yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(2)));
} else {
yield Ok(RawStreamingChoice::FinalResponse(MockStreamingResponse::new(1)));
}
};

let pinned_stream: crate::streaming::StreamingResult<Self::StreamingResponse> =
Box::pin(stream);
Ok(StreamingCompletionResponse::stream(pinned_stream))
}
}

#[tokio::test]
async fn final_response_falls_back_to_prior_assistant_text_when_terminal_turn_is_empty() {
let model = StreamingEmptyTerminalTurnFallbackModel::default();
let turn_counter = model.turn_counter.clone();
let agent = AgentBuilder::new(model).build();

let mut stream = agent.stream_prompt("What is 2 + 3?").multi_turn(3).await;
let mut streamed_text = String::new();
let mut final_response_text = None;

while let Some(item) = stream.next().await {
match item {
Ok(MultiTurnStreamItem::StreamAssistantItem(StreamedAssistantContent::Text(
text,
))) => streamed_text.push_str(&text.text),
Ok(MultiTurnStreamItem::FinalResponse(res)) => {
final_response_text = Some(res.response().to_owned());
break;
}
Ok(_) => {}
Err(err) => panic!("unexpected streaming error: {err:?}"),
}
}

assert_eq!(streamed_text, "The answer is 5");
assert_eq!(final_response_text.as_deref(), Some("The answer is 5"));
assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
}

/// Background task that logs periodically to detect span leakage.
/// If span leakage occurs, these logs will be prefixed with `invoke_agent{...}`.
async fn background_logger(stop: Arc<AtomicBool>, leak_count: Arc<AtomicU32>) {
Expand Down
Loading