Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,17 @@ exclude = [

[workspace.lints.clippy]
dbg_macro = "forbid"
await_holding_lock = "deny"
await_holding_refcell_ref = "deny"
expect_used = "deny"
expect_fun_call = "deny"
indexing_slicing = "deny"
panic = "deny"
panic_in_result_fn = "deny"
todo = "forbid"
unimplemented = "forbid"
unreachable = "deny"
unwrap_used = "deny"

[profile.release]
lto = true
Expand Down
20 changes: 9 additions & 11 deletions rig-integrations/rig-bedrock/examples/agent_with_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@ async fn main() -> Result<(), anyhow::Error> {
Ok(())
}

fn client() -> Client {
Client::from_env()
fn client() -> Result<Client, anyhow::Error> {
Ok(Client::from_env()?)
}

async fn partial_agent() -> AgentBuilder<rig_bedrock::completion::CompletionModel> {
let client = client();
client.agent(AMAZON_NOVA_LITE)
fn partial_agent() -> Result<AgentBuilder<rig_bedrock::completion::CompletionModel>, anyhow::Error>
{
Ok(client()?.agent(AMAZON_NOVA_LITE))
}

/// Create an AWS Bedrock agent with a system prompt
async fn basic() -> Result<(), anyhow::Error> {
let agent = partial_agent()
.await
let agent = partial_agent()?
.preamble("Answer with json format only")
.build();

Expand All @@ -53,8 +52,7 @@ async fn basic() -> Result<(), anyhow::Error> {

/// Create an AWS Bedrock with tools
async fn tools() -> Result<(), anyhow::Error> {
let calculator_agent = partial_agent()
.await
let calculator_agent = partial_agent()?
.preamble("You must only do math by using a tool.")
.max_tokens(1024)
.tool(common::Adder)
Expand All @@ -69,7 +67,7 @@ async fn tools() -> Result<(), anyhow::Error> {
}

async fn context() -> Result<(), anyhow::Error> {
let model = client().completion_model(AMAZON_NOVA_LITE);
let model = client()?.completion_model(AMAZON_NOVA_LITE);

// Create an agent with multiple context documents
let agent = AgentBuilder::new(model)
Expand All @@ -92,7 +90,7 @@ async fn context() -> Result<(), anyhow::Error> {
/// This example loads in all the rust examples from the rig-core crate and uses them as\\
/// context for the agent
async fn loaders() -> Result<(), anyhow::Error> {
let model = client().completion_model(AMAZON_NOVA_LITE);
let model = client()?.completion_model(AMAZON_NOVA_LITE);

// Load in all the rust examples
let examples = FileLoader::with_glob("rig-core/examples/*.rs")?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn main() -> Result<(), anyhow::Error> {
.with_target(false)
.init();

let client = Client::from_env();
let client = Client::from_env()?;
let agent = client
.agent(AMAZON_NOVA_LITE)
.preamble("Describe this document")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn main() -> Result<(), anyhow::Error> {
.with_target(false)
.init();

let client = Client::from_env();
let client = Client::from_env()?;
let embeddings = client
.embeddings_with_ndims(AMAZON_TITAN_EMBED_TEXT_V2_0, 256)
.document(Greetings {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ async fn main() -> Result<(), anyhow::Error> {
.with_target(false)
.init();

let client = Client::from_env();
let client = Client::from_env()?;
let data_extractor = client.extractor::<Person>(AMAZON_NOVA_LITE).build();
let person = data_extractor
.extract("Hello my name is John Doe! I am a software engineer.")
.await?;

info!(
"AWS Bedrock: {}",
serde_json::to_string_pretty(&person).unwrap()
);
info!("AWS Bedrock: {}", serde_json::to_string_pretty(&person)?);
Ok(())
}
12 changes: 7 additions & 5 deletions rig-integrations/rig-bedrock/examples/image_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@ use std::path::Path;
const DEFAULT_PATH: &str = "./output.png";

#[tokio::main]
async fn main() {
let client = Client::from_env();
async fn main() -> Result<(), anyhow::Error> {
let client = Client::from_env()?;
let image_generation_model = client.image_generation_model(AMAZON_NOVA_CANVAS);
let response = image_generation_model
.image_generation_request()
.prompt("A castle sitting upon a large mountain, overlooking the water.")
.width(512)
.height(512)
.send()
.await;
.await?;

// save image
let mut file = File::create_new(Path::new(&DEFAULT_PATH)).expect("Failed to create file");
let _ = file.write(&response.unwrap().image);
let mut file = File::create_new(Path::new(DEFAULT_PATH))?;
file.write_all(&response.image)?;

Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn main() -> Result<(), anyhow::Error> {
.with_target(false)
.init();

let client = rig_bedrock::client::Client::from_env();
let client = rig_bedrock::client::Client::from_env()?;
let agent = client
.agent(AMAZON_NOVA_LITE)
.preamble("You are an image describer.")
Expand Down
2 changes: 1 addition & 1 deletion rig-integrations/rig-bedrock/examples/rag_with_bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async fn main() -> Result<(), anyhow::Error> {
.with_target(false)
.init();

let client = Client::from_env();
let client = Client::from_env()?;
let embedding_model = client.embedding_model_with_ndims(AMAZON_TITAN_EMBED_TEXT_V2_0, 256);

// Generate embeddings for the definitions of all the documents using the specified embedding model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use rig_bedrock::{client::Client, completion::AMAZON_NOVA_LITE};
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Create streaming agent with a single context prompt
let agent = Client::from_env()
let agent = Client::from_env()?
.agent(AMAZON_NOVA_LITE)
.preamble("Be precise and concise.")
.temperature(0.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod common;
async fn main() -> Result<(), anyhow::Error> {
tracing_subscriber::fmt().init();
// Create agent with a single context prompt and two tools
let agent = Client::from_env()
let agent = Client::from_env()?
.agent(AMAZON_NOVA_LITE)
.preamble(
"You are a calculator here to help the user perform arithmetic
Expand Down
13 changes: 7 additions & 6 deletions rig-integrations/rig-bedrock/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,21 +96,22 @@ impl Client {

impl ProviderClient for Client {
type Input = Nothing;
type Error = rig::client::ProviderClientError;

fn from_env() -> Self
fn from_env() -> Result<Self, Self::Error>
where
Self: Sized,
{
Client::new()
Ok(Client::new())
}

fn from_val(_: Nothing) -> Self
fn from_val(_: Nothing) -> Result<Self, Self::Error>
where
Self: Sized,
{
panic!(
"Please use `Client::from_env` or `Client::with_profile_name(\"aws_profile\")` instead"
);
Err(rig::client::ProviderClientError::InvalidConfiguration(
"use `Client::from_env()` or `Client::with_profile_name(\"aws_profile\")` instead",
))
}
}

Expand Down
2 changes: 1 addition & 1 deletion rig-integrations/rig-bedrock/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl completion::CompletionModel for CompletionModel {
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt())
.set_system(request.system_prompt()?)
.set_messages(Some(messages));

async move {
Expand Down
11 changes: 11 additions & 0 deletions rig-integrations/rig-bedrock/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
#![cfg_attr(
test,
allow(
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic,
clippy::unwrap_used,
clippy::unreachable
)
)]

pub mod client;
pub mod completion;
pub mod embedding;
Expand Down
2 changes: 1 addition & 1 deletion rig-integrations/rig-bedrock/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl CompletionModel {
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt())
.set_system(request.system_prompt()?)
.set_messages(Some(prompt_with_history));

let response = converse_builder.send().await.map_err(|sdk_error| {
Expand Down
80 changes: 44 additions & 36 deletions rig-integrations/rig-bedrock/src/types/completion_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ pub struct AwsCompletionRequest {
pub prompt_caching: bool,
}

fn cache_point_block() -> CachePointBlock {
fn cache_point_block() -> Result<CachePointBlock, CompletionError> {
CachePointBlock::builder()
.r#type(CachePointType::Default)
.build()
.expect("CachePointBlock type is set")
.map_err(|e| CompletionError::RequestError(e.into()))
}

impl AwsCompletionRequest {
Expand Down Expand Up @@ -65,31 +65,33 @@ impl AwsCompletionRequest {
if !tools.is_empty() {
// Convert rig's ToolChoice to AWS Bedrock ToolChoice
use aws_sdk_bedrockruntime::types as aws_bedrock;
let tool_choice = self.inner.tool_choice.as_ref().and_then(|choice| {
match choice {
rig::message::ToolChoice::Auto => Some(aws_bedrock::ToolChoice::Auto(
let tool_choice = self
.inner
.tool_choice
.as_ref()
.map(|choice| match choice {
rig::message::ToolChoice::Auto => Ok(Some(aws_bedrock::ToolChoice::Auto(
aws_bedrock::AutoToolChoice::builder().build(),
)),
rig::message::ToolChoice::Required => Some(aws_bedrock::ToolChoice::Any(
))),
rig::message::ToolChoice::Required => Ok(Some(aws_bedrock::ToolChoice::Any(
aws_bedrock::AnyToolChoice::builder().build(),
)),
rig::message::ToolChoice::None => {
// Bedrock doesn't have a "None" option - just omit tool_choice
None
}
rig::message::ToolChoice::Specific { function_names } => {
// Use the first function name for Bedrock's specific tool choice
function_names.first().map(|name| {
aws_bedrock::ToolChoice::Tool(
aws_bedrock::SpecificToolChoice::builder()
.name(name.clone())
.build()
.expect("Failed to build SpecificToolChoice"),
)
))),
rig::message::ToolChoice::None => Ok(None),
rig::message::ToolChoice::Specific { function_names } => function_names
.first()
.map(|name| {
aws_bedrock::SpecificToolChoice::builder()
.name(name.clone())
.build()
.map(aws_bedrock::ToolChoice::Tool)
.map(Some)
.map_err(|e| CompletionError::RequestError(e.into()))
})
}
}
});
.transpose()
.map(Option::flatten),
})
.transpose()?
.flatten();

let config = ToolConfiguration::builder()
.set_tools(Some(tools))
Expand All @@ -103,7 +105,7 @@ impl AwsCompletionRequest {
}
}

pub fn system_prompt(&self) -> Option<Vec<SystemContentBlock>> {
pub fn system_prompt(&self) -> Result<Option<Vec<SystemContentBlock>>, CompletionError> {
let mut system_blocks = Vec::new();

if let Some(system_prompt) = self.inner.preamble.to_owned()
Expand All @@ -121,12 +123,12 @@ impl AwsCompletionRequest {
}

if system_blocks.is_empty() {
None
Ok(None)
} else {
if self.prompt_caching {
system_blocks.push(SystemContentBlock::CachePoint(cache_point_block()));
system_blocks.push(SystemContentBlock::CachePoint(cache_point_block()?));
}
Some(system_blocks)
Ok(Some(system_blocks))
}
}

Expand Down Expand Up @@ -165,7 +167,7 @@ impl AwsCompletionRequest {
&& let Some(last_msg) = messages.last_mut()
{
let mut content = last_msg.content.clone();
content.push(aws_bedrock::ContentBlock::CachePoint(cache_point_block()));
content.push(aws_bedrock::ContentBlock::CachePoint(cache_point_block()?));
*last_msg = aws_bedrock::Message::builder()
.role(last_msg.role.clone())
.set_content(Some(content))
Expand Down Expand Up @@ -465,14 +467,17 @@ mod tests {
};

let aws_request = aws_request(request, false);
let system_prompt = aws_request.system_prompt();
let system_prompt = aws_request
.system_prompt()
.expect("system prompt should build")
.expect("system prompt should exist");

assert!(system_prompt.is_some());
let system_prompt = system_prompt.unwrap();
assert_eq!(system_prompt.len(), 1);
assert_eq!(
system_prompt[0],
aws_bedrock::SystemContentBlock::Text("History system instruction".to_string())
system_prompt.first(),
Some(&aws_bedrock::SystemContentBlock::Text(
"History system instruction".to_string()
))
);
}

Expand All @@ -486,12 +491,15 @@ mod tests {
let aws_request = aws_request(request, true);
let system_prompt = aws_request
.system_prompt()
.expect("system prompt should build")
.expect("system prompt should exist");

assert_eq!(system_prompt.len(), 2);
assert_eq!(
system_prompt[0],
aws_bedrock::SystemContentBlock::Text("System prompt".to_string())
system_prompt.first(),
Some(&aws_bedrock::SystemContentBlock::Text(
"System prompt".to_string()
))
);
assert!(matches!(
system_prompt.last(),
Expand Down
Loading
Loading